Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/openhuman/agent/harness/session/turn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ use crate::openhuman::agent::harness;
use crate::openhuman::agent::hooks::{self, ToolCallRecord, TurnContext};
use crate::openhuman::agent::memory_loader::collect_recall_citations;
use crate::openhuman::agent::progress::AgentProgress;
use crate::openhuman::agent::tool_policy::{ToolPolicyDecision, ToolPolicyRequest};
use crate::openhuman::agent::tool_policy::{
ToolCallContext, ToolPolicyDecision, ToolPolicyRequest,
};
use crate::openhuman::agent_experience::{
prepend_experience_block, render_experience_hits, AgentExperienceStore, ExperienceQuery,
};
Expand Down Expand Up @@ -1145,12 +1147,20 @@ impl Agent {
false,
)
} else if let Some(tool) = self.tools.iter().find(|t| t.name() == call.name) {
let context = ToolCallContext::session(
self.event_session_id(),
self.event_channel(),
self.agent_definition_id.to_string(),
call_id.clone(),
(iteration + 1) as u32,
);
let policy_request = ToolPolicyRequest {
tool_name: call.name.clone(),
arguments: call.arguments.clone(),
session_id: self.event_session_id().to_string(),
channel: self.event_channel().to_string(),
agent_definition_id: self.agent_definition_id.to_string(),
session_id: context.session_id.clone(),
channel: context.channel.clone(),
agent_definition_id: context.agent_definition_id.clone(),
context,
};
if let ToolPolicyDecision::Deny { reason } =
self.tool_policy.check(&policy_request).await
Expand Down
5 changes: 5 additions & 0 deletions src/openhuman/agent/harness/session/turn_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ impl ToolPolicy for DenyCountingPolicy {
assert_eq!(request.session_id, "turn-test-session");
assert_eq!(request.channel, "turn-test-channel");
assert_eq!(request.agent_definition_id, "main");
assert_eq!(request.context.session_id, "turn-test-session");
assert_eq!(request.context.channel, "turn-test-channel");
assert_eq!(request.context.agent_definition_id, "main");
assert_eq!(request.context.call_id, "policy-1");
assert_eq!(request.context.iteration, 1);
ToolPolicyDecision::deny("locked by test policy")
}
}
Expand Down
49 changes: 49 additions & 0 deletions src/openhuman/agent/tool_policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,59 @@

use async_trait::async_trait;

/// Structured context for a tool call before it reaches the tool
/// implementation.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ToolCallContext {
Comment thread
vaddisrinivas marked this conversation as resolved.
pub session_id: String,
pub channel: String,
pub agent_definition_id: String,
pub call_id: String,
pub iteration: u32,
pub source: ToolCallSource,
}

impl ToolCallContext {
pub fn session(
session_id: impl Into<String>,
channel: impl Into<String>,
agent_definition_id: impl Into<String>,
call_id: impl Into<String>,
iteration: u32,
) -> Self {
Self {
session_id: session_id.into(),
channel: channel.into(),
agent_definition_id: agent_definition_id.into(),
call_id: call_id.into(),
iteration,
source: ToolCallSource::Session,
}
}
}

/// Entry point that produced a tool call.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolCallSource {
Session,
Bus,
Channel,
Comment thread
vaddisrinivas marked this conversation as resolved.
Cron,
Webhook,
Unknown,
}

/// Snapshot of the tool call and session context a policy can inspect.
#[derive(Debug, Clone)]
pub struct ToolPolicyRequest {
pub tool_name: String,
pub arguments: serde_json::Value,
pub context: ToolCallContext,
/// Backward-compatible mirror of `context.session_id`.
pub session_id: String,
/// Backward-compatible mirror of `context.channel`.
Comment thread
vaddisrinivas marked this conversation as resolved.
pub channel: String,
/// Backward-compatible mirror of `context.agent_definition_id`.
pub agent_definition_id: String,
}

Expand Down Expand Up @@ -66,11 +112,14 @@ mod tests {
let request = ToolPolicyRequest {
tool_name: "echo".into(),
arguments: serde_json::json!({ "value": 1 }),
context: ToolCallContext::session("session", "chat", "orchestrator", "call-1", 1),
session_id: "session".into(),
channel: "chat".into(),
agent_definition_id: "orchestrator".into(),
};

assert_eq!(policy.check(&request).await, ToolPolicyDecision::Allow);
assert_eq!(request.context.source, ToolCallSource::Session);
assert_eq!(request.context.call_id, "call-1");
}
}
Loading