diff --git a/backend/.env.example b/backend/.env.example index a1fa1477..59ae1283 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -65,11 +65,34 @@ GITHUB_CLIENT_ID=your_github_client_id GITHUB_CLIENT_SECRET=your_github_client_secret GITHUB_REPO_TO_STAR=genlayerlabs/genlayer-project-boilerplate -# GitHub OAuth Token Encryption Key -# This key encrypts GitHub access tokens before storing them in the database. +# Social Connections Encryption Key +# This key encrypts OAuth access tokens before storing them in the database. # IMPORTANT: This key must remain constant - changing it will make all stored tokens unreadable. # Generate a new key with: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())" # The key is a base64-encoded 32-byte value (e.g., "07NNDpobC9R40oJnBxwvb6ifUENNGJTiaz3E562SgVw=") +# Note: Falls back to GITHUB_ENCRYPTION_KEY if not set (for backward compatibility) +SOCIAL_ENCRYPTION_KEY=your_encryption_key_here + +# Twitter/X OAuth 2.0 Configuration +# Create a Twitter OAuth App at https://developer.twitter.com/en/portal/dashboard +# 1. Create a project/app and enable OAuth 2.0 +# 2. Set the callback URL to: {BACKEND_URL}/api/auth/twitter/callback/ +# 3. Enable scopes: tweet.read, users.read, offline.access +TWITTER_CLIENT_ID=your_twitter_client_id +TWITTER_CLIENT_SECRET=your_twitter_client_secret + +# Discord OAuth 2.0 Configuration +# Create a Discord OAuth App at https://discord.com/developers/applications +# 1. Create a new application +# 2. Go to OAuth2 > General and add redirect URL: {BACKEND_URL}/api/auth/discord/callback/ +# 3. Required scopes: identify, guilds +DISCORD_CLIENT_ID=your_discord_client_id +DISCORD_CLIENT_SECRET=your_discord_client_secret +# Discord Guild (Server) ID for membership checks +# To get: Enable Developer Mode in Discord, right-click server > Copy Server ID +DISCORD_GUILD_ID=your_discord_server_id + +# Legacy: GitHub OAuth Token Encryption Key (deprecated, use SOCIAL_ENCRYPTION_KEY) GITHUB_ENCRYPTION_KEY=your_encryption_key_here # Google reCAPTCHA Configuration diff --git a/backend/social_connections/__init__.py b/backend/social_connections/__init__.py new file mode 100644 index 00000000..6b3f9732 --- /dev/null +++ b/backend/social_connections/__init__.py @@ -0,0 +1 @@ +default_app_config = 'social_connections.apps.SocialConnectionsConfig' diff --git a/backend/social_connections/admin.py b/backend/social_connections/admin.py new file mode 100644 index 00000000..9308a511 --- /dev/null +++ b/backend/social_connections/admin.py @@ -0,0 +1,4 @@ +# Admin registrations are handled in each sub-app: +# - social_connections.github.admin +# - social_connections.twitter.admin +# - social_connections.discord.admin diff --git a/backend/social_connections/apps.py b/backend/social_connections/apps.py new file mode 100644 index 00000000..68ea2453 --- /dev/null +++ b/backend/social_connections/apps.py @@ -0,0 +1,7 @@ +from django.apps import AppConfig + + +class SocialConnectionsConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'social_connections' + verbose_name = 'Social Connections' diff --git a/backend/social_connections/discord/__init__.py b/backend/social_connections/discord/__init__.py new file mode 100644 index 00000000..91229806 --- /dev/null +++ b/backend/social_connections/discord/__init__.py @@ -0,0 +1 @@ +default_app_config = 'social_connections.discord.apps.DiscordConfig' diff --git a/backend/social_connections/discord/admin.py b/backend/social_connections/discord/admin.py new file mode 100644 index 00000000..cd4af957 --- /dev/null +++ b/backend/social_connections/discord/admin.py @@ -0,0 +1,25 @@ +from django.contrib import admin +from .models import DiscordConnection + + +@admin.register(DiscordConnection) +class DiscordConnectionAdmin(admin.ModelAdmin): + list_display = ['user', 'username', 'discriminator', 'platform_user_id', 'linked_at', 'created_at'] + list_filter = ['linked_at'] + search_fields = ['user__email', 'user__name', 'username', 'platform_user_id'] + readonly_fields = ['platform_user_id', 'access_token', 'refresh_token', 'avatar_hash', 'linked_at', 'created_at', 'updated_at'] + raw_id_fields = ['user'] + + fieldsets = ( + (None, { + 'fields': ('user', 'username', 'discriminator', 'platform_user_id') + }), + ('Avatar', { + 'fields': ('avatar_hash',), + 'classes': ('collapse',) + }), + ('Timestamps', { + 'fields': ('linked_at', 'created_at', 'updated_at'), + 'classes': ('collapse',) + }), + ) diff --git a/backend/social_connections/discord/apps.py b/backend/social_connections/discord/apps.py new file mode 100644 index 00000000..67ad9af4 --- /dev/null +++ b/backend/social_connections/discord/apps.py @@ -0,0 +1,8 @@ +from django.apps import AppConfig + + +class DiscordConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'social_connections.discord' + label = 'social_discord' + verbose_name = 'Discord Connection' diff --git a/backend/social_connections/discord/migrations/0001_initial.py b/backend/social_connections/discord/migrations/0001_initial.py new file mode 100644 index 00000000..e8e8e50b --- /dev/null +++ b/backend/social_connections/discord/migrations/0001_initial.py @@ -0,0 +1,37 @@ +# Generated by Django 5.2.7 on 2026-01-30 15:31 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name='DiscordConnection', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('username', models.CharField(blank=True, help_text='Platform username', max_length=100)), + ('platform_user_id', models.CharField(blank=True, help_text='Platform user ID for unique identification', max_length=50)), + ('access_token', models.TextField(blank=True, help_text='Encrypted access token')), + ('refresh_token', models.TextField(blank=True, help_text='Encrypted refresh token (if applicable)')), + ('linked_at', models.DateTimeField(blank=True, help_text='When the account was linked', null=True)), + ('discriminator', models.CharField(blank=True, help_text='Discord discriminator (legacy - newer accounts may not have this)', max_length=10)), + ('avatar_hash', models.CharField(blank=True, help_text='Discord avatar hash for constructing avatar URL', max_length=100)), + ('user', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, related_name='discord_connection', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'verbose_name': 'Discord Connection', + 'verbose_name_plural': 'Discord Connections', + }, + ), + ] diff --git a/backend/social_connections/discord/migrations/__init__.py b/backend/social_connections/discord/migrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/social_connections/discord/models.py b/backend/social_connections/discord/models.py new file mode 100644 index 00000000..9e75ca98 --- /dev/null +++ b/backend/social_connections/discord/models.py @@ -0,0 +1,40 @@ +""" +Discord OAuth connection model. +""" +from django.db import models +from social_connections.models import SocialConnection + + +class DiscordConnection(SocialConnection): + """ + Discord OAuth connection. + + Stores the user's Discord account information obtained through OAuth 2.0. + Includes Discord-specific fields like discriminator and avatar hash. + """ + user = models.OneToOneField( + 'users.User', + on_delete=models.CASCADE, + related_name='discord_connection' + ) + discriminator = models.CharField( + max_length=10, + blank=True, + help_text="Discord discriminator (legacy - newer accounts may not have this)" + ) + avatar_hash = models.CharField( + max_length=100, + blank=True, + help_text="Discord avatar hash for constructing avatar URL" + ) + + class Meta: + verbose_name = 'Discord Connection' + verbose_name_plural = 'Discord Connections' + + @property + def avatar_url(self): + """Construct the Discord avatar URL from the avatar hash.""" + if not self.avatar_hash or not self.platform_user_id: + return None + return f"https://cdn.discordapp.com/avatars/{self.platform_user_id}/{self.avatar_hash}.png" diff --git a/backend/social_connections/discord/oauth.py b/backend/social_connections/discord/oauth.py new file mode 100644 index 00000000..3cd577c8 --- /dev/null +++ b/backend/social_connections/discord/oauth.py @@ -0,0 +1,362 @@ +""" +Discord OAuth 2.0 authentication handling. +""" +import secrets +from urllib.parse import urlencode + +import requests +from django.conf import settings +from django.shortcuts import redirect, render +from django.utils import timezone +from django.views.decorators.csrf import csrf_exempt +from django.core import signing +from rest_framework.decorators import api_view, permission_classes +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response +from rest_framework import status + +from cryptography.fernet import InvalidToken +from users.models import User +from social_connections.encryption import encrypt_token, decrypt_token +from social_connections.discord.models import DiscordConnection +from tally.middleware.logging_utils import get_app_logger +from tally.middleware.tracing import trace_external + +logger = get_app_logger('discord_oauth') + +# Cache to track used OAuth codes (prevents duplicate exchanges) +_used_oauth_codes = {} + + +@api_view(['GET']) +@permission_classes([IsAuthenticated]) +def discord_oauth_initiate(request): + """Initiate Discord OAuth 2.0 flow.""" + user_id = request.user.id + logger.debug("Discord OAuth initiated") + + # Generate state token with user ID embedded + state_data = { + 'user_id': user_id, + 'nonce': secrets.token_urlsafe(32) + } + + # Sign the state data to make it tamper-proof + state = signing.dumps(state_data, salt='discord_oauth_state') + logger.debug("Generated signed OAuth state") + + # Build Discord OAuth URL + discord_oauth_url = "https://discord.com/api/oauth2/authorize" + params = { + 'response_type': 'code', + 'client_id': settings.DISCORD_CLIENT_ID, + 'redirect_uri': settings.DISCORD_REDIRECT_URI, + 'scope': 'identify guilds', + 'state': state + } + + auth_url = f"{discord_oauth_url}?{urlencode(params)}" + return redirect(auth_url) + + +@csrf_exempt +def discord_oauth_callback(request): + """Handle Discord OAuth callback.""" + code = request.GET.get('code') + state = request.GET.get('state') + error = request.GET.get('error') + error_description = request.GET.get('error_description') + + logger.debug("Discord OAuth callback received") + + template_context = { + 'platform': 'discord', + 'frontend_origin': settings.FRONTEND_URL + } + + # Handle errors from Discord + if error: + logger.error(f"Discord OAuth error: {error} - {error_description}") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'authorization_failed', + 'message': 'Authorization failed. This window will close automatically.', + }) + + # Validate state token + if not state: + logger.error("No state token received from Discord") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'invalid_state', + 'message': 'Invalid state token. This window will close automatically.', + }) + + try: + state_data = signing.loads(state, salt='discord_oauth_state', max_age=600) + user_id = state_data.get('user_id') + logger.debug("Successfully validated signed OAuth state") + except signing.SignatureExpired: + logger.error("OAuth state token has expired") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'state_expired', + 'message': 'Session expired. This window will close automatically.', + }) + except signing.BadSignature: + logger.error("OAuth state token has invalid signature") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'invalid_state', + 'message': 'Invalid state token. This window will close automatically.', + }) + + # Clean up old codes (older than 10 minutes) + cutoff = timezone.now() - timezone.timedelta(minutes=10) + expired_codes = [k for k, v in _used_oauth_codes.items() if v <= cutoff] + for expired_code in expired_codes: + del _used_oauth_codes[expired_code] + + # Check if this code has already been used + if code in _used_oauth_codes: + logger.warning("OAuth code already used, rejecting duplicate request") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'code_already_used', + 'message': 'This code has already been used. This window will close automatically.', + }) + + # Mark code as used immediately + _used_oauth_codes[code] = timezone.now() + + # Exchange code for access token + logger.debug("Attempting to exchange OAuth code for access token") + token_url = "https://discord.com/api/oauth2/token" + + token_params = { + 'client_id': settings.DISCORD_CLIENT_ID, + 'client_secret': settings.DISCORD_CLIENT_SECRET, + 'grant_type': 'authorization_code', + 'code': code, + 'redirect_uri': settings.DISCORD_REDIRECT_URI + } + + headers = { + 'Content-Type': 'application/x-www-form-urlencoded' + } + + try: + with trace_external('discord', 'token_exchange'): + token_response = requests.post(token_url, data=token_params, headers=headers) + token_data = token_response.json() + + if token_response.status_code != 200 or 'error' in token_data: + error_msg = token_data.get('error_description', token_data.get('error', 'Unknown error')) + logger.error(f"Discord token exchange error: {error_msg}") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'token_exchange_failed', + 'message': 'Failed to exchange token. This window will close automatically.', + }) + + access_token = token_data.get('access_token') + refresh_token = token_data.get('refresh_token', '') + + if not access_token: + logger.error("No access token received from Discord") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'no_access_token', + 'message': 'No access token received. This window will close automatically.', + }) + + # Fetch Discord user data + user_url = "https://discord.com/api/v10/users/@me" + user_headers = { + 'Authorization': f'Bearer {access_token}', + } + + with trace_external('discord', 'get_user'): + user_response = requests.get(user_url, headers=user_headers) + user_response.raise_for_status() + discord_user = user_response.json() + + # Get user from state token + if not user_id: + logger.error("No user_id in state token") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'invalid_state', + 'message': 'Invalid authentication state. This window will close automatically.', + }) + + try: + user = User.objects.get(id=user_id) + logger.debug("Found user from state token") + except User.DoesNotExist: + logger.error("User from state not found") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'user_not_found', + 'message': 'User not found. This window will close automatically.', + }) + + # Check if this Discord account is already linked to another user + existing_connection = DiscordConnection.objects.filter( + platform_user_id=discord_user.get('id') + ).exclude(user=user).first() + + if existing_connection: + logger.warning("Discord account already linked to another user") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'already_linked', + 'message': 'Discord account already linked. This window will close automatically.', + }) + + # Create or update DiscordConnection + connection, created = DiscordConnection.objects.update_or_create( + user=user, + defaults={ + 'username': discord_user.get('username', ''), + 'platform_user_id': discord_user.get('id', ''), + 'discriminator': discord_user.get('discriminator', '0'), + 'avatar_hash': discord_user.get('avatar', ''), + 'access_token': encrypt_token(access_token), + 'refresh_token': encrypt_token(refresh_token) if refresh_token else '', + 'linked_at': timezone.now() + } + ) + + logger.debug("Discord account linked successfully") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': True, + 'error': '', + 'message': 'Discord account linked successfully! This window will close automatically.', + }) + + except requests.RequestException as e: + logger.error(f"Discord API request failed: {e}") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'api_request_failed', + 'message': 'API request failed. This window will close automatically.', + }) + + +@api_view(['POST']) +@permission_classes([IsAuthenticated]) +def disconnect_discord(request): + """Disconnect Discord account from user profile.""" + try: + user = request.user + try: + connection = user.discord_connection + connection.delete() + logger.debug("Discord connection deleted") + except DiscordConnection.DoesNotExist: + pass + + return Response({ + 'message': 'Discord account disconnected successfully' + }, status=status.HTTP_200_OK) + except Exception as e: + logger.error(f"Failed to disconnect Discord: {e}") + return Response({ + 'error': 'Failed to disconnect Discord account' + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + +@api_view(['GET']) +@permission_classes([IsAuthenticated]) +def check_guild_membership(request): + """Check if user is a member of the configured Discord guild/server.""" + user = request.user + + # Check if guild ID is configured + guild_id = getattr(settings, 'DISCORD_GUILD_ID', '') + if not guild_id: + return Response({ + 'is_member': False, + 'error': 'Discord guild not configured' + }, status=status.HTTP_200_OK) + + # Try to get Discord connection + try: + connection = user.discord_connection + except DiscordConnection.DoesNotExist: + return Response({ + 'is_member': False, + 'guild_id': guild_id, + 'error': 'Discord account not linked' + }, status=status.HTTP_200_OK) + + if not connection.access_token: + return Response({ + 'is_member': False, + 'guild_id': guild_id, + 'error': 'Discord access token missing' + }, status=status.HTTP_200_OK) + + try: + token = decrypt_token(connection.access_token) + except InvalidToken: + logger.warning("Failed to decrypt Discord token") + connection.access_token = '' + connection.save() + return Response({ + 'is_member': False, + 'guild_id': guild_id, + 'error': 'Discord token invalid' + }, status=status.HTTP_200_OK) + + try: + # Fetch user's guilds + guilds_url = "https://discord.com/api/v10/users/@me/guilds" + headers = { + 'Authorization': f'Bearer {token}', + } + + with trace_external('discord', 'get_guilds'): + response = requests.get(guilds_url, headers=headers) + + if response.status_code == 401: + # Token expired or revoked + logger.warning("Discord token expired or revoked") + return Response({ + 'is_member': False, + 'guild_id': guild_id, + 'error': 'Discord authorization expired. Please reconnect your account.' + }, status=status.HTTP_200_OK) + + response.raise_for_status() + guilds = response.json() + + # Check if user is in the configured guild + is_member = any(g.get('id') == guild_id for g in guilds) + + return Response({ + 'is_member': is_member, + 'guild_id': guild_id, + 'discord_username': connection.username + }, status=status.HTTP_200_OK) + + except requests.RequestException as e: + logger.error(f"Failed to check guild membership: {e}") + return Response({ + 'is_member': False, + 'guild_id': guild_id, + 'error': 'Failed to check guild membership' + }, status=status.HTTP_200_OK) diff --git a/backend/social_connections/discord/urls.py b/backend/social_connections/discord/urls.py new file mode 100644 index 00000000..4980c029 --- /dev/null +++ b/backend/social_connections/discord/urls.py @@ -0,0 +1,14 @@ +from django.urls import path +from . import oauth + +urlpatterns = [ + # OAuth endpoints (under /api/auth/discord/) + path('', oauth.discord_oauth_initiate, name='discord_oauth'), + path('callback/', oauth.discord_oauth_callback, name='discord_callback'), +] + +# API endpoints (under /api/v1/social/discord/) +api_urlpatterns = [ + path('disconnect/', oauth.disconnect_discord, name='discord_disconnect'), + path('check-guild/', oauth.check_guild_membership, name='discord_check_guild'), +] diff --git a/backend/social_connections/encryption.py b/backend/social_connections/encryption.py new file mode 100644 index 00000000..6a364323 --- /dev/null +++ b/backend/social_connections/encryption.py @@ -0,0 +1,41 @@ +""" +Shared encryption utilities for social connection tokens. +""" +from django.conf import settings +from cryptography.fernet import Fernet + + +def get_fernet(): + """ + Get Fernet encryption instance using configured key. + + Uses SOCIAL_ENCRYPTION_KEY setting, falling back to GITHUB_ENCRYPTION_KEY + for backward compatibility. + """ + key = getattr(settings, 'SOCIAL_ENCRYPTION_KEY', None) or \ + getattr(settings, 'GITHUB_ENCRYPTION_KEY', None) + + if not key: + raise RuntimeError( + "SOCIAL_ENCRYPTION_KEY is not set. " + "Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\"" + ) + + key = key.encode() if isinstance(key, str) else key + return Fernet(key) + + +def encrypt_token(token): + """Encrypt a token for secure storage.""" + if not token: + return "" + fernet = get_fernet() + return fernet.encrypt(token.encode()).decode() + + +def decrypt_token(encrypted_token): + """Decrypt a stored token.""" + if not encrypted_token: + return "" + fernet = get_fernet() + return fernet.decrypt(encrypted_token.encode()).decode() diff --git a/backend/social_connections/github/__init__.py b/backend/social_connections/github/__init__.py new file mode 100644 index 00000000..e8470527 --- /dev/null +++ b/backend/social_connections/github/__init__.py @@ -0,0 +1 @@ +default_app_config = 'social_connections.github.apps.GitHubConfig' diff --git a/backend/social_connections/github/admin.py b/backend/social_connections/github/admin.py new file mode 100644 index 00000000..1fdcc0fc --- /dev/null +++ b/backend/social_connections/github/admin.py @@ -0,0 +1,21 @@ +from django.contrib import admin +from .models import GitHubConnection + + +@admin.register(GitHubConnection) +class GitHubConnectionAdmin(admin.ModelAdmin): + list_display = ['user', 'username', 'platform_user_id', 'linked_at', 'created_at'] + list_filter = ['linked_at'] + search_fields = ['user__email', 'user__name', 'username', 'platform_user_id'] + readonly_fields = ['platform_user_id', 'access_token', 'linked_at', 'created_at', 'updated_at'] + raw_id_fields = ['user'] + + fieldsets = ( + (None, { + 'fields': ('user', 'username', 'platform_user_id') + }), + ('Timestamps', { + 'fields': ('linked_at', 'created_at', 'updated_at'), + 'classes': ('collapse',) + }), + ) diff --git a/backend/social_connections/github/apps.py b/backend/social_connections/github/apps.py new file mode 100644 index 00000000..495fdea2 --- /dev/null +++ b/backend/social_connections/github/apps.py @@ -0,0 +1,8 @@ +from django.apps import AppConfig + + +class GitHubConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'social_connections.github' + label = 'social_github' + verbose_name = 'GitHub Connection' diff --git a/backend/social_connections/github/migrations/0001_initial.py b/backend/social_connections/github/migrations/0001_initial.py new file mode 100644 index 00000000..3b5c0577 --- /dev/null +++ b/backend/social_connections/github/migrations/0001_initial.py @@ -0,0 +1,35 @@ +# Generated by Django 5.2.7 on 2026-01-30 15:27 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name='GitHubConnection', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('username', models.CharField(blank=True, help_text='Platform username', max_length=100)), + ('platform_user_id', models.CharField(blank=True, help_text='Platform user ID for unique identification', max_length=50)), + ('access_token', models.TextField(blank=True, help_text='Encrypted access token')), + ('refresh_token', models.TextField(blank=True, help_text='Encrypted refresh token (if applicable)')), + ('linked_at', models.DateTimeField(blank=True, help_text='When the account was linked', null=True)), + ('user', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, related_name='github_connection', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'verbose_name': 'GitHub Connection', + 'verbose_name_plural': 'GitHub Connections', + }, + ), + ] diff --git a/backend/social_connections/github/migrations/0002_migrate_github_data.py b/backend/social_connections/github/migrations/0002_migrate_github_data.py new file mode 100644 index 00000000..a224e5ad --- /dev/null +++ b/backend/social_connections/github/migrations/0002_migrate_github_data.py @@ -0,0 +1,81 @@ +# Generated by Django 5.2.7 on 2026-01-30 + +from django.db import migrations + + +def migrate_github_data(apps, schema_editor): + """ + Migrate GitHub OAuth data from User model to GitHubConnection model. + + This migration handles the case where the User model may or may not have + the old GitHub fields (github_user_id, github_username, etc.) depending + on whether this is a fresh database or an existing one. + """ + User = apps.get_model('users', 'User') + GitHubConnection = apps.get_model('social_github', 'GitHubConnection') + + # Check if the User model has the old GitHub fields + # In a fresh database (like during tests), these fields won't exist + user_fields = [f.name for f in User._meta.get_fields()] + + if 'github_user_id' not in user_fields: + # Fresh database - no data to migrate + print("No GitHub fields on User model - skipping migration") + return + + # Find users with GitHub data + users_with_github = User.objects.exclude(github_user_id='').exclude(github_user_id__isnull=True) + + count = 0 + for user in users_with_github: + # Create GitHubConnection if it doesn't already exist + _, created = GitHubConnection.objects.get_or_create( + user=user, + defaults={ + 'username': user.github_username or '', + 'platform_user_id': user.github_user_id or '', + 'access_token': user.github_access_token or '', + 'linked_at': user.github_linked_at + } + ) + if created: + count += 1 + + print(f"Migrated {count} GitHub connections") + + +def reverse_migrate_github_data(apps, schema_editor): + """ + Reverse migration: copy data back to User model. + Note: This is mainly for safety; the User fields will be removed later. + """ + User = apps.get_model('users', 'User') + GitHubConnection = apps.get_model('social_github', 'GitHubConnection') + + # Check if the User model has the old GitHub fields + user_fields = [f.name for f in User._meta.get_fields()] + + if 'github_user_id' not in user_fields: + # Can't reverse to fields that don't exist + print("No GitHub fields on User model - skipping reverse migration") + return + + for connection in GitHubConnection.objects.all(): + user = connection.user + user.github_username = connection.username + user.github_user_id = connection.platform_user_id + user.github_access_token = connection.access_token + user.github_linked_at = connection.linked_at + user.save() + + +class Migration(migrations.Migration): + + dependencies = [ + ('social_github', '0001_initial'), + ('users', '0001_initial'), # Ensure users migration has run + ] + + operations = [ + migrations.RunPython(migrate_github_data, reverse_migrate_github_data), + ] diff --git a/backend/social_connections/github/migrations/__init__.py b/backend/social_connections/github/migrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/social_connections/github/models.py b/backend/social_connections/github/models.py new file mode 100644 index 00000000..c7bbaa5c --- /dev/null +++ b/backend/social_connections/github/models.py @@ -0,0 +1,22 @@ +""" +GitHub OAuth connection model. +""" +from django.db import models +from social_connections.models import SocialConnection + + +class GitHubConnection(SocialConnection): + """ + GitHub OAuth connection. + + Stores the user's GitHub account information obtained through OAuth. + """ + user = models.OneToOneField( + 'users.User', + on_delete=models.CASCADE, + related_name='github_connection' + ) + + class Meta: + verbose_name = 'GitHub Connection' + verbose_name_plural = 'GitHub Connections' diff --git a/backend/social_connections/github/oauth.py b/backend/social_connections/github/oauth.py new file mode 100644 index 00000000..2ecc70b7 --- /dev/null +++ b/backend/social_connections/github/oauth.py @@ -0,0 +1,362 @@ +""" +GitHub OAuth authentication handling. +""" +import secrets +from urllib.parse import urlencode + +import requests +from django.conf import settings +from django.shortcuts import redirect, render +from django.utils import timezone +from django.views.decorators.csrf import csrf_exempt +from django.core import signing +from rest_framework.decorators import api_view, permission_classes +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response +from rest_framework import status + +from cryptography.fernet import InvalidToken +from users.models import User +from social_connections.encryption import encrypt_token, decrypt_token +from social_connections.github.models import GitHubConnection +from tally.middleware.logging_utils import get_app_logger +from tally.middleware.tracing import trace_external + +logger = get_app_logger('github_oauth') + +# Cache to track used OAuth codes (prevents duplicate exchanges) +# Format: {code: timestamp} +_used_oauth_codes = {} + + +@api_view(['GET']) +@permission_classes([IsAuthenticated]) +def github_oauth_initiate(request): + """Initiate GitHub OAuth flow.""" + user_id = request.user.id + logger.debug("GitHub OAuth initiated") + + # Generate state token with user ID embedded + state_data = { + 'user_id': user_id, + 'nonce': secrets.token_urlsafe(32) + } + + # Sign the state data to make it tamper-proof + state = signing.dumps(state_data, salt='github_oauth_state') + logger.debug("Generated signed OAuth state") + + # Build GitHub OAuth URL with minimal read-only permissions + github_oauth_url = "https://github.com/login/oauth/authorize" + params = { + 'client_id': settings.GITHUB_CLIENT_ID, + 'redirect_uri': settings.GITHUB_REDIRECT_URI, + 'scope': '', # Empty scope = read-only public info + 'state': state, + 'allow_signup': 'false' + } + + auth_url = f"{github_oauth_url}?{urlencode(params)}" + return redirect(auth_url) + + +@csrf_exempt +def github_oauth_callback(request): + """Handle GitHub OAuth callback.""" + code = request.GET.get('code') + state = request.GET.get('state') + error = request.GET.get('error') + error_description = request.GET.get('error_description') + + logger.debug("OAuth callback received") + + template_context = { + 'platform': 'github', + 'frontend_origin': settings.FRONTEND_URL + } + + # Handle errors from GitHub + if error: + logger.error(f"GitHub OAuth error: {error} - {error_description}") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'authorization_failed', + 'message': 'Authorization failed. This window will close automatically.', + }) + + # Validate state token + if not state: + logger.error("No state token received from GitHub") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'invalid_state', + 'message': 'Invalid state token. This window will close automatically.', + }) + + try: + state_data = signing.loads(state, salt='github_oauth_state', max_age=600) + user_id = state_data.get('user_id') + logger.debug("Successfully validated signed OAuth state") + except signing.SignatureExpired: + logger.error("OAuth state token has expired") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'state_expired', + 'message': 'Session expired. This window will close automatically.', + }) + except signing.BadSignature: + logger.error("OAuth state token has invalid signature") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'invalid_state', + 'message': 'Invalid state token. This window will close automatically.', + }) + + # Clean up old codes (older than 10 minutes) + cutoff = timezone.now() - timezone.timedelta(minutes=10) + expired_codes = [k for k, v in _used_oauth_codes.items() if v <= cutoff] + for expired_code in expired_codes: + del _used_oauth_codes[expired_code] + + if expired_codes: + logger.debug(f"Cleaned up {len(expired_codes)} expired OAuth codes") + + # Check if this code has already been used + if code in _used_oauth_codes: + logger.warning("OAuth code already used, rejecting duplicate request") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'code_already_used', + 'message': 'This code has already been used. This window will close automatically.', + }) + + # Mark code as used immediately + _used_oauth_codes[code] = timezone.now() + logger.debug(f"Marked OAuth code as used ({len(_used_oauth_codes)} codes in cache)") + + # Exchange code for access token + logger.debug("Attempting to exchange OAuth code for access token") + token_url = "https://github.com/login/oauth/access_token" + token_params = { + 'client_id': settings.GITHUB_CLIENT_ID, + 'client_secret': settings.GITHUB_CLIENT_SECRET, + 'code': code, + 'redirect_uri': settings.GITHUB_REDIRECT_URI + } + + headers = {'Accept': 'application/json'} + + try: + with trace_external('github', 'token_exchange'): + token_response = requests.post(token_url, data=token_params, headers=headers) + token_response.raise_for_status() + token_data = token_response.json() + + if 'error' in token_data: + logger.error("GitHub token exchange error") + if token_data.get('error') == 'bad_verification_code': + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'code_already_used', + 'message': 'This code has already been used. This window will close automatically.', + }) + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'token_exchange_failed', + 'message': 'Failed to exchange token. This window will close automatically.', + }) + + access_token = token_data.get('access_token') + if not access_token: + logger.error("No access token received from GitHub") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'no_access_token', + 'message': 'No access token received. This window will close automatically.', + }) + + # Fetch GitHub user data + user_url = "https://api.github.com/user" + user_headers = { + 'Authorization': f'token {access_token}', + 'Accept': 'application/json' + } + + with trace_external('github', 'get_user'): + user_response = requests.get(user_url, headers=user_headers) + user_response.raise_for_status() + github_user = user_response.json() + + # Get user from state token + if not user_id: + logger.error("No user_id in state token") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'invalid_state', + 'message': 'Invalid authentication state. This window will close automatically.', + }) + + try: + user = User.objects.get(id=user_id) + logger.debug("Found user from state token") + except User.DoesNotExist: + logger.error("User from state not found") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'user_not_found', + 'message': 'User not found. This window will close automatically.', + }) + + # Check if this GitHub account is already linked to another user + existing_connection = GitHubConnection.objects.filter( + platform_user_id=str(github_user['id']) + ).exclude(user=user).first() + + if existing_connection: + logger.warning("GitHub account already linked to another user") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'already_linked', + 'message': 'GitHub account already linked. This window will close automatically.', + }) + + # Create or update GitHubConnection + connection, created = GitHubConnection.objects.update_or_create( + user=user, + defaults={ + 'username': github_user['login'], + 'platform_user_id': str(github_user['id']), + 'access_token': encrypt_token(access_token), + 'linked_at': timezone.now() + } + ) + + logger.debug("GitHub account linked successfully") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': True, + 'error': '', + 'message': 'GitHub account linked successfully! This window will close automatically.', + }) + + except requests.RequestException as e: + logger.error(f"GitHub API request failed: {e}") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'api_request_failed', + 'message': 'API request failed. This window will close automatically.', + }) + + +@api_view(['POST']) +@permission_classes([IsAuthenticated]) +def disconnect_github(request): + """Disconnect GitHub account from user profile.""" + try: + user = request.user + try: + connection = user.github_connection + connection.delete() + logger.debug("GitHub connection deleted") + except GitHubConnection.DoesNotExist: + pass # No connection to delete + + return Response({ + 'message': 'GitHub account disconnected successfully' + }, status=status.HTTP_200_OK) + except Exception as e: + logger.error(f"Failed to disconnect GitHub: {e}") + return Response({ + 'error': 'Failed to disconnect GitHub account' + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + +@api_view(['GET']) +@permission_classes([IsAuthenticated]) +def check_repo_star(request): + """Check if user has starred the required repository.""" + user = request.user + + # Try to get GitHub connection + try: + connection = user.github_connection + github_username = connection.username + github_access_token = connection.access_token + except GitHubConnection.DoesNotExist: + return Response({ + 'has_starred': False, + 'repo': settings.GITHUB_REPO_TO_STAR, + 'error': 'GitHub account not linked' + }, status=status.HTTP_200_OK) + + if not github_username: + return Response({ + 'has_starred': False, + 'repo': settings.GITHUB_REPO_TO_STAR, + 'error': 'GitHub account not linked' + }, status=status.HTTP_200_OK) + + try: + headers = {'Accept': 'application/json'} + if github_access_token: + try: + token = decrypt_token(github_access_token) + headers['Authorization'] = f'token {token}' + except InvalidToken: + logger.warning("Failed to decrypt GitHub token, clearing token") + connection.access_token = '' + connection.save() + + # Parse the repo owner and name + repo_parts = settings.GITHUB_REPO_TO_STAR.split('/') + if len(repo_parts) != 2: + logger.error(f"Invalid GITHUB_REPO_TO_STAR format: {settings.GITHUB_REPO_TO_STAR}") + return Response({ + 'error': 'Invalid repository configuration' + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + owner, repo = repo_parts + + # Check if user has starred the specific repo + if github_access_token and 'Authorization' in headers: + url = f'https://api.github.com/user/starred/{owner}/{repo}' + with trace_external('github', 'check_star'): + response = requests.get(url, headers=headers) + has_starred = response.status_code == 204 + else: + url = f'https://api.github.com/users/{github_username}/starred' + with trace_external('github', 'check_star_public'): + response = requests.get(url, headers={'Accept': 'application/json'}) + + if response.status_code == 200: + starred_repos = response.json() + repo_full_name = f'{owner}/{repo}' + has_starred = any(r.get('full_name') == repo_full_name for r in starred_repos) + else: + has_starred = False + + return Response({ + 'has_starred': has_starred, + 'repo': settings.GITHUB_REPO_TO_STAR, + 'github_username': github_username + }, status=status.HTTP_200_OK) + + except requests.RequestException as e: + logger.error(f"Failed to check star status: {e}") + return Response({ + 'has_starred': False, + 'repo': settings.GITHUB_REPO_TO_STAR, + 'error': 'Failed to check star status' + }, status=status.HTTP_200_OK) diff --git a/backend/social_connections/github/urls.py b/backend/social_connections/github/urls.py new file mode 100644 index 00000000..048fbdcb --- /dev/null +++ b/backend/social_connections/github/urls.py @@ -0,0 +1,14 @@ +from django.urls import path +from . import oauth + +urlpatterns = [ + # OAuth endpoints (under /api/auth/github/) + path('', oauth.github_oauth_initiate, name='github_oauth'), + path('callback/', oauth.github_oauth_callback, name='github_callback'), +] + +# API endpoints (under /api/v1/social/github/) +api_urlpatterns = [ + path('disconnect/', oauth.disconnect_github, name='github_disconnect'), + path('check-star/', oauth.check_repo_star, name='github_check_star'), +] diff --git a/backend/social_connections/models.py b/backend/social_connections/models.py new file mode 100644 index 00000000..38b15ea9 --- /dev/null +++ b/backend/social_connections/models.py @@ -0,0 +1,49 @@ +""" +Abstract base model for social connections. +""" +from django.db import models +from django.conf import settings +from utils.models import BaseModel + + +class SocialConnection(BaseModel): + """ + Abstract base model for social connections. + + Provides common fields for storing OAuth connection data across + different social platforms (GitHub, Twitter, Discord, etc.). + """ + user = models.OneToOneField( + settings.AUTH_USER_MODEL, + on_delete=models.CASCADE, + related_name='%(class)s' + ) + username = models.CharField( + max_length=100, + blank=True, + help_text="Platform username" + ) + platform_user_id = models.CharField( + max_length=50, + blank=True, + help_text="Platform user ID for unique identification" + ) + access_token = models.TextField( + blank=True, + help_text="Encrypted access token" + ) + refresh_token = models.TextField( + blank=True, + help_text="Encrypted refresh token (if applicable)" + ) + linked_at = models.DateTimeField( + null=True, + blank=True, + help_text="When the account was linked" + ) + + class Meta: + abstract = True + + def __str__(self): + return f"{self.__class__.__name__}: {self.username or self.platform_user_id}" diff --git a/backend/social_connections/serializers.py b/backend/social_connections/serializers.py new file mode 100644 index 00000000..68aa33f7 --- /dev/null +++ b/backend/social_connections/serializers.py @@ -0,0 +1,38 @@ +""" +Serializers for social connection models. +""" +from rest_framework import serializers +from social_connections.github.models import GitHubConnection +from social_connections.twitter.models import TwitterConnection +from social_connections.discord.models import DiscordConnection + + +class GitHubConnectionSerializer(serializers.ModelSerializer): + """Serializer for GitHub connection data.""" + + class Meta: + model = GitHubConnection + fields = ['username', 'linked_at'] + read_only_fields = ['username', 'linked_at'] + + +class TwitterConnectionSerializer(serializers.ModelSerializer): + """Serializer for Twitter connection data.""" + + class Meta: + model = TwitterConnection + fields = ['username', 'linked_at'] + read_only_fields = ['username', 'linked_at'] + + +class DiscordConnectionSerializer(serializers.ModelSerializer): + """Serializer for Discord connection data.""" + avatar_url = serializers.SerializerMethodField() + + class Meta: + model = DiscordConnection + fields = ['username', 'discriminator', 'avatar_url', 'linked_at'] + read_only_fields = ['username', 'discriminator', 'avatar_url', 'linked_at'] + + def get_avatar_url(self, obj): + return obj.avatar_url diff --git a/backend/social_connections/templates/social_connections/oauth_callback.html b/backend/social_connections/templates/social_connections/oauth_callback.html new file mode 100644 index 00000000..83b4607f --- /dev/null +++ b/backend/social_connections/templates/social_connections/oauth_callback.html @@ -0,0 +1,86 @@ + + + + + + {{ platform|title }} Connection + + + +
+ {% if success %} +
+

{{ message }}

+ {% else %} +
+

{{ message }}

+ {% endif %} +

This window will close automatically.

+
+ + + + diff --git a/backend/social_connections/tests/__init__.py b/backend/social_connections/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/social_connections/tests/test_discord_oauth.py b/backend/social_connections/tests/test_discord_oauth.py new file mode 100644 index 00000000..bcfe63e8 --- /dev/null +++ b/backend/social_connections/tests/test_discord_oauth.py @@ -0,0 +1,342 @@ +""" +Tests for Discord OAuth views. +""" +from unittest.mock import patch, MagicMock +from django.test import TestCase, Client, override_settings +from django.core import signing +from rest_framework.test import APIClient +from users.models import User +from social_connections.discord.models import DiscordConnection + + +@override_settings( + DISCORD_CLIENT_ID='test_client_id', + DISCORD_CLIENT_SECRET='test_client_secret', + DISCORD_REDIRECT_URI='http://localhost:8000/api/auth/discord/callback/', + DISCORD_GUILD_ID='123456789', + FRONTEND_URL='http://localhost:5173', + SOCIAL_ENCRYPTION_KEY='9GibasU7S9kA35HL7CovU1xOoAf-WoC-tNVDeQhJlik=' +) +class DiscordOAuthInitiateTest(TestCase): + """Tests for Discord OAuth initiation.""" + + def setUp(self): + self.client = APIClient() + self.user = User.objects.create_user( + email='test@example.com', + address='0x1234567890123456789012345678901234567890' + ) + + def test_initiate_requires_authentication(self): + """Test that OAuth initiation requires authentication.""" + response = self.client.get('/api/auth/discord/') + self.assertEqual(response.status_code, 403) + + def test_initiate_redirects_to_discord(self): + """Test that authenticated user is redirected to Discord.""" + self.client.force_authenticate(user=self.user) + response = self.client.get('/api/auth/discord/') + + self.assertEqual(response.status_code, 302) + self.assertIn('discord.com', response.url) + self.assertIn('client_id=test_client_id', response.url) + + def test_initiate_includes_correct_scopes(self): + """Test that redirect URL includes required scopes.""" + self.client.force_authenticate(user=self.user) + response = self.client.get('/api/auth/discord/') + + # URL-encoded scopes + self.assertIn('scope=', response.url) + # identify and guilds scopes + self.assertIn('identify', response.url) + self.assertIn('guilds', response.url) + + +@override_settings( + DISCORD_CLIENT_ID='test_client_id', + DISCORD_CLIENT_SECRET='test_client_secret', + DISCORD_REDIRECT_URI='http://localhost:8000/api/auth/discord/callback/', + DISCORD_GUILD_ID='123456789', + FRONTEND_URL='http://localhost:5173', + SOCIAL_ENCRYPTION_KEY='9GibasU7S9kA35HL7CovU1xOoAf-WoC-tNVDeQhJlik=' +) +class DiscordOAuthCallbackTest(TestCase): + """Tests for Discord OAuth callback.""" + + def setUp(self): + self.client = Client() + self.user = User.objects.create_user( + email='test@example.com', + address='0x1234567890123456789012345678901234567890' + ) + # Clear the used OAuth codes cache to prevent interference between tests + from social_connections.discord import oauth + oauth._used_oauth_codes.clear() + + def _create_valid_state(self): + """Create a valid signed state token.""" + state_data = { + 'user_id': self.user.id, + 'nonce': 'test_nonce_12345' + } + return signing.dumps(state_data, salt='discord_oauth_state') + + @patch('social_connections.discord.oauth.requests.post') + def test_callback_without_code_fails(self, mock_post): + """Test callback fails without authorization code.""" + # Mock token exchange to fail + mock_post.return_value = MagicMock( + status_code=400, + json=lambda: {'error': 'invalid_request'} + ) + + state = self._create_valid_state() + response = self.client.get('/api/auth/discord/callback/', {'state': state, 'code': ''}) + + self.assertEqual(response.status_code, 200) + # Should show an error + self.assertContains(response, 'error') + + def test_callback_without_state_fails(self): + """Test callback fails without state parameter.""" + response = self.client.get('/api/auth/discord/callback/', {'code': 'test_code'}) + + self.assertEqual(response.status_code, 200) + self.assertContains(response, 'invalid_state') + + def test_callback_with_error_from_discord(self): + """Test callback handles error from Discord.""" + response = self.client.get('/api/auth/discord/callback/', { + 'error': 'access_denied', + 'error_description': 'User denied access' + }) + + self.assertEqual(response.status_code, 200) + self.assertContains(response, 'authorization_failed') + + @patch('social_connections.discord.oauth.requests.post') + @patch('social_connections.discord.oauth.requests.get') + def test_callback_success_creates_connection(self, mock_get, mock_post): + """Test successful callback creates DiscordConnection.""" + # Mock token exchange + mock_post.return_value = MagicMock( + status_code=200, + json=lambda: { + 'access_token': 'test_access_token', + 'refresh_token': 'test_refresh_token' + } + ) + + # Mock user info fetch + mock_get.return_value = MagicMock( + status_code=200, + json=lambda: { + 'username': 'testdiscorduser', + 'id': '111222333', + 'discriminator': '1234', + 'avatar': 'abc123hash' + } + ) + mock_get.return_value.raise_for_status = MagicMock() + + state = self._create_valid_state() + response = self.client.get('/api/auth/discord/callback/', { + 'code': 'valid_code', + 'state': state + }) + + self.assertEqual(response.status_code, 200) + self.assertContains(response, 'var success = true;') + + # Verify connection was created + self.assertTrue(DiscordConnection.objects.filter(user=self.user).exists()) + connection = DiscordConnection.objects.get(user=self.user) + self.assertEqual(connection.username, 'testdiscorduser') + self.assertEqual(connection.platform_user_id, '111222333') + self.assertEqual(connection.discriminator, '1234') + self.assertEqual(connection.avatar_hash, 'abc123hash') + + @patch('social_connections.discord.oauth.requests.post') + @patch('social_connections.discord.oauth.requests.get') + def test_callback_rejects_already_linked_account(self, mock_get, mock_post): + """Test callback rejects Discord account already linked to another user.""" + # Create another user with this Discord account + other_user = User.objects.create_user( + email='other@example.com', + address='0x0987654321098765432109876543210987654321' + ) + DiscordConnection.objects.create( + user=other_user, + username='testdiscorduser', + platform_user_id='111222333' + ) + + # Mock token exchange + mock_post.return_value = MagicMock( + status_code=200, + json=lambda: { + 'access_token': 'test_access_token', + 'refresh_token': 'test_refresh_token' + } + ) + + # Mock user info - same Discord ID as other_user + mock_get.return_value = MagicMock( + status_code=200, + json=lambda: { + 'username': 'testdiscorduser', + 'id': '111222333', + 'discriminator': '1234', + 'avatar': 'abc123hash' + } + ) + mock_get.return_value.raise_for_status = MagicMock() + + state = self._create_valid_state() + response = self.client.get('/api/auth/discord/callback/', { + 'code': 'valid_code', + 'state': state + }) + + self.assertEqual(response.status_code, 200) + self.assertContains(response, 'already_linked') + + # Verify no new connection was created for test user + self.assertFalse(DiscordConnection.objects.filter(user=self.user).exists()) + + +@override_settings( + SOCIAL_ENCRYPTION_KEY='9GibasU7S9kA35HL7CovU1xOoAf-WoC-tNVDeQhJlik=', + DISCORD_GUILD_ID='123456789' +) +class DiscordDisconnectTest(TestCase): + """Tests for Discord disconnect endpoint.""" + + def setUp(self): + self.client = APIClient() + self.user = User.objects.create_user( + email='test@example.com', + address='0x1234567890123456789012345678901234567890' + ) + + def test_disconnect_requires_authentication(self): + """Test that disconnect requires authentication.""" + response = self.client.post('/api/v1/social/discord/disconnect/') + self.assertEqual(response.status_code, 403) + + def test_disconnect_removes_connection(self): + """Test that disconnect removes the Discord connection.""" + DiscordConnection.objects.create( + user=self.user, + username='testuser', + platform_user_id='12345' + ) + + self.client.force_authenticate(user=self.user) + response = self.client.post('/api/v1/social/discord/disconnect/') + + self.assertEqual(response.status_code, 200) + self.assertFalse(DiscordConnection.objects.filter(user=self.user).exists()) + + def test_disconnect_without_connection_succeeds(self): + """Test that disconnect succeeds even if no connection exists.""" + self.client.force_authenticate(user=self.user) + response = self.client.post('/api/v1/social/discord/disconnect/') + + self.assertEqual(response.status_code, 200) + + +@override_settings( + SOCIAL_ENCRYPTION_KEY='9GibasU7S9kA35HL7CovU1xOoAf-WoC-tNVDeQhJlik=', + DISCORD_GUILD_ID='123456789' +) +class DiscordCheckGuildTest(TestCase): + """Tests for Discord check-guild endpoint.""" + + def setUp(self): + self.client = APIClient() + self.user = User.objects.create_user( + email='test@example.com', + address='0x1234567890123456789012345678901234567890' + ) + + def test_check_guild_requires_authentication(self): + """Test that check-guild requires authentication.""" + response = self.client.get('/api/v1/social/discord/check-guild/') + self.assertEqual(response.status_code, 403) + + def test_check_guild_without_connection(self): + """Test check-guild when user has no Discord connection.""" + self.client.force_authenticate(user=self.user) + response = self.client.get('/api/v1/social/discord/check-guild/') + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data['is_member'], False) + self.assertIn('error', response.data) + + @patch('social_connections.discord.oauth.requests.get') + def test_check_guild_returns_true_when_member(self, mock_get): + """Test check-guild returns true when user is a guild member.""" + from social_connections.encryption import encrypt_token + + DiscordConnection.objects.create( + user=self.user, + username='testuser', + platform_user_id='12345', + access_token=encrypt_token('test_token') + ) + + # Mock Discord API response - user's guilds including our guild + mock_get.return_value = MagicMock( + status_code=200, + json=lambda: [ + {'id': '111111111', 'name': 'Other Server'}, + {'id': '123456789', 'name': 'Our Server'}, # This is the configured guild + ] + ) + mock_get.return_value.raise_for_status = MagicMock() + + self.client.force_authenticate(user=self.user) + response = self.client.get('/api/v1/social/discord/check-guild/') + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data['is_member'], True) + + @patch('social_connections.discord.oauth.requests.get') + def test_check_guild_returns_false_when_not_member(self, mock_get): + """Test check-guild returns false when user is not a guild member.""" + from social_connections.encryption import encrypt_token + + DiscordConnection.objects.create( + user=self.user, + username='testuser', + platform_user_id='12345', + access_token=encrypt_token('test_token') + ) + + # Mock Discord API response - user's guilds NOT including our guild + mock_get.return_value = MagicMock( + status_code=200, + json=lambda: [ + {'id': '111111111', 'name': 'Other Server'}, + {'id': '222222222', 'name': 'Another Server'}, + ] + ) + mock_get.return_value.raise_for_status = MagicMock() + + self.client.force_authenticate(user=self.user) + response = self.client.get('/api/v1/social/discord/check-guild/') + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data['is_member'], False) + + @override_settings(DISCORD_GUILD_ID='') + def test_check_guild_without_configured_guild(self): + """Test check-guild when no guild ID is configured.""" + self.client.force_authenticate(user=self.user) + response = self.client.get('/api/v1/social/discord/check-guild/') + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data['is_member'], False) + self.assertIn('not configured', response.data.get('error', '')) diff --git a/backend/social_connections/tests/test_encryption.py b/backend/social_connections/tests/test_encryption.py new file mode 100644 index 00000000..c7f2d9ee --- /dev/null +++ b/backend/social_connections/tests/test_encryption.py @@ -0,0 +1,60 @@ +""" +Tests for social connection encryption utilities. +""" +from django.test import TestCase, override_settings +from social_connections.encryption import encrypt_token, decrypt_token, get_fernet + + +class EncryptionTestCase(TestCase): + """Tests for token encryption and decryption.""" + + @override_settings(SOCIAL_ENCRYPTION_KEY='9GibasU7S9kA35HL7CovU1xOoAf-WoC-tNVDeQhJlik=') + def test_encrypt_and_decrypt_token(self): + """Test that a token can be encrypted and decrypted.""" + original_token = "ghp_test_token_12345" + encrypted = encrypt_token(original_token) + + # Encrypted token should be different from original + self.assertNotEqual(encrypted, original_token) + + # Decrypted token should match original + decrypted = decrypt_token(encrypted) + self.assertEqual(decrypted, original_token) + + @override_settings(SOCIAL_ENCRYPTION_KEY='9GibasU7S9kA35HL7CovU1xOoAf-WoC-tNVDeQhJlik=') + def test_encrypt_empty_token(self): + """Test that empty tokens return empty string.""" + self.assertEqual(encrypt_token(""), "") + self.assertEqual(encrypt_token(None), "") + + @override_settings(SOCIAL_ENCRYPTION_KEY='9GibasU7S9kA35HL7CovU1xOoAf-WoC-tNVDeQhJlik=') + def test_decrypt_empty_token(self): + """Test that empty encrypted tokens return empty string.""" + self.assertEqual(decrypt_token(""), "") + self.assertEqual(decrypt_token(None), "") + + @override_settings(SOCIAL_ENCRYPTION_KEY='9GibasU7S9kA35HL7CovU1xOoAf-WoC-tNVDeQhJlik=') + def test_different_tokens_produce_different_encryptions(self): + """Test that different tokens produce different encrypted values.""" + token1 = "token_one" + token2 = "token_two" + + encrypted1 = encrypt_token(token1) + encrypted2 = encrypt_token(token2) + + self.assertNotEqual(encrypted1, encrypted2) + + @override_settings(SOCIAL_ENCRYPTION_KEY=None, GITHUB_ENCRYPTION_KEY=None) + def test_missing_encryption_key_raises_error(self): + """Test that missing encryption key raises RuntimeError.""" + with self.assertRaises(RuntimeError) as context: + get_fernet() + + self.assertIn("SOCIAL_ENCRYPTION_KEY", str(context.exception)) + + @override_settings(SOCIAL_ENCRYPTION_KEY=None, GITHUB_ENCRYPTION_KEY='9GibasU7S9kA35HL7CovU1xOoAf-WoC-tNVDeQhJlik=') + def test_fallback_to_github_encryption_key(self): + """Test that GITHUB_ENCRYPTION_KEY is used as fallback.""" + # Should not raise an error + fernet = get_fernet() + self.assertIsNotNone(fernet) diff --git a/backend/social_connections/tests/test_github_oauth.py b/backend/social_connections/tests/test_github_oauth.py new file mode 100644 index 00000000..a0352ece --- /dev/null +++ b/backend/social_connections/tests/test_github_oauth.py @@ -0,0 +1,299 @@ +""" +Tests for GitHub OAuth views. +""" +from unittest.mock import patch, MagicMock +from django.test import TestCase, Client, override_settings +from django.urls import reverse +from django.core import signing +from django.utils import timezone +from rest_framework.test import APIClient +from users.models import User +from social_connections.github.models import GitHubConnection + + +@override_settings( + GITHUB_CLIENT_ID='test_client_id', + GITHUB_CLIENT_SECRET='test_client_secret', + GITHUB_REDIRECT_URI='http://localhost:8000/api/auth/github/callback/', + FRONTEND_URL='http://localhost:5173', + SOCIAL_ENCRYPTION_KEY='9GibasU7S9kA35HL7CovU1xOoAf-WoC-tNVDeQhJlik=' +) +class GitHubOAuthInitiateTest(TestCase): + """Tests for GitHub OAuth initiation.""" + + def setUp(self): + self.client = APIClient() + self.user = User.objects.create_user( + email='test@example.com', + address='0x1234567890123456789012345678901234567890' + ) + + def test_initiate_requires_authentication(self): + """Test that OAuth initiation requires authentication.""" + response = self.client.get('/api/auth/github/') + self.assertEqual(response.status_code, 403) + + def test_initiate_redirects_to_github(self): + """Test that authenticated user is redirected to GitHub.""" + self.client.force_authenticate(user=self.user) + response = self.client.get('/api/auth/github/') + + self.assertEqual(response.status_code, 302) + self.assertIn('github.com/login/oauth/authorize', response.url) + self.assertIn('client_id=test_client_id', response.url) + + def test_initiate_includes_state_parameter(self): + """Test that redirect URL includes signed state parameter.""" + self.client.force_authenticate(user=self.user) + response = self.client.get('/api/auth/github/') + + self.assertIn('state=', response.url) + + +@override_settings( + GITHUB_CLIENT_ID='test_client_id', + GITHUB_CLIENT_SECRET='test_client_secret', + GITHUB_REDIRECT_URI='http://localhost:8000/api/auth/github/callback/', + FRONTEND_URL='http://localhost:5173', + SOCIAL_ENCRYPTION_KEY='9GibasU7S9kA35HL7CovU1xOoAf-WoC-tNVDeQhJlik=' +) +class GitHubOAuthCallbackTest(TestCase): + """Tests for GitHub OAuth callback.""" + + def setUp(self): + self.client = Client() + self.user = User.objects.create_user( + email='test@example.com', + address='0x1234567890123456789012345678901234567890' + ) + # Clear the used OAuth codes cache to prevent interference between tests + from social_connections.github import oauth + oauth._used_oauth_codes.clear() + + def _create_valid_state(self): + """Create a valid signed state token.""" + state_data = { + 'user_id': self.user.id, + 'nonce': 'test_nonce_12345' + } + return signing.dumps(state_data, salt='github_oauth_state') + + @patch('social_connections.github.oauth.requests.post') + def test_callback_without_code_fails(self, mock_post): + """Test callback fails without authorization code.""" + # Mock token exchange to fail (no code provided) + mock_post.return_value = MagicMock( + status_code=200, + json=lambda: {'error': 'bad_verification_code'} + ) + + state = self._create_valid_state() + response = self.client.get('/api/auth/github/callback/', {'state': state, 'code': ''}) + + self.assertEqual(response.status_code, 200) + # Should show an error + self.assertContains(response, 'error') + + def test_callback_without_state_fails(self): + """Test callback fails without state parameter.""" + response = self.client.get('/api/auth/github/callback/', {'code': 'test_code'}) + + self.assertEqual(response.status_code, 200) + self.assertContains(response, 'invalid_state') + + def test_callback_with_expired_state_fails(self): + """Test callback fails with expired state.""" + # Create state that appears expired + state_data = { + 'user_id': self.user.id, + 'nonce': 'test_nonce' + } + # We can't easily create an expired token, but we can test invalid signature + response = self.client.get('/api/auth/github/callback/', { + 'code': 'test_code', + 'state': 'invalid_state_token' + }) + + self.assertEqual(response.status_code, 200) + self.assertContains(response, 'invalid_state') + + def test_callback_with_error_from_github(self): + """Test callback handles error from GitHub.""" + response = self.client.get('/api/auth/github/callback/', { + 'error': 'access_denied', + 'error_description': 'User denied access' + }) + + self.assertEqual(response.status_code, 200) + self.assertContains(response, 'authorization_failed') + + @patch('social_connections.github.oauth.requests.post') + @patch('social_connections.github.oauth.requests.get') + def test_callback_success_creates_connection(self, mock_get, mock_post): + """Test successful callback creates GitHubConnection.""" + # Mock token exchange + mock_post.return_value = MagicMock( + status_code=200, + json=lambda: {'access_token': 'test_access_token'} + ) + mock_post.return_value.raise_for_status = MagicMock() + + # Mock user info fetch + mock_get.return_value = MagicMock( + status_code=200, + json=lambda: { + 'login': 'testgithubuser', + 'id': 12345678 + } + ) + mock_get.return_value.raise_for_status = MagicMock() + + state = self._create_valid_state() + response = self.client.get('/api/auth/github/callback/', { + 'code': 'valid_code', + 'state': state + }) + + self.assertEqual(response.status_code, 200) + self.assertContains(response, 'var success = true;') + + # Verify connection was created + self.assertTrue(GitHubConnection.objects.filter(user=self.user).exists()) + connection = GitHubConnection.objects.get(user=self.user) + self.assertEqual(connection.username, 'testgithubuser') + self.assertEqual(connection.platform_user_id, '12345678') + + @patch('social_connections.github.oauth.requests.post') + @patch('social_connections.github.oauth.requests.get') + def test_callback_rejects_already_linked_account(self, mock_get, mock_post): + """Test callback rejects GitHub account already linked to another user.""" + # Create another user with this GitHub account + other_user = User.objects.create_user( + email='other@example.com', + address='0x0987654321098765432109876543210987654321' + ) + GitHubConnection.objects.create( + user=other_user, + username='testgithubuser', + platform_user_id='12345678' + ) + + # Mock token exchange + mock_post.return_value = MagicMock( + status_code=200, + json=lambda: {'access_token': 'test_access_token'} + ) + mock_post.return_value.raise_for_status = MagicMock() + + # Mock user info - same GitHub ID as other_user + mock_get.return_value = MagicMock( + status_code=200, + json=lambda: { + 'login': 'testgithubuser', + 'id': 12345678 + } + ) + mock_get.return_value.raise_for_status = MagicMock() + + state = self._create_valid_state() + response = self.client.get('/api/auth/github/callback/', { + 'code': 'valid_code', + 'state': state + }) + + self.assertEqual(response.status_code, 200) + self.assertContains(response, 'already_linked') + + # Verify no new connection was created for test user + self.assertFalse(GitHubConnection.objects.filter(user=self.user).exists()) + + +@override_settings( + SOCIAL_ENCRYPTION_KEY='9GibasU7S9kA35HL7CovU1xOoAf-WoC-tNVDeQhJlik=', + GITHUB_REPO_TO_STAR='genlayerlabs/test-repo' +) +class GitHubDisconnectTest(TestCase): + """Tests for GitHub disconnect endpoint.""" + + def setUp(self): + self.client = APIClient() + self.user = User.objects.create_user( + email='test@example.com', + address='0x1234567890123456789012345678901234567890' + ) + + def test_disconnect_requires_authentication(self): + """Test that disconnect requires authentication.""" + response = self.client.post('/api/v1/social/github/disconnect/') + self.assertEqual(response.status_code, 403) + + def test_disconnect_removes_connection(self): + """Test that disconnect removes the GitHub connection.""" + GitHubConnection.objects.create( + user=self.user, + username='testuser', + platform_user_id='12345' + ) + + self.client.force_authenticate(user=self.user) + response = self.client.post('/api/v1/social/github/disconnect/') + + self.assertEqual(response.status_code, 200) + self.assertFalse(GitHubConnection.objects.filter(user=self.user).exists()) + + def test_disconnect_without_connection_succeeds(self): + """Test that disconnect succeeds even if no connection exists.""" + self.client.force_authenticate(user=self.user) + response = self.client.post('/api/v1/social/github/disconnect/') + + self.assertEqual(response.status_code, 200) + + +@override_settings( + SOCIAL_ENCRYPTION_KEY='9GibasU7S9kA35HL7CovU1xOoAf-WoC-tNVDeQhJlik=', + GITHUB_REPO_TO_STAR='genlayerlabs/test-repo' +) +class GitHubCheckStarTest(TestCase): + """Tests for GitHub check star endpoint.""" + + def setUp(self): + self.client = APIClient() + self.user = User.objects.create_user( + email='test@example.com', + address='0x1234567890123456789012345678901234567890' + ) + + def test_check_star_requires_authentication(self): + """Test that check-star requires authentication.""" + response = self.client.get('/api/v1/social/github/check-star/') + self.assertEqual(response.status_code, 403) + + def test_check_star_without_connection(self): + """Test check-star when user has no GitHub connection.""" + self.client.force_authenticate(user=self.user) + response = self.client.get('/api/v1/social/github/check-star/') + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data['has_starred'], False) + self.assertIn('error', response.data) + + @patch('social_connections.github.oauth.requests.get') + def test_check_star_returns_true_when_starred(self, mock_get): + """Test check-star returns true when user has starred the repo.""" + from social_connections.encryption import encrypt_token + + GitHubConnection.objects.create( + user=self.user, + username='testuser', + platform_user_id='12345', + access_token=encrypt_token('test_token') + ) + + # Mock GitHub API response - 204 means starred + mock_get.return_value = MagicMock(status_code=204) + + self.client.force_authenticate(user=self.user) + response = self.client.get('/api/v1/social/github/check-star/') + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data['has_starred'], True) diff --git a/backend/social_connections/tests/test_models.py b/backend/social_connections/tests/test_models.py new file mode 100644 index 00000000..77e9b5d5 --- /dev/null +++ b/backend/social_connections/tests/test_models.py @@ -0,0 +1,166 @@ +""" +Tests for social connection models. +""" +from django.test import TestCase +from django.utils import timezone +from users.models import User +from social_connections.github.models import GitHubConnection +from social_connections.twitter.models import TwitterConnection +from social_connections.discord.models import DiscordConnection + + +class GitHubConnectionModelTest(TestCase): + """Tests for GitHubConnection model.""" + + def setUp(self): + self.user = User.objects.create_user( + email='test@example.com', + address='0x1234567890123456789012345678901234567890' + ) + + def test_create_github_connection(self): + """Test creating a GitHub connection.""" + connection = GitHubConnection.objects.create( + user=self.user, + username='testuser', + platform_user_id='12345', + access_token='encrypted_token', + linked_at=timezone.now() + ) + + self.assertEqual(connection.user, self.user) + self.assertEqual(connection.username, 'testuser') + self.assertEqual(connection.platform_user_id, '12345') + self.assertIsNotNone(connection.linked_at) + + def test_github_connection_str(self): + """Test string representation of GitHub connection.""" + connection = GitHubConnection.objects.create( + user=self.user, + username='testuser', + platform_user_id='12345' + ) + + self.assertIn('testuser', str(connection)) + + def test_one_to_one_relationship(self): + """Test that user can only have one GitHub connection.""" + GitHubConnection.objects.create( + user=self.user, + username='testuser', + platform_user_id='12345' + ) + + # Creating another connection for the same user should fail + with self.assertRaises(Exception): + GitHubConnection.objects.create( + user=self.user, + username='anotheruser', + platform_user_id='67890' + ) + + def test_access_via_user(self): + """Test accessing GitHub connection via user.github_connection.""" + connection = GitHubConnection.objects.create( + user=self.user, + username='testuser', + platform_user_id='12345' + ) + + self.assertEqual(self.user.github_connection, connection) + + +class TwitterConnectionModelTest(TestCase): + """Tests for TwitterConnection model.""" + + def setUp(self): + self.user = User.objects.create_user( + email='test@example.com', + address='0x1234567890123456789012345678901234567890' + ) + + def test_create_twitter_connection(self): + """Test creating a Twitter connection.""" + connection = TwitterConnection.objects.create( + user=self.user, + username='testuser', + platform_user_id='12345', + access_token='encrypted_token', + refresh_token='encrypted_refresh_token', + linked_at=timezone.now() + ) + + self.assertEqual(connection.user, self.user) + self.assertEqual(connection.username, 'testuser') + self.assertEqual(connection.platform_user_id, '12345') + self.assertIsNotNone(connection.refresh_token) + + def test_access_via_user(self): + """Test accessing Twitter connection via user.twitter_connection.""" + connection = TwitterConnection.objects.create( + user=self.user, + username='testuser', + platform_user_id='12345' + ) + + self.assertEqual(self.user.twitter_connection, connection) + + +class DiscordConnectionModelTest(TestCase): + """Tests for DiscordConnection model.""" + + def setUp(self): + self.user = User.objects.create_user( + email='test@example.com', + address='0x1234567890123456789012345678901234567890' + ) + + def test_create_discord_connection(self): + """Test creating a Discord connection.""" + connection = DiscordConnection.objects.create( + user=self.user, + username='testuser', + platform_user_id='12345', + discriminator='1234', + avatar_hash='abc123', + access_token='encrypted_token', + linked_at=timezone.now() + ) + + self.assertEqual(connection.user, self.user) + self.assertEqual(connection.username, 'testuser') + self.assertEqual(connection.discriminator, '1234') + self.assertEqual(connection.avatar_hash, 'abc123') + + def test_avatar_url_property(self): + """Test the avatar_url property.""" + connection = DiscordConnection.objects.create( + user=self.user, + username='testuser', + platform_user_id='123456789', + avatar_hash='abc123def456' + ) + + expected_url = 'https://cdn.discordapp.com/avatars/123456789/abc123def456.png' + self.assertEqual(connection.avatar_url, expected_url) + + def test_avatar_url_none_when_no_hash(self): + """Test avatar_url returns None when no avatar hash.""" + connection = DiscordConnection.objects.create( + user=self.user, + username='testuser', + platform_user_id='12345', + avatar_hash='' + ) + + self.assertIsNone(connection.avatar_url) + + def test_access_via_user(self): + """Test accessing Discord connection via user.discord_connection.""" + connection = DiscordConnection.objects.create( + user=self.user, + username='testuser', + platform_user_id='12345' + ) + + self.assertEqual(self.user.discord_connection, connection) diff --git a/backend/social_connections/tests/test_serializers.py b/backend/social_connections/tests/test_serializers.py new file mode 100644 index 00000000..994fd8c9 --- /dev/null +++ b/backend/social_connections/tests/test_serializers.py @@ -0,0 +1,209 @@ +""" +Tests for social connection serializers. +""" +from django.test import TestCase +from django.utils import timezone +from users.models import User +from users.serializers import UserSerializer +from social_connections.github.models import GitHubConnection +from social_connections.twitter.models import TwitterConnection +from social_connections.discord.models import DiscordConnection +from social_connections.serializers import ( + GitHubConnectionSerializer, + TwitterConnectionSerializer, + DiscordConnectionSerializer +) + + +class GitHubConnectionSerializerTest(TestCase): + """Tests for GitHubConnectionSerializer.""" + + def setUp(self): + self.user = User.objects.create_user( + email='test@example.com', + address='0x1234567890123456789012345678901234567890' + ) + self.connection = GitHubConnection.objects.create( + user=self.user, + username='testuser', + platform_user_id='12345', + linked_at=timezone.now() + ) + + def test_serializer_fields(self): + """Test that serializer returns expected fields.""" + serializer = GitHubConnectionSerializer(self.connection) + data = serializer.data + + self.assertIn('username', data) + self.assertIn('linked_at', data) + self.assertEqual(data['username'], 'testuser') + + def test_serializer_excludes_sensitive_fields(self): + """Test that sensitive fields are not exposed.""" + serializer = GitHubConnectionSerializer(self.connection) + data = serializer.data + + self.assertNotIn('access_token', data) + self.assertNotIn('platform_user_id', data) + + +class TwitterConnectionSerializerTest(TestCase): + """Tests for TwitterConnectionSerializer.""" + + def setUp(self): + self.user = User.objects.create_user( + email='test@example.com', + address='0x1234567890123456789012345678901234567890' + ) + self.connection = TwitterConnection.objects.create( + user=self.user, + username='testuser', + platform_user_id='12345', + linked_at=timezone.now() + ) + + def test_serializer_fields(self): + """Test that serializer returns expected fields.""" + serializer = TwitterConnectionSerializer(self.connection) + data = serializer.data + + self.assertIn('username', data) + self.assertIn('linked_at', data) + self.assertEqual(data['username'], 'testuser') + + +class DiscordConnectionSerializerTest(TestCase): + """Tests for DiscordConnectionSerializer.""" + + def setUp(self): + self.user = User.objects.create_user( + email='test@example.com', + address='0x1234567890123456789012345678901234567890' + ) + self.connection = DiscordConnection.objects.create( + user=self.user, + username='testuser', + platform_user_id='123456789', + discriminator='1234', + avatar_hash='abc123', + linked_at=timezone.now() + ) + + def test_serializer_fields(self): + """Test that serializer returns expected fields.""" + serializer = DiscordConnectionSerializer(self.connection) + data = serializer.data + + self.assertIn('username', data) + self.assertIn('discriminator', data) + self.assertIn('avatar_url', data) + self.assertIn('linked_at', data) + + def test_avatar_url_included(self): + """Test that avatar_url is correctly computed.""" + serializer = DiscordConnectionSerializer(self.connection) + data = serializer.data + + expected_url = 'https://cdn.discordapp.com/avatars/123456789/abc123.png' + self.assertEqual(data['avatar_url'], expected_url) + + +class UserSerializerSocialConnectionsTest(TestCase): + """Tests for UserSerializer with social connections.""" + + def setUp(self): + self.user = User.objects.create_user( + email='test@example.com', + address='0x1234567890123456789012345678901234567890', + name='Test User' + ) + + def test_user_without_connections(self): + """Test UserSerializer when user has no social connections.""" + serializer = UserSerializer(self.user) + data = serializer.data + + self.assertIsNone(data['github_connection']) + self.assertIsNone(data['twitter_connection']) + self.assertIsNone(data['discord_connection']) + self.assertEqual(data['github_username'], '') + self.assertIsNone(data['github_linked_at']) + + def test_user_with_github_connection(self): + """Test UserSerializer includes GitHub connection data.""" + GitHubConnection.objects.create( + user=self.user, + username='githubuser', + platform_user_id='12345', + linked_at=timezone.now() + ) + + serializer = UserSerializer(self.user) + data = serializer.data + + self.assertIsNotNone(data['github_connection']) + self.assertEqual(data['github_connection']['username'], 'githubuser') + # Legacy fields + self.assertEqual(data['github_username'], 'githubuser') + self.assertIsNotNone(data['github_linked_at']) + + def test_user_with_twitter_connection(self): + """Test UserSerializer includes Twitter connection data.""" + TwitterConnection.objects.create( + user=self.user, + username='twitteruser', + platform_user_id='12345', + linked_at=timezone.now() + ) + + serializer = UserSerializer(self.user) + data = serializer.data + + self.assertIsNotNone(data['twitter_connection']) + self.assertEqual(data['twitter_connection']['username'], 'twitteruser') + + def test_user_with_discord_connection(self): + """Test UserSerializer includes Discord connection data.""" + DiscordConnection.objects.create( + user=self.user, + username='discorduser', + platform_user_id='12345', + discriminator='1234', + linked_at=timezone.now() + ) + + serializer = UserSerializer(self.user) + data = serializer.data + + self.assertIsNotNone(data['discord_connection']) + self.assertEqual(data['discord_connection']['username'], 'discorduser') + self.assertEqual(data['discord_connection']['discriminator'], '1234') + + def test_user_with_all_connections(self): + """Test UserSerializer with all social connections.""" + GitHubConnection.objects.create( + user=self.user, + username='githubuser', + platform_user_id='111', + linked_at=timezone.now() + ) + TwitterConnection.objects.create( + user=self.user, + username='twitteruser', + platform_user_id='222', + linked_at=timezone.now() + ) + DiscordConnection.objects.create( + user=self.user, + username='discorduser', + platform_user_id='333', + linked_at=timezone.now() + ) + + serializer = UserSerializer(self.user) + data = serializer.data + + self.assertEqual(data['github_connection']['username'], 'githubuser') + self.assertEqual(data['twitter_connection']['username'], 'twitteruser') + self.assertEqual(data['discord_connection']['username'], 'discorduser') diff --git a/backend/social_connections/tests/test_twitter_oauth.py b/backend/social_connections/tests/test_twitter_oauth.py new file mode 100644 index 00000000..cc4b184a --- /dev/null +++ b/backend/social_connections/tests/test_twitter_oauth.py @@ -0,0 +1,248 @@ +""" +Tests for Twitter OAuth views. +""" +from unittest.mock import patch, MagicMock +from django.test import TestCase, Client, override_settings +from django.core import signing +from rest_framework.test import APIClient +from users.models import User +from social_connections.twitter.models import TwitterConnection + + +@override_settings( + TWITTER_CLIENT_ID='test_client_id', + TWITTER_CLIENT_SECRET='test_client_secret', + TWITTER_REDIRECT_URI='http://localhost:8000/api/auth/twitter/callback/', + FRONTEND_URL='http://localhost:5173', + SOCIAL_ENCRYPTION_KEY='9GibasU7S9kA35HL7CovU1xOoAf-WoC-tNVDeQhJlik=' +) +class TwitterOAuthInitiateTest(TestCase): + """Tests for Twitter OAuth initiation.""" + + def setUp(self): + self.client = APIClient() + self.user = User.objects.create_user( + email='test@example.com', + address='0x1234567890123456789012345678901234567890' + ) + + def test_initiate_requires_authentication(self): + """Test that OAuth initiation requires authentication.""" + response = self.client.get('/api/auth/twitter/') + self.assertEqual(response.status_code, 403) + + def test_initiate_redirects_to_twitter(self): + """Test that authenticated user is redirected to Twitter.""" + self.client.force_authenticate(user=self.user) + response = self.client.get('/api/auth/twitter/') + + self.assertEqual(response.status_code, 302) + self.assertIn('twitter.com', response.url) + self.assertIn('client_id=test_client_id', response.url) + + def test_initiate_includes_pkce_parameters(self): + """Test that redirect URL includes PKCE code challenge.""" + self.client.force_authenticate(user=self.user) + response = self.client.get('/api/auth/twitter/') + + self.assertIn('code_challenge=', response.url) + self.assertIn('code_challenge_method=S256', response.url) + + def test_initiate_includes_correct_scopes(self): + """Test that redirect URL includes required scopes.""" + self.client.force_authenticate(user=self.user) + response = self.client.get('/api/auth/twitter/') + + # URL-encoded scopes + self.assertIn('scope=', response.url) + + +@override_settings( + TWITTER_CLIENT_ID='test_client_id', + TWITTER_CLIENT_SECRET='test_client_secret', + TWITTER_REDIRECT_URI='http://localhost:8000/api/auth/twitter/callback/', + FRONTEND_URL='http://localhost:5173', + SOCIAL_ENCRYPTION_KEY='9GibasU7S9kA35HL7CovU1xOoAf-WoC-tNVDeQhJlik=' +) +class TwitterOAuthCallbackTest(TestCase): + """Tests for Twitter OAuth callback.""" + + def setUp(self): + self.client = Client() + self.user = User.objects.create_user( + email='test@example.com', + address='0x1234567890123456789012345678901234567890' + ) + # Clear the used OAuth codes cache to prevent interference between tests + from social_connections.twitter import oauth + oauth._used_oauth_codes.clear() + + def _create_valid_state(self): + """Create a valid signed state token with PKCE code verifier.""" + state_data = { + 'user_id': self.user.id, + 'code_verifier': 'test_code_verifier_12345678901234567890', + 'nonce': 'test_nonce_12345' + } + return signing.dumps(state_data, salt='twitter_oauth_state') + + @patch('social_connections.twitter.oauth.requests.post') + def test_callback_without_code_fails(self, mock_post): + """Test callback fails without authorization code.""" + # Mock token exchange to fail + mock_post.return_value = MagicMock( + status_code=400, + json=lambda: {'error': 'invalid_request'} + ) + + state = self._create_valid_state() + response = self.client.get('/api/auth/twitter/callback/', {'state': state, 'code': ''}) + + self.assertEqual(response.status_code, 200) + # Should show an error + self.assertContains(response, 'error') + + def test_callback_without_state_fails(self): + """Test callback fails without state parameter.""" + response = self.client.get('/api/auth/twitter/callback/', {'code': 'test_code'}) + + self.assertEqual(response.status_code, 200) + self.assertContains(response, 'invalid_state') + + def test_callback_with_error_from_twitter(self): + """Test callback handles error from Twitter.""" + response = self.client.get('/api/auth/twitter/callback/', { + 'error': 'access_denied', + 'error_description': 'User denied access' + }) + + self.assertEqual(response.status_code, 200) + self.assertContains(response, 'authorization_failed') + + @patch('social_connections.twitter.oauth.requests.post') + @patch('social_connections.twitter.oauth.requests.get') + def test_callback_success_creates_connection(self, mock_get, mock_post): + """Test successful callback creates TwitterConnection.""" + # Mock token exchange + mock_post.return_value = MagicMock( + status_code=200, + json=lambda: { + 'access_token': 'test_access_token', + 'refresh_token': 'test_refresh_token' + } + ) + + # Mock user info fetch + mock_get.return_value = MagicMock( + status_code=200, + json=lambda: { + 'data': { + 'username': 'testtwitteruser', + 'id': '987654321' + } + } + ) + mock_get.return_value.raise_for_status = MagicMock() + + state = self._create_valid_state() + response = self.client.get('/api/auth/twitter/callback/', { + 'code': 'valid_code', + 'state': state + }) + + self.assertEqual(response.status_code, 200) + self.assertContains(response, 'var success = true;') + + # Verify connection was created + self.assertTrue(TwitterConnection.objects.filter(user=self.user).exists()) + connection = TwitterConnection.objects.get(user=self.user) + self.assertEqual(connection.username, 'testtwitteruser') + self.assertEqual(connection.platform_user_id, '987654321') + + @patch('social_connections.twitter.oauth.requests.post') + @patch('social_connections.twitter.oauth.requests.get') + def test_callback_rejects_already_linked_account(self, mock_get, mock_post): + """Test callback rejects Twitter account already linked to another user.""" + # Create another user with this Twitter account + other_user = User.objects.create_user( + email='other@example.com', + address='0x0987654321098765432109876543210987654321' + ) + TwitterConnection.objects.create( + user=other_user, + username='testtwitteruser', + platform_user_id='987654321' + ) + + # Mock token exchange + mock_post.return_value = MagicMock( + status_code=200, + json=lambda: { + 'access_token': 'test_access_token', + 'refresh_token': 'test_refresh_token' + } + ) + + # Mock user info - same Twitter ID as other_user + mock_get.return_value = MagicMock( + status_code=200, + json=lambda: { + 'data': { + 'username': 'testtwitteruser', + 'id': '987654321' + } + } + ) + mock_get.return_value.raise_for_status = MagicMock() + + state = self._create_valid_state() + response = self.client.get('/api/auth/twitter/callback/', { + 'code': 'valid_code', + 'state': state + }) + + self.assertEqual(response.status_code, 200) + self.assertContains(response, 'already_linked') + + # Verify no new connection was created for test user + self.assertFalse(TwitterConnection.objects.filter(user=self.user).exists()) + + +@override_settings( + SOCIAL_ENCRYPTION_KEY='9GibasU7S9kA35HL7CovU1xOoAf-WoC-tNVDeQhJlik=' +) +class TwitterDisconnectTest(TestCase): + """Tests for Twitter disconnect endpoint.""" + + def setUp(self): + self.client = APIClient() + self.user = User.objects.create_user( + email='test@example.com', + address='0x1234567890123456789012345678901234567890' + ) + + def test_disconnect_requires_authentication(self): + """Test that disconnect requires authentication.""" + response = self.client.post('/api/v1/social/twitter/disconnect/') + self.assertEqual(response.status_code, 403) + + def test_disconnect_removes_connection(self): + """Test that disconnect removes the Twitter connection.""" + TwitterConnection.objects.create( + user=self.user, + username='testuser', + platform_user_id='12345' + ) + + self.client.force_authenticate(user=self.user) + response = self.client.post('/api/v1/social/twitter/disconnect/') + + self.assertEqual(response.status_code, 200) + self.assertFalse(TwitterConnection.objects.filter(user=self.user).exists()) + + def test_disconnect_without_connection_succeeds(self): + """Test that disconnect succeeds even if no connection exists.""" + self.client.force_authenticate(user=self.user) + response = self.client.post('/api/v1/social/twitter/disconnect/') + + self.assertEqual(response.status_code, 200) diff --git a/backend/social_connections/twitter/__init__.py b/backend/social_connections/twitter/__init__.py new file mode 100644 index 00000000..2eb7be0e --- /dev/null +++ b/backend/social_connections/twitter/__init__.py @@ -0,0 +1 @@ +default_app_config = 'social_connections.twitter.apps.TwitterConfig' diff --git a/backend/social_connections/twitter/admin.py b/backend/social_connections/twitter/admin.py new file mode 100644 index 00000000..240f2307 --- /dev/null +++ b/backend/social_connections/twitter/admin.py @@ -0,0 +1,21 @@ +from django.contrib import admin +from .models import TwitterConnection + + +@admin.register(TwitterConnection) +class TwitterConnectionAdmin(admin.ModelAdmin): + list_display = ['user', 'username', 'platform_user_id', 'linked_at', 'created_at'] + list_filter = ['linked_at'] + search_fields = ['user__email', 'user__name', 'username', 'platform_user_id'] + readonly_fields = ['platform_user_id', 'access_token', 'refresh_token', 'linked_at', 'created_at', 'updated_at'] + raw_id_fields = ['user'] + + fieldsets = ( + (None, { + 'fields': ('user', 'username', 'platform_user_id') + }), + ('Timestamps', { + 'fields': ('linked_at', 'created_at', 'updated_at'), + 'classes': ('collapse',) + }), + ) diff --git a/backend/social_connections/twitter/apps.py b/backend/social_connections/twitter/apps.py new file mode 100644 index 00000000..66215d23 --- /dev/null +++ b/backend/social_connections/twitter/apps.py @@ -0,0 +1,8 @@ +from django.apps import AppConfig + + +class TwitterConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'social_connections.twitter' + label = 'social_twitter' + verbose_name = 'Twitter Connection' diff --git a/backend/social_connections/twitter/migrations/0001_initial.py b/backend/social_connections/twitter/migrations/0001_initial.py new file mode 100644 index 00000000..a0dd027d --- /dev/null +++ b/backend/social_connections/twitter/migrations/0001_initial.py @@ -0,0 +1,35 @@ +# Generated by Django 5.2.7 on 2026-01-30 15:27 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name='TwitterConnection', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('username', models.CharField(blank=True, help_text='Platform username', max_length=100)), + ('platform_user_id', models.CharField(blank=True, help_text='Platform user ID for unique identification', max_length=50)), + ('access_token', models.TextField(blank=True, help_text='Encrypted access token')), + ('refresh_token', models.TextField(blank=True, help_text='Encrypted refresh token (if applicable)')), + ('linked_at', models.DateTimeField(blank=True, help_text='When the account was linked', null=True)), + ('user', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, related_name='twitter_connection', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'verbose_name': 'Twitter Connection', + 'verbose_name_plural': 'Twitter Connections', + }, + ), + ] diff --git a/backend/social_connections/twitter/migrations/__init__.py b/backend/social_connections/twitter/migrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/social_connections/twitter/models.py b/backend/social_connections/twitter/models.py new file mode 100644 index 00000000..d6673fc4 --- /dev/null +++ b/backend/social_connections/twitter/models.py @@ -0,0 +1,22 @@ +""" +Twitter/X OAuth connection model. +""" +from django.db import models +from social_connections.models import SocialConnection + + +class TwitterConnection(SocialConnection): + """ + Twitter/X OAuth connection. + + Stores the user's Twitter account information obtained through OAuth 2.0 with PKCE. + """ + user = models.OneToOneField( + 'users.User', + on_delete=models.CASCADE, + related_name='twitter_connection' + ) + + class Meta: + verbose_name = 'Twitter Connection' + verbose_name_plural = 'Twitter Connections' diff --git a/backend/social_connections/twitter/oauth.py b/backend/social_connections/twitter/oauth.py new file mode 100644 index 00000000..f8a83784 --- /dev/null +++ b/backend/social_connections/twitter/oauth.py @@ -0,0 +1,306 @@ +""" +Twitter/X OAuth 2.0 authentication handling with PKCE. +""" +import secrets +import hashlib +import base64 +from urllib.parse import urlencode + +import requests +from django.conf import settings +from django.shortcuts import redirect, render +from django.utils import timezone +from django.views.decorators.csrf import csrf_exempt +from django.core import signing +from rest_framework.decorators import api_view, permission_classes +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response +from rest_framework import status + +from users.models import User +from social_connections.encryption import encrypt_token +from social_connections.twitter.models import TwitterConnection +from tally.middleware.logging_utils import get_app_logger +from tally.middleware.tracing import trace_external + +logger = get_app_logger('twitter_oauth') + +# Cache to track used OAuth codes (prevents duplicate exchanges) +_used_oauth_codes = {} + + +def generate_code_verifier(): + """Generate a cryptographically random code verifier for PKCE.""" + return secrets.token_urlsafe(64)[:128] # Max 128 chars per spec + + +def generate_code_challenge(verifier): + """Generate code challenge from verifier using S256 method.""" + digest = hashlib.sha256(verifier.encode()).digest() + return base64.urlsafe_b64encode(digest).decode().rstrip('=') + + +@api_view(['GET']) +@permission_classes([IsAuthenticated]) +def twitter_oauth_initiate(request): + """Initiate Twitter OAuth 2.0 flow with PKCE.""" + user_id = request.user.id + logger.debug("Twitter OAuth initiated") + + # Generate PKCE code verifier and challenge + code_verifier = generate_code_verifier() + code_challenge = generate_code_challenge(code_verifier) + + # Generate state token with user ID and code_verifier embedded + state_data = { + 'user_id': user_id, + 'code_verifier': code_verifier, + 'nonce': secrets.token_urlsafe(32) + } + + # Sign the state data to make it tamper-proof + state = signing.dumps(state_data, salt='twitter_oauth_state') + logger.debug("Generated signed OAuth state with PKCE") + + # Build Twitter OAuth URL + twitter_oauth_url = "https://twitter.com/i/oauth2/authorize" + params = { + 'response_type': 'code', + 'client_id': settings.TWITTER_CLIENT_ID, + 'redirect_uri': settings.TWITTER_REDIRECT_URI, + 'scope': 'tweet.read users.read offline.access', + 'state': state, + 'code_challenge': code_challenge, + 'code_challenge_method': 'S256' + } + + auth_url = f"{twitter_oauth_url}?{urlencode(params)}" + return redirect(auth_url) + + +@csrf_exempt +def twitter_oauth_callback(request): + """Handle Twitter OAuth callback.""" + code = request.GET.get('code') + state = request.GET.get('state') + error = request.GET.get('error') + error_description = request.GET.get('error_description') + + logger.debug("Twitter OAuth callback received") + + template_context = { + 'platform': 'twitter', + 'frontend_origin': settings.FRONTEND_URL + } + + # Handle errors from Twitter + if error: + logger.error(f"Twitter OAuth error: {error} - {error_description}") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'authorization_failed', + 'message': 'Authorization failed. This window will close automatically.', + }) + + # Validate state token + if not state: + logger.error("No state token received from Twitter") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'invalid_state', + 'message': 'Invalid state token. This window will close automatically.', + }) + + try: + state_data = signing.loads(state, salt='twitter_oauth_state', max_age=600) + user_id = state_data.get('user_id') + code_verifier = state_data.get('code_verifier') + logger.debug("Successfully validated signed OAuth state") + except signing.SignatureExpired: + logger.error("OAuth state token has expired") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'state_expired', + 'message': 'Session expired. This window will close automatically.', + }) + except signing.BadSignature: + logger.error("OAuth state token has invalid signature") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'invalid_state', + 'message': 'Invalid state token. This window will close automatically.', + }) + + # Clean up old codes (older than 10 minutes) + cutoff = timezone.now() - timezone.timedelta(minutes=10) + expired_codes = [k for k, v in _used_oauth_codes.items() if v <= cutoff] + for expired_code in expired_codes: + del _used_oauth_codes[expired_code] + + # Check if this code has already been used + if code in _used_oauth_codes: + logger.warning("OAuth code already used, rejecting duplicate request") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'code_already_used', + 'message': 'This code has already been used. This window will close automatically.', + }) + + # Mark code as used immediately + _used_oauth_codes[code] = timezone.now() + + # Exchange code for access token + logger.debug("Attempting to exchange OAuth code for access token") + token_url = "https://api.twitter.com/2/oauth2/token" + + # Twitter requires Basic auth for confidential clients + auth = (settings.TWITTER_CLIENT_ID, settings.TWITTER_CLIENT_SECRET) + + token_params = { + 'code': code, + 'grant_type': 'authorization_code', + 'redirect_uri': settings.TWITTER_REDIRECT_URI, + 'code_verifier': code_verifier + } + + headers = { + 'Content-Type': 'application/x-www-form-urlencoded' + } + + try: + with trace_external('twitter', 'token_exchange'): + token_response = requests.post( + token_url, + data=token_params, + headers=headers, + auth=auth + ) + token_data = token_response.json() + + if token_response.status_code != 200 or 'error' in token_data: + error_msg = token_data.get('error_description', token_data.get('error', 'Unknown error')) + logger.error(f"Twitter token exchange error: {error_msg}") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'token_exchange_failed', + 'message': 'Failed to exchange token. This window will close automatically.', + }) + + access_token = token_data.get('access_token') + refresh_token = token_data.get('refresh_token', '') + + if not access_token: + logger.error("No access token received from Twitter") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'no_access_token', + 'message': 'No access token received. This window will close automatically.', + }) + + # Fetch Twitter user data + user_url = "https://api.twitter.com/2/users/me" + user_headers = { + 'Authorization': f'Bearer {access_token}', + } + + with trace_external('twitter', 'get_user'): + user_response = requests.get(user_url, headers=user_headers) + user_response.raise_for_status() + twitter_data = user_response.json() + + twitter_user = twitter_data.get('data', {}) + + # Get user from state token + if not user_id: + logger.error("No user_id in state token") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'invalid_state', + 'message': 'Invalid authentication state. This window will close automatically.', + }) + + try: + user = User.objects.get(id=user_id) + logger.debug("Found user from state token") + except User.DoesNotExist: + logger.error("User from state not found") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'user_not_found', + 'message': 'User not found. This window will close automatically.', + }) + + # Check if this Twitter account is already linked to another user + existing_connection = TwitterConnection.objects.filter( + platform_user_id=twitter_user.get('id') + ).exclude(user=user).first() + + if existing_connection: + logger.warning("Twitter account already linked to another user") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'already_linked', + 'message': 'Twitter account already linked. This window will close automatically.', + }) + + # Create or update TwitterConnection + connection, created = TwitterConnection.objects.update_or_create( + user=user, + defaults={ + 'username': twitter_user.get('username', ''), + 'platform_user_id': twitter_user.get('id', ''), + 'access_token': encrypt_token(access_token), + 'refresh_token': encrypt_token(refresh_token) if refresh_token else '', + 'linked_at': timezone.now() + } + ) + + logger.debug("Twitter account linked successfully") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': True, + 'error': '', + 'message': 'Twitter account linked successfully! This window will close automatically.', + }) + + except requests.RequestException as e: + logger.error(f"Twitter API request failed: {e}") + return render(request, 'social_connections/oauth_callback.html', { + **template_context, + 'success': False, + 'error': 'api_request_failed', + 'message': 'API request failed. This window will close automatically.', + }) + + +@api_view(['POST']) +@permission_classes([IsAuthenticated]) +def disconnect_twitter(request): + """Disconnect Twitter account from user profile.""" + try: + user = request.user + try: + connection = user.twitter_connection + connection.delete() + logger.debug("Twitter connection deleted") + except TwitterConnection.DoesNotExist: + pass + + return Response({ + 'message': 'Twitter account disconnected successfully' + }, status=status.HTTP_200_OK) + except Exception as e: + logger.error(f"Failed to disconnect Twitter: {e}") + return Response({ + 'error': 'Failed to disconnect Twitter account' + }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) diff --git a/backend/social_connections/twitter/urls.py b/backend/social_connections/twitter/urls.py new file mode 100644 index 00000000..1de18446 --- /dev/null +++ b/backend/social_connections/twitter/urls.py @@ -0,0 +1,13 @@ +from django.urls import path +from . import oauth + +urlpatterns = [ + # OAuth endpoints (under /api/auth/twitter/) + path('', oauth.twitter_oauth_initiate, name='twitter_oauth'), + path('callback/', oauth.twitter_oauth_callback, name='twitter_callback'), +] + +# API endpoints (under /api/v1/social/twitter/) +api_urlpatterns = [ + path('disconnect/', oauth.disconnect_twitter, name='twitter_disconnect'), +] diff --git a/backend/social_connections/urls.py b/backend/social_connections/urls.py new file mode 100644 index 00000000..7b896630 --- /dev/null +++ b/backend/social_connections/urls.py @@ -0,0 +1,25 @@ +""" +URL configuration for social connections. + +This module provides two sets of URL patterns: +- auth_urlpatterns: OAuth flow endpoints (initiate, callback) under /api/auth/{platform}/ +- api_urlpatterns: API endpoints (disconnect, etc.) under /api/v1/social/{platform}/ +""" +from django.urls import path, include +from social_connections.github.urls import urlpatterns as github_auth_urls, api_urlpatterns as github_api_urls +from social_connections.twitter.urls import urlpatterns as twitter_auth_urls, api_urlpatterns as twitter_api_urls +from social_connections.discord.urls import urlpatterns as discord_auth_urls, api_urlpatterns as discord_api_urls + +# OAuth flow endpoints (under /api/auth/) +auth_urlpatterns = [ + path('github/', include(github_auth_urls)), + path('twitter/', include(twitter_auth_urls)), + path('discord/', include(discord_auth_urls)), +] + +# API endpoints (under /api/v1/social/) +api_urlpatterns = [ + path('github/', include(github_api_urls)), + path('twitter/', include(twitter_api_urls)), + path('discord/', include(discord_api_urls)), +] diff --git a/backend/tally/settings.py b/backend/tally/settings.py index b50ed155..be0bacdf 100644 --- a/backend/tally/settings.py +++ b/backend/tally/settings.py @@ -75,6 +75,10 @@ def get_required_env(key): 'builders', 'stewards', 'creators', + 'social_connections', + 'social_connections.github', + 'social_connections.twitter', + 'social_connections.discord', ] MIDDLEWARE = [ @@ -223,14 +227,29 @@ def get_required_env(key): # Backend URL (for constructing OAuth redirect URIs) BACKEND_URL = os.environ.get('BACKEND_URL', 'http://localhost:8000') +# Social OAuth Settings +# Shared encryption key for all social tokens (can reuse GITHUB_ENCRYPTION_KEY) +SOCIAL_ENCRYPTION_KEY = os.environ.get('SOCIAL_ENCRYPTION_KEY', + os.environ.get('GITHUB_ENCRYPTION_KEY', '')) + # GitHub OAuth settings GITHUB_CLIENT_ID = os.environ.get('GITHUB_CLIENT_ID', '') GITHUB_CLIENT_SECRET = os.environ.get('GITHUB_CLIENT_SECRET', '') -# Auto-calculate redirect URI from backend URL GITHUB_REDIRECT_URI = f"{BACKEND_URL}/api/auth/github/callback/" -GITHUB_ENCRYPTION_KEY = os.environ.get('GITHUB_ENCRYPTION_KEY', '') +GITHUB_ENCRYPTION_KEY = os.environ.get('GITHUB_ENCRYPTION_KEY', '') # Legacy, use SOCIAL_ENCRYPTION_KEY GITHUB_REPO_TO_STAR = os.environ.get('GITHUB_REPO_TO_STAR', 'genlayerlabs/genlayer-project-boilerplate') +# Twitter/X OAuth 2.0 settings +TWITTER_CLIENT_ID = os.environ.get('TWITTER_CLIENT_ID', '') +TWITTER_CLIENT_SECRET = os.environ.get('TWITTER_CLIENT_SECRET', '') +TWITTER_REDIRECT_URI = f"{BACKEND_URL}/api/auth/twitter/callback/" + +# Discord OAuth 2.0 settings +DISCORD_CLIENT_ID = os.environ.get('DISCORD_CLIENT_ID', '') +DISCORD_CLIENT_SECRET = os.environ.get('DISCORD_CLIENT_SECRET', '') +DISCORD_REDIRECT_URI = f"{BACKEND_URL}/api/auth/discord/callback/" +DISCORD_GUILD_ID = os.environ.get('DISCORD_GUILD_ID', '') # Optional: for checking guild membership + # Frontend URL for OAuth redirects FRONTEND_URL = os.environ.get('FRONTEND_URL', 'http://localhost:5173') diff --git a/backend/tally/urls.py b/backend/tally/urls.py index 88756811..33c8d6d7 100644 --- a/backend/tally/urls.py +++ b/backend/tally/urls.py @@ -27,8 +27,8 @@ from django.utils.decorators import method_decorator from django.views import View -# Import GitHub OAuth views -from users.github_oauth import github_oauth_initiate, github_oauth_callback, disconnect_github, check_repo_star +# Import social connections URLs +from social_connections.urls import auth_urlpatterns as social_auth_urls, api_urlpatterns as social_api_urls @csrf_exempt def health_check(request): @@ -61,11 +61,11 @@ def health_check(request): # Ethereum Authentication path('api/auth/', include('ethereum_auth.urls')), - # GitHub OAuth - path('api/auth/github/', github_oauth_initiate, name='github_oauth'), - path('api/auth/github/callback/', github_oauth_callback, name='github_callback'), - path('api/v1/users/github/disconnect/', disconnect_github, name='github_disconnect'), - path('api/v1/users/github/check-star/', check_repo_star, name='github_check_star'), + # Social OAuth (GitHub, Twitter, Discord) + path('api/auth/', include(social_auth_urls)), + + # Social API endpoints (disconnect, check-star, check-guild, etc.) + path('api/v1/social/', include(social_api_urls)), # Contributions app (includes both API and staff views) path('contributions/', include('contributions.urls')), diff --git a/backend/users/serializers.py b/backend/users/serializers.py index 7b96b30b..90426dd2 100644 --- a/backend/users/serializers.py +++ b/backend/users/serializers.py @@ -8,6 +8,11 @@ from contributions.node_upgrade.models import TargetNodeVersion from leaderboard.models import LeaderboardEntry from contributions.models import Category +from social_connections.serializers import ( + GitHubConnectionSerializer, + TwitterConnectionSerializer, + DiscordConnectionSerializer +) # ============================================================================ @@ -232,8 +237,7 @@ class UserProfileUpdateSerializer(serializers.ModelSerializer): class Meta: model = User fields = ['name', 'node_version', 'email', 'description', 'website', - 'twitter_handle', 'discord_handle', 'telegram_handle', 'linkedin_handle', - 'github_username'] + 'twitter_handle', 'discord_handle', 'telegram_handle', 'linkedin_handle'] def validate_email(self, value): """Validate email with DNS checks and block disposable providers""" @@ -514,6 +518,15 @@ class UserSerializer(serializers.ModelSerializer): has_builder_welcome = serializers.SerializerMethodField() email = serializers.SerializerMethodField() + # Social connections (new nested serializers) + github_connection = GitHubConnectionSerializer(read_only=True) + twitter_connection = TwitterConnectionSerializer(read_only=True) + discord_connection = DiscordConnectionSerializer(read_only=True) + + # Legacy GitHub fields (computed from github_connection for backward compatibility) + github_username = serializers.SerializerMethodField() + github_linked_at = serializers.SerializerMethodField() + # Referral system fields referred_by_info = serializers.SerializerMethodField() total_referrals = serializers.SerializerMethodField() @@ -528,8 +541,12 @@ class Meta: 'creator', 'has_validator_waitlist', 'has_builder_welcome', 'created_at', 'updated_at', # Profile fields 'description', 'banner_image_url', 'profile_image_url', 'website', - 'twitter_handle', 'discord_handle', 'telegram_handle', 'linkedin_handle', 'github_username', 'github_linked_at', + 'twitter_handle', 'discord_handle', 'telegram_handle', 'linkedin_handle', 'email', 'is_email_verified', + # Social connections (new) + 'github_connection', 'twitter_connection', 'discord_connection', + # Legacy GitHub fields (for backward compatibility) + 'github_username', 'github_linked_at', # Referral fields 'referral_code', 'referred_by_info', 'total_referrals', 'referral_details', # Working groups @@ -626,7 +643,27 @@ def get_email(self, obj): if obj.is_email_verified: return obj.email return '' - + + def get_github_username(self, obj): + """ + Legacy field: Get GitHub username from github_connection. + Returns empty string if not connected for backward compatibility. + """ + try: + return obj.github_connection.username or '' + except Exception: + return '' + + def get_github_linked_at(self, obj): + """ + Legacy field: Get GitHub linked_at from github_connection. + Returns None if not connected for backward compatibility. + """ + try: + return obj.github_connection.linked_at + except Exception: + return None + def get_referred_by_info(self, obj): """ Get information about who referred this user. diff --git a/frontend/src/App.svelte b/frontend/src/App.svelte index a30afc90..fc582739 100644 --- a/frontend/src/App.svelte +++ b/frontend/src/App.svelte @@ -42,6 +42,8 @@ import WaitlistParticipants from './routes/WaitlistParticipants.svelte'; import BuilderWelcome from './routes/BuilderWelcome.svelte'; import GitHubCallback from './routes/GitHubCallback.svelte'; + import TwitterCallback from './routes/TwitterCallback.svelte'; + import DiscordCallback from './routes/DiscordCallback.svelte'; import TermsOfUse from './routes/TermsOfUse.svelte'; import PrivacyPolicy from './routes/PrivacyPolicy.svelte'; import Referrals from './routes/Referrals.svelte'; @@ -54,6 +56,8 @@ // Auth callback routes '/auth/github/callback': GitHubCallback, + '/auth/twitter/callback': TwitterCallback, + '/auth/discord/callback': DiscordCallback, // Global/Testnet Asimov routes // Overview and Testnet Asimov routes diff --git a/frontend/src/components/DiscordLink.svelte b/frontend/src/components/DiscordLink.svelte new file mode 100644 index 00000000..382134ce --- /dev/null +++ b/frontend/src/components/DiscordLink.svelte @@ -0,0 +1,145 @@ + + + diff --git a/frontend/src/components/TwitterLink.svelte b/frontend/src/components/TwitterLink.svelte new file mode 100644 index 00000000..fe3eb392 --- /dev/null +++ b/frontend/src/components/TwitterLink.svelte @@ -0,0 +1,145 @@ + + + diff --git a/frontend/src/lib/api.js b/frontend/src/lib/api.js index f13eb9ba..c4c263a6 100644 --- a/frontend/src/lib/api.js +++ b/frontend/src/lib/api.js @@ -160,9 +160,24 @@ export const creatorAPI = { joinAsCreator: () => api.post('/creators/join/') }; -// GitHub OAuth API +// Social OAuth API +export const socialAPI = { + github: { + disconnect: () => api.post('/social/github/disconnect/'), + checkStar: () => api.get('/social/github/check-star/') + }, + twitter: { + disconnect: () => api.post('/social/twitter/disconnect/') + }, + discord: { + disconnect: () => api.post('/social/discord/disconnect/'), + checkGuild: () => api.get('/social/discord/check-guild/') + } +}; + +// Legacy GitHub OAuth API (for backward compatibility) export const githubAPI = { - checkStar: () => api.get('/users/github/check-star/') + checkStar: () => api.get('/social/github/check-star/') }; // Steward API diff --git a/frontend/src/routes/DiscordCallback.svelte b/frontend/src/routes/DiscordCallback.svelte new file mode 100644 index 00000000..2cbc3d07 --- /dev/null +++ b/frontend/src/routes/DiscordCallback.svelte @@ -0,0 +1,43 @@ + + +
+
+
+

Completing Discord connection...

+
+
diff --git a/frontend/src/routes/ProfileEdit.svelte b/frontend/src/routes/ProfileEdit.svelte index 84e5575c..5d341b17 100644 --- a/frontend/src/routes/ProfileEdit.svelte +++ b/frontend/src/routes/ProfileEdit.svelte @@ -9,6 +9,8 @@ import { userStore } from '../lib/userStore'; import { showSuccess, showError } from '../lib/toastStore'; import GitHubLink from '../components/GitHubLink.svelte'; + import TwitterLink from '../components/TwitterLink.svelte'; + import DiscordLink from '../components/DiscordLink.svelte'; // State management let user = $state(null); @@ -347,8 +349,8 @@ cropperImage = null; } - async function handleGitHubLinked(updatedUser) { - // Update local user state with the updated GitHub info + async function handleSocialLinked(updatedUser) { + // Update local user state with the updated social connection info user = updatedUser; } @@ -567,16 +569,16 @@
- {#if user.github_username} + {#if user.github_connection?.username}
- {user.github_username} + {user.github_connection.username}
- {#if user.github_linked_at} -

Linked on {new Date(user.github_linked_at).toLocaleDateString()}

+ {#if user.github_connection?.linked_at} +

Linked on {new Date(user.github_connection.linked_at).toLocaleDateString()}

{/if} {:else}

Link your GitHub to participate in builder programs

{/if}
+ + +
+ + {#if user.twitter_connection?.username} +
+
+ + + + @{user.twitter_connection.username} +
+ + + + + +
+ {#if user.twitter_connection?.linked_at} +

Linked on {new Date(user.twitter_connection.linked_at).toLocaleDateString()}

+ {/if} + {:else} + +

Link your Twitter for verified social presence

+ {/if} +
+ + +
+ + {#if user.discord_connection?.username} +
+
+ + + + {user.discord_connection.username} +
+
+ {#if user.discord_connection?.linked_at} +

Linked on {new Date(user.discord_connection.linked_at).toLocaleDateString()}

+ {/if} + {:else} + +

Link your Discord for verified community membership

+ {/if} +
diff --git a/frontend/src/routes/TwitterCallback.svelte b/frontend/src/routes/TwitterCallback.svelte new file mode 100644 index 00000000..923cd721 --- /dev/null +++ b/frontend/src/routes/TwitterCallback.svelte @@ -0,0 +1,43 @@ + + +
+
+
+

Completing Twitter connection...

+
+