tilelang.primitives.gemm.base ============================= .. py:module:: tilelang.primitives.gemm.base Classes ------- .. autoapisummary:: tilelang.primitives.gemm.base.GemmWarpPolicy tilelang.primitives.gemm.base.GemmBaseParams Module Contents --------------- .. py:class:: GemmWarpPolicy Bases: :py:obj:`enum.IntEnum` Enumeration for GEMM Warp Partitioning Policies. .. py:attribute:: Square :value: 0 .. py:attribute:: FullRow :value: 1 .. py:attribute:: FullCol :value: 2 .. py:method:: is_square() Check if the policy is a square partitioning. :returns: True if the policy is square, False otherwise. :rtype: bool .. py:method:: is_full_row() Check if the policy is a full row partitioning. :returns: True if the policy is full row, False otherwise. :rtype: bool .. py:method:: is_full_col() Check if the policy is a full column partitioning. :returns: True if the policy is full column, False otherwise. :rtype: bool .. py:method:: to_prime_factors(num) :staticmethod: Compute the prime factorization of a given number. :param num: The number to factorize. :type num: int :returns: A list of prime factors of the number. :rtype: list .. py:method:: compute_warp_partition(M, N, num_warps) Compute the warp partition (m_warp, n_warp) based on the given policy. :param M: The number of rows in the GEMM workload. :type M: int :param N: The number of columns in the GEMM workload. :type N: int :param num_warps: The total number of warps available. :type num_warps: int :returns: A tuple (m_warp, n_warp) representing the partitioning of warps. :rtype: tuple :raises ValueError: If the policy is invalid or the partitioning fails. :raises AssertionError: If M or N is not divisible by the required factor for FullRow or FullCol policies. .. py:method:: from_warp_partition(m_warp, n_warp) :classmethod: Determine the warp policy based on the given warp partitioning. :param m_warp: Number of warps in the row dimension :type m_warp: int :param n_warp: Number of warps in the column dimension :type n_warp: int :returns: The corresponding warp policy :rtype: GemmWarpPolicy .. rubric:: Examples >>> GemmWarpPolicy.from_block_row_cols(4, 1) # All warps in rows GemmWarpPolicy.FullRow >>> GemmWarpPolicy.from_block_row_cols(1, 4) # All warps in columns GemmWarpPolicy.FullCol >>> GemmWarpPolicy.from_block_row_cols(2, 2) # Balanced distribution GemmWarpPolicy.Square .. py:class:: GemmBaseParams .. py:attribute:: A :type: tvm.tir.Buffer .. py:attribute:: B :type: tvm.tir.Buffer .. py:attribute:: C :type: tvm.tir.Buffer .. py:attribute:: transpose_A :type: bool :value: False .. py:attribute:: transpose_B :type: bool :value: False .. py:attribute:: block_row_warps :type: Optional[int] :value: None .. py:attribute:: block_col_warps :type: Optional[int] :value: None .. py:attribute:: warp_row_tiles :type: Optional[int] :value: None .. py:attribute:: warp_col_tiles :type: Optional[int] :value: None .. py:attribute:: chunk :type: Optional[int] :value: None .. py:attribute:: policy :type: GemmWarpPolicy .. py:attribute:: k_pack :type: int :value: 1 .. py:method:: get_warp_size() .. py:method:: params_as_dict() .. py:method:: infer_block_partition(threads) Infer and set block partition parameters (e.g., block_row_warps, block_col_warps, warp_row_tiles, warp_col_tiles, chunk) based on the shape of A and B. If these parameters are not already specified, the method will attempt to infer them automatically based on the given `threads`. :param threads: The total number of threads in a block. Must be provided if any block partition parameter is not already set. :type threads: Optional[int] :raises AssertionError: If `threads` is None but any block partition parameter is missing, or if A and B have inconsistent shapes for GEMM. .. py:property:: class_attributes .. py:method:: __repr__()