tilelang.language.v2.annot ========================== .. py:module:: tilelang.language.v2.annot Attributes ---------- .. autoapisummary:: tilelang.language.v2.annot.Scope Classes ------- .. autoapisummary:: tilelang.language.v2.annot.Annot tilelang.language.v2.annot.ArgVarTable tilelang.language.v2.annot.Value tilelang.language.v2.annot.BufferAnnot tilelang.language.v2.annot.TensorAnnot tilelang.language.v2.annot.StridedTensorAnnot tilelang.language.v2.annot.FragmentBufferAnnot tilelang.language.v2.annot.SharedBufferAnnot tilelang.language.v2.annot.LocalBufferAnnot tilelang.language.v2.annot.DynAnnot tilelang.language.v2.annot.DTypeAnnot tilelang.language.v2.annot.TIRAnnot tilelang.language.v2.annot.Buffer tilelang.language.v2.annot.FuncAnnot Module Contents --------------- .. py:data:: Scope .. py:class:: Annot Bases: :py:obj:`abc.ABC` Base class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation .. py:method:: is_kernel_arg() Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) .. py:method:: with_name(name) :abstractmethod: .. py:method:: get_key_parser() :abstractmethod: Return a parser function that converts the argument value into a hash key for jit caching .. py:method:: create_prim_func_arg(name, value, vt) :abstractmethod: Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation .. py:method:: promote() Try to promote the annotation into a FixedAnnot if possible Return None if not promotable .. py:class:: ArgVarTable ArgVarTable is used to manage the mapping from argument names to tir.Var objects .. py:attribute:: var_tab :type: dict[str, tvm.tir.Var] .. py:attribute:: tmp_name_idx :type: int :value: 0 .. py:method:: get_or_create_var(name, dtype) .. py:method:: create_tmp_name() .. py:class:: Value Bases: :py:obj:`Annot` Base class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation .. py:attribute:: kind :type: Literal['static', 'dynamic'] :value: 'dynamic' .. py:attribute:: name :type: str | None :value: None .. py:attribute:: dtype :type: tilelang.language.v2.dtypes.dtype | None .. py:attribute:: value :type: int | tvm.tir.Var | None :value: None .. py:attribute:: creator :type: Callable[[], Any] | None :value: None .. py:method:: is_kernel_arg() Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) .. py:method:: from_value(value, prefer_name = None) :classmethod: .. py:method:: with_name(name) .. py:method:: get_key_parser() Return a parser function that converts the argument value into a hash key for jit caching .. py:method:: parse_key(target) .. py:method:: create_prim_func_arg(name, value, vt, create_arg = True) Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation .. py:method:: __repr__() .. py:class:: BufferAnnot Bases: :py:obj:`Annot` Base class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation .. py:attribute:: shape :type: tuple :value: None .. py:attribute:: strides :type: tuple :value: None .. py:attribute:: dtype :type: tilelang.language.v2.dtypes.dtype :value: None .. py:method:: is_kernel_arg() Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) .. py:property:: scope .. py:method:: __call__(shape, dtype = 'float32', data=None, strides=None, elem_offset=None, scope=None, align=0, offset_factor=0, buffer_type='', axis_separators=None) .. py:method:: __getitem__(params) .. py:method:: with_name(name) .. py:method:: get_key_parser() Return a parser function that converts the argument value into a hash key for jit caching .. py:method:: parse_key(target) .. py:method:: match_shape(shape, target_shape, vt) :staticmethod: .. py:method:: create_prim_func_arg(name, value, vt) Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation .. py:method:: promote() Try to promote the annotation into a FixedAnnot if possible Return None if not promotable .. py:class:: TensorAnnot Bases: :py:obj:`BufferAnnot` Base class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation .. py:method:: __call__(shape, dtype = 'float32', data=None, strides=None, elem_offset=None, scope=None, align=0, offset_factor=0, buffer_type='', axis_separators=None) .. py:method:: promote() Try to promote the annotation into a FixedAnnot if possible Return None if not promotable .. py:class:: StridedTensorAnnot Bases: :py:obj:`BufferAnnot` Base class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation .. py:method:: __call__(shape, strides, dtype = 'float32', data=None, elem_offset=None, scope=None, align=0, offset_factor=0, buffer_type='', axis_separators=None) .. py:method:: __getitem__(params) .. py:class:: FragmentBufferAnnot Bases: :py:obj:`BufferAnnot` Base class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation .. py:property:: scope .. py:class:: SharedBufferAnnot Bases: :py:obj:`BufferAnnot` Base class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation .. py:property:: scope .. py:class:: LocalBufferAnnot Bases: :py:obj:`BufferAnnot` Base class for tilelang kernel annotations Tilelang kernel annotations are used to specify how to interpret each argument of the jit kernel It provides 3 main functionalities: 1. determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) 2. parse the argument value into a hash key for jit caching 3. convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation .. py:property:: scope .. py:class:: DynAnnot Bases: :py:obj:`Value` Dynamic variable annotation represents a tvm tir.Var argument .. py:method:: __call__(dtype = dt.float32, name = None) .. py:method:: __getitem__(params) .. py:class:: DTypeAnnot Bases: :py:obj:`Annot` Data type annotation ensures automatically conversion from AnyDType to dtype >>> def foo(A: T.dtype): print(A) >>> foo(torch.float32) dtype('float32') >>> foo(T.float32) dtype('float32') >>> foo('float32') dtype('float32') .. py:attribute:: name :type: str | None :value: None .. py:method:: is_kernel_arg() Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) .. py:method:: with_name(name) .. py:method:: get_key_parser() Return a parser function that converts the argument value into a hash key for jit caching .. py:method:: create_prim_func_arg(name, value, vt) Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation .. py:method:: __repr__() .. py:class:: TIRAnnot Bases: :py:obj:`Annot` TIR annotation is used to directly pass tir.Buffer or tir.Var as kernel arguments >>> def foo(A: T.Buffer((128,), T.float32)): ... .. py:attribute:: data :type: tvm.tir.Buffer | tvm.tir.Var .. py:method:: is_kernel_arg() Determine whether the argument is a kernel argument (i.e., needs to be passed at kernel launch time) .. py:method:: get_key_parser() Return a parser function that converts the argument value into a hash key for jit caching .. py:method:: create_prim_func_arg(name, value, vt) Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation .. py:method:: with_name(name) .. py:method:: __repr__() .. py:class:: Buffer(dtype = 'float32', data=None, strides=None, elem_offset=None, scope=None, align=0, offset_factor=0, buffer_type='', axis_separators=None) Bases: :py:obj:`Generic`\ [\ :py:obj:`_Shape`\ , :py:obj:`_DType`\ ] Abstract base class for generic types. A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as:: class Mapping(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: ... # Etc. This class can then be used as follows:: def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default .. py:property:: shape :type: tuple[Unpack[_Shapes]] .. py:property:: dtype :type: tilelang.language.v2.dtypes.dtype[_DType] .. py:property:: strides :type: tuple[tvm.tir.PrimExpr] .. py:method:: scope() .. py:class:: FuncAnnot .. py:attribute:: sig :type: inspect.Signature .. py:attribute:: arg_names :type: list[str] .. py:attribute:: annots :type: dict[str, Annot] .. py:attribute:: arg_parser :type: dict[str, Callable[[Any], tuple[Any, Ellipsis]]] .. py:attribute:: ker_arg_names :type: list[str] .. py:method:: from_sig_annots(sig, func_annots) :classmethod: .. py:method:: parse_key(*args, **kws) Parse arguments and generates the cache key for jit caching .. py:method:: convert_to_kernel_args(*args, **kws) .. py:method:: create_argument(name, value, vt) Convert the argument into a tvm tir argument (tir.Var | tir.Buffer) for prim func generation .. py:method:: is_all_static() Check if all arguments are static (i.e., can be fully determined at compile time) .. py:method:: get_all_static_args() .. py:method:: get_compile_time_unknown_args()