diff --git a/neo/rawio/micromedrawio.py b/neo/rawio/micromedrawio.py index 52dd222bd..aa9bae776 100644 --- a/neo/rawio/micromedrawio.py +++ b/neo/rawio/micromedrawio.py @@ -52,7 +52,6 @@ def __init__(self, filename=""): def _parse_header(self): - self._buffer_descriptions = {0: {0: {}}} with open(self.filename, "rb") as fid: f = StructFile(fid) @@ -67,6 +66,7 @@ def _parse_header(self): rec_datetime = datetime.datetime(year + 1900, month, day, hour, minute, sec) Data_Start_Offset, Num_Chan, Multiplexer, Rate_Min, Bytes = f.read_f("IHHHH", offset=138) + sig_dtype = "u" + str(Bytes) # header version (header_version,) = f.read_f("b", offset=175) @@ -99,25 +99,37 @@ def _parse_header(self): if zname != zname2.decode("ascii").strip(" "): raise NeoReadWriteError("expected the zone name to match") - # raw signals memmap - sig_dtype = "u" + str(Bytes) - signal_shape = get_memmap_shape(self.filename, sig_dtype, num_channels=Num_Chan, offset=Data_Start_Offset) - buffer_id = "0" - stream_id = "0" - self._buffer_descriptions[0][0][buffer_id] = { - "type": "raw", - "file_path": str(self.filename), - "dtype": sig_dtype, - "order": "C", - "file_offset": 0, - "shape": signal_shape, - } + + # "TRONCA" zone define segments + zname2, pos, length = zones["TRONCA"] + f.seek(pos) + # this number avoid a infinite loop in case of corrupted TRONCA zone (seg_start!=0 and trace_offset!=0) + max_segments = 100 + self.info_segments = [] + for i in range(max_segments): + # 4 bytes u4 each + seg_start = int(np.frombuffer(f.read(4), dtype="u4")[0]) + trace_offset = int(np.frombuffer(f.read(4), dtype="u4")[0]) + if seg_start == 0 and trace_offset == 0: + break + else: + self.info_segments.append((seg_start, trace_offset)) + + if len(self.info_segments) == 0: + # one unique segment = general case + self.info_segments.append((0, 0)) + + nb_segment = len(self.info_segments) # Reading Code Info zname2, pos, length = zones["ORDER"] f.seek(pos) code = np.frombuffer(f.read(Num_Chan * 2), dtype="u2") + # unique stream and buffer + buffer_id = "0" + stream_id = "0" + units_code = {-1: "nV", 0: "uV", 1: "mV", 2: 1, 100: "percent", 101: "dimensionless", 102: "dimensionless"} signal_channels = [] sig_grounds = [] @@ -140,10 +152,8 @@ def _parse_header(self): (sampling_rate,) = f.read_f("H") sampling_rate *= Rate_Min chan_id = str(c) + signal_channels.append((chan_name, chan_id, sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id)) - signal_channels.append( - (chan_name, chan_id, sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id) - ) signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype) @@ -155,6 +165,32 @@ def _parse_header(self): raise NeoReadWriteError("The sampling rates must be the same across signal channels") self._sampling_rate = float(np.unique(signal_channels["sampling_rate"])[0]) + # memmap traces buffer + full_signal_shape = get_memmap_shape(self.filename, sig_dtype, num_channels=Num_Chan, offset=Data_Start_Offset) + seg_limits = [trace_offset for seg_start, trace_offset in self.info_segments] + [full_signal_shape[0]] + self._t_starts = [] + self._buffer_descriptions = {0 :{}} + for seg_index in range(nb_segment): + seg_start, trace_offset = self.info_segments[seg_index] + self._t_starts.append(seg_start / self._sampling_rate) + + start = seg_limits[seg_index] + stop = seg_limits[seg_index + 1] + + shape = (stop - start, Num_Chan) + file_offset = Data_Start_Offset + ( start * np.dtype(sig_dtype).itemsize * Num_Chan) + self._buffer_descriptions[0][seg_index] = {} + self._buffer_descriptions[0][seg_index][buffer_id] = { + "type" : "raw", + "file_path" : str(self.filename), + "dtype" : sig_dtype, + "order": "C", + "file_offset" : file_offset, + "shape" : shape, + } + + + # Event channels event_channels = [] event_channels.append(("Trigger", "", "event")) @@ -176,13 +212,18 @@ def _parse_header(self): dtype = np.dtype(ev_dtype) rawevent = np.memmap(self.filename, dtype=dtype, mode="r", offset=pos, shape=length // dtype.itemsize) - keep = ( - (rawevent["start"] >= rawevent["start"][0]) - & (rawevent["start"] < signal_shape[0]) - & (rawevent["start"] != 0) - ) - rawevent = rawevent[keep] - self._raw_events.append(rawevent) + # important : all events timing are related to the first segment t_start + self._raw_events.append([]) + for seg_index in range(nb_segment): + left_lim = seg_limits[seg_index] + right_lim = seg_limits[seg_index + 1] + keep = ( + (rawevent["start"] >= left_lim) + & (rawevent["start"] < right_lim) + & (rawevent["start"] != 0) + ) + self._raw_events[-1].append(rawevent[keep]) + # No spikes spike_channels = [] @@ -191,7 +232,7 @@ def _parse_header(self): # fille into header dict self.header = {} self.header["nb_block"] = 1 - self.header["nb_segment"] = [1] + self.header["nb_segment"] = [nb_segment] self.header["signal_buffers"] = signal_buffers self.header["signal_streams"] = signal_streams self.header["signal_channels"] = signal_channels @@ -216,38 +257,40 @@ def _source_name(self): return self.filename def _segment_t_start(self, block_index, seg_index): - return 0.0 + return self._t_starts[seg_index] def _segment_t_stop(self, block_index, seg_index): - sig_size = self.get_signal_size(block_index, seg_index, 0) - t_stop = sig_size / self._sampling_rate - return t_stop + duration = self.get_signal_size(block_index, seg_index, stream_index=0) / self._sampling_rate + return duration + self.segment_t_start(block_index, seg_index) def _get_signal_t_start(self, block_index, seg_index, stream_index): - if stream_index != 0: - raise ValueError("`stream_index` must be 0") - return 0.0 + assert stream_index == 0 + return self._t_starts[seg_index] def _spike_count(self, block_index, seg_index, unit_index): return 0 def _event_count(self, block_index, seg_index, event_channel_index): - n = self._raw_events[event_channel_index].size + n = self._raw_events[event_channel_index][seg_index].size return n def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): - raw_event = self._raw_events[event_channel_index] + raw_event = self._raw_events[event_channel_index][seg_index] + + # important : all events timing are related to the first segment t_start + seg_start0, _ = self.info_segments[0] if t_start is not None: - keep = raw_event["start"] >= int(t_start * self._sampling_rate) + keep = raw_event["start"] + seg_start0 >= int(t_start * self._sampling_rate) raw_event = raw_event[keep] if t_stop is not None: - keep = raw_event["start"] <= int(t_stop * self._sampling_rate) + keep = raw_event["start"] + seg_start0 <= int(t_stop * self._sampling_rate) raw_event = raw_event[keep] - timestamp = raw_event["start"] + timestamp = raw_event["start"] + seg_start0 + if event_channel_index < 2: durations = None else: diff --git a/neo/test/rawiotest/test_micromedrawio.py b/neo/test/rawiotest/test_micromedrawio.py index c74ea857c..1d7f616e1 100644 --- a/neo/test/rawiotest/test_micromedrawio.py +++ b/neo/test/rawiotest/test_micromedrawio.py @@ -8,6 +8,7 @@ from neo.test.rawiotest.common_rawio_test import BaseTestRawIO +import numpy as np class TestMicromedRawIO( BaseTestRawIO, @@ -15,7 +16,42 @@ class TestMicromedRawIO( ): rawioclass = MicromedRawIO entities_to_download = ["micromed"] - entities_to_test = ["micromed/File_micromed_1.TRC"] + entities_to_test = [ + "micromed/File_micromed_1.TRC", + "micromed/File_mircomed2.TRC", + "micromed/File_mircomed2_2segments.TRC", + ] + + def test_micromed_multi_segments(self): + file_full = self.get_local_path("micromed/File_mircomed2.TRC") + file_splitted = self.get_local_path("micromed/File_mircomed2_2segments.TRC") + + # the second file contains 2 pieces of the first file + # so it is 2 segments with the same traces but reduced + # note that traces in the splited can differ at the very end of the cut + + reader1 = MicromedRawIO(file_full) + reader1.parse_header() + assert reader1.segment_count(block_index=0) == 1 + assert reader1.get_signal_t_start(block_index=0, seg_index=0, stream_index=0) == 0. + traces1 = reader1.get_analogsignal_chunk(stream_index=0) + + reader2 = MicromedRawIO(file_splitted) + reader2.parse_header() + print(reader2) + assert reader2.segment_count(block_index=0) == 2 + + # check that pieces of the second file is equal to the first file (except a truncation at the end) + for seg_index in range(2): + t_start = reader2.get_signal_t_start(block_index=0, seg_index=seg_index, stream_index=0) + assert t_start > 0 + sr = reader2.get_signal_sampling_rate(stream_index=0) + ind_start = int(t_start * sr) + traces2 = reader2.get_analogsignal_chunk(block_index=0, seg_index=seg_index, stream_index=0) + traces1_chunk = traces1[ind_start: ind_start+traces2.shape[0]] + # we remove the last 100 sample because tools that cut traces is truncating the last buffer + assert np.array_equal(traces2[:-100], traces1_chunk[:-100]) + if __name__ == "__main__":