4040 logger .warning ("Failed to remove semantic decoder for CBOR tag 258" , e )
4141 pass
4242
43- from cbor2 import CBOREncoder , CBORSimpleValue , CBORTag , dumps , loads , undefined
44- from frozendict import frozendict
43+ from cbor2 import (
44+ CBOREncoder ,
45+ CBORSimpleValue ,
46+ CBORTag ,
47+ FrozenDict ,
48+ dumps ,
49+ loads ,
50+ undefined ,
51+ )
4552from frozenlist import FrozenList
4653from pprintpp import pformat
4754
@@ -126,7 +133,7 @@ class RawCBOR:
126133 set ,
127134 Fraction ,
128135 frozenset ,
129- frozendict ,
136+ FrozenDict ,
130137 FrozenList ,
131138 IndefiniteFrozenList ,
132139 ByteString ,
@@ -154,7 +161,7 @@ class RawCBOR:
154161 CBORTag ,
155162 set ,
156163 frozenset ,
157- frozendict ,
164+ FrozenDict ,
158165 Fraction ,
159166 FrozenList ,
160167 IndefiniteFrozenList ,
@@ -205,7 +212,7 @@ def default_encoder(
205212 RawCBOR ,
206213 FrozenList ,
207214 IndefiniteFrozenList ,
208- frozendict ,
215+ FrozenDict ,
209216 ),
210217 ), (
211218 f"Type of input value is not CBORSerializable, " f"got { type (value )} instead."
@@ -231,7 +238,7 @@ def default_encoder(
231238 encoder .write (value .cbor )
232239 elif isinstance (value , FrozenList ):
233240 encoder .encode (list (value ))
234- elif isinstance (value , frozendict ):
241+ elif isinstance (value , FrozenDict ):
235242 encoder .encode (dict (value ))
236243 else :
237244 encoder .encode (value .to_validated_primitive ())
@@ -296,7 +303,7 @@ def _dfs(value, freeze=False):
296303 for k , v in value .items ():
297304 _dict [_dfs (k , freeze = True )] = _dfs (v , freeze )
298305 if freeze :
299- return frozendict (_dict )
306+ return FrozenDict (_dict )
300307 return _dict
301308 elif isinstance (value , set ):
302309 _set = set (_dfs (v , freeze = True ) for v in value )
@@ -348,7 +355,7 @@ def _check_recursive(value, type_hint):
348355 return _check_recursive (value , type_hint .__args__ [0 ])
349356 elif origin is Union :
350357 return any (_check_recursive (value , arg ) for arg in type_hint .__args__ )
351- elif origin is Dict or isinstance (value , (dict , frozendict )):
358+ elif origin is Dict or isinstance (value , (dict , FrozenDict )):
352359 key_type , value_type = type_hint .__args__
353360 return all (
354361 _check_recursive (k , key_type ) and _check_recursive (v , value_type )
@@ -814,8 +821,8 @@ def to_shallow_primitive(self) -> Primitive:
814821 return primitives
815822
816823 @classmethod
817- @limit_primitive_type (dict )
818- def from_primitive (cls : Type [MapBase ], values : dict ) -> MapBase :
824+ @limit_primitive_type (dict , FrozenDict )
825+ def from_primitive (cls : Type [MapBase ], values : Union [ dict , FrozenDict ] ) -> MapBase :
819826 """Restore a primitive value to its original class type.
820827
821828 Args:
@@ -1038,10 +1045,14 @@ def from_primitive(
10381045 value .value = [type_arg .from_primitive (v ) for v in value .value ]
10391046 return cls (value .value , use_tag = True )
10401047
1048+ use_tag = isinstance (value , set )
1049+
10411050 if isinstance (value , (list , tuple , set )):
10421051 if isclass (type_arg ) and issubclass (type_arg , CBORSerializable ):
10431052 value = [type_arg .from_primitive (v ) for v in value ]
1044- return cls (list (value ), use_tag = False )
1053+
1054+ # If the value is a set, we know it is coming from a CBORTag (#6.258)
1055+ return cls (list (value ), use_tag = use_tag )
10451056
10461057 raise ValueError (f"Cannot deserialize { value } to { cls } " )
10471058
0 commit comments