Skip to content

Commit

Permalink
feat: add a timeout option
Browse files Browse the repository at this point in the history
  • Loading branch information
frectonz committed Jun 24, 2024
1 parent 0d30220 commit 82ab1f8
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 33 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ libsql = { git = "https://github.com/tursodatabase/libsql.git", features = ["rem
tokio-postgres = "0.7.10"
mysql_async = { version = "0.34.1", default-features = false, features = ["rustls-tls", "default-rustls"] }
duckdb = { git = "https://github.com/frectonz/duckdb-rs.git", features = ["bundled"] }
humantime = "2.1.0"

[profile.release]
strip = true
Expand Down
112 changes: 79 additions & 33 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ struct Args {
/// The address to bind to.
#[arg(short, long, default_value = "127.0.0.1:3030")]
address: String,

/// Timeout duration for queries sent from the query page.
#[clap(short, long, default_value = "5secs")]
timeout: humantime::Duration,
}

#[derive(Debug, Subcommand)]
Expand Down Expand Up @@ -58,7 +62,7 @@ async fn main() -> color_eyre::Result<()> {
color_eyre::install()?;

let filter = std::env::var("RUST_LOG")
.unwrap_or_else(|_| "tracing=info,warp=debug,sqlite_studio=debug".to_owned());
.unwrap_or_else(|_| "tracing=info,warp=debug,sql_studio=debug".to_owned());
tracing_subscriber::fmt()
.with_env_filter(filter)
.with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
Expand All @@ -69,16 +73,20 @@ async fn main() -> color_eyre::Result<()> {
let db = match args.db {
Command::Sqlite { database } => AllDbs::Sqlite(if database == "preview" {
tokio::fs::write("sample.db", SAMPLE_DB).await?;
sqlite::Db::open("sample.db".to_string()).await?
sqlite::Db::open("sample.db".to_string(), args.timeout.into()).await?
} else {
sqlite::Db::open(database).await?
sqlite::Db::open(database, args.timeout.into()).await?
}),
Command::Libsql { url, auth_token } => {
AllDbs::Libsql(libsql::Db::open(url, auth_token).await?)
AllDbs::Libsql(libsql::Db::open(url, auth_token, args.timeout.into()).await?)
}
Command::Postgres { url } => {
AllDbs::Postgres(postgres::Db::open(url, args.timeout.into()).await?)
}
Command::Mysql { url } => AllDbs::Mysql(mysql::Db::open(url, args.timeout.into()).await?),
Command::Duckdb { database } => {
AllDbs::Duckdb(duckdb::Db::open(database, args.timeout.into()).await?)
}
Command::Postgres { url } => AllDbs::Postgres(postgres::Db::open(url).await?),
Command::Mysql { url } => AllDbs::Mysql(mysql::Db::open(url).await?),
Command::Duckdb { database } => AllDbs::Duckdb(duckdb::Db::open(database).await?),
};

let cors = warp::cors()
Expand Down Expand Up @@ -247,7 +255,12 @@ impl Database for AllDbs {
mod sqlite {
use async_trait::async_trait;
use color_eyre::eyre::OptionExt;
use std::{collections::HashMap, path::Path, sync::Arc};
use std::{
collections::HashMap,
path::Path,
sync::Arc,
time::{Duration, SystemTime},
};
use tokio_rusqlite::{Connection, OpenFlags};

use crate::{helpers, responses, Database, ROWS_PER_PAGE};
Expand All @@ -256,10 +269,11 @@ mod sqlite {
pub struct Db {
path: String,
conn: Arc<Connection>,
query_timeout: Duration,
}

impl Db {
pub async fn open(path: String) -> color_eyre::Result<Self> {
pub async fn open(path: String, query_timeout: Duration) -> color_eyre::Result<Self> {
let conn = Connection::open_with_flags(&path, OpenFlags::SQLITE_OPEN_READ_ONLY).await?;

// This is meant to test if the file at path is actually a DB.
Expand All @@ -279,6 +293,7 @@ mod sqlite {
tracing::info!("found {tables} tables in {path}");
Ok(Self {
path,
query_timeout,
conn: Arc::new(conn),
})
}
Expand Down Expand Up @@ -527,7 +542,10 @@ mod sqlite {
}

async fn query(&self, query: String) -> color_eyre::Result<responses::Query> {
Ok(self
let start = SystemTime::now();
let timeout = self.query_timeout;

let res = self
.conn
.call(move |conn| {
let mut stmt = conn.prepare(&query)?;
Expand All @@ -538,27 +556,35 @@ mod sqlite {
.collect::<Vec<_>>();

let columns_len = columns.len();
let rows = stmt
let rows: Result<Vec<_>, _> = stmt
.query_map((), |r| {
let now = SystemTime::now();
if now - timeout >= start {
// just used a random error, we just want to bail out
return Err(rusqlite::Error::InvalidQuery);
}

let mut rows = Vec::with_capacity(columns_len);
for i in 0..columns_len {
let val = helpers::rusqlite_value_to_json(r.get_ref(i)?);
rows.push(val);
}
Ok(rows)
})?
.filter_map(|x| x.ok())
.collect::<Vec<_>>();
.collect();
let rows = rows?;

Ok(responses::Query { columns, rows })
})
.await?)
.await?;

Ok(res)
}
}
}

mod libsql {
use std::{collections::HashMap, sync::Arc};
use std::{collections::HashMap, sync::Arc, time::Duration};

use async_trait::async_trait;
use color_eyre::eyre::OptionExt;
Expand All @@ -571,10 +597,15 @@ mod libsql {
pub struct Db {
url: String,
db: Arc<libsql::Database>,
query_timeout: Duration,
}

impl Db {
pub async fn open(url: String, auth_token: String) -> color_eyre::Result<Self> {
pub async fn open(
url: String,
auth_token: String,
query_timeout: Duration,
) -> color_eyre::Result<Self> {
let db = Builder::new_remote(url.to_owned(), auth_token)
.build()
.await?;
Expand All @@ -601,6 +632,7 @@ mod libsql {

Ok(Self {
url,
query_timeout,
db: Arc::new(db),
})
}
Expand Down Expand Up @@ -932,8 +964,11 @@ mod libsql {

color_eyre::eyre::Ok(rows)
})
.collect::<Vec<_>>()
.await
.collect::<Vec<_>>();

let rows = tokio::time::timeout(self.query_timeout, rows).await?;

let rows = rows
.into_iter()
.filter_map(|r| r.ok())
.filter_map(|r| r.ok())
Expand Down Expand Up @@ -961,7 +996,7 @@ mod libsql {
}

mod postgres {
use std::sync::Arc;
use std::{sync::Arc, time::Duration};

use async_trait::async_trait;
use tokio_postgres::{Client, NoTls};
Expand All @@ -975,10 +1010,11 @@ mod postgres {
#[derive(Clone)]
pub struct Db {
client: Arc<Client>,
query_timeout: Duration,
}

impl Db {
pub async fn open(url: String) -> color_eyre::Result<Self> {
pub async fn open(url: String, query_timeout: Duration) -> color_eyre::Result<Self> {
let (client, connection) = tokio_postgres::connect(&url, NoTls).await?;

// The connection object performs the actual communication with the database,
Expand Down Expand Up @@ -1008,6 +1044,7 @@ mod postgres {
);

Ok(Self {
query_timeout,
client: Arc::new(client),
})
}
Expand Down Expand Up @@ -1267,10 +1304,9 @@ mod postgres {
.collect::<Vec<_>>();

let columns_len = columns.len();
let rows = self
.client
.simple_query(&query)
.await?
let rows = self.client.simple_query(&query);
let rows = tokio::time::timeout(self.query_timeout, rows)
.await??
.into_iter()
.filter_map(|r| {
if let tokio_postgres::SimpleQueryMessage::Row(row) = r {
Expand All @@ -1296,6 +1332,8 @@ mod postgres {
}

mod mysql {
use std::time::Duration;

use async_trait::async_trait;
use color_eyre::eyre::OptionExt;
use mysql_async::{prelude::*, Pool};
Expand All @@ -1309,10 +1347,11 @@ mod mysql {
#[derive(Clone)]
pub struct Db {
pool: Pool,
query_timeout: Duration,
}

impl Db {
pub async fn open(url: String) -> color_eyre::Result<Self> {
pub async fn open(url: String, query_timeout: Duration) -> color_eyre::Result<Self> {
let pool = Pool::from_url(&url)?;
let conn = pool.get_conn().await?;

Expand All @@ -1333,7 +1372,10 @@ mod mysql {
if tables == 1 { "" } else { "s" }
);

Ok(Self { pool })
Ok(Self {
pool,
query_timeout,
})
}
}

Expand Down Expand Up @@ -1592,9 +1634,9 @@ mod mysql {
.collect::<Vec<_>>();

let columns_len = columns.len();
let rows = conn
.query_iter(query)
.await?
let rows = conn.query_iter(query);
let rows = tokio::time::timeout(self.query_timeout, rows)
.await??
.map_and_drop(|mut r| {
let mut row: Vec<mysql_async::Value> = Vec::with_capacity(columns_len);

Expand Down Expand Up @@ -1628,6 +1670,7 @@ mod duckdb {
use std::{
path::Path,
sync::{Arc, Mutex},
time::Duration,
};

use crate::{
Expand All @@ -1640,10 +1683,11 @@ mod duckdb {
pub struct Db {
path: String,
conn: Arc<Mutex<Connection>>,
query_timeout: Duration,
}

impl Db {
pub async fn open(path: String) -> color_eyre::Result<Self> {
pub async fn open(path: String, query_timeout: Duration) -> color_eyre::Result<Self> {
let p = path.to_owned();
let conn = tokio::task::spawn_blocking(move || {
let config = Config::default().access_mode(duckdb::AccessMode::ReadOnly)?;
Expand Down Expand Up @@ -1675,6 +1719,7 @@ mod duckdb {
);
Ok(Self {
path,
query_timeout,
conn: Arc::new(Mutex::new(conn)),
})
}
Expand Down Expand Up @@ -1909,7 +1954,7 @@ mod duckdb {
async fn query(&self, query: String) -> color_eyre::Result<responses::Query> {
let c = self.conn.clone();

let (columns, rows) = tokio::task::spawn_blocking(move || {
let future = tokio::task::spawn_blocking(move || {
let c = c.lock().expect("could not get lock on connection");

let mut stmt = c.prepare(&query)?;
Expand All @@ -1933,8 +1978,9 @@ mod duckdb {
let columns = stmt.column_names();

eyre::Ok((columns, rows))
})
.await??;
});

let (columns, rows) = tokio::time::timeout(self.query_timeout, future).await???;

Ok(responses::Query { columns, rows })
}
Expand Down

0 comments on commit 82ab1f8

Please sign in to comment.