From 3228a73d009685d4d85a80cfc4c8aa99952f02d5 Mon Sep 17 00:00:00 2001 From: Ludvig Strigeus Date: Tue, 30 Oct 2018 00:31:33 +0100 Subject: [PATCH] Split up UDP processing code into more functions --- netapi.h | 31 ++- network_win32.cpp | 521 ++++++++++++++++++++++++-------------------- network_win32.h | 85 +++++++- wireguard.cpp | 4 +- wireguard_proto.cpp | 8 +- 5 files changed, 386 insertions(+), 263 deletions(-) diff --git a/netapi.h b/netapi.h index c4a1903..43178e0 100644 --- a/netapi.h +++ b/netapi.h @@ -17,26 +17,38 @@ #pragma warning (disable: 4200) -struct Packet { +struct QueuedItem; + +struct QueuedItemCallback { + virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) = 0; + virtual void OnQueuedItemDelete(QueuedItem *ow) = 0; +}; + +struct QueuedItem { union { - Packet *next; #if defined(OS_WIN) + // NOTE: This must be at offset 0 for SLIST to work SLIST_ENTRY list_entry; + OVERLAPPED overlapped; #endif + QueuedItem *queue_next; }; - unsigned int post_target, size; - byte *data; + QueuedItemCallback *queue_cb; +}; -#if defined(OS_WIN) - OVERLAPPED overlapped; // For Windows overlapped IO -#endif +#define Packet_NEXT(p) (*(Packet**)&(p)->queue_next) - IpAddr addr; // Optionally set to target/source of the packet +struct Packet : QueuedItem { int sin_size; + unsigned int size; + byte *data; + IpAddr addr; // Optionally set to target/source of the packet + + uint8 post_target; + uint8 userdata; byte data_pre[4]; byte data_buf[0]; - enum { // there's always this much data before data_ptr HEADROOM_BEFORE = 64, @@ -68,7 +80,6 @@ public: bool block_dns_on_adapters; - // Set mtu int mtu; diff --git a/network_win32.cpp b/network_win32.cpp index 7d701e3..3df6146 100644 --- a/network_win32.cpp +++ b/network_win32.cpp @@ -78,7 +78,7 @@ void FreeAllPackets() { Packet *p; p = (Packet*)InterlockedFlushSList(&freelist_head); while (Packet *r = p) { - p = p->next; + p = Packet_NEXT(p); _aligned_free(r); } } @@ -102,8 +102,6 @@ struct { } qs; -#define kConcurrentReadUdp 16 -#define kConcurrentWriteUdp 16 #define kConcurrentReadTap 16 #define kConcurrentWriteTap 16 @@ -411,7 +409,7 @@ static inline bool NoMoreAllocationRetry(volatile bool *exit_flag) { static inline bool AllocPacketFrom(Packet **list, int *counter, bool *exit_flag, Packet **res) { Packet *p; if (p = *list) { - *list = p->next; + *list = Packet_NEXT(p); (*counter)--; p->data = p->data_buf + Packet::HEADROOM_BEFORE; } else { @@ -426,24 +424,40 @@ static inline bool AllocPacketFrom(Packet **list, int *counter, bool *exit_flag, static void FreePacketList(Packet *pp) { while (Packet *p = pp) { - pp = p->next; + pp = Packet_NEXT(p); FreePacket(p); } } -UdpSocketWin32::UdpSocketWin32() { +inline void NetworkWin32::FreePacketToPool(Packet *p) { + Packet_NEXT(p) = NULL; + *freed_packets_end_ = p; + freed_packets_end_ = &Packet_NEXT(p); + freed_packets_count_++; +} + +inline bool NetworkWin32::AllocPacketFromPool(Packet **p) { + return AllocPacketFrom(&freed_packets_, &freed_packets_count_, &exit_thread_, p); +} + +UdpSocketWin32::UdpSocketWin32(NetworkWin32 *network_win32) { + network_ = network_win32; wqueue_end_ = &wqueue_; wqueue_ = NULL; - exit_thread_ = false; - thread_ = NULL; socket_ = INVALID_SOCKET; socket_ipv6_ = INVALID_SOCKET; - completion_port_handle_ = NULL; + + finished_reads_ = NULL; + finished_reads_end_ = &finished_reads_; + finished_reads_count_ = 0; + + max_read_ipv6_ = 0; + num_reads_[0] = num_reads_[1] = 0; + num_writes_ = 0; + pending_writes_ = NULL; } UdpSocketWin32::~UdpSocketWin32() { - assert(thread_ == NULL); - CloseHandle(completion_port_handle_); closesocket(socket_); closesocket(socket_ipv6_); FreePacketList(wqueue_); @@ -452,15 +466,14 @@ UdpSocketWin32::~UdpSocketWin32() { bool UdpSocketWin32::Configure(int listen_on_port) { // If attempting to initialize when the thread is already started, then stop // the thread, reinitialize, and start the thread. - if (thread_ != NULL) { - StopThread(); + if (network_->thread_ != NULL) { + network_->StopThread(); bool retcode = Configure(listen_on_port); - StartThread(); + network_->StartThread(); return retcode; } bool retval = false; - HANDLE completion_port = NULL; SOCKET socket_ipv4 = INVALID_SOCKET, socket_ipv6 = INVALID_SOCKET; socket_ipv4 = WSASocket(AF_INET, SOCK_DGRAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED); @@ -468,8 +481,7 @@ bool UdpSocketWin32::Configure(int listen_on_port) { RERROR("UdpSocketWin32::Initialize WSASocket failed"); goto fail; } - completion_port = CreateIoCompletionPort((HANDLE)socket_ipv4, NULL, NULL, 0); - if (!completion_port) { + if (!CreateIoCompletionPort((HANDLE)socket_ipv4, network_->completion_port_handle_, 0, 0)) { RERROR("UdpSocketWin32::Initialize CreateIoCompletionPort failed"); goto fail; } @@ -483,11 +495,10 @@ bool UdpSocketWin32::Configure(int listen_on_port) { goto fail; } } - // Also open up a socket for ipv6 socket_ipv6 = WSASocket(AF_INET6, SOCK_DGRAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED); if (socket_ipv6 != INVALID_SOCKET) { - if (!CreateIoCompletionPort((HANDLE)socket_ipv6, completion_port, 1, 0)) { + if (!CreateIoCompletionPort((HANDLE)socket_ipv6, network_->completion_port_handle_, 0, 0)) { RERROR("IPv6 Socket completion port failed."); closesocket(socket_ipv6); socket_ipv6 = INVALID_SOCKET; @@ -504,11 +515,11 @@ bool UdpSocketWin32::Configure(int listen_on_port) { } std::swap(socket_ipv6_, socket_ipv6); std::swap(socket_, socket_ipv4); - std::swap(completion_port_handle_, completion_port); retval = true; + + max_read_ipv6_ = socket_ipv6 != INVALID_SOCKET ? 1 : 0; + fail: - if (completion_port) - CloseHandle(completion_port); if (socket_ipv4 != INVALID_SOCKET) closesocket(socket_ipv4); if (socket_ipv6 != INVALID_SOCKET) @@ -516,6 +527,27 @@ fail: return retval; } +// Called on another thread to queue up a udp packet +void UdpSocketWin32::WriteUdpPacket(Packet *packet) { + if (qs.udp_qsize2 - qs.udp_qsize1 >= (unsigned)(packet->size < 576 ? MAX_BYTES_IN_UDP_OUT_QUEUE_SMALL : MAX_BYTES_IN_UDP_OUT_QUEUE)) { + FreePacket(packet); + return; + } + Packet_NEXT(packet) = NULL; + qs.udp_qsize2 += packet->size; + + mutex_.Acquire(); + Packet *was_empty = wqueue_; + *wqueue_end_ = packet; + wqueue_end_ = &Packet_NEXT(packet); + mutex_.Release(); + + if (was_empty == NULL) { + // Notify the worker thread that it should attempt more writes + PostQueuedCompletionStatus(network_->completion_port_handle_, NULL, NULL, NULL); + } +} + enum { kUdpGetQueuedCompletionStatusSize = kConcurrentWriteTap + kConcurrentReadTap + 1 }; @@ -532,73 +564,205 @@ static inline bool IsIgnoredUdpError(DWORD err) { return err == WSAEMSGSIZE || err == WSAECONNRESET || err == WSAENETRESET || err == STATUS_PORT_UNREACHABLE; } -void UdpSocketWin32::ThreadMain() { +void UdpSocketWin32::DoMoreReads() { + // Listen with multiple ipv6 packets only if we ever sent an ipv6 packet. + for (int i = num_reads_[IPV6]; i < max_read_ipv6_; i++) { + Packet *p; + if (!network_->AllocPacketFromPool(&p)) + break; +restart_read_udp6: + ClearOverlapped(&p->overlapped); + WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data}; + DWORD flags = 0; + p->userdata = IPV6; + p->sin_size = sizeof(p->addr.sin6); + p->queue_cb = this; + if (WSARecvFrom(socket_ipv6_, &wsabuf, 1, NULL, &flags, (struct sockaddr*)&p->addr, &p->sin_size, &p->overlapped, NULL) != 0) { + DWORD err = WSAGetLastError(); + if (err != WSA_IO_PENDING) { + if (err == WSAEMSGSIZE || err == WSAECONNRESET || err == WSAENETRESET) + goto restart_read_udp6; + RERROR("UdpSocketWin32:WSARecvFrom failed 0x%X", err); + FreePacket(p); + break; + } + } + num_reads_[IPV6]++; + } + // Initiate more reads, reusing the Packet structures in |finished_writes|. + for (int i = num_reads_[IPV4]; i < kConcurrentReadUdp; i++) { + Packet *p; + if (!network_->AllocPacketFromPool(&p)) + break; +restart_read_udp: + ClearOverlapped(&p->overlapped); + WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data}; + DWORD flags = 0; + p->userdata = IPV4; + p->sin_size = sizeof(p->addr.sin); + p->queue_cb = this; + if (WSARecvFrom(socket_, &wsabuf, 1, NULL, &flags, (struct sockaddr*)&p->addr, &p->sin_size, &p->overlapped, NULL) != 0) { + DWORD err = WSAGetLastError(); + if (err != WSA_IO_PENDING) { + if (err == WSAEMSGSIZE || err == WSAECONNRESET || err == WSAENETRESET) + goto restart_read_udp; + RERROR("UdpSocketWin32:WSARecvFrom failed 0x%X", err); + FreePacket(p); + break; + } + } + num_reads_[IPV4]++; + } +} + +void UdpSocketWin32::DoMoreWrites() { + // Push all the finished reads to the packet handler + if (finished_reads_ != NULL) { + packet_handler_->Post(finished_reads_, finished_reads_end_, finished_reads_count_); + finished_reads_ = NULL; + finished_reads_end_ = &finished_reads_; + finished_reads_count_ = 0; + } + + Packet *pending_writes = pending_writes_; + // Initiate more writes from |wqueue_| + while (num_writes_ < kConcurrentWriteUdp) { + // Refill from queue if empty, avoid taking the mutex if it looks empty + if (!pending_writes) { + if (!wqueue_) + break; + mutex_.Acquire(); + pending_writes = wqueue_; + wqueue_end_ = &wqueue_; + wqueue_ = NULL; + mutex_.Release(); + if (!pending_writes) + break; + } + qs.udp_qsize1 += pending_writes->size; + + // Then issue writes + Packet *p = pending_writes; + pending_writes = Packet_NEXT(p); + ClearOverlapped(&p->overlapped); + p->userdata = 2; + p->queue_cb = this; + WSABUF wsabuf = {(ULONG)p->size, (char*)p->data}; + int rv; + if (p->addr.sin.sin_family == AF_INET) { + rv = WSASendTo(socket_, &wsabuf, 1, NULL, 0, (struct sockaddr*)&p->addr.sin, sizeof(p->addr.sin), &p->overlapped, NULL); + } else { + if (socket_ipv6_ == INVALID_SOCKET) { + RERROR("UdpSocketWin32: unavailable ipv6 socket"); + FreePacket(p); + continue; + } + max_read_ipv6_ = kConcurrentReadUdp; + rv = WSASendTo(socket_ipv6_, &wsabuf, 1, NULL, 0, (struct sockaddr*)&p->addr.sin6, sizeof(p->addr.sin6), &p->overlapped, NULL); + } + if (rv != 0) { + DWORD err = WSAGetLastError(); + if (err != ERROR_IO_PENDING) { + RERROR("UdpSocketWin32: WSASendTo failed 0x%X", err); + FreePacket(p); + continue; + } + } + num_writes_++; + } + pending_writes_ = pending_writes; +} + +void UdpSocketWin32::CancelAllIO() { + CancelIo((HANDLE)socket_); + CancelIo((HANDLE)socket_ipv6_); + FreePacketList(pending_writes_); +} + +bool UdpSocketWin32::HasOutstandingIO() { + return (num_reads_[IPV4] + num_reads_[IPV6] + num_writes_) != 0; +} + +void UdpSocketWin32::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) { + Packet *p = static_cast(qi); + if (p->userdata < 2) { + num_reads_[p->userdata]--; + if ((DWORD)p->overlapped.Internal != 0) { + network_->FreePacketToPool(p); + if (!IsIgnoredUdpError((DWORD)p->overlapped.Internal)) + RERROR("UdpSocketWin32::Read error 0x%X", (DWORD)p->overlapped.Internal); + } else { + // Remember all the finished packets and queue them up to the next thread once we've + // collected them all. + p->size = (int)p->overlapped.InternalHigh; + p->post_target = PacketProcessor::TARGET_PROCESSOR_UDP; + Packet_NEXT(p) = NULL; + *finished_reads_end_ = p; + finished_reads_end_ = &Packet_NEXT(p); + finished_reads_count_++; + } + } else { + num_writes_--; + network_->FreePacketToPool(p); + if ((DWORD)p->overlapped.Internal != 0) + RERROR("UdpSocketWin32::Write error 0x%X", (DWORD)p->overlapped.Internal); + } +} + +void UdpSocketWin32::OnQueuedItemDelete(QueuedItem *qi) { + Packet *p = static_cast(qi); + if (p->userdata < 2) { + num_reads_[p->userdata]--; + } else { + num_writes_--; + } + network_->FreePacketToPool(p); +} + +void UdpSocketWin32::DoIO() { + DoMoreWrites(); + DoMoreReads(); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////// +NetworkWin32::NetworkWin32() : udp_socket_(this) { + exit_thread_ = false; + thread_ = NULL; + freed_packets_ = NULL; + freed_packets_end_ = &freed_packets_; + freed_packets_count_ = 0; + completion_port_handle_ = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, NULL, 0); +} + +NetworkWin32::~NetworkWin32() { + assert(thread_ == NULL); + CloseHandle(completion_port_handle_); + FreePacketList(freed_packets_); +} + +DWORD WINAPI NetworkWin32::NetworkThread(void *x) { + NetworkWin32 *net = (NetworkWin32 *)x; + net->ThreadMain(); + return 0; +} + +void NetworkWin32::ThreadMain() { OVERLAPPED_ENTRY entries[kUdpGetQueuedCompletionStatusSize]; - Packet *pending_writes = NULL; - int num_reads[2] = {0,0}, num_writes = 0; - enum { IPV4, IPV6 }; - Packet *finished_reads = NULL, **finished_reads_end = &finished_reads; - Packet *freed_packets = NULL, **freed_packets_end = &freed_packets; - int freed_packets_count = 0; - int max_read_ipv6 = socket_ipv6_ != INVALID_SOCKET ? 1 : 0; while (!exit_thread_) { - // Listen with multiple ipv6 packets only if we ever sent an ipv6 packet. - for (int i = num_reads[IPV6]; i < max_read_ipv6; i++) { - Packet *p; - if (!AllocPacketFrom(&freed_packets, &freed_packets_count, &exit_thread_, &p)) - break; -restart_read_udp6: - ClearOverlapped(&p->overlapped); - p->post_target = PacketProcessor::TARGET_PROCESSOR_UDP; - WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data}; - DWORD flags = 0; - p->sin_size = sizeof(p->addr.sin6); - if (WSARecvFrom(socket_ipv6_, &wsabuf, 1, NULL, &flags, (struct sockaddr*)&p->addr, &p->sin_size, &p->overlapped, NULL) != 0) { - DWORD err = WSAGetLastError(); - if (err != WSA_IO_PENDING) { - if (err == WSAEMSGSIZE || err == WSAECONNRESET || err == WSAENETRESET) - goto restart_read_udp6; - RERROR("UdpSocketWin32:WSARecvFrom failed 0x%X", err); - FreePacket(p); - break; - } - } - num_reads[IPV6]++; - } + // Run IO on all sockets queued for IO + udp_socket_.DoIO(); - // Initiate more reads, reusing the Packet structures in |finished_writes|. - for (int i = num_reads[IPV4]; i < kConcurrentReadTap; i++) { - Packet *p; - if (!AllocPacketFrom(&freed_packets, &freed_packets_count, &exit_thread_, &p)) - break; -restart_read_udp: - ClearOverlapped(&p->overlapped); - p->post_target = PacketProcessor::TARGET_PROCESSOR_UDP; - WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data}; - DWORD flags = 0; - p->sin_size = sizeof(p->addr.sin); - if (WSARecvFrom(socket_, &wsabuf, 1, NULL, &flags, (struct sockaddr*)&p->addr, &p->sin_size, &p->overlapped, NULL) != 0) { - DWORD err = WSAGetLastError(); - if (err != WSA_IO_PENDING) { - if (err == WSAEMSGSIZE || err == WSAECONNRESET || err == WSAENETRESET) - goto restart_read_udp; - RERROR("UdpSocketWin32:WSARecvFrom failed 0x%X", err); - FreePacket(p); - break; - } - } - num_reads[IPV4]++; - } - - assert(freed_packets_count >= 0); - if (freed_packets_count >= 32) { - FreePackets(freed_packets, freed_packets_end, freed_packets_count); - freed_packets_count = 0; - freed_packets_end = &freed_packets; - } else if (freed_packets == NULL) { - assert(freed_packets_count == 0); - freed_packets_end = &freed_packets; + // Free some packets + assert(freed_packets_count_ >= 0); + if (freed_packets_count_ >= 32) { + FreePackets(freed_packets_, freed_packets_end_, freed_packets_count_); + freed_packets_count_ = 0; + freed_packets_ = NULL; + freed_packets_end_ = &freed_packets_; + } else if (freed_packets_ == NULL) { + assert(freed_packets_count_ == 0); + freed_packets_end_ = &freed_packets_; } ULONG num_entries = 0; @@ -606,154 +770,38 @@ restart_read_udp: RINFO("GetQueuedCompletionStatusEx failed."); break; } - finished_reads_end = &finished_reads; - - int finished_reads_count = 0; - // Go through the finished entries and determine which ones are reads, and which ones are writes. for (ULONG i = 0; i < num_entries; i++) { - if (!entries[i].lpOverlapped) - continue; // This is the dummy entry from |PostQueuedCompletionStatus| - Packet *p = (Packet*)((byte*)entries[i].lpOverlapped - offsetof(Packet, overlapped)); - if (p->post_target == PacketProcessor::TARGET_PROCESSOR_UDP) { - num_reads[entries[i].lpCompletionKey]--; - if ((DWORD)p->overlapped.Internal != 0) { - if (!IsIgnoredUdpError((DWORD)p->overlapped.Internal)) - RERROR("UdpSocketWin32::Read error 0x%X", (DWORD)p->overlapped.Internal); - FreePacket(p); - continue; - } - p->size = (int)p->overlapped.InternalHigh; - *finished_reads_end = p; - finished_reads_end = &p->next; - finished_reads_count++; - } else { - num_writes--; - if ((DWORD)p->overlapped.Internal != 0) { - RERROR("UdpSocketWin32::Write error 0x%X", (DWORD)p->overlapped.Internal); - FreePacket(p); - continue; - } - *freed_packets_end = p; - freed_packets_end = &p->next; - freed_packets_count++; + if (entries[i].lpOverlapped) { + QueuedItem *w = (QueuedItem*)((byte*)entries[i].lpOverlapped - offsetof(QueuedItem, overlapped)); + w->queue_cb->OnQueuedItemEvent(w, 0); } } - *finished_reads_end = NULL; - *freed_packets_end = NULL; - assert(num_writes >= 0); - - // Push all the finished reads to the packet handler - if (finished_reads != NULL) { - packet_handler_->Post(finished_reads, finished_reads_end, finished_reads_count); - } - // Initiate more writes from |wqueue_| - while (num_writes < kConcurrentWriteTap) { - // Refill from queue if empty, avoid taking the mutex if it looks empty - if (!pending_writes) { - if (!wqueue_) - break; - mutex_.Acquire(); - pending_writes = wqueue_; - wqueue_end_ = &wqueue_; - wqueue_ = NULL; - mutex_.Release(); - if (!pending_writes) - break; - } - - qs.udp_qsize1+= pending_writes->size; - - // Then issue writes - Packet *p = pending_writes; - pending_writes = p->next; - ClearOverlapped(&p->overlapped); - p->post_target = PacketProcessor::TARGET_UDP_DEVICE; - WSABUF wsabuf = {(ULONG)p->size, (char*)p->data}; - - int rv; - if (p->addr.sin.sin_family == AF_INET) { - rv = WSASendTo(socket_, &wsabuf, 1, NULL, 0, (struct sockaddr*)&p->addr.sin, sizeof(p->addr.sin), &p->overlapped, NULL); - } else { - if (socket_ipv6_ == INVALID_SOCKET) { - RERROR("UdpSocketWin32: unavailable ipv6 socket"); - FreePacket(p); - continue; - } - max_read_ipv6 = kConcurrentReadTap; - rv = WSASendTo(socket_ipv6_, &wsabuf, 1, NULL, 0, (struct sockaddr*)&p->addr.sin6, sizeof(p->addr.sin6), &p->overlapped, NULL); - } - if (rv != 0) { - DWORD err = WSAGetLastError(); - if (err != ERROR_IO_PENDING) { - RERROR("UdpSocketWin32: WSASendTo failed 0x%X", err); - FreePacket(p); - continue; - } - } - num_writes++; - } } - FreePacketList(freed_packets); - FreePacketList(pending_writes); - // Cancel all IO and wait for all completions - CancelIo((HANDLE)socket_); - CancelIo((HANDLE)socket_ipv6_); + udp_socket_.CancelAllIO(); - while (num_reads[IPV4] + num_reads[IPV6] + num_writes) { + while (udp_socket_.HasOutstandingIO()) { ULONG num_entries = 0; if (!GetQueuedCompletionStatusEx(completion_port_handle_, entries, 1, &num_entries, INFINITE, FALSE)) { RINFO("GetQueuedCompletionStatusEx failed."); break; } - if (!entries[0].lpOverlapped) - continue; // This is the dummy entry from |PostQueuedCompletionStatus| - Packet *p = (Packet*)((byte*)entries[0].lpOverlapped - offsetof(Packet, overlapped)); - if (p->post_target == PacketProcessor::TARGET_PROCESSOR_UDP) { - num_reads[entries[0].lpCompletionKey]--; - } else { - num_writes--; + if (entries[0].lpOverlapped) { + QueuedItem *w = (QueuedItem*)((byte*)entries[0].lpOverlapped - offsetof(QueuedItem, overlapped)); + w->queue_cb->OnQueuedItemDelete(w); } - FreePacket(p); } } -// Called on another thread to queue up a udp packet -void UdpSocketWin32::WriteUdpPacket(Packet *packet) { - if (qs.udp_qsize2 - qs.udp_qsize1 >= (unsigned)(packet->size < 576 ? MAX_BYTES_IN_UDP_OUT_QUEUE_SMALL : MAX_BYTES_IN_UDP_OUT_QUEUE)) { - FreePacket(packet); - return; - } - packet->next = NULL; - qs.udp_qsize2 += packet->size; - - mutex_.Acquire(); - Packet *was_empty = wqueue_; - *wqueue_end_ = packet; - wqueue_end_ = &packet->next; - mutex_.Release(); - - if (was_empty == NULL) { - // Notify the worker thread that it should attempt more writes - PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, NULL); - } -} - -DWORD WINAPI UdpSocketWin32::UdpThread(void *x) { - UdpSocketWin32 *udp = (UdpSocketWin32 *)x; - udp->ThreadMain(); - return 0; -} - -void UdpSocketWin32::StartThread() { +void NetworkWin32::StartThread() { assert(completion_port_handle_); DWORD thread_id; - thread_ = CreateThread(NULL, 0, &UdpThread, this, 0, &thread_id); + thread_ = CreateThread(NULL, 0, &NetworkThread, this, 0, &thread_id); SetThreadPriority(thread_, ABOVE_NORMAL_PRIORITY_CLASS); } -void UdpSocketWin32::StopThread() { +void NetworkWin32::StopThread() { if (thread_ != NULL) { exit_thread_ = true; PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, NULL); @@ -764,6 +812,8 @@ void UdpSocketWin32::StopThread() { } } +///////////////////////////////////////////////////////////////////////// + PacketProcessor::PacketProcessor() { event_ = CreateEvent(NULL, FALSE, FALSE, NULL); @@ -810,7 +860,7 @@ void PacketProcessor::Reset() { timer_interrupt_ = false; while (packet) { - Packet *next = packet->next; + Packet *next = Packet_NEXT(packet); if (packet->post_target == TARGET_CONFIG_PROTOCOL) { ConfigPacket *config = (ConfigPacket*)((uint8*)packet - offsetof(ConfigPacket, packet)); delete config; @@ -878,7 +928,7 @@ int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) { bool is_overload = (overload != 0); do { - Packet *next = packet->next; + Packet *next = Packet_NEXT(packet); if (packet->post_target == TARGET_PROCESSOR_UDP) { wg->HandleUdpPacket(packet, is_overload); } else if (packet->post_target == TARGET_PROCESSOR_TUN) { @@ -944,10 +994,10 @@ void PacketProcessor::Post(Packet *packet, Packet **end, int count) { void PacketProcessor::ForcePost(Packet *packet) { mutex_.Acquire(); - packet->next = NULL; + Packet_NEXT(packet) = NULL; packets_in_queue_ += 1; *last_ptr_ = packet; - last_ptr_ = &packet->next; + last_ptr_ = &Packet_NEXT(packet); if (need_notify_) { need_notify_ = 0; mutex_.Release(); @@ -1694,7 +1744,7 @@ void TunWin32Iocp::ThreadMain() { Packet *p; if (!AllocPacketFrom(&freed_packets, &freed_packets_count, &exit_thread_, &p)) break; - memset(&p->overlapped, 0, sizeof(p->overlapped)); + ClearOverlapped(&p->overlapped); p->post_target = PacketProcessor::TARGET_PROCESSOR_TUN; if (!ReadFile(adapter_.handle(), p->data, kPacketCapacity, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) { FreePacket(p); @@ -1746,7 +1796,7 @@ void TunWin32Iocp::ThreadMain() { p->size = (int)p->overlapped.InternalHigh; *finished_reads_end = p; - finished_reads_end = &p->next; + finished_reads_end = &Packet_NEXT(p); finished_reads_count++; } else { num_writes--; @@ -1757,7 +1807,7 @@ void TunWin32Iocp::ThreadMain() { } freed_packets_count++; *freed_packets_end = p; - freed_packets_end = &p->next; + freed_packets_end = &Packet_NEXT(p); } } *finished_reads_end = NULL; @@ -1783,8 +1833,8 @@ void TunWin32Iocp::ThreadMain() { } // Then issue writes Packet *p = pending_writes; - pending_writes = p->next; - memset(&p->overlapped, 0, sizeof(p->overlapped)); + pending_writes = Packet_NEXT(p); + ClearOverlapped(&p->overlapped); p->post_target = PacketProcessor::TARGET_TUN_DEVICE; if (!WriteFile(adapter_.handle(), p->data, p->size, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) { RERROR("TunWin32: WriteFile failed 0x%X", err); @@ -1844,7 +1894,7 @@ void TunWin32Iocp::StopThread() { } void TunWin32Iocp::WriteTunPacket(Packet *packet) { - packet->next = NULL; + Packet_NEXT(packet) = NULL; mutex_.Acquire(); if (wqueue_size_ >= HARD_MAXIMUM_TUN_QUEUE_SIZE) { mutex_.Release(); @@ -1859,7 +1909,7 @@ void TunWin32Iocp::WriteTunPacket(Packet *packet) { Packet *was_empty = wqueue_; *wqueue_end_ = packet; - wqueue_end_ = &packet->next; + wqueue_end_ = &Packet_NEXT(packet); mutex_.Release(); if (was_empty == NULL) { // Notify the worker thread that it should attempt more writes @@ -1916,7 +1966,7 @@ void TunWin32Overlapped::ThreadMain() { while (!exit_thread_) { if (read_packet == NULL) { Packet *p = AllocPacket(); - memset(&p->overlapped, 0, sizeof(p->overlapped)); + ClearOverlapped(&p->overlapped); p->overlapped.hEvent = read_event_; p->post_target = PacketProcessor::TARGET_PROCESSOR_TUN; if (!ReadFile(adapter_.handle(), p->data, kPacketCapacity, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) { @@ -1940,8 +1990,8 @@ void TunWin32Overlapped::ThreadMain() { HANDLE hx = h[res - WAIT_OBJECT_0]; if (hx == read_event_) { read_packet->size = (int)read_packet->overlapped.InternalHigh; - read_packet->next = NULL; - packet_handler_->Post(read_packet, &read_packet->next, 1); + Packet_NEXT(read_packet) = NULL; + packet_handler_->Post(read_packet, &Packet_NEXT(read_packet), 1); read_packet = NULL; } else if (hx == write_event_) { FreePacket(write_packet); @@ -1962,7 +2012,7 @@ void TunWin32Overlapped::ThreadMain() { if (pending_writes) { // Then issue writes Packet *p = pending_writes; - pending_writes = p->next; + pending_writes = Packet_NEXT(p); memset(&p->overlapped, 0, sizeof(p->overlapped)); p->overlapped.hEvent = write_event_; p->post_target = PacketProcessor::TARGET_TUN_DEVICE; @@ -2002,11 +2052,11 @@ void TunWin32Overlapped::StopThread() { } void TunWin32Overlapped::WriteTunPacket(Packet *packet) { - packet->next = NULL; + Packet_NEXT(packet) = NULL; mutex_.Acquire(); Packet *was_empty = wqueue_; *wqueue_end_ = packet; - wqueue_end_ = &packet->next; + wqueue_end_ = &Packet_NEXT(packet); mutex_.Release(); if (was_empty == NULL) SetEvent(wake_event_); @@ -2024,12 +2074,12 @@ DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) { for(;;) { TunWin32Iocp tun(&backend->dns_blocker_, backend); - UdpSocketWin32 udp; - WireguardProcessor wg_proc(&udp, &tun, backend); + NetworkWin32 net; + WireguardProcessor wg_proc(&net.udp(), &tun, backend); qs.udp_qsize1 = qs.udp_qsize2 = 0; - udp.SetPacketHandler(&backend->packet_processor_); + net.udp().SetPacketHandler(&backend->packet_processor_); tun.SetPacketHandler(&backend->packet_processor_); if (backend->config_file_[0] && @@ -2043,13 +2093,12 @@ DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) { backend->wg_processor_ = &wg_proc; - udp.StartThread(); + net.StartThread(); tun.StartThread(); stop_mode = backend->packet_processor_.Run(&wg_proc, backend); - udp.StopThread(); + net.StopThread(); tun.StopThread(); - - + backend->wg_processor_ = NULL; // Keep DNS alive diff --git a/network_win32.h b/network_win32.h index 6f87f02..e024cca 100644 --- a/network_win32.h +++ b/network_win32.h @@ -56,40 +56,103 @@ private: bool timer_interrupt_; }; -// Encapsulates a UDP socket, optionally listening for incoming packets +class NetworkWin32; +class PacketAllocPool; + +// Encapsulates a UDP socket pair (ipv4 / ipv6), optionally listening for incoming packets // on a specific port. -class UdpSocketWin32 : public UdpInterface { +class UdpSocketWin32 : public UdpInterface, QueuedItemCallback { public: - explicit UdpSocketWin32(); + explicit UdpSocketWin32(NetworkWin32 *network_win32); ~UdpSocketWin32(); void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; } - void StartThread(); - void StopThread(); - // -- from UdpInterface virtual bool Configure(int listen_on_port) override; virtual void WriteUdpPacket(Packet *packet) override; + void DoIO(); + void CancelAllIO(); + bool HasOutstandingIO(); + + enum { + kConcurrentReadUdp = 16, + kConcurrentWriteUdp = 16 + }; + private: - void ThreadMain(); - static DWORD WINAPI UdpThread(void *x); + + void DoMoreReads(); + void DoMoreWrites(); + + // From OverlappedCallbacks + virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override; + virtual void OnQueuedItemDelete(QueuedItem *ow) override; + + NetworkWin32 *network_; // All packets queued for writing. Locked by |mutex_| + // Both ipv6 and ipv4 are supported Packet *wqueue_, **wqueue_end_; + // Protects wqueue Mutex mutex_; + // This is where packets end up PacketProcessor *packet_handler_; - SOCKET socket_; - SOCKET socket_ipv6_; - HANDLE completion_port_handle_; + + // The two socket handles, since we support both ipv4 and ipv6 + SOCKET socket_, socket_ipv6_; + + enum { IPV4, IPV6 }; + int max_read_ipv6_; + int num_reads_[2]; + int num_writes_; + Packet *pending_writes_; + + Packet *finished_reads_, **finished_reads_end_; + int finished_reads_count_; +}; + +// Holds the thread for network communications +class NetworkWin32 { + friend class UdpSocketWin32; +public: + explicit NetworkWin32(); + ~NetworkWin32(); + + void StartThread(); + void StopThread(); + + UdpSocketWin32 &udp() { return udp_socket_; } + +private: + void ThreadMain(); + static DWORD WINAPI NetworkThread(void *x); + + void FreePacketToPool(Packet *p); + bool AllocPacketFromPool(Packet **p); + + // The network thread handle HANDLE thread_; + // Whether we're exiting the thread bool exit_thread_; + + // The handle to the completion port + HANDLE completion_port_handle_; + + Packet *freed_packets_, **freed_packets_end_; + int freed_packets_count_; + + // Right now there's always one udp socket only + UdpSocketWin32 udp_socket_; }; + + + class DnsBlocker; class TunWin32Adapter { diff --git a/wireguard.cpp b/wireguard.cpp index 99d025e..d565dfe 100644 --- a/wireguard.cpp +++ b/wireguard.cpp @@ -1053,7 +1053,7 @@ void WireguardProcessor::SendKeepalive_Locked(WgPeer *peer) { if (!packet) return; packet->size = 0; - packet->next = NULL; + Packet_NEXT(packet) = NULL; peer->first_queued_packet_ = packet; } SendQueuedPackets_Locked(peer); @@ -1067,7 +1067,7 @@ void WireguardProcessor::SendQueuedPackets_Locked(WgPeer *peer) { peer->last_queued_packet_ptr_ = &peer->first_queued_packet_; peer->num_queued_packets_ = 0; while (packet != NULL) { - Packet *next = packet->next; + Packet *next = Packet_NEXT(packet); WriteAndEncryptPacketToUdp_WillUnlock(peer, packet); packet = next; WG_ACQUIRE_LOCK(peer->mutex_); // WriteAndEncryptPacketToUdp_WillUnlock releases the lock diff --git a/wireguard_proto.cpp b/wireguard_proto.cpp index 60acdae..c2bc048 100644 --- a/wireguard_proto.cpp +++ b/wireguard_proto.cpp @@ -459,7 +459,7 @@ void WgPeer::ClearPacketQueue_Locked() { assert(dev_->IsMainThread() && IsPeerLocked()); Packet *packet; while ((packet = first_queued_packet_) != NULL) { - first_queued_packet_ = packet->next; + first_queued_packet_ = Packet_NEXT(packet); FreePacket(packet); } last_queued_packet_ptr_ = &first_queued_packet_; @@ -472,14 +472,14 @@ void WgPeer::AddPacketToPeerQueue_Locked(Packet *packet) { // Keep only the first MAX_QUEUED_PACKETS packets. while (num_queued_packets_ >= MAX_QUEUED_PACKETS_PER_PEER) { Packet *packet = first_queued_packet_; - first_queued_packet_ = packet->next; + first_queued_packet_ = Packet_NEXT(packet); num_queued_packets_--; FreePacket(packet); } // Add the packet to the out queue that will get sent once handshake completes *last_queued_packet_ptr_ = packet; - last_queued_packet_ptr_ = &packet->next; - packet->next = NULL; + last_queued_packet_ptr_ = &Packet_NEXT(packet); + Packet_NEXT(packet) = NULL; num_queued_packets_++; }