forked from FederatedAI/FATE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathftl_param.py
197 lines (173 loc) · 8.72 KB
/
ftl_param.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright 2019 The FATE Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import collections
import copy
from federatedml.param.intersect_param import IntersectParam
from types import SimpleNamespace
from federatedml.param.base_param import BaseParam, deprecated_param
from federatedml.util import consts
from federatedml.param.encrypt_param import EncryptParam
from federatedml.param.encrypted_mode_calculation_param import EncryptedModeCalculatorParam
from federatedml.param.predict_param import PredictParam
from federatedml.param.callback_param import CallbackParam
deprecated_param_list = ["validation_freqs", "metrics"]
@deprecated_param(*deprecated_param_list)
class FTLParam(BaseParam):
def __init__(self, alpha=1, tol=0.000001,
n_iter_no_change=False, validation_freqs=None, optimizer={'optimizer': 'Adam', 'learning_rate': 0.01},
nn_define={}, epochs=1, intersect_param=IntersectParam(consts.RSA), config_type='keras', batch_size=-1,
encrypte_param=EncryptParam(),
encrypted_mode_calculator_param=EncryptedModeCalculatorParam(mode="confusion_opt"),
predict_param=PredictParam(), mode='plain', communication_efficient=False,
local_round=5, callback_param=CallbackParam()):
"""
Parameters
----------
alpha : float
a loss coefficient defined in paper, it defines the importance of alignment loss
tol : float
loss tolerance
n_iter_no_change : bool
check loss convergence or not
validation_freqs : None or positive integer or container object in python
Do validation in training process or Not.
if equals None, will not do validation in train process;
if equals positive integer, will validate data every validation_freqs epochs passes;
if container object in python, will validate data if epochs belong to this container.
e.g. validation_freqs = [10, 15], will validate data when epoch equals to 10 and 15.
The default value is None, 1 is suggested. You can set it to a number larger than 1 in order to
speed up training by skipping validation rounds. When it is larger than 1, a number which is
divisible by "epochs" is recommended, otherwise, you will miss the validation scores
of last training epoch.
optimizer : str or dict
optimizer method, accept following types:
1. a string, one of "Adadelta", "Adagrad", "Adam", "Adamax", "Nadam", "RMSprop", "SGD"
2. a dict, with a required key-value pair keyed by "optimizer",
with optional key-value pairs such as learning rate.
defaults to "SGD"
nn_define : dict
a dict represents the structure of neural network, it can be output by tf-keras
epochs : int
epochs num
intersect_param
define the intersect method
config_type : {'tf-keras'}
config type
batch_size : int
batch size when computing transformed feature embedding, -1 use full data.
encrypte_param
encrypted param
encrypted_mode_calculator_param
encrypted mode calculator param:
predict_param
predict param
mode: {"plain", "encrypted"}
plain: will not use any encrypt algorithms, data exchanged in plaintext
encrypted: use paillier to encrypt gradients
communication_efficient: bool
will use communication efficient or not. when communication efficient is enabled, FTL model will
update gradients by several local rounds using intermediate data
local_round: int
local update round when using communication efficient
"""
super(FTLParam, self).__init__()
self.alpha = alpha
self.tol = tol
self.n_iter_no_change = n_iter_no_change
self.validation_freqs = validation_freqs
self.optimizer = optimizer
self.nn_define = nn_define
self.epochs = epochs
self.intersect_param = copy.deepcopy(intersect_param)
self.config_type = config_type
self.batch_size = batch_size
self.encrypted_mode_calculator_param = copy.deepcopy(encrypted_mode_calculator_param)
self.encrypt_param = copy.deepcopy(encrypte_param)
self.predict_param = copy.deepcopy(predict_param)
self.mode = mode
self.communication_efficient = communication_efficient
self.local_round = local_round
self.callback_param = copy.deepcopy(callback_param)
def check(self):
self.intersect_param.check()
self.encrypt_param.check()
self.encrypted_mode_calculator_param.check()
self.optimizer = self._parse_optimizer(self.optimizer)
supported_config_type = ["keras"]
if self.config_type not in supported_config_type:
raise ValueError(f"config_type should be one of {supported_config_type}")
if not isinstance(self.tol, (int, float)):
raise ValueError("tol should be numeric")
if not isinstance(self.epochs, int) or self.epochs <= 0:
raise ValueError("epochs should be a positive integer")
if self.nn_define and not isinstance(self.nn_define, dict):
raise ValueError("bottom_nn_define should be a dict defining the structure of neural network")
if self.batch_size != -1:
if not isinstance(self.batch_size, int) \
or self.batch_size < consts.MIN_BATCH_SIZE:
raise ValueError(
" {} not supported, should be larger than 10 or -1 represent for all data".format(self.batch_size))
for p in deprecated_param_list:
# if self._warn_to_deprecate_param(p, "", ""):
if self._deprecated_params_set.get(p):
if "callback_param" in self.get_user_feeded():
raise ValueError(f"{p} and callback param should not be set simultaneously,"
f"{self._deprecated_params_set}, {self.get_user_feeded()}")
else:
self.callback_param.callbacks = ["PerformanceEvaluate"]
break
descr = "ftl's"
if self._warn_to_deprecate_param("validation_freqs", descr, "callback_param's 'validation_freqs'"):
self.callback_param.validation_freqs = self.validation_freqs
if self._warn_to_deprecate_param("metrics", descr, "callback_param's 'metrics'"):
self.callback_param.metrics = self.metrics
if self.validation_freqs is None:
pass
elif isinstance(self.validation_freqs, int):
if self.validation_freqs < 1:
raise ValueError("validation_freqs should be larger than 0 when it's integer")
elif not isinstance(self.validation_freqs, collections.Container):
raise ValueError("validation_freqs should be None or positive integer or container")
assert isinstance(self.communication_efficient, bool), 'communication efficient must be a boolean'
assert self.mode in [
'encrypted', 'plain'], 'mode options: encrpyted or plain, but {} is offered'.format(
self.mode)
self.check_positive_integer(self.epochs, 'epochs')
self.check_positive_number(self.alpha, 'alpha')
self.check_positive_integer(self.local_round, 'local round')
@staticmethod
def _parse_optimizer(opt):
"""
Examples:
1. "optimize": "SGD"
2. "optimize": {
"optimizer": "SGD",
"learning_rate": 0.05
}
"""
kwargs = {}
if isinstance(opt, str):
return SimpleNamespace(optimizer=opt, kwargs=kwargs)
elif isinstance(opt, dict):
optimizer = opt.get("optimizer", kwargs)
if not optimizer:
raise ValueError(f"optimizer config: {opt} invalid")
kwargs = {k: v for k, v in opt.items() if k != "optimizer"}
return SimpleNamespace(optimizer=optimizer, kwargs=kwargs)
else:
raise ValueError(f"invalid type for optimize: {type(opt)}")