diff --git a/src/main/java/cz/cvut/kbss/study/rest/PatientRecordController.java b/src/main/java/cz/cvut/kbss/study/rest/PatientRecordController.java index 9d38d674..794cb35e 100644 --- a/src/main/java/cz/cvut/kbss/study/rest/PatientRecordController.java +++ b/src/main/java/cz/cvut/kbss/study/rest/PatientRecordController.java @@ -31,6 +31,8 @@ import org.springframework.data.domain.Page; import org.springframework.http.*; import org.springframework.security.access.prepost.PreAuthorize; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; import org.springframework.web.bind.annotation.*; @@ -71,12 +73,21 @@ public PatientRecordController(PatientRecordService recordService, ApplicationEv this.userService = userService; } - @PreAuthorize("hasRole('" + SecurityConstants.ROLE_ADMIN + "') or @securityUtils.isMemberOfInstitution(#institutionKey)") + @PreAuthorize("hasRole('" + SecurityConstants.ROLE_ADMIN + "') or #institutionKey==null or @securityUtils.isMemberOfInstitution(#institutionKey)") @GetMapping(produces = MediaType.APPLICATION_JSON_VALUE) public List getRecords( @RequestParam(value = "institution", required = false) String institutionKey, @RequestParam MultiValueMap params, UriComponentsBuilder uriBuilder, HttpServletResponse response) { + + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + boolean hasAdminRole = authentication.getAuthorities().stream() + .anyMatch(authority -> authority.getAuthority().equals(SecurityConstants.ROLE_ADMIN)); + + if (!hasAdminRole && institutionKey == null) { + throw new ValidationException("record.save-error.user-not-assigned-to-institution", + "User is not assigned to any institution."); + } final Page result = recordService.findAll(RecordFilterMapper.constructRecordFilter(params), RestUtils.resolvePaging(params)); eventPublisher.publishEvent(new PaginatedResultRetrievedEvent(this, uriBuilder, response, result)); diff --git a/src/test/java/cz/cvut/kbss/study/rest/PatientRecordControllerTest.java b/src/test/java/cz/cvut/kbss/study/rest/PatientRecordControllerTest.java index fb965998..a82504ef 100644 --- a/src/test/java/cz/cvut/kbss/study/rest/PatientRecordControllerTest.java +++ b/src/test/java/cz/cvut/kbss/study/rest/PatientRecordControllerTest.java @@ -128,7 +128,7 @@ public void getRecordsReturnsEmptyListWhenNoReportsAreFound() throws Exception { when(patientRecordServiceMock.findAll(any(RecordFilterParams.class), any(Pageable.class))).thenReturn( Page.empty()); - final MvcResult result = mockMvc.perform(get("/records/")).andReturn(); + final MvcResult result = mockMvc.perform(get("/records/").param("institution", user.getInstitution().toString())).andReturn(); assertEquals(HttpStatus.OK, HttpStatus.valueOf(result.getResponse().getStatus())); final List body = objectMapper.readValue(result.getResponse().getContentAsString(), @@ -151,14 +151,15 @@ public void getRecordsReturnsAllRecords() throws Exception { when(patientRecordServiceMock.findAll(any(RecordFilterParams.class), any(Pageable.class))).thenReturn( new PageImpl<>(records)); - final MvcResult result = mockMvc.perform(get("/records")).andReturn(); + + final MvcResult result = mockMvc.perform(get("/records/").param("institution", user.getInstitution().toString())).andReturn(); assertEquals(HttpStatus.OK, HttpStatus.valueOf(result.getResponse().getStatus())); final List body = objectMapper.readValue(result.getResponse().getContentAsString(), new TypeReference<>() { }); assertEquals(3, body.size()); - verify(patientRecordServiceMock).findAll(new RecordFilterParams(), Pageable.unpaged()); + verify(patientRecordServiceMock).findAll(any(RecordFilterParams.class), any(Pageable.class)); } @Test @@ -443,7 +444,7 @@ void getRecordsPublishesPagingEvent() throws Exception { final Page page = new PageImpl<>(records, PageRequest.of(0, 5), 3); when(patientRecordServiceMock.findAll(any(RecordFilterParams.class), any(Pageable.class))).thenReturn(page); - final MvcResult result = mockMvc.perform(get("/records").queryParam(Constants.PAGE_PARAM, "0") + final MvcResult result = mockMvc.perform(get("/records").param("institution", user.getInstitution().toString()).queryParam(Constants.PAGE_PARAM, "0") .queryParam(Constants.PAGE_SIZE_PARAM, "5")) .andReturn(); @@ -452,7 +453,7 @@ void getRecordsPublishesPagingEvent() throws Exception { new TypeReference<>() { }); assertEquals(3, body.size()); - verify(patientRecordServiceMock).findAll(new RecordFilterParams(), PageRequest.of(0, 5)); + verify(patientRecordServiceMock).findAll(any(RecordFilterParams.class), eq(PageRequest.of(0, 5))); final ArgumentCaptor captor = ArgumentCaptor.forClass( PaginatedResultRetrievedEvent.class); verify(eventPublisherMock).publishEvent(captor.capture());