diff --git a/alembic/versions/596bb368fc0d_add_orcid_id_to_users.py b/alembic/versions/596bb368fc0d_add_orcid_id_to_users.py new file mode 100644 index 0000000..6837c41 --- /dev/null +++ b/alembic/versions/596bb368fc0d_add_orcid_id_to_users.py @@ -0,0 +1,31 @@ +"""Add orcid_id column to users + +Revision ID: 596bb368fc0d +Revises: d3a7b8c1e2f4 +Create Date: 2026-03-28 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +revision: str = "596bb368fc0d" +down_revision: Union[str, Sequence[str], None] = "d3a7b8c1e2f4" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "users", + sa.Column("orcid_id", sa.String(20), nullable=True), + ) + op.create_index("ix_users_orcid_id", "users", ["orcid_id"], unique=True) + + +def downgrade() -> None: + op.drop_index("ix_users_orcid_id", table_name="users") + op.drop_column("users", "orcid_id") diff --git a/app/models/user.py b/app/models/user.py index 8dad6f0..5313f91 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -1,12 +1,27 @@ from datetime import datetime, timezone +import re import uuid from sqlalchemy import Boolean, Column, DateTime, String, TypeDecorator from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import relationship +from sqlalchemy.orm import relationship, validates from app.database import Base +ORCID_PATTERN = re.compile(r"^\d{4}-\d{4}-\d{4}-\d{3}[\dX]$") + + +def validate_orcid_id(value: str | None) -> str | None: + """Validate ORCID iD format: XXXX-XXXX-XXXX-XXXX (last char may be X checksum).""" + if value is None: + return None + if not ORCID_PATTERN.match(value): + raise ValueError( + f"Invalid ORCID iD format: '{value}'. " + "Expected format: 0000-0002-1234-5678" + ) + return value + # Cross-platform UUID type that works with SQLite class GUID(TypeDecorator): @@ -58,11 +73,16 @@ class User(Base): onupdate=lambda: datetime.now(timezone.utc), nullable=False, ) + orcid_id = Column(String(20), nullable=True, unique=True, index=True) # Relationships scrolls = relationship("Scroll", back_populates="user") sessions = relationship("Session", back_populates="user") tokens = relationship("Token", back_populates="user") + @validates("orcid_id") + def _validate_orcid_id(self, _key: str, value: str | None) -> str | None: + return validate_orcid_id(value) + def __repr__(self): return f"" diff --git a/app/routes/auth.py b/app/routes/auth.py index baac466..d828812 100644 --- a/app/routes/auth.py +++ b/app/routes/auth.py @@ -89,7 +89,10 @@ async def login_page(request: Request, db: AsyncSession = Depends(get_db)): get_logger().info(f"Authenticated user {current_user.id} redirected from login page") return RedirectResponse(url="/", status_code=302) - return templates.TemplateResponse(request, "auth/login.html", {"current_user": current_user}) + error = request.query_params.get("error") + return templates.TemplateResponse( + request, "auth/login.html", {"current_user": current_user, "error": error} + ) @router.post("/logout") @@ -244,8 +247,9 @@ async def register_page(request: Request, db: AsyncSession = Depends(get_db)): get_logger().info(f"Authenticated user {current_user.id} redirected from register page") return RedirectResponse(url="/", status_code=302) + error = request.query_params.get("error") return templates.TemplateResponse( - request, "auth/register.html", {"current_user": current_user} + request, "auth/register.html", {"current_user": current_user, "error": error} ) diff --git a/app/routes/main.py b/app/routes/main.py index dced326..8829e58 100644 --- a/app/routes/main.py +++ b/app/routes/main.py @@ -597,6 +597,9 @@ async def dashboard(request: Request, db: AsyncSession = Depends(get_db)): is_boosted = request.headers.get("HX-Boosted") == "true" use_partial = is_htmx and not is_boosted + error = request.query_params.get("error") + orcid_status = request.query_params.get("orcid") + return templates.TemplateResponse( request, "dashboard_content.html" if use_partial else "dashboard.html", @@ -605,6 +608,8 @@ async def dashboard(request: Request, db: AsyncSession = Depends(get_db)): "papers": papers, "drafts": drafts, "csrf_token": csrf_token, + "error": error, + "orcid_status": orcid_status, }, ) diff --git a/app/routes/orcid.py b/app/routes/orcid.py new file mode 100644 index 0000000..428ef56 --- /dev/null +++ b/app/routes/orcid.py @@ -0,0 +1,209 @@ +"""ORCID OAuth2 authentication routes.""" + +import os +import secrets + +from fastapi import APIRouter, Depends, Request +from fastapi.responses import RedirectResponse +import httpx +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.auth.session import create_session, get_current_user_from_session +from app.database import get_db +from app.logging_config import get_logger +from app.models.user import User + +router = APIRouter(prefix="/auth/orcid") + +IS_PRODUCTION = os.getenv("ENVIRONMENT") == "production" + +ORCID_CLIENT_ID = os.getenv("ORCID_CLIENT_ID", "") +ORCID_CLIENT_SECRET = os.getenv("ORCID_CLIENT_SECRET", "") +ORCID_BASE_URL = os.getenv("ORCID_BASE_URL", "https://sandbox.orcid.org") + +# Pending OAuth states: state_token -> True +# Short-lived, cleared on use. In production, use Redis/DB before horizontal scaling. +_pending_states: dict[str, bool] = {} + + +def _get_redirect_uri(request: Request) -> str: + """Build the ORCID callback URI from the current request.""" + return str(request.url_for("orcid_callback")) + + +@router.get("", name="orcid_redirect") +async def orcid_redirect(request: Request, db: AsyncSession = Depends(get_db)): + """Redirect to ORCID authorize URL with CSRF state.""" + if not ORCID_CLIENT_ID or not ORCID_CLIENT_SECRET: + current_user = await get_current_user_from_session(request, db) + error_url = "/dashboard?error=orcid_not_configured" if current_user else "/login?error=orcid_not_configured" + return RedirectResponse(url=error_url, status_code=302) + + state = secrets.token_urlsafe(32) + _pending_states[state] = True + + redirect_uri = _get_redirect_uri(request) + authorize_url = ( + f"{ORCID_BASE_URL}/oauth/authorize" + f"?client_id={ORCID_CLIENT_ID}" + f"&response_type=code" + f"&scope=/authenticate" + f"&redirect_uri={redirect_uri}" + f"&state={state}" + ) + + response = RedirectResponse(url=authorize_url, status_code=302) + response.set_cookie( + "orcid_state", state, httponly=True, secure=IS_PRODUCTION, + samesite="lax", max_age=600, + ) + return response + + +@router.get("/callback", name="orcid_callback") +async def orcid_callback( + request: Request, + code: str | None = None, + state: str | None = None, + db: AsyncSession = Depends(get_db), +): + """Handle ORCID OAuth2 callback.""" + logger = get_logger() + current_user = await get_current_user_from_session(request, db) + + def _error_redirect(error: str) -> RedirectResponse: + base = "/dashboard" if current_user else "/login" + return RedirectResponse(url=f"{base}?error={error}", status_code=302) + + # Validate state + cookie_state = request.cookies.get("orcid_state") + if not state or not cookie_state or state != cookie_state or state not in _pending_states: + logger.warning("ORCID callback: invalid or missing state") + return _error_redirect("orcid_state") + + _pending_states.pop(state, None) + + if not code: + logger.warning("ORCID callback: missing code") + return _error_redirect("orcid_missing_code") + + # Exchange code for token + redirect_uri = _get_redirect_uri(request) + try: + async with httpx.AsyncClient() as client: + token_resp = await client.post( + f"{ORCID_BASE_URL}/oauth/token", + data={ + "client_id": ORCID_CLIENT_ID, + "client_secret": ORCID_CLIENT_SECRET, + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + }, + headers={"Accept": "application/json"}, + ) + + if token_resp.status_code != 200: + logger.error(f"ORCID token exchange failed: {token_resp.status_code}") + return _error_redirect("orcid_token") + + token_data = token_resp.json() + except Exception as e: + logger.error(f"ORCID token exchange error: {e}") + return _error_redirect("orcid_token") + + orcid_id = token_data.get("orcid") + orcid_name = token_data.get("name", "") + + if not orcid_id: + logger.error("ORCID token response missing orcid field") + return _error_redirect("orcid_token") + + if current_user: + return await _link_orcid(db, current_user, orcid_id) + + return await _login_or_register(db, orcid_id, orcid_name) + + +async def _link_orcid(db: AsyncSession, user: User, orcid_id: str) -> RedirectResponse: + """Link ORCID to an existing logged-in user.""" + logger = get_logger() + + # Check if ORCID is already taken by another user + result = await db.execute(select(User).where(User.orcid_id == orcid_id)) + existing = result.scalar_one_or_none() + if existing and existing.id != user.id: + logger.warning(f"ORCID {orcid_id} already linked to user {existing.id}") + return RedirectResponse(url="/dashboard?error=orcid_taken", status_code=302) + + user.orcid_id = orcid_id + await db.commit() + logger.info(f"Linked ORCID {orcid_id} to user {user.id}") + return RedirectResponse(url="/dashboard?orcid=linked", status_code=302) + + +async def _login_or_register( + db: AsyncSession, orcid_id: str, orcid_name: str, +) -> RedirectResponse: + """Log in existing ORCID user or create a new account.""" + logger = get_logger() + + result = await db.execute(select(User).where(User.orcid_id == orcid_id)) + user = result.scalar_one_or_none() + + if user: + session_id = await create_session(db, user.id) + logger.info(f"ORCID login for user {user.id}") + response = RedirectResponse(url="/dashboard", status_code=302) + response.set_cookie( + "session_id", session_id, httponly=True, + secure=IS_PRODUCTION, samesite="lax", + ) + return response + + # Create new user + # Generate a placeholder email using ORCID (users can update it later) + placeholder_email = f"{orcid_id}@orcid.placeholder" + display_name = orcid_name.strip() if orcid_name else f"ORCID User {orcid_id[-4:]}" + + new_user = User( + email=placeholder_email, + password_hash="!orcid-only", + display_name=display_name, + email_verified=True, + orcid_id=orcid_id, + ) + db.add(new_user) + await db.commit() + await db.refresh(new_user) + + session_id = await create_session(db, new_user.id) + logger.info(f"Created new user {new_user.id} via ORCID {orcid_id}") + + response = RedirectResponse(url="/dashboard", status_code=302) + response.set_cookie( + "session_id", session_id, httponly=True, + secure=IS_PRODUCTION, samesite="lax", + ) + return response + + +@router.get("/unlink", name="orcid_unlink") +async def orcid_unlink(request: Request, db: AsyncSession = Depends(get_db)): + """Remove ORCID from the current user's account.""" + current_user = await get_current_user_from_session(request, db) + if not current_user: + return RedirectResponse(url="/login", status_code=302) + + logger = get_logger() + + # Block unlink if user has no password (would lock them out) + if not current_user.password_hash or current_user.password_hash == "!orcid-only": + logger.warning(f"User {current_user.id} tried to unlink ORCID without password") + return RedirectResponse(url="/dashboard?error=orcid_no_password", status_code=302) + + current_user.orcid_id = None + await db.commit() + logger.info(f"Unlinked ORCID from user {current_user.id}") + return RedirectResponse(url="/dashboard?orcid=unlinked", status_code=302) diff --git a/app/templates/auth/login.html b/app/templates/auth/login.html index 5ec2cb8..3149c53 100644 --- a/app/templates/auth/login.html +++ b/app/templates/auth/login.html @@ -9,8 +9,33 @@

Sign In

Sign in to your Press account

+ {% if error %} + + {% endif %} + {% include "auth/partials/login_form.html" %} +
or
+ + + Sign in with ORCID + + + {% if error %} + + {% endif %} + {% include "auth/partials/register_form.html" %} +
or
+ + + Sign up with ORCID + + diff --git a/app/templates/dashboard.html b/app/templates/dashboard.html index 0e3f8d5..e6d6212 100644 --- a/app/templates/dashboard.html +++ b/app/templates/dashboard.html @@ -98,6 +98,45 @@

No published papers yet

{% endif %} +
+

ORCID

+ {% if error == "orcid_taken" %} + + {% elif error == "orcid_no_password" %} + + {% elif error == "orcid_not_configured" %} + + {% elif error == "orcid_state" %} + + {% elif error == "orcid_token" %} + + {% elif error == "orcid_missing_code" %} + + {% endif %} + {% if orcid_status == "linked" %} + + {% elif orcid_status == "unlinked" %} + + {% endif %} + {% if current_user.orcid_id %} +

+ Linked: {{ current_user.orcid_id }} +

+ + {% else %} +

No ORCID linked to your account.

+ + {% endif %} +
+

Account Management

diff --git a/main.py b/main.py index 8092f81..1b6b306 100644 --- a/main.py +++ b/main.py @@ -14,12 +14,12 @@ from sentry_sdk.integrations.asyncio import AsyncioIntegration from sentry_sdk.integrations.fastapi import FastApiIntegration from sentry_sdk.integrations.sqlalchemy import SqlalchemyIntegration +from starlette.exceptions import HTTPException as StarletteHTTPException from app.exception_handlers import ( http_exception_handler, internal_server_error_handler, ) -from starlette.exceptions import HTTPException as StarletteHTTPException from app.logging_config import get_logger from app.memory_profiling_middleware import MemoryProfilingMiddleware from app.middleware import ( @@ -31,7 +31,7 @@ SecurityHeadersMiddleware, StaticFilesCacheMiddleware, ) -from app.routes import api, auth, main, scrolls +from app.routes import api, auth, main, orcid, scrolls from app.security.nonce_middleware import NonceMiddleware from app.sentry_config import before_send @@ -226,3 +226,4 @@ async def health_check(): app.include_router(auth.router) app.include_router(scrolls.router) app.include_router(api.router) +app.include_router(orcid.router) diff --git a/static/css/main.css b/static/css/main.css index b016eab..7721364 100644 --- a/static/css/main.css +++ b/static/css/main.css @@ -582,6 +582,26 @@ button.btn, a.btn { .form-errors li:last-child { margin-bottom: 0; } +.alert { + border-radius: var(--border-radius); + padding: var(--space-md); + margin-bottom: var(--space-md); + font-size: var(--text-sm); +} +.alert-error { + background: var(--error-bg); + border: 1px solid var(--error-border); + color: #991b1b; +} +[data-theme="dark"] .alert-error, +.dark .alert-error { + color: #fca5a5; +} +.alert-success { + background: var(--success-bg); + border: 1px solid var(--success-border); + color: var(--success-text); +} .success-message { background: var(--success-bg); border: 1px solid var(--success-border); @@ -614,6 +634,50 @@ button.btn, a.btn { .auth-links a:hover { text-decoration: underline; } +.auth-divider { + display: flex; + align-items: center; + margin: var(--space-lg) 0; + color: var(--gray); + font-size: var(--text-sm); +} +.auth-divider::before, +.auth-divider::after { + content: ""; + flex: 1; + border-bottom: 1px solid var(--gray-lightest); +} +.auth-divider span { + padding: 0 var(--space-md); +} +.btn-orcid { + display: flex; + align-items: center; + justify-content: center; + gap: var(--space-sm); + width: 100%; + padding: var(--space-md) var(--space-lg); + background: var(--white); + color: var(--text-color); + border: 1px solid var(--gray-lighter); + border-radius: var(--border-radius); + font-size: var(--text-base); + font-weight: 500; + text-decoration: none; + cursor: pointer; + transition: background-color 0.2s, border-color 0.2s; +} +.btn-orcid:hover { + background: var(--gray-bg); + border-color: #a6ce39; +} +.orcid-icon { + flex-shrink: 0; +} +.orcid-status { + color: var(--gray-dark); + margin-bottom: var(--space-md); +} .back-link { text-align: center; margin-top: var(--space-xl); @@ -711,15 +775,25 @@ button.btn, a.btn { margin-bottom: var(--mobile-space-lg); } + .auth-divider { + margin: var(--mobile-space-lg) 0; + } + + .btn-orcid { + padding: var(--mobile-space-lg); + font-size: var(--mobile-text-base); + min-height: 44px; + } + .auth-links { margin-top: var(--mobile-space-xl); padding-top: var(--mobile-space-xl); } - + .auth-links a { font-size: var(--mobile-text-base); } - + .back-link { margin-top: var(--mobile-space-xl); } diff --git a/tests/test_orcid.py b/tests/test_orcid.py new file mode 100644 index 0000000..f40283d --- /dev/null +++ b/tests/test_orcid.py @@ -0,0 +1,154 @@ +"""Tests for ORCID iD storage on the User model.""" + +import pytest +import pytest_asyncio +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError + +from app.models.user import User, validate_orcid_id + + +class TestValidateOrcidId: + """Tests for the orcid_id format validation function.""" + + def test_valid_orcid(self): + assert validate_orcid_id("0000-0002-1234-5678") == "0000-0002-1234-5678" + + def test_valid_orcid_with_checksum_x(self): + assert validate_orcid_id("0000-0002-1234-567X") == "0000-0002-1234-567X" + + def test_none_passes_through(self): + assert validate_orcid_id(None) is None + + def test_rejects_missing_hyphens(self): + with pytest.raises(ValueError, match="ORCID"): + validate_orcid_id("0000000212345678") + + def test_rejects_too_short(self): + with pytest.raises(ValueError, match="ORCID"): + validate_orcid_id("0000-0002-1234") + + def test_rejects_letters_in_body(self): + with pytest.raises(ValueError, match="ORCID"): + validate_orcid_id("000A-0002-1234-5678") + + def test_rejects_empty_string(self): + with pytest.raises(ValueError, match="ORCID"): + validate_orcid_id("") + + def test_rejects_lowercase_x(self): + with pytest.raises(ValueError, match="ORCID"): + validate_orcid_id("0000-0002-1234-567x") + + +@pytest.mark.asyncio +class TestUserOrcidColumn: + """Tests for the orcid_id column on the User model.""" + + async def test_user_created_without_orcid(self, test_db): + """Backward compat: users can be created with no orcid_id.""" + from app.auth.utils import get_password_hash + + user = User( + email="noorcid@example.com", + password_hash=get_password_hash("password123"), + display_name="No ORCID User", + email_verified=True, + ) + test_db.add(user) + await test_db.commit() + await test_db.refresh(user) + + assert user.orcid_id is None + + async def test_user_created_with_valid_orcid(self, test_db): + from app.auth.utils import get_password_hash + + user = User( + email="orcid@example.com", + password_hash=get_password_hash("password123"), + display_name="ORCID User", + email_verified=True, + orcid_id="0000-0002-1234-5678", + ) + test_db.add(user) + await test_db.commit() + await test_db.refresh(user) + + assert user.orcid_id == "0000-0002-1234-5678" + + async def test_orcid_with_checksum_x(self, test_db): + from app.auth.utils import get_password_hash + + user = User( + email="orcidx@example.com", + password_hash=get_password_hash("password123"), + display_name="ORCID X User", + email_verified=True, + orcid_id="0000-0002-1234-567X", + ) + test_db.add(user) + await test_db.commit() + await test_db.refresh(user) + + assert user.orcid_id == "0000-0002-1234-567X" + + async def test_orcid_uniqueness_constraint(self, test_db): + """Two users cannot share the same ORCID iD.""" + from app.auth.utils import get_password_hash + + user1 = User( + email="user1@example.com", + password_hash=get_password_hash("password123"), + display_name="User 1", + email_verified=True, + orcid_id="0000-0002-1234-5678", + ) + test_db.add(user1) + await test_db.commit() + + user2 = User( + email="user2@example.com", + password_hash=get_password_hash("password123"), + display_name="User 2", + email_verified=True, + orcid_id="0000-0002-1234-5678", + ) + test_db.add(user2) + + with pytest.raises(IntegrityError): + await test_db.commit() + + async def test_multiple_users_with_null_orcid(self, test_db): + """Multiple users can have NULL orcid_id (no false unique violation).""" + from app.auth.utils import get_password_hash + + for i in range(3): + user = User( + email=f"null_orcid_{i}@example.com", + password_hash=get_password_hash("password123"), + display_name=f"Null ORCID {i}", + email_verified=True, + ) + test_db.add(user) + + await test_db.commit() + + result = await test_db.execute( + select(User).where(User.orcid_id.is_(None)) + ) + users = result.scalars().all() + assert len(users) == 3 + + async def test_invalid_orcid_rejected_by_validator(self, test_db): + """The @validates decorator rejects bad formats before hitting the DB.""" + from app.auth.utils import get_password_hash + + with pytest.raises(ValueError, match="ORCID"): + User( + email="bad@example.com", + password_hash=get_password_hash("password123"), + display_name="Bad ORCID", + email_verified=True, + orcid_id="not-an-orcid", + ) diff --git a/tests/test_orcid_oauth.py b/tests/test_orcid_oauth.py new file mode 100644 index 0000000..43cba84 --- /dev/null +++ b/tests/test_orcid_oauth.py @@ -0,0 +1,435 @@ +"""Tests for ORCID OAuth2 authentication flow.""" + +from unittest.mock import AsyncMock, MagicMock, patch +from urllib.parse import parse_qs, urlparse + +import pytest +import pytest_asyncio +from sqlalchemy import select + +from app.auth.utils import get_password_hash +from app.models.user import User + +FAKE_ORCID = "0000-0002-1234-5678" +FAKE_ORCID_2 = "0000-0002-9999-0001" + + +@pytest.fixture(autouse=True) +def orcid_env(monkeypatch): + """Set ORCID env vars and reset module-level state for each test.""" + monkeypatch.setenv("ORCID_CLIENT_ID", "APP-TESTCLIENTID") + monkeypatch.setenv("ORCID_CLIENT_SECRET", "test-secret") + monkeypatch.setenv("ORCID_BASE_URL", "https://sandbox.orcid.org") + + import app.routes.orcid as orcid_mod + + monkeypatch.setattr(orcid_mod, "ORCID_CLIENT_ID", "APP-TESTCLIENTID") + monkeypatch.setattr(orcid_mod, "ORCID_CLIENT_SECRET", "test-secret") + monkeypatch.setattr(orcid_mod, "ORCID_BASE_URL", "https://sandbox.orcid.org") + orcid_mod._pending_states.clear() + yield + orcid_mod._pending_states.clear() + + +@pytest_asyncio.fixture +async def orcid_user(test_db): + """A user who already has an ORCID linked.""" + user = User( + email="orcid-linked@example.com", + password_hash=get_password_hash("password123"), + display_name="ORCID User", + email_verified=True, + orcid_id=FAKE_ORCID, + ) + test_db.add(user) + await test_db.commit() + await test_db.refresh(user) + return user + + +@pytest_asyncio.fixture +async def passwordless_orcid_user(test_db): + """A user created via ORCID login (no real password).""" + user = User( + email="orcid-only@example.com", + password_hash="!orcid-only", + display_name="ORCID Only User", + email_verified=True, + orcid_id=FAKE_ORCID_2, + ) + test_db.add(user) + await test_db.commit() + await test_db.refresh(user) + return user + + +def _mock_orcid_token_response(orcid_id=FAKE_ORCID, name="Jane Doe"): + """Build a mock httpx response for ORCID token exchange.""" + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = { + "access_token": "fake-access-token", + "token_type": "bearer", + "orcid": orcid_id, + "name": name, + } + return mock_resp + + +def _mock_orcid_token_error(): + mock_resp = MagicMock() + mock_resp.status_code = 401 + mock_resp.json.return_value = {"error": "invalid_grant"} + return mock_resp + + +@pytest.mark.asyncio +class TestOrcidRedirect: + """GET /auth/orcid should redirect to ORCID authorize URL.""" + + async def test_redirects_to_orcid(self, client): + resp = await client.get("/auth/orcid", follow_redirects=False) + assert resp.status_code == 302 + + location = resp.headers["location"] + parsed = urlparse(location) + assert "orcid.org" in parsed.netloc + assert parsed.path == "/oauth/authorize" + + params = parse_qs(parsed.query) + assert params["response_type"] == ["code"] + assert params["scope"] == ["/authenticate"] + assert "redirect_uri" in params + assert "state" in params + assert "client_id" in params + + async def test_redirect_blocked_when_not_configured(self, client, monkeypatch): + """Unauthenticated user gets redirected to /login when ORCID not configured.""" + import app.routes.orcid as orcid_mod + + monkeypatch.setattr(orcid_mod, "ORCID_CLIENT_ID", "") + resp = await client.get("/auth/orcid", follow_redirects=False) + assert resp.status_code == 302 + assert "/login" in resp.headers["location"] + assert "orcid_not_configured" in resp.headers["location"] + + async def test_redirect_blocked_when_not_configured_authenticated( + self, authenticated_client, monkeypatch + ): + """Authenticated user gets redirected to /dashboard when ORCID not configured.""" + import app.routes.orcid as orcid_mod + + monkeypatch.setattr(orcid_mod, "ORCID_CLIENT_ID", "") + resp = await authenticated_client.get("/auth/orcid", follow_redirects=False) + assert resp.status_code == 302 + assert "/dashboard" in resp.headers["location"] + assert "orcid_not_configured" in resp.headers["location"] + + async def test_state_param_is_random(self, client): + """Each redirect should generate a unique state.""" + resp1 = await client.get("/auth/orcid", follow_redirects=False) + resp2 = await client.get("/auth/orcid", follow_redirects=False) + + state1 = parse_qs(urlparse(resp1.headers["location"]).query)["state"][0] + state2 = parse_qs(urlparse(resp2.headers["location"]).query)["state"][0] + assert state1 != state2 + + +@pytest.mark.asyncio +class TestOrcidCallback: + """GET /auth/orcid/callback tests.""" + + async def test_rejects_missing_state(self, client): + resp = await client.get("/auth/orcid/callback?code=abc", follow_redirects=False) + assert resp.status_code == 302 + assert "/login" in resp.headers["location"] + + async def test_rejects_invalid_state(self, client): + # First, hit /auth/orcid to set a state in session + await client.get("/auth/orcid", follow_redirects=False) + + resp = await client.get( + "/auth/orcid/callback?code=abc&state=wrong-state", follow_redirects=False + ) + assert resp.status_code == 302 + assert "/login" in resp.headers["location"] + + async def test_rejects_missing_code(self, client): + # Get valid state + redir = await client.get("/auth/orcid", follow_redirects=False) + state = parse_qs(urlparse(redir.headers["location"]).query)["state"][0] + + resp = await client.get( + f"/auth/orcid/callback?state={state}", follow_redirects=False + ) + assert resp.status_code == 302 + assert "/login" in resp.headers["location"] + + async def test_login_existing_orcid_user(self, client, orcid_user, test_db): + """Callback with known ORCID logs in existing user.""" + redir = await client.get("/auth/orcid", follow_redirects=False) + state = parse_qs(urlparse(redir.headers["location"]).query)["state"][0] + + mock_resp = _mock_orcid_token_response(orcid_id=FAKE_ORCID) + with patch("app.routes.orcid.httpx.AsyncClient") as MockClient: + MockClient.return_value.__aenter__ = AsyncMock(return_value=MockClient.return_value) + MockClient.return_value.__aexit__ = AsyncMock(return_value=False) + MockClient.return_value.post = AsyncMock(return_value=mock_resp) + + resp = await client.get( + f"/auth/orcid/callback?code=valid-code&state={state}", + follow_redirects=False, + ) + + assert resp.status_code == 302 + assert "/dashboard" in resp.headers["location"] + assert "session_id" in resp.cookies + + async def test_creates_new_user_for_unknown_orcid(self, client, test_db): + """Callback with unknown ORCID creates a new account.""" + redir = await client.get("/auth/orcid", follow_redirects=False) + state = parse_qs(urlparse(redir.headers["location"]).query)["state"][0] + + mock_resp = _mock_orcid_token_response(orcid_id="0000-0003-0000-0001", name="New Researcher") + with patch("app.routes.orcid.httpx.AsyncClient") as MockClient: + MockClient.return_value.__aenter__ = AsyncMock(return_value=MockClient.return_value) + MockClient.return_value.__aexit__ = AsyncMock(return_value=False) + MockClient.return_value.post = AsyncMock(return_value=mock_resp) + + resp = await client.get( + f"/auth/orcid/callback?code=valid-code&state={state}", + follow_redirects=False, + ) + + assert resp.status_code == 302 + assert "/dashboard" in resp.headers["location"] + + # Verify user was created + result = await test_db.execute( + select(User).where(User.orcid_id == "0000-0003-0000-0001") + ) + user = result.scalar_one_or_none() + assert user is not None + assert user.display_name == "New Researcher" + assert user.email_verified is True + + async def test_links_orcid_when_logged_in(self, authenticated_client, test_user, test_db): + """Callback when logged in links ORCID to current account.""" + redir = await authenticated_client.get("/auth/orcid", follow_redirects=False) + state = parse_qs(urlparse(redir.headers["location"]).query)["state"][0] + + mock_resp = _mock_orcid_token_response(orcid_id="0000-0003-5555-6666", name="Ignored Name") + with patch("app.routes.orcid.httpx.AsyncClient") as MockClient: + MockClient.return_value.__aenter__ = AsyncMock(return_value=MockClient.return_value) + MockClient.return_value.__aexit__ = AsyncMock(return_value=False) + MockClient.return_value.post = AsyncMock(return_value=mock_resp) + + resp = await authenticated_client.get( + f"/auth/orcid/callback?code=valid-code&state={state}", + follow_redirects=False, + ) + + assert resp.status_code == 302 + + # Verify ORCID was linked + await test_db.refresh(test_user) + assert test_user.orcid_id == "0000-0003-5555-6666" + + async def test_token_exchange_failure(self, client): + """Callback handles ORCID token exchange failure.""" + redir = await client.get("/auth/orcid", follow_redirects=False) + state = parse_qs(urlparse(redir.headers["location"]).query)["state"][0] + + mock_resp = _mock_orcid_token_error() + with patch("app.routes.orcid.httpx.AsyncClient") as MockClient: + MockClient.return_value.__aenter__ = AsyncMock(return_value=MockClient.return_value) + MockClient.return_value.__aexit__ = AsyncMock(return_value=False) + MockClient.return_value.post = AsyncMock(return_value=mock_resp) + + resp = await client.get( + f"/auth/orcid/callback?code=bad-code&state={state}", + follow_redirects=False, + ) + + assert resp.status_code == 302 + assert "/login" in resp.headers["location"] + + async def test_token_exchange_failure_authenticated(self, authenticated_client): + """Authenticated user gets error redirect to /dashboard, not /login.""" + redir = await authenticated_client.get("/auth/orcid", follow_redirects=False) + state = parse_qs(urlparse(redir.headers["location"]).query)["state"][0] + + mock_resp = _mock_orcid_token_error() + with patch("app.routes.orcid.httpx.AsyncClient") as MockClient: + MockClient.return_value.__aenter__ = AsyncMock(return_value=MockClient.return_value) + MockClient.return_value.__aexit__ = AsyncMock(return_value=False) + MockClient.return_value.post = AsyncMock(return_value=mock_resp) + + resp = await authenticated_client.get( + f"/auth/orcid/callback?code=bad-code&state={state}", + follow_redirects=False, + ) + + assert resp.status_code == 302 + assert "/dashboard" in resp.headers["location"] + assert "orcid_token" in resp.headers["location"] + + async def test_duplicate_orcid_link_rejected(self, client, orcid_user, test_db): + """Cannot link an ORCID that already belongs to another user.""" + # Create and authenticate a second user + from app.auth.session import create_session + + user2 = User( + email="user2@example.com", + password_hash=get_password_hash("password123"), + display_name="User Two", + email_verified=True, + ) + test_db.add(user2) + await test_db.commit() + await test_db.refresh(user2) + + session_id = await create_session(test_db, user2.id) + client.cookies.set("session_id", session_id) + + redir = await client.get("/auth/orcid", follow_redirects=False) + state = parse_qs(urlparse(redir.headers["location"]).query)["state"][0] + + # Try to link orcid_user's ORCID + mock_resp = _mock_orcid_token_response(orcid_id=FAKE_ORCID) + with patch("app.routes.orcid.httpx.AsyncClient") as MockClient: + MockClient.return_value.__aenter__ = AsyncMock(return_value=MockClient.return_value) + MockClient.return_value.__aexit__ = AsyncMock(return_value=False) + MockClient.return_value.post = AsyncMock(return_value=mock_resp) + + resp = await client.get( + f"/auth/orcid/callback?code=valid-code&state={state}", + follow_redirects=False, + ) + + assert resp.status_code == 302 + # Should redirect with an error, not link + assert "/dashboard" in resp.headers["location"] or "/login" in resp.headers["location"] + + # ORCID should NOT be linked to user2 + await test_db.refresh(user2) + assert user2.orcid_id is None + + +@pytest.mark.asyncio +class TestOrcidUnlink: + """GET /auth/orcid/unlink tests.""" + + async def test_unlink_requires_auth(self, client): + resp = await client.get("/auth/orcid/unlink", follow_redirects=False) + assert resp.status_code == 302 + assert "/login" in resp.headers["location"] + + async def test_unlink_removes_orcid(self, authenticated_client, test_user, test_db): + """Unlink removes orcid_id when user has a password.""" + test_user.orcid_id = FAKE_ORCID + await test_db.commit() + + resp = await authenticated_client.get("/auth/orcid/unlink", follow_redirects=False) + assert resp.status_code == 302 + + await test_db.refresh(test_user) + assert test_user.orcid_id is None + + async def test_unlink_blocked_without_password(self, client, passwordless_orcid_user, test_db): + """Cannot unlink ORCID if user has no password (would be locked out).""" + from app.auth.session import create_session + + session_id = await create_session(test_db, passwordless_orcid_user.id) + client.cookies.set("session_id", session_id) + + resp = await client.get("/auth/orcid/unlink", follow_redirects=False) + assert resp.status_code == 302 + # ORCID should still be linked + await test_db.refresh(passwordless_orcid_user) + assert passwordless_orcid_user.orcid_id == FAKE_ORCID_2 + + +@pytest.mark.asyncio +class TestOrcidUI: + """ORCID buttons appear on login/register pages and dashboard.""" + + async def test_login_page_shows_orcid_button(self, client): + resp = await client.get("/login") + assert resp.status_code == 200 + body = resp.text + assert "/auth/orcid" in body + assert "Sign in with ORCID" in body + + async def test_register_page_shows_orcid_button(self, client): + resp = await client.get("/register") + assert resp.status_code == 200 + body = resp.text + assert "/auth/orcid" in body + assert "Sign up with ORCID" in body + + async def test_dashboard_shows_link_orcid(self, authenticated_client, test_user): + """Dashboard shows 'Link ORCID' when user has no ORCID linked.""" + resp = await authenticated_client.get("/dashboard") + assert resp.status_code == 200 + assert "Link ORCID" in resp.text + + async def test_dashboard_shows_linked_orcid(self, authenticated_client, test_user, test_db): + """Dashboard shows linked ORCID iD and unlink button.""" + test_user.orcid_id = FAKE_ORCID + await test_db.commit() + + resp = await authenticated_client.get("/dashboard") + assert resp.status_code == 200 + assert FAKE_ORCID in resp.text + assert "Unlink ORCID" in resp.text + + async def test_login_shows_orcid_not_configured_error(self, client): + resp = await client.get("/login?error=orcid_not_configured") + assert resp.status_code == 200 + assert "not available" in resp.text + + async def test_login_shows_orcid_state_error(self, client): + resp = await client.get("/login?error=orcid_state") + assert resp.status_code == 200 + assert "try again" in resp.text.lower() + + async def test_register_shows_orcid_error(self, client): + resp = await client.get("/register?error=orcid_not_configured") + assert resp.status_code == 200 + assert "not available" in resp.text + + async def test_dashboard_shows_orcid_taken_error(self, authenticated_client, test_user): + resp = await authenticated_client.get("/dashboard?error=orcid_taken") + assert resp.status_code == 200 + assert "already linked" in resp.text.lower() + + async def test_dashboard_shows_orcid_linked_success(self, authenticated_client, test_user): + resp = await authenticated_client.get("/dashboard?orcid=linked") + assert resp.status_code == 200 + assert "linked successfully" in resp.text.lower() + + async def test_dashboard_shows_orcid_unlinked_success(self, authenticated_client, test_user): + resp = await authenticated_client.get("/dashboard?orcid=unlinked") + assert resp.status_code == 200 + assert "unlinked" in resp.text.lower() + + async def test_dashboard_shows_no_password_error(self, authenticated_client, test_user): + resp = await authenticated_client.get("/dashboard?error=orcid_no_password") + assert resp.status_code == 200 + assert "password" in resp.text.lower() + + async def test_dashboard_shows_orcid_not_configured_error(self, authenticated_client, test_user): + resp = await authenticated_client.get("/dashboard?error=orcid_not_configured") + assert resp.status_code == 200 + assert "not available" in resp.text.lower() + + async def test_dashboard_shows_orcid_state_error(self, authenticated_client, test_user): + resp = await authenticated_client.get("/dashboard?error=orcid_state") + assert resp.status_code == 200 + assert "try again" in resp.text.lower() + + async def test_dashboard_shows_orcid_token_error(self, authenticated_client, test_user): + resp = await authenticated_client.get("/dashboard?error=orcid_token") + assert resp.status_code == 200 + assert "try again" in resp.text.lower()