Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
345 changes: 344 additions & 1 deletion src/openhuman/inference/provider/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,21 @@ pub async fn list_configured_models(
let api_key =
crate::openhuman::inference::provider::factory::lookup_key_for_slug(&entry.slug, &config)
.unwrap_or_default();
let api_key = api_key.trim().to_string();

let client = crate::openhuman::config::build_runtime_proxy_client_with_timeouts(
"providers.list_models",
30,
10,
);

use crate::openhuman::config::schema::cloud_providers::AuthStyle;
if is_openrouter_provider(&entry) {
validate_openrouter_api_key(&client, base, &api_key).await?;
}

let mut request = client.get(&models_url);

use crate::openhuman::config::schema::cloud_providers::AuthStyle;
request = match entry.auth_style {
AuthStyle::Bearer => {
if !api_key.is_empty() {
Expand Down Expand Up @@ -182,6 +187,81 @@ pub async fn list_configured_models(
))
}

fn is_openrouter_provider(
entry: &crate::openhuman::config::schema::cloud_providers::CloudProviderCreds,
) -> bool {
if entry.slug.eq_ignore_ascii_case("openrouter") {
return true;
}

reqwest::Url::parse(&entry.endpoint)
.ok()
.and_then(|url| url.host_str().map(|host| host.to_ascii_lowercase()))
.is_some_and(|host| host == "openrouter.ai" || host.ends_with(".openrouter.ai"))
}

async fn validate_openrouter_api_key(
client: &reqwest::Client,
base: &str,
api_key: &str,
) -> Result<(), String> {
if api_key.is_empty() {
return Err("OpenRouter API key is required before enabling the provider".to_string());
}

let key_url = format!("{}/key", base);
log::debug!("[providers][list_models] validating OpenRouter API key");
let response = client
.get(&key_url)
.header("Authorization", format!("Bearer {api_key}"))
.send()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
.await
.map_err(|e| format!("[providers][list_models] OpenRouter key validation failed: {e}"))?;

let status = response.status();
let text = response.text().await.unwrap_or_default();
if !status.is_success() {
let sanitized = sanitize_api_error(&text);
let truncated = crate::openhuman::util::truncate_with_ellipsis(&sanitized, 300);
log::debug!(
"[providers][list_models] OpenRouter key validation failed status={} body={}",
status.as_u16(),
truncated
);
return Err(format!(
"OpenRouter key validation returned {}: {}",
status.as_u16(),
truncated
));
}

if let Ok(body) = serde_json::from_str::<serde_json::Value>(&text) {
if let Some(err_field) = body.get("error") {
let msg = err_field
.as_str()
.map(|s| s.to_string())
.or_else(|| {
err_field
.get("message")
.and_then(|m| m.as_str())
.map(|s| s.to_string())
})
.unwrap_or_else(|| err_field.to_string());
let sanitized = sanitize_api_error(&msg);
log::debug!(
"[providers][list_models] OpenRouter key validation returned error payload={}",
sanitized
);
return Err(format!(
"OpenRouter key validation returned error payload: {}",
sanitized
));
}
}

Ok(())
}

impl Default for ProviderRuntimeOptions {
fn default() -> Self {
Self {
Expand Down Expand Up @@ -823,6 +903,169 @@ pub fn canonical_china_provider_name(_name: &str) -> Option<&'static str> {
#[cfg(test)]
mod tests {
use super::*;
use crate::openhuman::config::schema::cloud_providers::{AuthStyle, CloudProviderCreds};
use crate::openhuman::config::Config;
use crate::openhuman::credentials::AuthService;
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
routing::get,
Json, Router,
};
use std::collections::HashMap;
use std::sync::{
atomic::{AtomicUsize, Ordering as AtomicOrdering},
Arc, Mutex,
};
use tempfile::TempDir;

#[derive(Clone)]
struct ModelProbeState {
key_status: StatusCode,
key_calls: Arc<AtomicUsize>,
model_calls: Arc<AtomicUsize>,
key_authorization: Arc<Mutex<Vec<Option<String>>>>,
model_authorization: Arc<Mutex<Vec<Option<String>>>>,
}

struct WorkspaceEnvGuard {
prev: Option<std::ffi::OsString>,
_lock: std::sync::MutexGuard<'static, ()>,
}

impl Drop for WorkspaceEnvGuard {
fn drop(&mut self) {
unsafe {
match self.prev.take() {
Some(value) => std::env::set_var("OPENHUMAN_WORKSPACE", value),
None => std::env::remove_var("OPENHUMAN_WORKSPACE"),
}
}
}
}

fn set_workspace_env(path: &std::path::Path) -> WorkspaceEnvGuard {
let lock = crate::openhuman::config::TEST_ENV_LOCK
.lock()
.unwrap_or_else(|e| e.into_inner());
let prev = std::env::var_os("OPENHUMAN_WORKSPACE");
unsafe {
std::env::set_var("OPENHUMAN_WORKSPACE", path);
}
WorkspaceEnvGuard { prev, _lock: lock }
}

async fn openrouter_key_handler(
State(state): State<ModelProbeState>,
headers: HeaderMap,
) -> Response {
state.key_calls.fetch_add(1, AtomicOrdering::SeqCst);
state
.key_authorization
.lock()
.unwrap_or_else(|e| e.into_inner())
.push(authorization_header(&headers));
if state.key_status.is_success() {
Json(serde_json::json!({
"data": {
"label": "test-key",
"usage": 0
}
}))
.into_response()
} else {
(
state.key_status,
Json(serde_json::json!({
"error": {
"message": "No auth credentials found"
}
})),
)
.into_response()
}
}

async fn models_handler(State(state): State<ModelProbeState>, headers: HeaderMap) -> Response {
state.model_calls.fetch_add(1, AtomicOrdering::SeqCst);
state
.model_authorization
.lock()
.unwrap_or_else(|e| e.into_inner())
.push(authorization_header(&headers));
Json(serde_json::json!({
"data": [{
"id": "openrouter/test-model",
"owned_by": "openrouter",
"context_length": 128000
}]
}))
.into_response()
}

fn authorization_header(headers: &HeaderMap) -> Option<String> {
headers
.get("authorization")
.and_then(|value| value.to_str().ok())
.map(|value| value.to_string())
}

async fn spawn_openrouter_probe_server(key_status: StatusCode) -> (String, ModelProbeState) {
let state = ModelProbeState {
key_status,
key_calls: Arc::new(AtomicUsize::new(0)),
model_calls: Arc::new(AtomicUsize::new(0)),
key_authorization: Arc::new(Mutex::new(Vec::new())),
model_authorization: Arc::new(Mutex::new(Vec::new())),
};
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind");
let addr = listener.local_addr().expect("local_addr");
let app = Router::new()
.route("/key", get(openrouter_key_handler))
.route("/models", get(models_handler))
.with_state(state.clone());
tokio::spawn(async move {
axum::serve(listener, app).await.expect("serve");
});
(format!("http://{addr}"), state)
}

async fn configure_openrouter_workspace(
tmp: &TempDir,
endpoint: String,
token: &str,
) -> Config {
let mut config = Config {
config_path: tmp.path().join("config.toml"),
workspace_dir: tmp.path().join("workspace"),
..Config::default()
};
config.secrets.encrypt = false;
config.cloud_providers.push(CloudProviderCreds {
id: "p_openrouter_test".to_string(),
slug: "openrouter".to_string(),
label: "OpenRouter".to_string(),
endpoint,
auth_style: AuthStyle::Bearer,
legacy_type: None,
default_model: None,
});
config.save().await.expect("save config");

let auth = AuthService::from_config(&config);
auth.store_provider_token(
&crate::openhuman::inference::provider::factory::auth_key_for_slug("openrouter"),
"default",
token,
HashMap::new(),
true,
)
.expect("store provider key");
config
}

#[test]
fn list_configured_models_accepts_slug() {
Expand Down Expand Up @@ -863,6 +1106,106 @@ mod tests {
assert!(found_by_id.is_some(), "id lookup must still work");
}

#[test]
fn openrouter_detection_matches_builtin_slug_or_host() {
let provider = |slug: &str, endpoint: &str| CloudProviderCreds {
id: format!("p_{slug}"),
slug: slug.to_string(),
label: slug.to_string(),
endpoint: endpoint.to_string(),
auth_style: AuthStyle::Bearer,
legacy_type: None,
default_model: None,
};

assert!(is_openrouter_provider(&provider(
"openrouter",
"http://127.0.0.1:1234"
)));
assert!(is_openrouter_provider(&provider(
"custom-router",
"https://openrouter.ai/api/v1"
)));
assert!(is_openrouter_provider(&provider(
"custom-router",
"https://oauth.openrouter.ai/api/v1"
)));
assert!(!is_openrouter_provider(&provider(
"custom-openai",
"https://api.openai.com/v1"
)));
}

#[tokio::test]
async fn openrouter_invalid_key_fails_before_models_catalog_probe() {
let tmp = tempfile::tempdir().expect("tempdir");
let _env = set_workspace_env(tmp.path());
let (endpoint, state) = spawn_openrouter_probe_server(StatusCode::UNAUTHORIZED).await;
configure_openrouter_workspace(&tmp, endpoint, "bad-openrouter-key").await;

let err = list_configured_models("openrouter")
.await
.expect_err("invalid OpenRouter key must fail");

assert!(
err.contains("OpenRouter key validation returned 401"),
"unexpected error: {err}"
);
assert_eq!(state.key_calls.load(AtomicOrdering::SeqCst), 1);
assert_eq!(
state.model_calls.load(AtomicOrdering::SeqCst),
0,
"invalid OpenRouter credentials must not fall through to /models"
);
}

#[tokio::test]
async fn openrouter_valid_key_allows_models_catalog_probe() {
let tmp = tempfile::tempdir().expect("tempdir");
let _env = set_workspace_env(tmp.path());
let (endpoint, state) = spawn_openrouter_probe_server(StatusCode::OK).await;
configure_openrouter_workspace(&tmp, endpoint, "valid-openrouter-key").await;

let outcome = list_configured_models("openrouter")
.await
.expect("valid OpenRouter key should list models");

assert_eq!(state.key_calls.load(AtomicOrdering::SeqCst), 1);
assert_eq!(state.model_calls.load(AtomicOrdering::SeqCst), 1);
assert_eq!(outcome.value["models"][0]["id"], "openrouter/test-model");
}

#[tokio::test]
async fn openrouter_key_is_trimmed_for_validation_and_catalog_probe() {
let tmp = tempfile::tempdir().expect("tempdir");
let _env = set_workspace_env(tmp.path());
let (endpoint, state) = spawn_openrouter_probe_server(StatusCode::OK).await;
configure_openrouter_workspace(&tmp, endpoint, " valid-openrouter-key\r\n").await;

list_configured_models("openrouter")
.await
.expect("trimmed OpenRouter key should list models");

let key_authorization = state
.key_authorization
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone();
let model_authorization = state
.model_authorization
.lock()
.unwrap_or_else(|e| e.into_inner())
.clone();
assert_eq!(
key_authorization,
vec![Some("Bearer valid-openrouter-key".to_string())]
);
assert_eq!(
model_authorization,
vec![Some("Bearer valid-openrouter-key".to_string())]
);
}

#[test]
fn factory_backend() {
assert!(create_backend_inference_provider(
Expand Down
Loading