From 0d324973c72cf9cf631a2b1b0370c01ee7a9b563 Mon Sep 17 00:00:00 2001 From: Georges Berenger Date: Tue, 19 Dec 2023 13:58:08 -0800 Subject: [PATCH] Add some support for relative stream notation Summary: We recently added the "TTT+NN" notation to reference the NNth stream of recordable type ID TTT in command line tools. Since pyvrs uses strings to designate streams, we can make pyvrs APIs take advantage of that notation. Reviewed By: kiminoue7 Differential Revision: D52295114 fbshipit-source-id: 78ac3a4f87e2292c1013ef0ccddbfbe197a28dde --- csrc/reader/MultiVRSReader.cpp | 25 ++++++++++++++----------- csrc/reader/VRSReader.cpp | 25 ++++++++++++++----------- csrc/utils/PyExceptions.h | 8 +++++--- 3 files changed, 33 insertions(+), 25 deletions(-) diff --git a/csrc/reader/MultiVRSReader.cpp b/csrc/reader/MultiVRSReader.cpp index 5506dd8..c8e9608 100644 --- a/csrc/reader/MultiVRSReader.cpp +++ b/csrc/reader/MultiVRSReader.cpp @@ -804,7 +804,10 @@ std::vector OssMultiVRSReader::regenerateEnabledIndices( } for (const auto& streamId : streamIds) { - const StreamId id = StreamId::fromNumericName(streamId); + const StreamId id = reader_.getStreamForName(streamId); + if (!id.isValid()) { + throw StreamNotFoundError(streamId, reader_.getStreams()); + } streamIdSet.insert(id); } @@ -839,7 +842,11 @@ string OssMultiVRSReader::getStreamIdForIndex(int recordIndex) { } string OssMultiVRSReader::getSerialNumberForStream(const string& streamId) const { - return reader_.getSerialNumber(UniqueStreamId::fromNumericName(streamId)); + const StreamId id = reader_.getStreamForName(streamId); + if (!id.isValid()) { + throw StreamNotFoundError(streamId, reader_.getStreams()); + } + return reader_.getSerialNumber(id); } string OssMultiVRSReader::getStreamForSerialNumber(const string& streamSerialNumber) const { @@ -1031,16 +1038,12 @@ void OssMultiVRSReader::initRecordSummaries() { } StreamId OssMultiVRSReader::getStreamId(const string& streamId) { - // Quick parsing of "NNN-DDD", two uint numbers separated by a '-'. - const StreamId id = StreamId::fromNumericName(streamId); - const auto& recordables = reader_.getStreams(); - if (id.getTypeId() == RecordableTypeId::Undefined) { - throw py::value_error("Invalid stream ID: " + streamId); - } - if (recordables.find(id) != recordables.end()) { - return id; + // "NNN-DDD" or "NNN+DDD", two uint numbers separated by a '-' or '+'. + const StreamId id = reader_.getStreamForName(streamId); + if (!id.isValid()) { + throw StreamNotFoundError(streamId, reader_.getStreams()); } - throw StreamNotFoundError(id.getTypeId(), recordables); + return id; } PyObject* OssMultiVRSReader::getRecordInfo( diff --git a/csrc/reader/VRSReader.cpp b/csrc/reader/VRSReader.cpp index e5c0da5..c64e419 100644 --- a/csrc/reader/VRSReader.cpp +++ b/csrc/reader/VRSReader.cpp @@ -837,7 +837,10 @@ vector OssVRSReader::regenerateEnabledIndices( } for (const auto& streamId : streamIds) { - const StreamId id = StreamId::fromNumericName(streamId); + const StreamId id = reader_.getStreamForName(streamId); + if (!id.isValid()) { + throw StreamNotFoundError(streamId, reader_.getStreams()); + } streamIdSet.insert(id); } @@ -875,7 +878,11 @@ string OssVRSReader::getStreamIdForIndex(int recordIndex) { } string OssVRSReader::getSerialNumberForStream(const string& streamId) const { - return reader_.getSerialNumber(StreamId::fromNumericName(streamId)); + const StreamId id = reader_.getStreamForName(streamId); + if (!id.isValid()) { + throw StreamNotFoundError(streamId, reader_.getStreams()); + } + return reader_.getSerialNumber(id); } string OssVRSReader::getStreamForSerialNumber(const string& streamSerialNumber) const { @@ -1007,16 +1014,12 @@ void OssVRSReader::initRecordSummaries() { } StreamId OssVRSReader::getStreamId(const string& streamId) { - // Quick parsing of "NNN-DDD", two uint numbers separated by a '-'. - const StreamId id = StreamId::fromNumericName(streamId); - const auto& recordables = reader_.getStreams(); - if (id.getTypeId() == RecordableTypeId::Undefined) { - throw py::value_error("Invalid stream ID: " + streamId); - } - if (recordables.find(id) != recordables.end()) { - return id; + // "NNN-DDD" or "NNN+DDD", two uint numbers separated by a '-' or '+'. + const StreamId id = reader_.getStreamForName(streamId); + if (!id.isValid()) { + throw StreamNotFoundError(streamId, reader_.getStreams()); } - throw StreamNotFoundError(id.getTypeId(), recordables); + return id; } bool OssVRSReader::match( diff --git a/csrc/utils/PyExceptions.h b/csrc/utils/PyExceptions.h index 26841e6..5c08dd5 100644 --- a/csrc/utils/PyExceptions.h +++ b/csrc/utils/PyExceptions.h @@ -72,11 +72,13 @@ class StreamNotFoundError : public std::exception { public: explicit StreamNotFoundError( vrs::RecordableTypeId recordableTypeId, + const std::set& availableStreamIds) + : StreamNotFoundError(vrs::toString(recordableTypeId), availableStreamIds) {} + explicit StreamNotFoundError( + const std::string& streamId, const std::set& availableStreamIds) { std::stringstream ss; - ss << fmt::format( - "Matching stream not found for ID {0}. Available streams are:\n", - vrs::toString(recordableTypeId)); + ss << fmt::format("No matching stream for {0}. Available streams are:\n", streamId); for (auto it : availableStreamIds) { ss << it.getName() << "\n"; }