tilelang.utils.language

Functions

is_global(buffer)

Check if the buffer is in the global memory scope.

is_shared(buffer[, allow_dynamic])

Check if the buffer is in the shared memory scope.

is_shared_dynamic(buffer)

Check if the buffer is in the dynamic shared memory scope.

is_tensor_memory(buffer)

Check if the buffer is in tensor memory scope (e.g., shared.tmem).

is_local(buffer)

Check if the buffer is in the local memory scope.

is_fragment(buffer)

Check if the buffer is a fragment (e.g., for matrix multiplication operations).

get_buffer_elems(buffer)

Get the number of elements in the buffer.

array_reduce(array)

Reduce an array of integers to a single integer.

retrieve_func_from_module(ir_module)

Retrieve the single PrimFunc from an IRModule.

get_buffer_region_from_load(buffer_load[, extents])

Get the buffer region from a buffer load.

to_buffer_region(obj[, access_type, extents])

Convert to/from the tl.region representation.

retrieve_shape(obj)

Retrieve shape-like extents for a buffer-like object.

retrieve_stride(obj)

Retrieve row-major strides for a buffer-like object based on its buffer.shape.

retrive_ptr_from_buffer_region(buffer_or_load_or_region)

retrieve_ptr(obj[, access_type, ignore_last_ndim])

Retrieve a pointer to the start of a (possibly sliced) buffer region.

retrieve_offset(obj)

Retrieve per-dimension minima offsets.

bits_product(shape, dtype)

Compute the number of bits in a Buffer (shape with dtype).

prim_expr_equal(lhs, rhs)

Robust equality for PrimExpr shapes/extents.

legalize_pairwise_extents(src_extents, dst_extents)

Right-align and broadcast two extent lists to be mutually compatible.

is_full_region(buffer_region)

Check whether a BufferRegion covers the full buffer region.

Module Contents

tilelang.utils.language.is_global(buffer)

Check if the buffer is in the global memory scope.

Parameters:

buffer (tvm.tir.Buffer | tvm.tir.BufferLoad | tvm.tir.BufferRegion) – The TVM buffer, BufferLoad, or BufferRegion to check.

Returns:

True if the buffer is in global memory, False otherwise.

Return type:

bool

tilelang.utils.language.is_shared(buffer, allow_dynamic=True)

Check if the buffer is in the shared memory scope.

Parameters:
  • buffer (tvm.tir.Buffer | tvm.tir.BufferLoad | tvm.tir.BufferRegion) – The TVM buffer, BufferLoad, or BufferRegion to check.

  • allow_dynamic (bool)

Returns:

True if the buffer is in shared memory, False otherwise.

Return type:

bool

tilelang.utils.language.is_shared_dynamic(buffer)

Check if the buffer is in the dynamic shared memory scope.

Parameters:

buffer (tvm.tir.Buffer | tvm.tir.BufferLoad | tvm.tir.BufferRegion) – The TVM buffer, BufferLoad, or BufferRegion to check.

Returns:

True if the buffer is in dynamic shared memory, False otherwise.

Return type:

bool

tilelang.utils.language.is_tensor_memory(buffer)

Check if the buffer is in tensor memory scope (e.g., shared.tmem).

Parameters:

buffer (tvm.tir.Buffer | tvm.tir.BufferLoad | tvm.tir.BufferRegion) – The TVM buffer, BufferLoad, or BufferRegion to check.

Returns:

True if the buffer is in tensor memory, False otherwise.

Return type:

bool

tilelang.utils.language.is_local(buffer)

Check if the buffer is in the local memory scope.

Parameters:

buffer (tvm.tir.Buffer | tvm.tir.BufferLoad | tvm.tir.BufferRegion) – The TVM buffer, BufferLoad, or BufferRegion to check.

Returns:

True if the buffer is in local memory, False otherwise.

Return type:

bool

tilelang.utils.language.is_fragment(buffer)

Check if the buffer is a fragment (e.g., for matrix multiplication operations).

Parameters:

buffer (tvm.tir.Buffer | tvm.tir.BufferLoad | tvm.tir.BufferRegion) – The TVM buffer, BufferLoad, or BufferRegion to check.

Returns:

True if the buffer is a fragment, False otherwise.

Return type:

bool

tilelang.utils.language.get_buffer_elems(buffer)

Get the number of elements in the buffer.

Parameters:

buffer (tvm.tir.Buffer)

Return type:

int

tilelang.utils.language.array_reduce(array)

Reduce an array of integers to a single integer.

Parameters:

array (List[int]) – The array of integers to reduce.

Returns:

The reduced integer.

Return type:

int

tilelang.utils.language.retrieve_func_from_module(ir_module)

Retrieve the single PrimFunc from an IRModule.

Parameters:

ir_module (IRModule) – The TVM IRModule to extract the function from. The module should contain exactly one global function.

Returns:

The single function contained in the module.

Return type:

PrimFunc

Raises:
  • ValueError – If ir_module is not an IRModule.

  • AssertionError – If the module contains more than one global function.

tilelang.utils.language.get_buffer_region_from_load(buffer_load, extents=None)

Get the buffer region from a buffer load.

May encounter buffer load like C[0:128, 0:32], ref to pull request for buffer wise op: https://github.com/apache/tvm/pull/14693 convert load to region

Parameters:
  • buffer_load (tvm.tir.BufferLoad)

  • extents (list[tvm.tir.PrimExpr] | None)

Return type:

tvm.tir.BufferRegion | None

tilelang.utils.language.to_buffer_region(obj, access_type='rw', extents=None)

Convert to/from the tl.region representation.

  • Buffer/BufferLoad/BufferRegion -> returns a tl.region call (PrimExpr)

  • tl.region Call -> returns the decoded BufferRegion for analysis

Parameters:
  • obj (tvm.tir.Buffer | tvm.tir.BufferLoad | tvm.tir.BufferRegion | tvm.tir.Var)

  • access_type (str)

  • extents (list[tvm.tir.PrimExpr] | None)

Return type:

tvm.tir.PrimExpr | tvm.tir.BufferRegion

tilelang.utils.language.retrieve_shape(obj)

Retrieve shape-like extents for a buffer-like object.

  • Buffer -> its shape

  • BufferRegion -> list of each range’s extent

  • BufferLoad -> extents from get_buffer_region_from_load(obj)

Parameters:

obj (tvm.tir.Buffer | tvm.tir.BufferRegion | tvm.tir.BufferLoad)

Return type:

list

tilelang.utils.language.retrieve_stride(obj)

Retrieve row-major strides for a buffer-like object based on its buffer.shape.

For BufferRegion and BufferLoad, uses the underlying buffer’s shape.

Parameters:

obj (tvm.tir.Buffer | tvm.tir.BufferRegion | tvm.tir.BufferLoad)

Return type:

list

tilelang.utils.language.retrive_ptr_from_buffer_region(buffer_or_load_or_region, access_type='r')
Parameters:
  • buffer_or_load_or_region (tvm.tir.Buffer | tvm.tir.BufferLoad | tvm.tir.BufferRegion)

  • access_type (str)

Return type:

tvm.tir.PrimExpr

tilelang.utils.language.retrieve_ptr(obj, access_type='r', ignore_last_ndim=0)

Retrieve a pointer to the start of a (possibly sliced) buffer region.

  • Buffer -> base pointer

  • BufferRegion -> pointer with byte offset computed from region minima

  • BufferLoad -> pointer offset computed from indices or derived region

Parameters:
  • obj (tvm.tir.Buffer | tvm.tir.BufferRegion | tvm.tir.BufferLoad) – Buffer-like object

  • access_type (str) – TVM Buffer access mask, e.g. “r”, “w”, “rw”

  • ignore_last_ndim (int) – do not offset the last N dimensions

Return type:

tvm.tir.PrimExpr

tilelang.utils.language.retrieve_offset(obj)

Retrieve per-dimension minima offsets.

  • Buffer -> [0, 0, …]

  • BufferRegion -> [r.min for r in region]

  • BufferLoad -> indices (or derived region minima)

Parameters:

obj (tvm.tir.Buffer | tvm.tir.BufferRegion | tvm.tir.BufferLoad)

Return type:

list

tilelang.utils.language.bits_product(shape, dtype)

Compute the number of bits in a Buffer (shape with dtype).

Parameters:
  • shape (list[tvm.tir.PrimExpr])

  • dtype (str)

Return type:

tvm.tir.PrimExpr

tilelang.utils.language.prim_expr_equal(lhs, rhs)

Robust equality for PrimExpr shapes/extents.

Tries structural_equal first, then falls back to expr_deep_equal. Python ints are converted to IntImm for comparison.

Return type:

bool

tilelang.utils.language.legalize_pairwise_extents(src_extents, dst_extents)

Right-align and broadcast two extent lists to be mutually compatible.

Early-exit rule: - If the number of non-1 dimensions in src_extents equals that in dst_extents,

no adjustment is made; the original extents are returned unchanged. This preserves the per-dimension iteration mapping (one loop var per non-1 dim) and avoids creating extra varying axes on either side.

Otherwise, for each pair of tail-aligned dimensions (x, y):
  • if x == y: keep both

  • elif x == 1: set x = y

  • elif y == 1: set y = x

  • else: promote both to tir.max(x, y) to handle dynamic-vs-static safely

Leading unmatched dimensions are kept as-is.

Returns a tuple of new lists (src_new, dst_new).

Parameters:
  • src_extents (list)

  • dst_extents (list)

Return type:

tuple[list, list]

tilelang.utils.language.is_full_region(buffer_region)

Check whether a BufferRegion covers the full buffer region.

A full region means each dimension has start 0 and extent equal to the corresponding dimension in the buffer’s shape.

Parameters:

buffer_region (tvm.tir.BufferRegion) – The TVM BufferRegion to check.

Returns:

True if the region is full; otherwise False.

Return type:

bool