From 1c3f16c601fad60d928acd95b829d0f15b7f385e Mon Sep 17 00:00:00 2001 From: monst Date: Tue, 10 Oct 2023 18:30:52 +0200 Subject: [PATCH 1/3] Change Pyris response type --- .../iris/dto/IrisMessageResponseDTO.java | 25 ++++++++++- .../iris/session/IrisChatSessionService.java | 19 +++++++-- .../session/IrisHestiaSessionService.java | 25 ++++++++--- .../connector/IrisRequestMockProvider.java | 42 +++++++++++-------- .../iris/IrisConnectorServiceTest.java | 4 +- 5 files changed, 86 insertions(+), 29 deletions(-) diff --git a/src/main/java/de/tum/in/www1/artemis/service/connectors/iris/dto/IrisMessageResponseDTO.java b/src/main/java/de/tum/in/www1/artemis/service/connectors/iris/dto/IrisMessageResponseDTO.java index cac9b0849bad..979134d653a0 100644 --- a/src/main/java/de/tum/in/www1/artemis/service/connectors/iris/dto/IrisMessageResponseDTO.java +++ b/src/main/java/de/tum/in/www1/artemis/service/connectors/iris/dto/IrisMessageResponseDTO.java @@ -1,6 +1,27 @@ package de.tum.in.www1.artemis.service.connectors.iris.dto; -import de.tum.in.www1.artemis.domain.iris.IrisMessage; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; -public record IrisMessageResponseDTO(String usedModel, IrisMessage message) { +import java.time.ZonedDateTime; + +public record IrisMessageResponseDTO(String usedModel, ZonedDateTime sentAt, JsonNode content) { + + /** + * Create a new IrisMessageResponseDTO. Jackson uses this constructor to create the object. + * This is necessary because Jackson was not throwing an exception when the response from Iris did not contain + * the expected fields, which resulted in a NullPointerException when trying to access the fields. + * Not sure if this is a bug in Jackson or if it is intended behavior, either way this is a workaround. + */ + @JsonCreator + public IrisMessageResponseDTO( + @JsonProperty(value = "used_model", required = true) String usedModel, + @JsonProperty(value = "sent_at", required = true) ZonedDateTime sentAt, + @JsonProperty(value = "content", required = true) JsonNode content) { + this.usedModel = usedModel; + this.sentAt = sentAt; + this.content = content; + } + } diff --git a/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisChatSessionService.java b/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisChatSessionService.java index 508e8c5b7afa..218c7ff8be27 100644 --- a/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisChatSessionService.java +++ b/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisChatSessionService.java @@ -6,6 +6,9 @@ import javax.ws.rs.BadRequestException; +import de.tum.in.www1.artemis.domain.iris.IrisMessage; +import de.tum.in.www1.artemis.domain.iris.IrisMessageContent; +import de.tum.in.www1.artemis.service.connectors.iris.dto.IrisMessageResponseDTO; import org.eclipse.jgit.api.Git; import org.eclipse.jgit.api.errors.GitAPIException; import org.eclipse.jgit.treewalk.FileTreeIterator; @@ -147,13 +150,13 @@ public void requestAndHandleResponse(IrisSession session) { var irisSettings = irisSettingsService.getCombinedIrisSettings(exercise, false); irisConnectorService.sendRequest(irisSettings.getIrisChatSettings().getTemplate(), irisSettings.getIrisChatSettings().getPreferredModel(), parameters) - .handleAsync((irisMessage, throwable) -> { + .handleAsync((response, throwable) -> { if (throwable != null) { log.error("Error while getting response from Iris model", throwable); irisWebsocketService.sendException(fullSession, throwable.getCause()); } - else if (irisMessage != null) { - var irisMessageSaved = irisMessageService.saveMessage(irisMessage.message(), fullSession, IrisMessageSender.LLM); + else if (response != null) { + var irisMessageSaved = irisMessageService.saveMessage(toIrisMessage(response), fullSession, IrisMessageSender.LLM); irisWebsocketService.sendMessage(irisMessageSaved); } else { @@ -163,6 +166,16 @@ else if (irisMessage != null) { return null; }); } + + private static IrisMessage toIrisMessage(IrisMessageResponseDTO dto) { + var message = new IrisMessage(); + message.setSentAt(dto.sentAt()); + message.setSender(IrisMessageSender.LLM); + var irisMessageContent = new IrisMessageContent(); + irisMessageContent.setTextContent(dto.content().get("response").asText()); + message.setContent(List.of(irisMessageContent)); + return message; + } private void addDiffAndTemplatesForStudentAndExerciseIfPossible(User student, ProgrammingExercise exercise, Map parameters) { parameters.put("gitDiff", ""); diff --git a/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisHestiaSessionService.java b/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisHestiaSessionService.java index 5aeea216407b..e6ed6d15614b 100644 --- a/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisHestiaSessionService.java +++ b/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisHestiaSessionService.java @@ -8,6 +8,7 @@ import javax.ws.rs.BadRequestException; +import de.tum.in.www1.artemis.service.connectors.iris.dto.IrisMessageResponseDTO; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.context.annotation.Profile; @@ -89,16 +90,18 @@ public CodeHint generateDescription(CodeHint codeHint) { Map parameters = Map.of("codeHint", irisSession.getCodeHint()); var irisSettings = irisSettingsService.getCombinedIrisSettings(irisSession.getCodeHint().getExercise(), false); try { - var irisMessage1 = irisConnectorService + var response1 = irisConnectorService .sendRequest(irisSettings.getIrisHestiaSettings().getTemplate(), irisSettings.getIrisHestiaSettings().getPreferredModel(), parameters).get(); - irisMessageService.saveMessage(irisMessage1.message(), irisSession, IrisMessageSender.LLM); + var irisMessage1 = toIrisMessage(response1); + irisMessageService.saveMessage(irisMessage1, irisSession, IrisMessageSender.LLM); irisSession = (IrisHestiaSession) irisSessionRepository.findByIdWithMessagesAndContents(irisSession.getId()); - var irisMessage2 = irisConnectorService + var response2 = irisConnectorService .sendRequest(irisSettings.getIrisHestiaSettings().getTemplate(), irisSettings.getIrisHestiaSettings().getPreferredModel(), parameters).get(); - irisMessageService.saveMessage(irisMessage2.message(), irisSession, IrisMessageSender.LLM); + var irisMessage2 = toIrisMessage(response2); + irisMessageService.saveMessage(irisMessage2, irisSession, IrisMessageSender.LLM); - codeHint.setContent(irisMessage1.message().getContent().stream().map(IrisMessageContent::getTextContent).collect(Collectors.joining("\n"))); - codeHint.setDescription(irisMessage2.message().getContent().stream().map(IrisMessageContent::getTextContent).collect(Collectors.joining("\n"))); + codeHint.setContent(irisMessage1.getContent().stream().map(IrisMessageContent::getTextContent).collect(Collectors.joining("\n"))); + codeHint.setDescription(irisMessage2.getContent().stream().map(IrisMessageContent::getTextContent).collect(Collectors.joining("\n"))); return codeHint; } catch (InterruptedException | ExecutionException e) { @@ -106,6 +109,16 @@ public CodeHint generateDescription(CodeHint codeHint) { throw new InternalServerErrorException("Unable to generate description: " + e.getMessage()); } } + + private static IrisMessage toIrisMessage(IrisMessageResponseDTO dto) { + var message = new IrisMessage(); + message.setSentAt(dto.sentAt()); + message.setSender(IrisMessageSender.LLM); + var irisMessageContent = new IrisMessageContent(); + irisMessageContent.setTextContent(dto.content().get("response").asText()); + message.setContent(List.of(irisMessageContent)); + return message; + } private IrisMessage generateSystemMessage() { var irisMessage = new IrisMessage(); diff --git a/src/test/java/de/tum/in/www1/artemis/connector/IrisRequestMockProvider.java b/src/test/java/de/tum/in/www1/artemis/connector/IrisRequestMockProvider.java index f3bb1f04e1a6..615faa55b504 100644 --- a/src/test/java/de/tum/in/www1/artemis/connector/IrisRequestMockProvider.java +++ b/src/test/java/de/tum/in/www1/artemis/connector/IrisRequestMockProvider.java @@ -6,9 +6,9 @@ import java.net.URL; import java.time.ZonedDateTime; -import java.util.Collections; import java.util.Map; +import org.codehaus.jackson.node.ObjectNode; import org.mockito.MockitoAnnotations; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; @@ -24,9 +24,6 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; -import de.tum.in.www1.artemis.domain.iris.IrisMessage; -import de.tum.in.www1.artemis.domain.iris.IrisMessageContent; -import de.tum.in.www1.artemis.domain.iris.IrisMessageSender; import de.tum.in.www1.artemis.service.connectors.iris.dto.*; @Component @@ -71,28 +68,39 @@ public void reset() throws Exception { } /** - * Mocks response call for the pyris call + * Mocks a message response from the call to pyris */ public void mockMessageResponse(String responseMessage) throws JsonProcessingException { if (responseMessage == null) { mockServer.expect(ExpectedCount.once(), requestTo(messagesApiURL.toString())).andExpect(method(HttpMethod.POST)).andRespond(withSuccess()); return; } - var irisMessage = new IrisMessage(); - var irisMessageContent = new IrisMessageContent(); - irisMessageContent.setTextContent(responseMessage); - irisMessage.setContent(Collections.singletonList(irisMessageContent)); - irisMessage.setSender(IrisMessageSender.LLM); - irisMessage.setSentAt(ZonedDateTime.now()); - - var response = new IrisMessageResponseDTO(null, irisMessage); + + var content = Map.of("response", responseMessage); + mockResponse(content); + } + + /** + * Mocks an arbitrary response from the call to pyris + * @param content The content of the response + * @throws JsonProcessingException If the content cannot be serialized to JSON + */ + public void mockResponse(Map content) throws JsonProcessingException { + if (content == null) { + mockServer.expect(ExpectedCount.once(), requestTo(messagesApiURL.toString())).andExpect(method(HttpMethod.POST)).andRespond(withSuccess()); + return; + } + + var response = new IrisMessageResponseDTO("gpt-3.5-turbo", ZonedDateTime.now(), mapper.valueToTree(content)); + var json = mapper.writeValueAsString(response); - - mockServer.expect(ExpectedCount.once(), requestTo(messagesApiURL.toString())).andExpect(method(HttpMethod.POST)).andRespond(withSuccess(json, MediaType.APPLICATION_JSON)); + + mockCustomJsonResponse(json); } - public void mockCustomJsonResponse(String responseMessage) throws JsonProcessingException { - mockServer.expect(ExpectedCount.once(), requestTo(messagesApiURL.toString())).andExpect(method(HttpMethod.POST)) + public void mockCustomJsonResponse(String responseMessage) { + mockServer.expect(ExpectedCount.once(), requestTo(messagesApiURL.toString())) + .andExpect(method(HttpMethod.POST)) .andRespond(withSuccess(responseMessage, MediaType.APPLICATION_JSON)); } diff --git a/src/test/java/de/tum/in/www1/artemis/iris/IrisConnectorServiceTest.java b/src/test/java/de/tum/in/www1/artemis/iris/IrisConnectorServiceTest.java index 45aca2fa24af..4f4f834c3330 100644 --- a/src/test/java/de/tum/in/www1/artemis/iris/IrisConnectorServiceTest.java +++ b/src/test/java/de/tum/in/www1/artemis/iris/IrisConnectorServiceTest.java @@ -46,9 +46,11 @@ void testException(int httpStatus, Class exceptionClass) throws Exception { void testParseException() throws Exception { var template = new IrisTemplate("Dummy"); - irisRequestMockProvider.mockCustomJsonResponse("{\"message\": \"invalid\"}"); + irisRequestMockProvider.mockCustomJsonResponse("{\"invalid\": \"invalid\"}"); irisConnectorService.sendRequest(template, "TEST_MODEL", Collections.emptyMap()).handle((response, throwable) -> { + assertThat(response).isNull(); + assertThat(throwable).isNotNull(); assertThat(throwable.getCause()).isNotNull().isInstanceOf(IrisParseResponseException.class); return null; }).get(); From b1d6f18e90ef39835126bff6ed6b8b40d0dcf106 Mon Sep 17 00:00:00 2001 From: monst Date: Fri, 13 Oct 2023 16:47:57 +0200 Subject: [PATCH 2/3] Restore v1 pyris communication behavior, add v2 alternative DTO --- .../connectors/iris/IrisConnectorService.java | 29 ++++++++++---- .../iris/dto/IrisMessageResponseDTO.java | 25 +----------- .../iris/dto/IrisMessageResponseV2DTO.java | 25 ++++++++++++ .../iris/session/IrisChatSessionService.java | 15 +------ .../session/IrisHestiaSessionService.java | 21 ++-------- .../connector/IrisRequestMockProvider.java | 40 ++++++++----------- .../iris/IrisConnectorServiceTest.java | 2 +- 7 files changed, 71 insertions(+), 86 deletions(-) create mode 100644 src/main/java/de/tum/in/www1/artemis/service/connectors/iris/dto/IrisMessageResponseV2DTO.java diff --git a/src/main/java/de/tum/in/www1/artemis/service/connectors/iris/IrisConnectorService.java b/src/main/java/de/tum/in/www1/artemis/service/connectors/iris/IrisConnectorService.java index c13551e658c4..da926a44485b 100644 --- a/src/main/java/de/tum/in/www1/artemis/service/connectors/iris/IrisConnectorService.java +++ b/src/main/java/de/tum/in/www1/artemis/service/connectors/iris/IrisConnectorService.java @@ -18,10 +18,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import de.tum.in.www1.artemis.domain.iris.IrisTemplate; -import de.tum.in.www1.artemis.service.connectors.iris.dto.IrisErrorResponseDTO; -import de.tum.in.www1.artemis.service.connectors.iris.dto.IrisMessageResponseDTO; -import de.tum.in.www1.artemis.service.connectors.iris.dto.IrisModelDTO; -import de.tum.in.www1.artemis.service.connectors.iris.dto.IrisRequestDTO; +import de.tum.in.www1.artemis.service.connectors.iris.dto.*; import de.tum.in.www1.artemis.service.iris.exception.*; /** @@ -57,7 +54,23 @@ public IrisConnectorService(@Qualifier("irisRestTemplate") RestTemplate restTemp @Async public CompletableFuture sendRequest(IrisTemplate template, String preferredModel, Map parameters) { var request = new IrisRequestDTO(template, preferredModel, parameters); - return sendRequest(request); + return sendRequest(request, "v1", IrisMessageResponseDTO.class); + } + + /** + * Requests a response from an LLM using the V2 API + * + * @param template The template that should be used with the respective parameters (e.g., for initial system message) + * @param preferredModel The LLM model to be used (e.g., GPT3.5-turbo). Note: The used model might not be the preferred model (e.g., if an error occurs or the preferredModel is + * not reachable) + * @param parameters A map of parameters to be included in the template through handlebars (if they are specified + * in the template) + * @return The message response to the request which includes the {@link de.tum.in.www1.artemis.domain.iris.IrisMessage} and the used IrisModel + */ + @Async + public CompletableFuture sendRequestV2(IrisTemplate template, String preferredModel, Map parameters) { + var request = new IrisRequestDTO(template, preferredModel, parameters); + return sendRequest(request, "v2", IrisMessageResponseV2DTO.class); } /** @@ -78,14 +91,14 @@ public List getOfferedModels() throws IrisConnectorException { } } - private CompletableFuture sendRequest(IrisRequestDTO request) { + private CompletableFuture sendRequest(IrisRequestDTO request, String version, Class responseType) { try { try { - var response = restTemplate.postForEntity(irisUrl + "/api/v1/messages", objectMapper.valueToTree(request), JsonNode.class); + var response = restTemplate.postForEntity(irisUrl + "/api/" + version + "/messages", objectMapper.valueToTree(request), JsonNode.class); if (!response.hasBody()) { return CompletableFuture.failedFuture(new IrisNoResponseException()); } - return CompletableFuture.completedFuture(parseResponse(response.getBody(), IrisMessageResponseDTO.class)); + return CompletableFuture.completedFuture(parseResponse(response.getBody(), responseType)); } catch (HttpStatusCodeException e) { switch (e.getStatusCode()) { diff --git a/src/main/java/de/tum/in/www1/artemis/service/connectors/iris/dto/IrisMessageResponseDTO.java b/src/main/java/de/tum/in/www1/artemis/service/connectors/iris/dto/IrisMessageResponseDTO.java index 979134d653a0..cac9b0849bad 100644 --- a/src/main/java/de/tum/in/www1/artemis/service/connectors/iris/dto/IrisMessageResponseDTO.java +++ b/src/main/java/de/tum/in/www1/artemis/service/connectors/iris/dto/IrisMessageResponseDTO.java @@ -1,27 +1,6 @@ package de.tum.in.www1.artemis.service.connectors.iris.dto; -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.JsonNode; +import de.tum.in.www1.artemis.domain.iris.IrisMessage; -import java.time.ZonedDateTime; - -public record IrisMessageResponseDTO(String usedModel, ZonedDateTime sentAt, JsonNode content) { - - /** - * Create a new IrisMessageResponseDTO. Jackson uses this constructor to create the object. - * This is necessary because Jackson was not throwing an exception when the response from Iris did not contain - * the expected fields, which resulted in a NullPointerException when trying to access the fields. - * Not sure if this is a bug in Jackson or if it is intended behavior, either way this is a workaround. - */ - @JsonCreator - public IrisMessageResponseDTO( - @JsonProperty(value = "used_model", required = true) String usedModel, - @JsonProperty(value = "sent_at", required = true) ZonedDateTime sentAt, - @JsonProperty(value = "content", required = true) JsonNode content) { - this.usedModel = usedModel; - this.sentAt = sentAt; - this.content = content; - } - +public record IrisMessageResponseDTO(String usedModel, IrisMessage message) { } diff --git a/src/main/java/de/tum/in/www1/artemis/service/connectors/iris/dto/IrisMessageResponseV2DTO.java b/src/main/java/de/tum/in/www1/artemis/service/connectors/iris/dto/IrisMessageResponseV2DTO.java new file mode 100644 index 000000000000..462fe713417c --- /dev/null +++ b/src/main/java/de/tum/in/www1/artemis/service/connectors/iris/dto/IrisMessageResponseV2DTO.java @@ -0,0 +1,25 @@ +package de.tum.in.www1.artemis.service.connectors.iris.dto; + +import java.time.ZonedDateTime; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; + +public record IrisMessageResponseV2DTO(String usedModel, ZonedDateTime sentAt, JsonNode content) { + + /** + * Create a new IrisMessageResponseDTO. Jackson uses this constructor to create the object. + * This is necessary because Jackson was not throwing an exception when the response from Iris did not contain + * the expected fields, which resulted in a NullPointerException when trying to access the fields. + * Not sure if this is a bug in Jackson or if it is intended behavior, either way this is a workaround. + */ + @JsonCreator + public IrisMessageResponseV2DTO(@JsonProperty(value = "used_model", required = true) String usedModel, @JsonProperty(value = "sent_at", required = true) ZonedDateTime sentAt, + @JsonProperty(value = "content", required = true) JsonNode content) { + this.usedModel = usedModel; + this.sentAt = sentAt; + this.content = content; + } + +} diff --git a/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisChatSessionService.java b/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisChatSessionService.java index 218c7ff8be27..c6aa7251c23c 100644 --- a/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisChatSessionService.java +++ b/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisChatSessionService.java @@ -6,9 +6,6 @@ import javax.ws.rs.BadRequestException; -import de.tum.in.www1.artemis.domain.iris.IrisMessage; -import de.tum.in.www1.artemis.domain.iris.IrisMessageContent; -import de.tum.in.www1.artemis.service.connectors.iris.dto.IrisMessageResponseDTO; import org.eclipse.jgit.api.Git; import org.eclipse.jgit.api.errors.GitAPIException; import org.eclipse.jgit.treewalk.FileTreeIterator; @@ -156,7 +153,7 @@ public void requestAndHandleResponse(IrisSession session) { irisWebsocketService.sendException(fullSession, throwable.getCause()); } else if (response != null) { - var irisMessageSaved = irisMessageService.saveMessage(toIrisMessage(response), fullSession, IrisMessageSender.LLM); + var irisMessageSaved = irisMessageService.saveMessage(response.message(), fullSession, IrisMessageSender.LLM); irisWebsocketService.sendMessage(irisMessageSaved); } else { @@ -166,16 +163,6 @@ else if (response != null) { return null; }); } - - private static IrisMessage toIrisMessage(IrisMessageResponseDTO dto) { - var message = new IrisMessage(); - message.setSentAt(dto.sentAt()); - message.setSender(IrisMessageSender.LLM); - var irisMessageContent = new IrisMessageContent(); - irisMessageContent.setTextContent(dto.content().get("response").asText()); - message.setContent(List.of(irisMessageContent)); - return message; - } private void addDiffAndTemplatesForStudentAndExerciseIfPossible(User student, ProgrammingExercise exercise, Map parameters) { parameters.put("gitDiff", ""); diff --git a/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisHestiaSessionService.java b/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisHestiaSessionService.java index e6ed6d15614b..1ea8cc8ac9c3 100644 --- a/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisHestiaSessionService.java +++ b/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisHestiaSessionService.java @@ -8,7 +8,6 @@ import javax.ws.rs.BadRequestException; -import de.tum.in.www1.artemis.service.connectors.iris.dto.IrisMessageResponseDTO; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.context.annotation.Profile; @@ -92,16 +91,14 @@ public CodeHint generateDescription(CodeHint codeHint) { try { var response1 = irisConnectorService .sendRequest(irisSettings.getIrisHestiaSettings().getTemplate(), irisSettings.getIrisHestiaSettings().getPreferredModel(), parameters).get(); - var irisMessage1 = toIrisMessage(response1); - irisMessageService.saveMessage(irisMessage1, irisSession, IrisMessageSender.LLM); + irisMessageService.saveMessage(response1.message(), irisSession, IrisMessageSender.LLM); irisSession = (IrisHestiaSession) irisSessionRepository.findByIdWithMessagesAndContents(irisSession.getId()); var response2 = irisConnectorService .sendRequest(irisSettings.getIrisHestiaSettings().getTemplate(), irisSettings.getIrisHestiaSettings().getPreferredModel(), parameters).get(); - var irisMessage2 = toIrisMessage(response2); - irisMessageService.saveMessage(irisMessage2, irisSession, IrisMessageSender.LLM); + irisMessageService.saveMessage(response2.message(), irisSession, IrisMessageSender.LLM); - codeHint.setContent(irisMessage1.getContent().stream().map(IrisMessageContent::getTextContent).collect(Collectors.joining("\n"))); - codeHint.setDescription(irisMessage2.getContent().stream().map(IrisMessageContent::getTextContent).collect(Collectors.joining("\n"))); + codeHint.setContent(response1.message().getContent().stream().map(IrisMessageContent::getTextContent).collect(Collectors.joining("\n"))); + codeHint.setDescription(response2.message().getContent().stream().map(IrisMessageContent::getTextContent).collect(Collectors.joining("\n"))); return codeHint; } catch (InterruptedException | ExecutionException e) { @@ -109,16 +106,6 @@ public CodeHint generateDescription(CodeHint codeHint) { throw new InternalServerErrorException("Unable to generate description: " + e.getMessage()); } } - - private static IrisMessage toIrisMessage(IrisMessageResponseDTO dto) { - var message = new IrisMessage(); - message.setSentAt(dto.sentAt()); - message.setSender(IrisMessageSender.LLM); - var irisMessageContent = new IrisMessageContent(); - irisMessageContent.setTextContent(dto.content().get("response").asText()); - message.setContent(List.of(irisMessageContent)); - return message; - } private IrisMessage generateSystemMessage() { var irisMessage = new IrisMessage(); diff --git a/src/test/java/de/tum/in/www1/artemis/connector/IrisRequestMockProvider.java b/src/test/java/de/tum/in/www1/artemis/connector/IrisRequestMockProvider.java index 615faa55b504..4946f5128316 100644 --- a/src/test/java/de/tum/in/www1/artemis/connector/IrisRequestMockProvider.java +++ b/src/test/java/de/tum/in/www1/artemis/connector/IrisRequestMockProvider.java @@ -6,9 +6,9 @@ import java.net.URL; import java.time.ZonedDateTime; +import java.util.Collections; import java.util.Map; -import org.codehaus.jackson.node.ObjectNode; import org.mockito.MockitoAnnotations; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; @@ -24,6 +24,9 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import de.tum.in.www1.artemis.domain.iris.IrisMessage; +import de.tum.in.www1.artemis.domain.iris.IrisMessageContent; +import de.tum.in.www1.artemis.domain.iris.IrisMessageSender; import de.tum.in.www1.artemis.service.connectors.iris.dto.*; @Component @@ -75,32 +78,23 @@ public void mockMessageResponse(String responseMessage) throws JsonProcessingExc mockServer.expect(ExpectedCount.once(), requestTo(messagesApiURL.toString())).andExpect(method(HttpMethod.POST)).andRespond(withSuccess()); return; } - - var content = Map.of("response", responseMessage); - mockResponse(content); - } - - /** - * Mocks an arbitrary response from the call to pyris - * @param content The content of the response - * @throws JsonProcessingException If the content cannot be serialized to JSON - */ - public void mockResponse(Map content) throws JsonProcessingException { - if (content == null) { - mockServer.expect(ExpectedCount.once(), requestTo(messagesApiURL.toString())).andExpect(method(HttpMethod.POST)).andRespond(withSuccess()); - return; - } - - var response = new IrisMessageResponseDTO("gpt-3.5-turbo", ZonedDateTime.now(), mapper.valueToTree(content)); - + + var irisMessage = new IrisMessage(); + var irisMessageContent = new IrisMessageContent(); + irisMessageContent.setTextContent(responseMessage); + irisMessage.setContent(Collections.singletonList(irisMessageContent)); + irisMessage.setSender(IrisMessageSender.LLM); + irisMessage.setSentAt(ZonedDateTime.now()); + + var response = new IrisMessageResponseDTO(null, irisMessage); + var json = mapper.writeValueAsString(response); - - mockCustomJsonResponse(json); + + mockServer.expect(ExpectedCount.once(), requestTo(messagesApiURL.toString())).andExpect(method(HttpMethod.POST)).andRespond(withSuccess(json, MediaType.APPLICATION_JSON)); } public void mockCustomJsonResponse(String responseMessage) { - mockServer.expect(ExpectedCount.once(), requestTo(messagesApiURL.toString())) - .andExpect(method(HttpMethod.POST)) + mockServer.expect(ExpectedCount.once(), requestTo(messagesApiURL.toString())).andExpect(method(HttpMethod.POST)) .andRespond(withSuccess(responseMessage, MediaType.APPLICATION_JSON)); } diff --git a/src/test/java/de/tum/in/www1/artemis/iris/IrisConnectorServiceTest.java b/src/test/java/de/tum/in/www1/artemis/iris/IrisConnectorServiceTest.java index 4f4f834c3330..c0eed8a2ea88 100644 --- a/src/test/java/de/tum/in/www1/artemis/iris/IrisConnectorServiceTest.java +++ b/src/test/java/de/tum/in/www1/artemis/iris/IrisConnectorServiceTest.java @@ -46,7 +46,7 @@ void testException(int httpStatus, Class exceptionClass) throws Exception { void testParseException() throws Exception { var template = new IrisTemplate("Dummy"); - irisRequestMockProvider.mockCustomJsonResponse("{\"invalid\": \"invalid\"}"); + irisRequestMockProvider.mockCustomJsonResponse("{\"message\": \"invalid\"}"); irisConnectorService.sendRequest(template, "TEST_MODEL", Collections.emptyMap()).handle((response, throwable) -> { assertThat(response).isNull(); From c694734c28ef4000824364810809fefb5b1feaa0 Mon Sep 17 00:00:00 2001 From: monst Date: Fri, 13 Oct 2023 16:51:50 +0200 Subject: [PATCH 3/3] Revert variable names --- .../service/iris/session/IrisChatSessionService.java | 6 +++--- .../iris/session/IrisHestiaSessionService.java | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisChatSessionService.java b/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisChatSessionService.java index c6aa7251c23c..508e8c5b7afa 100644 --- a/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisChatSessionService.java +++ b/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisChatSessionService.java @@ -147,13 +147,13 @@ public void requestAndHandleResponse(IrisSession session) { var irisSettings = irisSettingsService.getCombinedIrisSettings(exercise, false); irisConnectorService.sendRequest(irisSettings.getIrisChatSettings().getTemplate(), irisSettings.getIrisChatSettings().getPreferredModel(), parameters) - .handleAsync((response, throwable) -> { + .handleAsync((irisMessage, throwable) -> { if (throwable != null) { log.error("Error while getting response from Iris model", throwable); irisWebsocketService.sendException(fullSession, throwable.getCause()); } - else if (response != null) { - var irisMessageSaved = irisMessageService.saveMessage(response.message(), fullSession, IrisMessageSender.LLM); + else if (irisMessage != null) { + var irisMessageSaved = irisMessageService.saveMessage(irisMessage.message(), fullSession, IrisMessageSender.LLM); irisWebsocketService.sendMessage(irisMessageSaved); } else { diff --git a/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisHestiaSessionService.java b/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisHestiaSessionService.java index 1ea8cc8ac9c3..5aeea216407b 100644 --- a/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisHestiaSessionService.java +++ b/src/main/java/de/tum/in/www1/artemis/service/iris/session/IrisHestiaSessionService.java @@ -89,16 +89,16 @@ public CodeHint generateDescription(CodeHint codeHint) { Map parameters = Map.of("codeHint", irisSession.getCodeHint()); var irisSettings = irisSettingsService.getCombinedIrisSettings(irisSession.getCodeHint().getExercise(), false); try { - var response1 = irisConnectorService + var irisMessage1 = irisConnectorService .sendRequest(irisSettings.getIrisHestiaSettings().getTemplate(), irisSettings.getIrisHestiaSettings().getPreferredModel(), parameters).get(); - irisMessageService.saveMessage(response1.message(), irisSession, IrisMessageSender.LLM); + irisMessageService.saveMessage(irisMessage1.message(), irisSession, IrisMessageSender.LLM); irisSession = (IrisHestiaSession) irisSessionRepository.findByIdWithMessagesAndContents(irisSession.getId()); - var response2 = irisConnectorService + var irisMessage2 = irisConnectorService .sendRequest(irisSettings.getIrisHestiaSettings().getTemplate(), irisSettings.getIrisHestiaSettings().getPreferredModel(), parameters).get(); - irisMessageService.saveMessage(response2.message(), irisSession, IrisMessageSender.LLM); + irisMessageService.saveMessage(irisMessage2.message(), irisSession, IrisMessageSender.LLM); - codeHint.setContent(response1.message().getContent().stream().map(IrisMessageContent::getTextContent).collect(Collectors.joining("\n"))); - codeHint.setDescription(response2.message().getContent().stream().map(IrisMessageContent::getTextContent).collect(Collectors.joining("\n"))); + codeHint.setContent(irisMessage1.message().getContent().stream().map(IrisMessageContent::getTextContent).collect(Collectors.joining("\n"))); + codeHint.setDescription(irisMessage2.message().getContent().stream().map(IrisMessageContent::getTextContent).collect(Collectors.joining("\n"))); return codeHint; } catch (InterruptedException | ExecutionException e) {