diff --git a/identity_provider.go b/identity_provider.go index abaaad68..16da85b3 100644 --- a/identity_provider.go +++ b/identity_provider.go @@ -914,7 +914,11 @@ func (req *IdpAuthnRequest) PostBinding() (IdpAuthnRequestForm, error) { req.ACSEndpoint.Binding) } - form.URL = req.ACSEndpoint.Location + if req.ACSEndpoint.ResponseLocation != nil { + form.URL = *req.ACSEndpoint.ResponseLocation + } else { + form.URL = req.ACSEndpoint.Location + } form.SAMLResponse = base64.StdEncoding.EncodeToString(responseBuf) form.RelayState = req.RelayState diff --git a/samlsp/middleware.go b/samlsp/middleware.go index f5eabb16..e057a031 100644 --- a/samlsp/middleware.go +++ b/samlsp/middleware.go @@ -3,6 +3,8 @@ package samlsp import ( "bytes" "encoding/xml" + "errors" + "fmt" "net/http" "github.com/crewjam/saml" @@ -193,24 +195,31 @@ func (m *Middleware) HandleStartAuthFlow(w http.ResponseWriter, r *http.Request) // CreateSessionFromAssertion is invoked by ServeHTTP when we have a new, valid SAML assertion. func (m *Middleware) CreateSessionFromAssertion(w http.ResponseWriter, r *http.Request, assertion *saml.Assertion, redirectURI string) { + var err error + + trackedRequest := &TrackedRequest{ + Method: "GET", + URI: redirectURI, + } + if trackedRequestIndex := r.Form.Get("RelayState"); trackedRequestIndex != "" { - trackedRequest, err := m.RequestTracker.GetTrackedRequest(r, trackedRequestIndex) + trackedRequest, err = m.RequestTracker.GetTrackedRequest(r, trackedRequestIndex) if err != nil { - if err == http.ErrNoCookie && m.ServiceProvider.AllowIDPInitiated { - if uri := r.Form.Get("RelayState"); uri != "" { - redirectURI = uri + if errors.Is(err, http.ErrNoCookie) && m.ServiceProvider.AllowIDPInitiated { + // We don't need to re-read RelayState from the form and check it for nil. The test above did that + trackedRequest = &TrackedRequest{ + Method: "GET", + URI: trackedRequestIndex, } } else { m.OnError(w, r, err) return } } else { - if err := m.RequestTracker.StopTrackingRequest(w, r, trackedRequestIndex); err != nil { + if err = m.RequestTracker.StopTrackingRequest(w, r, trackedRequestIndex); err != nil { m.OnError(w, r, err) return } - - redirectURI = trackedRequest.URI } } @@ -218,8 +227,36 @@ func (m *Middleware) CreateSessionFromAssertion(w http.ResponseWriter, r *http.R m.OnError(w, r, err) return } + m.HandleRedirectAfterAssertion(w, r, trackedRequest) +} + +// HandleRedirectAfterAssertion is called after we've handled receiving a SAML assertion and created a session with the +// browser. Most normal cases are just a redirect, but if the original request was a POST, it's a little more tricky. +func (m *Middleware) HandleRedirectAfterAssertion(w http.ResponseWriter, r *http.Request, trackedRequest *TrackedRequest) { + switch trackedRequest.Method { + case "POST": + text := fmt.Sprintf(``+ + `
` + + `` + + `` + + `` - http.Redirect(w, r, redirectURI, http.StatusFound) + if _, err := w.Write([]byte(text)); err != nil { + m.OnError(w, r, err) + return + } + return + // TODO: Handle HEAD, DELETE, etc. + default: + http.Redirect(w, r, trackedRequest.URI, http.StatusFound) + } } // RequireAttribute returns a middleware function that requires that the diff --git a/samlsp/middleware_test.go b/samlsp/middleware_test.go index fdb05b20..59f8f0b1 100644 --- a/samlsp/middleware_test.go +++ b/samlsp/middleware_test.go @@ -115,6 +115,7 @@ func (test *MiddlewareTest) makeTrackedRequest(id string) string { Index: "KCosLjAyNDY4Ojw-QEJERkhKTE5QUlRWWFpcXmBiZGZoamxucHJ0dnh6", SAMLRequestID: id, URI: "/frob", + Method: "GET", }) if err != nil { panic(err) diff --git a/samlsp/request_tracker.go b/samlsp/request_tracker.go index f5477c8d..b3cfc5a2 100644 --- a/samlsp/request_tracker.go +++ b/samlsp/request_tracker.go @@ -2,6 +2,7 @@ package samlsp import ( "net/http" + "net/url" ) // RequestTracker tracks pending authentication requests. @@ -31,9 +32,11 @@ type RequestTracker interface { // TrackedRequest holds the data we store for each pending request. type TrackedRequest struct { - Index string `json:"-"` - SAMLRequestID string `json:"id"` - URI string `json:"uri"` + Index string `json:"-"` + SAMLRequestID string `json:"id"` + URI string `json:"uri"` + Method string `json:"method"` + PostData url.Values `json:"post_data"` } // TrackedRequestCodec handles encoding and decoding of a TrackedRequest. diff --git a/samlsp/request_tracker_cookie.go b/samlsp/request_tracker_cookie.go index d9189f63..13bb6fa0 100644 --- a/samlsp/request_tracker_cookie.go +++ b/samlsp/request_tracker_cookie.go @@ -26,10 +26,17 @@ type CookieRequestTracker struct { // TrackRequest starts tracking the SAML request with the given ID. It returns an // `index` that should be used as the RelayState in the SAMl request flow. func (t CookieRequestTracker) TrackRequest(w http.ResponseWriter, r *http.Request, samlRequestID string) (string, error) { + if r.Method == "POST" { + if err := r.ParseForm(); err != nil { + return "", err + } + } trackedRequest := TrackedRequest{ Index: base64.RawURLEncoding.EncodeToString(randomBytes(42)), SAMLRequestID: samlRequestID, URI: r.URL.String(), + Method: r.Method, + PostData: r.PostForm, } if t.RelayStateFunc != nil {