From 04eb627a334dd64bb4ce2eaa6dd603b3fa3cfb5c Mon Sep 17 00:00:00 2001 From: Jake Low Date: Fri, 25 Jul 2025 22:17:04 -0700 Subject: [PATCH] Make OAuth flow work for local development This commit modifies the OAuth redirect logic to support the workflow where a developer runs the frontend locally while using the production website's backend API. When the initial request's Origin header is localhost or 127.0.0.1, that hostname is used as the redirect target, rather than the configured value that would normally be used. --- osmchadjango/users/views.py | 47 +++++++++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/osmchadjango/users/views.py b/osmchadjango/users/views.py index f388f19d..29a34b6c 100644 --- a/osmchadjango/users/views.py +++ b/osmchadjango/users/views.py @@ -3,6 +3,7 @@ from django.contrib.auth import get_user_model from django.conf import settings +from urllib.parse import urlparse from rest_framework.authtoken.models import Token from rest_framework.generics import ( @@ -77,14 +78,39 @@ class SocialAuthAPIView(GenericAPIView): base_oauth2_url = "{}/oauth2".format(settings.OSM_SERVER_URL) token_url = "{}/token".format(base_oauth2_url) auth_url = "{}/authorize".format(base_oauth2_url) - consumer = OAuth2Session( - client_id=settings.SOCIAL_AUTH_OPENSTREETMAP_OAUTH2_KEY, - scope=settings.SOCIAL_AUTH_OPENSTREETMAP_OAUTH2_SCOPE, - redirect_uri=settings.OAUTH_REDIRECT_URI, - ) - def get_access_token(self, code): - return self.consumer.fetch_token( + def get_redirect_uri(self, request): + """ + Get the redirect URI for the OAuth flow. + + Normally, after login we ask the OAuth provider to redirect back to + osmcha.org (or whatever domain you're running OSMCha at). But to allow + developers to run the frontend locally while using the production server + as the backend, we have a special case for when the HTTP Origin header + has a hostname of localhost or 127.0.0.1. In those cases we redirect + back to that hostname. + """ + origin = request.META.get('HTTP_ORIGIN') + if origin: + try: + url = urlparse(origin) + if url.hostname in {'127.0.0.1', 'localhost'}: + return f"{origin}/authorized" + except (ValueError, AttributeError): + pass + return settings.OAUTH_REDIRECT_URI + + def get_oauth_consumer(self, request, state=None): + return OAuth2Session( + client_id=settings.SOCIAL_AUTH_OPENSTREETMAP_OAUTH2_KEY, + scope=settings.SOCIAL_AUTH_OPENSTREETMAP_OAUTH2_SCOPE, + redirect_uri=self.get_redirect_uri(request), + state=state, + ) + + def get_access_token(self, request, code): + consumer = self.get_oauth_consumer(request) + return consumer.fetch_token( token_url=self.token_url, code=code, client_secret=settings.SOCIAL_AUTH_OPENSTREETMAP_OAUTH2_SECRET, @@ -95,7 +121,7 @@ def get_user_token(self, request, access_token, *args, **kwargs): backend = load_backend( strategy=load_strategy(request), name="openstreetmap-oauth2", - redirect_uri=settings.OAUTH_REDIRECT_URI, + redirect_uri=self.get_redirect_uri(request), ) user = backend.do_auth(access_token, *args, **kwargs) token, created = Token.objects.get_or_create(user=user) @@ -103,13 +129,14 @@ def get_user_token(self, request, access_token, *args, **kwargs): def post(self, request, *args, **kwargs): if 'code' not in request.data.keys() or not request.data['code']: - login_url, state = self.consumer.authorization_url(self.auth_url) + consumer = self.get_oauth_consumer(request) + login_url, state = consumer.authorization_url(self.auth_url) return Response({"auth_url": login_url, "state": state}) else: serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) access_token = self.get_access_token( - request.data['code'], + request, request.data['code'], ).get('access_token') return Response(self.get_user_token(request, access_token))