diff --git a/README.md b/README.md index 7b9455c..5b5f1fe 100644 --- a/README.md +++ b/README.md @@ -148,7 +148,25 @@ for i in range(6): assert check_params(f"./torch/step_{i}", f"./paddle/step_{i}") == True ``` +### 框架与编译器对齐 +使用文档 [CINN](padiff/cinn_diff/README.md) +```python +import os +from padiff import cinn_diff + + +def run(run_script, base_env, cinn_env): + run_env = cinn_diff.Env(run_script, base_env, cinn_env) + run_env.run_base_model() #可以注释掉选择不运行base model + run_env.run_cinn_model() #也可以注释掉选择不运行cinn model + cinn_diff.auto_diff(run_env.base_path, run_env.cinn_path, rtol=1e-3, atol=1e-3) + + +if __name__ == '__main__': + run_script = "/root/workspace/PaddleNLP/model_zoo/bert/run_bert.sh" + run(run_script, None, None) +``` ## 已支持 `Special Init` 的组件 diff --git a/padiff/__init__.py b/padiff/__init__.py index c7be604..4ead7bd 100644 --- a/padiff/__init__.py +++ b/padiff/__init__.py @@ -27,6 +27,7 @@ from .report.hooks import info_hook from .datas import global_json_laoder as jsons +from . import cinn_diff def module_filter(name): diff --git a/padiff/cinn_diff/.gitignore b/padiff/cinn_diff/.gitignore new file mode 100644 index 0000000..056dc05 --- /dev/null +++ b/padiff/cinn_diff/.gitignore @@ -0,0 +1,2 @@ +*.json +*.log \ No newline at end of file diff --git a/padiff/cinn_diff/README.md b/padiff/cinn_diff/README.md new file mode 100644 index 0000000..576b3a7 --- /dev/null +++ b/padiff/cinn_diff/README.md @@ -0,0 +1,130 @@ +## 快速开始 + +### 安装依赖 + +`python -m pip install -r requirments.txt` + +### 一键运行模式 + +**example** +```python +import os +from padiff import cinn_diff + + +def run(run_script, base_env, cinn_env): + run_env = cinn_diff.Env(run_script, base_env, cinn_env) + run_env.run_base_model() + run_env.run_cinn_model() + cinn_diff.auto_diff(run_env.base_path, run_env.cinn_path, rtol=1e-3, atol=1e-3) + + +if __name__ == '__main__': + run_script = "/root/workspace/PaddleNLP/model_zoo/bert/run_bert.sh" + run(run_script, None, None) +``` +**run_script** +模型运行脚本,使用时提供脚本路径,需在模型内部实现好动转静 + +**base_env** +模型基线运行的环境变量 +初始配置为 +```python +{ + "CUDA_VISIBLE_DEVICES" : "0", + "NVIDIA_TF32_OVERRIDE" : "1", + "CUDA_LAUNCH_BLOCKING" : "1", + "FLAGS_save_static_runtime_data" : "1", + "FLAGS_static_runtime_data_save_path" : "./", + "FLAGS_cudnn_deterministc" : "1", + "FLAGS_cinn_cudnn_deterministc" : "1", + "FLAGS_prim_all" : "true" +} +``` + +**cinn_env** +模型接入编译器运行的环境变量 +初始配置为 +```python +{ + "FLAGS_use_cinn" : "1", + "FLAGS_deny_cinn_ops" :"", + "FLAGS_use_reduce_split_pass" : "1", + "FLAGS_nvrtc_compile_to_cubin" : "0", + "FLAGS_cinn_use_op_fusion" : "1", + "FLAGS_cinn_parallel_compile_size" : "8", + "FLAGS_cinn_pass_visualize_dir": "", +} +``` + +### 手动运行模式 + +step1: 准备模型运行脚本,跑通动转静+组合算子+编译器 + +step2: 手动运行动转静+组合算子的基线模型 + +基线模型运行是需要配置如下环境变量 +``` +"FLAGS_save_static_runtime_data" : "1", + +"FLAGS_static_runtime_data_save_path" : "./base", +``` +step3: 手动运行动转静+组合算子+编译器的模型 + +接入编译器的模型运行需要配置如下环境变量 +``` +"FLAGS_save_static_runtime_data" : "1", + +"FLAGS_static_runtime_data_save_path" : "./cinn", + +"FLAGS_cinn_pass_visualize_dir": "./cinn/cinn_pass", +``` +step4: 运行模型精度对齐脚本 + +```python +from analyze import auto_diff + +base_path = "/root/dev/PaddleClas/base" +compare_path = "/root/dev/PaddleClas/cinn" +auto_diff(base_path, compare_path, atol=0, rtol=0) +``` + +模型运行脚本环境变量配置例子 +``` shell +#!/bin/bash +export CUDA_VISIBLE_DEVICES=5 +export NVIDIA_TF32_OVERRIDE=1 +export CUDA_LAUNCH_BLOCKING=1 +export FLAGS_save_static_runtime_data=true +export FLAGS_cudnn_deterministc=1 +export FLAGS_cinn_cudnn_deterministc=1 +# export FLAGS_check_nan_inf=1 +rm -rf ./cinn/* +export FLAGS_static_runtime_data_save_path="./cinn/" + +# 跑 动转静 + 组合算子时打开下面这1行 +export FLAGS_prim_all=true +# 跑 动转静 + 组合算子 + CINN时打开下面这6行 +export FLAGS_use_cinn=1 +export FLAGS_deny_cinn_ops="reduce_sum" +export FLAGS_use_reduce_split_pass=1 +export FLAGS_nvrtc_compile_to_cubin=0 +export FLAGS_cinn_use_op_fusion=1 +export FLAGS_cinn_parallel_compile_size=8 + + +# # before and after cinn program and graph pass(including group opfusion pass) in each sub-graph +rm -rf ./cinn_pass/* +export FLAGS_cinn_pass_visualize_dir="./cinn/cinn_pass/" + +task_name_or_path="llama_output" +python run_pretrain.py \ + --model_type "llama" \ + ... +``` + +## 运行结果 +![运行结果图](./img/run_ret.png) + + +更多功能正在研发中... diff --git a/padiff/cinn_diff/__init__.py b/padiff/cinn_diff/__init__.py new file mode 100644 index 0000000..0511ee5 --- /dev/null +++ b/padiff/cinn_diff/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .graph import * +from .compare_utils import * +from .utils import * +from .env import * +from .analyze import * diff --git a/padiff/cinn_diff/analyze.py b/padiff/cinn_diff/analyze.py new file mode 100644 index 0000000..513223f --- /dev/null +++ b/padiff/cinn_diff/analyze.py @@ -0,0 +1,122 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .read_file import read_all +from .compare_utils import Comparator +from .logs import logger + +# TODO(GGBond8488): Needs optimization in the future +def back_track_group(base, compare, cluster, cmp, graph, node): + inputs = graph.inputs() + all_inputs_equal = True + paddle_output_name = "" + cur_cluster_cinn2paddle = {v: k for k, v in cluster.varmaps.items()} + for input in inputs: + tmp = input + paddle_name = cur_cluster_cinn2paddle.get(tmp.name, "") + if not paddle_name: + logger.info(f"can't find {node.name}'s paddle name") + diff_ret = { + "cluster": cluster.idx, + "group": cluster.cinn_group, + "output": paddle_output_name, + "output_cinn_var": node.name, + "subgraph_id": node.graph_id if node else None, + } + return diff_ret + paddle_output_name = paddle_name + base_var_path = base.all_vars_paths[paddle_name] + compare_var_path = compare.all_vars_paths[paddle_name] + ret = cmp.allclose(base_var_path, compare_var_path) + if not ret: + all_inputs_equal = False + group = cluster.cinn_group + for graph in group.subgraphs: + node = graph.find(tmp.name) + if node and not graph.is_input(node): + return back_track_group(base, compare, cluster, cmp, graph, node) + if all_inputs_equal: + diff_ret = { + "cluster": cluster.idx, + "group": cluster.cinn_group, + "output": paddle_output_name, + "output_cinn_var": node.name, + "subgraph_id": node.graph_id if node else None, + } + return diff_ret + + +def auto_diff(base_path, compare_path, rtol=1e-6, atol=1e-6): + base = read_all(base_path) + compare = read_all(compare_path) + cmp = Comparator(rtol=rtol, atol=atol) + + # step1: Confirm whether the input and output of the cluster are aligned + for cluster in compare.all_clusters: + input_equals_flag = True + output_equals_flag = True + for input in cluster.inputs: + base_var_path = base.all_vars_paths[input] + compare_var_path = compare.all_vars_paths[input] + ret = cmp.allclose(base_var_path, compare_var_path) + if not ret: + input_equals_flag = False + cmp.record_input_diff(cluster.idx, input) + continue + + if input_equals_flag: + # step2: Find the misaligned output of the cluster + for output in cluster.outputs: + base_var_path = base.all_vars_paths[output] + compare_var_path = compare.all_vars_paths[output] + ret = cmp.allclose(base_var_path, compare_var_path) + if not ret: + output_equals_flag = False + # step3: Find the group corresponding to the misaligned output + output_cinn_var = cluster.varmaps.get(output, "") + if not output_cinn_var: + logger.info("can't find var " + output + " corresponding cinn var name") + else: + find_diff_group_flag = False + # step4 : Starting from the misaligned output, find the group where the output misalignment + # occurs for the first time (the input can be aligned, but the output cannot be aligned) + group = cluster.cinn_group + for graph in group.subgraphs: + node = graph.find(output_cinn_var) + if node and not graph.is_input(node): + # Find the first misaligned output and start backtracking + diff_ret = back_track_group(base, compare, cluster, cmp, graph, node) + if diff_ret: # Input can be aligned, but output cannot be aligned + diff_ret["output"] = output + cmp.record_group_output_diff(diff_ret) + find_diff_group_flag = True + break + + if not find_diff_group_flag: + cmp.record_output_diff(cluster.idx, output, cluster.varmaps.get(output, "")) + logger.info("can't find diff group in cluster_" + cluster.idx + " but diff exsits") + + if output_equals_flag: + logger.info("cluster_" + cluster.idx + " has no diff") + + for diff in cmp.record: + logger.info(diff) + return cmp.record + + +if __name__ == "__main__": + # test code, you can simple use cinn_diff.auto_diff this way + base_path = "/root/dev/PaddleClas/base" + compare_path = "/root/dev/PaddleClas/cinn" + auto_diff(base_path, compare_path, atol=0, rtol=0) diff --git a/padiff/cinn_diff/compare_utils.py b/padiff/cinn_diff/compare_utils.py new file mode 100644 index 0000000..0ab4647 --- /dev/null +++ b/padiff/cinn_diff/compare_utils.py @@ -0,0 +1,83 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np + + +class Comparator: + def __init__(self, rtol=0, atol=0) -> None: + self.cluster_ret = {} + self.graph_ret = {} + self.record = [] + self.rtol = rtol + self.atol = atol + + @classmethod + def load_var(self, path): + return paddle.Tensor(paddle.core.load_dense_tensor(path)) + + def allclose(self, base_path, compare_path): + base_var = self.load_var(base_path) + compare_var = self.load_var(compare_path) + ret = np.allclose(base_var, compare_var, rtol=self.rtol, atol=self.atol) + return ret + + def assert_allclose(self, base_path, compare_path): + base_var = self.load_var(base_path) + compare_var = self.load_var(compare_path) + ret = np.testing.assert_allclose(base_var, compare_var, rtol=self.rtol, atol=self.atol) + return ret + + def record_diff(self, diff, type): + diff = { + "type": type, + "event": diff, + } + self.record.append(diff) + + def record_input_diff(self, cluster_idx, input): + self.record_diff({"cluster_idx": cluster_idx, "cluster_input_diff_name": input}, "cluster_input_diff") + + def record_output_diff(self, cluster_idx, output, output_cinn_name): + self.record_diff( + { + "cluster_idx": cluster_idx, + "cluster_output_diff_paddle_name": output, + "cluster_output_diff_cinn_name": output_cinn_name, + }, + "cluster_output_diff", + ) + + def record_group_output_diff(self, diff_ret): + self.record_diff( + { + "cluster_idx": diff_ret["cluster"], + "cluster_output_diff_paddle_name": diff_ret["output"], + "group_idx": diff_ret["group"].group_id, + "group_output_diff_cinn_name": diff_ret["output_cinn_var"], + "group_graphviz_path": diff_ret["group"].dot_path, + "group_test_py_code_path": diff_ret["group"].txt_path, + "group_diff_subgraph_id": diff_ret["subgraph_id"], + }, + "group_output_diff", + ) + + +if __name__ == "__main__": + # test code, intermediate variables can be read this way + cmp = Comparator() + base = cmp.load_var("/root/dev/PaddleClas/base/saved_tensors/batch_norm_grad-input-batch_norm_0.tmp_3@GRAD") + cinn = cmp.load_var("/root/dev/PaddleClas/cinn/saved_tensors/batch_norm_grad-input-batch_norm_0.tmp_3@GRAD") + np.testing.assert_allclose(base, cinn, rtol=0, atol=0) diff --git a/padiff/cinn_diff/env.py b/padiff/cinn_diff/env.py new file mode 100644 index 0000000..dcf93c8 --- /dev/null +++ b/padiff/cinn_diff/env.py @@ -0,0 +1,116 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +from .logs import logger + + +class Env: + + base_dir_name = "base" + cinn_dir_name = "cinn" + cinn_pass_dir = "cinn_pass" + cinn_graph_dir = "cinn_graph" + + def __init__( + self, + script=None, + base_env=None, + cinn_env=None, + ): + self._base_env = { + "CUDA_VISIBLE_DEVICES": "7", + "NVIDIA_TF32_OVERRIDE": "1", + "CUDA_LAUNCH_BLOCKING": "1", + "FLAGS_save_static_runtime_data": "1", + "FLAGS_static_runtime_data_save_path": "./", + "FLAGS_cudnn_deterministc": "1", + "FLAGS_cinn_cudnn_deterministc": "1", + "FLAGS_prim_all": "true", + } + self._cinn_env = { + "FLAGS_use_cinn": "1", + "FLAGS_deny_cinn_ops": "", + "FLAGS_use_reduce_split_pass": "1", + "FLAGS_nvrtc_compile_to_cubin": "0", + "FLAGS_cinn_use_op_fusion": "1", + "FLAGS_cinn_parallel_compile_size": "8", + "FLAGS_cinn_pass_visualize_dir": "", + } + self.base_env = base_env if base_env else self._base_env + self.cinn_env = cinn_env if cinn_env else self._cinn_env + self.base_path = os.path.join(os.path.dirname(script), self.base_dir_name) + self.cinn_path = os.path.join(os.path.dirname(script), self.cinn_dir_name) + self.script = script + self.script_path = os.path.dirname(script) + self.script_name = os.path.basename(script) + self.os_env = dict(os.environ) + + def init_base_env(self): + if os.path.exists(self.base_path): + logger.info("base path exists, remove it") + os.system("rm -rf " + self.base_path) + self.base_env["FLAGS_static_runtime_data_save_path"] = self.base_path + self.base_env["FLAGS_save_static_runtime_data"] = "1" + + def set_base_env(self, env): + self.base_env = env + + def init_cinn_env(self): + self.base_env["FLAGS_static_runtime_data_save_path"] = self.cinn_path + if os.path.exists(self.cinn_path): + logger.info("cinn path exists, remove it") + os.system("rm -rf " + self.cinn_path) + self.cinn_env["FLAGS_cinn_pass_visualize_dir"] = os.path.join(self.cinn_path, self.cinn_pass_dir) + self.cinn_env["FLAGS_cinn_subgraph_graphviz_dir"] = os.path.join(self.cinn_path, self.cinn_graph_dir) + + def set_cinn_env(self, env): + self.cinn_env = env + + def set_script(self, name): + self.script = name + + def run_model(self, run_env, log): + logger.info(self.script) + ret = subprocess.run(["sh", self.script_name], env=run_env, stdout=log, stderr=log) + logger.info(ret) + + def run_base_model(self): + self.init_base_env() + root_path = os.getcwd() + os.chdir(self.script_path) + run_env = self.base_env.copy() + logger.info(run_env) + run_env.update(self.os_env) + base_log = open("base.log", "w") + self.run_model(run_env, base_log) + base_log.close() + os.chdir(root_path) + + def run_cinn_model(self): + self.init_cinn_env() + root_path = os.getcwd() + os.chdir(self.script_path) + run_env = self.cinn_env.copy() + base_env = self.base_env.copy() + for key in base_env: + if key not in run_env: + run_env[key] = base_env[key] + logger.info(run_env) + run_env.update(self.os_env) + cinn_log = open("cinn.log", "w") + self.run_model(run_env, cinn_log) + cinn_log.close() + os.chdir(root_path) diff --git a/padiff/cinn_diff/graph.py b/padiff/cinn_diff/graph.py new file mode 100644 index 0000000..4272244 --- /dev/null +++ b/padiff/cinn_diff/graph.py @@ -0,0 +1,286 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import graphviz +import pygraphviz as pgv +from .utils import retry +from .logs import logger + + +@retry(max_times=1) +def get_graph(dot_path): + return pgv.AGraph(dot_path) + + +def construct_graph_by_dot(dot_path, sep="\\n"): + # logger.info("dot_path:" + dot_path) + graph_source = get_graph(dot_path) + # ['color', 'label', 'style'] + all_nodes = [] + idx2nodes = {} + ret_subgraphs = [] + subgraphs = graph_source.subgraphs() + if not subgraphs: + subgraphs = [graph_source] + for subgraph in subgraphs: + subgraph_id = subgraph.get_name().split("_")[-1] + tmp_nodes = [] + for node in subgraph.nodes(): + name = node.attr["label"].split(sep)[0] + idx = node.get_name() + cls_node = Node(name=name, idx=idx, type="unknown", graph_id=subgraph_id) + idx2nodes[idx] = cls_node + all_nodes.append(cls_node) + tmp_nodes.append(cls_node) + ret_subgraphs.append(Graph(nodes=tmp_nodes, name=f"group_{subgraph_id}")) + + for edge in graph_source.edges(): + start, end = edge + if start not in idx2nodes.keys() or end not in idx2nodes.keys(): + continue + # Output edge + idx2nodes[start].add_output(idx2nodes[end]) + # Input edge + idx2nodes[end].add_input(idx2nodes[start]) + + return all_nodes, ret_subgraphs + + +class Graph: + def __init__(self, nodes, name) -> None: + self.nodes = nodes + self.name = str(name) + + def add_node(self, node): + if isinstance(node, Node): + if node in self.nodes: + return + else: + self.nodes.append(node) + else: + raise ValueError(" param type must be Node") + + # just for cinn graph + def graph_inputs(self): + inputs = set() + for node in self.nodes: + if node.name == "feed": + inputs.add(node.outputs[0].name) + + return inputs + + def inputs(self): + inputs = [] + for node in self.nodes: + if node.name == "feed": + inputs.append(node.outputs[0]) + if not node.inputs: + inputs.append(node) + # Inputs in another subgraph are also counted as inputs to the current subgraph. + for node in node.inputs: + if node not in self.nodes: + inputs.append(node) + return inputs + + # just for cinn graph + def graph_outputs(self): + outputs = set() + for node in self.nodes: + if node.name == "fetch": + outputs.add(node.inputs[0].name) + if not node.outputs: + outputs.add(node.name) + return outputs + + def outputs(self): + outputs = [] + for node in self.nodes: + if node.name == "fetch": + outputs.append(node.inputs[0]) + if not node.outputs: + outputs.append(node) + # The output is in another subgraph and is also counted as the output of the current subgraph. + for node in node.outputs: + if node not in self.nodes: + outputs.append(node) + return outputs + + def find(self, cinn_var_name): + for node in self.nodes: + if node.name == cinn_var_name: + return node + return None + + def is_input(self, node): + return node in self.inputs() + + def is_output(self, node): + return node in self.outputs() + + def export_dot(self): + dot = graphviz.Digraph(comment=self.name) + for item in self.nodes: + dot.node(item.idx, item.idx + "\n" + item.name + ":" + item.node_type) + for next in item.outputs: + dot.edge(item.idx, next.idx) + return dot + + def __str__(self) -> str: + return "graph_" + str(self.name) + + def __repr__(self) -> str: + return "graph_" + str(self.name) + + +class Pass: + def __init__(self, id, pass_name=None, before_txt=None, after_txt=None, before_dot=None, after_dot=None) -> None: + self.pass_id = id + self.pass_name = pass_name + self.before_txt = before_txt + self.after_txt = after_txt + self.before_dot = before_dot + self.after_dot = after_dot + + def set_pass_name(self, pass_name): + self.pass_name = pass_name + + def set_before_txt(self, before_txt): + self.before_txt = before_txt + + def set_after_txt(self, after_txt): + self.after_txt = after_txt + + def set_before_dot(self, before_dot): + self.before_dot = before_dot + + def set_after_dot(self, after_dot): + self.after_dot = after_dot + + def __str__(self) -> str: + return "pass_" + str(self.pass_id) + "_" + self.pass_name + + def __repr__(self) -> str: + return "pass_" + str(self.pass_id) + "_" + self.pass_name + + +class Group: + def __init__(self, group_id, all_passes, last_pass_id) -> None: + self.group_id = group_id + self.passes = all_passes + self.dot_path = all_passes[last_pass_id].after_dot + self.txt_path = all_passes[last_pass_id].after_txt + self.all_nodes, self.subgraphs = construct_graph_by_dot(self.dot_path) + self.fetch = None + self.feed = None + + def export_graph(self): + self.graph = Graph(self.all_nodes, self.__str__) + dot = self.graph.export_dot() + dot.render(self.__str__(), format="png", cleanup=True) + + def export_dot(self): + dot = graphviz.Source(self.dot_path) + dot.render(self.__str__(), format="png", cleanup=True) + + def __str__(self) -> str: + return "fusion_group_" + str(self.group_id) + + def __repr__(self) -> str: + return "fusion_group_" + str(self.group_id) + + +class Node: + def __init__(self, name, type, idx, graph_id=None) -> None: + self.name = name # var name, like arg_1 + self.node_type = type if type else "unknown" + self.idx = idx # node name like node1 + self.inputs = [] + self.outputs = [] + self.cinn_name = "" + self.graph_id = graph_id + + def is_op(self): + return self.node_type == "op" + + def is_var(self): + return self.node_type == "var" + + def is_leaf(self): + return self.outputs == [] or self.outputs[0].name == ["fetch"] + + def is_root(self): + return self.inputs == [] or self.inputs[0].name == ["feed"] + + def set_outputs(self, outputs): + self.outputs = outputs + + def set_inputs(self, inputs): + self.inputs = inputs + + def add_input(self, node): + if isinstance(node, Node): + if node in self.inputs: + return + else: + self.inputs.append(node) + else: + raise ValueError("Node input must be Node") + + def add_output(self, node): + if isinstance(node, Node): + if node in self.outputs: + return + else: + self.outputs.append(node) + else: + raise ValueError("Node output must be Node") + + def __str__(self) -> str: + return self.name + "_" + self.idx + " : " + self.node_type + + def __repr__(self) -> str: + return self.name + "_" + self.idx + " : " + self.node_type + + +class Cluster: + def __init__(self, idx, graph, ops, inputs, outputs, graph_key, varmaps=None) -> None: + self.idx = idx + self.graph = graph + self.ops = ops + self.inputs = inputs + self.outputs = outputs + self.graph_key = graph_key + self.varmaps = varmaps + self.cinn_group = None + + def set_varmaps(self, varmaps: dict): + self.varmaps = varmaps + + def set_associate_groups(self, group): + if isinstance(group, [list, set, tuple]): + self.associate_groups.extend(list(group)) + elif isinstance(group, str): + self.associate_groups.append(group) + else: + raise ValueError(f"group must be str or sequence type, but got {type(group)}") + + def __str__(self) -> str: + return "Cluster_" + str(self.idx) + + def __repr__(self) -> str: + return "Cluster_" + str(self.idx) + + def print_varmaps(self): + for paddle_name, cinn_name in self.varmaps.items(): + logger.info({"graph_key": self.graph_key, "paddle_name": paddle_name, "cinn_name": cinn_name}) diff --git a/padiff/cinn_diff/img/run_ret.png b/padiff/cinn_diff/img/run_ret.png new file mode 100644 index 0000000..e224fd8 Binary files /dev/null and b/padiff/cinn_diff/img/run_ret.png differ diff --git a/padiff/cinn_diff/logs.py b/padiff/cinn_diff/logs.py new file mode 100644 index 0000000..4ea896d --- /dev/null +++ b/padiff/cinn_diff/logs.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import logging +from logging import handlers + + +class Logger(object): + def __init__( + self, filename, level=logging.INFO, when="D", backCount=3, fmt="%(asctime)s" "-%(levelname)s: %(message)s" + ): + + self.logger = logging.getLogger(filename) + self.logger.setLevel(level) + self.format_str = logging.Formatter(fmt) + + time_file_handler = handlers.TimedRotatingFileHandler( + filename=filename, + when=when, + interval=1, + backupCount=backCount, + encoding="utf-8", + ) + + time_file_handler.setFormatter(self.format_str) + self.logger.addHandler(time_file_handler) + + def loggerImp(self): + return self.logger + + +logger = Logger(os.path.join(sys.path[0], "cinn_diff.log")).loggerImp() + + +def log_init(file_name): + global logger + if file_name: + logger = Logger(os.path.join(sys.path[0], file_name)).loggerImp() diff --git a/padiff/cinn_diff/read_file.py b/padiff/cinn_diff/read_file.py new file mode 100644 index 0000000..79eac01 --- /dev/null +++ b/padiff/cinn_diff/read_file.py @@ -0,0 +1,268 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import collections + +from .graph import Graph, Node, Cluster, Group, Pass, construct_graph_by_dot +from .logs import logger + + +INPUTS_NAME = "cluster_inputs.txt" +OUTPUTS_NAME = "cluster_outputs.txt" +OPS_NAME = "cluster_ops.txt" +GRAPH_NAME = "subgraph.txt" +PADDLE2CINN_VARMAP = "paddle2cinn_varmap.txt" +GRAPH_COMPLATION_KEY = "graph_compilation_key.txt" +VARMAPS_KEY_NAME = "graph_compilation_key" + + +def read_varmaps(varmaps_file): + graph2varmaps = {} + varmaps_file = os.path.join(varmaps_file, PADDLE2CINN_VARMAP) + with open(varmaps_file) as f: + lines = f.readlines() + cur_graph_key = None + for line in lines: + var_map = line.strip("\n").split(":") + if var_map[0] == VARMAPS_KEY_NAME: + cur_graph_key = var_map[1] + graph2varmaps[cur_graph_key] = {} + continue + if not cur_graph_key: + return + graph2varmaps[cur_graph_key][var_map[0]] = var_map[1] + + return graph2varmaps + + +def read_tensors(tensors_path): + tensors_map = {} + assert os.path.isdir(tensors_path) + for file in os.listdir(tensors_path): + # ['matmul_v2_grad', 'input', 'scale_0.tmp_0'] + var_info = file.split("-") + # bind var_name to var tensor file + if "share_buffer" in file: + continue + tensors_map[var_info[-1]] = os.path.join(tensors_path, file) + return tensors_map + + +def read_graph(graph_file, idx): + assert os.path.isfile(graph_file) + nodes = {} + edges = {} + + def record_nodes_and_edges(line, type, nodes, edges): + if type == "nodes": + node = line.split(" : ") + node_id, node_desc = node[0], node[1] + name, node_type = node_desc[1:-1].split(", ") + if name in ["feed", "fetch"]: + return + nodes[node_id] = Node(name, node_type, node_id) + elif type == "edges": + edge = line.split(" -> ") + cur, next = edge[0], edge[1] + edges[cur] = next + else: + raise ValueError(type + "not support") + + type = "nodes" + with open(graph_file) as f: + lines = f.readlines() + for line in lines: + line = line.strip("\n") + if line.startswith("nodes:"): # start to record node + type = "nodes" + continue + if line.startswith("edges:"): # start to record edge + type = "edges" + continue + record_nodes_and_edges(line, type, nodes, edges) + + def construct_graph(nodes, edges): + for k, v in edges.items(): + if k not in nodes or v not in nodes: + continue + nodes[k].add_output(nodes[v]) + nodes[v].add_input(nodes[k]) + graph = Graph(nodes.values(), idx) + return graph + + graph = construct_graph(nodes, edges) + return graph + + +def read_strings(string_file): + assert os.path.isfile(string_file) + with open(string_file) as f: + line = f.readline() + line = line[1:-1] + rets = line.split(", ")[:-1] + return rets + + +def read_string(string_file): + assert os.path.isfile(string_file) + with open(string_file) as f: + line = f.readline() + rets = line + return rets + + +def read_cluster(path, idx): + assert os.path.isdir(path), f"{path} must be dir" + inputs = read_strings(os.path.join(path, INPUTS_NAME)) + outputs = read_strings(os.path.join(path, OUTPUTS_NAME)) + ops = read_strings(os.path.join(path, OPS_NAME)) + graph = read_graph(os.path.join(path, GRAPH_NAME), idx) + graph_key = read_string(os.path.join(path, GRAPH_COMPLATION_KEY)) + return Cluster(idx, graph, ops, inputs, outputs, graph_key) + + +def read_cinn_pass(path): + all_groups = {} + + def read_graphviz_dot(path): + passes = os.listdir(path) + idx = path.split("_")[-1] + all_passes = {} + for pass_path in passes: + pass_idx = int(pass_path.split("_")[1]) + if pass_idx not in all_passes: + all_passes[pass_idx] = Pass(pass_idx) + pass_name = pass_path.split("_")[2] + all_passes[pass_idx].set_pass_name(pass_name) + type = pass_path.split("_")[3] + record_path = os.path.join(path, pass_path) + if type == "after.txt": + all_passes[pass_idx].set_after_txt(record_path) + elif type == "before.txt": + all_passes[pass_idx].set_before_txt(record_path) + elif type == "after.dot": + all_passes[pass_idx].set_after_dot(record_path) + elif type == "before.dot": + all_passes[pass_idx].set_before_dot(record_path) + else: + raise ValueError(type + "not support") + max_pass_id = max(all_passes.keys()) + group_cc = Group(idx, all_passes, max_pass_id) + all_groups[idx] = group_cc + + file_names = os.listdir(path) + for file_name in file_names: + read_graphviz_dot(os.path.join(path, file_name)) + + return all_groups + + +def read_cinn_graph(path): + all_cinn_graphs = {} + + def read_cinn_graph_dot(path, idx): + graph_path = os.listdir(path)[0] + file_path = os.path.join(path, graph_path) + nodes, _ = construct_graph_by_dot(file_path, sep="\n") + graph = Graph(nodes=nodes, name=str("cinn_graph_" + idx)) + return graph + + file_names = os.listdir(path) + for file_name in file_names: + idx = file_name.split("_")[-1] + graph = read_cinn_graph_dot(os.path.join(path, file_name), idx) + all_cinn_graphs[idx] = graph + return all_cinn_graphs + + +def set_node_cinn_name(all_clusters): + for cluster in all_clusters: + nodes = cluster.graph.nodes + for node in nodes: + if node.is_var(): + node.cinn_name = cluster.varmaps.get(node.name, "") + + +def set_cluster_varmaps(clusters, varmaps): + for cluster in clusters: + tmp_varmaps = varmaps.get(cluster.graph_key, "") + if not tmp_varmaps: + raise KeyError(f"can't find graph key {cluster.graph_key} in graph2varmaps") + cluster.set_varmaps(tmp_varmaps) + + +def set_clusters_group(clusters, groups, cinn_graphs): + for cluster in clusters: + inputs = cluster.inputs + outputs = cluster.outputs + for idx, graph in cinn_graphs.items(): + graph_inputs = graph.graph_inputs() + graph_outputs = graph.graph_outputs() + if not graph_inputs and not graph_outputs: + raise ValueError(f"{graph} does not have inputs or outputs") + + if not set(inputs).difference(graph_inputs) and not set(outputs).difference(graph_outputs): + logger.info(f"group_{idx} belongs to Cluster_{cluster.idx}") + cluster.cinn_group = groups[idx] + + +def read_all(root_path="", type="cinn"): + + assert root_path, f"{root_path} can't be None" + all_clusters = [] + # paddle2cinn_varmaps + graph2varmaps = {} + all_vars_paths = {} + all_cinn_groups = {} + all_cinn_graphs = {} + + cinn2paddle_varmaps = {} + + allinone = collections.namedtuple( + "allinone", ["all_clusters", "all_varmaps", "all_vars_paths", "all_cinn_groups", "cinn2paddle"] + ) + + all_paths = os.listdir(root_path) + for path in all_paths: + file_path = os.path.join(root_path, path) + assert os.path.isfile(file_path) or os.path.isdir(file_path), f"{file_path} must be path or dir" + if type == "cinn" and path.startswith("cluster"): + idx = path.split("_")[-1] + all_clusters.append(read_cluster(file_path, idx)) + + if type == "cinn" and path == "paddle2cinn_varmap": + graph2varmaps.update(read_varmaps(file_path)) + + if path == "saved_tensors": + all_vars_paths.update(read_tensors(file_path)) + + if type == "cinn" and path == "cinn_pass": + all_cinn_groups = read_cinn_pass(file_path) + + if type == "cinn" and path == "cinn_graph": + all_cinn_graphs = read_cinn_graph(file_path) + + set_cluster_varmaps(all_clusters, graph2varmaps) + + if type == "cinn": + set_node_cinn_name(all_clusters) + set_clusters_group(all_clusters, all_cinn_groups, all_cinn_graphs) + + return allinone(all_clusters, graph2varmaps, all_vars_paths, all_cinn_groups, cinn2paddle_varmaps) + + +if __name__ == "__main__": + read_all() diff --git a/padiff/cinn_diff/requirments.txt b/padiff/cinn_diff/requirments.txt new file mode 100644 index 0000000..27f2041 --- /dev/null +++ b/padiff/cinn_diff/requirments.txt @@ -0,0 +1,2 @@ +pygraphviz==1.11 +graphviz==0.20.1 diff --git a/padiff/cinn_diff/run.py b/padiff/cinn_diff/run.py new file mode 100644 index 0000000..30cd2cb --- /dev/null +++ b/padiff/cinn_diff/run.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from padiff import cinn_diff +import json + + +def run(run_script, base_env, cinn_env): + run_env = cinn_diff.Env(run_script, base_env, cinn_env) + run_env.run_base_model() + run_env.run_cinn_model() + ret = cinn_diff.auto_diff(run_env.base_path, run_env.cinn_path, rtol=0, atol=0) + with open("./cmp_ret.json", "w") as jsonf: + json.dump(ret, jsonf, indent=4) + + +if __name__ == "__main__": + _base_env = { + "CUDA_VISIBLE_DEVICES": "7", + "NVIDIA_TF32_OVERRIDE": "1", + "CUDA_LAUNCH_BLOCKING": "1", + "FLAGS_cudnn_deterministc": "1", + "FLAGS_cinn_cudnn_deterministc": "1", + "FLAGS_prim_all": "true", + } + _cinn_env = { + "FLAGS_use_cinn": "1", + "FLAGS_deny_cinn_ops": "reduce_sum", + "FLAGS_use_reduce_split_pass": "1", + "FLAGS_nvrtc_compile_to_cubin": "0", + "FLAGS_cinn_use_op_fusion": "1", + "FLAGS_cinn_parallel_compile_size": "8", + "FLAGS_cinn_pass_visualize_dir": "", + } + run_script = "/root/dev/PaddleNLP/model_zoo/bert/run_bert.sh" + run(run_script, _base_env, _cinn_env) diff --git a/padiff/cinn_diff/utils.py b/padiff/cinn_diff/utils.py new file mode 100644 index 0000000..f71b1a0 --- /dev/null +++ b/padiff/cinn_diff/utils.py @@ -0,0 +1,78 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager +from functools import wraps +import os, sys + + +@contextmanager +def suppress_stdout(): + with open(os.devnull, "w") as devnull: + old_stdout = sys.stdout + old_stderr = sys.stderr + old_stdin = sys.stdin + sys.stdout = devnull + sys.stderr = devnull + sys.stdin = devnull + try: + yield + finally: + sys.stdout = old_stdout + sys.stderr = old_stderr + sys.stdin = old_stdin + + +def retry(max_times=1): + def retry_decorator(func): + @wraps(func) + def inner(*args, **kwargs): + retry_times = 0 + while retry_times <= max_times: + try: + ret = func(*args, **kwargs) + return ret + except Exception as e: + retry_times += 1 + + return inner + + return retry_decorator + + +# \033 [显示方式;字体色;背景色m ...... [\033[0m] +# 显示方式: 0(默认值)、1(高亮)、22(非粗体)、4(下划线)、24(非下划线)、 5(闪烁)、25(非闪烁)、7(反显)、27(非反显) +# 前景色: 30(黑色)、31(红色)、32(绿色)、 33(黄色)、34(蓝色)、35(洋 红)、36(青色)、37(白色) +# 背景色: 40(黑色)、41(红色)、42(绿色)、 43(黄色)、44(蓝色)、45(洋 红)、46(青色)、47(白色) + + +class console: + def __init__(self) -> None: + pass + + @classmethod + def red(self, str): + return + + @classmethod + def info(self, str): + return + + @classmethod + def error(self, str): + return + + @classmethod + def warning(self, str): + return