diff --git a/autotest/interface/restful/test_restful_abort_request.py b/autotest/interface/restful/test_restful_abort_request.py new file mode 100644 index 0000000000..a79d3ae743 --- /dev/null +++ b/autotest/interface/restful/test_restful_abort_request.py @@ -0,0 +1,643 @@ +import json +import os +import random +import re +import threading +import time +from collections.abc import Callable +from datetime import datetime + +import allure +import pytest +import requests +from utils.constant import BACKEND_LIST, DEFAULT_PORT, DEFAULT_SERVER, RESTFUL_MODEL_LIST +from utils.restful_return_check import assert_chat_completions_batch_return + +from lmdeploy.serve.openai.api_client import APIClient + +BASE_URL = f'http://{DEFAULT_SERVER}:{DEFAULT_PORT}' +JSON_HEADERS = {'Content-Type': 'application/json'} +_REQUEST_TIMEOUT = 300 +_ABORT_TIMEOUT = 60 +_SESSION_RETRY = 25 +_SESSION_RETRY_INTERVAL = 0.3 +_NONSTREAM_ABORT_LEAD_S = 2.0 +_THREAD_JOIN_EXTRA_S = 30 +_POST_ABORT_LOGPROBS_NUM = 10 +_MAX_LOG_TEXT = 2000 +_LOG_HOOK: Callable[[dict], None] | None = None + + +def _set_log_hook(hook: Callable[[dict], None] | None) -> None: + global _LOG_HOOK + _LOG_HOOK = hook + + +def _truncate_text(text: str, limit: int = _MAX_LOG_TEXT) -> str: + if len(text) <= limit: + return text + return text[:limit] + f'...' + + +def _response_snapshot(resp: requests.Response) -> dict: + snap = {'status_code': resp.status_code} + try: + snap['json'] = resp.json() + except Exception: + snap['text'] = _truncate_text(resp.text or '') + return snap + + +def _emit_log(event: str, **kwargs) -> None: + if _LOG_HOOK is None: + return + payload = {'timestamp': datetime.now().isoformat(), 'event': event, **kwargs} + try: + _LOG_HOOK(payload) + except Exception: + # Logging should never break test assertions. + pass + + +def _post_abort_request(payload: dict) -> requests.Response: + resp = requests.post( + f'{BASE_URL}/abort_request', + headers=JSON_HEADERS, + json=payload, + timeout=_ABORT_TIMEOUT, + ) + _emit_log('post_abort_request', request={'payload': payload}, response=_response_snapshot(resp)) + return resp + + +def _chat_non_stream( + model_name: str, + session_id: int, + *, + max_tokens: int = 32, + logprobs: bool = False, + top_logprobs: int = _POST_ABORT_LOGPROBS_NUM, +) -> requests.Response: + body: dict = { + 'model': model_name, + 'messages': [{'role': 'user', 'content': 'Say OK in one word.'}], + 'max_tokens': max_tokens, + 'temperature': 0.01, + 'stream': False, + 'session_id': session_id, + } + if logprobs: + body['logprobs'] = True + body['top_logprobs'] = top_logprobs + resp = requests.post( + f'{BASE_URL}/v1/chat/completions', + headers=JSON_HEADERS, + json=body, + timeout=_REQUEST_TIMEOUT, + ) + _emit_log('chat_non_stream', request={'payload': body}, response=_response_snapshot(resp)) + return resp + + +def _consume_first_nonempty_sse_data_line(resp: requests.Response) -> None: + idx = 0 + for raw in resp.iter_lines(decode_unicode=True): + if not raw or not raw.startswith('data:'): + continue + chunk = raw[5:].strip() + if chunk == '[DONE]': + _emit_log('stream_chunk_done_before_abort') + break + if not chunk: + continue + try: + parsed = json.loads(chunk) + except json.JSONDecodeError: + _emit_log('stream_chunk_parse_error', raw_preview=_truncate_text(chunk)) + continue + idx += 1 + _emit_log('stream_chunk_before_abort', + chunk_index=idx, + chunk_preview=_truncate_text(chunk), + finish_reason=(parsed.get('choices') or [{}])[0].get('finish_reason') + if isinstance(parsed, dict) else None, + meta_finish_reason=(parsed.get('meta_info') or {}).get('finish_reason') + if isinstance(parsed, dict) else None) + return + assert False, 'expected at least one parsable SSE data line before abort' + + +def _post_abort_explicit_session_or_skip(session_id: int) -> None: + abort_r = _post_abort_request({'session_id': session_id, 'abort_all': False}) + if abort_r.status_code == 501: + pytest.skip('api_server started without --enable-abort-handling') + assert abort_r.status_code == 200, f'abort_request: {abort_r.status_code} {abort_r.text!r}' + + +def _post_abort_nonexistent_session(session_id: int) -> requests.Response: + return _post_abort_request({'session_id': session_id, 'abort_all': False}) + + +def _post_abort_all_or_skip() -> None: + abort_r = _post_abort_request({'abort_all': True}) + if abort_r.status_code == 501: + pytest.skip('api_server started without --enable-abort-handling') + assert abort_r.status_code == 200, f'abort_request abort_all: {abort_r.status_code} {abort_r.text!r}' + + +def _assert_session_reusable_after_abort(model_name: str, session_id: int) -> None: + last = None + for _ in range(_SESSION_RETRY): + last = _chat_non_stream( + model_name, + session_id, + max_tokens=16, + logprobs=True, + top_logprobs=_POST_ABORT_LOGPROBS_NUM, + ) + if last.status_code == 200: + data = last.json() + assert 'choices' in data and data['choices'], last.text + assert_chat_completions_batch_return( + data, + model_name, + check_logprobs=True, + logprobs_num=_POST_ABORT_LOGPROBS_NUM, + ) + return + if last.status_code == 400 and 'occupied' in (last.text or '').lower(): + time.sleep(_SESSION_RETRY_INTERVAL) + continue + break + assert False, f'session {session_id} not reusable after abort: last={last.status_code} {last.text!r}' + + +def _long_user_prompt() -> str: + return 'Write a long numbered list from 1 to 500, one number per line, no other text.' + + +def _finish_reason_indicates_abort(finish_reason) -> bool: + """LMDeploy may use OpenAI-style ``'abort'`` or nested ``{'type': + + 'abort'}``. + """ + if finish_reason == 'abort': + return True + if isinstance(finish_reason, dict) and finish_reason.get('type') == 'abort': + return True + return False + + +def _sse_chunk_indicates_abort(chunk: dict) -> bool: + choices = chunk.get('choices') or [] + if choices: + if _finish_reason_indicates_abort(choices[0].get('finish_reason')): + return True + meta = chunk.get('meta_info') or {} + return _finish_reason_indicates_abort(meta.get('finish_reason')) + + +def _verify_stream_abort_finish_reason(resp: requests.Response) -> None: + found_abort = False + chunk_idx = 0 + chunk_summaries: list[dict] = [] + for raw in resp.iter_lines(decode_unicode=True): + if not raw or not raw.startswith('data:'): + continue + chunk_str = raw[5:].strip() + if chunk_str == '[DONE]': + _emit_log('stream_chunk_done_after_abort', seen_chunks=chunk_idx) + break + if chunk_str: + try: + chunk = json.loads(chunk_str) + except json.JSONDecodeError: + _emit_log('stream_chunk_parse_error_after_abort', raw_preview=_truncate_text(chunk_str)) + continue + chunk_idx += 1 + choice_fr = None + choices = chunk.get('choices') or [] + if choices and isinstance(choices[0], dict): + choice_fr = choices[0].get('finish_reason') + meta_fr = (chunk.get('meta_info') or {}).get('finish_reason') + summary = { + 'idx': chunk_idx, + 'choice_finish_reason': choice_fr, + 'meta_finish_reason': meta_fr, + 'preview': _truncate_text(chunk_str), + } + chunk_summaries.append(summary) + _emit_log('stream_chunk_after_abort', **summary) + if _sse_chunk_indicates_abort(chunk): + found_abort = True + break + if not found_abort: + _emit_log('stream_abort_not_found', + seen_chunks=chunk_idx, + chunk_summaries=chunk_summaries[-10:]) + assert found_abort, "Expected finish_reason 'abort' in stream response" + + +def _verify_non_stream_abort_finish_reason(resp: requests.Response) -> None: + data = resp.json() + if 'choices' in data and data['choices']: + finish_reason = data['choices'][0].get('finish_reason') + assert _finish_reason_indicates_abort(finish_reason), ( + f'Expected abort finish_reason, got {finish_reason!r}') + return + # Legacy ``/generate`` body: ``text`` + ``meta_info.finish_reason`` + meta = data.get('meta_info') or {} + fr = meta.get('finish_reason') + assert _finish_reason_indicates_abort(fr), ( + f'Expected abort in meta_info.finish_reason, got {fr!r}; keys={list(data)!r}') + + +def _send_nonstream_request_with_abort(model_name: str, session_id: int, endpoint: str) -> requests.Response: + payload: dict + if endpoint == '/v1/chat/completions': + payload = { + 'model': model_name, + 'messages': [{'role': 'user', 'content': _long_user_prompt()}], + 'max_tokens': 2048, + 'temperature': 0.3, + 'stream': False, + 'session_id': session_id, + } + elif endpoint == '/generate': + payload = { + 'prompt': _long_user_prompt(), + 'max_tokens': 2048, + 'temperature': 0.3, + 'stream': False, + 'session_id': session_id, + } + else: + payload = { + 'model': model_name, + 'prompt': _long_user_prompt(), + 'max_tokens': 2048, + 'temperature': 0.3, + 'stream': False, + 'session_id': session_id, + } + resp = requests.post( + f'{BASE_URL}{endpoint}', + headers=JSON_HEADERS, + json=payload, + timeout=_REQUEST_TIMEOUT, + ) + _emit_log('send_nonstream_request_with_abort', + request={'endpoint': endpoint, 'payload': payload}, + response=_response_snapshot(resp)) + return resp + + +@pytest.mark.order(9) +@pytest.mark.flaky(reruns=2) +@pytest.mark.parametrize('backend', BACKEND_LIST) +@pytest.mark.parametrize('model_case', RESTFUL_MODEL_LIST) +class TestRestfulAbortRequest: + + @pytest.fixture(autouse=True) + def setup_abort_log(self, request, config, backend, model_case): + test_name = re.sub(r'[^\w\.-]', '_', request.node.name) + model_name = str(model_case).replace('/', '_') + log_base = config.get('log_path', './logs') + log_dir = os.path.join(log_base, model_name) + os.makedirs(log_dir, exist_ok=True) + timestamp = time.strftime('%Y%m%d_%H%M%S') + self.log_file = os.path.join(log_dir, f'restful_abort_{backend}_{test_name}_{timestamp}.log') + + def _writer(entry: dict) -> None: + with open(self.log_file, 'a', encoding='utf-8') as f: + f.write(json.dumps(entry, ensure_ascii=False) + '\n') + + _set_log_hook(_writer) + _emit_log('test_start', test=test_name, backend=backend, model_case=model_case, base_url=BASE_URL) + yield + _emit_log('test_end', test=test_name) + _set_log_hook(None) + if os.path.isfile(self.log_file): + allure.attach.file( + self.log_file, + name=os.path.basename(self.log_file), + attachment_type=allure.attachment_type.TEXT, + ) + + def test_abort_running_stream_chat_request_returns_abort_finish_reason(self, backend, model_case): + api_client = APIClient(BASE_URL) + model_name = api_client.available_models[0] + session_id = 8_000_000 + random.randint(0, 99_999) + + stream_payload = { + 'model': model_name, + 'messages': [{'role': 'user', 'content': _long_user_prompt()}], + 'max_tokens': 2048, + 'temperature': 0.3, + 'stream': True, + 'session_id': session_id, + } + resp = requests.post( + f'{BASE_URL}/v1/chat/completions', + headers=JSON_HEADERS, + json=stream_payload, + stream=True, + timeout=_REQUEST_TIMEOUT, + ) + resp.raise_for_status() + + try: + _consume_first_nonempty_sse_data_line(resp) + _post_abort_explicit_session_or_skip(session_id) + _verify_stream_abort_finish_reason(resp) + finally: + resp.close() + + _assert_session_reusable_after_abort(model_name, session_id) + + def test_abort_running_stream_generate_request_returns_abort_finish_reason(self, backend, model_case): + api_client = APIClient(BASE_URL) + model_name = api_client.available_models[0] + session_id = 7_000_000 + random.randint(0, 99_999) + + stream_payload = { + 'prompt': _long_user_prompt(), + 'max_tokens': 2048, + 'temperature': 0.3, + 'stream': True, + 'session_id': session_id, + } + resp = requests.post( + f'{BASE_URL}/generate', + headers=JSON_HEADERS, + json=stream_payload, + stream=True, + timeout=_REQUEST_TIMEOUT, + ) + resp.raise_for_status() + + try: + _consume_first_nonempty_sse_data_line(resp) + _post_abort_explicit_session_or_skip(session_id) + _verify_stream_abort_finish_reason(resp) + finally: + resp.close() + + _assert_session_reusable_after_abort(model_name, session_id) + + def test_abort_running_stream_completions_request_returns_abort_finish_reason(self, backend, model_case): + api_client = APIClient(BASE_URL) + model_name = api_client.available_models[0] + session_id = 6_000_000 + random.randint(0, 99_999) + + stream_payload = { + 'model': model_name, + 'prompt': _long_user_prompt(), + 'max_tokens': 2048, + 'temperature': 0.3, + 'stream': True, + 'session_id': session_id, + } + resp = requests.post( + f'{BASE_URL}/v1/completions', + headers=JSON_HEADERS, + json=stream_payload, + stream=True, + timeout=_REQUEST_TIMEOUT, + ) + resp.raise_for_status() + + try: + _consume_first_nonempty_sse_data_line(resp) + _post_abort_explicit_session_or_skip(session_id) + _verify_stream_abort_finish_reason(resp) + finally: + resp.close() + + _assert_session_reusable_after_abort(model_name, session_id) + + def test_abort_running_non_stream_chat_request_returns_abort_finish_reason(self, backend, model_case): + api_client = APIClient(BASE_URL) + model_name = api_client.available_models[0] + session_id = 5_000_000 + random.randint(0, 99_999) + + results = {'resp': None, 'exc': None, 'completed': False} + + def worker(out: dict, sid: int) -> None: + try: + out['resp'] = _send_nonstream_request_with_abort(model_name, sid, '/v1/chat/completions') + out['completed'] = True + except Exception as e: + out['exc'] = e + + thread = threading.Thread(target=worker, args=(results, session_id), daemon=True) + thread.start() + + time.sleep(_NONSTREAM_ABORT_LEAD_S) + + abort_r = _post_abort_request({'session_id': session_id, 'abort_all': False}) + if abort_r.status_code == 501: + thread.join(timeout=_REQUEST_TIMEOUT + _THREAD_JOIN_EXTRA_S) + pytest.skip('api_server started without --enable-abort-handling') + + assert abort_r.status_code == 200 + thread.join(timeout=_REQUEST_TIMEOUT + _THREAD_JOIN_EXTRA_S) + + assert not thread.is_alive() + assert results['resp'] is not None, 'Request should complete even after abort' + _verify_non_stream_abort_finish_reason(results['resp']) + + _assert_session_reusable_after_abort(model_name, session_id) + + def test_abort_running_non_stream_generate_request_returns_abort_finish_reason(self, backend, model_case): + api_client = APIClient(BASE_URL) + model_name = api_client.available_models[0] + session_id = 5_000_001 + random.randint(0, 99_999) + + results = {'resp': None, 'exc': None, 'completed': False} + + def worker(out: dict, sid: int) -> None: + try: + out['resp'] = _send_nonstream_request_with_abort(model_name, sid, '/generate') + out['completed'] = True + except Exception as e: + out['exc'] = e + + thread = threading.Thread(target=worker, args=(results, session_id), daemon=True) + thread.start() + + time.sleep(_NONSTREAM_ABORT_LEAD_S) + + abort_r = _post_abort_request({'session_id': session_id, 'abort_all': False}) + if abort_r.status_code == 501: + thread.join(timeout=_REQUEST_TIMEOUT + _THREAD_JOIN_EXTRA_S) + pytest.skip('api_server started without --enable-abort-handling') + + assert abort_r.status_code == 200 + thread.join(timeout=_REQUEST_TIMEOUT + _THREAD_JOIN_EXTRA_S) + + assert not thread.is_alive() + assert results['resp'] is not None + _verify_non_stream_abort_finish_reason(results['resp']) + + _assert_session_reusable_after_abort(model_name, session_id) + + def test_abort_running_non_stream_completions_request_returns_abort_finish_reason(self, backend, model_case): + api_client = APIClient(BASE_URL) + model_name = api_client.available_models[0] + session_id = 5_000_002 + random.randint(0, 99_999) + + results = {'resp': None, 'exc': None, 'completed': False} + + def worker(out: dict, sid: int) -> None: + try: + out['resp'] = _send_nonstream_request_with_abort(model_name, sid, '/v1/completions') + out['completed'] = True + except Exception as e: + out['exc'] = e + + thread = threading.Thread(target=worker, args=(results, session_id), daemon=True) + thread.start() + + time.sleep(_NONSTREAM_ABORT_LEAD_S) + + abort_r = _post_abort_request({'session_id': session_id, 'abort_all': False}) + if abort_r.status_code == 501: + thread.join(timeout=_REQUEST_TIMEOUT + _THREAD_JOIN_EXTRA_S) + pytest.skip('api_server started without --enable-abort-handling') + + assert abort_r.status_code == 200 + thread.join(timeout=_REQUEST_TIMEOUT + _THREAD_JOIN_EXTRA_S) + + assert not thread.is_alive() + assert results['resp'] is not None + _verify_non_stream_abort_finish_reason(results['resp']) + + _assert_session_reusable_after_abort(model_name, session_id) + + def test_abort_nonexistent_session_returns_bad_request(self, backend, model_case): + nonexistent_session_id = 999_999_999 + + abort_r = _post_abort_nonexistent_session(nonexistent_session_id) + + if abort_r.status_code == 501: + pytest.skip('api_server started without --enable-abort-handling') + + assert abort_r.status_code == 400 + error_data = abort_r.json() + assert 'error' in error_data or 'message' in error_data + + def test_abort_invalid_session_id_format_returns_bad_request(self, backend, model_case): + invalid_session_ids = [-1, 'invalid', None, 3.14] + + for invalid_id in invalid_session_ids: + abort_r = _post_abort_request({'session_id': invalid_id, 'abort_all': False}) + + if abort_r.status_code == 501: + pytest.skip('api_server started without --enable-abort-handling') + + assert abort_r.status_code in (400, 422), ( + f'expected 400 or 422 for invalid session_id, got {abort_r.status_code}') + + def test_abort_all_terminates_multiple_requests_with_abort_finish_reason(self, backend, model_case): + api_client = APIClient(BASE_URL) + model_name = api_client.available_models[0] + + sessions = [5_000_000 + random.randint(0, 99_999) for _ in range(3)] + responses = [] + + for session_id in sessions: + stream_payload = { + 'model': model_name, + 'messages': [{'role': 'user', 'content': _long_user_prompt()}], + 'max_tokens': 2048, + 'temperature': 0.3, + 'stream': True, + 'session_id': session_id, + } + resp = requests.post( + f'{BASE_URL}/v1/chat/completions', + headers=JSON_HEADERS, + json=stream_payload, + stream=True, + timeout=_REQUEST_TIMEOUT, + ) + resp.raise_for_status() + responses.append(resp) + _consume_first_nonempty_sse_data_line(resp) + + _post_abort_all_or_skip() + + for i, resp in enumerate(responses): + try: + _verify_stream_abort_finish_reason(resp) + finally: + resp.close() + + for session_id in sessions: + _assert_session_reusable_after_abort(model_name, session_id) + + def test_abort_all_terminates_non_stream_requests_with_abort_finish_reason(self, backend, model_case): + api_client = APIClient(BASE_URL) + model_name = api_client.available_models[0] + sessions = [3_000_000 + random.randint(0, 99_999) for _ in range(2)] + results_list = [] + + for session_id in sessions: + results = {'resp': None, 'exc': None, 'completed': False} + results_list.append(results) + + def worker(out: dict, sid: int): + try: + out['resp'] = _send_nonstream_request_with_abort(model_name, sid, '/v1/chat/completions') + out['completed'] = True + except Exception as e: + out['exc'] = e + + thread = threading.Thread(target=worker, args=(results, session_id), daemon=True) + thread.start() + results['thread'] = thread + + time.sleep(_NONSTREAM_ABORT_LEAD_S) + + _post_abort_all_or_skip() + + for results in results_list: + results['thread'].join(timeout=_REQUEST_TIMEOUT + _THREAD_JOIN_EXTRA_S) + assert not results['thread'].is_alive() + assert results['resp'] is not None + _verify_non_stream_abort_finish_reason(results['resp']) + + for session_id in sessions: + _assert_session_reusable_after_abort(model_name, session_id) + + def test_session_immediately_reusable_after_abort(self, backend, model_case): + api_client = APIClient(BASE_URL) + model_name = api_client.available_models[0] + session_id = 4_000_000 + random.randint(0, 99_999) + + stream_payload = { + 'model': model_name, + 'messages': [{'role': 'user', 'content': _long_user_prompt()}], + 'max_tokens': 2048, + 'temperature': 0.3, + 'stream': True, + 'session_id': session_id, + } + resp = requests.post( + f'{BASE_URL}/v1/chat/completions', + headers=JSON_HEADERS, + json=stream_payload, + stream=True, + timeout=_REQUEST_TIMEOUT, + ) + resp.raise_for_status() + + try: + _consume_first_nonempty_sse_data_line(resp) + _post_abort_explicit_session_or_skip(session_id) + finally: + resp.close() + + new_resp = _chat_non_stream(model_name, session_id, max_tokens=16) + assert new_resp.status_code == 200, f'Session should be immediately reusable, got {new_resp.status_code}' diff --git a/autotest/interface/restful/test_restful_sleep_wakeup.py b/autotest/interface/restful/test_restful_sleep_wakeup.py new file mode 100644 index 0000000000..0f417a603d --- /dev/null +++ b/autotest/interface/restful/test_restful_sleep_wakeup.py @@ -0,0 +1,432 @@ +import time +from pathlib import Path + +import pytest +import requests +import torch +from utils.constant import ( + DEFAULT_PORT, + DEFAULT_SERVER, + SLEEP_WAKEUP_BACKENDS, + SLEEP_WAKEUP_MODEL_LIST, +) +from utils.restful_return_check import assert_chat_completions_batch_return +from utils.sleep_utils import ( + LEVEL2_BASELINE_RUNS, + LEVEL2_GREEDY_MESSAGES, + LEVEL2_MAX_TOKENS, + apply_serialized_hf_segments_for_level2_weights, + apply_serialized_hf_segments_for_turbomind_level2_weights, + assert_assistant_not_degenerate, + assert_chat_decode_unchanged, + assistant_content_from_openai_completion_dict, + level2_update_weights_request_dict, + resolve_hf_checkpoint_dir, +) + +from lmdeploy.serve.openai.api_client import APIClient + +BASE_URL = f'http://{DEFAULT_SERVER}:{DEFAULT_PORT}' +JSON_HEADERS = {'Content-Type': 'application/json'} +_REQUEST_TIMEOUT = 120 +_UPDATE_WEIGHTS_TIMEOUT = 600 + + +def _assert_status_200(resp: requests.Response) -> None: + assert resp.status_code == 200, f'status={resp.status_code} body={resp.text!r}' + + +def _post_sleep(*, level: int | None = None) -> requests.Response: + url = f'{BASE_URL}/sleep' + if level is not None: + url = f'{url}?level={level}' + return requests.post(url, headers=JSON_HEADERS, json={}, timeout=_REQUEST_TIMEOUT) + + +def _post_sleep_level2() -> requests.Response: + return requests.post( + f'{BASE_URL}/sleep', + headers=JSON_HEADERS, + json={}, + params=[('tags', 'weights'), ('tags', 'kv_cache'), ('level', 2)], + timeout=_REQUEST_TIMEOUT, + ) + + +def _post_sleep_query_raw(query: str) -> requests.Response: + q = query.lstrip('?') + url = f'{BASE_URL}/sleep?{q}' if q else f'{BASE_URL}/sleep' + return requests.post(url, headers=JSON_HEADERS, json={}, timeout=_REQUEST_TIMEOUT) + + +def _post_wakeup(*, tags: list[str] | None = None) -> requests.Response: + params = [('tags', t) for t in tags] if tags else None + return requests.post( + f'{BASE_URL}/wakeup', + headers=JSON_HEADERS, + json={}, + params=params, + timeout=_REQUEST_TIMEOUT, + ) + + +def _post_update_weights_from_hf_dir(model_dir: Path, *, engine: str) -> None: + def _emit(serialized_data: object, finished: bool) -> None: + data = level2_update_weights_request_dict(serialized_data, finished) + r = requests.post( + f'{BASE_URL}/update_weights', + headers=JSON_HEADERS, + json=data, + timeout=_UPDATE_WEIGHTS_TIMEOUT, + ) + _assert_status_200(r) + + if engine == 'pytorch': + apply_serialized_hf_segments_for_level2_weights(model_dir, _emit) + elif engine == 'turbomind': + apply_serialized_hf_segments_for_turbomind_level2_weights(model_dir, _emit) + else: + pytest.skip(f'unsupported engine for update_weights: {engine!r}') + + +def _level2_reload_hf_weights(backend: str, config: dict, model_case: str) -> None: + if not torch.cuda.is_available(): + pytest.skip('level-2 reload needs CUDA for serialize_state_dict / weight upload') + model_dir = resolve_hf_checkpoint_dir(config, model_case) + if not model_dir.is_dir(): + pytest.skip(f'HF checkpoint not found for update_weights: {model_dir}') + try: + _post_update_weights_from_hf_dir(model_dir, engine=backend) + except FileNotFoundError as e: + pytest.skip(str(e)) + except RuntimeError as e: + pytest.skip(str(e)) + + +def _fetch_is_sleeping() -> bool: + r = requests.get(f'{BASE_URL}/is_sleeping', timeout=30) + _assert_status_200(r) + return bool(r.json().get('is_sleeping')) + + +def _ensure_awake(max_attempts: int = 8) -> None: + for _ in range(max_attempts): + _assert_status_200(_post_wakeup()) + if not _fetch_is_sleeping(): + return + time.sleep(0.25) + raise AssertionError( + f'engine still is_sleeping=true after {max_attempts} POST /wakeup attempts; ' + f'BASE_URL={BASE_URL!r}') + + +def _chat_completion_collect(api_client: APIClient, model_name: str, **kwargs) -> dict: + kw = dict(kwargs) + kw['stream'] = False + output = None + for output in api_client.chat_completions_v1(model=model_name, **kw): + continue + assert output is not None, 'chat_completions_v1 returned no chunk' + return output + + +def _assert_level2_greedy_baseline_stable(api_client: APIClient, model_name: str, *, label: str) -> dict: + kwargs = dict( + messages=LEVEL2_GREEDY_MESSAGES, + max_tokens=LEVEL2_MAX_TOKENS, + temperature=0.0, + top_p=1.0, + top_k=1, + ) + refs: list[dict] = [] + contents: list[str] = [] + for i in range(LEVEL2_BASELINE_RUNS): + out = _chat_completion_collect(api_client, model_name, **kwargs) + assert_chat_completions_batch_return(out, model_name) + text = assistant_content_from_openai_completion_dict(out) + assert_assistant_not_degenerate(text, label=f'{label} baseline run {i + 1}') + refs.append(out) + contents.append(text) + assert len(set(contents)) == 1, ( + f'{label}: greedy REST baseline not stable (fix prompt/model for this case):\n' + + '\n'.join(f' run{j + 1}={c!r}' for j, c in enumerate(contents))) + return refs[0] + + +def _should_enforce_level2_greedy_checks(backend: str) -> bool: + # Known issue: TurboMind may produce non-stable outputs even in + # temperature=0 greedy-style requests. Keep the staged wakeup / reload + # flow coverage, but skip strict determinism assertions for this backend. + return backend != 'turbomind' + + +@pytest.mark.order(8) +@pytest.mark.flaky(reruns=2) +@pytest.mark.parametrize('backend', SLEEP_WAKEUP_BACKENDS) +@pytest.mark.parametrize('model_case', SLEEP_WAKEUP_MODEL_LIST) +class TestRestfulSleepWakeup: + + def test_sleep_wakeup_is_sleeping_roundtrip(self, model_case, backend): + try: + _ensure_awake() + r_sleep = _post_sleep() + _assert_status_200(r_sleep) + + assert _fetch_is_sleeping() is True + + r_wake = _post_wakeup() + _assert_status_200(r_wake) + + assert _fetch_is_sleeping() is False + finally: + _ensure_awake() + + def test_sleep_with_level_query_wakeup_and_chat(self, model_case, backend): + try: + _ensure_awake() + r_sleep = _post_sleep(level=1) + _assert_status_200(r_sleep) + + assert _fetch_is_sleeping() is True + + r_wake = _post_wakeup() + _assert_status_200(r_wake) + assert _fetch_is_sleeping() is False + + api_client = APIClient(BASE_URL) + model_name = api_client.available_models[0] + output = None + for output in api_client.chat_completions_v1( + model=model_name, + messages=[{'role': 'user', 'content': 'Hi, reply with one short sentence.'}], + max_tokens=32, + temperature=0.01): + continue + assert output is not None + assert_chat_completions_batch_return(output, model_name) + finally: + _ensure_awake() + + def test_sleep_partial_wakeup_with_tags(self, model_case, backend): + try: + _ensure_awake() + r_sleep = _post_sleep(level=1) + _assert_status_200(r_sleep) + assert _fetch_is_sleeping() is True + + r_w = _post_wakeup(tags=['weights']) + _assert_status_200(r_w) + assert _fetch_is_sleeping() is True + + r_kv = _post_wakeup(tags=['kv_cache']) + _assert_status_200(r_kv) + assert _fetch_is_sleeping() is False + finally: + _ensure_awake() + + def test_wakeup_unknown_tags_is_noop_then_full_wakeup(self, model_case, backend): + try: + _ensure_awake() + _assert_status_200(_post_sleep(level=1)) + assert _fetch_is_sleeping() is True + + _assert_status_200(_post_wakeup(tags=['not_a_valid_tag'])) + assert _fetch_is_sleeping() is True + + _assert_status_200(_post_wakeup()) + assert _fetch_is_sleeping() is False + finally: + _ensure_awake() + + def test_wakeup_mixed_valid_and_invalid_tags_entire_call_noop(self, model_case, backend): + try: + _ensure_awake() + _assert_status_200(_post_sleep(level=1)) + assert _fetch_is_sleeping() is True + + _assert_status_200(_post_wakeup(tags=['weights', 'not_a_valid_tag'])) + assert _fetch_is_sleeping() is True + + _assert_status_200(_post_wakeup(tags=['not_a_valid_tag', 'weights'])) + assert _fetch_is_sleeping() is True + + _assert_status_200(_post_wakeup()) + assert _fetch_is_sleeping() is False + finally: + _ensure_awake() + + def test_wakeup_both_valid_tags_in_one_request(self, model_case, backend): + try: + _ensure_awake() + _assert_status_200(_post_sleep(level=1)) + assert _fetch_is_sleeping() is True + + _assert_status_200(_post_wakeup(tags=['weights', 'kv_cache'])) + assert _fetch_is_sleeping() is False + + api_client = APIClient(BASE_URL) + model_name = api_client.available_models[0] + output = None + for output in api_client.chat_completions_v1( + model=model_name, + messages=[{'role': 'user', 'content': 'Hi, reply with one short sentence.'}], + max_tokens=32, + temperature=0.01): + continue + assert output is not None + assert_chat_completions_batch_return(output, model_name) + finally: + _ensure_awake() + + def test_wakeup_redundant_tag_after_partial_wake_is_noop(self, model_case, backend): + try: + _ensure_awake() + _assert_status_200(_post_sleep(level=1)) + assert _fetch_is_sleeping() is True + + _assert_status_200(_post_wakeup(tags=['weights'])) + assert _fetch_is_sleeping() is True + + _assert_status_200(_post_wakeup(tags=['weights'])) + assert _fetch_is_sleeping() is True + + _assert_status_200(_post_wakeup(tags=['kv_cache'])) + assert _fetch_is_sleeping() is False + finally: + _ensure_awake() + + def test_wakeup_empty_string_tag_is_noop_when_sleeping(self, model_case, backend): + try: + _ensure_awake() + _assert_status_200(_post_sleep(level=1)) + assert _fetch_is_sleeping() is True + + r = requests.post( + f'{BASE_URL}/wakeup', + headers=JSON_HEADERS, + json={}, + params=[('tags', '')], + timeout=_REQUEST_TIMEOUT, + ) + _assert_status_200(r) + assert _fetch_is_sleeping() is True + + _assert_status_200(_post_wakeup()) + assert _fetch_is_sleeping() is False + finally: + _ensure_awake() + + def test_full_wakeup_when_already_awake(self, model_case, backend): + try: + _ensure_awake() + assert _fetch_is_sleeping() is False + _assert_status_200(_post_wakeup()) + assert _fetch_is_sleeping() is False + _assert_status_200(_post_wakeup()) + assert _fetch_is_sleeping() is False + finally: + _ensure_awake() + + def test_sleep_second_call_while_sleeping_still_ok(self, model_case, backend): + try: + _ensure_awake() + _assert_status_200(_post_sleep(level=1)) + assert _fetch_is_sleeping() is True + _assert_status_200(_post_sleep(level=1)) + assert _fetch_is_sleeping() is True + _assert_status_200(_post_wakeup()) + assert _fetch_is_sleeping() is False + finally: + _ensure_awake() + + def test_sleep_non_integer_level_is_http_error(self, model_case, backend): + try: + _ensure_awake() + resp = _post_sleep_query_raw('level=not_an_int') + assert resp.status_code != 200, f'expected non-200, got {resp.status_code} body={resp.text!r}' + finally: + _ensure_awake() + + def test_sleep_level_2_full_wakeup_and_chat(self, model_case, backend, config): + try: + _ensure_awake() + api_client = APIClient(BASE_URL) + model_name = api_client.available_models[0] + + baseline = None + if _should_enforce_level2_greedy_checks(backend): + baseline = _assert_level2_greedy_baseline_stable( + api_client, model_name, label='level2 REST') + + _assert_status_200(_post_sleep_level2()) + assert _fetch_is_sleeping() is True + + _assert_status_200(_post_wakeup(tags=['weights'])) + assert _fetch_is_sleeping() is True + _level2_reload_hf_weights(backend, config, model_case) + + _assert_status_200(_post_wakeup(tags=['kv_cache'])) + assert _fetch_is_sleeping() is False + + after = _chat_completion_collect( + api_client, + model_name, + messages=LEVEL2_GREEDY_MESSAGES, + max_tokens=LEVEL2_MAX_TOKENS, + temperature=0.0, + top_p=1.0, + top_k=1, + ) + assert_chat_completions_batch_return(after, model_name) + assert_assistant_not_degenerate( + assistant_content_from_openai_completion_dict(after), + label='level2 REST after staged wakeup (1st chat)') + if baseline is not None: + assert_chat_decode_unchanged(baseline, after, label='level2 REST 1st infer after staged wakeup') + + after2 = _chat_completion_collect( + api_client, + model_name, + messages=LEVEL2_GREEDY_MESSAGES, + max_tokens=LEVEL2_MAX_TOKENS, + temperature=0.0, + top_p=1.0, + top_k=1, + ) + assert_chat_completions_batch_return(after2, model_name) + if baseline is not None: + assert_chat_decode_unchanged(baseline, after2, label='level2 REST 2nd infer after staged wakeup') + + _assert_status_200(_post_sleep_level2()) + assert _fetch_is_sleeping() is True + _assert_status_200(_post_wakeup(tags=['weights'])) + assert _fetch_is_sleeping() is True + _level2_reload_hf_weights(backend, config, model_case) + _assert_status_200(_post_wakeup(tags=['kv_cache'])) + assert _fetch_is_sleeping() is False + + after_full = _chat_completion_collect( + api_client, + model_name, + messages=LEVEL2_GREEDY_MESSAGES, + max_tokens=LEVEL2_MAX_TOKENS, + temperature=0.0, + top_p=1.0, + top_k=1, + ) + assert_chat_completions_batch_return(after_full, model_name) + label2 = 'level2 REST infer after 2nd sleep cycle (staged wakeup)' + if baseline is not None: + assert_chat_decode_unchanged(baseline, after_full, label=label2) + + output = None + for output in api_client.chat_completions_v1( + model=model_name, + messages=[{'role': 'user', 'content': 'Hi, reply with one short sentence.'}], + max_tokens=32, + temperature=0.01): + continue + assert output is not None + assert_chat_completions_batch_return(output, model_name) + finally: + _ensure_awake() diff --git a/autotest/utils/constant.py b/autotest/utils/constant.py index bc3ebb0ad5..e8e9087d3b 100644 --- a/autotest/utils/constant.py +++ b/autotest/utils/constant.py @@ -201,6 +201,14 @@ } } +SLEEP_WAKEUP_MODEL_LIST = [ + 'Qwen/Qwen3.5-35B-A3B', + 'Qwen/Qwen3.5-35B-A3B-FP8', + 'Qwen/Qwen3.5-122B-A10B', +] + +SLEEP_WAKEUP_BACKENDS = ['pytorch', 'turbomind'] + BACKEND_LIST = ['turbomind', 'pytorch'] RESTFUL_MODEL_LIST_LATEST = [ diff --git a/autotest/utils/sleep_utils.py b/autotest/utils/sleep_utils.py new file mode 100644 index 0000000000..f289f3d558 --- /dev/null +++ b/autotest/utils/sleep_utils.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import json +import os +from collections import Counter +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import torch +from safetensors.torch import safe_open +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME + +from lmdeploy.utils import serialize_state_dict + +UPDATE_WEIGHTS_CUDA_DEVICE_ENV = 'LMDEPLOY_UPDATE_WEIGHTS_CUDA_DEVICE' + +LEVEL2_GREEDY_MESSAGES = [{'role': 'user', 'content': '424242'}] +LEVEL2_MAX_TOKENS = 64 +LEVEL2_BASELINE_RUNS = 3 +MAX_SINGLE_CHAR_FRACTION = 0.75 + + +def resolve_update_weights_cuda_device_index() -> int: + raw = os.environ.get(UPDATE_WEIGHTS_CUDA_DEVICE_ENV, '').strip() + if not raw: + return torch.cuda.current_device() + try: + idx = int(raw) + except ValueError as e: + raise AssertionError( + f'{UPDATE_WEIGHTS_CUDA_DEVICE_ENV} must be an int, got {raw!r}') from e + n = torch.cuda.device_count() + assert 0 <= idx < n, ( + f'{UPDATE_WEIGHTS_CUDA_DEVICE_ENV}={idx} out of range for cuda.device_count()={n}') + return idx + + +def resolve_hf_checkpoint_dir(config: dict, model_case: str) -> Path: + if os.environ.get('LMDEPLOY_USE_MODELSCOPE', 'False') == 'True': + return Path(model_case) + return Path(config['model_path']) / model_case + + +def shard_paths(model_dir: Path) -> tuple[str, list[Path]]: + if (model_dir / SAFE_WEIGHTS_NAME).is_file(): + return 'safetensors', [model_dir / SAFE_WEIGHTS_NAME] + if (model_dir / SAFE_WEIGHTS_INDEX_NAME).is_file(): + with open(model_dir / SAFE_WEIGHTS_INDEX_NAME, encoding='utf-8') as f: + index = json.load(f) + paths = sorted(set(index['weight_map'].values())) + return 'safetensors', [model_dir / p for p in paths] + if (model_dir / WEIGHTS_NAME).is_file(): + return 'pytorch', [model_dir / WEIGHTS_NAME] + if (model_dir / WEIGHTS_INDEX_NAME).is_file(): + with open(model_dir / WEIGHTS_INDEX_NAME, encoding='utf-8') as f: + index = json.load(f) + paths = sorted(set(index['weight_map'].values())) + return 'pytorch', [model_dir / p for p in paths] + raise FileNotFoundError(f'No HF weights under {model_dir}') + + +def load_shard_tensors(kind: str, path: Path) -> dict[str, torch.Tensor]: + out: dict[str, torch.Tensor] = {} + if kind == 'safetensors': + with safe_open(str(path), framework='pt') as f: + for key in f.keys(): + out[key] = f.get_tensor(key) + else: + state = torch.load(str(path), weights_only=True, map_location='cpu') + try: + out.update(state) + finally: + del state + return out + + +def assistant_content_from_openai_completion_dict(output: dict) -> str: + choices = output.get('choices') or [] + assert len(choices) == 1, f'expected 1 choice, got {len(choices)}' + msg = choices[0].get('message') or {} + return (msg.get('content') or '').strip() + + +def assert_assistant_not_degenerate(content: str, *, label: str) -> None: + assert content, f'{label}: empty assistant content' + compact = content.replace('\n', ' ').strip() + assert len(set(compact)) >= 4, ( + f'{label}: degenerate assistant text (too few distinct chars): {content!r}') + top_cnt = Counter(compact).most_common(1)[0][1] + assert top_cnt / len(compact) <= MAX_SINGLE_CHAR_FRACTION, ( + f'{label}: one token/char dominates assistant text: {content!r}') + + +def level2_update_weights_request_dict(serialized_data: object, finished: bool) -> dict[str, Any]: + return { + 'serialized_named_tensors': serialized_data, + 'finished': finished, + } + + +def assert_chat_decode_unchanged(ref: dict, cur: dict, *, label: str) -> None: + a, b = assistant_content_from_openai_completion_dict(ref), assistant_content_from_openai_completion_dict(cur) + assert a == b, f'{label}: assistant content changed\n before={a!r}\n after={b!r}' + rt = ref.get('usage', {}).get('completion_tokens') + ct = cur.get('usage', {}).get('completion_tokens') + assert rt == ct, f'{label}: completion_tokens changed {rt} -> {ct}' + rfr = ref['choices'][0].get('finish_reason') + cfr = cur['choices'][0].get('finish_reason') + if rfr is not None and cfr is not None: + assert rfr == cfr, f'{label}: finish_reason changed {rfr!r} -> {cfr!r}' + + +def apply_serialized_hf_segments_for_level2_weights( + model_dir: Path, + emit_segment: Callable[[Any, bool], None], +) -> None: + kind, shards = shard_paths(model_dir) + num_segment = len(shards) + dev_idx = resolve_update_weights_cuda_device_index() + device = torch.device('cuda', dev_idx) + with torch.cuda.device(dev_idx): + for seg_idx in range(num_segment): + cpu_dict = load_shard_tensors(kind, shards[seg_idx]) + seg_gpu = {k: v.to(device, non_blocking=True) for k, v in cpu_dict.items()} + del cpu_dict + serialized_data = serialize_state_dict(seg_gpu) + del seg_gpu + torch.cuda.empty_cache() + emit_segment(serialized_data, seg_idx == num_segment - 1) + + +def apply_serialized_hf_segments_for_turbomind_level2_weights( + model_dir: Path, + emit_segment: Callable[[Any, bool], None], +) -> None: + from lmdeploy.turbomind.deploy.converter import get_input_model_registered_name + from lmdeploy.turbomind.deploy.source_model.base import INPUT_MODELS + + root = str(model_dir.resolve()) + try: + input_model_name = get_input_model_registered_name(root, 'hf') + if input_model_name == 'qwen3_5-moe': + raise RuntimeError( + 'turbomind update_weights is unsupported for qwen3_5-moe in the current server build: ' + 'server-side StateDictLoader has no `index`, but Qwen3_5MoeModel.readers() accesses loader.index') + input_model_cls = INPUT_MODELS.get(input_model_name) + input_model = input_model_cls(model_path=root, tokenizer_path=root) + except Exception as e: + raise RuntimeError( + f'turbomind update_weights: failed to build input_model readers for {model_dir}: {e}') from e + + dev_idx = resolve_update_weights_cuda_device_index() + device = torch.device('cuda', dev_idx) + with torch.cuda.device(dev_idx): + it = iter(dict(reader.params) for _, reader in input_model.readers()) + try: + chunk = next(it) + except StopIteration: + raise RuntimeError(f'no turbomind weight chunks to emit under {model_dir}') from None + + for cpu_dict_next in it: + seg_gpu = {k: v.to(device, non_blocking=True) for k, v in chunk.items()} + try: + emit_segment(serialize_state_dict(seg_gpu), False) + finally: + del seg_gpu + torch.cuda.empty_cache() + chunk = cpu_dict_next + + seg_gpu = {k: v.to(device, non_blocking=True) for k, v in chunk.items()} + try: + emit_segment(serialize_state_dict(seg_gpu), True) + finally: + del seg_gpu + torch.cuda.empty_cache()