Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cmd/thv/app/run_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"log/slog"
"strings"
"time"

"github.com/spf13/cobra"

Expand Down Expand Up @@ -104,6 +105,9 @@ type RunFlags struct {
// Endpoint prefix for SSE endpoint URLs
EndpointPrefix string

// SessionTTL is the session inactivity timeout. Zero uses the transport default.
SessionTTL time.Duration

// Network mode
Network string

Expand Down Expand Up @@ -264,6 +268,8 @@ func AddRunFlags(cmd *cobra.Command, config *RunFlags) {
cmd.Flags().BoolVar(&config.Stateless, "stateless", false,
"Declare the server as stateless (POST-only, no SSE). "+
"Use for MCP servers implementing streamable-HTTP stateless mode.")
cmd.Flags().DurationVar(&config.SessionTTL, "session-ttl", 0,
"Session inactivity timeout (e.g., 30m, 2h); zero uses the default (2h)")
cmd.Flags().StringVar(&config.EndpointPrefix, "endpoint-prefix", "",
"Path prefix to prepend to SSE endpoint URLs (e.g., /playwright)")
cmd.Flags().StringVar(&config.Network, "network", "",
Expand Down Expand Up @@ -665,6 +671,7 @@ func buildRunnerConfig(
runner.WithAllowDockerGateway(runFlags.AllowDockerGateway),
runner.WithTrustProxyHeaders(runFlags.TrustProxyHeaders),
runner.WithStateless(runFlags.Stateless),
runner.WithSessionTTL(runFlags.SessionTTL),
runner.WithEndpointPrefix(runFlags.EndpointPrefix),
runner.WithNetworkMode(runFlags.Network),
runner.WithK8sPodPatch(runFlags.K8sPodPatch),
Expand Down
5 changes: 5 additions & 0 deletions cmd/thv/app/vmcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package app

import (
"fmt"
"time"

"github.com/spf13/cobra"

Expand Down Expand Up @@ -39,6 +40,7 @@ func newVMCPServeCommand() *cobra.Command {
enableEmbedding bool
embeddingModel string
embeddingImage string
sessionTTL time.Duration
)
cmd := &cobra.Command{
Use: "serve",
Expand All @@ -64,6 +66,7 @@ configuration file is needed for the common case of aggregating a local group.`,
EnableEmbedding: enableEmbedding,
EmbeddingModel: embeddingModel,
EmbeddingImage: embeddingImage,
SessionTTL: sessionTTL,
})
},
}
Expand All @@ -80,6 +83,8 @@ configuration file is needed for the common case of aggregating a local group.`,
cmd.Flags().StringVar(&host, "host", "127.0.0.1", "Host address to bind to")
cmd.Flags().IntVar(&port, "port", 4483, "Port to listen on")
cmd.Flags().BoolVar(&enableAudit, "enable-audit", false, "Enable audit logging with default configuration")
cmd.Flags().DurationVar(&sessionTTL, "session-ttl", 0,
"Session inactivity timeout (e.g., 30m, 2h); zero uses the default (30m)")
return cmd
}

Expand Down
5 changes: 5 additions & 0 deletions cmd/vmcp/app/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package app
import (
"fmt"
"log/slog"
"time"

"github.com/spf13/cobra"
"github.com/spf13/viper"
Expand Down Expand Up @@ -96,12 +97,14 @@ from all configured backend MCP servers.`,
host, _ := cmd.Flags().GetString("host")
port, _ := cmd.Flags().GetInt("port")
enableAudit, _ := cmd.Flags().GetBool("enable-audit")
sessionTTL, _ := cmd.Flags().GetDuration("session-ttl")

return vmcpcli.Serve(cmd.Context(), vmcpcli.ServeConfig{
ConfigPath: configPath,
Host: host,
Port: port,
EnableAudit: enableAudit,
SessionTTL: sessionTTL,
})
},
}
Expand All @@ -110,6 +113,8 @@ from all configured backend MCP servers.`,
cmd.Flags().String("host", "127.0.0.1", "Host address to bind to")
cmd.Flags().Int("port", 4483, "Port to listen on")
cmd.Flags().Bool("enable-audit", false, "Enable audit logging with default configuration")
cmd.Flags().Duration("session-ttl", time.Duration(0),
"Session inactivity timeout (e.g., 30m, 2h); zero uses the default (30m)")

return cmd
}
Expand Down
1 change: 1 addition & 0 deletions docs/cli/thv_run.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions docs/cli/thv_vmcp_serve.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions docs/server/docs.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions docs/server/swagger.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions docs/server/swagger.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions pkg/runner/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"io"
"log/slog"
"time"

"github.com/stacklok/toolhive-core/permissions"
v1beta1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1beta1"
Expand Down Expand Up @@ -203,6 +204,10 @@ type RunConfig struct {
// Applies to both remote URLs and local container workloads.
Stateless bool `json:"stateless,omitempty" yaml:"stateless,omitempty"`

// SessionTTL is the inactivity timeout for proxy sessions.
// Zero uses the transport default (2h). Negative values are rejected by the builder.
SessionTTL time.Duration `json:"session_ttl,omitempty" yaml:"session_ttl,omitempty" swaggertype:"primitive,integer"`

// ProxyMode is the effective HTTP protocol the proxy uses.
// For stdio transports, this is the configured mode (sse or streamable-http).
// For direct transports (sse/streamable-http), this matches the transport type.
Expand Down
14 changes: 14 additions & 0 deletions pkg/runner/config_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"path/filepath"
"slices"
"strings"
"time"

"github.com/stacklok/toolhive-core/permissions"
regtypes "github.com/stacklok/toolhive-core/registry/types"
Expand Down Expand Up @@ -362,6 +363,19 @@ func WithEndpointPrefix(prefix string) RunConfigBuilderOption {
}
}

// WithSessionTTL sets the inactivity timeout for proxy sessions.
// Zero is valid and means "use the transport default" (2h).
// Negative values return an error.
func WithSessionTTL(ttl time.Duration) RunConfigBuilderOption {
return func(b *runConfigBuilder) error {
if ttl < 0 {
return fmt.Errorf("session-ttl must be non-negative, got %s", ttl)
}
b.config.SessionTTL = ttl
return nil
}
}

// WithNetworkMode sets the network mode for the container.
// The network mode will be applied to the permission profile after it is loaded.
func WithNetworkMode(networkMode string) RunConfigBuilderOption {
Expand Down
51 changes: 51 additions & 0 deletions pkg/runner/config_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,57 @@ func TestWithRegistryServerName(t *testing.T) {
}
}

func TestWithSessionTTL(t *testing.T) {
t.Parallel()

tests := []struct {
name string
ttl time.Duration
expectErr bool
expectedTTL time.Duration
}{
{
name: "zero is accepted and means use default",
ttl: 0,
expectErr: false,
expectedTTL: 0,
},
{
name: "positive duration is accepted",
ttl: 45 * time.Minute,
expectErr: false,
expectedTTL: 45 * time.Minute,
},
{
name: "large positive duration is accepted",
ttl: 24 * time.Hour,
expectErr: false,
expectedTTL: 24 * time.Hour,
},
{
name: "negative duration returns an error",
ttl: -1 * time.Second,
expectErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

builder := &runConfigBuilder{config: NewRunConfig()}
err := WithSessionTTL(tt.ttl)(builder)

if tt.expectErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expectedTTL, builder.config.SessionTTL)
})
}
}

func TestResolveRegistryServerName(t *testing.T) {
t.Parallel()

Expand Down
11 changes: 10 additions & 1 deletion pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ func (c *RunConfig) GetPort() int {
//
//nolint:gocyclo // This function is complex but manageable
func (r *Runner) Run(ctx context.Context) error {
// Resolve session TTL once so both the transport proxy and Redis storage use
// the same effective value, rather than each applying their own zero-fallback
// independently.
effectiveSessionTTL := r.Config.SessionTTL
if effectiveSessionTTL <= 0 {
effectiveSessionTTL = session.DefaultSessionTTL
}

// Create transport with runtime
transportConfig := types.Config{
Type: r.Config.Transport,
Expand All @@ -177,6 +185,7 @@ func (r *Runner) Run(ctx context.Context) error {
Debug: r.Config.Debug,
TrustProxyHeaders: r.Config.TrustProxyHeaders,
EndpointPrefix: r.Config.EndpointPrefix,
SessionTTL: effectiveSessionTTL,
}

// Set proxy mode for stdio transport
Expand Down Expand Up @@ -368,7 +377,7 @@ func (r *Runner) Run(ctx context.Context) error {
Password: os.Getenv(session.RedisPasswordEnvVar),
DB: int(redisCfg.DB),
KeyPrefix: keyPrefix,
}, session.DefaultSessionTTL)
}, effectiveSessionTTL)
if err != nil {
return fmt.Errorf("failed to create Redis session storage: %w", err)
}
Expand Down
3 changes: 3 additions & 0 deletions pkg/transport/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func (*Factory) Create(config types.Config, opts ...Option) (types.Transport, er
if config.SessionStorage != nil {
stdio.SetSessionStorage(config.SessionStorage)
}
stdio.SetSessionTTL(config.SessionTTL)
tr = stdio
case types.TransportTypeSSE:
httpTransport := NewHTTPTransport(
Expand All @@ -73,6 +74,7 @@ func (*Factory) Create(config types.Config, opts ...Option) (types.Transport, er
config.Middlewares...,
)
httpTransport.sessionStorage = config.SessionStorage
httpTransport.sessionTTL = config.SessionTTL
tr = httpTransport
case types.TransportTypeStreamableHTTP:
httpTransport := NewHTTPTransport(
Expand All @@ -91,6 +93,7 @@ func (*Factory) Create(config types.Config, opts ...Option) (types.Transport, er
config.Middlewares...,
)
httpTransport.sessionStorage = config.SessionStorage
httpTransport.sessionTTL = config.SessionTTL
tr = httpTransport
case types.TransportTypeInspector:
// HTTP transport is not implemented yet
Expand Down
8 changes: 8 additions & 0 deletions pkg/transport/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"os"
"strings"
"sync"
"time"

"golang.org/x/oauth2"

Expand Down Expand Up @@ -81,6 +82,10 @@ type HTTPTransport struct {
// Used for Redis-backed session sharing across replicas.
sessionStorage session.Storage

// sessionTTL overrides the inactivity timeout for sessions managed by the
// underlying proxy. Zero uses the proxy's default.
sessionTTL time.Duration

// Transparent proxy
proxy types.Proxy

Expand Down Expand Up @@ -432,6 +437,9 @@ func (t *HTTPTransport) buildProxyOptions(remoteBasePath, remoteRawQuery string)
if t.stateless {
opts = append(opts, transparent.WithStateless())
}
if t.sessionTTL > 0 {
opts = append(opts, transparent.WithSessionTTL(t.sessionTTL))
}
if t.sessionStorage != nil {
opts = append(opts, transparent.WithSessionStorage(t.sessionStorage))
}
Expand Down
Loading
Loading