@@ -211,16 +211,20 @@ class CompilationResult:
211211 keys : List [str ]
212212
213213 def __post_init__ (self , sequence_examples : List [tf .train .SequenceExample ]):
214- object .__setattr__ (self , 'serialized_sequence_examples' ,
215- [x .SerializeToString () for x in sequence_examples ])
214+ object .__setattr__ (
215+ self , 'serialized_sequence_examples' ,
216+ [x .SerializeToString () for x in sequence_examples if x is not None ])
216217 lengths = [
217218 len (next (iter (x .feature_lists .feature_list .values ())).feature )
218219 for x in sequence_examples
220+ if x is not None
219221 ]
220222 object .__setattr__ (self , 'length' , sum (lengths ))
221223
222- assert (len (self .serialized_sequence_examples ) == len (self .rewards ) ==
223- (len (self .keys )))
224+ # TODO: is it necessary to return keys AND reward_stats(which has the keys)?
225+ # sequence_examples' length could also just not be checked, this allows
226+ # raw_reward_only to do less work
227+ assert (len (sequence_examples ) == len (self .rewards ) == (len (self .keys )))
224228 assert set (self .keys ) == set (self .reward_stats .keys ())
225229 assert not hasattr (self , 'sequence_examples' )
226230
@@ -230,9 +234,11 @@ class CompilationRunnerStub(metaclass=abc.ABCMeta):
230234
231235 @abc .abstractmethod
232236 def collect_data (
233- self , module_spec : corpus .ModuleSpec , tf_policy_path : str ,
234- reward_stat : Optional [Dict [str , RewardStat ]]
235- ) -> WorkerFuture [CompilationResult ]:
237+ self ,
238+ module_spec : corpus .ModuleSpec ,
239+ tf_policy_path : str ,
240+ reward_stat : Optional [Dict [str , RewardStat ]],
241+ raw_reward_only : bool = False ) -> WorkerFuture [CompilationResult ]:
236242 raise NotImplementedError ()
237243
238244 @abc .abstractmethod
@@ -275,17 +281,18 @@ def enable(self):
275281 def cancel_all_work (self ):
276282 self ._cancellation_manager .kill_all_processes ()
277283
278- def collect_data (
279- self , module_spec : corpus .ModuleSpec , tf_policy_path : str ,
280- reward_stat : Optional [Dict [str , RewardStat ]]) -> CompilationResult :
284+ def collect_data (self ,
285+ module_spec : corpus .ModuleSpec ,
286+ tf_policy_path : str ,
287+ reward_stat : Optional [Dict [str , RewardStat ]],
288+ raw_reward_only = False ) -> CompilationResult :
281289 """Collect data for the given IR file and policy.
282290
283291 Args:
284292 module_spec: a ModuleSpec.
285293 tf_policy_path: path to the tensorflow policy.
286294 reward_stat: reward stat of this module, None if unknown.
287- cancellation_token: a CancellationToken through which workers may be
288- signaled early termination
295+ raw_reward_only: whether to return the raw reward value without examples.
289296
290297 Returns:
291298 A CompilationResult. In particular:
@@ -311,7 +318,7 @@ def collect_data(
311318 policy_result = self ._compile_fn (
312319 module_spec ,
313320 tf_policy_path ,
314- reward_only = False ,
321+ reward_only = raw_reward_only ,
315322 cancellation_manager = self ._cancellation_manager )
316323 else :
317324 policy_result = default_result
@@ -326,6 +333,11 @@ def collect_data(
326333 raise ValueError (
327334 (f'Example { k } does not exist under default policy for '
328335 f'module { module_spec .name } ' ))
336+ if raw_reward_only :
337+ sequence_example_list .append (None )
338+ rewards .append (policy_reward )
339+ keys .append (k )
340+ continue
329341 default_reward = reward_stat [k ].default_reward
330342 moving_average_reward = reward_stat [k ].moving_average_reward
331343 sequence_example = _overwrite_trajectory_reward (
0 commit comments