Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderConfig


CONTEXTUALIZED_CHUNK_SIZE = 32000


class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for VoyageAI models."""

Expand Down Expand Up @@ -50,9 +53,26 @@ def __call__(self, input: Documents) -> Embeddings:
if isinstance(input, str):
input = [input]

model = self._config.get("model", "voyage-2")

if model.startswith("voyage-context"):
result = self._client.contextualized_embed(
inputs=input,
model=model,
input_type="document",
output_dtype=self._config.get("output_dtype"),
output_dimension=self._config.get("output_dimension"),
enable_auto_chunking=True,
chunk_size=CONTEXTUALIZED_CHUNK_SIZE,
Comment thread
fzowl marked this conversation as resolved.
Outdated
)
return cast(
Embeddings,
[embedding for r in result.results for embedding in r.embeddings],
)

result = self._client.embed(
texts=input,
model=self._config.get("model", "voyage-2"),
model=model,
input_type=self._config.get("input_type"),
truncation=self._config.get("truncation", True),
output_dtype=self._config.get("output_dtype"),
Expand Down
104 changes: 104 additions & 0 deletions lib/crewai/tests/rag/embeddings/test_voyageai_embedding_callable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Tests for the VoyageAI embedding function."""

from unittest.mock import MagicMock, patch

import numpy as np

from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
CONTEXTUALIZED_CHUNK_SIZE,
VoyageAIEmbeddingFunction,
)


class TestVoyageAIEmbeddingFunction:
"""Test the VoyageAI embedding function call routing."""

def test_standard_model_uses_embed(self):
"""Standard models should call the regular embed endpoint."""
with patch("voyageai.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.embed.return_value = MagicMock(embeddings=[[0.1, 0.2]])

fn = VoyageAIEmbeddingFunction(api_key="voyage-key", model="voyage-2")
result = fn(["aa", "bb"])

mock_client.embed.assert_called_once()
mock_client.contextualized_embed.assert_not_called()
assert np.allclose(result, [[0.1, 0.2]])

def test_contextualized_model_uses_contextualized_embed(self):
"""voyage-context-4 should call the contextualized embeddings endpoint."""
with patch("voyageai.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.contextualized_embed.return_value = MagicMock(
results=[
MagicMock(embeddings=[[0.1, 0.2]]),
MagicMock(embeddings=[[0.3, 0.4]]),
]
)

fn = VoyageAIEmbeddingFunction(
api_key="voyage-key", model="voyage-context-4"
)
result = fn(["aa", "bb"])

mock_client.embed.assert_not_called()
mock_client.contextualized_embed.assert_called_once()
assert np.allclose(result, [[0.1, 0.2], [0.3, 0.4]])

def test_contextualized_call_sets_chunk_size_to_max(self):
"""chunk_size must be set to 32000 on every contextualized call."""
with patch("voyageai.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.contextualized_embed.return_value = MagicMock(
results=[MagicMock(embeddings=[[0.1, 0.2]])]
)

fn = VoyageAIEmbeddingFunction(
api_key="voyage-key", model="voyage-context-4"
)
fn(["aa"])

_, kwargs = mock_client.contextualized_embed.call_args
assert kwargs["chunk_size"] == CONTEXTUALIZED_CHUNK_SIZE
assert CONTEXTUALIZED_CHUNK_SIZE == 32000

def test_contextualized_input_is_flat_list(self):
"""Input must be passed as a flat List[str], not wrapped in an extra list."""
with patch("voyageai.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.contextualized_embed.return_value = MagicMock(
results=[
MagicMock(embeddings=[[0.1, 0.2]]),
MagicMock(embeddings=[[0.3, 0.4]]),
]
)

fn = VoyageAIEmbeddingFunction(
api_key="voyage-key", model="voyage-context-4"
)
fn(["aa", "bb"])

_, kwargs = mock_client.contextualized_embed.call_args
assert kwargs["inputs"] == ["aa", "bb"]

def test_contextualized_string_input_normalized_to_flat_list(self):
"""A single string input is normalized to a flat list of one string."""
with patch("voyageai.Client") as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.contextualized_embed.return_value = MagicMock(
results=[MagicMock(embeddings=[[0.1, 0.2]])]
)

fn = VoyageAIEmbeddingFunction(
api_key="voyage-key", model="voyage-context-4"
)
fn("aa")

_, kwargs = mock_client.contextualized_embed.call_args
assert kwargs["inputs"] == ["aa"]
Loading