Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small changes: code reuse, simplify, doc comments, client timeouts #25

Merged
merged 9 commits into from
May 9, 2024
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
url = "2.5.0"
validator = { version = "0.18.1", features = ["derive"] } # For API validation
uuid = { version = "1.8.0", features = ["v4", "fast-rng"] }

[build-dependencies]
tonic-build = "0.11.0"
Expand Down
12 changes: 9 additions & 3 deletions src/clients.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#![allow(dead_code)]
use std::collections::HashMap;
use std::{collections::HashMap, time::Duration};

use futures::future::try_join_all;
use ginepro::LoadBalancedChannel;
Expand All @@ -19,6 +19,8 @@ pub use nlp::NlpClient;
pub const DEFAULT_TGIS_PORT: u16 = 8033;
pub const DEFAULT_CAIKIT_NLP_PORT: u16 = 8085;
pub const DEFAULT_DETECTOR_PORT: u16 = 8080;
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);

#[derive(Debug, thiserror::Error)]
pub enum Error {
Expand Down Expand Up @@ -72,7 +74,9 @@ pub async fn create_http_clients(
let port = service_config.port.unwrap_or(default_port);
let mut base_url = Url::parse(&service_config.hostname).unwrap();
base_url.set_port(Some(port)).unwrap();
let mut builder = reqwest::ClientBuilder::new();
let mut builder = reqwest::ClientBuilder::new()
.connect_timeout(DEFAULT_CONNECT_TIMEOUT)
.timeout(DEFAULT_REQUEST_TIMEOUT);
if let Some(Tls::Config(tls_config)) = &service_config.tls {
let cert_path = tls_config.cert_path.as_ref().unwrap().as_path();
let cert_pem = tokio::fs::read(cert_path).await.unwrap_or_else(|error| {
Expand Down Expand Up @@ -100,7 +104,9 @@ async fn create_grpc_clients<C>(
let mut builder = LoadBalancedChannel::builder((
service_config.hostname.clone(),
service_config.port.unwrap_or(default_port),
));
))
.connect_timeout(DEFAULT_CONNECT_TIMEOUT)
.timeout(DEFAULT_REQUEST_TIMEOUT);
let client_tls_config = if let Some(Tls::Config(tls_config)) = &service_config.tls {
let cert_path = tls_config.cert_path.as_ref().unwrap().as_path();
let key_path = tls_config.key_path.as_ref().unwrap().as_path();
Expand Down
1 change: 1 addition & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ detectors:
port: 9000
chunker_id: sentence-en
config: {}
tls: {}
"#;
let config: OrchestratorConfig = serde_yml::from_str(s)?;
assert!(config.chunkers.len() == 2 && config.detectors.len() == 1);
Expand Down
16 changes: 16 additions & 0 deletions src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,22 @@ pub struct GuardrailsConfig {
pub output: Option<GuardrailsConfigOutput>,
}

impl GuardrailsConfig {
pub fn input_masks(&self) -> Option<&[(usize, usize)]> {
self.input.as_ref().and_then(|input| input.masks.as_deref())
}

pub fn input_detectors(&self) -> Option<&HashMap<String, DetectorParams>> {
self.input.as_ref().and_then(|input| input.models.as_ref())
}

pub fn output_detectors(&self) -> Option<&HashMap<String, DetectorParams>> {
self.output
.as_ref()
.and_then(|output| output.models.as_ref())
}
}

/// Configuration for detection on input to a text generation model (e.g. user prompt)
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize, validator::Validate)]
pub struct GuardrailsConfigInput {
Expand Down
Loading