diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 26ae0032b8..cf3a966890 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -9,6 +9,7 @@ get_swagger_ui_html, get_swagger_ui_oauth2_redirect_html, ) +import asyncio import os import re import logging @@ -19,6 +20,8 @@ from fastapi.staticfiles import StaticFiles from fastapi.responses import RedirectResponse from pathlib import Path +import configparser +from fnmatch import fnmatch from ascii_colors import ASCIIColors from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager @@ -73,6 +76,10 @@ webui_title = os.getenv("WEBUI_TITLE") webui_description = os.getenv("WEBUI_DESCRIPTION") +# Initialize config parser +config = configparser.ConfigParser() +config.read("config.ini") + # Global authentication configuration auth_configured = bool(auth_handler.accounts) @@ -339,8 +346,90 @@ def create_app(args): # Check if API key is provided either through env var or args api_key = os.getenv("LIGHTRAG_API_KEY") or args.key - # Initialize document manager with workspace support for data isolation - doc_manager = DocumentManager(args.input_dir, workspace=args.workspace) + def normalize_workspace(workspace: str | None) -> str: + if workspace is None: + return args.workspace or "" + return workspace + + def detect_workspace_dirs(working_dir: Path) -> list[str]: + workspace_names: set[str] = set() + if not working_dir.exists(): + return [] + + storage_patterns = ( + "kv_store_*.json", + "graph_*.graphml", + "vdb_*.json", + "faiss_index_*.index", + ) + + for child in working_dir.iterdir(): + if not child.is_dir(): + continue + try: + if any( + nested.is_file() + and any(fnmatch(nested.name, pattern) for pattern in storage_patterns) + for nested in child.iterdir() + ): + workspace_names.add(child.name) + except OSError: + continue + + return sorted(workspace_names) + + def build_workspace_inventory(current_workspace: str | None = None) -> dict: + default_workspace = get_default_workspace() or "" + effective_workspace = normalize_workspace(current_workspace) + + sources_by_workspace: dict[str, set[str]] = {} + + def add_workspace(name: str | None, source: str): + workspace_name = normalize_workspace(name) + sources_by_workspace.setdefault(workspace_name, set()).add(source) + + add_workspace(default_workspace, "default-config") + add_workspace(effective_workspace, "current-request") + + for workspace_name in rag_registry.keys(): + add_workspace(workspace_name, "rag-registry") + + for workspace_name in doc_manager_registry.keys(): + add_workspace(workspace_name, "document-manager") + + for workspace_name in detect_workspace_dirs(Path(args.working_dir)): + add_workspace(workspace_name, "working-dir-scan") + + workspace_entries = [] + for workspace_name in sorted(sources_by_workspace.keys()): + workspace_entries.append( + { + "id": workspace_name, + "label": workspace_name if workspace_name else "default", + "is_default": workspace_name == default_workspace, + "is_current": workspace_name == effective_workspace, + "sources": sorted(sources_by_workspace[workspace_name]), + } + ) + + return { + "current_workspace": effective_workspace, + "default_workspace": default_workspace, + "workspaces": workspace_entries, + } + + # Initialize document manager registry with workspace support for data isolation + doc_manager_registry: dict[str, DocumentManager] = {} + + def get_doc_manager_for_workspace(workspace: str | None) -> DocumentManager: + workspace_key = normalize_workspace(workspace) + manager = doc_manager_registry.get(workspace_key) + if manager is None: + manager = DocumentManager(args.input_dir, workspace=workspace_key) + doc_manager_registry[workspace_key] = manager + return manager + + doc_manager = get_doc_manager_for_workspace(args.workspace) @asynccontextmanager async def lifespan(app: FastAPI): @@ -361,8 +450,14 @@ async def lifespan(app: FastAPI): yield finally: - # Clean up database connections - await rag.finalize_storages() + # Clean up database connections for all initialized workspaces + for workspace_key, rag_instance in list(rag_registry.items()): + try: + await rag_instance.finalize_storages() + except Exception as e: + logger.error( + f"Failed to finalize storages for workspace '{workspace_key}': {e}" + ) if "LIGHTRAG_GUNICORN_MODE" not in os.environ: # Only perform cleanup in Uvicorn single-process mode @@ -1055,45 +1150,73 @@ async def server_rerank_func( name=args.simulated_model_name, tag=args.simulated_model_tag ) - # Initialize RAG with unified configuration - try: - rag = LightRAG( - working_dir=args.working_dir, - workspace=args.workspace, - llm_model_func=create_llm_model_func(args.llm_binding), - llm_model_name=args.llm_model, - llm_model_max_async=args.max_async, - summary_max_tokens=args.summary_max_tokens, - summary_context_size=args.summary_context_size, - chunk_token_size=int(args.chunk_size), - chunk_overlap_token_size=int(args.chunk_overlap_size), - llm_model_kwargs=create_llm_model_kwargs( - args.llm_binding, args, llm_timeout - ), - embedding_func=embedding_func, - default_llm_timeout=llm_timeout, - default_embedding_timeout=embedding_timeout, - kv_storage=args.kv_storage, - graph_storage=args.graph_storage, - vector_storage=args.vector_storage, - doc_status_storage=args.doc_status_storage, - vector_db_storage_cls_kwargs={ - "cosine_better_than_threshold": args.cosine_threshold - }, - enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, - enable_llm_cache=args.enable_llm_cache, - rerank_model_func=rerank_model_func, - max_parallel_insert=args.max_parallel_insert, - max_graph_nodes=args.max_graph_nodes, - addon_params={ - "language": args.summary_language, - "entity_types": args.entity_types, - }, - ollama_server_infos=ollama_server_infos, - ) - except Exception as e: - logger.error(f"Failed to initialize LightRAG: {e}") - raise + def build_rag(workspace: str | None = None) -> LightRAG: + workspace_key = normalize_workspace(workspace) + try: + return LightRAG( + working_dir=args.working_dir, + workspace=workspace_key, + llm_model_func=create_llm_model_func(args.llm_binding), + llm_model_name=args.llm_model, + llm_model_max_async=args.max_async, + summary_max_tokens=args.summary_max_tokens, + summary_context_size=args.summary_context_size, + chunk_token_size=int(args.chunk_size), + chunk_overlap_token_size=int(args.chunk_overlap_size), + llm_model_kwargs=create_llm_model_kwargs( + args.llm_binding, args, llm_timeout + ), + embedding_func=embedding_func, + default_llm_timeout=llm_timeout, + default_embedding_timeout=embedding_timeout, + kv_storage=args.kv_storage, + graph_storage=args.graph_storage, + vector_storage=args.vector_storage, + doc_status_storage=args.doc_status_storage, + vector_db_storage_cls_kwargs={ + "cosine_better_than_threshold": args.cosine_threshold + }, + enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract, + enable_llm_cache=args.enable_llm_cache, + rerank_model_func=rerank_model_func, + max_parallel_insert=args.max_parallel_insert, + max_graph_nodes=args.max_graph_nodes, + addon_params={ + "language": args.summary_language, + "entity_types": args.entity_types, + }, + ollama_server_infos=ollama_server_infos, + ) + except Exception as e: + logger.error(f"Failed to initialize LightRAG for workspace '{workspace_key}': {e}") + raise + + rag_registry: dict[str, LightRAG] = {} + rag_registry_locks: dict[str, asyncio.Lock] = {} + default_workspace = normalize_workspace(args.workspace) + rag = build_rag(default_workspace) + rag_registry[default_workspace] = rag + + async def get_rag_for_workspace(workspace: str | None = None) -> LightRAG: + workspace_key = normalize_workspace(workspace) + rag_instance = rag_registry.get(workspace_key) + if rag_instance is not None: + return rag_instance + + lock = rag_registry_locks.setdefault(workspace_key, asyncio.Lock()) + async with lock: + rag_instance = rag_registry.get(workspace_key) + if rag_instance is not None: + return rag_instance + + rag_instance = build_rag(workspace_key) + await rag_instance.initialize_storages() + await rag_instance.check_and_migrate_data() + rag_registry[workspace_key] = rag_instance + logger.info( + f"Initialized workspace-specific LightRAG instance for workspace '{workspace_key}'" + ) + return rag_instance # Add routes app.include_router( @@ -1101,13 +1224,37 @@ async def server_rerank_func( rag, doc_manager, api_key, + get_rag_for_workspace=get_rag_for_workspace, + get_doc_manager_for_workspace=get_doc_manager_for_workspace, + get_workspace_from_request=get_workspace_from_request, + ) + ) + app.include_router( + create_query_routes( + rag, + api_key, + args.top_k, + get_rag_for_workspace=get_rag_for_workspace, + get_workspace_from_request=get_workspace_from_request, + ) + ) + app.include_router( + create_graph_routes( + rag, + api_key, + get_rag_for_workspace=get_rag_for_workspace, + get_workspace_from_request=get_workspace_from_request, ) ) - app.include_router(create_query_routes(rag, api_key, args.top_k)) - app.include_router(create_graph_routes(rag, api_key)) # Add Ollama API routes - ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key) + ollama_api = OllamaAPI( + rag, + top_k=args.top_k, + api_key=api_key, + get_rag_for_workspace=get_rag_for_workspace, + get_workspace_from_request=get_workspace_from_request, + ) app.include_router(ollama_api.router, prefix="/api") # Custom Swagger UI endpoint for offline support @@ -1242,6 +1389,7 @@ async def get_status(request: Request): default_workspace = get_default_workspace() if workspace is None: workspace = default_workspace + await get_rag_for_workspace(workspace) pipeline_status = await get_namespace_data( "pipeline_status", workspace=workspace ) @@ -1276,7 +1424,8 @@ async def get_status(request: Request): "vector_storage": args.vector_storage, "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract, "enable_llm_cache": args.enable_llm_cache, - "workspace": default_workspace, + "workspace": workspace, + "default_workspace": default_workspace, "max_graph_nodes": args.max_graph_nodes, # Rerank configuration "enable_rerank": rerank_model_func is not None, @@ -1308,6 +1457,17 @@ async def get_status(request: Request): logger.error(f"Error getting health status: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) + @app.get( + "/workspaces", + dependencies=[Depends(combined_auth)], + tags=["workspace"], + summary="List known workspaces", + description="Return the currently known workspaces from active registries and local storage directories.", + ) + async def list_workspaces(request: Request): + workspace = get_workspace_from_request(request) + return build_workspace_inventory(workspace) + # Custom StaticFiles class for smart caching class SmartStaticFiles(StaticFiles): # Renamed from NoCacheStaticFiles async def get_response(self, path: str, scope): diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 9e6fab9dbb..97cc60fc29 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -3,10 +3,8 @@ """ import asyncio -import time -from uuid import uuid4 from functools import lru_cache -from lightrag.utils import logger, get_pinyin_sort_key, performance_timing_log +from lightrag.utils import logger, get_pinyin_sort_key import aiofiles import traceback from datetime import datetime, timezone @@ -19,6 +17,7 @@ Depends, File, HTTPException, + Request, UploadFile, ) from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -1003,13 +1002,12 @@ def _extract_pdf_pypdf(file_bytes: bytes, password: str = None) -> str: # Check if PDF is encrypted if reader.is_encrypted: - # Try empty password first (covers permission-only encrypted PDFs) - decrypt_result = reader.decrypt(password or "") + if not password: + raise Exception("PDF is encrypted but no password provided") + + decrypt_result = reader.decrypt(password) if decrypt_result == 0: - if password: - raise Exception("Incorrect PDF password") - else: - raise Exception("PDF is encrypted but no password provided") + raise Exception("Incorrect PDF password") # Extract text from all pages content = "" @@ -1252,8 +1250,7 @@ async def pipeline_enqueue_file( # Get file size for error reporting try: - stat = await asyncio.to_thread(file_path.stat) - file_size = stat.st_size + file_size = file_path.stat().st_size except Exception: file_size = 0 @@ -1344,8 +1341,8 @@ async def pipeline_enqueue_file( | ".less" ): try: - # Try to decode as UTF-8 (offloaded to thread to avoid blocking the event loop) - content = await asyncio.to_thread(file.decode, "utf-8") + # Try to decode as UTF-8 + content = file.decode("utf-8") # Validate content if not content or len(content.strip()) == 0: @@ -1612,7 +1609,7 @@ async def pipeline_enqueue_file( # Move file to __enqueued__ directory after enqueuing try: enqueued_dir = file_path.parent / "__enqueued__" - await asyncio.to_thread(enqueued_dir.mkdir, exist_ok=True) + enqueued_dir.mkdir(exist_ok=True) # Generate unique filename to avoid conflicts unique_filename = get_unique_filename_in_enqueued( @@ -1621,7 +1618,7 @@ async def pipeline_enqueue_file( target_path = enqueued_dir / unique_filename # Move the file - await asyncio.to_thread(file_path.rename, target_path) + file_path.rename(target_path) logger.debug( f"Moved file to enqueued directory: {file_path.name} -> {unique_filename}" ) @@ -2085,15 +2082,40 @@ async def background_delete_documents( def create_document_routes( - rag: LightRAG, doc_manager: DocumentManager, api_key: Optional[str] = None + rag: LightRAG, + doc_manager: DocumentManager, + api_key: Optional[str] = None, + get_rag_for_workspace: Optional[Any] = None, + get_doc_manager_for_workspace: Optional[Any] = None, + get_workspace_from_request: Optional[Any] = None, ): # Create combined auth dependency for document routes combined_auth = get_combined_auth_dependency(api_key) + async def resolve_rag(raw_request: Request) -> LightRAG: + if get_rag_for_workspace is None or get_workspace_from_request is None: + return rag + workspace = get_workspace_from_request(raw_request) + return await get_rag_for_workspace(workspace) + + def resolve_doc_manager(raw_request: Request) -> DocumentManager: + if get_doc_manager_for_workspace is None or get_workspace_from_request is None: + return doc_manager + workspace = get_workspace_from_request(raw_request) + return get_doc_manager_for_workspace(workspace) + + def resolve_workspace(raw_request: Request) -> str: + if get_workspace_from_request is None: + return rag.workspace + workspace = get_workspace_from_request(raw_request) + return rag.workspace if workspace is None else workspace + @router.post( "/scan", response_model=ScanResponse, dependencies=[Depends(combined_auth)] ) - async def scan_for_new_documents(background_tasks: BackgroundTasks): + async def scan_for_new_documents( + raw_request: Request, background_tasks: BackgroundTasks + ): """ Trigger the scanning process for new documents. @@ -2106,9 +2128,13 @@ async def scan_for_new_documents(background_tasks: BackgroundTasks): """ # Generate track_id with "scan" prefix for scanning operation track_id = generate_track_id("scan") + active_rag = await resolve_rag(raw_request) + active_doc_manager = resolve_doc_manager(raw_request) # Start the scanning process in the background with track_id - background_tasks.add_task(run_scanning_process, rag, doc_manager, track_id) + background_tasks.add_task( + run_scanning_process, active_rag, active_doc_manager, track_id + ) return ScanResponse( status="scanning_started", message="Scanning process has been initiated in the background", @@ -2119,7 +2145,9 @@ async def scan_for_new_documents(background_tasks: BackgroundTasks): "/upload", response_model=InsertResponse, dependencies=[Depends(combined_auth)] ) async def upload_to_input_dir( - background_tasks: BackgroundTasks, file: UploadFile = File(...) + raw_request: Request, + background_tasks: BackgroundTasks, + file: UploadFile = File(...), ): """ Upload a file to the input directory and index it. @@ -2173,13 +2201,17 @@ async def upload_to_input_dir( HTTPException: If the file type is not supported (400), file too large (413), or other errors occur (500). """ try: + active_rag = await resolve_rag(raw_request) + active_doc_manager = resolve_doc_manager(raw_request) # Sanitize filename to prevent Path Traversal attacks - safe_filename = sanitize_filename(file.filename, doc_manager.input_dir) + safe_filename = sanitize_filename( + file.filename, active_doc_manager.input_dir + ) - if not doc_manager.is_supported_file(safe_filename): + if not active_doc_manager.is_supported_file(safe_filename): raise HTTPException( status_code=400, - detail=f"Unsupported file type. Supported types: {doc_manager.supported_extensions}", + detail=f"Unsupported file type. Supported types: {active_doc_manager.supported_extensions}", ) # Check file size limit (if configured) @@ -2204,7 +2236,9 @@ async def upload_to_input_dir( ) # Check if filename already exists in doc_status storage - existing_doc_data = await rag.doc_status.get_doc_by_file_path(safe_filename) + existing_doc_data = await active_rag.doc_status.get_doc_by_file_path( + safe_filename + ) if existing_doc_data: # Get document status and track_id from existing document status = existing_doc_data.get("status", "unknown") @@ -2216,7 +2250,7 @@ async def upload_to_input_dir( track_id=existing_track_id, ) - file_path = doc_manager.input_dir / safe_filename + file_path = active_doc_manager.input_dir / safe_filename # Check if file already exists in file system if file_path.exists(): return InsertResponse( @@ -2267,7 +2301,9 @@ async def upload_to_input_dir( track_id = generate_track_id("upload") # Add to background tasks and get track_id - background_tasks.add_task(pipeline_index_file, rag, file_path, track_id) + background_tasks.add_task( + pipeline_index_file, active_rag, file_path, track_id + ) return InsertResponse( status="success", @@ -2287,7 +2323,9 @@ async def upload_to_input_dir( "/text", response_model=InsertResponse, dependencies=[Depends(combined_auth)] ) async def insert_text( - request: InsertTextRequest, background_tasks: BackgroundTasks + raw_request: Request, + request: InsertTextRequest, + background_tasks: BackgroundTasks, ): """ Insert text into the RAG system. @@ -2306,13 +2344,14 @@ async def insert_text( HTTPException: If an error occurs during text processing (500). """ try: + active_rag = await resolve_rag(raw_request) # Check if file_source already exists in doc_status storage if ( request.file_source and request.file_source.strip() and request.file_source != "unknown_source" ): - existing_doc_data = await rag.doc_status.get_doc_by_file_path( + existing_doc_data = await active_rag.doc_status.get_doc_by_file_path( request.file_source ) if existing_doc_data: @@ -2329,7 +2368,7 @@ async def insert_text( # Check if content already exists by computing content hash (doc_id) sanitized_text = sanitize_text_for_encoding(request.text) content_doc_id = compute_mdhash_id(sanitized_text, prefix="doc-") - existing_doc = await rag.doc_status.get_by_id(content_doc_id) + existing_doc = await active_rag.doc_status.get_by_id(content_doc_id) if existing_doc: # Content already exists, return duplicated with existing track_id status = existing_doc.get("status", "unknown") @@ -2345,7 +2384,7 @@ async def insert_text( background_tasks.add_task( pipeline_index_texts, - rag, + active_rag, [request.text], file_sources=[request.file_source], track_id=track_id, @@ -2367,7 +2406,9 @@ async def insert_text( dependencies=[Depends(combined_auth)], ) async def insert_texts( - request: InsertTextsRequest, background_tasks: BackgroundTasks + raw_request: Request, + request: InsertTextsRequest, + background_tasks: BackgroundTasks, ): """ Insert multiple texts into the RAG system. @@ -2386,6 +2427,7 @@ async def insert_texts( HTTPException: If an error occurs during text processing (500). """ try: + active_rag = await resolve_rag(raw_request) # Check if any file_sources already exist in doc_status storage if request.file_sources: for file_source in request.file_sources: @@ -2394,7 +2436,7 @@ async def insert_texts( and file_source.strip() and file_source != "unknown_source" ): - existing_doc_data = await rag.doc_status.get_doc_by_file_path( + existing_doc_data = await active_rag.doc_status.get_doc_by_file_path( file_source ) if existing_doc_data: @@ -2412,7 +2454,7 @@ async def insert_texts( for text in request.texts: sanitized_text = sanitize_text_for_encoding(text) content_doc_id = compute_mdhash_id(sanitized_text, prefix="doc-") - existing_doc = await rag.doc_status.get_by_id(content_doc_id) + existing_doc = await active_rag.doc_status.get_by_id(content_doc_id) if existing_doc: # Content already exists, return duplicated with existing track_id status = existing_doc.get("status", "unknown") @@ -2428,7 +2470,7 @@ async def insert_texts( background_tasks.add_task( pipeline_index_texts, - rag, + active_rag, request.texts, file_sources=request.file_sources, track_id=track_id, @@ -2447,7 +2489,7 @@ async def insert_texts( @router.delete( "", response_model=ClearDocumentsResponse, dependencies=[Depends(combined_auth)] ) - async def clear_documents(): + async def clear_documents(raw_request: Request): """ Clear all documents from the RAG system. @@ -2472,13 +2514,16 @@ async def clear_documents(): get_namespace_data, get_namespace_lock, ) + active_rag = await resolve_rag(raw_request) + active_doc_manager = resolve_doc_manager(raw_request) + workspace = resolve_workspace(raw_request) # Get pipeline status and lock pipeline_status = await get_namespace_data( - "pipeline_status", workspace=rag.workspace + "pipeline_status", workspace=workspace ) pipeline_status_lock = get_namespace_lock( - "pipeline_status", workspace=rag.workspace + "pipeline_status", workspace=workspace ) # Check and set status with lock @@ -2511,17 +2556,17 @@ async def clear_documents(): # Use drop method to clear all data drop_tasks = [] storages = [ - rag.text_chunks, - rag.full_docs, - rag.full_entities, - rag.full_relations, - rag.entity_chunks, - rag.relation_chunks, - rag.entities_vdb, - rag.relationships_vdb, - rag.chunks_vdb, - rag.chunk_entity_relation_graph, - rag.doc_status, + active_rag.text_chunks, + active_rag.full_docs, + active_rag.full_entities, + active_rag.full_relations, + active_rag.entity_chunks, + active_rag.relation_chunks, + active_rag.entities_vdb, + active_rag.relationships_vdb, + active_rag.chunks_vdb, + active_rag.chunk_entity_relation_graph, + active_rag.doc_status, ] # Log storage drop start @@ -2586,7 +2631,7 @@ async def clear_documents(): deleted_files_count = 0 file_errors_count = 0 - for file_path in doc_manager.input_dir.glob("*"): + for file_path in active_doc_manager.input_dir.glob("*"): if file_path.is_file(): try: file_path.unlink() @@ -2643,7 +2688,7 @@ async def clear_documents(): dependencies=[Depends(combined_auth)], response_model=PipelineStatusResponse, ) - async def get_pipeline_status() -> PipelineStatusResponse: + async def get_pipeline_status(raw_request: Request) -> PipelineStatusResponse: """ Get the current status of the document indexing pipeline. @@ -2673,16 +2718,17 @@ async def get_pipeline_status() -> PipelineStatusResponse: get_namespace_lock, get_all_update_flags_status, ) + workspace = resolve_workspace(raw_request) pipeline_status = await get_namespace_data( - "pipeline_status", workspace=rag.workspace + "pipeline_status", workspace=workspace ) pipeline_status_lock = get_namespace_lock( - "pipeline_status", workspace=rag.workspace + "pipeline_status", workspace=workspace ) # Get update flags status for all namespaces - update_status = await get_all_update_flags_status(workspace=rag.workspace) + update_status = await get_all_update_flags_status(workspace=workspace) # Convert MutableBoolean objects to regular boolean values processed_update_status = {} @@ -2742,7 +2788,7 @@ async def get_pipeline_status() -> PipelineStatusResponse: @router.get( "", response_model=DocsStatusesResponse, dependencies=[Depends(combined_auth)] ) - async def documents() -> DocsStatusesResponse: + async def documents(raw_request: Request) -> DocsStatusesResponse: """ Get the status of all documents in the system. This endpoint is deprecated; use /documents/paginated instead. To prevent excessive resource consumption, a maximum of 1,000 records is returned. @@ -2761,6 +2807,7 @@ async def documents() -> DocsStatusesResponse: HTTPException: If an error occurs while retrieving document statuses (500). """ try: + active_rag = await resolve_rag(raw_request) statuses = ( DocStatus.PENDING, DocStatus.PROCESSING, @@ -2769,7 +2816,7 @@ async def documents() -> DocsStatusesResponse: DocStatus.FAILED, ) - tasks = [rag.get_docs_by_status(status) for status in statuses] + tasks = [active_rag.get_docs_by_status(status) for status in statuses] results: List[Dict[str, DocProcessingStatus]] = await asyncio.gather(*tasks) response = DocsStatusesResponse() @@ -2856,6 +2903,7 @@ class DeleteDocByIdResponse(BaseModel): summary="Delete a document and all its associated data by its ID.", ) async def delete_document( + raw_request: Request, delete_request: DeleteDocRequest, background_tasks: BackgroundTasks, ) -> DeleteDocByIdResponse: @@ -2889,12 +2937,15 @@ async def delete_document( get_namespace_data, get_namespace_lock, ) + active_rag = await resolve_rag(raw_request) + active_doc_manager = resolve_doc_manager(raw_request) + workspace = resolve_workspace(raw_request) pipeline_status = await get_namespace_data( - "pipeline_status", workspace=rag.workspace + "pipeline_status", workspace=workspace ) pipeline_status_lock = get_namespace_lock( - "pipeline_status", workspace=rag.workspace + "pipeline_status", workspace=workspace ) # Check if pipeline is busy with proper lock @@ -2909,8 +2960,8 @@ async def delete_document( # Add deletion task to background tasks background_tasks.add_task( background_delete_documents, - rag, - doc_manager, + active_rag, + active_doc_manager, doc_ids, delete_request.delete_file, delete_request.delete_llm_cache, @@ -2933,7 +2984,7 @@ async def delete_document( response_model=ClearCacheResponse, dependencies=[Depends(combined_auth)], ) - async def clear_cache(request: ClearCacheRequest): + async def clear_cache(raw_request: Request, request: ClearCacheRequest): """ Clear all cache data from the LLM response cache storage. @@ -2950,8 +3001,9 @@ async def clear_cache(request: ClearCacheRequest): HTTPException: If an error occurs during cache clearing (500). """ try: + active_rag = await resolve_rag(raw_request) # Call the aclear_cache method (no modes parameter) - await rag.aclear_cache() + await active_rag.aclear_cache() # Prepare success message message = "Successfully cleared all cache" @@ -2967,7 +3019,7 @@ async def clear_cache(request: ClearCacheRequest): response_model=DeletionResult, dependencies=[Depends(combined_auth)], ) - async def delete_entity(request: DeleteEntityRequest): + async def delete_entity(raw_request: Request, request: DeleteEntityRequest): """ Delete an entity and all its relationships from the knowledge graph. @@ -2981,7 +3033,8 @@ async def delete_entity(request: DeleteEntityRequest): HTTPException: If the entity is not found (404) or an error occurs (500). """ try: - result = await rag.adelete_by_entity(entity_name=request.entity_name) + active_rag = await resolve_rag(raw_request) + result = await active_rag.adelete_by_entity(entity_name=request.entity_name) if result.status == "not_found": raise HTTPException(status_code=404, detail=result.message) if result.status == "fail": @@ -3002,7 +3055,7 @@ async def delete_entity(request: DeleteEntityRequest): response_model=DeletionResult, dependencies=[Depends(combined_auth)], ) - async def delete_relation(request: DeleteRelationRequest): + async def delete_relation(raw_request: Request, request: DeleteRelationRequest): """ Delete a relationship between two entities from the knowledge graph. @@ -3016,7 +3069,8 @@ async def delete_relation(request: DeleteRelationRequest): HTTPException: If the relation is not found (404) or an error occurs (500). """ try: - result = await rag.adelete_by_relation( + active_rag = await resolve_rag(raw_request) + result = await active_rag.adelete_by_relation( source_entity=request.source_entity, target_entity=request.target_entity, ) @@ -3040,7 +3094,9 @@ async def delete_relation(request: DeleteRelationRequest): response_model=TrackStatusResponse, dependencies=[Depends(combined_auth)], ) - async def get_track_status(track_id: str) -> TrackStatusResponse: + async def get_track_status( + raw_request: Request, track_id: str + ) -> TrackStatusResponse: """ Get the processing status of documents by tracking ID. @@ -3060,6 +3116,7 @@ async def get_track_status(track_id: str) -> TrackStatusResponse: HTTPException: If track_id is invalid (400) or an error occurs (500). """ try: + active_rag = await resolve_rag(raw_request) # Validate track_id if not track_id or not track_id.strip(): raise HTTPException(status_code=400, detail="Track ID cannot be empty") @@ -3067,7 +3124,7 @@ async def get_track_status(track_id: str) -> TrackStatusResponse: track_id = track_id.strip() # Get documents by track_id - docs_by_track_id = await rag.aget_docs_by_track_id(track_id) + docs_by_track_id = await active_rag.aget_docs_by_track_id(track_id) # Convert to response format documents = [] @@ -3115,6 +3172,7 @@ async def get_track_status(track_id: str) -> TrackStatusResponse: dependencies=[Depends(combined_auth)], ) async def get_documents_paginated( + raw_request: Request, request: DocumentsRequest, ) -> PaginatedDocsResponse: """ @@ -3136,92 +3194,24 @@ async def get_documents_paginated( Raises: HTTPException: If an error occurs while retrieving documents (500). """ - trace_id = uuid4().hex[:8] - request_start = time.perf_counter() - status_filter_value = ( - request.status_filter.value if request.status_filter is not None else None - ) - - performance_timing_log( - "[documents/paginated][%s] Request start workspace=%s status_filter=%s page=%s page_size=%s sort_field=%s sort_direction=%s", - trace_id, - rag.workspace, - status_filter_value, - request.page, - request.page_size, - request.sort_field, - request.sort_direction, - ) - try: - - async def _timed_call(operation_name: str, operation): - operation_start = time.perf_counter() - performance_timing_log( - "[documents/paginated][%s] %s started", - trace_id, - operation_name, - ) - try: - result = await operation - except Exception: - elapsed = time.perf_counter() - operation_start - performance_timing_log( - "[documents/paginated][%s] %s failed after %.4fs", - trace_id, - operation_name, - elapsed, - ) - raise - - elapsed = time.perf_counter() - operation_start - performance_timing_log( - "[documents/paginated][%s] %s completed in %.4fs", - trace_id, - operation_name, - elapsed, - ) - return result - - query_task_create_start = time.perf_counter() - docs_task = asyncio.create_task( - _timed_call( - "get_docs_paginated", - rag.doc_status.get_docs_paginated( - status_filter=request.status_filter, - page=request.page, - page_size=request.page_size, - sort_field=request.sort_field, - sort_direction=request.sort_direction, - ), - ) - ) - status_counts_task = asyncio.create_task( - _timed_call( - "get_all_status_counts", - rag.doc_status.get_all_status_counts(), - ) - ) - query_task_create_elapsed = time.perf_counter() - query_task_create_start - performance_timing_log( - "[documents/paginated][%s] Query tasks created in %.4fs", - trace_id, - query_task_create_elapsed, + active_rag = await resolve_rag(raw_request) + # Get paginated documents and status counts in parallel + docs_task = active_rag.doc_status.get_docs_paginated( + status_filter=request.status_filter, + page=request.page, + page_size=request.page_size, + sort_field=request.sort_field, + sort_direction=request.sort_direction, ) + status_counts_task = active_rag.doc_status.get_all_status_counts() - query_await_start = time.perf_counter() + # Execute both queries in parallel (documents_with_ids, total_count), status_counts = await asyncio.gather( docs_task, status_counts_task ) - query_await_elapsed = time.perf_counter() - query_await_start - performance_timing_log( - "[documents/paginated][%s] Query tasks awaited in %.4fs", - trace_id, - query_await_elapsed, - ) # Convert documents to response format - response_assembly_start = time.perf_counter() doc_responses = [] for doc_id, doc in documents_with_ids: doc_responses.append( @@ -3253,37 +3243,14 @@ async def _timed_call(operation_name: str, operation): has_next=has_next, has_prev=has_prev, ) - response = PaginatedDocsResponse( + + return PaginatedDocsResponse( documents=doc_responses, pagination=pagination, status_counts=status_counts, ) - response_assembly_elapsed = time.perf_counter() - response_assembly_start - total_elapsed = time.perf_counter() - request_start - - performance_timing_log( - "[documents/paginated][%s] Response assembled in %.4fs", - trace_id, - response_assembly_elapsed, - ) - performance_timing_log( - "[documents/paginated][%s] Request completed in %.4fs returned_rows=%s total_count=%s status_count_keys=%s", - trace_id, - total_elapsed, - len(doc_responses), - total_count, - sorted(status_counts.keys()), - ) - - return response except Exception as e: - total_elapsed = time.perf_counter() - request_start - performance_timing_log( - "[documents/paginated][%s] Request failed after %.4fs", - trace_id, - total_elapsed, - ) logger.error(f"Error getting paginated documents: {str(e)}") logger.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @@ -3293,7 +3260,9 @@ async def _timed_call(operation_name: str, operation): response_model=StatusCountsResponse, dependencies=[Depends(combined_auth)], ) - async def get_document_status_counts() -> StatusCountsResponse: + async def get_document_status_counts( + raw_request: Request, + ) -> StatusCountsResponse: """ Get counts of documents by status. @@ -3307,7 +3276,8 @@ async def get_document_status_counts() -> StatusCountsResponse: HTTPException: If an error occurs while retrieving status counts (500). """ try: - status_counts = await rag.doc_status.get_all_status_counts() + active_rag = await resolve_rag(raw_request) + status_counts = await active_rag.doc_status.get_all_status_counts() return StatusCountsResponse(status_counts=status_counts) except Exception as e: @@ -3320,7 +3290,9 @@ async def get_document_status_counts() -> StatusCountsResponse: response_model=ReprocessResponse, dependencies=[Depends(combined_auth)], ) - async def reprocess_failed_documents(background_tasks: BackgroundTasks): + async def reprocess_failed_documents( + raw_request: Request, background_tasks: BackgroundTasks + ): """ Reprocess failed and pending documents. @@ -3346,9 +3318,10 @@ async def reprocess_failed_documents(background_tasks: BackgroundTasks): HTTPException: If an error occurs while initiating reprocessing (500). """ try: + active_rag = await resolve_rag(raw_request) # Start the reprocessing in the background # Note: Reprocessed documents retain their original track_id from initial upload - background_tasks.add_task(rag.apipeline_process_enqueue_documents) + background_tasks.add_task(active_rag.apipeline_process_enqueue_documents) logger.info("Reprocessing of failed documents initiated") return ReprocessResponse( @@ -3366,7 +3339,7 @@ async def reprocess_failed_documents(background_tasks: BackgroundTasks): response_model=CancelPipelineResponse, dependencies=[Depends(combined_auth)], ) - async def cancel_pipeline(): + async def cancel_pipeline(raw_request: Request): """ Request cancellation of the currently running pipeline. @@ -3392,12 +3365,13 @@ async def cancel_pipeline(): get_namespace_data, get_namespace_lock, ) + workspace = resolve_workspace(raw_request) pipeline_status = await get_namespace_data( - "pipeline_status", workspace=rag.workspace + "pipeline_status", workspace=workspace ) pipeline_status_lock = get_namespace_lock( - "pipeline_status", workspace=rag.workspace + "pipeline_status", workspace=workspace ) async with pipeline_status_lock: diff --git a/lightrag/api/routers/graph_routes.py b/lightrag/api/routers/graph_routes.py index e892ff011c..943ff96c58 100644 --- a/lightrag/api/routers/graph_routes.py +++ b/lightrag/api/routers/graph_routes.py @@ -2,9 +2,9 @@ This module contains all graph-related routes for the LightRAG API. """ -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, Callable, Awaitable import traceback -from fastapi import APIRouter, Depends, Query, HTTPException +from fastapi import APIRouter, Depends, Query, HTTPException, Request from pydantic import BaseModel, Field from lightrag.utils import logger @@ -86,11 +86,22 @@ class RelationCreateRequest(BaseModel): ) -def create_graph_routes(rag, api_key: Optional[str] = None): +def create_graph_routes( + rag, + api_key: Optional[str] = None, + get_rag_for_workspace: Optional[Callable[[str | None], Awaitable[Any]]] = None, + get_workspace_from_request: Optional[Callable[[Request], str | None]] = None, +): combined_auth = get_combined_auth_dependency(api_key) + async def resolve_rag(raw_request: Request): + if get_rag_for_workspace is None or get_workspace_from_request is None: + return rag + workspace = get_workspace_from_request(raw_request) + return await get_rag_for_workspace(workspace) + @router.get("/graph/label/list", dependencies=[Depends(combined_auth)]) - async def get_graph_labels(): + async def get_graph_labels(raw_request: Request): """ Get all graph labels @@ -98,7 +109,8 @@ async def get_graph_labels(): List[str]: List of graph labels """ try: - return await rag.get_graph_labels() + active_rag = await resolve_rag(raw_request) + return await active_rag.get_graph_labels() except Exception as e: logger.error(f"Error getting graph labels: {str(e)}") logger.error(traceback.format_exc()) @@ -108,6 +120,7 @@ async def get_graph_labels(): @router.get("/graph/label/popular", dependencies=[Depends(combined_auth)]) async def get_popular_labels( + raw_request: Request, limit: int = Query( 300, description="Maximum number of popular labels to return", ge=1, le=1000 ), @@ -122,7 +135,8 @@ async def get_popular_labels( List[str]: List of popular labels sorted by degree (highest first) """ try: - return await rag.chunk_entity_relation_graph.get_popular_labels(limit) + active_rag = await resolve_rag(raw_request) + return await active_rag.chunk_entity_relation_graph.get_popular_labels(limit) except Exception as e: logger.error(f"Error getting popular labels: {str(e)}") logger.error(traceback.format_exc()) @@ -132,6 +146,7 @@ async def get_popular_labels( @router.get("/graph/label/search", dependencies=[Depends(combined_auth)]) async def search_labels( + raw_request: Request, q: str = Query(..., description="Search query string"), limit: int = Query( 50, description="Maximum number of search results to return", ge=1, le=100 @@ -148,7 +163,8 @@ async def search_labels( List[str]: List of matching labels sorted by relevance """ try: - return await rag.chunk_entity_relation_graph.search_labels(q, limit) + active_rag = await resolve_rag(raw_request) + return await active_rag.chunk_entity_relation_graph.search_labels(q, limit) except Exception as e: logger.error(f"Error searching labels with query '{q}': {str(e)}") logger.error(traceback.format_exc()) @@ -158,6 +174,7 @@ async def search_labels( @router.get("/graphs", dependencies=[Depends(combined_auth)]) async def get_knowledge_graph( + raw_request: Request, label: str = Query(..., description="Label to get knowledge graph for"), max_depth: int = Query(3, description="Maximum depth of graph", ge=1), max_nodes: int = Query(1000, description="Maximum nodes to return", ge=1), @@ -182,7 +199,8 @@ async def get_knowledge_graph( f"get_knowledge_graph called with label: '{label}' (length: {len(label)}, repr: {repr(label)})" ) - return await rag.get_knowledge_graph( + active_rag = await resolve_rag(raw_request) + return await active_rag.get_knowledge_graph( node_label=label, max_depth=max_depth, max_nodes=max_nodes, @@ -196,6 +214,7 @@ async def get_knowledge_graph( @router.get("/graph/entity/exists", dependencies=[Depends(combined_auth)]) async def check_entity_exists( + raw_request: Request, name: str = Query(..., description="Entity name to check"), ): """ @@ -208,7 +227,8 @@ async def check_entity_exists( Dict[str, bool]: Dictionary with 'exists' key indicating if entity exists """ try: - exists = await rag.chunk_entity_relation_graph.has_node(name) + active_rag = await resolve_rag(raw_request) + exists = await active_rag.chunk_entity_relation_graph.has_node(name) return {"exists": exists} except Exception as e: logger.error(f"Error checking entity existence for '{name}': {str(e)}") @@ -218,7 +238,7 @@ async def check_entity_exists( ) @router.post("/graph/entity/edit", dependencies=[Depends(combined_auth)]) - async def update_entity(request: EntityUpdateRequest): + async def update_entity(raw_request: Request, request: EntityUpdateRequest): """ Update an entity's properties in the knowledge graph @@ -353,7 +373,8 @@ async def update_entity(request: EntityUpdateRequest): } """ try: - result = await rag.aedit_entity( + active_rag = await resolve_rag(raw_request) + result = await active_rag.aedit_entity( entity_name=request.entity_name, updated_data=request.updated_data, allow_rename=request.allow_rename, @@ -408,7 +429,7 @@ async def update_entity(request: EntityUpdateRequest): ) @router.post("/graph/relation/edit", dependencies=[Depends(combined_auth)]) - async def update_relation(request: RelationUpdateRequest): + async def update_relation(raw_request: Request, request: RelationUpdateRequest): """Update a relation's properties in the knowledge graph Args: @@ -418,7 +439,8 @@ async def update_relation(request: RelationUpdateRequest): Dict: Updated relation information """ try: - result = await rag.aedit_relation( + active_rag = await resolve_rag(raw_request) + result = await active_rag.aedit_relation( source_entity=request.source_id, target_entity=request.target_id, updated_data=request.updated_data, @@ -443,7 +465,7 @@ async def update_relation(request: RelationUpdateRequest): ) @router.post("/graph/entity/create", dependencies=[Depends(combined_auth)]) - async def create_entity(request: EntityCreateRequest): + async def create_entity(raw_request: Request, request: EntityCreateRequest): """ Create a new entity in the knowledge graph @@ -488,12 +510,13 @@ async def create_entity(request: EntityCreateRequest): } """ try: + active_rag = await resolve_rag(raw_request) # Use the proper acreate_entity method which handles: # - Graph lock for concurrency # - Vector embedding creation in entities_vdb # - Metadata population and defaults # - Index consistency via _edit_entity_done - result = await rag.acreate_entity( + result = await active_rag.acreate_entity( entity_name=request.entity_name, entity_data=request.entity_data, ) @@ -516,7 +539,7 @@ async def create_entity(request: EntityCreateRequest): ) @router.post("/graph/relation/create", dependencies=[Depends(combined_auth)]) - async def create_relation(request: RelationCreateRequest): + async def create_relation(raw_request: Request, request: RelationCreateRequest): """ Create a new relationship between two entities in the knowledge graph @@ -573,13 +596,14 @@ async def create_relation(request: RelationCreateRequest): } """ try: + active_rag = await resolve_rag(raw_request) # Use the proper acreate_relation method which handles: # - Graph lock for concurrency # - Entity existence validation # - Duplicate relation checks # - Vector embedding creation in relationships_vdb # - Index consistency via _edit_relation_done - result = await rag.acreate_relation( + result = await active_rag.acreate_relation( source_entity=request.source_entity, target_entity=request.target_entity, relation_data=request.relation_data, @@ -605,7 +629,7 @@ async def create_relation(request: RelationCreateRequest): ) @router.post("/graph/entities/merge", dependencies=[Depends(combined_auth)]) - async def merge_entities(request: EntityMergeRequest): + async def merge_entities(raw_request: Request, request: EntityMergeRequest): """ Merge multiple entities into a single entity, preserving all relationships @@ -662,7 +686,8 @@ async def merge_entities(request: EntityMergeRequest): - This operation cannot be undone, so verify entity names before merging """ try: - result = await rag.amerge_entities( + active_rag = await resolve_rag(raw_request) + result = await active_rag.amerge_entities( source_entities=request.entities_to_change, target_entity=request.entity_to_change_into, ) diff --git a/lightrag/api/routers/ollama_api.py b/lightrag/api/routers/ollama_api.py index 15c695cee7..87ed7eaee0 100644 --- a/lightrag/api/routers/ollama_api.py +++ b/lightrag/api/routers/ollama_api.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel -from typing import List, Dict, Any, Optional, Type +from typing import List, Dict, Any, Optional, Type, Callable, Awaitable from lightrag.utils import logger import time import json @@ -218,14 +218,41 @@ def parse_query_mode(query: str) -> tuple[str, SearchMode, bool, Optional[str]]: class OllamaAPI: - def __init__(self, rag: LightRAG, top_k: int = 60, api_key: Optional[str] = None): + def __init__( + self, + rag: LightRAG, + top_k: int = 60, + api_key: Optional[str] = None, + get_rag_for_workspace: Optional[ + Callable[[str | None], Awaitable[LightRAG]] + ] = None, + get_workspace_from_request: Optional[Callable[[Request], str | None]] = None, + ): self.rag = rag self.ollama_server_infos = rag.ollama_server_infos self.top_k = top_k self.api_key = api_key + self.get_rag_for_workspace = get_rag_for_workspace + self.get_workspace_from_request = get_workspace_from_request self.router = APIRouter(tags=["ollama"]) self.setup_routes() + async def resolve_rag(self, raw_request: Request) -> LightRAG: + if ( + self.get_rag_for_workspace is None + or self.get_workspace_from_request is None + ): + return self.rag + workspace = self.get_workspace_from_request(raw_request) + return await self.get_rag_for_workspace(workspace) + + @staticmethod + def build_llm_kwargs(rag: LightRAG, system_prompt: Optional[str]) -> Dict[str, Any]: + llm_kwargs = dict(rag.llm_model_kwargs) + if system_prompt: + llm_kwargs["system_prompt"] = system_prompt + return llm_kwargs + def setup_routes(self): # Create combined auth dependency for Ollama API routes combined_auth = get_combined_auth_dependency(self.api_key) @@ -236,21 +263,23 @@ async def get_version(): return OllamaVersionResponse(version="0.9.3") @self.router.get("/tags", dependencies=[Depends(combined_auth)]) - async def get_tags(): + async def get_tags(raw_request: Request): """Return available models acting as an Ollama server""" + active_rag = await self.resolve_rag(raw_request) + active_infos = active_rag.ollama_server_infos return OllamaTagResponse( models=[ { - "name": self.ollama_server_infos.LIGHTRAG_MODEL, - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "modified_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, - "size": self.ollama_server_infos.LIGHTRAG_SIZE, - "digest": self.ollama_server_infos.LIGHTRAG_DIGEST, + "name": active_infos.LIGHTRAG_MODEL, + "model": active_infos.LIGHTRAG_MODEL, + "modified_at": active_infos.LIGHTRAG_CREATED_AT, + "size": active_infos.LIGHTRAG_SIZE, + "digest": active_infos.LIGHTRAG_DIGEST, "details": { "parent_model": "", "format": "gguf", - "family": self.ollama_server_infos.LIGHTRAG_NAME, - "families": [self.ollama_server_infos.LIGHTRAG_NAME], + "family": active_infos.LIGHTRAG_NAME, + "families": [active_infos.LIGHTRAG_NAME], "parameter_size": "13B", "quantization_level": "Q4_0", }, @@ -259,15 +288,17 @@ async def get_tags(): ) @self.router.get("/ps", dependencies=[Depends(combined_auth)]) - async def get_running_models(): + async def get_running_models(raw_request: Request): """List Running Models - returns currently running models""" + active_rag = await self.resolve_rag(raw_request) + active_infos = active_rag.ollama_server_infos return OllamaPsResponse( models=[ { - "name": self.ollama_server_infos.LIGHTRAG_MODEL, - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "size": self.ollama_server_infos.LIGHTRAG_SIZE, - "digest": self.ollama_server_infos.LIGHTRAG_DIGEST, + "name": active_infos.LIGHTRAG_MODEL, + "model": active_infos.LIGHTRAG_MODEL, + "size": active_infos.LIGHTRAG_SIZE, + "digest": active_infos.LIGHTRAG_DIGEST, "details": { "parent_model": "", "format": "gguf", @@ -277,7 +308,7 @@ async def get_running_models(): "quantization_level": "Q4_0", }, "expires_at": "2050-12-31T14:38:31.83753-07:00", - "size_vram": self.ollama_server_infos.LIGHTRAG_SIZE, + "size_vram": active_infos.LIGHTRAG_SIZE, } ] ) @@ -294,17 +325,17 @@ async def generate(raw_request: Request): try: # Parse the request body manually request = await parse_request_body(raw_request, OllamaGenerateRequest) + active_rag = await self.resolve_rag(raw_request) + active_infos = active_rag.ollama_server_infos query = request.prompt start_time = time.time_ns() prompt_tokens = estimate_tokens(query) - - if request.system: - self.rag.llm_model_kwargs["system_prompt"] = request.system + llm_kwargs = self.build_llm_kwargs(active_rag, request.system) if request.stream: - response = await self.rag.llm_model_func( - query, stream=True, **self.rag.llm_model_kwargs + response = await active_rag.llm_model_func( + query, stream=True, **llm_kwargs ) async def stream_generator(): @@ -320,8 +351,8 @@ async def stream_generator(): total_response = response data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": active_infos.LIGHTRAG_MODEL, + "created_at": active_infos.LIGHTRAG_CREATED_AT, "response": response, "done": False, } @@ -333,8 +364,8 @@ async def stream_generator(): eval_time = last_chunk_time - first_chunk_time data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": active_infos.LIGHTRAG_MODEL, + "created_at": active_infos.LIGHTRAG_CREATED_AT, "response": "", "done": True, "done_reason": "stop", @@ -358,8 +389,8 @@ async def stream_generator(): total_response += chunk data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": active_infos.LIGHTRAG_MODEL, + "created_at": active_infos.LIGHTRAG_CREATED_AT, "response": chunk, "done": False, } @@ -375,8 +406,8 @@ async def stream_generator(): # Send error message to client error_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": active_infos.LIGHTRAG_MODEL, + "created_at": active_infos.LIGHTRAG_CREATED_AT, "response": f"\n\nError: {error_msg}", "error": f"\n\nError: {error_msg}", "done": False, @@ -385,8 +416,8 @@ async def stream_generator(): # Send final message to close the stream final_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": active_infos.LIGHTRAG_MODEL, + "created_at": active_infos.LIGHTRAG_CREATED_AT, "response": "", "done": True, } @@ -400,8 +431,8 @@ async def stream_generator(): eval_time = last_chunk_time - first_chunk_time data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": active_infos.LIGHTRAG_MODEL, + "created_at": active_infos.LIGHTRAG_CREATED_AT, "response": "", "done": True, "done_reason": "stop", @@ -428,8 +459,8 @@ async def stream_generator(): ) else: first_chunk_time = time.time_ns() - response_text = await self.rag.llm_model_func( - query, stream=False, **self.rag.llm_model_kwargs + response_text = await active_rag.llm_model_func( + query, stream=False, **llm_kwargs ) last_chunk_time = time.time_ns() @@ -442,8 +473,8 @@ async def stream_generator(): eval_time = last_chunk_time - first_chunk_time return { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": active_infos.LIGHTRAG_MODEL, + "created_at": active_infos.LIGHTRAG_CREATED_AT, "response": str(response_text), "done": True, "done_reason": "stop", @@ -471,6 +502,8 @@ async def chat(raw_request: Request): try: # Parse the request body manually request = await parse_request_body(raw_request, OllamaChatRequest) + active_rag = await self.resolve_rag(raw_request) + active_infos = active_rag.ollama_server_infos # Get all messages messages = request.messages @@ -511,20 +544,19 @@ async def chat(raw_request: Request): param_dict["user_prompt"] = user_prompt query_param = QueryParam(**param_dict) + llm_kwargs = self.build_llm_kwargs(active_rag, request.system) if request.stream: # Determine if the request is prefix with "/bypass" if mode == SearchMode.bypass: - if request.system: - self.rag.llm_model_kwargs["system_prompt"] = request.system - response = await self.rag.llm_model_func( + response = await active_rag.llm_model_func( cleaned_query, stream=True, history_messages=conversation_history, - **self.rag.llm_model_kwargs, + **llm_kwargs, ) else: - response = await self.rag.aquery( + response = await active_rag.aquery( cleaned_query, param=query_param ) @@ -541,8 +573,8 @@ async def stream_generator(): total_response = response data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": active_infos.LIGHTRAG_MODEL, + "created_at": active_infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", "content": response, @@ -558,8 +590,8 @@ async def stream_generator(): eval_time = last_chunk_time - first_chunk_time data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": active_infos.LIGHTRAG_MODEL, + "created_at": active_infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", "content": "", @@ -586,8 +618,8 @@ async def stream_generator(): total_response += chunk data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": active_infos.LIGHTRAG_MODEL, + "created_at": active_infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", "content": chunk, @@ -607,8 +639,8 @@ async def stream_generator(): # Send error message to client error_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": active_infos.LIGHTRAG_MODEL, + "created_at": active_infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", "content": f"\n\nError: {error_msg}", @@ -621,8 +653,8 @@ async def stream_generator(): # Send final message to close the stream final_data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": active_infos.LIGHTRAG_MODEL, + "created_at": active_infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", "content": "", @@ -641,8 +673,8 @@ async def stream_generator(): eval_time = last_chunk_time - first_chunk_time data = { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": active_infos.LIGHTRAG_MODEL, + "created_at": active_infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", "content": "", @@ -677,17 +709,14 @@ async def stream_generator(): r"\n\nUSER:", cleaned_query, re.MULTILINE ) if match_result or mode == SearchMode.bypass: - if request.system: - self.rag.llm_model_kwargs["system_prompt"] = request.system - - response_text = await self.rag.llm_model_func( + response_text = await active_rag.llm_model_func( cleaned_query, stream=False, history_messages=conversation_history, - **self.rag.llm_model_kwargs, + **llm_kwargs, ) else: - response_text = await self.rag.aquery( + response_text = await active_rag.aquery( cleaned_query, param=query_param ) @@ -702,8 +731,8 @@ async def stream_generator(): eval_time = last_chunk_time - first_chunk_time return { - "model": self.ollama_server_infos.LIGHTRAG_MODEL, - "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT, + "model": active_infos.LIGHTRAG_MODEL, + "created_at": active_infos.LIGHTRAG_CREATED_AT, "message": { "role": "assistant", "content": str(response_text), diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 22958158a1..6fd55f4524 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -3,8 +3,8 @@ """ import json -from typing import Any, Dict, List, Literal, Optional -from fastapi import APIRouter, Depends, HTTPException +from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional +from fastapi import APIRouter, Depends, HTTPException, Request from lightrag.base import QueryParam from lightrag.api.utils_api import get_combined_auth_dependency from lightrag.utils import logger @@ -190,9 +190,21 @@ class StreamChunkResponse(BaseModel): ) -def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): +def create_query_routes( + rag, + api_key: Optional[str] = None, + top_k: int = 60, + get_rag_for_workspace: Optional[Callable[[str | None], Awaitable[Any]]] = None, + get_workspace_from_request: Optional[Callable[[Request], str | None]] = None, +): combined_auth = get_combined_auth_dependency(api_key) + async def resolve_rag(raw_request: Request): + if get_rag_for_workspace is None or get_workspace_from_request is None: + return rag + workspace = get_workspace_from_request(raw_request) + return await get_rag_for_workspace(workspace) + @router.post( "/query", response_model=QueryResponse, @@ -322,7 +334,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): }, }, ) - async def query_text(request: QueryRequest): + async def query_text(raw_request: Request, request: QueryRequest): """ Comprehensive RAG query endpoint with non-streaming response. Parameter "stream" is ignored. @@ -409,7 +421,8 @@ async def query_text(request: QueryRequest): param.stream = False # Unified approach: always use aquery_llm for both cases - result = await rag.aquery_llm(request.query, param=param) + active_rag = await resolve_rag(raw_request) + result = await active_rag.aquery_llm(request.query, param=param) # Extract LLM response and references from unified result llm_response = result.get("llm_response", {}) @@ -532,7 +545,7 @@ async def query_text(request: QueryRequest): }, }, ) - async def query_text_stream(request: QueryRequest): + async def query_text_stream(raw_request: Request, request: QueryRequest): """ Advanced RAG query endpoint with flexible streaming response. @@ -667,7 +680,8 @@ async def query_text_stream(request: QueryRequest): from fastapi.responses import StreamingResponse # Unified approach: always use aquery_llm for all cases - result = await rag.aquery_llm(request.query, param=param) + active_rag = await resolve_rag(raw_request) + result = await active_rag.aquery_llm(request.query, param=param) async def stream_generator(): # Extract references and LLM response from unified result @@ -1035,7 +1049,7 @@ async def stream_generator(): }, }, ) - async def query_data(request: QueryRequest): + async def query_data(raw_request: Request, request: QueryRequest): """ Advanced data retrieval endpoint for structured RAG analysis. @@ -1140,7 +1154,8 @@ async def query_data(request: QueryRequest): """ try: param = request.to_query_params(False) # No streaming for data endpoint - response = await rag.aquery_data(request.query, param=param) + active_rag = await resolve_rag(raw_request) + response = await active_rag.aquery_data(request.query, param=param) # aquery_data returns the new format with status, message, data, and metadata if isinstance(response, dict): @@ -1151,7 +1166,6 @@ async def query_data(request: QueryRequest): status="failure", message="Invalid response type", data={}, - metadata={}, ) except Exception as e: logger.error(f"Error processing data query: {str(e)}", exc_info=True) diff --git a/lightrag_webui/src/api/lightrag.ts b/lightrag_webui/src/api/lightrag.ts index ea90020628..e92fe62ec3 100644 --- a/lightrag_webui/src/api/lightrag.ts +++ b/lightrag_webui/src/api/lightrag.ts @@ -41,6 +41,7 @@ export type LightragStatus = { graph_storage: string vector_storage: string workspace?: string + default_workspace?: string max_graph_nodes?: string enable_rerank?: boolean rerank_binding?: string | null @@ -86,6 +87,20 @@ export type LightragDocumentsScanProgress = { progress: number } +export type LightragWorkspace = { + id: string + label: string + is_default: boolean + is_current: boolean + sources: string[] +} + +export type LightragWorkspaceList = { + current_workspace: string + default_workspace: string + workspaces: LightragWorkspace[] +} + /** * Specifies the retrieval mode: * - "naive": Performs a basic search without advanced techniques. @@ -286,6 +301,11 @@ const axiosInstance = axios.create({ } }) +const getWorkspaceHeader = (): string | null => { + const workspace = useSettingsStore.getState().workspace?.trim() + return workspace ? workspace : null +} + // ========== Token Management ========== // Prevent multiple requests from triggering token refresh simultaneously let isRefreshingGuestToken = false; @@ -352,6 +372,12 @@ axiosInstance.interceptors.request.use((config) => { if (apiKey) { config.headers['X-API-Key'] = apiKey } + const workspace = getWorkspaceHeader() + if (workspace) { + config.headers['LIGHTRAG-WORKSPACE'] = workspace + } else if (config.headers['LIGHTRAG-WORKSPACE']) { + delete config.headers['LIGHTRAG-WORKSPACE'] + } return config }) @@ -488,6 +514,11 @@ export const checkHealth = async (): Promise< } } +export const listWorkspaces = async (): Promise => { + const response = await axiosInstance.get('/workspaces') + return response.data +} + export const getDocuments = async (): Promise => { const response = await axiosInstance.get('/documents') return response.data @@ -530,6 +561,10 @@ export const queryTextStream = async ( if (apiKey) { headers['X-API-Key'] = apiKey; } + const workspace = getWorkspaceHeader(); + if (workspace) { + headers['LIGHTRAG-WORKSPACE'] = workspace; + } try { const response = await fetch(`${backendBaseUrl}/query/stream`, { diff --git a/lightrag_webui/src/components/AppSettings.tsx b/lightrag_webui/src/components/AppSettings.tsx index 93f507efce..d5ccf63d36 100644 --- a/lightrag_webui/src/components/AppSettings.tsx +++ b/lightrag_webui/src/components/AppSettings.tsx @@ -1,9 +1,12 @@ -import { useState, useCallback } from 'react' +import { useState, useCallback, useEffect } from 'react' import { Popover, PopoverContent, PopoverTrigger } from '@/components/ui/Popover' import Button from '@/components/ui/Button' +import { listWorkspaces, type LightragWorkspace } from '@/api/lightrag' import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/components/ui/Select' +import Input from '@/components/ui/Input' import { useSettingsStore } from '@/stores/settings' -import { PaletteIcon } from 'lucide-react' +import { useBackendState } from '@/stores/state' +import { Loader2, PaletteIcon, RefreshCw } from 'lucide-react' import { useTranslation } from 'react-i18next' import { cn } from '@/lib/utils' @@ -13,6 +16,9 @@ interface AppSettingsProps { export default function AppSettings({ className }: AppSettingsProps) { const [opened, setOpened] = useState(false) + const [knownWorkspaces, setKnownWorkspaces] = useState([]) + const [loadingWorkspaces, setLoadingWorkspaces] = useState(false) + const [workspaceListLoaded, setWorkspaceListLoaded] = useState(false) const { t } = useTranslation() const language = useSettingsStore.use.language() @@ -21,6 +27,10 @@ export default function AppSettings({ className }: AppSettingsProps) { const theme = useSettingsStore.use.theme() const setTheme = useSettingsStore.use.setTheme() + const workspace = useSettingsStore.use.workspace() + const setWorkspace = useSettingsStore.use.setWorkspace() + const [workspaceDraft, setWorkspaceDraft] = useState(workspace) + const handleLanguageChange = useCallback((value: string) => { setLanguage(value as 'en' | 'zh' | 'fr' | 'ar' | 'zh_TW' | 'ru' | 'ja' | 'de' | 'uk' | 'ko' | 'vi') }, [setLanguage]) @@ -29,15 +39,135 @@ export default function AppSettings({ className }: AppSettingsProps) { setTheme(value as 'light' | 'dark' | 'system') }, [setTheme]) + const handleWorkspaceChange = useCallback((value: string) => { + setWorkspace(value) + useBackendState.getState().resetHealthCheckTimerDelayed(0) + }, [setWorkspace]) + + const handleWorkspaceCandidateSelect = useCallback((value: string) => { + setWorkspaceDraft(value) + handleWorkspaceChange(value) + setOpened(false) + }, [handleWorkspaceChange]) + + const commitWorkspaceDraft = useCallback(() => { + if (workspaceDraft !== workspace) { + handleWorkspaceChange(workspaceDraft) + } + }, [handleWorkspaceChange, workspace, workspaceDraft]) + + const handleOpenChange = useCallback((nextOpened: boolean) => { + if (!nextOpened) { + commitWorkspaceDraft() + } + setOpened(nextOpened) + }, [commitWorkspaceDraft]) + + const handleWorkspaceKeyDown = useCallback((event: React.KeyboardEvent) => { + if (event.key === 'Enter') { + commitWorkspaceDraft() + } + }, [commitWorkspaceDraft]) + + const handleWorkspaceBlur = useCallback(() => { + commitWorkspaceDraft() + }, [commitWorkspaceDraft]) + + const handleWorkspaceDraftChange = useCallback((value: string) => { + setWorkspaceDraft(value) + }, []) + + useEffect(() => { + setWorkspaceDraft(workspace) + }, [workspace]) + + const refreshWorkspaceList = useCallback(async () => { + setLoadingWorkspaces(true) + try { + const result = await listWorkspaces() + setKnownWorkspaces(result.workspaces || []) + setWorkspaceListLoaded(true) + } catch (error) { + console.error('Failed to load workspaces:', error) + setKnownWorkspaces([]) + setWorkspaceListLoaded(false) + } finally { + setLoadingWorkspaces(false) + } + }, []) + + useEffect(() => { + if (opened && !workspaceListLoaded && !loadingWorkspaces) { + void refreshWorkspaceList() + } + }, [loadingWorkspaces, opened, refreshWorkspaceList, workspaceListLoaded]) + return ( - + - +
+
+ + handleWorkspaceDraftChange(event.target.value)} + onBlur={handleWorkspaceBlur} + onKeyDown={handleWorkspaceKeyDown} + placeholder={t('settings.workspacePlaceholder', { + defaultValue: 'Leave empty to use the server default workspace' + })} + autoCapitalize="off" + autoCorrect="off" + spellCheck={false} + /> +

+ {t('settings.workspaceDescription', { + defaultValue: 'All document, graph, and retrieval requests use this workspace header.' + })} +

+
+ + {t('settings.detectedWorkspaces', { defaultValue: 'Detected workspaces' })} + + +
+
+ {knownWorkspaces.length > 0 ? knownWorkspaces.map((candidate) => ( + + )) : ( + + {loadingWorkspaces + ? t('settings.loadingWorkspaces', { defaultValue: 'Loading workspaces...' }) + : t('settings.noDetectedWorkspaces', { defaultValue: 'No detected workspaces yet' })} + + )} +
+
+