tilelang.jit.adapter.utils¶

Classes¶

TMADescriptorParams

Parsed TMA descriptor parameters.

Functions¶

match_global_kernel(source[, annotation])

match_declare_kernel(source[, annotation])

match_declare_kernel_cpu(source[, annotation])

is_cuda_target(target)

is_hip_target(target)

is_cpu_target(target)

is_metal_target(target)

get_annotated_mod(func_or_mod[, target, target_host, ...])

pythonic_expr(expr[, dtype_map, ignore_cast])

Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.

maybe_desc_name(name, matches, i[, desc_name_map])

Check if a parameter name corresponds to a TMA descriptor.

parse_function_call_args(declaration, function_args, ...)

Parse function call arguments from a kernel declaration.

parse_tma_descriptor_args(tma_descriptor_args, ...)

Parse TMA descriptor arguments into structured parameters.

Module Contents¶

tilelang.jit.adapter.utils.match_global_kernel(source, annotation='__global__')¶
Parameters:
  • source (str)

  • annotation (str)

Return type:

int

tilelang.jit.adapter.utils.match_declare_kernel(source, annotation='__global__')¶
Parameters:
  • source (str)

  • annotation (str)

Return type:

int

tilelang.jit.adapter.utils.match_declare_kernel_cpu(source, annotation='int32_t')¶
Parameters:
  • source (str)

  • annotation (str)

Return type:

int

tilelang.jit.adapter.utils.is_cuda_target(target)¶
Parameters:

target (tvm.target.Target)

Return type:

bool

tilelang.jit.adapter.utils.is_hip_target(target)¶
Parameters:

target (tvm.target.Target)

Return type:

bool

tilelang.jit.adapter.utils.is_cpu_target(target)¶
Parameters:

target (tvm.target.Target)

Return type:

bool

tilelang.jit.adapter.utils.is_metal_target(target)¶
Parameters:

target (tvm.target.Target)

Return type:

bool

tilelang.jit.adapter.utils.get_annotated_mod(func_or_mod, target='auto', target_host=None, model_type='all')¶
Parameters:
  • func_or_mod (tvm.tir.PrimFunc | tilelang.tvm.IRModule)

  • target (str | tvm.target.Target)

  • target_host (str | tvm.target.Target | None)

  • model_type (Literal['device', 'host', 'all'])

Return type:

tvm.IRModule | tuple[tvm.IRModule, tvm.IRModule]

tilelang.jit.adapter.utils.pythonic_expr(expr, dtype_map=None, ignore_cast=False)¶

Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.

Parameters:
  • expr (tilelang.tvm.tir.PrimExpr) – The TVM PrimExpr to convert.

  • dtype_map (dict[str, str] | None) – A dictionary mapping data types to their string representations.

  • ignore_cast (bool) – Whether to ignore the cast operator and return the string representation of the value without the cast.

Returns:

A string representation of the expression.

Return type:

str

tilelang.jit.adapter.utils.maybe_desc_name(name, matches, i, desc_name_map=None)¶

Check if a parameter name corresponds to a TMA descriptor.

Parameters:
  • name (str) – The parameter name to check.

  • matches (list[str]) – List of all matched parameter names.

  • i (int) – Index of the current match.

  • desc_name_map (dict[str, str] | None) – Optional mapping to store descriptor name relationships.

Returns:

True if the parameter is a TMA descriptor.

Return type:

bool

tilelang.jit.adapter.utils.parse_function_call_args(declaration, function_args, function_params, desc_name_map=None, desc_name_var_map=None, transform_arg=None)¶

Parse function call arguments from a kernel declaration.

Parameters:
  • declaration (str) – The kernel function declaration string.

  • function_args (list[dict[str, str]]) – List of function argument specifications.

  • function_params (list[Any]) – List of function parameters from TVM IR.

  • desc_name_map (dict[str, str] | None) – Optional mapping for descriptor names.

  • desc_name_var_map (dict[str, tilelang.tvm.tir.Var] | None) – Optional mapping from descriptor names to TVM variables.

  • transform_arg (Callable[[str, str], Any] | None) – Optional function to transform each argument (name, type) -> result.

Returns:

List of parsed call arguments.

Return type:

list[Any]

class tilelang.jit.adapter.utils.TMADescriptorParams(handle_name, dtype, tensor_rank, global_address, is_img2col=False)¶

Parsed TMA descriptor parameters.

Parameters:
  • handle_name (str)

  • dtype (str)

  • tensor_rank (int)

  • global_address (Any)

  • is_img2col (bool)

handle_name¶
dtype¶
tensor_rank¶
global_address¶
is_img2col = False¶
global_dim: list[str] = []¶
global_stride: list[str] = []¶
element_strides: list[str] = []¶
interleave: str = ''¶
swizzle: str = ''¶
l2_promotion: str = ''¶
oob_fill: str = ''¶
box_dim: list[str] = []¶
lower_corner: list[str] = []¶
upper_corner: list[str] = []¶
smem_box_channel: str = ''¶
smem_box_pixel: str = ''¶
tilelang.jit.adapter.utils.parse_tma_descriptor_args(tma_descriptor_args, desc_name_map, desc_name_var_map, pythonic_expr_func)¶

Parse TMA descriptor arguments into structured parameters.

Parameters:
  • tma_descriptor_args (dict[tilelang.tvm.tir.Var, list[Any]]) – Dictionary mapping TMA descriptor variables to their arguments.

  • desc_name_map (dict[str, str]) – Mapping from descriptor handles to parameter names.

  • desc_name_var_map (dict[str, tilelang.tvm.tir.Var]) – Mapping from descriptor handles to TVM variables.

  • pythonic_expr_func (Callable[[Any], str]) – Function to convert TVM expressions to strings.

Returns:

List of parsed TMA descriptor parameters.

Return type:

list[TMADescriptorParams]