@@ -17,6 +17,14 @@ import (
1717 "time"
1818)
1919
20+ type Clock interface {
21+ Now () time.Time
22+ }
23+
24+ type realClock struct {}
25+
26+ func (_ realClock ) Now () time.Time { return time .Now () }
27+
2028const tokenRefreshMargin = 10 * time .Second
2129
2230// KeycloakConfig holds the credentials and configuration details
@@ -44,30 +52,24 @@ type KeycloakClient struct {
4452 cfg KeycloakConfig
4553 tokenInfo tokenInfo
4654 tokenMutex sync.RWMutex
55+
56+ clock Clock // to mock time in tests
4757}
4858
4959// NewKeycloakClient creates a new KeycloakClient.
5060func NewKeycloakClient (httpClient * http.Client , cfg KeycloakConfig ) * KeycloakClient {
5161 return & KeycloakClient {
5262 httpClient : httpClient ,
5363 cfg : cfg ,
64+ clock : realClock {},
5465 }
5566}
5667
5768// GetToken retrieves a valid access token. The token is cached and refreshed before expiry.
5869// The token is obtained by `Resource owner password credentials grant` flow.
5970// Ref: https://www.keycloak.org/docs/latest/server_admin/index.html#_oidc-auth-flows-direct
6071func (c * KeycloakClient ) GetToken (ctx context.Context ) (string , error ) {
61- getCachedToken := func () (token string , ok bool ) {
62- c .tokenMutex .RLock ()
63- defer c .tokenMutex .RUnlock ()
64- if time .Now ().Before (c .tokenInfo .ExpiresAt .Add (- tokenRefreshMargin )) {
65- return c .tokenInfo .AccessToken , true
66- }
67- return "" , false
68- }
69-
70- token , ok := getCachedToken ()
72+ token , ok := c .getCachedToken ()
7173 if ok {
7274 return token , nil
7375 }
@@ -77,7 +79,7 @@ func (c *KeycloakClient) GetToken(ctx context.Context) (string, error) {
7779 defer c .tokenMutex .Unlock ()
7880
7981 // check again in case another goroutine already refreshed the token
80- if time .Now ().Before (c .tokenInfo .ExpiresAt .Add (- tokenRefreshMargin )) {
82+ if c . clock .Now ().Before (c .tokenInfo .ExpiresAt .Add (- tokenRefreshMargin )) {
8183 return c .tokenInfo .AccessToken , nil
8284 }
8385
@@ -88,15 +90,22 @@ func (c *KeycloakClient) GetToken(ctx context.Context) (string, error) {
8890
8991 c .tokenInfo = tokenInfo {
9092 AccessToken : authResponse .AccessToken ,
91- ExpiresAt : time . Now ().Add (time .Duration (authResponse .ExpiresIn ) * time .Second ),
93+ ExpiresAt : c . clock . Now (). UTC ().Add (time .Duration (authResponse .ExpiresIn ) * time .Second ),
9294 }
9395
9496 return authResponse .AccessToken , nil
9597}
9698
97- func (c * KeycloakClient ) requestToken (ctx context.Context ) (authResponse , error ) {
98- var empty authResponse
99+ func (c * KeycloakClient ) getCachedToken () (token string , ok bool ) {
100+ c .tokenMutex .RLock ()
101+ defer c .tokenMutex .RUnlock ()
102+ if c .clock .Now ().Before (c .tokenInfo .ExpiresAt .Add (- tokenRefreshMargin )) {
103+ return c .tokenInfo .AccessToken , true
104+ }
105+ return "" , false
106+ }
99107
108+ func (c * KeycloakClient ) requestToken (ctx context.Context ) (* authResponse , error ) {
100109 data := url.Values {}
101110 data .Set ("client_id" , c .cfg .ClientID )
102111 data .Set ("password" , c .cfg .Password )
@@ -108,13 +117,13 @@ func (c *KeycloakClient) requestToken(ctx context.Context) (authResponse, error)
108117
109118 req , err := http .NewRequestWithContext (ctx , http .MethodPost , authenticationURL , strings .NewReader (data .Encode ()))
110119 if err != nil {
111- return empty , fmt .Errorf ("failed to create authentication request: %w" , err )
120+ return nil , fmt .Errorf ("failed to create authentication request: %w" , err )
112121 }
113122 req .Header .Set ("Content-Type" , "application/x-www-form-urlencoded" )
114123
115124 resp , err := c .httpClient .Do (req )
116125 if err != nil {
117- return empty , fmt .Errorf ("failed to execute authentication request with retry : %w" , err )
126+ return nil , fmt .Errorf ("failed to execute authentication request: %w" , err )
118127 }
119128 defer func () { _ = resp .Body .Close () }()
120129
@@ -123,13 +132,13 @@ func (c *KeycloakClient) requestToken(ctx context.Context) (authResponse, error)
123132 if err != nil {
124133 respBody = []byte ("failed to read response body: " + err .Error ())
125134 }
126- return empty , fmt .Errorf ("authentication request failed with status: %s: %s" , resp .Status , string (respBody ))
135+ return nil , fmt .Errorf ("authentication request failed with status: %s: %s" , resp .Status , string (respBody ))
127136 }
128137
129- var authResponse authResponse
130- if err := json .NewDecoder (resp .Body ).Decode (& authResponse ); err != nil {
131- return empty , fmt .Errorf ("failed to parse authentication response: %w" , err )
138+ var authResp authResponse
139+ if err := json .NewDecoder (resp .Body ).Decode (& authResp ); err != nil {
140+ return nil , fmt .Errorf ("failed to parse authentication response: %w" , err )
132141 }
133142
134- return authResponse , nil
143+ return & authResp , nil
135144}
0 commit comments