Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions lightrag/kg/opensearch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1988,8 +1988,25 @@ async def _bfs_subgraph_ppl(

@staticmethod
def _escape_ppl(value: str) -> str:
"""Escape a string for safe inclusion in a PPL single-quoted literal."""
return value.replace("\\", "\\\\").replace("'", "\\'")
"""Escape a string for safe inclusion in a PPL single-quoted literal.

Escapes backslashes, single quotes, and control characters that could
interfere with PPL query parsing.
"""
value = value.replace("\\", "\\\\").replace("'", "\\'")
# Strip control characters that could break the PPL string literal
value = value.replace("\n", " ").replace("\r", " ").replace("\t", " ")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve control characters when escaping PPL labels

When PPL graph lookup is enabled, a node label that legitimately contains \n, \r, or \t is first found by get_node(start_label), but this replacement changes the value used in where entity_id = '...'. Since upsert_node indexes the original node_id unchanged, the PPL query searches for a different label and returns no connected edges instead of falling back to the BFS path. This affects any extracted or manually created entity names with internal tabs/newlines; escape these characters without changing their value, or force the fallback path for labels PPL cannot represent.

Useful? React with 👍 / 👎.

return value

@staticmethod
def _escape_wildcard(value: str) -> str:
"""Escape OpenSearch wildcard special characters in user input.

Escapes \\, *, and ? so they are treated as literal characters
rather than wildcard operators, preventing DoS via expensive patterns.
"""
# Escape backslash first, then wildcard metacharacters
return value.replace("\\", "\\\\").replace("*", "\\*").replace("?", "\\?")

async def _bfs_subgraph(
self, start_label: str, max_depth: int, max_nodes: int
Expand Down Expand Up @@ -2248,7 +2265,7 @@ async def search_labels(self, query: str, limit: int = 50) -> list[str]:
{
"wildcard": {
"entity_id": {
"value": f"*{query.lower()}*",
"value": f"*{self._escape_wildcard(query.lower())}*",
"case_insensitive": True,
"boost": 2,
}
Expand Down
227 changes: 227 additions & 0 deletions tests/test_cwe89_opensearch_injection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
"""
PoC test: CWE-89 OpenSearch injection via unsanitized entity names in query construction.

The test validates that:
1. Wildcard special characters (*, ?) in user input to search_labels are escaped
before being used in OpenSearch wildcard queries, preventing DoS via expensive
wildcard patterns.
2. PPL escape handles control characters and additional metacharacters beyond
just backslash and single-quote.

Run with: pytest tests/test_cwe89_opensearch_injection.py -v
"""

import re
import pytest
from contextlib import asynccontextmanager
from unittest.mock import AsyncMock, patch
import numpy as np

pytest.importorskip(
"opensearchpy",
reason="opensearchpy is required for OpenSearch storage tests",
)

from opensearchpy.exceptions import NotFoundError, OpenSearchException # type: ignore
from lightrag.kg.opensearch_impl import (
OpenSearchGraphStorage,
ClientManager,
)

pytestmark = pytest.mark.offline


@asynccontextmanager
async def _mock_lock():
yield


def _mock_lock_factory():
return _mock_lock()


@pytest.fixture(autouse=True)
def patch_data_init_lock():
with patch(
"lightrag.kg.opensearch_impl.get_data_init_lock", side_effect=_mock_lock_factory
):
yield


class MockEmbeddingFunc:
def __init__(self, dim=128):
self.embedding_dim = dim
self.max_token_size = 512
self.model_name = "mock-embed"

async def __call__(self, texts, **kwargs):
return np.random.rand(len(texts), self.embedding_dim).astype(np.float32)


@pytest.fixture
def global_config():
return {
"embedding_batch_num": 10,
"max_graph_nodes": 1000,
}


@pytest.fixture
def embed_func():
return MockEmbeddingFunc()


def _make_client():
from opensearchpy import AsyncOpenSearch
client = AsyncMock(spec=AsyncOpenSearch)
client.indices = AsyncMock()
client.indices.exists.return_value = True
client.transport = AsyncMock()
return client


@pytest.fixture
def graph_storage(global_config, embed_func):
with patch.object(ClientManager, "get_client") as mock_get:
client = _make_client()
mock_get.return_value = client

storage = OpenSearchGraphStorage(
namespace="test_graph",
global_config=global_config,
embedding_func=embed_func,
)
storage.client = client
storage._indices_ready = True
storage._ppl_graphlookup_available = True
yield storage


class TestWildcardInjection:
"""Test that wildcard special chars are escaped in search_labels."""

@pytest.mark.asyncio
async def test_wildcard_chars_escaped_in_search_labels(self, graph_storage):
"""Wildcard metacharacters *, ? in user input must be escaped."""
client = graph_storage.client

# Setup mock to return empty results
client.search.return_value = {"hits": {"hits": []}}

# Malicious query with wildcard chars that could cause expensive patterns
malicious_query = "test*?foo"
await graph_storage.search_labels(malicious_query)

# Inspect the query body that was sent to OpenSearch
assert client.search.called, "search should have been called"
call_kwargs = client.search.call_args
body = call_kwargs.kwargs.get("body") or call_kwargs[1].get("body")

# Extract the wildcard clause
should_clauses = body["query"]["bool"]["should"]
wildcard_clause = None
for clause in should_clauses:
if "wildcard" in clause:
wildcard_clause = clause["wildcard"]["entity_id"]["value"]
break

assert wildcard_clause is not None, "wildcard clause should exist"

# The wildcard value should NOT contain unescaped * or ? from the user input
# The outer * wrapping is fine (those are the intentional wildcards),
# but the inner user-provided * and ? must be escaped
# Expected: *test\*\?foo* (with the user's * and ? escaped)
inner_value = wildcard_clause[1:-1] # strip leading and trailing *
assert "\\*" in inner_value, (
f"User's '*' should be escaped as '\\*' in wildcard, got: {wildcard_clause}"
)
assert "\\?" in inner_value, (
f"User's '?' should be escaped as '\\?' in wildcard, got: {wildcard_clause}"
)

@pytest.mark.asyncio
async def test_wildcard_heavy_pattern_not_exploitable(self, graph_storage):
"""A series of ? chars should be escaped, not passed raw to OpenSearch."""
client = graph_storage.client
client.search.return_value = {"hits": {"hits": []}}

# Attack: many single-char wildcards cause exponential matching
attack_query = "?" * 50
await graph_storage.search_labels(attack_query)

call_kwargs = client.search.call_args
body = call_kwargs.kwargs.get("body") or call_kwargs[1].get("body")

should_clauses = body["query"]["bool"]["should"]
wildcard_clause = None
for clause in should_clauses:
if "wildcard" in clause:
wildcard_clause = clause["wildcard"]["entity_id"]["value"]
break

# None of the user's ? should appear as unescaped wildcards
# The value between the outer * delimiters should have all ? escaped
inner = wildcard_clause[1:-1]
# Count unescaped ? (i.e., ? not preceded by \)
unescaped_q = re.findall(r'(?<!\\)\?', inner)
assert len(unescaped_q) == 0, (
f"Found {len(unescaped_q)} unescaped '?' in wildcard pattern: {wildcard_clause}"
)

@pytest.mark.asyncio
async def test_backslash_escaped_in_wildcard(self, graph_storage):
"""Backslashes in user input must be double-escaped for the wildcard query."""
client = graph_storage.client
client.search.return_value = {"hits": {"hits": []}}

attack_query = "test\\*"
await graph_storage.search_labels(attack_query)

call_kwargs = client.search.call_args
body = call_kwargs.kwargs.get("body") or call_kwargs[1].get("body")

should_clauses = body["query"]["bool"]["should"]
wildcard_clause = None
for clause in should_clauses:
if "wildcard" in clause:
wildcard_clause = clause["wildcard"]["entity_id"]["value"]
break

# The backslash should be escaped first, then the * — so we get \\\\\\*
# In the final pattern between outer *...*, user's \ becomes \\ and * becomes \*
inner = wildcard_clause[1:-1]
assert "\\\\" in inner or "\\*" in inner, (
f"Backslash and * from user should be escaped in wildcard: {wildcard_clause}"
)


class TestPPLInjection:
"""Test that PPL string escape handles additional metacharacters."""

def test_escape_ppl_basic_quote(self, graph_storage):
"""Single quotes should be escaped."""
result = graph_storage._escape_ppl("it's a test")
assert "'" not in result.replace("\\'", ""), (
f"Unescaped quote found in: {result}"
)

def test_escape_ppl_backslash(self, graph_storage):
"""Backslashes should be escaped."""
result = graph_storage._escape_ppl("test\\path")
assert result == "test\\\\path"

def test_escape_ppl_newline_and_control_chars(self, graph_storage):
"""Newlines and control characters should be escaped/stripped."""
result = graph_storage._escape_ppl("line1\nline2\rline3\t")
# Control chars should either be stripped or escaped — no raw newlines
assert "\n" not in result, f"Raw newline in PPL literal: {repr(result)}"
assert "\r" not in result, f"Raw carriage return in PPL literal: {repr(result)}"

def test_escape_ppl_pipe_in_quotes_safe(self, graph_storage):
"""Pipe character inside a quoted string literal poses no injection risk."""
result = graph_storage._escape_ppl("entity | stats count()")
assert isinstance(result, str)


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading