diff --git a/src/program/atomics.rs b/src/program/atomics.rs index 8417805..a6f8379 100644 --- a/src/program/atomics.rs +++ b/src/program/atomics.rs @@ -264,6 +264,38 @@ pub enum Model { GPT4oMini, } +impl From for String { + fn from(model: Model) -> Self { + model.to_string() // via Display + } +} + +impl fmt::Display for Model { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // guaranteed not to fail because this is enum to string serialization + let self_str = serde_json::to_string(&self).unwrap_or_default(); + + // remove quotes from JSON + write!(f, "{}", self_str.trim_matches('"')) + } +} + +impl TryFrom for Model { + type Error = String; + fn try_from(value: LocalModel) -> Result { + Model::try_from(value.name) + } +} + +impl TryFrom for Model { + type Error = String; + fn try_from(value: String) -> Result { + // serde requires quotes (for JSON) + serde_json::from_str::(&format!("\"{}\"", value)) + .map_err(|e| format!("Model {} invalid: {}", value, e)) + } +} + /// A model provider is a service that hosts the chosen Model. /// It can be derived from the model name, e.g. GPT4o is hosted by OpenAI (via API), or Phi3 is hosted by Ollama (locally). #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] @@ -290,13 +322,16 @@ impl From for ModelProvider { } } -impl From for String { - fn from(model: Model) -> Self { - model.to_string() // via Display +impl TryFrom for ModelProvider { + type Error = String; + fn try_from(value: String) -> Result { + // serde requires quotes (for JSON) + serde_json::from_str::(&format!("\"{}\"", value)) + .map_err(|e| format!("Model provider {} invalid: {}", value, e)) } } -impl fmt::Display for Model { +impl fmt::Display for ModelProvider { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // guaranteed not to fail because this is enum to string serialization let self_str = serde_json::to_string(&self).unwrap_or_default(); @@ -306,27 +341,12 @@ impl fmt::Display for Model { } } -impl TryFrom for Model { - type Error = String; - fn try_from(value: LocalModel) -> Result { - Model::try_from(value.name) - } -} - -impl TryFrom for Model { - type Error = String; - fn try_from(value: String) -> Result { - // serde requires quotes (for JSON) - serde_json::from_str::(&format!("\"{}\"", value)) - .map_err(|e| format!("Model {} invalid: {}", value, e)) - } -} - #[cfg(test)] mod tests { use super::*; const MODEL_NAME: &str = "phi3:3.8b"; + const PROVIDER_NAME: &str = "openai"; #[test] fn test_model_string_conversion() { let model = Model::Phi3Mini; @@ -357,7 +377,26 @@ mod tests { assert_eq!(model_from, model); // (try) deserialize from invalid model - let model = serde_json::from_str::("\"this-model-does-not-will-not-exist\""); - assert!(model.is_err()); + let bad_model = serde_json::from_str::("\"this-model-does-not-will-not-exist\""); + assert!(bad_model.is_err()); + } + + #[test] + fn test_provider_string_serde() { + let provider = ModelProvider::OpenAI; + + // serialize to string via serde + let provider_str = serde_json::to_string(&provider).expect("should serialize"); + assert_eq!(provider_str, format!("\"{}\"", PROVIDER_NAME)); + + // deserialize from string via serde + let provider_from: ModelProvider = + serde_json::from_str(&provider_str).expect("should deserialize"); + assert_eq!(provider_from, provider); + + // (try) deserialize from invalid model + let bad_provider = + serde_json::from_str::("\"this-provider-does-not-will-not-exist\""); + assert!(bad_provider.is_err()); } }