diff --git a/wireguard.cpp b/wireguard.cpp index 42346ae..14859a3 100644 --- a/wireguard.cpp +++ b/wireguard.cpp @@ -651,6 +651,7 @@ bool WireguardProcessor::IsMainThreadPacket(Packet *packet) { // Handles an incoming WireGuard packet from the UDP side, decrypt etc. void WireguardProcessor::HandleUdpPacket(Packet *packet, bool overload) { uint32 type; + assert(packet->protocol != 0xCD && (uint16)packet->addr.sin.sin_family != 0xCDCD); // catch msvc uninit mem // Unscramble incoming packets #if WITH_HEADER_OBFUSCATION @@ -699,19 +700,6 @@ invalid_size: } } -void WgPeer::CopyEndpointToPeer_Locked(WgKeypair *keypair, const IpAddr *addr) { - // Remember how to send packets to this peer - if (keypair->peer->allow_endpoint_change_ && - CompareIpAddr(&keypair->peer->endpoint_, addr) && addr->sin.sin_family != 0) { -#if WITH_SHORT_HEADERS - // When the endpoint changes, forget about using the short key. - keypair->broadcast_short_key = 0; - keypair->can_use_short_key_for_outgoing = false; -#endif // WITH_SHORT_HEADERS - keypair->peer->endpoint_ = *addr; - } -} - #if WITH_SHORT_HEADERS void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packet) { assert(dev_.IsMainOrDataThread()); @@ -809,8 +797,6 @@ void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packe keypair->send_ctr_acked = std::max(keypair->send_ctr_acked, acked_counter); - WgPeer::CopyEndpointToPeer_Locked(keypair, &packet->addr); - // Periodically broadcast out the short key if ((tag & WG_SHORT_HEADER_KEY_ID_MASK) == 0x00 && !keypair->did_attempt_remember_ip_port) { keypair->did_attempt_remember_ip_port = true; @@ -848,6 +834,19 @@ void WireguardProcessor::NotifyHandshakeComplete() { void WireguardProcessor::HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet) { WgPeer *peer = keypair->peer; assert(peer->IsPeerLocked()); + assert(packet->addr.sin.sin_family != 0); + + // Remember the endpoint of the peer + if (peer->allow_endpoint_change_ && + (CompareIpAddr(&peer->endpoint_, &packet->addr) | (peer->endpoint_protocol_ ^ packet->protocol)) != 0) { +#if WITH_SHORT_HEADERS + // When the endpoint changes, forget about using the short key. + keypair->broadcast_short_key = 0; + keypair->can_use_short_key_for_outgoing = false; +#endif // WITH_SHORT_HEADERS + peer->endpoint_ = packet->addr; + peer->endpoint_protocol_ = packet->protocol; + } // Remember how many incoming packets we've seen so we can approximate loss keypair->incoming_packet_count++; @@ -969,9 +968,6 @@ getout: goto getout; } else { assert(!keypair->peer->marked_for_delete_); - - WgPeer::CopyEndpointToPeer_Locked(keypair, &packet->addr); - HandleAuthenticatedDataPacket_WillUnlock(keypair, packet); } } diff --git a/wireguard_proto.h b/wireguard_proto.h index ec8afc3..735f017 100644 --- a/wireguard_proto.h +++ b/wireguard_proto.h @@ -509,7 +509,6 @@ private: void WriteMacToPacket(const uint8 *data, MessageMacs *mac); void CheckAndUpdateTimeOfNextKeyEvent(uint64 now); static void DeleteKeypair(WgKeypair **kp); - static void CopyEndpointToPeer_Locked(WgKeypair *keypair, const IpAddr *addr); static void DelayedDelete(void *x); size_t WriteHandshakeExtension(uint8 *dst, WgKeypair *keypair); void InsertKeypairInPeer_Locked(WgKeypair *keypair);