tilelang.carver.template.flashattention¶

Classes¶

FlashAttentionTemplate

Base class template for hardware-aware configurations.

Module Contents¶

class tilelang.carver.template.flashattention.FlashAttentionTemplate¶

Bases: tilelang.carver.template.base.BaseTemplate

Base class template for hardware-aware configurations. This serves as an abstract base class (ABC) that defines the structure for subclasses implementing hardware-specific optimizations.

batch_size: int = 1¶
num_heads: int = 1¶
head_dim: int = 1¶
seq_length: int = 1¶
seq_kv_length: int = 1¶
is_causal: bool = False¶
in_dtype: str = 'float16'¶
out_dtype: str = 'float16'¶
accum_dtype: str = 'float16'¶
get_hardware_aware_configs(arch=None, topk=10)¶

Retrieves optimized hardware-aware configurations.

Parameters:
  • arch (TileDevice, optional) – The target hardware architecture.

  • topk (int, optional) – Number of top configurations to consider.

Returns:

A list of optimization hints for hardware acceleration.

Return type:

List[Hint]

initialize_function()¶

Defines and initializes the matrix multiplication computation.

This method sets up placeholders for input matrices, computes the matrix multiplication using TVM’s compute API, and optionally applies bias and type casting.

Raises:

AssertionError – If M, N, or K are not positive integers.

Return type:

None

params_as_dict()¶

Returns the template parameters as a dictionary.

Returns:

Dictionary containing template parameter values.

Return type:

dict

property class_attributes¶

Returns the class attributes in dictionary form.

Returns:

Dictionary of class attributes.

Return type:

dict

__repr__()¶

Returns a string representation of the class instance.

Returns:

A formatted string representation of the class.

Return type:

str