Skip to content

Commit

Permalink
editoast: destructure AppState early in handlers for consistency
Browse files Browse the repository at this point in the history
Signed-off-by: Leo Valais <[email protected]>
  • Loading branch information
leovalais committed Dec 3, 2024
1 parent 614ab93 commit 4991144
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 88 deletions.
68 changes: 38 additions & 30 deletions editoast/src/views/infra/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,13 @@ struct RefreshResponse {
)
)]
async fn refresh(
app_state: State<AppState>,
State(AppState {
db_pool,
valkey: valkey_client,
infra_caches,
map_layers,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
Query(query_params): Query<RefreshQueryParams>,
) -> Result<Json<RefreshResponse>> {
Expand All @@ -135,11 +141,6 @@ async fn refresh(
return Err(AuthorizationError::Unauthorized.into());
}

let db_pool = app_state.db_pool.clone();
let valkey_client = app_state.valkey.clone();
let infra_caches = app_state.infra_caches.clone();
let map_layers = app_state.map_layers.clone();

// Use a transaction to give scope to infra list lock
let RefreshQueryParams {
force,
Expand All @@ -160,7 +161,6 @@ async fn refresh(
};

// Refresh each infras
let db_pool = db_pool;
let mut infra_refreshed = vec![];

for mut infra in infras_list {
Expand Down Expand Up @@ -201,7 +201,11 @@ struct InfraListResponse {
),
)]
async fn list(
app_state: State<AppState>,
State(AppState {
db_pool,
osrdyne_client,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
pagination_params: Query<PaginationQueryParams>,
) -> Result<Json<InfraListResponse>> {
Expand All @@ -212,8 +216,6 @@ async fn list(
if !authorized {
return Err(AuthorizationError::Unauthorized.into());
}
let db_pool = app_state.db_pool.clone();
let osrdyne_client = app_state.osrdyne_client.clone();

let settings = pagination_params
.validate(1000)?
Expand Down Expand Up @@ -295,7 +297,11 @@ struct InfraIdParam {
),
)]
async fn get(
app_state: State<AppState>,
State(AppState {
db_pool,
osrdyne_client,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
Path(infra): Path<InfraIdParam>,
) -> Result<Json<InfraWithState>> {
Expand All @@ -307,9 +313,6 @@ async fn get(
return Err(AuthorizationError::Unauthorized.into());
}

let db_pool = app_state.db_pool.clone();
let osrdyne_client = app_state.osrdyne_client.clone();

let infra_id = infra.infra_id;
let infra = Infra::retrieve_or_fail(&mut db_pool.get().await?, infra_id, || {
InfraApiError::NotFound { infra_id }
Expand Down Expand Up @@ -344,7 +347,7 @@ impl From<InfraCreateForm> for Changeset<Infra> {
),
)]
async fn create(
db_pool: State<DbConnectionPoolV2>,
State(db_pool): State<DbConnectionPoolV2>,
Extension(auth): AuthenticationExt,
Json(infra_form): Json<InfraCreateForm>,
) -> Result<impl IntoResponse> {
Expand Down Expand Up @@ -381,7 +384,7 @@ struct CloneQuery {
async fn clone(
Extension(auth): AuthenticationExt,
Path(params): Path<InfraIdParam>,
db_pool: State<DbConnectionPoolV2>,
State(db_pool): State<DbConnectionPoolV2>,
Query(CloneQuery { name }): Query<CloneQuery>,
) -> Result<Json<i64>> {
let authorized = auth
Expand Down Expand Up @@ -421,7 +424,11 @@ async fn clone(
),
)]
async fn delete(
app_state: State<AppState>,
State(AppState {
db_pool,
infra_caches,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
infra: Path<InfraIdParam>,
) -> Result<impl IntoResponse> {
Expand All @@ -433,8 +440,6 @@ async fn delete(
return Err(AuthorizationError::Unauthorized.into());
}

let db_pool = app_state.db_pool.clone();
let infra_caches = app_state.infra_caches.clone();
let infra_id = infra.infra_id;
if Infra::fast_delete_static(db_pool.get().await?, infra_id).await? {
infra_caches.remove(&infra_id);
Expand Down Expand Up @@ -468,7 +473,7 @@ impl From<InfraPatchForm> for Changeset<Infra> {
),
)]
async fn put(
db_pool: State<DbConnectionPoolV2>,
State(db_pool): State<DbConnectionPoolV2>,
Extension(auth): AuthenticationExt,
Path(infra): Path<i64>,
Json(patch): Json<InfraPatchForm>,
Expand Down Expand Up @@ -501,7 +506,11 @@ async fn put(
)
)]
async fn get_switch_types(
app_state: State<AppState>,
State(AppState {
db_pool,
infra_caches,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
Path(infra): Path<InfraIdParam>,
) -> Result<Json<Vec<SwitchType>>> {
Expand All @@ -513,9 +522,7 @@ async fn get_switch_types(
return Err(AuthorizationError::Unauthorized.into());
}

let db_pool = app_state.db_pool.clone();
let conn = &mut db_pool.get().await?;
let infra_caches = app_state.infra_caches.clone();

let infra = Infra::retrieve_or_fail(conn, infra.infra_id, || InfraApiError::NotFound {
infra_id: infra.infra_id,
Expand Down Expand Up @@ -546,7 +553,7 @@ async fn get_switch_types(
async fn get_speed_limit_tags(
Extension(auth): AuthenticationExt,
Path(infra): Path<InfraIdParam>,
db_pool: State<DbConnectionPoolV2>,
State(db_pool): State<DbConnectionPoolV2>,
) -> Result<Json<Vec<String>>> {
let authorized = auth
.check_roles([BuiltinRole::InfraRead].into())
Expand Down Expand Up @@ -590,7 +597,7 @@ async fn get_voltages(
Extension(auth): AuthenticationExt,
Path(infra): Path<InfraIdParam>,
Query(param): Query<GetVoltagesQueryParams>,
db_pool: State<DbConnectionPoolV2>,
State(db_pool): State<DbConnectionPoolV2>,
) -> Result<Json<Vec<String>>> {
let authorized = auth
.check_roles([BuiltinRole::InfraRead].into())
Expand Down Expand Up @@ -623,7 +630,7 @@ async fn get_voltages(
)
)]
async fn get_all_voltages(
db_pool: State<DbConnectionPoolV2>,
State(db_pool): State<DbConnectionPoolV2>,
Extension(auth): AuthenticationExt,
) -> Result<Json<Vec<String>>> {
let authorized = auth
Expand Down Expand Up @@ -712,7 +719,11 @@ async fn unlock(
)
)]
async fn load(
app_state: State<AppState>,
State(AppState {
db_pool,
core_client,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
Path(path): Path<InfraIdParam>,
) -> Result<impl IntoResponse> {
Expand All @@ -724,9 +735,6 @@ async fn load(
return Err(AuthorizationError::Unauthorized.into());
}

let db_pool = app_state.db_pool.clone();
let core_client = app_state.core_client.clone();

let infra_id = path.infra_id;
let infra = Infra::retrieve_or_fail(&mut db_pool.get().await?, infra_id, || {
InfraApiError::NotFound { infra_id }
Expand Down
9 changes: 5 additions & 4 deletions editoast/src/views/infra/pathfinding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ struct QueryParam {
)
)]
async fn pathfinding_view(
app_state: State<AppState>,
State(AppState {
db_pool,
infra_caches,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
Path(infra): Path<InfraIdParam>,
Query(params): Query<QueryParam>,
Expand All @@ -110,9 +114,6 @@ async fn pathfinding_view(
return Err(AuthorizationError::Unauthorized.into());
}

let db_pool = app_state.db_pool.clone();
let infra_caches = app_state.infra_caches.clone();

// Parse and check input
let infra_id = infra.infra_id;
let number = params.number.unwrap_or(DEFAULT_NUMBER_OF_PATHS);
Expand Down
10 changes: 6 additions & 4 deletions editoast/src/views/infra/railjson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ enum ListErrorsRailjson {
)]
async fn get_railjson(
Path(infra): Path<InfraIdParam>,
db_pool: State<DbConnectionPoolV2>,
State(db_pool): State<DbConnectionPoolV2>,
Extension(auth): AuthenticationExt,
) -> Result<impl IntoResponse> {
let authorized = auth
Expand Down Expand Up @@ -168,7 +168,11 @@ struct PostRailjsonResponse {
)
)]
async fn post_railjson(
app_state: State<AppState>,
State(AppState {
db_pool,
infra_caches,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
Query(params): Query<PostRailjsonQueryParams>,
Json(railjson): Json<RailJson>,
Expand All @@ -181,8 +185,6 @@ async fn post_railjson(
return Err(AuthorizationError::Unauthorized.into());
}

let db_pool = app_state.db_pool.clone();
let infra_caches = app_state.infra_caches.clone();
if railjson.version != RAILJSON_VERSION {
return Err(ListErrorsRailjson::WrongRailjsonVersionProvided.into());
}
Expand Down
21 changes: 13 additions & 8 deletions editoast/src/views/infra/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ struct RoutesResponse {
)]
async fn get_routes_from_waypoint(
Path(path): Path<RoutesFromWaypointParams>,
db_pool: State<DbConnectionPoolV2>,
State(db_pool): State<DbConnectionPoolV2>,
Extension(auth): AuthenticationExt,
) -> Result<Json<RoutesResponse>> {
let authorized = auth
Expand Down Expand Up @@ -145,7 +145,11 @@ struct RoutesFromNodesPositions {
),
)]
async fn get_routes_track_ranges(
app_state: State<AppState>,
State(AppState {
db_pool,
infra_caches,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
Path(infra): Path<i64>,
Query(params): Query<RouteTrackRangesParams>,
Expand All @@ -158,8 +162,8 @@ async fn get_routes_track_ranges(
return Err(AuthorizationError::Unauthorized.into());
}

let db_pool = app_state.db_pool.clone();
let infra_caches = app_state.infra_caches.clone();
let db_pool = db_pool.clone();
let infra_caches = infra_caches.clone();
let infra_id = infra;
let infra = Infra::retrieve_or_fail(&mut db_pool.get().await?, infra_id, || {
InfraApiError::NotFound { infra_id }
Expand Down Expand Up @@ -206,7 +210,11 @@ async fn get_routes_track_ranges(
),
)]
async fn get_routes_nodes(
app_state: State<AppState>,
State(AppState {
db_pool,
infra_caches,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
Path(params): Path<InfraIdParam>,
Json(node_states): Json<HashMap<String, Option<String>>>,
Expand All @@ -219,9 +227,6 @@ async fn get_routes_nodes(
return Err(AuthorizationError::Unauthorized.into());
}

let db_pool = app_state.db_pool.clone();
let infra_caches = app_state.infra_caches.clone();

let infra = Infra::retrieve_or_fail(&mut db_pool.get().await?, params.infra_id, || {
InfraApiError::NotFound {
infra_id: params.infra_id,
Expand Down
7 changes: 5 additions & 2 deletions editoast/src/views/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,11 @@ async fn version() -> Json<Version> {
(status = 200, description = "Return the core service version", body = Version),
),
)]
async fn core_version(app_state: State<AppState>) -> Json<Version> {
let core = app_state.core_client.clone();
async fn core_version(
State(AppState {
core_client: core, ..
}): State<AppState>,
) -> Json<Version> {
let response = CoreVersionRequest {}.fetch(&core).await;
let response = response.unwrap_or(Version { git_describe: None });
Json(response)
Expand Down
19 changes: 10 additions & 9 deletions editoast/src/views/timetable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ struct TimetableIdParam {
),
)]
async fn get(
db_pool: State<DbConnectionPoolV2>,
State(db_pool): State<DbConnectionPoolV2>,
Extension(auth): AuthenticationExt,
Path(timetable_id): Path<TimetableIdParam>,
) -> Result<Json<TimetableDetailedResult>> {
Expand Down Expand Up @@ -154,7 +154,7 @@ async fn get(
),
)]
async fn post(
db_pool: State<DbConnectionPoolV2>,
State(db_pool): State<DbConnectionPoolV2>,
Extension(auth): AuthenticationExt,
) -> Result<Json<TimetableResult>> {
let authorized = auth
Expand Down Expand Up @@ -183,7 +183,7 @@ async fn post(
),
)]
async fn delete(
db_pool: State<DbConnectionPoolV2>,
State(db_pool): State<DbConnectionPoolV2>,
Extension(auth): AuthenticationExt,
timetable_id: Path<TimetableIdParam>,
) -> Result<impl IntoResponse> {
Expand Down Expand Up @@ -215,7 +215,7 @@ async fn delete(
)
)]
async fn train_schedule(
db_pool: State<DbConnectionPoolV2>,
State(db_pool): State<DbConnectionPoolV2>,
Extension(auth): AuthenticationExt,
Path(timetable_id): Path<TimetableIdParam>,
Json(train_schedules): Json<Vec<TrainScheduleBase>>,
Expand Down Expand Up @@ -271,7 +271,12 @@ pub struct ElectricalProfileSetIdQueryParam {
),
)]
async fn conflicts(
app_state: State<AppState>,
State(AppState {
db_pool,
valkey: valkey_client,
core_client,
..
}): State<AppState>,
Extension(auth): AuthenticationExt,
Path(timetable_id): Path<TimetableIdParam>,
Query(infra_id_query): Query<InfraIdQueryParam>,
Expand All @@ -285,10 +290,6 @@ async fn conflicts(
return Err(AuthorizationError::Unauthorized.into());
}

let db_pool = app_state.db_pool.clone();
let valkey_client = app_state.valkey.clone();
let core_client = app_state.core_client.clone();

let timetable_id = timetable_id.id;
let infra_id = infra_id_query.infra_id;
let electrical_profile_set_id = electrical_profile_set_id_query.electrical_profile_set_id;
Expand Down
Loading

0 comments on commit 4991144

Please sign in to comment.