diff --git a/rag-service/main.py b/rag-service/main.py index 2efd192..2f846a8 100644 --- a/rag-service/main.py +++ b/rag-service/main.py @@ -3840,286 +3840,166 @@ def _sse_frame(text: str, event: str | None = None) -> str: def _sse_done() -> str: return "data: [DONE]\n\n" - question = (data.question or "").strip() - if not question: - raise HTTPException(status_code=400, detail="Question is required.") - - intent = detect_question_intent(question) - mode = data.mode - normalized_query = normalize_query(question) - - session_id_list = [str(sid) for sid in data.session_ids] if data.session_ids else ([str(data.session_id)] if data.session_id else []) - secret_list = data.session_secrets if data.session_secrets else ([data.session_secret] if data.session_secret else []) - - if not session_id_list: - raise HTTPException(status_code=400, detail="At least one session ID must be provided.") + def _error_stream(error_msg: str): + yield _sse_frame(error_msg, event="error") + yield _sse_done() - all_scored_candidates = [] - all_indexed_documents = [] - - cache_key = f"{mode}:{normalized_query}" + try: + question = (data.question or "").strip() + if not question: + raise HTTPException(status_code=400, detail="Question is required.") - for idx, session_id in enumerate(session_id_list): - secret = secret_list[idx] if idx < len(secret_list) else None + intent = detect_question_intent(question) + mode = data.mode + normalized_query = normalize_query(question) - with sessions_lock: - session = _touch_session_unlocked(session_id) - if not session: - raise HTTPException( - status_code=404, - detail=f"Session {session_id} expired or invalid. Please re-upload your PDFs.", - ) - - _require_session_secret(session, secret) - - if "lock" not in session: - session["lock"] = threading.Lock() - - session_lock = session["lock"] - if not session.get("vectorstore"): - try: - session["vectorstore"] = _load_vectorstore_for_session_unlocked(session_id, session) - except Exception as exc: - logger.error("Failed to lazy load vectorstore session_id=%s error=%s", session_id, exc) - raise HTTPException(status_code=500, detail="Failed to load session index.") - vectorstore = session["vectorstore"] - - # Session-level retrieval cache for streaming path - retrieval_cache = ensure_retrieval_cache(session) - with session_lock: - cleanup_retrieval_cache(retrieval_cache) - cached_value = retrieval_cache.get(cache_key) - cache_hit = False - if isinstance(cached_value, dict) and "scored_candidates" in cached_value: - logger.info( - "Stream retrieval cache hit session_id=%s cache_key=%s", - session_id, - cache_key, - ) - scored_candidates = cached_value["scored_candidates"] - cache_hit = True - elif cached_value is not None: - logger.info( - "Stream retrieval cache invalidated session_id=%s cache_key=%s", - session_id, - cache_key, - ) - retrieval_cache.pop(cache_key, None) - cache_hit = False + session_id_list = [str(sid) for sid in data.session_ids] if data.session_ids else ([str(data.session_id)] if data.session_id else []) + secret_list = data.session_secrets if data.session_secrets else ([data.session_secret] if data.session_secret else []) + + if not session_id_list: + raise HTTPException(status_code=400, detail="At least one session ID must be provided.") - try: - with session_lock: - indexed_documents = collect_index_documents(vectorstore) - if not cache_hit: - logger.info( - "Stream retrieval cache miss session_id=%s cache_key=%s", - session_id, - cache_key, - ) - scored_candidates = search_retrieval_candidates( - vectorstore, - question, - ASK_RETRIEVAL_CANDIDATES, - ) + all_scored_candidates = [] + all_indexed_documents = [] + + cache_key = f"{mode}:{normalized_query}" - if not cache_hit: - with sessions_lock: - current_session = sessions.get(session_id) - if current_session: - rc = ensure_retrieval_cache(current_session) - if len(rc) >= RETRIEVAL_CACHE_LIMIT: - oldest = next(iter(rc)) - del rc[oldest] - rc[cache_key] = { - "cached_at": now_ts(), - "scored_candidates": scored_candidates, - } - except Exception: - logger.exception("Stream similarity search failed session_id=%s", session_id) - raise HTTPException(status_code=500, detail="Failed to search the uploaded documents.") + for idx, session_id in enumerate(session_id_list): + secret = secret_list[idx] if idx < len(secret_list) else None - all_scored_candidates.extend(scored_candidates) - all_indexed_documents.extend(indexed_documents) - - # Sort all candidates from multiple documents by score ascending (lower is better in FAISS L2) - all_scored_candidates.sort(key=lambda x: x[1]) - scored_candidates = all_scored_candidates[:ASK_RETRIEVAL_CANDIDATES] - indexed_documents = all_indexed_documents - - docs = ( - representative_documents_by_source(indexed_documents) - if intent == "overview" - else diversify_retrieved_documents(scored_candidates, question) - ) - - best_score = scored_candidates[0][1] if scored_candidates else None - if not passes_evidence_gate(question, docs, best_score, intent): - logger.info( - "Stream evidence gate refused session_id=%s intent=%s best_score=%s", - session_id, - intent, - best_score, - ) - with sessions_lock: - current_session = sessions.get(session_id) - if current_session: - append_chat_exchange( - current_session, - question, - INSUFFICIENT_CONTEXT_MESSAGE, - [], - mode, - ) - _mark_session_dirty(session_id) - - def _refuse_stream(): - yield _sse_frame(INSUFFICIENT_CONTEXT_MESSAGE) - yield _sse_done() - - return StreamingResponse(_refuse_stream(), media_type="text/event-stream; charset=utf-8") - - context = format_context(docs) + with sessions_lock: + session = _touch_session_unlocked(session_id) + if not session: + raise HTTPException( + status_code=404, + detail=f"Session {session_id} expired or invalid. Please re-upload your PDFs.", + ) - grounded_answer = None # Disabled to force LLM usage + _require_session_secret(session, secret) - # For grounded (non-LLM) answers, stream the result directly without - # spinning up a generation thread — there are no tokens to generate. - if grounded_answer != INSUFFICIENT_CONTEXT_MESSAGE and grounded_answer: - citation_sources = [citation_source_for_document(doc, idx) for idx, doc in enumerate(docs)] - framed = apply_mode_framing(grounded_answer, question, mode, docs, context) - if ASK_REQUIRE_CITATIONS and not answer_contains_citation(framed, len(docs)): - framed = grounded_answer + if "lock" not in session: + session["lock"] = threading.Lock() - with sessions_lock: - current_session = sessions.get(session_id) - if current_session: - ensure_retrieval_cache(current_session) + session_lock = session["lock"] + if not session.get("vectorstore"): + try: + session["vectorstore"] = _load_vectorstore_for_session_unlocked(session_id, session) + except Exception as exc: + logger.error("Failed to lazy load vectorstore session_id=%s error=%s", session_id, exc) + raise HTTPException(status_code=500, detail="Failed to load session index.") + vectorstore = session["vectorstore"] - append_chat_exchange( - current_session, - question, - framed, - citation_sources, - mode, - ) + # Session-level retrieval cache for streaming path + retrieval_cache = ensure_retrieval_cache(session) + with session_lock: + cleanup_retrieval_cache(retrieval_cache) + cached_value = retrieval_cache.get(cache_key) + cache_hit = False + if isinstance(cached_value, dict) and "scored_candidates" in cached_value: + logger.info( + "Stream retrieval cache hit session_id=%s cache_key=%s", + session_id, + cache_key, + ) + scored_candidates = cached_value["scored_candidates"] + cache_hit = True + elif cached_value is not None: + logger.info( + "Stream retrieval cache invalidated session_id=%s cache_key=%s", + session_id, + cache_key, + ) + retrieval_cache.pop(cache_key, None) + cache_hit = False - _mark_session_dirty(session_id) + try: + with session_lock: + indexed_documents = collect_index_documents(vectorstore) + if not cache_hit: + logger.info( + "Stream retrieval cache miss session_id=%s cache_key=%s", + session_id, + cache_key, + ) + scored_candidates = search_retrieval_candidates( + vectorstore, + question, + ASK_RETRIEVAL_CANDIDATES, + ) - def _grounded_stream(): - yield _sse_frame(framed) - yield _sse_done() + if not cache_hit: + with sessions_lock: + current_session = sessions.get(session_id) + if current_session: + rc = ensure_retrieval_cache(current_session) + if len(rc) >= RETRIEVAL_CACHE_LIMIT: + oldest = next(iter(rc)) + del rc[oldest] + rc[cache_key] = { + "cached_at": now_ts(), + "scored_candidates": scored_candidates, + } + except Exception: + logger.exception("Stream similarity search failed session_id=%s", session_id) + raise HTTPException(status_code=500, detail="Failed to search the uploaded documents.") + + all_scored_candidates.extend(scored_candidates) + all_indexed_documents.extend(indexed_documents) - return StreamingResponse(_grounded_stream(), media_type="text/event-stream; charset=utf-8") + # Sort all candidates from multiple documents by score ascending (lower is better in FAISS L2) + all_scored_candidates.sort(key=lambda x: x[1]) + scored_candidates = all_scored_candidates[:ASK_RETRIEVAL_CANDIDATES] + indexed_documents = all_indexed_documents - # LLM generation path — run in a background thread so we can stream tokens - # back to the caller as they are produced rather than waiting for the full - # completion before sending anything. - followup_instructions = "" - if mode in ["tutor", "socratic"]: - followup_instructions = ( - "You MUST append an interactive multiple-choice question to test their understanding. " - "Format it exactly like this at the very end of your response:\n" - "\nQuestion: [Question text]\nOptions:\n- [Option A]\n- [Option B]\n\n\n" - ) - elif mode in ["default", "eli5"]: - followup_instructions = ( - "If there is a deterministic follow-up question that would be helpful, you MAY append an interactive block. " - "Format it exactly like this at the very end of your response:\n" - "\nQuestion: [Question text]\nOptions:\n- [Option A]\n- [Option B]\n\n\n" + docs = ( + representative_documents_by_source(indexed_documents) + if intent == "overview" + else diversify_retrieved_documents(scored_candidates, question) ) - prompt = ( - "You are a careful assistant answering questions over one or more uploaded PDF documents. " - "Use only the provided context. The context may include excerpts from multiple PDFs. " - "When the question asks for a relationship, comparison, or synthesis, connect the relevant facts across documents. " - "If the context does not contain enough information, say that briefly and do not invent details.\n\n" - "Reference the provided source numbers naturally whenever the answer is directly supported by the context.\n" - "Cite sources using formats like 'According to Source 1' or 'Source 2 explains that...'\n" - "You are a helpful AI assistant.\n" - "Give clear, conversational, human-friendly answers.\n" - "Do not return raw PDF text or chunks.\n" - "Summarize properly in readable sentences.\n\n" - f"{followup_instructions}" - f"Context:\n{context}\n\n" - f"Question: {question}\n" - "Answer:" - ) + best_score = scored_candidates[0][1] if scored_candidates else None + if not passes_evidence_gate(question, docs, best_score, intent): + logger.info( + "Stream evidence gate refused session_id=%s intent=%s best_score=%s", + session_id, + intent, + best_score, + ) + with sessions_lock: + current_session = sessions.get(session_id) + if current_session: + append_chat_exchange( + current_session, + question, + INSUFFICIENT_CONTEXT_MESSAGE, + [], + mode, + ) + _mark_session_dirty(session_id) - logger.info( - "Stream executing query session_id=%s retrieved_chunks=%s", - session_id, - len(docs), - ) + def _refuse_stream(): + yield _sse_frame(INSUFFICIENT_CONTEXT_MESSAGE) + yield _sse_done() - def _generate_and_stream(): - groq_api_key = os.environ.get("GROQ_API_KEY") - if not groq_api_key: - err = "Groq API Key is missing! Please provide your GROQ_API_KEY in the environment." - yield err - return + return StreamingResponse(_refuse_stream(), media_type="text/event-stream; charset=utf-8") - full_answer_parts = [] - try: - import urllib.request - import json - - url = "https://api.groq.com/openai/v1/chat/completions" - headers = { - "Authorization": f"Bearer {groq_api_key}", - "Content-Type": "application/json", - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)" - } - payload = json.dumps({ - "model": "llama-3.1-8b-instant", - "messages": [{"role": "user", "content": prompt}], - "stream": True, - "temperature": 0 - }).encode("utf-8") - - req = urllib.request.Request(url, data=payload, headers=headers, method="POST") - with urllib.request.urlopen(req, timeout=30) as resp: - for line in resp: - decoded = line.decode('utf-8').strip() - if decoded.startswith('data: '): - data_str = decoded[6:] - if data_str == '[DONE]': - break - try: - data = json.loads(data_str) - token = data['choices'][0]['delta'].get('content', '') - if token: - full_answer_parts.append(token) - yield _sse_frame(token) - except Exception: - pass - except Exception as e: - err = f"Groq API Error: {str(e)}" - yield err - full_answer_parts.append(err) + context = format_context(docs) - full_answer = "".join(full_answer_parts).strip() - - try: - # Streamed tokens were already yielded above as they arrived. - # Now produce the final framed answer, persist the chat exchange, - # and send the final framed answer + done event. - framed = apply_mode_framing(full_answer, question, mode, docs, context) + grounded_answer = None # Disabled to force LLM usage + # For grounded (non-LLM) answers, stream the result directly without + # spinning up a generation thread — there are no tokens to generate. + if grounded_answer != INSUFFICIENT_CONTEXT_MESSAGE and grounded_answer: + citation_sources = [citation_source_for_document(doc, idx) for idx, doc in enumerate(docs)] + framed = apply_mode_framing(grounded_answer, question, mode, docs, context) if ASK_REQUIRE_CITATIONS and not answer_contains_citation(framed, len(docs)): - framed = full_answer - - citation_sources = [ - citation_source_for_document(doc, idx) - for idx, doc in enumerate(docs) - ] - - # stream the final framed answer once at the end - yield _sse_frame(framed) + framed = grounded_answer with sessions_lock: current_session = sessions.get(session_id) if current_session: ensure_retrieval_cache(current_session) + append_chat_exchange( current_session, question, @@ -4127,21 +4007,150 @@ def _generate_and_stream(): citation_sources, mode, ) + _mark_session_dirty(session_id) - yield _sse_done() - except Exception: - logger.exception("Stream generation failed session_id=%s", session_id) - yield _sse_frame("Generation error. Please try again.", event="error") - # Emit an explicit done marker after the error so SSE clients - # that rely on an in-band completion token can handle the - # terminal state deterministically. + def _grounded_stream(): + yield _sse_frame(framed) + yield _sse_done() + + return StreamingResponse(_grounded_stream(), media_type="text/event-stream; charset=utf-8") + + # LLM generation path — run in a background thread so we can stream tokens + # back to the caller as they are produced rather than waiting for the full + # completion before sending anything. + followup_instructions = "" + if mode in ["tutor", "socratic"]: + followup_instructions = ( + "You MUST append an interactive multiple-choice question to test their understanding. " + "Format it exactly like this at the very end of your response:\n" + "\nQuestion: [Question text]\nOptions:\n- [Option A]\n- [Option B]\n\n\n" + ) + elif mode in ["default", "eli5"]: + followup_instructions = ( + "If there is a deterministic follow-up question that would be helpful, you MAY append an interactive block. " + "Format it exactly like this at the very end of your response:\n" + "\nQuestion: [Question text]\nOptions:\n- [Option A]\n- [Option B]\n\n\n" + ) + + prompt = ( + "You are a careful assistant answering questions over one or more uploaded PDF documents. " + "Use only the provided context. The context may include excerpts from multiple PDFs. " + "When the question asks for a relationship, comparison, or synthesis, connect the relevant facts across documents. " + "If the context does not contain enough information, say that briefly and do not invent details.\n\n" + "Reference the provided source numbers naturally whenever the answer is directly supported by the context.\n" + "Cite sources using formats like 'According to Source 1' or 'Source 2 explains that...'\n" + "You are a helpful AI assistant.\n" + "Give clear, conversational, human-friendly answers.\n" + "Do not return raw PDF text or chunks.\n" + "Summarize properly in readable sentences.\n\n" + f"{followup_instructions}" + f"Context:\n{context}\n\n" + f"Question: {question}\n" + "Answer:" + ) + + logger.info( + "Stream executing query session_id=%s retrieved_chunks=%s", + session_id, + len(docs), + ) + + def _generate_and_stream(): + groq_api_key = os.environ.get("GROQ_API_KEY") + if not groq_api_key: + err = "Groq API Key is missing! Please provide your GROQ_API_KEY in the environment." + yield _sse_frame(err, event="error") + yield _sse_done() + return + + full_answer_parts = [] + try: + import urllib.request + import json + + url = "https://api.groq.com/openai/v1/chat/completions" + headers = { + "Authorization": f"Bearer {groq_api_key}", + "Content-Type": "application/json", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64)" + } + payload = json.dumps({ + "model": "llama-3.1-8b-instant", + "messages": [{"role": "user", "content": prompt}], + "stream": True, + "temperature": 0 + }).encode("utf-8") + + req = urllib.request.Request(url, data=payload, headers=headers, method="POST") + with urllib.request.urlopen(req, timeout=30) as resp: + for line in resp: + decoded = line.decode('utf-8').strip() + if decoded.startswith('data: '): + data_str = decoded[6:] + if data_str == '[DONE]': + break + try: + data = json.loads(data_str) + token = data['choices'][0]['delta'].get('content', '') + if token: + full_answer_parts.append(token) + yield _sse_frame(token) + except Exception: + pass + except Exception as e: + err = f"Groq API Error: {str(e)}" + yield _sse_frame(err, event="error") + yield _sse_done() + return + + full_answer = "".join(full_answer_parts).strip() + try: + # Streamed tokens were already yielded above as they arrived. + # Now produce the final framed answer for persistence (logging, citations, analytics). + # Do NOT re-emit the full answer to the client to avoid duplicate content. + framed = apply_mode_framing(full_answer, question, mode, docs, context) + + if ASK_REQUIRE_CITATIONS and not answer_contains_citation(framed, len(docs)): + framed = full_answer + + citation_sources = [ + citation_source_for_document(doc, idx) + for idx, doc in enumerate(docs) + ] + + with sessions_lock: + current_session = sessions.get(session_id) + if current_session: + ensure_retrieval_cache(current_session) + append_chat_exchange( + current_session, + question, + framed, + citation_sources, + mode, + ) + _mark_session_dirty(session_id) + yield _sse_done() except Exception: - pass + logger.exception("Stream generation failed session_id=%s", session_id) + yield _sse_frame("Generation error. Please try again.", event="error") + # Emit an explicit done marker after the error so SSE clients + # that rely on an in-band completion token can handle the + # terminal state deterministically. + try: + yield _sse_done() + except Exception: + pass + + return StreamingResponse(_generate_and_stream(), media_type="text/event-stream; charset=utf-8") - return StreamingResponse(_generate_and_stream(), media_type="text/event-stream; charset=utf-8") + except HTTPException as e: + # Convert HTTPExceptions to SSE-formatted errors for streaming endpoint + error_msg = e.detail if isinstance(e.detail, str) else str(e.detail) + return StreamingResponse(_error_stream(error_msg), media_type="text/event-stream; charset=utf-8") def _run_generation_locked(model, generate_kwargs): diff --git a/rag-service/test_main.py b/rag-service/test_main.py index 49ae100..89bddbe 100644 --- a/rag-service/test_main.py +++ b/rag-service/test_main.py @@ -2,6 +2,7 @@ import sys from unittest.mock import MagicMock import multiprocessing +import threading os.environ.setdefault("INTERNAL_RAG_TOKEN", "test-secret") @@ -1187,3 +1188,266 @@ def test_vectorstore_loader_rejects_path_traversal_like_session_id(): with pytest.raises(ValueError, match=r"(badly formed hexadecimal UUID string|Invalid persisted session id)"): main_module._load_vectorstore_from_snapshot("../escape", MagicMock()) + + +# ============================================================================ +# Streaming SSE Protocol Tests +# ============================================================================ + +def test_stream_no_duplicate_content(): + """Verify that /ask/stream does not emit duplicate content. + + Regression test for bug where tokens were streamed individually, + then the full answer was streamed again at the end. + + This test inspects the source code to ensure the problematic pattern + (yielding framed answer after tokens) is not present. + """ + import main as main_module + import inspect + + source = inspect.getsource(main_module.ask_question_stream) + + # The bug was: after streaming tokens, the code would yield the full framed answer again + # This pattern should NOT exist in the streaming path + # We check for the specific pattern that was causing the bug + + # The buggy pattern was: yield _sse_frame(framed) after token streaming + # We verify that after the token streaming loop, we don't yield the framed answer + lines = source.split('\n') + + # Find the token streaming section + in_token_stream = False + after_token_stream = False + found_framed_yield_after_tokens = False + + for i, line in enumerate(lines): + if 'yield _sse_frame(token)' in line: + in_token_stream = True + + if in_token_stream and 'full_answer = ' in line: + after_token_stream = True + + if after_token_stream and 'yield _sse_frame(framed)' in line: + # Check if this is NOT in a comment or the refuse/grounded stream paths + # The refuse_stream and grounded_stream are separate functions that should yield framed + # We only care about the main LLM generation path + if '_refuse_stream' not in lines[max(0, i-10):i+10] and '_grounded_stream' not in lines[max(0, i-10):i+10]: + found_framed_yield_after_tokens = True + + assert not found_framed_yield_after_tokens, ( + "Found yield _sse_frame(framed) after token streaming, which causes duplicate content. " + "The full framed answer should only be used for persistence, not re-emitted to the client." + ) + + +def test_stream_error_uses_sse_event(): + """Verify that error paths use proper SSE event format, not raw text.""" + import main as main_module + + # Temporarily remove GROQ_API_KEY to trigger the error path + original_groq_key = os.environ.get("GROQ_API_KEY") + if "GROQ_API_KEY" in os.environ: + del os.environ["GROQ_API_KEY"] + + try: + # Setup a minimal session + session_id = str(_secrets.token_hex(16)) + session_secret = _secrets.token_urlsafe(32) + + with main_module.sessions_lock: + main_module.sessions[session_id] = { + "session_secret": session_secret, + "hashed_session_secret": main_module._hash_secret(session_secret), + "documents": [], + "chat": [], + "vectorstore": MagicMock(), + "lock": threading.Lock(), + } + + try: + client = TestClient(main_module.app) + response = client.post( + "/ask/stream", + json={ + "question": "test", + "session_id": session_id, + "session_secret": session_secret, + }, + headers={"X-Internal-Token": "test-secret"}, + ) + + chunks = [] + for chunk in response.iter_bytes(): + chunks.append(chunk.decode('utf-8')) + + full_stream = ''.join(chunks) + + # Verify error uses SSE event format + assert 'event: error' in full_stream, "Error should use SSE event format" + + # Verify no raw text yields (all data should be prefixed with 'data:') + lines = full_stream.split('\n') + for line in lines: + if line and not line.startswith('event:') and not line.startswith('data:') and line != '': + # Skip empty lines and comments + if not line.startswith(':'): + raise AssertionError(f"Found raw text without SSE prefix: {line}") + + # Verify exactly one [DONE] event after error + done_count = full_stream.count('[DONE]') + assert done_count == 1, f"[DONE] appeared {done_count} times, expected 1" + + finally: + with main_module.sessions_lock: + main_module.sessions.pop(session_id, None) + finally: + # Restore original GROQ_API_KEY + if original_groq_key: + os.environ["GROQ_API_KEY"] = original_groq_key + + +def test_stream_single_completion_event(): + """Verify that exactly one [DONE] event is emitted per stream.""" + import main as main_module + from unittest.mock import patch, MagicMock + + # Mock the Groq API to return a simple response + mock_response = MagicMock() + mock_response.__iter__ = lambda self: iter([ + b'data: {"choices":[{"delta":{"content":"Test"}}]}\n\n', + b'data: [DONE]\n\n', + ]) + + with patch('urllib.request.urlopen', return_value=mock_response): + with patch.object(main_module, 'apply_mode_framing', return_value='Test'): + session_id = str(_secrets.token_hex(16)) + session_secret = _secrets.token_urlsafe(32) + + with main_module.sessions_lock: + main_module.sessions[session_id] = { + "session_secret": session_secret, + "hashed_session_secret": main_module._hash_secret(session_secret), + "documents": [], + "chat": [], + "vectorstore": MagicMock(), + "lock": threading.Lock(), + } + + try: + client = TestClient(main_module.app) + response = client.post( + "/ask/stream", + json={ + "question": "test", + "session_id": session_id, + "session_secret": session_secret, + }, + headers={"X-Internal-Token": "test-secret"}, + ) + + chunks = [] + for chunk in response.iter_bytes(): + chunks.append(chunk.decode('utf-8')) + + full_stream = ''.join(chunks) + + # Count [DONE] events + done_count = full_stream.count('[DONE]') + assert done_count == 1, f"Expected exactly 1 [DONE] event, got {done_count}" + + finally: + with main_module.sessions_lock: + main_module.sessions.pop(session_id, None) + + +def test_stream_frontend_parser_compatibility(): + """Verify that stream output is compatible with frontend SSE parser. + + This test ensures the output format matches what ragService.js expects: + - event: message for regular content + - event: error for errors + - data: [DONE] for completion + """ + import main as main_module + from unittest.mock import patch, MagicMock + + # Mock the Groq API + mock_response = MagicMock() + mock_response.__iter__ = lambda self: iter([ + b'data: {"choices":[{"delta":{"content":"Response"}}]}\n\n', + b'data: [DONE]\n\n', + ]) + + with patch('urllib.request.urlopen', return_value=mock_response): + with patch.object(main_module, 'apply_mode_framing', return_value='Response'): + session_id = str(_secrets.token_hex(16)) + session_secret = _secrets.token_urlsafe(32) + + with main_module.sessions_lock: + main_module.sessions[session_id] = { + "session_secret": session_secret, + "hashed_session_secret": main_module._hash_secret(session_secret), + "documents": [], + "chat": [], + "vectorstore": MagicMock(), + "lock": threading.Lock(), + } + + try: + client = TestClient(main_module.app) + response = client.post( + "/ask/stream", + json={ + "question": "test", + "session_id": session_id, + "session_secret": session_secret, + }, + headers={"X-Internal-Token": "test-secret"}, + ) + + chunks = [] + for chunk in response.iter_bytes(): + chunks.append(chunk.decode('utf-8')) + + full_stream = ''.join(chunks) + + # Parse the stream similar to ragService.js + events = full_stream.split('\n\n') + parsed_events = [] + + for event_text in events: + if not event_text.strip(): + continue + + event_name = 'message' + data_lines = [] + + for line in event_text.split('\n'): + line = line.strip() + if line.startswith('event:'): + event_name = line[6:].strip() + elif line.startswith('data:'): + data_lines.append(line[5:].strip()) + + if data_lines: + parsed_events.append({ + 'event': event_name, + 'data': '\n'.join(data_lines) + }) + + # Verify we got the expected events + assert len(parsed_events) >= 2, "Should have at least 2 events (content + done)" + + # Find the done event + done_events = [e for e in parsed_events if e['data'] == '[DONE]'] + assert len(done_events) == 1, "Should have exactly one [DONE] event" + + # Verify content events use proper format + content_events = [e for e in parsed_events if e['data'] != '[DONE]'] + for event in content_events: + assert event['event'] in ['message', 'error'], f"Unexpected event type: {event['event']}" + + finally: + with main_module.sessions_lock: + main_module.sessions.pop(session_id, None)