From d04afa1cdb60c611031fba90b977c2d34c30ae72 Mon Sep 17 00:00:00 2001 From: Ludvig Strigeus Date: Fri, 23 Nov 2018 22:25:53 +0100 Subject: [PATCH] Add another packet processing API that doesn't forward --- wireguard.cpp | 154 +++++++++++++++++++++++++++++--------------------- wireguard.h | 36 +++++++++--- 2 files changed, 116 insertions(+), 74 deletions(-) diff --git a/wireguard.cpp b/wireguard.cpp index 6c218dc..8b087bb 100644 --- a/wireguard.cpp +++ b/wireguard.cpp @@ -152,7 +152,6 @@ bool WireguardProcessor::ConfigureTun() { RINFO("TAP is not compatible CIDR /31 or /32. Changing to /24"); it->cidr = 24; } - // Packets to this IP will not be sent out. if (ipv4_broadcast_addr == 0xffffffff) { uint32 netmask = it->cidr == 32 ? 0xffffffff : 0xffffffff << (32 - it->cidr); @@ -332,8 +331,32 @@ static inline bool IsIpv6Multicast(const uint8 dst[16]) { return dst[0] == 0xff; } -// On incoming packet to the tun interface. void WireguardProcessor::HandleTunPacket(Packet *packet) { + STATIC_ASSERT(kPacketResult_ForwardUdp == 1 && kPacketResult_Free == 3, kPacketResult_wrong_values); + PacketResult result = HandleTunPacket2(packet); + if (result == kPacketResult_ForwardUdp) { + udp_->WriteUdpPacket(packet); + } else if (result == kPacketResult_Free) { + FreePacket(packet); + } +} + +void WireguardProcessor::HandleUdpPacket(Packet *packet, bool overload) { + PacketResult result = HandleUdpPacket2(packet, overload); + if (result == kPacketResult_ForwardTun) { + //stats_.tun_bytes_out += size_from_header; + //stats_.tun_packets_out++; + tun_->WriteTunPacket(packet); + } else if (result == kPacketResult_ForwardUdp) { + udp_->WriteUdpPacket(packet); + } else if (result == kPacketResult_Free) { + FreePacket(packet); + } +} + + +// On incoming packet to the tun interface. +WireguardProcessor::PacketResult WireguardProcessor::HandleTunPacket2(Packet *packet) { uint8 *data = packet->data; size_t data_size = packet->size; unsigned ip_version, size_from_header; @@ -385,16 +408,14 @@ void WireguardProcessor::HandleTunPacket(Packet *packet) { // WriteAndEncryptPacketToUdp needs a held lock WG_ACQUIRE_LOCK(peer->mutex_); - WriteAndEncryptPacketToUdp_WillUnlock(peer, packet); - return; + return WriteAndEncryptPacketToUdp_WillUnlock(peer, packet); getout: - // send ICMP? - FreePacket(packet); + return kPacketResult_Free; } // This function must be called with the peer lock held. It will remove the lock -void WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet) { +WireguardProcessor::PacketResult WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet) { assert(peer->IsPeerLocked()); uint8 *data = packet->data, *ad; size_t size = packet->size, ad_len, orig_size = size; @@ -417,7 +438,7 @@ void WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Pac peer->AddPacketToPeerQueue_Locked(packet); WG_RELEASE_LOCK(peer->mutex_); peer->ScheduleNewHandshake(); - return; + return kPacketResult_InUse; } assert(!peer->marked_for_delete_); @@ -538,15 +559,17 @@ need_big_packet: WgKeypairEncryptPayload(data, size, ad, ad_len, send_ctr, keypair); - DoWriteUdpPacket(packet); if (want_handshake) peer->ScheduleNewHandshake(); - return; + + stats_.udp_packets_out++; + stats_.udp_bytes_out += packet->size; + + return kPacketResult_ForwardUdp; getout_discard: WG_RELEASE_LOCK(peer->mutex_); - FreePacket(packet); - return; + return kPacketResult_Free; } // This scrambles the initial 16 bytes of the packet with the @@ -576,20 +599,15 @@ static void ScrambleUnscramblePacket(Packet *packet, ScramblerSiphashKeys *keys) } } -static NOINLINE void ScrambleUnscrambleAndWrite(Packet *packet, ScramblerSiphashKeys *keys, UdpInterface *udp) { -#if WITH_HEADER_OBFUSCATION - ScrambleUnscramblePacket(packet, keys); - udp->WriteUdpPacket(packet); -#endif // WITH_HEADER_OBFUSCATION -} +void WireguardProcessor::PrepareOutgoingHandshakePacket(WgPeer *peer, Packet *packet) { + assert(dev_.IsMainThread()); -void WireguardProcessor::DoWriteUdpPacket(Packet *packet) { stats_.udp_packets_out++; stats_.udp_bytes_out += packet->size; - if (!dev_.header_obfuscation_) - udp_->WriteUdpPacket(packet); - else - ScrambleUnscrambleAndWrite(packet, &dev_.header_obfuscation_key_, udp_); +#if WITH_HEADER_OBFUSCATION + if (dev_.header_obfuscation_) + ScrambleUnscramblePacket(packet, &dev_.header_obfuscation_key_); +#endif // WITH_HEADER_OBFUSCATION } void WireguardProcessor::RunAllMainThreadScheduled() { @@ -657,7 +675,8 @@ void WireguardProcessor::SendHandshakeInitiation(WgPeer *peer) { } WG_RELEASE_LOCK(peer->mutex_); - DoWriteUdpPacket(packet); + PrepareOutgoingHandshakePacket(peer, packet); + udp_->WriteUdpPacket(packet); if (attempts > 1 && attempts <= 20) RINFO("Retrying handshake, attempt %d...%s", attempts, (attempts == 20) ? " (last notice)" : ""); } @@ -669,7 +688,7 @@ bool WireguardProcessor::IsMainThreadPacket(Packet *packet) { } // Handles an incoming WireGuard packet from the UDP side, decrypt etc. -void WireguardProcessor::HandleUdpPacket(Packet *packet, bool overload) { +WireguardProcessor::PacketResult WireguardProcessor::HandleUdpPacket2(Packet *packet, bool overload) { uint32 type; assert(packet->protocol != 0xCD && (uint16)packet->addr.sin.sin_family != 0xCDCD); // catch msvc uninit mem @@ -688,40 +707,44 @@ void WireguardProcessor::HandleUdpPacket(Packet *packet, bool overload) { if (type == MESSAGE_DATA) { if (packet->size < sizeof(MessageData)) goto invalid_size; - HandleDataPacket(packet); + return HandleDataPacket(packet); #if WITH_SHORT_HEADERS } else if (type & WG_SHORT_HEADER_BIT) { - HandleShortHeaderFormatPacket(type, packet); + return HandleShortHeaderFormatPacket(type, packet); #endif // WITH_SHORT_HEADERS } else if (type == MESSAGE_HANDSHAKE_COOKIE) { assert(dev_.IsMainThread()); if (packet->size != sizeof(MessageHandshakeCookie) || !dev_.is_private_key_initialized()) goto invalid_size; - HandleHandshakeCookiePacket(packet); + return HandleHandshakeCookiePacket(packet); } else if (type == MESSAGE_HANDSHAKE_INITIATION) { assert(dev_.IsMainThread()); if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeInitiation)) : (packet->size != sizeof(MessageHandshakeInitiation)) || !dev_.is_private_key_initialized()) goto invalid_size; stats_.handshakes_in++; - if (CheckIncomingHandshakeRateLimit(packet, overload)) - HandleHandshakeInitiationPacket(packet); + PacketResult result = CheckIncomingHandshakeRateLimit(packet, overload); + if (result != kPacketResult_InUse) + return result; + return HandleHandshakeInitiationPacket(packet); } else if (type == MESSAGE_HANDSHAKE_RESPONSE) { assert(dev_.IsMainThread()); if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeResponse)) : (packet->size != sizeof(MessageHandshakeResponse)) || !dev_.is_private_key_initialized()) goto invalid_size; - if (CheckIncomingHandshakeRateLimit(packet, overload)) - HandleHandshakeResponsePacket(packet); + PacketResult result = CheckIncomingHandshakeRateLimit(packet, overload); + if (result != kPacketResult_InUse) + return result; + return HandleHandshakeResponsePacket(packet); } else { // unknown packet invalid_size: - FreePacket(packet); + return kPacketResult_Free; } } #if WITH_SHORT_HEADERS -void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packet) { +WireguardProcessor::PacketResultd WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packet) { assert(dev_.IsMainOrDataThread()); uint8 *data = packet->data + 1; @@ -829,13 +852,11 @@ void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packe packet->data = data; packet->size = bytes_left - keypair->auth_tag_length; - HandleAuthenticatedDataPacket_WillUnlock(keypair, packet); - return; + return HandleAuthenticatedDataPacket_WillUnlock(keypair, packet); getout_unlock: WG_RELEASE_LOCK(keypair->peer->mutex_); getout: - FreePacket(packet); - return; + return kPacketResult_Free; } #endif // WITH_SHORT_HEADERS @@ -851,7 +872,7 @@ void WireguardProcessor::NotifyHandshakeComplete() { procdel_->OnConnected(); } -void WireguardProcessor::HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet) { +WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet) { WgPeer *peer = keypair->peer; assert(peer->IsPeerLocked()); assert(packet->addr.sin.sin_family != 0); @@ -939,20 +960,16 @@ void WireguardProcessor::HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *key goto getout_error_header; packet->size = size_from_header; - - stats_.tun_bytes_out += size_from_header; - stats_.tun_packets_out++; - - tun_->WriteTunPacket(packet); - return; + + return kPacketResult_ForwardTun; getout_error_header: stats_.error_header++; getout: - FreePacket(packet); + return kPacketResult_Free; } -void WireguardProcessor::HandleDataPacket(Packet *packet) { +WireguardProcessor::PacketResult WireguardProcessor::HandleDataPacket(Packet *packet) { assert(dev_.IsMainOrDataThread()); uint8 *data = packet->data; @@ -963,8 +980,7 @@ void WireguardProcessor::HandleDataPacket(Packet *packet) { if (keypair == NULL || counter >= REJECT_AFTER_MESSAGES) { stats_.error_key_id++; getout: - FreePacket(packet); - return; + return kPacketResult_Free; } packet->data = data + sizeof(MessageData); @@ -988,7 +1004,7 @@ getout: goto getout; } else { assert(!keypair->peer->marked_for_delete_); - HandleAuthenticatedDataPacket_WillUnlock(keypair, packet); + return HandleAuthenticatedDataPacket_WillUnlock(keypair, packet); } } @@ -1000,36 +1016,38 @@ static uint64 GetIpForRateLimit(Packet *packet) { } } -bool WireguardProcessor::CheckIncomingHandshakeRateLimit(Packet *packet, bool overload) { +WireguardProcessor::PacketResult WireguardProcessor::CheckIncomingHandshakeRateLimit(Packet *packet, bool overload) { assert(dev_.IsMainThread()); WgRateLimit::RateLimitResult rr = dev_.rate_limiter()->CheckRateLimit(GetIpForRateLimit(packet)); - if ((overload && rr.is_rate_limited()) || !dev_.CheckCookieMac1(packet)) { - FreePacket(packet); - return false; - } + if ((overload && rr.is_rate_limited()) || !dev_.CheckCookieMac1(packet)) + return kPacketResult_Free; + dev_.rate_limiter()->CommitResult(rr); if (overload && !rr.is_first_ip() && !dev_.CheckCookieMac2(packet)) { dev_.CreateCookieMessage((MessageHandshakeCookie*)packet->data, packet, ((MessageHandshakeInitiation*)packet->data)->sender_key_id); packet->size = sizeof(MessageHandshakeCookie); - DoWriteUdpPacket(packet); - return false; + PrepareOutgoingHandshakePacket(NULL, packet); + return kPacketResult_ForwardUdp; } - return true; + + // This function returns InUse when everything went well + return kPacketResult_InUse; } // server receives this when client wants to setup a session -void WireguardProcessor::HandleHandshakeInitiationPacket(Packet *packet) { +WireguardProcessor::PacketResult WireguardProcessor::HandleHandshakeInitiationPacket(Packet *packet) { assert(dev_.IsMainThread()); WgPeer *peer = WgPeer::ParseMessageHandshakeInitiation(&dev_, packet); if (peer) { - DoWriteUdpPacket(packet); + PrepareOutgoingHandshakePacket(peer, packet); + return kPacketResult_ForwardUdp; } else { - FreePacket(packet); + return kPacketResult_Free; } } // client receives this after session is established -void WireguardProcessor::HandleHandshakeResponsePacket(Packet *packet) { +WireguardProcessor::PacketResult WireguardProcessor::HandleHandshakeResponsePacket(Packet *packet) { assert(dev_.IsMainThread()); WgPeer *peer = WgPeer::ParseMessageHandshakeResponse(&dev_, packet); if (peer) { @@ -1040,7 +1058,7 @@ void WireguardProcessor::HandleHandshakeResponsePacket(Packet *packet) { NotifyHandshakeComplete(); SendKeepalive_Locked(peer); } - FreePacket(packet); + return kPacketResult_Free; } void WireguardProcessor::SendKeepalive_Locked(WgPeer *peer) { @@ -1069,15 +1087,21 @@ void WireguardProcessor::SendQueuedPackets_Locked(WgPeer *peer) { peer->num_queued_packets_ = 0; while (packet != NULL) { Packet *next = Packet_NEXT(packet); - WriteAndEncryptPacketToUdp_WillUnlock(peer, packet); + PacketResult result = WriteAndEncryptPacketToUdp_WillUnlock(peer, packet); + if (result == kPacketResult_ForwardUdp) { + udp_->WriteUdpPacket(packet); + } else if (result == kPacketResult_Free) { + FreePacket(packet); + } packet = next; WG_ACQUIRE_LOCK(peer->mutex_); // WriteAndEncryptPacketToUdp_WillUnlock releases the lock } } -void WireguardProcessor::HandleHandshakeCookiePacket(Packet *packet) { +WireguardProcessor::PacketResult WireguardProcessor::HandleHandshakeCookiePacket(Packet *packet) { assert(dev_.IsMainThread()); WgPeer::ParseMessageHandshakeCookie(&dev_, (MessageHandshakeCookie *)packet->data); + return kPacketResult_Free; } // Only one thread may run the second loop diff --git a/wireguard.h b/wireguard.h index c4c7c89..1197215 100644 --- a/wireguard.h +++ b/wireguard.h @@ -95,6 +95,24 @@ public: void HandleTunPacket(Packet *packet); void HandleUdpPacket(Packet *packet, bool overload); + + // These are the same as above, but instead return the processed packet + // instead of forwarding it to udp/tun. + enum PacketResult { + // The packet is now in use by the wireguard impl and must not be touched + kPacketResult_InUse = 0, + + // The wireguard impl has processed the packet and it's to be forwarded + // to the udp / tun layer + kPacketResult_ForwardUdp = 1, + kPacketResult_ForwardTun = 2, + + // The wireguard impl does not need the packet and it should be freed + kPacketResult_Free = 3 + }; + PacketResult HandleTunPacket2(Packet *packet); + PacketResult HandleUdpPacket2(Packet *packet, bool overload); + static bool IsMainThreadPacket(Packet *packet); void SecondLoop(); @@ -115,20 +133,20 @@ public: void ForceSendHandshakeInitiation(WgPeer *peer); private: - void DoWriteUdpPacket(Packet *packet); - void WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet); + inline void PrepareOutgoingHandshakePacket(WgPeer *peer, Packet *packet); + PacketResult WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet); void SendHandshakeInitiation(WgPeer *peer); void SendKeepalive_Locked(WgPeer *peer); void SendQueuedPackets_Locked(WgPeer *peer); - void HandleHandshakeInitiationPacket(Packet *packet); - void HandleHandshakeResponsePacket(Packet *packet); - void HandleHandshakeCookiePacket(Packet *packet); - void HandleDataPacket(Packet *packet); + PacketResult HandleHandshakeInitiationPacket(Packet *packet); + PacketResult HandleHandshakeResponsePacket(Packet *packet); + PacketResult HandleHandshakeCookiePacket(Packet *packet); + PacketResult HandleDataPacket(Packet *packet); - void HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet); - void HandleShortHeaderFormatPacket(uint32 tag, Packet *packet); - bool CheckIncomingHandshakeRateLimit(Packet *packet, bool overload); + PacketResult HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet); + PacketResult HandleShortHeaderFormatPacket(uint32 tag, Packet *packet); + PacketResult CheckIncomingHandshakeRateLimit(Packet *packet, bool overload); bool HandleIcmpv6NeighborSolicitation(const byte *data, size_t data_size); void NotifyHandshakeComplete();