Skip to content

Commit

Permalink
Merge pull request #94 from chorus-ai/open_once_read_many
Browse files Browse the repository at this point in the history
Open once read many
  • Loading branch information
briangow authored Sep 20, 2024
2 parents 6e9e646 + 031aafa commit d78778d
Show file tree
Hide file tree
Showing 9 changed files with 1,044 additions and 86 deletions.
34 changes: 33 additions & 1 deletion BENCHMARK.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,38 @@ After loading the record written by `write_waveforms()`, you will need to return

Again, [`pickle.py`](./waveform_benchmark/formats/pickle.py) provides a simple example.

#### 6. Add your new format to the GitHub repository

#### 6. Add an `open_waveforms()` method.

This method is used for testing the read performance when a record written by `write_waveforms()` is opened once and multiple reads are performed against it. The method takes two arguments, `path`, `signal_names`. Please see `read_waveforms()` for argument documentation in section 5.

The function should return `opened_files`, a `dict`. The exact keys and values included in the dict is DEFINED BY THE IMPLEMENTOR. An example may be `filename` to file object (from `open(filename)`) mapping. The output is directly used as argument by the `read_opened_waveforms` function. Note that other types of data may be tsored in `open_files` including metadata and even the full signals, although such approach will increase memory utilization which will also be benchmarked.

- `opened_files.keys()` -> `dict_keys(['file1', 'file2'])`
- `opened_files['file1']` -> `<_io.BufferedReader name='./wavetest-h5y7dasl/wavetest/WV000001'>`

Again, [`pickle.py`](./waveform_benchmark/formats/pickle.py) provides a simple example.


#### 7. Add a `read_opened_waveforms()` method.

This method reads the record written by `write_waveforms()` that has been opened using the `open` function. The method takes four arguments, `opened_files`, `start_time`, `end_time`, `signal_names`. The last three arguments have the same definition as for `read_waveforms()`; please refer to section 5.

- `opened_files` (`dict`) is the dictionary objects that holds the relevant internal states and variables produced by your `open_waveforms()` method.
Caching such data reduces overheads for repeated, consecutive read operations.

The function has the same `dict` return type as the `read_waveforms()` function; please refer to section 5 for details.

Again, [`pickle.py`](./waveform_benchmark/formats/pickle.py) provides a simple example.

#### 8. Add a `close_waveforms()` method.

This method closes and clean up any open files, internal states, and variables produced by `open_waveforms()`. The method takes a single arguments, `opened_files`. Please see section 7 `read_opened_waveforms()` for argument description.

Again, [`pickle.py`](./waveform_benchmark/formats/pickle.py) provides a simple example.



#### 9. Add your new format to the GitHub repository

Once you have created your new module, you should contribute it to the [GitHub repository](https://github.com/chorus-ai/chorus_waveform/) by opening a pull request.
2 changes: 1 addition & 1 deletion waveform_benchmark/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def main():
test_list=test_list,
result_list=result_list,
test_only = opts.test_only,
mem_profile = opts.memory_profile)
mem_profile = opts.memory_profiling)

save_summary(format_list, waveform_list, test_list, result_list, opts.waveform_suite_summary_file)

Expand Down
416 changes: 334 additions & 82 deletions waveform_benchmark/benchmark.py

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions waveform_benchmark/formats/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,26 @@ def write_waveforms(self, path: str, waveforms: dict):
def read_waveforms(self, path: str, start_time: float, end_time: float,
signal_names: list):
raise NotImplementedError

# kwargs is a dictionary that can be used to pass additional arguments to the format
# replaces the total_length, block_length, and block_size params which are either test specific or intrinsic to the format.
@abc.abstractmethod
def open_waveforms(self, path: str, signal_names:list, **kwargs):
raise NotImplementedError

@abc.abstractmethod
def read_opened_waveforms(self, opened_files: dict, start_time: float, end_time: float,
signal_names: list):
raise NotImplementedError

@abc.abstractmethod
def close_waveforms(self, opened_files: dict):
raise NotImplementedError


def open_read_close_waveforms(self, path, start_time, end_time, signal_names, **kwargs):
opened_files = self.open_waveforms(path, signal_names, **kwargs)
output = self.read_opened_waveforms(opened_files, start_time, end_time, signal_names)
self.close_waveforms(opened_files)

return output
234 changes: 234 additions & 0 deletions waveform_benchmark/formats/dicom.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,10 @@ def read_waveforms(self, path, start_time, end_time, signal_names):
# then random access to read.
t1 = time.time()

if (not os.path.exists(path + "/DICOMDIR")):
print("ERROR: DICOMDIR file not found")
return {}

ds = dcmread(path + "/DICOMDIR")

file_info = {}
Expand Down Expand Up @@ -803,9 +807,239 @@ def read_waveforms(self, path, start_time, end_time, signal_names):
return output



# dicom value types are constrained by IOD type
# https://dicom.nema.org/medical/dicom/current/output/chtml/part03/PS3.3.html

def open_waveforms(self, path: str, signal_names:list, **kwargs):

signal_set = set(signal_names)
signal_set = {name.upper() : name for name in signal_set}
# ========== read from dicomdir file
# ideally - each file should have a list of channels inside, and the start and end time stamps.
# but we may have to open each file and read to gather that info
# read as file_set, then use metadata to get the table and see which files need to be accessed.
# then random access to read.

if (not os.path.exists(path + "/DICOMDIR")):
print("ERROR: DICOMDIR file required but not found")
return {}

# each should use a different padding value.
info_tags = ["WaveformSequence",
"MultiplexGroupTimeOffset",
"SamplingFrequency",
"WaveformPaddingValue",
"ChannelDefinitionSequence",
"ChannelSourceSequence",
"CodeMeaning",
]

ds = dcmread(path + "/DICOMDIR")

file_info = {}
for item in ds.DirectoryRecordSequence:
# if there is private tag data, use it.

if (item.DirectoryRecordType == "WAVEFORM") and (Tag(0x0099, 0x1001) in item):

# keep same order, do not use set
freqs = [float(x) for x in str.split(item[0x0099, 0x1011].value, sep = ',')]
samples = [int(x) for x in str.split(item[0x0099, 0x1012].value, sep = ',')]
stimes = [float(x) for x in str.split(item[0x0099, 0x1001].value, sep = ',')]
etimes = [x + float(y) / z for x, y, z in zip(stimes, samples, freqs)]

channels = str.split(item[0x0099, 0x1021].value, sep = ',')
canonical_channels = [x.upper() for x in channels]
group_ids = [ int(x) for x in str.split(item[0x0099, 0x1022].value, sep = ',')]
chan_ids = [int(x) for x in str.split(item[0x0099, 0x1023].value, sep = ',')]

# get group ids and set of channels for all available channels.
for (chan, group_id, chan_id) in zip(channels, group_ids, chan_ids):
stime = stimes[group_id]
etime = etimes[group_id]

# filtering here reduces the number of files to open
if (chan.upper() not in signal_set.keys()):
continue

# only add key if this is a file to be opened.
if item.ReferencedFileID not in file_info.keys():
file_info[item.ReferencedFileID] = {}

if group_id not in file_info[item.ReferencedFileID].keys():
file_info[item.ReferencedFileID][group_id] = []

# original channel name
channel_info = {'channel': chan,
'channel_idx': chan_id,
'freq': freqs[group_id],
'number_samples': samples[group_id],
'start_time': stime,
'end_time': etime}
file_info[item.ReferencedFileID][group_id].append(channel_info)

else:
# no metadata, so add mapping of None to indicate need to read metadata from file
file_info[item.ReferencedFileID] = None



# file_info contains either None (have to get from individual dicom file), or metadata for matched channel/time
for file_name, finfo in file_info.items():
fn = path + "/" + file_name

read_meta_from_file = (finfo is None)

# if metadata is in dicomdir, but no match by channel, then skip
if (not read_meta_from_file) and (len(finfo) == 0):
continue
# else we open the file and save it.

# if metadata is in dicomdir, then we have only required files in file_info.
# if metadata is not in dicomdir, then all files are listed and metadata needs to be retrieved.
# either way, need to read the file.
fobj = open(fn, 'rb')

# open the file
ds = dcmread(fobj, defer_size = 1000, specific_tags = info_tags)
seqs_raw = dcm_reader.get_tag(fobj, ds, 'WaveformSequence', defer_size = 1000)
seqs = cast(list[Dataset], seqs_raw)

if (not read_meta_from_file):
# already has metadata, so just save the fojb and seqs
finfo['fobj'] = fobj
finfo['seqs'] = seqs
continue

# else need to parse the metdata and check channels.
for group_idx, seq in enumerate(seqs):
# get the file metadata (can be saved in DICOMDIR in the future, but would need to change the channel metadata info.)
channel_infos = dcm_reader.get_waveform_seq_info(fobj, seq) # get channel info

# iterate over the channel_infos now.
for info in channel_infos:
chan = info['channel'].upper()

# channel not in selected. skip
if (chan not in signal_set.keys()):
continue
# else we should save it.
if group_idx not in finfo.keys():
finfo[group_idx] = []
finfo[group_idx].append(info)

if (len(finfo) == 0) > 0:
finfo['fobj'] = fobj
finfo['seqs'] = seqs

return file_info



def read_opened_waveforms(self, opened_files: dict, start_time: float, end_time: float, signal_names: list):

signal_set = set(signal_names)
signal_set = {name.upper() : name for name in signal_set}

# each should use a different padding value.
output = {}
info_tags = ["WaveformSequence",
"MultiplexGroupTimeOffset",
"SamplingFrequency",
"WaveformPaddingValue",
"ChannelDefinitionSequence",
"ChannelSourceSequence",
"CodeMeaning",
]
# file_info contains either None (have to get from individual dicom file), or metadata for matched channel/time
for file_name, finfo in opened_files.items():

fobj = finfo['fobj']
seqs = finfo['seqs']

arrs = {}
for group_idx, seq in enumerate(seqs):

# if group is not in finfo, then channel did not match. skip.
if group_idx not in finfo.keys():
continue

channel_infos = finfo[group_idx]

# iterate over the channel_infos now.
for info in channel_infos:
channel = info['channel'].upper()

if (channel not in signal_set.keys()):
continue

# compute start and end offsets in the file using timestamps
freq = float(info['freq'])
max_len = int(np.round(end_time * freq)) - int(np.round(start_time * freq))

# get multiplex group time window
gstart = float(info['start_time'])
nsamples = int(info['number_samples'])
gend = gstart + float(nsamples) / freq

# calculate the intersection of the time window
win_start = max(gstart, start_time)
win_end = min(gend, end_time)

if (win_start >= win_end):
# window is not possible
continue
# else we have a valid window

# compute the start and end offset for the source and destination
start_offset = max(0, int(np.round(win_start * freq) - np.round(gstart * freq) ))
end_offset = min(nsamples, int(np.round(win_end * freq) - np.round(gstart * freq) ))
# compute the start and end offset in the output for this channel
target_start = max(0, int(np.round(win_start * freq) - np.round(start_time * freq) ))
target_end = min(max_len, int(np.round(win_end * freq) - np.round(start_time * freq) ))

# print(" start - end: src time ", gstart, gend, " target time ", start_time, end_time, " freq", freq, " window ", win_start, win_end, " samples : src ", start_offset, end_offset, " target ", target_start, target_end)

nsamps = min(end_offset - start_offset, target_end - target_start)
end_offset = start_offset + nsamps
target_end = target_start + nsamps

if nsamps <= 0:
continue
# else we have a valid window with positive number of samples

# get info about the each channel present.
channel_idx = info['channel_idx']
requested_channel_name = signal_set[channel]
# load the data if never read. else use cached..
if group_idx not in arrs.keys():
item = cast(Dataset, seq)
arrs[group_idx] = dcm_reader.get_multiplex_array(fobj, item, start_offset, end_offset, as_raw = False)

# init the output if not previously allocated
if requested_channel_name not in output.keys():
output[requested_channel_name] = np.full(shape = max_len, fill_value = np.nan, dtype=np.float32)

# copy the data to the output
# print("copy ", arrs[group_idx].shape, " to ", output[channel].shape,
# " from ", target_start, " to ", target_end)
new_vals = arrs[group_idx][channel_idx, 0:nsamps]
output[requested_channel_name][target_start:target_end] = np.where(np.isfinite(new_vals), new_vals, output[requested_channel_name][target_start:target_end])

# now return output.
return output



def close_waveforms(self, opened_files: dict):
for file_name, finfo in opened_files.items():

finfo['fobj'].close()

opened_files.clear()



class DICOMHighBits(BaseDICOMFormat):
# waveform lead names to dicom IOD mapping. Incomplete.
Expand Down
Loading

0 comments on commit d78778d

Please sign in to comment.