Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -210,19 +210,20 @@ public void setAuthorizationHeaderSync(HttpPipelineCallContext context, TokenReq
private Mono<Void> setAuthorizationHeaderHelper(HttpPipelineCallContext context,
TokenRequestContext tokenRequestContext, boolean checkToForceFetchToken) {
return cache.getToken(tokenRequestContext, checkToForceFetchToken).flatMap(token -> {
setAuthorizationHeader(context.getHttpRequest().getHeaders(), token.getToken());
setAuthorizationHeader(context.getHttpRequest().getHeaders(), token);
return Mono.empty();
});
}

private void setAuthorizationHeaderHelperSync(HttpPipelineCallContext context,
TokenRequestContext tokenRequestContext, boolean checkToForceFetchToken) {
AccessToken token = cache.getTokenSync(tokenRequestContext, checkToForceFetchToken);
setAuthorizationHeader(context.getHttpRequest().getHeaders(), token.getToken());
setAuthorizationHeader(context.getHttpRequest().getHeaders(), token);
}

private static void setAuthorizationHeader(HttpHeaders headers, String token) {
headers.set(HttpHeaderName.AUTHORIZATION, BEARER + " " + token);
private static void setAuthorizationHeader(HttpHeaders headers, AccessToken token) {
String tokenType = CoreUtils.isNullOrEmpty(token.getTokenType()) ? BEARER : token.getTokenType();
headers.set(HttpHeaderName.AUTHORIZATION, tokenType + " " + token.getToken());
}

private TokenRequestContext getTokenRequestContextForCaeChallenge(HttpResponse response) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.azure.core.http.MockHttpResponse;
import com.azure.core.implementation.http.policy.AuthorizationChallengeParser;
import com.azure.core.util.Context;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
Expand All @@ -32,6 +33,45 @@

public class BearerTokenAuthenticationPolicyTests {

@Test
public void usesAccessTokenTypeInAuthorizationHeader() {
TokenCredential credential
= request -> Mono.just(new AccessToken("token", OffsetDateTime.now().plusHours(2), null, "Pop"));
BearerTokenAuthenticationPolicy policy = new BearerTokenAuthenticationPolicy(credential, "scope");
AtomicReference<String> authorizationHeader = new AtomicReference<>();
HttpClient client = request -> {
authorizationHeader.set(request.getHeaders().getValue(HttpHeaderName.AUTHORIZATION));
return Mono.just(new MockHttpResponse(request, 200));
};

HttpPipeline pipeline = new HttpPipelineBuilder().policies(policy).httpClient(client).build();

StepVerifier.create(pipeline.send(new HttpRequest(HttpMethod.GET, "https://localhost")))
.assertNext(response -> assertEquals(200, response.getStatusCode()))
.verifyComplete();
assertEquals("Pop token", authorizationHeader.get());
}

@Test
public void usesAccessTokenTypeInAuthorizationHeaderSync() {
TokenCredential credential
= request -> Mono.just(new AccessToken("token", OffsetDateTime.now().plusHours(2), null, "Pop"));
BearerTokenAuthenticationPolicy policy = new BearerTokenAuthenticationPolicy(credential, "scope");
AtomicReference<String> authorizationHeader = new AtomicReference<>();
HttpClient client = request -> {
authorizationHeader.set(request.getHeaders().getValue(HttpHeaderName.AUTHORIZATION));
return Mono.just(new MockHttpResponse(request, 200));
};

HttpPipeline pipeline = new HttpPipelineBuilder().policies(policy).httpClient(client).build();

try (HttpResponse response
= pipeline.sendSync(new HttpRequest(HttpMethod.GET, "https://localhost"), Context.NONE)) {
assertEquals(200, response.getStatusCode());
}
assertEquals("Pop token", authorizationHeader.get());
}

@ParameterizedTest
@MethodSource("caeTestArguments")
public void testDefaultCae(String challenge, int expectedStatusCode, String expectedClaims, String encodedClaims) {
Expand Down
Loading