tilelang.language.builtin

The language interface for tl programs.

Functions

create_list_of_mbarrier(*args)

Create a list of memory barrier handles.

get_mbarrier(*args)

Retrieve a memory barrier operation.

create_tma_descriptor(*args)

Create a Tensor Memory Access (TMA) descriptor.

tma_load(*args)

Perform a Tensor Memory Access (TMA) load operation.

fence_proxy_async(*args)

Create a fence for asynchronous proxy operations.

tma_store_arrive(*args)

Signal the arrival of a TMA store operation.

tma_store_wait(*args)

Wait for completion of TMA store operations.

set_max_nreg(reg_count, is_inc)

Set the maximum number of registers to use.

inc_max_nreg(reg_count)

Increment the maximum number of registers to use.

dec_max_nreg(reg_count)

Decrement the maximum number of registers to use.

annotate_producer_reg_dealloc([reg_count])

Annotate the producer reg dealloc.

annotate_consumer_reg_alloc([reg_count])

Annotate the consumer reg alloc.

no_set_max_nreg()

Disable the maximum register limit setting.

disable_warp_group_reg_alloc()

Disable the warp group reg alloc.

mbarrier_wait_parity(mbarrier, parity)

Wait for memory barrier parity condition.

mbarrier_arrive(mbarrier)

Arrive at memory barrier.

mbarrier_expect_tx(*args)

Set expected transaction count for memory barrier.

warpgroup_arrive()

Signal warpgroup readiness for subsequent WGMMA operations.

warpgroup_commit_batch()

Commit the current warpgroup batch for WGMMA operations.

warpgroup_wait(num_mma)

Wait for completion of the specified warpgroup batch.

get_lane_idx([warp_size])

Return the logical lane index of the calling thread within a warp.

get_warp_idx_sync([warp_size])

Return the canonical warp index, assuming the warp's threads are converged.

get_warp_idx([warp_size])

Return the canonical warp index without synchronizing the warp.

get_warp_group_idx([warp_size, warps_per_group])

Return the canonical warp group index for the calling thread.

shuffle_elect(thread_extent)

Elect exactly one lane within a logical thread group.

warpgroup_fence_operand(buffer_or_ptr[, offset, ...])

Insert a warpgroup fence for the destination accumulator registers.

wait_wgmma(id)

Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.

barrier_wait(barrier_id[, parity])

Wait for a memory barrier to complete.

barrier_arrive(barrier_id)

Arrive at a memory barrier.

shfl_xor(value, offset)

Perform a shuffle operation with XOR offset.

shfl_down(value, offset)

Perform a shuffle operation with down offset.

shfl_up(value, offset)

Perform a shuffle operation with up offset.

sync_threads([barrier_id, arrive_count])

Synchronize all threads in a block.

sync_global()

Synchronize all threads in the entire grid.

sync_grid()

Synchronize all threads in a grid.

initialize_wgmma_descriptor(descriptor, start_address)

Initialize a WGMMA/UTCMMA shared-memory descriptor.

initialize_tcgen05_descriptor(descriptor, ...[, ...])

Initialize a TCGEN05 shared-memory descriptor.

increase_descriptor_offset(descriptor, offset)

Increase the offset of a memory descriptor.

loop_break()

Break out of the innermost loop.

cp_async_barrier_noinc(barrier_id)

Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.

tcgen05_mma_arrive(mbar_ptr)

Signal UMMA (TCGEN05) barrier arrival for a shared-memory mbarrier pointer.

ptx_mma_sm70(shape, A_layout, B_layout, A_dtype, ...)

TVM intrinsic for ptx tensor core mma instructions on SM70 (Volta).

Module Contents

tilelang.language.builtin.create_list_of_mbarrier(*args)

Create a list of memory barrier handles.

Parameters:

*args (list or Any) – Either a single list of arguments, or multiple arguments directly.

Returns:

Handle to the created list of memory barriers.

Return type:

tvm.tir.Call

Raises:

TypeError – If the input is not a list or variadic arguments.

Examples

>>> create_list_of_mbarrier([128, 128])
>>> create_list_of_mbarrier(128, 128)
tilelang.language.builtin.get_mbarrier(*args)

Retrieve a memory barrier operation.

Parameters:

*args – Variable arguments to specify which memory barrier to retrieve

Returns:

A handle to the requested memory barrier

Return type:

tir.Call

tilelang.language.builtin.create_tma_descriptor(*args)

Create a Tensor Memory Access (TMA) descriptor.

Parameters:

*args – Variable arguments defining the TMA descriptor configuration

Returns:

A handle to the created TMA descriptor

Return type:

tir.Call

tilelang.language.builtin.tma_load(*args)

Perform a Tensor Memory Access (TMA) load operation.

Parameters:

*args – Variable arguments specifying the TMA load parameters

Returns:

A handle to the TMA load operation

Return type:

tir.Call

tilelang.language.builtin.fence_proxy_async(*args)

Create a fence for asynchronous proxy operations.

Parameters:

*args – Variable arguments for fence configuration

Returns:

A handle to the fence operation

Return type:

tir.Call

tilelang.language.builtin.tma_store_arrive(*args)

Signal the arrival of a TMA store operation.

Parameters:

*args – Variable arguments for the store arrival operation

Returns:

A handle to the store arrive operation

Return type:

tir.Call

tilelang.language.builtin.tma_store_wait(*args)

Wait for completion of TMA store operations.

Parameters:

*args – Variable arguments specifying which store operations to wait for

Returns:

A handle to the store wait operation

Return type:

tir.Call

tilelang.language.builtin.set_max_nreg(reg_count, is_inc)

Set the maximum number of registers to use. Detailed Documentation: https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg

Parameters:
  • reg_count (int) – int The number of registers to allocate

  • is_inc (int) – int Whether to increment or decrement the register count 0 if decrement, 1 if increment

Returns:

A handle to the register setting operation

Return type:

tir.Call

tilelang.language.builtin.inc_max_nreg(reg_count)

Increment the maximum number of registers to use.

Parameters:

reg_count (int)

tilelang.language.builtin.dec_max_nreg(reg_count)

Decrement the maximum number of registers to use.

Parameters:

reg_count (int)

tilelang.language.builtin.annotate_producer_reg_dealloc(reg_count=24)

Annotate the producer reg dealloc.

Parameters:

reg_count (int)

tilelang.language.builtin.annotate_consumer_reg_alloc(reg_count=240)

Annotate the consumer reg alloc.

Parameters:

reg_count (int)

tilelang.language.builtin.no_set_max_nreg()

Disable the maximum register limit setting.

tilelang.language.builtin.disable_warp_group_reg_alloc()

Disable the warp group reg alloc.

tilelang.language.builtin.mbarrier_wait_parity(mbarrier, parity)

Wait for memory barrier parity condition.

Parameters:
  • mbarrier (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The memory barrier to wait on

  • parity (int | tvm.tir.Var) – Optional[int, Var] The parity value to wait for

Examples

# Wait for parity 0 on barrier 0
T.mbarrier_wait_parity(0, 0)

# Wait for parity value in variable ko on barrier 1
T.mbarrier_wait_parity(1, ko)

# Wait using barrier handle
barrier = T.get_mbarrier(0)
T.mbarrier_wait_parity(barrier, 1)

# Common usage in pipelined kernels:
for ko in range(num_stages):
    # Producer waits for consumer to finish previous iteration
    T.mbarrier_wait_parity(1, ko ^ 1)
    # Producer copies data
    T.copy(A_global, A_shared)
    # Producer signals data ready
    T.mbarrier_arrive(0)

    # Consumer waits for producer data
    T.mbarrier_wait_parity(0, ko)
    # Consumer computes
    T.gemm(A_shared, B_shared, C_local)
    # Consumer signals completion
    T.mbarrier_arrive(1)
Returns:

A handle to the barrier wait operation

Return type:

tir.Call

Parameters:
  • mbarrier (int | tvm.tir.PrimExpr | tvm.tir.Call)

  • parity (int | tvm.tir.Var)

tilelang.language.builtin.mbarrier_arrive(mbarrier)

Arrive at memory barrier.

Parameters:

mbarrier (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The memory barrier to arrive at

tilelang.language.builtin.mbarrier_expect_tx(*args)

Set expected transaction count for memory barrier.

Parameters:

*args – Variable arguments specifying the expected transaction count

Returns:

A handle to the barrier expectation operation

Return type:

tir.Call

tilelang.language.builtin.warpgroup_arrive()

Signal warpgroup readiness for subsequent WGMMA operations.

Returns:

A handle to the warpgroup arrive operation.

Return type:

tir.Call

tilelang.language.builtin.warpgroup_commit_batch()

Commit the current warpgroup batch for WGMMA operations.

Returns:

A handle to the warpgroup commit batch operation.

Return type:

tir.Call

tilelang.language.builtin.warpgroup_wait(num_mma)

Wait for completion of the specified warpgroup batch.

Parameters:

num_mma (int) – int Identifier of the warpgroup MMA batch to wait on.

Returns:

A handle to the warpgroup wait operation.

Return type:

tir.Call

tilelang.language.builtin.get_lane_idx(warp_size=None)

Return the logical lane index of the calling thread within a warp.

Parameters:

warp_size (Optional[int, PrimExpr]) – Logical warp (or wavefront) size. Defaults to 32 on NVIDIA and 64 on AMD.

Return type:

tvm.tir.PrimExpr

Example

>>> lane = T.get_lane_idx()
>>> custom_lane = T.get_lane_idx(64)  # override warp size explicitly

Implementation Notes

Lowers to the CUDA helper tl::get_lane_idx(warp_size) defined in src/tl_templates/cuda/intrin.h, which computes the lane index from the linear thread id using the provided warp_size.

tilelang.language.builtin.get_warp_idx_sync(warp_size=None)

Return the canonical warp index, assuming the warp’s threads are converged.

Parameters:

warp_size (Optional[int, PrimExpr]) – Logical warp size used for the index calculation.

Return type:

tvm.tir.PrimExpr

Example

>>> warp = T.get_warp_idx_sync()
>>> custom_warp = T.get_warp_idx_sync(64)

Implementation Notes

Emits tl::get_warp_idx_sync(warp_size) which divides the block-linear thread id by warp_size, matching the semantics of CUTLASS’ canonical helpers.

tilelang.language.builtin.get_warp_idx(warp_size=None)

Return the canonical warp index without synchronizing the warp.

Parameters:

warp_size (Optional[int, PrimExpr]) – Logical warp size used for the index calculation.

Return type:

tvm.tir.PrimExpr

Example

>>> warp = T.get_warp_idx()
>>> custom_warp = T.get_warp_idx(64)

Implementation Notes

Lowers to tl::get_warp_idx(warp_size) which divides the block-linear thread id by the provided warp_size without requiring warp convergence.

tilelang.language.builtin.get_warp_group_idx(warp_size=None, warps_per_group=None)

Return the canonical warp group index for the calling thread.

Parameters:
  • warp_size (Optional[int, PrimExpr]) – Logical warp size to use (defaults to 32 on NVIDIA / 64 on AMD).

  • warps_per_group (Optional[int, PrimExpr]) – Number of warps per warp-group. Defaults to 4 on NVIDIA architectures.

Return type:

tvm.tir.PrimExpr

Example

>>> group = T.get_warp_group_idx()
>>> custom_group = T.get_warp_group_idx(32, 6)  # treat 6 warps as a group

Implementation Notes

Generates tl::get_warp_group_idx(warp_size, warps_per_group) which divides the block-linear thread id by warp_size * warps_per_group, matching the canonical ordering while allowing architecture-specific overrides.

tilelang.language.builtin.shuffle_elect(thread_extent)

Elect exactly one lane within a logical thread group.

Parameters:

thread_extent (int) – Size (in threads) of the group in which a single lane should be elected. Passing 0 elects a single lane in the entire thread block.

Return type:

tvm.tir.PrimExpr

Example

>>> is_leader = T.shuffle_elect(64)
>>> T.if_then_else(is_leader, do_leader_work(), T.evaluate(0))

Implementation Notes

Lowered to the CUDA helper tl::tl_shuffle_elect<thread_extent>() defined in src/tl_templates/cuda/intrin.h, which relies on cutlass::canonical_warp_idx_sync() and cute::elect_one_sync() (or __shfl_sync) to pick one lane per group.

tilelang.language.builtin.warpgroup_fence_operand(buffer_or_ptr, offset=0, num_regs=None, dtype=None)

Insert a warpgroup fence for the destination accumulator registers.

This prevents NVCC from sinking uses of accumulator fragments past the corresponding WGMMA operations by issuing an empty inline assembly barrier on every register.

Parameters:
  • buffer_or_ptr (tvm.tir.Buffer | tvm.tir.PrimExpr) – Buffer | BufferLoad | BufferRegion | PrimExpr A buffer representing the accumulator fragment, a buffer load/region that identifies a starting element within the fragment, or a pointer expression (e.g., tvm_access_ptr/address_of/typed Var).

  • offset (int | tvm.tir.PrimExpr) – int | PrimExpr Element offset from the start of the accumulator fragment.

  • num_regs (int | tvm.tir.PrimExpr | None) – int | PrimExpr | None Number of 32-bit registers to fence. If None and a Buffer is provided, it will be derived from the buffer shape and dtype.

  • dtype (str | None) – str | None Data type string of the accumulator elements. When passing a buffer or buffer-derived expression, dtype is inferred. It is required only when passing a raw pointer expression that cannot be inferred.

Returns:

A handle to the warpgroup fence operation.

Return type:

tir.Call

tilelang.language.builtin.wait_wgmma(id)

Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.

Parameters:

id (int) – int The id of the WGMMA operation to wait for

Returns:

A handle to the WGMMA wait operation

Return type:

tir.Call

tilelang.language.builtin.barrier_wait(barrier_id, parity=None)

Wait for a memory barrier to complete.

Parameters:
  • barrier_id (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The memory barrier to wait on

  • parity (int | tvm.tir.Var | None) – Optional[int, Var] The parity value to wait for

Returns:

A handle to the barrier wait operation

Return type:

tir.Call

Current implementation is a sugar syntax for mbarrier_wait_parity, as we only support parity 0 and 1.

tilelang.language.builtin.barrier_arrive(barrier_id)

Arrive at a memory barrier.

Parameters:

barrier_id (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The memory barrier to arrive at

tilelang.language.builtin.shfl_xor(value, offset)

Perform a shuffle operation with XOR offset.

Parameters:
  • value (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The value to shuffle

  • offset (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The offset for the shuffle operation

Returns:

A handle to the shuffle operation

Return type:

tir.Call

tilelang.language.builtin.shfl_down(value, offset)

Perform a shuffle operation with down offset.

Parameters:
  • value (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The value to shuffle

  • offset (int | tvm.tir.PrimExpr | tvm.tir.Call)

tilelang.language.builtin.shfl_up(value, offset)

Perform a shuffle operation with up offset.

Parameters:
  • value (int | tvm.tir.PrimExpr | tvm.tir.Call) – Optional[int, PrimExpr] The value to shuffle

  • offset (int | tvm.tir.PrimExpr | tvm.tir.Call)

tilelang.language.builtin.sync_threads(barrier_id=None, arrive_count=None)

Synchronize all threads in a block.

Parameters:
  • barrier_id (int)

  • arrive_count (int)

tilelang.language.builtin.sync_global()

Synchronize all threads in the entire grid.

tilelang.language.builtin.sync_grid()

Synchronize all threads in a grid.

tilelang.language.builtin.initialize_wgmma_descriptor(descriptor, start_address, layout_type_=0, leading_byte_offset=0, stride_byte_offset=0)

Initialize a WGMMA/UTCMMA shared-memory descriptor.

Parameters:
  • descriptor (tvm.tir.Buffer)

  • start_address (tvm.tir.PrimExpr)

  • layout_type_ (int)

  • leading_byte_offset (int)

  • stride_byte_offset (int)

Return type:

tvm.tir.PrimExpr

tilelang.language.builtin.initialize_tcgen05_descriptor(descriptor, start_address, leading_byte_offset, stride_byte_offset, base_offset=0, leading_is_absolute=False, swizzle_mode=0)

Initialize a TCGEN05 shared-memory descriptor.

Parameters:
  • descriptor (tvm.tir.Buffer)

  • start_address (tvm.tir.PrimExpr)

  • leading_byte_offset (int)

  • stride_byte_offset (int)

  • base_offset (int)

  • leading_is_absolute (bool)

  • swizzle_mode (int)

Return type:

tvm.tir.PrimExpr

tilelang.language.builtin.increase_descriptor_offset(descriptor, offset)

Increase the offset of a memory descriptor.

Parameters:
  • descriptor (PrimExpr) – The memory descriptor to modify.

  • offset (PrimExpr) – The offset value to increase.

Returns:

A handle representing the modified descriptor.

Return type:

PrimExpr

tilelang.language.builtin.loop_break()

Break out of the innermost loop.

tilelang.language.builtin.cp_async_barrier_noinc(barrier_id)

Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.

Parameters:

barrier_id (int | tvm.tir.PrimExpr | tvm.tir.Call)

tilelang.language.builtin.tcgen05_mma_arrive(mbar_ptr)

Signal UMMA (TCGEN05) barrier arrival for a shared-memory mbarrier pointer.

Parameters:

mbar_ptr (PrimExpr) – Pointer to the mbarrier object in shared memory (e.g., Barrier*).

tilelang.language.builtin.ptx_mma_sm70(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, multiplicand_a, a_index, multiplicand_b, b_index, accumulator, c_index)

TVM intrinsic for ptx tensor core mma instructions on SM70 (Volta).

This intrinsic provides SM70-specific MMA operations that support m16n16k4 shape with FP16 inputs and FP16/FP32 accumulation.

Parameters:
  • shape (str) – The shape of mma fragment (e.g., “m16n16k4”).

  • A_layout (str) – The layout of multiplicand fragment A (“row” or “col”).

  • B_layout (str) – The layout of multiplicand fragment B (“row” or “col”).

  • A_dtype (str) – The data type of multiplicand fragment A (typically “fp16”).

  • B_dtype (str) – The data type of multiplicand fragment B (typically “fp16”).

  • C_dtype (str) – The data type of accumulator fragment C (“fp16” or “fp32”).

  • multiplicand_a (Var) – The multiplicand fragment A variable.

  • a_index (Expr) – The index of multiplicand fragment A.

  • multiplicand_b (Var) – The multiplicand fragment B variable.

  • b_index (Expr) – The index of multiplicand fragment B.

  • accumulator (Var) – The accumulator fragment C variable.

  • c_index (Expr) – The index of accumulator fragment C.

Returns:

call – The call expression.

Return type:

PrimExpr

Examples

>>> T.ptx_mma_sm70(
...     "float16",
...     "m16n16k4",
...     "row",
...     "col",
...     "fp16",
...     "fp16",
...     "fp16",
...     A_local.data,
...     0,
...     B_local.data,
...     0,
...     C_local.data,
...     0,
... )