diff --git a/contest_service/src/main.rs b/contest_service/src/main.rs index a9dabd2..cdb3cd0 100644 --- a/contest_service/src/main.rs +++ b/contest_service/src/main.rs @@ -1,3 +1,4 @@ +#![feature(async_closure)] use futures::stream::StreamExt; use std::convert::TryFrom; @@ -114,16 +115,31 @@ impl Contest for ContestService { &self, _request: Request, ) -> Result, Status> { - self.get_contest_metadata() + let metadata = self + .get_contest_metadata() .await .map(|contest_metadata_doc| { mappings::contest::ContestMetadata::from(contest_metadata_doc).into() - }) + })?; + let problems = self + .get_problems_collection() + .find(None, None) + .await + .map_err(internal_error)? + .filter_map(async move |x| x.ok()) + .map(mappings::problem::ProblemData::from) + .map(|x| x.get_problem().into()) + .collect() + .await; + Ok(Response::new(GetContestMetadataResponse { + metadata, + problems, + })) } - async fn get_problem( + async fn get_problem_info( &self, request: Request, - ) -> Result, Status> { + ) -> Result, Status> { let problem_id = request.into_inner().problem_id; self.get_problems_collection() .find_one(doc! {"_id": problem_id as i64}, None) @@ -131,13 +147,31 @@ impl Contest for ContestService { .map_err(internal_error)? .map(mappings::problem::ProblemData::from) .map(|x| { - Response::new(GetProblemResponse { + Response::new(GetProblemInfoResponse { info: x.get_problem().into(), + }) + }) + .ok_or_else(|| Status::not_found("Problem not found")) + } + + async fn get_problem_statement( + &self, + request: Request, + ) -> Result, Status> { + let problem_id = request.into_inner().problem_id; + self.get_problems_collection() + .find_one(doc! {"_id": problem_id as i64}, None) + .await + .map_err(internal_error)? + .map(mappings::problem::ProblemData::from) + .map(|x| { + Response::new(GetProblemStatementResponse { statement: x.get_statement(), }) }) - .ok_or_else(|| Status::internal("Problem not found")) + .ok_or_else(|| Status::not_found("Problem not found")) } + async fn get_announcement_list( &self, _request: Request, @@ -214,26 +248,22 @@ impl Contest for ContestService { request: Request, ) -> Result, Status> { let problem_data_from_req = request.into_inner(); - if let Some(p) = problem_data_from_req.info { - if let Some(bin) = problem_data_from_req.statement { - let problem_data: mappings::problem::ProblemData = (p.into(), bin).into(); + let problem_data: mappings::problem::ProblemData = ( + problem_data_from_req.info.into(), + problem_data_from_req.statement, + ) + .into(); - let document: Document = problem_data.into(); - self.get_problems_collection() - .update_one( - doc! { "_id": document.get_i32("_id").unwrap() }, - doc! { "$set": document }, - UpdateOptions::builder().upsert(true).build(), - ) - .await - .map_err(internal_error) - .map(|_| Response::new(SetProblemResponse {})) - } else { - Err(Status::invalid_argument("Missing required parameter")) - } - } else { - Err(Status::invalid_argument("Missing required parameter")) - } + let document: Document = problem_data.into(); + self.get_problems_collection() + .update_one( + doc! { "_id": document.get_i64("_id").unwrap() }, + doc! { "$set": document }, + UpdateOptions::builder().upsert(true).build(), + ) + .await + .map_err(internal_error) + .map(|_| Response::new(SetProblemResponse {})) } async fn add_message( &self, @@ -251,6 +281,45 @@ impl Contest for ContestService { .map_err(internal_error)?; Ok(Response::new(AddMessageResponse {})) } + + async fn update_problem_info( + &self, + request: Request, + ) -> Result, Status> { + let problem_data_from_req = request.into_inner(); + let problem_data: mappings::problem::Problem = problem_data_from_req.info.into(); + self.get_problems_collection() + .update_one( + doc! { "_id": problem_data.get_id() }, + doc! { "$set": doc!{"name": problem_data.name, "longName": problem_data.long_name} }, + UpdateOptions::builder().build(), + ) + .await + .map_err(internal_error) + .map(|_| Response::new(SetProblemResponse {})) + } + + async fn update_problem_statement( + &self, + request: Request, + ) -> Result, Status> { + let problem_data_from_req = request.into_inner(); + let problem_statement = problem_data_from_req.statement; + let problem_id = problem_data_from_req.problem_id; + + self.get_problems_collection() + .update_one( + doc! { "_id": problem_id as i64 }, + doc! { "$set": doc!{"statement": mongodb::bson::Binary { + subtype: mongodb::bson::spec::BinarySubtype::Generic, + bytes: problem_statement, + }} }, + UpdateOptions::builder().build(), + ) + .await + .map_err(internal_error) + .map(|_| Response::new(SetProblemResponse {})) + } } #[tokio::main] diff --git a/contest_service/src/mappings.rs b/contest_service/src/mappings.rs index 7e901d9..74b5261 100644 --- a/contest_service/src/mappings.rs +++ b/contest_service/src/mappings.rs @@ -1,8 +1,6 @@ use std::convert::TryFrom; use mongodb::bson::Document; -use protos::service::contest::GetContestMetadataResponse; -use tonic::Response; #[derive(Debug)] pub enum MappingError { @@ -40,16 +38,14 @@ pub mod contest { } } } - impl From for Response { + impl From for protos::service::contest::ContestMetadata { fn from(md: ContestMetadata) -> Self { - Response::new(GetContestMetadataResponse { - metadata: protos::service::contest::ContestMetadata { - name: md.name, - description: md.description, - start_time: md.start_time.map(protos::common::Timestamp::from), - end_time: md.end_time.map(protos::common::Timestamp::from), - }, - }) + protos::service::contest::ContestMetadata { + name: md.name, + description: md.description, + start_time: md.start_time.map(protos::common::Timestamp::from), + end_time: md.end_time.map(protos::common::Timestamp::from), + } } } @@ -199,8 +195,14 @@ pub mod problem { #[derive(Default, Clone)] pub struct Problem { id: u64, - name: String, - long_name: String, + pub name: String, + pub long_name: String, + } + + impl Problem { + pub fn get_id(&self) -> i64 { + self.id as i64 + } } impl From for Problem { diff --git a/protos/protos/service/contest.proto b/protos/protos/service/contest.proto index 3c98756..5b944dc 100644 --- a/protos/protos/service/contest.proto +++ b/protos/protos/service/contest.proto @@ -1,9 +1,9 @@ syntax = "proto2"; -import "common.proto"; - package service.contest; +import "common.proto"; + message AuthUserRequest { required string username = 1; required string password = 2; @@ -31,6 +31,7 @@ message ContestMetadata { message GetContestMetadataRequest {} message GetContestMetadataResponse { required ContestMetadata metadata = 1; + repeated Problem problems = 2; } message Problem { @@ -41,9 +42,12 @@ message Problem { message GetProblemRequest { required uint64 problem_id = 1; } -message GetProblemResponse { +message GetProblemInfoResponse { required Problem info = 1; - required bytes statement = 2; +} + +message GetProblemStatementResponse { + required bytes statement = 1; } message Message { // questions and announcements are the same @@ -90,9 +94,19 @@ message SetContestMetadataRequest { message SetContestMetadataResponse {} message SetProblemRequest { - optional Problem info = 1; - optional bytes statement = 2; + required Problem info = 1; + required bytes statement = 2; } + +message UpdateProblemStatementRequest { + required uint64 problem_id = 1; + required bytes statement = 2; +} + +message UpdateProblemInfoRequest { + required Problem info = 1; +} + message SetProblemResponse {} message AddMessageRequest { @@ -103,12 +117,15 @@ message AddMessageResponse {} service Contest { rpc auth_user(AuthUserRequest) returns (AuthUserResponse); rpc get_contest_metadata(GetContestMetadataRequest) returns (GetContestMetadataResponse); - rpc get_problem(GetProblemRequest) returns (GetProblemResponse); + rpc get_problem_info(GetProblemRequest) returns (GetProblemInfoResponse); + rpc get_problem_statement(GetProblemRequest) returns (GetProblemStatementResponse); rpc get_announcement_list(GetAnnouncementListRequest) returns (GetAnnouncementListResponse); rpc get_question_list(GetQuestionListRequest) returns (GetQuestionListResponse); rpc set_user(SetUserRequest) returns (SetUserResponse); rpc set_contest_metadata(SetContestMetadataRequest) returns (SetContestMetadataResponse); rpc set_problem(SetProblemRequest) returns (SetProblemResponse); + rpc update_problem_info(UpdateProblemInfoRequest) returns (SetProblemResponse); + rpc update_problem_statement(UpdateProblemStatementRequest) returns (SetProblemResponse); rpc add_message(AddMessageRequest) returns (AddMessageResponse); } diff --git a/protos/src/lib.rs b/protos/src/lib.rs index 519ff8e..4f3fec9 100644 --- a/protos/src/lib.rs +++ b/protos/src/lib.rs @@ -54,12 +54,15 @@ pub mod service { rpc_mock_server!(contest_server::Contest; MockContest; (auth_user,AuthUserRequest,AuthUserResponse), (get_contest_metadata,GetContestMetadataRequest,GetContestMetadataResponse), - (get_problem,GetProblemRequest,GetProblemResponse), + (get_problem_statement,GetProblemRequest,GetProblemStatementResponse), + (get_problem_info,GetProblemRequest,GetProblemInfoResponse), (get_announcement_list,GetAnnouncementListRequest,GetAnnouncementListResponse), (get_question_list,GetQuestionListRequest,GetQuestionListResponse), (set_user,SetUserRequest,SetUserResponse), (set_contest_metadata,SetContestMetadataRequest,SetContestMetadataResponse), (set_problem,SetProblemRequest,SetProblemResponse), + (update_problem_info,UpdateProblemInfoRequest,SetProblemResponse), + (update_problem_statement,UpdateProblemStatementRequest,SetProblemResponse), (add_message,AddMessageRequest,AddMessageResponse) ); }