Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 62 additions & 5 deletions python/uagents-core/uagents_core/utils/resolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""This module provides methods to resolve an agent address."""

import urllib.parse
import ipaddress
import socket
from typing import Any

import requests
Expand All @@ -19,6 +21,47 @@
logger = get_logger("uagents_core.utils.resolver")


def is_safe_url(url: str) -> bool:
"""
Check if a URL is safe for an agent to connect to.
Prevents SSRF by blocking private/internal IP addresses and cloud metadata.
"""
try:
parsed_url = urllib.parse.urlsplit(url)
hostname = parsed_url.hostname
if not hostname:
return False

# Block common local hostnames
if hostname.lower() in ["localhost", "127.0.0.1", "::1"]:
return False

# Check if hostname is an IP address
try:
ip = ipaddress.ip_address(hostname)
# is_private covers RFC 1918 (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16)
# it also covers loopback and link-local (169.254.0.0/16)
return not (ip.is_private or ip.is_loopback or ip.is_link_local)
except ValueError:
# If not an IP address, resolve hostname to check IPs
# Note: This doesn't prevent DNS Rebinding, but adds a basic layer
try:
# Basic check: hostname shouldn't resolve to a private IP
addr_info = socket.getaddrinfo(hostname, None)
for info in addr_info:
ip_str = info[4][0]
ip = ipaddress.ip_address(ip_str)
if ip.is_private or ip.is_loopback or ip.is_link_local:
return False
except socket.gaierror:
# If resolution fails, we let it pass for now as the request will fail anyway
pass

return True
except Exception:
return False


def lookup_address_for_domain(
agent_identifier: str,
*,
Expand Down Expand Up @@ -69,15 +112,15 @@ def lookup_endpoint_for_agent(
max_endpoints: int = DEFAULT_MAX_ENDPOINTS,
agentverse_config: AgentverseConfig | None = None,
) -> list[str]:
"""
\"\"\"
Resolve the endpoints for an agent using the Almanac API.

Args:
destination (str): The destination address to look up.

Returns:
List[str]: The endpoint(s) for the agent.
"""
\"\"\"

agentverse_config = agentverse_config or AgentverseConfig()
almanac_api = urllib.parse.urljoin(agentverse_config.url, DEFAULT_ALMANAC_API_PATH)
Expand Down Expand Up @@ -124,12 +167,26 @@ def lookup_endpoint_for_agent(
endpoints: list = response.json().get("endpoints", [])

if len(endpoints) > 0:
urls = [val.get("url") for val in endpoints]
weights = [val.get("weight") for val in endpoints]
safe_endpoints = []
for val in endpoints:
url = val.get("url")
if is_safe_url(url):
safe_endpoints.append(val)
else:
logger.warning(
"Skipping unsafe endpoint",
extra={"agent_address": agent_address, "url": url},
)

if len(safe_endpoints) == 0:
return []

urls = [val.get("url") for val in safe_endpoints]
weights = [val.get("weight") for val in safe_endpoints]
return weighted_random_sample(
items=urls,
weights=weights,
k=min(max_endpoints, len(endpoints)),
k=min(max_endpoints, len(safe_endpoints)),
)

return []
Expand Down
Loading