tilelang.intrinsics.mma_macro_generator¶
Attributes¶
Classes¶
To eliminate Python syntax within TIR Macro. |
|
To eliminate Python syntax within TIR Macro. |
|
To eliminate Python syntax within TIR Macro. |
|
To eliminate Python syntax within TIR Macro. |
Module Contents¶
- tilelang.intrinsics.mma_macro_generator.lift¶
- class tilelang.intrinsics.mma_macro_generator.TensorCoreIntrinEmitter(a_dtype='float16', b_dtype='float16', accum_dtype='float16', a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False)¶
Bases:
object
To eliminate Python syntax within TIR Macro.
- Parameters:
a_dtype (str)
b_dtype (str)
accum_dtype (str)
a_transposed (bool)
b_transposed (bool)
block_row_warps (int)
block_col_warps (int)
warp_row_tiles (int)
warp_col_tiles (int)
chunk (int)
reduce_k (int)
num_elems_per_byte (int)
is_m_first (Optional[bool])
- M_DIM = 16¶
- N_DIM = 16¶
- WARP_SIZE = 32¶
- dtype_abbrv¶
- is_m_first = False¶
- a_dtype = 'float16'¶
- b_dtype = 'float16'¶
- accum_dtype = 'float16'¶
- a_transposed = False¶
- b_transposed = False¶
- block_row_warps = 2¶
- block_col_warps = 2¶
- warp_row_tiles = 8¶
- warp_col_tiles = 8¶
- chunk = 16¶
- warp_rows = 0¶
- warp_cols = 0¶
- reduce_k = 1¶
- threads = 128¶
- num_elems_per_byte = 1¶
- get_store_index_map(inverse=False)¶
- Parameters:
inverse (bool)
- Return type:
tvm.tir.IndexMap
- extract_thread_binding(thread_id, is_m_first=None)¶
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] Otherwise, it is in the form of [warp_size, block_col_warps (split m), block_row_warps (split n)]
- Parameters:
thread_id (tvm.tir.PrimExpr)
is_m_first (Optional[bool])
- Return type:
Tuple[tvm.tir.PrimExpr, tvm.tir.PrimExpr, tvm.tir.PrimExpr]
- ldmatrix_a(A_local_buf, A_shared_buf, ki, rk=0)¶
- Parameters:
A_local_buf (tvm.tir.Buffer)
A_shared_buf (tvm.tir.Buffer)
ki (tvm.tir.PrimExpr)
rk (Optional[tvm.tir.PrimExpr])
- ldmatrix_b(B_local_buf, B_shared_buf, ki, rk=0)¶
- Parameters:
B_local_buf (tvm.tir.Buffer)
B_shared_buf (tvm.tir.Buffer)
ki (tvm.tir.PrimExpr)
rk (Optional[tvm.tir.PrimExpr])
- mma(A_local_buf, B_local_buf, C_local_buf, k_inner=0)¶
- Parameters:
A_local_buf (tvm.tir.Buffer)
B_local_buf (tvm.tir.Buffer)
C_local_buf (tvm.tir.Buffer)
k_inner (Optional[tvm.tir.PrimExpr])
- stmatrix(C_local_buf, C_buf, pid_m=None, pid_n=None)¶
- make_mma_load_layout(local_buf, matrix='A')¶
Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with inverse_mma_store_layout to map fragment indices to threads and local indices.
- Parameters:
local_buf (tir.Buffer) – The local buffer representing a fragment of a matrix.
matrix (Literal['A', 'B'])
- Returns:
A fragment object that describes how threads and indices in local_buf are laid out.
- Return type:
T.Fragment
- Raises:
AssertionError – If local_buf is not detected to be a fragment buffer.
- make_mma_store_layout(local_buf)¶
Create a layout function for storing MMA results into a fragment buffer. This layout is used in conjunction with inverse_mma_store_layout to map fragment indices to threads and local indices.
- Parameters:
local_buf (tir.Buffer) – The local buffer representing a fragment of a matrix.
- Returns:
A fragment object that describes how threads and indices in local_buf are laid out.
- Return type:
T.Fragment
- Raises:
AssertionError – If local_buf is not detected to be a fragment buffer.
- class tilelang.intrinsics.mma_macro_generator.TensorCoreIntrinEmitterWithLadderTransform(a_dtype='float16', b_dtype='float16', accum_dtype='float16', a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False, transform_kind_a=0, transform_kind_b=0)¶
Bases:
TensorCoreIntrinEmitter
To eliminate Python syntax within TIR Macro. With Ladder Transform Plugin.
- Parameters:
a_dtype (str)
b_dtype (str)
accum_dtype (str)
a_transposed (bool)
b_transposed (bool)
block_row_warps (int)
block_col_warps (int)
warp_row_tiles (int)
warp_col_tiles (int)
chunk (int)
reduce_k (int)
num_elems_per_byte (int)
is_m_first (Optional[bool])
transform_kind_a (Union[int, tilelang.common.TransformKind])
transform_kind_b (Union[int, tilelang.common.TransformKind])
- ldmatrix_a(A_local_buf, A_shared_buf, ki, rk=0)¶
- ldmatrix_b(B_local_buf, B_shared_buf, ki, rk=0)¶
- mma(A_local_buf, B_local_buf, C_local_buf)¶
- class tilelang.intrinsics.mma_macro_generator.INT4TensorCoreIntrinEmitter(a_dtype='float16', b_dtype='float16', accum_dtype='float16', a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False)¶
Bases:
TensorCoreIntrinEmitter
To eliminate Python syntax within TIR Macro.
- Parameters:
a_dtype (str)
b_dtype (str)
accum_dtype (str)
a_transposed (bool)
b_transposed (bool)
block_row_warps (int)
block_col_warps (int)
warp_row_tiles (int)
warp_col_tiles (int)
chunk (int)
reduce_k (int)
num_elems_per_byte (int)
is_m_first (Optional[bool])
- mma(A_local_buf, B_local_buf, C_local_buf)¶
- class tilelang.intrinsics.mma_macro_generator.INT4TensorCoreIntrinEmitterWithLadderTransform(a_dtype='float16', b_dtype='float16', accum_dtype='float16', a_transposed=False, b_transposed=False, block_row_warps=2, block_col_warps=2, warp_row_tiles=8, warp_col_tiles=8, chunk=16, reduce_k=1, num_elems_per_byte=1, is_m_first=False, transform_kind_a=0, transform_kind_b=0)¶
Bases:
TensorCoreIntrinEmitterWithLadderTransform
To eliminate Python syntax within TIR Macro. With Ladder Transform Plugin.
- Parameters:
a_dtype (str)
b_dtype (str)
accum_dtype (str)
a_transposed (bool)
b_transposed (bool)
block_row_warps (int)
block_col_warps (int)
warp_row_tiles (int)
warp_col_tiles (int)
chunk (int)
reduce_k (int)
num_elems_per_byte (int)
is_m_first (Optional[bool])
transform_kind_a (Union[int, tilelang.common.TransformKind])
transform_kind_b (Union[int, tilelang.common.TransformKind])
- mma(A_local_buf, B_local_buf, C_local_buf)¶