Skip to content

Commit

Permalink
Add some support for relative stream notation
Browse files Browse the repository at this point in the history
Summary: We recently added the "TTT+NN" notation to reference the NNth stream of recordable type ID TTT in command line tools. Since pyvrs uses strings to designate streams, we can make pyvrs APIs take advantage of that notation.

Reviewed By: kiminoue7

Differential Revision: D52295114

fbshipit-source-id: 78ac3a4f87e2292c1013ef0ccddbfbe197a28dde
  • Loading branch information
Georges Berenger authored and facebook-github-bot committed Dec 19, 2023
1 parent 2b905e6 commit 0d32497
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 25 deletions.
25 changes: 14 additions & 11 deletions csrc/reader/MultiVRSReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,10 @@ std::vector<int32_t> OssMultiVRSReader::regenerateEnabledIndices(
}

for (const auto& streamId : streamIds) {
const StreamId id = StreamId::fromNumericName(streamId);
const StreamId id = reader_.getStreamForName(streamId);
if (!id.isValid()) {
throw StreamNotFoundError(streamId, reader_.getStreams());
}
streamIdSet.insert(id);
}

Expand Down Expand Up @@ -839,7 +842,11 @@ string OssMultiVRSReader::getStreamIdForIndex(int recordIndex) {
}

string OssMultiVRSReader::getSerialNumberForStream(const string& streamId) const {
return reader_.getSerialNumber(UniqueStreamId::fromNumericName(streamId));
const StreamId id = reader_.getStreamForName(streamId);
if (!id.isValid()) {
throw StreamNotFoundError(streamId, reader_.getStreams());
}
return reader_.getSerialNumber(id);
}

string OssMultiVRSReader::getStreamForSerialNumber(const string& streamSerialNumber) const {
Expand Down Expand Up @@ -1031,16 +1038,12 @@ void OssMultiVRSReader::initRecordSummaries() {
}

StreamId OssMultiVRSReader::getStreamId(const string& streamId) {
// Quick parsing of "NNN-DDD", two uint numbers separated by a '-'.
const StreamId id = StreamId::fromNumericName(streamId);
const auto& recordables = reader_.getStreams();
if (id.getTypeId() == RecordableTypeId::Undefined) {
throw py::value_error("Invalid stream ID: " + streamId);
}
if (recordables.find(id) != recordables.end()) {
return id;
// "NNN-DDD" or "NNN+DDD", two uint numbers separated by a '-' or '+'.
const StreamId id = reader_.getStreamForName(streamId);
if (!id.isValid()) {
throw StreamNotFoundError(streamId, reader_.getStreams());
}
throw StreamNotFoundError(id.getTypeId(), recordables);
return id;
}

PyObject* OssMultiVRSReader::getRecordInfo(
Expand Down
25 changes: 14 additions & 11 deletions csrc/reader/VRSReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,10 @@ vector<int32_t> OssVRSReader::regenerateEnabledIndices(
}

for (const auto& streamId : streamIds) {
const StreamId id = StreamId::fromNumericName(streamId);
const StreamId id = reader_.getStreamForName(streamId);
if (!id.isValid()) {
throw StreamNotFoundError(streamId, reader_.getStreams());
}
streamIdSet.insert(id);
}

Expand Down Expand Up @@ -875,7 +878,11 @@ string OssVRSReader::getStreamIdForIndex(int recordIndex) {
}

string OssVRSReader::getSerialNumberForStream(const string& streamId) const {
return reader_.getSerialNumber(StreamId::fromNumericName(streamId));
const StreamId id = reader_.getStreamForName(streamId);
if (!id.isValid()) {
throw StreamNotFoundError(streamId, reader_.getStreams());
}
return reader_.getSerialNumber(id);
}

string OssVRSReader::getStreamForSerialNumber(const string& streamSerialNumber) const {
Expand Down Expand Up @@ -1007,16 +1014,12 @@ void OssVRSReader::initRecordSummaries() {
}

StreamId OssVRSReader::getStreamId(const string& streamId) {
// Quick parsing of "NNN-DDD", two uint numbers separated by a '-'.
const StreamId id = StreamId::fromNumericName(streamId);
const auto& recordables = reader_.getStreams();
if (id.getTypeId() == RecordableTypeId::Undefined) {
throw py::value_error("Invalid stream ID: " + streamId);
}
if (recordables.find(id) != recordables.end()) {
return id;
// "NNN-DDD" or "NNN+DDD", two uint numbers separated by a '-' or '+'.
const StreamId id = reader_.getStreamForName(streamId);
if (!id.isValid()) {
throw StreamNotFoundError(streamId, reader_.getStreams());
}
throw StreamNotFoundError(id.getTypeId(), recordables);
return id;
}

bool OssVRSReader::match(
Expand Down
8 changes: 5 additions & 3 deletions csrc/utils/PyExceptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ class StreamNotFoundError : public std::exception {
public:
explicit StreamNotFoundError(
vrs::RecordableTypeId recordableTypeId,
const std::set<vrs::StreamId>& availableStreamIds)
: StreamNotFoundError(vrs::toString(recordableTypeId), availableStreamIds) {}
explicit StreamNotFoundError(
const std::string& streamId,
const std::set<vrs::StreamId>& availableStreamIds) {
std::stringstream ss;
ss << fmt::format(
"Matching stream not found for ID {0}. Available streams are:\n",
vrs::toString(recordableTypeId));
ss << fmt::format("No matching stream for {0}. Available streams are:\n", streamId);
for (auto it : availableStreamIds) {
ss << it.getName() << "\n";
}
Expand Down

0 comments on commit 0d32497

Please sign in to comment.