Add another packet processing API that doesn't forward

This commit is contained in:
Ludvig Strigeus 2018-11-23 22:25:53 +01:00
parent 1e414a700e
commit d04afa1cdb
2 changed files with 116 additions and 74 deletions

View file

@ -152,7 +152,6 @@ bool WireguardProcessor::ConfigureTun() {
RINFO("TAP is not compatible CIDR /31 or /32. Changing to /24");
it->cidr = 24;
}
// Packets to this IP will not be sent out.
if (ipv4_broadcast_addr == 0xffffffff) {
uint32 netmask = it->cidr == 32 ? 0xffffffff : 0xffffffff << (32 - it->cidr);
@ -332,8 +331,32 @@ static inline bool IsIpv6Multicast(const uint8 dst[16]) {
return dst[0] == 0xff;
}
// On incoming packet to the tun interface.
void WireguardProcessor::HandleTunPacket(Packet *packet) {
STATIC_ASSERT(kPacketResult_ForwardUdp == 1 && kPacketResult_Free == 3, kPacketResult_wrong_values);
PacketResult result = HandleTunPacket2(packet);
if (result == kPacketResult_ForwardUdp) {
udp_->WriteUdpPacket(packet);
} else if (result == kPacketResult_Free) {
FreePacket(packet);
}
}
void WireguardProcessor::HandleUdpPacket(Packet *packet, bool overload) {
PacketResult result = HandleUdpPacket2(packet, overload);
if (result == kPacketResult_ForwardTun) {
//stats_.tun_bytes_out += size_from_header;
//stats_.tun_packets_out++;
tun_->WriteTunPacket(packet);
} else if (result == kPacketResult_ForwardUdp) {
udp_->WriteUdpPacket(packet);
} else if (result == kPacketResult_Free) {
FreePacket(packet);
}
}
// On incoming packet to the tun interface.
WireguardProcessor::PacketResult WireguardProcessor::HandleTunPacket2(Packet *packet) {
uint8 *data = packet->data;
size_t data_size = packet->size;
unsigned ip_version, size_from_header;
@ -385,16 +408,14 @@ void WireguardProcessor::HandleTunPacket(Packet *packet) {
// WriteAndEncryptPacketToUdp needs a held lock
WG_ACQUIRE_LOCK(peer->mutex_);
WriteAndEncryptPacketToUdp_WillUnlock(peer, packet);
return;
return WriteAndEncryptPacketToUdp_WillUnlock(peer, packet);
getout:
// send ICMP?
FreePacket(packet);
return kPacketResult_Free;
}
// This function must be called with the peer lock held. It will remove the lock
void WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet) {
WireguardProcessor::PacketResult WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet) {
assert(peer->IsPeerLocked());
uint8 *data = packet->data, *ad;
size_t size = packet->size, ad_len, orig_size = size;
@ -417,7 +438,7 @@ void WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Pac
peer->AddPacketToPeerQueue_Locked(packet);
WG_RELEASE_LOCK(peer->mutex_);
peer->ScheduleNewHandshake();
return;
return kPacketResult_InUse;
}
assert(!peer->marked_for_delete_);
@ -538,15 +559,17 @@ need_big_packet:
WgKeypairEncryptPayload(data, size, ad, ad_len, send_ctr, keypair);
DoWriteUdpPacket(packet);
if (want_handshake)
peer->ScheduleNewHandshake();
return;
stats_.udp_packets_out++;
stats_.udp_bytes_out += packet->size;
return kPacketResult_ForwardUdp;
getout_discard:
WG_RELEASE_LOCK(peer->mutex_);
FreePacket(packet);
return;
return kPacketResult_Free;
}
// This scrambles the initial 16 bytes of the packet with the
@ -576,20 +599,15 @@ static void ScrambleUnscramblePacket(Packet *packet, ScramblerSiphashKeys *keys)
}
}
static NOINLINE void ScrambleUnscrambleAndWrite(Packet *packet, ScramblerSiphashKeys *keys, UdpInterface *udp) {
#if WITH_HEADER_OBFUSCATION
ScrambleUnscramblePacket(packet, keys);
udp->WriteUdpPacket(packet);
#endif // WITH_HEADER_OBFUSCATION
}
void WireguardProcessor::PrepareOutgoingHandshakePacket(WgPeer *peer, Packet *packet) {
assert(dev_.IsMainThread());
void WireguardProcessor::DoWriteUdpPacket(Packet *packet) {
stats_.udp_packets_out++;
stats_.udp_bytes_out += packet->size;
if (!dev_.header_obfuscation_)
udp_->WriteUdpPacket(packet);
else
ScrambleUnscrambleAndWrite(packet, &dev_.header_obfuscation_key_, udp_);
#if WITH_HEADER_OBFUSCATION
if (dev_.header_obfuscation_)
ScrambleUnscramblePacket(packet, &dev_.header_obfuscation_key_);
#endif // WITH_HEADER_OBFUSCATION
}
void WireguardProcessor::RunAllMainThreadScheduled() {
@ -657,7 +675,8 @@ void WireguardProcessor::SendHandshakeInitiation(WgPeer *peer) {
}
WG_RELEASE_LOCK(peer->mutex_);
DoWriteUdpPacket(packet);
PrepareOutgoingHandshakePacket(peer, packet);
udp_->WriteUdpPacket(packet);
if (attempts > 1 && attempts <= 20)
RINFO("Retrying handshake, attempt %d...%s", attempts, (attempts == 20) ? " (last notice)" : "");
}
@ -669,7 +688,7 @@ bool WireguardProcessor::IsMainThreadPacket(Packet *packet) {
}
// Handles an incoming WireGuard packet from the UDP side, decrypt etc.
void WireguardProcessor::HandleUdpPacket(Packet *packet, bool overload) {
WireguardProcessor::PacketResult WireguardProcessor::HandleUdpPacket2(Packet *packet, bool overload) {
uint32 type;
assert(packet->protocol != 0xCD && (uint16)packet->addr.sin.sin_family != 0xCDCD); // catch msvc uninit mem
@ -688,40 +707,44 @@ void WireguardProcessor::HandleUdpPacket(Packet *packet, bool overload) {
if (type == MESSAGE_DATA) {
if (packet->size < sizeof(MessageData))
goto invalid_size;
HandleDataPacket(packet);
return HandleDataPacket(packet);
#if WITH_SHORT_HEADERS
} else if (type & WG_SHORT_HEADER_BIT) {
HandleShortHeaderFormatPacket(type, packet);
return HandleShortHeaderFormatPacket(type, packet);
#endif // WITH_SHORT_HEADERS
} else if (type == MESSAGE_HANDSHAKE_COOKIE) {
assert(dev_.IsMainThread());
if (packet->size != sizeof(MessageHandshakeCookie) || !dev_.is_private_key_initialized())
goto invalid_size;
HandleHandshakeCookiePacket(packet);
return HandleHandshakeCookiePacket(packet);
} else if (type == MESSAGE_HANDSHAKE_INITIATION) {
assert(dev_.IsMainThread());
if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeInitiation)) : (packet->size != sizeof(MessageHandshakeInitiation)) ||
!dev_.is_private_key_initialized())
goto invalid_size;
stats_.handshakes_in++;
if (CheckIncomingHandshakeRateLimit(packet, overload))
HandleHandshakeInitiationPacket(packet);
PacketResult result = CheckIncomingHandshakeRateLimit(packet, overload);
if (result != kPacketResult_InUse)
return result;
return HandleHandshakeInitiationPacket(packet);
} else if (type == MESSAGE_HANDSHAKE_RESPONSE) {
assert(dev_.IsMainThread());
if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeResponse)) : (packet->size != sizeof(MessageHandshakeResponse)) ||
!dev_.is_private_key_initialized())
goto invalid_size;
if (CheckIncomingHandshakeRateLimit(packet, overload))
HandleHandshakeResponsePacket(packet);
PacketResult result = CheckIncomingHandshakeRateLimit(packet, overload);
if (result != kPacketResult_InUse)
return result;
return HandleHandshakeResponsePacket(packet);
} else {
// unknown packet
invalid_size:
FreePacket(packet);
return kPacketResult_Free;
}
}
#if WITH_SHORT_HEADERS
void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packet) {
WireguardProcessor::PacketResultd WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packet) {
assert(dev_.IsMainOrDataThread());
uint8 *data = packet->data + 1;
@ -829,13 +852,11 @@ void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packe
packet->data = data;
packet->size = bytes_left - keypair->auth_tag_length;
HandleAuthenticatedDataPacket_WillUnlock(keypair, packet);
return;
return HandleAuthenticatedDataPacket_WillUnlock(keypair, packet);
getout_unlock:
WG_RELEASE_LOCK(keypair->peer->mutex_);
getout:
FreePacket(packet);
return;
return kPacketResult_Free;
}
#endif // WITH_SHORT_HEADERS
@ -851,7 +872,7 @@ void WireguardProcessor::NotifyHandshakeComplete() {
procdel_->OnConnected();
}
void WireguardProcessor::HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet) {
WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet) {
WgPeer *peer = keypair->peer;
assert(peer->IsPeerLocked());
assert(packet->addr.sin.sin_family != 0);
@ -939,20 +960,16 @@ void WireguardProcessor::HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *key
goto getout_error_header;
packet->size = size_from_header;
stats_.tun_bytes_out += size_from_header;
stats_.tun_packets_out++;
tun_->WriteTunPacket(packet);
return;
return kPacketResult_ForwardTun;
getout_error_header:
stats_.error_header++;
getout:
FreePacket(packet);
return kPacketResult_Free;
}
void WireguardProcessor::HandleDataPacket(Packet *packet) {
WireguardProcessor::PacketResult WireguardProcessor::HandleDataPacket(Packet *packet) {
assert(dev_.IsMainOrDataThread());
uint8 *data = packet->data;
@ -963,8 +980,7 @@ void WireguardProcessor::HandleDataPacket(Packet *packet) {
if (keypair == NULL || counter >= REJECT_AFTER_MESSAGES) {
stats_.error_key_id++;
getout:
FreePacket(packet);
return;
return kPacketResult_Free;
}
packet->data = data + sizeof(MessageData);
@ -988,7 +1004,7 @@ getout:
goto getout;
} else {
assert(!keypair->peer->marked_for_delete_);
HandleAuthenticatedDataPacket_WillUnlock(keypair, packet);
return HandleAuthenticatedDataPacket_WillUnlock(keypair, packet);
}
}
@ -1000,36 +1016,38 @@ static uint64 GetIpForRateLimit(Packet *packet) {
}
}
bool WireguardProcessor::CheckIncomingHandshakeRateLimit(Packet *packet, bool overload) {
WireguardProcessor::PacketResult WireguardProcessor::CheckIncomingHandshakeRateLimit(Packet *packet, bool overload) {
assert(dev_.IsMainThread());
WgRateLimit::RateLimitResult rr = dev_.rate_limiter()->CheckRateLimit(GetIpForRateLimit(packet));
if ((overload && rr.is_rate_limited()) || !dev_.CheckCookieMac1(packet)) {
FreePacket(packet);
return false;
}
if ((overload && rr.is_rate_limited()) || !dev_.CheckCookieMac1(packet))
return kPacketResult_Free;
dev_.rate_limiter()->CommitResult(rr);
if (overload && !rr.is_first_ip() && !dev_.CheckCookieMac2(packet)) {
dev_.CreateCookieMessage((MessageHandshakeCookie*)packet->data, packet, ((MessageHandshakeInitiation*)packet->data)->sender_key_id);
packet->size = sizeof(MessageHandshakeCookie);
DoWriteUdpPacket(packet);
return false;
PrepareOutgoingHandshakePacket(NULL, packet);
return kPacketResult_ForwardUdp;
}
return true;
// This function returns InUse when everything went well
return kPacketResult_InUse;
}
// server receives this when client wants to setup a session
void WireguardProcessor::HandleHandshakeInitiationPacket(Packet *packet) {
WireguardProcessor::PacketResult WireguardProcessor::HandleHandshakeInitiationPacket(Packet *packet) {
assert(dev_.IsMainThread());
WgPeer *peer = WgPeer::ParseMessageHandshakeInitiation(&dev_, packet);
if (peer) {
DoWriteUdpPacket(packet);
PrepareOutgoingHandshakePacket(peer, packet);
return kPacketResult_ForwardUdp;
} else {
FreePacket(packet);
return kPacketResult_Free;
}
}
// client receives this after session is established
void WireguardProcessor::HandleHandshakeResponsePacket(Packet *packet) {
WireguardProcessor::PacketResult WireguardProcessor::HandleHandshakeResponsePacket(Packet *packet) {
assert(dev_.IsMainThread());
WgPeer *peer = WgPeer::ParseMessageHandshakeResponse(&dev_, packet);
if (peer) {
@ -1040,7 +1058,7 @@ void WireguardProcessor::HandleHandshakeResponsePacket(Packet *packet) {
NotifyHandshakeComplete();
SendKeepalive_Locked(peer);
}
FreePacket(packet);
return kPacketResult_Free;
}
void WireguardProcessor::SendKeepalive_Locked(WgPeer *peer) {
@ -1069,15 +1087,21 @@ void WireguardProcessor::SendQueuedPackets_Locked(WgPeer *peer) {
peer->num_queued_packets_ = 0;
while (packet != NULL) {
Packet *next = Packet_NEXT(packet);
WriteAndEncryptPacketToUdp_WillUnlock(peer, packet);
PacketResult result = WriteAndEncryptPacketToUdp_WillUnlock(peer, packet);
if (result == kPacketResult_ForwardUdp) {
udp_->WriteUdpPacket(packet);
} else if (result == kPacketResult_Free) {
FreePacket(packet);
}
packet = next;
WG_ACQUIRE_LOCK(peer->mutex_); // WriteAndEncryptPacketToUdp_WillUnlock releases the lock
}
}
void WireguardProcessor::HandleHandshakeCookiePacket(Packet *packet) {
WireguardProcessor::PacketResult WireguardProcessor::HandleHandshakeCookiePacket(Packet *packet) {
assert(dev_.IsMainThread());
WgPeer::ParseMessageHandshakeCookie(&dev_, (MessageHandshakeCookie *)packet->data);
return kPacketResult_Free;
}
// Only one thread may run the second loop

View file

@ -95,6 +95,24 @@ public:
void HandleTunPacket(Packet *packet);
void HandleUdpPacket(Packet *packet, bool overload);
// These are the same as above, but instead return the processed packet
// instead of forwarding it to udp/tun.
enum PacketResult {
// The packet is now in use by the wireguard impl and must not be touched
kPacketResult_InUse = 0,
// The wireguard impl has processed the packet and it's to be forwarded
// to the udp / tun layer
kPacketResult_ForwardUdp = 1,
kPacketResult_ForwardTun = 2,
// The wireguard impl does not need the packet and it should be freed
kPacketResult_Free = 3
};
PacketResult HandleTunPacket2(Packet *packet);
PacketResult HandleUdpPacket2(Packet *packet, bool overload);
static bool IsMainThreadPacket(Packet *packet);
void SecondLoop();
@ -115,20 +133,20 @@ public:
void ForceSendHandshakeInitiation(WgPeer *peer);
private:
void DoWriteUdpPacket(Packet *packet);
void WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet);
inline void PrepareOutgoingHandshakePacket(WgPeer *peer, Packet *packet);
PacketResult WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet);
void SendHandshakeInitiation(WgPeer *peer);
void SendKeepalive_Locked(WgPeer *peer);
void SendQueuedPackets_Locked(WgPeer *peer);
void HandleHandshakeInitiationPacket(Packet *packet);
void HandleHandshakeResponsePacket(Packet *packet);
void HandleHandshakeCookiePacket(Packet *packet);
void HandleDataPacket(Packet *packet);
PacketResult HandleHandshakeInitiationPacket(Packet *packet);
PacketResult HandleHandshakeResponsePacket(Packet *packet);
PacketResult HandleHandshakeCookiePacket(Packet *packet);
PacketResult HandleDataPacket(Packet *packet);
void HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet);
void HandleShortHeaderFormatPacket(uint32 tag, Packet *packet);
bool CheckIncomingHandshakeRateLimit(Packet *packet, bool overload);
PacketResult HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet);
PacketResult HandleShortHeaderFormatPacket(uint32 tag, Packet *packet);
PacketResult CheckIncomingHandshakeRateLimit(Packet *packet, bool overload);
bool HandleIcmpv6NeighborSolicitation(const byte *data, size_t data_size);
void NotifyHandshakeComplete();