From 1bed87a8ea2c5d92f8a45ece093be8c0726d6963 Mon Sep 17 00:00:00 2001 From: Toby Hede Date: Fri, 6 Sep 2024 01:13:35 +1000 Subject: [PATCH] Suppor postgres `TRUNCATE` syntax (#1406) --- src/ast/mod.rs | 77 ++++++++++++++++++++++++++++++++-- src/keywords.rs | 1 + src/parser/mod.rs | 36 +++++++++++++++- tests/sqlparser_postgres.rs | 82 ++++++++++++++++++++++++++++++++++--- 4 files changed, 185 insertions(+), 11 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 0e6357cbc..2bb7a161a 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2011,11 +2011,19 @@ pub enum Statement { /// ``` /// Truncate (Hive) Truncate { - #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] - table_name: ObjectName, + table_names: Vec, partitions: Option>, /// TABLE - optional keyword; table: bool, + /// Postgres-specific option + /// [ TRUNCATE TABLE ONLY ] + only: bool, + /// Postgres-specific option + /// [ RESTART IDENTITY | CONTINUE IDENTITY ] + identity: Option, + /// Postgres-specific option + /// [ CASCADE | RESTRICT ] + cascade: Option, }, /// ```sql /// MSCK @@ -3131,12 +3139,35 @@ impl fmt::Display for Statement { Ok(()) } Statement::Truncate { - table_name, + table_names, partitions, table, + only, + identity, + cascade, } => { let table = if *table { "TABLE " } else { "" }; - write!(f, "TRUNCATE {table}{table_name}")?; + let only = if *only { "ONLY " } else { "" }; + + write!( + f, + "TRUNCATE {table}{only}{table_names}", + table_names = display_comma_separated(table_names) + )?; + + if let Some(identity) = identity { + match identity { + TruncateIdentityOption::Restart => write!(f, " RESTART IDENTITY")?, + TruncateIdentityOption::Continue => write!(f, " CONTINUE IDENTITY")?, + } + } + if let Some(cascade) = cascade { + match cascade { + TruncateCascadeOption::Cascade => write!(f, " CASCADE")?, + TruncateCascadeOption::Restrict => write!(f, " RESTRICT")?, + } + } + if let Some(ref parts) = partitions { if !parts.is_empty() { write!(f, " PARTITION ({})", display_comma_separated(parts))?; @@ -4587,6 +4618,44 @@ impl fmt::Display for SequenceOptions { } } +/// Target of a `TRUNCATE TABLE` command +/// +/// Note this is its own struct because `visit_relation` requires an `ObjectName` (not a `Vec`) +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct TruncateTableTarget { + /// name of the table being truncated + #[cfg_attr(feature = "visitor", visit(with = "visit_relation"))] + pub name: ObjectName, +} + +impl fmt::Display for TruncateTableTarget { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.name) + } +} + +/// PostgreSQL identity option for TRUNCATE table +/// [ RESTART IDENTITY | CONTINUE IDENTITY ] +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum TruncateIdentityOption { + Restart, + Continue, +} + +/// PostgreSQL cascade option for TRUNCATE table +/// [ CASCADE | RESTRICT ] +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub enum TruncateCascadeOption { + Cascade, + Restrict, +} + /// Can use to describe options in create sequence or table column type identity /// [ MINVALUE minvalue | NO MINVALUE ] [ MAXVALUE maxvalue | NO MAXVALUE ] #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] diff --git a/src/keywords.rs b/src/keywords.rs index ce4972f98..ae0f14f18 100644 --- a/src/keywords.rs +++ b/src/keywords.rs @@ -177,6 +177,7 @@ define_keywords!( CONNECTION, CONSTRAINT, CONTAINS, + CONTINUE, CONVERT, COPY, COPY_OPTIONS, diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 977372656..30e776787 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -681,17 +681,49 @@ impl<'a> Parser<'a> { pub fn parse_truncate(&mut self) -> Result { let table = self.parse_keyword(Keyword::TABLE); - let table_name = self.parse_object_name(false)?; + let only = self.parse_keyword(Keyword::ONLY); + + let table_names = self + .parse_comma_separated(|p| p.parse_object_name(false))? + .into_iter() + .map(|n| TruncateTableTarget { name: n }) + .collect(); + let mut partitions = None; if self.parse_keyword(Keyword::PARTITION) { self.expect_token(&Token::LParen)?; partitions = Some(self.parse_comma_separated(Parser::parse_expr)?); self.expect_token(&Token::RParen)?; } + + let mut identity = None; + let mut cascade = None; + + if dialect_of!(self is PostgreSqlDialect | GenericDialect) { + identity = if self.parse_keywords(&[Keyword::RESTART, Keyword::IDENTITY]) { + Some(TruncateIdentityOption::Restart) + } else if self.parse_keywords(&[Keyword::CONTINUE, Keyword::IDENTITY]) { + Some(TruncateIdentityOption::Continue) + } else { + None + }; + + cascade = if self.parse_keyword(Keyword::CASCADE) { + Some(TruncateCascadeOption::Cascade) + } else if self.parse_keyword(Keyword::RESTRICT) { + Some(TruncateCascadeOption::Restrict) + } else { + None + }; + }; + Ok(Statement::Truncate { - table_name, + table_names, partitions, table, + only, + identity, + cascade, }) } diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index d96211823..1ebb5d54c 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -571,6 +571,10 @@ fn parse_alter_table_constraints_rename() { fn parse_alter_table_disable() { pg_and_generic().verified_stmt("ALTER TABLE tab DISABLE ROW LEVEL SECURITY"); pg_and_generic().verified_stmt("ALTER TABLE tab DISABLE RULE rule_name"); +} + +#[test] +fn parse_alter_table_disable_trigger() { pg_and_generic().verified_stmt("ALTER TABLE tab DISABLE TRIGGER ALL"); pg_and_generic().verified_stmt("ALTER TABLE tab DISABLE TRIGGER USER"); pg_and_generic().verified_stmt("ALTER TABLE tab DISABLE TRIGGER trigger_name"); @@ -589,6 +593,13 @@ fn parse_alter_table_enable() { pg_and_generic().verified_stmt("ALTER TABLE tab ENABLE TRIGGER trigger_name"); } +#[test] +fn parse_truncate_table() { + pg_and_generic() + .verified_stmt("TRUNCATE TABLE \"users\", \"orders\" RESTART IDENTITY RESTRICT"); + pg_and_generic().verified_stmt("TRUNCATE users, orders RESTART IDENTITY"); +} + #[test] fn parse_create_extension() { pg_and_generic().verified_stmt("CREATE EXTENSION extension_name"); @@ -3967,11 +3978,72 @@ fn parse_select_group_by_cube() { #[test] fn parse_truncate() { let truncate = pg_and_generic().verified_stmt("TRUNCATE db.table_name"); + let table_name = ObjectName(vec![Ident::new("db"), Ident::new("table_name")]); + let table_names = vec![TruncateTableTarget { + name: table_name.clone(), + }]; + assert_eq!( + Statement::Truncate { + table_names, + partitions: None, + table: false, + only: false, + identity: None, + cascade: None, + }, + truncate + ); +} + +#[test] +fn parse_truncate_with_options() { + let truncate = pg_and_generic() + .verified_stmt("TRUNCATE TABLE ONLY db.table_name RESTART IDENTITY CASCADE"); + + let table_name = ObjectName(vec![Ident::new("db"), Ident::new("table_name")]); + let table_names = vec![TruncateTableTarget { + name: table_name.clone(), + }]; + assert_eq!( Statement::Truncate { - table_name: ObjectName(vec![Ident::new("db"), Ident::new("table_name")]), + table_names, partitions: None, - table: false + table: true, + only: true, + identity: Some(TruncateIdentityOption::Restart), + cascade: Some(TruncateCascadeOption::Cascade) + }, + truncate + ); +} + +#[test] +fn parse_truncate_with_table_list() { + let truncate = pg().verified_stmt( + "TRUNCATE TABLE db.table_name, db.other_table_name RESTART IDENTITY CASCADE", + ); + + let table_name_a = ObjectName(vec![Ident::new("db"), Ident::new("table_name")]); + let table_name_b = ObjectName(vec![Ident::new("db"), Ident::new("other_table_name")]); + + let table_names = vec![ + TruncateTableTarget { + name: table_name_a.clone(), + }, + TruncateTableTarget { + name: table_name_b.clone(), + }, + ]; + + assert_eq!( + Statement::Truncate { + table_names, + partitions: None, + table: true, + only: false, + identity: Some(TruncateIdentityOption::Restart), + cascade: Some(TruncateCascadeOption::Cascade) }, truncate ); @@ -4745,12 +4817,12 @@ fn parse_trigger_related_functions() { IF NEW.salary IS NULL THEN RAISE EXCEPTION '% cannot have null salary', NEW.empname; END IF; - + -- Who works for us when they must pay for it? IF NEW.salary < 0 THEN RAISE EXCEPTION '% cannot have a negative salary', NEW.empname; END IF; - + -- Remember who changed the payroll when NEW.last_date := current_timestamp; NEW.last_user := current_user; @@ -4883,7 +4955,7 @@ fn parse_trigger_related_functions() { Expr::Value( Value::DollarQuotedString( DollarQuotedString { - value: "\n BEGIN\n -- Check that empname and salary are given\n IF NEW.empname IS NULL THEN\n RAISE EXCEPTION 'empname cannot be null';\n END IF;\n IF NEW.salary IS NULL THEN\n RAISE EXCEPTION '% cannot have null salary', NEW.empname;\n END IF;\n \n -- Who works for us when they must pay for it?\n IF NEW.salary < 0 THEN\n RAISE EXCEPTION '% cannot have a negative salary', NEW.empname;\n END IF;\n \n -- Remember who changed the payroll when\n NEW.last_date := current_timestamp;\n NEW.last_user := current_user;\n RETURN NEW;\n END;\n ".to_owned(), + value: "\n BEGIN\n -- Check that empname and salary are given\n IF NEW.empname IS NULL THEN\n RAISE EXCEPTION 'empname cannot be null';\n END IF;\n IF NEW.salary IS NULL THEN\n RAISE EXCEPTION '% cannot have null salary', NEW.empname;\n END IF;\n\n -- Who works for us when they must pay for it?\n IF NEW.salary < 0 THEN\n RAISE EXCEPTION '% cannot have a negative salary', NEW.empname;\n END IF;\n\n -- Remember who changed the payroll when\n NEW.last_date := current_timestamp;\n NEW.last_user := current_user;\n RETURN NEW;\n END;\n ".to_owned(), tag: Some( "emp_stamp".to_owned(), ),