tilelang.tileop.gemm.gemm_base¶

Classes¶

Module Contents¶

class tilelang.tileop.gemm.gemm_base.GemmBase¶
gemm_node: tvm.ir.base.Node¶
abstract infer_layout(target, thread_nums)¶
Parameters:
  • target (tvm.target.Target)

  • thread_nums (int)

abstract lower(target, thread_nums, thread_var)¶
Parameters:
  • target (tvm.target.Target)

  • thread_nums (int)

  • thread_var (tvm.tir.Var)

is_gemm_ss()¶
Return type:

bool

is_gemm_sr()¶
Return type:

bool

is_gemm_rs()¶
Return type:

bool

is_gemm_rr()¶
Return type:

bool

property M: int¶
Return type:

int

property N: int¶
Return type:

int

property K: int¶
Return type:

int

property trans_A: bool¶
Return type:

bool

property trans_B: bool¶
Return type:

bool

property in_dtype: str¶
Return type:

str

property accum_dtype: str¶
Return type:

str

property chunk: int¶
Return type:

int

property A: tvm.tir.Buffer¶
Return type:

tvm.tir.Buffer

property B: tvm.tir.Buffer¶
Return type:

tvm.tir.Buffer

property C: tvm.tir.Buffer¶
Return type:

tvm.tir.Buffer

property ARegion¶
property BRegion¶
property CRegion¶
property stride_A: int¶
Return type:

int

property stride_B: int¶
Return type:

int

property offset_A: int¶
Return type:

int

property offset_B: int¶
Return type:

int

property clear_accum: tvm.ir.PrimExpr¶
Return type:

tvm.ir.PrimExpr

property k_pack: int¶
Return type:

int

property wg_wait: int¶
Return type:

int

property policy: tilelang.ir.GemmWarpPolicy¶
Return type:

tilelang.ir.GemmWarpPolicy

property mbarptr: tvm.ir.PrimExpr¶
Return type:

tvm.ir.PrimExpr

property C_coords¶
get_region_base_offsets(region)¶

Get the base offset (start index) for each dimension from a BufferRegion.

For example, if region is A_shared[ko % 2, 0:128, 0:64], this returns [ko % 2, 0, 0]

Parameters:

region – BufferRegion object

Returns:

List of PrimExpr representing the base offset for each dimension

property A_base_offsets¶

Get base offsets for each dimension of A region

property B_base_offsets¶

Get base offsets for each dimension of B region

property C_base_offsets¶

Get base offsets for each dimension of C region