tilelang.language.gemmΒΆ

The language interface for tl programs.

FunctionsΒΆ

gemm(A, B, C[, transpose_A, transpose_B, policy, ...])

Perform a General Matrix Multiplication (GEMM) operation.

Module ContentsΒΆ

tilelang.language.gemm.gemm(A, B, C, transpose_A=False, transpose_B=False, policy=GemmWarpPolicy.Square, clear_accum=False, k_pack=1, wg_wait=0)ΒΆ

Perform a General Matrix Multiplication (GEMM) operation.

This function computes C = A @ B where A and B can optionally be transposed. The operation supports various warp policies and accumulation modes.

Parameters:
  • A (Union[tir.Buffer, tir.Var]) – First input matrix

  • B (Union[tir.Buffer, tir.Var]) – Second input matrix

  • C (Union[tir.Buffer, tir.Var]) – Output matrix for results

  • transpose_A (bool, optional) – Whether to transpose matrix A. Defaults to False.

  • transpose_B (bool, optional) – Whether to transpose matrix B. Defaults to False.

  • policy (GemmWarpPolicy, optional) – Warp execution policy. Defaults to GemmWarpPolicy.Square.

  • clear_accum (bool, optional) – Whether to clear accumulator before computation. Defaults to False.

  • k_pack (int, optional) – Number of k dimensions packed into a single warp. Defaults to 1.

  • wg_wait (int, optional) – Warp group wait count. Defaults to 0.

Returns:

A handle to the GEMM operation

Return type:

tir.Call

Raises:

AssertionError – If the K dimensions of matrices A and B don’t match