diff --git a/scripts/generate_model.py b/scripts/generate_model.py index aa0d249..f0e518d 100644 --- a/scripts/generate_model.py +++ b/scripts/generate_model.py @@ -938,6 +938,15 @@ def generate_runtime(llvm_instr,addr_list,pes): """ print(includes) + def_mutex = """ +std::mutex log_mutex; +std::mutex stall_mutex;\n""" + + for pe in pes: + def_mutex += "std::mutex " + pe.name + "_mutex;\n" + + print(def_mutex) + def_maps = """static std::map trace_bools; static std::map stall_counts;\n""" @@ -1014,30 +1023,25 @@ def generate_runtime(llvm_instr,addr_list,pes): #print(init_memtest) register = "inline void __register_entry(unsigned id, std::string e) {\n" + register += " log_mutex.lock();\n" register += " (*traceFiles.at(id)) << e;\n" register += " traceFiles.at(id)->flush();\n" + register += " log_mutex.unlock();\n" register += "}" print(register) write_memtest = """void write_hetsim() {\n""" + write_memtest += " log_mutex.lock();\n" write_memtest += " int size = traceFiles.size();\n" write_memtest += " for (int i = 0; i < size; ++i) {\n" write_memtest += " __register_entry(i, \"EOF\");\n""" write_memtest += " traceFiles.at(i)->flush();\n" write_memtest += " traceFiles.at(i)->close();\n" write_memtest += " }\n" + write_memtest += " log_mutex.unlock();\n" write_memtest += "}\n" print(write_memtest) - def_mutex = """ -std::mutex log_mutex; -std::mutex stall_mutex;\n""" - - for pe in pes: - def_mutex += "std::mutex " + pe.name + "_mutex;\n" - - print(def_mutex) - print(""" std::string __LD(void *addr, std::vector deps, @@ -1136,7 +1140,12 @@ def generate_runtime(llvm_instr,addr_list,pes): print(stall_func) increment_stalls = ''' -int increment_stalls(int num_stalls, int num_deps = 0, void *dep1 = NULL, void *dep2 = NULL, void *dep3 = NULL) { +int increment_stalls(int num_stalls, int num_deps, void *dep1, void *dep2, void *dep3) { + + stall_mutex.lock(); + bool found_stall_counts = stall_counts.find(std::this_thread::get_id()) != stall_counts.end(); + stall_mutex.unlock(); + if (num_deps > 0) { if (is_log_open()) { emit_stall(); @@ -1148,8 +1157,10 @@ def generate_runtime(llvm_instr,addr_list,pes): stall_counts[std::this_thread::get_id()] += num_stalls; stall_mutex.unlock(); } - else if (stall_counts.find(std::this_thread::get_id()) != stall_counts.end()) { + else if (found_stall_counts) { + stall_mutex.lock(); stall_counts[std::this_thread::get_id()] = 0; + stall_mutex.unlock(); } return 0; }''' @@ -1158,6 +1169,11 @@ def generate_runtime(llvm_instr,addr_list,pes): emit_stall = ''' int emit_stall() { std::thread::id thread_id = std::this_thread::get_id(); + + stall_mutex.lock(); + bool found_stall_counts = stall_counts.find(std::this_thread::get_id()) != stall_counts.end(); + stall_mutex.unlock(); + if (is_log_open()) { stall_mutex.lock(); int stalls = stall_counts[thread_id]; @@ -1175,8 +1191,10 @@ def generate_runtime(llvm_instr,addr_list,pes): emit_stall += ''' } } - else if (stall_counts.find(std::this_thread::get_id()) != stall_counts.end()) { + else if (found_stall_counts) { + stall_mutex.lock(); stall_counts[std::this_thread::get_id()] = 0; + stall_mutex.unlock(); } return 0; }\n''' @@ -1188,14 +1206,12 @@ def generate_runtime(llvm_instr,addr_list,pes): func += """ std::string fname = "./traces/pe_" + std::to_string(tid) + ".trace"; log_mutex.lock(); traceFiles.insert(std::make_pair(tid, std::make_shared(fname, std::ios::out))); - log_mutex.unlock(); if (!traceFiles.at(tid)) { printf("[ERROR] couldn't open trace file to write: %s\\n", fname.c_str()); exit(1); }\n """ - func +=""" log_mutex.lock(); - trace_bools[std::this_thread::get_id()] = true; + func +=""" trace_bools[std::this_thread::get_id()] = true; log_mutex.unlock();\n""" func += """ stall_mutex.lock(); @@ -1211,11 +1227,11 @@ def generate_runtime(llvm_instr,addr_list,pes): if (trace_bools.find(std::this_thread::get_id()) == trace_bools.end()) { trace_bools[std::this_thread::get_id()] = false; } - log_mutex.unlock(); + log_mutex.unlock(); __register_entry(tid, "EOF\\n"); - traceFiles.at(tid)->close(); log_mutex.lock(); + traceFiles.at(tid)->close(); traceFiles.erase(tid); log_mutex.unlock(); }""") @@ -1295,7 +1311,7 @@ def generate_runtime(llvm_instr,addr_list,pes): print(emit_store); increment_stalls = """ -int increment_stalls(int num_stalls, int num_deps = 0, void *dep1 = NULL, void *dep2 = NULL, void *dep3 = NULL) { +int increment_stalls(int num_stalls, int num_deps, void *dep1 = NULL, void *dep2 = NULL, void *dep3 = NULL) { if (num_deps > 0) { if (is_log_open()) { @@ -1317,8 +1333,13 @@ def generate_runtime(llvm_instr,addr_list,pes): """ emit_stall_with_deps = """ -int emit_stall_with_deps(int num_stalls, int num_deps = 0, void *dep1, void *dep2, void *dep3) { +int emit_stall_with_deps(int num_stalls, int num_deps, void *dep1, void *dep2, void *dep3) { std::thread::id thread_id = std::this_thread::get_id(); + + stall_mutex.lock(); + bool found_stall_counts = stall_counts.find(std::this_thread::get_id()) != stall_counts.end(); + stall_mutex.unlock(); + if (is_log_open()) { stall_mutex.lock(); int stalls = stall_counts[thread_id]; @@ -1354,8 +1375,10 @@ def generate_runtime(llvm_instr,addr_list,pes): emit_stall_with_deps += """ } } - else if (stall_counts.find(std::this_thread::get_id()) != stall_counts.end()) { + else if (found_stall_counts) { + stall_mutex.lock(); stall_counts[std::this_thread::get_id()] = 0; + stall_mutex.unlock(); } return 0; }"""