tilelang.carver.matmul_analysis¶
A GEMM schedule rule for GPU operators.
Attributes¶
Classes¶
Functions¶
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
traverse to find the arg index from the buffer |
|
|
|
Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K] |
|
Get index maps for the block |
|
Detect In/Out data types for the given block based on the analysis if read/write buffers. |
|
|
|
|
|
|
|
|
|
|
|
|
|
transform function to matmul if necessary (e.g. transform conv2d with im2col) |
|
|
|
|
|
Module Contents¶
- tilelang.carver.matmul_analysis.logger¶
- tilelang.carver.matmul_analysis.collect_vars_from_expr(prim_expr)¶
- tilelang.carver.matmul_analysis.auto_inline_producers(sch, block, skip_blocks=None)¶
- Parameters:
sch (tvm.tir.Schedule)
block (tvm.tir.schedule.BlockRV)
skip_blocks (Optional[List[tvm.tir.schedule.BlockRV]])
- tilelang.carver.matmul_analysis.auto_inline_consumers(sch, block)¶
- Parameters:
sch (tvm.tir.Schedule)
block (tvm.tir.schedule.BlockRV)
- tilelang.carver.matmul_analysis.auto_inline_consumer_chain(sch, block)¶
- Parameters:
sch (tvm.tir.Schedule)
block (tvm.tir.schedule.BlockRV)
- tilelang.carver.matmul_analysis.find_first_similar_region(regions, buffer)¶
- Parameters:
regions (List[tvm.tir.BufferRegion])
buffer (tvm.tir.Buffer)
- tilelang.carver.matmul_analysis.find_first_similar_buffer(regions, buffer)¶
- Parameters:
regions (List[tvm.tir.BufferRegion])
buffer (tvm.tir.Buffer)
- tilelang.carver.matmul_analysis.find_last_producer_from_buffer(sch, main_block, buffer)¶
- Parameters:
buffer (tvm.tir.Buffer)
- Return type:
Optional[tvm.tir.schedule.schedule.BlockRV]
- tilelang.carver.matmul_analysis.find_arg_idx_from_buffer_chain(sch, main_block, buffer)¶
traverse to find the arg index from the buffer
- Parameters:
sch (tvm.tir.Schedule)
main_block (tvm.tir.schedule.BlockRV)
buffer (tvm.tir.Buffer)
- Return type:
int
- class tilelang.carver.matmul_analysis.IterKind¶
Bases:
enum.Enum
Iter kinds for GEMM-liked programs. We can simplify the computation to C[S, I, J] += A[S, I, K] * B[S, J, K], where I, J, K are fundamental axes for gemm and S represents all other spatial axes (e.g. batches) kIter_S: spatial axes kIter_I: I axes kIter_J: J axes kIter_K: K axes kIter_T: trivial axes (i.e. with extent 1)
- kIter_S = 0¶
- kIter_I = 1¶
- kIter_J = 2¶
- kIter_K = 3¶
- kIter_T = 4¶
- tilelang.carver.matmul_analysis.make_iter_fusion_index_map(traits, kind_order)¶
- tilelang.carver.matmul_analysis.detect_iter_traits(block)¶
Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K]
- Parameters:
block (tir.Block) – The block to be analyzed
- Returns:
traits – The detected iter traits for axes in A, B and C. None if the block does not match the pattern.
- Return type:
Optional[Tuple[List[IterTrait]]]
- tilelang.carver.matmul_analysis.get_index_map(block, layout=None)¶
Get index maps for the block
- Parameters:
block (tir.Block) – The block to be analyzed
layout (List[str]) – the target layout index map to be used. ‘n’ for [i, k] layout ‘t’ for [k, j] layout ‘a’ for auto inference based on whether the last axis is reduction.
- Returns:
index_maps – The index maps for the block, or None if the block is not a gemm-liked kernel
- Return type:
Optional[Tuple[tir.IndexMap]]
- tilelang.carver.matmul_analysis.get_in_out_dtypes(block)¶
Detect In/Out data types for the given block based on the analysis if read/write buffers.
- Parameters:
block (tvm.tir.Block)
- Return type:
Tuple[str]
- tilelang.carver.matmul_analysis.get_dequantize_block(sch, blocks)¶
- Return type:
Optional[tvm.tir.schedule.schedule.BlockRV]
- tilelang.carver.matmul_analysis.is_identity_or_transpose_block(block_stmt)¶
- Parameters:
block_stmt (tvm.tir.Block)
- Return type:
bool
- tilelang.carver.matmul_analysis.is_identity_block(block_stmt)¶
- Parameters:
block_stmt (tvm.tir.Block)
- Return type:
bool
- tilelang.carver.matmul_analysis.is_transpose_block(block_stmt)¶
- Parameters:
block_stmt (tvm.tir.Block)
- Return type:
bool
- tilelang.carver.matmul_analysis.inline_transpose_block(sch, blocks)¶
- Parameters:
sch (tvm.tir.Schedule)
blocks (List[tvm.tir.schedule.BlockRV])
- tilelang.carver.matmul_analysis.normalize_to_matmul(sch, main_block, layout=None)¶
- Parameters:
sch (tvm.tir.Schedule)
main_block (tvm.tir.schedule.schedule.BlockRV)
layout (Optional[List[str]])
- Return type:
Optional[tvm.tir.Schedule]
- tilelang.carver.matmul_analysis.get_tensorized_func_and_tags(func, target, layout=None, skip_normalize=False, allow_gemv=False)¶
transform function to matmul if necessary (e.g. transform conv2d with im2col)
- Parameters:
func (tvm.tir.PrimFunc)
target (tvm.target.target.Target)
layout (Optional[List[str]])
skip_normalize (bool)
allow_gemv (bool)
- Return type:
Tuple[tvm.tir.PrimFunc, Dict[str, Union[List[int], int]]]
- tilelang.carver.matmul_analysis.get_propagate_map(trans=True, dtype='float16', matrix_name='A', index_dtype='int32')¶
- Parameters:
trans (bool)
- tilelang.carver.matmul_analysis.get_ladder_stage3_map(dtype='float16', index_dtype='int32')¶
- tilelang.carver.matmul_analysis.layout_propagate_chain(sch, start_block, start_buffer, end_block, index_map)¶
- Parameters:
sch (tvm.tir.Schedule)
start_block (tvm.tir.schedule.schedule.BlockRV)
start_buffer (tvm.tir.Buffer)
end_block (tvm.tir.schedule.schedule.BlockRV)
index_map (tvm.tir.IndexMap)