diff --git a/netapi.h b/netapi.h index 43178e0..268ce69 100644 --- a/netapi.h +++ b/netapi.h @@ -41,12 +41,11 @@ struct QueuedItem { struct Packet : QueuedItem { int sin_size; unsigned int size; + byte *data; - IpAddr addr; // Optionally set to target/source of the packet - - uint8 post_target; uint8 userdata; - + IpAddr addr; // Optionally set to target/source of the packet + byte data_pre[4]; byte data_buf[0]; enum { diff --git a/network_win32.cpp b/network_win32.cpp index 3df6146..c2c0a7f 100644 --- a/network_win32.cpp +++ b/network_win32.cpp @@ -618,7 +618,7 @@ restart_read_udp: void UdpSocketWin32::DoMoreWrites() { // Push all the finished reads to the packet handler if (finished_reads_ != NULL) { - packet_handler_->Post(finished_reads_, finished_reads_end_, finished_reads_count_); + packet_handler_->PostPackets(finished_reads_, finished_reads_end_, finished_reads_count_); finished_reads_ = NULL; finished_reads_end_ = &finished_reads_; finished_reads_count_ = 0; @@ -695,8 +695,8 @@ void UdpSocketWin32::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) { // Remember all the finished packets and queue them up to the next thread once we've // collected them all. p->size = (int)p->overlapped.InternalHigh; - p->post_target = PacketProcessor::TARGET_PROCESSOR_UDP; - Packet_NEXT(p) = NULL; + p->queue_cb = packet_handler_->udp_queue(); + p->queue_next = NULL; *finished_reads_end_ = p; finished_reads_end_ = &Packet_NEXT(p); finished_reads_count_++; @@ -844,14 +844,8 @@ void CALLBACK PacketProcessor::ThreadPoolTimerCallback(PTP_CALLBACK_INSTANCE iTi th->mutex_.Release(); } -struct ConfigPacket { - std::string message; - uint32 ident; - Packet packet; -}; - void PacketProcessor::Reset() { - Packet *packet; + QueuedItem *packet; packet = first_; first_ = NULL; @@ -860,13 +854,8 @@ void PacketProcessor::Reset() { timer_interrupt_ = false; while (packet) { - Packet *next = Packet_NEXT(packet); - if (packet->post_target == TARGET_CONFIG_PROTOCOL) { - ConfigPacket *config = (ConfigPacket*)((uint8*)packet - offsetof(ConfigPacket, packet)); - delete config; - } else { - FreePacket(packet); - } + QueuedItem *next = packet->queue_next; + packet->queue_cb->OnQueuedItemDelete(packet); packet = next; } } @@ -875,8 +864,9 @@ int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) { int free_packets_ctr = 0; int overload = 0; int exit_code; - Packet *packet; + QueuedItem *packet; PTP_TIMER threadpool_timer; + QueueContext queue_context = {wg, backend}; threadpool_timer = CreateThreadpoolTimer(&ThreadPoolTimerCallback, this, NULL); static const int64 duetime = -10000000; // the unit is 100ns @@ -925,18 +915,11 @@ int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) { tpq_last_qsize = packets_in_queue; if (packets_in_queue >= 1024) overload = 2; - bool is_overload = (overload != 0); + queue_context.overload = (overload != 0); do { - Packet *next = Packet_NEXT(packet); - if (packet->post_target == TARGET_PROCESSOR_UDP) { - wg->HandleUdpPacket(packet, is_overload); - } else if (packet->post_target == TARGET_PROCESSOR_TUN) { - wg->HandleTunPacket(packet); - } else { - assert(packet->post_target == TARGET_CONFIG_PROTOCOL); - HandleConfigurationProtocolPacket(wg, backend, packet); - } + QueuedItem *next = packet->queue_next; + packet->queue_cb->OnQueuedItemEvent(packet, (uintptr_t)&queue_context); packet = next; } while (packet); } @@ -953,12 +936,22 @@ int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) { return exit_code; } -void PacketProcessor::HandleConfigurationProtocolPacket(WireguardProcessor *wg, TunsafeBackendWin32 *backend, Packet *packet) { - ConfigPacket *config = (ConfigPacket*)((uint8*)packet - offsetof(ConfigPacket, packet)); - std::string reply; - WgConfig::HandleConfigurationProtocolMessage(wg, std::move(config->message), &reply); - backend->delegate_->OnConfigurationProtocolReply(config->ident, std::move(reply)); +void PacketProcessorTunCb::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) { + PacketProcessor::QueueContext *context = (PacketProcessor::QueueContext *)extra; + context->wg->HandleTunPacket(static_cast(qi)); +} +void PacketProcessorTunCb::OnQueuedItemDelete(QueuedItem *qi) { + FreePacket(static_cast(qi)); +} + +void PacketProcessorUdpCb::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) { + PacketProcessor::QueueContext *context = (PacketProcessor::QueueContext *)extra; + context->wg->HandleUdpPacket(static_cast(qi), context->overload); +} + +void PacketProcessorUdpCb::OnQueuedItemDelete(QueuedItem *qi) { + FreePacket(static_cast(qi)); } void PacketProcessor::PostExit(int exit_code) { @@ -970,18 +963,18 @@ void PacketProcessor::PostExit(int exit_code) { SetEvent(event_); } -void PacketProcessor::Post(Packet *packet, Packet **end, int count) { +void PacketProcessor::PostPackets(Packet *first, Packet **end, int count) { mutex_.Acquire(); if (packets_in_queue_ >= HARD_MAXIMUM_QUEUE_SIZE) { mutex_.Release(); - FreePackets(packet, end, count); + FreePackets(first, end, count); return; } - assert(packet != NULL); + assert(first != NULL); assert(first_ || last_ptr_ == &first_); packets_in_queue_ += count; - *last_ptr_ = packet; - last_ptr_ = end; + *last_ptr_ = first; + last_ptr_ = (QueuedItem**)end; assert(first_ || last_ptr_ == &first_); if (need_notify_) { need_notify_ = 0; @@ -992,12 +985,12 @@ void PacketProcessor::Post(Packet *packet, Packet **end, int count) { mutex_.Release(); } -void PacketProcessor::ForcePost(Packet *packet) { +void PacketProcessor::ForcePost(QueuedItem *item) { mutex_.Acquire(); - Packet_NEXT(packet) = NULL; + item->queue_next = NULL; packets_in_queue_ += 1; - *last_ptr_ = packet; - last_ptr_ = &Packet_NEXT(packet); + *last_ptr_ = item; + last_ptr_ = &item->queue_next; if (need_notify_) { need_notify_ = 0; mutex_.Release(); @@ -1745,7 +1738,7 @@ void TunWin32Iocp::ThreadMain() { if (!AllocPacketFrom(&freed_packets, &freed_packets_count, &exit_thread_, &p)) break; ClearOverlapped(&p->overlapped); - p->post_target = PacketProcessor::TARGET_PROCESSOR_TUN; + p->userdata = 0; if (!ReadFile(adapter_.handle(), p->data, kPacketCapacity, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) { FreePacket(p); @@ -1786,7 +1779,7 @@ void TunWin32Iocp::ThreadMain() { if (!entries[i].lpOverlapped) continue; // This is the dummy entry from |PostQueuedCompletionStatus| Packet *p = (Packet*)((byte*)entries[i].lpOverlapped - offsetof(Packet, overlapped)); - if (p->post_target == PacketProcessor::TARGET_PROCESSOR_TUN) { + if (p->userdata == 0) { num_reads--; if ((int)p->overlapped.Internal != 0) { RERROR("TunWin32::ReadComplete error 0x%X", (int)p->overlapped.Internal); @@ -1794,7 +1787,7 @@ void TunWin32Iocp::ThreadMain() { continue; } p->size = (int)p->overlapped.InternalHigh; - + p->queue_cb = packet_handler_->tun_queue(); *finished_reads_end = p; finished_reads_end = &Packet_NEXT(p); finished_reads_count++; @@ -1814,7 +1807,7 @@ void TunWin32Iocp::ThreadMain() { *freed_packets_end = NULL; if (finished_reads != NULL) - packet_handler_->Post(finished_reads, finished_reads_end, finished_reads_count); + packet_handler_->PostPackets(finished_reads, finished_reads_end, finished_reads_count); // Initiate more writes from |wqueue_| while (num_writes < kConcurrentWriteTap) { @@ -1835,7 +1828,7 @@ void TunWin32Iocp::ThreadMain() { Packet *p = pending_writes; pending_writes = Packet_NEXT(p); ClearOverlapped(&p->overlapped); - p->post_target = PacketProcessor::TARGET_TUN_DEVICE; + p->userdata = 1; if (!WriteFile(adapter_.handle(), p->data, p->size, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) { RERROR("TunWin32: WriteFile failed 0x%X", err); FreePacket(p); @@ -1858,7 +1851,7 @@ EXIT: if (!entries[0].lpOverlapped) continue; // This is the dummy entry from |PostQueuedCompletionStatus| Packet *p = (Packet*)((byte*)entries[0].lpOverlapped - offsetof(Packet, overlapped)); - if (p->post_target == PacketProcessor::TARGET_PROCESSOR_TUN) { + if (p->userdata == 0) { num_reads--; } else { num_writes--; @@ -1968,7 +1961,6 @@ void TunWin32Overlapped::ThreadMain() { Packet *p = AllocPacket(); ClearOverlapped(&p->overlapped); p->overlapped.hEvent = read_event_; - p->post_target = PacketProcessor::TARGET_PROCESSOR_TUN; if (!ReadFile(adapter_.handle(), p->data, kPacketCapacity, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) { FreePacket(p); RERROR("TunWin32: ReadFile failed 0x%X", err); @@ -1991,7 +1983,7 @@ void TunWin32Overlapped::ThreadMain() { if (hx == read_event_) { read_packet->size = (int)read_packet->overlapped.InternalHigh; Packet_NEXT(read_packet) = NULL; - packet_handler_->Post(read_packet, &Packet_NEXT(read_packet), 1); + packet_handler_->PostPackets(read_packet, &Packet_NEXT(read_packet), 1); read_packet = NULL; } else if (hx == write_event_) { FreePacket(write_packet); @@ -2015,7 +2007,6 @@ void TunWin32Overlapped::ThreadMain() { pending_writes = Packet_NEXT(p); memset(&p->overlapped, 0, sizeof(p->overlapped)); p->overlapped.hEvent = write_event_; - p->post_target = PacketProcessor::TARGET_TUN_DEVICE; if (!WriteFile(adapter_.handle(), p->data, p->size, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) { RERROR("TunWin32: WriteFile failed 0x%X", err); FreePacket(p); @@ -2185,7 +2176,6 @@ void TunsafeBackendWin32::SetStatus(StatusCode status) { bool TunsafeBackendWin32::Configure() { // it's always initialized - return true; } @@ -2339,12 +2329,32 @@ std::string TunsafeBackendWin32::GetConfigFileName() { return std::string(); } +struct ConfigQueueItem : QueuedItem, QueuedItemCallback { + virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override; + virtual void OnQueuedItemDelete(QueuedItem *ow) override; + + std::string message; + uint32 ident; +}; + +void ConfigQueueItem::OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) { + PacketProcessor::QueueContext *context = (PacketProcessor::QueueContext *)extra; + std::string reply; + WgConfig::HandleConfigurationProtocolMessage(context->wg, std::move(message), &reply); + context->backend->delegate_->OnConfigurationProtocolReply(ident, std::move(reply)); + delete this; +} + +void ConfigQueueItem::OnQueuedItemDelete(QueuedItem *ow) { + delete this; +} + void TunsafeBackendWin32::SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) { - ConfigPacket *config_packet = new ConfigPacket; - config_packet->ident = identifier; - config_packet->message = std::move(message); - config_packet->packet.post_target = PacketProcessor::TARGET_CONFIG_PROTOCOL; - packet_processor_.ForcePost(&config_packet->packet); + ConfigQueueItem *queue_item = new ConfigQueueItem; + queue_item->ident = identifier; + queue_item->message = std::move(message); + queue_item->queue_cb = queue_item; + packet_processor_.ForcePost(queue_item); } void TunsafeBackendWin32::OnConnected() { diff --git a/network_win32.h b/network_win32.h index e024cca..6b91ed0 100644 --- a/network_win32.h +++ b/network_win32.h @@ -16,37 +16,46 @@ enum { ADAPTER_GUID_SIZE = 40, }; -struct Packet; class WireguardProcessor; class TunsafeBackendWin32; +struct PacketProcessorTunCb : QueuedItemCallback { + virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override; + virtual void OnQueuedItemDelete(QueuedItem *ow) override; +}; + +struct PacketProcessorUdpCb : QueuedItemCallback { + virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override; + virtual void OnQueuedItemDelete(QueuedItem *ow) override; +}; + class PacketProcessor { public: explicit PacketProcessor(); ~PacketProcessor(); - enum { - TARGET_PROCESSOR_UDP = 0, - TARGET_PROCESSOR_TUN = 1, - TARGET_UDP_DEVICE = 2, - TARGET_TUN_DEVICE = 3, - TARGET_CONFIG_PROTOCOL = 4, - }; - void Reset(); int Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend); - void Post(Packet *packet, Packet **end, int count); - void ForcePost(Packet *packet); + void PostPackets(Packet *first, Packet **end, int count); + void ForcePost(QueuedItem *item); void PostExit(int exit_code); const uint32 *posted_exit_code() { return &exit_code_; } + QueuedItemCallback *tun_queue() { return &tun_cb_; } + QueuedItemCallback *udp_queue() { return &udp_cb_; } + + struct QueueContext { + WireguardProcessor *wg; + TunsafeBackendWin32 *backend; + bool overload; + }; + private: static void CALLBACK ThreadPoolTimerCallback(PTP_CALLBACK_INSTANCE iTimerInstance, PVOID pContext, PTP_TIMER); - void HandleConfigurationProtocolPacket(WireguardProcessor *wg, TunsafeBackendWin32 *backend, Packet *packet); - Packet *first_; - Packet **last_ptr_; + QueuedItem *first_; + QueuedItem **last_ptr_; uint32 packets_in_queue_; uint32 need_notify_; Mutex mutex_; @@ -54,6 +63,9 @@ private: uint32 exit_code_; bool timer_interrupt_; + + PacketProcessorTunCb tun_cb_; + PacketProcessorUdpCb udp_cb_; }; class NetworkWin32; @@ -272,6 +284,7 @@ class TunsafeBackendWin32 : public TunsafeBackend, public ProcessorDelegate { friend class TunWin32Iocp; friend class TunWin32Overlapped; friend class TunWin32Adapter; + friend struct ConfigQueueItem; public: TunsafeBackendWin32(Delegate *delegate); ~TunsafeBackendWin32();