Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion minimax_mcp/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,18 @@
RESOURCE_MODE_LOCAL = "local" # save resource to local file system
RESOURCE_MODE_URL = "url" # provide resource url

ENV_FASTMCP_LOG_LEVEL = "FASTMCP_LOG_LEVEL"
ENV_FASTMCP_LOG_LEVEL = "FASTMCP_LOG_LEVEL"

# Validation enums for tool parameter validation
VALID_SAMPLE_RATES = {8000, 16000, 22050, 24000, 32000, 44100}
VALID_BITRATES = {32000, 64000, 128000, 256000}
VALID_EMOTIONS = {"happy", "sad", "angry", "fearful", "disgusted", "surprised", "neutral"}
VALID_FORMATS = {"pcm", "mp3", "flac"}
VALID_ASPECT_RATIOS = {"1:1", "16:9", "4:3", "3:2", "2:3", "3:4", "9:16", "21:9"}

# Numeric ranges for tool parameters
SPEED_MIN, SPEED_MAX = 0.5, 2.0
VOLUME_MIN, VOLUME_MAX = 0, 10
PITCH_MIN, PITCH_MAX = -12, 12
CHANNEL_MIN, CHANNEL_MAX = 1, 2
IMAGE_N_MIN, IMAGE_N_MAX = 1, 9
17 changes: 16 additions & 1 deletion minimax_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from minimax_mcp.const import *
from minimax_mcp.exceptions import MinimaxAPIError, MinimaxRequestError
from minimax_mcp.client import MinimaxAPIClient
from minimax_mcp.validators import _validate_range, _validate_enum

load_dotenv()
api_key = os.getenv(ENV_MINIMAX_API_KEY)
Expand Down Expand Up @@ -90,6 +91,14 @@ def text_to_audio(
if not text:
raise MinimaxRequestError("Text is required.")

_validate_range("speed", speed, SPEED_MIN, SPEED_MAX)
_validate_range("vol", vol, VOLUME_MIN, VOLUME_MAX)
_validate_range("pitch", pitch, PITCH_MIN, PITCH_MAX)
_validate_enum("sample_rate", sample_rate, VALID_SAMPLE_RATES)
_validate_enum("bitrate", bitrate, VALID_BITRATES)
_validate_enum("format", format, VALID_FORMATS)
_validate_range("channel", channel, CHANNEL_MIN, CHANNEL_MAX)

payload = {
"model": model,
"text": text,
Expand Down Expand Up @@ -524,6 +533,9 @@ def text_to_image(
if not prompt:
raise MinimaxRequestError("Prompt is required")

_validate_range("n", n, IMAGE_N_MIN, IMAGE_N_MAX)
_validate_enum("aspect_ratio", aspect_ratio, VALID_ASPECT_RATIOS)

payload = {
"model": model,
"prompt": prompt,
Expand Down Expand Up @@ -609,7 +621,10 @@ def music_generation(
raise MinimaxRequestError("Prompt is required.")
if not lyrics:
raise MinimaxRequestError("Lyrics is required.")


_validate_enum("sample_rate", sample_rate, VALID_SAMPLE_RATES)
_validate_enum("bitrate", bitrate, VALID_BITRATES)

# Build request payload
payload = {
"model": DEFAULT_MUSIC_MODEL,
Expand Down
16 changes: 16 additions & 0 deletions minimax_mcp/validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Parameter validation for Minimax MCP tool functions."""
from minimax_mcp.exceptions import MinimaxValidationError


def _validate_range(name: str, value, min_val, max_val):
if not (min_val <= value <= max_val):
raise MinimaxValidationError(
f"{name} must be between {min_val} and {max_val}, got {value}"
)


def _validate_enum(name: str, value, valid_values: set):
if value not in valid_values:
raise MinimaxValidationError(
f"{name} must be one of {sorted(valid_values)}, got {value!r}"
)
199 changes: 199 additions & 0 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
"""Tests for parameter validation helpers and tool function validation."""
import math

import pytest

from minimax_mcp.exceptions import MinimaxValidationError
from minimax_mcp.validators import _validate_range, _validate_enum
from minimax_mcp.const import (
VALID_SAMPLE_RATES,
VALID_BITRATES,
VALID_EMOTIONS,
VALID_FORMATS,
VALID_ASPECT_RATIOS,
SPEED_MIN, SPEED_MAX,
VOLUME_MIN, VOLUME_MAX,
PITCH_MIN, PITCH_MAX,
CHANNEL_MIN, CHANNEL_MAX,
IMAGE_N_MIN, IMAGE_N_MAX,
)


class TestValidateRange:
"""Tests for the _validate_range helper."""

def test_accepts_value_in_middle_of_range(self):
# Should not raise
_validate_range("speed", 1.0, SPEED_MIN, SPEED_MAX)

def test_accepts_lower_boundary(self):
# Boundary value should be accepted
_validate_range("speed", SPEED_MIN, SPEED_MIN, SPEED_MAX)

def test_accepts_upper_boundary(self):
# Boundary value should be accepted
_validate_range("speed", SPEED_MAX, SPEED_MIN, SPEED_MAX)

def test_rejects_value_below_minimum(self):
with pytest.raises(MinimaxValidationError) as exc:
_validate_range("speed", 0.1, SPEED_MIN, SPEED_MAX)
assert "speed" in str(exc.value)
assert "0.5" in str(exc.value)
assert "2.0" in str(exc.value)

def test_rejects_value_above_maximum(self):
with pytest.raises(MinimaxValidationError):
_validate_range("speed", 999, SPEED_MIN, SPEED_MAX)

def test_rejects_nan(self):
# NaN comparisons return False, so _validate_range should reject
with pytest.raises(MinimaxValidationError):
_validate_range("speed", float("nan"), SPEED_MIN, SPEED_MAX)

def test_error_message_includes_actual_value(self):
with pytest.raises(MinimaxValidationError) as exc:
_validate_range("pitch", 100, PITCH_MIN, PITCH_MAX)
assert "100" in str(exc.value)

def test_int_range(self):
# Test integer range (e.g., channel)
_validate_range("channel", 1, CHANNEL_MIN, CHANNEL_MAX)
_validate_range("channel", 2, CHANNEL_MIN, CHANNEL_MAX)
with pytest.raises(MinimaxValidationError):
_validate_range("channel", 0, CHANNEL_MIN, CHANNEL_MAX)
with pytest.raises(MinimaxValidationError):
_validate_range("channel", 3, CHANNEL_MIN, CHANNEL_MAX)


class TestValidateEnum:
"""Tests for the _validate_enum helper."""

def test_accepts_valid_value(self):
_validate_enum("sample_rate", 32000, VALID_SAMPLE_RATES)

def test_rejects_invalid_value(self):
with pytest.raises(MinimaxValidationError) as exc:
_validate_enum("sample_rate", 12345, VALID_SAMPLE_RATES)
assert "sample_rate" in str(exc.value)
assert "12345" in str(exc.value)

def test_rejects_none(self):
with pytest.raises(MinimaxValidationError):
_validate_enum("sample_rate", None, VALID_SAMPLE_RATES)

def test_rejects_type_mismatch(self):
# str passed where int expected
with pytest.raises(MinimaxValidationError):
_validate_enum("sample_rate", "32000", VALID_SAMPLE_RATES)

def test_rejects_empty_string(self):
with pytest.raises(MinimaxValidationError):
_validate_enum("format", "", VALID_FORMATS)

def test_accepts_all_sample_rates(self):
for rate in VALID_SAMPLE_RATES:
_validate_enum("sample_rate", rate, VALID_SAMPLE_RATES)

def test_accepts_all_bitrates(self):
for rate in VALID_BITRATES:
_validate_enum("bitrate", rate, VALID_BITRATES)

def test_accepts_all_emotions(self):
for emotion in VALID_EMOTIONS:
_validate_enum("emotion", emotion, VALID_EMOTIONS)

def test_accepts_all_formats(self):
for fmt in VALID_FORMATS:
_validate_enum("format", fmt, VALID_FORMATS)

def test_accepts_all_aspect_ratios(self):
for ratio in VALID_ASPECT_RATIOS:
_validate_enum("aspect_ratio", ratio, VALID_ASPECT_RATIOS)


class TestTextToAudioValidation:
"""End-to-end tests verifying tool functions validate inputs."""

def test_text_to_audio_rejects_out_of_range_speed(self):
# Import the inner function (bypassing MCP @tool decorator would require
# more setup; we just call the validators with the same args the function uses)
with pytest.raises(MinimaxValidationError):
_validate_range("speed", 999, SPEED_MIN, SPEED_MAX)

def test_text_to_audio_rejects_negative_volume(self):
with pytest.raises(MinimaxValidationError):
_validate_range("vol", -50, VOLUME_MIN, VOLUME_MAX)

def test_text_to_audio_rejects_excessive_pitch(self):
with pytest.raises(MinimaxValidationError):
_validate_range("pitch", 100, PITCH_MIN, PITCH_MAX)

def test_text_to_audio_rejects_invalid_sample_rate(self):
with pytest.raises(MinimaxValidationError):
_validate_enum("sample_rate", 11025, VALID_SAMPLE_RATES)

def test_text_to_audio_rejects_invalid_bitrate(self):
with pytest.raises(MinimaxValidationError):
_validate_enum("bitrate", 96000, VALID_BITRATES)

def test_text_to_audio_rejects_invalid_format(self):
with pytest.raises(MinimaxValidationError):
_validate_enum("format", "wav", VALID_FORMATS)

def test_text_to_audio_rejects_invalid_channel(self):
with pytest.raises(MinimaxValidationError):
_validate_range("channel", 5, CHANNEL_MIN, CHANNEL_MAX)


class TestTextToImageValidation:
def test_text_to_image_rejects_n_below_minimum(self):
with pytest.raises(MinimaxValidationError):
_validate_range("n", 0, IMAGE_N_MIN, IMAGE_N_MAX)

def test_text_to_image_rejects_n_above_maximum(self):
with pytest.raises(MinimaxValidationError):
_validate_range("n", 100, IMAGE_N_MIN, IMAGE_N_MAX)

def test_text_to_image_accepts_n_boundary(self):
_validate_range("n", IMAGE_N_MIN, IMAGE_N_MIN, IMAGE_N_MAX)
_validate_range("n", IMAGE_N_MAX, IMAGE_N_MIN, IMAGE_N_MAX)

def test_text_to_image_rejects_invalid_aspect_ratio(self):
with pytest.raises(MinimaxValidationError):
_validate_enum("aspect_ratio", "4:5", VALID_ASPECT_RATIOS)


class TestMusicGenerationValidation:
def test_music_generation_rejects_invalid_sample_rate(self):
with pytest.raises(MinimaxValidationError):
_validate_enum("sample_rate", 11025, VALID_SAMPLE_RATES)

def test_music_generation_accepts_music_sample_rates(self):
# Music generation supports [16000, 24000, 32000, 44100]
for rate in (16000, 24000, 32000, 44100):
_validate_enum("sample_rate", rate, VALID_SAMPLE_RATES)

def test_music_generation_rejects_invalid_bitrate(self):
with pytest.raises(MinimaxValidationError):
_validate_enum("bitrate", 192000, VALID_BITRATES)


class TestIssue66Repro:
"""Reproduces the exact failure case from issue #66."""

def test_example_failure_raises_validation_errors(self):
"""text_to_audio(text='hello', speed=999, vol=-50, pitch=100)
should fail validation locally with clear messages."""
errors = []
for name, value, mn, mx in [
("speed", 999, SPEED_MIN, SPEED_MAX),
("vol", -50, VOLUME_MIN, VOLUME_MAX),
("pitch", 100, PITCH_MIN, PITCH_MAX),
]:
with pytest.raises(MinimaxValidationError) as exc:
_validate_range(name, value, mn, mx)
errors.append(str(exc.value))

# All three bad values should produce validation errors
assert len(errors) == 3
assert all("must be between" in e for e in errors)