From 402358e5a0938f7bffcd671098fccd5571ce5ffa Mon Sep 17 00:00:00 2001 From: Ludvig Strigeus Date: Tue, 20 Nov 2018 20:25:27 +0100 Subject: [PATCH] Add hooks into more parts of the wg code --- tunsafe_config.h | 7 ++ wireguard.cpp | 11 +- wireguard_config.cpp | 6 +- wireguard_proto.cpp | 278 +++++++++++++++++++++++-------------------- wireguard_proto.h | 78 +++++++----- 5 files changed, 220 insertions(+), 160 deletions(-) diff --git a/tunsafe_config.h b/tunsafe_config.h index 70a3701..37214c2 100644 --- a/tunsafe_config.h +++ b/tunsafe_config.h @@ -6,10 +6,17 @@ #define TUNSAFE_VERSION_STRING_LONG "TunSafe 1.5-rc1" #define WITH_HANDSHAKE_EXT 0 +#define WITH_CIPHER_SUITES 0 +#define WITH_BOOLEAN_FEATURES 0 +#define WITH_PACKET_COMPRESSION 0 + #define WITH_SHORT_HEADERS 0 #define WITH_HEADER_OBFUSCATION 0 #define WITH_AVX512_OPTIMIZATIONS 0 #define WITH_BENCHMARK 0 + + + // Use bytell hashmap instead. Only works in 64-bit builds #define WITH_BYTELL_HASHMAP 0 diff --git a/wireguard.cpp b/wireguard.cpp index 40c816d..7831cb1 100644 --- a/wireguard.cpp +++ b/wireguard.cpp @@ -436,7 +436,7 @@ void WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Pac peer->OnDataSent(); // Attempt to compress the packet headers - if (WITH_HANDSHAKE_EXT && keypair->compress_handler_) { + if (WITH_PACKET_COMPRESSION && keypair->compress_handler_) { WgCompressHandler::CompressState st = keypair->compress_handler_->Compress(packet); if (st == WgCompressHandler::COMPRESS_FAIL) goto getout_discard; @@ -626,10 +626,13 @@ void WireguardProcessor::SendHandshakeInitiation(WgPeer *peer) { if (!peer->CheckHandshakeRateLimit() || peer->endpoint_.sin.sin_family == 0) return; - stats_.handshakes_out++; Packet *packet = AllocPacket(); if (packet) { - peer->CreateMessageHandshakeInitiation(packet); + if (!peer->CreateMessageHandshakeInitiation(packet)) { + FreePacket(packet); + return; + } + stats_.handshakes_out++; WG_ACQUIRE_LOCK(peer->mutex_); int attempts = ++peer->total_handshake_attempts_; if (procdel_) @@ -888,7 +891,7 @@ void WireguardProcessor::HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *key WG_RELEASE_LOCK(peer->mutex_); // Unpack the packet headers? - if (WITH_HANDSHAKE_EXT && keypair->compress_handler_) { + if (WITH_PACKET_COMPRESSION && keypair->compress_handler_) { WgCompressHandler::CompressState st = keypair->compress_handler_->Decompress(packet); if (st == WgCompressHandler::COMPRESS_FAIL) goto getout; diff --git a/wireguard_config.cpp b/wireguard_config.cpp index 50c11c5..1c13e60 100644 --- a/wireguard_config.cpp +++ b/wireguard_config.cpp @@ -185,7 +185,8 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) { wg_->AddExcludedIp(addr); } } else { - goto err; + if (!wg_->dev().plugin() || !wg_->dev().plugin()->OnUnknownInterfaceSetting(key, value)) + goto err; } } else if (strcmp(group, "[Peer]") == 0) { if (key == NULL) { @@ -251,7 +252,8 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) { return false; } } else { - goto err; + if (!wg_->dev().plugin() || !wg_->dev().plugin()->OnUnknownPeerSetting(peer_, key, value)) + goto err; } } else { err: diff --git a/wireguard_proto.cpp b/wireguard_proto.cpp index f24a1e2..d2e8437 100644 --- a/wireguard_proto.cpp +++ b/wireguard_proto.cpp @@ -54,7 +54,7 @@ bool ReplayDetector::CheckReplay(uint64 seq_nr) { WgDevice::WgDevice() { peers_ = NULL; last_peer_ptr_ = &peers_; - delegate_ = NULL; + plugin_ = NULL; header_obfuscation_ = false; is_private_key_initialized_ = false; next_rng_slot_ = 0; @@ -349,6 +349,7 @@ WgPeer::WgPeer(WgDevice *dev) { endpoint_.sin.sin_family = 0; endpoint_protocol_ = 0; next_peer_ = NULL; + peer_extra_data_ = NULL; curr_keypair_ = next_keypair_ = prev_keypair_ = NULL; expect_cookie_reply_ = false; has_mac2_cookie_ = false; @@ -391,6 +392,7 @@ WgPeer::~WgPeer() { assert(curr_keypair_ == NULL && next_keypair_ == NULL && prev_keypair_ == NULL); assert(local_key_id_during_hs_ == 0); assert(first_queued_packet_ == NULL); + delete peer_extra_data_; } void WgPeer::DelayedDelete(void *x) { @@ -518,12 +520,23 @@ void WgPeer::SetPresharedKey(const uint8 preshared_key[WG_SYMMETRIC_KEY_LEN]) { } // run on the client -void WgPeer::CreateMessageHandshakeInitiation(Packet *packet) { +bool WgPeer::CreateMessageHandshakeInitiation(Packet *packet) { assert(dev_->IsMainThread()); uint8 k[WG_SYMMETRIC_KEY_LEN]; MessageHandshakeInitiation *dst = (MessageHandshakeInitiation *)packet->data; + int extfield_size = 0; + if (WITH_HANDSHAKE_EXT && supports_handshake_extensions_) + extfield_size = WriteHandshakeExtension(dst->timestamp_enc + WG_TIMESTAMP_LEN, NULL); + + if (dev_->plugin_) { + uint32 rv = dev_->plugin_->OnHandshake0(this, dst->timestamp_enc + WG_TIMESTAMP_LEN + extfield_size, MAX_SIZE_OF_HANDSHAKE_EXTENSION - extfield_size); + if (rv & WgPlugin::kHandshakeResponseFail) + return false; + extfield_size += rv; + } + // Ci := HASH(CONSTRUCTION) memcpy(hs_.ci, kWgInitChainingKey, sizeof(hs_.ci)); // Hi := HASH(Ci || IDENTIFIER) @@ -550,11 +563,6 @@ void WgPeer::CreateMessageHandshakeInitiation(Packet *packet) { // TAI64N OsGetTimestampTAI64N(dst->timestamp_enc); - size_t extfield_size = 0; -#if WITH_HANDSHAKE_EXT - if (supports_handshake_extensions_) - extfield_size = WriteHandshakeExtension(dst->timestamp_enc + WG_TIMESTAMP_LEN, NULL); -#endif // WITH_HANDSHAKE_EXT // msg.timestamp := AEAD(K, 0, timestamp, hi) chacha20poly1305_encrypt(dst->timestamp_enc, dst->timestamp_enc, extfield_size + WG_TIMESTAMP_LEN, hs_.hi, sizeof(hs_.hi), 0, k); // Hi := HASH(Hi || msg.timestamp) @@ -566,6 +574,7 @@ void WgPeer::CreateMessageHandshakeInitiation(Packet *packet) { dst->type = MESSAGE_HANDSHAKE_INITIATION; memzero_crypto(k, sizeof(k)); WriteMacToPacket((uint8*)dst, (MessageMacs*)((uint8*)&dst->mac + extfield_size)); + return true; } // Parsed by server @@ -591,7 +600,7 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) { uint8 extbuf[MAX_SIZE_OF_HANDSHAKE_EXTENSION + WG_TIMESTAMP_LEN]; MessageHandshakeInitiation *src = (MessageHandshakeInitiation *)packet->data; MessageHandshakeResponse *dst; - size_t extfield_size; + int extfield_size; // Ci := HASH(CONSTRUCTION) memcpy(ci, kWgInitChainingKey, sizeof(ci)); @@ -612,7 +621,7 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) { BlakeMix(hi, src->static_enc, sizeof(src->static_enc)); // Lookup the peer with this ID while ((peer = dev->GetPeerFromPublicKey(spubi)) == NULL) { - if (dev->delegate_ == NULL || !dev->delegate_->HandleUnknownPeerId(spubi.bytes, packet)) + if (dev->plugin_ == NULL || !dev->plugin_->HandleUnknownPeerId(spubi.bytes, packet)) goto getout; } // (Ci, K) := KDF2(Ci, DH(sprivr, spubi)) @@ -620,7 +629,7 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) { // Hi2 := Hi memcpy(hi2, hi, sizeof(hi2)); extfield_size = packet->size - sizeof(MessageHandshakeInitiation); - if (extfield_size > MAX_SIZE_OF_HANDSHAKE_EXTENSION || (extfield_size && !peer->supports_handshake_extensions_)) + if ((uint32)extfield_size > MAX_SIZE_OF_HANDSHAKE_EXTENSION || (extfield_size && !peer->supports_handshake_extensions_)) goto getout; // Hi := HASH(Hi || msg.timestamp) BlakeMix(hi, src->timestamp_enc, extfield_size + WG_TIMESTAMP_LEN + WG_MAC_LEN); @@ -641,9 +650,9 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) { memcpy(e_remote, src->ephemeral, sizeof(e_remote)); remote_key_id = src->sender_key_id; - + dst = (MessageHandshakeResponse *)src; - + dst->receiver_key_id = remote_key_id; // (Epriv_r, Epub_r) := DH-GENERATE() // msg.ephemeral = Epub_r OsGetRandomBytes(e_priv, sizeof(e_priv)); @@ -662,28 +671,40 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) { // Hr := HASH(Hr || T) BlakeMix(hi, t, sizeof(t)); - dst->receiver_key_id = remote_key_id; - keypair = WgPeer::CreateNewKeypair(false, ci, remote_key_id, extbuf + WG_TIMESTAMP_LEN, extfield_size); + keypair = WgPeer::CreateNewKeypair(false, ci, remote_key_id); if (keypair) { + if (WITH_HANDSHAKE_EXT && !peer->ParseExtendedHandshake(keypair, extbuf + WG_TIMESTAMP_LEN, extfield_size)) + goto getout; - dst->sender_key_id = dev->InsertInKeyIdLookup(peer, keypair); - - size_t extfield_out_size = 0; -#if WITH_HANDSHAKE_EXT - if (extfield_size) + int extfield_out_size = 0; + if (WITH_HANDSHAKE_EXT && extfield_size) extfield_out_size = peer->WriteHandshakeExtension(dst->empty_enc, keypair); -#endif // WITH_HANDSHAKE_EXT - uint32 orig_packet_size = packet->size; - packet->size = (unsigned)(sizeof(MessageHandshakeResponse) + extfield_out_size); + // Allow plugin to determine what to do with the packet, + // it can append new headers to the response, and decide what to do. + if (dev->plugin_) { + uint32 rv = dev->plugin_->OnHandshake1(peer, extbuf + WG_TIMESTAMP_LEN, extfield_size, + dst->empty_enc + extfield_out_size, MAX_SIZE_OF_HANDSHAKE_EXTENSION - extfield_out_size); + if (rv == WgPlugin::kHandshakeResponseDrop) + goto getout; + if (rv & WgPlugin::kHandshakeResponseFail) + delete exch_null(keypair); + extfield_out_size += rv & ~WgPlugin::kHandshakeResponseFail; + } + + dst->sender_key_id = keypair ? dev->InsertInKeyIdLookup(peer, keypair) : 0; WG_ACQUIRE_LOCK(peer->mutex_); - peer->rx_bytes_ += orig_packet_size; + peer->rx_bytes_ += packet->size; + if (keypair != NULL) { + peer->InsertKeypairInPeer_Locked(keypair); + peer->OnHandshakeAuthComplete(); + } + packet->size = (unsigned)(sizeof(MessageHandshakeResponse) + extfield_out_size); peer->tx_bytes_ += packet->size; - peer->InsertKeypairInPeer_Locked(keypair); - peer->OnHandshakeAuthComplete(); WG_RELEASE_LOCK(peer->mutex_); + // msg.empty := AEAD(K, 0, "", Hr) chacha20poly1305_encrypt(dst->empty_enc, dst->empty_enc, extfield_out_size, hi, sizeof(hi), 0, k); // Hr := HASH(Hr || "") @@ -693,6 +714,7 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) { peer->WriteMacToPacket((uint8*)dst, (MessageMacs*)((uint8*)&dst->mac + extfield_out_size)); } else { getout: + delete keypair; peer = NULL; } memzero_crypto(hi, sizeof(hi)); @@ -728,18 +750,33 @@ WgPeer *WgPeer::ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packe // Hr := HASH(Hr || T) BlakeMix(hs.hi, t, sizeof(t)); - size_t extfield_size = packet->size - sizeof(MessageHandshakeResponse); - if (extfield_size > MAX_SIZE_OF_HANDSHAKE_EXTENSION) + int extfield_size = packet->size - sizeof(MessageHandshakeResponse); + if ((uint32)extfield_size > MAX_SIZE_OF_HANDSHAKE_EXTENSION) goto getout; // "" := AEAD_DEC(K, 0, msg.empty, Hr) if (!chacha20poly1305_decrypt(src->empty_enc, src->empty_enc, extfield_size + sizeof(src->empty_enc), hs.hi, sizeof(hs.hi), 0, k)) goto getout; - keypair = WgPeer::CreateNewKeypair(true, hs.ci, src->sender_key_id, src->empty_enc, extfield_size); + keypair = WgPeer::CreateNewKeypair(true, hs.ci, src->sender_key_id); if (!keypair) goto getout; + if (WITH_HANDSHAKE_EXT && !peer->ParseExtendedHandshake(keypair, src->empty_enc, extfield_size)) { + delete keypair; + goto getout; + } + + // Allow plugin to determine what to do with the packet, + // it can append new headers to the response, and decide what to do. + if (dev->plugin_) { + uint32 rv = dev->plugin_->OnHandshake2(peer, src->empty_enc, extfield_size); + if (rv & WgPlugin::kHandshakeResponseFail) { + delete keypair; + goto getout; + } + } + // Re-map the entry in the id table so it points at this keypair instead. keypair->local_key_id = peer->local_key_id_during_hs_; peer->local_key_id_during_hs_ = 0; @@ -787,42 +824,43 @@ void WgPeer::ParseMessageHandshakeCookie(WgDevice *dev, const MessageHandshakeCo memcpy(peer->mac2_cookie_, cookie, sizeof(peer->mac2_cookie_)); } -#if WITH_HANDSHAKE_EXT +int WgPeer::WriteHandshakeExtension(uint8 *dst, WgKeypair *keypair) { + uint8 *dst_org = dst, *dst_end = dst + MAX_SIZE_OF_HANDSHAKE_EXTENSION; -size_t WgPeer::WriteHandshakeExtension(uint8 *dst, WgKeypair *keypair) { - uint8 *dst_end = dst + MAX_SIZE_OF_HANDSHAKE_EXTENSION; - uint8 *dst_org = dst, value = 0; - // Include the supported features extension - if (!IsOnlyZeros(features_, sizeof(features_))) { - *dst++ = EXT_BOOLEAN_FEATURES; - *dst++ = (WG_FEATURES_COUNT + 3) >> 2; - for (size_t i = 0; i != WG_FEATURES_COUNT; i++) { - if ((i & 3) == 0) - value = 0; - dst[i >> 2] = (value += (features_[i] << ((i * 2) & 7))); + if (WITH_HANDSHAKE_EXT) { + if (WITH_BOOLEAN_FEATURES) { + uint8 value = 0; + // Include the supported features extension + if (!IsOnlyZeros(features_, sizeof(features_))) { + *dst++ = EXT_BOOLEAN_FEATURES; + *dst++ = (WG_FEATURES_COUNT + 3) >> 2; + for (size_t i = 0; i != WG_FEATURES_COUNT; i++) { + if ((i & 3) == 0) + value = 0; + dst[i >> 2] = (value += (features_[i] << ((i * 2) & 7))); + } + // swap WG_FEATURE_ID_SKIP_KEYID_IN and WG_FEATURE_ID_SKIP_KEYID_OUT + dst[1] = (dst[1] & 0xF0) + ((dst[1] >> 2) & 0x03) + ((dst[1] << 2) & 0x0C); + dst += (WG_FEATURES_COUNT + 3) >> 2; + } } - // swap WG_FEATURE_ID_SKIP_KEYID_IN and WG_FEATURE_ID_SKIP_KEYID_OUT - dst[1] = (dst[1] & 0xF0) + ((dst[1] >> 2) & 0x03) + ((dst[1] << 2) & 0x0C); - dst += (WG_FEATURES_COUNT + 3) >> 2; - } - // Ordered list of cipher suites - size_t ciphers = num_ciphers_; - if (ciphers) { - *dst++ = EXT_CIPHER_SUITES + cipher_prio_; - if (keypair) { - *dst++ = 1; - *dst++ = keypair->cipher_suite; - } else { - *dst++ = (uint8)ciphers; - memcpy(dst, ciphers_, ciphers); - dst += ciphers; + if (WITH_CIPHER_SUITES) { + // Ordered list of cipher suites + size_t ciphers = num_ciphers_; + if (ciphers) { + *dst++ = EXT_CIPHER_SUITES + cipher_prio_; + if (keypair) { + *dst++ = 1; + *dst++ = keypair->cipher_suite; + } else { + *dst++ = (uint8)ciphers; + memcpy(dst, ciphers_, ciphers); + dst += ciphers; + } + } } } - // Packet compression extension? - if (features_[WG_FEATURE_ID_IPZIP] && dev_->delegate_) - dst += dev_->delegate_->WritePacketCompressionExtension(dst, dst_end - dst); - - return dst - dst_org; + return (int)(dst - dst_org); } static bool ResolveBooleanFeatureValue(uint8 other, uint8 self, bool *result) { @@ -853,48 +891,6 @@ static uint32 ResolveCipherSuite(int tie, const uint8 *a, size_t a_size, const u (tie == 0 && cipher_strengths[found_a] > cipher_strengths[found_b])) ? found_a : found_b; } -bool WgPeer::ParseExtendedHandshake(WgKeypair *keypair, const uint8 *data, size_t data_size) { - while (data_size >= 2) { - uint8 type = data[0], size = data[1]; - data += 2, data_size -= 2; - if (size > data_size) - return false; - switch (type) { - case EXT_CIPHER_SUITES_PRIO: - case EXT_CIPHER_SUITES: - keypair->cipher_suite = ResolveCipherSuite(keypair->peer->cipher_prio_ - (type - EXT_CIPHER_SUITES), - keypair->peer->ciphers_, keypair->peer->num_ciphers_, - data, size); - break; - case EXT_BOOLEAN_FEATURES: - for (size_t i = 0, j = std::max(WG_FEATURES_COUNT, size * 4); i != j; i++) { - uint8 value = (i < size * 4) ? (data[i >> 2] >> ((i * 2) & 7)) & 3 : 0; - if (i >= WG_FEATURES_COUNT ? (value == WG_BOOLEAN_FEATURE_ENFORCES) : - !ResolveBooleanFeatureValue(value, keypair->peer->features_[i], &keypair->enabled_features[i])) - return false; - } - break; - case EXT_PACKET_COMPRESSION: - if (keypair->enabled_features[WG_FEATURE_ID_IPZIP] && !keypair->compress_handler_ && keypair->peer->dev_->delegate_) - keypair->compress_handler_ = keypair->peer->dev_->delegate_->ParsePacketCompressionExtension(keypair, data, size); - break; - } - data += size, data_size -= size; - } - if (data_size != 0) - return false; - - if (!keypair->compress_handler_) - keypair->enabled_features[WG_FEATURE_ID_IPZIP] = false; - keypair->auth_tag_length = (keypair->enabled_features[WG_FEATURE_ID_SHORT_MAC] ? 8 : CHACHA20POLY1305_AUTHTAGLEN); - -// RINFO("Cipher Suite = %d", keypair->cipher_suite); - - return true; -} - -#endif // WITH_HANDSHAKE_EXT - static void WgKeypairDelayedDelete(void *x) { WgKeypair *t = (WgKeypair*)x; if (t->aes_gcm128_context_) @@ -922,7 +918,60 @@ void WgPeer::DeleteKeypair(WgKeypair **kp) { } } -WgKeypair *WgPeer::CreateNewKeypair(bool is_initiator, const uint8 chaining_key[WG_HASH_LEN], uint32 remote_key_id, const uint8 *extfield, size_t extfield_size) { +bool WgPeer::ParseExtendedHandshake(WgKeypair *kp, const uint8 *data, size_t data_size) { + assert(WITH_HANDSHAKE_EXT); + + while (data_size >= 2) { + uint8 type = data[0], size = data[1]; + data += 2, data_size -= 2; + if (size > data_size) + return false; + switch (type) { + case EXT_CIPHER_SUITES_PRIO: + case EXT_CIPHER_SUITES: + if (WITH_CIPHER_SUITES) { + kp->cipher_suite = ResolveCipherSuite(cipher_prio_ - (type - EXT_CIPHER_SUITES), + ciphers_, num_ciphers_, data, size); + } + break; + + case EXT_BOOLEAN_FEATURES: + if (WITH_BOOLEAN_FEATURES) { + for (size_t i = 0, j = std::max(WG_FEATURES_COUNT, size * 4); i != j; i++) { + uint8 value = (i < size * 4) ? (data[i >> 2] >> ((i * 2) & 7)) & 3 : 0; + if (i >= WG_FEATURES_COUNT ? (value == WG_BOOLEAN_FEATURE_ENFORCES) : + !ResolveBooleanFeatureValue(value, features_[i], &kp->enabled_features[i])) + return false; + } + } + break; + } + data += size, data_size -= size; + } + if (data_size != 0) + return false; + + if (WITH_BOOLEAN_FEATURES) + kp->auth_tag_length = (kp->enabled_features[WG_FEATURE_ID_SHORT_MAC] ? 8 : CHACHA20POLY1305_AUTHTAGLEN); + return true; + + + if (WITH_CIPHER_SUITES && kp->cipher_suite >= EXT_CIPHER_SUITE_AES128_GCM && kp->cipher_suite <= EXT_CIPHER_SUITE_AES256_GCM) { +#if WITH_AESGCM + kp->aes_gcm128_context_ = (AesGcm128StaticContext *)malloc(sizeof(*kp->aes_gcm128_context_) * 2); + if (!kp->aes_gcm128_context_) + return false; + int key_size = (kp->cipher_suite == EXT_CIPHER_SUITE_AES128_GCM) ? 128 : 256; + CRYPTO_gcm128_init(&kp->aes_gcm128_context_[0], kp->send_key, key_size); + CRYPTO_gcm128_init(&kp->aes_gcm128_context_[1], kp->recv_key, key_size); +#else // WITH_AESGCM + return false; +#endif // WITH_AESGCM + } + +} + +WgKeypair *WgPeer::CreateNewKeypair(bool is_initiator, const uint8 chaining_key[WG_HASH_LEN], uint32 remote_key_id) { WgKeypair *kp = new WgKeypair; uint8 *first_key, *second_key; if (!kp) @@ -932,14 +981,6 @@ WgKeypair *WgPeer::CreateNewKeypair(bool is_initiator, const uint8 chaining_key[ kp->remote_key_id = remote_key_id; kp->auth_tag_length = CHACHA20POLY1305_AUTHTAGLEN; -#if WITH_HANDSHAKE_EXT - if (!ParseExtendedHandshake(kp, extfield, extfield_size)) { -fail: - delete kp; - return NULL; - } -#endif // WITH_HANDSHAKE_EXT - first_key = kp->send_key, second_key = kp->recv_key; if (!is_initiator) std::swap(first_key, second_key); @@ -952,21 +993,6 @@ fail: std::swap(kp->compress_mac_keys[0][1], kp->compress_mac_keys[1][1]); } -#if WITH_HANDSHAKE_EXT - if (kp->cipher_suite >= EXT_CIPHER_SUITE_AES128_GCM && kp->cipher_suite <= EXT_CIPHER_SUITE_AES256_GCM) { -#if WITH_AESGCM - kp->aes_gcm128_context_ = (AesGcm128StaticContext *)malloc(sizeof(*kp->aes_gcm128_context_) * 2); - if (!kp->aes_gcm128_context_) - goto fail; - int key_size = (kp->cipher_suite == EXT_CIPHER_SUITE_AES128_GCM) ? 128 : 256; - CRYPTO_gcm128_init(&kp->aes_gcm128_context_[0], kp->send_key, key_size); - CRYPTO_gcm128_init(&kp->aes_gcm128_context_[1], kp->recv_key, key_size); -#else // WITH_AESGCM - goto fail; -#endif // WITH_AESGCM - } -#endif // WITH_HANDSHAKE_EXT - kp->send_key_state = kp->recv_key_state = WgKeypair::KEY_VALID; kp->key_timestamp = OsGetMilliseconds(); return kp; diff --git a/wireguard_proto.h b/wireguard_proto.h index 735f017..101e2b4 100644 --- a/wireguard_proto.h +++ b/wireguard_proto.h @@ -135,8 +135,6 @@ STATIC_ASSERT(sizeof(MessageHandshakeInitiation) == 148, MessageHandshakeInitiat // 1 byte length // - - struct MessageHandshakeResponse { uint32 type; uint32 sender_key_id; @@ -163,9 +161,6 @@ struct MessageData { STATIC_ASSERT(sizeof(MessageData) == 16, MessageData_wrong_size); enum { - EXT_PACKET_COMPRESSION = 0x15, - EXT_PACKET_COMPRESSION_VER = 0x01, - EXT_BOOLEAN_FEATURES = 0x16, EXT_CIPHER_SUITES = 0x18, @@ -307,30 +302,43 @@ public: virtual CompressState Compress(Packet *packet); virtual CompressState Decompress(Packet *packet); - }; +// Can be used to customize the behavior of the wireguard impl +class WgPlugin { +public: + virtual ~WgPlugin() {} + + // This is called from the main thread whenever a public key was not found in the WgDevice, + // return true to try again or false to fail. The packet can be copied and saved + // to resume a handshake later on. + virtual bool HandleUnknownPeerId(uint8 public_key[WG_PUBLIC_KEY_LEN], Packet *packet) = 0; + + // For handling unknown settings during config parsing + virtual bool OnUnknownInterfaceSetting(const char *key, const char *value) = 0; + virtual bool OnUnknownPeerSetting(WgPeer *peer, const char *key, const char *value) = 0; + + enum { + kHandshakeResponseDrop = 0xffffffff, + kHandshakeResponseFail = 0x80000000 + }; + + // Called right before handshake initiation is sent out. Can be dropped. + virtual uint32 OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size) = 0; + // Called after handshake initiation is parsed, but before handshake response is sent. + // Packet can be dropped or keypair failed. + virtual uint32 OnHandshake1(WgPeer *peer, const uint8 *ext, uint32 ext_size, uint8 *extout, uint32 extout_size) = 0; + // Called when handshake response is parsed + virtual uint32 OnHandshake2(WgPeer *peer, const uint8 *ext, uint32 ext_size) = 0; +}; + + class WgDevice { friend class WgPeer; friend class WireguardProcessor; friend class WgConfig; public: - // Can be used to customize the behavior of WgDevice - class Delegate { - public: - // This is called from the main thread whenever a public key was not found in the WgDevice, - // return true to try again or false to fail. The packet can be copied and saved - // to resume a handshake later on. - virtual bool HandleUnknownPeerId(uint8 public_key[WG_PUBLIC_KEY_LEN], Packet *packet) = 0; - - // Write out the compression header - virtual size_t WritePacketCompressionExtension(uint8 *data, size_t data_size) = 0; - - // Parse the packet compression extension - virtual WgCompressHandler *ParsePacketCompressionExtension(WgKeypair *keypair, const uint8 *data, size_t data_size) = 0; - }; - WgDevice(); ~WgDevice(); @@ -364,8 +372,9 @@ public: bool IsMainThread() { return CurrentThreadIdEquals(main_thread_id_); } bool IsMainOrDataThread() { return CurrentThreadIdEquals(main_thread_id_) || WG_IF_LOCKS_ENABLED_ELSE(delayed_delete_.enabled(), false); } - void SetDelegate(Delegate *del) { delegate_ = del; } - + void SetPlugin(WgPlugin *del) { plugin_ = del; } + WgPlugin *plugin() { return plugin_; } + private: std::pair *LookupPeerInKeyIdLookup(uint32 key_id); WgKeypair *LookupKeypairByKeyId(uint32 key_id); @@ -393,7 +402,7 @@ private: WgPeer *peers_, **last_peer_ptr_; // For hooking - Delegate *delegate_; + WgPlugin *plugin_; // Keypair IDs are generated randomly by us so no point in wasting cycles on @@ -452,6 +461,12 @@ private: MultithreadedDelayedDelete delayed_delete_; }; +// Allows associating extradata with peers that can be used by plugins etc. +class WgPeerExtraData { +public: + virtual ~WgPeerExtraData() {} +}; + // State for peer class WgPeer { friend class WgDevice; @@ -476,7 +491,7 @@ public: static WgPeer *ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet); static WgPeer *ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packet); static void ParseMessageHandshakeCookie(WgDevice *dev, const MessageHandshakeCookie *src); - void CreateMessageHandshakeInitiation(Packet *packet); + bool CreateMessageHandshakeInitiation(Packet *packet); bool CheckSwitchToNextKey_Locked(WgKeypair *keypair); void RemovePeer(); bool CheckHandshakeRateLimit(); @@ -503,19 +518,24 @@ public: uint8 endpoint_protocol() const { return endpoint_protocol_; } WgPeer *next_peer() { return next_peer_; } + WgPeerExtraData *extradata() { return peer_extra_data_; } + void SetExtradata(WgPeerExtraData *ex) { peer_extra_data_ = ex; } + WgDevice *dev() { return dev_; } + private: - static bool ParseExtendedHandshake(WgKeypair *keypair, const uint8 *data, size_t data_size); - static WgKeypair *CreateNewKeypair(bool is_initiator, const uint8 key[WG_HASH_LEN], uint32 send_key_id, const uint8 *extfield, size_t extfield_size); + bool ParseExtendedHandshake(WgKeypair *keypair, const uint8 *data, size_t data_size); + static WgKeypair *CreateNewKeypair(bool is_initiator, const uint8 key[WG_HASH_LEN], uint32 send_key_id); void WriteMacToPacket(const uint8 *data, MessageMacs *mac); void CheckAndUpdateTimeOfNextKeyEvent(uint64 now); static void DeleteKeypair(WgKeypair **kp); static void DelayedDelete(void *x); - size_t WriteHandshakeExtension(uint8 *dst, WgKeypair *keypair); + int WriteHandshakeExtension(uint8 *dst, WgKeypair *keypair); void InsertKeypairInPeer_Locked(WgKeypair *keypair); void ClearKeys_Locked(); void ClearHandshake_Locked(); void ClearPacketQueue_Locked(); void ScheduleNewHandshake(); + WgDevice *dev_; WgPeer *next_peer_; @@ -604,6 +624,8 @@ private: uint64 rx_bytes_; uint64 tx_bytes_; + WgPeerExtraData *peer_extra_data_; + // Handshake state that gets setup in |CreateMessageHandshakeInitiation| and used in // the response. struct HandshakeState {