@@ -827,7 +827,7 @@ def set(
827827 if not self .initialized :
828828 if not isinstance (cursor , INT_CLASSES ):
829829 if is_tensor_collection (data ):
830- self ._init (data [ 0 ])
830+ self ._init (data , shape = data . shape [ 1 : ])
831831 else :
832832 self ._init (tree_map (lambda x : x [0 ], data ))
833833 else :
@@ -873,7 +873,7 @@ def set( # noqa: F811
873873 )
874874 if not self .initialized :
875875 if not isinstance (cursor , INT_CLASSES ):
876- self ._init (data [ 0 ])
876+ self ._init (data , shape = data . shape [ 1 : ])
877877 else :
878878 self ._init (data )
879879 if not isinstance (cursor , (* INT_CLASSES , slice )):
@@ -993,6 +993,15 @@ class LazyTensorStorage(TensorStorage):
993993 Defaults to ``False``.
994994 consolidated (bool, optional): if ``True``, the storage will be consolidated after
995995 its first expansion. Defaults to ``False``.
996+ empty_lazy (bool, optional): if ``True``, any lazy tensordict in the first tensordict
997+ passed to the storage will be emptied of its content. This can be used to store
998+ ragged data or content with exclusive keys (e.g., when some but not all environments
999+ provide extra data to be stored in the buffer).
1000+ Setting `empty_lazy` to `True` requires :meth:`~.extend` to be called first (a call to `add`
1001+ will result in an exception).
1002+ Recall that data stored in lazy stacks is not stored contiguously in memory: indexing can be
1003+ slower than contiguous data and serialization is more hazardous. Use with caution!
1004+ Defaults to ``False``.
9961005
9971006 Examples:
9981007 >>> data = TensorDict({
@@ -1054,6 +1063,7 @@ def __init__(
10541063 ndim : int = 1 ,
10551064 compilable : bool = False ,
10561065 consolidated : bool = False ,
1066+ empty_lazy : bool = False ,
10571067 ):
10581068 super ().__init__ (
10591069 storage = None ,
@@ -1062,11 +1072,13 @@ def __init__(
10621072 ndim = ndim ,
10631073 compilable = compilable ,
10641074 )
1075+ self .empty_lazy = empty_lazy
10651076 self .consolidated = consolidated
10661077
10671078 def _init (
10681079 self ,
10691080 data : TensorDictBase | torch .Tensor | PyTree , # noqa: F821
1081+ shape : torch .Size | None = None ,
10701082 ) -> None :
10711083 if not self ._compilable :
10721084 # TODO: Investigate why this seems to have a performance impact with
@@ -1087,8 +1099,14 @@ def max_size_along_dim0(data_shape):
10871099
10881100 if is_tensor_collection (data ):
10891101 out = data .to (self .device )
1090- out : TensorDictBase = torch .empty_like (
1091- out .expand (max_size_along_dim0 (data .shape ))
1102+ if self .empty_lazy and shape is None :
1103+ raise RuntimeError (
1104+ "Make sure you have called `extend` and not `add` first when setting `empty_lazy=True`."
1105+ )
1106+ elif shape is None :
1107+ shape = data .shape
1108+ out : TensorDictBase = out .new_empty (
1109+ max_size_along_dim0 (shape ), empty_lazy = self .empty_lazy
10921110 )
10931111 if self .consolidated :
10941112 out = out .consolidate ()
@@ -1286,7 +1304,9 @@ def load_state_dict(self, state_dict):
12861304 self .initialized = state_dict ["initialized" ]
12871305 self ._len = state_dict ["_len" ]
12881306
1289- def _init (self , data : TensorDictBase | torch .Tensor ) -> None :
1307+ def _init (
1308+ self , data : TensorDictBase | torch .Tensor , * , shape : torch .Size | None = None
1309+ ) -> None :
12901310 torchrl_logger .debug ("Creating a MemmapStorage..." )
12911311 if self .device == "auto" :
12921312 self .device = data .device
@@ -1304,8 +1324,14 @@ def max_size_along_dim0(data_shape):
13041324 return (self .max_size , * data_shape )
13051325
13061326 if is_tensor_collection (data ):
1327+ if shape is None :
1328+ # Within add()
1329+ shape = data .shape
1330+ else :
1331+ # Get the first element - we don't care about empty_lazy in memmap storages
1332+ data = data [0 ]
13071333 out = data .clone ().to (self .device )
1308- out = out .expand (max_size_along_dim0 (data . shape ))
1334+ out = out .expand (max_size_along_dim0 (shape ))
13091335 out = out .memmap_like (prefix = self .scratch_dir , existsok = self .existsok )
13101336 if torchrl_logger .isEnabledFor (logging .DEBUG ):
13111337 for key , tensor in sorted (
0 commit comments