tilelang.primitives.gemm.base¶

Classes¶

GemmWarpPolicy

Enumeration for GEMM Warp Partitioning Policies.

GemmBaseParams

Module Contents¶

class tilelang.primitives.gemm.base.GemmWarpPolicy¶

Bases: enum.IntEnum

Enumeration for GEMM Warp Partitioning Policies.

Square = 0¶
FullRow = 1¶
FullCol = 2¶
is_square()¶

Check if the policy is a square partitioning.

Returns:

True if the policy is square, False otherwise.

Return type:

bool

is_full_row()¶

Check if the policy is a full row partitioning.

Returns:

True if the policy is full row, False otherwise.

Return type:

bool

is_full_col()¶

Check if the policy is a full column partitioning.

Returns:

True if the policy is full column, False otherwise.

Return type:

bool

static to_prime_factors(num)¶

Compute the prime factorization of a given number.

Parameters:

num (int) – The number to factorize.

Returns:

A list of prime factors of the number.

Return type:

list

compute_warp_partition(M, N, num_warps)¶

Compute the warp partition (m_warp, n_warp) based on the given policy.

Parameters:
  • M (int) – The number of rows in the GEMM workload.

  • N (int) – The number of columns in the GEMM workload.

  • num_warps (int) – The total number of warps available.

Returns:

A tuple (m_warp, n_warp) representing the partitioning of warps.

Return type:

tuple

Raises:
  • ValueError – If the policy is invalid or the partitioning fails.

  • AssertionError – If M or N is not divisible by the required factor for FullRow or FullCol policies.

classmethod from_warp_partition(m_warp, n_warp)¶

Determine the warp policy based on the given warp partitioning.

Parameters:
  • m_warp (int) – Number of warps in the row dimension

  • n_warp (int) – Number of warps in the column dimension

Returns:

The corresponding warp policy

Return type:

GemmWarpPolicy

Examples

>>> GemmWarpPolicy.from_block_row_cols(4, 1)  # All warps in rows
GemmWarpPolicy.FullRow
>>> GemmWarpPolicy.from_block_row_cols(1, 4)  # All warps in columns
GemmWarpPolicy.FullCol
>>> GemmWarpPolicy.from_block_row_cols(2, 2)  # Balanced distribution
GemmWarpPolicy.Square
class tilelang.primitives.gemm.base.GemmBaseParams¶
A: tvm.tir.Buffer¶
B: tvm.tir.Buffer¶
C: tvm.tir.Buffer¶
transpose_A: bool = False¶
transpose_B: bool = False¶
block_row_warps: int | None = None¶
block_col_warps: int | None = None¶
warp_row_tiles: int | None = None¶
warp_col_tiles: int | None = None¶
chunk: int | None = None¶
policy: GemmWarpPolicy¶
k_pack: int = 1¶
get_warp_size()¶
Return type:

int

params_as_dict()¶
infer_block_partition(threads)¶

Infer and set block partition parameters (e.g., block_row_warps, block_col_warps, warp_row_tiles, warp_col_tiles, chunk) based on the shape of A and B. If these parameters are not already specified, the method will attempt to infer them automatically based on the given threads.

Parameters:

threads (Optional[int]) – The total number of threads in a block. Must be provided if any block partition parameter is not already set.

Raises:

AssertionError – If threads is None but any block partition parameter is missing, or if A and B have inconsistent shapes for GEMM.

Return type:

None

property class_attributes¶
__repr__()¶
Return type:

str