Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/xtc/backends/mlir/MlirCompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from xtc.backends.mlir.MlirCompilerPasses import (
MlirProgramInsertTransformPass,
MlirProgramApplyTransformPass,
apply_bufferization_passes,
_POST_BUFFERIZE_SEQ_NAME,
)

from xtc.backends.mlir.MlirTarget import (
Expand Down Expand Up @@ -141,6 +143,7 @@ def mlir_insert_transform_pass(self) -> None:
always_vectorize=self._config.always_vectorize,
vectors_size=self._config.vectors_size,
target=self._target,
using_tensors_hint=self._config.using_tensors_hint,
)
insert_transform_pass.run()
if self._config.print_source_ir:
Expand All @@ -149,11 +152,27 @@ def mlir_insert_transform_pass(self) -> None:
def mlir_apply_transform_pass(self) -> None:
apply_transform_pass = MlirProgramApplyTransformPass(
mlir_program=self._mlir_program,
clean_all=not self._config.using_tensors_hint,
)
apply_transform_pass.run()
if self._config.print_transformed_ir:
self.dump_ir("IR Dump After transform")

def mlir_apply_tensor_lowering_pass(self) -> None:
apply_bufferization_passes(self._mlir_program)
if self._config.print_bufferization_ir:
self.dump_ir("IR Dump After Tensor Lowering")

def mlir_apply_post_bufferize_transform_pass(self) -> None:
apply_transform_pass = MlirProgramApplyTransformPass(
mlir_program=self._mlir_program,
clean_all=True,
custom_sequence=_POST_BUFFERIZE_SEQ_NAME,
)
apply_transform_pass.run()
if self._config.print_transformed_ir:
self.dump_ir("IR Dump After Post-Bufferize transform")

def _save_temp(self, fname: str, content: Any) -> None:
if not self._config.save_temps:
return
Expand Down Expand Up @@ -194,6 +213,7 @@ def compile(self) -> None:
src_ir_dump_file = f"{dump_base}.mlir"
mlir_btrn_dump_file = f"{dump_base}.before_trn.mlir"
mlir_atrn_dump_file = f"{dump_base}.after_trn.mlir"
mlir_tlwr_dump_file = f"{dump_base}.bufferized.mlir"

save_temp(src_ir_dump_file, self._mlir_program.mlir_module)

Expand All @@ -203,4 +223,9 @@ def compile(self) -> None:
self.mlir_apply_transform_pass()
save_temp(mlir_atrn_dump_file, self._mlir_program.mlir_module)

self.mlir_apply_tensor_lowering_pass()
if self._config.using_tensors_hint:
self.mlir_apply_post_bufferize_transform_pass()
save_temp(mlir_tlwr_dump_file, self._mlir_program.mlir_module)

self._target.generate_code_for_target(self._mlir_program, dump_file=dump_file)
123 changes: 110 additions & 13 deletions src/xtc/backends/mlir/MlirCompilerPasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
TileUsingForOp,
VectorizeOp,
)
from mlir.dialects.transform.structured import structured_match
from mlir.dialects.transform.structured import (
structured_match,
ApplyFoldUnitExtentDimsViaSlicesPatternsOp,
MatchInterfaceEnum,
)
from mlir.dialects.transform.loop import loop_unroll
from mlir.dialects.transform import SplitHandleOp
from mlir.ir import (
Expand All @@ -43,6 +47,7 @@

_VECTO_SEQ_NAME = "_vecto"
_SUPER_VECTORIZE_SEQ_NAME = "_super_vectorize"
_POST_BUFFERIZE_SEQ_NAME = "_post_bufferize"


@dataclass
Expand Down Expand Up @@ -105,6 +110,7 @@ def __init__(
concluding_passes: list[str] = [],
always_vectorize: bool = True,
vectors_size: int | None = None,
using_tensors_hint: bool = False,
) -> None:
self._mlir_program = mlir_program
self._target = target
Expand All @@ -113,9 +119,11 @@ def __init__(
self._concluding_passes = concluding_passes
self._always_vectorize = always_vectorize
self._vectors_size = vectors_size
self._using_tensors_hint = using_tensors_hint
self._super_vectorize = self._vectors_size is not None
self._vecto_sequence: NamedSequenceOp | None = None
self._super_vectorize_sequence: NamedSequenceOp | None = None
self._post_vecto_sequence: NamedSequenceOp | None = None
self._named_sequence: NamedSequenceOp | None = None
self._nodes_schedules = (
self._mlir_schedule.schedule_impl if self._mlir_schedule is not None else []
Expand All @@ -139,6 +147,16 @@ def run(self) -> None:
[],
arg_attrs=[{"transform.consumed": UnitAttr.get()}],
)
if self._using_tensors_hint:
self._post_vecto_sequence = NamedSequenceOp(
_POST_BUFFERIZE_SEQ_NAME,
[transform.AnyOpType.get()],
[],
arg_attrs=[{"transform.readonly": UnitAttr.get()}],
)
assert self._post_vecto_sequence is not None
with InsertionPoint(self._post_vecto_sequence.body):
transform.YieldOp([])
if self._super_vectorize:
self._super_vectorize_sequence = NamedSequenceOp(
_SUPER_VECTORIZE_SEQ_NAME,
Expand Down Expand Up @@ -202,7 +220,7 @@ def _generate_scheduling(self) -> OpResult:
handle=handle,
)
if schedule.vectorization or self._always_vectorize:
self._post_vectorize(scheduling_state)
self._post_vectorize(scheduling_state, schedule)
handle = scheduling_state.handle
assert handle, "At least 1 operation should have been processed"
return handle
Expand Down Expand Up @@ -442,21 +460,39 @@ def _strip_mine(
def _vectorize(self, sched_state: SchedulingState):
if self._vectors_size is not None:
return
assert self._named_sequence is not None
vecto_handle = sched_state.handle

if self._using_tensors_hint:
parent_op = get_parent_op(
transform.AnyOpType.get(),
sched_state.handle,
isolated_from_above=True,
)
with InsertionPoint(transform.ApplyPatternsOp(parent_op).patterns):
ApplyFoldUnitExtentDimsViaSlicesPatternsOp()
vecto_handle = structured_match(
results_=transform.AnyOpType.get(),
target=parent_op, # self._named_sequence.bodyTarget,
interface=MatchInterfaceEnum.LinalgOp,
)

if self._target.has_custom_vectorize():
self._target.apply_custom_vectorize(sched_state.handle)
self._target.apply_custom_vectorize(vecto_handle)
else:
transform.IncludeOp(
results_=[],
target=_VECTO_SEQ_NAME,
failure_propagation_mode=2,
operands_=[sched_state.handle],
operands_=[vecto_handle],
)

def _post_vectorize(self, sched_state: SchedulingState):
def _post_vectorize(self, sched_state: SchedulingState, schedule: MlirNodeSchedule):
if self._vectors_size is not None:
return

post_vec_annotation = "_apply_post_vectorize_patterns"
if self._using_tensors_hint:
transform.AnnotateOp(sched_state.handle, post_vec_annotation)
parent_op = get_parent_op(
transform.AnyOpType.get(),
sched_state.handle,
Expand All @@ -465,9 +501,25 @@ def _post_vectorize(self, sched_state: SchedulingState):
with InsertionPoint(transform.ApplyPatternsOp(parent_op).patterns):
vector.ApplyVectorReductionToContractPatternsOp()
vector.ApplyTransferPermutationPatternsOp()
with InsertionPoint(transform.ApplyPatternsOp(parent_op).patterns):
vector.ApplyLowerOuterProductPatternsOp()
vector.ApplyLowerContractionPatternsOp()

if not self._post_vecto_sequence:
with InsertionPoint(transform.ApplyPatternsOp(parent_op).patterns):
vector.ApplyLowerOuterProductPatternsOp()
vector.ApplyLowerContractionPatternsOp()
else:
with (
InsertionPoint.at_block_begin(self._post_vecto_sequence.body),
self._mlir_program.mlir_context,
self._loc,
):
handle = structured_match(
results_=transform.AnyOpType.get(),
target=self._post_vecto_sequence.bodyTarget,
op_attrs={post_vec_annotation: UnitAttr.get()},
)
with InsertionPoint(transform.ApplyPatternsOp(handle).patterns):
vector.ApplyLowerOuterProductPatternsOp()
vector.ApplyLowerContractionPatternsOp()

def _unroll(
self,
Expand Down Expand Up @@ -543,16 +595,24 @@ class MlirProgramApplyTransformPass:
def __init__(
self,
mlir_program: RawMlirProgram,
clean_all: bool = False,
custom_sequence: None | str = None,
) -> None:
self._mlir_program = mlir_program
self._clean_all = clean_all
self._custom_sequence = custom_sequence

def run(self) -> None:
transform_op = [op for op in self._mlir_program.mlir_module.body.operations][-1]
transform = isinstance(transform_op, NamedSequenceOp)
assert transform
pm = PassManager(context=self._mlir_program.mlir_context)
for opt in transform_opts:
pm.add(opt) # type: ignore # no attribte add?
if self._custom_sequence:
for opt in transform_opts:
pm.add(f"{opt}{{entry-point={self._custom_sequence}}}") # type: ignore
else:
for opt in transform_opts:
pm.add(opt) # type: ignore
pm.run(self._mlir_program.mlir_module.operation)

while True:
Expand All @@ -561,5 +621,42 @@ def run(self) -> None:
][-1]
if isinstance(transform_op, NamedSequenceOp):
transform_op.erase()
else:
break
if self._clean_all:
continue
break


class MlirProgramApplyPasses:
def __init__(
self,
mlir_program: RawMlirProgram,
) -> None:
self._mlir_program = mlir_program

def run(self, pass_names: list[str]) -> None:
ctx = self._mlir_program.mlir_context
pm = PassManager(context=ctx)
for name in pass_names:
pm.add(name) # type: ignore # no attribute add
pm.run(self._mlir_program.mlir_module.operation)


def apply_bufferization_passes(mlir_program: RawMlirProgram):
apply_passes = MlirProgramApplyPasses(mlir_program)
bufferize_options = [
"bufferize-function-boundaries=1",
"function-boundary-type-conversion=identity-layout-map",
"buffer-alignment=256",
]
apply_passes.run(
[
"canonicalize",
"cse",
"eliminate-empty-tensors", # causes ops to write directly to out buffer
f"one-shot-bufferize{{{' '.join(bufferize_options)}}}",
# "func.func(buffer-hoisting)",
# "func.func(buffer-loop-hoisting)",
"drop-equivalent-buffer-results",
"func.func(promote-buffers-to-stack)",
]
)
2 changes: 2 additions & 0 deletions src/xtc/backends/mlir/MlirConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ class MlirConfig:
print_assembly: bool = False
visualize_jumps: bool = True
print_lowered_ir: bool = False
print_bufferization_ir: bool = False
debug: bool = False
color: bool = False
concluding_passes: list[str] = field(default_factory=list)
always_vectorize: bool = False
vectors_size: int | None = None
using_tensors_hint: bool = False
arch: str = "native"
cpu: str = "native"
selected_device: int | None = None
Expand Down
Loading
Loading