tilelang.jit.kernelΒΆ

AttributesΒΆ

ClassesΒΆ

JITKernel

A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.

Module ContentsΒΆ

tilelang.jit.kernel.loggerΒΆ
class tilelang.jit.kernel.JITKernel(func=None, out_idx=None, execution_backend='tvm_ffi', target='auto', target_host=None, verbose=False, pass_configs=None, from_database=False, compile_flags=None)ΒΆ

Bases: Generic[_P, _T]

A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.

Parameters:
  • func (tvm.tir.PrimFunc)

  • out_idx (list[int] | int)

  • execution_backend (Literal['tvm_ffi', 'ctypes', 'cython', 'nvrtc', 'torch'])

  • target (str | tvm.target.Target)

  • target_host (str | tvm.target.Target)

  • verbose (bool)

  • pass_configs (dict[str, Any] | None)

  • from_database (bool)

  • compile_flags (list[str] | None)

artifactΒΆ

The compiled artifact containing the runtime module and parameters.

Type:

CompiledArtifact

adapterΒΆ

The adapter for the compiled function.

Type:

BaseKernelAdapter

torch_functionΒΆ

The compiled function that can be invoked as a PyTorch-compatible function.

Type:

Callable

prim_func: tvm.tir.PrimFunc = NoneΒΆ
artifact: tilelang.engine.param.CompiledArtifact = NoneΒΆ
adapter: tilelang.jit.adapter.BaseKernelAdapter = NoneΒΆ
torch_function: Callable = NoneΒΆ
latency: float = NoneΒΆ
config: dict[str, Any] = NoneΒΆ
ref_latency: float = NoneΒΆ
execution_backend = 'tvm_ffi'ΒΆ
target_host = NoneΒΆ
verbose = FalseΒΆ
pass_configs = NoneΒΆ
compile_flags = NoneΒΆ
target = 'auto'ΒΆ
classmethod from_database(func, host_kernel_source, device_kernel_source, kernel_lib_path, params, target, target_host, out_idx, execution_backend, pass_configs=None, compile_flags=None)ΒΆ

Alternative constructor to create a TorchFunction directly from a database.

Parameters:
  • func (tvm.tir.PrimFunc)

  • host_kernel_source (str)

  • device_kernel_source (str)

  • kernel_lib_path (str)

  • params (list[tilelang.engine.param.KernelParam])

  • target (str | tvm.target.Target)

  • target_host (str | tvm.target.Target)

  • out_idx (list[int] | int)

  • execution_backend (Literal['tvm_ffi', 'ctypes', 'cython', 'nvrtc', 'torch'])

  • pass_configs (dict[str, Any] | None)

  • compile_flags (list[str] | None)

__call__(*args, **kwds)ΒΆ

Invokes the compiled function with the given arguments.

Parameters:
  • *args (Any) – Positional arguments for the function.

  • **kwds (Any) – Keyword arguments for the function.

Returns:

The result of the function execution.

Return type:

Any

classmethod from_tilelang_function(tilelang_func, **kwargs)ΒΆ

Alternative constructor to create a TorchFunction directly from a TileLang PrimFunc.

Parameters:
  • tilelang_func (tvm.tir.PrimFunc) – The TileLang (TVM TIR) function to compile.

  • **kwargs (dict) – Additional keyword arguments to pass to the constructor.

Returns:

An instance of TorchFunction wrapping the compiled function.

Return type:

TorchFunction

get_profiler(tensor_supply_type=TensorSupplyType.Auto)ΒΆ

Creates a profiler to benchmark the compiled runtime module.

Parameters:

tensor_supply_type (TensorSupplyType, optional) – The type of input tensors to supply for profiling (default: TensorSupplyType.Auto).

Returns:

A Profiler instance for benchmarking the runtime module.

Return type:

Profiler

get_kernel_source(kernel_only=True)ΒΆ

Returns the source code of the compiled kernel function.

Returns:

The source code of the compiled kernel function.

Return type:

str

Parameters:

kernel_only (bool)

get_host_source()ΒΆ

Returns the source code of the host function.

Return type:

str

run_once(func=None)ΒΆ
Parameters:

func (Callable | None)

Return type:

None

show_source(which='kernel')ΒΆ

Print generated source code to stdout.

Parameters:

which (Literal["kernel", "host", "both"], optional) – Select which source to print. Defaults to β€œkernel”.

Return type:

None

Examples

>>> jit_kernel.show_source()            # print kernel source
>>> jit_kernel.show_source("host")      # print host source
>>> jit_kernel.show_source("both")      # print both sources
export_sources(kernel_path=None, host_path=None)ΒΆ

Export generated source code to files.

Parameters:
  • kernel_path (Optional[str]) – Destination file path to write the kernel source. If None, skips writing kernel code.

  • host_path (Optional[str]) – Destination file path to write the host source. If None, skips writing host code.

Return type:

None

Examples

>>> jit_kernel.export_sources(kernel_path="/tmp/kernel.cu")
>>> jit_kernel.export_sources(host_path="/tmp/host.cc")
>>> jit_kernel.export_sources(
...     kernel_path="/tmp/kernel.cu",
...     host_path="/tmp/host.cc",
... )
print_source_code(which='kernel', file=None)ΒΆ

Deprecated: use show_source() or export_sources() instead.

Parameters:
  • which (Literal["kernel", "host", "both"], optional) – Kept for backward compatibility with printing behavior.

  • file (Optional[str]) – If provided, behaves like export_sources(kernel_path=file).

Return type:

None

Examples

>>> # New API (preferred)
>>> jit_kernel.show_source("both")
>>> jit_kernel.export_sources(kernel_path="/tmp/kernel.cu")
>>> # Old API (still works but deprecated)
>>> jit_kernel.print_source_code(file="/tmp/kernel.cu")
update_tuner_result(latency, config, ref_latency)ΒΆ

Updates the tuning results for this kernel.

Parameters:
  • latency (float) – The measured latency of this kernel configuration.

  • config (Dict[str, Any]) – The configuration parameters used for this kernel.

  • ref_latency (float) – The reference latency to compare against.

Return type:

None

get_tuner_result()ΒΆ

Gets the tuning results for this kernel.

Returns:

A dictionary containing: - latency: The measured latency of this kernel - config: The configuration parameters used - ref_latency: The reference latency for comparison

Return type:

Dict[str, Any]

property out_idx: list[int]ΒΆ
Return type:

list[int]

property params: list[tilelang.engine.param.KernelParam]ΒΆ
Return type:

list[tilelang.engine.param.KernelParam]

property kernel_source: strΒΆ
Return type:

str

property host_source: strΒΆ
Return type:

str

export_library(kernel_file)ΒΆ

Exports the compiled kernel function to a shared library file.

Parameters:

kernel_file (str) – The path to the shared library file to create.

Return type:

None

show_ptx()ΒΆ

Print compiled PTX for the kernel (CUDA only).

Examples

>>> jit_kernel.show_ptx()
Return type:

None

export_ptx(path)ΒΆ

Export compiled PTX to a file (CUDA only).

Parameters:

path (str) – Destination file path to write PTX.

Return type:

None

Examples

>>> jit_kernel.export_ptx("/tmp/kernel.ptx")
show_sass()ΒΆ

Print disassembled SASS for the kernel (CUDA only).

Examples

>>> jit_kernel.show_sass()
Return type:

None

export_sass(path)ΒΆ

Export disassembled SASS to a file (CUDA only).

Parameters:

path (str) – Destination file path to write SASS.

Return type:

None

Examples

>>> jit_kernel.export_sass("/tmp/kernel.sass")