11"""JSON Web Key."""
22import abc
3- import base64
43import json
54import logging
65import math
@@ -407,38 +406,45 @@ class JWKOKP(JWK):
407406 )
408407 required = ('crv' , JWK .type_field_name , 'x' )
409408
410- def __init__ (self , * args , ** kwargs ) -> None :
409+ def __init__ (self , * args , ** kwargs ):
411410 if 'key' in kwargs and not isinstance (kwargs ['key' ], util .ComparableOKPKey ):
412411 kwargs ['key' ] = util .ComparableOKPKey (kwargs ['key' ])
413412 super ().__init__ (* args , ** kwargs )
414413
415- def public_key (self ) -> Union [
416- ed25519 .Ed25519PublicKey , ed448 .Ed448PublicKey ,
417- x25519 .X25519PublicKey , x448 .X448PublicKey ,
418- ]:
419- return self ._wrapped .__class__ .public_key ()
414+ def public_key (self ):
415+ return self .key ._wrapped .__class__ .public_key ()
416+
417+ def _key_to_crv (self ):
418+ if isinstance (self .key ._wrapped , (ed25519 .Ed25519PrivateKey , ed25519 .Ed25519PrivateKey )):
419+ return "Ed25519"
420+ elif isinstance (self .key ._wrapped , (ed448 .Ed448PrivateKey , ed448 .Ed448PrivateKey )):
421+ return "Ed448"
422+ elif isinstance (self .key ._wrapped , (x25519 .X25519PrivateKey , x25519 .X25519PrivateKey )):
423+ return "X25519"
424+ elif isinstance (self .key ._wrapped , (x448 .X448PrivateKey , x448 .X448PrivateKey )):
425+ return "X448"
426+ return NotImplemented
420427
421428 def fields_to_partial_json (self ) -> Dict :
422- params = {} # type: Dict
429+ params = {}
430+ print (dir (self ))
423431 if self .key .is_private ():
424- params ['d' ] = base64 . b64encode (self .key .private_bytes (
432+ params ['d' ] = json_util . encode_b64jose (self .key .private_bytes (
425433 encoding = serialization .Encoding .PEM ,
426434 format = serialization .PrivateFormat .PKCS8 ,
427435 encryption_algorithm = serialization .NoEncryption ()
428436 ))
429437 params ['x' ] = self .key .public_key ().public_bytes (
430438 encoding = serialization .Encoding .PEM ,
431- format = serialization .PublicFormat .PKCS8 ,
432- encryption_algorithm = serialization .NoEncryption ()
439+ format = serialization .PublicFormat .SubjectPublicKeyInfo ,
433440 )
434441 else :
435- params ['x' ] = base64 . b64decode (self .key .public_bytes (
442+ params ['x' ] = json_util . encode_b64jose (self .key .public_bytes (
436443 serialization .Encoding .Raw ,
437444 serialization .PublicFormat .Raw ,
438445 serialization .NoEncryption (),
439446 ))
440- # TODO find a better way to get the curve name
441- params ['crv' ] = 'ed25519'
447+ params ['crv' ] = self ._key_to_crv ()
442448 return params
443449
444450 @classmethod
@@ -463,12 +469,12 @@ def fields_from_json(cls, jobj) -> ComparableOKPKey:
463469
464470 if "x" not in obj :
465471 raise errors .DeserializationError ('OKP should have "x" parameter' )
466- x = base64 . b64decode (jobj .get ("x" ))
472+ x = json_util . decode_b64jose (jobj .get ("x" ))
467473
468474 try :
469475 if "d" not in obj :
470476 return jobj ["key" ]._wrapped .__class__ .from_public_bytes (x ) # noqa
471- d = base64 . b64decode (obj .get ("d" ))
477+ d = json_util . decode_b64jose (obj .get ("d" ))
472478 return jobj ["key" ]._wrapped .__class__ .from_private_bytes (d ) # noqa
473479 except ValueError as err :
474480 raise errors .DeserializationError ("Invalid key parameter" ) from err
0 commit comments