diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index d957146ec4..4bd8a0b112 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import final import configparser +import re from ..utils import logger from ..base import BaseGraphStorage @@ -59,18 +60,20 @@ def __init__(self, namespace, global_config, embedding_func, workspace=None): self._driver = None def _get_workspace_label(self) -> str: - """Return sanitized workspace label safe for use as a backtick-quoted identifier in Cypher queries. + """Return a workspace label safe for use in Cypher queries. - Escapes backticks by doubling them to prevent Cypher injection - via the LIGHTRAG-WORKSPACE header, while preserving a 1-to-1 mapping - for all other characters. The returned value is intended to be used - inside backticks (for example, MATCH (n:`{label}`)) and is not - validated as a standalone unquoted identifier. + Cypher labels cannot be parameterized, so constrain workspace labels to + alphanumeric characters and underscores before interpolating them into + query strings. This prevents injection through MEMGRAPH_WORKSPACE or + direct library callers that bypass API-level workspace sanitization. """ workspace = self.workspace.strip() if not workspace: return "base" - return workspace.replace("`", "``") + sanitized = re.sub(r"[^A-Za-z0-9_]", "_", workspace) + if not sanitized: + return "base" + return sanitized async def initialize(self): async with get_data_init_lock(): @@ -1093,7 +1096,7 @@ async def get_knowledge_graph( WHERE start.entity_id = $entity_id OPTIONAL MATCH path = (start)-[*BFS 0..{max_depth}]-(end:`{workspace_label}`) - WHERE path IS NULL OR ALL(n IN nodes(path) WHERE '{workspace_label}' IN labels(n)) + WHERE path IS NULL OR ALL(n IN nodes(path) WHERE $workspace_label IN labels(n)) WITH start, collect(DISTINCT end) AS discovered_nodes WITH start, [node IN discovered_nodes WHERE node IS NOT NULL AND node <> start] AS other_nodes WITH @@ -1122,6 +1125,7 @@ async def get_knowledge_graph( "entity_id": node_label, "max_nodes": max_nodes, "max_other_nodes": max(max_nodes - 1, 0), + "workspace_label": workspace_label, }, ) record = await result_set.single() diff --git a/tests/test_workspace_sanitization.py b/tests/test_workspace_sanitization.py index 7959f1d206..2e7af870ba 100644 --- a/tests/test_workspace_sanitization.py +++ b/tests/test_workspace_sanitization.py @@ -1,64 +1,67 @@ """ -Unit tests for workspace label sanitization in Memgraph and Neo4j implementations. +Unit tests for workspace label sanitization in graph implementations. -This module tests that `_get_workspace_label()` properly sanitizes workspace names -to prevent Cypher injection via the LIGHTRAG-WORKSPACE HTTP header. +Memgraph labels are used in Cypher query strings where labels cannot be +parameterized, so the Memgraph helper must reduce workspace names to an +alphanumeric/underscore allowlist. -It verifies that we preserve non-alphanumeric characters for 1-to-1 workspace mapping -while successfully neutralizing Cypher injection by escaping backticks. - -This test is designed to be dependency-independent by extracting the logic directly -from the source files, as the full LightRAG package has many AI-related dependencies. - -References: GitHub Issue #2698 +Neo4j keeps the existing backtick-escaping behavior to preserve one-to-one +workspace labels while preventing backtick-delimited identifier breakout. """ -import re import os +import re + import pytest # Mark all tests as offline (no external dependencies) pytestmark = pytest.mark.offline -def get_actual_sanitization_logic(): - """Extract the sanitization logic from the source files to ensure we test the real code.""" +def _source_contains(file_name: str, pattern: str) -> bool: + """Check that tests mirror the implementation currently in source.""" base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - files = [ - os.path.join(base_path, "lightrag/kg/memgraph_impl.py"), - os.path.join(base_path, "lightrag/kg/neo4j_impl.py"), - ] + file_path = os.path.join(base_path, "lightrag/kg", file_name) + with open(file_path, "r", encoding="utf-8") as f: + return re.search(pattern, f.read()) is not None - logics = [] - for file_path in files: - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - # Find the _get_workspace_label method body - # We look for the specific line: return workspace.replace("`", "``") - match = re.search(r"return workspace\.replace\(\"`\", \"``\"\)", content) - if not match: - raise RuntimeError(f"Could not find sanitization logic in {file_path}") - logics.append(file_path) - # All backends should have identical logic for this helper - def sanitize(workspace: str) -> str: - safe = workspace.strip() - if not safe: - safe = "base" - return safe.replace("`", "``") +def sanitize_memgraph(workspace: str) -> str: + """Mirror MemgraphStorage._get_workspace_label() for dependency-free tests.""" + if not _source_contains( + "memgraph_impl.py", + r"re\.sub\(r[\"']\[\^A-Za-z0-9_\][\"'], [\"']_[\"'], workspace\)", + ): + raise RuntimeError("Could not find Memgraph allowlist sanitization logic") - return sanitize + safe = workspace.strip() + if not safe: + return "base" + sanitized = re.sub(r"[^A-Za-z0-9_]", "_", safe) + if not sanitized: + return "base" + return sanitized -sanitize = get_actual_sanitization_logic() +def sanitize_neo4j(workspace: str) -> str: + """Mirror Neo4jStorage._get_workspace_label() for dependency-free tests.""" + if not _source_contains( + "neo4j_impl.py", r"return workspace\.replace\(\"`\", \"``\"\)" + ): + raise RuntimeError("Could not find Neo4j backtick escaping logic") + safe = workspace.strip() + if not safe: + return "base" + return safe.replace("`", "``") -class TestWorkspaceLabelSanitization: - """Test suite for _get_workspace_label() sanitization logic.""" + +class TestMemgraphWorkspaceLabelSanitization: + """Test suite for MemgraphStorage._get_workspace_label().""" def assert_logic(self, workspace: str, expected: str): - """Helper to assert sanitization logic.""" - assert sanitize(workspace) == expected + """Helper to assert Memgraph sanitization logic.""" + assert sanitize_memgraph(workspace) == expected # --- Normal inputs --- @@ -78,53 +81,52 @@ def test_numeric_only(self): """Numeric-only workspaces are valid.""" self.assert_logic("12345", "12345") - # --- Special characters preserved (unlike PostgreSQL regex stripping) --- + # --- Special characters replaced --- - def test_spaces_preserved(self): - """Spaces in workspace names should be preserved.""" - self.assert_logic("my workspace", "my workspace") + def test_spaces_replaced(self): + """Spaces in workspace names should be replaced.""" + self.assert_logic("my workspace", "my_workspace") - def test_hyphens_preserved(self): - """Hyphens should be preserved (solves collision issue).""" - self.assert_logic("my-workspace", "my-workspace") + def test_hyphens_replaced(self): + """Hyphens should be replaced.""" + self.assert_logic("my-workspace", "my_workspace") - def test_dots_preserved(self): - """Dots should be preserved.""" - self.assert_logic("my.workspace", "my.workspace") + def test_dots_replaced(self): + """Dots should be replaced.""" + self.assert_logic("my.workspace", "my_workspace") - def test_mixed_special_chars_preserved(self): - """Multiple different special characters should be preserved.""" - self.assert_logic("a-b.c d@e!f", "a-b.c d@e!f") + def test_mixed_special_chars_replaced(self): + """Multiple different special characters should be replaced.""" + self.assert_logic("a-b.c d@e!f", "a_b_c_d_e_f") # --- Cypher injection payloads --- - def test_cypher_injection_backtick_escaped(self): - """Backtick injection attempt should be neutralized by doubling backticks.""" + def test_cypher_injection_backtick_replaced(self): + """Backtick injection attempt should be neutralized by replacement.""" malicious = "test`}) MATCH (n) DETACH DELETE n //" - # The single backtick should become a double backtick - expected = "test``}) MATCH (n) DETACH DELETE n //" + expected = "test____MATCH__n__DETACH_DELETE_n___" self.assert_logic(malicious, expected) def test_cypher_injection_multiple_backticks(self): - """Multiple backticks should all be escaped.""" + """Multiple backticks should all be replaced.""" malicious = "`DROP`DATABASE`" - expected = "``DROP``DATABASE``" + expected = "_DROP_DATABASE_" self.assert_logic(malicious, expected) - def test_cypher_injection_curly_braces_preserved(self): - """Curly brace injection is harmless when enclosed in backticks, so preserved.""" + def test_cypher_injection_curly_braces_replaced(self): + """Curly brace injection should be replaced for identifier safety.""" malicious = "test}) RETURN 1 //" - self.assert_logic(malicious, malicious) + self.assert_logic(malicious, "test___RETURN_1___") - def test_cypher_injection_semicolon_preserved(self): - """Semicolon injection is harmless when enclosed in backticks, so preserved.""" + def test_cypher_injection_semicolon_replaced(self): + """Semicolon injection should be replaced for identifier safety.""" malicious = "test; DROP DATABASE neo4j" - self.assert_logic(malicious, malicious) + self.assert_logic(malicious, "test__DROP_DATABASE_neo4j") - def test_cypher_injection_quotes_preserved(self): - """Quote injection is harmless when enclosed in backticks, so preserved.""" + def test_cypher_injection_quotes_replaced(self): + """Quote injection should be replaced for identifier safety.""" malicious = 'test" OR 1=1 //' - self.assert_logic(malicious, malicious) + self.assert_logic(malicious, "test__OR_1_1___") # --- Empty / whitespace fallback --- @@ -136,9 +138,9 @@ def test_whitespace_only_fallback(self): """Whitespace-only workspace should fall back to 'base'.""" self.assert_logic(" ", "base") - def test_special_chars_only_preserved(self): - """Workspace with only special characters should be preserved.""" - self.assert_logic("---", "---") + def test_special_chars_only_replaced(self): + """Workspace with only special characters should be replaced with underscores.""" + self.assert_logic("---", "___") # --- Edge cases --- @@ -146,22 +148,22 @@ def test_leading_trailing_whitespace_stripped(self): """Leading/trailing whitespace should be stripped before sanitization.""" self.assert_logic(" myworkspace ", "myworkspace") - def test_unicode_characters_preserved(self): - """Non-ASCII/Chinese characters should be preserved.""" - self.assert_logic("工作区_test", "工作区_test") + def test_unicode_characters_replaced(self): + """Non-ASCII/Chinese characters should be replaced.""" + self.assert_logic("工作区_test", "____test") def test_very_long_workspace(self): """Very long workspace names should still be sanitized correctly.""" long_name = "a" * 1000 + "`" - expected = "a" * 1000 + "``" + expected = "a" * 1000 + "_" self.assert_logic(long_name, expected) def test_single_underscore(self): """Single underscore should be valid.""" self.assert_logic("_", "_") - def test_result_always_escapes_backticks(self): - """Parametric check: any output must not contain unescaped single backticks.""" + def test_result_uses_strict_allowlist(self): + """Parametric check: output contains only alphanumerics and underscores.""" dangerous_inputs = [ "normal", "with spaces", @@ -174,10 +176,20 @@ def test_result_always_escapes_backticks(self): "emoji🚀test", ] for inp in dangerous_inputs: - result = sanitize(inp) - backtick_sequences = re.findall(r"`+", result) - for seq in backtick_sequences: - # Any sequence of backticks should have an EVEN length because each ` becomes `` - assert ( - len(seq) % 2 == 0 - ), f"Unescaped backtick found in result '{result}' for input '{inp}'" + result = sanitize_memgraph(inp) + assert re.fullmatch(r"[A-Za-z0-9_]+", result) + + +class TestNeo4jWorkspaceLabelSanitization: + """Test suite for Neo4jStorage._get_workspace_label().""" + + def test_neo4j_preserves_non_backtick_characters(self): + assert sanitize_neo4j("my workspace") == "my workspace" + assert sanitize_neo4j("my-workspace") == "my-workspace" + assert sanitize_neo4j("工作区_test") == "工作区_test" + + def test_neo4j_escapes_backticks_and_falls_back(self): + assert sanitize_neo4j("test`}) MATCH (n) DETACH DELETE n //") == ( + "test``}) MATCH (n) DETACH DELETE n //" + ) + assert sanitize_neo4j(" ") == "base"