Skip to content

Add persistent kernel cache on disk #532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions include/taco/codegen/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <string>
#include <utility>
#include <random>
#include <unordered_map>

#include "taco/target.h"
#include "taco/ir/ir.h"
Expand All @@ -16,10 +17,18 @@ namespace ir {
class Module {
public:
/// Create a module for some target
Module(Target target=getTargetFromEnvironment())
: lib_handle(nullptr), moduleFromUserSource(false), target(target) {
Module(const std::string& cacheStr = "", Target target=getTargetFromEnvironment())
: cacheStr(cacheStr), lib_handle(nullptr), moduleFromUserSource(false), target(target) {
setJITLibname();
setJITTmpdir();

if (cacheStr != "") {
std::hash<std::string> hasher;
size_t hash = hasher(cacheStr);
std::ostringstream s;
s << std::hex << hash;
cacheStrHashed = s.str();
}
}

/// Compile the source into a library, returning its full path
Expand Down Expand Up @@ -67,6 +76,8 @@ class Module {
void setSource(std::string source);

private:
std::string cacheStr;
std::string cacheStrHashed;
std::stringstream source;
std::stringstream header;
std::string libname;
Expand Down
6 changes: 6 additions & 0 deletions include/taco/index_notation/index_notation.h
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,9 @@ class IndexStmt : public util::IntrusivePtr<const IndexStmtNode> {
SubType as() {
return to<SubType>(*this);
}

/// Get string to use for this index statement in cache.
std::string getCacheString() const;
};

/// Check if two index statements are isomorphic.
Expand Down Expand Up @@ -1232,6 +1235,9 @@ class TensorVar : public util::Comparable<TensorVar> {
friend bool operator==(const TensorVar&, const TensorVar&);
friend bool operator<(const TensorVar&, const TensorVar&);

/// Get string to use for this tensor in cache.
std::string getCacheString() const;

private:
struct Content;
std::shared_ptr<Content> content;
Expand Down
63 changes: 50 additions & 13 deletions src/codegen/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,15 @@ void Module::compileToSource(string path, string prefix) {
ofstream source_file;
string file_ending = should_use_CUDA_codegen() ? ".cu" : ".c";
source_file.open(path+prefix+file_ending);
taco_uassert((bool)source_file) << "Could not open file '"
<< path+prefix+file_ending << "' for writing";
source_file << source.str();
source_file.close();

ofstream header_file;
header_file.open(path+prefix+".h");
taco_uassert((bool)header_file) << "Could not open file '"
<< path+prefix+".h" << "' for writing";
header_file << header.str();
header_file.close();
}
Expand All @@ -97,9 +101,13 @@ void writeShims(vector<Stmt> funcs, string path, string prefix) {
ofstream shims_file;
if (should_use_CUDA_codegen()) {
shims_file.open(path+prefix+"_shims.cpp");
taco_uassert((bool)shims_file) << "Could not open file '"
<< path+prefix+"_shims.cpp" << "' for writing";
}
else {
shims_file.open(path+prefix+".c", ios::app);
taco_uassert((bool)shims_file) << "Could not open file '"
<< path+prefix+".c" << "' for writing";
}
shims_file << "#include \"" << path << prefix << ".h\"\n";
shims_file << shims.str();
Expand All @@ -109,7 +117,14 @@ void writeShims(vector<Stmt> funcs, string path, string prefix) {
} // anonymous namespace

string Module::compile() {
string prefix = tmpdir+libname;
string cache_dir = util::getFromEnv("TACO_CACHE_DIR", "");
bool use_cache = cache_dir != "" && cacheStr != "";
if (use_cache && cache_dir.back() != '/') {
cache_dir += '/';
}
string dir = use_cache ? cache_dir : tmpdir;
string name = use_cache ? cacheStrHashed : libname;
string prefix = dir+name;
string fullpath = prefix + ".so";

string cc;
Expand Down Expand Up @@ -140,21 +155,43 @@ string Module::compile() {
file_ending = ".c";
shims_file = "";
}

string source_file;
string object_file;
if (use_cache) {
source_file = cache_dir + name + file_ending;
object_file = cache_dir + name + ".so";
} else {
source_file = prefix + file_ending;
object_file = fullpath;
}

string cmd = cc + " " + cflags + " " +
prefix + file_ending + " " + shims_file + " " +
"-o " + fullpath + " -lm";
source_file + " " + shims_file + " " +
"-o " + object_file + " -lm";

// open the output file & write out the source
compileToSource(tmpdir, libname);

// write out the shims
writeShims(funcs, tmpdir, libname);

// now compile it
int err = system(cmd.data());
taco_uassert(err == 0) << "Compilation command failed:\n" << cmd
<< "\nreturned " << err;
bool cached = false;
if (use_cache) {
// first check if this file already exists in cache
ifstream cached_source_file(source_file);
ifstream cached_header_file(cache_dir + name + ".h");
ifstream cached_object_file(object_file);
// only run codegen if the files don't already exist
cached = cached_source_file.good() && cached_header_file.good() && cached_object_file.good();
}

if (!cached) {
// open the output file & write out the source
compileToSource(dir, name);

// write out the shims
writeShims(funcs, dir, name);

// now compile it
int err = system(cmd.data());
taco_uassert(err == 0) << "Compilation command failed:\n" << cmd
<< "\nreturned " << err;
}

// use dlsym() to open the compiled library
if (lib_handle) {
Expand Down
16 changes: 16 additions & 0 deletions src/index_notation/index_notation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2074,6 +2074,15 @@ IndexStmt IndexStmt::wsaccel(TensorVar& ws, bool shouldAccel, const std::vector<
return *this;
}

std::string IndexStmt::getCacheString() const {
std::ostringstream s;
for (auto& var : getTensorVars(*this)) {
s << var.getCacheString() << "|";
}
s << *this;
return s.str();
}

std::ostream& operator<<(std::ostream& os, const IndexStmt& expr) {
if (!expr.defined()) return os << "IndexStmt()";
IndexNotationPrinter printer(os);
Expand Down Expand Up @@ -2694,6 +2703,13 @@ std::ostream& operator<<(std::ostream& os, const TensorVar& var) {
return os << var.getName() << " : " << var.getType();
}

string TensorVar::getCacheString() const {
ostringstream s;
s << getName() << ":";
s << getType().getDataType() << ";";
s << getFormat();
return s.str();
}

static bool isValid(Assignment assignment, string* reason) {
if (reason == nullptr) {
Expand Down
2 changes: 1 addition & 1 deletion src/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ void TensorBase::compile(taco::IndexStmt stmt, bool assembleWhileCompute) {
// If we have to recompile the kernel, we need to create a new Module. Since
// the module we are holding on to could have been retrieved from the cache,
// we can't modify it.
content->module = make_shared<Module>();
content->module = make_shared<Module>(stmtToCompile.getCacheString());
content->module->addFunction(content->assembleFunc);
content->module->addFunction(content->computeFunc);
content->module->compile();
Expand Down