tilelang.primitives.gemm.gemm_mma¶

Classes¶

GemmPrimitiveMMA

A GEMM (General Matrix Multiply) primitive that uses Tensor Core MMA (Matrix

Module Contents¶

class tilelang.primitives.gemm.gemm_mma.GemmPrimitiveMMA¶

Bases: tilelang.primitives.gemm.base.GemmBaseParams

A GEMM (General Matrix Multiply) primitive that uses Tensor Core MMA (Matrix Multiply and Accumulate) instructions. Inherits from GemmBaseParams which provides basic parameters such as A, B, C buffers and transposition flags.

abstract gemm_rrr(A, B, C, mma_emitter)¶
Parameters:
Return type:

tvm.tir.PrimExpr

gemm_rsr(A, B, C, mma_emitter)¶
Parameters:
Return type:

tvm.tir.PrimExpr

abstract gemm_srr(A, B, C, mma_emitter)¶
Parameters:
Return type:

tvm.tir.PrimExpr

gemm_ssr(A, B, C, mma_emitter)¶

Perform a single-step reduction (SSR) GEMM using Tensor Core MMA primitives. Loads fragments of A and B from shared memory, multiplies them, and accumulates into C.

Parameters:
  • A (tir.Buffer) – The buffer for matrix A (in shared memory).

  • B (tir.Buffer) – The buffer for matrix B (in shared memory).

  • C (tir.Buffer) – The buffer for the accumulation results.

  • mma_emitter (TensorCoreIntrinEmitter) – A helper object responsible for generating Tensor Core MMA instructions (ldmatrix, mma, etc.).

Returns:

The generated IR expression (macro) representing the GEMM loop.

Return type:

tir.PrimExpr

invoke()¶

Entry point to generate a GEMM SSR (single-step reduction) with Tensor Core instructions. Performs the following steps:

  1. Infers block partition parameters if necessary.

  2. Creates a TensorCoreIntrinEmitter with the correct data types and dimensions.

  3. Invokes the GEMM SSR function to generate the final IR expression.

Returns:

The generated GEMM IR expression.

Return type:

tir.PrimExpr

property in_dtype: str¶

returns: The input data type for A and B. Assumes both have the same dtype. :rtype: str

Raises:

AssertionError – If A and B do not share the same dtype.

Return type:

str

property accum_dtype: str¶

returns: The accumulation data type for C. :rtype: str

Return type:

str