Skip to content

Commit

Permalink
Changes from internal repo
Browse files Browse the repository at this point in the history
  • Loading branch information
subhankarpal committed Sep 30, 2020
1 parent e9d32a2 commit 27bebb2
Showing 1 changed file with 43 additions and 20 deletions.
63 changes: 43 additions & 20 deletions scripts/generate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,15 @@ def generate_runtime(llvm_instr,addr_list,pes):
"""
print(includes)

def_mutex = """
std::mutex log_mutex;
std::mutex stall_mutex;\n"""

for pe in pes:
def_mutex += "std::mutex " + pe.name + "_mutex;\n"

print(def_mutex)

def_maps = """static std::map<std::thread::id, bool> trace_bools;
static std::map<std::thread::id, int> stall_counts;\n"""

Expand Down Expand Up @@ -1014,30 +1023,25 @@ def generate_runtime(llvm_instr,addr_list,pes):
#print(init_memtest)

register = "inline void __register_entry(unsigned id, std::string e) {\n"
register += " log_mutex.lock();\n"
register += " (*traceFiles.at(id)) << e;\n"
register += " traceFiles.at(id)->flush();\n"
register += " log_mutex.unlock();\n"
register += "}"
print(register)

write_memtest = """void write_hetsim() {\n"""
write_memtest += " log_mutex.lock();\n"
write_memtest += " int size = traceFiles.size();\n"
write_memtest += " for (int i = 0; i < size; ++i) {\n"
write_memtest += " __register_entry(i, \"EOF\");\n"""
write_memtest += " traceFiles.at(i)->flush();\n"
write_memtest += " traceFiles.at(i)->close();\n"
write_memtest += " }\n"
write_memtest += " log_mutex.unlock();\n"
write_memtest += "}\n"
print(write_memtest)

def_mutex = """
std::mutex log_mutex;
std::mutex stall_mutex;\n"""

for pe in pes:
def_mutex += "std::mutex " + pe.name + "_mutex;\n"

print(def_mutex)

print("""
std::string __LD(void *addr, std::vector<void *> deps,
Expand Down Expand Up @@ -1136,7 +1140,12 @@ def generate_runtime(llvm_instr,addr_list,pes):
print(stall_func)

increment_stalls = '''
int increment_stalls(int num_stalls, int num_deps = 0, void *dep1 = NULL, void *dep2 = NULL, void *dep3 = NULL) {
int increment_stalls(int num_stalls, int num_deps, void *dep1, void *dep2, void *dep3) {
stall_mutex.lock();
bool found_stall_counts = stall_counts.find(std::this_thread::get_id()) != stall_counts.end();
stall_mutex.unlock();
if (num_deps > 0) {
if (is_log_open()) {
emit_stall();
Expand All @@ -1148,8 +1157,10 @@ def generate_runtime(llvm_instr,addr_list,pes):
stall_counts[std::this_thread::get_id()] += num_stalls;
stall_mutex.unlock();
}
else if (stall_counts.find(std::this_thread::get_id()) != stall_counts.end()) {
else if (found_stall_counts) {
stall_mutex.lock();
stall_counts[std::this_thread::get_id()] = 0;
stall_mutex.unlock();
}
return 0;
}'''
Expand All @@ -1158,6 +1169,11 @@ def generate_runtime(llvm_instr,addr_list,pes):
emit_stall = '''
int emit_stall() {
std::thread::id thread_id = std::this_thread::get_id();
stall_mutex.lock();
bool found_stall_counts = stall_counts.find(std::this_thread::get_id()) != stall_counts.end();
stall_mutex.unlock();
if (is_log_open()) {
stall_mutex.lock();
int stalls = stall_counts[thread_id];
Expand All @@ -1175,8 +1191,10 @@ def generate_runtime(llvm_instr,addr_list,pes):
emit_stall += '''
}
}
else if (stall_counts.find(std::this_thread::get_id()) != stall_counts.end()) {
else if (found_stall_counts) {
stall_mutex.lock();
stall_counts[std::this_thread::get_id()] = 0;
stall_mutex.unlock();
}
return 0;
}\n'''
Expand All @@ -1188,14 +1206,12 @@ def generate_runtime(llvm_instr,addr_list,pes):
func += """ std::string fname = "./traces/pe_" + std::to_string(tid) + ".trace";
log_mutex.lock();
traceFiles.insert(std::make_pair(tid, std::make_shared<std::ofstream>(fname, std::ios::out)));
log_mutex.unlock();
if (!traceFiles.at(tid)) {
printf("[ERROR] couldn't open trace file to write: %s\\n", fname.c_str());
exit(1);
}\n
"""
func +=""" log_mutex.lock();
trace_bools[std::this_thread::get_id()] = true;
func +=""" trace_bools[std::this_thread::get_id()] = true;
log_mutex.unlock();\n"""

func += """ stall_mutex.lock();
Expand All @@ -1211,11 +1227,11 @@ def generate_runtime(llvm_instr,addr_list,pes):
if (trace_bools.find(std::this_thread::get_id()) == trace_bools.end()) {
trace_bools[std::this_thread::get_id()] = false;
}
log_mutex.unlock();
log_mutex.unlock();
__register_entry(tid, "EOF\\n");
traceFiles.at(tid)->close();
log_mutex.lock();
traceFiles.at(tid)->close();
traceFiles.erase(tid);
log_mutex.unlock();
}""")
Expand Down Expand Up @@ -1295,7 +1311,7 @@ def generate_runtime(llvm_instr,addr_list,pes):
print(emit_store);

increment_stalls = """
int increment_stalls(int num_stalls, int num_deps = 0, void *dep1 = NULL, void *dep2 = NULL, void *dep3 = NULL) {
int increment_stalls(int num_stalls, int num_deps, void *dep1 = NULL, void *dep2 = NULL, void *dep3 = NULL) {
if (num_deps > 0) {
if (is_log_open()) {
Expand All @@ -1317,8 +1333,13 @@ def generate_runtime(llvm_instr,addr_list,pes):
"""
emit_stall_with_deps = """
int emit_stall_with_deps(int num_stalls, int num_deps = 0, void *dep1, void *dep2, void *dep3) {
int emit_stall_with_deps(int num_stalls, int num_deps, void *dep1, void *dep2, void *dep3) {
std::thread::id thread_id = std::this_thread::get_id();
stall_mutex.lock();
bool found_stall_counts = stall_counts.find(std::this_thread::get_id()) != stall_counts.end();
stall_mutex.unlock();
if (is_log_open()) {
stall_mutex.lock();
int stalls = stall_counts[thread_id];
Expand Down Expand Up @@ -1354,8 +1375,10 @@ def generate_runtime(llvm_instr,addr_list,pes):
emit_stall_with_deps += """
}
}
else if (stall_counts.find(std::this_thread::get_id()) != stall_counts.end()) {
else if (found_stall_counts) {
stall_mutex.lock();
stall_counts[std::this_thread::get_id()] = 0;
stall_mutex.unlock();
}
return 0;
}"""
Expand Down

0 comments on commit 27bebb2

Please sign in to comment.