Skip to content

Commit

Permalink
fix: a bug where train.jsonl does not exist
Browse files Browse the repository at this point in the history
Signed-off-by: Hung-Han (Henry) Chen <[email protected]>
  • Loading branch information
chenhunghan committed Jan 10, 2024
1 parent e372610 commit a9cd04c
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "mlx-training-rs"
version = "0.2.3"
version = "0.2.4"
edition = "2021"
repository = "https://github.com/chenhunghan/mlx-training-rs"

Expand Down
33 changes: 18 additions & 15 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,17 @@ use mlx_training_rs::cli::CLI;
use serde::Deserialize;
use tokio::fs::{self, OpenOptions};
use tokio::io::AsyncWriteExt;
use tokio::runtime::Runtime;
use serde_json;
use async_openai::{Client, types::{CreateChatCompletionRequestArgs, ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs}};

fn main() {
let rt = Runtime::new().unwrap();
rt.block_on(main_async()).unwrap();
}

async fn main_async() -> Result<(), Box<dyn Error>> {
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
// Parse command line arguments
let cli = CLI::parse();
let topic = &cli.topic;
let n = cli.n;

tokio::fs::create_dir_all("./data").await?;
write_instruction_jsonl(topic, n).await?;
write_train_jsonl().await?;
create_valid_file().await?;
Expand Down Expand Up @@ -55,6 +51,12 @@ async fn write_instruction_jsonl(topic: &str, n: usize) -> Result<(), Box<dyn Er
let instructions = fs::read_to_string(&file_path).await?;
let instructions: Vec<Instruction> = instructions.lines().map(|line| serde_json::from_str(&line).unwrap()).collect();

// Open the file in append mode
let mut file = OpenOptions::new()
.append(true)
.open(&file_path)
.await?;

println!("------------------------------");
println!("{}", format!("Generating instructions on topic {}...", topic));
for _ in 0..n {
Expand All @@ -64,12 +66,6 @@ async fn write_instruction_jsonl(topic: &str, n: usize) -> Result<(), Box<dyn Er
// println!("Skipping duplicate instruction: {}", instruction);
continue;
} else {
// Open the file in append mode
let mut file = OpenOptions::new()
.create(true)
.append(true)
.open(&file_path)
.await?;

println!("------------------------------");
println!("Writing new instruction to file: {}", instruction);
Expand Down Expand Up @@ -107,9 +103,16 @@ async fn write_train_jsonl() -> Result<(), Box<dyn Error>> {
let total = instructions.len();

let train_file_path = PathBuf::from("./data/").join("train.jsonl");
if !train_file_path.exists() {
println!("Creating train.jsonl file...");
let _ = OpenOptions::new()
.create(true)
.append(true)
.open(&train_file_path)
.await?;
}
let trainings: Vec<Train> = fs::read_to_string(&train_file_path).await?.lines().filter_map(|line| serde_json::from_str(&line).ok()).collect();
print!("{} data found in train.jsonl. ", trainings.len());


for (i, instruction) in instructions.iter().enumerate() {
if let Some(_) = trainings.iter().find(|t| t.text.contains(&instruction.text)) {
// println!("Skipping processing instruction {} because it can be found in train.jsonl", instruction.text);
Expand Down

0 comments on commit a9cd04c

Please sign in to comment.