tilelang.tileop.gemm¶
Submodules¶
Classes¶
Functions¶
|
|
|
Package Contents¶
- tilelang.tileop.gemm.gemm_py_infer_layout(gemm_py, target, thread_bounds)¶
- tilelang.tileop.gemm.gemm_py_lower(gemm_py, layout_map, target, thread_bounds, thread_var)¶
- class tilelang.tileop.gemm.GemmInst¶
Bases:
enum.IntEnumEnum where members are also (and must be) ints
- MMA = 0¶
- WGMMA = 1¶
- TCGEN5MMA = 2¶
- MFMA = 3¶
- is_mma()¶
- Return type:
bool
- is_wgmma()¶
- Return type:
bool
- is_tcgen5mma()¶
- Return type:
bool
- is_mfma()¶
- Return type:
bool
- class tilelang.tileop.gemm.GemmPy¶
Bases:
tvm.ir.base.Node,tvm.runtime.Scriptable- A: tvm.tir.Buffer¶
- B: tvm.tir.Buffer¶
- C: tvm.tir.Buffer¶
- APtr: tvm.tir.PrimExpr¶
- BPtr: tvm.tir.PrimExpr¶
- CPtr: tvm.tir.PrimExpr¶
- M: int¶
- N: int¶
- K: int¶
- trans_A: bool¶
- trans_B: bool¶
- stride_A: int¶
- stride_B: int¶
- offset_A: int¶
- offset_B: int¶
- clear_accum: bool¶
- k_pack: int¶
- wg_wait: int¶
- policy: tilelang.ir.GemmWarpPolicy¶
- infer_layout(target, thread_nums)¶
Infer the layout for the GEMM operation based on target architecture.
- Parameters:
target (tvm.target.Target)
thread_nums (int)
- lower(layout_map, target, thread_nums, thread_var)¶
Lower the GEMM operation to TIR statements based on target architecture.
- Parameters:
layout_map (dict)
target (tvm.target.Target)
thread_nums (int)
thread_var (tvm.tir.Var)