-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sourcery refactored relax branch #1
base: relax
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,9 +41,9 @@ def del_dir(target: Union[Path, str], only_if_empty: bool = False): | |
p.chmod(0o666) | ||
if p.is_dir(): | ||
p.rmdir() | ||
elif only_if_empty: | ||
raise RuntimeError(f"{p.parent} is not empty!") | ||
else: | ||
if only_if_empty: | ||
raise RuntimeError(f"{p.parent} is not empty!") | ||
p.unlink() | ||
target.rmdir() | ||
|
||
|
@@ -100,12 +100,12 @@ def main(model_str, output_path): | |
with tvm.transform.PassContext(opt_level=3): | ||
graph, lib, params = relay.build(net, tvm.target.Target(target, target_host), params=params) | ||
print("dumping lib...") | ||
lib.export_library(output_path_str + "/" + "deploy_lib_cpu.so", ndk.create_shared) | ||
lib.export_library(f"{output_path_str}/deploy_lib_cpu.so", ndk.create_shared) | ||
print("dumping graph...") | ||
with open(output_path_str + "/" + "deploy_graph.json", "w") as f: | ||
with open(f"{output_path_str}/deploy_graph.json", "w") as f: | ||
f.write(graph) | ||
print("dumping params...") | ||
with open(output_path_str + "/" + "deploy_param.params", "wb") as f: | ||
with open(f"{output_path_str}/deploy_param.params", "wb") as f: | ||
Comment on lines
-103
to
+108
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
f.write(tvm.runtime.save_param_dict(params)) | ||
print("dumping labels...") | ||
synset_url = "".join( | ||
|
@@ -116,11 +116,11 @@ def main(model_str, output_path): | |
"imagenet1000_clsid_to_human.txt", | ||
] | ||
) | ||
synset_path = output_path_str + "/image_net_labels" | ||
download(synset_url, output_path_str + "/image_net_labels") | ||
synset_path = f"{output_path_str}/image_net_labels" | ||
download(synset_url, f"{output_path_str}/image_net_labels") | ||
with open(synset_path) as fi: | ||
synset = eval(fi.read()) | ||
with open(output_path_str + "/image_net_labels.json", "w") as fo: | ||
with open(f"{output_path_str}/image_net_labels.json", "w") as fo: | ||
json.dump(synset, fo, indent=4) | ||
os.remove(synset_path) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ | |
Use "android" as the key if you wish to avoid modifying this script. | ||
""" | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
|
||
import tvm | ||
from tvm import te | ||
import os | ||
|
@@ -36,7 +37,7 @@ | |
# Change target configuration. | ||
# Run `adb shell cat /proc/cpuinfo` to find the arch. | ||
arch = "arm64" | ||
target = "llvm -mtriple=%s-linux-android" % arch | ||
target = f"llvm -mtriple={arch}-linux-android" | ||
|
||
# whether enable to execute test on OpenCL target | ||
test_opencl = False | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,10 +46,10 @@ def evaluate_network(network, target, target_host, repeat): | |
if "android" in str(target): | ||
from tvm.contrib import ndk | ||
|
||
filename = "%s.so" % network | ||
filename = f"{network}.so" | ||
lib.export_library(tmp.relpath(filename), ndk.create_shared) | ||
else: | ||
filename = "%s.tar" % network | ||
filename = f"{network}.tar" | ||
Comment on lines
-49
to
+52
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
lib.export_library(tmp.relpath(filename)) | ||
|
||
# upload library and params | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,7 +103,10 @@ def benchmark(network, target): | |
else: | ||
networks = [args.network] | ||
|
||
target = tvm.target.Target("%s -device=%s -model=%s" % (args.target, args.device, args.model)) | ||
target = tvm.target.Target( | ||
f"{args.target} -device={args.device} -model={args.model}" | ||
) | ||
|
||
Comment on lines
-106
to
+109
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
|
||
print("--------------------------------------------------") | ||
print("%-20s %-20s" % ("Network Name", "Mean Inference Time (std dev)")) | ||
|
@@ -112,7 +115,7 @@ def benchmark(network, target): | |
if args.thread == 1: | ||
benchmark(network, target) | ||
else: | ||
threads = list() | ||
threads = [] | ||
for n in range(args.thread): | ||
thread = threading.Thread( | ||
target=benchmark, args=([network, target]), name="thread%d" % n | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,10 +46,10 @@ def evaluate_network(network, target, target_host, dtype, repeat): | |
if "android" in str(target) or "android" in str(target_host): | ||
from tvm.contrib import ndk | ||
|
||
filename = "%s.so" % network | ||
filename = f"{network}.so" | ||
lib.export_library(tmp.relpath(filename), ndk.create_shared) | ||
else: | ||
filename = "%s.tar" % network | ||
filename = f"{network}.tar" | ||
Comment on lines
-49
to
+52
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
lib.export_library(tmp.relpath(filename)) | ||
|
||
# upload library and params | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,7 +84,7 @@ def get_network(name, batch_size, dtype="float32"): | |
) | ||
net = tvm.IRModule.from_expr(net) | ||
else: | ||
raise ValueError("Unsupported network: " + name) | ||
raise ValueError(f"Unsupported network: {name}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
return net, params, input_shape, output_shape | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,9 +28,9 @@ | |
def load_lib(): | ||
"""Load library, the functions will be registered into TVM""" | ||
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) | ||
# load in as global so the global extern symbol is visible to other dll. | ||
lib = ctypes.CDLL(os.path.join(curr_path, "../../lib/libtvm_ext.so"), ctypes.RTLD_GLOBAL) | ||
return lib | ||
return ctypes.CDLL( | ||
os.path.join(curr_path, "../../lib/libtvm_ext.so"), ctypes.RTLD_GLOBAL | ||
) | ||
Comment on lines
-31
to
+33
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
This removes the following comments ( why? ):
|
||
|
||
|
||
_LIB = load_lib() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,12 +48,9 @@ | |
team_id = args.team_id | ||
tvm_build_dir = args.tvm_build_dir | ||
|
||
fi = open("tvmrpc.xcodeproj/project.pbxproj") | ||
proj_config = fi.read() | ||
fi.close() | ||
|
||
with open("tvmrpc.xcodeproj/project.pbxproj") as fi: | ||
proj_config = fi.read() | ||
proj_config = proj_config.replace(default_team_id, team_id) | ||
proj_config = proj_config.replace(default_tvm_build_dir, tvm_build_dir) | ||
fo = open("tvmrpc.xcodeproj/project.pbxproj", "w") | ||
fo.write(proj_config) | ||
fo.close() | ||
with open("tvmrpc.xcodeproj/project.pbxproj", "w") as fo: | ||
fo.write(proj_config) | ||
Comment on lines
-51
to
+56
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,7 +39,7 @@ | |
# sdk = "iphonesimulator" | ||
arch = "arm64" | ||
sdk = "iphoneos" | ||
target_host = "llvm -mtriple=%s-apple-darwin" % arch | ||
target_host = f"llvm -mtriple={arch}-apple-darwin" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
|
||
MODES = {"proxy": rpc.connect, "tracker": rpc.connect_tracker, "standalone": rpc.connect} | ||
|
||
|
@@ -106,10 +106,7 @@ def run(mod, target): | |
remote = MODES[mode](host, port, key=key) | ||
remote.upload(path_dso) | ||
|
||
if target == "metal": | ||
dev = remote.metal(0) | ||
else: | ||
dev = remote.cpu(0) | ||
dev = remote.metal(0) if target == "metal" else remote.cpu(0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
lib = remote.load_module("deploy.dylib") | ||
m = graph_executor.GraphModule(lib["default"](dev)) | ||
|
||
|
@@ -174,11 +171,10 @@ def annotate(func, compiler): | |
"--mode", | ||
type=str, | ||
default="tracker", | ||
help="type of RPC connection (default: tracker), possible values: {}".format( | ||
", ".join(MODES.keys()) | ||
), | ||
help=f'type of RPC connection (default: tracker), possible values: {", ".join(MODES.keys())}', | ||
) | ||
|
||
Comment on lines
-177
to
176
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
|
||
args = parser.parse_args() | ||
assert args.mode in MODES.keys() | ||
test_mobilenet(args.host, args.port, args.key, args.mode) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
And configure the proxy host field as commented. | ||
""" | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
|
||
import tvm | ||
from tvm import te | ||
import os | ||
|
@@ -33,7 +34,7 @@ | |
# Change target configuration, this is setting for iphone6s | ||
arch = "arm64" | ||
sdk = "iphoneos" | ||
target = "llvm -mtriple=%s-apple-darwin" % arch | ||
target = f"llvm -mtriple={arch}-apple-darwin" | ||
|
||
MODES = {"proxy": rpc.connect, "tracker": rpc.connect_tracker, "standalone": rpc.connect} | ||
|
||
|
@@ -105,11 +106,10 @@ def test_rpc_module(host, port, key, mode): | |
"--mode", | ||
type=str, | ||
default="tracker", | ||
help="type of RPC connection (default: tracker), possible values: {}".format( | ||
", ".join(MODES.keys()) | ||
), | ||
help=f'type of RPC connection (default: tracker), possible values: {", ".join(MODES.keys())}', | ||
) | ||
|
||
Comment on lines
-108
to
111
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
|
||
args = parser.parse_args() | ||
assert args.mode in MODES.keys() | ||
test_rpc_module(args.host, args.port, args.key, args.mode) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,8 +57,6 @@ | |
|
||
class BoardAutodetectFailed(Exception): | ||
"""Raised when no attached hardware is found matching the requested board""" | ||
|
||
|
||
Comment on lines
-60
to
-61
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
PROJECT_TYPES = ["example_project", "host_driven"] | ||
|
||
PROJECT_OPTIONS = [ | ||
|
@@ -71,14 +69,14 @@ class BoardAutodetectFailed(Exception): | |
), | ||
server.ProjectOption( | ||
"arduino_cli_cmd", | ||
required=( | ||
required=None | ||
if ARDUINO_CLI_CMD | ||
else ["generate_project", "build", "flash", "open_transport"], | ||
optional=( | ||
["generate_project", "build", "flash", "open_transport"] | ||
if not ARDUINO_CLI_CMD | ||
if ARDUINO_CLI_CMD | ||
else None | ||
), | ||
optional=( | ||
["generate_project", "build", "flash", "open_transport"] if ARDUINO_CLI_CMD else None | ||
), | ||
default=ARDUINO_CLI_CMD, | ||
type="str", | ||
help="Path to the arduino-cli tool.", | ||
|
@@ -208,19 +206,18 @@ def _template_model_header(self, source_dir, metadata): | |
with open(source_dir / "model.h", "r") as f: | ||
model_h_template = Template(f.read()) | ||
|
||
all_module_names = [] | ||
for name in metadata["modules"].keys(): | ||
all_module_names.append(name) | ||
|
||
all_module_names = list(metadata["modules"].keys()) | ||
assert all( | ||
metadata["modules"][mod_name]["style"] == "full-model" for mod_name in all_module_names | ||
), "when generating AOT, expect only full-model Model Library Format" | ||
|
||
workspace_size_bytes = 0 | ||
for mod_name in all_module_names: | ||
workspace_size_bytes += metadata["modules"][mod_name]["memory"]["functions"]["main"][0][ | ||
workspace_size_bytes = sum( | ||
metadata["modules"][mod_name]["memory"]["functions"]["main"][0][ | ||
"workspace_size_bytes" | ||
] | ||
for mod_name in all_module_names | ||
) | ||
|
||
Comment on lines
-211
to
+220
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
template_values = { | ||
"workspace_size_bytes": workspace_size_bytes, | ||
} | ||
|
@@ -261,16 +258,16 @@ def _convert_includes(self, project_dir, source_dir): | |
with filename.open("wb") as dst_file: | ||
for line in lines: | ||
line_str = str(line, "utf-8") | ||
# Check if line has an include | ||
result = re.search(r"#include\s*[<\"]([^>]*)[>\"]", line_str) | ||
if not result: | ||
dst_file.write(line) | ||
else: | ||
if result := re.search( | ||
r"#include\s*[<\"]([^>]*)[>\"]", line_str | ||
): | ||
new_include = self._find_modified_include_path( | ||
project_dir, filename, result.groups()[0] | ||
) | ||
updated_line = f'#include "{new_include}"\n' | ||
dst_file.write(updated_line.encode("utf-8")) | ||
else: | ||
dst_file.write(line) | ||
Comment on lines
-264
to
+270
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
This removes the following comments ( why? ):
|
||
|
||
# Most of the files we used to be able to point to directly are under "src/standalone_crt/include/". | ||
# Howver, crt_config.h lives under "src/standalone_crt/crt_config/", and more exceptions might | ||
|
@@ -360,7 +357,7 @@ def _get_platform_version(self, arduino_cli_path: str) -> float: | |
version_output = subprocess.run( | ||
[arduino_cli_path, "version"], check=True, stdout=subprocess.PIPE | ||
).stdout.decode("utf-8") | ||
str_version = re.search(r"Version: ([\.0-9]*)", version_output).group(1) | ||
str_version = re.search(r"Version: ([\.0-9]*)", version_output)[1] | ||
Comment on lines
-363
to
+360
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
# Using too low a version should raise an error. Note that naively | ||
# comparing floats will fail here: 0.7 > 0.21, but 0.21 is a higher | ||
|
@@ -428,7 +425,7 @@ def _parse_connected_boards(self, tabular_str): | |
column_regex = r"\s*|".join(self.POSSIBLE_BOARD_LIST_HEADERS) + r"\s*" | ||
str_rows = tabular_str.split("\n") | ||
column_headers = list(re.finditer(column_regex, str_rows[0])) | ||
assert len(column_headers) > 0 | ||
assert column_headers | ||
Comment on lines
-431
to
+428
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
for str_row in str_rows[1:]: | ||
if not str_row.strip(): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -178,7 +178,7 @@ def test_flash(self, mock_run): | |
# Test we checked version then called upload | ||
assert mock_run.call_count == 2 | ||
assert mock_run.call_args_list[0][0] == (["arduino-cli", "version"],) | ||
assert mock_run.call_args_list[1][0][0][0:2] == ["arduino-cli", "upload"] | ||
assert mock_run.call_args_list[1][0][0][:2] == ["arduino-cli", "upload"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
mock_run.reset_mock() | ||
|
||
# Test exception raised when `arduino-cli upload` returns error code | ||
|
@@ -188,4 +188,4 @@ def test_flash(self, mock_run): | |
|
||
# Version information should be cached and not checked again | ||
mock_run.assert_called_once() | ||
assert mock_run.call_args[0][0][0:2] == ["arduino-cli", "upload"] | ||
assert mock_run.call_args[0][0][:2] == ["arduino-cli", "upload"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,7 @@ def create_header_file(name, tensor_name, tensor_data, output_path): | |
""" | ||
This function generates a header file containing the data from the numpy array provided. | ||
""" | ||
file_path = pathlib.Path(f"{output_path}/" + name).resolve() | ||
file_path = pathlib.Path(f"{output_path}/{name}").resolve() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
# Create header file with npy_data as a C array | ||
raw_path = file_path.with_suffix(".h").resolve() | ||
with open(raw_path, "w") as header_file: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,7 @@ def create_header_file(name, section, tensor_name, tensor_data, output_path): | |
""" | ||
This function generates a header file containing the data from the numpy array provided. | ||
""" | ||
file_path = pathlib.Path(f"{output_path}/" + name).resolve() | ||
file_path = pathlib.Path(f"{output_path}/{name}").resolve() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
# Create header file with npy_data as a C array | ||
raw_path = file_path.with_suffix(".h").resolve() | ||
with open(raw_path, "w") as header_file: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,7 @@ def create_labels_header(labels_file, section, output_path): | |
with open(file_path, "w") as header_file: | ||
header_file.write(f'char* labels[] __attribute__((section("{section}"), aligned(16))) = {{') | ||
|
||
for _, label in enumerate(labels): | ||
for label in labels: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
header_file.write(f'"{label.rstrip()}",') | ||
|
||
header_file.write("};\n") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function
del_dir
refactored with the following changes:merge-else-if-into-elif
)reintroduce-else
)