Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,16 @@ impl SessionManager for LocalSessionManager {
Ok(response)
}
async fn close_session(&self, id: &SessionId) -> Result<(), Self::Error> {
let mut sessions = self.sessions.write().await;
if let Some(handle) = sessions.remove(id) {
handle.close().await?;
let handle = {
let mut sessions = self.sessions.write().await;
sessions.remove(id)
};
if let Some(handle) = handle {
match handle.close().await {
// Worker already exited — nothing left to clean up.
Ok(()) | Err(SessionError::SessionServiceTerminated) => {}
Err(e) => return Err(e.into()),
}
}
Ok(())
}
Expand Down Expand Up @@ -928,8 +935,6 @@ pub enum LocalSessionWorkerError {
FailToSendInitializeRequest(SessionError),
#[error("fail to handle message: {0}")]
FailToHandleMessage(SessionError),
#[error("keep alive timeout after {}ms", _0.as_millis())]
KeepAliveTimeout(Duration),
Comment on lines -931 to -932
Copy link
Copy Markdown
Member

@DaleSeo DaleSeo May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The removed LocalSessionWorkerError::KeepAliveTimeout variant was not part of the public API.

I think it's reachable via fully pub module chain:
rmcp::transport::streamable_http_server::session::local::LocalSessionWorkerError::KeepAliveTimeout

Could we keep the variant with #[deprecated] to avoid a breaking change?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

#[error("Transport closed")]
TransportClosed,
#[error("Tokio join error {0}")]
Expand Down Expand Up @@ -1008,7 +1013,7 @@ impl Worker for LocalSessionWorker {
return Err(WorkerQuitReason::Cancelled)
}
_ = keep_alive_timeout => {
return Err(WorkerQuitReason::fatal(LocalSessionWorkerError::KeepAliveTimeout(keep_alive), "poll next session event"))
return Err(WorkerQuitReason::IdleTimeout(keep_alive))
}
};
match event {
Expand Down
7 changes: 5 additions & 2 deletions crates/rmcp/src/transport/worker.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::borrow::Cow;
use std::{borrow::Cow, time::Duration};

use tokio_util::sync::CancellationToken;
use tracing::{Instrument, Level};
Expand All @@ -22,6 +22,8 @@ pub enum WorkerQuitReason<E> {
TransportClosed,
#[error("Handler terminated")]
HandlerTerminated,
#[error("Worker idle timeout ({}ms)", _0.as_millis())]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#[error("Worker idle timeout ({}ms)", _0.as_millis())]
#[error("Worker idle timeout after {}ms", _0.as_millis())]

IdleTimeout(Duration),
}

impl<E: std::error::Error + Send + 'static> WorkerQuitReason<E> {
Expand Down Expand Up @@ -122,7 +124,8 @@ impl<W: Worker> WorkerTransport<W> {
.inspect_err(|e| match e {
WorkerQuitReason::Cancelled
| WorkerQuitReason::TransportClosed
| WorkerQuitReason::HandlerTerminated => {
| WorkerQuitReason::HandlerTerminated
| WorkerQuitReason::IdleTimeout(_) => {
tracing::debug!("worker quit with reason: {:?}", e);
}
WorkerQuitReason::Join(e) => {
Expand Down
246 changes: 246 additions & 0 deletions crates/rmcp/tests/test_streamable_http_idle_timeout_log.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
#![cfg(all(
feature = "transport-streamable-http-server",
feature = "transport-streamable-http-client-reqwest",
not(feature = "local")
))]

use std::{
sync::{Arc, Mutex},
time::Duration,
};

use rmcp::transport::streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService,
session::{SessionManager, local::LocalSessionManager},
};
use tokio_util::sync::CancellationToken;
use tracing_subscriber::layer::SubscriberExt;

mod common;
use common::calculator::Calculator;

struct CapturedEvent {
level: tracing::Level,
message: String,
}

struct CapturingLayer {
events: Arc<Mutex<Vec<CapturedEvent>>>,
}

impl<S: tracing::Subscriber> tracing_subscriber::Layer<S> for CapturingLayer {
fn on_event(
&self,
event: &tracing::Event<'_>,
_ctx: tracing_subscriber::layer::Context<'_, S>,
) {
let mut visitor = MessageVisitor(String::new());
event.record(&mut visitor);
self.events.lock().unwrap().push(CapturedEvent {
level: *event.metadata().level(),
message: visitor.0,
});
}
}

struct MessageVisitor(String);

impl tracing::field::Visit for MessageVisitor {
fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
if field.name() == "message" {
self.0 = format!("{:?}", value);
}
}
}

#[tokio::test(flavor = "current_thread")]
async fn test_keep_alive_timeout_does_not_emit_error_log() {
let events = Arc::new(Mutex::new(Vec::<CapturedEvent>::new()));

let subscriber = tracing_subscriber::registry().with(CapturingLayer {
events: events.clone(),
});

let _guard = tracing::subscriber::set_default(subscriber);

let ct = CancellationToken::new();
let mut session_manager = LocalSessionManager::default();
session_manager.session_config.keep_alive = Some(Duration::from_millis(200));
let session_manager = Arc::new(session_manager);

let service = StreamableHttpService::new(
|| Ok(Calculator::new()),
session_manager.clone(),
StreamableHttpServerConfig::default()
.with_sse_keep_alive(None)
.with_cancellation_token(ct.child_token()),
);

let router = axum::Router::new().nest_service("/mcp", service);
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = tcp_listener.local_addr().unwrap();

tokio::spawn({
let ct = ct.clone();
async move {
let _ = axum::serve(tcp_listener, router)
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
.await;
}
});

let client = reqwest::Client::new();

let response = client
.post(format!("http://{addr}/mcp"))
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}"#)
.send()
.await
.unwrap();
assert_eq!(response.status(), 200);
let session_id = response.headers()["mcp-session-id"]
.to_str()
.unwrap()
.to_string();

client
.post(format!("http://{addr}/mcp"))
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.header("mcp-session-id", &session_id)
.header("Mcp-Protocol-Version", "2025-06-18")
.body(r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#)
.send()
.await
.unwrap();

tokio::time::sleep(Duration::from_millis(400)).await;

// Wait until close_session() has completed so all logs are captured.
let session_id_parsed: Arc<str> = Arc::from(session_id.as_str());
for _ in 0..20 {
if !session_manager
.has_session(&session_id_parsed)
.await
.unwrap()
{
break;
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
assert!(
!session_manager
.has_session(&session_id_parsed)
.await
.unwrap(),
"session should have been removed after idle reap"
);

let captured = events.lock().unwrap();

let error_events: Vec<_> = captured
.iter()
.filter(|e| e.level == tracing::Level::ERROR)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we also filter errors by target? Any unrelated subsystem that logs at ERROR would flake this test.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added target field and scoped both filters to starts_with("rmcp")

.collect();
assert!(
error_events.is_empty(),
"idle reap should not produce any ERROR logs, found {}: {:?}",
error_events.len(),
error_events.iter().map(|e| &e.message).collect::<Vec<_>>()
);

let debug_events: Vec<_> = captured
.iter()
.filter(|e| e.level == tracing::Level::DEBUG && e.message.contains("IdleTimeout"))
.collect();
assert!(
!debug_events.is_empty(),
"expected a DEBUG log with IdleTimeout, but found none"
);

ct.cancel();
}

#[tokio::test(flavor = "current_thread")]
async fn test_explicit_close_on_live_session_succeeds() {
let ct = CancellationToken::new();
let mut session_manager = LocalSessionManager::default();
session_manager.session_config.keep_alive = Some(Duration::from_secs(60));
let session_manager = Arc::new(session_manager);

let service = StreamableHttpService::new(
|| Ok(Calculator::new()),
session_manager.clone(),
StreamableHttpServerConfig::default()
.with_sse_keep_alive(None)
.with_cancellation_token(ct.child_token()),
);

let router = axum::Router::new().nest_service("/mcp", service);
let tcp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = tcp_listener.local_addr().unwrap();

tokio::spawn({
let ct = ct.clone();
async move {
let _ = axum::serve(tcp_listener, router)
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
.await;
}
});

let client = reqwest::Client::new();

let response = client
.post(format!("http://{addr}/mcp"))
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.body(r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"test","version":"1.0"}}}"#)
.send()
.await
.unwrap();
assert_eq!(response.status(), 200);
let session_id = response.headers()["mcp-session-id"]
.to_str()
.unwrap()
.to_string();

client
.post(format!("http://{addr}/mcp"))
.header("Content-Type", "application/json")
.header("Accept", "application/json, text/event-stream")
.header("mcp-session-id", &session_id)
.header("Mcp-Protocol-Version", "2025-06-18")
.body(r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#)
.send()
.await
.unwrap();

let session_id_parsed: Arc<str> = Arc::from(session_id.as_str());

assert!(
session_manager
.has_session(&session_id_parsed)
.await
.unwrap(),
"session should exist before explicit close"
);

let result = session_manager.close_session(&session_id_parsed).await;
assert!(
result.is_ok(),
"close_session on a live worker should succeed: {result:?}"
);

assert!(
!session_manager
.has_session(&session_id_parsed)
.await
.unwrap(),
"session should not exist after explicit close"
);

ct.cancel();
}
Loading