diff --git a/gtep/tests/unit/pyomo_object_testing.py b/gtep/tests/unit/pyomo_object_testing.py index d518535..2ac4c92 100644 --- a/gtep/tests/unit/pyomo_object_testing.py +++ b/gtep/tests/unit/pyomo_object_testing.py @@ -12,11 +12,9 @@ ################################################################################# -from io import StringIO -import re - from pyomo.environ import units as u from pyomo.core.base.block import BlockData +from pyomo.core.base.component import ComponentData import pyomo.environ as pyo @@ -160,140 +158,35 @@ def check_all_objects(self): else: self._check_does_not_exist(properties) - @classmethod - def _extract_terms_from_string_expression(cls, expr: str) -> list[tuple]: - """ - Helper function for `parse_constraint_pprint` that extracts - individual terms from a string representation of an expression, - including flattening nested parentheses. - - :param expr: Expression to be parsed. - :type expr: str - :returns: List of tuples. Each tuple corresponds to a term in the - expression and is of the form `(sign, term)` where `sign` - is either `1` or `-1` and `term` is the term itself - (including index, e.g. `"m.lines[branch_2_1]"`) - """ - expr = expr.replace(" ", "") - stack = [1] # sign context - sign = 1 - term = "" - out = [] - - i = 0 - while i < len(expr): - c = expr[i] - if c in "+-": - if term: - out.append((sign * stack[-1], term)) - term = "" - sign = 1 if c == "+" else -1 - elif c == "(": - stack.append(stack[-1] * sign) - sign = 1 - elif c == ")": - if term: - out.append((sign * stack[-1], term)) - term = "" - stack.pop() - else: - term += c - i += 1 - if term: - out.append((sign * stack[-1], term)) - - return out - - @classmethod - def parse_constraint_pprint( - cls, block_name: str, constraint: pyo.Constraint - ) -> dict: - """ - Parses the output of a constraint's `.pprint()` function, returning - a dictionary representing the contents of the constraints' expressions, - of the form: - ``` - { - i: { - "expr": [(sign, term), (sign, term), ...], - "val": val - } - } - ``` - where - - `i` is an element of the constraint's index (per its `.pprint()` function) - - `val` is the value of the expression associated with `i` - - `term` is an individual term of the expression associated with `i` - - `sign` is the sign of an individual term of the expression associated with `i` - - :param block_name: `.name` attribute of the block that `constraint` is on. - :param constraint: Constraint to be parsed. - :type block_name: str - :type constraint: pyomo.environ.Constraint + def check_expr_contains(self, c: ComponentData, expected: list | dict): """ - buf = StringIO() - constraint.pprint(ostream=buf) - pprinted = buf.getvalue().replace(block_name, "b") - - out = {} - for index_expr_pprinted in pprinted.split("\n")[3:-1]: - index_expr_split = index_expr_pprinted.split(":") - i = index_expr_split[0].strip() - expr = index_expr_split[2].strip() - val = index_expr_split[3].strip() - # print(i, ":", expr) - - out[i] = { - "expr": cls._extract_terms_from_string_expression(expr), - "val": val, - } - return out - - def check_constraint_for_terms( - self, - constraint_expr: list, - term_to_find: str, - expected_signs: list[int], - expected_indices: list[str], - ): + Checks that the given component expr has exactly the given params and vars. + + :param c: Component + :param expected: A list containing every param and var expected to be + in the expr of this component, or a dict mapping + index elements to such lists for indexed components. + :type c: ComponentData + :type expected: list | dict """ - Checks that the given constraint expression contains expected term(s). Intended as a helper - function for constraint check functions. - - :param constraint_expr: Constraint expression (`"expr"` value from an element of `parse_constraint_pprint`). - :param term_to_find: Name of term to find, not including the index (e.g., `"loads"`). - :param expected_signs: Expected signs of matching terms. - :param expected_indices: Expected indices of matching terms. - :type constraint_expr: dict - :type term_to_find: str - :type expected_signs: list[int] - :type expected_indices: list - """ - if len(expected_signs) != len(expected_indices): - raise ValueError( - "Expected signs and indices must be lists of matching length." + self._iter_func_over_index(self._nonindexed_expr_contains, c, expected=expected) + + def _nonindexed_expr_contains(self, c: ComponentData, expected: list): + actual = [ + a + for a in ( + list(pyo.visitor.identify_variables(c.expr)) + + list(pyo.visitor.identify_mutable_parameters(c.expr)) ) + if isinstance(a, ComponentData) + ] - matching_terms = [] - for sign, term in constraint_expr: - match = re.fullmatch(rf"{term_to_find}\[([^\]]+)\]", term) - if match: - matching_terms.append( - { - "sign": sign, - "index": match.group(1), - } - ) + ids = [id(a) for a in actual] + names = [a.name for a in actual] + expected_names = [e.name for e in expected] - self.test_class.assertEqual(len(matching_terms), len(expected_signs)) - for s, i in zip(expected_signs, expected_indices): - matching_term = [ - term - for term in matching_terms - if term["sign"] == s and term["index"] == i - ] - self.test_class.assertEqual( - len(matching_term), - 1, - f"There should be only one term matching {'-' if s == -1 else ''}{term_to_find}[{i}]", - ) + self.test_class.assertEqual( + len(actual), len(expected), f"{expected_names} vs {names}" + ) + for exp, exp_name in zip(expected, expected_names): + self.test_class.assertIn(id(exp), ids, f"{exp_name} not in {names}")