diff --git a/src/main/java/org/seedstack/oauth/OAuthConfig.java b/src/main/java/org/seedstack/oauth/OAuthConfig.java index 2e3d2be..455b767 100644 --- a/src/main/java/org/seedstack/oauth/OAuthConfig.java +++ b/src/main/java/org/seedstack/oauth/OAuthConfig.java @@ -23,6 +23,7 @@ public class OAuthConfig { @NotNull private ProviderConfig provider = new ProviderConfig(); + private RetrieverConfig retriever = new RetrieverConfig(); @NotNull private AlgorithmConfig algorithms = new AlgorithmConfig(); private URI discoveryDocument; @@ -223,7 +224,15 @@ public OAuthConfig setDiscloseUnauthorizedReason(boolean discloseUnauthorizedRea return this; } - @Config("provider") + public RetrieverConfig getRetriever() { + return retriever; + } + + public void setRetriever(RetrieverConfig retriever) { + this.retriever = retriever; + } + + @Config("provider") public static class ProviderConfig { private URI authorization; private URI token; @@ -318,4 +327,23 @@ public AlgorithmConfig setPlainTokenAllowed(boolean plainTokenAllowed) { return this; } } + @Config("retriever") + public static class RetrieverConfig { + private String connectTimeout; + private String readTimeout; + public String getConnectTimeout() { + return connectTimeout; + } + public void setConnectTimeout(String connectTimeout) { + this.connectTimeout = connectTimeout; + } + public String getReadTimeout() { + return readTimeout; + } + public void setReadTimeout(String readTimeout) { + this.readTimeout = readTimeout; + } + + + } } diff --git a/src/main/java/org/seedstack/oauth/internal/OAuthServiceImpl.java b/src/main/java/org/seedstack/oauth/internal/OAuthServiceImpl.java index dc08474..9e99da4 100644 --- a/src/main/java/org/seedstack/oauth/internal/OAuthServiceImpl.java +++ b/src/main/java/org/seedstack/oauth/internal/OAuthServiceImpl.java @@ -16,6 +16,7 @@ import com.nimbusds.jose.proc.JWSKeySelector; import com.nimbusds.jose.proc.JWSVerificationKeySelector; import com.nimbusds.jose.proc.SecurityContext; +import com.nimbusds.jose.util.DefaultResourceRetriever; import com.nimbusds.jwt.*; import com.nimbusds.jwt.proc.ConfigurableJWTProcessor; import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier; @@ -178,14 +179,11 @@ private JWTClaimsSet validateJwtAccessToken(JWT accessToken, Algorithm algorithm // Signing key selector oauthProvider.getJwksEndpoint().ifPresent(jwksEndpoint -> { - try { - JWKSource keySource = new RemoteJWKSet<>(jwksEndpoint.toURL()); + JWKSource keySource = getkeySource(jwksEndpoint); JWSAlgorithm expectedAlg = JWSAlgorithm.parse(oauthConfig.algorithms().getAccessSigningAlgorithm()); JWSKeySelector keySelector = new JWSVerificationKeySelector<>(expectedAlg, keySource); jwtProcessor.setJWSKeySelector(keySelector); - } catch (MalformedURLException e) { - throw new TokenValidationException("Invalid JWKS endpoint: " + e.getMessage()); - } + }); // Claims verification @@ -219,6 +217,24 @@ private JWTClaimsSet validateJwtAccessToken(JWT accessToken, Algorithm algorithm throw new TokenValidationException("Unable to validate JWT access token: " + e.getMessage(), e); } } + + private JWKSource getkeySource(URI jwksEndpoint){ + try { + JWKSource keySource = new RemoteJWKSet<>(jwksEndpoint.toURL()); + String connectTimeout =oauthConfig.getRetriever().getConnectTimeout(); + String readTimeOut=oauthConfig.getRetriever().getReadTimeout(); + if(connectTimeout!=null && ! connectTimeout.equals("") && + readTimeOut!=null && ! readTimeOut.equals("")) { + DefaultResourceRetriever defaultResourceRetriever= + new DefaultResourceRetriever(Integer.parseInt(connectTimeout), + Integer.parseInt(readTimeOut)); + keySource = new RemoteJWKSet<>(jwksEndpoint.toURL(),defaultResourceRetriever); + } + return keySource; + } catch (MalformedURLException e) { + throw new TokenValidationException("Invalid JWKS endpoint: " + e.getMessage()); + } + } private JWTClaimsSet validateOpaqueAccessToken(AccessToken accessToken) { AccessTokenValidator accessTokenValidator = accessTokenValidatorProvider.get(); @@ -313,8 +329,9 @@ Optional fetchUserInfo(String accessToken) { if (userInfoResponse.indicatesSuccess()) { return Optional.of(((UserInfoSuccessResponse) userInfoResponse).getUserInfo()); } else { - LOGGER.warn("Unable to fetch user info: {}", OAuthUtils.buildGenericError(((ErrorResponse) userInfoResponse)).getDescription()); - return Optional.empty(); + throw new TokenValidationException("Unable to validate the access token (HTTP status " + + userInfoResponse.toErrorResponse().getErrorObject().getHTTPStatusCode() + "): " + + userInfoResponse.toErrorResponse().getErrorObject().getDescription()); } } return Optional.empty();