diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index af8ddb4099..049dcb1222 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -530,6 +530,9 @@ class LightRAG(_RoleLLMMixin, _StorageMigrationMixin, _PipelineMixin): repr=False, ) + prompts: dict[str, Any] = field(default_factory=dict) + """Per-instance prompt overrides. Keys matching PROMPTS dict take precedence over the global PROMPTS.""" + # Storages Management # --- diff --git a/lightrag/operate.py b/lightrag/operate.py index 2334ddcc17..0dedf70c36 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -55,7 +55,13 @@ QueryResult, QueryContextResult, ) -from lightrag.prompt import PROMPTS, resolve_entity_extraction_prompt_profile +from lightrag.prompt import ( + PROMPTS, + _get_prompt, + resolve_entity_extraction_prompt_profile, +) + + from lightrag.constants import ( GRAPH_FIELD_SEP, DEFAULT_MAX_ENTITY_TOKENS, @@ -400,7 +406,7 @@ async def _summarize_descriptions( summary_length_recommended = global_config["summary_length_recommended"] - prompt_template = PROMPTS["summarize_entity_descriptions"] + prompt_template = _get_prompt("summarize_entity_descriptions", global_config) # Convert descriptions to JSONL format and apply token-based truncation tokenizer = global_config["tokenizer"] @@ -3303,7 +3309,9 @@ async def extract_entities( # path that LightRAG.__post_init__ populates, instead of duplicating # guidance/override logic here. prompt_profile = resolve_entity_extraction_prompt_profile( - addon_params, use_json_extraction + addon_params, + use_json_extraction, + global_config, ) entity_types_guidance = prompt_profile["entity_types_guidance"] @@ -3324,8 +3332,10 @@ async def extract_entities( # Text mode: use traditional delimiter-based prompts examples = "\n".join(prompt_profile["entity_extraction_examples"]) example_context_base = dict( - tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], - completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], + tuple_delimiter=_get_prompt("DEFAULT_TUPLE_DELIMITER", global_config), + completion_delimiter=_get_prompt( + "DEFAULT_COMPLETION_DELIMITER", global_config + ), entity_types_guidance=entity_types_guidance, language=language, ) @@ -3333,8 +3343,10 @@ async def extract_entities( examples = examples.format(**example_context_base) context_base = dict( - tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"], - completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"], + tuple_delimiter=_get_prompt("DEFAULT_TUPLE_DELIMITER", global_config), + completion_delimiter=_get_prompt( + "DEFAULT_COMPLETION_DELIMITER", global_config + ), entity_types_guidance=entity_types_guidance, examples=examples, language=language, @@ -3365,26 +3377,26 @@ async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]): if use_json_extraction: # JSON mode: use JSON prompts and pass entity_extraction flag to LLM provider - entity_extraction_system_prompt = PROMPTS[ - "entity_extraction_json_system_prompt" - ].format(**context_base) - entity_extraction_user_prompt = PROMPTS[ - "entity_extraction_json_user_prompt" - ].format(**{**context_base, "input_text": content}) - entity_continue_extraction_user_prompt = PROMPTS[ - "entity_continue_extraction_json_user_prompt" - ].format(**context_base) + entity_extraction_system_prompt = _get_prompt( + "entity_extraction_json_system_prompt", global_config + ).format(**context_base) + entity_extraction_user_prompt = _get_prompt( + "entity_extraction_json_user_prompt", global_config + ).format(**{**context_base, "input_text": content}) + entity_continue_extraction_user_prompt = _get_prompt( + "entity_continue_extraction_json_user_prompt", global_config + ).format(**context_base) else: # Text mode: use traditional delimiter-based prompts - entity_extraction_system_prompt = PROMPTS[ - "entity_extraction_system_prompt" - ].format(**context_base) - entity_extraction_user_prompt = PROMPTS[ - "entity_extraction_user_prompt" - ].format(**{**context_base, "input_text": content}) - entity_continue_extraction_user_prompt = PROMPTS[ - "entity_continue_extraction_user_prompt" - ].format(**{**context_base, "input_text": content}) + entity_extraction_system_prompt = _get_prompt( + "entity_extraction_system_prompt", global_config + ).format(**context_base) + entity_extraction_user_prompt = _get_prompt( + "entity_extraction_user_prompt", global_config + ).format(**{**context_base, "input_text": content}) + entity_continue_extraction_user_prompt = _get_prompt( + "entity_continue_extraction_user_prompt", global_config + ).format(**{**context_base, "input_text": content}) final_result, timestamp = await use_llm_func_with_cache( entity_extraction_user_prompt, @@ -3631,7 +3643,7 @@ async def kg_query( Returns None when no relevant context could be constructed for the query. """ if not query: - return QueryResult(content=PROMPTS["fail_response"]) + return QueryResult(content=_get_prompt("fail_response", global_config)) if query_param.model_func: use_model_func = query_param.model_func @@ -3660,7 +3672,7 @@ async def kg_query( logger.warning(f"Forced low_level_keywords to origin query: {query}") ll_keywords = [query] else: - return QueryResult(content=PROMPTS["fail_response"]) + return QueryResult(content=_get_prompt("fail_response", global_config)) ll_keywords_str = ", ".join(ll_keywords) if ll_keywords else "" hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else "" @@ -3696,7 +3708,9 @@ async def kg_query( ) # Build system prompt - sys_prompt_temp = system_prompt if system_prompt else PROMPTS["rag_response"] + sys_prompt_temp = ( + system_prompt if system_prompt else _get_prompt("rag_response", global_config) + ) sys_prompt = sys_prompt_temp.format( response_type=response_type, user_prompt=user_prompt, @@ -3973,7 +3987,7 @@ async def extract_keywords_only( """ # 1. Build the examples - examples = "\n".join(PROMPTS["keywords_extraction_examples"]) + examples = "\n".join(_get_prompt("keywords_extraction_examples", global_config)) addon_params = global_config.get("addon_params") or {} language = global_config.get("_resolved_summary_language") @@ -4007,7 +4021,7 @@ async def extract_keywords_only( ) # 3. Build the keyword-extraction prompt - kw_prompt = PROMPTS["keywords_extraction"].format( + kw_prompt = _get_prompt("keywords_extraction", global_config).format( query=text, examples=examples, language=language, @@ -4646,10 +4660,10 @@ async def _build_context_str( # Get the system prompt template from PROMPTS or global_config sys_prompt_template = global_config.get( - "system_prompt_template", PROMPTS["rag_response"] + "system_prompt_template", _get_prompt("rag_response", global_config) ) - kg_context_template = PROMPTS["kg_query_context"] + kg_context_template = _get_prompt("kg_query_context", global_config) user_prompt = query_param.user_prompt if query_param.user_prompt else "" response_type = ( query_param.response_type @@ -5534,7 +5548,7 @@ async def naive_query( """ if not query: - return QueryResult(content=PROMPTS["fail_response"]) + return QueryResult(content=_get_prompt("fail_response", global_config)) if query_param.model_func: use_model_func = query_param.model_func @@ -5549,7 +5563,7 @@ async def naive_query( tokenizer: Tokenizer = global_config["tokenizer"] if not tokenizer: logger.error("Tokenizer not found in global configuration.") - return QueryResult(content=PROMPTS["fail_response"]) + return QueryResult(content=_get_prompt("fail_response", global_config)) chunks = await _get_vector_context(query, chunks_vdb, query_param, None) @@ -5576,7 +5590,9 @@ async def naive_query( # Use the provided system prompt or default sys_prompt_template = ( - system_prompt if system_prompt else PROMPTS["naive_rag_response"] + system_prompt + if system_prompt + else _get_prompt("naive_rag_response", global_config) ) # Create a preliminary system prompt with empty content_data to calculate overhead @@ -5655,7 +5671,7 @@ async def naive_query( if ref["reference_id"] ) - naive_context_template = PROMPTS["naive_query_context"] + naive_context_template = _get_prompt("naive_query_context", global_config) context_content = naive_context_template.format( text_chunks_str=text_units_str, reference_list_str=reference_list_str, diff --git a/lightrag/prompt.py b/lightrag/prompt.py index fbfe61e0b0..6fb416eab2 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -698,16 +698,31 @@ class EntityExtractionPromptProfile(TypedDict): entity_extraction_json_examples: list[str] -def get_default_entity_extraction_prompt_profile() -> EntityExtractionPromptProfile: +def _get_prompt(key: str, global_config: dict | None = None) -> Any: + """Return per-instance prompt override if set, else fall back to global PROMPTS.""" + if global_config: + override = global_config.get("prompts", {}).get(key) + if override is not None: + return override + return PROMPTS[key] + + +def get_default_entity_extraction_prompt_profile( + global_config: dict | None = None, +) -> EntityExtractionPromptProfile: """Return a copy of the built-in entity extraction prompt profile.""" return { - "entity_types_guidance": PROMPTS["default_entity_types_guidance"].rstrip(), + "entity_types_guidance": _get_prompt( + "default_entity_types_guidance", global_config + ).rstrip(), "entity_extraction_examples": [ - example.rstrip() for example in PROMPTS["entity_extraction_examples"] + example.rstrip() + for example in _get_prompt("entity_extraction_examples", global_config) ], "entity_extraction_json_examples": [ - example.rstrip() for example in PROMPTS["entity_extraction_json_examples"] + example.rstrip() + for example in _get_prompt("entity_extraction_json_examples", global_config) ], } @@ -847,10 +862,11 @@ def load_entity_extraction_prompt_profile( def resolve_entity_extraction_prompt_profile( addon_params: Mapping[str, Any] | None, use_json: bool, + global_config: dict | None = None, ) -> EntityExtractionPromptProfile: """Resolve and merge the configured entity extraction prompt profile.""" - default_profile = get_default_entity_extraction_prompt_profile() + default_profile = get_default_entity_extraction_prompt_profile(global_config) addon_params = addon_params or {} prompt_file = addon_params.get("entity_type_prompt_file")