tilelang.tileop.gemm_sp¶

Submodules¶

Classes¶

Functions¶

gemm_sp_py_infer_layout(gemm_sp_py, target, thread_bounds)

gemm_sp_py_lower(gemm_sp_py, target, thread_bounds, ...)

Package Contents¶

tilelang.tileop.gemm_sp.gemm_sp_py_infer_layout(gemm_sp_py, target, thread_bounds)¶
Parameters:
tilelang.tileop.gemm_sp.gemm_sp_py_lower(gemm_sp_py, target, thread_bounds, thread_var)¶
Parameters:
  • gemm_sp_py (gemm_sp_mma.GemmSPMMA)

  • target (tvm.target.Target)

  • thread_bounds (tvm.ir.Range)

  • thread_var (tvm.tir.Var)

class tilelang.tileop.gemm_sp.GemmSPPy¶

Bases: tvm.ir.base.Node, tvm.runtime.Scriptable

A: tvm.tir.Buffer¶
E: tvm.tir.Buffer¶
B: tvm.tir.Buffer¶
C: tvm.tir.Buffer¶
APtr: tvm.tir.PrimExpr¶
EPtr: tvm.tir.PrimExpr¶
BPtr: tvm.tir.PrimExpr¶
CPtr: tvm.tir.PrimExpr¶
M: int¶
N: int¶
K: int¶
trans_A: bool¶
trans_B: bool¶
stride_A: int¶
stride_B: int¶
offset_A: int¶
offset_B: int¶
clear_accum: bool¶
k_pack: int¶
wg_wait: int¶
policy: tilelang.ir.GemmWarpPolicy¶
infer_layout(target, thread_nums)¶
Parameters:
  • target (tvm.target.Target)

  • thread_nums (int)

lower(target, thread_nums, thread_var)¶
Parameters:
  • target (tvm.target.Target)

  • thread_nums (int)

  • thread_var (tvm.tir.Var)