Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 28 additions & 135 deletions gtep/tests/unit/pyomo_object_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
#################################################################################


from io import StringIO
import re

from pyomo.environ import units as u
from pyomo.core.base.block import BlockData
from pyomo.core.base.component import ComponentData
import pyomo.environ as pyo


Expand Down Expand Up @@ -160,140 +158,35 @@ def check_all_objects(self):
else:
self._check_does_not_exist(properties)

@classmethod
def _extract_terms_from_string_expression(cls, expr: str) -> list[tuple]:
"""
Helper function for `parse_constraint_pprint` that extracts
individual terms from a string representation of an expression,
including flattening nested parentheses.

:param expr: Expression to be parsed.
:type expr: str
:returns: List of tuples. Each tuple corresponds to a term in the
expression and is of the form `(sign, term)` where `sign`
is either `1` or `-1` and `term` is the term itself
(including index, e.g. `"m.lines[branch_2_1]"`)
"""
expr = expr.replace(" ", "")
stack = [1] # sign context
sign = 1
term = ""
out = []

i = 0
while i < len(expr):
c = expr[i]
if c in "+-":
if term:
out.append((sign * stack[-1], term))
term = ""
sign = 1 if c == "+" else -1
elif c == "(":
stack.append(stack[-1] * sign)
sign = 1
elif c == ")":
if term:
out.append((sign * stack[-1], term))
term = ""
stack.pop()
else:
term += c
i += 1
if term:
out.append((sign * stack[-1], term))

return out

@classmethod
def parse_constraint_pprint(
cls, block_name: str, constraint: pyo.Constraint
) -> dict:
"""
Parses the output of a constraint's `.pprint()` function, returning
a dictionary representing the contents of the constraints' expressions,
of the form:
```
{
i: {
"expr": [(sign, term), (sign, term), ...],
"val": val
}
}
```
where
- `i` is an element of the constraint's index (per its `.pprint()` function)
- `val` is the value of the expression associated with `i`
- `term` is an individual term of the expression associated with `i`
- `sign` is the sign of an individual term of the expression associated with `i`

:param block_name: `.name` attribute of the block that `constraint` is on.
:param constraint: Constraint to be parsed.
:type block_name: str
:type constraint: pyomo.environ.Constraint
def check_expr_contains(self, c: ComponentData, expected: list | dict):
"""
buf = StringIO()
constraint.pprint(ostream=buf)
pprinted = buf.getvalue().replace(block_name, "b")

out = {}
for index_expr_pprinted in pprinted.split("\n")[3:-1]:
index_expr_split = index_expr_pprinted.split(":")
i = index_expr_split[0].strip()
expr = index_expr_split[2].strip()
val = index_expr_split[3].strip()
# print(i, ":", expr)

out[i] = {
"expr": cls._extract_terms_from_string_expression(expr),
"val": val,
}
return out

def check_constraint_for_terms(
self,
constraint_expr: list,
term_to_find: str,
expected_signs: list[int],
expected_indices: list[str],
):
Checks that the given component expr has exactly the given params and vars.

:param c: Component
:param expected: A list containing every param and var expected to be
in the expr of this component, or a dict mapping
index elements to such lists for indexed components.
:type c: ComponentData
:type expected: list | dict
"""
Checks that the given constraint expression contains expected term(s). Intended as a helper
function for constraint check functions.

:param constraint_expr: Constraint expression (`"expr"` value from an element of `parse_constraint_pprint`).
:param term_to_find: Name of term to find, not including the index (e.g., `"loads"`).
:param expected_signs: Expected signs of matching terms.
:param expected_indices: Expected indices of matching terms.
:type constraint_expr: dict
:type term_to_find: str
:type expected_signs: list[int]
:type expected_indices: list
"""
if len(expected_signs) != len(expected_indices):
raise ValueError(
"Expected signs and indices must be lists of matching length."
self._iter_func_over_index(self._nonindexed_expr_contains, c, expected=expected)

def _nonindexed_expr_contains(self, c: ComponentData, expected: list):
actual = [
a
for a in (
list(pyo.visitor.identify_variables(c.expr))
+ list(pyo.visitor.identify_mutable_parameters(c.expr))
)
if isinstance(a, ComponentData)
]

matching_terms = []
for sign, term in constraint_expr:
match = re.fullmatch(rf"{term_to_find}\[([^\]]+)\]", term)
if match:
matching_terms.append(
{
"sign": sign,
"index": match.group(1),
}
)
ids = [id(a) for a in actual]
names = [a.name for a in actual]
expected_names = [e.name for e in expected]

self.test_class.assertEqual(len(matching_terms), len(expected_signs))
for s, i in zip(expected_signs, expected_indices):
matching_term = [
term
for term in matching_terms
if term["sign"] == s and term["index"] == i
]
self.test_class.assertEqual(
len(matching_term),
1,
f"There should be only one term matching {'-' if s == -1 else ''}{term_to_find}[{i}]",
)
self.test_class.assertEqual(
len(actual), len(expected), f"{expected_names} vs {names}"
)
for exp, exp_name in zip(expected, expected_names):
self.test_class.assertIn(id(exp), ids, f"{exp_name} not in {names}")
Loading