-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #101 from GGBond8488/add_cinn_autodiff
add cinn_diff module
- Loading branch information
Showing
15 changed files
with
1,223 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
*.json | ||
*.log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
## 快速开始 | ||
|
||
### 安装依赖 | ||
|
||
`python -m pip install -r requirments.txt` | ||
|
||
### 一键运行模式 | ||
|
||
**example** | ||
```python | ||
import os | ||
from padiff import cinn_diff | ||
|
||
|
||
def run(run_script, base_env, cinn_env): | ||
run_env = cinn_diff.Env(run_script, base_env, cinn_env) | ||
run_env.run_base_model() | ||
run_env.run_cinn_model() | ||
cinn_diff.auto_diff(run_env.base_path, run_env.cinn_path, rtol=1e-3, atol=1e-3) | ||
|
||
|
||
if __name__ == '__main__': | ||
run_script = "/root/workspace/PaddleNLP/model_zoo/bert/run_bert.sh" | ||
run(run_script, None, None) | ||
``` | ||
**run_script** | ||
模型运行脚本,使用时提供脚本路径,需在模型内部实现好动转静 | ||
|
||
**base_env** | ||
模型基线运行的环境变量 | ||
初始配置为 | ||
```python | ||
{ | ||
"CUDA_VISIBLE_DEVICES" : "0", | ||
"NVIDIA_TF32_OVERRIDE" : "1", | ||
"CUDA_LAUNCH_BLOCKING" : "1", | ||
"FLAGS_save_static_runtime_data" : "1", | ||
"FLAGS_static_runtime_data_save_path" : "./", | ||
"FLAGS_cudnn_deterministc" : "1", | ||
"FLAGS_cinn_cudnn_deterministc" : "1", | ||
"FLAGS_prim_all" : "true" | ||
} | ||
``` | ||
|
||
**cinn_env** | ||
模型接入编译器运行的环境变量 | ||
初始配置为 | ||
```python | ||
{ | ||
"FLAGS_use_cinn" : "1", | ||
"FLAGS_deny_cinn_ops" :"", | ||
"FLAGS_use_reduce_split_pass" : "1", | ||
"FLAGS_nvrtc_compile_to_cubin" : "0", | ||
"FLAGS_cinn_use_op_fusion" : "1", | ||
"FLAGS_cinn_parallel_compile_size" : "8", | ||
"FLAGS_cinn_pass_visualize_dir": "", | ||
} | ||
``` | ||
|
||
### 手动运行模式 | ||
|
||
step1: 准备模型运行脚本,跑通动转静+组合算子+编译器 | ||
|
||
step2: 手动运行动转静+组合算子的基线模型 | ||
|
||
基线模型运行是需要配置如下环境变量 | ||
``` | ||
"FLAGS_save_static_runtime_data" : "1", | ||
"FLAGS_static_runtime_data_save_path" : "./base", | ||
``` | ||
step3: 手动运行动转静+组合算子+编译器的模型 | ||
|
||
接入编译器的模型运行需要配置如下环境变量 | ||
``` | ||
"FLAGS_save_static_runtime_data" : "1", | ||
"FLAGS_static_runtime_data_save_path" : "./cinn", | ||
"FLAGS_cinn_pass_visualize_dir": "./cinn/cinn_pass", | ||
``` | ||
step4: 运行模型精度对齐脚本 | ||
|
||
```python | ||
from analyze import auto_diff | ||
|
||
base_path = "/root/dev/PaddleClas/base" | ||
compare_path = "/root/dev/PaddleClas/cinn" | ||
auto_diff(base_path, compare_path, atol=0, rtol=0) | ||
``` | ||
|
||
模型运行脚本环境变量配置例子 | ||
``` shell | ||
#!/bin/bash | ||
export CUDA_VISIBLE_DEVICES=5 | ||
export NVIDIA_TF32_OVERRIDE=1 | ||
export CUDA_LAUNCH_BLOCKING=1 | ||
export FLAGS_save_static_runtime_data=true | ||
export FLAGS_cudnn_deterministc=1 | ||
export FLAGS_cinn_cudnn_deterministc=1 | ||
# export FLAGS_check_nan_inf=1 | ||
rm -rf ./cinn/* | ||
export FLAGS_static_runtime_data_save_path="./cinn/" | ||
|
||
# 跑 动转静 + 组合算子时打开下面这1行 | ||
export FLAGS_prim_all=true | ||
# 跑 动转静 + 组合算子 + CINN时打开下面这6行 | ||
export FLAGS_use_cinn=1 | ||
export FLAGS_deny_cinn_ops="reduce_sum" | ||
export FLAGS_use_reduce_split_pass=1 | ||
export FLAGS_nvrtc_compile_to_cubin=0 | ||
export FLAGS_cinn_use_op_fusion=1 | ||
export FLAGS_cinn_parallel_compile_size=8 | ||
|
||
|
||
# # before and after cinn program and graph pass(including group opfusion pass) in each sub-graph | ||
rm -rf ./cinn_pass/* | ||
export FLAGS_cinn_pass_visualize_dir="./cinn/cinn_pass/" | ||
|
||
task_name_or_path="llama_output" | ||
python run_pretrain.py \ | ||
--model_type "llama" \ | ||
... | ||
``` | ||
|
||
## 运行结果 | ||
![运行结果图](./img/run_ret.png) | ||
|
||
|
||
更多功能正在研发中... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .graph import * | ||
from .compare_utils import * | ||
from .utils import * | ||
from .env import * | ||
from .analyze import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .read_file import read_all | ||
from .compare_utils import Comparator | ||
from .logs import logger | ||
|
||
# TODO(GGBond8488): Needs optimization in the future | ||
def back_track_group(base, compare, cluster, cmp, graph, node): | ||
inputs = graph.inputs() | ||
all_inputs_equal = True | ||
paddle_output_name = "" | ||
cur_cluster_cinn2paddle = {v: k for k, v in cluster.varmaps.items()} | ||
for input in inputs: | ||
tmp = input | ||
paddle_name = cur_cluster_cinn2paddle.get(tmp.name, "") | ||
if not paddle_name: | ||
logger.info(f"can't find {node.name}'s paddle name") | ||
diff_ret = { | ||
"cluster": cluster.idx, | ||
"group": cluster.cinn_group, | ||
"output": paddle_output_name, | ||
"output_cinn_var": node.name, | ||
"subgraph_id": node.graph_id if node else None, | ||
} | ||
return diff_ret | ||
paddle_output_name = paddle_name | ||
base_var_path = base.all_vars_paths[paddle_name] | ||
compare_var_path = compare.all_vars_paths[paddle_name] | ||
ret = cmp.allclose(base_var_path, compare_var_path) | ||
if not ret: | ||
all_inputs_equal = False | ||
group = cluster.cinn_group | ||
for graph in group.subgraphs: | ||
node = graph.find(tmp.name) | ||
if node and not graph.is_input(node): | ||
return back_track_group(base, compare, cluster, cmp, graph, node) | ||
if all_inputs_equal: | ||
diff_ret = { | ||
"cluster": cluster.idx, | ||
"group": cluster.cinn_group, | ||
"output": paddle_output_name, | ||
"output_cinn_var": node.name, | ||
"subgraph_id": node.graph_id if node else None, | ||
} | ||
return diff_ret | ||
|
||
|
||
def auto_diff(base_path, compare_path, rtol=1e-6, atol=1e-6): | ||
base = read_all(base_path) | ||
compare = read_all(compare_path) | ||
cmp = Comparator(rtol=rtol, atol=atol) | ||
|
||
# step1: Confirm whether the input and output of the cluster are aligned | ||
for cluster in compare.all_clusters: | ||
input_equals_flag = True | ||
output_equals_flag = True | ||
for input in cluster.inputs: | ||
base_var_path = base.all_vars_paths[input] | ||
compare_var_path = compare.all_vars_paths[input] | ||
ret = cmp.allclose(base_var_path, compare_var_path) | ||
if not ret: | ||
input_equals_flag = False | ||
cmp.record_input_diff(cluster.idx, input) | ||
continue | ||
|
||
if input_equals_flag: | ||
# step2: Find the misaligned output of the cluster | ||
for output in cluster.outputs: | ||
base_var_path = base.all_vars_paths[output] | ||
compare_var_path = compare.all_vars_paths[output] | ||
ret = cmp.allclose(base_var_path, compare_var_path) | ||
if not ret: | ||
output_equals_flag = False | ||
# step3: Find the group corresponding to the misaligned output | ||
output_cinn_var = cluster.varmaps.get(output, "") | ||
if not output_cinn_var: | ||
logger.info("can't find var " + output + " corresponding cinn var name") | ||
else: | ||
find_diff_group_flag = False | ||
# step4 : Starting from the misaligned output, find the group where the output misalignment | ||
# occurs for the first time (the input can be aligned, but the output cannot be aligned) | ||
group = cluster.cinn_group | ||
for graph in group.subgraphs: | ||
node = graph.find(output_cinn_var) | ||
if node and not graph.is_input(node): | ||
# Find the first misaligned output and start backtracking | ||
diff_ret = back_track_group(base, compare, cluster, cmp, graph, node) | ||
if diff_ret: # Input can be aligned, but output cannot be aligned | ||
diff_ret["output"] = output | ||
cmp.record_group_output_diff(diff_ret) | ||
find_diff_group_flag = True | ||
break | ||
|
||
if not find_diff_group_flag: | ||
cmp.record_output_diff(cluster.idx, output, cluster.varmaps.get(output, "")) | ||
logger.info("can't find diff group in cluster_" + cluster.idx + " but diff exsits") | ||
|
||
if output_equals_flag: | ||
logger.info("cluster_" + cluster.idx + " has no diff") | ||
|
||
for diff in cmp.record: | ||
logger.info(diff) | ||
return cmp.record | ||
|
||
|
||
if __name__ == "__main__": | ||
# test code, you can simple use cinn_diff.auto_diff this way | ||
base_path = "/root/dev/PaddleClas/base" | ||
compare_path = "/root/dev/PaddleClas/cinn" | ||
auto_diff(base_path, compare_path, atol=0, rtol=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
import numpy as np | ||
|
||
|
||
class Comparator: | ||
def __init__(self, rtol=0, atol=0) -> None: | ||
self.cluster_ret = {} | ||
self.graph_ret = {} | ||
self.record = [] | ||
self.rtol = rtol | ||
self.atol = atol | ||
|
||
@classmethod | ||
def load_var(self, path): | ||
return paddle.Tensor(paddle.core.load_dense_tensor(path)) | ||
|
||
def allclose(self, base_path, compare_path): | ||
base_var = self.load_var(base_path) | ||
compare_var = self.load_var(compare_path) | ||
ret = np.allclose(base_var, compare_var, rtol=self.rtol, atol=self.atol) | ||
return ret | ||
|
||
def assert_allclose(self, base_path, compare_path): | ||
base_var = self.load_var(base_path) | ||
compare_var = self.load_var(compare_path) | ||
ret = np.testing.assert_allclose(base_var, compare_var, rtol=self.rtol, atol=self.atol) | ||
return ret | ||
|
||
def record_diff(self, diff, type): | ||
diff = { | ||
"type": type, | ||
"event": diff, | ||
} | ||
self.record.append(diff) | ||
|
||
def record_input_diff(self, cluster_idx, input): | ||
self.record_diff({"cluster_idx": cluster_idx, "cluster_input_diff_name": input}, "cluster_input_diff") | ||
|
||
def record_output_diff(self, cluster_idx, output, output_cinn_name): | ||
self.record_diff( | ||
{ | ||
"cluster_idx": cluster_idx, | ||
"cluster_output_diff_paddle_name": output, | ||
"cluster_output_diff_cinn_name": output_cinn_name, | ||
}, | ||
"cluster_output_diff", | ||
) | ||
|
||
def record_group_output_diff(self, diff_ret): | ||
self.record_diff( | ||
{ | ||
"cluster_idx": diff_ret["cluster"], | ||
"cluster_output_diff_paddle_name": diff_ret["output"], | ||
"group_idx": diff_ret["group"].group_id, | ||
"group_output_diff_cinn_name": diff_ret["output_cinn_var"], | ||
"group_graphviz_path": diff_ret["group"].dot_path, | ||
"group_test_py_code_path": diff_ret["group"].txt_path, | ||
"group_diff_subgraph_id": diff_ret["subgraph_id"], | ||
}, | ||
"group_output_diff", | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
# test code, intermediate variables can be read this way | ||
cmp = Comparator() | ||
base = cmp.load_var("/root/dev/PaddleClas/base/saved_tensors/batch_norm_grad-input-batch_norm_0.tmp_3@GRAD") | ||
cinn = cmp.load_var("/root/dev/PaddleClas/cinn/saved_tensors/batch_norm_grad-input-batch_norm_0.tmp_3@GRAD") | ||
np.testing.assert_allclose(base, cinn, rtol=0, atol=0) |
Oops, something went wrong.