-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
318 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
{ | ||
// Use IntelliSense to learn about possible attributes. | ||
// Hover to view descriptions of existing attributes. | ||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 | ||
"version": "0.2.0", | ||
"configurations": [ | ||
{ | ||
"name": "(gdb) 启动", | ||
"type": "cppdbg", | ||
"request": "launch", | ||
"program": "${workspaceFolder}/example/example.out", | ||
"args": [], | ||
"stopAtEntry": false, | ||
"cwd": "${workspaceFolder}/example", | ||
"environment": [{ | ||
"name": "PATH", | ||
"value": "/opt/conda/envs/py38/bin" | ||
}], | ||
"externalConsole": false, | ||
"MIMode": "gdb", | ||
"setupCommands": [ | ||
{ | ||
"description": "为 gdb 启用整齐打印", | ||
"text": "-enable-pretty-printing", | ||
"ignoreFailures": true | ||
}, | ||
{ | ||
"description": "将反汇编风格设置为 Intel", | ||
"text": "-gdb-set disassembly-flavor intel", | ||
"ignoreFailures": true | ||
} | ||
] | ||
}, | ||
{ | ||
"name": "Python: Current File", | ||
"type": "python", | ||
"request": "launch", | ||
"program": "${file}", | ||
"console": "integratedTerminal" | ||
} | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,69 @@ | ||
import os | ||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | ||
import logging | ||
logging.basicConfig(level=logging.DEBUG) | ||
|
||
from pathlib import Path | ||
import sys | ||
import numpy as np | ||
sys.path.append(str(Path(__file__).parent/".."/"src"/"python")) | ||
|
||
from PyCXpress import debug_array | ||
from PyCXpress import InputDataSet, OutputDataSet | ||
from PyCXpress import TensorMeta, ModelAnnotationCreator, ModelAnnotationType, ModelRuntimeType | ||
from PyCXpress import convert_to_spec_tuple | ||
from contextlib import nullcontext | ||
|
||
def show(a): | ||
debug_array(a) | ||
def show(a: np.array): | ||
logging.info(f"array data type: {a.dtype}") | ||
logging.info(f"array data shape: {a.shape}") | ||
logging.info(f"array data: ") | ||
logging.info(a) | ||
|
||
InputFields = dict( | ||
data_to_be_reshaped=TensorMeta(dtype=np.float_, | ||
shape=(100,), | ||
), | ||
new_2d_shape=TensorMeta(dtype=np.uint8, | ||
shape=-2,) | ||
) | ||
|
||
|
||
class InputDataSet(metaclass=ModelAnnotationCreator, fields=InputFields, type=ModelAnnotationType.Input, mode=ModelRuntimeType.EagerExecution): | ||
pass | ||
|
||
|
||
OutputFields = dict( | ||
output_a=TensorMeta(dtype=np.float_, | ||
shape=(10, 10),), | ||
) | ||
|
||
|
||
class OutputDataSet(metaclass=ModelAnnotationCreator, fields=OutputFields, type=ModelAnnotationType.Output, mode=ModelRuntimeType.EagerExecution): | ||
pass | ||
|
||
|
||
def init(): | ||
return InputDataSet(), OutputDataSet() | ||
return InputDataSet(), OutputDataSet(), tuple((*convert_to_spec_tuple(InputFields.values()), *convert_to_spec_tuple(OutputFields.values()))), tuple(OutputFields.keys()) | ||
|
||
def model(input: InputDataSet, output: OutputDataSet): | ||
with nullcontext(): | ||
output.output_a = input.input_a + input.input_b | ||
# print(input.data_to_be_reshaped) | ||
# print(input.new_2d_shape) | ||
output.output_a = input.data_to_be_reshaped.reshape(input.new_2d_shape) | ||
# print(output.output_a) | ||
|
||
def main(): | ||
input_data, output_data, spec = init() | ||
print(spec) | ||
|
||
input_data.set_buffer_value("data_to_be_reshaped", np.arange(12, dtype=np.float_)) | ||
print(input_data.data_to_be_reshaped) | ||
input_data.set_buffer_value("new_2d_shape", np.array([3, 4]).astype(np.uint8)) | ||
print(input_data.new_2d_shape) | ||
output_data.set_buffer_value("output_a", np.arange(12)*0) | ||
|
||
model(input_data, output_data) | ||
print(output_data.output_a) | ||
print(output_data.get_buffer_shape("output_a")) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.