Add another packet processing API that doesn't forward

This commit is contained in:
Ludvig Strigeus 2018-11-23 22:25:53 +01:00
parent 1e414a700e
commit d04afa1cdb
2 changed files with 116 additions and 74 deletions

View file

@ -152,7 +152,6 @@ bool WireguardProcessor::ConfigureTun() {
RINFO("TAP is not compatible CIDR /31 or /32. Changing to /24"); RINFO("TAP is not compatible CIDR /31 or /32. Changing to /24");
it->cidr = 24; it->cidr = 24;
} }
// Packets to this IP will not be sent out. // Packets to this IP will not be sent out.
if (ipv4_broadcast_addr == 0xffffffff) { if (ipv4_broadcast_addr == 0xffffffff) {
uint32 netmask = it->cidr == 32 ? 0xffffffff : 0xffffffff << (32 - it->cidr); uint32 netmask = it->cidr == 32 ? 0xffffffff : 0xffffffff << (32 - it->cidr);
@ -332,8 +331,32 @@ static inline bool IsIpv6Multicast(const uint8 dst[16]) {
return dst[0] == 0xff; return dst[0] == 0xff;
} }
// On incoming packet to the tun interface.
void WireguardProcessor::HandleTunPacket(Packet *packet) { void WireguardProcessor::HandleTunPacket(Packet *packet) {
STATIC_ASSERT(kPacketResult_ForwardUdp == 1 && kPacketResult_Free == 3, kPacketResult_wrong_values);
PacketResult result = HandleTunPacket2(packet);
if (result == kPacketResult_ForwardUdp) {
udp_->WriteUdpPacket(packet);
} else if (result == kPacketResult_Free) {
FreePacket(packet);
}
}
void WireguardProcessor::HandleUdpPacket(Packet *packet, bool overload) {
PacketResult result = HandleUdpPacket2(packet, overload);
if (result == kPacketResult_ForwardTun) {
//stats_.tun_bytes_out += size_from_header;
//stats_.tun_packets_out++;
tun_->WriteTunPacket(packet);
} else if (result == kPacketResult_ForwardUdp) {
udp_->WriteUdpPacket(packet);
} else if (result == kPacketResult_Free) {
FreePacket(packet);
}
}
// On incoming packet to the tun interface.
WireguardProcessor::PacketResult WireguardProcessor::HandleTunPacket2(Packet *packet) {
uint8 *data = packet->data; uint8 *data = packet->data;
size_t data_size = packet->size; size_t data_size = packet->size;
unsigned ip_version, size_from_header; unsigned ip_version, size_from_header;
@ -385,16 +408,14 @@ void WireguardProcessor::HandleTunPacket(Packet *packet) {
// WriteAndEncryptPacketToUdp needs a held lock // WriteAndEncryptPacketToUdp needs a held lock
WG_ACQUIRE_LOCK(peer->mutex_); WG_ACQUIRE_LOCK(peer->mutex_);
WriteAndEncryptPacketToUdp_WillUnlock(peer, packet); return WriteAndEncryptPacketToUdp_WillUnlock(peer, packet);
return;
getout: getout:
// send ICMP? return kPacketResult_Free;
FreePacket(packet);
} }
// This function must be called with the peer lock held. It will remove the lock // This function must be called with the peer lock held. It will remove the lock
void WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet) { WireguardProcessor::PacketResult WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet) {
assert(peer->IsPeerLocked()); assert(peer->IsPeerLocked());
uint8 *data = packet->data, *ad; uint8 *data = packet->data, *ad;
size_t size = packet->size, ad_len, orig_size = size; size_t size = packet->size, ad_len, orig_size = size;
@ -417,7 +438,7 @@ void WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Pac
peer->AddPacketToPeerQueue_Locked(packet); peer->AddPacketToPeerQueue_Locked(packet);
WG_RELEASE_LOCK(peer->mutex_); WG_RELEASE_LOCK(peer->mutex_);
peer->ScheduleNewHandshake(); peer->ScheduleNewHandshake();
return; return kPacketResult_InUse;
} }
assert(!peer->marked_for_delete_); assert(!peer->marked_for_delete_);
@ -538,15 +559,17 @@ need_big_packet:
WgKeypairEncryptPayload(data, size, ad, ad_len, send_ctr, keypair); WgKeypairEncryptPayload(data, size, ad, ad_len, send_ctr, keypair);
DoWriteUdpPacket(packet);
if (want_handshake) if (want_handshake)
peer->ScheduleNewHandshake(); peer->ScheduleNewHandshake();
return;
stats_.udp_packets_out++;
stats_.udp_bytes_out += packet->size;
return kPacketResult_ForwardUdp;
getout_discard: getout_discard:
WG_RELEASE_LOCK(peer->mutex_); WG_RELEASE_LOCK(peer->mutex_);
FreePacket(packet); return kPacketResult_Free;
return;
} }
// This scrambles the initial 16 bytes of the packet with the // This scrambles the initial 16 bytes of the packet with the
@ -576,20 +599,15 @@ static void ScrambleUnscramblePacket(Packet *packet, ScramblerSiphashKeys *keys)
} }
} }
static NOINLINE void ScrambleUnscrambleAndWrite(Packet *packet, ScramblerSiphashKeys *keys, UdpInterface *udp) { void WireguardProcessor::PrepareOutgoingHandshakePacket(WgPeer *peer, Packet *packet) {
#if WITH_HEADER_OBFUSCATION assert(dev_.IsMainThread());
ScrambleUnscramblePacket(packet, keys);
udp->WriteUdpPacket(packet);
#endif // WITH_HEADER_OBFUSCATION
}
void WireguardProcessor::DoWriteUdpPacket(Packet *packet) {
stats_.udp_packets_out++; stats_.udp_packets_out++;
stats_.udp_bytes_out += packet->size; stats_.udp_bytes_out += packet->size;
if (!dev_.header_obfuscation_) #if WITH_HEADER_OBFUSCATION
udp_->WriteUdpPacket(packet); if (dev_.header_obfuscation_)
else ScrambleUnscramblePacket(packet, &dev_.header_obfuscation_key_);
ScrambleUnscrambleAndWrite(packet, &dev_.header_obfuscation_key_, udp_); #endif // WITH_HEADER_OBFUSCATION
} }
void WireguardProcessor::RunAllMainThreadScheduled() { void WireguardProcessor::RunAllMainThreadScheduled() {
@ -657,7 +675,8 @@ void WireguardProcessor::SendHandshakeInitiation(WgPeer *peer) {
} }
WG_RELEASE_LOCK(peer->mutex_); WG_RELEASE_LOCK(peer->mutex_);
DoWriteUdpPacket(packet); PrepareOutgoingHandshakePacket(peer, packet);
udp_->WriteUdpPacket(packet);
if (attempts > 1 && attempts <= 20) if (attempts > 1 && attempts <= 20)
RINFO("Retrying handshake, attempt %d...%s", attempts, (attempts == 20) ? " (last notice)" : ""); RINFO("Retrying handshake, attempt %d...%s", attempts, (attempts == 20) ? " (last notice)" : "");
} }
@ -669,7 +688,7 @@ bool WireguardProcessor::IsMainThreadPacket(Packet *packet) {
} }
// Handles an incoming WireGuard packet from the UDP side, decrypt etc. // Handles an incoming WireGuard packet from the UDP side, decrypt etc.
void WireguardProcessor::HandleUdpPacket(Packet *packet, bool overload) { WireguardProcessor::PacketResult WireguardProcessor::HandleUdpPacket2(Packet *packet, bool overload) {
uint32 type; uint32 type;
assert(packet->protocol != 0xCD && (uint16)packet->addr.sin.sin_family != 0xCDCD); // catch msvc uninit mem assert(packet->protocol != 0xCD && (uint16)packet->addr.sin.sin_family != 0xCDCD); // catch msvc uninit mem
@ -688,40 +707,44 @@ void WireguardProcessor::HandleUdpPacket(Packet *packet, bool overload) {
if (type == MESSAGE_DATA) { if (type == MESSAGE_DATA) {
if (packet->size < sizeof(MessageData)) if (packet->size < sizeof(MessageData))
goto invalid_size; goto invalid_size;
HandleDataPacket(packet); return HandleDataPacket(packet);
#if WITH_SHORT_HEADERS #if WITH_SHORT_HEADERS
} else if (type & WG_SHORT_HEADER_BIT) { } else if (type & WG_SHORT_HEADER_BIT) {
HandleShortHeaderFormatPacket(type, packet); return HandleShortHeaderFormatPacket(type, packet);
#endif // WITH_SHORT_HEADERS #endif // WITH_SHORT_HEADERS
} else if (type == MESSAGE_HANDSHAKE_COOKIE) { } else if (type == MESSAGE_HANDSHAKE_COOKIE) {
assert(dev_.IsMainThread()); assert(dev_.IsMainThread());
if (packet->size != sizeof(MessageHandshakeCookie) || !dev_.is_private_key_initialized()) if (packet->size != sizeof(MessageHandshakeCookie) || !dev_.is_private_key_initialized())
goto invalid_size; goto invalid_size;
HandleHandshakeCookiePacket(packet); return HandleHandshakeCookiePacket(packet);
} else if (type == MESSAGE_HANDSHAKE_INITIATION) { } else if (type == MESSAGE_HANDSHAKE_INITIATION) {
assert(dev_.IsMainThread()); assert(dev_.IsMainThread());
if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeInitiation)) : (packet->size != sizeof(MessageHandshakeInitiation)) || if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeInitiation)) : (packet->size != sizeof(MessageHandshakeInitiation)) ||
!dev_.is_private_key_initialized()) !dev_.is_private_key_initialized())
goto invalid_size; goto invalid_size;
stats_.handshakes_in++; stats_.handshakes_in++;
if (CheckIncomingHandshakeRateLimit(packet, overload)) PacketResult result = CheckIncomingHandshakeRateLimit(packet, overload);
HandleHandshakeInitiationPacket(packet); if (result != kPacketResult_InUse)
return result;
return HandleHandshakeInitiationPacket(packet);
} else if (type == MESSAGE_HANDSHAKE_RESPONSE) { } else if (type == MESSAGE_HANDSHAKE_RESPONSE) {
assert(dev_.IsMainThread()); assert(dev_.IsMainThread());
if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeResponse)) : (packet->size != sizeof(MessageHandshakeResponse)) || if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeResponse)) : (packet->size != sizeof(MessageHandshakeResponse)) ||
!dev_.is_private_key_initialized()) !dev_.is_private_key_initialized())
goto invalid_size; goto invalid_size;
if (CheckIncomingHandshakeRateLimit(packet, overload)) PacketResult result = CheckIncomingHandshakeRateLimit(packet, overload);
HandleHandshakeResponsePacket(packet); if (result != kPacketResult_InUse)
return result;
return HandleHandshakeResponsePacket(packet);
} else { } else {
// unknown packet // unknown packet
invalid_size: invalid_size:
FreePacket(packet); return kPacketResult_Free;
} }
} }
#if WITH_SHORT_HEADERS #if WITH_SHORT_HEADERS
void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packet) { WireguardProcessor::PacketResultd WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packet) {
assert(dev_.IsMainOrDataThread()); assert(dev_.IsMainOrDataThread());
uint8 *data = packet->data + 1; uint8 *data = packet->data + 1;
@ -829,13 +852,11 @@ void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packe
packet->data = data; packet->data = data;
packet->size = bytes_left - keypair->auth_tag_length; packet->size = bytes_left - keypair->auth_tag_length;
HandleAuthenticatedDataPacket_WillUnlock(keypair, packet); return HandleAuthenticatedDataPacket_WillUnlock(keypair, packet);
return;
getout_unlock: getout_unlock:
WG_RELEASE_LOCK(keypair->peer->mutex_); WG_RELEASE_LOCK(keypair->peer->mutex_);
getout: getout:
FreePacket(packet); return kPacketResult_Free;
return;
} }
#endif // WITH_SHORT_HEADERS #endif // WITH_SHORT_HEADERS
@ -851,7 +872,7 @@ void WireguardProcessor::NotifyHandshakeComplete() {
procdel_->OnConnected(); procdel_->OnConnected();
} }
void WireguardProcessor::HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet) { WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet) {
WgPeer *peer = keypair->peer; WgPeer *peer = keypair->peer;
assert(peer->IsPeerLocked()); assert(peer->IsPeerLocked());
assert(packet->addr.sin.sin_family != 0); assert(packet->addr.sin.sin_family != 0);
@ -940,19 +961,15 @@ void WireguardProcessor::HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *key
packet->size = size_from_header; packet->size = size_from_header;
stats_.tun_bytes_out += size_from_header; return kPacketResult_ForwardTun;
stats_.tun_packets_out++;
tun_->WriteTunPacket(packet);
return;
getout_error_header: getout_error_header:
stats_.error_header++; stats_.error_header++;
getout: getout:
FreePacket(packet); return kPacketResult_Free;
} }
void WireguardProcessor::HandleDataPacket(Packet *packet) { WireguardProcessor::PacketResult WireguardProcessor::HandleDataPacket(Packet *packet) {
assert(dev_.IsMainOrDataThread()); assert(dev_.IsMainOrDataThread());
uint8 *data = packet->data; uint8 *data = packet->data;
@ -963,8 +980,7 @@ void WireguardProcessor::HandleDataPacket(Packet *packet) {
if (keypair == NULL || counter >= REJECT_AFTER_MESSAGES) { if (keypair == NULL || counter >= REJECT_AFTER_MESSAGES) {
stats_.error_key_id++; stats_.error_key_id++;
getout: getout:
FreePacket(packet); return kPacketResult_Free;
return;
} }
packet->data = data + sizeof(MessageData); packet->data = data + sizeof(MessageData);
@ -988,7 +1004,7 @@ getout:
goto getout; goto getout;
} else { } else {
assert(!keypair->peer->marked_for_delete_); assert(!keypair->peer->marked_for_delete_);
HandleAuthenticatedDataPacket_WillUnlock(keypair, packet); return HandleAuthenticatedDataPacket_WillUnlock(keypair, packet);
} }
} }
@ -1000,36 +1016,38 @@ static uint64 GetIpForRateLimit(Packet *packet) {
} }
} }
bool WireguardProcessor::CheckIncomingHandshakeRateLimit(Packet *packet, bool overload) { WireguardProcessor::PacketResult WireguardProcessor::CheckIncomingHandshakeRateLimit(Packet *packet, bool overload) {
assert(dev_.IsMainThread()); assert(dev_.IsMainThread());
WgRateLimit::RateLimitResult rr = dev_.rate_limiter()->CheckRateLimit(GetIpForRateLimit(packet)); WgRateLimit::RateLimitResult rr = dev_.rate_limiter()->CheckRateLimit(GetIpForRateLimit(packet));
if ((overload && rr.is_rate_limited()) || !dev_.CheckCookieMac1(packet)) { if ((overload && rr.is_rate_limited()) || !dev_.CheckCookieMac1(packet))
FreePacket(packet); return kPacketResult_Free;
return false;
}
dev_.rate_limiter()->CommitResult(rr); dev_.rate_limiter()->CommitResult(rr);
if (overload && !rr.is_first_ip() && !dev_.CheckCookieMac2(packet)) { if (overload && !rr.is_first_ip() && !dev_.CheckCookieMac2(packet)) {
dev_.CreateCookieMessage((MessageHandshakeCookie*)packet->data, packet, ((MessageHandshakeInitiation*)packet->data)->sender_key_id); dev_.CreateCookieMessage((MessageHandshakeCookie*)packet->data, packet, ((MessageHandshakeInitiation*)packet->data)->sender_key_id);
packet->size = sizeof(MessageHandshakeCookie); packet->size = sizeof(MessageHandshakeCookie);
DoWriteUdpPacket(packet); PrepareOutgoingHandshakePacket(NULL, packet);
return false; return kPacketResult_ForwardUdp;
} }
return true;
// This function returns InUse when everything went well
return kPacketResult_InUse;
} }
// server receives this when client wants to setup a session // server receives this when client wants to setup a session
void WireguardProcessor::HandleHandshakeInitiationPacket(Packet *packet) { WireguardProcessor::PacketResult WireguardProcessor::HandleHandshakeInitiationPacket(Packet *packet) {
assert(dev_.IsMainThread()); assert(dev_.IsMainThread());
WgPeer *peer = WgPeer::ParseMessageHandshakeInitiation(&dev_, packet); WgPeer *peer = WgPeer::ParseMessageHandshakeInitiation(&dev_, packet);
if (peer) { if (peer) {
DoWriteUdpPacket(packet); PrepareOutgoingHandshakePacket(peer, packet);
return kPacketResult_ForwardUdp;
} else { } else {
FreePacket(packet); return kPacketResult_Free;
} }
} }
// client receives this after session is established // client receives this after session is established
void WireguardProcessor::HandleHandshakeResponsePacket(Packet *packet) { WireguardProcessor::PacketResult WireguardProcessor::HandleHandshakeResponsePacket(Packet *packet) {
assert(dev_.IsMainThread()); assert(dev_.IsMainThread());
WgPeer *peer = WgPeer::ParseMessageHandshakeResponse(&dev_, packet); WgPeer *peer = WgPeer::ParseMessageHandshakeResponse(&dev_, packet);
if (peer) { if (peer) {
@ -1040,7 +1058,7 @@ void WireguardProcessor::HandleHandshakeResponsePacket(Packet *packet) {
NotifyHandshakeComplete(); NotifyHandshakeComplete();
SendKeepalive_Locked(peer); SendKeepalive_Locked(peer);
} }
FreePacket(packet); return kPacketResult_Free;
} }
void WireguardProcessor::SendKeepalive_Locked(WgPeer *peer) { void WireguardProcessor::SendKeepalive_Locked(WgPeer *peer) {
@ -1069,15 +1087,21 @@ void WireguardProcessor::SendQueuedPackets_Locked(WgPeer *peer) {
peer->num_queued_packets_ = 0; peer->num_queued_packets_ = 0;
while (packet != NULL) { while (packet != NULL) {
Packet *next = Packet_NEXT(packet); Packet *next = Packet_NEXT(packet);
WriteAndEncryptPacketToUdp_WillUnlock(peer, packet); PacketResult result = WriteAndEncryptPacketToUdp_WillUnlock(peer, packet);
if (result == kPacketResult_ForwardUdp) {
udp_->WriteUdpPacket(packet);
} else if (result == kPacketResult_Free) {
FreePacket(packet);
}
packet = next; packet = next;
WG_ACQUIRE_LOCK(peer->mutex_); // WriteAndEncryptPacketToUdp_WillUnlock releases the lock WG_ACQUIRE_LOCK(peer->mutex_); // WriteAndEncryptPacketToUdp_WillUnlock releases the lock
} }
} }
void WireguardProcessor::HandleHandshakeCookiePacket(Packet *packet) { WireguardProcessor::PacketResult WireguardProcessor::HandleHandshakeCookiePacket(Packet *packet) {
assert(dev_.IsMainThread()); assert(dev_.IsMainThread());
WgPeer::ParseMessageHandshakeCookie(&dev_, (MessageHandshakeCookie *)packet->data); WgPeer::ParseMessageHandshakeCookie(&dev_, (MessageHandshakeCookie *)packet->data);
return kPacketResult_Free;
} }
// Only one thread may run the second loop // Only one thread may run the second loop

View file

@ -95,6 +95,24 @@ public:
void HandleTunPacket(Packet *packet); void HandleTunPacket(Packet *packet);
void HandleUdpPacket(Packet *packet, bool overload); void HandleUdpPacket(Packet *packet, bool overload);
// These are the same as above, but instead return the processed packet
// instead of forwarding it to udp/tun.
enum PacketResult {
// The packet is now in use by the wireguard impl and must not be touched
kPacketResult_InUse = 0,
// The wireguard impl has processed the packet and it's to be forwarded
// to the udp / tun layer
kPacketResult_ForwardUdp = 1,
kPacketResult_ForwardTun = 2,
// The wireguard impl does not need the packet and it should be freed
kPacketResult_Free = 3
};
PacketResult HandleTunPacket2(Packet *packet);
PacketResult HandleUdpPacket2(Packet *packet, bool overload);
static bool IsMainThreadPacket(Packet *packet); static bool IsMainThreadPacket(Packet *packet);
void SecondLoop(); void SecondLoop();
@ -115,20 +133,20 @@ public:
void ForceSendHandshakeInitiation(WgPeer *peer); void ForceSendHandshakeInitiation(WgPeer *peer);
private: private:
void DoWriteUdpPacket(Packet *packet); inline void PrepareOutgoingHandshakePacket(WgPeer *peer, Packet *packet);
void WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet); PacketResult WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet);
void SendHandshakeInitiation(WgPeer *peer); void SendHandshakeInitiation(WgPeer *peer);
void SendKeepalive_Locked(WgPeer *peer); void SendKeepalive_Locked(WgPeer *peer);
void SendQueuedPackets_Locked(WgPeer *peer); void SendQueuedPackets_Locked(WgPeer *peer);
void HandleHandshakeInitiationPacket(Packet *packet); PacketResult HandleHandshakeInitiationPacket(Packet *packet);
void HandleHandshakeResponsePacket(Packet *packet); PacketResult HandleHandshakeResponsePacket(Packet *packet);
void HandleHandshakeCookiePacket(Packet *packet); PacketResult HandleHandshakeCookiePacket(Packet *packet);
void HandleDataPacket(Packet *packet); PacketResult HandleDataPacket(Packet *packet);
void HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet); PacketResult HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet);
void HandleShortHeaderFormatPacket(uint32 tag, Packet *packet); PacketResult HandleShortHeaderFormatPacket(uint32 tag, Packet *packet);
bool CheckIncomingHandshakeRateLimit(Packet *packet, bool overload); PacketResult CheckIncomingHandshakeRateLimit(Packet *packet, bool overload);
bool HandleIcmpv6NeighborSolicitation(const byte *data, size_t data_size); bool HandleIcmpv6NeighborSolicitation(const byte *data, size_t data_size);
void NotifyHandshakeComplete(); void NotifyHandshakeComplete();