-
Notifications
You must be signed in to change notification settings - Fork 4.9k
fix(opensearch): escape wildcard metacharacters in search_labels to prevent DoS (CWE-89) #3026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sebastiondev
wants to merge
2
commits into
HKUDS:main
Choose a base branch
from
sebastiondev:fix/cwe89-opensearch-impl-unsanitize-d709
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+247
−3
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,227 @@ | ||
| """ | ||
| PoC test: CWE-89 OpenSearch injection via unsanitized entity names in query construction. | ||
|
|
||
| The test validates that: | ||
| 1. Wildcard special characters (*, ?) in user input to search_labels are escaped | ||
| before being used in OpenSearch wildcard queries, preventing DoS via expensive | ||
| wildcard patterns. | ||
| 2. PPL escape handles control characters and additional metacharacters beyond | ||
| just backslash and single-quote. | ||
|
|
||
| Run with: pytest tests/test_cwe89_opensearch_injection.py -v | ||
| """ | ||
|
|
||
| import re | ||
| import pytest | ||
| from contextlib import asynccontextmanager | ||
| from unittest.mock import AsyncMock, patch | ||
| import numpy as np | ||
|
|
||
| pytest.importorskip( | ||
| "opensearchpy", | ||
| reason="opensearchpy is required for OpenSearch storage tests", | ||
| ) | ||
|
|
||
| from opensearchpy.exceptions import NotFoundError, OpenSearchException # type: ignore | ||
| from lightrag.kg.opensearch_impl import ( | ||
| OpenSearchGraphStorage, | ||
| ClientManager, | ||
| ) | ||
|
|
||
| pytestmark = pytest.mark.offline | ||
|
|
||
|
|
||
| @asynccontextmanager | ||
| async def _mock_lock(): | ||
| yield | ||
|
|
||
|
|
||
| def _mock_lock_factory(): | ||
| return _mock_lock() | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def patch_data_init_lock(): | ||
| with patch( | ||
| "lightrag.kg.opensearch_impl.get_data_init_lock", side_effect=_mock_lock_factory | ||
| ): | ||
| yield | ||
|
|
||
|
|
||
| class MockEmbeddingFunc: | ||
| def __init__(self, dim=128): | ||
| self.embedding_dim = dim | ||
| self.max_token_size = 512 | ||
| self.model_name = "mock-embed" | ||
|
|
||
| async def __call__(self, texts, **kwargs): | ||
| return np.random.rand(len(texts), self.embedding_dim).astype(np.float32) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def global_config(): | ||
| return { | ||
| "embedding_batch_num": 10, | ||
| "max_graph_nodes": 1000, | ||
| } | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def embed_func(): | ||
| return MockEmbeddingFunc() | ||
|
|
||
|
|
||
| def _make_client(): | ||
| from opensearchpy import AsyncOpenSearch | ||
| client = AsyncMock(spec=AsyncOpenSearch) | ||
| client.indices = AsyncMock() | ||
| client.indices.exists.return_value = True | ||
| client.transport = AsyncMock() | ||
| return client | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def graph_storage(global_config, embed_func): | ||
| with patch.object(ClientManager, "get_client") as mock_get: | ||
| client = _make_client() | ||
| mock_get.return_value = client | ||
|
|
||
| storage = OpenSearchGraphStorage( | ||
| namespace="test_graph", | ||
| global_config=global_config, | ||
| embedding_func=embed_func, | ||
| ) | ||
| storage.client = client | ||
| storage._indices_ready = True | ||
| storage._ppl_graphlookup_available = True | ||
| yield storage | ||
|
|
||
|
|
||
| class TestWildcardInjection: | ||
| """Test that wildcard special chars are escaped in search_labels.""" | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_wildcard_chars_escaped_in_search_labels(self, graph_storage): | ||
| """Wildcard metacharacters *, ? in user input must be escaped.""" | ||
| client = graph_storage.client | ||
|
|
||
| # Setup mock to return empty results | ||
| client.search.return_value = {"hits": {"hits": []}} | ||
|
|
||
| # Malicious query with wildcard chars that could cause expensive patterns | ||
| malicious_query = "test*?foo" | ||
| await graph_storage.search_labels(malicious_query) | ||
|
|
||
| # Inspect the query body that was sent to OpenSearch | ||
| assert client.search.called, "search should have been called" | ||
| call_kwargs = client.search.call_args | ||
| body = call_kwargs.kwargs.get("body") or call_kwargs[1].get("body") | ||
|
|
||
| # Extract the wildcard clause | ||
| should_clauses = body["query"]["bool"]["should"] | ||
| wildcard_clause = None | ||
| for clause in should_clauses: | ||
| if "wildcard" in clause: | ||
| wildcard_clause = clause["wildcard"]["entity_id"]["value"] | ||
| break | ||
|
|
||
| assert wildcard_clause is not None, "wildcard clause should exist" | ||
|
|
||
| # The wildcard value should NOT contain unescaped * or ? from the user input | ||
| # The outer * wrapping is fine (those are the intentional wildcards), | ||
| # but the inner user-provided * and ? must be escaped | ||
| # Expected: *test\*\?foo* (with the user's * and ? escaped) | ||
| inner_value = wildcard_clause[1:-1] # strip leading and trailing * | ||
| assert "\\*" in inner_value, ( | ||
| f"User's '*' should be escaped as '\\*' in wildcard, got: {wildcard_clause}" | ||
| ) | ||
| assert "\\?" in inner_value, ( | ||
| f"User's '?' should be escaped as '\\?' in wildcard, got: {wildcard_clause}" | ||
| ) | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_wildcard_heavy_pattern_not_exploitable(self, graph_storage): | ||
| """A series of ? chars should be escaped, not passed raw to OpenSearch.""" | ||
| client = graph_storage.client | ||
| client.search.return_value = {"hits": {"hits": []}} | ||
|
|
||
| # Attack: many single-char wildcards cause exponential matching | ||
| attack_query = "?" * 50 | ||
| await graph_storage.search_labels(attack_query) | ||
|
|
||
| call_kwargs = client.search.call_args | ||
| body = call_kwargs.kwargs.get("body") or call_kwargs[1].get("body") | ||
|
|
||
| should_clauses = body["query"]["bool"]["should"] | ||
| wildcard_clause = None | ||
| for clause in should_clauses: | ||
| if "wildcard" in clause: | ||
| wildcard_clause = clause["wildcard"]["entity_id"]["value"] | ||
| break | ||
|
|
||
| # None of the user's ? should appear as unescaped wildcards | ||
| # The value between the outer * delimiters should have all ? escaped | ||
| inner = wildcard_clause[1:-1] | ||
| # Count unescaped ? (i.e., ? not preceded by \) | ||
| unescaped_q = re.findall(r'(?<!\\)\?', inner) | ||
| assert len(unescaped_q) == 0, ( | ||
| f"Found {len(unescaped_q)} unescaped '?' in wildcard pattern: {wildcard_clause}" | ||
| ) | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_backslash_escaped_in_wildcard(self, graph_storage): | ||
| """Backslashes in user input must be double-escaped for the wildcard query.""" | ||
| client = graph_storage.client | ||
| client.search.return_value = {"hits": {"hits": []}} | ||
|
|
||
| attack_query = "test\\*" | ||
| await graph_storage.search_labels(attack_query) | ||
|
|
||
| call_kwargs = client.search.call_args | ||
| body = call_kwargs.kwargs.get("body") or call_kwargs[1].get("body") | ||
|
|
||
| should_clauses = body["query"]["bool"]["should"] | ||
| wildcard_clause = None | ||
| for clause in should_clauses: | ||
| if "wildcard" in clause: | ||
| wildcard_clause = clause["wildcard"]["entity_id"]["value"] | ||
| break | ||
|
|
||
| # The backslash should be escaped first, then the * — so we get \\\\\\* | ||
| # In the final pattern between outer *...*, user's \ becomes \\ and * becomes \* | ||
| inner = wildcard_clause[1:-1] | ||
| assert "\\\\" in inner or "\\*" in inner, ( | ||
| f"Backslash and * from user should be escaped in wildcard: {wildcard_clause}" | ||
| ) | ||
|
|
||
|
|
||
| class TestPPLInjection: | ||
| """Test that PPL string escape handles additional metacharacters.""" | ||
|
|
||
| def test_escape_ppl_basic_quote(self, graph_storage): | ||
| """Single quotes should be escaped.""" | ||
| result = graph_storage._escape_ppl("it's a test") | ||
| assert "'" not in result.replace("\\'", ""), ( | ||
| f"Unescaped quote found in: {result}" | ||
| ) | ||
|
|
||
| def test_escape_ppl_backslash(self, graph_storage): | ||
| """Backslashes should be escaped.""" | ||
| result = graph_storage._escape_ppl("test\\path") | ||
| assert result == "test\\\\path" | ||
|
|
||
| def test_escape_ppl_newline_and_control_chars(self, graph_storage): | ||
| """Newlines and control characters should be escaped/stripped.""" | ||
| result = graph_storage._escape_ppl("line1\nline2\rline3\t") | ||
| # Control chars should either be stripped or escaped — no raw newlines | ||
| assert "\n" not in result, f"Raw newline in PPL literal: {repr(result)}" | ||
| assert "\r" not in result, f"Raw carriage return in PPL literal: {repr(result)}" | ||
|
|
||
| def test_escape_ppl_pipe_in_quotes_safe(self, graph_storage): | ||
| """Pipe character inside a quoted string literal poses no injection risk.""" | ||
| result = graph_storage._escape_ppl("entity | stats count()") | ||
| assert isinstance(result, str) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__, "-v"]) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When PPL graph lookup is enabled, a node label that legitimately contains
\n,\r, or\tis first found byget_node(start_label), but this replacement changes the value used inwhere entity_id = '...'. Sinceupsert_nodeindexes the originalnode_idunchanged, the PPL query searches for a different label and returns no connected edges instead of falling back to the BFS path. This affects any extracted or manually created entity names with internal tabs/newlines; escape these characters without changing their value, or force the fallback path for labels PPL cannot represent.Useful? React with 👍 / 👎.