forked from Firesuiry/autoLOL
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathpicProcessor.py
224 lines (194 loc) · 6.46 KB
/
picProcessor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# -*- coding: utf-8 -*-
from paramsExtract.paramsExtracter import paramExtract
import numpy as np
import cv2
import matplotlib.pyplot as plt
from setting import *
from modelManger import model_manager
game_state_check_running_img = cv2.imread(PROJECT_ADDRESS + 'resource/GAME_STATE_CHECK_RUNNING.bmp')
game_state_check_ending_img = cv2.imread(PROJECT_ADDRESS + 'resource/GAME_STATE_CHECK_RUNNING2.png')
class picProcessor:
def __init__(self,test=False):
self.test = test
self.oldobs = np.zeros((4689,))
# 以下初始化一些数据
self.currentPic = None
# x0,y0,x1,y1
self.positionData = {
'HP': [454, 686, 730, 697],
'MP': [454, 699, 730, 710],
'MONEY': [798, 696, 860, 712],
'MAP': [1100, 538, 1279, 719],
'EXP': [410, 622, 448, 710]
}
# 节点名称介绍
"""
首位 L左下边 R右上边
次位 Base基地(结束) T上 M中 B下
再次 Node 召唤节点(结束) T塔 R河道边
再次 1高地塔(结束) 2二塔(结束) 3边塔(结束) 0门牙塔
再次 0左门牙塔(结束) 1右门牙塔(结束)
"""
self.nodesPostions = {
'LSpring': [1105, 707],
'LBase': [1119, 694],
'LBNode': [1142, 699],
'LBT1': [1151, 700],
'LBT2': [1182, 697],
'LBT3': [1223, 700],
'LBR': [1243, 692],
'RBR': [1254, 680],
'RBT3': [1262, 633],
'RBT2': [1256, 619],
'RBT1': [1260, 592],
'RBNode': [1259, 583],
'RT01': [1253, 568],
'RT00': [1247, 563],
'RBase': [1255, 560],
}
self.bottomNodeKeys = ['LSpring', 'LBase', 'LBNode', 'LBT1', 'LBT2', 'LBT3', 'LBR', 'RBR', 'RBT3', 'RBT2',
'RBT1', 'RBNode', 'RT01', 'RT00', 'RBase']
# 数据处理
self.bottimMiddlePoint = []
point = np.array([0, 0])
for key in self.bottomNodeKeys:
newPoint = np.array(self.nodesPostions[key])
if (point == 0).all():
point = newPoint
else:
middlePoint = 0.5 * (point + newPoint)
point = newPoint
self.bottimMiddlePoint.append(middlePoint)
# 创建节点列表
self.bottomNodeList = []
for key in self.bottomNodeKeys:
self.bottomNodeList.append(self.nodesPostions[key])
# 节点插值
assert len(self.bottomNodeList) != 0
self.newBottomNodeList = []
for node in self.bottomNodeList:
node = np.array(node)
if len(self.newBottomNodeList) == 0:
self.newBottomNodeList.append(node)
continue
middlePoint = (self.newBottomNodeList[-1] + node) * 0.5
self.newBottomNodeList.append(middlePoint)
self.newBottomNodeList.append(node)
self.bottomNodeList = self.newBottomNodeList.copy()
assert (len(self.bottomNodeList) != 0)
def init_obs(self):
self.oldobs = np.zeros((4689,))
def element_extract(self, element_name, ori_pic):
"""
从图片中提取某些元素,如小地图,血条,蓝条等,具体数据在self.postionData
:param element_name: 元素名称 具体看那个字典
:param ori_pic: 原始图片 格式cv2图片
:return: 返回地图,格式cv2图片
"""
# print ('element_extract elment:{}'.format(elementName))
targetArea = self.positionData.get(element_name, None)
assert targetArea is not None
pic = ori_pic[targetArea[1]:targetArea[3], targetArea[0]:targetArea[2]]
# print(pic.shape)
return pic
@staticmethod
def pic_display(pic, save_name='', pil=False, no_dis=False):
if np.max(pic) <= 1:
pic = pic * 255
if not no_dis:
if not pil:
cv2.imshow('pic', pic)
cv2.waitKey(0)
cv2.destroyAllWindows()
else:
plt.figure(save_name)
plt.imshow(pic)
plt.show()
if save_name != '':
cv2.imwrite('ans/' + save_name + '.png', pic)
def point_transform(self, point_in, map2all=False):
"""
:param point_in: 输入坐标,【X,Y】
:param map2all: 如果是小地图转通用坐标,该参数为True
:return:返回坐标值列表,【y,x】
"""
if point_in is None:
return -1
point_in = point_in.copy()
yOffset = self.positionData['MAP'][0]
xOffset = self.positionData['MAP'][1]
addRatio = -1
if map2all:
addRatio = 1
point_in[0] += addRatio * yOffset
point_in[1] += addRatio * xOffset
return point_in
def param_extract(self, img, **args):
assert (img.shape == (720, 1280, 3))
same = (img == self.currentPic).all()
print('获取图片 重复判断结果:{}'.format(same))
if same:
return
self.currentPic = img
return paramExtract(self, **args)
def obs_params_extract(self, img, igone_same_check=False):
assert (img.shape == (720, 1280, 3))
same = (img == self.currentPic).all()
if same and not igone_same_check:
print('获取图片 判断为重复 退出:{}'.format(same))
return
self.currentPic = img
params = paramExtract(self, position=False,money=False,exp=False,target=False,tower=False,img=img)
obs = params['mat'].reshape(-1)
hp = params['HP'].reshape(-1)
code = self.encoder(img)
obs = np.hstack((obs, hp, code))
print('single obs shape:',obs.shape)
new_obs = np.hstack((obs,self.oldobs))
self.oldobs = obs
obs = new_obs
print('obs.shape:', obs.shape)
return params, obs
def encoder(self,img):
print(img.shape)
small_img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
small_img = cv2.resize(small_img,(448, 256))
print(small_img.shape)
small_img = small_img / 255
small_img = small_img.reshape(1,256,448,1)
code = model_manager.useModel('encoder.h5',small_img).reshape(-1)
return code
def action_params_augment(self, action, params):
if action in [1, 2, 5, 4]:
params = paramExtract(self, params=params, money=False, exp=False, target_mat=False, tower=False)
return params
# 回家0
# 前进1
# 后退2
# 原地A3
# 走到己方小兵的中心位置4
# 攻击最近的敌方小兵5
@staticmethod
def loading_complete(img):
res = cv2.matchTemplate(img, game_state_check_running_img, cv2.TM_CCOEFF_NORMED)
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
print(max_val)
return max_val > 0.9
def game_running(self, img):
return self.loading_complete(img)
@staticmethod
def game_end(img):
print('game_end:',game_state_check_ending_img.shape)
res = cv2.matchTemplate(img, game_state_check_ending_img, cv2.TM_CCOEFF_NORMED)
min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
print(max_val)
return max_val < 0.9
if __name__ == "__main__":
filename = 'screen310.bmp'
# img = cv2.imread(r'D:\develop\autoLOL\dm\ans\\' + filename)
img = cv2.imread(r'D:\develop\autoLOL\dm\ans\screen295.bmp')
img2 = cv2.imread(r'D:\develop\autoLOL\dm\ans\screen296.bmp')
p = picProcessor()
p.obs_params_extract(img)
p.obs_params_extract(img2)
# print(p.param_extract(img, position=True))