tilelang.carver.roller.policy.tensorcore ======================================== .. py:module:: tilelang.carver.roller.policy.tensorcore .. autoapi-nested-parse:: Policy for tensorcore schedule Attributes ---------- .. autoapisummary:: tilelang.carver.roller.policy.tensorcore.logger Classes ------- .. autoapisummary:: tilelang.carver.roller.policy.tensorcore.TensorCorePolicy Module Contents --------------- .. py:data:: logger .. py:class:: TensorCorePolicy(arch, tags = None) Bases: :py:obj:`tilelang.carver.roller.policy.default.DefaultPolicy` Default Policy for fastdlight, a heuristic plan that tries to minimize memory traffic and maximize parallelism.for BitBLAS Schedule. .. py:attribute:: wmma_k :type: int :value: 16 .. py:attribute:: pipeline_stage :type: int :value: 1 .. py:attribute:: use_async_copy :type: bool :value: False .. py:attribute:: block_reduction_depth :type: Optional[int] :value: None .. py:method:: infer_node_smem_usage(td, node) Infers the shared memory usage of a node given a TileDict configuration. :param td: The TileDict object containing the tile configuration. :type td: TileDict :param node: The node for which to infer the shared memory usage. :type node: PrimFuncNode :returns: The estimated amount of shared memory used by the node. :rtype: int .. py:method:: get_node_reduce_step_candidates(node) Calculates reduction step candidates for each reduction axis in a PrimFuncNode. General idea : use factor first, since it does not require extra boundary check. for large prime number, which is rare case, use power of 2. :param node: The node for which to calculate reduction step candidates. It contains reduction axes (raxis) with their domains (dom.extent). :type node: PrimFuncNode :returns: A dictionary mapping axis variable names to lists of step candidates. For each axis in the node, this function calculates possible step sizes. For axes with a large prime domain, it uses powers of 2 as step candidates; for others, it uses all factors of the domain. :rtype: Dict[str, List[int]] .. py:method:: check_tile_shape_isvalid(td) Checks if the tile shapes in the TileDict are valid for the nodes in this context. Parameters: - td (TileDict): The TileDict object containing tile shapes and other configurations. Returns: - bool: True if all tile shapes are valid, False otherwise. .. py:method:: compute_node_stride_map(node, td) Computes the stride map for a given node based on the TileDict configuration. :param node: The node for which to compute the stride map. :type node: PrimFuncNode :param td: The TileDict object containing the tile configuration. :type td: TileDict :returns: A tuple of dictionaries containing the output strides and tensor strides. :rtype: Tuple[Dict, Dict] .. py:method:: plan_rasterization(td) Plans the rasterization for the given TileDict. This function is not implemented yet. :param td: The TileDict object to plan rasterization for. :type td: TileDict :raises RasterRationPlan: This function is not implemented yet.