-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
272 lines (217 loc) · 9.71 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
from textwrap import dedent
from typing import Any
import streamlit as st
from langchain import PromptTemplate, ConversationChain
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationTokenBufferMemory
from langchain.output_parsers import PydanticOutputParser
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.schema import LLMResult
from pydantic import BaseModel, Field
from redlines import Redlines
MODEL_TOKEN_LIMIT = 4000
MODEL_NAME = 'gpt-4'
class StreamingStreamlitCallbackHandler(BaseCallbackHandler):
"""Callback handler for streaming. Only works with LLMs that support streaming."""
def __init__(
self,
message_placeholder: st.delta_generator.DeltaGenerator,
message_contents: str = "",
):
"""Initialize the callback handler.
Parameters
----------
message_placeholder: st.delta_generator.DeltaGenerator
The placeholder where the messages will be streamed to. Typically an st.empty() object.
"""
self.message_placeholder = message_placeholder
self.message_contents = message_contents
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
self.message_contents += token
self.message_placeholder.markdown(self.message_contents + "▌")
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
self.message_placeholder.markdown(self.message_contents)
def classify_text_level(prompt, message_placeholder) -> str:
"""Classify the prompt based on the Common European Framework of Reference. Prompt
is assumed to be text in a foreign language that the user wants help with."""
llm = ChatOpenAI(
model_name=MODEL_NAME,
temperature=0,
streaming=True,
callbacks=[StreamingStreamlitCallbackHandler(message_placeholder)],
)
shorter_template = """Classify the text based on the Common European Framework of Reference
for Languages (CEFR), provide maxiumum 50 words for your answer.
Text: {text}
Format the output as markdown like this:
## CEFR Level: <level>
<reason>
"""
prompt_template_reason_level = ChatPromptTemplate(
messages=[HumanMessagePromptTemplate.from_template(shorter_template)],
input_variables=["text"],
)
chain_reason_level = LLMChain(
llm=llm, prompt=prompt_template_reason_level, output_key="reason_level"
)
response = chain_reason_level({"text": prompt})
# Add cefr_text explanation to bottom
cefr_text = (
"\n\nSee [Common European Framework of Reference for Languages]"
"(https://en.wikipedia.org/wiki/Common_European_Framework_of_Reference_for_Languages)"
" for more information on language levels."
)
for letter in cefr_text:
response["reason_level"] += letter
message_placeholder.markdown(response["reason_level"] + "▌")
message_placeholder.markdown(response["reason_level"])
return response["reason_level"]
def correct_text(prompt, message_placeholder, message_contents="") -> str:
llm = ChatOpenAI(
model_name=MODEL_NAME,
temperature=0,
streaming=True,
callbacks=[
StreamingStreamlitCallbackHandler(
message_placeholder, message_contents=message_contents
)
],
)
correction_template = """The following is a friendly conversation between a human and an AI. The
AI is helping the human improve their foreign language writing skills. The human provides texts
written in a foreign language and the AI corrects the spelling and grammar of the texts
and provides detailed reasons for each correction.
The AI keeps in in mind spelling, grammar, naturalness (how much it sounds like a native
speaker), correct capitalisation, correct placement of commas or other punctuation and
anything else necessary for correct writing.
The AI only provides corrections for words/phrases that have changed. If the original
text is the same as the corrected text, then the AI does not provide a correction.
The AI knows that each sentence may contain multiple errors and provides corrections for
all errors in the sentence. It also knows that some sentences will not contain any errors
and does not provide corrections for those sentences.
The AI does not give answers like "changed X to Y because this is how it is done in German".
Instead, it explains the reason for the change, e.g. "changed X to Y because Z".
The AI only gives one explanation for each change. It does not repeat the same explanation
multiple times.
The AI counts all changes such as "changed X to Y" as one change. If there are multiple reasons
for the change, they are listed in the same bullet point. For example, "changed X to Y because
Z and W".
If the AI does not know the answer to a question, it truthfully says it does not know.
Current conversation:
{history}
Human: {input}
AI: Let's think step by step"""
correction_prompt = PromptTemplate(
input_variables=["history", "input"], template=correction_template
)
memory = ConversationTokenBufferMemory(llm=llm, max_token_limit=MODEL_TOKEN_LIMIT)
input_1 = "Hallo, ich heisse Adam. Ich habe 25 Jahre alt."
output_1 = dedent(
"""
Let's think step by step
## Corrected Text
Ich heiße Adam. Ich bin 25 Jahre alt.
## Reasons
1. Corrected spelling of 'heisse' to 'heiße' because 'ss' can be combined to form 'ß' in German.
2. Corrected 'alt' to 'bin' because 'bin' is the correct verb to use when stating one's age in German."""
)
memory.save_context({"input": input_1}, {"output": output_1})
input_2 = "Ich bin 25 Jahre alt"
output_2 = dedent(
"""
Let's think step by step
## Corrected Text
Ich bin 25 Jahre alt.
## Reasons
1. Added full stop to the end of the sentence because it is a complete sentence."""
)
memory.save_context({"input": input_2}, {"output": output_2})
input_3 = "Ich habe eine Katze. Sie ist schwarz und klein."
output_3 = dedent(
"""
Let's think step by step
## Corrected Text
Ich habe eine Katze. Sie ist schwarz und klein.
## Reasons
1. No corrections needed. The text is grammatically correct and natural."""
)
memory.save_context({"input": input_3}, {"output": output_3})
input_4 = "Ich wohne auf England fuer 15 Jahren."
output_4 = dedent(
"""
Let's think step by step
## Corrected Text
Ich wohne in England seit 15 Jahren.
## Reasons
1. Corrected 'auf' to 'in' because 'in' is the correct preposition to use when talking about living in a country.
2. Corrected 'fuer' to 'seit' because 'seit' is the correct preposition to use when talking about the duration of time.
"""
)
memory.save_context({"input": input_4}, {"output": output_4})
conversation = ConversationChain(
llm=llm,
memory=memory,
verbose=False,
prompt=correction_prompt,
)
response = conversation.predict(input=prompt)
return response
def parse_corrections(correction_and_reasons):
"""Extract the corrections/reasons from input and store in Pydantic object."""
llm = ChatOpenAI(
model_name=MODEL_NAME,
temperature=0,
)
template = """Extract the corrections and reasons for them from the text.
Text: ####{text}####
{format_instructions}
"""
class Output(BaseModel):
corrected_text: str = Field(description="The corrected text (no heading)")
reasons: list[str] = Field(description="The list of reasons.")
parser = PydanticOutputParser(pydantic_object=Output)
prompt_template = ChatPromptTemplate(
messages=[HumanMessagePromptTemplate.from_template(template)],
input_variables=["text"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
chain = LLMChain(llm=llm, prompt=prompt_template, output_key="output")
output = chain({"text": correction_and_reasons})
results = parser.parse(output["output"])
return results
def main(prompt):
"""Classify, correct and explain the text."""
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# Display assistant response in chat message container
with st.chat_message("assistant"):
message_placeholder = st.empty()
# Classify
text_class = classify_text_level(prompt, message_placeholder)
# Correct + parse
text_correct = correct_text(
prompt, message_placeholder, message_contents=text_class + "\n\n"
)
text_correct = parse_corrections(text_correct)
# Compare with input and create nice redline formatting of changes
comparison = Redlines(prompt, text_correct.corrected_text)
comparison = comparison.output_markdown
# Combine all results into one string and display
final_response = f"{text_class}\n\n"
final_response += "## Corrected Text\n\n"
final_response += f"{comparison}\n\n"
final_response += "## Reasons\n\n"
for reason in text_correct.reasons:
final_response += f"1. {reason}\n"
message_placeholder.markdown(final_response, unsafe_allow_html=True)
st.session_state.messages.append({"role": "assistant", "content": final_response})