Skip to content

Commit

Permalink
Simple AI mentor chat logic (#133)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Felix T.J. Dietrich <[email protected]>
  • Loading branch information
3 people authored Nov 10, 2024
1 parent 0fb30c7 commit ee0d907
Show file tree
Hide file tree
Showing 13 changed files with 405 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/generate-application-server-client.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,4 @@ jobs:
run: |
echo "Removing the autocommit-openapi label..."
curl --silent --fail-with-body -X DELETE -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
https://api.github.com/repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/labels/autocommit-openapi
https://api.github.com/repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/labels/autocommit-openapi
4 changes: 2 additions & 2 deletions .github/workflows/generate-intelligence-service-client.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ jobs:

- name: Install Python dependencies
working-directory: server/intelligence-service
run: poetry install --no-interaction --no-root
run: poetry lock --no-update && poetry install --no-interaction --no-root

- name: Generate API client for the application server
working-directory: server/intelligence-service
Expand Down Expand Up @@ -97,4 +97,4 @@ jobs:
run: |
echo "Removing the autocommit-openapi label..."
curl --silent --fail-with-body -X DELETE -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
https://api.github.com/repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/labels/autocommit-openapi
https://api.github.com/repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/labels/autocommit-openapi
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
name: Lint
name: Intelligence Service QA

on: [pull_request]
on:
pull_request:
paths:
- "server/intelligence-service/**"
push:
paths:
- "server/intelligence-service/**"
branches: [develop]

jobs:
lint:
name: Code Quality Checks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand Down
6 changes: 4 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
"generate:api:application-server": "npm run generate:api:application-server-specs && npm run generate:api:application-server:clean && npm run generate:api:application-server-client",
"generate:api:intelligence-service": "npm run generate:api:intelligence-service:clean && npm run generate:api:intelligence-service-specs && npm run generate:api:intelligence-service-client",
"generate:api": "npm run generate:api:intelligence-service && npm run generate:api:application-server",
"prettier:java:check": "prettier --check server/application-server/src/**/*.java",
"prettier:java:write": "prettier --write server/application-server/src/**/*.java",
"format:java:check": "prettier --check server/application-server/src/**/*.java",
"format:java:write": "prettier --write server/application-server/src/**/*.java",
"format:python:check": "cd server/intelligence-service/ && poetry run black --check .",
"format:python:write": "cd server/intelligence-service/ && poetry run black .",
"db:changelog:diff": "cd server/application-server && docker compose up -d postgres && mvn liquibase:diff && docker compose down postgres"
},
"devDependencies": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public DefaultApi(ApiClient apiClient) {
}

/**
* Get a response from an LLM to a chat message.
* Start and continue a chat session with an LLM.
*
* <p><b>200</b> - Successful Response
* <p><b>422</b> - Validation Error
Expand All @@ -55,7 +55,7 @@ public ChatResponse chatChatPost(ChatRequest chatRequest) throws RestClientExcep
}

/**
* Get a response from an LLM to a chat message.
* Start and continue a chat session with an LLM.
*
* <p><b>200</b> - Successful Response
* <p><b>422</b> - Validation Error
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonTypeName;
import com.fasterxml.jackson.annotation.JsonValue;
import org.openapitools.jackson.nullable.JsonNullable;
import com.fasterxml.jackson.annotation.JsonIgnore;
import org.openapitools.jackson.nullable.JsonNullable;
import java.util.NoSuchElementException;
import com.fasterxml.jackson.annotation.JsonPropertyOrder;
import com.fasterxml.jackson.annotation.JsonTypeName;
import org.hibernate.validator.constraints.*;
Expand All @@ -28,13 +32,17 @@
* ChatRequest
*/
@JsonPropertyOrder({
ChatRequest.JSON_PROPERTY_MESSAGE
ChatRequest.JSON_PROPERTY_MESSAGE,
ChatRequest.JSON_PROPERTY_THREAD_ID
})
@jakarta.annotation.Generated(value = "org.openapitools.codegen.languages.JavaClientCodegen", comments = "Generator version: 7.7.0")
public class ChatRequest {
public static final String JSON_PROPERTY_MESSAGE = "message";
private String message;

public static final String JSON_PROPERTY_THREAD_ID = "thread_id";
private JsonNullable<String> threadId = JsonNullable.<String>undefined();

public ChatRequest() {
}

Expand Down Expand Up @@ -63,6 +71,39 @@ public void setMessage(String message) {
this.message = message;
}

public ChatRequest threadId(String threadId) {
this.threadId = JsonNullable.<String>of(threadId);

return this;
}

/**
* Get threadId
* @return threadId
*/
@jakarta.annotation.Nullable
@JsonIgnore

public String getThreadId() {
return threadId.orElse(null);
}

@JsonProperty(JSON_PROPERTY_THREAD_ID)
@JsonInclude(value = JsonInclude.Include.USE_DEFAULTS)

public JsonNullable<String> getThreadId_JsonNullable() {
return threadId;
}

@JsonProperty(JSON_PROPERTY_THREAD_ID)
public void setThreadId_JsonNullable(JsonNullable<String> threadId) {
this.threadId = threadId;
}

public void setThreadId(String threadId) {
this.threadId = JsonNullable.<String>of(threadId);
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand All @@ -72,19 +113,32 @@ public boolean equals(Object o) {
return false;
}
ChatRequest chatRequest = (ChatRequest) o;
return Objects.equals(this.message, chatRequest.message);
return Objects.equals(this.message, chatRequest.message) &&
equalsNullable(this.threadId, chatRequest.threadId);
}

private static <T> boolean equalsNullable(JsonNullable<T> a, JsonNullable<T> b) {
return a == b || (a != null && b != null && a.isPresent() && b.isPresent() && Objects.deepEquals(a.get(), b.get()));
}

@Override
public int hashCode() {
return Objects.hash(message);
return Objects.hash(message, hashCodeNullable(threadId));
}

private static <T> int hashCodeNullable(JsonNullable<T> a) {
if (a == null) {
return 1;
}
return a.isPresent() ? Arrays.deepHashCode(new Object[]{a.get()}) : 31;
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("class ChatRequest {\n");
sb.append(" message: ").append(toIndentedString(message)).append("\n");
sb.append(" threadId: ").append(toIndentedString(threadId)).append("\n");
sb.append("}");
return sb.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonTypeName;
import com.fasterxml.jackson.annotation.JsonValue;
import org.openapitools.jackson.nullable.JsonNullable;
import com.fasterxml.jackson.annotation.JsonIgnore;
import org.openapitools.jackson.nullable.JsonNullable;
import java.util.NoSuchElementException;
import com.fasterxml.jackson.annotation.JsonPropertyOrder;
import com.fasterxml.jackson.annotation.JsonTypeName;
import org.hibernate.validator.constraints.*;
Expand All @@ -28,13 +32,17 @@
* ChatResponse
*/
@JsonPropertyOrder({
ChatResponse.JSON_PROPERTY_RESPONSE
ChatResponse.JSON_PROPERTY_RESPONSE,
ChatResponse.JSON_PROPERTY_THREAD_ID
})
@jakarta.annotation.Generated(value = "org.openapitools.codegen.languages.JavaClientCodegen", comments = "Generator version: 7.7.0")
public class ChatResponse {
public static final String JSON_PROPERTY_RESPONSE = "response";
private String response;

public static final String JSON_PROPERTY_THREAD_ID = "thread_id";
private JsonNullable<String> threadId = JsonNullable.<String>undefined();

public ChatResponse() {
}

Expand Down Expand Up @@ -63,6 +71,39 @@ public void setResponse(String response) {
this.response = response;
}

public ChatResponse threadId(String threadId) {
this.threadId = JsonNullable.<String>of(threadId);

return this;
}

/**
* Get threadId
* @return threadId
*/
@jakarta.annotation.Nullable
@JsonIgnore

public String getThreadId() {
return threadId.orElse(null);
}

@JsonProperty(JSON_PROPERTY_THREAD_ID)
@JsonInclude(value = JsonInclude.Include.USE_DEFAULTS)

public JsonNullable<String> getThreadId_JsonNullable() {
return threadId;
}

@JsonProperty(JSON_PROPERTY_THREAD_ID)
public void setThreadId_JsonNullable(JsonNullable<String> threadId) {
this.threadId = threadId;
}

public void setThreadId(String threadId) {
this.threadId = JsonNullable.<String>of(threadId);
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand All @@ -72,19 +113,32 @@ public boolean equals(Object o) {
return false;
}
ChatResponse chatResponse = (ChatResponse) o;
return Objects.equals(this.response, chatResponse.response);
return Objects.equals(this.response, chatResponse.response) &&
equalsNullable(this.threadId, chatResponse.threadId);
}

private static <T> boolean equalsNullable(JsonNullable<T> a, JsonNullable<T> b) {
return a == b || (a != null && b != null && a.isPresent() && b.isPresent() && Objects.deepEquals(a.get(), b.get()));
}

@Override
public int hashCode() {
return Objects.hash(response);
return Objects.hash(response, hashCodeNullable(threadId));
}

private static <T> int hashCodeNullable(JsonNullable<T> a) {
if (a == null) {
return 1;
}
return a.isPresent() ? Arrays.deepHashCode(new Object[]{a.get()}) : 31;
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("class ChatResponse {\n");
sb.append(" response: ").append(toIndentedString(response)).append("\n");
sb.append(" threadId: ").append(toIndentedString(threadId)).append("\n");
sb.append("}");
return sb.toString();
}
Expand Down
7 changes: 5 additions & 2 deletions server/intelligence-service/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ def is_openai_available(self):

@property
def is_azure_openai_available(self):
return bool(self.AZURE_OPENAI_API_KEY) and bool(self.AZURE_OPENAI_ENDPOINT) and bool(
self.AZURE_OPENAI_API_VERSION)
return (
bool(self.AZURE_OPENAI_API_KEY)
and bool(self.AZURE_OPENAI_ENDPOINT)
and bool(self.AZURE_OPENAI_API_VERSION)
)


settings = Settings()
40 changes: 32 additions & 8 deletions server/intelligence-service/app/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from .model import model

from .model import start_chat as start_chat_function, chat as chat_function
from typing import Dict, Optional

app = FastAPI(
title="Hephaestus Intelligence Service API",
Expand All @@ -10,19 +10,43 @@
contact={"name": "Felix T.J. Dietrich", "email": "[email protected]"},
)

# Global dictionary to store conversation states
conversations: Dict[str, dict] = {}


class ChatRequest(BaseModel):
message: str
thread_id: Optional[str] = None


class ChatResponse(BaseModel):
response: str
thread_id: Optional[str] = None


@app.post("/chat", response_model=ChatResponse, summary="Get a response from an LLM to a chat message.")
@app.post(
"/chat",
response_model=ChatResponse,
summary="Start and continue a chat session with an LLM.",
)
async def chat(request: ChatRequest):
try:
response = model.invoke(request.message)
return {"response": response.content}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if request.thread_id is None:
# Start a new chat session
result = start_chat_function(request.message)
thread_id = result["thread_id"]
state = result["state"]
response_message = result["response"]["messages"][-1].content
conversations[thread_id] = state
return ChatResponse(thread_id=thread_id, response=response_message)
else:
thread_id = request.thread_id
# Check if the thread_id exists
if thread_id not in conversations:
raise HTTPException(status_code=404, detail="Thread ID not found")
state = conversations[thread_id]
user_input = request.message
result = chat_function(thread_id, user_input, state)
state = result["state"]
response_message = result["response"]["messages"][-1].content
conversations[thread_id] = state
return ChatResponse(response=response_message)
Loading

0 comments on commit ee0d907

Please sign in to comment.