diff --git a/src/openhuman/inference/provider/ops.rs b/src/openhuman/inference/provider/ops.rs index c6d02b7353..29e8185840 100644 --- a/src/openhuman/inference/provider/ops.rs +++ b/src/openhuman/inference/provider/ops.rs @@ -58,6 +58,7 @@ pub async fn list_configured_models( let api_key = crate::openhuman::inference::provider::factory::lookup_key_for_slug(&entry.slug, &config) .unwrap_or_default(); + let api_key = api_key.trim().to_string(); let client = crate::openhuman::config::build_runtime_proxy_client_with_timeouts( "providers.list_models", @@ -65,9 +66,13 @@ pub async fn list_configured_models( 10, ); + use crate::openhuman::config::schema::cloud_providers::AuthStyle; + if is_openrouter_provider(&entry) { + validate_openrouter_api_key(&client, base, &api_key).await?; + } + let mut request = client.get(&models_url); - use crate::openhuman::config::schema::cloud_providers::AuthStyle; request = match entry.auth_style { AuthStyle::Bearer => { if !api_key.is_empty() { @@ -182,6 +187,81 @@ pub async fn list_configured_models( )) } +fn is_openrouter_provider( + entry: &crate::openhuman::config::schema::cloud_providers::CloudProviderCreds, +) -> bool { + if entry.slug.eq_ignore_ascii_case("openrouter") { + return true; + } + + reqwest::Url::parse(&entry.endpoint) + .ok() + .and_then(|url| url.host_str().map(|host| host.to_ascii_lowercase())) + .is_some_and(|host| host == "openrouter.ai" || host.ends_with(".openrouter.ai")) +} + +async fn validate_openrouter_api_key( + client: &reqwest::Client, + base: &str, + api_key: &str, +) -> Result<(), String> { + if api_key.is_empty() { + return Err("OpenRouter API key is required before enabling the provider".to_string()); + } + + let key_url = format!("{}/key", base); + log::debug!("[providers][list_models] validating OpenRouter API key"); + let response = client + .get(&key_url) + .header("Authorization", format!("Bearer {api_key}")) + .send() + .await + .map_err(|e| format!("[providers][list_models] OpenRouter key validation failed: {e}"))?; + + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + if !status.is_success() { + let sanitized = sanitize_api_error(&text); + let truncated = crate::openhuman::util::truncate_with_ellipsis(&sanitized, 300); + log::debug!( + "[providers][list_models] OpenRouter key validation failed status={} body={}", + status.as_u16(), + truncated + ); + return Err(format!( + "OpenRouter key validation returned {}: {}", + status.as_u16(), + truncated + )); + } + + if let Ok(body) = serde_json::from_str::(&text) { + if let Some(err_field) = body.get("error") { + let msg = err_field + .as_str() + .map(|s| s.to_string()) + .or_else(|| { + err_field + .get("message") + .and_then(|m| m.as_str()) + .map(|s| s.to_string()) + }) + .unwrap_or_else(|| err_field.to_string()); + let sanitized = sanitize_api_error(&msg); + log::debug!( + "[providers][list_models] OpenRouter key validation returned error payload={}", + sanitized + ); + return Err(format!( + "OpenRouter key validation returned error payload: {}", + sanitized + )); + } + } + + Ok(()) +} + impl Default for ProviderRuntimeOptions { fn default() -> Self { Self { @@ -823,6 +903,169 @@ pub fn canonical_china_provider_name(_name: &str) -> Option<&'static str> { #[cfg(test)] mod tests { use super::*; + use crate::openhuman::config::schema::cloud_providers::{AuthStyle, CloudProviderCreds}; + use crate::openhuman::config::Config; + use crate::openhuman::credentials::AuthService; + use axum::{ + extract::State, + http::{HeaderMap, StatusCode}, + response::{IntoResponse, Response}, + routing::get, + Json, Router, + }; + use std::collections::HashMap; + use std::sync::{ + atomic::{AtomicUsize, Ordering as AtomicOrdering}, + Arc, Mutex, + }; + use tempfile::TempDir; + + #[derive(Clone)] + struct ModelProbeState { + key_status: StatusCode, + key_calls: Arc, + model_calls: Arc, + key_authorization: Arc>>>, + model_authorization: Arc>>>, + } + + struct WorkspaceEnvGuard { + prev: Option, + _lock: std::sync::MutexGuard<'static, ()>, + } + + impl Drop for WorkspaceEnvGuard { + fn drop(&mut self) { + unsafe { + match self.prev.take() { + Some(value) => std::env::set_var("OPENHUMAN_WORKSPACE", value), + None => std::env::remove_var("OPENHUMAN_WORKSPACE"), + } + } + } + } + + fn set_workspace_env(path: &std::path::Path) -> WorkspaceEnvGuard { + let lock = crate::openhuman::config::TEST_ENV_LOCK + .lock() + .unwrap_or_else(|e| e.into_inner()); + let prev = std::env::var_os("OPENHUMAN_WORKSPACE"); + unsafe { + std::env::set_var("OPENHUMAN_WORKSPACE", path); + } + WorkspaceEnvGuard { prev, _lock: lock } + } + + async fn openrouter_key_handler( + State(state): State, + headers: HeaderMap, + ) -> Response { + state.key_calls.fetch_add(1, AtomicOrdering::SeqCst); + state + .key_authorization + .lock() + .unwrap_or_else(|e| e.into_inner()) + .push(authorization_header(&headers)); + if state.key_status.is_success() { + Json(serde_json::json!({ + "data": { + "label": "test-key", + "usage": 0 + } + })) + .into_response() + } else { + ( + state.key_status, + Json(serde_json::json!({ + "error": { + "message": "No auth credentials found" + } + })), + ) + .into_response() + } + } + + async fn models_handler(State(state): State, headers: HeaderMap) -> Response { + state.model_calls.fetch_add(1, AtomicOrdering::SeqCst); + state + .model_authorization + .lock() + .unwrap_or_else(|e| e.into_inner()) + .push(authorization_header(&headers)); + Json(serde_json::json!({ + "data": [{ + "id": "openrouter/test-model", + "owned_by": "openrouter", + "context_length": 128000 + }] + })) + .into_response() + } + + fn authorization_header(headers: &HeaderMap) -> Option { + headers + .get("authorization") + .and_then(|value| value.to_str().ok()) + .map(|value| value.to_string()) + } + + async fn spawn_openrouter_probe_server(key_status: StatusCode) -> (String, ModelProbeState) { + let state = ModelProbeState { + key_status, + key_calls: Arc::new(AtomicUsize::new(0)), + model_calls: Arc::new(AtomicUsize::new(0)), + key_authorization: Arc::new(Mutex::new(Vec::new())), + model_authorization: Arc::new(Mutex::new(Vec::new())), + }; + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind"); + let addr = listener.local_addr().expect("local_addr"); + let app = Router::new() + .route("/key", get(openrouter_key_handler)) + .route("/models", get(models_handler)) + .with_state(state.clone()); + tokio::spawn(async move { + axum::serve(listener, app).await.expect("serve"); + }); + (format!("http://{addr}"), state) + } + + async fn configure_openrouter_workspace( + tmp: &TempDir, + endpoint: String, + token: &str, + ) -> Config { + let mut config = Config { + config_path: tmp.path().join("config.toml"), + workspace_dir: tmp.path().join("workspace"), + ..Config::default() + }; + config.secrets.encrypt = false; + config.cloud_providers.push(CloudProviderCreds { + id: "p_openrouter_test".to_string(), + slug: "openrouter".to_string(), + label: "OpenRouter".to_string(), + endpoint, + auth_style: AuthStyle::Bearer, + legacy_type: None, + default_model: None, + }); + config.save().await.expect("save config"); + + let auth = AuthService::from_config(&config); + auth.store_provider_token( + &crate::openhuman::inference::provider::factory::auth_key_for_slug("openrouter"), + "default", + token, + HashMap::new(), + true, + ) + .expect("store provider key"); + config + } #[test] fn list_configured_models_accepts_slug() { @@ -863,6 +1106,106 @@ mod tests { assert!(found_by_id.is_some(), "id lookup must still work"); } + #[test] + fn openrouter_detection_matches_builtin_slug_or_host() { + let provider = |slug: &str, endpoint: &str| CloudProviderCreds { + id: format!("p_{slug}"), + slug: slug.to_string(), + label: slug.to_string(), + endpoint: endpoint.to_string(), + auth_style: AuthStyle::Bearer, + legacy_type: None, + default_model: None, + }; + + assert!(is_openrouter_provider(&provider( + "openrouter", + "http://127.0.0.1:1234" + ))); + assert!(is_openrouter_provider(&provider( + "custom-router", + "https://openrouter.ai/api/v1" + ))); + assert!(is_openrouter_provider(&provider( + "custom-router", + "https://oauth.openrouter.ai/api/v1" + ))); + assert!(!is_openrouter_provider(&provider( + "custom-openai", + "https://api.openai.com/v1" + ))); + } + + #[tokio::test] + async fn openrouter_invalid_key_fails_before_models_catalog_probe() { + let tmp = tempfile::tempdir().expect("tempdir"); + let _env = set_workspace_env(tmp.path()); + let (endpoint, state) = spawn_openrouter_probe_server(StatusCode::UNAUTHORIZED).await; + configure_openrouter_workspace(&tmp, endpoint, "bad-openrouter-key").await; + + let err = list_configured_models("openrouter") + .await + .expect_err("invalid OpenRouter key must fail"); + + assert!( + err.contains("OpenRouter key validation returned 401"), + "unexpected error: {err}" + ); + assert_eq!(state.key_calls.load(AtomicOrdering::SeqCst), 1); + assert_eq!( + state.model_calls.load(AtomicOrdering::SeqCst), + 0, + "invalid OpenRouter credentials must not fall through to /models" + ); + } + + #[tokio::test] + async fn openrouter_valid_key_allows_models_catalog_probe() { + let tmp = tempfile::tempdir().expect("tempdir"); + let _env = set_workspace_env(tmp.path()); + let (endpoint, state) = spawn_openrouter_probe_server(StatusCode::OK).await; + configure_openrouter_workspace(&tmp, endpoint, "valid-openrouter-key").await; + + let outcome = list_configured_models("openrouter") + .await + .expect("valid OpenRouter key should list models"); + + assert_eq!(state.key_calls.load(AtomicOrdering::SeqCst), 1); + assert_eq!(state.model_calls.load(AtomicOrdering::SeqCst), 1); + assert_eq!(outcome.value["models"][0]["id"], "openrouter/test-model"); + } + + #[tokio::test] + async fn openrouter_key_is_trimmed_for_validation_and_catalog_probe() { + let tmp = tempfile::tempdir().expect("tempdir"); + let _env = set_workspace_env(tmp.path()); + let (endpoint, state) = spawn_openrouter_probe_server(StatusCode::OK).await; + configure_openrouter_workspace(&tmp, endpoint, " valid-openrouter-key\r\n").await; + + list_configured_models("openrouter") + .await + .expect("trimmed OpenRouter key should list models"); + + let key_authorization = state + .key_authorization + .lock() + .unwrap_or_else(|e| e.into_inner()) + .clone(); + let model_authorization = state + .model_authorization + .lock() + .unwrap_or_else(|e| e.into_inner()) + .clone(); + assert_eq!( + key_authorization, + vec![Some("Bearer valid-openrouter-key".to_string())] + ); + assert_eq!( + model_authorization, + vec![Some("Bearer valid-openrouter-key".to_string())] + ); + } + #[test] fn factory_backend() { assert!(create_backend_inference_provider(