Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 14 additions & 21 deletions src/auth/src/supabase_auth/_async/gotrue_admin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from pydantic import TypeAdapter

from ..helpers import (
validate_uuid,
model_validate,
parse_link_response,
parse_user_response,
validate_uuid,
)
from ..http_clients import AsyncClient
from ..types import (
Expand Down Expand Up @@ -57,15 +57,15 @@ def __init__(
)
# TODO(@o-santi): why is is this done this way?
self.mfa = AsyncGoTrueAdminMFAAPI()
self.mfa.list_factors = self._list_factors # type: ignore
self.mfa.delete_factor = self._delete_factor # type: ignore
self.mfa.list_factors = self._list_factors # type: ignore
self.mfa.delete_factor = self._delete_factor # type: ignore
self.oauth = AsyncGoTrueAdminOAuthAPI()
self.oauth.list_clients = self._list_oauth_clients # type: ignore
self.oauth.create_client = self._create_oauth_client # type: ignore
self.oauth.get_client = self._get_oauth_client # type: ignore
self.oauth.update_client = self._update_oauth_client # type: ignore
self.oauth.delete_client = self._delete_oauth_client # type: ignore
self.oauth.regenerate_client_secret = self._regenerate_oauth_client_secret # type: ignore
self.oauth.list_clients = self._list_oauth_clients # type: ignore
self.oauth.create_client = self._create_oauth_client # type: ignore
self.oauth.get_client = self._get_oauth_client # type: ignore
self.oauth.update_client = self._update_oauth_client # type: ignore
self.oauth.delete_client = self._delete_oauth_client # type: ignore
self.oauth.regenerate_client_secret = self._regenerate_oauth_client_secret # type: ignore

async def sign_out(self, jwt: str, scope: SignOutScope = "global") -> None:
"""
Expand Down Expand Up @@ -276,9 +276,8 @@ async def _create_oauth_client(
body=params,
)

return OAuthClientResponse(
client=model_validate(OAuthClient, response.content)
)
return OAuthClientResponse(client=model_validate(OAuthClient, response.content))

async def _get_oauth_client(
self,
client_id: str,
Expand All @@ -295,9 +294,7 @@ async def _get_oauth_client(
"GET",
f"admin/oauth/clients/{client_id}",
)
return OAuthClientResponse(
client=model_validate(OAuthClient, response.content)
)
return OAuthClientResponse(client=model_validate(OAuthClient, response.content))

async def _update_oauth_client(
self,
Expand All @@ -317,9 +314,7 @@ async def _update_oauth_client(
f"admin/oauth/clients/{client_id}",
body=params,
)
return OAuthClientResponse(
client=model_validate(OAuthClient, response.content)
)
return OAuthClientResponse(client=model_validate(OAuthClient, response.content))

async def _delete_oauth_client(
self,
Expand Down Expand Up @@ -354,6 +349,4 @@ async def _regenerate_oauth_client_secret(
"POST",
f"admin/oauth/clients/{client_id}/regenerate_secret",
)
return OAuthClientResponse(
client=model_validate(OAuthClient, response.content)
)
return OAuthClientResponse(client=model_validate(OAuthClient, response.content))
3 changes: 2 additions & 1 deletion src/auth/src/supabase_auth/_async/gotrue_admin_oauth_api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional

from ..types import (
CreateOAuthClientParams,
OAuthClientListResponse,
OAuthClientResponse,
PageParams,
UpdateOAuthClientParams,
)
from typing import Optional


class AsyncGoTrueAdminOAuthAPI:
Expand Down
6 changes: 4 additions & 2 deletions src/auth/src/supabase_auth/_async/gotrue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,9 @@ async def _get_url_for_provider(
query = query.set("provider", provider)
return f"{url}?{query}", query

async def exchange_code_for_session(self, params: CodeExchangeParams):
async def exchange_code_for_session(
self, params: CodeExchangeParams
) -> AuthResponse:
code_verifier = params.get("code_verifier") or await self._storage.get_item(
f"{self._storage_key}-code-verifier"
)
Expand All @@ -1184,7 +1186,7 @@ async def exchange_code_for_session(self, params: CodeExchangeParams):
if auth_response.session:
await self._save_session(auth_response.session)
self._notify_all_subscribers("SIGNED_IN", auth_response.session)
return response
return auth_response

async def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK:
jwk: Optional[JWK] = None
Expand Down
35 changes: 14 additions & 21 deletions src/auth/src/supabase_auth/_sync/gotrue_admin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from pydantic import TypeAdapter

from ..helpers import (
validate_uuid,
model_validate,
parse_link_response,
parse_user_response,
validate_uuid,
)
from ..http_clients import SyncClient
from ..types import (
Expand Down Expand Up @@ -57,15 +57,15 @@ def __init__(
)
# TODO(@o-santi): why is is this done this way?
self.mfa = SyncGoTrueAdminMFAAPI()
self.mfa.list_factors = self._list_factors # type: ignore
self.mfa.delete_factor = self._delete_factor # type: ignore
self.mfa.list_factors = self._list_factors # type: ignore
self.mfa.delete_factor = self._delete_factor # type: ignore
self.oauth = SyncGoTrueAdminOAuthAPI()
self.oauth.list_clients = self._list_oauth_clients # type: ignore
self.oauth.create_client = self._create_oauth_client # type: ignore
self.oauth.get_client = self._get_oauth_client # type: ignore
self.oauth.update_client = self._update_oauth_client # type: ignore
self.oauth.delete_client = self._delete_oauth_client # type: ignore
self.oauth.regenerate_client_secret = self._regenerate_oauth_client_secret # type: ignore
self.oauth.list_clients = self._list_oauth_clients # type: ignore
self.oauth.create_client = self._create_oauth_client # type: ignore
self.oauth.get_client = self._get_oauth_client # type: ignore
self.oauth.update_client = self._update_oauth_client # type: ignore
self.oauth.delete_client = self._delete_oauth_client # type: ignore
self.oauth.regenerate_client_secret = self._regenerate_oauth_client_secret # type: ignore

def sign_out(self, jwt: str, scope: SignOutScope = "global") -> None:
"""
Expand Down Expand Up @@ -276,9 +276,8 @@ def _create_oauth_client(
body=params,
)

return OAuthClientResponse(
client=model_validate(OAuthClient, response.content)
)
return OAuthClientResponse(client=model_validate(OAuthClient, response.content))

def _get_oauth_client(
self,
client_id: str,
Expand All @@ -295,9 +294,7 @@ def _get_oauth_client(
"GET",
f"admin/oauth/clients/{client_id}",
)
return OAuthClientResponse(
client=model_validate(OAuthClient, response.content)
)
return OAuthClientResponse(client=model_validate(OAuthClient, response.content))

def _update_oauth_client(
self,
Expand All @@ -317,9 +314,7 @@ def _update_oauth_client(
f"admin/oauth/clients/{client_id}",
body=params,
)
return OAuthClientResponse(
client=model_validate(OAuthClient, response.content)
)
return OAuthClientResponse(client=model_validate(OAuthClient, response.content))

def _delete_oauth_client(
self,
Expand Down Expand Up @@ -354,6 +349,4 @@ def _regenerate_oauth_client_secret(
"POST",
f"admin/oauth/clients/{client_id}/regenerate_secret",
)
return OAuthClientResponse(
client=model_validate(OAuthClient, response.content)
)
return OAuthClientResponse(client=model_validate(OAuthClient, response.content))
3 changes: 2 additions & 1 deletion src/auth/src/supabase_auth/_sync/gotrue_admin_oauth_api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional

from ..types import (
CreateOAuthClientParams,
OAuthClientListResponse,
OAuthClientResponse,
PageParams,
UpdateOAuthClientParams,
)
from typing import Optional


class SyncGoTrueAdminOAuthAPI:
Expand Down
16 changes: 5 additions & 11 deletions src/auth/src/supabase_auth/_sync/gotrue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,7 @@ def sign_in_with_oauth(
)
return OAuthResponse(provider=provider, url=url_with_qs)

def link_identity(
self, credentials: SignInWithOAuthCredentials
) -> OAuthResponse:
def link_identity(self, credentials: SignInWithOAuthCredentials) -> OAuthResponse:
provider = credentials["provider"]
options = credentials.get("options", {})
redirect_to = options.get("redirect_to")
Expand Down Expand Up @@ -743,9 +741,7 @@ def set_session(self, access_token: str, refresh_token: str) -> AuthResponse:
self._notify_all_subscribers("TOKEN_REFRESHED", session)
return AuthResponse(session=session, user=session.user)

def refresh_session(
self, refresh_token: Optional[str] = None
) -> AuthResponse:
def refresh_session(self, refresh_token: Optional[str] = None) -> AuthResponse:
"""
Returns a new session, regardless of expiry status.

Expand Down Expand Up @@ -1153,9 +1149,7 @@ def _get_url_for_provider(
if self._flow_type == "pkce":
code_verifier = generate_pkce_verifier()
code_challenge = generate_pkce_challenge(code_verifier)
self._storage.set_item(
f"{self._storage_key}-code-verifier", code_verifier
)
self._storage.set_item(f"{self._storage_key}-code-verifier", code_verifier)
code_challenge_method = (
"plain" if code_verifier == code_challenge else "s256"
)
Expand All @@ -1165,7 +1159,7 @@ def _get_url_for_provider(
query = query.set("provider", provider)
return f"{url}?{query}", query

def exchange_code_for_session(self, params: CodeExchangeParams):
def exchange_code_for_session(self, params: CodeExchangeParams) -> AuthResponse:
code_verifier = params.get("code_verifier") or self._storage.get_item(
f"{self._storage_key}-code-verifier"
)
Expand All @@ -1184,7 +1178,7 @@ def exchange_code_for_session(self, params: CodeExchangeParams):
if auth_response.session:
self._save_session(auth_response.session)
self._notify_all_subscribers("SIGNED_IN", auth_response.session)
return response
return auth_response

def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK:
jwk: Optional[JWK] = None
Expand Down
3 changes: 2 additions & 1 deletion src/auth/src/supabase_auth/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,9 @@ def is_valid_uuid(value: str) -> bool:
except ValueError:
return False


def validate_uuid(id: str | None) -> None:
if id is None:
raise ValueError("Invalid id, id is None")
if not is_valid_uuid(id):
raise ValueError(f"Invalid id, '{id}' is not a valid uuid")
raise ValueError(f"Invalid id, '{id}' is not a valid uuid")
6 changes: 5 additions & 1 deletion src/auth/src/supabase_auth/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,9 @@ class JWKSet(TypedDict):
Only relevant when the OAuth 2.1 server is enabled in Supabase Auth.
"""

OAuthClientTokenEndpointAuthMethod = Literal["none", "client_secret_basic", "client_secret_post"]
OAuthClientTokenEndpointAuthMethod = Literal[
"none", "client_secret_basic", "client_secret_post"
]
"""
OAuth client token endpoint authentication method.
Only relevant when the OAuth 2.1 server is enabled in Supabase Auth.
Expand Down Expand Up @@ -957,6 +959,7 @@ class CreateOAuthClientParams(BaseModel):
scope: Optional[str] = None
"""Space-separated list of scope values"""


class UpdateOAuthClientParams(BaseModel):
"""
Parameters for updating an existing OAuth client.
Expand All @@ -974,6 +977,7 @@ class UpdateOAuthClientParams(BaseModel):
grant_types: Optional[List[OAuthClientGrantType]] = None
"""Array of allowed grant types"""


class OAuthClientResponse(BaseModel):
"""
Response type for OAuth client operations.
Expand Down
3 changes: 3 additions & 0 deletions src/auth/tests/_async/test_gotrue_admin_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
AuthWeakPasswordError,
)
from supabase_auth.types import CreateOAuthClientParams, UpdateOAuthClientParams

from .clients import (
auth_client,
auth_client_with_session,
Expand Down Expand Up @@ -649,6 +650,7 @@ async def test_get_oauth_client():
assert response.client is not None
assert response.client.client_id == client_id


# Server is not yet released, so this test is not yet relevant.
# async def test_update_oauth_client():
# """Test updating an OAuth client."""
Expand All @@ -671,6 +673,7 @@ async def test_get_oauth_client():
# assert response.client is not None
# assert response.client.client_name == "Updated Test OAuth Client"


async def test_delete_oauth_client():
"""Test deleting an OAuth client."""
# First create a client
Expand Down
4 changes: 1 addition & 3 deletions src/auth/tests/_sync/test_gotrue.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,7 @@ def test_exchange_code_for_session():
client._flow_type = "pkce"

# Test the PKCE URL generation which is needed for exchange_code_for_session
url, params = client._get_url_for_provider(
f"{client._url}/authorize", "github", {}
)
url, params = client._get_url_for_provider(f"{client._url}/authorize", "github", {})

# Verify PKCE parameters were added
assert "code_challenge" in params
Expand Down
Loading