From 13158f9d904aae7c55ca05bbc57dc38f98a2cc6b Mon Sep 17 00:00:00 2001 From: Ludvig Strigeus Date: Mon, 10 Dec 2018 23:12:57 +0100 Subject: [PATCH] Add more extension points in wireguard code --- netapi.h | 6 ++++++ wireguard.cpp | 19 +++++++++++++++---- wireguard_config.cpp | 20 ++++++++++++++------ wireguard_proto.cpp | 3 +++ wireguard_proto.h | 21 +++++++++++++++------ 5 files changed, 53 insertions(+), 16 deletions(-) diff --git a/netapi.h b/netapi.h index f1b9318..4614aad 100644 --- a/netapi.h +++ b/netapi.h @@ -66,6 +66,12 @@ struct Packet : QueuedItem { HEADROOM_BEFORE = 64, }; + +#ifdef PACKET_EXTENSION_FIELDS + PACKET_EXTENSION_FIELDS +#endif // PACKET_EXTENSION_FIELDS + + byte data_pre[HEADROOM_BEFORE]; byte data_buf[0]; diff --git a/wireguard.cpp b/wireguard.cpp index f15de27..39be614 100644 --- a/wireguard.cpp +++ b/wireguard.cpp @@ -450,6 +450,8 @@ WireguardProcessor::PacketResult WireguardProcessor::WriteAndEncryptPacketToUdp_ packet->addr = peer->endpoint_; packet->protocol = peer->endpoint_protocol_; + WG_EXTENSION_HOOKS::OnPeerOutgoingUdp(peer, packet); + if (size == 0) { peer->OnKeepaliveSent(); } else { @@ -663,6 +665,7 @@ void WireguardProcessor::SendHandshakeInitiation(WgPeer *peer) { peer->OnHandshakeInitSent(); packet->addr = peer->endpoint_; packet->protocol = peer->endpoint_protocol_; + WG_EXTENSION_HOOKS::OnPeerOutgoingUdp(peer, packet); peer->tx_bytes_ += packet->size; // If this is an incoming oneway connection (such as tcp), forget the @@ -888,6 +891,8 @@ WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPack peer->endpoint_protocol_ = packet->protocol; } + WG_EXTENSION_HOOKS::OnPeerIncomingUdp(peer, packet); + // Remember how many incoming packets we've seen so we can approximate loss keypair->incoming_packet_count++; @@ -936,9 +941,13 @@ WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPack if (ip_version == 4) { if (data_size < IPV4_HEADER_SIZE) goto getout_error_header; - WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_); - peer_from_header = (WgPeer*)dev_.ip_to_peer_map().LookupV4(ReadBE32(data + 12)); - WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_); + if (!WG_EXTENSION_HOOKS::DisableSourceAddressVerification(peer)) { + WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_); + peer_from_header = (WgPeer*)dev_.ip_to_peer_map().LookupV4(ReadBE32(data + 12)); + WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_); + if (peer_from_header != peer) + goto getout_error_header; + } size_from_header = ReadBE16(data + 2); if (size_from_header < IPV4_HEADER_SIZE) { // too small packet? @@ -950,12 +959,14 @@ WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPack WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_); peer_from_header = (WgPeer*)dev_.ip_to_peer_map().LookupV6(data + 8); WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_); + if (peer_from_header != peer) + goto getout_error_header; size_from_header = IPV6_HEADER_SIZE + ReadBE16(data + 4); } else { // invalid ip version goto getout_error_header; } - if (peer_from_header != peer || size_from_header > data_size) + if (size_from_header > data_size) goto getout_error_header; packet->size = size_from_header; diff --git a/wireguard_config.cpp b/wireguard_config.cpp index ecf17fc..c82396f 100644 --- a/wireguard_config.cpp +++ b/wireguard_config.cpp @@ -97,7 +97,9 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) { if (strcmp(group, "[Interface]") == 0) { if (key == NULL) return true; - if (strcmp(key, "PrivateKey") == 0) { + if (wg_->dev().plugin() && wg_->dev().plugin()->OnUnknownInterfaceSetting(key, value)) { + // nothing here + } else if (strcmp(key, "PrivateKey") == 0) { if (!ParseBase64Key(value, binkey)) return false; had_interface_ = true; @@ -185,8 +187,7 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) { wg_->AddExcludedIp(addr); } } else { - if (!wg_->dev().plugin() || !wg_->dev().plugin()->OnUnknownInterfaceSetting(key, value)) - goto err; + goto err; } } else if (strcmp(group, "[Peer]") == 0) { if (key == NULL) { @@ -199,7 +200,10 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) { memset(&pi_, 0, sizeof(pi_)); return true; } - if (strcmp(key, "PublicKey") == 0) { + + if (wg_->dev().plugin() && wg_->dev().plugin()->OnUnknownPeerSetting(peer_, key, value)) { + // nothing here + } else if (strcmp(key, "PublicKey") == 0) { if (!ParseBase64Key(value, pi_.pub.bytes)) return false; } else if (strcmp(key, "PresharedKey") == 0) { @@ -252,8 +256,7 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) { return false; } } else { - if (!wg_->dev().plugin() || !wg_->dev().plugin()->OnUnknownPeerSetting(peer_, key, value)) - goto err; + goto err; } } else { err: @@ -331,6 +334,11 @@ bool ParseWireGuardConfigString(WireguardProcessor *wg, const char *bufin, size_ buf = nl + 1; } file_parser.FinishGroup(); + + // Let plugin do any final processing + if (wg->dev().plugin() && !wg->dev().plugin()->OnAfterSettingsParsed()) + return false; + return true; } diff --git a/wireguard_proto.cpp b/wireguard_proto.cpp index 2b97441..f214148 100644 --- a/wireguard_proto.cpp +++ b/wireguard_proto.cpp @@ -787,6 +787,9 @@ WgPeer *WgPeer::ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packe peer->endpoint_ = packet->addr; peer->endpoint_protocol_ = packet->protocol; } + + WG_EXTENSION_HOOKS::OnPeerIncomingUdp(peer, packet); + peer->rx_bytes_ += packet->size; peer->InsertKeypairInPeer_Locked(keypair); WG_RELEASE_LOCK(peer->mutex_); diff --git a/wireguard_proto.h b/wireguard_proto.h index 5003d40..d602919 100644 --- a/wireguard_proto.h +++ b/wireguard_proto.h @@ -318,22 +318,26 @@ public: virtual bool OnUnknownInterfaceSetting(const char *key, const char *value) = 0; virtual bool OnUnknownPeerSetting(WgPeer *peer, const char *key, const char *value) = 0; + // Called after settings have been completely parsed, the plugin may modify the state + virtual bool OnAfterSettingsParsed() = 0; + // Returns true if we want to perform a handshake for this peer. virtual bool WantHandshake(WgPeer *peer) = 0; - // Called right before handshake initiation is sent out. Can't drop packets. - virtual uint32 OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) = 0; - // Called after handshake initiation is parsed, but before handshake response is sent. - enum { kHandshakeResponseDrop = 0xffffffff, kHandshakeResponseFail = 0x80000000 }; + // Called before handshake initiation is sent out. Can write extra headers. Can't drop packets. + virtual uint32 OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) = 0; + // Called after handshake initiation is parsed, but before handshake response is sent. // Packet can be dropped or keypair failed. virtual uint32 OnHandshake1(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt_in[WG_PUBLIC_KEY_LEN], uint8 *extout, uint32 extout_size, const uint8 salt_out[WG_PUBLIC_KEY_LEN]) = 0; - // Called when handshake response is parsed virtual uint32 OnHandshake2(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) = 0; + + // Called right before an outgoing non-data packet is sent out, but before it's scrambled. + virtual void OnOutgoingHandshakePacket(WgPeer *peer, Packet *packet) = 0; }; @@ -378,7 +382,9 @@ public: void SetPlugin(WgPlugin *del) { plugin_ = del; } WgPlugin *plugin() { return plugin_; } - + + MultithreadedDelayedDelete *GetDelayedDelete() { return &delayed_delete_; } + private: std::pair *LookupPeerInKeyIdLookup(uint32 key_id); WgKeypair *LookupKeypairByKeyId(uint32 key_id); @@ -773,6 +779,9 @@ bool WgKeypairDecryptPayload(uint8 *dst, const size_t src_len, struct WgExtensionHooksDefault { static uint32 GetIpv4Target(Packet *packet, uint8 *data) { return ReadBE32(data + 16); } + static void OnPeerIncomingUdp(WgPeer *peer, const Packet *packet) { } + static void OnPeerOutgoingUdp(WgPeer *peer, Packet *packet) { } + static bool DisableSourceAddressVerification(WgPeer *peer) { return false; } }; #ifndef WG_EXTENSION_HOOKS