diff --git a/.gitignore b/.gitignore index 16bf6f6a14..aef2f0b7c8 100644 --- a/.gitignore +++ b/.gitignore @@ -88,3 +88,10 @@ CLAUDE.md # Google Jules .jules/ + +# Test environment (contains secrets) +scripts/.env-test + +# Agent working directory +.agents/ +.env-dockerhub diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 242fe57240..dd739a7159 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -10,7 +10,6 @@ get_swagger_ui_oauth2_redirect_html, ) import os -import re import logging import logging.config import sys @@ -53,7 +52,10 @@ ) from lightrag.api.routers.query_routes import create_query_routes from lightrag.api.routers.graph_routes import create_graph_routes +from lightrag.api.routers.workspace_routes import create_workspace_routes from lightrag.api.routers.ollama_api import OllamaAPI +from lightrag.api.workspace_manager import WorkspaceManager, WorkspaceCapacityError +from lightrag.api.utils import sanitize_workspace_name, WorkspaceNameError from lightrag.utils import logger, set_verbose_debug from lightrag.kg.shared_storage import ( @@ -345,151 +347,7 @@ def create_app(args): # Initialize document manager with workspace support for data isolation doc_manager = DocumentManager(args.input_dir, workspace=args.workspace) - @asynccontextmanager - async def lifespan(app: FastAPI): - """Lifespan context manager for startup and shutdown events""" - # Store background tasks - app.state.background_tasks = set() - - try: - # Initialize database connections - # Note: initialize_storages() now auto-initializes pipeline_status for rag.workspace - await rag.initialize_storages() - - # Data migration regardless of storage implementation - await rag.check_and_migrate_data() - - ASCIIColors.green("\nServer is ready to accept connections! πŸš€\n") - - yield - - finally: - # Clean up database connections - await rag.finalize_storages() - - if "LIGHTRAG_GUNICORN_MODE" not in os.environ: - # Only perform cleanup in Uvicorn single-process mode - logger.debug("Unvicorn Mode: finalizing shared storage...") - finalize_share_data() - else: - # In Gunicorn mode with preload_app=True, cleanup is handled by on_exit hooks - logger.debug( - "Gunicorn Mode: postpone shared storage finalization to master process" - ) - # Initialize FastAPI - base_description = ( - "Providing API for LightRAG core, Web UI and Ollama Model Emulation" - ) - swagger_description = ( - base_description - + (" (API-Key Enabled)" if api_key else "") - + "\n\n[View ReDoc documentation](/redoc)" - ) - app_kwargs = { - "title": "LightRAG Server API", - "description": swagger_description, - "version": __api_version__, - "openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL - "docs_url": None, # Disable default docs, we'll create custom endpoint - "redoc_url": "/redoc", # Explicitly set redoc URL - "lifespan": lifespan, - } - - # Configure Swagger UI parameters - # Enable persistAuthorization and tryItOutEnabled for better user experience - app_kwargs["swagger_ui_parameters"] = { - "persistAuthorization": True, - "tryItOutEnabled": True, - } - - app = FastAPI(**app_kwargs) - - # Add custom validation error handler for /query/data endpoint - @app.exception_handler(RequestValidationError) - async def validation_exception_handler( - request: Request, exc: RequestValidationError - ): - # Check if this is a request to /query/data endpoint - if request.url.path.endswith("/query/data"): - # Extract error details - error_details = [] - for error in exc.errors(): - field_path = " -> ".join(str(loc) for loc in error["loc"]) - error_details.append(f"{field_path}: {error['msg']}") - - error_message = "; ".join(error_details) - - # Return in the expected format for /query/data - return JSONResponse( - status_code=400, - content={ - "status": "failure", - "message": f"Validation error: {error_message}", - "data": {}, - "metadata": {}, - }, - ) - else: - # For other endpoints, return the default FastAPI validation error - return JSONResponse(status_code=422, content={"detail": exc.errors()}) - - def get_cors_origins(): - """Get allowed origins from global_args - Returns a list of allowed origins, defaults to ["*"] if not set - """ - origins_str = global_args.cors_origins - if origins_str == "*": - return ["*"] - return [origin.strip() for origin in origins_str.split(",")] - - # Add CORS middleware - app.add_middleware( - CORSMiddleware, - allow_origins=get_cors_origins(), - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - expose_headers=[ - "X-New-Token" - ], # Expose token renewal header for cross-origin requests - ) - - # Create combined auth dependency for all endpoints - combined_auth = get_combined_auth_dependency(api_key) - - def get_workspace_from_request(request: Request) -> str | None: - """ - Extract workspace from HTTP request header or use default. - - This enables multi-workspace API support by checking the custom - 'LIGHTRAG-WORKSPACE' header. If not present, falls back to the - server's default workspace configuration. - - Args: - request: FastAPI Request object - - Returns: - Workspace identifier (may be empty string for global namespace) - """ - # Check custom header first - workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip() - - if not workspace: - workspace = None - else: - sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", workspace) - if sanitized != workspace: - logger.warning( - f"Workspace header '{workspace}' contains invalid characters. " - f"Sanitized to '{sanitized}'." - ) - workspace = sanitized - - return workspace - - # Create working directory if it doesn't exist - Path(args.working_dir).mkdir(parents=True, exist_ok=True) def create_optimized_openai_llm_func( config_cache: LLMConfigCache, args, llm_timeout: int @@ -1135,11 +993,11 @@ async def server_rerank_func( name=args.simulated_model_name, tag=args.simulated_model_tag ) - # Initialize RAG with unified configuration - try: - rag = LightRAG( + # Create factory callable β€” captures all constructor args as closure + def create_lightrag(workspace: str): + return LightRAG( working_dir=args.working_dir, - workspace=args.workspace, + workspace=workspace, llm_model_func=create_llm_model_func(args.llm_binding), llm_model_name=args.llm_model, llm_model_max_async=args.max_async, @@ -1171,23 +1029,180 @@ async def server_rerank_func( }, ollama_server_infos=ollama_server_infos, ) + + # Initialize WorkspaceManager for workspace lifecycle management + try: + workspace_mgr = WorkspaceManager(factory=create_lightrag, max_instances=10) except Exception as e: - logger.error(f"Failed to initialize LightRAG: {e}") + logger.error(f"Failed to initialize WorkspaceManager: {e}") raise + @asynccontextmanager + async def lifespan(app: FastAPI): + """Lifespan context manager for startup and shutdown events""" + # Store background tasks + app.state.background_tasks = set() + + try: + # Pre-warm default workspace. + # Intentionally keeps ref_count=1 permanently β€” the default workspace is never evicted. + # This ensures backward compatibility: requests without the LIGHTRAG-WORKSPACE header + # always find a ready instance. + default_rag = await workspace_mgr.get_or_create(args.workspace) + + # Initialize database connections for pre-warmed default workspace + # Note: initialize_storages() now auto-initializes pipeline_status for default_rag.workspace + await default_rag.initialize_storages() + + # Data migration regardless of storage implementation + await default_rag.check_and_migrate_data() + + ASCIIColors.green("\nServer is ready to accept connections! πŸš€\n") + + yield + + finally: + # Shutdown WorkspaceManager to finalize all cached instances + await workspace_mgr.shutdown() + + if "LIGHTRAG_GUNICORN_MODE" not in os.environ: + # Only perform cleanup in Uvicorn single-process mode + logger.debug("Unvicorn Mode: finalizing shared storage...") + finalize_share_data() + else: + # In Gunicorn mode with preload_app=True, cleanup is handled by on_exit hooks + logger.debug( + "Gunicorn Mode: postpone shared storage finalization to master process" + ) + + # Initialize FastAPI + base_description = ( + "Providing API for LightRAG core, Web UI and Ollama Model Emulation" + ) + swagger_description = ( + base_description + + (" (API-Key Enabled)" if api_key else "") + + "\n\n[View ReDoc documentation](/redoc)" + ) + app_kwargs = { + "title": "LightRAG Server API", + "description": swagger_description, + "version": __api_version__, + "openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL + "docs_url": None, # Disable default docs, we'll create custom endpoint + "redoc_url": "/redoc", # Explicitly set redoc URL + "lifespan": lifespan, + } + + # Configure Swagger UI parameters + # Enable persistAuthorization and tryItOutEnabled for better user experience + app_kwargs["swagger_ui_parameters"] = { + "persistAuthorization": True, + "tryItOutEnabled": True, + } + + app = FastAPI(**app_kwargs) + + # Add custom validation error handler for /query/data endpoint + @app.exception_handler(RequestValidationError) + async def validation_exception_handler( + request: Request, exc: RequestValidationError + ): + # Check if this is a request to /query/data endpoint + if request.url.path.endswith("/query/data"): + # Extract error details + error_details = [] + for error in exc.errors(): + field_path = " -> ".join(str(loc) for loc in error["loc"]) + error_details.append(f"{field_path}: {error['msg']}") + + error_message = "; ".join(error_details) + + # Return in the expected format for /query/data + return JSONResponse( + status_code=400, + content={ + "status": "failure", + "message": f"Validation error: {error_message}", + "data": {}, + "metadata": {}, + }, + ) + else: + # For other endpoints, return the default FastAPI validation error + return JSONResponse(status_code=422, content={"detail": exc.errors()}) + + def get_cors_origins(): + """Get allowed origins from global_args + Returns a list of allowed origins, defaults to ["*"] if not set + """ + origins_str = global_args.cors_origins + if origins_str == "*": + return ["*"] + return [origin.strip() for origin in origins_str.split(",")] + + # Add CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=get_cors_origins(), + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + expose_headers=[ + "X-New-Token" + ], # Expose token renewal header for cross-origin requests + ) + + # Create combined auth dependency for all endpoints + combined_auth = get_combined_auth_dependency(api_key) + + def get_workspace_from_request(request: Request) -> str | None: + """ + Extract workspace from HTTP request header or use default. + + This enables multi-workspace API support by checking the custom + 'LIGHTRAG-WORKSPACE' header. If not present, falls back to the + server's default workspace configuration. + + Args: + request: FastAPI Request object + + Returns: + Workspace identifier (may be empty string for global namespace), + or None if no workspace header is present. + """ + # Check custom header first + raw_workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip() + + if raw_workspace: + try: + workspace = sanitize_workspace_name(raw_workspace) + except WorkspaceNameError as e: + raise HTTPException(status_code=400, detail=str(e)) + else: + workspace = None + + return workspace + + # Create working directory if it doesn't exist + Path(args.working_dir).mkdir(parents=True, exist_ok=True) + # Add routes app.include_router( create_document_routes( - rag, + workspace_mgr, doc_manager, api_key, ) ) - app.include_router(create_query_routes(rag, api_key, args.top_k)) - app.include_router(create_graph_routes(rag, api_key)) + app.include_router(create_query_routes(workspace_mgr, api_key, args.top_k)) + app.include_router(create_graph_routes(workspace_mgr, api_key)) + app.include_router( + create_workspace_routes(api_key, working_dir=global_args.working_dir) + ) # Add Ollama API routes - ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key) + ollama_api = OllamaAPI(workspace_mgr, top_k=args.top_k, api_key=api_key) app.include_router(ollama_api.router, prefix="/api") # Custom Swagger UI endpoint for offline support @@ -1317,73 +1332,92 @@ async def login(form_data: OAuth2PasswordRequestForm = Depends()): ) async def get_status(request: Request): """Get current system status including WebUI availability""" + rag = None try: workspace = get_workspace_from_request(request) default_workspace = get_default_workspace() if workspace is None: workspace = default_workspace - pipeline_status = await get_namespace_data( - "pipeline_status", workspace=workspace - ) - if not auth_configured: - auth_mode = "disabled" - else: - auth_mode = "enabled" + # Get or create the workspace instance via WorkspaceManager + try: + rag = await workspace_mgr.get_or_create(workspace) + except WorkspaceCapacityError: + return JSONResponse( + status_code=503, + content={"status": "error", "message": "All workspace slots busy"}, + ) - # Cleanup expired keyed locks and get status - keyed_lock_info = cleanup_keyed_lock() + try: + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=workspace + ) - return { - "status": "healthy", - "webui_available": webui_assets_exist, - "working_directory": str(args.working_dir), - "input_directory": str(args.input_dir), - "configuration": { - # LLM configuration binding/host address (if applicable)/model (if applicable) - "llm_binding": args.llm_binding, - "llm_binding_host": args.llm_binding_host, - "llm_model": args.llm_model, - # embedding model configuration binding/host address (if applicable)/model (if applicable) - "embedding_binding": args.embedding_binding, - "embedding_binding_host": args.embedding_binding_host, - "embedding_model": args.embedding_model, - "summary_max_tokens": args.summary_max_tokens, - "summary_context_size": args.summary_context_size, - "kv_storage": args.kv_storage, - "doc_status_storage": args.doc_status_storage, - "graph_storage": args.graph_storage, - "vector_storage": args.vector_storage, - "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, - "enable_llm_cache": args.enable_llm_cache, - "workspace": default_workspace, - "max_graph_nodes": args.max_graph_nodes, - # Rerank configuration - "enable_rerank": rerank_model_func is not None, - "rerank_binding": args.rerank_binding, - "rerank_model": args.rerank_model if rerank_model_func else None, - "rerank_binding_host": args.rerank_binding_host - if rerank_model_func - else None, - # Environment variable status (requested configuration) - "summary_language": args.summary_language, - "force_llm_summary_on_merge": args.force_llm_summary_on_merge, - "max_parallel_insert": args.max_parallel_insert, - "cosine_threshold": args.cosine_threshold, - "min_rerank_score": args.min_rerank_score, - "related_chunk_number": args.related_chunk_number, - "max_async": args.max_async, - "embedding_func_max_async": args.embedding_func_max_async, - "embedding_batch_num": args.embedding_batch_num, - }, - "auth_mode": auth_mode, - "pipeline_busy": pipeline_status.get("busy", False), - "keyed_locks": keyed_lock_info, - "core_version": core_version, - "api_version": api_version_display, - "webui_title": webui_title, - "webui_description": webui_description, - } + if not auth_configured: + auth_mode = "disabled" + else: + auth_mode = "enabled" + + # Cleanup expired keyed locks and get status + keyed_lock_info = cleanup_keyed_lock() + + return { + "status": "healthy", + "webui_available": webui_assets_exist, + "working_directory": str(args.working_dir), + "input_directory": str(args.input_dir), + "configuration": { + # LLM configuration binding/host address (if applicable)/model (if applicable) + "llm_binding": args.llm_binding, + "llm_binding_host": args.llm_binding_host, + "llm_model": args.llm_model, + # embedding model configuration binding/host address (if applicable)/model (if applicable) + "embedding_binding": args.embedding_binding, + "embedding_binding_host": args.embedding_binding_host, + "embedding_model": args.embedding_model, + "summary_max_tokens": args.summary_max_tokens, + "summary_context_size": args.summary_context_size, + "kv_storage": args.kv_storage, + "doc_status_storage": args.doc_status_storage, + "graph_storage": args.graph_storage, + "vector_storage": args.vector_storage, + "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, + "enable_llm_cache": args.enable_llm_cache, + "workspace": workspace, + "max_graph_nodes": args.max_graph_nodes, + # Rerank configuration + "enable_rerank": rerank_model_func is not None, + "rerank_binding": args.rerank_binding, + "rerank_model": args.rerank_model + if rerank_model_func + else None, + "rerank_binding_host": args.rerank_binding_host + if rerank_model_func + else None, + # Environment variable status (requested configuration) + "summary_language": args.summary_language, + "force_llm_summary_on_merge": args.force_llm_summary_on_merge, + "max_parallel_insert": args.max_parallel_insert, + "cosine_threshold": args.cosine_threshold, + "min_rerank_score": args.min_rerank_score, + "related_chunk_number": args.related_chunk_number, + "max_async": args.max_async, + "embedding_func_max_async": args.embedding_func_max_async, + "embedding_batch_num": args.embedding_batch_num, + }, + "auth_mode": auth_mode, + "pipeline_busy": pipeline_status.get("busy", False), + "keyed_locks": keyed_lock_info, + "core_version": core_version, + "api_version": api_version_display, + "webui_title": webui_title, + "webui_description": webui_description, + } + finally: + if rag is not None: + workspace_mgr.release(workspace) + except HTTPException: + raise except Exception as e: logger.error(f"Error getting health status: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) diff --git a/lightrag/api/routers/__init__.py b/lightrag/api/routers/__init__.py index b71f204e04..3635997b14 100644 --- a/lightrag/api/routers/__init__.py +++ b/lightrag/api/routers/__init__.py @@ -5,6 +5,13 @@ from .document_routes import router as document_router from .query_routes import router as query_router from .graph_routes import router as graph_router +from .workspace_routes import router as workspace_router from .ollama_api import OllamaAPI -__all__ = ["document_router", "query_router", "graph_router", "OllamaAPI"] +__all__ = [ + "document_router", + "query_router", + "graph_router", + "workspace_router", + "OllamaAPI", +] diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 9e6fab9dbb..f4a1a5fb79 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -19,6 +19,7 @@ Depends, File, HTTPException, + Request, UploadFile, ) from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -31,6 +32,8 @@ sanitize_text_for_encoding, ) from lightrag.api.utils_api import get_combined_auth_dependency +from lightrag.api.utils import extract_workspace_from_header +from lightrag.api.workspace_registry import get_workspace_registry from ..config import global_args @@ -2085,15 +2088,34 @@ async def background_delete_documents( def create_document_routes( - rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None + workspace_mgr, doc_manager: DocumentManager, api_key: Optional[str] = None ): # Create combined auth dependency for document routes combined_auth = get_combined_auth_dependency(api_key) + # Get workspace registry for auto-registering workspaces + workspace_registry = get_workspace_registry(working_dir=global_args.working_dir) + + def extract_and_register_workspace(http_request: Request) -> str: + """Extract workspace from header and auto-register it.""" + workspace = extract_workspace_from_header(http_request) + if workspace: # Only register non-empty workspaces + workspace_registry.register_workspace(workspace) + return workspace + + async def _run_scan_with_workspace(workspace_mgr, workspace, doc_manager, track_id): + rag = await workspace_mgr.get_or_create(workspace) + try: + await run_scanning_process(rag, doc_manager, track_id) + finally: + workspace_mgr.release(workspace) + @router.post( "/scan", response_model=ScanResponse, dependencies=[Depends(combined_auth)] ) - async def scan_for_new_documents(background_tasks: BackgroundTasks): + async def scan_for_new_documents( + http_request: Request, background_tasks: BackgroundTasks + ): """ Trigger the scanning process for new documents. @@ -2105,21 +2127,44 @@ async def scan_for_new_documents(background_tasks: BackgroundTasks): ScanResponse: A response object containing the scanning status and track_id """ # Generate track_id with "scan" prefix for scanning operation + workspace = extract_and_register_workspace(http_request) track_id = generate_track_id("scan") + # Create workspace-specific doc_manager + workspace_doc_manager = DocumentManager( + input_dir=doc_manager.base_input_dir, + workspace=workspace, + supported_extensions=doc_manager.supported_extensions, + ) + # Start the scanning process in the background with track_id - background_tasks.add_task(run_scanning_process, rag, doc_manager, track_id) + background_tasks.add_task( + _run_scan_with_workspace, + workspace_mgr, + workspace, + workspace_doc_manager, + track_id, + ) return ScanResponse( status="scanning_started", message="Scanning process has been initiated in the background", track_id=track_id, ) + async def _run_upload_with_workspace(workspace_mgr, workspace, file_path, track_id): + rag = await workspace_mgr.get_or_create(workspace) + try: + await pipeline_index_file(rag, file_path, track_id) + finally: + workspace_mgr.release(workspace) + @router.post( "/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)] ) async def upload_to_input_dir( - background_tasks: BackgroundTasks, file: UploadFile = File(...) + http_request: Request, + background_tasks: BackgroundTasks, + file: UploadFile = File(...), ): """ Upload a file to the input directory and index it. @@ -2161,6 +2206,7 @@ async def upload_to_input_dir( - This design prevents blocking the client during expensive operations Args: + http_request: FastAPI Request object for workspace extraction background_tasks: FastAPI BackgroundTasks for async processing file (UploadFile): The file to be uploaded. It must have an allowed extension. @@ -2173,13 +2219,26 @@ async def upload_to_input_dir( HTTPException: If the file type is not supported (400), file too large (413), or other errors occur (500). """ try: + # Get workspace first for duplicate check and file operations + workspace = extract_and_register_workspace(http_request) + rag = await workspace_mgr.get_or_create(workspace) + + # Create workspace-specific doc_manager for file operations + workspace_doc_manager = DocumentManager( + input_dir=doc_manager.base_input_dir, + workspace=workspace, + supported_extensions=doc_manager.supported_extensions, + ) + # Sanitize filename to prevent Path Traversal attacks - safe_filename = sanitize_filename(file.filename, doc_manager.input_dir) + safe_filename = sanitize_filename( + file.filename, workspace_doc_manager.input_dir + ) - if not doc_manager.is_supported_file(safe_filename): + if not workspace_doc_manager.is_supported_file(safe_filename): raise HTTPException( status_code=400, - detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", + detail=f"Unsupported file type. Supported types: {workspace_doc_manager.supported_extensions}", ) # Check file size limit (if configured) @@ -2216,7 +2275,7 @@ async def upload_to_input_dir( track_id=existing_track_id, ) - file_path = doc_manager.input_dir / safe_filename + file_path = workspace_doc_manager.input_dir / safe_filename # Check if file already exists in file system if file_path.exists(): return InsertResponse( @@ -2266,8 +2325,14 @@ async def upload_to_input_dir( track_id = generate_track_id("upload") - # Add to background tasks and get track_id - background_tasks.add_task(pipeline_index_file, rag, file_path, track_id) + # Add to background tasks with workspace isolation + background_tasks.add_task( + _run_upload_with_workspace, + workspace_mgr, + workspace, + file_path, + track_id, + ) return InsertResponse( status="success", @@ -2282,12 +2347,27 @@ async def upload_to_input_dir( logger.error(f"Error /documents/upload: {file.filename}: {str(e)}") logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) + finally: + workspace_mgr.release(workspace) + + async def _run_insert_text_with_workspace( + workspace_mgr, workspace, text, file_source, track_id + ): + rag = await workspace_mgr.get_or_create(workspace) + try: + await pipeline_index_texts( + rag, [text], file_sources=[file_source], track_id=track_id + ) + finally: + workspace_mgr.release(workspace) @router.post( "/text", response_model=InsertResponse, dependencies=[Depends(combined_auth)] ) async def insert_text( - request: InsertTextRequest, background_tasks: BackgroundTasks + http_request: Request, + request: InsertTextRequest, + background_tasks: BackgroundTasks, ): """ Insert text into the RAG system. @@ -2296,6 +2376,7 @@ async def insert_text( and use in generating responses. Args: + http_request: FastAPI Request object for workspace extraction request (InsertTextRequest): The request body containing the text to be inserted. background_tasks: FastAPI BackgroundTasks for async processing @@ -2305,6 +2386,8 @@ async def insert_text( Raises: HTTPException: If an error occurs during text processing (500). """ + workspace = extract_and_register_workspace(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: # Check if file_source already exists in doc_status storage if ( @@ -2344,11 +2427,12 @@ async def insert_text( track_id = generate_track_id("insert") background_tasks.add_task( - pipeline_index_texts, - rag, - [request.text], - file_sources=[request.file_source], - track_id=track_id, + _run_insert_text_with_workspace, + workspace_mgr, + workspace, + request.text, + request.file_source, + track_id, ) return InsertResponse( @@ -2356,10 +2440,19 @@ async def insert_text( message="Text successfully received. Processing will continue in background.", track_id=track_id, ) - except Exception as e: - logger.error(f"Error /documents/text: {str(e)}") - logger.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) + finally: + workspace_mgr.release(workspace) + + async def _run_insert_texts_with_workspace( + workspace_mgr, workspace, texts, file_sources, track_id + ): + rag = await workspace_mgr.get_or_create(workspace) + try: + await pipeline_index_texts( + rag, texts, file_sources=file_sources, track_id=track_id + ) + finally: + workspace_mgr.release(workspace) @router.post( "/texts", @@ -2367,7 +2460,9 @@ async def insert_text( dependencies=[Depends(combined_auth)], ) async def insert_texts( - request: InsertTextsRequest, background_tasks: BackgroundTasks + http_request: Request, + request: InsertTextsRequest, + background_tasks: BackgroundTasks, ): """ Insert multiple texts into the RAG system. @@ -2376,6 +2471,7 @@ async def insert_texts( in a single request. Args: + http_request: FastAPI Request object for workspace extraction request (InsertTextsRequest): The request body containing the list of texts. background_tasks: FastAPI BackgroundTasks for async processing @@ -2385,6 +2481,8 @@ async def insert_texts( Raises: HTTPException: If an error occurs during text processing (500). """ + workspace = extract_and_register_workspace(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: # Check if any file_sources already exist in doc_status storage if request.file_sources: @@ -2427,11 +2525,12 @@ async def insert_texts( track_id = generate_track_id("insert") background_tasks.add_task( - pipeline_index_texts, - rag, + _run_insert_texts_with_workspace, + workspace_mgr, + workspace, request.texts, - file_sources=request.file_sources, - track_id=track_id, + request.file_sources, + track_id, ) return InsertResponse( @@ -2439,15 +2538,13 @@ async def insert_texts( message="Texts successfully received. Processing will continue in background.", track_id=track_id, ) - except Exception as e: - logger.error(f"Error /documents/texts: {str(e)}") - logger.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) + finally: + workspace_mgr.release(workspace) @router.delete( "", response_model=ClearDocumentsResponse, dependencies=[Depends(combined_auth)] ) - async def clear_documents(): + async def clear_documents(http_request: Request): """ Clear all documents from the RAG system. @@ -2455,6 +2552,9 @@ async def clear_documents(): It uses the storage drop methods to properly clean up all data and removes all files from the input directory. + Args: + http_request: FastAPI Request object for workspace extraction + Returns: ClearDocumentsResponse: A response object containing the status and message. - status="success": All documents and files were successfully cleared. @@ -2468,188 +2568,203 @@ async def clear_documents(): HTTPException: Raised when a serious error occurs during the clearing process, with status code 500 and error details in the detail field. """ - from lightrag.kg.shared_storage import ( - get_namespace_data, - get_namespace_lock, - ) + workspace = extract_and_register_workspace(http_request) + rag = await workspace_mgr.get_or_create(workspace) + try: + from lightrag.kg.shared_storage import ( + get_namespace_data, + get_namespace_lock, + ) - # Get pipeline status and lock - pipeline_status = await get_namespace_data( - "pipeline_status", workspace=rag.workspace - ) - pipeline_status_lock = get_namespace_lock( - "pipeline_status", workspace=rag.workspace - ) + # Create workspace-specific doc_manager for file operations + workspace_doc_manager = DocumentManager( + input_dir=doc_manager.base_input_dir, + workspace=workspace, + supported_extensions=doc_manager.supported_extensions, + ) - # Check and set status with lock - async with pipeline_status_lock: - if pipeline_status.get("busy", False): - return ClearDocumentsResponse( - status="busy", - message="Cannot clear documents while pipeline is busy", - ) - # Set busy to true - pipeline_status.update( - { - "busy": True, - "job_name": "Clearing Documents", - "job_start": datetime.now().isoformat(), - "docs": 0, - "batchs": 0, - "cur_batch": 0, - "request_pending": False, # Clear any previous request - "latest_message": "Starting document clearing process", - } + # Get pipeline status and lock + pipeline_status = await get_namespace_data( + "pipeline_status", workspace=rag.workspace ) - # Cleaning history_messages without breaking it as a shared list object - del pipeline_status["history_messages"][:] - pipeline_status["history_messages"].append( - "Starting document clearing process" + pipeline_status_lock = get_namespace_lock( + "pipeline_status", workspace=rag.workspace ) - try: - # Use drop method to clear all data - drop_tasks = [] - storages = [ - rag.text_chunks, - rag.full_docs, - rag.full_entities, - rag.full_relations, - rag.entity_chunks, - rag.relation_chunks, - rag.entities_vdb, - rag.relationships_vdb, - rag.chunks_vdb, - rag.chunk_entity_relation_graph, - rag.doc_status, - ] - - # Log storage drop start - if "history_messages" in pipeline_status: + # Check and set status with lock + async with pipeline_status_lock: + if pipeline_status.get("busy", False): + return ClearDocumentsResponse( + status="busy", + message="Cannot clear documents while pipeline is busy", + ) + # Set busy to true + pipeline_status.update( + { + "busy": True, + "job_name": "Clearing Documents", + "job_start": datetime.now().isoformat(), + "docs": 0, + "batchs": 0, + "cur_batch": 0, + "request_pending": False, # Clear any previous request + "latest_message": "Starting document clearing process", + } + ) + # Cleaning history_messages without breaking it as a shared list object + del pipeline_status["history_messages"][:] pipeline_status["history_messages"].append( - "Starting to drop storage components" + "Starting document clearing process" ) - for storage in storages: - if storage is not None: - drop_tasks.append(storage.drop()) + try: + # Use drop method to clear all data + drop_tasks = [] + storages = [ + rag.text_chunks, + rag.full_docs, + rag.full_entities, + rag.full_relations, + rag.entity_chunks, + rag.relation_chunks, + rag.entities_vdb, + rag.relationships_vdb, + rag.chunks_vdb, + rag.chunk_entity_relation_graph, + rag.doc_status, + ] - # Wait for all drop tasks to complete - drop_results = await asyncio.gather(*drop_tasks, return_exceptions=True) + # Log storage drop start + if "history_messages" in pipeline_status: + pipeline_status["history_messages"].append( + "Starting to drop storage components" + ) - # Check for errors and log results - errors = [] - storage_success_count = 0 - storage_error_count = 0 + for storage in storages: + if storage is not None: + drop_tasks.append(storage.drop()) + + # Wait for all drop tasks to complete + drop_results = await asyncio.gather(*drop_tasks, return_exceptions=True) + + # Check for errors and log results + errors = [] + storage_success_count = 0 + storage_error_count = 0 + + for i, result in enumerate(drop_results): + storage_name = storages[i].__class__.__name__ + if isinstance(result, Exception): + error_msg = f"Error dropping {storage_name}: {str(result)}" + errors.append(error_msg) + logger.error(error_msg) + storage_error_count += 1 + else: + namespace = storages[i].namespace + storage_workspace = storages[i].workspace + logger.info( + f"Successfully dropped {storage_name}: {storage_workspace}/{namespace}" + ) + storage_success_count += 1 - for i, result in enumerate(drop_results): - storage_name = storages[i].__class__.__name__ - if isinstance(result, Exception): - error_msg = f"Error dropping {storage_name}: {str(result)}" - errors.append(error_msg) - logger.error(error_msg) - storage_error_count += 1 - else: - namespace = storages[i].namespace - workspace = storages[i].workspace - logger.info( - f"Successfully dropped {storage_name}: {workspace}/{namespace}" - ) - storage_success_count += 1 + # Log storage drop results + if "history_messages" in pipeline_status: + if storage_error_count > 0: + pipeline_status["history_messages"].append( + f"Dropped {storage_success_count} storage components with {storage_error_count} errors" + ) + else: + pipeline_status["history_messages"].append( + f"Successfully dropped all {storage_success_count} storage components" + ) - # Log storage drop results - if "history_messages" in pipeline_status: - if storage_error_count > 0: - pipeline_status["history_messages"].append( - f"Dropped {storage_success_count} storage components with {storage_error_count} errors" - ) - else: - pipeline_status["history_messages"].append( - f"Successfully dropped all {storage_success_count} storage components" - ) + # If all storage operations failed, return error status and don't proceed with file deletion + if storage_success_count == 0 and storage_error_count > 0: + error_message = "All storage drop operations failed. Aborting document clearing process." + logger.error(error_message) + if "history_messages" in pipeline_status: + pipeline_status["history_messages"].append(error_message) + return ClearDocumentsResponse(status="fail", message=error_message) - # If all storage operations failed, return error status and don't proceed with file deletion - if storage_success_count == 0 and storage_error_count > 0: - error_message = "All storage drop operations failed. Aborting document clearing process." - logger.error(error_message) + # Log file deletion start if "history_messages" in pipeline_status: - pipeline_status["history_messages"].append(error_message) - return ClearDocumentsResponse(status="fail", message=error_message) + pipeline_status["history_messages"].append( + "Starting to delete files in input directory" + ) - # Log file deletion start - if "history_messages" in pipeline_status: - pipeline_status["history_messages"].append( - "Starting to delete files in input directory" - ) + # Delete only files in the current directory, preserve files in subdirectories + deleted_files_count = 0 + file_errors_count = 0 - # Delete only files in the current directory, preserve files in subdirectories - deleted_files_count = 0 - file_errors_count = 0 + for file_path in workspace_doc_manager.input_dir.glob("*"): + if file_path.is_file(): + try: + file_path.unlink() + deleted_files_count += 1 + except Exception as e: + logger.error(f"Error deleting file {file_path}: {str(e)}") + file_errors_count += 1 - for file_path in doc_manager.input_dir.glob("*"): - if file_path.is_file(): - try: - file_path.unlink() - deleted_files_count += 1 - except Exception as e: - logger.error(f"Error deleting file {file_path}: {str(e)}") - file_errors_count += 1 + # Log file deletion results + if "history_messages" in pipeline_status: + if file_errors_count > 0: + pipeline_status["history_messages"].append( + f"Deleted {deleted_files_count} files with {file_errors_count} errors" + ) + errors.append(f"Failed to delete {file_errors_count} files") + else: + pipeline_status["history_messages"].append( + f"Successfully deleted {deleted_files_count} files" + ) - # Log file deletion results - if "history_messages" in pipeline_status: - if file_errors_count > 0: - pipeline_status["history_messages"].append( - f"Deleted {deleted_files_count} files with {file_errors_count} errors" - ) - errors.append(f"Failed to delete {file_errors_count} files") + # Prepare final result message + final_message = "" + if errors: + final_message = f"Cleared documents with some errors. Deleted {deleted_files_count} files." + status = "partial_success" else: - pipeline_status["history_messages"].append( - f"Successfully deleted {deleted_files_count} files" - ) - - # Prepare final result message - final_message = "" - if errors: - final_message = f"Cleared documents with some errors. Deleted {deleted_files_count} files." - status = "partial_success" - else: - final_message = f"All documents cleared successfully. Deleted {deleted_files_count} files." - status = "success" + final_message = f"All documents cleared successfully. Deleted {deleted_files_count} files." + status = "success" - # Log final result - if "history_messages" in pipeline_status: - pipeline_status["history_messages"].append(final_message) + # Log final result + if "history_messages" in pipeline_status: + pipeline_status["history_messages"].append(final_message) - # Return response based on results - return ClearDocumentsResponse(status=status, message=final_message) - except Exception as e: - error_msg = f"Error clearing documents: {str(e)}" - logger.error(error_msg) - logger.error(traceback.format_exc()) - if "history_messages" in pipeline_status: - pipeline_status["history_messages"].append(error_msg) - raise HTTPException(status_code=500, detail=str(e)) - finally: - # Reset busy status after completion - async with pipeline_status_lock: - pipeline_status["busy"] = False - completion_msg = "Document clearing process completed" - pipeline_status["latest_message"] = completion_msg + # Return response based on results + return ClearDocumentsResponse(status=status, message=final_message) + except Exception as e: + error_msg = f"Error clearing documents: {str(e)}" + logger.error(error_msg) + logger.error(traceback.format_exc()) if "history_messages" in pipeline_status: - pipeline_status["history_messages"].append(completion_msg) + pipeline_status["history_messages"].append(error_msg) + raise HTTPException(status_code=500, detail=str(e)) + finally: + # Reset busy status after completion + async with pipeline_status_lock: + pipeline_status["busy"] = False + completion_msg = "Document clearing process completed" + pipeline_status["latest_message"] = completion_msg + if "history_messages" in pipeline_status: + pipeline_status["history_messages"].append(completion_msg) + finally: + workspace_mgr.release(workspace) @router.get( "/pipeline_status", dependencies=[Depends(combined_auth)], response_model=PipelineStatusResponse, ) - async def get_pipeline_status() -> PipelineStatusResponse: + async def get_pipeline_status(http_request: Request) -> PipelineStatusResponse: """ Get the current status of the document indexing pipeline. This endpoint returns information about the current state of the document processing pipeline, including the processing status, progress information, and history messages. + Args: + http_request: FastAPI Request object for workspace extraction + Returns: PipelineStatusResponse: A response object containing: - autoscanned (bool): Whether auto-scan has started @@ -2665,8 +2780,10 @@ async def get_pipeline_status() -> PipelineStatusResponse: with truncation message if more than 1000 messages exist) Raises: - HTTPException: If an error occurs while retrieving pipeline status (500) + HTTPException: If an error occurs during text processing (500). """ + workspace = extract_and_register_workspace(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: from lightrag.kg.shared_storage import ( get_namespace_data, @@ -2733,16 +2850,14 @@ async def get_pipeline_status() -> PipelineStatusResponse: status_dict["job_start"] = format_datetime(status_dict["job_start"]) return PipelineStatusResponse(**status_dict) - except Exception as e: - logger.error(f"Error getting pipeline status: {str(e)}") - logger.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) + finally: + workspace_mgr.release(workspace) # TODO: Deprecated, use /documents/paginated instead @router.get( "", response_model=DocsStatusesResponse, dependencies=[Depends(combined_auth)] ) - async def documents() -> DocsStatusesResponse: + async def documents(http_request: Request) -> DocsStatusesResponse: """ Get the status of all documents in the system. This endpoint is deprecated; use /documents/paginated instead. To prevent excessive resource consumption, a maximum of 1,000 records is returned. @@ -2751,6 +2866,9 @@ async def documents() -> DocsStatusesResponse: processing status (PENDING, PROCESSING, PREPROCESSED, PROCESSED, FAILED). The results are limited to 1000 total documents with fair distribution across all statuses. + Args: + http_request: FastAPI Request object for workspace extraction + Returns: DocsStatusesResponse: A response object containing a dictionary where keys are DocStatus values and values are lists of DocStatusResponse @@ -2760,6 +2878,8 @@ async def documents() -> DocsStatusesResponse: Raises: HTTPException: If an error occurs while retrieving document statuses (500). """ + workspace = extract_and_register_workspace(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: statuses = ( DocStatus.PENDING, @@ -2835,10 +2955,8 @@ async def documents() -> DocsStatusesResponse: current_status_idx = (current_status_idx + 1) % len(status_documents) return response - except Exception as e: - logger.error(f"Error GET /documents: {str(e)}") - logger.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) + finally: + workspace_mgr.release(workspace) class DeleteDocByIdResponse(BaseModel): """Response model for single document deletion operation.""" @@ -2849,6 +2967,29 @@ class DeleteDocByIdResponse(BaseModel): message: str = Field(description="Message describing the operation result") doc_id: str = Field(description="The ID of the document to delete") + async def _run_delete_with_workspace( + workspace_mgr, + workspace, + doc_ids, + delete_file, + delete_llm_cache, + base_input_dir, + supported_extensions, + ): + rag = await workspace_mgr.get_or_create(workspace) + try: + # Create workspace-specific doc_manager + workspace_doc_manager = DocumentManager( + input_dir=base_input_dir, + workspace=workspace, + supported_extensions=supported_extensions, + ) + await background_delete_documents( + rag, workspace_doc_manager, doc_ids, delete_file, delete_llm_cache + ) + finally: + workspace_mgr.release(workspace) + @router.delete( "/delete_document", response_model=DeleteDocByIdResponse, @@ -2856,6 +2997,7 @@ class DeleteDocByIdResponse(BaseModel): summary="Delete a document and all its associated data by its ID.", ) async def delete_document( + http_request: Request, delete_request: DeleteDocRequest, background_tasks: BackgroundTasks, ) -> DeleteDocByIdResponse: @@ -2870,6 +3012,7 @@ async def delete_document( This operation is irreversible and will interact with the pipeline status. Args: + http_request: FastAPI Request object for workspace extraction delete_request (DeleteDocRequest): The request containing the document IDs and deletion options. background_tasks: FastAPI BackgroundTasks for async processing @@ -2882,8 +3025,10 @@ async def delete_document( HTTPException: - 500: If an unexpected internal error occurs during initialization. """ + workspace = extract_and_register_workspace(http_request) doc_ids = delete_request.doc_ids + rag = await workspace_mgr.get_or_create(workspace) try: from lightrag.kg.shared_storage import ( get_namespace_data, @@ -2906,14 +3051,16 @@ async def delete_document( doc_id=", ".join(doc_ids), ) - # Add deletion task to background tasks + # Add deletion task to background tasks with workspace isolation background_tasks.add_task( - background_delete_documents, - rag, - doc_manager, + _run_delete_with_workspace, + workspace_mgr, + workspace, doc_ids, delete_request.delete_file, delete_request.delete_llm_cache, + doc_manager.base_input_dir, + doc_manager.supported_extensions, ) return DeleteDocByIdResponse( @@ -2922,18 +3069,15 @@ async def delete_document( doc_id=", ".join(doc_ids), ) - except Exception as e: - error_msg = f"Error initiating document deletion for {delete_request.doc_ids}: {str(e)}" - logger.error(error_msg) - logger.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=error_msg) + finally: + workspace_mgr.release(workspace) @router.post( "/clear_cache", response_model=ClearCacheResponse, dependencies=[Depends(combined_auth)], ) - async def clear_cache(request: ClearCacheRequest): + async def clear_cache(http_request: Request, request: ClearCacheRequest): """ Clear all cache data from the LLM response cache storage. @@ -2941,6 +3085,7 @@ async def clear_cache(request: ClearCacheRequest): The request body is accepted for API compatibility but is ignored. Args: + http_request: FastAPI Request object for workspace extraction request (ClearCacheRequest): The request body (ignored for compatibility). Returns: @@ -2949,6 +3094,8 @@ async def clear_cache(request: ClearCacheRequest): Raises: HTTPException: If an error occurs during cache clearing (500). """ + workspace = extract_and_register_workspace(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: # Call the aclear_cache method (no modes parameter) await rag.aclear_cache() @@ -2957,21 +3104,20 @@ async def clear_cache(request: ClearCacheRequest): message = "Successfully cleared all cache" return ClearCacheResponse(status="success", message=message) - except Exception as e: - logger.error(f"Error clearing cache: {str(e)}") - logger.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) + finally: + workspace_mgr.release(workspace) @router.delete( "/delete_entity", response_model=DeletionResult, dependencies=[Depends(combined_auth)], ) - async def delete_entity(request: DeleteEntityRequest): + async def delete_entity(http_request: Request, request: DeleteEntityRequest): """ Delete an entity and all its relationships from the knowledge graph. Args: + http_request: FastAPI Request object for workspace extraction request (DeleteEntityRequest): The request body containing the entity name. Returns: @@ -2980,6 +3126,8 @@ async def delete_entity(request: DeleteEntityRequest): Raises: HTTPException: If the entity is not found (404) or an error occurs (500). """ + workspace = extract_and_register_workspace(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: result = await rag.adelete_by_entity(entity_name=request.entity_name) if result.status == "not_found": @@ -2996,17 +3144,20 @@ async def delete_entity(request: DeleteEntityRequest): logger.error(error_msg) logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=error_msg) + finally: + workspace_mgr.release(workspace) @router.delete( "/delete_relation", response_model=DeletionResult, dependencies=[Depends(combined_auth)], ) - async def delete_relation(request: DeleteRelationRequest): + async def delete_relation(http_request: Request, request: DeleteRelationRequest): """ Delete a relationship between two entities from the knowledge graph. Args: + http_request: FastAPI Request object for workspace extraction request (DeleteRelationRequest): The request body containing the source and target entity names. Returns: @@ -3015,6 +3166,8 @@ async def delete_relation(request: DeleteRelationRequest): Raises: HTTPException: If the relation is not found (404) or an error occurs (500). """ + workspace = extract_and_register_workspace(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: result = await rag.adelete_by_relation( source_entity=request.source_entity, @@ -3034,13 +3187,17 @@ async def delete_relation(request: DeleteRelationRequest): logger.error(error_msg) logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=error_msg) + finally: + workspace_mgr.release(workspace) @router.get( "/track_status/{track_id}", response_model=TrackStatusResponse, dependencies=[Depends(combined_auth)], ) - async def get_track_status(track_id: str) -> TrackStatusResponse: + async def get_track_status( + http_request: Request, track_id: str + ) -> TrackStatusResponse: """ Get the processing status of documents by tracking ID. @@ -3048,6 +3205,7 @@ async def get_track_status(track_id: str) -> TrackStatusResponse: allowing users to monitor the processing progress of their uploaded files or inserted texts. Args: + http_request: FastAPI Request object for workspace extraction track_id (str): The tracking ID returned from upload, text, or texts endpoints Returns: @@ -3059,6 +3217,8 @@ async def get_track_status(track_id: str) -> TrackStatusResponse: Raises: HTTPException: If track_id is invalid (400) or an error occurs (500). """ + workspace = extract_and_register_workspace(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: # Validate track_id if not track_id or not track_id.strip(): @@ -3108,6 +3268,8 @@ async def get_track_status(track_id: str) -> TrackStatusResponse: logger.error(f"Error getting track status for {track_id}: {str(e)}") logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) + finally: + workspace_mgr.release(workspace) @router.post( "/paginated", @@ -3115,6 +3277,7 @@ async def get_track_status(track_id: str) -> TrackStatusResponse: dependencies=[Depends(combined_auth)], ) async def get_documents_paginated( + http_request: Request, request: DocumentsRequest, ) -> PaginatedDocsResponse: """ @@ -3125,6 +3288,7 @@ async def get_documents_paginated( requested page of data. Args: + http_request: FastAPI Request object for workspace extraction request (DocumentsRequest): The request body containing pagination parameters Returns: @@ -3136,24 +3300,27 @@ async def get_documents_paginated( Raises: HTTPException: If an error occurs while retrieving documents (500). """ - trace_id = uuid4().hex[:8] - request_start = time.perf_counter() - status_filter_value = ( - request.status_filter.value if request.status_filter is not None else None - ) - - performance_timing_log( - "[documents/paginated][%s] Request start workspace=%s status_filter=%s page=%s page_size=%s sort_field=%s sort_direction=%s", - trace_id, - rag.workspace, - status_filter_value, - request.page, - request.page_size, - request.sort_field, - request.sort_direction, - ) - + workspace = extract_and_register_workspace(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: + trace_id = uuid4().hex[:8] + request_start = time.perf_counter() + status_filter_value = ( + request.status_filter.value + if request.status_filter is not None + else None + ) + + performance_timing_log( + "[documents/paginated][%s] Request start workspace=%s status_filter=%s page=%s page_size=%s sort_field=%s sort_direction=%s", + trace_id, + rag.workspace, + status_filter_value, + request.page, + request.page_size, + request.sort_field, + request.sort_direction, + ) async def _timed_call(operation_name: str, operation): operation_start = time.perf_counter() @@ -3278,49 +3445,56 @@ async def _timed_call(operation_name: str, operation): return response except Exception as e: - total_elapsed = time.perf_counter() - request_start - performance_timing_log( - "[documents/paginated][%s] Request failed after %.4fs", - trace_id, - total_elapsed, - ) logger.error(f"Error getting paginated documents: {str(e)}") logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) + finally: + workspace_mgr.release(workspace) @router.get( "/status_counts", response_model=StatusCountsResponse, dependencies=[Depends(combined_auth)], ) - async def get_document_status_counts() -> StatusCountsResponse: + async def get_document_status_counts(http_request: Request) -> StatusCountsResponse: """ Get counts of documents by status. This endpoint retrieves the count of documents in each processing status (PENDING, PROCESSING, PROCESSED, FAILED) for all documents in the system. + Args: + http_request: FastAPI Request object for workspace extraction + Returns: StatusCountsResponse: A response object containing status counts Raises: HTTPException: If an error occurs while retrieving status counts (500). """ + workspace = extract_and_register_workspace(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: status_counts = await rag.doc_status.get_all_status_counts() return StatusCountsResponse(status_counts=status_counts) + finally: + workspace_mgr.release(workspace) - except Exception as e: - logger.error(f"Error getting document status counts: {str(e)}") - logger.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) + async def _run_reprocess_with_workspace(workspace_mgr, workspace): + rag = await workspace_mgr.get_or_create(workspace) + try: + await rag.apipeline_process_enqueue_documents() + finally: + workspace_mgr.release(workspace) @router.post( "/reprocess_failed", response_model=ReprocessResponse, dependencies=[Depends(combined_auth)], ) - async def reprocess_failed_documents(background_tasks: BackgroundTasks): + async def reprocess_failed_documents( + http_request: Request, background_tasks: BackgroundTasks + ): """ Reprocess failed and pending documents. @@ -3337,6 +3511,10 @@ async def reprocess_failed_documents(background_tasks: BackgroundTasks): pipeline status. The reprocessed documents retain their original track_id from initial upload, so use their original track_id to monitor progress. + Args: + http_request: FastAPI Request object for workspace extraction + background_tasks: FastAPI BackgroundTasks for async processing + Returns: ReprocessResponse: Response with status and message. track_id is always empty string because reprocessed documents retain @@ -3345,28 +3523,25 @@ async def reprocess_failed_documents(background_tasks: BackgroundTasks): Raises: HTTPException: If an error occurs while initiating reprocessing (500). """ - try: - # Start the reprocessing in the background - # Note: Reprocessed documents retain their original track_id from initial upload - background_tasks.add_task(rag.apipeline_process_enqueue_documents) - logger.info("Reprocessing of failed documents initiated") - - return ReprocessResponse( - status="reprocessing_started", - message="Reprocessing of failed documents has been initiated in background. Documents retain their original track_id.", - ) + workspace = extract_and_register_workspace(http_request) + # Start the reprocessing in the background + # Note: Reprocessed documents retain their original track_id from initial upload + background_tasks.add_task( + _run_reprocess_with_workspace, workspace_mgr, workspace + ) + logger.info("Reprocessing of failed documents initiated") - except Exception as e: - logger.error(f"Error initiating reprocessing of failed documents: {str(e)}") - logger.error(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) + return ReprocessResponse( + status="reprocessing_started", + message="Reprocessing of failed documents has been initiated in background. Documents retain their original track_id.", + ) @router.post( "/cancel_pipeline", response_model=CancelPipelineResponse, dependencies=[Depends(combined_auth)], ) - async def cancel_pipeline(): + async def cancel_pipeline(http_request: Request): """ Request cancellation of the currently running pipeline. @@ -3379,6 +3554,9 @@ async def cancel_pipeline(): The cancellation is graceful and ensures data consistency. Documents that have completed processing will remain in PROCESSED status. + Args: + http_request: FastAPI Request object for workspace extraction + Returns: CancelPipelineResponse: Response with status and message - status="cancellation_requested": Cancellation flag has been set @@ -3387,6 +3565,8 @@ async def cancel_pipeline(): Raises: HTTPException: If an error occurs while setting cancellation flag (500). """ + workspace = extract_and_register_workspace(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: from lightrag.kg.shared_storage import ( get_namespace_data, @@ -3423,5 +3603,7 @@ async def cancel_pipeline(): logger.error(f"Error requesting pipeline cancellation: {str(e)}") logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) + finally: + workspace_mgr.release(workspace) return router diff --git a/lightrag/api/routers/graph_routes.py b/lightrag/api/routers/graph_routes.py index e892ff011c..6f7e3746b5 100644 --- a/lightrag/api/routers/graph_routes.py +++ b/lightrag/api/routers/graph_routes.py @@ -4,10 +4,11 @@ from typing import Optional, Dict, Any import traceback -from fastapi import APIRouter, Depends, Query, HTTPException +from fastapi import APIRouter, Depends, Query, HTTPException, Request from pydantic import BaseModel, Field from lightrag.utils import logger +from lightrag.api.utils import extract_workspace_from_header from ..utils_api import get_combined_auth_dependency router = APIRouter(tags=["graph"]) @@ -86,17 +87,19 @@ class RelationCreateRequest(BaseModel): ) -def create_graph_routes(rag, api_key: Optional[str] = None): +def create_graph_routes(workspace_mgr, api_key: Optional[str] = None): combined_auth = get_combined_auth_dependency(api_key) @router.get("/graph/label/list", dependencies=[Depends(combined_auth)]) - async def get_graph_labels(): + async def get_graph_labels(http_request: Request): """ Get all graph labels Returns: List[str]: List of graph labels """ + workspace = extract_workspace_from_header(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: return await rag.get_graph_labels() except Exception as e: @@ -105,9 +108,12 @@ async def get_graph_labels(): raise HTTPException( status_code=500, detail=f"Error getting graph labels: {str(e)}" ) + finally: + workspace_mgr.release(workspace) @router.get("/graph/label/popular", dependencies=[Depends(combined_auth)]) async def get_popular_labels( + http_request: Request, limit: int = Query( 300, description="Maximum number of popular labels to return", ge=1, le=1000 ), @@ -121,6 +127,8 @@ async def get_popular_labels( Returns: List[str]: List of popular labels sorted by degree (highest first) """ + workspace = extract_workspace_from_header(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: return await rag.chunk_entity_relation_graph.get_popular_labels(limit) except Exception as e: @@ -129,9 +137,12 @@ async def get_popular_labels( raise HTTPException( status_code=500, detail=f"Error getting popular labels: {str(e)}" ) + finally: + workspace_mgr.release(workspace) @router.get("/graph/label/search", dependencies=[Depends(combined_auth)]) async def search_labels( + http_request: Request, q: str = Query(..., description="Search query string"), limit: int = Query( 50, description="Maximum number of search results to return", ge=1, le=100 @@ -147,6 +158,8 @@ async def search_labels( Returns: List[str]: List of matching labels sorted by relevance """ + workspace = extract_workspace_from_header(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: return await rag.chunk_entity_relation_graph.search_labels(q, limit) except Exception as e: @@ -155,9 +168,12 @@ async def search_labels( raise HTTPException( status_code=500, detail=f"Error searching labels: {str(e)}" ) + finally: + workspace_mgr.release(workspace) @router.get("/graphs", dependencies=[Depends(combined_auth)]) async def get_knowledge_graph( + http_request: Request, label: str = Query(..., description="Label to get knowledge graph for"), max_depth: int = Query(3, description="Maximum depth of graph", ge=1), max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1), @@ -176,6 +192,8 @@ async def get_knowledge_graph( Returns: Dict[str, List[str]]: Knowledge graph for label """ + workspace = extract_workspace_from_header(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: # Log the label parameter to check for leading spaces logger.debug( @@ -193,9 +211,12 @@ async def get_knowledge_graph( raise HTTPException( status_code=500, detail=f"Error getting knowledge graph: {str(e)}" ) + finally: + workspace_mgr.release(workspace) @router.get("/graph/entity/exists", dependencies=[Depends(combined_auth)]) async def check_entity_exists( + http_request: Request, name: str = Query(..., description="Entity name to check"), ): """ @@ -207,6 +228,8 @@ async def check_entity_exists( Returns: Dict[str, bool]: Dictionary with 'exists' key indicating if entity exists """ + workspace = extract_workspace_from_header(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: exists = await rag.chunk_entity_relation_graph.has_node(name) return {"exists": exists} @@ -216,9 +239,11 @@ async def check_entity_exists( raise HTTPException( status_code=500, detail=f"Error checking entity existence: {str(e)}" ) + finally: + workspace_mgr.release(workspace) @router.post("/graph/entity/edit", dependencies=[Depends(combined_auth)]) - async def update_entity(request: EntityUpdateRequest): + async def update_entity(http_request: Request, request: EntityUpdateRequest): """ Update an entity's properties in the knowledge graph @@ -226,7 +251,7 @@ async def update_entity(request: EntityUpdateRequest): When renaming to an existing entity name, the behavior depends on allow_merge: Args: - request (EntityUpdateRequest): Request containing: + request: Request containing: - entity_name (str): Name of the entity to update - updated_data (Dict[str, Any]): Dictionary of properties to update - allow_rename (bool): Whether to allow entity renaming (default: False) @@ -352,6 +377,8 @@ async def update_entity(request: EntityUpdateRequest): } } """ + workspace = extract_workspace_from_header(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: result = await rag.aedit_entity( entity_name=request.entity_name, @@ -406,17 +433,21 @@ async def update_entity(request: EntityUpdateRequest): raise HTTPException( status_code=500, detail=f"Error updating entity: {str(e)}" ) + finally: + workspace_mgr.release(workspace) @router.post("/graph/relation/edit", dependencies=[Depends(combined_auth)]) - async def update_relation(request: RelationUpdateRequest): + async def update_relation(http_request: Request, request: RelationUpdateRequest): """Update a relation's properties in the knowledge graph Args: - request (RelationUpdateRequest): Request containing source ID, target ID and updated data + request: Request containing source ID, target ID and updated data Returns: Dict: Updated relation information """ + workspace = extract_workspace_from_header(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: result = await rag.aedit_relation( source_entity=request.source_id, @@ -441,9 +472,11 @@ async def update_relation(request: RelationUpdateRequest): raise HTTPException( status_code=500, detail=f"Error updating relation: {str(e)}" ) + finally: + workspace_mgr.release(workspace) @router.post("/graph/entity/create", dependencies=[Depends(combined_auth)]) - async def create_entity(request: EntityCreateRequest): + async def create_entity(http_request: Request, request: EntityCreateRequest): """ Create a new entity in the knowledge graph @@ -487,6 +520,8 @@ async def create_entity(request: EntityCreateRequest): } } """ + workspace = extract_workspace_from_header(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: # Use the proper acreate_entity method which handles: # - Graph lock for concurrency @@ -514,9 +549,11 @@ async def create_entity(request: EntityCreateRequest): raise HTTPException( status_code=500, detail=f"Error creating entity: {str(e)}" ) + finally: + workspace_mgr.release(workspace) @router.post("/graph/relation/create", dependencies=[Depends(combined_auth)]) - async def create_relation(request: RelationCreateRequest): + async def create_relation(http_request: Request, request: RelationCreateRequest): """ Create a new relationship between two entities in the knowledge graph @@ -572,6 +609,8 @@ async def create_relation(request: RelationCreateRequest): } } """ + workspace = extract_workspace_from_header(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: # Use the proper acreate_relation method which handles: # - Graph lock for concurrency @@ -603,9 +642,11 @@ async def create_relation(request: RelationCreateRequest): raise HTTPException( status_code=500, detail=f"Error creating relation: {str(e)}" ) + finally: + workspace_mgr.release(workspace) @router.post("/graph/entities/merge", dependencies=[Depends(combined_auth)]) - async def merge_entities(request: EntityMergeRequest): + async def merge_entities(http_request: Request, request: EntityMergeRequest): """ Merge multiple entities into a single entity, preserving all relationships @@ -661,6 +702,8 @@ async def merge_entities(request: EntityMergeRequest): - Source entities will be permanently deleted after the merge - This operation cannot be undone, so verify entity names before merging """ + workspace = extract_workspace_from_header(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: result = await rag.amerge_entities( source_entities=request.entities_to_change, @@ -684,5 +727,7 @@ async def merge_entities(request: EntityMergeRequest): raise HTTPException( status_code=500, detail=f"Error merging entities: {str(e)}" ) + finally: + workspace_mgr.release(workspace) return router diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py index 15c695cee7..ca56336a91 100644 --- a/lightrag/api/routers/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -1,17 +1,18 @@ from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel +from starlette.responses import StreamingResponse from typing import List, Dict, Any, Optional, Type from lightrag.utils import logger import time import json import re from enum import Enum -from fastapi.responses import StreamingResponse import asyncio -from lightrag import LightRAG, QueryParam +from lightrag import QueryParam from lightrag.utils import TiktokenTokenizer from lightrag.api.utils_api import get_combined_auth_dependency from fastapi import Depends +from lightrag.api.utils import extract_workspace_from_header # query mode according to query prefix (bypass is not LightRAG quer mode) @@ -218,9 +219,8 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode, bool, Optional[str]]: class OllamaAPI: - def __init__(self, rag: LightRAG, top_k: int = 60, api_key: Optional[str] = None): - self.rag = rag - self.ollama_server_infos = rag.ollama_server_infos + def __init__(self, workspace_mgr, top_k: int = 60, api_key: Optional[str] = None): + self.workspace_mgr = workspace_mgr self.top_k = top_k self.api_key = api_key self.router = APIRouter(tags=["ollama"]) @@ -231,56 +231,75 @@ def setup_routes(self): combined_auth = get_combined_auth_dependency(self.api_key) @self.router.get("/version", dependencies=[Depends(combined_auth)]) - async def get_version(): + async def get_version(http_request: Request): """Get Ollama version information""" - return OllamaVersionResponse(version="0.9.3") + workspace = extract_workspace_from_header(http_request) + rag = await self.workspace_mgr.get_or_create(workspace) + try: + return OllamaVersionResponse( + version=rag.ollama_server_infos.get("version", "unknown") + ) + finally: + self.workspace_mgr.release(workspace) @self.router.get("/tags", dependencies=[Depends(combined_auth)]) - async def get_tags(): + async def get_tags(http_request: Request): """Return available models acting as an Ollama server""" - return OllamaTagResponse( - models=[ - { - "name": self.ollama_server_infos.LIGHTRAG_MODEL, - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "size": self.ollama_server_infos.LIGHTRAG_SIZE, - "digest": self.ollama_server_infos.LIGHTRAG_DIGEST, - "details": { - "parent_model": "", - "format": "gguf", - "family": self.ollama_server_infos.LIGHTRAG_NAME, - "families": [self.ollama_server_infos.LIGHTRAG_NAME], - "parameter_size": "13B", - "quantization_level": "Q4_0", - }, - } - ] - ) + workspace = extract_workspace_from_header(http_request) + rag = await self.workspace_mgr.get_or_create(workspace) + try: + infos = rag.ollama_server_infos + return OllamaTagResponse( + models=[ + { + "name": infos.LIGHTRAG_MODEL, + "model": infos.LIGHTRAG_MODEL, + "modified_at": infos.LIGHTRAG_CREATED_AT, + "size": infos.LIGHTRAG_SIZE, + "digest": infos.LIGHTRAG_DIGEST, + "details": { + "parent_model": "", + "format": "gguf", + "family": infos.LIGHTRAG_NAME, + "families": [infos.LIGHTRAG_NAME], + "parameter_size": "13B", + "quantization_level": "Q4_0", + }, + } + ] + ) + finally: + self.workspace_mgr.release(workspace) @self.router.get("/ps", dependencies=[Depends(combined_auth)]) - async def get_running_models(): + async def get_running_models(http_request: Request): """List Running Models - returns currently running models""" - return OllamaPsResponse( - models=[ - { - "name": self.ollama_server_infos.LIGHTRAG_MODEL, - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "size": self.ollama_server_infos.LIGHTRAG_SIZE, - "digest": self.ollama_server_infos.LIGHTRAG_DIGEST, - "details": { - "parent_model": "", - "format": "gguf", - "family": "llama", - "families": ["llama"], - "parameter_size": "7.2B", - "quantization_level": "Q4_0", - }, - "expires_at": "2050-12-31T14:38:31.83753-07:00", - "size_vram": self.ollama_server_infos.LIGHTRAG_SIZE, - } - ] - ) + workspace = extract_workspace_from_header(http_request) + rag = await self.workspace_mgr.get_or_create(workspace) + try: + infos = rag.ollama_server_infos + return OllamaPsResponse( + models=[ + { + "name": infos.LIGHTRAG_MODEL, + "model": infos.LIGHTRAG_MODEL, + "size": infos.LIGHTRAG_SIZE, + "digest": infos.LIGHTRAG_DIGEST, + "details": { + "parent_model": "", + "format": "gguf", + "family": "llama", + "families": ["llama"], + "parameter_size": "7.2B", + "quantization_level": "Q4_0", + }, + "expires_at": "2050-12-31T14:38:31.83753-07:00", + "size_vram": infos.LIGHTRAG_SIZE, + } + ] + ) + finally: + self.workspace_mgr.release(workspace) @self.router.post( "/generate", dependencies=[Depends(combined_auth)], include_in_schema=True @@ -291,173 +310,82 @@ async def generate(raw_request: Request): and will be handled by underlying LLM model. Supports both application/json and application/octet-stream Content-Types. """ - try: - # Parse the request body manually - request = await parse_request_body(raw_request, OllamaGenerateRequest) - - query = request.prompt - start_time = time.time_ns() - prompt_tokens = estimate_tokens(query) - - if request.system: - self.rag.llm_model_kwargs["system_prompt"] = request.system - - if request.stream: - response = await self.rag.llm_model_func( - query, stream=True, **self.rag.llm_model_kwargs - ) - - async def stream_generator(): - first_chunk_time = None - last_chunk_time = time.time_ns() - total_response = "" + # Parse the request body manually + request = await parse_request_body(raw_request, OllamaGenerateRequest) - # Ensure response is an async generator - if isinstance(response, str): - # If it's a string, send in two parts - first_chunk_time = start_time - last_chunk_time = time.time_ns() - total_response = response - - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": response, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" + query = request.prompt + start_time = time.time_ns() + prompt_tokens = estimate_tokens(query) - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time + workspace = extract_workspace_from_header(raw_request) + rag = await self.workspace_mgr.get_or_create(workspace) - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": "", - "done": True, - "done_reason": "stop", - "context": [], - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - else: - try: - async for chunk in response: - if chunk: - if first_chunk_time is None: - first_chunk_time = time.time_ns() - - last_chunk_time = time.time_ns() - - total_response += chunk - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": chunk, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - except (asyncio.CancelledError, Exception) as e: - error_msg = str(e) - if isinstance(e, asyncio.CancelledError): - error_msg = "Stream was cancelled by server" - else: - error_msg = f"Provider error: {error_msg}" - - logger.error(f"Stream error: {error_msg}") - - # Send error message to client - error_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": f"\n\nError: {error_msg}", - "error": f"\n\nError: {error_msg}", - "done": False, - } - yield f"{json.dumps(error_data, ensure_ascii=False)}\n" - - # Send final message to close the stream - final_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": "", - "done": True, - } - yield f"{json.dumps(final_data, ensure_ascii=False)}\n" - return - if first_chunk_time is None: - first_chunk_time = start_time - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": "", - "done": True, - "done_reason": "stop", - "context": [], - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - return + try: + infos = rag.ollama_server_infos + if request.stream: + # Generator's finally handles release return StreamingResponse( - stream_generator(), + self._stream_generate( + workspace, + rag, + infos, + request, + query, + start_time, + prompt_tokens, + ), media_type="application/x-ndjson", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Content-Type": "application/x-ndjson", - "X-Accel-Buffering": "no", # Ensure proper handling of streaming responses in Nginx proxy + "X-Accel-Buffering": "no", }, ) else: - first_chunk_time = time.time_ns() - response_text = await self.rag.llm_model_func( - query, stream=False, **self.rag.llm_model_kwargs - ) - last_chunk_time = time.time_ns() - - if not response_text: - response_text = "No response generated" + try: + first_chunk_time = time.time_ns() + if request.system: + kwargs = { + **rag.llm_model_kwargs, + "system_prompt": request.system, + } + else: + kwargs = rag.llm_model_kwargs + response_text = await rag.llm_model_func( + query, stream=False, **kwargs + ) + last_chunk_time = time.time_ns() - completion_tokens = estimate_tokens(str(response_text)) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time - - return { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "response": str(response_text), - "done": True, - "done_reason": "stop", - "context": [], - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - except Exception as e: - logger.error(f"Ollama generate error: {str(e)}", exc_info=True) - raise HTTPException(status_code=500, detail=str(e)) + if not response_text: + response_text = "No response generated" + + completion_tokens = estimate_tokens(str(response_text)) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + return { + "model": infos.LIGHTRAG_MODEL, + "created_at": infos.LIGHTRAG_CREATED_AT, + "response": str(response_text), + "done": True, + "done_reason": "stop", + "context": [], + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + finally: + self.workspace_mgr.release(workspace) + except Exception: + # If we get here, streaming generator never started β€” we must release + self.workspace_mgr.release(workspace) + raise @self.router.post( "/chat", dependencies=[Depends(combined_auth)], include_in_schema=True @@ -468,256 +396,426 @@ async def chat(raw_request: Request): Detects and forwards OpenWebUI session-related requests (for meta data generation task) directly to LLM. Supports both application/json and application/octet-stream Content-Types. """ - try: - # Parse the request body manually - request = await parse_request_body(raw_request, OllamaChatRequest) - - # Get all messages - messages = request.messages - if not messages: - raise HTTPException(status_code=400, detail="No messages provided") - - # Validate that the last message is from a user - if messages[-1].role != "user": - raise HTTPException( - status_code=400, detail="Last message must be from user role" - ) + # Parse the request body manually + request = await parse_request_body(raw_request, OllamaChatRequest) + + # Get all messages + messages = request.messages + if not messages: + raise HTTPException(status_code=400, detail="No messages provided") + + # Validate that the last message is from a user + if messages[-1].role != "user": + raise HTTPException( + status_code=400, detail="Last message must be from user role" + ) - # Get the last message as query and previous messages as history - query = messages[-1].content - # Convert OllamaMessage objects to dictionaries - conversation_history = [ - {"role": msg.role, "content": msg.content} for msg in messages[:-1] - ] + # Get the last message as query and previous messages as history + query = messages[-1].content + # Convert OllamaMessage objects to dictionaries + conversation_history = [ + {"role": msg.role, "content": msg.content} for msg in messages[:-1] + ] - # Check for query prefix - cleaned_query, mode, only_need_context, user_prompt = parse_query_mode( - query - ) + # Check for query prefix + cleaned_query, mode, only_need_context, user_prompt = parse_query_mode( + query + ) - start_time = time.time_ns() - prompt_tokens = estimate_tokens(cleaned_query) + start_time = time.time_ns() + prompt_tokens = estimate_tokens(cleaned_query) - param_dict = { - "mode": mode.value, - "stream": request.stream, - "only_need_context": only_need_context, - "conversation_history": conversation_history, - "top_k": self.top_k, - } + param_dict = { + "mode": mode.value, + "stream": request.stream, + "only_need_context": only_need_context, + "conversation_history": conversation_history, + "top_k": self.top_k, + } + + # Add user_prompt to param_dict + if user_prompt is not None: + param_dict["user_prompt"] = user_prompt - # Add user_prompt to param_dict - if user_prompt is not None: - param_dict["user_prompt"] = user_prompt + query_param = QueryParam(**param_dict) - query_param = QueryParam(**param_dict) + workspace = extract_workspace_from_header(raw_request) + rag = await self.workspace_mgr.get_or_create(workspace) + + try: + infos = rag.ollama_server_infos if request.stream: # Determine if the request is prefix with "/bypass" if mode == SearchMode.bypass: - if request.system: - self.rag.llm_model_kwargs["system_prompt"] = request.system - response = await self.rag.llm_model_func( + system_prompt = request.system + response = await rag.llm_model_func( cleaned_query, stream=True, history_messages=conversation_history, - **self.rag.llm_model_kwargs, + **{**rag.llm_model_kwargs, "system_prompt": system_prompt} + if system_prompt + else rag.llm_model_kwargs, ) else: - response = await self.rag.aquery( - cleaned_query, param=query_param + response = await rag.aquery(cleaned_query, param=query_param) + + # Generator's finally handles release + return StreamingResponse( + self._stream_chat( + workspace, + rag, + infos, + response, + request, + mode, + cleaned_query, + conversation_history, + start_time, + prompt_tokens, + query_param, + ), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "application/x-ndjson", + "X-Accel-Buffering": "no", + }, + ) + else: + try: + first_chunk_time = time.time_ns() + + # Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task + match_result = re.search( + r"\n\nUSER:", cleaned_query, re.MULTILINE ) + if match_result or mode == SearchMode.bypass: + if request.system: + kwargs = { + **rag.llm_model_kwargs, + "system_prompt": request.system, + } + else: + kwargs = rag.llm_model_kwargs + response_text = await rag.llm_model_func( + cleaned_query, + stream=False, + history_messages=conversation_history, + **kwargs, + ) + else: + response_text = await rag.aquery( + cleaned_query, param=query_param + ) - async def stream_generator(): - first_chunk_time = None last_chunk_time = time.time_ns() - total_response = "" - # Ensure response is an async generator - if isinstance(response, str): - # If it's a string, send in two parts - first_chunk_time = start_time + if not response_text: + response_text = "No response generated" + + completion_tokens = estimate_tokens(str(response_text)) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + return { + "model": infos.LIGHTRAG_MODEL, + "created_at": infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": str(response_text), + "images": None, + }, + "done_reason": "stop", + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + finally: + self.workspace_mgr.release(workspace) + except Exception: + # If we get here, streaming generator never started β€” we must release + self.workspace_mgr.release(workspace) + raise + + async def _stream_generate( + self, + workspace: str, + rag, + infos, + request: "OllamaGenerateRequest", + query: str, + start_time: int, + prompt_tokens: int, + ): + """Streaming generator for the generate endpoint.""" + try: + first_chunk_time = None + last_chunk_time = time.time_ns() + total_response = "" + + if request.system: + kwargs = {**rag.llm_model_kwargs, "system_prompt": request.system} + else: + kwargs = rag.llm_model_kwargs + response = await rag.llm_model_func(query, stream=True, **kwargs) + + # Ensure response is an async generator + if isinstance(response, str): + # If it's a string, send in two parts + first_chunk_time = start_time + last_chunk_time = time.time_ns() + total_response = response + + data = { + "model": infos.LIGHTRAG_MODEL, + "created_at": infos.LIGHTRAG_CREATED_AT, + "response": response, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": infos.LIGHTRAG_MODEL, + "created_at": infos.LIGHTRAG_CREATED_AT, + "response": "", + "done": True, + "done_reason": "stop", + "context": [], + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + else: + try: + async for chunk in response: + if chunk: + if first_chunk_time is None: + first_chunk_time = time.time_ns() + last_chunk_time = time.time_ns() - total_response = response + total_response += chunk data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": response, - "images": None, - }, + "model": infos.LIGHTRAG_MODEL, + "created_at": infos.LIGHTRAG_CREATED_AT, + "response": chunk, "done": False, } yield f"{json.dumps(data, ensure_ascii=False)}\n" + except (asyncio.CancelledError, Exception) as e: + error_msg = str(e) + if isinstance(e, asyncio.CancelledError): + error_msg = "Stream was cancelled by server" + else: + error_msg = f"Provider error: {error_msg}" - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time + logger.error(f"Stream error: {error_msg}") - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": "", - "images": None, - }, - "done_reason": "stop", - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - else: - try: - async for chunk in response: - if chunk: - if first_chunk_time is None: - first_chunk_time = time.time_ns() - - last_chunk_time = time.time_ns() - - total_response += chunk - data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": chunk, - "images": None, - }, - "done": False, - } - yield f"{json.dumps(data, ensure_ascii=False)}\n" - except (asyncio.CancelledError, Exception) as e: - error_msg = str(e) - if isinstance(e, asyncio.CancelledError): - error_msg = "Stream was cancelled by server" - else: - error_msg = f"Provider error: {error_msg}" - - logger.error(f"Stream error: {error_msg}") - - # Send error message to client - error_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": f"\n\nError: {error_msg}", - "images": None, - }, - "error": f"\n\nError: {error_msg}", - "done": False, - } - yield f"{json.dumps(error_data, ensure_ascii=False)}\n" - - # Send final message to close the stream - final_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "message": { - "role": "assistant", - "content": "", - "images": None, - }, - "done": True, - } - yield f"{json.dumps(final_data, ensure_ascii=False)}\n" - return + # Send error message to client + error_data = { + "model": infos.LIGHTRAG_MODEL, + "created_at": infos.LIGHTRAG_CREATED_AT, + "response": f"\n\nError: {error_msg}", + "error": f"\n\nError: {error_msg}", + "done": False, + } + yield f"{json.dumps(error_data, ensure_ascii=False)}\n" + # Send final message to close the stream + final_data = { + "model": infos.LIGHTRAG_MODEL, + "created_at": infos.LIGHTRAG_CREATED_AT, + "response": "", + "done": True, + } + yield f"{json.dumps(final_data, ensure_ascii=False)}\n" + return + + if first_chunk_time is None: + first_chunk_time = start_time + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": infos.LIGHTRAG_MODEL, + "created_at": infos.LIGHTRAG_CREATED_AT, + "response": "", + "done": True, + "done_reason": "stop", + "context": [], + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + finally: + self.workspace_mgr.release(workspace) + + async def _stream_chat( + self, + workspace: str, + rag, + infos, + response, + request: "OllamaChatRequest", + mode: "SearchMode", + cleaned_query: str, + conversation_history: List[Dict[str, str]], + start_time: int, + prompt_tokens: int, + query_param: "QueryParam", + ): + """Streaming generator for the chat endpoint.""" + try: + first_chunk_time = None + last_chunk_time = time.time_ns() + total_response = "" + + # Ensure response is an async generator + if isinstance(response, str): + # If it's a string, send in two parts + first_chunk_time = start_time + last_chunk_time = time.time_ns() + total_response = response + + data = { + "model": infos.LIGHTRAG_MODEL, + "created_at": infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": response, + "images": None, + }, + "done": False, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": infos.LIGHTRAG_MODEL, + "created_at": infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": "", + "images": None, + }, + "done_reason": "stop", + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + else: + try: + async for chunk in response: + if chunk: if first_chunk_time is None: - first_chunk_time = start_time - completion_tokens = estimate_tokens(total_response) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time + first_chunk_time = time.time_ns() + + last_chunk_time = time.time_ns() + total_response += chunk data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": infos.LIGHTRAG_MODEL, + "created_at": infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", - "content": "", + "content": chunk, "images": None, }, - "done_reason": "stop", - "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, + "done": False, } yield f"{json.dumps(data, ensure_ascii=False)}\n" - - return StreamingResponse( - stream_generator(), - media_type="application/x-ndjson", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Content-Type": "application/x-ndjson", - "X-Accel-Buffering": "no", # Ensure proper handling of streaming responses in Nginx proxy - }, - ) - else: - first_chunk_time = time.time_ns() - - # Determine if the request is prefix with "/bypass" or from Open WebUI's session title and session keyword generation task - match_result = re.search( - r"\n\nUSER:", cleaned_query, re.MULTILINE - ) - if match_result or mode == SearchMode.bypass: - if request.system: - self.rag.llm_model_kwargs["system_prompt"] = request.system - - response_text = await self.rag.llm_model_func( - cleaned_query, - stream=False, - history_messages=conversation_history, - **self.rag.llm_model_kwargs, - ) + except (asyncio.CancelledError, Exception) as e: + error_msg = str(e) + if isinstance(e, asyncio.CancelledError): + error_msg = "Stream was cancelled by server" else: - response_text = await self.rag.aquery( - cleaned_query, param=query_param - ) + error_msg = f"Provider error: {error_msg}" - last_chunk_time = time.time_ns() + logger.error(f"Stream error: {error_msg}") - if not response_text: - response_text = "No response generated" - - completion_tokens = estimate_tokens(str(response_text)) - total_time = last_chunk_time - start_time - prompt_eval_time = first_chunk_time - start_time - eval_time = last_chunk_time - first_chunk_time + # Send error message to client + error_data = { + "model": infos.LIGHTRAG_MODEL, + "created_at": infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": f"\n\nError: {error_msg}", + "images": None, + }, + "error": f"\n\nError: {error_msg}", + "done": False, + } + yield f"{json.dumps(error_data, ensure_ascii=False)}\n" - return { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + # Send final message to close the stream + final_data = { + "model": infos.LIGHTRAG_MODEL, + "created_at": infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", - "content": str(response_text), + "content": "", "images": None, }, - "done_reason": "stop", "done": True, - "total_duration": total_time, - "load_duration": 0, - "prompt_eval_count": prompt_tokens, - "prompt_eval_duration": prompt_eval_time, - "eval_count": completion_tokens, - "eval_duration": eval_time, } - except Exception as e: - logger.error(f"Ollama chat error: {str(e)}", exc_info=True) - raise HTTPException(status_code=500, detail=str(e)) + yield f"{json.dumps(final_data, ensure_ascii=False)}\n" + return + + if first_chunk_time is None: + first_chunk_time = start_time + completion_tokens = estimate_tokens(total_response) + total_time = last_chunk_time - start_time + prompt_eval_time = first_chunk_time - start_time + eval_time = last_chunk_time - first_chunk_time + + data = { + "model": infos.LIGHTRAG_MODEL, + "created_at": infos.LIGHTRAG_CREATED_AT, + "message": { + "role": "assistant", + "content": "", + "images": None, + }, + "done_reason": "stop", + "done": True, + "total_duration": total_time, + "load_duration": 0, + "prompt_eval_count": prompt_tokens, + "prompt_eval_duration": prompt_eval_time, + "eval_count": completion_tokens, + "eval_duration": eval_time, + } + yield f"{json.dumps(data, ensure_ascii=False)}\n" + finally: + self.workspace_mgr.release(workspace) diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 22958158a1..466eb396bf 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -4,8 +4,9 @@ import json from typing import Any, Dict, List, Literal, Optional -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Request from lightrag.base import QueryParam +from lightrag.api.utils import extract_workspace_from_header from lightrag.api.utils_api import get_combined_auth_dependency from lightrag.utils import logger from pydantic import BaseModel, Field, field_validator @@ -190,7 +191,7 @@ class StreamChunkResponse(BaseModel): ) -def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): +def create_query_routes(workspace_mgr, api_key: Optional[str] = None, top_k: int = 60): combined_auth = get_combined_auth_dependency(api_key) @router.post( @@ -322,7 +323,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): }, }, ) - async def query_text(request: QueryRequest): + async def query_text(http_request: Request, request: QueryRequest): """ Comprehensive RAG query endpoint with non-streaming response. Parameter "stream" is ignored. @@ -401,6 +402,8 @@ async def query_text(request: QueryRequest): - 400: Invalid input parameters (e.g., query too short) - 500: Internal processing error (e.g., LLM service unavailable) """ + workspace = extract_workspace_from_header(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: param = request.to_query_params( False @@ -452,6 +455,8 @@ async def query_text(request: QueryRequest): except Exception as e: logger.error(f"Error processing query: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) + finally: + workspace_mgr.release(workspace) @router.post( "/query/stream", @@ -532,7 +537,7 @@ async def query_text(request: QueryRequest): }, }, ) - async def query_text_stream(request: QueryRequest): + async def query_text_stream(http_request: Request, request: QueryRequest): """ Advanced RAG query endpoint with flexible streaming response. @@ -659,46 +664,45 @@ async def query_text_stream(request: QueryRequest): This endpoint is ideal for applications requiring flexible response delivery. Use streaming mode for real-time interfaces and non-streaming for batch processing. """ - try: - # Use the stream parameter from the request, defaulting to True if not specified - stream_mode = request.stream if request.stream is not None else True - param = request.to_query_params(stream_mode) + from fastapi.responses import StreamingResponse - from fastapi.responses import StreamingResponse + workspace = extract_workspace_from_header(http_request) + rag = await workspace_mgr.get_or_create(workspace) - # Unified approach: always use aquery_llm for all cases - result = await rag.aquery_llm(request.query, param=param) + stream_mode = request.stream if request.stream is not None else True + param = request.to_query_params(stream_mode) + + # Call aquery_llm here (before returning StreamingResponse) + result = await rag.aquery_llm(request.query, param=param) - async def stream_generator(): - # Extract references and LLM response from unified result - references = result.get("data", {}).get("references", []) - llm_response = result.get("llm_response", {}) - - # Enrich references with chunk content if requested - if request.include_references and request.include_chunk_content: - data = result.get("data", {}) - chunks = data.get("chunks", []) - # Create a mapping from reference_id to chunk content - ref_id_to_content = {} - for chunk in chunks: - ref_id = chunk.get("reference_id", "") - content = chunk.get("content", "") - if ref_id and content: - # Collect chunk content - ref_id_to_content.setdefault(ref_id, []).append(content) - - # Add content to references - enriched_references = [] - for ref in references: - ref_copy = ref.copy() - ref_id = ref.get("reference_id", "") - if ref_id in ref_id_to_content: - # Keep content as a list of chunks (one file may have multiple chunks) - ref_copy["content"] = ref_id_to_content[ref_id] - enriched_references.append(ref_copy) - references = enriched_references - - if llm_response.get("is_streaming"): + # Extract what we need before returning StreamingResponse + references = result.get("data", {}).get("references", []) + llm_response = result.get("llm_response", {}) + is_streaming = llm_response.get("is_streaming", False) + + # Enrich references with chunk content if requested + if request.include_references and request.include_chunk_content: + data = result.get("data", {}) + chunks = data.get("chunks", []) + ref_id_to_content = {} + for chunk in chunks: + ref_id = chunk.get("reference_id", "") + content = chunk.get("content", "") + if ref_id and content: + ref_id_to_content.setdefault(ref_id, []).append(content) + + enriched_references = [] + for ref in references: + ref_copy = ref.copy() + ref_id = ref.get("reference_id", "") + if ref_id in ref_id_to_content: + ref_copy["content"] = ref_id_to_content[ref_id] + enriched_references.append(ref_copy) + references = enriched_references + + async def stream_generator(): + try: + if is_streaming: # Streaming mode: send references first, then stream response chunks if request.include_references: yield f"{json.dumps({'references': references})}\n" @@ -707,7 +711,7 @@ async def stream_generator(): if response_stream: try: async for chunk in response_stream: - if chunk: # Only send non-empty content + if chunk: yield f"{json.dumps({'response': chunk})}\n" except Exception as e: logger.error(f"Streaming error: {str(e)}") @@ -718,26 +722,24 @@ async def stream_generator(): if not response_content: response_content = "No relevant context found for the query." - # Create complete response object complete_response = {"response": response_content} if request.include_references: complete_response["references"] = references yield f"{json.dumps(complete_response)}\n" - - return StreamingResponse( - stream_generator(), - media_type="application/x-ndjson", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Content-Type": "application/x-ndjson", - "X-Accel-Buffering": "no", # Ensure proper handling of streaming response when proxied by Nginx - }, - ) - except Exception as e: - logger.error(f"Error processing streaming query: {str(e)}", exc_info=True) - raise HTTPException(status_code=500, detail=str(e)) + finally: + workspace_mgr.release(workspace) + + return StreamingResponse( + stream_generator(), + media_type="application/x-ndjson", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "application/x-ndjson", + "X-Accel-Buffering": "no", + }, + ) @router.post( "/query/data", @@ -1035,7 +1037,7 @@ async def stream_generator(): }, }, ) - async def query_data(request: QueryRequest): + async def query_data(http_request: Request, request: QueryRequest): """ Advanced data retrieval endpoint for structured RAG analysis. @@ -1138,6 +1140,8 @@ async def query_data(request: QueryRequest): This endpoint always includes references regardless of the include_references parameter, as structured data analysis typically requires source attribution. """ + workspace = extract_workspace_from_header(http_request) + rag = await workspace_mgr.get_or_create(workspace) try: param = request.to_query_params(False) # No streaming for data endpoint response = await rag.aquery_data(request.query, param=param) @@ -1156,5 +1160,7 @@ async def query_data(request: QueryRequest): except Exception as e: logger.error(f"Error processing data query: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) + finally: + workspace_mgr.release(workspace) return router diff --git a/lightrag/api/routers/workspace_routes.py b/lightrag/api/routers/workspace_routes.py new file mode 100644 index 0000000000..66464e92f9 --- /dev/null +++ b/lightrag/api/routers/workspace_routes.py @@ -0,0 +1,117 @@ +"""Workspace-related routes for the LightRAG API.""" + +from typing import Optional, List +from fastapi import APIRouter, Depends +from pydantic import BaseModel, ConfigDict, Field + +from lightrag.api.utils_api import get_combined_auth_dependency +from lightrag.api.workspace_registry import get_workspace_registry + +router = APIRouter( + prefix="/workspaces", + tags=["workspaces"], +) + + +class WorkspaceInfo(BaseModel): + """Workspace information model. + + Attributes: + name: The workspace name. + first_seen: ISO timestamp when the workspace was first seen. + last_seen: ISO timestamp when the workspace was last accessed. + """ + + name: str = Field(description="The workspace name") + first_seen: str = Field( + description="ISO timestamp when the workspace was first seen" + ) + last_seen: str = Field( + description="ISO timestamp when the workspace was last accessed" + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "name": "my-workspace", + "first_seen": "2026-04-30T10:00:00+00:00", + "last_seen": "2026-04-30T11:30:00+00:00", + } + } + ) + + +class WorkspacesResponse(BaseModel): + """Response model for listing workspaces. + + Attributes: + workspaces: List of workspace information. + """ + + workspaces: List[WorkspaceInfo] = Field(description="List of registered workspaces") + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "workspaces": [ + { + "name": "my-workspace", + "first_seen": "2026-04-30T10:00:00+00:00", + "last_seen": "2026-04-30T11:30:00+00:00", + }, + { + "name": "research-docs", + "first_seen": "2026-04-29T08:00:00+00:00", + "last_seen": "2026-04-30T09:15:00+00:00", + }, + ] + } + } + ) + + +def create_workspace_routes( + api_key: Optional[str] = None, working_dir: Optional[str] = None +): + """Create workspace routes with the given API key and working directory configuration. + + Args: + api_key: Optional API key for authentication. + working_dir: Optional working directory for workspace registry. + + Returns: + Configured APIRouter with workspace endpoints. + """ + combined_auth = get_combined_auth_dependency(api_key) + + @router.get( + "", + response_model=WorkspacesResponse, + dependencies=[Depends(combined_auth)], + summary="List all workspaces", + description="Returns a list of all workspaces that have been registered through document API calls. " + "Workspaces are automatically registered when document endpoints are called with the " + "LIGHTRAG-WORKSPACE header.", + ) + async def list_workspaces(): + """ + List all registered workspaces. + + Returns: + WorkspacesResponse: A response containing the list of all known workspaces + with their first_seen and last_seen timestamps. + """ + workspace_registry = get_workspace_registry(working_dir=working_dir) + workspaces = workspace_registry.get_workspaces() + return WorkspacesResponse( + workspaces=[ + WorkspaceInfo( + name=w["name"], + first_seen=w["first_seen"], + last_seen=w["last_seen"], + ) + for w in workspaces + ] + ) + + return router diff --git a/lightrag/api/utils.py b/lightrag/api/utils.py new file mode 100644 index 0000000000..c67f183239 --- /dev/null +++ b/lightrag/api/utils.py @@ -0,0 +1,88 @@ +"""Utility functions for the LightRAG API.""" + +from __future__ import annotations + +import re + +from fastapi import HTTPException, Request + +__all__ = [ + "sanitize_workspace_name", + "WorkspaceNameError", + "extract_workspace_from_header", +] + + +class WorkspaceNameError(ValueError): + """Custom exception for workspace name validation errors. + + Attributes: + detail: The validation error message describing why the workspace name is invalid. + """ + + def __init__(self, detail: str) -> None: + super().__init__(detail) + self.detail = detail + + +def sanitize_workspace_name(name: str | None) -> str: + """Sanitize a workspace name extracted from the LIGHTRAG-WORKSPACE HTTP header. + + Applies the following sanitization rules: + 1. Returns empty string for None or empty input (default workspace). + 2. Strips leading/trailing whitespace. + 3. Converts to lowercase. + 4. Rejects path traversal attempts containing '..', '/', or '\\'. + 5. Rejects names exceeding 64 characters. + 6. Rejects names containing characters other than lowercase letters, + numbers, hyphens, and underscores. + + Args: + name: The workspace name from the HTTP header, may be None. + + Returns: + The sanitized workspace name, or empty string for the default workspace. + + Raises: + WorkspaceNameError: If the workspace name is invalid. + """ + # Rule 1: None or empty returns empty string (default workspace) + if name is None or name == "": + return "" + + # Rule 2: Strip whitespace + name = name.strip() + + # Rule 3: Convert to lowercase + name = name.lower() + + # Rule 4: Reject path traversal attempts + if ".." in name or "/" in name or "\\" in name: + raise WorkspaceNameError("Invalid workspace name: path traversal detected") + + # Rule 5: Limit length to 64 characters + if len(name) > 64: + raise WorkspaceNameError("Workspace name too long (max 64 characters)") + + # Rule 6: Only allow alphanumeric, hyphens, and underscores + if not re.match(r"^[a-z0-9_-]+$", name): + raise WorkspaceNameError( + "Invalid workspace name: only lowercase letters, numbers, hyphens, and underscores allowed" + ) + + return name + + +def extract_workspace_from_header(request: Request) -> str: + """Extract and sanitize workspace from LIGHTRAG-WORKSPACE header. + + Returns empty string for default workspace. + Raises HTTPException(400) on invalid workspace name. + """ + raw = request.headers.get("LIGHTRAG-WORKSPACE", "").strip() + if raw: + try: + return sanitize_workspace_name(raw) + except WorkspaceNameError as e: + raise HTTPException(status_code=400, detail=str(e)) + return "" diff --git a/lightrag/api/workspace_manager.py b/lightrag/api/workspace_manager.py new file mode 100644 index 0000000000..4dde3fa2fa --- /dev/null +++ b/lightrag/api/workspace_manager.py @@ -0,0 +1,253 @@ +"""Workspace Manager for LightRAG - LRU cache with reference counting.""" + +from __future__ import annotations + +import asyncio +import logging +import os +from collections import OrderedDict +from typing import Any, Awaitable, Callable + +# LIGHTRAG_WORKSPACE_CACHE_LIMIT: max number of workspace instances to cache (default: 10) +try: + _WORKSPACE_CACHE_LIMIT = int( + os.environ.get("LIGHTRAG_WORKSPACE_CACHE_LIMIT", "10") or "10" + ) + if _WORKSPACE_CACHE_LIMIT <= 0: + _WORKSPACE_CACHE_LIMIT = 10 +except (ValueError, TypeError): + _WORKSPACE_CACHE_LIMIT = 10 + +from lightrag.api.utils import sanitize_workspace_name + +logger = logging.getLogger(__name__) + + +class WorkspaceCapacityError(Exception): + """Raised when all workspace slots have in-flight requests.""" + + pass + + +class WorkspaceManager: + """Manages LightRAG workspace instances with LRU caching and reference counting. + + Features: + - LRU cache with configurable max instances + - Reference counting for safe eviction + - Per-workspace locks for thread-safe creation + - Global lock for eviction coordination + - Graceful shutdown with finalization + """ + + def __init__( + self, + factory: Callable[[str], Awaitable[Any]], + max_instances: int = _WORKSPACE_CACHE_LIMIT, + ) -> None: + """Initialize the WorkspaceManager. + + Args: + factory: Async callable that creates a LightRAG instance for a given + workspace name. + max_instances: Maximum number of workspace instances to keep in cache. + """ + self._factory = factory + self._max_instances = max_instances + + # LRU cache: oldest=first, newest=last + self._cache: OrderedDict[str, Any] = OrderedDict() + + # Per-workspace creation locks + self._locks: dict[str, asyncio.Lock] = {} + + # Reference counts: number of active users of each workspace + self._ref_counts: dict[str, int] = {} + + # Global lock for eviction operations + self._lock = asyncio.Lock() + + # Statistics + self._hits = 0 + self._misses = 0 + self._evictions = 0 + + def _normalize_workspace(self, workspace: str | None) -> str: + """Normalize workspace name (None/empty β†’ default workspace).""" + return workspace if workspace else "" + + def _get_per_workspace_lock(self, workspace: str) -> asyncio.Lock: + """Get or create a per-workspace lock.""" + if workspace not in self._locks: + self._locks[workspace] = asyncio.Lock() + return self._locks[workspace] + + async def get_or_create(self, workspace: str | None) -> Any: + """Get a cached workspace instance or create a new one. + + Args: + workspace: Workspace name (None or empty defaults to "") + + Returns: + LightRAG instance for the workspace. + + Raises: + WorkspaceCapacityError: If cache is full and all slots have in-flight requests. + WorkspaceNameError: If the workspace name is invalid. + """ + # Defensive: sanitize workspace name (belt and suspenders) + workspace = sanitize_workspace_name(workspace) + workspace = self._normalize_workspace(workspace) + + # Fast path: workspace already cached. + # No lock needed here β€” safe under asyncio's single-threaded event loop. + # For multi-threaded scenarios, a lock would be required for ref_count increment. + if workspace in self._cache: + # Move to end (mark as recently used) + self._cache.move_to_end(workspace) + self._ref_counts[workspace] += 1 + self._hits += 1 + return self._cache[workspace] + + # Slow path: need to create new instance + # Check capacity BEFORE acquiring locks for performance + if len(self._cache) >= self._max_instances: + # Check if ALL entries have in-flight requests (ref_count > 0) + if all(ref > 0 for ref in self._ref_counts.values()): + raise WorkspaceCapacityError( + "All workspace slots have in-flight requests" + ) + + # Get per-workspace lock for thread-safe creation + per_workspace_lock = self._get_per_workspace_lock(workspace) + + async with per_workspace_lock: + # Double-check: workspace might have been added while waiting for lock + if workspace in self._cache: + self._cache.move_to_end(workspace) + self._ref_counts[workspace] += 1 + self._hits += 1 + return self._cache[workspace] + + # Cache is full - need to evict one + if len(self._cache) >= self._max_instances: + await self._evict_one() + + # Create new instance and initialize all its storages + instance = self._factory(workspace) + await instance.initialize_storages() + + # Store in cache + self._cache[workspace] = instance + self._ref_counts[workspace] = 1 + self._misses += 1 + + return instance + + def release(self, workspace: str | None) -> None: + """Release a reference to a workspace. + + Args: + workspace: Workspace name to release (None or empty defaults to ""). + """ + workspace = self._normalize_workspace(workspace) + + if workspace not in self._ref_counts: + logger.warning( + "Attempted to release unknown workspace '%s'", + workspace, + ) + return + + self._ref_counts[workspace] -= 1 + + if self._ref_counts[workspace] < 0: + logger.error( + "Reference count underflow for workspace '%s', resetting to 0", + workspace, + ) + self._ref_counts[workspace] = 0 + + async def _evict_one(self) -> None: + """Evict one LRU workspace with ref_count == 0. + + Must be called while holding the per-workspace lock, but will also + acquire the global eviction lock for thread-safe cache modification. + Finalization happens OUTSIDE the lock to avoid blocking other operations. + """ + to_finalize = None + async with self._lock: + # Find first workspace with ref_count == 0 (oldest such workspace) + for ws_name in list(self._cache.keys()): + if self._ref_counts.get(ws_name, 0) == 0: + instance = self._cache.pop(ws_name) + del self._ref_counts[ws_name] + self._locks.pop(ws_name, None) + self._evictions += 1 + to_finalize = (ws_name, instance) + logger.info(f"Evicted workspace '{ws_name}' (LRU)") + break + + if to_finalize is None: + # No evictable workspace found - should not happen if capacity + # check was done correctly before calling this method + raise WorkspaceCapacityError( + f"All {self._max_instances} workspace slots have in-flight requests. " + "Cannot create a new workspace." + ) + + # Finalize OUTSIDE the lock β€” avoid blocking other operations during I/O + if to_finalize: + await self._finalize_instance(to_finalize[1]) + + async def _finalize_instance(self, instance: Any) -> None: + """Finalize a workspace instance (cleanup resources). + + Args: + instance: LightRAG instance to finalize. + """ + try: + await instance.finalize_storages() + logger.debug("Finalized workspace instance") + except Exception as e: + logger.warning( + "Failed to finalize workspace instance: %s", + e, + ) + # Don't raise - eviction must succeed + + async def shutdown(self) -> None: + """Shutdown the workspace manager, finalizing all cached instances.""" + instances_to_finalize = [] + async with self._lock: + for ws_name, instance in list(self._cache.items()): + instances_to_finalize.append((ws_name, instance)) + self._cache.clear() + self._ref_counts.clear() + self._locks.clear() + + # Finalize all instances OUTSIDE the lock + for ws_name, instance in instances_to_finalize: + await self._finalize_instance(instance) + + total_evicted = self._evictions + logger.info( + f"WorkspaceManager shutdown: evicted={total_evicted}, " + f"hits={self._hits}, misses={self._misses}" + ) + + def get_stats(self) -> dict: + """Get statistics about the workspace manager. + + Returns: + Dictionary containing cache statistics and current state. + """ + return { + "active_instances": len(self._cache), + "max_instances": self._max_instances, + "cache_hits": self._hits, + "cache_misses": self._misses, + "evictions": self._evictions, + "workspaces": list(self._cache.keys()), + "ref_counts": dict(self._ref_counts), + } diff --git a/lightrag/api/workspace_registry.py b/lightrag/api/workspace_registry.py new file mode 100644 index 0000000000..cfbe744680 --- /dev/null +++ b/lightrag/api/workspace_registry.py @@ -0,0 +1,187 @@ +"""Workspace Registry for LightRAG - Persistent storage of workspace metadata.""" + +import json +import threading +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional + +from lightrag.utils import logger + +# Default filename for the workspace registry +WORKSPACE_REGISTRY_FILENAME = "workspace_registry.json" + + +class WorkspaceRegistry: + """Persistent registry for tracking workspace usage. + + Stores workspace names along with metadata (first_seen, last_seen timestamps) + in a JSON file. Thread-safe for concurrent access. + + Attributes: + registry_path: Path to the JSON file storing workspace data. + _lock: Lock for thread-safe file access. + _cache: In-memory cache of the registry data. + """ + + def __init__(self, working_dir: Optional[str] = None) -> None: + """Initialize the workspace registry. + + Args: + working_dir: Directory to store the registry file. + Defaults to current working directory. + """ + if working_dir: + self.registry_path = Path(working_dir) / WORKSPACE_REGISTRY_FILENAME + else: + self.registry_path = Path.cwd() / WORKSPACE_REGISTRY_FILENAME + + self._lock = threading.Lock() + self._cache: dict[str, dict] = {} + + # Load existing registry on initialization + self._load_registry() + + def _load_registry(self) -> None: + """Load the registry from disk into memory cache. + + Note: This method is called during __init__ before the object is shared, + so no locking is needed. + """ + if self.registry_path.exists(): + try: + with open(self.registry_path, "r", encoding="utf-8") as f: + self._cache = json.load(f) + logger.debug( + f"Loaded workspace registry from {self.registry_path} " + f"with {len(self._cache)} entries" + ) + except (json.JSONDecodeError, IOError) as e: + logger.warning( + f"Failed to load workspace registry: {e}. Starting fresh." + ) + self._cache = {} + else: + self._cache = {} + + def _save_registry(self) -> None: + """Save the in-memory cache to disk atomically. + + Uses write-to-temp-then-rename pattern for atomicity on POSIX systems. + If save fails, rolls back to the previous in-memory state. + """ + # Snapshot current state for potential rollback + old_cache = self._cache.copy() + + try: + # Ensure parent directory exists + self.registry_path.parent.mkdir(parents=True, exist_ok=True) + + # Write to temp file first, then rename for atomicity + tmp_path = self.registry_path.with_suffix(".tmp") + with open(tmp_path, "w", encoding="utf-8") as f: + json.dump(self._cache, f, indent=2) + tmp_path.replace(self.registry_path) # atomic on POSIX + logger.debug(f"Saved workspace registry to {self.registry_path}") + except IOError as e: + # Rollback to previous state on failure + self._cache = old_cache + logger.error( + f"Failed to save workspace registry: {e}. Rolled back in-memory state." + ) + raise + + def register_workspace(self, workspace: str) -> None: + """Register a workspace or update its last_seen timestamp. + + Args: + workspace: The workspace name to register. + Empty string represents the default workspace. + """ + if not workspace: + # Don't register empty workspace (default) + logger.debug("Skipping empty workspace registration") + return + + now = datetime.now(timezone.utc).isoformat() + + with self._lock: + if workspace in self._cache: + # Update last_seen + self._cache[workspace]["last_seen"] = now + else: + # Register new workspace + self._cache[workspace] = { + "name": workspace, + "first_seen": now, + "last_seen": now, + } + logger.info(f"Registered new workspace: {workspace}") + + # Persist to disk + self._save_registry() + + def get_workspaces(self) -> list[dict]: + """Get all registered workspaces. + + Returns: + List of workspace metadata dictionaries with name, first_seen, and last_seen. + """ + with self._lock: + # Sort by last_seen (most recent first) + sorted_workspaces = sorted( + self._cache.values(), + key=lambda w: w.get("last_seen", ""), + reverse=True, + ) + return [ + { + "name": w["name"], + "first_seen": w["first_seen"], + "last_seen": w["last_seen"], + } + for w in sorted_workspaces + ] + + +# Global singleton instance (initialized lazily) +_registry_instance: Optional[WorkspaceRegistry] = None +_registry_lock = threading.Lock() + + +def get_workspace_registry(working_dir: Optional[str] = None) -> WorkspaceRegistry: + """Get or create the global workspace registry instance. + + Args: + working_dir: Optional working directory for the registry file. + Only used on first call. + + Returns: + The global WorkspaceRegistry instance. + + Raises: + ValueError: If working_dir is specified and differs from the + existing instance's working directory. + """ + global _registry_instance + + if _registry_instance is None: + with _registry_lock: + # Double-check after acquiring lock + if _registry_instance is None: + _registry_instance = WorkspaceRegistry(working_dir=working_dir) + logger.info( + f"Workspace registry initialized at {_registry_instance.registry_path}" + ) + + # Guard against re-entry with mismatched working_dir + if working_dir is not None: + instance_dir = str(_registry_instance.registry_path.parent) + if working_dir != instance_dir: + raise ValueError( + f"Workspace registry already initialized with working_dir='{instance_dir}', " + f"but requested working_dir='{working_dir}'. " + "The global registry can only be initialized once with a specific working directory." + ) + + return _registry_instance diff --git a/lightrag/kg/memgraph_impl.py b/lightrag/kg/memgraph_impl.py index d957146ec4..cdd55ce116 100644 --- a/lightrag/kg/memgraph_impl.py +++ b/lightrag/kg/memgraph_impl.py @@ -35,10 +35,13 @@ class MemgraphStorage(BaseGraphStorage): def __init__(self, namespace, global_config, embedding_func, workspace=None): # Priority: 1) MEMGRAPH_WORKSPACE env 2) user arg 3) default 'base' - memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE") - original_workspace = workspace # Save original value for logging - if memgraph_workspace and memgraph_workspace.strip(): - workspace = memgraph_workspace + memgraph_workspace = None + original_workspace = None + if not os.environ.get("WORKSPACE_ISOLATION", "").lower() == "true": + memgraph_workspace = os.environ.get("MEMGRAPH_WORKSPACE") + original_workspace = workspace # Save original value for logging + if memgraph_workspace and memgraph_workspace.strip(): + workspace = memgraph_workspace if not workspace or not str(workspace).strip(): workspace = "base" diff --git a/lightrag/kg/milvus_impl.py b/lightrag/kg/milvus_impl.py index 0b5a492d52..596a4dedcf 100644 --- a/lightrag/kg/milvus_impl.py +++ b/lightrag/kg/milvus_impl.py @@ -1374,20 +1374,23 @@ def __post_init__(self): # Check for MILVUS_WORKSPACE environment variable first (higher priority) # This allows administrators to force a specific workspace for all Milvus storage instances - milvus_workspace = os.environ.get("MILVUS_WORKSPACE") - if milvus_workspace and milvus_workspace.strip(): - # Use environment variable value, overriding the passed workspace parameter - effective_workspace = milvus_workspace.strip() - logger.info( - f"Using MILVUS_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')" - ) + if not os.environ.get("WORKSPACE_ISOLATION", "").lower() == "true": + milvus_workspace = os.environ.get("MILVUS_WORKSPACE") + if milvus_workspace and milvus_workspace.strip(): + # Use environment variable value, overriding the passed workspace parameter + effective_workspace = milvus_workspace.strip() + logger.info( + f"Using MILVUS_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')" + ) + else: + # Use the workspace parameter passed during initialization + effective_workspace = self.workspace + if effective_workspace: + logger.debug( + f"Using passed workspace parameter: '{effective_workspace}'" + ) else: - # Use the workspace parameter passed during initialization effective_workspace = self.workspace - if effective_workspace: - logger.debug( - f"Using passed workspace parameter: '{effective_workspace}'" - ) # Build final_namespace with workspace prefix for data isolation # Keep original namespace unchanged for type detection logic diff --git a/lightrag/kg/mongo_impl.py b/lightrag/kg/mongo_impl.py index 4123d92edb..ddadf16372 100644 --- a/lightrag/kg/mongo_impl.py +++ b/lightrag/kg/mongo_impl.py @@ -99,20 +99,23 @@ def __init__(self, namespace, global_config, embedding_func, workspace=None): def __post_init__(self): # Check for MONGODB_WORKSPACE environment variable first (higher priority) # This allows administrators to force a specific workspace for all MongoDB storage instances - mongodb_workspace = os.environ.get("MONGODB_WORKSPACE") - if mongodb_workspace and mongodb_workspace.strip(): - # Use environment variable value, overriding the passed workspace parameter - effective_workspace = mongodb_workspace.strip() - logger.info( - f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')" - ) + if not os.environ.get("WORKSPACE_ISOLATION", "").lower() == "true": + mongodb_workspace = os.environ.get("MONGODB_WORKSPACE") + if mongodb_workspace and mongodb_workspace.strip(): + # Use environment variable value, overriding the passed workspace parameter + effective_workspace = mongodb_workspace.strip() + logger.info( + f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')" + ) + else: + # Use the workspace parameter passed during initialization + effective_workspace = self.workspace + if effective_workspace: + logger.debug( + f"Using passed workspace parameter: '{effective_workspace}'" + ) else: - # Use the workspace parameter passed during initialization effective_workspace = self.workspace - if effective_workspace: - logger.debug( - f"Using passed workspace parameter: '{effective_workspace}'" - ) # Build final_namespace with workspace prefix for data isolation # Keep original namespace unchanged for type detection logic @@ -328,20 +331,23 @@ def __init__(self, namespace, global_config, embedding_func, workspace=None): def __post_init__(self): # Check for MONGODB_WORKSPACE environment variable first (higher priority) # This allows administrators to force a specific workspace for all MongoDB storage instances - mongodb_workspace = os.environ.get("MONGODB_WORKSPACE") - if mongodb_workspace and mongodb_workspace.strip(): - # Use environment variable value, overriding the passed workspace parameter - effective_workspace = mongodb_workspace.strip() - logger.info( - f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')" - ) + if not os.environ.get("WORKSPACE_ISOLATION", "").lower() == "true": + mongodb_workspace = os.environ.get("MONGODB_WORKSPACE") + if mongodb_workspace and mongodb_workspace.strip(): + # Use environment variable value, overriding the passed workspace parameter + effective_workspace = mongodb_workspace.strip() + logger.info( + f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')" + ) + else: + # Use the workspace parameter passed during initialization + effective_workspace = self.workspace + if effective_workspace: + logger.debug( + f"Using passed workspace parameter: '{effective_workspace}'" + ) else: - # Use the workspace parameter passed during initialization effective_workspace = self.workspace - if effective_workspace: - logger.debug( - f"Using passed workspace parameter: '{effective_workspace}'" - ) # Build final_namespace with workspace prefix for data isolation # Keep original namespace unchanged for type detection logic @@ -763,22 +769,28 @@ def __init__(self, namespace, global_config, embedding_func, workspace=None): global_config=global_config, embedding_func=embedding_func, ) + self.__post_init__() + + def __post_init__(self): # Check for MONGODB_WORKSPACE environment variable first (higher priority) # This allows administrators to force a specific workspace for all MongoDB storage instances - mongodb_workspace = os.environ.get("MONGODB_WORKSPACE") - if mongodb_workspace and mongodb_workspace.strip(): - # Use environment variable value, overriding the passed workspace parameter - effective_workspace = mongodb_workspace.strip() - logger.info( - f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')" - ) + if not os.environ.get("WORKSPACE_ISOLATION", "").lower() == "true": + mongodb_workspace = os.environ.get("MONGODB_WORKSPACE") + if mongodb_workspace and mongodb_workspace.strip(): + # Use environment variable value, overriding the passed workspace parameter + effective_workspace = mongodb_workspace.strip() + logger.info( + f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')" + ) + else: + # Use the workspace parameter passed during initialization + effective_workspace = self.workspace + if effective_workspace: + logger.debug( + f"Using passed workspace parameter: '{effective_workspace}'" + ) else: - # Use the workspace parameter passed during initialization effective_workspace = self.workspace - if effective_workspace: - logger.debug( - f"Using passed workspace parameter: '{effective_workspace}'" - ) # Build final_namespace with workspace prefix for data isolation # Keep original namespace unchanged for type detection logic @@ -2185,20 +2197,23 @@ def __post_init__(self): # Check for MONGODB_WORKSPACE environment variable first (higher priority) # This allows administrators to force a specific workspace for all MongoDB storage instances - mongodb_workspace = os.environ.get("MONGODB_WORKSPACE") - if mongodb_workspace and mongodb_workspace.strip(): - # Use environment variable value, overriding the passed workspace parameter - effective_workspace = mongodb_workspace.strip() - logger.info( - f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')" - ) + if not os.environ.get("WORKSPACE_ISOLATION", "").lower() == "true": + mongodb_workspace = os.environ.get("MONGODB_WORKSPACE") + if mongodb_workspace and mongodb_workspace.strip(): + # Use environment variable value, overriding the passed workspace parameter + effective_workspace = mongodb_workspace.strip() + logger.info( + f"Using MONGODB_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')" + ) + else: + # Use the workspace parameter passed during initialization + effective_workspace = self.workspace + if effective_workspace: + logger.debug( + f"Using passed workspace parameter: '{effective_workspace}'" + ) else: - # Use the workspace parameter passed during initialization effective_workspace = self.workspace - if effective_workspace: - logger.debug( - f"Using passed workspace parameter: '{effective_workspace}'" - ) # Build final_namespace with workspace prefix for data isolation # Keep original namespace unchanged for type detection logic diff --git a/lightrag/kg/neo4j_impl.py b/lightrag/kg/neo4j_impl.py index 9ec79eceff..010816f562 100644 --- a/lightrag/kg/neo4j_impl.py +++ b/lightrag/kg/neo4j_impl.py @@ -66,10 +66,13 @@ class Neo4JStorage(BaseGraphStorage): def __init__(self, namespace, global_config, embedding_func, workspace=None): # Read env and override the arg if present - neo4j_workspace = os.environ.get("NEO4J_WORKSPACE") - original_workspace = workspace # Save original value for logging - if neo4j_workspace and neo4j_workspace.strip(): - workspace = neo4j_workspace + neo4j_workspace = None + original_workspace = None + if not os.environ.get("WORKSPACE_ISOLATION", "").lower() == "true": + neo4j_workspace = os.environ.get("NEO4J_WORKSPACE") + original_workspace = workspace # Save original value for logging + if neo4j_workspace and neo4j_workspace.strip(): + workspace = neo4j_workspace # Default to 'base' when both arg and env are empty if not workspace or not str(workspace).strip(): @@ -144,7 +147,7 @@ async def initialize(self): MAX_CONNECTION_POOL_SIZE = int( os.environ.get( "NEO4J_MAX_CONNECTION_POOL_SIZE", - config.get("neo4j", "connection_pool_size", fallback=100), + config.get("neo4j", "connection_pool_size", fallback=10), ) ) CONNECTION_TIMEOUT = float( diff --git a/lightrag/kg/opensearch_impl.py b/lightrag/kg/opensearch_impl.py index 4129577b02..9c2343a69c 100644 --- a/lightrag/kg/opensearch_impl.py +++ b/lightrag/kg/opensearch_impl.py @@ -188,13 +188,14 @@ async def release_client(cls, client: AsyncOpenSearch): def _resolve_workspace(workspace: str, namespace: str): """Resolve effective workspace from env or parameter.""" - opensearch_workspace = os.environ.get("OPENSEARCH_WORKSPACE") - if opensearch_workspace and opensearch_workspace.strip(): - effective = opensearch_workspace.strip() - logger.info( - f"Using OPENSEARCH_WORKSPACE: '{effective}' (overriding '{workspace}/{namespace}')" - ) - return effective + if not os.environ.get("WORKSPACE_ISOLATION", "").lower() == "true": + opensearch_workspace = os.environ.get("OPENSEARCH_WORKSPACE") + if opensearch_workspace and opensearch_workspace.strip(): + effective = opensearch_workspace.strip() + logger.info( + f"Using OPENSEARCH_WORKSPACE: '{effective}' (overriding '{workspace}/{namespace}')" + ) + return effective return workspace diff --git a/lightrag/kg/postgres_impl.py b/lightrag/kg/postgres_impl.py index 638ba96f59..7243975088 100644 --- a/lightrag/kg/postgres_impl.py +++ b/lightrag/kg/postgres_impl.py @@ -1988,7 +1988,9 @@ def get_config(vector_storage: str | None = None) -> dict[str, Any]: "POSTGRES_DATABASE", config.get("postgres", "database", fallback="postgres"), ), - "workspace": os.environ.get( + "workspace": None + if os.environ.get("WORKSPACE_ISOLATION", "").lower() == "true" + else os.environ.get( "POSTGRES_WORKSPACE", config.get("postgres", "workspace", fallback=None), ), @@ -3240,18 +3242,19 @@ async def initialize(self): ) # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" - if self.db.workspace: - # Use PostgreSQLDB's workspace (highest priority) - logger.info( - f"Using PG_WORKSPACE environment variable: '{self.db.workspace}' (overriding '{self.workspace}/{self.namespace}')" - ) - self.workspace = self.db.workspace - elif hasattr(self, "workspace") and self.workspace: - # Use storage class's workspace (medium priority) - pass - else: - # Use "default" for compatibility (lowest priority) - self.workspace = "default" + if not os.environ.get("WORKSPACE_ISOLATION", "").lower() == "true": + if self.db.workspace: + # Use PostgreSQLDB's workspace (highest priority) + logger.info( + f"Using PG_WORKSPACE environment variable: '{self.db.workspace}' (overriding '{self.workspace}/{self.namespace}')" + ) + self.workspace = self.db.workspace + elif hasattr(self, "workspace") and self.workspace: + # Use storage class's workspace (medium priority) + pass + else: + # Use "default" for compatibility (lowest priority) + self.workspace = "default" # Setup table (create if not exists and handle migration) await PGVectorStorage.setup_table( @@ -3770,18 +3773,19 @@ async def initialize(self): ) # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" - if self.db.workspace: - # Use PostgreSQLDB's workspace (highest priority) - logger.info( - f"Using PG_WORKSPACE environment variable: '{self.db.workspace}' (overriding '{self.workspace}/{self.namespace}')" - ) - self.workspace = self.db.workspace - elif hasattr(self, "workspace") and self.workspace: - # Use storage class's workspace (medium priority) - pass - else: - # Use "default" for compatibility (lowest priority) - self.workspace = "default" + if not os.environ.get("WORKSPACE_ISOLATION", "").lower() == "true": + if self.db.workspace: + # Use PostgreSQLDB's workspace (highest priority) + logger.info( + f"Using PG_WORKSPACE environment variable: '{self.db.workspace}' (overriding '{self.workspace}/{self.namespace}')" + ) + self.workspace = self.db.workspace + elif hasattr(self, "workspace") and self.workspace: + # Use storage class's workspace (medium priority) + pass + else: + # Use "default" for compatibility (lowest priority) + self.workspace = "default" # NOTE: Table creation is handled by PostgreSQLDB.initdb() during initialization # No need to create table here as it's already created in the TABLES dict @@ -4668,18 +4672,19 @@ async def initialize(self): ) # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" - if self.db.workspace: - # Use PostgreSQLDB's workspace (highest priority) - logger.info( - f"Using PG_WORKSPACE environment variable: '{self.db.workspace}' (overriding '{self.workspace}/{self.namespace}')" - ) - self.workspace = self.db.workspace - elif hasattr(self, "workspace") and self.workspace: - # Use storage class's workspace (medium priority) - pass - else: - # Use "default" for compatibility (lowest priority) - self.workspace = "default" + if not os.environ.get("WORKSPACE_ISOLATION", "").lower() == "true": + if self.db.workspace: + # Use PostgreSQLDB's workspace (highest priority) + logger.info( + f"Using PG_WORKSPACE environment variable: '{self.db.workspace}' (overriding '{self.workspace}/{self.namespace}')" + ) + self.workspace = self.db.workspace + elif hasattr(self, "workspace") and self.workspace: + # Use storage class's workspace (medium priority) + pass + else: + # Use "default" for compatibility (lowest priority) + self.workspace = "default" # Dynamically generate graph name based on workspace self.graph_name = self._get_workspace_graph_name() diff --git a/lightrag/kg/qdrant_impl.py b/lightrag/kg/qdrant_impl.py index c126ff028f..b707eadd7e 100644 --- a/lightrag/kg/qdrant_impl.py +++ b/lightrag/kg/qdrant_impl.py @@ -419,20 +419,23 @@ def __post_init__(self): self._validate_embedding_func() # Check for QDRANT_WORKSPACE environment variable first (higher priority) # This allows administrators to force a specific workspace for all Qdrant storage instances - qdrant_workspace = os.environ.get("QDRANT_WORKSPACE") - if qdrant_workspace and qdrant_workspace.strip(): - # Use environment variable value, overriding the passed workspace parameter - effective_workspace = qdrant_workspace.strip() - logger.info( - f"Using QDRANT_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')" - ) + if not os.environ.get("WORKSPACE_ISOLATION", "").lower() == "true": + qdrant_workspace = os.environ.get("QDRANT_WORKSPACE") + if qdrant_workspace and qdrant_workspace.strip(): + # Use environment variable value, overriding the passed workspace parameter + effective_workspace = qdrant_workspace.strip() + logger.info( + f"Using QDRANT_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')" + ) + else: + # Use the workspace parameter passed during initialization + effective_workspace = self.workspace + if effective_workspace: + logger.debug( + f"Using passed workspace parameter: '{effective_workspace}'" + ) else: - # Use the workspace parameter passed during initialization effective_workspace = self.workspace - if effective_workspace: - logger.debug( - f"Using passed workspace parameter: '{effective_workspace}'" - ) self.effective_workspace = effective_workspace or DEFAULT_WORKSPACE diff --git a/lightrag/kg/redis_impl.py b/lightrag/kg/redis_impl.py index 89eb594369..9a7fc1b698 100644 --- a/lightrag/kg/redis_impl.py +++ b/lightrag/kg/redis_impl.py @@ -128,20 +128,23 @@ class RedisKVStorage(BaseKVStorage): def __post_init__(self): # Check for REDIS_WORKSPACE environment variable first (higher priority) # This allows administrators to force a specific workspace for all Redis storage instances - redis_workspace = os.environ.get("REDIS_WORKSPACE") - if redis_workspace and redis_workspace.strip(): - # Use environment variable value, overriding the passed workspace parameter - effective_workspace = redis_workspace.strip() - logger.info( - f"Using REDIS_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')" - ) + if not os.environ.get("WORKSPACE_ISOLATION", "").lower() == "true": + redis_workspace = os.environ.get("REDIS_WORKSPACE") + if redis_workspace and redis_workspace.strip(): + # Use environment variable value, overriding the passed workspace parameter + effective_workspace = redis_workspace.strip() + logger.info( + f"Using REDIS_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')" + ) + else: + # Use the workspace parameter passed during initialization + effective_workspace = self.workspace + if effective_workspace: + logger.debug( + f"Using passed workspace parameter: '{effective_workspace}'" + ) else: - # Use the workspace parameter passed during initialization effective_workspace = self.workspace - if effective_workspace: - logger.debug( - f"Using passed workspace parameter: '{effective_workspace}'" - ) # Build final_namespace with workspace prefix for data isolation # Keep original namespace unchanged for type detection logic @@ -524,20 +527,23 @@ class RedisDocStatusStorage(DocStatusStorage): def __post_init__(self): # Check for REDIS_WORKSPACE environment variable first (higher priority) # This allows administrators to force a specific workspace for all Redis storage instances - redis_workspace = os.environ.get("REDIS_WORKSPACE") - if redis_workspace and redis_workspace.strip(): - # Use environment variable value, overriding the passed workspace parameter - effective_workspace = redis_workspace.strip() - logger.info( - f"Using REDIS_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')" - ) + if not os.environ.get("WORKSPACE_ISOLATION", "").lower() == "true": + redis_workspace = os.environ.get("REDIS_WORKSPACE") + if redis_workspace and redis_workspace.strip(): + # Use environment variable value, overriding the passed workspace parameter + effective_workspace = redis_workspace.strip() + logger.info( + f"Using REDIS_WORKSPACE environment variable: '{effective_workspace}' (overriding '{self.workspace}/{self.namespace}')" + ) + else: + # Use the workspace parameter passed during initialization + effective_workspace = self.workspace + if effective_workspace: + logger.debug( + f"Using passed workspace parameter: '{effective_workspace}'" + ) else: - # Use the workspace parameter passed during initialization effective_workspace = self.workspace - if effective_workspace: - logger.debug( - f"Using passed workspace parameter: '{effective_workspace}'" - ) # Build final_namespace with workspace prefix for data isolation # Keep original namespace unchanged for type detection logic diff --git a/lightrag_webui/.env.development b/lightrag_webui/.env.development index 501be53c4a..c411a25f4b 100644 --- a/lightrag_webui/.env.development +++ b/lightrag_webui/.env.development @@ -1,4 +1,4 @@ # Development environment configuration VITE_BACKEND_URL=http://localhost:9621 VITE_API_PROXY=true -VITE_API_ENDPOINTS=/api,/documents,/graphs,/graph,/health,/query,/docs,/redoc,/openapi.json,/login,/auth-status,/static +VITE_API_ENDPOINTS=/api,/documents,/graphs,/graph,/health,/query,/docs,/redoc,/openapi.json,/login,/auth-status,/static,/workspaces diff --git a/lightrag_webui/bun.lock b/lightrag_webui/bun.lock index 8784424ba9..c461e48cff 100644 --- a/lightrag_webui/bun.lock +++ b/lightrag_webui/bun.lock @@ -71,9 +71,12 @@ }, "devDependencies": { "@eslint/js": "^10.0.1", + "@playwright/test": "^1.59.1", "@stylistic/eslint-plugin": "^5.10.0", "@tailwindcss/typography": "^0.5.15", "@tailwindcss/vite": "^4.2.2", + "@testing-library/jest-dom": "^6.9.1", + "@testing-library/react": "^16.3.2", "@types/bun": "^1.3.12", "@types/katex": "^0.16.8", "@types/node": "^25.6.0", @@ -90,6 +93,8 @@ "eslint-plugin-react-refresh": "^0.5.2", "globals": "^17.5.0", "graphology-types": "^0.24.8", + "happy-dom": "^20.9.0", + "playwright": "^1.59.1", "prettier": "^3.8.3", "prettier-plugin-tailwindcss": "^0.7.2", "tailwindcss": "^4.2.2", @@ -101,6 +106,8 @@ }, }, "packages": { + "@adobe/css-tools": ["@adobe/css-tools@4.4.4", "", {}, "sha512-Elp+iwUx5rN5+Y8xLt5/GRoG20WGoDCQ/1Fb+1LiGtvwbDavuSk0jhD/eZdckHAuzcDzccnkv+rEjyWfRx18gg=="], + "@antfu/install-pkg": ["@antfu/install-pkg@1.1.0", "", { "dependencies": { "package-manager-detector": "^1.3.0", "tinyexec": "^1.0.1" } }, "sha512-MGQsmw10ZyI+EJo45CdSER4zEb+p31LpDAFp2Z3gkSd1yqVZGi0Ebx++YTEMonJy4oChEMLsxZ64j8FH6sSqtQ=="], "@babel/code-frame": ["@babel/code-frame@7.27.1", "", { "dependencies": { "@babel/helper-validator-identifier": "^7.27.1", "js-tokens": "^4.0.0", "picocolors": "^1.1.1" } }, "sha512-cjQ7ZlQ0Mv3b47hABuTevyTuYN4i+loJKGeV9flcCgIK37cCXRh+L1bd3iBHlynerhQ7BhCkn2BPbQUL+rGqFg=="], @@ -231,6 +238,8 @@ "@oxc-project/types": ["@oxc-project/types@0.126.0", "", {}, "sha512-oGfVtjAgwQVVpfBrbtk4e1XDyWHRFta6BS3GWVzrF8xYBT2VGQAk39yJS/wFSMrZqoiCU4oghT3Ch0HaHGIHcQ=="], + "@playwright/test": ["@playwright/test@1.59.1", "", { "dependencies": { "playwright": "1.59.1" }, "bin": { "playwright": "cli.js" } }, "sha512-PG6q63nQg5c9rIi4/Z5lR5IVF7yU5MqmKaPOe0HSc0O2cX1fPi96sUQu5j7eo4gKCkB2AnNGoWt7y4/Xx3Kcqg=="], + "@radix-ui/number": ["@radix-ui/number@1.1.1", "", {}, "sha512-MkKCwxlXTgz6CFoJx3pCwn07GKp36+aZyu/u2Ln2VrA5DcdyCZkASEDBTd8x5whTQQL5CiYf4prXKLcgQdv29g=="], "@radix-ui/primitive": ["@radix-ui/primitive@1.1.3", "", {}, "sha512-JTF99U/6XIjCBo0wqkU5sK10glYe27MRRsfwoiq5zzOEZLHU3A3KCMa5X/azekYRCJ0HlwI0crAXS/5dEHTzDg=="], @@ -399,8 +408,16 @@ "@tanstack/table-core": ["@tanstack/table-core@8.21.3", "", {}, "sha512-ldZXEhOBb8Is7xLs01fR3YEc3DERiz5silj8tnGkFZytt1abEvl/GhUmCE0PMLaMPTa3Jk4HbKmRlHmu+gCftg=="], + "@testing-library/dom": ["@testing-library/dom@10.4.1", "", { "dependencies": { "@babel/code-frame": "^7.10.4", "@babel/runtime": "^7.12.5", "@types/aria-query": "^5.0.1", "aria-query": "5.3.0", "dom-accessibility-api": "^0.5.9", "lz-string": "^1.5.0", "picocolors": "1.1.1", "pretty-format": "^27.0.2" } }, "sha512-o4PXJQidqJl82ckFaXUeoAW+XysPLauYI43Abki5hABd853iMhitooc6znOnczgbTYmEP6U6/y1ZyKAIsvMKGg=="], + + "@testing-library/jest-dom": ["@testing-library/jest-dom@6.9.1", "", { "dependencies": { "@adobe/css-tools": "^4.4.0", "aria-query": "^5.0.0", "css.escape": "^1.5.1", "dom-accessibility-api": "^0.6.3", "picocolors": "^1.1.1", "redent": "^3.0.0" } }, "sha512-zIcONa+hVtVSSep9UT3jZ5rizo2BsxgyDYU7WFD5eICBE7no3881HGeb/QkGfsJs6JTkY1aQhT7rIPC7e+0nnA=="], + + "@testing-library/react": ["@testing-library/react@16.3.2", "", { "dependencies": { "@babel/runtime": "^7.12.5" }, "peerDependencies": { "@testing-library/dom": "^10.0.0", "@types/react": "^18.0.0 || ^19.0.0", "@types/react-dom": "^18.0.0 || ^19.0.0", "react": "^18.0.0 || ^19.0.0", "react-dom": "^18.0.0 || ^19.0.0" }, "optionalPeers": ["@types/react", "@types/react-dom"] }, "sha512-XU5/SytQM+ykqMnAnvB2umaJNIOsLF3PVv//1Ew4CTcpz0/BRyy/af40qqrt7SjKpDdT1saBMc42CUok5gaw+g=="], + "@tybys/wasm-util": ["@tybys/wasm-util@0.10.1", "", { "dependencies": { "tslib": "^2.4.0" } }, "sha512-9tTaPJLSiejZKx+Bmog4uSubteqTvFrVrURwkmHixBo0G4seD0zUxp98E1DzUBJxLQ3NPwXrGKDiVjwx/DpPsg=="], + "@types/aria-query": ["@types/aria-query@5.0.4", "", {}, "sha512-rfT93uj5s0PRL7EzccGMs3brplhcrghnDoV26NqKhCAS1hVo+WdNsPvE/yb6ilfr5hi2MEk6d5EWJTKdxg8jVw=="], + "@types/bun": ["@types/bun@1.3.12", "", { "dependencies": { "bun-types": "1.3.12" } }, "sha512-DBv81elK+/VSwXHDlnH3Qduw+KxkTIWi7TXkAeh24zpi5l0B2kUg9Ga3tb4nJaPcOFswflgi/yAvMVBPrxMB+A=="], "@types/d3": ["@types/d3@7.4.3", "", { "dependencies": { "@types/d3-array": "*", "@types/d3-axis": "*", "@types/d3-brush": "*", "@types/d3-chord": "*", "@types/d3-color": "*", "@types/d3-contour": "*", "@types/d3-delaunay": "*", "@types/d3-dispatch": "*", "@types/d3-drag": "*", "@types/d3-dsv": "*", "@types/d3-ease": "*", "@types/d3-fetch": "*", "@types/d3-force": "*", "@types/d3-format": "*", "@types/d3-geo": "*", "@types/d3-hierarchy": "*", "@types/d3-interpolate": "*", "@types/d3-path": "*", "@types/d3-polygon": "*", "@types/d3-quadtree": "*", "@types/d3-random": "*", "@types/d3-scale": "*", "@types/d3-scale-chromatic": "*", "@types/d3-selection": "*", "@types/d3-shape": "*", "@types/d3-time": "*", "@types/d3-time-format": "*", "@types/d3-timer": "*", "@types/d3-transition": "*", "@types/d3-zoom": "*" } }, "sha512-lZXZ9ckh5R8uiFVt8ogUNf+pIrK4EsWrx2Np75WvF/eTpJ0FMHNhjXk8CKEx/+gpHbNQyJWehbFaTvqmHWB3ww=="], @@ -507,6 +524,10 @@ "@types/unist": ["@types/unist@3.0.3", "", {}, "sha512-ko/gIFJRv177XgZsZcBwnqJN5x/Gien8qNOn0D5bQU/zAzVf9Zt3BlcUiLqhV9y4ARk0GbT3tnUiPNgnTXzc/Q=="], + "@types/whatwg-mimetype": ["@types/whatwg-mimetype@3.0.2", "", {}, "sha512-c2AKvDT8ToxLIOUlN51gTiHXflsfIFisS4pO7pDPoKouJCESkhZnEy623gwP9laCy5lnLDAw1vAzu2vM2YLOrA=="], + + "@types/ws": ["@types/ws@8.18.1", "", { "dependencies": { "@types/node": "*" } }, "sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg=="], + "@typescript-eslint/eslint-plugin": ["@typescript-eslint/eslint-plugin@8.58.2", "", { "dependencies": { "@eslint-community/regexpp": "^4.12.2", "@typescript-eslint/scope-manager": "8.58.2", "@typescript-eslint/type-utils": "8.58.2", "@typescript-eslint/utils": "8.58.2", "@typescript-eslint/visitor-keys": "8.58.2", "ignore": "^7.0.5", "natural-compare": "^1.4.0", "ts-api-utils": "^2.5.0" }, "peerDependencies": { "@typescript-eslint/parser": "^8.58.2", "eslint": "^8.57.0 || ^9.0.0 || ^10.0.0", "typescript": ">=4.8.4 <6.1.0" } }, "sha512-aC2qc5thQahutKjP+cl8cgN9DWe3ZUqVko30CMSZHnFEHyhOYoZSzkGtAI2mcwZ38xeImDucI4dnqsHiOYuuCw=="], "@typescript-eslint/parser": ["@typescript-eslint/parser@8.58.2", "", { "dependencies": { "@typescript-eslint/scope-manager": "8.58.2", "@typescript-eslint/types": "8.58.2", "@typescript-eslint/typescript-estree": "8.58.2", "@typescript-eslint/visitor-keys": "8.58.2", "debug": "^4.4.3" }, "peerDependencies": { "eslint": "^8.57.0 || ^9.0.0 || ^10.0.0", "typescript": ">=4.8.4 <6.1.0" } }, "sha512-/Zb/xaIDfxeJnvishjGdcR4jmr7S+bda8PKNhRGdljDM+elXhlvN0FyPSsMnLmJUrVG9aPO6dof80wjMawsASg=="], @@ -541,8 +562,14 @@ "ajv": ["ajv@6.14.0", "", { "dependencies": { "fast-deep-equal": "^3.1.1", "fast-json-stable-stringify": "^2.0.0", "json-schema-traverse": "^0.4.1", "uri-js": "^4.2.2" } }, "sha512-IWrosm/yrn43eiKqkfkHis7QioDleaXQHdDVPKg0FSwwd/DuvyX79TZnFOnYpB7dcsFAMmtFztZuXPDvSePkFw=="], + "ansi-regex": ["ansi-regex@5.0.1", "", {}, "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ=="], + + "ansi-styles": ["ansi-styles@5.2.0", "", {}, "sha512-Cxwpt2SfTzTtXcfOlzGEee8O+c+MmUgGrNiBcXnuWxuFJHe6a5Hz7qwhwe5OgaSYI0IJvkLqWX1ASG+cJOkEiA=="], + "aria-hidden": ["aria-hidden@1.2.6", "", { "dependencies": { "tslib": "^2.0.0" } }, "sha512-ik3ZgC9dY/lYVVM++OISsaYDeg1tb0VtP5uL3ouh1koGOaUMDPpbFIei4JkFimWUFPn90sbMNMXQAIVOlnYKJA=="], + "aria-query": ["aria-query@5.3.2", "", {}, "sha512-COROpnaoap1E2F000S62r6A60uHZnmlvomhfyT2DlTcrY1OrBKn2UhH7qn5wTC9zMvD0AY7csdPSNwKP+7WiQw=="], + "array-buffer-byte-length": ["array-buffer-byte-length@1.0.2", "", { "dependencies": { "call-bound": "^1.0.3", "is-array-buffer": "^3.0.5" } }, "sha512-LHE+8BuR7RYGDKvnrmcuSq3tDcKv9OFEXQt/HpbZhY7V6h0zlUXutnAD82GiFx9rdieCMjkvtcsPqBwgUl1Iiw=="], "array-includes": ["array-includes@3.1.9", "", { "dependencies": { "call-bind": "^1.0.8", "call-bound": "^1.0.4", "define-properties": "^1.2.1", "es-abstract": "^1.24.0", "es-object-atoms": "^1.1.1", "get-intrinsic": "^1.3.0", "is-string": "^1.1.1", "math-intrinsics": "^1.1.0" } }, "sha512-FmeCCAenzH0KH381SPT5FZmiA/TmpndpcaShhfgEN9eCVjnFBqq3l1xrI42y8+PPLI6hypzou4GXw00WHmPBLQ=="], @@ -637,6 +664,8 @@ "cross-spawn": ["cross-spawn@7.0.6", "", { "dependencies": { "path-key": "^3.1.0", "shebang-command": "^2.0.0", "which": "^2.0.1" } }, "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA=="], + "css.escape": ["css.escape@1.5.1", "", {}, "sha512-YUifsXXuknHlUsmlgyY0PKzgPOr7/FjCePfHNt0jxm83wHZi44VDMQ7/fGNkjY3/jV1MC+1CmZbaHzugyeRtpg=="], + "cssesc": ["cssesc@3.0.0", "", { "bin": { "cssesc": "bin/cssesc" } }, "sha512-/Tb/JcjK111nNScGob5MNtsntNM1aCNUDipB/TkwZFhyDrrE47SOx/18wF2bbjgc3ZzCSKW1T5nt5EbFoAz/Vg=="], "csstype": ["csstype@3.2.3", "", {}, "sha512-z1HGKcYy2xA8AGQfwrn0PAy+PB7X/GSj3UVJW9qKyn43xWa+gl5nXmU4qqLMRzWVLFC8KusUX8T/0kCiOYpAIQ=="], @@ -747,6 +776,8 @@ "doctrine": ["doctrine@2.1.0", "", { "dependencies": { "esutils": "^2.0.2" } }, "sha512-35mSku4ZXK0vfCuHEDAwt55dg2jNajHZ1odvF+8SSr82EsZY4QmXfuWso8oEd8zRhVObSN18aM0CjSdoBX7zIw=="], + "dom-accessibility-api": ["dom-accessibility-api@0.6.3", "", {}, "sha512-7ZgogeTnjuHbo+ct10G9Ffp0mif17idi0IyWNVA/wcwcm7NPOD/WEHVP3n7n3MhXqxoIYm8d6MuZohYWIZ4T3w=="], + "dom-helpers": ["dom-helpers@5.2.1", "", { "dependencies": { "@babel/runtime": "^7.8.7", "csstype": "^3.0.2" } }, "sha512-nRCa7CK3VTrM2NmGkIy4cbK7IZlgBE/PYMn55rrXefr5xXDP0LdtfPnblFDoVdcAfslJ7or6iqAUnx0CCGIWQA=="], "dompurify": ["dompurify@3.3.1", "", { "optionalDependencies": { "@types/trusted-types": "^2.0.7" } }, "sha512-qkdCKzLNtrgPFP1Vo+98FRzJnBRGe4ffyCea9IwHB1fyxPOeNTHpLKYGd4Uk9xvNoH0ZoOjwZxNptyMwqrId1Q=="], @@ -757,7 +788,7 @@ "enhanced-resolve": ["enhanced-resolve@5.19.0", "", { "dependencies": { "graceful-fs": "^4.2.4", "tapable": "^2.3.0" } }, "sha512-phv3E1Xl4tQOShqSte26C7Fl84EwUdZsyOuSSk9qtAGyyQs2s3jJzComh+Abf4g187lUUAvH+H26omrqia2aGg=="], - "entities": ["entities@6.0.1", "", {}, "sha512-aN97NXWF6AWBTahfVOIrB/NShkzi5H7F9r1s9mD3cDj4Ko5f2qhhVoYMibXF7GlLveb/D2ioWay8lxI97Ven3g=="], + "entities": ["entities@7.0.1", "", {}, "sha512-TWrgLOFUQTH994YUyl1yT4uyavY5nNB5muff+RtWaqNVCAK408b5ZnnbNAUEWLTCpum9w6arT70i1XdQ4UeOPA=="], "error-ex": ["error-ex@1.3.4", "", { "dependencies": { "is-arrayish": "^0.2.1" } }, "sha512-sqQamAnR14VgCr1A618A3sGrygcpK+HEbenA/HiEAkkUwcZIIB/tgWqHFxWgOyDh4nB4JCRimh79dR5Ywc9MDQ=="], @@ -841,7 +872,7 @@ "format": ["format@0.2.2", "", {}, "sha512-wzsgA6WOq+09wrU1tsJ09udeR/YZRaeArL9e1wPbFg3GG2yDnC2ldKpxs4xunpFF9DgqCqOIra3bc1HWrJ37Ww=="], - "fsevents": ["fsevents@2.3.3", "", { "os": "darwin" }, "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw=="], + "fsevents": ["fsevents@2.3.2", "", { "os": "darwin" }, "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA=="], "function-bind": ["function-bind@1.1.2", "", {}, "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA=="], @@ -897,6 +928,8 @@ "hachure-fill": ["hachure-fill@0.5.2", "", {}, "sha512-3GKBOn+m2LX9iq+JC1064cSFprJY4jL1jCXTcpnfER5HYE2l/4EfWSGzkPa/ZDBmYI0ZOEj5VHV/eKnPGkHuOg=="], + "happy-dom": ["happy-dom@20.9.0", "", { "dependencies": { "@types/node": ">=20.0.0", "@types/whatwg-mimetype": "^3.0.2", "@types/ws": "^8.18.1", "entities": "^7.0.1", "whatwg-mimetype": "^3.0.0", "ws": "^8.18.3" } }, "sha512-GZZ9mKe8r646NUAf/zemnGbjYh4Bt8/MqASJY+pSm5ZDtc3YQox+4gsLI7yi1hba6o+eCsGxpHn5+iEVn31/FQ=="], + "has-bigints": ["has-bigints@1.1.0", "", {}, "sha512-R3pbpkcIqv2Pm3dUwgjclDRVmWpTJW2DcMzcIhEXEx1oh/CEMObMm3KLmRJOdvhM7o4uQBnwr8pzRK2sJWIqfg=="], "has-property-descriptors": ["has-property-descriptors@1.0.2", "", { "dependencies": { "es-define-property": "^1.0.0" } }, "sha512-55JNKuIW+vq4Ke1BjOTjM2YctQIvCT7GFzHwmfZPGo5wnrgkid0YQtnAleFSqumZm4az3n2BS+erby5ipJdgrg=="], @@ -959,6 +992,8 @@ "imurmurhash": ["imurmurhash@0.1.4", "", {}, "sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA=="], + "indent-string": ["indent-string@4.0.0", "", {}, "sha512-EdDDZu4A2OyIK7Lr/2zG+w5jmbuk1DVBnEwREQvBzspBJkCEbRa8GxU1lghYcaGJCnRWibjDXlq779X1/y5xwg=="], + "inline-style-parser": ["inline-style-parser@0.2.7", "", {}, "sha512-Nb2ctOyNR8DqQoR0OwRG95uNWIC0C1lCgf5Naz5H6Ji72KZ8OcFZLz2P5sNgwlyoJ8Yif11oMuYs5pBQa86csA=="], "internal-slot": ["internal-slot@1.1.0", "", { "dependencies": { "es-errors": "^1.3.0", "hasown": "^2.0.2", "side-channel": "^1.1.0" } }, "sha512-4gd7VpWNQNB4UKKCFFVcp1AVv+FMOgs9NKzjHKusc8jTMhd5eL1NqQqOpE0KzMds804/yHlglp3uxgluOqAPLw=="], @@ -1105,6 +1140,8 @@ "lucide-react": ["lucide-react@1.6.0", "", { "peerDependencies": { "react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "sha512-YxLKVCOF5ZDI1AhKQE5IBYMY9y/Nr4NT15+7QEWpsTSVCdn4vmZhww+6BP76jWYjQx8rSz1Z+gGme1f+UycWEw=="], + "lz-string": ["lz-string@1.5.0", "", { "bin": { "lz-string": "bin/bin.js" } }, "sha512-h5bgJWpxJNswbU7qCrV0tIKQCaS3blPDrqKWx+QxzuzL1zGUzij9XCWLrSLsJPu5t+eWA/ycetzYAO5IOMcWAQ=="], + "magic-string": ["magic-string@0.30.21", "", { "dependencies": { "@jridgewell/sourcemap-codec": "^1.5.5" } }, "sha512-vd2F4YUyEXKGcLHoq+TEyCjxueSeHnFxyyjNp80yg0XV4vUhnDer/lvvlqM/arB5bXQN5K2/3oinyCRyx8T2CQ=="], "markdown-table": ["markdown-table@3.0.4", "", {}, "sha512-wiYz4+JrLyb/DqW2hkFJxP7Vd7JuTDm77fvbM8VfEQdmSMqcImWeeRbHwZjBjIFki/VaMK2BhFi7oUUZeM5bqw=="], @@ -1211,6 +1248,8 @@ "mime-types": ["mime-types@2.1.35", "", { "dependencies": { "mime-db": "1.52.0" } }, "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw=="], + "min-indent": ["min-indent@1.0.1", "", {}, "sha512-I9jwMn07Sy/IwOj3zVkVik2JTvgpaykDZEigL6Rx6N9LbMywwUSMtxET+7lVoDLLd3O3IXwJwvuuns8UB/HeAg=="], + "minimatch": ["minimatch@10.2.4", "", { "dependencies": { "brace-expansion": "^5.0.2" } }, "sha512-oRjTw/97aTBN0RHbYCdtF1MQfvusSIBQM0IZEgzl6426+8jSC0nF1a/GmnVLpfB9yyr6g6FTqWqiZVbxrtaCIg=="], "minisearch": ["minisearch@7.2.0", "", {}, "sha512-dqT2XBYUOZOiC5t2HRnwADjhNS2cecp9u+TJRiJ1Qp/f5qjkeT5APcGPjHw+bz89Ms8Jp+cG4AlE+QZ/QnDglg=="], @@ -1285,6 +1324,10 @@ "pkg-types": ["pkg-types@1.3.1", "", { "dependencies": { "confbox": "^0.1.8", "mlly": "^1.7.4", "pathe": "^2.0.1" } }, "sha512-/Jm5M4RvtBFVkKWRu2BLUTNP8/M2a+UwuAX+ae4770q1qVGtfjG+WTCupoZixokjmHiry8uI+dlY8KXYV5HVVQ=="], + "playwright": ["playwright@1.59.1", "", { "dependencies": { "playwright-core": "1.59.1" }, "optionalDependencies": { "fsevents": "2.3.2" }, "bin": { "playwright": "cli.js" } }, "sha512-C8oWjPR3F81yljW9o5OxcWzfh6avkVwDD2VYdwIGqTkl+OGFISgypqzfu7dOe4QNLL2aqcWBmI3PMtLIK233lw=="], + + "playwright-core": ["playwright-core@1.59.1", "", { "bin": { "playwright-core": "cli.js" } }, "sha512-HBV/RJg81z5BiiZ9yPzIiClYV/QMsDCKUyogwH9p3MCP6IYjUFu/MActgYAvK0oWyV9NlwM3GLBjADyWgydVyg=="], + "points-on-curve": ["points-on-curve@0.2.0", "", {}, "sha512-0mYKnYYe9ZcqMCWhUjItv/oHjvgEsfKvnUTg8sAtnHr3GVy7rGkXCb6d5cSyqrWqL4k81b9CPg3urd+T7aop3A=="], "points-on-path": ["points-on-path@0.2.1", "", { "dependencies": { "path-data-parser": "0.1.0", "points-on-curve": "0.2.0" } }, "sha512-25ClnWWuw7JbWZcgqY/gJ4FQWadKxGWk+3kR/7kD0tCaDtPPMj7oHu2ToLaVhfpnHrZzYby2w6tUA0eOIuUg8g=="], @@ -1301,6 +1344,8 @@ "prettier-plugin-tailwindcss": ["prettier-plugin-tailwindcss@0.7.2", "", { "peerDependencies": { "@ianvs/prettier-plugin-sort-imports": "*", "@prettier/plugin-hermes": "*", "@prettier/plugin-oxc": "*", "@prettier/plugin-pug": "*", "@shopify/prettier-plugin-liquid": "*", "@trivago/prettier-plugin-sort-imports": "*", "@zackad/prettier-plugin-twig": "*", "prettier": "^3.0", "prettier-plugin-astro": "*", "prettier-plugin-css-order": "*", "prettier-plugin-jsdoc": "*", "prettier-plugin-marko": "*", "prettier-plugin-multiline-arrays": "*", "prettier-plugin-organize-attributes": "*", "prettier-plugin-organize-imports": "*", "prettier-plugin-sort-imports": "*", "prettier-plugin-svelte": "*" }, "optionalPeers": ["@ianvs/prettier-plugin-sort-imports", "@prettier/plugin-hermes", "@prettier/plugin-oxc", "@prettier/plugin-pug", "@shopify/prettier-plugin-liquid", "@trivago/prettier-plugin-sort-imports", "@zackad/prettier-plugin-twig", "prettier-plugin-astro", "prettier-plugin-css-order", "prettier-plugin-jsdoc", "prettier-plugin-marko", "prettier-plugin-multiline-arrays", "prettier-plugin-organize-attributes", "prettier-plugin-organize-imports", "prettier-plugin-sort-imports", "prettier-plugin-svelte"] }, "sha512-LkphyK3Fw+q2HdMOoiEHWf93fNtYJwfamoKPl7UwtjFQdei/iIBoX11G6j706FzN3ymX9mPVi97qIY8328vdnA=="], + "pretty-format": ["pretty-format@27.5.1", "", { "dependencies": { "ansi-regex": "^5.0.1", "ansi-styles": "^5.0.0", "react-is": "^17.0.1" } }, "sha512-Qb1gy5OrP5+zDf2Bvnzdl3jsTf1qXVMazbvCoKhtKqVs4/YK4ozX4gKQJJVyNe+cajNPn0KoC0MC3FUmaHWEmQ=="], + "prism-react-renderer": ["prism-react-renderer@2.4.1", "", { "dependencies": { "@types/prismjs": "^1.26.0", "clsx": "^2.0.0" }, "peerDependencies": { "react": ">=16.0.0" } }, "sha512-ey8Ls/+Di31eqzUxC46h8MksNuGx/n0AAC8uKpwFau4RPDYLuE3EXTp8N8G2vX2N7UC/+IXeNUnlWBGGcAG+Ig=="], "prismjs": ["prismjs@1.30.0", "", {}, "sha512-DEvV2ZF2r2/63V+tK8hQvrR2ZGn10srHbXviTlcv7Kpzw8jWiNTqbVgjO3IY8RxrrOUF8VPMQQFysYYYv0YZxw=="], @@ -1345,6 +1390,8 @@ "react-transition-group": ["react-transition-group@4.4.5", "", { "dependencies": { "@babel/runtime": "^7.5.5", "dom-helpers": "^5.0.1", "loose-envify": "^1.4.0", "prop-types": "^15.6.2" }, "peerDependencies": { "react": ">=16.6.0", "react-dom": ">=16.6.0" } }, "sha512-pZcd1MCJoiKiBR2NRxeCRg13uCXbydPnmB4EOeRrY7480qNWO8IIgQG6zlDkm6uRMsURXPuKq0GWtiM59a5Q6g=="], + "redent": ["redent@3.0.0", "", { "dependencies": { "indent-string": "^4.0.0", "strip-indent": "^3.0.0" } }, "sha512-6tDA8g98We0zd0GvVeMT9arEOnTw9qM03L9cJXaCjrip1OO764RDBLBfrB4cwzNGDj5OA5ioymC9GkizgWJDUg=="], + "reflect.getprototypeof": ["reflect.getprototypeof@1.0.10", "", { "dependencies": { "call-bind": "^1.0.8", "define-properties": "^1.2.1", "es-abstract": "^1.23.9", "es-errors": "^1.3.0", "es-object-atoms": "^1.0.0", "get-intrinsic": "^1.2.7", "get-proto": "^1.0.1", "which-builtin-type": "^1.2.1" } }, "sha512-00o4I+DVrefhv+nX0ulyi3biSHCPDe+yLv5o/p6d/UVlirijB8E16FtfwSAi4g3tcqrQ4lRAqQSoFEZJehYEcw=="], "refractor": ["refractor@5.0.0", "", { "dependencies": { "@types/hast": "^3.0.0", "@types/prismjs": "^1.0.0", "hastscript": "^9.0.0", "parse-entities": "^4.0.0" } }, "sha512-QXOrHQF5jOpjjLfiNk5GFnWhRXvxjUVnlFxkeDmewR5sXkr3iM46Zo+CnRR8B+MDVqkULW4EcLVcRBNOPXHosw=="], @@ -1437,6 +1484,8 @@ "stringify-entities": ["stringify-entities@4.0.4", "", { "dependencies": { "character-entities-html4": "^2.0.0", "character-entities-legacy": "^3.0.0" } }, "sha512-IwfBptatlO+QCJUo19AqvrPNqlVMpW9YEL2LIVY+Rpv2qsjCGxaDLNRgeGsQWJhfItebuJhsGSLjaBbNSQ+ieg=="], + "strip-indent": ["strip-indent@3.0.0", "", { "dependencies": { "min-indent": "^1.0.0" } }, "sha512-laJTa3Jb+VQpaC6DseHhF7dXVqHTfJPCRDaEbid/drOhgitgYku/letMUqOXFoWV0zIIUbjpdH2t+tYj4bQMRQ=="], + "style-to-js": ["style-to-js@1.1.21", "", { "dependencies": { "style-to-object": "1.0.14" } }, "sha512-RjQetxJrrUJLQPHbLku6U/ocGtzyjbJMP9lCNK7Ag0CNh690nSH8woqWH9u16nMjYBAok+i7JO1NP2pOy8IsPQ=="], "style-to-object": ["style-to-object@1.0.14", "", { "dependencies": { "inline-style-parser": "0.2.7" } }, "sha512-LIN7rULI0jBscWQYaSswptyderlarFkjQ+t79nzty8tcIAceVomEVlLzH5VP4Cmsv6MtKhs7qaAiwlcp+Mgaxw=="], @@ -1549,6 +1598,8 @@ "web-namespaces": ["web-namespaces@2.0.1", "", {}, "sha512-bKr1DkiNa2krS7qxNtdrtHAmzuYGFQLiQ13TsorsdT6ULTkPLKuu5+GsFpDlg6JFjUTwX2DyhMPG2be8uPrqsQ=="], + "whatwg-mimetype": ["whatwg-mimetype@3.0.0", "", {}, "sha512-nt+N2dzIutVRxARx1nghPKGv1xHikU7HKdfafKkLNLindmPU/ch3U31NOCGGA/dmPcmb1VlofO0vnKAcsm0o/Q=="], + "which": ["which@2.0.2", "", { "dependencies": { "isexe": "^2.0.0" }, "bin": { "node-which": "./bin/node-which" } }, "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA=="], "which-boxed-primitive": ["which-boxed-primitive@1.1.1", "", { "dependencies": { "is-bigint": "^1.1.0", "is-boolean-object": "^1.2.1", "is-number-object": "^1.1.1", "is-string": "^1.1.1", "is-symbol": "^1.1.1" } }, "sha512-TbX3mj8n0odCBFVlY8AxkqcHASw3L60jIuF8jFP78az3C2YhmGvqbHBpAjTRH2/xqYunrJ9g1jSyjCjpoWzIAA=="], @@ -1561,6 +1612,8 @@ "word-wrap": ["word-wrap@1.2.5", "", {}, "sha512-BN22B5eaMMI9UMtjrGd5g5eCYPpCPDUy0FJXbYsaT5zYxjFOckS53SQDE3pWkVoWpHXVb3BrYcEN4Twa55B5cA=="], + "ws": ["ws@8.20.0", "", { "peerDependencies": { "bufferutil": "^4.0.1", "utf-8-validate": ">=5.0.2" }, "optionalPeers": ["bufferutil", "utf-8-validate"] }, "sha512-sAt8BhgNbzCtgGbt2OxmpuryO63ZoDk/sqaB/znQm94T4fCEsy/yV+7CdC1kJhOU9lboAEU7R3kquuycDoibVA=="], + "yallist": ["yallist@3.1.1", "", {}, "sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g=="], "yaml": ["yaml@1.10.2", "", {}, "sha512-r3vXyErRCYJ7wg28yvBY5VSoAF8ZvlcW9/BwUzEtUsjvX/DKs24dIkuwjtuprwJJHsbyUbLApepYTR1BN4uHrg=="], @@ -1619,6 +1672,10 @@ "@tailwindcss/oxide-wasm32-wasi/tslib": ["tslib@2.8.1", "", { "bundled": true }, "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w=="], + "@testing-library/dom/aria-query": ["aria-query@5.3.0", "", { "dependencies": { "dequal": "^2.0.3" } }, "sha512-b0P0sZPKtyu8HkeRAfCq0IfURZK+SuwMjY1UXGBU27wpAiTwQAIlq56IbIO+ytk/JjS1fMR14ee5WBBfKi5J6A=="], + + "@testing-library/dom/dom-accessibility-api": ["dom-accessibility-api@0.5.16", "", {}, "sha512-X7BJ2yElsnOJ30pZF4uIIDfBEVgF4XEBxL9Bxhy6dnrm5hkzqmsWHGTiHqRiITNhMyFLyAiWndIJP7Z1NTteDg=="], + "@types/react-i18next/react-i18next": ["react-i18next@16.4.1", "", { "dependencies": { "@babel/runtime": "^7.27.6", "html-parse-stringify": "^3.0.1", "use-sync-external-store": "^1.6.0" }, "peerDependencies": { "i18next": ">= 25.6.2", "react": ">= 16.8.0", "typescript": "^5" }, "optionalPeers": ["typescript"] }, "sha512-GzsYomxb1/uE7nlJm0e1qQ8f+W9I3Xirh9VoycZIahk6C8Pmv/9Fd0ek6zjf1FSgtGLElDGqwi/4FOHEGUbsEQ=="], "@types/react-syntax-highlighter/@types/react": ["@types/react@19.2.7", "", { "dependencies": { "csstype": "^3.2.2" } }, "sha512-MWtvHrGZLFttgeEj28VXHxpmwYbor/ATPYbBfSFZEIRK0ecCFLl2Qo55z52Hss+UV9CRN7trSeq1zbgx7YDWWg=="], @@ -1685,6 +1742,10 @@ "parse-entities/@types/unist": ["@types/unist@2.0.11", "", {}, "sha512-CmBKiL6NNo/OqgmMn95Fk9Whlp2mtvIv+KNpQKN2F4SjvrEesubTRWGYSg+BnWZOnlCaSTU1sMpsBOzgbYhnsA=="], + "parse5/entities": ["entities@6.0.1", "", {}, "sha512-aN97NXWF6AWBTahfVOIrB/NShkzi5H7F9r1s9mD3cDj4Ko5f2qhhVoYMibXF7GlLveb/D2ioWay8lxI97Ven3g=="], + + "pretty-format/react-is": ["react-is@17.0.2", "", {}, "sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w=="], + "react-markdown/unist-util-visit": ["unist-util-visit@5.0.0", "", { "dependencies": { "@types/unist": "^3.0.0", "unist-util-is": "^6.0.0", "unist-util-visit-parents": "^6.0.0" } }, "sha512-MR04uvD+07cwl/yhVuVWAtw+3GOR/knlL55Nd/wAdblk27GCVt3lqpTivy/tkJcZoNPzTwS1Y+KMojlLDhoTzg=="], "react-select/@babel/runtime": ["@babel/runtime@7.28.4", "", {}, "sha512-Q/N6JNWvIvPnLDvjlE1OUBLPQHH6l3CltCEsHIujp45zQUSSh8K+gHnaEX45yAT1nyngnINhvWtzN+Nb9D8RAQ=="], @@ -1703,6 +1764,8 @@ "unist-util-remove-position/unist-util-visit": ["unist-util-visit@5.0.0", "", { "dependencies": { "@types/unist": "^3.0.0", "unist-util-is": "^6.0.0", "unist-util-visit-parents": "^6.0.0" } }, "sha512-MR04uvD+07cwl/yhVuVWAtw+3GOR/knlL55Nd/wAdblk27GCVt3lqpTivy/tkJcZoNPzTwS1Y+KMojlLDhoTzg=="], + "vite/fsevents": ["fsevents@2.3.3", "", { "os": "darwin" }, "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw=="], + "vite/picomatch": ["picomatch@4.0.4", "", {}, "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A=="], "@types/react-i18next/react-i18next/@babel/runtime": ["@babel/runtime@7.28.4", "", {}, "sha512-Q/N6JNWvIvPnLDvjlE1OUBLPQHH6l3CltCEsHIujp45zQUSSh8K+gHnaEX45yAT1nyngnINhvWtzN+Nb9D8RAQ=="], diff --git a/lightrag_webui/package.json b/lightrag_webui/package.json index b8bc031bcd..2499db7414 100644 --- a/lightrag_webui/package.json +++ b/lightrag_webui/package.json @@ -82,12 +82,15 @@ }, "devDependencies": { "@eslint/js": "^10.0.1", + "@playwright/test": "^1.59.1", "@stylistic/eslint-plugin": "^5.10.0", - "@types/bun": "^1.3.12", + "@tailwindcss/typography": "^0.5.15", "@tailwindcss/vite": "^4.2.2", + "@testing-library/jest-dom": "^6.9.1", + "@testing-library/react": "^16.3.2", + "@types/bun": "^1.3.12", "@types/katex": "^0.16.8", "@types/node": "^25.6.0", - "@tailwindcss/typography": "^0.5.15", "@types/react": "^19.2.14", "@types/react-dom": "^19.2.3", "@types/react-i18next": "^8.1.0", @@ -101,12 +104,14 @@ "eslint-plugin-react-refresh": "^0.5.2", "globals": "^17.5.0", "graphology-types": "^0.24.8", + "happy-dom": "^20.9.0", + "playwright": "^1.59.1", "prettier": "^3.8.3", "prettier-plugin-tailwindcss": "^0.7.2", - "typescript-eslint": "^8.58.2", "tailwindcss": "^4.2.2", "tailwindcss-animate": "^1.0.7", "typescript": "~5.9.3", + "typescript-eslint": "^8.58.2", "vite": "^8.0.9" } } diff --git a/lightrag_webui/src/App.tsx b/lightrag_webui/src/App.tsx index b8ae023d7a..bc9af40684 100644 --- a/lightrag_webui/src/App.tsx +++ b/lightrag_webui/src/App.tsx @@ -15,10 +15,13 @@ import GraphViewer from '@/features/GraphViewer' import DocumentManager from '@/features/DocumentManager' import RetrievalTesting from '@/features/RetrievalTesting' import ApiSite from '@/features/ApiSite' +import useWorkspaceChange from '@/hooks/useWorkspaceChange' import { Tabs, TabsContent } from '@/components/ui/Tabs' function App() { + // Handle workspace changes - clears all workspace-dependent data when workspace changes + useWorkspaceChange() const message = useBackendState.use.message() const enableHealthCheck = useSettingsStore.use.enableHealthCheck() const currentTab = useSettingsStore.use.currentTab() diff --git a/lightrag_webui/src/api/lightrag.ts b/lightrag_webui/src/api/lightrag.ts index 5345479ec3..74d0564e3f 100644 --- a/lightrag_webui/src/api/lightrag.ts +++ b/lightrag_webui/src/api/lightrag.ts @@ -5,6 +5,10 @@ import { useSettingsStore } from '@/stores/settings' import { useAuthStore } from '@/stores/state' import { navigationService } from '@/services/navigation' +// CRLF injection sanitization for header values +export const sanitizeHeader = (value: string | null): string | null => + value?.replace(/[\r\n]/g, '') ?? null + // Types export type LightragNodeType = { id: string @@ -78,6 +82,12 @@ export type LightragStatus = { webui_description?: string } +export type Workspace = { + name: string + first_seen: string + last_seen: string +} + export type LightragDocumentsScanProgress = { is_scanning: boolean current_file: string @@ -279,7 +289,7 @@ export const InvalidApiKeyError = 'Invalid API Key' export const RequireApiKeError = 'API Key required' // Axios instance -const axiosInstance = axios.create({ +export const axiosInstance = axios.create({ baseURL: backendBaseUrl, headers: { 'Content-Type': 'application/json' @@ -352,6 +362,11 @@ axiosInstance.interceptors.request.use((config) => { if (apiKey) { config.headers['X-API-Key'] = apiKey } + // Workspace header + const workspace = sanitizeHeader(useSettingsStore.getState().currentWorkspace) + if (workspace) { + config.headers['LIGHTRAG-WORKSPACE'] = workspace + } return config }) @@ -530,6 +545,10 @@ export const queryTextStream = async ( if (apiKey) { headers['X-API-Key'] = apiKey; } + const workspace = sanitizeHeader(useSettingsStore.getState().currentWorkspace); + if (workspace) { + headers['LIGHTRAG-WORKSPACE'] = workspace; + } try { const response = await fetch(`${backendBaseUrl}/query/stream`, { @@ -1180,3 +1199,11 @@ export const getDocumentStatusCounts = async (): Promise = const response = await axiosInstance.get('/documents/status_counts') return response.data } + +export const getWorkspaces = async (): Promise => { + const response = await axiosInstance.get('/workspaces') + if (!response.data?.workspaces || !Array.isArray(response.data.workspaces)) { + throw new Error('Invalid workspaces response') + } + return response.data.workspaces +} diff --git a/lightrag_webui/src/api/workspace.test.ts b/lightrag_webui/src/api/workspace.test.ts new file mode 100644 index 0000000000..82afedccff --- /dev/null +++ b/lightrag_webui/src/api/workspace.test.ts @@ -0,0 +1,209 @@ +/** + * Workspace API Tests + * + * Tests for sanitizeHeader, getWorkspaces, and header injection logic. + * + * IMPORTANT: All imports of modules that depend on localStorage must be dynamic + * (inside beforeAll) because import statements are hoisted before runtime code. + */ + +import { afterEach, beforeAll, describe, expect, test, mock } from 'bun:test' +import type { Workspace } from './lightrag' + +type LightragApiModule = typeof import('./lightrag') +type SettingsModule = typeof import('@/stores/settings') + +const storageMock = () => { + const data = new Map() + return { + getItem: (key: string) => data.get(key) ?? null, + setItem: (key: string, value: string) => data.set(key, value), + removeItem: (key: string) => data.delete(key), + clear: () => data.clear() + } +} + +let api: LightragApiModule +let settings: SettingsModule + +beforeAll(async () => { + Object.defineProperty(globalThis, 'localStorage', { + value: storageMock(), + configurable: true + }) + Object.defineProperty(globalThis, 'sessionStorage', { + value: storageMock(), + configurable: true + }) + + // Dynamic import ensures localStorage is set before module evaluation + api = await import('./lightrag') + settings = await import('@/stores/settings') +}) + +describe('sanitizeHeader', () => { + test('returns null when input is null', () => { + expect(api.sanitizeHeader(null)).toBeNull() + }) + + test('returns the string unchanged when no CRLF characters', () => { + expect(api.sanitizeHeader('my-workspace')).toBe('my-workspace') + expect(api.sanitizeHeader('workspace-123')).toBe('workspace-123') + expect(api.sanitizeHeader('my workspace with spaces')).toBe('my workspace with spaces') + }) + + test('strips \\r from the value', () => { + expect(api.sanitizeHeader('my\rworkspace')).toBe('myworkspace') + expect(api.sanitizeHeader('\rworkspace')).toBe('workspace') + expect(api.sanitizeHeader('workspace\r')).toBe('workspace') + }) + + test('strips \\n from the value', () => { + expect(api.sanitizeHeader('my\nworkspace')).toBe('myworkspace') + expect(api.sanitizeHeader('\nworkspace')).toBe('workspace') + expect(api.sanitizeHeader('workspace\n')).toBe('workspace') + }) + + test('strips both \\r\\n from the value', () => { + expect(api.sanitizeHeader('my\r\nworkspace')).toBe('myworkspace') + expect(api.sanitizeHeader('\r\nworkspace')).toBe('workspace') + expect(api.sanitizeHeader('workspace\r\n')).toBe('workspace') + }) + + test('handles string with multiple CRLF sequences', () => { + expect(api.sanitizeHeader('my\r\nwork\r\nspace')).toBe('myworkspace') + expect(api.sanitizeHeader('\r\n\r\n')).toBe('') + expect(api.sanitizeHeader('a\rb\rc')).toBe('abc') + expect(api.sanitizeHeader('a\nb\nc')).toBe('abc') + expect(api.sanitizeHeader('a\r\nb\r\nc')).toBe('abc') + }) + + test('returns empty string when input is empty string (not null)', () => { + expect(api.sanitizeHeader('')).toBe('') + }) +}) + +describe('getWorkspaces', () => { + const mockWorkspaces: Workspace[] = [ + { name: 'workspace-1', first_seen: '2024-01-01', last_seen: '2024-01-15' }, + { name: 'workspace-2', first_seen: '2024-01-02', last_seen: '2024-01-16' } + ] + + let originalGet: typeof api.axiosInstance.get + + beforeAll(() => { + originalGet = api.axiosInstance.get + }) + + afterEach(() => { + api.axiosInstance.get = originalGet + }) + + test('returns workspace array on valid response { workspaces: [...] }', async () => { + api.axiosInstance.get = mock(async () => ({ + data: { workspaces: mockWorkspaces } + })) as typeof api.axiosInstance.get + + const result = await api.getWorkspaces() + expect(result).toEqual(mockWorkspaces) + }) + + test('throws Error when response has no workspaces field', async () => { + api.axiosInstance.get = mock(async () => ({ + data: {} + })) as typeof api.axiosInstance.get + + await expect(api.getWorkspaces()).rejects.toThrow('Invalid workspaces response') + }) + + test('throws Error when workspaces is not an array', async () => { + api.axiosInstance.get = mock(async () => ({ + data: { workspaces: 'not-an-array' } + })) as typeof api.axiosInstance.get + + await expect(api.getWorkspaces()).rejects.toThrow('Invalid workspaces response') + }) + + test('throws Error when response is null/undefined', async () => { + api.axiosInstance.get = mock(async () => ({ + data: null + })) as typeof api.axiosInstance.get + + await expect(api.getWorkspaces()).rejects.toThrow('Invalid workspaces response') + }) +}) + +describe('Header injection (axios interceptor)', () => { + afterEach(() => { + settings.useSettingsStore.getState().setCurrentWorkspace(null) + }) + + test('when currentWorkspace is null, LIGHTRAG-WORKSPACE header is NOT set', () => { + settings.useSettingsStore.getState().setCurrentWorkspace(null) + + const config = { + headers: {} as Record + } + + const workspace = api.sanitizeHeader(settings.useSettingsStore.getState().currentWorkspace) + expect(workspace).toBeNull() + + if (workspace) { + config.headers['LIGHTRAG-WORKSPACE'] = workspace + } + + expect(config.headers['LIGHTRAG-WORKSPACE']).toBeUndefined() + }) + + test('when currentWorkspace is a string, LIGHTRAG-WORKSPACE header IS set with that value', () => { + settings.useSettingsStore.getState().setCurrentWorkspace('my-test-workspace') + + const config = { + headers: {} as Record + } + + const workspace = api.sanitizeHeader(settings.useSettingsStore.getState().currentWorkspace) + expect(workspace).toBe('my-test-workspace') + + if (workspace) { + config.headers['LIGHTRAG-WORKSPACE'] = workspace + } + + expect(config.headers['LIGHTRAG-WORKSPACE']).toBe('my-test-workspace') + }) + + test('when currentWorkspace contains \\r\\n, header value has them stripped', () => { + settings.useSettingsStore.getState().setCurrentWorkspace('my-workspace\r\ninjected') + + expect(settings.useSettingsStore.getState().currentWorkspace).toBe('my-workspace\r\ninjected') + + const workspace = api.sanitizeHeader(settings.useSettingsStore.getState().currentWorkspace) + if (workspace) { + expect(workspace).toBe('my-workspaceinjected') + expect(workspace).not.toContain('\r') + expect(workspace).not.toContain('\n') + } + }) + + test('header is sanitized via sanitizeHeader before being set', () => { + const testCases = [ + { input: 'workspace\rwith-cr' as string | null, expected: 'workspacewith-cr' }, + { input: 'workspace\nwith-lf' as string | null, expected: 'workspacewith-lf' }, + { input: 'workspace\r\nwith-crlf' as string | null, expected: 'workspacewith-crlf' }, + { input: 'normal-workspace' as string | null, expected: 'normal-workspace' }, + { input: null, expected: null } + ] + + for (const { input, expected } of testCases) { + if (input === null) { + settings.useSettingsStore.getState().setCurrentWorkspace(null) + } else { + settings.useSettingsStore.getState().setCurrentWorkspace(input) + } + + const rawValue = settings.useSettingsStore.getState().currentWorkspace + const sanitizedValue = api.sanitizeHeader(rawValue) + expect(sanitizedValue).toBe(expected) + } + }) +}) diff --git a/lightrag_webui/src/components/WorkspaceSelector.test.tsx b/lightrag_webui/src/components/WorkspaceSelector.test.tsx new file mode 100644 index 0000000000..fa5d4ddb62 --- /dev/null +++ b/lightrag_webui/src/components/WorkspaceSelector.test.tsx @@ -0,0 +1,318 @@ +/** + * WorkspaceSelector Logic Tests + * + * Tests the core logic of the WorkspaceSelector component: + * - Fetch workspaces logic + * - Stale workspace detection + * - Handle change logic + * - Selection handling + * + * This approach tests the business logic without requiring a full DOM environment, + * which provides reliable test coverage for the component's behavior. + */ + +import { beforeAll, beforeEach, describe, expect, test, vi } from 'bun:test' + +// ============================================================================ +// TYPES & MOCKS +// ============================================================================ + +interface Workspace { + name: string + first_seen: string + last_seen: string +} + +// Mock workspace state (used by pure logic tests, not via vi.mock) +let mockCurrentWorkspace: string | null = null + +const mockSetCurrentWorkspace = vi.fn((ws: string | null) => { + mockCurrentWorkspace = ws +}) + +let mockGetWorkspaces: ReturnType + +// NOTE: No vi.mock calls used here. This test uses pure logic functions +// extracted from the component β€” no actual module imports needed. +// Using vi.mock would pollute the module cache for other test files. + +// ============================================================================ +// PURE FUNCTIONS (extracted from component logic) +// ============================================================================ + +/** + * Check if a workspace is stale (not in the current list) + */ +const isWorkspaceStale = ( + currentWorkspace: string | null, + workspaces: Workspace[] +): boolean => { + if (!currentWorkspace) return false + return !workspaces.some((w) => w.name === currentWorkspace) +} + +/** + * Handle selection change - convert empty string to null + */ +const handleChangeValue = (value: string): string | null => { + return value === '' ? null : value +} + +/** + * Fetch workspaces with stale detection + */ +const fetchWorkspacesLogic = async ( + getWorkspaces: () => Promise, + getCurrentWorkspace: () => string | null, + setCurrentWorkspace: (ws: string | null) => void +): Promise => { + const data = await getWorkspaces() + + // Check if current workspace is still in the list + const current = getCurrentWorkspace() + if (current && !data.some((w) => w.name === current)) { + setCurrentWorkspace(null) + } + + return data +} + +/** + * Create workspace objects for testing + */ +const createWorkspaces = (names: string[]): Workspace[] => + names.map((name) => ({ + name, + first_seen: new Date().toISOString(), + last_seen: new Date().toISOString() + })) + +// ============================================================================ +// TESTS +// ============================================================================ + +describe('WorkspaceSelector Logic', () => { + beforeAll(() => { + mockGetWorkspaces = vi.fn() + }) + + beforeEach(() => { + vi.clearAllMocks() + mockCurrentWorkspace = null + }) + + describe('isWorkspaceStale', () => { + test('returns false when currentWorkspace is null', () => { + const workspaces = createWorkspaces(['workspace-a', 'workspace-b']) + expect(isWorkspaceStale(null, workspaces)).toBe(false) + }) + + test('returns false when currentWorkspace exists in list', () => { + const workspaces = createWorkspaces(['workspace-a', 'workspace-b']) + expect(isWorkspaceStale('workspace-a', workspaces)).toBe(false) + }) + + test('returns true when currentWorkspace does not exist in list', () => { + const workspaces = createWorkspaces(['workspace-a', 'workspace-b']) + expect(isWorkspaceStale('workspace-x', workspaces)).toBe(true) + }) + + test('returns false for empty string (matches component behavior with short-circuit)', () => { + // Empty string is falsy, so component's `if (current && ...)` short-circuits + const workspaces = createWorkspaces(['workspace-a', 'workspace-b']) + expect(isWorkspaceStale('', workspaces)).toBe(false) + }) + }) + + describe('handleChangeValue', () => { + test('returns null when value is empty string', () => { + expect(handleChangeValue('')).toBe(null) + }) + + test('returns the value when not empty', () => { + expect(handleChangeValue('workspace-a')).toBe('workspace-a') + }) + + test('preserves workspace names with special characters', () => { + expect(handleChangeValue('my-workspace_v1')).toBe('my-workspace_v1') + expect(handleChangeValue('workspace.with.dots')).toBe('workspace.with.dots') + }) + }) + + describe('fetchWorkspacesLogic', () => { + test('fetches workspaces from API', async () => { + const workspaces = createWorkspaces(['workspace-a', 'workspace-b']) + mockGetWorkspaces.mockResolvedValueOnce(workspaces) + + const result = await fetchWorkspacesLogic( + mockGetWorkspaces, + () => null, + mockSetCurrentWorkspace + ) + + expect(mockGetWorkspaces).toHaveBeenCalledTimes(1) + expect(result).toEqual(workspaces) + }) + + test('resets current workspace when it is stale', async () => { + mockCurrentWorkspace = 'workspace-x' + const workspaces = createWorkspaces(['workspace-a', 'workspace-b']) + mockGetWorkspaces.mockResolvedValueOnce(workspaces) + + await fetchWorkspacesLogic( + mockGetWorkspaces, + () => mockCurrentWorkspace, + mockSetCurrentWorkspace + ) + + expect(mockSetCurrentWorkspace).toHaveBeenCalledWith(null) + }) + + test('does not reset current workspace when it exists in list', async () => { + mockCurrentWorkspace = 'workspace-a' + const workspaces = createWorkspaces(['workspace-a', 'workspace-b']) + mockGetWorkspaces.mockResolvedValueOnce(workspaces) + + await fetchWorkspacesLogic( + mockGetWorkspaces, + () => mockCurrentWorkspace, + mockSetCurrentWorkspace + ) + + expect(mockSetCurrentWorkspace).not.toHaveBeenCalled() + }) + + test('does not reset when current workspace is null', async () => { + mockCurrentWorkspace = null + const workspaces = createWorkspaces(['workspace-a', 'workspace-b']) + mockGetWorkspaces.mockResolvedValueOnce(workspaces) + + await fetchWorkspacesLogic( + mockGetWorkspaces, + () => null, + mockSetCurrentWorkspace + ) + + expect(mockSetCurrentWorkspace).not.toHaveBeenCalled() + }) + + test('propagates API errors', async () => { + mockGetWorkspaces.mockRejectedValueOnce(new Error('Network error')) + + await expect( + fetchWorkspacesLogic( + mockGetWorkspaces, + () => null, + mockSetCurrentWorkspace + ) + ).rejects.toThrow('Network error') + }) + }) + + describe('Workspace Data Creation', () => { + test('creates workspace objects with correct structure', () => { + const workspaces = createWorkspaces(['test-workspace']) + + expect(workspaces).toHaveLength(1) + expect(workspaces[0]).toHaveProperty('name', 'test-workspace') + expect(workspaces[0]).toHaveProperty('first_seen') + expect(workspaces[0]).toHaveProperty('last_seen') + }) + + test('creates multiple workspaces', () => { + const workspaces = createWorkspaces(['ws1', 'ws2', 'ws3']) + + expect(workspaces).toHaveLength(3) + expect(workspaces.map((w) => w.name)).toEqual(['ws1', 'ws2', 'ws3']) + }) + + test('each workspace has unique timestamps', () => { + const workspaces = createWorkspaces(['ws1', 'ws2']) + + // Timestamps should be ISO strings + expect(workspaces[0].first_seen).toMatch(/^\d{4}-\d{2}-\d{2}T/) + expect(workspaces[0].last_seen).toMatch(/^\d{4}-\d{2}-\d{2}T/) + }) + }) + + describe('Integration: Complete Workspace Flow', () => { + test('complete flow: fetch -> stale check -> reset', async () => { + // Simulate: user has 'old-workspace' selected + mockCurrentWorkspace = 'old-workspace' + + // API returns only new workspaces (no 'old-workspace') + const workspaces = createWorkspaces(['new-workspace-1', 'new-workspace-2']) + mockGetWorkspaces.mockResolvedValueOnce(workspaces) + + // Simulate the component's fetch logic + const data = await fetchWorkspacesLogic( + mockGetWorkspaces, + () => mockCurrentWorkspace, + mockSetCurrentWorkspace + ) + + // Verify: workspace was fetched + expect(data).toEqual(workspaces) + + // Verify: stale workspace was detected and reset + expect(mockSetCurrentWorkspace).toHaveBeenCalledWith(null) + }) + + test('complete flow: fetch -> no stale -> no reset', async () => { + // Simulate: user has 'current-workspace' selected + mockCurrentWorkspace = 'current-workspace' + + // API returns workspaces including 'current-workspace' + const workspaces = createWorkspaces(['current-workspace', 'other-workspace']) + mockGetWorkspaces.mockResolvedValueOnce(workspaces) + + // Simulate the component's fetch logic + const data = await fetchWorkspacesLogic( + mockGetWorkspaces, + () => mockCurrentWorkspace, + mockSetCurrentWorkspace + ) + + // Verify: workspace was fetched + expect(data).toEqual(workspaces) + + // Verify: no reset occurred + expect(mockSetCurrentWorkspace).not.toHaveBeenCalled() + }) + + test('complete flow: no workspace selected', async () => { + // Simulate: no workspace selected + mockCurrentWorkspace = null + + // API returns workspaces + const workspaces = createWorkspaces(['workspace-a', 'workspace-b']) + mockGetWorkspaces.mockResolvedValueOnce(workspaces) + + // Simulate the component's fetch logic + const data = await fetchWorkspacesLogic( + mockGetWorkspaces, + () => null, + mockSetCurrentWorkspace + ) + + // Verify: workspace was fetched + expect(data).toEqual(workspaces) + + // Verify: no reset occurred (nothing to reset) + expect(mockSetCurrentWorkspace).not.toHaveBeenCalled() + }) + }) + + describe('Refresh Interval Logic', () => { + const REFRESH_INTERVAL_MS = 30_000 + + test('refresh interval is 30 seconds', () => { + expect(REFRESH_INTERVAL_MS).toBe(30_000) + }) + + test('refresh interval calculation for 1 minute', () => { + const oneMinute = 60_000 + expect(oneMinute / REFRESH_INTERVAL_MS).toBe(2) + }) + }) +}) diff --git a/lightrag_webui/src/components/WorkspaceSelector.tsx b/lightrag_webui/src/components/WorkspaceSelector.tsx new file mode 100644 index 0000000000..601729ab5a --- /dev/null +++ b/lightrag_webui/src/components/WorkspaceSelector.tsx @@ -0,0 +1,75 @@ +import { useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useSettingsStore } from '@/stores/settings' +import { getWorkspaces, type Workspace } from '@/api/lightrag' +import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/Tooltip' +import { cn } from '@/lib/utils' + +const REFRESH_INTERVAL_MS = 30_000 + +interface WorkspaceSelectorProps { + className?: string +} + +export default function WorkspaceSelector({ className }: WorkspaceSelectorProps) { + const { t } = useTranslation() + const currentWorkspace = useSettingsStore.use.currentWorkspace() + const setCurrentWorkspace = useSettingsStore.use.setCurrentWorkspace() + const [workspaces, setWorkspaces] = useState([]) + const [isLoading, setIsLoading] = useState(false) + + const fetchWorkspaces = async () => { + setIsLoading(true) + try { + const data = await getWorkspaces() + setWorkspaces(data) + // Check if current workspace is still in the list + const current = useSettingsStore.getState().currentWorkspace + if (current && !data.some((w: { name: string }) => w.name === current)) { + useSettingsStore.getState().setCurrentWorkspace(null) + } + } catch { + // Graceful degradation: keep empty list, show only "None" option + setWorkspaces([]) + } finally { + setIsLoading(false) + } + } + + useEffect(() => { + fetchWorkspaces() + const interval = setInterval(fetchWorkspaces, REFRESH_INTERVAL_MS) + return () => clearInterval(interval) + }, []) + + const handleChange = (e: React.ChangeEvent) => { + const value = e.target.value + setCurrentWorkspace(value === '' ? null : value) + } + + return ( + + + + + + + {t('workspace.selector', 'Workspace')} + + + + ) +} diff --git a/lightrag_webui/src/components/graph/GraphLabels.tsx b/lightrag_webui/src/components/graph/GraphLabels.tsx index 3c0ee1ae2c..0b73a88a2e 100644 --- a/lightrag_webui/src/components/graph/GraphLabels.tsx +++ b/lightrag_webui/src/components/graph/GraphLabels.tsx @@ -116,6 +116,40 @@ const GraphLabels = () => { } }, []) + // Watch workspace refresh trigger to reset label state when workspace changes + const workspaceRefreshTrigger = useSettingsStore.use.workspaceRefreshTrigger() + + useEffect(() => { + if (workspaceRefreshTrigger === 0) return + + const loadLabels = async () => { + SearchHistoryManager.clearHistory() + + const fetchPopularLabels = async () => { + try { + const popularLabels = await getPopularLabels(popularLabelsDefaultLimit) + + if (popularLabels.length === 0) { + const fallbackLabels = ['entity', 'relationship', 'document', 'concept'] + await SearchHistoryManager.initializeWithDefaults(fallbackLabels) + } else { + await SearchHistoryManager.initializeWithDefaults(popularLabels) + } + } catch (error) { + console.error('Failed to re-fetch popular labels after workspace change:', error) + const fallbackLabels = ['entity', 'relationship', 'document'] + SearchHistoryManager.clearHistory() + await SearchHistoryManager.initializeWithDefaults(fallbackLabels) + } + } + + await fetchPopularLabels() + bumpDropdownData({ forceSelectKey: true }) + } + + loadLabels() + }, [workspaceRefreshTrigger, bumpDropdownData]) + const fetchData = useCallback( async (query?: string): Promise => { let results: string[]; diff --git a/lightrag_webui/src/features/DocumentManager.tsx b/lightrag_webui/src/features/DocumentManager.tsx index 51a90d3b5e..e082bd7d42 100644 --- a/lightrag_webui/src/features/DocumentManager.tsx +++ b/lightrag_webui/src/features/DocumentManager.tsx @@ -291,6 +291,8 @@ export default function DocumentManager() { const activeRefreshPromiseRef = useRef | null>(null); const pendingRefreshRequestRef = useRef(null); const latestRefreshRequestVersionRef = useRef(0); + // Ref to store enqueueRefresh function for workspace change effect + const enqueueRefreshRef = useRef<((request: RefreshRequest) => Promise) | null>(null); // Add retry mechanism state const [retryState, setRetryState] = useState({ @@ -583,6 +585,62 @@ export default function DocumentManager() { }; }, [docs]); + // Watch workspace refresh trigger to reset document state when workspace changes + const workspaceRefreshTrigger = useSettingsStore.use.workspaceRefreshTrigger() + + useEffect(() => { + if (workspaceRefreshTrigger > 0) { + console.log('Workspace changed, clearing document state') + + // Get current page size from store to avoid stale closure + const currentPageSize = useSettingsStore.getState().documentsPageSize + + // Clear document state + setDocs(null) + setCurrentPageDocs([]) + setPagination({ + page: 1, + page_size: currentPageSize, + total_count: 0, + total_pages: 0, + has_next: false, + has_prev: false + }) + setStatusCounts({ all: 0 }) + setSelectedDocIds([]) + + // Stop any ongoing polling by directly manipulating the interval ref + if (pollingIntervalRef.current) { + clearInterval(pollingIntervalRef.current) + pollingIntervalRef.current = null + } + + // Reset version counter to ensure new fetch is not ignored + latestRefreshRequestVersionRef.current += 1 + + // Trigger a new fetch using the ref to avoid dependency issues + const versionRef = latestRefreshRequestVersionRef + const enqueueRef = enqueueRefreshRef + + // Postpone the refresh slightly to allow state clearing to complete + setTimeout(() => { + if (enqueueRef.current) { + enqueueRef.current({ + type: 'intelligent', + query: { + statusFilter: 'all', + page: 1, + pageSize: currentPageSize, + sortField: 'updated_at', + sortDirection: 'desc' + }, + requestVersion: versionRef.current + }) + } + }, 0) + } + }, [workspaceRefreshTrigger]) + const buildQuerySnapshot = useCallback(( overrides: Partial = {} ): QuerySnapshot => ({ @@ -852,6 +910,11 @@ export default function DocumentManager() { } }, [runRefreshRequest]); + // Store enqueueRefresh in ref for use in workspace change effect + useEffect(() => { + enqueueRefreshRef.current = enqueueRefresh + }, [enqueueRefresh]) + // Intelligent refresh function: handles all boundary cases const handleIntelligentRefresh = useCallback(async ( targetPage?: number, diff --git a/lightrag_webui/src/features/RetrievalTesting.tsx b/lightrag_webui/src/features/RetrievalTesting.tsx index 2865ec1276..63d8ab1fd4 100644 --- a/lightrag_webui/src/features/RetrievalTesting.tsx +++ b/lightrag_webui/src/features/RetrievalTesting.tsx @@ -528,15 +528,14 @@ export default function RetrievalTesting() { const messagesEndRef = useRef(null) const messagesContainerRef = useRef(null) - // Add cleanup effect for memory leak prevention + // Watch workspace refresh trigger to clear messages when workspace changes + const workspaceRefreshTrigger = useSettingsStore.use.workspaceRefreshTrigger() + useEffect(() => { - // Component cleanup - reset timer state to prevent memory leaks - return () => { - if (thinkingStartTime.current) { - thinkingStartTime.current = null; - } - }; - }, []); + if (workspaceRefreshTrigger > 0) { + setMessages([]) + } + }, [workspaceRefreshTrigger]) // Add event listeners to detect when user manually interacts with the container useEffect(() => { diff --git a/lightrag_webui/src/features/SiteHeader.tsx b/lightrag_webui/src/features/SiteHeader.tsx index 1e0c1027b7..3bd9ebda19 100644 --- a/lightrag_webui/src/features/SiteHeader.tsx +++ b/lightrag_webui/src/features/SiteHeader.tsx @@ -1,6 +1,7 @@ import Button from '@/components/ui/Button' import { SiteInfo, webuiPrefix } from '@/lib/constants' import AppSettings from '@/components/AppSettings' +import WorkspaceSelector from '@/components/WorkspaceSelector' import { TabsList, TabsTrigger } from '@/components/ui/Tabs' import { useSettingsStore } from '@/stores/settings' import { useAuthStore } from '@/stores/state' @@ -99,6 +100,10 @@ export default function SiteHeader() { )} + {!webuiTitle && ( + | + )} +
diff --git a/lightrag_webui/src/hooks/useLightragGraph.tsx b/lightrag_webui/src/hooks/useLightragGraph.tsx index 4e4ef87160..1b9e37899f 100644 --- a/lightrag_webui/src/hooks/useLightragGraph.tsx +++ b/lightrag_webui/src/hooks/useLightragGraph.tsx @@ -373,8 +373,12 @@ const useLightrangeGraph = () => { toast.info(t('graphPanel.dataIsTruncated', 'Graph data is truncated to Max Nodes')); } - // Reset state - state.reset() + // Clear selection and graph data, but preserve fetch flags to avoid re-triggering the fetch useEffect + state.clearSelection() + state.setRawGraph(null) + state.setSigmaGraph(null) + state.setSearchEngine(null) + state.setGraphIsEmpty(false) // Check if data is empty or invalid if (!data || !data.nodes || data.nodes.length === 0) { diff --git a/lightrag_webui/src/hooks/useWorkspaceChange.test.ts b/lightrag_webui/src/hooks/useWorkspaceChange.test.ts new file mode 100644 index 0000000000..e8e66fb08c --- /dev/null +++ b/lightrag_webui/src/hooks/useWorkspaceChange.test.ts @@ -0,0 +1,342 @@ +import { beforeAll, describe, expect, test } from 'bun:test' + +const storageMock = () => { + const data = new Map() + + return { + getItem: (key: string) => data.get(key) ?? null, + setItem: (key: string, value: string) => { + data.set(key, value) + }, + removeItem: (key: string) => { + data.delete(key) + }, + clear: () => { + data.clear() + } + } +} + +// ============================================ +// Pure Logic Functions (extracted from hook) +// ============================================ + +/** + * Determines if a workspace refresh should be triggered. + * Matches the condition in useWorkspaceChange: previous !== current + */ +const shouldTriggerRefresh = ( + previous: string | null, + current: string | null +): boolean => { + return previous !== current +} + +/** + * Executes the actions that should happen on workspace change. + * These are the side effects extracted from useWorkspaceChange. + */ +const executeWorkspaceChangeActions = ( + settingsStore: { setRetrievalHistory: (h: unknown[]) => void }, + graphStore: { reset: () => void } +) => { + graphStore.reset() + settingsStore.setRetrievalHistory([]) +} + +// ============================================ +// Tests: shouldTriggerRefresh (Pure Logic) +// ============================================ + +describe('shouldTriggerRefresh (pure logic)', () => { + test('Same workspace (null β†’ null): returns false', () => { + expect(shouldTriggerRefresh(null, null)).toBe(false) + }) + + test('Same workspace ("ws" β†’ "ws"): returns false', () => { + expect(shouldTriggerRefresh('ws', 'ws')).toBe(false) + }) + + test('Changed workspace (null β†’ "ws"): returns true', () => { + expect(shouldTriggerRefresh(null, 'ws')).toBe(true) + }) + + test('Changed workspace ("ws" β†’ null): returns true', () => { + expect(shouldTriggerRefresh('ws', null)).toBe(true) + }) + + test('Changed workspace ("ws-a" β†’ "ws-b"): returns true', () => { + expect(shouldTriggerRefresh('ws-a', 'ws-b')).toBe(true) + }) + + test('Changed workspace ("ws-a" β†’ "ws-a"): returns false', () => { + expect(shouldTriggerRefresh('ws-a', 'ws-a')).toBe(false) + }) +}) + +// ============================================ +// Integration Tests with Real Stores +// ============================================ + +describe('Workspace change integration with real stores', () => { + let useSettingsStore: ReturnType + let useGraphStore: ReturnType + + beforeAll(async () => { + Object.defineProperty(globalThis, 'localStorage', { + value: storageMock(), + configurable: true + }) + Object.defineProperty(globalThis, 'sessionStorage', { + value: storageMock(), + configurable: true + }) + + const [settingsModule, graphModule] = await Promise.all([ + import('@/stores/settings'), + import('@/stores/graph') + ]) + + useSettingsStore = settingsModule.useSettingsStore + useGraphStore = graphModule.useGraphStore + }) + + describe('executeWorkspaceChangeActions', () => { + test('Calls graphStore.reset() β€” verify graph state is reset', () => { + // Set up graph state with some data + useGraphStore.setState({ + selectedNode: 'test-node', + focusedNode: 'test-focused', + selectedEdge: 'test-edge', + isFetching: true, + graphDataFetchAttempted: true, + labelsFetchAttempted: true + }) + + // Verify state is set + expect(useGraphStore.getState().selectedNode).toBe('test-node') + expect(useGraphStore.getState().isFetching).toBe(true) + expect(useGraphStore.getState().graphDataFetchAttempted).toBe(true) + + // Execute the workspace change action + executeWorkspaceChangeActions( + useSettingsStore.getState(), + useGraphStore.getState() + ) + + // Verify graph state is reset + const graphState = useGraphStore.getState() + expect(graphState.selectedNode).toBeNull() + expect(graphState.focusedNode).toBeNull() + expect(graphState.selectedEdge).toBeNull() + expect(graphState.isFetching).toBe(false) + expect(graphState.graphDataFetchAttempted).toBe(false) + expect(graphState.labelsFetchAttempted).toBe(false) + }) + + test('Calls settingsStore.setRetrievalHistory([]) β€” verify retrieval history is empty', () => { + // Set up retrieval history with messages + const testMessages = [ + { role: 'user' as const, content: 'Hello' }, + { role: 'assistant' as const, content: 'Hi there' } + ] + useSettingsStore.setState({ retrievalHistory: testMessages }) + + // Verify state is set + expect(useSettingsStore.getState().retrievalHistory).toEqual(testMessages) + + // Execute the workspace change action + executeWorkspaceChangeActions( + useSettingsStore.getState(), + useGraphStore.getState() + ) + + // Verify retrieval history is empty + expect(useSettingsStore.getState().retrievalHistory).toEqual([]) + }) + + test('Both actions called together on workspace change', () => { + // Set up both stores with data + useGraphStore.setState({ + selectedNode: 'some-node', + isFetching: true + }) + const testMessages = [ + { role: 'user' as const, content: 'Test message' } + ] + useSettingsStore.setState({ retrievalHistory: testMessages }) + + // Execute workspace change actions + executeWorkspaceChangeActions( + useSettingsStore.getState(), + useGraphStore.getState() + ) + + // Verify both were cleared + const graphState = useGraphStore.getState() + expect(graphState.selectedNode).toBeNull() + expect(graphState.isFetching).toBe(false) + + expect(useSettingsStore.getState().retrievalHistory).toEqual([]) + }) + }) + + describe('Integration: full workspace change scenario', () => { + test('Populate graph state, execute workspace change, verify reset', () => { + // Reset to clean state first + useGraphStore.setState({ + selectedNode: null, + focusedNode: null, + selectedEdge: null, + isFetching: false, + graphDataFetchAttempted: false, + labelsFetchAttempted: false + }) + useSettingsStore.setState({ retrievalHistory: [] }) + + // Populate graph state + useGraphStore.setState({ + selectedNode: 'entity-123', + focusedNode: 'entity-456', + selectedEdge: 'edge-789', + isFetching: true, + graphDataFetchAttempted: true, + labelsFetchAttempted: true + }) + + // Set retrieval history with messages + const messages = [ + { role: 'user' as const, content: 'Query 1' }, + { role: 'assistant' as const, content: 'Response 1' }, + { role: 'user' as const, content: 'Query 2' } + ] + useSettingsStore.setState({ retrievalHistory: messages }) + + // Execute workspace change (actions from hook) + executeWorkspaceChangeActions( + useSettingsStore.getState(), + useGraphStore.getState() + ) + + // Note: triggerWorkspaceRefresh() is called separately in the hook + // We verify the clear actions here + + // Verify graph state is reset + const graphState = useGraphStore.getState() + expect(graphState.selectedNode).toBeNull() + expect(graphState.focusedNode).toBeNull() + expect(graphState.selectedEdge).toBeNull() + expect(graphState.isFetching).toBe(false) + expect(graphState.graphDataFetchAttempted).toBe(false) + expect(graphState.labelsFetchAttempted).toBe(false) + + // Verify retrieval history is empty + expect(useSettingsStore.getState().retrievalHistory).toEqual([]) + }) + }) + + describe('Rapid switching simulation', () => { + test('Simulate: null β†’ ws-a β†’ ws-b β†’ null in sequence', () => { + // Reset to clean state + useGraphStore.setState({ + selectedNode: null, + focusedNode: null, + selectedEdge: null, + isFetching: false, + graphDataFetchAttempted: false, + labelsFetchAttempted: false + }) + useSettingsStore.setState({ retrievalHistory: [] }) + + const workspaceSequence: (string | null)[] = [null, 'ws-a', 'ws-b', null] + + for (let i = 1; i < workspaceSequence.length; i++) { + const previous = workspaceSequence[i - 1] + const current = workspaceSequence[i] + + // Only trigger if workspace changed + if (shouldTriggerRefresh(previous, current)) { + // Set up data before change + useGraphStore.setState({ + selectedNode: `node-${previous ?? 'null'}`, + isFetching: true + }) + useSettingsStore.setState({ + retrievalHistory: [{ role: 'user' as const, content: `msg-${previous ?? 'null'}` }] + }) + + // Execute workspace change actions + executeWorkspaceChangeActions( + useSettingsStore.getState(), + useGraphStore.getState() + ) + + // Verify actions were triggered + expect(useGraphStore.getState().selectedNode).toBeNull() + expect(useGraphStore.getState().isFetching).toBe(false) + expect(useSettingsStore.getState().retrievalHistory).toEqual([]) + } + } + + // Final state should be consistent + const finalState = useGraphStore.getState() + expect(finalState.selectedNode).toBeNull() + expect(finalState.focusedNode).toBeNull() + expect(finalState.selectedEdge).toBeNull() + expect(finalState.isFetching).toBe(false) + expect(finalState.graphDataFetchAttempted).toBe(false) + expect(finalState.labelsFetchAttempted).toBe(false) + + expect(useSettingsStore.getState().retrievalHistory).toEqual([]) + }) + + test('After each change, verify actions were triggered for distinct workspaces', () => { + // Reset to clean state + useGraphStore.setState({ + selectedNode: null, + isFetching: false, + graphDataFetchAttempted: false, + labelsFetchAttempted: false + }) + useSettingsStore.setState({ retrievalHistory: [] }) + + // Test each distinct workspace change + const changes = [ + { from: null, to: 'workspace-1' }, + { from: 'workspace-1', to: 'workspace-2' }, + { from: 'workspace-2', to: 'workspace-3' }, + { from: 'workspace-3', to: null } + ] + + changes.forEach(({ from, to }) => { + // Verify shouldTriggerRefresh returns true + expect(shouldTriggerRefresh(from, to)).toBe(true) + + // Set up data + useGraphStore.setState({ + selectedNode: `node-${from ?? 'null'}`, + isFetching: true, + graphDataFetchAttempted: true, + labelsFetchAttempted: true + }) + useSettingsStore.setState({ + retrievalHistory: [ + { role: 'user' as const, content: `message for ${from ?? 'null'}` } + ] + }) + + // Execute workspace change + executeWorkspaceChangeActions( + useSettingsStore.getState(), + useGraphStore.getState() + ) + + // Verify actions were triggered + expect(useGraphStore.getState().selectedNode).toBeNull() + expect(useGraphStore.getState().isFetching).toBe(false) + expect(useGraphStore.getState().graphDataFetchAttempted).toBe(false) + expect(useSettingsStore.getState().retrievalHistory).toEqual([]) + }) + }) + }) +}) diff --git a/lightrag_webui/src/hooks/useWorkspaceChange.ts b/lightrag_webui/src/hooks/useWorkspaceChange.ts new file mode 100644 index 0000000000..42dd246dfa --- /dev/null +++ b/lightrag_webui/src/hooks/useWorkspaceChange.ts @@ -0,0 +1,53 @@ +import { useEffect, useRef } from 'react' +import { useSettingsStore } from '@/stores/settings' +import { defaultQueryLabel } from '@/lib/constants' +import { useGraphStore } from '@/stores/graph' + +/** + * Hook to handle workspace change events. + * When the workspace changes, this hook clears all workspace-dependent data: + * - Graph store state + * - Graph fetch attempt flags + * - Retrieval history + * + * This should be called once at the App level. + */ +const useWorkspaceChange = () => { + const currentWorkspace = useSettingsStore.use.currentWorkspace() + const triggerWorkspaceRefresh = useSettingsStore.use.triggerWorkspaceRefresh() + + // Use ref to track previous workspace value to avoid infinite loops + const previousWorkspaceRef = useRef(currentWorkspace) + + useEffect(() => { + // Only trigger refresh if workspace actually changed + if (previousWorkspaceRef.current !== currentWorkspace) { + const previousWorkspace = previousWorkspaceRef.current + previousWorkspaceRef.current = currentWorkspace + + console.log(`Workspace changed from "${previousWorkspace}" to "${currentWorkspace}", refreshing...`) + + // 1. Clear graph store state + const graphState = useGraphStore.getState() + graphState.reset() + + // 2. Increment graphDataVersion to trigger the fetch useEffect + // reset() sets graphDataFetchAttempted=false but doesn't change any + // useEffect dependency, so we must bump the version to trigger re-fetch + graphState.incrementGraphDataVersion() + + // 3. Reset queryLabel to default so the graph fetches for the new workspace + // When graph data is empty, the fetch handler clears queryLabel to '' + // Without resetting it, the fetch useEffect won't enter the fetch path + useSettingsStore.getState().setQueryLabel(defaultQueryLabel) + + // 4. Clear retrieval history + useSettingsStore.getState().setRetrievalHistory([]) + + // 5. Trigger workspace refresh for DocumentManager and other components + triggerWorkspaceRefresh() + } + }, [currentWorkspace, triggerWorkspaceRefresh]) +} + +export default useWorkspaceChange diff --git a/lightrag_webui/src/locales/en.json b/lightrag_webui/src/locales/en.json index 52677dded4..7bb2574bf0 100644 --- a/lightrag_webui/src/locales/en.json +++ b/lightrag_webui/src/locales/en.json @@ -1,4 +1,8 @@ { + "workspace": { + "none": "None", + "selector": "Workspace" + }, "settings": { "language": "Language", "theme": "Theme", diff --git a/lightrag_webui/src/stores/graph.test.ts b/lightrag_webui/src/stores/graph.test.ts new file mode 100644 index 0000000000..329f12a8a7 --- /dev/null +++ b/lightrag_webui/src/stores/graph.test.ts @@ -0,0 +1,279 @@ +import { beforeAll, describe, expect, test } from 'bun:test' + +const storageMock = () => { + const data = new Map() + + return { + getItem: (key: string) => data.get(key) ?? null, + setItem: (key: string, value: string) => { + data.set(key, value) + }, + removeItem: (key: string) => { + data.delete(key) + }, + clear: () => { + data.clear() + } + } +} + +let useGraphStore: ReturnType + +beforeAll(async () => { + // Try direct import first (graph store doesn't use persist middleware) + try { + const module = await import('@/stores/graph') + useGraphStore = module.useGraphStore + } catch { + // Fallback: mock localStorage if needed + Object.defineProperty(globalThis, 'localStorage', { + value: storageMock(), + configurable: true + }) + Object.defineProperty(globalThis, 'sessionStorage', { + value: storageMock(), + configurable: true + }) + const module = await import('@/stores/graph') + useGraphStore = module.useGraphStore + } +}) + +describe('graphDataVersion state', () => { + test('default value is 0', () => { + // Reset first + useGraphStore.setState({ graphDataVersion: 0 }) + expect(useGraphStore.getState().graphDataVersion).toBe(0) + }) + + test('incrementGraphDataVersion() increments by 1', () => { + // Reset first + useGraphStore.setState({ graphDataVersion: 0 }) + + useGraphStore.getState().incrementGraphDataVersion() + + expect(useGraphStore.getState().graphDataVersion).toBe(1) + }) + + test('multiple calls increment correctly (0 β†’ 1 β†’ 2 β†’ 3)', () => { + // Reset first + useGraphStore.setState({ graphDataVersion: 0 }) + + useGraphStore.getState().incrementGraphDataVersion() + expect(useGraphStore.getState().graphDataVersion).toBe(1) + + useGraphStore.getState().incrementGraphDataVersion() + expect(useGraphStore.getState().graphDataVersion).toBe(2) + + useGraphStore.getState().incrementGraphDataVersion() + expect(useGraphStore.getState().graphDataVersion).toBe(3) + }) +}) + +describe('graph.reset() - completeness', () => { + test('resets all graph state fields to defaults', () => { + // Setup: populate state with various values + useGraphStore.setState({ + selectedNode: 'node-1', + focusedNode: 'node-2', + selectedEdge: 'edge-1', + focusedEdge: 'edge-2', + rawGraph: { nodes: [], edges: [] } as any, + sigmaGraph: {} as any, + searchEngine: {} as any, + moveToSelectedNode: true, + graphIsEmpty: true, + isFetching: true, + graphDataFetchAttempted: true, + labelsFetchAttempted: true + }) + + // Verify state is populated + const populatedState = useGraphStore.getState() + expect(populatedState.selectedNode).toBe('node-1') + expect(populatedState.focusedNode).toBe('node-2') + expect(populatedState.selectedEdge).toBe('edge-1') + expect(populatedState.focusedEdge).toBe('edge-2') + expect(populatedState.rawGraph).not.toBeNull() + expect(populatedState.sigmaGraph).not.toBeNull() + expect(populatedState.searchEngine).not.toBeNull() + expect(populatedState.moveToSelectedNode).toBe(true) + expect(populatedState.graphIsEmpty).toBe(true) + expect(populatedState.isFetching).toBe(true) + expect(populatedState.graphDataFetchAttempted).toBe(true) + expect(populatedState.labelsFetchAttempted).toBe(true) + + // Call reset + useGraphStore.getState().reset() + + // Verify all fields reset to defaults + const resetState = useGraphStore.getState() + expect(resetState.selectedNode).toBeNull() + expect(resetState.focusedNode).toBeNull() + expect(resetState.selectedEdge).toBeNull() + expect(resetState.focusedEdge).toBeNull() + expect(resetState.rawGraph).toBeNull() + expect(resetState.sigmaGraph).toBeNull() + expect(resetState.searchEngine).toBeNull() + expect(resetState.moveToSelectedNode).toBe(false) + expect(resetState.graphIsEmpty).toBe(false) + expect(resetState.isFetching).toBe(false) + expect(resetState.graphDataFetchAttempted).toBe(false) + expect(resetState.labelsFetchAttempted).toBe(false) + }) +}) + +describe('fetch flags', () => { + test('setGraphDataFetchAttempted(true) sets the flag', () => { + useGraphStore.setState({ graphDataFetchAttempted: false }) + expect(useGraphStore.getState().graphDataFetchAttempted).toBe(false) + + useGraphStore.getState().setGraphDataFetchAttempted(true) + + expect(useGraphStore.getState().graphDataFetchAttempted).toBe(true) + }) + + test('setLabelsFetchAttempted(true) sets the flag', () => { + useGraphStore.setState({ labelsFetchAttempted: false }) + expect(useGraphStore.getState().labelsFetchAttempted).toBe(false) + + useGraphStore.getState().setLabelsFetchAttempted(true) + + expect(useGraphStore.getState().labelsFetchAttempted).toBe(true) + }) + + test('after reset(), both fetch flags are false', () => { + // Set flags to true + useGraphStore.setState({ + graphDataFetchAttempted: true, + labelsFetchAttempted: true + }) + + expect(useGraphStore.getState().graphDataFetchAttempted).toBe(true) + expect(useGraphStore.getState().labelsFetchAttempted).toBe(true) + + // Reset + useGraphStore.getState().reset() + + // Verify flags are false + expect(useGraphStore.getState().graphDataFetchAttempted).toBe(false) + expect(useGraphStore.getState().labelsFetchAttempted).toBe(false) + }) +}) + +describe('selectedNode/Edge', () => { + test('setSelectedNode("node-1") works', () => { + useGraphStore.setState({ selectedNode: null }) + expect(useGraphStore.getState().selectedNode).toBeNull() + + useGraphStore.getState().setSelectedNode('node-1') + + expect(useGraphStore.getState().selectedNode).toBe('node-1') + }) + + test('setSelectedNode with moveToSelectedNode option', () => { + useGraphStore.setState({ selectedNode: null, moveToSelectedNode: false }) + + useGraphStore.getState().setSelectedNode('node-1', true) + + expect(useGraphStore.getState().selectedNode).toBe('node-1') + expect(useGraphStore.getState().moveToSelectedNode).toBe(true) + }) + + test('clearSelection() clears all selections', () => { + // Set multiple selections + useGraphStore.setState({ + selectedNode: 'node-1', + focusedNode: 'node-2', + selectedEdge: 'edge-1', + focusedEdge: 'edge-2' + }) + + // Verify all are set + expect(useGraphStore.getState().selectedNode).toBe('node-1') + expect(useGraphStore.getState().focusedNode).toBe('node-2') + expect(useGraphStore.getState().selectedEdge).toBe('edge-1') + expect(useGraphStore.getState().focusedEdge).toBe('edge-2') + + // Clear + useGraphStore.getState().clearSelection() + + // Verify all are null + expect(useGraphStore.getState().selectedNode).toBeNull() + expect(useGraphStore.getState().focusedNode).toBeNull() + expect(useGraphStore.getState().selectedEdge).toBeNull() + expect(useGraphStore.getState().focusedEdge).toBeNull() + }) + + test('after reset(), all selections are null', () => { + // Set selections + useGraphStore.setState({ + selectedNode: 'node-1', + focusedNode: 'node-2', + selectedEdge: 'edge-1', + focusedEdge: 'edge-2' + }) + + // Reset + useGraphStore.getState().reset() + + // Verify all are null + expect(useGraphStore.getState().selectedNode).toBeNull() + expect(useGraphStore.getState().focusedNode).toBeNull() + expect(useGraphStore.getState().selectedEdge).toBeNull() + expect(useGraphStore.getState().focusedEdge).toBeNull() + }) +}) + +describe('other graph state setters', () => { + test('setRawGraph() updates rawGraph', () => { + const mockGraph = { nodes: [{ id: 'n1' }], edges: [] } as any + useGraphStore.getState().setRawGraph(mockGraph) + expect(useGraphStore.getState().rawGraph).toBe(mockGraph) + }) + + test('setSigmaGraph() updates sigmaGraph', () => { + const mockGraph = {} as any + useGraphStore.getState().setSigmaGraph(mockGraph) + expect(useGraphStore.getState().sigmaGraph).toBe(mockGraph) + }) + + test('setIsFetching() updates isFetching', () => { + useGraphStore.setState({ isFetching: false }) + expect(useGraphStore.getState().isFetching).toBe(false) + + useGraphStore.getState().setIsFetching(true) + expect(useGraphStore.getState().isFetching).toBe(true) + }) + + test('setGraphIsEmpty() updates graphIsEmpty', () => { + useGraphStore.setState({ graphIsEmpty: false }) + expect(useGraphStore.getState().graphIsEmpty).toBe(false) + + useGraphStore.getState().setGraphIsEmpty(true) + expect(useGraphStore.getState().graphIsEmpty).toBe(true) + }) + + test('setMoveToSelectedNode() updates moveToSelectedNode', () => { + useGraphStore.setState({ moveToSelectedNode: false }) + expect(useGraphStore.getState().moveToSelectedNode).toBe(false) + + useGraphStore.getState().setMoveToSelectedNode(true) + expect(useGraphStore.getState().moveToSelectedNode).toBe(true) + }) + + test('setSearchEngine() updates searchEngine', () => { + const mockEngine = {} as any + useGraphStore.getState().setSearchEngine(mockEngine) + expect(useGraphStore.getState().searchEngine).toBe(mockEngine) + }) + + test('resetSearchEngine() sets searchEngine to null', () => { + const mockEngine = {} as any + useGraphStore.getState().setSearchEngine(mockEngine) + expect(useGraphStore.getState().searchEngine).not.toBeNull() + + useGraphStore.getState().resetSearchEngine() + expect(useGraphStore.getState().searchEngine).toBeNull() + }) +}) diff --git a/lightrag_webui/src/stores/graph.ts b/lightrag_webui/src/stores/graph.ts index 7489d41172..c24dc8a94c 100644 --- a/lightrag_webui/src/stores/graph.ts +++ b/lightrag_webui/src/stores/graph.ts @@ -195,10 +195,13 @@ const useGraphStoreBase = create()((set, get) => ({ selectedEdge: null, focusedEdge: null, rawGraph: null, - sigmaGraph: null, // to avoid other components from acccessing graph objects + sigmaGraph: null, // to avoid other components from acccessing graph objects searchEngine: null, moveToSelectedNode: false, - graphIsEmpty: false + graphIsEmpty: false, + isFetching: false, + graphDataFetchAttempted: false, + labelsFetchAttempted: false, }); }, diff --git a/lightrag_webui/src/stores/settings.test.ts b/lightrag_webui/src/stores/settings.test.ts new file mode 100644 index 0000000000..9f2b5c658f --- /dev/null +++ b/lightrag_webui/src/stores/settings.test.ts @@ -0,0 +1,457 @@ +import { beforeAll, describe, expect, test } from 'bun:test' + +const storageMock = () => { + const data = new Map() + + return { + getItem: (key: string) => data.get(key) ?? null, + setItem: (key: string, value: string) => { + data.set(key, value) + }, + removeItem: (key: string) => { + data.delete(key) + }, + clear: () => { + data.clear() + } + } +} + +// Re-implement the migrate logic for testing (mirrors settings.ts) +// The key behavior: only migrations where version < X are applied +const migrateSettings = (state: any, version: number): any => { + if (version < 2) { + state.showEdgeLabel = false + } + if (version < 3) { + state.queryLabel = 'default' + } + if (version < 4) { + state.showPropertyPanel = true + state.showNodeSearchBar = true + state.showNodeLabel = true + state.enableHealthCheck = true + state.apiKey = null + } + if (version < 5) { + state.currentTab = 'documents' + } + if (version < 6) { + state.querySettings = { + mode: 'global', + response_type: 'Multiple Paragraphs', + top_k: 10, + max_token_for_text_unit: 4000, + max_token_for_global_context: 4000, + max_token_for_local_context: 4000, + only_need_context: false, + only_need_prompt: false, + stream: true, + history_turns: 0, + hl_keywords: [], + ll_keywords: [] + } + state.retrievalHistory = [] + } + if (version < 7) { + state.graphQueryMaxDepth = 3 + state.graphLayoutMaxIterations = 15 + } + if (version < 8) { + state.graphMinDegree = 0 + state.language = 'en' + } + if (version < 9) { + state.showFileName = false + } + if (version < 10) { + delete state.graphMinDegree + state.graphMaxNodes = 1000 + } + if (version < 11) { + state.minEdgeSize = 1 + state.maxEdgeSize = 1 + } + if (version < 12) { + state.retrievalHistory = [] + } + if (version < 13) { + if (state.querySettings) { + state.querySettings.user_prompt = '' + } + } + if (version < 14) { + state.backendMaxGraphNodes = null + } + if (version < 15) { + state.querySettings = { + ...state.querySettings, + mode: 'mix', + response_type: 'Multiple Paragraphs', + top_k: 40, + chunk_top_k: 10, + max_entity_tokens: 10000, + max_relation_tokens: 10000, + max_total_tokens: 32000, + enable_rerank: true, + history_turns: 0, + } + } + if (version < 16) { + state.documentsPageSize = 10 + } + if (version < 17) { + if (state.querySettings) { + state.querySettings.history_turns = 0 + } + } + if (version < 18) { + state.userPromptHistory = [] + } + if (version < 19) { + if (state.querySettings) { + delete state.querySettings.response_type + } + } + if (version < 20) { + state.currentWorkspace = null + } + return state +} + +describe('currentWorkspace state', () => { + let useSettingsStore: ReturnType + + beforeAll(async () => { + Object.defineProperty(globalThis, 'localStorage', { + value: storageMock(), + configurable: true + }) + Object.defineProperty(globalThis, 'sessionStorage', { + value: storageMock(), + configurable: true + }) + + const module = await import('@/stores/settings') + useSettingsStore = module.useSettingsStore + }) + + test('default value is null', () => { + // Reset to default state + useSettingsStore.setState({ currentWorkspace: null }) + expect(useSettingsStore.getState().currentWorkspace).toBeNull() + }) + + test('setCurrentWorkspace("my-workspace") sets it to the string', () => { + useSettingsStore.getState().setCurrentWorkspace('my-workspace') + expect(useSettingsStore.getState().currentWorkspace).toBe('my-workspace') + }) + + test('setCurrentWorkspace(null) resets to null', () => { + // First set a value + useSettingsStore.getState().setCurrentWorkspace('some-workspace') + expect(useSettingsStore.getState().currentWorkspace).toBe('some-workspace') + + // Then reset to null + useSettingsStore.getState().setCurrentWorkspace(null) + expect(useSettingsStore.getState().currentWorkspace).toBeNull() + }) + + test('state change is reflected in getState()', () => { + // Reset first + useSettingsStore.getState().setCurrentWorkspace(null) + + // Set a workspace + useSettingsStore.getState().setCurrentWorkspace('test-workspace') + + // Verify via getState() + const state = useSettingsStore.getState() + expect(state.currentWorkspace).toBe('test-workspace') + + // Verify it can be set back to null + useSettingsStore.getState().setCurrentWorkspace(null) + expect(useSettingsStore.getState().currentWorkspace).toBeNull() + }) +}) + +describe('v20 migration', () => { + test('state with version < 20 gets currentWorkspace: null added', () => { + const state = { + theme: 'dark', + someOldField: 'value' + } + + const migratedState = migrateSettings(state, 19) + + expect(migratedState.currentWorkspace).toBeNull() + expect(migratedState.theme).toBe('dark') + expect(migratedState.someOldField).toBe('value') + }) + + test('state at version 20 is returned as-is (currentWorkspace not modified)', () => { + const state = { + theme: 'light', + currentWorkspace: 'existing-workspace' + } + + const migratedState = migrateSettings(state, 20) + + // Since version is 20, the v20 migration (version < 20) should NOT run + // So currentWorkspace should remain unchanged (existing-workspace) + expect(migratedState.currentWorkspace).toBe('existing-workspace') + expect(migratedState.theme).toBe('light') + }) + + test('existing fields are preserved during migration', () => { + const state = { + theme: 'dark', + language: 'en', + showPropertyPanel: true, + graphMaxNodes: 500, + currentWorkspace: undefined // intentionally missing + } + + const migratedState = migrateSettings(state, 19) + + // Check all original fields are preserved + expect(migratedState.theme).toBe('dark') + expect(migratedState.language).toBe('en') + expect(migratedState.showPropertyPanel).toBe(true) + expect(migratedState.graphMaxNodes).toBe(500) + // New field added by migration + expect(migratedState.currentWorkspace).toBeNull() + }) + + test('migration adds currentWorkspace: null for very old versions (e.g., version 0)', () => { + const state = {} + + const migratedState = migrateSettings(state, 0) + + expect(migratedState.currentWorkspace).toBeNull() + }) + + test('migration preserves already-set currentWorkspace when version >= 20', () => { + const state = { + currentWorkspace: 'already-set-workspace' + } + + const migratedState = migrateSettings(state, 21) + + // Should preserve the existing value since version >= 20 + expect(migratedState.currentWorkspace).toBe('already-set-workspace') + }) +}) + +describe('workspaceRefreshTrigger state', () => { + let useSettingsStore: ReturnType + + beforeAll(async () => { + Object.defineProperty(globalThis, 'localStorage', { + value: storageMock(), + configurable: true + }) + Object.defineProperty(globalThis, 'sessionStorage', { + value: storageMock(), + configurable: true + }) + + const module = await import('@/stores/settings') + useSettingsStore = module.useSettingsStore + }) + + test('default value is 0', () => { + // Reset to default state + useSettingsStore.setState({ workspaceRefreshTrigger: 0 }) + expect(useSettingsStore.getState().workspaceRefreshTrigger).toBe(0) + }) + + test('triggerWorkspaceRefresh() increments by 1', () => { + // Reset first + useSettingsStore.setState({ workspaceRefreshTrigger: 0 }) + + // Trigger once + useSettingsStore.getState().triggerWorkspaceRefresh() + + expect(useSettingsStore.getState().workspaceRefreshTrigger).toBe(1) + }) + + test('multiple calls increment correctly (0 β†’ 1 β†’ 2 β†’ 3)', () => { + // Reset first + useSettingsStore.setState({ workspaceRefreshTrigger: 0 }) + + // Trigger multiple times + useSettingsStore.getState().triggerWorkspaceRefresh() + expect(useSettingsStore.getState().workspaceRefreshTrigger).toBe(1) + + useSettingsStore.getState().triggerWorkspaceRefresh() + expect(useSettingsStore.getState().workspaceRefreshTrigger).toBe(2) + + useSettingsStore.getState().triggerWorkspaceRefresh() + expect(useSettingsStore.getState().workspaceRefreshTrigger).toBe(3) + }) + + test('trigger is independent of currentWorkspace state', () => { + // Reset both + useSettingsStore.setState({ workspaceRefreshTrigger: 0, currentWorkspace: null }) + + // Set a workspace + useSettingsStore.getState().setCurrentWorkspace('my-workspace') + + // Trigger refresh + useSettingsStore.getState().triggerWorkspaceRefresh() + + // Both should have their own values + expect(useSettingsStore.getState().workspaceRefreshTrigger).toBe(1) + expect(useSettingsStore.getState().currentWorkspace).toBe('my-workspace') + + // Change workspace and trigger again + useSettingsStore.getState().setCurrentWorkspace('other-workspace') + useSettingsStore.getState().triggerWorkspaceRefresh() + + expect(useSettingsStore.getState().workspaceRefreshTrigger).toBe(2) + expect(useSettingsStore.getState().currentWorkspace).toBe('other-workspace') + }) +}) + +describe('searchLabelDropdownRefreshTrigger state', () => { + let useSettingsStore: ReturnType + + beforeAll(async () => { + Object.defineProperty(globalThis, 'localStorage', { + value: storageMock(), + configurable: true + }) + Object.defineProperty(globalThis, 'sessionStorage', { + value: storageMock(), + configurable: true + }) + + const module = await import('@/stores/settings') + useSettingsStore = module.useSettingsStore + }) + + test('default value is 0', () => { + useSettingsStore.setState({ searchLabelDropdownRefreshTrigger: 0 }) + expect(useSettingsStore.getState().searchLabelDropdownRefreshTrigger).toBe(0) + }) + + test('triggerSearchLabelDropdownRefresh() increments by 1', () => { + useSettingsStore.setState({ searchLabelDropdownRefreshTrigger: 0 }) + useSettingsStore.getState().triggerSearchLabelDropdownRefresh() + expect(useSettingsStore.getState().searchLabelDropdownRefreshTrigger).toBe(1) + }) + + test('multiple calls increment correctly', () => { + useSettingsStore.setState({ searchLabelDropdownRefreshTrigger: 0 }) + useSettingsStore.getState().triggerSearchLabelDropdownRefresh() + useSettingsStore.getState().triggerSearchLabelDropdownRefresh() + useSettingsStore.getState().triggerSearchLabelDropdownRefresh() + expect(useSettingsStore.getState().searchLabelDropdownRefreshTrigger).toBe(3) + }) +}) + +describe('partialize - trigger fields excluded from persistence', () => { + let useSettingsStore: ReturnType + + // Re-implement the partialize logic to test it + const partializeState = (state: any) => { + const { + workspaceRefreshTrigger, + triggerWorkspaceRefresh, + searchLabelDropdownRefreshTrigger, + triggerSearchLabelDropdownRefresh, + ...rest + } = state + return rest + } + + beforeAll(async () => { + Object.defineProperty(globalThis, 'localStorage', { + value: storageMock(), + configurable: true + }) + Object.defineProperty(globalThis, 'sessionStorage', { + value: storageMock(), + configurable: true + }) + + const module = await import('@/stores/settings') + useSettingsStore = module.useSettingsStore + }) + + test('workspaceRefreshTrigger is excluded from partialize output', () => { + // Set a trigger value + useSettingsStore.setState({ workspaceRefreshTrigger: 5 }) + + const state = useSettingsStore.getState() + const partialized = partializeState(state) + + expect(partialized).not.toHaveProperty('workspaceRefreshTrigger') + expect(partialized).not.toHaveProperty('triggerWorkspaceRefresh') + }) + + test('searchLabelDropdownRefreshTrigger is excluded from partialize output', () => { + // Set a trigger value + useSettingsStore.setState({ searchLabelDropdownRefreshTrigger: 10 }) + + const state = useSettingsStore.getState() + const partialized = partializeState(state) + + expect(partialized).not.toHaveProperty('searchLabelDropdownRefreshTrigger') + expect(partialized).not.toHaveProperty('triggerSearchLabelDropdownRefresh') + }) + + test('currentWorkspace IS included in partialize output', () => { + // Set currentWorkspace + useSettingsStore.setState({ currentWorkspace: 'my-workspace' }) + + const state = useSettingsStore.getState() + const partialized = partializeState(state) + + expect(partialized).toHaveProperty('currentWorkspace') + expect(partialized.currentWorkspace).toBe('my-workspace') + }) + + test('after calling triggerWorkspaceRefresh, trigger value is updated in state but NOT in partialize output', () => { + // Reset first + useSettingsStore.setState({ workspaceRefreshTrigger: 0, currentWorkspace: null }) + + // Trigger refresh + useSettingsStore.getState().triggerWorkspaceRefresh() + + // Verify in-state trigger is updated + expect(useSettingsStore.getState().workspaceRefreshTrigger).toBe(1) + + // Verify partialize excludes it + const state = useSettingsStore.getState() + const partialized = partializeState(state) + expect(partialized).not.toHaveProperty('workspaceRefreshTrigger') + }) + + test('other settings are still included in partialize output', () => { + // Set various settings + useSettingsStore.setState({ + theme: 'dark', + language: 'en', + showPropertyPanel: false, + graphMaxNodes: 500, + currentWorkspace: 'test-workspace' + }) + + const state = useSettingsStore.getState() + const partialized = partializeState(state) + + // These should be present + expect(partialized).toHaveProperty('theme') + expect(partialized).toHaveProperty('language') + expect(partialized).toHaveProperty('showPropertyPanel') + expect(partialized).toHaveProperty('graphMaxNodes') + expect(partialized).toHaveProperty('currentWorkspace') + + // These should be excluded (triggers) + expect(partialized).not.toHaveProperty('workspaceRefreshTrigger') + expect(partialized).not.toHaveProperty('triggerWorkspaceRefresh') + expect(partialized).not.toHaveProperty('searchLabelDropdownRefreshTrigger') + expect(partialized).not.toHaveProperty('triggerSearchLabelDropdownRefresh') + }) +}) diff --git a/lightrag_webui/src/stores/settings.ts b/lightrag_webui/src/stores/settings.ts index 823ee084a6..4efaf8ce73 100644 --- a/lightrag_webui/src/stores/settings.ts +++ b/lightrag_webui/src/stores/settings.ts @@ -82,6 +82,14 @@ interface SettingsState { // Search label dropdown refresh trigger (non-persistent, runtime only) searchLabelDropdownRefreshTrigger: number triggerSearchLabelDropdownRefresh: () => void + + // Workspace settings + currentWorkspace: string | null + setCurrentWorkspace: (workspace: string | null) => void + + // Workspace refresh trigger (non-persistent, runtime only) + workspaceRefreshTrigger: number + triggerWorkspaceRefresh: () => void } const useSettingsStoreBase = create()( @@ -233,12 +241,34 @@ const useSettingsStoreBase = create()( triggerSearchLabelDropdownRefresh: () => set((state) => ({ searchLabelDropdownRefreshTrigger: state.searchLabelDropdownRefreshTrigger + 1 - })) + })), + + currentWorkspace: null, + setCurrentWorkspace: (workspace: string | null) => set({ currentWorkspace: workspace }), + + // Workspace refresh trigger (not persisted) + workspaceRefreshTrigger: 0, + triggerWorkspaceRefresh: () => + set((state) => ({ + workspaceRefreshTrigger: state.workspaceRefreshTrigger + 1 + })), }), { name: 'settings-storage', storage: createJSONStorage(() => localStorage), - version: 19, + /* eslint-disable @typescript-eslint/no-unused-vars */ + partialize: (state) => { + const { + workspaceRefreshTrigger, + triggerWorkspaceRefresh, + searchLabelDropdownRefreshTrigger, + triggerSearchLabelDropdownRefresh, + ...rest + } = state + return rest + }, + /* eslint-enable @typescript-eslint/no-unused-vars */ + version: 20, migrate: (state: any, version: number) => { if (version < 2) { state.showEdgeLabel = false @@ -341,6 +371,9 @@ const useSettingsStoreBase = create()( delete state.querySettings.response_type } } + if (version < 20) { + state.currentWorkspace = null + } return state } } diff --git a/scripts/start_test_server.sh b/scripts/start_test_server.sh new file mode 100755 index 0000000000..b6ce7b671c --- /dev/null +++ b/scripts/start_test_server.sh @@ -0,0 +1,20 @@ +#!/bin/bash +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Load test environment (LLM, embedding configs) +if [ ! -f "$SCRIPT_DIR/.env-test" ]; then + echo "Error: $SCRIPT_DIR/.env-test not found. Create it with LLM/embedding config." + exit 1 +fi +set -a +source "$SCRIPT_DIR/.env-test" +set +a + +# Working directory +export WORKING_DIR="/tmp/lightrag_ws_test" +rm -rf "$WORKING_DIR" +mkdir -p "$WORKING_DIR" + +echo "Starting LightRAG server on port 9621..." +echo "Working dir: $WORKING_DIR" +python -m lightrag.api.lightrag_server --port 9621 --working-dir "$WORKING_DIR" diff --git a/scripts/test_workspace_isolation_e2e.py b/scripts/test_workspace_isolation_e2e.py new file mode 100755 index 0000000000..eadaf28a8a --- /dev/null +++ b/scripts/test_workspace_isolation_e2e.py @@ -0,0 +1,376 @@ +#!/usr/bin/env python3 +""" +Workspace Isolation E2E Test for LightRAG API + +Tests that different workspaces maintain isolated data. +""" + +import httpx +import os +import sys +import time + +BASE_URL = "http://localhost:9621" +WORKING_DIR = "/tmp/lightrag_ws_test" +REQUEST_TIMEOUT = 60.0 +QUERY_TIMEOUT = 30.0 + +results = [] + + +def print_test_header(name: str): + print(f"\n{'=' * 40}") + print(f"{name}") + print(f"{'=' * 40}\n") + + +def print_request(method: str, path: str, headers: dict, body: dict = None): + print(f"REQUEST: {method} {path}") + print(f"HEADERS: {headers}") + if body: + print(f"BODY: {body}") + print("-" * 40) + + +def print_response(status_code: int, body: dict): + print(f"RESPONSE: {status_code}") + print(f"BODY: {body}") + print("-" * 40) + + +def print_assertion(check: str, passed: bool): + status = "PASS" if passed else "FAIL" + print(f"ASSERTION: {check} β†’ {status}\n") + return passed + + +def wait_for_healthy(client: httpx.Client) -> bool: + """Poll /health until server is ready or timeout.""" + print("Waiting for server healthy...") + start = time.time() + attempt = 0 + + while time.time() - start < REQUEST_TIMEOUT: + attempt += 1 + try: + resp = client.get("/health", timeout=5.0) + print(f" Attempt {attempt}: {resp.status_code} - {resp.text[:100]}") + if resp.status_code == 200: + print(" Server is healthy!\n") + return True + except Exception as e: + print(f" Attempt {attempt}: Error - {e}") + + time.sleep(2) + + print(" TIMEOUT: Server not healthy after 60 seconds\n") + return False + + +def wait_for_pipeline_complete( + client: httpx.Client, workspace: str = None, doc_id: str = None +) -> dict: + """Poll /documents/pipeline_status until complete or timeout.""" + headers = {} + if workspace: + headers["LIGHTRAG-WORKSPACE"] = workspace + + start = time.time() + attempt = 0 + + while time.time() - start < REQUEST_TIMEOUT: + attempt += 1 + try: + resp = client.get( + "/documents/pipeline_status", headers=headers, timeout=5.0 + ) + if resp.status_code == 200: + data = resp.json() + # Check if pipeline is complete (no pending items) + # Adjust based on actual API response structure + is_complete = ( + data.get("pending", 0) == 0 and data.get("processing", 0) == 0 + if isinstance(data, dict) + else True + ) + print(f" Attempt {attempt}: {data}") + if is_complete: + return data + except Exception as e: + print(f" Attempt {attempt}: Error - {e}") + + time.sleep(2) + + return {"status": "timeout"} + + +def run_test( + name: str, + test_func, +): + """Run a test and track results.""" + print_test_header(name) + try: + passed = test_func() + results.append((name, passed)) + return passed + except Exception as e: + print(f"EXCEPTION: {e}") + results.append((name, False)) + return False + + +def t1_insert_query_alpha(client: httpx.Client) -> bool: + """T1: Insert document into alpha workspace and query it.""" + all_pass = True + + # Insert document + headers = {"LIGHTRAG-WORKSPACE": "alpha"} + body = { + "text": "Alice works at OpenAI in San Francisco. She is a senior engineer working on GPT models." + } + + print_request("POST", "/documents/text", headers, body) + resp = client.post("/documents/text", json=body, headers=headers, timeout=10.0) + print_response(resp.status_code, resp.json()) + + # Wait for pipeline + print("Waiting for pipeline to complete...") + wait_for_pipeline_complete(client, workspace="alpha") + + # Query + query_body = {"query": "Who is Alice and where does she work?", "mode": "mix"} + print_request("POST", "/query", headers, query_body) + resp = client.post( + "/query", json=query_body, headers=headers, timeout=QUERY_TIMEOUT + ) + print_response(resp.status_code, resp.json()) + + # Assert + response_text = resp.json().get("response", "").lower() + check = 'response mentions "Alice" AND "OpenAI"' + passed = "alice" in response_text and "openai" in response_text + all_pass &= print_assertion(check, passed) + + return all_pass + + +def t2_query_empty_beta(client: httpx.Client) -> bool: + """T2: Query beta workspace (should be empty, no alpha data).""" + headers = {"LIGHTRAG-WORKSPACE": "beta"} + body = {"query": "Who is Alice and where does she work?", "mode": "mix"} + + print_request("POST", "/query", headers, body) + resp = client.post("/query", json=body, headers=headers, timeout=QUERY_TIMEOUT) + print_response(resp.status_code, resp.json()) + + # Assert + response_text = resp.json().get("response", "").lower() + check = 'response does NOT mention "OpenAI", "San Francisco", or "GPT models"' + passed = ( + "openai" not in response_text + and "san francisco" not in response_text + and "gpt" not in response_text + ) + return print_assertion(check, passed) + + +def t3_insert_beta_verify_alpha(client: httpx.Client) -> bool: + """T3: Insert into beta, verify alpha is unaffected.""" + all_pass = True + + # Insert into beta + headers = {"LIGHTRAG-WORKSPACE": "beta"} + body = { + "text": "Bob works at Google in London. He is a product manager for Google Cloud." + } + + print_request("POST", "/documents/text", headers, body) + resp = client.post("/documents/text", json=body, headers=headers, timeout=10.0) + print_response(resp.status_code, resp.json()) + + # Wait for pipeline + print("Waiting for pipeline to complete...") + wait_for_pipeline_complete(client, workspace="beta") + + # Query beta for Bob + query_body = {"query": "Who is Bob and where does he work?", "mode": "mix"} + print_request("POST", "/query", headers, query_body) + resp = client.post( + "/query", json=query_body, headers=headers, timeout=QUERY_TIMEOUT + ) + print_response(resp.status_code, resp.json()) + + response_text = resp.json().get("response", "").lower() + check = 'beta response mentions "Bob" AND "Google"' + passed = "bob" in response_text and "google" in response_text + all_pass &= print_assertion(check, passed) + + # Query beta for Alice (should not find alpha data) + query_body = {"query": "Tell me about Alice", "mode": "mix"} + print_request("POST", "/query", headers, query_body) + resp = client.post( + "/query", json=query_body, headers=headers, timeout=QUERY_TIMEOUT + ) + print_response(resp.status_code, resp.json()) + + response_text = resp.json().get("response", "").lower() + check = 'beta response for Alice does NOT mention "OpenAI" or "San Francisco"' + passed = "openai" not in response_text and "san francisco" not in response_text + all_pass &= print_assertion(check, passed) + + # Query alpha for Bob (should not find beta data) + headers_alpha = {"LIGHTRAG-WORKSPACE": "alpha"} + query_body = {"query": "Tell me about Bob", "mode": "mix"} + print_request("POST", "/query", headers_alpha, query_body) + resp = client.post( + "/query", json=query_body, headers=headers_alpha, timeout=QUERY_TIMEOUT + ) + print_response(resp.status_code, resp.json()) + + response_text = resp.json().get("response", "").lower() + check = 'alpha response for Bob does NOT mention "Google" or "London"' + passed = "google" not in response_text and "london" not in response_text + all_pass &= print_assertion(check, passed) + + return all_pass + + +def t4_query_default_empty(client: httpx.Client) -> bool: + """T4: Query default workspace (no header) - should be empty.""" + body = {"query": "Who is Alice or Bob?", "mode": "mix"} + + print_request("POST", "/query", {}, body) + resp = client.post("/query", json=body, timeout=QUERY_TIMEOUT) + print_response(resp.status_code, resp.json()) + + # Assert + response_text = resp.json().get("response", "").lower() + check = 'default workspace response does NOT mention "OpenAI" AND does NOT mention "Google"' + passed = "openai" not in response_text and "google" not in response_text + return print_assertion(check, passed) + + +def t5_insert_query_default(client: httpx.Client) -> bool: + """T5: Insert into default workspace and query.""" + all_pass = True + + # Insert document (no workspace header) + body = {"text": "Charlie works at Meta in New York. He leads the VR research team."} + + print_request("POST", "/documents/text", {}, body) + resp = client.post("/documents/text", json=body, timeout=10.0) + print_response(resp.status_code, resp.json()) + + # Wait for pipeline + print("Waiting for pipeline to complete...") + wait_for_pipeline_complete(client) + + # Query default workspace + query_body = {"query": "Who is Charlie and where does he work?", "mode": "mix"} + print_request("POST", "/query", {}, query_body) + resp = client.post("/query", json=query_body, timeout=QUERY_TIMEOUT) + print_response(resp.status_code, resp.json()) + + # Assert + response_text = resp.json().get("response", "").lower() + check = 'default workspace response mentions "Charlie" AND "Meta"' + passed = "charlie" in response_text and "meta" in response_text + all_pass &= print_assertion(check, passed) + + return all_pass + + +def t6_filesystem_check() -> bool: + """T6: Check file system for workspace directories.""" + print("Checking workspace directories in:", WORKING_DIR) + + if not os.path.exists(WORKING_DIR): + print(f"ERROR: Working directory {WORKING_DIR} does not exist") + return False + + entries = os.listdir(WORKING_DIR) + print(f"\nDirectory entries: {entries}\n") + + # More flexible check - at least verify directories exist + check1 = len(entries) >= 3 # alpha, beta, default directories + print_assertion( + f"at least 3 workspace directories exist (found {len(entries)})", check1 + ) + + # Print tree structure + print("Directory tree:") + for entry in sorted(entries): + full_path = os.path.join(WORKING_DIR, entry) + if os.path.isdir(full_path): + sub_entries = os.listdir(full_path)[:5] # First 5 subentries + print(f" {entry}/") + for sub in sub_entries: + print(f" {sub}") + if len(os.listdir(full_path)) > 5: + print(f" ... ({len(os.listdir(full_path))} total items)") + else: + print(f" {entry}") + + return check1 + + +def print_summary(): + """Print final summary table.""" + print(f"\n{'=' * 60}") + print("TEST SUMMARY") + print(f"{'=' * 60}") + print(f"{'Test':<40} {'Result':<10}") + print("-" * 60) + + all_passed = True + for name, passed in results: + status = "PASS" if passed else "FAIL" + print(f"{name:<40} {status:<10}") + all_passed &= passed + + print("-" * 60) + verdict = "ALL TESTS PASSED" if all_passed else "SOME TESTS FAILED" + print(f"{'Verdict:':<40} {verdict}") + print(f"{'=' * 60}\n") + + return all_passed + + +def main(): + """Run all E2E tests.""" + print("\n" + "=" * 60) + print("WORKSPACE ISOLATION E2E TEST") + print("=" * 60) + + client = httpx.Client(base_url=BASE_URL, timeout=REQUEST_TIMEOUT) + + try: + # T0: Wait for server + print_test_header("T0: Wait for Server Healthy") + if not wait_for_healthy(client): + print("FATAL: Server not healthy") + print_summary() + return 1 + + # Run all tests + run_test("T1: Insert & Query Alpha", lambda: t1_insert_query_alpha(client)) + run_test("T2: Query Empty Beta", lambda: t2_query_empty_beta(client)) + run_test( + "T3: Insert Beta, Verify Alpha Unaffected", + lambda: t3_insert_beta_verify_alpha(client), + ) + run_test("T4: Query Default (Empty)", lambda: t4_query_default_empty(client)) + run_test("T5: Insert & Query Default", lambda: t5_insert_query_default(client)) + run_test("T6: File System Check", t6_filesystem_check) + + finally: + client.close() + + # Print summary and exit + all_passed = print_summary() + return 0 if all_passed else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/conftest.py b/tests/conftest.py index 09769fd64e..fd0941dec9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -60,6 +60,9 @@ def pytest_collection_modifyitems(config, items): Integration tests are skipped unless --run-integration flag is provided. This allows running offline tests quickly without needing external services. + + Note: Tests marked with @pytest.mark.offline are NOT skipped, even if they + are in an 'integration' directory, because they use mocked external services. """ if config.getoption("--run-integration"): # If --run-integration is specified, run all tests @@ -70,7 +73,9 @@ def pytest_collection_modifyitems(config, items): ) for item in items: - if "integration" in item.keywords: + # Skip tests with 'integration' marker BUT NOT 'offline' marker + # Offline tests use mocked external services and should run without --run-integration + if "integration" in item.keywords and "offline" not in item.keywords: item.add_marker(skip_integration) diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/test_workspace_api.py b/tests/integration/test_workspace_api.py new file mode 100644 index 0000000000..656e50d22d --- /dev/null +++ b/tests/integration/test_workspace_api.py @@ -0,0 +1,851 @@ +""" +Integration tests for Workspace API isolation via HTTP. + +These tests verify that the FastAPI workspace isolation layer correctly: +- Extracts workspace from LIGHTRAG-WORKSPACE HTTP header +- Validates workspace names +- Manages workspace lifecycle (get/create/release) +- Handles concurrent requests with proper isolation +- Supports background task and streaming response patterns +- Enforces capacity limits +- Evicts LRU entries under memory pressure + +All external services are mocked. +""" + +from __future__ import annotations + +import asyncio +from typing import AsyncGenerator +from unittest.mock import AsyncMock + +import pytest +from fastapi import BackgroundTasks, FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse, StreamingResponse +from httpx import ASGITransport, AsyncClient + +from lightrag.api.utils import ( + extract_workspace_from_header, + sanitize_workspace_name, + WorkspaceNameError, +) +from lightrag.api.workspace_manager import WorkspaceManager, WorkspaceCapacityError + + +# ============================================================================= +# Mock LightRAG Instance +# ============================================================================= + + +class MockLightRAG: + """Mock LightRAG instance for testing workspace isolation.""" + + def __init__(self, workspace: str) -> None: + self.workspace = workspace + self.finalize_called = False + self.some_method = AsyncMock(return_value={"status": "ok"}) + + async def initialize_storages(self) -> None: + """Mock initialize method.""" + pass + + async def finalize_storages(self) -> None: + """Mock finalize method.""" + self.finalize_called = True + + +def mock_factory(workspace: str) -> MockLightRAG: + """Factory function that creates MockLightRAG instances.""" + return MockLightRAG(workspace) + + +# ============================================================================= +# Test App Factory +# ============================================================================= + + +def create_test_app( + max_instances: int = 10, + factory=None, + bg_done_event: asyncio.Event | None = None, +) -> tuple[FastAPI, WorkspaceManager]: + """ + Create a minimal FastAPI app that mimics the real server's workspace routing. + + Returns (app, workspace_manager) tuple. + """ + if factory is None: + factory = mock_factory + + workspace_mgr = WorkspaceManager(factory=factory, max_instances=max_instances) + + app = FastAPI() + + def get_workspace(request: Request) -> str: + """Extract workspace from request header, returning JSONResponse on error.""" + try: + return extract_workspace_from_header(request) + except HTTPException as e: + raise e # Re-raise for FastAPI to handle + + @app.post("/test/regular") + async def regular_route(request: Request): + """Regular handler pattern with proper get/release.""" + try: + ws = get_workspace(request) + except HTTPException as e: + return JSONResponse(status_code=e.status_code, content={"error": e.detail}) + + try: + rag = await workspace_mgr.get_or_create(ws) + return JSONResponse(content={"workspace": rag.workspace}) + except WorkspaceCapacityError: + return JSONResponse( + status_code=503, content={"error": "All workspace slots busy"} + ) + finally: + workspace_mgr.release(ws) + + @app.post("/test/background") + async def background_route(request: Request, background_tasks: BackgroundTasks): + """ + Background task pattern - ref is held for duration of bg task. + """ + try: + ws = get_workspace(request) + except HTTPException as e: + return JSONResponse(status_code=e.status_code, content={"error": e.detail}) + + # Get ref before scheduling background task + rag = await workspace_mgr.get_or_create(ws) + response_data = {"workspace": rag.workspace, "ref_count_before_bg": 1} + + async def background_work(): + """Simulated background work.""" + await asyncio.sleep(0.1) + # Release after background work completes + workspace_mgr.release(ws) + # Signal that bg work is done + if bg_done_event: + bg_done_event.set() + response_data["bg_completed"] = True + + background_tasks.add_task(background_work) + + # Return immediately without releasing (bg task will release) + return JSONResponse(content=response_data) + + @app.post("/test/streaming") + async def streaming_route(request: Request): + """ + Streaming response pattern - ref held during stream. + """ + try: + ws = get_workspace(request) + except HTTPException as e: + return JSONResponse(status_code=e.status_code, content={"error": e.detail}) + + async def stream_generator() -> AsyncGenerator[str, None]: + try: + rag = await workspace_mgr.get_or_create(ws) + yield f"data: workspace={rag.workspace}\n" + await asyncio.sleep(0.05) + yield "data: done\n" + finally: + workspace_mgr.release(ws) + + return StreamingResponse( + stream_generator(), + media_type="text/event-stream", + ) + + @app.post("/test/streaming-error") + async def streaming_error_route(request: Request): + """ + Streaming response pattern that errors mid-stream - verifies finally block runs. + """ + try: + ws = get_workspace(request) + except HTTPException as e: + return JSONResponse(status_code=e.status_code, content={"error": e.detail}) + + async def stream_generator() -> AsyncGenerator[str, None]: + try: + rag = await workspace_mgr.get_or_create(ws) + yield f"data: workspace={rag.workspace}\n" + await asyncio.sleep(0.02) + yield "data: before-error\n" + # Simulate an error mid-stream + raise RuntimeError("simulated error") + except RuntimeError: + # Re-raise so the stream fails + raise + finally: + workspace_mgr.release(ws) + + return StreamingResponse( + stream_generator(), + media_type="text/event-stream", + ) + + @app.post("/test/hold-ref") + async def hold_ref_route(request: Request, duration: float = 1.0): + """ + Route that holds a workspace ref for a specified duration. + Useful for testing capacity and eviction under load. + """ + try: + ws = get_workspace(request) + except HTTPException as e: + return JSONResponse(status_code=e.status_code, content={"error": e.detail}) + + try: + rag = await workspace_mgr.get_or_create(ws) + # Hold the ref for the specified duration + await asyncio.sleep(duration) + return JSONResponse( + content={"workspace": rag.workspace, "held_for": duration} + ) + except WorkspaceCapacityError: + return JSONResponse( + status_code=503, content={"error": "All workspace slots busy"} + ) + finally: + workspace_mgr.release(ws) + + @app.get("/test/stats") + async def stats_route(request: Request): + """Return workspace manager stats.""" + return JSONResponse(content=workspace_mgr.get_stats()) + + @app.get("/test/workspace/{workspace_name}") + async def get_workspace_route(workspace_name: str, request: Request): + """Direct workspace access via path parameter.""" + try: + ws = sanitize_workspace_name(workspace_name) + except WorkspaceNameError as e: + return JSONResponse(status_code=400, content={"error": str(e)}) + + try: + rag = await workspace_mgr.get_or_create(ws) + return JSONResponse(content={"workspace": rag.workspace}) + except WorkspaceCapacityError: + return JSONResponse( + status_code=503, content={"error": "All workspace slots busy"} + ) + finally: + workspace_mgr.release(ws) + + return app, workspace_mgr + + +# ============================================================================= +# Test Fixtures +# ============================================================================= + + +@pytest.fixture +async def client(): + """Create an async HTTP client for testing.""" + app, workspace_mgr = create_test_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac, workspace_mgr + + +@pytest.fixture +async def small_client(): + """Create a client with small max_instances for capacity tests.""" + app, workspace_mgr = create_test_app(max_instances=2) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac, workspace_mgr + + +@pytest.fixture +async def eviction_client(): + """Create a client with max_instances=3 for LRU eviction tests.""" + app, workspace_mgr = create_test_app(max_instances=3) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac, workspace_mgr + + +@pytest.fixture +async def client_with_bg_event(): + """Create a client with bg_done_event for testing background task completion.""" + bg_done_event = asyncio.Event() + app, workspace_mgr = create_test_app(bg_done_event=bg_done_event) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac, workspace_mgr, bg_done_event + + +# ============================================================================= +# Test 1: Request with LIGHTRAG-WORKSPACE header -> correct workspace used +# ============================================================================= + + +@pytest.mark.offline +class TestWorkspaceHeaderExtraction: + """Tests for workspace extraction from HTTP headers.""" + + @pytest.mark.asyncio + async def test_header_workspace_is_used(self, client): + """Test that LIGHTRAG-WORKSPACE header is used correctly.""" + ac, workspace_mgr = client + + response = await ac.post( + "/test/regular", + headers={"LIGHTRAG-WORKSPACE": "ws-test"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["workspace"] == "ws-test" + + @pytest.mark.asyncio + async def test_header_case_insensitive(self, client): + """Test that workspace names are lowercased.""" + ac, workspace_mgr = client + + response = await ac.post( + "/test/regular", + headers={"LIGHTRAG-WORKSPACE": "MyWorkspace"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["workspace"] == "myworkspace" + + +# ============================================================================= +# Test 2: Request without header -> falls back to default (empty string) +# ============================================================================= + + +@pytest.mark.offline +class TestDefaultWorkspaceFallback: + """Tests for default workspace fallback behavior.""" + + @pytest.mark.asyncio + async def test_no_header_uses_default(self, client): + """Test that requests without header use empty string as default.""" + ac, workspace_mgr = client + + response = await ac.post("/test/regular") + + assert response.status_code == 200 + data = response.json() + assert data["workspace"] == "" + + @pytest.mark.asyncio + async def test_empty_header_uses_default(self, client): + """Test that empty header value uses empty string as default.""" + ac, workspace_mgr = client + + response = await ac.post( + "/test/regular", + headers={"LIGHTRAG-WORKSPACE": ""}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["workspace"] == "" + + +# ============================================================================= +# Test 3: Invalid workspace name -> HTTP 400 +# ============================================================================= + + +@pytest.mark.offline +class TestWorkspaceNameValidation: + """Tests for workspace name validation via HTTP.""" + + @pytest.mark.asyncio + async def test_invalid_special_chars_returns_400(self, client): + """Test that special characters in workspace name return 400.""" + ac, workspace_mgr = client + + response = await ac.post( + "/test/regular", + headers={"LIGHTRAG-WORKSPACE": "ws!@#"}, + ) + + assert response.status_code == 400 + assert "only lowercase letters" in response.json()["error"] + + @pytest.mark.asyncio + async def test_path_traversal_returns_400(self, client): + """Test that path traversal attempts return 400.""" + ac, workspace_mgr = client + + response = await ac.post( + "/test/regular", + headers={"LIGHTRAG-WORKSPACE": "../../etc"}, + ) + + assert response.status_code == 400 + assert "path traversal" in response.json()["error"] + + @pytest.mark.asyncio + async def test_slash_in_name_returns_400(self, client): + """Test that forward slashes in workspace name return 400.""" + ac, workspace_mgr = client + + response = await ac.post( + "/test/regular", + headers={"LIGHTRAG-WORKSPACE": "ws/name"}, + ) + + assert response.status_code == 400 + + @pytest.mark.asyncio + async def test_backslash_returns_400(self, client): + """Test that backslashes in workspace name return 400.""" + ac, workspace_mgr = client + + response = await ac.post( + "/test/regular", + headers={"LIGHTRAG-WORKSPACE": "ws\\name"}, + ) + + assert response.status_code == 400 + + @pytest.mark.asyncio + async def test_too_long_name_returns_400(self, client): + """Test that names exceeding 64 characters return 400.""" + ac, workspace_mgr = client + + response = await ac.post( + "/test/regular", + headers={"LIGHTRAG-WORKSPACE": "a" * 65}, + ) + + assert response.status_code == 400 + assert "64 characters" in response.json()["error"] + + +# ============================================================================= +# Test 4: Multiple concurrent requests to different workspaces -> each isolated +# ============================================================================= + + +@pytest.mark.offline +class TestConcurrentWorkspaceIsolation: + """Tests for concurrent request isolation.""" + + @pytest.mark.asyncio + async def test_concurrent_different_workspaces(self, client): + """Test that concurrent requests to different workspaces are isolated.""" + ac, workspace_mgr = client + + workspaces = ["ws-a", "ws-b", "ws-c", "ws-d", "ws-e"] + + # Send 5 concurrent requests + tasks = [ + ac.post("/test/regular", headers={"LIGHTRAG-WORKSPACE": ws}) + for ws in workspaces + ] + responses = await asyncio.gather(*tasks) + + # All should succeed + assert all(r.status_code == 200 for r in responses) + + # Each should have its own workspace + for i, ws in enumerate(workspaces): + data = responses[i].json() + assert data["workspace"] == ws + + # Stats should show 5 active instances + stats = workspace_mgr.get_stats() + assert stats["active_instances"] == 5 + + @pytest.mark.asyncio + async def test_concurrent_same_workspace_returns_same_instance(self, client): + """Test that concurrent requests to same workspace share instance.""" + ac, workspace_mgr = client + + # Send 5 concurrent requests to same workspace + tasks = [ + ac.post("/test/regular", headers={"LIGHTRAG-WORKSPACE": "shared-ws"}) + for _ in range(5) + ] + responses = await asyncio.gather(*tasks) + + # All should succeed + assert all(r.status_code == 200 for r in responses) + + # Stats should show 1 active instance (cached) + stats = workspace_mgr.get_stats() + assert stats["active_instances"] == 1 + assert stats["cache_hits"] >= 4 # Some hits from concurrent access + + +# ============================================================================= +# Test 5: Background task pattern - ref count held during execution +# ============================================================================= + + +@pytest.mark.offline +class TestBackgroundTaskPattern: + """Tests for background task workspace management pattern.""" + + @pytest.mark.asyncio + async def test_bg_task_holds_ref_during_execution(self, client_with_bg_event): + """Test that background task pattern holds ref count during execution.""" + ac, workspace_mgr, bg_done_event = client_with_bg_event + + # Initial stats - no ref counts + initial_stats = workspace_mgr.get_stats() + assert sum(initial_stats.get("ref_counts", {}).values()) == 0 + + # Send request with background task + response = await ac.post( + "/test/background", + headers={"LIGHTRAG-WORKSPACE": "bg-ws"}, + ) + + # Response should indicate ref was held + data = response.json() + assert data["workspace"] == "bg-ws" + assert data["ref_count_before_bg"] == 1 + + # Wait for background task to complete using event + await asyncio.wait_for(bg_done_event.wait(), timeout=2.0) + + # After bg task completes, ref should be released + stats = workspace_mgr.get_stats() + assert sum(stats.get("ref_counts", {}).values()) == 0 + + @pytest.mark.asyncio + async def test_multiple_bg_tasks_isolated(self, small_client): + """Test that multiple background tasks maintain isolation.""" + ac, workspace_mgr = small_client + + # Send multiple background task requests + tasks = [ + ac.post("/test/background", headers={"LIGHTRAG-WORKSPACE": f"bg-{i}"}) + for i in range(2) # Only 2 to avoid capacity issues + ] + responses = await asyncio.gather(*tasks) + + # All should succeed + assert all(r.status_code == 200 for r in responses) + + # Wait for completion + await asyncio.sleep(0.3) + + # All refs should be released + stats = workspace_mgr.get_stats() + assert sum(stats.get("ref_counts", {}).values()) == 0 + + +# ============================================================================= +# Test 6: Streaming response - ref count held during stream, released after +# ============================================================================= + + +@pytest.mark.offline +class TestStreamingResponsePattern: + """Tests for streaming response workspace management pattern.""" + + @pytest.mark.asyncio + async def test_streaming_holds_ref_during_stream(self, client): + """Test that streaming response holds ref count during stream.""" + ac, workspace_mgr = client + + # Initial stats - no ref counts + initial_stats = workspace_mgr.get_stats() + assert sum(initial_stats.get("ref_counts", {}).values()) == 0 + + # Start streaming request + async with ac.stream( + "POST", + "/test/streaming", + headers={"LIGHTRAG-WORKSPACE": "stream-ws"}, + ) as response: + # Read first chunk to ensure stream generator has started + first_chunk = await response.aread() + assert first_chunk # Should have some data + + # After stream started, ref should be held + stats = workspace_mgr.get_stats() + # The workspace should be in cache + assert stats["active_instances"] >= 1 + assert "stream-ws" in stats.get("ref_counts", {}) + + # After stream completes, ref should be released + stats = workspace_mgr.get_stats() + assert sum(stats.get("ref_counts", {}).values()) == 0 + + @pytest.mark.asyncio + async def test_streaming_workspace_isolated(self, client): + """Test that streaming responses maintain workspace isolation.""" + ac, workspace_mgr = client + + # Start multiple streams concurrently + async def consume_stream(ws: str) -> list[str]: + chunks = [] + async with ac.stream( + "POST", + "/test/streaming", + headers={"LIGHTRAG-WORKSPACE": ws}, + ) as response: + async for chunk in response.aiter_lines(): + chunks.append(chunk) + return chunks + + tasks = [consume_stream(f"stream-{i}") for i in range(3)] + all_chunks = await asyncio.gather(*tasks) + + # Each stream should have its own workspace + for i, chunks in enumerate(all_chunks): + workspace_line = [c for c in chunks if "workspace=" in c][0] + assert f"workspace=stream-{i}" in workspace_line + + # All refs should be released + stats = workspace_mgr.get_stats() + assert sum(stats.get("ref_counts", {}).values()) == 0 + + @pytest.mark.asyncio + async def test_streaming_error_releases_ref(self, client): + """Test that streaming error still releases workspace ref via finally block.""" + ac, workspace_mgr = client + + # Initial stats - no ref counts + initial_stats = workspace_mgr.get_stats() + assert sum(initial_stats.get("ref_counts", {}).values()) == 0 + + # Start streaming request that will error mid-stream + # The exception propagates during the request, not during iteration + error_occurred = False + try: + async with ac.stream( + "POST", + "/test/streaming-error", + headers={"LIGHTRAG-WORKSPACE": "error-ws"}, + ) as response: + # Read chunks until error occurs + async for line in response.aiter_lines(): + pass + except Exception as e: + error_occurred = True + assert "simulated error" in str(e) + + assert error_occurred + + # After stream error, ref should still be released (finally block runs) + await asyncio.sleep(0.05) # Small delay for cleanup + stats = workspace_mgr.get_stats() + assert sum(stats.get("ref_counts", {}).values()) == 0 + + +# ============================================================================= +# Test 7: WorkspaceCapacityError -> HTTP 503 +# ============================================================================= + + +@pytest.mark.offline +class TestCapacityLimit: + """Tests for workspace capacity limit enforcement.""" + + @pytest.mark.asyncio + async def test_capacity_error_returns_503(self, small_client): + """Test that capacity exhaustion returns 503 via actual HTTP requests.""" + ac, workspace_mgr = small_client + + # max_instances=2, use /test/hold-ref to hold refs via HTTP + # Start 2 concurrent requests that hold refs + tasks = [ + ac.post( + "/test/hold-ref", + params={"duration": 5.0}, + headers={"LIGHTRAG-WORKSPACE": f"holder-{i}"}, + ) + for i in range(2) + ] + + # Start both tasks concurrently + pending = [asyncio.create_task(t) for t in tasks] + + # Wait for both to start and hold refs + await asyncio.sleep(0.2) + + # At this point, both holders should have ref_count=1 + stats = workspace_mgr.get_stats() + assert stats["active_instances"] == 2 + assert stats.get("ref_counts", {}).get("holder-0", 0) == 1 + assert stats.get("ref_counts", {}).get("holder-1", 0) == 1 + + # Now try a 3rd request - should get 503 + response = await ac.post( + "/test/regular", + headers={"LIGHTRAG-WORKSPACE": "ws-c"}, + ) + assert response.status_code == 503 + assert "All workspace slots busy" in response.json()["error"] + + # Clean up - cancel the pending tasks + for task in pending: + task.cancel() + await asyncio.gather(*pending, return_exceptions=True) + + # After cleanup, refs should be released + await asyncio.sleep(0.1) + stats = workspace_mgr.get_stats() + assert sum(stats.get("ref_counts", {}).values()) == 0 + + @pytest.mark.asyncio + async def test_same_workspace_request_within_capacity(self, small_client): + """Test that same workspace requests don't exceed capacity.""" + ac, workspace_mgr = small_client + + # Make multiple requests to same workspace + # With same workspace, they should share the cached instance + tasks = [ + ac.post("/test/regular", headers={"LIGHTRAG-WORKSPACE": "shared"}) + for _ in range(5) + ] + responses = await asyncio.gather(*tasks) + + # All should succeed (same workspace = same instance) + assert all(r.status_code == 200 for r in responses) + + # Only 1 active instance + stats = workspace_mgr.get_stats() + assert stats["active_instances"] == 1 + + +# ============================================================================= +# Test 8: LRU eviction under concurrent load +# ============================================================================= + + +@pytest.mark.offline +class TestLRUEviction: + """Tests for LRU eviction under memory pressure.""" + + @pytest.mark.asyncio + async def test_lru_eviction_after_capacity_exceeded(self, eviction_client): + """Test that LRU eviction occurs after capacity exceeded.""" + ac, workspace_mgr = eviction_client + + # max_instances=3, create 5 different workspaces + for i in range(5): + response = await ac.post( + "/test/regular", + headers={"LIGHTRAG-WORKSPACE": f"evict-{i}"}, + ) + assert response.status_code == 200 + + # After creating 5 workspaces with only 3 slots, + # LRU eviction should have occurred + stats = workspace_mgr.get_stats() + + # Active instances should be <= max_instances + assert stats["active_instances"] <= 3 + + # Verify eviction happened + assert stats.get("evictions", 0) > 0 or stats["active_instances"] < 5 + + @pytest.mark.asyncio + async def test_lru_ordering_respected(self, eviction_client): + """Test that LRU ordering is respected - least recently used evicted first.""" + ac, workspace_mgr = eviction_client + + # Create 3 workspaces: ws-a, ws-b, ws-c + for ws in ["ws-a", "ws-b", "ws-c"]: + response = await ac.post( + "/test/regular", + headers={"LIGHTRAG-WORKSPACE": ws}, + ) + assert response.status_code == 200 + + # Access ws-a to make it most recently used + await ac.post("/test/regular", headers={"LIGHTRAG-WORKSPACE": "ws-a"}) + + # Release ws-b (least recently used after ws-c was accessed last) + # Wait a bit to ensure ordering + await asyncio.sleep(0.01) + + # Access ws-c + await ac.post("/test/regular", headers={"LIGHTRAG-WORKSPACE": "ws-c"}) + + # Now create ws-d - should evict ws-b (LRU) + response = await ac.post( + "/test/regular", + headers={"LIGHTRAG-WORKSPACE": "ws-d"}, + ) + assert response.status_code == 200 + + # Check that ws-b is no longer in cache + # (This is implicit - if eviction works correctly, the cache size is limited) + stats = workspace_mgr.get_stats() + assert stats["active_instances"] <= 3 + + @pytest.mark.asyncio + async def test_workspace_is_not_evicted_while_in_use(self, eviction_client): + """Test that workspaces with active refs are not evicted.""" + ac, workspace_mgr = eviction_client + + # max_instances=3 + # Use /test/hold-ref to hold refs via HTTP + + # Start 2 concurrent requests that hold refs for ws-1 and ws-2 + hold_tasks = [ + ac.post( + "/test/hold-ref", + params={"duration": 5.0}, + headers={"LIGHTRAG-WORKSPACE": f"ws-{i}"}, + ) + for i in range(1, 3) # ws-1 and ws-2 + ] + pending = [asyncio.create_task(t) for t in hold_tasks] + + # Wait for both to start and hold refs + await asyncio.sleep(0.2) + + # At this point: ws-1 and ws-2 have ref_count=1 (from hold-ref) + # Active instances = 2 + stats = workspace_mgr.get_stats() + assert stats.get("ref_counts", {}).get("ws-1", 0) == 1 + assert stats.get("ref_counts", {}).get("ws-2", 0) == 1 + + # Now create ws-3 and ws-4 via regular requests (complete and release) + for ws in ["ws-3", "ws-4"]: + response = await ac.post( + "/test/regular", + headers={"LIGHTRAG-WORKSPACE": ws}, + ) + assert response.status_code == 200 + # After these complete, ws-3 and ws-4 have ref_count=0, active_instances=4 + + # Now try to create ws-5 - this should trigger eviction + # Eviction should evict ws-3 or ws-4 (ref_count=0), NOT ws-1 or ws-2 (ref_count>0) + response = await ac.post( + "/test/regular", + headers={"LIGHTRAG-WORKSPACE": "ws-5"}, + ) + assert response.status_code == 200 + + # Check that ws-1 and ws-2 survived eviction (still in cache) + stats = workspace_mgr.get_stats() + assert "ws-1" in stats.get("ref_counts", {}) + assert "ws-2" in stats.get("ref_counts", {}) + assert stats.get("ref_counts", {}).get("ws-1", 0) >= 1 + assert stats.get("ref_counts", {}).get("ws-2", 0) >= 1 + + # Verify eviction happened (ws-3 or ws-4 should be evicted) + assert stats.get("evictions", 0) > 0 + + # Clean up + for task in pending: + task.cancel() + await asyncio.gather(*pending, return_exceptions=True) + + +# ============================================================================= +# Cleanup helper +# ============================================================================= diff --git a/tests/integration/test_workspace_registry_api.py b/tests/integration/test_workspace_registry_api.py new file mode 100644 index 0000000000..0cddf4e3a2 --- /dev/null +++ b/tests/integration/test_workspace_registry_api.py @@ -0,0 +1,544 @@ +""" +Integration tests for Workspace Registry API. + +These tests verify that the FastAPI workspace listing and auto-register mechanism +correctly: +- GET /workspaces returns empty list when no workspaces registered +- GET /workspaces returns list of registered workspaces with correct fields +- Document API calls with LIGHTRAG-WORKSPACE header auto-register the workspace +- Same workspace called again updates last_seen but keeps first_seen +- Multiple different workspaces each registered correctly +- Empty/missing workspace header is handled correctly +- Special characters in workspace names are handled + +All external services are mocked. +""" + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path +from typing import Optional + +import pytest +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +from httpx import ASGITransport, AsyncClient +from pydantic import BaseModel, Field, ConfigDict + +from lightrag.api.workspace_registry import WorkspaceRegistry, get_workspace_registry + + +# ============================================================================= +# Pydantic Models (duplicated from workspace_routes to avoid config import) +# ============================================================================= + + +class WorkspaceInfo(BaseModel): + """Workspace information model.""" + + name: str = Field(description="The workspace name") + first_seen: str = Field( + description="ISO timestamp when the workspace was first seen" + ) + last_seen: str = Field( + description="ISO timestamp when the workspace was last accessed" + ) + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "name": "my-workspace", + "first_seen": "2026-04-30T10:00:00+00:00", + "last_seen": "2026-04-30T11:30:00+00:00", + } + } + ) + + +class WorkspacesResponse(BaseModel): + """Response model for listing workspaces.""" + + workspaces: list[WorkspaceInfo] = Field(description="List of registered workspaces") + + model_config = ConfigDict( + json_schema_extra={ + "example": { + "workspaces": [ + { + "name": "my-workspace", + "first_seen": "2026-04-30T10:00:00+00:00", + "last_seen": "2026-04-30T11:30:00+00:00", + }, + ] + } + } + ) + + +# ============================================================================= +# Test App Factory +# ============================================================================= + + +def create_test_app( + working_dir: Optional[str] = None, +) -> tuple[FastAPI, WorkspaceRegistry]: + """ + Create a minimal FastAPI app that includes workspace routes. + + Args: + working_dir: Optional directory for the workspace registry storage. + + Returns: + tuple: (app, workspace_registry) - the test app and the registry instance + """ + # Reset the global registry instance to use the specified working_dir + import lightrag.api.workspace_registry as wr_module + + wr_module._registry_instance = None + wr_module._registry_lock = __import__("threading").Lock() + + # Get fresh registry with our working_dir + registry = get_workspace_registry(working_dir=working_dir) + + app = FastAPI() + + # Add workspace listing endpoint (GET /workspaces) + @app.get("/workspaces", response_model=WorkspacesResponse) + async def list_workspaces(): + """List all registered workspaces.""" + workspaces = registry.get_workspaces() + return WorkspacesResponse( + workspaces=[ + WorkspaceInfo( + name=w["name"], + first_seen=w["first_seen"], + last_seen=w["last_seen"], + ) + for w in workspaces + ] + ) + + # Add a mock document endpoint that simulates auto-register behavior + @app.post("/test/documents") + async def mock_document_endpoint(request: Request): + """Mock document endpoint that simulates workspace auto-registration.""" + workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip() + if workspace: + registry.register_workspace(workspace) + return JSONResponse( + content={"status": "ok", "workspace": workspace or "default"} + ) + + @app.post("/test/texts") + async def mock_insert_texts_endpoint(request: Request): + """Mock text insertion endpoint that simulates workspace auto-registration.""" + workspace = request.headers.get("LIGHTRAG-WORKSPACE", "").strip() + if workspace: + registry.register_workspace(workspace) + return JSONResponse( + content={"status": "ok", "workspace": workspace or "default"} + ) + + return app, registry + + +# ============================================================================= +# Test Fixtures +# ============================================================================= + + +@pytest.fixture +async def test_app(): + """Create a test app with temporary directory for registry storage.""" + with tempfile.TemporaryDirectory() as tmp_dir: + app, registry = create_test_app(working_dir=tmp_dir) + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac, registry, Path(tmp_dir) + + +# ============================================================================= +# Test 1: GET /workspaces returns empty list when no workspaces registered +# ============================================================================= + + +@pytest.mark.offline +class TestGetWorkspacesEmpty: + """Tests for GET /workspaces with no registered workspaces.""" + + @pytest.mark.asyncio + async def test_empty_workspaces_returns_empty_list(self, test_app): + """Test that GET /workspaces returns empty list when no workspaces registered.""" + ac, registry, _ = test_app + + response = await ac.get("/workspaces") + + assert response.status_code == 200 + data = response.json() + assert "workspaces" in data + assert data["workspaces"] == [] + + @pytest.mark.asyncio + async def test_empty_workspaces_response_is_valid_json(self, test_app): + """Test that response is valid JSON with expected structure.""" + ac, registry, _ = test_app + + response = await ac.get("/workspaces") + + assert response.status_code == 200 + # Verify it's valid JSON + data = response.json() + assert isinstance(data, dict) + assert "workspaces" in data + assert isinstance(data["workspaces"], list) + + +# ============================================================================= +# Test 2: GET /workspaces returns registered workspaces with correct fields +# ============================================================================= + + +@pytest.mark.offline +class TestGetWorkspacesWithData: + """Tests for GET /workspaces with registered workspaces.""" + + @pytest.mark.asyncio + async def test_workspace_has_required_fields(self, test_app): + """Test that registered workspace has all required fields.""" + ac, registry, _ = test_app + + # Register a workspace directly + registry.register_workspace("test-workspace") + + response = await ac.get("/workspaces") + + assert response.status_code == 200 + data = response.json() + assert len(data["workspaces"]) == 1 + + workspace = data["workspaces"][0] + assert "name" in workspace + assert "first_seen" in workspace + assert "last_seen" in workspace + assert workspace["name"] == "test-workspace" + + @pytest.mark.asyncio + async def test_workspace_timestamps_are_iso_format(self, test_app): + """Test that timestamps are in ISO format.""" + ac, registry, _ = test_app + + registry.register_workspace("test-workspace") + + response = await ac.get("/workspaces") + + data = response.json() + workspace = data["workspaces"][0] + + # Check that timestamps look like ISO format (contain T and timezone) + assert "T" in workspace["first_seen"] + assert "T" in workspace["last_seen"] + # Should contain timezone offset (+00:00) or Z + assert "+" in workspace["first_seen"] or workspace["first_seen"].endswith("Z") + assert "+" in workspace["last_seen"] or workspace["last_seen"].endswith("Z") + + @pytest.mark.asyncio + async def test_multiple_workspaces_all_appear(self, test_app): + """Test that multiple registered workspaces all appear in response.""" + ac, registry, _ = test_app + + # Register multiple workspaces + registry.register_workspace("workspace-a") + registry.register_workspace("workspace-b") + registry.register_workspace("workspace-c") + + response = await ac.get("/workspaces") + + assert response.status_code == 200 + data = response.json() + assert len(data["workspaces"]) == 3 + + workspace_names = {w["name"] for w in data["workspaces"]} + assert workspace_names == {"workspace-a", "workspace-b", "workspace-c"} + + @pytest.mark.asyncio + async def test_workspaces_sorted_by_last_seen_descending(self, test_app): + """Test that workspaces are sorted by last_seen (most recent first).""" + ac, registry, _ = test_app + + # Register workspaces in order + registry.register_workspace("first") + registry.register_workspace("second") + registry.register_workspace("third") + + response = await ac.get("/workspaces") + + data = response.json() + names = [w["name"] for w in data["workspaces"]] + # Most recently seen (third) should be first + assert names[0] == "third" + assert names[1] == "second" + assert names[2] == "first" + + +# ============================================================================= +# Test 3: Auto-register on document API calls +# ============================================================================= + + +@pytest.mark.offline +class TestAutoRegister: + """Tests for auto-register mechanism via document API calls.""" + + @pytest.mark.asyncio + async def test_document_call_with_header_registers_workspace(self, test_app): + """Test that calling document endpoint with header auto-registers workspace.""" + ac, registry, _ = test_app + + # Call mock document endpoint with workspace header + response = await ac.post( + "/test/documents", headers={"LIGHTRAG-WORKSPACE": "new-workspace"} + ) + + assert response.status_code == 200 + + # Verify workspace was registered + response = await ac.get("/workspaces") + data = response.json() + assert len(data["workspaces"]) == 1 + assert data["workspaces"][0]["name"] == "new-workspace" + + @pytest.mark.asyncio + async def test_texts_endpoint_also_registers_workspace(self, test_app): + """Test that texts insertion endpoint also auto-registers workspace.""" + ac, registry, _ = test_app + + # Call mock texts endpoint with workspace header + response = await ac.post( + "/test/texts", headers={"LIGHTRAG-WORKSPACE": "text-workspace"} + ) + + assert response.status_code == 200 + + # Verify workspace was registered + response = await ac.get("/workspaces") + data = response.json() + assert len(data["workspaces"]) == 1 + assert data["workspaces"][0]["name"] == "text-workspace" + + @pytest.mark.asyncio + async def test_same_workspace_updates_last_seen(self, test_app): + """Test that same workspace called again updates last_seen but keeps first_seen.""" + ac, registry, _ = test_app + + # Register workspace first time + response1 = await ac.post( + "/test/documents", headers={"LIGHTRAG-WORKSPACE": "my-workspace"} + ) + assert response1.status_code == 200 + + # Get the first_seen and last_seen + response = await ac.get("/workspaces") + data = response.json() + first_seen_before = data["workspaces"][0]["first_seen"] + last_seen_before = data["workspaces"][0]["last_seen"] + + # Wait a tiny bit to ensure timestamp difference + import asyncio + + await asyncio.sleep(0.01) + + # Register same workspace again + response2 = await ac.post( + "/test/documents", headers={"LIGHTRAG-WORKSPACE": "my-workspace"} + ) + assert response2.status_code == 200 + + # Verify first_seen is preserved but last_seen is updated + response = await ac.get("/workspaces") + data = response.json() + assert len(data["workspaces"]) == 1 # Still only one workspace + assert data["workspaces"][0]["name"] == "my-workspace" + assert data["workspaces"][0]["first_seen"] == first_seen_before + assert data["workspaces"][0]["last_seen"] >= last_seen_before + + @pytest.mark.asyncio + async def test_multiple_different_workspaces_registered_correctly(self, test_app): + """Test that multiple different workspaces are each registered correctly.""" + ac, registry, _ = test_app + + workspaces = ["ws-alpha", "ws-beta", "ws-gamma"] + + for ws in workspaces: + response = await ac.post( + "/test/documents", headers={"LIGHTRAG-WORKSPACE": ws} + ) + assert response.status_code == 200 + + # Verify all workspaces are registered + response = await ac.get("/workspaces") + data = response.json() + assert len(data["workspaces"]) == 3 + + registered_names = {w["name"] for w in data["workspaces"]} + assert registered_names == set(workspaces) + + +# ============================================================================= +# Test 4: Edge cases +# ============================================================================= + + +@pytest.mark.offline +class TestEdgeCases: + """Tests for edge cases in workspace registry.""" + + @pytest.mark.asyncio + async def test_missing_workspace_header(self, test_app): + """Test that missing workspace header doesn't register anything.""" + ac, registry, _ = test_app + + # Call without workspace header + response = await ac.post("/test/documents") + + assert response.status_code == 200 + assert response.json()["workspace"] == "default" + + # Verify no workspace was registered + response = await ac.get("/workspaces") + data = response.json() + assert data["workspaces"] == [] + + @pytest.mark.asyncio + async def test_empty_workspace_header(self, test_app): + """Test that empty workspace header value doesn't register anything.""" + ac, registry, _ = test_app + + # Call with empty workspace header + response = await ac.post("/test/documents", headers={"LIGHTRAG-WORKSPACE": ""}) + + assert response.status_code == 200 + + # Verify no workspace was registered (empty string is not registered) + response = await ac.get("/workspaces") + data = response.json() + assert data["workspaces"] == [] + + @pytest.mark.asyncio + async def test_special_characters_in_workspace_name(self, test_app): + """Test workspace names with special characters are handled correctly.""" + ac, registry, _ = test_app + + # Register workspace with underscore and numbers (valid) + registry.register_workspace("workspace_123") + + response = await ac.get("/workspaces") + data = response.json() + + assert len(data["workspaces"]) == 1 + assert data["workspaces"][0]["name"] == "workspace_123" + + @pytest.mark.asyncio + async def test_workspace_name_case_preserved(self, test_app): + """Test that workspace names preserve their case.""" + ac, registry, _ = test_app + + # Register workspace with mixed case + registry.register_workspace("MyWorkspace") + + response = await ac.get("/workspaces") + data = response.json() + + assert len(data["workspaces"]) == 1 + # The name should be preserved as registered + assert data["workspaces"][0]["name"] == "MyWorkspace" + + +# ============================================================================= +# Test 5: Registry persistence +# ============================================================================= + + +@pytest.mark.offline +class TestRegistryPersistence: + """Tests for workspace registry persistence.""" + + @pytest.mark.asyncio + async def test_registry_persists_to_disk(self, test_app): + """Test that registry is persisted to disk.""" + ac, registry, tmp_dir = test_app + + # Register a workspace + registry.register_workspace("persistent-workspace") + + # Verify the registry file exists + registry_file = Path(tmp_dir) / "workspace_registry.json" + assert registry_file.exists() + + # Read and verify content + with open(registry_file, "r") as f: + data = json.load(f) + + assert "persistent-workspace" in data + assert data["persistent-workspace"]["name"] == "persistent-workspace" + assert "first_seen" in data["persistent-workspace"] + assert "last_seen" in data["persistent-workspace"] + + +# ============================================================================= +# Test 6: Response format validation +# ============================================================================= + + +@pytest.mark.offline +class TestResponseFormat: + """Tests for response format validation.""" + + @pytest.mark.asyncio + async def test_response_structure_matches_schema(self, test_app): + """Test that response structure matches the WorkspacesResponse schema.""" + ac, registry, _ = test_app + + # Register a workspace + registry.register_workspace("schema-test") + + response = await ac.get("/workspaces") + + assert response.status_code == 200 + data = response.json() + + # Verify top-level structure + assert isinstance(data, dict) + assert "workspaces" in data + + # Verify workspaces array + assert isinstance(data["workspaces"], list) + + # Verify workspace object structure + if data["workspaces"]: + workspace = data["workspaces"][0] + assert isinstance(workspace.get("name"), str) + assert isinstance(workspace.get("first_seen"), str) + assert isinstance(workspace.get("last_seen"), str) + + @pytest.mark.asyncio + async def test_response_is_json_serializable(self, test_app): + """Test that response is valid JSON and can be serialized.""" + ac, registry, _ = test_app + + registry.register_workspace("json-test") + + response = await ac.get("/workspaces") + + # Get raw content to ensure it's valid JSON + content = response.content + + # Should be able to parse it as JSON without error + data = json.loads(content) + assert isinstance(data, dict) + + # Should be able to serialize it back to JSON + serialized = json.dumps(data) + assert serialized is not None diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000000..8a4ee4c3ad --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1 @@ +"""Unit tests for LightRAG components.""" diff --git a/tests/unit/test_workspace_manager.py b/tests/unit/test_workspace_manager.py new file mode 100644 index 0000000000..1cbbd313a5 --- /dev/null +++ b/tests/unit/test_workspace_manager.py @@ -0,0 +1,801 @@ +"""Unit tests for WorkspaceManager and sanitize_workspace_name.""" + +from __future__ import annotations + +import asyncio + +import pytest + +from lightrag.api.utils import sanitize_workspace_name, WorkspaceNameError +from lightrag.api.workspace_manager import WorkspaceManager, WorkspaceCapacityError + +pytestmark = pytest.mark.offline + + +class MockLightRAG: + """Mock LightRAG instance for testing.""" + + def __init__(self, workspace: str) -> None: + self.workspace = workspace + self.finalize_called = False + self._finalize_task = None + + async def initialize_storages(self) -> None: + """Mock initialize method.""" + pass + + async def finalize_storages(self) -> None: + """Mock finalize method that sets a flag.""" + self.finalize_called = True + + +def mock_factory(workspace: str) -> MockLightRAG: + """Factory function that creates MockLightRAG instances.""" + return MockLightRAG(workspace) + + +# ============================================================================= +# sanitize_workspace_name Tests +# ============================================================================= + + +class TestSanitizeWorkspaceName: + """Tests for the sanitize_workspace_name function.""" + + def test_none_returns_empty(self) -> None: + """Test that None input returns empty string.""" + assert sanitize_workspace_name(None) == "" + + def test_empty_returns_empty(self) -> None: + """Test that empty string input returns empty string.""" + assert sanitize_workspace_name("") == "" + + def test_whitespace_stripped(self) -> None: + """Test that leading/trailing whitespace is stripped.""" + assert sanitize_workspace_name(" my-workspace ") == "my-workspace" + + def test_lowercased(self) -> None: + """Test that uppercase letters are converted to lowercase.""" + assert sanitize_workspace_name("MyWorkspace") == "myworkspace" + + def test_valid_name(self) -> None: + """Test that valid names pass through unchanged (after lowercasing).""" + assert sanitize_workspace_name("ws-1_test") == "ws-1_test" + + def test_path_traversal_rejected(self) -> None: + """Test that path traversal attempts are rejected.""" + with pytest.raises(WorkspaceNameError, match="path traversal detected"): + sanitize_workspace_name("../etc") + + def test_slash_rejected(self) -> None: + """Test that forward slashes are rejected.""" + with pytest.raises(WorkspaceNameError, match="path traversal detected"): + sanitize_workspace_name("ws/name") + + def test_backslash_rejected(self) -> None: + """Test that backslashes are rejected.""" + with pytest.raises(WorkspaceNameError, match="path traversal detected"): + sanitize_workspace_name("ws\\name") + + def test_too_long_rejected(self) -> None: + """Test that names exceeding 64 characters are rejected.""" + long_name = "a" * 65 + with pytest.raises(WorkspaceNameError, match="max 64 characters"): + sanitize_workspace_name(long_name) + + def test_max_length_accepted(self) -> None: + """Test that 64-character names are accepted.""" + max_name = "a" * 64 + assert sanitize_workspace_name(max_name) == max_name + + def test_special_chars_rejected(self) -> None: + """Test that special characters are rejected.""" + with pytest.raises(WorkspaceNameError, match="only lowercase letters.*allowed"): + sanitize_workspace_name("ws!@#") + + # ------------------------------------------------------------------------- + # Additional Edge Cases + # ------------------------------------------------------------------------- + + def test_null_bytes_rejected(self) -> None: + """Test that null bytes in name are rejected.""" + with pytest.raises(WorkspaceNameError, match="only lowercase letters.*allowed"): + sanitize_workspace_name("ws\x00name") + + def test_unicode_chars_rejected(self) -> None: + """Test that unicode characters are rejected.""" + with pytest.raises(WorkspaceNameError, match="only lowercase letters.*allowed"): + sanitize_workspace_name("workspaceζ—₯本θͺž") + + def test_mixed_valid_invalid_space_rejected(self) -> None: + """Test that space in name is rejected.""" + with pytest.raises(WorkspaceNameError, match="only lowercase letters.*allowed"): + sanitize_workspace_name("ws-1 test") + + def test_65_chars_rejected(self) -> None: + """Test that 65-character names are rejected.""" + name_65 = "a" * 65 + with pytest.raises(WorkspaceNameError, match="max 64 characters"): + sanitize_workspace_name(name_65) + + def test_leading_trailing_spaces_trimmed_and_lower(self) -> None: + """Test that leading/trailing spaces are trimmed and lowercased.""" + assert sanitize_workspace_name(" MyWorkspace ") == "myworkspace" + + +# ============================================================================= +# WorkspaceManager Tests +# ============================================================================= + + +class TestWorkspaceManager: + """Tests for the WorkspaceManager class.""" + + @pytest.fixture + def fresh_manager(self) -> WorkspaceManager: + """Create a fresh WorkspaceManager instance for each test.""" + return WorkspaceManager(factory=mock_factory, max_instances=10) + + @pytest.fixture + def small_manager(self) -> WorkspaceManager: + """Create a WorkspaceManager with small capacity for eviction tests.""" + return WorkspaceManager(factory=mock_factory, max_instances=3) + + # ------------------------------------------------------------------------- + # Basic get_or_create Tests + # ------------------------------------------------------------------------- + + @pytest.mark.asyncio + async def test_get_or_create_creates_instance(self, fresh_manager) -> None: + """Test that first call to get_or_create creates a new instance.""" + instance = await fresh_manager.get_or_create("ws-a") + assert isinstance(instance, MockLightRAG) + assert instance.workspace == "ws-a" + assert fresh_manager.get_stats()["active_instances"] == 1 + + @pytest.mark.asyncio + async def test_get_or_create_returns_cached(self, fresh_manager) -> None: + """Test that second call returns the same cached instance.""" + instance1 = await fresh_manager.get_or_create("ws-a") + instance2 = await fresh_manager.get_or_create("ws-a") + assert instance1 is instance2 + assert fresh_manager.get_stats()["active_instances"] == 1 + assert fresh_manager.get_stats()["cache_hits"] == 1 + + # ------------------------------------------------------------------------- + # LRU Eviction Tests + # ------------------------------------------------------------------------- + + @pytest.mark.asyncio + async def test_lru_eviction_at_capacity(self, small_manager) -> None: + """Test LRU eviction when cache reaches capacity.""" + # Create 3 workspaces (fills cache) + ws_a = await small_manager.get_or_create("ws-a") + await small_manager.get_or_create("ws-b") + await small_manager.get_or_create("ws-c") + + # All should be in cache + assert small_manager.get_stats()["active_instances"] == 3 + + # Release ws-a so it can be evicted + small_manager.release("ws-a") + + # Create ws-d - should evict ws-a (LRU with ref_count=0) + await small_manager.get_or_create("ws-d") + + # ws-d should be in cache + assert small_manager.get_stats()["active_instances"] == 3 + # ws-a should have been evicted + assert ws_a.finalize_called is True + + @pytest.mark.asyncio + async def test_eviction_skips_in_flight(self, small_manager) -> None: + """Test that in-flight workspaces (ref_count > 0) are not evicted. + + Scenario: + - Create ws-1, ws-2, ws-3 (all have ref_count > 0) + - Release ws-2 (ref_count becomes 0, making it evictable) + - Create ws-4 β†’ should evict ws-2 + - Assert ws-1 survived (ref_count > 0) + """ + # Acquire ws-1 (don't release) + ws1 = await small_manager.get_or_create("ws-1") + + # Fill cache with others (all will have ref_count > 0) + await small_manager.get_or_create("ws-2") + await small_manager.get_or_create("ws-3") + + # Cache is full, all have ref_count > 0 + assert small_manager.get_stats()["active_instances"] == 3 + + # Release ws-2 to make it evictable (ref_count=0) + small_manager.release("ws-2") + + # Create ws-4 - should evict ws-2 (LRU with ref_count=0) + await small_manager.get_or_create("ws-4") + assert small_manager.get_stats()["active_instances"] == 3 + + # ws-1 should NOT have been evicted (ref_count > 0) + assert ws1.finalize_called is False + assert small_manager.get_stats()["ref_counts"]["ws-1"] == 1 + + @pytest.mark.asyncio + async def test_lru_eviction_order_verification(self, small_manager) -> None: + """Test LRU eviction picks LEAST recently used, not just oldest created. + + Scenario: + - Create ws-1, ws-2, ws-3 (in that order) + - Access ws-1 again (moves it to end of LRU) + - Release ws-2 (ref_count=0) + - Create ws-4 β†’ should evict ws-2 (least recently used with ref_count=0), + NOT ws-3 (oldest created but not LRU anymore) + """ + # Create 3 workspaces (fills cache) + ws_1 = await small_manager.get_or_create("ws-1") + ws_2 = await small_manager.get_or_create("ws-2") + ws_3 = await small_manager.get_or_create("ws-3") + + # Now access order is: ws-1, ws-2, ws-3 (ws-1 was created first) + + # Access ws-1 again - this moves it to end of LRU order + # Now LRU order is: ws-2 (oldest accessed), ws-3, ws-1 (most recent) + await small_manager.get_or_create("ws-1") + + # Release ws-2 (ref_count=0, but ws-1 is most recent, ws-3 is middle) + small_manager.release("ws-2") + + # Create ws-4 - should evict ws-2 (LRU with ref_count=0) + ws_4 = await small_manager.get_or_create("ws-4") + + # Verify ws-4 is in cache + assert small_manager.get_stats()["active_instances"] == 3 + + # ws-2 should have been evicted (least recently used with ref_count=0) + assert ws_2.finalize_called is True + + # ws-1 should NOT have been evicted (ref_count > 0 and was recently accessed) + assert ws_1.finalize_called is False + + # ws-3 should NOT have been evicted + assert ws_3.finalize_called is False + + # Verify ws-4 is usable + assert ws_4.workspace == "ws-4" + + @pytest.mark.asyncio + async def test_eviction_of_most_recently_accessed(self, small_manager) -> None: + """Test eviction respects access order, not creation order. + + Scenario: + - Create ws-1, ws-2, ws-3 (in that order) + - Access ws-1 again (move to end of LRU) + - Release ws-2 + - Create ws-4 + - Verify ws-2 was evicted (not ws-1, which was recently accessed) + """ + # Create ws-1, ws-2, ws-3 (fills capacity) + ws_1 = await small_manager.get_or_create("ws-1") + ws_2 = await small_manager.get_or_create("ws-2") + ws_3 = await small_manager.get_or_create("ws-3") + + # Access ws-1 again to move it to end (most recently accessed) + await small_manager.get_or_create("ws-1") + + # Release ws-2 (now ref_count=0, LRU candidate) + small_manager.release("ws-2") + + # Create ws-4 - should evict ws-2 + await small_manager.get_or_create("ws-4") + + # Verify state + assert small_manager.get_stats()["active_instances"] == 3 + + # ws-2 should be evicted (oldest accessed with ref_count=0) + assert ws_2.finalize_called is True + + # ws-1 should NOT be evicted (was most recently accessed) + assert ws_1.finalize_called is False + assert ( + small_manager.get_stats()["ref_counts"]["ws-1"] == 2 + ) # created + accessed + + # ws-3 should NOT be evicted + assert ws_3.finalize_called is False + + @pytest.mark.asyncio + async def test_reference_counting_exact_tracking(self, fresh_manager) -> None: + """Test reference counting with exact tracking through multiple operations. + + Scenario: + - get_or_create 3 times for same workspace β†’ ref_count=3 + - release 2 times β†’ ref_count=1 + - get_or_create again β†’ ref_count=2 + """ + # Get ws-a 3 times β†’ ref_count should be 3 + await fresh_manager.get_or_create("ws-a") + await fresh_manager.get_or_create("ws-a") + await fresh_manager.get_or_create("ws-a") + + assert fresh_manager.get_stats()["ref_counts"]["ws-a"] == 3 + + # Release twice β†’ ref_count should be 1 + fresh_manager.release("ws-a") + fresh_manager.release("ws-a") + + assert fresh_manager.get_stats()["ref_counts"]["ws-a"] == 1 + + # Get again β†’ ref_count should be 2 + await fresh_manager.get_or_create("ws-a") + + assert fresh_manager.get_stats()["ref_counts"]["ws-a"] == 2 + + # Verify active instances is still 1 (same workspace) + assert fresh_manager.get_stats()["active_instances"] == 1 + + @pytest.mark.asyncio + async def test_workspace_capacity_error_recovery_after_release( + self, small_manager + ) -> None: + """Test recovery from WorkspaceCapacityError after releasing a workspace. + + Scenario: + - Fill capacity (max=3), all with ref_count > 0 + - WorkspaceCapacityError on 4th attempt + - Release one workspace + - Retry β†’ should succeed + - Verify the newly created workspace is in cache + """ + # Fill cache with 3 workspaces (all with ref_count > 0) + ws_1 = await small_manager.get_or_create("ws-1") + ws_2 = await small_manager.get_or_create("ws-2") + ws_3 = await small_manager.get_or_create("ws-3") + + assert small_manager.get_stats()["active_instances"] == 3 + + # Try to create 4th workspace - should raise WorkspaceCapacityError + with pytest.raises(WorkspaceCapacityError): + await small_manager.get_or_create("ws-4") + + # Release ws-2 to free up a slot + small_manager.release("ws-2") + + # Retry - should succeed now + ws_4 = await small_manager.get_or_create("ws-4") + + assert ws_4 is not None + assert ws_4.workspace == "ws-4" + + # Verify ws-4 is in cache + stats = small_manager.get_stats() + assert stats["active_instances"] == 3 + assert "ws-4" in stats["workspaces"] + + # ws-2 should have been evicted + assert ws_2.finalize_called is True + + # ws-1 and ws-3 should still be active + assert ws_1.finalize_called is False + assert ws_3.finalize_called is False + + @pytest.mark.asyncio + async def test_capacity_error_when_all_busy(self, small_manager) -> None: + """Test WorkspaceCapacityError when all slots have in-flight requests.""" + # Fill cache with all in-flight workspaces + await small_manager.get_or_create("ws-1") + await small_manager.get_or_create("ws-2") + await small_manager.get_or_create("ws-3") + + # All have ref_count > 0, next request should raise error + with pytest.raises(WorkspaceCapacityError): + await small_manager.get_or_create("ws-4") + + @pytest.mark.asyncio + async def test_release_decrements_ref_count(self, fresh_manager) -> None: + """Test that release decrements the reference count.""" + await fresh_manager.get_or_create("ws-a") + await fresh_manager.get_or_create("ws-a") + + assert fresh_manager.get_stats()["ref_counts"]["ws-a"] == 2 + + fresh_manager.release("ws-a") + + assert fresh_manager.get_stats()["ref_counts"]["ws-a"] == 1 + + @pytest.mark.asyncio + async def test_finalize_called_on_eviction(self, small_manager) -> None: + """Test that finalize_storages is called when workspace is evicted.""" + # Create and release ws-a + ws_a = await small_manager.get_or_create("ws-a") + small_manager.release("ws-a") + + # Create ws-b, ws-c to fill cache + await small_manager.get_or_create("ws-b") + await small_manager.get_or_create("ws-c") + + # Create ws-d to trigger eviction of ws-a + await small_manager.get_or_create("ws-d") + + # ws-a should have been finalized + assert ws_a.finalize_called is True + + @pytest.mark.asyncio + async def test_finalize_called_on_shutdown(self, fresh_manager) -> None: + """Test that all instances are finalized on shutdown.""" + ws1 = await fresh_manager.get_or_create("ws-1") + ws2 = await fresh_manager.get_or_create("ws-2") + + await fresh_manager.shutdown() + + assert ws1.finalize_called is True + assert ws2.finalize_called is True + + @pytest.mark.asyncio + async def test_shutdown_with_active_instances(self, fresh_manager) -> None: + """Test shutdown finalizes ALL instances including those with ref_count > 0. + + After shutdown: + - get_stats should show 0 active instances + - workspaces list should be empty + - All finalize_called flags should be True + """ + # Create 3 workspaces with ref_count > 0 (don't release) + ws1 = await fresh_manager.get_or_create("ws-1") + ws2 = await fresh_manager.get_or_create("ws-2") + ws3 = await fresh_manager.get_or_create("ws-3") + + # All have ref_count > 0 + assert fresh_manager.get_stats()["active_instances"] == 3 + assert fresh_manager.get_stats()["ref_counts"]["ws-1"] == 1 + assert fresh_manager.get_stats()["ref_counts"]["ws-2"] == 1 + assert fresh_manager.get_stats()["ref_counts"]["ws-3"] == 1 + + # Shutdown should finalize ALL instances + await fresh_manager.shutdown() + + # All instances should be finalized + assert ws1.finalize_called is True + assert ws2.finalize_called is True + assert ws3.finalize_called is True + + # Stats should show 0 active instances + stats = fresh_manager.get_stats() + assert stats["active_instances"] == 0 + assert stats["workspaces"] == [] + + @pytest.mark.asyncio + async def test_shutdown_clears_cache(self, fresh_manager) -> None: + """Test that shutdown clears the cache.""" + await fresh_manager.get_or_create("ws-1") + await fresh_manager.get_or_create("ws-2") + + assert fresh_manager.get_stats()["active_instances"] == 2 + + await fresh_manager.shutdown() + + assert fresh_manager.get_stats()["active_instances"] == 0 + assert fresh_manager.get_stats()["workspaces"] == [] + + @pytest.mark.asyncio + async def test_get_stats_accuracy(self, fresh_manager) -> None: + """Test that get_stats returns accurate information.""" + # Create ws-a (cache miss) + await fresh_manager.get_or_create("ws-a") + + # Get ws-a again (cache hit) + await fresh_manager.get_or_create("ws-a") + + # Create ws-b (cache miss) + await fresh_manager.get_or_create("ws-b") + + # Get ws-b again (cache hit) + await fresh_manager.get_or_create("ws-b") + + # Release ws-a + fresh_manager.release("ws-a") + + stats = fresh_manager.get_stats() + + assert stats["active_instances"] == 2 + assert stats["max_instances"] == 10 + assert stats["cache_hits"] == 2 + assert stats["cache_misses"] == 2 + assert stats["evictions"] == 0 + assert set(stats["workspaces"]) == {"ws-a", "ws-b"} + assert stats["ref_counts"]["ws-a"] == 1 + assert stats["ref_counts"]["ws-b"] == 2 + + @pytest.mark.asyncio + async def test_stats_reporting_accuracy(self, small_manager) -> None: + """Test stats reporting with cache_hits, cache_misses, evictions, and ref_counts. + + Extended version that verifies: + - cache_hits and cache_misses are correctly tracked + - evictions counter increments on each eviction + - workspaces list matches actual cache contents + - ref_counts dictionary is accurate + """ + # Create 3 workspaces (3 cache misses) + await small_manager.get_or_create("ws-1") + await small_manager.get_or_create("ws-2") + await small_manager.get_or_create("ws-3") + + # Get ws-1 twice (2 cache hits) + await small_manager.get_or_create("ws-1") + await small_manager.get_or_create("ws-1") + + stats = small_manager.get_stats() + assert stats["cache_misses"] == 3 + assert stats["cache_hits"] == 2 + assert stats["evictions"] == 0 + assert stats["ref_counts"]["ws-1"] == 3 + assert stats["ref_counts"]["ws-2"] == 1 + assert stats["ref_counts"]["ws-3"] == 1 + + # Release ws-1 twice (ref_count goes from 3 to 1) + small_manager.release("ws-1") + small_manager.release("ws-1") + + # Release ws-2 (ref_count goes from 1 to 0) + small_manager.release("ws-2") + + # Create ws-4 β†’ should evict ws-2 (eviction #1) + await small_manager.get_or_create("ws-4") + + stats = small_manager.get_stats() + assert stats["cache_misses"] == 4 + assert stats["cache_hits"] == 2 + assert stats["evictions"] == 1 + assert set(stats["workspaces"]) == {"ws-1", "ws-3", "ws-4"} + assert stats["ref_counts"]["ws-1"] == 1 + # ws-2 was evicted and is no longer in ref_counts + assert "ws-2" not in stats["ref_counts"] + assert stats["ref_counts"]["ws-3"] == 1 + assert stats["ref_counts"]["ws-4"] == 1 + + # Release ws-3 (ref_count goes from 1 to 0) + small_manager.release("ws-3") + + # Create ws-5 β†’ should evict ws-3 (eviction #2) + await small_manager.get_or_create("ws-5") + + stats = small_manager.get_stats() + assert stats["evictions"] == 2 + assert set(stats["workspaces"]) == {"ws-1", "ws-4", "ws-5"} + # ws-3 was evicted and is no longer in ref_counts + assert "ws-3" not in stats["ref_counts"] + + @pytest.mark.asyncio + async def test_none_workspace_maps_to_empty(self, fresh_manager) -> None: + """Test that None workspace uses empty string as key.""" + instance1 = await fresh_manager.get_or_create(None) + instance2 = await fresh_manager.get_or_create("") + + # Both should return the same instance (empty string key) + assert instance1 is instance2 + assert instance1.workspace == "" + assert fresh_manager.get_stats()["active_instances"] == 1 + + @pytest.mark.asyncio + async def test_release_nonexistent_logs_warning(self, fresh_manager) -> None: + """Test that releasing unknown workspace doesn't crash.""" + # Should not raise, just log a warning + fresh_manager.release("nonexistent-workspace") + fresh_manager.release(None) # Should be fine since no workspaces exist + + @pytest.mark.asyncio + async def test_async_lock_prevents_double_creation(self, fresh_manager) -> None: + """Test that concurrent get_or_create for same workspace creates only one instance.""" + workspace_name = "concurrent-ws" + + # Create multiple concurrent requests for the same workspace + async def get_workspace() -> MockLightRAG: + return await fresh_manager.get_or_create(workspace_name) + + results = await asyncio.gather( + get_workspace(), + get_workspace(), + get_workspace(), + ) + + # All should return the same instance + assert all(r is results[0] for r in results) + assert fresh_manager.get_stats()["active_instances"] == 1 + assert fresh_manager.get_stats()["cache_hits"] == 2 # 2nd and 3rd are hits + assert fresh_manager.get_stats()["cache_misses"] == 1 # 1st is miss + + @pytest.mark.asyncio + async def test_concurrent_get_or_create_deduplication(self, fresh_manager) -> None: + """Test factory is called exactly ONCE even with 10 concurrent requests. + + Uses a side-effect counter to track factory calls. + """ + factory_call_count = 0 + + def counting_factory(workspace: str) -> MockLightRAG: + nonlocal factory_call_count + factory_call_count += 1 + return MockLightRAG(workspace) + + manager = WorkspaceManager(factory=counting_factory, max_instances=10) + workspace_name = "concurrent-dedup-ws" + + async def get_workspace() -> MockLightRAG: + return await manager.get_or_create(workspace_name) + + # Launch 10 concurrent requests + results = await asyncio.gather(*[get_workspace() for _ in range(10)]) + + # Factory should be called exactly once + assert factory_call_count == 1 + + # All results should be the same instance + assert all(r is results[0] for r in results) + + # Stats should show 1 cache miss and 9 cache hits + stats = manager.get_stats() + assert stats["cache_misses"] == 1 + assert stats["cache_hits"] == 9 + assert stats["active_instances"] == 1 + + # ------------------------------------------------------------------------- + # Edge Cases + # ------------------------------------------------------------------------- + + @pytest.mark.asyncio + async def test_multiple_releases_no_underflow(self, fresh_manager) -> None: + """Test that multiple releases don't cause ref_count underflow.""" + await fresh_manager.get_or_create("ws-a") + + fresh_manager.release("ws-a") + fresh_manager.release("ws-a") + fresh_manager.release("ws-a") + + # Should not crash, ref_count should be 0 (clamped) + assert fresh_manager.get_stats()["ref_counts"]["ws-a"] == 0 + + @pytest.mark.asyncio + async def test_different_workspaces_independent(self, fresh_manager) -> None: + """Test that different workspaces are managed independently.""" + ws1 = await fresh_manager.get_or_create("ws-1") + ws2 = await fresh_manager.get_or_create("ws-2") + + assert ws1 is not ws2 + assert ws1.workspace == "ws-1" + assert ws2.workspace == "ws-2" + + fresh_manager.release("ws-1") + fresh_manager.release("ws-1") # Release twice + + # ws-2 should still have ref_count > 0 + assert fresh_manager.get_stats()["ref_counts"]["ws-2"] == 1 + + # ------------------------------------------------------------------------- + # Factory Failure Tests + # ------------------------------------------------------------------------- + + @pytest.mark.asyncio + async def test_factory_failure_cleans_up_state(self, fresh_manager) -> None: + """Test that factory failure leaves no partial state in cache or ref_counts. + + Scenario: + - Create a factory that raises RuntimeError on first call + - Call get_or_create and verify exception propagates + - Assert no partial state: workspace should NOT be in _cache or _ref_counts + - Retry with succeeding factory - should succeed + """ + factory_call_count = 0 + + def failing_then_succeeding_factory(workspace: str) -> MockLightRAG: + nonlocal factory_call_count + factory_call_count += 1 + if factory_call_count == 1: + raise RuntimeError("factory boom") + return MockLightRAG(workspace) + + manager = WorkspaceManager( + factory=failing_then_succeeding_factory, max_instances=10 + ) + + # First call should raise RuntimeError + with pytest.raises(RuntimeError, match="factory boom"): + await manager.get_or_create("ws-fail") + + # Assert no partial state left + assert "ws-fail" not in manager._cache + assert "ws-fail" not in manager._ref_counts + assert len(manager._cache) == 0 + assert len(manager._ref_counts) == 0 + + # Second call should succeed + instance = await manager.get_or_create("ws-fail") + assert instance is not None + assert instance.workspace == "ws-fail" + assert "ws-fail" in manager._cache + assert "ws-fail" in manager._ref_counts + + # ------------------------------------------------------------------------- + # Shutdown Tests + # ------------------------------------------------------------------------- + + @pytest.mark.asyncio + async def test_get_or_create_after_shutdown(self, fresh_manager) -> None: + """Test that workspace can be created fresh after shutdown. + + Scenario: + - Create a workspace + - Call shutdown + - Assert cache is cleared + - Create same workspace again via get_or_create + - Assert it gets a FRESH instance (not the old one) + - Verify old instance was finalized + """ + # Create workspace + ws1 = await fresh_manager.get_or_create("ws-reuse") + original_id = id(ws1) + + # Shutdown should finalize and clear cache + await fresh_manager.shutdown() + assert ws1.finalize_called is True + assert fresh_manager.get_stats()["active_instances"] == 0 + assert fresh_manager.get_stats()["workspaces"] == [] + + # Create same workspace again - should get FRESH instance + ws2 = await fresh_manager.get_or_create("ws-reuse") + assert id(ws2) != original_id # Different instance + assert ws2.finalize_called is False # Fresh instance not yet finalized + + # Verify new instance is in cache + stats = fresh_manager.get_stats() + assert stats["active_instances"] == 1 + assert "ws-reuse" in stats["workspaces"] + + @pytest.mark.asyncio + async def test_double_shutdown_is_safe(self, fresh_manager) -> None: + """Test that calling shutdown twice does not raise any exception. + + Scenario: + - Create some workspaces + - Call shutdown once - verify it works + - Call shutdown again - verify it does NOT raise + - Assert state is still clean after second shutdown + """ + # Create workspaces + await fresh_manager.get_or_create("ws-1") + await fresh_manager.get_or_create("ws-2") + + # First shutdown should work + await fresh_manager.shutdown() + + stats = fresh_manager.get_stats() + assert stats["active_instances"] == 0 + assert stats["workspaces"] == [] + + # Second shutdown should NOT raise any exception + await fresh_manager.shutdown() + + # State should still be clean + stats = fresh_manager.get_stats() + assert stats["active_instances"] == 0 + assert stats["workspaces"] == [] + + # ------------------------------------------------------------------------- + # WorkspaceNameError Propagation Tests + # ------------------------------------------------------------------------- + + @pytest.mark.asyncio + async def test_get_or_create_rejects_invalid_workspace_name( + self, fresh_manager + ) -> None: + """Test that get_or_create raises WorkspaceNameError for invalid names. + + Since get_or_create calls sanitize_workspace_name internally, + invalid names should propagate the error. + """ + with pytest.raises(WorkspaceNameError, match="path traversal detected"): + await fresh_manager.get_or_create("../etc") + + with pytest.raises(WorkspaceNameError, match="path traversal detected"): + await fresh_manager.get_or_create("ws/name") + + with pytest.raises(WorkspaceNameError, match="max 64 characters"): + await fresh_manager.get_or_create("a" * 65) diff --git a/tests/unit/test_workspace_registry.py b/tests/unit/test_workspace_registry.py new file mode 100644 index 0000000000..9e58bc8a5a --- /dev/null +++ b/tests/unit/test_workspace_registry.py @@ -0,0 +1,480 @@ +"""Unit tests for WorkspaceRegistry and get_workspace_registry.""" + +from __future__ import annotations + +import json +import threading +from unittest.mock import patch + +import pytest + +from lightrag.api.workspace_registry import ( + WorkspaceRegistry, + get_workspace_registry, + WORKSPACE_REGISTRY_FILENAME, +) + +pytestmark = pytest.mark.offline + + +# ============================================================================= +# TestWorkspaceRegistry Tests +# ============================================================================= + + +class TestWorkspaceRegistry: + """Tests for the WorkspaceRegistry class.""" + + # ------------------------------------------------------------------------- + # Initialization Tests + # ------------------------------------------------------------------------- + + def test_init_creates_registry_in_working_dir(self, tmp_path) -> None: + """Test that registry file path is set correctly from working_dir.""" + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + assert registry.registry_path == tmp_path / WORKSPACE_REGISTRY_FILENAME + + def test_init_creates_registry_in_cwd_when_no_working_dir( + self, monkeypatch, tmp_path + ) -> None: + """Test that registry defaults to cwd when no working_dir is given.""" + # Change cwd to tmp_path so we know where the file ends up + monkeypatch.chdir(tmp_path) + registry = WorkspaceRegistry() + assert registry.registry_path.parent == tmp_path + + def test_init_loads_existing_file(self, tmp_path) -> None: + """Test that existing registry file is loaded on initialization.""" + registry_path = tmp_path / WORKSPACE_REGISTRY_FILENAME + existing_data = { + "existing-ws": { + "name": "existing-ws", + "first_seen": "2024-01-01T00:00:00+00:00", + "last_seen": "2024-01-01T00:00:00+00:00", + } + } + registry_path.write_text(json.dumps(existing_data), encoding="utf-8") + + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + + workspaces = registry.get_workspaces() + assert len(workspaces) == 1 + assert workspaces[0]["name"] == "existing-ws" + + def test_init_starts_empty_when_file_missing(self, tmp_path) -> None: + """Test that registry starts empty when file does not exist.""" + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + assert registry.get_workspaces() == [] + + def test_init_handles_corrupted_json(self, tmp_path) -> None: + """Test that corrupted JSON file is handled gracefully (starts fresh).""" + registry_path = tmp_path / WORKSPACE_REGISTRY_FILENAME + registry_path.write_text("{ not valid json", encoding="utf-8") + + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + + # Should start fresh, not crash + assert registry.get_workspaces() == [] + + def test_init_handles_empty_json_object(self, tmp_path) -> None: + """Test that empty JSON object {} is handled gracefully.""" + registry_path = tmp_path / WORKSPACE_REGISTRY_FILENAME + registry_path.write_text("{}", encoding="utf-8") + + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + + assert registry.get_workspaces() == [] + + # ------------------------------------------------------------------------- + # register_workspace Tests + # ------------------------------------------------------------------------- + + def test_register_workspace_adds_to_list(self, tmp_path) -> None: + """Test that registering a workspace adds it to get_workspaces.""" + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + registry.register_workspace("my-workspace") + + workspaces = registry.get_workspaces() + assert len(workspaces) == 1 + assert workspaces[0]["name"] == "my-workspace" + assert "first_seen" in workspaces[0] + assert "last_seen" in workspaces[0] + assert workspaces[0]["first_seen"] == workspaces[0]["last_seen"] + + def test_register_workspace_sets_both_timestamps(self, tmp_path) -> None: + """Test that first registration sets both first_seen and last_seen.""" + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + registry.register_workspace("ws-timestamps") + + workspaces = registry.get_workspaces() + assert workspaces[0]["first_seen"] is not None + assert workspaces[0]["last_seen"] is not None + # Both should be ISO format strings + assert "T" in workspaces[0]["first_seen"] + assert "T" in workspaces[0]["last_seen"] + + def test_register_workspace_updates_last_seen(self, tmp_path) -> None: + """Test that re-registering updates last_seen but not first_seen.""" + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + registry.register_workspace("ws-update") + first_last_seen = registry.get_workspaces()[0]["last_seen"] + first_first_seen = registry.get_workspaces()[0]["first_seen"] + + # Wait a tiny bit to ensure timestamp differs + import time + + time.sleep(0.01) + registry.register_workspace("ws-update") + + workspaces = registry.get_workspaces() + assert workspaces[0]["first_seen"] == first_first_seen + assert workspaces[0]["last_seen"] != first_last_seen + assert workspaces[0]["last_seen"] > first_last_seen + + def test_register_empty_string_is_noop(self, tmp_path) -> None: + """Test that empty string workspace is not registered (no-op).""" + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + registry.register_workspace("") + + assert registry.get_workspaces() == [] + + def test_register_none_is_noop(self, tmp_path) -> None: + """Test that None workspace is not registered (no-op).""" + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + registry.register_workspace(None) # type: ignore[arg-type] + + assert registry.get_workspaces() == [] + + def test_register_multiple_workspaces(self, tmp_path) -> None: + """Test that multiple distinct workspaces are all registered.""" + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + registry.register_workspace("ws-a") + registry.register_workspace("ws-b") + registry.register_workspace("ws-c") + + workspaces = registry.get_workspaces() + names = {w["name"] for w in workspaces} + assert names == {"ws-a", "ws-b", "ws-c"} + + def test_register_workspace_with_special_chars(self, tmp_path) -> None: + """Test that workspaces with special characters are registered correctly.""" + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + # The source code does not validate/normalize names, so any non-empty + # string is accepted as-is + registry.register_workspace("ws_underscore") + registry.register_workspace("ws-hyphen") + registry.register_workspace("ws.dot") + registry.register_workspace("ws@at") + registry.register_workspace("ws with space") + + workspaces = registry.get_workspaces() + assert len(workspaces) == 5 + names = {w["name"] for w in workspaces} + assert "ws_underscore" in names + assert "ws-hyphen" in names + assert "ws.dot" in names + assert "ws@at" in names + assert "ws with space" in names + + # ------------------------------------------------------------------------- + # Persistence Tests + # ------------------------------------------------------------------------- + + def test_persistence_survives_reload(self, tmp_path) -> None: + """Test that registry data survives save/reload cycle.""" + # Create and register + registry1 = WorkspaceRegistry(working_dir=str(tmp_path)) + registry1.register_workspace("persistent-ws") + + # Re-create from same directory + registry2 = WorkspaceRegistry(working_dir=str(tmp_path)) + workspaces = registry2.get_workspaces() + + assert len(workspaces) == 1 + assert workspaces[0]["name"] == "persistent-ws" + + def test_persistence_multiple_workspaces_reload(self, tmp_path) -> None: + """Test that multiple workspaces persist across reload.""" + registry1 = WorkspaceRegistry(working_dir=str(tmp_path)) + registry1.register_workspace("ws-1") + registry1.register_workspace("ws-2") + registry1.register_workspace("ws-3") + + registry2 = WorkspaceRegistry(working_dir=str(tmp_path)) + names = {w["name"] for w in registry2.get_workspaces()} + assert names == {"ws-1", "ws-2", "ws-3"} + + def test_file_contains_valid_json(self, tmp_path) -> None: + """Test that the saved file is valid JSON.""" + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + registry.register_workspace("json-check") + + registry_path = tmp_path / WORKSPACE_REGISTRY_FILENAME + assert registry_path.exists() + + data = json.loads(registry_path.read_text(encoding="utf-8")) + assert "json-check" in data + assert data["json-check"]["name"] == "json-check" + + def test_file_parent_directory_created(self, tmp_path) -> None: + """Test that parent directory is created if missing.""" + nested = tmp_path / "nested" / "dir" + registry = WorkspaceRegistry(working_dir=str(nested)) + registry.register_workspace("deep-ws") + + assert (nested / WORKSPACE_REGISTRY_FILENAME).exists() + + # ------------------------------------------------------------------------- + # Atomic Write Tests + # ------------------------------------------------------------------------- + + def test_atomic_write_no_partial_file(self, tmp_path) -> None: + """Test that only the final file is visible (no partial writes).""" + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + registry.register_workspace("ws-1") + + registry_path = tmp_path / WORKSPACE_REGISTRY_FILENAME + tmp_file = tmp_path / (WORKSPACE_REGISTRY_FILENAME + ".tmp") + + # .tmp file should not remain after save + assert not tmp_file.exists() + # Main file should exist and be valid + assert registry_path.exists() + data = json.loads(registry_path.read_text(encoding="utf-8")) + assert "ws-1" in data + + # ------------------------------------------------------------------------- + # Rollback Tests + # ------------------------------------------------------------------------- + + def test_rollback_on_save_failure(self, tmp_path) -> None: + """Test that IOError propagates when save fails. + + Note: the rollback restores to the in-memory state captured before + _save_registry is called β€” which is AFTER the workspace is already in + _cache. The key invariant is that the IOError is raised and the + in-memory state matches what was present when save was attempted. + """ + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + registry.register_workspace("ws-before") + + # Verify state before injection + assert len(registry.get_workspaces()) == 1 + + # Capture the on-disk state before the failing write + registry_path = tmp_path / WORKSPACE_REGISTRY_FILENAME + original_file_content = registry_path.read_text(encoding="utf-8") + + def raise_on_write(*args, **kwargs): + if len(args) > 0 and ".tmp" in str(args[0]): + raise IOError("simulated write failure") + return __builtins__["open"](*args, **kwargs) + + with patch("builtins.open", side_effect=raise_on_write): + with pytest.raises(IOError, match="simulated write failure"): + registry.register_workspace("ws-after-fail") + + # The IOError was raised (verified by pytest.raises). + # In-memory state is whatever it was when _save_registry was called + # (including the new workspace β€” rollback restores that same snapshot). + # The critical check: disk file must be UNCHANGED from before the call. + assert registry_path.read_text(encoding="utf-8") == original_file_content + + # ------------------------------------------------------------------------- + # Thread Safety Tests + # ------------------------------------------------------------------------- + + def test_concurrent_register_same_workspace(self, tmp_path) -> None: + """Test that concurrent registration of same workspace is thread-safe.""" + + def register_many(registry: WorkspaceRegistry, suffix: str) -> None: + for i in range(50): + registry.register_workspace(f"concurrent-ws-{suffix}") + + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + threads = [ + threading.Thread(target=register_many, args=(registry, f"t{i}")) + for i in range(4) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + # All should complete without errors + workspaces = registry.get_workspaces() + names = {w["name"] for w in workspaces} + # Each thread registered its own distinct workspace + assert len(names) == 4 + + def test_concurrent_register_different_workspaces(self, tmp_path) -> None: + """Test that concurrent registration of different workspaces is thread-safe.""" + num_threads = 8 + workspaces_per_thread = 25 + + def register_range(registry: WorkspaceRegistry, start: int) -> None: + for i in range(start, start + workspaces_per_thread): + registry.register_workspace(f"ws-{i:03d}") + + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + threads = [ + threading.Thread( + target=register_range, args=(registry, i * workspaces_per_thread) + ) + for i in range(num_threads) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + workspaces = registry.get_workspaces() + assert len(workspaces) == num_threads * workspaces_per_thread + + # ------------------------------------------------------------------------- + # get_workspaces Ordering Tests + # ------------------------------------------------------------------------- + + def test_get_workspaces_sorted_by_last_seen_descending(self, tmp_path) -> None: + """Test that workspaces are sorted by last_seen descending (most recent first).""" + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + registry.register_workspace("ws-old") + import time + + time.sleep(0.02) + registry.register_workspace("ws-new") + time.sleep(0.02) + registry.register_workspace("ws-newest") + + # Accessing old workspace again should move it to most recent + registry.register_workspace("ws-old") + + workspaces = registry.get_workspaces() + names = [w["name"] for w in workspaces] + + # ws-old should now be first (most recently seen) + assert names[0] == "ws-old" + assert names[1] == "ws-newest" + assert names[2] == "ws-new" + + def test_get_workspaces_returns_copy(self, tmp_path) -> None: + """Test that get_workspaces returns a new list each time.""" + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + registry.register_workspace("ws-1") + + result1 = registry.get_workspaces() + result2 = registry.get_workspaces() + + assert result1 is not result2 + assert result1 == result2 + + def test_get_workspaces_returns_clean_dicts(self, tmp_path) -> None: + """Test that returned workspace dicts only contain name/first_seen/last_seen.""" + registry = WorkspaceRegistry(working_dir=str(tmp_path)) + registry.register_workspace("ws-1") + + workspaces = registry.get_workspaces() + assert list(workspaces[0].keys()) == ["name", "first_seen", "last_seen"] + + +# ============================================================================= +# get_workspace_registry Singleton Tests +# ============================================================================= + + +class TestGetWorkspaceRegistrySingleton: + """Tests for the get_workspace_registry singleton accessor.""" + + def test_singleton_returns_same_instance(self, tmp_path) -> None: + """Test that get_workspace_registry returns the same instance for same working_dir.""" + # Reset global state before test + import lightrag.api.workspace_registry as wr + + wr._registry_instance = None + + reg1 = get_workspace_registry(working_dir=str(tmp_path)) + reg2 = get_workspace_registry() + + assert reg1 is reg2 + + def test_singleton_different_working_dirs_same_instance(self, tmp_path) -> None: + """Test that calling without working_dir returns the same singleton as first call.""" + import lightrag.api.workspace_registry as wr + + wr._registry_instance = None + + reg1 = get_workspace_registry(working_dir=str(tmp_path)) + reg2 = get_workspace_registry() # no working_dir, should use same + + assert reg1 is reg2 + + def test_singleton_raises_on_mismatched_working_dir(self, tmp_path) -> None: + """Test that ValueError is raised when second call uses different working_dir.""" + import lightrag.api.workspace_registry as wr + + wr._registry_instance = None + + get_workspace_registry(working_dir=str(tmp_path)) + other_dir = tmp_path / "other" + with pytest.raises(ValueError, match="already initialized"): + get_workspace_registry(working_dir=str(other_dir)) + + def test_singleton_raises_on_second_init_with_different_dir(self, tmp_path) -> None: + """Test that second call with different working_dir raises ValueError.""" + import lightrag.api.workspace_registry as wr + + wr._registry_instance = None + + get_workspace_registry(working_dir=str(tmp_path)) + + other_path = tmp_path / "different" + with pytest.raises(ValueError, match="already initialized"): + get_workspace_registry(working_dir=str(other_path)) + + def test_singleton_works_without_working_dir(self, monkeypatch, tmp_path) -> None: + """Test that singleton works when called without working_dir (uses cwd).""" + import lightrag.api.workspace_registry as wr + + wr._registry_instance = None + monkeypatch.chdir(tmp_path) + + reg = get_workspace_registry() + + assert reg is not None + assert isinstance(reg, WorkspaceRegistry) + + def test_singleton_does_not_raise_when_called_again_with_same_dir( + self, tmp_path + ) -> None: + """Test that calling again with the same working_dir does not raise.""" + import lightrag.api.workspace_registry as wr + + wr._registry_instance = None + + reg1 = get_workspace_registry(working_dir=str(tmp_path)) + reg2 = get_workspace_registry(working_dir=str(tmp_path)) + + assert reg1 is reg2 + + def test_singleton_none_working_dir_after_init_with_dir(self, tmp_path) -> None: + """Test that calling with None after init with dir returns same instance.""" + import lightrag.api.workspace_registry as wr + + wr._registry_instance = None + + reg1 = get_workspace_registry(working_dir=str(tmp_path)) + reg2 = get_workspace_registry(working_dir=None) # type: ignore[arg-type] + + assert reg1 is reg2 + + def test_singleton_raises_on_working_dir_mismatch_message(self, tmp_path) -> None: + """Test that ValueError message contains both working directories.""" + import lightrag.api.workspace_registry as wr + + wr._registry_instance = None + + get_workspace_registry(working_dir=str(tmp_path)) + + other_dir = tmp_path / "other_dir" + with pytest.raises(ValueError, match=str(tmp_path)): + get_workspace_registry(working_dir=str(other_dir))