diff --git a/editoast/editoast_models/src/db_connection_pool.rs b/editoast/editoast_models/src/db_connection_pool.rs index 44760eed997..ae651012313 100644 --- a/editoast/editoast_models/src/db_connection_pool.rs +++ b/editoast/editoast_models/src/db_connection_pool.rs @@ -27,9 +27,6 @@ use url::Url; use tokio::sync::OwnedRwLockWriteGuard; use tokio::sync::RwLock; -use super::DbConnectionPool; -use super::DieselConnection; - pub type DbConnectionConfig = AsyncDieselConnectionManager<AsyncPgConnection>; #[derive(Clone)] @@ -110,7 +107,7 @@ impl DerefMut for WriteHandle { /// /// # Testing pool /// -/// In test mode, the [DbConnectionPool::get] function will always return the same connection that has +/// In test mode, the [Pool::<AsyncPgConnection>::get] function will always return the same connection that has /// been setup to drop all modification once the test ends. /// Since this connection will not commit any changes to the database, we ensure the isolation of each test. /// @@ -423,17 +420,17 @@ pub async fn ping_database(conn: &mut DbConnection) -> Result<(), PingError> { Ok(()) } -pub fn create_connection_pool( +fn create_connection_pool( url: Url, max_size: usize, -) -> Result<DbConnectionPool, DatabasePoolBuildError> { +) -> Result<Pool<AsyncPgConnection>, DatabasePoolBuildError> { let mut manager_config = ManagerConfig::default(); manager_config.custom_setup = Box::new(establish_connection); let manager = DbConnectionConfig::new_with_config(url, manager_config); Ok(Pool::builder(manager).max_size(max_size).build()?) } -fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<DieselConnection>> { +fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<AsyncPgConnection>> { let fut = async { let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap(); connector_builder.set_verify(SslVerifyMode::NONE); @@ -448,7 +445,7 @@ fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<DieselConnec tracing::error!("connection error: {}", e); } }); - DieselConnection::try_from(client).await + AsyncPgConnection::try_from(client).await }; fut.boxed() } diff --git a/editoast/editoast_models/src/lib.rs b/editoast/editoast_models/src/lib.rs index 079c7f87bed..92df8442939 100644 --- a/editoast/editoast_models/src/lib.rs +++ b/editoast/editoast_models/src/lib.rs @@ -1,18 +1,12 @@ -use diesel_async::pooled_connection::deadpool::Pool; -use diesel_async::AsyncPgConnection; - pub mod db_connection_pool; pub mod tables; pub use db_connection_pool::DbConnection; pub use db_connection_pool::DbConnectionPoolV2; -type DieselConnection = AsyncPgConnection; -pub type DbConnectionPool = Pool<DieselConnection>; - /// Generic error type to forward errors from the database /// /// Useful for functions which only points of failure are the DB calls. #[derive(Debug, thiserror::Error)] -#[error("an error occured while querying the database: {0}")] +#[error("an error occurred while querying the database: {0}")] pub struct DatabaseError(#[from] diesel::result::Error); diff --git a/editoast/openapi.yaml b/editoast/openapi.yaml index 53f213dafad..29362fc945f 100644 --- a/editoast/openapi.yaml +++ b/editoast/openapi.yaml @@ -10072,7 +10072,7 @@ components: enum: - pathfinding_not_found - type: object - description: An error has occured during pathfinding + description: An error has occurred during pathfinding required: - core_error - status @@ -10084,7 +10084,7 @@ components: enum: - pathfinding_failure - type: object - description: An error has occured during computing + description: An error has occurred during computing required: - error_type - status diff --git a/editoast/src/models/fixtures.rs b/editoast/src/models/fixtures.rs index 2f32e11255d..5a93f662aa3 100644 --- a/editoast/src/models/fixtures.rs +++ b/editoast/src/models/fixtures.rs @@ -4,7 +4,7 @@ use std::ops::DerefMut; use chrono::Utc; use editoast_models::DbConnection; -use editoast_models::DbConnectionPool; + use editoast_models::DbConnectionPoolV2; use editoast_schemas::infra::Direction; use editoast_schemas::infra::DirectionalTrackRange; diff --git a/editoast/src/views/infra/attached.rs b/editoast/src/views/infra/attached.rs index 6595a55bc4a..ae0184eb32c 100644 --- a/editoast/src/views/infra/attached.rs +++ b/editoast/src/views/infra/attached.rs @@ -67,7 +67,7 @@ async fn attached( Path(InfraAttachedParams { infra_id, track_id }): Path<InfraAttachedParams>, State(AppState { infra_caches, - db_pool_v2: db_pool, + db_pool, .. }): State<AppState>, Extension(auth): AuthenticationExt, diff --git a/editoast/src/views/infra/auto_fixes/mod.rs b/editoast/src/views/infra/auto_fixes/mod.rs index b62802b87e7..cb9cd2649ab 100644 --- a/editoast/src/views/infra/auto_fixes/mod.rs +++ b/editoast/src/views/infra/auto_fixes/mod.rs @@ -87,7 +87,7 @@ async fn list_auto_fixes( Path(infra_id): Path<i64>, State(AppState { infra_caches, - db_pool_v2: db_pool, + db_pool, .. }): State<AppState>, Extension(auth): AuthenticationExt, diff --git a/editoast/src/views/infra/delimited_area.rs b/editoast/src/views/infra/delimited_area.rs index 0ac51ae9b87..7ef6104afc7 100644 --- a/editoast/src/views/infra/delimited_area.rs +++ b/editoast/src/views/infra/delimited_area.rs @@ -129,7 +129,7 @@ async fn delimited_area( Extension(auth): AuthenticationExt, State(AppState { infra_caches, - db_pool_v2: db_pool, + db_pool, .. }): State<AppState>, Path(InfraIdParam { infra_id }): Path<InfraIdParam>, diff --git a/editoast/src/views/infra/edition.rs b/editoast/src/views/infra/edition.rs index fc58345f4bf..241ce66917b 100644 --- a/editoast/src/views/infra/edition.rs +++ b/editoast/src/views/infra/edition.rs @@ -71,7 +71,7 @@ crate::routes! { async fn edit<'a>( Path(InfraIdParam { infra_id }): Path<InfraIdParam>, State(AppState { - db_pool_v2: db_pool, + db_pool, infra_caches, valkey, map_layers, @@ -126,7 +126,7 @@ async fn edit<'a>( pub async fn split_track_section<'a>( Path(InfraIdParam { infra_id }): Path<InfraIdParam>, State(AppState { - db_pool_v2: db_pool, + db_pool, infra_caches, valkey, map_layers, diff --git a/editoast/src/views/infra/lines.rs b/editoast/src/views/infra/lines.rs index ecf9ce840a4..5f5d89f57c3 100644 --- a/editoast/src/views/infra/lines.rs +++ b/editoast/src/views/infra/lines.rs @@ -45,7 +45,7 @@ async fn get_line_bbox( Path((infra_id, line_code)): Path<(i64, i64)>, State(AppState { infra_caches, - db_pool_v2: db_pool, + db_pool, .. }): State<AppState>, Extension(auth): AuthenticationExt, diff --git a/editoast/src/views/infra/mod.rs b/editoast/src/views/infra/mod.rs index 442886c72c6..dc64f075337 100644 --- a/editoast/src/views/infra/mod.rs +++ b/editoast/src/views/infra/mod.rs @@ -123,7 +123,13 @@ struct RefreshResponse { ) )] async fn refresh( - app_state: State<AppState>, + State(AppState { + db_pool, + valkey: valkey_client, + infra_caches, + map_layers, + .. + }): State<AppState>, Extension(auth): AuthenticationExt, Query(query_params): Query<RefreshQueryParams>, ) -> Result<Json<RefreshResponse>> { @@ -135,11 +141,6 @@ async fn refresh( return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); - let valkey_client = app_state.valkey.clone(); - let infra_caches = app_state.infra_caches.clone(); - let map_layers = app_state.map_layers.clone(); - // Use a transaction to give scope to infra list lock let RefreshQueryParams { force, @@ -160,7 +161,6 @@ async fn refresh( }; // Refresh each infras - let db_pool = db_pool; let mut infra_refreshed = vec![]; for mut infra in infras_list { @@ -201,9 +201,13 @@ struct InfraListResponse { ), )] async fn list( - app_state: State<AppState>, + State(AppState { + db_pool, + osrdyne_client, + .. + }): State<AppState>, Extension(auth): AuthenticationExt, - pagination_params: Query<PaginationQueryParams>, + Query(pagination_params): Query<PaginationQueryParams>, ) -> Result<Json<InfraListResponse>> { let authorized = auth .check_roles([BuiltinRole::InfraRead].into()) @@ -212,8 +216,6 @@ async fn list( if !authorized { return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); - let osrdyne_client = app_state.osrdyne_client.clone(); let settings = pagination_params .validate(1000)? @@ -295,7 +297,11 @@ struct InfraIdParam { ), )] async fn get( - app_state: State<AppState>, + State(AppState { + db_pool, + osrdyne_client, + .. + }): State<AppState>, Extension(auth): AuthenticationExt, Path(infra): Path<InfraIdParam>, ) -> Result<Json<InfraWithState>> { @@ -307,9 +313,6 @@ async fn get( return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); - let osrdyne_client = app_state.osrdyne_client.clone(); - let infra_id = infra.infra_id; let infra = Infra::retrieve_or_fail(&mut db_pool.get().await?, infra_id, || { InfraApiError::NotFound { infra_id } @@ -344,7 +347,7 @@ impl From<InfraCreateForm> for Changeset<Infra> { ), )] async fn create( - db_pool: State<DbConnectionPoolV2>, + State(db_pool): State<DbConnectionPoolV2>, Extension(auth): AuthenticationExt, Json(infra_form): Json<InfraCreateForm>, ) -> Result<impl IntoResponse> { @@ -381,7 +384,7 @@ struct CloneQuery { async fn clone( Extension(auth): AuthenticationExt, Path(params): Path<InfraIdParam>, - db_pool: State<DbConnectionPoolV2>, + State(db_pool): State<DbConnectionPoolV2>, Query(CloneQuery { name }): Query<CloneQuery>, ) -> Result<Json<i64>> { let authorized = auth @@ -421,9 +424,13 @@ async fn clone( ), )] async fn delete( - app_state: State<AppState>, + State(AppState { + db_pool, + infra_caches, + .. + }): State<AppState>, Extension(auth): AuthenticationExt, - infra: Path<InfraIdParam>, + Path(InfraIdParam { infra_id }): Path<InfraIdParam>, ) -> Result<impl IntoResponse> { let authorized = auth .check_roles([BuiltinRole::InfraWrite].into()) @@ -433,9 +440,6 @@ async fn delete( return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); - let infra_caches = app_state.infra_caches.clone(); - let infra_id = infra.infra_id; if Infra::fast_delete_static(db_pool.get().await?, infra_id).await? { infra_caches.remove(&infra_id); Ok(StatusCode::NO_CONTENT) @@ -468,7 +472,7 @@ impl From<InfraPatchForm> for Changeset<Infra> { ), )] async fn put( - db_pool: State<DbConnectionPoolV2>, + State(db_pool): State<DbConnectionPoolV2>, Extension(auth): AuthenticationExt, Path(infra): Path<i64>, Json(patch): Json<InfraPatchForm>, @@ -501,7 +505,11 @@ async fn put( ) )] async fn get_switch_types( - app_state: State<AppState>, + State(AppState { + db_pool, + infra_caches, + .. + }): State<AppState>, Extension(auth): AuthenticationExt, Path(infra): Path<InfraIdParam>, ) -> Result<Json<Vec<SwitchType>>> { @@ -513,9 +521,7 @@ async fn get_switch_types( return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); let conn = &mut db_pool.get().await?; - let infra_caches = app_state.infra_caches.clone(); let infra = Infra::retrieve_or_fail(conn, infra.infra_id, || InfraApiError::NotFound { infra_id: infra.infra_id, @@ -546,7 +552,7 @@ async fn get_switch_types( async fn get_speed_limit_tags( Extension(auth): AuthenticationExt, Path(infra): Path<InfraIdParam>, - db_pool: State<DbConnectionPoolV2>, + State(db_pool): State<DbConnectionPoolV2>, ) -> Result<Json<Vec<String>>> { let authorized = auth .check_roles([BuiltinRole::InfraRead].into()) @@ -590,7 +596,7 @@ async fn get_voltages( Extension(auth): AuthenticationExt, Path(infra): Path<InfraIdParam>, Query(param): Query<GetVoltagesQueryParams>, - db_pool: State<DbConnectionPoolV2>, + State(db_pool): State<DbConnectionPoolV2>, ) -> Result<Json<Vec<String>>> { let authorized = auth .check_roles([BuiltinRole::InfraRead].into()) @@ -623,7 +629,7 @@ async fn get_voltages( ) )] async fn get_all_voltages( - db_pool: State<DbConnectionPoolV2>, + State(db_pool): State<DbConnectionPoolV2>, Extension(auth): AuthenticationExt, ) -> Result<Json<Vec<String>>> { let authorized = auth @@ -712,7 +718,11 @@ async fn unlock( ) )] async fn load( - app_state: State<AppState>, + State(AppState { + db_pool, + core_client, + .. + }): State<AppState>, Extension(auth): AuthenticationExt, Path(path): Path<InfraIdParam>, ) -> Result<impl IntoResponse> { @@ -724,9 +734,6 @@ async fn load( return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); - let core_client = app_state.core_client.clone(); - let infra_id = path.infra_id; let infra = Infra::retrieve_or_fail(&mut db_pool.get().await?, infra_id, || { InfraApiError::NotFound { infra_id } diff --git a/editoast/src/views/infra/pathfinding.rs b/editoast/src/views/infra/pathfinding.rs index f727541d07c..b4a77b80a0d 100644 --- a/editoast/src/views/infra/pathfinding.rs +++ b/editoast/src/views/infra/pathfinding.rs @@ -96,7 +96,11 @@ struct QueryParam { ) )] async fn pathfinding_view( - app_state: State<AppState>, + State(AppState { + db_pool, + infra_caches, + .. + }): State<AppState>, Extension(auth): AuthenticationExt, Path(infra): Path<InfraIdParam>, Query(params): Query<QueryParam>, @@ -110,9 +114,6 @@ async fn pathfinding_view( return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); - let infra_caches = app_state.infra_caches.clone(); - // Parse and check input let infra_id = infra.infra_id; let number = params.number.unwrap_or(DEFAULT_NUMBER_OF_PATHS); diff --git a/editoast/src/views/infra/railjson.rs b/editoast/src/views/infra/railjson.rs index 607c2bd7d55..ad859718765 100644 --- a/editoast/src/views/infra/railjson.rs +++ b/editoast/src/views/infra/railjson.rs @@ -55,7 +55,7 @@ enum ListErrorsRailjson { )] async fn get_railjson( Path(infra): Path<InfraIdParam>, - db_pool: State<DbConnectionPoolV2>, + State(db_pool): State<DbConnectionPoolV2>, Extension(auth): AuthenticationExt, ) -> Result<impl IntoResponse> { let authorized = auth @@ -168,7 +168,11 @@ struct PostRailjsonResponse { ) )] async fn post_railjson( - app_state: State<AppState>, + State(AppState { + db_pool, + infra_caches, + .. + }): State<AppState>, Extension(auth): AuthenticationExt, Query(params): Query<PostRailjsonQueryParams>, Json(railjson): Json<RailJson>, @@ -181,8 +185,6 @@ async fn post_railjson( return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); - let infra_caches = app_state.infra_caches.clone(); if railjson.version != RAILJSON_VERSION { return Err(ListErrorsRailjson::WrongRailjsonVersionProvided.into()); } diff --git a/editoast/src/views/infra/routes.rs b/editoast/src/views/infra/routes.rs index 91769ed92fd..22c7b52ac2c 100644 --- a/editoast/src/views/infra/routes.rs +++ b/editoast/src/views/infra/routes.rs @@ -67,7 +67,7 @@ struct RoutesResponse { )] async fn get_routes_from_waypoint( Path(path): Path<RoutesFromWaypointParams>, - db_pool: State<DbConnectionPoolV2>, + State(db_pool): State<DbConnectionPoolV2>, Extension(auth): AuthenticationExt, ) -> Result<Json<RoutesResponse>> { let authorized = auth @@ -145,7 +145,11 @@ struct RoutesFromNodesPositions { ), )] async fn get_routes_track_ranges( - app_state: State<AppState>, + State(AppState { + db_pool, + infra_caches, + .. + }): State<AppState>, Extension(auth): AuthenticationExt, Path(infra): Path<i64>, Query(params): Query<RouteTrackRangesParams>, @@ -158,8 +162,8 @@ async fn get_routes_track_ranges( return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); - let infra_caches = app_state.infra_caches.clone(); + let db_pool = db_pool.clone(); + let infra_caches = infra_caches.clone(); let infra_id = infra; let infra = Infra::retrieve_or_fail(&mut db_pool.get().await?, infra_id, || { InfraApiError::NotFound { infra_id } @@ -206,7 +210,11 @@ async fn get_routes_track_ranges( ), )] async fn get_routes_nodes( - app_state: State<AppState>, + State(AppState { + db_pool, + infra_caches, + .. + }): State<AppState>, Extension(auth): AuthenticationExt, Path(params): Path<InfraIdParam>, Json(node_states): Json<HashMap<String, Option<String>>>, @@ -219,9 +227,6 @@ async fn get_routes_nodes( return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); - let infra_caches = app_state.infra_caches.clone(); - let infra = Infra::retrieve_or_fail(&mut db_pool.get().await?, params.infra_id, || { InfraApiError::NotFound { infra_id: params.infra_id, diff --git a/editoast/src/views/layers.rs b/editoast/src/views/layers.rs index 78a031822a7..5b9bcf88ed4 100644 --- a/editoast/src/views/layers.rs +++ b/editoast/src/views/layers.rs @@ -179,7 +179,7 @@ struct TileParams { async fn cache_and_get_mvt_tile( State(AppState { map_layers, - db_pool_v2: db_pool, + db_pool, valkey, .. }): State<AppState>, diff --git a/editoast/src/views/mod.rs b/editoast/src/views/mod.rs index 522c23f9e71..6f68bcefc5a 100644 --- a/editoast/src/views/mod.rs +++ b/editoast/src/views/mod.rs @@ -41,7 +41,7 @@ use dashmap::DashMap; use editoast_authz::authorizer::Authorizer; use editoast_authz::authorizer::UserInfo; use editoast_authz::BuiltinRole; -use editoast_models::DbConnectionPool; + use editoast_osrdyne_client::OsrdyneClient; use futures::TryFutureExt; pub use openapi::OpenApiRoot; @@ -225,7 +225,7 @@ async fn authenticate( async fn authentication_middleware( State(AppState { - db_pool_v2: db_pool, + db_pool, disable_authorization, .. }): State<AppState>, @@ -278,7 +278,7 @@ pub enum AppHealthError { )] async fn health( State(AppState { - db_pool_v2: db_pool, + db_pool, valkey, health_check_timeout, core_client, @@ -334,8 +334,7 @@ async fn version() -> Json<Version> { (status = 200, description = "Return the core service version", body = Version), ), )] -async fn core_version(app_state: State<AppState>) -> Json<Version> { - let core = app_state.core_client.clone(); +async fn core_version(State(core): State<Arc<CoreClient>>) -> Json<Version> { let response = CoreVersionRequest {}.fetch(&core).await; let response = response.unwrap_or(Version { git_describe: None }); Json(response) @@ -384,8 +383,7 @@ pub struct Server { pub struct AppState { pub config: Arc<ServerConfig>, - pub db_pool_v1: Arc<DbConnectionPool>, - pub db_pool_v2: Arc<DbConnectionPoolV2>, + pub db_pool: Arc<DbConnectionPoolV2>, pub valkey: Arc<ValkeyClient>, pub infra_caches: Arc<DashMap<i64, InfraCache>>, pub map_layers: Arc<MapLayers>, @@ -398,7 +396,13 @@ pub struct AppState { impl FromRef<AppState> for DbConnectionPoolV2 { fn from_ref(input: &AppState) -> Self { - (*input.db_pool_v2).clone() + (*input.db_pool).clone() + } +} + +impl FromRef<AppState> for Arc<CoreClient> { + fn from_ref(input: &AppState) -> Self { + input.core_client.clone() } } @@ -409,16 +413,15 @@ impl AppState { // Config database let valkey = ValkeyClient::new(config.valkey_config.clone())?.into(); - // Create both database pools - let db_pool_v2 = { + // Create database pool + let db_pool = { let PostgresConfig { database_url, pool_size, } = config.postgres_config.clone(); - DbConnectionPoolV2::try_initialize(database_url, pool_size).await? + let pool = DbConnectionPoolV2::try_initialize(database_url, pool_size).await?; + Arc::new(pool) }; - let db_pool_v1 = db_pool_v2.pool_v1(); - let db_pool_v2 = Arc::new(db_pool_v2); // Setup infra cache map let infra_caches = DashMap::<i64, InfraCache>::default().into(); @@ -449,8 +452,7 @@ impl AppState { Ok(Self { valkey, - db_pool_v1, - db_pool_v2, + db_pool, infra_caches, core_client, osrdyne_client, diff --git a/editoast/src/views/path/pathfinding.rs b/editoast/src/views/path/pathfinding.rs index 7ecb6f5417e..220566072ef 100644 --- a/editoast/src/views/path/pathfinding.rs +++ b/editoast/src/views/path/pathfinding.rs @@ -158,7 +158,7 @@ pub enum PathfindingFailure { )] async fn post( State(AppState { - db_pool_v2: db_pool, + db_pool, valkey, core_client, .. diff --git a/editoast/src/views/path/properties.rs b/editoast/src/views/path/properties.rs index 4a32b467231..ca5b0aaff1c 100644 --- a/editoast/src/views/path/properties.rs +++ b/editoast/src/views/path/properties.rs @@ -164,7 +164,7 @@ type Properties = EnumSet<Property>; )] async fn post( State(AppState { - db_pool_v2: db_pool, + db_pool, valkey, core_client, .. diff --git a/editoast/src/views/stdcm_search_environment.rs b/editoast/src/views/stdcm_search_environment.rs index cfd94a4d604..e4d6db80f34 100644 --- a/editoast/src/views/stdcm_search_environment.rs +++ b/editoast/src/views/stdcm_search_environment.rs @@ -6,6 +6,7 @@ use axum::response::Response; use axum::Extension; use chrono::NaiveDateTime; use editoast_authz::BuiltinRole; +use editoast_models::DbConnectionPoolV2; use serde::de::Error as SerdeError; use serde::Deserialize; use std::result::Result as StdResult; @@ -19,7 +20,6 @@ use crate::models::stdcm_search_environment::StdcmSearchEnvironment; use crate::models::Changeset; use crate::views::AuthenticationExt; use crate::views::AuthorizationError; -use crate::AppState; use crate::Model; crate::routes! { @@ -106,7 +106,7 @@ impl From<StdcmSearchEnvironmentCreateForm> for Changeset<StdcmSearchEnvironment ) )] async fn overwrite( - State(app_state): State<AppState>, + State(db_pool): State<DbConnectionPoolV2>, Extension(auth): AuthenticationExt, Json(form): Json<StdcmSearchEnvironmentCreateForm>, ) -> Result<impl IntoResponse> { @@ -118,11 +118,8 @@ async fn overwrite( return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); let conn = &mut db_pool.get().await?; - let changeset: Changeset<StdcmSearchEnvironment> = form.into(); - Ok((StatusCode::CREATED, Json(changeset.overwrite(conn).await?))) } @@ -135,7 +132,7 @@ async fn overwrite( ) )] async fn retrieve_latest( - State(app_state): State<AppState>, + State(db_pool): State<DbConnectionPoolV2>, Extension(auth): AuthenticationExt, ) -> Result<Response> { let authorized = auth @@ -146,9 +143,7 @@ async fn retrieve_latest( return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); let conn = &mut db_pool.get().await?; - let search_env = StdcmSearchEnvironment::retrieve_latest(conn).await; if let Some(search_env) = search_env { Ok(Json(search_env).into_response()) diff --git a/editoast/src/views/test_app.rs b/editoast/src/views/test_app.rs index fb1f0372810..9ab47d93081 100644 --- a/editoast/src/views/test_app.rs +++ b/editoast/src/views/test_app.rs @@ -7,7 +7,6 @@ use std::sync::Arc; use axum::Router; use axum_tracing_opentelemetry::middleware::OtelAxumLayer; use dashmap::DashMap; -use editoast_models::db_connection_pool::create_connection_pool; use editoast_models::DbConnectionPoolV2; use editoast_osrdyne_client::OsrdyneClient; use serde::de::DeserializeOwned; @@ -40,7 +39,6 @@ pub(crate) struct TestAppBuilder { db_pool: Option<DbConnectionPoolV2>, core_client: Option<CoreClient>, osrdyne_client: Option<OsrdyneClient>, - db_pool_v1: bool, } impl TestAppBuilder { @@ -49,13 +47,11 @@ impl TestAppBuilder { db_pool: None, core_client: None, osrdyne_client: None, - db_pool_v1: false, } } pub fn db_pool(mut self, db_pool: DbConnectionPoolV2) -> Self { assert!(self.db_pool.is_none()); - assert!(!self.db_pool_v1); self.db_pool = Some(db_pool); self } @@ -126,24 +122,10 @@ impl TestAppBuilder { .expect("Could not build Valkey client") .into(); - // Create both database pools - let (db_pool_v2, db_pool_v1) = if self.db_pool_v1 { - let PostgresConfig { - database_url, - pool_size, - } = config.postgres_config.clone(); - let pool = create_connection_pool(database_url, pool_size) - .expect("could not create connection pool for tests"); - let v1 = Arc::new(pool); - let v2 = futures::executor::block_on(DbConnectionPoolV2::from_pool(v1.clone())); - (Arc::new(v2), v1) - } else { - let db_pool_v2 = self.db_pool.expect( - "No database pool provided to TestAppBuilder, use Default or provide a database pool" - ); - let db_pool_v1 = db_pool_v2.pool_v1(); - (Arc::new(db_pool_v2), db_pool_v1) - }; + // Create database pool + let db_pool_v2 = Arc::new(self.db_pool.expect( + "No database pool provided to TestAppBuilder, use Default or provide a database pool", + )); // Setup infra cache map let infra_caches = DashMap::<i64, InfraCache>::default().into(); @@ -163,8 +145,7 @@ impl TestAppBuilder { let osrdyne_client = Arc::new(osrdyne_client); let app_state = AppState { - db_pool_v1, - db_pool_v2: db_pool_v2.clone(), + db_pool: db_pool_v2.clone(), core_client: core_client.clone(), osrdyne_client, valkey, diff --git a/editoast/src/views/timetable.rs b/editoast/src/views/timetable.rs index 8b278f6111d..3b888627a9e 100644 --- a/editoast/src/views/timetable.rs +++ b/editoast/src/views/timetable.rs @@ -120,9 +120,9 @@ struct TimetableIdParam { ), )] async fn get( - db_pool: State<DbConnectionPoolV2>, + State(db_pool): State<DbConnectionPoolV2>, Extension(auth): AuthenticationExt, - Path(timetable_id): Path<TimetableIdParam>, + Path(TimetableIdParam { id: timetable_id }): Path<TimetableIdParam>, ) -> Result<Json<TimetableDetailedResult>> { let authorized = auth .check_roles([BuiltinRole::TimetableRead].into()) @@ -132,15 +132,11 @@ async fn get( return Err(AuthorizationError::Unauthorized.into()); } - let timetable_id = timetable_id.id; - // Return the timetable - let conn = &mut db_pool.get().await?; let timetable = TimetableWithTrains::retrieve_or_fail(conn, timetable_id, || { TimetableError::NotFound { timetable_id } }) .await?; - Ok(Json(timetable.into())) } @@ -154,7 +150,7 @@ async fn get( ), )] async fn post( - db_pool: State<DbConnectionPoolV2>, + State(db_pool): State<DbConnectionPoolV2>, Extension(auth): AuthenticationExt, ) -> Result<Json<TimetableResult>> { let authorized = auth @@ -183,9 +179,9 @@ async fn post( ), )] async fn delete( - db_pool: State<DbConnectionPoolV2>, + State(db_pool): State<DbConnectionPoolV2>, Extension(auth): AuthenticationExt, - timetable_id: Path<TimetableIdParam>, + Path(TimetableIdParam { id: timetable_id }): Path<TimetableIdParam>, ) -> Result<impl IntoResponse> { let authorized = auth .check_roles([BuiltinRole::TimetableWrite].into()) @@ -195,7 +191,6 @@ async fn delete( return Err(AuthorizationError::Unauthorized.into()); } - let timetable_id = timetable_id.id; let conn = &mut db_pool.get().await?; Timetable::delete_static_or_fail(conn, timetable_id, || TimetableError::NotFound { timetable_id, @@ -215,9 +210,9 @@ async fn delete( ) )] async fn train_schedule( - db_pool: State<DbConnectionPoolV2>, + State(db_pool): State<DbConnectionPoolV2>, Extension(auth): AuthenticationExt, - Path(timetable_id): Path<TimetableIdParam>, + Path(TimetableIdParam { id: timetable_id }): Path<TimetableIdParam>, Json(train_schedules): Json<Vec<TrainScheduleBase>>, ) -> Result<Json<Vec<TrainScheduleResult>>> { let authorized = auth @@ -230,7 +225,6 @@ async fn train_schedule( let conn = &mut db_pool.get().await?; - let timetable_id = timetable_id.id; TimetableWithTrains::retrieve_or_fail(conn, timetable_id, || TimetableError::NotFound { timetable_id, }) @@ -271,11 +265,18 @@ pub struct ElectricalProfileSetIdQueryParam { ), )] async fn conflicts( - app_state: State<AppState>, + State(AppState { + db_pool, + valkey: valkey_client, + core_client, + .. + }): State<AppState>, Extension(auth): AuthenticationExt, - Path(timetable_id): Path<TimetableIdParam>, - Query(infra_id_query): Query<InfraIdQueryParam>, - Query(electrical_profile_set_id_query): Query<ElectricalProfileSetIdQueryParam>, + Path(TimetableIdParam { id: timetable_id }): Path<TimetableIdParam>, + Query(InfraIdQueryParam { infra_id }): Query<InfraIdQueryParam>, + Query(ElectricalProfileSetIdQueryParam { + electrical_profile_set_id, + }): Query<ElectricalProfileSetIdQueryParam>, ) -> Result<Json<Vec<Conflict>>> { let authorized = auth .check_roles([BuiltinRole::InfraRead, BuiltinRole::TimetableRead].into()) @@ -285,14 +286,6 @@ async fn conflicts( return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); - let valkey_client = app_state.valkey.clone(); - let core_client = app_state.core_client.clone(); - - let timetable_id = timetable_id.id; - let infra_id = infra_id_query.infra_id; - let electrical_profile_set_id = electrical_profile_set_id_query.electrical_profile_set_id; - // 1. Retrieve Timetable / Infra / Trains / Simultion let timetable_trains = TimetableWithTrains::retrieve_or_fail(&mut db_pool.get().await?, timetable_id, || { diff --git a/editoast/src/views/timetable/stdcm.rs b/editoast/src/views/timetable/stdcm.rs index 34fe53fb625..1abf7ef1ae3 100644 --- a/editoast/src/views/timetable/stdcm.rs +++ b/editoast/src/views/timetable/stdcm.rs @@ -120,7 +120,12 @@ struct InfraIdQueryParam { ) )] async fn stdcm( - app_state: State<AppState>, + State(AppState { + db_pool, + valkey: valkey_client, + core_client, + .. + }): State<AppState>, Extension(auth): AuthenticationExt, Path(id): Path<i64>, Query(query): Query<InfraIdQueryParam>, @@ -134,11 +139,8 @@ async fn stdcm( return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); let conn = &mut db_pool.get().await?; - let valkey_client = app_state.valkey.clone(); - let core_client = app_state.core_client.clone(); let timetable_id = id; let infra_id = query.infra; diff --git a/editoast/src/views/train_schedule.rs b/editoast/src/views/train_schedule.rs index d537e715985..b6aa27a3ab3 100644 --- a/editoast/src/views/train_schedule.rs +++ b/editoast/src/views/train_schedule.rs @@ -167,9 +167,11 @@ impl From<TrainScheduleForm> for TrainScheduleChangeset { ) )] async fn get( - app_state: State<AppState>, + State(db_pool): State<DbConnectionPoolV2>, Extension(auth): AuthenticationExt, - train_schedule_id: Path<TrainScheduleIdParam>, + Path(TrainScheduleIdParam { + id: train_schedule_id, + }): Path<TrainScheduleIdParam>, ) -> Result<Json<TrainScheduleResult>> { let authorized = auth .check_roles([BuiltinRole::InfraRead, BuiltinRole::TimetableRead].into()) @@ -179,10 +181,7 @@ async fn get( return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); - let train_schedule_id = train_schedule_id.id; let conn = &mut db_pool.get().await?; - let train_schedule = TrainSchedule::retrieve_or_fail(conn, train_schedule_id, || { TrainScheduleError::NotFound { train_schedule_id } }) @@ -205,7 +204,7 @@ struct BatchRequest { ) )] async fn get_batch( - app_state: State<AppState>, + State(db_pool): State<DbConnectionPoolV2>, Extension(auth): AuthenticationExt, Json(BatchRequest { ids: train_ids }): Json<BatchRequest>, ) -> Result<Json<Vec<TrainScheduleResult>>> { @@ -217,7 +216,6 @@ async fn get_batch( return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); let conn = &mut db_pool.get().await?; let train_schedules: Vec<TrainSchedule> = TrainSchedule::retrieve_batch_or_fail(conn, train_ids, |missing| { @@ -239,7 +237,7 @@ async fn get_batch( ) )] async fn delete( - app_state: State<AppState>, + State(db_pool): State<DbConnectionPoolV2>, Extension(auth): AuthenticationExt, Json(BatchRequest { ids: train_ids }): Json<BatchRequest>, ) -> Result<impl IntoResponse> { @@ -251,8 +249,6 @@ async fn delete( return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); - use crate::models::DeleteBatch; let conn = &mut db_pool.get().await?; TrainSchedule::delete_batch_or_fail(conn, train_ids, |number| { @@ -274,9 +270,11 @@ async fn delete( ) )] async fn put( - db_pool: State<DbConnectionPoolV2>, + State(db_pool): State<DbConnectionPoolV2>, Extension(auth): AuthenticationExt, - train_schedule_id: Path<TrainScheduleIdParam>, + Path(TrainScheduleIdParam { + id: train_schedule_id, + }): Path<TrainScheduleIdParam>, Json(train_schedule_form): Json<TrainScheduleForm>, ) -> Result<Json<TrainScheduleResult>> { let authorized = auth @@ -288,10 +286,7 @@ async fn put( } let conn = &mut db_pool.get().await?; - - let train_schedule_id = train_schedule_id.id; let ts_changeset: TrainScheduleChangeset = train_schedule_form.into(); - let ts_result = ts_changeset .update_or_fail(conn, train_schedule_id, || TrainScheduleError::NotFound { train_schedule_id, @@ -323,11 +318,20 @@ pub struct ElectricalProfileSetIdQueryParam { ), )] async fn simulation( - app_state: State<AppState>, + State(AppState { + valkey: valkey_client, + core_client, + db_pool, + .. + }): State<AppState>, Extension(auth): AuthenticationExt, - Path(train_schedule_id): Path<TrainScheduleIdParam>, - Query(infra_id_query): Query<InfraIdQueryParam>, - Query(electrical_profile_set_id_query): Query<ElectricalProfileSetIdQueryParam>, + Path(TrainScheduleIdParam { + id: train_schedule_id, + }): Path<TrainScheduleIdParam>, + Query(InfraIdQueryParam { infra_id }): Query<InfraIdQueryParam>, + Query(ElectricalProfileSetIdQueryParam { + electrical_profile_set_id, + }): Query<ElectricalProfileSetIdQueryParam>, ) -> Result<Json<SimulationResponse>> { let authorized = auth .check_roles([BuiltinRole::InfraRead, BuiltinRole::TimetableRead].into()) @@ -337,14 +341,6 @@ async fn simulation( return Err(AuthorizationError::Unauthorized.into()); } - let valkey_client = app_state.valkey.clone(); - let core_client = app_state.core_client.clone(); - let db_pool = app_state.db_pool_v2.clone(); - - let infra_id = infra_id_query.infra_id; - let electrical_profile_set_id = electrical_profile_set_id_query.electrical_profile_set_id; - let train_schedule_id = train_schedule_id.id; - // Retrieve infra or fail let infra = Infra::retrieve_or_fail(&mut db_pool.get().await?, infra_id, || { TrainScheduleError::InfraNotFound { infra_id } @@ -638,9 +634,9 @@ enum SimulationSummaryResult { }, /// Pathfinding not found PathfindingNotFound(PathfindingNotFound), - /// An error has occured during pathfinding + /// An error has occurred during pathfinding PathfindingFailure { core_error: InternalError }, - /// An error has occured during computing + /// An error has occurred during computing SimulationFailed { error_type: String }, /// InputError PathfindingInputError(PathfindingInputError), @@ -657,7 +653,12 @@ enum SimulationSummaryResult { ), )] async fn simulation_summary( - app_state: State<AppState>, + State(AppState { + db_pool, + valkey: valkey_client, + core_client: core, + .. + }): State<AppState>, Extension(auth): AuthenticationExt, Json(SimulationBatchForm { infra_id, @@ -673,10 +674,7 @@ async fn simulation_summary( return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); let conn = &mut db_pool.get().await?; - let valkey_client = app_state.valkey.clone(); - let core = app_state.core_client.clone(); let infra = Infra::retrieve_or_fail(conn, infra_id, || TrainScheduleError::InfraNotFound { infra_id, @@ -759,7 +757,12 @@ async fn simulation_summary( ) )] async fn get_path( - app_state: State<AppState>, + State(AppState { + db_pool, + valkey: valkey_client, + core_client: core, + .. + }): State<AppState>, Extension(auth): AuthenticationExt, Path(TrainScheduleIdParam { id: train_schedule_id, @@ -774,10 +777,6 @@ async fn get_path( return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); - let valkey_client = app_state.valkey.clone(); - let core = app_state.core_client.clone(); - let conn = &mut db_pool.get().await?; let mut valkey_conn = valkey_client.get_connection().await?; diff --git a/editoast/src/views/train_schedule/projection.rs b/editoast/src/views/train_schedule/projection.rs index 46590b1f151..8bfa58abdae 100644 --- a/editoast/src/views/train_schedule/projection.rs +++ b/editoast/src/views/train_schedule/projection.rs @@ -126,7 +126,12 @@ struct CachedProjectPathTrainResult { ), )] async fn project_path( - app_state: State<AppState>, + State(AppState { + db_pool, + valkey: valkey_client, + core_client, + .. + }): State<AppState>, Extension(auth): AuthenticationExt, Json(ProjectPathForm { infra_id, @@ -150,10 +155,6 @@ async fn project_path( return Err(AuthorizationError::Unauthorized.into()); } - let db_pool = app_state.db_pool_v2.clone(); - let valkey_client = app_state.valkey.clone(); - let core_client = app_state.core_client.clone(); - let ProjectPathInput { track_section_ranges: path_track_ranges, routes: path_routes, @@ -287,11 +288,11 @@ async fn project_path( let cached_value = CachedProjectPathTrainResult { space_time_curves: space_time_curves .get(id) - .expect("Space time curves not availabe for train") + .expect("Space time curves not available for train") .clone(), signal_updates: signal_updates .get(id) - .expect("Signal update not availabe for train") + .expect("Signal update not available for train") .clone(), }; hit_cache.insert(*id, cached_value.clone());