Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions alembic/versions/596bb368fc0d_add_orcid_id_to_users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Add orcid_id column to users

Revision ID: 596bb368fc0d
Revises: d3a7b8c1e2f4
Create Date: 2026-03-28

"""

from typing import Sequence, Union

import sqlalchemy as sa

from alembic import op

revision: str = "596bb368fc0d"
down_revision: Union[str, Sequence[str], None] = "d3a7b8c1e2f4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.add_column(
"users",
sa.Column("orcid_id", sa.String(20), nullable=True),
)
op.create_index("ix_users_orcid_id", "users", ["orcid_id"], unique=True)


def downgrade() -> None:
op.drop_index("ix_users_orcid_id", table_name="users")
op.drop_column("users", "orcid_id")
22 changes: 21 additions & 1 deletion app/models/user.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
from datetime import datetime, timezone
import re
import uuid

from sqlalchemy import Boolean, Column, DateTime, String, TypeDecorator
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from sqlalchemy.orm import relationship, validates

from app.database import Base

ORCID_PATTERN = re.compile(r"^\d{4}-\d{4}-\d{4}-\d{3}[\dX]$")


def validate_orcid_id(value: str | None) -> str | None:
"""Validate ORCID iD format: XXXX-XXXX-XXXX-XXXX (last char may be X checksum)."""
if value is None:
return None
if not ORCID_PATTERN.match(value):
raise ValueError(
f"Invalid ORCID iD format: '{value}'. "
"Expected format: 0000-0002-1234-5678"
)
return value


# Cross-platform UUID type that works with SQLite
class GUID(TypeDecorator):
Expand Down Expand Up @@ -58,11 +73,16 @@ class User(Base):
onupdate=lambda: datetime.now(timezone.utc),
nullable=False,
)
orcid_id = Column(String(20), nullable=True, unique=True, index=True)

# Relationships
scrolls = relationship("Scroll", back_populates="user")
sessions = relationship("Session", back_populates="user")
tokens = relationship("Token", back_populates="user")

@validates("orcid_id")
def _validate_orcid_id(self, _key: str, value: str | None) -> str | None:
return validate_orcid_id(value)

def __repr__(self):
return f"<User(email='{self.email}', display_name='{self.display_name}')>"
154 changes: 154 additions & 0 deletions tests/test_orcid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""Tests for ORCID iD storage on the User model."""

import pytest
import pytest_asyncio
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError

from app.models.user import User, validate_orcid_id


class TestValidateOrcidId:
"""Tests for the orcid_id format validation function."""

def test_valid_orcid(self):
assert validate_orcid_id("0000-0002-1234-5678") == "0000-0002-1234-5678"

def test_valid_orcid_with_checksum_x(self):
assert validate_orcid_id("0000-0002-1234-567X") == "0000-0002-1234-567X"

def test_none_passes_through(self):
assert validate_orcid_id(None) is None

def test_rejects_missing_hyphens(self):
with pytest.raises(ValueError, match="ORCID"):
validate_orcid_id("0000000212345678")

def test_rejects_too_short(self):
with pytest.raises(ValueError, match="ORCID"):
validate_orcid_id("0000-0002-1234")

def test_rejects_letters_in_body(self):
with pytest.raises(ValueError, match="ORCID"):
validate_orcid_id("000A-0002-1234-5678")

def test_rejects_empty_string(self):
with pytest.raises(ValueError, match="ORCID"):
validate_orcid_id("")

def test_rejects_lowercase_x(self):
with pytest.raises(ValueError, match="ORCID"):
validate_orcid_id("0000-0002-1234-567x")


@pytest.mark.asyncio
class TestUserOrcidColumn:
"""Tests for the orcid_id column on the User model."""

async def test_user_created_without_orcid(self, test_db):
"""Backward compat: users can be created with no orcid_id."""
from app.auth.utils import get_password_hash

user = User(
email="noorcid@example.com",
password_hash=get_password_hash("password123"),
display_name="No ORCID User",
email_verified=True,
)
test_db.add(user)
await test_db.commit()
await test_db.refresh(user)

assert user.orcid_id is None

async def test_user_created_with_valid_orcid(self, test_db):
from app.auth.utils import get_password_hash

user = User(
email="orcid@example.com",
password_hash=get_password_hash("password123"),
display_name="ORCID User",
email_verified=True,
orcid_id="0000-0002-1234-5678",
)
test_db.add(user)
await test_db.commit()
await test_db.refresh(user)

assert user.orcid_id == "0000-0002-1234-5678"

async def test_orcid_with_checksum_x(self, test_db):
from app.auth.utils import get_password_hash

user = User(
email="orcidx@example.com",
password_hash=get_password_hash("password123"),
display_name="ORCID X User",
email_verified=True,
orcid_id="0000-0002-1234-567X",
)
test_db.add(user)
await test_db.commit()
await test_db.refresh(user)

assert user.orcid_id == "0000-0002-1234-567X"

async def test_orcid_uniqueness_constraint(self, test_db):
"""Two users cannot share the same ORCID iD."""
from app.auth.utils import get_password_hash

user1 = User(
email="user1@example.com",
password_hash=get_password_hash("password123"),
display_name="User 1",
email_verified=True,
orcid_id="0000-0002-1234-5678",
)
test_db.add(user1)
await test_db.commit()

user2 = User(
email="user2@example.com",
password_hash=get_password_hash("password123"),
display_name="User 2",
email_verified=True,
orcid_id="0000-0002-1234-5678",
)
test_db.add(user2)

with pytest.raises(IntegrityError):
await test_db.commit()

async def test_multiple_users_with_null_orcid(self, test_db):
"""Multiple users can have NULL orcid_id (no false unique violation)."""
from app.auth.utils import get_password_hash

for i in range(3):
user = User(
email=f"null_orcid_{i}@example.com",
password_hash=get_password_hash("password123"),
display_name=f"Null ORCID {i}",
email_verified=True,
)
test_db.add(user)

await test_db.commit()

result = await test_db.execute(
select(User).where(User.orcid_id.is_(None))
)
users = result.scalars().all()
assert len(users) == 3

async def test_invalid_orcid_rejected_by_validator(self, test_db):
"""The @validates decorator rejects bad formats before hitting the DB."""
from app.auth.utils import get_password_hash

with pytest.raises(ValueError, match="ORCID"):
User(
email="bad@example.com",
password_hash=get_password_hash("password123"),
display_name="Bad ORCID",
email_verified=True,
orcid_id="not-an-orcid",
)
Loading