diff --git a/project/moirai-agent/ctx_forecast/README.md b/project/moirai-agent/ctx_forecast/README.md index 08fef020..d639b9ad 100644 --- a/project/moirai-agent/ctx_forecast/README.md +++ b/project/moirai-agent/ctx_forecast/README.md @@ -6,11 +6,26 @@ First, set up the environment by running: ``` pip install -r requirement.txt ``` -And export your OpenAI key by: + +#### OpenAI (default) +Export your OpenAI key: ``` export OPENAI_API_KEY="..." ``` +#### MiniMax (alternative) +To use [MiniMax](https://www.minimaxi.com/) as the LLM provider, export your MiniMax API key and update the config: +```bash +export MINIMAX_API_KEY="..." +``` +Then set `"provider": "minimax"` and `"model_name": "MiniMax-M2.7"` in your config file (see `src/ctx_forecast/config.py` for a full example). Available models: +| Model | Context | Notes | +|---|---|---| +| `MiniMax-M2.7` | 1M tokens | Latest, recommended | +| `MiniMax-M2.7-highspeed` | 1M tokens | Faster variant | +| `MiniMax-M2.5` | 204K tokens | Previous generation | +| `MiniMax-M2.5-highspeed` | 204K tokens | Fast | + ### Prepare the dataset - Download the `gift_ctx.parquet` data here: `https://huggingface.co/datasets/Salesforce/GIFT-CTX` - Plot the historical data before runtime: diff --git a/project/moirai-agent/ctx_forecast/src/ctx_forecast/config.py b/project/moirai-agent/ctx_forecast/src/ctx_forecast/config.py index 5ad836c8..5c202020 100644 --- a/project/moirai-agent/ctx_forecast/src/ctx_forecast/config.py +++ b/project/moirai-agent/ctx_forecast/src/ctx_forecast/config.py @@ -1,6 +1,14 @@ max_iterations = 6 + +# Default configuration – uses OpenAI Responses API. +# Set "provider" to "minimax" and "model_name" to a MiniMax model +# (e.g. "MiniMax-M2.7") to use MiniMax via the Chat Completions API. CONFIG = { "llm": { + # "provider" selects the LLM backend. + # Supported values: "openai" (default), "minimax", "openai_compatible". + # For "minimax", export MINIMAX_API_KEY and set model_name accordingly. + "provider": "openai", "model_name": "gpt-5.1", "model_params_type": { "temperature": 1.0, @@ -35,3 +43,27 @@ "\n - Only when exact mathematical structures are inferred from the context, you may write and execute codes in the python-sandbox. Always include a print function in your codes to return valid messages. " f"\n - Ensure that all reasoning and tool usage is complete within a maximum of {max_iterations} steps. Each step should either advance your understanding or gather necessary information. Be systematic and thorough in your approach. A final and fully cited answer has to be output before step {max_iterations}. ", } + +# ---- Example: MiniMax configuration ---- +# To use MiniMax instead of OpenAI, export MINIMAX_API_KEY and use: +# +# MINIMAX_CONFIG = { +# "llm": { +# "provider": "minimax", +# "model_name": "MiniMax-M2.7", # or MiniMax-M2.5-highspeed (204K ctx) +# "model_params_type": { +# "temperature": 0.7, +# "top_p": 1.0, +# "max_output_tokens": 8192, +# }, +# }, +# "servers": CONFIG["servers"], +# "max_iterations": max_iterations, +# "system_prompt": CONFIG["system_prompt"], +# } +# +# Available MiniMax models: +# - MiniMax-M2.7 (latest, 1M context) +# - MiniMax-M2.7-highspeed (faster variant) +# - MiniMax-M2.5 (previous gen) +# - MiniMax-M2.5-highspeed (204K context, fast) diff --git a/project/moirai-agent/ctx_forecast/src/ctx_forecast/llm_provider.py b/project/moirai-agent/ctx_forecast/src/ctx_forecast/llm_provider.py new file mode 100644 index 00000000..287d44c4 --- /dev/null +++ b/project/moirai-agent/ctx_forecast/src/ctx_forecast/llm_provider.py @@ -0,0 +1,327 @@ +""" +LLM Provider abstraction for Moirai Agent. + +Supports OpenAI (native Responses API) and OpenAI-compatible providers +(e.g., MiniMax) via Chat Completions API with automatic format translation. +""" + +import json +import os +import re +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + + +# --------------------------------------------------------------------------- +# Lightweight response dataclasses – attribute-compatible with the objects +# returned by OpenAI's Responses API so that existing MCPClient code works +# unchanged. +# --------------------------------------------------------------------------- + + +@dataclass +class OutputMessage: + """Text message output item.""" + + type: str = "message" + content: str = "" + + +@dataclass +class OutputFunctionCall: + """Function/tool call output item.""" + + type: str = "function_call" + call_id: str = "" + name: str = "" + arguments: str = "" + + +@dataclass +class OutputReasoning: + """Reasoning trace output item (only populated by OpenAI reasoning models).""" + + type: str = "reasoning" + summary: str = "" + + +@dataclass +class LLMResponse: + """Unified response envelope – drop-in replacement for ``openai.Response``.""" + + output: list = field(default_factory=list) + output_text: str = "" + + +# --------------------------------------------------------------------------- +# Provider presets +# --------------------------------------------------------------------------- + +PROVIDER_PRESETS: Dict[str, Dict[str, Any]] = { + "openai": { + "env_key": "OPENAI_API_KEY", + "base_url": None, + }, + "minimax": { + "env_key": "MINIMAX_API_KEY", + "base_url": "https://api.minimax.io/v1", + "default_model": "MiniMax-M2.7", + "models": [ + "MiniMax-M2.7", + "MiniMax-M2.7-highspeed", + "MiniMax-M2.5", + "MiniMax-M2.5-highspeed", + ], + }, +} + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + + +def create_llm_provider(config: dict): + """Instantiate the appropriate LLM provider based on ``config["llm"]["provider"]``.""" + + provider_name = config.get("llm", {}).get("provider", "openai") + + if provider_name == "openai": + return OpenAIResponsesProvider(config) + + if provider_name in PROVIDER_PRESETS or provider_name == "openai_compatible": + return ChatCompletionsProvider(config) + + raise ValueError( + f"Unknown provider: {provider_name!r}. " + f"Supported: {sorted(PROVIDER_PRESETS.keys())}" + ) + + +# --------------------------------------------------------------------------- +# OpenAI – native Responses API (no translation needed) +# --------------------------------------------------------------------------- + + +class OpenAIResponsesProvider: + """Thin wrapper around the OpenAI Responses API.""" + + def __init__(self, config: dict): + from openai import OpenAI + + self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) + + def create(self, model, instructions, input, tools, tool_choice, **kwargs): + return self.client.responses.create( + model=model, + instructions=instructions, + input=input, + tools=tools, + tool_choice=tool_choice, + **kwargs, + ) + + +# --------------------------------------------------------------------------- +# Chat Completions – for MiniMax and other OpenAI-compatible providers +# --------------------------------------------------------------------------- + + +class ChatCompletionsProvider: + """Translates between Responses-API semantics and Chat Completions.""" + + # Parameters specific to OpenAI Responses API that are not supported + # by the standard Chat Completions endpoint. + _UNSUPPORTED_PARAMS = frozenset({"reasoning", "reasoning_effort", "summary"}) + + def __init__(self, config: dict): + from openai import OpenAI + + provider_name = config.get("llm", {}).get("provider", "openai_compatible") + preset = PROVIDER_PRESETS.get(provider_name, {}) + + env_key = preset.get("env_key", f"{provider_name.upper()}_API_KEY") + api_key = os.environ.get(env_key, "") + base_url = config.get("llm", {}).get( + "base_url", preset.get("base_url", "") + ) + + if not api_key: + raise ValueError( + f"API key not found. Set the {env_key} environment variable." + ) + if not base_url: + raise ValueError( + f"base_url is required for provider {provider_name!r}. " + "Set it in config['llm']['base_url']." + ) + + self.client = OpenAI(api_key=api_key, base_url=base_url) + self.provider_name = provider_name + + # ---- public interface -------------------------------------------------- + + def create(self, model, instructions, input, tools, tool_choice, **kwargs): + messages = self._to_chat_messages(instructions, input) + chat_tools = self._to_chat_tools(tools) + + clean_kwargs = {} + for k, v in kwargs.items(): + if k in self._UNSUPPORTED_PARAMS: + continue + if k == "temperature": + v = max(0.0, min(float(v), 1.0)) + if k == "max_output_tokens": + clean_kwargs["max_tokens"] = v + continue + clean_kwargs[k] = v + + create_kwargs: Dict[str, Any] = { + "model": model, + "messages": messages, + **clean_kwargs, + } + if chat_tools: + create_kwargs["tools"] = chat_tools + create_kwargs["tool_choice"] = "auto" + + response = self.client.chat.completions.create(**create_kwargs) + return self._to_llm_response(response) + + # ---- input conversion -------------------------------------------------- + + def _to_chat_messages(self, instructions, input_items): + """Convert Responses-API ``input`` list to Chat Completions ``messages``.""" + + messages: List[Dict[str, Any]] = [] + if instructions: + messages.append({"role": "system", "content": instructions}) + + i = 0 + while i < len(input_items): + item = input_items[i] + + # --- dict-based items --- + if isinstance(item, dict): + if item.get("type") == "function_call_output": + messages.append( + { + "role": "tool", + "tool_call_id": item["call_id"], + "content": str(item.get("output", "")), + } + ) + elif "role" in item: + content = item.get("content", "") + if isinstance(content, list): + converted = [] + for part in content: + ptype = part.get("type", "") + if ptype == "input_image": + converted.append( + { + "type": "image_url", + "image_url": {"url": part["image_url"]}, + } + ) + elif ptype == "input_text": + converted.append( + {"type": "text", "text": part["text"]} + ) + else: + converted.append(part) + messages.append({"role": item["role"], "content": converted}) + else: + messages.append({"role": item["role"], "content": content}) + i += 1 + + # --- output objects from a previous LLM response --- + elif hasattr(item, "type"): + text_content = "" + tool_calls: List[Dict[str, Any]] = [] + + # Group consecutive output objects into a single assistant msg + while i < len(input_items) and hasattr(input_items[i], "type"): + curr = input_items[i] + if curr.type == "function_call": + tool_calls.append( + { + "id": curr.call_id, + "type": "function", + "function": { + "name": curr.name, + "arguments": curr.arguments, + }, + } + ) + elif curr.type == "message": + text_content = getattr(curr, "content", "") + # reasoning items are silently skipped + i += 1 + + msg: Dict[str, Any] = {"role": "assistant"} + if text_content: + msg["content"] = text_content + if tool_calls: + msg["tool_calls"] = tool_calls + if text_content or tool_calls: + messages.append(msg) + else: + i += 1 + + return messages + + # ---- tools conversion -------------------------------------------------- + + @staticmethod + def _to_chat_tools(tools): + """Convert Responses-API tool dicts to Chat Completions format.""" + + if not tools: + return [] + chat_tools = [] + for tool in tools: + if isinstance(tool, dict) and tool.get("type") == "function": + chat_tools.append( + { + "type": "function", + "function": { + "name": tool["name"], + "description": tool.get("description", ""), + "parameters": tool.get("parameters", {}), + }, + } + ) + return chat_tools + + # ---- response conversion ----------------------------------------------- + + @staticmethod + def _to_llm_response(response): + """Convert a Chat Completions response to the ``LLMResponse`` envelope.""" + + message = response.choices[0].message + output: list = [] + output_text = message.content or "" + + # Strip blocks that some models emit + if output_text: + output_text = re.sub( + r".*?\s*", "", output_text, flags=re.DOTALL + ).strip() + + if output_text: + output.append(OutputMessage(type="message", content=output_text)) + + if message.tool_calls: + for tc in message.tool_calls: + output.append( + OutputFunctionCall( + type="function_call", + call_id=tc.id, + name=tc.function.name, + arguments=tc.function.arguments, + ) + ) + + return LLMResponse(output=output, output_text=output_text) diff --git a/project/moirai-agent/ctx_forecast/src/ctx_forecast/mcp_client.py b/project/moirai-agent/ctx_forecast/src/ctx_forecast/mcp_client.py index 126499c6..b649c07d 100644 --- a/project/moirai-agent/ctx_forecast/src/ctx_forecast/mcp_client.py +++ b/project/moirai-agent/ctx_forecast/src/ctx_forecast/mcp_client.py @@ -6,14 +6,13 @@ from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client -from openai import OpenAI +from src.ctx_forecast.llm_provider import create_llm_provider from src.ctx_forecast.utils import encode_image, parse_values_from_string class MCPClient: def __init__(self, config): - # self.openai = OpenAI() - self.openai = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) + self.provider = create_llm_provider(config) self.system_prompt = config["system_prompt"] self.model_name = config["llm"]["model_name"] self.model_params = config["llm"]["model_params_type"] @@ -148,7 +147,7 @@ async def _single_round_query( max_attempts = 3 for attempt in range(max_attempts): try: - reply = self.openai.responses.create( + reply = self.provider.create( model=self.model_name, instructions=instructions, input=context, diff --git a/project/moirai-agent/ctx_forecast/tests/__init__.py b/project/moirai-agent/ctx_forecast/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/project/moirai-agent/ctx_forecast/tests/test_llm_provider.py b/project/moirai-agent/ctx_forecast/tests/test_llm_provider.py new file mode 100644 index 00000000..bcaf902a --- /dev/null +++ b/project/moirai-agent/ctx_forecast/tests/test_llm_provider.py @@ -0,0 +1,452 @@ +"""Unit tests for the LLM provider abstraction.""" + +import json +import os +import sys +import unittest +from dataclasses import dataclass +from unittest.mock import MagicMock, patch + +# Adjust path so imports work from the test directory +sys.path.insert( + 0, + os.path.join(os.path.dirname(__file__), "..", "src"), +) + +from ctx_forecast.llm_provider import ( + PROVIDER_PRESETS, + ChatCompletionsProvider, + LLMResponse, + OpenAIResponsesProvider, + OutputFunctionCall, + OutputMessage, + OutputReasoning, + create_llm_provider, +) + + +# --------------------------------------------------------------------------- +# Factory tests +# --------------------------------------------------------------------------- + + +class TestCreateLLMProvider(unittest.TestCase): + """Tests for the ``create_llm_provider`` factory.""" + + @patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}) + def test_default_provider_is_openai(self): + config = {"llm": {"model_name": "gpt-5.1", "model_params_type": {}}} + provider = create_llm_provider(config) + self.assertIsInstance(provider, OpenAIResponsesProvider) + + @patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}) + def test_explicit_openai_provider(self): + config = { + "llm": {"provider": "openai", "model_name": "gpt-5.1", "model_params_type": {}} + } + provider = create_llm_provider(config) + self.assertIsInstance(provider, OpenAIResponsesProvider) + + @patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}) + def test_minimax_provider(self): + config = { + "llm": { + "provider": "minimax", + "model_name": "MiniMax-M2.7", + "model_params_type": {}, + } + } + provider = create_llm_provider(config) + self.assertIsInstance(provider, ChatCompletionsProvider) + + @patch.dict(os.environ, {"OPENAI_COMPATIBLE_API_KEY": "test-key"}) + def test_openai_compatible_provider(self): + config = { + "llm": { + "provider": "openai_compatible", + "model_name": "custom-model", + "base_url": "https://api.example.com/v1", + "model_params_type": {}, + } + } + provider = create_llm_provider(config) + self.assertIsInstance(provider, ChatCompletionsProvider) + + def test_unknown_provider_raises(self): + config = { + "llm": { + "provider": "nonexistent_provider", + "model_name": "x", + "model_params_type": {}, + } + } + with self.assertRaises(ValueError): + create_llm_provider(config) + + +# --------------------------------------------------------------------------- +# Message conversion tests +# --------------------------------------------------------------------------- + + +class TestChatCompletionsMessageConversion(unittest.TestCase): + """Tests for ``ChatCompletionsProvider._to_chat_messages``.""" + + def _make_provider(self): + """Create a provider instance with mocked client.""" + with patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}): + config = { + "llm": { + "provider": "minimax", + "model_name": "MiniMax-M2.7", + "model_params_type": {}, + } + } + return ChatCompletionsProvider(config) + + def test_system_instruction(self): + provider = self._make_provider() + msgs = provider._to_chat_messages("You are helpful.", []) + self.assertEqual(len(msgs), 1) + self.assertEqual(msgs[0]["role"], "system") + self.assertEqual(msgs[0]["content"], "You are helpful.") + + def test_simple_user_message(self): + provider = self._make_provider() + input_items = [{"role": "user", "content": "Hello"}] + msgs = provider._to_chat_messages("sys", input_items) + self.assertEqual(len(msgs), 2) # system + user + self.assertEqual(msgs[1]["role"], "user") + self.assertEqual(msgs[1]["content"], "Hello") + + def test_image_message_conversion(self): + provider = self._make_provider() + input_items = [ + { + "role": "user", + "content": [ + {"type": "input_image", "image_url": "data:image/png;base64,abc"}, + {"type": "input_text", "text": "Describe this"}, + ], + } + ] + msgs = provider._to_chat_messages(None, input_items) + self.assertEqual(len(msgs), 1) + content = msgs[0]["content"] + self.assertEqual(content[0]["type"], "image_url") + self.assertEqual( + content[0]["image_url"]["url"], "data:image/png;base64,abc" + ) + self.assertEqual(content[1]["type"], "text") + self.assertEqual(content[1]["text"], "Describe this") + + def test_function_call_output_conversion(self): + provider = self._make_provider() + input_items = [ + { + "type": "function_call_output", + "call_id": "call_123", + "output": "result text", + } + ] + msgs = provider._to_chat_messages(None, input_items) + self.assertEqual(len(msgs), 1) + self.assertEqual(msgs[0]["role"], "tool") + self.assertEqual(msgs[0]["tool_call_id"], "call_123") + self.assertEqual(msgs[0]["content"], "result text") + + def test_output_objects_grouped_into_assistant_message(self): + """Consecutive output objects should merge into a single assistant msg.""" + provider = self._make_provider() + input_items = [ + OutputMessage(content="I'll call the tool"), + OutputFunctionCall( + call_id="c1", name="forecast", arguments='{"x": 1}' + ), + OutputFunctionCall( + call_id="c2", name="sandbox", arguments='{"code": "print(1)"}' + ), + ] + msgs = provider._to_chat_messages(None, input_items) + self.assertEqual(len(msgs), 1) + self.assertEqual(msgs[0]["role"], "assistant") + self.assertEqual(msgs[0]["content"], "I'll call the tool") + self.assertEqual(len(msgs[0]["tool_calls"]), 2) + self.assertEqual(msgs[0]["tool_calls"][0]["id"], "c1") + self.assertEqual(msgs[0]["tool_calls"][1]["id"], "c2") + + def test_reasoning_output_skipped(self): + provider = self._make_provider() + input_items = [ + OutputReasoning(summary="thinking..."), + OutputMessage(content="Done"), + ] + msgs = provider._to_chat_messages(None, input_items) + self.assertEqual(len(msgs), 1) + self.assertEqual(msgs[0]["content"], "Done") + + def test_multi_round_context(self): + """Simulate a two-round conversation context.""" + provider = self._make_provider() + input_items = [ + # Round 1: user + assistant + tool result + {"role": "user", "content": "step: 1"}, + {"role": "user", "content": "What is 2+2?"}, + OutputMessage(content="Let me calculate"), + OutputFunctionCall( + call_id="c1", name="calc", arguments='{"expr": "2+2"}' + ), + {"type": "function_call_output", "call_id": "c1", "output": "4"}, + # Round 2: user step marker + {"role": "user", "content": "step: 2"}, + ] + msgs = provider._to_chat_messages("You are helpful.", input_items) + # system, user(step1), user(query), assistant(content+tool), tool, user(step2) + self.assertEqual(len(msgs), 6) + self.assertEqual(msgs[0]["role"], "system") + self.assertEqual(msgs[1]["role"], "user") + self.assertEqual(msgs[2]["role"], "user") + self.assertEqual(msgs[3]["role"], "assistant") + self.assertEqual(msgs[4]["role"], "tool") + self.assertEqual(msgs[5]["role"], "user") + + +# --------------------------------------------------------------------------- +# Tool conversion tests +# --------------------------------------------------------------------------- + + +class TestChatCompletionsToolConversion(unittest.TestCase): + def test_tools_converted(self): + tools = [ + { + "type": "function", + "name": "forecast", + "description": "Run forecast", + "parameters": {"type": "object", "properties": {"x": {"type": "number"}}}, + } + ] + result = ChatCompletionsProvider._to_chat_tools(tools) + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["type"], "function") + self.assertEqual(result[0]["function"]["name"], "forecast") + self.assertEqual(result[0]["function"]["description"], "Run forecast") + + def test_empty_tools(self): + self.assertEqual(ChatCompletionsProvider._to_chat_tools([]), []) + self.assertEqual(ChatCompletionsProvider._to_chat_tools(None), []) + + +# --------------------------------------------------------------------------- +# Response conversion tests +# --------------------------------------------------------------------------- + + +class TestChatCompletionsResponseConversion(unittest.TestCase): + def _mock_response(self, content=None, tool_calls=None): + """Create a mock Chat Completions response.""" + message = MagicMock() + message.content = content + message.tool_calls = tool_calls + + choice = MagicMock() + choice.message = message + + response = MagicMock() + response.choices = [choice] + return response + + def test_text_only_response(self): + resp = self._mock_response(content="The answer is 42") + result = ChatCompletionsProvider._to_llm_response(resp) + self.assertIsInstance(result, LLMResponse) + self.assertEqual(result.output_text, "The answer is 42") + self.assertEqual(len(result.output), 1) + self.assertIsInstance(result.output[0], OutputMessage) + + def test_tool_call_response(self): + tc = MagicMock() + tc.id = "call_abc" + tc.function.name = "forecast" + tc.function.arguments = '{"x": 1}' + + resp = self._mock_response(content=None, tool_calls=[tc]) + result = ChatCompletionsProvider._to_llm_response(resp) + self.assertEqual(result.output_text, "") + self.assertEqual(len(result.output), 1) + self.assertIsInstance(result.output[0], OutputFunctionCall) + self.assertEqual(result.output[0].call_id, "call_abc") + self.assertEqual(result.output[0].name, "forecast") + + def test_text_with_tool_calls(self): + tc = MagicMock() + tc.id = "call_1" + tc.function.name = "calc" + tc.function.arguments = '{"a": 1}' + + resp = self._mock_response(content="Calling calc", tool_calls=[tc]) + result = ChatCompletionsProvider._to_llm_response(resp) + self.assertEqual(result.output_text, "Calling calc") + self.assertEqual(len(result.output), 2) # message + function_call + + def test_think_tag_stripping(self): + resp = self._mock_response( + content="Let me think about this...\nThe answer is 42" + ) + result = ChatCompletionsProvider._to_llm_response(resp) + self.assertEqual(result.output_text, "The answer is 42") + + +# --------------------------------------------------------------------------- +# Parameter handling tests +# --------------------------------------------------------------------------- + + +class TestParameterHandling(unittest.TestCase): + def _make_provider(self): + with patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}): + config = { + "llm": { + "provider": "minimax", + "model_name": "MiniMax-M2.7", + "model_params_type": {}, + } + } + return ChatCompletionsProvider(config) + + @patch("ctx_forecast.llm_provider.ChatCompletionsProvider._to_llm_response") + def test_reasoning_param_filtered(self, mock_resp): + provider = self._make_provider() + mock_resp.return_value = LLMResponse(output=[], output_text="") + provider.client = MagicMock() + provider.client.chat.completions.create.return_value = MagicMock() + + provider.create( + model="MiniMax-M2.7", + instructions="sys", + input=[], + tools=[], + tool_choice="auto", + reasoning={"effort": "medium"}, + temperature=0.7, + ) + + call_kwargs = provider.client.chat.completions.create.call_args[1] + self.assertNotIn("reasoning", call_kwargs) + self.assertNotIn("reasoning_effort", call_kwargs) + + @patch("ctx_forecast.llm_provider.ChatCompletionsProvider._to_llm_response") + def test_max_output_tokens_converted(self, mock_resp): + provider = self._make_provider() + mock_resp.return_value = LLMResponse(output=[], output_text="") + provider.client = MagicMock() + provider.client.chat.completions.create.return_value = MagicMock() + + provider.create( + model="MiniMax-M2.7", + instructions="sys", + input=[], + tools=[], + tool_choice="auto", + max_output_tokens=4096, + ) + + call_kwargs = provider.client.chat.completions.create.call_args[1] + self.assertNotIn("max_output_tokens", call_kwargs) + self.assertEqual(call_kwargs["max_tokens"], 4096) + + @patch("ctx_forecast.llm_provider.ChatCompletionsProvider._to_llm_response") + def test_temperature_clamped(self, mock_resp): + provider = self._make_provider() + mock_resp.return_value = LLMResponse(output=[], output_text="") + provider.client = MagicMock() + provider.client.chat.completions.create.return_value = MagicMock() + + provider.create( + model="MiniMax-M2.7", + instructions="sys", + input=[], + tools=[], + tool_choice="auto", + temperature=1.5, + ) + + call_kwargs = provider.client.chat.completions.create.call_args[1] + self.assertLessEqual(call_kwargs["temperature"], 1.0) + + +# --------------------------------------------------------------------------- +# Provider presets tests +# --------------------------------------------------------------------------- + + +class TestProviderPresets(unittest.TestCase): + def test_minimax_preset_exists(self): + self.assertIn("minimax", PROVIDER_PRESETS) + + def test_minimax_base_url(self): + self.assertEqual( + PROVIDER_PRESETS["minimax"]["base_url"], + "https://api.minimax.io/v1", + ) + + def test_minimax_models(self): + models = PROVIDER_PRESETS["minimax"]["models"] + self.assertIn("MiniMax-M2.7", models) + self.assertIn("MiniMax-M2.5-highspeed", models) + + def test_minimax_env_key(self): + self.assertEqual(PROVIDER_PRESETS["minimax"]["env_key"], "MINIMAX_API_KEY") + + +# --------------------------------------------------------------------------- +# Dataclass tests +# --------------------------------------------------------------------------- + + +class TestResponseDataclasses(unittest.TestCase): + def test_output_message_attributes(self): + msg = OutputMessage(content="hello") + self.assertEqual(msg.type, "message") + self.assertEqual(msg.content, "hello") + + def test_output_function_call_attributes(self): + fc = OutputFunctionCall(call_id="c1", name="fn", arguments='{"a":1}') + self.assertEqual(fc.type, "function_call") + self.assertEqual(fc.call_id, "c1") + self.assertEqual(fc.name, "fn") + self.assertEqual(fc.arguments, '{"a":1}') + + def test_output_reasoning_attributes(self): + r = OutputReasoning(summary="thinking") + self.assertEqual(r.type, "reasoning") + self.assertEqual(r.summary, "thinking") + + def test_llm_response_attributes(self): + resp = LLMResponse( + output=[OutputMessage(content="hi")], + output_text="hi", + ) + self.assertEqual(len(resp.output), 1) + self.assertEqual(resp.output_text, "hi") + + def test_llm_response_iteration(self): + """Verify output items can be iterated and type-checked like the OpenAI SDK.""" + items = [ + OutputReasoning(summary="step 1"), + OutputMessage(content="result"), + OutputFunctionCall(call_id="c1", name="fn", arguments="{}"), + ] + resp = LLMResponse(output=items, output_text="result") + + reasoning = [r for r in resp.output if r.type == "reasoning"] + self.assertEqual(len(reasoning), 1) + self.assertEqual(reasoning[0].summary, "step 1") + + tool_calls = [r for r in resp.output if r.type == "function_call"] + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0].name, "fn") + + +if __name__ == "__main__": + unittest.main() diff --git a/project/moirai-agent/ctx_forecast/tests/test_minimax_integration.py b/project/moirai-agent/ctx_forecast/tests/test_minimax_integration.py new file mode 100644 index 00000000..00a159d7 --- /dev/null +++ b/project/moirai-agent/ctx_forecast/tests/test_minimax_integration.py @@ -0,0 +1,209 @@ +""" +Integration tests for MiniMax LLM provider. + +These tests require a valid MINIMAX_API_KEY environment variable. +Skip if the key is not set. +""" + +import json +import os +import sys +import unittest + +sys.path.insert( + 0, + os.path.join(os.path.dirname(__file__), "..", "src"), +) + +from ctx_forecast.llm_provider import ( + ChatCompletionsProvider, + LLMResponse, + OutputFunctionCall, + OutputMessage, + create_llm_provider, +) + +MINIMAX_API_KEY = os.environ.get("MINIMAX_API_KEY", "") +SKIP_REASON = "MINIMAX_API_KEY not set" + + +@unittest.skipUnless(MINIMAX_API_KEY, SKIP_REASON) +class TestMiniMaxBasicCompletion(unittest.TestCase): + """Test basic text completion with MiniMax.""" + + def setUp(self): + self.config = { + "llm": { + "provider": "minimax", + "model_name": "MiniMax-M2.5-highspeed", + "model_params_type": { + "temperature": 0.1, + "max_output_tokens": 256, + }, + } + } + self.provider = create_llm_provider(self.config) + + def test_simple_text_response(self): + result = self.provider.create( + model="MiniMax-M2.5-highspeed", + instructions="You are a helpful assistant. Reply concisely.", + input=[{"role": "user", "content": "What is 2+2? Reply with just the number."}], + tools=[], + tool_choice="auto", + temperature=0.1, + max_output_tokens=64, + ) + self.assertIsInstance(result, LLMResponse) + self.assertIn("4", result.output_text) + self.assertTrue(len(result.output) > 0) + + def test_response_output_iteration(self): + """Verify the response output can be iterated and type-checked.""" + result = self.provider.create( + model="MiniMax-M2.5-highspeed", + instructions="Reply briefly.", + input=[{"role": "user", "content": "Say hi"}], + tools=[], + tool_choice="auto", + temperature=0.1, + max_output_tokens=32, + ) + messages = [r for r in result.output if r.type == "message"] + self.assertTrue(len(messages) >= 1) + self.assertTrue(len(messages[0].content) > 0) + + +@unittest.skipUnless(MINIMAX_API_KEY, SKIP_REASON) +class TestMiniMaxToolCalling(unittest.TestCase): + """Test function/tool calling with MiniMax.""" + + def setUp(self): + self.provider = create_llm_provider( + { + "llm": { + "provider": "minimax", + "model_name": "MiniMax-M2.5-highspeed", + "model_params_type": {}, + } + } + ) + self.tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get current weather for a city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + }, + "required": ["city"], + }, + } + ] + + def test_tool_call_triggered(self): + result = self.provider.create( + model="MiniMax-M2.5-highspeed", + instructions="You have access to tools. Use them when appropriate.", + input=[ + {"role": "user", "content": "What is the weather in Tokyo?"}, + ], + tools=self.tools, + tool_choice="auto", + temperature=0.1, + max_output_tokens=256, + ) + self.assertIsInstance(result, LLMResponse) + tool_calls = [r for r in result.output if r.type == "function_call"] + self.assertTrue(len(tool_calls) >= 1, "Expected at least one tool call") + tc = tool_calls[0] + self.assertEqual(tc.name, "get_weather") + args = json.loads(tc.arguments) + self.assertIn("city", args) + + +@unittest.skipUnless(MINIMAX_API_KEY, SKIP_REASON) +class TestMiniMaxMultiTurn(unittest.TestCase): + """Test multi-turn conversation with tool results fed back.""" + + def setUp(self): + self.provider = create_llm_provider( + { + "llm": { + "provider": "minimax", + "model_name": "MiniMax-M2.5-highspeed", + "model_params_type": {}, + } + } + ) + + def test_multi_turn_with_tool_result(self): + tools = [ + { + "type": "function", + "name": "calculate", + "description": "Evaluate a math expression", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Math expression", + } + }, + "required": ["expression"], + }, + } + ] + + # Round 1: ask question, expect tool call + r1 = self.provider.create( + model="MiniMax-M2.5-highspeed", + instructions="Use the calculate tool when asked math questions.", + input=[ + {"role": "user", "content": "What is 123 * 456?"}, + ], + tools=tools, + tool_choice="auto", + temperature=0.1, + max_output_tokens=256, + ) + + tc_items = [r for r in r1.output if r.type == "function_call"] + if not tc_items: + # Model answered directly – still valid + self.assertIn("56088", r1.output_text) + return + + # Build context for round 2 + context = [ + {"role": "user", "content": "What is 123 * 456?"}, + ] + context += r1.output # add assistant output objects + context.append( + { + "type": "function_call_output", + "call_id": tc_items[0].call_id, + "output": "56088", + } + ) + + # Round 2: feed tool result back + r2 = self.provider.create( + model="MiniMax-M2.5-highspeed", + instructions="Use the calculate tool when asked math questions.", + input=context, + tools=tools, + tool_choice="auto", + temperature=0.1, + max_output_tokens=256, + ) + # Model may format the number with commas (e.g., "56,088") + normalized = r2.output_text.replace(",", "") + self.assertIn("56088", normalized) + + +if __name__ == "__main__": + unittest.main()