tilelang.language.v2.annot¶

Attributes¶

Classes¶

Annot

Base class for tilelang kernel annotations

ArgVarTable

ArgVarTable is used to manage the mapping from argument names to tir.Var objects

Value

Base class for tilelang kernel annotations

BufferAnnot

Base class for tilelang kernel annotations

TensorAnnot

Base class for tilelang kernel annotations

StridedTensorAnnot

Base class for tilelang kernel annotations

FragmentBufferAnnot

Base class for tilelang kernel annotations

SharedBufferAnnot

Base class for tilelang kernel annotations

LocalBufferAnnot

Base class for tilelang kernel annotations

DynAnnot

Dynamic variable annotation represents a tvm tir.Var argument

DTypeAnnot

Data type annotation ensures automatically conversion from AnyDType to dtype

TIRAnnot

TIR annotation is used to directly pass tir.Buffer or tir.Var as kernel arguments

Buffer

Abstract base class for generic types.

FuncAnnot

Module Contents¶

tilelang.language.v2.annot.Scope¶
class tilelang.language.v2.annot.Annot¶

Bases: abc.ABC

Base class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel

It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation

is_kernel_arg()¶

Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time)

Return type:

bool

abstract with_name(name)¶
Return type:

Self

abstract get_key_parser()¶

Return a parser function that converts the argument value into a hash key for jit caching

Return type:

Callable[[str, Any], tuple[Any, Ellipsis]]

abstract create_prim_func_arg(name, value, vt)¶

Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation

Parameters:
Return type:

tvm.tir.Var | tvm.tir.Buffer

promote()¶

Try to promote the annotation into a FixedAnnot if possible Return None if not promotable

Return type:

TIRAnnot | None

class tilelang.language.v2.annot.ArgVarTable¶

ArgVarTable is used to manage the mapping from argument names to tir.Var objects

var_tab: dict[str, tvm.tir.Var]¶
tmp_name_idx: int = 0¶
get_or_create_var(name, dtype)¶
Parameters:
Return type:

tvm.tir.Var

create_tmp_name()¶
Return type:

str

class tilelang.language.v2.annot.Value¶

Bases: Annot

Base class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel

It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation

kind: Literal['static', 'dynamic'] = 'dynamic'¶
name: str | None = None¶
dtype: tilelang.language.v2.dtypes.dtype | None¶
value: int | tvm.tir.Var | None = None¶
creator: Callable[[], Any] | None = None¶
is_kernel_arg()¶

Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time)

Return type:

bool

classmethod from_value(value, prefer_name=None)¶
Parameters:
  • value (Any)

  • prefer_name (str)

Return type:

Value

with_name(name)¶
Parameters:

name (str)

Return type:

Value

get_key_parser()¶

Return a parser function that converts the argument value into a hash key for jit caching

parse_key(target)¶
Parameters:

target (Any)

create_prim_func_arg(name, value, vt, create_arg=True)¶

Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation

Parameters:
__repr__()¶
class tilelang.language.v2.annot.BufferAnnot¶

Bases: Annot

Base class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel

It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation

shape: tuple = None¶
strides: tuple = None¶
dtype: tilelang.language.v2.dtypes.dtype = None¶
is_kernel_arg()¶

Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time)

Return type:

bool

property scope¶
__call__(shape, dtype='float32', data=None, strides=None, elem_offset=None, scope=None, align=0, offset_factor=0, buffer_type='', axis_separators=None)¶
Parameters:
  • shape (tuple[Unpack[_Shapes]])

  • dtype (_DType)

Return type:

Tensor[Callable[[Unpack[_Shapes]]], _DType]

__getitem__(params)¶
with_name(name)¶
Parameters:

name (str)

get_key_parser()¶

Return a parser function that converts the argument value into a hash key for jit caching

parse_key(target)¶
Parameters:

target (Any)

static match_shape(shape, target_shape, vt)¶
Parameters:
  • shape (tuple[Value, Ellipsis])

  • target_shape (tuple[int, Ellipsis])

  • vt (ArgVarTable)

create_prim_func_arg(name, value, vt)¶

Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation

Parameters:
promote()¶

Try to promote the annotation into a FixedAnnot if possible Return None if not promotable

class tilelang.language.v2.annot.TensorAnnot¶

Bases: BufferAnnot

Base class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel

It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation

__call__(shape, dtype='float32', data=None, strides=None, elem_offset=None, scope=None, align=0, offset_factor=0, buffer_type='', axis_separators=None)¶
Parameters:
  • shape (tuple[Unpack[_Shapes]])

  • dtype (_DType)

promote()¶

Try to promote the annotation into a FixedAnnot if possible Return None if not promotable

class tilelang.language.v2.annot.StridedTensorAnnot¶

Bases: BufferAnnot

Base class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel

It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation

__call__(shape, strides, dtype='float32', data=None, elem_offset=None, scope=None, align=0, offset_factor=0, buffer_type='', axis_separators=None)¶
Parameters:

dtype (_DType)

__getitem__(params)¶
class tilelang.language.v2.annot.FragmentBufferAnnot¶

Bases: BufferAnnot

Base class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel

It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation

property scope¶
class tilelang.language.v2.annot.SharedBufferAnnot¶

Bases: BufferAnnot

Base class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel

It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation

property scope¶
class tilelang.language.v2.annot.LocalBufferAnnot¶

Bases: BufferAnnot

Base class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel

It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation

property scope¶
class tilelang.language.v2.annot.DynAnnot¶

Bases: Value

Dynamic variable annotation represents a tvm tir.Var argument

__call__(dtype=dt.float32, name=None)¶
Parameters:
  • dtype (tilelang.language.v2.dtypes.AnyDType)

  • name (str | None)

Return type:

DynAnnot

__getitem__(params)¶
class tilelang.language.v2.annot.DTypeAnnot¶

Bases: Annot

Data type annotation ensures automatically conversion from AnyDType to dtype >>> def foo(A: T.dtype): print(A) >>> foo(torch.float32) dtype(‘float32’) >>> foo(T.float32) dtype(‘float32’) >>> foo(‘float32’) dtype(‘float32’)

name: str | None = None¶
is_kernel_arg()¶

Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time)

Return type:

bool

with_name(name)¶
get_key_parser()¶

Return a parser function that converts the argument value into a hash key for jit caching

create_prim_func_arg(name, value, vt)¶

Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation

__repr__()¶
class tilelang.language.v2.annot.TIRAnnot¶

Bases: Annot

TIR annotation is used to directly pass tir.Buffer or tir.Var as kernel arguments >>> def foo(A: T.Buffer((128,), T.float32)): …

data: tvm.tir.Buffer | tvm.tir.Var¶
is_kernel_arg()¶

Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time)

Return type:

bool

get_key_parser()¶

Return a parser function that converts the argument value into a hash key for jit caching

create_prim_func_arg(name, value, vt)¶

Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation

with_name(name)¶
Parameters:

name (str)

__repr__()¶
class tilelang.language.v2.annot.Buffer(dtype='float32', data=None, strides=None, elem_offset=None, scope=None, align=0, offset_factor=0, buffer_type='', axis_separators=None)¶

Bases: Generic[_Shape, _DType]

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as:

class Mapping(Generic[KT, VT]):
    def __getitem__(self, key: KT) -> VT:
        ...
    # Etc.

This class can then be used as follows:

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT:
    try:
        return mapping[key]
    except KeyError:
        return default
Parameters:

dtype (_DType)

property shape: tuple[Unpack[_Shapes]]¶
Return type:

tuple[Unpack[_Shapes]]

property dtype: tilelang.language.v2.dtypes.dtype[_DType]¶
Return type:

tilelang.language.v2.dtypes.dtype[_DType]

property strides: tuple[tvm.tir.PrimExpr]¶
Return type:

tuple[tvm.tir.PrimExpr]

scope()¶
Return type:

Scope

class tilelang.language.v2.annot.FuncAnnot¶
sig: inspect.Signature¶
arg_names: list[str]¶
annots: dict[str, Annot]¶
arg_parser: dict[str, Callable[[Any], tuple[Any, Ellipsis]]]¶
ker_arg_names: list[str]¶
classmethod from_sig_annots(sig, func_annots)¶
Parameters:
  • sig (inspect.Signature)

  • func_annots (dict[str, Any])

Return type:

FuncAnnot

parse_key(*args, **kws)¶

Parse arguments and generates the cache key for jit caching

convert_to_kernel_args(*args, **kws)¶
create_argument(name, value, vt)¶

Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation

Parameters:
is_all_static()¶

Check if all arguments are static (i.e., can be fully determined at compile time)

get_all_static_args()¶
get_compile_time_unknown_args()¶