diff --git a/python/uagents-core/uagents_core/utils/resolver.py b/python/uagents-core/uagents_core/utils/resolver.py index 05262316..5e29e935 100644 --- a/python/uagents-core/uagents_core/utils/resolver.py +++ b/python/uagents-core/uagents_core/utils/resolver.py @@ -1,6 +1,8 @@ """This module provides methods to resolve an agent address.""" import urllib.parse +import ipaddress +import socket from typing import Any import requests @@ -19,6 +21,47 @@ logger = get_logger("uagents_core.utils.resolver") +def is_safe_url(url: str) -> bool: + """ + Check if a URL is safe for an agent to connect to. + Prevents SSRF by blocking private/internal IP addresses and cloud metadata. + """ + try: + parsed_url = urllib.parse.urlsplit(url) + hostname = parsed_url.hostname + if not hostname: + return False + + # Block common local hostnames + if hostname.lower() in ["localhost", "127.0.0.1", "::1"]: + return False + + # Check if hostname is an IP address + try: + ip = ipaddress.ip_address(hostname) + # is_private covers RFC 1918 (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16) + # it also covers loopback and link-local (169.254.0.0/16) + return not (ip.is_private or ip.is_loopback or ip.is_link_local) + except ValueError: + # If not an IP address, resolve hostname to check IPs + # Note: This doesn't prevent DNS Rebinding, but adds a basic layer + try: + # Basic check: hostname shouldn't resolve to a private IP + addr_info = socket.getaddrinfo(hostname, None) + for info in addr_info: + ip_str = info[4][0] + ip = ipaddress.ip_address(ip_str) + if ip.is_private or ip.is_loopback or ip.is_link_local: + return False + except socket.gaierror: + # If resolution fails, we let it pass for now as the request will fail anyway + pass + + return True + except Exception: + return False + + def lookup_address_for_domain( agent_identifier: str, *, @@ -69,7 +112,7 @@ def lookup_endpoint_for_agent( max_endpoints: int = DEFAULT_MAX_ENDPOINTS, agentverse_config: AgentverseConfig | None = None, ) -> list[str]: - """ + \"\"\" Resolve the endpoints for an agent using the Almanac API. Args: @@ -77,7 +120,7 @@ def lookup_endpoint_for_agent( Returns: List[str]: The endpoint(s) for the agent. - """ + \"\"\" agentverse_config = agentverse_config or AgentverseConfig() almanac_api = urllib.parse.urljoin(agentverse_config.url, DEFAULT_ALMANAC_API_PATH) @@ -124,12 +167,26 @@ def lookup_endpoint_for_agent( endpoints: list = response.json().get("endpoints", []) if len(endpoints) > 0: - urls = [val.get("url") for val in endpoints] - weights = [val.get("weight") for val in endpoints] + safe_endpoints = [] + for val in endpoints: + url = val.get("url") + if is_safe_url(url): + safe_endpoints.append(val) + else: + logger.warning( + "Skipping unsafe endpoint", + extra={"agent_address": agent_address, "url": url}, + ) + + if len(safe_endpoints) == 0: + return [] + + urls = [val.get("url") for val in safe_endpoints] + weights = [val.get("weight") for val in safe_endpoints] return weighted_random_sample( items=urls, weights=weights, - k=min(max_endpoints, len(endpoints)), + k=min(max_endpoints, len(safe_endpoints)), ) return []