55
66import numpy as np
77from numpy .lib .recfunctions import append_fields
8- from pandas import DataFrame
8+ from pandas import DataFrame , RangeIndex
99from root_numpy import root2array , list_trees
1010from fnmatch import fnmatch
1111from root_numpy import list_branches
@@ -199,11 +199,13 @@ def do_flatten(arr, flatten):
199199 # XXX could explicitly clean up the opened TFiles with TChain::Reset
200200
201201 def genchunks ():
202+ current_index = 0
202203 for chunk in range (int (ceil (float (n_entries ) / chunksize ))):
203204 arr = root2array (paths , key , all_vars , start = chunk * chunksize , stop = (chunk + 1 ) * chunksize , selection = where , * args , ** kwargs )
204205 if flatten :
205206 arr = do_flatten (arr , flatten )
206- yield convert_to_dataframe (arr )
207+ yield convert_to_dataframe (arr , start_index = current_index )
208+ current_index += len (arr )
207209 return genchunks ()
208210
209211 arr = root2array (paths , key , all_vars , selection = where , * args , ** kwargs )
@@ -212,15 +214,17 @@ def genchunks():
212214 return convert_to_dataframe (arr )
213215
214216
215-
216- def convert_to_dataframe (array ):
217+ def convert_to_dataframe (array , start_index = None ):
217218 nonscalar_columns = get_nonscalar_columns (array )
218219 if nonscalar_columns :
219220 warnings .warn ("Ignored the following non-scalar branches: {bad_names}"
220221 .format (bad_names = ", " .join (nonscalar_columns )), UserWarning )
221222 indices = list (filter (lambda x : x .startswith ('__index__' ) and x not in nonscalar_columns , array .dtype .names ))
222223 if len (indices ) == 0 :
223- df = DataFrame .from_records (array , exclude = nonscalar_columns )
224+ index = None
225+ if start_index is not None :
226+ index = RangeIndex (start = start_index , stop = start_index + len (array ))
227+ df = DataFrame .from_records (array , exclude = nonscalar_columns , index = index )
224228 elif len (indices ) == 1 :
225229 # We store the index under the __index__* branch, where
226230 # * is the name of the index
@@ -235,7 +239,7 @@ def convert_to_dataframe(array):
235239 return df
236240
237241
238- def to_root (df , path , key = 'default' , mode = 'w' , * args , ** kwargs ):
242+ def to_root (df , path , key = 'default' , mode = 'w' , store_index = True , * args , ** kwargs ):
239243 """
240244 Write DataFrame to a ROOT file.
241245
@@ -247,6 +251,9 @@ def to_root(df, path, key='default', mode='w', *args, **kwargs):
247251 Name of tree that the DataFrame will be saved as
248252 mode: string, {'w', 'a'}
249253 Mode that the file should be opened in (default: 'w')
254+ store_index: bool (optional, default: True)
255+ Whether the index of the DataFrame should be stored as
256+ an __index__* branch in the tree
250257
251258 Notes
252259 -----
@@ -270,11 +277,12 @@ def to_root(df, path, key='default', mode='w', *args, **kwargs):
270277 from root_numpy import array2root
271278 # We don't want to modify the user's DataFrame here, so we make a shallow copy
272279 df_ = df .copy (deep = False )
273- name = df_ .index .name
274- if name is None :
275- # Handle the case where the index has no name
276- name = ''
277- df_ ['__index__' + name ] = df_ .index
280+ if store_index :
281+ name = df_ .index .name
282+ if name is None :
283+ # Handle the case where the index has no name
284+ name = ''
285+ df_ ['__index__' + name ] = df_ .index
278286 arr = df_ .to_records (index = False )
279287 array2root (arr , path , key , mode = mode , * args , ** kwargs )
280288
0 commit comments