tilelang.language.gemmΒΆ
The language interface for tl programs.
FunctionsΒΆ
|
Perform a General Matrix Multiplication (GEMM) operation. |
Module ContentsΒΆ
- tilelang.language.gemm.gemm(A, B, C, transpose_A=False, transpose_B=False, policy=GemmWarpPolicy.Square, clear_accum=False, k_pack=1, wg_wait=0)ΒΆ
Perform a General Matrix Multiplication (GEMM) operation.
This function computes C = A @ B where A and B can optionally be transposed. The operation supports various warp policies and accumulation modes.
- Parameters:
A (Union[tir.Buffer, tir.Var]) β First input matrix
B (Union[tir.Buffer, tir.Var]) β Second input matrix
C (Union[tir.Buffer, tir.Var]) β Output matrix for results
transpose_A (bool, optional) β Whether to transpose matrix A. Defaults to False.
transpose_B (bool, optional) β Whether to transpose matrix B. Defaults to False.
policy (GemmWarpPolicy, optional) β Warp execution policy. Defaults to GemmWarpPolicy.Square.
clear_accum (bool, optional) β Whether to clear accumulator before computation. Defaults to False.
k_pack (int, optional) β Number of k dimensions packed into a single warp. Defaults to 1.
wg_wait (int, optional) β Warp group wait count. Defaults to 0.
- Returns:
A handle to the GEMM operation
- Return type:
tir.Call
- Raises:
AssertionError β If the K dimensions of matrices A and B donβt match