Skip to content

Commit

Permalink
Suppor postgres TRUNCATE syntax (#1406)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobyhede authored Sep 5, 2024
1 parent 4d52ee7 commit 1bed87a
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 11 deletions.
77 changes: 73 additions & 4 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2011,11 +2011,19 @@ pub enum Statement {
/// ```
/// Truncate (Hive)
Truncate {
#[cfg_attr(feature = "visitor", visit(with = "visit_relation"))]
table_name: ObjectName,
table_names: Vec<TruncateTableTarget>,
partitions: Option<Vec<Expr>>,
/// TABLE - optional keyword;
table: bool,
/// Postgres-specific option
/// [ TRUNCATE TABLE ONLY ]
only: bool,
/// Postgres-specific option
/// [ RESTART IDENTITY | CONTINUE IDENTITY ]
identity: Option<TruncateIdentityOption>,
/// Postgres-specific option
/// [ CASCADE | RESTRICT ]
cascade: Option<TruncateCascadeOption>,
},
/// ```sql
/// MSCK
Expand Down Expand Up @@ -3131,12 +3139,35 @@ impl fmt::Display for Statement {
Ok(())
}
Statement::Truncate {
table_name,
table_names,
partitions,
table,
only,
identity,
cascade,
} => {
let table = if *table { "TABLE " } else { "" };
write!(f, "TRUNCATE {table}{table_name}")?;
let only = if *only { "ONLY " } else { "" };

write!(
f,
"TRUNCATE {table}{only}{table_names}",
table_names = display_comma_separated(table_names)
)?;

if let Some(identity) = identity {
match identity {
TruncateIdentityOption::Restart => write!(f, " RESTART IDENTITY")?,
TruncateIdentityOption::Continue => write!(f, " CONTINUE IDENTITY")?,
}
}
if let Some(cascade) = cascade {
match cascade {
TruncateCascadeOption::Cascade => write!(f, " CASCADE")?,
TruncateCascadeOption::Restrict => write!(f, " RESTRICT")?,
}
}

if let Some(ref parts) = partitions {
if !parts.is_empty() {
write!(f, " PARTITION ({})", display_comma_separated(parts))?;
Expand Down Expand Up @@ -4587,6 +4618,44 @@ impl fmt::Display for SequenceOptions {
}
}

/// Target of a `TRUNCATE TABLE` command
///
/// Note this is its own struct because `visit_relation` requires an `ObjectName` (not a `Vec<ObjectName>`)
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct TruncateTableTarget {
/// name of the table being truncated
#[cfg_attr(feature = "visitor", visit(with = "visit_relation"))]
pub name: ObjectName,
}

impl fmt::Display for TruncateTableTarget {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.name)
}
}

/// PostgreSQL identity option for TRUNCATE table
/// [ RESTART IDENTITY | CONTINUE IDENTITY ]
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum TruncateIdentityOption {
Restart,
Continue,
}

/// PostgreSQL cascade option for TRUNCATE table
/// [ CASCADE | RESTRICT ]
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum TruncateCascadeOption {
Cascade,
Restrict,
}

/// Can use to describe options in create sequence or table column type identity
/// [ MINVALUE minvalue | NO MINVALUE ] [ MAXVALUE maxvalue | NO MAXVALUE ]
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
Expand Down
1 change: 1 addition & 0 deletions src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ define_keywords!(
CONNECTION,
CONSTRAINT,
CONTAINS,
CONTINUE,
CONVERT,
COPY,
COPY_OPTIONS,
Expand Down
36 changes: 34 additions & 2 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -681,17 +681,49 @@ impl<'a> Parser<'a> {

pub fn parse_truncate(&mut self) -> Result<Statement, ParserError> {
let table = self.parse_keyword(Keyword::TABLE);
let table_name = self.parse_object_name(false)?;
let only = self.parse_keyword(Keyword::ONLY);

let table_names = self
.parse_comma_separated(|p| p.parse_object_name(false))?
.into_iter()
.map(|n| TruncateTableTarget { name: n })
.collect();

let mut partitions = None;
if self.parse_keyword(Keyword::PARTITION) {
self.expect_token(&Token::LParen)?;
partitions = Some(self.parse_comma_separated(Parser::parse_expr)?);
self.expect_token(&Token::RParen)?;
}

let mut identity = None;
let mut cascade = None;

if dialect_of!(self is PostgreSqlDialect | GenericDialect) {
identity = if self.parse_keywords(&[Keyword::RESTART, Keyword::IDENTITY]) {
Some(TruncateIdentityOption::Restart)
} else if self.parse_keywords(&[Keyword::CONTINUE, Keyword::IDENTITY]) {
Some(TruncateIdentityOption::Continue)
} else {
None
};

cascade = if self.parse_keyword(Keyword::CASCADE) {
Some(TruncateCascadeOption::Cascade)
} else if self.parse_keyword(Keyword::RESTRICT) {
Some(TruncateCascadeOption::Restrict)
} else {
None
};
};

Ok(Statement::Truncate {
table_name,
table_names,
partitions,
table,
only,
identity,
cascade,
})
}

Expand Down
82 changes: 77 additions & 5 deletions tests/sqlparser_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,10 @@ fn parse_alter_table_constraints_rename() {
fn parse_alter_table_disable() {
pg_and_generic().verified_stmt("ALTER TABLE tab DISABLE ROW LEVEL SECURITY");
pg_and_generic().verified_stmt("ALTER TABLE tab DISABLE RULE rule_name");
}

#[test]
fn parse_alter_table_disable_trigger() {
pg_and_generic().verified_stmt("ALTER TABLE tab DISABLE TRIGGER ALL");
pg_and_generic().verified_stmt("ALTER TABLE tab DISABLE TRIGGER USER");
pg_and_generic().verified_stmt("ALTER TABLE tab DISABLE TRIGGER trigger_name");
Expand All @@ -589,6 +593,13 @@ fn parse_alter_table_enable() {
pg_and_generic().verified_stmt("ALTER TABLE tab ENABLE TRIGGER trigger_name");
}

#[test]
fn parse_truncate_table() {
pg_and_generic()
.verified_stmt("TRUNCATE TABLE \"users\", \"orders\" RESTART IDENTITY RESTRICT");
pg_and_generic().verified_stmt("TRUNCATE users, orders RESTART IDENTITY");
}

#[test]
fn parse_create_extension() {
pg_and_generic().verified_stmt("CREATE EXTENSION extension_name");
Expand Down Expand Up @@ -3967,11 +3978,72 @@ fn parse_select_group_by_cube() {
#[test]
fn parse_truncate() {
let truncate = pg_and_generic().verified_stmt("TRUNCATE db.table_name");
let table_name = ObjectName(vec![Ident::new("db"), Ident::new("table_name")]);
let table_names = vec![TruncateTableTarget {
name: table_name.clone(),
}];
assert_eq!(
Statement::Truncate {
table_names,
partitions: None,
table: false,
only: false,
identity: None,
cascade: None,
},
truncate
);
}

#[test]
fn parse_truncate_with_options() {
let truncate = pg_and_generic()
.verified_stmt("TRUNCATE TABLE ONLY db.table_name RESTART IDENTITY CASCADE");

let table_name = ObjectName(vec![Ident::new("db"), Ident::new("table_name")]);
let table_names = vec![TruncateTableTarget {
name: table_name.clone(),
}];

assert_eq!(
Statement::Truncate {
table_name: ObjectName(vec![Ident::new("db"), Ident::new("table_name")]),
table_names,
partitions: None,
table: false
table: true,
only: true,
identity: Some(TruncateIdentityOption::Restart),
cascade: Some(TruncateCascadeOption::Cascade)
},
truncate
);
}

#[test]
fn parse_truncate_with_table_list() {
let truncate = pg().verified_stmt(
"TRUNCATE TABLE db.table_name, db.other_table_name RESTART IDENTITY CASCADE",
);

let table_name_a = ObjectName(vec![Ident::new("db"), Ident::new("table_name")]);
let table_name_b = ObjectName(vec![Ident::new("db"), Ident::new("other_table_name")]);

let table_names = vec![
TruncateTableTarget {
name: table_name_a.clone(),
},
TruncateTableTarget {
name: table_name_b.clone(),
},
];

assert_eq!(
Statement::Truncate {
table_names,
partitions: None,
table: true,
only: false,
identity: Some(TruncateIdentityOption::Restart),
cascade: Some(TruncateCascadeOption::Cascade)
},
truncate
);
Expand Down Expand Up @@ -4745,12 +4817,12 @@ fn parse_trigger_related_functions() {
IF NEW.salary IS NULL THEN
RAISE EXCEPTION '% cannot have null salary', NEW.empname;
END IF;
-- Who works for us when they must pay for it?
IF NEW.salary < 0 THEN
RAISE EXCEPTION '% cannot have a negative salary', NEW.empname;
END IF;
-- Remember who changed the payroll when
NEW.last_date := current_timestamp;
NEW.last_user := current_user;
Expand Down Expand Up @@ -4883,7 +4955,7 @@ fn parse_trigger_related_functions() {
Expr::Value(
Value::DollarQuotedString(
DollarQuotedString {
value: "\n BEGIN\n -- Check that empname and salary are given\n IF NEW.empname IS NULL THEN\n RAISE EXCEPTION 'empname cannot be null';\n END IF;\n IF NEW.salary IS NULL THEN\n RAISE EXCEPTION '% cannot have null salary', NEW.empname;\n END IF;\n \n -- Who works for us when they must pay for it?\n IF NEW.salary < 0 THEN\n RAISE EXCEPTION '% cannot have a negative salary', NEW.empname;\n END IF;\n \n -- Remember who changed the payroll when\n NEW.last_date := current_timestamp;\n NEW.last_user := current_user;\n RETURN NEW;\n END;\n ".to_owned(),
value: "\n BEGIN\n -- Check that empname and salary are given\n IF NEW.empname IS NULL THEN\n RAISE EXCEPTION 'empname cannot be null';\n END IF;\n IF NEW.salary IS NULL THEN\n RAISE EXCEPTION '% cannot have null salary', NEW.empname;\n END IF;\n\n -- Who works for us when they must pay for it?\n IF NEW.salary < 0 THEN\n RAISE EXCEPTION '% cannot have a negative salary', NEW.empname;\n END IF;\n\n -- Remember who changed the payroll when\n NEW.last_date := current_timestamp;\n NEW.last_user := current_user;\n RETURN NEW;\n END;\n ".to_owned(),
tag: Some(
"emp_stamp".to_owned(),
),
Expand Down

0 comments on commit 1bed87a

Please sign in to comment.