Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions lightrag/kg/memgraph_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass
from typing import final
import configparser
import re

from ..utils import logger
from ..base import BaseGraphStorage
Expand Down Expand Up @@ -59,18 +60,20 @@ def __init__(self, namespace, global_config, embedding_func, workspace=None):
self._driver = None

def _get_workspace_label(self) -> str:
"""Return sanitized workspace label safe for use as a backtick-quoted identifier in Cypher queries.
"""Return a workspace label safe for use in Cypher queries.

Escapes backticks by doubling them to prevent Cypher injection
via the LIGHTRAG-WORKSPACE header, while preserving a 1-to-1 mapping
for all other characters. The returned value is intended to be used
inside backticks (for example, MATCH (n:`{label}`)) and is not
validated as a standalone unquoted identifier.
Cypher labels cannot be parameterized, so constrain workspace labels to
alphanumeric characters and underscores before interpolating them into
query strings. This prevents injection through MEMGRAPH_WORKSPACE or
direct library callers that bypass API-level workspace sanitization.
"""
workspace = self.workspace.strip()
if not workspace:
return "base"
return workspace.replace("`", "``")
sanitized = re.sub(r"[^A-Za-z0-9_]", "_", workspace)
if not sanitized:
return "base"
return sanitized

async def initialize(self):
async with get_data_init_lock():
Expand Down Expand Up @@ -1093,7 +1096,7 @@ async def get_knowledge_graph(
WHERE start.entity_id = $entity_id

OPTIONAL MATCH path = (start)-[*BFS 0..{max_depth}]-(end:`{workspace_label}`)
WHERE path IS NULL OR ALL(n IN nodes(path) WHERE '{workspace_label}' IN labels(n))
WHERE path IS NULL OR ALL(n IN nodes(path) WHERE $workspace_label IN labels(n))
WITH start, collect(DISTINCT end) AS discovered_nodes
WITH start, [node IN discovered_nodes WHERE node IS NOT NULL AND node <> start] AS other_nodes
WITH
Expand Down Expand Up @@ -1122,6 +1125,7 @@ async def get_knowledge_graph(
"entity_id": node_label,
"max_nodes": max_nodes,
"max_other_nodes": max(max_nodes - 1, 0),
"workspace_label": workspace_label,
},
)
record = await result_set.single()
Expand Down
178 changes: 95 additions & 83 deletions tests/test_workspace_sanitization.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,67 @@
"""
Unit tests for workspace label sanitization in Memgraph and Neo4j implementations.
Unit tests for workspace label sanitization in graph implementations.

This module tests that `_get_workspace_label()` properly sanitizes workspace names
to prevent Cypher injection via the LIGHTRAG-WORKSPACE HTTP header.
Memgraph labels are used in Cypher query strings where labels cannot be
parameterized, so the Memgraph helper must reduce workspace names to an
alphanumeric/underscore allowlist.

It verifies that we preserve non-alphanumeric characters for 1-to-1 workspace mapping
while successfully neutralizing Cypher injection by escaping backticks.

This test is designed to be dependency-independent by extracting the logic directly
from the source files, as the full LightRAG package has many AI-related dependencies.

References: GitHub Issue #2698
Neo4j keeps the existing backtick-escaping behavior to preserve one-to-one
workspace labels while preventing backtick-delimited identifier breakout.
"""

import re
import os
import re

import pytest

# Mark all tests as offline (no external dependencies)
pytestmark = pytest.mark.offline


def get_actual_sanitization_logic():
"""Extract the sanitization logic from the source files to ensure we test the real code."""
def _source_contains(file_name: str, pattern: str) -> bool:
"""Check that tests mirror the implementation currently in source."""
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
files = [
os.path.join(base_path, "lightrag/kg/memgraph_impl.py"),
os.path.join(base_path, "lightrag/kg/neo4j_impl.py"),
]
file_path = os.path.join(base_path, "lightrag/kg", file_name)
with open(file_path, "r", encoding="utf-8") as f:
return re.search(pattern, f.read()) is not None

logics = []
for file_path in files:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
# Find the _get_workspace_label method body
# We look for the specific line: return workspace.replace("`", "``")
match = re.search(r"return workspace\.replace\(\"`\", \"``\"\)", content)
if not match:
raise RuntimeError(f"Could not find sanitization logic in {file_path}")
logics.append(file_path)

# All backends should have identical logic for this helper
def sanitize(workspace: str) -> str:
safe = workspace.strip()
if not safe:
safe = "base"
return safe.replace("`", "``")
def sanitize_memgraph(workspace: str) -> str:
"""Mirror MemgraphStorage._get_workspace_label() for dependency-free tests."""
if not _source_contains(
"memgraph_impl.py",
r"re\.sub\(r[\"']\[\^A-Za-z0-9_\][\"'], [\"']_[\"'], workspace\)",
):
raise RuntimeError("Could not find Memgraph allowlist sanitization logic")

return sanitize
safe = workspace.strip()
if not safe:
return "base"
sanitized = re.sub(r"[^A-Za-z0-9_]", "_", safe)
if not sanitized:
return "base"
return sanitized


sanitize = get_actual_sanitization_logic()
def sanitize_neo4j(workspace: str) -> str:
"""Mirror Neo4jStorage._get_workspace_label() for dependency-free tests."""
if not _source_contains(
"neo4j_impl.py", r"return workspace\.replace\(\"`\", \"``\"\)"
):
raise RuntimeError("Could not find Neo4j backtick escaping logic")

safe = workspace.strip()
if not safe:
return "base"
return safe.replace("`", "``")

class TestWorkspaceLabelSanitization:
"""Test suite for _get_workspace_label() sanitization logic."""

class TestMemgraphWorkspaceLabelSanitization:
"""Test suite for MemgraphStorage._get_workspace_label()."""

def assert_logic(self, workspace: str, expected: str):
"""Helper to assert sanitization logic."""
assert sanitize(workspace) == expected
"""Helper to assert Memgraph sanitization logic."""
assert sanitize_memgraph(workspace) == expected

# --- Normal inputs ---

Expand All @@ -78,53 +81,52 @@ def test_numeric_only(self):
"""Numeric-only workspaces are valid."""
self.assert_logic("12345", "12345")

# --- Special characters preserved (unlike PostgreSQL regex stripping) ---
# --- Special characters replaced ---

def test_spaces_preserved(self):
"""Spaces in workspace names should be preserved."""
self.assert_logic("my workspace", "my workspace")
def test_spaces_replaced(self):
"""Spaces in workspace names should be replaced."""
self.assert_logic("my workspace", "my_workspace")

def test_hyphens_preserved(self):
"""Hyphens should be preserved (solves collision issue)."""
self.assert_logic("my-workspace", "my-workspace")
def test_hyphens_replaced(self):
"""Hyphens should be replaced."""
self.assert_logic("my-workspace", "my_workspace")

def test_dots_preserved(self):
"""Dots should be preserved."""
self.assert_logic("my.workspace", "my.workspace")
def test_dots_replaced(self):
"""Dots should be replaced."""
self.assert_logic("my.workspace", "my_workspace")

def test_mixed_special_chars_preserved(self):
"""Multiple different special characters should be preserved."""
self.assert_logic("a-b.c d@e!f", "a-b.c d@e!f")
def test_mixed_special_chars_replaced(self):
"""Multiple different special characters should be replaced."""
self.assert_logic("a-b.c d@e!f", "a_b_c_d_e_f")

# --- Cypher injection payloads ---

def test_cypher_injection_backtick_escaped(self):
"""Backtick injection attempt should be neutralized by doubling backticks."""
def test_cypher_injection_backtick_replaced(self):
"""Backtick injection attempt should be neutralized by replacement."""
malicious = "test`}) MATCH (n) DETACH DELETE n //"
# The single backtick should become a double backtick
expected = "test``}) MATCH (n) DETACH DELETE n //"
expected = "test____MATCH__n__DETACH_DELETE_n___"
self.assert_logic(malicious, expected)

def test_cypher_injection_multiple_backticks(self):
"""Multiple backticks should all be escaped."""
"""Multiple backticks should all be replaced."""
malicious = "`DROP`DATABASE`"
expected = "``DROP``DATABASE``"
expected = "_DROP_DATABASE_"
self.assert_logic(malicious, expected)

def test_cypher_injection_curly_braces_preserved(self):
"""Curly brace injection is harmless when enclosed in backticks, so preserved."""
def test_cypher_injection_curly_braces_replaced(self):
"""Curly brace injection should be replaced for identifier safety."""
malicious = "test}) RETURN 1 //"
self.assert_logic(malicious, malicious)
self.assert_logic(malicious, "test___RETURN_1___")

def test_cypher_injection_semicolon_preserved(self):
"""Semicolon injection is harmless when enclosed in backticks, so preserved."""
def test_cypher_injection_semicolon_replaced(self):
"""Semicolon injection should be replaced for identifier safety."""
malicious = "test; DROP DATABASE neo4j"
self.assert_logic(malicious, malicious)
self.assert_logic(malicious, "test__DROP_DATABASE_neo4j")

def test_cypher_injection_quotes_preserved(self):
"""Quote injection is harmless when enclosed in backticks, so preserved."""
def test_cypher_injection_quotes_replaced(self):
"""Quote injection should be replaced for identifier safety."""
malicious = 'test" OR 1=1 //'
self.assert_logic(malicious, malicious)
self.assert_logic(malicious, "test__OR_1_1___")

# --- Empty / whitespace fallback ---

Expand All @@ -136,32 +138,32 @@ def test_whitespace_only_fallback(self):
"""Whitespace-only workspace should fall back to 'base'."""
self.assert_logic(" ", "base")

def test_special_chars_only_preserved(self):
"""Workspace with only special characters should be preserved."""
self.assert_logic("---", "---")
def test_special_chars_only_replaced(self):
"""Workspace with only special characters should be replaced with underscores."""
self.assert_logic("---", "___")

# --- Edge cases ---

def test_leading_trailing_whitespace_stripped(self):
"""Leading/trailing whitespace should be stripped before sanitization."""
self.assert_logic(" myworkspace ", "myworkspace")

def test_unicode_characters_preserved(self):
"""Non-ASCII/Chinese characters should be preserved."""
self.assert_logic("工作区_test", "工作区_test")
def test_unicode_characters_replaced(self):
"""Non-ASCII/Chinese characters should be replaced."""
self.assert_logic("工作区_test", "____test")

def test_very_long_workspace(self):
"""Very long workspace names should still be sanitized correctly."""
long_name = "a" * 1000 + "`"
expected = "a" * 1000 + "``"
expected = "a" * 1000 + "_"
self.assert_logic(long_name, expected)

def test_single_underscore(self):
"""Single underscore should be valid."""
self.assert_logic("_", "_")

def test_result_always_escapes_backticks(self):
"""Parametric check: any output must not contain unescaped single backticks."""
def test_result_uses_strict_allowlist(self):
"""Parametric check: output contains only alphanumerics and underscores."""
dangerous_inputs = [
"normal",
"with spaces",
Expand All @@ -174,10 +176,20 @@ def test_result_always_escapes_backticks(self):
"emoji🚀test",
]
for inp in dangerous_inputs:
result = sanitize(inp)
backtick_sequences = re.findall(r"`+", result)
for seq in backtick_sequences:
# Any sequence of backticks should have an EVEN length because each ` becomes ``
assert (
len(seq) % 2 == 0
), f"Unescaped backtick found in result '{result}' for input '{inp}'"
result = sanitize_memgraph(inp)
assert re.fullmatch(r"[A-Za-z0-9_]+", result)


class TestNeo4jWorkspaceLabelSanitization:
"""Test suite for Neo4jStorage._get_workspace_label()."""

def test_neo4j_preserves_non_backtick_characters(self):
assert sanitize_neo4j("my workspace") == "my workspace"
assert sanitize_neo4j("my-workspace") == "my-workspace"
assert sanitize_neo4j("工作区_test") == "工作区_test"

def test_neo4j_escapes_backticks_and_falls_back(self):
assert sanitize_neo4j("test`}) MATCH (n) DETACH DELETE n //") == (
"test``}) MATCH (n) DETACH DELETE n //"
)
assert sanitize_neo4j(" ") == "base"
Loading