Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Returning #292

Merged
merged 33 commits into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
c5468eb
Use "marlon-sousa/sea-query"
billy1624 Nov 5, 2021
c39a3b8
Insert with returning for Postgres
billy1624 Nov 5, 2021
52ff943
Docs
billy1624 Nov 5, 2021
a977572
Update with returning for Postgres
billy1624 Nov 5, 2021
50605c7
FIXME: breaking behaviors
billy1624 Nov 5, 2021
6238736
Handle "None of the database rows are affected" for Postgres
billy1624 Nov 8, 2021
2f7cffa
Fix test cases
billy1624 Nov 8, 2021
732d080
Update docs
billy1624 Nov 8, 2021
0eafacc
Try returning on MariaDB
billy1624 Nov 8, 2021
80c0d69
Merge remote-tracking branch 'origin/master' into returning
billy1624 Nov 8, 2021
30f43b6
Fixup
billy1624 Nov 8, 2021
2f0ac4c
Fixup
billy1624 Nov 8, 2021
afdb1af
This will fail loll
billy1624 Nov 8, 2021
1723206
This will fail loll
billy1624 Nov 8, 2021
3e6423a
This will fail loll
billy1624 Nov 8, 2021
30a50ca
Try
billy1624 Nov 9, 2021
429b920
Fixup
billy1624 Nov 9, 2021
24fab66
Try
billy1624 Nov 9, 2021
8020ae1
Fixup
billy1624 Nov 9, 2021
533c3cf
Try
billy1624 Nov 9, 2021
ec637b2
Returning support for SQLite
billy1624 Nov 9, 2021
c1fae1b
Debug print
billy1624 Nov 9, 2021
cc035d7
Refactoring
billy1624 Nov 9, 2021
66c23c8
Revert MySQL & SQLite returning support
billy1624 Nov 10, 2021
257a893
Use `sea-query` master
billy1624 Nov 10, 2021
4d44827
Docs
billy1624 Nov 11, 2021
fd50ffd
Merge remote-tracking branch 'origin/master' into returning
billy1624 Nov 16, 2021
d5de8b1
Should fail
billy1624 Nov 16, 2021
9655805
Will fail, as expected
billy1624 Nov 16, 2021
4c147a2
Rewrite doctests
billy1624 Nov 16, 2021
f9d04fc
Hotfix - separate counter for mock exec & query
billy1624 Nov 16, 2021
7298fde
Rewrite doctests
billy1624 Nov 16, 2021
42404eb
Fixup
billy1624 Nov 16, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Try returning on MariaDB
  • Loading branch information
billy1624 committed Nov 8, 2021
commit 0eafacc2a1bc34f0499c0b8a020fbe41504bb97d
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ sqlx = { version = "^0.5", optional = true }
uuid = { version = "0.8", features = ["serde", "v4"], optional = true }
ouroboros = "0.11"
url = "^2.2"
regex = "^1"

[dev-dependencies]
smol = { version = "^1.2" }
Expand Down
3 changes: 3 additions & 0 deletions src/database/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ pub trait ConnectionTrait<'a>: Sync {
T: Send,
E: std::error::Error + Send;

/// Check if the connection supports `RETURNING` syntax
fn support_returning(&self) -> bool;

/// Check if the connection is a test connection for the Mock database
fn is_mock_connection(&self) -> bool {
false
Expand Down
45 changes: 36 additions & 9 deletions src/database/db_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ use std::sync::Arc;
pub enum DatabaseConnection {
/// Create a MYSQL database connection and pool
#[cfg(feature = "sqlx-mysql")]
SqlxMySqlPoolConnection(crate::SqlxMySqlPoolConnection),
SqlxMySqlPoolConnection {
/// A SQLx MySQL pool
conn: crate::SqlxMySqlPoolConnection,
/// A flag indicating whether `RETURNING` syntax is supported
support_returning: bool,
},
/// Create a PostgreSQL database connection and pool
#[cfg(feature = "sqlx-postgres")]
SqlxPostgresPoolConnection(crate::SqlxPostgresPoolConnection),
Expand Down Expand Up @@ -73,7 +78,7 @@ impl std::fmt::Debug for DatabaseConnection {
"{}",
match self {
#[cfg(feature = "sqlx-mysql")]
Self::SqlxMySqlPoolConnection(_) => "SqlxMySqlPoolConnection",
Self::SqlxMySqlPoolConnection { .. } => "SqlxMySqlPoolConnection",
#[cfg(feature = "sqlx-postgres")]
Self::SqlxPostgresPoolConnection(_) => "SqlxPostgresPoolConnection",
#[cfg(feature = "sqlx-sqlite")]
Expand All @@ -93,7 +98,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
fn get_database_backend(&self) -> DbBackend {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(_) => DbBackend::MySql,
DatabaseConnection::SqlxMySqlPoolConnection { .. } => DbBackend::MySql,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(_) => DbBackend::Postgres,
#[cfg(feature = "sqlx-sqlite")]
Expand All @@ -107,7 +112,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
async fn execute(&self, stmt: Statement) -> Result<ExecResult, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.execute(stmt).await,
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.execute(stmt).await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.execute(stmt).await,
#[cfg(feature = "sqlx-sqlite")]
Expand All @@ -121,7 +126,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
async fn query_one(&self, stmt: Statement) -> Result<Option<QueryResult>, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_one(stmt).await,
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.query_one(stmt).await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_one(stmt).await,
#[cfg(feature = "sqlx-sqlite")]
Expand All @@ -135,7 +140,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
async fn query_all(&self, stmt: Statement) -> Result<Vec<QueryResult>, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.query_all(stmt).await,
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.query_all(stmt).await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.query_all(stmt).await,
#[cfg(feature = "sqlx-sqlite")]
Expand All @@ -153,7 +158,9 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
Box::pin(async move {
Ok(match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.stream(stmt).await?,
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => {
conn.stream(stmt).await?
}
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.stream(stmt).await?,
#[cfg(feature = "sqlx-sqlite")]
Expand All @@ -170,7 +177,7 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
async fn begin(&self) -> Result<DatabaseTransaction, DbErr> {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.begin().await,
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => conn.begin().await,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => conn.begin().await,
#[cfg(feature = "sqlx-sqlite")]
Expand All @@ -196,7 +203,9 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
{
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection(conn) => conn.transaction(_callback).await,
DatabaseConnection::SqlxMySqlPoolConnection { conn, .. } => {
conn.transaction(_callback).await
}
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => {
conn.transaction(_callback).await
Expand All @@ -214,6 +223,24 @@ impl<'a> ConnectionTrait<'a> for DatabaseConnection {
}
}

fn support_returning(&self) -> bool {
match self {
#[cfg(feature = "sqlx-mysql")]
DatabaseConnection::SqlxMySqlPoolConnection { .. } => false,
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(_) => true,
#[cfg(feature = "sqlx-sqlite")]
DatabaseConnection::SqlxSqlitePoolConnection(_) => false,
#[cfg(feature = "mock")]
DatabaseConnection::MockDatabaseConnection(conn) => match conn.get_database_backend() {
DbBackend::MySql => false,
DbBackend::Postgres => true,
DbBackend::Sqlite => false,
},
DatabaseConnection::Disconnected => panic!("Disconnected"),
}
}

#[cfg(feature = "mock")]
fn is_mock_connection(&self) -> bool {
matches!(self, DatabaseConnection::MockDatabaseConnection(_))
Expand Down
8 changes: 8 additions & 0 deletions src/database/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,14 @@ impl<'a> ConnectionTrait<'a> for DatabaseTransaction {
let transaction = self.begin().await.map_err(TransactionError::Connection)?;
transaction.run(_callback).await
}

fn support_returning(&self) -> bool {
match self.backend {
DbBackend::MySql => false,
DbBackend::Postgres => true,
DbBackend::Sqlite => false,
}
}
}

/// Defines errors for handling transaction failures
Expand Down
51 changes: 45 additions & 6 deletions src/driver/sqlx_mysql.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use regex::Regex;
use std::{future::Future, pin::Pin};

use sqlx::{
Expand All @@ -10,7 +11,7 @@ use sea_query_driver_mysql::bind_query;

use crate::{
debug_print, error::*, executor::*, ConnectOptions, DatabaseConnection, DatabaseTransaction,
QueryStream, Statement, TransactionError,
DbBackend, QueryStream, Statement, TransactionError,
};

use super::sqlx_common::*;
Expand Down Expand Up @@ -42,9 +43,7 @@ impl SqlxMySqlConnector {
opt.disable_statement_logging();
}
if let Ok(pool) = options.pool_options().connect_with(opt).await {
Ok(DatabaseConnection::SqlxMySqlPoolConnection(
SqlxMySqlPoolConnection { pool },
))
into_db_connection(pool).await
} else {
Err(DbErr::Conn("Failed to connect.".to_owned()))
}
Expand All @@ -53,8 +52,8 @@ impl SqlxMySqlConnector {

impl SqlxMySqlConnector {
/// Instantiate a sqlx pool connection to a [DatabaseConnection]
pub fn from_sqlx_mysql_pool(pool: MySqlPool) -> DatabaseConnection {
DatabaseConnection::SqlxMySqlPoolConnection(SqlxMySqlPoolConnection { pool })
pub async fn from_sqlx_mysql_pool(pool: MySqlPool) -> Result<DatabaseConnection, DbErr> {
into_db_connection(pool).await
}
}

Expand Down Expand Up @@ -183,3 +182,43 @@ pub(crate) fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, MySql, MySq
}
query
}

async fn into_db_connection(pool: MySqlPool) -> Result<DatabaseConnection, DbErr> {
let conn = SqlxMySqlPoolConnection { pool };
let res = conn
.query_one(Statement::from_string(
DbBackend::MySql,
r#"SHOW VARIABLES LIKE "version""#.to_owned(),
))
.await?;
let support_returning = if let Some(query_result) = res {
let version: String = query_result.try_get("", "Value")?;
if !version.contains("MariaDB") {
// This is MySQL
false
} else {
// This is MariaDB
let regex = Regex::new(r"^(\d+)?.(\d+)?.(\*|\d+)").unwrap();
let captures = regex.captures(&version).unwrap();
macro_rules! parse_captures {
( $idx: expr ) => {
captures.get($idx).map_or(0, |m| {
m.as_str()
.parse::<usize>()
.map_err(|e| DbErr::Conn(e.to_string()))
.unwrap()
})
};
}
let ver_major = parse_captures!(1);
let ver_minor = parse_captures!(2);
ver_major >= 10 && ver_minor >= 5
}
} else {
return Err(DbErr::Conn("Fail to parse MySQL version".to_owned()));
};
Ok(DatabaseConnection::SqlxMySqlPoolConnection {
conn,
support_returning,
})
}
20 changes: 9 additions & 11 deletions src/executor/insert.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
error::*, ActiveModelTrait, ConnectionTrait, DbBackend, EntityTrait, Insert, IntoActiveModel,
Iterable, PrimaryKeyTrait, SelectModel, SelectorRaw, Statement, TryFromU64,
error::*, ActiveModelTrait, ConnectionTrait, EntityTrait, Insert, IntoActiveModel, Iterable,
PrimaryKeyTrait, SelectModel, SelectorRaw, Statement, TryFromU64,
};
use sea_query::{FromValueTuple, Iden, InsertStatement, IntoColumnRef, Returning, ValueTuple};
use std::{future::Future, marker::PhantomData};
Expand Down Expand Up @@ -39,9 +39,7 @@ where
{
// so that self is dropped before entering await
let mut query = self.query;
if db.get_database_backend() == DbBackend::Postgres
&& <A::Entity as EntityTrait>::PrimaryKey::iter().count() > 0
{
if db.support_returning() && <A::Entity as EntityTrait>::PrimaryKey::iter().count() > 0 {
query.returning(Returning::Columns(
<A::Entity as EntityTrait>::PrimaryKey::iter()
.map(|c| c.into_column_ref())
Expand Down Expand Up @@ -113,15 +111,15 @@ where
{
type PrimaryKey<A> = <<A as ActiveModelTrait>::Entity as EntityTrait>::PrimaryKey;
type ValueTypeOf<A> = <PrimaryKey<A> as PrimaryKeyTrait>::ValueType;
let last_insert_id_opt = match db.get_database_backend() {
DbBackend::Postgres => {
let last_insert_id_opt = match db.support_returning() {
true => {
let cols = PrimaryKey::<A>::iter()
.map(|col| col.to_string())
.collect::<Vec<_>>();
let res = db.query_one(statement).await?.unwrap();
res.try_get_many("", cols.as_ref()).ok()
}
_ => {
false => {
let last_insert_id = db.execute(statement).await?.last_insert_id();
ValueTypeOf::<A>::try_from_u64(last_insert_id).ok()
}
Expand All @@ -147,8 +145,8 @@ where
A: ActiveModelTrait,
{
let db_backend = db.get_database_backend();
let found = match db_backend {
DbBackend::Postgres => {
let found = match db.support_returning() {
true => {
insert_statement.returning(Returning::Columns(
<A::Entity as EntityTrait>::Column::iter()
.map(|c| c.into_column_ref())
Expand All @@ -160,7 +158,7 @@ where
.one(db)
.await?
}
_ => {
false => {
let insert_res =
exec_insert::<A, _>(primary_key, db_backend.build(&insert_statement), db).await?;
<A::Entity as EntityTrait>::find_by_id(insert_res.last_insert_id)
Expand Down
10 changes: 5 additions & 5 deletions src/executor/update.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
error::*, ActiveModelTrait, ConnectionTrait, DbBackend, EntityTrait, IntoActiveModel, Iterable,
error::*, ActiveModelTrait, ConnectionTrait, EntityTrait, IntoActiveModel, Iterable,
SelectModel, SelectorRaw, Statement, UpdateMany, UpdateOne,
};
use sea_query::{FromValueTuple, IntoColumnRef, Returning, UpdateStatement};
Expand Down Expand Up @@ -90,14 +90,14 @@ where
A: ActiveModelTrait,
C: ConnectionTrait<'a>,
{
let db_backend = db.get_database_backend();
match db_backend {
DbBackend::Postgres => {
match db.support_returning() {
true => {
query.returning(Returning::Columns(
<A::Entity as EntityTrait>::Column::iter()
.map(|c| c.into_column_ref())
.collect(),
));
let db_backend = db.get_database_backend();
let found: Option<<A::Entity as EntityTrait>::Model> =
SelectorRaw::<SelectModel<<A::Entity as EntityTrait>::Model>>::from_statement(
db_backend.build(&query),
Expand All @@ -112,7 +112,7 @@ where
)),
}
}
_ => {
false => {
// If we updating a row that does not exist then an error will be thrown here.
Updater::new(query).check_record_exists().exec(db).await?;
let primary_key_value = match model.get_primary_key_value() {
Expand Down