-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
245 lines (201 loc) · 6.7 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
from typing import List, Dict
from hashlib import md5
import random
import os
import pathlib
import torch
import matplotlib.pyplot as plt
from flask import request, flash
from flask_login import current_user
from sqlalchemy import desc, func
from werkzeug.utils import secure_filename
from aflat.nn import Generator
from aflat.main import db
from aflat.models import User, Comment, Post, PostScore
absolute_path = str(pathlib.Path().absolute())
model_path = "./gan/generator.pth"
print(absolute_path)
if "aflat" not in absolute_path:
absolute_path = os.path.join(absolute_path, "aflat")
model_path = os.path.join(absolute_path, "gan/generator.pth")
gen = Generator(512, 512)
gen.load_state_dict(
torch.load(model_path, map_location=torch.device("cpu"))["state_dict"]
)
TITLES = [
"Painting in case of a zombie apocalypse",
"Paintings technique that changed my life forever",
"Paintings frying around",
"Painting you’ll encounter during your next trip",
"Painting from a dog’s perspective",
"Never trust paintings",
"Most footballers have paintings for breakfast",
"Painting that will make you a painting",
"Paintings’s adventure",
"Demystifying paintings",
"Keep calm and think about paintings",
"Painting fail",
"The prehistoric painting",
"Zombie painting is better than sleeping",
]
EXTENSIONS = ["png", "jpg", "jpeg", "git", "avif"]
POSTS_PATH = absolute_path + "/static/generated/"
def comments_data(id_: int) -> List[Dict]:
comments_json = []
comments = Comment.query.filter_by(post_id=id_).all()
for com in reversed(comments):
comments_json.append(
{
"username": User.query.filter_by(id=com.user_id).first().username,
"date": com.date,
"content": com.content,
}
)
return comments_json
def users_comments_data():
users = User.query
users_json = []
for user in users.all():
if user:
users_json.append({"id": user.id, "username": user.username})
comments = Comment.query.all()
comments_json = []
for comment in reversed(comments):
user = users.filter_by(id=comment.user_id).first()
if user:
username = user.username
else:
username = "[deleted]"
comments_json.append(
{
"date": comment.date,
"post_id": comment.post_id,
"username": username,
"content": comment.content,
}
)
return [users_json, comments_json]
def paintings_data(page=0) -> Dict:
start = page * 5
posts = Post.query.order_by(desc(Post.id))[start : start + 5]
posts_json = []
for post in posts:
comment_count = Comment.query.filter_by(post_id=post.id).count()
posts_json.append(
{
"id": post.id,
"title": post.title,
"filename": post.picture_filename,
"comments_num": comment_count,
}
)
return posts_json
def paintings_count_data():
return Post.query.count()
def painting_data(id_) -> Dict:
post = Post.query.get(id_)
return {"id": post.id, "title": post.title, "filename": post.picture_filename}
def new_painting():
noise = torch.randn(1, 512, 1, 1, dtype=torch.float32)
with torch.no_grad():
img = gen.forward(noise, 1, 6) * 0.5 + 0.5
plt.imsave(os.path.join(POSTS_PATH, "tmp.jpg"), img[0].permute(1, 2, 0).numpy())
def publish_painting():
post = Post.query.order_by(Post.id.desc()).first()
user = User.query.filter_by(username=current_user.username).first()
if post:
os.path.join(POSTS_PATH, "tmp.jpg")
os.rename(
os.path.join(POSTS_PATH, "tmp.jpg"),
os.path.join(POSTS_PATH, f"{post.id+1}.jpg"),
)
new_post = Post(
title=random.choice(TITLES),
picture_filename=f"/generated/{post.id+1}.jpg",
user_post=user,
)
else:
os.rename(
os.path.join(POSTS_PATH, "tmp.jpg"), os.path.join(POSTS_PATH, "1.jpg")
)
new_post = Post(
title=random.choice(TITLES),
picture_filename="/generated/1.jpg",
user_post=user,
)
db.session.add(new_post)
db.session.commit()
def check_extension(fn):
return fn.split(".")[-1] in EXTENSIONS
def new_filename(fn, file):
extension = "." + fn.split(".")[-1]
return md5(file).hexdigest() + extension
def generated_directory():
if os.path.isdir(POSTS_PATH):
return True
try:
os.mkdir(POSTS_PATH)
except:
return False
return True
def publish_post():
title = request.form.get("title")
if title == "":
flash("No title!")
return False
if "image" not in request.files:
flash("No image")
return False
file = request.files["image"]
if file.filename == "":
flash("No image")
return False
if check_extension(file.filename):
filename = secure_filename(file.filename)
filename = new_filename(filename, file.stream.read())
path = os.path.join(POSTS_PATH, filename)
if not generated_directory():
flash("Hmmm error or smth....")
return False
file.stream.seek(0)
file.save(path)
user = User.query.filter_by(username=current_user.username).first()
new_post = Post(
title=title, picture_filename="generated/" + filename, user_post=user
)
db.session.add(new_post)
db.session.commit()
return True
flash("Wrong file!")
return False
def stonks_db(id_, username):
user = User.query.filter_by(username=current_user.username).first()
stonk = PostScore.query.filter_by(post_score=user, post_id=id_).first()
if not stonk:
new_stonk = PostScore(user_stonk=user, post_id=id_)
db.session.add(new_stonk)
db.session.commit()
else:
db.session.delete(stonk)
db.session.commit()
def get_stonk(id_, username):
user = User.query.filter_by(username=username).first()
if not user:
return False
stonk = PostScore.query.filter_by(user_id=user.id, post_id=id_).first()
if stonk:
return True
return False
def popular_data():
popular = (
PostScore.query.with_entities(PostScore.post_id, func.count(PostScore.post_id))
.group_by(PostScore.post_id)
.all()
)
popular_json = []
for post in sorted(popular, key=lambda k: k[1], reverse=True)[:10]:
p = Post.query.filter_by(id=post[0]).first()
popular_json.append(
{"id": post[0], "title": p.title, "filename": p.picture_filename}
)
return popular_json