diff --git a/pytato/array.py b/pytato/array.py index 3c66ef251..86310f2e5 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -609,7 +609,7 @@ class ReductionDescriptor(Taggable): Records information about a reduction dimension in an :class:`~pytato.Array`'. """ - tags: frozenset[Tag] + tags: frozenset[Tag] = frozenset() @override def _with_new_tags(self, tags: frozenset[Tag]) -> ReductionDescriptor: @@ -1766,7 +1766,7 @@ def einsum(subscripts: str, *operands: Array, for descr in index_to_descr.values(): if (isinstance(descr, EinsumReductionAxis) and descr not in redn_axis_to_redn_descr): - redn_axis_to_redn_descr[descr] = ReductionDescriptor(frozenset()) + redn_axis_to_redn_descr[descr] = ReductionDescriptor() # }}} @@ -2339,9 +2339,21 @@ class CSRMatmul(SparseMatmul): .. attribute:: array The :class:`Array` to which the sparse matrix is being applied. + + .. attribute:: reduction_var + + The index variable for the per-row reduction. + + .. attribute:: reduction_descr + + The :class:`ReductionDescriptor` for the per-row reduction. """ matrix: CSRMatrix array: Array + reduction_var: str = "_r0" + reduction_descr: ReductionDescriptor = dataclasses.field( + kw_only=True, + default_factory=ReductionDescriptor) @property @override @@ -2353,6 +2365,24 @@ def _matrix(self) -> SparseMatrix: def _array(self) -> Array: return self.array + def with_tagged_reduction(self, tags: Tag | Iterable[Tag]) -> CSRMatmul: + """ + Returns a copy of *self* with its :class:`ReductionDescriptor` + tagged with *tag*. + """ + new_redn_descr = self.reduction_descr.tagged(tags) + if new_redn_descr is not self.reduction_descr: + return type(self)( + matrix=self.matrix, + array=self.array, + axes=self.axes, + reduction_descr=new_redn_descr, + tags=self.tags, + non_equality_tags=self.non_equality_tags) + else: + return self + + # }}} @@ -3271,7 +3301,7 @@ def make_index_lambda( for redn_var in redn_vars: redn_descr = var_to_reduction_descr.get(redn_var, - ReductionDescriptor(frozenset())) + ReductionDescriptor()) if not isinstance(redn_descr, ReductionDescriptor): raise TypeError(f"reduction_dim for {redn_var} expected to be" f" of type ReductionDescriptor, got {type(redn_descr)}.") diff --git a/pytato/reductions.py b/pytato/reductions.py index 6efa45ac2..a616c6e95 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -261,7 +261,7 @@ def _get_var_to_redn_descr( idx = f"_r{n_redn_dims}" redn_descr = axis_to_reduction_descr.get( idim, - ReductionDescriptor(frozenset())) + ReductionDescriptor()) if not isinstance(redn_descr, ReductionDescriptor): raise TypeError(f"'axis_to_reduction_descr[{idim}]': " "expected an instance of ReductionDescriptor, " diff --git a/pytato/stringifier.py b/pytato/stringifier.py index de08ade1b..763155d6b 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -137,7 +137,7 @@ def _map_generic_array(self, expr: Array, depth: int) -> str: fields = tuple(field for field in fields if field != "axes") if (isinstance(expr, IndexLambda) - and all(redn_descr == ReductionDescriptor(frozenset()) + and all(redn_descr == ReductionDescriptor() for redn_descr in expr.var_to_reduction_descr.values())): # prettify: if trivial 'expr.var_to_reduction_descr' => don't print. fields = tuple(field diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 773ee4cbe..3fef2907e 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -51,7 +51,6 @@ IndexExpr, IndexLambda, NormalizedSlice, - ReductionDescriptor, Reshape, Roll, ShapeComponent, @@ -732,6 +731,8 @@ def map_axis_permutation(self, expr: AxisPermutation) -> IndexLambda: non_equality_tags=expr.non_equality_tags) def map_csr_matmul(self, expr: CSRMatmul) -> IndexLambda: + redn_var = expr.reduction_var + rec_matrix_elem_values = self.rec(expr.matrix.elem_values) rec_matrix_elem_col_indices = self.rec(expr.matrix.elem_col_indices) rec_matrix_row_starts = self.rec(expr.matrix.row_starts) @@ -740,15 +741,15 @@ def map_csr_matmul(self, expr: CSRMatmul) -> IndexLambda: from pytato.reductions import SumReductionOperation from pytato.scalar_expr import Reduce index_expr = Reduce( - prim.Variable("_in0")[prim.Variable("_r0"),] + prim.Variable("_in0")[prim.Variable(redn_var),] * prim.Variable("_in3")[( - prim.Variable("_in1")[prim.Variable("_r0"),], + prim.Variable("_in1")[prim.Variable(redn_var),], *( prim.Variable(f"_{idim}") for idim in range(1, rec_array.ndim)))], SumReductionOperation(), constantdict({ - "_r0": ( + redn_var: ( prim.Variable("_in2")[prim.Variable("_0"),], prim.Variable("_in2")[prim.Variable("_0") + 1,])})) @@ -762,7 +763,7 @@ def map_csr_matmul(self, expr: CSRMatmul) -> IndexLambda: "_in3": rec_array}), axes=expr.axes, var_to_reduction_descr=constantdict({ - "_r0": ReductionDescriptor(tags=frozenset())}), + redn_var: expr.reduction_descr}), tags=expr.tags, non_equality_tags=expr.non_equality_tags) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 36aab2391..0bb99aaf2 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -507,6 +507,11 @@ def _attach_tags(self, expr: Array, rec_expr: Array) -> Array: self.axis_to_tags.get((expr, redn_var), []) ) + if isinstance(expr, CSRMatmul): + assert isinstance(result, CSRMatmul) + tags = self.axis_to_tags.get((expr, expr.reduction_var), []) + result = result.with_tagged_reduction(tags) + # }}} return result