-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.mpc
77 lines (53 loc) · 2.64 KB
/
inference.mpc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from Compiler.script_utils import output_utils
from Compiler.script_utils.data import data
from Compiler.script_utils.data import AbstractInputLoader
from Compiler import ml
from Compiler import library
from Compiler.script_utils.audit import shap
from Compiler.script_utils import config, timers, input_consistency
class InferenceConfig(config.BaseAuditModel):
n_samples: int = 1 # -1 = all
batch_size: int = 1
program.options_from_args()
cfg = config.from_program_args(program.args, InferenceConfig)
if not cfg.emulate:
pass
# program.use_trunc_pr = cfg.trunc_pr
# program.use_edabits = True
# program.use_split(4)
# program.use_edabit(False)
# program.use_dabit = False
# program.use_split(3)
# program.set_bit_length(32)
# sfix.round_nearest = True
sfix.round_nearest = cfg.round_nearest
program.use_trunc_pr = cfg.trunc_pr
ml.set_n_threads(cfg.n_threads)
library.start_timer(timer_id=timers.TIMER_LOAD_DATA)
input_loader: AbstractInputLoader = data.get_inference_input_loader(dataset=cfg.dataset, audit_trigger_idx=cfg.audit_trigger_idx,
batch_size=cfg.batch_size, debug=cfg.debug, emulate=cfg.emulate, consistency_check=cfg.consistency_check, sha3_approx_factor=cfg.sha3_approx_factor, n_target_test_samples=cfg.n_samples)
library.stop_timer(timer_id=timers.TIMER_LOAD_DATA)
library.start_timer(timer_id=timers.TIMER_INFERENCE)
# eval here
inf_samples, inf_labels = input_loader.test_dataset() # train dataset in case we dont have test dataset
inf_samples = inf_samples.get_part(0, cfg.n_samples)
inf_labels = inf_labels.get_part(0, cfg.n_samples)
model = input_loader.model()
model.summary()
model.layers[-1].compute_loss = False
prediction_results = model.eval(inf_samples, batch_size=min(cfg.batch_size, cfg.n_samples))
# n_correct, avg_loss = model.reveal_correctness(data=inf_samples, truth=inf_labels, batch_size=input_loader.batch_size(), running=True)
# print_ln(" n_correct=%s n_samples=%s avg_loss=%s", n_correct, len(inf_samples), avg_loss)
library.stop_timer(timer_id=timers.TIMER_INFERENCE)
# write output to file
library.start_timer(timer_id=timers.TIMER_OUTPUT_COMMIT)
if isinstance(prediction_results[0], sfix):
y_arr = Array(1, sfix)
y_arr[0] = prediction_results[0]
else:
y_arr = prediction_results[0]
output_object = input_consistency.InputObject(x=[inf_samples[0]], y=[y_arr])
input_consistency.output(output_object, cfg.consistency_check, cfg.n_threads, cfg.sha3_approx_factor, cfg.cerebro_output_approx_factor)
library.stop_timer(timer_id=timers.TIMER_OUTPUT_COMMIT)
# if cfg.debug:
# print_ln(prediction_results.reveal(), inf_labels.reveal())