Skip to content

Commit

Permalink
added TryFrom<String> and Display for ModelProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
erhant committed Jul 30, 2024
1 parent a45ba2d commit 90ff247
Showing 1 changed file with 61 additions and 22 deletions.
83 changes: 61 additions & 22 deletions src/program/atomics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,38 @@ pub enum Model {
GPT4oMini,
}

impl From<Model> for String {
fn from(model: Model) -> Self {
model.to_string() // via Display
}
}

impl fmt::Display for Model {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// guaranteed not to fail because this is enum to string serialization
let self_str = serde_json::to_string(&self).unwrap_or_default();

// remove quotes from JSON
write!(f, "{}", self_str.trim_matches('"'))
}
}

impl TryFrom<LocalModel> for Model {
type Error = String;
fn try_from(value: LocalModel) -> Result<Self, Self::Error> {
Model::try_from(value.name)
}
}

impl TryFrom<String> for Model {
type Error = String;
fn try_from(value: String) -> Result<Self, Self::Error> {
// serde requires quotes (for JSON)
serde_json::from_str::<Self>(&format!("\"{}\"", value))
.map_err(|e| format!("Model {} invalid: {}", value, e))
}
}

/// A model provider is a service that hosts the chosen Model.
/// It can be derived from the model name, e.g. GPT4o is hosted by OpenAI (via API), or Phi3 is hosted by Ollama (locally).
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
Expand All @@ -290,13 +322,16 @@ impl From<Model> for ModelProvider {
}
}

impl From<Model> for String {
fn from(model: Model) -> Self {
model.to_string() // via Display
impl TryFrom<String> for ModelProvider {
type Error = String;
fn try_from(value: String) -> Result<Self, Self::Error> {
// serde requires quotes (for JSON)
serde_json::from_str::<Self>(&format!("\"{}\"", value))
.map_err(|e| format!("Model provider {} invalid: {}", value, e))
}
}

impl fmt::Display for Model {
impl fmt::Display for ModelProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// guaranteed not to fail because this is enum to string serialization
let self_str = serde_json::to_string(&self).unwrap_or_default();
Expand All @@ -306,27 +341,12 @@ impl fmt::Display for Model {
}
}

impl TryFrom<LocalModel> for Model {
type Error = String;
fn try_from(value: LocalModel) -> Result<Self, Self::Error> {
Model::try_from(value.name)
}
}

impl TryFrom<String> for Model {
type Error = String;
fn try_from(value: String) -> Result<Self, Self::Error> {
// serde requires quotes (for JSON)
serde_json::from_str::<Self>(&format!("\"{}\"", value))
.map_err(|e| format!("Model {} invalid: {}", value, e))
}
}

#[cfg(test)]
mod tests {
use super::*;

const MODEL_NAME: &str = "phi3:3.8b";
const PROVIDER_NAME: &str = "openai";
#[test]
fn test_model_string_conversion() {
let model = Model::Phi3Mini;
Expand Down Expand Up @@ -357,7 +377,26 @@ mod tests {
assert_eq!(model_from, model);

// (try) deserialize from invalid model
let model = serde_json::from_str::<Model>("\"this-model-does-not-will-not-exist\"");
assert!(model.is_err());
let bad_model = serde_json::from_str::<Model>("\"this-model-does-not-will-not-exist\"");
assert!(bad_model.is_err());
}

#[test]
fn test_provider_string_serde() {
let provider = ModelProvider::OpenAI;

// serialize to string via serde
let provider_str = serde_json::to_string(&provider).expect("should serialize");
assert_eq!(provider_str, format!("\"{}\"", PROVIDER_NAME));

// deserialize from string via serde
let provider_from: ModelProvider =
serde_json::from_str(&provider_str).expect("should deserialize");
assert_eq!(provider_from, provider);

// (try) deserialize from invalid model
let bad_provider =
serde_json::from_str::<ModelProvider>("\"this-provider-does-not-will-not-exist\"");
assert!(bad_provider.is_err());
}
}

0 comments on commit 90ff247

Please sign in to comment.