diff --git a/mps-cli-py/src/mpscli/model/SModel.py b/mps-cli-py/src/mpscli/model/SModel.py index d54ab04..333de27 100644 --- a/mps-cli-py/src/mpscli/model/SModel.py +++ b/mps-cli-py/src/mpscli/model/SModel.py @@ -2,9 +2,10 @@ class SModel: - def __init__(self, name, uuid, is_do_not_generate): + def __init__(self, name, uuid, is_do_not_generate, imported_models=None): self.name = name self.uuid = uuid + self.imported_models = {} if imported_models is None else imported_models self.root_nodes = [] self.path_to_model_file = "" self.is_do_not_generate = is_do_not_generate diff --git a/mps-cli-py/src/mpscli/model/builder/SModelBuilderBase.py b/mps-cli-py/src/mpscli/model/builder/SModelBuilderBase.py index 1230b90..2210afe 100644 --- a/mps-cli-py/src/mpscli/model/builder/SModelBuilderBase.py +++ b/mps-cli-py/src/mpscli/model/builder/SModelBuilderBase.py @@ -51,21 +51,32 @@ def is_model_generatable(model_xml_node): for attribute in model_xml_node.findall("attribute") ) + @staticmethod + def extract_imported_models(self, model_xml_node): + imported_models = {} + imports_xml_node = model_xml_node.find("imports") + if imports_xml_node is None: + return imported_models + + for import_xml_node in imports_xml_node.findall("import"): + import_index = import_xml_node.get("index") + imported_model_ref = import_xml_node.get("ref") + imported_model_uuid = imported_model_ref[0 : imported_model_ref.find("(")] + imported_models[import_index] = imported_model_uuid + self.index_2_imported_model_uuid[import_index] = imported_model_uuid + + return imported_models + def extract_model_core_info(self, model_xml_node): model_ref = model_xml_node.get("ref") model_name = model_ref[model_ref.find("(") + 1 : len(model_ref) - 1] model_uuid = model_ref[0 : model_ref.find("(")] model_is_do_not_generate = self.is_model_generatable(model_xml_node) - model = SModel(model_name, model_uuid, model_is_do_not_generate) + model_imported_models = self.extract_imported_models(self, model_xml_node) + model = SModel(model_name, model_uuid, model_is_do_not_generate, model_imported_models) return model - def extract_imports_and_registry(self, model_xml_node): - imports_xml_node = model_xml_node.find("imports") - for import_xml_node in imports_xml_node.findall("import"): - import_index = import_xml_node.get("index") - imported_model_ref = import_xml_node.get("ref") - imported_model_uuid = imported_model_ref[0: imported_model_ref.find("(")] - self.index_2_imported_model_uuid[import_index] = imported_model_uuid + def extract_registry(self, model_xml_node): registry_xml_node = model_xml_node.find("registry") for language_xml_node in registry_xml_node.findall("language"): language_id = language_xml_node.get("id") diff --git a/mps-cli-py/src/mpscli/model/builder/SModelBuilderBinaryPersistency.py b/mps-cli-py/src/mpscli/model/builder/SModelBuilderBinaryPersistency.py index 794a7bd..0931349 100644 --- a/mps-cli-py/src/mpscli/model/builder/SModelBuilderBinaryPersistency.py +++ b/mps-cli-py/src/mpscli/model/builder/SModelBuilderBinaryPersistency.py @@ -93,12 +93,12 @@ def build(self, path_to_model: str): uuid_str = model_uuid or "r:unknown" name_str = model_name or "unknown.model" - model = SModel(name_str, uuid_str, False) # Import index 0 is always the current model's own uuid.. # Java: SModel.importedModels() lists imports starting from index 1 and index 0 is implicitly # the model itself used when resolving REF_THIS_MODEL self.index_2_imported_model_uuid["0"] = uuid_str + model = SModel(name_str, uuid_str, False) # 2. registry - builds concept/property/reference/child index maps load_registry(reader, self) @@ -116,6 +116,12 @@ def build(self, path_to_model: str): advance_until_after(reader, MODEL_START) return model + model.imported_models = { + index: imported_model_uuid + for index, imported_model_uuid in self.index_2_imported_model_uuid.items() + if index != "0" + } + # 4. MODEL_START token = reader.read_u32() if token != MODEL_START: diff --git a/mps-cli-py/src/mpscli/model/builder/SModelBuilderDefaultPersistency.py b/mps-cli-py/src/mpscli/model/builder/SModelBuilderDefaultPersistency.py index bb530ae..9dcff52 100644 --- a/mps-cli-py/src/mpscli/model/builder/SModelBuilderDefaultPersistency.py +++ b/mps-cli-py/src/mpscli/model/builder/SModelBuilderDefaultPersistency.py @@ -10,7 +10,7 @@ def build(self, path): model_xml_node = tree.getroot() model = self.extract_model_core_info(model_xml_node) model.path_to_model_file = path - self.extract_imports_and_registry(model_xml_node) + self.extract_registry(model_xml_node) for node_xml_node in model_xml_node.findall("node"): root_node = self.extract_node(model, node_xml_node, None) diff --git a/mps-cli-py/src/mpscli/model/builder/SModelBuilderFilePerRootPersistency.py b/mps-cli-py/src/mpscli/model/builder/SModelBuilderFilePerRootPersistency.py index 2ba25fc..76e503f 100644 --- a/mps-cli-py/src/mpscli/model/builder/SModelBuilderFilePerRootPersistency.py +++ b/mps-cli-py/src/mpscli/model/builder/SModelBuilderFilePerRootPersistency.py @@ -22,7 +22,7 @@ def build(self, path): def extract_root_node(self, model, mpsr_file): tree = ET.parse(mpsr_file) model_xml_node = tree.getroot() - self.extract_imports_and_registry(model_xml_node) + self.extract_registry(model_xml_node) root_node = model_xml_node.find("node") return self.extract_node(model, root_node, None) diff --git a/mps-cli-py/tests/binary/test_binary_model_imports.py b/mps-cli-py/tests/binary/test_binary_model_imports.py new file mode 100644 index 0000000..82fcce5 --- /dev/null +++ b/mps-cli-py/tests/binary/test_binary_model_imports.py @@ -0,0 +1,25 @@ +import unittest + +from parameterized import parameterized + +from tests.test_base import TestBase + + +class TestBinaryModelImports(TestBase): + + @parameterized.expand( + [ + ( + "mps_cli_binary_persistency_generated", + "mps.cli.lanuse.library_top.binary_persistency.library_top", + "r:cf91f372-8bfd-44b8-8e34-024eb23e64a8", + ), + ] + ) + def test_model_imports(self, test_data_location, model_name, imported_model_uuid): + self.doSetUp(test_data_location) + + model = self.repo.find_model_by_name(model_name) + self.assertNotEqual(None, model) + self.assertIsInstance(model.imported_models, dict) + self.assertIn(imported_model_uuid, model.imported_models.values()) \ No newline at end of file diff --git a/mps-cli-py/tests/test_model_imports.py b/mps-cli-py/tests/test_model_imports.py new file mode 100644 index 0000000..323c8d6 --- /dev/null +++ b/mps-cli-py/tests/test_model_imports.py @@ -0,0 +1,30 @@ +import unittest + +from parameterized import parameterized + +from tests.test_base import TestBase + + +class TestModelImports(TestBase): + + @parameterized.expand( + [ + ( + "mps_cli_lanuse_file_per_root", + "mps.cli.lanuse.library_top.library_top", + "r:ec5f093b-9d83-43a1-9b41-b5952da8b1ed", + ), + ( + "mps_cli_lanuse_default_persistency", + "mps.cli.lanuse.library_top.default_persistency.library_top", + "r:ca00da79-915e-4bdb-9c30-11a341daf779", + ), + ] + ) + def test_model_imports(self, test_data_location, model_name, imported_model_uuid): + self.doSetUp(test_data_location) + + model = self.repo.find_model_by_name(model_name) + self.assertNotEqual(None, model) + self.assertIsInstance(model.imported_models, dict) + self.assertIn(imported_model_uuid, model.imported_models.values()) \ No newline at end of file