-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathengine.py
167 lines (133 loc) · 4.55 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import List, Optional, TypeVar, Generic
from core.messages import Message
from core.profile import profile_generation
from core.types import (
Schema,
TokenUsage,
GenerationOutput,
)
@dataclass
class EngineConfig:
pass
T = TypeVar("T", bound=EngineConfig)
class Engine(ABC, Generic[T]):
name: str
def __init__(self, config: T):
"""Defines the interface that should be implemented by all engines.
Engines are assumed to take a schema and generate a JSON object that
matches the schema.
:param config: EngineConfig
Configuration for the engine. This config is passed to the
engine constructor and is used to configure the engine.
"""
self.config = config
self.total_usage = TokenUsage()
@profile_generation
def generate(
self,
task: str,
messages: List[Message],
schema: Schema,
) -> GenerationOutput:
"""Generates a JSON object that matches the schema.
This method is used to generate a JSON object that matches the schema.
It is a wrapper around the `_generate` method.
:param task: str
The task to generate the JSON object for.
:param messages: List[Message]
The messages to generate the JSON object for.
:param schema: Schema
The schema to generate the JSON object for.
:return: GenerationOutput
The generation output.
"""
schema = self.adapt_schema(schema)
output = GenerationOutput(
task=task, messages=messages, generation="", schema=schema
)
self._generate(output)
self.total_usage += output.token_usage
return output
@abstractmethod
def _generate(
self,
output: GenerationOutput,
) -> None:
"""The method that should be implemented by all engines. It takes
a generation output and modifies it in place.
:param output: GenerationOutput
The generation output.
:return: None
The generation output is modified in place.
"""
raise NotImplementedError
@property
@abstractmethod
def max_context_length(self) -> int:
"""The maximum context length of the engine.
:return: int
The maximum context length.
"""
raise NotImplementedError
def adapt_schema(self, schema: Schema) -> Schema:
"""Adapts the schema to the engine. This should be implemented if the
engine needs to modify the schema in some way before generating.
:param schema: Schema
The schema to adapt.
:return: Schema
The adapted schema.
"""
return schema
def encode(self, text: str) -> Optional[List[int]]:
"""Encodes a text string into a list of tokens.
:param text: str
The text to encode.
:return: Optional[List[int]]
The encoded tokens.
"""
return None
def decode(self, ids: List[int]) -> Optional[str]:
"""Decodes a list of tokens into a text string.
:param ids: List[int]
The tokens to decode.
:return: Optional[str]
The decoded text.
"""
return None
def convert_token_to_id(self, token: str) -> Optional[int]:
"""Converts a token to an id.
:param token: str
The token to convert.
:return: Optional[int]
The id of the token.
"""
res = self.encode(token)
return res[0] if res else None
def convert_id_to_token(self, id: int) -> Optional[str]:
"""Converts an id to a token.
:param id: int
The id to convert.
:return: Optional[str]
The token.
"""
res = self.decode([id])
return res[0] if res else None
def count_tokens(self, text: str) -> int:
"""Counts the number of tokens in a text string. This can be
implemented by the engines if they don't provide a tokenizer.
:param text: str
The text to count the tokens in.
:return: int
The number of tokens in the text.
"""
res = self.encode(text)
return len(res) if res else 0
def close(self) -> None:
"""Closes the engine. This method can be implemented by engines that
need to close the model or the sampler.
:return: None
The engine is closed.
"""
pass