Skip to content

Commit 90f3076

Browse files
Add mcap support to osi3trace (#8)
* Add mcap support for osi3trace * Deprecate single-channel only methods, unify new methods * Make implementation classes private Signed-off-by: Thomas Sedlmayer <tsedlmayer@pmsfit.de> Signed-off-by: Pierre R. Mai <pmai@pmsf.de> Co-authored-by: Pierre R. Mai <pmai@pmsf.de>
1 parent 77cc5d5 commit 90f3076

File tree

5 files changed

+1460
-31
lines changed

5 files changed

+1460
-31
lines changed

osi3trace/osi_trace.py

Lines changed: 295 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,15 @@
33
"""
44

55
import lzma
6+
from pathlib import Path
67
import struct
78

9+
from abc import ABC, abstractmethod
10+
from typing_extensions import deprecated
11+
12+
from mcap_protobuf.decoder import DecoderFactory
13+
from mcap.reader import make_reader
14+
815
from osi3.osi_sensorview_pb2 import SensorView
916
from osi3.osi_sensorviewconfiguration_pb2 import SensorViewConfiguration
1017
from osi3.osi_groundtruth_pb2 import GroundTruth
@@ -32,7 +39,7 @@
3239

3340

3441
class OSITrace:
35-
"""This class can import and decode OSI trace files."""
42+
"""This class can import and decode OSI single- and multi-channel trace files."""
3643

3744
@staticmethod
3845
def map_message_type(type_name):
@@ -44,30 +51,207 @@ def message_types():
4451
"""Message types that OSITrace supports."""
4552
return list(MESSAGES_TYPE.keys())
4653

54+
_legacy_ositrace_attributes = {
55+
"type",
56+
"file",
57+
"current_index",
58+
"message_offsets",
59+
"read_complete",
60+
"message_cache",
61+
}
62+
63+
def __getattr__(self, name):
64+
"""
65+
This method forwards the getattr call for unsuccessful legacy attribute
66+
name lookups to the reader in case it is an _OSITraceSingle instance.
67+
"""
68+
if name in self._legacy_ositrace_attributes and isinstance(
69+
self.reader, _OSITraceSingle
70+
):
71+
return getattr(self.reader, name)
72+
raise AttributeError(
73+
f"'{type(self).__name__}' object has no attribute '{name}'"
74+
)
75+
76+
def __setattr__(self, name, value):
77+
"""
78+
This method overwrites the default setter and forwards setattr calls for
79+
legacy attribute names to the reader in case the reader is an
80+
_OSITraceSingle instance. Otherwise it uses the default setter.
81+
"""
82+
reader = (
83+
super().__getattribute__("reader") if "reader" in self.__dict__ else None
84+
)
85+
if name in self._legacy_ositrace_attributes and isinstance(
86+
reader, _OSITraceSingle
87+
):
88+
setattr(reader, name, value)
89+
else:
90+
super().__setattr__(name, value)
91+
92+
def __dir__(self):
93+
attrs = super().__dir__()
94+
if isinstance(self.reader, _OSITraceSingle):
95+
attrs += list(self._legacy_ositrace_attributes)
96+
return attrs
97+
98+
def __init__(
99+
self, path=None, type_name="SensorView", cache_messages=False, topic=None
100+
):
101+
"""
102+
Initializes the trace reader depending on the trace file format.
103+
104+
Args:
105+
path (str): The path to the trace file.
106+
type_name (str): The type name of the messages in the trace; check supported message types with `OSITrace.message_types()`.
107+
cache_messages (bool): Whether to cache messages in memory (only applies to single-channel traces).
108+
topic (str): The topic name for multi-channel traces (only applies to multi-channel traces); Using the first available topic if not specified.
109+
"""
110+
self.reader = None
111+
112+
if path is not None:
113+
self.reader = self._init_reader(
114+
Path(path), type_name, cache_messages, topic
115+
)
116+
117+
def _init_reader(self, path, type_name, cache_messages, topic):
118+
if not path.exists():
119+
raise FileNotFoundError("File not found")
120+
121+
if path.suffix.lower() == ".mcap":
122+
return _OSITraceMulti(path, type_name, topic)
123+
elif path.suffix.lower() in [".osi", ".lzma", ".xz"]:
124+
return _OSITraceSingle(path, type_name, cache_messages)
125+
else:
126+
raise ValueError(f"Unsupported file format: '{path.suffix}'")
127+
128+
def from_file(self, path, type_name="SensorView", cache_messages=False, topic=None):
129+
"""
130+
Initializes the trace reader depending on the trace file format.
131+
132+
Args:
133+
path (str): The path to the trace file.
134+
type_name (str): The type name of the messages in the trace; check supported message types with `OSITrace.message_types()`.
135+
cache_messages (bool): Whether to cache messages in memory (only applies to single-channel traces).
136+
topic (str): The topic name for multi-channel traces (only applies to multi-channel traces); Using the first available topic if not specified.
137+
"""
138+
self.reader = self._init_reader(Path(path), type_name, cache_messages, topic)
139+
140+
def restart(self, index=None):
141+
"""
142+
Restart the trace reader.
143+
144+
Note:
145+
Multi-channel traces don't support restarting from a specific index.
146+
"""
147+
return self.reader.restart(index)
148+
149+
def __iter__(self):
150+
return self.reader.__iter__()
151+
152+
def close(self):
153+
return self.reader.close()
154+
155+
@deprecated(
156+
"This is a legacy interface only supported for single-channel traces, which will be removed in future versions."
157+
)
158+
def retrieve_offsets(self, limit=None):
159+
if isinstance(self.reader, _OSITraceSingle):
160+
return self.reader.retrieve_offsets(limit)
161+
raise NotImplementedError(
162+
"Offsets are only supported for single-channel traces."
163+
)
164+
165+
@deprecated(
166+
"This is a legacy interface only supported for single-channel traces, which will be removed in future versions."
167+
)
168+
def retrieve_message(self, index=None, skip=False):
169+
if isinstance(self.reader, _OSITraceSingle):
170+
return self.reader.retrieve_message(index, skip)
171+
raise NotImplementedError(
172+
"Index-based message retrieval is only supported for single-channel traces."
173+
)
174+
175+
@deprecated(
176+
"This is a legacy interface only supported for single-channel traces, which will be removed in future versions."
177+
)
178+
def get_message_by_index(self, index):
179+
if isinstance(self.reader, _OSITraceSingle):
180+
return self.reader.get_message_by_index(index)
181+
raise NotImplementedError(
182+
"Index-based message retrieval is only supported for single-channel traces."
183+
)
184+
185+
@deprecated(
186+
"This is a legacy interface only supported for single-channel traces, which will be removed in future versions."
187+
)
188+
def get_messages_in_index_range(self, begin, end):
189+
if isinstance(self.reader, _OSITraceSingle):
190+
return self.reader.get_messages_in_index_range(begin, end)
191+
raise NotImplementedError(
192+
"Index-based message retrieval is only supported for single-channel traces."
193+
)
194+
195+
def get_available_topics(self):
196+
return self.reader.get_available_topics()
197+
198+
def get_file_metadata(self):
199+
return self.reader.get_file_metadata()
200+
201+
def get_channel_metadata(self):
202+
return self.reader.get_channel_metadata()
203+
204+
205+
class _ReaderBase(ABC):
206+
"""Common interface for trace readers"""
207+
208+
@abstractmethod
209+
def restart(self, index=None):
210+
pass
211+
212+
@abstractmethod
213+
def __iter__(self):
214+
pass
215+
216+
@abstractmethod
217+
def close(self):
218+
pass
219+
220+
@abstractmethod
221+
def get_available_topics(self):
222+
pass
223+
224+
@abstractmethod
225+
def get_file_metadata(self):
226+
pass
227+
228+
@abstractmethod
229+
def get_channel_metadata(self):
230+
pass
231+
232+
233+
class _OSITraceSingle(_ReaderBase):
234+
"""OSI single-channel trace reader"""
235+
47236
def __init__(self, path=None, type_name="SensorView", cache_messages=False):
48-
self.type = self.map_message_type(type_name)
237+
self.type = OSITrace.map_message_type(type_name)
49238
self.file = None
50239
self.current_index = None
51240
self.message_offsets = None
52241
self.read_complete = False
53242
self.message_cache = {} if cache_messages else None
54243
self._header_length = 4
55244
if path:
56-
self.from_file(path, type_name, cache_messages)
245+
self.type = OSITrace.map_message_type(type_name)
57246

58-
def from_file(self, path, type_name="SensorView", cache_messages=False):
59-
"""Import a trace from a file"""
60-
self.type = self.map_message_type(type_name)
61-
62-
if path.lower().endswith((".lzma", ".xz")):
63-
self.file = lzma.open(path, "rb")
64-
else:
65-
self.file = open(path, "rb")
66-
67-
self.read_complete = False
68-
self.current_index = 0
69-
self.message_offsets = [0]
70-
self.message_cache = {} if cache_messages else None
247+
if path.suffix.lower() in [".lzma", ".xz"]:
248+
self.file = lzma.open(path, "rb")
249+
else:
250+
self.file = open(path, "rb")
251+
self.read_complete = False
252+
self.current_index = 0
253+
self.message_offsets = [0]
254+
self.message_cache = {} if cache_messages else None
71255

72256
def retrieve_offsets(self, limit=None):
73257
"""Retrieve the offsets of the messages from the file."""
@@ -186,3 +370,98 @@ def close(self):
186370
self.read_complete = False
187371
self.read_limit = None
188372
self.type = None
373+
374+
def get_available_topics(self):
375+
raise NotImplementedError(
376+
"Getting available topics is only supported for multi-channel traces."
377+
)
378+
379+
def get_file_metadata(self):
380+
raise NotImplementedError(
381+
"Getting file metadata is only supported for multi-channel traces."
382+
)
383+
384+
def get_channel_metadata(self):
385+
raise NotImplementedError(
386+
"Getting channel metadata is only supported for multi-channel traces."
387+
)
388+
389+
390+
class _OSITraceMulti(_ReaderBase):
391+
"""OSI multi-channel trace reader"""
392+
393+
def __init__(self, path, type_name, topic):
394+
self._file = open(path, "rb")
395+
self._mcap_reader = make_reader(self._file)
396+
self._iter = None
397+
self._summary = self._mcap_reader.get_summary()
398+
available_topics = self.get_available_topics(type_name)
399+
if topic == None:
400+
topic = next(iter(available_topics), None)
401+
if topic not in available_topics:
402+
raise ValueError(
403+
f"The requested topic '{topic}' is not present in the trace file or is not of type '{type_name}'."
404+
)
405+
self.topic = topic
406+
407+
def restart(self, index=None):
408+
if index != None:
409+
raise NotImplementedError(
410+
"Restarting from a given index is not supported for multi-channel traces."
411+
)
412+
self._iter = None
413+
414+
def __iter__(self):
415+
"""Stateful iterator over the channel's messages in log time order."""
416+
if self._iter is None:
417+
self._iter = self._mcap_reader.iter_messages(topics=[self.topic])
418+
419+
message_class = OSITrace.map_message_type(self.get_message_type())
420+
421+
for _, _, message in self._iter:
422+
msg = message_class()
423+
msg.ParseFromString(message.data)
424+
yield msg
425+
426+
def close(self):
427+
if self._file:
428+
self._file.close()
429+
self._file = None
430+
self._mcap_reader = None
431+
self._summary = None
432+
self._iter = None
433+
434+
def get_available_topics(self, type_name=None):
435+
return [
436+
channel.topic
437+
for channel in self._summary.channels.values()
438+
if self._channel_is_of_type(channel, type_name)
439+
]
440+
441+
def get_file_metadata(self):
442+
metadata = []
443+
for metadata_entry in self._mcap_reader.iter_metadata():
444+
metadata.append(metadata_entry)
445+
return metadata
446+
447+
def get_channel_metadata(self):
448+
for id, channel in self._summary.channels.items():
449+
if channel.topic == self.topic:
450+
return channel.metadata
451+
return None
452+
453+
def get_message_type(self):
454+
for channel in self._summary.channels.values():
455+
if channel.topic == self.topic:
456+
schema = self._summary.schemas[channel.schema_id]
457+
if schema.name.startswith("osi3."):
458+
return schema.name[len("osi3.") :]
459+
else:
460+
raise ValueError(
461+
f"Schema '{schema.name}' is not an 'osi3.' schema."
462+
)
463+
return None
464+
465+
def _channel_is_of_type(self, channel, type_name):
466+
schema = self._summary.schemas[channel.schema_id]
467+
return type_name is None or schema.name == f"osi3.{type_name}"

0 commit comments

Comments
 (0)