tilelang.tileop.gemm¶

Submodules¶

Classes¶

GemmInst

Enum where members are also (and must be) ints

GemmPy

Functions¶

gemm_py_infer_layout(gemm_py, target, thread_bounds)

gemm_py_lower(gemm_py, layout_map, target, ...)

Package Contents¶

tilelang.tileop.gemm.gemm_py_infer_layout(gemm_py, target, thread_bounds)¶
tilelang.tileop.gemm.gemm_py_lower(gemm_py, layout_map, target, thread_bounds, thread_var)¶
class tilelang.tileop.gemm.GemmInst¶

Bases: enum.IntEnum

Enum where members are also (and must be) ints

MMA = 0¶
WGMMA = 1¶
TCGEN5MMA = 2¶
MFMA = 3¶
is_mma()¶
Return type:

bool

is_wgmma()¶
Return type:

bool

is_tcgen5mma()¶
Return type:

bool

is_mfma()¶
Return type:

bool

class tilelang.tileop.gemm.GemmPy¶

Bases: tvm.ir.base.Node, tvm.runtime.Scriptable

A: tvm.tir.Buffer¶
B: tvm.tir.Buffer¶
C: tvm.tir.Buffer¶
APtr: tvm.tir.PrimExpr¶
BPtr: tvm.tir.PrimExpr¶
CPtr: tvm.tir.PrimExpr¶
M: int¶
N: int¶
K: int¶
trans_A: bool¶
trans_B: bool¶
stride_A: int¶
stride_B: int¶
offset_A: int¶
offset_B: int¶
clear_accum: bool¶
k_pack: int¶
wg_wait: int¶
policy: tilelang.ir.GemmWarpPolicy¶
infer_layout(target, thread_nums)¶

Infer the layout for the GEMM operation based on target architecture.

Parameters:
  • target (tvm.target.Target)

  • thread_nums (int)

lower(layout_map, target, thread_nums, thread_var)¶

Lower the GEMM operation to TIR statements based on target architecture.

Parameters:
  • layout_map (dict)

  • target (tvm.target.Target)

  • thread_nums (int)

  • thread_var (tvm.tir.Var)