tilelang.jit.kernelΒΆ
AttributesΒΆ
ClassesΒΆ
A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions. |
Module ContentsΒΆ
- tilelang.jit.kernel.loggerΒΆ
- class tilelang.jit.kernel.JITKernel(func=None, out_idx=None, execution_backend='tvm_ffi', target='auto', target_host=None, verbose=False, pass_configs=None, from_database=False, compile_flags=None)ΒΆ
Bases:
Generic[_P,_T]A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.
- Parameters:
func (tvm.tir.PrimFunc)
out_idx (list[int] | int)
execution_backend (Literal['tvm_ffi', 'ctypes', 'cython', 'nvrtc', 'torch'])
target (str | tvm.target.Target)
target_host (str | tvm.target.Target)
verbose (bool)
pass_configs (dict[str, Any] | None)
from_database (bool)
compile_flags (list[str] | None)
- artifactΒΆ
The compiled artifact containing the runtime module and parameters.
- Type:
- adapterΒΆ
The adapter for the compiled function.
- Type:
- torch_functionΒΆ
The compiled function that can be invoked as a PyTorch-compatible function.
- Type:
Callable
- prim_func: tvm.tir.PrimFunc = NoneΒΆ
- artifact: tilelang.engine.param.CompiledArtifact = NoneΒΆ
- adapter: tilelang.jit.adapter.BaseKernelAdapter = NoneΒΆ
- torch_function: Callable = NoneΒΆ
- latency: float = NoneΒΆ
- config: dict[str, Any] = NoneΒΆ
- ref_latency: float = NoneΒΆ
- execution_backend = 'tvm_ffi'ΒΆ
- target_host = NoneΒΆ
- verbose = FalseΒΆ
- pass_configs = NoneΒΆ
- compile_flags = NoneΒΆ
- target = 'auto'ΒΆ
- classmethod from_database(func, host_kernel_source, device_kernel_source, kernel_lib_path, params, target, target_host, out_idx, execution_backend, pass_configs=None, compile_flags=None)ΒΆ
Alternative constructor to create a TorchFunction directly from a database.
- Parameters:
func (tvm.tir.PrimFunc)
host_kernel_source (str)
device_kernel_source (str)
kernel_lib_path (str)
params (list[tilelang.engine.param.KernelParam])
target (str | tvm.target.Target)
target_host (str | tvm.target.Target)
out_idx (list[int] | int)
execution_backend (Literal['tvm_ffi', 'ctypes', 'cython', 'nvrtc', 'torch'])
pass_configs (dict[str, Any] | None)
compile_flags (list[str] | None)
- __call__(*args, **kwds)ΒΆ
Invokes the compiled function with the given arguments.
- Parameters:
*args (Any) β Positional arguments for the function.
**kwds (Any) β Keyword arguments for the function.
- Returns:
The result of the function execution.
- Return type:
Any
- classmethod from_tilelang_function(tilelang_func, **kwargs)ΒΆ
Alternative constructor to create a TorchFunction directly from a TileLang PrimFunc.
- Parameters:
tilelang_func (tvm.tir.PrimFunc) β The TileLang (TVM TIR) function to compile.
**kwargs (dict) β Additional keyword arguments to pass to the constructor.
- Returns:
An instance of TorchFunction wrapping the compiled function.
- Return type:
TorchFunction
- get_profiler(tensor_supply_type=TensorSupplyType.Auto)ΒΆ
Creates a profiler to benchmark the compiled runtime module.
- Parameters:
tensor_supply_type (TensorSupplyType, optional) β The type of input tensors to supply for profiling (default: TensorSupplyType.Auto).
- Returns:
A Profiler instance for benchmarking the runtime module.
- Return type:
- get_kernel_source(kernel_only=True)ΒΆ
Returns the source code of the compiled kernel function.
- Returns:
The source code of the compiled kernel function.
- Return type:
str
- Parameters:
kernel_only (bool)
- get_host_source()ΒΆ
Returns the source code of the host function.
- Return type:
str
- run_once(func=None)ΒΆ
- Parameters:
func (Callable | None)
- Return type:
None
- show_source(which='kernel')ΒΆ
Print generated source code to stdout.
- Parameters:
which (Literal["kernel", "host", "both"], optional) β Select which source to print. Defaults to βkernelβ.
- Return type:
None
Examples
>>> jit_kernel.show_source() # print kernel source >>> jit_kernel.show_source("host") # print host source >>> jit_kernel.show_source("both") # print both sources
- export_sources(kernel_path=None, host_path=None)ΒΆ
Export generated source code to files.
- Parameters:
kernel_path (Optional[str]) β Destination file path to write the kernel source. If None, skips writing kernel code.
host_path (Optional[str]) β Destination file path to write the host source. If None, skips writing host code.
- Return type:
None
Examples
>>> jit_kernel.export_sources(kernel_path="/tmp/kernel.cu") >>> jit_kernel.export_sources(host_path="/tmp/host.cc") >>> jit_kernel.export_sources( ... kernel_path="/tmp/kernel.cu", ... host_path="/tmp/host.cc", ... )
- print_source_code(which='kernel', file=None)ΒΆ
Deprecated: use show_source() or export_sources() instead.
- Parameters:
which (Literal["kernel", "host", "both"], optional) β Kept for backward compatibility with printing behavior.
file (Optional[str]) β If provided, behaves like export_sources(kernel_path=file).
- Return type:
None
Examples
>>> # New API (preferred) >>> jit_kernel.show_source("both") >>> jit_kernel.export_sources(kernel_path="/tmp/kernel.cu")
>>> # Old API (still works but deprecated) >>> jit_kernel.print_source_code(file="/tmp/kernel.cu")
- update_tuner_result(latency, config, ref_latency)ΒΆ
Updates the tuning results for this kernel.
- Parameters:
latency (float) β The measured latency of this kernel configuration.
config (Dict[str, Any]) β The configuration parameters used for this kernel.
ref_latency (float) β The reference latency to compare against.
- Return type:
None
- get_tuner_result()ΒΆ
Gets the tuning results for this kernel.
- Returns:
A dictionary containing: - latency: The measured latency of this kernel - config: The configuration parameters used - ref_latency: The reference latency for comparison
- Return type:
Dict[str, Any]
- property out_idx: list[int]ΒΆ
- Return type:
list[int]
- property params: list[tilelang.engine.param.KernelParam]ΒΆ
- Return type:
- property kernel_source: strΒΆ
- Return type:
str
- property host_source: strΒΆ
- Return type:
str
- export_library(kernel_file)ΒΆ
Exports the compiled kernel function to a shared library file.
- Parameters:
kernel_file (str) β The path to the shared library file to create.
- Return type:
None
- show_ptx()ΒΆ
Print compiled PTX for the kernel (CUDA only).
Examples
>>> jit_kernel.show_ptx()
- Return type:
None
- export_ptx(path)ΒΆ
Export compiled PTX to a file (CUDA only).
- Parameters:
path (str) β Destination file path to write PTX.
- Return type:
None
Examples
>>> jit_kernel.export_ptx("/tmp/kernel.ptx")
- show_sass()ΒΆ
Print disassembled SASS for the kernel (CUDA only).
Examples
>>> jit_kernel.show_sass()
- Return type:
None
- export_sass(path)ΒΆ
Export disassembled SASS to a file (CUDA only).
- Parameters:
path (str) β Destination file path to write SASS.
- Return type:
None
Examples
>>> jit_kernel.export_sass("/tmp/kernel.sass")