tilelang.language.v2.annot¶
Attributes¶
Classes¶
Base class for tilelang kernel annotations |
|
ArgVarTable is used to manage the mapping from argument names to tir.Var objects |
|
Base class for tilelang kernel annotations |
|
Base class for tilelang kernel annotations |
|
Base class for tilelang kernel annotations |
|
Base class for tilelang kernel annotations |
|
Base class for tilelang kernel annotations |
|
Base class for tilelang kernel annotations |
|
Base class for tilelang kernel annotations |
|
Dynamic variable annotation represents a tvm tir.Var argument |
|
Data type annotation ensures automatically conversion from AnyDType to dtype |
|
TIR annotation is used to directly pass tir.Buffer or tir.Var as kernel arguments |
|
Abstract base class for generic types. |
|
Module Contents¶
- tilelang.language.v2.annot.Scope¶
- class tilelang.language.v2.annot.Annot¶
Bases:
abc.ABCBase class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel
It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
- is_kernel_arg()¶
Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time)
- Return type:
- abstract with_name(name)¶
- Return type:
Self
- abstract get_key_parser()¶
Return a parser function that converts the argument value into a hash key for jit caching
- Return type:
Callable[[str, Any], tuple[Any, Ellipsis]]
- abstract create_prim_func_arg(name, value, vt)¶
Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
- Parameters:
name (str)
value (Any)
vt (ArgVarTable)
- Return type:
tvm.tir.Var | tvm.tir.Buffer
- class tilelang.language.v2.annot.ArgVarTable¶
ArgVarTable is used to manage the mapping from argument names to tir.Var objects
- var_tab: dict[str, tvm.tir.Var]¶
- tmp_name_idx: int = 0¶
- get_or_create_var(name, dtype)¶
- Parameters:
name (str)
- Return type:
tvm.tir.Var
- create_tmp_name()¶
- Return type:
str
- class tilelang.language.v2.annot.Value¶
Bases:
AnnotBase class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel
It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
- kind: Literal['static', 'dynamic'] = 'dynamic'¶
- name: str | None = None¶
- dtype: tilelang.language.v2.dtypes.dtype | None¶
- value: int | tvm.tir.Var | None = None¶
- creator: Callable[[], Any] | None = None¶
- is_kernel_arg()¶
Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time)
- Return type:
- classmethod from_value(value, prefer_name=None)¶
- Parameters:
value (Any)
prefer_name (str)
- Return type:
- get_key_parser()¶
Return a parser function that converts the argument value into a hash key for jit caching
- parse_key(target)¶
- Parameters:
target (Any)
- create_prim_func_arg(name, value, vt, create_arg=True)¶
Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
- Parameters:
name (str)
value (Any)
vt (ArgVarTable)
create_arg (bool)
- __repr__()¶
- class tilelang.language.v2.annot.BufferAnnot¶
Bases:
AnnotBase class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel
It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
- shape: tuple = None¶
- strides: tuple = None¶
- dtype: tilelang.language.v2.dtypes.dtype = None¶
- is_kernel_arg()¶
Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time)
- Return type:
- property scope¶
- __call__(shape, dtype='float32', data=None, strides=None, elem_offset=None, scope=None, align=0, offset_factor=0, buffer_type='', axis_separators=None)¶
- Parameters:
shape (tuple[Unpack[_Shapes]])
dtype (_DType)
- Return type:
Tensor[Callable[[Unpack[_Shapes]]], _DType]
- __getitem__(params)¶
- with_name(name)¶
- Parameters:
name (str)
- get_key_parser()¶
Return a parser function that converts the argument value into a hash key for jit caching
- parse_key(target)¶
- Parameters:
target (Any)
- static match_shape(shape, target_shape, vt)¶
- Parameters:
shape (tuple[Value, Ellipsis])
target_shape (tuple[int, Ellipsis])
vt (ArgVarTable)
- create_prim_func_arg(name, value, vt)¶
Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
- Parameters:
name (str)
value (Any)
vt (ArgVarTable)
- promote()¶
Try to promote the annotation into a FixedAnnot if possible Return None if not promotable
- class tilelang.language.v2.annot.TensorAnnot¶
Bases:
BufferAnnotBase class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel
It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
- __call__(shape, dtype='float32', data=None, strides=None, elem_offset=None, scope=None, align=0, offset_factor=0, buffer_type='', axis_separators=None)¶
- Parameters:
shape (tuple[Unpack[_Shapes]])
dtype (_DType)
- promote()¶
Try to promote the annotation into a FixedAnnot if possible Return None if not promotable
- class tilelang.language.v2.annot.StridedTensorAnnot¶
Bases:
BufferAnnotBase class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel
It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
- __call__(shape, strides, dtype='float32', data=None, elem_offset=None, scope=None, align=0, offset_factor=0, buffer_type='', axis_separators=None)¶
- Parameters:
dtype (_DType)
- __getitem__(params)¶
- class tilelang.language.v2.annot.FragmentBufferAnnot¶
Bases:
BufferAnnotBase class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel
It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
- property scope¶
Bases:
BufferAnnotBase class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel
It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
- class tilelang.language.v2.annot.LocalBufferAnnot¶
Bases:
BufferAnnotBase class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel
It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
- property scope¶
- class tilelang.language.v2.annot.DynAnnot¶
Bases:
ValueDynamic variable annotation represents a tvm tir.Var argument
- __call__(dtype=dt.float32, name=None)¶
- Parameters:
dtype (tilelang.language.v2.dtypes.AnyDType)
name (str | None)
- Return type:
- __getitem__(params)¶
- class tilelang.language.v2.annot.DTypeAnnot¶
Bases:
AnnotData type annotation ensures automatically conversion from AnyDType to dtype >>> def foo(A: T.dtype): print(A) >>> foo(torch.float32) dtype(‘float32’) >>> foo(T.float32) dtype(‘float32’) >>> foo(‘float32’) dtype(‘float32’)
- name: str | None = None¶
- is_kernel_arg()¶
Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time)
- Return type:
- with_name(name)¶
- get_key_parser()¶
Return a parser function that converts the argument value into a hash key for jit caching
- create_prim_func_arg(name, value, vt)¶
Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
- __repr__()¶
- class tilelang.language.v2.annot.TIRAnnot¶
Bases:
AnnotTIR annotation is used to directly pass tir.Buffer or tir.Var as kernel arguments >>> def foo(A: T.Buffer((128,), T.float32)): …
- data: tvm.tir.Buffer | tvm.tir.Var¶
- is_kernel_arg()¶
Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time)
- Return type:
- get_key_parser()¶
Return a parser function that converts the argument value into a hash key for jit caching
- create_prim_func_arg(name, value, vt)¶
Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
- with_name(name)¶
- Parameters:
name (str)
- __repr__()¶
- class tilelang.language.v2.annot.Buffer(dtype='float32', data=None, strides=None, elem_offset=None, scope=None, align=0, offset_factor=0, buffer_type='', axis_separators=None)¶
Bases:
Generic[_Shape,_DType]Abstract base class for generic types.
A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as:
class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc.
This class can then be used as follows:
def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default
- Parameters:
dtype (_DType)
- property shape: tuple[Unpack[_Shapes]]¶
- Return type:
tuple[Unpack[_Shapes]]
- property dtype: tilelang.language.v2.dtypes.dtype[_DType]¶
- Return type:
- property strides: tuple[tvm.tir.PrimExpr]¶
- Return type:
tuple[tvm.tir.PrimExpr]
- scope()¶
- Return type:
Scope
- class tilelang.language.v2.annot.FuncAnnot¶
- sig: inspect.Signature¶
- arg_names: list[str]¶
- arg_parser: dict[str, Callable[[Any], tuple[Any, Ellipsis]]]¶
- ker_arg_names: list[str]¶
- classmethod from_sig_annots(sig, func_annots)¶
- Parameters:
sig (inspect.Signature)
func_annots (dict[str, Any])
- Return type:
- parse_key(*args, **kws)¶
Parse arguments and generates the cache key for jit caching
- convert_to_kernel_args(*args, **kws)¶
- create_argument(name, value, vt)¶
Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation
- Parameters:
name (str)
value (Any)
vt (ArgVarTable)
- is_all_static()¶
Check if all arguments are static (i.e., can be fully determined at compile time)
- get_all_static_args()¶
- get_compile_time_unknown_args()¶