Skip to content
14 changes: 7 additions & 7 deletions src/openhuman/agent/harness/session/turn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ use crate::openhuman::agent::harness;
use crate::openhuman::agent::hooks::{self, ToolCallRecord, TurnContext};
use crate::openhuman::agent::memory_loader::collect_recall_citations;
use crate::openhuman::agent::progress::AgentProgress;
use crate::openhuman::agent::tool_policy::{
ToolCallContext, ToolPolicyDecision, ToolPolicyRequest,
};
use crate::openhuman::agent::tool_policy::{ToolCallContext, ToolPolicyRequest};
use crate::openhuman::agent_experience::{
prepend_experience_block, render_experience_hits, AgentExperienceStore, ExperienceQuery,
};
Expand Down Expand Up @@ -1201,11 +1199,13 @@ impl Agent {
call_id.clone(),
(iteration + 1) as u32,
);
let policy_request =
let mut policy_request =
ToolPolicyRequest::new(call.name.clone(), call.arguments.clone(), context);
if let ToolPolicyDecision::Deny { reason } =
self.tool_policy.check(&policy_request).await
{
if let Some(generated_context) = tool.generated_runtime_context(&call.arguments) {
policy_request = policy_request.with_generated_tool_context(generated_context);
}
let policy_decision = self.tool_policy.check(&policy_request).await;
if let Some(reason) = policy_decision.blocking_reason() {
tracing::debug!(
tool = call.name.as_str(),
policy = self.tool_policy.name(),
Expand Down
104 changes: 103 additions & 1 deletion src/openhuman/agent/harness/session/turn_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use crate::core::event_bus::{global, init_global, DomainEvent};
use crate::openhuman::agent::dispatcher::XmlToolDispatcher;
use crate::openhuman::agent::hooks::{PostTurnHook, TurnContext};
use crate::openhuman::agent::memory_loader::MemoryLoader;
use crate::openhuman::agent::tool_policy::{ToolPolicy, ToolPolicyDecision, ToolPolicyRequest};
use crate::openhuman::agent::tool_policy::{
GeneratedToolRuntimeContext, GeneratedToolRuntimeRisk, ToolPolicy, ToolPolicyDecision,
ToolPolicyRequest,
};
use crate::openhuman::inference::provider::{ChatRequest, ChatResponse, Provider};
use crate::openhuman::memory::Memory;
use crate::openhuman::tools::ToolResult;
Expand Down Expand Up @@ -200,6 +203,64 @@ impl Tool for CountingWriteTool {
}
}

struct GeneratedContextTool {
calls: Arc<AtomicUsize>,
}

#[async_trait]
impl Tool for GeneratedContextTool {
fn name(&self) -> &str {
"generated_send"
}

fn description(&self) -> &str {
"generated send"
}

fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({"type":"object"})
}

async fn execute(&self, _args: serde_json::Value) -> Result<ToolResult> {
self.calls.fetch_add(1, Ordering::SeqCst);
Ok(ToolResult::success("generated-output"))
}

fn generated_runtime_context(
&self,
_args: &serde_json::Value,
) -> Option<GeneratedToolRuntimeContext> {
Some(GeneratedToolRuntimeContext {
provider_id: "mail.runtime".to_string(),
capability_id: "email.send".to_string(),
risk: GeneratedToolRuntimeRisk::ExternalWrite,
source_digest: Some("sha256:abc".to_string()),
approval_id: Some("approval-1".to_string()),
})
}
}

struct RequireGeneratedContextPolicy;

#[async_trait]
impl ToolPolicy for RequireGeneratedContextPolicy {
fn name(&self) -> &str {
"require_generated_context"
}

async fn check(&self, request: &ToolPolicyRequest) -> ToolPolicyDecision {
let context = request
.generated_tool
.as_ref()
.expect("generated tool context should be threaded");
assert_eq!(context.provider_id, "mail.runtime");
assert_eq!(context.capability_id, "email.send");
assert_eq!(context.risk, GeneratedToolRuntimeRisk::ExternalWrite);
assert_eq!(context.approval_id.as_deref(), Some("approval-1"));
ToolPolicyDecision::require_approval("generated context requires approval")
}
}

struct RecordingHook {
calls: Arc<AsyncMutex<Vec<TurnContext>>>,
notify: Arc<Notify>,
Expand Down Expand Up @@ -591,6 +652,47 @@ async fn execute_tool_call_denies_by_policy_before_tool_runs() {
assert!(!record.success);
}

#[tokio::test]
async fn execute_tool_call_threads_generated_tool_context_into_policy() {
let workspace = tempfile::TempDir::new().expect("temp workspace");
let workspace_path = workspace.path().to_path_buf();
std::mem::forget(workspace);
let memory_cfg = crate::openhuman::config::MemoryConfig {
backend: "none".into(),
..crate::openhuman::config::MemoryConfig::default()
};
let mem: Arc<dyn Memory> =
Arc::from(crate::openhuman::memory::create_memory(&memory_cfg, &workspace_path).unwrap());
let calls = Arc::new(AtomicUsize::new(0));

let agent = Agent::builder()
.provider(Box::new(DummyProvider))
.tools(vec![Box::new(GeneratedContextTool {
calls: Arc::clone(&calls),
})])
.memory(mem)
.tool_dispatcher(Box::new(XmlToolDispatcher))
.workspace_dir(workspace_path)
.event_context("turn-test-session", "turn-test-channel")
.tool_policy(Arc::new(RequireGeneratedContextPolicy))
.build()
.unwrap();
let call = ParsedToolCall {
name: "generated_send".into(),
arguments: serde_json::json!({ "value": 1 }),
tool_call_id: Some("policy-generated-1".into()),
};

let (result, record) = agent.execute_tool_call(&call, 0).await;
assert!(!result.success);
assert!(result
.output
.contains("generated context requires approval"));
assert_eq!(calls.load(Ordering::SeqCst), 0);
assert_eq!(record.name, "generated_send");
assert!(!record.success);
}

#[tokio::test]
async fn turn_runs_full_tool_cycle_with_context_and_hooks() {
let provider_impl = Arc::new(SequenceProvider {
Expand Down
Loading
Loading