Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sourcery refactored relax branch #1

Open
wants to merge 1 commit into
base: relax
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions apps/android_camera/models/prepare_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ def del_dir(target: Union[Path, str], only_if_empty: bool = False):
p.chmod(0o666)
if p.is_dir():
p.rmdir()
elif only_if_empty:
raise RuntimeError(f"{p.parent} is not empty!")
else:
if only_if_empty:
raise RuntimeError(f"{p.parent} is not empty!")
Comment on lines +44 to -46
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function del_dir refactored with the following changes:

p.unlink()
target.rmdir()

Expand Down Expand Up @@ -100,12 +100,12 @@ def main(model_str, output_path):
with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(net, tvm.target.Target(target, target_host), params=params)
print("dumping lib...")
lib.export_library(output_path_str + "/" + "deploy_lib_cpu.so", ndk.create_shared)
lib.export_library(f"{output_path_str}/deploy_lib_cpu.so", ndk.create_shared)
print("dumping graph...")
with open(output_path_str + "/" + "deploy_graph.json", "w") as f:
with open(f"{output_path_str}/deploy_graph.json", "w") as f:
f.write(graph)
print("dumping params...")
with open(output_path_str + "/" + "deploy_param.params", "wb") as f:
with open(f"{output_path_str}/deploy_param.params", "wb") as f:
Comment on lines -103 to +108
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function main refactored with the following changes:

f.write(tvm.runtime.save_param_dict(params))
print("dumping labels...")
synset_url = "".join(
Expand All @@ -116,11 +116,11 @@ def main(model_str, output_path):
"imagenet1000_clsid_to_human.txt",
]
)
synset_path = output_path_str + "/image_net_labels"
download(synset_url, output_path_str + "/image_net_labels")
synset_path = f"{output_path_str}/image_net_labels"
download(synset_url, f"{output_path_str}/image_net_labels")
with open(synset_path) as fi:
synset = eval(fi.read())
with open(output_path_str + "/image_net_labels.json", "w") as fo:
with open(f"{output_path_str}/image_net_labels.json", "w") as fo:
json.dump(synset, fo, indent=4)
os.remove(synset_path)

Expand Down
3 changes: 2 additions & 1 deletion apps/android_rpc/tests/android_rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Use "android" as the key if you wish to avoid modifying this script.
"""

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 39-39 refactored with the following changes:


import tvm
from tvm import te
import os
Expand All @@ -36,7 +37,7 @@
# Change target configuration.
# Run `adb shell cat /proc/cpuinfo` to find the arch.
arch = "arm64"
target = "llvm -mtriple=%s-linux-android" % arch
target = f"llvm -mtriple={arch}-linux-android"

# whether enable to execute test on OpenCL target
test_opencl = False
Expand Down
4 changes: 2 additions & 2 deletions apps/benchmark/arm_cpu_imagenet_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def evaluate_network(network, target, target_host, repeat):
if "android" in str(target):
from tvm.contrib import ndk

filename = "%s.so" % network
filename = f"{network}.so"
lib.export_library(tmp.relpath(filename), ndk.create_shared)
else:
filename = "%s.tar" % network
filename = f"{network}.tar"
Comment on lines -49 to +52
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function evaluate_network refactored with the following changes:

lib.export_library(tmp.relpath(filename))

# upload library and params
Expand Down
7 changes: 5 additions & 2 deletions apps/benchmark/gpu_imagenet_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ def benchmark(network, target):
else:
networks = [args.network]

target = tvm.target.Target("%s -device=%s -model=%s" % (args.target, args.device, args.model))
target = tvm.target.Target(
f"{args.target} -device={args.device} -model={args.model}"
)

Comment on lines -106 to +109
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 106-115 refactored with the following changes:


print("--------------------------------------------------")
print("%-20s %-20s" % ("Network Name", "Mean Inference Time (std dev)"))
Expand All @@ -112,7 +115,7 @@ def benchmark(network, target):
if args.thread == 1:
benchmark(network, target)
else:
threads = list()
threads = []
for n in range(args.thread):
thread = threading.Thread(
target=benchmark, args=([network, target]), name="thread%d" % n
Expand Down
4 changes: 2 additions & 2 deletions apps/benchmark/mobile_gpu_imagenet_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def evaluate_network(network, target, target_host, dtype, repeat):
if "android" in str(target) or "android" in str(target_host):
from tvm.contrib import ndk

filename = "%s.so" % network
filename = f"{network}.so"
lib.export_library(tmp.relpath(filename), ndk.create_shared)
else:
filename = "%s.tar" % network
filename = f"{network}.tar"
Comment on lines -49 to +52
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function evaluate_network refactored with the following changes:

lib.export_library(tmp.relpath(filename))

# upload library and params
Expand Down
2 changes: 1 addition & 1 deletion apps/benchmark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_network(name, batch_size, dtype="float32"):
)
net = tvm.IRModule.from_expr(net)
else:
raise ValueError("Unsupported network: " + name)
raise ValueError(f"Unsupported network: {name}")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_network refactored with the following changes:


return net, params, input_shape, output_shape

Expand Down
6 changes: 3 additions & 3 deletions apps/extension/python/tvm_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
def load_lib():
"""Load library, the functions will be registered into TVM"""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
# load in as global so the global extern symbol is visible to other dll.
lib = ctypes.CDLL(os.path.join(curr_path, "../../lib/libtvm_ext.so"), ctypes.RTLD_GLOBAL)
return lib
return ctypes.CDLL(
os.path.join(curr_path, "../../lib/libtvm_ext.so"), ctypes.RTLD_GLOBAL
)
Comment on lines -31 to +33
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function load_lib refactored with the following changes:

This removes the following comments ( why? ):

# load in as global so the global extern symbol is visible to other dll.



_LIB = load_lib()
Expand Down
11 changes: 4 additions & 7 deletions apps/ios_rpc/init_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,9 @@
team_id = args.team_id
tvm_build_dir = args.tvm_build_dir

fi = open("tvmrpc.xcodeproj/project.pbxproj")
proj_config = fi.read()
fi.close()

with open("tvmrpc.xcodeproj/project.pbxproj") as fi:
proj_config = fi.read()
proj_config = proj_config.replace(default_team_id, team_id)
proj_config = proj_config.replace(default_tvm_build_dir, tvm_build_dir)
fo = open("tvmrpc.xcodeproj/project.pbxproj", "w")
fo.write(proj_config)
fo.close()
with open("tvmrpc.xcodeproj/project.pbxproj", "w") as fo:
fo.write(proj_config)
Comment on lines -51 to +56
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 51-59 refactored with the following changes:

12 changes: 4 additions & 8 deletions apps/ios_rpc/tests/ios_rpc_mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
# sdk = "iphonesimulator"
arch = "arm64"
sdk = "iphoneos"
target_host = "llvm -mtriple=%s-apple-darwin" % arch
target_host = f"llvm -mtriple={arch}-apple-darwin"
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 42-42 refactored with the following changes:


MODES = {"proxy": rpc.connect, "tracker": rpc.connect_tracker, "standalone": rpc.connect}

Expand Down Expand Up @@ -106,10 +106,7 @@ def run(mod, target):
remote = MODES[mode](host, port, key=key)
remote.upload(path_dso)

if target == "metal":
dev = remote.metal(0)
else:
dev = remote.cpu(0)
dev = remote.metal(0) if target == "metal" else remote.cpu(0)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function test_mobilenet refactored with the following changes:

lib = remote.load_module("deploy.dylib")
m = graph_executor.GraphModule(lib["default"](dev))

Expand Down Expand Up @@ -174,11 +171,10 @@ def annotate(func, compiler):
"--mode",
type=str,
default="tracker",
help="type of RPC connection (default: tracker), possible values: {}".format(
", ".join(MODES.keys())
),
help=f'type of RPC connection (default: tracker), possible values: {", ".join(MODES.keys())}',
)

Comment on lines -177 to 176
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 177-179 refactored with the following changes:


args = parser.parse_args()
assert args.mode in MODES.keys()
test_mobilenet(args.host, args.port, args.key, args.mode)
8 changes: 4 additions & 4 deletions apps/ios_rpc/tests/ios_rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
And configure the proxy host field as commented.
"""

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 36-36 refactored with the following changes:


import tvm
from tvm import te
import os
Expand All @@ -33,7 +34,7 @@
# Change target configuration, this is setting for iphone6s
arch = "arm64"
sdk = "iphoneos"
target = "llvm -mtriple=%s-apple-darwin" % arch
target = f"llvm -mtriple={arch}-apple-darwin"

MODES = {"proxy": rpc.connect, "tracker": rpc.connect_tracker, "standalone": rpc.connect}

Expand Down Expand Up @@ -105,11 +106,10 @@ def test_rpc_module(host, port, key, mode):
"--mode",
type=str,
default="tracker",
help="type of RPC connection (default: tracker), possible values: {}".format(
", ".join(MODES.keys())
),
help=f'type of RPC connection (default: tracker), possible values: {", ".join(MODES.keys())}',
)

Comment on lines -108 to 111
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 108-110 refactored with the following changes:


args = parser.parse_args()
assert args.mode in MODES.keys()
test_rpc_module(args.host, args.port, args.key, args.mode)
39 changes: 18 additions & 21 deletions apps/microtvm/arduino/template_project/microtvm_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@

class BoardAutodetectFailed(Exception):
"""Raised when no attached hardware is found matching the requested board"""


Comment on lines -60 to -61
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 60-80 refactored with the following changes:

PROJECT_TYPES = ["example_project", "host_driven"]

PROJECT_OPTIONS = [
Expand All @@ -71,14 +69,14 @@ class BoardAutodetectFailed(Exception):
),
server.ProjectOption(
"arduino_cli_cmd",
required=(
required=None
if ARDUINO_CLI_CMD
else ["generate_project", "build", "flash", "open_transport"],
optional=(
["generate_project", "build", "flash", "open_transport"]
if not ARDUINO_CLI_CMD
if ARDUINO_CLI_CMD
else None
),
optional=(
["generate_project", "build", "flash", "open_transport"] if ARDUINO_CLI_CMD else None
),
default=ARDUINO_CLI_CMD,
type="str",
help="Path to the arduino-cli tool.",
Expand Down Expand Up @@ -208,19 +206,18 @@ def _template_model_header(self, source_dir, metadata):
with open(source_dir / "model.h", "r") as f:
model_h_template = Template(f.read())

all_module_names = []
for name in metadata["modules"].keys():
all_module_names.append(name)

all_module_names = list(metadata["modules"].keys())
assert all(
metadata["modules"][mod_name]["style"] == "full-model" for mod_name in all_module_names
), "when generating AOT, expect only full-model Model Library Format"

workspace_size_bytes = 0
for mod_name in all_module_names:
workspace_size_bytes += metadata["modules"][mod_name]["memory"]["functions"]["main"][0][
workspace_size_bytes = sum(
metadata["modules"][mod_name]["memory"]["functions"]["main"][0][
"workspace_size_bytes"
]
for mod_name in all_module_names
)

Comment on lines -211 to +220
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Handler._template_model_header refactored with the following changes:

template_values = {
"workspace_size_bytes": workspace_size_bytes,
}
Expand Down Expand Up @@ -261,16 +258,16 @@ def _convert_includes(self, project_dir, source_dir):
with filename.open("wb") as dst_file:
for line in lines:
line_str = str(line, "utf-8")
# Check if line has an include
result = re.search(r"#include\s*[<\"]([^>]*)[>\"]", line_str)
if not result:
dst_file.write(line)
else:
if result := re.search(
r"#include\s*[<\"]([^>]*)[>\"]", line_str
):
new_include = self._find_modified_include_path(
project_dir, filename, result.groups()[0]
)
updated_line = f'#include "{new_include}"\n'
dst_file.write(updated_line.encode("utf-8"))
else:
dst_file.write(line)
Comment on lines -264 to +270
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Handler._convert_includes refactored with the following changes:

This removes the following comments ( why? ):

# Check if line has an include


# Most of the files we used to be able to point to directly are under "src/standalone_crt/include/".
# Howver, crt_config.h lives under "src/standalone_crt/crt_config/", and more exceptions might
Expand Down Expand Up @@ -360,7 +357,7 @@ def _get_platform_version(self, arduino_cli_path: str) -> float:
version_output = subprocess.run(
[arduino_cli_path, "version"], check=True, stdout=subprocess.PIPE
).stdout.decode("utf-8")
str_version = re.search(r"Version: ([\.0-9]*)", version_output).group(1)
str_version = re.search(r"Version: ([\.0-9]*)", version_output)[1]
Comment on lines -363 to +360
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Handler._get_platform_version refactored with the following changes:


# Using too low a version should raise an error. Note that naively
# comparing floats will fail here: 0.7 > 0.21, but 0.21 is a higher
Expand Down Expand Up @@ -428,7 +425,7 @@ def _parse_connected_boards(self, tabular_str):
column_regex = r"\s*|".join(self.POSSIBLE_BOARD_LIST_HEADERS) + r"\s*"
str_rows = tabular_str.split("\n")
column_headers = list(re.finditer(column_regex, str_rows[0]))
assert len(column_headers) > 0
assert column_headers
Comment on lines -431 to +428
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function Handler._parse_connected_boards refactored with the following changes:


for str_row in str_rows[1:]:
if not str_row.strip():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_flash(self, mock_run):
# Test we checked version then called upload
assert mock_run.call_count == 2
assert mock_run.call_args_list[0][0] == (["arduino-cli", "version"],)
assert mock_run.call_args_list[1][0][0][0:2] == ["arduino-cli", "upload"]
assert mock_run.call_args_list[1][0][0][:2] == ["arduino-cli", "upload"]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function TestGenerateProject.test_flash refactored with the following changes:

mock_run.reset_mock()

# Test exception raised when `arduino-cli upload` returns error code
Expand All @@ -188,4 +188,4 @@ def test_flash(self, mock_run):

# Version information should be cached and not checked again
mock_run.assert_called_once()
assert mock_run.call_args[0][0][0:2] == ["arduino-cli", "upload"]
assert mock_run.call_args[0][0][:2] == ["arduino-cli", "upload"]
2 changes: 1 addition & 1 deletion apps/microtvm/cmsisnn/convert_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def create_header_file(name, tensor_name, tensor_data, output_path):
"""
This function generates a header file containing the data from the numpy array provided.
"""
file_path = pathlib.Path(f"{output_path}/" + name).resolve()
file_path = pathlib.Path(f"{output_path}/{name}").resolve()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function create_header_file refactored with the following changes:

# Create header file with npy_data as a C array
raw_path = file_path.with_suffix(".h").resolve()
with open(raw_path, "w") as header_file:
Expand Down
2 changes: 1 addition & 1 deletion apps/microtvm/ethosu/convert_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def create_header_file(name, section, tensor_name, tensor_data, output_path):
"""
This function generates a header file containing the data from the numpy array provided.
"""
file_path = pathlib.Path(f"{output_path}/" + name).resolve()
file_path = pathlib.Path(f"{output_path}/{name}").resolve()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function create_header_file refactored with the following changes:

# Create header file with npy_data as a C array
raw_path = file_path.with_suffix(".h").resolve()
with open(raw_path, "w") as header_file:
Expand Down
2 changes: 1 addition & 1 deletion apps/microtvm/ethosu/convert_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def create_labels_header(labels_file, section, output_path):
with open(file_path, "w") as header_file:
header_file.write(f'char* labels[] __attribute__((section("{section}"), aligned(16))) = {{')

for _, label in enumerate(labels):
for label in labels:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function create_labels_header refactored with the following changes:

header_file.write(f'"{label.rstrip()}",')

header_file.write("};\n")
Expand Down
Loading