From 5e4bac26337ef6015e1a267042edc81c88dbb6d3 Mon Sep 17 00:00:00 2001 From: jeremyhi Date: Thu, 29 Aug 2024 17:32:21 +0800 Subject: [PATCH] feat: import cli tool (#4639) * feat: import create tables * feat: import databasse * fix: export view schema --- src/cmd/src/cli.rs | 29 +++-- src/cmd/src/cli/database.rs | 119 +++++++++++++++++ src/cmd/src/cli/export.rs | 246 ++++++++++++++---------------------- src/cmd/src/cli/import.rs | 204 ++++++++++++++++++++++++++++++ 4 files changed, 437 insertions(+), 161 deletions(-) create mode 100644 src/cmd/src/cli/database.rs create mode 100644 src/cmd/src/cli/import.rs diff --git a/src/cmd/src/cli.rs b/src/cmd/src/cli.rs index 3042b8370f77..cc7ddc47e11d 100644 --- a/src/cmd/src/cli.rs +++ b/src/cmd/src/cli.rs @@ -12,18 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod bench; - -// Wait for https://github.com/GreptimeTeam/greptimedb/issues/2373 -#[allow(unused)] -mod cmd; -mod export; -mod helper; - -// Wait for https://github.com/GreptimeTeam/greptimedb/issues/2373 -#[allow(unused)] -mod repl; - use async_trait::async_trait; use bench::BenchTableMetadataCommand; use clap::Parser; @@ -32,10 +20,25 @@ pub use repl::Repl; use tracing_appender::non_blocking::WorkerGuard; use self::export::ExportCommand; +use crate::cli::import::ImportCommand; use crate::error::Result; use crate::options::GlobalOptions; use crate::App; +mod bench; + +// Wait for https://github.com/GreptimeTeam/greptimedb/issues/2373 +#[allow(unused)] +mod cmd; +mod export; +mod helper; + +// Wait for https://github.com/GreptimeTeam/greptimedb/issues/2373 +mod database; +mod import; +#[allow(unused)] +mod repl; + pub const APP_NAME: &str = "greptime-cli"; #[async_trait] @@ -114,6 +117,7 @@ enum SubCommand { // Attach(AttachCommand), Bench(BenchTableMetadataCommand), Export(ExportCommand), + Import(ImportCommand), } impl SubCommand { @@ -122,6 +126,7 @@ impl SubCommand { // SubCommand::Attach(cmd) => cmd.build().await, SubCommand::Bench(cmd) => cmd.build(guard).await, SubCommand::Export(cmd) => cmd.build(guard).await, + SubCommand::Import(cmd) => cmd.build(guard).await, } } } diff --git a/src/cmd/src/cli/database.rs b/src/cmd/src/cli/database.rs new file mode 100644 index 000000000000..eb5647699ef0 --- /dev/null +++ b/src/cmd/src/cli/database.rs @@ -0,0 +1,119 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use base64::engine::general_purpose; +use base64::Engine; +use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; +use serde_json::Value; +use servers::http::greptime_result_v1::GreptimedbV1Response; +use servers::http::GreptimeQueryOutput; +use snafu::ResultExt; + +use crate::error::{HttpQuerySqlSnafu, Result, SerdeJsonSnafu}; + +pub(crate) struct DatabaseClient { + addr: String, + catalog: String, + auth_header: Option, +} + +impl DatabaseClient { + pub fn new(addr: String, catalog: String, auth_basic: Option) -> Self { + let auth_header = if let Some(basic) = auth_basic { + let encoded = general_purpose::STANDARD.encode(basic); + Some(format!("basic {}", encoded)) + } else { + None + }; + + Self { + addr, + catalog, + auth_header, + } + } + + pub async fn sql_in_public(&self, sql: &str) -> Result>>> { + self.sql(sql, DEFAULT_SCHEMA_NAME).await + } + + /// Execute sql query. + pub async fn sql(&self, sql: &str, schema: &str) -> Result>>> { + let url = format!("http://{}/v1/sql", self.addr); + let params = [ + ("db", format!("{}-{}", self.catalog, schema)), + ("sql", sql.to_string()), + ]; + let mut request = reqwest::Client::new() + .post(&url) + .form(¶ms) + .header("Content-Type", "application/x-www-form-urlencoded"); + if let Some(ref auth) = self.auth_header { + request = request.header("Authorization", auth); + } + + let response = request.send().await.with_context(|_| HttpQuerySqlSnafu { + reason: format!("bad url: {}", url), + })?; + let response = response + .error_for_status() + .with_context(|_| HttpQuerySqlSnafu { + reason: format!("query failed: {}", sql), + })?; + + let text = response.text().await.with_context(|_| HttpQuerySqlSnafu { + reason: "cannot get response text".to_string(), + })?; + + let body = serde_json::from_str::(&text).context(SerdeJsonSnafu)?; + Ok(body.output().first().and_then(|output| match output { + GreptimeQueryOutput::Records(records) => Some(records.rows().clone()), + GreptimeQueryOutput::AffectedRows(_) => None, + })) + } +} + +/// Split at `-`. +pub(crate) fn split_database(database: &str) -> Result<(String, Option)> { + let (catalog, schema) = match database.split_once('-') { + Some((catalog, schema)) => (catalog, schema), + None => (DEFAULT_CATALOG_NAME, database), + }; + + if schema == "*" { + Ok((catalog.to_string(), None)) + } else { + Ok((catalog.to_string(), Some(schema.to_string()))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_split_database() { + let result = split_database("catalog-schema").unwrap(); + assert_eq!(result, ("catalog".to_string(), Some("schema".to_string()))); + + let result = split_database("schema").unwrap(); + assert_eq!(result, ("greptime".to_string(), Some("schema".to_string()))); + + let result = split_database("catalog-*").unwrap(); + assert_eq!(result, ("catalog".to_string(), None)); + + let result = split_database("*").unwrap(); + assert_eq!(result, ("greptime".to_string(), None)); + } +} diff --git a/src/cmd/src/cli/export.rs b/src/cmd/src/cli/export.rs index 90699fae7746..5634d3f1b61a 100644 --- a/src/cmd/src/cli/export.rs +++ b/src/cmd/src/cli/export.rs @@ -17,26 +17,19 @@ use std::path::Path; use std::sync::Arc; use async_trait::async_trait; -use base64::engine::general_purpose; -use base64::Engine; use clap::{Parser, ValueEnum}; -use client::DEFAULT_SCHEMA_NAME; -use common_catalog::consts::DEFAULT_CATALOG_NAME; use common_telemetry::{debug, error, info}; use serde_json::Value; -use servers::http::greptime_result_v1::GreptimedbV1Response; -use servers::http::GreptimeQueryOutput; -use snafu::ResultExt; +use snafu::{OptionExt, ResultExt}; use tokio::fs::File; use tokio::io::{AsyncWriteExt, BufWriter}; use tokio::sync::Semaphore; use tokio::time::Instant; use tracing_appender::non_blocking::WorkerGuard; -use crate::cli::{Instance, Tool}; -use crate::error::{ - EmptyResultSnafu, Error, FileIoSnafu, HttpQuerySqlSnafu, Result, SerdeJsonSnafu, -}; +use crate::cli::database::DatabaseClient; +use crate::cli::{database, Instance, Tool}; +use crate::error::{EmptyResultSnafu, Error, FileIoSnafu, Result}; type TableReference = (String, String, String); @@ -94,26 +87,21 @@ pub struct ExportCommand { impl ExportCommand { pub async fn build(&self, guard: Vec) -> Result { - let (catalog, schema) = split_database(&self.database)?; + let (catalog, schema) = database::split_database(&self.database)?; - let auth_header = if let Some(basic) = &self.auth_basic { - let encoded = general_purpose::STANDARD.encode(basic); - Some(format!("basic {}", encoded)) - } else { - None - }; + let database_client = + DatabaseClient::new(self.addr.clone(), catalog.clone(), self.auth_basic.clone()); Ok(Instance::new( Box::new(Export { - addr: self.addr.clone(), catalog, schema, + database_client, output_dir: self.output_dir.clone(), parallelism: self.export_jobs, target: self.target.clone(), start_time: self.start_time.clone(), end_time: self.end_time.clone(), - auth_header, }), guard, )) @@ -121,78 +109,43 @@ impl ExportCommand { } pub struct Export { - addr: String, catalog: String, schema: Option, + database_client: DatabaseClient, output_dir: String, parallelism: usize, target: ExportTarget, start_time: Option, end_time: Option, - auth_header: Option, } impl Export { - /// Execute one single sql query. - async fn sql(&self, sql: &str) -> Result>>> { - let url = format!( - "http://{}/v1/sql?db={}-{}&sql={}", - self.addr, - self.catalog, - self.schema.as_deref().unwrap_or(DEFAULT_SCHEMA_NAME), - sql - ); - - let mut request = reqwest::Client::new() - .get(&url) - .header("Content-Type", "application/x-www-form-urlencoded"); - if let Some(ref auth) = self.auth_header { - request = request.header("Authorization", auth); + async fn get_db_names(&self) -> Result> { + if let Some(schema) = &self.schema { + Ok(vec![schema.clone()]) + } else { + self.all_db_names().await } - - let response = request.send().await.with_context(|_| HttpQuerySqlSnafu { - reason: format!("bad url: {}", url), - })?; - let response = response - .error_for_status() - .with_context(|_| HttpQuerySqlSnafu { - reason: format!("query failed: {}", sql), - })?; - - let text = response.text().await.with_context(|_| HttpQuerySqlSnafu { - reason: "cannot get response text".to_string(), - })?; - - let body = serde_json::from_str::(&text).context(SerdeJsonSnafu)?; - Ok(body.output().first().and_then(|output| match output { - GreptimeQueryOutput::Records(records) => Some(records.rows().clone()), - GreptimeQueryOutput::AffectedRows(_) => None, - })) } /// Iterate over all db names. - /// - /// Newbie: `db_name` is catalog + schema. - async fn iter_db_names(&self) -> Result> { - if let Some(schema) = &self.schema { - Ok(vec![(self.catalog.clone(), schema.clone())]) - } else { - let result = self.sql("SHOW DATABASES").await?; - let Some(records) = result else { - EmptyResultSnafu.fail()? + async fn all_db_names(&self) -> Result> { + let result = self.database_client.sql_in_public("SHOW DATABASES").await?; + let records = result.context(EmptyResultSnafu)?; + let mut result = Vec::with_capacity(records.len()); + for value in records { + let Value::String(schema) = &value[0] else { + unreachable!() }; - let mut result = Vec::with_capacity(records.len()); - for value in records { - let Value::String(schema) = &value[0] else { - unreachable!() - }; - if schema == common_catalog::consts::INFORMATION_SCHEMA_NAME { - continue; - } - result.push((self.catalog.clone(), schema.clone())); + if schema == common_catalog::consts::INFORMATION_SCHEMA_NAME { + continue; + } + if schema == common_catalog::consts::PG_CATALOG_NAME { + continue; } - Ok(result) + result.push(schema.clone()); } + Ok(result) } /// Return a list of [`TableReference`] to be exported. @@ -201,7 +154,11 @@ impl Export { &self, catalog: &str, schema: &str, - ) -> Result<(Vec, Vec)> { + ) -> Result<( + Vec, + Vec, + Vec, + )> { // Puts all metric table first let sql = format!( "SELECT table_catalog, table_schema, table_name \ @@ -210,15 +167,13 @@ impl Export { and table_catalog = \'{catalog}\' \ and table_schema = \'{schema}\'" ); - let result = self.sql(&sql).await?; - let Some(records) = result else { - EmptyResultSnafu.fail()? - }; + let result = self.database_client.sql_in_public(&sql).await?; + let records = result.context(EmptyResultSnafu)?; let mut metric_physical_tables = HashSet::with_capacity(records.len()); for value in records { let mut t = Vec::with_capacity(3); for v in &value { - let serde_json::Value::String(value) = v else { + let Value::String(value) = v else { unreachable!() }; t.push(value); @@ -228,54 +183,63 @@ impl Export { // TODO: SQL injection hurts let sql = format!( - "SELECT table_catalog, table_schema, table_name \ + "SELECT table_catalog, table_schema, table_name, table_type \ FROM information_schema.tables \ - WHERE table_type = \'BASE TABLE\' \ + WHERE (table_type = \'BASE TABLE\' OR table_type = \'VIEW\') \ and table_catalog = \'{catalog}\' \ and table_schema = \'{schema}\'", ); - let result = self.sql(&sql).await?; - let Some(records) = result else { - EmptyResultSnafu.fail()? - }; + let result = self.database_client.sql_in_public(&sql).await?; + let records = result.context(EmptyResultSnafu)?; - debug!("Fetched table list: {:?}", records); + debug!("Fetched table/view list: {:?}", records); if records.is_empty() { - return Ok((vec![], vec![])); + return Ok((vec![], vec![], vec![])); } let mut remaining_tables = Vec::with_capacity(records.len()); + let mut views = Vec::new(); for value in records { - let mut t = Vec::with_capacity(3); + let mut t = Vec::with_capacity(4); for v in &value { - let serde_json::Value::String(value) = v else { + let Value::String(value) = v else { unreachable!() }; t.push(value); } let table = (t[0].clone(), t[1].clone(), t[2].clone()); + let table_type = t[3].as_str(); // Ignores the physical table if !metric_physical_tables.contains(&table) { - remaining_tables.push(table); + if table_type == "VIEW" { + views.push(table); + } else { + remaining_tables.push(table); + } } } Ok(( metric_physical_tables.into_iter().collect(), remaining_tables, + views, )) } - async fn show_create_table(&self, catalog: &str, schema: &str, table: &str) -> Result { + async fn show_create( + &self, + show_type: &str, + catalog: &str, + schema: &str, + table: &str, + ) -> Result { let sql = format!( - r#"SHOW CREATE TABLE "{}"."{}"."{}""#, - catalog, schema, table + r#"SHOW CREATE {} "{}"."{}"."{}""#, + show_type, catalog, schema, table ); - let result = self.sql(&sql).await?; - let Some(records) = result else { - EmptyResultSnafu.fail()? - }; + let result = self.database_client.sql_in_public(&sql).await?; + let records = result.context(EmptyResultSnafu)?; let Value::String(create_table) = &records[0][1] else { unreachable!() }; @@ -286,18 +250,19 @@ impl Export { async fn export_create_table(&self) -> Result<()> { let timer = Instant::now(); let semaphore = Arc::new(Semaphore::new(self.parallelism)); - let db_names = self.iter_db_names().await?; + let db_names = self.get_db_names().await?; let db_count = db_names.len(); let mut tasks = Vec::with_capacity(db_names.len()); - for (catalog, schema) in db_names { + for schema in db_names { let semaphore_moved = semaphore.clone(); tasks.push(async move { let _permit = semaphore_moved.acquire().await.unwrap(); - let (metric_physical_tables, remaining_tables) = - self.get_table_list(&catalog, &schema).await?; - let table_count = metric_physical_tables.len() + remaining_tables.len(); + let (metric_physical_tables, remaining_tables, views) = + self.get_table_list(&self.catalog, &schema).await?; + let table_count = + metric_physical_tables.len() + remaining_tables.len() + views.len(); let output_dir = Path::new(&self.output_dir) - .join(&catalog) + .join(&self.catalog) .join(format!("{schema}/")); tokio::fs::create_dir_all(&output_dir) .await @@ -305,7 +270,7 @@ impl Export { let output_file = Path::new(&output_dir).join("create_tables.sql"); let mut file = File::create(output_file).await.context(FileIoSnafu)?; for (c, s, t) in metric_physical_tables.into_iter().chain(remaining_tables) { - match self.show_create_table(&c, &s, &t).await { + match self.show_create("TABLE", &c, &s, &t).await { Err(e) => { error!(e; r#"Failed to export table "{}"."{}"."{}""#, c, s, t) } @@ -316,9 +281,22 @@ impl Export { } } } + for (c, s, v) in views { + match self.show_create("VIEW", &c, &s, &v).await { + Err(e) => { + error!(e; r#"Failed to export view "{}"."{}"."{}""#, c, s, v) + } + Ok(create_view) => { + file.write_all(create_view.as_bytes()) + .await + .context(FileIoSnafu)?; + } + } + } info!( - "Finished exporting {catalog}.{schema} with {table_count} table schemas to path: {}", + "Finished exporting {}.{schema} with {table_count} table schemas to path: {}", + self.catalog, output_dir.to_string_lossy() ); @@ -332,7 +310,7 @@ impl Export { .filter(|r| match r { Ok(_) => true, Err(e) => { - error!(e; "export job failed"); + error!(e; "export schema job failed"); false } }) @@ -347,15 +325,15 @@ impl Export { async fn export_database_data(&self) -> Result<()> { let timer = Instant::now(); let semaphore = Arc::new(Semaphore::new(self.parallelism)); - let db_names = self.iter_db_names().await?; + let db_names = self.get_db_names().await?; let db_count = db_names.len(); - let mut tasks = Vec::with_capacity(db_names.len()); - for (catalog, schema) in db_names { + let mut tasks = Vec::with_capacity(db_count); + for schema in db_names { let semaphore_moved = semaphore.clone(); tasks.push(async move { let _permit = semaphore_moved.acquire().await.unwrap(); let output_dir = Path::new(&self.output_dir) - .join(&catalog) + .join(&self.catalog) .join(format!("{schema}/")); tokio::fs::create_dir_all(&output_dir) .await @@ -379,7 +357,7 @@ impl Export { let sql = format!( r#"COPY DATABASE "{}"."{}" TO '{}' {};"#, - catalog, + self.catalog, schema, output_dir.to_str().unwrap(), with_options @@ -387,10 +365,11 @@ impl Export { info!("Executing sql: {sql}"); - self.sql(&sql).await?; + self.database_client.sql_in_public(&sql).await?; info!( - "Finished exporting {catalog}.{schema} data into path: {}", + "Finished exporting {}.{schema} data into path: {}", + self.catalog, output_dir.to_string_lossy() ); @@ -400,7 +379,7 @@ impl Export { BufWriter::new(File::create(copy_from_file).await.context(FileIoSnafu)?); let copy_database_from_sql = format!( r#"COPY DATABASE "{}"."{}" FROM '{}' WITH (FORMAT='parquet');"#, - catalog, + self.catalog, schema, output_dir.to_str().unwrap() ); @@ -410,7 +389,7 @@ impl Export { .context(FileIoSnafu)?; writer.flush().await.context(FileIoSnafu)?; - info!("Finished exporting {catalog}.{schema} copy_from.sql"); + info!("Finished exporting {}.{schema} copy_from.sql", self.catalog); Ok::<(), Error>(()) }) @@ -429,13 +408,12 @@ impl Export { .count(); let elapsed = timer.elapsed(); - info!("Success {success}/{db_count} jobs, costs: {:?}", elapsed); + info!("Success {success}/{db_count} jobs, costs: {elapsed:?}"); Ok(()) } } -#[allow(deprecated)] #[async_trait] impl Tool for Export { async fn do_work(&self) -> Result<()> { @@ -450,20 +428,6 @@ impl Tool for Export { } } -/// Split at `-`. -fn split_database(database: &str) -> Result<(String, Option)> { - let (catalog, schema) = match database.split_once('-') { - Some((catalog, schema)) => (catalog, schema), - None => (DEFAULT_CATALOG_NAME, database), - }; - - if schema == "*" { - Ok((catalog.to_string(), None)) - } else { - Ok((catalog.to_string(), Some(schema.to_string()))) - } -} - #[cfg(test)] mod tests { use clap::Parser; @@ -471,26 +435,10 @@ mod tests { use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_telemetry::logging::LoggingOptions; - use crate::cli::export::split_database; use crate::error::Result as CmdResult; use crate::options::GlobalOptions; use crate::{cli, standalone, App}; - #[test] - fn test_split_database() { - let result = split_database("catalog-schema").unwrap(); - assert_eq!(result, ("catalog".to_string(), Some("schema".to_string()))); - - let result = split_database("schema").unwrap(); - assert_eq!(result, ("greptime".to_string(), Some("schema".to_string()))); - - let result = split_database("catalog-*").unwrap(); - assert_eq!(result, ("catalog".to_string(), None)); - - let result = split_database("*").unwrap(); - assert_eq!(result, ("greptime".to_string(), None)); - } - #[tokio::test(flavor = "multi_thread")] async fn test_export_create_table_with_quoted_names() -> CmdResult<()> { let output_dir = tempfile::tempdir().unwrap(); diff --git a/src/cmd/src/cli/import.rs b/src/cmd/src/cli/import.rs new file mode 100644 index 000000000000..920e225d7ae4 --- /dev/null +++ b/src/cmd/src/cli/import.rs @@ -0,0 +1,204 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::path::PathBuf; +use std::sync::Arc; + +use async_trait::async_trait; +use clap::{Parser, ValueEnum}; +use common_telemetry::{error, info, warn}; +use snafu::ResultExt; +use tokio::sync::Semaphore; +use tokio::time::Instant; +use tracing_appender::non_blocking::WorkerGuard; + +use crate::cli::database::DatabaseClient; +use crate::cli::{database, Instance, Tool}; +use crate::error::{Error, FileIoSnafu, Result}; + +#[derive(Debug, Default, Clone, ValueEnum)] +enum ImportTarget { + /// Import all table schemas into the database. + Schema, + /// Import all table data into the database. + Data, + /// Export all table schemas and data at once. + #[default] + All, +} + +#[derive(Debug, Default, Parser)] +pub struct ImportCommand { + /// Server address to connect + #[clap(long)] + addr: String, + + /// Directory of the data. E.g.: /tmp/greptimedb-backup + #[clap(long)] + input_dir: String, + + /// The name of the catalog to import. + #[clap(long, default_value = "greptime-*")] + database: String, + + /// Parallelism of the import. + #[clap(long, short = 'j', default_value = "1")] + import_jobs: usize, + + /// Max retry times for each job. + #[clap(long, default_value = "3")] + max_retry: usize, + + /// Things to export + #[clap(long, short = 't', value_enum, default_value = "all")] + target: ImportTarget, + + /// The basic authentication for connecting to the server + #[clap(long)] + auth_basic: Option, +} + +impl ImportCommand { + pub async fn build(&self, guard: Vec) -> Result { + let (catalog, schema) = database::split_database(&self.database)?; + let database_client = + DatabaseClient::new(self.addr.clone(), catalog.clone(), self.auth_basic.clone()); + + Ok(Instance::new( + Box::new(Import { + catalog, + schema, + database_client, + input_dir: self.input_dir.clone(), + parallelism: self.import_jobs, + target: self.target.clone(), + }), + guard, + )) + } +} + +pub struct Import { + catalog: String, + schema: Option, + database_client: DatabaseClient, + input_dir: String, + parallelism: usize, + target: ImportTarget, +} + +impl Import { + async fn import_create_table(&self) -> Result<()> { + self.do_sql_job("create_tables.sql").await + } + + async fn import_database_data(&self) -> Result<()> { + self.do_sql_job("copy_from.sql").await + } + + async fn do_sql_job(&self, filename: &str) -> Result<()> { + let timer = Instant::now(); + let semaphore = Arc::new(Semaphore::new(self.parallelism)); + let db_names = self.get_db_names().await?; + let db_count = db_names.len(); + let mut tasks = Vec::with_capacity(db_count); + for schema in db_names { + let semaphore_moved = semaphore.clone(); + tasks.push(async move { + let _permit = semaphore_moved.acquire().await.unwrap(); + let database_input_dir = self.catalog_path().join(&schema); + let sql_file = database_input_dir.join(filename); + let sql = tokio::fs::read_to_string(sql_file) + .await + .context(FileIoSnafu)?; + if sql.is_empty() { + info!("Empty `{filename}` {database_input_dir:?}"); + } else { + self.database_client.sql(&sql, &schema).await?; + info!("Imported `{filename}` for database {schema}"); + } + + Ok::<(), Error>(()) + }) + } + + let success = futures::future::join_all(tasks) + .await + .into_iter() + .filter(|r| match r { + Ok(_) => true, + Err(e) => { + error!(e; "import {filename} job failed"); + false + } + }) + .count(); + let elapsed = timer.elapsed(); + info!("Success {success}/{db_count} `{filename}` jobs, cost: {elapsed:?}"); + + Ok(()) + } + + fn catalog_path(&self) -> PathBuf { + PathBuf::from(&self.input_dir).join(&self.catalog) + } + + async fn get_db_names(&self) -> Result> { + if let Some(schema) = &self.schema { + Ok(vec![schema.clone()]) + } else { + self.all_db_names().await + } + } + + // Get all database names in the input directory. + // The directory structure should be like: + // /tmp/greptimedb-backup + // ├── greptime-1 + // │ ├── db1 + // │ └── db2 + async fn all_db_names(&self) -> Result> { + let mut db_names = vec![]; + let path = self.catalog_path(); + let mut entries = tokio::fs::read_dir(path).await.context(FileIoSnafu)?; + while let Some(entry) = entries.next_entry().await.context(FileIoSnafu)? { + let path = entry.path(); + if path.is_dir() { + let db_name = match path.file_name() { + Some(name) => name.to_string_lossy().to_string(), + None => { + warn!("Failed to get the file name of {:?}", path); + continue; + } + }; + db_names.push(db_name); + } + } + Ok(db_names) + } +} + +#[async_trait] +impl Tool for Import { + async fn do_work(&self) -> Result<()> { + match self.target { + ImportTarget::Schema => self.import_create_table().await, + ImportTarget::Data => self.import_database_data().await, + ImportTarget::All => { + self.import_create_table().await?; + self.import_database_data().await + } + } + } +}