Skip to content

Commit

Permalink
feat: customize postgres schema
Browse files Browse the repository at this point in the history
  • Loading branch information
frectonz committed Jul 13, 2024
1 parent 740a0de commit 7fde612
Showing 1 changed file with 81 additions and 43 deletions.
124 changes: 81 additions & 43 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@ enum Command {

/// A PostgreSQL database.
Postgres {
/// postgresql connection url [postgresql://postgres:[email protected]/sample]
/// PostgreSQL connection url [postgresql://postgres:[email protected]/sample]
url: String,

/// PostgreSQL schema
#[arg(short, long, default_value = "public")]
schema: String,
},

/// A MySQL/MariaDB database.
Expand Down Expand Up @@ -100,8 +104,8 @@ async fn main() -> color_eyre::Result<()> {
Command::Libsql { url, auth_token } => {
AllDbs::Libsql(libsql::Db::open(url, auth_token, args.timeout.into()).await?)
}
Command::Postgres { url } => {
AllDbs::Postgres(postgres::Db::open(url, args.timeout.into()).await?)
Command::Postgres { url, schema } => {
AllDbs::Postgres(postgres::Db::open(url, schema, args.timeout.into()).await?)
}
Command::Mysql { url } => AllDbs::Mysql(mysql::Db::open(url, args.timeout.into()).await?),
Command::Duckdb { database } => {
Expand Down Expand Up @@ -1200,11 +1204,16 @@ mod postgres {
#[derive(Clone)]
pub struct Db {
client: Arc<Client>,
schema: String,
query_timeout: Duration,
}

impl Db {
pub async fn open(url: String, query_timeout: Duration) -> color_eyre::Result<Self> {
pub async fn open(
url: String,
schema: String,
query_timeout: Duration,
) -> color_eyre::Result<Self> {
let (client, connection) = tokio_postgres::connect(&url, NoTls).await?;

// The connection object performs the actual communication with the database,
Expand All @@ -1217,12 +1226,14 @@ mod postgres {

let tables: i64 = client
.query_one(
r#"
&format!(
r#"
SELECT count(*)
FROM information_schema.tables
WHERE table_schema = 'public'
WHERE table_schema = '{schema}'
AND table_type = 'BASE TABLE'
"#,
"#
),
&[],
)
.await?
Expand All @@ -1234,6 +1245,7 @@ mod postgres {
);

Ok(Self {
schema,
query_timeout,
client: Arc::new(client),
})
Expand All @@ -1243,6 +1255,8 @@ mod postgres {
#[async_trait]
impl Database for Db {
async fn overview(&self) -> color_eyre::Result<responses::Overview> {
let schema = &self.schema;

let file_name: String = self
.client
.query_one("SELECT current_database()", &[])
Expand All @@ -1262,12 +1276,14 @@ mod postgres {
let tables: i64 = self
.client
.query_one(
r#"
&format!(
r#"
SELECT count(*)
FROM information_schema.tables
WHERE table_schema = 'public'
WHERE table_schema = '{schema}'
AND table_type = 'BASE TABLE'
"#,
"#
),
&[],
)
.await?
Expand All @@ -1276,11 +1292,13 @@ mod postgres {
let indexes: i64 = self
.client
.query_one(
r#"
&format!(
r#"
SELECT count(*)
FROM pg_indexes
WHERE schemaname = 'public'
"#,
WHERE schemaname = '{schema}'
"#
),
&[],
)
.await?
Expand All @@ -1289,11 +1307,13 @@ mod postgres {
let triggers: i64 = self
.client
.query_one(
r#"
&format!(
r#"
SELECT count(*)
FROM information_schema.triggers
WHERE trigger_schema = 'public'
"#,
WHERE trigger_schema = '{schema}'
"#
),
&[],
)
.await?
Expand All @@ -1302,11 +1322,13 @@ mod postgres {
let views: i64 = self
.client
.query_one(
r#"
&format!(
r#"
SELECT count(*)
FROM information_schema.views
WHERE table_schema = 'public';
"#,
WHERE table_schema = '{schema}';
"#
),
&[],
)
.await?
Expand All @@ -1315,11 +1337,13 @@ mod postgres {
let mut row_counts = self
.client
.query(
r#"
SELECT relname
FROM pg_stat_user_tables
WHERE schemaname = 'public'
"#,
&format!(
r#"
SELECT table_name
FROM information_schema.tables
WHERE table_schema = '{schema}'
"#
),
&[],
)
.await?
Expand All @@ -1344,11 +1368,13 @@ mod postgres {
let mut column_counts = self
.client
.query(
r#"
SELECT relname
FROM pg_stat_user_tables
WHERE schemaname = 'public'
"#,
&format!(
r#"
SELECT table_name
FROM information_schema.tables
WHERE table_schema = '{schema}'
"#
),
&[],
)
.await?
Expand All @@ -1367,7 +1393,7 @@ mod postgres {
r#"
SELECT count(*)
FROM information_schema.columns
WHERE table_schema = 'public'
WHERE table_schema = '{schema}'
AND table_name = '{}'
"#,
table.name
Expand All @@ -1385,11 +1411,13 @@ mod postgres {
let mut index_counts = self
.client
.query(
r#"
SELECT relname
FROM pg_stat_user_tables
WHERE schemaname = 'public'
"#,
&format!(
r#"
SELECT table_name
FROM information_schema.tables
WHERE table_schema = '{schema}'
"#
),
&[],
)
.await?
Expand Down Expand Up @@ -1439,14 +1467,18 @@ mod postgres {
}

async fn tables(&self) -> color_eyre::Result<responses::Tables> {
let schema = &self.schema;

let mut tables = self
.client
.query(
r#"
SELECT relname
FROM pg_stat_user_tables
WHERE schemaname = 'public'
"#,
&format!(
r#"
SELECT table_name
FROM information_schema.tables
WHERE table_schema = '{schema}'
"#
),
&[],
)
.await?
Expand All @@ -1472,6 +1504,8 @@ mod postgres {
}

async fn table(&self, name: String) -> color_eyre::Result<responses::Table> {
let schema = &self.schema;

let row_count: i64 = self
.client
.query_one(&format!(r#"SELECT count(*) FROM "{name}""#), &[])
Expand Down Expand Up @@ -1507,7 +1541,7 @@ mod postgres {
r#"
SELECT count(*)
FROM information_schema.columns
WHERE table_schema = 'public'
WHERE table_schema = '{schema}'
AND table_name = '{name}'
"#
),
Expand All @@ -1531,13 +1565,17 @@ mod postgres {
name: String,
page: i32,
) -> color_eyre::Result<responses::TableData> {
let schema = &self.schema;

let first_column: String = self
.client
.query_one(
&format!(
r#"
SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = '{name}'
SELECT column_name
FROM information_schema.columns
WHERE table_schema = '{schema}'
AND table_name = '{name}'
LIMIT 1
"#
),
Expand Down

0 comments on commit 7fde612

Please sign in to comment.