# Tensor Checks (Host-Side Auto-Validation) This page explains the host-side checks that TileLang automatically inserts into the generated host stub for kernels. When you pass `torch.Tensor` or any DLPack-compatible object to a TileLang kernel, the host stub validates argument count, pointer kinds, dtype, shape, strides, device, and more — so you don’t need to handwrite Python checks. This keeps the ABI stable and significantly reduces Python overhead compared to doing equivalent checks in Python or via pybind. ## Why Host-Side Checks - ABI stability: the entry is based on TVM FFI + DLPack, consistently accepting tensors and scalars. - Lower overhead: shifting checks from Python into C reduces interpreter/property-access costs; the call overhead is lower than pybind-based approaches. - Focused error reporting: assertions are raised close to the call site with precise “which field failed” messages. ## How To Inspect Host Source You can inspect the auto-generated host source (with all checks and the final device-kernel call) for debugging: ```python print(matmul_relu_kernel.get_host_source()) ``` --- ## What The Host Checks ### 1) Argument count and pointer kind - `num_args` must match the number of formal parameters; otherwise the kernel returns `-1` with an error message. - Each argument’s FFI type must be a pointer kind (for DLTensor/handle) or a valid scalar type; otherwise you’ll see errors like `Expect arg[i] to be pointer` or a scalar type error. ### 2) Tensor checks (per tensor, after nullability decision) - Nullability - If the tensor is “statically reachable/used” by the function body, the handle must be non-NULL; otherwise: `xxx is expected to have non-NULL pointer`. - If an input tensor is not used by the function (statically unreachable), NULL is allowed; other field checks are executed only when `handle != NULL`. - Rank (`ndim`) - Runtime `ndim` must equal the compile-time rank. - Data type (`dtype`) - Match the triple `(code, bits, lanes)` with tolerance: - `float8_e4m3`: accept `e4m3`, `e4m3fn`, `e4m3fnuz`. - `float8_e5m2`: accept `e5m2`, `e5m2fnuz`. - `bool`: accept `int8/uint8` with `bits=8` (same lanes), `kDLBool(code=6, bits=1 or 8)`, and any `bitwidth=1` (lanes must match). - For packed-bit dtypes (e.g., `Int(1)`, `Int(4)`, `UInt(4)`), strict dtype checking is skipped. - Shape - Each runtime dimension is bound to the compile-time shape (constants or symbols) and checked for consistency. - Linear equations among symbolic dims can be solved on the fly (when there’s only one unknown at a given check point), enabling cross-tensor constraints. - Strides - If `buffer_type = AutoBroadcast`: allow `strides == NULL` and derive strides from `shape`. If explicit `strides` is present, bind to compile-time constraints and check for equality. - Otherwise: check per-dimension; if `strides == NULL`, derive from `shape` and compare (e.g., contiguous: `strides[-1] == 1`, `strides[-2] == shape[-1]`). - `byte_offset` - Must be 0 (non-zero raises an error) to keep addressing simple and aligned. - Device info - Assert `device_type == target backend` (CUDA/ROCM/Metal/OneAPI/WebGPU/CPU, etc.). Error messages include a DLPack code legend. - When multiple tensors participate, assert that `device_id` matches across them. - Data pointer - Must be non-NULL when the tensor is required to be non-null by the nullability rule. ### 3) Scalar checks - `T.int*` family: require integer; error: `Expect arg[i] to be int`. - `T.bool`: require boolean; error: `Expect arg[i] to be boolean`. --- ## Shapes and Symbolic Equations: Linear Solving When shapes are symbolic, the host binds and (when possible) solves linear relations at runtime (only one unknown per check point). Example: ```python @T.prim_func def main( A: T.Tensor((m,), dtype), B: T.Tensor((m + n,), dtype), C: T.Tensor((n * k,), dtype), ): ... ``` This enables enforcing cross-tensor relationships like `len(B) == m + n` and `len(C) == n * k` at runtime. --- ## Nullability Rules and Examples Which tensors may be NULL? - Rule: If an input tensor is not used by the function under static analysis (i.e., the access is statically unreachable), it is considered Nullable; otherwise it must be non-NULL. - Examples: 1) Must be non-NULL (used) ```python @T.prim_func def main(A: T.Tensor((M, K), dtype)): A[0] = 1 ``` Passing `None` raises: `main.A_handle is expected to have non-NULL pointer`. 2) Still must be non-NULL (constant-true branch) ```python some_cond: bool = True @T.prim_func def main(A: T.Tensor((M, K), dtype)): if some_cond: A[0] = 1 ``` 3) Nullable (constant-false branch, statically unreachable) ```python some_cond: bool = False @T.prim_func def main(A: T.Tensor((M, K), dtype)): if some_cond: A[0] = 1 ``` 4) Must be non-NULL (runtime condition) ```python @T.prim_func def main(A: T.Tensor((M, K), dtype), some_cond: T.bool): if some_cond: A[0] = 1 ``` Since `some_cond` is only known at runtime, static analysis cannot prove `A` is unused; `A` is thus non-nullable. --- ## Device Type Codes (DLPack) Supported and referenced device codes in error messages: `1=CPU, 2=CUDA, 7=Vulkan, 8=Metal, 10=ROCM, 14=OneAPI, 15=WebGPU`. Kernels assert that `device_type` matches the target backend, and require `device_id` consistency across tensors. --- ## Common Error Examples (What you’ll see) - Argument count mismatch (num_args) - Trigger: missing/extra argument - Error: `: num_args should be N; expected: , got: N` - Pointer-typed argument expected - Trigger: scalar passed where a tensor is expected - Error: `: Expect arg[i] to be pointer` - Rank (ndim) mismatch - Trigger: runtime rank differs from compile-time rank - Error: `..ndim is expected to equal R, but got mismatched ndim` - Dtype mismatch - Trigger: dtype not equal to the compiled dtype and not within the tolerance set - Error: `..dtype is expected to be , but got incompatible dtype` - Shape constraint violation - Trigger: a dimension doesn’t match a constant/symbol binding - Error: `Argument ..shape[i] has an unsatisfied constraint: ... == ` - Strides check failed (e.g., non-contiguous layout) - Trigger: transposed/sliced tensors that violate expected strides - Error: `Argument ..strides[j] has an unsatisfied constraint: ... == ` - Device type mismatch - Trigger: calling a CUDA kernel with CPU tensors, etc. - Error: `..device_type mismatch [expected: ()] ...` - Device id mismatch - Trigger: mixing tensors from different GPUs - Error: `Argument ..device_id has an unsatisfied constraint: ... == ...` - NULL data pointer - Trigger: tensor required to be non-null has a NULL data pointer - Error: `. is expected to have non-NULL data pointer, but got NULL` - Scalar type mismatch - Trigger: passing float to `T.int32`, or non-boolean to `T.bool` - Error: `: Expect arg[i] to be int/boolean` --- ## Troubleshooting Tips - Print the host source: `print(fn.get_host_source())` to see the exact assertion and expected vs. actual fields. - Fix strides: call `.contiguous()` for non-contiguous tensors, or avoid generating transposed/sliced layouts that break assumptions. - Align devices: ensure all participating tensors share the same `device_type` and `device_id`. - Align dtype: use `.to()` or construct tensors with the correct dtype; pay attention to `float8` and `bool` tolerance. - Dynamic shapes: ensure cross-tensor linear relations can be uniquely determined at the check point (only one unknown at a time). --- ## FAQ - Can I disable the checks? - Not recommended and usually not supported. Checks are done on the host to preserve ABI stability and fail early close to the device call. - Is the overhead noticeable? - The checks are lightweight (branches and field reads). Compared to Python-side checks, it’s faster; the dominating cost remains the Python→C boundary. Overall it’s cheaper than equivalent checks in Python. --- ## Reference Example (Matmul + ReLU) ```python @T.prim_func def matmul_relu_kernel( A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype), C: T.Tensor((M, N), dtype), ): # Initialize Kernel Context with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.clear(C_local) for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0): T.copy(A[by * block_M, ko * block_K], A_shared) T.copy(B[ko * block_K, bx * block_N], B_shared) T.gemm(A_shared, B_shared, C_local) T.copy(C_local, C[by * block_M, bx * block_N]) # For debugging, print the host source print(matmul_relu_kernel.get_host_source()) ``` The host will insert all checks described above for this example. --- ## Quick Error Reference (Short List) - Argument count - Trigger: missing/extra args; Error: `num_args should be N; expected: , got: N`. - Pointer kind - Trigger: scalar passed to tensor arg; Error: `Expect arg[i] to be pointer`. - Rank (ndim) - Trigger: runtime rank != compile-time; Error: `ndim ... expected to equal R`. - Dtype - Trigger: mismatch and not tolerated; Error: `dtype ... expected to be `. - Shape - Trigger: constant/symbol binding violated; Error: `shape[i] ... == `. - Strides - Trigger: layout mismatch; Error: `strides[j] ... == `. - Device type - Trigger: wrong backend device; Error: `device_type mismatch [expected: ...]`. - Device id - Trigger: tensors on different GPUs; Error: `device_id ... == ...`. - Data pointer - Trigger: required non-NULL but NULL; Error: `non-NULL data pointer`. - Scalar types - Trigger: wrong scalar type; Error: `Expect arg[i] to be int/boolean`. --- ## Host Error Troubleshooting (Minimal Repros) Below are minimal repro snippets for common host-side errors, assuming a CUDA-targeted kernel like `matmul_relu_kernel` with: ```python # Convention: # A: float16 [M, K] # B: float16 [K, N] # C: float16 [M, N] # Target: CUDA (device_type=2) fn = matmul_relu_kernel # your compiled function M = N = K = 1024 ``` Adjust dtype/device if your kernel differs. ### 0. Tip: print the host source ```python print(fn.get_host_source()) ``` ### 1. num_args mismatch ```python import torch A = torch.empty((M, K), device='cuda', dtype=torch.float16) B = torch.empty((K, N), device='cuda', dtype=torch.float16) # Missing C fn(A, B) ``` Expected: `: num_args should be 3; expected: , got: 3`. Fix: pass all arguments per the signature. ### 2. Expect pointer (tensor) but got scalar ```python import torch B = torch.empty((K, N), device='cuda', dtype=torch.float16) C = torch.empty((M, N), device='cuda', dtype=torch.float16) fn(1, B, C) ``` Expected: `: Expect arg[0] to be pointer`. Fix: pass a DLPack-compatible tensor (e.g., torch.Tensor). ### 3. ndim mismatch ```python import torch A = torch.empty((M, K, 1), device='cuda', dtype=torch.float16) # rank=3 B = torch.empty((K, N), device='cuda', dtype=torch.float16) C = torch.empty((M, N), device='cuda', dtype=torch.float16) fn(A, B, C) ``` Expected: `.A_handle.ndim is expected to equal 2, but got mismatched ndim`. Fix: ensure runtime rank equals compiled rank. ### 4. dtype mismatch ```python import torch A = torch.empty((M, K), device='cuda', dtype=torch.float32) # should be float16 B = torch.empty((K, N), device='cuda', dtype=torch.float16) C = torch.empty((M, N), device='cuda', dtype=torch.float16) fn(A, B, C) ``` Expected: `.A_handle.dtype is expected to be float16, but got incompatible dtype`. Fix: `A = A.to(torch.float16)` or create with the correct dtype. ### 5. Shape constant/symbol mismatch ```python import torch A = torch.empty((M, K + 1), device='cuda', dtype=torch.float16) # K mismatched B = torch.empty((K, N), device='cuda', dtype=torch.float16) C = torch.empty((M, N), device='cuda', dtype=torch.float16) fn(A, B, C) ``` Expected: `Argument .A_handle.shape[i] has an unsatisfied constraint: ... == `. Fix: satisfy linear constraints and constants across tensors. ### 6. Strides check failure (non-contiguous) ```python import torch A = torch.empty((M, K), device='cuda', dtype=torch.float16) A_nc = A.t() # transpose -> non-contiguous B = torch.empty((K, N), device='cuda', dtype=torch.float16) C = torch.empty((M, N), device='cuda', dtype=torch.float16) fn(A_nc, B, C) ``` Expected: `Argument .A_handle.strides[1] has an unsatisfied constraint: ... == 1`. Fix: pass `A_nc.contiguous()` or align the layout expectation in the kernel. ### 7. device_type mismatch ```python import torch A = torch.empty((M, K), device='cpu', dtype=torch.float16) B = torch.empty((K, N), device='cpu', dtype=torch.float16) C = torch.empty((M, N), device='cpu', dtype=torch.float16) fn(A, B, C) # CUDA-targeted kernel ``` Expected: `.A_handle.device_type mismatch [expected: 2 (cuda)] ...`. Fix: move tensors to the CUDA device. ### 8. device_id mismatch (multi-GPU) ```python import torch A = torch.empty((M, K), device='cuda:0', dtype=torch.float16) B = torch.empty((K, N), device='cuda:1', dtype=torch.float16) C = torch.empty((M, N), device='cuda:0', dtype=torch.float16) fn(A, B, C) ``` Expected: `Argument .B_handle.device_id has an unsatisfied constraint: ... == ...`. Fix: place all tensors on the same GPU (e.g., `cuda:0`). ### 9. NULL data pointer (advanced) This usually comes from hand-constructed DLTensor/NDArray, or external frameworks passing unallocated/freed storage. Regular `torch.Tensor` allocations rarely hit this. Expected: `. is expected to have non-NULL data pointer, but got NULL`. Fix: ensure valid underlying storage; in PyTorch scenarios, avoid constructing tensors from invalid external handles. ### 10. Scalar type mismatch (int / bool) ```python import tilelang.language as T @T.prim_func def scalar_check(x: T.int32, flag: T.bool()): T.evaluate(0) scalar_check(1.0, True) # x is float -> Expect arg[0] to be int scalar_check(1, 2.5) # flag is float -> Expect arg[1] to be boolean ``` Fix: pass correct scalar types, e.g., `scalar_check(1, True)`. --- ## Closing Notes - Cross-check “shape / strides / device / dtype” against the kernel signature to localize issues efficiently. - For complex symbolic relations, print the host source to confirm binding/solving order, then adjust runtime shapes/layouts accordingly.