tilelang.language.builtin¶
The language interface for tl programs.
Functions¶
|
Create a list of memory barrier handles. |
|
Retrieve a memory barrier operation. |
|
Create a Tensor Memory Access (TMA) descriptor. |
|
Perform a Tensor Memory Access (TMA) load operation. |
|
Create a fence for asynchronous proxy operations. |
|
Signal the arrival of a TMA store operation. |
|
Wait for completion of TMA store operations. |
|
Set the maximum number of registers to use. |
|
Increment the maximum number of registers to use. |
|
Decrement the maximum number of registers to use. |
|
Annotate the producer reg dealloc. |
|
Annotate the consumer reg alloc. |
Disable the maximum register limit setting. |
|
Disable the warp group reg alloc. |
|
|
Wait for memory barrier parity condition. |
|
Arrive at memory barrier. |
|
Set expected transaction count for memory barrier. |
Signal warpgroup readiness for subsequent WGMMA operations. |
|
Commit the current warpgroup batch for WGMMA operations. |
|
|
Wait for completion of the specified warpgroup batch. |
|
Return the logical lane index of the calling thread within a warp. |
|
Return the canonical warp index, assuming the warp's threads are converged. |
|
Return the canonical warp index without synchronizing the warp. |
|
Return the canonical warp group index for the calling thread. |
|
Elect exactly one lane within a logical thread group. |
|
Insert a warpgroup fence for the destination accumulator registers. |
|
Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete. |
|
Wait for a memory barrier to complete. |
|
Arrive at a memory barrier. |
|
Perform a shuffle operation with XOR offset. |
|
Perform a shuffle operation with down offset. |
|
Perform a shuffle operation with up offset. |
|
Synchronize all threads in a block. |
Synchronize all threads in the entire grid. |
|
Synchronize all threads in a grid. |
|
|
Initialize a WGMMA/UTCMMA shared-memory descriptor. |
|
Initialize a TCGEN05 shared-memory descriptor. |
|
Increase the offset of a memory descriptor. |
Break out of the innermost loop. |
|
|
Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. |
|
Signal UMMA (TCGEN05) barrier arrival for a shared-memory mbarrier pointer. |
|
TVM intrinsic for ptx tensor core mma instructions on SM70 (Volta). |
Module Contents¶
- tilelang.language.builtin.create_list_of_mbarrier(*args)¶
Create a list of memory barrier handles.
- Parameters:
*args (list or Any) – Either a single list of arguments, or multiple arguments directly.
- Returns:
Handle to the created list of memory barriers.
- Return type:
tvm.tir.Call
- Raises:
TypeError – If the input is not a list or variadic arguments.
Examples
>>> create_list_of_mbarrier([128, 128]) >>> create_list_of_mbarrier(128, 128)
- tilelang.language.builtin.get_mbarrier(*args)¶
Retrieve a memory barrier operation.
- Parameters:
*args – Variable arguments to specify which memory barrier to retrieve
- Returns:
A handle to the requested memory barrier
- Return type:
tir.Call
- tilelang.language.builtin.create_tma_descriptor(*args)¶
Create a Tensor Memory Access (TMA) descriptor.
- Parameters:
*args – Variable arguments defining the TMA descriptor configuration
- Returns:
A handle to the created TMA descriptor
- Return type:
tir.Call
- tilelang.language.builtin.tma_load(*args)¶
Perform a Tensor Memory Access (TMA) load operation.
- Parameters:
*args – Variable arguments specifying the TMA load parameters
- Returns:
A handle to the TMA load operation
- Return type:
tir.Call
- tilelang.language.builtin.fence_proxy_async(*args)¶
Create a fence for asynchronous proxy operations.
- Parameters:
*args – Variable arguments for fence configuration
- Returns:
A handle to the fence operation
- Return type:
tir.Call
- tilelang.language.builtin.tma_store_arrive(*args)¶
Signal the arrival of a TMA store operation.
- Parameters:
*args – Variable arguments for the store arrival operation
- Returns:
A handle to the store arrive operation
- Return type:
tir.Call
- tilelang.language.builtin.tma_store_wait(*args)¶
Wait for completion of TMA store operations.
- Parameters:
*args – Variable arguments specifying which store operations to wait for
- Returns:
A handle to the store wait operation
- Return type:
tir.Call
- tilelang.language.builtin.set_max_nreg(reg_count, is_inc)¶
Set the maximum number of registers to use. Detailed Documentation: https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg
- Parameters:
reg_count (int) – int The number of registers to allocate
is_inc (int) – int Whether to increment or decrement the register count 0 if decrement, 1 if increment
- Returns:
A handle to the register setting operation
- Return type:
tir.Call
- tilelang.language.builtin.inc_max_nreg(reg_count)¶
Increment the maximum number of registers to use.
- Parameters:
reg_count (int)
- tilelang.language.builtin.dec_max_nreg(reg_count)¶
Decrement the maximum number of registers to use.
- Parameters:
reg_count (int)
- tilelang.language.builtin.annotate_producer_reg_dealloc(reg_count=24)¶
Annotate the producer reg dealloc.
- Parameters:
reg_count (int)
- tilelang.language.builtin.annotate_consumer_reg_alloc(reg_count=240)¶
Annotate the consumer reg alloc.
- Parameters:
reg_count (int)
- tilelang.language.builtin.no_set_max_nreg()¶
Disable the maximum register limit setting.
- tilelang.language.builtin.disable_warp_group_reg_alloc()¶
Disable the warp group reg alloc.
- tilelang.language.builtin.mbarrier_wait_parity(mbarrier, parity)¶
Wait for memory barrier parity condition.
- Parameters:
mbarrier (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The memory barrier to wait on
parity (int | tvm.tir.Var) – Optional[int, Var] The parity value to wait for
Examples
# Wait for parity 0 on barrier 0 T.mbarrier_wait_parity(0, 0) # Wait for parity value in variable ko on barrier 1 T.mbarrier_wait_parity(1, ko) # Wait using barrier handle barrier = T.get_mbarrier(0) T.mbarrier_wait_parity(barrier, 1) # Common usage in pipelined kernels: for ko in range(num_stages): # Producer waits for consumer to finish previous iteration T.mbarrier_wait_parity(1, ko ^ 1) # Producer copies data T.copy(A_global, A_shared) # Producer signals data ready T.mbarrier_arrive(0) # Consumer waits for producer data T.mbarrier_wait_parity(0, ko) # Consumer computes T.gemm(A_shared, B_shared, C_local) # Consumer signals completion T.mbarrier_arrive(1)
- Returns:
A handle to the barrier wait operation
- Return type:
tir.Call
- Parameters:
mbarrier (int | tvm.tir.PrimExpr | tvm.tir.Call)
parity (int | tvm.tir.Var)
- tilelang.language.builtin.mbarrier_arrive(mbarrier)¶
Arrive at memory barrier.
- Parameters:
mbarrier (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The memory barrier to arrive at
- tilelang.language.builtin.mbarrier_expect_tx(*args)¶
Set expected transaction count for memory barrier.
- Parameters:
*args – Variable arguments specifying the expected transaction count
- Returns:
A handle to the barrier expectation operation
- Return type:
tir.Call
- tilelang.language.builtin.warpgroup_arrive()¶
Signal warpgroup readiness for subsequent WGMMA operations.
- Returns:
A handle to the warpgroup arrive operation.
- Return type:
tir.Call
- tilelang.language.builtin.warpgroup_commit_batch()¶
Commit the current warpgroup batch for WGMMA operations.
- Returns:
A handle to the warpgroup commit batch operation.
- Return type:
tir.Call
- tilelang.language.builtin.warpgroup_wait(num_mma)¶
Wait for completion of the specified warpgroup batch.
- Parameters:
num_mma (int) – int Identifier of the warpgroup MMA batch to wait on.
- Returns:
A handle to the warpgroup wait operation.
- Return type:
tir.Call
- tilelang.language.builtin.get_lane_idx(warp_size=None)¶
Return the logical lane index of the calling thread within a warp.
- Parameters:
warp_size (Optional[int, PrimExpr]) – Logical warp (or wavefront) size. Defaults to 32 on NVIDIA and 64 on AMD.
- Return type:
tvm.tir.PrimExpr
Example
>>> lane = T.get_lane_idx() >>> custom_lane = T.get_lane_idx(64) # override warp size explicitly
Implementation Notes¶
Lowers to the CUDA helper tl::get_lane_idx(warp_size) defined in src/tl_templates/cuda/intrin.h, which computes the lane index from the linear thread id using the provided warp_size.
- tilelang.language.builtin.get_warp_idx_sync(warp_size=None)¶
Return the canonical warp index, assuming the warp’s threads are converged.
- Parameters:
warp_size (Optional[int, PrimExpr]) – Logical warp size used for the index calculation.
- Return type:
tvm.tir.PrimExpr
Example
>>> warp = T.get_warp_idx_sync() >>> custom_warp = T.get_warp_idx_sync(64)
Implementation Notes¶
Emits tl::get_warp_idx_sync(warp_size) which divides the block-linear thread id by warp_size, matching the semantics of CUTLASS’ canonical helpers.
- tilelang.language.builtin.get_warp_idx(warp_size=None)¶
Return the canonical warp index without synchronizing the warp.
- Parameters:
warp_size (Optional[int, PrimExpr]) – Logical warp size used for the index calculation.
- Return type:
tvm.tir.PrimExpr
Example
>>> warp = T.get_warp_idx() >>> custom_warp = T.get_warp_idx(64)
Implementation Notes¶
Lowers to tl::get_warp_idx(warp_size) which divides the block-linear thread id by the provided warp_size without requiring warp convergence.
- tilelang.language.builtin.get_warp_group_idx(warp_size=None, warps_per_group=None)¶
Return the canonical warp group index for the calling thread.
- Parameters:
warp_size (Optional[int, PrimExpr]) – Logical warp size to use (defaults to 32 on NVIDIA / 64 on AMD).
warps_per_group (Optional[int, PrimExpr]) – Number of warps per warp-group. Defaults to 4 on NVIDIA architectures.
- Return type:
tvm.tir.PrimExpr
Example
>>> group = T.get_warp_group_idx() >>> custom_group = T.get_warp_group_idx(32, 6) # treat 6 warps as a group
Implementation Notes¶
Generates tl::get_warp_group_idx(warp_size, warps_per_group) which divides the block-linear thread id by warp_size * warps_per_group, matching the canonical ordering while allowing architecture-specific overrides.
- tilelang.language.builtin.shuffle_elect(thread_extent)¶
Elect exactly one lane within a logical thread group.
- Parameters:
thread_extent (int) – Size (in threads) of the group in which a single lane should be elected. Passing 0 elects a single lane in the entire thread block.
- Return type:
tvm.tir.PrimExpr
Example
>>> is_leader = T.shuffle_elect(64) >>> T.if_then_else(is_leader, do_leader_work(), T.evaluate(0))
Implementation Notes¶
Lowered to the CUDA helper tl::tl_shuffle_elect<thread_extent>() defined in src/tl_templates/cuda/intrin.h, which relies on cutlass::canonical_warp_idx_sync() and cute::elect_one_sync() (or __shfl_sync) to pick one lane per group.
- tilelang.language.builtin.warpgroup_fence_operand(buffer_or_ptr, offset=0, num_regs=None, dtype=None)¶
Insert a warpgroup fence for the destination accumulator registers.
This prevents NVCC from sinking uses of accumulator fragments past the corresponding WGMMA operations by issuing an empty inline assembly barrier on every register.
- Parameters:
buffer_or_ptr (tvm.tir.Buffer | tvm.tir.PrimExpr) – Buffer | BufferLoad | BufferRegion | PrimExpr A buffer representing the accumulator fragment, a buffer load/region that identifies a starting element within the fragment, or a pointer expression (e.g., tvm_access_ptr/address_of/typed Var).
offset (int | tvm.tir.PrimExpr) – int | PrimExpr Element offset from the start of the accumulator fragment.
num_regs (int | tvm.tir.PrimExpr | None) – int | PrimExpr | None Number of 32-bit registers to fence. If None and a Buffer is provided, it will be derived from the buffer shape and dtype.
dtype (str | None) – str | None Data type string of the accumulator elements. When passing a buffer or buffer-derived expression, dtype is inferred. It is required only when passing a raw pointer expression that cannot be inferred.
- Returns:
A handle to the warpgroup fence operation.
- Return type:
tir.Call
- tilelang.language.builtin.wait_wgmma(id)¶
Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.
- Parameters:
id (int) – int The id of the WGMMA operation to wait for
- Returns:
A handle to the WGMMA wait operation
- Return type:
tir.Call
- tilelang.language.builtin.barrier_wait(barrier_id, parity=None)¶
Wait for a memory barrier to complete.
- Parameters:
barrier_id (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The memory barrier to wait on
parity (int | tvm.tir.Var | None) – Optional[int, Var] The parity value to wait for
- Returns:
A handle to the barrier wait operation
- Return type:
tir.Call
Current implementation is a sugar syntax for mbarrier_wait_parity, as we only support parity 0 and 1.
- tilelang.language.builtin.barrier_arrive(barrier_id)¶
Arrive at a memory barrier.
- Parameters:
barrier_id (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The memory barrier to arrive at
- tilelang.language.builtin.shfl_xor(value, offset)¶
Perform a shuffle operation with XOR offset.
- Parameters:
value (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The value to shuffle
offset (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The offset for the shuffle operation
- Returns:
A handle to the shuffle operation
- Return type:
tir.Call
- tilelang.language.builtin.shfl_down(value, offset)¶
Perform a shuffle operation with down offset.
- Parameters:
value (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The value to shuffle
offset (int | tvm.tir.PrimExpr | tvm.tir.Call)
- tilelang.language.builtin.shfl_up(value, offset)¶
Perform a shuffle operation with up offset.
- Parameters:
value (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The value to shuffle
offset (int | tvm.tir.PrimExpr | tvm.tir.Call)
- tilelang.language.builtin.sync_threads(barrier_id=None, arrive_count=None)¶
Synchronize all threads in a block.
- Parameters:
barrier_id (int)
arrive_count (int)
- tilelang.language.builtin.sync_global()¶
Synchronize all threads in the entire grid.
- tilelang.language.builtin.sync_grid()¶
Synchronize all threads in a grid.
- tilelang.language.builtin.initialize_wgmma_descriptor(descriptor, start_address, layout_type_=0, leading_byte_offset=0, stride_byte_offset=0)¶
Initialize a WGMMA/UTCMMA shared-memory descriptor.
- Parameters:
descriptor (tvm.tir.Buffer)
start_address (tvm.tir.PrimExpr)
layout_type_ (int)
leading_byte_offset (int)
stride_byte_offset (int)
- Return type:
tvm.tir.PrimExpr
- tilelang.language.builtin.initialize_tcgen05_descriptor(descriptor, start_address, leading_byte_offset, stride_byte_offset, base_offset=0, leading_is_absolute=False, swizzle_mode=0)¶
Initialize a TCGEN05 shared-memory descriptor.
- Parameters:
descriptor (tvm.tir.Buffer)
start_address (tvm.tir.PrimExpr)
leading_byte_offset (int)
stride_byte_offset (int)
base_offset (int)
leading_is_absolute (bool)
swizzle_mode (int)
- Return type:
tvm.tir.PrimExpr
- tilelang.language.builtin.increase_descriptor_offset(descriptor, offset)¶
Increase the offset of a memory descriptor.
- Parameters:
descriptor (PrimExpr) – The memory descriptor to modify.
offset (PrimExpr) – The offset value to increase.
- Returns:
A handle representing the modified descriptor.
- Return type:
PrimExpr
- tilelang.language.builtin.loop_break()¶
Break out of the innermost loop.
- tilelang.language.builtin.cp_async_barrier_noinc(barrier_id)¶
Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
- Parameters:
barrier_id (int | tvm.tir.PrimExpr | tvm.tir.Call)
- tilelang.language.builtin.tcgen05_mma_arrive(mbar_ptr)¶
Signal UMMA (TCGEN05) barrier arrival for a shared-memory mbarrier pointer.
- Parameters:
mbar_ptr (PrimExpr) – Pointer to the mbarrier object in shared memory (e.g., Barrier*).
- tilelang.language.builtin.ptx_mma_sm70(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index)¶
TVM intrinsic for ptx tensor core mma instructions on SM70 (Volta).
This intrinsic provides SM70-specific MMA operations that support m16n16k4 shape with FP16 inputs and FP16/FP32 accumulation.
- Parameters:
shape (str) – The shape of mma fragment (e.g., “m16n16k4”).
A_layout (str) – The layout of multiplicand fragment A (“row” or “col”).
B_layout (str) – The layout of multiplicand fragment B (“row” or “col”).
A_dtype (str) – The data type of multiplicand fragment A (typically “fp16”).
B_dtype (str) – The data type of multiplicand fragment B (typically “fp16”).
C_dtype (str) – The data type of accumulator fragment C (“fp16” or “fp32”).
multiplicand_a (Var) – The multiplicand fragment A variable.
a_index (Expr) – The index of multiplicand fragment A.
multiplicand_b (Var) – The multiplicand fragment B variable.
b_index (Expr) – The index of multiplicand fragment B.
accumulator (Var) – The accumulator fragment C variable.
c_index (Expr) – The index of accumulator fragment C.
- Returns:
call – The call expression.
- Return type:
PrimExpr
Examples
>>> T.ptx_mma_sm70( ... "float16", ... "m16n16k4", ... "row", ... "col", ... "fp16", ... "fp16", ... "fp16", ... A_local.data, ... 0, ... B_local.data, ... 0, ... C_local.data, ... 0, ... )