tilelang.language.builtin¶
The language interface for tl programs.
Functions¶
|
Create a list of memory barrier handles. |
|
Retrieve a memory barrier operation. |
|
Create a Tensor Memory Access (TMA) descriptor. |
|
Perform a Tensor Memory Access (TMA) load operation. |
|
Create a fence for asynchronous proxy operations. |
|
Signal the arrival of a TMA store operation. |
|
Wait for completion of TMA store operations. |
|
Set the maximum number of registers to use. |
|
Increment the maximum number of registers to use. |
|
Decrement the maximum number of registers to use. |
Disable the maximum register limit setting. |
|
|
Wait for memory barrier parity condition. |
|
Arrive at memory barrier. |
|
Set expected transaction count for memory barrier. |
|
Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete. |
|
Wait for a memory barrier to complete. |
|
Arrive at a memory barrier. |
|
Perform a shuffle operation with XOR offset. |
|
Perform a shuffle operation with down offset. |
|
Perform a shuffle operation with up offset. |
Synchronize all threads in a warp. |
|
|
Synchronize threads within a warp. |
Synchronize all threads in a block. |
|
Synchronize all threads in a grid. |
Module Contents¶
- tilelang.language.builtin.create_list_of_mbarrier(*args)¶
Create a list of memory barrier handles.
- Parameters:
*args (list or Any) – Either a single list of arguments, or multiple arguments directly.
- Returns:
Handle to the created list of memory barriers.
- Return type:
tvm.tir.Call
- Raises:
TypeError – If the input is not a list or variadic arguments.
Examples
>>> create_list_of_mbarrier([128, 128]) >>> create_list_of_mbarrier(128, 128)
- tilelang.language.builtin.get_mbarrier(*args)¶
Retrieve a memory barrier operation.
- Parameters:
*args – Variable arguments to specify which memory barrier to retrieve
- Returns:
A handle to the requested memory barrier
- Return type:
tir.Call
- tilelang.language.builtin.create_tma_descriptor(*args)¶
Create a Tensor Memory Access (TMA) descriptor.
- Parameters:
*args – Variable arguments defining the TMA descriptor configuration
- Returns:
A handle to the created TMA descriptor
- Return type:
tir.Call
- tilelang.language.builtin.tma_load(*args)¶
Perform a Tensor Memory Access (TMA) load operation.
- Parameters:
*args – Variable arguments specifying the TMA load parameters
- Returns:
A handle to the TMA load operation
- Return type:
tir.Call
- tilelang.language.builtin.fence_proxy_async(*args)¶
Create a fence for asynchronous proxy operations.
- Parameters:
*args – Variable arguments for fence configuration
- Returns:
A handle to the fence operation
- Return type:
tir.Call
- tilelang.language.builtin.tma_store_arrive(*args)¶
Signal the arrival of a TMA store operation.
- Parameters:
*args – Variable arguments for the store arrival operation
- Returns:
A handle to the store arrive operation
- Return type:
tir.Call
- tilelang.language.builtin.tma_store_wait(*args)¶
Wait for completion of TMA store operations.
- Parameters:
*args – Variable arguments specifying which store operations to wait for
- Returns:
A handle to the store wait operation
- Return type:
tir.Call
- tilelang.language.builtin.set_max_nreg(reg_count, is_inc)¶
Set the maximum number of registers to use. Detailed Documentation: https://docs.nvidia.com/cuda/parallel-thread-execution/#miscellaneous-instructions-setmaxnreg
- Parameters:
reg_count (int) – int The number of registers to allocate
is_inc (int) – int Whether to increment or decrement the register count 0 if decrement, 1 if increment
- Returns:
A handle to the register setting operation
- Return type:
tir.Call
- tilelang.language.builtin.inc_max_nreg(reg_count)¶
Increment the maximum number of registers to use.
- Parameters:
reg_count (int)
- tilelang.language.builtin.dec_max_nreg(reg_count)¶
Decrement the maximum number of registers to use.
- Parameters:
reg_count (int)
- tilelang.language.builtin.no_set_max_nreg()¶
Disable the maximum register limit setting.
- tilelang.language.builtin.mbarrier_wait_parity(mbarrier, parity)¶
Wait for memory barrier parity condition.
- Parameters:
mbarrier (Union[int, tvm.tir.PrimExpr, tvm.tir.Call]) – Optional[int, PrimExpr] The memory barrier to wait on
parity (Union[int, tvm.tir.Var]) – Optional[int, Var] The parity value to wait for
Examples
# Wait for parity 0 on barrier 0 T.mbarrier_wait_parity(0, 0) # Wait for parity value in variable ko on barrier 1 T.mbarrier_wait_parity(1, ko) # Wait using barrier handle barrier = T.get_mbarrier(0) T.mbarrier_wait_parity(barrier, 1) # Common usage in pipelined kernels: for ko in range(num_stages): # Producer waits for consumer to finish previous iteration T.mbarrier_wait_parity(1, ko ^ 1) # Producer copies data T.copy(A_global, A_shared) # Producer signals data ready T.mbarrier_arrive(0) # Consumer waits for producer data T.mbarrier_wait_parity(0, ko) # Consumer computes T.gemm(A_shared, B_shared, C_local) # Consumer signals completion T.mbarrier_arrive(1)
- Returns:
A handle to the barrier wait operation
- Return type:
tir.Call
- Parameters:
mbarrier (Union[int, tvm.tir.PrimExpr, tvm.tir.Call])
parity (Union[int, tvm.tir.Var])
- tilelang.language.builtin.mbarrier_arrive(mbarrier)¶
Arrive at memory barrier.
- Parameters:
mbarrier (Union[int, tvm.tir.PrimExpr, tvm.tir.Call]) – Optional[int, PrimExpr] The memory barrier to arrive at
- tilelang.language.builtin.mbarrier_expect_tx(*args)¶
Set expected transaction count for memory barrier.
- Parameters:
*args – Variable arguments specifying the expected transaction count
- Returns:
A handle to the barrier expectation operation
- Return type:
tir.Call
- tilelang.language.builtin.wait_wgmma(id)¶
Wait for WGMMA (Warp Group Matrix Multiply-Accumulate) operations to complete.
- Parameters:
id (int) – int The id of the WGMMA operation to wait for
- Returns:
A handle to the WGMMA wait operation
- Return type:
tir.Call
- tilelang.language.builtin.barrier_wait(barrier_id, parity=None)¶
Wait for a memory barrier to complete.
- Parameters:
barrier_id (Union[int, tvm.tir.PrimExpr, tvm.tir.Call]) – Optional[int, PrimExpr] The memory barrier to wait on
parity (Union[int, tvm.tir.Var, None]) – Optional[int, Var] The parity value to wait for
- Returns:
A handle to the barrier wait operation
- Return type:
tir.Call
Current implementation is a sugar syntax for mbarrier_wait_parity, as we only support parity 0 and 1.
- tilelang.language.builtin.barrier_arrive(barrier_id)¶
Arrive at a memory barrier.
- Parameters:
barrier_id (Union[int, tvm.tir.PrimExpr, tvm.tir.Call]) – Optional[int, PrimExpr] The memory barrier to arrive at
- tilelang.language.builtin.shfl_xor(value, offset)¶
Perform a shuffle operation with XOR offset.
- Parameters:
value (Union[int, tvm.tir.PrimExpr, tvm.tir.Call]) – Optional[int, PrimExpr] The value to shuffle
offset (Union[int, tvm.tir.PrimExpr, tvm.tir.Call]) – Optional[int, PrimExpr] The offset for the shuffle operation
- Returns:
A handle to the shuffle operation
- Return type:
tir.Call
- tilelang.language.builtin.shfl_down(value, offset)¶
Perform a shuffle operation with down offset.
- Parameters:
value (Union[int, tvm.tir.PrimExpr, tvm.tir.Call]) – Optional[int, PrimExpr] The value to shuffle
offset (Union[int, tvm.tir.PrimExpr, tvm.tir.Call])
- tilelang.language.builtin.shfl_up(value, offset)¶
Perform a shuffle operation with up offset.
- Parameters:
value (Union[int, tvm.tir.PrimExpr, tvm.tir.Call]) – Optional[int, PrimExpr] The value to shuffle
offset (Union[int, tvm.tir.PrimExpr, tvm.tir.Call])
- tilelang.language.builtin.sync_threads()¶
Synchronize all threads in a warp.
- tilelang.language.builtin.sync_thread_partial(barrier_id)¶
Synchronize threads within a warp.
- Parameters:
barrier_id (Union[int, tvm.tir.PrimExpr, tvm.tir.Call]) – Optional[int, PrimExpr] The memory barrier to synchronize
- Returns:
A handle to the synchronization operation
- Return type:
tir.Call
- tilelang.language.builtin.sync_global()¶
Synchronize all threads in a block.
- tilelang.language.builtin.sync_grid()¶
Synchronize all threads in a grid.