diff --git a/CHANGELOG.md b/CHANGELOG.md index 654e33a7cf37..f5ec67412388 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,9 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/ functionality of `NonAggregate`. See [the upgrade notes](#2-0-0-upgrade-non-aggregate) for details. +* It is now possible to inspect the type of values returned from the database + in such a way to support constructing a dynamic value depending on this type. + ### Removed @@ -47,6 +50,7 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/ * Support for `bigdecimal` < 0.0.13 has been removed. * Support for `pq-sys` < 0.4.0 has been removed. * Support for `mysqlclient-sys` < 0.2.0 has been removed. +* The `NonNull` for sql types has been removed in favour of the new `SqlType` trait. ### Changed @@ -68,8 +72,6 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/ * The `RawValue` types for the `Mysql` and `Postgresql` backend where changed from `[u8]` to distinct opaque types. If you used the concrete `RawValue` type somewhere you need to change it to `mysql::MysqlValue` or `pg::PgValue`. - For the postgres backend additionally type information where added to the `RawValue` - type. This allows to dynamically deserialize `RawValues` in container types. * The uuidv07 feature was renamed to uuid, due to the removal of support for older uuid versions @@ -93,6 +95,26 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/ `#[non_exhaustive]`. If you matched on one of those variants explicitly you need to introduce a wild card match instead. +* `FromSql::from_sql` is changed to construct value from non nullable database values. + To construct a rust value for nullable values use the new `FromSql::from_nullable_sql` + method instead. + +* Custom sql types are now required to implement the new `SqlType` trait. Diesel will + automatically create implementations of that trait for all types having a `#[derive(SqlType)]` + +* The workflow for manually implementing support custom types has changed. Implementing + `FromSqlRow` is not required anymore, as this is now implied by implementing + `FromSql`. The requirement of implementing `Queryable` remains + unchanged. For types using `#[derive(FromSqlRow)]` no changes are required as the + derive automatically generates the correct code + +* The structure of our deserialization trait has changed. Loading values from the database + requires now that the result type implements `FromSqlRow`. Diesel provides wild + card implementations for types implementing `Queryable` or `QueryableByName` + so non generic code does not require any change. For generic code you likely need to + replace a trait bound on `Queryable` with a trait bound on `FromSqlRow` + and a bound to `QueryableByName` with `FromSqlRow`. + ### Fixed @@ -129,6 +151,10 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/ * We've refactored our type translation layer for Mysql to handle more types now. +* We've refactored our type level representation of nullable values. This allowed us to + fix multiple long standing bugs regarding the correct handling of nullable values in some + corner cases (#104, #2274) + ### Deprecated * `diesel_(prefix|postfix|infix)_operator!` have been deprecated. These macros @@ -138,7 +164,6 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/ * `diesel::pg::upsert` has been deprecated to support upsert queries on more than one backend. Please use `diesel::upsert` instead. - ### Upgrade Notes #### Replacement of `NonAggregate` with `ValidGrouping` diff --git a/diesel/Cargo.toml b/diesel/Cargo.toml index 90a48319d4cb..b4f0259ca8bd 100644 --- a/diesel/Cargo.toml +++ b/diesel/Cargo.toml @@ -32,6 +32,7 @@ num-integer = { version = "0.1.39", optional = true } bigdecimal = { version = ">= 0.0.13, < 0.2.0", optional = true } bitflags = { version = "1.2.0", optional = true } r2d2 = { version = ">= 0.8, < 0.9", optional = true } +itoa = "0.4" [dependencies.diesel_derives] version = "~2.0.0" diff --git a/diesel/src/associations/belongs_to.rs b/diesel/src/associations/belongs_to.rs index d6e76e4ff12f..377c01ad0025 100644 --- a/diesel/src/associations/belongs_to.rs +++ b/diesel/src/associations/belongs_to.rs @@ -4,6 +4,7 @@ use crate::expression::array_comparison::AsInExpression; use crate::expression::AsExpression; use crate::prelude::*; use crate::query_dsl::methods::FilterDsl; +use crate::sql_types::SqlType; use std::borrow::Borrow; use std::hash::Hash; @@ -139,6 +140,7 @@ where Id<&'a Parent>: AsExpression<::SqlType>, Child::Table: FilterDsl>>, Child::ForeignKeyColumn: ExpressionMethods, + ::SqlType: SqlType, { type Output = FindBy>; @@ -154,6 +156,7 @@ where Vec>: AsInExpression<::SqlType>, ::Table: FilterDsl>>>, Child::ForeignKeyColumn: ExpressionMethods, + ::SqlType: SqlType, { type Output = Filter>>>; diff --git a/diesel/src/connection/mod.rs b/diesel/src/connection/mod.rs index 328b81328654..de1cbff09fa6 100644 --- a/diesel/src/connection/mod.rs +++ b/diesel/src/connection/mod.rs @@ -6,10 +6,10 @@ mod transaction_manager; use std::fmt::Debug; use crate::backend::Backend; -use crate::deserialize::{Queryable, QueryableByName}; +use crate::deserialize::FromSqlRow; +use crate::expression::QueryMetadata; use crate::query_builder::{AsQuery, QueryFragment, QueryId}; use crate::result::*; -use crate::sql_types::HasSqlType; #[doc(hidden)] pub use self::statement_cache::{MaybeCached, StatementCache, StatementCacheKey}; @@ -169,18 +169,12 @@ pub trait Connection: SimpleConnection + Sized + Send { fn execute(&self, query: &str) -> QueryResult; #[doc(hidden)] - fn query_by_index(&self, source: T) -> QueryResult> + fn load(&self, source: T) -> QueryResult> where T: AsQuery, T::Query: QueryFragment + QueryId, - Self::Backend: HasSqlType, - U: Queryable; - - #[doc(hidden)] - fn query_by_name(&self, source: &T) -> QueryResult> - where - T: QueryFragment + QueryId, - U: QueryableByName; + U: FromSqlRow, + Self::Backend: QueryMetadata; #[doc(hidden)] fn execute_returning_count(&self, source: &T) -> QueryResult diff --git a/diesel/src/deserialize.rs b/diesel/src/deserialize.rs index 27bbf21a7ed5..0e35bf996d55 100644 --- a/diesel/src/deserialize.rs +++ b/diesel/src/deserialize.rs @@ -5,6 +5,7 @@ use std::result; use crate::backend::{self, Backend}; use crate::row::{NamedRow, Row}; +use crate::sql_types::{SingleValue, SqlType, Untyped}; /// A specialized result type representing the result of deserializing /// a value from the database. @@ -56,7 +57,8 @@ pub type Result = result::Result>; /// # /// # use schema::users; /// # use diesel::backend::{self, Backend}; -/// # use diesel::deserialize::Queryable; +/// # use diesel::deserialize::{Queryable, FromSql}; +/// # use diesel::sql_types::Text; /// # /// struct LowercaseString(String); /// @@ -66,15 +68,15 @@ pub type Result = result::Result>; /// } /// } /// -/// impl Queryable for LowercaseString +/// impl Queryable for LowercaseString /// where /// DB: Backend, -/// String: Queryable, +/// String: FromSql, /// { -/// type Row = >::Row; +/// type Row = String; /// -/// fn build(row: Self::Row) -> Self { -/// LowercaseString(String::build(row).to_lowercase()) +/// fn build(s: String) -> Self { +/// LowercaseString(s.to_lowercase()) /// } /// } /// @@ -148,7 +150,7 @@ where /// The Rust type you'd like to map from. /// /// This is typically a tuple of all of your struct's fields. - type Row: FromSqlRow; + type Row: FromStaticSqlRow; /// Construct an instance of this type fn build(row: Self::Row) -> Self; @@ -216,7 +218,7 @@ pub use diesel_derives::Queryable; /// DB: Backend, /// String: FromSql, /// { -/// fn from_sql(bytes: Option>) -> deserialize::Result { +/// fn from_sql(bytes: backend::RawValue) -> deserialize::Result { /// String::from_sql(bytes) /// .map(|s| LowercaseString(s.to_lowercase())) /// } @@ -249,7 +251,7 @@ where DB: Backend, { /// Construct an instance of `Self` from the database row - fn build>(row: &R) -> Result; + fn build<'a>(row: &impl NamedRow<'a, DB>) -> Result; } #[doc(inline)] @@ -260,7 +262,8 @@ pub use diesel_derives::QueryableByName; /// When possible, implementations of this trait should prefer to use an /// existing implementation, rather than reading from `bytes`. (For example, if /// you are implementing this for an enum which is represented as an integer in -/// the database, prefer `i32::from_sql(bytes)` over reading from `bytes` +/// the database, prefer `i32::from_sql(bytes)` (or the explicit form +/// `>::from_sql(bytes)`) over reading from `bytes` /// directly) /// /// Types which implement this trait should also have `#[derive(FromSqlRow)]` @@ -285,10 +288,10 @@ pub use diesel_derives::QueryableByName; /// ```rust /// # use diesel::backend::{self, Backend}; /// # use diesel::sql_types::*; -/// # use diesel::deserialize::{self, FromSql}; +/// # use diesel::deserialize::{self, FromSql, FromSqlRow}; /// # /// #[repr(i32)] -/// #[derive(Debug, Clone, Copy)] +/// #[derive(Debug, Clone, Copy, FromSqlRow)] /// pub enum MyEnum { /// A = 1, /// B = 2, @@ -299,7 +302,7 @@ pub use diesel_derives::QueryableByName; /// DB: Backend, /// i32: FromSql, /// { -/// fn from_sql(bytes: Option>) -> deserialize::Result { +/// fn from_sql(bytes: backend::RawValue) -> deserialize::Result { /// match i32::from_sql(bytes)? { /// 1 => Ok(MyEnum::A), /// 2 => Ok(MyEnum::B), @@ -310,64 +313,130 @@ pub use diesel_derives::QueryableByName; /// ``` pub trait FromSql: Sized { /// See the trait documentation. - fn from_sql(bytes: Option>) -> Result; + fn from_sql(bytes: backend::RawValue) -> Result; + + /// A specialized variant of `from_sql` for handling null values. + /// + /// The default implementation returns an `UnexpectedNullError` for + /// an encountered null value and calls `Self::from_sql` otherwise + /// + /// If your custom type supports null values you need to provide a + /// custom implementation. + #[inline(always)] + fn from_nullable_sql(bytes: Option>) -> Result { + match bytes { + Some(bytes) => Self::from_sql(bytes), + None => Err(Box::new(crate::result::UnexpectedNullError)), + } + } } -/// Deserialize one or more fields. +/// Deserialize a database row into a rust data structure /// -/// All types which implement `FromSql` should also implement this trait. This -/// trait differs from `FromSql` in that it is also implemented by tuples. -/// Implementations of this trait are usually derived. -/// -/// In the future, we hope to be able to provide a blanket impl of this trait -/// for all types which implement `FromSql`. However, as of Diesel 1.0, such an -/// impl would conflict with our impl for tuples. -/// -/// This trait can be [derived](derive.FromSqlRow.html) -pub trait FromSqlRow: Sized { - /// The number of fields that this type will consume. Must be equal to - /// the number of times you would call `row.take()` in `build_from_row` - const FIELDS_NEEDED: usize = 1; - +/// Diesel provides wild card implementations of this trait for all types +/// that implement one of the following traits: +/// * [`Queryable`](trait.Queryable.html) +/// * [`QueryableByName`](trait.QueryableByName.html) +pub trait FromSqlRow: Sized { /// See the trait documentation. - fn build_from_row>(row: &mut T) -> Result; + fn build_from_row<'a>(row: &impl Row<'a, DB>) -> Result; } #[doc(inline)] pub use diesel_derives::FromSqlRow; -// Reasons we can't write this: -// -// impl FromSqlRow for T -// where -// DB: Backend + HasSqlType, -// T: FromSql, -// { -// fn build_from_row>(row: &mut T) -> Result { -// Self::from_sql(row.take()) -// } -// } -// -// (this is mostly here so @sgrif has a better reference every time they think -// they've somehow had a breakthrough on solving this problem): -// -// - It conflicts with our impl for tuples, because `DB` is a bare type -// parameter, it could in theory be a local type for some other impl. -// - This is fixed by replacing our impl with 3 impls, where `DB` is changed -// concrete backends. This would mean that any third party crates adding new -// backends would need to add the tuple impls, which sucks but is fine. -// - It conflicts with our impl for `Option` -// - So we could in theory fix this by both splitting the generic impl into -// backend specific impls, and removing the `FromSql` impls. In theory there -// is no reason that it needs to implement `FromSql`, since everything -// requires `FromSqlRow`, but it really feels like it should. -// - Specialization might also fix this one. The impl isn't quite a strict -// subset (the `FromSql` impl has `T: FromSql`, and the `FromSqlRow` impl -// has `T: FromSqlRow`), but if `FromSql` implies `FromSqlRow`, -// specialization might consider that a subset? -// - I don't know that we really need it. `#[derive(FromSqlRow)]` is probably -// good enough. That won't improve our own codebase, since 99% of our -// `FromSqlRow` impls are for types from another crate, but it's almost -// certainly good enough for user types. -// - Still, it really feels like `FromSql` *should* be able to imply both -// `FromSqlRow` and `Queryable` +/// A marker trait indicating that the corresponding type consumes a static at +/// compile time known number of field +/// +/// There is normally no need to implement this trait. Diesel provides +/// wild card impls for all types that implement `FromSql` or `Queryable` +/// where the size of `ST` is known +pub trait StaticallySizedRow: FromSqlRow { + /// The number of fields that this type will consume. + const FIELD_COUNT: usize; +} + +impl FromSqlRow for T +where + DB: Backend, + T: QueryableByName, +{ + fn build_from_row<'a>(row: &impl Row<'a, DB>) -> Result { + T::build(row) + } +} + +/// A helper trait to deserialize a statically sized row into an tuple +/// +/// **If you see an error message mentioning this trait you likly +/// trying to map the result of an query to an struct with missmatching +/// field types. Recheck your field order and the concrete field types** +/// +/// You should not need to implement this trait directly. +/// Diesel provides wild card implementations for any supported tuple size +/// and for any type that implements `FromSql`. +/// +// This is a distinct trait from `FromSqlRow` because otherwise we +// are getting conflicting implementation errors for our `FromSqlRow` +// implementation for tuples and our wild card impl for all types +// implementing `Queryable` +pub trait FromStaticSqlRow: Sized { + /// See the trait documentation + fn build_from_row<'a>(row: &impl Row<'a, DB>) -> Result; +} + +impl FromSqlRow for T +where + T: Queryable, + ST: SqlType, + DB: Backend, + T::Row: FromStaticSqlRow, +{ + fn build_from_row<'a>(row: &impl Row<'a, DB>) -> Result { + let row = >::build_from_row(row)?; + Ok(T::build(row)) + } +} + +impl FromStaticSqlRow for T +where + DB: Backend, + T: FromSql, + ST: SingleValue, +{ + fn build_from_row<'a>(row: &impl Row<'a, DB>) -> Result { + use crate::row::Field; + + let field = row.get(0).ok_or(crate::result::UnexpectedEndOfRow)?; + T::from_nullable_sql(field.value()) + } +} + +// We cannot have this impl because rustc +// then complains in third party crates that +// diesel may implement `SingleValue` for tuples +// in the future. While that is theoretically true, +// that will likly not happen in practice. +// If we get negative trait impls at some point in time +// it should be possible to make this work. +/*impl Queryable for T +where + DB: Backend, + T: FromStaticSqlRow, + ST: SingleValue, +{ + type Row = Self; + + fn build(row: Self::Row) -> Self { + row + } +}*/ + +impl StaticallySizedRow for T +where + ST: SqlType + crate::type_impls::tuples::TupleSize, + T: Queryable, + DB: Backend, +{ + const FIELD_COUNT: usize = ::SIZE; +} diff --git a/diesel/src/expression/array_comparison.rs b/diesel/src/expression/array_comparison.rs index 1eb5d1db0306..ca90e3950d30 100644 --- a/diesel/src/expression/array_comparison.rs +++ b/diesel/src/expression/array_comparison.rs @@ -2,6 +2,7 @@ use crate::backend::Backend; use crate::expression::subselect::Subselect; use crate::expression::*; use crate::query_builder::*; +use crate::query_builder::{BoxedSelectStatement, SelectStatement}; use crate::result::QueryResult; use crate::sql_types::Bool; @@ -92,9 +93,7 @@ where impl_selectable_expression!(In); impl_selectable_expression!(NotIn); -use crate::query_builder::{BoxedSelectStatement, SelectStatement}; - -pub trait AsInExpression { +pub trait AsInExpression { type InExpression: MaybeEmpty + Expression; fn as_in_expression(self) -> Self::InExpression; @@ -104,6 +103,7 @@ impl AsInExpression for I where I: IntoIterator, T: AsExpression, + ST: SqlType + TypedExpressionType, { type InExpression = Many; @@ -119,6 +119,7 @@ pub trait MaybeEmpty { impl AsInExpression for SelectStatement where + ST: SqlType + TypedExpressionType, Subselect: Expression, Self: SelectQuery, { @@ -131,6 +132,7 @@ where impl<'a, ST, QS, DB> AsInExpression for BoxedSelectStatement<'a, ST, QS, DB> where + ST: SqlType + TypedExpressionType, Subselect, ST>: Expression, { type InExpression = Subselect; diff --git a/diesel/src/expression/bound.rs b/diesel/src/expression/bound.rs index 4dbb605a2450..6c3d33d976e9 100644 --- a/diesel/src/expression/bound.rs +++ b/diesel/src/expression/bound.rs @@ -5,7 +5,7 @@ use crate::backend::Backend; use crate::query_builder::*; use crate::result::QueryResult; use crate::serialize::ToSql; -use crate::sql_types::{DieselNumericOps, HasSqlType}; +use crate::sql_types::{DieselNumericOps, HasSqlType, SqlType}; #[derive(Debug, Clone, Copy, DieselNumericOps)] pub struct Bound { @@ -22,7 +22,10 @@ impl Bound { } } -impl Expression for Bound { +impl Expression for Bound +where + T: SqlType + TypedExpressionType, +{ type SqlType = T; } diff --git a/diesel/src/expression/coerce.rs b/diesel/src/expression/coerce.rs index 3328a3f74f84..ecf3eacc5084 100644 --- a/diesel/src/expression/coerce.rs +++ b/diesel/src/expression/coerce.rs @@ -4,7 +4,7 @@ use crate::backend::Backend; use crate::expression::*; use crate::query_builder::*; use crate::result::QueryResult; -use crate::sql_types::DieselNumericOps; +use crate::sql_types::{DieselNumericOps, SqlType}; #[derive(Debug, Copy, Clone, QueryId, DieselNumericOps)] #[doc(hidden)] @@ -36,13 +36,24 @@ impl Coerce { impl Expression for Coerce where T: Expression, + ST: SqlType + TypedExpressionType, { type SqlType = ST; } -impl SelectableExpression for Coerce where T: SelectableExpression {} +impl SelectableExpression for Coerce +where + T: SelectableExpression, + Self: Expression, +{ +} -impl AppearsOnTable for Coerce where T: AppearsOnTable {} +impl AppearsOnTable for Coerce +where + T: AppearsOnTable, + Self: Expression, +{ +} impl QueryFragment for Coerce where diff --git a/diesel/src/expression/count.rs b/diesel/src/expression/count.rs index c75f2cbde64d..6502340fc677 100644 --- a/diesel/src/expression/count.rs +++ b/diesel/src/expression/count.rs @@ -3,7 +3,7 @@ use super::{Expression, ValidGrouping}; use crate::backend::Backend; use crate::query_builder::*; use crate::result::QueryResult; -use crate::sql_types::{BigInt, DieselNumericOps}; +use crate::sql_types::{BigInt, DieselNumericOps, SingleValue, SqlType}; sql_function! { /// Creates a SQL `COUNT` expression @@ -25,7 +25,7 @@ sql_function! { /// # } /// ``` #[aggregate] - fn count(expr: T) -> BigInt; + fn count(expr: T) -> BigInt; } /// Creates a SQL `COUNT(*)` expression diff --git a/diesel/src/expression/exists.rs b/diesel/src/expression/exists.rs index 8b4d83ed54fe..385065f6bf94 100644 --- a/diesel/src/expression/exists.rs +++ b/diesel/src/expression/exists.rs @@ -32,21 +32,21 @@ pub fn exists(query: T) -> Exists { Exists(Subselect::new(query)) } -#[derive(Debug, Clone, Copy, QueryId)] -pub struct Exists(pub Subselect); +#[derive(Clone, Copy, QueryId, Debug)] +pub struct Exists(pub Subselect); impl Expression for Exists where - Subselect: Expression, + Subselect: Expression, { type SqlType = Bool; } impl ValidGrouping for Exists where - Subselect: ValidGrouping, + Subselect: ValidGrouping, { - type IsAggregate = as ValidGrouping>::IsAggregate; + type IsAggregate = as ValidGrouping>::IsAggregate; } #[cfg(not(feature = "unstable"))] @@ -80,13 +80,13 @@ where impl SelectableExpression for Exists where Self: AppearsOnTable, - Subselect: SelectableExpression, + Subselect: SelectableExpression, { } impl AppearsOnTable for Exists where Self: Expression, - Subselect: AppearsOnTable, + Subselect: AppearsOnTable, { } diff --git a/diesel/src/expression/functions/aggregate_ordering.rs b/diesel/src/expression/functions/aggregate_ordering.rs index 6048d11f0ac1..bb2889590a94 100644 --- a/diesel/src/expression/functions/aggregate_ordering.rs +++ b/diesel/src/expression/functions/aggregate_ordering.rs @@ -1,5 +1,17 @@ use crate::expression::functions::sql_function; -use crate::sql_types::{IntoNullable, SqlOrd}; +use crate::sql_types::{IntoNullable, SingleValue, SqlOrd, SqlType}; + +pub trait SqlOrdAggregate: SingleValue { + type Ret: SqlType + SingleValue; +} + +impl SqlOrdAggregate for T +where + T: SqlOrd + IntoNullable + SingleValue, + T::Nullable: SqlType + SingleValue, +{ + type Ret = T::Nullable; +} sql_function! { /// Represents a SQL `MAX` function. This function can only take types which are @@ -17,7 +29,7 @@ sql_function! { /// assert_eq!(Ok(Some(8)), animals.select(max(legs)).first(&connection)); /// # } #[aggregate] - fn max(expr: ST) -> ST::Nullable; + fn max(expr: ST) -> ST::Ret; } sql_function! { @@ -36,5 +48,5 @@ sql_function! { /// assert_eq!(Ok(Some(4)), animals.select(min(legs)).first(&connection)); /// # } #[aggregate] - fn min(expr: ST) -> ST::Nullable; + fn min(expr: ST) -> ST::Ret; } diff --git a/diesel/src/expression/mod.rs b/diesel/src/expression/mod.rs index fab5ac0e17a5..542c5d8234b6 100644 --- a/diesel/src/expression/mod.rs +++ b/diesel/src/expression/mod.rs @@ -82,6 +82,7 @@ pub use self::sql_literal::{SqlLiteral, UncheckedBind}; use crate::backend::Backend; use crate::dsl::AsExprOf; +use crate::sql_types::{HasSqlType, SingleValue, SqlType}; /// Represents a typed fragment of SQL. /// @@ -92,7 +93,52 @@ use crate::dsl::AsExprOf; /// implementing this directly. pub trait Expression { /// The type that this expression represents in SQL - type SqlType; + type SqlType: TypedExpressionType; +} + +/// Marker trait for possible types of [`Expression::SqlType`] +/// +/// [`Expression::SqlType`]: trait.Expression.html#associatedtype.SqlType +pub trait TypedExpressionType {} + +/// Possible types for []`Expression::SqlType`] +/// +/// [`Expression::SqlType`]: trait.Expression.html#associatedtype.SqlType +pub mod expression_types { + use super::{QueryMetadata, TypedExpressionType}; + use crate::backend::Backend; + use crate::sql_types::SingleValue; + + /// Query nodes with this expression type do not have a statically at compile + /// time known expression type. + /// + /// An example for such a query node in diesel itself, is `sql_query` as + /// we do not know which fields are returned from such a query at compile time. + /// + /// For loading values from queries returning a type of this expression, consider + /// using [`#[derive(QueryableByName)]`] on the corresponding result type. + /// + /// [`#[derive(QueryableByName)]`]: ../deserialize/derive.QueryableByName.html + #[derive(Clone, Copy, Debug)] + pub struct Untyped; + + /// Query nodes witch cannot be part of a select clause. + /// + /// If you see an error message containing `FromSqlRow` and this type + /// recheck that you have written a valid select clause + #[derive(Debug, Clone, Copy)] + pub struct NotSelectable; + + impl TypedExpressionType for Untyped {} + impl TypedExpressionType for NotSelectable {} + + impl TypedExpressionType for ST where ST: SingleValue {} + + impl QueryMetadata for DB { + fn row_metadata(_: &DB::MetadataLookup, row: &mut Vec>) { + row.push(None) + } + } } impl Expression for Box { @@ -103,6 +149,28 @@ impl<'a, T: Expression + ?Sized> Expression for &'a T { type SqlType = T::SqlType; } +/// A helper to translate type level sql type information into +/// runtime type information for specific queries +/// +/// If you do not implement a custom backend implementation +/// this trait is likely not relevant for you. +pub trait QueryMetadata: Backend { + /// The exact return value of this function is considerded to be a + /// backend specific implementation detail. You should not rely on those + /// values if you not own the corresponding backend + fn row_metadata(lookup: &Self::MetadataLookup, out: &mut Vec>); +} + +impl QueryMetadata for DB +where + DB: Backend + HasSqlType, + T: SingleValue, +{ + fn row_metadata(lookup: &Self::MetadataLookup, out: &mut Vec>) { + out.push(Some(>::metadata(lookup))) + } +} + /// Converts a type to its representation for use in Diesel's query builder. /// /// This trait is used directly. Apps should typically use [`IntoSql`] instead. @@ -124,7 +192,10 @@ impl<'a, T: Expression + ?Sized> Expression for &'a T { /// /// This trait could be [derived](derive.AsExpression.html) -pub trait AsExpression { +pub trait AsExpression +where + T: SqlType + TypedExpressionType, +{ /// The expression being returned type Expression: Expression; @@ -135,7 +206,11 @@ pub trait AsExpression { #[doc(inline)] pub use diesel_derives::AsExpression; -impl AsExpression for T { +impl AsExpression for T +where + T: Expression, + ST: SqlType + TypedExpressionType, +{ type Expression = Self; fn as_expression(self) -> Self { @@ -177,6 +252,7 @@ pub trait IntoSql { fn into_sql(self) -> AsExprOf where Self: AsExpression + Sized, + T: SqlType + TypedExpressionType, { self.as_expression() } @@ -188,6 +264,7 @@ pub trait IntoSql { fn as_sql<'a, T>(&'a self) -> AsExprOf<&'a Self, T> where &'a Self: AsExpression, + T: SqlType + TypedExpressionType, { self.as_expression() } @@ -432,7 +509,7 @@ use crate::query_builder::{QueryFragment, QueryId}; /// type DB = diesel::sqlite::Sqlite; /// # */ /// -/// fn find_user(search: Search) -> Box> { +/// fn find_user(search: Search) -> Box> { /// match search { /// Search::Id(id) => Box::new(users::id.eq(id)), /// Search::Name(name) => Box::new(users::name.eq(name)), diff --git a/diesel/src/expression/nullable.rs b/diesel/src/expression/nullable.rs index 70c400cea638..13a5bc268882 100644 --- a/diesel/src/expression/nullable.rs +++ b/diesel/src/expression/nullable.rs @@ -1,4 +1,5 @@ use crate::backend::Backend; +use crate::expression::TypedExpressionType; use crate::expression::*; use crate::query_builder::*; use crate::query_source::joins::ToInnerJoin; @@ -17,9 +18,10 @@ impl Nullable { impl Expression for Nullable where T: Expression, - ::SqlType: IntoNullable, + T::SqlType: IntoNullable, + ::Nullable: TypedExpressionType, { - type SqlType = <::SqlType as IntoNullable>::Nullable; + type SqlType = ::Nullable; } impl QueryFragment for Nullable diff --git a/diesel/src/expression/operators.rs b/diesel/src/expression/operators.rs index b5d9116f21e8..02e9d9481df2 100644 --- a/diesel/src/expression/operators.rs +++ b/diesel/src/expression/operators.rs @@ -5,7 +5,7 @@ macro_rules! __diesel_operator_body { notation = $notation:ident, struct_name = $name:ident, operator = $operator:expr, - return_ty = ReturnBasedOnArgs, + return_ty = (ReturnBasedOnArgs), ty_params = ($($ty_param:ident,)+), field_names = $field_names:tt, backend_ty_params = $backend_ty_params:tt, @@ -15,7 +15,7 @@ macro_rules! __diesel_operator_body { notation = $notation, struct_name = $name, operator = $operator, - return_ty = ST, + return_ty = (ST), ty_params = ($($ty_param,)+), field_names = $field_names, backend_ty_params = $backend_ty_params, @@ -29,7 +29,7 @@ macro_rules! __diesel_operator_body { notation = $notation:ident, struct_name = $name:ident, operator = $operator:expr, - return_ty = $return_ty:ty, + return_ty = ($($return_ty:tt)+), ty_params = ($($ty_param:ident,)+), field_names = $field_names:tt, backend_ty_params = $backend_ty_params:tt, @@ -39,7 +39,7 @@ macro_rules! __diesel_operator_body { notation = $notation, struct_name = $name, operator = $operator, - return_ty = $return_ty, + return_ty = ($($return_ty)*), ty_params = ($($ty_param,)+), field_names = $field_names, backend_ty_params = $backend_ty_params, @@ -53,7 +53,7 @@ macro_rules! __diesel_operator_body { notation = $notation:ident, struct_name = $name:ident, operator = $operator:expr, - return_ty = $return_ty:ty, + return_ty = ($($return_ty:tt)+), ty_params = ($($ty_param:ident,)+), field_names = ($($field_name:ident,)+), backend_ty_params = ($($backend_ty_param:ident,)*), @@ -85,7 +85,7 @@ macro_rules! __diesel_operator_body { impl<$($ty_param,)+ $($expression_ty_params,)*> $crate::expression::Expression for $name<$($ty_param,)+> where $($expression_bounds)* { - type SqlType = $return_ty; + type SqlType = $($return_ty)*; } impl<$($ty_param,)+ $($backend_ty_param,)*> $crate::query_builder::QueryFragment<$backend_ty> @@ -187,6 +187,8 @@ macro_rules! __diesel_operator_to_sql { /// /// ```rust /// # include!("../doctest_setup.rs"); +/// # use diesel::sql_types::SqlType; +/// # use diesel::expression::TypedExpressionType; /// # /// # fn main() { /// # use schema::users::dsl::*; @@ -196,9 +198,10 @@ macro_rules! __diesel_operator_to_sql { /// use diesel::expression::AsExpression; /// /// // Normally you would put this on a trait instead -/// fn my_eq(left: T, right: U) -> MyEq where -/// T: Expression, -/// U: AsExpression, +/// fn my_eq(left: T, right: U) -> MyEq where +/// T: Expression, +/// U: AsExpression, +/// ST: SqlType + TypedExpressionType, /// { /// MyEq::new(left, right.as_expression()) /// } @@ -223,11 +226,37 @@ macro_rules! infix_operator { notation = infix, struct_name = $name, operator = $operator, - return_ty = $($return_ty)::*, + return_ty = ( + $crate::sql_types::is_nullable::MaybeNullable< + $crate::sql_types::is_nullable::IsOneNullable< + ::SqlType, + ::SqlType + >, + $($return_ty)::* + > + ), ty_params = (T, U,), field_names = (left, right,), backend_ty_params = (DB,), backend_ty = DB, + expression_ty_params = (), + expression_bounds = ( + T: $crate::expression::Expression, + U: $crate::expression::Expression, + ::SqlType: $crate::sql_types::SqlType, + ::SqlType: $crate::sql_types::SqlType, + $crate::sql_types::is_nullable::IsSqlTypeNullable< + ::SqlType + >: $crate::sql_types::OneIsNullable< + $crate::sql_types::is_nullable::IsSqlTypeNullable< + ::SqlType + > + >, + $crate::sql_types::is_nullable::IsOneNullable< + ::SqlType, + ::SqlType + >: $crate::sql_types::MaybeNullableType<$($return_ty)::*>, + ), ); }; @@ -236,11 +265,37 @@ macro_rules! infix_operator { notation = infix, struct_name = $name, operator = $operator, - return_ty = $return_ty, + return_ty = ( + $crate::sql_types::is_nullable::MaybeNullable< + $crate::sql_types::is_nullable::IsOneNullable< + ::SqlType, + ::SqlType + >, + $return_ty, + > + ), ty_params = (T, U,), field_names = (left, right,), backend_ty_params = (), backend_ty = $backend, + expression_ty_params = (), + expression_bounds = ( + T: $crate::expression::Expression, + U: $crate::expression::Expression, + ::SqlType: $crate::sql_types::SqlType, + ::SqlType: $crate::sql_types::SqlType, + $crate::sql_types::is_nullable::IsSqlTypeNullable< + ::SqlType + >: $crate::sql_types::OneIsNullable< + $crate::sql_types::is_nullable::IsSqlTypeNullable< + ::SqlType + > + >, + $crate::sql_types::is_nullable::IsOneNullable< + ::SqlType, + ::SqlType + >: $crate::sql_types::MaybeNullableType<$return_ty>, + ), ); }; } @@ -278,7 +333,7 @@ macro_rules! postfix_operator { notation = postfix, struct_name = $name, operator = $operator, - return_ty = $return_ty, + return_ty = ($return_ty), ty_params = (Expr,), field_names = (expr,), backend_ty_params = (DB,), @@ -291,7 +346,7 @@ macro_rules! postfix_operator { notation = postfix, struct_name = $name, operator = $operator, - return_ty = $return_ty, + return_ty = ($return_ty), ty_params = (Expr,), field_names = (expr,), backend_ty_params = (), @@ -333,7 +388,7 @@ macro_rules! prefix_operator { notation = prefix, struct_name = $name, operator = $operator, - return_ty = $return_ty, + return_ty = ($return_ty), ty_params = (Expr,), field_names = (expr,), backend_ty_params = (DB,), @@ -346,7 +401,7 @@ macro_rules! prefix_operator { notation = prefix, struct_name = $name, operator = $operator, - return_ty = $return_ty, + return_ty = ($return_ty), ty_params = (Expr,), field_names = (expr,), backend_ty_params = (), @@ -377,20 +432,30 @@ infix_operator!(LtEq, " <= "); infix_operator!(NotBetween, " NOT BETWEEN "); infix_operator!(NotEq, " != "); infix_operator!(NotLike, " NOT LIKE "); -infix_operator!(Or, " OR "); postfix_operator!(IsNull, " IS NULL"); postfix_operator!(IsNotNull, " IS NOT NULL"); -postfix_operator!(Asc, " ASC", ()); -postfix_operator!(Desc, " DESC", ()); +postfix_operator!( + Asc, + " ASC ", + crate::expression::expression_types::NotSelectable +); +postfix_operator!( + Desc, + " DESC ", + crate::expression::expression_types::NotSelectable +); prefix_operator!(Not, "NOT "); -use crate::expression::ValidGrouping; +use crate::expression::{TypedExpressionType, ValidGrouping}; use crate::insertable::{ColumnInsertValue, Insertable}; use crate::query_builder::{QueryId, ValuesClause}; use crate::query_source::Column; -use crate::sql_types::DieselNumericOps; +use crate::sql_types::{ + is_nullable, AllAreNullable, Bool, DieselNumericOps, MaybeNullableType, SqlType, +}; +use crate::Expression; impl Insertable for Eq where @@ -432,6 +497,7 @@ impl crate::expression::Expression for Concat where L: crate::expression::Expression, R: crate::expression::Expression, + ST: SqlType + TypedExpressionType, { type SqlType = ST; } @@ -458,3 +524,58 @@ where Ok(()) } } + +// or is different +// it only evaluates to null if both sides are null +#[derive( + Debug, + Clone, + Copy, + crate::query_builder::QueryId, + crate::sql_types::DieselNumericOps, + crate::expression::ValidGrouping, +)] +#[doc(hidden)] +pub struct Or { + pub(crate) left: T, + pub(crate) right: U, +} + +impl Or { + pub fn new(left: T, right: U) -> Self { + Or { left, right } + } +} + +impl_selectable_expression!(Or); + +impl Expression for Or +where + T: Expression, + U: Expression, + T::SqlType: SqlType, + U::SqlType: SqlType, + is_nullable::IsSqlTypeNullable: + AllAreNullable>, + is_nullable::AreAllNullable: MaybeNullableType, +{ + type SqlType = + is_nullable::MaybeNullable, Bool>; +} + +impl crate::query_builder::QueryFragment for Or +where + DB: crate::backend::Backend, + T: crate::query_builder::QueryFragment, + U: crate::query_builder::QueryFragment, +{ + fn walk_ast( + &self, + mut out: crate::query_builder::AstPass, + ) -> crate::result::QueryResult<()> { + self.left.walk_ast(out.reborrow())?; + out.push_sql(" OR "); + self.right.walk_ast(out.reborrow())?; + Ok(()) + } +} diff --git a/diesel/src/expression/ops/mod.rs b/diesel/src/expression/ops/mod.rs index eea4aa4ac9c5..580a2932da0a 100644 --- a/diesel/src/expression/ops/mod.rs +++ b/diesel/src/expression/ops/mod.rs @@ -2,7 +2,8 @@ macro_rules! generic_numeric_expr_inner { ($tpe: ident, ($($param: ident),*), $op: ident, $fn_name: ident) => { impl ::std::ops::$op for $tpe<$($param),*> where $tpe<$($param),*>: $crate::expression::Expression, - <$tpe<$($param),*> as $crate::Expression>::SqlType: $crate::sql_types::ops::$op, + <$tpe<$($param),*> as $crate::Expression>::SqlType: $crate::sql_types::SqlType + $crate::sql_types::ops::$op, + <<$tpe<$($param),*> as $crate::Expression>::SqlType as $crate::sql_types::ops::$op>::Rhs: $crate::expression::TypedExpressionType, Rhs: $crate::expression::AsExpression< <<$tpe<$($param),*> as $crate::Expression>::SqlType as $crate::sql_types::ops::$op>::Rhs, >, diff --git a/diesel/src/expression/ops/numeric.rs b/diesel/src/expression/ops/numeric.rs index 8a7b79793197..1d5a5d4153cd 100644 --- a/diesel/src/expression/ops/numeric.rs +++ b/diesel/src/expression/ops/numeric.rs @@ -1,5 +1,5 @@ use crate::backend::Backend; -use crate::expression::{Expression, ValidGrouping}; +use crate::expression::{Expression, TypedExpressionType, ValidGrouping}; use crate::query_builder::*; use crate::result::QueryResult; use crate::sql_types; @@ -26,6 +26,7 @@ macro_rules! numeric_operation { Lhs: Expression, Lhs::SqlType: sql_types::ops::$name, Rhs: Expression, + ::Output: TypedExpressionType, { type SqlType = ::Output; } diff --git a/diesel/src/expression/sql_literal.rs b/diesel/src/expression/sql_literal.rs index bbc3dfee3792..68c0e0abe9f5 100644 --- a/diesel/src/expression/sql_literal.rs +++ b/diesel/src/expression/sql_literal.rs @@ -5,7 +5,7 @@ use crate::expression::*; use crate::query_builder::*; use crate::query_dsl::RunQueryDsl; use crate::result::QueryResult; -use crate::sql_types::DieselNumericOps; +use crate::sql_types::{DieselNumericOps, SqlType}; #[derive(Debug, Clone, DieselNumericOps)] #[must_use = "Queries are only executed when calling `load`, `get_result`, or similar."] @@ -18,7 +18,10 @@ pub struct SqlLiteral { _marker: PhantomData, } -impl SqlLiteral { +impl SqlLiteral +where + ST: TypedExpressionType, +{ #[doc(hidden)] pub fn new(sql: String, inner: T) -> Self { SqlLiteral { @@ -51,11 +54,11 @@ impl SqlLiteral { /// # fn main() { /// # use self::users::dsl::*; /// # use diesel::dsl::sql; - /// # use diesel::sql_types::{Integer, Text}; + /// # use diesel::sql_types::{Integer, Text, Bool}; /// # let connection = establish_connection(); /// let seans_id = users /// .select(id) - /// .filter(sql("name = ").bind::("Sean")) + /// .filter(sql::("name = ").bind::("Sean")) /// .get_result(&connection); /// assert_eq!(Ok(1), seans_id); /// @@ -81,14 +84,14 @@ impl SqlLiteral { /// # fn main() { /// # use self::users::dsl::*; /// # use diesel::dsl::sql; - /// # use diesel::sql_types::{Integer, Text}; + /// # use diesel::sql_types::{Integer, Text, Bool}; /// # let connection = establish_connection(); /// # diesel::insert_into(users).values(name.eq("Ryan")) /// # .execute(&connection).unwrap(); /// let query = users /// .select(name) /// .filter( - /// sql("id > ") + /// sql::("id > ") /// .bind::(1) /// .sql(" AND name <> ") /// .bind::("Ryan") @@ -100,6 +103,7 @@ impl SqlLiteral { /// ``` pub fn bind(self, bind_value: U) -> UncheckedBind where + BindST: SqlType + TypedExpressionType, U: AsExpression, { UncheckedBind::new(self, bind_value.as_expression()) @@ -132,14 +136,14 @@ impl SqlLiteral { /// # fn main() { /// # use self::users::dsl::*; /// # use diesel::dsl::sql; - /// # use diesel::sql_types::{Integer, Text}; + /// # use diesel::sql_types::Bool; /// # let connection = establish_connection(); /// # diesel::insert_into(users).values(name.eq("Ryan")) /// # .execute(&connection).unwrap(); /// let query = users /// .select(name) /// .filter( - /// sql("id > 1") + /// sql::("id > 1") /// .sql(" AND name <> 'Ryan'") /// ) /// .get_results(&connection); @@ -152,7 +156,10 @@ impl SqlLiteral { } } -impl Expression for SqlLiteral { +impl Expression for SqlLiteral +where + ST: TypedExpressionType, +{ type SqlType = ST; } @@ -175,15 +182,18 @@ impl QueryId for SqlLiteral { const HAS_STATIC_QUERY_ID: bool = false; } -impl Query for SqlLiteral { +impl Query for SqlLiteral +where + Self: Expression, +{ type SqlType = ST; } impl RunQueryDsl for SqlLiteral {} -impl SelectableExpression for SqlLiteral {} +impl SelectableExpression for SqlLiteral where Self: Expression {} -impl AppearsOnTable for SqlLiteral {} +impl AppearsOnTable for SqlLiteral where Self: Expression {} impl ValidGrouping for SqlLiteral { type IsAggregate = is_aggregate::Never; @@ -215,15 +225,19 @@ impl ValidGrouping for SqlLiteral { /// # /// # fn run_test() -> QueryResult<()> { /// # use schema::users::dsl::*; +/// # use diesel::sql_types::Bool; /// use diesel::dsl::sql; /// # let connection = establish_connection(); -/// let user = users.filter(sql("name = 'Sean'")).first(&connection)?; +/// let user = users.filter(sql::("name = 'Sean'")).first(&connection)?; /// let expected = (1, String::from("Sean")); /// assert_eq!(expected, user); /// # Ok(()) /// # } /// ``` -pub fn sql(sql: &str) -> SqlLiteral { +pub fn sql(sql: &str) -> SqlLiteral +where + ST: TypedExpressionType, +{ SqlLiteral::new(sql.into(), ()) } @@ -272,14 +286,14 @@ where /// # fn main() { /// # use self::users::dsl::*; /// # use diesel::dsl::sql; - /// # use diesel::sql_types::{Integer, Text}; + /// # use diesel::sql_types::{Integer, Bool}; /// # let connection = establish_connection(); /// # diesel::insert_into(users).values(name.eq("Ryan")) /// # .execute(&connection).unwrap(); /// let query = users /// .select(name) /// .filter( - /// sql("id > ") + /// sql::("id > ") /// .bind::(1) /// .sql(" AND name <> 'Ryan'") /// ) diff --git a/diesel/src/expression/subselect.rs b/diesel/src/expression/subselect.rs index 754d80df9f1e..a0bab4e8b240 100644 --- a/diesel/src/expression/subselect.rs +++ b/diesel/src/expression/subselect.rs @@ -4,6 +4,7 @@ use crate::expression::array_comparison::MaybeEmpty; use crate::expression::*; use crate::query_builder::*; use crate::result::QueryResult; +use crate::sql_types::SqlType; #[derive(Debug, Copy, Clone, QueryId)] pub struct Subselect { @@ -20,7 +21,10 @@ impl Subselect { } } -impl Expression for Subselect { +impl Expression for Subselect +where + ST: SqlType + TypedExpressionType, +{ type SqlType = ST; } diff --git a/diesel/src/expression_methods/bool_expression_methods.rs b/diesel/src/expression_methods/bool_expression_methods.rs index 663da37b2ddd..3125e321bc2d 100644 --- a/diesel/src/expression_methods/bool_expression_methods.rs +++ b/diesel/src/expression_methods/bool_expression_methods.rs @@ -1,7 +1,7 @@ use crate::expression::grouped::Grouped; use crate::expression::operators::{And, Or}; -use crate::expression::{AsExpression, Expression}; -use crate::sql_types::{Bool, Nullable}; +use crate::expression::{AsExpression, Expression, TypedExpressionType}; +use crate::sql_types::{BoolOrNullableBool, SqlType}; /// Methods present on boolean expressions pub trait BoolExpressionMethods: Expression + Sized { @@ -36,7 +36,13 @@ pub trait BoolExpressionMethods: Expression + Sized { /// assert_eq!(expected, data); /// # Ok(()) /// # } - fn and>(self, other: T) -> And { + fn and(self, other: T) -> And + where + Self::SqlType: SqlType, + ST: SqlType + TypedExpressionType, + T: AsExpression, + And: Expression, + { And::new(self, other.as_expression()) } @@ -77,7 +83,13 @@ pub trait BoolExpressionMethods: Expression + Sized { /// assert_eq!(expected, data); /// # Ok(()) /// # } - fn or>(self, other: T) -> Grouped> { + fn or(self, other: T) -> Grouped> + where + Self::SqlType: SqlType, + ST: SqlType + TypedExpressionType, + T: AsExpression, + Or: Expression, + { Grouped(Or::new(self, other.as_expression())) } } @@ -88,12 +100,3 @@ where T::SqlType: BoolOrNullableBool, { } - -#[doc(hidden)] -/// Marker trait used to implement `BoolExpressionMethods` on the appropriate -/// types. Once coherence takes associated types into account, we can remove -/// this trait. -pub trait BoolOrNullableBool {} - -impl BoolOrNullableBool for Bool {} -impl BoolOrNullableBool for Nullable {} diff --git a/diesel/src/expression_methods/global_expression_methods.rs b/diesel/src/expression_methods/global_expression_methods.rs index 88f30701b514..c0d4f7c798e7 100644 --- a/diesel/src/expression_methods/global_expression_methods.rs +++ b/diesel/src/expression_methods/global_expression_methods.rs @@ -1,7 +1,7 @@ use crate::expression::array_comparison::{AsInExpression, In, NotIn}; use crate::expression::operators::*; use crate::expression::{nullable, AsExpression, Expression}; -use crate::sql_types::SingleValue; +use crate::sql_types::{SingleValue, SqlType}; /// Methods present on all expressions, except tuples pub trait ExpressionMethods: Expression + Sized { @@ -19,7 +19,11 @@ pub trait ExpressionMethods: Expression + Sized { /// assert_eq!(Ok(1), data.first(&connection)); /// # } /// ``` - fn eq>(self, other: T) -> Eq { + fn eq(self, other: T) -> Eq + where + Self::SqlType: SqlType, + T: AsExpression, + { Eq::new(self, other.as_expression()) } @@ -37,7 +41,11 @@ pub trait ExpressionMethods: Expression + Sized { /// assert_eq!(Ok(2), data.first(&connection)); /// # } /// ``` - fn ne>(self, other: T) -> NotEq { + fn ne(self, other: T) -> NotEq + where + Self::SqlType: SqlType, + T: AsExpression, + { NotEq::new(self, other.as_expression()) } @@ -68,6 +76,7 @@ pub trait ExpressionMethods: Expression + Sized { /// ``` fn eq_any(self, values: T) -> In where + Self::SqlType: SqlType, T: AsInExpression, { In::new(self, values.as_in_expression()) @@ -103,6 +112,7 @@ pub trait ExpressionMethods: Expression + Sized { /// ``` fn ne_all(self, values: T) -> NotIn where + Self::SqlType: SqlType, T: AsInExpression, { NotIn::new(self, values.as_in_expression()) @@ -182,7 +192,11 @@ pub trait ExpressionMethods: Expression + Sized { /// # Ok(()) /// # } /// ``` - fn gt>(self, other: T) -> Gt { + fn gt(self, other: T) -> Gt + where + Self::SqlType: SqlType, + T: AsExpression, + { Gt::new(self, other.as_expression()) } @@ -208,7 +222,11 @@ pub trait ExpressionMethods: Expression + Sized { /// # Ok(()) /// # } /// ``` - fn ge>(self, other: T) -> GtEq { + fn ge(self, other: T) -> GtEq + where + Self::SqlType: SqlType, + T: AsExpression, + { GtEq::new(self, other.as_expression()) } @@ -234,7 +252,11 @@ pub trait ExpressionMethods: Expression + Sized { /// # Ok(()) /// # } /// ``` - fn lt>(self, other: T) -> Lt { + fn lt(self, other: T) -> Lt + where + Self::SqlType: SqlType, + T: AsExpression, + { Lt::new(self, other.as_expression()) } @@ -259,7 +281,11 @@ pub trait ExpressionMethods: Expression + Sized { /// assert_eq!("Sean", data); /// # Ok(()) /// # } - fn le>(self, other: T) -> LtEq { + fn le(self, other: T) -> LtEq + where + Self::SqlType: SqlType, + T: AsExpression, + { LtEq::new(self, other.as_expression()) } @@ -285,6 +311,7 @@ pub trait ExpressionMethods: Expression + Sized { /// ``` fn between(self, lower: T, upper: U) -> Between> where + Self::SqlType: SqlType, T: AsExpression, U: AsExpression, { @@ -320,6 +347,7 @@ pub trait ExpressionMethods: Expression + Sized { upper: U, ) -> NotBetween> where + Self::SqlType: SqlType, T: AsExpression, U: AsExpression, { @@ -365,11 +393,12 @@ pub trait ExpressionMethods: Expression + Sized { /// /// ```rust /// # include!("../doctest_setup.rs"); + /// # use diesel::expression::expression_types::NotSelectable; /// # /// # fn main() { /// # use schema::users::dsl::*; /// # let order = "name"; - /// let ordering: Box> = + /// let ordering: Box> = /// if order == "name" { /// Box::new(name.desc()) /// } else { diff --git a/diesel/src/expression_methods/text_expression_methods.rs b/diesel/src/expression_methods/text_expression_methods.rs index 913690a9eb62..8527edc68038 100644 --- a/diesel/src/expression_methods/text_expression_methods.rs +++ b/diesel/src/expression_methods/text_expression_methods.rs @@ -1,6 +1,6 @@ use crate::expression::operators::{Concat, Like, NotLike}; use crate::expression::{AsExpression, Expression}; -use crate::sql_types::{Nullable, Text}; +use crate::sql_types::{Nullable, SqlType, Text}; /// Methods present on text expressions pub trait TextExpressionMethods: Expression + Sized { @@ -54,7 +54,11 @@ pub trait TextExpressionMethods: Expression + Sized { /// assert_eq!(Ok(expected_names), names); /// # } /// ``` - fn concat>(self, other: T) -> Concat { + fn concat(self, other: T) -> Concat + where + Self::SqlType: SqlType, + T: AsExpression, + { Concat::new(self, other.as_expression()) } @@ -86,8 +90,12 @@ pub trait TextExpressionMethods: Expression + Sized { /// # Ok(()) /// # } /// ``` - fn like>(self, other: T) -> Like { - Like::new(self.as_expression(), other.as_expression()) + fn like(self, other: T) -> Like + where + Self::SqlType: SqlType, + T: AsExpression, + { + Like::new(self, other.as_expression()) } /// Returns a SQL `NOT LIKE` expression @@ -118,8 +126,12 @@ pub trait TextExpressionMethods: Expression + Sized { /// # Ok(()) /// # } /// ``` - fn not_like>(self, other: T) -> NotLike { - NotLike::new(self.as_expression(), other.as_expression()) + fn not_like(self, other: T) -> NotLike + where + Self::SqlType: SqlType, + T: AsExpression, + { + NotLike::new(self, other.as_expression()) } } diff --git a/diesel/src/macros/mod.rs b/diesel/src/macros/mod.rs index 28d8f5d5c659..3a7e0424ff71 100644 --- a/diesel/src/macros/mod.rs +++ b/diesel/src/macros/mod.rs @@ -101,9 +101,9 @@ macro_rules! __diesel_column { impl $crate::EqAll for $column_name where T: $crate::expression::AsExpression<$($Type)*>, - $crate::dsl::Eq<$column_name, T>: $crate::Expression, + $crate::dsl::Eq<$column_name, T::Expression>: $crate::Expression, { - type Output = $crate::dsl::Eq; + type Output = $crate::dsl::Eq; fn eq_all(self, rhs: T) -> Self::Output { $crate::expression::operators::Eq::new(self, rhs.as_expression()) @@ -180,6 +180,7 @@ macro_rules! __diesel_column { /// /// ``` /// # mod diesel_full_text_search { +/// # #[derive(diesel::sql_types::SqlType)] /// # pub struct TsVector; /// # } /// @@ -819,7 +820,7 @@ macro_rules! __diesel_table_impl { } impl Expression for star { - type SqlType = (); + type SqlType = $crate::expression::expression_types::NotSelectable; } impl QueryFragment for star where @@ -1057,19 +1058,6 @@ macro_rules! allow_tables_to_appear_in_same_query { () => {}; } -/// Gets the value out of an option, or returns an error. -/// -/// This is used by `FromSql` implementations. -#[macro_export] -macro_rules! not_none { - ($bytes:expr) => { - match $bytes { - Some(bytes) => bytes, - None => return Err(Box::new($crate::result::UnexpectedNullError)), - } - }; -} - // The order of these modules is important (at least for those which have tests). // Utility macros which don't call any others need to come first. #[macro_use] @@ -1093,7 +1081,7 @@ mod tests { } mod my_types { - #[derive(Debug, Clone, Copy)] + #[derive(Debug, Clone, Copy, crate::sql_types::SqlType)] pub struct MyCustomType; } @@ -1141,11 +1129,11 @@ mod tests { table_with_arbitrarily_complex_types { id -> sql_types::Integer, qualified_nullable -> sql_types::Nullable, - deeply_nested_type -> Option>, + deeply_nested_type -> Nullable>, // This actually should work, but there appears to be a rustc bug // on the `AsExpression` bound for `EqAll` when the ty param is a projection // projected_type -> as sql_types::IntoNullable>::Nullable, - random_tuple -> (Integer, Integer), + //random_tuple -> (Integer, Integer), } } diff --git a/diesel/src/macros/ops.rs b/diesel/src/macros/ops.rs index 63292c31e60f..eac8f12f6ab2 100644 --- a/diesel/src/macros/ops.rs +++ b/diesel/src/macros/ops.rs @@ -39,7 +39,7 @@ macro_rules! numeric_expr { #[doc(hidden)] macro_rules! __diesel_generate_ops_impls_if_numeric { ($column_name:ident, Nullable<$($inner:tt)::*>) => { __diesel_generate_ops_impls_if_numeric!($column_name, $($inner)::*); }; - + ($column_name:ident, Unsigned<$($inner:tt)::*>) => { __diesel_generate_ops_impls_if_numeric!($column_name, $($inner)::*); }; ($column_name:ident, SmallInt) => { numeric_expr!($column_name); }; diff --git a/diesel/src/mysql/connection/bind.rs b/diesel/src/mysql/connection/bind.rs index 74c07ddf243f..816ceb1058b9 100644 --- a/diesel/src/mysql/connection/bind.rs +++ b/diesel/src/mysql/connection/bind.rs @@ -1,8 +1,10 @@ use mysqlclient_sys as ffi; use std::mem; +use std::ops::Index; use std::os::raw as libc; use super::stmt::Statement; +use crate::mysql::connection::stmt::StatementMetadata; use crate::mysql::types::MYSQL_TIME; use crate::mysql::{MysqlType, MysqlValue}; use crate::result::QueryResult; @@ -12,42 +14,30 @@ pub struct Binds { } impl Binds { - pub fn from_input_data(input: Iter) -> Self + pub fn from_input_data(input: Iter) -> QueryResult where Iter: IntoIterator>)>, { let data = input .into_iter() - .map(|(metadata, bytes)| BindData::for_input(metadata, bytes)) - .collect(); - - Binds { data } - } - - pub fn from_output_types(types: Vec) -> Self { - let data = types - .into_iter() - .map(|metadata| metadata.into()) - .map(BindData::for_output) - .collect(); + .map(BindData::for_input) + .collect::>(); - Binds { data } + Ok(Binds { data }) } - pub fn from_result_metadata(fields: &[ffi::MYSQL_FIELD]) -> Self { - let data = fields + pub fn from_output_types(types: Vec>, metadata: &StatementMetadata) -> Self { + let data = metadata + .fields() .iter() - .map(|field| { - ( - field.type_, - Flags::from_bits(field.flags).expect( - "We encountered a unknown type flag while parsing \ - Mysql's type information. If you see this error message \ - please open an issue at diesels github page.", - ), - ) + .zip(types.into_iter().chain(std::iter::repeat(None))) + .map(|(field, tpe)| { + if let Some(tpe) = tpe { + BindData::for_output(tpe.into()) + } else { + BindData::for_output((field.field_type(), field.flags())) + } }) - .map(BindData::for_output) .collect(); Binds { data } @@ -88,17 +78,20 @@ impl Binds { } } - pub fn field_data(&self, idx: usize) -> Option> { - let data = &self.data[idx]; - self.data[idx].bytes().map(|bytes| { - let tpe = (data.tpe, data.flags).into(); - MysqlValue::new(bytes, tpe) - }) + pub fn len(&self) -> usize { + self.data.len() + } +} + +impl Index for Binds { + type Output = BindData; + fn index(&self, index: usize) -> &Self::Output { + &self.data[index] } } bitflags::bitflags! { - struct Flags: u32 { + pub(crate) struct Flags: u32 { const NOT_NULL_FLAG = 1; const PRI_KEY_FAG = 2; const UNIQUE_KEY_FLAG = 4; @@ -123,7 +116,18 @@ bitflags::bitflags! { } } -struct BindData { +impl From for Flags { + fn from(flags: u32) -> Self { + Flags::from_bits(flags).expect( + "We encountered a unknown type flag while parsing \ + Mysql's type information. If you see this error message \ + please open an issue at diesels github page.", + ) + } +} + +#[derive(Debug)] +pub struct BindData { tpe: ffi::enum_field_types, bytes: Vec, length: libc::c_ulong, @@ -133,7 +137,7 @@ struct BindData { } impl BindData { - fn for_input(tpe: MysqlType, data: Option>) -> Self { + fn for_input((tpe, data): (MysqlType, Option>)) -> Self { let is_null = if data.is_none() { 1 } else { 0 }; let bytes = data.unwrap_or_default(); let length = bytes.len() as libc::c_ulong; @@ -172,14 +176,19 @@ impl BindData { known_buffer_size_for_ffi_type(self.tpe).is_some() } - fn bytes(&self) -> Option<&[u8]> { - if self.is_null == 0 { - Some(&*self.bytes) - } else { + pub fn value(&'_ self) -> Option> { + if self.is_null() { None + } else { + let tpe = (self.tpe, self.flags).into(); + Some(MysqlValue::new(&self.bytes, tpe)) } } + pub fn is_null(&self) -> bool { + self.is_null != 0 + } + fn update_buffer_length(&mut self) { use std::cmp::min; @@ -446,7 +455,7 @@ mod tests { let meta = (bind.tpe, bind.flags).into(); dbg!(meta); let value = MysqlValue::new(&bind.bytes, meta); - dbg!(T::from_sql(Some(value))) + dbg!(T::from_sql(value)) } #[cfg(feature = "extras")] @@ -552,11 +561,12 @@ mod tests { .unwrap(); let metadata = stmt.metadata().unwrap(); - let mut output_binds = Binds::from_result_metadata(metadata.fields()); + let mut output_binds = + Binds::from_output_types(vec![None; metadata.fields().len()], &metadata); stmt.execute_statement(&mut output_binds).unwrap(); stmt.populate_row_buffers(&mut output_binds).unwrap(); - let results: Vec<(BindData, &ffi::st_mysql_field)> = output_binds + let results: Vec<(BindData, &_)> = output_binds .data .into_iter() .zip(metadata.fields()) diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index b675a216f51a..e9aa3689dc4d 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -8,11 +8,11 @@ use self::stmt::Statement; use self::url::ConnectionOptions; use super::backend::Mysql; use crate::connection::*; -use crate::deserialize::{Queryable, QueryableByName}; +use crate::deserialize::FromSqlRow; +use crate::expression::QueryMetadata; use crate::query_builder::bind_collector::RawBytesBindCollector; use crate::query_builder::*; use crate::result::*; -use crate::sql_types::HasSqlType; #[allow(missing_debug_implementations, missing_copy_implementations)] /// A connection to a MySQL database. Connection URLs should be in the form @@ -60,38 +60,20 @@ impl Connection for MysqlConnection { } #[doc(hidden)] - fn query_by_index(&self, source: T) -> QueryResult> + fn load(&self, source: T) -> QueryResult> where T: AsQuery, T::Query: QueryFragment + QueryId, - Self::Backend: HasSqlType, - U: Queryable, + U: FromSqlRow, + Self::Backend: QueryMetadata, { - use crate::deserialize::FromSqlRow; use crate::result::Error::DeserializationError; let mut stmt = self.prepare_query(&source.as_query())?; let mut metadata = Vec::new(); - Mysql::mysql_row_metadata(&mut metadata, &()); + Mysql::row_metadata(&(), &mut metadata); let results = unsafe { stmt.results(metadata)? }; - results.map(|mut row| { - U::Row::build_from_row(&mut row) - .map(U::build) - .map_err(DeserializationError) - }) - } - - #[doc(hidden)] - fn query_by_name(&self, source: &T) -> QueryResult> - where - T: QueryFragment + QueryId, - U: QueryableByName, - { - use crate::result::Error::DeserializationError; - - let mut stmt = self.prepare_query(source)?; - let results = unsafe { stmt.named_results()? }; - results.map(|row| U::build(&row).map_err(DeserializationError)) + results.map(|row| U::build_from_row(&row).map_err(DeserializationError)) } #[doc(hidden)] diff --git a/diesel/src/mysql/connection/stmt/iterator.rs b/diesel/src/mysql/connection/stmt/iterator.rs index 2321efcacd07..584b66f949ae 100644 --- a/diesel/src/mysql/connection/stmt/iterator.rs +++ b/diesel/src/mysql/connection/stmt/iterator.rs @@ -1,24 +1,29 @@ -use std::collections::HashMap; - -use super::{Binds, Statement, StatementMetadata}; -use crate::mysql::{Mysql, MysqlType, MysqlValue}; +use super::{metadata::MysqlFieldMetadata, BindData, Binds, Statement, StatementMetadata}; +use crate::mysql::{Mysql, MysqlType}; use crate::result::QueryResult; use crate::row::*; pub struct StatementIterator<'a> { stmt: &'a mut Statement, output_binds: Binds, + metadata: StatementMetadata, } #[allow(clippy::should_implement_trait)] // don't neet `Iterator` here impl<'a> StatementIterator<'a> { #[allow(clippy::new_ret_no_self)] - pub fn new(stmt: &'a mut Statement, types: Vec) -> QueryResult { - let mut output_binds = Binds::from_output_types(types); + pub fn new(stmt: &'a mut Statement, types: Vec>) -> QueryResult { + let metadata = stmt.metadata()?; + + let mut output_binds = Binds::from_output_types(types, &metadata); stmt.execute_statement(&mut output_binds)?; - Ok(StatementIterator { stmt, output_binds }) + Ok(StatementIterator { + stmt, + output_binds, + metadata, + }) } pub fn map(mut self, mut f: F) -> QueryResult> @@ -37,6 +42,7 @@ impl<'a> StatementIterator<'a> { Ok(Some(())) => Some(Ok(MysqlRow { col_idx: 0, binds: &mut self.output_binds, + metadata: &self.metadata, })), Ok(None) => None, Err(e) => Some(Err(e)), @@ -44,79 +50,73 @@ impl<'a> StatementIterator<'a> { } } +#[derive(Clone)] pub struct MysqlRow<'a> { col_idx: usize, binds: &'a Binds, + metadata: &'a StatementMetadata, } -impl<'a> Row for MysqlRow<'a> { - fn take(&mut self) -> Option> { - let current_idx = self.col_idx; - self.col_idx += 1; - self.binds.field_data(current_idx) - } +impl<'a> Row<'a, Mysql> for MysqlRow<'a> { + type Field = MysqlField<'a>; + type InnerPartialRow = Self; - fn next_is_null(&self, count: usize) -> bool { - (0..count).all(|i| self.binds.field_data(self.col_idx + i).is_none()) + fn field_count(&self) -> usize { + self.binds.len() } -} - -pub struct NamedStatementIterator<'a> { - stmt: &'a mut Statement, - output_binds: Binds, - metadata: StatementMetadata, -} -#[allow(clippy::should_implement_trait)] // don't need `Iterator` here -impl<'a> NamedStatementIterator<'a> { - #[allow(clippy::new_ret_no_self)] - pub fn new(stmt: &'a mut Statement) -> QueryResult { - let metadata = stmt.metadata()?; - let mut output_binds = Binds::from_result_metadata(metadata.fields()); - - stmt.execute_statement(&mut output_binds)?; - - Ok(NamedStatementIterator { - stmt, - output_binds, - metadata, + fn get(&self, idx: I) -> Option + where + Self: RowIndex, + { + let idx = self.idx(idx)?; + Some(MysqlField { + bind: &self.binds[idx], + metadata: &self.metadata.fields()[idx], }) } - pub fn map(mut self, mut f: F) -> QueryResult> - where - F: FnMut(NamedMysqlRow) -> QueryResult, - { - let mut results = Vec::new(); - while let Some(row) = self.next() { - results.push(f(row?)?); - } - Ok(results) + fn partial_row(&self, range: std::ops::Range) -> PartialRow { + PartialRow::new(self, range) } +} - fn next(&mut self) -> Option> { - match self.stmt.populate_row_buffers(&mut self.output_binds) { - Ok(Some(())) => Some(Ok(NamedMysqlRow { - binds: &self.output_binds, - column_indices: self.metadata.column_indices(), - })), - Ok(None) => None, - Err(e) => Some(Err(e)), +impl<'a> RowIndex for MysqlRow<'a> { + fn idx(&self, idx: usize) -> Option { + if idx < self.field_count() { + Some(idx) + } else { + None } } } -pub struct NamedMysqlRow<'a> { - binds: &'a Binds, - column_indices: &'a HashMap<&'a str, usize>, +impl<'a, 'b> RowIndex<&'a str> for MysqlRow<'b> { + fn idx(&self, idx: &'a str) -> Option { + self.metadata + .fields() + .iter() + .enumerate() + .find(|(_, field_meta)| field_meta.field_name() == Some(idx)) + .map(|(idx, _)| idx) + } +} + +pub struct MysqlField<'a> { + bind: &'a BindData, + metadata: &'a MysqlFieldMetadata<'a>, } -impl<'a> NamedRow for NamedMysqlRow<'a> { - fn index_of(&self, column_name: &str) -> Option { - self.column_indices.get(column_name).cloned() +impl<'a> Field<'a, Mysql> for MysqlField<'a> { + fn field_name(&self) -> Option<&'a str> { + self.metadata.field_name() + } + + fn is_null(&self) -> bool { + self.bind.is_null() } - fn get_raw_value(&self, idx: usize) -> Option> { - self.binds.field_data(idx) + fn value(&self) -> Option> { + self.bind.value() } } diff --git a/diesel/src/mysql/connection/stmt/metadata.rs b/diesel/src/mysql/connection/stmt/metadata.rs index 331c10013b94..7a79ee92c51f 100644 --- a/diesel/src/mysql/connection/stmt/metadata.rs +++ b/diesel/src/mysql/connection/stmt/metadata.rs @@ -1,52 +1,60 @@ -use std::collections::HashMap; use std::ffi::CStr; +use std::ptr::NonNull; use std::slice; use super::ffi; +use crate::mysql::connection::bind::Flags; pub struct StatementMetadata { - result: &'static mut ffi::MYSQL_RES, - column_indices: HashMap<&'static str, usize>, + result: NonNull, } impl StatementMetadata { - pub fn new(result: &'static mut ffi::MYSQL_RES) -> Self { - let mut res = StatementMetadata { - column_indices: HashMap::new(), - result, - }; - res.populate_column_indices(); - res + pub fn new(result: NonNull) -> Self { + StatementMetadata { result } } - pub fn fields(&self) -> &[ffi::MYSQL_FIELD] { + pub fn fields(&'_ self) -> &'_ [MysqlFieldMetadata<'_>] { unsafe { - let ptr = self.result as *const _ as *mut _; - let num_fields = ffi::mysql_num_fields(ptr); - let field_ptr = ffi::mysql_fetch_fields(ptr); - slice::from_raw_parts(field_ptr, num_fields as usize) + let num_fields = ffi::mysql_num_fields(self.result.as_ptr()); + let field_ptr = ffi::mysql_fetch_fields(self.result.as_ptr()); + if field_ptr.is_null() { + &[] + } else { + slice::from_raw_parts(field_ptr as _, num_fields as usize) + } } } +} - pub fn column_indices(&self) -> &HashMap<&str, usize> { - &self.column_indices +impl Drop for StatementMetadata { + fn drop(&mut self) { + unsafe { ffi::mysql_free_result(self.result.as_mut()) }; } +} + +#[repr(transparent)] +pub struct MysqlFieldMetadata<'a>(ffi::MYSQL_FIELD, std::marker::PhantomData<&'a ()>); - fn populate_column_indices(&mut self) { - self.column_indices = self - .fields() - .iter() - .enumerate() - .map(|(i, field)| { - let c_name = unsafe { CStr::from_ptr(field.name) }; - (c_name.to_str().unwrap_or_default(), i) - }) - .collect() +impl<'a> MysqlFieldMetadata<'a> { + pub fn field_name(&self) -> Option<&str> { + if self.0.name.is_null() { + None + } else { + unsafe { + Some(CStr::from_ptr(self.0.name).to_str().expect( + "Expect mysql field names to be UTF-8, because we \ + requested UTF-8 encoding on connection setup", + )) + } + } } -} -impl Drop for StatementMetadata { - fn drop(&mut self) { - unsafe { ffi::mysql_free_result(self.result) }; + pub fn field_type(&self) -> ffi::enum_field_types { + self.0.type_ + } + + pub(crate) fn flags(&self) -> Flags { + Flags::from(self.0.flags) } } diff --git a/diesel/src/mysql/connection/stmt/mod.rs b/diesel/src/mysql/connection/stmt/mod.rs index 713a4dd74fc6..7fb043327207 100644 --- a/diesel/src/mysql/connection/stmt/mod.rs +++ b/diesel/src/mysql/connection/stmt/mod.rs @@ -8,11 +8,12 @@ use std::os::raw as libc; use std::ptr::NonNull; use self::iterator::*; -use self::metadata::*; -use super::bind::Binds; +use super::bind::{BindData, Binds}; use crate::mysql::MysqlType; use crate::result::{DatabaseErrorKind, QueryResult}; +pub use self::metadata::StatementMetadata; + pub struct Statement { stmt: NonNull, input_binds: Option, @@ -21,7 +22,7 @@ pub struct Statement { impl Statement { pub(crate) fn new(stmt: NonNull) -> Self { Statement { - stmt: stmt, + stmt, input_binds: None, } } @@ -41,7 +42,7 @@ impl Statement { where Iter: IntoIterator>)>, { - let input_binds = Binds::from_input_data(binds); + let input_binds = Binds::from_input_data(binds)?; self.input_bind(input_binds) } @@ -76,17 +77,13 @@ impl Statement { /// This function should be called instead of `execute` for queries which /// have a return value. After calling this function, `execute` can never /// be called on this statement. - pub unsafe fn results(&mut self, types: Vec) -> QueryResult { + pub unsafe fn results( + &mut self, + types: Vec>, + ) -> QueryResult { StatementIterator::new(self, types) } - /// This function should be called instead of `execute` for queries which - /// have a return value. After calling this function, `execute` can never - /// be called on this statement. - pub unsafe fn named_results(&mut self) -> QueryResult { - NamedStatementIterator::new(self) - } - fn last_error_message(&self) -> String { unsafe { CStr::from_ptr(ffi::mysql_stmt_error(self.stmt.as_ptr())) } .to_string_lossy() @@ -118,9 +115,9 @@ impl Statement { pub(super) fn metadata(&self) -> QueryResult { use crate::result::Error::DeserializationError; - let result_ptr = unsafe { ffi::mysql_stmt_result_metadata(self.stmt.as_ptr()).as_mut() }; + let result_ptr = unsafe { ffi::mysql_stmt_result_metadata(self.stmt.as_ptr()) }; self.did_an_error_occur()?; - result_ptr + NonNull::new(result_ptr) .map(StatementMetadata::new) .ok_or_else(|| DeserializationError("No metadata exists".into())) } diff --git a/diesel/src/mysql/types/date_and_time.rs b/diesel/src/mysql/types/date_and_time.rs index 9709c1d2ef3f..093dd49fe0bd 100644 --- a/diesel/src/mysql/types/date_and_time.rs +++ b/diesel/src/mysql/types/date_and_time.rs @@ -24,9 +24,8 @@ macro_rules! mysql_time_impls { } impl FromSql<$ty, Mysql> for MYSQL_TIME { - fn from_sql(value: Option>) -> deserialize::Result { - let data = not_none!(value); - data.time_value() + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { + value.time_value() } } }; @@ -44,7 +43,7 @@ impl ToSql for NaiveDateTime { } impl FromSql for NaiveDateTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result { >::from_sql(bytes) } } @@ -69,7 +68,7 @@ impl ToSql for NaiveDateTime { } impl FromSql for NaiveDateTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result { let mysql_time = >::from_sql(bytes)?; NaiveDate::from_ymd_opt( @@ -109,7 +108,7 @@ impl ToSql for NaiveTime { } impl FromSql for NaiveTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result { let mysql_time = >::from_sql(bytes)?; NaiveTime::from_hms_opt( mysql_time.hour as u32, @@ -140,7 +139,7 @@ impl ToSql for NaiveDate { } impl FromSql for NaiveDate { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result { let mysql_time = >::from_sql(bytes)?; NaiveDate::from_ymd_opt( mysql_time.year as i32, diff --git a/diesel/src/mysql/types/json.rs b/diesel/src/mysql/types/json.rs index 237027b13b91..8d67636995c2 100644 --- a/diesel/src/mysql/types/json.rs +++ b/diesel/src/mysql/types/json.rs @@ -5,8 +5,7 @@ use crate::sql_types; use std::io::prelude::*; impl FromSql for serde_json::Value { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { serde_json::from_slice(value.as_bytes()).map_err(|_| "Invalid Json".into()) } } @@ -31,25 +30,24 @@ fn json_to_sql() { fn some_json_from_sql() { use crate::mysql::MysqlType; let input_json = b"true"; - let output_json: serde_json::Value = FromSql::::from_sql(Some( - MysqlValue::new(input_json, MysqlType::String), - )) - .unwrap(); + let output_json: serde_json::Value = + FromSql::::from_sql(MysqlValue::new(input_json, MysqlType::String)) + .unwrap(); assert_eq!(output_json, serde_json::Value::Bool(true)); } #[test] fn bad_json_from_sql() { use crate::mysql::MysqlType; - let uuid: Result = FromSql::::from_sql(Some( - MysqlValue::new(b"boom", MysqlType::String), - )); + let uuid: Result = + FromSql::::from_sql(MysqlValue::new(b"boom", MysqlType::String)); assert_eq!(uuid.unwrap_err().to_string(), "Invalid Json"); } #[test] fn no_json_from_sql() { - let uuid: Result = FromSql::::from_sql(None); + let uuid: Result = + FromSql::::from_nullable_sql(None); assert_eq!( uuid.unwrap_err().to_string(), "Unexpected null for non-null column" diff --git a/diesel/src/mysql/types/mod.rs b/diesel/src/mysql/types/mod.rs index 302bb8683ba7..701a5f32d15d 100644 --- a/diesel/src/mysql/types/mod.rs +++ b/diesel/src/mysql/types/mod.rs @@ -46,8 +46,7 @@ impl ToSql for i8 { } impl FromSql for i8 { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { let bytes = value.as_bytes(); Ok(bytes[0] as i8) } @@ -96,7 +95,7 @@ impl ToSql, Mysql> for u8 { } impl FromSql, Mysql> for u8 { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result { let signed: i8 = FromSql::::from_sql(bytes)?; Ok(signed as u8) } @@ -109,7 +108,7 @@ impl ToSql, Mysql> for u16 { } impl FromSql, Mysql> for u16 { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result { let signed: i16 = FromSql::::from_sql(bytes)?; Ok(signed as u16) } @@ -122,7 +121,7 @@ impl ToSql, Mysql> for u32 { } impl FromSql, Mysql> for u32 { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result { let signed: i32 = FromSql::::from_sql(bytes)?; Ok(signed as u32) } @@ -135,7 +134,7 @@ impl ToSql, Mysql> for u64 { } impl FromSql, Mysql> for u64 { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result { let signed: i64 = FromSql::::from_sql(bytes)?; Ok(signed as u64) } @@ -149,8 +148,8 @@ impl ToSql for bool { } impl FromSql for bool { - fn from_sql(bytes: Option>) -> deserialize::Result { - Ok(not_none!(bytes).as_bytes().iter().any(|x| *x != 0)) + fn from_sql(bytes: MysqlValue<'_>) -> deserialize::Result { + Ok(bytes.as_bytes().iter().any(|x| *x != 0)) } } diff --git a/diesel/src/mysql/types/numeric.rs b/diesel/src/mysql/types/numeric.rs index 38b6f23d9315..783f813383c7 100644 --- a/diesel/src/mysql/types/numeric.rs +++ b/diesel/src/mysql/types/numeric.rs @@ -19,10 +19,10 @@ pub mod bigdecimal { } impl FromSql for BigDecimal { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { use crate::mysql::NumericRepresentation::*; - let data = not_none!(value); - match data.numeric_value()? { + + match value.numeric_value()? { Tiny(x) => Ok(x.into()), Small(x) => Ok(x.into()), Medium(x) => Ok(x.into()), diff --git a/diesel/src/mysql/types/primitives.rs b/diesel/src/mysql/types/primitives.rs index 6e6d9c0f5733..d2583d91a1e9 100644 --- a/diesel/src/mysql/types/primitives.rs +++ b/diesel/src/mysql/types/primitives.rs @@ -29,11 +29,10 @@ where } impl FromSql for i16 { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { use crate::mysql::NumericRepresentation::*; - let data = not_none!(value); - match data.numeric_value()? { + match value.numeric_value()? { Tiny(x) => Ok(x.into()), Small(x) => Ok(x), Medium(x) => Ok(x as Self), @@ -46,11 +45,10 @@ impl FromSql for i16 { } impl FromSql for i32 { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { use crate::mysql::NumericRepresentation::*; - let data = not_none!(value); - match data.numeric_value()? { + match value.numeric_value()? { Tiny(x) => Ok(x.into()), Small(x) => Ok(x.into()), Medium(x) => Ok(x), @@ -63,11 +61,10 @@ impl FromSql for i32 { } impl FromSql for i64 { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { use crate::mysql::NumericRepresentation::*; - let data = not_none!(value); - match data.numeric_value()? { + match value.numeric_value()? { Tiny(x) => Ok(x.into()), Small(x) => Ok(x.into()), Medium(x) => Ok(x.into()), @@ -80,11 +77,10 @@ impl FromSql for i64 { } impl FromSql for f32 { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { use crate::mysql::NumericRepresentation::*; - let data = not_none!(value); - match data.numeric_value()? { + match value.numeric_value()? { Tiny(x) => Ok(x.into()), Small(x) => Ok(x.into()), Medium(x) => Ok(x as Self), @@ -97,11 +93,10 @@ impl FromSql for f32 { } impl FromSql for f64 { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { use crate::mysql::NumericRepresentation::*; - let data = not_none!(value); - match data.numeric_value()? { + match value.numeric_value()? { Tiny(x) => Ok(x.into()), Small(x) => Ok(x.into()), Medium(x) => Ok(x.into()), @@ -114,15 +109,13 @@ impl FromSql for f64 { } impl FromSql for String { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { String::from_utf8(value.as_bytes().into()).map_err(Into::into) } } impl FromSql for Vec { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: MysqlValue<'_>) -> deserialize::Result { Ok(value.as_bytes().into()) } } diff --git a/diesel/src/mysql/value.rs b/diesel/src/mysql/value.rs index 8b33bffc7543..91b2d3e51e60 100644 --- a/diesel/src/mysql/value.rs +++ b/diesel/src/mysql/value.rs @@ -20,6 +20,11 @@ impl<'a> MysqlValue<'a> { self.raw } + /// Get the mysql type of the current value + pub fn value_type(&self) -> MysqlType { + self.tpe + } + /// Checks that the type code is valid, and interprets the data as a /// `MYSQL_TIME` pointer // We use `ptr.read_unaligned()` to read the potential unaligned ptr, diff --git a/diesel/src/pg/backend.rs b/diesel/src/pg/backend.rs index 5d78758a561e..a374fae4d738 100644 --- a/diesel/src/pg/backend.rs +++ b/diesel/src/pg/backend.rs @@ -7,7 +7,7 @@ use super::{PgMetadataLookup, PgValue}; use crate::backend::*; use crate::deserialize::Queryable; use crate::query_builder::bind_collector::RawBytesBindCollector; -use crate::sql_types::{Oid, TypeMetadata}; +use crate::sql_types::TypeMetadata; /// The PostgreSQL backend #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] @@ -16,7 +16,7 @@ pub struct Pg; /// The [OIDs] for a SQL type /// /// [OIDs]: https://www.postgresql.org/docs/current/static/datatype-oid.html -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Default)] +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Default, Queryable)] pub struct PgTypeMetadata { /// The [OID] of `T` /// @@ -28,14 +28,6 @@ pub struct PgTypeMetadata { pub array_oid: u32, } -impl Queryable<(Oid, Oid), Pg> for PgTypeMetadata { - type Row = (u32, u32); - - fn build((oid, array_oid): Self::Row) -> Self { - PgTypeMetadata { oid, array_oid } - } -} - impl Backend for Pg { type QueryBuilder = PgQueryBuilder; type BindCollector = RawBytesBindCollector; diff --git a/diesel/src/pg/connection/cursor.rs b/diesel/src/pg/connection/cursor.rs index 698de6899b52..b345ec27a8e2 100644 --- a/diesel/src/pg/connection/cursor.rs +++ b/diesel/src/pg/connection/cursor.rs @@ -1,77 +1,48 @@ use super::result::PgResult; -use super::row::PgNamedRow; -use crate::deserialize::{FromSqlRow, Queryable, QueryableByName}; -use crate::pg::Pg; -use crate::result::Error::DeserializationError; -use crate::result::QueryResult; - -use std::marker::PhantomData; +use super::row::PgRow; /// The type returned by various [`Connection`](struct.Connection.html) methods. /// Acts as an iterator over `T`. -pub struct Cursor { +pub struct Cursor<'a> { current_row: usize, - db_result: PgResult, - _marker: PhantomData<(ST, T)>, + db_result: &'a PgResult, } -impl Cursor { - #[doc(hidden)] - pub fn new(db_result: PgResult) -> Self { +impl<'a> Cursor<'a> { + pub(super) fn new(db_result: &'a PgResult) -> Self { Cursor { current_row: 0, db_result, - _marker: PhantomData, } } } -impl Iterator for Cursor -where - T: Queryable, -{ - type Item = QueryResult; - - fn next(&mut self) -> Option { - if self.current_row >= self.db_result.num_rows() { - None - } else { - let mut row = self.db_result.get_row(self.current_row); - self.current_row += 1; - let value = T::Row::build_from_row(&mut row) - .map(T::build) - .map_err(DeserializationError); - Some(value) - } +impl<'a> ExactSizeIterator for Cursor<'a> { + fn len(&self) -> usize { + self.db_result.num_rows() - self.current_row } } -pub struct NamedCursor { - pub(crate) db_result: PgResult, -} +impl<'a> Iterator for Cursor<'a> { + type Item = PgRow<'a>; -impl NamedCursor { - pub fn new(db_result: PgResult) -> Self { - NamedCursor { db_result } - } - - pub fn collect(self) -> QueryResult> - where - T: QueryableByName, - { - (0..self.db_result.num_rows()) - .map(|i| { - let row = PgNamedRow::new(&self, i); - T::build(&row).map_err(DeserializationError) - }) - .collect() + fn next(&mut self) -> Option { + if self.current_row < self.db_result.num_rows() { + let row = self.db_result.get_row(self.current_row); + self.current_row += 1; + Some(row) + } else { + None + } } - pub fn index_of_column(&self, column_name: &str) -> Option { - self.db_result.field_number(column_name) + fn nth(&mut self, n: usize) -> Option { + self.current_row = (self.current_row + n).min(self.db_result.num_rows()); + self.next() } - pub fn get_value(&self, row: usize, column: usize) -> Option<&[u8]> { - self.db_result.get(row, column) + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) } } diff --git a/diesel/src/pg/connection/mod.rs b/diesel/src/pg/connection/mod.rs index 121180bf6248..c7788f21a7cf 100644 --- a/diesel/src/pg/connection/mod.rs +++ b/diesel/src/pg/connection/mod.rs @@ -13,13 +13,14 @@ use self::raw::RawConnection; use self::result::PgResult; use self::stmt::Statement; use crate::connection::*; -use crate::deserialize::{Queryable, QueryableByName}; +use crate::deserialize::FromSqlRow; +use crate::expression::QueryMetadata; use crate::pg::{metadata_lookup::PgMetadataCache, Pg, PgMetadataLookup, TransactionBuilder}; use crate::query_builder::bind_collector::RawBytesBindCollector; use crate::query_builder::*; use crate::result::ConnectionError::CouldntSetupConfiguration; +use crate::result::Error::DeserializationError; use crate::result::*; -use crate::sql_types::HasSqlType; /// The connection string expected by `PgConnection::establish` /// should be a PostgreSQL connection string, as documented at @@ -67,29 +68,20 @@ impl Connection for PgConnection { } #[doc(hidden)] - fn query_by_index(&self, source: T) -> QueryResult> + fn load(&self, source: T) -> QueryResult> where T: AsQuery, - T::Query: QueryFragment + QueryId, - Pg: HasSqlType, - U: Queryable, + T::Query: QueryFragment + QueryId, + U: FromSqlRow, + Self::Backend: QueryMetadata, { let (query, params) = self.prepare_query(&source.as_query())?; - query - .execute(self, ¶ms) - .and_then(|r| Cursor::new(r).collect()) - } + let result = query.execute(self, ¶ms)?; + let cursor = Cursor::new(&result); - #[doc(hidden)] - fn query_by_name(&self, source: &T) -> QueryResult> - where - T: QueryFragment + QueryId, - U: QueryableByName, - { - let (query, params) = self.prepare_query(source)?; - query - .execute(self, ¶ms) - .and_then(|r| NamedCursor::new(r).collect()) + cursor + .map(|row| U::build_from_row(&row).map_err(DeserializationError)) + .collect::>>() } #[doc(hidden)] diff --git a/diesel/src/pg/connection/result.rs b/diesel/src/pg/connection/result.rs index 9a4eda32a814..782e1f315967 100644 --- a/diesel/src/pg/connection/result.rs +++ b/diesel/src/pg/connection/result.rs @@ -1,7 +1,7 @@ extern crate pq_sys; use self::pq_sys::*; -use std::ffi::{CStr, CString}; +use std::ffi::CStr; use std::num::NonZeroU32; use std::os::raw as libc; use std::{slice, str}; @@ -12,6 +12,8 @@ use crate::result::{DatabaseErrorInformation, DatabaseErrorKind, Error, QueryRes pub struct PgResult { internal_result: RawResult, + column_count: usize, + row_count: usize, } impl PgResult { @@ -21,7 +23,15 @@ impl PgResult { let result_status = unsafe { PQresultStatus(internal_result.as_ptr()) }; match result_status { - PGRES_COMMAND_OK | PGRES_TUPLES_OK => Ok(PgResult { internal_result }), + PGRES_COMMAND_OK | PGRES_TUPLES_OK => { + let column_count = unsafe { PQnfields(internal_result.as_ptr()) as usize }; + let row_count = unsafe { PQntuples(internal_result.as_ptr()) as usize }; + Ok(PgResult { + internal_result, + column_count, + row_count, + }) + } PGRES_EMPTY_QUERY => { let error_message = "Received an empty query".to_string(); Err(Error::DatabaseError( @@ -71,7 +81,7 @@ impl PgResult { } pub fn num_rows(&self) -> usize { - unsafe { PQntuples(self.internal_result.as_ptr()) as usize } + self.row_count } pub fn get_row(&self, idx: usize) -> PgRow { @@ -104,22 +114,29 @@ impl PgResult { } pub fn column_type(&self, col_idx: usize) -> NonZeroU32 { + let type_oid = unsafe { PQftype(self.internal_result.as_ptr(), col_idx as libc::c_int) }; + NonZeroU32::new(type_oid).expect( + "Got a zero oid from postgres. If you see this error message \ + please report it as issue on the diesel github bug tracker.", + ) + } + + pub fn column_name(&self, col_idx: usize) -> Option<&str> { unsafe { - NonZeroU32::new(PQftype( - self.internal_result.as_ptr(), - col_idx as libc::c_int, - )) - .expect("Oid's aren't zero") + let ptr = PQfname(self.internal_result.as_ptr(), col_idx as libc::c_int); + if ptr.is_null() { + None + } else { + Some(CStr::from_ptr(ptr).to_str().expect( + "Expect postgres field names to be UTF-8, because we \ + requested UTF-8 encoding on connection setup", + )) + } } } - pub fn field_number(&self, column_name: &str) -> Option { - let cstr = CString::new(column_name).unwrap_or_default(); - let fnum = unsafe { PQfnumber(self.internal_result.as_ptr(), cstr.as_ptr()) }; - match fnum { - -1 => None, - x => Some(x as usize), - } + pub fn column_count(&self) -> usize { + self.column_count } } diff --git a/diesel/src/pg/connection/row.rs b/diesel/src/pg/connection/row.rs index 1ebf50cfb51c..421fd4bf7c53 100644 --- a/diesel/src/pg/connection/row.rs +++ b/diesel/src/pg/connection/row.rs @@ -1,56 +1,75 @@ -use super::cursor::NamedCursor; use super::result::PgResult; use crate::pg::{Pg, PgValue}; use crate::row::*; +#[derive(Clone)] pub struct PgRow<'a> { db_result: &'a PgResult, row_idx: usize, - col_idx: usize, } impl<'a> PgRow<'a> { pub fn new(db_result: &'a PgResult, row_idx: usize) -> Self { - PgRow { - db_result, - row_idx, - col_idx: 0, - } + PgRow { row_idx, db_result } } } -impl<'a> Row for PgRow<'a> { - fn take(&mut self) -> Option> { - let current_idx = self.col_idx; - self.col_idx += 1; - let raw = self.db_result.get(self.row_idx, current_idx)?; +impl<'a> Row<'a, Pg> for PgRow<'a> { + type Field = PgField<'a>; + type InnerPartialRow = Self; + + fn field_count(&self) -> usize { + self.db_result.column_count() + } - Some(PgValue::new(raw, self.db_result.column_type(current_idx))) + fn get(&self, idx: I) -> Option + where + Self: RowIndex, + { + let idx = self.idx(idx)?; + Some(PgField { + db_result: self.db_result, + row_idx: self.row_idx, + col_idx: idx, + }) } - fn next_is_null(&self, count: usize) -> bool { - (0..count).all(|i| self.db_result.is_null(self.row_idx, self.col_idx + i)) + fn partial_row(&self, range: std::ops::Range) -> PartialRow { + PartialRow::new(self, range) } } -pub struct PgNamedRow<'a> { - cursor: &'a NamedCursor, - idx: usize, +impl<'a> RowIndex for PgRow<'a> { + fn idx(&self, idx: usize) -> Option { + if idx < self.field_count() { + Some(idx) + } else { + None + } + } } -impl<'a> PgNamedRow<'a> { - pub fn new(cursor: &'a NamedCursor, idx: usize) -> Self { - PgNamedRow { cursor, idx } +impl<'a, 'b> RowIndex<&'a str> for PgRow<'b> { + fn idx(&self, field_name: &'a str) -> Option { + (0..self.field_count()).find(|idx| self.db_result.column_name(*idx) == Some(field_name)) } } -impl<'a> NamedRow for PgNamedRow<'a> { - fn get_raw_value(&self, index: usize) -> Option> { - let raw = self.cursor.get_value(self.idx, index)?; - Some(PgValue::new(raw, self.cursor.db_result.column_type(index))) +pub struct PgField<'a> { + db_result: &'a PgResult, + row_idx: usize, + col_idx: usize, +} + +impl<'a> Field<'a, Pg> for PgField<'a> { + fn field_name(&self) -> Option<&'a str> { + self.db_result.column_name(self.col_idx) } - fn index_of(&self, column_name: &str) -> Option { - self.cursor.index_of_column(column_name) + fn value(&self) -> Option> { + let raw = self.db_result.get(self.row_idx, self.col_idx)?; + let type_oid = self.db_result.column_type(self.col_idx); + + Some(PgValue::new(raw, type_oid)) } } diff --git a/diesel/src/pg/expression/array_comparison.rs b/diesel/src/pg/expression/array_comparison.rs index db39198b3911..0dd1fc4cb047 100644 --- a/diesel/src/pg/expression/array_comparison.rs +++ b/diesel/src/pg/expression/array_comparison.rs @@ -1,9 +1,9 @@ use crate::expression::subselect::Subselect; -use crate::expression::{AsExpression, Expression, ValidGrouping}; +use crate::expression::{AsExpression, Expression, TypedExpressionType, ValidGrouping}; use crate::pg::Pg; use crate::query_builder::*; use crate::result::QueryResult; -use crate::sql_types::Array; +use crate::sql_types::{Array, SqlType}; /// Creates a PostgreSQL `ANY` expression. /// @@ -75,6 +75,7 @@ impl Any { impl Expression for Any where Expr: Expression>, + ST: SqlType + TypedExpressionType, { type SqlType = ST; } @@ -108,6 +109,7 @@ impl All { impl Expression for All where Expr: Expression>, + ST: SqlType + TypedExpressionType, { type SqlType = ST; } diff --git a/diesel/src/pg/expression/date_and_time.rs b/diesel/src/pg/expression/date_and_time.rs index f0cfff226b48..94ef7362ae46 100644 --- a/diesel/src/pg/expression/date_and_time.rs +++ b/diesel/src/pg/expression/date_and_time.rs @@ -2,15 +2,14 @@ use crate::expression::{Expression, ValidGrouping}; use crate::pg::Pg; use crate::query_builder::*; use crate::result::QueryResult; -use crate::sql_types::{Date, NotNull, Nullable, Timestamp, Timestamptz, VarChar}; +use crate::sql_types::{is_nullable, Date, Nullable, SqlType, Timestamp, Timestamptz, VarChar}; /// Marker trait for types which are valid in `AT TIME ZONE` expressions pub trait DateTimeLike {} impl DateTimeLike for Date {} impl DateTimeLike for Timestamp {} impl DateTimeLike for Timestamptz {} -impl DateTimeLike for Nullable {} - +impl DateTimeLike for Nullable where T: SqlType + DateTimeLike {} #[derive(Debug, Copy, Clone, QueryId, ValidGrouping)] pub struct AtTimeZone { timestamp: Ts, diff --git a/diesel/src/pg/expression/expression_methods.rs b/diesel/src/pg/expression/expression_methods.rs index 8b8931880bd1..9abf426f3717 100644 --- a/diesel/src/pg/expression/expression_methods.rs +++ b/diesel/src/pg/expression/expression_methods.rs @@ -1,8 +1,8 @@ //! PostgreSQL specific expression methods use super::operators::*; -use crate::expression::{AsExpression, Expression}; -use crate::sql_types::{Array, Nullable, Range, Text}; +use crate::expression::{AsExpression, Expression, TypedExpressionType}; +use crate::sql_types::{Array, Nullable, Range, SqlType, Text}; /// PostgreSQL specific methods which are present on all expressions. pub trait PgExpressionMethods: Expression + Sized { @@ -27,6 +27,7 @@ pub trait PgExpressionMethods: Expression + Sized { /// ``` fn is_not_distinct_from(self, other: T) -> IsNotDistinctFrom where + Self::SqlType: SqlType, T: AsExpression, { IsNotDistinctFrom::new(self, other.as_expression()) @@ -53,6 +54,7 @@ pub trait PgExpressionMethods: Expression + Sized { /// ``` fn is_distinct_from(self, other: T) -> IsDistinctFrom where + Self::SqlType: SqlType, T: AsExpression, { IsDistinctFrom::new(self, other.as_expression()) @@ -187,6 +189,7 @@ pub trait PgArrayExpressionMethods: Expression + Sized { /// ``` fn overlaps_with(self, other: T) -> OverlapsWith where + Self::SqlType: SqlType, T: AsExpression, { OverlapsWith::new(self, other.as_expression()) @@ -236,6 +239,7 @@ pub trait PgArrayExpressionMethods: Expression + Sized { /// ``` fn contains(self, other: T) -> Contains where + Self::SqlType: SqlType, T: AsExpression, { Contains::new(self, other.as_expression()) @@ -286,6 +290,7 @@ pub trait PgArrayExpressionMethods: Expression + Sized { /// ``` fn is_contained_by(self, other: T) -> IsContainedBy where + Self::SqlType: SqlType, T: AsExpression, { IsContainedBy::new(self, other.as_expression()) @@ -309,6 +314,7 @@ impl ArrayOrNullableArray for Array {} impl ArrayOrNullableArray for Nullable> {} use crate::expression::operators::{Asc, Desc}; +use crate::EscapeExpressionMethods; /// PostgreSQL expression methods related to sorting. /// @@ -440,8 +446,11 @@ pub trait PgTextExpressionMethods: Expression + Sized { /// # Ok(()) /// # } /// ``` - fn ilike>(self, other: T) -> ILike { - ILike::new(self.as_expression(), other.as_expression()) + fn ilike(self, other: T) -> ILike + where + T: AsExpression, + { + ILike::new(self, other.as_expression()) } /// Creates a PostgreSQL `NOT ILIKE` expression @@ -466,8 +475,11 @@ pub trait PgTextExpressionMethods: Expression + Sized { /// # Ok(()) /// # } /// ``` - fn not_ilike>(self, other: T) -> NotILike { - NotILike::new(self.as_expression(), other.as_expression()) + fn not_ilike(self, other: T) -> NotILike + where + T: AsExpression, + { + NotILike::new(self, other.as_expression()) } } @@ -487,10 +499,13 @@ where { } +impl EscapeExpressionMethods for ILike {} +impl EscapeExpressionMethods for NotILike {} + #[doc(hidden)] /// Marker trait used to extract the inner type /// of our `Range` sql type, used to implement `PgRangeExpressionMethods` -pub trait RangeHelper { +pub trait RangeHelper: SqlType { type Inner; } @@ -547,10 +562,25 @@ pub trait PgRangeExpressionMethods: Expression + Sized { fn contains(self, other: T) -> Contains where Self::SqlType: RangeHelper, + ::Inner: SqlType + TypedExpressionType, T: AsExpression<::Inner>, { Contains::new(self, other.as_expression()) } } -impl PgRangeExpressionMethods for T where T: Expression> {} +#[doc(hidden)] +/// Marker trait used to implement `PgRangeExpressionMethods` on the appropriate +/// types. Once coherence takes associated types into account, we can remove +/// this trait. +pub trait RangeOrNullableRange {} + +impl RangeOrNullableRange for Range {} +impl RangeOrNullableRange for Nullable> {} + +impl PgRangeExpressionMethods for T +where + T: Expression, + T::SqlType: RangeOrNullableRange, +{ +} diff --git a/diesel/src/pg/expression/operators.rs b/diesel/src/pg/expression/operators.rs index 7c9f93d7c174..611e3f462337 100644 --- a/diesel/src/pg/expression/operators.rs +++ b/diesel/src/pg/expression/operators.rs @@ -1,3 +1,4 @@ +use crate::expression::expression_types::NotSelectable; use crate::pg::Pg; infix_operator!(IsDistinctFrom, " IS DISTINCT FROM ", backend: Pg); @@ -7,5 +8,5 @@ infix_operator!(Contains, " @> ", backend: Pg); infix_operator!(IsContainedBy, " <@ ", backend: Pg); infix_operator!(ILike, " ILIKE ", backend: Pg); infix_operator!(NotILike, " NOT ILIKE ", backend: Pg); -postfix_operator!(NullsFirst, " NULLS FIRST", (), backend: Pg); -postfix_operator!(NullsLast, " NULLS LAST", (), backend: Pg); +postfix_operator!(NullsFirst, " NULLS FIRST", NotSelectable, backend: Pg); +postfix_operator!(NullsLast, " NULLS LAST", NotSelectable, backend: Pg); diff --git a/diesel/src/pg/query_builder/mod.rs b/diesel/src/pg/query_builder/mod.rs index 3e8ad828fff9..981c2745293b 100644 --- a/diesel/src/pg/query_builder/mod.rs +++ b/diesel/src/pg/query_builder/mod.rs @@ -37,8 +37,8 @@ impl QueryBuilder for PgQueryBuilder { fn push_bind_param(&mut self) { self.bind_idx += 1; - let sql = format!("${}", self.bind_idx); - self.push_sql(&sql); + self.sql += "$"; + itoa::fmt(&mut self.sql, self.bind_idx).expect("int formating does not fail"); } fn finish(self) -> String { diff --git a/diesel/src/pg/types/array.rs b/diesel/src/pg/types/array.rs index cddc6801ce1f..7d799b325ba0 100644 --- a/diesel/src/pg/types/array.rs +++ b/diesel/src/pg/types/array.rs @@ -23,8 +23,7 @@ impl FromSql, Pg> for Vec where T: FromSql, { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: PgValue<'_>) -> deserialize::Result { let mut bytes = value.as_bytes(); let num_dimensions = bytes.read_i32::()?; let has_null = bytes.read_i32::()? != 0; @@ -45,11 +44,11 @@ where .map(|_| { let elem_size = bytes.read_i32::()?; if has_null && elem_size == -1 { - T::from_sql(None) + T::from_nullable_sql(None) } else { let (elem_bytes, new_bytes) = bytes.split_at(elem_size as usize); bytes = new_bytes; - T::from_sql(Some(PgValue::new(elem_bytes, value.get_oid()))) + T::from_sql(PgValue::new(elem_bytes, value.get_oid())) } }) .collect() diff --git a/diesel/src/pg/types/date_and_time/chrono.rs b/diesel/src/pg/types/date_and_time/chrono.rs index 686cb11cb8fb..7d4d4c061179 100644 --- a/diesel/src/pg/types/date_and_time/chrono.rs +++ b/diesel/src/pg/types/date_and_time/chrono.rs @@ -19,7 +19,7 @@ fn pg_epoch() -> NaiveDateTime { } impl FromSql for NaiveDateTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { let PgTimestamp(offset) = FromSql::::from_sql(bytes)?; match pg_epoch().checked_add_signed(Duration::microseconds(offset)) { Some(v) => Ok(v), @@ -46,7 +46,7 @@ impl ToSql for NaiveDateTime { } impl FromSql for NaiveDateTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { FromSql::::from_sql(bytes) } } @@ -58,14 +58,14 @@ impl ToSql for NaiveDateTime { } impl FromSql for DateTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { let naive_date_time = >::from_sql(bytes)?; Ok(DateTime::from_utc(naive_date_time, Utc)) } } impl FromSql for DateTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { let naive_date_time = >::from_sql(bytes)?; Ok(Local::from_utc_datetime(&Local, &naive_date_time)) } @@ -92,7 +92,7 @@ impl ToSql for NaiveTime { } impl FromSql for NaiveTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { let PgTime(offset) = FromSql::::from_sql(bytes)?; let duration = Duration::microseconds(offset); Ok(midnight() + duration) @@ -111,7 +111,7 @@ impl ToSql for NaiveDate { } impl FromSql for NaiveDate { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { let PgDate(offset) = FromSql::::from_sql(bytes)?; match pg_epoch_date().checked_add_signed(Duration::days(i64::from(offset))) { Some(date) => Ok(date), diff --git a/diesel/src/pg/types/date_and_time/deprecated_time.rs b/diesel/src/pg/types/date_and_time/deprecated_time.rs index f6e372df3cfb..95ba1d70e4cb 100644 --- a/diesel/src/pg/types/date_and_time/deprecated_time.rs +++ b/diesel/src/pg/types/date_and_time/deprecated_time.rs @@ -10,7 +10,7 @@ use crate::pg::{Pg, PgValue}; use crate::serialize::{self, Output, ToSql}; use crate::sql_types; -#[derive(FromSqlRow, AsExpression)] +#[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[sql_type = "sql_types::Timestamp"] #[allow(dead_code)] @@ -28,7 +28,7 @@ impl ToSql for Timespec { } impl FromSql for Timespec { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { let t = >::from_sql(bytes)?; let pg_epoch = Timespec::new(TIME_SEC_CONV, 0); let duration = Duration::microseconds(t); diff --git a/diesel/src/pg/types/date_and_time/mod.rs b/diesel/src/pg/types/date_and_time/mod.rs index 360f913c0c79..f67dda0de6d2 100644 --- a/diesel/src/pg/types/date_and_time/mod.rs +++ b/diesel/src/pg/types/date_and_time/mod.rs @@ -15,7 +15,7 @@ mod deprecated_time; mod quickcheck_impls; mod std_time; -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, FromSqlRow, AsExpression)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, AsExpression, FromSqlRow)] #[sql_type = "Timestamp"] #[sql_type = "Timestamptz"] /// Timestamps are represented in Postgres as a 64 bit signed integer representing the number of @@ -23,7 +23,7 @@ mod std_time; /// the integer's meaning. pub struct PgTimestamp(pub i64); -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, FromSqlRow, AsExpression)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, AsExpression, FromSqlRow)] #[sql_type = "Date"] /// Dates are represented in Postgres as a 32 bit signed integer representing the number of julian /// days since January 1st 2000. This struct is a dumb wrapper type, meant only to indicate the @@ -33,7 +33,7 @@ pub struct PgDate(pub i32); /// Time is represented in Postgres as a 64 bit signed integer representing the number of /// microseconds since midnight. This struct is a dumb wrapper type, meant only to indicate the /// integer's meaning. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, FromSqlRow, AsExpression)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, AsExpression, FromSqlRow)] #[sql_type = "Time"] pub struct PgTime(pub i64); @@ -41,7 +41,7 @@ pub struct PgTime(pub i64); /// microseconds, a 32 bit integer representing number of days, and a 32 bit integer /// representing number of months. This struct is a dumb wrapper type, meant only to indicate the /// meaning of these parts. -#[derive(Debug, Clone, Copy, PartialEq, Eq, FromSqlRow, AsExpression)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, AsExpression, FromSqlRow)] #[sql_type = "Interval"] pub struct PgInterval { /// The number of whole microseconds @@ -90,7 +90,7 @@ impl ToSql for PgTimestamp { } impl FromSql for PgTimestamp { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { FromSql::::from_sql(bytes).map(PgTimestamp) } } @@ -102,7 +102,7 @@ impl ToSql for PgTimestamp { } impl FromSql for PgTimestamp { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { FromSql::::from_sql(bytes) } } @@ -114,7 +114,7 @@ impl ToSql for PgDate { } impl FromSql for PgDate { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { FromSql::::from_sql(bytes).map(PgDate) } } @@ -126,7 +126,7 @@ impl ToSql for PgTime { } impl FromSql for PgTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { FromSql::::from_sql(bytes).map(PgTime) } } @@ -141,12 +141,11 @@ impl ToSql for PgInterval { } impl FromSql for PgInterval { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: PgValue<'_>) -> deserialize::Result { Ok(PgInterval { - microseconds: FromSql::::from_sql(Some(value.subslice(0..8)))?, - days: FromSql::::from_sql(Some(value.subslice(8..12)))?, - months: FromSql::::from_sql(Some(value.subslice(12..16)))?, + microseconds: FromSql::::from_sql(value.subslice(0..8))?, + days: FromSql::::from_sql(value.subslice(8..12))?, + months: FromSql::::from_sql(value.subslice(12..16))?, }) } } diff --git a/diesel/src/pg/types/date_and_time/std_time.rs b/diesel/src/pg/types/date_and_time/std_time.rs index 1bf879bd83d2..8789aa4773e3 100644 --- a/diesel/src/pg/types/date_and_time/std_time.rs +++ b/diesel/src/pg/types/date_and_time/std_time.rs @@ -27,7 +27,7 @@ impl ToSql for SystemTime { } impl FromSql for SystemTime { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { let usecs_passed = >::from_sql(bytes)?; let before_epoch = usecs_passed < 0; let time_passed = usecs_to_duration(usecs_passed.abs() as u64); diff --git a/diesel/src/pg/types/floats/mod.rs b/diesel/src/pg/types/floats/mod.rs index 21915594c900..aa1387946ce3 100644 --- a/diesel/src/pg/types/floats/mod.rs +++ b/diesel/src/pg/types/floats/mod.rs @@ -11,7 +11,7 @@ use crate::sql_types; #[cfg(feature = "quickcheck")] mod quickcheck_impls; -#[derive(Debug, Clone, PartialEq, Eq, FromSqlRow, AsExpression)] +#[derive(Debug, Clone, PartialEq, Eq, AsExpression, FromSqlRow)] #[sql_type = "sql_types::Numeric"] /// Represents a NUMERIC value, closely mirroring the PG wire protocol /// representation @@ -50,8 +50,7 @@ impl ::std::fmt::Display for InvalidNumericSign { impl Error for InvalidNumericSign {} impl FromSql for PgNumeric { - fn from_sql(bytes: Option>) -> deserialize::Result { - let bytes = not_none!(bytes); + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { let mut bytes = bytes.as_bytes(); let digit_count = bytes.read_u16::()?; let mut digits = Vec::with_capacity(digit_count as usize); diff --git a/diesel/src/pg/types/integers.rs b/diesel/src/pg/types/integers.rs index 3e74e0473659..bc59ae8f51f4 100644 --- a/diesel/src/pg/types/integers.rs +++ b/diesel/src/pg/types/integers.rs @@ -7,8 +7,7 @@ use crate::serialize::{self, IsNull, Output, ToSql}; use crate::sql_types; impl FromSql for u32 { - fn from_sql(bytes: Option>) -> deserialize::Result { - let bytes = not_none!(bytes); + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { let mut bytes = bytes.as_bytes(); bytes.read_u32::().map_err(Into::into) } diff --git a/diesel/src/pg/types/json.rs b/diesel/src/pg/types/json.rs index 4b9c41648756..e483501cb295 100644 --- a/diesel/src/pg/types/json.rs +++ b/diesel/src/pg/types/json.rs @@ -10,8 +10,7 @@ use crate::serialize::{self, IsNull, Output, ToSql}; use crate::sql_types; impl FromSql for serde_json::Value { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: PgValue<'_>) -> deserialize::Result { serde_json::from_slice(value.as_bytes()).map_err(|_| "Invalid Json".into()) } } @@ -25,8 +24,7 @@ impl ToSql for serde_json::Value { } impl FromSql for serde_json::Value { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: PgValue<'_>) -> deserialize::Result { let bytes = value.as_bytes(); if bytes[0] != 1 { return Err("Unsupported JSONB encoding version".into()); @@ -56,20 +54,21 @@ fn json_to_sql() { fn some_json_from_sql() { let input_json = b"true"; let output_json: serde_json::Value = - FromSql::::from_sql(Some(PgValue::for_test(input_json))).unwrap(); + FromSql::::from_sql(PgValue::for_test(input_json)).unwrap(); assert_eq!(output_json, serde_json::Value::Bool(true)); } #[test] fn bad_json_from_sql() { let uuid: Result = - FromSql::::from_sql(Some(PgValue::for_test(b"boom"))); + FromSql::::from_sql(PgValue::for_test(b"boom")); assert_eq!(uuid.unwrap_err().to_string(), "Invalid Json"); } #[test] fn no_json_from_sql() { - let uuid: Result = FromSql::::from_sql(None); + let uuid: Result = + FromSql::::from_nullable_sql(None); assert_eq!( uuid.unwrap_err().to_string(), "Unexpected null for non-null column" @@ -88,21 +87,21 @@ fn jsonb_to_sql() { fn some_jsonb_from_sql() { let input_json = b"\x01true"; let output_json: serde_json::Value = - FromSql::::from_sql(Some(PgValue::for_test(input_json))).unwrap(); + FromSql::::from_sql(PgValue::for_test(input_json)).unwrap(); assert_eq!(output_json, serde_json::Value::Bool(true)); } #[test] fn bad_jsonb_from_sql() { let uuid: Result = - FromSql::::from_sql(Some(PgValue::for_test(b"\x01boom"))); + FromSql::::from_sql(PgValue::for_test(b"\x01boom")); assert_eq!(uuid.unwrap_err().to_string(), "Invalid Json"); } #[test] fn bad_jsonb_version_from_sql() { let uuid: Result = - FromSql::::from_sql(Some(PgValue::for_test(b"\x02true"))); + FromSql::::from_sql(PgValue::for_test(b"\x02true")); assert_eq!( uuid.unwrap_err().to_string(), "Unsupported JSONB encoding version" @@ -111,7 +110,8 @@ fn bad_jsonb_version_from_sql() { #[test] fn no_jsonb_from_sql() { - let uuid: Result = FromSql::::from_sql(None); + let uuid: Result = + FromSql::::from_nullable_sql(None); assert_eq!( uuid.unwrap_err().to_string(), "Unexpected null for non-null column" diff --git a/diesel/src/pg/types/mac_addr.rs b/diesel/src/pg/types/mac_addr.rs index a5a97751066c..d7230f99878b 100644 --- a/diesel/src/pg/types/mac_addr.rs +++ b/diesel/src/pg/types/mac_addr.rs @@ -12,15 +12,14 @@ mod foreign_derives { use crate::deserialize::FromSqlRow; use crate::expression::AsExpression; - #[derive(FromSqlRow, AsExpression)] + #[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[sql_type = "MacAddr"] struct ByteArrayProxy([u8; 6]); } impl FromSql for [u8; 6] { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: PgValue<'_>) -> deserialize::Result { value .as_bytes() .try_into() @@ -41,7 +40,6 @@ fn macaddr_roundtrip() { let mut bytes = Output::test(); let input_address = [0x52, 0x54, 0x00, 0xfb, 0xc6, 0x16]; ToSql::::to_sql(&input_address, &mut bytes).unwrap(); - let output_address: [u8; 6] = - FromSql::from_sql(Some(PgValue::for_test(bytes.as_ref()))).unwrap(); + let output_address: [u8; 6] = FromSql::from_sql(PgValue::for_test(bytes.as_ref())).unwrap(); assert_eq!(input_address, output_address); } diff --git a/diesel/src/pg/types/money.rs b/diesel/src/pg/types/money.rs index 8939ce1fda2d..b6f3ebb7d6d0 100644 --- a/diesel/src/pg/types/money.rs +++ b/diesel/src/pg/types/money.rs @@ -20,12 +20,12 @@ use crate::sql_types::{BigInt, Money}; /// use diesel::data_types::PgMoney as Pence; // 1/100th unit of Pound /// use diesel::data_types::PgMoney as Fils; // 1/1000th unit of Dinar /// ``` -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, FromSqlRow, AsExpression)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, AsExpression, FromSqlRow)] #[sql_type = "Money"] pub struct PgMoney(pub i64); impl FromSql for PgMoney { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { FromSql::::from_sql(bytes).map(PgMoney) } } diff --git a/diesel/src/pg/types/network_address.rs b/diesel/src/pg/types/network_address.rs index 563f1852fc51..575622c3641c 100644 --- a/diesel/src/pg/types/network_address.rs +++ b/diesel/src/pg/types/network_address.rs @@ -5,7 +5,7 @@ use self::ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network}; use std::io::prelude::*; use std::net::{Ipv4Addr, Ipv6Addr}; -use crate::deserialize::{self, FromSql}; +use crate::deserialize::{self, FromSql, FromSqlRow}; use crate::pg::{Pg, PgValue}; use crate::serialize::{self, IsNull, Output, ToSql}; use crate::sql_types::{Cidr, Inet}; @@ -24,10 +24,9 @@ const PGSQL_AF_INET6: u8 = AF_INET + 1; #[allow(dead_code)] mod foreign_derives { use super::*; - use crate::deserialize::FromSqlRow; use crate::expression::AsExpression; - #[derive(FromSqlRow, AsExpression)] + #[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[sql_type = "Inet"] #[sql_type = "Cidr"] @@ -60,9 +59,8 @@ macro_rules! assert_or_error { macro_rules! impl_Sql { ($ty: ty, $net_type: expr) => { impl FromSql<$ty, Pg> for IpNetwork { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: PgValue<'_>) -> deserialize::Result { // https://github.com/postgres/postgres/blob/55c3391d1e6a201b5b891781d21fe682a8c64fe6/src/include/utils/inet.h#L23-L28 - let value = not_none!(value); let bytes = value.as_bytes(); assert_or_error!(4 <= bytes.len(), "input is too short."); let af = bytes[0]; @@ -161,7 +159,7 @@ fn some_v4address_from_sql() { let mut bytes = Output::test(); ToSql::<$ty, Pg>::to_sql(&input_address, &mut bytes).unwrap(); let output_address = - FromSql::<$ty, Pg>::from_sql(Some(PgValue::for_test(bytes.as_ref()))).unwrap(); + FromSql::<$ty, Pg>::from_sql(PgValue::for_test(bytes.as_ref())).unwrap(); assert_eq!(input_address, output_address); }; } @@ -219,7 +217,7 @@ fn some_v6address_from_sql() { let mut bytes = Output::test(); ToSql::<$ty, Pg>::to_sql(&input_address, &mut bytes).unwrap(); let output_address = - FromSql::<$ty, Pg>::from_sql(Some(PgValue::for_test(bytes.as_ref()))).unwrap(); + FromSql::<$ty, Pg>::from_sql(PgValue::for_test(bytes.as_ref())).unwrap(); assert_eq!(input_address, output_address); }; } @@ -233,7 +231,7 @@ fn bad_address_from_sql() { macro_rules! bad_address_from_sql { ($ty:tt) => { let address: Result = - FromSql::<$ty, Pg>::from_sql(Some(PgValue::for_test(&[7, PGSQL_AF_INET, 0]))); + FromSql::<$ty, Pg>::from_sql(PgValue::for_test(&[7, PGSQL_AF_INET, 0])); assert_eq!( address.unwrap_err().to_string(), "invalid network address format. input is too short." @@ -249,7 +247,7 @@ fn bad_address_from_sql() { fn no_address_from_sql() { macro_rules! test_no_address_from_sql { ($ty:ty) => { - let address: Result = FromSql::<$ty, Pg>::from_sql(None); + let address: Result = FromSql::<$ty, Pg>::from_nullable_sql(None); assert_eq!( address.unwrap_err().to_string(), "Unexpected null for non-null column" diff --git a/diesel/src/pg/types/numeric.rs b/diesel/src/pg/types/numeric.rs index c965a6cdfc85..72f0292069fb 100644 --- a/diesel/src/pg/types/numeric.rs +++ b/diesel/src/pg/types/numeric.rs @@ -153,7 +153,7 @@ mod bigdecimal { } impl FromSql for BigDecimal { - fn from_sql(numeric: Option>) -> deserialize::Result { + fn from_sql(numeric: PgValue<'_>) -> deserialize::Result { PgNumeric::from_sql(numeric)?.try_into() } } diff --git a/diesel/src/pg/types/primitives.rs b/diesel/src/pg/types/primitives.rs index 79e14798802f..58dc2dc93fd1 100644 --- a/diesel/src/pg/types/primitives.rs +++ b/diesel/src/pg/types/primitives.rs @@ -6,11 +6,8 @@ use crate::serialize::{self, IsNull, Output, ToSql}; use crate::sql_types; impl FromSql for bool { - fn from_sql(bytes: Option>) -> deserialize::Result { - match bytes { - Some(bytes) => Ok(bytes.as_bytes()[0] != 0), - None => Ok(false), - } + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { + Ok(bytes.as_bytes()[0] != 0) } } @@ -31,7 +28,10 @@ fn bool_to_sql() { } #[test] -fn bool_from_sql_treats_null_as_false() { - let result = >::from_sql(None).unwrap(); - assert!(!result); +fn no_bool_from_sql() { + let result = >::from_nullable_sql(None); + assert_eq!( + result.unwrap_err().to_string(), + "Unexpected null for non-null column" + ); } diff --git a/diesel/src/pg/types/ranges.rs b/diesel/src/pg/types/ranges.rs index 81fc127548f1..94a3e31969a8 100644 --- a/diesel/src/pg/types/ranges.rs +++ b/diesel/src/pg/types/ranges.rs @@ -2,7 +2,7 @@ use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt}; use std::collections::Bound; use std::io::Write; -use crate::deserialize::{self, FromSql, FromSqlRow, Queryable}; +use crate::deserialize::{self, FromSql, Queryable}; use crate::expression::bound::Bound as SqlBound; use crate::expression::AsExpression; use crate::pg::{Pg, PgMetadataLookup, PgTypeMetadata, PgValue}; @@ -23,16 +23,6 @@ bitflags! { } } -impl Queryable, Pg> for (Bound, Bound) -where - T: FromSql + Queryable, -{ - type Row = Self; - fn build(row: Self) -> Self { - row - } -} - impl AsExpression> for (Bound, Bound) { type Expression = SqlBound, Self>; @@ -65,21 +55,11 @@ impl<'a, ST, T> AsExpression>> for &'a (Bound, Bound) { } } -impl FromSqlRow, Pg> for (Bound, Bound) -where - (Bound, Bound): FromSql, Pg>, -{ - fn build_from_row>(row: &mut R) -> deserialize::Result { - FromSql::, Pg>::from_sql(row.take()) - } -} - impl FromSql, Pg> for (Bound, Bound) where T: FromSql, { - fn from_sql(bytes: Option>) -> deserialize::Result { - let value = not_none!(bytes); + fn from_sql(value: PgValue<'_>) -> deserialize::Result { let mut bytes = value.as_bytes(); let flags: RangeFlags = RangeFlags::from_bits_truncate(bytes.read_u8()?); let mut lower_bound = Bound::Unbounded; @@ -89,7 +69,7 @@ where let elem_size = bytes.read_i32::()?; let (elem_bytes, new_bytes) = bytes.split_at(elem_size as usize); bytes = new_bytes; - let value = T::from_sql(Some(PgValue::new(elem_bytes, value.get_oid())))?; + let value = T::from_sql(PgValue::new(elem_bytes, value.get_oid()))?; lower_bound = if flags.contains(RangeFlags::LB_INC) { Bound::Included(value) @@ -100,7 +80,7 @@ where if !flags.contains(RangeFlags::UB_INF) { let _size = bytes.read_i32::()?; - let value = T::from_sql(Some(PgValue::new(bytes, value.get_oid())))?; + let value = T::from_sql(PgValue::new(bytes, value.get_oid()))?; upper_bound = if flags.contains(RangeFlags::UB_INC) { Bound::Included(value) @@ -113,6 +93,17 @@ where } } +impl Queryable, Pg> for (Bound, Bound) +where + T: FromSql, +{ + type Row = Self; + + fn build(row: Self) -> Self { + row + } +} + impl ToSql, Pg> for (Bound, Bound) where T: ToSql, diff --git a/diesel/src/pg/types/record.rs b/diesel/src/pg/types/record.rs index 05fd2faab451..fb45530f356b 100644 --- a/diesel/src/pg/types/record.rs +++ b/diesel/src/pg/types/record.rs @@ -2,16 +2,16 @@ use byteorder::*; use std::io::Write; use std::num::NonZeroU32; -use crate::deserialize::{self, FromSql, FromSqlRow, Queryable}; +use crate::deserialize::{self, FromSql, Queryable}; use crate::expression::{ - AppearsOnTable, AsExpression, Expression, SelectableExpression, ValidGrouping, + AppearsOnTable, AsExpression, Expression, SelectableExpression, TypedExpressionType, + ValidGrouping, }; use crate::pg::{Pg, PgValue}; use crate::query_builder::{AstPass, QueryFragment, QueryId}; use crate::result::QueryResult; -use crate::row::Row; use crate::serialize::{self, IsNull, Output, ToSql, WriteTuple}; -use crate::sql_types::{HasSqlType, Record}; +use crate::sql_types::{HasSqlType, Record, SqlType}; macro_rules! tuple_impls { ($( @@ -27,8 +27,7 @@ macro_rules! tuple_impls { // but the only other option would be to use `mem::uninitialized` // and `ptr::write`. #[allow(clippy::eval_order_dependence)] - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: PgValue<'_>) -> deserialize::Result { let mut bytes = value.as_bytes(); let num_elements = bytes.read_i32::()?; @@ -49,14 +48,14 @@ macro_rules! tuple_impls { let num_bytes = bytes.read_i32::()?; if num_bytes == -1 { - $T::from_sql(None)? + $T::from_nullable_sql(None)? } else { let (elem_bytes, new_bytes) = bytes.split_at(num_bytes as usize); bytes = new_bytes; - $T::from_sql(Some(PgValue::new( + $T::from_sql(PgValue::new( elem_bytes, oid, - )))? + ))? } },)+); @@ -69,20 +68,8 @@ macro_rules! tuple_impls { } } - impl<$($T,)+ $($ST,)+> FromSqlRow, Pg> for ($($T,)+) - where - Self: FromSql, Pg>, - { - const FIELDS_NEEDED: usize = 1; - - fn build_from_row>(row: &mut RowT) -> deserialize::Result { - Self::from_sql(row.take()) - } - } - impl<$($T,)+ $($ST,)+> Queryable, Pg> for ($($T,)+) - where - Self: FromSqlRow, Pg>, + where Self: FromSql, Pg> { type Row = Self; @@ -93,6 +80,7 @@ macro_rules! tuple_impls { impl<$($T,)+ $($ST,)+> AsExpression> for ($($T,)+) where + $($ST: SqlType + TypedExpressionType,)+ $($T: AsExpression<$ST>,)+ PgTuple<($($T::Expression,)+)>: Expression>, { diff --git a/diesel/src/pg/types/uuid.rs b/diesel/src/pg/types/uuid.rs index 5e5e86ec6956..d2577cebde8a 100644 --- a/diesel/src/pg/types/uuid.rs +++ b/diesel/src/pg/types/uuid.rs @@ -1,21 +1,19 @@ use std::io::prelude::*; -use crate::deserialize::FromSqlRow; -use crate::deserialize::{self, FromSql}; +use crate::deserialize::{self, FromSql, FromSqlRow}; use crate::expression::AsExpression; use crate::pg::{Pg, PgValue}; use crate::serialize::{self, IsNull, Output, ToSql}; use crate::sql_types::Uuid; -#[derive(FromSqlRow, AsExpression)] +#[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[sql_type = "Uuid"] #[allow(dead_code)] struct UuidProxy(uuid::Uuid); impl FromSql for uuid::Uuid { - fn from_sql(bytes: Option>) -> deserialize::Result { - let value = not_none!(bytes); + fn from_sql(value: PgValue<'_>) -> deserialize::Result { uuid::Uuid::from_slice(value.as_bytes()).map_err(Into::into) } } @@ -40,13 +38,13 @@ fn uuid_to_sql() { fn some_uuid_from_sql() { let input_uuid = uuid::Uuid::from_fields(0xFFFF_FFFF, 0xFFFF, 0xFFFF, b"abcdef12").unwrap(); let output_uuid = - FromSql::::from_sql(Some(PgValue::for_test(input_uuid.as_bytes()))).unwrap(); + FromSql::::from_sql(PgValue::for_test(input_uuid.as_bytes())).unwrap(); assert_eq!(input_uuid, output_uuid); } #[test] fn bad_uuid_from_sql() { - let uuid = uuid::Uuid::from_sql(Some(PgValue::for_test(b"boom"))); + let uuid = uuid::Uuid::from_sql(PgValue::for_test(b"boom")); assert_eq!( uuid.unwrap_err().to_string(), "invalid bytes length: expected 16, found 4" @@ -55,7 +53,7 @@ fn bad_uuid_from_sql() { #[test] fn no_uuid_from_sql() { - let uuid = uuid::Uuid::from_sql(None); + let uuid = uuid::Uuid::from_nullable_sql(None); assert_eq!( uuid.unwrap_err().to_string(), "Unexpected null for non-null column" diff --git a/diesel/src/query_builder/insert_statement/insert_from_select.rs b/diesel/src/query_builder/insert_statement/insert_from_select.rs index f5e1132068f1..faafffc4dce0 100644 --- a/diesel/src/query_builder/insert_statement/insert_from_select.rs +++ b/diesel/src/query_builder/insert_statement/insert_from_select.rs @@ -46,8 +46,8 @@ where impl QueryFragment for InsertFromSelect where DB: Backend, - Columns: ColumnList + Expression, - Select: Query + QueryFragment, + Columns: ColumnList + Expression, + Select: Query + QueryFragment, { fn walk_ast(&self, mut out: AstPass) -> QueryResult<()> { out.push_sql("("); @@ -60,7 +60,7 @@ where impl UndecoratedInsertRecord for InsertFromSelect where - Columns: ColumnList + Expression, - Select: Query, + Columns: ColumnList + Expression, + Select: Query, { } diff --git a/diesel/src/query_builder/insert_statement/mod.rs b/diesel/src/query_builder/insert_statement/mod.rs index da798643105b..b252c72714ca 100644 --- a/diesel/src/query_builder/insert_statement/mod.rs +++ b/diesel/src/query_builder/insert_statement/mod.rs @@ -162,8 +162,8 @@ impl InsertStatement, Op, Ret> { columns: C2, ) -> InsertStatement, Op, Ret> where - C2: ColumnList + Expression, - U: Query, + C2: ColumnList
+ Expression, + U: Query, { InsertStatement::new( self.target, diff --git a/diesel/src/query_builder/mod.rs b/diesel/src/query_builder/mod.rs index c4e520a6e497..c6291db87b9f 100644 --- a/diesel/src/query_builder/mod.rs +++ b/diesel/src/query_builder/mod.rs @@ -26,7 +26,7 @@ pub mod nodes; pub(crate) mod offset_clause; mod order_clause; mod returning_clause; -mod select_clause; +pub(crate) mod select_clause; mod select_statement; mod sql_query; mod update_statement; @@ -42,6 +42,10 @@ pub use self::insert_statement::{ IncompleteInsertStatement, InsertStatement, UndecoratedInsertRecord, ValuesClause, }; pub use self::query_id::QueryId; +#[doc(inline)] +pub use self::select_clause::{ + IntoBoxedSelectClause, SelectClauseExpression, SelectClauseQueryFragment, +}; #[doc(hidden)] pub use self::select_statement::{BoxedSelectStatement, SelectStatement}; pub use self::sql_query::{BoxedSqlQuery, SqlQuery}; diff --git a/diesel/src/query_builder/select_clause.rs b/diesel/src/query_builder/select_clause.rs index 54bff358174d..fea6ed31857d 100644 --- a/diesel/src/query_builder/select_clause.rs +++ b/diesel/src/query_builder/select_clause.rs @@ -8,8 +8,14 @@ pub struct DefaultSelectClause; #[derive(Debug, Clone, Copy, QueryId)] pub struct SelectClause(pub T); +/// Specialised variant of `Expression` for select clause types +/// +/// The difference to the normal `Expression` trait is the query source (`QS`) +/// generic type parameter. This allows to access the query source in generic code. pub trait SelectClauseExpression { + /// The expression represented by the given select clause type Selection: SelectableExpression; + /// SQL type of the select clause type SelectClauseSqlType; } @@ -29,7 +35,18 @@ where type SelectClauseSqlType = ::SqlType; } +/// Specialised variant of `QueryFragment` for select clause types +/// +/// The difference to the normal `QueryFragment` trait is the query source (`QS`) +/// generic type parameter. pub trait SelectClauseQueryFragment { + /// Walk over this `SelectClauseQueryFragment` for all passes. + /// + /// This method is where the actual behavior of an select clause is implemented. + /// This method will contain the behavior required for all possible AST + /// passes. See [`AstPass`] for more details. + /// + /// [`AstPass`]: struct.AstPass.html fn walk_ast(&self, source: &QS, pass: AstPass) -> QueryResult<()>; } @@ -53,3 +70,41 @@ where source.default_selection().walk_ast(pass) } } + +/// An internal helper trait to convert different select clauses +/// into their boxed counter part. +/// +/// You normally don't need this trait, at least as long as you +/// don't implement your own select clause representation +pub trait IntoBoxedSelectClause<'a, DB, QS> { + /// The sql type of the select clause + type SqlType; + + /// Convert the select clause into a the boxed representation + fn into_boxed(self, source: &QS) -> Box + Send + 'a>; +} + +impl<'a, DB, T, QS> IntoBoxedSelectClause<'a, DB, QS> for SelectClause +where + T: QueryFragment + SelectableExpression + Send + 'a, + DB: Backend, +{ + type SqlType = T::SqlType; + + fn into_boxed(self, _source: &QS) -> Box + Send + 'a> { + Box::new(self.0) + } +} + +impl<'a, DB, QS> IntoBoxedSelectClause<'a, DB, QS> for DefaultSelectClause +where + QS: QuerySource, + QS::DefaultSelection: QueryFragment + Send + 'a, + DB: Backend, +{ + type SqlType = ::SqlType; + + fn into_boxed(self, source: &QS) -> Box + Send + 'a> { + Box::new(source.default_selection()) + } +} diff --git a/diesel/src/query_builder/select_statement/boxed.rs b/diesel/src/query_builder/select_statement/boxed.rs index 5216e617d848..d853b80e5cd6 100644 --- a/diesel/src/query_builder/select_statement/boxed.rs +++ b/diesel/src/query_builder/select_statement/boxed.rs @@ -19,7 +19,7 @@ use crate::query_dsl::*; use crate::query_source::joins::*; use crate::query_source::{QuerySource, Table}; use crate::result::QueryResult; -use crate::sql_types::{BigInt, Bool, NotNull, Nullable}; +use crate::sql_types::{BigInt, BoolOrNullableBool, IntoNullable}; #[allow(missing_debug_implementations)] pub struct BoxedSelectStatement<'a, ST, QS, DB> { @@ -194,7 +194,8 @@ where impl<'a, ST, QS, DB, Predicate> FilterDsl for BoxedSelectStatement<'a, ST, QS, DB> where BoxedWhereClause<'a, DB>: WhereAnd>, - Predicate: AppearsOnTable + NonAggregate, + Predicate: AppearsOnTable + NonAggregate, + Predicate::SqlType: BoolOrNullableBool, { type Output = Self; @@ -207,7 +208,8 @@ where impl<'a, ST, QS, DB, Predicate> OrFilterDsl for BoxedSelectStatement<'a, ST, QS, DB> where BoxedWhereClause<'a, DB>: WhereOr>, - Predicate: AppearsOnTable + NonAggregate, + Predicate: AppearsOnTable + NonAggregate, + Predicate::SqlType: BoolOrNullableBool, { type Output = Self; @@ -331,9 +333,9 @@ where impl<'a, ST, QS, DB> SelectNullableDsl for BoxedSelectStatement<'a, ST, QS, DB> where - ST: NotNull, + ST: IntoNullable, { - type Output = BoxedSelectStatement<'a, Nullable, QS, DB>; + type Output = BoxedSelectStatement<'a, ST::Nullable, QS, DB>; fn nullable(self) -> Self::Output { BoxedSelectStatement { diff --git a/diesel/src/query_builder/select_statement/dsl_impls.rs b/diesel/src/query_builder/select_statement/dsl_impls.rs index 61d7a09766c8..4fabed4537e8 100644 --- a/diesel/src/query_builder/select_statement/dsl_impls.rs +++ b/diesel/src/query_builder/select_statement/dsl_impls.rs @@ -24,7 +24,7 @@ use crate::query_dsl::methods::*; use crate::query_dsl::*; use crate::query_source::joins::{Join, JoinOn, JoinTo}; use crate::query_source::QuerySource; -use crate::sql_types::{BigInt, Bool}; +use crate::sql_types::{BigInt, BoolOrNullableBool}; impl InternalJoinDsl for SelectStatement @@ -94,7 +94,8 @@ where impl FilterDsl for SelectStatement where - Predicate: Expression + NonAggregate, + Predicate: Expression + NonAggregate, + Predicate::SqlType: BoolOrNullableBool, W: WhereAnd, { type Output = SelectStatement; @@ -116,7 +117,8 @@ where impl OrFilterDsl for SelectStatement where - Predicate: Expression + NonAggregate, + Predicate: Expression + NonAggregate, + Predicate::SqlType: BoolOrNullableBool, W: WhereOr, { type Output = SelectStatement; @@ -340,12 +342,11 @@ impl ModifyLockDsl } } -impl<'a, F, S, D, W, O, LOf, G, DB> BoxedDsl<'a, DB> - for SelectStatement, D, W, O, LOf, G> +impl<'a, F, S, D, W, O, LOf, G, DB> BoxedDsl<'a, DB> for SelectStatement where Self: AsQuery, DB: Backend, - S: QueryFragment + SelectableExpression + Send + 'a, + S: IntoBoxedSelectClause<'a, DB, F>, D: QueryFragment + Send + 'a, W: Into>, O: Into + Send + 'a>>>, @@ -356,34 +357,7 @@ where fn internal_into_boxed(self) -> Self::Output { BoxedSelectStatement::new( - Box::new(self.select.0), - self.from, - Box::new(self.distinct), - self.where_clause.into(), - self.order.into(), - self.limit_offset.into_boxed(), - Box::new(self.group_by), - ) - } -} - -impl<'a, F, D, W, O, LOf, G, DB> BoxedDsl<'a, DB> - for SelectStatement -where - Self: AsQuery, - DB: Backend, - F: QuerySource, - F::DefaultSelection: QueryFragment + Send + 'a, - D: QueryFragment + Send + 'a, - W: Into>, - O: Into + Send + 'a>>>, - LOf: IntoBoxedClause<'a, DB, BoxedClause = BoxedLimitOffsetClause<'a, DB>>, - G: QueryFragment + Send + 'a, -{ - type Output = BoxedSelectStatement<'a, ::SqlType, F, DB>; - fn internal_into_boxed(self) -> Self::Output { - BoxedSelectStatement::new( - Box::new(self.from.default_selection()), + self.select.into_boxed(&self.from), self.from, Box::new(self.distinct), self.where_clause.into(), diff --git a/diesel/src/query_builder/sql_query.rs b/diesel/src/query_builder/sql_query.rs index f177d4a9eae4..f838cc3c5c8b 100644 --- a/diesel/src/query_builder/sql_query.rs +++ b/diesel/src/query_builder/sql_query.rs @@ -1,13 +1,13 @@ use std::marker::PhantomData; +use super::Query; use crate::backend::Backend; use crate::connection::Connection; -use crate::deserialize::QueryableByName; use crate::query_builder::{AstPass, QueryFragment, QueryId}; -use crate::query_dsl::{LoadQuery, RunQueryDsl}; +use crate::query_dsl::RunQueryDsl; use crate::result::QueryResult; use crate::serialize::ToSql; -use crate::sql_types::HasSqlType; +use crate::sql_types::{HasSqlType, Untyped}; #[derive(Debug, Clone)] #[must_use = "Queries are only executed when calling `load`, `get_result` or similar."] @@ -116,15 +116,8 @@ impl QueryId for SqlQuery { const HAS_STATIC_QUERY_ID: bool = false; } -impl LoadQuery for SqlQuery -where - Conn: Connection, - T: QueryableByName, - Self: QueryFragment, -{ - fn internal_load(self, conn: &Conn) -> QueryResult> { - conn.query_by_name(&self) - } +impl Query for SqlQuery { + type SqlType = Untyped; } impl RunQueryDsl for SqlQuery {} @@ -182,15 +175,8 @@ where } } -impl LoadQuery for UncheckedBind -where - Conn: Connection, - T: QueryableByName, - Self: QueryFragment + QueryId, -{ - fn internal_load(self, conn: &Conn) -> QueryResult> { - conn.query_by_name(&self) - } +impl Query for UncheckedBind { + type SqlType = Untyped; } impl RunQueryDsl for UncheckedBind {} @@ -260,15 +246,11 @@ impl QueryId for BoxedSqlQuery<'_, DB, Query> { const HAS_STATIC_QUERY_ID: bool = false; } -impl LoadQuery for BoxedSqlQuery<'_, Conn::Backend, Query> +impl Query for BoxedSqlQuery<'_, DB, Q> where - Conn: Connection, - T: QueryableByName, - Self: QueryFragment + QueryId, + DB: Backend, { - fn internal_load(self, conn: &Conn) -> QueryResult> { - conn.query_by_name(&self) - } + type SqlType = Untyped; } impl RunQueryDsl for BoxedSqlQuery<'_, Conn::Backend, Query> {} diff --git a/diesel/src/query_builder/where_clause.rs b/diesel/src/query_builder/where_clause.rs index 8c47086e44f9..a16e1704bb66 100644 --- a/diesel/src/query_builder/where_clause.rs +++ b/diesel/src/query_builder/where_clause.rs @@ -1,11 +1,9 @@ use super::*; use crate::backend::Backend; -use crate::dsl::Or; -use crate::expression::operators::And; +use crate::expression::operators::{And, Or}; use crate::expression::*; -use crate::expression_methods::*; use crate::result::QueryResult; -use crate::sql_types::Bool; +use crate::sql_types::BoolOrNullableBool; /// Add `Predicate` to the current `WHERE` clause, joining with `AND` if /// applicable. @@ -39,7 +37,8 @@ impl QueryFragment for NoWhereClause { impl WhereAnd for NoWhereClause where - Predicate: Expression, + Predicate: Expression, + Predicate::SqlType: BoolOrNullableBool, { type Output = WhereClause; @@ -50,7 +49,8 @@ where impl WhereOr for NoWhereClause where - Predicate: Expression, + Predicate: Expression, + Predicate::SqlType: BoolOrNullableBool, { type Output = WhereClause; @@ -83,25 +83,29 @@ where impl WhereAnd for WhereClause where - Expr: Expression, - Predicate: Expression, + Expr: Expression, + Expr::SqlType: BoolOrNullableBool, + Predicate: Expression, + Predicate::SqlType: BoolOrNullableBool, { type Output = WhereClause>; fn and(self, predicate: Predicate) -> Self::Output { - WhereClause(self.0.and(predicate)) + WhereClause(And::new(self.0, predicate)) } } impl WhereOr for WhereClause where - Expr: Expression, - Predicate: Expression, + Expr: Expression, + Expr::SqlType: BoolOrNullableBool, + Predicate: Expression, + Predicate::SqlType: BoolOrNullableBool, { type Output = WhereClause>; fn or(self, predicate: Predicate) -> Self::Output { - WhereClause(self.0.or(predicate)) + WhereClause(Or::new(self.0, predicate)) } } @@ -177,7 +181,6 @@ where fn or(self, predicate: Predicate) -> Self::Output { use self::BoxedWhereClause::Where; use crate::expression::grouped::Grouped; - use crate::expression::operators::Or; match self { Where(where_clause) => Where(Box::new(Grouped(Or::new(where_clause, predicate)))), diff --git a/diesel/src/query_dsl/load_dsl.rs b/diesel/src/query_dsl/load_dsl.rs index 4d74bae62da3..21246de94e9d 100644 --- a/diesel/src/query_dsl/load_dsl.rs +++ b/diesel/src/query_dsl/load_dsl.rs @@ -1,10 +1,10 @@ use super::RunQueryDsl; use crate::backend::Backend; use crate::connection::Connection; -use crate::deserialize::Queryable; +use crate::deserialize::FromSqlRow; +use crate::expression::QueryMetadata; use crate::query_builder::{AsQuery, QueryFragment, QueryId}; use crate::result::QueryResult; -use crate::sql_types::HasSqlType; /// The `load` method /// @@ -21,13 +21,13 @@ pub trait LoadQuery: RunQueryDsl { impl LoadQuery for T where Conn: Connection, - Conn::Backend: HasSqlType, T: AsQuery + RunQueryDsl, T::Query: QueryFragment + QueryId, - U: Queryable, + U: FromSqlRow, + Conn::Backend: QueryMetadata, { fn internal_load(self, conn: &Conn) -> QueryResult> { - conn.query_by_index(self) + conn.load(self) } } diff --git a/diesel/src/query_dsl/single_value_dsl.rs b/diesel/src/query_dsl/single_value_dsl.rs index be4c9b272a63..c220b6a93271 100644 --- a/diesel/src/query_dsl/single_value_dsl.rs +++ b/diesel/src/query_dsl/single_value_dsl.rs @@ -3,7 +3,7 @@ use crate::dsl::Limit; use crate::expression::grouped::Grouped; use crate::expression::subselect::Subselect; use crate::query_builder::SelectQuery; -use crate::sql_types::{IntoNullable, SingleValue}; +use crate::sql_types::IntoNullable; /// The `single_value` method /// @@ -20,13 +20,13 @@ pub trait SingleValueDsl { fn single_value(self) -> Self::Output; } -impl SingleValueDsl for T +impl SingleValueDsl for T where - Self: SelectQuery + LimitDsl, - ST: IntoNullable, - ST::Nullable: SingleValue, + Self: SelectQuery + LimitDsl, + ::SqlType: IntoNullable, { - type Output = Grouped, ST::Nullable>>; + type Output = + Grouped, <::SqlType as IntoNullable>::Nullable>>; fn single_value(self) -> Self::Output { Grouped(Subselect::new(self.limit(1))) diff --git a/diesel/src/query_source/joins.rs b/diesel/src/query_source/joins.rs index 90dec6e21a26..e22dac52c1e6 100644 --- a/diesel/src/query_source/joins.rs +++ b/diesel/src/query_source/joins.rs @@ -6,7 +6,7 @@ use crate::expression::SelectableExpression; use crate::prelude::*; use crate::query_builder::*; use crate::result::QueryResult; -use crate::sql_types::Bool; +use crate::sql_types::BoolOrNullableBool; use crate::util::TupleAppend; #[derive(Debug, Clone, Copy, QueryId)] @@ -84,7 +84,8 @@ where impl QuerySource for JoinOn where Join: QuerySource, - On: AppearsOnTable + Clone, + On: AppearsOnTable + Clone, + On::SqlType: BoolOrNullableBool, Join::DefaultSelection: SelectableExpression, { type FromClause = Grouped>; diff --git a/diesel/src/r2d2.rs b/diesel/src/r2d2.rs index bf3ba738550a..72a4b805d011 100644 --- a/diesel/src/r2d2.rs +++ b/diesel/src/r2d2.rs @@ -17,10 +17,10 @@ use std::fmt; use std::marker::PhantomData; use crate::connection::{SimpleConnection, TransactionManager}; -use crate::deserialize::{Queryable, QueryableByName}; +use crate::deserialize::FromSqlRow; +use crate::expression::QueryMetadata; use crate::prelude::*; use crate::query_builder::{AsQuery, QueryFragment, QueryId}; -use crate::sql_types::HasSqlType; /// An r2d2 connection manager for use with Diesel. /// @@ -142,22 +142,14 @@ where (&**self).execute(query) } - fn query_by_index(&self, source: T) -> QueryResult> + fn load(&self, source: T) -> QueryResult> where T: AsQuery, T::Query: QueryFragment + QueryId, - Self::Backend: HasSqlType, - U: Queryable, + U: FromSqlRow, + Self::Backend: QueryMetadata, { - (&**self).query_by_index(source) - } - - fn query_by_name(&self, source: &T) -> QueryResult> - where - T: QueryFragment + QueryId, - U: QueryableByName, - { - (&**self).query_by_name(source) + (&**self).load(source) } fn execute_returning_count(&self, source: &T) -> QueryResult diff --git a/diesel/src/result.rs b/diesel/src/result.rs index fb06e4a6a984..38b458f3c54d 100644 --- a/diesel/src/result.rs +++ b/diesel/src/result.rs @@ -355,3 +355,15 @@ impl fmt::Display for UnexpectedNullError { } impl StdError for UnexpectedNullError {} + +/// Expected more fields then present in the current row while deserialising results +#[derive(Debug, Clone, Copy)] +pub struct UnexpectedEndOfRow; + +impl fmt::Display for UnexpectedEndOfRow { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "Unexpected end of row") + } +} + +impl StdError for UnexpectedEndOfRow {} diff --git a/diesel/src/row.rs b/diesel/src/row.rs index f5e6a695c065..398372d5c8bb 100644 --- a/diesel/src/row.rs +++ b/diesel/src/row.rs @@ -1,30 +1,169 @@ //! Contains the `Row` trait -use crate::backend::{self, Backend}; -use crate::deserialize::{self, FromSql}; +use crate::{ + backend::{self, Backend}, + deserialize, +}; +use deserialize::FromSql; +use std::ops::Range; + +/// Representing a way to index into database rows +/// +/// * Crates using existing backends should use existing implementations of +/// this traits. Diesel provides `RowIndex` and `RowIndex<&str>` for +/// all bulit-in backends +/// +/// * Crates implementing custom backends need to provide `RowIndex` and +/// `RowIndex<&str>` impls for their [`Row`] type. +/// +/// [`Row`]: trait.Row.html +pub trait RowIndex { + /// Get the numeric index inside the current row for the provided index value + fn idx(&self, idx: I) -> Option; +} /// Represents a single database row. -/// Apps should not need to concern themselves with this trait. /// -/// This trait is only used as an argument to [`FromSqlRow`]. +/// This trait is used as an argument to [`FromSqlRow`]. /// /// [`FromSqlRow`]: ../deserialize/trait.FromSqlRow.html -pub trait Row { - /// Returns the value of the next column in the row. - fn take(&mut self) -> Option>; +pub trait Row<'a, DB: Backend>: RowIndex + for<'b> RowIndex<&'b str> + Sized { + /// Field type returned by a `Row` implementation + /// + /// * Crates using existing backend should not concern themself with the + /// concrete type of this associated type. + /// + /// * Crates implementing custom backends should provide their own type + /// meeting the required trait bounds + type Field: Field<'a, DB>; + + /// Return type of `PartialRow` + /// + /// For all implementations, beside of the `Row` implementation on `PartialRow` itself + /// this should be `Self`. + #[doc(hidden)] + type InnerPartialRow: Row<'a, DB>; + + /// Get the number of fields in the current row + fn field_count(&self) -> usize; + + /// Get the field with the provided index from the row. + /// + /// Returns `None` if there is no matching field for the given index + fn get(&self, idx: I) -> Option + where + Self: RowIndex; - /// Returns whether the next `count` columns are all `NULL`. + /// Returns a wrapping row that allows only to access fields, where the index is part of + /// the provided range. + #[doc(hidden)] + fn partial_row(&self, range: Range) -> PartialRow; +} + +/// Represents a single field in a database row. +/// +/// This trait allows retrieving information on the name of the colum and on the value of the +/// field. +pub trait Field<'a, DB: Backend> { + /// The name of the current field /// - /// If this method returns `true`, then the next `count` calls to `take` - /// would all return `None`. - fn next_is_null(&self, count: usize) -> bool; - - /// Skips the next `count` columns. This method must be called if you are - /// choosing not to call `take` as a result of `next_is_null` returning - /// `true`. - fn advance(&mut self, count: usize) { - for _ in 0..count { - self.take(); + /// Returns `None` if it's an unnamed field + fn field_name(&self) -> Option<&'a str>; + + /// Get the value representing the current field in the raw representation + /// as it is transmitted by the database + fn value(&self) -> Option>; + + /// Checks whether this field is null or not. + fn is_null(&self) -> bool { + self.value().is_none() + } +} + +/// A row type that wraps an inner row +/// +/// This type only allows to access fields of the inner row, whose index is +/// part of `range`. +/// +/// Indexing via `usize` starts with 0 for this row type. The index is then shifted +/// by `self.range.start` to match the corresponding field in the underlying row. +#[derive(Debug)] +#[doc(hidden)] +pub struct PartialRow<'a, R> { + inner: &'a R, + range: Range, +} + +impl<'a, R> PartialRow<'a, R> { + #[doc(hidden)] + pub fn new<'b, DB>(inner: &'a R, range: Range) -> Self + where + R: Row<'b, DB>, + DB: Backend, + { + let range_lower = std::cmp::min(range.start, inner.field_count()); + let range_upper = std::cmp::min(range.end, inner.field_count()); + Self { + inner, + range: range_lower..range_upper, + } + } +} + +impl<'a, 'b, DB, R> Row<'a, DB> for PartialRow<'b, R> +where + DB: Backend, + R: Row<'a, DB>, +{ + type Field = R::Field; + type InnerPartialRow = R; + + fn field_count(&self) -> usize { + self.range.len() + } + + fn get(&self, idx: I) -> Option + where + Self: RowIndex, + { + let idx = self.idx(idx)?; + self.inner.get(idx) + } + + fn partial_row(&self, range: Range) -> PartialRow { + let range_upper_bound = std::cmp::min(self.range.end, self.range.start + range.end); + let range = (self.range.start + range.start)..range_upper_bound; + PartialRow { + inner: self.inner, + range, + } + } +} + +impl<'a, 'b, R> RowIndex<&'a str> for PartialRow<'b, R> +where + R: RowIndex<&'a str>, +{ + fn idx(&self, idx: &'a str) -> Option { + let idx = self.inner.idx(idx)?; + if self.range.contains(&idx) { + Some(idx) + } else { + None + } + } +} + +impl<'a, R> RowIndex for PartialRow<'a, R> +where + R: RowIndex, +{ + fn idx(&self, idx: usize) -> Option { + let idx = self.inner.idx(idx + self.range.start)?; + if self.range.contains(&idx) { + Some(idx) + } else { + None } } } @@ -34,7 +173,7 @@ pub trait Row { /// /// This trait is used by implementations of /// [`QueryableByName`](../deserialize/trait.QueryableByName.html) -pub trait NamedRow { +pub trait NamedRow<'a, DB: Backend>: Row<'a, DB> { /// Retrieve and deserialize a single value from the query /// /// Note that `ST` *must* be the exact type of the value with that name in @@ -44,23 +183,23 @@ pub trait NamedRow { /// /// If two or more fields in the query have the given name, the result of /// this function is undefined. + fn get<'b, ST, T>(&self, column_name: &'b str) -> deserialize::Result + where + T: FromSql; +} + +impl<'a, R, DB> NamedRow<'a, DB> for R +where + R: Row<'a, DB>, + DB: Backend, +{ fn get(&self, column_name: &str) -> deserialize::Result where T: FromSql, { - let idx = self - .index_of(column_name) - .ok_or_else(|| format!("Column `{}` was not present in query", column_name).into()); - let idx = match idx { - Ok(x) => x, - Err(e) => return Err(e), - }; - let raw_value = self.get_raw_value(idx); - T::from_sql(raw_value) - } + let field = Row::get(self, column_name) + .ok_or_else(|| format!("Column `{}` was not present in query", column_name))?; - #[doc(hidden)] - fn index_of(&self, column_name: &str) -> Option; - #[doc(hidden)] - fn get_raw_value(&self, index: usize) -> Option>; + T::from_nullable_sql(field.value()) + } } diff --git a/diesel/src/serialize.rs b/diesel/src/serialize.rs index fe70c17b08d3..d87ca18155a1 100644 --- a/diesel/src/serialize.rs +++ b/diesel/src/serialize.rs @@ -145,7 +145,7 @@ where /// database, you should use `i32::to_sql(x, out)` instead of writing to `out` /// yourself. /// -/// Any types which implement this trait should also `#[derive(AsExpression)]`. +/// Any types which implement this trait should also [`#[derive(AsExpression)]`]. /// /// ### Backend specific details /// @@ -157,6 +157,7 @@ where /// - For third party backends, consult that backend's documentation. /// /// [`MysqlType`]: ../mysql/enum.MysqlType.html +/// [`#[derive(AsExpression)]`]: ../expression/derive.AsExpression.html; /// /// ### Examples /// @@ -165,12 +166,14 @@ where /// /// ```rust /// # use diesel::backend::Backend; +/// # use diesel::expression::AsExpression; /// # use diesel::sql_types::*; /// # use diesel::serialize::{self, ToSql, Output}; /// # use std::io::Write; /// # /// #[repr(i32)] -/// #[derive(Debug, Clone, Copy)] +/// #[derive(Debug, Clone, Copy, AsExpression)] +/// #[sql_type = "Integer"] /// pub enum MyEnum { /// A = 1, /// B = 2, diff --git a/diesel/src/sql_types/fold.rs b/diesel/src/sql_types/fold.rs index 023aedf142c3..91e0de2e8132 100644 --- a/diesel/src/sql_types/fold.rs +++ b/diesel/src/sql_types/fold.rs @@ -1,16 +1,16 @@ -use crate::sql_types::{self, NotNull}; +use crate::sql_types::{self, is_nullable, SingleValue, SqlType}; /// Represents SQL types which can be used with `SUM` and `AVG` -pub trait Foldable { +pub trait Foldable: SingleValue { /// The SQL type of `sum(this_type)` - type Sum; + type Sum: SqlType + SingleValue; /// The SQL type of `avg(this_type)` - type Avg; + type Avg: SqlType + SingleValue; } impl Foldable for sql_types::Nullable where - T: Foldable + NotNull, + T: Foldable + SqlType, { type Sum = T::Sum; type Avg = T::Avg; diff --git a/diesel/src/sql_types/mod.rs b/diesel/src/sql_types/mod.rs index bb5994e55332..db517968db37 100644 --- a/diesel/src/sql_types/mod.rs +++ b/diesel/src/sql_types/mod.rs @@ -20,6 +20,7 @@ mod ord; pub use self::fold::Foldable; pub use self::ord::SqlOrd; +use crate::expression::TypedExpressionType; use crate::query_builder::QueryId; /// The boolean SQL type. @@ -377,7 +378,14 @@ pub struct Json; /// /// - `Option` for any `T` which implements `FromSql` #[derive(Debug, Clone, Copy, Default)] -pub struct Nullable(ST); +pub struct Nullable(ST); + +impl SqlType for Nullable +where + ST: SqlType, +{ + type IsNull = is_nullable::IsNullable; +} #[cfg(feature = "postgres")] pub use crate::pg::types::sql_types::*; @@ -404,12 +412,6 @@ pub trait HasSqlType: TypeMetadata { /// This method may use `lookup` to do dynamic runtime lookup. Implementors /// of this method should not do dynamic lookup unless absolutely necessary fn metadata(lookup: &Self::MetadataLookup) -> Self::TypeMetadata; - - #[doc(hidden)] - #[cfg(feature = "mysql")] - fn mysql_row_metadata(out: &mut Vec, lookup: &Self::MetadataLookup) { - out.push(Self::metadata(lookup)) - } } /// Information about how a backend stores metadata about given SQL types @@ -427,15 +429,6 @@ pub trait TypeMetadata { type MetadataLookup; } -/// A marker trait indicating that a SQL type is not null. -/// -/// All SQL types must implement this trait. -/// -/// # Deriving -/// -/// This trait is automatically implemented by `#[derive(SqlType)]` -pub trait NotNull {} - /// Converts a type which may or may not be nullable into its nullable /// representation. pub trait IntoNullable { @@ -445,12 +438,41 @@ pub trait IntoNullable { type Nullable; } -impl IntoNullable for T { +impl IntoNullable for T +where + T: SqlType + SingleValue, +{ type Nullable = Nullable; } -impl IntoNullable for Nullable { - type Nullable = Nullable; +impl IntoNullable for Nullable +where + T: SqlType, +{ + type Nullable = Self; +} + +/// Converts a type which may or may not be nullable into its not nullable +/// representation. +pub trait IntoNotNullable { + /// The not nullable representation of this type. + /// + /// For `Nullable`, this will be `T` otherwise the type itself + type NotNullable; +} + +impl IntoNotNullable for T +where + T: SqlType, +{ + type NotNullable = T; +} + +impl IntoNotNullable for Nullable +where + T: SqlType, +{ + type NotNullable = T; } /// A marker trait indicating that a SQL type represents a single value, as @@ -462,12 +484,149 @@ impl IntoNullable for Nullable { /// /// # Deriving /// -/// This trait is automatically implemented by `#[derive(SqlType)]` -pub trait SingleValue {} +/// This trait is automatically implemented by [`#[derive(SqlType)]`] +/// +/// [`#[derive(SqlType)]`]: derive.SqlType.html +pub trait SingleValue: SqlType {} -impl SingleValue for Nullable {} +impl SingleValue for Nullable {} #[doc(inline)] pub use diesel_derives::DieselNumericOps; #[doc(inline)] pub use diesel_derives::SqlType; + +/// A marker trait for SQL types +/// +/// # Deriving +/// +/// This trait is automatically implemented by [`#[derive(SqlType)]`] +/// which sets `IsNull` to [`is_nullable::NotNull`] +/// +/// [`#[derive(SqlType)]`]: derive.SqlType.html +/// [`is_nullable::NotNull`]: is_nullable/struct.NotNull.html +pub trait SqlType { + /// Is this type nullable? + /// + /// This type should always be one of the structs in the ['is_nullable`] + /// module. See the documentation of those structs for more details. + /// + /// ['is_nullable`]: is_nullable/index.html + type IsNull: OneIsNullable + OneIsNullable; +} + +/// Is one value of `IsNull` nullable? +/// +/// You should never implement this trait. +pub trait OneIsNullable { + /// See the trait documentation + type Out: OneIsNullable + OneIsNullable; +} + +/// Are both values of `IsNull` are nullable? +pub trait AllAreNullable { + /// See the trait documentation + type Out: AllAreNullable + AllAreNullable; +} + +/// A type level constructor for maybe nullable types +/// +/// Constructs either `Nullable` (for `Self` == `is_nullable::IsNullable`) +/// or `O` (for `Self` == `is_nullable::NotNull`) +pub trait MaybeNullableType { + /// See the trait documentation + type Out: SqlType + TypedExpressionType; +} + +/// Possible values for `SqlType::IsNullable` +pub mod is_nullable { + use super::*; + + /// No, this type cannot be null as it is marked as `NOT NULL` at database level + /// + /// This should be choosen for basically all manual impls of `SqlType` + /// beside implementing your own `Nullable<>` wrapper type + #[derive(Debug, Clone, Copy)] + pub struct NotNull; + + /// Yes, this type can be null + /// + /// The only diesel provided `SqlType` that uses this value is [`Nullable`] + /// + /// [`Nullable`]: ../struct.Nullable.html + #[derive(Debug, Clone, Copy)] + pub struct IsNullable; + + impl OneIsNullable for NotNull { + type Out = NotNull; + } + + impl OneIsNullable for NotNull { + type Out = IsNullable; + } + + impl OneIsNullable for IsNullable { + type Out = IsNullable; + } + + impl OneIsNullable for IsNullable { + type Out = IsNullable; + } + + impl AllAreNullable for NotNull { + type Out = NotNull; + } + + impl AllAreNullable for NotNull { + type Out = NotNull; + } + + impl AllAreNullable for IsNullable { + type Out = NotNull; + } + + impl AllAreNullable for IsNullable { + type Out = IsNullable; + } + + impl MaybeNullableType for NotNull + where + O: SqlType + TypedExpressionType, + { + type Out = O; + } + + impl MaybeNullableType for IsNullable + where + O: SqlType, + Nullable: TypedExpressionType, + { + type Out = Nullable; + } + + /// Represents the output type of [`MaybeNullableType`](../trait.MaybeNullableType.html) + pub type MaybeNullable = >::Out; + + /// Represents the output type of [`OneIsNullable`](../trait.OneIsNullable.html) + /// for two given SQL types + pub type IsOneNullable = + as OneIsNullable>>::Out; + + /// Represents the output type of [`AllAreNullable`](../trait.AllAreNullable.html) + /// for two given SQL types + pub type AreAllNullable = + as AllAreNullable>>::Out; + + /// Represents if the SQL type is nullable or not + pub type IsSqlTypeNullable = ::IsNull; +} + +/// A marker trait for accepting expressions of the type `Bool` and +/// `Nullable` in the same place +pub trait BoolOrNullableBool {} + +impl BoolOrNullableBool for Bool {} +impl BoolOrNullableBool for Nullable {} + +#[doc(inline)] +pub use crate::expression::expression_types::Untyped; diff --git a/diesel/src/sql_types/ops.rs b/diesel/src/sql_types/ops.rs index ad8a6668126b..f548764c6aba 100644 --- a/diesel/src/sql_types/ops.rs +++ b/diesel/src/sql_types/ops.rs @@ -36,33 +36,33 @@ use super::*; /// Represents SQL types which can be added. pub trait Add { /// The SQL type which can be added to this one - type Rhs; + type Rhs: SqlType; /// The SQL type of the result of adding `Rhs` to `Self` - type Output; + type Output: SqlType; } /// Represents SQL types which can be subtracted. pub trait Sub { /// The SQL type which can be subtracted from this one - type Rhs; + type Rhs: SqlType; /// The SQL type of the result of subtracting `Rhs` from `Self` - type Output; + type Output: SqlType; } /// Represents SQL types which can be multiplied. pub trait Mul { /// The SQL type which this can be multiplied by - type Rhs; + type Rhs: SqlType; /// The SQL type of the result of multiplying `Self` by `Rhs` - type Output; + type Output: SqlType; } /// Represents SQL types which can be divided. pub trait Div { /// The SQL type which this one can be divided by - type Rhs; + type Rhs: SqlType; /// The SQL type of the result of dividing `Self` by `Rhs` - type Output; + type Output: SqlType; } macro_rules! numeric_type { @@ -145,9 +145,9 @@ impl Div for Interval { impl Add for Nullable where - T: Add + NotNull, - T::Rhs: NotNull, - T::Output: NotNull, + T: Add + SqlType, + T::Rhs: SqlType, + T::Output: SqlType, { type Rhs = Nullable; type Output = Nullable; @@ -155,9 +155,9 @@ where impl Sub for Nullable where - T: Sub + NotNull, - T::Rhs: NotNull, - T::Output: NotNull, + T: Sub + SqlType, + T::Rhs: SqlType, + T::Output: SqlType, { type Rhs = Nullable; type Output = Nullable; @@ -165,9 +165,9 @@ where impl Mul for Nullable where - T: Mul + NotNull, - T::Rhs: NotNull, - T::Output: NotNull, + T: Mul + SqlType, + T::Rhs: SqlType, + T::Output: SqlType, { type Rhs = Nullable; type Output = Nullable; @@ -175,9 +175,9 @@ where impl Div for Nullable where - T: Div + NotNull, - T::Rhs: NotNull, - T::Output: NotNull, + T: Div + SqlType, + T::Rhs: SqlType, + T::Output: SqlType, { type Rhs = Nullable; type Output = Nullable; diff --git a/diesel/src/sql_types/ord.rs b/diesel/src/sql_types/ord.rs index 7ce6293d14f6..e3ba8d3fdc6f 100644 --- a/diesel/src/sql_types/ord.rs +++ b/diesel/src/sql_types/ord.rs @@ -1,7 +1,7 @@ -use crate::sql_types::{self, NotNull}; +use crate::sql_types::{self, is_nullable, SqlType}; /// Marker trait for types which can be used with `MAX` and `MIN` -pub trait SqlOrd {} +pub trait SqlOrd: SqlType {} impl SqlOrd for sql_types::SmallInt {} impl SqlOrd for sql_types::Integer {} @@ -13,7 +13,7 @@ impl SqlOrd for sql_types::Date {} impl SqlOrd for sql_types::Interval {} impl SqlOrd for sql_types::Time {} impl SqlOrd for sql_types::Timestamp {} -impl SqlOrd for sql_types::Nullable {} +impl SqlOrd for sql_types::Nullable where T: SqlOrd + SqlType {} #[cfg(feature = "postgres")] impl SqlOrd for sql_types::Timestamptz {} diff --git a/diesel/src/sqlite/backend.rs b/diesel/src/sqlite/backend.rs index 24a1947b4714..1c6af1637e82 100644 --- a/diesel/src/sqlite/backend.rs +++ b/diesel/src/sqlite/backend.rs @@ -20,7 +20,7 @@ pub struct Sqlite; /// The variants of this struct determine what bytes are expected from /// `ToSql` impls. #[allow(missing_debug_implementations)] -#[derive(Hash, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)] pub enum SqliteType { /// Bind using `sqlite3_bind_blob` Binary, @@ -45,7 +45,7 @@ impl Backend for Sqlite { } impl<'a> HasRawValue<'a> for Sqlite { - type RawValue = &'a SqliteValue; + type RawValue = SqliteValue<'a>; } impl TypeMetadata for Sqlite { diff --git a/diesel/src/sqlite/connection/functions.rs b/diesel/src/sqlite/connection/functions.rs index e340b64c1388..c3fccc9dfce5 100644 --- a/diesel/src/sqlite/connection/functions.rs +++ b/diesel/src/sqlite/connection/functions.rs @@ -3,11 +3,12 @@ extern crate libsqlite3_sys as ffi; use super::raw::RawConnection; use super::serialized_value::SerializedValue; use super::{Sqlite, SqliteAggregateFunction, SqliteValue}; -use crate::deserialize::{FromSqlRow, Queryable}; +use crate::deserialize::{FromSqlRow, StaticallySizedRow}; use crate::result::{DatabaseErrorKind, Error, QueryResult}; -use crate::row::Row; +use crate::row::{Field, PartialRow, Row, RowIndex}; use crate::serialize::{IsNull, Output, ToSql}; use crate::sql_types::HasSqlType; +use std::marker::PhantomData; pub fn register( conn: &RawConnection, @@ -17,11 +18,11 @@ pub fn register( ) -> QueryResult<()> where F: FnMut(&RawConnection, Args) -> Ret + Send + 'static, - Args: Queryable, + Args: FromSqlRow + StaticallySizedRow, Ret: ToSql, Sqlite: HasSqlType, { - let fields_needed = Args::Row::FIELDS_NEEDED; + let fields_needed = Args::FIELD_COUNT; if fields_needed > 127 { return Err(Error::DatabaseError( DatabaseErrorKind::UnableToSendCommand, @@ -45,11 +46,11 @@ pub fn register_aggregate( ) -> QueryResult<()> where A: SqliteAggregateFunction + 'static + Send, - Args: Queryable, + Args: FromSqlRow + StaticallySizedRow, Ret: ToSql, Sqlite: HasSqlType, { - let fields_needed = Args::Row::FIELDS_NEEDED; + let fields_needed = Args::FIELD_COUNT; if fields_needed > 127 { return Err(Error::DatabaseError( DatabaseErrorKind::UnableToSendCommand, @@ -69,12 +70,10 @@ pub(crate) fn build_sql_function_args( args: &[*mut ffi::sqlite3_value], ) -> Result where - Args: Queryable, + Args: FromSqlRow, { - let mut row = FunctionRow { args }; - let args_row = Args::Row::build_from_row(&mut row).map_err(Error::DeserializationError)?; - - Ok(Args::build(args_row)) + let row = FunctionRow::new(args); + Args::build_from_row(&row).map_err(Error::DeserializationError) } pub(crate) fn process_sql_function_result( @@ -99,21 +98,73 @@ where }) } +#[derive(Clone)] struct FunctionRow<'a> { args: &'a [*mut ffi::sqlite3_value], } -impl<'a> Row for FunctionRow<'a> { - fn take(&mut self) -> Option<&SqliteValue> { - self.args.split_first().and_then(|(&first, rest)| { - self.args = rest; - unsafe { SqliteValue::new(first) } +impl<'a> FunctionRow<'a> { + fn new(args: &'a [*mut ffi::sqlite3_value]) -> Self { + Self { args } + } +} + +impl<'a> Row<'a, Sqlite> for FunctionRow<'a> { + type Field = FunctionArgument<'a>; + type InnerPartialRow = Self; + + fn field_count(&self) -> usize { + self.args.len() + } + + fn get(&self, idx: I) -> Option + where + Self: crate::row::RowIndex, + { + let idx = self.idx(idx)?; + + self.args.get(idx).map(|arg| FunctionArgument { + arg: *arg, + p: PhantomData, }) } - fn next_is_null(&self, count: usize) -> bool { - self.args[..count] - .iter() - .all(|&p| unsafe { SqliteValue::new(p) }.is_none()) + fn partial_row(&self, range: std::ops::Range) -> PartialRow { + PartialRow::new(self, range) + } +} + +impl<'a> RowIndex for FunctionRow<'a> { + fn idx(&self, idx: usize) -> Option { + if idx < self.args.len() { + Some(idx) + } else { + None + } + } +} + +impl<'a, 'b> RowIndex<&'a str> for FunctionRow<'b> { + fn idx(&self, _idx: &'a str) -> Option { + None + } +} + +struct FunctionArgument<'a> { + arg: *mut ffi::sqlite3_value, + p: PhantomData<&'a ()>, +} + +impl<'a> Field<'a, Sqlite> for FunctionArgument<'a> { + fn field_name(&self) -> Option<&'a str> { + None + } + + fn is_null(&self) -> bool { + self.value().is_none() + } + + fn value(&self) -> Option> { + unsafe { SqliteValue::new(self.arg) } } } diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs index b200d15e573a..a52a1f842935 100644 --- a/diesel/src/sqlite/connection/mod.rs +++ b/diesel/src/sqlite/connection/mod.rs @@ -17,7 +17,8 @@ use self::statement_iterator::*; use self::stmt::{Statement, StatementUse}; use super::SqliteAggregateFunction; use crate::connection::*; -use crate::deserialize::{Queryable, QueryableByName}; +use crate::deserialize::{FromSqlRow, StaticallySizedRow}; +use crate::expression::QueryMetadata; use crate::query_builder::bind_collector::RawBytesBindCollector; use crate::query_builder::*; use crate::result::*; @@ -71,12 +72,12 @@ impl Connection for SqliteConnection { } #[doc(hidden)] - fn query_by_index(&self, source: T) -> QueryResult> + fn load(&self, source: T) -> QueryResult> where T: AsQuery, T::Query: QueryFragment + QueryId, - Self::Backend: HasSqlType, - U: Queryable, + U: FromSqlRow, + Self::Backend: QueryMetadata, { let mut statement = self.prepare_query(&source.as_query())?; let statement_use = StatementUse::new(&mut statement); @@ -84,18 +85,6 @@ impl Connection for SqliteConnection { iter.collect() } - #[doc(hidden)] - fn query_by_name(&self, source: &T) -> QueryResult> - where - T: QueryFragment + QueryId, - U: QueryableByName, - { - let mut statement = self.prepare_query(source)?; - let statement_use = StatementUse::new(&mut statement); - let iter = NamedStatementIterator::new(statement_use)?; - iter.collect() - } - #[doc(hidden)] fn execute_returning_count(&self, source: &T) -> QueryResult where @@ -227,7 +216,7 @@ impl SqliteConnection { ) -> QueryResult<()> where F: FnMut(Args) -> Ret + Send + 'static, - Args: Queryable, + Args: FromSqlRow + StaticallySizedRow, Ret: ToSql, Sqlite: HasSqlType, { @@ -246,7 +235,7 @@ impl SqliteConnection { ) -> QueryResult<()> where A: SqliteAggregateFunction + 'static + Send, - Args: Queryable, + Args: FromSqlRow + StaticallySizedRow, Ret: ToSql, Sqlite: HasSqlType, { diff --git a/diesel/src/sqlite/connection/raw.rs b/diesel/src/sqlite/connection/raw.rs index dde7f3c94b11..7d2484646344 100644 --- a/diesel/src/sqlite/connection/raw.rs +++ b/diesel/src/sqlite/connection/raw.rs @@ -9,7 +9,7 @@ use std::{mem, ptr, slice, str}; use super::functions::{build_sql_function_args, process_sql_function_result}; use super::serialized_value::SerializedValue; use super::{Sqlite, SqliteAggregateFunction}; -use crate::deserialize::Queryable; +use crate::deserialize::FromSqlRow; use crate::result::Error::DatabaseError; use crate::result::*; use crate::serialize::ToSql; @@ -116,7 +116,7 @@ impl RawConnection { ) -> QueryResult<()> where A: SqliteAggregateFunction + 'static + Send, - Args: Queryable, + Args: FromSqlRow, Ret: ToSql, Sqlite: HasSqlType, { @@ -266,7 +266,7 @@ extern "C" fn run_aggregator_step_function + 'static + Send, - Args: Queryable, + Args: FromSqlRow, Ret: ToSql, Sqlite: HasSqlType, { @@ -336,7 +336,7 @@ extern "C" fn run_aggregator_final_function + 'static + Send, - Args: Queryable, + Args: FromSqlRow, Ret: ToSql, Sqlite: HasSqlType, { diff --git a/diesel/src/sqlite/connection/sqlite_value.rs b/diesel/src/sqlite/connection/sqlite_value.rs index 247f49210d20..60e89885286d 100644 --- a/diesel/src/sqlite/connection/sqlite_value.rs +++ b/diesel/src/sqlite/connection/sqlite_value.rs @@ -1,39 +1,52 @@ extern crate libsqlite3_sys as ffi; -use std::collections::HashMap; +use std::marker::PhantomData; use std::os::raw as libc; use std::ptr::NonNull; use std::{slice, str}; use crate::row::*; -use crate::sqlite::Sqlite; +use crate::sqlite::{Sqlite, SqliteType}; +/// Raw sqlite value as received from the database +/// +/// Use existing `FromSql` implementations to convert this into +/// rust values: #[allow(missing_debug_implementations, missing_copy_implementations)] -pub struct SqliteValue { - value: ffi::sqlite3_value, +pub struct SqliteValue<'a> { + value: NonNull, + p: PhantomData<&'a ()>, } -pub struct SqliteRow { +#[derive(Clone)] +pub struct SqliteRow<'a> { stmt: NonNull, next_col_index: libc::c_int, + p: PhantomData<&'a ()>, } -impl SqliteValue { - #[allow(clippy::new_ret_no_self)] - pub(crate) unsafe fn new<'a>(inner: *mut ffi::sqlite3_value) -> Option<&'a Self> { - (inner as *const _ as *const Self).as_ref().and_then(|v| { - if v.is_null() { - None - } else { - Some(v) - } - }) +impl<'a> SqliteValue<'a> { + pub(crate) unsafe fn new(inner: *mut ffi::sqlite3_value) -> Option { + NonNull::new(inner) + .map(|value| SqliteValue { + value, + p: PhantomData, + }) + .and_then(|value| { + // We check here that the actual value represented by the inner + // `sqlite3_value` is not `NULL` (is sql meaning, not ptr meaning) + if value.is_null() { + None + } else { + Some(value) + } + }) } - pub fn read_text(&self) -> &str { + pub(crate) fn read_text(&self) -> &str { unsafe { - let ptr = ffi::sqlite3_value_text(self.value()); - let len = ffi::sqlite3_value_bytes(self.value()); + let ptr = ffi::sqlite3_value_text(self.value.as_ptr()); + let len = ffi::sqlite3_value_bytes(self.value.as_ptr()); let bytes = slice::from_raw_parts(ptr as *const u8, len as usize); // The string is guaranteed to be utf8 according to // https://www.sqlite.org/c3ref/value_blob.html @@ -41,86 +54,136 @@ impl SqliteValue { } } - pub fn read_blob(&self) -> &[u8] { + pub(crate) fn read_blob(&self) -> &[u8] { unsafe { - let ptr = ffi::sqlite3_value_blob(self.value()); - let len = ffi::sqlite3_value_bytes(self.value()); + let ptr = ffi::sqlite3_value_blob(self.value.as_ptr()); + let len = ffi::sqlite3_value_bytes(self.value.as_ptr()); slice::from_raw_parts(ptr as *const u8, len as usize) } } - pub fn read_integer(&self) -> i32 { - unsafe { ffi::sqlite3_value_int(self.value()) as i32 } + pub(crate) fn read_integer(&self) -> i32 { + unsafe { ffi::sqlite3_value_int(self.value.as_ptr()) as i32 } } - pub fn read_long(&self) -> i64 { - unsafe { ffi::sqlite3_value_int64(self.value()) as i64 } + pub(crate) fn read_long(&self) -> i64 { + unsafe { ffi::sqlite3_value_int64(self.value.as_ptr()) as i64 } } - pub fn read_double(&self) -> f64 { - unsafe { ffi::sqlite3_value_double(self.value()) as f64 } + pub(crate) fn read_double(&self) -> f64 { + unsafe { ffi::sqlite3_value_double(self.value.as_ptr()) as f64 } } - pub fn is_null(&self) -> bool { - let tpe = unsafe { ffi::sqlite3_value_type(self.value()) }; - tpe == ffi::SQLITE_NULL + /// Get the type of the value as returned by sqlite + pub fn value_type(&self) -> Option { + let tpe = unsafe { ffi::sqlite3_value_type(self.value.as_ptr()) }; + match tpe { + ffi::SQLITE_TEXT => Some(SqliteType::Text), + ffi::SQLITE_INTEGER => Some(SqliteType::Long), + ffi::SQLITE_FLOAT => Some(SqliteType::Double), + ffi::SQLITE_BLOB => Some(SqliteType::Binary), + ffi::SQLITE_NULL => None, + _ => unreachable!("Sqlite docs saying this is not reachable"), + } } - fn value(&self) -> *mut ffi::sqlite3_value { - &self.value as *const _ as _ + pub(crate) fn is_null(&self) -> bool { + self.value_type().is_none() } } -impl SqliteRow { - pub(crate) fn new(inner_statement: NonNull) -> Self { +impl<'a> SqliteRow<'a> { + pub(crate) unsafe fn new(inner_statement: NonNull) -> Self { SqliteRow { stmt: inner_statement, next_col_index: 0, + p: PhantomData, } } +} + +impl<'a> Row<'a, Sqlite> for SqliteRow<'a> { + type Field = SqliteField<'a>; + type InnerPartialRow = Self; + + fn field_count(&self) -> usize { + column_count(self.stmt) as usize + } - pub fn into_named<'a>(self, indices: &'a HashMap<&'a str, usize>) -> SqliteNamedRow<'a> { - SqliteNamedRow { + fn get(&self, idx: I) -> Option + where + Self: RowIndex, + { + let idx = self.idx(idx)?; + Some(SqliteField { stmt: self.stmt, - column_indices: indices, - } + col_idx: idx as i32, + p: PhantomData, + }) } -} -impl Row for SqliteRow { - fn take(&mut self) -> Option<&SqliteValue> { - let col_index = self.next_col_index; - self.next_col_index += 1; + fn partial_row(&self, range: std::ops::Range) -> PartialRow { + PartialRow::new(self, range) + } +} - unsafe { - let ptr = ffi::sqlite3_column_value(self.stmt.as_ptr(), col_index); - SqliteValue::new(ptr) +impl<'a> RowIndex for SqliteRow<'a> { + fn idx(&self, idx: usize) -> Option { + if idx < self.field_count() { + Some(idx) + } else { + None } } +} - fn next_is_null(&self, count: usize) -> bool { - (0..count).all(|i| { - let idx = self.next_col_index + i as libc::c_int; - let tpe = unsafe { ffi::sqlite3_column_type(self.stmt.as_ptr(), idx) }; - tpe == ffi::SQLITE_NULL - }) +impl<'a, 'b> RowIndex<&'a str> for SqliteRow<'b> { + fn idx(&self, field_name: &'a str) -> Option { + (0..column_count(self.stmt)) + .find(|idx| column_name(self.stmt, *idx) == Some(field_name)) + .map(|a| a as usize) } } -pub struct SqliteNamedRow<'a> { +pub struct SqliteField<'a> { stmt: NonNull, - column_indices: &'a HashMap<&'a str, usize>, + col_idx: i32, + p: PhantomData<&'a ()>, } -impl<'a> NamedRow for SqliteNamedRow<'a> { - fn index_of(&self, column_name: &str) -> Option { - self.column_indices.get(column_name).cloned() +impl<'a> Field<'a, Sqlite> for SqliteField<'a> { + fn field_name(&self) -> Option<&'a str> { + column_name(self.stmt, self.col_idx) } - fn get_raw_value(&self, idx: usize) -> Option<&SqliteValue> { + fn is_null(&self) -> bool { + self.value().is_none() + } + + fn value(&self) -> Option> { unsafe { - let ptr = ffi::sqlite3_column_value(self.stmt.as_ptr(), idx as libc::c_int); + let ptr = ffi::sqlite3_column_value(self.stmt.as_ptr(), self.col_idx); SqliteValue::new(ptr) } } } + +fn column_name<'a>(stmt: NonNull, field_number: i32) -> Option<&'a str> { + unsafe { + let ptr = ffi::sqlite3_column_name(stmt.as_ptr(), field_number); + if ptr.is_null() { + None + } else { + Some(std::ffi::CStr::from_ptr(ptr).to_str().expect( + "The Sqlite documentation states that this is UTF8. \ + If you see this error message something has gone \ + horribliy wrong. Please open an issue at the \ + diesel repository.", + )) + } + } +} + +fn column_count(stmt: NonNull) -> i32 { + unsafe { ffi::sqlite3_column_count(stmt.as_ptr()) } +} diff --git a/diesel/src/sqlite/connection/statement_iterator.rs b/diesel/src/sqlite/connection/statement_iterator.rs index 098b3c2baeaf..91195631faf9 100644 --- a/diesel/src/sqlite/connection/statement_iterator.rs +++ b/diesel/src/sqlite/connection/statement_iterator.rs @@ -1,8 +1,7 @@ -use std::collections::HashMap; use std::marker::PhantomData; use super::stmt::StatementUse; -use crate::deserialize::{FromSqlRow, Queryable, QueryableByName}; +use crate::deserialize::FromSqlRow; use crate::result::Error::DeserializationError; use crate::result::QueryResult; use crate::sqlite::Sqlite; @@ -23,7 +22,7 @@ impl<'a, ST, T> StatementIterator<'a, ST, T> { impl<'a, ST, T> Iterator for StatementIterator<'a, ST, T> where - T: Queryable, + T: FromSqlRow, { type Item = QueryResult; @@ -32,55 +31,6 @@ where Ok(row) => row, Err(e) => return Some(Err(e)), }; - row.map(|mut row| { - T::Row::build_from_row(&mut row) - .map(T::build) - .map_err(DeserializationError) - }) - } -} - -pub struct NamedStatementIterator<'a, T> { - stmt: StatementUse<'a>, - column_indices: HashMap<&'a str, usize>, - _marker: PhantomData, -} - -impl<'a, T> NamedStatementIterator<'a, T> { - #[allow(clippy::new_ret_no_self)] - pub fn new(stmt: StatementUse<'a>) -> QueryResult { - let column_indices = (0..stmt.num_fields()) - .filter_map(|i| { - stmt.field_name(i).map(|column| { - let column = column - .to_str() - .map_err(|e| DeserializationError(e.into()))?; - Ok((column, i)) - }) - }) - .collect::>()?; - Ok(NamedStatementIterator { - stmt, - column_indices, - _marker: PhantomData, - }) - } -} - -impl<'a, T> Iterator for NamedStatementIterator<'a, T> -where - T: QueryableByName, -{ - type Item = QueryResult; - - fn next(&mut self) -> Option { - let row = match self.stmt.step() { - Ok(row) => row, - Err(e) => return Some(Err(e)), - }; - row.map(|row| { - let row = row.into_named(&self.column_indices); - T::build(&row).map_err(DeserializationError) - }) + row.map(|row| T::build_from_row(&row).map_err(DeserializationError)) } } diff --git a/diesel/src/sqlite/connection/stmt.rs b/diesel/src/sqlite/connection/stmt.rs index 07fe2989847e..3ca28c8ecbc3 100644 --- a/diesel/src/sqlite/connection/stmt.rs +++ b/diesel/src/sqlite/connection/stmt.rs @@ -54,26 +54,13 @@ impl Statement { ensure_sqlite_ok(result, self.raw_connection()) } - fn num_fields(&self) -> usize { - unsafe { ffi::sqlite3_column_count(self.inner_statement.as_ptr()) as usize } - } - - /// The lifetime of the returned CStr is shorter than self. This function - /// should be tied to a lifetime that ends before the next call to `reset` - unsafe fn field_name<'a>(&self, idx: usize) -> Option<&'a CStr> { - let ptr = ffi::sqlite3_column_name(self.inner_statement.as_ptr(), idx as libc::c_int); - if ptr.is_null() { - None - } else { - Some(CStr::from_ptr(ptr)) - } - } - fn step(&mut self) -> QueryResult> { - match unsafe { ffi::sqlite3_step(self.inner_statement.as_ptr()) } { - ffi::SQLITE_DONE => Ok(None), - ffi::SQLITE_ROW => Ok(Some(SqliteRow::new(self.inner_statement))), - _ => Err(last_error(self.raw_connection())), + unsafe { + match ffi::sqlite3_step(self.inner_statement.as_ptr()) { + ffi::SQLITE_DONE => Ok(None), + ffi::SQLITE_ROW => Ok(Some(SqliteRow::new(self.inner_statement))), + _ => Err(last_error(self.raw_connection())), + } } } @@ -158,14 +145,6 @@ impl<'a> StatementUse<'a> { pub fn step(&mut self) -> QueryResult> { self.statement.step() } - - pub fn num_fields(&self) -> usize { - self.statement.num_fields() - } - - pub fn field_name(&self, idx: usize) -> Option<&'a CStr> { - unsafe { self.statement.field_name(idx) } - } } impl<'a> Drop for StatementUse<'a> { diff --git a/diesel/src/sqlite/mod.rs b/diesel/src/sqlite/mod.rs index cbd07231bc7c..931785b7e023 100644 --- a/diesel/src/sqlite/mod.rs +++ b/diesel/src/sqlite/mod.rs @@ -12,6 +12,7 @@ pub mod query_builder; pub use self::backend::{Sqlite, SqliteType}; pub use self::connection::SqliteConnection; +pub use self::connection::SqliteValue; pub use self::query_builder::SqliteQueryBuilder; /// Trait for the implementation of a SQLite aggregate function diff --git a/diesel/src/sqlite/types/date_and_time/chrono.rs b/diesel/src/sqlite/types/date_and_time/chrono.rs index 7cdb0a0eaadd..7d186c6f6130 100644 --- a/diesel/src/sqlite/types/date_and_time/chrono.rs +++ b/diesel/src/sqlite/types/date_and_time/chrono.rs @@ -12,7 +12,7 @@ use crate::sqlite::Sqlite; const SQLITE_DATE_FORMAT: &str = "%F"; impl FromSql for NaiveDate { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: backend::RawValue) -> deserialize::Result { let text_ptr = <*const str as FromSql>::from_sql(value)?; let text = unsafe { &*text_ptr }; Self::parse_from_str(text, SQLITE_DATE_FORMAT).map_err(Into::into) @@ -27,7 +27,7 @@ impl ToSql for NaiveDate { } impl FromSql for NaiveTime { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: backend::RawValue) -> deserialize::Result { let text_ptr = <*const str as FromSql>::from_sql(value)?; let text = unsafe { &*text_ptr }; let valid_time_formats = &[ @@ -54,7 +54,7 @@ impl ToSql for NaiveTime { } impl FromSql for NaiveDateTime { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: backend::RawValue) -> deserialize::Result { let text_ptr = <*const str as FromSql>::from_sql(value)?; let text = unsafe { &*text_ptr }; diff --git a/diesel/src/sqlite/types/date_and_time/mod.rs b/diesel/src/sqlite/types/date_and_time/mod.rs index 6679d434d052..18bdd45713f0 100644 --- a/diesel/src/sqlite/types/date_and_time/mod.rs +++ b/diesel/src/sqlite/types/date_and_time/mod.rs @@ -15,7 +15,7 @@ mod chrono; /// raw pointer instead of a reference with a lifetime due to the structure of /// `FromSql` impl FromSql for *const str { - fn from_sql(value: Option<&SqliteValue>) -> deserialize::Result { + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { FromSql::::from_sql(value) } } @@ -38,7 +38,7 @@ impl ToSql for String { /// raw pointer instead of a reference with a lifetime due to the structure of /// `FromSql` impl FromSql for *const str { - fn from_sql(value: Option<&SqliteValue>) -> deserialize::Result { + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { FromSql::::from_sql(value) } } @@ -61,7 +61,7 @@ impl ToSql for String { /// raw pointer instead of a reference with a lifetime due to the structure of /// `FromSql` impl FromSql for *const str { - fn from_sql(value: Option<&SqliteValue>) -> deserialize::Result { + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { FromSql::::from_sql(value) } } diff --git a/diesel/src/sqlite/types/mod.rs b/diesel/src/sqlite/types/mod.rs index 5c73aeedf44a..326fc40bc5fe 100644 --- a/diesel/src/sqlite/types/mod.rs +++ b/diesel/src/sqlite/types/mod.rs @@ -15,8 +15,8 @@ use crate::sql_types; /// raw pointer instead of a reference with a lifetime due to the structure of /// `FromSql` impl FromSql for *const str { - fn from_sql(value: Option<&SqliteValue>) -> deserialize::Result { - let text = not_none!(value).read_text(); + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + let text = value.read_text(); Ok(text as *const _) } } @@ -27,45 +27,45 @@ impl FromSql for *const str { /// raw pointer instead of a reference with a lifetime due to the structure of /// `FromSql` impl FromSql for *const [u8] { - fn from_sql(bytes: Option<&SqliteValue>) -> deserialize::Result { - let bytes = not_none!(bytes).read_blob(); + fn from_sql(bytes: SqliteValue<'_>) -> deserialize::Result { + let bytes = bytes.read_blob(); Ok(bytes as *const _) } } impl FromSql for i16 { - fn from_sql(value: Option<&SqliteValue>) -> deserialize::Result { - Ok(not_none!(value).read_integer() as i16) + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + Ok(value.read_integer() as i16) } } impl FromSql for i32 { - fn from_sql(value: Option<&SqliteValue>) -> deserialize::Result { - Ok(not_none!(value).read_integer()) + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + Ok(value.read_integer()) } } impl FromSql for bool { - fn from_sql(value: Option<&SqliteValue>) -> deserialize::Result { - Ok(not_none!(value).read_integer() != 0) + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + Ok(value.read_integer() != 0) } } impl FromSql for i64 { - fn from_sql(value: Option<&SqliteValue>) -> deserialize::Result { - Ok(not_none!(value).read_long()) + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + Ok(value.read_long()) } } impl FromSql for f32 { - fn from_sql(value: Option<&SqliteValue>) -> deserialize::Result { - Ok(not_none!(value).read_double() as f32) + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + Ok(value.read_double() as f32) } } impl FromSql for f64 { - fn from_sql(value: Option<&SqliteValue>) -> deserialize::Result { - Ok(not_none!(value).read_double()) + fn from_sql(value: SqliteValue<'_>) -> deserialize::Result { + Ok(value.read_double()) } } diff --git a/diesel/src/sqlite/types/numeric.rs b/diesel/src/sqlite/types/numeric.rs index 19d8a95757fc..976921e56d01 100644 --- a/diesel/src/sqlite/types/numeric.rs +++ b/diesel/src/sqlite/types/numeric.rs @@ -1,8 +1,6 @@ #![cfg(feature = "bigdecimal")] -extern crate bigdecimal; - -use self::bigdecimal::BigDecimal; +use bigdecimal::BigDecimal; use crate::deserialize::{self, FromSql}; use crate::sql_types::{Double, Numeric}; @@ -10,7 +8,7 @@ use crate::sqlite::connection::SqliteValue; use crate::sqlite::Sqlite; impl FromSql for BigDecimal { - fn from_sql(bytes: Option<&SqliteValue>) -> deserialize::Result { + fn from_sql(bytes: SqliteValue<'_>) -> deserialize::Result { let data = >::from_sql(bytes)?; Ok(data.into()) } diff --git a/diesel/src/type_impls/date_and_time.rs b/diesel/src/type_impls/date_and_time.rs index 49fc57d4c08a..c00fec6eae46 100644 --- a/diesel/src/type_impls/date_and_time.rs +++ b/diesel/src/type_impls/date_and_time.rs @@ -4,7 +4,7 @@ use crate::deserialize::FromSqlRow; use crate::expression::AsExpression; use std::time::SystemTime; -#[derive(FromSqlRow, AsExpression)] +#[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[sql_type = "crate::sql_types::Timestamp"] struct SystemTimeProxy(SystemTime); @@ -17,24 +17,24 @@ mod chrono { use crate::expression::AsExpression; use crate::sql_types::{Date, Time, Timestamp}; - #[derive(FromSqlRow, AsExpression)] + #[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[sql_type = "Date"] struct NaiveDateProxy(NaiveDate); - #[derive(FromSqlRow, AsExpression)] + #[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[sql_type = "Time"] struct NaiveTimeProxy(NaiveTime); - #[derive(FromSqlRow, AsExpression)] + #[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[sql_type = "Timestamp"] #[cfg_attr(feature = "postgres", sql_type = "crate::sql_types::Timestamptz")] #[cfg_attr(feature = "mysql", sql_type = "crate::sql_types::Datetime")] struct NaiveDateTimeProxy(NaiveDateTime); - #[derive(FromSqlRow, AsExpression)] + #[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[cfg_attr(feature = "postgres", sql_type = "crate::sql_types::Timestamptz")] struct DateTimeProxy(DateTime); diff --git a/diesel/src/type_impls/decimal.rs b/diesel/src/type_impls/decimal.rs index 17170b368f04..3bce367e4098 100644 --- a/diesel/src/type_impls/decimal.rs +++ b/diesel/src/type_impls/decimal.rs @@ -8,7 +8,7 @@ mod bigdecimal { use crate::expression::AsExpression; use crate::sql_types::Numeric; - #[derive(FromSqlRow, AsExpression)] + #[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[sql_type = "Numeric"] struct BigDecimalProxy(BigDecimal); diff --git a/diesel/src/type_impls/floats.rs b/diesel/src/type_impls/floats.rs index 01d2d139e793..3fd60e680502 100644 --- a/diesel/src/type_impls/floats.rs +++ b/diesel/src/type_impls/floats.rs @@ -11,8 +11,7 @@ impl FromSql for f32 where DB: Backend + for<'a> BinaryRawValue<'a>, { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: crate::backend::RawValue) -> deserialize::Result { let mut bytes = DB::as_bytes(value); debug_assert!( bytes.len() <= 4, @@ -37,8 +36,7 @@ impl FromSql for f64 where DB: Backend + for<'a> BinaryRawValue<'a>, { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: crate::backend::RawValue) -> deserialize::Result { let mut bytes = DB::as_bytes(value); debug_assert!( bytes.len() <= 8, diff --git a/diesel/src/type_impls/integers.rs b/diesel/src/type_impls/integers.rs index 108e8ddcb9c0..bf08ce1878cb 100644 --- a/diesel/src/type_impls/integers.rs +++ b/diesel/src/type_impls/integers.rs @@ -11,8 +11,7 @@ impl FromSql for i16 where DB: Backend + for<'a> BinaryRawValue<'a>, { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: crate::backend::RawValue) -> deserialize::Result { let mut bytes = DB::as_bytes(value); debug_assert!( bytes.len() <= 2, @@ -43,8 +42,7 @@ impl FromSql for i32 where DB: Backend + for<'a> BinaryRawValue<'a>, { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: crate::backend::RawValue) -> deserialize::Result { let mut bytes = DB::as_bytes(value); debug_assert!( bytes.len() <= 4, @@ -74,8 +72,7 @@ impl FromSql for i64 where DB: Backend + for<'a> BinaryRawValue<'a>, { - fn from_sql(value: Option>) -> deserialize::Result { - let value = not_none!(value); + fn from_sql(value: crate::backend::RawValue) -> deserialize::Result { let mut bytes = DB::as_bytes(value); debug_assert!( bytes.len() <= 8, diff --git a/diesel/src/type_impls/json.rs b/diesel/src/type_impls/json.rs index c5ae63bf81cc..45d127c5be3c 100644 --- a/diesel/src/type_impls/json.rs +++ b/diesel/src/type_impls/json.rs @@ -6,7 +6,7 @@ use crate::sql_types::Json; #[cfg(feature = "postgres")] use crate::sql_types::Jsonb; -#[derive(FromSqlRow, AsExpression)] +#[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[sql_type = "Json"] #[cfg_attr(feature = "postgres", sql_type = "Jsonb")] diff --git a/diesel/src/type_impls/mod.rs b/diesel/src/type_impls/mod.rs index ac604b48260c..6b8421951bf3 100644 --- a/diesel/src/type_impls/mod.rs +++ b/diesel/src/type_impls/mod.rs @@ -6,4 +6,4 @@ mod integers; mod json; pub mod option; mod primitives; -mod tuples; +pub(crate) mod tuples; diff --git a/diesel/src/type_impls/option.rs b/diesel/src/type_impls/option.rs index 10414eefe5d2..1929df7b807d 100644 --- a/diesel/src/type_impls/option.rs +++ b/diesel/src/type_impls/option.rs @@ -1,33 +1,26 @@ use std::io::Write; use crate::backend::{self, Backend}; -use crate::deserialize::{self, FromSql, FromSqlRow, Queryable, QueryableByName}; +use crate::deserialize::{self, FromSql, Queryable, QueryableByName}; use crate::expression::bound::Bound; use crate::expression::*; use crate::query_builder::QueryId; -use crate::result::UnexpectedNullError; -use crate::row::NamedRow; use crate::serialize::{self, IsNull, Output, ToSql}; -use crate::sql_types::{HasSqlType, NotNull, Nullable}; +use crate::sql_types::{is_nullable, HasSqlType, Nullable, SingleValue, SqlType}; impl HasSqlType> for DB where DB: Backend + HasSqlType, - T: NotNull, + T: SqlType, { fn metadata(lookup: &DB::MetadataLookup) -> DB::TypeMetadata { >::metadata(lookup) } - - #[cfg(feature = "mysql")] - fn mysql_row_metadata(out: &mut Vec, lookup: &DB::MetadataLookup) { - >::mysql_row_metadata(out, lookup) - } } impl QueryId for Nullable where - T: QueryId + NotNull, + T: QueryId + SqlType, { type QueryId = T::QueryId; @@ -38,64 +31,16 @@ impl FromSql, DB> for Option where T: FromSql, DB: Backend, - ST: NotNull, -{ - fn from_sql(bytes: Option>) -> deserialize::Result { - match bytes { - Some(_) => T::from_sql(bytes).map(Some), - None => Ok(None), - } - } -} - -impl Queryable, DB> for Option -where - T: Queryable, - DB: Backend, - Option: FromSqlRow, DB>, - ST: NotNull, -{ - type Row = Option; - - fn build(row: Self::Row) -> Self { - row.map(T::build) - } -} - -impl QueryableByName for Option -where - T: QueryableByName, - DB: Backend, + ST: SqlType, { - fn build>(row: &R) -> deserialize::Result { - match T::build(row) { - Ok(v) => Ok(Some(v)), - Err(e) => { - if e.is::() { - Ok(None) - } else { - Err(e) - } - } - } + fn from_sql(bytes: backend::RawValue) -> deserialize::Result { + T::from_sql(bytes).map(Some) } -} -impl FromSqlRow, DB> for Option -where - T: FromSqlRow, - DB: Backend, - ST: NotNull, -{ - const FIELDS_NEEDED: usize = T::FIELDS_NEEDED; - - fn build_from_row>(row: &mut R) -> deserialize::Result { - let fields_needed = Self::FIELDS_NEEDED; - if row.next_is_null(fields_needed) { - row.advance(fields_needed); - Ok(None) - } else { - T::build_from_row(row).map(Some) + fn from_nullable_sql(bytes: Option>) -> deserialize::Result { + match bytes { + Some(bytes) => T::from_sql(bytes).map(Some), + None => Ok(None), } } } @@ -104,7 +49,7 @@ impl ToSql, DB> for Option where T: ToSql, DB: Backend, - ST: NotNull, + ST: SqlType, { fn to_sql(&self, out: &mut Output) -> serialize::Result { if let Some(ref value) = *self { @@ -117,7 +62,8 @@ where impl AsExpression> for Option where - ST: NotNull, + ST: SqlType, + Nullable: TypedExpressionType, { type Expression = Bound, Self>; @@ -128,7 +74,8 @@ where impl<'a, T, ST> AsExpression> for &'a Option where - ST: NotNull, + ST: SqlType, + Nullable: TypedExpressionType, { type Expression = Bound, Self>; @@ -137,6 +84,33 @@ where } } +impl QueryableByName for Option +where + DB: Backend, + T: QueryableByName, +{ + fn build<'a>(row: &impl crate::row::NamedRow<'a, DB>) -> deserialize::Result { + match T::build(row) { + Ok(v) => Ok(Some(v)), + Err(e) if e.is::() => Ok(None), + Err(e) => Err(e), + } + } +} + +impl Queryable for Option +where + ST: SingleValue, + DB: Backend, + Self: FromSql, +{ + type Row = Self; + + fn build(row: Self::Row) -> Self { + row + } +} + #[cfg(all(test, feature = "postgres"))] use crate::pg::Pg; #[cfg(all(test, feature = "postgres"))] diff --git a/diesel/src/type_impls/primitives.rs b/diesel/src/type_impls/primitives.rs index 0d889f0a4c95..9da2d45730bc 100644 --- a/diesel/src/type_impls/primitives.rs +++ b/diesel/src/type_impls/primitives.rs @@ -2,42 +2,43 @@ use std::error::Error; use std::io::Write; use crate::backend::{self, Backend, BinaryRawValue}; -use crate::deserialize::{self, FromSql, FromSqlRow, Queryable}; +use crate::deserialize::{self, FromSql, Queryable}; use crate::serialize::{self, IsNull, Output, ToSql}; use crate::sql_types::{ - self, BigInt, Binary, Bool, Double, Float, Integer, NotNull, SmallInt, Text, + self, BigInt, Binary, Bool, Double, Float, Integer, SingleValue, SmallInt, Text, }; #[allow(dead_code)] mod foreign_impls { use super::*; + use crate::deserialize::FromSqlRow; - #[derive(FromSqlRow, AsExpression)] + #[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[sql_type = "Bool"] struct BoolProxy(bool); - #[derive(FromSqlRow, AsExpression)] + #[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[cfg_attr(feature = "mysql", sql_type = "crate::sql_types::TinyInt")] struct I8Proxy(i8); - #[derive(FromSqlRow, AsExpression)] + #[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[sql_type = "SmallInt"] struct I16Proxy(i16); - #[derive(FromSqlRow, AsExpression)] + #[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[sql_type = "Integer"] struct I32Proxy(i32); - #[derive(FromSqlRow, AsExpression)] + #[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[sql_type = "BigInt"] struct I64Proxy(i64); - #[derive(FromSqlRow, AsExpression)] + #[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[cfg_attr( feature = "mysql", @@ -45,33 +46,33 @@ mod foreign_impls { )] struct U8Proxy(u8); - #[derive(FromSqlRow, AsExpression)] + #[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[cfg_attr(feature = "mysql", sql_type = "crate::sql_types::Unsigned")] struct U16Proxy(u16); - #[derive(FromSqlRow, AsExpression)] + #[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[cfg_attr(feature = "mysql", sql_type = "crate::sql_types::Unsigned")] #[cfg_attr(feature = "postgres", sql_type = "crate::sql_types::Oid")] struct U32Proxy(u32); - #[derive(FromSqlRow, AsExpression)] + #[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[cfg_attr(feature = "mysql", sql_type = "crate::sql_types::Unsigned")] struct U64Proxy(u64); - #[derive(FromSqlRow, AsExpression)] + #[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[sql_type = "Float"] struct F32Proxy(f32); - #[derive(FromSqlRow, AsExpression)] + #[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[sql_type = "Double"] struct F64Proxy(f64); - #[derive(FromSqlRow, AsExpression)] + #[derive(AsExpression, FromSqlRow)] #[diesel(foreign_derive)] #[sql_type = "Text"] #[cfg_attr(feature = "sqlite", sql_type = "crate::sql_types::Date")] @@ -102,14 +103,12 @@ mod foreign_impls { struct BinarySliceProxy([u8]); } -impl NotNull for () {} - impl FromSql for String where DB: Backend, *const str: FromSql, { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: backend::RawValue) -> deserialize::Result { let str_ptr = <*const str as FromSql>::from_sql(bytes)?; // We know that the pointer impl will never return null let string = unsafe { &*str_ptr }; @@ -127,9 +126,8 @@ impl FromSql for *const str where DB: Backend + for<'a> BinaryRawValue<'a>, { - fn from_sql(value: Option>) -> deserialize::Result { + fn from_sql(value: crate::backend::RawValue) -> deserialize::Result { use std::str; - let value = not_none!(value); let string = str::from_utf8(DB::as_bytes(value))?; Ok(string as *const _) } @@ -140,9 +138,8 @@ impl FromSql for *const str where DB: Backend + for<'a> BinaryRawValue<'a>, { - default fn from_sql(value: Option>) -> deserialize::Result { + default fn from_sql(value: crate::backend::RawValue) -> deserialize::Result { use std::str; - let value = not_none!(value); let string = str::from_utf8(DB::as_bytes(value))?; Ok(string as *const _) } @@ -171,7 +168,7 @@ where DB: Backend, *const [u8]: FromSql, { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: backend::RawValue) -> deserialize::Result { let slice_ptr = <*const [u8] as FromSql>::from_sql(bytes)?; // We know that the pointer impl will never return null let bytes = unsafe { &*slice_ptr }; @@ -188,8 +185,8 @@ impl FromSql for *const [u8] where DB: Backend + for<'a> BinaryRawValue<'a>, { - fn from_sql(bytes: Option>) -> deserialize::Result { - Ok(DB::as_bytes(not_none!(bytes)) as *const _) + fn from_sql(bytes: backend::RawValue) -> deserialize::Result { + Ok(DB::as_bytes(bytes) as *const _) } } @@ -230,27 +227,17 @@ where DB: Backend, T::Owned: FromSql, { - fn from_sql(bytes: Option>) -> deserialize::Result { + fn from_sql(bytes: backend::RawValue) -> deserialize::Result { T::Owned::from_sql(bytes).map(Cow::Owned) } } -impl<'a, T: ?Sized, ST, DB> FromSqlRow for Cow<'a, T> -where - T: 'a + ToOwned, - DB: Backend, - Cow<'a, T>: FromSql, -{ - fn build_from_row>(row: &mut R) -> deserialize::Result { - FromSql::::from_sql(row.take()) - } -} - impl<'a, T: ?Sized, ST, DB> Queryable for Cow<'a, T> where T: 'a + ToOwned, + ST: SingleValue, DB: Backend, - Self: FromSqlRow, + Self: FromSql, { type Row = Self; @@ -260,12 +247,14 @@ where } use crate::expression::bound::Bound; -use crate::expression::{AsExpression, Expression}; +use crate::expression::{AsExpression, Expression, TypedExpressionType}; +use sql_types::SqlType; impl<'a, T: ?Sized, ST> AsExpression for Cow<'a, T> where T: 'a + ToOwned, Bound>: Expression, + ST: SqlType + TypedExpressionType, { type Expression = Bound; @@ -278,6 +267,7 @@ impl<'a, 'b, T: ?Sized, ST> AsExpression for &'b Cow<'a, T> where T: 'a + ToOwned, Bound: Expression, + ST: SqlType + TypedExpressionType, { type Expression = Bound; diff --git a/diesel/src/type_impls/tuples.rs b/diesel/src/type_impls/tuples.rs index a544c10dca74..e97ab157b85c 100644 --- a/diesel/src/type_impls/tuples.rs +++ b/diesel/src/type_impls/tuples.rs @@ -1,19 +1,29 @@ -use std::error::Error; - use crate::associations::BelongsTo; use crate::backend::Backend; -use crate::deserialize::{self, FromSqlRow, Queryable, QueryableByName}; +use crate::deserialize::{self, FromSqlRow, FromStaticSqlRow, Queryable, StaticallySizedRow}; use crate::expression::{ - AppearsOnTable, AsExpression, AsExpressionList, Expression, SelectableExpression, ValidGrouping, + AppearsOnTable, AsExpression, AsExpressionList, Expression, QueryMetadata, + SelectableExpression, TypedExpressionType, ValidGrouping, }; use crate::insertable::{CanInsertInSingleQuery, InsertValues, Insertable}; use crate::query_builder::*; use crate::query_source::*; use crate::result::QueryResult; use crate::row::*; -use crate::sql_types::{HasSqlType, NotNull}; +use crate::sql_types::{HasSqlType, IntoNullable, Nullable, OneIsNullable, SqlType}; use crate::util::TupleAppend; +pub trait TupleSize { + const SIZE: usize; +} + +impl TupleSize for T +where + T: crate::sql_types::SingleValue, +{ + const SIZE: usize = 1; +} + macro_rules! tuple_impls { ($( $Tuple:tt { @@ -28,50 +38,27 @@ macro_rules! tuple_impls { fn metadata(_: &__DB::MetadataLookup) -> __DB::TypeMetadata { unreachable!("Tuples should never implement `ToSql` directly"); } - - #[cfg(feature = "mysql")] - fn mysql_row_metadata(out: &mut Vec<__DB::TypeMetadata>, lookup: &__DB::MetadataLookup) { - $(<__DB as HasSqlType<$T>>::mysql_row_metadata(out, lookup);)+ - } } - impl<$($T),+> NotNull for ($($T,)+) { - } + impl_from_sql_row!(($($T,)+), ($($ST,)+)); - impl<$($T),+, $($ST),+, __DB> FromSqlRow<($($ST,)+), __DB> for ($($T,)+) where - __DB: Backend, - $($T: FromSqlRow<$ST, __DB>),+, - { - const FIELDS_NEEDED: usize = $($T::FIELDS_NEEDED +)+ 0; - fn build_from_row>(row: &mut RowT) -> Result> { - Ok(($($T::build_from_row(row)?,)+)) - } - } - - impl<$($T),+, $($ST),+, __DB> Queryable<($($ST,)+), __DB> for ($($T,)+) where - __DB: Backend, - $($T: Queryable<$ST, __DB>),+, + impl<$($T: Expression),+> Expression for ($($T,)+) + where ($($T::SqlType, )*): TypedExpressionType { - type Row = ($($T::Row,)+); - - fn build(row: Self::Row) -> Self { - ($($T::build(row.$idx),)+) - } + type SqlType = ($(<$T as Expression>::SqlType,)+); } - impl<$($T,)+ __DB> QueryableByName<__DB> for ($($T,)+) - where - __DB: Backend, - $($T: QueryableByName<__DB>,)+ + impl<$($T: TypedExpressionType,)*> TypedExpressionType for ($($T,)*) {} + impl<$($T: SqlType + TypedExpressionType,)*> TypedExpressionType for Nullable<($($T,)*)> + where ($($T,)*): SqlType { - fn build>(row: &RowT) -> deserialize::Result { - Ok(($($T::build(row)?,)+)) - } } - impl<$($T: Expression),+> Expression for ($($T,)+) { - type SqlType = ($(<$T as Expression>::SqlType,)+); + impl<$($T: SqlType,)*> IntoNullable for ($($T,)*) + where Self: SqlType, + { + type Nullable = Nullable<($($T,)*)>; } impl<$($T: QueryFragment<__DB>),+, __DB: Backend> QueryFragment<__DB> for ($($T,)+) { @@ -233,6 +220,7 @@ macro_rules! tuple_impls { impl<$($T,)+ ST> AsExpressionList for ($($T,)+) where $($T: AsExpression,)+ + ST: SqlType + TypedExpressionType, { type Expression = ($($T::Expression,)+); @@ -240,8 +228,189 @@ macro_rules! tuple_impls { ($(self.$idx.as_expression(),)+) } } + + impl_sql_type!($($T,)*); + + impl<$($T,)* __DB, $($ST,)*> Queryable<($($ST,)*), __DB> for ($($T,)*) + where __DB: Backend, + Self: FromStaticSqlRow<($($ST,)*), __DB>, + { + type Row = Self; + + fn build(row: Self::Row) -> Self { + row + } + } + + impl<__T, $($ST,)* __DB> FromStaticSqlRow, __DB> for Option<__T> where + __DB: Backend, + ($($ST,)*): SqlType, + __T: FromSqlRow<($($ST,)*), __DB>, + { + + #[allow(non_snake_case, unused_variables, unused_mut)] + fn build_from_row<'a>(row: &impl Row<'a, __DB>) + -> deserialize::Result + { + match <__T as FromSqlRow<($($ST,)*), __DB>>::build_from_row(row) { + Ok(v) => Ok(Some(v)), + Err(e) if e.is::() => Ok(None), + Err(e) => Err(e) + } + } + } + + impl<__T, __DB, $($ST,)*> Queryable, __DB> for Option<__T> + where __DB: Backend, + Self: FromStaticSqlRow, __DB>, + ($($ST,)*): SqlType, + { + type Row = Self; + + fn build(row: Self::Row) -> Self { + row + } + } + + impl<$($T,)*> TupleSize for ($($T,)*) + where $($T: TupleSize,)* + { + const SIZE: usize = $($T::SIZE +)* 0; + } + + impl<$($T,)*> TupleSize for Nullable<($($T,)*)> + where $($T: TupleSize,)* + ($($T,)*): SqlType, + { + const SIZE: usize = $($T::SIZE +)* 0; + } + + impl<$($T,)* __DB> QueryMetadata<($($T,)*)> for __DB + where __DB: Backend, + $(__DB: QueryMetadata<$T>,)* + { + fn row_metadata(lookup: &Self::MetadataLookup, row: &mut Vec>) { + $( + <__DB as QueryMetadata<$T>>::row_metadata(lookup, row); + )* + } + } + + impl<$($T,)* __DB> QueryMetadata> for __DB + where __DB: Backend, + $(__DB: QueryMetadata<$T>,)* + { + fn row_metadata(lookup: &Self::MetadataLookup, row: &mut Vec>) { + $( + <__DB as QueryMetadata<$T>>::row_metadata(lookup, row); + )* + } + } + + impl<$($T,)* __DB> deserialize::QueryableByName< __DB> for ($($T,)*) + where __DB: Backend, + $($T: deserialize::QueryableByName<__DB>,)* + { + fn build<'a>(row: &impl NamedRow<'a, __DB>) -> deserialize::Result { + Ok(($( + <$T as deserialize::QueryableByName<__DB>>::build(row)?, + )*)) + } + } + )+ } } +macro_rules! impl_from_sql_row { + (($T1: ident,), ($ST1: ident,)) => { + impl<$T1, $ST1, __DB> crate::deserialize::FromStaticSqlRow<($ST1,), __DB> for ($T1,) where + __DB: Backend, + $T1: FromSqlRow<$ST1, __DB>, + { + + #[allow(non_snake_case, unused_variables, unused_mut)] + fn build_from_row<'a>(row: &impl Row<'a, __DB>) + -> deserialize::Result + { + Ok(($T1::build_from_row(row)?,)) + } + } + }; + (($T1: ident, $($T: ident,)*), ($ST1: ident, $($ST: ident,)*)) => { + impl<$T1, $($T,)* $($ST,)* __DB> FromSqlRow<($($ST,)* crate::sql_types::Untyped), __DB> for ($($T,)* $T1) + where __DB: Backend, + $T1: FromSqlRow, + $( + $T: FromSqlRow<$ST, __DB> + StaticallySizedRow<$ST, __DB>, + )* + { + #[allow(non_snake_case, unused_variables, unused_mut)] + fn build_from_row<'a>(full_row: &impl Row<'a, __DB>) + -> deserialize::Result + { + let field_count = full_row.field_count(); + + let mut static_field_count = 0; + $( + let row = full_row.partial_row(static_field_count..static_field_count + $T::FIELD_COUNT); + static_field_count += $T::FIELD_COUNT; + let $T = $T::build_from_row(&row)?; + )* + + let row = full_row.partial_row(static_field_count..field_count); + + Ok(($($T,)* $T1::build_from_row(&row)?,)) + } + } + + impl<$T1, $ST1, $($T,)* $($ST,)* __DB> FromStaticSqlRow<($($ST,)* $ST1,), __DB> for ($($T,)* $T1,) where + __DB: Backend, + $T1: FromSqlRow<$ST1, __DB>, + $( + $T: FromSqlRow<$ST, __DB> + StaticallySizedRow<$ST, __DB>, + )* + + { + + #[allow(non_snake_case, unused_variables, unused_mut)] + fn build_from_row<'a>(full_row: &impl Row<'a, __DB>) + -> deserialize::Result + { + let field_count = full_row.field_count(); + + let mut static_field_count = 0; + $( + let row = full_row.partial_row(static_field_count..static_field_count + $T::FIELD_COUNT); + static_field_count += $T::FIELD_COUNT; + let $T = $T::build_from_row(&row)?; + )* + + let row = full_row.partial_row(static_field_count..field_count); + + Ok(($($T,)* $T1::build_from_row(&row)?,)) + } + } + } +} + +macro_rules! impl_sql_type { + ($T1: ident, $($T: ident,)+) => { + impl<$T1, $($T,)+> SqlType for ($T1, $($T,)*) + where $T1: SqlType, + ($($T,)*): SqlType, + $T1::IsNull: OneIsNullable<<($($T,)*) as SqlType>::IsNull>, + { + type IsNull = <$T1::IsNull as OneIsNullable<<($($T,)*) as SqlType>::IsNull>>::Out; + } + }; + ($T1: ident,) => { + impl<$T1> SqlType for ($T1,) + where $T1: SqlType, + { + type IsNull = $T1::IsNull; + } + } +} + __diesel_for_each_tuple!(tuple_impls); diff --git a/diesel_cli/src/infer_schema_internals/data_structures.rs b/diesel_cli/src/infer_schema_internals/data_structures.rs index 1e4fdd76a2a0..c503637b4272 100644 --- a/diesel_cli/src/infer_schema_internals/data_structures.rs +++ b/diesel_cli/src/infer_schema_internals/data_structures.rs @@ -1,6 +1,6 @@ #[cfg(feature = "uses_information_schema")] use diesel::backend::Backend; -use diesel::deserialize::{FromSqlRow, Queryable}; +use diesel::deserialize::{FromStaticSqlRow, Queryable}; #[cfg(feature = "sqlite")] use diesel::sqlite::Sqlite; @@ -76,7 +76,7 @@ impl ColumnInformation { impl Queryable for ColumnInformation where DB: Backend + UsesInformationSchema, - (String, String, String): FromSqlRow, + (String, String, String): FromStaticSqlRow, { type Row = (String, String, String); @@ -88,7 +88,7 @@ where #[cfg(feature = "sqlite")] impl Queryable for ColumnInformation where - (i32, String, String, bool, Option, bool): FromSqlRow, + (i32, String, String, bool, Option, bool): FromStaticSqlRow, { type Row = (i32, String, String, bool, Option, bool); diff --git a/diesel_cli/src/infer_schema_internals/information_schema.rs b/diesel_cli/src/infer_schema_internals/information_schema.rs index 91bee4395965..df63e1a055a0 100644 --- a/diesel_cli/src/infer_schema_internals/information_schema.rs +++ b/diesel_cli/src/infer_schema_internals/information_schema.rs @@ -2,9 +2,9 @@ use std::borrow::Cow; use std::error::Error; use diesel::backend::Backend; -use diesel::deserialize::FromSql; +use diesel::deserialize::{FromSql, FromSqlRow}; use diesel::dsl::*; -use diesel::expression::{is_aggregate, ValidGrouping}; +use diesel::expression::{is_aggregate, QueryMetadata, ValidGrouping}; #[cfg(feature = "mysql")] use diesel::mysql::Mysql; #[cfg(feature = "postgres")] @@ -78,7 +78,8 @@ mod information_schema { table_schema -> VarChar, table_name -> VarChar, column_name -> VarChar, - is_nullable -> VarChar, + #[sql_name = "is_nullable"] + __is_nullable -> VarChar, ordinal_position -> BigInt, udt_name -> VarChar, column_type -> VarChar, @@ -126,6 +127,14 @@ pub fn get_table_data<'a, Conn>( where Conn: Connection, Conn::Backend: UsesInformationSchema, + ColumnInformation: FromSqlRow< + SqlTypeOf<( + columns::column_name, + ::TypeColumn, + columns::__is_nullable, + )>, + Conn::Backend, + >, String: FromSql, Order< Filter< @@ -135,7 +144,7 @@ where ( columns::column_name, ::TypeColumn, - columns::is_nullable, + columns::__is_nullable, ), >, Eq, @@ -144,6 +153,7 @@ where >, columns::ordinal_position, >: QueryFragment, + Conn::Backend: QueryMetadata<(sql_types::Text, sql_types::Text, sql_types::Text)>, { use self::information_schema::columns::dsl::*; @@ -154,7 +164,7 @@ where let type_column = Conn::Backend::type_column(); columns - .select((column_name, type_column, is_nullable)) + .select((column_name, type_column, __is_nullable)) .filter(table_name.eq(&table.sql_name)) .filter(table_schema.eq(schema_name)) .order(ordinal_position) @@ -185,6 +195,7 @@ where >, key_column_usage::ordinal_position, >: QueryFragment, + Conn::Backend: QueryMetadata, { use self::information_schema::key_column_usage::dsl::*; use self::information_schema::table_constraints::constraint_type; @@ -225,6 +236,7 @@ where >, Like, >: QueryFragment, + Conn::Backend: QueryMetadata, { use self::information_schema::tables::dsl::*; diff --git a/diesel_cli/src/infer_schema_internals/sqlite.rs b/diesel_cli/src/infer_schema_internals/sqlite.rs index 765a903f38d4..80fef012598b 100644 --- a/diesel_cli/src/infer_schema_internals/sqlite.rs +++ b/diesel_cli/src/infer_schema_internals/sqlite.rs @@ -52,7 +52,7 @@ pub fn load_table_names( .select(name) .filter(name.not_like("\\_\\_%").escape('\\')) .filter(name.not_like("sqlite%")) - .filter(sql("type='table'")) + .filter(sql::("type='table'")) .order(name) .load::(connection)? .into_iter() diff --git a/diesel_cli/src/infer_schema_internals/table_data.rs b/diesel_cli/src/infer_schema_internals/table_data.rs index f6c6bad05bc4..2619783ca6d2 100644 --- a/diesel_cli/src/infer_schema_internals/table_data.rs +++ b/diesel_cli/src/infer_schema_internals/table_data.rs @@ -1,5 +1,5 @@ use diesel::backend::Backend; -use diesel::deserialize::{FromSqlRow, Queryable}; +use diesel::deserialize::{FromStaticSqlRow, Queryable}; use std::fmt; use std::str::FromStr; @@ -55,8 +55,8 @@ impl TableName { impl Queryable for TableName where + (String, String): FromStaticSqlRow, DB: Backend, - (String, String): FromSqlRow, { type Row = (String, String); diff --git a/diesel_compile_tests/tests/compile-fail/cannot_mix_aggregate_and_non_aggregate_selects.rs b/diesel_compile_tests/tests/compile-fail/cannot_mix_aggregate_and_non_aggregate_selects.rs index 360114dacb74..671fd4b90478 100644 --- a/diesel_compile_tests/tests/compile-fail/cannot_mix_aggregate_and_non_aggregate_selects.rs +++ b/diesel_compile_tests/tests/compile-fail/cannot_mix_aggregate_and_non_aggregate_selects.rs @@ -2,7 +2,7 @@ extern crate diesel; use diesel::*; -use diesel::dsl::count; +use diesel::dsl::count_star; table! { users { @@ -13,6 +13,6 @@ table! { fn main() { use self::users::dsl::*; - let source = users.select((id, count(users.star()))); + let source = users.select((id, count_star())); //~^ ERROR MixedAggregates } diff --git a/diesel_compile_tests/tests/compile-fail/filter_requires_bool_nonaggregate_expression.rs b/diesel_compile_tests/tests/compile-fail/filter_requires_bool_nonaggregate_expression.rs index b262cc08ea0a..83e7e8046293 100644 --- a/diesel_compile_tests/tests/compile-fail/filter_requires_bool_nonaggregate_expression.rs +++ b/diesel_compile_tests/tests/compile-fail/filter_requires_bool_nonaggregate_expression.rs @@ -14,7 +14,7 @@ fn main() { use diesel::dsl::sum; let _ = users::table.filter(users::name); - //~^ ERROR type mismatch resolving `::SqlType == diesel::sql_types::Bool` + //~^ ERROR the trait bound `diesel::sql_types::Text: diesel::sql_types::BoolOrNullableBool` is not satisfied let _ = users::table.filter(sum(users::id).eq(1)); //~^ ERROR MixedAggregates } diff --git a/diesel_compile_tests/tests/compile-fail/insert_statement_does_not_support_returning_methods_on_sqlite.rs b/diesel_compile_tests/tests/compile-fail/insert_statement_does_not_support_returning_methods_on_sqlite.rs index 3505a61c6ec3..c667bf7f1be7 100644 --- a/diesel_compile_tests/tests/compile-fail/insert_statement_does_not_support_returning_methods_on_sqlite.rs +++ b/diesel_compile_tests/tests/compile-fail/insert_statement_does_not_support_returning_methods_on_sqlite.rs @@ -1,10 +1,11 @@ #[macro_use] extern crate diesel; -use diesel::*; -use diesel::sqlite::SqliteConnection; use diesel::backend::Backend; use diesel::sql_types::{Integer, VarChar}; +use diesel::sqlite::SqliteConnection; +use diesel::deserialize::Queryable; +use diesel::*; table! { users { @@ -13,25 +14,12 @@ table! { } } +#[derive(Queryable)] pub struct User { id: i32, name: String, } -use diesel::deserialize::FromSqlRow; - -impl Queryable<(Integer, VarChar), DB> for User where - (i32, String): FromSqlRow<(Integer, VarChar), DB>, -{ - type Row = (i32, String); - - fn build(row: Self::Row) -> Self { - User { - id: row.0, - name: row.1, - } - } -} #[derive(Insertable)] #[table_name = "users"] diff --git a/diesel_compile_tests/tests/compile-fail/join_with_explicit_on_requires_valid_boolean_expression.rs b/diesel_compile_tests/tests/compile-fail/join_with_explicit_on_requires_valid_boolean_expression.rs index 509cfc64b5b9..949b33ada6f1 100644 --- a/diesel_compile_tests/tests/compile-fail/join_with_explicit_on_requires_valid_boolean_expression.rs +++ b/diesel_compile_tests/tests/compile-fail/join_with_explicit_on_requires_valid_boolean_expression.rs @@ -32,5 +32,5 @@ fn main() { //~^ ERROR E0271 // Invalid, type is not boolean let _ = users::table.inner_join(posts::table.on(users::id)); - //~^ ERROR E0271 + //~^ ERROR the trait bound `diesel::sql_types::Integer: diesel::sql_types::BoolOrNullableBool` is not satisfied [E0277] } diff --git a/diesel_compile_tests/tests/compile-fail/right_side_of_left_join_requires_nullable.rs b/diesel_compile_tests/tests/compile-fail/right_side_of_left_join_requires_nullable.rs index 7dc540cf90a9..0fd459b41718 100644 --- a/diesel_compile_tests/tests/compile-fail/right_side_of_left_join_requires_nullable.rs +++ b/diesel_compile_tests/tests/compile-fail/right_side_of_left_join_requires_nullable.rs @@ -52,7 +52,6 @@ fn direct_joins() { // Invalid, Nullable is selectable, but lower expects not-null let _ = join.select(lower(posts::title.nullable())); //~^ ERROR E0271 - //~| ERROR E0271 } fn nested_outer_joins_left_associative() { @@ -74,7 +73,6 @@ fn nested_outer_joins_left_associative() { // Invalid, Nullable<title> is selectable, but lower expects not-null let _ = join.select(lower(posts::title.nullable())); //~^ ERROR E0271 - //~| ERROR E0271 } fn nested_mixed_joins_left_associative() { @@ -95,7 +93,6 @@ fn nested_mixed_joins_left_associative() { // Invalid, Nullable<title> is selectable, but lower expects not-null let _ = join.select(lower(posts::title.nullable())); //~^ ERROR E0271 - //~| ERROR E0271 } fn nested_outer_joins_right_associative() { @@ -116,7 +113,6 @@ fn nested_outer_joins_right_associative() { // Invalid, Nullable<title> is selectable, but lower expects not-null let _ = join.select(lower(posts::title.nullable())); //~^ ERROR E0271 - //~| ERROR E0271 } fn nested_mixed_joins_right_associative() { @@ -137,5 +133,4 @@ fn nested_mixed_joins_right_associative() { // Invalid, Nullable<title> is selectable, but lower expects not-null let _ = join.select(lower(posts::title.nullable())); //~^ ERROR E0271 - //~| ERROR E0271 } diff --git a/diesel_compile_tests/tests/ui/queryable_by_name_requires_table_name_or_sql_type_annotation.stderr b/diesel_compile_tests/tests/ui/queryable_by_name_requires_table_name_or_sql_type_annotation.stderr index f800624d2f43..854455ef05b0 100644 --- a/diesel_compile_tests/tests/ui/queryable_by_name_requires_table_name_or_sql_type_annotation.stderr +++ b/diesel_compile_tests/tests/ui/queryable_by_name_requires_table_name_or_sql_type_annotation.stderr @@ -20,6 +20,12 @@ error: All fields of tuple structs must be annotated with `#[column_name]` 10 | struct Bar(i32, String); | ^^^ +error: All fields of tuple structs must be annotated with `#[column_name]` + --> $DIR/queryable_by_name_requires_table_name_or_sql_type_annotation.rs:10:17 + | +10 | struct Bar(i32, String); + | ^^^^^^ + error: Cannot determine the SQL type of field --> $DIR/queryable_by_name_requires_table_name_or_sql_type_annotation.rs:10:12 | @@ -28,12 +34,6 @@ error: Cannot determine the SQL type of field | = help: Your struct must either be annotated with `#[table_name = "foo"]` or have all of its fields annotated with `#[sql_type = "Integer"]` -error: All fields of tuple structs must be annotated with `#[column_name]` - --> $DIR/queryable_by_name_requires_table_name_or_sql_type_annotation.rs:10:17 - | -10 | struct Bar(i32, String); - | ^^^^^^ - error: Cannot determine the SQL type of field --> $DIR/queryable_by_name_requires_table_name_or_sql_type_annotation.rs:10:17 | diff --git a/diesel_compile_tests/tests/ui/queryable_type_missmatch.rs b/diesel_compile_tests/tests/ui/queryable_type_missmatch.rs new file mode 100644 index 000000000000..c767135a065b --- /dev/null +++ b/diesel_compile_tests/tests/ui/queryable_type_missmatch.rs @@ -0,0 +1,71 @@ +extern crate diesel; + +use diesel::prelude::*; + +table! { + users { + id -> Integer, + name -> Text, + bio -> Nullable<Text>, + } +} + +#[derive(Queryable)] +struct User { + id: i32, + name: String, + bio: Option<String> +} + +#[derive(Queryable)] +struct UserWithToFewFields { + id: i32, + name: String, +} + +#[derive(Queryable)] +struct UserWithToManyFields { + id: i32, + name: String, + bio: Option<String>, + age: i32, +} + +#[derive(Queryable)] +struct UserWrongOrder { + name: String, + id: i32, + bio: Option<String> +} + +#[derive(Queryable)] +struct UserTypeMissmatch { + id: i32, + name: i32, + bio: Option<String> +} + +#[derive(Queryable)] +struct UserNullableTypeMissmatch { + id: i32, + name: String, + bio: Option<String> +} + +fn test(conn: &PgConnection) { + // check that this works fine + let _ = users::table.load::<User>(conn); + + let _ = users::table.load::<UserWithToFewFields>(conn); + + let _ = users::table.load::<UserWithToManyFields>(conn); + + let _ = users::table.load::<UserWrongOrder>(conn); + + let _ = users::table.load::<UserTypeMissmatch>(conn); + + let _ = users::table.load::<UserNullableTypeMissmatch>(conn); +} + + +fn main() {} diff --git a/diesel_compile_tests/tests/ui/queryable_type_missmatch.stderr b/diesel_compile_tests/tests/ui/queryable_type_missmatch.stderr new file mode 100644 index 000000000000..adb85ae8d7fc --- /dev/null +++ b/diesel_compile_tests/tests/ui/queryable_type_missmatch.stderr @@ -0,0 +1,49 @@ +error[E0277]: the trait bound `UserWithToFewFields: diesel::Queryable<(diesel::sql_types::Integer, diesel::sql_types::Text, diesel::sql_types::Nullable<diesel::sql_types::Text>), _>` is not satisfied + --> $DIR/queryable_type_missmatch.rs:59:26 + | +59 | let _ = users::table.load::<UserWithToFewFields>(conn); + | ^^^^ the trait `diesel::Queryable<(diesel::sql_types::Integer, diesel::sql_types::Text, diesel::sql_types::Nullable<diesel::sql_types::Text>), _>` is not implemented for `UserWithToFewFields` + | + = help: the following implementations were found: + <UserWithToFewFields as diesel::Queryable<(__ST0, __ST1), __DB>> + = note: required because of the requirements on the impl of `diesel::deserialize::FromSqlRow<(diesel::sql_types::Integer, diesel::sql_types::Text, diesel::sql_types::Nullable<diesel::sql_types::Text>), _>` for `UserWithToFewFields` + = note: required because of the requirements on the impl of `diesel::query_dsl::LoadQuery<_, UserWithToFewFields>` for `users::table` + +error[E0277]: the trait bound `UserWithToManyFields: diesel::Queryable<(diesel::sql_types::Integer, diesel::sql_types::Text, diesel::sql_types::Nullable<diesel::sql_types::Text>), _>` is not satisfied + --> $DIR/queryable_type_missmatch.rs:61:26 + | +61 | let _ = users::table.load::<UserWithToManyFields>(conn); + | ^^^^ the trait `diesel::Queryable<(diesel::sql_types::Integer, diesel::sql_types::Text, diesel::sql_types::Nullable<diesel::sql_types::Text>), _>` is not implemented for `UserWithToManyFields` + | + = help: the following implementations were found: + <UserWithToManyFields as diesel::Queryable<(__ST0, __ST1, __ST2, __ST3), __DB>> + = note: required because of the requirements on the impl of `diesel::deserialize::FromSqlRow<(diesel::sql_types::Integer, diesel::sql_types::Text, diesel::sql_types::Nullable<diesel::sql_types::Text>), _>` for `UserWithToManyFields` + = note: required because of the requirements on the impl of `diesel::query_dsl::LoadQuery<_, UserWithToManyFields>` for `users::table` + +error[E0277]: the trait bound `(std::string::String, i32, std::option::Option<std::string::String>): diesel::deserialize::FromStaticSqlRow<(diesel::sql_types::Integer, diesel::sql_types::Text, diesel::sql_types::Nullable<diesel::sql_types::Text>), _>` is not satisfied + --> $DIR/queryable_type_missmatch.rs:63:26 + | +63 | let _ = users::table.load::<UserWrongOrder>(conn); + | ^^^^ the trait `diesel::deserialize::FromStaticSqlRow<(diesel::sql_types::Integer, diesel::sql_types::Text, diesel::sql_types::Nullable<diesel::sql_types::Text>), _>` is not implemented for `(std::string::String, i32, std::option::Option<std::string::String>)` + | + = help: the following implementations were found: + <(B, C, A) as diesel::deserialize::FromStaticSqlRow<(SB, SC, SA), __DB>> + = note: required because of the requirements on the impl of `diesel::Queryable<(diesel::sql_types::Integer, diesel::sql_types::Text, diesel::sql_types::Nullable<diesel::sql_types::Text>), _>` for `UserWrongOrder` + = note: required because of the requirements on the impl of `diesel::deserialize::FromSqlRow<(diesel::sql_types::Integer, diesel::sql_types::Text, diesel::sql_types::Nullable<diesel::sql_types::Text>), _>` for `UserWrongOrder` + = note: required because of the requirements on the impl of `diesel::query_dsl::LoadQuery<_, UserWrongOrder>` for `users::table` + +error[E0277]: the trait bound `(i32, i32, std::option::Option<std::string::String>): diesel::deserialize::FromStaticSqlRow<(diesel::sql_types::Integer, diesel::sql_types::Text, diesel::sql_types::Nullable<diesel::sql_types::Text>), _>` is not satisfied + --> $DIR/queryable_type_missmatch.rs:65:26 + | +65 | let _ = users::table.load::<UserTypeMissmatch>(conn); + | ^^^^ the trait `diesel::deserialize::FromStaticSqlRow<(diesel::sql_types::Integer, diesel::sql_types::Text, diesel::sql_types::Nullable<diesel::sql_types::Text>), _>` is not implemented for `(i32, i32, std::option::Option<std::string::String>)` + | + = help: the following implementations were found: + <(B, C, A) as diesel::deserialize::FromStaticSqlRow<(SB, SC, SA), __DB>> + = note: required because of the requirements on the impl of `diesel::Queryable<(diesel::sql_types::Integer, diesel::sql_types::Text, diesel::sql_types::Nullable<diesel::sql_types::Text>), _>` for `UserTypeMissmatch` + = note: required because of the requirements on the impl of `diesel::deserialize::FromSqlRow<(diesel::sql_types::Integer, diesel::sql_types::Text, diesel::sql_types::Nullable<diesel::sql_types::Text>), _>` for `UserTypeMissmatch` + = note: required because of the requirements on the impl of `diesel::query_dsl::LoadQuery<_, UserTypeMissmatch>` for `users::table` + +error: aborting due to 4 previous errors + +For more information about this error, try `rustc --explain E0277`. diff --git a/diesel_derives/src/diesel_numeric_ops.rs b/diesel_derives/src/diesel_numeric_ops.rs index f0b60a369c95..b8e02199097c 100644 --- a/diesel_derives/src/diesel_numeric_ops.rs +++ b/diesel_derives/src/diesel_numeric_ops.rs @@ -22,10 +22,13 @@ pub fn derive(mut item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Di Ok(wrap_in_dummy_mod(quote! { use diesel::expression::{ops, Expression, AsExpression}; use diesel::sql_types::ops::{Add, Sub, Mul, Div}; + use diesel::sql_types::{SqlType, SingleValue}; impl #impl_generics ::std::ops::Add<__Rhs> for #struct_name #ty_generics #where_clause + Self: Expression, <Self as Expression>::SqlType: Add, + <<Self as Expression>::SqlType as Add>::Rhs: SqlType + SingleValue, __Rhs: AsExpression<<<Self as Expression>::SqlType as Add>::Rhs>, { type Output = ops::Add<Self, __Rhs::Expression>; @@ -37,7 +40,9 @@ pub fn derive(mut item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Di impl #impl_generics ::std::ops::Sub<__Rhs> for #struct_name #ty_generics #where_clause + Self: Expression, <Self as Expression>::SqlType: Sub, + <<Self as Expression>::SqlType as Sub>::Rhs: SqlType + SingleValue, __Rhs: AsExpression<<<Self as Expression>::SqlType as Sub>::Rhs>, { type Output = ops::Sub<Self, __Rhs::Expression>; @@ -49,7 +54,9 @@ pub fn derive(mut item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Di impl #impl_generics ::std::ops::Mul<__Rhs> for #struct_name #ty_generics #where_clause + Self: Expression, <Self as Expression>::SqlType: Mul, + <<Self as Expression>::SqlType as Mul>::Rhs: SqlType + SingleValue, __Rhs: AsExpression<<<Self as Expression>::SqlType as Mul>::Rhs>, { type Output = ops::Mul<Self, __Rhs::Expression>; @@ -61,7 +68,9 @@ pub fn derive(mut item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Di impl #impl_generics ::std::ops::Div<__Rhs> for #struct_name #ty_generics #where_clause + Self: Expression, <Self as Expression>::SqlType: Div, + <<Self as Expression>::SqlType as Div>::Rhs: SqlType + SingleValue, __Rhs: AsExpression<<<Self as Expression>::SqlType as Div>::Rhs>, { type Output = ops::Div<Self, __Rhs::Expression>; diff --git a/diesel_derives/src/from_sql_row.rs b/diesel_derives/src/from_sql_row.rs index 823708cee78f..c7edb9b7de08 100644 --- a/diesel_derives/src/from_sql_row.rs +++ b/diesel_derives/src/from_sql_row.rs @@ -19,6 +19,9 @@ pub fn derive(mut item: syn::DeriveInput) -> Result<TokenStream, Diagnostic> { where_clause .predicates .push(parse_quote!(__DB: diesel::backend::Backend)); + where_clause + .predicates + .push(parse_quote!(__ST: diesel::sql_types::SingleValue)); where_clause .predicates .push(parse_quote!(Self: FromSql<__ST, __DB>)); @@ -26,17 +29,7 @@ pub fn derive(mut item: syn::DeriveInput) -> Result<TokenStream, Diagnostic> { let (impl_generics, _, where_clause) = item.generics.split_for_impl(); Ok(wrap_in_dummy_mod(quote! { - use diesel::deserialize::{self, FromSql, FromSqlRow, Queryable}; - - impl #impl_generics FromSqlRow<__ST, __DB> for #struct_ty - #where_clause - { - fn build_from_row<R: diesel::row::Row<__DB>>(row: &mut R) - -> deserialize::Result<Self> - { - FromSql::<__ST, __DB>::from_sql(row.take()) - } - } + use diesel::deserialize::{FromSql, Queryable}; impl #impl_generics Queryable<__ST, __DB> for #struct_ty #where_clause diff --git a/diesel_derives/src/lib.rs b/diesel_derives/src/lib.rs index aae4f5355220..7805f8bb0774 100644 --- a/diesel_derives/src/lib.rs +++ b/diesel_derives/src/lib.rs @@ -171,7 +171,7 @@ pub fn derive_diesel_numeric_ops(input: TokenStream) -> TokenStream { expand_proc_macro(input, diesel_numeric_ops::derive) } -/// Implements `FromSqlRow` and `Queryable` +/// Implements `Queryable` for primitive types /// /// This derive is mostly useful to implement support deserializing /// into rust types not supported by diesel itself. @@ -307,7 +307,7 @@ pub fn derive_query_id(input: TokenStream) -> TokenStream { expand_proc_macro(input, query_id::derive) } -/// Implements `Queryable` +/// Implements `Queryable` to load the result of statically typed queries /// /// This trait can only be derived for structs, not enums. /// @@ -331,12 +331,144 @@ pub fn derive_query_id(input: TokenStream) -> TokenStream { /// into the field type, the implementation will deserialize into `Type`. /// Then `Type` is converted via `.into()` into the field type. By default /// this derive will deserialize directly into the field type +/// +/// +/// # Examples +/// +/// If we just want to map a query to our struct, we can use `derive`. +/// +/// ```rust +/// # extern crate diesel; +/// # extern crate dotenv; +/// # include!("../../diesel/src/doctest_setup.rs"); +/// # +/// #[derive(Queryable, PartialEq, Debug)] +/// struct User { +/// id: i32, +/// name: String, +/// } +/// +/// # fn main() { +/// # run_test(); +/// # } +/// # +/// # fn run_test() -> QueryResult<()> { +/// # use schema::users::dsl::*; +/// # let connection = establish_connection(); +/// let first_user = users.first(&connection)?; +/// let expected = User { id: 1, name: "Sean".into() }; +/// assert_eq!(expected, first_user); +/// # Ok(()) +/// # } +/// ``` +/// +/// If we want to do additional work during deserialization, we can use +/// `deserialize_as` to use a different implementation. +/// +/// ```rust +/// # extern crate diesel; +/// # extern crate dotenv; +/// # include!("../../diesel/src/doctest_setup.rs"); +/// # +/// # use schema::users; +/// # use diesel::backend::{self, Backend}; +/// # use diesel::deserialize::{Queryable, FromSql}; +/// # use diesel::sql_types::Text; +/// # +/// struct LowercaseString(String); +/// +/// impl Into<String> for LowercaseString { +/// fn into(self) -> String { +/// self.0 +/// } +/// } +/// +/// impl<DB> Queryable<Text, DB> for LowercaseString +/// where +/// DB: Backend, +/// String: FromSql<Text, DB> +/// { +/// +/// type Row = String; +/// +/// fn build(s: String) -> Self { +/// LowercaseString(s.to_lowercase()) +/// } +/// } +/// +/// #[derive(Queryable, PartialEq, Debug)] +/// struct User { +/// id: i32, +/// #[diesel(deserialize_as = "LowercaseString")] +/// name: String, +/// } +/// +/// # fn main() { +/// # run_test(); +/// # } +/// # +/// # fn run_test() -> QueryResult<()> { +/// # use schema::users::dsl::*; +/// # let connection = establish_connection(); +/// let first_user = users.first(&connection)?; +/// let expected = User { id: 1, name: "sean".into() }; +/// assert_eq!(expected, first_user); +/// # Ok(()) +/// # } +/// ``` +/// +/// Alternatively, we can implement the trait for our struct manually. +/// +/// ```rust +/// # extern crate diesel; +/// # extern crate dotenv; +/// # include!("../../diesel/src/doctest_setup.rs"); +/// # +/// use schema::users; +/// use diesel::deserialize::{Queryable, FromSqlRow}; +/// use diesel::row::Row; +/// +/// # /* +/// type DB = diesel::sqlite::Sqlite; +/// # */ +/// +/// #[derive(PartialEq, Debug)] +/// struct User { +/// id: i32, +/// name: String, +/// } +/// +/// impl Queryable<users::SqlType, DB> for User +/// where +/// (i32, String): FromSqlRow<users::SqlType, DB>, +/// { +/// type Row = (i32, String); +/// +/// fn build((id, name): Self::Row) -> Self { +/// User { id, name: name.to_lowercase() } +/// } +/// } +/// +/// # fn main() { +/// # run_test(); +/// # } +/// # +/// # fn run_test() -> QueryResult<()> { +/// # use schema::users::dsl::*; +/// # let connection = establish_connection(); +/// let first_user = users.first(&connection)?; +/// let expected = User { id: 1, name: "sean".into() }; +/// assert_eq!(expected, first_user); +/// # Ok(()) +/// # } +/// ``` #[proc_macro_derive(Queryable, attributes(column_name, diesel))] pub fn derive_queryable(input: TokenStream) -> TokenStream { expand_proc_macro(input, queryable::derive) } -/// Implements `QueryableByName` +/// Implements `QueryableByName` for untyped sql queries, such as that one generated +/// by `sql_query` /// /// To derive this trait, Diesel needs to know the SQL type of each field. You /// can do this by either annotating your struct with `#[table_name = @@ -388,6 +520,137 @@ pub fn derive_queryable(input: TokenStream) -> TokenStream { /// * `#[diesel(embed)]`, specifies that the current field maps not only /// single database column, but is a type that implements /// `QueryableByName` on it's own +/// +/// /// # Examples +/// +/// If we just want to map a query to our struct, we can use `derive`. +/// +/// ```rust +/// # extern crate diesel; +/// # extern crate dotenv; +/// # include!("../../diesel/src/doctest_setup.rs"); +/// # use schema::users; +/// # use diesel::sql_query; +/// # +/// #[derive(QueryableByName, PartialEq, Debug)] +/// #[table_name = "users"] +/// struct User { +/// id: i32, +/// name: String, +/// } +/// +/// # fn main() { +/// # run_test(); +/// # } +/// # +/// # fn run_test() -> QueryResult<()> { +/// # let connection = establish_connection(); +/// let first_user = sql_query("SELECT * FROM users ORDER BY id LIMIT 1") +/// .get_result(&connection)?; +/// let expected = User { id: 1, name: "Sean".into() }; +/// assert_eq!(expected, first_user); +/// # Ok(()) +/// # } +/// ``` +/// +/// If we want to do additional work during deserialization, we can use +/// `deserialize_as` to use a different implementation. +/// +/// ```rust +/// # extern crate diesel; +/// # extern crate dotenv; +/// # include!("../../diesel/src/doctest_setup.rs"); +/// # use diesel::sql_query; +/// # use schema::users; +/// # use diesel::backend::{self, Backend}; +/// # use diesel::deserialize::{self, FromSql}; +/// # +/// struct LowercaseString(String); +/// +/// impl Into<String> for LowercaseString { +/// fn into(self) -> String { +/// self.0 +/// } +/// } +/// +/// impl<DB, ST> FromSql<ST, DB> for LowercaseString +/// where +/// DB: Backend, +/// String: FromSql<ST, DB>, +/// { +/// fn from_sql(bytes: backend::RawValue<DB>) -> deserialize::Result<Self> { +/// String::from_sql(bytes) +/// .map(|s| LowercaseString(s.to_lowercase())) +/// } +/// } +/// +/// #[derive(QueryableByName, PartialEq, Debug)] +/// #[table_name = "users"] +/// struct User { +/// id: i32, +/// #[diesel(deserialize_as = "LowercaseString")] +/// name: String, +/// } +/// +/// # fn main() { +/// # run_test(); +/// # } +/// # +/// # fn run_test() -> QueryResult<()> { +/// # let connection = establish_connection(); +/// let first_user = sql_query("SELECT * FROM users ORDER BY id LIMIT 1") +/// .get_result(&connection)?; +/// let expected = User { id: 1, name: "sean".into() }; +/// assert_eq!(expected, first_user); +/// # Ok(()) +/// # } +/// ``` +/// +/// The custom derive generates impls similar to the follownig one +/// +/// ```rust +/// # extern crate diesel; +/// # extern crate dotenv; +/// # include!("../../diesel/src/doctest_setup.rs"); +/// # use schema::users; +/// # use diesel::sql_query; +/// # use diesel::deserialize::{self, QueryableByName, FromSql}; +/// # use diesel::row::NamedRow; +/// # use diesel::backend::Backend; +/// # +/// #[derive(PartialEq, Debug)] +/// struct User { +/// id: i32, +/// name: String, +/// } +/// +/// impl<DB> QueryableByName<DB> for User +/// where +/// DB: Backend, +/// i32: FromSql<diesel::dsl::SqlTypeOf<users::id>, DB>, +/// String: FromSql<diesel::dsl::SqlTypeOf<users::name>, DB>, +/// { +/// fn build<'a>(row: &impl NamedRow<'a, DB>) -> deserialize::Result<Self> { +/// let id = NamedRow::get::<diesel::dsl::SqlTypeOf<users::id>, _>(row, "id")?; +/// let name = NamedRow::get::<diesel::dsl::SqlTypeOf<users::name>, _>(row, "name")?; +/// +/// Ok(Self { id, name }) +/// } +/// } +/// +/// # fn main() { +/// # run_test(); +/// # } +/// # +/// # fn run_test() -> QueryResult<()> { +/// # let connection = establish_connection(); +/// let first_user = sql_query("SELECT * FROM users ORDER BY id LIMIT 1") +/// .get_result(&connection)?; +/// let expected = User { id: 1, name: "Sean".into() }; +/// assert_eq!(expected, first_user); +/// # Ok(()) +/// # } +/// ``` #[proc_macro_derive(QueryableByName, attributes(table_name, column_name, sql_type, diesel))] pub fn derive_queryable_by_name(input: TokenStream) -> TokenStream { expand_proc_macro(input, queryable_by_name::derive) diff --git a/diesel_derives/src/queryable.rs b/diesel_derives/src/queryable.rs index 077fe3698ac8..388a2f8802fb 100644 --- a/diesel_derives/src/queryable.rs +++ b/diesel_derives/src/queryable.rs @@ -19,31 +19,41 @@ pub fn derive(item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Diagno let i = syn::Index::from(i); f.name.assign(parse_quote!(row.#i.into())) }); + let sql_type = (0..model.fields().len()) + .map(|i| { + let i = syn::Ident::new(&format!("__ST{}", i), proc_macro2::Span::call_site()); + quote!(#i) + }) + .collect::<Vec<_>>(); + let sql_type = &sql_type; let (_, ty_generics, _) = item.generics.split_for_impl(); let mut generics = item.generics.clone(); generics .params .push(parse_quote!(__DB: diesel::backend::Backend)); - generics.params.push(parse_quote!(__ST)); + for id in 0..model.fields().len() { + let ident = syn::Ident::new(&format!("__ST{}", id), proc_macro2::Span::call_site()); + generics.params.push(parse_quote!(#ident)); + } { let where_clause = generics.where_clause.get_or_insert(parse_quote!(where)); where_clause .predicates - .push(parse_quote!((#(#field_ty,)*): Queryable<__ST, __DB>)); + .push(parse_quote!((#(#field_ty,)*): FromStaticSqlRow<(#(#sql_type,)*), __DB>)); } let (impl_generics, _, where_clause) = generics.split_for_impl(); Ok(wrap_in_dummy_mod(quote! { - use diesel::deserialize::Queryable; + use diesel::deserialize::{FromStaticSqlRow, Queryable}; + use diesel::row::{Row, Field}; - impl #impl_generics Queryable<__ST, __DB> for #struct_name #ty_generics - #where_clause + impl #impl_generics Queryable<(#(#sql_type,)*), __DB> for #struct_name #ty_generics + #where_clause { - type Row = <(#(#field_ty,)*) as Queryable<__ST, __DB>>::Row; + type Row = (#(#field_ty,)*); fn build(row: Self::Row) -> Self { - let row: (#(#field_ty,)*) = Queryable::build(row); Self { #(#build_expr,)* } diff --git a/diesel_derives/src/queryable_by_name.rs b/diesel_derives/src/queryable_by_name.rs index ef9a44f6fe2d..bc954d5b7358 100644 --- a/diesel_derives/src/queryable_by_name.rs +++ b/diesel_derives/src/queryable_by_name.rs @@ -9,11 +9,31 @@ pub fn derive(item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Diagno let model = Model::from_item(&item)?; let struct_name = &item.ident; - let field_expr = model + let fields = model.fields().iter().map(get_ident).collect::<Vec<_>>(); + let field_names = model.fields().iter().map(|f| &f.name).collect::<Vec<_>>(); + + let initial_field_expr = model .fields() .iter() - .map(|f| field_expr(f, &model)) - .collect::<Result<Vec<_>, _>>()?; + .map(|f| { + if f.has_flag("embed") { + let field_ty = &f.ty; + Ok(quote!(<#field_ty as QueryableByName<__DB>>::build( + row, + )?)) + } else { + let name = f.column_name(); + let field_ty = &f.ty; + let deserialize_ty = f.ty_for_deserialize()?; + Ok(quote!( + { + let field = diesel::row::NamedRow::get(row, stringify!(#name))?; + <#deserialize_ty as Into<#field_ty>>::into(field) + } + )) + } + }) + .collect::<Result<Vec<_>, Diagnostic>>()?; let (_, ty_generics, ..) = item.generics.split_for_impl(); let mut generics = item.generics.clone(); @@ -40,33 +60,34 @@ pub fn derive(item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Diagno Ok(wrap_in_dummy_mod(quote! { use diesel::deserialize::{self, QueryableByName}; - use diesel::row::NamedRow; + use diesel::row::{NamedRow}; + use diesel::sql_types::Untyped; impl #impl_generics QueryableByName<__DB> for #struct_name #ty_generics #where_clause { - fn build<__R: NamedRow<__DB>>(row: &__R) -> deserialize::Result<Self> { - std::result::Result::Ok(Self { - #(#field_expr,)* + fn build<'__a>(row: &impl NamedRow<'__a, __DB>) -> deserialize::Result<Self> + { + + + #( + let mut #fields = #initial_field_expr; + )* + deserialize::Result::Ok(Self { + #( + #field_names: #fields, + )* }) } } })) } -fn field_expr(field: &Field, model: &Model) -> Result<syn::FieldValue, Diagnostic> { - if field.has_flag("embed") { - Ok(field - .name - .assign(parse_quote!(QueryableByName::build(row)?))) - } else { - let column_name = field.column_name(); - let ty = field.ty_for_deserialize()?; - let st = sql_type(field, model); - Ok(field - .name - .assign(parse_quote!(row.get::<#st, #ty>(stringify!(#column_name))?.into()))) +fn get_ident(field: &Field) -> Ident { + match &field.name { + FieldName::Named(n) => n.clone(), + FieldName::Unnamed(i) => Ident::new(&format!("field_{}", i.index), Span::call_site()), } } diff --git a/diesel_derives/src/sql_function.rs b/diesel_derives/src/sql_function.rs index a41af4f9fffa..a64a50595fde 100644 --- a/diesel_derives/src/sql_function.rs +++ b/diesel_derives/src/sql_function.rs @@ -141,7 +141,7 @@ pub(crate) fn expand(input: SqlFunctionDecl) -> Result<TokenStream, Diagnostic> use diesel::sqlite::{Sqlite, SqliteConnection}; use diesel::serialize::ToSql; - use diesel::deserialize::Queryable; + use diesel::deserialize::{FromSqlRow, StaticallySizedRow}; use diesel::sqlite::SqliteAggregateFunction; use diesel::sql_types::IntoNullable; }; @@ -163,7 +163,8 @@ pub(crate) fn expand(input: SqlFunctionDecl) -> Result<TokenStream, Diagnostic> where A: SqliteAggregateFunction<(#(#arg_name,)*)> + Send + 'static, A::Output: ToSql<#return_type, Sqlite>, - (#(#arg_name,)*): Queryable<(#(#arg_type,)*), Sqlite>, + (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> + + StaticallySizedRow<(#(#arg_type,)*), Sqlite>, { conn.register_aggregate_function::<(#(#arg_type,)*), #return_type, _, _, A>(#sql_name) } @@ -188,7 +189,8 @@ pub(crate) fn expand(input: SqlFunctionDecl) -> Result<TokenStream, Diagnostic> where A: SqliteAggregateFunction<#arg_name> + Send + 'static, A::Output: ToSql<#return_type, Sqlite>, - #arg_name: Queryable<#arg_type, Sqlite>, + #arg_name: FromSqlRow<#arg_type, Sqlite> + + StaticallySizedRow<#arg_type, Sqlite>, { conn.register_aggregate_function::<#arg_type, #return_type, _, _, A>(#sql_name) } @@ -219,7 +221,7 @@ pub(crate) fn expand(input: SqlFunctionDecl) -> Result<TokenStream, Diagnostic> use diesel::sqlite::{Sqlite, SqliteConnection}; use diesel::serialize::ToSql; - use diesel::deserialize::Queryable; + use diesel::deserialize::{FromSqlRow, StaticallySizedRow}; #[allow(dead_code)] /// Registers an implementation for this function on the given connection @@ -235,7 +237,8 @@ pub(crate) fn expand(input: SqlFunctionDecl) -> Result<TokenStream, Diagnostic> ) -> QueryResult<()> where F: Fn(#(#arg_name,)*) -> Ret + Send + 'static, - (#(#arg_name,)*): Queryable<(#(#arg_type,)*), Sqlite>, + (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> + + StaticallySizedRow<(#(#arg_type,)*), Sqlite>, Ret: ToSql<#return_type, Sqlite>, { conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>( @@ -260,7 +263,8 @@ pub(crate) fn expand(input: SqlFunctionDecl) -> Result<TokenStream, Diagnostic> ) -> QueryResult<()> where F: FnMut(#(#arg_name,)*) -> Ret + Send + 'static, - (#(#arg_name,)*): Queryable<(#(#arg_type,)*), Sqlite>, + (#(#arg_name,)*): FromSqlRow<(#(#arg_type,)*), Sqlite> + + StaticallySizedRow<(#(#arg_type,)*), Sqlite>, Ret: ToSql<#return_type, Sqlite>, { conn.register_sql_function::<(#(#arg_type,)*), #return_type, _, _, _>( @@ -319,7 +323,7 @@ impl Parse for SqlFunctionDecl { let return_type = if Option::<Token![->]>::parse(input)?.is_some() { syn::Type::parse(input)? } else { - parse_quote!(()) + parse_quote!(diesel::expression::expression_types::NotSelectable) }; let _semi = Option::<Token![;]>::parse(input)?; diff --git a/diesel_derives/src/sql_type.rs b/diesel_derives/src/sql_type.rs index 5a488d4e1a18..2f8ea7a2bc38 100644 --- a/diesel_derives/src/sql_type.rs +++ b/diesel_derives/src/sql_type.rs @@ -13,10 +13,11 @@ pub fn derive(item: syn::DeriveInput) -> Result<proc_macro2::TokenStream, Diagno let pg_tokens = pg_tokens(&item); Ok(wrap_in_dummy_mod(quote! { - impl #impl_generics diesel::sql_types::NotNull + impl #impl_generics diesel::sql_types::SqlType for #struct_name #ty_generics #where_clause { + type IsNull = diesel::sql_types::is_nullable::NotNull; } impl #impl_generics diesel::sql_types::SingleValue diff --git a/diesel_migrations/migrations_internals/src/connection.rs b/diesel_migrations/migrations_internals/src/connection.rs index 0452a66ed0c8..7ce928d0153f 100644 --- a/diesel_migrations/migrations_internals/src/connection.rs +++ b/diesel_migrations/migrations_internals/src/connection.rs @@ -1,11 +1,12 @@ use diesel::deserialize::FromSql; use diesel::expression::bound::Bound; +use diesel::expression::QueryMetadata; use diesel::helper_types::{max, Limit, Select}; use diesel::insertable::ColumnInsertValue; use diesel::prelude::*; use diesel::query_builder::{InsertStatement, QueryFragment, ValuesClause}; use diesel::query_dsl::methods::{self, ExecuteDsl, LoadQuery}; -use diesel::sql_types::VarChar; +use diesel::sql_types::{Nullable, VarChar}; use std::collections::HashSet; use std::iter::FromIterator; @@ -36,6 +37,7 @@ where __diesel_schema_migrations: methods::SelectDsl<version>, Select<__diesel_schema_migrations, version>: LoadQuery<T, String>, Limit<Select<__diesel_schema_migrations, max<version>>>: QueryFragment<T::Backend>, + T::Backend: QueryMetadata<Nullable<VarChar>>, { fn previously_run_migration_versions(&self) -> QueryResult<HashSet<String>> { __diesel_schema_migrations diff --git a/diesel_tests/tests/custom_types.rs b/diesel_tests/tests/custom_types.rs index 50eaa1caa624..915e81703486 100644 --- a/diesel_tests/tests/custom_types.rs +++ b/diesel_tests/tests/custom_types.rs @@ -37,8 +37,8 @@ impl ToSql<MyType, Pg> for MyEnum { } impl FromSql<MyType, Pg> for MyEnum { - fn from_sql(bytes: Option<PgValue<'_>>) -> deserialize::Result<Self> { - match not_none!(bytes).as_bytes() { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result<Self> { + match bytes.as_bytes() { b"foo" => Ok(MyEnum::Foo), b"bar" => Ok(MyEnum::Bar), _ => Err("Unrecognized enum variant".into()), diff --git a/diesel_tests/tests/expressions/date_and_time.rs b/diesel_tests/tests/expressions/date_and_time.rs index 52d5b40bfab8..fb8624f7e8c7 100644 --- a/diesel_tests/tests/expressions/date_and_time.rs +++ b/diesel_tests/tests/expressions/date_and_time.rs @@ -131,7 +131,7 @@ fn now_can_be_used_as_nullable() { let nullable_timestamp = sql::<Nullable<Timestamp>>("CURRENT_TIMESTAMP"); let result = select(nullable_timestamp.eq(now)).get_result(&connection()); - assert_eq!(Ok(true), result); + assert_eq!(Ok(Some(true)), result); } #[test] diff --git a/diesel_tests/tests/expressions/mod.rs b/diesel_tests/tests/expressions/mod.rs index 314c5690e431..6d6702b7df28 100644 --- a/diesel_tests/tests/expressions/mod.rs +++ b/diesel_tests/tests/expressions/mod.rs @@ -10,7 +10,9 @@ use crate::schema::{ }; use diesel::backend::Backend; use diesel::dsl::*; +use diesel::expression::TypedExpressionType; use diesel::query_builder::*; +use diesel::sql_types::SqlType; use diesel::*; #[test] @@ -152,7 +154,10 @@ struct Arbitrary<T> { _marker: PhantomData<T>, } -impl<T> Expression for Arbitrary<T> { +impl<T> Expression for Arbitrary<T> +where + T: SqlType + TypedExpressionType, +{ type SqlType = T; } @@ -165,9 +170,9 @@ where } } -impl<T, QS> SelectableExpression<QS> for Arbitrary<T> {} +impl<T, QS> SelectableExpression<QS> for Arbitrary<T> where Self: Expression {} -impl<T, QS> AppearsOnTable<QS> for Arbitrary<T> {} +impl<T, QS> AppearsOnTable<QS> for Arbitrary<T> where Self: Expression {} fn arbitrary<T>() -> Arbitrary<T> { Arbitrary { diff --git a/diesel_tests/tests/filter.rs b/diesel_tests/tests/filter.rs index 7b010e10dcd0..9f7d72164974 100644 --- a/diesel_tests/tests/filter.rs +++ b/diesel_tests/tests/filter.rs @@ -328,14 +328,14 @@ fn or_doesnt_mess_with_precedence_of_previous_statements() { let f = false.into_sql::<sql_types::Bool>(); let count = users .filter(f) - .filter(f.or(true)) + .filter(f.or(true.into_sql::<sql_types::Bool>())) .count() .first(&connection); assert_eq!(Ok(0), count); let count = users - .filter(f.or(f).and(f.or(true))) + .filter(f.or(f).and(f.or(true.into_sql::<sql_types::Bool>()))) .count() .first(&connection); diff --git a/diesel_tests/tests/joins.rs b/diesel_tests/tests/joins.rs index c7d4d5a02981..9b4dd3762a42 100644 --- a/diesel_tests/tests/joins.rs +++ b/diesel_tests/tests/joins.rs @@ -300,6 +300,35 @@ fn select_right_side_with_nullable_column_first() { assert_eq!(expected_data, actual_data); } +#[test] +fn select_left_join_right_side_with_non_null_inside() { + let connection = connection_with_sean_and_tess_in_users_table(); + + connection + .execute( + "INSERT INTO posts (user_id, title, body) VALUES + (1, 'Hello', 'Content') + ", + ) + .unwrap(); + + let expected_data = vec![ + (None, 2), + (Some((1, "Hello".to_string(), "Hello".to_string())), 1), + ]; + + let source = users::table + .left_outer_join(posts::table) + .select(( + (users::id, posts::title, posts::title).nullable(), + users::id, + )) + .order_by((users::id.desc(), posts::id.asc())); + let actual_data: Vec<(Option<(i32, String, String)>, i32)> = source.load(&connection).unwrap(); + + assert_eq!(expected_data, actual_data); +} + #[test] fn select_then_join() { use crate::schema::users::dsl::*; diff --git a/diesel_tests/tests/types.rs b/diesel_tests/tests/types.rs index 8c528cc9bbe4..0e1bf368c912 100644 --- a/diesel_tests/tests/types.rs +++ b/diesel_tests/tests/types.rs @@ -5,6 +5,7 @@ extern crate bigdecimal; extern crate chrono; use crate::schema::*; +use diesel::deserialize::FromSqlRow; #[cfg(feature = "postgres")] use diesel::pg::Pg; use diesel::sql_types::*; @@ -144,12 +145,11 @@ fn boolean_from_sql() { } #[test] -#[cfg(feature = "postgres")] -fn boolean_treats_null_as_false_when_predicates_return_null() { +fn nullable_boolean_from_sql() { let connection = connection(); - let one = Some(1).into_sql::<Nullable<Integer>>(); + let one = Some(1).into_sql::<diesel::sql_types::Nullable<Integer>>(); let query = select(one.eq(None::<i32>)); - assert_eq!(Ok(false), query.first(&connection)); + assert_eq!(Ok(Option::<bool>::None), query.first(&connection)); } #[test] @@ -670,16 +670,16 @@ fn pg_specific_option_to_sql() { "'t'::bool", Some(true) )); - assert!(!query_to_sql_equality::<Nullable<Bool>, Option<bool>>( + assert!(query_to_sql_equality::<Nullable<Bool>, Option<bool>>( "'f'::bool", - Some(true) + Some(false) )); assert!(query_to_sql_equality::<Nullable<Bool>, Option<bool>>( "NULL", None )); - assert!(!query_to_sql_equality::<Nullable<Bool>, Option<bool>>( + assert!(query_to_sql_equality::<Nullable<Bool>, Option<bool>>( "NULL::bool", - Some(false) + None )); } @@ -1231,7 +1231,7 @@ fn third_party_crates_can_add_new_types() { } impl FromSql<MyInt, Pg> for i32 { - fn from_sql(bytes: Option<PgValue<'_>>) -> deserialize::Result<Self> { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result<Self> { FromSql::<Integer, Pg>::from_sql(bytes) } } @@ -1241,17 +1241,18 @@ fn third_party_crates_can_add_new_types() { assert_eq!(70_000, query_single_value::<MyInt, i32>("70000")); } -fn query_single_value<T, U: Queryable<T, TestBackend>>(sql_str: &str) -> U +fn query_single_value<T, U: FromSqlRow<T, TestBackend>>(sql_str: &str) -> U where TestBackend: HasSqlType<T>, - T: QueryId + SingleValue, + T: QueryId + SingleValue + SqlType, { use diesel::dsl::sql; let connection = connection(); select(sql::<T>(sql_str)).first(&connection).unwrap() } -use diesel::expression::{is_aggregate, AsExpression, ValidGrouping}; +use diesel::dsl::{And, AsExprOf, Eq, IsNull}; +use diesel::expression::{is_aggregate, AsExpression, SqlLiteral, ValidGrouping}; use diesel::query_builder::{QueryFragment, QueryId}; use std::fmt::Debug; @@ -1261,7 +1262,16 @@ where U::Expression: SelectableExpression<(), SqlType = T> + ValidGrouping<(), IsAggregate = is_aggregate::Never>, U::Expression: QueryFragment<TestBackend> + QueryId, - T: QueryId + SingleValue, + T: QueryId + SingleValue + SqlType, + T::IsNull: OneIsNullable<T::IsNull, Out = T::IsNull>, + T::IsNull: MaybeNullableType<Bool>, + <T::IsNull as MaybeNullableType<Bool>>::Out: SqlType, + diesel::sql_types::is_nullable::NotNull: diesel::sql_types::AllAreNullable< + <<T::IsNull as MaybeNullableType<Bool>>::Out as SqlType>::IsNull, + Out = diesel::sql_types::is_nullable::NotNull, + >, + Eq<SqlLiteral<T>, U>: Expression<SqlType = <T::IsNull as MaybeNullableType<Bool>>::Out>, + And<IsNull<SqlLiteral<T>>, IsNull<AsExprOf<U, T>>>: Expression<SqlType = Bool>, { use diesel::dsl::sql; let connection = connection(); @@ -1272,7 +1282,7 @@ where .or(sql::<T>(sql_str).eq(value.clone())), ); query - .get_result(&connection) + .get_result::<bool>(&connection) .expect(&format!("Error comparing {}, {:?}", sql_str, value)) } diff --git a/diesel_tests/tests/types_roundtrip.rs b/diesel_tests/tests/types_roundtrip.rs index ea2577e03ebd..67818efb7464 100644 --- a/diesel_tests/tests/types_roundtrip.rs +++ b/diesel_tests/tests/types_roundtrip.rs @@ -9,10 +9,11 @@ pub use crate::schema::{connection_without_transaction, TestConnection}; pub use diesel::data_types::*; pub use diesel::result::Error; pub use diesel::serialize::ToSql; -pub use diesel::sql_types::HasSqlType; +pub use diesel::sql_types::{HasSqlType, SingleValue, SqlType}; pub use diesel::*; -use diesel::expression::{AsExpression, NonAggregate}; +use deserialize::FromSqlRow; +use diesel::expression::{AsExpression, NonAggregate, TypedExpressionType}; use diesel::query_builder::{QueryFragment, QueryId}; #[cfg(feature = "postgres")] use std::collections::Bound; @@ -23,10 +24,10 @@ thread_local! { pub fn test_type_round_trips<ST, T>(value: T) -> bool where - ST: QueryId, + ST: QueryId + SqlType + TypedExpressionType + SingleValue, <TestConnection as Connection>::Backend: HasSqlType<ST>, T: AsExpression<ST> - + Queryable<ST, <TestConnection as Connection>::Backend> + + FromSqlRow<ST, <TestConnection as Connection>::Backend> + PartialEq + Clone + ::std::fmt::Debug, diff --git a/examples/postgres/advanced-blog-cli/src/post.rs b/examples/postgres/advanced-blog-cli/src/post.rs index 397a3a68dfb0..8c621813f365 100644 --- a/examples/postgres/advanced-blog-cli/src/post.rs +++ b/examples/postgres/advanced-blog-cli/src/post.rs @@ -14,6 +14,7 @@ pub struct Post { pub body: String, pub created_at: NaiveDateTime, pub updated_at: NaiveDateTime, + #[diesel(deserialize_as = "Option<NaiveDateTime>")] pub status: Status, } @@ -22,17 +23,11 @@ pub enum Status { Published { at: NaiveDateTime }, } -use diesel::deserialize::Queryable; -use diesel::pg::Pg; -use diesel::sql_types::{Nullable, Timestamp}; - -impl Queryable<Nullable<Timestamp>, Pg> for Status { - type Row = Option<NaiveDateTime>; - - fn build(row: Self::Row) -> Self { - match row { - Some(at) => Status::Published { at }, +impl Into<Status> for Option<NaiveDateTime> { + fn into(self) -> Status { + match self { None => Status::Draft, + Some(at) => Status::Published { at }, } } } diff --git a/examples/postgres/custom_types/src/main.rs b/examples/postgres/custom_types/src/main.rs index 584fd50fb857..b46f1441f4f4 100644 --- a/examples/postgres/custom_types/src/main.rs +++ b/examples/postgres/custom_types/src/main.rs @@ -1,8 +1,5 @@ -#[macro_use] -extern crate diesel; - +use self::schema::translations; use diesel::prelude::*; -use schema::translations::{self, dsl}; mod model; mod schema; @@ -20,7 +17,7 @@ fn main() { let conn = PgConnection::establish(&database_url) .unwrap_or_else(|e| panic!("Error connecting to {}: {}", database_url, e)); - let _ = diesel::insert_into(dsl::translations) + let _ = diesel::insert_into(translations::table) .values(&Translation { word_id: 1, translation_id: 1, @@ -28,8 +25,12 @@ fn main() { }) .execute(&conn); - let t = dsl::translations - .select((dsl::word_id, dsl::translation_id, dsl::language)) + let t = translations::table + .select(( + translations::word_id, + translations::translation_id, + translations::language, + )) .get_results::<Translation>(&conn) .expect("select"); println!("{:?}", t); diff --git a/examples/postgres/custom_types/src/model.rs b/examples/postgres/custom_types/src/model.rs index 2acc278faef2..2889c93c32af 100644 --- a/examples/postgres/custom_types/src/model.rs +++ b/examples/postgres/custom_types/src/model.rs @@ -13,7 +13,7 @@ pub mod exports { #[postgres(type_name = "Language")] pub struct LanguageType; -#[derive(Debug, FromSqlRow, AsExpression)] +#[derive(Debug, AsExpression, FromSqlRow)] #[sql_type = "LanguageType"] pub enum Language { En, @@ -33,8 +33,8 @@ impl ToSql<LanguageType, Pg> for Language { } impl FromSql<LanguageType, Pg> for Language { - fn from_sql(bytes: Option<PgValue>) -> deserialize::Result<Self> { - match not_none!(bytes).as_bytes() { + fn from_sql(bytes: PgValue) -> deserialize::Result<Self> { + match bytes.as_bytes() { b"en" => Ok(Language::En), b"ru" => Ok(Language::Ru), b"de" => Ok(Language::De), diff --git a/examples/postgres/custom_types/src/schema.rs b/examples/postgres/custom_types/src/schema.rs index 7c9bd89c7598..3294bb84f6d6 100644 --- a/examples/postgres/custom_types/src/schema.rs +++ b/examples/postgres/custom_types/src/schema.rs @@ -1,4 +1,4 @@ -table! { +diesel::table! { use diesel::sql_types::*; use crate::model::exports::*;