// SPDX-License-Identifier: AGPL-1.0-only // Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. #include "stdafx.h" #include "network_win32.h" #include "wireguard_config.h" #include "netapi.h" #include #include #include #include #include #include #include #include #include #include #include #include #include "tunsafe_endian.h" #include "wireguard.h" #include "util.h" #include #include "network_win32_dnsblock.h" #include "util_win32.h" #include "tunsafe_wg_plugin.h" enum { HARD_MAXIMUM_QUEUE_SIZE = 102400, MAX_BYTES_IN_UDP_OUT_QUEUE = 256 * 1024, MAX_BYTES_IN_UDP_OUT_QUEUE_SMALL = (256 + 64) * 1024, // On Windows 7 with NDIS6 sometimes the tun queue blows up. HARD_MAXIMUM_TUN_QUEUE_SIZE = 16384, }; enum { kMetricNone = -1, kMetricAutomatic = 0, }; static uint8 internet_route_blocking_state; static SLIST_HEADER freelist_head; static HKEY g_hklm_reg_key; static uint8 g_killswitch_curr, g_killswitch_want, g_killswitch_currconn; bool g_allow_pre_post; static volatile bool g_fail_malloc_flag; static void DeactivateKillSwitch(uint32 want); Packet *AllocPacket() { Packet *packet = (Packet*)InterlockedPopEntrySList(&freelist_head); if (packet == NULL) { while ((packet = (Packet *)_aligned_malloc(kPacketAllocSize, 16)) == NULL) { if (g_fail_malloc_flag) return NULL; Sleep(1000); } } packet->Reset(); return packet; } void FreePacket(Packet *packet) { InterlockedPushEntrySList(&freelist_head, &packet->list_entry); } static bool IsIpv6AddressSet(const void *p) { return (ReadLE64(p) | ReadLE64((char*)p + 8)) != 0; } extern "C" PSLIST_ENTRY __fastcall InterlockedPushListSList( IN PSLIST_HEADER ListHead, IN PSLIST_ENTRY List, IN PSLIST_ENTRY ListEnd, IN ULONG Count ); void FreePackets(Packet *packet, Packet **end, int count) { InterlockedPushListSList(&freelist_head, &packet->list_entry, (PSLIST_ENTRY)end, count); } void FreeAllPackets() { Packet *p; p = (Packet*)InterlockedFlushSList(&freelist_head); while (Packet *r = p) { p = Packet_NEXT(p); _aligned_free(r); } } void SimplePacketPool::FreeSomePacketsInner() { int n = freed_packets_count_ - 24; Packet **p = &freed_packets_; for (; n; n--) p = &Packet_NEXT(*p); FreePackets(exch(freed_packets_, *p), p, exch(freed_packets_count_, 24) - 24); } void InitPacketMutexes() { static bool mutex_inited; if (!mutex_inited) { mutex_inited = true; InitializeSListHead(&freelist_head); } } #define kConcurrentReadTap 16 #define kConcurrentWriteTap 16 #define kAdapterKeyName "SYSTEM\\CurrentControlSet\\Control\\Class\\{4D36E972-E325-11CE-BFC1-08002BE10318}" #define kNetworkConnectionsKeyName "SYSTEM\\CurrentControlSet\\Control\\Network\\{4D36E972-E325-11CE-BFC1-08002BE10318}" #define kTapComponentId "tap0901" #define TAP_CONTROL_CODE(request,method) \ CTL_CODE (FILE_DEVICE_UNKNOWN, request, method, FILE_ANY_ACCESS) #define TAP_IOCTL_GET_MAC TAP_CONTROL_CODE(1, METHOD_BUFFERED) #define TAP_IOCTL_GET_VERSION TAP_CONTROL_CODE(2, METHOD_BUFFERED) #define TAP_IOCTL_GET_MTU TAP_CONTROL_CODE(3, METHOD_BUFFERED) #define TAP_IOCTL_GET_INFO TAP_CONTROL_CODE(4, METHOD_BUFFERED) #define TAP_IOCTL_CONFIG_POINT_TO_POINT TAP_CONTROL_CODE(5, METHOD_BUFFERED) #define TAP_IOCTL_SET_MEDIA_STATUS TAP_CONTROL_CODE(6, METHOD_BUFFERED) #define TAP_IOCTL_CONFIG_DHCP_MASQ TAP_CONTROL_CODE(7, METHOD_BUFFERED) #define TAP_IOCTL_GET_LOG_LINE TAP_CONTROL_CODE(8, METHOD_BUFFERED) #define TAP_IOCTL_CONFIG_DHCP_SET_OPT TAP_CONTROL_CODE(9, METHOD_BUFFERED) #define TAP_IOCTL_CONFIG_TUN TAP_CONTROL_CODE(10, METHOD_BUFFERED) static bool RunNetsh(const char *cmdline) { wchar_t path[MAX_PATH + 20]; size_t size = GetSystemDirectoryW(path, MAX_PATH); bool result = false; if (!size) { RERROR("GetSystemDirectory failed"); return false; } memcpy(path + size, L"\\netsh.exe", 11 * sizeof(path[0])); size_t cmdline_size = strlen(cmdline); wchar_t *cmdlinew = new wchar_t[cmdline_size + 1]; for (size_t i = 0; i <= cmdline_size; i++) cmdlinew[i] = cmdline[i]; STARTUPINFOW si = {0}; PROCESS_INFORMATION pi = {0}; GetStartupInfoW(&si); si.dwFlags = STARTF_USESHOWWINDOW; si.wShowWindow = SW_HIDE; if (CreateProcessW(path, cmdlinew, NULL, NULL, FALSE, CREATE_NO_WINDOW, NULL, NULL, &si, &pi)) { DWORD exit_code = -1; WaitForSingleObject(pi.hProcess, INFINITE); GetExitCodeProcess(pi.hProcess, &exit_code); if (exit_code != 0) RERROR("Netsh failed (%d) : %s", exit_code, cmdline); else { RINFO("Run: %s", cmdline); result = true; } CloseHandle(pi.hThread); CloseHandle(pi.hProcess); } else { RERROR("CreateProcess failed: %s", cmdline); } delete[]cmdlinew; return result; } // Open the TAP adapter, either a random one or a specific one // On return, the adapter is locked in |TunAdaptersInUse|. static HANDLE OpenTunAdapter(char guid[ADAPTER_GUID_SIZE], TunsafeBackendWin32 *backend, DWORD open_flags) { char path[128]; HANDLE h; int retries = 0; std::vector adapters; // When guid is empty, we try all adapters, otherwise // just try the specific adapter if (guid[0] == 0) { GetTapAdapterInfo(&adapters); if (adapters.empty()) { RERROR("Unable to find any TAP adapters"); RERROR(" Please ensure that TunSafe-TAP is properly installed."); return NULL; } } else { adapters.emplace_back(); memcpy(adapters.back().guid, guid, ADAPTER_GUID_SIZE); adapters.back().name[0] = 0; } TunAdaptersInUse *tun_adapters_in_use = TunAdaptersInUse::GetInstance(); RETRY: bool did_try_adapter = false; int error_code = 0; for (GuidAndDevName &x : adapters) { snprintf(path, sizeof(path), "\\\\.\\Global\\%s.tap", x.guid); if (tun_adapters_in_use->Acquire(x.guid, static_cast(backend))) { h = CreateFile(path, GENERIC_READ | GENERIC_WRITE, 0, 0, OPEN_EXISTING, FILE_ATTRIBUTE_SYSTEM | open_flags, 0); if (h != INVALID_HANDLE_VALUE) { memcpy(guid, x.guid, ADAPTER_GUID_SIZE); return h; } did_try_adapter = true; error_code = GetLastError(); tun_adapters_in_use->Release(static_cast(backend)); } } if (!did_try_adapter) { RERROR("All TAP adapters are currently in use"); return NULL; } // Sometimes if you close the device right before, it will fail to open with errorcode 31. // When resuming from sleep in my VM, the error code is ERROR_FILE_NOT_FOUND if ((error_code == ERROR_FILE_NOT_FOUND || error_code == ERROR_GEN_FAILURE) && !backend->exit_code()) { if (retries <= 10) { RERROR("OpenTapAdapter: CreateFile failed: 0x%X... retrying%s", error_code, retries == 10 ? " (last notice)" : ""); if (retries == 10) { if (error_code == ERROR_FILE_NOT_FOUND) { RERROR(" Please ensure that TunSafe-TAP is properly installed."); } else if (error_code == ERROR_GEN_FAILURE) { RERROR(" Please ensure that the TAP device is not in use."); } backend->SetStatus(TunsafeBackend::kStatusTunRetrying); } } int sleep_amount = 250 * std::min(++retries, 40); for (;;) { if (backend->exit_code()) return NULL; if (sleep_amount == 0) break; Sleep(125); sleep_amount -= 125; } goto RETRY; } RERROR("OpenTapAdapter: CreateFile failed: 0x%X", error_code); return NULL; } static bool AddRoute(int family, const void *dest, int dest_prefix, const void *gateway, const NET_LUID *interface_luid, std::vector *undo_array) { MIB_IPFORWARD_ROW2 row = {0}; char buf1[kSizeOfAddress], buf2[kSizeOfAddress]; row.InterfaceLuid = *interface_luid; row.DestinationPrefix.PrefixLength = dest_prefix; row.DestinationPrefix.Prefix.si_family = family; row.NextHop.si_family = family; if (family == AF_INET) { memcpy(&row.DestinationPrefix.Prefix.Ipv4.sin_addr, dest, 4); memcpy(&row.NextHop.Ipv4.sin_addr, gateway, 4); } else if (family == AF_INET6) { memcpy(&row.DestinationPrefix.Prefix.Ipv6.sin6_addr, dest, 16); memcpy(&row.NextHop.Ipv6.sin6_addr, gateway, 16); } else { return false; } row.ValidLifetime = 0xffffffff; row.PreferredLifetime = 0xffffffff; row.Metric = 100; row.Protocol = MIB_IPPROTO_NETMGMT; DWORD error = CreateIpForwardEntry2(&row); if (error == NO_ERROR || error == ERROR_OBJECT_ALREADY_EXISTS) { if (undo_array) undo_array->push_back(row); RINFO("Added Route %s => %s%s", print_ip_prefix(buf1, family, dest, dest_prefix), print_ip_prefix(buf2, family, gateway, -1), (error == ERROR_OBJECT_ALREADY_EXISTS) ? " (already exists)" : ""); return true; } RINFO("AddRoute failed (%d) %s => %s", error, print_ip_prefix(buf1, family, dest, dest_prefix), print_ip_prefix(buf2, family, gateway, -1)); return false; } static bool DeleteRoute(MIB_IPFORWARD_ROW2 *row) { char buf1[kSizeOfAddress], buf2[kSizeOfAddress]; DWORD error = DeleteIpForwardEntry2(row); print_ip_prefix(buf1, row->DestinationPrefix.Prefix.si_family, (row->DestinationPrefix.Prefix.si_family == AF_INET) ? (uint8*) &row->DestinationPrefix.Prefix.Ipv4.sin_addr : (uint8*) &row->DestinationPrefix.Prefix.Ipv6.sin6_addr, row->DestinationPrefix.PrefixLength); print_ip_prefix(buf2, row->NextHop.si_family, (row->NextHop.si_family == AF_INET) ? (uint8*)&row->NextHop.Ipv4.sin_addr : (uint8*)&row->NextHop.Ipv6.sin6_addr, -1); if (error == NO_ERROR) { RINFO("Deleted Route %s => %s", buf1, buf2); return true; } RINFO("DeleteRoute failed (%d) %s => %s", error, buf1, buf2); return false; } static uint32 CidrToNetmaskV4(int cidr) { return cidr == 32 ? 0xffffffff : 0xffffffff << (32 - cidr); } struct RouteInfo { uint8 default_gw[16]; NET_LUID default_adapter; bool found_default_adapter; uint8 found_null_routes; }; static bool IsRouteOriginatingFromNullRoute(MIB_IPFORWARD_ROW2 *row) { if (!(row->InterfaceLuid.Info.IfType == 24 && row->Protocol == MIB_IPPROTO_NETMGMT && row->DestinationPrefix.PrefixLength == 1)) return false; if (row->NextHop.si_family == AF_INET) { return (row->NextHop.Ipv4.sin_addr.S_un.S_addr == 0); } else if (row->NextHop.si_family == AF_INET6) { static const uint32 nulladdr[4]; return memcmp(&row->NextHop.Ipv6.sin6_addr, nulladdr, 16) == 0; } return false; } static bool IsDestinationRouteEqualTo(MIB_IPFORWARD_ROW2 *row, const WgCidrAddr *addr) { if (addr->size == 32) { return row->DestinationPrefix.Prefix.si_family == AF_INET && row->DestinationPrefix.PrefixLength == addr->cidr && memcmp(&row->DestinationPrefix.Prefix.Ipv4.sin_addr, addr->addr, 4) == 0; } else if (addr->size == 128) { return row->DestinationPrefix.Prefix.si_family == AF_INET6 && row->DestinationPrefix.PrefixLength == addr->cidr && memcmp(&row->DestinationPrefix.Prefix.Ipv6.sin6_addr, addr->addr, 16) == 0; } else { return false; } } static bool IsDestinationRouteEqualToAny(MIB_IPFORWARD_ROW2 *row, const std::vector &addr) { for (const WgCidrAddr &x : addr) if (IsDestinationRouteEqualTo(row, &x)) return true; return false; } static void DeleteRouteOrPrintErr(MIB_IPFORWARD_ROW2 *row) { char buf1[kSizeOfAddress]; UINT32 r = DeleteIpForwardEntry2(row); if (r) RERROR("Unable to delete old route (%d): %s", r, print_ip_prefix(buf1, row->DestinationPrefix.Prefix.si_family, row->DestinationPrefix.Prefix.si_family == AF_INET ? (void*)&row->DestinationPrefix.Prefix.Ipv4.sin_addr : (void*)&row->DestinationPrefix.Prefix.Ipv6.sin6_addr, row->DestinationPrefix.PrefixLength)); } static bool GetDefaultRouteAndDeleteOldRoutes(int family, const NET_LUID *InterfaceLuid, bool keep_null_routes, const std::vector *old_endpoint_to_delete, RouteInfo *ri) { MIB_IPFORWARD_TABLE2 *table = NULL; assert(family == AF_INET || family == AF_INET6); ri->found_default_adapter = false; ri->found_null_routes = 0; if (GetIpForwardTable2(family, &table)) return false; DWORD rv = 0; DWORD gw_metric = 0xffffffff; for (unsigned i = 0; i < table->NumEntries; i++) { MIB_IPFORWARD_ROW2 *row = &table->Table[i]; if (InterfaceLuid && memcmp(&row->InterfaceLuid, InterfaceLuid, sizeof(NET_LUID)) == 0) { if (row->Protocol == MIB_IPPROTO_NETMGMT && !row->AutoconfigureAddress) DeleteRouteOrPrintErr(row); } else if (IsRouteOriginatingFromNullRoute(row)) { ri->found_null_routes++; if (!keep_null_routes) DeleteRouteOrPrintErr(row); } else if (row->DestinationPrefix.PrefixLength == 0 && row->Metric < gw_metric) { gw_metric = row->Metric; if (family == AF_INET) { memcpy(&ri->default_gw, &row->NextHop.Ipv4.sin_addr, 4); } else { memcpy(&ri->default_gw, &row->NextHop.Ipv6.sin6_addr, 16); } ri->default_adapter = row->InterfaceLuid; ri->found_default_adapter = true; } } if (old_endpoint_to_delete && ri->found_default_adapter) { for (unsigned i = 0; i < table->NumEntries; i++) { MIB_IPFORWARD_ROW2 *row = &table->Table[i]; if (row->Protocol == MIB_IPPROTO_NETMGMT && memcmp(&row->InterfaceLuid, &ri->default_adapter, sizeof(NET_LUID)) == 0) { if (IsDestinationRouteEqualToAny(row, *old_endpoint_to_delete)) DeleteRouteOrPrintErr(row); } } } FreeMibTable(table); return (rv == 0); } void FreePacketList(Packet *pp) { while (Packet *p = pp) { pp = Packet_NEXT(p); FreePacket(p); } } UdpSocketWin32::UdpSocketWin32(NetworkWin32 *network_win32) { network_ = network_win32; wqueue_end_ = &wqueue_; wqueue_ = NULL; socket_ = INVALID_SOCKET; socket_ipv6_ = INVALID_SOCKET; finished_reads_ = NULL; finished_reads_end_ = &finished_reads_; finished_reads_count_ = 0; max_read_ipv6_ = 0; num_reads_[0] = num_reads_[1] = 0; num_writes_ = 0; pending_writes_ = NULL; qsize1_ = 0; qsize2_ = 0; } UdpSocketWin32::~UdpSocketWin32() { closesocket(socket_); closesocket(socket_ipv6_); FreePacketList(wqueue_); } bool UdpSocketWin32::Configure(int listen_on_port) { // If attempting to initialize when the thread is already started, then stop // the thread, reinitialize, and start the thread. if (network_->thread_ != NULL) { network_->StopThread(); bool retcode = Configure(listen_on_port); network_->StartThread(); return retcode; } bool retval = false; SOCKET socket_ipv4 = INVALID_SOCKET, socket_ipv6 = INVALID_SOCKET; socket_ipv4 = WSASocket(AF_INET, SOCK_DGRAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED); if (socket_ipv4 == INVALID_SOCKET) { RERROR("UdpSocketWin32::Initialize WSASocket failed"); goto fail; } if (!CreateIoCompletionPort((HANDLE)socket_ipv4, network_->completion_port_handle_, 0, 0)) { RERROR("UdpSocketWin32::Initialize CreateIoCompletionPort failed"); goto fail; } { sockaddr_in sin = {0}; sin.sin_family = AF_INET; sin.sin_port = htons(listen_on_port); if (bind(socket_ipv4, (struct sockaddr*)&sin, sizeof(sin)) != 0) { RERROR("UdpSocketWin32::Initialize bind failed"); goto fail; } } // Also open up a socket for ipv6 socket_ipv6 = WSASocket(AF_INET6, SOCK_DGRAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED); if (socket_ipv6 != INVALID_SOCKET) { if (!CreateIoCompletionPort((HANDLE)socket_ipv6, network_->completion_port_handle_, 0, 0)) { RERROR("IPv6 Socket completion port failed."); closesocket(socket_ipv6); socket_ipv6 = INVALID_SOCKET; } else { sockaddr_in6 sin6 = {0}; sin6.sin6_family = AF_INET6; sin6.sin6_port = htons(listen_on_port); if (bind(socket_ipv6, (struct sockaddr*)&sin6, sizeof(sin6)) != 0) { RERROR("UdpSocketWin32::Initialize bind failed IPv6"); } } } else { RERROR("IPv6 Socket creation failed."); } std::swap(socket_ipv6_, socket_ipv6); std::swap(socket_, socket_ipv4); retval = true; max_read_ipv6_ = socket_ipv6 != INVALID_SOCKET ? 1 : 0; fail: if (socket_ipv4 != INVALID_SOCKET) closesocket(socket_ipv4); if (socket_ipv6 != INVALID_SOCKET) closesocket(socket_ipv6); return retval; } // Called on another thread to queue up a udp packet void UdpSocketWin32::WriteUdpPacket(Packet *packet) { if (qsize2_ - qsize1_ >= (unsigned)(packet->size < 576 ? MAX_BYTES_IN_UDP_OUT_QUEUE_SMALL : MAX_BYTES_IN_UDP_OUT_QUEUE)) { FreePacket(packet); return; } packet->queue_next = NULL; qsize2_ += packet->size; mutex_.Acquire(); Packet *was_empty = wqueue_; *wqueue_end_ = packet; wqueue_end_ = &Packet_NEXT(packet); mutex_.Release(); if (was_empty == NULL) network_->WakeUp(); } enum { kUdpGetQueuedCompletionStatusSize = kConcurrentWriteTap + kConcurrentReadTap + 1 }; #ifndef STATUS_PORT_UNREACHABLE #define STATUS_PORT_UNREACHABLE 0xC000023F #endif static inline bool IsIgnoredUdpError(DWORD err) { return err == WSAEMSGSIZE || err == WSAECONNRESET || err == WSAENETRESET || err == STATUS_PORT_UNREACHABLE; } void UdpSocketWin32::DoMoreReads() { // Listen with multiple ipv6 packets only if we ever sent an ipv6 packet. for (int i = num_reads_[IPV6]; i < max_read_ipv6_; i++) { Packet *p = network_->packet_pool().AllocPacketFromPool(); if (!p) break; restart_read_udp6: ClearOverlapped(&p->overlapped); WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data}; DWORD flags = 0; p->userdata = IPV6; p->sin_size = sizeof(p->addr.sin6); p->queue_cb = this; if (WSARecvFrom(socket_ipv6_, &wsabuf, 1, NULL, &flags, (struct sockaddr*)&p->addr, &p->sin_size, &p->overlapped, NULL) != 0) { DWORD err = WSAGetLastError(); if (err != WSA_IO_PENDING) { if (err == WSAEMSGSIZE || err == WSAECONNRESET || err == WSAENETRESET) goto restart_read_udp6; RERROR("UdpSocketWin32:WSARecvFrom failed 0x%X", err); FreePacket(p); break; } } num_reads_[IPV6]++; } // Initiate more reads, reusing the Packet structures in |finished_writes|. if (socket_ != INVALID_SOCKET) { for (int i = num_reads_[IPV4]; i < kConcurrentReadUdp; i++) { Packet *p = network_->packet_pool().AllocPacketFromPool(); if (!p) break; restart_read_udp: ClearOverlapped(&p->overlapped); WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data}; DWORD flags = 0; p->userdata = IPV4; p->sin_size = sizeof(p->addr.sin); p->queue_cb = this; if (WSARecvFrom(socket_, &wsabuf, 1, NULL, &flags, (struct sockaddr*)&p->addr, &p->sin_size, &p->overlapped, NULL) != 0) { DWORD err = WSAGetLastError(); if (err != WSA_IO_PENDING) { if (err == WSAEMSGSIZE || err == WSAECONNRESET || err == WSAENETRESET) goto restart_read_udp; RERROR("UdpSocketWin32:WSARecvFrom failed 0x%X", err); FreePacket(p); break; } } num_reads_[IPV4]++; } } } void UdpSocketWin32::ProcessPackets() { // Push all the finished reads to the packet handler if (finished_reads_ != NULL) { packet_handler_->PostPackets(finished_reads_, finished_reads_end_, finished_reads_count_); finished_reads_ = NULL; finished_reads_end_ = &finished_reads_; finished_reads_count_ = 0; } } void UdpSocketWin32::DoMoreWrites() { Packet *pending_writes = pending_writes_; // Initiate more writes from |wqueue_| while (num_writes_ < kConcurrentWriteUdp) { // Refill from queue if empty, avoid taking the mutex if it looks empty if (!pending_writes) { if (!wqueue_) break; mutex_.Acquire(); pending_writes = wqueue_; wqueue_end_ = &wqueue_; wqueue_ = NULL; mutex_.Release(); if (!pending_writes) break; } qsize1_ += pending_writes->size; // Then issue writes Packet *p = pending_writes; pending_writes = Packet_NEXT(p); ClearOverlapped(&p->overlapped); p->userdata = 2; p->queue_cb = this; WSABUF wsabuf = {(ULONG)p->size, (char*)p->data}; int rv; if (p->addr.sin.sin_family == AF_INET) { rv = WSASendTo(socket_, &wsabuf, 1, NULL, 0, (struct sockaddr*)&p->addr.sin, sizeof(p->addr.sin), &p->overlapped, NULL); } else { if (socket_ipv6_ == INVALID_SOCKET) { RERROR("UdpSocketWin32: unavailable ipv6 socket"); FreePacket(p); continue; } max_read_ipv6_ = kConcurrentReadUdp; rv = WSASendTo(socket_ipv6_, &wsabuf, 1, NULL, 0, (struct sockaddr*)&p->addr.sin6, sizeof(p->addr.sin6), &p->overlapped, NULL); } if (rv != 0) { DWORD err = WSAGetLastError(); if (err != ERROR_IO_PENDING) { RERROR("UdpSocketWin32: WSASendTo failed 0x%X", err); FreePacket(p); continue; } } num_writes_++; } pending_writes_ = pending_writes; } void UdpSocketWin32::CancelAllIO() { CancelIo((HANDLE)socket_); CancelIo((HANDLE)socket_ipv6_); FreePacketList(pending_writes_); } bool UdpSocketWin32::HasOutstandingIO() { return (num_reads_[IPV4] + num_reads_[IPV6] + num_writes_) != 0; } void UdpSocketWin32::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) { Packet *p = static_cast(qi); if (p->userdata < 2) { num_reads_[p->userdata]--; if ((DWORD)p->overlapped.Internal != 0) { if (!IsIgnoredUdpError((DWORD)p->overlapped.Internal)) RERROR("UdpSocketWin32::Read error 0x%X", (DWORD)p->overlapped.Internal); network_->packet_pool().FreePacketToPool(p); } else { // Remember all the finished packets and queue them up to the next thread once we've // collected them all. p->size = (int)p->overlapped.InternalHigh; p->protocol = kPacketProtocolUdp; p->queue_cb = packet_handler_->udp_queue(); p->queue_next = NULL; *finished_reads_end_ = p; finished_reads_end_ = &Packet_NEXT(p); finished_reads_count_++; } } else { num_writes_--; if ((DWORD)p->overlapped.Internal != 0) RERROR("UdpSocketWin32::Write error 0x%X", (DWORD)p->overlapped.Internal); network_->packet_pool().FreePacketToPool(p); } } void UdpSocketWin32::OnQueuedItemDelete(QueuedItem *qi) { Packet *p = static_cast(qi); if (p->userdata < 2) { num_reads_[p->userdata]--; } else { num_writes_--; } network_->packet_pool().FreePacketToPool(p); } void UdpSocketWin32::DoIO() { DoMoreWrites(); ProcessPackets(); DoMoreReads(); } //////////////////////////////////////////////////////////////////////////////////////////////////////// NetworkWin32::NetworkWin32() : udp_socket_(this), tcp_socket_queue_(this) { exit_thread_ = false; thread_ = NULL; completion_port_handle_ = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, NULL, 0); tcp_socket_ = NULL; } NetworkWin32::~NetworkWin32() { assert(thread_ == NULL); for (TcpSocketWin32 *socket = tcp_socket_; socket; ) delete exch(socket, socket->next_); CloseHandle(completion_port_handle_); } DWORD WINAPI NetworkWin32::NetworkThread(void *x) { NetworkWin32 *net = (NetworkWin32 *)x; net->ThreadMain(); return 0; } void NetworkWin32::ThreadMain() { OVERLAPPED_ENTRY entries[kUdpGetQueuedCompletionStatusSize]; while (!exit_thread_) { // TODO: In the future, don't process every socket here, only // those sockets that requested it. udp_socket_.DoIO(); for (TcpSocketWin32 *tcp = tcp_socket_; tcp;) exch(tcp, tcp->next_)->DoIO(); packet_pool_.FreeSomePackets(); ULONG num_entries = 0; if (!GetQueuedCompletionStatusEx(completion_port_handle_, entries, kUdpGetQueuedCompletionStatusSize, &num_entries, INFINITE, FALSE)) { RINFO("GetQueuedCompletionStatusEx failed."); break; } for (ULONG i = 0; i < num_entries; i++) { if (entries[i].lpOverlapped) { QueuedItem *w = (QueuedItem*)((byte*)entries[i].lpOverlapped - offsetof(QueuedItem, overlapped)); w->queue_cb->OnQueuedItemEvent(w, 0); } } } udp_socket_.CancelAllIO(); for (TcpSocketWin32 *tcp = tcp_socket_; tcp; tcp = tcp->next_) tcp->CancelAllIO(); while (HasOutstandingIO()) { ULONG num_entries = 0; if (!GetQueuedCompletionStatusEx(completion_port_handle_, entries, kUdpGetQueuedCompletionStatusSize, &num_entries, INFINITE, FALSE)) { RINFO("GetQueuedCompletionStatusEx failed."); break; } for (ULONG i = 0; i < num_entries; i++) { if (entries[i].lpOverlapped) { QueuedItem *w = (QueuedItem*)((byte*)entries[i].lpOverlapped - offsetof(QueuedItem, overlapped)); w->queue_cb->OnQueuedItemDelete(w); } } } } bool NetworkWin32::HasOutstandingIO() { if (udp_socket_.HasOutstandingIO()) return true; for (TcpSocketWin32 *tcp = tcp_socket_; tcp; tcp = tcp->next_) if (tcp->HasOutstandingIO()) return true; return false; } void NetworkWin32::StartThread() { assert(completion_port_handle_); DWORD thread_id; thread_ = CreateThread(NULL, 0, &NetworkThread, this, 0, &thread_id); SetThreadPriority(thread_, ABOVE_NORMAL_PRIORITY_CLASS); } void NetworkWin32::StopThread() { if (thread_ != NULL) { exit_thread_ = true; g_fail_malloc_flag = true; PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, NULL); WaitForSingleObject(thread_, INFINITE); CloseHandle(thread_); thread_ = NULL; exit_thread_ = false; g_fail_malloc_flag = false; } } void NetworkWin32::WakeUp() { PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, NULL); } void NetworkWin32::PostQueuedItem(QueuedItem *item) { PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, &item->overlapped); } bool NetworkWin32::Configure(int listen_port, int listen_port_tcp) { if (listen_port_tcp) RERROR("ListenPortTCP not supported in this version"); return udp_socket_.Configure(listen_port); } // Called from tunsafe thread void NetworkWin32::WriteUdpPacket(Packet *packet) { if (packet->protocol & kPacketProtocolUdp) { udp_socket_.WriteUdpPacket(packet); } else { tcp_socket_queue_.WritePacket(packet); } } ///////////////////////////////////////////////////////////////////////// PacketProcessor::PacketProcessor() { event_ = CreateEvent(NULL, FALSE, FALSE, NULL); last_ptr_ = &first_; first_ = NULL; exit_code_ = 0; timer_interrupt_ = false; packets_in_queue_ = 0; need_notify_ = 0; } PacketProcessor::~PacketProcessor() { first_ = NULL; last_ptr_ = &first_; CloseHandle(event_); } void CALLBACK PacketProcessor::ThreadPoolTimerCallback(PTP_CALLBACK_INSTANCE iTimerInstance, PVOID pContext, PTP_TIMER) { PacketProcessor *th = (PacketProcessor *)pContext; th->mutex_.Acquire(); th->timer_interrupt_ = true; if (th->need_notify_) { th->need_notify_ = 0; th->mutex_.Release(); SetEvent(th->event_); return; } th->mutex_.Release(); } void PacketProcessor::Reset() { QueuedItem *packet; packet = first_; first_ = NULL; exit_code_ = 0; last_ptr_ = &first_; timer_interrupt_ = false; while (packet) { QueuedItem *next = packet->queue_next; packet->queue_cb->OnQueuedItemDelete(packet); packet = next; } } int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) { int free_packets_ctr = 0; int overload = 0; int exit_code; QueuedItem *packet; PTP_TIMER threadpool_timer; QueueContext queue_context = {wg, backend}; threadpool_timer = CreateThreadpoolTimer(&ThreadPoolTimerCallback, this, NULL); static const int64 duetime = -10000000; // the unit is 100ns SetThreadpoolTimer(threadpool_timer, (FILETIME*)&duetime, 1000, 1000); mutex_.Acquire(); while (!(exit_code = exit_code_)) { FreeAllPackets(); if (timer_interrupt_) { timer_interrupt_ = false; need_notify_ = 0; mutex_.Release(); wg->SecondLoop(); backend->stats_mutex_.Acquire(); backend->stats_ = wg->GetStats(); float data[2] = { // unit is megabits/second backend->stats_.tun_bytes_in_per_second * (1.0f / 125000), backend->stats_.tun_bytes_out_per_second * (1.0f / 125000), }; backend->stats_collector_.AddSamples(data); backend->stats_mutex_.Release(); backend->delegate_->OnGraphAvailable(); backend->PushStats(); // Conserve memory every 10s if (free_packets_ctr++ == 10) { free_packets_ctr = 0; FreeAllPackets(); } if (overload) overload -= 1; } else if ((packet = first_) == NULL) { need_notify_ = 1; mutex_.Release(); WaitForSingleObject(event_, INFINITE); } else { // Steal the whole work queue first_ = NULL; last_ptr_ = &first_; int packets_in_queue = packets_in_queue_; packets_in_queue_ = 0; need_notify_ = 0; mutex_.Release(); if (packets_in_queue >= 1024) overload = 2; queue_context.overload = (overload != 0); do { QueuedItem *next = packet->queue_next; packet->queue_cb->OnQueuedItemEvent(packet, (uintptr_t)&queue_context); packet = next; } while (packet); } wg->RunAllMainThreadScheduled(); mutex_.Acquire(); } exit_code_ = 0; mutex_.Release(); SetThreadpoolTimer(threadpool_timer, nullptr, 0, 0); WaitForThreadpoolTimerCallbacks(threadpool_timer, true); CloseThreadpoolTimer(threadpool_timer); return exit_code; } void PacketProcessorTunCb::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) { PacketProcessor::QueueContext *context = (PacketProcessor::QueueContext *)extra; context->wg->HandleTunPacket(static_cast(qi)); } void PacketProcessorTunCb::OnQueuedItemDelete(QueuedItem *qi) { FreePacket(static_cast(qi)); } void PacketProcessorUdpCb::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) { PacketProcessor::QueueContext *context = (PacketProcessor::QueueContext *)extra; context->wg->HandleUdpPacket(static_cast(qi), context->overload); } void PacketProcessorUdpCb::OnQueuedItemDelete(QueuedItem *qi) { FreePacket(static_cast(qi)); } void PacketProcessor::PostExit(int exit_code) { mutex_.Acquire(); // Avoid race condition where mode_tun_failed is set during thread exit. if (exit_code_ != TunsafeBackendWin32::MODE_RESTART && exit_code_ != TunsafeBackendWin32::MODE_EXIT) exit_code_ = exit_code; mutex_.Release(); SetEvent(event_); } void PacketProcessor::PostPackets(Packet *first, Packet **end, int count) { mutex_.Acquire(); if (packets_in_queue_ >= HARD_MAXIMUM_QUEUE_SIZE) { mutex_.Release(); FreePackets(first, end, count); return; } assert(first != NULL); assert(first_ || last_ptr_ == &first_); packets_in_queue_ += count; *last_ptr_ = first; last_ptr_ = (QueuedItem**)end; assert(first_ || last_ptr_ == &first_); if (need_notify_) { need_notify_ = 0; mutex_.Release(); SetEvent(event_); return; } mutex_.Release(); } void PacketProcessor::ForcePost(QueuedItem *item) { item->queue_next = NULL; mutex_.Acquire(); packets_in_queue_ += 1; *last_ptr_ = item; last_ptr_ = &item->queue_next; if (need_notify_) { need_notify_ = 0; mutex_.Release(); SetEvent(event_); return; } mutex_.Release(); } bool GetNetLuidFromGuid(const char *adapter_guid, NET_LUID *luid) { char buffer[64]; UUID uuid; size_t len = strlen(adapter_guid); if (adapter_guid[0] != '{' || adapter_guid[len - 1] != '}' || len >= 64) return false; buffer[len - 2] = 0; memcpy(buffer, adapter_guid + 1, len - 2); RPC_STATUS status = UuidFromStringA((RPC_CSTR)buffer, &uuid); if (status != 0) return false; return ConvertInterfaceGuidToLuid((GUID*)&uuid, luid) == 0; } DWORD SetMtuOnNetworkAdapter(NET_LUID *InterfaceLuid, ADDRESS_FAMILY family, int new_mtu) { MIB_IPINTERFACE_ROW row; DWORD err; InitializeIpInterfaceEntry(&row); row.Family = family; row.InterfaceLuid = *InterfaceLuid; if ((err = GetIpInterfaceEntry(&row)) == 0) { row.NlMtu = new_mtu; if (row.Family == AF_INET) row.SitePrefixLength = 0; err = SetIpInterfaceEntry(&row); } return err; } DWORD SetMetricOnNetworkAdapter(NET_LUID *InterfaceLuid, ADDRESS_FAMILY family, int new_metric, int *old_metric) { MIB_IPINTERFACE_ROW row; DWORD err; if (old_metric) *old_metric = kMetricNone; InitializeIpInterfaceEntry(&row); row.Family = family; row.InterfaceLuid = *InterfaceLuid; if ((err = GetIpInterfaceEntry(&row)) == 0) { if (old_metric) *old_metric = row.UseAutomaticMetric ? kMetricAutomatic : row.Metric; row.Metric = new_metric; row.UseAutomaticMetric = (new_metric == kMetricAutomatic); if (row.Family == AF_INET) row.SitePrefixLength = 0; err = SetIpInterfaceEntry(&row); } return err; } static const char *PrintIPV6(const uint8 new_address[16]) { sockaddr_in6 sin6 = {0}; static char buf[100]; // cast to void* to work on VS2015 if (!inet_ntop(PF_INET6, (void*)new_address, buf, 100)) memcpy(buf, "unknown", 8); return buf; } static void AssignIpv6Address(const void *new_address, int new_cidr, WgCidrAddr *target) { target->size = 128; target->cidr = new_cidr; memcpy(target->addr, new_address, 16); } static int IsIpv6AddressInList(const std::vector &addresses, const void *ipv6_addr, int cidr) { int i = 0; for (auto it = addresses.begin(); it != addresses.end(); ++it, i++) { if (it->size == 128 && it->cidr == cidr && memcmp(it->addr, ipv6_addr, 16) == 0) return i; } return -1; } // Set new_cidr to 0 to clear it. static bool SetIPV6AddressOnInterface(NET_LUID *InterfaceLuid, const std::vector &addresses, std::vector *old_address) { NETIO_STATUS Status; PMIB_UNICASTIPADDRESS_TABLE table = NULL; if (old_address) old_address->clear(); Status = GetUnicastIpAddressTable(AF_INET6, &table); if (Status != 0) { RERROR("GetUnicastAddressTable Failed. Error %d\n", Status); return false; } uint64 matching_addr = 0; for (int i = 0; i < (int)table->NumEntries; i++) { MIB_UNICASTIPADDRESS_ROW *row = &table->Table[i]; if (!memcmp(&row->InterfaceLuid, InterfaceLuid, sizeof(NET_LUID))) { if (row->PrefixOrigin == 1 && row->SuffixOrigin == 1) { if (old_address) { WgCidrAddr tmp; AssignIpv6Address(&row->Address.Ipv6.sin6_addr, row->OnLinkPrefixLength, &tmp); old_address->push_back(tmp); } int idx = IsIpv6AddressInList(addresses, &row->Address.Ipv6.sin6_addr, row->OnLinkPrefixLength); if (idx >= 0 && idx < 64) { matching_addr |= (uint64)1 << idx; RINFO("Using IPv6 address: %s/%d", PrintIPV6((uint8*)&row->Address.Ipv6.sin6_addr), row->OnLinkPrefixLength); continue; } Status = DeleteUnicastIpAddressEntry(row); if (Status) RERROR("Error %d deleting IPv6 address: %s/%d", Status, PrintIPV6((uint8*)&row->Address.Ipv6.sin6_addr), row->OnLinkPrefixLength); else RINFO("Deleted IPv6 address: %s/%d", PrintIPV6((uint8*)&row->Address.Ipv6.sin6_addr), row->OnLinkPrefixLength); } } } FreeMibTable(table); // Add all ipv6 addresses that were not already set. bool success = true; int i = 0; for (auto it = addresses.begin(); it != addresses.end(); ++it, i++) { // skip it because of wrong type or already set? if (it->size != 128 || i < 64 && (matching_addr & ((uint64)1 << i))) continue; MIB_UNICASTIPADDRESS_ROW Row; InitializeUnicastIpAddressEntry(&Row); Row.OnLinkPrefixLength = it->cidr; Row.Address.si_family = AF_INET6; memcpy(&Row.Address.Ipv6.sin6_addr, it->addr, 16); Row.InterfaceLuid = *InterfaceLuid; Status = CreateUnicastIpAddressEntry(&Row); if (Status != 0) { RERROR("Error %d setting IPv6 address: %s/%d", Status, PrintIPV6(it->addr), it->cidr); success = false; } else { RINFO("Added IPV6 Address: %s/%d", PrintIPV6(it->addr), it->cidr); } } return success; } static bool SetIPV6DnsOnInterface(NET_LUID *InterfaceLuid, const IpAddr *new_address, size_t new_address_size) { char buf[128]; char ipv6[128]; bool isfirst = true; NET_IFINDEX InterfaceIndex; if (ConvertInterfaceLuidToIndex(InterfaceLuid, &InterfaceIndex)) return false; if (new_address_size) { for (size_t i = 0; i < new_address_size; i++) { if (new_address[i].sin.sin_family != AF_INET6) continue; if (!inet_ntop(AF_INET6, (void*)&new_address[i].sin6.sin6_addr, ipv6, sizeof(ipv6))) return false; if (isfirst) { isfirst = false; snprintf(buf, sizeof(buf), "netsh interface ipv6 set dnsservers name=%d static %s validate=no", InterfaceIndex, ipv6); } else { snprintf(buf, sizeof(buf), "netsh interface ipv6 add dnsservers name=%d %s validate=no", InterfaceIndex, ipv6); } if (!RunNetsh(buf)) return false; } return true; } else { snprintf(buf, sizeof(buf), "netsh interface ipv6 delete dns name=%d all", InterfaceIndex); return RunNetsh(buf); } } static uint32 ComputeIpv4DefaultRoute(uint32 ip, uint32 netmask) { uint32 default_route_v4 = (ip & netmask) | 1; if (default_route_v4 == ip) default_route_v4++; return default_route_v4; } static void ComputeIpv6DefaultRoute(const uint8 *ipv6_address, uint8 ipv6_cidr, uint8 *default_route_v6) { memcpy(default_route_v6, ipv6_address, 16); // clear the last bits of the ipv6 address to match the cidr. size_t n = (ipv6_cidr + 7) >> 3; memset(&default_route_v6[n], 0, 16 - n); if (n == 0) return; // adjust the final byte default_route_v6[n - 1] &= ~(0xff >> (ipv6_cidr & 7)); // set the very last byte to something default_route_v6[15] |= 1; // ensure it doesn't collide if (memcmp(default_route_v6, ipv6_address, 16) == 0) default_route_v6[15] ^= 3; } static bool AddMultipleCatchallRoutes(int inet, int bits, const uint8 *target, const NET_LUID &luid, std::vector *undo_array) { uint8 tmp[16] = {0}; bool success = true; for (int i = 0; i < (1 << bits); i++) { tmp[0] = i << (8 - bits); success &= AddRoute(inet, tmp, bits, target, &luid, undo_array); } return success; } TunWin32Adapter::TunWin32Adapter(DnsBlocker *dns_blocker, const char guid[ADAPTER_GUID_SIZE]) { handle_ = NULL; dns_blocker_ = dns_blocker; old_ipv6_metric_ = kMetricNone; old_ipv4_metric_ = kMetricNone; has_dns6_setting_ = false; guid_[0] = 0; if (guid) memcpy(guid_, guid, ADAPTER_GUID_SIZE); } TunWin32Adapter::~TunWin32Adapter() { } bool TunWin32Adapter::OpenAdapter(TunsafeBackendWin32 *backend, DWORD open_flags) { ULONG info[3]; DWORD len; assert(handle_ == NULL); backend_ = backend; handle_ = OpenTunAdapter(guid_, backend, open_flags); if (handle_ != NULL) { memset(info, 0, sizeof(info)); if (DeviceIoControl(handle_, TAP_IOCTL_GET_VERSION, &info, sizeof(info), &info, sizeof(info), &len, NULL)) { RINFO("TAP Driver Version %d.%d %s", (int)info[0], (int)info[1], (info[2] ? "(DEBUG)" : "")); } if (info[0] < 9 || info[0] == 9 && info[1] <= 8) { RERROR("TAP is too old. Go to https://tunsafe.com/download to upgrade the driver"); CloseHandle(handle_); handle_ = NULL; } } return (handle_ != NULL); } bool TunWin32Adapter::ConfigureAdapter(const TunInterface::TunConfig &&config, TunInterface::TunConfigOut *out) { DWORD len, err; out->enable_neighbor_discovery_spoofing = false; if (!RunPrePostCommand(config.pre_post_commands.pre_up)) { RERROR("Pre command failed!"); return false; } pre_down_ = std::move(config.pre_post_commands.pre_down); post_down_ = std::move(config.pre_post_commands.post_down); const WgCidrAddr *ipv4_addr = NULL; const WgCidrAddr *ipv6_addr = NULL; for (auto it = config.addresses.begin(); it != config.addresses.end(); ++it) { if (it->size == 32 && ipv4_addr == NULL) ipv4_addr = &*it; else if (it->size == 128 && ipv6_addr == NULL) ipv6_addr = &*it; } if (ipv4_addr == NULL) { RERROR("The TUN adapter on Windows requires an IPv4 address"); return false; } uint32 ipv4_netmask = CidrToNetmaskV4(ipv4_addr->cidr); uint32 ipv4_ip = ReadBE32(ipv4_addr->addr); // Set TAP-Windows TUN subnet mode if (1) { uint32 v[3]; v[0] = htonl(ipv4_ip); v[1] = htonl(ipv4_ip & ipv4_netmask); v[2] = htonl(ipv4_netmask); if (!DeviceIoControl(handle_, TAP_IOCTL_CONFIG_TUN, v, sizeof(v), v, sizeof(v), &len, NULL)) { RERROR("DeviceIoControl(TAP_IOCTL_CONFIG_TUN) failed"); return false; } } // Set DHCP IP/netmask { uint32 v[4]; v[0] = htonl(ipv4_ip); v[1] = htonl(ipv4_netmask); v[2] = htonl((ipv4_ip | ~ipv4_netmask) - 1); // x.x.x.254 v[3] = 31536000; // One year if (!DeviceIoControl(handle_, TAP_IOCTL_CONFIG_DHCP_MASQ, v, sizeof(v), v, sizeof(v), &len, NULL)) { RERROR("DeviceIoControl(TAP_IOCTL_CONFIG_DHCP_MASQ) failed"); return false; } } // Extract and set IPv4 DNS servers through DHCP { enum { kMaxDnsServers = 4 }; uint8 dhcp_options[2 + kMaxDnsServers * 4]; // max 4 dns servers uint32 num_dns = 0; for (auto it = config.dns.begin(); it != config.dns.end(); ++it) { if (it->sin.sin_family != AF_INET) continue; memcpy(&dhcp_options[2 + num_dns * 4], &it->sin.sin_addr, 4); if (++num_dns == kMaxDnsServers) break; } if (num_dns != 0) { dhcp_options[0] = 6; dhcp_options[1] = (uint8)(num_dns * 4); DWORD dhcp_options_size = (DWORD)(num_dns * 4 + 2); byte output[10]; if (!DeviceIoControl(handle_, TAP_IOCTL_CONFIG_DHCP_SET_OPT, (void*)dhcp_options, dhcp_options_size, output, sizeof(output), &len, NULL)) { RERROR("DeviceIoControl(TAP_IOCTL_CONFIG_DHCP_SET_OPT) failed"); return false; } } } // Get device MAC address if (!DeviceIoControl(handle_, TAP_IOCTL_GET_MAC, mac_adress_, 6, mac_adress_, sizeof(mac_adress_), &len, NULL)) { RERROR("DeviceIoControl(TAP_IOCTL_GET_MAC) failed"); } else { out->enable_neighbor_discovery_spoofing = true; memcpy(out->neighbor_discovery_spoofing_mac, mac_adress_, sizeof(out->neighbor_discovery_spoofing_mac)); } // Set driver media status to 'connected' ULONG status = TRUE; if (!DeviceIoControl(handle_, TAP_IOCTL_SET_MEDIA_STATUS, &status, sizeof(status), &status, sizeof(status), &len, NULL)) { RERROR("DeviceIoControl(TAP_IOCTL_SET_MEDIA_STATUS) failed"); return false; } bool has_interface_luid = GetNetLuidFromGuid(guid_, &interface_luid_); if (!has_interface_luid) { RERROR("Unable to determine interface luid for %s.", guid_); return false; } if (config.mtu) { err = SetMtuOnNetworkAdapter(&interface_luid_, AF_INET, config.mtu); if (err) RERROR("SetMtuOnNetworkAdapter IPv4 failed: %d", err); if (ipv6_addr) { err = SetMtuOnNetworkAdapter(&interface_luid_, AF_INET6, config.mtu); if (err) RERROR("SetMtuOnNetworkAdapter IPv6 failed: %d", err); } } has_dns6_setting_ = false; if (ipv6_addr) { SetIPV6AddressOnInterface(&interface_luid_, config.addresses, &old_ipv6_address_); // Check if we have at least one ipv6 setting for (auto it = config.dns.begin(); it != config.dns.end(); ++it) { if (it->sin.sin_family == AF_INET6) { has_dns6_setting_ = true; break; } } if (has_dns6_setting_ && !SetIPV6DnsOnInterface(&interface_luid_, config.dns.data(), config.dns.size())) { RERROR("SetIPV6DnsOnInterface: failed"); } } if (config.dns.size() && config.block_dns_on_adapters) { RINFO("Blocking standard DNS on all adapters"); dns_blocker_->BlockDnsExceptOnAdapter(interface_luid_, has_dns6_setting_); err = SetMetricOnNetworkAdapter(&interface_luid_, AF_INET, 2, &old_ipv4_metric_); if (err) RERROR("SetMetricOnNetworkAdapter IPv4 failed: %d", err); if (ipv6_addr) { err = SetMetricOnNetworkAdapter(&interface_luid_, AF_INET6, 2, &old_ipv6_metric_); if (err) RERROR("SetMetricOnNetworkAdapter IPv6 failed: %d", err); } } else { dns_blocker_->RestoreDns(); } g_killswitch_currconn = config.internet_blocking; uint8 ibs = (g_killswitch_currconn == kBlockInternet_Default) ? g_killswitch_want : g_killswitch_currconn; bool block_all_traffic_route = (ibs & kBlockInternet_Route) != 0; RouteInfo ri, ri6; // Delete any current /1 default routes and read some stuff from the routing table. if (!GetDefaultRouteAndDeleteOldRoutes(AF_INET, &interface_luid_, block_all_traffic_route, &config.excluded_routes, &ri)) { RERROR("Unable to read old default gateway and delete old default routes."); return false; } // Delete any current /1 default routes and read some stuff from the routing table. if (!GetDefaultRouteAndDeleteOldRoutes(AF_INET6, &interface_luid_, block_all_traffic_route, &config.excluded_routes, &ri6)) { RERROR("Unable to read old default gateway and delete old default routes for IPv6."); } if (block_all_traffic_route) { RINFO("Blocking all regular Internet traffic using routing rules"); NET_LUID localhost_luid; if (ConvertInterfaceIndexToLuid(1, &localhost_luid) || localhost_luid.Info.IfType != 24) { RERROR("Unable to get localhost luid - while adding route based blocking."); } else { g_killswitch_curr |= kBlockInternet_Route; uint32 dst[4] = {0}; if (!AddMultipleCatchallRoutes(AF_INET, 1, (uint8*)&dst, localhost_luid, NULL)) { RERROR("Unable to add routes for route based blocking."); DeactivateKillSwitch(0); return false; } if (!AddMultipleCatchallRoutes(AF_INET6, 1, (uint8*)&dst, localhost_luid, NULL)) { RERROR("Unable to add IPv6 routes for route based blocking."); DeactivateKillSwitch(0); return false; } } } if (ibs & kBlockInternet_Firewall) { RINFO("Blocking all regular Internet traffic using firewall rules"); g_killswitch_curr |= kBlockInternet_Firewall; if (!AddKillSwitchFirewall(interface_luid_, true, (ibs & kBlockInternet_AllowLocalNetworks) != 0)) { RERROR("Unable to activate firewall based kill switch"); DeactivateKillSwitch(0); return false; } } DeactivateKillSwitch(ibs); uint8 default_route_v4[4]; uint8 default_route_v6[16]; WriteBE32(default_route_v4, ComputeIpv4DefaultRoute(ipv4_ip, ipv4_netmask)); if (ipv6_addr) ComputeIpv6DefaultRoute(ipv6_addr->addr, ipv6_addr->cidr, default_route_v6); // Add all the routes that should go through the VPN for (auto it = config.included_routes.begin(); it != config.included_routes.end(); ++it) { if (it->cidr == 0) { // /0 gets changed to two /1 routes, to avoid overwriting the system's default route if (it->size == 32) { if (!AddMultipleCatchallRoutes(AF_INET, block_all_traffic_route ? 2 : 1, default_route_v4, interface_luid_, &routes_to_undo_)) RERROR("Unable to add new default ipv4 route."); } else if (it->size == 128 && ipv6_addr) { if (!AddMultipleCatchallRoutes(AF_INET6, block_all_traffic_route ? 2 : 1, default_route_v6, interface_luid_, &routes_to_undo_)) RERROR("Unable to add new default ipv6 route."); } continue; } // Avoid adding a route if it's a subset of the address if (IsWgCidrAddrSubsetOfAny(*it, config.addresses)) continue; if (it->size == 32) { AddRoute(AF_INET, it->addr, it->cidr, default_route_v4, &interface_luid_, &routes_to_undo_); } else if (it->size == 128 && ipv6_addr) { AddRoute(AF_INET6, it->addr, it->cidr, default_route_v6, &interface_luid_, &routes_to_undo_); } } // Add all the routes that should bypass vpn int warned = 0; for (auto it = config.excluded_routes.begin(); it != config.excluded_routes.end(); ++it) { if (it->size == 32) { if (ri.found_default_adapter) { AddRoute(AF_INET, it->addr, it->cidr, ri.default_gw, &ri.default_adapter, &routes_to_undo_); } else if (!(warned & 1)) { warned |= 1; RERROR("Unable to read old ipv4 default gateway"); } } else if (it->size == 128) { if (ri6.found_default_adapter) { AddRoute(AF_INET6, it->addr, it->cidr, ri6.default_gw, &ri6.default_adapter, &routes_to_undo_); } else if (!(warned & 2)) { warned |= 2; RERROR("Unable to read old ipv6 default gateway"); } } } NET_IFINDEX InterfaceIndex; if (ConvertInterfaceLuidToIndex(&interface_luid_, &InterfaceIndex)) { RERROR("Unable to get index of adapter"); return false; } if ((err = FlushIpNetTable2(AF_INET, InterfaceIndex)) != NO_ERROR) { RERROR("FlushIpNetTable failed: 0x%X", err); return false; } if (ipv6_addr != NULL) { if ((err = FlushIpNetTable2(AF_INET6, InterfaceIndex)) != NO_ERROR) { RERROR("FlushIpNetTable failed: 0x%X", err); return false; } } RunPrePostCommand(config.pre_post_commands.post_up); return true; } void TunWin32Adapter::CloseAdapter(bool is_restart) { RunPrePostCommand(pre_down_); if (handle_ != NULL) { ULONG status = FALSE; DWORD len; DeviceIoControl(handle_, TAP_IOCTL_SET_MEDIA_STATUS, &status, sizeof(status), &status, sizeof(status), &len, NULL); CloseHandle(handle_); handle_ = NULL; TunAdaptersInUse::GetInstance()->Release(backend_); } if (old_ipv6_address_.size()) SetIPV6AddressOnInterface(&interface_luid_, old_ipv6_address_, NULL); if (old_ipv4_metric_ != kMetricNone) SetMetricOnNetworkAdapter(&interface_luid_, AF_INET, old_ipv4_metric_, NULL); if (old_ipv6_metric_ != kMetricNone) SetMetricOnNetworkAdapter(&interface_luid_, AF_INET6, old_ipv6_metric_, NULL); if (has_dns6_setting_) SetIPV6DnsOnInterface(&interface_luid_, NULL, 0); old_ipv4_metric_ = old_ipv6_metric_ = -1; old_ipv6_address_.clear(); has_dns6_setting_ = false; for (auto it = routes_to_undo_.begin(); it != routes_to_undo_.end(); ++it) DeleteRoute(&*it); routes_to_undo_.clear(); if (!is_restart && dns_blocker_) dns_blocker_->RestoreDns(); RunPrePostCommand(post_down_); pre_down_.clear(); post_down_.clear(); } static bool RunOneCommand(const std::string &cmd) { std::string command = "cmd.exe /C " + cmd; STARTUPINFOA si = {0}; PROCESS_INFORMATION pi = {0}; HANDLE hstdout_wr = NULL, hstdout_rd = NULL; HANDLE hstdin_wr = NULL, hstdin_rd = NULL; bool result = false; SECURITY_ATTRIBUTES saAttr; saAttr.nLength = sizeof(SECURITY_ATTRIBUTES); saAttr.bInheritHandle = TRUE; saAttr.lpSecurityDescriptor = NULL; if (!CreatePipe(&hstdout_rd, &hstdout_wr, &saAttr, 0) || !CreatePipe(&hstdin_rd, &hstdin_wr, &saAttr, 0) || !SetHandleInformation(hstdout_rd, HANDLE_FLAG_INHERIT, 0) || !SetHandleInformation(hstdin_wr, HANDLE_FLAG_INHERIT, 0)) { goto out; } CloseHandle(hstdin_wr); hstdin_wr = NULL; si.cb = sizeof(si); si.dwFlags = STARTF_USESTDHANDLES; si.hStdError = hstdout_wr; si.hStdOutput = hstdout_wr; si.hStdInput = hstdin_rd; RINFO("Run: %s", cmd.c_str()); if (CreateProcessA(NULL, &command[0], NULL, NULL, TRUE, CREATE_NO_WINDOW, NULL, NULL, &si, &pi)) { DWORD exit_code = -1; char buf[1024]; DWORD bufend = 0, bufstart = 0; CloseHandle(hstdout_wr); hstdout_wr = NULL; for (;;) { DWORD bytes_read = 0; bool foundeof = (!ReadFile(hstdout_rd, buf + bufend, sizeof(buf) - bufend, &bytes_read, NULL) || bytes_read == 0); bufend += bytes_read; for(;;) { char *nl = (char*)memchr(buf + bufstart, '\n', bufend - bufstart); if (!nl) break; char *st = buf + bufstart; char *nl2 = nl; if (nl != buf + bufstart && nl[-1] == '\r') nl--; bufstart = (DWORD)(nl2 - buf + 1); RINFO("%.*s", nl - st, st); } if (bufend - bufstart == sizeof(buf) || foundeof) { if (bufend - bufstart) RINFO("%.*s", buf + bufstart, bufend - bufstart); bufstart = bufend = 0; } if (foundeof) break; if (bufstart) { bufend -= bufstart; memmove(buf, buf + bufstart, bufend); bufstart = 0; } } WaitForSingleObject(pi.hProcess, INFINITE); GetExitCodeProcess(pi.hProcess, &exit_code); CloseHandle(pi.hThread); CloseHandle(pi.hProcess); if (exit_code != 0) { RERROR("Command line failed (%d) : %s", exit_code, cmd.c_str()); } else { result = true; } } else { RERROR("CreateProcess failed: %s", cmd.c_str()); } CloseHandle(hstdout_rd); CloseHandle(hstdout_wr); CloseHandle(hstdin_rd); CloseHandle(hstdin_wr); out: return result; } bool TunWin32Adapter::RunPrePostCommand(const std::vector &vec) { bool success = true; for (auto it = vec.begin(); it != vec.end(); ++it) { if (!g_allow_pre_post) { RERROR("Pre/Post commands are disabled. Ignoring: %s", it->c_str()); } else { success &= RunOneCommand(*it); } } return success; } ////////////////////////////////////////////////////////////////////////////// TunWin32Iocp::TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker, backend->guid_), backend_(backend) { wqueue_end_ = &wqueue_; wqueue_ = NULL; wqueue_size_ = 0; thread_ = NULL; completion_port_handle_ = NULL; packet_handler_ = NULL; exit_thread_ = false; did_show_tun_queue_warning_ = false; } TunWin32Iocp::~TunWin32Iocp() { //assert(num_reads_ == 0 && num_writes_ == 0); assert(thread_ == NULL); CloseTun(false); FreePacketList(wqueue_); } bool TunWin32Iocp::Configure(const TunConfig &&config, TunConfigOut *out) { // Reconfigure while started? if (thread_ != NULL) { assert(completion_port_handle_); StopThread(); bool rv = Configure(std::move(config), out); StartThread(); return rv; } CloseTun(true); if (adapter_.OpenAdapter(backend_, FILE_FLAG_OVERLAPPED)) { completion_port_handle_ = CreateIoCompletionPort(adapter_.handle(), NULL, NULL, 0); if (completion_port_handle_ != NULL) { if (adapter_.ConfigureAdapter(std::move(config), out)) return true; } } CloseTun(false); return false; } void TunWin32Iocp::CloseTun(bool is_restart) { assert(thread_ == NULL); adapter_.CloseAdapter(is_restart); if (completion_port_handle_) { CloseHandle(completion_port_handle_); completion_port_handle_ = NULL; } } enum { kTunGetQueuedCompletionStatusSize = kConcurrentWriteTap + kConcurrentReadTap + 1 }; static inline bool AllocPacketFrom(Packet **list, int *counter, bool *exit_flag, Packet **res) { Packet *p; if (p = *list) { *list = Packet_NEXT(p); (*counter)--; p->data = p->data_buf; } else { if (!(p = AllocPacket())) return false; } *res = p; return true; } void TunWin32Iocp::ThreadMain() { OVERLAPPED_ENTRY entries[kTunGetQueuedCompletionStatusSize]; Packet *pending_writes = NULL; int num_reads = 0, num_writes = 0; Packet *finished_reads = NULL, **finished_reads_end; Packet *freed_packets = NULL, **freed_packets_end; int freed_packets_count = 0; DWORD err; if (!completion_port_handle_) return; while (!exit_thread_) { // Initiate more reads, reusing the Packet structures in |finished_writes|. for (int i = num_reads; i < kConcurrentReadTap; i++) { Packet *p; if (!AllocPacketFrom(&freed_packets, &freed_packets_count, &exit_thread_, &p)) break; ClearOverlapped(&p->overlapped); p->userdata = 0; if (!ReadFile(adapter_.handle(), p->data, kPacketCapacity, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) { FreePacket(p); RERROR("TunWin32: ReadFile failed 0x%X", err); if (err == ERROR_OPERATION_ABORTED || err == ERROR_FILE_NOT_FOUND) { RERROR("TAP driver stopped communicating. Attempting to restart.", err); // This can happen if we reinstall the TAP driver while there's an active connection. backend_->PostExit(TunsafeBackendWin32::MODE_TUN_FAILED); goto EXIT; } } else { num_reads++; } } assert(freed_packets_count >= 0); if (freed_packets_count >= 32) { FreePackets(freed_packets, freed_packets_end, freed_packets_count); freed_packets_count = 0; freed_packets_end = &freed_packets; } else if (freed_packets == NULL) { assert(freed_packets_count == 0); freed_packets_end = &freed_packets; } ULONG num_entries = 0; if (!GetQueuedCompletionStatusEx(completion_port_handle_, entries, kTunGetQueuedCompletionStatusSize, &num_entries, INFINITE, FALSE)) { RINFO("GetQueuedCompletionStatusEx failed."); break; } finished_reads_end = &finished_reads; int finished_reads_count = 0; // Go through the finished entries and determine which ones are reads, and which ones are writes. for (ULONG i = 0; i < num_entries; i++) { if (!entries[i].lpOverlapped) continue; // This is the dummy entry from |PostQueuedCompletionStatus| Packet *p = (Packet*)((byte*)entries[i].lpOverlapped - offsetof(Packet, overlapped)); if (p->userdata == 0) { num_reads--; if ((int)p->overlapped.Internal != 0) { RERROR("TunWin32::ReadComplete error 0x%X", (int)p->overlapped.Internal); FreePacket(p); continue; } p->size = (int)p->overlapped.InternalHigh; p->queue_cb = packet_handler_->tun_queue(); *finished_reads_end = p; finished_reads_end = &Packet_NEXT(p); finished_reads_count++; } else { num_writes--; if ((int)p->overlapped.Internal != 0) { RERROR("TunWin32::WriteComplete error 0x%X", (int)p->overlapped.Internal); FreePacket(p); continue; } freed_packets_count++; *freed_packets_end = p; freed_packets_end = &Packet_NEXT(p); } } *finished_reads_end = NULL; *freed_packets_end = NULL; if (finished_reads != NULL) packet_handler_->PostPackets(finished_reads, finished_reads_end, finished_reads_count); // Initiate more writes from |wqueue_| while (num_writes < kConcurrentWriteTap) { // Refill from queue if empty, avoid taking the mutex if it looks empty if (!pending_writes) { if (!wqueue_) break; mutex_.Acquire(); pending_writes = wqueue_; wqueue_end_ = &wqueue_; wqueue_ = NULL; wqueue_size_ = 0; mutex_.Release(); if (!pending_writes) break; } // Then issue writes Packet *p = pending_writes; pending_writes = Packet_NEXT(p); ClearOverlapped(&p->overlapped); p->userdata = 1; if (!WriteFile(adapter_.handle(), p->data, p->size, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) { RERROR("TunWin32: WriteFile failed 0x%X", err); FreePacket(p); } else { num_writes++; } } } EXIT: // Cancel all IO and wait for all completions CancelIo(adapter_.handle()); while (num_reads + num_writes) { ULONG num_entries = 0; if (!GetQueuedCompletionStatusEx(completion_port_handle_, entries, 1, &num_entries, INFINITE, FALSE)) { RINFO("GetQueuedCompletionStatusEx failed."); break; } if (!entries[0].lpOverlapped) continue; // This is the dummy entry from |PostQueuedCompletionStatus| Packet *p = (Packet*)((byte*)entries[0].lpOverlapped - offsetof(Packet, overlapped)); if (p->userdata == 0) { num_reads--; } else { num_writes--; } FreePacket(p); } FreePacketList(freed_packets); FreePacketList(pending_writes); } DWORD WINAPI TunWin32Iocp::TunThread(void *x) { TunWin32Iocp *xx = (TunWin32Iocp *)x; xx->ThreadMain(); return 0; } void TunWin32Iocp::StartThread() { DWORD thread_id; assert(thread_ == NULL); assert(completion_port_handle_ != NULL); thread_ = CreateThread(NULL, 0, &TunThread, this, 0, &thread_id); SetThreadPriority(thread_, ABOVE_NORMAL_PRIORITY_CLASS); } void TunWin32Iocp::StopThread() { exit_thread_ = true; PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, NULL); WaitForSingleObject(thread_, INFINITE); CloseHandle(thread_); thread_ = NULL; exit_thread_ = false; } void TunWin32Iocp::WriteTunPacket(Packet *packet) { Packet_NEXT(packet) = NULL; mutex_.Acquire(); if (wqueue_size_ >= HARD_MAXIMUM_TUN_QUEUE_SIZE) { mutex_.Release(); FreePacket(packet); if (!did_show_tun_queue_warning_) { did_show_tun_queue_warning_ = true; RERROR("TUN Queue Overload! This might happen if you use the NDIS6 driver on Windows 7."); } return; } wqueue_size_++; Packet *was_empty = wqueue_; *wqueue_end_ = packet; wqueue_end_ = &Packet_NEXT(packet); mutex_.Release(); if (was_empty == NULL) { // Notify the worker thread that it should attempt more writes PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, NULL); } } ////////////////////////////////////////////////////////////////////////////// TunsafeBackend::TunsafeBackend() { is_started_ = false; is_remote_ = false; ipv4_ip_ = 0; status_ = kStatusStopped; memset(public_key_, 0, sizeof(public_key_)); } TunsafeBackend::~TunsafeBackend() { } static bool GetKillSwitchRouteActive() { RouteInfo ri; return (GetDefaultRouteAndDeleteOldRoutes(AF_INET, NULL, TRUE, NULL, &ri) && ri.found_null_routes == 2); } static void RemoveKillSwitchRoute() { RouteInfo ri; GetDefaultRouteAndDeleteOldRoutes(AF_INET, NULL, FALSE, NULL, &ri); GetDefaultRouteAndDeleteOldRoutes(AF_INET6, NULL, FALSE, NULL, &ri); } TunsafeBackendWin32::TunsafeBackendWin32(Delegate *delegate) : delegate_(delegate), dns_resolver_(&dns_blocker_) { memset(&stats_, 0, sizeof(stats_)); wg_processor_ = NULL; token_request_ = 0; InitPacketMutexes(); worker_thread_ = NULL; last_tun_adapter_failed_ = 0; want_periodic_stats_ = false; guid_[0] = 0; if (g_hklm_reg_key == NULL) { RegCreateKeyEx(HKEY_LOCAL_MACHINE, "Software\\TunSafe", NULL, NULL, 0, KEY_ALL_ACCESS, NULL, &g_hklm_reg_key, NULL); g_killswitch_want = RegReadInt(g_hklm_reg_key, "KillSwitch", 0); g_killswitch_curr = GetKillSwitchRouteActive() * kBlockInternet_Route + GetKillSwitchFirewallActive() * kBlockInternet_Firewall; } delegate_->OnStateChanged(); } TunsafeBackendWin32::~TunsafeBackendWin32() { StopInner(false); TunAdaptersInUse::GetInstance()->Release(this); } void TunsafeBackendWin32::SetPublicKey(const uint8 key[32]) { memcpy(public_key_, key, 32); delegate_->OnStateChanged(); } struct PluginHolder { PluginHolder(PluginDelegate *del) : plugin(CreateTunsafePlugin(del)) {} ~PluginHolder() { delete plugin; } TunsafePlugin *plugin; }; DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) { TunsafeBackendWin32 *backend = (TunsafeBackendWin32*)bk; int stop_mode; int fast_retry_ctr = 0; for (;;) { TunWin32Iocp tun(&backend->dns_blocker_, backend); NetworkWin32 net; PluginHolder plugin(backend); WireguardProcessor wg_proc(&net, &tun, backend); wg_proc.dev().SetPlugin(plugin.plugin); plugin.plugin->Initialize(&wg_proc); net.udp().SetPacketHandler(&backend->packet_processor_); net.tcp_socket_queue().SetPacketHandler(&backend->packet_processor_); tun.SetPacketHandler(&backend->packet_processor_); if (backend->config_file_[0] && !ParseWireGuardConfigFile(&wg_proc, backend->config_file_, &backend->dns_resolver_)) goto getout_fail; if (!wg_proc.Start()) goto getout_fail; backend->SetPublicKey(wg_proc.dev().public_key()); backend->wg_processor_ = &wg_proc; backend->tunsafe_wg_plugin_ = plugin.plugin; net.StartThread(); tun.StartThread(); stop_mode = backend->packet_processor_.Run(&wg_proc, backend); net.StopThread(); tun.StopThread(); backend->wg_processor_ = NULL; backend->tunsafe_wg_plugin_ = NULL; // Keep DNS alive if (stop_mode != MODE_EXIT) tun.adapter().DisassociateDnsBlocker(); else backend->dns_resolver_.ClearCache(); FreeAllPackets(); if (stop_mode != MODE_TUN_FAILED) return 0; uint32 last_fail = GetTickCount(); fast_retry_ctr = (last_fail - backend->last_tun_adapter_failed_ < 5000) ? fast_retry_ctr + 1 : 0; backend->last_tun_adapter_failed_ = last_fail; backend->SetStatus((fast_retry_ctr >= 3) ? TunsafeBackend::kErrorTunPermanent : TunsafeBackend::kStatusTunRetrying); if (backend->status_ == TunsafeBackend::kErrorTunPermanent) { RERROR("Too many automatic restarts..."); goto getout_fail_noseterr; } Sleep(1000); } getout_fail: backend->status_ = TunsafeBackend::kErrorInitialize; backend->delegate_->OnStatusCode(TunsafeBackend::kErrorInitialize); getout_fail_noseterr: backend->dns_blocker_.RestoreDns(); return 0; } void TunsafeBackendWin32::SetStatus(StatusCode status) { status_ = status; delegate_->OnStatusCode(status); } bool TunsafeBackendWin32::Configure() { // it's always initialized return true; } void TunsafeBackendWin32::Teardown() { } bool TunsafeBackendWin32::SetTunAdapterName(const char *name) { assert(worker_thread_ == NULL); size_t len = strlen(name); if (len >= sizeof(guid_) || guid_[0]) return false; if (!TunAdaptersInUse::GetInstance()->Acquire(name, this)) return false; memcpy(guid_, name, len + 1); return true; } void TunsafeBackendWin32::RequestStats(bool enable) { want_periodic_stats_ = enable; PushStats(); } void TunsafeBackendWin32::PushStats() { if (want_periodic_stats_) { stats_mutex_.Acquire(); WgProcessorStats stats = stats_; stats_mutex_.Release(); delegate_->OnGetStats(stats); } } void TunsafeBackendWin32::Stop() { StopInner(false); delegate_->OnStatusCode(status_); delegate_->OnStateChanged(); } void TunsafeBackendWin32::Start(const char *config_file) { StopInner(true); dns_resolver_.ResetCancel(); g_killswitch_currconn = kBlockInternet_Default; is_started_ = true; token_request_ = 0; memset(public_key_, 0, sizeof(public_key_)); SetStatus(kStatusInitializing); delegate_->OnClearLog(); DWORD thread_id; config_file_ = _strdup(config_file); worker_thread_ = CreateThread(NULL, 0, &WorkerThread, this, 0, &thread_id); SetThreadPriority(worker_thread_, THREAD_PRIORITY_ABOVE_NORMAL); delegate_->OnStateChanged(); } void TunsafeBackendWin32::PostExit(int exit_code) { packet_processor_.PostExit(exit_code); } void TunsafeBackendWin32::StopInner(bool is_restart) { if (worker_thread_) { ipv4_ip_ = 0; dns_resolver_.Cancel(); PostExit(is_restart ? MODE_RESTART : MODE_EXIT); WaitForSingleObject(worker_thread_, INFINITE); CloseHandle(worker_thread_); worker_thread_ = NULL; free(config_file_); config_file_ = NULL; is_started_ = false; status_ = kStatusStopped; packet_processor_.Reset(); uint8 wanted_ibs = (g_killswitch_currconn == kBlockInternet_Default) ? g_killswitch_want : g_killswitch_currconn; if (!is_restart && !(wanted_ibs & kBlockInternet_BlockOnDisconnect)) DeactivateKillSwitch(kBlockInternet_Off); } } void TunsafeBackendWin32::ResetStats() { } LinearizedGraph *TunsafeBackendWin32::GetGraph(int type) { if (type < 0 || type >= 4) return NULL; size_t size = sizeof(LinearizedGraph) + 2 * (sizeof(uint32) + sizeof(float) * 120); LinearizedGraph *graph = (LinearizedGraph *)malloc(size); if (graph) { graph->total_size = (uint32)size; graph->num_charts = 2; graph->graph_type = type; memset(graph->reserved, 0, sizeof(graph->reserved)); stats_mutex_.Acquire(); uint8 *ptr = (uint8*)(graph + 1); for (size_t i = 0; i < 2; i++) { *(uint32*)ptr = 120; ptr += 4; const StatsCollector::TimeSeries *series = stats_collector_.GetTimeSeries((int)i, type); memcpy(postinc(ptr, (series->size - series->shift) * sizeof(float)), series->data + series->shift, (series->size - series->shift) * sizeof(float)); memcpy(postinc(ptr, series->shift * sizeof(float)), series->data, series->shift * sizeof(float)); } stats_mutex_.Release(); } return graph; } InternetBlockState TunsafeBackendWin32::GetInternetBlockState() { return (InternetBlockState)(g_killswitch_want | (g_killswitch_curr ? kBlockInternet_Active : 0)); } static void DeactivateKillSwitch(uint32 want) { // Disable blocking without reconnecting uint32 maybeon = g_killswitch_curr; if ((maybeon & kBlockInternet_Route) > (want & kBlockInternet_Route)) { if (g_killswitch_curr & kBlockInternet_Route) { g_killswitch_curr &= ~kBlockInternet_Route; RINFO("Removing the routing rule internet block"); } RemoveKillSwitchRoute(); } if ((maybeon & kBlockInternet_Firewall) > (want & kBlockInternet_Firewall)) { if (g_killswitch_curr & kBlockInternet_Firewall) { g_killswitch_curr &= ~kBlockInternet_Firewall; RINFO("Removing the firewall internet block"); } RemoveKillSwitchFirewall(); } } void TunsafeBackendWin32::SetInternetBlockState(InternetBlockState want) { if (worker_thread_ == NULL && !(want & kBlockInternet_BlockOnDisconnect) || !(want & kBlockInternet_Active)) DeactivateKillSwitch(kBlockInternet_Off); else DeactivateKillSwitch(want); int value = want & 0xff; g_killswitch_want = value; RegWriteInt(g_hklm_reg_key, "KillSwitch", (int)value); delegate_->OnStateChanged(); } void TunsafeBackendWin32::SetServiceStartupFlags(uint32 flags) { // not used } std::string TunsafeBackendWin32::GetConfigFileName() { return std::string(); } struct ConfigQueueItem : QueuedItem, QueuedItemCallback { virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override; virtual void OnQueuedItemDelete(QueuedItem *ow) override; enum Type { SendConfigurationProtocolPacket, SubmitToken }; Type type; std::string message; uint32 ident; }; void ConfigQueueItem::OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) { PacketProcessor::QueueContext *context = (PacketProcessor::QueueContext *)extra; if (type == SendConfigurationProtocolPacket) { std::string reply; WgConfig::HandleConfigurationProtocolMessage(context->wg, std::move(message), &reply); context->backend->delegate_->OnConfigurationProtocolReply(ident, std::move(reply)); } else { context->backend->tunsafe_wg_plugin_->SubmitToken((const uint8*)message.data(), message.size()); } delete this; } void ConfigQueueItem::OnQueuedItemDelete(QueuedItem *ow) { delete this; } void TunsafeBackendWin32::SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) { ConfigQueueItem *queue_item = new ConfigQueueItem; queue_item->type = ConfigQueueItem::SendConfigurationProtocolPacket; queue_item->ident = identifier; queue_item->message = std::move(message); queue_item->queue_cb = queue_item; packet_processor_.ForcePost(queue_item); } void TunsafeBackendWin32::SubmitToken(const std::string &&message) { // Clear out the old token request so GetTokenRequest returns zero. token_request_ = 0; ConfigQueueItem *queue_item = new ConfigQueueItem; queue_item->type = ConfigQueueItem::SubmitToken; queue_item->message = std::move(message); queue_item->queue_cb = queue_item; packet_processor_.ForcePost(queue_item); } uint32 TunsafeBackendWin32::GetTokenRequest() { return token_request_; } // This is called on the wireguard thread whenever it needs a token, // it should reschedule void TunsafeBackendWin32::OnRequestToken(WgPeer *peer, uint32 type) { token_request_ = type; delegate_->OnStateChanged(); } void TunsafeBackendWin32::OnConnected() { if (status_ != TunsafeBackend::kStatusConnected) { const WgCidrAddr *ipv4_addr = NULL; for (const WgCidrAddr &x : wg_processor_->addr()) { if (x.size == 32) { ipv4_addr = &x; break; } } ipv4_ip_ = ipv4_addr ? ReadBE32(ipv4_addr->addr) : 0; if (status_ != TunsafeBackend::kStatusReconnecting) { char buf[kSizeOfAddress]; RINFO("Connection established. IP %s", ipv4_addr ? print_ip_prefix(buf, AF_INET, ipv4_addr->addr, -1) : "(none)"); } SetStatus(TunsafeBackend::kStatusConnected); } } void TunsafeBackendWin32::OnConnectionRetry(uint32 attempts) { if (status_ == TunsafeBackend::kStatusInitializing) SetStatus(TunsafeBackend::kStatusConnecting); else if (attempts >= 3 && status_ == TunsafeBackend::kStatusConnected) SetStatus(TunsafeBackend::kStatusReconnecting); } void TunsafeBackend::Delegate::DoWork() { // implemented by subclasses } TunsafeBackendDelegateThreaded::TunsafeBackendDelegateThreaded(TunsafeBackend::Delegate *delegate, const std::function &callback) { callback_ = callback; delegate_ = delegate; } TunsafeBackendDelegateThreaded::~TunsafeBackendDelegateThreaded() { for (auto it = incoming_entry_.begin(); it != incoming_entry_.end(); ++it) FreeEntry(&*it); } void TunsafeBackendDelegateThreaded::FreeEntry(Entry *e) { if (e->lparam) { if (e->which == Id_OnConfigurationProtocolReply) delete (std::string*)e->lparam; else free((void*)e->lparam); e->lparam = NULL; } } void TunsafeBackendDelegateThreaded::DoWork() { mutex_.Acquire(); std::swap(incoming_entry_, processing_entry_); mutex_.Release(); TunsafeBackend::Delegate *delegate = delegate_; for (auto it = processing_entry_.begin(); it != processing_entry_.end(); ++it) { switch (it->which) { case Id_OnGetStats: delegate->OnGetStats(*(WgProcessorStats*)it->lparam); break; case Id_OnStateChanged: delegate->OnStateChanged(); break; case Id_OnLogLine: delegate->OnLogLine((const char**)&it->lparam); break; case Id_OnStatusCode: delegate->OnStatusCode((TunsafeBackend::StatusCode)it->wparam); break; case Id_OnClearLog: delegate->OnClearLog(); break; case Id_OnGraphAvailable: delegate->OnGraphAvailable(); break; case Id_OnConfigurationProtocolReply: delegate->OnConfigurationProtocolReply(it->wparam, std::move(*(std::string*)it->lparam)); break; } FreeEntry(&*it); } processing_entry_.clear(); } void TunsafeBackendDelegateThreaded::AddEntry(Which which, intptr_t lparam, uint32 wparam) { mutex_.Acquire(); bool was_empty = incoming_entry_.empty(); incoming_entry_.emplace_back(which, wparam, lparam); mutex_.Release(); if (was_empty) callback_(); } void TunsafeBackendDelegateThreaded::OnGetStats(const WgProcessorStats &stats) { AddEntry(Id_OnGetStats, (intptr_t)memdup(&stats, sizeof(stats))); } void TunsafeBackendDelegateThreaded::OnGraphAvailable() { AddEntry(Id_OnGraphAvailable); } void TunsafeBackendDelegateThreaded::OnStateChanged() { AddEntry(Id_OnStateChanged); } void TunsafeBackendDelegateThreaded::OnLogLine(const char **s) { const char *ss = *s; *s = NULL; AddEntry(Id_OnLogLine, (intptr_t)ss); } void TunsafeBackendDelegateThreaded::OnStatusCode(TunsafeBackend::StatusCode status) { AddEntry(Id_OnStatusCode, 0, status); } void TunsafeBackendDelegateThreaded::OnClearLog() { AddEntry(Id_OnClearLog); } void TunsafeBackendDelegateThreaded::OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) { AddEntry(Id_OnConfigurationProtocolReply, (intptr_t)new std::string(std::move(reply)), ident); } TunsafeBackend::Delegate::~Delegate() { } TunsafeBackend *CreateNativeTunsafeBackend(TunsafeBackend::Delegate *delegate) { return new TunsafeBackendWin32(delegate); } TunsafeBackend::Delegate *CreateTunsafeBackendDelegateThreaded(TunsafeBackend::Delegate *delegate, const std::function &callback) { return new TunsafeBackendDelegateThreaded(delegate, callback); } /////////////////////////////////////////////////// void StatsCollector::Init() { Accumulator *acc = &accum_[0][0]; static const int kAccMax[TIMEVALS] = {5, 6, 10, 0}; // Configure all stats channels for (uint32 channel = 0; channel != CHANNELS; channel++) { for (uint32 timeval = 0; timeval != TIMEVALS; timeval++, acc++) { acc->acc = 0; acc->dirty = false; acc->acc_count = 0; acc->acc_max = kAccMax[timeval]; acc->data.size = 120; acc->data.data = (float*)calloc(sizeof(float), acc->data.size); acc->data.shift = 0; } } } void StatsCollector::AddToGraphDataSource(StatsCollector::TimeSeries *ts, float value) { ts->data[ts->shift] = value; if (++ts->shift == ts->size) ts->shift = 0; } void StatsCollector::AddToAccumulators(StatsCollector::Accumulator *acc, float rval) { for (;;) { AddToGraphDataSource(&acc->data, rval); acc->dirty = true; acc->acc += rval; if (acc->acc_max == 0 || ++acc->acc_count < acc->acc_max) break; rval = acc->acc / (float)acc->acc_count; acc->acc_count = 0; acc->acc = 0.0f; acc++; } } void StatsCollector::AddSamples(float data[CHANNELS]) { for (size_t i = 0; i < CHANNELS; i++) AddToAccumulators(&accum_[i][0], data[i]); } TunAdaptersInUse::TunAdaptersInUse() { num_inuse_ = 0; } bool TunAdaptersInUse::Acquire(const char guid[ADAPTER_GUID_SIZE], void *context) { size_t len = strlen(guid); if (len >= ADAPTER_GUID_SIZE) return false; ScopedLock scoped_lock(&mutex_); Entry *e = entry_; for (uint32 n = num_inuse_; ; n--, e++) { if (n == 0) { if (num_inuse_ == kMaxAdaptersInUse) return false; num_inuse_++; e->context = context; e->count = 0; memcpy(e->guid, guid, len + 1); return true; } if (!strcmp(e->guid, guid)) { if (e->context != context) return false; e->count++; return true; } } } void TunAdaptersInUse::Release(void *context) { ScopedLock scoped_lock(&mutex_); Entry *e = entry_; for (uint32 n = num_inuse_; n; n--, e++) { if (e->context == context) { if (e->count-- == 0) *e = entry_[num_inuse_-- - 1]; break; } } } void *TunAdaptersInUse::LookupContextFromGuid(const char guid[ADAPTER_GUID_SIZE]) { ScopedLock scoped_lock(&mutex_); Entry *e = entry_; for (uint32 n = num_inuse_; n; n--, e++) { if (!strcmp(e->guid, guid)) return e->context; } return NULL; } char *TunAdaptersInUse::GetAllGuid() { ScopedLock scoped_lock(&mutex_); char *rv = (char*)malloc(ADAPTER_GUID_SIZE * num_inuse_ + 1), *p = rv; if (rv) { Entry *e = entry_; for (uint32 n = num_inuse_; n; n--, e++) { size_t len = strlen(e->guid); p[len] = '\n'; memcpy(p, e->guid, len); p += len + 1; } *p = 0; } return rv; } static TunAdaptersInUse g_tun_adapters_in_use; TunAdaptersInUse *TunAdaptersInUse::GetInstance() { return &g_tun_adapters_in_use; } TunsafeBackend *TunsafeBackend::FindBackendByTunGuid(const char *guid) { return (TunsafeBackend*)TunAdaptersInUse::GetInstance()->LookupContextFromGuid(guid); } char *TunsafeBackend::GetAllGuid() { return TunAdaptersInUse::GetInstance()->GetAllGuid(); }