Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 207 additions & 47 deletions lightrag/api/lightrag_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
get_swagger_ui_html,
get_swagger_ui_oauth2_redirect_html,
)
import asyncio
import os
import re
import logging
Expand All @@ -19,6 +20,8 @@
from fastapi.staticfiles import StaticFiles
from fastapi.responses import RedirectResponse
from pathlib import Path
import configparser
from fnmatch import fnmatch
from ascii_colors import ASCIIColors
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
Expand Down Expand Up @@ -73,6 +76,10 @@
webui_title = os.getenv("WEBUI_TITLE")
webui_description = os.getenv("WEBUI_DESCRIPTION")

# Initialize config parser
config = configparser.ConfigParser()
config.read("config.ini")

# Global authentication configuration
auth_configured = bool(auth_handler.accounts)

Expand Down Expand Up @@ -339,8 +346,90 @@ def create_app(args):
# Check if API key is provided either through env var or args
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key

# Initialize document manager with workspace support for data isolation
doc_manager = DocumentManager(args.input_dir, workspace=args.workspace)
def normalize_workspace(workspace: str | None) -> str:
if workspace is None:
return args.workspace or ""
return workspace

def detect_workspace_dirs(working_dir: Path) -> list[str]:
workspace_names: set[str] = set()
if not working_dir.exists():
return []

storage_patterns = (
"kv_store_*.json",
"graph_*.graphml",
"vdb_*.json",
"faiss_index_*.index",
)

for child in working_dir.iterdir():
if not child.is_dir():
continue
try:
if any(
nested.is_file()
and any(fnmatch(nested.name, pattern) for pattern in storage_patterns)
for nested in child.iterdir()
):
workspace_names.add(child.name)
except OSError:
continue

return sorted(workspace_names)

def build_workspace_inventory(current_workspace: str | None = None) -> dict:
default_workspace = get_default_workspace() or ""
effective_workspace = normalize_workspace(current_workspace)

sources_by_workspace: dict[str, set[str]] = {}

def add_workspace(name: str | None, source: str):
workspace_name = normalize_workspace(name)
sources_by_workspace.setdefault(workspace_name, set()).add(source)

add_workspace(default_workspace, "default-config")
add_workspace(effective_workspace, "current-request")

for workspace_name in rag_registry.keys():
add_workspace(workspace_name, "rag-registry")

for workspace_name in doc_manager_registry.keys():
add_workspace(workspace_name, "document-manager")

for workspace_name in detect_workspace_dirs(Path(args.working_dir)):
add_workspace(workspace_name, "working-dir-scan")

workspace_entries = []
for workspace_name in sorted(sources_by_workspace.keys()):
workspace_entries.append(
{
"id": workspace_name,
"label": workspace_name if workspace_name else "default",
"is_default": workspace_name == default_workspace,
"is_current": workspace_name == effective_workspace,
"sources": sorted(sources_by_workspace[workspace_name]),
}
)

return {
"current_workspace": effective_workspace,
"default_workspace": default_workspace,
"workspaces": workspace_entries,
}

# Initialize document manager registry with workspace support for data isolation
doc_manager_registry: dict[str, DocumentManager] = {}

def get_doc_manager_for_workspace(workspace: str | None) -> DocumentManager:
workspace_key = normalize_workspace(workspace)
manager = doc_manager_registry.get(workspace_key)
if manager is None:
manager = DocumentManager(args.input_dir, workspace=workspace_key)
doc_manager_registry[workspace_key] = manager
return manager

doc_manager = get_doc_manager_for_workspace(args.workspace)

@asynccontextmanager
async def lifespan(app: FastAPI):
Expand All @@ -361,8 +450,14 @@ async def lifespan(app: FastAPI):
yield

finally:
# Clean up database connections
await rag.finalize_storages()
# Clean up database connections for all initialized workspaces
for workspace_key, rag_instance in list(rag_registry.items()):
try:
await rag_instance.finalize_storages()
except Exception as e:
logger.error(
f"Failed to finalize storages for workspace '{workspace_key}': {e}"
)

if "LIGHTRAG_GUNICORN_MODE" not in os.environ:
# Only perform cleanup in Uvicorn single-process mode
Expand Down Expand Up @@ -1055,59 +1150,111 @@ async def server_rerank_func(
name=args.simulated_model_name, tag=args.simulated_model_tag
)

# Initialize RAG with unified configuration
try:
rag = LightRAG(
working_dir=args.working_dir,
workspace=args.workspace,
llm_model_func=create_llm_model_func(args.llm_binding),
llm_model_name=args.llm_model,
llm_model_max_async=args.max_async,
summary_max_tokens=args.summary_max_tokens,
summary_context_size=args.summary_context_size,
chunk_token_size=int(args.chunk_size),
chunk_overlap_token_size=int(args.chunk_overlap_size),
llm_model_kwargs=create_llm_model_kwargs(
args.llm_binding, args, llm_timeout
),
embedding_func=embedding_func,
default_llm_timeout=llm_timeout,
default_embedding_timeout=embedding_timeout,
kv_storage=args.kv_storage,
graph_storage=args.graph_storage,
vector_storage=args.vector_storage,
doc_status_storage=args.doc_status_storage,
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
enable_llm_cache=args.enable_llm_cache,
rerank_model_func=rerank_model_func,
max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes,
addon_params={
"language": args.summary_language,
"entity_types": args.entity_types,
},
ollama_server_infos=ollama_server_infos,
)
except Exception as e:
logger.error(f"Failed to initialize LightRAG: {e}")
raise
def build_rag(workspace: str | None = None) -> LightRAG:
workspace_key = normalize_workspace(workspace)
try:
return LightRAG(
working_dir=args.working_dir,
workspace=workspace_key,
llm_model_func=create_llm_model_func(args.llm_binding),
llm_model_name=args.llm_model,
llm_model_max_async=args.max_async,
summary_max_tokens=args.summary_max_tokens,
summary_context_size=args.summary_context_size,
chunk_token_size=int(args.chunk_size),
chunk_overlap_token_size=int(args.chunk_overlap_size),
llm_model_kwargs=create_llm_model_kwargs(
args.llm_binding, args, llm_timeout
),
embedding_func=embedding_func,
default_llm_timeout=llm_timeout,
default_embedding_timeout=embedding_timeout,
kv_storage=args.kv_storage,
graph_storage=args.graph_storage,
vector_storage=args.vector_storage,
doc_status_storage=args.doc_status_storage,
vector_db_storage_cls_kwargs={
"cosine_better_than_threshold": args.cosine_threshold
},
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
enable_llm_cache=args.enable_llm_cache,
rerank_model_func=rerank_model_func,
max_parallel_insert=args.max_parallel_insert,
max_graph_nodes=args.max_graph_nodes,
addon_params={
"language": args.summary_language,
"entity_types": args.entity_types,
},
ollama_server_infos=ollama_server_infos,
)
except Exception as e:
logger.error(f"Failed to initialize LightRAG for workspace '{workspace_key}': {e}")
raise

rag_registry: dict[str, LightRAG] = {}
rag_registry_locks: dict[str, asyncio.Lock] = {}
default_workspace = normalize_workspace(args.workspace)
rag = build_rag(default_workspace)
rag_registry[default_workspace] = rag

async def get_rag_for_workspace(workspace: str | None = None) -> LightRAG:
workspace_key = normalize_workspace(workspace)
rag_instance = rag_registry.get(workspace_key)
if rag_instance is not None:
return rag_instance

lock = rag_registry_locks.setdefault(workspace_key, asyncio.Lock())
async with lock:
rag_instance = rag_registry.get(workspace_key)
if rag_instance is not None:
return rag_instance

rag_instance = build_rag(workspace_key)
await rag_instance.initialize_storages()
await rag_instance.check_and_migrate_data()
rag_registry[workspace_key] = rag_instance
logger.info(
f"Initialized workspace-specific LightRAG instance for workspace '{workspace_key}'"
)
return rag_instance

# Add routes
app.include_router(
create_document_routes(
rag,
doc_manager,
api_key,
get_rag_for_workspace=get_rag_for_workspace,
get_doc_manager_for_workspace=get_doc_manager_for_workspace,
get_workspace_from_request=get_workspace_from_request,
)
)
app.include_router(
create_query_routes(
rag,
api_key,
args.top_k,
get_rag_for_workspace=get_rag_for_workspace,
get_workspace_from_request=get_workspace_from_request,
)
)
app.include_router(
create_graph_routes(
rag,
api_key,
get_rag_for_workspace=get_rag_for_workspace,
get_workspace_from_request=get_workspace_from_request,
)
)
app.include_router(create_query_routes(rag, api_key, args.top_k))
app.include_router(create_graph_routes(rag, api_key))

# Add Ollama API routes
ollama_api = OllamaAPI(rag, top_k=args.top_k, api_key=api_key)
ollama_api = OllamaAPI(
rag,
top_k=args.top_k,
api_key=api_key,
get_rag_for_workspace=get_rag_for_workspace,
get_workspace_from_request=get_workspace_from_request,
)
app.include_router(ollama_api.router, prefix="/api")

# Custom Swagger UI endpoint for offline support
Expand Down Expand Up @@ -1242,6 +1389,7 @@ async def get_status(request: Request):
default_workspace = get_default_workspace()
if workspace is None:
workspace = default_workspace
await get_rag_for_workspace(workspace)
pipeline_status = await get_namespace_data(
"pipeline_status", workspace=workspace
)
Expand Down Expand Up @@ -1276,7 +1424,8 @@ async def get_status(request: Request):
"vector_storage": args.vector_storage,
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
"enable_llm_cache": args.enable_llm_cache,
"workspace": default_workspace,
"workspace": workspace,
"default_workspace": default_workspace,
"max_graph_nodes": args.max_graph_nodes,
# Rerank configuration
"enable_rerank": rerank_model_func is not None,
Expand Down Expand Up @@ -1308,6 +1457,17 @@ async def get_status(request: Request):
logger.error(f"Error getting health status: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))

@app.get(
"/workspaces",
dependencies=[Depends(combined_auth)],
tags=["workspace"],
summary="List known workspaces",
description="Return the currently known workspaces from active registries and local storage directories.",
)
async def list_workspaces(request: Request):
workspace = get_workspace_from_request(request)
return build_workspace_inventory(workspace)

# Custom StaticFiles class for smart caching
class SmartStaticFiles(StaticFiles): # Renamed from NoCacheStaticFiles
async def get_response(self, path: str, scope):
Expand Down
Loading