Skip to content

Commit

Permalink
Bugfix.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 478591776
  • Loading branch information
tensorflower-gardener committed Oct 3, 2022
1 parent 3f6d0ac commit 0738d6f
Showing 1 changed file with 15 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _run_trained_attack(attack_input: AttackInputData,
left_out_indices = prepared_attacker_data.left_out_indices
features = prepared_attacker_data.features_all
labels = prepared_attacker_data.labels_all
sample_weights = prepared_attacker_data.sample_weights_all

# We are going to train multiple models on disjoint subsets of the data
# (`features`, `labels`), so we can get the membership scores of all samples,
Expand All @@ -85,8 +86,21 @@ def _run_trained_attack(attack_input: AttackInputData,
# Make sure one sample only got score predicted once
assert np.all(np.isnan(scores[test_indices]))

# Setup sample weights if provided.
if sample_weights is not None:
# If sample weights are provided, only the weights at the training indices
# are used for training. The weights at the test indices are not used
# during prediction. Not that 'train' and 'test' refer to the data for the
# attack models, not the data for the original models.
sample_weights_train = np.squeeze(sample_weights[train_indices])
else:
sample_weights_train = None

attacker = models.create_attacker(attack_type, backend=backend)
attacker.train_model(features[train_indices], labels[train_indices])
attacker.train_model(
features[train_indices],
labels[train_indices],
sample_weight=sample_weights_train)
predictions = attacker.predict(features[test_indices])
scores[test_indices] = predictions

Expand Down

0 comments on commit 0738d6f

Please sign in to comment.