1818
1919import random
2020import warnings
21+ from collections import defaultdict
2122from dataclasses import dataclass , field
2223from typing import Dict , Optional , Type , Union
2324
@@ -335,8 +336,7 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int,
335336
336337 # only sample within the mask, if the mask is in the batch
337338 all_indices = []
338- all_images = []
339- all_depth_images = []
339+ all_images = defaultdict (list )
340340
341341 assert num_rays_per_batch % 2 == 0 , "num_rays_per_batch must be divisible by 2"
342342 num_rays_per_image = divide_rays_per_image (num_rays_per_batch , num_images )
@@ -350,10 +350,11 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int,
350350 )
351351 indices [:, 0 ] = i
352352 all_indices .append (indices )
353- all_images .append (batch ["image" ][i ][indices [:, 1 ], indices [:, 2 ]])
354- if "depth_image" in batch :
355- all_depth_images .append (batch ["depth_image" ][i ][indices [:, 1 ], indices [:, 2 ]])
356353
354+ for key , value in batch .items ():
355+ if key in ["image_idx" , "mask" ]:
356+ continue
357+ all_images [key ].append (value [i ][indices [:, 1 ], indices [:, 2 ]])
357358 else :
358359 for i , num_rays in enumerate (num_rays_per_image ):
359360 image_height , image_width , _ = batch ["image" ][i ].shape
@@ -363,26 +364,19 @@ def collate_image_dataset_batch_list(self, batch: Dict, num_rays_per_batch: int,
363364 indices = self .sample_method (num_rays , 1 , image_height , image_width , device = device )
364365 indices [:, 0 ] = i
365366 all_indices .append (indices )
366- all_images .append (batch ["image" ][i ][indices [:, 1 ], indices [:, 2 ]])
367- if "depth_image" in batch :
368- all_depth_images .append (batch ["depth_image" ][i ][indices [:, 1 ], indices [:, 2 ]])
367+ for key , value in batch .items ():
368+ if key in ["image_idx" , "mask" ]:
369+ continue
370+ all_images [key ].append (value [i ][indices [:, 1 ], indices [:, 2 ]])
369371
370372 indices = torch .cat (all_indices , dim = 0 )
371373
372- c , y , x = (i .flatten () for i in torch .split (indices , 1 , dim = - 1 ))
373- collated_batch = {
374- key : value [c , y , x ]
375- for key , value in batch .items ()
376- if key not in ("image_idx" , "image" , "mask" , "depth_image" ) and value is not None
377- }
378-
379- collated_batch ["image" ] = torch .cat (all_images , dim = 0 )
380- if "depth_image" in batch :
381- collated_batch ["depth_image" ] = torch .cat (all_depth_images , dim = 0 )
374+ collated_batch = {key : torch .cat (all_images [key ], dim = 0 ) for key in all_images }
382375
383376 assert collated_batch ["image" ].shape [0 ] == num_rays_per_batch
384377
385378 # Needed to correct the random indices to their actual camera idx locations.
379+ c = indices [..., 0 ].flatten ()
386380 indices [:, 0 ] = batch ["image_idx" ][c ]
387381 collated_batch ["indices" ] = indices # with the abs camera indices
388382
0 commit comments