Add another packet processing API that doesn't forward
This commit is contained in:
parent
1e414a700e
commit
d04afa1cdb
2 changed files with 116 additions and 74 deletions
154
wireguard.cpp
154
wireguard.cpp
|
@ -152,7 +152,6 @@ bool WireguardProcessor::ConfigureTun() {
|
|||
RINFO("TAP is not compatible CIDR /31 or /32. Changing to /24");
|
||||
it->cidr = 24;
|
||||
}
|
||||
|
||||
// Packets to this IP will not be sent out.
|
||||
if (ipv4_broadcast_addr == 0xffffffff) {
|
||||
uint32 netmask = it->cidr == 32 ? 0xffffffff : 0xffffffff << (32 - it->cidr);
|
||||
|
@ -332,8 +331,32 @@ static inline bool IsIpv6Multicast(const uint8 dst[16]) {
|
|||
return dst[0] == 0xff;
|
||||
}
|
||||
|
||||
// On incoming packet to the tun interface.
|
||||
void WireguardProcessor::HandleTunPacket(Packet *packet) {
|
||||
STATIC_ASSERT(kPacketResult_ForwardUdp == 1 && kPacketResult_Free == 3, kPacketResult_wrong_values);
|
||||
PacketResult result = HandleTunPacket2(packet);
|
||||
if (result == kPacketResult_ForwardUdp) {
|
||||
udp_->WriteUdpPacket(packet);
|
||||
} else if (result == kPacketResult_Free) {
|
||||
FreePacket(packet);
|
||||
}
|
||||
}
|
||||
|
||||
void WireguardProcessor::HandleUdpPacket(Packet *packet, bool overload) {
|
||||
PacketResult result = HandleUdpPacket2(packet, overload);
|
||||
if (result == kPacketResult_ForwardTun) {
|
||||
//stats_.tun_bytes_out += size_from_header;
|
||||
//stats_.tun_packets_out++;
|
||||
tun_->WriteTunPacket(packet);
|
||||
} else if (result == kPacketResult_ForwardUdp) {
|
||||
udp_->WriteUdpPacket(packet);
|
||||
} else if (result == kPacketResult_Free) {
|
||||
FreePacket(packet);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// On incoming packet to the tun interface.
|
||||
WireguardProcessor::PacketResult WireguardProcessor::HandleTunPacket2(Packet *packet) {
|
||||
uint8 *data = packet->data;
|
||||
size_t data_size = packet->size;
|
||||
unsigned ip_version, size_from_header;
|
||||
|
@ -385,16 +408,14 @@ void WireguardProcessor::HandleTunPacket(Packet *packet) {
|
|||
|
||||
// WriteAndEncryptPacketToUdp needs a held lock
|
||||
WG_ACQUIRE_LOCK(peer->mutex_);
|
||||
WriteAndEncryptPacketToUdp_WillUnlock(peer, packet);
|
||||
return;
|
||||
return WriteAndEncryptPacketToUdp_WillUnlock(peer, packet);
|
||||
|
||||
getout:
|
||||
// send ICMP?
|
||||
FreePacket(packet);
|
||||
return kPacketResult_Free;
|
||||
}
|
||||
|
||||
// This function must be called with the peer lock held. It will remove the lock
|
||||
void WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet) {
|
||||
WireguardProcessor::PacketResult WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet) {
|
||||
assert(peer->IsPeerLocked());
|
||||
uint8 *data = packet->data, *ad;
|
||||
size_t size = packet->size, ad_len, orig_size = size;
|
||||
|
@ -417,7 +438,7 @@ void WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Pac
|
|||
peer->AddPacketToPeerQueue_Locked(packet);
|
||||
WG_RELEASE_LOCK(peer->mutex_);
|
||||
peer->ScheduleNewHandshake();
|
||||
return;
|
||||
return kPacketResult_InUse;
|
||||
}
|
||||
assert(!peer->marked_for_delete_);
|
||||
|
||||
|
@ -538,15 +559,17 @@ need_big_packet:
|
|||
|
||||
WgKeypairEncryptPayload(data, size, ad, ad_len, send_ctr, keypair);
|
||||
|
||||
DoWriteUdpPacket(packet);
|
||||
if (want_handshake)
|
||||
peer->ScheduleNewHandshake();
|
||||
return;
|
||||
|
||||
stats_.udp_packets_out++;
|
||||
stats_.udp_bytes_out += packet->size;
|
||||
|
||||
return kPacketResult_ForwardUdp;
|
||||
|
||||
getout_discard:
|
||||
WG_RELEASE_LOCK(peer->mutex_);
|
||||
FreePacket(packet);
|
||||
return;
|
||||
return kPacketResult_Free;
|
||||
}
|
||||
|
||||
// This scrambles the initial 16 bytes of the packet with the
|
||||
|
@ -576,20 +599,15 @@ static void ScrambleUnscramblePacket(Packet *packet, ScramblerSiphashKeys *keys)
|
|||
}
|
||||
}
|
||||
|
||||
static NOINLINE void ScrambleUnscrambleAndWrite(Packet *packet, ScramblerSiphashKeys *keys, UdpInterface *udp) {
|
||||
#if WITH_HEADER_OBFUSCATION
|
||||
ScrambleUnscramblePacket(packet, keys);
|
||||
udp->WriteUdpPacket(packet);
|
||||
#endif // WITH_HEADER_OBFUSCATION
|
||||
}
|
||||
void WireguardProcessor::PrepareOutgoingHandshakePacket(WgPeer *peer, Packet *packet) {
|
||||
assert(dev_.IsMainThread());
|
||||
|
||||
void WireguardProcessor::DoWriteUdpPacket(Packet *packet) {
|
||||
stats_.udp_packets_out++;
|
||||
stats_.udp_bytes_out += packet->size;
|
||||
if (!dev_.header_obfuscation_)
|
||||
udp_->WriteUdpPacket(packet);
|
||||
else
|
||||
ScrambleUnscrambleAndWrite(packet, &dev_.header_obfuscation_key_, udp_);
|
||||
#if WITH_HEADER_OBFUSCATION
|
||||
if (dev_.header_obfuscation_)
|
||||
ScrambleUnscramblePacket(packet, &dev_.header_obfuscation_key_);
|
||||
#endif // WITH_HEADER_OBFUSCATION
|
||||
}
|
||||
|
||||
void WireguardProcessor::RunAllMainThreadScheduled() {
|
||||
|
@ -657,7 +675,8 @@ void WireguardProcessor::SendHandshakeInitiation(WgPeer *peer) {
|
|||
}
|
||||
|
||||
WG_RELEASE_LOCK(peer->mutex_);
|
||||
DoWriteUdpPacket(packet);
|
||||
PrepareOutgoingHandshakePacket(peer, packet);
|
||||
udp_->WriteUdpPacket(packet);
|
||||
if (attempts > 1 && attempts <= 20)
|
||||
RINFO("Retrying handshake, attempt %d...%s", attempts, (attempts == 20) ? " (last notice)" : "");
|
||||
}
|
||||
|
@ -669,7 +688,7 @@ bool WireguardProcessor::IsMainThreadPacket(Packet *packet) {
|
|||
}
|
||||
|
||||
// Handles an incoming WireGuard packet from the UDP side, decrypt etc.
|
||||
void WireguardProcessor::HandleUdpPacket(Packet *packet, bool overload) {
|
||||
WireguardProcessor::PacketResult WireguardProcessor::HandleUdpPacket2(Packet *packet, bool overload) {
|
||||
uint32 type;
|
||||
assert(packet->protocol != 0xCD && (uint16)packet->addr.sin.sin_family != 0xCDCD); // catch msvc uninit mem
|
||||
|
||||
|
@ -688,40 +707,44 @@ void WireguardProcessor::HandleUdpPacket(Packet *packet, bool overload) {
|
|||
if (type == MESSAGE_DATA) {
|
||||
if (packet->size < sizeof(MessageData))
|
||||
goto invalid_size;
|
||||
HandleDataPacket(packet);
|
||||
return HandleDataPacket(packet);
|
||||
#if WITH_SHORT_HEADERS
|
||||
} else if (type & WG_SHORT_HEADER_BIT) {
|
||||
HandleShortHeaderFormatPacket(type, packet);
|
||||
return HandleShortHeaderFormatPacket(type, packet);
|
||||
#endif // WITH_SHORT_HEADERS
|
||||
} else if (type == MESSAGE_HANDSHAKE_COOKIE) {
|
||||
assert(dev_.IsMainThread());
|
||||
if (packet->size != sizeof(MessageHandshakeCookie) || !dev_.is_private_key_initialized())
|
||||
goto invalid_size;
|
||||
HandleHandshakeCookiePacket(packet);
|
||||
return HandleHandshakeCookiePacket(packet);
|
||||
} else if (type == MESSAGE_HANDSHAKE_INITIATION) {
|
||||
assert(dev_.IsMainThread());
|
||||
if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeInitiation)) : (packet->size != sizeof(MessageHandshakeInitiation)) ||
|
||||
!dev_.is_private_key_initialized())
|
||||
goto invalid_size;
|
||||
stats_.handshakes_in++;
|
||||
if (CheckIncomingHandshakeRateLimit(packet, overload))
|
||||
HandleHandshakeInitiationPacket(packet);
|
||||
PacketResult result = CheckIncomingHandshakeRateLimit(packet, overload);
|
||||
if (result != kPacketResult_InUse)
|
||||
return result;
|
||||
return HandleHandshakeInitiationPacket(packet);
|
||||
} else if (type == MESSAGE_HANDSHAKE_RESPONSE) {
|
||||
assert(dev_.IsMainThread());
|
||||
if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeResponse)) : (packet->size != sizeof(MessageHandshakeResponse)) ||
|
||||
!dev_.is_private_key_initialized())
|
||||
goto invalid_size;
|
||||
if (CheckIncomingHandshakeRateLimit(packet, overload))
|
||||
HandleHandshakeResponsePacket(packet);
|
||||
PacketResult result = CheckIncomingHandshakeRateLimit(packet, overload);
|
||||
if (result != kPacketResult_InUse)
|
||||
return result;
|
||||
return HandleHandshakeResponsePacket(packet);
|
||||
} else {
|
||||
// unknown packet
|
||||
invalid_size:
|
||||
FreePacket(packet);
|
||||
return kPacketResult_Free;
|
||||
}
|
||||
}
|
||||
|
||||
#if WITH_SHORT_HEADERS
|
||||
void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packet) {
|
||||
WireguardProcessor::PacketResultd WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packet) {
|
||||
assert(dev_.IsMainOrDataThread());
|
||||
|
||||
uint8 *data = packet->data + 1;
|
||||
|
@ -829,13 +852,11 @@ void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packe
|
|||
|
||||
packet->data = data;
|
||||
packet->size = bytes_left - keypair->auth_tag_length;
|
||||
HandleAuthenticatedDataPacket_WillUnlock(keypair, packet);
|
||||
return;
|
||||
return HandleAuthenticatedDataPacket_WillUnlock(keypair, packet);
|
||||
getout_unlock:
|
||||
WG_RELEASE_LOCK(keypair->peer->mutex_);
|
||||
getout:
|
||||
FreePacket(packet);
|
||||
return;
|
||||
return kPacketResult_Free;
|
||||
}
|
||||
#endif // WITH_SHORT_HEADERS
|
||||
|
||||
|
@ -851,7 +872,7 @@ void WireguardProcessor::NotifyHandshakeComplete() {
|
|||
procdel_->OnConnected();
|
||||
}
|
||||
|
||||
void WireguardProcessor::HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet) {
|
||||
WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet) {
|
||||
WgPeer *peer = keypair->peer;
|
||||
assert(peer->IsPeerLocked());
|
||||
assert(packet->addr.sin.sin_family != 0);
|
||||
|
@ -939,20 +960,16 @@ void WireguardProcessor::HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *key
|
|||
goto getout_error_header;
|
||||
|
||||
packet->size = size_from_header;
|
||||
|
||||
stats_.tun_bytes_out += size_from_header;
|
||||
stats_.tun_packets_out++;
|
||||
|
||||
tun_->WriteTunPacket(packet);
|
||||
return;
|
||||
|
||||
return kPacketResult_ForwardTun;
|
||||
|
||||
getout_error_header:
|
||||
stats_.error_header++;
|
||||
getout:
|
||||
FreePacket(packet);
|
||||
return kPacketResult_Free;
|
||||
}
|
||||
|
||||
void WireguardProcessor::HandleDataPacket(Packet *packet) {
|
||||
WireguardProcessor::PacketResult WireguardProcessor::HandleDataPacket(Packet *packet) {
|
||||
assert(dev_.IsMainOrDataThread());
|
||||
|
||||
uint8 *data = packet->data;
|
||||
|
@ -963,8 +980,7 @@ void WireguardProcessor::HandleDataPacket(Packet *packet) {
|
|||
if (keypair == NULL || counter >= REJECT_AFTER_MESSAGES) {
|
||||
stats_.error_key_id++;
|
||||
getout:
|
||||
FreePacket(packet);
|
||||
return;
|
||||
return kPacketResult_Free;
|
||||
}
|
||||
|
||||
packet->data = data + sizeof(MessageData);
|
||||
|
@ -988,7 +1004,7 @@ getout:
|
|||
goto getout;
|
||||
} else {
|
||||
assert(!keypair->peer->marked_for_delete_);
|
||||
HandleAuthenticatedDataPacket_WillUnlock(keypair, packet);
|
||||
return HandleAuthenticatedDataPacket_WillUnlock(keypair, packet);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1000,36 +1016,38 @@ static uint64 GetIpForRateLimit(Packet *packet) {
|
|||
}
|
||||
}
|
||||
|
||||
bool WireguardProcessor::CheckIncomingHandshakeRateLimit(Packet *packet, bool overload) {
|
||||
WireguardProcessor::PacketResult WireguardProcessor::CheckIncomingHandshakeRateLimit(Packet *packet, bool overload) {
|
||||
assert(dev_.IsMainThread());
|
||||
WgRateLimit::RateLimitResult rr = dev_.rate_limiter()->CheckRateLimit(GetIpForRateLimit(packet));
|
||||
if ((overload && rr.is_rate_limited()) || !dev_.CheckCookieMac1(packet)) {
|
||||
FreePacket(packet);
|
||||
return false;
|
||||
}
|
||||
if ((overload && rr.is_rate_limited()) || !dev_.CheckCookieMac1(packet))
|
||||
return kPacketResult_Free;
|
||||
|
||||
dev_.rate_limiter()->CommitResult(rr);
|
||||
if (overload && !rr.is_first_ip() && !dev_.CheckCookieMac2(packet)) {
|
||||
dev_.CreateCookieMessage((MessageHandshakeCookie*)packet->data, packet, ((MessageHandshakeInitiation*)packet->data)->sender_key_id);
|
||||
packet->size = sizeof(MessageHandshakeCookie);
|
||||
DoWriteUdpPacket(packet);
|
||||
return false;
|
||||
PrepareOutgoingHandshakePacket(NULL, packet);
|
||||
return kPacketResult_ForwardUdp;
|
||||
}
|
||||
return true;
|
||||
|
||||
// This function returns InUse when everything went well
|
||||
return kPacketResult_InUse;
|
||||
}
|
||||
|
||||
// server receives this when client wants to setup a session
|
||||
void WireguardProcessor::HandleHandshakeInitiationPacket(Packet *packet) {
|
||||
WireguardProcessor::PacketResult WireguardProcessor::HandleHandshakeInitiationPacket(Packet *packet) {
|
||||
assert(dev_.IsMainThread());
|
||||
WgPeer *peer = WgPeer::ParseMessageHandshakeInitiation(&dev_, packet);
|
||||
if (peer) {
|
||||
DoWriteUdpPacket(packet);
|
||||
PrepareOutgoingHandshakePacket(peer, packet);
|
||||
return kPacketResult_ForwardUdp;
|
||||
} else {
|
||||
FreePacket(packet);
|
||||
return kPacketResult_Free;
|
||||
}
|
||||
}
|
||||
|
||||
// client receives this after session is established
|
||||
void WireguardProcessor::HandleHandshakeResponsePacket(Packet *packet) {
|
||||
WireguardProcessor::PacketResult WireguardProcessor::HandleHandshakeResponsePacket(Packet *packet) {
|
||||
assert(dev_.IsMainThread());
|
||||
WgPeer *peer = WgPeer::ParseMessageHandshakeResponse(&dev_, packet);
|
||||
if (peer) {
|
||||
|
@ -1040,7 +1058,7 @@ void WireguardProcessor::HandleHandshakeResponsePacket(Packet *packet) {
|
|||
NotifyHandshakeComplete();
|
||||
SendKeepalive_Locked(peer);
|
||||
}
|
||||
FreePacket(packet);
|
||||
return kPacketResult_Free;
|
||||
}
|
||||
|
||||
void WireguardProcessor::SendKeepalive_Locked(WgPeer *peer) {
|
||||
|
@ -1069,15 +1087,21 @@ void WireguardProcessor::SendQueuedPackets_Locked(WgPeer *peer) {
|
|||
peer->num_queued_packets_ = 0;
|
||||
while (packet != NULL) {
|
||||
Packet *next = Packet_NEXT(packet);
|
||||
WriteAndEncryptPacketToUdp_WillUnlock(peer, packet);
|
||||
PacketResult result = WriteAndEncryptPacketToUdp_WillUnlock(peer, packet);
|
||||
if (result == kPacketResult_ForwardUdp) {
|
||||
udp_->WriteUdpPacket(packet);
|
||||
} else if (result == kPacketResult_Free) {
|
||||
FreePacket(packet);
|
||||
}
|
||||
packet = next;
|
||||
WG_ACQUIRE_LOCK(peer->mutex_); // WriteAndEncryptPacketToUdp_WillUnlock releases the lock
|
||||
}
|
||||
}
|
||||
|
||||
void WireguardProcessor::HandleHandshakeCookiePacket(Packet *packet) {
|
||||
WireguardProcessor::PacketResult WireguardProcessor::HandleHandshakeCookiePacket(Packet *packet) {
|
||||
assert(dev_.IsMainThread());
|
||||
WgPeer::ParseMessageHandshakeCookie(&dev_, (MessageHandshakeCookie *)packet->data);
|
||||
return kPacketResult_Free;
|
||||
}
|
||||
|
||||
// Only one thread may run the second loop
|
||||
|
|
36
wireguard.h
36
wireguard.h
|
@ -95,6 +95,24 @@ public:
|
|||
|
||||
void HandleTunPacket(Packet *packet);
|
||||
void HandleUdpPacket(Packet *packet, bool overload);
|
||||
|
||||
// These are the same as above, but instead return the processed packet
|
||||
// instead of forwarding it to udp/tun.
|
||||
enum PacketResult {
|
||||
// The packet is now in use by the wireguard impl and must not be touched
|
||||
kPacketResult_InUse = 0,
|
||||
|
||||
// The wireguard impl has processed the packet and it's to be forwarded
|
||||
// to the udp / tun layer
|
||||
kPacketResult_ForwardUdp = 1,
|
||||
kPacketResult_ForwardTun = 2,
|
||||
|
||||
// The wireguard impl does not need the packet and it should be freed
|
||||
kPacketResult_Free = 3
|
||||
};
|
||||
PacketResult HandleTunPacket2(Packet *packet);
|
||||
PacketResult HandleUdpPacket2(Packet *packet, bool overload);
|
||||
|
||||
static bool IsMainThreadPacket(Packet *packet);
|
||||
|
||||
void SecondLoop();
|
||||
|
@ -115,20 +133,20 @@ public:
|
|||
void ForceSendHandshakeInitiation(WgPeer *peer);
|
||||
|
||||
private:
|
||||
void DoWriteUdpPacket(Packet *packet);
|
||||
void WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet);
|
||||
inline void PrepareOutgoingHandshakePacket(WgPeer *peer, Packet *packet);
|
||||
PacketResult WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet);
|
||||
void SendHandshakeInitiation(WgPeer *peer);
|
||||
void SendKeepalive_Locked(WgPeer *peer);
|
||||
void SendQueuedPackets_Locked(WgPeer *peer);
|
||||
|
||||
void HandleHandshakeInitiationPacket(Packet *packet);
|
||||
void HandleHandshakeResponsePacket(Packet *packet);
|
||||
void HandleHandshakeCookiePacket(Packet *packet);
|
||||
void HandleDataPacket(Packet *packet);
|
||||
PacketResult HandleHandshakeInitiationPacket(Packet *packet);
|
||||
PacketResult HandleHandshakeResponsePacket(Packet *packet);
|
||||
PacketResult HandleHandshakeCookiePacket(Packet *packet);
|
||||
PacketResult HandleDataPacket(Packet *packet);
|
||||
|
||||
void HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet);
|
||||
void HandleShortHeaderFormatPacket(uint32 tag, Packet *packet);
|
||||
bool CheckIncomingHandshakeRateLimit(Packet *packet, bool overload);
|
||||
PacketResult HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet);
|
||||
PacketResult HandleShortHeaderFormatPacket(uint32 tag, Packet *packet);
|
||||
PacketResult CheckIncomingHandshakeRateLimit(Packet *packet, bool overload);
|
||||
bool HandleIcmpv6NeighborSolicitation(const byte *data, size_t data_size);
|
||||
void NotifyHandshakeComplete();
|
||||
|
||||
|
|
Loading…
Reference in a new issue