tilelang.tileop.gemm¶

Submodules¶

Classes¶

Functions¶

gemm_py_infer_layout(gemm_py, target, thread_bounds)

gemm_py_lower(gemm_py, target, thread_bounds, thread_var)

Package Contents¶

tilelang.tileop.gemm.gemm_py_infer_layout(gemm_py, target, thread_bounds)¶
tilelang.tileop.gemm.gemm_py_lower(gemm_py, target, thread_bounds, thread_var)¶
class tilelang.tileop.gemm.GemmPy¶

Bases: tvm.ir.base.Node, tvm.runtime.Scriptable

A: tvm.tir.Buffer¶
B: tvm.tir.Buffer¶
C: tvm.tir.Buffer¶
APtr: tvm.tir.PrimExpr¶
BPtr: tvm.tir.PrimExpr¶
CPtr: tvm.tir.PrimExpr¶
M: int¶
N: int¶
K: int¶
trans_A: bool¶
trans_B: bool¶
stride_A: int¶
stride_B: int¶
offset_A: int¶
offset_B: int¶
clear_accum: bool¶
k_pack: int¶
wg_wait: int¶
policy: tilelang.ir.GemmWarpPolicy¶
infer_layout(target, thread_nums)¶
Parameters:
  • target (tvm.target.Target)

  • thread_nums (int)

lower(target, thread_nums, thread_var)¶
Parameters:
  • target (tvm.target.Target)

  • thread_nums (int)

  • thread_var (tvm.tir.Var)