tilelang.utils.tensorΒΆ

ClassesΒΆ

TensorSupplyType

Generic enumeration.

FunctionsΒΆ

map_torch_type(intype)

adapt_torch2tvm(arg)

get_tensor_supply([supply_type])

torch_assert_close(tensor_a, tensor_b[, rtol, atol, ...])

Custom function to assert that two tensors are "close enough," allowing a specified

Module ContentsΒΆ

class tilelang.utils.tensor.TensorSupplyTypeΒΆ

Bases: enum.Enum

Generic enumeration.

Derive from this class to define new enumerations.

Integer = 1ΒΆ
Uniform = 2ΒΆ
Normal = 3ΒΆ
Randn = 4ΒΆ
Zero = 5ΒΆ
One = 6ΒΆ
Auto = 7ΒΆ
tilelang.utils.tensor.map_torch_type(intype)ΒΆ
Parameters:

intype (str)

Return type:

torch.dtype

tilelang.utils.tensor.adapt_torch2tvm(arg)ΒΆ
tilelang.utils.tensor.get_tensor_supply(supply_type=TensorSupplyType.Integer)ΒΆ
Parameters:

supply_type (TensorSupplyType)

tilelang.utils.tensor.torch_assert_close(tensor_a, tensor_b, rtol=0.01, atol=0.001, max_mismatched_ratio=0.001, verbose=False, equal_nan=True, check_device=True, check_dtype=True, check_layout=True, check_stride=False, base_name='LHS', ref_name='RHS')ΒΆ

Custom function to assert that two tensors are β€œclose enough,” allowing a specified percentage of mismatched elements.

Parameters:ΒΆ

tensor_atorch.Tensor

The first tensor to compare.

tensor_btorch.Tensor

The second tensor to compare.

rtolfloat, optional

Relative tolerance for comparison. Default is 1e-2.

atolfloat, optional

Absolute tolerance for comparison. Default is 1e-3.

max_mismatched_ratiofloat, optional

Maximum ratio of mismatched elements allowed (relative to the total number of elements). Default is 0.001 (0.1% of total elements).

Raises:ΒΆ

AssertionError:

If the ratio of mismatched elements exceeds max_mismatched_ratio.

Parameters:
  • verbose (bool)

  • equal_nan (bool)

  • check_device (bool)

  • check_dtype (bool)

  • check_layout (bool)

  • check_stride (bool)

  • base_name (str)

  • ref_name (str)