Skip to content

Commit 8e739b3

Browse files
authored
Update oauth_providers.py to include Keycloak (#1525)
1 parent fd882b8 commit 8e739b3

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

backend/chainlit/oauth_providers.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,71 @@ async def get_user_info(self, token: str):
665665
return (gitlab_user, user)
666666

667667

668+
class KeycloakOAuthProvider(OAuthProvider):
669+
env = [
670+
"OAUTH_KEYCLOAK_CLIENT_ID",
671+
"OAUTH_KEYCLOAK_CLIENT_SECRET",
672+
"OAUTH_KEYCLOAK_REALM",
673+
"OAUTH_KEYCLOAK_BASE_URL",
674+
]
675+
id = os.environ.get("OAUTH_KEYCLOAK_NAME", "keycloak")
676+
677+
def __init__(self):
678+
self.client_id = os.environ.get("OAUTH_KEYCLOAK_CLIENT_ID")
679+
self.client_secret = os.environ.get("OAUTH_KEYCLOAK_CLIENT_SECRET")
680+
self.realm = os.environ.get("OAUTH_KEYCLOAK_REALM")
681+
self.base_url = os.environ.get("OAUTH_KEYCLOAK_BASE_URL")
682+
self.authorize_url = (
683+
f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/auth"
684+
)
685+
686+
self.authorize_params = {
687+
"scope": "profile email openid",
688+
"response_type": "code",
689+
}
690+
691+
if prompt := self.get_prompt():
692+
self.authorize_params["prompt"] = prompt
693+
694+
async def get_token(self, code: str, url: str):
695+
payload = {
696+
"client_id": self.client_id,
697+
"client_secret": self.client_secret,
698+
"code": code,
699+
"grant_type": "authorization_code",
700+
"redirect_uri": url,
701+
}
702+
async with httpx.AsyncClient() as client:
703+
response = await client.post(
704+
f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/token",
705+
data=payload,
706+
)
707+
response.raise_for_status()
708+
json = response.json()
709+
token = json.get("access_token")
710+
if not token:
711+
raise httpx.HTTPStatusError(
712+
"Failed to get the access token",
713+
request=response.request,
714+
response=response,
715+
)
716+
return token
717+
718+
async def get_user_info(self, token: str):
719+
async with httpx.AsyncClient() as client:
720+
response = await client.get(
721+
f"{self.base_url}/realms/{self.realm}/protocol/openid-connect/userinfo",
722+
headers={"Authorization": f"Bearer {token}"},
723+
)
724+
response.raise_for_status()
725+
kc_user = response.json()
726+
user = User(
727+
identifier=kc_user["email"],
728+
metadata={"provider": "keycloak"},
729+
)
730+
return (kc_user, user)
731+
732+
668733
providers = [
669734
GithubOAuthProvider(),
670735
GoogleOAuthProvider(),
@@ -675,6 +740,7 @@ async def get_user_info(self, token: str):
675740
DescopeOAuthProvider(),
676741
AWSCognitoOAuthProvider(),
677742
GitlabOAuthProvider(),
743+
KeycloakOAuthProvider(),
678744
]
679745

680746

0 commit comments

Comments
 (0)