Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor improvements of AppState management in editoast #9935

Merged
merged 4 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions editoast/editoast_models/src/db_connection_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ use url::Url;
use tokio::sync::OwnedRwLockWriteGuard;
use tokio::sync::RwLock;

use super::DbConnectionPool;
use super::DieselConnection;

pub type DbConnectionConfig = AsyncDieselConnectionManager<AsyncPgConnection>;

#[derive(Clone)]
Expand Down Expand Up @@ -110,7 +107,7 @@ impl DerefMut for WriteHandle {
///
/// # Testing pool
///
/// In test mode, the [DbConnectionPool::get] function will always return the same connection that has
/// In test mode, the [Pool::<AsyncPgConnection>::get] function will always return the same connection that has
/// been setup to drop all modification once the test ends.
/// Since this connection will not commit any changes to the database, we ensure the isolation of each test.
///
Expand Down Expand Up @@ -423,17 +420,17 @@ pub async fn ping_database(conn: &mut DbConnection) -> Result<(), PingError> {
Ok(())
}

pub fn create_connection_pool(
fn create_connection_pool(
url: Url,
max_size: usize,
) -> Result<DbConnectionPool, DatabasePoolBuildError> {
) -> Result<Pool<AsyncPgConnection>, DatabasePoolBuildError> {
let mut manager_config = ManagerConfig::default();
manager_config.custom_setup = Box::new(establish_connection);
let manager = DbConnectionConfig::new_with_config(url, manager_config);
Ok(Pool::builder(manager).max_size(max_size).build()?)
}

fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<DieselConnection>> {
fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<AsyncPgConnection>> {
let fut = async {
let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();
connector_builder.set_verify(SslVerifyMode::NONE);
Expand All @@ -448,7 +445,7 @@ fn establish_connection(config: &str) -> BoxFuture<ConnectionResult<DieselConnec
tracing::error!("connection error: {}", e);
}
});
DieselConnection::try_from(client).await
AsyncPgConnection::try_from(client).await
};
fut.boxed()
}
8 changes: 1 addition & 7 deletions editoast/editoast_models/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
use diesel_async::pooled_connection::deadpool::Pool;
use diesel_async::AsyncPgConnection;

pub mod db_connection_pool;
pub mod tables;

pub use db_connection_pool::DbConnection;
pub use db_connection_pool::DbConnectionPoolV2;

type DieselConnection = AsyncPgConnection;
pub type DbConnectionPool = Pool<DieselConnection>;

/// Generic error type to forward errors from the database
///
/// Useful for functions which only points of failure are the DB calls.
#[derive(Debug, thiserror::Error)]
#[error("an error occured while querying the database: {0}")]
#[error("an error occurred while querying the database: {0}")]
pub struct DatabaseError(#[from] diesel::result::Error);
4 changes: 2 additions & 2 deletions editoast/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10072,7 +10072,7 @@ components:
enum:
- pathfinding_not_found
- type: object
description: An error has occured during pathfinding
description: An error has occurred during pathfinding
required:
- core_error
- status
Expand All @@ -10084,7 +10084,7 @@ components:
enum:
- pathfinding_failure
- type: object
description: An error has occured during computing
description: An error has occurred during computing
required:
- error_type
- status
Expand Down
2 changes: 1 addition & 1 deletion editoast/src/models/fixtures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::ops::DerefMut;

use chrono::Utc;
use editoast_models::DbConnection;
use editoast_models::DbConnectionPool;

use editoast_models::DbConnectionPoolV2;
use editoast_schemas::infra::Direction;
use editoast_schemas::infra::DirectionalTrackRange;
Expand Down
2 changes: 1 addition & 1 deletion editoast/src/views/infra/attached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async fn attached(
Path(InfraAttachedParams { infra_id, track_id }): Path<InfraAttachedParams>,
State(AppState {
infra_caches,
db_pool_v2: db_pool,
db_pool,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
Expand Down
2 changes: 1 addition & 1 deletion editoast/src/views/infra/auto_fixes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async fn list_auto_fixes(
Path(infra_id): Path<i64>,
State(AppState {
infra_caches,
db_pool_v2: db_pool,
db_pool,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
Expand Down
2 changes: 1 addition & 1 deletion editoast/src/views/infra/delimited_area.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ async fn delimited_area(
Extension(auth): AuthenticationExt,
State(AppState {
infra_caches,
db_pool_v2: db_pool,
db_pool,
..
}): State<AppState>,
Path(InfraIdParam { infra_id }): Path<InfraIdParam>,
Expand Down
4 changes: 2 additions & 2 deletions editoast/src/views/infra/edition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ crate::routes! {
async fn edit<'a>(
Path(InfraIdParam { infra_id }): Path<InfraIdParam>,
State(AppState {
db_pool_v2: db_pool,
db_pool,
infra_caches,
valkey,
map_layers,
Expand Down Expand Up @@ -126,7 +126,7 @@ async fn edit<'a>(
pub async fn split_track_section<'a>(
Path(InfraIdParam { infra_id }): Path<InfraIdParam>,
State(AppState {
db_pool_v2: db_pool,
db_pool,
infra_caches,
valkey,
map_layers,
Expand Down
2 changes: 1 addition & 1 deletion editoast/src/views/infra/lines.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async fn get_line_bbox(
Path((infra_id, line_code)): Path<(i64, i64)>,
State(AppState {
infra_caches,
db_pool_v2: db_pool,
db_pool,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
Expand Down
73 changes: 40 additions & 33 deletions editoast/src/views/infra/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,13 @@ struct RefreshResponse {
)
)]
async fn refresh(
app_state: State<AppState>,
State(AppState {
db_pool,
valkey: valkey_client,
infra_caches,
map_layers,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
Query(query_params): Query<RefreshQueryParams>,
) -> Result<Json<RefreshResponse>> {
Expand All @@ -135,11 +141,6 @@ async fn refresh(
return Err(AuthorizationError::Unauthorized.into());
}

let db_pool = app_state.db_pool_v2.clone();
let valkey_client = app_state.valkey.clone();
let infra_caches = app_state.infra_caches.clone();
let map_layers = app_state.map_layers.clone();

// Use a transaction to give scope to infra list lock
let RefreshQueryParams {
force,
Expand All @@ -160,7 +161,6 @@ async fn refresh(
};

// Refresh each infras
let db_pool = db_pool;
let mut infra_refreshed = vec![];

for mut infra in infras_list {
Expand Down Expand Up @@ -201,9 +201,13 @@ struct InfraListResponse {
),
)]
async fn list(
app_state: State<AppState>,
State(AppState {
db_pool,
osrdyne_client,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
pagination_params: Query<PaginationQueryParams>,
Query(pagination_params): Query<PaginationQueryParams>,
) -> Result<Json<InfraListResponse>> {
let authorized = auth
.check_roles([BuiltinRole::InfraRead].into())
Expand All @@ -212,8 +216,6 @@ async fn list(
if !authorized {
return Err(AuthorizationError::Unauthorized.into());
}
let db_pool = app_state.db_pool_v2.clone();
let osrdyne_client = app_state.osrdyne_client.clone();

let settings = pagination_params
.validate(1000)?
Expand Down Expand Up @@ -295,7 +297,11 @@ struct InfraIdParam {
),
)]
async fn get(
app_state: State<AppState>,
State(AppState {
db_pool,
osrdyne_client,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
Path(infra): Path<InfraIdParam>,
) -> Result<Json<InfraWithState>> {
Expand All @@ -307,9 +313,6 @@ async fn get(
return Err(AuthorizationError::Unauthorized.into());
}

let db_pool = app_state.db_pool_v2.clone();
let osrdyne_client = app_state.osrdyne_client.clone();

let infra_id = infra.infra_id;
let infra = Infra::retrieve_or_fail(&mut db_pool.get().await?, infra_id, || {
InfraApiError::NotFound { infra_id }
Expand Down Expand Up @@ -344,7 +347,7 @@ impl From<InfraCreateForm> for Changeset<Infra> {
),
)]
async fn create(
db_pool: State<DbConnectionPoolV2>,
State(db_pool): State<DbConnectionPoolV2>,
Extension(auth): AuthenticationExt,
Json(infra_form): Json<InfraCreateForm>,
) -> Result<impl IntoResponse> {
Expand Down Expand Up @@ -381,7 +384,7 @@ struct CloneQuery {
async fn clone(
Extension(auth): AuthenticationExt,
Path(params): Path<InfraIdParam>,
db_pool: State<DbConnectionPoolV2>,
State(db_pool): State<DbConnectionPoolV2>,
Query(CloneQuery { name }): Query<CloneQuery>,
) -> Result<Json<i64>> {
let authorized = auth
Expand Down Expand Up @@ -421,9 +424,13 @@ async fn clone(
),
)]
async fn delete(
app_state: State<AppState>,
State(AppState {
db_pool,
infra_caches,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
infra: Path<InfraIdParam>,
Path(InfraIdParam { infra_id }): Path<InfraIdParam>,
) -> Result<impl IntoResponse> {
let authorized = auth
.check_roles([BuiltinRole::InfraWrite].into())
Expand All @@ -433,9 +440,6 @@ async fn delete(
return Err(AuthorizationError::Unauthorized.into());
}

let db_pool = app_state.db_pool_v2.clone();
let infra_caches = app_state.infra_caches.clone();
let infra_id = infra.infra_id;
if Infra::fast_delete_static(db_pool.get().await?, infra_id).await? {
infra_caches.remove(&infra_id);
Ok(StatusCode::NO_CONTENT)
Expand Down Expand Up @@ -468,7 +472,7 @@ impl From<InfraPatchForm> for Changeset<Infra> {
),
)]
async fn put(
db_pool: State<DbConnectionPoolV2>,
State(db_pool): State<DbConnectionPoolV2>,
Extension(auth): AuthenticationExt,
Path(infra): Path<i64>,
Json(patch): Json<InfraPatchForm>,
Expand Down Expand Up @@ -501,7 +505,11 @@ async fn put(
)
)]
async fn get_switch_types(
app_state: State<AppState>,
State(AppState {
db_pool,
infra_caches,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
Path(infra): Path<InfraIdParam>,
) -> Result<Json<Vec<SwitchType>>> {
Expand All @@ -513,9 +521,7 @@ async fn get_switch_types(
return Err(AuthorizationError::Unauthorized.into());
}

let db_pool = app_state.db_pool_v2.clone();
let conn = &mut db_pool.get().await?;
let infra_caches = app_state.infra_caches.clone();

let infra = Infra::retrieve_or_fail(conn, infra.infra_id, || InfraApiError::NotFound {
infra_id: infra.infra_id,
Expand Down Expand Up @@ -546,7 +552,7 @@ async fn get_switch_types(
async fn get_speed_limit_tags(
Extension(auth): AuthenticationExt,
Path(infra): Path<InfraIdParam>,
db_pool: State<DbConnectionPoolV2>,
State(db_pool): State<DbConnectionPoolV2>,
) -> Result<Json<Vec<String>>> {
let authorized = auth
.check_roles([BuiltinRole::InfraRead].into())
Expand Down Expand Up @@ -590,7 +596,7 @@ async fn get_voltages(
Extension(auth): AuthenticationExt,
Path(infra): Path<InfraIdParam>,
Query(param): Query<GetVoltagesQueryParams>,
db_pool: State<DbConnectionPoolV2>,
State(db_pool): State<DbConnectionPoolV2>,
) -> Result<Json<Vec<String>>> {
let authorized = auth
.check_roles([BuiltinRole::InfraRead].into())
Expand Down Expand Up @@ -623,7 +629,7 @@ async fn get_voltages(
)
)]
async fn get_all_voltages(
db_pool: State<DbConnectionPoolV2>,
State(db_pool): State<DbConnectionPoolV2>,
Extension(auth): AuthenticationExt,
) -> Result<Json<Vec<String>>> {
let authorized = auth
Expand Down Expand Up @@ -712,7 +718,11 @@ async fn unlock(
)
)]
async fn load(
app_state: State<AppState>,
State(AppState {
db_pool,
core_client,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
Path(path): Path<InfraIdParam>,
) -> Result<impl IntoResponse> {
Expand All @@ -724,9 +734,6 @@ async fn load(
return Err(AuthorizationError::Unauthorized.into());
}

let db_pool = app_state.db_pool_v2.clone();
let core_client = app_state.core_client.clone();

let infra_id = path.infra_id;
let infra = Infra::retrieve_or_fail(&mut db_pool.get().await?, infra_id, || {
InfraApiError::NotFound { infra_id }
Expand Down
9 changes: 5 additions & 4 deletions editoast/src/views/infra/pathfinding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ struct QueryParam {
)
)]
async fn pathfinding_view(
app_state: State<AppState>,
State(AppState {
db_pool,
infra_caches,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
Path(infra): Path<InfraIdParam>,
Query(params): Query<QueryParam>,
Expand All @@ -110,9 +114,6 @@ async fn pathfinding_view(
return Err(AuthorizationError::Unauthorized.into());
}

let db_pool = app_state.db_pool_v2.clone();
let infra_caches = app_state.infra_caches.clone();

// Parse and check input
let infra_id = infra.infra_id;
let number = params.number.unwrap_or(DEFAULT_NUMBER_OF_PATHS);
Expand Down
Loading
Loading