-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Small changes: code reuse, simplify, doc comments, client timeouts #25
Conversation
Signed-off-by: declark1 <[email protected]>
Signed-off-by: declark1 <[email protected]>
Signed-off-by: declark1 <[email protected]>
Signed-off-by: declark1 <[email protected]>
Signed-off-by: declark1 <[email protected]>
Signed-off-by: declark1 <[email protected]>
Signed-off-by: declark1 <[email protected]>
Signed-off-by: declark1 <[email protected]>
@@ -92,7 +94,8 @@ async fn stream_classification_with_gen( | |||
State(state): State<Arc<ServerState>>, | |||
Json(request): Json<models::GuardrailsHttpRequest>, | |||
) -> Result<impl IntoResponse, (StatusCode, Json<String>)> { | |||
let task = StreamingClassificationWithGenTask::new(request); | |||
let request_id = Uuid::new_v4(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ideally we would be getting a request id / transaction id from our upstream and it would be good to stick those in logs for tracking purpose and use this as a "default", if not provided. Would you expect any issues with replace this request_id
later on? Since you have added it already in downstream functions, it would essentially look like replace request_id
with something from request
(although it may be in header, so we may need to do some processing)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We haven't instrumented tracing here yet, but I think the trace-id
should probably be used for tracking across services? We should probably still have a unique service-level request identifier though, which is what this is for.
If there are additional transaction id(s) provided from upstream services via headers, we can still extract and record them. I haven't seen anything documented on this to know what to expect.
model_id = %task.model_id, | ||
config = ?task.guardrails_config, | ||
"handling task" | ||
); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
q - can we still tell from this info log vs. the streaming one that this is the "unary" task vs. the "streaming" task?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, no. I was thinking to add detail once we implement streaming and/or additional methods.
I could make this simply "handling unary task" and "handling streaming task", but I assumed there could be additional task types added. I was trying to avoid "handling classification with gen task" and "handling streaming classification with gen task" as these are verbose.
Ideally, these API method names (and associated task handlers) can be renamed/simplified (generate
and generate_stream
would be nice)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll change it to "handling unary task" and "handling streaming task" for now.
Signed-off-by: declark1 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This PR is a batch of several small changes:
input_masks()
,input_detectors()
, andoutput_detectors()
helper methods toGuardrailsConfig
to reuse code shared by unary and streaming task handlers.chunk_and_detect()
todetect()
, consistent withchunk()
.detect()
andchunk()
code to be easier to follow for new Rustacianstasks
and separately,try_join_all()
is used to await them, collecting results toresults
tokio::spawn
fromhandle_chunk_task()
andhandle_detection_task()
request_id
to new requestsrequest_id
and excludeinputs
text (also addresses #18)test_deserialize_config()
unit test (tls
is currently a required field, although it should beOption
)