diff --git a/src/openhuman/agent/harness/session/turn.rs b/src/openhuman/agent/harness/session/turn.rs index d6ce84e20d..60d1b20d59 100644 --- a/src/openhuman/agent/harness/session/turn.rs +++ b/src/openhuman/agent/harness/session/turn.rs @@ -1137,20 +1137,28 @@ impl Agent { call_id.clone(), (iteration + 1) as u32, ); - let policy_request = + let mut policy_request = ToolPolicyRequest::new(call.name.clone(), call.arguments.clone(), context); - if let ToolPolicyDecision::Deny { reason } = - self.tool_policy.check(&policy_request).await - { + if let Some(generated_context) = tool.generated_runtime_context(&call.arguments) { + policy_request = policy_request.with_generated_tool_context(generated_context); + } + let policy_decision = self.tool_policy.check(&policy_request).await; + if let Some(reason) = policy_decision.blocking_reason() { + let blocked_action = match &policy_decision { + ToolPolicyDecision::RequireApproval { .. } => "requires approval", + ToolPolicyDecision::Deny { .. } => "denied", + ToolPolicyDecision::Allow => "allowed", + }; tracing::debug!( tool = call.name.as_str(), policy = self.tool_policy.name(), + action = blocked_action, reason = %reason, - "[agent_loop] tool denied by policy" + "[agent_loop] tool blocked by policy" ); ( format!( - "Tool '{}' denied by policy '{}': {reason}", + "Tool '{}' {blocked_action} by policy '{}': {reason}", call.name, self.tool_policy.name() ), diff --git a/src/openhuman/agent/harness/session/turn_tests.rs b/src/openhuman/agent/harness/session/turn_tests.rs index 741a00a2ae..b3e81dd17f 100644 --- a/src/openhuman/agent/harness/session/turn_tests.rs +++ b/src/openhuman/agent/harness/session/turn_tests.rs @@ -3,7 +3,10 @@ use crate::core::event_bus::{global, init_global, DomainEvent}; use crate::openhuman::agent::dispatcher::XmlToolDispatcher; use crate::openhuman::agent::hooks::{PostTurnHook, TurnContext}; use crate::openhuman::agent::memory_loader::MemoryLoader; -use crate::openhuman::agent::tool_policy::{ToolPolicy, ToolPolicyDecision, ToolPolicyRequest}; +use crate::openhuman::agent::tool_policy::{ + GeneratedToolRuntimeContext, GeneratedToolRuntimeRisk, ToolPolicy, ToolPolicyDecision, + ToolPolicyRequest, +}; use crate::openhuman::inference::provider::{ChatRequest, ChatResponse, Provider}; use crate::openhuman::memory::Memory; use crate::openhuman::tools::ToolResult; @@ -200,6 +203,64 @@ impl Tool for CountingWriteTool { } } +struct GeneratedContextTool { + calls: Arc, +} + +#[async_trait] +impl Tool for GeneratedContextTool { + fn name(&self) -> &str { + "generated_send" + } + + fn description(&self) -> &str { + "generated send" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({"type":"object"}) + } + + async fn execute(&self, _args: serde_json::Value) -> Result { + self.calls.fetch_add(1, Ordering::SeqCst); + Ok(ToolResult::success("generated-output")) + } + + fn generated_runtime_context( + &self, + _args: &serde_json::Value, + ) -> Option { + Some(GeneratedToolRuntimeContext { + provider_id: "mail.runtime".to_string(), + capability_id: "email.send".to_string(), + risk: GeneratedToolRuntimeRisk::ExternalWrite, + source_digest: Some("sha256:abc".to_string()), + approval_id: Some("approval-1".to_string()), + }) + } +} + +struct RequireGeneratedContextPolicy; + +#[async_trait] +impl ToolPolicy for RequireGeneratedContextPolicy { + fn name(&self) -> &str { + "require_generated_context" + } + + async fn check(&self, request: &ToolPolicyRequest) -> ToolPolicyDecision { + let context = request + .generated_tool + .as_ref() + .expect("generated tool context should be threaded"); + assert_eq!(context.provider_id, "mail.runtime"); + assert_eq!(context.capability_id, "email.send"); + assert_eq!(context.risk, GeneratedToolRuntimeRisk::ExternalWrite); + assert_eq!(context.approval_id.as_deref(), Some("approval-1")); + ToolPolicyDecision::require_approval("generated context requires approval") + } +} + struct RecordingHook { calls: Arc>>, notify: Arc, @@ -594,6 +655,49 @@ async fn execute_tool_call_denies_by_policy_before_tool_runs() { assert!(!record.success); } +#[tokio::test] +async fn execute_tool_call_threads_generated_tool_context_into_policy() { + let workspace = tempfile::TempDir::new().expect("temp workspace"); + let workspace_path = workspace.path().to_path_buf(); + std::mem::forget(workspace); + let memory_cfg = crate::openhuman::config::MemoryConfig { + backend: "none".into(), + ..crate::openhuman::config::MemoryConfig::default() + }; + let mem: Arc = Arc::from( + crate::openhuman::memory_store::create_memory(&memory_cfg, &workspace_path).unwrap(), + ); + let calls = Arc::new(AtomicUsize::new(0)); + + let agent = Agent::builder() + .provider(Box::new(DummyProvider)) + .tools(vec![Box::new(GeneratedContextTool { + calls: Arc::clone(&calls), + })]) + .memory(mem) + .tool_dispatcher(Box::new(XmlToolDispatcher)) + .workspace_dir(workspace_path) + .event_context("turn-test-session", "turn-test-channel") + .tool_policy(Arc::new(RequireGeneratedContextPolicy)) + .build() + .unwrap(); + let call = ParsedToolCall { + name: "generated_send".into(), + arguments: serde_json::json!({ "value": 1 }), + tool_call_id: Some("policy-generated-1".into()), + }; + + let (result, record) = agent.execute_tool_call(&call, 0).await; + assert!(!result.success); + assert!(result.output.contains("requires approval by policy")); + assert!(result + .output + .contains("generated context requires approval")); + assert_eq!(calls.load(Ordering::SeqCst), 0); + assert_eq!(record.name, "generated_send"); + assert!(!record.success); +} + #[tokio::test] async fn turn_runs_full_tool_cycle_with_context_and_hooks() { let provider_impl = Arc::new(SequenceProvider { diff --git a/src/openhuman/agent/tool_policy.rs b/src/openhuman/agent/tool_policy.rs index e9e26049ac..3396e50cde 100644 --- a/src/openhuman/agent/tool_policy.rs +++ b/src/openhuman/agent/tool_policy.rs @@ -5,6 +5,7 @@ //! deny a tool before any side effect reaches the tool implementation. use async_trait::async_trait; +use std::collections::{BTreeMap, BTreeSet}; use std::fmt; /// Structured context for a tool call before it reaches the tool @@ -69,6 +70,7 @@ pub struct ToolPolicyRequest { pub tool_name: String, pub arguments: serde_json::Value, pub context: ToolCallContext, + pub generated_tool: Option, /// Backward-compatible mirror of `context.session_id`. #[deprecated(note = "use context.session_id")] pub session_id: String, @@ -88,6 +90,7 @@ impl fmt::Debug for ToolPolicyRequest { .field("tool_name", &self.tool_name) .field("arguments", &"") .field("context", &self.context) + .field("generated_tool", &self.generated_tool) .field("session_id", &redact_for_debug(&self.session_id)) .field("channel", &redact_for_debug(&self.channel)) .field("agent_definition_id", &self.agent_definition_id) @@ -111,9 +114,15 @@ impl ToolPolicyRequest { channel: context.channel.clone(), agent_definition_id: context.agent_definition_id.clone(), context, + generated_tool: None, } } } + + pub fn with_generated_tool_context(mut self, context: GeneratedToolRuntimeContext) -> Self { + self.generated_tool = Some(context); + self + } } fn redact_for_debug(value: &str) -> String { @@ -129,15 +138,39 @@ fn redact_for_debug(value: &str) -> String { #[derive(Debug, Clone, PartialEq, Eq)] pub enum ToolPolicyDecision { Allow, - Deny { reason: String }, + /// The policy requires an approval handoff before execution. + /// + /// Session execution currently treats this as fail-closed through + /// [`ToolPolicyDecision::blocking_reason`]. Callers that can prompt for + /// approval may branch on this variant and retry after approval is granted. + RequireApproval { + reason: String, + }, + Deny { + reason: String, + }, } impl ToolPolicyDecision { + pub fn require_approval(reason: impl Into) -> Self { + Self::RequireApproval { + reason: reason.into(), + } + } + pub fn deny(reason: impl Into) -> Self { Self::Deny { reason: reason.into(), } } + + /// Reason used by fail-closed executors that cannot complete approvals inline. + pub fn blocking_reason(&self) -> Option<&str> { + match self { + Self::Allow => None, + Self::RequireApproval { reason } | Self::Deny { reason } => Some(reason.as_str()), + } + } } /// Policy middleware invoked before an agent executes a tool. @@ -165,6 +198,196 @@ impl ToolPolicy for AllowAllToolPolicy { } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GeneratedToolRuntimeContext { + pub provider_id: String, + pub capability_id: String, + pub risk: GeneratedToolRuntimeRisk, + pub source_digest: Option, + pub approval_id: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum GeneratedToolRuntimeRisk { + Read, + Write, + ExternalWrite, + Execute, + Dangerous, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RuntimeToolPolicyAction { + Allow, + RequireApproval, + Deny, +} + +#[derive(Debug, Clone, Default)] +pub struct GeneratedToolRuntimePolicyConfig { + pub enabled: bool, + pub revoked_providers: BTreeSet, + pub revoked_capabilities: BTreeSet, + pub provider_actions: BTreeMap, + pub capability_actions: BTreeMap, + pub risk_actions: BTreeMap, +} + +#[derive(Debug, Clone)] +pub struct GeneratedToolRuntimePolicy { + config: GeneratedToolRuntimePolicyConfig, +} + +impl GeneratedToolRuntimePolicy { + pub fn new(config: GeneratedToolRuntimePolicyConfig) -> Self { + Self { config } + } + + fn action_for( + &self, + tool_name: &str, + context: &GeneratedToolRuntimeContext, + ) -> (RuntimeToolPolicyAction, String) { + if self + .config + .revoked_providers + .contains(context.provider_id.as_str()) + { + tracing::debug!( + tool = tool_name, + provider_id = context.provider_id.as_str(), + capability_id = context.capability_id.as_str(), + risk = ?context.risk, + action = ?RuntimeToolPolicyAction::Deny, + "[generated_tool_runtime] provider revoked" + ); + return ( + RuntimeToolPolicyAction::Deny, + format!("provider `{}` is revoked", context.provider_id), + ); + } + if self + .config + .revoked_capabilities + .contains(context.capability_id.as_str()) + { + tracing::debug!( + tool = tool_name, + provider_id = context.provider_id.as_str(), + capability_id = context.capability_id.as_str(), + risk = ?context.risk, + action = ?RuntimeToolPolicyAction::Deny, + "[generated_tool_runtime] capability revoked" + ); + return ( + RuntimeToolPolicyAction::Deny, + format!("capability `{}` is revoked", context.capability_id), + ); + } + if let Some(action) = self.config.capability_actions.get(&context.capability_id) { + tracing::debug!( + tool = tool_name, + provider_id = context.provider_id.as_str(), + capability_id = context.capability_id.as_str(), + risk = ?context.risk, + action = ?action, + "[generated_tool_runtime] capability action matched" + ); + return ( + *action, + format!( + "capability `{}` matched runtime policy", + context.capability_id + ), + ); + } + if let Some(action) = self.config.provider_actions.get(&context.provider_id) { + tracing::debug!( + tool = tool_name, + provider_id = context.provider_id.as_str(), + capability_id = context.capability_id.as_str(), + risk = ?context.risk, + action = ?action, + "[generated_tool_runtime] provider action matched" + ); + return ( + *action, + format!("provider `{}` matched runtime policy", context.provider_id), + ); + } + if let Some(action) = self.config.risk_actions.get(&context.risk) { + tracing::debug!( + tool = tool_name, + provider_id = context.provider_id.as_str(), + capability_id = context.capability_id.as_str(), + risk = ?context.risk, + action = ?action, + "[generated_tool_runtime] risk action matched" + ); + return ( + *action, + format!("risk `{:?}` matched runtime policy", context.risk), + ); + } + tracing::trace!( + tool = tool_name, + provider_id = context.provider_id.as_str(), + capability_id = context.capability_id.as_str(), + risk = ?context.risk, + action = ?RuntimeToolPolicyAction::Allow, + "[generated_tool_runtime] default allow" + ); + ( + RuntimeToolPolicyAction::Allow, + format!("tool `{tool_name}` allowed"), + ) + } +} + +#[async_trait] +impl ToolPolicy for GeneratedToolRuntimePolicy { + fn name(&self) -> &str { + "generated_tool_runtime" + } + + async fn check(&self, request: &ToolPolicyRequest) -> ToolPolicyDecision { + if !self.config.enabled { + tracing::trace!( + policy = self.name(), + tool = request.tool_name.as_str(), + "[generated_tool_runtime] policy disabled" + ); + return ToolPolicyDecision::Allow; + } + let Some(context) = request.generated_tool.as_ref() else { + tracing::trace!( + policy = self.name(), + tool = request.tool_name.as_str(), + "[generated_tool_runtime] context missing" + ); + return ToolPolicyDecision::Allow; + }; + let (action, reason) = self.action_for(&request.tool_name, context); + tracing::debug!( + policy = self.name(), + tool = request.tool_name.as_str(), + provider_id = context.provider_id.as_str(), + capability_id = context.capability_id.as_str(), + risk = ?context.risk, + action = ?action, + reason = reason.as_str(), + "[generated_tool_runtime] policy decision" + ); + match action { + RuntimeToolPolicyAction::Allow => ToolPolicyDecision::Allow, + RuntimeToolPolicyAction::RequireApproval => { + ToolPolicyDecision::require_approval(reason) + } + RuntimeToolPolicyAction::Deny => ToolPolicyDecision::deny(reason), + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -213,4 +436,89 @@ mod tests { assert!(!rendered.contains("private-channel")); assert!(!rendered.contains("super-secret-token")); } + + fn generated_request() -> ToolPolicyRequest { + ToolPolicyRequest::new( + "email.send", + serde_json::json!({ "to": "user@example.com" }), + ToolCallContext::session("session", "chat", "orchestrator", "call-1", 1), + ) + .with_generated_tool_context(GeneratedToolRuntimeContext { + provider_id: "mail.runtime".to_string(), + capability_id: "email.send".to_string(), + risk: GeneratedToolRuntimeRisk::ExternalWrite, + source_digest: Some("sha256:abc".to_string()), + approval_id: None, + }) + } + + #[tokio::test] + async fn generated_runtime_policy_allows_when_disabled() { + let policy = GeneratedToolRuntimePolicy::new(GeneratedToolRuntimePolicyConfig::default()); + + assert_eq!( + policy.check(&generated_request()).await, + ToolPolicyDecision::Allow + ); + } + + #[tokio::test] + async fn generated_runtime_policy_allows_when_enabled_but_missing_context() { + let policy = GeneratedToolRuntimePolicy::new(GeneratedToolRuntimePolicyConfig { + enabled: true, + ..Default::default() + }); + + let request = ToolPolicyRequest::new( + "echo", + serde_json::json!({ "value": 1 }), + ToolCallContext::session("session", "chat", "orchestrator", "call-1", 1), + ); + + assert_eq!(policy.check(&request).await, ToolPolicyDecision::Allow); + } + + #[tokio::test] + async fn generated_runtime_policy_denies_revoked_provider() { + let policy = GeneratedToolRuntimePolicy::new(GeneratedToolRuntimePolicyConfig { + enabled: true, + revoked_providers: BTreeSet::from(["mail.runtime".to_string()]), + ..Default::default() + }); + + let decision = policy.check(&generated_request()).await; + assert!(matches!(decision, ToolPolicyDecision::Deny { .. })); + assert!(decision.blocking_reason().unwrap().contains("revoked")); + } + + #[tokio::test] + async fn generated_runtime_policy_denies_revoked_capability() { + let policy = GeneratedToolRuntimePolicy::new(GeneratedToolRuntimePolicyConfig { + enabled: true, + revoked_capabilities: BTreeSet::from(["email.send".to_string()]), + ..Default::default() + }); + + let decision = policy.check(&generated_request()).await; + assert!(matches!(decision, ToolPolicyDecision::Deny { .. })); + assert!(decision.blocking_reason().unwrap().contains("capability")); + } + + #[tokio::test] + async fn generated_runtime_policy_requires_approval_by_risk() { + let policy = GeneratedToolRuntimePolicy::new(GeneratedToolRuntimePolicyConfig { + enabled: true, + risk_actions: BTreeMap::from([( + GeneratedToolRuntimeRisk::ExternalWrite, + RuntimeToolPolicyAction::RequireApproval, + )]), + ..Default::default() + }); + + let decision = policy.check(&generated_request()).await; + assert!(matches!( + decision, + ToolPolicyDecision::RequireApproval { .. } + )); + } } diff --git a/src/openhuman/security/audit.rs b/src/openhuman/security/audit.rs index e32da17614..02e014957e 100644 --- a/src/openhuman/security/audit.rs +++ b/src/openhuman/security/audit.rs @@ -40,6 +40,14 @@ pub struct Action { pub risk_level: Option, pub approved: bool, pub allowed: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub provider_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub capability_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub policy_decision: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub approval_id: Option, } /// Execution result @@ -117,6 +125,25 @@ impl AuditEvent { risk_level: Some(risk_level), approved, allowed, + provider_id: None, + capability_id: None, + policy_decision: None, + approval_id: None, + }); + self + } + + /// Set action metadata for a generated tool execution. + pub fn with_generated_tool_action(mut self, entry: GeneratedToolExecutionLog<'_>) -> Self { + self.action = Some(Action { + command: Some(entry.tool_name.to_string()), + risk_level: Some(entry.risk_level.to_string()), + approved: entry.approved, + allowed: entry.allowed, + provider_id: Some(entry.provider_id.to_string()), + capability_id: Some(entry.capability_id.to_string()), + policy_decision: Some(entry.policy_decision.to_string()), + approval_id: entry.approval_id.map(str::to_string), }); self } @@ -221,6 +248,22 @@ pub struct CommandExecutionLog<'a> { pub duration_ms: u64, } +/// Structured generated tool execution details for audit correlation. +#[derive(Debug, Clone)] +pub struct GeneratedToolExecutionLog<'a> { + pub channel: &'a str, + pub tool_name: &'a str, + pub provider_id: &'a str, + pub capability_id: &'a str, + pub risk_level: &'a str, + pub policy_decision: &'a str, + pub approval_id: Option<&'a str>, + pub approved: bool, + pub allowed: bool, + pub success: bool, + pub duration_ms: u64, +} + impl AuditLogger { /// Build a disabled `Arc` for tests and contexts that need a /// handle but should not write to disk. The `enabled = false` flag @@ -301,6 +344,17 @@ impl AuditLogger { self.log(&event) } + /// Log a generated tool execution event with provider/capability + /// provenance suitable for runtime policy audits. + pub fn log_generated_tool_event(&self, entry: GeneratedToolExecutionLog<'_>) -> Result<()> { + let event = AuditEvent::new(AuditEventType::CommandExecution) + .with_actor(entry.channel.to_string(), None, None) + .with_generated_tool_action(entry.clone()) + .with_result(entry.success, None, entry.duration_ms, None); + + self.log(&event) + } + /// Backward-compatible helper to log a command execution event. #[allow(clippy::too_many_arguments)] pub fn log_command( @@ -513,6 +567,42 @@ mod tests { Ok(()) } + #[tokio::test] + async fn audit_log_generated_tool_event_writes_correlation_fields() -> Result<()> { + let tmp = TempDir::new()?; + let config = AuditConfig { + enabled: true, + max_size_mb: 10, + ..Default::default() + }; + let logger = AuditLogger::new(config, tmp.path().to_path_buf())?; + + logger.log_generated_tool_event(GeneratedToolExecutionLog { + channel: "chat", + tool_name: "email.send", + provider_id: "mail.runtime", + capability_id: "email.send", + risk_level: "external_write", + policy_decision: "require_approval", + approval_id: Some("approval-1"), + approved: true, + allowed: true, + success: true, + duration_ms: 13, + })?; + + let log_path = tmp.path().join("audit.log"); + let content = tokio::fs::read_to_string(&log_path).await?; + let parsed: AuditEvent = serde_json::from_str(content.trim())?; + let action = parsed.action.unwrap(); + assert_eq!(action.command, Some("email.send".to_string())); + assert_eq!(action.provider_id, Some("mail.runtime".to_string())); + assert_eq!(action.capability_id, Some("email.send".to_string())); + assert_eq!(action.policy_decision, Some("require_approval".to_string())); + assert_eq!(action.approval_id, Some("approval-1".to_string())); + Ok(()) + } + #[test] fn audit_rotation_creates_numbered_backup() -> Result<()> { let tmp = TempDir::new()?; diff --git a/src/openhuman/tools/traits.rs b/src/openhuman/tools/traits.rs index ab1f7d1318..cbf98e3d32 100644 --- a/src/openhuman/tools/traits.rs +++ b/src/openhuman/tools/traits.rs @@ -1,6 +1,8 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; +use crate::openhuman::agent::tool_policy::GeneratedToolRuntimeContext; + // Re-export the unified ToolResult from the lightweight skills types module so all tools use one type. pub use crate::openhuman::skills::types::{ToolContent, ToolResult}; @@ -225,6 +227,18 @@ pub trait Tool: Send + Sync { self.external_effect() } + /// Optional generated-tool runtime metadata for policy enforcement. + /// + /// Generated or externally supplied tools can override this to let + /// the agent policy layer apply provider/capability/risk rules before + /// execution. Built-in tools leave it unset. + fn generated_runtime_context( + &self, + _args: &serde_json::Value, + ) -> Option { + None + } + /// Per-tool cap on the character length of the result body sent /// back to the model. ///