diff --git a/src/openhuman/agent/harness/session/turn.rs b/src/openhuman/agent/harness/session/turn.rs index 5c52aa2b23..eb5c39dbbe 100644 --- a/src/openhuman/agent/harness/session/turn.rs +++ b/src/openhuman/agent/harness/session/turn.rs @@ -25,7 +25,9 @@ use crate::openhuman::agent::harness; use crate::openhuman::agent::hooks::{self, ToolCallRecord, TurnContext}; use crate::openhuman::agent::memory_loader::collect_recall_citations; use crate::openhuman::agent::progress::AgentProgress; -use crate::openhuman::agent::tool_policy::{ToolPolicyDecision, ToolPolicyRequest}; +use crate::openhuman::agent::tool_policy::{ + ToolCallContext, ToolPolicyDecision, ToolPolicyRequest, +}; use crate::openhuman::agent_experience::{ prepend_experience_block, render_experience_hits, AgentExperienceStore, ExperienceQuery, }; @@ -1163,13 +1165,15 @@ impl Agent { false, ) } else { - let policy_request = ToolPolicyRequest { - tool_name: call.name.clone(), - arguments: call.arguments.clone(), - session_id: self.event_session_id().to_string(), - channel: self.event_channel().to_string(), - agent_definition_id: self.agent_definition_id.to_string(), - }; + let context = ToolCallContext::session( + self.event_session_id(), + self.event_channel(), + self.agent_definition_id.to_string(), + call_id.clone(), + (iteration + 1) as u32, + ); + let policy_request = + ToolPolicyRequest::new(call.name.clone(), call.arguments.clone(), context); if let ToolPolicyDecision::Deny { reason } = self.tool_policy.check(&policy_request).await { diff --git a/src/openhuman/agent/harness/session/turn_tests.rs b/src/openhuman/agent/harness/session/turn_tests.rs index ec3cc0a6bd..ee55a518a2 100644 --- a/src/openhuman/agent/harness/session/turn_tests.rs +++ b/src/openhuman/agent/harness/session/turn_tests.rs @@ -142,9 +142,11 @@ impl ToolPolicy for DenyCountingPolicy { async fn check(&self, request: &ToolPolicyRequest) -> ToolPolicyDecision { assert_eq!(request.tool_name, "counting"); - assert_eq!(request.session_id, "turn-test-session"); - assert_eq!(request.channel, "turn-test-channel"); - assert_eq!(request.agent_definition_id, "main"); + assert_eq!(request.context.session_id, "turn-test-session"); + assert_eq!(request.context.channel, "turn-test-channel"); + assert_eq!(request.context.agent_definition_id, "main"); + assert_eq!(request.context.call_id, "policy-1"); + assert_eq!(request.context.iteration, 1); ToolPolicyDecision::deny("locked by test policy") } } diff --git a/src/openhuman/agent/tool_policy.rs b/src/openhuman/agent/tool_policy.rs index 3028ad65c0..e9e26049ac 100644 --- a/src/openhuman/agent/tool_policy.rs +++ b/src/openhuman/agent/tool_policy.rs @@ -5,17 +5,126 @@ //! deny a tool before any side effect reaches the tool implementation. use async_trait::async_trait; +use std::fmt; + +/// Structured context for a tool call before it reaches the tool +/// implementation. +#[derive(Clone, PartialEq, Eq)] +pub struct ToolCallContext { + pub session_id: String, + pub channel: String, + pub agent_definition_id: String, + pub call_id: String, + pub iteration: u32, + pub source: ToolCallSource, +} + +impl ToolCallContext { + pub fn session( + session_id: impl Into, + channel: impl Into, + agent_definition_id: impl Into, + call_id: impl Into, + iteration: u32, + ) -> Self { + Self { + session_id: session_id.into(), + channel: channel.into(), + agent_definition_id: agent_definition_id.into(), + call_id: call_id.into(), + iteration, + source: ToolCallSource::Session, + } + } +} + +impl fmt::Debug for ToolCallContext { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ToolCallContext") + .field("session_id", &redact_for_debug(&self.session_id)) + .field("channel", &redact_for_debug(&self.channel)) + .field("agent_definition_id", &self.agent_definition_id) + .field("call_id", &self.call_id) + .field("iteration", &self.iteration) + .field("source", &self.source) + .finish() + } +} + +/// Entry point that produced a tool call. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[allow(dead_code)] // Reserved for non-session tool ingress paths wired in follow-up PRs. +pub enum ToolCallSource { + Session, + Bus, + Channel, + Cron, + Webhook, + Unknown, +} /// Snapshot of the tool call and session context a policy can inspect. -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct ToolPolicyRequest { pub tool_name: String, pub arguments: serde_json::Value, + pub context: ToolCallContext, + /// Backward-compatible mirror of `context.session_id`. + #[deprecated(note = "use context.session_id")] pub session_id: String, + /// Backward-compatible mirror of `context.channel`. + #[deprecated(note = "use context.channel")] pub channel: String, + /// Backward-compatible mirror of `context.agent_definition_id`. + #[deprecated(note = "use context.agent_definition_id")] pub agent_definition_id: String, } +impl fmt::Debug for ToolPolicyRequest { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[allow(deprecated)] + { + f.debug_struct("ToolPolicyRequest") + .field("tool_name", &self.tool_name) + .field("arguments", &"") + .field("context", &self.context) + .field("session_id", &redact_for_debug(&self.session_id)) + .field("channel", &redact_for_debug(&self.channel)) + .field("agent_definition_id", &self.agent_definition_id) + .finish() + } + } +} + +impl ToolPolicyRequest { + pub fn new( + tool_name: impl Into, + arguments: serde_json::Value, + context: ToolCallContext, + ) -> Self { + #[allow(deprecated)] + { + Self { + tool_name: tool_name.into(), + arguments, + session_id: context.session_id.clone(), + channel: context.channel.clone(), + agent_definition_id: context.agent_definition_id.clone(), + context, + } + } + } +} + +fn redact_for_debug(value: &str) -> String { + let trimmed = value.trim(); + if trimmed.is_empty() { + return "".to_string(); + } + let prefix: String = trimmed.chars().take(4).collect(); + format!("{prefix}...") +} + /// Decision returned by a [`ToolPolicy`]. #[derive(Debug, Clone, PartialEq, Eq)] pub enum ToolPolicyDecision { @@ -63,14 +172,45 @@ mod tests { #[tokio::test] async fn allow_all_policy_allows_every_call() { let policy = AllowAllToolPolicy; - let request = ToolPolicyRequest { - tool_name: "echo".into(), - arguments: serde_json::json!({ "value": 1 }), - session_id: "session".into(), - channel: "chat".into(), - agent_definition_id: "orchestrator".into(), - }; + let request = ToolPolicyRequest::new( + "echo", + serde_json::json!({ "value": 1 }), + ToolCallContext::session("session", "chat", "orchestrator", "call-1", 1), + ); assert_eq!(policy.check(&request).await, ToolPolicyDecision::Allow); + #[allow(deprecated)] + { + assert_eq!(request.session_id, request.context.session_id); + assert_eq!(request.channel, request.context.channel); + assert_eq!( + request.agent_definition_id, + request.context.agent_definition_id + ); + } + assert_eq!(request.context.source, ToolCallSource::Session); + assert_eq!(request.context.call_id, "call-1"); + } + + #[test] + fn debug_redacts_sensitive_context_fields() { + let request = ToolPolicyRequest::new( + "secrets.lookup", + serde_json::json!({ "secret": "super-secret-token" }), + ToolCallContext::session( + "session-secret-123", + "private-channel", + "orchestrator", + "call-1", + 1, + ), + ); + + let rendered = format!("{request:?}"); + assert!(rendered.contains("sess...")); + assert!(rendered.contains("priv...")); + assert!(!rendered.contains("session-secret-123")); + assert!(!rendered.contains("private-channel")); + assert!(!rendered.contains("super-secret-token")); } }