Tensor Checks (Host-Side Auto-Validation)¶
This page explains the host-side checks that TileLang automatically inserts into the generated host stub for kernels. When you pass torch.Tensor or any DLPack-compatible object to a TileLang kernel, the host stub validates argument count, pointer kinds, dtype, shape, strides, device, and more — so you don’t need to handwrite Python checks. This keeps the ABI stable and significantly reduces Python overhead compared to doing equivalent checks in Python or via pybind.
Why Host-Side Checks¶
ABI stability: the entry is based on TVM FFI + DLPack, consistently accepting tensors and scalars.
Lower overhead: shifting checks from Python into C reduces interpreter/property-access costs; the call overhead is lower than pybind-based approaches.
Focused error reporting: assertions are raised close to the call site with precise “which field failed” messages.
How To Inspect Host Source¶
You can inspect the auto-generated host source (with all checks and the final device-kernel call) for debugging:
print(matmul_relu_kernel.get_host_source())
What The Host Checks¶
1) Argument count and pointer kind¶
num_argsmust match the number of formal parameters; otherwise the kernel returns-1with an error message.Each argument’s FFI type must be a pointer kind (for DLTensor/handle) or a valid scalar type; otherwise you’ll see errors like
Expect arg[i] to be pointeror a scalar type error.
2) Tensor checks (per tensor, after nullability decision)¶
Nullability
If the tensor is “statically reachable/used” by the function body, the handle must be non-NULL; otherwise:
xxx is expected to have non-NULL pointer.If an input tensor is not used by the function (statically unreachable), NULL is allowed; other field checks are executed only when
handle != NULL.
Rank (
ndim)Runtime
ndimmust equal the compile-time rank.
Data type (
dtype)Match the triple
(code, bits, lanes)with tolerance:float8_e4m3: accepte4m3,e4m3fn,e4m3fnuz.float8_e5m2: accepte5m2,e5m2fnuz.bool: acceptint8/uint8withbits=8(same lanes),kDLBool(code=6, bits=1 or 8), and anybitwidth=1(lanes must match).
For packed-bit dtypes (e.g.,
Int(1),Int(4),UInt(4)), strict dtype checking is skipped.
Shape
Each runtime dimension is bound to the compile-time shape (constants or symbols) and checked for consistency.
Linear equations among symbolic dims can be solved on the fly (when there’s only one unknown at a given check point), enabling cross-tensor constraints.
Strides
If
buffer_type = AutoBroadcast: allowstrides == NULLand derive strides fromshape. If explicitstridesis present, bind to compile-time constraints and check for equality.Otherwise: check per-dimension; if
strides == NULL, derive fromshapeand compare (e.g., contiguous:strides[-1] == 1,strides[-2] == shape[-1]).
byte_offsetMust be 0 (non-zero raises an error) to keep addressing simple and aligned.
Device info
Assert
device_type == target backend(CUDA/ROCM/Metal/OneAPI/WebGPU/CPU, etc.). Error messages include a DLPack code legend.When multiple tensors participate, assert that
device_idmatches across them.
Data pointer
Must be non-NULL when the tensor is required to be non-null by the nullability rule.
3) Scalar checks¶
T.int*family: require integer; error:Expect arg[i] to be int.T.bool: require boolean; error:Expect arg[i] to be boolean.
Shapes and Symbolic Equations: Linear Solving¶
When shapes are symbolic, the host binds and (when possible) solves linear relations at runtime (only one unknown per check point). Example:
@T.prim_func
def main(
A: T.Tensor((m,), dtype),
B: T.Tensor((m + n,), dtype),
C: T.Tensor((n * k,), dtype),
):
...
This enables enforcing cross-tensor relationships like len(B) == m + n and len(C) == n * k at runtime.
Nullability Rules and Examples¶
Which tensors may be NULL?
Rule: If an input tensor is not used by the function under static analysis (i.e., the access is statically unreachable), it is considered Nullable; otherwise it must be non-NULL.
Examples:
Must be non-NULL (used)
@T.prim_func
def main(A: T.Tensor((M, K), dtype)):
A[0] = 1
Passing None raises: main.A_handle is expected to have non-NULL pointer.
Still must be non-NULL (constant-true branch)
some_cond: bool = True
@T.prim_func
def main(A: T.Tensor((M, K), dtype)):
if some_cond:
A[0] = 1
Nullable (constant-false branch, statically unreachable)
some_cond: bool = False
@T.prim_func
def main(A: T.Tensor((M, K), dtype)):
if some_cond:
A[0] = 1
Must be non-NULL (runtime condition)
@T.prim_func
def main(A: T.Tensor((M, K), dtype), some_cond: T.bool):
if some_cond:
A[0] = 1
Since some_cond is only known at runtime, static analysis cannot prove A is unused; A is thus non-nullable.
Device Type Codes (DLPack)¶
Supported and referenced device codes in error messages: 1=CPU, 2=CUDA, 7=Vulkan, 8=Metal, 10=ROCM, 14=OneAPI, 15=WebGPU.
Kernels assert that device_type matches the target backend, and require device_id consistency across tensors.
Common Error Examples (What you’ll see)¶
Argument count mismatch (num_args)
Trigger: missing/extra argument
Error:
<kernel>: num_args should be N; expected: <num_args>, got: N
Pointer-typed argument expected
Trigger: scalar passed where a tensor is expected
Error:
<kernel>: Expect arg[i] to be pointer
Rank (ndim) mismatch
Trigger: runtime rank differs from compile-time rank
Error:
<kernel>.<name>.ndim is expected to equal R, but got mismatched ndim
Dtype mismatch
Trigger: dtype not equal to the compiled dtype and not within the tolerance set
Error:
<kernel>.<name>.dtype is expected to be <dtype>, but got incompatible dtype
Shape constraint violation
Trigger: a dimension doesn’t match a constant/symbol binding
Error:
Argument <kernel>.<name>.shape[i] has an unsatisfied constraint: ... == <expected>
Strides check failed (e.g., non-contiguous layout)
Trigger: transposed/sliced tensors that violate expected strides
Error:
Argument <kernel>.<name>.strides[j] has an unsatisfied constraint: ... == <expected>
Device type mismatch
Trigger: calling a CUDA kernel with CPU tensors, etc.
Error:
<kernel>.<name>.device_type mismatch [expected: <code> (<name>)] ...
Device id mismatch
Trigger: mixing tensors from different GPUs
Error:
Argument <kernel>.<name>.device_id has an unsatisfied constraint: ... == ...
NULL data pointer
Trigger: tensor required to be non-null has a NULL data pointer
Error:
<kernel>.<name> is expected to have non-NULL data pointer, but got NULL
Scalar type mismatch
Trigger: passing float to
T.int32, or non-boolean toT.boolError:
<kernel>: Expect arg[i] to be int/boolean
Troubleshooting Tips¶
Print the host source:
print(fn.get_host_source())to see the exact assertion and expected vs. actual fields.Fix strides: call
.contiguous()for non-contiguous tensors, or avoid generating transposed/sliced layouts that break assumptions.Align devices: ensure all participating tensors share the same
device_typeanddevice_id.Align dtype: use
.to(<dtype>)or construct tensors with the correct dtype; pay attention tofloat8andbooltolerance.Dynamic shapes: ensure cross-tensor linear relations can be uniquely determined at the check point (only one unknown at a time).
FAQ¶
Can I disable the checks?
Not recommended and usually not supported. Checks are done on the host to preserve ABI stability and fail early close to the device call.
Is the overhead noticeable?
The checks are lightweight (branches and field reads). Compared to Python-side checks, it’s faster; the dominating cost remains the Python→C boundary. Overall it’s cheaper than equivalent checks in Python.
Reference Example (Matmul + ReLU)¶
@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
# For debugging, print the host source
print(matmul_relu_kernel.get_host_source())
The host will insert all checks described above for this example.
Quick Error Reference (Short List)¶
Argument count
Trigger: missing/extra args; Error:
num_args should be N; expected: <num_args>, got: N.
Pointer kind
Trigger: scalar passed to tensor arg; Error:
Expect arg[i] to be pointer.
Rank (ndim)
Trigger: runtime rank != compile-time; Error:
ndim ... expected to equal R.
Dtype
Trigger: mismatch and not tolerated; Error:
dtype ... expected to be <dtype>.
Shape
Trigger: constant/symbol binding violated; Error:
shape[i] ... == <expected>.
Strides
Trigger: layout mismatch; Error:
strides[j] ... == <expected>.
Device type
Trigger: wrong backend device; Error:
device_type mismatch [expected: ...].
Device id
Trigger: tensors on different GPUs; Error:
device_id ... == ....
Data pointer
Trigger: required non-NULL but NULL; Error:
non-NULL data pointer.
Scalar types
Trigger: wrong scalar type; Error:
Expect arg[i] to be int/boolean.
Host Error Troubleshooting (Minimal Repros)¶
Below are minimal repro snippets for common host-side errors, assuming a CUDA-targeted kernel like matmul_relu_kernel with:
# Convention:
# A: float16 [M, K]
# B: float16 [K, N]
# C: float16 [M, N]
# Target: CUDA (device_type=2)
fn = matmul_relu_kernel # your compiled function
M = N = K = 1024
Adjust dtype/device if your kernel differs.
0. Tip: print the host source¶
print(fn.get_host_source())
1. num_args mismatch¶
import torch
A = torch.empty((M, K), device='cuda', dtype=torch.float16)
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
# Missing C
fn(A, B)
Expected: <kernel>: num_args should be 3; expected: <num_args>, got: 3.
Fix: pass all arguments per the signature.
2. Expect pointer (tensor) but got scalar¶
import torch
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(1, B, C)
Expected: <kernel>: Expect arg[0] to be pointer.
Fix: pass a DLPack-compatible tensor (e.g., torch.Tensor).
3. ndim mismatch¶
import torch
A = torch.empty((M, K, 1), device='cuda', dtype=torch.float16) # rank=3
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(A, B, C)
Expected: <kernel>.A_handle.ndim is expected to equal 2, but got mismatched ndim.
Fix: ensure runtime rank equals compiled rank.
4. dtype mismatch¶
import torch
A = torch.empty((M, K), device='cuda', dtype=torch.float32) # should be float16
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(A, B, C)
Expected: <kernel>.A_handle.dtype is expected to be float16, but got incompatible dtype.
Fix: A = A.to(torch.float16) or create with the correct dtype.
5. Shape constant/symbol mismatch¶
import torch
A = torch.empty((M, K + 1), device='cuda', dtype=torch.float16) # K mismatched
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(A, B, C)
Expected: Argument <kernel>.A_handle.shape[i] has an unsatisfied constraint: ... == <expected>.
Fix: satisfy linear constraints and constants across tensors.
6. Strides check failure (non-contiguous)¶
import torch
A = torch.empty((M, K), device='cuda', dtype=torch.float16)
A_nc = A.t() # transpose -> non-contiguous
B = torch.empty((K, N), device='cuda', dtype=torch.float16)
C = torch.empty((M, N), device='cuda', dtype=torch.float16)
fn(A_nc, B, C)
Expected: Argument <kernel>.A_handle.strides[1] has an unsatisfied constraint: ... == 1.
Fix: pass A_nc.contiguous() or align the layout expectation in the kernel.
7. device_type mismatch¶
import torch
A = torch.empty((M, K), device='cpu', dtype=torch.float16)
B = torch.empty((K, N), device='cpu', dtype=torch.float16)
C = torch.empty((M, N), device='cpu', dtype=torch.float16)
fn(A, B, C) # CUDA-targeted kernel
Expected: <kernel>.A_handle.device_type mismatch [expected: 2 (cuda)] ....
Fix: move tensors to the CUDA device.
8. device_id mismatch (multi-GPU)¶
import torch
A = torch.empty((M, K), device='cuda:0', dtype=torch.float16)
B = torch.empty((K, N), device='cuda:1', dtype=torch.float16)
C = torch.empty((M, N), device='cuda:0', dtype=torch.float16)
fn(A, B, C)
Expected: Argument <kernel>.B_handle.device_id has an unsatisfied constraint: ... == ....
Fix: place all tensors on the same GPU (e.g., cuda:0).
9. NULL data pointer (advanced)¶
This usually comes from hand-constructed DLTensor/NDArray, or external frameworks passing unallocated/freed storage. Regular torch.Tensor allocations rarely hit this.
Expected: <kernel>.<name> is expected to have non-NULL data pointer, but got NULL.
Fix: ensure valid underlying storage; in PyTorch scenarios, avoid constructing tensors from invalid external handles.
10. Scalar type mismatch (int / bool)¶
import tilelang.language as T
@T.prim_func
def scalar_check(x: T.int32, flag: T.bool()):
T.evaluate(0)
scalar_check(1.0, True) # x is float -> Expect arg[0] to be int
scalar_check(1, 2.5) # flag is float -> Expect arg[1] to be boolean
Fix: pass correct scalar types, e.g., scalar_check(1, True).
Closing Notes¶
Cross-check “shape / strides / device / dtype” against the kernel signature to localize issues efficiently.
For complex symbolic relations, print the host source to confirm binding/solving order, then adjust runtime shapes/layouts accordingly.