diff --git a/src/netlink.c b/src/netlink.c index 9bf2f84edda8..4392f7ba57d1 100644 --- a/src/netlink.c +++ b/src/netlink.c @@ -567,10 +567,8 @@ static int wg_set_device(struct sk_buff *skb, struct genl_info *info) private_key); list_for_each_entry_safe(peer, temp, &wg->peer_list, peer_list) { - if (wg_noise_precompute_static_static(peer)) - wg_noise_expire_current_peer_keypairs(peer); - else - wg_peer_remove(peer); + BUG_ON(!wg_noise_precompute_static_static(peer)); + wg_noise_expire_current_peer_keypairs(peer); } wg_cookie_checker_precompute_device_keys(&wg->cookie_checker); up_write(&wg->static_identity.lock); diff --git a/src/noise.c b/src/noise.c index 269b69faf11a..d0b897f40a13 100644 --- a/src/noise.c +++ b/src/noise.c @@ -46,17 +46,21 @@ void __init wg_noise_init(void) /* Must hold peer->handshake.static_identity->lock */ bool wg_noise_precompute_static_static(struct wg_peer *peer) { - bool ret = true; + bool ret; down_write(&peer->handshake.lock); - if (peer->handshake.static_identity->has_identity) + if (peer->handshake.static_identity->has_identity) { ret = curve25519( peer->handshake.precomputed_static_static, peer->handshake.static_identity->static_private, peer->handshake.remote_static); - else + } else { + u8 empty[NOISE_PUBLIC_KEY_LEN] = { 0 }; + + ret = curve25519(empty, empty, peer->handshake.remote_static); memset(peer->handshake.precomputed_static_static, 0, NOISE_PUBLIC_KEY_LEN); + } up_write(&peer->handshake.lock); return ret; }