diff --git a/lightrag/kg/opensearch_impl.py b/lightrag/kg/opensearch_impl.py index 8a579f67d3..884eeae6d7 100644 --- a/lightrag/kg/opensearch_impl.py +++ b/lightrag/kg/opensearch_impl.py @@ -1988,8 +1988,25 @@ async def _bfs_subgraph_ppl( @staticmethod def _escape_ppl(value: str) -> str: - """Escape a string for safe inclusion in a PPL single-quoted literal.""" - return value.replace("\\", "\\\\").replace("'", "\\'") + """Escape a string for safe inclusion in a PPL single-quoted literal. + + Escapes backslashes, single quotes, and control characters that could + interfere with PPL query parsing. + """ + value = value.replace("\\", "\\\\").replace("'", "\\'") + # Strip control characters that could break the PPL string literal + value = value.replace("\n", " ").replace("\r", " ").replace("\t", " ") + return value + + @staticmethod + def _escape_wildcard(value: str) -> str: + """Escape OpenSearch wildcard special characters in user input. + + Escapes \\, *, and ? so they are treated as literal characters + rather than wildcard operators, preventing DoS via expensive patterns. + """ + # Escape backslash first, then wildcard metacharacters + return value.replace("\\", "\\\\").replace("*", "\\*").replace("?", "\\?") async def _bfs_subgraph( self, start_label: str, max_depth: int, max_nodes: int @@ -2248,7 +2265,7 @@ async def search_labels(self, query: str, limit: int = 50) -> list[str]: { "wildcard": { "entity_id": { - "value": f"*{query.lower()}*", + "value": f"*{self._escape_wildcard(query.lower())}*", "case_insensitive": True, "boost": 2, } diff --git a/tests/test_cwe89_opensearch_injection.py b/tests/test_cwe89_opensearch_injection.py new file mode 100644 index 0000000000..6257304268 --- /dev/null +++ b/tests/test_cwe89_opensearch_injection.py @@ -0,0 +1,227 @@ +""" +PoC test: CWE-89 OpenSearch injection via unsanitized entity names in query construction. + +The test validates that: +1. Wildcard special characters (*, ?) in user input to search_labels are escaped + before being used in OpenSearch wildcard queries, preventing DoS via expensive + wildcard patterns. +2. PPL escape handles control characters and additional metacharacters beyond + just backslash and single-quote. + +Run with: pytest tests/test_cwe89_opensearch_injection.py -v +""" + +import re +import pytest +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, patch +import numpy as np + +pytest.importorskip( + "opensearchpy", + reason="opensearchpy is required for OpenSearch storage tests", +) + +from opensearchpy.exceptions import NotFoundError, OpenSearchException # type: ignore +from lightrag.kg.opensearch_impl import ( + OpenSearchGraphStorage, + ClientManager, +) + +pytestmark = pytest.mark.offline + + +@asynccontextmanager +async def _mock_lock(): + yield + + +def _mock_lock_factory(): + return _mock_lock() + + +@pytest.fixture(autouse=True) +def patch_data_init_lock(): + with patch( + "lightrag.kg.opensearch_impl.get_data_init_lock", side_effect=_mock_lock_factory + ): + yield + + +class MockEmbeddingFunc: + def __init__(self, dim=128): + self.embedding_dim = dim + self.max_token_size = 512 + self.model_name = "mock-embed" + + async def __call__(self, texts, **kwargs): + return np.random.rand(len(texts), self.embedding_dim).astype(np.float32) + + +@pytest.fixture +def global_config(): + return { + "embedding_batch_num": 10, + "max_graph_nodes": 1000, + } + + +@pytest.fixture +def embed_func(): + return MockEmbeddingFunc() + + +def _make_client(): + from opensearchpy import AsyncOpenSearch + client = AsyncMock(spec=AsyncOpenSearch) + client.indices = AsyncMock() + client.indices.exists.return_value = True + client.transport = AsyncMock() + return client + + +@pytest.fixture +def graph_storage(global_config, embed_func): + with patch.object(ClientManager, "get_client") as mock_get: + client = _make_client() + mock_get.return_value = client + + storage = OpenSearchGraphStorage( + namespace="test_graph", + global_config=global_config, + embedding_func=embed_func, + ) + storage.client = client + storage._indices_ready = True + storage._ppl_graphlookup_available = True + yield storage + + +class TestWildcardInjection: + """Test that wildcard special chars are escaped in search_labels.""" + + @pytest.mark.asyncio + async def test_wildcard_chars_escaped_in_search_labels(self, graph_storage): + """Wildcard metacharacters *, ? in user input must be escaped.""" + client = graph_storage.client + + # Setup mock to return empty results + client.search.return_value = {"hits": {"hits": []}} + + # Malicious query with wildcard chars that could cause expensive patterns + malicious_query = "test*?foo" + await graph_storage.search_labels(malicious_query) + + # Inspect the query body that was sent to OpenSearch + assert client.search.called, "search should have been called" + call_kwargs = client.search.call_args + body = call_kwargs.kwargs.get("body") or call_kwargs[1].get("body") + + # Extract the wildcard clause + should_clauses = body["query"]["bool"]["should"] + wildcard_clause = None + for clause in should_clauses: + if "wildcard" in clause: + wildcard_clause = clause["wildcard"]["entity_id"]["value"] + break + + assert wildcard_clause is not None, "wildcard clause should exist" + + # The wildcard value should NOT contain unescaped * or ? from the user input + # The outer * wrapping is fine (those are the intentional wildcards), + # but the inner user-provided * and ? must be escaped + # Expected: *test\*\?foo* (with the user's * and ? escaped) + inner_value = wildcard_clause[1:-1] # strip leading and trailing * + assert "\\*" in inner_value, ( + f"User's '*' should be escaped as '\\*' in wildcard, got: {wildcard_clause}" + ) + assert "\\?" in inner_value, ( + f"User's '?' should be escaped as '\\?' in wildcard, got: {wildcard_clause}" + ) + + @pytest.mark.asyncio + async def test_wildcard_heavy_pattern_not_exploitable(self, graph_storage): + """A series of ? chars should be escaped, not passed raw to OpenSearch.""" + client = graph_storage.client + client.search.return_value = {"hits": {"hits": []}} + + # Attack: many single-char wildcards cause exponential matching + attack_query = "?" * 50 + await graph_storage.search_labels(attack_query) + + call_kwargs = client.search.call_args + body = call_kwargs.kwargs.get("body") or call_kwargs[1].get("body") + + should_clauses = body["query"]["bool"]["should"] + wildcard_clause = None + for clause in should_clauses: + if "wildcard" in clause: + wildcard_clause = clause["wildcard"]["entity_id"]["value"] + break + + # None of the user's ? should appear as unescaped wildcards + # The value between the outer * delimiters should have all ? escaped + inner = wildcard_clause[1:-1] + # Count unescaped ? (i.e., ? not preceded by \) + unescaped_q = re.findall(r'(?