tilelang.carver.roller.shape_inference.tir¶
Classes¶
For tensor dependency analysis. |
|
Functions¶
|
|
|
|
|
Module Contents¶
- class tilelang.carver.roller.shape_inference.tir.Statement(block_analyzer, block)¶
- Parameters:
block (tvm.tir.schedule.schedule.BlockRV)
- block_analyzer¶
- block¶
- dep_name¶
- dependent_region¶
- reverse_bound_inference¶
- make_reverse(input_name, input_iter)¶
- Parameters:
input_name (str)
input_iter (List[tvm.tir.PrimExpr])
- class tilelang.carver.roller.shape_inference.tir.TensorDepNode(name)¶
Bases:
object
For tensor dependency analysis.
- name¶
- add_next(node)¶
- add_prev(node)¶
- deduplicate(lst)¶
- __str__()¶
- __repr__()¶
- class tilelang.carver.roller.shape_inference.tir.DependencyAnalysis(deps)¶
Bases:
object
- deps¶
- name2dep¶
- mapping¶
- get_or_create_node(name)¶
- traverse_dependencies(compute)¶
- analyze()¶
- print_dependencies()¶
- find_path_from_source(start_name, target_name)¶
Finds the path (if it exists) from a starting node (source) to a target node. Returns the path as a list of nodes.
- class tilelang.carver.roller.shape_inference.tir.InputShapeInference(deps)¶
- Parameters:
deps (List[Statement])
- deps¶
- target_mapping¶
- buffer_mapping¶
- reduce_axes = []¶
- dep_analysis¶
- construct_dependency_target(targets)¶
- Parameters:
targets (Tuple[str])
- infer(shape, rstep=None, targets=None)¶
- Parameters:
shape (Dict[str, List[tvm.arith.ConstIntBound]])
rstep (Dict[str, int])
- get_input_exprs(output_exprs)¶
- tilelang.carver.roller.shape_inference.tir.region_exist_in_list(a, list)¶
- Return type:
bool
- tilelang.carver.roller.shape_inference.tir.walk_indice(expr)¶
- tilelang.carver.roller.shape_inference.tir.get_analyzer_by_tir(block_analyzer, args)¶
- Return type: