diff --git a/README.md b/README.md index 1a0f66ac2..199459ae0 100644 --- a/README.md +++ b/README.md @@ -1 +1,75 @@ -# jwp-subway-path \ No newline at end of file +# jwp-subway-path + +## Quick Start + +```text +cd docker + +docker-compose -p subway up +``` + +## API 문서 + +```text +http://localhost:8080/swagger-ui/index.html#/ +``` + +## 기능 요구사항 + +### 역 + +- [x] 역은 고유한 식별자를 가진다. +- [x] 역 이름은 `역`으로 끝나야 한다. +- [x] 역 이름은 2글자에서 11글자까지 가능하다. +- [x] 역 이름은 한글 + 숫자로만 이루어져야 한다. + +### 구역 + +- [x] 구역은 두 역과 역 사이의 거리를 가진다. +- [x] 거리는 양의 정수이고, 단위는 km이다. + +### 노선 + +- [x] 노선 이름은 숫자 + `호선` 이다. + - [x] 숫자는 1 ~ 9까지 가능하다. + +- [x] 노선의 색은 `색`으로 끝나야 한다. + - [x] 색 이름은 2글자에서 11글자까지 가능하다. + +- [x] 역을 추가할 수 있어야 한다. + -[x] 역을 추가할 때, 상행, 하행 역의 정보와 거리 정보를 입력받는다. + - 최초 등록이 아닐 경우, 상행역 또는 하행역 어느 한 가지도 존재하지 않으면 예외를 던진다. + -[x] 하나의 역은 여러 노선에 등록될 수 있다. + -[x] 두 역의 가운데에 다른 역을 등록할 때, 기존 거리를 고려해야 한다. + +- [x] 역이 2개 이상 등록된 노선을 전부 보여주어야 한다. + +- [x] 구역은 역 순서대로 저장되어 있어야 한다. + - [x] 노선 번호를 입력받으면, 노선에 포함된 역을 순서대로 보여주어야 한다. + +- [x] 역을 제거할 수 있어야 한다. + -[x] 역을 제거하면 남은 역을 재배치 해야 한다. + -[x] 노선에서 역이 제거되면 역과 역 사이의 거리도 재배정되어야 한다. + -[x] 노선에 등록된 역이 2개인 경우 하나의 역을 제거할 때 두 역이 모두 제거되어야 한다. + +### 지하철 + +- [x] 출발역과 도착역 사이의 최단 경로를 구한다. + - [x] 총 거리 정보를 함께 응답한다. + +### 요금 정책 + +- [x] 거리에 따른 요금을 계산한다. + - 기본운임(10㎞ 이내): 기본운임 1,250원 + - 이용 거리 초과 시 추가운임 부과 + - 10km~50km: 5km 까지 마다 100원 추가 + - 50km 초과: 8km 까지 마다 100원 추가 + +- [x] 노선에 따른 요금을 계산한다. + - 추가 요금이 있는 노선을 이용하면 측정된 요금에 추가한다. + - 경로 중 추가 요금이 있는 노선을 환승하여 이용하면 가장 높은 금액의 추가 요금만 적용한다. + +- [x] 연령에 따른 요금 할인을 계산한다. + - 연령에 따른 요금 할인 정책을 반영한다. + - 청소년: 운임에서 350원을 공제한 금액의 20% 할인 + - 어린이: 운임에서 350원을 공제한 금액의 50% 할인 diff --git a/build.gradle b/build.gradle index 68d8f2558..c4c067cab 100644 --- a/build.gradle +++ b/build.gradle @@ -1,27 +1,30 @@ plugins { - id 'java' - id 'org.springframework.boot' version '2.7.9' - id 'io.spring.dependency-management' version '1.0.15.RELEASE' + id 'java' + id 'org.springframework.boot' version '2.7.9' + id 'io.spring.dependency-management' version '1.0.15.RELEASE' } sourceCompatibility = '11' repositories { - mavenCentral() + mavenCentral() } dependencies { - implementation 'org.springframework.boot:spring-boot-starter-web' - implementation 'org.springframework.boot:spring-boot-starter-jdbc' + implementation 'org.springframework.boot:spring-boot-starter-web' + implementation 'org.springframework.boot:spring-boot-starter-jdbc' + implementation 'org.springframework.boot:spring-boot-starter-validation' + implementation 'net.rakugakibox.spring.boot:logback-access-spring-boot-starter:2.7.1' + implementation 'org.springdoc:springdoc-openapi-ui:1.7.0' + implementation 'org.jgrapht:jgrapht-core:1.5.2' + implementation 'mysql:mysql-connector-java:8.0.33' - implementation 'net.rakugakibox.spring.boot:logback-access-spring-boot-starter:2.7.1' + testImplementation 'io.rest-assured:rest-assured:4.4.0' + testImplementation 'org.springframework.boot:spring-boot-starter-test' - testImplementation 'io.rest-assured:rest-assured:4.4.0' - testImplementation 'org.springframework.boot:spring-boot-starter-test' - - runtimeOnly 'com.h2database:h2' + runtimeOnly 'com.h2database:h2' } test { - useJUnitPlatform() -} \ No newline at end of file + useJUnitPlatform() +} diff --git a/docker/.env b/docker/.env new file mode 100644 index 000000000..534d3fe10 --- /dev/null +++ b/docker/.env @@ -0,0 +1,4 @@ +MYSQL_DATABASE=subway +MYSQL_ROOT_PASSWORD=1234 +MYSQL_USER=subway +MYSQL_PASSWORD=1234 diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml new file mode 100644 index 000000000..4ffcc96bb --- /dev/null +++ b/docker/docker-compose.yml @@ -0,0 +1,37 @@ +version: "3" +services: + app: + container_name: app + image: arm64v8/amazoncorretto:11-alpine-jdk + ports: + - "8080:8080" + environment: + SPRING_DATASOURCE_URL: jdbc:mysql://db:3306/${MYSQL_DATABASE}?serverTimezone=Asia/Seoul&characterEncoding=UTF-8 + SPRING_DATASOURCE_USERNAME: ${MYSQL_USER} + SPRING_DATASOURCE_PASSWORD: ${MYSQL_PASSWORD} + volumes: + - ../src:/app/src + - ../gradle:/app/gradle + - ../build.gradle:/app/build.gradle + - ../gradlew:/app/gradlew + - ../gradlew.bat:/app/gradlew.bat + working_dir: /app + command: [ "./gradlew", "bootrun" ] + depends_on: + - db + restart: always + + db: + container_name: db + image: mysql:8.0.33 + ports: + - "3307:3306" + environment: + MYSQL_DATABASE: ${MYSQL_DATABASE} + MYSQL_ROOT_PASSWORD: ${MYSQL_ROOT_PASSWORD} + MYSQL_USER: ${MYSQL_USER} + MYSQL_PASSWORD: ${MYSQL_PASSWORD} + volumes: + - ../src/main/resources/data.sql:/docker-entrypoint-initdb.d/init.sql + - ./mysql.cnf:/etc/mysql/conf.d/mysql.cnf + restart: always diff --git a/docker/mysql.cnf b/docker/mysql.cnf new file mode 100644 index 000000000..ea91bbc28 --- /dev/null +++ b/docker/mysql.cnf @@ -0,0 +1,44 @@ +# Copyright (c) 2015, 2021, Oracle and/or its affiliates. +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License, version 2.0, +# as published by the Free Software Foundation. +# +# This program is also distributed with certain software (including +# but not limited to OpenSSL) that is licensed under separate terms, +# as designated in a particular file or component or in included license +# documentation. The authors of MySQL hereby grant you an additional +# permission to link the program and your derivative works with the +# separately licensed software that they have included with MySQL. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License, version 2.0, for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA + +# +# The MySQL Client configuration file. +# +# For explanations see +# http://dev.mysql.com/doc/mysql/en/server-system-variables.html + +[client] +default-character-set=utf8mb4 + +[mysqld] +character-set-client-handshake=FALSE +init_connect="SET collation_connection=utf8mb4_unicode_ci" +init_connect="SET NAMES utf8mb4" +character-set-server=utf8mb4 +default_time_zone=Asia/Seoul +lower_case_table_names=1 + +[mysql] +default-character-set=utf8mb4 + +[mysqldump] +default-character-set=utf8mb4 diff --git a/src/main/java/subway/SubwayApplication.java b/src/main/java/subway/SubwayApplication.java index 5174a4245..5780e8fb1 100644 --- a/src/main/java/subway/SubwayApplication.java +++ b/src/main/java/subway/SubwayApplication.java @@ -6,8 +6,8 @@ @SpringBootApplication public class SubwayApplication { - public static void main(String[] args) { - SpringApplication.run(SubwayApplication.class, args); - } + public static void main(String[] args) { + SpringApplication.run(SubwayApplication.class, args); + } } diff --git a/src/main/java/subway/application/LineService.java b/src/main/java/subway/application/LineService.java deleted file mode 100644 index bdb006f53..000000000 --- a/src/main/java/subway/application/LineService.java +++ /dev/null @@ -1,53 +0,0 @@ -package subway.application; - -import org.springframework.stereotype.Service; -import subway.dao.LineDao; -import subway.domain.Line; -import subway.dto.LineRequest; -import subway.dto.LineResponse; - -import java.util.List; -import java.util.stream.Collectors; - -@Service -public class LineService { - private final LineDao lineDao; - - public LineService(LineDao lineDao) { - this.lineDao = lineDao; - } - - public LineResponse saveLine(LineRequest request) { - Line persistLine = lineDao.insert(new Line(request.getName(), request.getColor())); - return LineResponse.of(persistLine); - } - - public List findLineResponses() { - List persistLines = findLines(); - return persistLines.stream() - .map(LineResponse::of) - .collect(Collectors.toList()); - } - - public List findLines() { - return lineDao.findAll(); - } - - public LineResponse findLineResponseById(Long id) { - Line persistLine = findLineById(id); - return LineResponse.of(persistLine); - } - - public Line findLineById(Long id) { - return lineDao.findById(id); - } - - public void updateLine(Long id, LineRequest lineUpdateRequest) { - lineDao.update(new Line(id, lineUpdateRequest.getName(), lineUpdateRequest.getColor())); - } - - public void deleteLineById(Long id) { - lineDao.deleteById(id); - } - -} diff --git a/src/main/java/subway/application/StationService.java b/src/main/java/subway/application/StationService.java deleted file mode 100644 index 603d9daa7..000000000 --- a/src/main/java/subway/application/StationService.java +++ /dev/null @@ -1,44 +0,0 @@ -package subway.application; - -import org.springframework.stereotype.Service; -import subway.dao.StationDao; -import subway.domain.Station; -import subway.dto.StationRequest; -import subway.dto.StationResponse; - -import java.util.List; -import java.util.stream.Collectors; - -@Service -public class StationService { - private final StationDao stationDao; - - public StationService(StationDao stationDao) { - this.stationDao = stationDao; - } - - public StationResponse saveStation(StationRequest stationRequest) { - Station station = stationDao.insert(new Station(stationRequest.getName())); - return StationResponse.of(station); - } - - public StationResponse findStationResponseById(Long id) { - return StationResponse.of(stationDao.findById(id)); - } - - public List findAllStationResponses() { - List stations = stationDao.findAll(); - - return stations.stream() - .map(StationResponse::of) - .collect(Collectors.toList()); - } - - public void updateStation(Long id, StationRequest stationRequest) { - stationDao.update(new Station(id, stationRequest.getName())); - } - - public void deleteStationById(Long id) { - stationDao.deleteById(id); - } -} \ No newline at end of file diff --git a/src/main/java/subway/config/SwaggerConfiguration.java b/src/main/java/subway/config/SwaggerConfiguration.java new file mode 100644 index 000000000..da242fdc9 --- /dev/null +++ b/src/main/java/subway/config/SwaggerConfiguration.java @@ -0,0 +1,22 @@ +package subway.config; + +import io.swagger.v3.oas.models.Components; +import io.swagger.v3.oas.models.OpenAPI; +import io.swagger.v3.oas.models.info.Info; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration +public class SwaggerConfiguration { + + @Bean + public OpenAPI openAPI() { + final Info info = new Info() + .title("지하철 API Document") + .version("v0.0.1") + .description("지하철 API 명세서입니다."); + return new OpenAPI() + .components(new Components()) + .info(info); + } +} diff --git a/src/main/java/subway/controller/ControllerAdvice.java b/src/main/java/subway/controller/ControllerAdvice.java new file mode 100644 index 000000000..81b3f69fb --- /dev/null +++ b/src/main/java/subway/controller/ControllerAdvice.java @@ -0,0 +1,58 @@ +package subway.controller; + +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.context.support.DefaultMessageSourceResolvable; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.MethodArgumentNotValidException; +import org.springframework.web.bind.MissingServletRequestParameterException; +import org.springframework.web.bind.annotation.ExceptionHandler; +import org.springframework.web.bind.annotation.RestControllerAdvice; +import org.springframework.web.method.annotation.MethodArgumentTypeMismatchException; +import subway.exception.SubwayException; + +@RestControllerAdvice +public class ControllerAdvice { + + private static final Logger LOGGER = LoggerFactory.getLogger(ControllerAdvice.class); + + @ExceptionHandler(Exception.class) + public ResponseEntity handleException(final Exception exception) { + final String message = exception.getMessage(); + LOGGER.error(message); + return ResponseEntity.internalServerError().body("알 수 없는 서버 에러가 발생했습니다."); + } + + @ExceptionHandler(MethodArgumentNotValidException.class) + public ResponseEntity handleException(final MethodArgumentNotValidException exception) { + final String message = exception.getBindingResult() + .getFieldErrors() + .stream() + .map(DefaultMessageSourceResolvable::getDefaultMessage) + .collect(Collectors.joining(System.lineSeparator())); + LOGGER.warn(message); + return ResponseEntity.badRequest().body(message); + } + + @ExceptionHandler(MethodArgumentTypeMismatchException.class) + public ResponseEntity handleException(final MethodArgumentTypeMismatchException exception) { + final String message = exception.getMessage(); + LOGGER.warn(message); + return ResponseEntity.badRequest().body(message); + } + + @ExceptionHandler(MissingServletRequestParameterException.class) + public ResponseEntity handleException(final MissingServletRequestParameterException exception) { + final String message = exception.getMessage(); + LOGGER.warn(message); + return ResponseEntity.badRequest().body(message); + } + + @ExceptionHandler(SubwayException.class) + public ResponseEntity handleException(final SubwayException exception) { + final String message = exception.getMessage(); + LOGGER.warn(message); + return ResponseEntity.badRequest().body(message); + } +} diff --git a/src/main/java/subway/controller/LineController.java b/src/main/java/subway/controller/LineController.java new file mode 100644 index 000000000..babdbf0d6 --- /dev/null +++ b/src/main/java/subway/controller/LineController.java @@ -0,0 +1,73 @@ +package subway.controller; + +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; +import java.net.URI; +import javax.validation.Valid; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.DeleteMapping; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RestController; +import subway.controller.dto.request.LineCreateRequest; +import subway.controller.dto.request.SectionCreateRequest; +import subway.controller.dto.response.LineResponse; +import subway.controller.dto.response.LinesResponse; +import subway.service.LineService; + +@Tag(name = "Line", description = "노선 API Document") +@RequestMapping("/lines") +@RestController +public class LineController { + + private final LineService lineService; + + public LineController(final LineService lineService) { + this.lineService = lineService; + } + + @Operation(summary = "노선 추가 API", description = "새로운 노선을 추가합니다.") + @PostMapping + public ResponseEntity createLine(@Valid @RequestBody LineCreateRequest request) { + final Long lineId = lineService.createLine(request); + return ResponseEntity.created(URI.create("/lines/" + lineId)).build(); + } + + @Operation(summary = "노선 정보 조회 API", description = "노선의 정보를 조회합니다.") + @GetMapping("/{id}") + public ResponseEntity findLine(@PathVariable(name = "id") Long lineId) { + final LineResponse response = lineService.findLineById(lineId); + return ResponseEntity.ok(response); + } + + @Operation(summary = "모든 노선 정보 조회 API", description = "모든 노선의 정보를 조회합니다.") + @GetMapping + public ResponseEntity findLines() { + final LinesResponse response = lineService.findLines(); + return ResponseEntity.ok(response); + } + + @Operation(summary = "노선 구간 등록 API", description = "노선에 새로운 구간을 등록합니다.") + @PostMapping("/{id}/sections") + public ResponseEntity createSection( + @PathVariable(name = "id") Long lineId, + @Valid @RequestBody SectionCreateRequest request + ) { + lineService.createSection(lineId, request); + return ResponseEntity.created(URI.create("/lines/" + lineId)).build(); + } + + @Operation(summary = "노선 구간 삭제 API", description = "노선의 특정 구간을 삭제합니다.") + @DeleteMapping("/{id}") + public ResponseEntity deleteSection( + @PathVariable(name = "id") Long lineId, + @RequestParam Long stationId + ) { + lineService.deleteStation(lineId, stationId); + return ResponseEntity.noContent().build(); + } +} diff --git a/src/main/java/subway/controller/StationController.java b/src/main/java/subway/controller/StationController.java new file mode 100644 index 000000000..6e08e5363 --- /dev/null +++ b/src/main/java/subway/controller/StationController.java @@ -0,0 +1,42 @@ +package subway.controller; + +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; +import java.net.URI; +import javax.validation.Valid; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import subway.controller.dto.request.StationCreateRequest; +import subway.controller.dto.response.StationResponse; +import subway.service.StationService; + +@Tag(name = "Station", description = "역 API Document") +@RequestMapping("/stations") +@RestController +public class StationController { + + private final StationService stationService; + + private StationController(final StationService stationService) { + this.stationService = stationService; + } + + @Operation(summary = "역 등록 API", description = "새로운 역을 등록합니다.") + @PostMapping + public ResponseEntity createStation(@Valid @RequestBody StationCreateRequest request) { + final Long stationId = stationService.createStation(request); + return ResponseEntity.created(URI.create("/stations/" + stationId)).build(); + } + + @Operation(summary = "역 정보 조회 API", description = "역 정보를 조회합니다.") + @GetMapping("/{id}") + public ResponseEntity findStationById(@PathVariable(name = "id") Long stationId) { + final StationResponse response = stationService.findStationById(stationId); + return ResponseEntity.ok(response); + } +} diff --git a/src/main/java/subway/controller/SubwayController.java b/src/main/java/subway/controller/SubwayController.java new file mode 100644 index 000000000..b6c4326bc --- /dev/null +++ b/src/main/java/subway/controller/SubwayController.java @@ -0,0 +1,32 @@ +package subway.controller; + +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; +import javax.validation.Valid; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; +import subway.controller.dto.request.PassengerRequest; +import subway.controller.dto.response.ShortestPathResponse; +import subway.service.SubwayService; + +@Tag(name = "Subway", description = "지하철 API Document") +@RequestMapping("/subways") +@RestController +public class SubwayController { + + private final SubwayService subwayService; + + public SubwayController(final SubwayService subwayService) { + this.subwayService = subwayService; + } + + @Operation(summary = "경로 정보 조회 API", description = "출발역에서 도착역까지의 경로 정보를 조회합니다.") + @GetMapping("/shortest-path") + public ResponseEntity findShortestPath(@Valid @RequestBody PassengerRequest request) { + final ShortestPathResponse response = subwayService.findShortestPath(request); + return ResponseEntity.ok(response); + } +} diff --git a/src/main/java/subway/controller/dto/request/LineCreateRequest.java b/src/main/java/subway/controller/dto/request/LineCreateRequest.java new file mode 100644 index 000000000..5ce7726e2 --- /dev/null +++ b/src/main/java/subway/controller/dto/request/LineCreateRequest.java @@ -0,0 +1,47 @@ +package subway.controller.dto.request; + +import io.swagger.v3.oas.annotations.media.Schema; +import javax.validation.constraints.Min; +import javax.validation.constraints.NotBlank; +import javax.validation.constraints.NotNull; + +@Schema( + description = "노선 생성 요청 정보", + example = "{\"name\": \"2호선\", \"color\": \"초록색\", \"fare\": 1000}" +) +public class LineCreateRequest { + + @Schema(description = "노선 이름") + @NotBlank(message = "노선 이름은 공백일 수 없습니다.") + private String name; + + @Schema(description = "노선 색") + @NotBlank(message = "노선 색깔은 공백일 수 없습니다.") + private String color; + + @Schema(description = "노선 추가 요금") + @NotNull(message = "노선 추가 요금은 존재해야 합니다.") + @Min(value = 0, message = "노선 추가 요금은 0원 이상 가능합니다.") + private Integer fare; + + private LineCreateRequest() { + } + + public LineCreateRequest(final String name, final String color, final Integer fare) { + this.name = name; + this.color = color; + this.fare = fare; + } + + public String getName() { + return name; + } + + public String getColor() { + return color; + } + + public Integer getFare() { + return fare; + } +} diff --git a/src/main/java/subway/controller/dto/request/PassengerRequest.java b/src/main/java/subway/controller/dto/request/PassengerRequest.java new file mode 100644 index 000000000..daeb55195 --- /dev/null +++ b/src/main/java/subway/controller/dto/request/PassengerRequest.java @@ -0,0 +1,46 @@ +package subway.controller.dto.request; + +import io.swagger.v3.oas.annotations.media.Schema; +import javax.validation.constraints.Min; +import javax.validation.constraints.NotNull; + +@Schema( + description = "지하철 승객 정보", + example = "{\"age\": 15, \"startStationId\": 1, \"endStationId\": 2}" +) +public class PassengerRequest { + + @Schema(description = "탑승자 나이") + @NotNull(message = "탑승자 나이는 입력해야 합니다.") + @Min(value = 1, message = "탑승자 나이는 0보다 커야합니다.") + private Integer age; + + @Schema(description = "출발역 ID") + @NotNull(message = "출발역 ID는 존재해야 합니다.") + private Long startStationId; + + @Schema(description = "도착역 ID") + @NotNull(message = "도착역 ID는 존재해야 합니다.") + private Long endStationId; + + private PassengerRequest() { + } + + public PassengerRequest(final Integer age, final Long startStationId, final Long endStationId) { + this.age = age; + this.startStationId = startStationId; + this.endStationId = endStationId; + } + + public Integer getAge() { + return age; + } + + public Long getStartStationId() { + return startStationId; + } + + public Long getEndStationId() { + return endStationId; + } +} diff --git a/src/main/java/subway/controller/dto/request/SectionCreateRequest.java b/src/main/java/subway/controller/dto/request/SectionCreateRequest.java new file mode 100644 index 000000000..0bb854ef7 --- /dev/null +++ b/src/main/java/subway/controller/dto/request/SectionCreateRequest.java @@ -0,0 +1,46 @@ +package subway.controller.dto.request; + +import io.swagger.v3.oas.annotations.media.Schema; +import javax.validation.constraints.NotNull; +import javax.validation.constraints.Positive; + +@Schema( + description = "구간 정보 생성 요청 정보", + example = "{\"upwardStationId\": 1, \"downwardStationId\": 2, \"distance\": 5}" +) +public class SectionCreateRequest { + + @Schema(description = "상행 역 ID") + @NotNull(message = "상행 역 ID는 존재해야 합니다.") + private Long upwardStationId; + + @Schema(description = "하행 역 ID") + @NotNull(message = "하행 역 ID는 존재해야 합니다.") + private Long downwardStationId; + + @Schema(description = "상행 역, 하행 역 사이 거리") + @NotNull(message = "역 간의 거리는 존재해야 합니다.") + @Positive(message = "역 간의 거리는 0보다 커야합니다.") + private Integer distance; + + private SectionCreateRequest() { + } + + public SectionCreateRequest(final Long upwardStationId, final Long downwardStationId, final Integer distance) { + this.upwardStationId = upwardStationId; + this.downwardStationId = downwardStationId; + this.distance = distance; + } + + public Long getUpwardStationId() { + return upwardStationId; + } + + public Long getDownwardStationId() { + return downwardStationId; + } + + public Integer getDistance() { + return distance; + } +} diff --git a/src/main/java/subway/controller/dto/request/StationCreateRequest.java b/src/main/java/subway/controller/dto/request/StationCreateRequest.java new file mode 100644 index 000000000..07fc5e6f7 --- /dev/null +++ b/src/main/java/subway/controller/dto/request/StationCreateRequest.java @@ -0,0 +1,26 @@ +package subway.controller.dto.request; + +import io.swagger.v3.oas.annotations.media.Schema; +import javax.validation.constraints.NotBlank; + +@Schema( + description = "역 생성 요청 정보", + example = "{\"name\": \"잠실역\"}" +) +public class StationCreateRequest { + + @Schema(description = "역 이름") + @NotBlank(message = "역 이름은 공백일 수 없습니다.") + private String name; + + private StationCreateRequest() { + } + + public StationCreateRequest(final String name) { + this.name = name; + } + + public String getName() { + return name; + } +} diff --git a/src/main/java/subway/controller/dto/response/LineResponse.java b/src/main/java/subway/controller/dto/response/LineResponse.java new file mode 100644 index 000000000..0439caf4a --- /dev/null +++ b/src/main/java/subway/controller/dto/response/LineResponse.java @@ -0,0 +1,79 @@ +package subway.controller.dto.response; + +import io.swagger.v3.oas.annotations.media.Schema; +import java.util.List; +import java.util.stream.Collectors; +import subway.domain.line.Line; +import subway.domain.station.Station; + +@Schema( + description = "역 응답 정보", + example = "{\"id\": 1, \"name\": \"2호선\", \"color\": \"초록색\", \"fare\": 1000, \"stations\": [{\"id\": 1, \"name\": \"잠실역\"}]}" +) +public class LineResponse { + + @Schema(description = "노선 ID") + private Long id; + + @Schema(description = "노선 이름") + private String name; + + @Schema(description = "노선 색") + private String color; + + @Schema(description = "노선 추가 요금") + private Integer fare; + + @Schema(description = "노선의 역 목록") + private List stations; + + public LineResponse( + final Long id, + final String name, + final String color, + final Integer fare, + final List stations + ) { + this.id = id; + this.name = name; + this.color = color; + this.fare = fare; + this.stations = stations; + } + + public static LineResponse from(final Line line) { + return new LineResponse( + line.getId(), + line.getName(), + line.getColor(), + line.getFare(), + generateStations(line.getStations()) + ); + } + + private static List generateStations(final List stations) { + return stations.stream() + .map(StationResponse::from) + .collect(Collectors.toUnmodifiableList()); + } + + public Long getId() { + return id; + } + + public String getName() { + return name; + } + + public String getColor() { + return color; + } + + public Integer getFare() { + return fare; + } + + public List getStations() { + return stations; + } +} diff --git a/src/main/java/subway/controller/dto/response/LineSectionResponse.java b/src/main/java/subway/controller/dto/response/LineSectionResponse.java new file mode 100644 index 000000000..24e4cb148 --- /dev/null +++ b/src/main/java/subway/controller/dto/response/LineSectionResponse.java @@ -0,0 +1,45 @@ +package subway.controller.dto.response; + +import io.swagger.v3.oas.annotations.media.Schema; +import java.util.ArrayList; +import java.util.List; +import subway.domain.section.PathSection; + +@Schema( + description = "구간 응답 정보", + example = "{\"lineId\": 1, \"sections\": [{\"upwardStationName\": \"잠실역\", \"downwardStationName\": \"잠실새내역\", \"distance\": 5}]}" +) +public class LineSectionResponse { + + @Schema(description = "노선 ID") + private Long lineId; + + @Schema(description = "구간 정보 목록") + private List sections; + + public LineSectionResponse(final Long lineId, final List sections) { + this.lineId = lineId; + this.sections = sections; + } + + public static LineSectionResponse from(final List sections) { + return new LineSectionResponse(sections.get(0).getLineId(), generateSections(sections)); + } + + private static List generateSections(final List sections) { + final List result = new ArrayList<>(); + + for (final PathSection section : sections) { + result.add(SectionResponse.from(section)); + } + return result; + } + + public Long getLineId() { + return lineId; + } + + public List getSections() { + return sections; + } +} diff --git a/src/main/java/subway/controller/dto/response/LinesResponse.java b/src/main/java/subway/controller/dto/response/LinesResponse.java new file mode 100644 index 000000000..b58c6c78b --- /dev/null +++ b/src/main/java/subway/controller/dto/response/LinesResponse.java @@ -0,0 +1,25 @@ +package subway.controller.dto.response; + +import io.swagger.v3.oas.annotations.media.Schema; +import java.util.List; + +@Schema( + description = "노선 목록 응답 정보", + example = "{\"lines\": [{\"id\": 1, \"name\": \"2호선\", \"color\": \"초록색\", \"fare\": 1000, \"stations\": [{\"id\": 1, \"name\": \"잠실역\"}]}]}" +) +public class LinesResponse { + + @Schema(description = "노선 목록") + private List lines; + + public LinesResponse() { + } + + public LinesResponse(final List lines) { + this.lines = lines; + } + + public List getLines() { + return lines; + } +} diff --git a/src/main/java/subway/controller/dto/response/SectionResponse.java b/src/main/java/subway/controller/dto/response/SectionResponse.java new file mode 100644 index 000000000..40eae5e29 --- /dev/null +++ b/src/main/java/subway/controller/dto/response/SectionResponse.java @@ -0,0 +1,47 @@ +package subway.controller.dto.response; + +import io.swagger.v3.oas.annotations.media.Schema; +import subway.domain.section.PathSection; + +@Schema( + description = "구간 응답 정보", + example = "{\"upwardStationName\": \"잠실역\", \"downwardStationName\": \"잠실새내역\", \"distance\": 10}" +) +public class SectionResponse { + + @Schema(description = "상행역 이름") + private String upwardStationName; + + @Schema(description = "하행역 이름") + private String downwardStationName; + + @Schema(description = "상행역과 하행역 사이의 거리") + private int distance; + + public SectionResponse(final String upwardStationName, final String downwardStationName, final int distance) { + this.upwardStationName = upwardStationName; + this.downwardStationName = downwardStationName; + this.distance = distance; + } + + public static SectionResponse from(final PathSection pathSection) { + return new SectionResponse( + pathSection.getSource().getName(), + pathSection.getTarget().getName(), + pathSection.getDistance() + ); + } + + public String getUpwardStationName() { + return upwardStationName; + } + + public String getDownwardStationName() { + return downwardStationName; + } + + public int getDistance() { + return distance; + } +} + diff --git a/src/main/java/subway/controller/dto/response/ShortestPathResponse.java b/src/main/java/subway/controller/dto/response/ShortestPathResponse.java new file mode 100644 index 000000000..46e49e32b --- /dev/null +++ b/src/main/java/subway/controller/dto/response/ShortestPathResponse.java @@ -0,0 +1,82 @@ +package subway.controller.dto.response; + +import io.swagger.v3.oas.annotations.media.Schema; +import java.util.ArrayList; +import java.util.List; +import subway.domain.section.PathSection; + +@Schema( + description = "최단 경로 응답 정보", + example = "{\"transferCount\": 3, \"path\": [{\"lineId\": 1, \"sections\": [{\"upwardStationName\": \"잠실역\", \"downwardStationName\": \"잠실새내역\", \"distance\": 5}]}], \"totalDistance\": 10, \"subwayFare\": 2000}" +) +public class ShortestPathResponse { + + @Schema(description = "환승 횟수") + private int transferCount; + + @Schema(description = "구간 경로 목록") + private List path; + + @Schema(description = "출발역에서 도착역까지의 총 거리") + private long totalDistance; + + @Schema(description = "출발역에서 도착역까지 운임 요금") + private long subwayFare; + + public ShortestPathResponse( + final int transferCount, + final List path, + final long totalDistance, + final long subwayFare + ) { + this.transferCount = transferCount; + this.path = path; + this.totalDistance = totalDistance; + this.subwayFare = subwayFare; + } + + public static ShortestPathResponse of( + final List sections, + final long totalDistance, + final long subwayFare + ) { + final List path = generateLineSections(sections); + return new ShortestPathResponse(path.size() - 1, path, totalDistance, subwayFare); + } + + private static List generateLineSections(final List sections) { + final List result = new ArrayList<>(); + Long currentLineId = sections.get(0).getLineId(); + List currentSections = new ArrayList<>(); + + for (final PathSection section : sections) { + if (section.getLineId() != currentLineId) { + result.add(LineSectionResponse.from(currentSections)); + currentSections = new ArrayList<>(); + currentLineId = section.getLineId(); + } + currentSections.add(section); + } + + if (!currentSections.isEmpty()) { + result.add(LineSectionResponse.from(currentSections)); + } + return result; + } + + public int getTransferCount() { + return transferCount; + } + + public List getPath() { + return path; + } + + public long getTotalDistance() { + return totalDistance; + } + + public long getSubwayFare() { + return subwayFare; + } +} diff --git a/src/main/java/subway/controller/dto/response/StationResponse.java b/src/main/java/subway/controller/dto/response/StationResponse.java new file mode 100644 index 000000000..1ffeb10b4 --- /dev/null +++ b/src/main/java/subway/controller/dto/response/StationResponse.java @@ -0,0 +1,34 @@ +package subway.controller.dto.response; + +import io.swagger.v3.oas.annotations.media.Schema; +import subway.domain.station.Station; + +@Schema( + description = "역 응답 정보", + example = "{\"id\": 1, \"name\": \"잠실역\"}" +) +public class StationResponse { + + @Schema(description = "역 ID") + private Long id; + + @Schema(description = "역 이름") + private String name; + + public StationResponse(final Long id, final String name) { + this.id = id; + this.name = name; + } + + public static StationResponse from(final Station station) { + return new StationResponse(station.getId(), station.getName()); + } + + public Long getId() { + return id; + } + + public String getName() { + return name; + } +} diff --git a/src/main/java/subway/dao/LineDao.java b/src/main/java/subway/dao/LineDao.java index f644bac29..cf99f4a28 100644 --- a/src/main/java/subway/dao/LineDao.java +++ b/src/main/java/subway/dao/LineDao.java @@ -1,61 +1,67 @@ package subway.dao; +import java.util.List; +import java.util.Optional; +import org.springframework.dao.EmptyResultDataAccessException; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.core.namedparam.BeanPropertySqlParameterSource; +import org.springframework.jdbc.core.namedparam.SqlParameterSource; import org.springframework.jdbc.core.simple.SimpleJdbcInsert; import org.springframework.stereotype.Repository; -import subway.domain.Line; - -import javax.sql.DataSource; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import subway.entity.LineEntity; @Repository public class LineDao { - private final JdbcTemplate jdbcTemplate; - private final SimpleJdbcInsert insertAction; - private RowMapper rowMapper = (rs, rowNum) -> - new Line( - rs.getLong("id"), - rs.getString("name"), - rs.getString("color") - ); + private static final RowMapper ROW_MAPPER = (rs, rowNum) -> new LineEntity( + rs.getLong("id"), + rs.getString("name"), + rs.getString("color"), + rs.getInt("fare") + ); - public LineDao(JdbcTemplate jdbcTemplate, DataSource dataSource) { + private final JdbcTemplate jdbcTemplate; + private final SimpleJdbcInsert jdbcInsert; + + public LineDao(final JdbcTemplate jdbcTemplate) { this.jdbcTemplate = jdbcTemplate; - this.insertAction = new SimpleJdbcInsert(dataSource) + this.jdbcInsert = new SimpleJdbcInsert(jdbcTemplate) .withTableName("line") + .usingColumns("name", "color", "fare") .usingGeneratedKeyColumns("id"); } - public Line insert(Line line) { - Map params = new HashMap<>(); - params.put("id", line.getId()); - params.put("name", line.getName()); - params.put("color", line.getColor()); - - Long lineId = insertAction.executeAndReturnKey(params).longValue(); - return new Line(lineId, line.getName(), line.getColor()); + public LineEntity save(final LineEntity lineEntity) { + final SqlParameterSource parameterSource = new BeanPropertySqlParameterSource(lineEntity); + final Long lineId = jdbcInsert.executeAndReturnKey(parameterSource).longValue(); + return new LineEntity(lineId, lineEntity.getName(), lineEntity.getColor(), lineEntity.getFare()); } - public List findAll() { - String sql = "select id, name, color from LINE"; - return jdbcTemplate.query(sql, rowMapper); - } - - public Line findById(Long id) { - String sql = "select id, name, color from LINE WHERE id = ?"; - return jdbcTemplate.queryForObject(sql, rowMapper, id); + public Optional findById(final Long lineId) { + final String sql = "SELECT id, name, color, fare FROM line WHERE id = ?"; + try { + final LineEntity result = jdbcTemplate.queryForObject( + sql, + ROW_MAPPER, + lineId + ); + return Optional.ofNullable(result); + } catch (EmptyResultDataAccessException exception) { + return Optional.empty(); + } } - public void update(Line newLine) { - String sql = "update LINE set name = ?, color = ? where id = ?"; - jdbcTemplate.update(sql, new Object[]{newLine.getName(), newLine.getColor(), newLine.getId()}); + public List findAll() { + final String sql = "SELECT id, name, color, fare FROM line"; + return jdbcTemplate.query(sql, ROW_MAPPER); } - public void deleteById(Long id) { - jdbcTemplate.update("delete from Line where id = ?", id); + public int update(final LineEntity lineEntity) { + final String sql = "UPDATE line SET name = ?, color = ?, fare = ? WHERE id = ?"; + return jdbcTemplate.update( + sql, + lineEntity.getName(), lineEntity.getColor(), lineEntity.getFare(), lineEntity.getId() + ); } } diff --git a/src/main/java/subway/dao/SectionDao.java b/src/main/java/subway/dao/SectionDao.java new file mode 100644 index 000000000..53583e956 --- /dev/null +++ b/src/main/java/subway/dao/SectionDao.java @@ -0,0 +1,81 @@ +package subway.dao; + +import java.util.List; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.core.namedparam.BeanPropertySqlParameterSource; +import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate; +import org.springframework.jdbc.core.namedparam.SqlParameterSource; +import org.springframework.jdbc.core.namedparam.SqlParameterSourceUtils; +import org.springframework.jdbc.core.simple.SimpleJdbcInsert; +import org.springframework.stereotype.Repository; +import subway.entity.SectionEntity; + +@Repository +public class SectionDao { + + private static final RowMapper ROW_MAPPER = (rs, count) -> new SectionEntity( + rs.getLong("id"), + rs.getLong("line_id"), + rs.getLong("upward_station_id"), + rs.getString("upward_station_name"), + rs.getLong("downward_station_id"), + rs.getString("downward_station_name"), + rs.getInt("distance") + ); + + private final NamedParameterJdbcTemplate jdbcTemplate; + private final SimpleJdbcInsert jdbcInsert; + + public SectionDao(final JdbcTemplate jdbcTemplate) { + this.jdbcTemplate = new NamedParameterJdbcTemplate(jdbcTemplate); + this.jdbcInsert = new SimpleJdbcInsert(jdbcTemplate) + .withTableName("section") + .usingColumns("line_id", "upward_station_id", "downward_station_id", "distance") + .usingGeneratedKeyColumns("id"); + } + + public SectionEntity save(final SectionEntity sectionEntity) { + final SqlParameterSource sqlParameterSource = new BeanPropertySqlParameterSource(sectionEntity); + final Long sectionId = jdbcInsert.executeAndReturnKey(sqlParameterSource).longValue(); + final String sql = "SELECT s.id AS id," + + " s.line_id AS line_id," + + " us.id AS upward_station_id," + + " us.name AS upward_station_name," + + " ds.id AS downward_station_id," + + " ds.name AS downward_station_name," + + " s.distance AS distance" + + " FROM section s" + + " JOIN station us ON s.upward_station_id = us.id" + + " JOIN station ds ON s.downward_station_id = ds.id" + + " WHERE s.id = ?"; + return jdbcTemplate.getJdbcOperations().queryForObject(sql, ROW_MAPPER, sectionId); + } + + public void saveAll(final List sectionEntities) { + final String sql = "INSERT INTO section (line_id, upward_station_id, downward_station_id, distance)" + + " VALUES (:lineId, :upwardStationId, :downwardStationId, :distance)"; + jdbcTemplate.batchUpdate(sql, SqlParameterSourceUtils.createBatch(sectionEntities)); + } + + public List findAllByLineId(final Long lineId) { + final String sql = "SELECT s.id AS id," + + " s.line_id AS line_id," + + " us.id AS upward_station_id," + + " us.name AS upward_station_name," + + " ds.id AS downward_station_id," + + " ds.name AS downward_station_name," + + " s.distance AS distance" + + " FROM section s" + + " JOIN station us ON s.upward_station_id = us.id" + + " JOIN station ds ON s.downward_station_id = ds.id" + + " WHERE s.line_id = ?" + + " ORDER BY s.id"; + return jdbcTemplate.getJdbcOperations().query(sql, ROW_MAPPER, lineId); + } + + public void deleteAllByLineId(final Long lineId) { + final String sql = "DELETE FROM section WHERE line_id = ?"; + jdbcTemplate.getJdbcOperations().update(sql, lineId); + } +} diff --git a/src/main/java/subway/dao/StationDao.java b/src/main/java/subway/dao/StationDao.java index 07f7eab30..16f9206ca 100644 --- a/src/main/java/subway/dao/StationDao.java +++ b/src/main/java/subway/dao/StationDao.java @@ -1,58 +1,47 @@ package subway.dao; +import java.util.Optional; +import org.springframework.dao.EmptyResultDataAccessException; import org.springframework.jdbc.core.JdbcTemplate; -import org.springframework.jdbc.core.RowMapper; import org.springframework.jdbc.core.namedparam.BeanPropertySqlParameterSource; import org.springframework.jdbc.core.namedparam.SqlParameterSource; import org.springframework.jdbc.core.simple.SimpleJdbcInsert; import org.springframework.stereotype.Repository; -import subway.domain.Station; - -import javax.sql.DataSource; -import java.util.List; +import subway.entity.StationEntity; @Repository public class StationDao { - private final JdbcTemplate jdbcTemplate; - private final SimpleJdbcInsert insertAction; - - private RowMapper rowMapper = (rs, rowNum) -> - new Station( - rs.getLong("id"), - rs.getString("name") - ); + private final JdbcTemplate jdbcTemplate; + private final SimpleJdbcInsert jdbcInsert; - public StationDao(JdbcTemplate jdbcTemplate, DataSource dataSource) { + public StationDao(final JdbcTemplate jdbcTemplate) { this.jdbcTemplate = jdbcTemplate; - this.insertAction = new SimpleJdbcInsert(dataSource) + jdbcInsert = new SimpleJdbcInsert(jdbcTemplate) .withTableName("station") - .usingGeneratedKeyColumns("id"); - } - - public Station insert(Station station) { - SqlParameterSource params = new BeanPropertySqlParameterSource(station); - Long id = insertAction.executeAndReturnKey(params).longValue(); - return new Station(id, station.getName()); - } - - public List findAll() { - String sql = "select * from STATION"; - return jdbcTemplate.query(sql, rowMapper); - } - - public Station findById(Long id) { - String sql = "select * from STATION where id = ?"; - return jdbcTemplate.queryForObject(sql, rowMapper, id); + .usingGeneratedKeyColumns("id") + .usingColumns("name"); } - public void update(Station newStation) { - String sql = "update STATION set name = ? where id = ?"; - jdbcTemplate.update(sql, new Object[]{newStation.getName(), newStation.getId()}); + public StationEntity save(final StationEntity stationEntity) { + final SqlParameterSource parameterSource = new BeanPropertySqlParameterSource(stationEntity); + final Long stationId = jdbcInsert.executeAndReturnKey(parameterSource).longValue(); + return new StationEntity(stationId, stationEntity.getName()); } - public void deleteById(Long id) { - String sql = "delete from STATION where id = ?"; - jdbcTemplate.update(sql, id); + public Optional findById(final Long stationId) { + final String sql = "SELECT id, name FROM station WHERE id = ?"; + try { + final StationEntity result = jdbcTemplate.queryForObject( + sql, + (rs, rowNum) -> new StationEntity( + rs.getLong("id"), + rs.getString("name") + ), + stationId); + return Optional.ofNullable(result); + } catch (EmptyResultDataAccessException exception) { + return Optional.empty(); + } } } diff --git a/src/main/java/subway/domain/Line.java b/src/main/java/subway/domain/Line.java deleted file mode 100644 index 699d0b6df..000000000 --- a/src/main/java/subway/domain/Line.java +++ /dev/null @@ -1,48 +0,0 @@ -package subway.domain; - -import java.util.Objects; - -public class Line { - private Long id; - private String name; - private String color; - - public Line() { - } - - public Line(String name, String color) { - this.name = name; - this.color = color; - } - - public Line(Long id, String name, String color) { - this.id = id; - this.name = name; - this.color = color; - } - - public Long getId() { - return id; - } - - public String getName() { - return name; - } - - public String getColor() { - return color; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Line line = (Line) o; - return Objects.equals(id, line.id) && Objects.equals(name, line.name) && Objects.equals(color, line.color); - } - - @Override - public int hashCode() { - return Objects.hash(id, name, color); - } -} diff --git a/src/main/java/subway/domain/Station.java b/src/main/java/subway/domain/Station.java deleted file mode 100644 index dbf9d7835..000000000 --- a/src/main/java/subway/domain/Station.java +++ /dev/null @@ -1,41 +0,0 @@ -package subway.domain; - -import java.util.Objects; - -public class Station { - private Long id; - private String name; - - public Station() { - } - - public Station(Long id, String name) { - this.id = id; - this.name = name; - } - - public Station(String name) { - this.name = name; - } - - public Long getId() { - return id; - } - - public String getName() { - return name; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Station station = (Station) o; - return id.equals(station.id) && name.equals(station.name); - } - - @Override - public int hashCode() { - return Objects.hash(id, name); - } -} diff --git a/src/main/java/subway/domain/fare/AgeFareStrategy.java b/src/main/java/subway/domain/fare/AgeFareStrategy.java new file mode 100644 index 000000000..0f016f67e --- /dev/null +++ b/src/main/java/subway/domain/fare/AgeFareStrategy.java @@ -0,0 +1,13 @@ +package subway.domain.fare; + +import subway.domain.subway.Passenger; +import subway.domain.subway.Subway; + +class AgeFareStrategy implements FareStrategy { + + @Override + public double calculateFare(final double fare, final Passenger passenger, final Subway subway) { + final AgePolicy agePolicy = AgePolicy.search(passenger.getAge()); + return agePolicy.calculateDiscountFare(fare); + } +} diff --git a/src/main/java/subway/domain/fare/AgePolicy.java b/src/main/java/subway/domain/fare/AgePolicy.java new file mode 100644 index 000000000..33e02de78 --- /dev/null +++ b/src/main/java/subway/domain/fare/AgePolicy.java @@ -0,0 +1,42 @@ +package subway.domain.fare; + +import java.util.Arrays; +import java.util.function.Function; +import subway.exception.InvalidPolicyException; + +public enum AgePolicy { + + BABY(age -> age < 6, fare -> fare), + KID(age -> 6 <= age && age < 13, fare -> { + final double deductionFare = Math.max(0, fare - 350); + return Math.max(0, deductionFare - deductionFare * 0.5); + }), + TEEN(age -> 13 <= age && age < 19, fare -> { + final double deductionFare = Math.max(0, fare - 350); + return Math.max(0, deductionFare - deductionFare * 0.2); + }), + ADULT(age -> age >= 19, fare -> fare); + + private final Function validTarget; + private final Function discount; + + AgePolicy(final Function validTarget, final Function discount) { + this.validTarget = validTarget; + this.discount = discount; + } + + public static AgePolicy search(final int age) { + return Arrays.stream(AgePolicy.values()) + .filter(agePolicy -> agePolicy.canBeApplied(age)) + .findFirst() + .orElseThrow(() -> new InvalidPolicyException("적용할 수 있는 정책이 존재하지 않습니다.")); + } + + private boolean canBeApplied(final int age) { + return validTarget.apply(age); + } + + public double calculateDiscountFare(final double fare) { + return discount.apply(fare); + } +} diff --git a/src/main/java/subway/domain/fare/DistanceFareStrategy.java b/src/main/java/subway/domain/fare/DistanceFareStrategy.java new file mode 100644 index 000000000..3a37e254e --- /dev/null +++ b/src/main/java/subway/domain/fare/DistanceFareStrategy.java @@ -0,0 +1,20 @@ +package subway.domain.fare; + +import subway.domain.subway.Passenger; +import subway.domain.subway.Subway; + +class DistanceFareStrategy implements FareStrategy { + + private static final int DEFAULT_FARE = 1250; + + @Override + public double calculateFare(final double fare, final Passenger passenger, final Subway subway) { + final long distance = subway.calculateShortestDistance(passenger.getStart(), passenger.getEnd()); + + long totalFare = DEFAULT_FARE; + for (final DistancePolicy distancePolicy : DistancePolicy.values()) { + totalFare += distancePolicy.calculateAdditionFare(distance); + } + return fare + totalFare; + } +} diff --git a/src/main/java/subway/domain/fare/DistancePolicy.java b/src/main/java/subway/domain/fare/DistancePolicy.java new file mode 100644 index 000000000..215eecf75 --- /dev/null +++ b/src/main/java/subway/domain/fare/DistancePolicy.java @@ -0,0 +1,34 @@ +package subway.domain.fare; + +import java.util.function.Function; + +public enum DistancePolicy { + + BASE_FIFTY(50, distance -> { + final long additionDistance = distance - 50; + return calculateDistanceFarePerStep(additionDistance, 8); + }), + BASE_TEN(10, distance -> { + final long additionDistance = Math.min(BASE_FIFTY.base, distance) - 10; + return calculateDistanceFarePerStep(additionDistance, 5); + }); + + private final int base; + private final Function policy; + + DistancePolicy(final int base, final Function policy) { + this.base = base; + this.policy = policy; + } + + private static long calculateDistanceFarePerStep(final long distance, final int step) { + if (distance <= 0) { + return 0L; + } + return (long) ((Math.ceil((distance - 1) / step) + 1) * 100); + } + + public long calculateAdditionFare(final long distance) { + return policy.apply(distance); + } +} diff --git a/src/main/java/subway/domain/fare/FareStrategy.java b/src/main/java/subway/domain/fare/FareStrategy.java new file mode 100644 index 000000000..67949f1c6 --- /dev/null +++ b/src/main/java/subway/domain/fare/FareStrategy.java @@ -0,0 +1,9 @@ +package subway.domain.fare; + +import subway.domain.subway.Passenger; +import subway.domain.subway.Subway; + +public interface FareStrategy { + + double calculateFare(final double fare, final Passenger passenger, final Subway subway); +} diff --git a/src/main/java/subway/domain/fare/FareStrategyComposite.java b/src/main/java/subway/domain/fare/FareStrategyComposite.java new file mode 100644 index 000000000..8c4d52c4d --- /dev/null +++ b/src/main/java/subway/domain/fare/FareStrategyComposite.java @@ -0,0 +1,25 @@ +package subway.domain.fare; + +import java.util.List; +import org.springframework.stereotype.Component; +import subway.domain.subway.Passenger; +import subway.domain.subway.Subway; + +@Component +public class FareStrategyComposite implements FareStrategy { + + private final List strategies = List.of( + new DistanceFareStrategy(), + new RouteFareStrategy(), + new AgeFareStrategy() + ); + + @Override + public double calculateFare(final double fare, final Passenger passenger, final Subway subway) { + double baseFare = fare; + for (final FareStrategy strategy : strategies) { + baseFare = strategy.calculateFare(baseFare, passenger, subway); + } + return baseFare; + } +} diff --git a/src/main/java/subway/domain/fare/RouteFareStrategy.java b/src/main/java/subway/domain/fare/RouteFareStrategy.java new file mode 100644 index 000000000..b815a0409 --- /dev/null +++ b/src/main/java/subway/domain/fare/RouteFareStrategy.java @@ -0,0 +1,22 @@ +package subway.domain.fare; + +import java.util.List; +import subway.domain.section.PathSection; +import subway.domain.subway.Passenger; +import subway.domain.subway.Subway; + +class RouteFareStrategy implements FareStrategy { + + @Override + public double calculateFare(final double fare, final Passenger passenger, final Subway subway) { + final List sections = subway.findShortestPathSections(passenger.getStart(), passenger.getEnd()); + final int highestFareOfLine = findHighestFareOfLine(sections); + return fare + highestFareOfLine; + } + + private int findHighestFareOfLine(final List sections) { + return sections.stream() + .map(PathSection::getFareOfLine) + .reduce(0, Integer::max); + } +} diff --git a/src/main/java/subway/domain/line/Color.java b/src/main/java/subway/domain/line/Color.java new file mode 100644 index 000000000..5abc814ce --- /dev/null +++ b/src/main/java/subway/domain/line/Color.java @@ -0,0 +1,34 @@ +package subway.domain.line; + +import java.util.Objects; +import java.util.regex.Pattern; +import subway.exception.InvalidColorException; + +final class Color { + + private static final int MAXIMUM_LENGTH = 11; + private static final Pattern PATTERN = Pattern.compile("^[가-힣]+색$"); + + private final String value; + + public Color(final String value) { + validate(value); + this.value = value; + } + + private void validate(final String value) { + if (Objects.isNull(value) || value.isBlank()) { + throw new InvalidColorException("노선 색은 존재해야 합니다."); + } + if (value.length() > MAXIMUM_LENGTH) { + throw new InvalidColorException("노선 색은 " + MAXIMUM_LENGTH + "글자까지 가능합니다."); + } + if (!PATTERN.matcher(value).matches()) { + throw new InvalidColorException("색 이름은 한글로 이루어져 있고, '색'으로 끝나야 합니다."); + } + } + + public String getValue() { + return value; + } +} diff --git a/src/main/java/subway/domain/line/Fare.java b/src/main/java/subway/domain/line/Fare.java new file mode 100644 index 000000000..a968ed3b8 --- /dev/null +++ b/src/main/java/subway/domain/line/Fare.java @@ -0,0 +1,25 @@ +package subway.domain.line; + +import subway.exception.InvalidFareException; + +public class Fare { + + private static final int MINIMUM_VALUE = 0; + + private final int value; + + public Fare(final int value) { + validate(value); + this.value = value; + } + + private void validate(final int value) { + if (value < MINIMUM_VALUE) { + throw new InvalidFareException("노선 추가 요금은 " + MINIMUM_VALUE + "원보다 커야 합니다."); + } + } + + public int getValue() { + return value; + } +} diff --git a/src/main/java/subway/domain/line/Line.java b/src/main/java/subway/domain/line/Line.java new file mode 100644 index 000000000..1095073fc --- /dev/null +++ b/src/main/java/subway/domain/line/Line.java @@ -0,0 +1,192 @@ +package subway.domain.line; + +import java.util.LinkedList; +import java.util.List; +import subway.domain.section.Section; +import subway.domain.section.Sections; +import subway.domain.station.Station; +import subway.exception.InvalidDistanceException; +import subway.exception.InvalidSectionException; + +public final class Line { + + private static final int ADDITIONAL_INDEX = -1; + private static final int INITIAL_SECTION_SIZE = 2; + + private final Long id; + private final Name name; + private final Color color; + private final Fare fare; + private final Sections sections; + + public Line(final String name, final String color, final int fare) { + this(null, name, color, fare); + } + + public Line(final Long id, final String name, final String color, final int fare) { + this(id, name, color, fare, new LinkedList<>()); + } + + public Line(final Long id, final String name, final String color, final int fare, final List
sections) { + this.id = id; + this.name = new Name(name); + this.color = new Color(color); + this.fare = new Fare(fare); + this.sections = new Sections(sections); + } + + public void addSection(final Station upward, final Station downward, final int distance) { + if (sections.isEmpty()) { + sections.add(new Section(upward, downward, distance)); + sections.add(new Section(downward, Station.TERMINAL, 0)); + return; + } + + final int upwardPosition = sections.findPosition(upward); + final int downwardPosition = sections.findPosition(downward); + validateForAddSection(upwardPosition, downwardPosition); + + if (upwardPosition == ADDITIONAL_INDEX) { + if (isAddAtFront(downwardPosition)) { + addSectionEndPoints(true, upward, downward, distance); + return; + } + addUpwardSectionBetweenStations(upward, downward, distance, downwardPosition); + } + + if (downwardPosition == ADDITIONAL_INDEX) { + if (isAddAtEnd(upwardPosition)) { + addSectionEndPoints(false, upward, downward, distance); + return; + } + addDownwardSectionBetweenStations(upward, downward, distance, upwardPosition); + } + } + + private void validateForAddSection(final int upwardPosition, final int downwardPosition) { + if (upwardPosition != Sections.NOT_EXIST_INDEX && downwardPosition != Sections.NOT_EXIST_INDEX) { + throw new InvalidSectionException("두 역이 이미 노선에 존재합니다."); + } + if (upwardPosition == Sections.NOT_EXIST_INDEX && downwardPosition == Sections.NOT_EXIST_INDEX) { + throw new InvalidSectionException("연결할 역 정보가 없습니다."); + } + } + + private boolean isAddAtFront(final int downwardPosition) { + return downwardPosition == 0; + } + + private boolean isAddAtEnd(final int upwardPosition) { + return upwardPosition == sections.size() - 1; + } + + private void addSectionEndPoints( + final boolean isFirst, + final Station upward, + final Station downward, + final int distance + ) { + sections.deleteByPosition(sections.size() - 1); + sections.add(getEndPosition(isFirst), new Section(upward, downward, distance)); + final Section lastSection = sections.findSectionByPosition(sections.size() - 1); + sections.add(sections.size(), new Section(lastSection.getDownward(), Station.TERMINAL, 0)); + } + + public int getEndPosition(final boolean isFirst) { + if (isFirst) { + return 0; + } + return sections.size(); + } + + private void addUpwardSectionBetweenStations( + final Station upward, + final Station downward, + final int distance, + final int position + ) { + final int targetPosition = position - 1; + final Section section = sections.findSectionByPosition(targetPosition); + sections.deleteByPosition(targetPosition); + validateDistance(section.getDistance(), distance); + sections.add(targetPosition, new Section(upward, downward, distance)); + sections.add(targetPosition, new Section(section.getUpward(), upward, section.getDistance() - distance)); + } + + private void addDownwardSectionBetweenStations( + final Station upward, + final Station downward, + final int distance, + final int position + ) { + final Section section = sections.findSectionByPosition(position); + sections.deleteByPosition(position); + validateDistance(section.getDistance(), distance); + sections.add(position, new Section(downward, section.getDownward(), section.getDistance() - distance)); + sections.add(position, new Section(upward, downward, distance)); + } + + private void validateDistance(final int oldDistance, final int inputDistance) { + if (oldDistance <= inputDistance) { + throw new InvalidDistanceException("추가될 역의 거리는 추가될 위치의 두 역사이의 거리보다 작아야합니다."); + } + } + + public void deleteStation(final Station station) { + final int position = sections.findPosition(station); + if (position == Sections.NOT_EXIST_INDEX) { + throw new InvalidSectionException("노선에 해당 역이 존재하지 않습니다."); + } + + if (sections.size() == INITIAL_SECTION_SIZE) { + sections.clear(); + return; + } + + if (position == 0) { + sections.deleteByPosition(position); + return; + } + + final Section targetSection = sections.findSectionByPosition(position); + final Section previousSection = sections.findSectionByPosition(position - 1); + + sections.deleteByPosition(position - 1); + sections.deleteByPosition(position - 1); + + sections.add( + position - 1, + new Section( + previousSection.getUpward(), + targetSection.getDownward(), + targetSection.getDistance() + previousSection.getDistance() + ) + ); + } + + public Long getId() { + return id; + } + + public String getName() { + return name.getValue(); + } + + public String getColor() { + return color.getValue(); + } + + public int getFare() { + return fare.getValue(); + } + + public List getStations() { + return sections.getUpwards(); + } + + public List
getSections() { + final List
sections = this.sections.getValue(); + sections.removeIf(section -> section.getDownward() == Station.TERMINAL); + return sections; + } +} diff --git a/src/main/java/subway/domain/line/Name.java b/src/main/java/subway/domain/line/Name.java new file mode 100644 index 000000000..c44bc6f61 --- /dev/null +++ b/src/main/java/subway/domain/line/Name.java @@ -0,0 +1,30 @@ +package subway.domain.line; + +import java.util.Objects; +import java.util.regex.Pattern; +import subway.exception.InvalidLineNameException; + +final class Name { + + private static final Pattern PATTERN = Pattern.compile("^[1-9]호선$"); + + private final String value; + + public Name(final String value) { + validate(value); + this.value = value; + } + + private void validate(final String value) { + if (Objects.isNull(value) || value.isBlank()) { + throw new InvalidLineNameException("노선 이름은 존재해야 합니다."); + } + if (!PATTERN.matcher(value).matches()) { + throw new InvalidLineNameException("노선 이름은 1~9호선이어야 합니다."); + } + } + + public String getValue() { + return value; + } +} diff --git a/src/main/java/subway/domain/section/Distance.java b/src/main/java/subway/domain/section/Distance.java new file mode 100644 index 000000000..79d1f107d --- /dev/null +++ b/src/main/java/subway/domain/section/Distance.java @@ -0,0 +1,25 @@ +package subway.domain.section; + +import subway.exception.InvalidDistanceException; + +final class Distance { + + private static final int MINIMUM_VALUE = 0; + + private final int value; + + public Distance(final int value) { + validate(value); + this.value = value; + } + + private void validate(final int value) { + if (value < MINIMUM_VALUE) { + throw new InvalidDistanceException("역 사이의 거리는 0이상이어야합니다."); + } + } + + public int getValue() { + return value; + } +} diff --git a/src/main/java/subway/domain/section/PathSection.java b/src/main/java/subway/domain/section/PathSection.java new file mode 100644 index 000000000..1f8aced18 --- /dev/null +++ b/src/main/java/subway/domain/section/PathSection.java @@ -0,0 +1,58 @@ +package subway.domain.section; + +import subway.domain.line.Fare; +import subway.domain.station.Station; +import subway.domain.subway.LineWeightedEdge; + +public class PathSection { + + private final Long lineId; + private final Station source; + private final Station target; + private final Distance distance; + private final Fare fareOfLine; + + public PathSection( + final Long lineId, + final Station source, + final Station target, + final int distance, + final int fareOfLine + ) { + this.lineId = lineId; + this.source = source; + this.target = target; + this.distance = new Distance(distance); + this.fareOfLine = new Fare(fareOfLine); + } + + public static PathSection from(final LineWeightedEdge edge) { + return new PathSection( + edge.getLineId(), + edge.getSource(), + edge.getTarget(), + (int) edge.getWeight(), + edge.getFareOfLine() + ); + } + + public Long getLineId() { + return lineId; + } + + public Station getSource() { + return source; + } + + public Station getTarget() { + return target; + } + + public int getDistance() { + return distance.getValue(); + } + + public int getFareOfLine() { + return fareOfLine.getValue(); + } +} diff --git a/src/main/java/subway/domain/section/Section.java b/src/main/java/subway/domain/section/Section.java new file mode 100644 index 000000000..f6f22eab9 --- /dev/null +++ b/src/main/java/subway/domain/section/Section.java @@ -0,0 +1,35 @@ +package subway.domain.section; + +import subway.domain.station.Station; +import subway.entity.SectionEntity; + +public final class Section { + + private final Station upward; + private final Station downward; + private final Distance distance; + + public Section(final Station upward, final Station downward, final int distance) { + this.upward = upward; + this.downward = downward; + this.distance = new Distance(distance); + } + + public static Section from(final SectionEntity sectionEntity) { + final Station upward = new Station(sectionEntity.getUpwardStationId(), sectionEntity.getUpwardStation()); + final Station downward = new Station(sectionEntity.getDownwardStationId(), sectionEntity.getDownwardStation()); + return new Section(upward, downward, sectionEntity.getDistance()); + } + + public Station getUpward() { + return upward; + } + + public Station getDownward() { + return downward; + } + + public int getDistance() { + return distance.getValue(); + } +} diff --git a/src/main/java/subway/domain/section/Sections.java b/src/main/java/subway/domain/section/Sections.java new file mode 100644 index 000000000..2c02aeece --- /dev/null +++ b/src/main/java/subway/domain/section/Sections.java @@ -0,0 +1,63 @@ +package subway.domain.section; + +import java.util.LinkedList; +import java.util.List; +import java.util.stream.Collectors; +import subway.domain.station.Station; + +public final class Sections { + + public static final int NOT_EXIST_INDEX = -1; + + private final List
sections; + + public Sections(final List
sections) { + this.sections = sections; + } + + public void add(final Section section) { + sections.add(section); + } + + public void add(final int position, final Section section) { + sections.add(position, section); + } + + public Section findSectionByPosition(final int position) { + return sections.get(position); + } + + public void deleteByPosition(final int position) { + sections.remove(position); + } + + public int findPosition(final Station station) { + try { + return getUpwards().indexOf(station); + } catch (NullPointerException exception) { + return NOT_EXIST_INDEX; + } + } + + public boolean isEmpty() { + return sections.isEmpty(); + } + + public int size() { + return sections.size(); + } + + public void clear() { + sections.clear(); + } + + public List getUpwards() { + return sections.stream() + .map(Section::getUpward) + .collect(Collectors.toList()); + } + + public List
getValue() { + return new LinkedList<>(sections); + } +} diff --git a/src/main/java/subway/domain/station/Name.java b/src/main/java/subway/domain/station/Name.java new file mode 100644 index 000000000..e61a3470a --- /dev/null +++ b/src/main/java/subway/domain/station/Name.java @@ -0,0 +1,34 @@ +package subway.domain.station; + +import java.util.Objects; +import java.util.regex.Pattern; +import subway.exception.InvalidStationNameException; + +final class Name { + + private static final int MAXIMUM_LENGTH = 11; + private static final Pattern PATTERN = Pattern.compile("^[가-힣0-9]+역$"); + + private final String value; + + public Name(final String value) { + validate(value); + this.value = value; + } + + private void validate(final String value) { + if (Objects.isNull(value) || value.isBlank()) { + throw new InvalidStationNameException("역 이름은 공백일 수 없습니다."); + } + if (value.length() > MAXIMUM_LENGTH) { + throw new InvalidStationNameException("역 이름은 " + MAXIMUM_LENGTH + "글자까지 가능합니다."); + } + if (!PATTERN.matcher(value).matches()) { + throw new InvalidStationNameException("역 이름은 한글, 숫자만 가능하고, '역'으로 끝나야 합니다."); + } + } + + public String getValue() { + return value; + } +} diff --git a/src/main/java/subway/domain/station/Station.java b/src/main/java/subway/domain/station/Station.java new file mode 100644 index 000000000..77dd378d9 --- /dev/null +++ b/src/main/java/subway/domain/station/Station.java @@ -0,0 +1,45 @@ +package subway.domain.station; + +import java.util.Objects; + +public final class Station { + + public static final Station TERMINAL = new Station(0L, "종착역"); + + private final Long id; + private final Name name; + + public Station(final String name) { + this(null, name); + } + + public Station(final Long id, final String name) { + this.id = id; + this.name = new Name(name); + } + + public Long getId() { + return id; + } + + public String getName() { + return name.getValue(); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final Station station = (Station) o; + return Objects.equals(id, station.id); + } + + @Override + public int hashCode() { + return Objects.hash(id); + } +} diff --git a/src/main/java/subway/domain/subway/LineWeightedEdge.java b/src/main/java/subway/domain/subway/LineWeightedEdge.java new file mode 100644 index 000000000..8187f0852 --- /dev/null +++ b/src/main/java/subway/domain/subway/LineWeightedEdge.java @@ -0,0 +1,46 @@ +package subway.domain.subway; + +import org.jgrapht.graph.DefaultWeightedEdge; +import subway.domain.station.Station; + +public class LineWeightedEdge extends DefaultWeightedEdge { + + private Long lineId; + private int fareOfLine; + + public LineWeightedEdge() { + } + + public LineWeightedEdge(final Long lineId, final int fareOfLine) { + this.lineId = lineId; + this.fareOfLine = fareOfLine; + } + + public Station getSource() { + return (Station) super.getSource(); + } + + public Station getTarget() { + return (Station) super.getTarget(); + } + + public double getWeight() { + return super.getWeight(); + } + + public Long getLineId() { + return lineId; + } + + public void setLineId(final Long lineId) { + this.lineId = lineId; + } + + public int getFareOfLine() { + return fareOfLine; + } + + public void setFareOfLine(final int fareOfLine) { + this.fareOfLine = fareOfLine; + } +} diff --git a/src/main/java/subway/domain/subway/Passenger.java b/src/main/java/subway/domain/subway/Passenger.java new file mode 100644 index 000000000..98a21033c --- /dev/null +++ b/src/main/java/subway/domain/subway/Passenger.java @@ -0,0 +1,28 @@ +package subway.domain.subway; + +import subway.domain.station.Station; + +public class Passenger { + + private final int age; + private final Station start; + private final Station end; + + public Passenger(final int age, final Station start, final Station end) { + this.age = age; + this.start = start; + this.end = end; + } + + public int getAge() { + return age; + } + + public Station getStart() { + return start; + } + + public Station getEnd() { + return end; + } +} diff --git a/src/main/java/subway/domain/subway/Subway.java b/src/main/java/subway/domain/subway/Subway.java new file mode 100644 index 000000000..c2cf8d1a6 --- /dev/null +++ b/src/main/java/subway/domain/subway/Subway.java @@ -0,0 +1,22 @@ +package subway.domain.subway; + +import java.util.List; +import subway.domain.section.PathSection; +import subway.domain.station.Station; + +public class Subway { + + private final SubwayGraph subwayGraph; + + public Subway(final SubwayGraph subwayGraph) { + this.subwayGraph = subwayGraph; + } + + public List findShortestPathSections(final Station start, final Station end) { + return subwayGraph.findShortestPathSections(start, end); + } + + public long calculateShortestDistance(final Station start, final Station end) { + return subwayGraph.calculateShortestDistance(start, end); + } +} diff --git a/src/main/java/subway/domain/subway/SubwayGraph.java b/src/main/java/subway/domain/subway/SubwayGraph.java new file mode 100644 index 000000000..89f692cc2 --- /dev/null +++ b/src/main/java/subway/domain/subway/SubwayGraph.java @@ -0,0 +1,12 @@ +package subway.domain.subway; + +import java.util.List; +import subway.domain.section.PathSection; +import subway.domain.station.Station; + +public interface SubwayGraph { + + List findShortestPathSections(final Station start, final Station end); + + long calculateShortestDistance(final Station start, final Station end); +} diff --git a/src/main/java/subway/domain/subway/SubwayJgraphtGraph.java b/src/main/java/subway/domain/subway/SubwayJgraphtGraph.java new file mode 100644 index 000000000..1141876bb --- /dev/null +++ b/src/main/java/subway/domain/subway/SubwayJgraphtGraph.java @@ -0,0 +1,80 @@ +package subway.domain.subway; + +import java.util.List; +import java.util.stream.Collectors; +import org.jgrapht.alg.shortestpath.DijkstraShortestPath; +import org.jgrapht.graph.DefaultDirectedWeightedGraph; +import subway.domain.line.Line; +import subway.domain.section.PathSection; +import subway.domain.section.Section; +import subway.domain.station.Station; +import subway.exception.InvalidStationException; + +public class SubwayJgraphtGraph implements SubwayGraph { + + private final DijkstraShortestPath dijkstraShortestPath; + + public SubwayJgraphtGraph(final List lines) { + this.dijkstraShortestPath = new DijkstraShortestPath(generateGraph(lines)); + } + + private DefaultDirectedWeightedGraph generateGraph(final List lines) { + final DefaultDirectedWeightedGraph graph = + new DefaultDirectedWeightedGraph<>(LineWeightedEdge.class); + + for (final Line line : lines) { + drawGraph(graph, line); + } + return graph; + } + + private void drawGraph(final DefaultDirectedWeightedGraph graph, final Line line) { + for (final Section section : line.getSections()) { + addVertex(graph, section); + addEdge(graph, line, section); + } + } + + private void addVertex(final DefaultDirectedWeightedGraph graph, final Section section) { + graph.addVertex(section.getUpward()); + graph.addVertex(section.getDownward()); + } + + private void addEdge( + final DefaultDirectedWeightedGraph graph, + final Line line, + final Section section + ) { + final LineWeightedEdge upwardEdge = graph.addEdge(section.getUpward(), section.getDownward()); + final LineWeightedEdge downwardEdge = graph.addEdge(section.getDownward(), section.getUpward()); + + upwardEdge.setLineId(line.getId()); + upwardEdge.setFareOfLine(line.getFare()); + downwardEdge.setLineId(line.getId()); + downwardEdge.setFareOfLine(line.getFare()); + + graph.setEdgeWeight(upwardEdge, section.getDistance()); + graph.setEdgeWeight(downwardEdge, section.getDistance()); + } + + @Override + public List findShortestPathSections(final Station start, final Station end) { + try { + final List edges = dijkstraShortestPath.getPath(start, end).getEdgeList(); + return edges.stream() + .map(PathSection::from) + .collect(Collectors.toList()); + } catch (IllegalArgumentException e) { + throw new InvalidStationException("노선 구간에 등록되지 않은 역 이름을 통해 경로를 조회할 수 없습니다."); + } + } + + @Override + public long calculateShortestDistance(final Station start, final Station end) { + try { + return (long) dijkstraShortestPath.getPath(start, end).getWeight(); + } catch (IllegalArgumentException e) { + throw new InvalidStationException("노선 구간에 등록되지 않은 역 이름을 통해 경로를 조회할 수 없습니다."); + } + } +} diff --git a/src/main/java/subway/dto/LineRequest.java b/src/main/java/subway/dto/LineRequest.java deleted file mode 100644 index 16cb5bf76..000000000 --- a/src/main/java/subway/dto/LineRequest.java +++ /dev/null @@ -1,23 +0,0 @@ -package subway.dto; - -public class LineRequest { - private String name; - private String color; - - public LineRequest() { - } - - public LineRequest(String name, String color) { - this.name = name; - this.color = color; - } - - public String getName() { - return name; - } - - public String getColor() { - return color; - } - -} diff --git a/src/main/java/subway/dto/LineResponse.java b/src/main/java/subway/dto/LineResponse.java deleted file mode 100644 index c9b668122..000000000 --- a/src/main/java/subway/dto/LineResponse.java +++ /dev/null @@ -1,31 +0,0 @@ -package subway.dto; - -import subway.domain.Line; - -public class LineResponse { - private Long id; - private String name; - private String color; - - public LineResponse(Long id, String name, String color) { - this.id = id; - this.name = name; - this.color = color; - } - - public static LineResponse of(Line line) { - return new LineResponse(line.getId(), line.getName(), line.getColor()); - } - - public Long getId() { - return id; - } - - public String getName() { - return name; - } - - public String getColor() { - return color; - } -} diff --git a/src/main/java/subway/dto/StationRequest.java b/src/main/java/subway/dto/StationRequest.java deleted file mode 100644 index 15175303d..000000000 --- a/src/main/java/subway/dto/StationRequest.java +++ /dev/null @@ -1,16 +0,0 @@ -package subway.dto; - -public class StationRequest { - private String name; - - public StationRequest() { - } - - public StationRequest(String name) { - this.name = name; - } - - public String getName() { - return name; - } -} diff --git a/src/main/java/subway/dto/StationResponse.java b/src/main/java/subway/dto/StationResponse.java deleted file mode 100644 index 5eec02fe0..000000000 --- a/src/main/java/subway/dto/StationResponse.java +++ /dev/null @@ -1,25 +0,0 @@ -package subway.dto; - -import subway.domain.Station; - -public class StationResponse { - private Long id; - private String name; - - public StationResponse(Long id, String name) { - this.id = id; - this.name = name; - } - - public static StationResponse of(Station station) { - return new StationResponse(station.getId(), station.getName()); - } - - public Long getId() { - return id; - } - - public String getName() { - return name; - } -} diff --git a/src/main/java/subway/entity/LineEntity.java b/src/main/java/subway/entity/LineEntity.java new file mode 100644 index 000000000..fc4a15df4 --- /dev/null +++ b/src/main/java/subway/entity/LineEntity.java @@ -0,0 +1,60 @@ +package subway.entity; + +import java.util.Objects; +import subway.domain.line.Line; + +public class LineEntity { + + private final Long id; + private final String name; + private final String color; + private final Integer fare; + + public LineEntity(final String name, final String color, final Integer fare) { + this(null, name, color, fare); + } + + public LineEntity(final Long id, final String name, final String color, final Integer fare) { + this.id = id; + this.name = name; + this.color = color; + this.fare = fare; + } + + public static LineEntity from(final Line line) { + return new LineEntity(line.getId(), line.getName(), line.getColor(), line.getFare()); + } + + public Long getId() { + return id; + } + + public String getName() { + return name; + } + + public String getColor() { + return color; + } + + public Integer getFare() { + return fare; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final LineEntity that = (LineEntity) o; + return Objects.equals(id, that.id); + } + + @Override + public int hashCode() { + return Objects.hash(id); + } +} diff --git a/src/main/java/subway/entity/SectionEntity.java b/src/main/java/subway/entity/SectionEntity.java new file mode 100644 index 000000000..cb9437030 --- /dev/null +++ b/src/main/java/subway/entity/SectionEntity.java @@ -0,0 +1,91 @@ +package subway.entity; + +import java.util.Objects; +import subway.domain.section.Section; + +public class SectionEntity { + + private final Long id; + private final Long lineId; + private final Long upwardStationId; + private final String upwardStation; + private final Long downwardStationId; + private final String downwardStation; + private final Integer distance; + + public SectionEntity( + final Long lineId, + final Long upwardStationId, + final Long downwardStationId, + final Integer distance + ) { + this(null, lineId, upwardStationId, null, downwardStationId, null, distance); + } + + public SectionEntity( + final Long id, final Long lineId, final Long upwardStationId, final String upwardStation, + final Long downwardStationId, final String downwardStation, final Integer distance + ) { + this.id = id; + this.lineId = lineId; + this.upwardStationId = upwardStationId; + this.upwardStation = upwardStation; + this.downwardStationId = downwardStationId; + this.downwardStation = downwardStation; + this.distance = distance; + } + + public static SectionEntity of(final Long lineId, final Section section) { + return new SectionEntity( + lineId, + section.getUpward().getId(), + section.getDownward().getId(), + section.getDistance() + ); + } + + public Long getId() { + return id; + } + + public Long getLineId() { + return lineId; + } + + public Long getUpwardStationId() { + return upwardStationId; + } + + public String getUpwardStation() { + return upwardStation; + } + + public Long getDownwardStationId() { + return downwardStationId; + } + + public String getDownwardStation() { + return downwardStation; + } + + public Integer getDistance() { + return distance; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final SectionEntity that = (SectionEntity) o; + return Objects.equals(id, that.id); + } + + @Override + public int hashCode() { + return Objects.hash(id); + } +} diff --git a/src/main/java/subway/entity/StationEntity.java b/src/main/java/subway/entity/StationEntity.java new file mode 100644 index 000000000..d622fd32f --- /dev/null +++ b/src/main/java/subway/entity/StationEntity.java @@ -0,0 +1,48 @@ +package subway.entity; + +import java.util.Objects; +import subway.domain.station.Station; + +public class StationEntity { + + private final Long id; + private final String name; + + public StationEntity(final String name) { + this(null, name); + } + + public StationEntity(final Long id, final String name) { + this.id = id; + this.name = name; + } + + public static StationEntity from(final Station station) { + return new StationEntity(station.getId(), station.getName()); + } + + public Long getId() { + return id; + } + + public String getName() { + return name; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final StationEntity that = (StationEntity) o; + return Objects.equals(id, that.id); + } + + @Override + public int hashCode() { + return Objects.hash(id); + } +} diff --git a/src/main/java/subway/exception/InvalidColorException.java b/src/main/java/subway/exception/InvalidColorException.java new file mode 100644 index 000000000..8bdfdea62 --- /dev/null +++ b/src/main/java/subway/exception/InvalidColorException.java @@ -0,0 +1,8 @@ +package subway.exception; + +public final class InvalidColorException extends SubwayException { + + public InvalidColorException(final String message) { + super(message); + } +} diff --git a/src/main/java/subway/exception/InvalidDistanceException.java b/src/main/java/subway/exception/InvalidDistanceException.java new file mode 100644 index 000000000..76f4ee13b --- /dev/null +++ b/src/main/java/subway/exception/InvalidDistanceException.java @@ -0,0 +1,8 @@ +package subway.exception; + +public final class InvalidDistanceException extends SubwayException { + + public InvalidDistanceException(final String message) { + super(message); + } +} diff --git a/src/main/java/subway/exception/InvalidFareException.java b/src/main/java/subway/exception/InvalidFareException.java new file mode 100644 index 000000000..c28a69ce9 --- /dev/null +++ b/src/main/java/subway/exception/InvalidFareException.java @@ -0,0 +1,8 @@ +package subway.exception; + +public class InvalidFareException extends SubwayException { + + public InvalidFareException(final String message) { + super(message); + } +} diff --git a/src/main/java/subway/exception/InvalidLineException.java b/src/main/java/subway/exception/InvalidLineException.java new file mode 100644 index 000000000..6b17e3ad9 --- /dev/null +++ b/src/main/java/subway/exception/InvalidLineException.java @@ -0,0 +1,8 @@ +package subway.exception; + +public class InvalidLineException extends SubwayException { + + public InvalidLineException(final String message) { + super(message); + } +} diff --git a/src/main/java/subway/exception/InvalidLineNameException.java b/src/main/java/subway/exception/InvalidLineNameException.java new file mode 100644 index 000000000..6d2ec852a --- /dev/null +++ b/src/main/java/subway/exception/InvalidLineNameException.java @@ -0,0 +1,8 @@ +package subway.exception; + +public final class InvalidLineNameException extends SubwayException { + + public InvalidLineNameException(final String message) { + super(message); + } +} diff --git a/src/main/java/subway/exception/InvalidPolicyException.java b/src/main/java/subway/exception/InvalidPolicyException.java new file mode 100644 index 000000000..3e46ab393 --- /dev/null +++ b/src/main/java/subway/exception/InvalidPolicyException.java @@ -0,0 +1,8 @@ +package subway.exception; + +public class InvalidPolicyException extends SubwayException { + + public InvalidPolicyException(final String message) { + super(message); + } +} diff --git a/src/main/java/subway/exception/InvalidSectionException.java b/src/main/java/subway/exception/InvalidSectionException.java new file mode 100644 index 000000000..1e955dac7 --- /dev/null +++ b/src/main/java/subway/exception/InvalidSectionException.java @@ -0,0 +1,8 @@ +package subway.exception; + +public class InvalidSectionException extends SubwayException { + + public InvalidSectionException(final String message) { + super(message); + } +} diff --git a/src/main/java/subway/exception/InvalidStationException.java b/src/main/java/subway/exception/InvalidStationException.java new file mode 100644 index 000000000..1ce9fd4bc --- /dev/null +++ b/src/main/java/subway/exception/InvalidStationException.java @@ -0,0 +1,8 @@ +package subway.exception; + +public class InvalidStationException extends SubwayException { + + public InvalidStationException(final String message) { + super(message); + } +} diff --git a/src/main/java/subway/exception/InvalidStationNameException.java b/src/main/java/subway/exception/InvalidStationNameException.java new file mode 100644 index 000000000..4e092c62a --- /dev/null +++ b/src/main/java/subway/exception/InvalidStationNameException.java @@ -0,0 +1,8 @@ +package subway.exception; + +public final class InvalidStationNameException extends SubwayException { + + public InvalidStationNameException(final String message) { + super(message); + } +} diff --git a/src/main/java/subway/exception/SubwayException.java b/src/main/java/subway/exception/SubwayException.java new file mode 100644 index 000000000..6c6023e1b --- /dev/null +++ b/src/main/java/subway/exception/SubwayException.java @@ -0,0 +1,8 @@ +package subway.exception; + +public class SubwayException extends RuntimeException { + + public SubwayException(final String message) { + super(message); + } +} diff --git a/src/main/java/subway/repository/LineRepository.java b/src/main/java/subway/repository/LineRepository.java new file mode 100644 index 000000000..9bb907638 --- /dev/null +++ b/src/main/java/subway/repository/LineRepository.java @@ -0,0 +1,86 @@ +package subway.repository; + +import java.util.List; +import java.util.stream.Collectors; +import org.springframework.stereotype.Repository; +import subway.dao.LineDao; +import subway.dao.SectionDao; +import subway.domain.line.Line; +import subway.domain.section.Section; +import subway.entity.LineEntity; +import subway.entity.SectionEntity; +import subway.exception.InvalidLineException; +import subway.exception.InvalidSectionException; + +@Repository +public class LineRepository { + + private final LineDao lineDao; + private final SectionDao sectionDao; + + public LineRepository(final LineDao lineDao, final SectionDao sectionDao) { + this.lineDao = lineDao; + this.sectionDao = sectionDao; + } + + public Line save(final Line line) { + final LineEntity lineEntity = lineDao.save(LineEntity.from(line)); + return new Line(lineEntity.getId(), lineEntity.getName(), lineEntity.getColor(), lineEntity.getFare()); + } + + public Line findById(final Long lineId) { + final LineEntity lineEntity = lineDao.findById(lineId) + .orElseThrow(() -> new InvalidLineException("존재하지 않는 노선 ID 입니다.")); + final List sectionEntities = sectionDao.findAllByLineId(lineId); + return generateLine(lineEntity, sectionEntities); + } + + private Line generateLine(final LineEntity lineEntity, final List sectionEntities) { + final Line line = new Line( + lineEntity.getId(), + lineEntity.getName(), + lineEntity.getColor(), + lineEntity.getFare() + ); + loadSections(line, generateSections(sectionEntities)); + return line; + } + + private List
generateSections(final List sectionEntities) { + return sectionEntities.stream() + .map(Section::from) + .collect(Collectors.toList()); + } + + private void loadSections(final Line line, final List
sections) { + while (!sections.isEmpty()) { + final Section section = sections.remove(0); + try { + line.addSection(section.getUpward(), section.getDownward(), section.getDistance()); + } catch (InvalidSectionException e) { + sections.add(section); + } + } + } + + public List findAll() { + return lineDao.findAll() + .stream() + .map(lineEntity -> generateLine(lineEntity, sectionDao.findAllByLineId(lineEntity.getId()))) + .collect(Collectors.toList()); + } + + public void update(final Line line) { + lineDao.update(LineEntity.from(line)); + sectionDao.deleteAllByLineId(line.getId()); + final List entities = generateSectionEntities(line); + sectionDao.saveAll(entities); + } + + private List generateSectionEntities(final Line line) { + return line.getSections() + .stream() + .map(section -> SectionEntity.of(line.getId(), section)) + .collect(Collectors.toUnmodifiableList()); + } +} diff --git a/src/main/java/subway/repository/StationRepository.java b/src/main/java/subway/repository/StationRepository.java new file mode 100644 index 000000000..505476bbe --- /dev/null +++ b/src/main/java/subway/repository/StationRepository.java @@ -0,0 +1,28 @@ +package subway.repository; + +import org.springframework.stereotype.Repository; +import subway.dao.StationDao; +import subway.domain.station.Station; +import subway.entity.StationEntity; +import subway.exception.InvalidStationException; + +@Repository +public class StationRepository { + + private final StationDao stationDao; + + public StationRepository(final StationDao stationDao) { + this.stationDao = stationDao; + } + + public Station save(final Station station) { + final StationEntity stationEntity = stationDao.save(StationEntity.from(station)); + return new Station(stationEntity.getId(), stationEntity.getName()); + } + + public Station findById(final Long stationId) { + final StationEntity stationEntity = stationDao.findById(stationId) + .orElseThrow(() -> new InvalidStationException("존재하지 않은 역 ID입니다.")); + return new Station(stationEntity.getId(), stationEntity.getName()); + } +} diff --git a/src/main/java/subway/service/LineService.java b/src/main/java/subway/service/LineService.java new file mode 100644 index 000000000..8b76a0cbc --- /dev/null +++ b/src/main/java/subway/service/LineService.java @@ -0,0 +1,66 @@ +package subway.service; + +import java.util.List; +import java.util.stream.Collectors; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; +import subway.controller.dto.request.LineCreateRequest; +import subway.controller.dto.request.SectionCreateRequest; +import subway.controller.dto.response.LineResponse; +import subway.controller.dto.response.LinesResponse; +import subway.domain.line.Line; +import subway.domain.station.Station; +import subway.repository.LineRepository; +import subway.repository.StationRepository; + +@Service +@Transactional(readOnly = true) +public class LineService { + + private final LineRepository lineRepository; + private final StationRepository stationRepository; + + public LineService(final LineRepository lineRepository, final StationRepository stationRepository) { + this.lineRepository = lineRepository; + this.stationRepository = stationRepository; + } + + @Transactional + public Long createLine(final LineCreateRequest request) { + final Line line = new Line(request.getName(), request.getColor(), request.getFare()); + return lineRepository.save(line).getId(); + } + + public LineResponse findLineById(final Long lineId) { + final Line line = lineRepository.findById(lineId); + return LineResponse.from(line); + } + + public LinesResponse findLines() { + final List lines = lineRepository.findAll(); + return new LinesResponse(generateLineResponses(lines)); + } + + private List generateLineResponses(final List lines) { + return lines.stream() + .map(LineResponse::from) + .collect(Collectors.toUnmodifiableList()); + } + + @Transactional + public void createSection(final Long lineId, final SectionCreateRequest request) { + final Line line = lineRepository.findById(lineId); + final Station upward = stationRepository.findById(request.getUpwardStationId()); + final Station downward = stationRepository.findById(request.getDownwardStationId()); + line.addSection(upward, downward, request.getDistance()); + lineRepository.update(line); + } + + @Transactional + public void deleteStation(final Long lineId, final Long stationId) { + final Line line = lineRepository.findById(lineId); + final Station station = stationRepository.findById(stationId); + line.deleteStation(station); + lineRepository.update(line); + } +} diff --git a/src/main/java/subway/service/StationService.java b/src/main/java/subway/service/StationService.java new file mode 100644 index 000000000..8ef4077a2 --- /dev/null +++ b/src/main/java/subway/service/StationService.java @@ -0,0 +1,30 @@ +package subway.service; + +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; +import subway.controller.dto.request.StationCreateRequest; +import subway.controller.dto.response.StationResponse; +import subway.domain.station.Station; +import subway.repository.StationRepository; + +@Service +@Transactional(readOnly = true) +public class StationService { + + private final StationRepository stationRepository; + + public StationService(final StationRepository stationRepository) { + this.stationRepository = stationRepository; + } + + @Transactional + public Long createStation(final StationCreateRequest request) { + final Station station = new Station(request.getName()); + return stationRepository.save(station).getId(); + } + + public StationResponse findStationById(final Long stationId) { + final Station station = stationRepository.findById(stationId); + return StationResponse.from(station); + } +} diff --git a/src/main/java/subway/service/SubwayService.java b/src/main/java/subway/service/SubwayService.java new file mode 100644 index 000000000..7337ff9b3 --- /dev/null +++ b/src/main/java/subway/service/SubwayService.java @@ -0,0 +1,61 @@ +package subway.service; + +import java.util.List; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; +import subway.controller.dto.request.PassengerRequest; +import subway.controller.dto.response.ShortestPathResponse; +import subway.domain.fare.FareStrategy; +import subway.domain.line.Line; +import subway.domain.section.PathSection; +import subway.domain.station.Station; +import subway.domain.subway.Passenger; +import subway.domain.subway.Subway; +import subway.domain.subway.SubwayGraph; +import subway.domain.subway.SubwayJgraphtGraph; +import subway.repository.LineRepository; +import subway.repository.StationRepository; + +@Service +@Transactional(readOnly = true) +public class SubwayService { + + private final LineRepository lineRepository; + private final StationRepository stationRepository; + private final FareStrategy fareStrategy; + + public SubwayService( + final LineRepository lineRepository, + final StationRepository stationRepository, + final FareStrategy fareStrategy + ) { + this.lineRepository = lineRepository; + this.stationRepository = stationRepository; + this.fareStrategy = fareStrategy; + } + + public ShortestPathResponse findShortestPath(final PassengerRequest request) { + final List lines = lineRepository.findAll(); + final Subway subway = generateSubway(lines); + final Passenger passenger = generatePassenger(request); + + final List pathSections = + subway.findShortestPathSections(passenger.getStart(), passenger.getEnd()); + final long totalDistance = subway.calculateShortestDistance(passenger.getStart(), passenger.getEnd()); + final long subwayFare = (long) fareStrategy.calculateFare(0, passenger, subway); + + return ShortestPathResponse.of(pathSections, totalDistance, subwayFare); + } + + private Subway generateSubway(final List lines) { + final SubwayGraph subwayGraph = new SubwayJgraphtGraph(lines); + return new Subway(subwayGraph); + } + + private Passenger generatePassenger(final PassengerRequest request) { + final Station start = stationRepository.findById(request.getStartStationId()); + final Station end = stationRepository.findById(request.getEndStationId()); + return new Passenger(request.getAge(), start, end); + } +} + diff --git a/src/main/java/subway/ui/LineController.java b/src/main/java/subway/ui/LineController.java deleted file mode 100644 index 3a335ee14..000000000 --- a/src/main/java/subway/ui/LineController.java +++ /dev/null @@ -1,55 +0,0 @@ -package subway.ui; - -import org.springframework.http.ResponseEntity; -import org.springframework.web.bind.annotation.*; -import subway.application.LineService; -import subway.dto.LineRequest; -import subway.dto.LineResponse; - -import java.net.URI; -import java.sql.SQLException; -import java.util.List; - -@RestController -@RequestMapping("/lines") -public class LineController { - - private final LineService lineService; - - public LineController(LineService lineService) { - this.lineService = lineService; - } - - @PostMapping - public ResponseEntity createLine(@RequestBody LineRequest lineRequest) { - LineResponse line = lineService.saveLine(lineRequest); - return ResponseEntity.created(URI.create("/lines/" + line.getId())).body(line); - } - - @GetMapping - public ResponseEntity> findAllLines() { - return ResponseEntity.ok(lineService.findLineResponses()); - } - - @GetMapping("/{id}") - public ResponseEntity findLineById(@PathVariable Long id) { - return ResponseEntity.ok(lineService.findLineResponseById(id)); - } - - @PutMapping("/{id}") - public ResponseEntity updateLine(@PathVariable Long id, @RequestBody LineRequest lineUpdateRequest) { - lineService.updateLine(id, lineUpdateRequest); - return ResponseEntity.ok().build(); - } - - @DeleteMapping("/{id}") - public ResponseEntity deleteLine(@PathVariable Long id) { - lineService.deleteLineById(id); - return ResponseEntity.noContent().build(); - } - - @ExceptionHandler(SQLException.class) - public ResponseEntity handleSQLException() { - return ResponseEntity.badRequest().build(); - } -} diff --git a/src/main/java/subway/ui/StationController.java b/src/main/java/subway/ui/StationController.java deleted file mode 100644 index 5bf52a9a9..000000000 --- a/src/main/java/subway/ui/StationController.java +++ /dev/null @@ -1,54 +0,0 @@ -package subway.ui; - -import org.springframework.http.ResponseEntity; -import org.springframework.web.bind.annotation.*; -import subway.dto.StationRequest; -import subway.dto.StationResponse; -import subway.application.StationService; - -import java.net.URI; -import java.sql.SQLException; -import java.util.List; - -@RestController -@RequestMapping("/stations") -public class StationController { - private final StationService stationService; - - public StationController(StationService stationService) { - this.stationService = stationService; - } - - @PostMapping - public ResponseEntity createStation(@RequestBody StationRequest stationRequest) { - StationResponse station = stationService.saveStation(stationRequest); - return ResponseEntity.created(URI.create("/stations/" + station.getId())).body(station); - } - - @GetMapping - public ResponseEntity> showStations() { - return ResponseEntity.ok().body(stationService.findAllStationResponses()); - } - - @GetMapping("/{id}") - public ResponseEntity showStation(@PathVariable Long id) { - return ResponseEntity.ok().body(stationService.findStationResponseById(id)); - } - - @PutMapping("/{id}") - public ResponseEntity updateStation(@PathVariable Long id, @RequestBody StationRequest stationRequest) { - stationService.updateStation(id, stationRequest); - return ResponseEntity.ok().build(); - } - - @DeleteMapping("/{id}") - public ResponseEntity deleteStation(@PathVariable Long id) { - stationService.deleteStationById(id); - return ResponseEntity.noContent().build(); - } - - @ExceptionHandler(SQLException.class) - public ResponseEntity handleSQLException() { - return ResponseEntity.badRequest().build(); - } -} diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml new file mode 100644 index 000000000..b3e5032b4 --- /dev/null +++ b/src/main/resources/application.yml @@ -0,0 +1,26 @@ +spring: + datasource: + driver-class-name: com.mysql.cj.jdbc.Driver + url: ${SPRING_DATASOURCE_URL} + username: ${SPRING_DATASOURCE_USERNAME} + password: ${SPRING_DATASOURCE_PASSWORD} + + sql: + init: + schema-locations: classpath:data.sql + data-locations: classpath:dummy.sql + + springdoc: + packages-to-scan: subway + default-consumes-media-type: application/json;charset=UTF-8 + default-produces-media-type: application/json;charset=UTF-8 + swagger-ui: + path: subway.html + tags-sorter: alpha + operations-sorter: alpha + api-docs: + path: /subway-api-docs + groups: + enabled: true + cache: + disabled: true diff --git a/src/main/resources/data.sql b/src/main/resources/data.sql new file mode 100644 index 000000000..cfabf10c3 --- /dev/null +++ b/src/main/resources/data.sql @@ -0,0 +1,28 @@ +CREATE TABLE IF NOT EXISTS station +( + `id` BIGINT AUTO_INCREMENT NOT NULL, + `name` VARCHAR(15) NOT NULL, + PRIMARY KEY (`id`) +); + +CREATE TABLE IF NOT EXISTS line +( + `id` BIGINT AUTO_INCREMENT NOT NULL, + `name` VARCHAR(15) NOT NULL, + `color` VARCHAR(15) NOT NULL, + `fare` INT NOT NULL, + PRIMARY KEY (`id`) +); + +CREATE TABLE IF NOT EXISTS section +( + `id` BIGINT AUTO_INCREMENT NOT NULL, + `line_id` BIGINT NOT NULL, + `upward_station_id` BIGINT NOT NULL, + `downward_station_id` BIGINT NOT NULL, + `distance` INT NOT NULL, + PRIMARY KEY (`id`), + FOREIGN KEY (`line_id`) REFERENCES `line` (`id`), + FOREIGN KEY (`upward_station_id`) REFERENCES `station` (`id`), + FOREIGN KEY (`downward_station_id`) REFERENCES `station` (`id`) +); diff --git a/src/main/resources/dummy.sql b/src/main/resources/dummy.sql new file mode 100644 index 000000000..b29e4530d --- /dev/null +++ b/src/main/resources/dummy.sql @@ -0,0 +1,5 @@ + +INSERT INTO line (id, name, color, fare) VALUES (1, '2호선', '초록색', 500); +INSERT INTO station (id, name) VALUES (1, '잠실역'); +INSERT INTO station (id, name) VALUES (2, '잠실새내역'); +INSERT INTO section (id, line_id, upward_station_id, downward_station_id, distance) VALUES (1, 1, 1, 2, 10); diff --git a/src/main/resources/schema.sql b/src/main/resources/schema.sql deleted file mode 100644 index fc6de5f5c..000000000 --- a/src/main/resources/schema.sql +++ /dev/null @@ -1,14 +0,0 @@ -create table if not exists STATION -( - id bigint auto_increment not null, - name varchar(255) not null unique, - primary key(id) -); - -create table if not exists LINE -( - id bigint auto_increment not null, - name varchar(255) not null unique, - color varchar(20) not null, - primary key(id) -); diff --git a/src/test/java/fixtures/StationFixtures.java b/src/test/java/fixtures/StationFixtures.java new file mode 100644 index 000000000..7722fe643 --- /dev/null +++ b/src/test/java/fixtures/StationFixtures.java @@ -0,0 +1,12 @@ +package fixtures; + +import subway.domain.station.Station; + +public class StationFixtures { + + public static final Station GANGNAM = new Station(1L, "강남역"); + public static final Station YANGJAE = new Station(2L, "양재역"); + public static final Station GYODAE = new Station(3L, "교대역"); + public static final Station NAMBU = new Station(4L, "남부터미널역"); + public static final Station JAMSIL = new Station(5L, "잠실역"); +} diff --git a/src/test/java/subway/Integration/IntegrationTest.java b/src/test/java/subway/Integration/IntegrationTest.java new file mode 100644 index 000000000..0a449c801 --- /dev/null +++ b/src/test/java/subway/Integration/IntegrationTest.java @@ -0,0 +1,20 @@ +package subway.Integration; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.transaction.annotation.Transactional; + +@SpringBootTest +@Transactional +@AutoConfigureMockMvc +public class IntegrationTest { + + @Autowired + protected MockMvc mockMvc; + + @Autowired + protected ObjectMapper objectMapper; +} diff --git a/src/test/java/subway/Integration/LineControllerIntegrationTest.java b/src/test/java/subway/Integration/LineControllerIntegrationTest.java new file mode 100644 index 000000000..7712f9ad3 --- /dev/null +++ b/src/test/java/subway/Integration/LineControllerIntegrationTest.java @@ -0,0 +1,279 @@ +package subway.Integration; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.print; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +import java.nio.charset.Charset; +import java.util.List; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.test.web.servlet.MvcResult; +import subway.controller.dto.request.LineCreateRequest; +import subway.controller.dto.request.SectionCreateRequest; +import subway.controller.dto.response.LineResponse; +import subway.controller.dto.response.LinesResponse; +import subway.dao.StationDao; +import subway.domain.line.Line; +import subway.domain.station.Station; +import subway.entity.StationEntity; +import subway.repository.LineRepository; +import subway.repository.StationRepository; + +public class LineControllerIntegrationTest extends IntegrationTest { + + @Autowired + private StationDao stationDao; + + @Autowired + private LineRepository lineRepository; + + @Autowired + private StationRepository stationRepository; + + private Line lineTwo; + private Station upward; + private Station downward; + + @BeforeEach + void setUp() { + lineTwo = lineRepository.save(new Line("2호선", "초록색", 500)); + upward = stationRepository.save(new Station("잠실역")); + downward = stationRepository.save(new Station("잠실새내역")); + lineTwo.addSection(upward, downward, 10); + lineRepository.update(lineTwo); + } + + @Test + @DisplayName("노선 목록을 조회한다.") + void findLines() throws Exception { + final Line lineFour = lineRepository.save(new Line("4호선", "하늘색", 1000)); + final Station lineFourUpward = stationRepository.save(new Station("이수역")); + final Station lineFourDownward = stationRepository.save(new Station("서울역")); + lineFour.addSection(lineFourUpward, lineFourDownward, 10); + lineRepository.update(lineFour); + + final MvcResult mvcResult = mockMvc.perform(get("/lines")) + .andDo(print()) + .andExpect(status().isOk()) + .andReturn(); + + final LinesResponse response = + new LinesResponse(List.of(LineResponse.from(lineTwo), LineResponse.from(lineFour))); + final String jsonResponse = mvcResult.getResponse().getContentAsString(Charset.forName("UTF-8")); + final LinesResponse result = objectMapper.readValue(jsonResponse, LinesResponse.class); + assertThat(result).usingRecursiveComparison().isEqualTo(response); + } + + @Nested + @DisplayName("노선 생성 요청시 ") + class CreateLine { + + @Test + @DisplayName("유효한 노선 정보라면 새로운 노선을 추가한다.") + void createLine() throws Exception { + final LineCreateRequest request = new LineCreateRequest("2호선", "초록색", 500); + + mockMvc.perform(post("/lines") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isCreated()) + .andExpect(header().string(HttpHeaders.LOCATION, containsString("/lines/"))); + } + + @Test + @DisplayName("이름이 공백이라면 400 상태를 반환한다.") + void createLineWithInvalidName() throws Exception { + final LineCreateRequest request = new LineCreateRequest(" ", "초록색", 500); + + mockMvc.perform(post("/lines") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("노선 이름은 공백일 수 없습니다.")); + } + + @Test + @DisplayName("색이 공백이라면 400 상태를 반환한다.") + void createLineWithInvalidColor() throws Exception { + final LineCreateRequest request = new LineCreateRequest("2호선", " ", 500); + + mockMvc.perform(post("/lines") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("노선 색깔은 공백일 수 없습니다.")); + } + + @Test + @DisplayName("역 간의 거리가 입력되지 않으면 400 상태를 반환한다.") + void createSectionWithoutDistance() throws Exception { + final SectionCreateRequest request = new SectionCreateRequest(1L, 2L, null); + + mockMvc.perform(post("/lines/{id}/sections", 1L) + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("역 간의 거리는 존재해야 합니다.")); + } + + @Test + @DisplayName("역 간의 거리가 0이하이면 400 상태를 반환한다.") + void createSectionWithNegativeDistance() throws Exception { + final SectionCreateRequest request = new SectionCreateRequest(1L, 2L, -1); + + mockMvc.perform(post("/lines/{id}/sections", 1L) + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("역 간의 거리는 0보다 커야합니다.")); + } + } + + @Nested + @DisplayName("노선 조회 시 ") + class FindLine { + + @Test + @DisplayName("존재하는 노선이라면 노선 정보를 조회한다.") + void findLine() throws Exception { + final MvcResult mvcResult = mockMvc.perform(get("/lines/{id}", lineTwo.getId())) + .andDo(print()) + .andExpect(status().isOk()) + .andReturn(); + + final LineResponse response = LineResponse.from(lineTwo); + final String jsonResponse = mvcResult.getResponse().getContentAsString(Charset.forName("UTF-8")); + final LineResponse result = objectMapper.readValue(jsonResponse, LineResponse.class); + assertThat(result).usingRecursiveComparison().isEqualTo(response); + } + + @Test + @DisplayName("ID로 변환할 수 없는 타입이라면 400 상태를 반환한다.") + void findLineWithInvalidIDType() throws Exception { + mockMvc.perform(get("/lines/{id}", "l")) + .andDo(print()) + .andExpect(status().isBadRequest()); + } + } + + @Nested + @DisplayName("노선에 역을 등록할 시 ") + class CreateSection { + + @Test + @DisplayName("유효한 정보가 입력되면 노선에 역을 등록한다.") + void createSection() throws Exception { + final StationEntity middle = stationDao.save(new StationEntity("종합운동장역")); + final SectionCreateRequest request = new SectionCreateRequest(upward.getId(), middle.getId(), 5); + + mockMvc.perform(post("/lines/{id}/sections", lineTwo.getId()) + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isCreated()) + .andExpect(header().string(HttpHeaders.LOCATION, containsString("/lines/" + lineTwo.getId()))); + } + + @Test + @DisplayName("상행 역 ID가 입력되지 않으면 400 상태를 반환한다.") + void createSectionWithoutUpwardStationId() throws Exception { + final SectionCreateRequest request = new SectionCreateRequest(null, downward.getId(), 10); + + mockMvc.perform(post("/lines/{id}/sections", lineTwo.getId()) + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("상행 역 ID는 존재해야 합니다.")); + } + + @Test + @DisplayName("하행 역 ID가 입력되지 않으면 400 상태를 반환한다.") + void createSectionWithoutDownwardStationId() throws Exception { + final SectionCreateRequest request = new SectionCreateRequest(upward.getId(), null, 10); + + mockMvc.perform(post("/lines/{id}/sections", lineTwo.getId()) + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("하행 역 ID는 존재해야 합니다.")); + } + + @Test + @DisplayName("역 간의 거리가 입력되지 않으면 400 상태를 반환한다.") + void createSectionWithoutDistance() throws Exception { + final StationEntity middle = stationDao.save(new StationEntity("종합운동장역")); + final SectionCreateRequest request = new SectionCreateRequest(upward.getId(), middle.getId(), null); + + mockMvc.perform(post("/lines/{id}/sections", lineTwo.getId()) + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("역 간의 거리는 존재해야 합니다.")); + } + + @Test + @DisplayName("역 간의 거리가 0이하이면 400 상태를 반환한다.") + void createSectionWithNegativeDistance() throws Exception { + final StationEntity middle = stationDao.save(new StationEntity("종합운동장역")); + final SectionCreateRequest request = new SectionCreateRequest(upward.getId(), middle.getId(), -1); + + mockMvc.perform(post("/lines/{id}/sections", lineTwo.getId()) + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("역 간의 거리는 0보다 커야합니다.")); + } + } + + @Nested + @DisplayName("노선에서 역 삭제 요청 시") + class DeleteStation { + + @Test + @DisplayName("유효한 요청이라면 역을 삭제한다.") + void deleteStation() throws Exception { + mockMvc.perform(delete("/lines/{lineId}", lineTwo.getId()) + .queryParam("stationId", String.valueOf(upward.getId()))) + .andDo(print()) + .andExpect(status().isNoContent()); + } + + @Test + @DisplayName("역 아이디가 존재하지 않으면 400 상태를 반환한다.") + void deleteStationWithoutStationId() throws Exception { + mockMvc.perform(delete("/lines/{lineId}", Long.MAX_VALUE)) + .andDo(print()) + .andExpect(status().isBadRequest()); + } + + @Test + @DisplayName("역 아이디로 변환할 수 없는 타입이면 400 상태를 반환한다.") + void deleteStationWithInvalidStationIDType() throws Exception { + mockMvc.perform(delete("/lines/{lineId}", lineTwo.getId()) + .queryParam("stationId", "s")) + .andDo(print()) + .andExpect(status().isBadRequest()); + } + } +} diff --git a/src/test/java/subway/Integration/StationControllerIntegrationTest.java b/src/test/java/subway/Integration/StationControllerIntegrationTest.java new file mode 100644 index 000000000..1e9959d13 --- /dev/null +++ b/src/test/java/subway/Integration/StationControllerIntegrationTest.java @@ -0,0 +1,88 @@ +package subway.Integration; + +import static org.hamcrest.Matchers.containsString; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.print; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import subway.controller.dto.request.StationCreateRequest; +import subway.domain.station.Station; +import subway.repository.StationRepository; + +class StationControllerIntegrationTest extends IntegrationTest { + + @Autowired + private StationRepository stationRepository; + + @Nested + @DisplayName("역 추가 요청시 ") + class CreateStation { + + @Test + @DisplayName("유효한 역 정보라면 새로운 역을 추가한다") + void createStation() throws Exception { + final StationCreateRequest request = new StationCreateRequest("잠실역"); + + mockMvc.perform(post("/stations") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isCreated()) + .andExpect(header().string(HttpHeaders.LOCATION, containsString("/stations/"))); + } + + @Test + @DisplayName("역 이름이 잘못되면 400 상태를 반환한다.") + void createStationWithInvalidName() throws Exception { + final StationCreateRequest request = new StationCreateRequest(" "); + + mockMvc.perform(post("/stations") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("역 이름은 공백일 수 없습니다.")); + } + } + + @Nested + @DisplayName("역 정보 조회 시 ") + class FindStation { + + private Station station; + + @BeforeEach + void setUp() { + station = stationRepository.save(new Station("잠실역")); + } + + @Test + @DisplayName("유효한 ID라면 역 정보를 조회한다.") + void findStation() throws Exception { + mockMvc.perform(get("/stations/{id}", station.getId())) + .andDo(print()) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.id").value(station.getId())) + .andExpect(jsonPath("$.name").value(station.getName())); + } + + @Test + @DisplayName("ID가 유효하지 않다면 400 상태를 반환한다.") + void findStationWithInvalidID() throws Exception { + mockMvc.perform(get("/stations/{id}", "poi")) + .andDo(print()) + .andExpect(status().isBadRequest()); + } + } +} diff --git a/src/test/java/subway/Integration/SubwayControllerIntegrationTest.java b/src/test/java/subway/Integration/SubwayControllerIntegrationTest.java new file mode 100644 index 000000000..89470b726 --- /dev/null +++ b/src/test/java/subway/Integration/SubwayControllerIntegrationTest.java @@ -0,0 +1,122 @@ +package subway.Integration; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.print; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +import java.nio.charset.Charset; +import java.util.List; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.http.MediaType; +import org.springframework.test.web.servlet.MvcResult; +import subway.controller.dto.request.PassengerRequest; +import subway.controller.dto.response.ShortestPathResponse; +import subway.domain.line.Line; +import subway.domain.section.PathSection; +import subway.domain.station.Station; +import subway.repository.LineRepository; +import subway.repository.StationRepository; + +class SubwayControllerIntegrationTest extends IntegrationTest { + + @Autowired + private LineRepository lineRepository; + + @Autowired + private StationRepository stationRepository; + + @Nested + @DisplayName("findShortestPath 메서드는 ") + class FindShortestPath { + + @Test + @DisplayName("유효한 요청이라면 최단 경로 정보를 반환한다.") + void findShortestPath() throws Exception { + final Line line = lineRepository.save(new Line(1L, "1호선", "빨간색", 1000)); + final Station upward = stationRepository.save(new Station("강남역")); + final Station downward = stationRepository.save(new Station("양재역")); + line.addSection(upward, downward, 10); + lineRepository.update(line); + + final PassengerRequest request = new PassengerRequest(10, upward.getId(), downward.getId()); + final ShortestPathResponse response = ShortestPathResponse.of( + List.of( + new PathSection(line.getId(), upward, downward, 10, 1000) + ), + 10, + 950 + ); + + final MvcResult mvcResult = mockMvc.perform(get("/subways/shortest-path") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isOk()) + .andReturn(); + + final String jsonResponse = mvcResult.getResponse().getContentAsString(Charset.forName("UTF-8")); + final ShortestPathResponse result = objectMapper.readValue(jsonResponse, ShortestPathResponse.class); + assertThat(result).usingRecursiveComparison().isEqualTo(response); + } + + @Test + @DisplayName("출발역 ID가 존재하지 않으면 400 상태를 반환한다.") + void findShortestPathWithInvalidStartStation() throws Exception { + final PassengerRequest request = new PassengerRequest(10, null, 2L); + + mockMvc.perform(get("/subways/shortest-path") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("출발역 ID는 존재해야 합니다.")); + } + + @Test + @DisplayName("도착역 ID가 존재하지 않으면 400 상태를 반환한다.") + void findShortestPathWithInvalidEndStation() throws Exception { + final PassengerRequest request = new PassengerRequest(10, 1L, null); + + mockMvc.perform(get("/subways/shortest-path") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("도착역 ID는 존재해야 합니다.")); + } + + @Test + @DisplayName("탑승자 나이가 존재하지 않으면 400 상태를 반환한다.") + void findShortestPathWithEmptyAge() throws Exception { + final PassengerRequest request = new PassengerRequest(null, 1L, 2L); + + mockMvc.perform(get("/subways/shortest-path") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("탑승자 나이는 입력해야 합니다.")); + } + + @ParameterizedTest + @ValueSource(ints = {Integer.MIN_VALUE, -1, 0}) + @DisplayName("탑승자 나이가 0보다 작거나 같으면 400 상태를 반환한다.") + void findShortestPathWithInvalidAge(final int age) throws Exception { + final PassengerRequest request = new PassengerRequest(age, 1L, 2L); + + mockMvc.perform(get("/subways/shortest-path") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("탑승자 나이는 0보다 커야합니다.")); + } + } +} diff --git a/src/test/java/subway/SubwayApplicationTests.java b/src/test/java/subway/SubwayApplicationTests.java index cdf84476f..92a2eb3b8 100644 --- a/src/test/java/subway/SubwayApplicationTests.java +++ b/src/test/java/subway/SubwayApplicationTests.java @@ -6,8 +6,8 @@ @SpringBootTest class SubwayApplicationTests { - @Test - void contextLoads() { - } + @Test + void contextLoads() { + } } diff --git a/src/test/java/subway/controller/LineControllerTest.java b/src/test/java/subway/controller/LineControllerTest.java new file mode 100644 index 000000000..88ae898d7 --- /dev/null +++ b/src/test/java/subway/controller/LineControllerTest.java @@ -0,0 +1,287 @@ +package subway.controller; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willDoNothing; +import static org.mockito.Mockito.mock; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.print; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +import com.fasterxml.jackson.databind.ObjectMapper; +import java.nio.charset.Charset; +import java.util.List; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.autoconfigure.web.servlet.WebMvcTest; +import org.springframework.boot.test.mock.mockito.MockBean; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; +import subway.controller.dto.request.LineCreateRequest; +import subway.controller.dto.request.SectionCreateRequest; +import subway.controller.dto.response.LineResponse; +import subway.controller.dto.response.LinesResponse; +import subway.controller.dto.response.StationResponse; +import subway.service.LineService; + +@WebMvcTest(LineController.class) +class LineControllerTest { + + @Autowired + private MockMvc mockMvc; + + @Autowired + private ObjectMapper objectMapper; + + @MockBean + private LineService lineService; + + @Test + @DisplayName("노선 목록을 조회한다.") + void findLines() throws Exception { + final List stationsOfLineTwo = List.of( + new StationResponse(1L, "잠실역"), + new StationResponse(2L, "잠실새내역") + ); + final List stationsOfLineFour = List.of( + new StationResponse(3L, "이수역"), + new StationResponse(4L, "서울역") + ); + final List lines = List.of( + new LineResponse(1L, "2호선", "초록색", 500, stationsOfLineTwo), + new LineResponse(2L, "4호선", "하늘색", 1000, stationsOfLineFour)); + final LinesResponse response = new LinesResponse(lines); + + given(lineService.findLines()).willReturn(response); + + final MvcResult mvcResult = mockMvc.perform(get("/lines")) + .andDo(print()) + .andExpect(status().isOk()) + .andReturn(); + + final String jsonResponse = mvcResult.getResponse().getContentAsString(Charset.forName("UTF-8")); + final LinesResponse result = objectMapper.readValue(jsonResponse, LinesResponse.class); + assertThat(result).usingRecursiveComparison().isEqualTo(response); + } + + @Nested + @DisplayName("노선 생성 요청시 ") + class CreateLine { + + @Test + @DisplayName("유효한 노선 정보라면 새로운 노선을 추가한다.") + void createLine() throws Exception { + final LineCreateRequest request = new LineCreateRequest("2호선", "초록색", 500); + + given(lineService.createLine(any(LineCreateRequest.class))).willReturn(1L); + + mockMvc.perform(post("/lines") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isCreated()) + .andExpect(header().string(HttpHeaders.LOCATION, containsString("/lines/1"))); + } + + @Test + @DisplayName("이름이 존재하지 않으면 400 상태를 반환한다.") + void createLineWithInvalidName() throws Exception { + final LineCreateRequest request = new LineCreateRequest(" ", "초록색", 500); + + mockMvc.perform(post("/lines") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("노선 이름은 공백일 수 없습니다.")); + } + + @Test + @DisplayName("색이 존재하지 않으면 400 상태를 반환한다.") + void createLineWithInvalidColor() throws Exception { + final LineCreateRequest request = new LineCreateRequest("2호선", " ", 500); + + mockMvc.perform(post("/lines") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("노선 색깔은 공백일 수 없습니다.")); + } + + @Test + @DisplayName("추가 요금이 존재하지 않으면 400 상태를 반환한다.") + void createLineWithNotExistFare() throws Exception { + final LineCreateRequest request = new LineCreateRequest("2호선", "초록색", null); + + mockMvc.perform(post("/lines") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("노선 추가 요금은 존재해야 합니다.")); + } + + @Test + @DisplayName("추가 요금이 0보다 작으면 400 상태를 반환한다.") + void createLineWithNegativeFare() throws Exception { + final LineCreateRequest request = new LineCreateRequest("2호선", "초록색", -500); + + mockMvc.perform(post("/lines") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("노선 추가 요금은 0원 이상 가능합니다.")); + } + } + + @Nested + @DisplayName("노선 조회 시 ") + class FindLine { + + @Test + @DisplayName("존재하는 노선이라면 노선 정보를 조회한다.") + void findLine() throws Exception { + final List stations = List.of( + new StationResponse(1L, "잠실역"), + new StationResponse(2L, "잠실새내역") + ); + final LineResponse response = new LineResponse(1L, "2호선", "초록색", 500, stations); + + given(lineService.findLineById(1L)).willReturn(response); + + final MvcResult mvcResult = mockMvc.perform(get("/lines/{id}", 1L)) + .andDo(print()) + .andExpect(status().isOk()) + .andReturn(); + + final String jsonResponse = mvcResult.getResponse().getContentAsString(Charset.forName("UTF-8")); + final LineResponse result = objectMapper.readValue(jsonResponse, LineResponse.class); + assertThat(result).usingRecursiveComparison().isEqualTo(response); + } + + @Test + @DisplayName("ID로 변환할 수 없는 타입이라면 400 상태를 반환한다.") + void findLineWithInvalidIDType() throws Exception { + mockMvc.perform(get("/lines/{id}", "l")) + .andDo(print()) + .andExpect(status().isBadRequest()); + } + } + + @Nested + @DisplayName("노선에 역을 등록할 시 ") + class CreateSection { + + @Test + @DisplayName("유효한 정보가 입력되면 노선에 역을 등록한다.") + void createSection() throws Exception { + final SectionCreateRequest request = new SectionCreateRequest(1L, 2L, 10); + + willDoNothing().given(lineService).createSection(1L, mock(SectionCreateRequest.class)); + + mockMvc.perform(post("/lines/{id}/sections", 1L) + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isCreated()) + .andExpect(header().string(HttpHeaders.LOCATION, containsString("/lines/1"))); + } + + @Test + @DisplayName("상행 역 ID가 입력되지 않으면 400 상태를 반환한다.") + void createSectionWithoutUpwardStationId() throws Exception { + final SectionCreateRequest request = new SectionCreateRequest(null, 2L, 10); + + mockMvc.perform(post("/lines/{id}/sections", 1L) + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("상행 역 ID는 존재해야 합니다.")); + } + + @Test + @DisplayName("하행 역 ID가 입력되지 않으면 400 상태를 반환한다.") + void createSectionWithoutDownwardStationId() throws Exception { + final SectionCreateRequest request = new SectionCreateRequest(1L, null, 10); + + mockMvc.perform(post("/lines/{id}/sections", 1L) + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("하행 역 ID는 존재해야 합니다.")); + } + + @Test + @DisplayName("역 간의 거리가 입력되지 않으면 400 상태를 반환한다.") + void createSectionWithoutDistance() throws Exception { + final SectionCreateRequest request = new SectionCreateRequest(1L, 2L, null); + + mockMvc.perform(post("/lines/{id}/sections", 1L) + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("역 간의 거리는 존재해야 합니다.")); + } + + @Test + @DisplayName("역 간의 거리가 0이하이면 400 상태를 반환한다.") + void createSectionWithNegativeDistance() throws Exception { + final SectionCreateRequest request = new SectionCreateRequest(1L, 2L, -1); + + mockMvc.perform(post("/lines/{id}/sections", 1L) + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("역 간의 거리는 0보다 커야합니다.")); + } + } + + @Nested + @DisplayName("노선에서 역 삭제 요청 시") + class DeleteStation { + + @Test + @DisplayName("유효한 요청이라면 역을 삭제한다.") + void deleteStation() throws Exception { + willDoNothing().given(lineService).deleteStation(any(Long.class), any(Long.class)); + + mockMvc.perform(delete("/lines/{lineId}", 1L) + .queryParam("stationId", String.valueOf(1L))) + .andDo(print()) + .andExpect(status().isNoContent()); + } + + @Test + @DisplayName("역 아이디가 존재하지 않으면 400 상태를 반환한다.") + void deleteStationWithoutStationId() throws Exception { + mockMvc.perform(delete("/lines/{lineId}", 1L)) + .andDo(print()) + .andExpect(status().isBadRequest()); + } + + @Test + @DisplayName("역 아이디로 변환할 수 없는 타입이면 400 상태를 반환한다.") + void deleteStationWithInvalidStationIDType() throws Exception { + mockMvc.perform(delete("/lines/{lineId}", 1L) + .queryParam("stationId", "s")) + .andDo(print()) + .andExpect(status().isBadRequest()); + } + } +} diff --git a/src/test/java/subway/controller/StationControllerTest.java b/src/test/java/subway/controller/StationControllerTest.java new file mode 100644 index 000000000..d3f26e045 --- /dev/null +++ b/src/test/java/subway/controller/StationControllerTest.java @@ -0,0 +1,99 @@ +package subway.controller; + +import static org.hamcrest.Matchers.containsString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.print; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.autoconfigure.web.servlet.WebMvcTest; +import org.springframework.boot.test.mock.mockito.MockBean; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.test.web.servlet.MockMvc; +import subway.controller.dto.request.StationCreateRequest; +import subway.controller.dto.response.StationResponse; +import subway.service.StationService; + +@WebMvcTest(StationController.class) +class StationControllerTest { + + @Autowired + private MockMvc mockMvc; + + @Autowired + private ObjectMapper objectMapper; + + @MockBean + private StationService stationService; + + @Nested + @DisplayName("역 추가 요청시 ") + class CreateStation { + + @Test + @DisplayName("유효한 역 정보라면 새로운 역을 추가한다") + void createStation() throws Exception { + final StationCreateRequest request = new StationCreateRequest("잠실역"); + + given(stationService.createStation(any(StationCreateRequest.class))).willReturn(1L); + + mockMvc.perform(post("/stations") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isCreated()) + .andExpect(header().string(HttpHeaders.LOCATION, containsString("/stations/"))); + } + + @Test + @DisplayName("역 이름이 잘못되면 400 상태를 반환한다.") + void createStationWithInvalidName() throws Exception { + final StationCreateRequest request = new StationCreateRequest(" "); + + mockMvc.perform(post("/stations") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("역 이름은 공백일 수 없습니다.")); + } + } + + @Nested + @DisplayName("역 정보 조회 시 ") + class FindStation { + + @Test + @DisplayName("유효한 ID라면 역 정보를 조회한다.") + void findStation() throws Exception { + final StationResponse response = new StationResponse(1L, "잠실역"); + + given(stationService.findStationById(1L)).willReturn(response); + + mockMvc.perform(get("/stations/{id}", 1L)) + .andDo(print()) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.id").value(1L)) + .andExpect(jsonPath("$.name").value("잠실역")); + } + + @Test + @DisplayName("ID가 유효하지 않다면 400 상태를 반환한다.") + void findStationWithInvalidID() throws Exception { + mockMvc.perform(get("/stations/{id}", "poi")) + .andDo(print()) + .andExpect(status().isBadRequest()); + } + } +} diff --git a/src/test/java/subway/controller/SubwayControllerTest.java b/src/test/java/subway/controller/SubwayControllerTest.java new file mode 100644 index 000000000..fe25413d6 --- /dev/null +++ b/src/test/java/subway/controller/SubwayControllerTest.java @@ -0,0 +1,127 @@ +package subway.controller; + +import static fixtures.StationFixtures.GANGNAM; +import static fixtures.StationFixtures.YANGJAE; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.print; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +import com.fasterxml.jackson.databind.ObjectMapper; +import java.nio.charset.Charset; +import java.util.List; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.autoconfigure.web.servlet.WebMvcTest; +import org.springframework.boot.test.mock.mockito.MockBean; +import org.springframework.http.MediaType; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; +import subway.controller.dto.request.PassengerRequest; +import subway.controller.dto.response.ShortestPathResponse; +import subway.domain.section.PathSection; +import subway.service.SubwayService; + +@WebMvcTest(SubwayController.class) +class SubwayControllerTest { + + @Autowired + private MockMvc mockMvc; + + @Autowired + private ObjectMapper objectMapper; + + @MockBean + private SubwayService subwayService; + + @Nested + @DisplayName("findShortestPath 메서드는 ") + class FindShortestPath { + + @Test + @DisplayName("유효한 요청이라면 최단 경로 정보를 반환한다.") + void findShortestPath() throws Exception { + final PassengerRequest request = new PassengerRequest(10, 1L, 2L); + final ShortestPathResponse response = ShortestPathResponse.of( + List.of( + new PathSection(1L, GANGNAM, YANGJAE, 10, 1000) + ), + 10, + 100 + ); + + given(subwayService.findShortestPath(any(PassengerRequest.class))).willReturn(response); + + final MvcResult mvcResult = mockMvc.perform(get("/subways/shortest-path") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isOk()) + .andReturn(); + + final String jsonResponse = mvcResult.getResponse().getContentAsString(Charset.forName("UTF-8")); + final ShortestPathResponse result = objectMapper.readValue(jsonResponse, ShortestPathResponse.class); + assertThat(result).usingRecursiveComparison().isEqualTo(response); + } + + @Test + @DisplayName("출발역 ID가 존재하지 않으면 400 상태를 반환한다.") + void findShortestPathWithInvalidStartStation() throws Exception { + final PassengerRequest request = new PassengerRequest(10, null, 2L); + + mockMvc.perform(get("/subways/shortest-path") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("출발역 ID는 존재해야 합니다.")); + } + + @Test + @DisplayName("도착역 ID가 존재하지 않으면 400 상태를 반환한다.") + void findShortestPathWithInvalidEndStation() throws Exception { + final PassengerRequest request = new PassengerRequest(10, 1L, null); + + mockMvc.perform(get("/subways/shortest-path") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("도착역 ID는 존재해야 합니다.")); + } + + @Test + @DisplayName("탑승자 나이가 존재하지 않으면 400 상태를 반환한다.") + void findShortestPathWithEmptyAge() throws Exception { + final PassengerRequest request = new PassengerRequest(null, 1L, 2L); + + mockMvc.perform(get("/subways/shortest-path") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("탑승자 나이는 입력해야 합니다.")); + } + + @ParameterizedTest + @ValueSource(ints = {Integer.MIN_VALUE, -1, 0}) + @DisplayName("탑승자 나이가 0보다 작거나 같으면 400 상태를 반환한다.") + void findShortestPathWithInvalidAge(final int age) throws Exception { + final PassengerRequest request = new PassengerRequest(age, 1L, 2L); + + mockMvc.perform(get("/subways/shortest-path") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(request))) + .andDo(print()) + .andExpect(status().isBadRequest()) + .andExpect(content().string("탑승자 나이는 0보다 커야합니다.")); + } + } +} diff --git a/src/test/java/subway/dao/LineDaoTest.java b/src/test/java/subway/dao/LineDaoTest.java new file mode 100644 index 000000000..bc66d1591 --- /dev/null +++ b/src/test/java/subway/dao/LineDaoTest.java @@ -0,0 +1,91 @@ +package subway.dao; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertAll; + +import java.util.List; +import java.util.Optional; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.autoconfigure.jdbc.JdbcTest; +import org.springframework.jdbc.core.JdbcTemplate; +import subway.entity.LineEntity; + +@JdbcTest +class LineDaoTest { + + @Autowired + private JdbcTemplate jdbcTemplate; + + private LineDao lineDao; + + @BeforeEach + void setUp() { + lineDao = new LineDao(jdbcTemplate); + } + + @Test + @DisplayName("모든 노선을 조회한다.") + void findAll() { + final LineEntity lineTwo = lineDao.save(new LineEntity("2호선", "초록색", 500)); + final LineEntity lineFour = lineDao.save(new LineEntity("4호선", "하늘색", 1000)); + + final List lines = lineDao.findAll(); + + assertThat(lines).usingRecursiveComparison().isEqualTo(List.of(lineTwo, lineFour)); + } + + @Nested + @DisplayName("아이디로 조회시 ") + class FindById { + + @Test + @DisplayName("존재하는 ID라면 노선 정보를 반환한다.") + void findById() { + final LineEntity lineEntity = lineDao.save(new LineEntity("2호선", "초록색", 500)); + + final Optional line = lineDao.findById(lineEntity.getId()); + + assertThat(line).usingRecursiveComparison().isEqualTo(Optional.of(lineEntity)); + } + + @Test + @DisplayName("존재하지 않는 ID라면 빈 값을 반환한다.") + void findByWithInvalidId() { + final Optional line = lineDao.findById(-3L); + + assertThat(line).isEmpty(); + } + } + + @Nested + @DisplayName("노선 정보 업데이트시 ") + class Update { + + @Test + @DisplayName("존재하는 노선이라면 정보를 업데이트한다.") + void update() { + final LineEntity line = lineDao.save(new LineEntity("2호선", "초록색", 500)); + + final LineEntity updateLine = new LineEntity(line.getId(), "4호선", "하늘색", 1000); + final int numberOfUpdatedRow = lineDao.update(updateLine); + + final LineEntity updatedLine = lineDao.findById(line.getId()).get(); + assertAll( + () -> assertThat(numberOfUpdatedRow).isEqualTo(1), + () -> assertThat(updatedLine).usingRecursiveComparison().isEqualTo(updatedLine) + ); + } + + @Test + @DisplayName("존재하지 않는 노선이라면 0을 반환한다.") + void updateWithNotExistLine() { + final int numberOfUpdatedRow = lineDao.update(new LineEntity(1L, "4호선", "하늘색", 1000)); + + assertThat(numberOfUpdatedRow).isEqualTo(0); + } + } +} diff --git a/src/test/java/subway/dao/SectionDaoTest.java b/src/test/java/subway/dao/SectionDaoTest.java new file mode 100644 index 000000000..6a9f0161d --- /dev/null +++ b/src/test/java/subway/dao/SectionDaoTest.java @@ -0,0 +1,79 @@ +package subway.dao; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.autoconfigure.jdbc.JdbcTest; +import org.springframework.jdbc.core.JdbcTemplate; +import subway.entity.LineEntity; +import subway.entity.SectionEntity; +import subway.entity.StationEntity; + +@JdbcTest +class SectionDaoTest { + + @Autowired + private JdbcTemplate jdbcTemplate; + + private SectionDao sectionDao; + private StationDao stationDao; + private LineDao lineDao; + + @BeforeEach + void setUp() { + sectionDao = new SectionDao(jdbcTemplate); + stationDao = new StationDao(jdbcTemplate); + lineDao = new LineDao(jdbcTemplate); + } + + @Test + @DisplayName("해당 노선의 모든 구간 정보를 조회한다.") + void findAllByLineId() { + final LineEntity lineEntity = lineDao.save(new LineEntity("2호선", "초록색", 500)); + final StationEntity upward = stationDao.save(new StationEntity("잠실역")); + final StationEntity downward = stationDao.save(new StationEntity("잠실새내역")); + final SectionEntity entity = new SectionEntity(lineEntity.getId(), upward.getId(), downward.getId(), 10); + final SectionEntity savedEntity = sectionDao.save(entity); + + final List sections = sectionDao.findAllByLineId(savedEntity.getLineId()); + + assertThat(sections).containsExactly(savedEntity); + } + + @Test + @DisplayName("모든 구간 정보를 저장한다.") + void saveAll() { + final LineEntity lineEntity = lineDao.save(new LineEntity("2호선", "초록색", 500)); + final StationEntity upward = stationDao.save(new StationEntity("잠실역")); + final StationEntity middle = stationDao.save(new StationEntity("잠실새내역")); + final StationEntity downward = stationDao.save(new StationEntity("종합운동장역")); + final List sections = List.of( + new SectionEntity(lineEntity.getId(), upward.getId(), middle.getId(), 10), + new SectionEntity(lineEntity.getId(), middle.getId(), downward.getId(), 10) + ); + + sectionDao.saveAll(sections); + + final List result = sectionDao.findAllByLineId(lineEntity.getId()); + assertThat(sections).usingRecursiveComparison().ignoringActualNullFields().isEqualTo(result); + } + + @Test + @DisplayName("노선의 구간 정보를 삭제한다.") + void deleteAllByLineId() { + final LineEntity lineEntity = lineDao.save(new LineEntity("2호선", "초록색", 500)); + final StationEntity upward = stationDao.save(new StationEntity("잠실역")); + final StationEntity downward = stationDao.save(new StationEntity("잠실새내역")); + final SectionEntity entity = new SectionEntity(lineEntity.getId(), upward.getId(), downward.getId(), 10); + final SectionEntity savedEntity = sectionDao.save(entity); + + sectionDao.deleteAllByLineId(savedEntity.getLineId()); + + final List sectionEntities = sectionDao.findAllByLineId(savedEntity.getLineId()); + assertThat(sectionEntities).isEmpty(); + } +} diff --git a/src/test/java/subway/dao/StationDaoTest.java b/src/test/java/subway/dao/StationDaoTest.java new file mode 100644 index 000000000..21dde1d73 --- /dev/null +++ b/src/test/java/subway/dao/StationDaoTest.java @@ -0,0 +1,50 @@ +package subway.dao; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Optional; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.autoconfigure.jdbc.JdbcTest; +import org.springframework.jdbc.core.JdbcTemplate; +import subway.entity.StationEntity; + +@JdbcTest +class StationDaoTest { + + @Autowired + private JdbcTemplate jdbcTemplate; + + private StationDao stationDao; + + @BeforeEach + void setUp() { + stationDao = new StationDao(jdbcTemplate); + } + + @Nested + @DisplayName("아이디로 조회시 ") + class FindById { + + @Test + @DisplayName("존재하는 ID라면 역 정보를 반환한다.") + void findById() { + final StationEntity stationEntity = stationDao.save(new StationEntity("잠실역")); + + final Optional station = stationDao.findById(stationEntity.getId()); + + assertThat(station).usingRecursiveComparison().isEqualTo(Optional.of(stationEntity)); + } + + @Test + @DisplayName("존재하지 않는 ID라면 빈 값을 반환한다.") + void findByWithInvalidId() { + final Optional station = stationDao.findById(-3L); + + assertThat(station).isEmpty(); + } + } +} diff --git a/src/test/java/subway/domain/fare/AgeFareStrategyTest.java b/src/test/java/subway/domain/fare/AgeFareStrategyTest.java new file mode 100644 index 000000000..a3fa504b2 --- /dev/null +++ b/src/test/java/subway/domain/fare/AgeFareStrategyTest.java @@ -0,0 +1,83 @@ +package subway.domain.fare; + +import static fixtures.StationFixtures.GANGNAM; +import static fixtures.StationFixtures.YANGJAE; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import subway.domain.line.Line; +import subway.domain.subway.Passenger; +import subway.domain.subway.Subway; +import subway.domain.subway.SubwayJgraphtGraph; + +class AgeFareStrategyTest { + + private static final AgeFareStrategy ageFareStrategy = new AgeFareStrategy(); + + @Test + @DisplayName("이전 요금이 0이라면 요금은 없다.") + void zero() { + final Line lineOfTwo = new Line(2L, "2호선", "초록색", 0); + lineOfTwo.addSection(GANGNAM, YANGJAE, 8); + final Subway subway = new Subway(new SubwayJgraphtGraph(List.of(lineOfTwo))); + final Passenger passenger = new Passenger(5, GANGNAM, YANGJAE); + + final double result = ageFareStrategy.calculateFare(0, passenger, subway); + + assertThat(result).isEqualTo(0); + } + + @Test + @DisplayName("6세 미만 어린이인 경우 할인이 적용되지 않는다.") + void baby() { + final Line lineOfTwo = new Line(2L, "2호선", "초록색", 0); + lineOfTwo.addSection(GANGNAM, YANGJAE, 8); + final Subway subway = new Subway(new SubwayJgraphtGraph(List.of(lineOfTwo))); + final Passenger passenger = new Passenger(5, GANGNAM, YANGJAE); + + final double result = ageFareStrategy.calculateFare(1350, passenger, subway); + + assertThat(result).isEqualTo(1350); + } + + @Test + @DisplayName("6세 이상, 13세 미만의 어린이인 경우 350원 공제 후, 50% 할인이 적용된다.") + void kid() { + final Line lineOfTwo = new Line(2L, "2호선", "초록색", 0); + lineOfTwo.addSection(GANGNAM, YANGJAE, 8); + final Subway subway = new Subway(new SubwayJgraphtGraph(List.of(lineOfTwo))); + final Passenger passenger = new Passenger(8, GANGNAM, YANGJAE); + + final double result = ageFareStrategy.calculateFare(1350, passenger, subway); + + assertThat(result).isEqualTo(500); + } + + @Test + @DisplayName("13세 이상, 19세 미만 청소년인 경우 350원 공제 후, 20% 할인이 적용된다.") + void teen() { + final Line lineOfTwo = new Line(2L, "2호선", "초록색", 0); + lineOfTwo.addSection(GANGNAM, YANGJAE, 8); + final Subway subway = new Subway(new SubwayJgraphtGraph(List.of(lineOfTwo))); + final Passenger passenger = new Passenger(15, GANGNAM, YANGJAE); + + final double result = ageFareStrategy.calculateFare(1350, passenger, subway); + + assertThat(result).isEqualTo(800); + } + + @Test + @DisplayName("19세 이상 어른인 경우 할인이 적용되지 않는다.") + void adult() { + final Line lineOfTwo = new Line(2L, "2호선", "초록색", 0); + lineOfTwo.addSection(GANGNAM, YANGJAE, 8); + final Subway subway = new Subway(new SubwayJgraphtGraph(List.of(lineOfTwo))); + final Passenger passenger = new Passenger(19, GANGNAM, YANGJAE); + + final double result = ageFareStrategy.calculateFare(1350, passenger, subway); + + assertThat(result).isEqualTo(1350); + } +} diff --git a/src/test/java/subway/domain/fare/AgePolicyTest.java b/src/test/java/subway/domain/fare/AgePolicyTest.java new file mode 100644 index 000000000..53b4dc71a --- /dev/null +++ b/src/test/java/subway/domain/fare/AgePolicyTest.java @@ -0,0 +1,66 @@ +package subway.domain.fare; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +class AgePolicyTest { + + @ParameterizedTest + @CsvSource(value = {"5,BABY", "8,KID", "15,TEEN", "19,ADULT"}) + @DisplayName("나이에 맞는 정책을 찾는다.") + void search(final int age, final AgePolicy expected) { + final AgePolicy result = AgePolicy.search(age); + + assertThat(result).isEqualTo(expected); + } + + @Nested + @DisplayName("할인 금액을 계산할 때 ") + class CalculateDiscountFare { + + @Test + @DisplayName("BABY는 할인되지 않는다.") + void baby() { + final AgePolicy agePolicy = AgePolicy.BABY; + + final double result = agePolicy.calculateDiscountFare(1000); + + assertThat(result).isEqualTo(1000); + } + + @Test + @DisplayName("KID는 350원 공제 후, 50% 할인된 금액을 반환한다.") + void kid() { + final AgePolicy agePolicy = AgePolicy.KID; + + final double result = agePolicy.calculateDiscountFare(1350); + + assertThat(result).isEqualTo(500); + } + + @Test + @DisplayName("TEEN은 350원 공제 후, 20% 할인된 금액을 반환한다.") + void teen() { + final AgePolicy agePolicy = AgePolicy.TEEN; + + final double result = agePolicy.calculateDiscountFare(1350); + + assertThat(result).isEqualTo(800); + } + + @Test + @DisplayName("ADULT는 할인되지 않는다.") + void adult() { + final AgePolicy agePolicy = AgePolicy.ADULT; + + final double result = agePolicy.calculateDiscountFare(1000); + + assertThat(result).isEqualTo(1000); + } + } +} diff --git a/src/test/java/subway/domain/fare/DistanceFareStrategyTest.java b/src/test/java/subway/domain/fare/DistanceFareStrategyTest.java new file mode 100644 index 000000000..afa81a265 --- /dev/null +++ b/src/test/java/subway/domain/fare/DistanceFareStrategyTest.java @@ -0,0 +1,94 @@ +package subway.domain.fare; + +import static fixtures.StationFixtures.GANGNAM; +import static fixtures.StationFixtures.YANGJAE; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import subway.domain.line.Line; +import subway.domain.subway.Passenger; +import subway.domain.subway.Subway; +import subway.domain.subway.SubwayJgraphtGraph; + +class DistanceFareStrategyTest { + + private static final DistanceFareStrategy distanceFareStrategy = new DistanceFareStrategy(); + + @Test + @DisplayName("거리가 10km 이내라면 요금은 기본운임 요금 1250원이다.") + void calculateFare() { + final Line lineOfTwo = new Line(2L, "2호선", "초록색", 500); + lineOfTwo.addSection(GANGNAM, YANGJAE, 9); + final Subway subway = new Subway(new SubwayJgraphtGraph(List.of(lineOfTwo))); + final Passenger passenger = new Passenger(26, GANGNAM, YANGJAE); + + final double result = distanceFareStrategy.calculateFare(0, passenger, subway); + + assertThat(result).isEqualTo(1250); + } + + @Nested + @DisplayName("거리가 10km~50km일 때 5km마다 100원을 추가한다면 ") + class BaseTenDistance { + + @Test + @DisplayName("거리가 12km라면 요금은 1350원이다.") + void distanceTwelve() { + final Line lineOfTwo = new Line(2L, "2호선", "초록색", 500); + lineOfTwo.addSection(GANGNAM, YANGJAE, 12); + final Subway subway = new Subway(new SubwayJgraphtGraph(List.of(lineOfTwo))); + final Passenger passenger = new Passenger(26, GANGNAM, YANGJAE); + + final double result = distanceFareStrategy.calculateFare(0, passenger, subway); + + assertThat(result).isEqualTo(1350); + } + + @Test + @DisplayName("거리가 16km라면 요금은 1450원이다.") + void distanceSixTeen() { + final Line lineOfTwo = new Line(2L, "2호선", "초록색", 500); + lineOfTwo.addSection(GANGNAM, YANGJAE, 16); + final Subway subway = new Subway(new SubwayJgraphtGraph(List.of(lineOfTwo))); + final Passenger passenger = new Passenger(26, GANGNAM, YANGJAE); + + final double result = distanceFareStrategy.calculateFare(0, passenger, subway); + + assertThat(result).isEqualTo(1450); + } + } + + @Nested + @DisplayName("거리가 50km 초과할 때 8km마다 100원을 추가한다면") + class BaseFiftyDistance { + + @Test + @DisplayName("거리가 58km라면 요금은 2150원이다.") + void distanceFiftyEight() { + final Line lineOfTwo = new Line(2L, "2호선", "초록색", 500); + lineOfTwo.addSection(GANGNAM, YANGJAE, 58); + final Subway subway = new Subway(new SubwayJgraphtGraph(List.of(lineOfTwo))); + final Passenger passenger = new Passenger(26, GANGNAM, YANGJAE); + + final double result = distanceFareStrategy.calculateFare(0, passenger, subway); + + assertThat(result).isEqualTo(2150); + } + + @Test + @DisplayName("거리가 66km라면 요금은 2250원이다.") + void distanceSixtySix() { + final Line lineOfTwo = new Line(2L, "2호선", "초록색", 500); + lineOfTwo.addSection(GANGNAM, YANGJAE, 66); + final Subway subway = new Subway(new SubwayJgraphtGraph(List.of(lineOfTwo))); + final Passenger passenger = new Passenger(26, GANGNAM, YANGJAE); + + final double result = distanceFareStrategy.calculateFare(0, passenger, subway); + + assertThat(result).isEqualTo(2250); + } + } +} diff --git a/src/test/java/subway/domain/fare/DistancePolicyTest.java b/src/test/java/subway/domain/fare/DistancePolicyTest.java new file mode 100644 index 000000000..45e1c262a --- /dev/null +++ b/src/test/java/subway/domain/fare/DistancePolicyTest.java @@ -0,0 +1,32 @@ +package subway.domain.fare; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +class DistancePolicyTest { + + @ParameterizedTest + @CsvSource(value = {"12,100", "16,200", "50,800", "58,800"}) + @DisplayName("10km~50km는 5km마다 100원이 추가된다.") + void baseTen(final int distance, final long additionFare) { + final DistancePolicy distancePolicy = DistancePolicy.BASE_TEN; + + final long result = distancePolicy.calculateAdditionFare(distance); + + assertThat(result).isEqualTo(additionFare); + } + + @ParameterizedTest + @CsvSource(value = {"50,0", "55,100", "58,100", "66,200"}) + @DisplayName("50km부터는 8km마다 100원이 추가된다.") + void baseFifty(final int distance, final long additionFare) { + final DistancePolicy distancePolicy = DistancePolicy.BASE_FIFTY; + + final long result = distancePolicy.calculateAdditionFare(distance); + + assertThat(result).isEqualTo(additionFare); + } +} diff --git a/src/test/java/subway/domain/fare/FareStrategyCompositeTest.java b/src/test/java/subway/domain/fare/FareStrategyCompositeTest.java new file mode 100644 index 000000000..00c0133ca --- /dev/null +++ b/src/test/java/subway/domain/fare/FareStrategyCompositeTest.java @@ -0,0 +1,30 @@ +package subway.domain.fare; + +import static fixtures.StationFixtures.GANGNAM; +import static fixtures.StationFixtures.YANGJAE; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import subway.domain.line.Line; +import subway.domain.subway.Passenger; +import subway.domain.subway.Subway; +import subway.domain.subway.SubwayJgraphtGraph; + +class FareStrategyCompositeTest { + + @Test + @DisplayName("지정된 요금 정책을 수행 후 요금을 확인한다.") + void calculateFare() { + final FareStrategyComposite fareStrategyComposite = new FareStrategyComposite(); + final Line lineOfTwo = new Line(2L, "2호선", "초록색", 1000); + lineOfTwo.addSection(GANGNAM, YANGJAE, 12); + final Subway subway = new Subway(new SubwayJgraphtGraph(List.of(lineOfTwo))); + final Passenger passenger = new Passenger(15, GANGNAM, YANGJAE); + + final double result = fareStrategyComposite.calculateFare(0, passenger, subway); + + assertThat(result).isEqualTo(1600); + } +} diff --git a/src/test/java/subway/domain/fare/RouteFareStrategyTest.java b/src/test/java/subway/domain/fare/RouteFareStrategyTest.java new file mode 100644 index 000000000..426202d07 --- /dev/null +++ b/src/test/java/subway/domain/fare/RouteFareStrategyTest.java @@ -0,0 +1,50 @@ +package subway.domain.fare; + +import static fixtures.StationFixtures.GANGNAM; +import static fixtures.StationFixtures.GYODAE; +import static fixtures.StationFixtures.NAMBU; +import static fixtures.StationFixtures.YANGJAE; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import subway.domain.line.Line; +import subway.domain.subway.Passenger; +import subway.domain.subway.Subway; +import subway.domain.subway.SubwayJgraphtGraph; + +class RouteFareStrategyTest { + + private static final RouteFareStrategy routeFareStrategy = new RouteFareStrategy(); + + @Test + @DisplayName("900원 추가 요금이 있는 노선을 이용하면 기존 요금에 900원이 추가된다.") + void addition900() { + final Line lineOfTwo = new Line(2L, "2호선", "초록색", 900); + lineOfTwo.addSection(GANGNAM, YANGJAE, 8); + final Subway subway = new Subway(new SubwayJgraphtGraph(List.of(lineOfTwo))); + final Passenger passenger = new Passenger(26, GANGNAM, YANGJAE); + + final double result = routeFareStrategy.calculateFare(1250, passenger, subway); + + assertThat(result).isEqualTo(2150); + } + + @Test + @DisplayName("0, 500, 900원의 추가 요금이 있는 노선을 경유하면 900원이 추가된다.") + void additionHighest900() { + final Line lineOfOne = new Line(1L, "1호선", "빨간색", 0); + final Line lineOfTwo = new Line(2L, "2호선", "노란색", 500); + final Line lineOfThree = new Line(3L, "3호선", "파란색", 900); + lineOfOne.addSection(GANGNAM, YANGJAE, 3); + lineOfTwo.addSection(YANGJAE, GYODAE, 3); + lineOfThree.addSection(GYODAE, NAMBU, 2); + final Subway subway = new Subway(new SubwayJgraphtGraph(List.of(lineOfOne, lineOfTwo, lineOfThree))); + final Passenger passenger = new Passenger(26, GANGNAM, NAMBU); + + final double result = routeFareStrategy.calculateFare(1250, passenger, subway); + + assertThat(result).isEqualTo(2150); + } +} diff --git a/src/test/java/subway/domain/line/ColorTest.java b/src/test/java/subway/domain/line/ColorTest.java new file mode 100644 index 000000000..5bab2ece7 --- /dev/null +++ b/src/test/java/subway/domain/line/ColorTest.java @@ -0,0 +1,41 @@ +package subway.domain.line; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.NullAndEmptySource; +import org.junit.jupiter.params.provider.ValueSource; +import subway.exception.InvalidColorException; + +class ColorTest { + + @ParameterizedTest + @DisplayName("색 이름이 존재하지 않으면 예외를 던진다.") + @NullAndEmptySource + void validateWithNull(final String input) { + assertThatThrownBy(() -> new Color(input)) + .isInstanceOf(InvalidColorException.class) + .hasMessage("노선 색은 존재해야 합니다."); + } + + @Test + @DisplayName("색 이름이 최대 길이를 넘으면 예외를 던진다.") + void validateWithLength() { + final String input = "열한글자가넘는색이름입니다색"; + + assertThatThrownBy(() -> new Color(input)) + .isInstanceOf(InvalidColorException.class) + .hasMessage("노선 색은 11글자까지 가능합니다."); + } + + @ParameterizedTest + @DisplayName("색 이름 형식에 맞지 않는다면 예외를 던진다.") + @ValueSource(strings = {"색", "하양", "white색", "123색"}) + void validateWithInvalidColorFormat(final String input) { + assertThatThrownBy(() -> new Color(input)) + .isInstanceOf(InvalidColorException.class) + .hasMessage("색 이름은 한글로 이루어져 있고, '색'으로 끝나야 합니다."); + } +} diff --git a/src/test/java/subway/domain/line/FareTest.java b/src/test/java/subway/domain/line/FareTest.java new file mode 100644 index 000000000..dc1486307 --- /dev/null +++ b/src/test/java/subway/domain/line/FareTest.java @@ -0,0 +1,19 @@ +package subway.domain.line; + +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import subway.exception.InvalidFareException; + +class FareTest { + + @ParameterizedTest + @ValueSource(ints = {Integer.MIN_VALUE, -1}) + @DisplayName("추가 요금이 0보다 작으면 예외를 던진다.") + void validate(final int value) { + Assertions.assertThatThrownBy(() -> new Fare(value)) + .isInstanceOf(InvalidFareException.class) + .hasMessage("노선 추가 요금은 0원보다 커야 합니다."); + } +} diff --git a/src/test/java/subway/domain/line/LineTest.java b/src/test/java/subway/domain/line/LineTest.java new file mode 100644 index 000000000..b4e3907b2 --- /dev/null +++ b/src/test/java/subway/domain/line/LineTest.java @@ -0,0 +1,192 @@ +package subway.domain.line; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertAll; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import subway.domain.section.Section; +import subway.domain.station.Station; +import subway.exception.InvalidDistanceException; +import subway.exception.InvalidSectionException; + +class LineTest { + + private Line line; + private Station upward; + private Station downward; + + @BeforeEach + void setUp() { + upward = new Station(1L, "잠실역"); + downward = new Station(2L, "종합운동장역"); + final List
sections = List.of( + new Section(upward, downward, 10), + new Section(downward, Station.TERMINAL, 0) + ); + line = new Line(1L, "2호선", "초록색", 500, new ArrayList<>(sections)); + } + + @Nested + @DisplayName("노선 역 추가 시 ") + class AddSection { + + @Test + @DisplayName("노선에 역을 최초 추가한다.") + void addFirstSection() { + final Station upward = new Station(1L, "잠실역"); + final Station downward = new Station(2L, "종합운동장역"); + final Line line = new Line(1L, "2호선", "초록색", 500); + + line.addSection(upward, downward, 10); + + final List result = line.getStations(); + assertAll( + () -> assertThat(result).containsExactly(upward, downward), + () -> assertThat(line.getSections()).extracting(Section::getDistance).containsExactly(10) + ); + } + + @Test + @DisplayName("중간에 상행역을 추가한다.") + void addUpwardSection() { + final Station additionStation = new Station(3L, "잠실새내역"); + + line.addSection(additionStation, downward, 5); + + final List result = line.getStations(); + assertAll( + () -> assertThat(result).containsExactly(upward, additionStation, downward), + () -> assertThat(line.getSections()).extracting(Section::getDistance).containsExactly(5, 5) + ); + } + + @Test + @DisplayName("중간에 하행역을 추가한다.") + void addDownwardSection() { + final Station additionStation = new Station(3L, "잠실새내역"); + + line.addSection(upward, additionStation, 5); + + final List result = line.getStations(); + assertAll( + () -> assertThat(result).containsExactly(upward, additionStation, downward), + () -> assertThat(line.getSections()).extracting(Section::getDistance).containsExactly(5, 5) + ); + } + + @Test + @DisplayName("맨 앞에 역을 추가한다.") + void addSectionAtFirst() { + final Station additionStation = new Station(3L, "잠실새내역"); + + line.addSection(additionStation, upward, 5); + + final List result = line.getStations(); + assertAll( + () -> assertThat(result).containsExactly(additionStation, upward, downward), + () -> assertThat(line.getSections()).extracting(Section::getDistance).containsExactly(5, 10) + ); + } + + @Test + @DisplayName("맨 뒤에 역을 추가한다.") + void addSectionAtLast() { + final Station additionStation = new Station(3L, "잠실새내역"); + + line.addSection(downward, additionStation, 5); + + final List result = line.getStations(); + assertAll( + () -> assertThat(result).containsExactly(upward, downward, additionStation), + () -> assertThat(line.getSections()).extracting(Section::getDistance).containsExactly(10, 5) + ); + } + + @Test + @DisplayName("역이 둘다 존재한다면 예외를 던진다.") + void addSectionWithExistStations() { + assertThatThrownBy(() -> line.addSection(upward, downward, 5)) + .isInstanceOf(InvalidSectionException.class) + .hasMessage("두 역이 이미 노선에 존재합니다."); + } + + @Test + @DisplayName("역이 둘다 존재하지 않으면 예외를 던진다.") + void addSectionWithoutExistStations() { + final Station newUpward = new Station(3L, "잠실새내역"); + final Station newDownward = new Station(4L, "사당역"); + + assertThatThrownBy(() -> line.addSection(newUpward, newDownward, 5)) + .isInstanceOf(InvalidSectionException.class) + .hasMessage("연결할 역 정보가 없습니다."); + } + + @Test + @DisplayName("추가될 역의 거리가 추가될 위치의 두 역 사이보다 크거나 같으면 예외를 던진다.") + void addSectionWithInvalidRangeDistance() { + final Station additionStation = new Station(3L, "잠실새내역"); + + assertThatThrownBy(() -> line.addSection(upward, additionStation, 10)) + .isInstanceOf(InvalidDistanceException.class) + .hasMessage("추가될 역의 거리는 추가될 위치의 두 역사이의 거리보다 작아야합니다."); + } + } + + @Nested + @DisplayName("노선에서 역 제거할 시 ") + class DeleteStation { + + @Test + @DisplayName("역이 2개일 때 역을 제거한다.") + void deleteStationAtInitialState() { + line.deleteStation(upward); + + final List result = line.getStations(); + assertThat(result).isEmpty(); + } + + @Test + @DisplayName("역이 2개가 아닐 때 맨 앞의 역을 제거한다.") + void deleteStationAtFirst() { + final Station additionStation = new Station(3L, "잠실새내역"); + line.addSection(upward, additionStation, 3); + + line.deleteStation(upward); + + final List result = line.getStations(); + assertAll( + () -> assertThat(result).containsExactly(additionStation, downward), + () -> assertThat(line.getSections()).extracting(Section::getDistance).containsExactly(7) + ); + } + + @Test + @DisplayName("역이 2개가 아닐 때 중간의 역을 제거한다.") + void deleteStationBetweenStations() { + final Station additionStation = new Station(3L, "잠실새내역"); + line.addSection(upward, additionStation, 3); + + line.deleteStation(additionStation); + + final List result = line.getStations(); + assertAll( + () -> assertThat(result).containsExactly(upward, downward), + () -> assertThat(line.getSections()).extracting(Section::getDistance).containsExactly(10) + ); + } + + @Test + @DisplayName("역이 존재하지 않을 때 예외를 던진다.") + void deleteStationWithNotExistStation() { + assertThatThrownBy(() -> line.deleteStation(new Station(3L, "잠실새내역"))) + .isInstanceOf(InvalidSectionException.class) + .hasMessage("노선에 해당 역이 존재하지 않습니다."); + } + } +} diff --git a/src/test/java/subway/domain/line/NameTest.java b/src/test/java/subway/domain/line/NameTest.java new file mode 100644 index 000000000..9b017ab11 --- /dev/null +++ b/src/test/java/subway/domain/line/NameTest.java @@ -0,0 +1,38 @@ +package subway.domain.line; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.NullAndEmptySource; +import org.junit.jupiter.params.provider.ValueSource; +import subway.exception.InvalidLineNameException; + +class NameTest { + + @Test + @DisplayName("노선 이름을 정상적으로 생성한다.") + void name() { + assertDoesNotThrow(() -> new Name("5호선")); + } + + @ParameterizedTest + @NullAndEmptySource + @DisplayName("노선이 이름이 존재하지 않으면 예외를 던진다.") + void validateWithBlank(final String input) { + assertThatThrownBy(() -> new Name(input)) + .isInstanceOf(InvalidLineNameException.class) + .hasMessage("노선 이름은 존재해야 합니다."); + } + + @ParameterizedTest + @DisplayName("노선의 이름 형식이 맞지 않을 경우 예외를 던진다.") + @ValueSource(strings = {"3", "삼호선", "0호선", "A호선", "12호선"}) + void validateWithInvalidNameFormat(final String input) { + assertThatThrownBy(() -> new Name(input)) + .isInstanceOf(InvalidLineNameException.class) + .hasMessage("노선 이름은 1~9호선이어야 합니다."); + } +} diff --git a/src/test/java/subway/domain/section/DistanceTest.java b/src/test/java/subway/domain/section/DistanceTest.java new file mode 100644 index 000000000..07dfeb571 --- /dev/null +++ b/src/test/java/subway/domain/section/DistanceTest.java @@ -0,0 +1,30 @@ +package subway.domain.section; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import subway.exception.InvalidDistanceException; + +class DistanceTest { + + @Test + @DisplayName("거리를 정상적으로 생성한다.") + void distance() { + assertDoesNotThrow(() -> new Distance(10)); + } + + @ParameterizedTest + @DisplayName("거리가 0보다 작을 경우 예외를 던진다.") + @ValueSource(ints = {-1, -3}) + void validateWithInvalidRange(final int input) { + assertThatThrownBy(() -> new Distance(input)) + .isInstanceOf(InvalidDistanceException.class) + .hasMessage("역 사이의 거리는 0이상이어야합니다."); + } +} + + diff --git a/src/test/java/subway/domain/section/SectionsTest.java b/src/test/java/subway/domain/section/SectionsTest.java new file mode 100644 index 000000000..cd5daa5cf --- /dev/null +++ b/src/test/java/subway/domain/section/SectionsTest.java @@ -0,0 +1,155 @@ +package subway.domain.section; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertAll; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import subway.domain.station.Station; + +class SectionsTest { + + @Test + @DisplayName("새로운 구간을 추가한다.") + void add() { + final Sections sections = new Sections(new ArrayList<>()); + final Section section = new Section(new Station("잠실역"), new Station("잠실새내역"), 10); + + sections.add(section); + + assertThat(sections.size()).isEqualTo(1); + } + + @Test + @DisplayName("특정 위치에 새로운 구간을 추가한다.") + void addAtPosition() { + final Sections sections = new Sections(new ArrayList<>()); + final Section oldSection = new Section(new Station("잠실역"), new Station("잠실새내역"), 10); + sections.add(oldSection); + final Section newSection = new Section(new Station("강남역"), new Station("잠실역"), 10); + + sections.add(0, newSection); + + assertAll( + () -> assertThat(sections.size()).isEqualTo(2), + () -> assertThat(sections.findSectionByPosition(0)).isEqualTo(newSection), + () -> assertThat(sections.findSectionByPosition(1)).isEqualTo(oldSection) + ); + } + + @Test + @DisplayName("특정 위치의 구간을 조회한다.") + void findSectionByPosition() { + final Sections sections = new Sections(new ArrayList<>()); + final Section section = new Section(new Station("잠실역"), new Station("잠실새내역"), 10); + sections.add(0, section); + + final Section findSection = sections.findSectionByPosition(0); + + assertThat(findSection).isEqualTo(section); + } + + @Test + @DisplayName("특정 위치의 구간을 삭제한다.") + void deleteByPosition() { + final Sections sections = new Sections(new ArrayList<>()); + final Section section = new Section(new Station("잠실역"), new Station("잠실새내역"), 10); + sections.add(0, section); + + sections.deleteByPosition(0); + + assertThat(sections.size()).isEqualTo(0); + } + + @Nested + @DisplayName("findPosition 메서드는 ") + class FindPosition { + + @Test + @DisplayName("역이 상행역으로 존재하는 구간 위치를 찾는다.") + void findPosition() { + final Sections sections = new Sections(new ArrayList<>()); + final Station upward = new Station(1L, "잠실역"); + final Station downward = new Station(2L, "잠실새내역"); + final Section section = new Section(upward, downward, 10); + sections.add(0, section); + + final int position = sections.findPosition(upward); + + assertThat(position).isEqualTo(0); + } + + @Test + @DisplayName("역이 상행역으로 존재하는 구간이 존재하지 않으면 -1을 반환한다.") + void findPositionWithNotExistStation() { + final Sections sections = new Sections(new ArrayList<>()); + final Station upward = new Station(1L, "잠실역"); + final Station downward = new Station(2L, "잠실새내역"); + final Section section = new Section(upward, downward, 10); + sections.add(0, section); + + final int position = sections.findPosition(new Station(3L, "강남역")); + + assertThat(position).isEqualTo(-1); + } + } + + @Nested + @DisplayName("isEmpty 메서드는 ") + class IsEmpty { + + @Test + @DisplayName("구간 정보가 존재하지 않으면 true 반환한다.") + void isEmptyTrue() { + final Sections sections = new Sections(new ArrayList<>()); + + final boolean result = sections.isEmpty(); + + assertThat(result).isTrue(); + } + + @Test + @DisplayName("구간 정보가 존재하면 false 반환한다.") + void isEmptyFalse() { + final Sections sections = new Sections(new ArrayList<>()); + final Section section = new Section(new Station(1L, "잠실역"), new Station(2L, "잠실새내역"), 10); + sections.add(section); + + final boolean result = sections.isEmpty(); + + assertThat(result).isFalse(); + } + } + + @Test + @DisplayName("모든 구간 목록 정보를 삭제한다.") + void clear() { + final Sections sections = new Sections(new ArrayList<>()); + final Section section = new Section(new Station(1L, "잠실역"), new Station(2L, "잠실새내역"), 10); + sections.add(section); + + sections.clear(); + + assertThat(sections.size()).isEqualTo(0); + } + + @Test + @DisplayName("상행 역 목록을 조회한다.") + void getUpwards() { + final Sections sections = new Sections(new ArrayList<>()); + final Station jamsil = new Station(1L, "잠실역"); + final Station jamsilsaenae = new Station(2L, "잠실새내역"); + final Station geondae = new Station(3L, "건대역"); + final Section firstSection = new Section(jamsil, jamsilsaenae, 10); + final Section secondSection = new Section(jamsilsaenae, geondae, 10); + sections.add(firstSection); + sections.add(secondSection); + + final List result = sections.getUpwards(); + + assertThat(result).isEqualTo(List.of(jamsil, jamsilsaenae)); + } +} diff --git a/src/test/java/subway/domain/station/NameTest.java b/src/test/java/subway/domain/station/NameTest.java new file mode 100644 index 000000000..5b3c71e99 --- /dev/null +++ b/src/test/java/subway/domain/station/NameTest.java @@ -0,0 +1,56 @@ +package subway.domain.station; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.NullAndEmptySource; +import org.junit.jupiter.params.provider.ValueSource; +import subway.exception.InvalidStationNameException; + +class NameTest { + + @Test + @DisplayName("역 이름을 정상적으로 생성한다.") + void name() { + assertDoesNotThrow(() -> new Name("을지로3가역")); + } + + @Test + @DisplayName("역 이름 최대 글자를 넘으면 예외를 던진다.") + void validateWithInvalidLength() { + final String input = "열한글자가넘는역이름입니다역"; + + assertThatThrownBy(() -> new Name(input)) + .isInstanceOf(InvalidStationNameException.class) + .hasMessage("역 이름은 11글자까지 가능합니다."); + } + + @Test + @DisplayName("역 이름이 역으로 끝나지 않을 경우 예외를 던진다.") + void validateWithInvalidNameFormat() { + assertThatThrownBy(() -> new Name("선릉")) + .isInstanceOf(InvalidStationNameException.class) + .hasMessage("역 이름은 한글, 숫자만 가능하고, '역'으로 끝나야 합니다."); + } + + @ParameterizedTest + @DisplayName("역 이름이 한글, 숫자로 구성되지 않을 경우 예외를 던진다.") + @ValueSource(strings = {"역", "NewYork역", "!!역"}) + void validateWithInvalidNameElement(final String input) { + assertThatThrownBy(() -> new Name(input)) + .isInstanceOf(InvalidStationNameException.class) + .hasMessage("역 이름은 한글, 숫자만 가능하고, '역'으로 끝나야 합니다."); + } + + @ParameterizedTest + @DisplayName("이름이 빈 칸이거나 null 일 때 예외를 던진다.") + @NullAndEmptySource + void validateWithBlankName(final String input) { + assertThatThrownBy(() -> new Name(input)) + .isInstanceOf(InvalidStationNameException.class) + .hasMessage("역 이름은 공백일 수 없습니다."); + } +} diff --git a/src/test/java/subway/domain/subway/SubwayJgraphtGraphTest.java b/src/test/java/subway/domain/subway/SubwayJgraphtGraphTest.java new file mode 100644 index 000000000..4a2ab6335 --- /dev/null +++ b/src/test/java/subway/domain/subway/SubwayJgraphtGraphTest.java @@ -0,0 +1,121 @@ +package subway.domain.subway; + +import static fixtures.StationFixtures.GANGNAM; +import static fixtures.StationFixtures.GYODAE; +import static fixtures.StationFixtures.JAMSIL; +import static fixtures.StationFixtures.NAMBU; +import static fixtures.StationFixtures.YANGJAE; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.List; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import subway.domain.line.Line; +import subway.domain.section.PathSection; +import subway.exception.InvalidStationException; + +class SubwayJgraphtGraphTest { + + @Nested + @DisplayName("findShortestPathSections 메서드는 ") + class GetShortestPath { + + @Test + @DisplayName("호선 정보를 통해 최단 경로를 반환한다.") + void findShortestPathSections() { + final Line lineOfTwo = new Line(2L, "2호선", "초록색", 500); + final Line lineOfThree = new Line(3L, "3호선", "주황색", 800); + final Line lineOfNine = new Line(9L, "9호선", "빨간색", 1000); + lineOfTwo.addSection(GYODAE, GANGNAM, 20); + lineOfThree.addSection(GYODAE, NAMBU, 5); + lineOfThree.addSection(NAMBU, YANGJAE, 5); + lineOfNine.addSection(GANGNAM, YANGJAE, 5); + + final SubwayJgraphtGraph subwayGraph = new SubwayJgraphtGraph(List.of(lineOfTwo, lineOfThree, lineOfNine)); + final List result = subwayGraph.findShortestPathSections(GYODAE, GANGNAM); + + final List expected = List.of( + new PathSection(3L, GYODAE, NAMBU, 5, 800), + new PathSection(3L, NAMBU, YANGJAE, 5, 800), + new PathSection(9L, YANGJAE, GANGNAM, 5, 1000) + ); + assertThat(result).usingRecursiveComparison().isEqualTo(expected); + } + + @Test + @DisplayName("출발역이 등록되어 있지 않은 경우 예외를 던진다.") + void findShortestPathSectionsWithInvalidStartStation() { + final Line lineOfTwo = new Line(2L, "2호선", "초록색", 500); + lineOfTwo.addSection(GYODAE, GANGNAM, 20); + + final SubwayGraph subwayGraph = new SubwayJgraphtGraph(List.of(lineOfTwo)); + + assertThatThrownBy(() -> subwayGraph.findShortestPathSections(JAMSIL, GANGNAM)) + .isInstanceOf(InvalidStationException.class) + .hasMessage("노선 구간에 등록되지 않은 역 이름을 통해 경로를 조회할 수 없습니다."); + } + + @Test + @DisplayName("도착역이 등록되어 있지 않은 경우 예외를 던진다.") + void findShortestPathSectionsWithInvalidEndStation() { + final Line lineOfTwo = new Line(2L, "2호선", "초록색", 500); + lineOfTwo.addSection(GYODAE, GANGNAM, 20); + + final SubwayGraph subwayGraph = new SubwayJgraphtGraph(List.of(lineOfTwo)); + + assertThatThrownBy(() -> subwayGraph.findShortestPathSections(GANGNAM, JAMSIL)) + .isInstanceOf(InvalidStationException.class) + .hasMessage("노선 구간에 등록되지 않은 역 이름을 통해 경로를 조회할 수 없습니다."); + } + } + + @Nested + @DisplayName("calculateShortestDistance 메서드는 ") + class GetShortestDistance { + + @Test + @DisplayName("호선 정보를 통해 최단 거리를 반환한다.") + void getShortestDistance() { + final Line lineOfTwo = new Line(2L, "2호선", "초록색", 500); + final Line lineOfThree = new Line(3L, "3호선", "주황색", 800); + final Line lineOfNew = new Line(9L, "9호선", "빨간색", 1000); + lineOfTwo.addSection(GYODAE, GANGNAM, 20); + lineOfThree.addSection(GYODAE, NAMBU, 5); + lineOfThree.addSection(NAMBU, YANGJAE, 5); + lineOfNew.addSection(GANGNAM, YANGJAE, 5); + + final SubwayGraph subwayGraph = new SubwayJgraphtGraph(List.of(lineOfTwo, lineOfThree, lineOfNew)); + final long result = subwayGraph.calculateShortestDistance(GYODAE, GANGNAM); + + assertThat(result).isEqualTo(15); + } + + @Test + @DisplayName("출발역이 등록되어 있지 않은 경우 예외를 던진다.") + void getShortestDistanceWithInvalidStartStation() { + final Line lineOfTwo = new Line(2L, "2호선", "초록색", 500); + lineOfTwo.addSection(GYODAE, GANGNAM, 20); + + final SubwayGraph subwayGraph = new SubwayJgraphtGraph(List.of(lineOfTwo)); + + assertThatThrownBy(() -> subwayGraph.calculateShortestDistance(JAMSIL, GANGNAM)) + .isInstanceOf(InvalidStationException.class) + .hasMessage("노선 구간에 등록되지 않은 역 이름을 통해 경로를 조회할 수 없습니다."); + } + + @Test + @DisplayName("도착역이 등록되어 있지 않은 경우 예외를 던진다.") + void getShortestDistanceWithInvalidEndStation() { + final Line lineOfTwo = new Line(2L, "2호선", "초록색", 500); + lineOfTwo.addSection(GYODAE, GANGNAM, 20); + + final SubwayGraph subwayGraph = new SubwayJgraphtGraph(List.of(lineOfTwo)); + + assertThatThrownBy(() -> subwayGraph.calculateShortestDistance(GANGNAM, JAMSIL)) + .isInstanceOf(InvalidStationException.class) + .hasMessage("노선 구간에 등록되지 않은 역 이름을 통해 경로를 조회할 수 없습니다."); + } + } +} diff --git a/src/test/java/subway/integration/IntegrationTest.java b/src/test/java/subway/integration/IntegrationTest.java deleted file mode 100644 index c30949402..000000000 --- a/src/test/java/subway/integration/IntegrationTest.java +++ /dev/null @@ -1,19 +0,0 @@ -package subway.integration; - -import io.restassured.RestAssured; -import org.junit.jupiter.api.BeforeEach; -import org.springframework.boot.test.context.SpringBootTest; -import org.springframework.boot.web.server.LocalServerPort; -import org.springframework.test.annotation.DirtiesContext; - -@DirtiesContext(classMode = DirtiesContext.ClassMode.BEFORE_EACH_TEST_METHOD) -@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) -public class IntegrationTest { - @LocalServerPort - int port; - - @BeforeEach - public void setUp() { - RestAssured.port = port; - } -} diff --git a/src/test/java/subway/integration/LineIntegrationTest.java b/src/test/java/subway/integration/LineIntegrationTest.java deleted file mode 100644 index ad4170205..000000000 --- a/src/test/java/subway/integration/LineIntegrationTest.java +++ /dev/null @@ -1,190 +0,0 @@ -package subway.integration; - -import io.restassured.RestAssured; -import io.restassured.response.ExtractableResponse; -import io.restassured.response.Response; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; -import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; -import subway.dto.LineRequest; -import subway.dto.LineResponse; - -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import static org.assertj.core.api.Assertions.assertThat; - -@DisplayName("지하철 노선 관련 기능") -public class LineIntegrationTest extends IntegrationTest { - private LineRequest lineRequest1; - private LineRequest lineRequest2; - - @BeforeEach - public void setUp() { - super.setUp(); - - lineRequest1 = new LineRequest("신분당선", "bg-red-600"); - lineRequest2 = new LineRequest("구신분당선", "bg-red-600"); - } - - @DisplayName("지하철 노선을 생성한다.") - @Test - void createLine() { - // when - ExtractableResponse response = RestAssured - .given().log().all() - .contentType(MediaType.APPLICATION_JSON_VALUE) - .body(lineRequest1) - .when().post("/lines") - .then().log().all(). - extract(); - - // then - assertThat(response.statusCode()).isEqualTo(HttpStatus.CREATED.value()); - assertThat(response.header("Location")).isNotBlank(); - } - - @DisplayName("기존에 존재하는 지하철 노선 이름으로 지하철 노선을 생성한다.") - @Test - void createLineWithDuplicateName() { - // given - RestAssured - .given().log().all() - .contentType(MediaType.APPLICATION_JSON_VALUE) - .body(lineRequest1) - .when().post("/lines") - .then().log().all(). - extract(); - - // when - ExtractableResponse response = RestAssured - .given().log().all() - .contentType(MediaType.APPLICATION_JSON_VALUE) - .body(lineRequest1) - .when().post("/lines") - .then().log().all(). - extract(); - - // then - assertThat(response.statusCode()).isEqualTo(HttpStatus.BAD_REQUEST.value()); - } - - @DisplayName("지하철 노선 목록을 조회한다.") - @Test - void getLines() { - // given - ExtractableResponse createResponse1 = RestAssured - .given().log().all() - .contentType(MediaType.APPLICATION_JSON_VALUE) - .body(lineRequest1) - .when().post("/lines") - .then().log().all(). - extract(); - - ExtractableResponse createResponse2 = RestAssured - .given().log().all() - .contentType(MediaType.APPLICATION_JSON_VALUE) - .body(lineRequest2) - .when().post("/lines") - .then().log().all(). - extract(); - - // when - ExtractableResponse response = RestAssured - .given().log().all() - .accept(MediaType.APPLICATION_JSON_VALUE) - .when().get("/lines") - .then().log().all() - .extract(); - - // then - assertThat(response.statusCode()).isEqualTo(HttpStatus.OK.value()); - List expectedLineIds = Stream.of(createResponse1, createResponse2) - .map(it -> Long.parseLong(it.header("Location").split("/")[2])) - .collect(Collectors.toList()); - List resultLineIds = response.jsonPath().getList(".", LineResponse.class).stream() - .map(LineResponse::getId) - .collect(Collectors.toList()); - assertThat(resultLineIds).containsAll(expectedLineIds); - } - - @DisplayName("지하철 노선을 조회한다.") - @Test - void getLine() { - // given - ExtractableResponse createResponse = RestAssured - .given().log().all() - .contentType(MediaType.APPLICATION_JSON_VALUE) - .body(lineRequest1) - .when().post("/lines") - .then().log().all(). - extract(); - - // when - Long lineId = Long.parseLong(createResponse.header("Location").split("/")[2]); - ExtractableResponse response = RestAssured - .given().log().all() - .accept(MediaType.APPLICATION_JSON_VALUE) - .when().get("/lines/{lineId}", lineId) - .then().log().all() - .extract(); - - // then - assertThat(response.statusCode()).isEqualTo(HttpStatus.OK.value()); - LineResponse resultResponse = response.as(LineResponse.class); - assertThat(resultResponse.getId()).isEqualTo(lineId); - } - - @DisplayName("지하철 노선을 수정한다.") - @Test - void updateLine() { - // given - ExtractableResponse createResponse = RestAssured - .given().log().all() - .contentType(MediaType.APPLICATION_JSON_VALUE) - .body(lineRequest1) - .when().post("/lines") - .then().log().all(). - extract(); - - // when - Long lineId = Long.parseLong(createResponse.header("Location").split("/")[2]); - ExtractableResponse response = RestAssured - .given().log().all() - .contentType(MediaType.APPLICATION_JSON_VALUE) - .body(lineRequest2) - .when().put("/lines/{lineId}", lineId) - .then().log().all() - .extract(); - - // then - assertThat(response.statusCode()).isEqualTo(HttpStatus.OK.value()); - } - - @DisplayName("지하철 노선을 제거한다.") - @Test - void deleteLine() { - // given - ExtractableResponse createResponse = RestAssured - .given().log().all() - .contentType(MediaType.APPLICATION_JSON_VALUE) - .body(lineRequest1) - .when().post("/lines") - .then().log().all(). - extract(); - - // when - Long lineId = Long.parseLong(createResponse.header("Location").split("/")[2]); - ExtractableResponse response = RestAssured - .given().log().all() - .when().delete("/lines/{lineId}", lineId) - .then().log().all() - .extract(); - - // then - assertThat(response.statusCode()).isEqualTo(HttpStatus.NO_CONTENT.value()); - } -} diff --git a/src/test/java/subway/integration/StationIntegrationTest.java b/src/test/java/subway/integration/StationIntegrationTest.java deleted file mode 100644 index a97d184a0..000000000 --- a/src/test/java/subway/integration/StationIntegrationTest.java +++ /dev/null @@ -1,197 +0,0 @@ -package subway.integration; - -import io.restassured.RestAssured; -import io.restassured.response.ExtractableResponse; -import io.restassured.response.Response; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; -import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; -import subway.dto.StationResponse; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import static org.assertj.core.api.Assertions.assertThat; - -@DisplayName("지하철역 관련 기능") -public class StationIntegrationTest extends IntegrationTest { - @DisplayName("지하철역을 생성한다.") - @Test - void createStation() { - // given - Map params = new HashMap<>(); - params.put("name", "강남역"); - - // when - ExtractableResponse response = RestAssured.given().log().all() - .body(params) - .contentType(MediaType.APPLICATION_JSON_VALUE) - .when() - .post("/stations") - .then().log().all() - .extract(); - - // then - assertThat(response.statusCode()).isEqualTo(HttpStatus.CREATED.value()); - assertThat(response.header("Location")).isNotBlank(); - } - - @DisplayName("기존에 존재하는 지하철역 이름으로 지하철역을 생성한다.") - @Test - void createStationWithDuplicateName() { - // given - Map params = new HashMap<>(); - params.put("name", "강남역"); - RestAssured.given().log().all() - .body(params) - .contentType(MediaType.APPLICATION_JSON_VALUE) - .when() - .post("/stations") - .then().log().all() - .extract(); - - // when - ExtractableResponse response = RestAssured.given().log().all() - .body(params) - .contentType(MediaType.APPLICATION_JSON_VALUE) - .when() - .post("/stations") - .then() - .log().all() - .extract(); - - // then - assertThat(response.statusCode()).isEqualTo(HttpStatus.BAD_REQUEST.value()); - } - - @DisplayName("지하철역 목록을 조회한다.") - @Test - void getStations() { - /// given - Map params1 = new HashMap<>(); - params1.put("name", "강남역"); - ExtractableResponse createResponse1 = RestAssured.given().log().all() - .body(params1) - .contentType(MediaType.APPLICATION_JSON_VALUE) - .when() - .post("/stations") - .then().log().all() - .extract(); - - Map params2 = new HashMap<>(); - params2.put("name", "역삼역"); - ExtractableResponse createResponse2 = RestAssured.given().log().all() - .body(params2) - .contentType(MediaType.APPLICATION_JSON_VALUE) - .when() - .post("/stations") - .then().log().all() - .extract(); - - // when - ExtractableResponse response = RestAssured.given().log().all() - .when() - .get("/stations") - .then().log().all() - .extract(); - - // then - assertThat(response.statusCode()).isEqualTo(HttpStatus.OK.value()); - List expectedStationIds = Stream.of(createResponse1, createResponse2) - .map(it -> Long.parseLong(it.header("Location").split("/")[2])) - .collect(Collectors.toList()); - List resultStationIds = response.jsonPath().getList(".", StationResponse.class).stream() - .map(StationResponse::getId) - .collect(Collectors.toList()); - assertThat(resultStationIds).containsAll(expectedStationIds); - } - - @DisplayName("지하철역을 조회한다.") - @Test - void getStation() { - /// given - Map params1 = new HashMap<>(); - params1.put("name", "강남역"); - ExtractableResponse createResponse = RestAssured.given().log().all() - .body(params1) - .contentType(MediaType.APPLICATION_JSON_VALUE) - .when() - .post("/stations") - .then().log().all() - .extract(); - - // when - Long stationId = Long.parseLong(createResponse.header("Location").split("/")[2]); - ExtractableResponse response = RestAssured.given().log().all() - .when() - .get("/stations/{stationId}", stationId) - .then().log().all() - .extract(); - - // then - assertThat(response.statusCode()).isEqualTo(HttpStatus.OK.value()); - StationResponse stationResponse = response.as(StationResponse.class); - assertThat(stationResponse.getId()).isEqualTo(stationId); - } - - @DisplayName("지하철역을 수정한다.") - @Test - void updateStation() { - // given - Map params = new HashMap<>(); - params.put("name", "강남역"); - ExtractableResponse createResponse = RestAssured.given().log().all() - .body(params) - .contentType(MediaType.APPLICATION_JSON_VALUE) - .when() - .post("/stations") - .then().log().all() - .extract(); - - // when - Map otherParams = new HashMap<>(); - otherParams.put("name", "삼성역"); - String uri = createResponse.header("Location"); - ExtractableResponse response = RestAssured.given().log().all() - .contentType(MediaType.APPLICATION_JSON_VALUE) - .body(otherParams) - .when() - .put(uri) - .then().log().all() - .extract(); - - // then - assertThat(response.statusCode()).isEqualTo(HttpStatus.OK.value()); - } - - @DisplayName("지하철역을 제거한다.") - @Test - void deleteStation() { - // given - Map params = new HashMap<>(); - params.put("name", "강남역"); - ExtractableResponse createResponse = RestAssured.given().log().all() - .body(params) - .contentType(MediaType.APPLICATION_JSON_VALUE) - .when() - .post("/stations") - .then().log().all() - .extract(); - - // when - String uri = createResponse.header("Location"); - ExtractableResponse response = RestAssured.given().log().all() - .when() - .delete(uri) - .then().log().all() - .extract(); - - // then - assertThat(response.statusCode()).isEqualTo(HttpStatus.NO_CONTENT.value()); - } -} diff --git a/src/test/java/subway/repository/LineRepositoryTest.java b/src/test/java/subway/repository/LineRepositoryTest.java new file mode 100644 index 000000000..00161ea67 --- /dev/null +++ b/src/test/java/subway/repository/LineRepositoryTest.java @@ -0,0 +1,134 @@ +package subway.repository; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertAll; + +import java.util.List; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.autoconfigure.jdbc.JdbcTest; +import org.springframework.jdbc.core.JdbcTemplate; +import subway.dao.LineDao; +import subway.dao.SectionDao; +import subway.dao.StationDao; +import subway.domain.line.Line; +import subway.domain.station.Station; +import subway.exception.InvalidLineException; + +@JdbcTest +class LineRepositoryTest { + + @Autowired + private JdbcTemplate jdbcTemplate; + + private LineRepository lineRepository; + private StationRepository stationRepository; + + @BeforeEach + void setUp() { + final StationDao stationDao = new StationDao(jdbcTemplate); + final LineDao lineDao = new LineDao(jdbcTemplate); + final SectionDao sectionDao = new SectionDao(jdbcTemplate); + lineRepository = new LineRepository(lineDao, sectionDao); + stationRepository = new StationRepository(stationDao); + } + + @Test + @DisplayName("노선을 저장한다.") + void save() { + final Line line = new Line("2호선", "초록색", 500); + + final Line result = lineRepository.save(line); + + assertAll( + () -> assertThat(result.getId()).isNotNull(), + () -> assertThat(result.getName()).isEqualTo(line.getName()), + () -> assertThat(result.getColor()).isEqualTo(line.getColor()), + () -> assertThat(result.getFare()).isEqualTo(line.getFare()) + ); + } + + @Nested + @DisplayName("노선 조회 시 ") + class FindById { + + private Line line; + + @BeforeEach + void setUp() { + line = lineRepository.save(new Line("2호선", "초록색", 500)); + final Station upward = stationRepository.save(new Station("잠실역")); + final Station downward = stationRepository.save(new Station("잠실새내역")); + line.addSection(upward, downward, 10); + lineRepository.update(line); + } + + @Test + @DisplayName("ID로 조회할 때 존재하는 노선이라면 노선 정보를 반환한다.") + void findById() { + final Line result = lineRepository.findById(line.getId()); + + assertThat(line).usingRecursiveComparison().isEqualTo(result); + } + + @Test + @DisplayName("ID로 조회할 때 존재하지 않는 노선이라면 예외를 던진다.") + void findByInvalidId() { + assertThatThrownBy(() -> lineRepository.findById(-2L)) + .isInstanceOf(InvalidLineException.class) + .hasMessage("존재하지 않는 노선 ID 입니다."); + } + + @Test + @DisplayName("모든 노선 정보를 조회한다.") + void findAll() { + final List lines = lineRepository.findAll(); + + assertThat(lines).usingRecursiveComparison().isEqualTo(List.of(line)); + } + } + + @Nested + @DisplayName("노선 정보 업데이트 시") + class Update { + + @Test + @DisplayName("섹션이 추가 됐을 때 노선 정보를 업데이트한다.") + void updateWhenStationAdded() { + final Line line = lineRepository.save(new Line("2호선", "초록색", 500)); + final Station upward = stationRepository.save(new Station("잠실역")); + final Station middle = stationRepository.save(new Station("종합운동장역")); + final Station downward = stationRepository.save(new Station("잠실새내역")); + line.addSection(upward, downward, 10); + lineRepository.update(line); + + line.addSection(upward, middle, 3); + lineRepository.update(line); + + final Line result = lineRepository.findById(line.getId()); + assertThat(result).usingRecursiveComparison().isEqualTo(line); + } + + + @Test + @DisplayName("섹션이 삭제 됐을 때 노선 정보를 업데이트한다.") + void updateWhenStationDeleted() { + final Line line = lineRepository.save(new Line("2호선", "초록색", 500)); + final Station upward = stationRepository.save(new Station("잠실역")); + final Station downward = stationRepository.save(new Station("잠실새내역")); + line.addSection(upward, downward, 10); + lineRepository.update(line); + + line.deleteStation(upward); + lineRepository.update(line); + + final Line result = lineRepository.findById(line.getId()); + final List stations = result.getStations(); + assertThat(stations).isEmpty(); + } + } +} diff --git a/src/test/java/subway/repository/StationRepositoryTest.java b/src/test/java/subway/repository/StationRepositoryTest.java new file mode 100644 index 000000000..7a7b08486 --- /dev/null +++ b/src/test/java/subway/repository/StationRepositoryTest.java @@ -0,0 +1,67 @@ +package subway.repository; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.jupiter.api.Assertions.assertAll; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.autoconfigure.jdbc.JdbcTest; +import org.springframework.jdbc.core.JdbcTemplate; +import subway.dao.StationDao; +import subway.domain.station.Station; +import subway.exception.InvalidStationException; + +@JdbcTest +class StationRepositoryTest { + + @Autowired + private JdbcTemplate jdbcTemplate; + + private StationRepository stationRepository; + + @BeforeEach + void setUp() { + final StationDao stationDao = new StationDao(jdbcTemplate); + stationRepository = new StationRepository(stationDao); + } + + @Test + @DisplayName("역을 저장한다.") + void save() { + final Station station = new Station("잠실역"); + + final Station result = stationRepository.save(station); + + assertAll( + () -> assertThat(result.getId()).isNotNull(), + () -> assertThat(result.getName()).isEqualTo(station.getName()) + ); + } + + @Nested + @DisplayName("역을 조회 시 ") + class FindById { + + @Test + @DisplayName("존재하는 역이라면 역 정보를 반환한다.") + void findById() { + final Station station = stationRepository.save(new Station("잠실역")); + + final Station result = stationRepository.findById(station.getId()); + + assertThat(result).usingRecursiveComparison().isEqualTo(station); + } + + @Test + @DisplayName("존재하지 않는 역이라면 예외를 던진다.") + void findByInvalidId() { + assertThatThrownBy(() -> stationRepository.findById(-2L)) + .isInstanceOf(InvalidStationException.class) + .hasMessage("존재하지 않은 역 ID입니다."); + } + } +} diff --git a/src/test/java/subway/service/LineServiceTest.java b/src/test/java/subway/service/LineServiceTest.java new file mode 100644 index 000000000..d982e60cd --- /dev/null +++ b/src/test/java/subway/service/LineServiceTest.java @@ -0,0 +1,192 @@ +package subway.service; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willDoNothing; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import subway.controller.dto.request.LineCreateRequest; +import subway.controller.dto.request.SectionCreateRequest; +import subway.controller.dto.response.LineResponse; +import subway.controller.dto.response.LinesResponse; +import subway.domain.line.Line; +import subway.domain.section.Section; +import subway.domain.station.Station; +import subway.exception.InvalidDistanceException; +import subway.exception.InvalidLineNameException; +import subway.exception.InvalidSectionException; +import subway.repository.LineRepository; +import subway.repository.StationRepository; + +@ExtendWith(MockitoExtension.class) +class LineServiceTest { + + @InjectMocks + private LineService lineService; + + @Mock + private LineRepository lineRepository; + + @Mock + private StationRepository stationRepository; + + @Test + @DisplayName("노선 목록을 조회한다.") + void findLines() { + final List
sectionsOfLineTwo = List.of( + new Section(new Station(1L, "잠실역"), new Station(2L, "잠실새내역"), 10), + new Section(new Station(2L, "잠실새내역"), Station.TERMINAL, 0) + ); + final List
sectionsOfLineFour = List.of( + new Section(new Station(3L, "이수역"), new Station(4L, "서울역"), 11), + new Section(new Station(4L, "서울역"), Station.TERMINAL, 0) + ); + final List lines = List.of( + new Line(1L, "2호선", "초록색", 500, sectionsOfLineTwo), + new Line(2L, "4호선", "하늘색", 500, sectionsOfLineFour) + ); + given(lineRepository.findAll()).willReturn(lines); + + final LinesResponse response = lineService.findLines(); + + final List lineResponses = lines.stream() + .map(LineResponse::from) + .collect(Collectors.toList()); + assertThat(response).usingRecursiveComparison().isEqualTo(new LinesResponse(lineResponses)); + } + + @Nested + @DisplayName("노선 추가시 ") + class CreateLine { + + @Test + @DisplayName("유효한 정보라면 노선을 추가한다.") + void createLine() { + final Line line = new Line(1L, "2호선", "초록색", 500); + final LineCreateRequest request = new LineCreateRequest("2호선", "초록색", 500); + given(lineRepository.save(any(Line.class))).willReturn(line); + + final Long lineId = lineService.createLine(request); + + assertThat(lineId).isEqualTo(1L); + } + + @Test + @DisplayName("유효하지 않은 정보라면 예외를 던진다.") + void createLineWithInvalidName() { + final LineCreateRequest request = new LineCreateRequest("경의중앙선", "초록색", 500); + + assertThatThrownBy(() -> lineService.createLine(request)) + .isInstanceOf(InvalidLineNameException.class); + } + } + + @Nested + @DisplayName("노선 조회시 ") + class FindLineById { + + @Test + @DisplayName("존재하는 노선이라면 노선 정보를 조회한다.") + void findLineById() { + final List
sections = List.of( + new Section(new Station(1L, "잠실역"), new Station(2L, "잠실새내역"), 10), + new Section(new Station(2L, "잠실새내역"), Station.TERMINAL, 0) + ); + final Line line = new Line(1L, "2호선", "초록색", 500, sections); + given(lineRepository.findById(1L)).willReturn(line); + + final LineResponse response = lineService.findLineById(1L); + + assertThat(response).usingRecursiveComparison().isEqualTo(LineResponse.from(line)); + } + } + + @Nested + @DisplayName("노선에 역 등록 시") + class CreateSection { + + @Test + @DisplayName("유효한 정보라면 노선에 역을 등록한다.") + void createSection() { + final SectionCreateRequest request = new SectionCreateRequest(1L, 3L, 2); + final List
sections = List.of( + new Section(new Station(1L, "잠실역"), new Station(2L, "잠실새내역"), 10), + new Section(new Station(2L, "잠실새내역"), Station.TERMINAL, 0) + ); + final Line line = new Line(1L, "2호선", "초록색", 500, new ArrayList<>(sections)); + given(lineRepository.findById(1L)).willReturn(line); + given(stationRepository.findById(1L)).willReturn(new Station(1L, "잠실역")); + given(stationRepository.findById(3L)).willReturn(new Station(3L, "종합운동장역")); + willDoNothing().given(lineRepository).update(any(Line.class)); + + lineService.createSection(1L, request); + + assertThat(line.getSections()).hasSize(2); + } + + @Test + @DisplayName("유효하지 않은 정보라면 예외를 던진다.") + void createSectionWithInvalidDistance() { + final SectionCreateRequest request = new SectionCreateRequest(1L, 3L, 10); + final List
sections = List.of( + new Section(new Station(1L, "잠실역"), new Station(2L, "잠실새내역"), 10), + new Section(new Station(2L, "잠실새내역"), Station.TERMINAL, 0) + ); + final Line line = new Line(1L, "2호선", "초록색", 500, new ArrayList<>(sections)); + given(lineRepository.findById(1L)).willReturn(line); + given(stationRepository.findById(1L)).willReturn(new Station(1L, "잠실역")); + given(stationRepository.findById(3L)).willReturn(new Station(3L, "종합운동장역")); + + assertThatThrownBy(() -> lineService.createSection(1L, request)) + .isInstanceOf(InvalidDistanceException.class); + } + } + + @Nested + @DisplayName("노선에서 역 삭제시 ") + class DeleteStation { + + @Test + @DisplayName("유효한 정보라면 역을 삭제한다.") + void deleteStation() { + final List
sections = List.of( + new Section(new Station(1L, "잠실역"), new Station(2L, "잠실새내역"), 10), + new Section(new Station(2L, "잠실새내역"), Station.TERMINAL, 0) + ); + final Line line = new Line(1L, "2호선", "초록색", 500, new ArrayList<>(sections)); + given(lineRepository.findById(1L)).willReturn(line); + given(stationRepository.findById(1L)).willReturn(new Station(1L, "잠실역")); + willDoNothing().given(lineRepository).update(any(Line.class)); + + lineService.deleteStation(1L, 1L); + + assertThat(line.getSections()).isEmpty(); + } + + @Test + @DisplayName("유효하지 않은 정보라면 예외를 던진다.") + void deleteStationWithInvalidStationId() { + final List
sections = List.of( + new Section(new Station(1L, "잠실역"), new Station(2L, "잠실새내역"), 10), + new Section(new Station(2L, "잠실새내역"), Station.TERMINAL, 0) + ); + final Line line = new Line(1L, "2호선", "초록색", 500, new ArrayList<>(sections)); + given(lineRepository.findById(1L)).willReturn(line); + given(stationRepository.findById(3L)).willReturn(new Station(3L, "종합운동장역")); + + assertThatThrownBy(() -> lineService.deleteStation(1L, 3L)) + .isInstanceOf(InvalidSectionException.class); + } + } +} diff --git a/src/test/java/subway/service/StationServiceTest.java b/src/test/java/subway/service/StationServiceTest.java new file mode 100644 index 000000000..44b42947e --- /dev/null +++ b/src/test/java/subway/service/StationServiceTest.java @@ -0,0 +1,85 @@ +package subway.service; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.BDDMockito.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willThrow; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import subway.controller.dto.request.StationCreateRequest; +import subway.controller.dto.response.StationResponse; +import subway.domain.station.Station; +import subway.exception.InvalidStationException; +import subway.exception.InvalidStationNameException; +import subway.repository.StationRepository; + +@ExtendWith(MockitoExtension.class) +class StationServiceTest { + + @InjectMocks + private StationService stationService; + + @Mock + private StationRepository stationRepository; + + @Nested + @DisplayName("역 생성 시 ") + class CreateStation { + + @Test + @DisplayName("정보가 유효하면 역을 생성한다.") + void createStation() { + final StationCreateRequest request = new StationCreateRequest("잠실역"); + final Station station = new Station(1L, "잠실역"); + given(stationRepository.save(any(Station.class))).willReturn(station); + + final Long stationId = stationService.createStation(request); + + assertThat(stationId).isEqualTo(1L); + } + + @Test + @DisplayName("역 이름이 유효하지 않으면 예외를 던진다.") + void createStationWithInvalidName() { + final StationCreateRequest request = new StationCreateRequest("잠실"); + + assertThatThrownBy(() -> stationService.createStation(request)) + .isInstanceOf(InvalidStationNameException.class); + } + } + + @Nested + @DisplayName("역을 조회 시 ") + class FindStationById { + + @Test + @DisplayName("존재하는 역일시 역 정보를 반환한다.") + void findStationById() { + final Station station = new Station(1L, "잠실역"); + given(stationRepository.findById(1L)).willReturn(station); + + final StationResponse result = stationService.findStationById(1L); + + assertThat(result).usingRecursiveComparison().isEqualTo(StationResponse.from(station)); + } + + @Test + @DisplayName("존재하지 않는 역일시 예외를 던진다.") + void findStationByInvalidId() { + final InvalidStationException exception = new InvalidStationException("존재하지 않는 역입니다."); + willThrow(exception).given(stationRepository).findById(anyLong()); + + assertThatThrownBy(() -> stationService.findStationById(1L)) + .isInstanceOf(InvalidStationException.class) + .hasMessage("존재하지 않는 역입니다."); + } + } +} diff --git a/src/test/java/subway/service/SubwayServiceTest.java b/src/test/java/subway/service/SubwayServiceTest.java new file mode 100644 index 000000000..6f45530cf --- /dev/null +++ b/src/test/java/subway/service/SubwayServiceTest.java @@ -0,0 +1,67 @@ +package subway.service; + +import static fixtures.StationFixtures.GANGNAM; +import static fixtures.StationFixtures.GYODAE; +import static fixtures.StationFixtures.YANGJAE; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyDouble; +import static org.mockito.BDDMockito.given; + +import java.util.List; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import subway.controller.dto.request.PassengerRequest; +import subway.controller.dto.response.ShortestPathResponse; +import subway.domain.fare.FareStrategy; +import subway.domain.line.Line; +import subway.domain.section.PathSection; +import subway.domain.subway.Passenger; +import subway.domain.subway.Subway; +import subway.repository.LineRepository; +import subway.repository.StationRepository; + +@ExtendWith(MockitoExtension.class) +class SubwayServiceTest { + + @InjectMocks + private SubwayService subwayService; + + @Mock + private LineRepository lineRepository; + @Mock + private StationRepository stationRepository; + @Mock + private FareStrategy fareStrategy; + + @Test + @DisplayName("탑승자 정보를 바탕으로 경로 정보를 반환한다.") + void findShortestPath() { + final PassengerRequest request = new PassengerRequest(8, 1L, 3L); + final Line lineOfOne = new Line(1L, "1호선", "빨간색", 500); + final Line lineOfTwo = new Line(2L, "2호선", "파란색", 1000); + lineOfOne.addSection(GANGNAM, YANGJAE, 5); + lineOfTwo.addSection(YANGJAE, GYODAE, 7); + + given(lineRepository.findAll()).willReturn(List.of(lineOfOne, lineOfTwo)); + given(stationRepository.findById(1L)).willReturn(GANGNAM); + given(stationRepository.findById(3L)).willReturn(GYODAE); + given(fareStrategy.calculateFare(anyDouble(), any(Passenger.class), any(Subway.class))).willReturn(1000d); + + final ShortestPathResponse result = subwayService.findShortestPath(request); + + final ShortestPathResponse expected = ShortestPathResponse.of( + List.of( + new PathSection(1L, GANGNAM, YANGJAE, 5, 500), + new PathSection(2L, YANGJAE, GYODAE, 7, 1000) + ), + 12, + 1000 + ); + assertThat(result).usingRecursiveComparison().isEqualTo(expected); + } +} diff --git a/src/test/resources/application.yml b/src/test/resources/application.yml new file mode 100644 index 000000000..273a4b14c --- /dev/null +++ b/src/test/resources/application.yml @@ -0,0 +1,4 @@ +spring: + datasource: + url: jdbc:h2:mem:testdb;MODE=MySQL + driver-class-name: org.h2.Driver diff --git a/src/test/resources/data.sql b/src/test/resources/data.sql new file mode 100644 index 000000000..f0b85f685 --- /dev/null +++ b/src/test/resources/data.sql @@ -0,0 +1,28 @@ +CREATE TABLE IF NOT EXISTS STATION +( + id BIGINT AUTO_INCREMENT NOT NULL, + name VARCHAR(15) NOT NULL, + PRIMARY KEY (id) +); + +CREATE TABLE IF NOT EXISTS LINE +( + id BIGINT AUTO_INCREMENT NOT NULL, + name VARCHAR(15) NOT NULL, + color VARCHAR(15) NOT NULL, + fare INT NOT NULL, + PRIMARY KEY (id) +); + +CREATE TABLE IF NOT EXISTS SECTION +( + id BIGINT AUTO_INCREMENT NOT NULL, + line_id BIGINT NOT NULL, + upward_station_id BIGINT NOT NULL, + downward_station_id BIGINT NOT NULL, + distance INT NOT NULL, + PRIMARY KEY (id), + FOREIGN KEY (line_id) REFERENCES LINE (id), + FOREIGN KEY (upward_station_id) REFERENCES STATION (id), + FOREIGN KEY (downward_station_id) REFERENCES STATION (id) +); diff --git a/src/test/resources/logback-access.xml b/src/test/resources/logback-access.xml new file mode 100644 index 000000000..38e0823f4 --- /dev/null +++ b/src/test/resources/logback-access.xml @@ -0,0 +1,8 @@ + + + + %n###### HTTP Request ######%n%fullRequest%n###### HTTP Response ######%n%fullResponse%n%n + + + + \ No newline at end of file