tunsafe-clang15/wireguard_proto.cpp

1308 lines
45 KiB
C++
Raw Normal View History

// SPDX-License-Identifier: AGPL-1.0-only
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
#include "stdafx.h"
#include "wireguard_proto.h"
#include "crypto/chacha20poly1305.h"
#include "crypto/blake2s.h"
#include "crypto/curve25519-donna.h"
#include "crypto/aesgcm/aes.h"
#include "crypto/siphash.h"
#include "tunsafe_endian.h"
#include "util.h"
#include "crypto_ops.h"
#include "bit_ops.h"
#include "tunsafe_cpu.h"
#include <algorithm>
#include <assert.h>
#include <stdlib.h>
#include <string.h>
static const uint8 kLabelCookie[] = {'c', 'o', 'o', 'k', 'i', 'e', '-', '-'};
static const uint8 kLabelMac1[] = {'m', 'a', 'c', '1', '-', '-', '-', '-'};
static const uint8 kWgInitHash[WG_HASH_LEN] = {0x22,0x11,0xb3,0x61,0x08,0x1a,0xc5,0x66,0x69,0x12,0x43,0xdb,0x45,0x8a,0xd5,0x32,0x2d,0x9c,0x6c,0x66,0x22,0x93,0xe8,0xb7,0x0e,0xe1,0x9c,0x65,0xba,0x07,0x9e,0xf3};
static const uint8 kWgInitChainingKey[WG_HASH_LEN] = {0x60,0xe2,0x6d,0xae,0xf3,0x27,0xef,0xc0,0x2e,0xc3,0x35,0xe2,0xa0,0x25,0xd2,0xd0,0x16,0xeb,0x42,0x06,0xf8,0x72,0x77,0xf5,0x2d,0x38,0xd1,0x98,0x8b,0x78,0xcd,0x36};
static const uint8 kCurve25519Basepoint[32] = {9};
IpToPeerMap::IpToPeerMap() {
}
IpToPeerMap::~IpToPeerMap() {
}
bool IpToPeerMap::InsertV4(const void *addr, int cidr, void *peer) {
uint32 mask = cidr == 32 ? 0xffffffff : ~(0xffffffff >> cidr);
Entry4 e = {ReadBE32(addr) & mask, mask, peer};
ipv4_.push_back(e);
return true;
}
bool IpToPeerMap::InsertV6(const void *addr, int cidr, void *peer) {
Entry6 e;
e.cidr_len = cidr;
e.peer = peer;
memcpy(e.ip, addr, 16);
ipv6_.push_back(e);
return true;
}
void *IpToPeerMap::LookupV4(uint32 ip) {
uint32 best_mask = 0;
void *best_peer = NULL;
for (auto it = ipv4_.begin(); it != ipv4_.end(); ++it) {
if (it->ip == (ip & it->mask) && it->mask >= best_mask) {
best_mask = it->mask;
best_peer = it->peer;
}
}
return best_peer;
}
void *IpToPeerMap::LookupV4DefaultPeer() {
for (auto it = ipv4_.begin(); it != ipv4_.end(); ++it) {
if (it->mask == 0)
return it->peer;
}
return NULL;
}
void *IpToPeerMap::LookupV6DefaultPeer() {
for (auto it = ipv6_.begin(); it != ipv6_.end(); ++it) {
if (it->cidr_len == 0)
return it->peer;
}
return NULL;
}
static int CalculateIPv6CommonPrefix(const uint8 *a, const uint8 *b) {
uint64 x = ToBE64(*(uint64*)&a[0] ^ *(uint64*)&b[0]);
uint64 y = ToBE64(*(uint64*)&a[8] ^ *(uint64*)&b[8]);
return x ? 64 - FindHighestSetBit64(x) : 128 - FindHighestSetBit64(y);
}
void *IpToPeerMap::LookupV6(const void *addr) {
int best_len = 0;
void *best_peer = NULL;
for (auto it = ipv6_.begin(); it != ipv6_.end(); ++it) {
int len = CalculateIPv6CommonPrefix((const uint8*)addr, it->ip);
if (len >= it->cidr_len && len >= best_len) {
best_len = len;
best_peer = it->peer;
}
}
return best_peer;
}
void IpToPeerMap::RemovePeer(void *peer) {
{
size_t n = ipv4_.size();
Entry4 *r = &ipv4_[0], *w = r;
for (size_t i = 0; i != n; i++, r++) {
if (r->peer != peer)
*w++ = *r;
}
ipv4_.resize(w - &ipv4_[0]);
}
{
size_t n = ipv6_.size();
Entry6 *r = &ipv6_[0], *w = r;
for (size_t i = 0; i != n; i++, r++) {
if (r->peer != peer)
*w++ = *r;
}
ipv6_.resize(w - &ipv6_[0]);
}
}
ReplayDetector::ReplayDetector() {
expected_seq_nr_ = 0;
memset(bitmap_, 0, sizeof(bitmap_));
}
ReplayDetector::~ReplayDetector() {
}
bool ReplayDetector::CheckReplay(uint64 seq_nr) {
uint64 slot = seq_nr / BITS_PER_ENTRY;
if (seq_nr >= expected_seq_nr_) {
uint64 prev_slot = (expected_seq_nr_ + BITS_PER_ENTRY - 1) / BITS_PER_ENTRY - 1, n;
if ((n = slot - prev_slot) != 0) {
size_t nn = (size_t)std::min<uint64>(n, BITMAP_SIZE);
do {
bitmap_[(prev_slot + nn) & BITMAP_MASK] = 0;
} while (--nn);
}
expected_seq_nr_ = seq_nr + 1;
} else if (seq_nr + WINDOW_SIZE <= expected_seq_nr_) {
return false;
}
uint32 mask = 1 << (seq_nr & (BITS_PER_ENTRY - 1)), prev;
prev = bitmap_[slot & BITMAP_MASK];
bitmap_[slot & BITMAP_MASK] = prev | mask;
return (prev & mask) == 0;
}
WgDevice::WgDevice() {
peers_ = NULL;
header_obfuscation_ = false;
next_rng_slot_ = 0;
last_complete_handskake_timestamp_ = 0;
memset(&compression_header_, 0, sizeof(compression_header_));
low_resolution_timestamp_ = cookie_secret_timestamp_ = OsGetMilliseconds();
OsGetRandomBytes(cookie_secret_, sizeof(cookie_secret_));
OsGetRandomBytes((uint8*)random_number_input_, sizeof(random_number_input_));
}
WgDevice::~WgDevice() {
}
void WgDevice::SecondLoop(uint64 now) {
low_resolution_timestamp_ = now;
if (rate_limiter_.is_used()) {
uint32 k[5];
for (size_t i = 0; i < ARRAY_SIZE(k); i++)
k[i] = GetRandomNumber();
rate_limiter_.Periodic(k);
}
}
uint32 WgDevice::InsertInKeyIdLookup(WgPeer *peer, WgKeypair *kp) {
assert(peer);
for (;;) {
uint32 v = GetRandomNumber();
if (v == 0)
continue;
std::pair<WgPeer*, WgKeypair*> &peer_and_keypair = key_id_lookup_[v];
if (peer_and_keypair.first == NULL) {
peer_and_keypair = std::make_pair(peer, kp);
uint32 &x = (kp ? kp->local_key_id : peer->local_key_id_during_hs_);
uint32 old = x;
x = v;
if (old)
key_id_lookup_.erase(old);
return v;
}
}
}
uint32 WgDevice::GetRandomNumber() {
size_t slot;
if ((slot = next_rng_slot_) == 0) {
blake2s(random_number_output_, sizeof(random_number_output_), random_number_input_, sizeof(random_number_input_), NULL, 0);
random_number_input_[0]++;
slot = BLAKE2S_OUTBYTES / 4;
}
next_rng_slot_ = (uint8) --slot;
return random_number_output_[slot];
}
static void BlakeX2(uint8 *dst, size_t dst_size, const uint8 *a, size_t a_size, const uint8 *b, size_t b_size) {
blake2s_state b2s;
blake2s_init(&b2s, dst_size);
blake2s_update(&b2s, a, a_size);
blake2s_update(&b2s, b, b_size);
blake2s_final(&b2s, dst, dst_size);
}
static inline void BlakeMix(uint8 dst[WG_HASH_LEN], const uint8 *a, size_t a_size) {
BlakeX2(dst, WG_HASH_LEN, dst, WG_HASH_LEN, a, a_size);
}
static inline void ComputeHKDF2DH(uint8 ci[WG_HASH_LEN], uint8 k[WG_SYMMETRIC_KEY_LEN], const uint8 priv[WG_PUBLIC_KEY_LEN], const uint8 pub[WG_PUBLIC_KEY_LEN]) {
uint8 dh[WG_PUBLIC_KEY_LEN];
curve25519_donna(dh, priv, pub);
blake2s_hkdf(ci, WG_HASH_LEN, k, WG_SYMMETRIC_KEY_LEN, NULL, 32, dh, sizeof(dh), ci, WG_HASH_LEN);
memzero_crypto(dh, sizeof(dh));
}
void WgDevice::Initialize(const uint8 private_key[WG_PUBLIC_KEY_LEN]) {
// Derive the public key from the private key.
memcpy(s_priv_, private_key, sizeof(s_priv_));
curve25519_donna(s_pub_, s_priv_, kCurve25519Basepoint);
// Precompute: precomputed_cookie_label_hash_ := HASH(LABEL-COOKIE || Spub_m)
// precomputed_label_mac1_hash_ := HASH(MAC1-COOKIE || Spub_m)
BlakeX2(precomputed_cookie_key_, sizeof(precomputed_cookie_key_),
kLabelCookie, sizeof(kLabelCookie), s_pub_, sizeof(s_pub_));
BlakeX2(precomputed_mac1_key_, sizeof(precomputed_mac1_key_),
kLabelMac1, sizeof(kLabelMac1), s_pub_, sizeof(s_pub_));
}
WgPeer *WgDevice::AddPeer() {
WgPeer *peer = new WgPeer(this);
WgPeer **pp = &peers_;
while (*pp)
pp = &(*pp)->next_peer_;
*pp = peer;
return peer;
}
WgPeer *WgDevice::GetPeerFromPublicKey(uint8 public_key[WG_PUBLIC_KEY_LEN]) {
for (WgPeer *peer = peers_; peer; peer = peer->next_peer_) {
if (memcmp(peer->s_remote_, public_key, WG_PUBLIC_KEY_LEN) == 0)
return peer;
}
return NULL;
}
bool WgDevice::CheckCookieMac1(Packet *packet) {
uint8 mac[WG_COOKIE_LEN];
const uint8 *data = packet->data;
size_t data_size = packet->size;
blake2s(mac, sizeof(mac), data, data_size - WG_COOKIE_LEN * 2, precomputed_mac1_key_, sizeof(precomputed_mac1_key_));
return !memcmp_crypto(mac, data + data_size - WG_COOKIE_LEN * 2, WG_COOKIE_LEN);
}
void WgDevice::MakeCookie(uint8 cookie[WG_COOKIE_LEN], Packet *packet) {
blake2s_state b2s;
uint64 now = OsGetMilliseconds();
if (now - cookie_secret_timestamp_ >= COOKIE_SECRET_MAX_AGE_MS) {
cookie_secret_timestamp_ = now;
OsGetRandomBytes(cookie_secret_, sizeof(cookie_secret_));
}
blake2s_init_key(&b2s, WG_COOKIE_LEN, cookie_secret_, sizeof(cookie_secret_));
if (packet->addr.sin.sin_family == AF_INET)
blake2s_update(&b2s, &packet->addr.sin.sin_addr, 4);
else if (packet->addr.sin.sin_family == AF_INET6)
blake2s_update(&b2s, &packet->addr.sin6.sin6_addr, sizeof(packet->addr.sin6.sin6_addr));
blake2s_update(&b2s, &packet->addr.sin6.sin6_port, 2);
blake2s_final(&b2s, cookie, WG_COOKIE_LEN);
}
bool WgDevice::CheckCookieMac2(Packet *packet) {
uint8 cookie[WG_COOKIE_LEN];
uint8 mac[WG_COOKIE_LEN];
MakeCookie(cookie, packet);
blake2s(mac, sizeof(mac), packet->data, packet->size - WG_COOKIE_LEN, cookie, sizeof(cookie));
return !memcmp_crypto(mac, packet->data + packet->size - WG_COOKIE_LEN, WG_COOKIE_LEN);
}
void WgDevice::CreateCookieMessage(MessageHandshakeCookie *dst, Packet *packet, uint32 remote_key_id) {
dst->type = MESSAGE_HANDSHAKE_COOKIE;
dst->receiver_key_id = remote_key_id;
MakeCookie(dst->cookie_enc, packet);
OsGetRandomBytes(dst->nonce, sizeof(dst->nonce));
MessageMacs *mac = (MessageMacs *)(packet->data + packet->size - sizeof(MessageMacs));
xchacha20poly1305_encrypt(dst->cookie_enc, dst->cookie_enc, WG_COOKIE_LEN, mac->mac1, WG_COOKIE_LEN, dst->nonce, precomputed_cookie_key_);
}
void WgDevice::EraseKeypairAddrEntry(WgKeypair *kp) {
WgAddrEntry *ae = kp->addr_entry;
assert(ae->ref_count >= 1);
assert(ae->ref_count == !!ae->keys[0] + !!ae->keys[1] + !!ae->keys[2]);
assert(ae->keys[kp->addr_entry_slot - 1] == kp);
kp->addr_entry = NULL;
ae->keys[kp->addr_entry_slot - 1] = NULL;
kp->addr_entry_slot = 0;
if (ae->ref_count-- == 1) {
addr_entry_lookup_.erase(ae->addr_entry_id);
delete ae;
}
}
void WgDevice::UpdateKeypairAddrEntry(uint64 addr_id, WgKeypair *keypair) {
if (keypair->addr_entry != NULL && keypair->addr_entry->addr_entry_id == addr_id) {
keypair->broadcast_short_key = 1;
return;
}
if (keypair->addr_entry != NULL)
EraseKeypairAddrEntry(keypair);
WgAddrEntry **aep = &addr_entry_lookup_[addr_id], *ae;
if ((ae = *aep) == NULL) {
*aep = ae = new WgAddrEntry(addr_id);
} else {
// Ensure we don't insert new things in this addr entry too often.
if (ae->time_of_last_insertion + 1000 * 60 > low_resolution_timestamp_)
return;
}
ae->time_of_last_insertion = low_resolution_timestamp_;
// Update slot #
uint32 next_slot = ae->next_slot;
ae->next_slot = (next_slot == 2) ? 0 : next_slot + 1;
WgKeypair *old_keypair = ae->keys[next_slot];
ae->keys[next_slot] = keypair;
keypair->addr_entry = ae;
keypair->addr_entry_slot = next_slot + 1;
if (old_keypair != NULL) {
old_keypair->addr_entry = NULL;
old_keypair->addr_entry_slot = 0;
} else {
ae->ref_count++;
}
assert(ae->ref_count == !!ae->keys[0] + !!ae->keys[1] + !!ae->keys[2]);
keypair->broadcast_short_key = 1;
}
//>> > hashlib.sha256('TunSafe Header Obfuscation Key').hexdigest()
//'2444423e33eb5bb875961224c6441f54c5dea95a3a4e1139509ffa6992bdb278'
static const uint8 kHeaderObfuscationKey[32] = {36, 68, 66, 62, 51, 235, 91, 184, 117, 150, 18, 36, 198, 68, 31, 84, 197, 222, 169, 90, 58, 78, 17, 57, 80, 159, 250, 105, 146, 189, 178, 120};
void WgDevice::SetHeaderObfuscation(const char *key) {
#if WITH_HEADER_OBFUSCATION
header_obfuscation_ = (key != NULL);
if (key)
blake2s_hmac((uint8*)&header_obfuscation_key_, sizeof(header_obfuscation_key_), (uint8*)key, strlen(key), kHeaderObfuscationKey, sizeof(kHeaderObfuscationKey));
#endif // WITH_HEADER_OBFUSCATION
}
WgPeer::WgPeer(WgDevice *dev) {
dev_ = dev;
endpoint_.sin.sin_family = 0;
next_peer_ = NULL;
curr_keypair_ = next_keypair_ = prev_keypair_ = NULL;
expect_cookie_reply_ = false;
has_mac2_cookie_ = false;
allow_multicast_through_peer_ = false;
supports_handshake_extensions_ = true;
local_key_id_during_hs_ = 0;
last_handshake_init_timestamp_ = -1000000ll;
last_handshake_init_recv_timestamp_ = 0;
last_complete_handskake_timestamp_ = 0;
persistent_keepalive_ms_ = 0;
timers_ = 0;
first_queued_packet_ = NULL;
last_queued_packet_ptr_ = &first_queued_packet_;
num_queued_packets_ = 0;
handshake_attempts_ = 0;
num_ciphers_ = 0;
cipher_prio_ = 0;
memset(last_timestamp_, 0, sizeof(last_timestamp_));
ipv4_broadcast_addr_ = 0xffffffff;
memset(features_, 0, sizeof(features_));
}
WgPeer::~WgPeer() {
ClearKeys();
ClearHandshake();
ClearPacketQueue();
}
void WgPeer::ClearPacketQueue() {
Packet *packet;
while ((packet = first_queued_packet_) != NULL) {
first_queued_packet_ = packet->next;
FreePacket(packet);
}
last_queued_packet_ptr_ = &first_queued_packet_;
num_queued_packets_ = 0;
}
void WgPeer::Initialize(const uint8 spub[WG_PUBLIC_KEY_LEN], const uint8 preshared_key[WG_SYMMETRIC_KEY_LEN]) {
// Optionally use a preshared key, it defaults to all zeros.
if (preshared_key)
memcpy(preshared_key_, preshared_key, sizeof(preshared_key_));
else
memset(preshared_key_, 0, sizeof(preshared_key_));
// Precompute: s_priv_pub_ := DH(sprivr, spubi)
memcpy(s_remote_, spub, sizeof(s_remote_));
curve25519_donna(s_priv_pub_, dev_->s_priv_, s_remote_);
// Precompute: precomputed_cookie_key_ := HASH(LABEL-COOKIE || Spub_m)
// precomputed_mac1_key_ := HASH(MAC1-COOKIE || Spub_m)
BlakeX2(precomputed_cookie_key_, sizeof(precomputed_cookie_key_),
kLabelCookie, sizeof(kLabelCookie), spub, WG_PUBLIC_KEY_LEN);
BlakeX2(precomputed_mac1_key_, sizeof(precomputed_mac1_key_),
kLabelMac1, sizeof(kLabelMac1), spub, WG_PUBLIC_KEY_LEN);
}
// run on the client
void WgPeer::CreateMessageHandshakeInitiation(Packet *packet) {
uint8 k[WG_SYMMETRIC_KEY_LEN];
MessageHandshakeInitiation *dst = (MessageHandshakeInitiation *)packet->data;
// Ci := HASH(CONSTRUCTION)
memcpy(hs_.ci, kWgInitChainingKey, sizeof(hs_.ci));
// Hi := HASH(Ci || IDENTIFIER)
memcpy(hs_.hi, kWgInitHash, sizeof(hs_.hi));
// Hi := HASH(Hi || Spub_r)
BlakeMix(hs_.hi, s_remote_, sizeof(s_remote_));
// (Epriv_r, Epub_r) := DH-GENERATE()
// msg.ephemeral = Epub_r
OsGetRandomBytes(hs_.e_priv, sizeof(hs_.e_priv));
curve25519_normalize(hs_.e_priv);
curve25519_donna(dst->ephemeral, hs_.e_priv, kCurve25519Basepoint);
// Ci := KDF_1(Ci, msg.ephemeral)
blake2s_hkdf(hs_.ci, sizeof(hs_.ci), NULL, 32, NULL, 32, dst->ephemeral, sizeof(dst->ephemeral), hs_.ci, WG_HASH_LEN);
// Hi := HASH(Hi || msg.ephemeral)
BlakeMix(hs_.hi, dst->ephemeral, sizeof(dst->ephemeral));
// (Ci, K) := KDF2(Ci, DH(epriv, spub_r))
ComputeHKDF2DH(hs_.ci, k, hs_.e_priv, s_remote_);
// msg.static = AEAD(K, 0, Spub_i, Hi)
chacha20poly1305_encrypt(dst->static_enc, dev_->s_pub_, sizeof(dev_->s_pub_), hs_.hi, sizeof(hs_.hi), 0, k);
// Hi := HASH(Hi || msg.static)
BlakeMix(hs_.hi, dst->static_enc, sizeof(dst->static_enc));
// (Ci, K) := KDF2(Ci, DH(sprivr, spubi))
blake2s_hkdf(hs_.ci, sizeof(hs_.ci), k, sizeof(k), NULL, 32, s_priv_pub_, sizeof(s_priv_pub_), hs_.ci, WG_HASH_LEN);
// TAI64N
OsGetTimestampTAI64N(dst->timestamp_enc);
size_t extfield_size = 0;
#if WITH_HANDSHAKE_EXT
if (supports_handshake_extensions_)
extfield_size = WriteHandshakeExtension(dst->timestamp_enc + WG_TIMESTAMP_LEN, NULL);
#endif // WITH_HANDSHAKE_EXT
// msg.timestamp := AEAD(K, 0, timestamp, hi)
chacha20poly1305_encrypt(dst->timestamp_enc, dst->timestamp_enc, extfield_size + WG_TIMESTAMP_LEN, hs_.hi, sizeof(hs_.hi), 0, k);
// Hi := HASH(Hi || msg.timestamp)
BlakeMix(hs_.hi, dst->timestamp_enc, extfield_size + WG_TIMESTAMP_LEN + WG_MAC_LEN);
packet->size = (unsigned)(sizeof(MessageHandshakeInitiation) + extfield_size);
// Insert a pointer to this object,
dst->sender_key_id = dev_->InsertInKeyIdLookup(this, NULL);
dst->type = MESSAGE_HANDSHAKE_INITIATION;
memzero_crypto(k, sizeof(k));
WriteMacToPacket((uint8*)dst, (MessageMacs*)((uint8*)&dst->mac + extfield_size));
}
// Parsed by server
WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) { // const MessageHandshakeInitiation *src, MessageHandshakeResponse *dst) {
// Copy values into handshake once we've validated it all.
uint8 ci[WG_HASH_LEN];
uint8 hi[WG_HASH_LEN];
union {
uint8 k[WG_SYMMETRIC_KEY_LEN];
uint8 e_priv[WG_PUBLIC_KEY_LEN];
};
union {
uint8 spubi[WG_PUBLIC_KEY_LEN];
uint8 e_remote[WG_PUBLIC_KEY_LEN];
uint8 hi2[WG_HASH_LEN];
};
uint8 t[WG_HASH_LEN];
WgPeer *peer;
WgKeypair *keypair;
uint32 remote_key_id;
uint64 now;
uint8 extbuf[MAX_SIZE_OF_HANDSHAKE_EXTENSION + WG_TIMESTAMP_LEN];
MessageHandshakeInitiation *src = (MessageHandshakeInitiation *)packet->data;
MessageHandshakeResponse *dst;
size_t extfield_size;
// Ci := HASH(CONSTRUCTION)
memcpy(ci, kWgInitChainingKey, sizeof(ci));
// Hi := HASH(Ci || IDENTIFIER)
memcpy(hi, kWgInitHash, sizeof(hi));
// Hi := HASH(Hi || Spub_r)
BlakeMix(hi, dev->s_pub_, sizeof(dev->s_pub_));
// Ci := KDF_1(Ci, msg.ephemeral)
blake2s_hkdf(ci, sizeof(ci), NULL, 32, NULL, 32, src->ephemeral, sizeof(src->ephemeral), ci, WG_HASH_LEN);
// Hi := HASH(Hi || msg.ephemeral)
BlakeMix(hi, src->ephemeral, sizeof(src->ephemeral));
// (Ci, K) := KDF2(Ci, DH(spriv, msg.ephemeral))
ComputeHKDF2DH(ci, k, dev->s_priv_, src->ephemeral);
// Spub_i = AEAD_DEC(K, 0, msg.static, Hi)
if (!chacha20poly1305_decrypt(spubi, src->static_enc, sizeof(src->static_enc), hi, sizeof(hi), 0, k))
goto getout;
// Hi := HASH(Hi || msg.static)
BlakeMix(hi, src->static_enc, sizeof(src->static_enc));
// Lookup the peer with this ID
if (!(peer = dev->GetPeerFromPublicKey(spubi)))
goto getout;
// (Ci, K) := KDF2(Ci, DH(sprivr, spubi))
blake2s_hkdf(ci, sizeof(ci), k, sizeof(k), NULL, 32, peer->s_priv_pub_, sizeof(peer->s_priv_pub_), ci, WG_HASH_LEN);
// Hi2 := Hi
memcpy(hi2, hi, sizeof(hi2));
extfield_size = packet->size - sizeof(MessageHandshakeInitiation);
if (extfield_size > MAX_SIZE_OF_HANDSHAKE_EXTENSION || (extfield_size && !peer->supports_handshake_extensions_))
goto getout;
// Hi := HASH(Hi || msg.timestamp)
BlakeMix(hi, src->timestamp_enc, extfield_size + WG_TIMESTAMP_LEN + WG_MAC_LEN);
// TIMESTAMP := AEAD_DEC(K, 0, msg.timestamp, hi2)
if (!chacha20poly1305_decrypt(extbuf, src->timestamp_enc, extfield_size + WG_TIMESTAMP_LEN + WG_MAC_LEN, hi2, sizeof(hi2), 0, k))
goto getout;
// Replay attack?
if (memcmp(extbuf, peer->last_timestamp_, WG_TIMESTAMP_LEN) <= 0)
goto getout;
// Flood attack?
now = OsGetMilliseconds();
if (now < peer->last_handshake_init_recv_timestamp_ + MIN_HANDSHAKE_INTERVAL_MS)
goto getout;
// Remember all the information we need to produce a response cause we cannot touch src again
peer->last_handshake_init_recv_timestamp_ = now;
memcpy(peer->last_timestamp_, extbuf, sizeof(peer->last_timestamp_));
memcpy(e_remote, src->ephemeral, sizeof(e_remote));
remote_key_id = src->sender_key_id;
dst = (MessageHandshakeResponse *)src;
// (Epriv_r, Epub_r) := DH-GENERATE()
// msg.ephemeral = Epub_r
OsGetRandomBytes(e_priv, sizeof(e_priv));
curve25519_normalize(e_priv);
curve25519_donna(dst->ephemeral, e_priv, kCurve25519Basepoint);
// Hr := HASH(Hr || msg.ephemeral)
BlakeMix(hi, dst->ephemeral, sizeof(dst->ephemeral));
// Ci := KDF_1(Ci, msg.ephemeral)
blake2s_hkdf(ci, sizeof(ci), NULL, 32, NULL, 32, dst->ephemeral, sizeof(dst->ephemeral), ci, WG_HASH_LEN);
// Ci : = KDF2(Ci, DH(epriv, epub))
ComputeHKDF2DH(ci, NULL, e_priv, e_remote);
// Ci : = KDF2(Ci, DH(epriv, spub))
ComputeHKDF2DH(ci, NULL, e_priv, peer->s_remote_);
// (Ci, T, K) := KDF3(Ci, Q)
blake2s_hkdf(ci, sizeof(ci), t, sizeof(t), k, sizeof(k), peer->preshared_key_, sizeof(preshared_key_), ci, WG_HASH_LEN);
// Hr := HASH(Hr || T)
BlakeMix(hi, t, sizeof(t));
dst->receiver_key_id = remote_key_id;
keypair = peer->CreateNewKeypair(false, ci, remote_key_id, extbuf + WG_TIMESTAMP_LEN, extfield_size);
if (keypair) {
peer->InsertKeypairInPeer(keypair);
dst->sender_key_id = dev->InsertInKeyIdLookup(peer, keypair);
size_t extfield_out_size = 0;
#if WITH_HANDSHAKE_EXT
if (extfield_size)
extfield_out_size = peer->WriteHandshakeExtension(dst->empty_enc, keypair);
#endif // WITH_HANDSHAKE_EXT
packet->size = (unsigned)(sizeof(MessageHandshakeResponse) + extfield_out_size);
// msg.empty := AEAD(K, 0, "", Hr)
chacha20poly1305_encrypt(dst->empty_enc, dst->empty_enc, extfield_out_size, hi, sizeof(hi), 0, k);
// Hr := HASH(Hr || "")
//BlakeMix(hi, dst->empty_enc, extfield_out_size);
dst->type = MESSAGE_HANDSHAKE_RESPONSE;
peer->WriteMacToPacket((uint8*)dst, (MessageMacs*)((uint8*)&dst->mac + extfield_out_size));
} else {
getout:
peer = NULL;
}
memzero_crypto(hi, sizeof(hi));
memzero_crypto(ci, sizeof(ci));
memzero_crypto(k, sizeof(k));
memzero_crypto(t, sizeof(t));
return peer;
}
WgPeer *WgPeer::ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packet) {
MessageHandshakeResponse *src = (MessageHandshakeResponse *)packet->data;
uint8 t[WG_HASH_LEN];
uint8 k[WG_SYMMETRIC_KEY_LEN];
WgKeypair *keypair;
auto it = dev->key_id_lookup().find(src->receiver_key_id);
if (it == dev->key_id_lookup().end() || it->second.second != NULL)
return NULL;
WgPeer *peer = it->second.first;
assert(src->receiver_key_id == peer->local_key_id_during_hs_);
HandshakeState hs = peer->hs_;
// Hr := HASH(Hr || msg.ephemeral)
BlakeMix(hs.hi, src->ephemeral, sizeof(src->ephemeral));
// Ci := KDF_1(Ci, msg.ephemeral)
blake2s_hkdf(hs.ci, sizeof(hs.ci), NULL, 32, NULL, 32, src->ephemeral, sizeof(src->ephemeral), hs.ci, sizeof(hs.ci));
// Ci : = KDF2(Ci, DH(epriv, epub))
ComputeHKDF2DH(hs.ci, NULL, hs.e_priv, src->ephemeral);
// Ci : = KDF2(Ci, DH(spriv, epub))
ComputeHKDF2DH(hs.ci, NULL, peer->dev_->s_priv_, src->ephemeral);
// (Ci, T, K) := KDF3(Ci, Q)
blake2s_hkdf(hs.ci, sizeof(hs.ci), t, sizeof(t), k, sizeof(k), peer->preshared_key_, sizeof(peer->preshared_key_), hs.ci, sizeof(hs.ci));
// Hr := HASH(Hr || T)
BlakeMix(hs.hi, t, sizeof(t));
size_t extfield_size = packet->size - sizeof(MessageHandshakeResponse);
if (extfield_size > MAX_SIZE_OF_HANDSHAKE_EXTENSION)
goto getout;
// "" := AEAD_DEC(K, 0, msg.empty, Hr)
if (!chacha20poly1305_decrypt(src->empty_enc, src->empty_enc, extfield_size + sizeof(src->empty_enc), hs.hi, sizeof(hs.hi), 0, k))
goto getout;
keypair = peer->CreateNewKeypair(true, hs.ci, src->sender_key_id, src->empty_enc, extfield_size);
if (!keypair)
goto getout;
peer->InsertKeypairInPeer(keypair);
// Re-map the entry in the id table so it points at this keypair instead.
keypair->local_key_id = peer->local_key_id_during_hs_;
peer->local_key_id_during_hs_ = 0;
it->second.second = keypair;
if (0) {
getout:
peer = NULL;
}
memzero_crypto(t, sizeof(t));
memzero_crypto(k, sizeof(k));
memzero_crypto(&hs, sizeof(hs));
return peer;
}
// This is parsed by the initiator, when it needs to re-send the handshake message with a better mac.
void WgPeer::ParseMessageHandshakeCookie(WgDevice *dev, const MessageHandshakeCookie *src) {
uint8 cookie[WG_COOKIE_LEN];
auto it = dev->key_id_lookup().find(src->receiver_key_id);
if (it == dev->key_id_lookup().end() || it->second.second != NULL)
return;
WgPeer *peer = it->second.first;
if (!peer->expect_cookie_reply_)
return;
if (!xchacha20poly1305_decrypt(cookie, src->cookie_enc, sizeof(src->cookie_enc),
peer->sent_mac1_, sizeof(peer->sent_mac1_), src->nonce, peer->precomputed_cookie_key_))
return;
peer->expect_cookie_reply_ = false;
peer->has_mac2_cookie_ = true;
peer->mac2_cookie_timestamp_ = OsGetMilliseconds();
memcpy(peer->mac2_cookie_, cookie, sizeof(peer->mac2_cookie_));
}
#if WITH_HANDSHAKE_EXT
size_t WgPeer::WriteHandshakeExtension(uint8 *dst, WgKeypair *keypair) {
uint8 *dst_org = dst, value = 0;
// Include the supported features extension
if (!IsOnlyZeros(features_, sizeof(features_))) {
*dst++ = EXT_BOOLEAN_FEATURES;
*dst++ = (WG_FEATURES_COUNT + 3) >> 2;
for (size_t i = 0; i != WG_FEATURES_COUNT; i++) {
if ((i & 3) == 0)
value = 0;
dst[i >> 2] = (value += (features_[i] << ((i * 2) & 7)));
}
// swap WG_FEATURE_ID_SKIP_KEYID_IN and WG_FEATURE_ID_SKIP_KEYID_OUT
dst[1] = (dst[1] & 0xF0) + ((dst[1] >> 2) & 0x03) + ((dst[1] << 2) & 0x0C);
dst += (WG_FEATURES_COUNT + 3) >> 2;
}
// Ordered list of cipher suites
size_t ciphers = num_ciphers_;
if (ciphers) {
*dst++ = EXT_CIPHER_SUITES + cipher_prio_;
if (keypair) {
*dst++ = 1;
*dst++ = keypair->cipher_suite;
} else {
*dst++ = (uint8)ciphers;
memcpy(dst, ciphers_, ciphers);
dst += ciphers;
}
}
if (features_[WG_FEATURE_ID_IPZIP]) {
// Include the packet compression extension
*dst++ = EXT_PACKET_COMPRESSION;
*dst++ = sizeof(WgPacketCompressionVer01);
memcpy(dst, &dev_->compression_header_, sizeof(WgPacketCompressionVer01));
dst += sizeof(WgPacketCompressionVer01);
}
return dst - dst_org;
}
static bool ResolveBooleanFeatureValue(uint8 other, uint8 self, bool *result) {
uint8 both = other * 4 + self;
*result = (0xfec0 >> both) & 1;
return (0xeff7 >> both) & 1;
}
static const uint8 cipher_strengths[EXT_CIPHER_SUITE_COUNT] = {4,2,3,1};
static uint32 ResolveCipherSuite(int tie, const uint8 *a, size_t a_size, const uint8 *b, size_t b_size) {
uint32 abits[8] = {0}, bbits[8] = {0}, found_a = 0, found_b = 0;
for (size_t i = 0; i < a_size; i++)
abits[a[i] >> 5] |= 1 << (a[i] & 31);
for (size_t i = 0; i < b_size; i++)
bbits[b[i] >> 5] |= 1 << (b[i] & 31);
for (size_t i = 0; i < a_size; i++)
if (bbits[a[i] >> 5] & (1 << (a[i] & 31))) {
found_a = a[i];
break;
}
for (size_t i = 0; i < b_size; i++)
if (abits[b[i] >> 5] & (1 << (b[i] & 31))) {
found_b = b[i];
break;
}
return (tie > 0 ||
(tie == 0 && cipher_strengths[found_a] > cipher_strengths[found_b])) ? found_a : found_b;
}
void WgKeypairSetupCompressionExtension(WgKeypair *keypair, const WgPacketCompressionVer01 *remotec) {
const WgPacketCompressionVer01 *localc = keypair->peer->dev_->compression_header();
IpzipState *state = &keypair->ipzip_state_;
// Use is_initiator as tie-breaker on who's going to be the client side.
int flags_xor = 0;
if ((localc->flags & ~3) + 2 * keypair->is_initiator - 1 <= (remotec->flags & ~3))
std::swap(localc, remotec), flags_xor = 1;
state->flags_xor = flags_xor;
memcpy(state->client_addr_v4, localc->ipv4_addr, 4);
memcpy(state->client_addr_v6, localc->ipv6_addr, 16);
state->guess_ttl[0] = localc->ttl;
state->client_addr_v4_subnet_bytes = (localc->flags & 3);
WriteLE32(&state->client_addr_v4_netmask, 0xffffffff >> ((localc->flags & 3) * 8));
memcpy(state->server_addr_v4, remotec->ipv4_addr, 4);
memcpy(state->server_addr_v6, remotec->ipv6_addr, 16);
state->guess_ttl[1] = remotec->ttl;
state->server_addr_v4_subnet_bytes = (remotec->flags & 3);
WriteLE32(&state->server_addr_v4_netmask, 0xffffffff >> ((remotec->flags & 3) * 8));
}
bool WgKeypairParseExtendedHandshake(WgKeypair *keypair, const uint8 *data, size_t data_size) {
bool did_setup_compression = false;
while (data_size >= 2) {
uint8 type = data[0], size = data[1];
data += 2, data_size -= 2;
if (size > data_size)
return false;
switch (type) {
case EXT_CIPHER_SUITES_PRIO:
case EXT_CIPHER_SUITES:
keypair->cipher_suite = ResolveCipherSuite(keypair->peer->cipher_prio_ - (type - EXT_CIPHER_SUITES),
keypair->peer->ciphers_, keypair->peer->num_ciphers_,
data, data_size);
break;
case EXT_BOOLEAN_FEATURES:
for (size_t i = 0, j = std::max<uint32>(WG_FEATURES_COUNT, size * 4); i != j; i++) {
uint8 value = (i < size * 4) ? (data[i >> 2] >> ((i * 2) & 7)) & 3 : 0;
if (i >= WG_FEATURES_COUNT ? (value == WG_BOOLEAN_FEATURE_ENFORCES) :
!ResolveBooleanFeatureValue(value, keypair->peer->features_[i], &keypair->enabled_features[i]))
return false;
}
break;
case EXT_PACKET_COMPRESSION:
if (size == sizeof(WgPacketCompressionVer01)) {
WgPacketCompressionVer01 *c = (WgPacketCompressionVer01*)data;
if (ReadLE16(&c->version) == EXT_PACKET_COMPRESSION_VER) {
WgKeypairSetupCompressionExtension(keypair, c);
did_setup_compression = true;
}
}
break;
}
data += size, data_size -= size;
}
if (data_size != 0)
return false;
keypair->enabled_features[WG_FEATURE_ID_IPZIP] &= did_setup_compression;
keypair->auth_tag_length = (keypair->enabled_features[WG_FEATURE_ID_SHORT_MAC] ? 8 : CHACHA20POLY1305_AUTHTAGLEN);
// RINFO("Cipher Suite = %d", keypair->cipher_suite);
return true;
}
#endif // WITH_HANDSHAKE_EXT
void WgPeer::ClearKeys() {
DeleteKeypair(&curr_keypair_);
DeleteKeypair(&next_keypair_);
DeleteKeypair(&prev_keypair_);
}
void WgPeer::ClearHandshake() {
uint32 v = local_key_id_during_hs_;
if (v != 0) {
local_key_id_during_hs_ = 0;
dev_->key_id_lookup_.erase(v);
}
}
void WgPeer::DeleteKeypair(WgKeypair **kp) {
WgKeypair *t = *kp;
*kp = NULL;
if (t) {
if (t->addr_entry)
dev_->EraseKeypairAddrEntry(t);
if (t->local_key_id)
dev_->key_id_lookup_.erase(t->local_key_id);
if (t->aes_gcm128_context_)
free(t->aes_gcm128_context_);
delete t;
}
}
WgKeypair *WgPeer::CreateNewKeypair(bool is_initiator, const uint8 chaining_key[WG_HASH_LEN], uint32 remote_key_id, const uint8 *extfield, size_t extfield_size) {
WgKeypair *kp = new WgKeypair;
uint8 *first_key, *second_key;
if (!kp)
return NULL;
memset(kp, 0, offsetof(WgKeypair, replay_detector));
kp->peer = this;
kp->is_initiator = is_initiator;
kp->remote_key_id = remote_key_id;
kp->auth_tag_length = CHACHA20POLY1305_AUTHTAGLEN;
#if WITH_HANDSHAKE_EXT
if (!WgKeypairParseExtendedHandshake(kp, extfield, extfield_size))
goto fail;
#endif // WITH_HANDSHAKE_EXT
first_key = kp->send_key, second_key = kp->recv_key;
if (!is_initiator)
std::swap(first_key, second_key);
blake2s_hkdf(first_key, sizeof(kp->send_key), second_key, sizeof(kp->recv_key),
kp->auth_tag_length != CHACHA20POLY1305_AUTHTAGLEN ? (uint8*)kp->compress_mac_keys : NULL, 32, NULL, 0, chaining_key, WG_HASH_LEN);
if (!is_initiator) {
std::swap(kp->compress_mac_keys[0][0], kp->compress_mac_keys[1][0]);
std::swap(kp->compress_mac_keys[0][1], kp->compress_mac_keys[1][1]);
}
#if WITH_HANDSHAKE_EXT
if (kp->cipher_suite >= EXT_CIPHER_SUITE_AES128_GCM && kp->cipher_suite <= EXT_CIPHER_SUITE_AES256_GCM) {
#if WITH_AESGCM
kp->aes_gcm128_context_ = (AesGcm128StaticContext *)malloc(sizeof(*kp->aes_gcm128_context_) * 2);
if (!kp->aes_gcm128_context_)
goto fail;
int key_size = (kp->cipher_suite == EXT_CIPHER_SUITE_AES128_GCM) ? 128 : 256;
CRYPTO_gcm128_init(&kp->aes_gcm128_context_[0], kp->send_key, key_size);
CRYPTO_gcm128_init(&kp->aes_gcm128_context_[1], kp->recv_key, key_size);
#else
goto fail;
#endif
}
#endif // WITH_HANDSHAKE_EXT
kp->send_key_state = kp->recv_key_state = WgKeypair::KEY_VALID;
time_of_next_key_event_ = 0;
kp->key_timestamp = OsGetMilliseconds();
return kp;
fail:
delete kp;
return NULL;
}
void WgPeer::InsertKeypairInPeer(WgKeypair *kp) {
assert(kp->peer == this);
DeleteKeypair(&prev_keypair_);
if (kp->is_initiator) {
// When we're the initator then we got the handshake and we can
// use the keypair right away.
if (next_keypair_) {
prev_keypair_ = next_keypair_;
next_keypair_ = NULL;
DeleteKeypair(&curr_keypair_);
} else {
prev_keypair_ = curr_keypair_;
}
curr_keypair_ = kp;
} else {
// The keypair will be moved to curr when we get the first data packet.
DeleteKeypair(&next_keypair_);
next_keypair_ = kp;
}
}
bool WgPeer::CheckSwitchToNextKey(WgKeypair *keypair) {
if (keypair != next_keypair_)
return false;
DeleteKeypair(&prev_keypair_);
prev_keypair_ = curr_keypair_;
curr_keypair_ = next_keypair_;
next_keypair_ = NULL;
time_of_next_key_event_ = 0;
return true;
}
bool WgPeer::CheckHandshakeRateLimit() {
uint64 now = OsGetMilliseconds();
if (now - last_handshake_init_timestamp_ < REKEY_TIMEOUT_MS)
return false;
last_handshake_init_timestamp_ = now;
return true;
}
void WgPeer::WriteMacToPacket(const uint8 *data, MessageMacs *dst) {
expect_cookie_reply_ = true;
blake2s(dst->mac1, sizeof(dst->mac1), data, (uint8*)dst->mac1 - data, precomputed_mac1_key_, sizeof(precomputed_mac1_key_));
memcpy(sent_mac1_, dst->mac1, sizeof(sent_mac1_));
if (has_mac2_cookie_ && OsGetMilliseconds() - mac2_cookie_timestamp_ < COOKIE_SECRET_MAX_AGE_MS - COOKIE_SECRET_LATENCY_MS) {
blake2s(dst->mac2, sizeof(dst->mac2), data, (uint8*)dst->mac2 - data, mac2_cookie_, sizeof(mac2_cookie_));
} else {
has_mac2_cookie_ = false;
if (dev_->header_obfuscation_) {
// when obfuscation is enabled just make the top bits random
for (size_t i = 0; i < 4; i++)
((uint32*)dst->mac2)[i] = dev_->GetRandomNumber();
} else {
memset(dst->mac2, 0, sizeof(dst->mac2));
}
}
}
enum {
// Timer for retransmitting the handshake if we don't hear back after REKEY_TIMEOUT_MS
TIMER_RETRANSMIT_HANDSHAKE = 0,
// Timer for sending keepalive if we received a packet if we don't send anything else for KEEPALIVE_TIMEOUT_MS
TIMER_SEND_KEEPALIVE = 1,
// Timer for initiating new handshake if we have sent a packet but after have not received one for KEEPALIVE_TIMEOUT_MS + REKEY_TIMEOUT_MS
TIMER_NEW_HANDSHAKE = 2,
// Timer for zeroing out all keys and handshake state after (REJECT_AFTER_TIME_MS * 3) if no new keys have been received
TIMER_ZERO_KEYS = 3,
// Timer for sending a keepalive packet every PERSISTENT_KEEPALIVE_MS
TIMER_PERSISTENT_KEEPALIVE = 4,
};
#define WgClearTimer(x) (timers_ &= ~(33 << x))
#define WgIsTimerActive(x) (timers_ & (33 << x))
#define WgSetTimer(x) (timers_ |= (32 << (x)))
void WgPeer::OnDataSent() {
WgClearTimer(TIMER_SEND_KEEPALIVE);
if (!WgIsTimerActive(TIMER_NEW_HANDSHAKE))
WgSetTimer(TIMER_NEW_HANDSHAKE);
WgSetTimer(TIMER_PERSISTENT_KEEPALIVE);
}
void WgPeer::OnKeepaliveSent() {
WgSetTimer(TIMER_PERSISTENT_KEEPALIVE);
}
void WgPeer::OnDataReceived() {
WgClearTimer(TIMER_NEW_HANDSHAKE);
if (!WgIsTimerActive(TIMER_SEND_KEEPALIVE))
WgSetTimer(TIMER_SEND_KEEPALIVE);
else
pending_keepalive_ = true;
WgSetTimer(TIMER_PERSISTENT_KEEPALIVE);
}
void WgPeer::OnKeepaliveReceived() {
WgClearTimer(TIMER_NEW_HANDSHAKE);
WgSetTimer(TIMER_PERSISTENT_KEEPALIVE);
}
void WgPeer::OnHandshakeInitSent() {
WgClearTimer(TIMER_SEND_KEEPALIVE);
WgSetTimer(TIMER_RETRANSMIT_HANDSHAKE);
}
void WgPeer::OnHandshakeAuthComplete() {
WgClearTimer(TIMER_NEW_HANDSHAKE);
WgSetTimer(TIMER_ZERO_KEYS);
WgSetTimer(TIMER_PERSISTENT_KEEPALIVE);
}
static const char * const kCipherSuites[] = {
"chacha20-poly1305",
"aes128-gcm",
"aes256-gcm",
"none"
};
void WgPeer::OnHandshakeFullyComplete() {
WgClearTimer(TIMER_RETRANSMIT_HANDSHAKE);
handshake_attempts_ = 0;
if (last_complete_handskake_timestamp_ == 0) {
bool any_feature = false;
for(size_t i = 0; i < WG_FEATURES_COUNT; i++)
any_feature |= curr_keypair_->enabled_features[i];
if (curr_keypair_->cipher_suite != 0 || any_feature) {
RINFO("Using %s, %s %s %s %s %s", kCipherSuites[curr_keypair_->cipher_suite],
curr_keypair_->enabled_features[0] ? "short_header" : "",
curr_keypair_->enabled_features[1] ? "mac64" : "",
curr_keypair_->enabled_features[2] ? "ipzip" : "",
curr_keypair_->enabled_features[4] ? "skip_keyid_in" : "",
curr_keypair_->enabled_features[5] ? "skip_keyid_out" : "");
}
}
last_complete_handskake_timestamp_ = OsGetMilliseconds();
dev_->last_complete_handskake_timestamp_ = last_complete_handskake_timestamp_;
// RINFO("Connection established.");
}
// Check if any of the timeouts have expired
uint32 WgPeer::CheckTimeouts(uint64 now) {
uint32 t, rv = 0;
if (now >= time_of_next_key_event_)
CheckAndUpdateTimeOfNextKeyEvent(now);
if ((t = timers_) == 0)
return 0;
uint32 now32 = (uint32)now;
// Got any new timers?
if (t & (0x1f << 5)) {
if (t & (1 << (5+0))) timer_value_[0] = now32;
if (t & (1 << (5+1))) timer_value_[1] = now32;
if (t & (1 << (5+2))) timer_value_[2] = now32;
if (t & (1 << (5+3))) timer_value_[3] = now32;
if (t & (1 << (5+4))) timer_value_[4] = now32;
t |= (t >> 5);
t &= 0x1F;
}
// Got any expired timers?
if (t & 0x1F) {
if ((t & (1 << TIMER_RETRANSMIT_HANDSHAKE)) && (now32 - timer_value_[TIMER_RETRANSMIT_HANDSHAKE]) >= REKEY_TIMEOUT_MS) {
t ^= (1 << TIMER_RETRANSMIT_HANDSHAKE);
if (handshake_attempts_ > MAX_HANDSHAKE_ATTEMPTS) {
RINFO("Too many handshake attempts. Stopping.");
t &= ~(1 << TIMER_SEND_KEEPALIVE);
ClearPacketQueue();
} else {
RINFO("Retrying handshake, attempt %d...", handshake_attempts_ + 2);
handshake_attempts_++;
rv |= ACTION_SEND_HANDSHAKE;
}
}
if ((t & (1 << TIMER_SEND_KEEPALIVE)) && (now32 - timer_value_[TIMER_SEND_KEEPALIVE]) >= KEEPALIVE_TIMEOUT_MS) {
t &= ~(1 << TIMER_SEND_KEEPALIVE);
rv |= ACTION_SEND_KEEPALIVE;
if (pending_keepalive_) {
pending_keepalive_ = false;
timer_value_[TIMER_SEND_KEEPALIVE] = now32;
t |= (1 << TIMER_SEND_KEEPALIVE);
}
}
if ((t & (1 << TIMER_PERSISTENT_KEEPALIVE)) && (now32 - timer_value_[TIMER_PERSISTENT_KEEPALIVE]) >= (uint32)persistent_keepalive_ms_) {
t &= ~(1 << TIMER_PERSISTENT_KEEPALIVE);
if (persistent_keepalive_ms_) {
t &= ~(1 << TIMER_SEND_KEEPALIVE);
rv |= ACTION_SEND_KEEPALIVE;
}
}
if ((t & (1 << TIMER_NEW_HANDSHAKE)) && (now32 - timer_value_[TIMER_NEW_HANDSHAKE]) >= KEEPALIVE_TIMEOUT_MS + REKEY_TIMEOUT_MS) {
t &= ~(1 << TIMER_NEW_HANDSHAKE);
handshake_attempts_ = 0;
rv |= ACTION_SEND_HANDSHAKE;
RINFO("Retrying handshake with peer");
}
if ((t & (1 << TIMER_ZERO_KEYS)) && (now32 - timer_value_[TIMER_ZERO_KEYS]) >= REJECT_AFTER_TIME_MS * 3) {
RINFO("Expiring all keys for peer");
t &= ~(1 << TIMER_ZERO_KEYS);
ClearKeys();
ClearHandshake();
}
}
timers_ = t;
return rv;
}
// Check all key stuff here to avoid calling possibly expensive timestamp routines in the packet handler
void WgPeer::CheckAndUpdateTimeOfNextKeyEvent(uint64 now) {
uint64 next_time = UINT64_MAX;
uint32 rv = 0;
if (curr_keypair_ != NULL) {
if (now >= curr_keypair_->key_timestamp + REJECT_AFTER_TIME_MS) {
DeleteKeypair(&curr_keypair_);
} else if (curr_keypair_->is_initiator) {
// if a peer is the initiator of a current secure session, WireGuard will send a handshake initiation
// message to begin a new secure session if, after transmitting a transport data message, the current secure session
// is REKEY_AFTER_TIME_MS old, or if after receiving a transport data message, the current secure session is
// (REKEY_AFTER_TIME_MS - KEEPALIVE_TIMEOUT_MS - REKEY_TIMEOUT_MS) seconds old and it has not yet acted upon
// this event.
if (now >= curr_keypair_->key_timestamp + (REJECT_AFTER_TIME_MS - KEEPALIVE_TIMEOUT_MS - REKEY_TIMEOUT_MS)) {
next_time = curr_keypair_->key_timestamp + REJECT_AFTER_TIME_MS;
if (curr_keypair_->recv_key_state == WgKeypair::KEY_VALID)
curr_keypair_->recv_key_state = WgKeypair::KEY_WANT_REFRESH;
} else if (now >= curr_keypair_->key_timestamp + REKEY_AFTER_TIME_MS) {
next_time = curr_keypair_->key_timestamp + (REJECT_AFTER_TIME_MS - KEEPALIVE_TIMEOUT_MS - REKEY_TIMEOUT_MS);
if (curr_keypair_->send_key_state == WgKeypair::KEY_VALID)
curr_keypair_->send_key_state = WgKeypair::KEY_WANT_REFRESH;
} else {
next_time = curr_keypair_->key_timestamp + REKEY_AFTER_TIME_MS;
}
} else {
next_time = curr_keypair_->key_timestamp + REJECT_AFTER_TIME_MS;
}
}
if (prev_keypair_ != NULL) {
if (now >= prev_keypair_->key_timestamp + REJECT_AFTER_TIME_MS)
DeleteKeypair(&prev_keypair_);
else
next_time = std::min<uint64>(next_time, prev_keypair_->key_timestamp + REJECT_AFTER_TIME_MS);
}
if (next_keypair_ != NULL) {
if (now >= next_keypair_->key_timestamp + REJECT_AFTER_TIME_MS)
DeleteKeypair(&next_keypair_);
else
next_time = std::min<uint64>(next_time, next_keypair_->key_timestamp + REJECT_AFTER_TIME_MS);
}
time_of_next_key_event_ = next_time;
}
void WgPeer::SetEndpoint(const IpAddr &sin) {
endpoint_ = sin;
}
void WgPeer::SetPersistentKeepalive(int persistent_keepalive_secs) {
if (persistent_keepalive_secs < 10 || persistent_keepalive_secs > 10000)
return;
persistent_keepalive_ms_ = persistent_keepalive_secs * 1000;
}
bool WgPeer::AddIp(const WgCidrAddr &cidr_addr) {
if (cidr_addr.size == 32) {
if (cidr_addr.cidr > 32)
return false;
dev_->ip_to_peer_map_.InsertV4(cidr_addr.addr, cidr_addr.cidr, this);
allowed_ips_.push_back(cidr_addr);
return true;
} else if (cidr_addr.size == 128) {
if (cidr_addr.cidr > 128)
return false;
dev_->ip_to_peer_map_.InsertV6(cidr_addr.addr, cidr_addr.cidr, this);
allowed_ips_.push_back(cidr_addr);
return true;
} else {
return false;
}
}
void WgPeer::SetAllowMulticast(bool allow) {
allow_multicast_through_peer_ = allow;
}
void WgPeer::SetFeature(int feature, uint8 value) {
features_[feature] = value;
}
bool WgPeer::AddCipher(int cipher) {
if (num_ciphers_ == MAX_CIPHERS)
return false;
if (cipher == EXT_CIPHER_SUITE_AES128_GCM || cipher == EXT_CIPHER_SUITE_AES256_GCM) {
#if !WITH_AESGCM
return true;
#endif // !WITH_AESGCM
if (!X86_PCAP_AES)
return true;
}
ciphers_[num_ciphers_++] = cipher;
return true;
}
WgRateLimit::WgRateLimit() {
key1_[0] = key1_[1] = 1;
key2_[0] = key2_[1] = 1;
bin1_ = bins_[0];
bin2_ = bins_[1];
rand_ = 0;
rand_xor_ = 0;
packets_per_sec_ = PACKETS_PER_SEC;
used_rate_limit_ = 0;
memset(bins_, 0, sizeof(bins_));
}
void WgRateLimit::Periodic(uint32 s[5]) {
unsigned int per_sec = PACKETS_PER_SEC;
if (used_rate_limit_ >= TOTAL_PACKETS_PER_SEC) {
per_sec = PACKETS_PER_SEC * TOTAL_PACKETS_PER_SEC / used_rate_limit_;
if (per_sec < 1)
per_sec = 1;
}
if ((unsigned)per_sec > packets_per_sec_)
per_sec = (per_sec + packets_per_sec_ + 1) >> 1;
// if (per_sec != packets_per_sec_) {
// RINFO("Setting pps: %d", per_sec);
packets_per_sec_ = per_sec;
// }
used_rate_limit_ = 0;
rand_xor_ = s[4];
key2_[0] = key1_[0];
key2_[1] = key1_[1];
memcpy(key1_, s, sizeof(key1_));
std::swap(bin1_, bin2_);
memset(bin1_, 0, BINSIZE);
}
static inline size_t hashit(uint64 ip, const uint64 *key) {
uint64 x = ip * key[0] + rol64(ip, 32) * key[1];
uint32 a = (uint32)(x + (x >> 32) * 0x85ebca6b);
a -= a >> 16;
a ^= a >> 4;
return a;
}
WgRateLimit::RateLimitResult WgRateLimit::CheckRateLimit(uint64 ip) {
uint8 *a = &bin1_[hashit(ip, key1_) & (BINSIZE - 1)];
uint8 *b = &bin2_[hashit(ip, key2_) & (BINSIZE - 1)];
unsigned int old = std::max<int>(*a, *b - packets_per_sec_), v = 0;
if (old < PACKET_ACCUM / 2) {
v = 1;
} else if (old < PACKET_ACCUM) {
v = old < ((uint64)rand_ * ((PACKET_ACCUM / 2) + 1) >> 32) + (PACKET_ACCUM / 2);
rand_ = (rand_ * 0x1b873593 + 5) + rand_xor_;
}
RateLimitResult rr = {a, (uint8)(old + v), (uint8)v};
return rr;
}
void WgKeypairEncryptPayload(uint8 *dst, const size_t src_len,
const uint8 *ad, const size_t ad_len,
const uint64 nonce, WgKeypair *keypair) {
if (keypair->cipher_suite == EXT_CIPHER_SUITE_CHACHA20POLY1305) {
chacha20poly1305_encrypt(dst, dst, src_len, ad, ad_len, nonce, keypair->send_key);
} else if (keypair->cipher_suite >= EXT_CIPHER_SUITE_AES128_GCM && keypair->cipher_suite <= EXT_CIPHER_SUITE_AES256_GCM) {
#if WITH_AESGCM
aesgcm_encrypt(dst, dst, src_len, ad, ad_len, nonce, &keypair->aes_gcm128_context_[0]);
#endif // WITH_AESGCM
} else {
poly1305_get_mac(dst, src_len, ad, ad_len, nonce, keypair->send_key, dst + src_len);
}
// Convert MAC to 8 bytes if that's all we need.
if (keypair->auth_tag_length != WG_MAC_LEN) {
uint8 *mac = dst + src_len;
uint64 rv = siphash_2u64(ReadLE64(mac), ReadLE64(mac + 8), (siphash_key_t*)keypair->compress_mac_keys[0]);
WriteLE64(mac, rv);
}
}
bool WgKeypairDecryptPayload(uint8 *dst, size_t src_len,
const uint8 *ad, size_t ad_len,
const uint64 nonce, WgKeypair *keypair) {
uint8 mac[16];
if (src_len < keypair->auth_tag_length)
return false;
src_len -= keypair->auth_tag_length;
if (keypair->cipher_suite == EXT_CIPHER_SUITE_CHACHA20POLY1305) {
chacha20poly1305_decrypt_get_mac(dst, dst, src_len, ad, ad_len, nonce, keypair->recv_key, mac);
} else if (keypair->cipher_suite >= EXT_CIPHER_SUITE_AES128_GCM && keypair->cipher_suite <= EXT_CIPHER_SUITE_AES256_GCM) {
#if WITH_AESGCM
aesgcm_decrypt_get_mac(dst, dst, src_len, ad, ad_len, nonce, &keypair->aes_gcm128_context_[1], mac);
#else // WITH_AESGCM
return false;
#endif // WITH_AESGCM
} else {
poly1305_get_mac(dst, src_len, ad, ad_len, nonce, keypair->recv_key, mac);
}
if (keypair->auth_tag_length == WG_MAC_LEN) {
return memcmp_crypto(mac, dst + src_len, WG_MAC_LEN) == 0;
} else {
uint64 rv = siphash_2u64(ReadLE64(mac), ReadLE64(mac + 8), (siphash_key_t*)keypair->compress_mac_keys[1]);
WriteLE64(mac, rv);
return memcmp_crypto(mac, dst + src_len, keypair->auth_tag_length) == 0;
}
}