diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100755 index 00000000..c7cd746d --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,91 @@ +{ + "files.associations": { + "*.ipp": "cpp", + "memory": "cpp", + "regex": "cpp", + "utility": "cpp", + "__bit_reference": "cpp", + "__config": "cpp", + "__debug": "cpp", + "__errc": "cpp", + "__functional_base": "cpp", + "__hash_table": "cpp", + "__locale": "cpp", + "__mutex_base": "cpp", + "__node_handle": "cpp", + "__nullptr": "cpp", + "__split_buffer": "cpp", + "__string": "cpp", + "__threading_support": "cpp", + "__tree": "cpp", + "__tuple": "cpp", + "algorithm": "cpp", + "array": "cpp", + "atomic": "cpp", + "bit": "cpp", + "bitset": "cpp", + "cctype": "cpp", + "chrono": "cpp", + "clocale": "cpp", + "cmath": "cpp", + "codecvt": "cpp", + "complex": "cpp", + "condition_variable": "cpp", + "csignal": "cpp", + "cstdarg": "cpp", + "cstddef": "cpp", + "cstdint": "cpp", + "cstdio": "cpp", + "cstdlib": "cpp", + "cstring": "cpp", + "ctime": "cpp", + "cwchar": "cpp", + "cwctype": "cpp", + "deque": "cpp", + "exception": "cpp", + "forward_list": "cpp", + "fstream": "cpp", + "functional": "cpp", + "future": "cpp", + "initializer_list": "cpp", + "iomanip": "cpp", + "ios": "cpp", + "iosfwd": "cpp", + "iostream": "cpp", + "istream": "cpp", + "iterator": "cpp", + "limits": "cpp", + "list": "cpp", + "locale": "cpp", + "map": "cpp", + "mutex": "cpp", + "new": "cpp", + "numeric": "cpp", + "optional": "cpp", + "ostream": "cpp", + "queue": "cpp", + "random": "cpp", + "ratio": "cpp", + "set": "cpp", + "sstream": "cpp", + "stack": "cpp", + "stdexcept": "cpp", + "streambuf": "cpp", + "string": "cpp", + "string_view": "cpp", + "system_error": "cpp", + "thread": "cpp", + "tuple": "cpp", + "type_traits": "cpp", + "typeindex": "cpp", + "typeinfo": "cpp", + "unordered_map": "cpp", + "unordered_set": "cpp", + "valarray": "cpp", + "variant": "cpp", + "vector": "cpp", + "__functional_03": "cpp", + "filesystem": "cpp" + }, + "python.pythonPath": "/usr/local/Caskroom/miniconda/base/envs/moRL/bin/python" +} \ No newline at end of file diff --git a/README.md b/README.md old mode 100644 new mode 100755 diff --git a/elf/ai.h b/elf/ai.h old mode 100644 new mode 100755 index a2658460..8d8b37d2 --- a/elf/ai.h +++ b/elf/ai.h @@ -17,7 +17,7 @@ namespace elf { using namespace std; -template +template // RTSState RTSMCAction class AI_T { public: using Action = A; @@ -43,6 +43,8 @@ class AI_T { virtual ~AI_T() { } + void Print(){std::cout<<"AIName: _name: "<<_name<<" id: "<<_id< +template //带有AICommT的AI class AIWithCommT : public AI_T { public: using AI = AI_T; diff --git a/elf/comm_template.h b/elf/comm_template.h old mode 100644 new mode 100755 index 26a2a110..3224898f --- a/elf/comm_template.h +++ b/elf/comm_template.h @@ -200,14 +200,16 @@ class CommT { _signal->use_queue_per_group(_groups.size()); } + //std::cout<<"-------CommT CollectorsReady---------------"<MainLoop(); }); + _pool.push([p, this](int) { p->MainLoop(); }); // 16 个 Batch Collector } } @@ -343,13 +345,15 @@ class ContextT { const Options &options() const { return _options; } void Start(GameStartFunc game_start_func) { - _comm.CollectorsReady(); - + std::cout<<"--------ContextT Start---------------"<Init(); + void MainLoop(const std::atomic_bool *done = nullptr,bool isPrint = false) { + if(isPrint) + std::cout<<"-------MainLoop----------"<Init(isPrint); // 初始化游戏 + // if(isPrint){ + // std::cout<<"--------Start PleyerInfo"<env().PrintPlayerInfo()<load()) break; @@ -87,6 +93,10 @@ class GameBaseT { // Send message to AIs. _act(false, done); _game_end(); + // if(isPrint){ + // std::cout<<"--------End PleyerInfo"<env().PrintPlayerInfo()<Finalize(); } @@ -122,6 +132,7 @@ class GameBaseT { bot.ai->GameEnd(); } if (_spectator != nullptr) { + std::cout<<"_spectator"<GameEnd(); } } diff --git a/elf/python_options_utils_cpp.h b/elf/python_options_utils_cpp.h index 7063182a..20cfd76e 100644 --- a/elf/python_options_utils_cpp.h +++ b/elf/python_options_utils_cpp.h @@ -19,13 +19,13 @@ struct ContextOptions { // How many simulation threads we are running. - int num_games = 1; + int num_games = 1; // 1024 // The maximum number of threads per game - int max_num_threads = 0; + int max_num_threads = 0; //0 // History length. How long we should keep the history. - int T = 1; + int T = 1; // 20 // verbose options. bool verbose_comm = false; @@ -34,7 +34,7 @@ struct ContextOptions { // Whether we wait for each group or we wait jointly. bool wait_per_group = false; - int num_collectors = 1; + int num_collectors = 1; // 0 mcts::TSOptions mcts_options; diff --git a/elf/tree_search_options.h b/elf/tree_search_options.h index 4813786e..a1bf8b0f 100644 --- a/elf/tree_search_options.h +++ b/elf/tree_search_options.h @@ -20,12 +20,12 @@ using namespace std; struct TSOptions { int max_num_moves = 0; - int num_threads = 16; - int num_rollout_per_thread = 100; - bool verbose = false; + int num_threads = 16; // 0 + int num_rollout_per_thread = 100; // 1 + bool verbose = false; bool verbose_time = false; - string save_tree_filename; + string save_tree_filename; // "" bool persistent_tree = false; // [TODO] Not a good design. diff --git a/elf/utils_elf.py b/elf/utils_elf.py old mode 100644 new mode 100755 index d3e30027..8ed9200c --- a/elf/utils_elf.py +++ b/elf/utils_elf.py @@ -223,12 +223,14 @@ def __init__(self, GC, co, descriptions, use_numpy=False, gpu=None, params=dict( gpu(int): gpu to use. params(dict): additional parameters ''' - + + #self.isPrint = False self._init_collectors(GC, co, descriptions, use_gpu=gpu is not None, use_numpy=use_numpy) self.gpu = gpu self.inputs_gpu = [ self.inputs[gids[0]].cpu2gpu(gpu=gpu) for gids in self.gpu2gid ] if gpu is not None else None self.params = params self._cb = { } + def _init_collectors(self, GC, co, descriptions, use_gpu=True, use_numpy=False): num_games = co.num_games @@ -236,7 +238,7 @@ def _init_collectors(self, GC, co, descriptions, use_gpu=True, use_numpy=False): total_batchsize = 0 for key, v in descriptions.items(): total_batchsize += v["batchsize"] - + if co.num_collectors > 0: num_recv_thread = co.num_collectors else: @@ -269,11 +271,11 @@ def _init_collectors(self, GC, co, descriptions, use_gpu=True, use_numpy=False): for i in range(num_recv_thread): group_id = GC.AddCollectors(batchsize, len(gpu2gid) - 1, timeout_usec, gstat) - input_batch = Batch.load(GC, "input", input, group_id, use_gpu=use_gpu, use_numpy=use_numpy) + input_batch = Batch.load(GC, "input", input, group_id, use_gpu=use_gpu, use_numpy=use_numpy) # 加载输入Batch input_batch.batchsize = batchsize inputs.append(input_batch) if reply is not None: - reply_batch = Batch.load(GC, "reply", reply, group_id, use_gpu=use_gpu, use_numpy=use_numpy) + reply_batch = Batch.load(GC, "reply", reply, group_id, use_gpu=use_gpu, use_numpy=use_numpy) # 加载回复Batch reply_batch.batchsize= batchsize replies.append(reply_batch) else: @@ -298,6 +300,14 @@ def _init_collectors(self, GC, co, descriptions, use_gpu=True, use_numpy=False): self.name2idx = name2idx self.gid2gpu = gid2gpu self.gpu2gid = gpu2gid + # if not self.isPrint: + # print("idx2name",self.idx2name) + # print("name2idx",self.name2idx) + # print("gid2gpu",self.gid2gpu) + # print("gpu2gid",self.gpu2gid) + # print("num_collectors: ",co.num_collectors) + # self.isPrint = True + def reg_has_callback(self, key): return key in self.name2idx @@ -311,6 +321,7 @@ def reg_callback_if_exists(self, key, cb): def reg_callback(self, key, cb): '''Set callback function for key + 注册回调函数,有符合要求和数量的Batch到来时,调用对应的函数 Parameters: key(str): the key used to register the callback function. @@ -332,7 +343,7 @@ def _call(self, infos): raise ValueError("info.gid[%d] is not in callback functions" % infos.gid) if self._cb[infos.gid] is None: - return; + return batchsize = len(infos.s) diff --git a/rlpytorch/sampler/sample_methods.py b/rlpytorch/sampler/sample_methods.py old mode 100644 new mode 100755 index ec657de7..b5b98f38 --- a/rlpytorch/sampler/sample_methods.py +++ b/rlpytorch/sampler/sample_methods.py @@ -33,10 +33,10 @@ def sample_with_check(probs, greedy=False): ''' num_action = probs.size(1) if greedy: - _, actions = probs.max(1) + _, actions = probs.max(1) # 贪婪算法,每次取概率最大的动作 return actions while True: - actions = probs.multinomial(1)[:,0] + actions = probs.multinomial(1)[:,0] # 按照概率选择一个动作 cond1 = (actions < 0).sum() cond2 = (actions >= num_action).sum() if cond1 == 0 and cond2 == 0: @@ -74,8 +74,9 @@ def sample_eps_with_check(probs, epsilon, greedy=False): rej_p = probs.new().resize_(2) rej_p[0] = 1 - epsilon rej_p[1] = epsilon + # rej 按照概率取 0 或 1(batchsize次),取到1时(epsilon)表示此次不选择该动作并随机取样 rej = rej_p.multinomial(batchsize, replacement=True).byte() - + # 随机取样 uniform_p = probs.new().resize_(num_action).fill_(1.0 / num_action) uniform_sampling = uniform_p.multinomial(batchsize, replacement=True) actions[rej] = uniform_sampling[rej] @@ -110,7 +111,7 @@ def sample_multinomial(state_curr, args, node="pi", greedy=False): return actions else: probs = state_curr[node].data - return sample_eps_with_check(probs, args.epsilon, greedy=greedy) + return sample_eps_with_check(probs, args.epsilon, greedy=greedy) # probs 0 False def epsilon_greedy(state_curr, args, node="pi"): ''' epsilon greedy sampling diff --git a/rlpytorch/sampler/sampler.py b/rlpytorch/sampler/sampler.py old mode 100644 new mode 100755 diff --git a/rlpytorch/trainer/trainer.py b/rlpytorch/trainer/trainer.py old mode 100644 new mode 100755 index 8f2b4241..66bbbe15 --- a/rlpytorch/trainer/trainer.py +++ b/rlpytorch/trainer/trainer.py @@ -40,6 +40,7 @@ def __init__(self, name="eval", stats=True, verbose=False, actor_name="actor"): on_get_args = self._on_get_args, child_providers = child_providers ) + self.isPrint = False def _on_get_args(self, _): if self.stats is not None and not self.stats.is_valid(): @@ -75,6 +76,8 @@ def actor(self, batch): if self.sampler is not None: reply_msg = self.sampler.sample(state_curr) + # if not self.isPrint: + # print("sampler reply: ",reply_msg) else: reply_msg = dict(pi=state_curr["pi"].data) @@ -88,6 +91,11 @@ def actor(self, batch): reply_msg["V"] = state_curr["V"].data self.actor_count += 1 + # if not self.isPrint: + # print("batch: ",batch) + # print("state_curr",state_curr) + # print("reply_msg",reply_msg) + # self.isPrint = True return reply_msg def episode_summary(self, i): diff --git a/rts/backend/main_loop.cc b/rts/backend/main_loop.cc old mode 100644 new mode 100755 index 139c8402..ef62ac15 --- a/rts/backend/main_loop.cc +++ b/rts/backend/main_loop.cc @@ -435,7 +435,9 @@ int main(int argc, char *argv[]) { p.stop(true); std::cout << gstats.PrintInfo() << std::endl; } else { + std::cout<<"======RTSStateExtend===="< duration = chrono::system_clock::now() - time_start; cout << "Total time spent = " << duration.count() << "s" << endl; + std::cout<<"======MainLoop======="<(new RTSMap()); + // std::cout<<"-------GameEnv-----------"<ClearMap(); _next_unit_id = 0; _winner_id = INVALID; @@ -121,6 +124,7 @@ bool GameEnv::RemoveUnit(const UnitId &id) { return true; } +// 找到最近的基地 UnitId GameEnv::FindClosestBase(PlayerId player_id) const { // Find closest base. [TODO]: Not efficient here. for (auto it = _units.begin(); it != _units.end(); ++it) { @@ -324,3 +328,15 @@ string GameEnv::PrintDebugInfo() const { ss << _map->PrintDebugInfo() << endl; return ss.str(); } + +string GameEnv::PrintPlayerInfo() const { + stringstream ss; + for (const auto& player : _players) { + ss << "Player " << player.GetId() << endl; + ss << player.PrintInfo()<< endl; + } + + ss << _map->Draw() << endl; + ss << _map->PrintDebugInfo() << endl; + return ss.str(); +} diff --git a/rts/engine/game_env.h b/rts/engine/game_env.h old mode 100644 new mode 100755 index cfdb9584..60e8bc5a --- a/rts/engine/game_env.h +++ b/rts/engine/game_env.h @@ -89,6 +89,7 @@ class GameEnv { // Initialize different units for this game. void InitGameDef() { + //std::cout<<"----------InitGameDef-------"< snapshots; - // Initial seed. If seed = 0, then we use time(NULL) + // Initial seed. If seed = 0, then we use time(NULL) 初始化种子 // When seed != 0, the game should be deterministic (If not then there is a bug somewhere). int seed = 0; diff --git a/rts/engine/game_state.cc b/rts/engine/game_state.cc old mode 100644 new mode 100755 index 52c7c45c..a680448e --- a/rts/engine/game_state.cc +++ b/rts/engine/game_state.cc @@ -12,6 +12,7 @@ using namespace std; using namespace std::chrono; RTSState::RTSState() { + //std::cout<<"-----------RTSState Constructor --------"<PrintInfo()<PrintInfo()<<" "< &cds, const vector &l, UnitAttr attr) { @@ -21,7 +30,7 @@ UnitTemplate _C(int cost, int hp, int defense, float speed, int att, int att_r, p._attr = attr; p._att_r = att_r; p._vis_r = vis_r; - for (int i = 0; i < NUM_COOLDOWN; ++i) { + for (int i = 0; i < NUM_COOLDOWN; ++i) { //设置CD p._cds[i].Set(cds[i]); } diff --git a/rts/engine/gamedef.h b/rts/engine/gamedef.h old mode 100644 new mode 100755 index a2140cd7..6cbc88e7 --- a/rts/engine/gamedef.h +++ b/rts/engine/gamedef.h @@ -21,8 +21,8 @@ struct UnitProperty { int _att_r; float _speed; - // Visualization range. - int _vis_r; + // Visualization range. 可视距离 + int _vis_r; int _changed_hp; UnitId _damage_from; diff --git a/rts/engine/map.cc b/rts/engine/map.cc old mode 100644 new mode 100755 index e11659f6..4826e217 --- a/rts/engine/map.cc +++ b/rts/engine/map.cc @@ -11,11 +11,14 @@ // Constructor RTSMap::RTSMap() { + //std::cout<<"RTSMAP"<& f, int *x1, int *y1, int *x2, int *y2, int i) const { + // std::cout<<"find_two_nearby_empty_slots"<& f, } bool RTSMap::GenerateImpassable(const std::function& f, int nImpassable) { - _map.assign(_m * _n * _level, MapSlot()); - for (int i = 0; i < nImpassable; ++i) { + _map.assign(_m * _n * _level, MapSlot()); //初始化地图格子 类型为NORMAL + + + for (int i = 0; i < nImpassable; ++i) { //随机选一些格子,设为IMPOSSIBLE const int x = f(_m); const int y = f(_n); _map[GetLoc(Coord(x, y))].type = IMPASSABLE; @@ -115,9 +120,10 @@ bool RTSMap::GenerateTDMaze(const std::function& f) { bool RTSMap::GenerateMap(const std::function& f, int nImpassable, int num_player, int init_resource) { // load a map for now simple format. bool success; + // std::cout<<"-------GenerateMap nImpassable = "< RTSMap::GetSight(const Loc& loc, int range) const { const int xmax = std::min(_m - 1, c.x + range); for (int x = xmin; x <= xmax; ++x) { - const int yrange = range - std::abs(c.x - x); + //const int yrange = range - std::abs(c.x - x); + const int x_1 = std::abs(c.x - x); + const int yrange = std::sqrt(range*range - x_1*x_1); const int ymin = std::max(0, c.y - yrange); const int ymax = std::min(_n - 1, c.y + yrange); for (int y = ymin; y <= ymax; ++y) { diff --git a/rts/engine/map.h b/rts/engine/map.h old mode 100644 new mode 100755 index 9ea78932..5d2d4f2a --- a/rts/engine/map.h +++ b/rts/engine/map.h @@ -16,6 +16,7 @@ struct MapSlot { // three layers, terrian, ground and air. + // 只有两层 ground air Terrain type; int height; diff --git a/rts/engine/player.cc b/rts/engine/player.cc old mode 100644 new mode 100755 index 06316b30..8cac3772 --- a/rts/engine/player.cc +++ b/rts/engine/player.cc @@ -58,6 +58,7 @@ void Player::ComputeFOW(const Units &units) { // or loop on the units. // [TODO]: We could do better with LocalitySearch. // Clear fogs. + //std::cout<<"------ComputFOW--------"<second.get(); if (ExtractPlayerId(u->GetId()) != _player_id) { - Loc l = _filter_with_fow(*u); + Loc l = _filter_with_fow(*u); //判断该点的位置是否在视野范围内 // Add the unit info to the loc. - if (l != -1) _fogs[l].SaveUnit(*u); + if (l != -1) _fogs[l].SaveUnit(*u); //在迷雾格中存储该单位 } } } @@ -106,7 +107,9 @@ string Player::PrintInfo() const { stringstream ss; ss << "Map ptr = " << _map << endl; ss << "Player id = " << _player_id << endl; + ss << "Player name = "<< _name << endl; ss << "Resource = " << _resource << endl; + ss << "PlayerPrivilege = "<<_privilege< _prev_seen_units; + vector _prev_seen_units; // 该点可见单位 - void MakeInvisible() { _fog = 100; } + void MakeInvisible() { _fog = 100; } // 让该点不可见 void SetClear() { _fog = 0; _prev_seen_units.clear(); } bool CanSeeTerrain() const { return _fog < 50; } bool CanSeeUnit() const { return _fog < 30; } diff --git a/rts/engine/python_common_options.h b/rts/engine/python_common_options.h old mode 100644 new mode 100755 index e548c7e0..0f0da064 --- a/rts/engine/python_common_options.h +++ b/rts/engine/python_common_options.h @@ -71,21 +71,21 @@ struct PythonOptions { int map_size_x, map_size_y; // Maximum unit command you could send per action. - int max_unit_cmd; + int max_unit_cmd; // 1 // Max tick. - int max_tick; + int max_tick; // 30000 // Random seed to use. seed = 0 mean uses time(NULL). // If seed != 0, then each simulation thread will use a seed which is a deterministic function // of PythonOption.seed and the thread id. - int seed; + int seed; // 0 - bool shuffle_player; + bool shuffle_player; // false int game_name; // [TODO] put handicap to TD. - int handicap_level; + int handicap_level; // 0 PythonOptions() : simulation_type(ST_NORMAL), map_size_x(20), map_size_y(20), max_unit_cmd(10), max_tick(30000), seed(0), shuffle_player(false), game_name(0), handicap_level(0) { diff --git a/rts/engine/rule_actor.h b/rts/engine/rule_actor.h index 474563cc..eed23247 100644 --- a/rts/engine/rule_actor.h +++ b/rts/engine/rule_actor.h @@ -16,7 +16,7 @@ custom_enum(AIState, STATE_START = 0, STATE_BUILD_WORKER, STATE_BUILD_BARRACK, STATE_BUILD_MELEE_TROOP, STATE_BUILD_RANGE_TROOP, STATE_ATTACK, - STATE_ATTACK_IN_RANGE, STATE_HIT_AND_RUN, STATE_DEFEND, NUM_AISTATE); + STATE_ATTACK_IN_RANGE, STATE_HIT_AND_RUN, STATE_DEFEND, NUM_AISTATE); // Action Space custom_enum(FlagState, FLAGSTATE_START = 0, FLAGSTATE_GET_FLAG, FLAGSTATE_ATTACK_FLAG, FLAGSTATE_ESCORT_FLAG, FLAGSTATE_PROTECT_FLAG, //FLAGSTATE_ATTACK, FLAGSTATE_MOVE, diff --git a/rts/engine/wrapper_template.h b/rts/engine/wrapper_template.h old mode 100644 new mode 100755 index eca6d382..9a5726ee --- a/rts/engine/wrapper_template.h +++ b/rts/engine/wrapper_template.h @@ -30,7 +30,11 @@ class WrapperT { const PythonOptions &options, const elf::Signal &signal, const std::map *more_params, Comm *comm) { const string& replay_prefix = options.save_replay_prefix; - + + bool isPrint = false; + if (game_idx==1) + isPrint = true; + // Create a game. RTSGameOptions op; op.seed = (options.seed == 0 ? 0 : options.seed + game_idx); @@ -41,30 +45,47 @@ class WrapperT { op.output_file = options.output_filename; op.cmd_dumper_prefix = options.cmd_dumper_prefix; - // std::cout << "before running wrapper" << std::endl; - - WrapperCB wrapper(game_idx, context_options, options, comm); - wrapper.OnGameOptions(&op); - - // std::cout << "before initializing the game" << std::endl; + if(isPrint){ + string ss = op.PrintInfo(); + std::cout<<"----------RTSGameOptions----"< 20 || xy0[1] > 20) return; + // console.log(xy0); + if (xy0[0] >700 || xy0[1] > 700) return; x_down = e.pageX; y_down = e.pageY; } }, false); + canvas.addEventListener("mouseup", function (e) { var xy0 = convert_xy_back(e.pageX, e.pageY); - if (xy0[0] > 20 || xy0[1] > 20) return; + if (xy0[0] > 700 || xy0[1] > 700) return; if (e.button === 0) { var xy = convert_xy_back(x_down, y_down); if (dragging && x_down && y_down) { @@ -158,6 +170,7 @@ canvas.addEventListener("mouseup", function (e) { } }, false); + canvas.addEventListener("mousemove", function (e) { if (x_down && y_down) { x_curr = e.pageX; @@ -168,21 +181,21 @@ canvas.addEventListener("mousemove", function (e) { } }, false); - +// game.rts_map var onMap = function(m) { var counter = 0; for (y = 0; y < m.height; y++) { for (x = 0; x < m.width; x++){ - var color = cell_colors[m.slots[counter]]; - var x1 = x * cell_size; - var y1 = y * cell_size; + var color = cell_colors[m.slots[counter]]; // counter = 0 m.slots = #404040 + var x1 = x * cell_size; // 一个格子的长是 50 + var y1 = y * cell_size; // 宽 50 ctx.beginPath(); ctx.fillStyle = color; - ctx.lineWidth = 1; - ctx.rect(x1, y1, rect_size, rect_size); - ctx.strokeStyle = 'black'; - ctx.stroke(); - ctx.fillRect(x1, y1, rect_size, rect_size); + ctx.lineWidth = 1; // 格子之间的线宽 1 + ctx.rect(x1, y1, rect_size, rect_size); + ctx.strokeStyle = 'black'; // 矩形 用 black 填充 + ctx.stroke(); + ctx.fillRect(x1, y1, rect_size, rect_size); ctx.closePath(); counter += 1; } @@ -200,7 +213,9 @@ var draw_hp = function(bbox, states, font_color, player_color){ var margin = 2; ctx.fillStyle = 'black'; ctx.lineWidth = margin; + // 开始一条路径 ctx.beginPath(); + // 绘制一个矩阵 ctx.rect(x1, y1, x2 - x1, y2 - y1); ctx.fillRect(x1, y1, x2 - x1, y2 - y1); ctx.strokeStyle = player_color; @@ -211,6 +226,8 @@ var draw_hp = function(bbox, states, font_color, player_color){ if (hp_ratio <= 0.2) color = 'red'; ctx.fillStyle = color; ctx.fillRect(x1, y1, Math.floor((x2 - x1) * hp_ratio + 0.5), y2 - y1); + + // 如果单位有名字,则显示名字 if (state_str){ ctx.beginPath(); ctx.fillStyle = font_color; @@ -220,14 +237,15 @@ var draw_hp = function(bbox, states, font_color, player_color){ } } +// 单位的绘制实现 var onUnit = function(u, isSelected) { - var player_color = player_colors[u.player_id]; - var p = u.p; - var last_p = u.last_p; - var diffx = p.x - last_p.x; + var player_color = player_colors[u.player_id]; // unit 内部的 player_id,用于区分单位是哪一方的 + var p = u.p; // 单位此刻的位置 + var last_p = u.last_p; // 上一刻的位置 + var diffx = p.x - last_p.x; var diffy = p.y - last_p.y; - var ori = "down"; - if (Math.abs(diffx) > Math.abs(diffy)) { + var ori = "down"; + if (Math.abs(diffx) > Math.abs(diffy)) { if (diffx >= 0) { ori = "right"; } else { @@ -242,12 +260,13 @@ var onUnit = function(u, isSelected) { } var xy = convert_xy(p.x, p.y); - draw_sprites(sprites[unit_names_minirts[u.unit_type]], xy[0], xy[1], ori); + draw_sprites(sprites[unit_names_minirts[u.player_id][u.unit_type]], xy[0], xy[1], ori); - var hp_ratio = u.hp / u.max_hp; + var hp_ratio = u.hp / u.max_hp; // 掉血的比例 var state_str; if ("cmd" in u) { if (u.cmd.cmd[0] != 'I') { + // 农民的名字 G1 G0 state_str = u.cmd.cmd[0] + u.cmd.state; } } @@ -255,6 +274,7 @@ var onUnit = function(u, isSelected) { var y1 = xy[1] - 27; var x2 = x1 + unit_size; var y2 = y1 + 5; + // 绘制血条。后两个参数:字体颜色、血条边框颜色 draw_hp([x1, y1, x2, y2], [hp_ratio, state_str], 'white', player_color); if (isSelected) { ctx.lineWidth = 2; @@ -266,40 +286,44 @@ var onUnit = function(u, isSelected) { } }; +// 绘制子弹 var onBullet = function(bullet) { var xy = convert_xy(bullet.p.x, bullet.p.y); + // 加载子弹图片 draw_sprites(bullets, xy[0], xy[1], bullet.state); } -var onPlayerStats = function(player) { - if (player.player_id == 2) { - unit_names_minirts = unit_names_flag; - } - var x1 = left_frame_width + 10; - var y1 = (player.player_id + 1) * 50; - var label = ["PlayerId", player.player_id, "Resource", player.resource].join(" "); - ctx.beginPath() - ctx.fillStyle = "Black"; - ctx.font = "15px Arial"; - ctx.fillText(label, x1, y1); - ctx.closePath(); -} +// 显示 玩家资源的 文本信息 +// var onPlayerStats = function(player) { +// if (player.player_id == 2) { +// unit_names_minirts = unit_names_flag; +// } +// var x1 = left_frame_width + 10; +// var y1 = (player.player_id + 1) * 50; +// var label = ["PlayerId", player.player_id, "Resource", player.resource].join(" "); +// ctx.beginPath() +// ctx.fillStyle = "Black"; +// ctx.font = "15px Arial"; +// ctx.fillText(label, x1, y1); +// ctx.closePath(); +// } + -// Draw units that have been seen. var onPlayerSeenUnits = function(m) { if ("units" in m) { + // 未执行 var oldAlpha = ctx.globalAlpha; ctx.globalAlpha = 0.3; for (var i in m.units) { onUnit(m.units[i], false); } - // console.log(m.units.length) ctx.globalAlpha = oldAlpha; } } +// 空白出显示游戏单位状态 var draw_state = function(u) { var x1 = left_frame_width + 10; var y1 = 150; @@ -339,7 +363,7 @@ var convert_xy_back = function(x, y){ }; var load_sprites = function(spec) { - // Default behavior. + // 默认行为 var specReady = false; var specImage = new Image(); specImage.onload = function () { @@ -350,22 +374,25 @@ var load_sprites = function(spec) { return spec; }; -var draw_sprites = function(spec, px, py, ori) { +// 绘制图片函数 +var draw_sprites = function(spec, px, py, ori) { // 图片 、 var image = spec["image"] var width = image.width; var height = image.height; - if (!("_sizes" in spec)) { - ctx.drawImage(image, px - width / 2, py - height / 2); + if (!("_sizes" in spec)) { + ctx.drawImage(image, px - width / 2, py - height / 2); // 图片,在画布上放的位置的x,y坐标 } else { - var sw = spec["_sizes"][0]; - var sh = spec["_sizes"][1]; + + var sw = spec["_sizes"][0]; + var sh = spec["_sizes"][1]; var nw = Math.floor(width / sw); var nh = Math.floor(height/ sh); var xidx = spec[ori][0]; - var yidx = spec[ori][1]; - var cx = xidx[Math.floor(tick / 3) % xidx.length] * sw; - var cy = yidx[Math.floor(tick / 3) % yidx.length] * sh; - ctx.drawImage(image, cx, cy, sw, sh, px - sw / 2, py - sh / 2, sw, sh); + var yidx = spec[ori][1]; // + var cx = xidx[Math.floor(tick / 3) % xidx.length] * sw; + var cy = yidx[Math.floor(tick / 3) % yidx.length] * sh; + // px - sw / 2, py - sh / 2 代表图像在画布上的 x y 坐标 + ctx.drawImage(image, cx, cy, sw, sh, px - sw / 2, py - sh / 2, sw, sh); // 剪切图像,并在画布上定位被剪切的部分 } }; @@ -374,10 +401,44 @@ var myrange = function (j, k){ return Array.from(new Array(n), (x,i) => i + j); }; -// load pics + var sprites = {}; -sprites["RANGE_ATTACKER"] = load_sprites({ +// sprites["RANGE_ATTACKER"] = load_sprites({ +// "up" : [myrange(15, 22), [0]], +// "down": [myrange(15, 22), [1]], +// "left": [[16], myrange(2, 9)], +// "right": [[15], myrange(2, 9)], +// "_file" : "imgs/tiles.png", +// "_sizes" : [32, 32] +// }); + +// sprites["MELEE_ATTACKER"] = load_sprites({ +// "up" : [myrange(15, 22), [9]], +// "down": [myrange(15, 22), [10]], +// "left": [[20], myrange(2, 9)], +// "right": [[21], myrange(2, 9)], +// "_file" : "imgs/tiles.png", +// "_sizes" : [32, 32] +// }); +sprites["RANGE_ATTACKER1"] = load_sprites({ + "up" : [myrange(15, 22), [0]], + "down": [myrange(15, 22), [1]], + "left": [[16], myrange(2, 9)], + "right": [[15], myrange(2, 9)], + "_file" : "imgs/tile11.png", + "_sizes" : [32, 32] +}); + +sprites["MELEE_ATTACKER1"] = load_sprites({ + "up" : [myrange(15, 22), [9]], + "down": [myrange(15, 22), [10]], + "left": [[20], myrange(2, 9)], + "right": [[21], myrange(2, 9)], + "_file" : "imgs/tile11.png", + "_sizes" : [32, 32] +}); +sprites["RANGE_ATTACKER2"] = load_sprites({ "up" : [myrange(15, 22), [0]], "down": [myrange(15, 22), [1]], "left": [[16], myrange(2, 9)], @@ -386,7 +447,7 @@ sprites["RANGE_ATTACKER"] = load_sprites({ "_sizes" : [32, 32] }); -sprites["MELEE_ATTACKER"] = load_sprites({ +sprites["MELEE_ATTACKER2"] = load_sprites({ "up" : [myrange(15, 22), [9]], "down": [myrange(15, 22), [10]], "left": [[20], myrange(2, 9)], @@ -395,18 +456,20 @@ sprites["MELEE_ATTACKER"] = load_sprites({ "_sizes" : [32, 32] }); +// 资源 sprites["RESOURCE"] = load_sprites({ "_file" : "imgs/mineral1.png", }); - +// 基地 sprites["BASE"] = load_sprites({ "_file" : "imgs/base.png" }); - +// 兵营 sprites["BARRACKS"] = load_sprites({ "_file" : "imgs/barracks.png", }); +// 未执行 var targets = load_sprites({ "attack" : [[11], [6]], "move" : [[14], [6]], @@ -423,9 +486,10 @@ sprites["WORKER"] = load_sprites({ "_sizes" : [32, 32] }); +// 子弹 var bullets = load_sprites({ - "BULLET_READY" : [[7], [0]], - "BULLET_EXPLODE1" : [[0], [0]], + "BULLET_READY" : [[7], [0]], // 子弹类型 + "BULLET_EXPLODE1" : [[0], [0]], // 三种状态 "BULLET_EXPLODE2" : [[1], [0]], "BULLET_EXPLODE3": [[2], [0]], "_file" : "imgs/tiles.png", @@ -451,36 +515,45 @@ sprites["FLAG_BASE"] = load_sprites({ }); + var render = function (game) { tick = game.tick; - ctx.beginPath() - ctx.fillStyle = "Black"; - ctx.font = "15px Arial"; - var label = "Tick: " + tick; - ctx.fillText(label, left_frame_width + 10, 20); - ctx.closePath(); + // Tick + // ctx.beginPath() + // ctx.fillStyle = "Black"; + // ctx.font = "15px Arial"; + // var label = "Tick: " + tick; + // // tick 位置 + // ctx.fillText(label, left_frame_width + 10, 20); + // ctx.closePath(); + + // 加载地图 onMap(game.rts_map); - if (! game.spectator) { + if (! game.spectator) { // game.spectator = true + // 切换成玩家视角 + // rts_map 信息变为某个玩家传来的,下一帧画面生效 onPlayerSeenUnits(game.rts_map); } var all_units = {}; var selected = {}; - for (var i in game.players) { - onPlayerStats(game.players[i]); - } - for (var i in game.units) { + // for (var i in game.players) { // 两个字典,{player_id: 0; resource: 0} + // onPlayerStats(game.players[i]); // {player_id: 1; resource: 0} {player_id: 0; resource: 0} + // } + for (var i in game.units) { var unit = game.units[i]; all_units[unit.id] = unit; - var s_units = game.selected_units; - var isSelected = (s_units && s_units.indexOf(unit.id) >= 0); + var s_units = game.selected_units; // 选中后执行 + var isSelected = (s_units && s_units.indexOf(unit.id) >= 0); if (isSelected) { selected[unit.id] = unit; } + onUnit(unit, isSelected); } + if (dragging && x_down && y_down) { ctx.lineWidth = 2; ctx.beginPath(); @@ -489,46 +562,61 @@ var render = function (game) { ctx.stroke(); ctx.closePath(); } + + // 显示子弹 for (var i in game.bullets) { + // console.log(game.bullets[i]); onBullet(game.bullets[i]); } - var len = Object.keys(selected).length; - if (len == 1) { - var idx = Object.keys(selected)[0]; - var unit = selected[idx]; - draw_state(unit); - } - ctx.beginPath(); - ctx.fillStyle = "Black"; - ctx.font = "15px Arial"; - if (len > 1) { - var label = len + " units"; - ctx.fillText(label ,left_frame_width + 50, 200); - } - var label = "Current FPS is " + Math.floor(50 * Math.pow(1.3, speed)); - ctx.fillText(label, left_frame_width + 50, 570); - if (game.replay_length) { - range1.value = 100 * game.tick / game.replay_length; - } - - var label = "Current progress_percent is " + range1.value; - ctx.fillText(label, left_frame_width + 50, 670); - ctx.closePath(); + var len = Object.keys(selected).length; // 选中几个单位 + // 选中某单位时,空白处显示当前选择单位的状态 + // if (len == 1) { + // var idx = Object.keys(selected)[0]; // 选中的单位的编号 + // var unit = selected[idx]; + // // console.log(unit); // 获取选中的单位的所有属性 + // draw_state(unit); // 在右边空白出显示具体信息 + // } + // ctx.beginPath(); + // ctx.fillStyle = "Black"; + // ctx.font = "15px Arial"; + // // 选中多个单位时,空白处显示 len + " units" + // if (len > 1) { + // var label = len + " units"; + // ctx.fillText(label ,left_frame_width + 50, 200); + // } + // var label = "Current FPS is " + Math.floor(50 * Math.pow(1.3, speed)); + // ctx.fillText(label, left_frame_width + 50, 570); + // if (game.replay_length) { + // range1.value = 100 * game.tick / game.replay_length; + // } + + // var label = "Current progress_percent is " + range1.value; + // ctx.fillText(label, left_frame_width + 50, 670); + // ctx.closePath(); }; + var main = function () { + // 建立连接 dealer = new WebSocket('ws://localhost:8000'); + // 连接建立成功调用的方法 dealer.onopen = function(event) { + // 连接成功后输出 console.log("WS Opened."); } + dealer.onmessage = function (message) { - var s = message.data; + + var s = message.data; + // 将 s 转换为 json 存储在 game 中 var game = JSON.parse(s); + //console.log(game); ctx.clearRect(0, 0, canvas.width, canvas.height); render(game); }; }; var then = Date.now(); +// 开始 main(); diff --git a/rts/frontend/imgs/tile.png b/rts/frontend/imgs/tile.png new file mode 100644 index 00000000..d00885f6 Binary files /dev/null and b/rts/frontend/imgs/tile.png differ diff --git a/rts/frontend/imgs/tile11.png b/rts/frontend/imgs/tile11.png new file mode 100755 index 00000000..58e23116 Binary files /dev/null and b/rts/frontend/imgs/tile11.png differ diff --git a/rts/frontend/imgs/tiles.png b/rts/frontend/imgs/tiles.png index d00885f6..48540e6f 100644 Binary files a/rts/frontend/imgs/tiles.png and b/rts/frontend/imgs/tiles.png differ diff --git a/rts/game_MC/cmd_specific.cc b/rts/game_MC/cmd_specific.cc old mode 100644 new mode 100755 index 406480bc..3a7b8382 --- a/rts/game_MC/cmd_specific.cc +++ b/rts/game_MC/cmd_specific.cc @@ -16,6 +16,7 @@ #include "cmd_specific.gen.h" bool CmdGenerateMap::run(GameEnv *env, CmdReceiver*) { + // std::cout<<"CmdGenerateMap"<GenerateMap(_num_obstacles, _init_resource) ? true : false; } @@ -24,6 +25,7 @@ bool CmdGenerateMap::run(GameEnv *env, CmdReceiver*) { bool CmdGameStartSpecific::run(GameEnv*, CmdReceiver* receiver) { + // std::cout<<"CmdGameStartSpecific"<GetRandomFunc(); - int lr_seed = f(2); - int ud_seed = f(2); - bool shuffle_lr = (lr_seed == 0); - bool shuffle_ud = (ud_seed == 0); - auto shuffle_loc = [&] (PointF p, bool b1, bool b2) -> PointF { - int x = b1 ? 19 - p.x : p.x; - int y = b2 ? 19 - p.y : p.y; - return PointF(x, y); - }; + //std::cout<<"CmdGenerateTDUnit"<GetGameStats().PickBase(lr_seed * 2 + ud_seed); - for (const auto &info : env->GetMap().GetPlayerMapInfo()) { - PlayerId id = info.player_id; - _CREATE(BASE, shuffle_loc(PointF(info.base_coord), shuffle_lr, shuffle_ud), id); - _CREATE(RESOURCE, shuffle_loc(PointF(info.resource_coord), shuffle_lr, shuffle_ud), id); - _CHANGE_RES(id, info.initial_resource); - //base_locs[id] = PointF(info.base_coord); - } - // for (size_t i = 0; i < base_locs.size(); ++i) { - // std::cout << "[" << i << "] Baseloc: " << base_locs[i].x << ", " << base_locs[i].y << std::endl; - //} - auto gen_loc = [&] (int player_id) -> PointF { - // Note that we could not write - // PointF( f(8) + ..., f(8) + ...) - // since the result will depend on which f is evaluated first, and will yield different results on - // different platform/compiler (e.g., clang and gcc yields different results). - // The following implementation is uniquely determined. - int x = f(6) + player_id * 10 + 2; - int y = f(6) + player_id * 10 + 2; - return PointF(x, y); - }; - for (PlayerId player_id = 0; player_id < 2; player_id++) { - PlayerId id = player_id; - // Generate workers (up to three). - for (int i = 0; i < 3; i++) { - if (f(10) >= 5) { - _CREATE(WORKER, shuffle_loc(gen_loc(player_id), shuffle_lr, shuffle_ud), id); - } - } - if (f(10) >= 8) - _CREATE(BARRACKS, shuffle_loc(gen_loc(player_id), shuffle_lr, shuffle_ud), id); - if (f(10) >= 5) - _CREATE(MELEE_ATTACKER, shuffle_loc(gen_loc(player_id), shuffle_lr, shuffle_ud), id); - if (f(10) >= 5) - _CREATE(RANGE_ATTACKER, shuffle_loc(gen_loc(player_id), shuffle_lr, shuffle_ud), id); - } + // enemy + _CREATE(BASE,PointF(1, 1),enemy_id); return true; + } +//-----------------Test-------------------- +// bool CmdGenerateUnit::run(GameEnv *env, CmdReceiver *receiver) { +// // std::cout<<"CmdGenerateUnit"<GetRandomFunc(); +// int lr_seed = f(2); +// int ud_seed = f(2); +// bool shuffle_lr = (lr_seed == 0); +// bool shuffle_ud = (ud_seed == 0); +// auto shuffle_loc = [&] (PointF p, bool b1, bool b2) -> PointF { +// int x = b1 ? 19 - p.x : p.x; +// int y = b2 ? 19 - p.y : p.y; +// return PointF(x, y); +// }; + +// receiver->GetGameStats().PickBase(lr_seed * 2 + ud_seed); +// for (const auto &info : env->GetMap().GetPlayerMapInfo()) { +// PlayerId id = info.player_id; +// _CREATE(BASE, shuffle_loc(PointF(info.base_coord), shuffle_lr, shuffle_ud), id); +// _CREATE(RESOURCE, shuffle_loc(PointF(info.resource_coord), shuffle_lr, shuffle_ud), id); +// _CHANGE_RES(id, info.initial_resource); +// //base_locs[id] = PointF(info.base_coord); +// } +// // for (size_t i = 0; i < base_locs.size(); ++i) { +// // std::cout << "[" << i << "] Baseloc: " << base_locs[i].x << ", " << base_locs[i].y << std::endl; +// //} +// auto gen_loc = [&] (int player_id) -> PointF { +// // Note that we could not write +// // PointF( f(8) + ..., f(8) + ...) +// // since the result will depend on which f is evaluated first, and will yield different results on +// // different platform/compiler (e.g., clang and gcc yields different results). +// // The following implementation is uniquely determined. +// int x = f(6) + player_id * 10 + 2; +// int y = f(6) + player_id * 10 + 2; +// return PointF(x, y); +// }; +// for (PlayerId player_id = 0; player_id < 2; player_id++) { +// PlayerId id = player_id; +// // Generate workers (up to three). +// for (int i = 0; i < 3; i++) { +// if (f(10) >= 5) { +// _CREATE(WORKER, shuffle_loc(gen_loc(player_id), shuffle_lr, shuffle_ud), id); +// } +// } +// if (f(10) >= 8) +// _CREATE(BARRACKS, shuffle_loc(gen_loc(player_id), shuffle_lr, shuffle_ud), id); +// if (f(10) >= 5) +// _CREATE(MELEE_ATTACKER, shuffle_loc(gen_loc(player_id), shuffle_lr, shuffle_ud), id); +// if (f(10) >= 5) +// _CREATE(RANGE_ATTACKER, shuffle_loc(gen_loc(player_id), shuffle_lr, shuffle_ud), id); +// } +// return true; +// } + #undef _CHANGE_RES #undef _CREATE diff --git a/rts/game_MC/game.py b/rts/game_MC/game.py old mode 100644 new mode 100755 index 40c36437..d2cf20fb --- a/rts/game_MC/game.py +++ b/rts/game_MC/game.py @@ -20,11 +20,11 @@ def __init__(self): def _define_args(self): return [ - ("use_unit_action", dict(action="store_true")), - ("disable_time_decay", dict(action="store_true")), - ("use_prev_units", dict(action="store_true")), - ("attach_complete_info", dict(action="store_true")), - ("feature_type", "ORIGINAL") + ("use_unit_action", dict(action="store_true")), # false + ("disable_time_decay", dict(action="store_true")), # false + ("use_prev_units", dict(action="store_true")), # false + ("attach_complete_info", dict(action="store_true")),# false + ("feature_type", "ORIGINAL") # "ORIGINAL" ] def _on_gc(self, GC): @@ -32,11 +32,11 @@ def _on_gc(self, GC): opt.use_time_decay = not self.args.disable_time_decay opt.save_prev_seen_units = self.args.use_prev_units opt.attach_complete_info = self.args.attach_complete_info - GC.ApplyExtractorParams(opt) + GC.ApplyExtractorParams(opt) # 设置 MCExtractor usage = minirts.MCExtractorUsageOptions() usage.Set(self.args.feature_type) - GC.ApplyExtractorUsage(usage) + GC.ApplyExtractorUsage(usage) # 设置 ExtractorUsage def _unit_action_keys(self): if self.args.use_unit_action: @@ -44,12 +44,12 @@ def _unit_action_keys(self): else: return [] - def _get_actor_spec(self): + def _get_actor_spec(self): #d=定义用于actor的batch字典 reply_keys = ["V", "pi", "a"] return dict( batchsize=self.args.batchsize, - input=dict(T=1, keys=set(["s", "last_r", "terminal"])), + input=dict(T=1, keys=set(["s", "last_r", "terminal"])), # 期待收到 s last_r terminal reply=dict(T=1, keys=set(reply_keys + self._unit_action_keys())), ) @@ -58,7 +58,7 @@ def _get_train_spec(self): return dict( batchsize=self.args.batchsize, input=dict(T=self.args.T, keys=set(keys + self._unit_action_keys())), - reply=None + reply=None # train 不需要回复 ) def _get_reduced_predict(self): diff --git a/rts/game_MC/game_action.h b/rts/game_MC/game_action.h old mode 100644 new mode 100755 diff --git a/rts/game_MC/gamedef.cc b/rts/game_MC/gamedef.cc old mode 100644 new mode 100755 index 1bbde036..cb199805 --- a/rts/game_MC/gamedef.cc +++ b/rts/game_MC/gamedef.cc @@ -35,7 +35,7 @@ bool GameDef::CheckAddUnit(RTSMap *_map, UnitType, const PointF& p) const{ } void GameDef::GlobalInit() { - reg_engine(); + reg_engine(); // reg_engine_specific(); reg_minirts_specific(); @@ -55,21 +55,47 @@ void GameDef::GlobalInit() { void GameDef::Init() { _units.assign(GetNumUnitType(), UnitTemplate()); + /** + * int cost, int hp, int defense, float speed, int att, int att_r, int vis_r, + const vector &cds, const vector &l, UnitAttr attr**/ + + // _units[RESOURCE] = _C(0, 1000, 1000, 0, 0, 0, 0, vector{0, 0, 0, 0}, vector{}, ATTR_INVULNERABLE); + // _units[WORKER] = _C(50, 50, 0, 0.1, 2, 1, 3, vector{0, 10, 40, 40}, vector{MOVE, ATTACK, BUILD, GATHER}); + // _units[MELEE_ATTACKER] = _C(100, 100, 1, 0.1, 15, 1, 3, vector{0, 15, 0, 0}, vector{MOVE, ATTACK}); + // _units[RANGE_ATTACKER] = _C(100, 50, 0, 0.2, 10, 5, 5, vector{0, 10, 0, 0}, vector{MOVE, ATTACK}); + // _units[BARRACKS] = _C(200, 200, 1, 0.0, 0, 0, 5, vector{0, 0, 0, 50}, vector{BUILD}); + // _units[BASE] = _C(500, 500, 2, 0.0, 0, 0, 5, {0, 0, 0, 50}, vector{BUILD}); + + /** + * 定义单位属性 + * cost 造价(不需要) + * hp 血量 + * defence防御力 + * speed 速度 + * att 攻击力 + * att_r 攻击距离 + * vis_r 可视距离 + * * cost hp def sp att att_r vis_r + * */ + _units[RESOURCE] = _C(0, 1000, 1000, 0, 0, 0, 0, vector{0, 0, 0, 0}, vector{}, ATTR_INVULNERABLE); - _units[WORKER] = _C(50, 50, 0, 0.1, 2, 1, 3, vector{0, 10, 40, 40}, vector{MOVE, ATTACK, BUILD, GATHER}); - _units[MELEE_ATTACKER] = _C(100, 100, 1, 0.1, 15, 1, 3, vector{0, 15, 0, 0}, vector{MOVE, ATTACK}); - _units[RANGE_ATTACKER] = _C(100, 50, 0, 0.2, 10, 5, 5, vector{0, 10, 0, 0}, vector{MOVE, ATTACK}); - _units[BARRACKS] = _C(200, 200, 1, 0.0, 0, 0, 5, vector{0, 0, 0, 50}, vector{BUILD}); - _units[BASE] = _C(500, 500, 2, 0.0, 0, 0, 5, {0, 0, 0, 50}, vector{BUILD}); + _units[WORKER] = _C(50, 50, 0, 0.1, 2, 1, 0, vector{0, 10, 40, 40}, vector{MOVE, ATTACK, BUILD, GATHER}); + _units[MELEE_ATTACKER] = _C(50, 100, 1, 0.1, 15, 10, 0, vector{0, 15, 0, 0}, vector{MOVE,ATTACK}); + _units[RANGE_ATTACKER] = _C(100, 50, 0, 0.2, 0, 0, 15, vector{0, 0, 0, 0}, vector{}); + _units[BARRACKS] = _C(200, 200, 1, 0.0, 0, 0, 0, vector{0, 0, 0, 50}, vector{BUILD}); + _units[BASE] = _C(500, 500, 2, 0.0, 0, 0, 0, {0, 0, 0, 50}, vector{BUILD}); + + } vector > GameDef::GetInitCmds(const RTSGameOptions&) const{ vector > init_cmds; - init_cmds.push_back(make_pair(CmdBPtr(new CmdGenerateMap(INVALID, 0, 200)), 1)); - init_cmds.push_back(make_pair(CmdBPtr(new CmdGenerateUnit(INVALID)), 2)); + init_cmds.push_back(make_pair(CmdBPtr(new CmdGenerateMap(INVALID, 0, 200)), 1)); // 障碍 资源 + init_cmds.push_back(make_pair(CmdBPtr(new CmdGenerateUnit(INVALID)), 2)); return init_cmds; } +// 通过判断最后一个拥有基地的玩家来确定胜利者 PlayerId GameDef::CheckWinner(const GameEnv& env, bool /*exceeds_max_tick*/) const { return env.CheckBase(BASE); } diff --git a/rts/game_MC/model.py b/rts/game_MC/model.py old mode 100644 new mode 100755 index 04bf5f61..f1848f3c --- a/rts/game_MC/model.py +++ b/rts/game_MC/model.py @@ -18,12 +18,13 @@ class Model_ActorCritic(Model): def __init__(self, args): super(Model_ActorCritic, self).__init__(args) self._init(args) + #self.isPrint = False def _init(self, args): params = args.params assert isinstance(params["num_action"], int), "num_action has to be a number. action = " + str(params["num_action"]) self.params = params - self.net = MiniRTSNet(args) + self.net = MiniRTSNet(args) # 卷积神经网络处理输入数据 last_num_channel = self.net.num_channels[-1] if self.params.get("model_no_spatial", False): @@ -31,9 +32,11 @@ def _init(self, args): linear_in_dim = last_num_channel else: linear_in_dim = last_num_channel * 25 + + - self.linear_policy = nn.Linear(linear_in_dim, params["num_action"]) - self.linear_value = nn.Linear(linear_in_dim, 1) + self.linear_policy = nn.Linear(linear_in_dim, params["num_action"]) # 策略函数 + self.linear_value = nn.Linear(linear_in_dim, 1) # 价值函数 self.relu = nn.LeakyReLU(0.1) @@ -49,13 +52,21 @@ def get_define_args(): def forward(self, x): if self.params.get("model_no_spatial", False): # Replace a complicated network with a simple retraction. - # Input: batchsize, channel, height, width + # Input: batchsize, channel, height, width Batch Object xreduced = x["s"].sum(2).sum(3).squeeze() xreduced[:, self.num_unit:] /= 20 * 20 output = self._var(xreduced) else: output = self.net(self._var(x["s"])) - + + #decide = self.decision(output) + #if not self.isPrint: + #print("x: ",x.batch) + # print("output: ",output) + # print("decision: ",decide) + # print("net: ",self) + #self.isPrint = True + #return decide return self.decision(output) def decision(self, h): diff --git a/rts/game_MC/python_options.h b/rts/game_MC/python_options.h old mode 100644 new mode 100755 index 52629a02..1eb52308 --- a/rts/game_MC/python_options.h +++ b/rts/game_MC/python_options.h @@ -22,6 +22,12 @@ struct GameState { using State = GameState; using Data = GameState; + /** + * 测试获取基地位置信息 + * */ + //float base_x, base_y; + + int32_t id; int32_t seq; int32_t game_counter; @@ -52,7 +58,7 @@ struct GameState { std::vector uloc, tloc; std::vector bt, ct; - // Also we need to save distributions. + // Also we need to save distributions. 概率? std::vector uloc_prob, tloc_prob; std::vector bt_prob, ct_prob; @@ -133,6 +139,7 @@ struct GameState { // These fields are used to exchange with Python side using tensor interface. DECLARE_FIELD(GameState, id, a, V, pi, last_r, s, rv, terminal, seq, game_counter, last_terminal, uloc, tloc, bt, ct, uloc_prob, tloc_prob, bt_prob, ct_prob, reduced_s, reduced_next_s); + //DECLARE_FIELD(GameState, id, a, V, pi, last_r, s, rv, terminal, seq, game_counter, last_terminal, uloc, tloc, bt, ct, uloc_prob, tloc_prob, bt_prob, ct_prob, reduced_s, reduced_next_s,base_x,base_y); REGISTER_PYBIND_FIELDS(id); }; diff --git a/rts/game_MC/python_wrapper.cc b/rts/game_MC/python_wrapper.cc old mode 100644 new mode 100755 index 3788527c..013453f5 --- a/rts/game_MC/python_wrapper.cc +++ b/rts/game_MC/python_wrapper.cc @@ -33,32 +33,36 @@ class GameContext { GameDef::GlobalInit(); _context.reset(new GC{context_options, options}); - _num_frames_in_state = 1; + _num_frames_in_state = 1; for (const AIOptions& opt : options.ai_options) { _num_frames_in_state = max(_num_frames_in_state, opt.num_frames_in_state); } } void Start() { + std::cout<<"--------------GameContext Start-----------"<Start( [this](int game_idx, const ContextOptions &context_options, const PythonOptions &options, const elf::Signal &signal, Comm *comm) { auto params = this->GetParams(); + if(game_idx == 1){ + std::cout<<"game_"<_wrapper.thread_main(game_idx, context_options, options, signal, ¶ms, comm); }); } std::map GetParams() const { return std::map{ - { "num_action", GameDef::GetNumAction() }, - { "num_unit_type", GameDef::GetNumUnitType() }, - { "num_planes_per_time_stamp", MCExtractor::Size() }, - { "num_planes", MCExtractor::Size() * _num_frames_in_state }, - { "resource_dim", 2 * NUM_RES_SLOT }, - { "max_unit_cmd", _context->options().max_unit_cmd }, - { "map_x", _context->options().map_size_x }, - { "map_y", _context->options().map_size_y }, - { "num_cmd_type", CmdInput::CI_NUM_CMDS }, - { "reduced_dim", MCExtractor::Size() * 5 * 5 } + { "num_action", GameDef::GetNumAction() }, // 9 + { "num_unit_type", GameDef::GetNumUnitType() }, // 6 + { "num_planes_per_time_stamp", MCExtractor::Size() }, // 22 每一个时间戳中的 planes数? + { "num_planes", MCExtractor::Size() * _num_frames_in_state }, // 22 每一个状态包含一帧的数据 + { "resource_dim", 2 * NUM_RES_SLOT }, // 10 + { "max_unit_cmd", _context->options().max_unit_cmd }, // 1 + { "map_x", _context->options().map_size_x }, // 20 + { "map_y", _context->options().map_size_y }, // 20 + { "num_cmd_type", CmdInput::CI_NUM_CMDS }, // 4 + { "reduced_dim", MCExtractor::Size() * 5 * 5 }// 22*5*5 }; } @@ -88,7 +92,9 @@ class GameContext { else if (key == "ct_prob") return EntryInfo(key, type_name, { max_unit_cmd, CmdInput::CI_NUM_CMDS }); else if (key == "reduced_s") return EntryInfo(key, type_name, { reduced_size }); else if (key == "reduced_next_s") return EntryInfo(key, type_name, { reduced_size }); + else if (key == "base_x" || key == "base_y") return EntryInfo(key, type_name); + return EntryInfo(); } diff --git a/rts/game_MC/state_feature.cc b/rts/game_MC/state_feature.cc old mode 100644 new mode 100755 index 5ff62a8a..3e822adc --- a/rts/game_MC/state_feature.cc +++ b/rts/game_MC/state_feature.cc @@ -28,6 +28,14 @@ void MCExtractor::SaveInfo(const RTSState &s, PlayerId player_id, GameState *gs) gs->terminal = s.env().GetTermination() ? 1 : 0; gs->last_r = 0.0; + + // 测试获取基地位置 + // UnitId baseId = s.env().FindClosestBase(player_id); + //const Unit* base = s.env().GetUnit(s.env().FindClosestBase(player_id)); + //gs->base_x = s.env().GetUnit(s.env().FindClosestBase(player_id))->GetPointF().x; + //gs->base_y = s.env().GetUnit(s.env().FindClosestBase(player_id))->GetPointF().y; + // 测试获取基地位置 + int winner = s.env().GetWinnerId(); if (winner != INVALID) { if (winner == player_id) gs->last_r = 1.0; diff --git a/rts/game_MC/state_feature.h b/rts/game_MC/state_feature.h index f627a2ab..ea81b6a1 100644 --- a/rts/game_MC/state_feature.h +++ b/rts/game_MC/state_feature.h @@ -22,6 +22,7 @@ using namespace std; #define NUM_RES_SLOT 5 +// MCExtractor设置 struct MCExtractorOptions { bool use_time_decay = true; bool save_prev_seen_units = false; @@ -122,15 +123,15 @@ class MCExtractorInfo { void Reset(const MCExtractorOptions &opt) { extractors_.clear(); - int num_unit_type = GameDef::GetNumUnitType(); + int num_unit_type = GameDef::GetNumUnitType(); total_dim_ = 0; - total_dim_ += _add_extractor("UnitType", new ExtractorSpan(total_dim_, num_unit_type)); - total_dim_ += _add_extractor("Feature", new ExtractorSpan(total_dim_, NUM_FEATURE)); + total_dim_ += _add_extractor("UnitType", new ExtractorSpan(total_dim_, num_unit_type)); // 6 + total_dim_ += _add_extractor("Feature", new ExtractorSpan(total_dim_, NUM_FEATURE)); // 4 std::initializer_list ticks = { 200, 500, 1000, 2000, 5000, 10000 }; if (opt.use_time_decay) { - total_dim_ += _add_extractor("HistBin", new ExtractorSeq(total_dim_, ticks)); + total_dim_ += _add_extractor("HistBin", new ExtractorSeq(total_dim_, ticks)); // 6+1 } if (opt.save_prev_seen_units) { @@ -138,7 +139,7 @@ class MCExtractorInfo { total_dim_ += _add_extractor("HistBinPrevSeen", new ExtractorSeq(total_dim_, ticks)); } - total_dim_ += _add_extractor("Resource", new ExtractorSpan(total_dim_, NUM_RES_SLOT, 50)); + total_dim_ += _add_extractor("Resource", new ExtractorSpan(total_dim_, NUM_RES_SLOT, 50)); // 5 } int size() const { return total_dim_; } @@ -150,7 +151,7 @@ class MCExtractorInfo { private: int total_dim_; - map> extractors_; + map> extractors_; // int _add_extractor(const std::string &name, Extractor *e) { extractors_[name].reset(e); diff --git a/rts/game_MC/trunk.py b/rts/game_MC/trunk.py old mode 100644 new mode 100755 index e685481a..a226026d --- a/rts/game_MC/trunk.py +++ b/rts/game_MC/trunk.py @@ -16,6 +16,7 @@ def __init__(self, args, output1d=True): super(MiniRTSNet, self).__init__(args) self._init(args) self.output1d = output1d + #self.isPrint = False def _init(self, args): self.m = args.params.get("num_planes_per_time_stamp", 13) @@ -64,7 +65,13 @@ def get_define_args(): def forward(self, input): # BN and LeakyReLU are from Wendy's code. x = input.view(input.size(0), self.input_channel, self.mapy, self.mapx) - + # if not self.isPrint: + # print("input size: ",input.size()) + # print("input: ",input) + # print("x size:",x.size()) + # print("x: ",x) + # print("Net",self) + # self.isPrint = True counts = Counter() for i in range(len(self.arch)): if self.arch[i] == "c": @@ -78,5 +85,10 @@ def forward(self, input): if self.output1d: x = x.view(x.size(0), -1) - - return x + + # if not self.isPrint: + # print("x",x) + # print("x.size ",x.size()) + # self.isPrint = True + + return x # 64 x 550 diff --git a/rts/game_MC/wrapper_callback.cc b/rts/game_MC/wrapper_callback.cc old mode 100644 new mode 100755 index 40a06ab4..80195b71 --- a/rts/game_MC/wrapper_callback.cc +++ b/rts/game_MC/wrapper_callback.cc @@ -65,11 +65,11 @@ void WrapperCallbacks::OnGameOptions(RTSGameOptions *rts_options) { rts_options->handicap_level = _options.handicap_level; } -void WrapperCallbacks::OnGameInit(RTSGame *game, const std::map *more_params) { +void WrapperCallbacks::OnGameInit(RTSGame *game, const std::map *more_params,bool isPrint) { // std::cout << "Initialize opponent" << std::endl; std::vector ais; for (const AIOptions &ai_opt : _options.ai_options) { - Context::AIComm *ai_comm = new Context::AIComm(_game_idx, _comm); + Context::AIComm *ai_comm = new Context::AIComm(_game_idx, _comm); //设置 AI 和 Main_Loop通信的工具 _ai_comms.emplace_back(ai_comm); initialize_ai_comm(*ai_comm, more_params); ais.push_back(get_ai(_game_idx, _context_options.mcts_options, ai_opt, ai_comm)); @@ -85,12 +85,17 @@ void WrapperCallbacks::OnGameInit(RTSGame *game, const std::mapAddBot(ais[idx], _options.ai_options[idx].fs); game->GetState().AppendPlayer("player " + std::to_string(idx)); } + // 输出玩家信息 + if(isPrint){ + std::cout<<"ais size = "<GetState().env().PrintPlayerInfo()< *more_params); + void OnGameInit(RTSGame *game, const std::map *more_params,bool isPrint = false); void OnEpisodeStart(int k, std::mt19937 *rng, RTSGame *game); }; diff --git a/train.py b/train.py old mode 100644 new mode 100755 index 8b30c168..17fcba0d --- a/train.py +++ b/train.py @@ -16,8 +16,26 @@ if __name__ == '__main__': trainer = Trainer() runner = SingleProcessRun() - env, all_args = load_env(os.environ, trainer=trainer, runner=runner) - + ''' + 设置环境参数 + env : dice [ + game = game game.py + sampler = sampler + model_loaders = 存ModelLoader类的字典,每一个ModelLoader中有一个model + mi=mi + ] + + all_args: 所有定义的参数 + ''' + env, all_args = load_env(os.environ, trainer=trainer, runner=runner) + + ''' + GC = GCWrapper(GC, co, desc, gpu=args.gpu, use_numpy=False, params=params) + GC /gameMC/python_wrapper.cc GameContext + co /elf/python_options_utils_cpp.h ContextOption + desc actor 和 critic 的 Batch定义 + {'actor': {'batchsize': 64, 'input': {'T': 1, 'keys': {'s', 'terminal', 'last_r'}}, 'reply': {'T': 1, 'keys': {'a', 'pi', 'V'}}}} + ''' GC = env["game"].initialize() model = env["model_loaders"][0].load_model(GC.params) @@ -27,7 +45,15 @@ trainer.setup(sampler=env["sampler"], mi=env["mi"], rl_method=env["method"]) GC.reg_callback("train", trainer.train) + # def train(batch): + # print(batch) + # import pdb + # pdb.set_trace() + # return trainer.train(batch) + + # GC.reg_callback("train", train) GC.reg_callback("actor", trainer.actor) + runner.setup(GC, episode_summary=trainer.episode_summary, episode_start=trainer.episode_start) diff --git a/train_minirts.sh b/train_minirts.sh index 1fa620da..fee80a44 100755 --- a/train_minirts.sh +++ b/train_minirts.sh @@ -1,3 +1,3 @@ #!/bin/bash -game=./rts/game_MC/game model=actor_critic model_file=./rts/game_MC/model python3 train.py --batchsize 128 --freq_update 1 --players "type=AI_NN,fs=50,args=backup/AI_SIMPLE|start/500|decay/0.99;type=AI_SIMPLE,fs=20" --num_games 1024 --tqdm --T 20 --additional_labels id,last_terminal --trainer_stats winrate --keys_in_reply V "$@" +game=./rts/game_MC/game model=actor_critic model_file=./rts/game_MC/model python3 train.py --batchsize 32 --freq_update 1 --players "type=AI_NN,fs=50,args=backup/AI_SIMPLE|start/500|decay/0.99;type=AI_SIMPLE,fs=20" --num_games 1024 --tqdm --T 20 --additional_labels id,last_terminal --trainer_stats winrate --keys_in_reply V "$@"