Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderConfig


CONTEXTUALIZED_CHUNK_SIZE = 32000


class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for VoyageAI models."""

Expand Down Expand Up @@ -50,9 +53,25 @@ def __call__(self, input: Documents) -> Embeddings:
if isinstance(input, str):
input = [input]

model = self._config.get("model", "voyage-2")

if model.startswith("voyage-context"):
inputs = [[s] for s in input]
ctx_result = self._client.contextualized_embed(
inputs=inputs,
model=model,
input_type="document",
output_dtype=self._config.get("output_dtype"),
output_dimension=self._config.get("output_dimension"),
)
return cast(
Embeddings,
[r.embeddings[0] for r in ctx_result.results],
)

result = self._client.embed(
texts=input,
model=self._config.get("model", "voyage-2"),
model=model,
input_type=self._config.get("input_type"),
truncation=self._config.get("truncation", True),
output_dtype=self._config.get("output_dtype"),
Expand Down
107 changes: 107 additions & 0 deletions lib/crewai/tests/rag/embeddings/test_voyageai_embedding_callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Tests for the VoyageAI embedding function."""

from unittest.mock import MagicMock, patch

import numpy as np

from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
CONTEXTUALIZED_CHUNK_SIZE,
VoyageAIEmbeddingFunction,
)


class TestVoyageAIEmbeddingFunction:
"""Test the VoyageAI embedding function call routing."""

def test_standard_model_uses_embed(self):
"""Standard models should call the regular embed endpoint."""
with patch("voyageai.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.embed.return_value = MagicMock(embeddings=[[0.1, 0.2]])

fn = VoyageAIEmbeddingFunction(api_key="voyage-key", model="voyage-2")
result = fn(["aa", "bb"])

mock_client.embed.assert_called_once()
mock_client.contextualized_embed.assert_not_called()
assert np.allclose(result, [[0.1, 0.2]])

def test_contextualized_model_uses_contextualized_embed(self):
"""voyage-context-4 should call the contextualized embeddings endpoint."""
with patch("voyageai.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.contextualized_embed.return_value = MagicMock(
results=[
MagicMock(embeddings=[[0.1, 0.2]]),
MagicMock(embeddings=[[0.3, 0.4]]),
]
)

fn = VoyageAIEmbeddingFunction(
api_key="voyage-key", model="voyage-context-4"
)
result = fn(["aa", "bb"])

mock_client.embed.assert_not_called()
mock_client.contextualized_embed.assert_called_once()
assert np.allclose(result, [[0.1, 0.2], [0.3, 0.4]])

def test_contextualized_call_wraps_inputs_as_list_of_lists(self):
"""Each input string is wrapped as its own single-chunk document (List[List[str]])."""
with patch("voyageai.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.contextualized_embed.return_value = MagicMock(
results=[MagicMock(embeddings=[[0.1, 0.2]])]
)

fn = VoyageAIEmbeddingFunction(
api_key="voyage-key", model="voyage-context-4"
)
fn(["aa"])

_, kwargs = mock_client.contextualized_embed.call_args
# Each string is wrapped as its own single-chunk document
assert kwargs["inputs"] == [["aa"]]
# chunk_size and enable_auto_chunking must NOT be passed
assert "chunk_size" not in kwargs
assert "enable_auto_chunking" not in kwargs

def test_contextualized_input_is_list_of_lists(self):
"""Input must be passed as List[List[str]], each inner list is one document with its chunks."""
with patch("voyageai.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.contextualized_embed.return_value = MagicMock(
results=[
MagicMock(embeddings=[[0.1, 0.2]]),
MagicMock(embeddings=[[0.3, 0.4]]),
]
)

fn = VoyageAIEmbeddingFunction(
api_key="voyage-key", model="voyage-context-4"
)
fn(["aa", "bb"])

_, kwargs = mock_client.contextualized_embed.call_args
assert kwargs["inputs"] == [["aa"], ["bb"]]

def test_contextualized_string_input_normalized_with_wrapping(self):
"""A single string input is normalized and wrapped as a single-chunk document."""
with patch("voyageai.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.contextualized_embed.return_value = MagicMock(
results=[MagicMock(embeddings=[[0.1, 0.2]])]
)

fn = VoyageAIEmbeddingFunction(
api_key="voyage-key", model="voyage-context-4"
)
fn("aa")

_, kwargs = mock_client.contextualized_embed.call_args
assert kwargs["inputs"] == [["aa"]]
Loading