File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -784,20 +784,21 @@ class Average(Metric):
784784
785785 @classmethod
786786 def empty (cls ) -> Average :
787- return cls (total = jnp .array (0 , jnp .float32 ), count = jnp .array (0 , jnp .int32 ))
787+ return cls (total = jnp .array (0 , jnp .float32 ), count = jnp .array (0 , jnp .float32 ))
788788
789789 @classmethod
790790 def from_model_output (
791791 cls , values : jnp .ndarray , mask : jnp .ndarray | None = None , ** _
792792 ) -> Average :
793793 values , mask = _broadcast_masks (values , mask )
794794 return cls (
795- total = jnp .where (mask , values , jnp .zeros_like (values )).sum (),
795+ total = jnp .where (mask , values , jnp .zeros_like (values )).sum ().astype (
796+ jnp .float32 ),
796797 count = jnp .where (
797798 mask ,
798799 jnp .ones_like (values , dtype = jnp .int32 ),
799800 jnp .zeros_like (values , dtype = jnp .int32 ),
800- ).sum (),
801+ ).sum (). astype ( jnp . float32 ) ,
801802 )
802803
803804 def merge (self , other : Average ) -> Average :
You can’t perform that action at this time.
0 commit comments