Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions src/openhuman/agent/harness/session/turn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ use crate::openhuman::agent::harness;
use crate::openhuman::agent::hooks::{self, ToolCallRecord, TurnContext};
use crate::openhuman::agent::memory_loader::collect_recall_citations;
use crate::openhuman::agent::progress::AgentProgress;
use crate::openhuman::agent::tool_policy::{ToolPolicyDecision, ToolPolicyRequest};
use crate::openhuman::agent::tool_policy::{
ToolCallContext, ToolPolicyDecision, ToolPolicyRequest,
};
use crate::openhuman::agent_experience::{
prepend_experience_block, render_experience_hits, AgentExperienceStore, ExperienceQuery,
};
Expand Down Expand Up @@ -1145,13 +1147,15 @@ impl Agent {
false,
)
} else if let Some(tool) = self.tools.iter().find(|t| t.name() == call.name) {
let policy_request = ToolPolicyRequest {
tool_name: call.name.clone(),
arguments: call.arguments.clone(),
session_id: self.event_session_id().to_string(),
channel: self.event_channel().to_string(),
agent_definition_id: self.agent_definition_id.to_string(),
};
let context = ToolCallContext::session(
self.event_session_id(),
self.event_channel(),
self.agent_definition_id.to_string(),
call_id.clone(),
(iteration + 1) as u32,
);
let policy_request =
ToolPolicyRequest::new(call.name.clone(), call.arguments.clone(), context);
if let ToolPolicyDecision::Deny { reason } =
self.tool_policy.check(&policy_request).await
{
Expand Down
5 changes: 5 additions & 0 deletions src/openhuman/agent/harness/session/turn_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ impl ToolPolicy for DenyCountingPolicy {
assert_eq!(request.session_id, "turn-test-session");
assert_eq!(request.channel, "turn-test-channel");
assert_eq!(request.agent_definition_id, "main");
assert_eq!(request.context.session_id, "turn-test-session");
assert_eq!(request.context.channel, "turn-test-channel");
assert_eq!(request.context.agent_definition_id, "main");
assert_eq!(request.context.call_id, "policy-1");
assert_eq!(request.context.iteration, 1);
ToolPolicyDecision::deny("locked by test policy")
}
}
Expand Down
83 changes: 76 additions & 7 deletions src/openhuman/agent/tool_policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,79 @@

use async_trait::async_trait;

/// Structured context for a tool call before it reaches the tool
/// implementation.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ToolCallContext {
Comment thread
vaddisrinivas marked this conversation as resolved.
pub session_id: String,
pub channel: String,
pub agent_definition_id: String,
pub call_id: String,
pub iteration: u32,
pub source: ToolCallSource,
}

impl ToolCallContext {
pub fn session(
session_id: impl Into<String>,
channel: impl Into<String>,
agent_definition_id: impl Into<String>,
call_id: impl Into<String>,
iteration: u32,
) -> Self {
Self {
session_id: session_id.into(),
channel: channel.into(),
agent_definition_id: agent_definition_id.into(),
call_id: call_id.into(),
iteration,
source: ToolCallSource::Session,
}
}
}

/// Entry point that produced a tool call.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolCallSource {
Session,
Bus,
Channel,
Comment thread
vaddisrinivas marked this conversation as resolved.
Cron,
Webhook,
Unknown,
}

/// Snapshot of the tool call and session context a policy can inspect.
#[derive(Debug, Clone)]
pub struct ToolPolicyRequest {
pub tool_name: String,
pub arguments: serde_json::Value,
pub context: ToolCallContext,
/// Backward-compatible mirror of `context.session_id`.
pub session_id: String,
/// Backward-compatible mirror of `context.channel`.
Comment thread
vaddisrinivas marked this conversation as resolved.
pub channel: String,
/// Backward-compatible mirror of `context.agent_definition_id`.
pub agent_definition_id: String,
}

impl ToolPolicyRequest {
pub fn new(
tool_name: impl Into<String>,
arguments: serde_json::Value,
context: ToolCallContext,
) -> Self {
Self {
tool_name: tool_name.into(),
arguments,
session_id: context.session_id.clone(),
channel: context.channel.clone(),
agent_definition_id: context.agent_definition_id.clone(),
context,
}
}
}

/// Decision returned by a [`ToolPolicy`].
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ToolPolicyDecision {
Expand Down Expand Up @@ -63,14 +126,20 @@ mod tests {
#[tokio::test]
async fn allow_all_policy_allows_every_call() {
let policy = AllowAllToolPolicy;
let request = ToolPolicyRequest {
tool_name: "echo".into(),
arguments: serde_json::json!({ "value": 1 }),
session_id: "session".into(),
channel: "chat".into(),
agent_definition_id: "orchestrator".into(),
};
let request = ToolPolicyRequest::new(
"echo",
serde_json::json!({ "value": 1 }),
ToolCallContext::session("session", "chat", "orchestrator", "call-1", 1),
);

assert_eq!(policy.check(&request).await, ToolPolicyDecision::Allow);
assert_eq!(request.session_id, request.context.session_id);
assert_eq!(request.channel, request.context.channel);
assert_eq!(
request.agent_definition_id,
request.context.agent_definition_id
);
assert_eq!(request.context.source, ToolCallSource::Session);
assert_eq!(request.context.call_id, "call-1");
}
}
Loading