-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatabase.py
223 lines (191 loc) · 6.87 KB
/
database.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import random
from enum import IntEnum
import logging
logger = logging.getLogger(__name__)
MESSAGE_TABLE = "messages2"
class FileType(IntEnum):
Audio = 1
Document = 2
Photo = 3
Sticker = 4
Video = 5
Voice = 6
class BoundParameter:
def __init__(self, database, name, default=None, cast_fn=None):
self.database = database
self.name = name
self.default = default
self.cast_fn = cast_fn
# Prefetch parameter
database.get_parameter(name, default)
def get(self):
param = self.database.get_parameter(self.name, self.default)
if self.cast_fn:
param = self.cast_fn(param)
return param
def set(self, value):
self.database.set_parameter(self.name, value)
def __repr__(self):
return repr(self.__get__(None))
def __str__(self):
return str(self.__get__(None))
class Database:
def __init__(self, conn):
self.conn = conn
self._param_cache = {}
def initialize(self):
query = ("CREATE TABLE IF NOT EXISTS chains "
"(source text, response text, count int)")
self.conn.execute(query)
query = ("CREATE UNIQUE INDEX IF NOT EXISTS i_chains ON chains "
"(source, response)")
self.conn.execute(query)
query = """
CREATE TABLE IF NOT EXISTS {} (
"chat_id" INTEGER NOT NULL,
"message_id" INTEGER NOT NULL,
"user_id" INTEGER,
"message" TEXT,
"file_id" TEXT,
"filetype" INTEGER,
"reply_to" INTEGER,
"sent" DATETIME,
PRIMARY KEY (chat_id, message_id)
)
""".format(MESSAGE_TABLE)
self.conn.execute(query)
query = """
CREATE TABLE IF NOT EXISTS params (
key text NOT NULL,
value
)
"""
self.conn.execute(query)
query = ("CREATE UNIQUE INDEX IF NOT EXISTS i_params ON params "
"(key)")
self.conn.execute(query)
query = """
CREATE TABLE IF NOT EXISTS chat_aliases (
name text NOT NULL,
chat_id NOT NULL
)
"""
self.conn.execute(query)
query = ("CREATE UNIQUE INDEX IF NOT EXISTS i_name "
"ON chat_aliases (name)")
self.conn.execute(query)
query = ("CREATE UNIQUE INDEX IF NOT EXISTS i_chat_id "
"ON chat_aliases (chat_id)")
self.conn.execute(query)
query = """
CREATE TABLE IF NOT EXISTS chat_states (
chat_id int NOT NULL,
messages_since_reply int NOT NULL,
stickers_since_reply int NOT NULL,
chain_length int NOT NULL
)
"""
self.conn.execute(query)
query = ("CREATE UNIQUE INDEX IF NOT EXISTS i_chat_states "
"ON chat_states (chat_id)")
self.conn.execute(query)
def commit(self):
return self.conn.commit()
def set_parameter(self, key, value):
self._param_cache[key] = value
logger.debug('setting {} = {}'.format(key, value))
query = "INSERT OR REPLACE INTO params VALUES (?, ?)"
self.conn.execute(query, (key, value))
def get_parameter(self, key, default=None):
try:
return self._param_cache[key]
except KeyError:
query = "SELECT value FROM params WHERE key = ?"
row = self.conn.execute(query, (key,)).fetchone()
if row:
self._param_cache[key] = row[0]
return row[0]
else:
self.set_parameter(key, default)
return default
def get_parameters(self):
query = "SELECT key, value FROM params"
return self.conn.execute(query)
def bound_parameter(self, key, default=None, cast_fn=None):
return BoundParameter(self, key, default, cast_fn)
def set_chat_alias(self, name, value):
query = "INSERT OR REPLACE INTO chat_aliases VALUES (?, ?)"
self.conn.execute(query, (name, value))
def delete_chat_alias(self, name):
query = "DELETE FROM chat_aliases WHERE name = ?"
self.conn.execute(query, (name,))
def get_chat_alias(self, name):
query = "SELECT chat_id FROM chat_aliases WHERE name = ?"
row = self.conn.execute(query, (name,)).fetchone()
if row:
return row[0]
return None
def get_all_chat_aliases(self):
query = "SELECT name, chat_id FROM chat_aliases"
return self.conn.execute(query).fetchall()
def get_chat_state(self, chat_id):
query = """
SELECT
messages_since_reply,
stickers_since_reply,
chain_length
FROM chat_states WHERE
chat_id = ?
"""
return self.conn.execute(query, (chat_id, )).fetchone()
def set_chat_state(self, chat_id,
messages_since_reply,
stickers_since_reply,
chain_length):
query = """
INSERT OR REPLACE INTO chat_states
(chat_id,
messages_since_reply,
stickers_since_reply,
chain_length)
VALUES (?,?,?,?)
"""
values = (chat_id,
messages_since_reply,
stickers_since_reply,
chain_length)
self.conn.execute(query, values)
def add_message(self, message):
kwargs = {
'chat_id': message.chat.id,
'message_id': message.message_id,
'sent': message.date.timestamp()
}
if message.from_user:
kwargs['user_id'] = message.from_user.id
if message.sticker:
kwargs['file_id'] = message.sticker.file_id
kwargs['filetype'] = FileType.Sticker
if message.reply_to_message:
kwargs['reply_to'] = message.reply_to_message.message_id
if message.text:
kwargs['message'] = message.text
keys = kwargs.keys()
query = "INSERT INTO {} ({}) VALUES ({})".format(
MESSAGE_TABLE,
', '.join(keys),
', '.join([':'+key for key in keys]))
self.conn.execute(query, kwargs)
def add_link(self, source, response):
query = """
INSERT OR REPLACE INTO chains
VALUES (:source, :response,
COALESCE(
(SELECT count FROM chains
WHERE source=:source AND response=:response),
0) + 1);
"""
self.conn.execute(query, {'source': source, 'response': response})
def get_response_rows(self, source):
query = "SELECT response, count FROM chains WHERE source=?"
return self.conn.execute(query, (source, ))