forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
init_flatbuffer_module.cpp
131 lines (123 loc) · 4.69 KB
/
init_flatbuffer_module.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include <torch/csrc/python_headers.h>
#include <libshm.h>
#include <cstdlib>
#include <pybind11/detail/common.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <torch/csrc/utils/pybind.h>
#include <Python.h> // NOLINT
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#include <torch/csrc/jit/python/module_python.h>
#include <torch/csrc/jit/python/python_ivalue.h>
#include <torch/csrc/jit/python/python_sugared_value.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
#include <torch/csrc/jit/serialization/import.h>
namespace py = pybind11;
using torch::jit::kFlatbufferDataAlignmentBytes;
static std::shared_ptr<char> copyStr(const std::string& bytes) {
size_t size = (bytes.size() / kFlatbufferDataAlignmentBytes + 1) *
kFlatbufferDataAlignmentBytes;
#ifdef _WIN32
std::shared_ptr<char> bytes_copy(
static_cast<char*>(_aligned_malloc(size, kFlatbufferDataAlignmentBytes)),
_aligned_free);
#elif defined(__APPLE__)
void* p;
::posix_memalign(&p, kFlatbufferDataAlignmentBytes, size);
TORCH_INTERNAL_ASSERT(p, "Could not allocate memory for flatbuffer");
std::shared_ptr<char> bytes_copy(static_cast<char*>(p), free);
#else
std::shared_ptr<char> bytes_copy(
static_cast<char*>(aligned_alloc(kFlatbufferDataAlignmentBytes, size)),
free);
#endif
memcpy(bytes_copy.get(), bytes.data(), bytes.size());
return bytes_copy;
}
extern "C"
#ifdef _WIN32
__declspec(dllexport)
#endif
PyObject* initModuleFlatbuffer() {
using namespace torch::jit;
PyMethodDef m[] = {{nullptr, nullptr, 0, nullptr}}; // NOLINT
static struct PyModuleDef torchmodule = {
PyModuleDef_HEAD_INIT,
"torch._C_flatbuffer",
nullptr,
-1,
m,
}; // NOLINT
PyObject* module = PyModule_Create(&torchmodule);
auto pym = py::handle(module).cast<py::module>();
pym.def("_load_mobile_module_from_file", [](const std::string& filename) {
return torch::jit::load_mobile_module_from_file(filename);
});
pym.def("_load_mobile_module_from_bytes", [](const std::string& bytes) {
auto bytes_copy = copyStr(bytes);
return torch::jit::parse_and_initialize_mobile_module(
bytes_copy, bytes.size());
});
pym.def("_load_jit_module_from_file", [](const std::string& filename) {
ExtraFilesMap extra_files = ExtraFilesMap();
return torch::jit::load_jit_module_from_file(filename, extra_files);
});
pym.def("_load_jit_module_from_bytes", [](const std::string& bytes) {
auto bytes_copy = copyStr(bytes);
ExtraFilesMap extra_files = ExtraFilesMap();
return torch::jit::parse_and_initialize_jit_module(
bytes_copy, bytes.size(), extra_files);
});
pym.def(
"_save_mobile_module",
[](const torch::jit::mobile::Module& module,
const std::string& filename,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
return torch::jit::save_mobile_module(module, filename, _extra_files);
});
pym.def(
"_save_jit_module",
[](const torch::jit::Module& module,
const std::string& filename,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
return torch::jit::save_jit_module(module, filename, _extra_files);
});
pym.def(
"_save_mobile_module_to_bytes",
[](const torch::jit::mobile::Module& module,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
auto detached_buffer =
torch::jit::save_mobile_module_to_bytes(module, _extra_files);
return py::bytes(
reinterpret_cast<char*>(detached_buffer->data()),
detached_buffer->size());
});
pym.def(
"_save_jit_module_to_bytes",
[](const torch::jit::Module& module,
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
auto detached_buffer =
torch::jit::save_jit_module_to_bytes(module, _extra_files);
return py::bytes(
reinterpret_cast<char*>(detached_buffer->data()),
detached_buffer->size());
});
pym.def(
"_get_module_info_from_flatbuffer", [](std::string flatbuffer_content) {
py::gil_scoped_acquire acquire;
py::dict result;
mobile::ModuleInfo minfo = torch::jit::get_module_info_from_flatbuffer(
flatbuffer_content.data());
result["bytecode_version"] = minfo.bytecode_version;
result["operator_version"] = minfo.operator_version;
result["function_names"] = minfo.function_names;
result["type_names"] = minfo.type_names;
result["opname_to_num_args"] = minfo.opname_to_num_args;
return result;
});
return module;
}