Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 32 additions & 176 deletions pkg/auth/tokenexchange/exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,17 @@ package tokenexchange

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/url"
"strconv"
"strings"
"time"

"golang.org/x/oauth2"

"github.com/stacklok/toolhive/pkg/oauthproto"
)

// maxResponseBodySize bounds io.LimitReader in executeTokenExchangeRequest so
// a pathological server cannot exhaust memory. The shared pkg/oauthproto
// package has an identical unexported constant, but we cannot import it yet —
// the shared one is consumed by oauthproto.DoTokenRequest, which will replace
// executeTokenExchangeRequest in a follow-up commit.
// TODO: drop when executeTokenExchangeRequest is replaced by oauthproto.DoTokenRequest.
const maxResponseBodySize = 1 << 20

// NormalizeTokenType converts a short token type name to its full URN.
// Accepts both short forms ("access_token", "id_token", "jwt") and full URNs.
// Returns the full URN or an error if the token type is invalid.
Expand Down Expand Up @@ -94,22 +82,6 @@ func (r exchangeRequest) String() string {
r.GrantType, r.Audience, r.Resource, r.Scope, oauthproto.Redact(r.SubjectToken), actorToken)
}

// response is used to decode the remote server response during an OAuth 2.0 token exchange.
type response struct {
AccessToken string `json:"access_token"`
IssuedTokenType string `json:"issued_token_type"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
RefreshToken string `json:"refresh_token"`
}

// String implements fmt.Stringer for response, redacting sensitive tokens.
func (r response) String() string {
return fmt.Sprintf("response{AccessToken: %s, TokenType: %s, ExpiresIn: %d, RefreshToken: %s}",
oauthproto.Redact(r.AccessToken), r.TokenType, r.ExpiresIn, oauthproto.Redact(r.RefreshToken))
}

// clientAuthentication represents OAuth client credentials for token exchange.
type clientAuthentication struct {
ClientID string
Expand Down Expand Up @@ -166,6 +138,13 @@ type ExchangeConfig struct {
}

// Validate checks if the ExchangeConfig contains all required fields.
//
// Side effect: when SubjectTokenType is provided as a short form
// ("access_token", "id_token", "jwt"), Validate normalizes it to the full
// RFC 8693 URN and writes the result back onto the receiver. Callers in
// pkg/vmcp/auth/strategies/tokenexchange.go and pkg/runner/config_builder.go
// read the normalized value after Validate returns, so this mutation is part
// of the documented contract.
func (c *ExchangeConfig) Validate() error {
if c.TokenURL == "" {
return fmt.Errorf("TokenURL is required")
Expand Down Expand Up @@ -267,33 +246,18 @@ func (ts *tokenSource) Token() (*oauth2.Token, error) {
return nil, err
}

// Validate required RFC 8693 response fields
if resp.AccessToken == "" {
return nil, fmt.Errorf("token exchange: server returned empty access_token")
}
if resp.TokenType == "" {
return nil, fmt.Errorf("token exchange: server returned empty token_type")
// RFC 8693 Section 2.2.1 requires token_type in the response. The shared
// oauthproto.ParseTokenResponse is intentionally permissive on this field
// (matching x/oauth2); token exchange tightens it back.
if resp.Token.TokenType == "" {
return nil, fmt.Errorf("token exchange: server returned empty token_type (required by RFC 8693)")
}
// RFC 8693 Section 2.2.1 requires issued_token_type in the response.
if resp.IssuedTokenType == "" {
return nil, fmt.Errorf("token exchange: server returned empty issued_token_type (required by RFC 8693)")
}

// Build oauth2.Token
token := &oauth2.Token{
AccessToken: resp.AccessToken,
TokenType: resp.TokenType,
}

// Set expiry if provided
if resp.ExpiresIn > 0 {
token.Expiry = time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second)
}

if resp.RefreshToken != "" {
token.RefreshToken = resp.RefreshToken
}

return token, nil
return resp.Token, nil
}

// TokenSource returns an oauth2.TokenSource that performs token exchange.
Expand All @@ -312,36 +276,35 @@ func exchangeToken(
request *exchangeRequest,
auth clientAuthentication,
client *http.Client,
) (*response, error) {
data, err := buildTokenExchangeFormData(request)
) (*oauthproto.TokenResponse, error) {
data, err := buildFormData(request)
if err != nil {
return nil, err
}

req, err := createTokenExchangeRequest(ctx, endpoint, data, auth)
req, err := oauthproto.NewFormRequest(ctx, endpoint, data, auth.ClientID, auth.ClientSecret)
if err != nil {
return nil, err
}

if client == nil {
client = oauthproto.DefaultHTTPClient()
return nil, fmt.Errorf("tokenexchange: build request: %w", err)
}

body, err := executeTokenExchangeRequest(client, req)
if err != nil {
return nil, err
}

tokenResp, err := parseTokenExchangeResponse(body)
resp, err := oauthproto.DoTokenRequest(client, req)
if err != nil {
// Preserve the pre-refactor behavior: scrub RetrieveError.Body so raw
// upstream content cannot leak into error strings via err.Error().
// pkg/oauthproto deliberately preserves Body for general-purpose
// callers; tokenexchange opts back into the stricter behavior because
// its errors propagate through vmcp / runner paths that may log them.
var retrieveErr *oauth2.RetrieveError
if errors.As(err, &retrieveErr) {
retrieveErr.Body = nil
}
return nil, err
}

return tokenResp, nil
return resp, nil
}

// buildTokenExchangeFormData constructs the form data for a token exchange request according to RFC 8693.
func buildTokenExchangeFormData(request *exchangeRequest) (url.Values, error) {
// buildFormData constructs the form data for a token exchange request according to RFC 8693.
func buildFormData(request *exchangeRequest) (url.Values, error) {
data := url.Values{}

// Grant type is always token exchange
Expand Down Expand Up @@ -393,110 +356,3 @@ func addOptionalFields(data url.Values, request *exchangeRequest) {
}
}
}

// createTokenExchangeRequest creates an HTTP POST request for token exchange.
// Client credentials are sent via HTTP Basic Authentication as recommended by RFC 6749 Section 2.3.1.
func createTokenExchangeRequest(
ctx context.Context,
endpoint string,
data url.Values,
auth clientAuthentication,
) (*http.Request, error) {
encodedData := data.Encode()
req, err := http.NewRequestWithContext(ctx, "POST", endpoint, strings.NewReader(encodedData))
if err != nil {
return nil, fmt.Errorf("failed to create token exchange request: %w", err)
}

req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Content-Length", strconv.Itoa(len(encodedData)))

// Add client authentication via HTTP Basic Auth per RFC 6749 Section 2.3.1
// Per RFC 6749 and Go's SetBasicAuth documentation, credentials must be URL-encoded
// before being passed to SetBasicAuth for OAuth2 compatibility
if auth.ClientID != "" && auth.ClientSecret != "" {
req.SetBasicAuth(url.QueryEscape(auth.ClientID), url.QueryEscape(auth.ClientSecret))
}

return req, nil
}

// executeTokenExchangeRequest sends the HTTP request and returns the response body.
func executeTokenExchangeRequest(client *http.Client, req *http.Request) ([]byte, error) {
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("token exchange request failed: %w", err)
}
defer func() {
// Close without draining — matches oauthproto.DoTokenRequest. The
// LimitReader below caps how much we read; draining the remainder
// would be unbounded on oversized or never-terminating bodies and
// defeat the 1 MiB memory cap. Connection reuse is the tradeoff.
if closeErr := resp.Body.Close(); closeErr != nil {
slog.Debug("token exchange: close response body", "error", closeErr)
}
}()

body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize))
if err != nil {
return nil, fmt.Errorf("failed to read token exchange response: %w", err)
}

if err := validateResponseStatus(resp, body); err != nil {
return nil, err
}

return body, nil
}

// validateResponseStatus checks the HTTP status code and returns an error if not successful.
// On non-2xx responses it extracts RFC 6749 §5.2 fields (error, error_description, error_uri)
// onto the structured fields of the returned *oauth2.RetrieveError. Body is always cleared so
// callers cannot interpolate raw upstream content into error strings — matching the pattern used
// by Ory Hydra, which never surfaces raw error bodies through its public error type.
func validateResponseStatus(resp *http.Response, body []byte) error {
if resp.StatusCode >= 200 && resp.StatusCode <= 299 {
return nil
}

retrieveErr := &oauth2.RetrieveError{
Response: resp,
Body: body,
}

// Best-effort parse of the RFC 6749 Section 5.2 error response. Non-JSON or
// non-error-shaped bodies leave ErrorCode/ErrorDescription/ErrorURI empty.
var oauthErr struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description,omitempty"`
ErrorURI string `json:"error_uri,omitempty"`
}
if err := json.Unmarshal(body, &oauthErr); err == nil {
retrieveErr.ErrorCode = oauthErr.Error
retrieveErr.ErrorDescription = oauthErr.ErrorDescription
retrieveErr.ErrorURI = oauthErr.ErrorURI
}

if retrieveErr.ErrorCode != "" {
slog.Debug("Token exchange OAuth error",
"oauth_error_code", retrieveErr.ErrorCode,
"description", retrieveErr.ErrorDescription)
} else {
slog.Debug("Token exchange failed", "status", resp.StatusCode, "body_length", len(body), "body", string(body))
}

retrieveErr.Body = nil

return retrieveErr
}

// parseTokenExchangeResponse parses the token exchange response body.
func parseTokenExchangeResponse(body []byte) (*response, error) {
var tokenResp response
if err := json.Unmarshal(body, &tokenResp); err != nil {
slog.Debug("Failed to parse token exchange response", "error", err)
return nil, fmt.Errorf("failed to parse token exchange response: %w", err)
}

return &tokenResp, nil
}
Loading
Loading