Skip to content

Commit

Permalink
Update: remove unnecessary type wrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongdongy committed Mar 28, 2023
1 parent 910b741 commit 1685851
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 111 deletions.
22 changes: 8 additions & 14 deletions examples/azure/src/batch-create-transcription.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,24 @@
use std::{error::Error, time::Duration};

use rust_ai::azure::{
types::speech::transcription::{Status, Transcription},
SpeechModel,
};
use rust_ai::azure::{types::speech::Status, Speech};
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
std::env::set_var("RUST_LOG", "debug");
std::env::set_var("RUST_BACKTRACE", "1");
log4rs::init_file("log4rs.yml", Default::default()).unwrap();

let mut trans = Transcription::default();
trans.display_name = "Test".into();
trans.content_urls = Some(vec![String::from(
"https://crbn.us/whatstheweatherlike.wav",
)]);
let trans = SpeechModel::default()
.transcription(trans)
.create_transcription()
let trans = Speech::new_transcription("Test".into())
.content_urls(vec![String::from(
"https://crbn.us/whatstheweatherlike.wav",
)])
.create()
.await?;

std::thread::sleep(Duration::from_secs(5));
let trans = trans.status().await?;
if let Some(Status::Succeeded) = trans.status {
let results = trans.results().await?;
println!("{:?}", results.values);
let results = trans.files().await?;
println!("{:#?}", results.values);
}

Ok(())
Expand Down
173 changes: 107 additions & 66 deletions rust-ai/src/azure/apis/speech.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
//!
//! For use of styles and roles, see [docs/azure-voices-n-roles.md](https://github.com/dongsxyz/rust-ai/blob/master/docs/azure-voices-n-roles.md).
use std::collections::HashMap;

use log::{error, warn};
use reqwest::header::HeaderMap;

Expand All @@ -39,12 +41,18 @@ use crate::azure::{
},
types::{
common::{MicrosoftOutputFormat, ResponseExpectation, ResponseType},
speech::{health::ServiceHealth, filter::FilterOperator, transcription::{Transcription, Status}, ErrorResponse, file::PaginatedFiles

speech::{
entity::EntityReference,
file::PaginatedFiles,
filter::FilterOperator,
health::ServiceHealth,
transcription::{Status, Transcription},
ErrorResponse,
},
tts::Voice,
SSML,
},
Locale,
};

/// The Speech service allows you to convert text into synthesized speech and
Expand Down Expand Up @@ -76,6 +84,10 @@ impl From<SSML> for Speech {
}

impl Speech {
pub fn new_transcription(display_name: String) -> Transcription {
Transcription::default().display_name(display_name)
}

pub fn format(self, f: MicrosoftOutputFormat) -> Self {
Self {
output_format: f,
Expand Down Expand Up @@ -157,31 +169,7 @@ impl Speech {
}
}

/// TODO: remove `allow(dead_code)` when `models()` implemented.
#[allow(dead_code)]
pub struct SpeechModel {
model_id: Option<String>,

skip: Option<usize>,
top: Option<usize>,
filter: Option<FilterOperator>,

transcription: Option<Transcription>,
}

impl Default for SpeechModel {
fn default() -> Self {
Self {
model_id: None,
skip: None,
top: None,
filter: None,
transcription: None,
}
}
}

impl SpeechModel {
impl Transcription {
pub fn skip(self, skip: usize) -> Self {
Self {
skip: Some(skip),
Expand All @@ -201,20 +189,64 @@ impl SpeechModel {
}
}

pub fn id(self, id: String) -> Self {
pub fn sas_validity_in_seconds(self, sec: u32) -> Self {
Self {
model_id: Some(id),
sas_validity_in_seconds: Some(sec),
..self
}
}

pub fn transcription(self, transcription: Transcription) -> Self {
pub fn model(self, model: String) -> Self {
Self {
transcription: Some(transcription),
model: Some(EntityReference::from(model)),
..self
}
}

pub fn content_container_url(self, url: String) -> Self {
Self {
content_container_url: Some(url),
..self
}
}
pub fn content_urls(self, urls: Vec<String>) -> Self {
Self {
content_urls: Some(urls),
..self
}
}

pub fn project(self, project: String) -> Self {
Self {
project: Some(EntityReference::from(project)),
..self
}
}

pub fn set_self(self, _self: String) -> Self {
Self {
_self: Some(_self),
..self
}
}

/// Change display name of current transcription job.
///
/// No effect after transcription submitted.
pub fn display_name(self, display_name: String) -> Self {
Self {
display_name,
..self
}
}

/// Change default locale of this transcription job.
///
/// No effect after transcription submitted.
pub fn locale(self, locale: Locale) -> Self {
Self { locale, ..self }
}

/// [Custom Speech]
/// Gets the list of custom models for the authenticated subscription.
///
Expand Down Expand Up @@ -252,50 +284,45 @@ impl SpeechModel {
// Ok(())
}

pub async fn create_transcription(self) -> Result<Transcription, Box<dyn std::error::Error>> {
return if let Some(transcription) = self.transcription {
if let ResponseType::Text(text) = request_post_endpoint(
&SpeechServiceEndpoint::Post_Batch_Transcriptions_v3_1,
transcription,
ResponseExpectation::Text,
None,
)
.await?
{
return match serde_json::from_str::<Transcription>(&text) {
Ok(trans) => Ok(trans),
Err(e) => {
warn!(target: "azure", "Unable to parse transcription creation result: `{:#?}`", e);
match serde_json::from_str::<ErrorResponse>(&text) {
Ok(error) => {
println!("{:#?}", error);
error!(target: "azure", "Error from Azure: `{:?}`", e);
Err(Box::new(e))
}
Err(e) => {
error!(target: "azure", "Unable to parse error response: `{:?}`", e);
Err(Box::new(e))
}
/// Create a new audio transcription job.
pub async fn create(self) -> Result<Transcription, Box<dyn std::error::Error>> {
return if let ResponseType::Text(text) = request_post_endpoint(
&SpeechServiceEndpoint::Post_Create_Transcription_v3_1,
self,
ResponseExpectation::Text,
None,
)
.await?
{
return match serde_json::from_str::<Transcription>(&text) {
Ok(trans) => Ok(trans),
Err(e) => {
warn!(target: "azure", "Unable to parse transcription creation result: `{:#?}`", e);
match serde_json::from_str::<ErrorResponse>(&text) {
Ok(error) => {
println!("{:#?}", error);
error!(target: "azure", "Error from Azure: `{:?}`", e);
Err(Box::new(e))
}
Err(e) => {
error!(target: "azure", "Unable to parse error response: `{:?}`", e);
Err(Box::new(e))
}
}
};
} else {
Err("Unable to load output from Azure speech service endpoint".into())
}
}
};
} else {
Err("You need to call `transcription()` before create on Azure".into())
Err("Unable to load output from Azure speech service endpoint".into())
};
}
}

impl Transcription {
/// Check transcription status
///
/// This will only succeed when you've submitted the initial batch create
/// request to Azure endpoint.
pub async fn status(self) -> Result<Transcription, Box<dyn std::error::Error>> {
let text = request_get_endpoint(
&SpeechServiceEndpoint::Get_Transcription_Status_v3_1,
&SpeechServiceEndpoint::Get_Transcription_v3_1,
None,
Some(self.transcription_id().unwrap()),
)
Expand All @@ -321,7 +348,7 @@ impl Transcription {
}

/// Get batch transcription result from Azure endpoint
pub async fn results(self) -> Result<PaginatedFiles, Box<dyn std::error::Error>> {
pub async fn files(self) -> Result<PaginatedFiles, Box<dyn std::error::Error>> {
if let None = self.status.clone() {
return Err("You should submit the create request first.".into());
} else {
Expand All @@ -335,9 +362,23 @@ impl Transcription {
}
}

let mut params = HashMap::<String, String>::new();
if let Some(sas) = self.sas_validity_in_seconds.clone() {
params.insert("sasValidityInSeconds".into(), sas.to_string());
}
if let Some(skip) = self.skip.clone() {
params.insert("skip".into(), skip.to_string());
}
if let Some(top) = self.top.clone() {
params.insert("top".into(), top.to_string());
}
if let Some(filter) = self.filter.clone() {
params.insert("filter".into(), filter.to_string());
}

let text = request_get_endpoint(
&SpeechServiceEndpoint::Get_Transcription_Results_v3_1,
None,
&SpeechServiceEndpoint::Get_Transcription_Files_v3_1,
Some(params),
Some(format!("{}/files", self.transcription_id().unwrap())),
)
.await?;
Expand Down
12 changes: 6 additions & 6 deletions rust-ai/src/azure/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ pub enum SpeechServiceEndpoint {
Post_Text_to_Speech_v1,
Get_Speech_to_Text_Health_Status_v3_1,
Get_List_of_Models_v3_1,
Post_Batch_Transcriptions_v3_1,
Get_Transcription_Status_v3_1,
Get_Transcription_Results_v3_1,
Post_Create_Transcription_v3_1,
Get_Transcription_v3_1,
Get_Transcription_Files_v3_1,
}

impl SpeechServiceEndpoint {
Expand All @@ -46,17 +46,17 @@ impl SpeechServiceEndpoint {
region
),

Self::Post_Batch_Transcriptions_v3_1 => format!(
Self::Post_Create_Transcription_v3_1 => format!(
"https://{}.api.cognitive.microsoft.com/speechtotext/v3.1/transcriptions",
region
),

Self::Get_Transcription_Status_v3_1 => format!(
Self::Get_Transcription_v3_1 => format!(
"https://{}.api.cognitive.microsoft.com/speechtotext/v3.1/transcriptions/",
region
),

Self::Get_Transcription_Results_v3_1 => format!(
Self::Get_Transcription_Files_v3_1 => format!(
"https://{}.api.cognitive.microsoft.com/speechtotext/v3.1/transcriptions/",
region
),
Expand Down
2 changes: 1 addition & 1 deletion rust-ai/src/azure/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub mod apis;
/// Azure types definition
pub mod types;

pub use apis::speech::{Speech, SpeechModel};
pub use apis::speech::Speech;
pub use types::ssml;
pub use types::Gender;
pub use types::Locale;
Expand Down
12 changes: 12 additions & 0 deletions rust-ai/src/azure/types/speech/entity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ pub struct EntityReference {
pub _self: String,
}

impl From<String> for EntityReference {
fn from(value: String) -> Self {
Self { _self: value }
}
}

impl EntityReference {
pub fn from(s: String) -> Self {
Self { _self: s }
}
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct EntityError {
/// The code of this error.
Expand Down
Loading

0 comments on commit 1685851

Please sign in to comment.