Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Any

from pydantic import BaseModel, Field
Expand All @@ -6,6 +7,26 @@
from crewai_tools.tools.rag.rag_tool import RagTool


_MYSQL_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_$]*$")


def _quote_mysql_table_name(table_name: str) -> str:
identifier_parts = table_name.split(".")
if (
not identifier_parts
or len(identifier_parts) > 2
or any(
not _MYSQL_IDENTIFIER_PATTERN.fullmatch(part) for part in identifier_parts
)
):
raise ValueError(
"MySQL table_name must be a valid table identifier or schema.table "
"identifier"
)
Comment on lines +22 to +25

return ".".join(f"`{part}`" for part in identifier_parts)


class MySQLSearchToolSchema(BaseModel):
"""Input for MySQLSearchTool."""

Expand All @@ -32,7 +53,8 @@ def add( # type: ignore[override]
table_name: str,
**kwargs: Any,
) -> None:
super().add(f"SELECT * FROM {table_name};", **kwargs) # noqa: S608
quoted_table_name = _quote_mysql_table_name(table_name)
super().add(f"SELECT * FROM {quoted_table_name};", **kwargs) # noqa: S608

def _run( # type: ignore[override]
self,
Expand Down
110 changes: 110 additions & 0 deletions lib/crewai-tools/tests/tools/test_mysql_search_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from unittest.mock import MagicMock, patch

import pytest

from crewai_tools.rag.data_types import DataType
from crewai_tools.tools.mysql_search_tool.mysql_search_tool import MySQLSearchTool
from crewai_tools.tools.rag.rag_tool import RagTool


@pytest.fixture
def mock_rag_client() -> MagicMock:
mock_client = MagicMock()
mock_client.get_or_create_collection = MagicMock(return_value=None)
mock_client.add_documents = MagicMock(return_value=None)
mock_client.search = MagicMock(return_value=[])
return mock_client


def create_mysql_search_tool(
mock_rag_client: MagicMock, table_name: str
) -> MySQLSearchTool:
with (
patch(
"crewai_tools.adapters.crewai_rag_adapter.get_rag_client",
return_value=mock_rag_client,
),
patch(
"crewai_tools.adapters.crewai_rag_adapter.create_client",
return_value=mock_rag_client,
),
):
return MySQLSearchTool(
db_uri="mysql://user:password@localhost:3306/test_database",
table_name=table_name,
)


@pytest.mark.parametrize(
("table_name", "expected_query"),
[
("users", "SELECT * FROM `users`;"),
("user_profiles_2026", "SELECT * FROM `user_profiles_2026`;"),
("schema_name.users", "SELECT * FROM `schema_name`.`users`;"),
("information_schema.tables", "SELECT * FROM `information_schema`.`tables`;"),
],
)
def test_mysql_search_tool_quotes_valid_table_identifiers(
mock_rag_client: MagicMock, table_name: str, expected_query: str
) -> None:
with patch.object(RagTool, "add", return_value=None) as mock_add:
create_mysql_search_tool(mock_rag_client, table_name)

mock_add.assert_called_once_with(
expected_query,
data_type=DataType.MYSQL,
metadata={"db_uri": "mysql://user:password@localhost:3306/test_database"},
)


@pytest.mark.parametrize(
"table_name",
[
"users where 1=1",
"users; drop table users;--",
"users -- comment",
"users/*comment*/",
"`users`",
"schema.users.extra",
"schema.",
".users",
"123users",
],
)
def test_mysql_search_tool_rejects_invalid_table_identifiers(
mock_rag_client: MagicMock, table_name: str
) -> None:
with (
patch.object(RagTool, "add", return_value=None) as mock_add,
pytest.raises(ValueError, match="MySQL table_name must be a valid"),
):
create_mysql_search_tool(mock_rag_client, table_name)

mock_add.assert_not_called()


def test_mysql_search_tool_still_runs_search_queries(
mock_rag_client: MagicMock,
) -> None:
with patch.object(RagTool, "add", return_value=None):
tool = create_mysql_search_tool(mock_rag_client, "users")

with patch.object(RagTool, "_run", return_value="Alice") as mock_run:
result = tool._run("alice")

assert "Alice" in result
mock_run.assert_called_once_with(
query="alice", similarity_threshold=None, limit=None
)


def test_mysql_search_tool_uses_mysql_data_type_metadata(
mock_rag_client: MagicMock,
) -> None:
with patch.object(RagTool, "add", return_value=None) as mock_add:
create_mysql_search_tool(mock_rag_client, "users")

assert mock_add.call_args.kwargs == {
"data_type": DataType.MYSQL,
"metadata": {"db_uri": "mysql://user:password@localhost:3306/test_database"},
}
Loading