diff --git a/psl/builtin-connectors/src/lib.rs b/psl/builtin-connectors/src/lib.rs index 7c1aaec66ba3..9c43cd8b1244 100644 --- a/psl/builtin-connectors/src/lib.rs +++ b/psl/builtin-connectors/src/lib.rs @@ -9,6 +9,7 @@ pub use mongodb::MongoDbType; pub use mssql_datamodel_connector::{MsSqlType, MsSqlTypeParameter}; pub use mysql_datamodel_connector::MySqlType; pub use postgres_datamodel_connector::{PostgresDatasourceProperties, PostgresType}; +pub use psl_core::js_connector::JsConnector; mod mongodb; mod mssql_datamodel_connector; diff --git a/psl/psl-core/src/js_connector.rs b/psl/psl-core/src/js_connector.rs index 36c872804ec4..8798d1d3eeae 100644 --- a/psl/psl-core/src/js_connector.rs +++ b/psl/psl-core/src/js_connector.rs @@ -23,6 +23,15 @@ pub struct JsConnector { pub allowed_protocols: Option<&'static [&'static str]>, } +impl JsConnector { + /// Returns true if the given name is a valid provider name for a JsConnector. + /// We use the convention that if a provider starts with ´@prisma/´ (ex. ´@prisma/planetscale´) + /// then its a provider for a JS connector. + pub fn is_provider(name: &str) -> bool { + name.starts_with("@prisma/") + } +} + #[derive(Copy, Clone)] pub enum Flavor { MySQL, diff --git a/query-engine/connectors/sql-query-connector/src/database/connection.rs b/query-engine/connectors/sql-query-connector/src/database/connection.rs index c466b5ad7435..b7aeb43a3f76 100644 --- a/query-engine/connectors/sql-query-connector/src/database/connection.rs +++ b/query-engine/connectors/sql-query-connector/src/database/connection.rs @@ -22,7 +22,7 @@ pub(crate) struct SqlConnection { impl SqlConnection where - C: Queryable + TransactionCapable + Send + Sync + 'static, + C: TransactionCapable + Send + Sync + 'static, { pub fn new(inner: C, connection_info: &ConnectionInfo, features: psl::PreviewFeatures) -> Self { let connection_info = connection_info.clone(); diff --git a/query-engine/connectors/sql-query-connector/src/database/js.rs b/query-engine/connectors/sql-query-connector/src/database/js.rs new file mode 100644 index 000000000000..d2aa03832146 --- /dev/null +++ b/query-engine/connectors/sql-query-connector/src/database/js.rs @@ -0,0 +1,161 @@ +use super::connection::SqlConnection; +use crate::FromSource; +use async_trait::async_trait; +use connector_interface::{ + self as connector, + error::{ConnectorError, ErrorKind}, + Connection, Connector, +}; +use quaint::{ + connector::IsolationLevel, + prelude::{Queryable as QuaintQueryable, *}, +}; +use std::sync::Arc; + +// TODO: https://github.com/prisma/team-orm/issues/245 +// implement registry for client drivers, rather than a global variable, +// this would require the register_driver and registered_js_driver functions to +// receive an identifier for the specific driver +static QUERYABLE: once_cell::sync::OnceCell> = once_cell::sync::OnceCell::new(); + +pub fn registered_js_connector() -> Option<&'static Arc> { + QUERYABLE.get() +} + +pub fn register_js_connector(driver: Arc) { + if QUERYABLE.set(driver).is_err() { + panic!("Cannot register driver twice"); + } +} + +pub struct Js { + connector: JsConnector, + connection_info: ConnectionInfo, + features: psl::PreviewFeatures, + psl_connector: psl::builtin_connectors::JsConnector, +} + +fn get_connection_info(url: &str) -> connector::Result { + ConnectionInfo::from_url(url).map_err(|err| { + ConnectorError::from_kind(ErrorKind::InvalidDatabaseUrl { + details: err.to_string(), + url: url.to_string(), + }) + }) +} + +#[async_trait] +impl FromSource for Js { + async fn from_source( + source: &psl::Datasource, + url: &str, + features: psl::PreviewFeatures, + ) -> connector_interface::Result { + let psl_connector = source.active_connector.as_js_connector().unwrap_or_else(|| { + panic!( + "Connector for {} is not a JsConnector", + source.active_connector.provider_name() + ) + }); + + let connector = registered_js_connector().unwrap().clone(); + let connection_info = get_connection_info(url)?; + + return Ok(Js { + connector: JsConnector { queryable: connector }, + connection_info, + features: features.to_owned(), + psl_connector, + }); + } +} + +#[async_trait] +impl Connector for Js { + async fn get_connection<'a>(&'a self) -> connector::Result> { + super::catch(self.connection_info.clone(), async move { + let sql_conn = SqlConnection::new(self.connector.clone(), &self.connection_info, self.features); + Ok(Box::new(sql_conn) as Box) + }) + .await + } + + fn name(&self) -> &'static str { + self.psl_connector.name + } + + fn should_retry_on_transient_error(&self) -> bool { + false + } +} + +// TODO: miguelff: I haven´t found a better way to do this, yet... please continue reading. +// +// There is a bug in NAPI-rs by wich compiling a binary crate that links code using napi-rs +// bindings breaks. We could have used a JsQueryable from the `js-connectors` crate directly, as the +// `connection` field of a `Js` connector, but that will imply using napi-rs transitively, and break +// the tests (which are compiled as binary creates) +// +// To avoid the problem above I separated interface from implementation, making JsConnector +// independent on napi-rs. Initially, I tried having a field Arc<&dyn TransactionCabable> to hold +// JsQueryable at runtime. I did this, because TransactionCapable is the trait bounds required to +// create a value of `SqlConnection` (see [SqlConnection::new])) to actually performt the queries. +// using JSQueryable. However, this didn't work because TransactionCapable is not object safe. +// (has Sized as a supertrait) +// +// The thing is that TransactionCapable is not object safe and cannot be used in a dynamic type +// declaration, so finally I couldn't come up with anything better then wrapping a QuaintQueryable +// in this object, and implementing TransactionCapable (and quaint::Queryable) explicitly for it. +#[derive(Clone)] +struct JsConnector { + queryable: Arc, +} + +#[async_trait] +impl QuaintQueryable for JsConnector { + async fn query(&self, q: Query<'_>) -> quaint::Result { + self.queryable.query(q).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.queryable.query_raw(sql, params).await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.queryable.query_raw_typed(sql, params).await + } + + async fn execute(&self, q: Query<'_>) -> quaint::Result { + self.queryable.execute(q).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.queryable.execute_raw(sql, params).await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.queryable.execute_raw_typed(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { + self.queryable.raw_cmd(cmd).await + } + + async fn version(&self) -> quaint::Result> { + self.queryable.version().await + } + + fn is_healthy(&self) -> bool { + self.queryable.is_healthy() + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { + self.queryable.set_tx_isolation_level(isolation_level).await + } + + fn requires_isolation_first(&self) -> bool { + self.queryable.requires_isolation_first() + } +} + +impl TransactionCapable for JsConnector {} diff --git a/query-engine/connectors/sql-query-connector/src/database/js/mod.rs b/query-engine/connectors/sql-query-connector/src/database/js/mod.rs deleted file mode 100644 index 583a506aedf6..000000000000 --- a/query-engine/connectors/sql-query-connector/src/database/js/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod registry; - -pub use registry::{register_driver, registered_driver}; diff --git a/query-engine/connectors/sql-query-connector/src/database/js/registry.rs b/query-engine/connectors/sql-query-connector/src/database/js/registry.rs deleted file mode 100644 index 04609a27c3a9..000000000000 --- a/query-engine/connectors/sql-query-connector/src/database/js/registry.rs +++ /dev/null @@ -1,17 +0,0 @@ -use quaint::prelude::Queryable; -use std::sync::Arc; - -// TODO: implement registry for client drivers, rather than a global variable, -// this would require the register_driver and registered_js_driver functions to -// receive an identifier for the specific driver -static QUERYABLE: once_cell::sync::OnceCell> = once_cell::sync::OnceCell::new(); - -pub fn registered_driver() -> Option<&'static Arc> { - QUERYABLE.get() -} - -pub fn register_driver(driver: Arc) { - if QUERYABLE.set(driver).is_err() { - panic!("Cannot register driver twice"); - } -} diff --git a/query-engine/connectors/sql-query-connector/src/database/mod.rs b/query-engine/connectors/sql-query-connector/src/database/mod.rs index 8bd1dad7d6f3..96ab66334c51 100644 --- a/query-engine/connectors/sql-query-connector/src/database/mod.rs +++ b/query-engine/connectors/sql-query-connector/src/database/mod.rs @@ -1,8 +1,6 @@ -#[cfg(feature = "js-connectors")] -pub mod js; -mod runtime; - mod connection; +#[cfg(feature = "js-connectors")] +mod js; mod mssql; mod mysql; mod postgresql; @@ -14,6 +12,8 @@ pub(crate) mod operations; use async_trait::async_trait; use connector_interface::{error::ConnectorError, Connector}; +#[cfg(feature = "js-connectors")] +pub use js::*; pub use mssql::*; pub use mysql::*; pub use postgresql::*; diff --git a/query-engine/connectors/sql-query-connector/src/database/mysql.rs b/query-engine/connectors/sql-query-connector/src/database/mysql.rs index a84de669ff0f..deb3e6a4f35f 100644 --- a/query-engine/connectors/sql-query-connector/src/database/mysql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/mysql.rs @@ -1,5 +1,4 @@ use super::connection::SqlConnection; -use super::runtime::RuntimePool; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -11,7 +10,7 @@ use quaint::{pooled::Quaint, prelude::ConnectionInfo}; use std::time::Duration; pub struct Mysql { - pool: RuntimePool, + pool: Quaint, connection_info: ConnectionInfo, features: psl::PreviewFeatures, } @@ -39,31 +38,10 @@ fn get_connection_info(url: &str) -> connector::Result { #[async_trait] impl FromSource for Mysql { async fn from_source( - source: &psl::Datasource, + _: &psl::Datasource, url: &str, features: psl::PreviewFeatures, ) -> connector_interface::Result { - if source.provider == "@prisma/mysql" { - #[cfg(feature = "js-connectors")] - { - let driver = super::js::registered_driver(); - let connection_info = get_connection_info(url)?; - - return Ok(Mysql { - pool: RuntimePool::Js(driver.unwrap().clone()), - connection_info, - features: features.to_owned(), - }); - } - - #[cfg(not(feature = "js-connectors"))] - { - return Err(ConnectorError::from_kind(ErrorKind::UnsupportedConnector( - "The @prisma/mysql connector requires the `jsConnectors` preview feature to be enabled.".into(), - ))); - } - } - let connection_info = get_connection_info(url)?; let mut builder = Quaint::builder(url) @@ -77,7 +55,7 @@ impl FromSource for Mysql { let connection_info = pool.connection_info().to_owned(); Ok(Mysql { - pool: RuntimePool::Rust(pool), + pool, connection_info, features: features.to_owned(), }) @@ -99,11 +77,7 @@ impl Connector for Mysql { } fn name(&self) -> &'static str { - if self.pool.is_nodejs() { - "@prisma/mysql" - } else { - "mysql" - } + "mysql" } fn should_retry_on_transient_error(&self) -> bool { diff --git a/query-engine/connectors/sql-query-connector/src/database/runtime.rs b/query-engine/connectors/sql-query-connector/src/database/runtime.rs deleted file mode 100644 index 5aa8b14a666b..000000000000 --- a/query-engine/connectors/sql-query-connector/src/database/runtime.rs +++ /dev/null @@ -1,157 +0,0 @@ -use crate::SqlError; -use async_trait::async_trait; -use quaint::{ - connector::IsolationLevel, - pooled::{PooledConnection, Quaint}, - prelude::{Query, Queryable, TransactionCapable}, - Value, -}; - -#[cfg(feature = "js-connectors")] -type QueryableRef = std::sync::Arc; - -pub enum RuntimePool { - Rust(Quaint), - - #[cfg(feature = "js-connectors")] - Js(QueryableRef), -} - -impl RuntimePool { - pub fn is_nodejs(&self) -> bool { - match self { - Self::Rust(_) => false, - - #[cfg(feature = "js-connectors")] - Self::Js(_) => true, - } - } - - /// Reserve a connection from the pool - pub async fn check_out(&self) -> crate::Result { - match self { - Self::Rust(pool) => { - let conn: PooledConnection = pool.check_out().await.map_err(SqlError::from)?; - Ok(RuntimeConnection::Rust(conn)) - } - #[cfg(feature = "js-connectors")] - Self::Js(queryable) => Ok(RuntimeConnection::Js(queryable.clone())), - } - } -} - -pub enum RuntimeConnection { - Rust(PooledConnection), - - #[cfg(feature = "js-connectors")] - Js(QueryableRef), -} - -#[async_trait] -impl Queryable for RuntimeConnection { - async fn query(&self, q: Query<'_>) -> quaint::Result { - match self { - Self::Rust(conn) => conn.query(q).await, - - #[cfg(feature = "js-connectors")] - Self::Js(conn) => conn.query(q).await, - } - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { - match self { - Self::Rust(conn) => conn.query_raw(sql, params).await, - - #[cfg(feature = "js-connectors")] - Self::Js(conn) => conn.query_raw(sql, params).await, - } - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { - match self { - Self::Rust(conn) => conn.query_raw_typed(sql, params).await, - - #[cfg(feature = "js-connectors")] - Self::Js(conn) => conn.query_raw_typed(sql, params).await, - } - } - - async fn execute(&self, q: Query<'_>) -> quaint::Result { - match self { - Self::Rust(conn) => conn.execute(q).await, - - #[cfg(feature = "js-connectors")] - Self::Js(conn) => conn.execute(q).await, - } - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { - match self { - Self::Rust(conn) => conn.execute_raw(sql, params).await, - - #[cfg(feature = "js-connectors")] - Self::Js(conn) => conn.execute_raw(sql, params).await, - } - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { - match self { - Self::Rust(conn) => conn.execute_raw_typed(sql, params).await, - - #[cfg(feature = "js-connectors")] - Self::Js(conn) => conn.execute_raw_typed(sql, params).await, - } - } - - /// Run a command in the database, for queries that can't be run using - /// prepared statements. - async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { - match self { - Self::Rust(conn) => conn.raw_cmd(cmd).await, - - #[cfg(feature = "js-connectors")] - Self::Js(conn) => conn.raw_cmd(cmd).await, - } - } - - async fn version(&self) -> quaint::Result> { - match self { - Self::Rust(conn) => conn.version().await, - - #[cfg(feature = "js-connectors")] - Self::Js(conn) => conn.version().await, - } - } - - fn is_healthy(&self) -> bool { - match self { - Self::Rust(conn) => conn.is_healthy(), - - #[cfg(feature = "js-connectors")] - Self::Js(conn) => conn.is_healthy(), - } - } - - /// Sets the transaction isolation level to given value. - /// Implementers have to make sure that the passed isolation level is valid for the underlying database. - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { - match self { - Self::Rust(conn) => conn.set_tx_isolation_level(isolation_level).await, - - #[cfg(feature = "js-connectors")] - Self::Js(conn) => conn.set_tx_isolation_level(isolation_level).await, - } - } - - /// Signals if the isolation level SET needs to happen before or after the tx BEGIN. - fn requires_isolation_first(&self) -> bool { - match self { - Self::Rust(conn) => conn.requires_isolation_first(), - - #[cfg(feature = "js-connectors")] - Self::Js(conn) => conn.requires_isolation_first(), - } - } -} - -impl TransactionCapable for RuntimeConnection {} diff --git a/query-engine/connectors/sql-query-connector/src/lib.rs b/query-engine/connectors/sql-query-connector/src/lib.rs index bb728c85c6c9..fae5dff0f9f3 100644 --- a/query-engine/connectors/sql-query-connector/src/lib.rs +++ b/query-engine/connectors/sql-query-connector/src/lib.rs @@ -23,7 +23,7 @@ use self::{column_metadata::*, context::Context, filter_conversion::*, query_ext use quaint::prelude::Queryable; #[cfg(feature = "js-connectors")] -pub use database::js::register_driver; +pub use database::{register_js_connector, Js}; pub use database::{FromSource, Mssql, Mysql, PostgreSql, Sqlite}; pub use error::SqlError; diff --git a/query-engine/query-engine-node-api/Cargo.toml b/query-engine/query-engine-node-api/Cargo.toml index 5f19897c2660..3031c20f8d82 100644 --- a/query-engine/query-engine-node-api/Cargo.toml +++ b/query-engine/query-engine-node-api/Cargo.toml @@ -11,7 +11,7 @@ name = "query_engine" [features] default = ["js-connectors"] vendored-openssl = ["sql-connector/vendored-openssl"] -js-connectors = ["sql-connector/js-connectors"] +js-connectors = ["request-handlers/js-connectors", "sql-connector/js-connectors"] [dependencies] anyhow = "1" diff --git a/query-engine/query-engine-node-api/src/engine.rs b/query-engine/query-engine-node-api/src/engine.rs index 7d9f06824977..11f8590073df 100644 --- a/query-engine/query-engine-node-api/src/engine.rs +++ b/query-engine/query-engine-node-api/src/engine.rs @@ -151,12 +151,6 @@ impl QueryEngine { let log_callback = LogCallback::new(napi_env, callback)?; log_callback.unref(&napi_env)?; - #[cfg(feature = "js-connectors")] - if let Some(driver) = maybe_driver { - let queryable = js_connectors::JsQueryable::from(driver); - sql_connector::register_driver(Arc::new(queryable)); - } - let ConstructorOptions { datamodel, log_level, @@ -173,6 +167,12 @@ impl QueryEngine { let mut schema = psl::validate(datamodel.into()); let config = &mut schema.configuration; + #[cfg(feature = "js-connectors")] + if let Some(driver) = maybe_driver { + let queryable = js_connectors::JsQueryable::from(driver); + sql_connector::register_js_connector(Arc::new(queryable)); + } + schema .diagnostics .to_result() diff --git a/query-engine/request-handlers/Cargo.toml b/query-engine/request-handlers/Cargo.toml index 1f0436daaa6a..41568f126bea 100644 --- a/query-engine/request-handlers/Cargo.toml +++ b/query-engine/request-handlers/Cargo.toml @@ -35,6 +35,7 @@ codspeed-criterion-compat = "1.1.0" default = ["mongodb", "sql"] mongodb = ["mongodb-query-connector"] sql = ["sql-query-connector"] +js-connectors = ["sql-query-connector"] [[bench]] name = "query_planning_bench" diff --git a/query-engine/request-handlers/src/load_executor.rs b/query-engine/request-handlers/src/load_executor.rs index 8916669e8db4..8232e907e60f 100644 --- a/query-engine/request-handlers/src/load_executor.rs +++ b/query-engine/request-handlers/src/load_executor.rs @@ -24,6 +24,9 @@ pub async fn load( #[cfg(feature = "mongodb")] p if MONGODB.is_provider(p) => mongodb(source, url, features).await, + #[cfg(feature = "js-connectors")] + p if JsConnector::is_provider(p) => jsconnector(source, url, features).await, + x => Err(query_core::CoreError::ConfigurationError(format!( "Unsupported connector type: {x}" ))), @@ -38,7 +41,7 @@ async fn sqlite( trace!("Loading SQLite query connector..."); let sqlite = Sqlite::from_source(source, url, features).await?; trace!("Loaded SQLite query connector."); - Ok(sql_executor(sqlite, false)) + Ok(executor_for(sqlite, false)) } async fn postgres( @@ -47,7 +50,6 @@ async fn postgres( features: PreviewFeatures, ) -> query_core::Result> { trace!("Loading Postgres query connector..."); - let database_str = url; let psql = PostgreSql::from_source(source, url, features).await?; @@ -59,9 +61,8 @@ async fn postgres( .get("pgbouncer") .and_then(|flag| flag.parse().ok()) .unwrap_or(false); - trace!("Loaded Postgres query connector."); - Ok(sql_executor(psql, force_transactions)) + Ok(executor_for(psql, force_transactions)) } async fn mysql( @@ -71,7 +72,7 @@ async fn mysql( ) -> query_core::Result> { let mysql = Mysql::from_source(source, url, features).await?; trace!("Loaded MySQL query connector."); - Ok(sql_executor(mysql, false)) + Ok(executor_for(mysql, false)) } async fn mssql( @@ -82,10 +83,10 @@ async fn mssql( trace!("Loading SQL Server query connector..."); let mssql = Mssql::from_source(source, url, features).await?; trace!("Loaded SQL Server query connector."); - Ok(sql_executor(mssql, false)) + Ok(executor_for(mssql, false)) } -fn sql_executor(connector: T, force_transactions: bool) -> Box +fn executor_for(connector: T, force_transactions: bool) -> Box where T: Connector + Send + Sync + 'static, { @@ -101,5 +102,17 @@ async fn mongodb( trace!("Loading MongoDB query connector..."); let mongo = MongoDb::new(source, url).await?; trace!("Loaded MongoDB query connector."); - Ok(Box::new(InterpretingExecutor::new(mongo, false))) + Ok(executor_for(mongo, false)) +} + +#[cfg(feature = "js-connectors")] +async fn jsconnector( + source: &Datasource, + url: &str, + features: PreviewFeatures, +) -> Result, query_core::CoreError> { + trace!("Loading js connector ..."); + let js = Js::from_source(source, url, features).await?; + trace!("Loaded js connector ..."); + Ok(executor_for(js, false)) }