diff --git a/pom.xml b/pom.xml index fd68196..82b8393 100644 --- a/pom.xml +++ b/pom.xml @@ -19,7 +19,7 @@ org.seedstack.addons.oauth oauth - 3.3.0-SNAPSHOT + 3.3.1 3.12.0 diff --git a/src/main/java/org/seedstack/oauth/internal/OAuthAuthenticationFilter.java b/src/main/java/org/seedstack/oauth/internal/OAuthAuthenticationFilter.java index a309871..eea30f4 100644 --- a/src/main/java/org/seedstack/oauth/internal/OAuthAuthenticationFilter.java +++ b/src/main/java/org/seedstack/oauth/internal/OAuthAuthenticationFilter.java @@ -7,18 +7,23 @@ */ package org.seedstack.oauth.internal; -import com.google.common.base.Strings; -import com.nimbusds.oauth2.sdk.AuthorizationRequest; -import com.nimbusds.oauth2.sdk.ParseException; -import com.nimbusds.oauth2.sdk.ResponseType; -import com.nimbusds.oauth2.sdk.Scope; -import com.nimbusds.oauth2.sdk.id.ClientID; -import com.nimbusds.oauth2.sdk.id.State; -import com.nimbusds.oauth2.sdk.token.AccessToken; -import com.nimbusds.oauth2.sdk.token.BearerAccessToken; -import com.nimbusds.oauth2.sdk.token.TypelessAccessToken; -import com.nimbusds.openid.connect.sdk.AuthenticationRequest; -import com.nimbusds.openid.connect.sdk.Nonce; +import static com.google.common.base.Preconditions.checkNotNull; +import static org.apache.shiro.web.util.WebUtils.issueRedirect; +import static org.seedstack.oauth.internal.OAuthUtils.OPENID_SCOPE; +import static org.seedstack.oauth.internal.OAuthUtils.createScope; + +import java.io.IOException; +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import javax.inject.Inject; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.apache.shiro.SecurityUtils; import org.apache.shiro.authc.AuthenticationException; import org.apache.shiro.authc.AuthenticationToken; @@ -35,21 +40,18 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import javax.inject.Inject; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.io.IOException; -import java.net.URI; -import java.util.List; -import java.util.Map; -import java.util.Optional; - -import static com.google.common.base.Preconditions.checkNotNull; -import static org.apache.shiro.web.util.WebUtils.issueRedirect; -import static org.seedstack.oauth.internal.OAuthUtils.OPENID_SCOPE; -import static org.seedstack.oauth.internal.OAuthUtils.createScope; +import com.google.common.base.Strings; +import com.nimbusds.oauth2.sdk.AuthorizationRequest; +import com.nimbusds.oauth2.sdk.ParseException; +import com.nimbusds.oauth2.sdk.ResponseType; +import com.nimbusds.oauth2.sdk.Scope; +import com.nimbusds.oauth2.sdk.id.ClientID; +import com.nimbusds.oauth2.sdk.id.State; +import com.nimbusds.oauth2.sdk.token.AccessToken; +import com.nimbusds.oauth2.sdk.token.BearerAccessToken; +import com.nimbusds.oauth2.sdk.token.TypelessAccessToken; +import com.nimbusds.openid.connect.sdk.AuthenticationRequest; +import com.nimbusds.openid.connect.sdk.Nonce; @SecurityFilter("oauth") public class OAuthAuthenticationFilter extends AuthenticatingFilter implements SessionRegeneratingFilter { @@ -101,31 +103,31 @@ protected boolean onAccessDenied(ServletRequest request, ServletResponse respons loggedIn = executeLogin(request, response); } if (!loggedIn) { - if (oauthConfig.getRedirect() != null) { - redirectToAuthorizationEndpoint(request, response); - } else { - try { - ((HttpServletResponse) response).sendError( - HttpServletResponse.SC_UNAUTHORIZED, - OAuthUtils.formatUnauthorizedMessage(request, oauthConfig.isDiscloseUnauthorizedReason()) - ); - } catch (IOException e1) { - LOGGER.debug("Unable to send {} HTTP code to client", HttpServletResponse.SC_UNAUTHORIZED, e1); - } + // if (oauthConfig.getRedirect() != null) { + if (redirectToAuthorizationEndpoint(request, response)) + return loggedIn; + try { + ((HttpServletResponse) response).sendError( + HttpServletResponse.SC_UNAUTHORIZED, + OAuthUtils.formatUnauthorizedMessage(request, oauthConfig.isDiscloseUnauthorizedReason())); + } catch (IOException e1) { + LOGGER.debug("Unable to send {} HTTP code to client", HttpServletResponse.SC_UNAUTHORIZED, e1); } } return loggedIn; + } @Override protected boolean onLoginSuccess(AuthenticationToken token, Subject subject, ServletRequest request, - ServletResponse response) { + ServletResponse response) { regenerateSession(subject); return true; } + @Override protected boolean onLoginFailure(AuthenticationToken token, AuthenticationException e, - ServletRequest request, ServletResponse response) { + ServletRequest request, ServletResponse response) { if (LOGGER.isDebugEnabled()) { LOGGER.debug("Authentication exception", e); } @@ -133,24 +135,27 @@ protected boolean onLoginFailure(AuthenticationToken token, AuthenticationExcept return false; } - private void redirectToAuthorizationEndpoint(ServletRequest request, ServletResponse response) throws IOException { + private boolean redirectToAuthorizationEndpoint(ServletRequest request, ServletResponse response) throws IOException { State state = new State(); Nonce nonce = new Nonce(); Scope scope = createScope(oauthConfig.getScopes()); + URI callback = OAuthUtils.createRedirectCallback(request); + URI uri; if (scope.contains(OPENID_SCOPE)) { - uri = buildAuthenticationURI(state, nonce, scope); + uri = buildAuthenticationURI(state, nonce, scope, callback); } else { - uri = buildAuthorizationURI(state, scope); + uri = buildAuthorizationURI(state, scope, callback); } saveState(state, nonce); saveRequest(request); issueRedirect(request, response, uri.toString()); + return true; } - private URI buildAuthorizationURI(State state, Scope scope) { + private URI buildAuthorizationURI(State state, Scope scope, URI callback) { OAuthProvider oauthProvider = oAuthService.getOAuthProvider(); URI endpointURI = oauthProvider.getAuthorizationEndpoint(); Map> parameters = OAuthUtils.extractQueryParameters(endpointURI); @@ -159,10 +164,11 @@ private URI buildAuthorizationURI(State state, Scope scope) { AuthorizationRequest.Builder builder = new AuthorizationRequest.Builder( new ResponseType(ResponseType.Value.CODE), new ClientID(checkNotNull(oauthConfig.getClientId(), "Missing client identifier"))) - .scope(scope) - .redirectionURI(checkNotNull(oauthConfig.getRedirect(), "Missing redirect URI")) - .endpointURI(endpointURI) - .state(state); + .scope(scope) + .redirectionURI( + checkNotNull(oauthConfig.getRedirect() != null ? oauthConfig.getRedirect() : callback, "Missing redirect URI")) + .endpointURI(endpointURI) + .state(state); for (Map.Entry> parameter : parameters.entrySet()) { builder.customParameter(parameter.getKey(), parameter.getValue().toArray(new String[0])); @@ -171,7 +177,7 @@ private URI buildAuthorizationURI(State state, Scope scope) { return builder.build().toURI(); } - private URI buildAuthenticationURI(State state, Nonce nonce, Scope scope) { + private URI buildAuthenticationURI(State state, Nonce nonce, Scope scope, URI callback) { OAuthProvider oauthProvider = oAuthService.getOAuthProvider(); URI endpointURI = oauthProvider.getAuthorizationEndpoint(); Map> parameters = OAuthUtils.extractQueryParameters(endpointURI); @@ -181,10 +187,10 @@ private URI buildAuthenticationURI(State state, Nonce nonce, Scope scope) { new ResponseType(ResponseType.Value.CODE), scope, new ClientID(checkNotNull(oauthConfig.getClientId(), "Missing client identifier")), - checkNotNull(oauthConfig.getRedirect(), "Missing redirect URI")) - .endpointURI(endpointURI) - .state(state) - .nonce(nonce); + checkNotNull(oauthConfig.getRedirect() != null ? oauthConfig.getRedirect() : callback, "Missing redirect URI")) + .endpointURI(endpointURI) + .state(state) + .nonce(nonce); for (Map.Entry> parameter : parameters.entrySet()) { builder.customParameter(parameter.getKey(), parameter.getValue().toArray(new String[0])); diff --git a/src/main/java/org/seedstack/oauth/internal/OAuthCallbackFilter.java b/src/main/java/org/seedstack/oauth/internal/OAuthCallbackFilter.java index 96d4c76..adb61dd 100644 --- a/src/main/java/org/seedstack/oauth/internal/OAuthCallbackFilter.java +++ b/src/main/java/org/seedstack/oauth/internal/OAuthCallbackFilter.java @@ -7,9 +7,27 @@ */ package org.seedstack.oauth.internal; -import com.nimbusds.oauth2.sdk.*; -import com.nimbusds.oauth2.sdk.id.State; -import com.nimbusds.openid.connect.sdk.Nonce; +import static com.google.common.base.Preconditions.checkNotNull; +import static org.apache.shiro.web.util.WebUtils.toHttp; +import static org.seedstack.oauth.internal.OAuthUtils.buildGenericError; +import static org.seedstack.oauth.internal.OAuthUtils.createScope; +import static org.seedstack.oauth.internal.OAuthUtils.requestTokens; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.inject.Inject; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.apache.shiro.SecurityUtils; import org.apache.shiro.authc.AuthenticationException; import org.apache.shiro.authc.AuthenticationToken; @@ -25,19 +43,14 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import javax.inject.Inject; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.io.IOException; -import java.net.URI; -import java.net.URISyntaxException; -import java.util.*; - -import static com.google.common.base.Preconditions.checkNotNull; -import static org.apache.shiro.web.util.WebUtils.toHttp; -import static org.seedstack.oauth.internal.OAuthUtils.*; +import com.nimbusds.oauth2.sdk.AuthorizationCode; +import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant; +import com.nimbusds.oauth2.sdk.AuthorizationResponse; +import com.nimbusds.oauth2.sdk.AuthorizationSuccessResponse; +import com.nimbusds.oauth2.sdk.ErrorResponse; +import com.nimbusds.oauth2.sdk.ParseException; +import com.nimbusds.oauth2.sdk.id.State; +import com.nimbusds.openid.connect.sdk.Nonce; @SecurityFilter("oauthCallback") public class OAuthCallbackFilter extends AuthenticatingFilter implements SessionRegeneratingFilter { @@ -56,10 +69,10 @@ protected AuthenticationToken createToken(ServletRequest request, ServletRespons oauthConfig, new AuthorizationCodeGrant( authorizationCode, - checkNotNull(oauthConfig.getRedirect(), "Missing redirect URI")), + checkNotNull(oauthConfig.getRedirect() != null ? oauthConfig.getRedirect() : OAuthUtils.createRedirectCallback(request), + "Missing redirect URI")), getNonce(), - createScope(oauthConfig.getScopes()) - ); + createScope(oauthConfig.getScopes())); } catch (Exception e) { return OAuthAuthenticationTokenImpl.ERRORED.apply(new AuthenticationException(e)); } @@ -72,8 +85,7 @@ protected boolean onAccessDenied(ServletRequest request, ServletResponse respons try { ((HttpServletResponse) response).sendError( HttpServletResponse.SC_UNAUTHORIZED, - OAuthUtils.formatUnauthorizedMessage(request, oauthConfig.isDiscloseUnauthorizedReason()) - ); + OAuthUtils.formatUnauthorizedMessage(request, oauthConfig.isDiscloseUnauthorizedReason())); } catch (IOException e1) { LOGGER.debug("Unable to send {} HTTP code to client", HttpServletResponse.SC_UNAUTHORIZED, e1); } @@ -83,14 +95,15 @@ protected boolean onAccessDenied(ServletRequest request, ServletResponse respons @Override protected boolean onLoginSuccess(AuthenticationToken token, Subject subject, ServletRequest request, - ServletResponse response) throws Exception { + ServletResponse response) throws Exception { regenerateSession(subject); issueSuccessRedirect(request, response); return false; } + @Override protected boolean onLoginFailure(AuthenticationToken token, AuthenticationException e, - ServletRequest request, ServletResponse response) { + ServletRequest request, ServletResponse response) { if (LOGGER.isDebugEnabled()) { LOGGER.debug("Authentication exception", e); } @@ -118,9 +131,8 @@ private AuthorizationCode parseAuthorizationCode(HttpServletRequest request) thr throw new IllegalStateException("OAuth state mismatch"); } return ((AuthorizationSuccessResponse) authorizationResponse).getAuthorizationCode(); - } else { - throw buildGenericError((ErrorResponse) authorizationResponse); } + throw buildGenericError((ErrorResponse) authorizationResponse); } private Nonce getNonce() { diff --git a/src/main/java/org/seedstack/oauth/internal/OAuthUtils.java b/src/main/java/org/seedstack/oauth/internal/OAuthUtils.java index 4f300a1..9e9f39c 100644 --- a/src/main/java/org/seedstack/oauth/internal/OAuthUtils.java +++ b/src/main/java/org/seedstack/oauth/internal/OAuthUtils.java @@ -7,23 +7,8 @@ */ package org.seedstack.oauth.internal; -import com.google.common.base.Strings; -import com.nimbusds.oauth2.sdk.*; -import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic; -import com.nimbusds.oauth2.sdk.auth.Secret; -import com.nimbusds.oauth2.sdk.http.HTTPResponse; -import com.nimbusds.oauth2.sdk.id.ClientID; -import com.nimbusds.oauth2.sdk.token.Tokens; -import com.nimbusds.openid.connect.sdk.Nonce; -import com.nimbusds.openid.connect.sdk.OIDCTokenResponse; -import com.nimbusds.openid.connect.sdk.token.OIDCTokens; -import org.apache.shiro.authc.AuthenticationException; -import org.seedstack.oauth.OAuthConfig; -import org.seedstack.oauth.OAuthProvider; -import org.seedstack.seed.SeedException; -import org.seedstack.shed.exception.BaseException; +import static com.google.common.base.Preconditions.checkNotNull; -import javax.servlet.ServletRequest; import java.io.IOException; import java.io.UnsupportedEncodingException; import java.net.URI; @@ -34,7 +19,31 @@ import java.util.List; import java.util.Map; -import static com.google.common.base.Preconditions.checkNotNull; +import javax.servlet.ServletRequest; + +import org.apache.shiro.authc.AuthenticationException; +import org.seedstack.oauth.OAuthConfig; +import org.seedstack.oauth.OAuthProvider; +import org.seedstack.seed.SeedException; +import org.seedstack.shed.exception.BaseException; + +import com.google.common.base.Strings; +import com.nimbusds.oauth2.sdk.AccessTokenResponse; +import com.nimbusds.oauth2.sdk.AuthorizationGrant; +import com.nimbusds.oauth2.sdk.ErrorObject; +import com.nimbusds.oauth2.sdk.ErrorResponse; +import com.nimbusds.oauth2.sdk.ParseException; +import com.nimbusds.oauth2.sdk.Scope; +import com.nimbusds.oauth2.sdk.TokenRequest; +import com.nimbusds.oauth2.sdk.TokenResponse; +import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic; +import com.nimbusds.oauth2.sdk.auth.Secret; +import com.nimbusds.oauth2.sdk.http.HTTPResponse; +import com.nimbusds.oauth2.sdk.id.ClientID; +import com.nimbusds.oauth2.sdk.token.Tokens; +import com.nimbusds.openid.connect.sdk.Nonce; +import com.nimbusds.openid.connect.sdk.OIDCTokenResponse; +import com.nimbusds.openid.connect.sdk.token.OIDCTokens; final class OAuthUtils { public static final String OPENID_SCOPE = "openid"; @@ -85,7 +94,7 @@ static Map> extractQueryParameters(URI uri) { } static OAuthAuthenticationTokenImpl requestTokens(OAuthProvider oauthProvider, OAuthConfig oauthConfig, - AuthorizationGrant authorizationGrant, Nonce nonce, Scope scope) { + AuthorizationGrant authorizationGrant, Nonce nonce, Scope scope) { URI endpointURI = oauthProvider.getTokenEndpoint(); Map> parameters = OAuthUtils.extractQueryParameters(endpointURI); endpointURI = OAuthUtils.stripQueryString(endpointURI); @@ -161,4 +170,16 @@ static String formatUnauthorizedMessage(ServletRequest request, boolean includeD } return msg; } + + static URI createRedirectCallback(ServletRequest request) { + String scheme = request.getScheme(); + String host = request.getServerName(); + int port = request.getServerPort(); + try { + String portPart = (port == 80 || port == 443) ? "" : ":" + port; + return new URI(scheme + "://" + host + portPart + "/callback"); + } catch (URISyntaxException e) { + throw new IllegalStateException("Invalid redirect URI", e); + } + } }