tilelang.jit.adapter.utils¶
Classes¶
Parsed TMA descriptor parameters. |
Functions¶
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence. |
|
Check if a parameter name corresponds to a TMA descriptor. |
|
Parse function call arguments from a kernel declaration. |
|
Parse TMA descriptor arguments into structured parameters. |
Module Contents¶
- tilelang.jit.adapter.utils.match_global_kernel(source, annotation='__global__')¶
- Parameters:
source (str)
annotation (str)
- Return type:
int
- tilelang.jit.adapter.utils.match_declare_kernel(source, annotation='__global__')¶
- Parameters:
source (str)
annotation (str)
- Return type:
int
- tilelang.jit.adapter.utils.match_declare_kernel_cpu(source, annotation='int32_t')¶
- Parameters:
source (str)
annotation (str)
- Return type:
int
- tilelang.jit.adapter.utils.is_cuda_target(target)¶
- Parameters:
target (tvm.target.Target)
- Return type:
- tilelang.jit.adapter.utils.is_hip_target(target)¶
- Parameters:
target (tvm.target.Target)
- Return type:
- tilelang.jit.adapter.utils.is_cpu_target(target)¶
- Parameters:
target (tvm.target.Target)
- Return type:
- tilelang.jit.adapter.utils.is_metal_target(target)¶
- Parameters:
target (tvm.target.Target)
- Return type:
- tilelang.jit.adapter.utils.get_annotated_mod(func_or_mod, target='auto', target_host=None, model_type='all')¶
- Parameters:
func_or_mod (tvm.tir.PrimFunc | tilelang.tvm.IRModule)
target (str | tvm.target.Target)
target_host (str | tvm.target.Target | None)
model_type (Literal['device', 'host', 'all'])
- Return type:
tvm.IRModule | tuple[tvm.IRModule, tvm.IRModule]
- tilelang.jit.adapter.utils.pythonic_expr(expr, dtype_map=None, ignore_cast=False)¶
Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.
- Parameters:
expr (tilelang.tvm.tir.PrimExpr) – The TVM PrimExpr to convert.
dtype_map (dict[str, str] | None) – A dictionary mapping data types to their string representations.
ignore_cast (bool) – Whether to ignore the cast operator and return the string representation of the value without the cast.
- Returns:
A string representation of the expression.
- Return type:
str
- tilelang.jit.adapter.utils.maybe_desc_name(name, matches, i, desc_name_map=None)¶
Check if a parameter name corresponds to a TMA descriptor.
- Parameters:
name (str) – The parameter name to check.
matches (list[str]) – List of all matched parameter names.
i (int) – Index of the current match.
desc_name_map (dict[str, str] | None) – Optional mapping to store descriptor name relationships.
- Returns:
True if the parameter is a TMA descriptor.
- Return type:
- tilelang.jit.adapter.utils.parse_function_call_args(declaration, function_args, function_params, desc_name_map=None, desc_name_var_map=None, transform_arg=None)¶
Parse function call arguments from a kernel declaration.
- Parameters:
declaration (str) – The kernel function declaration string.
function_args (list[dict[str, str]]) – List of function argument specifications.
function_params (list[Any]) – List of function parameters from TVM IR.
desc_name_map (dict[str, str] | None) – Optional mapping for descriptor names.
desc_name_var_map (dict[str, tilelang.tvm.tir.Var] | None) – Optional mapping from descriptor names to TVM variables.
transform_arg (Callable[[str, str], Any] | None) – Optional function to transform each argument (name, type) -> result.
- Returns:
List of parsed call arguments.
- Return type:
list[Any]
- class tilelang.jit.adapter.utils.TMADescriptorParams(handle_name, dtype, tensor_rank, global_address, is_img2col=False)¶
Parsed TMA descriptor parameters.
- Parameters:
handle_name (str)
dtype (str)
tensor_rank (int)
global_address (Any)
is_img2col (bool)
- handle_name¶
- dtype¶
- tensor_rank¶
- global_address¶
- is_img2col = False¶
- global_dim: list[str] = []¶
- global_stride: list[str] = []¶
- element_strides: list[str] = []¶
- interleave: str = ''¶
- swizzle: str = ''¶
- l2_promotion: str = ''¶
- oob_fill: str = ''¶
- box_dim: list[str] = []¶
- lower_corner: list[str] = []¶
- upper_corner: list[str] = []¶
- smem_box_channel: str = ''¶
- smem_box_pixel: str = ''¶
- tilelang.jit.adapter.utils.parse_tma_descriptor_args(tma_descriptor_args, desc_name_map, desc_name_var_map, pythonic_expr_func)¶
Parse TMA descriptor arguments into structured parameters.
- Parameters:
tma_descriptor_args (dict[tilelang.tvm.tir.Var, list[Any]]) – Dictionary mapping TMA descriptor variables to their arguments.
desc_name_map (dict[str, str]) – Mapping from descriptor handles to parameter names.
desc_name_var_map (dict[str, tilelang.tvm.tir.Var]) – Mapping from descriptor handles to TVM variables.
pythonic_expr_func (Callable[[Any], str]) – Function to convert TVM expressions to strings.
- Returns:
List of parsed TMA descriptor parameters.
- Return type:
list[TMADescriptorParams]