Skip to content

Commit 2445a35

Browse files
committed
code review changes
1 parent 598390e commit 2445a35

2 files changed

Lines changed: 62 additions & 27 deletions

File tree

pkg/auth/auth_client.go

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ import (
1717
"time"
1818
)
1919

20+
type Clock interface {
21+
Now() time.Time
22+
}
23+
24+
type realClock struct{}
25+
26+
func (_ realClock) Now() time.Time { return time.Now() }
27+
2028
const tokenRefreshMargin = 10 * time.Second
2129

2230
// KeycloakConfig holds the credentials and configuration details
@@ -44,30 +52,24 @@ type KeycloakClient struct {
4452
cfg KeycloakConfig
4553
tokenInfo tokenInfo
4654
tokenMutex sync.RWMutex
55+
56+
clock Clock // to mock time in tests
4757
}
4858

4959
// NewKeycloakClient creates a new KeycloakClient.
5060
func NewKeycloakClient(httpClient *http.Client, cfg KeycloakConfig) *KeycloakClient {
5161
return &KeycloakClient{
5262
httpClient: httpClient,
5363
cfg: cfg,
64+
clock: realClock{},
5465
}
5566
}
5667

5768
// GetToken retrieves a valid access token. The token is cached and refreshed before expiry.
5869
// The token is obtained by `Resource owner password credentials grant` flow.
5970
// Ref: https://www.keycloak.org/docs/latest/server_admin/index.html#_oidc-auth-flows-direct
6071
func (c *KeycloakClient) GetToken(ctx context.Context) (string, error) {
61-
getCachedToken := func() (token string, ok bool) {
62-
c.tokenMutex.RLock()
63-
defer c.tokenMutex.RUnlock()
64-
if time.Now().Before(c.tokenInfo.ExpiresAt.Add(-tokenRefreshMargin)) {
65-
return c.tokenInfo.AccessToken, true
66-
}
67-
return "", false
68-
}
69-
70-
token, ok := getCachedToken()
72+
token, ok := c.getCachedToken()
7173
if ok {
7274
return token, nil
7375
}
@@ -77,7 +79,7 @@ func (c *KeycloakClient) GetToken(ctx context.Context) (string, error) {
7779
defer c.tokenMutex.Unlock()
7880

7981
// check again in case another goroutine already refreshed the token
80-
if time.Now().Before(c.tokenInfo.ExpiresAt.Add(-tokenRefreshMargin)) {
82+
if c.clock.Now().Before(c.tokenInfo.ExpiresAt.Add(-tokenRefreshMargin)) {
8183
return c.tokenInfo.AccessToken, nil
8284
}
8385

@@ -88,15 +90,22 @@ func (c *KeycloakClient) GetToken(ctx context.Context) (string, error) {
8890

8991
c.tokenInfo = tokenInfo{
9092
AccessToken: authResponse.AccessToken,
91-
ExpiresAt: time.Now().Add(time.Duration(authResponse.ExpiresIn) * time.Second),
93+
ExpiresAt: c.clock.Now().UTC().Add(time.Duration(authResponse.ExpiresIn) * time.Second),
9294
}
9395

9496
return authResponse.AccessToken, nil
9597
}
9698

97-
func (c *KeycloakClient) requestToken(ctx context.Context) (authResponse, error) {
98-
var empty authResponse
99+
func (c *KeycloakClient) getCachedToken() (token string, ok bool) {
100+
c.tokenMutex.RLock()
101+
defer c.tokenMutex.RUnlock()
102+
if c.clock.Now().Before(c.tokenInfo.ExpiresAt.Add(-tokenRefreshMargin)) {
103+
return c.tokenInfo.AccessToken, true
104+
}
105+
return "", false
106+
}
99107

108+
func (c *KeycloakClient) requestToken(ctx context.Context) (*authResponse, error) {
100109
data := url.Values{}
101110
data.Set("client_id", c.cfg.ClientID)
102111
data.Set("password", c.cfg.Password)
@@ -108,13 +117,13 @@ func (c *KeycloakClient) requestToken(ctx context.Context) (authResponse, error)
108117

109118
req, err := http.NewRequestWithContext(ctx, http.MethodPost, authenticationURL, strings.NewReader(data.Encode()))
110119
if err != nil {
111-
return empty, fmt.Errorf("failed to create authentication request: %w", err)
120+
return nil, fmt.Errorf("failed to create authentication request: %w", err)
112121
}
113122
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
114123

115124
resp, err := c.httpClient.Do(req)
116125
if err != nil {
117-
return empty, fmt.Errorf("failed to execute authentication request with retry: %w", err)
126+
return nil, fmt.Errorf("failed to execute authentication request: %w", err)
118127
}
119128
defer func() { _ = resp.Body.Close() }()
120129

@@ -123,13 +132,13 @@ func (c *KeycloakClient) requestToken(ctx context.Context) (authResponse, error)
123132
if err != nil {
124133
respBody = []byte("failed to read response body: " + err.Error())
125134
}
126-
return empty, fmt.Errorf("authentication request failed with status: %s: %s", resp.Status, string(respBody))
135+
return nil, fmt.Errorf("authentication request failed with status: %s: %s", resp.Status, string(respBody))
127136
}
128137

129-
var authResponse authResponse
130-
if err := json.NewDecoder(resp.Body).Decode(&authResponse); err != nil {
131-
return empty, fmt.Errorf("failed to parse authentication response: %w", err)
138+
var authResp authResponse
139+
if err := json.NewDecoder(resp.Body).Decode(&authResp); err != nil {
140+
return nil, fmt.Errorf("failed to parse authentication response: %w", err)
132141
}
133142

134-
return authResponse, nil
143+
return &authResp, nil
135144
}

pkg/auth/auth_client_test.go

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@ package auth
66

77
import (
88
"context"
9+
"fmt"
910
"net/http"
1011
"net/http/httptest"
1112
"sync/atomic"
1213
"testing"
14+
"time"
1315

1416
"github.com/stretchr/testify/assert"
1517
"github.com/stretchr/testify/require"
@@ -62,34 +64,54 @@ func TestKeycloakClient_GetToken(t *testing.T) {
6264
}
6365
}
6466

67+
type fakeClock struct {
68+
currentTime time.Time
69+
}
70+
71+
func (fc *fakeClock) Now() time.Time {
72+
return fc.currentTime
73+
}
74+
75+
func (fc *fakeClock) Advance(d time.Duration) {
76+
fc.currentTime = fc.currentTime.Add(d)
77+
}
78+
79+
func NewFakeClock(startTime time.Time) *fakeClock {
80+
return &fakeClock{currentTime: startTime}
81+
}
82+
6583
func TestKeycloakClient_GetToken_Refresh(t *testing.T) {
84+
tokenValidity := 60 * time.Second
85+
kcMockResponse := []byte(fmt.Sprintf(`{"access_token": "test-token", "expires_in": %d}`, int(tokenValidity.Seconds())))
86+
6687
tests := map[string]struct {
6788
responseBody string
6889
responseCode int
90+
requestAfter time.Duration
6991
wantServerCalled int
7092
wantToken string
7193
}{
7294
"token is cached": {
73-
responseBody: `{"access_token": "test-token", "expires_in": 3600}`,
74-
responseCode: http.StatusOK,
95+
requestAfter: tokenValidity - tokenRefreshMargin - time.Nanosecond,
7596
wantServerCalled: 1, // should be called only once due to caching
7697
wantToken: "test-token",
7798
},
7899
"token expiry handling": {
79-
responseBody: `{"access_token": "test-token", "expires_in": 0}`, // expires immediately
80-
responseCode: http.StatusOK,
100+
requestAfter: tokenValidity - tokenRefreshMargin + time.Nanosecond,
81101
wantServerCalled: 2, // should be called twice due to expiry
82102
wantToken: "test-token",
83103
},
84104
}
85105

86106
for name, tc := range tests {
87107
t.Run(name, func(t *testing.T) {
108+
fakeClock := NewFakeClock(time.Now())
109+
88110
var serverCallCount atomic.Int32
89111
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
90112
serverCallCount.Add(1)
91-
w.WriteHeader(tc.responseCode)
92-
_, err := w.Write([]byte(tc.responseBody))
113+
w.WriteHeader(200)
114+
_, err := w.Write(kcMockResponse)
93115
require.NoError(t, err)
94116
}))
95117
defer mockServer.Close()
@@ -98,9 +120,13 @@ func TestKeycloakClient_GetToken_Refresh(t *testing.T) {
98120
AuthURL: mockServer.URL,
99121
// the other fields are also required in real scenario, but omit here for brevity
100122
})
123+
client.clock = fakeClock
124+
101125
_, err := client.GetToken(context.Background())
102126
require.NoError(t, err)
103127

128+
fakeClock.Advance(tc.requestAfter)
129+
104130
gotToken, err := client.GetToken(context.Background()) // second call to test caching/refresh
105131
require.NoError(t, err)
106132

0 commit comments

Comments
 (0)