diff --git a/ingestion/setup.py b/ingestion/setup.py index 1c817e78ece0..9776a8f053f6 100644 --- a/ingestion/setup.py +++ b/ingestion/setup.py @@ -165,6 +165,7 @@ "sqlalchemy>=2.0.0,<3", "collate-sqllineage>=2.1.1", "tabulate==0.9.0", + "tenacity>=8.0,<10", "typing-inspect", "packaging", # For version parsing "setuptools>=78.1.1,<81", # <81 required: pkg_resources removed in setuptools 81+ diff --git a/ingestion/src/metadata/domain/__init__.py b/ingestion/src/metadata/domain/__init__.py new file mode 100644 index 000000000000..8cc65ba75bbb --- /dev/null +++ b/ingestion/src/metadata/domain/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OpenMetadata domain utilities. + +In-memory helpers operating on OpenMetadata's data model, reusable across +service-source bases and features. A module belongs here when it satisfies +ALL of: + +1. Knows OM concepts (operates on OM-generated types or OM-specific ideas). +2. Owns no I/O infrastructure. May use an INJECTED OM client for read-only + queries; the client's lifecycle is the caller's. +3. Framework-independent — no topology, stages, or sinks. +4. Cross-cutting — used by more than one service-source base or feature. +""" diff --git a/ingestion/src/metadata/domain/tags/__init__.py b/ingestion/src/metadata/domain/tags/__init__.py new file mode 100644 index 000000000000..2ee6134b0bca --- /dev/null +++ b/ingestion/src/metadata/domain/tags/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tag and Classification domain utilities.""" + +from metadata.domain.tags.canonicalizer import Canonical, TagCanonicalizer +from metadata.domain.tags.registry import ScopeAlreadyClearedError, TagRegistry + +__all__ = [ + "Canonical", + "ScopeAlreadyClearedError", + "TagCanonicalizer", + "TagRegistry", +] diff --git a/ingestion/src/metadata/domain/tags/canonicalizer.py b/ingestion/src/metadata/domain/tags/canonicalizer.py new file mode 100644 index 000000000000..199b704d2ec6 --- /dev/null +++ b/ingestion/src/metadata/domain/tags/canonicalizer.py @@ -0,0 +1,136 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TagCanonicalizer — case-corrected name resolution against OpenMetadata. + +Resolves source-system Classification and Tag names to the canonical form +of any matching system-provider entity in OM (e.g., source reports +``pii.sensitive`` → returns ``PII.Sensitive``). Persistent ES failures +raise after retry exhaustion. +""" + +import logging +import threading +from collections.abc import Iterable +from typing import Any, NamedTuple, cast + +from tenacity import ( + before_sleep_log, + retry, + stop_after_attempt, + wait_random_exponential, +) + +from metadata.generated.schema.entity.classification.classification import Classification +from metadata.generated.schema.entity.classification.tag import Tag +from metadata.generated.schema.type.basic import ProviderType +from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.utils import fqn +from metadata.utils.logger import ingestion_logger + +logger = ingestion_logger() + + +_es_retry = retry( + stop=stop_after_attempt(5), + wait=wait_random_exponential(multiplier=2, max=30), + reraise=True, + before_sleep=before_sleep_log(logger, logging.WARNING), +) + + +class Canonical(NamedTuple): + """Canonical (name, description) pair returned from OpenMetadata.""" + + name: str + description: str | None + + +class TagCanonicalizer: + """Case-corrected name resolution for system Classifications and Tags. + + Persistent ES failures raise; callers should wrap in ``Either`` to + surface them to workflow status. + """ + + def __init__(self, metadata: OpenMetadata) -> None: + self._metadata = metadata + self._classification_cache: dict[str, Canonical] = {} + self._tag_cache: dict[str, Canonical] = {} + self._lock = threading.RLock() + + def classification( + self, + name: str, + description: str | None = None, + ) -> Canonical: + """Return canonical classification name + description from OM, cached.""" + key = name.lower() + with self._lock: + cached = self._classification_cache.get(key) + if cached is not None: + return cached + + results = self._es_search(Classification, name) + canonical = Canonical(name=name, description=description) + for entity in results: + if entity.provider == ProviderType.system and entity.name.root.lower() == key: + canonical = Canonical( + name=entity.name.root, + description=entity.description.root if entity.description else description, + ) + break + + with self._lock: + self._classification_cache.setdefault(key, canonical) + return canonical + + def tag( + self, + classification_name: str, + tag_name: str, + tag_description: str | None = None, + ) -> Canonical: + """Return canonical tag name + description from OM, cached. + + ``classification_name`` must already be canonical (call ``classification`` first). + """ + tag_fqn = cast( + "str", + fqn.build(None, Tag, classification_name=classification_name, tag_name=tag_name), + ) + key = tag_fqn.lower() + with self._lock: + cached = self._tag_cache.get(key) + if cached is not None: + return cached + + results = self._es_search(Tag, tag_fqn) + canonical = Canonical(name=tag_name, description=tag_description) + for entity in results: + if ( + entity.provider == ProviderType.system + and entity.classification.name == classification_name + and entity.name.root.lower() == tag_name.lower() + ): + canonical = Canonical( + name=entity.name.root, + description=entity.description.root if entity.description else tag_description, + ) + break + + with self._lock: + self._tag_cache.setdefault(key, canonical) + return canonical + + @_es_retry + def _es_search(self, entity_type: Any, search_string: str) -> Iterable[Any]: + """Run an ES search by FQN with retries.""" + return self._metadata.es_search_from_fqn(entity_type=entity_type, fqn_search_string=search_string) or [] diff --git a/ingestion/src/metadata/domain/tags/registry.py b/ingestion/src/metadata/domain/tags/registry.py new file mode 100644 index 000000000000..a959e72132bb --- /dev/null +++ b/ingestion/src/metadata/domain/tags/registry.py @@ -0,0 +1,235 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TagRegistry — per-Source bookkeeping for Tag and Classification ingestion. + +Holds two concerns: + +* a queue of classification/tag create-payloads bound for the sink + (deduped by FQN, drained per scope), and +* a per-entity-FQN lookup of ``TagLabel`` instances for inheritance + reads, dropped at scope boundaries. + +Dedup is case-sensitive, matching OpenMetadata's tag-identity rule. +Safe for concurrent use across the topology's parallel schema workers. +""" + +import threading +from collections.abc import Iterable +from typing import NamedTuple, cast + +from metadata.generated.schema.api.classification.createClassification import ( + CreateClassificationRequest, +) +from metadata.generated.schema.api.classification.createTag import CreateTagRequest +from metadata.generated.schema.entity.classification.tag import Tag +from metadata.generated.schema.type.basic import ( + EntityName, + FullyQualifiedEntityName, + Markdown, +) +from metadata.generated.schema.type.tagLabel import ( + LabelType, + State, + TagFQN, + TagLabel, + TagSource, +) +from metadata.ingestion.models.ometa_classification import OMetaTagAndClassification +from metadata.ingestion.ometa.ometa_api import OpenMetadata +from metadata.ingestion.ometa.utils import model_str +from metadata.utils import fqn +from metadata.utils.logger import ingestion_logger + +logger = ingestion_logger() + + +class _TagLabelKey(NamedTuple): + """Identity tuple for the TagLabel cache.""" + + classification_name: str + tag_name: str + label_type: LabelType + state: State + + +class ScopeAlreadyClearedError(RuntimeError): + """Raised when 'attach' is called for a previously cleared scope. + + Surfaces topology lifecycle bug loudly rather than silently re-creating a cleared scope. + """ + + +class TagRegistry: + """Registry for Tag and Classification ingestion bookkeeping.""" + + def __init__(self, metadata: OpenMetadata) -> None: + self._metadata = metadata + + self._known_tag_fqns: set[str] = set() + self._tag_label_cache: dict[_TagLabelKey, TagLabel] = {} + self._pending: list[OMetaTagAndClassification] = [] + self._cleared_scopes: set[str] = set() + self._labels_by_entity: dict[str, list[TagLabel]] = {} + + self._lock = threading.Lock() + + def _intern_tag_label_locked( + self, *, classification_name: str, tag_name: str, label_type: LabelType, state: State + ) -> TagLabel: + """Return the shared ``TagLabel`` for the given key. Caller must hold ``self._lock``.""" + key = _TagLabelKey(classification_name, tag_name, label_type, state) + cached = self._tag_label_cache.get(key) + if cached is not None: + return cached + tag_fqn = cast("str", fqn.build(None, Tag, classification_name=classification_name, tag_name=tag_name)) + cached = TagLabel( # pyright: ignore[reportCallIssue] + tagFQN=TagFQN(tag_fqn), + labelType=label_type, + state=state, + source=TagSource.Classification, + ) + self._tag_label_cache[key] = cached + return cached + + def attach( + self, + *, + scope_fqn: str, + entity_fqn: str, + classification_name: str, + tag_name: str, + classification_description: str | None, + tag_description: str | None, + label_type: LabelType = LabelType.Automated, + state: State = State.Suggested, + ) -> None: + """Register a tag <-> entity association.""" + if not tag_name or not tag_name.strip(): + logger.debug("TagRegistry: skipping empty tag for classification %s", classification_name) + return + + with self._lock: + if scope_fqn in self._cleared_scopes: + raise ScopeAlreadyClearedError( + f"Tag attach called for cleared scope '{scope_fqn!r}' for entity '{entity_fqn!r}'" + ) + tag_label = self._intern_tag_label_locked( + classification_name=classification_name, + tag_name=tag_name, + label_type=label_type, + state=state, + ) + self._labels_by_entity.setdefault(entity_fqn, []).append(tag_label) + + tag_fqn = model_str(tag_label.tagFQN) + if tag_fqn not in self._known_tag_fqns: + self._known_tag_fqns.add(tag_fqn) + self._pending.append( + self._build_pending_record( + classification_name=classification_name, + classification_description=classification_description, + tag_name=tag_name, + tag_description=tag_description, + ) + ) + + def labels_for(self, entity_fqn: str) -> list[TagLabel]: + """Return tag labels attached to ``entity_fqn`` (idempotent; returns a copy).""" + with self._lock: + return list(self._labels_by_entity.get(entity_fqn, [])) + + def drain(self) -> Iterable[OMetaTagAndClassification]: + """Yield all queued create payloads and clear the queue.""" + with self._lock: + pending, self._pending = self._pending, [] + + if pending: + logger.debug("TagRegistry: drained %d pending tag payloads.", len(pending)) + yield from pending + + def clear_scope(self, scope_fqn: str) -> None: + """Drop labels under ``scope_fqn`` and mark the scope cleared. + + Subsequent ``attach`` calls for this scope will raise. + """ + prefix = scope_fqn + fqn.FQN_SEPARATOR + + with self._lock: + self._cleared_scopes.add(scope_fqn) + kept = {k: v for k, v in self._labels_by_entity.items() if k != scope_fqn and not k.startswith(prefix)} + dropped = len(self._labels_by_entity) - len(kept) + self._labels_by_entity = kept + if dropped: + logger.debug("TagRegistry: cleared scope %s (%d entity labels dropped)", scope_fqn, dropped) + + def is_known(self, tag_fqn: str) -> bool: + """Return True if the tag FQN has been recorded (case-sensitive match).""" + with self._lock: + return tag_fqn in self._known_tag_fqns + + def ensure_known(self, tag_fqn: str) -> bool: + """Return True if the tag exists server-side, caching positive results. + + Returns False (and does NOT cache) on 404 or transport error. + """ + if self.is_known(tag_fqn): + return True + + logger.debug("TagRegistry: cache miss for %s; fetching from OpenMetadata.", tag_fqn) + try: + entity = self._metadata.get_by_name(entity=Tag, fqn=tag_fqn) + except Exception: + logger.exception("TagRegistry: tag lookup failed for %s.", tag_fqn) + return False + + if entity is None: + logger.warning( + "TagRegistry: tag %s not found in OpenMetadata; labels referencing it will be skipped.", tag_fqn + ) + return False + + with self._lock: + self._known_tag_fqns.add(tag_fqn) + return True + + def stats(self) -> dict[str, int]: + """Return current state counts for instrumentation.""" + with self._lock: + return { + "known_tag_fqns": len(self._known_tag_fqns), + "tag_label_cache": len(self._tag_label_cache), + "pending": len(self._pending), + "cleared_scopes": len(self._cleared_scopes), + "live_entities": len(self._labels_by_entity), + "live_labels": sum(len(v) for v in self._labels_by_entity.values()), + } + + @staticmethod + def _build_pending_record( + *, + classification_name: str, + classification_description: str | None, + tag_name: str, + tag_description: str | None, + ) -> OMetaTagAndClassification: + """Compose the sink-bound create-payload for a classification + tag.""" + return OMetaTagAndClassification( + fqn=None, + classification_request=CreateClassificationRequest( # pyright: ignore[reportCallIssue] + name=EntityName(classification_name), + description=Markdown(classification_description or ""), + ), + tag_request=CreateTagRequest( # pyright: ignore[reportCallIssue] + classification=FullyQualifiedEntityName(classification_name), + name=EntityName(tag_name), + description=Markdown(tag_description or ""), + ), + ) diff --git a/ingestion/src/metadata/ingestion/models/topology.py b/ingestion/src/metadata/ingestion/models/topology.py index ebc919b6b203..003682b3400d 100644 --- a/ingestion/src/metadata/ingestion/models/topology.py +++ b/ingestion/src/metadata/ingestion/models/topology.py @@ -15,7 +15,7 @@ import queue import threading from functools import cache, singledispatchmethod -from typing import Any, Dict, Generic, List, Optional, Type, TypeVar # noqa: UP035 +from typing import Annotated, Any, Dict, Generic, List, Optional, Type, TypeVar # noqa: UP035 from pydantic import BaseModel, ConfigDict, Field, create_model @@ -111,14 +111,18 @@ class TopologyNode(BaseModel): "Each stage accepts the producer results as an argument" ), ) - children: Optional[List[str]] = Field(None, description="Nodes to execute next") # noqa: UP006, UP045 - post_process: Optional[List[str]] = Field( # noqa: UP006, UP045 - None, description="Method to be run after the node has been fully processed" - ) - threads: bool = Field( - False, - description="Flag that defines if a node is open to MultiThreading processing.", - ) + children: Annotated[ + list[str] | None, + Field(description="Nodes to execute next"), + ] = None + post_process: Annotated[ + list[str] | None, + Field(description="Method to be run after the node has been fully processed"), + ] = None + threads: Annotated[ + bool, + Field(description="Flag that defines if a node is open to MultiThreading processing."), + ] = False class ServiceTopology(BaseModel): diff --git a/ingestion/src/metadata/ingestion/source/database/database_service.py b/ingestion/src/metadata/ingestion/source/database/database_service.py index 4472404a0473..31628601ac6a 100644 --- a/ingestion/src/metadata/ingestion/source/database/database_service.py +++ b/ingestion/src/metadata/ingestion/source/database/database_service.py @@ -14,12 +14,13 @@ import traceback from abc import ABC, abstractmethod -from typing import Any, Iterable, List, Optional, Set, Tuple # noqa: UP035 +from typing import Any, Iterable, List, Optional, Set, Tuple, cast # noqa: UP035 from pydantic import BaseModel, Field from sqlalchemy.engine import Inspector from typing_extensions import Annotated # noqa: UP035 +from metadata.domain.tags import TagCanonicalizer, TagRegistry from metadata.generated.schema.api.data.createDatabase import CreateDatabaseRequest from metadata.generated.schema.api.data.createDatabaseSchema import ( CreateDatabaseSchemaRequest, @@ -160,6 +161,7 @@ class DatabaseServiceTopology(ServiceTopology): "mark_schemas_as_deleted", "mark_tables_as_deleted", "mark_stored_procedures_as_deleted", + "clear_database_tag_scope", ], threads=True, ) @@ -186,6 +188,7 @@ class DatabaseServiceTopology(ServiceTopology): nullable=True, ), ], + post_process=["clear_schema_tag_scope"], ) stored_procedure: Annotated[TopologyNode, Field(description="Stored Procedure Node")] = TopologyNode( producer="get_stored_procedures", @@ -224,6 +227,26 @@ class DatabaseServiceSource(TopologyRunnerMixin, Source, ABC): # pylint: disabl topology = DatabaseServiceTopology() context = TopologyContextManager(topology) + # ``vars(self).setdefault(...)`` for thread-safe lazy init. + # See: https://docs.python.org/3/library/threadsafety.html + @property + def tags_registry(self) -> TagRegistry: + """Per-Source registry tracking tag/classification ingestion state.""" + instance_dict = vars(self) + cached = instance_dict.get("tags_registry") + if cached is not None: + return cached + return instance_dict.setdefault("tags_registry", TagRegistry(metadata=self.metadata)) + + @property + def tag_canonicalizer(self) -> TagCanonicalizer: + """Per-Source canonicalizer for case-corrected tag/classification names.""" + instance_dict = vars(self) + cached = instance_dict.get("tag_canonicalizer") + if cached is not None: + return cached + return instance_dict.setdefault("tag_canonicalizer", TagCanonicalizer(metadata=self.metadata)) + @property def name(self) -> str: return self.service_connection.type.name @@ -811,6 +834,39 @@ def yield_life_cycle_data(self, _) -> Iterable[Either[OMetaLifeCycleData]]: Get the life cycle data of the table """ + def clear_schema_tag_scope(self): + """Drop tag-registry state for the current schema scope.""" + schema_name = self.context.get().database_schema # pyright: ignore[reportAttributeAccessIssue] + if schema_name: + schema_fqn = cast( + "str", + fqn.build( + self.metadata, + entity_type=DatabaseSchema, + service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue] + database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue] + schema_name=schema_name, + ), + ) + self.tags_registry.clear_scope(schema_fqn) + yield from () + + def clear_database_tag_scope(self): + """Drop tag-registry state for the current database scope.""" + database_name = self.context.get().database # pyright: ignore[reportAttributeAccessIssue] + if database_name: + database_fqn = cast( + "str", + fqn.build( + self.metadata, + entity_type=Database, + service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue] + database_name=database_name, + ), + ) + self.tags_registry.clear_scope(database_fqn) + yield from () + def yield_external_table_lineage(self) -> Iterable[Either[AddLineageRequest]]: """ Process external table lineage diff --git a/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py b/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py index 817407e568bf..b48c751e852f 100644 --- a/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py +++ b/ingestion/src/metadata/ingestion/source/database/snowflake/metadata.py @@ -15,7 +15,7 @@ import json # noqa: I001 import traceback from datetime import datetime -from typing import Iterable, List, Optional, Tuple # noqa: UP035 +from typing import Iterable, List, Optional, Tuple, cast # noqa: UP035 import sqlalchemy.types as sqltypes import sqlparse @@ -37,6 +37,7 @@ StoredProcedureType, ) from metadata.generated.schema.entity.data.table import ( + Column, PartitionColumnDetails, PartitionIntervalTypes, Table, @@ -54,7 +55,6 @@ ) from metadata.generated.schema.type.basic import ( EntityName, - FullyQualifiedEntityName, SourceUrl, ) from metadata.generated.schema.type.entityReferenceList import EntityReferenceList @@ -135,7 +135,6 @@ get_all_table_ddls, get_all_view_definitions, ) -from metadata.utils.tag_utils import get_ometa_tag_and_classification, get_tag_label class MAP(StructuredType): @@ -548,9 +547,20 @@ def yield_tag(self, schema_name: str) -> Iterable[Either[OMetaTagAndClassificati logger.debug(traceback.format_exc()) logger.error(f"Failed to fetch tags due to [{inner_exc}]") + schema_fqn = cast( + "str", + fqn.build( + self.metadata, + entity_type=DatabaseSchema, + service_name=self.context.get().database_service, + database_name=self.context.get().database, + schema_name=schema_name, + ), + ) for res in result: row = list(res) fqn_elements = [name for name in row[2:] if name] + # row[0] = TAG_NAME, row[1] = TAG_VALUE if not row[1]: logger.warning( @@ -558,62 +568,105 @@ def yield_tag(self, schema_name: str) -> Iterable[Either[OMetaTagAndClassificati "TAG_VALUE is empty. Snowflake tags require a value to be ingested." ) continue - yield from get_ometa_tag_and_classification( - tag_fqn=FullyQualifiedEntityName( - fqn._build( # pylint: disable=protected-access - self.context.get().database_service, *fqn_elements - ) - ), - tags=[row[1]], - classification_name=row[0], - tag_description=SNOWFLAKE_TAG_DESCRIPTION, - classification_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION, - metadata=self.metadata, - system_tags=True, - ) + + entity_fqn = fqn._build(self.context.get().database_service, *fqn_elements) # pyright: ignore[reportAttributeAccessIssue] + try: + classification = self.tag_canonicalizer.classification(row[0], SNOWFLAKE_CLASSIFICATION_DESCRIPTION) + tag = self.tag_canonicalizer.tag(classification.name, row[1], SNOWFLAKE_TAG_DESCRIPTION) + + self.tags_registry.attach( + scope_fqn=schema_fqn, + entity_fqn=entity_fqn, + classification_name=classification.name, + tag_name=tag.name, + classification_description=classification.description, + tag_description=tag.description, + ) + except Exception as exc: + logger.debug(traceback.format_exc()) + yield Either( + left=StackTraceError( + name=f"{row[0]}.{row[1]}", + error=f"Tag canonicalization failed for {row[0]}.{row[1]}: {exc}", + stackTrace=traceback.format_exc(), + ), + right=None, + ) # Yield schema-level tags if schema_name in self.schema_tags_map: - schema_fqn = fqn.build( - self.metadata, - entity_type=DatabaseSchema, - service_name=self.context.get().database_service, - database_name=self.context.get().database, - schema_name=schema_name, - ) for tag_info in self.schema_tags_map[schema_name]: - yield from get_ometa_tag_and_classification( - tag_fqn=FullyQualifiedEntityName(schema_fqn), - tags=[tag_info["tag_value"]], - classification_name=tag_info["tag_name"], - tag_description=SNOWFLAKE_TAG_DESCRIPTION, - classification_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION, - metadata=self.metadata, - system_tags=True, - ) + try: + classification = self.tag_canonicalizer.classification( + tag_info["tag_name"], SNOWFLAKE_CLASSIFICATION_DESCRIPTION + ) + tag = self.tag_canonicalizer.tag( + classification.name, tag_info["tag_value"], SNOWFLAKE_TAG_DESCRIPTION + ) + + self.tags_registry.attach( + scope_fqn=schema_fqn, + entity_fqn=schema_fqn, + classification_name=classification.name, + tag_name=tag.name, + classification_description=classification.description, + tag_description=tag.description, + ) + except Exception as exc: + logger.debug(traceback.format_exc()) + yield Either( + left=StackTraceError( + name=f"{tag_info['tag_name']}.{tag_info['tag_value']}", + error=f"Tag canonicalization failed for {tag_info['tag_name']}.{tag_info['tag_value']}: {exc}", + stackTrace=traceback.format_exc(), + ), + right=None, + ) + yield from (Either(left=None, right=record) for record in self.tags_registry.drain()) - def yield_database_tag(self, database_entity: str) -> Iterable[Either[OMetaTagAndClassification]]: + def yield_database_tag(self, database_name: str) -> Iterable[Either[OMetaTagAndClassification]]: """Yield database-level tags for the topology.""" if not self.source_config.includeTags: return - if database_entity in self.database_tags_map: - database_fqn = fqn.build( + if database_name not in self.database_tags_map: + return + + database_fqn = cast( + "str", + fqn.build( self.metadata, entity_type=Database, - service_name=self.context.get().database_service, - database_name=database_entity, - ) - for tag_info in self.database_tags_map[database_entity]: - yield from get_ometa_tag_and_classification( - tag_fqn=FullyQualifiedEntityName(database_fqn), - tags=[tag_info["tag_value"]], - classification_name=tag_info["tag_name"], - tag_description=SNOWFLAKE_TAG_DESCRIPTION, - classification_description=SNOWFLAKE_CLASSIFICATION_DESCRIPTION, - metadata=self.metadata, - system_tags=True, + service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue] + database_name=database_name, + ), + ) + for tag_info in self.database_tags_map[database_name]: + try: + classification = self.tag_canonicalizer.classification( + tag_info["tag_name"], SNOWFLAKE_CLASSIFICATION_DESCRIPTION + ) + tag = self.tag_canonicalizer.tag(classification.name, tag_info["tag_value"], SNOWFLAKE_TAG_DESCRIPTION) + + self.tags_registry.attach( + scope_fqn=database_fqn, + entity_fqn=database_fqn, + classification_name=classification.name, + tag_name=tag.name, + classification_description=classification.description, + tag_description=tag.description, ) + except Exception as exc: + logger.debug(traceback.format_exc()) + yield Either( + left=StackTraceError( + name=f"{tag_info['tag_name']}.{tag_info['tag_value']}", + error=f"Tag canonicalization failed for {tag_info['tag_name']}.{tag_info['tag_value']}: {exc}", + stackTrace=traceback.format_exc(), + ), + right=None, + ) + yield from (Either(left=None, right=record) for record in self.tags_registry.drain()) def _get_table_names_and_types( self, schema_name: str, table_type: TableType = TableType.Regular @@ -1049,42 +1102,72 @@ def _has_classification(self, classification_name: str, tag_list: List[TagLabel] return True return False + def get_database_tag_labels(self, database_name: str) -> Optional[List[TagLabel]]: # noqa: UP006, UP045 + """Return tags for the database entity from registry.""" + database_fqn = cast( + "str", + fqn.build( + self.metadata, + entity_type=Database, + service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue] + database_name=database_name, + ), + ) + return self.tags_registry.labels_for(database_fqn) or None + + def get_column_tag_labels(self, table_name: str, column: dict) -> Optional[List[TagLabel]]: # noqa: UP006, UP045 + """Return tags for a column entity from the registry. + + Column tags don't inherit from parent entities (table/schema/database) + — those have separate semantic meaning at their own level. Direct + lookup is sufficient. + """ + col_fqn = cast( + "str", + fqn.build( + self.metadata, + entity_type=Column, + service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue] + database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue] + schema_name=self.context.get().database_schema, # pyright: ignore[reportAttributeAccessIssue] + table_name=table_name, + column_name=column["name"], + ), + ) + return self.tags_registry.labels_for(col_fqn) or None + def get_schema_tag_labels(self, schema_name: str) -> Optional[List[TagLabel]]: # noqa: UP006, UP045 """ Return tags for schema entity including: 1. Snowflake schema-level tags 2. Inherited database-level tags (only if no tag with same classification exists) """ - schema_tags = [] + schema_fqn = cast( + "str", + fqn.build( + self.metadata, + entity_type=DatabaseSchema, + service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue] + database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue] + schema_name=schema_name, + ), + ) + database_fqn = cast( + "str", + fqn.build( + self.metadata, + entity_type=Database, + service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue] + database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue] + ), + ) - if schema_name in self.schema_tags_map: - for tag_info in self.schema_tags_map[schema_name]: - tag_label = get_tag_label( - metadata=self.metadata, - tag_name=tag_info["tag_value"], - classification_name=tag_info["tag_name"], - ) - if tag_label: - schema_tags.append(tag_label) + schema_tags = self.tags_registry.labels_for(schema_fqn) # Add inherited database tags (only if classification doesn't already exist) - database_name = self.context.get().database - if database_name and database_name in self.database_tags_map: - for tag_info in self.database_tags_map[database_name]: - if not self._has_classification(tag_info["tag_name"], schema_tags): - tag_label = get_tag_label( - metadata=self.metadata, - tag_name=tag_info["tag_value"], - classification_name=tag_info["tag_name"], - ) - if tag_label: - schema_tags.append(tag_label) - - # Include parent tags from context - parent_tags = super().get_schema_tag_labels(schema_name) or [] - for tag in parent_tags: - if not self._has_classification(self._get_classification_name(tag), schema_tags): - schema_tags.append(tag) + for label in self.tags_registry.labels_for(database_fqn): + if not self._has_classification(self._get_classification_name(label), schema_tags): + schema_tags.append(label) return schema_tags if schema_tags else None @@ -1098,32 +1181,48 @@ def get_tag_labels(self, table_name: str) -> Optional[List[TagLabel]]: # noqa: Tag values at lower levels take precedence over inherited values. """ - table_tags = super().get_tag_labels(table_name) or [] + table_fqn = cast( + "str", + fqn.build( + self.metadata, + entity_type=Table, + service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue] + database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue] + schema_name=self.context.get().database_schema, # pyright: ignore[reportAttributeAccessIssue] + table_name=table_name, + skip_es_search=True, + ), + ) + schema_fqn = cast( + "str", + fqn.build( + self.metadata, + entity_type=DatabaseSchema, + service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue] + database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue] + schema_name=self.context.get().database_schema, # pyright: ignore[reportAttributeAccessIssue] + ), + ) + database_fqn = cast( + "str", + fqn.build( + self.metadata, + entity_type=Database, + service_name=self.context.get().database_service, # pyright: ignore[reportAttributeAccessIssue] + database_name=self.context.get().database, # pyright: ignore[reportAttributeAccessIssue] + ), + ) + + table_tags = self.tags_registry.labels_for(table_fqn) # Add inherited schema tags (only if classification doesn't already exist) - schema_name = self.context.get().database_schema - if schema_name and schema_name in self.schema_tags_map: - for tag_info in self.schema_tags_map[schema_name]: - if not self._has_classification(tag_info["tag_name"], table_tags): - tag_label = get_tag_label( - metadata=self.metadata, - tag_name=tag_info["tag_value"], - classification_name=tag_info["tag_name"], - ) - if tag_label: - table_tags.append(tag_label) + for label in self.tags_registry.labels_for(schema_fqn): + if not self._has_classification(self._get_classification_name(label), table_tags): + table_tags.append(label) # Add inherited database tags (only if classification doesn't already exist) - database_name = self.context.get().database - if database_name and database_name in self.database_tags_map: - for tag_info in self.database_tags_map[database_name]: - if not self._has_classification(tag_info["tag_name"], table_tags): - tag_label = get_tag_label( - metadata=self.metadata, - tag_name=tag_info["tag_value"], - classification_name=tag_info["tag_name"], - ) - if tag_label: - table_tags.append(tag_label) + for label in self.tags_registry.labels_for(database_fqn): + if not self._has_classification(self._get_classification_name(label), table_tags): + table_tags.append(label) return table_tags if table_tags else None diff --git a/ingestion/tests/unit/domain/__init__.py b/ingestion/tests/unit/domain/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/ingestion/tests/unit/domain/tags/__init__.py b/ingestion/tests/unit/domain/tags/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/ingestion/tests/unit/domain/tags/test_canonicalizer.py b/ingestion/tests/unit/domain/tags/test_canonicalizer.py new file mode 100644 index 000000000000..3c97f1cd0d1e --- /dev/null +++ b/ingestion/tests/unit/domain/tags/test_canonicalizer.py @@ -0,0 +1,161 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for ``metadata.domain.tags.TagCanonicalizer``.""" + +from unittest.mock import MagicMock + +import pytest + +from metadata.domain.tags import Canonical, TagCanonicalizer +from metadata.generated.schema.entity.classification.classification import Classification +from metadata.generated.schema.type.basic import ProviderType + + +@pytest.fixture(autouse=True) +def _no_retry_sleep(monkeypatch: pytest.MonkeyPatch) -> None: + """Skip tenacity's between-retry sleeps so retry-tests run instantly.""" + monkeypatch.setattr("time.sleep", lambda *_args, **_kwargs: None) + + +@pytest.fixture +def mock_metadata() -> MagicMock: + return MagicMock() + + +@pytest.fixture +def canonicalizer(mock_metadata: MagicMock) -> TagCanonicalizer: + return TagCanonicalizer(metadata=mock_metadata) + + +def _system_classification(name: str, description: str = "") -> MagicMock: + m = MagicMock() + m.provider = ProviderType.system + m.name.root = name + if description: + m.description.root = description + else: + m.description = None + return m + + +def _system_tag(classification: str, name: str, description: str = "") -> MagicMock: + m = MagicMock() + m.provider = ProviderType.system + m.classification.name = classification + m.name.root = name + if description: + m.description.root = description + else: + m.description = None + return m + + +class TestClassification: + def test_no_match_returns_source_unchanged(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock): + mock_metadata.es_search_from_fqn.return_value = [] + result = canonicalizer.classification("MyClass", "Source desc") + assert result == Canonical(name="MyClass", description="Source desc") + + def test_system_match_uses_canonical_case(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock): + mock_metadata.es_search_from_fqn.return_value = [_system_classification("PII", "Canonical desc")] + result = canonicalizer.classification("pii", "Source desc") + assert result == Canonical(name="PII", description="Canonical desc") + + def test_caches_per_case_insensitive_key(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock): + mock_metadata.es_search_from_fqn.return_value = [_system_classification("PII", "Canonical desc")] + canonicalizer.classification("pii") + canonicalizer.classification("PII") + canonicalizer.classification("Pii") + # Three case variants share the same case-insensitive cache key + assert mock_metadata.es_search_from_fqn.call_count == 1 + + def test_non_system_match_ignored(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock): + non_system = _system_classification("PII", "Canonical desc") + non_system.provider = ProviderType.user + mock_metadata.es_search_from_fqn.return_value = [non_system] + result = canonicalizer.classification("pii", "Source desc") + assert result == Canonical(name="pii", description="Source desc") + + def test_classification_es_called_with_correct_args( + self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock + ): + mock_metadata.es_search_from_fqn.return_value = [] + canonicalizer.classification("Foo") + mock_metadata.es_search_from_fqn.assert_called_once_with(entity_type=Classification, fqn_search_string="Foo") + + +class TestTag: + def test_no_match_returns_source_unchanged(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock): + mock_metadata.es_search_from_fqn.return_value = [] + result = canonicalizer.tag("PII", "MyTag", "Source desc") + assert result == Canonical(name="MyTag", description="Source desc") + + def test_system_match_uses_canonical_case(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock): + mock_metadata.es_search_from_fqn.return_value = [_system_tag("PII", "Sensitive", "Canonical desc")] + result = canonicalizer.tag("PII", "sensitive", "Source desc") + assert result == Canonical(name="Sensitive", description="Canonical desc") + + def test_caches_per_case_insensitive_key(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock): + mock_metadata.es_search_from_fqn.return_value = [_system_tag("PII", "Sensitive", "")] + canonicalizer.tag("PII", "sensitive") + canonicalizer.tag("PII", "SENSITIVE") + canonicalizer.tag("PII", "Sensitive") + # Three case variants share the same case-insensitive cache key + assert mock_metadata.es_search_from_fqn.call_count == 1 + + def test_match_requires_classification_match(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock): + # ES returns a tag but for a different classification — no canonicalization + wrong_class_tag = _system_tag("OtherClass", "Sensitive", "Canonical desc") + mock_metadata.es_search_from_fqn.return_value = [wrong_class_tag] + result = canonicalizer.tag("PII", "sensitive", "Source desc") + assert result == Canonical(name="sensitive", description="Source desc") + + def test_non_system_match_ignored(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock): + non_system = _system_tag("PII", "Sensitive", "Canonical desc") + non_system.provider = ProviderType.user + mock_metadata.es_search_from_fqn.return_value = [non_system] + result = canonicalizer.tag("PII", "sensitive", "Source desc") + assert result == Canonical(name="sensitive", description="Source desc") + + +class TestRetryAndFailure: + def test_transient_failure_recovers_within_retry_budget( + self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock + ): + # First two ES calls raise; third succeeds. + mock_metadata.es_search_from_fqn.side_effect = [ + RuntimeError("transient 1"), + RuntimeError("transient 2"), + [_system_classification("PII", "Canonical desc")], + ] + result = canonicalizer.classification("pii", "Source desc") + assert result == Canonical(name="PII", description="Canonical desc") + assert mock_metadata.es_search_from_fqn.call_count == 3 + + def test_persistent_failure_raises_after_retries_exhaust( + self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock + ): + mock_metadata.es_search_from_fqn.side_effect = RuntimeError("persistent") + with pytest.raises(RuntimeError, match="persistent"): + canonicalizer.classification("MyClass", "Source desc") + assert mock_metadata.es_search_from_fqn.call_count == 5 + + def test_persistent_failure_does_not_poison_cache(self, canonicalizer: TagCanonicalizer, mock_metadata: MagicMock): + # First call: ES persistently fails -> raises. + mock_metadata.es_search_from_fqn.side_effect = RuntimeError("persistent") + with pytest.raises(RuntimeError): + canonicalizer.classification("MyClass", "Source desc") + + # ES recovers; subsequent call must reach ES again, not return a cached fallback. + mock_metadata.es_search_from_fqn.side_effect = None + mock_metadata.es_search_from_fqn.return_value = [_system_classification("MyClass", "Canonical desc")] + result = canonicalizer.classification("MyClass", "Source desc") + assert result == Canonical(name="MyClass", description="Canonical desc") diff --git a/ingestion/tests/unit/domain/tags/test_registry.py b/ingestion/tests/unit/domain/tags/test_registry.py new file mode 100644 index 000000000000..7cddc99eab0a --- /dev/null +++ b/ingestion/tests/unit/domain/tags/test_registry.py @@ -0,0 +1,375 @@ +# Copyright 2025 Collate +# Licensed under the Collate Community License, Version 1.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# https://github.com/open-metadata/OpenMetadata/blob/main/ingestion/LICENSE +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for ``metadata.domain.tags.TagRegistry``. + +Covers attach/labels_for/drain/clear_scope/ensure_known semantics plus +basic thread-safety stress scenarios. The OM client is mocked; no +network or schema validation against a real backend. +""" + +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import MagicMock + +import pytest + +from metadata.domain.tags import ScopeAlreadyClearedError, TagRegistry +from metadata.generated.schema.type.tagLabel import LabelType, State + + +@pytest.fixture +def mock_metadata() -> MagicMock: + return MagicMock() + + +@pytest.fixture +def registry(mock_metadata: MagicMock) -> TagRegistry: + return TagRegistry(metadata=mock_metadata) + + +def _attach_kwargs( + scope: str, + entity: str, + classification: str = "TestClass", + tag: str = "TestTag", +) -> dict: + return { + "scope_fqn": scope, + "entity_fqn": entity, + "classification_name": classification, + "tag_name": tag, + "classification_description": "test classification", + "tag_description": "test tag", + } + + +class TestAttachAndLabelsFor: + def test_attach_then_labels_for_returns_one_label(self, registry: TagRegistry): + registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.table")) + labels = registry.labels_for("svc.db.schema.table") + assert len(labels) == 1 + + def test_attach_multiple_tags_same_entity_returns_all(self, registry: TagRegistry): + registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.table", tag="Tag1")) + registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.table", tag="Tag2")) + labels = registry.labels_for("svc.db.schema.table") + assert len(labels) == 2 + + def test_labels_for_unattached_entity_returns_empty_list(self, registry: TagRegistry): + assert registry.labels_for("svc.db.schema.unknown") == [] + + def test_labels_for_is_idempotent(self, registry: TagRegistry): + registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.table")) + first = registry.labels_for("svc.db.schema.table") + second = registry.labels_for("svc.db.schema.table") + # Read-and-leave: both reads return the same labels. + # Cleanup is the responsibility of clear_scope, not labels_for. + assert len(first) == 1 + assert second == first + + def test_labels_for_returns_copy_not_internal_list(self, registry: TagRegistry): + # Mutating the returned list must not affect registry state. + registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.table")) + first = registry.labels_for("svc.db.schema.table") + first.clear() + second = registry.labels_for("svc.db.schema.table") + assert len(second) == 1 + + +class TestDrain: + def test_drain_yields_pending_then_clears(self, registry: TagRegistry): + registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_a")) + first = list(registry.drain()) + second = list(registry.drain()) + assert len(first) == 1 + assert second == [] + + def test_drain_dedupes_same_tag_across_entities(self, registry: TagRegistry): + for i in range(100): + registry.attach(**_attach_kwargs("svc.db", f"svc.db.schema.tbl_{i}")) + pending = list(registry.drain()) + assert len(pending) == 1 + + def test_drain_yields_distinct_payloads_for_distinct_tags(self, registry: TagRegistry): + registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_1", tag="TagA")) + registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_2", tag="TagB")) + pending = list(registry.drain()) + assert len(pending) == 2 + + def test_drain_does_not_dedup_across_case_variants(self, registry: TagRegistry): + # OM stores tags case-sensitively; our dedup must follow that rule. + registry.attach(**_attach_kwargs("svc.db", "svc.db.t1", tag="Sensitive")) + registry.attach(**_attach_kwargs("svc.db", "svc.db.t2", tag="sensitive")) + pending = list(registry.drain()) + assert len(pending) == 2 # both must PUT — they're distinct tags server-side + + def test_drain_dedupes_same_fqn_across_label_types(self, registry: TagRegistry): + # Different cache keys (label_type varies) but identical tag_fqn → ONE PUT. + # Cache key is (class, tag, label_type, state); tag_fqn is class.tag. + registry.attach( + **_attach_kwargs("svc.db", "svc.db.t1"), + label_type=LabelType.Manual, + ) + registry.attach( + **_attach_kwargs("svc.db", "svc.db.t2"), + label_type=LabelType.Automated, + ) + pending = list(registry.drain()) + assert len(pending) == 1, "fqn-level dedup must collapse PUTs across label_type variants" + + +class TestClearScope: + def test_clear_scope_drops_descendant_labels(self, registry: TagRegistry): + registry.attach(**_attach_kwargs("svc.db.schema", "svc.db.schema.tbl_1")) + registry.attach(**_attach_kwargs("svc.db.schema", "svc.db.schema.tbl_2")) + registry.clear_scope("svc.db.schema") + assert registry.labels_for("svc.db.schema.tbl_1") == [] + assert registry.labels_for("svc.db.schema.tbl_2") == [] + + def test_clear_scope_drops_scope_itself(self, registry: TagRegistry): + registry.attach(**_attach_kwargs("svc.db.schema", "svc.db.schema")) + registry.clear_scope("svc.db.schema") + assert registry.labels_for("svc.db.schema") == [] + + def test_clear_scope_preserves_other_scopes(self, registry: TagRegistry): + registry.attach(**_attach_kwargs("svc.db.schema_a", "svc.db.schema_a.tbl")) + registry.attach(**_attach_kwargs("svc.db.schema_b", "svc.db.schema_b.tbl")) + registry.clear_scope("svc.db.schema_a") + assert registry.labels_for("svc.db.schema_a.tbl") == [] + assert len(registry.labels_for("svc.db.schema_b.tbl")) == 1 + + def test_clear_scope_no_false_prefix_match(self, registry: TagRegistry): + # 'schema_a' is NOT a prefix of 'schema_alpha' once the FQN + # separator is taken into account. + registry.attach(**_attach_kwargs("svc.db.schema_alpha", "svc.db.schema_alpha.tbl")) + registry.clear_scope("svc.db.schema_a") + assert len(registry.labels_for("svc.db.schema_alpha.tbl")) == 1 + + def test_clear_scope_idempotent_on_unattached_scope(self, registry: TagRegistry): + registry.clear_scope("svc.db.never_attached") # must not raise + + def test_attach_after_clear_raises(self, registry: TagRegistry): + registry.clear_scope("svc.db.schema") + with pytest.raises(ScopeAlreadyClearedError): + registry.attach(**_attach_kwargs("svc.db.schema", "svc.db.schema.tbl")) + + +class TestEnsureKnown: + def test_is_known_empty_returns_false(self, registry: TagRegistry): + assert registry.is_known("Class.Tag") is False + + def test_is_known_after_attach_returns_true(self, registry: TagRegistry): + registry.attach( + **_attach_kwargs( + "svc.db", + "svc.db.schema.tbl", + classification="Class", + tag="Tag", + ) + ) + assert registry.is_known("Class.Tag") is True + + def test_is_known_is_case_sensitive(self, registry: TagRegistry): + # Reflects OM's case-sensitive identity rule. + registry.attach( + **_attach_kwargs( + "svc.db", + "svc.db.schema.tbl", + classification="Class", + tag="Tag", + ) + ) + assert registry.is_known("Class.Tag") is True + assert registry.is_known("class.tag") is False # different tag server-side + + def test_ensure_known_cache_hit_skips_io(self, registry: TagRegistry, mock_metadata: MagicMock): + registry.attach( + **_attach_kwargs( + "svc.db", + "svc.db.schema.tbl", + classification="Class", + tag="Tag", + ) + ) + assert registry.ensure_known("Class.Tag") is True + mock_metadata.get_by_name.assert_not_called() + + def test_ensure_known_cache_miss_calls_get_by_name_once(self, registry: TagRegistry, mock_metadata: MagicMock): + mock_metadata.get_by_name.return_value = MagicMock() + assert registry.ensure_known("Other.Tag") is True + assert registry.ensure_known("Other.Tag") is True # cached now + assert mock_metadata.get_by_name.call_count == 1 + + def test_ensure_known_404_returns_false_and_does_not_cache(self, registry: TagRegistry, mock_metadata: MagicMock): + mock_metadata.get_by_name.return_value = None + assert registry.ensure_known("Missing.Tag") is False + assert registry.ensure_known("Missing.Tag") is False + # Re-queries on each miss; not cached. + assert mock_metadata.get_by_name.call_count == 2 + + def test_ensure_known_swallows_exception(self, registry: TagRegistry, mock_metadata: MagicMock): + mock_metadata.get_by_name.side_effect = RuntimeError("network down") + assert registry.ensure_known("Crashed.Tag") is False + + +class TestThreadSafety: + def test_concurrent_attach_same_tag_dedupes_pending(self, registry: TagRegistry): + def worker(thread_idx: int) -> None: + for i in range(100): + registry.attach( + **_attach_kwargs( + "svc.db", + f"svc.db.schema.tbl_{thread_idx}_{i}", + ) + ) + + with ThreadPoolExecutor(max_workers=8) as pool: + list(pool.map(worker, range(8))) + + pending = list(registry.drain()) + assert len(pending) == 1 + + def test_concurrent_disjoint_scopes_no_label_loss(self, registry: TagRegistry): + def worker(scope_idx: int) -> None: + scope = f"svc.db.schema_{scope_idx}" + for i in range(50): + registry.attach( + **_attach_kwargs( + scope, + f"{scope}.tbl_{i}", + tag=f"Tag_{scope_idx}_{i}", + ) + ) + + with ThreadPoolExecutor(max_workers=8) as pool: + list(pool.map(worker, range(8))) + + for scope_idx in range(8): + scope = f"svc.db.schema_{scope_idx}" + for i in range(50): + entity = f"{scope}.tbl_{i}" + labels = registry.labels_for(entity) + assert len(labels) == 1, f"missing label for {entity}" + + +class TestStats: + def test_initial_stats_all_zero(self, registry: TagRegistry): + assert registry.stats() == { + "known_tag_fqns": 0, + "tag_label_cache": 0, + "pending": 0, + "cleared_scopes": 0, + "live_entities": 0, + "live_labels": 0, + } + + def test_stats_reflect_attach(self, registry: TagRegistry): + registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_1")) + registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_2")) + s = registry.stats() + # Both attaches share the same tag — known + pending dedup to 1 + assert s["known_tag_fqns"] == 1 + assert s["pending"] == 1 + # Two entities, each with one label + assert s["live_entities"] == 2 + assert s["live_labels"] == 2 + + def test_labels_for_does_not_decrease_live_state(self, registry: TagRegistry): + # labels_for is idempotent (read-and-leave); clear_scope is the + # only mechanism that reduces live state. + registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl")) + registry.labels_for("svc.db.schema.tbl") + s = registry.stats() + assert s["live_entities"] == 1 + assert s["live_labels"] == 1 + assert s["known_tag_fqns"] == 1 + assert s["pending"] == 1 + + def test_drain_decreases_pending_only(self, registry: TagRegistry): + registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl")) + list(registry.drain()) + s = registry.stats() + assert s["pending"] == 0 + assert s["known_tag_fqns"] == 1 # still tracked for dedup + + def test_clear_scope_zeroes_live_state_for_scope(self, registry: TagRegistry): + # Critical invariant: after clear_scope, no live_entities for that scope. + for i in range(50): + registry.attach(**_attach_kwargs("svc.db.schema", f"svc.db.schema.tbl_{i}")) + assert registry.stats()["live_entities"] == 50 + + registry.clear_scope("svc.db.schema") + s = registry.stats() + assert s["live_entities"] == 0 + assert s["live_labels"] == 0 + assert s["cleared_scopes"] == 1 + + +class TestInterning: + """TagLabel interning — multiple attaches with the same key share one + underlying ``TagLabel`` instance. Memory bound depends on this; the + `is`-identity assertion is the load-bearing check.""" + + def test_attach_interns_identical_tag_labels(self, registry: TagRegistry): + # Same (classification, tag, label_type, state) across two entities + # must return the exact same TagLabel object — not just an equal one. + registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_1")) + registry.attach(**_attach_kwargs("svc.db", "svc.db.schema.tbl_2")) + + label_1 = registry.labels_for("svc.db.schema.tbl_1")[0] + label_2 = registry.labels_for("svc.db.schema.tbl_2")[0] + + assert label_1 is label_2, "expected shared TagLabel instance via interning" + + def test_attach_does_not_intern_across_label_types(self, registry: TagRegistry): + # Cache key includes label_type — non-default values must not collide. + registry.attach( + **_attach_kwargs("svc.db", "svc.db.schema.tbl_1"), + label_type=LabelType.Manual, + ) + registry.attach( + **_attach_kwargs("svc.db", "svc.db.schema.tbl_2"), + label_type=LabelType.Automated, + ) + + label_manual = registry.labels_for("svc.db.schema.tbl_1")[0] + label_auto = registry.labels_for("svc.db.schema.tbl_2")[0] + + assert label_manual is not label_auto + assert label_manual.labelType == LabelType.Manual + assert label_auto.labelType == LabelType.Automated + + def test_attach_does_not_intern_across_states(self, registry: TagRegistry): + registry.attach( + **_attach_kwargs("svc.db", "svc.db.schema.tbl_1"), + state=State.Suggested, + ) + registry.attach( + **_attach_kwargs("svc.db", "svc.db.schema.tbl_2"), + state=State.Confirmed, + ) + + label_suggested = registry.labels_for("svc.db.schema.tbl_1")[0] + label_confirmed = registry.labels_for("svc.db.schema.tbl_2")[0] + + assert label_suggested is not label_confirmed + + def test_intern_cache_survives_clear_scope(self, registry: TagRegistry): + # Cache lifetime is registry lifetime, NOT scope lifetime — next scope + # reuses the same TagLabel instance for the same (class, tag, ...) key. + registry.attach(**_attach_kwargs("svc.db.schema_1", "svc.db.schema_1.tbl")) + label_first = registry.labels_for("svc.db.schema_1.tbl")[0] + + registry.clear_scope("svc.db.schema_1") + + registry.attach(**_attach_kwargs("svc.db.schema_2", "svc.db.schema_2.tbl")) + label_second = registry.labels_for("svc.db.schema_2.tbl")[0] + + assert label_first is label_second, "intern cache should survive clear_scope" diff --git a/ingestion/tests/unit/topology/database/test_snowflake.py b/ingestion/tests/unit/topology/database/test_snowflake.py index 7dcf6c7a18cd..ddba04e38a19 100644 --- a/ingestion/tests/unit/topology/database/test_snowflake.py +++ b/ingestion/tests/unit/topology/database/test_snowflake.py @@ -19,7 +19,9 @@ import sqlalchemy.types as sqltypes -from metadata.generated.schema.entity.data.table import TableType +from metadata.generated.schema.entity.data.database import Database +from metadata.generated.schema.entity.data.databaseSchema import DatabaseSchema +from metadata.generated.schema.entity.data.table import Table, TableType from metadata.generated.schema.entity.services.ingestionPipelines.ingestionPipeline import ( PipelineStatus, ) @@ -27,14 +29,9 @@ OpenMetadataWorkflowConfig, ) from metadata.generated.schema.type.filterPattern import FilterPattern -from metadata.generated.schema.type.tagLabel import ( - LabelType, - State, - TagLabel, - TagSource, -) from metadata.ingestion.source.database.snowflake.metadata import MAP, SnowflakeSource from metadata.ingestion.source.database.snowflake.models import SnowflakeStoredProcedure +from metadata.utils import fqn SNOWFLAKE_CONFIGURATION = { "source": { @@ -491,18 +488,39 @@ def test_map_class_partial_custom_initialization(self): self.assertEqual(map_type.value_type, sqltypes.VARCHAR) # default self.assertFalse(map_type.not_null) # default - @patch("metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_tag_labels") - @patch("metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_schema_tag_labels") - @patch("metadata.ingestion.source.database.snowflake.metadata.get_tag_label") - def test_schema_tag_inheritance( - self, - mock_get_tag_label, - mock_parent_get_schema_tag_labels, - mock_parent_get_tag_labels, - ): - """Test schema tag inheritance""" + def _setup_tag_context(self, source, service_name="local_snowflake"): + """Populate the topology context for schema-stage tag tests and return the FQN trio.""" + source.context.get().__dict__["database_service"] = service_name + source.context.get().__dict__["database"] = "TEST_DATABASE" + source.context.get().__dict__["database_schema"] = "TEST_SCHEMA" + + database_fqn = fqn.build( + source.metadata, + entity_type=Database, + service_name=service_name, + database_name="TEST_DATABASE", + ) + schema_fqn = fqn.build( + source.metadata, + entity_type=DatabaseSchema, + service_name=service_name, + database_name="TEST_DATABASE", + schema_name="TEST_SCHEMA", + ) + table_fqn = fqn.build( + source.metadata, + entity_type=Table, + service_name=service_name, + database_name="TEST_DATABASE", + schema_name="TEST_SCHEMA", + table_name="TEST_TABLE", + skip_es_search=True, + ) + return database_fqn, schema_fqn, table_fqn + + def test_schema_tag_inheritance(self): + """Schema tags propagate to tables; classification dedup is preserved.""" for source in self.sources.values(): - # Verify tags are fetched and stored mock_schema_tags = [ Mock(SCHEMA_NAME="TEST_SCHEMA", TAG_NAME="SCHEMA_TAG", TAG_VALUE="VALUE"), ] @@ -519,48 +537,39 @@ def test_schema_tag_inheritance( {"tag_name": "SCHEMA_TAG", "tag_value": "VALUE"}, ) - # Verify schema tag labels - mock_get_tag_label.return_value = TagLabel( - tagFQN="SnowflakeTag.SCHEMA_TAG", - labelType=LabelType.Automated, - state=State.Suggested, - source=TagSource.Classification, + _, schema_fqn, table_fqn = self._setup_tag_context(source) + + source.tags_registry.attach( + scope_fqn=schema_fqn, + entity_fqn=schema_fqn, + classification_name="SCHEMA_CLASSIFICATION", + tag_name="SCHEMA_TAG", + classification_description="", + tag_description="", + ) + source.tags_registry.attach( + scope_fqn=schema_fqn, + entity_fqn=table_fqn, + classification_name="TABLE_CLASSIFICATION", + tag_name="TABLE_TAG", + classification_description="", + tag_description="", ) - mock_parent_get_schema_tag_labels.return_value = None schema_labels = source.get_schema_tag_labels(schema_name="TEST_SCHEMA") self.assertIsNotNone(schema_labels) self.assertEqual(len(schema_labels), 1) - - # Verify tag inheritance - source.context.get().__dict__["database_schema"] = "TEST_SCHEMA" - mock_parent_get_tag_labels.return_value = [ - TagLabel( - tagFQN="SnowflakeTag.TABLE_TAG", - labelType=LabelType.Automated, - state=State.Suggested, - source=TagSource.Classification, - ) - ] + self.assertEqual(schema_labels[0].tagFQN.root, "SCHEMA_CLASSIFICATION.SCHEMA_TAG") table_labels = source.get_tag_labels(table_name="TEST_TABLE") self.assertEqual(len(table_labels), 2) tag_fqns = [tag.tagFQN.root for tag in table_labels] - self.assertIn("SnowflakeTag.SCHEMA_TAG", tag_fqns) - self.assertIn("SnowflakeTag.TABLE_TAG", tag_fqns) - - @patch("metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_tag_labels") - @patch("metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_schema_tag_labels") - @patch("metadata.ingestion.source.database.snowflake.metadata.get_tag_label") - def test_database_tag_inheritance( - self, - mock_get_tag_label, - mock_parent_get_schema_tag_labels, - mock_parent_get_tag_labels, - ): - """Test database tag inheritance to schemas and tables""" + self.assertIn("SCHEMA_CLASSIFICATION.SCHEMA_TAG", tag_fqns) + self.assertIn("TABLE_CLASSIFICATION.TABLE_TAG", tag_fqns) + + def test_database_tag_inheritance(self): + """Database tags propagate to schemas and tables when classifications don't overlap.""" for source in self.sources.values(): - # Setup mock database tags mock_database_tags = [ Mock( DATABASE_NAME="TEST_DATABASE", @@ -574,7 +583,6 @@ def test_database_tag_inheritance( source.engine.connect.return_value.__enter__ = MagicMock(return_value=mock_conn) source.engine.connect.return_value.__exit__ = MagicMock(return_value=False) - # Test set_database_tags_map source.set_database_tags_map("TEST_DATABASE") self.assertEqual(len(source.database_tags_map["TEST_DATABASE"]), 1) self.assertEqual( @@ -582,23 +590,33 @@ def test_database_tag_inheritance( {"tag_name": "DATABASE_TAG", "tag_value": "DB_VALUE"}, ) - # Setup schema tags for combined testing - source.schema_tags_map = {"TEST_SCHEMA": [{"tag_name": "SCHEMA_TAG", "tag_value": "SCHEMA_VALUE"}]} + database_fqn, schema_fqn, table_fqn = self._setup_tag_context(source) - # Mock tag label creation - def mock_tag_label_side_effect(metadata, tag_name, classification_name): - return TagLabel( - tagFQN=f"{classification_name}.{tag_name}", - labelType=LabelType.Automated, - state=State.Suggested, - source=TagSource.Classification, - ) - - mock_get_tag_label.side_effect = mock_tag_label_side_effect - mock_parent_get_schema_tag_labels.return_value = None + source.tags_registry.attach( + scope_fqn=database_fqn, + entity_fqn=database_fqn, + classification_name="DATABASE_TAG", + tag_name="DB_VALUE", + classification_description="", + tag_description="", + ) + source.tags_registry.attach( + scope_fqn=schema_fqn, + entity_fqn=schema_fqn, + classification_name="SCHEMA_TAG", + tag_name="SCHEMA_VALUE", + classification_description="", + tag_description="", + ) + source.tags_registry.attach( + scope_fqn=schema_fqn, + entity_fqn=table_fqn, + classification_name="TABLE_TAG", + tag_name="TABLE_VALUE", + classification_description="", + tag_description="", + ) - # Test schema inherits database tags - source.context.get().__dict__["database"] = "TEST_DATABASE" schema_labels = source.get_schema_tag_labels(schema_name="TEST_SCHEMA") self.assertIsNotNone(schema_labels) self.assertEqual(len(schema_labels), 2) @@ -606,17 +624,6 @@ def mock_tag_label_side_effect(metadata, tag_name, classification_name): self.assertIn("SCHEMA_TAG.SCHEMA_VALUE", tag_fqns) self.assertIn("DATABASE_TAG.DB_VALUE", tag_fqns) - # Test table inherits both schema and database tags - source.context.get().__dict__["database_schema"] = "TEST_SCHEMA" - mock_parent_get_tag_labels.return_value = [ - TagLabel( - tagFQN="TABLE_TAG.TABLE_VALUE", - labelType=LabelType.Automated, - state=State.Suggested, - source=TagSource.Classification, - ) - ] - table_labels = source.get_tag_labels(table_name="TEST_TABLE") self.assertEqual(len(table_labels), 3) tag_fqns = [tag.tagFQN.root for tag in table_labels] @@ -624,59 +631,44 @@ def mock_tag_label_side_effect(metadata, tag_name, classification_name): self.assertIn("SCHEMA_TAG.SCHEMA_VALUE", tag_fqns) self.assertIn("DATABASE_TAG.DB_VALUE", tag_fqns) - @patch("metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_tag_labels") - @patch("metadata.ingestion.source.database.database_service.DatabaseServiceSource.get_schema_tag_labels") - @patch("metadata.ingestion.source.database.snowflake.metadata.get_tag_label") - def test_tag_value_precedence( - self, - mock_get_tag_label, - mock_parent_get_schema_tag_labels, - mock_parent_get_tag_labels, - ): - """Test that tag values at lower levels take precedence over inherited values. + def test_tag_value_precedence(self): + """Lower-level tags override inherited values for the same classification. - When database, schema, and table all have the same tag name (classification) - but different values, the object's own value should take precedence. + Database: ENV=dev, Schema: ENV=staging, Table: ENV=production. + Schema lookup must return only ENV.staging; table lookup only ENV.production. """ for source in self.sources.values(): - # Setup: Database, schema, and table all have ENV tag with different values - # Database: ENV=dev - # Schema: ENV=staging - # Table: ENV=production - - source.database_tags_map = {"TEST_DATABASE": [{"tag_name": "ENV", "tag_value": "dev"}]} - - source.schema_tags_map = {"TEST_SCHEMA": [{"tag_name": "ENV", "tag_value": "staging"}]} - - def mock_tag_label_side_effect(metadata, tag_name, classification_name): - return TagLabel( - tagFQN=f"{classification_name}.{tag_name}", - labelType=LabelType.Automated, - state=State.Suggested, - source=TagSource.Classification, - ) - - mock_get_tag_label.side_effect = mock_tag_label_side_effect - mock_parent_get_schema_tag_labels.return_value = None - - source.context.get().__dict__["database"] = "TEST_DATABASE" - source.context.get().__dict__["database_schema"] = "TEST_SCHEMA" + database_fqn, schema_fqn, table_fqn = self._setup_tag_context(source) + + source.tags_registry.attach( + scope_fqn=database_fqn, + entity_fqn=database_fqn, + classification_name="ENV", + tag_name="dev", + classification_description="", + tag_description="", + ) + source.tags_registry.attach( + scope_fqn=schema_fqn, + entity_fqn=schema_fqn, + classification_name="ENV", + tag_name="staging", + classification_description="", + tag_description="", + ) + source.tags_registry.attach( + scope_fqn=schema_fqn, + entity_fqn=table_fqn, + classification_name="ENV", + tag_name="production", + classification_description=None, + tag_description=None, + ) - # Test schema level: schema's own value takes precedence over database schema_labels = source.get_schema_tag_labels(schema_name="TEST_SCHEMA") self.assertEqual(len(schema_labels), 1) self.assertEqual(schema_labels[0].tagFQN.root, "ENV.staging") - # Test table level: table's own value takes precedence over schema and database - mock_parent_get_tag_labels.return_value = [ - TagLabel( - tagFQN="ENV.production", - labelType=LabelType.Automated, - state=State.Suggested, - source=TagSource.Classification, - ) - ] - table_labels = source.get_tag_labels(table_name="TEST_TABLE") self.assertEqual(len(table_labels), 1) self.assertEqual(table_labels[0].tagFQN.root, "ENV.production")