diff --git a/.gitignore b/.gitignore index 71ce91a..e1c4422 100644 --- a/.gitignore +++ b/.gitignore @@ -1,18 +1,10 @@ -/Debug/ -/Release/ -/ipzip2/Debug/ -/Build -/Win32/ /TunSafe.aps /*.sdf -/*vcxproj.user +*vcxproj.user /*.opensdf /*.suo /.vs/ -/x64/ -/Azire.conf +/build/ /*.psess /*.vspx /installer/*.zip -/config/ -/tunsafe.com/ diff --git a/TunSafe.sln b/TunSafe.sln index 40906e6..ff47cf6 100644 --- a/TunSafe.sln +++ b/TunSafe.sln @@ -5,6 +5,8 @@ VisualStudioVersion = 15.0.26403.7 MinimumVisualStudioVersion = 10.0.40219.1 Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "TunSafe", "TunSafe.vcxproj", "{626FBC16-64C6-407D-BC2B-6C087794E0D0}" EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ts", "ts.vcxproj", "{443E105E-8D7C-401F-BD41-D3F56C76104B}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Win32 = Debug|Win32 @@ -21,8 +23,19 @@ Global {626FBC16-64C6-407D-BC2B-6C087794E0D0}.Release|Win32.Build.0 = Release|Win32 {626FBC16-64C6-407D-BC2B-6C087794E0D0}.Release|x64.ActiveCfg = Release|x64 {626FBC16-64C6-407D-BC2B-6C087794E0D0}.Release|x64.Build.0 = Release|x64 + {443E105E-8D7C-401F-BD41-D3F56C76104B}.Debug|Win32.ActiveCfg = Debug|Win32 + {443E105E-8D7C-401F-BD41-D3F56C76104B}.Debug|Win32.Build.0 = Debug|Win32 + {443E105E-8D7C-401F-BD41-D3F56C76104B}.Debug|x64.ActiveCfg = Debug|x64 + {443E105E-8D7C-401F-BD41-D3F56C76104B}.Debug|x64.Build.0 = Debug|x64 + {443E105E-8D7C-401F-BD41-D3F56C76104B}.Release|Win32.ActiveCfg = Release|Win32 + {443E105E-8D7C-401F-BD41-D3F56C76104B}.Release|Win32.Build.0 = Release|Win32 + {443E105E-8D7C-401F-BD41-D3F56C76104B}.Release|x64.ActiveCfg = Release|x64 + {443E105E-8D7C-401F-BD41-D3F56C76104B}.Release|x64.Build.0 = Release|x64 EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {F2DD9ED8-DDEA-4B40-9208-41726750D33D} + EndGlobalSection EndGlobal diff --git a/TunSafe.vcxproj b/TunSafe.vcxproj index 2645b39..a85d424 100644 --- a/TunSafe.vcxproj +++ b/TunSafe.vcxproj @@ -22,7 +22,7 @@ {626FBC16-64C6-407D-BC2B-6C087794E0D0} Win32Proj TunSafe - 10.0.15063.0 + 10.0.17134.0 TunSafe @@ -72,24 +72,28 @@ true TunSafe - $(SolutionDir)$(Platform)\$(Configuration)\ - $(Platform)\$(Configuration)\ + $(SolutionDir)build\$(Platform)_$(Configuration)\ + $(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\ true $(VC_ExecutablePath_x64);$(WindowsSDK_ExecutablePath);$(VS_ExecutablePath);$(MSBuild_ExecutablePath);$(FxCopDir);$(PATH);C:\Bin\Dev\nasm TunSafe + $(SolutionDir)build\$(Platform)_$(Configuration)\ + $(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\ false TunSafe - $(SolutionDir)$(Platform)\$(Configuration)\ - $(Platform)\$(Configuration)\ + $(SolutionDir)build\$(Platform)_$(Configuration)\ + $(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\ false $(VC_ExecutablePath_x64);$(WindowsSDK_ExecutablePath);$(VS_ExecutablePath);$(MSBuild_ExecutablePath);$(FxCopDir);$(PATH);C:\Bin\Dev\nasm TunSafe + $(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\ + $(SolutionDir)build\$(Platform)_$(Configuration)\ @@ -98,6 +102,7 @@ Disabled WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions);_CRT_SECURE_NO_WARNINGS;_CRT_SECURE_NO_WARNINGS . + false Windows @@ -114,6 +119,7 @@ . + false Windows @@ -133,6 +139,8 @@ WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions);_CRT_SECURE_NO_WARNINGS MultiThreaded . + false + false Windows @@ -157,6 +165,8 @@ AnySuitable true . + false + false Windows @@ -169,8 +179,10 @@ + + @@ -196,6 +208,7 @@ + diff --git a/TunSafe.vcxproj.filters b/TunSafe.vcxproj.filters index 2a444bb..01e1ff8 100644 --- a/TunSafe.vcxproj.filters +++ b/TunSafe.vcxproj.filters @@ -89,6 +89,12 @@ Source Files\Win32 + + Source Files\Win32 + + + Source Files\Win32 + @@ -154,6 +160,9 @@ Source Files + + Source Files\Win32 + diff --git a/build_freebsd.sh b/build_freebsd.sh index b546216..2f86855 100755 --- a/build_freebsd.sh +++ b/build_freebsd.sh @@ -1,4 +1,4 @@ g++7 -I . -O2 -DNDEBUG -static -mssse3 -o tunsafe benchmark.cpp tunsafe_cpu.cpp wireguard_config.cpp ip_to_peer_map.cpp tunsafe_threading.cpp \ -wireguard.cpp wireguard_proto.cpp util.cpp network_bsd.cpp network_bsd_common.cpp \ +wireguard.cpp wireguard_proto.cpp ts.cpp util.cpp network_bsd.cpp network_bsd_common.cpp \ crypto/blake2s.cpp crypto/blake2s_sse.cpp crypto/chacha20poly1305.cpp crypto/curve25519-donna.cpp \ crypto/siphash.cpp crypto/chacha20_x64_gas.s crypto/poly1305_x64_gas.s ipzip2/ipzip2.cpp -lrt -pthread diff --git a/build_linux.sh b/build_linux.sh index 026955e..850dcdc 100755 --- a/build_linux.sh +++ b/build_linux.sh @@ -1,6 +1,6 @@ #!/bin/sh clang++-6.0 -c -march=skylake-avx512 crypto/poly1305_x64_gas.s crypto/chacha20_x64_gas.s -clang++-6.0 -I . -O3 -DNDEBUG -mssse3 -pthread -lrt -o tunsafe util.cpp wireguard_config.cpp wireguard.cpp ip_to_peer_map.cpp tunsafe_threading.cpp \ +clang++-6.0 -I . -O3 -DNDEBUG -mssse3 -pthread -lrt -o tunsafe util.cpp wireguard_config.cpp wireguard.cpp ts.cpp ip_to_peer_map.cpp tunsafe_threading.cpp \ wireguard_proto.cpp network_bsd.cpp network_bsd_common.cpp tunsafe_cpu.cpp benchmark.cpp crypto/blake2s.cpp crypto/blake2s_sse.cpp crypto/chacha20poly1305.cpp \ crypto/curve25519-donna.cpp crypto/siphash.cpp chacha20_x64_gas.o crypto/aesgcm/aesni_gcm_x64_gas.s \ crypto/aesgcm/aesni_x64_gas.s crypto/aesgcm/aesgcm.cpp poly1305_x64_gas.o ipzip2/ipzip2.cpp \ diff --git a/build_osx.sh b/build_osx.sh index 29a02c0..3e905f9 100755 --- a/build_osx.sh +++ b/build_osx.sh @@ -3,8 +3,8 @@ set -e clang++ -c -mavx512f -mavx512vl crypto/poly1305_x64_gas_macosx.s crypto/chacha20_x64_gas_macosx.s -clang++ -g -O3 -I . -std=c++11 -DNDEBUG=1 -fno-exceptions -fno-rtti -ffunction-sections -o tunsafe \ -wireguard_config.cpp ip_to_peer_map.cpp tunsafe_threading.cpp wireguard.cpp wireguard_proto.cpp util.cpp network_bsd.cpp network_bsd_common.cpp benchmark.cpp tunsafe_cpu.cpp \ +clang++ -g -O3 -I . -std=c++11 -DNDEBUG=1 -Wno-deprecated-declarations -fno-exceptions -fno-rtti -ffunction-sections -o tunsafe \ +wireguard_config.cpp ip_to_peer_map.cpp tunsafe_threading.cpp wireguard.cpp wireguard_proto.cpp ts.cpp util.cpp network_bsd.cpp network_bsd_common.cpp benchmark.cpp tunsafe_cpu.cpp \ crypto/blake2s.cpp crypto/blake2s_sse.cpp crypto/chacha20poly1305.cpp crypto/curve25519-donna.cpp \ crypto/siphash.cpp crypto/aesgcm/aesgcm.cpp ipzip2/ipzip2.cpp \ crypto/aesgcm/aesni_gcm_x64_gas_macosx.s crypto/aesgcm/aesni_x64_gas_macosx.s crypto/aesgcm/ghash_x64_gas_macosx.s \ diff --git a/installer/tunsafe.nsi b/installer/tunsafe.nsi index 044f485..8550365 100644 --- a/installer/tunsafe.nsi +++ b/installer/tunsafe.nsi @@ -57,10 +57,12 @@ Section "TunSafe Client" SecTunSafe DetailPrint "Installing 64-bit version of TunSafe." SetOutPath "$INSTDIR" File "x64\TunSafe.exe" + File "x64\ts.exe" ${Else} DetailPrint "Installing 32-bit version of TunSafe." SetOutPath "$INSTDIR" File "x86\TunSafe.exe" + File "x86\ts.exe" ${EndIf} File "License.txt" File "ChangeLog.txt" @@ -205,6 +207,7 @@ Section "Uninstall" Delete "$INSTDIR\TunSafe.exe" + Delete "$INSTDIR\ts.exe" Delete "$INSTDIR\License.txt" Delete "$INSTDIR\ChangeLog.txt" Delete "$INSTDIR\Config\TunSafe.conf" diff --git a/ip_to_peer_map.cpp b/ip_to_peer_map.cpp index 93b8dd4..48539bb 100644 --- a/ip_to_peer_map.cpp +++ b/ip_to_peer_map.cpp @@ -5,6 +5,8 @@ #include "bit_ops.h" #include #include +#include +#include "util.h" IpToPeerMap::IpToPeerMap() { @@ -13,18 +15,22 @@ IpToPeerMap::IpToPeerMap() { IpToPeerMap::~IpToPeerMap() { } -bool IpToPeerMap::InsertV4(uint32 ip, int cidr, void *peer) { - ipv4_.Insert(ip, cidr, peer); - return true; +void *IpToPeerMap::InsertV4(uint32 ip, int cidr, void *peer) { + ipv4_.Insert(ip, cidr, &peer); + return peer; } -bool IpToPeerMap::InsertV6(const void *addr, int cidr, void *peer) { +void *IpToPeerMap::InsertV6(const void *addr, int cidr, void *peer) { Entry6 e; + for (auto it = ipv6_.begin(); it != ipv6_.end(); ++it) { + if (it->cidr_len == cidr && memcmp(it->ip, addr, 16) == 0) + return exch(it->peer, peer); + } e.cidr_len = cidr; e.peer = peer; memcpy(e.ip, addr, 16); ipv6_.push_back(e); - return true; + return NULL; } void *IpToPeerMap::LookupV4(uint32 ip) { @@ -43,6 +49,19 @@ void *IpToPeerMap::LookupV6DefaultPeer() { return NULL; } +void IpToPeerMap::RemoveV4(uint32 ip, int cidr) { + ipv4_.Delete(ip, cidr); +} + +void IpToPeerMap::RemoveV6(const void *addr, int cidr) { + for (auto it = ipv6_.begin(); it != ipv6_.end(); ++it) { + if (it->cidr_len == cidr && memcmp(it->ip, addr, 16) == 0) { + ipv6_.erase(it); + return; + } + } +} + static int CalculateIPv6CommonPrefix(const uint8 *a, const uint8 *b) { uint64 x = ToBE64(*(uint64*)&a[0] ^ *(uint64*)&b[0]); uint64 y = ToBE64(*(uint64*)&a[8] ^ *(uint64*)&b[8]); @@ -62,20 +81,6 @@ void *IpToPeerMap::LookupV6(const void *addr) { return best_peer; } -void IpToPeerMap::RemovePeer(void *peer) { - assert(0); - // todo: remove peer also from ipv4_ - { - size_t n = ipv6_.size(); - Entry6 *r = &ipv6_[0], *w = r; - for (size_t i = 0; i != n; i++, r++) { - if (r->peer != peer) - *w++ = *r; - } - ipv6_.resize(w - &ipv6_[0]); - } -} - #pragma warning (disable: 4200) // warning C4200: nonstandard extension used: zero-sized array in struct/union struct RoutingTrie32::Node { uint32 key; @@ -175,7 +180,7 @@ RoutingTrie32::~RoutingTrie32() { RoutingTrie32::Value RoutingTrie32::Lookup(uint32 ip) { uint32 key = ip; Node *n = root_, *pn = n, *ppn; - int cindex = 0; + uint32 cindex = 0; if (!n) return NULL; // Find the longest prefix match @@ -232,7 +237,7 @@ backtrace: } // strip lsb of cindex and find child cindex &= cindex - 1; - assert(cindex < (1 << pn->bits)); + assert(cindex < (1U << pn->bits)); n = pn->child[cindex]; if (!NODE_IS_NULL_OR_OLEAF(n)) break; @@ -246,7 +251,7 @@ backtrace: } } -bool RoutingTrie32::InsertLeafInto(Node **nn, uint8 leaf_pos, Value value) { +bool RoutingTrie32::InsertLeafInto(Node **nn, uint8 leaf_pos, Value *valuep) { // put higher cidr higher up Node *n = *nn; assert(IS_LEAF(n)); @@ -255,12 +260,12 @@ bool RoutingTrie32::InsertLeafInto(Node **nn, uint8 leaf_pos, Value value) { if (leaf_pos < n->pos) break; if (leaf_pos == n->pos) { - n->leaf_value = value; + std::swap(n->leaf_value, *valuep); return true; } nn = &n->leaf_next; } while ((n = *nn) != NULL); - Node *leaf = NewLeaf(key, leaf_pos, value); + Node *leaf = NewLeaf(key, leaf_pos, *valuep); if (leaf == NULL) return false; leaf->leaf_next = *nn; @@ -283,14 +288,14 @@ void RoutingTrie32::PutChild(Node *pn, uint32 i, Node *n) { assert(pn->full_children < 0x80000000); } -bool RoutingTrie32::Insert(uint32 ip, int cidr, Value value) { +bool RoutingTrie32::Insert(uint32 ip, int cidr, Value *valuep) { uint32 key = ip; Node **nn = &root_, *n = root_, *pn = NULL, *leaf, *tn = NULL, *leaf_to_free = NULL; uint8 leaf_pos = 32 - cidr; - + if (n == NULL) { - root_ = NewLeaf(key, leaf_pos, value); - return (root_ != NULL); + root_ = NewLeaf(key, leaf_pos, exch_null(*valuep)); + return false; } assert(!NODE_IS_OLEAF(n)); @@ -316,7 +321,7 @@ force_add: if (IS_LEAF(n)) { if (key != n->key) goto force_add; - return InsertLeafInto(nn, leaf_pos, value); + return InsertLeafInto(nn, leaf_pos, valuep); } pn = n; nn = &n->child[index]; @@ -330,6 +335,7 @@ force_add: *nn = n; } } + Value value = *valuep; // Create either leaf or oleaf if (tn->pos == leaf_pos) { leaf = VALUE_TO_OLEAF(value); @@ -338,8 +344,8 @@ force_add: FreeNode(tn); return false; } - // -- Start making irreversible changes here + *valuep = NULL; if (leaf_to_free) FreeNode(leaf_to_free); diff --git a/ip_to_peer_map.h b/ip_to_peer_map.h index eef7a36..2db5d5e 100644 --- a/ip_to_peer_map.h +++ b/ip_to_peer_map.h @@ -15,7 +15,7 @@ public: ~RoutingTrie32(); NOINLINE Value Lookup(uint32 ip); NOINLINE Value LookupExact(uint32 ip, int cidr); - bool Insert(uint32 ip, int cidr, Value value); + bool Insert(uint32 ip, int cidr, Value *value); bool Delete(uint32 ip, int cidr); private: @@ -31,7 +31,7 @@ private: static void PutChild(Node *pn, uint32 i, Node *n); static void ReplaceChild(Node **pnp, Node *n); static Node *ConvertOleafToLeaf(Node *pn, uint32 i, Node *n); - static bool InsertLeafInto(Node **n, uint8 leaf_pos, RoutingTrie32::Value value); + static bool InsertLeafInto(Node **n, uint8 leaf_pos, RoutingTrie32::Value *value); }; @@ -43,8 +43,8 @@ public: ~IpToPeerMap(); // Inserts an IP address of a given CIDR length into the lookup table, pointing to peer. - bool InsertV4(uint32 ip, int cidr, void *peer); - bool InsertV6(const void *addr, int cidr, void *peer); + void *InsertV4(uint32 ip, int cidr, void *peer); + void *InsertV6(const void *addr, int cidr, void *peer); // Lookup the peer matching the IP Address void *LookupV4(uint32 ip); @@ -53,8 +53,8 @@ public: void *LookupV4DefaultPeer(); void *LookupV6DefaultPeer(); - // Remove a peer from the table - void RemovePeer(void *peer); + void RemoveV4(uint32 ip, int cidr); + void RemoveV6(const void *addr, int cidr); private: struct Entry6 { uint8 ip[16]; diff --git a/netapi.h b/netapi.h index dac0b0a..2bb455e 100644 --- a/netapi.h +++ b/netapi.h @@ -18,7 +18,6 @@ #pragma warning (disable: 4200) -void OsGetRandomBytes(uint8 *dst, size_t dst_size); uint64 OsGetMilliseconds(); void OsGetTimestampTAI64N(uint8 dst[12]); void OsInterruptibleSleep(int millis); @@ -127,13 +126,13 @@ public: uint8 neighbor_discovery_spoofing_mac[6]; }; - virtual bool Initialize(const TunConfig &&config, TunConfigOut *out) = 0; + virtual bool Configure(const TunConfig &&config, TunConfigOut *out) = 0; virtual void WriteTunPacket(Packet *packet) = 0; }; class UdpInterface { public: - virtual bool Initialize(int listen_port) = 0; + virtual bool Configure(int listen_port) = 0; virtual void WriteUdpPacket(Packet *packet) = 0; }; diff --git a/network_bsd.cpp b/network_bsd.cpp index 40e987c..547e3db 100644 --- a/network_bsd.cpp +++ b/network_bsd.cpp @@ -13,10 +13,12 @@ #include #include #include +#include #include #include #include #include +#include static Packet *freelist; @@ -49,6 +51,7 @@ void FreePackets() { } } + class TunsafeBackendBsdImpl : public TunsafeBackendBsd { public: TunsafeBackendBsdImpl(); @@ -61,7 +64,7 @@ public: virtual void WriteTunPacket(Packet *packet) override; // -- from UdpInterface - virtual bool Initialize(int listen_port) override; + virtual bool Configure(int listen_port) override; virtual void WriteUdpPacket(Packet *packet) override; virtual void HandleSigAlrm() override { got_sig_alarm_ = true; } @@ -72,12 +75,19 @@ private: bool ReadFromTun(); bool WriteToUdp(); bool WriteToTun(); + bool InitializeUnixDomainSocket(const char *devname); + // Exists for the unix domain sockets + struct SockInfo { + bool is_listener; + + std::string inbuf, outbuf; + }; + bool HandleSpecialPollfd(struct pollfd *pollfd, struct SockInfo *sockinfo); + void CloseSpecialPollfd(size_t i); void SetUdpFd(int fd); void SetTunFd(int fd); - inline void RecomputeMaxFd() { max_fd_ = ((tun_fd_>udp_fd_) ? tun_fd_ : udp_fd_) + 1; } - int tun_fd_, udp_fd_, max_fd_; bool got_sig_alarm_; bool exit_; @@ -89,13 +99,25 @@ private: Packet *read_packet_; - fd_set readfds_, writefds_; + enum { + kMaxPollFd = 5, + kPollFdTun = 0, + kPollFdUdp = 1, + kPollFdUnix = 2, + }; + + unsigned int pollfd_num_; + struct pollfd pollfd_[kMaxPollFd]; + + struct SockInfo sockinfo_[kMaxPollFd - 2]; + + struct sockaddr_un un_addr_; + + UnixSocketDeletionWatcher un_deletion_watcher_; }; TunsafeBackendBsdImpl::TunsafeBackendBsdImpl() - : tun_fd_(-1), - udp_fd_(-1), - tun_readable_(false), + : tun_readable_(false), tun_writable_(false), udp_readable_(false), udp_writable_(false), @@ -106,35 +128,39 @@ TunsafeBackendBsdImpl::TunsafeBackendBsdImpl() udp_queue_(NULL), udp_queue_end_(&udp_queue_), read_packet_(NULL) { - RecomputeMaxFd(); - - FD_ZERO(&readfds_); - FD_ZERO(&writefds_); read_packet_ = AllocPacket(); + for(size_t i = 0; i < kMaxPollFd; i++) + pollfd_[i].fd = -1; + pollfd_num_ = 3; + sockinfo_[0].is_listener = true; + memset(&un_addr_, 0, sizeof(un_addr_)); } TunsafeBackendBsdImpl::~TunsafeBackendBsdImpl() { + if (un_addr_.sun_path[0]) + unlink(un_addr_.sun_path); if (read_packet_) FreePacket(read_packet_); + for(size_t i = 0; i < pollfd_num_; i++) + close(pollfd_[i].fd); } void TunsafeBackendBsdImpl::SetUdpFd(int fd) { - udp_fd_ = fd; - RecomputeMaxFd(); + pollfd_[kPollFdUdp].fd = fd; + pollfd_[kPollFdUdp].events = POLLIN; udp_writable_ = true; } void TunsafeBackendBsdImpl::SetTunFd(int fd) { - tun_fd_ = fd; - RecomputeMaxFd(); + pollfd_[kPollFdTun].fd = fd; + pollfd_[kPollFdTun].events = POLLIN; tun_writable_ = true; } - bool TunsafeBackendBsdImpl::ReadFromUdp(bool overload) { socklen_t sin_len; sin_len = sizeof(read_packet_->addr.sin); - int r = recvfrom(udp_fd_, read_packet_->data, kPacketCapacity, 0, + int r = recvfrom(pollfd_[kPollFdUdp].fd, read_packet_->data, kPacketCapacity, 0, (sockaddr*)&read_packet_->addr.sin, &sin_len); if (r >= 0) { // printf("Read %d bytes from UDP\n", r); @@ -157,11 +183,12 @@ bool TunsafeBackendBsdImpl::ReadFromUdp(bool overload) { bool TunsafeBackendBsdImpl::WriteToUdp() { assert(udp_writable_); // RINFO("Send %d bytes to %s", (int)udp_queue_->size, inet_ntoa(udp_queue_->sin.sin_addr)); - int r = sendto(udp_fd_, udp_queue_->data, udp_queue_->size, 0, + int r = sendto(pollfd_[kPollFdUdp].fd, udp_queue_->data, udp_queue_->size, 0, (sockaddr*)&udp_queue_->addr.sin, sizeof(udp_queue_->addr.sin)); if (r < 0) { if (errno == EAGAIN) { udp_writable_ = false; + pollfd_[kPollFdUdp].events = POLLIN | POLLOUT; return false; } perror("Write to UDP failed"); @@ -185,7 +212,7 @@ static inline bool IsCompatibleProto(uint32 v) { bool TunsafeBackendBsdImpl::ReadFromTun() { assert(tun_readable_); Packet *packet = read_packet_; - int r = read(tun_fd_, packet->data - TUN_PREFIX_BYTES, kPacketCapacity + TUN_PREFIX_BYTES); + int r = read(pollfd_[kPollFdTun].fd, packet->data - TUN_PREFIX_BYTES, kPacketCapacity + TUN_PREFIX_BYTES); if (r >= 0) { // printf("Read %d bytes from TUN\n", r); packet->size = r - TUN_PREFIX_BYTES; @@ -215,10 +242,11 @@ bool TunsafeBackendBsdImpl::WriteToTun() { if (TUN_PREFIX_BYTES) { WriteBE32(tun_queue_->data - TUN_PREFIX_BYTES, GetProtoFromPacket(tun_queue_->data, tun_queue_->size)); } - int r = write(tun_fd_, tun_queue_->data - TUN_PREFIX_BYTES, tun_queue_->size + TUN_PREFIX_BYTES); + int r = write(pollfd_[kPollFdTun].fd, tun_queue_->data - TUN_PREFIX_BYTES, tun_queue_->size + TUN_PREFIX_BYTES); if (r < 0) { if (errno == EAGAIN) { tun_writable_ = false; + pollfd_[kPollFdTun].events = POLLIN | POLLOUT; return false; } RERROR("Write to tun failed"); @@ -242,11 +270,13 @@ bool TunsafeBackendBsdImpl::InitializeTun(char devname[16]) { fcntl(tun_fd, F_SETFD, FD_CLOEXEC); fcntl(tun_fd, F_SETFL, O_NONBLOCK); SetTunFd(tun_fd); + + InitializeUnixDomainSocket(devname); return true; } void TunsafeBackendBsdImpl::WriteTunPacket(Packet *packet) override { - assert(tun_fd_ >= 0); + assert(pollfd_[kPollFdTun].fd >= 0); Packet *queue_is_used = tun_queue_; *tun_queue_end_ = packet; tun_queue_end_ = &packet->next; @@ -256,7 +286,7 @@ void TunsafeBackendBsdImpl::WriteTunPacket(Packet *packet) override { } // Called to initialize udp -bool TunsafeBackendBsdImpl::Initialize(int listen_port) override { +bool TunsafeBackendBsdImpl::Configure(int listen_port) override { int udp_fd = open_udp(listen_port); if (udp_fd < 0) { RERROR("Error opening udp"); return false; } fcntl(udp_fd, F_SETFD, FD_CLOEXEC); @@ -266,7 +296,7 @@ bool TunsafeBackendBsdImpl::Initialize(int listen_port) override { } void TunsafeBackendBsdImpl::WriteUdpPacket(Packet *packet) override { - assert(udp_fd_ >= 0); + assert(pollfd_[kPollFdUdp].fd >= 0); Packet *queue_is_used = udp_queue_; *udp_queue_end_ = packet; udp_queue_end_ = &packet->next; @@ -275,16 +305,137 @@ void TunsafeBackendBsdImpl::WriteUdpPacket(Packet *packet) override { WriteToUdp(); } +bool TunsafeBackendBsdImpl::InitializeUnixDomainSocket(const char *devname) { + int fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (fd == -1) { + RERROR("Error creating unix domain socket"); + return false; + } + + fcntl(fd, F_SETFD, FD_CLOEXEC); + fcntl(fd, F_SETFL, O_NONBLOCK); + + mkdir("/var/run/wireguard", 0755); + un_addr_.sun_family = AF_UNIX; + snprintf(un_addr_.sun_path, sizeof(un_addr_.sun_path), "/var/run/wireguard/%s.sock", devname); + unlink(un_addr_.sun_path); + if (bind(fd, (struct sockaddr*)&un_addr_, sizeof(un_addr_)) == -1) { + RERROR("Error binding unix domain socket"); + close(fd); + return false; + } + if (listen(fd, 5) == -1) { + RERROR("Error listening on unix domain socket"); + close(fd); + return false; + } + + pollfd_[kPollFdUnix].fd = fd; + pollfd_[kPollFdUnix].events = POLLIN; + + return true; +} + +static const char *FindMessageEnd(const char *start, size_t size) { + if (size <= 1) + return NULL; + const char *start_end = start + size - 1; + for(;(start = (const char*)memchr(start, '\n', start_end - start)) != NULL; start++) { + if (start[1] == '\n') + return start + 2; + } + return NULL; +} + +bool TunsafeBackendBsdImpl::HandleSpecialPollfd(struct pollfd *pfd, struct SockInfo *sockinfo) { + // handle domain socket thing + if (sockinfo->is_listener) { + if (pfd->revents & POLLIN) { + // wait if we can't allocate more pollfd + if (pollfd_num_ == kMaxPollFd) { + pfd->events = 0; + return true; + } + int fd = accept(pfd->fd, NULL, NULL); + if (fd >= 0) { + size_t slot = pollfd_num_++; + pollfd_[slot].fd = fd; + pollfd_[slot].events = POLLIN; + pollfd_[slot].revents = 0; + sockinfo_[slot - 2].is_listener = false; + } else { + RERROR("Unix domain socket accept failed"); + } + } + if (pfd->revents & ~POLLIN) { + RERROR("Unix domain socket got an error code"); + return false; + } + return true; + } + if (pfd->revents & POLLIN) { + char buf[4096]; + // read as much data as we can until we see \n\n + ssize_t n = recv(pfd->fd, buf, sizeof(buf), 0); + if (n <= 0) + return (n == -1 && errno == EAGAIN); // premature eof or error + sockinfo->inbuf.append(buf, n); + const char *message_end = FindMessageEnd(&sockinfo->inbuf[0], sockinfo->inbuf.size()); + if (message_end) { + if (message_end != &sockinfo->inbuf[sockinfo->inbuf.size()]) + return false; // trailing data? + WgConfig::HandleConfigurationProtocolMessage(processor_, std::move(sockinfo->inbuf), &sockinfo->outbuf); + if (!sockinfo->outbuf.size()) + return false; + pfd->revents = pfd->events = POLLOUT; + } + } + if (pfd->revents & POLLOUT) { + size_t n = send(pfd->fd, sockinfo->outbuf.data(), sockinfo->outbuf.size(), 0); + if (n <= 0) + return (n == -1 && errno == EAGAIN); // premature eof or error + sockinfo->outbuf.erase(0, n); + if (!sockinfo->outbuf.size()) + return false; + } + + if (pfd->revents & ~(POLLIN | POLLOUT)) { + RERROR("Unix domain socket got an error code"); + return false; + } + return true; +} + +void TunsafeBackendBsdImpl::CloseSpecialPollfd(size_t i) { + close(pollfd_[i].fd); + pollfd_[i].fd = -1; + sockinfo_[i - 2].inbuf.clear(); + sockinfo_[i - 2].outbuf.clear(); + pollfd_[i] = pollfd_[(size_t)pollfd_num_ - 1]; + std::swap(sockinfo_[i - 2], sockinfo_[(size_t)pollfd_num_ - 1 - 2]); + + // Can now allow more sockets? + if (pollfd_num_-- == kMaxPollFd && sockinfo_[kPollFdUnix - 2].is_listener) + pollfd_[kPollFdUnix].events = POLLIN; +} + void TunsafeBackendBsdImpl::RunLoopInner() { int free_packet_interval = 10; int overload_ctr = 0; + if (!un_deletion_watcher_.Start(un_addr_.sun_path, &exit_)) + return; + while (!exit_) { int n = -1; - // This is not fully signal safe. if (got_sig_alarm_) { got_sig_alarm_ = false; + + if (un_deletion_watcher_.Poll(un_addr_.sun_path)) { + RINFO("Unix socket %s deleted.", un_addr_.sun_path); + break; + } processor_->SecondLoop(); if (free_packet_interval == 0) { @@ -296,33 +447,53 @@ void TunsafeBackendBsdImpl::RunLoopInner() { overload_ctr -= (overload_ctr != 0); } - if (tun_fd_ >= 0) { - FD_SET(tun_fd_, &readfds_); - if (tun_writable_) FD_CLR(tun_fd_, &writefds_); else FD_SET(tun_fd_, &writefds_); - } - - if (udp_fd_ >= 0) { - FD_SET(udp_fd_, &readfds_); - if (udp_writable_) FD_CLR(udp_fd_, &writefds_); else FD_SET(udp_fd_, &writefds_); - } - - n = select(max_fd_, &readfds_, &writefds_, NULL, NULL); +#if defined(OS_LINUX) || defined(OS_FREEBSD) + n = ppoll(pollfd_, pollfd_num_, NULL, &orig_signal_mask_); +#else + n = poll(pollfd_, pollfd_num_, -1); +#endif if (n == -1) { if (errno != EINTR) { - fprintf(stderr, "select failed\n"); + RERROR("poll failed"); break; } } else { - if (tun_fd_ >= 0) { - tun_readable_ = (FD_ISSET(tun_fd_, &readfds_) != 0); - tun_writable_ |= (FD_ISSET(tun_fd_, &writefds_) != 0); + + if (pollfd_[kPollFdTun].revents & (POLLERR | POLLHUP | POLLNVAL)) { + if (pollfd_[kPollFdTun].revents & POLLERR) { + tun_interface_gone_ = true; + RERROR("Tun interface gone, closing."); + } else { + RERROR("Tun interface error %d, closing.", pollfd_[kPollFdTun].revents); + } + break; } - if (udp_fd_ >= 0) { - udp_readable_ = (FD_ISSET(udp_fd_, &readfds_) != 0); - udp_writable_ |= (FD_ISSET(udp_fd_, &writefds_) != 0); + tun_readable_ = (pollfd_[kPollFdTun].revents & POLLIN) != 0; + if (pollfd_[kPollFdTun].revents & POLLOUT) { + pollfd_[kPollFdTun].events = POLLIN; + tun_writable_ = true; + } + + if (pollfd_[kPollFdUdp].revents & (POLLERR | POLLHUP | POLLNVAL)) { + RERROR("UDP error %d, closing.", pollfd_[kPollFdUdp].revents); + break; + } + + udp_readable_ = (pollfd_[kPollFdUdp].revents & POLLIN) != 0; + if (pollfd_[kPollFdUdp].revents & POLLOUT) { + pollfd_[kPollFdUdp].events = POLLIN; + udp_writable_ = true; + } + + for(size_t i = 2; i < pollfd_num_; i++) { + if (pollfd_[i].revents && !HandleSpecialPollfd(&pollfd_[i], &sockinfo_[i - 2])) { + // Close the fd / discard the sockinfo + CloseSpecialPollfd(i); + i--; + } } } - + bool overload = (overload_ctr != 0); for(int loop = 0; ; loop++) { @@ -342,6 +513,8 @@ void TunsafeBackendBsdImpl::RunLoopInner() { processor_->RunAllMainThreadScheduled(); } + + un_deletion_watcher_.Stop(); } TunsafeBackendBsd *CreateTunsafeBackendBsd() { diff --git a/network_bsd_common.cpp b/network_bsd_common.cpp index 161f70f..f0a59fa 100644 --- a/network_bsd_common.cpp +++ b/network_bsd_common.cpp @@ -39,6 +39,8 @@ #include #include #include +#include +#include #endif void tunsafe_die(const char *msg) { @@ -286,15 +288,6 @@ void OsGetTimestampTAI64N(uint8 dst[12]) { WriteBE32(dst + 8, nanos); } -void OsGetRandomBytes(uint8 *data, size_t data_size) { - int fd = open("/dev/urandom", O_RDONLY); - int r = read(fd, data, data_size); - if (r < 0) r = 0; - close(fd); - for (; r < data_size; r++) - data[r] = rand() >> 6; -} - void OsInterruptibleSleep(int millis) { usleep((useconds_t)millis * 1000); } @@ -387,11 +380,12 @@ int open_tun(char *devname, size_t devname_size) { memset(&ifr, 0, sizeof(ifr)); ifr.ifr_flags = IFF_TUN | IFF_NO_PI; + my_strlcpy(ifr.ifr_name, sizeof(ifr.ifr_name), devname); if ((err = ioctl(fd, TUNSETIFF, (void *) &ifr)) < 0) { close(fd); return err; } - strcpy(devname, ifr.ifr_name); + my_strlcpy(devname, devname_size, ifr.ifr_name); return fd; } #endif @@ -411,6 +405,8 @@ int open_udp(int listen_on_port) { TunsafeBackendBsd::TunsafeBackendBsd() : processor_(NULL) { + devname_[0] = 0; + tun_interface_gone_ = false; } TunsafeBackendBsd::~TunsafeBackendBsd() { @@ -495,10 +491,10 @@ void TunsafeBackendBsd::DelRoute(const RouteInfo &cd) { static bool IsIpv6AddressSet(const void *p) { return (ReadLE64(p) | ReadLE64((char*)p + 8)) != 0; } - + // Called to initialize tun -bool TunsafeBackendBsd::Initialize(const TunConfig &&config, TunConfigOut *out) override { - char devname[16]; +bool TunsafeBackendBsd::Configure(const TunConfig &&config, TunConfigOut *out) override { + char buf[kSizeOfAddress]; if (!RunPrePostCommand(config.pre_post_commands.pre_up)) { RERROR("Pre command failed!"); @@ -507,17 +503,35 @@ bool TunsafeBackendBsd::Initialize(const TunConfig &&config, TunConfigOut *out) out->enable_neighbor_discovery_spoofing = false; - if (!InitializeTun(devname)) + if (!InitializeTun(devname_)) return false; - if (config.ipv6_cidr) - RERROR("IPv6 not supported"); - uint32 netmask = CidrToNetmaskV4(config.cidr); uint32 default_route_v4 = ComputeIpv4DefaultRoute(config.ip, netmask); - - RunCommand("/sbin/ifconfig %s %A mtu %d %A netmask %A up", devname, config.ip, config.mtu, config.ip, netmask); - AddRoute(config.ip & netmask, config.cidr, config.ip, devname); + + +#if defined(OS_LINUX) + if (config.ip) { + char ip[4]; + WriteBE32(ip, config.ip); + RunCommand("/sbin/ip address add dev %s %s", devname_, print_ip_prefix(buf, AF_INET, ip, config.cidr)); + } + if (config.ipv6_cidr) { + RunCommand("/sbin/ip address add dev %s %s", devname_, print_ip_prefix(buf, AF_INET6, config.ipv6_address, config.ipv6_cidr)); + } + RunCommand("/sbin/ip link set dev %s mtu %d up", devname_, config.mtu); +#else // !defined(OS_LINUX) + if (config.ip) { + RunCommand("/sbin/ifconfig %s %A mtu %d %A netmask %A up", devname_, config.ip, config.mtu, config.ip, netmask); + } + if (config.ipv6_cidr) { + RunCommand("/sbin/ifconfig %s inet6 add %s", devname_, print_ip_prefix(buf, AF_INET6, config.ipv6_address, config.ipv6_cidr)); + } +#endif // !defined(OS_LINUX) + + if (config.ip) { + AddRoute(config.ip & netmask, config.cidr, config.ip, devname_); + } if (config.use_ipv4_default_route) { if (config.default_route_endpoint_v4) { @@ -533,35 +547,30 @@ bool TunsafeBackendBsd::Initialize(const TunConfig &&config, TunConfigOut *out) AddRoute(ReadBE32(it->addr), it->cidr, ipv4_default_gw, default_iface); } } - AddRoute(0x00000000, 1, default_route_v4, devname); - AddRoute(0x80000000, 1, default_route_v4, devname); + AddRoute(0x00000000, 1, default_route_v4, devname_); + AddRoute(0x80000000, 1, default_route_v4, devname_); } uint8 default_route_v6[16]; if (config.ipv6_cidr) { static const uint8 matchall_1_route[17] = {0x80, 0, 0, 0}; - char buf[kSizeOfAddress]; - ComputeIpv6DefaultRoute(config.ipv6_address, config.ipv6_cidr, default_route_v6); - - RunCommand("/sbin/ifconfig %s inet6 add %s", devname, print_ip_prefix(buf, AF_INET6, config.ipv6_address, config.ipv6_cidr)); - if (config.use_ipv6_default_route) { if (IsIpv6AddressSet(config.default_route_endpoint_v6)) { RERROR("default_route_endpoint_v6 not supported"); } - AddRoute(AF_INET6, matchall_1_route + 1, 1, default_route_v6, devname); - AddRoute(AF_INET6, matchall_1_route + 0, 1, default_route_v6, devname); + AddRoute(AF_INET6, matchall_1_route + 1, 1, default_route_v6, devname_); + AddRoute(AF_INET6, matchall_1_route + 0, 1, default_route_v6, devname_); } } // Add all the extra routes for (auto it = config.extra_routes.begin(); it != config.extra_routes.end(); ++it) { if (it->size == 32) { - AddRoute(ReadBE32(it->addr), it->cidr, default_route_v4, devname); + AddRoute(ReadBE32(it->addr), it->cidr, default_route_v4, devname_); } else if (it->size == 128 && config.ipv6_cidr) { - AddRoute(AF_INET6, it->addr, it->cidr, default_route_v6, devname); + AddRoute(AF_INET6, it->addr, it->cidr, default_route_v6, devname_); } } @@ -576,8 +585,10 @@ bool TunsafeBackendBsd::Initialize(const TunConfig &&config, TunConfigOut *out) void TunsafeBackendBsd::CleanupRoutes() { RunPrePostCommand(pre_down_); - for(auto it = cleanup_commands_.begin(); it != cleanup_commands_.end(); ++it) - DelRoute(*it); + for(auto it = cleanup_commands_.begin(); it != cleanup_commands_.end(); ++it) { + if (!tun_interface_gone_ || strcmp(it->dev.c_str(), devname_) != 0) + DelRoute(*it); + } cleanup_commands_.clear(); RunPrePostCommand(post_down_); @@ -586,6 +597,10 @@ void TunsafeBackendBsd::CleanupRoutes() { post_down_.clear(); } +void TunsafeBackendBsd::SetTunDeviceName(const char *name) { + my_strlcpy(devname_, sizeof(devname_), name); +} + static bool RunOneCommand(const std::string &cmd) { RINFO("Run: %s", cmd.c_str()); int exit_code = system(cmd.c_str()); @@ -604,6 +619,94 @@ bool TunsafeBackendBsd::RunPrePostCommand(const std::vector &vec) { return success; } +#if defined(OS_LINUX) +UnixSocketDeletionWatcher::UnixSocketDeletionWatcher() + : inotify_fd_(-1) { + pipes_[0] = -1; + pipes_[0] = -1; +} + +UnixSocketDeletionWatcher::~UnixSocketDeletionWatcher() { + close(inotify_fd_); + close(pipes_[0]); + close(pipes_[1]); +} + +bool UnixSocketDeletionWatcher::Start(const char *path, bool *flag_to_set) { + assert(inotify_fd_ == -1); + path_ = path; + flag_to_set_ = flag_to_set; + pid_ = getpid(); + inotify_fd_ = inotify_init1(IN_CLOEXEC | IN_NONBLOCK); + if (inotify_fd_ == -1) { + perror("inotify_init1() failed"); + return false; + } + if (inotify_add_watch(inotify_fd_, "/var/run/wireguard", IN_DELETE | IN_DELETE_SELF) == -1) { + perror("inotify_add_watch failed"); + return false; + } + if (pipe(pipes_) == -1) { + perror("pipe() failed"); + return false; + } + return pthread_create(&thread_, NULL, &UnixSocketDeletionWatcher::RunThread, this) == 0; +} + +void UnixSocketDeletionWatcher::Stop() { + RINFO("Stopping.."); + void *retval; + write(pipes_[1], "", 1); + pthread_join(thread_, &retval); +} + +void *UnixSocketDeletionWatcher::RunThread(void *arg) { + UnixSocketDeletionWatcher *self = (UnixSocketDeletionWatcher*)arg; + return self->RunThreadInner(); +} + +void *UnixSocketDeletionWatcher::RunThreadInner() { + char buf[sizeof(struct inotify_event) + NAME_MAX + 1] + __attribute__ ((aligned(__alignof__(struct inotify_event)))); + fd_set fdset; + struct stat st; + for(;;) { + if (lstat(path_, &st) == -1 && errno == ENOENT) { + RINFO("Unix socket %s deleted.", path_); + *flag_to_set_ = true; + kill(pid_, SIGALRM); + break; + } + FD_ZERO(&fdset); + FD_SET(inotify_fd_, &fdset); + FD_SET(pipes_[0], &fdset); + int n = select(std::max(inotify_fd_, pipes_[0]) + 1, &fdset, NULL, NULL, NULL); + if (n == -1) { + perror("select"); + break; + } + if (FD_ISSET(inotify_fd_, &fdset)) { + ssize_t len = read(inotify_fd_, buf, sizeof(buf)); + if (len == -1) { + perror("read"); + break; + } + } + if (FD_ISSET(pipes_[0], &fdset)) + break; + } + return NULL; +} + +#else // !defined(OS_LINUX) + +bool UnixSocketDeletionWatcher::Poll(const char *path) { + struct stat st; + return lstat(path, &st) == -1 && errno == ENOENT; +} + +#endif // !defined(OS_LINUX) + static TunsafeBackendBsd *g_tunsafe_backend_bsd; static void SigAlrm(int sig) { @@ -611,10 +714,6 @@ static void SigAlrm(int sig) { g_tunsafe_backend_bsd->HandleSigAlrm(); } -static void SigUsr1(int sig) { - -} - static bool did_ctrlc; void SigInt(int sig) { @@ -623,6 +722,7 @@ void SigInt(int sig) { did_ctrlc = true; write(1, "Ctrl-C detected. Exiting. Press again to force quit.\n", sizeof("Ctrl-C detected. Exiting. Press again to force quit.\n")-1); + // todo: fix signal safety? if (g_tunsafe_backend_bsd) g_tunsafe_backend_bsd->HandleExit(); } @@ -631,7 +731,10 @@ void TunsafeBackendBsd::RunLoop() { assert(!g_tunsafe_backend_bsd); assert(processor_); + sigset_t mask; + g_tunsafe_backend_bsd = this; + // We want an alarm signal every second. { struct sigaction act = {0}; @@ -651,16 +754,14 @@ void TunsafeBackendBsd::RunLoop() { } } - { - struct sigaction act = {0}; - act.sa_handler = SigUsr1; - if (sigaction(SIGUSR1, &act, NULL) < 0) { - RERROR("Unable to install SIGUSR1 handler."); - return; - } +#if defined(OS_LINUX) || defined(OS_FREEBSD) + sigemptyset(&mask); + sigaddset(&mask, SIGALRM); + if (sigprocmask(SIG_BLOCK, &mask, &orig_signal_mask_) < 0) { + perror("sigprocmask"); + return; } -#if defined(OS_LINUX) || defined(OS_FREEBSD) { struct itimerspec tv = {0}; struct sigevent sev; @@ -727,7 +828,17 @@ public: bool is_connected_; }; +struct CommandLineOutput { + const char *filename_to_load; + const char *interface_name; + bool daemon; +}; + +int HandleCommandLine(int argc, char **argv, CommandLineOutput *output); + int main(int argc, char **argv) { + CommandLineOutput cmd = {0}; + InitCpuFeatures(); if (argc == 2 && strcmp(argv[1], "--benchmark") == 0) { @@ -735,12 +846,9 @@ int main(int argc, char **argv) { return 0; } - fprintf(stderr, "%s\n", TUNSAFE_VERSION_STRING); - - if (argc < 2) { - fprintf(stderr, "Syntax: tunsafe file.conf\n"); - return 1; - } + int rv = HandleCommandLine(argc, argv, &cmd); + if (!cmd.filename_to_load) + return rv; #if defined(OS_MACOSX) InitOsxGetMilliseconds(); @@ -749,19 +857,29 @@ int main(int argc, char **argv) { SetThreadName("tunsafe-m"); MyProcessorDelegate my_procdel; - TunsafeBackendBsd *socket_loop = CreateTunsafeBackendBsd(); - WireguardProcessor wg(socket_loop, socket_loop, &my_procdel); + TunsafeBackendBsd *backend = CreateTunsafeBackendBsd(); + if (cmd.interface_name) + backend->SetTunDeviceName(cmd.interface_name); + + WireguardProcessor wg(backend, backend, &my_procdel); my_procdel.wg_processor_ = &wg; - socket_loop->SetProcessor(&wg); + backend->SetProcessor(&wg); DnsResolver dns_resolver(NULL); - if (!ParseWireGuardConfigFile(&wg, argv[1], &dns_resolver)) return 1; + if (*cmd.filename_to_load && !ParseWireGuardConfigFile(&wg, cmd.filename_to_load, &dns_resolver)) + return 1; if (!wg.Start()) return 1; - socket_loop->RunLoop(); - socket_loop->CleanupRoutes(); - delete socket_loop; + if (cmd.daemon) { + fprintf(stderr, "Switching to daemon mode...\n"); + if (daemon(0, 0) == -1) + perror("daemon() failed"); + } + + backend->RunLoop(); + backend->CleanupRoutes(); + delete backend; return 0; } diff --git a/network_bsd_common.h b/network_bsd_common.h index cc14bc6..36d566c 100644 --- a/network_bsd_common.h +++ b/network_bsd_common.h @@ -7,6 +7,7 @@ #include "wireguard.h" #include "wireguard_config.h" #include +#include struct RouteInfo { uint8 family; @@ -16,6 +17,39 @@ struct RouteInfo { std::string dev; }; +#if defined(OS_LINUX) +// Keeps track of when the unix socket gets deleted +class UnixSocketDeletionWatcher { +public: + UnixSocketDeletionWatcher(); + ~UnixSocketDeletionWatcher(); + bool Start(const char *path, bool *flag_to_set); + void Stop(); + bool Poll(const char *path) { return false; } + +private: + static void *RunThread(void *arg); + void *RunThreadInner(); + const char *path_; + int inotify_fd_; + int pid_; + int pipes_[2]; + pthread_t thread_; + bool *flag_to_set_; +}; +#else // !defined(OS_LINUX) +// all other platforms that lack inotify +class UnixSocketDeletionWatcher { +public: + UnixSocketDeletionWatcher() {} + ~UnixSocketDeletionWatcher() {} + bool Start(const char *path, bool *flag_to_set) { return true; } + void Stop() {} + bool Poll(const char *path); +}; +#endif // !defined(OS_LINUX) + + class TunsafeBackendBsd : public TunInterface, public UdpInterface { public: TunsafeBackendBsd(); @@ -24,10 +58,12 @@ public: void RunLoop(); void CleanupRoutes(); + void SetTunDeviceName(const char *name); + void SetProcessor(WireguardProcessor *wg) { processor_ = wg; } // -- from TunInterface - virtual bool Initialize(const TunConfig &&config, TunConfigOut *out) override; + virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override; virtual void HandleSigAlrm() = 0; virtual void HandleExit() = 0; @@ -44,6 +80,9 @@ protected: WireguardProcessor *processor_; std::vector cleanup_commands_; std::vector pre_down_, post_down_; + sigset_t orig_signal_mask_; + char devname_[16]; + bool tun_interface_gone_; }; #if defined(OS_MACOSX) || defined(OS_FREEBSD) diff --git a/network_win32.cpp b/network_win32.cpp index 2b9c680..5bc7f01 100644 --- a/network_win32.cpp +++ b/network_win32.cpp @@ -60,20 +60,6 @@ static bool IsIpv6AddressSet(const void *p) { return (ReadLE64(p) | ReadLE64((char*)p + 8)) != 0; } -void OsGetRandomBytes(uint8 *data, size_t data_size) { - static BOOLEAN(APIENTRY *pfn)(void*, ULONG); - static bool resolved; - if (!resolved) { - pfn = (BOOLEAN(APIENTRY *)(void*, ULONG))GetProcAddress(LoadLibrary("ADVAPI32.DLL"), "SystemFunction036"); - resolved = true; - } - if (pfn && pfn(data, (ULONG)data_size)) - return; - size_t r = 0; - for (; r < data_size; r++) - data[r] = rand() >> 6; -} - void OsInterruptibleSleep(int millis) { SleepEx(millis, TRUE); } @@ -140,6 +126,7 @@ struct { #define kConcurrentWriteTap 16 #define kAdapterKeyName "SYSTEM\\CurrentControlSet\\Control\\Class\\{4D36E972-E325-11CE-BFC1-08002BE10318}" +#define kNetworkConnectionsKeyName "SYSTEM\\CurrentControlSet\\Control\\Network\\{4D36E972-E325-11CE-BFC1-08002BE10318}" #define kTapComponentId "tap0901" #define TAP_CONTROL_CODE(request,method) \ @@ -196,91 +183,79 @@ static bool RunNetsh(const char *cmdline) { return result; } -// Retrieve the device path to the TAP adapter. -static bool GetTapAdapterGuid(char guid[64]) { - LONG err; - HKEY adapter_key, device_key; - bool retval = false; - err = RegOpenKeyEx(HKEY_LOCAL_MACHINE, kAdapterKeyName, 0, KEY_READ, &adapter_key); - if (err != ERROR_SUCCESS) { - RERROR("GetTapAdapterName: RegOpenKeyEx failed: 0x%X", GetLastError()); - return false; - } - for (int i = 0; !retval; i++) { - char keyname[64 + sizeof(kAdapterKeyName) + 1]; - char value[64]; - DWORD len = sizeof(value), type; - err = RegEnumKeyEx(adapter_key, i, value, &len, NULL, NULL, NULL, NULL); - if (err == ERROR_NO_MORE_ITEMS) - break; - if (err != ERROR_SUCCESS) { - RERROR("GetTapAdapterName: RegEnumKeyEx failed: 0x%X", GetLastError()); - return false; - } - snprintf(keyname, sizeof(keyname), "%s\\%s", kAdapterKeyName, value); - err = RegOpenKeyEx(HKEY_LOCAL_MACHINE, keyname, 0, KEY_READ, &device_key); - if (err == ERROR_SUCCESS) { - len = sizeof(value); - err = RegQueryValueEx(device_key, "ComponentId", NULL, &type, (LPBYTE)value, &len); - if (err == ERROR_SUCCESS && type == REG_SZ && !memcmp(value, kTapComponentId, sizeof(kTapComponentId))) { - len = 64; - err = RegQueryValueEx(device_key, "NetCfgInstanceId", NULL, &type, (LPBYTE)guid, &len); - if (err == ERROR_SUCCESS && type == REG_SZ) { - guid[63] = 0; - retval = true; - } - } - RegCloseKey(device_key); - } - } - RegCloseKey(adapter_key); - return retval; -} - -// Open the TAP adapter -static HANDLE OpenTunAdapter(char guid[64], int retry_count, uint32 *exit_thread, DWORD open_flags) { +// Open the TAP adapter, either a random one or a specific one +// On return, the adapter is locked in |TunAdaptersInUse|. +static HANDLE OpenTunAdapter(char guid[ADAPTER_GUID_SIZE], TunsafeBackendWin32 *backend, DWORD open_flags) { char path[128]; HANDLE h; int retries = 0; - if (!GetTapAdapterGuid(guid)) { - RERROR("Unable to find ID of TAP adapter"); - RERROR(" Please ensure that TunSafe-TAP is properly installed."); - return NULL; - } - snprintf(path, sizeof(path), "\\\\.\\Global\\%s.tap", guid); -RETRY: - h = CreateFile(path, GENERIC_READ | GENERIC_WRITE, 0, 0, OPEN_EXISTING, - FILE_ATTRIBUTE_SYSTEM | open_flags, 0); - if (h == INVALID_HANDLE_VALUE) { - int error_code = GetLastError(); - - // Sometimes if you close the device right before, it will fail to open with errorcode 31. - // When resuming from sleep in my VM, the error code is ERROR_FILE_NOT_FOUND - if ((error_code == ERROR_FILE_NOT_FOUND || error_code == ERROR_GEN_FAILURE) && retry_count != 0 && !*exit_thread) { - RERROR("OpenTapAdapter: CreateFile failed: 0x%X... retrying", error_code); - retry_count--; - - int sleep_amount = 250 * ++retries; - for(;;) { - if (*exit_thread) - return NULL; - if (sleep_amount == 0) - break; - Sleep(50); - sleep_amount -= 50; - } - goto RETRY; - } - - RERROR("OpenTapAdapter: CreateFile failed: 0x%X", error_code); - if (error_code == ERROR_FILE_NOT_FOUND) { + std::vector adapters; + + // When guid is empty, we try all adapters, otherwise + // just try the specific adapter + if (guid[0] == 0) { + GetTapAdapterInfo(&adapters); + if (adapters.empty()) { + RERROR("Unable to find any TAP adapters"); RERROR(" Please ensure that TunSafe-TAP is properly installed."); - } else if (error_code == 0x1f) { - RERROR(" Please ensure that the TAP device is not in use."); + return NULL; } + } else { + adapters.emplace_back(); + memcpy(adapters.back().guid, guid, ADAPTER_GUID_SIZE); + adapters.back().name[0] = 0; + } + TunAdaptersInUse *tun_adapters_in_use = TunAdaptersInUse::GetInstance(); + +RETRY: + bool did_try_adapter = false; + int error_code = 0; + for (GuidAndDevName &x : adapters) { + snprintf(path, sizeof(path), "\\\\.\\Global\\%s.tap", x.guid); + if (tun_adapters_in_use->Acquire(x.guid, static_cast(backend))) { + h = CreateFile(path, GENERIC_READ | GENERIC_WRITE, 0, 0, OPEN_EXISTING, FILE_ATTRIBUTE_SYSTEM | open_flags, 0); + if (h != INVALID_HANDLE_VALUE) { + memcpy(guid, x.guid, ADAPTER_GUID_SIZE); + return h; + } + did_try_adapter = true; + error_code = GetLastError(); + tun_adapters_in_use->Release(static_cast(backend)); + } + } + if (!did_try_adapter) { + RERROR("All TAP adapters are currently in use"); return NULL; } - return h; + + // Sometimes if you close the device right before, it will fail to open with errorcode 31. + // When resuming from sleep in my VM, the error code is ERROR_FILE_NOT_FOUND + if ((error_code == ERROR_FILE_NOT_FOUND || error_code == ERROR_GEN_FAILURE) && !backend->exit_code()) { + if (retries <= 10) { + RERROR("OpenTapAdapter: CreateFile failed: 0x%X... retrying%s", error_code, retries == 10 ? " (last notice)" : ""); + if (retries == 10) { + if (error_code == ERROR_FILE_NOT_FOUND) { + RERROR(" Please ensure that TunSafe-TAP is properly installed."); + } else if (error_code == ERROR_GEN_FAILURE) { + RERROR(" Please ensure that the TAP device is not in use."); + } + backend->SetStatus(TunsafeBackend::kStatusTunRetrying); + } + } + int sleep_amount = 250 * std::min(++retries, 40); + for (;;) { + if (backend->exit_code()) + return NULL; + if (sleep_amount == 0) + break; + Sleep(125); + sleep_amount -= 125; + } + goto RETRY; + } + + RERROR("OpenTapAdapter: CreateFile failed: 0x%X", error_code); + return NULL; } static bool AddRoute(int family, @@ -466,60 +441,85 @@ UdpSocketWin32::UdpSocketWin32() { wqueue_end_ = &wqueue_; wqueue_ = NULL; exit_thread_ = false; - socket_ = INVALID_SOCKET; thread_ = NULL; + socket_ = INVALID_SOCKET; socket_ipv6_ = INVALID_SOCKET; completion_port_handle_ = NULL; } UdpSocketWin32::~UdpSocketWin32() { assert(thread_ == NULL); + CloseHandle(completion_port_handle_); closesocket(socket_); closesocket(socket_ipv6_); - CloseHandle(completion_port_handle_); FreePacketList(wqueue_); } -bool UdpSocketWin32::Initialize(int listen_on_port) { - SOCKET s = WSASocket(AF_INET, SOCK_DGRAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED); - if (s == INVALID_SOCKET) { - RERROR("UdpSocketWin32::Initialize WSASocket failed"); - return false; +bool UdpSocketWin32::Configure(int listen_on_port) { + // If attempting to initialize when the thread is already started, then stop + // the thread, reinitialize, and start the thread. + if (thread_ != NULL) { + StopThread(); + bool retcode = Configure(listen_on_port); + StartThread(); + return retcode; } - completion_port_handle_ = CreateIoCompletionPort((HANDLE)s, NULL, NULL, 0); - if (!completion_port_handle_) { - closesocket(s); - return false; - } - socket_ = s; - sockaddr_in sin = {0}; - sin.sin_family = AF_INET; - sin.sin_port = htons(listen_on_port); - if (bind(s, (struct sockaddr*)&sin, sizeof(sin)) != 0) { - RERROR("UdpSocketWin32::Initialize bind failed"); - return false; + bool retval = false; + HANDLE completion_port = NULL; + SOCKET socket_ipv4 = INVALID_SOCKET, socket_ipv6 = INVALID_SOCKET; + + socket_ipv4 = WSASocket(AF_INET, SOCK_DGRAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED); + if (socket_ipv4 == INVALID_SOCKET) { + RERROR("UdpSocketWin32::Initialize WSASocket failed"); + goto fail; + } + completion_port = CreateIoCompletionPort((HANDLE)socket_ipv4, NULL, NULL, 0); + if (!completion_port) { + RERROR("UdpSocketWin32::Initialize CreateIoCompletionPort failed"); + goto fail; + } + + { + sockaddr_in sin = {0}; + sin.sin_family = AF_INET; + sin.sin_port = htons(listen_on_port); + if (bind(socket_ipv4, (struct sockaddr*)&sin, sizeof(sin)) != 0) { + RERROR("UdpSocketWin32::Initialize bind failed"); + goto fail; + } } // Also open up a socket for ipv6 - s = WSASocket(AF_INET6, SOCK_DGRAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED); - if (s != INVALID_SOCKET) { - if (!CreateIoCompletionPort((HANDLE)s, completion_port_handle_, 1, 0)) { + socket_ipv6 = WSASocket(AF_INET6, SOCK_DGRAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED); + if (socket_ipv6 != INVALID_SOCKET) { + if (!CreateIoCompletionPort((HANDLE)socket_ipv6, completion_port, 1, 0)) { RERROR("IPv6 Socket completion port failed."); - closesocket(s); + closesocket(socket_ipv6); + socket_ipv6 = INVALID_SOCKET; } else { - socket_ipv6_ = s; sockaddr_in6 sin6 = {0}; sin6.sin6_family = AF_INET6; sin6.sin6_port = htons(listen_on_port); - if (bind(s, (struct sockaddr*)&sin6, sizeof(sin6)) != 0) { + if (bind(socket_ipv6, (struct sockaddr*)&sin6, sizeof(sin6)) != 0) { RERROR("UdpSocketWin32::Initialize bind failed IPv6"); } } } else { RERROR("IPv6 Socket creation failed."); } - return true; + std::swap(socket_ipv6_, socket_ipv6); + std::swap(socket_, socket_ipv4); + std::swap(completion_port_handle_, completion_port); + retval = true; +fail: + if (completion_port) + CloseHandle(completion_port); + if (socket_ipv4 != INVALID_SOCKET) + closesocket(socket_ipv4); + if (socket_ipv6 != INVALID_SOCKET) + closesocket(socket_ipv6); + return retval; } enum { @@ -556,7 +556,7 @@ void UdpSocketWin32::ThreadMain() { break; restart_read_udp6: ClearOverlapped(&p->overlapped); - p->post_target = ThreadedPacketQueue::TARGET_PROCESSOR_UDP; + p->post_target = PacketProcessor::TARGET_PROCESSOR_UDP; WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data}; DWORD flags = 0; p->sin_size = sizeof(p->addr.sin6); @@ -580,7 +580,7 @@ restart_read_udp6: break; restart_read_udp: ClearOverlapped(&p->overlapped); - p->post_target = ThreadedPacketQueue::TARGET_PROCESSOR_UDP; + p->post_target = PacketProcessor::TARGET_PROCESSOR_UDP; WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data}; DWORD flags = 0; p->sin_size = sizeof(p->addr.sin); @@ -620,7 +620,7 @@ restart_read_udp: if (!entries[i].lpOverlapped) continue; // This is the dummy entry from |PostQueuedCompletionStatus| Packet *p = (Packet*)((byte*)entries[i].lpOverlapped - offsetof(Packet, overlapped)); - if (p->post_target == ThreadedPacketQueue::TARGET_PROCESSOR_UDP) { + if (p->post_target == PacketProcessor::TARGET_PROCESSOR_UDP) { num_reads[entries[i].lpCompletionKey]--; if ((DWORD)p->overlapped.Internal != 0) { if (!IsIgnoredUdpError((DWORD)p->overlapped.Internal)) @@ -673,7 +673,7 @@ restart_read_udp: Packet *p = pending_writes; pending_writes = p->next; ClearOverlapped(&p->overlapped); - p->post_target = ThreadedPacketQueue::TARGET_UDP_DEVICE; + p->post_target = PacketProcessor::TARGET_UDP_DEVICE; WSABUF wsabuf = {(ULONG)p->size, (char*)p->data}; int rv; @@ -715,7 +715,7 @@ restart_read_udp: if (!entries[0].lpOverlapped) continue; // This is the dummy entry from |PostQueuedCompletionStatus| Packet *p = (Packet*)((byte*)entries[0].lpOverlapped - offsetof(Packet, overlapped)); - if (p->post_target == ThreadedPacketQueue::TARGET_PROCESSOR_UDP) { + if (p->post_target == PacketProcessor::TARGET_PROCESSOR_UDP) { num_reads[entries[0].lpCompletionKey]--; } else { num_writes--; @@ -724,8 +724,6 @@ restart_read_udp: } } - - // Called on another thread to queue up a udp packet void UdpSocketWin32::WriteUdpPacket(Packet *packet) { if (qs.udp_qsize2 - qs.udp_qsize1 >= (unsigned)(packet->size < 576 ? MAX_BYTES_IN_UDP_OUT_QUEUE_SMALL : MAX_BYTES_IN_UDP_OUT_QUEUE)) { @@ -754,73 +752,111 @@ DWORD WINAPI UdpSocketWin32::UdpThread(void *x) { } void UdpSocketWin32::StartThread() { + assert(completion_port_handle_); + DWORD thread_id; thread_ = CreateThread(NULL, 0, &UdpThread, this, 0, &thread_id); SetThreadPriority(thread_, ABOVE_NORMAL_PRIORITY_CLASS); } void UdpSocketWin32::StopThread() { - exit_thread_ = true; - PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, NULL); - WaitForSingleObject(thread_, INFINITE); - CloseHandle(thread_); - thread_ = NULL; + if (thread_ != NULL) { + exit_thread_ = true; + PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, NULL); + WaitForSingleObject(thread_, INFINITE); + CloseHandle(thread_); + thread_ = NULL; + exit_thread_ = false; + } } -ThreadedPacketQueue::ThreadedPacketQueue(WireguardProcessor *wg, TunsafeBackendWin32 *backend) { - wg_ = wg; - backend_ = backend; +PacketProcessor::PacketProcessor() { event_ = CreateEvent(NULL, FALSE, FALSE, NULL); last_ptr_ = &first_; first_ = NULL; - handle_ = NULL; - timer_handle_ = NULL; - exit_flag_ = false; + exit_code_ = 0; timer_interrupt_ = false; packets_in_queue_ = 0; need_notify_ = 0; } -ThreadedPacketQueue::~ThreadedPacketQueue() { - assert(handle_ == NULL); - assert(timer_handle_ == NULL); +PacketProcessor::~PacketProcessor() { first_ = NULL; last_ptr_ = &first_; CloseHandle(event_); } -DWORD WINAPI ThreadedPacketQueue::ThreadedPacketQueueLauncher(VOID *x) { - ThreadedPacketQueue *pq = (ThreadedPacketQueue *)x; - return pq->ThreadMain(); +void CALLBACK PacketProcessor::ThreadPoolTimerCallback(PTP_CALLBACK_INSTANCE iTimerInstance, PVOID pContext, PTP_TIMER) { + PacketProcessor *th = (PacketProcessor *)pContext; + th->mutex_.Acquire(); + th->timer_interrupt_ = true; + if (th->need_notify_) { + th->need_notify_ = 0; + th->mutex_.Release(); + SetEvent(th->event_); + return; + } + th->mutex_.Release(); } -DWORD ThreadedPacketQueue::ThreadMain() { - int free_packets_ctr = 0; - int overload = 0; +struct ConfigPacket { + std::string message; + uint32 ident; + Packet packet; +}; + +void PacketProcessor::Reset() { Packet *packet; - wg_->dev().SetCurrentThreadAsMainThread(); + packet = first_; + first_ = NULL; + exit_code_ = 0; + last_ptr_ = &first_; + timer_interrupt_ = false; + + while (packet) { + Packet *next = packet->next; + if (packet->post_target == TARGET_CONFIG_PROTOCOL) { + ConfigPacket *config = (ConfigPacket*)((uint8*)packet - offsetof(ConfigPacket, packet)); + delete config; + } else { + FreePacket(packet); + } + packet = next; + } +} + +int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) { + int free_packets_ctr = 0; + int overload = 0; + int exit_code; + Packet *packet; + PTP_TIMER threadpool_timer; + + threadpool_timer = CreateThreadpoolTimer(&ThreadPoolTimerCallback, this, NULL); + static const int64 duetime = -10000000; // the unit is 100ns + SetThreadpoolTimer(threadpool_timer, (FILETIME*)&duetime, 1000, 1000); mutex_.Acquire(); - while (!exit_flag_) { + while (!(exit_code = exit_code_)) { if (timer_interrupt_) { timer_interrupt_ = false; need_notify_ = 0; mutex_.Release(); - wg_->SecondLoop(); - backend_->stats_mutex_.Acquire(); - backend_->stats_ = wg_->GetStats(); + wg->SecondLoop(); + backend->stats_mutex_.Acquire(); + backend->stats_ = wg->GetStats(); float data[2] = { // unit is megabits/second - backend_->stats_.tun_bytes_in_per_second * (1.0f / 125000), - backend_->stats_.tun_bytes_out_per_second * (1.0f / 125000), + backend->stats_.tun_bytes_in_per_second * (1.0f / 125000), + backend->stats_.tun_bytes_out_per_second * (1.0f / 125000), }; - backend_->stats_collector_.AddSamples(data); - backend_->stats_mutex_.Release(); + backend->stats_collector_.AddSamples(data); + backend->stats_mutex_.Release(); - backend_->delegate_->OnGraphAvailable(); - backend_->PushStats(); + backend->delegate_->OnGraphAvailable(); + backend->PushStats(); // Conserve memory every 10s if (free_packets_ctr++ == 10) { @@ -847,64 +883,50 @@ DWORD ThreadedPacketQueue::ThreadMain() { overload = 2; bool is_overload = (overload != 0); - WireguardProcessor *procint = wg_; do { Packet *next = packet->next; - if (packet->post_target == TARGET_PROCESSOR_UDP) - procint->HandleUdpPacket(packet, is_overload); - else - procint->HandleTunPacket(packet); + if (packet->post_target == TARGET_PROCESSOR_UDP) { + wg->HandleUdpPacket(packet, is_overload); + } else if (packet->post_target == TARGET_PROCESSOR_TUN) { + wg->HandleTunPacket(packet); + } else { + assert(packet->post_target == TARGET_CONFIG_PROTOCOL); + HandleConfigurationProtocolPacket(wg, backend, packet); + } packet = next; } while (packet); } - wg_->RunAllMainThreadScheduled(); + wg->RunAllMainThreadScheduled(); mutex_.Acquire(); } + exit_code_ = 0; mutex_.Release(); - return 0; + + SetThreadpoolTimer(threadpool_timer, nullptr, 0, 0); + WaitForThreadpoolTimerCallbacks(threadpool_timer, true); + CloseThreadpoolTimer(threadpool_timer); + + return exit_code; } -void ThreadedPacketQueue::Start() { - if (handle_ == NULL) { - exit_flag_ = false; - DWORD thread_id; - handle_ = CreateThread(NULL, 0, &ThreadedPacketQueueLauncher, this, 0, &thread_id); - } +void PacketProcessor::HandleConfigurationProtocolPacket(WireguardProcessor *wg, TunsafeBackendWin32 *backend, Packet *packet) { + ConfigPacket *config = (ConfigPacket*)((uint8*)packet - offsetof(ConfigPacket, packet)); + std::string reply; + WgConfig::HandleConfigurationProtocolMessage(wg, std::move(config->message), &reply); + backend->delegate_->OnConfigurationProtocolReply(config->ident, std::move(reply)); - assert(timer_handle_ == NULL); - timer_handle_ = CreateWaitableTimer(NULL, FALSE, NULL); - long long due_time = 10000000; - SetWaitableTimer(timer_handle_, (LARGE_INTEGER*)&due_time, 1000, &TimerRoutine, this, FALSE); } -void ThreadedPacketQueue::Stop() { +void PacketProcessor::PostExit(int exit_code) { mutex_.Acquire(); - exit_flag_ = true; + // Avoid race condition where mode_tun_failed is set during thread exit. + if (exit_code_ != TunsafeBackendWin32::MODE_RESTART && exit_code_ != TunsafeBackendWin32::MODE_EXIT) + exit_code_ = exit_code; mutex_.Release(); - SetEvent(event_); - - if (timer_handle_ != NULL) { - // Not sure if just CloseHandle will close any outstanding APCs - CancelWaitableTimer(timer_handle_); - CloseHandle(timer_handle_); - timer_handle_ = NULL; - } - - if (handle_ != NULL) { - WaitForSingleObject(handle_, INFINITE); - CloseHandle(handle_); - handle_ = NULL; - } } -void ThreadedPacketQueue::AbortingDriver() { - mutex_.Acquire(); - exit_flag_ = true; - mutex_.Release(); -} - -void ThreadedPacketQueue::Post(Packet *packet, Packet **end, int count) { +void PacketProcessor::Post(Packet *packet, Packet **end, int count) { mutex_.Acquire(); if (packets_in_queue_ >= HARD_MAXIMUM_QUEUE_SIZE) { mutex_.Release(); @@ -912,15 +934,11 @@ void ThreadedPacketQueue::Post(Packet *packet, Packet **end, int count) { return; } assert(packet != NULL); - if (!first_) { - assert(last_ptr_ == &first_); - } + assert(first_ || last_ptr_ == &first_); packets_in_queue_ += count; *last_ptr_ = packet; last_ptr_ = end; - if (!first_) { - assert(last_ptr_ == &first_); - } + assert(first_ || last_ptr_ == &first_); if (need_notify_) { need_notify_ = 0; mutex_.Release(); @@ -930,13 +948,12 @@ void ThreadedPacketQueue::Post(Packet *packet, Packet **end, int count) { mutex_.Release(); } -void CALLBACK ThreadedPacketQueue::TimerRoutine(LPVOID lpArgToCompletionRoutine, DWORD dwTimerLowValue, DWORD dwTimerHighValue) { - ((ThreadedPacketQueue*)lpArgToCompletionRoutine)->PostTimerInterrupt(); -} - -void ThreadedPacketQueue::PostTimerInterrupt() { +void PacketProcessor::ForcePost(Packet *packet) { mutex_.Acquire(); - timer_interrupt_ = true; + packet->next = NULL; + packets_in_queue_ += 1; + *last_ptr_ = packet; + last_ptr_ = &packet->next; if (need_notify_) { need_notify_ = 0; mutex_.Release(); @@ -1125,29 +1142,47 @@ static bool AddMultipleCatchallRoutes(int inet, int bits, const uint8 *target, c return success; } -TunWin32Adapter::TunWin32Adapter(DnsBlocker *dns_blocker) { +TunWin32Adapter::TunWin32Adapter(DnsBlocker *dns_blocker, const char guid[ADAPTER_GUID_SIZE]) { handle_ = NULL; dns_blocker_ = dns_blocker; old_ipv6_address_.size = 0; old_ipv6_metric_ = kMetricNone; old_ipv4_metric_ = kMetricNone; has_dns6_setting_ = false; + guid_[0] = 0; + if (guid) + memcpy(guid_, guid, ADAPTER_GUID_SIZE); } TunWin32Adapter::~TunWin32Adapter() { } -bool TunWin32Adapter::OpenAdapter(uint32 *exit_thread, DWORD open_flags) { +bool TunWin32Adapter::OpenAdapter(TunsafeBackendWin32 *backend, DWORD open_flags) { + ULONG info[3]; + DWORD len; assert(handle_ == NULL); - int retry_count = 20; - handle_ = OpenTunAdapter(guid_, retry_count, exit_thread, open_flags); + backend_ = backend; + handle_ = OpenTunAdapter(guid_, backend, open_flags); + if (handle_ != NULL) { + memset(info, 0, sizeof(info)); + if (DeviceIoControl(handle_, TAP_IOCTL_GET_VERSION, &info, sizeof(info), + &info, sizeof(info), &len, NULL)) { + RINFO("TAP Driver Version %d.%d %s", (int)info[0], (int)info[1], (info[2] ? "(DEBUG)" : "")); + } + + if (info[0] < 9 || info[0] == 9 && info[1] <= 8) { + RERROR("TAP is too old. Go to https://tunsafe.com/download to upgrade the driver"); + CloseHandle(handle_); + handle_ = NULL; + } + } return (handle_ != NULL); } -bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInterface::TunConfigOut *out) { - ULONG info[3]; - DWORD len; +bool TunWin32Adapter::ConfigureAdapter(const TunInterface::TunConfig &&config, TunInterface::TunConfigOut *out) { + DWORD len, err; + out->enable_neighbor_discovery_spoofing = false; if (!RunPrePostCommand(config.pre_post_commands.pre_up)) { @@ -1158,28 +1193,11 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt pre_down_ = std::move(config.pre_post_commands.pre_down); post_down_ = std::move(config.pre_post_commands.post_down); - memset(info, 0, sizeof(info)); - if (DeviceIoControl(handle_, TAP_IOCTL_GET_VERSION, &info, sizeof(info), - &info, sizeof(info), &len, NULL)) { - RINFO("TAP Driver Version %d.%d %s", (int)info[0], (int)info[1], (info[2] ? "(DEBUG)" : "")); - } - - if (info[0] < 9 || info[0] == 9 && info[1] <= 8) { - RERROR("TAP is too old. Go to https://tunsafe.com/download to upgrade the driver"); - return false; - } - - // ULONG mtu = 0; - // if (DeviceIoControl(handle_, TAP_IOCTL_GET_MTU, &mtu, sizeof(mtu), &mtu, sizeof(mtu), &len, NULL)) - // RINFO("TAP-Win32 MTU=%d", (int)mtu); - // mtu_ = mtu; - uint32 netmask = CidrToNetmaskV4(config.cidr); // Set TAP-Windows TUN subnet mode if (1) { uint32 v[3]; - v[0] = htonl(config.ip); v[1] = htonl(config.ip & netmask); v[2] = htonl(netmask); @@ -1237,14 +1255,11 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt } bool has_interface_luid = GetNetLuidFromGuid(guid_, &interface_luid_); - if (!has_interface_luid) { RERROR("Unable to determine interface luid for %s.", guid_); return false; } - DWORD err; - if (config.mtu) { err = SetMtuOnNetworkAdapter(&interface_luid_, AF_INET, config.mtu); if (err) @@ -1419,11 +1434,10 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt } RunPrePostCommand(config.pre_post_commands.post_up); - return true; } -void TunWin32Adapter::CloseAdapter() { +void TunWin32Adapter::CloseAdapter(bool is_restart) { RunPrePostCommand(pre_down_); if (handle_ != NULL) { @@ -1433,6 +1447,8 @@ void TunWin32Adapter::CloseAdapter() { &status, sizeof(status), &len, NULL); CloseHandle(handle_); handle_ = NULL; + + TunAdaptersInUse::GetInstance()->Release(backend_); } if (old_ipv6_address_.size != 0) @@ -1441,23 +1457,24 @@ void TunWin32Adapter::CloseAdapter() { SetMetricOnNetworkAdapter(&interface_luid_, AF_INET, old_ipv4_metric_, NULL); if (old_ipv6_metric_ != kMetricNone) SetMetricOnNetworkAdapter(&interface_luid_, AF_INET6, old_ipv6_metric_, NULL); + if (has_dns6_setting_) + SetIPV6DnsOnInterface(&interface_luid_, NULL, 0); old_ipv4_metric_ = old_ipv6_metric_ = -1; old_ipv6_address_.size = 0; - - if (has_dns6_setting_) { - has_dns6_setting_ = false; - SetIPV6DnsOnInterface(&interface_luid_, NULL, 0); - } + has_dns6_setting_ = false; for (auto it = routes_to_undo_.begin(); it != routes_to_undo_.end(); ++it) DeleteRoute(&*it); routes_to_undo_.clear(); - if (dns_blocker_) + if (!is_restart && dns_blocker_) dns_blocker_->RestoreDns(); RunPrePostCommand(post_down_); + + pre_down_.clear(); + post_down_.clear(); } static bool RunOneCommand(const std::string &cmd) { @@ -1564,7 +1581,7 @@ bool TunWin32Adapter::RunPrePostCommand(const std::vector &vec) { ////////////////////////////////////////////////////////////////////////////// -TunWin32Iocp::TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker), backend_(backend) { +TunWin32Iocp::TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker, backend->guid_), backend_(backend) { wqueue_end_ = &wqueue_; wqueue_ = NULL; @@ -1577,36 +1594,38 @@ TunWin32Iocp::TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : TunWin32Iocp::~TunWin32Iocp() { //assert(num_reads_ == 0 && num_writes_ == 0); assert(thread_ == NULL); - CloseTun(); + CloseTun(false); + FreePacketList(wqueue_); } -bool TunWin32Iocp::Initialize(const TunConfig &&config, TunConfigOut *out) { - assert(thread_ == NULL); - - if (adapter_.OpenAdapter(&backend_->stop_mode_, FILE_FLAG_OVERLAPPED)) { +bool TunWin32Iocp::Configure(const TunConfig &&config, TunConfigOut *out) { + // Reconfigure while started? + if (thread_ != NULL) { + assert(completion_port_handle_); + StopThread(); + bool rv = Configure(std::move(config), out); + StartThread(); + return rv; + } + CloseTun(true); + if (adapter_.OpenAdapter(backend_, FILE_FLAG_OVERLAPPED)) { completion_port_handle_ = CreateIoCompletionPort(adapter_.handle(), NULL, NULL, 0); if (completion_port_handle_ != NULL) { - if (adapter_.InitAdapter(std::move(config), out)) + if (adapter_.ConfigureAdapter(std::move(config), out)) return true; } } - CloseTun(); + CloseTun(false); return false; } -void TunWin32Iocp::CloseTun() { +void TunWin32Iocp::CloseTun(bool is_restart) { assert(thread_ == NULL); - - adapter_.CloseAdapter(); - + adapter_.CloseAdapter(is_restart); if (completion_port_handle_) { CloseHandle(completion_port_handle_); completion_port_handle_ = NULL; } - - FreePacketList(wqueue_); - wqueue_ = NULL; - wqueue_end_ = &wqueue_; } enum { @@ -1621,6 +1640,8 @@ void TunWin32Iocp::ThreadMain() { Packet *freed_packets = NULL, **freed_packets_end; int freed_packets_count = 0; DWORD err; + if (!completion_port_handle_) + return; while (!exit_thread_) { // Initiate more reads, reusing the Packet structures in |finished_writes|. @@ -1629,19 +1650,16 @@ void TunWin32Iocp::ThreadMain() { if (!AllocPacketFrom(&freed_packets, &freed_packets_count, &exit_thread_, &p)) break; memset(&p->overlapped, 0, sizeof(p->overlapped)); - p->post_target = ThreadedPacketQueue::TARGET_PROCESSOR_TUN; + p->post_target = PacketProcessor::TARGET_PROCESSOR_TUN; if (!ReadFile(adapter_.handle(), p->data, kPacketCapacity, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) { FreePacket(p); RERROR("TunWin32: ReadFile failed 0x%X", err); if (err == ERROR_OPERATION_ABORTED || err == ERROR_FILE_NOT_FOUND) { - packet_handler_->AbortingDriver(); RERROR("TAP driver stopped communicating. Attempting to restart.", err); - // This can happen if we reinstall the TAP driver while there's an active connection. Wait a bit, then attempt to - // restart. - Sleep(1000); - backend_->TunAdapterFailed(); + // This can happen if we reinstall the TAP driver while there's an active connection. + backend_->PostExit(TunsafeBackendWin32::MODE_TUN_FAILED); goto EXIT; } } else { @@ -1673,7 +1691,7 @@ void TunWin32Iocp::ThreadMain() { if (!entries[i].lpOverlapped) continue; // This is the dummy entry from |PostQueuedCompletionStatus| Packet *p = (Packet*)((byte*)entries[i].lpOverlapped - offsetof(Packet, overlapped)); - if (p->post_target == ThreadedPacketQueue::TARGET_PROCESSOR_TUN) { + if (p->post_target == PacketProcessor::TARGET_PROCESSOR_TUN) { num_reads--; if ((int)p->overlapped.Internal != 0) { RERROR("TunWin32::ReadComplete error 0x%X", (int)p->overlapped.Internal); @@ -1721,7 +1739,7 @@ void TunWin32Iocp::ThreadMain() { Packet *p = pending_writes; pending_writes = p->next; memset(&p->overlapped, 0, sizeof(p->overlapped)); - p->post_target = ThreadedPacketQueue::TARGET_TUN_DEVICE; + p->post_target = PacketProcessor::TARGET_TUN_DEVICE; if (!WriteFile(adapter_.handle(), p->data, p->size, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) { RERROR("TunWin32: WriteFile failed 0x%X", err); FreePacket(p); @@ -1744,7 +1762,7 @@ EXIT: if (!entries[0].lpOverlapped) continue; // This is the dummy entry from |PostQueuedCompletionStatus| Packet *p = (Packet*)((byte*)entries[0].lpOverlapped - offsetof(Packet, overlapped)); - if (p->post_target == ThreadedPacketQueue::TARGET_PROCESSOR_TUN) { + if (p->post_target == PacketProcessor::TARGET_PROCESSOR_TUN) { num_reads--; } else { num_writes--; @@ -1764,6 +1782,8 @@ DWORD WINAPI TunWin32Iocp::TunThread(void *x) { void TunWin32Iocp::StartThread() { DWORD thread_id; + assert(thread_ == NULL); + assert(completion_port_handle_ != NULL); thread_ = CreateThread(NULL, 0, &TunThread, this, 0, &thread_id); SetThreadPriority(thread_, ABOVE_NORMAL_PRIORITY_CLASS); } @@ -1774,6 +1794,7 @@ void TunWin32Iocp::StopThread() { WaitForSingleObject(thread_, INFINITE); CloseHandle(thread_); thread_ = NULL; + exit_thread_ = false; } void TunWin32Iocp::WriteTunPacket(Packet *packet) { @@ -1789,11 +1810,9 @@ void TunWin32Iocp::WriteTunPacket(Packet *packet) { } } - - ////////////////////////////////////////////////////////////////////////////// -TunWin32Overlapped::TunWin32Overlapped(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker), backend_(backend) { +TunWin32Overlapped::TunWin32Overlapped(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker, backend->guid_), backend_(backend) { wqueue_end_ = &wqueue_; wqueue_ = NULL; @@ -1814,10 +1833,10 @@ TunWin32Overlapped::~TunWin32Overlapped() { CloseHandle(wake_event_); } -bool TunWin32Overlapped::Initialize(const TunConfig &&config, TunConfigOut *out) { +bool TunWin32Overlapped::Configure(const TunConfig &&config, TunConfigOut *out) { CloseTun(); - if (adapter_.OpenAdapter(&backend_->stop_mode_, FILE_FLAG_OVERLAPPED) && - adapter_.InitAdapter(std::move(config), out)) + if (adapter_.OpenAdapter(backend_, FILE_FLAG_OVERLAPPED) && + adapter_.ConfigureAdapter(std::move(config), out)) return true; CloseTun(); return false; @@ -1825,7 +1844,7 @@ bool TunWin32Overlapped::Initialize(const TunConfig &&config, TunConfigOut *out) void TunWin32Overlapped::CloseTun() { assert(thread_ == NULL); - adapter_.CloseAdapter(); + adapter_.CloseAdapter(false); FreePacketList(wqueue_); wqueue_ = NULL; wqueue_end_ = &wqueue_; @@ -1842,7 +1861,7 @@ void TunWin32Overlapped::ThreadMain() { Packet *p = AllocPacket(); memset(&p->overlapped, 0, sizeof(p->overlapped)); p->overlapped.hEvent = read_event_; - p->post_target = ThreadedPacketQueue::TARGET_PROCESSOR_TUN; + p->post_target = PacketProcessor::TARGET_PROCESSOR_TUN; if (!ReadFile(adapter_.handle(), p->data, kPacketCapacity, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) { FreePacket(p); RERROR("TunWin32: ReadFile failed 0x%X", err); @@ -1889,7 +1908,7 @@ void TunWin32Overlapped::ThreadMain() { pending_writes = p->next; memset(&p->overlapped, 0, sizeof(p->overlapped)); p->overlapped.hEvent = write_event_; - p->post_target = ThreadedPacketQueue::TARGET_TUN_DEVICE; + p->post_target = PacketProcessor::TARGET_TUN_DEVICE; if (!WriteFile(adapter_.handle(), p->data, p->size, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) { RERROR("TunWin32: WriteFile failed 0x%X", err); FreePacket(p); @@ -1944,39 +1963,37 @@ void TunsafeBackendWin32::SetPublicKey(const uint8 key[32]) { DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) { TunsafeBackendWin32 *backend = (TunsafeBackendWin32*)bk; int stop_mode; + int fast_retry_ctr = 0; for(;;) { TunWin32Iocp tun(&backend->dns_blocker_, backend); UdpSocketWin32 udp; WireguardProcessor wg_proc(&udp, &tun, backend); - ThreadedPacketQueue queues_for_processor(&wg_proc, backend); - qs.udp_qsize1 = qs.udp_qsize2 = 0; - udp.SetPacketHandler(&queues_for_processor); - tun.SetPacketHandler(&queues_for_processor); + udp.SetPacketHandler(&backend->packet_processor_); + tun.SetPacketHandler(&backend->packet_processor_); - wg_proc.dev().SetCurrentThreadAsMainThread(); - - if (!ParseWireGuardConfigFile(&wg_proc, backend->config_file_, &backend->dns_resolver_)) + if (backend->config_file_[0] && + !ParseWireGuardConfigFile(&wg_proc, backend->config_file_, &backend->dns_resolver_)) goto getout_fail; if (!wg_proc.Start()) goto getout_fail; - // only for use in callbacks from wg - backend->wg_processor_ = &wg_proc; - - queues_for_processor.Start(); - udp.StartThread(); - tun.StartThread(); - backend->SetPublicKey(wg_proc.dev().public_key()); - while ((stop_mode = InterlockedExchange(&backend->stop_mode_, MODE_NONE)) == MODE_NONE) { - SleepEx(INFINITE, TRUE); - } + backend->wg_processor_ = &wg_proc; + + udp.StartThread(); + tun.StartThread(); + stop_mode = backend->packet_processor_.Run(&wg_proc, backend); + udp.StopThread(); + tun.StopThread(); + + + backend->wg_processor_ = NULL; // Keep DNS alive if (stop_mode != MODE_EXIT) @@ -1984,39 +2001,31 @@ DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) { else backend->dns_resolver_.ClearCache(); - udp.StopThread(); - tun.StopThread(); - queues_for_processor.Stop(); - - backend->wg_processor_ = NULL; - FreeAllPackets(); if (stop_mode != MODE_TUN_FAILED) return 0; uint32 last_fail = GetTickCount(); - bool permanent_fail = (last_fail - backend->last_tun_adapter_failed_) < 5000; + fast_retry_ctr = (last_fail - backend->last_tun_adapter_failed_ < 5000) ? fast_retry_ctr + 1 : 0; backend->last_tun_adapter_failed_ = last_fail; - backend->status_ = permanent_fail ? TunsafeBackend::kErrorTunPermanent : TunsafeBackend::kStatusTunRetrying; - backend->delegate_->OnStatusCode(backend->status_); + backend->SetStatus((fast_retry_ctr >= 3) ? TunsafeBackend::kErrorTunPermanent : TunsafeBackend::kStatusTunRetrying); - if (permanent_fail) { + if (backend->status_ == TunsafeBackend::kErrorTunPermanent) { RERROR("Too many automatic restarts..."); - goto getout_fail; + goto getout_fail_noseterr; } + Sleep(1000); } getout_fail: - backend->dns_blocker_.RestoreDns(); backend->status_ = TunsafeBackend::kErrorInitialize; backend->delegate_->OnStatusCode(TunsafeBackend::kErrorInitialize); +getout_fail_noseterr: + backend->dns_blocker_.RestoreDns(); return 0; } -static void WINAPI ExitServiceAPC(ULONG_PTR a) { -} - TunsafeBackend::TunsafeBackend() { is_started_ = false; is_remote_ = false; @@ -2029,29 +2038,33 @@ TunsafeBackend::~TunsafeBackend() { } - TunsafeBackendWin32::TunsafeBackendWin32(Delegate *delegate) : delegate_(delegate), dns_resolver_(&dns_blocker_) { memset(&stats_, 0, sizeof(stats_)); wg_processor_ = NULL; InitPacketMutexes(); worker_thread_ = NULL; - stop_mode_ = MODE_NONE; last_tun_adapter_failed_ = 0; want_periodic_stats_ = false; - + guid_[0] = 0; if (g_hklm_reg_key == NULL) { RegCreateKeyEx(HKEY_LOCAL_MACHINE, "Software\\TunSafe", NULL, NULL, 0, KEY_ALL_ACCESS, NULL, &g_hklm_reg_key, NULL); g_killswitch_want = RegReadInt(g_hklm_reg_key, "KillSwitch", 0); } - delegate_->OnStateChanged(); } TunsafeBackendWin32::~TunsafeBackendWin32() { StopInner(false); + TunAdaptersInUse::GetInstance()->Release(this); } -bool TunsafeBackendWin32::Initialize() { + +void TunsafeBackendWin32::SetStatus(StatusCode status) { + status_ = status; + delegate_->OnStatusCode(status); +} + +bool TunsafeBackendWin32::Configure() { // it's always initialized return true; @@ -2061,6 +2074,17 @@ void TunsafeBackendWin32::Teardown() { } +bool TunsafeBackendWin32::SetTunAdapterName(const char *name) { + assert(worker_thread_ == NULL); + size_t len = strlen(name); + if (len >= sizeof(guid_) || guid_[0]) + return false; + if (!TunAdaptersInUse::GetInstance()->Acquire(name, this)) + return false; + memcpy(guid_, name, len + 1); + return true; +} + void TunsafeBackendWin32::RequestStats(bool enable) { want_periodic_stats_ = enable; @@ -2084,12 +2108,10 @@ void TunsafeBackendWin32::Stop() { void TunsafeBackendWin32::Start(const char *config_file) { StopInner(true); - stop_mode_ = MODE_NONE; // this needs to be here cause it's not reset on config file errors dns_resolver_.SetAbortFlag(false); is_started_ = true; memset(public_key_, 0, sizeof(public_key_)); - status_ = kStatusInitializing; - delegate_->OnStatusCode(kStatusInitializing); + SetStatus(kStatusInitializing); delegate_->OnClearLog(); DWORD thread_id; config_file_ = _strdup(config_file); @@ -2098,17 +2120,15 @@ void TunsafeBackendWin32::Start(const char *config_file) { delegate_->OnStateChanged(); } -void TunsafeBackendWin32::TunAdapterFailed() { - InterlockedExchange(&stop_mode_, MODE_TUN_FAILED); - QueueUserAPC(&ExitServiceAPC, worker_thread_, NULL); +void TunsafeBackendWin32::PostExit(int exit_code) { + packet_processor_.PostExit(exit_code); } void TunsafeBackendWin32::StopInner(bool is_restart) { if (worker_thread_) { ipv4_ip_ = 0; dns_resolver_.SetAbortFlag(true); - InterlockedExchange(&stop_mode_, is_restart ? MODE_RESTART : MODE_EXIT); - QueueUserAPC(&ExitServiceAPC, worker_thread_, NULL); + PostExit(is_restart ? MODE_RESTART : MODE_EXIT); WaitForSingleObject(worker_thread_, INFINITE); CloseHandle(worker_thread_); worker_thread_ = NULL; @@ -2116,6 +2136,7 @@ void TunsafeBackendWin32::StopInner(bool is_restart) { config_file_ = NULL; is_started_ = false; status_ = kStatusStopped; + packet_processor_.Reset(); if (!is_restart && !(g_killswitch_want & kBlockInternet_BlockOnDisconnect)) DeactivateKillSwitch(kBlockInternet_Off); @@ -2217,6 +2238,14 @@ std::string TunsafeBackendWin32::GetConfigFileName() { return std::string(); } +void TunsafeBackendWin32::SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) { + ConfigPacket *config_packet = new ConfigPacket; + config_packet->ident = identifier; + config_packet->message = std::move(message); + config_packet->packet.post_target = PacketProcessor::TARGET_CONFIG_PROTOCOL; + packet_processor_.ForcePost(&config_packet->packet); +} + void TunsafeBackendWin32::OnConnected() { if (status_ != TunsafeBackend::kStatusConnected) { ipv4_ip_ = ReadBE32(wg_processor_->tun_addr().addr); @@ -2224,19 +2253,15 @@ void TunsafeBackendWin32::OnConnected() { char buf[kSizeOfAddress]; RINFO("Connection established. IP %s", print_ip_prefix(buf, AF_INET, wg_processor_->tun_addr().addr, -1)); } - status_ = TunsafeBackend::kStatusConnected; - delegate_->OnStatusCode(TunsafeBackend::kStatusConnected); + SetStatus(TunsafeBackend::kStatusConnected); } } void TunsafeBackendWin32::OnConnectionRetry(uint32 attempts) { - if (status_ == TunsafeBackend::kStatusInitializing) { - status_ = TunsafeBackend::kStatusConnecting; - delegate_->OnStatusCode(TunsafeBackend::kStatusConnecting); - } else if (attempts >= 3 && status_ == TunsafeBackend::kStatusConnected) { - status_ = TunsafeBackend::kStatusReconnecting; - delegate_->OnStatusCode(TunsafeBackend::kStatusReconnecting); - } + if (status_ == TunsafeBackend::kStatusInitializing) + SetStatus(TunsafeBackend::kStatusConnecting); + else if (attempts >= 3 && status_ == TunsafeBackend::kStatusConnected) + SetStatus(TunsafeBackend::kStatusReconnecting); } void TunsafeBackend::Delegate::DoWork() { @@ -2255,7 +2280,10 @@ TunsafeBackendDelegateThreaded::~TunsafeBackendDelegateThreaded() { void TunsafeBackendDelegateThreaded::FreeEntry(Entry *e) { if (e->lparam) { - free((void*)e->lparam); + if (e->which == Id_OnConfigurationProtocolReply) + delete (std::string*)e->lparam; + else + free((void*)e->lparam); e->lparam = NULL; } } @@ -2273,6 +2301,7 @@ void TunsafeBackendDelegateThreaded::DoWork() { case Id_OnStatusCode: delegate->OnStatusCode((TunsafeBackend::StatusCode)it->wparam); break; case Id_OnClearLog: delegate->OnClearLog(); break; case Id_OnGraphAvailable: delegate->OnGraphAvailable(); break; + case Id_OnConfigurationProtocolReply: delegate->OnConfigurationProtocolReply(it->wparam, std::move(*(std::string*)it->lparam)); break; } FreeEntry(&*it); } @@ -2314,6 +2343,10 @@ void TunsafeBackendDelegateThreaded::OnClearLog() { AddEntry(Id_OnClearLog); } +void TunsafeBackendDelegateThreaded::OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) { + AddEntry(Id_OnConfigurationProtocolReply, (intptr_t)new std::string(std::move(reply)), ident); +} + TunsafeBackend::Delegate::~Delegate() { } @@ -2331,7 +2364,7 @@ void StatsCollector::Init() { Accumulator *acc = &accum_[0][0]; static const int kAccMax[TIMEVALS] = {5, 6, 10, 0}; - // Initialize all stats channels + // Configure all stats channels for (uint32 channel = 0; channel != CHANNELS; channel++) { for (uint32 timeval = 0; timeval != TIMEVALS; timeval++, acc++) { acc->acc = 0; @@ -2370,3 +2403,83 @@ void StatsCollector::AddSamples(float data[CHANNELS]) { AddToAccumulators(&accum_[i][0], data[i]); } + +TunAdaptersInUse::TunAdaptersInUse() { + num_inuse_ = 0; +} + +bool TunAdaptersInUse::Acquire(const char guid[ADAPTER_GUID_SIZE], void *context) { + size_t len = strlen(guid); + if (len >= ADAPTER_GUID_SIZE) + return false; + ScopedLock scoped_lock(&mutex_); + Entry *e = entry_; + for (uint32 n = num_inuse_; ; n--, e++) { + if (n == 0) { + if (num_inuse_ == kMaxAdaptersInUse) + return false; + num_inuse_++; + e->context = context; + e->count = 0; + memcpy(e->guid, guid, len + 1); + return true; + } + if (!strcmp(e->guid, guid)) { + if (e->context != context) + return false; + e->count++; + return true; + } + } +} + +void TunAdaptersInUse::Release(void *context) { + ScopedLock scoped_lock(&mutex_); + Entry *e = entry_; + for (uint32 n = num_inuse_; n; n--, e++) { + if (e->context == context) { + if (e->count-- == 0) + *e = entry_[num_inuse_-- - 1]; + break; + } + } +} + +void *TunAdaptersInUse::LookupContextFromGuid(const char guid[ADAPTER_GUID_SIZE]) { + ScopedLock scoped_lock(&mutex_); + Entry *e = entry_; + for (uint32 n = num_inuse_; n; n--, e++) { + if (!strcmp(e->guid, guid)) + return e->context; + } + return NULL; +} + +char *TunAdaptersInUse::GetAllGuid() { + ScopedLock scoped_lock(&mutex_); + char *rv = (char*)malloc(ADAPTER_GUID_SIZE * num_inuse_ + 1), *p = rv; + if (rv) { + Entry *e = entry_; + for (uint32 n = num_inuse_; n; n--, e++) { + size_t len = strlen(e->guid); + p[len] = '\n'; + memcpy(p, e->guid, len); + p += len + 1; + } + *p = 0; + } + return rv; +} + +static TunAdaptersInUse g_tun_adapters_in_use; +TunAdaptersInUse *TunAdaptersInUse::GetInstance() { + return &g_tun_adapters_in_use; +} + +TunsafeBackend *TunsafeBackend::FindBackendByTunGuid(const char *guid) { + return (TunsafeBackend*)TunAdaptersInUse::GetInstance()->LookupContextFromGuid(guid); +} + +char *TunsafeBackend::GetAllGuid() { + return TunAdaptersInUse::GetInstance()->GetAllGuid(); +} diff --git a/network_win32.h b/network_win32.h index 9e865e8..bcd8adf 100644 --- a/network_win32.h +++ b/network_win32.h @@ -11,34 +11,39 @@ #include "tunsafe_threading.h" #include +enum { + ADAPTER_GUID_SIZE = 40, +}; + struct Packet; class WireguardProcessor; class TunsafeBackendWin32; -class ThreadedPacketQueue { +class PacketProcessor { public: - explicit ThreadedPacketQueue(WireguardProcessor *wg, TunsafeBackendWin32 *backend); - ~ThreadedPacketQueue(); + explicit PacketProcessor(); + ~PacketProcessor(); enum { TARGET_PROCESSOR_UDP = 0, TARGET_PROCESSOR_TUN = 1, TARGET_UDP_DEVICE = 2, TARGET_TUN_DEVICE = 3, + TARGET_CONFIG_PROTOCOL = 4, }; - void Start(); - void Stop(); + void Reset(); + int Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend); void Post(Packet *packet, Packet **end, int count); - void AbortingDriver(); + void ForcePost(Packet *packet); + void PostExit(int exit_code); + + const uint32 *posted_exit_code() { return &exit_code_; } private: - void PostTimerInterrupt(); - static void CALLBACK TimerRoutine(LPVOID lpArgToCompletionRoutine, DWORD dwTimerLowValue, DWORD dwTimerHighValue); - - DWORD ThreadMain(); - static DWORD WINAPI ThreadedPacketQueueLauncher(VOID *x); + static void CALLBACK ThreadPoolTimerCallback(PTP_CALLBACK_INSTANCE iTimerInstance, PVOID pContext, PTP_TIMER); + void HandleConfigurationProtocolPacket(WireguardProcessor *wg, TunsafeBackendWin32 *backend, Packet *packet); Packet *first_; Packet **last_ptr_; uint32 packets_in_queue_; @@ -46,12 +51,8 @@ private: Mutex mutex_; HANDLE event_; - HANDLE timer_handle_; - HANDLE handle_; - WireguardProcessor *wg_; - bool exit_flag_; + uint32 exit_code_; bool timer_interrupt_; - TunsafeBackendWin32 *backend_; }; // Encapsulates a UDP socket, optionally listening for incoming packets @@ -61,17 +62,16 @@ public: explicit UdpSocketWin32(); ~UdpSocketWin32(); - void SetPacketHandler(ThreadedPacketQueue *packet_handler) { packet_handler_ = packet_handler; } + void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; } void StartThread(); void StopThread(); // -- from UdpInterface - virtual bool Initialize(int listen_on_port) override; + virtual bool Configure(int listen_on_port) override; virtual void WriteUdpPacket(Packet *packet) override; private: - void ThreadMain(); static DWORD WINAPI UdpThread(void *x); @@ -80,7 +80,7 @@ private: Mutex mutex_; - ThreadedPacketQueue *packet_handler_; + PacketProcessor *packet_handler_; SOCKET socket_; SOCKET socket_ipv6_; HANDLE completion_port_handle_; @@ -93,12 +93,12 @@ class DnsBlocker; class TunWin32Adapter { public: - TunWin32Adapter(DnsBlocker *dns_blocker); + TunWin32Adapter(DnsBlocker *dns_blocker, const char guid[ADAPTER_GUID_SIZE]); ~TunWin32Adapter(); - bool OpenAdapter(unsigned int *exit_thread, DWORD open_flags); - bool InitAdapter(const TunInterface::TunConfig &&config, TunInterface::TunConfigOut *out); - void CloseAdapter(); + bool OpenAdapter(TunsafeBackendWin32 *backend, DWORD open_flags); + bool ConfigureAdapter(const TunInterface::TunConfig &&config, TunInterface::TunConfigOut *out); + void CloseAdapter(bool is_restart); HANDLE handle() { return handle_; } @@ -121,8 +121,10 @@ private: NET_LUID interface_luid_; + void *backend_; + std::vector pre_down_, post_down_; - char guid_[64]; + char guid_[ADAPTER_GUID_SIZE]; }; // Implementation of TUN interface handling using IO Completion Ports @@ -131,23 +133,23 @@ public: explicit TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend); ~TunWin32Iocp(); - void SetPacketHandler(ThreadedPacketQueue *packet_handler) { packet_handler_ = packet_handler; } + void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; } void StartThread(); void StopThread(); // -- from TunInterface - virtual bool Initialize(const TunConfig &&config, TunConfigOut *out) override; + virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override; virtual void WriteTunPacket(Packet *packet) override; TunWin32Adapter &adapter() { return adapter_; } private: - void CloseTun(); + void CloseTun(bool is_restart); void ThreadMain(); static DWORD WINAPI TunThread(void *x); - ThreadedPacketQueue *packet_handler_; + PacketProcessor *packet_handler_; HANDLE completion_port_handle_; HANDLE thread_; @@ -168,13 +170,13 @@ public: explicit TunWin32Overlapped(DnsBlocker *blocker, TunsafeBackendWin32 *backend); ~TunWin32Overlapped(); - void SetPacketHandler(ThreadedPacketQueue *packet_handler) { packet_handler_ = packet_handler; } + void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; } void StartThread(); void StopThread(); // -- from TunInterface - virtual bool Initialize(const TunConfig &&config, TunConfigOut *out) override; + virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override; virtual void WriteTunPacket(Packet *packet) override; private: @@ -182,7 +184,7 @@ private: void ThreadMain(); static DWORD WINAPI TunThread(void *x); - ThreadedPacketQueue *packet_handler_; + PacketProcessor *packet_handler_; HANDLE thread_; Mutex mutex_; @@ -199,16 +201,18 @@ private: }; class TunsafeBackendWin32 : public TunsafeBackend, public ProcessorDelegate { - friend class ThreadedPacketQueue; + friend class PacketProcessor; friend class TunWin32Iocp; friend class TunWin32Overlapped; + friend class TunWin32Adapter; public: TunsafeBackendWin32(Delegate *delegate); ~TunsafeBackendWin32(); // -- from TunsafeBackend - virtual bool Initialize() override; + virtual bool Configure() override; virtual void Teardown() override; + virtual bool SetTunAdapterName(const char *name) override; virtual void Start(const char *config_file) override; virtual void Stop() override; virtual void RequestStats(bool enable) override; @@ -218,13 +222,23 @@ public: virtual void SetServiceStartupFlags(uint32 flags) override; virtual LinearizedGraph *GetGraph(int type) override; virtual std::string GetConfigFileName() override; + virtual void SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) override; // -- from ProcessorDelegate virtual void OnConnected() override; virtual void OnConnectionRetry(uint32 attempts) override; void SetPublicKey(const uint8 key[32]); - void TunAdapterFailed(); + void PostExit(int exit_code); + enum { + MODE_NONE = 0, + MODE_EXIT = 1, + MODE_RESTART = 2, + MODE_TUN_FAILED = 3, + }; + uint32 exit_code() { return *packet_processor_.posted_exit_code(); } + + void SetStatus(StatusCode status); private: void StopInner(bool is_restart); @@ -232,16 +246,7 @@ private: void PushStats(); HANDLE worker_thread_; - - enum { - MODE_NONE = 0, - MODE_EXIT = 1, - MODE_RESTART = 2, - MODE_TUN_FAILED = 3, - }; - bool want_periodic_stats_; - unsigned int stop_mode_; Delegate *delegate_; char *config_file_; @@ -256,6 +261,10 @@ private: Mutex stats_mutex_; WgProcessorStats stats_; + + PacketProcessor packet_processor_; + + char guid_[ADAPTER_GUID_SIZE]; }; // This class ensures that all callbacks get rescheduled to another thread @@ -265,13 +274,14 @@ public: ~TunsafeBackendDelegateThreaded(); private: - virtual void OnGetStats(const WgProcessorStats &stats); - virtual void OnGraphAvailable(); - virtual void OnStateChanged(); - virtual void OnClearLog(); - virtual void OnLogLine(const char **s); - virtual void OnStatusCode(TunsafeBackend::StatusCode status); - virtual void DoWork(); + virtual void OnGetStats(const WgProcessorStats &stats) override; + virtual void OnGraphAvailable() override; + virtual void OnStateChanged() override; + virtual void OnClearLog() override; + virtual void OnLogLine(const char **s) override; + virtual void OnStatusCode(TunsafeBackend::StatusCode status) override; + virtual void OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) override; + virtual void DoWork() override; enum Which { Id_OnGetStats, @@ -281,6 +291,7 @@ private: Id_OnUpdateUI, Id_OnStatusCode, Id_OnGraphAvailable, + Id_OnConfigurationProtocolReply, }; void AddEntry(Which which, intptr_t lparam = 0, uint32 wparam = 0); @@ -302,3 +313,37 @@ private: std::vector processing_entry_; }; +// For each adapter, remembers whether the adapter is in use +class TunAdaptersInUse { +public: + TunAdaptersInUse(); + + // attempt to acquire the adapter, so it can't be acquired by anyone else + bool Acquire(const char guid[ADAPTER_GUID_SIZE], void *context); + + // mark as free + void Release(void *context); + + // Lookup a context from a guid + void *LookupContextFromGuid(const char guid[ADAPTER_GUID_SIZE]); + + // Lookup a guid from a context + bool LookupGuidFromContext(void *context, char guid[ADAPTER_GUID_SIZE]); + + char *GetAllGuid(); + + static TunAdaptersInUse *GetInstance(); + +private: + enum { + kMaxAdaptersInUse = 16, + }; + struct Entry { + char guid[ADAPTER_GUID_SIZE]; + void *context; + int count; + }; + Mutex mutex_; + uint8 num_inuse_; + Entry entry_[kMaxAdaptersInUse]; +}; diff --git a/network_win32_api.h b/network_win32_api.h index bf5cf88..b280997 100644 --- a/network_win32_api.h +++ b/network_win32_api.h @@ -5,7 +5,6 @@ #include "stdafx.h" #include "tunsafe_types.h" #include "wireguard.h" - #include struct StatsCollector { @@ -72,6 +71,7 @@ public: virtual void OnClearLog() = 0; virtual void OnLogLine(const char **s) = 0; virtual void OnStatusCode(TunsafeBackend::StatusCode status) = 0; + virtual void OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) = 0; // This function is needed for CreateTunsafeBackendDelegateThreaded, // It's expected to be called on the main thread and then all callbacks will arrive // on the right thread. @@ -82,9 +82,16 @@ public: virtual ~TunsafeBackend(); // Setup/teardown the connection to the local service (if any) - virtual bool Initialize() = 0; + virtual bool Configure() = 0; virtual void Teardown() = 0; + // Set the name of the tun adapter that we want to use. + // On Windows this is the guid of the adapter. + // After having called this, this tun name cannot be used by any other instances. + // Returns false if the name can't be exclusively reserved to this adapter. + virtual bool SetTunAdapterName(const char *name) = 0; + + virtual void Start(const char *config_file) = 0; virtual void Stop() = 0; virtual void RequestStats(bool enable) = 0; @@ -93,10 +100,9 @@ public: virtual InternetBlockState GetInternetBlockState(bool *is_activated) = 0; virtual void SetInternetBlockState(InternetBlockState s) = 0; virtual void SetServiceStartupFlags(uint32 flags) = 0; - virtual std::string GetConfigFileName() = 0; - virtual LinearizedGraph *GetGraph(int type) = 0; + virtual void SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) = 0; bool is_started() { return is_started_; } bool is_remote() { return is_remote_; } @@ -105,6 +111,9 @@ public: StatusCode status() { return status_; } uint32 GetIP() { return ipv4_ip_; } + static TunsafeBackend *FindBackendByTunGuid(const char *guid); + static char *GetAllGuid(); + protected: bool is_started_; bool is_remote_; diff --git a/service_pipe_win32.cpp b/service_pipe_win32.cpp new file mode 100644 index 0000000..bb178c5 --- /dev/null +++ b/service_pipe_win32.cpp @@ -0,0 +1,374 @@ +// SPDX-License-Identifier: AGPL-1.0-only +// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. +#include "stdafx.h" +#include "service_pipe_win32.h" +#include "util.h" +#include "service_win32_constants.h" + +/////////////////////////////////////////////////////////////////////////////////////// +// PipeManager +/////////////////////////////////////////////////////////////////////////////////////// + +PipeManager::PipeManager(const char *pipe_name, bool is_server_pipe, Delegate *delegate) { + pipe_name_ = _strdup(pipe_name); + is_server_pipe_ = is_server_pipe; + for (size_t i = 0; i < kMaxConnections * 2 + 1; i++) + events_[i] = CreateEvent(NULL, i != 0, FALSE, NULL); // For Exit + delegate_ = delegate; + thread_ = NULL; + exit_thread_ = false; + thread_id_ = 0; + for (size_t i = 0; i != kMaxConnections; i++) + connections_[i].Configure(this, (int)i); + connections_[0].state_ = PipeConnection::kStateStarting; +} + +PipeManager::~PipeManager() { + StopThread(); + for (size_t i = 0; i < kMaxConnections * 2 + 1; i++) + CloseHandle(events_[i]); + free(pipe_name_); +} + +bool PipeManager::StartThread() { + assert(thread_ == NULL); + thread_ = CreateThread(NULL, 0, &StaticThreadMain, this, 0, &thread_id_); + return thread_ != NULL; +} + +void PipeManager::StopThread() { + if (thread_ != NULL) { + exit_thread_ = true; + SetEvent(events_[0]); + WaitForSingleObject(thread_, INFINITE); + CloseHandle(thread_); + thread_ = NULL; + } +} + +bool PipeManager::VerifyThread() { + return thread_id_ == GetCurrentThreadId(); +} + +void PipeManager::TryStartNewListener() { + assert(VerifyThread()); + assert(is_server_pipe_); + // Check if any thread is in the listener state, if not, start + PipeConnection *found_conn = NULL; + for (size_t i = 0; i < kMaxConnections; i++) { + PipeConnection *conn = &connections_[i]; + if (conn->connection_established_) + continue; + if (conn->state_ == PipeConnection::kStateWaitConnect) + return; + if (conn->state_ == PipeConnection::kStateNone && found_conn == NULL) + found_conn = conn; + } + if (found_conn) { + found_conn->state_ = PipeConnection::kStateStarting; + found_conn->AdvanceStateMachine(); + } +} + +DWORD WINAPI PipeManager::StaticThreadMain(void *x) { + return ((PipeManager*)x)->ThreadMain(); +} + +DWORD PipeManager::ThreadMain() { + assert(VerifyThread()); + + for (size_t i = 0; i < kMaxConnections; i++) + connections_[i].AdvanceStateMachine(); + + for (;;) { + DWORD rv = WaitForMultipleObjects(1 + kMaxConnections * 2, events_, FALSE, INFINITE); + + // notify? + if (rv == WAIT_OBJECT_0) { + if (exit_thread_) + break; + + delegate_->HandleNotify(); + // The notification event is set when there might be new messages to send, + // so try to send them. + for (size_t i = 0; i != kMaxConnections; i++) + connections_[i].TrySendNextQueuedWrite(); + } else if (rv >= WAIT_OBJECT_0 + 1 && rv < WAIT_OBJECT_0 + 1 + kMaxConnections * 2) { + PipeConnection *conn = &connections_[(rv - 1) >> 1]; + if (rv & 1) { + // read finished + conn->AdvanceStateMachine(); + } else { + // is the write event + conn->HandleWriteComplete(); + } + } else { + assert(0); + } + } + return 0; +} + +/////////////////////////////////////////////////////////////////////////////////////// +// PipeConnection +/////////////////////////////////////////////////////////////////////////////////////// + +static void ClearPipeOverlapped(OVERLAPPED *ov) { + ov->Internal = 0; + ov->InternalHigh = 0; + ov->Offset = 0; + ov->OffsetHigh = 0; +} + +PipeConnection::PipeConnection() { + pipe_ = INVALID_HANDLE_VALUE; + packets_ = NULL; + packets_end_ = &packets_; + write_overlapped_active_ = false; + connection_established_ = false; + state_ = kStateNone; + tmp_packet_buf_ = NULL; + tmp_packet_size_ = 0; + manager_ = NULL; + delegate_ = NULL; +} + +PipeConnection::~PipeConnection() { +} + +void PipeConnection::Configure(PipeManager *manager, int slot) { + manager_ = manager; + read_overlapped_.hEvent = manager->events_[1 + slot * 2]; + write_overlapped_.hEvent = manager->events_[1 + slot * 2 + 1]; +} + +int PipeConnection::InitializeServerPipeAndConnect() { + int BUFSIZE = 8192; + SECURITY_ATTRIBUTES saPipeSecurity = {0}; + uint8 buf[SECURITY_DESCRIPTOR_MIN_LENGTH]; + PSECURITY_DESCRIPTOR pPipeSD = (PSECURITY_DESCRIPTOR)buf; + + if (!InitializeSecurityDescriptor(pPipeSD, SECURITY_DESCRIPTOR_REVISION)) + return -1; + + // set NULL DACL on the SD + if (!SetSecurityDescriptorDacl(pPipeSD, TRUE, (PACL)NULL, FALSE)) + return -1; + + // now set up the security attributes + saPipeSecurity.nLength = sizeof(SECURITY_ATTRIBUTES); + saPipeSecurity.bInheritHandle = TRUE; + saPipeSecurity.lpSecurityDescriptor = pPipeSD; + + pipe_ = CreateNamedPipe(manager_->pipe_name_, + PIPE_ACCESS_DUPLEX | FILE_FLAG_OVERLAPPED, + PIPE_TYPE_BYTE | PIPE_READMODE_BYTE | PIPE_REJECT_REMOTE_CLIENTS | PIPE_WAIT, + PIPE_UNLIMITED_INSTANCES, + BUFSIZE, BUFSIZE, 0, &saPipeSecurity); + if (pipe_ == INVALID_HANDLE_VALUE) + return -1; + + ClearPipeOverlapped(&read_overlapped_); + // It seems like ConnectNamedPipe never sets the event object if it completes + // right away. + if (!ConnectNamedPipe(pipe_, &read_overlapped_)) { + DWORD rv = GetLastError(); + return (rv == ERROR_IO_PENDING) ? 0 : (rv == ERROR_PIPE_CONNECTED) ? 1 : -1; + } else { + return 1; + } +} + +bool PipeConnection::InitializeClientPipe() { + assert(pipe_ == INVALID_HANDLE_VALUE); + pipe_ = CreateFile(manager_->pipe_name_, GENERIC_READ | GENERIC_WRITE, 0, NULL, + OPEN_EXISTING, FILE_FLAG_OVERLAPPED, NULL); + return (pipe_ != INVALID_HANDLE_VALUE); +} + +void PipeConnection::ClosePipe() { + if (pipe_ != INVALID_HANDLE_VALUE) { + CancelIo(pipe_); + CloseHandle(pipe_); + pipe_ = INVALID_HANDLE_VALUE; + } + connection_established_ = false; + write_overlapped_active_ = false; + state_ = kStateNone; + + free(tmp_packet_buf_); + tmp_packet_buf_ = NULL; + tmp_packet_size_ = 0; + + ResetEvent(read_overlapped_.hEvent); + ResetEvent(write_overlapped_.hEvent); + + packets_mutex_.Acquire(); + OutgoingPacket *packets = packets_; + packets_ = NULL; + packets_end_ = &packets_; + packets_mutex_.Release(); + while (packets) { + OutgoingPacket *p = packets; + packets = p->next; + free(p); + } +} + +void PipeConnection::HandleWriteComplete() { + assert(write_overlapped_active_); + + write_overlapped_active_ = false; + + // Remove the packet from the front of the queue, now that it was sent. + packets_mutex_.Acquire(); + OutgoingPacket *p = packets_; + if ((packets_ = p->next) == NULL) + packets_end_ = &packets_; + packets_mutex_.Release(); + free(p); + + if (packets_ == NULL && state_ == kStateWaitTimeout) + AdvanceStateMachine(); + else + TrySendNextQueuedWrite(); +} + +bool PipeConnection::WritePacket(int type, const uint8 *data, size_t data_size) { + OutgoingPacket *packet = (OutgoingPacket *)malloc(offsetof(OutgoingPacket, data[data_size + 1])); + if (packet) { + packet->size = (uint32)(data_size + 1); + packet->data[0] = type; + memcpy(packet->data + 1, data, data_size); + packet->next = NULL; + + packets_mutex_.Acquire(); + OutgoingPacket *was_empty = packets_; + // login messages are always queued up front + if (type == TS_SERVICE_REQ_LOGIN) { + packet->next = packets_; + if (packet->next == NULL) + packets_end_ = &packet->next; + packets_ = packet; + } else { + *packets_end_ = packet; + packets_end_ = &packet->next; + } + packets_mutex_.Release(); + + if (was_empty == NULL) { + // Only allow the pipe thread to invoke the send + if (GetCurrentThreadId() == manager_->thread_id_) { + TrySendNextQueuedWrite(); + } else { + SetEvent(manager_->notify_handle()); + } + } + } + return true; +} + +bool PipeConnection::VerifyThread() { + return manager_->VerifyThread(); +} + +void PipeConnection::TrySendNextQueuedWrite() { + assert(manager_->VerifyThread()); + if (!write_overlapped_active_) { + OutgoingPacket *p = packets_; + if (p && connection_established_) { + ClearPipeOverlapped(&write_overlapped_); + if (WriteFile(pipe_, &p->size, p->size + 4, NULL, &write_overlapped_) || GetLastError() == ERROR_IO_PENDING) + write_overlapped_active_ = true; + } else { + ResetEvent(write_overlapped_.hEvent); + } + } +} + +#define TS_WAIT_BEGIN(t) switch(state_) { case t: +#define TS_WAIT_POINT(t) state_ = (t); return; case t: +#define TS_WAIT_END() } + +void PipeConnection::AdvanceStateMachine() { + DWORD rv; + int srv; + + TS_WAIT_BEGIN(kStateStarting) + // Create a named pipe and wait for connections from the UI process + if (manager_->is_server_pipe_) { + srv = InitializeServerPipeAndConnect(); + if (srv < 0) { + if (!manager_->exit_thread_) + ExitProcess(1); + ClosePipe(); + return; + } + if (srv == 0) { + TS_WAIT_POINT(kStateWaitConnect); + } + } else { + if (!InitializeClientPipe()) { + RINFO("Unable to connect to the TunSafe Service. Please make sure it's running."); + ClosePipe(); + return; + } + } + connection_established_ = true; + delegate_ = manager_->delegate_->HandleNewConnection(this); + TrySendNextQueuedWrite(); + + for (;;) { + // Read the packet length + read_pos_ = 0; + do { + ClearPipeOverlapped(&read_overlapped_); + if (!ReadFile(pipe_, (uint8*)&packet_size_ + read_pos_, 4 - read_pos_, NULL, &read_overlapped_)) { + if ((rv = GetLastError()) != ERROR_IO_PENDING) + goto fail; + TS_WAIT_POINT(kStateWaitReadLength); + } + if ((uint32)read_overlapped_.InternalHigh == 0) + goto fail; + read_pos_ += (uint32)read_overlapped_.InternalHigh; + } while (read_pos_ != 4); + assert(packet_size_ != 0 && packet_size_ < 0x1000000); + if (packet_size_ == 0 || packet_size_ >= 0x1000000) + break; + free(tmp_packet_buf_); + tmp_packet_buf_ = (uint8*)malloc(packet_size_); + if (!tmp_packet_buf_) + break; + // Read the packet payload + read_pos_ = 0; + do { + ClearPipeOverlapped(&read_overlapped_); + if (!ReadFile(pipe_, tmp_packet_buf_ + read_pos_, packet_size_ - read_pos_, NULL, &read_overlapped_)) { + if ((rv = GetLastError()) != ERROR_IO_PENDING) + goto fail; + TS_WAIT_POINT(kStateWaitReadPayload); + } + if ((uint32)read_overlapped_.InternalHigh == 0) + goto fail; + read_pos_ += (uint32)read_overlapped_.InternalHigh; + } while (read_pos_ != packet_size_); + if (!delegate_->HandleMessage(tmp_packet_buf_[0], tmp_packet_buf_ + 1, packet_size_ - 1)) { + ResetEvent(read_overlapped_.hEvent); + if (packets_ != NULL) { + TS_WAIT_POINT(kStateWaitTimeout); + } + break; + } + } +fail: + ClosePipe(); + if (!manager_->exit_thread_) { + delegate_->HandleDisconnect(); + if (manager_->is_server_pipe_) + manager_->TryStartNewListener(); + } + TS_WAIT_END() + + +} + diff --git a/service_pipe_win32.h b/service_pipe_win32.h new file mode 100644 index 0000000..e350755 --- /dev/null +++ b/service_pipe_win32.h @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: AGPL-1.0-only +// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. +#pragma once +#include "tunsafe_threading.h" +#include "network_win32_api.h" + +class PipeManager; + +// Once a pipe connects, this object is used to facilitate the connection +class PipeConnection { + friend class PipeManager; +public: + class Delegate { + public: + virtual bool HandleMessage(int type, uint8 *data, size_t size) = 0; + virtual void HandleDisconnect() = 0; + }; + + PipeConnection(); + ~PipeConnection(); + + void Configure(PipeManager *manager, int slot); + bool WritePacket(int type, const uint8 *data, size_t data_size); + HANDLE pipe_handle() { return pipe_; } + bool is_connected() { return connection_established_; } + bool VerifyThread(); + +private: + // -1 = fail, 0 = wait, 1 = conn + int InitializeServerPipeAndConnect(); + bool InitializeClientPipe(); + void AdvanceStateMachine(); + void ClosePipe(); + void TrySendNextQueuedWrite(); + void HandleWriteComplete(); + + Delegate *delegate_; + PipeManager *manager_; + + HANDLE pipe_; + bool write_overlapped_active_; + bool connection_established_; + + enum State { + kStateNone, + kStateStarting, + kStateWaitConnect, + kStateWaitReadLength, + kStateWaitReadPayload, + kStateWaitTimeout, + }; + + uint8 state_; + + uint32 packet_size_; + + struct OutgoingPacket { + OutgoingPacket *next; + uint32 size; + uint8 data[0]; + }; + OutgoingPacket *packets_, **packets_end_; + uint8 *tmp_packet_buf_; + DWORD tmp_packet_size_; + DWORD read_pos_; + OVERLAPPED write_overlapped_, read_overlapped_; + Mutex packets_mutex_; +}; + +// This class supports multiple PipeConnections and calls HandleNewConnection +// when a new pipe connection is established. +class PipeManager { + friend class PipeConnection; +public: + class Delegate { + public: + // Called when a new connection is established + virtual PipeConnection::Delegate *HandleNewConnection(PipeConnection *handler) = 0; + + // Called when a notification event was pushed + virtual void HandleNotify() = 0; + }; + + PipeManager(const char *pipe_name, bool is_server_pipe, Delegate *delegate); + ~PipeManager(); + + bool StartThread(); + void StopThread(); + bool VerifyThread(); + + HANDLE notify_handle() { return events_[0]; } + PipeConnection *GetClientConnection() { return &connections_[0]; } + + void TryStartNewListener(); + +private: + DWORD ThreadMain(); + static DWORD WINAPI StaticThreadMain(void *x); + + Delegate *delegate_; + HANDLE thread_; + char *pipe_name_; + DWORD thread_id_; + bool is_server_pipe_; + bool exit_thread_; + + enum { kMaxConnections = 2 }; + HANDLE events_[1 + kMaxConnections * 2]; + PipeConnection connections_[kMaxConnections]; +}; + diff --git a/service_win32.cpp b/service_win32.cpp index 48866ff..b4e3c1f 100644 --- a/service_win32.cpp +++ b/service_win32.cpp @@ -9,40 +9,19 @@ #include #include #include "util_win32.h" - -static const uint64 kTunsafeServiceProtocolVersion = 20180809001; +#include "service_win32_constants.h" static SERVICE_STATUS_HANDLE m_statusHandle; -static TunsafeServiceImpl *g_service; +static TunsafeServiceManager *g_service; + +#define SERVICE_DEBUGGING 0 #define SERVICE_NAME L"TunSafeService" #define SERVICE_NAMEA "TunSafeService" #define SERVICE_START_TYPE SERVICE_AUTO_START #define SERVICE_DEPENDENCIES L"tap0901\0dhcp\0" #define SERVICE_ACCOUNT NULL -//L"NT AUTHORITY\\LocalService" #define SERVICE_PASSWORD NULL -#define PIPE_NAME "\\\\.\\pipe\\TunSafe\\ServiceControl" - - -enum { - SERVICE_REQ_LOGIN = 0, - SERVICE_REQ_START = 1, - SERVICE_REQ_STOP = 2, - SERVICE_REQ_GETSTATS = 4, - SERVICE_REQ_SET_INTERNET_BLOCKSTATE = 5, - SERVICE_REQ_RESETSTATS = 6, - SERVICE_REQ_SET_STARTUP_FLAGS = 7, - - SERVICE_MSG_STATE = 8, - SERVICE_MSG_LOGLINE = 9, - SERVICE_MSG_STATS = 11, - SERVICE_MSG_CLEARLOG = 12, - SERVICE_MSG_STATUS_CODE = 14, - - SERVICE_REQ_GET_GRAPH = 15, - SERVICE_MSG_GRAPH = 16, -}; struct ServiceHandles { SC_HANDLE manager; @@ -61,7 +40,6 @@ struct ServiceHandles { bool StartService(); }; - static DWORD InstallService(PWSTR pszServiceName, PWSTR pszDisplayName, DWORD dwStartType, @@ -191,6 +169,18 @@ getout: return result; } +static wchar_t *RegReadStrW(HKEY hkey, const wchar_t *key, const wchar_t *def) { + wchar_t buf[1024]; + DWORD n = sizeof(buf) - 2; + DWORD type = 0; + if (RegQueryValueExW(hkey, key, NULL, &type, (BYTE*)buf, &n) != ERROR_SUCCESS || type != REG_SZ) + return def ? _wcsdup(def) : NULL; + n >>= 1; + if (n && buf[n - 1] == 0) + n--; + buf[n] = 0; + return _wcsdup(buf); +} static DWORD GetNonTransientServiceStatus(SC_HANDLE service) { SERVICE_STATUS ssSvcStatus = {}; @@ -208,7 +198,6 @@ static DWORD GetNonTransientServiceStatus(SC_HANDLE service) { } } - bool ServiceHandles::StartService() { DWORD state = GetNonTransientServiceStatus(service); if (state == 0 || state == SERVICE_RUNNING) @@ -221,7 +210,6 @@ bool ServiceHandles::StartService() { return GetNonTransientServiceStatus(service) == SERVICE_RUNNING; } - static bool StartTunsafeService() { ServiceHandles handles; @@ -232,14 +220,14 @@ static bool StartTunsafeService() { bool IsTunsafeServiceRunning() { ServiceHandles handles; - +#if SERVICE_DEBUGGING + return true; +#endif if (!handles.Open(SERVICE_NAME, SC_MANAGER_CONNECT, SERVICE_QUERY_STATUS)) return false; - return GetNonTransientServiceStatus(handles.service) == SERVICE_RUNNING; } - void StopTunsafeService() { ServiceHandles handles; if (!handles.Open(SERVICE_NAME, SC_MANAGER_CONNECT, @@ -296,7 +284,6 @@ bool IsTunSafeServiceInstalled() { return handles.Open(SERVICE_NAME, SC_MANAGER_CONNECT, SERVICE_QUERY_STATUS); } - static void WriteServiceLog(const char *pszFunction, WORD dwError) { char szMessage[260]; snprintf(szMessage, ARRAYSIZE(szMessage), "%s failed w/err 0x%08lx", pszFunction, dwError); @@ -321,8 +308,7 @@ static void WriteServiceLog(const char *pszFunction, WORD dwError) { } } -static void SetServiceStatus(DWORD dwCurrentState, - DWORD dwWin32ExitCode = 0, +static void SetServiceStatus(DWORD dwCurrentState, DWORD dwWin32ExitCode = 0, DWORD dwWaitHint = 0) { static DWORD dwCheckPoint = 1; @@ -333,10 +319,8 @@ static void SetServiceStatus(DWORD dwCurrentState, m_status.dwCurrentState = dwCurrentState; m_status.dwWin32ExitCode = dwWin32ExitCode; m_status.dwWaitHint = dwWaitHint; - m_status.dwCheckPoint = - ((dwCurrentState == SERVICE_RUNNING) || - (dwCurrentState == SERVICE_STOPPED)) ? - 0 : dwCheckPoint++; + m_status.dwCheckPoint = ((dwCurrentState == SERVICE_RUNNING) || + (dwCurrentState == SERVICE_STOPPED)) ? 0 : dwCheckPoint++; // Report the status of the service to the SCM. ::SetServiceStatus(m_statusHandle, &m_status); } @@ -389,495 +373,147 @@ static const SERVICE_TABLE_ENTRYW serviceTable[] = { {NULL, NULL} }; -PipeMessageHandler::PipeMessageHandler(const char *pipe_name, bool is_server_pipe, Delegate *delegate) { - pipe_name_ = _strdup(pipe_name); - is_server_pipe_ = is_server_pipe; - delegate_ = delegate; - pipe_ = INVALID_HANDLE_VALUE; - wait_handles_[0] = CreateEvent(NULL, TRUE, FALSE, NULL); // for ReadFile - wait_handles_[1] = CreateEvent(NULL, FALSE, FALSE, NULL); // For Exit - wait_handles_[2] = CreateEvent(NULL, TRUE, FALSE, NULL); // for WriteFile - packets_ = NULL; - thread_ = NULL; - packets_end_ = &packets_; - write_overlapped_active_ = false; - exit_thread_ = false; - connection_established_ = false; - thread_id_ = 0; - state_ = kStateNone; - tmp_packet_buf_ = NULL; -} +/////////////////////////////////////////////////////////////////////////////////////// +// TunsafeServiceManager +/////////////////////////////////////////////////////////////////////////////////////// -PipeMessageHandler::~PipeMessageHandler() { - StopThread(); - CloseHandle(wait_handles_[0]); - CloseHandle(wait_handles_[1]); - CloseHandle(wait_handles_[2]); - free(pipe_name_); -} - -bool PipeMessageHandler::InitializeServerPipeAndWait() { - int BUFSIZE = 2048; - SECURITY_ATTRIBUTES saPipeSecurity = {0}; - uint8 buf[SECURITY_DESCRIPTOR_MIN_LENGTH]; - PSECURITY_DESCRIPTOR pPipeSD = (PSECURITY_DESCRIPTOR)buf; - - if (!InitializeSecurityDescriptor(pPipeSD, SECURITY_DESCRIPTOR_REVISION)) - return false; - - // set NULL DACL on the SD - if (!SetSecurityDescriptorDacl(pPipeSD, TRUE, (PACL)NULL, FALSE)) - return false; - - // now set up the security attributes - saPipeSecurity.nLength = sizeof(SECURITY_ATTRIBUTES); - saPipeSecurity.bInheritHandle = TRUE; - saPipeSecurity.lpSecurityDescriptor = pPipeSD; - - pipe_ = CreateNamedPipeW(L"\\\\.\\pipe\\TunSafe\\ServiceControl", - PIPE_ACCESS_DUPLEX | FILE_FLAG_OVERLAPPED, - PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE | PIPE_REJECT_REMOTE_CLIENTS | PIPE_WAIT, - PIPE_UNLIMITED_INSTANCES, - BUFSIZE, BUFSIZE, 0, &saPipeSecurity); - if (pipe_ == INVALID_HANDLE_VALUE) - return false; +TunsafeServiceManager::TunsafeServiceManager() + : pipe_manager_(TUNSAFE_PIPE_NAME, true, this) { + server_unique_id_ = 0; - memset(&read_overlapped_, 0, sizeof(read_overlapped_)); - read_overlapped_.hEvent = wait_handles_[0]; - if (!ConnectNamedPipe(pipe_, &read_overlapped_)) { - DWORD rv = GetLastError(); - if (rv != ERROR_PIPE_CONNECTED && rv != ERROR_IO_PENDING) - return false; - } - return true; -} - -bool PipeMessageHandler::InitializeClientPipe() { - assert(pipe_ == INVALID_HANDLE_VALUE); - pipe_ = CreateFile(pipe_name_, GENERIC_READ | GENERIC_WRITE, 0, NULL, - OPEN_EXISTING, FILE_FLAG_OVERLAPPED, NULL); - if (pipe_ == INVALID_HANDLE_VALUE) - return false; - DWORD mode = PIPE_READMODE_MESSAGE; - SetNamedPipeHandleState(pipe_, &mode, NULL, NULL); - return true; -} - -void PipeMessageHandler::ClosePipe() { - if (pipe_ != INVALID_HANDLE_VALUE) { - CancelIo(pipe_); - CloseHandle(pipe_); - pipe_ = INVALID_HANDLE_VALUE; - } - connection_established_ = false; - write_overlapped_active_ = false; - - free(tmp_packet_buf_); - tmp_packet_buf_ = NULL; - - ResetEvent(wait_handles_[0]); - ResetEvent(wait_handles_[2]); - - packets_mutex_.Acquire(); - OutgoingPacket *packets = packets_; - packets_ = NULL; - packets_end_ = &packets_; - packets_mutex_.Release(); - while (packets) { - OutgoingPacket *p = packets; - packets = p->next; - free(p); - } -} - -bool PipeMessageHandler::WritePacket(int type, const uint8 *data, size_t data_size) { - OutgoingPacket *packet = (OutgoingPacket *)malloc(offsetof(OutgoingPacket, data[data_size + 1])); - if (packet) { - packet->size = (uint32)(data_size + 1); - packet->data[0] = type; - memcpy(packet->data + 1, data, data_size); - packet->next = NULL; - - packets_mutex_.Acquire(); - OutgoingPacket *was_empty = packets_; - // login messages are always queued up front - if (type == SERVICE_REQ_LOGIN) { - packet->next = packets_; - if (packet->next == NULL) - packets_end_ = &packet->next; - packets_ = packet; - } else { - *packets_end_ = packet; - packets_end_ = &packet->next; - } - packets_mutex_.Release(); - - if (was_empty == NULL) { - // Only allow the pipe thread to invoke the send - if (GetCurrentThreadId() == thread_id_) { - SendNextQueuedWrite(); - } else { - SetEvent(wait_handles_[1]); - } - } - } - return true; -} - -void PipeMessageHandler::SendNextQueuedWrite() { - assert(thread_id_ == GetCurrentThreadId()); - if (!write_overlapped_active_) { - OutgoingPacket *p = packets_; - if (p && connection_established_) { - memset(&write_overlapped_, 0, sizeof(write_overlapped_)); - write_overlapped_.hEvent = wait_handles_[2]; - if (WriteFile(pipe_, p->data, p->size, NULL, &write_overlapped_) || GetLastError() == ERROR_IO_PENDING) - write_overlapped_active_ = true; - } else { - ResetEvent(wait_handles_[2]); - } - } -} - -#define TS_WAIT_BEGIN(t) switch(state_) { case t: -#define TS_WAIT_POINT(t) state_ = (t); return; case t: -#define TS_WAIT_END() } - -void PipeMessageHandler::AdvanceStateMachine() { - DWORD rv, bytes_read; - - TS_WAIT_BEGIN(kStateNone) - for(;;) { - // Create a named pipe and wait for connections from the UI process - if (is_server_pipe_) { - if (!InitializeServerPipeAndWait()) { - if (!exit_thread_) - ExitProcess(1); - break; - } - TS_WAIT_POINT(kStateWaitConnect); - } else { - if (!InitializeClientPipe()) { - RINFO("Unable to connect to the TunSafe Service. Please make sure it's running."); - break; - } - } - connection_established_ = true; - delegate_->HandleNewConnection(); - SendNextQueuedWrite(); - - for (;;) { - memset(&read_overlapped_, 0, sizeof(read_overlapped_)); - read_overlapped_.hEvent = wait_handles_[0]; - if (!ReadFile(pipe_, NULL, 0, NULL, &read_overlapped_)) { - rv = GetLastError(); - if (rv != ERROR_IO_PENDING && rv != ERROR_MORE_DATA) - break; - } - TS_WAIT_POINT(kStateWaitReadLength); - PeekNamedPipe(pipe_, NULL, 0, NULL, &tmp_packet_size_, NULL); - if (tmp_packet_size_ == 0) - break; - - free(tmp_packet_buf_); - tmp_packet_buf_ = (uint8*)malloc(tmp_packet_size_); - if (!tmp_packet_buf_) - break; - - memset(&read_overlapped_, 0, sizeof(read_overlapped_)); - read_overlapped_.hEvent = wait_handles_[0]; - if (!ReadFile(pipe_, tmp_packet_buf_, tmp_packet_size_, NULL, &read_overlapped_)) { - rv = GetLastError(); - if (rv != ERROR_IO_PENDING) - break; - } - TS_WAIT_POINT(kStateWaitReadPayload); - bytes_read = (uint32)read_overlapped_.InternalHigh; - if (bytes_read == 0) - break; - if (!delegate_->HandleMessage(tmp_packet_buf_[0], tmp_packet_buf_ + 1, bytes_read - 1)) { - ResetEvent(wait_handles_[0]); - TS_WAIT_POINT(kStateWaitTimeout); - break; - } - } - if (exit_thread_) - break; - delegate_->HandleDisconnect(); - if (!is_server_pipe_) - break; - ClosePipe(); - } - TS_WAIT_END() - ClosePipe(); -} - -DWORD WINAPI PipeMessageHandler::StaticThreadMain(void *x) { - return ((PipeMessageHandler*)x)->ThreadMain(); -} - -bool PipeMessageHandler::VerifyThread() { - return thread_id_ == GetCurrentThreadId(); -} - -DWORD PipeMessageHandler::ThreadMain() { - assert((thread_id_ = GetCurrentThreadId()) != 0); - assert(state_ == kStateNone); - - AdvanceStateMachine(); - - for(;;) { - DWORD rv = WaitForMultipleObjects(3, wait_handles_, FALSE, (state_ == kStateWaitTimeout) ? 1000 : INFINITE); - - // packet write finished? - if (rv == WAIT_OBJECT_0 + 2) { - assert(write_overlapped_active_); - - write_overlapped_active_ = false; - - // Remove the packet from the front of the queue, now that it was sent. - packets_mutex_.Acquire(); - OutgoingPacket *p = packets_; - if ((packets_ = p->next) == NULL) - packets_end_ = &packets_; - packets_mutex_.Release(); - free(p); - SendNextQueuedWrite(); - - // notification - } else if (rv == WAIT_OBJECT_0 + 1) { - if (exit_thread_ || !delegate_->HandleNotify()) - break; - // The notification event is set when there might be new messages to send, - // so try to send them. - SendNextQueuedWrite(); - - // read finished? - } else if (rv == WAIT_OBJECT_0) { - AdvanceStateMachine(); - } else if (rv == WAIT_TIMEOUT) { - if (state_ == kStateWaitTimeout) - AdvanceStateMachine(); - } else { - assert(0); - } - } - return 0; -} - -bool PipeMessageHandler::StartThread() { - DWORD thread_id; - assert(thread_ == NULL); - thread_ = CreateThread(NULL, 0, &StaticThreadMain, this, 0, &thread_id); - return thread_ != NULL; -} - -void PipeMessageHandler::StopThread() { - if (thread_ != NULL) { - exit_thread_ = true; - SetEvent(wait_handles_[1]); - WaitForSingleObject(thread_, INFINITE); - CloseHandle(thread_); - thread_ = NULL; - } - ClosePipe(); -} - -TunsafeServiceImpl::TunsafeServiceImpl() - : message_handler_(PIPE_NAME, true, this) { - thread_delegate_ = CreateTunsafeBackendDelegateThreaded(this, [=] { - SetEvent(message_handler_.notify_handle()); - }); - - backend_ = CreateNativeTunsafeBackend(thread_delegate_); - historical_log_lines_count_ = historical_log_lines_pos_ = 0; - last_line_sent_ = 0; - did_send_getstate_ = false; - memset(historical_log_lines_, 0, sizeof(historical_log_lines_)); hkey_ = NULL; - want_graph_type_ = 0xffffffff; RegCreateKeyEx(HKEY_LOCAL_MACHINE, "Software\\TunSafe", NULL, NULL, 0, KEY_ALL_ACCESS, NULL, &hkey_, NULL); + + main_backend_ = new TunsafeServiceBackend(this); + backends_.push_back(main_backend_); } -TunsafeServiceImpl::~TunsafeServiceImpl() { +TunsafeServiceManager::~TunsafeServiceManager() { + for (TunsafeServiceBackend *backend : backends_) + delete backend; RegCloseKey(hkey_); } -static wchar_t *RegReadStrW(HKEY hkey, const wchar_t *key, const wchar_t *def) { - wchar_t buf[1024]; - DWORD n = sizeof(buf) - 2; - DWORD type = 0; - if (RegQueryValueExW(hkey, key, NULL, &type, (BYTE*)buf, &n) != ERROR_SUCCESS || type != REG_SZ) - return def ? _wcsdup(def) : NULL; - n >>= 1; - if (n && buf[n - 1] == 0) - n--; - buf[n] = 0; - return _wcsdup(buf); +void TunsafeServiceManager::HandleNotify() { + for (TunsafeServiceBackend *backend : backends_) + backend->HandleNotify(); } -unsigned TunsafeServiceImpl::OnStart(int argc, wchar_t **argv) { +PipeConnection::Delegate *TunsafeServiceManager::HandleNewConnection(PipeConnection *connection) { + TunsafeServiceServer *server = new TunsafeServiceServer(connection, main_backend_, server_unique_id_++); + main_backend_->AddPipeServer(server); + pipe_manager_.TryStartNewListener(); + return server; +} + +unsigned TunsafeServiceManager::OnStart(int argc, wchar_t **argv) { uint32 service_flags = RegReadInt(hkey_, "ServiceStartupFlags", 0); - if ( (service_flags & kStartupFlag_BackgroundService) && (service_flags & kStartupFlag_ConnectWhenWindowsStarts) ) { + if ((service_flags & kStartupFlag_BackgroundService) && (service_flags & kStartupFlag_ConnectWhenWindowsStarts)) { char *conf = RegReadStr(hkey_, "LastUsedConfigFile", ""); - if (conf && *conf) { - current_filename_ = (char*)conf; - backend_->Start((char*)conf); - } + if (conf && *conf) + main_backend_->Start(conf); free(conf); } - - message_handler_.StartThread(); + pipe_manager_.StartThread(); return 0; } -bool TunsafeServiceImpl::AuthenticateUser() { - did_authenticate_user_ = true; - - if (!ImpersonateNamedPipeClient(message_handler_.pipe_handle())) - return false; - wchar_t *user = GetUsernameOfCurrentUser(true); - RevertToSelf(); - if (!user) - return false; - wchar_t *valid_user = RegReadStrW(hkey_, L"AllowedUsername", L""); - bool rv = valid_user && wcscmp(user, valid_user) == 0; - - free(user); - free(valid_user); - return rv; +void TunsafeServiceManager::OnStop() { + pipe_manager_.StopThread(); + for (TunsafeServiceBackend *backend : backends_) + backend->Stop(); } -bool TunsafeServiceImpl::HandleMessage(int type, uint8 *data, size_t size) { - if (!did_authenticate_user_) { - if (type != SERVICE_REQ_LOGIN || size < 8 || *(uint64*)data != kTunsafeServiceProtocolVersion) { - const char *s = "Versioning Problem: The TunSafe service is a different version than the UI."; - message_handler_.WritePacket(SERVICE_MSG_LOGLINE, (uint8*)s, strlen(s)); - return false; - } - if (!AuthenticateUser()) { - const char *s = "Permission Problem: Your Windows account is different from the account\r\nthat installed the TunSafe Service. Please reinstall it.\r\n"; - message_handler_.WritePacket(SERVICE_MSG_LOGLINE, (uint8*)s, strlen(s)); - return false; +void TunsafeServiceManager::OnShutdown() { + +} + +TunsafeServiceBackend *TunsafeServiceManager::CreateBackend(const char *guid) { + TunsafeServiceBackend *service_backend = new TunsafeServiceBackend(this); + + // If we're unable to assign the name, maybe it's already in use + if (!service_backend->backend()->SetTunAdapterName(guid)) { + delete service_backend; + return NULL; + } + backends_.push_back(service_backend); + return service_backend; +} + +void TunsafeServiceManager::DestroyBackend(TunsafeServiceBackend *service_backend) { + assert(service_backend != main_backend_); + + // Erase from the list + auto it = std::find(backends_.begin(), backends_.end(), service_backend); + if (it != backends_.end()) + backends_.erase(it); + + delete service_backend; +} + +bool TunsafeServiceManager::SwitchInterface(TunsafeServiceServer *server, const char *interfac, bool want_create) { + // Find a backend by name + TunsafeBackend *backend = TunsafeBackend::FindBackendByTunGuid(interfac); + TunsafeServiceBackend *service_backend = NULL; + if (backend) { + for (TunsafeServiceBackend *sb : backends_) { + if (sb->backend() == backend) { + service_backend = sb; + break; + } } } - - switch (type) { - case SERVICE_REQ_START: - if (data[size - 1] != 0) + if (!service_backend) { + if (!want_create) return false; - - // Don't allow reading arbitrary files on disk - if (!EnsureValidConfigPath((char*)data)) { - char buf[MAX_PATH]; - GetConfigPath(buf, sizeof(buf)); - char *s = str_cat_alloc("Permission Problem: The Config file is in an unsafe location.\r\n Must be in:", buf, "\r\n"); - message_handler_.WritePacket(SERVICE_MSG_LOGLINE, (uint8*)s, strlen(s)); - free(s); + service_backend = CreateBackend(interfac); + if (!service_backend) return false; - } - - g_allow_pre_post = RegReadInt(hkey_, "AllowPrePost", 0) != 0; - - current_filename_ = (char*)data; - backend_->Start((char*)data); - RegWriteStr(hkey_, "LastUsedConfigFile", (char*)data); - - break; - - case SERVICE_REQ_STOP: - backend_->Stop(); - RegWriteStr(hkey_, "LastUsedConfigFile", ""); - OnStateChanged(); - break; - - case SERVICE_REQ_LOGIN: - did_send_getstate_ = true; - OnStatusCode(backend_->status()); - OnStateChanged(); - SendQueuedLogLines(); - break; - - case SERVICE_REQ_GETSTATS: - if (size < 1) return false; - backend_->RequestStats(data[0] != 0); - break; - - case SERVICE_REQ_SET_INTERNET_BLOCKSTATE: - if (size < 1) - return false; - backend_->SetInternetBlockState((InternetBlockState)data[0]); - OnStateChanged(); - break; - - case SERVICE_REQ_RESETSTATS: - backend_->ResetStats(); - break; - - case SERVICE_REQ_GET_GRAPH: - if (size < 4) return false; - want_graph_type_ = *(int*)data; - TunsafeServiceImpl::OnGraphAvailable(); - break; - - case SERVICE_REQ_SET_STARTUP_FLAGS: - if (size < 4) - return false; - RegSetValueEx(hkey_, "ServiceStartupFlags", NULL, REG_DWORD, (BYTE*)data, 4); - break; - - default: - return false; + } + if (server->service_backend() != service_backend) { + server->service_backend()->RemovePipeServer(server); + service_backend->AddPipeServer(server); + server->set_service_backend(service_backend); } return true; } -bool TunsafeServiceImpl::HandleNotify() { - thread_delegate_->DoWork(); - return true; + +/////////////////////////////////////////////////////////////////////////////////////// +// TunsafeServiceBackend +/////////////////////////////////////////////////////////////////////////////////////// + +TunsafeServiceBackend::TunsafeServiceBackend(TunsafeServiceManager *manager) { + manager_ = manager; + historical_log_lines_count_ = historical_log_lines_pos_ = 0; + memset(historical_log_lines_, 0, sizeof(historical_log_lines_)); + HANDLE event = manager_->pipe_manager_.notify_handle(); + thread_delegate_ = CreateTunsafeBackendDelegateThreaded(this, [=] { SetEvent(event); }); + backend_ = CreateNativeTunsafeBackend(thread_delegate_); } -void TunsafeServiceImpl::HandleNewConnection() { - did_send_getstate_ = false; - did_authenticate_user_ = false; - last_line_sent_ = 0; +TunsafeServiceBackend::~TunsafeServiceBackend() { + assert(pipe_servers_.empty()); + delete backend_; + delete thread_delegate_; } -void TunsafeServiceImpl::HandleDisconnect() { - want_graph_type_ = 0xffffffff; - backend_->RequestStats(false); - uint32 service_flags = RegReadInt(hkey_, "ServiceStartupFlags", 0); - if (!(service_flags & kStartupFlag_BackgroundService)) - backend_->Stop(); +void TunsafeServiceBackend::OnGetStats(const WgProcessorStats &stats) { + for (TunsafeServiceServer *pipe_server : pipe_servers_) + if (pipe_server->want_stats()) + pipe_server->WritePacket(TS_SERVICE_MSG_STATS, (uint8*)&stats, sizeof(stats)); } -void TunsafeServiceImpl::OnGraphAvailable() { - if (want_graph_type_ != 0xffffffff) { - LinearizedGraph *graph = backend_->GetGraph(want_graph_type_); - if (graph) - message_handler_.WritePacket(SERVICE_MSG_GRAPH, (uint8*)graph, graph->total_size); - } -} - -void TunsafeServiceImpl::SendQueuedLogLines() { - assert(message_handler_.VerifyThread()); - uint32 maxi = std::min(historical_log_lines_count_, historical_log_lines_pos_ - last_line_sent_); - last_line_sent_ = historical_log_lines_pos_; - for (uint32 i = 0; i < maxi; i++) { - const char *s = historical_log_lines_[(historical_log_lines_pos_ - maxi + i) & (LOGLINE_COUNT - 1)]; - if (s) - message_handler_.WritePacket(SERVICE_MSG_LOGLINE, (uint8*)s, strlen(s)); - } -} - -void TunsafeServiceImpl::OnClearLog() { +void TunsafeServiceBackend::OnClearLog() { historical_log_lines_pos_ = 0; historical_log_lines_count_ = 0; - message_handler_.WritePacket(SERVICE_MSG_CLEARLOG, NULL, 0); + for (TunsafeServiceServer *pipe_server : pipe_servers_) + if (pipe_server->want_state_updates()) + pipe_server->WritePacket(TS_SERVICE_MSG_CLEARLOG, NULL, 0); } -void TunsafeServiceImpl::OnLogLine(const char **s) { - assert(message_handler_.VerifyThread()); +void TunsafeServiceBackend::OnLogLine(const char **s) { + assert(manager_->pipe_manager_.VerifyThread()); char *ss = (char*)*s; *s = NULL; char *&x = historical_log_lines_[historical_log_lines_pos_++ & (LOGLINE_COUNT - 1)]; @@ -885,46 +521,321 @@ void TunsafeServiceImpl::OnLogLine(const char **s) { if (historical_log_lines_count_ < LOGLINE_COUNT) historical_log_lines_count_++; free(ss); - if (did_send_getstate_) - SendQueuedLogLines(); + + for (TunsafeServiceServer *pipe_server : pipe_servers_) + pipe_server->SendQueuedLogLines(); } -void TunsafeServiceImpl::OnGetStats(const WgProcessorStats &stats) { - message_handler_.WritePacket(SERVICE_MSG_STATS, (uint8*)&stats, sizeof(stats)); +void TunsafeServiceBackend::OnStateChanged() { + SendStateUpdate(NULL); // Send to all } -void TunsafeServiceImpl::OnStateChanged() { +void TunsafeServiceBackend::OnStatusCode(TunsafeBackend::StatusCode status) { + if (status == TunsafeBackend::kStatusConnected) + OnStateChanged(); // ensure we know the ip first + uint32 v32 = (uint32)status; + for (TunsafeServiceServer *pipe_server : pipe_servers_) + if (pipe_server->want_state_updates()) + pipe_server->WritePacket(TS_SERVICE_MSG_STATUS_CODE, (uint8*)&v32, 4); +} + +void TunsafeServiceBackend::OnGraphAvailable() { + for (TunsafeServiceServer *pipe_server : pipe_servers_) + pipe_server->OnGraphAvailable(); +} + +void TunsafeServiceBackend::OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) { + for (TunsafeServiceServer *pipe_server : pipe_servers_) + if (pipe_server->unique_id() == ident) + pipe_server->WritePacket(TS_SERVICE_REQ_TEXT_PROTOCOL_REPLY, (uint8*)reply.data(), reply.size()); +} + +void TunsafeServiceBackend::Start(const char *filename) { + g_allow_pre_post = RegReadInt(manager_->hkey_, "AllowPrePost", 0) != 0; + current_filename_ = filename; + backend_->Start(filename); +} + +void TunsafeServiceBackend::RememberLastUsedConfigFile(const char *filename) { + if (manager_->main_backend() == this) + RegWriteStr(manager_->hkey_, "LastUsedConfigFile", filename); +} + +void TunsafeServiceBackend::Stop() { + if (manager_->main_backend() == this) + RegWriteStr(manager_->hkey_, "LastUsedConfigFile", ""); + + backend_->Stop(); + OnStateChanged(); +} + +void TunsafeServiceBackend::UpdateRequestStats() { + bool want = false; + for (auto it = pipe_servers_.begin(); it != pipe_servers_.end(); ++it) { + if ((*it)->want_stats()) { + want = true; + break; + } + } + backend_->RequestStats(want); +} + +void TunsafeServiceBackend::HandleNotify() { + thread_delegate_->DoWork(); +} + +void TunsafeServiceBackend::SendStateUpdate(TunsafeServiceServer *filter) { + if (pipe_servers_.empty()) + return; + uint8 *temp = new uint8[current_filename_.size() + 1 + sizeof(ServiceState)]; bool is_activated; memset(temp, 0, sizeof(ServiceState)); - ServiceState *ss = (ServiceState *)temp; ss->is_started = backend_->is_started(); ss->internet_block_state = backend_->GetInternetBlockState(&is_activated); ss->internet_block_state_active = is_activated; ss->ipv4_ip = backend_->GetIP(); memcpy(ss->public_key, backend_->public_key(), 32); - memcpy(temp + sizeof(ServiceState), current_filename_.c_str(), current_filename_.size() + 1); - message_handler_.WritePacket(SERVICE_MSG_STATE, temp, current_filename_.size() + 1 + sizeof(ServiceState)); + for (TunsafeServiceServer *pipe_server : pipe_servers_) { + if (filter != NULL && pipe_server != filter) + continue; + if (pipe_server->want_state_updates()) + pipe_server->WritePacket(TS_SERVICE_MSG_STATE, temp, current_filename_.size() + 1 + sizeof(ServiceState)); + } + delete[] temp; } -void TunsafeServiceImpl::OnStatusCode(TunsafeBackend::StatusCode status) { - if (status == TunsafeBackend::kStatusConnected) - OnStateChanged(); // ensure we know the ip first - uint32 v32 = (uint32)status; - message_handler_.WritePacket(SERVICE_MSG_STATUS_CODE, (uint8*)&v32, 4); +void TunsafeServiceBackend::RemovePipeServer(TunsafeServiceServer *pipe_server) { + auto it = std::find(pipe_servers_.begin(), pipe_servers_.end(), pipe_server); + if (it != pipe_servers_.end()) + pipe_servers_.erase(it); + + UpdateRequestStats(); + + // Stop the main backend, or destroy a disconnetced backend, when the last client disconnects. + if (pipe_servers_.empty()) { + if (this == manager_->main_backend_) { + uint32 service_flags = RegReadInt(manager_->hkey_, "ServiceStartupFlags", 0); + if (!(service_flags & kStartupFlag_BackgroundService)) + backend_->Stop(); + } else { + if (!backend_->is_started()) + manager_->DestroyBackend(this); + } + } } -void TunsafeServiceImpl::OnStop() { - message_handler_.StopThread(); - backend_->Stop(); +void TunsafeServiceBackend::AddPipeServer(TunsafeServiceServer *pipe_server) { + pipe_servers_.push_back(pipe_server); } -void TunsafeServiceImpl::OnShutdown() { +/////////////////////////////////////////////////////////////////////////////////////// +// TunsafeServiceServer +/////////////////////////////////////////////////////////////////////////////////////// +TunsafeServiceServer::TunsafeServiceServer(PipeConnection *pipe, TunsafeServiceBackend *backend, uint32 unique_id) { + unique_id_ = unique_id; + connection_ = pipe; + service_backend_ = backend; + last_line_sent_ = 0; + want_state_updates_ = false; + did_authenticate_user_ = false; + want_stats_ = false; + want_graph_type_ = 0xffffffff; +} + +TunsafeServiceServer::~TunsafeServiceServer() { +} + +void TunsafeServiceServer::WritePacket(int type, const uint8 *data, size_t data_size) { + connection_->WritePacket(type, data, data_size); +} + +struct ServiceLoginMessage { + uint64 version; + char interfac[kTsMaxDevnameSize]; + bool want_state_updates; + bool want_create_interface; +}; + +bool TunsafeServiceServer::HandleMessage(int type, uint8 *data, size_t size) { + if (!did_authenticate_user_) { + if (type != TS_SERVICE_REQ_LOGIN || + size < sizeof(ServiceLoginMessage) || + ((ServiceLoginMessage*)data)->version != TUNSAFE_SERVICE_PROTOCOL_VERSION) { + const char *s = "Versioning Problem: The TunSafe service is a different version than the UI."; + connection_->WritePacket(TS_SERVICE_MSG_ERROR_REPLY, (uint8*)s, strlen(s)); + return false; + } + if (!AuthenticateUser()) { + const char *s = "Permission Problem: Your Windows account is different from the account\r\nthat installed the TunSafe Service. Please reinstall it."; + connection_->WritePacket(TS_SERVICE_MSG_ERROR_REPLY, (uint8*)s, strlen(s)); + return false; + } + } + + switch (type) { + case TS_SERVICE_REQ_START: { + if (size == 0 || data[size - 1] != 0) + return false; + + for (size_t i = 0; i < size; i++) { + if (data[i] == '/') + data[i] = '\\'; + } + + char buf[MAX_PATH]; + buf[0] = 0; + + if (data[0]) { + if (!ExpandConfigPath((char*)data, buf, sizeof(buf)) || GetFileAttributesA(buf) == INVALID_FILE_ATTRIBUTES) { + char *s = str_cat_alloc("File '", (char*)data, "' not found"); + connection_->WritePacket(TS_SERVICE_MSG_ERROR_REPLY, (uint8*)s, strlen(s)); + free(s); + return false; + } + // Don't allow reading arbitrary files on disk + if (!EnsureValidConfigPath(buf)) { + GetConfigPath(buf, sizeof(buf)); + char *s = str_cat_alloc("Permission Problem: The Config file is in an unsafe location.\r\n Must be in: ", buf, ""); + connection_->WritePacket(TS_SERVICE_MSG_ERROR_REPLY, (uint8*)s, strlen(s)); + free(s); + return false; + } + } + service_backend_->Start(buf); + service_backend_->RememberLastUsedConfigFile(buf); + + // Ensure we reply with something + if (!want_state_updates_) { + uint32 v32 = (uint32)service_backend_->backend_->status(); + connection_->WritePacket(TS_SERVICE_MSG_STATUS_CODE, (uint8*)&v32, 4); + } + break; + } + + case TS_SERVICE_REQ_STOP: + service_backend_->Stop(); + if (!want_state_updates_) { + uint32 v32 = (uint32)service_backend_->backend_->status(); + connection_->WritePacket(TS_SERVICE_MSG_STATUS_CODE, (uint8*)&v32, 4); + } + break; + + case TS_SERVICE_REQ_LOGIN: { + if (((ServiceLoginMessage*)data)->interfac[kTsMaxDevnameSize - 1]) + return false; // sanity check + + if (((ServiceLoginMessage*)data)->interfac[0] != 0) { + if (!service_backend_->manager_->SwitchInterface(this, ((ServiceLoginMessage*)data)->interfac, ((ServiceLoginMessage*)data)->want_create_interface)) { + const char *s = ((ServiceLoginMessage*)data)->want_create_interface ? "Unable to add the interface" : "Interface is not started"; + connection_->WritePacket(TS_SERVICE_MSG_ERROR_REPLY, (uint8*)s, strlen(s)); + return false; + } + } + want_state_updates_ = ((ServiceLoginMessage*)data)->want_state_updates; + if (want_state_updates_) { + SendQueuedLogLines(); + service_backend_->SendStateUpdate(this); + uint32 v32 = (uint32)service_backend_->backend_->status(); + connection_->WritePacket(TS_SERVICE_MSG_STATUS_CODE, (uint8*)&v32, 4); + } + + break; + } + + // return a list of all running interfaces + case TS_SERVICE_REQ_GETINTERFACES: { + char *s = TunsafeBackend::GetAllGuid(); + connection_->WritePacket(TS_SERVICE_REQ_GETINTERFACES_REPLY, (uint8*)s, s ? strlen(s) : 0); + free(s); + break; + } + + case TS_SERVICE_REQ_GETSTATS: + if (size < 1) return false; + want_stats_ = (data[0] != 0); + service_backend_->UpdateRequestStats(); + break; + + case TS_SERVICE_REQ_SET_INTERNET_BLOCKSTATE: + if (size < 1) + return false; + service_backend_->backend_->SetInternetBlockState((InternetBlockState)data[0]); + service_backend_->OnStateChanged(); + break; + + case TS_SERVICE_REQ_RESETSTATS: + service_backend_->backend_->ResetStats(); + break; + + case TS_SERVICE_REQ_GET_GRAPH: + if (size < 4) return false; + want_graph_type_ = *(int*)data; + TunsafeServiceServer::OnGraphAvailable(); + break; + + case TS_SERVICE_REQ_SET_STARTUP_FLAGS: + if (size < 4) + return false; + RegSetValueEx(service_backend_->manager_->hkey_, "ServiceStartupFlags", NULL, REG_DWORD, (BYTE*)data, 4); + break; + + case TS_SERVICE_REQ_TEXT_PROTOCOL: + if (!service_backend_->backend_->is_started()) + service_backend_->Start(""); + service_backend_->backend_->SendConfigurationProtocolPacket(unique_id_, std::string((char*)data, size)); + break; + + default: + return false; + } + return true; +} + +void TunsafeServiceServer::HandleDisconnect() { + service_backend_->RemovePipeServer(this); + delete this; +} + +void TunsafeServiceServer::OnGraphAvailable() { + if (want_graph_type_ != 0xffffffff) { + LinearizedGraph *graph = service_backend_->backend_->GetGraph(want_graph_type_); + if (graph) + connection_->WritePacket(TS_SERVICE_MSG_GRAPH, (uint8*)graph, graph->total_size); + } +} + +void TunsafeServiceServer::SendQueuedLogLines() { + if (!want_state_updates_) + return; + assert(connection_->VerifyThread()); + uint32 maxi = std::min(service_backend_->historical_log_lines_count_, service_backend_->historical_log_lines_pos_ - last_line_sent_); + last_line_sent_ = service_backend_->historical_log_lines_pos_; + for (uint32 i = 0; i < maxi; i++) { + const char *s = service_backend_->historical_log_lines_[(service_backend_->historical_log_lines_pos_ - maxi + i) & (TunsafeServiceBackend::LOGLINE_COUNT - 1)]; + if (s) + connection_->WritePacket(TS_SERVICE_MSG_LOGLINE, (uint8*)s, strlen(s)); + } +} + +bool TunsafeServiceServer::AuthenticateUser() { + if (!ImpersonateNamedPipeClient(connection_->pipe_handle())) + return false; + wchar_t *user = GetUsernameOfCurrentUser(true); + RevertToSelf(); + if (!user) + return false; + wchar_t *valid_user = RegReadStrW(service_backend_->manager_->hkey_, L"AllowedUsername", L""); + bool rv = valid_user && wcscmp(user, valid_user) == 0; + did_authenticate_user_ = rv; + free(user); + free(valid_user); + return rv; } static void PushServiceLine(const char *s) { @@ -937,13 +848,11 @@ static void PushServiceLine(const char *s) { snprintf(buf, sizeof(buf), "[%.2d:%.2d:%.2d] ", t.wHour, t.wMinute, t.wSecond); size_t tl = strlen(buf); - char *x = (char*) malloc(tl + l + 3); + char *x = (char*) malloc(tl + l + 1); memcpy(x, buf, tl); memcpy(x + tl, s, l); - x[l + tl] = '\r'; - x[l + tl + 1] = '\n'; - x[l + tl + 2] = '\0'; - g_service->delegate()->OnLogLine((const char**)&x); + x[l + tl] = '\0'; + g_service->main_backend()->delegate()->OnLogLine((const char**)&x); free(x); } else { size_t l = strlen(s); @@ -965,14 +874,14 @@ static void PushServiceLine(const char *s) { } BOOL RunProcessAsTunsafeServiceProcess() { - g_service = new TunsafeServiceImpl; + g_service = new TunsafeServiceManager; g_logger = &PushServiceLine; - - //g_service->OnStart(NULL, 0); - //MessageBoxA(0, "Service running", "Service running", 0); - //return TRUE; -// while (true)Sleep(1000); +#if SERVICE_DEBUGGING + g_service->OnStart(NULL, 0); + while (true) + Sleep(1000); +#endif // Connects the main thread of a service process to the service control // manager, which causes the thread to be the service control dispatcher @@ -980,42 +889,47 @@ BOOL RunProcessAsTunsafeServiceProcess() { // stopped. The process should simply terminate when the call returns. return StartServiceCtrlDispatcherW(serviceTable); } -TunsafeServiceClient::TunsafeServiceClient(TunsafeBackend::Delegate *delegate) - : message_handler_(PIPE_NAME, false, this) { + + +/////////////////////////////////////////////////////////////////////////////////////// +// TunsafeServiceClient +/////////////////////////////////////////////////////////////////////////////////////// + +TunsafeServiceClient::TunsafeServiceClient(TunsafeBackend::Delegate *delegate) + : pipe_manager_(TUNSAFE_PIPE_NAME, false, this) { is_remote_ = true; got_state_from_control_ = false; delegate_ = delegate; cached_graph_ = 0; last_graph_type_ = 0xffffffff; memset(&service_state_, 0, sizeof(service_state_)); + connection_ = pipe_manager_.GetClientConnection(); } TunsafeServiceClient::~TunsafeServiceClient() { - message_handler_.StopThread(); + pipe_manager_.StopThread(); } -bool TunsafeServiceClient::Initialize() { - // Wait for the service to start - last_graph_type_ = 0xffffffff; - return message_handler_.StartThread(); +bool TunsafeServiceClient::Configure() { + return pipe_manager_.StartThread(); } void TunsafeServiceClient::Start(const char *config_file) { - message_handler_.WritePacket(SERVICE_REQ_START, (uint8*)config_file, strlen(config_file) + 1); + connection_->WritePacket(TS_SERVICE_REQ_START, (uint8*)config_file, strlen(config_file) + 1); } void TunsafeServiceClient::Stop() { - message_handler_.WritePacket(SERVICE_REQ_STOP, NULL, 0); + connection_->WritePacket(TS_SERVICE_REQ_STOP, NULL, 0); } void TunsafeServiceClient::RequestStats(bool enable) { want_stats_ = enable; - if (message_handler_.is_connected()) - message_handler_.WritePacket(SERVICE_REQ_GETSTATS, &want_stats_, 1); + if (connection_->is_connected()) + connection_->WritePacket(TS_SERVICE_REQ_GETSTATS, &want_stats_, 1); } void TunsafeServiceClient::ResetStats() { - message_handler_.WritePacket(SERVICE_REQ_RESETSTATS, NULL, 0); + connection_->WritePacket(TS_SERVICE_REQ_RESETSTATS, NULL, 0); } InternetBlockState TunsafeServiceClient::GetInternetBlockState(bool *is_activated) { @@ -1026,17 +940,17 @@ InternetBlockState TunsafeServiceClient::GetInternetBlockState(bool *is_activate void TunsafeServiceClient::SetInternetBlockState(InternetBlockState s) { uint8 v = (uint8)s; - message_handler_.WritePacket(SERVICE_REQ_SET_INTERNET_BLOCKSTATE, &v, 1); + connection_->WritePacket(TS_SERVICE_REQ_SET_INTERNET_BLOCKSTATE, &v, 1); } void TunsafeServiceClient::SetServiceStartupFlags(uint32 flags) { - message_handler_.WritePacket(SERVICE_REQ_SET_STARTUP_FLAGS, (uint8*)&flags, 4); + connection_->WritePacket(TS_SERVICE_REQ_SET_STARTUP_FLAGS, (uint8*)&flags, 4); } LinearizedGraph *TunsafeServiceClient::GetGraph(int type) { if (type != last_graph_type_) { last_graph_type_ = type; - message_handler_.WritePacket(SERVICE_REQ_GET_GRAPH, (uint8*)&type, 4); + connection_->WritePacket(TS_SERVICE_REQ_GET_GRAPH, (uint8*)&type, 4); } mutex_.Acquire(); LinearizedGraph *graph = cached_graph_; @@ -1045,6 +959,8 @@ LinearizedGraph *TunsafeServiceClient::GetGraph(int type) { return new_graph; } +void TunsafeServiceClient::SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) { +} std::string TunsafeServiceClient::GetConfigFileName() { mutex_.Acquire(); @@ -1054,8 +970,8 @@ std::string TunsafeServiceClient::GetConfigFileName() { } bool TunsafeServiceClient::HandleMessage(int type, uint8 *data, size_t data_size) { - switch(type) { - case SERVICE_MSG_STATE: + switch (type) { + case TS_SERVICE_MSG_STATE: if (data_size <= sizeof(service_state_) || data[data_size - 1]) return false; got_state_from_control_ = true; @@ -1069,15 +985,18 @@ bool TunsafeServiceClient::HandleMessage(int type, uint8 *data, size_t data_size mutex_.Release(); delegate_->OnStateChanged(); return true; - case SERVICE_MSG_LOGLINE: { + case TS_SERVICE_MSG_LOGLINE: + case TS_SERVICE_MSG_ERROR_REPLY: { if (data_size == 0) return false; char *s = my_strndup((char*)data, data_size); - delegate_->OnLogLine((const char **)&s); - free(s); + if (s) { + delegate_->OnLogLine((const char **)&s); + free(s); + } return true; } - case SERVICE_MSG_STATS: { + case TS_SERVICE_MSG_STATS: { WgProcessorStats stats; if (data_size != sizeof(WgProcessorStats)) return false; @@ -1085,25 +1004,24 @@ bool TunsafeServiceClient::HandleMessage(int type, uint8 *data, size_t data_size delegate_->OnGetStats(stats); return true; } - case SERVICE_MSG_CLEARLOG: + case TS_SERVICE_MSG_CLEARLOG: delegate_->OnClearLog(); return true; - case SERVICE_MSG_STATUS_CODE: + case TS_SERVICE_MSG_STATUS_CODE: if (data_size < 4) return false; status_ = (StatusCode)*(uint32*)data; delegate_->OnStatusCode(status_); return true; - case SERVICE_MSG_GRAPH: - if (data_size < 4 || data_size != *(uint32*)data) + case TS_SERVICE_MSG_GRAPH: + if (data_size < sizeof(LinearizedGraph) || data_size != *(uint32*)data) return false; - LinearizedGraph *graph = (LinearizedGraph*)memdup(data, data_size); mutex_.Acquire(); std::swap(graph, cached_graph_); - mutex_.Release(); + mutex_.Release(); free(graph); delegate_->OnGraphAvailable(); return true; @@ -1112,15 +1030,18 @@ bool TunsafeServiceClient::HandleMessage(int type, uint8 *data, size_t data_size return false; } -bool TunsafeServiceClient::HandleNotify() { - return true; +void TunsafeServiceClient::HandleNotify() { } - -void TunsafeServiceClient::HandleNewConnection() { - message_handler_.WritePacket(SERVICE_REQ_LOGIN, (uint8*)&kTunsafeServiceProtocolVersion, 8); +PipeConnection::Delegate *TunsafeServiceClient::HandleNewConnection(PipeConnection *connection) { + assert(connection == connection_); + ServiceLoginMessage msg = {0}; + msg.want_state_updates = true; + msg.version = TUNSAFE_SERVICE_PROTOCOL_VERSION; + connection_->WritePacket(TS_SERVICE_REQ_LOGIN, (uint8*)&msg, sizeof(msg)); if (want_stats_) - message_handler_.WritePacket(SERVICE_REQ_GETSTATS, &want_stats_, 1); + connection_->WritePacket(TS_SERVICE_REQ_GETSTATS, &want_stats_, 1); + return this; } void TunsafeServiceClient::HandleDisconnect() { @@ -1129,16 +1050,19 @@ void TunsafeServiceClient::HandleDisconnect() { } void TunsafeServiceClient::Teardown() { - message_handler_.StopThread(); + pipe_manager_.StopThread(); +} + +bool TunsafeServiceClient::SetTunAdapterName(const char *name) { + // override which tun adapter we want to start + return false; } TunsafeBackend *CreateTunsafeServiceClient(TunsafeBackend::Delegate *delegate) { TunsafeServiceClient *client = new TunsafeServiceClient(delegate); - if (client && !client->Initialize()) { + if (client && !client->Configure()) { delete client; client = NULL; } return client; } - - diff --git a/service_win32.h b/service_win32.h index e19766e..6af6e2d 100644 --- a/service_win32.h +++ b/service_win32.h @@ -3,154 +3,183 @@ #pragma once #include "service_win32_api.h" -#include -#include "util.h" +#include "service_pipe_win32.h" #include "network_win32_api.h" #include "tunsafe_threading.h" -#include -#include -#include + +// Takes care of multiple TunsafeServiceBackend +class TunsafeServiceManager : public PipeManager::Delegate { + friend class TunsafeServiceBackend; + friend class TunsafeServiceServer; +public: + TunsafeServiceManager(); + virtual ~TunsafeServiceManager(); + + // -- from PipeManager::Delegate + virtual void HandleNotify() override; + virtual PipeConnection::Delegate *HandleNewConnection(PipeConnection *connection) override; + + // Called by the service control code to bring the service up or down + unsigned OnStart(int argc, wchar_t **argv); + void OnStop(); + void OnShutdown(); + + TunsafeServiceBackend *main_backend() { return main_backend_; } + + TunsafeServiceBackend *CreateBackend(const char *guid); + void DestroyBackend(TunsafeServiceBackend *backend); + + bool SwitchInterface(TunsafeServiceServer *server, const char *interfac, bool want_create); + +private: + // Points at the Tunsafe hklm reg key + HKEY hkey_; + uint32 server_unique_id_; + + PipeManager pipe_manager_; + + TunsafeServiceBackend *main_backend_; + std::vector backends_; +}; + +// One of these exist for each TunsafeBackend +class TunsafeServiceBackend : public TunsafeBackend::Delegate { + friend class TunsafeServiceServer; +public: + explicit TunsafeServiceBackend(TunsafeServiceManager *manager); + virtual ~TunsafeServiceBackend(); + + // -- from TunsafeBackend::Delegate + virtual void OnGetStats(const WgProcessorStats &stats) override; + virtual void OnClearLog() override; + virtual void OnLogLine(const char **s) override; + virtual void OnStateChanged() override; + virtual void OnStatusCode(TunsafeBackend::StatusCode status) override; + virtual void OnGraphAvailable() override; + virtual void OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) override; + + TunsafeBackend *backend() { return backend_; } + TunsafeBackend::Delegate *delegate() { return thread_delegate_; } + + void Start(const char *filename); + void RememberLastUsedConfigFile(const char *filename); + + void Stop(); + + // Trigger backend stats updates whenever a connected pipe client needs it + void UpdateRequestStats(); + + // Called by TunsafeServiceManager::HandleNotify to process events + // on each backend. + void HandleNotify(); + + // Send a state update to all connected pipes unless filter is set, then it + // sends only to that. + void SendStateUpdate(TunsafeServiceServer *filter); + + // Called whenever a pipe server disconnects + void RemovePipeServer(TunsafeServiceServer *pipe_server); + + // Called to register a pipe server with this backend + void AddPipeServer(TunsafeServiceServer *pipe_server); +private: + // Points at the service manager + TunsafeServiceManager *manager_; + + // Points at the actual TunsafeBackend + TunsafeBackend *backend_; + + // Points at all |TunsafeServiceServer| currently associated with this + // backend. + std::vector pipe_servers_; + + // Points at the thing that transmits TunsafeBackend events to + // the main thread + TunsafeBackend::Delegate *thread_delegate_; + + // The config filename that is loaded + std::string current_filename_; + + // Positions into |historical_log_lines_| + uint32 historical_log_lines_pos_; + uint32 historical_log_lines_count_; + + enum { LOGLINE_COUNT = 256 }; + char *historical_log_lines_[LOGLINE_COUNT]; +}; + +// The server side of the client<->server pipe connection +class TunsafeServiceServer : public PipeConnection::Delegate { + +public: + TunsafeServiceServer(PipeConnection *pipe, TunsafeServiceBackend *backend, uint32 unique_id); + virtual ~TunsafeServiceServer(); + + void WritePacket(int type, const uint8 *data, size_t data_size); + + // -- from PipeConnection::Delegate + virtual bool HandleMessage(int type, uint8 *data, size_t size) override; + virtual void HandleDisconnect() override; + + // Called by TunsafeServiceBackend to push a graph to the client + void OnGraphAvailable(); + + // Called by TunsafeServiceBackend to push more log lines to the client + void SendQueuedLogLines(); + + bool want_stats() const { return want_stats_; } + bool want_state_updates() const { return want_state_updates_; } + uint32 unique_id() const { return unique_id_; } + TunsafeServiceBackend *service_backend() { return service_backend_; } + void set_service_backend(TunsafeServiceBackend *sb) { service_backend_ = sb; } +private: + bool AuthenticateUser(); + + // Whether the client wants state updates + bool want_state_updates_; + + // Whether the client has authenticated + bool did_authenticate_user_; + + // Whether we want stats + bool want_stats_; + + // Whether the currently connected user wants a graph + uint32 want_graph_type_; + + // The last log line sent to the currently connected user + uint32 last_line_sent_; + + uint32 unique_id_; + + // The pipe used to communicate + PipeConnection *connection_; + + // The backend we're currently associated with + TunsafeServiceBackend *service_backend_; +}; + struct ServiceState { uint8 is_started : 1; uint8 internet_block_state_active : 1; uint8 internet_block_state; - uint8 reserved[26+64]; + uint8 reserved[26 + 64]; uint32 ipv4_ip; uint8 public_key[32]; }; STATIC_ASSERT(sizeof(ServiceState) == 128, ServiceState_wrong_size); -class PipeMessageHandler { -public: - class Delegate { - public: - virtual bool HandleMessage(int type, uint8 *data, size_t size) = 0; - virtual bool HandleNotify() = 0; - virtual void HandleNewConnection() = 0; - virtual void HandleDisconnect() = 0; - }; - - PipeMessageHandler(const char *pipe_name, bool is_server_pipe, Delegate *delegate); - ~PipeMessageHandler(); - - bool StartThread(); - void StopThread(); - - bool WritePacket(int type, const uint8 *data, size_t data_size); - - HANDLE notify_handle() { return wait_handles_[1]; } - HANDLE pipe_handle() { return pipe_; } - - bool VerifyThread(); - - bool is_connected() { return connection_established_; } -private: - bool InitializeServerPipeAndWait(); - bool InitializeClientPipe(); - void AdvanceStateMachine(); - void ClosePipe(); - DWORD ThreadMain(); - void SendNextQueuedWrite(); - static DWORD WINAPI StaticThreadMain(void *x); - - Delegate *delegate_; - - HANDLE pipe_; - HANDLE thread_; - HANDLE wait_handles_[3]; - bool write_overlapped_active_; - bool exit_thread_; - bool is_server_pipe_; - bool connection_established_; - char *pipe_name_; - - enum State { - kStateNone, - kStateWaitConnect, - kStateWaitReadLength, - kStateWaitReadPayload, - kStateWaitTimeout, - }; - - int state_; - - struct OutgoingPacket { - OutgoingPacket *next; - uint32 size; - uint8 data[0]; - }; - OutgoingPacket *packets_, **packets_end_; - uint8 *tmp_packet_buf_; - DWORD tmp_packet_size_; - - OVERLAPPED write_overlapped_, read_overlapped_; - - Mutex packets_mutex_; - - DWORD thread_id_; -}; - -class TunsafeServiceImpl : public TunsafeBackend::Delegate, public PipeMessageHandler::Delegate { -public: - TunsafeServiceImpl(); - virtual ~TunsafeServiceImpl(); - - // -- from TunsafeBackend::Delegate - virtual void OnGetStats(const WgProcessorStats &stats); - virtual void OnClearLog(); - virtual void OnLogLine(const char **s); - virtual void OnStateChanged(); - virtual void OnStatusCode(TunsafeBackend::StatusCode status); - virtual void OnGraphAvailable(); - - // -- from PipeMessageHandler::Delegate - virtual bool HandleMessage(int type, uint8 *data, size_t size); - virtual bool HandleNotify(); - virtual void HandleNewConnection(); - virtual void HandleDisconnect(); - - // virtual methods - virtual unsigned OnStart(int argc, wchar_t **argv); - virtual void OnStop(); - virtual void OnShutdown(); - - TunsafeBackend::Delegate *delegate() { return thread_delegate_; } - -private: - void SendQueuedLogLines(); - bool AuthenticateUser(); - - bool did_send_getstate_; - - bool did_authenticate_user_; - uint32 want_graph_type_; - - HKEY hkey_; - - TunsafeBackend *backend_; - TunsafeBackend::Delegate *thread_delegate_; - - PipeMessageHandler message_handler_; - - uint32 historical_log_lines_pos_; - uint32 historical_log_lines_count_; - uint32 last_line_sent_; - std::string current_filename_; - - enum { - LOGLINE_COUNT = 256 - }; - char *historical_log_lines_[LOGLINE_COUNT]; -}; - -class TunsafeServiceClient : public TunsafeBackend, public PipeMessageHandler::Delegate { +class TunsafeServiceClient : public TunsafeBackend, public PipeConnection::Delegate, public PipeManager::Delegate { public: TunsafeServiceClient(TunsafeBackend::Delegate *delegate); virtual ~TunsafeServiceClient(); - virtual bool Initialize(); + + // -- from TunsafeBackend + virtual bool Configure(); virtual void Teardown(); + virtual bool SetTunAdapterName(const char *name); virtual void Start(const char *config_file); virtual void Stop(); virtual void RequestStats(bool enable); @@ -160,12 +189,16 @@ public: virtual std::string GetConfigFileName(); virtual void SetServiceStartupFlags(uint32 flags); virtual LinearizedGraph *GetGraph(int type); + virtual void SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) override; + + // -- from PipeConnection::Delegate + virtual bool HandleMessage(int type, uint8 *data, size_t size) override; + virtual void HandleDisconnect() override; + + // -- from PipeManager::Delegate + virtual void HandleNotify() override; + virtual PipeConnection::Delegate *HandleNewConnection(PipeConnection *connection) override; - // -- from PipeMessageHandler::Delegate - virtual bool HandleMessage(int type, uint8 *data, size_t size); - virtual bool HandleNotify(); - virtual void HandleNewConnection(); - virtual void HandleDisconnect(); protected: TunsafeBackend::Delegate *delegate_; @@ -173,8 +206,10 @@ protected: bool got_state_from_control_; ServiceState service_state_; std::string config_file_; - PipeMessageHandler message_handler_; + PipeManager pipe_manager_; + PipeConnection *connection_; LinearizedGraph *cached_graph_; uint32 last_graph_type_; Mutex mutex_; }; + diff --git a/service_win32_constants.h b/service_win32_constants.h new file mode 100644 index 0000000..f1c5bd1 --- /dev/null +++ b/service_win32_constants.h @@ -0,0 +1,36 @@ +#pragma once + +#define TUNSAFE_PIPE_NAME "\\\\.\\pipe\\TunSafe\\ServiceControl" +#define TUNSAFE_SERVICE_PROTOCOL_VERSION 20180916001 + +enum { + TS_SERVICE_REQ_LOGIN = 0, + TS_SERVICE_REQ_START = 1, + TS_SERVICE_REQ_STOP = 2, + + TS_SERVICE_REQ_GETSTATS = 4, + TS_SERVICE_REQ_SET_INTERNET_BLOCKSTATE = 5, + TS_SERVICE_REQ_RESETSTATS = 6, + TS_SERVICE_REQ_SET_STARTUP_FLAGS = 7, + + TS_SERVICE_MSG_STATE = 8, + TS_SERVICE_MSG_LOGLINE = 9, + TS_SERVICE_MSG_ERROR_REPLY = 10, + TS_SERVICE_MSG_STATS = 11, + TS_SERVICE_MSG_CLEARLOG = 12, + TS_SERVICE_MSG_STATUS_CODE = 14, + + TS_SERVICE_REQ_GET_GRAPH = 15, + TS_SERVICE_MSG_GRAPH = 16, + + + TS_SERVICE_REQ_TEXT_PROTOCOL = 17, + TS_SERVICE_REQ_TEXT_PROTOCOL_REPLY = 18, + + TS_SERVICE_REQ_GETINTERFACES = 19, + TS_SERVICE_REQ_GETINTERFACES_REPLY = 20, +}; + +enum { + kTsMaxDevnameSize = 40 +}; \ No newline at end of file diff --git a/stdafx.h b/stdafx.h index 61625f4..a729e22 100644 --- a/stdafx.h +++ b/stdafx.h @@ -13,10 +13,14 @@ #if defined(OS_WIN) #define _WINSOCK_DEPRECATED_NO_WARNINGS 1 +#define _HAS_EXCEPTIONS 0 +#define _CRT_SECURE_NO_WARNINGS 1 + //#include #include #include +#undef max //#include #include #include diff --git a/ts.cpp b/ts.cpp new file mode 100644 index 0000000..2cd5d8c --- /dev/null +++ b/ts.cpp @@ -0,0 +1,883 @@ +#include "stdafx.h" +#include "tunsafe_types.h" +#include "netapi.h" +#include "crypto/curve25519-donna.h" +#include "util.h" +#include "wireguard_proto.h" +#include +#include + +#if defined(OS_WIN) +#include "util_win32.h" +#include "service_pipe_win32.h" +#include "service_win32_constants.h" +#endif // defined(OS_WIN) + +#if defined(OS_POSIX) +#include +#include +#include +#include +#include +#endif // defined(OS_WIN) + + +#pragma comment(lib, "ws2_32.lib") + +#define ANSI_RESET "\x1b[0m" +#define ANSI_BOLD "\x1b[1m" +#define ANSI_FG_BLACK "\x1b[30m" +#define ANSI_FG_RED "\x1b[31m" +#define ANSI_FG_GREEN "\x1b[32m" +#define ANSI_FG_YELLOW "\x1b[33m" +#define ANSI_FG_BLUE "\x1b[34m" +#define ANSI_FG_MAGENTA "\x1b[35m" +#define ANSI_FG_CYAN "\x1b[36m" +#define ANSI_FG_WHITE "\x1b[37m" + +static const uint8 kCurve25519Basepoint[32] = {9}; + +#if defined(OS_WIN) +#define EXENAME "ts" + +static bool SendMessageToService(HANDLE pipe, int message, const void *data, size_t data_size) { + uint8 *temp = new uint8[data_size + 5]; + *(uint32*)temp = (uint32)(data_size + 1); + temp[4] = (uint8)message; + memcpy(temp + 5, data, data_size); + // Write the whole thing + DWORD pos = 0, bytes_to_write = (DWORD)(data_size + 5), bytes_written; + do { + if (!WriteFile(pipe, temp + pos, bytes_to_write, &bytes_written, NULL)) { + fprintf(stderr, "Error writing to service pipe, error = %d\n", GetLastError()); + break; + } + pos += bytes_written; + bytes_to_write -= bytes_written; + } while (bytes_to_write != 0); + delete[] temp; + return (bytes_to_write == 0); +} + +static bool ReadExactBytesFromPipe(HANDLE pipe, const void *data, DWORD bytes_to_read) { + DWORD pos = 0, n; + do { + if (!ReadFile(pipe, (uint8*)data + pos, bytes_to_read, &n, NULL)) + return false; + if (n == 0) + return false; // premature eof.. + pos += n; + bytes_to_read -= n; + } while (bytes_to_read != 0); + return true; +} + +static bool ReadMessageFromService(HANDLE pipe, int *message, std::string *data) { + uint8 header[5]; + uint32 message_size; + + if (!ReadExactBytesFromPipe(pipe, header, 5) || (message_size = *(uint32*)header) == 0) { + fprintf(stderr, "Error reading from service pipe, error = %d\n", GetLastError()); + return false; + } + *message = header[4]; + data->resize(message_size - 1); + if (message_size - 1 != 0 && !ReadExactBytesFromPipe(pipe, data->data(), message_size - 1)) { + fprintf(stderr, "Error reading from service pipe, error = %d\n", GetLastError()); + return false; + } + return true; +} + +struct ServiceLoginMessage { + uint64 version; + char interfac[kTsMaxDevnameSize]; + bool want_state_updates; + bool want_create_interface; +}; + +static std::vector g_tap_adapters; +static bool g_did_get_adapters; +static const std::vector &GetTapAdapterInfo() { + if (!g_did_get_adapters) { + g_did_get_adapters = true; + GetTapAdapterInfo(&g_tap_adapters); + } + return g_tap_adapters; +} + +static const char *GetGuidFromInterfaceName(const char *name) { + for (const GuidAndDevName &e : GetTapAdapterInfo()) + if (strcmp(e.name, name) == 0) + return e.guid; + return NULL; +} + +static const char *GetInterfaceNameFromGuid(const char *guid) { + for (const GuidAndDevName &e : GetTapAdapterInfo()) + if (strcmp(e.guid, guid) == 0) + return e.name; + return NULL; +} + + +static HANDLE ConnectToService(const char *devname, bool want_updates, bool want_create = false) { + ServiceLoginMessage msg = {0}; + msg.version = TUNSAFE_SERVICE_PROTOCOL_VERSION; + msg.want_state_updates = want_updates; + msg.want_create_interface = want_create; + + // Rename devname to a guid + if (devname) { + const char *guid = (devname[0] == '{' || devname[0] == 0) ? devname : GetGuidFromInterfaceName(devname); + if (!guid) { + fprintf(stderr, "Interface '%s' not found\n", devname); + return NULL; + } + my_strlcpy(msg.interfac, sizeof(msg.interfac), guid); + } + + for (;;) { + HANDLE pipe = CreateFile(TUNSAFE_PIPE_NAME, GENERIC_READ | GENERIC_WRITE, 0, NULL, + OPEN_EXISTING, 0, NULL); + if (pipe != INVALID_HANDLE_VALUE) { + if (!SendMessageToService(pipe, TS_SERVICE_REQ_LOGIN, &msg, sizeof(msg))) { + CloseHandle(pipe); + pipe = NULL; + } + return pipe; + } + DWORD error = GetLastError(); + if (error != ERROR_PIPE_BUSY) { + fprintf(stderr, "Error connecting to TunSafe service, error = %d\n", error); + if (error == ERROR_FILE_NOT_FOUND) + fprintf(stderr, "Please check that the TunSafe service is started\n"); + return NULL; + } + if (!WaitNamedPipe(TUNSAFE_PIPE_NAME, 10000)) { + fprintf(stderr, "Error connecting to TunSafe service, timed out.\n"); + return NULL; + } + } +} + +static bool CommunicateWithService(const char *devname, const std::string &query, std::string *reply) { + HANDLE pipe = ConnectToService(devname, false); + int message_code; + bool rv = false; + + if (pipe != NULL && + SendMessageToService(pipe, TS_SERVICE_REQ_TEXT_PROTOCOL, query.data(), query.size()) && + ReadMessageFromService(pipe, &message_code, reply)) { + if (message_code == TS_SERVICE_REQ_TEXT_PROTOCOL_REPLY) { + rv = true; + } else { + if (message_code == TS_SERVICE_MSG_ERROR_REPLY) { + fprintf(stderr, "Error: %s\n", reply->c_str()); + } else { + fprintf(stderr, "Unknown reply (%d) from TunSafe service.\n", message_code); + } + } + } + CloseHandle(pipe); + return rv; +} + +static bool GetInterfaceList(std::string *result) { + HANDLE pipe = ConnectToService(NULL, false); + int message_code; + bool rv = false; + + if (pipe != NULL && + SendMessageToService(pipe, TS_SERVICE_REQ_GETINTERFACES, NULL, 0) && + ReadMessageFromService(pipe, &message_code, result)) { + if (message_code == TS_SERVICE_REQ_GETINTERFACES_REPLY) { + rv = true; + } else { + fprintf(stderr, "GetInterfaceList: bad reply\n"); + } + } + CloseHandle(pipe); + return rv; +} +#endif // defined(OS_WIN) + +#if defined(OS_POSIX) +#define EXENAME "tunsafe" + +static const char *GetGuidFromInterfaceName(const char *name) { + return name; +} + +static const char *GetInterfaceNameFromGuid(const char *guid) { + return guid; +} + + +static int OpenUserspaceInterface(const char *iface) { + struct stat st; + struct sockaddr_un un = { 0 }; + int fd = -1, rv; + + if (strchr(iface, '/') != NULL) { + fprintf(stderr, "Unable to open usermode socket: No such device\n"); + goto getout; + } + + snprintf(un.sun_path, sizeof(un.sun_path), "/var/run/wireguard/%s.sock", iface); + if (stat(un.sun_path, &st) < 0) { + perror("Unable to open usermode socket"); + goto getout; + } + + if (!S_ISSOCK(st.st_mode)) { + fprintf(stderr, "Unable to open usermode socket: No such device\n"); + goto getout; + } + + fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (fd < 0) + goto getout; + + un.sun_family = AF_UNIX; + if (connect(fd, (struct sockaddr *)&un, sizeof(un)) < 0) { + if (errno == ECONNREFUSED) + unlink(un.sun_path); + else + perror("Error opening wireguard usermode interface socket"); + goto getout; + } + return fd; + +getout: + if (fd >= 0) + close(fd); + return -1; +} + + +static bool GetInterfaceList(std::string *result) { + struct dirent *dent; + + DIR *dir = opendir("/var/run/wireguard/"); + if (!dir) + return errno == ENOENT; + + while ((dent = readdir(dir)) != NULL) { + size_t len = strlen(dent->d_name); + static const char kSuffix[6] = ".sock"; + if (len >= sizeof(kSuffix) - 1 && + memcmp(&dent->d_name[len - (sizeof(kSuffix) - 1)], kSuffix, sizeof(kSuffix) - 1) == 0) { + dent->d_name[len - (sizeof(kSuffix) - 1)] = '\n'; + result->append(dent->d_name, len - (sizeof(kSuffix) - 1) + 1); + } + } + closedir(dir); + return true; +} + +static bool CommunicateWithService(const char *devname, const std::string &query, std::string *reply) { + ssize_t n; + char buf[4096]; + bool rv = false; + + reply->clear(); + + int fd = OpenUserspaceInterface(devname); + if (fd == -1) + return false; + + for(size_t pos = 0; query.size() - pos; pos += n) { + n = write(fd, query.data() + pos, query.size() - pos); + if (n <= 0) { + perror("Error writing to service pipe"); + goto getout; + } + } + + for(;;) { + n = read(fd, buf, sizeof(buf)); + if (n <= 0) { + if (n == 0) { + // ensure that it ends with \n\n + if (reply->size() >= 2 && (*reply)[reply->size() - 1] == '\n' && (*reply)[reply->size() - 2] == '\n') { + rv = true; + } else { + fprintf(stderr, "Bad reply from service pipe\n"); + } + } else { + perror("Error reading from service pipe"); + } + break; + } + reply->append(buf, n); + } + +getout: + close(fd); + return rv; +} + +static int HandleStopCommand(int argc, char **argv) { + if (argc != 1) { + fprintf(stderr, "Usage: " EXENAME " stop \n"); + return 1; + } + struct sockaddr_un un; + struct stat st; + const char *iface = argv[0]; + if (strchr(iface, '/')) { + fprintf(stderr, "No such interface\n"); + return 1; + } + snprintf(un.sun_path, sizeof(un.sun_path), "/var/run/wireguard/%s.sock", iface); + if (unlink(un.sun_path) == -1) { + perror("unlink"); + return 1; + } + return 0; +} + +#endif // defined(OS_POSIX) + +void ShowHelp() { + fprintf(stderr, + "Usage: " EXENAME " []\n\n" +#if defined(OS_POSIX) + " " EXENAME " filename.conf\n\n" +#endif // defined(OS_POSIX) + "Available subcommands:\n" + " show: Shows the configuration and status of the interfaces\n" + " set: Change the configuration or the peer list\n" + " start: Start TunSafe on an interface\n" + " stop: Stop TunSafe on an interface\n" +#if defined(OS_WIN) + " log: Display recent log entries\n" +#endif // defined(OS_WIN) + " genkey: Writes a new private key to stdout\n" + " genpsk: Writes a new preshared key to stdout\n" + " pubkey: Reads a private key from stdin and writes its public key to stdout\n" + "To see more help about a subcommand, pass --help to it\n"); +} + + + +static bool ParseHexKeyToBase64(const char *key, char base64key[WG_PUBLIC_KEY_LEN_BASE64 + 1]) { + uint8 keybuf[32]; + if (!ParseHexString(key, keybuf, 32)) + return false; + return base64_encode(keybuf, 32, base64key, WG_PUBLIC_KEY_LEN_BASE64 + 1, NULL) != NULL; +} + +static char *FormatTransferPart(char *buf, size_t bufsize, uint64 n) { + if (n < 1024) + snprintf(buf, bufsize, "%u " ANSI_FG_CYAN "B" ANSI_RESET, (unsigned)n); + else if (n < 1024 * 1024) + snprintf(buf, bufsize, "%.2f " ANSI_FG_CYAN "KiB" ANSI_RESET, (double)n * (1.0 / 1024)); + else if (n < 1024 * 1024 * 1024) + snprintf(buf, bufsize, "%.2f " ANSI_FG_CYAN "MiB" ANSI_RESET, (double)n * (1.0 / 1024 / 1024)); + else if (n < 1024ull * 1024 * 1024 * 1024) + snprintf(buf, bufsize, "%.2f " ANSI_FG_CYAN "GiB" ANSI_RESET, (double)n * (1.0 / 1024 / 1024 / 1024)); + else + snprintf(buf, bufsize, "%.2f " ANSI_FG_CYAN "TiB" ANSI_RESET, (double)n * (1.0 / 1024 / 1024 / 1024 / 1024)); + return buf; +} + +static size_t PrintTime(char *buf, size_t bufsize, uint64 n) { + size_t pos = 0; + uint64 years = n / (365 * 24 * 60 * 60); + uint32 n32 = n % (365 * 24 * 60 * 60); + if (years) + pos += snprintf(buf + pos, bufsize - pos, "%llu " ANSI_FG_CYAN "year%s" ANSI_RESET ", ", (unsigned long long)years, (years == 1) ? "" : "s"); + uint32 days = n32 / (24 * 60 * 60); + n32 %= (24 * 60 * 60); + if (days) + pos += snprintf(buf + pos, bufsize - pos, "%u " ANSI_FG_CYAN "day%s" ANSI_RESET ", ", days, (days == 1) ? "" : "s"); + uint32 hours = n32 / (60 * 60); + n32 %= (60 * 60); + if (hours) + pos += snprintf(buf + pos, bufsize - pos, "%u " ANSI_FG_CYAN "hour%s" ANSI_RESET ", ", hours, (hours == 1) ? "" : "s"); + uint32 minutes = n32 / 60; + if (minutes) + pos += snprintf(buf + pos, bufsize - pos, "%u " ANSI_FG_CYAN "minute%s" ANSI_RESET ", ", minutes, (minutes == 1) ? "" : "s"); + uint32 seconds = n32 % 60; + if (seconds) + pos += snprintf(buf + pos, bufsize - pos, "%u " ANSI_FG_CYAN "second%s" ANSI_RESET ", ", seconds, (seconds == 1) ? "" : "s"); + if (pos) + buf[pos -= 2] = '\0'; + return pos; +} + +static char *PrintHandshake(char *buf, size_t bufsize, uint64 secs) { + time_t now = time(NULL); + if (now == secs) { + snprintf(buf, bufsize, "Now"); + } else if (now < (int64)secs) { + snprintf(buf, bufsize, ANSI_FG_RED "System clock going backwards" ANSI_RESET); + } else { + size_t pos = PrintTime(buf, bufsize - 4, now - secs); + memcpy(buf + pos, " ago", 5); + } + return buf; +} + +static void AppendIpToString(const char *value, std::string *result) { + if (!result->empty()) + (*result) += ", "; + const char *slash = strchr(value, '/'); + if (slash) { + result->append(value, slash - value); + result->append(ANSI_FG_CYAN "/" ANSI_RESET); + result->append(slash + 1); + } else { + result->append(value); + } +} + +static int ShowUserFriendlyForDevice(char *devname) { + std::string reply; + std::vector> kv; + std::string ips; + + if (!CommunicateWithService(devname, "get=1\n\n", &reply)) + return 1; + + if (!ParseConfigKeyValue(&reply[0], &kv)) { +getout_fail: + fprintf(stderr, "Unable to parse response"); + return 1; + } + + size_t i = 0; + char base64key[WG_PUBLIC_KEY_LEN_BASE64 + 1]; + char base64psk[WG_PUBLIC_KEY_LEN_BASE64 + 1]; + int listen_port = 0; + base64key[0] = 0; + + // Parse all interface level keys + for (; i < kv.size(); i++) { + char *key = kv[i].first, *value = kv[i].second; + if (strcmp(key, "private_key") == 0) { + uint8 binkey[32]; + if (!ParseHexString(value, binkey, sizeof(binkey))) + goto getout_fail; + if (!IsOnlyZeros(binkey, 32)) { + curve25519_donna(binkey, binkey, kCurve25519Basepoint); + base64_encode(binkey, sizeof(binkey), base64key, sizeof(base64key), NULL); + } + } else if (strcmp(key, "address") == 0) { + AppendIpToString(value, &ips); + } else if (strcmp(key, "listen_port") == 0) { + listen_port = atoi(value); + } else if (strcmp(key, "public_key") == 0) { + break; + } + } + + const char *interfacename = (devname[0] == '{') ? GetInterfaceNameFromGuid(devname) : devname; + + printf(ANSI_RESET ANSI_FG_GREEN ANSI_BOLD "interface" ANSI_RESET ": " ANSI_FG_GREEN "%s" ANSI_RESET "\n", + interfacename); + if (base64key[0]) { + printf(" " ANSI_BOLD "public key" ANSI_RESET ": %s\n" + " " ANSI_BOLD "private key" ANSI_RESET ": (hidden)\n", base64key); + } + if (listen_port) + printf(" " ANSI_BOLD "listening port" ANSI_RESET ": %d\n", listen_port); + if (ips.size()) + printf(" " ANSI_BOLD "address" ANSI_RESET ": %s\n", ips.c_str()); + + const char *endpoint = NULL; + uint64 rx_bytes, tx_bytes, last_handshake_time_sec; + int persistent_keepalive; + char text[256]; + bool clear_state = true; + + // Parse peer level keys + for (; i < kv.size(); i++) { + char *key = kv[i].first, *value = kv[i].second; + + if (clear_state) { + base64key[0] = base64psk[0] = 0; + endpoint = NULL; + ips.clear(); + persistent_keepalive = 0; + last_handshake_time_sec = tx_bytes = rx_bytes = 0; + clear_state = false; + } + if (strcmp(key, "public_key") == 0) { + if (!ParseHexKeyToBase64(value, base64key)) + goto getout_fail; + } else if (strcmp(key, "preshared_key") == 0) { + if (!ParseHexKeyToBase64(value, base64psk)) + goto getout_fail; + } else if (strcmp(key, "tx_bytes") == 0) { + tx_bytes = strtoull(value, NULL, 0); + } else if (strcmp(key, "rx_bytes") == 0) { + rx_bytes = strtoull(value, NULL, 0); + } else if (strcmp(key, "allowed_ip") == 0) { + AppendIpToString(value, &ips); + } else if (strcmp(key, "persistent_keepalive_interval") == 0) { + persistent_keepalive = atoi(value); + } else if (strcmp(key, "endpoint") == 0) { + endpoint = value; + } else if (strcmp(key, "last_handshake_time_sec") == 0) { + last_handshake_time_sec = strtoull(value, NULL, 0); + } + if (i == kv.size() - 1 || strcmp(kv[i + 1].first, "public_key") == 0) { + if (!base64key[0]) + goto getout_fail; + printf("\n" ANSI_FG_YELLOW ANSI_BOLD "peer" ANSI_RESET ": " ANSI_FG_YELLOW "%s" ANSI_RESET "\n", base64key); + if (base64psk[0]) + printf(" " ANSI_BOLD "preshared key" ANSI_RESET ": (hidden)\n"); + if (endpoint) + printf(" " ANSI_BOLD "endpoint" ANSI_RESET ": %s\n", endpoint); + printf(" " ANSI_BOLD "allowed ips" ANSI_RESET ": %s\n", ips.size() ? ips.c_str() : "(none)"); + if (last_handshake_time_sec) + printf(" " ANSI_BOLD "latest handshake" ANSI_RESET ": %s\n", PrintHandshake(text, sizeof(text), last_handshake_time_sec)); + if (tx_bytes | rx_bytes) { + printf(" " ANSI_BOLD "transfer" ANSI_RESET ": %s received, ", FormatTransferPart(text, sizeof(text), rx_bytes)); + printf("%s sent\n", FormatTransferPart(text, sizeof(text), tx_bytes)); + } + if (persistent_keepalive) { + PrintTime(text, sizeof(text), persistent_keepalive); + printf(" " ANSI_BOLD "persistent keepalive" ANSI_RESET ": every %s\n", text); + } + clear_state = true; + } + } + return 0; +} + +static int HandleShowCommand(int argc, char **argv) { + if (argc != 0 && strcmp(argv[0], "--help") == 0) { + fprintf(stderr, "Usage: ts show { | all | interfaces }\n"); + return 0; + } + + std::vector interfaces; + std::string interfaces_str; + + if (argc == 0 || strcmp(argv[0], "all") == 0) { + if (!GetInterfaceList(&interfaces_str)) + return 1; + SplitString(&interfaces_str[0], '\n', &interfaces); + + bool want_newline = false; + for (char *interfac : interfaces) { + if (want_newline) + printf("\n"); + want_newline = true; + if (ShowUserFriendlyForDevice(interfac)) + return 1; + } + } else if (strcmp(argv[0], "interfaces") == 0) { + if (!GetInterfaceList(&interfaces_str)) + return 1; + SplitString(&interfaces_str[0], '\n', &interfaces); + + for (char *interfac : interfaces) { + const char *name = GetInterfaceNameFromGuid(interfac); + if (name) + printf("%s\n", name); + } + } else { + return ShowUserFriendlyForDevice(argv[0]); + } + return 0; +} + +static void AppendCommand(std::string *result, const char *tag, const char *value) { + result->append(tag); + result->append("="); + result->append(value); + result->append("\n"); +} + +static bool ConvertBase64KeyToHex(const char *s, char key[65]) { + uint8 tmp[32]; + size_t size = 32; + if (!base64_decode((uint8*)s, strlen(s), tmp, &size) || size != 32) + return false; + PrintHexString(tmp, 32, key); + return true; +} + +static int HandleSetCommand(int argc, char **argv) { + std::string command, reply; + std::vector ss; + char hexkey[65]; + + if (argc == 0) { + fprintf(stderr, "Usage: ts set [address
] [listen-port ] [private-key ] " + "[peer [remove] [preshared-key ] [endpoint :] " + "[persistent-keepalive ] [allowed-ips /[,/]] ]"); + return 1; + } + char **argv_end = argv + argc; + const char *interfc = *argv++; + + command = "set=1\n"; + + bool in_interface_section = true; + bool in_peer_section = false; + bool did_clear_allowed_ips = false; + + while (argv != argv_end) { + const char *key = *argv++; + + if (argv != argv_end) { + if (in_interface_section) { + if (strcmp(key, "listen-port") == 0) { + AppendCommand(&command, "listen_port", *argv++); + continue; + } else if (strcmp(key, "address") == 0) { + AppendCommand(&command, "address", *argv++); + continue; + } else if (strcmp(key, "private-key") == 0) { + if (!ConvertBase64KeyToHex(*argv++, hexkey)) + goto invalid_key_format; + AppendCommand(&command, "private_key", hexkey); + continue; + } + } + if (strcmp(key, "peer") == 0) { + in_interface_section = false; + in_peer_section = true; + did_clear_allowed_ips = false; + if (!ConvertBase64KeyToHex(*argv++, hexkey)) + goto invalid_key_format; + AppendCommand(&command, "public_key", hexkey); + + continue; + } + if (in_peer_section) { + if (strcmp(key, "preshared-key") == 0) { + if (!ConvertBase64KeyToHex(*argv++, hexkey)) + goto invalid_key_format; + AppendCommand(&command, "preshared_key", hexkey); + continue; + } else if (strcmp(key, "endpoint") == 0) { + AppendCommand(&command, "endpoint", *argv++); + continue; + } else if (strcmp(key, "persistent-keepalive") == 0) { + AppendCommand(&command, "persistent_keepalive_interval", *argv++); + continue; + } else if (strcmp(key, "allowed-ips") == 0) { + if (!did_clear_allowed_ips) { + AppendCommand(&command, "replace_allowed_ips", "true"); + did_clear_allowed_ips = true; + } + SplitString(*argv++, ',', &ss); + for (char *x : ss) + AppendCommand(&command, "allowed_ip", x); + continue; + } + } + } + if (in_peer_section) { + if (strcmp(key, "remove") == 0) { + in_peer_section = false; + AppendCommand(&command, "remove", "true"); + continue; + } + } + + fprintf(stderr, "Invalid argument: %s\n", key); + return 1; + +invalid_key_format: + fprintf(stderr, "Key is not in the correct format: '%s'\n", argv[-1]); + return 1; + } + + command.append("\n"); + + if (!CommunicateWithService(interfc, command, &reply)) + return 1; + + return 0; +} + +#if defined(OS_WIN) +static int HandleLogCommand() { + HANDLE pipe = ConnectToService(NULL, true); + + int message_code; + std::string reply; + + while (pipe != NULL && ReadMessageFromService(pipe, &message_code, &reply) && message_code == TS_SERVICE_MSG_LOGLINE) + printf("%s\n", reply.c_str()); + + CloseHandle(pipe); + return 0; +} + +static int HandleStartCommand(int argc, char **argv) { + if (argc < 1 || argc > 2 || strcmp(argv[0], "--help") == 0) { + fprintf(stderr, "Usage: " EXENAME " start []\n"); + return 1; + } + + const char *devname = argv[0]; + HANDLE pipe = ConnectToService(devname, false, true); + int message_code; + std::string reply; + + const char *path = (argc == 1) ? "" : argv[1]; + + // Tell the server to startup a new interface + if (pipe == NULL || + !SendMessageToService(pipe, TS_SERVICE_REQ_START, path, strlen(path) + 1) || + !ReadMessageFromService(pipe, &message_code, &reply)) + return 1; + + if (message_code == TS_SERVICE_MSG_ERROR_REPLY) { + fprintf(stderr, "%s\n", reply.c_str()); + return 1; + } + + return 0; +} + +static int HandleStopCommand(int argc, char **argv) { + if (argc != 1) { + fprintf(stderr, "Usage: " EXENAME " stop \n"); + return 1; + } + + const char *devname = argv[0]; + HANDLE pipe = ConnectToService(devname, false); + + // Tell the server to stop the interface + if (pipe == NULL || + !SendMessageToService(pipe, TS_SERVICE_REQ_STOP, NULL, 0)) + return 1; + return 0; +} +#endif // defined(OS_WIN) + + +struct CommandLineOutput { + const char *filename_to_load; + const char *interface_name; + bool daemon; +}; + +// Returns -1 on invalid subcommand +int HandleCommandLine(int argc, char **argv, CommandLineOutput *output) { + uint8 key[32]; + char base64buf[WG_PUBLIC_KEY_LEN_BASE64 + 1]; + + if (argc == 1) { + ShowHelp(); + return 1; + } + + const char *subcommand = argv[1]; + argv += 2; + argc -= 2; + + if (!strcmp(subcommand, "show")) { + return HandleShowCommand(argc, argv); + + } else if (!strcmp(subcommand, "set")) { + return HandleSetCommand(argc, argv); + +#if defined(OS_WIN) + } else if (!strcmp(subcommand, "log")) { + if (argc != 0) { + fprintf(stderr, "Usage: " EXENAME " log\n"); + return 1; + } + return HandleLogCommand(); + + } else if (!strcmp(subcommand, "start")) { + return HandleStartCommand(argc, argv); + +#else + } else if (!strcmp(subcommand, "start") && output) { + if (argc != 0 && !strcmp(argv[0], "--help")) { +start_usage: + fprintf(stderr, "Usage: " EXENAME " start [-d/--daemon] [-n ] []\n"); + return 0; + } + for (; argc; argc--, argv++) { + char *arg = argv[0]; + if (strcmp(arg, "-d") == 0 || strcmp(arg, "--daemon") == 0) { + output->daemon = true; + continue; + } + if (strcmp(arg, "-n") == 0) { + if (argc < 2) goto start_usage; + output->interface_name = argv[1]; + argc--,argv++; + continue; + } + break; + } + if (argc > 1) goto start_usage; + output->filename_to_load = (argc == 0) ? "" : argv[0]; + return 0; +#endif // defined(OS_WIN) + } else if (!strcmp(subcommand, "stop")) { + return HandleStopCommand(argc, argv); + } else if(!strcmp(subcommand, "genkey")) { + if (argc != 0) { + fprintf(stderr, "Usage: " EXENAME " genkey\n"); + return 1; + } + OsGetRandomBytes(key, 32); + curve25519_normalize(key); + printf("%s\n", base64_encode(key, 32, base64buf, sizeof(base64buf), NULL)); + + } else if (!strcmp(subcommand, "genpsk")) { + if (argc != 0) { + fprintf(stderr, "Usage: " EXENAME " genpsk\n"); + return 1; + } + OsGetRandomBytes(key, 32); + printf("%s\n", base64_encode(key, 32, base64buf, sizeof(base64buf), NULL)); + } else if (!strcmp(subcommand, "pubkey")) { + char base64[WG_PUBLIC_KEY_LEN_BASE64 + 2]; + size_t n = fread(base64, 1, sizeof(base64), stdin); + if (n < sizeof(base64) - 2 || n >= sizeof(base64) || + (n == sizeof(base64) - 1 && (base64[WG_PUBLIC_KEY_LEN_BASE64] != ' ' && base64[WG_PUBLIC_KEY_LEN_BASE64] != '\n'))) { + fprintf(stderr, EXENAME ": Incorrect key format\n"); + return 1; + } + size_t size = 32; + if (!base64_decode((uint8*)base64, n, key, &size) || size != 32) { + fprintf(stderr, EXENAME ": Incorrect key format\n"); + return 1; + } + curve25519_donna(key, key, kCurve25519Basepoint); + printf("%s\n", base64_encode(key, 32, base64buf, sizeof(base64buf), NULL)); + } else if (!strcmp(subcommand, "--help")) { + ShowHelp(); + } else if (!strcmp(subcommand, "--version")) { + printf("%s\n", TUNSAFE_VERSION_STRING); + } else { + if (argc == 0) { + if (output) + output->filename_to_load = subcommand; + } else { + ShowHelp(); + } + return -1; + } + return 0; +} + +#if defined(OS_WIN) +// This is integrated into the main tunsafe binary on posix systems +int main(int argc, char **argv) { + int rv = HandleCommandLine(argc, argv, NULL); + if (rv == -1) { + fprintf(stderr, "Invalid subcommand '%s'\n", argv[1]); + ShowHelp(); + return 1; + } + return rv; +} +#endif // defined(OS_WIN) diff --git a/ts.vcxproj b/ts.vcxproj new file mode 100644 index 0000000..0774b3d --- /dev/null +++ b/ts.vcxproj @@ -0,0 +1,199 @@ + + + + + Debug + Win32 + + + Release + Win32 + + + Debug + x64 + + + Release + x64 + + + + 15.0 + {443E105E-8D7C-401F-BD41-D3F56C76104B} + Win32Proj + ts + 10.0.17134.0 + + + + Application + true + v141 + MultiByte + + + Application + false + v141 + true + MultiByte + + + Application + true + v141 + MultiByte + + + Application + false + v141 + true + MultiByte + + + + + + + + + + + + + + + + + + + + + + true + $(SolutionDir)build\$(Platform)_$(Configuration)\ + $(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\ + + + true + $(SolutionDir)build\$(Platform)_$(Configuration)\ + $(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\ + + + false + $(SolutionDir)build\$(Platform)_$(Configuration)\ + $(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\ + + + false + $(SolutionDir)build\$(Platform)_$(Configuration)\ + $(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\ + + + + Use + Level3 + Disabled + true + _DEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + false + + + Console + true + + + + + Use + Level3 + Disabled + true + WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + false + + + Console + true + + + + + Use + Level3 + MinSpace + true + true + true + WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + MultiThreaded + true + false + + + Console + true + true + true + + + + + Use + Level3 + MinSpace + true + true + true + NDEBUG;_CONSOLE;%(PreprocessorDefinitions) + true + MultiThreaded + true + false + + + Console + true + true + true + + + + + + + + + + + NotUsing + NotUsing + NotUsing + NotUsing + + + Create + Create + Create + Create + + + + + + + + true + true + + + + + + + \ No newline at end of file diff --git a/ts.vcxproj.filters b/ts.vcxproj.filters new file mode 100644 index 0000000..4d6247d --- /dev/null +++ b/ts.vcxproj.filters @@ -0,0 +1,53 @@ + + + + + {4FC737F1-C7A5-4376-A066-2A32D752A2FF} + cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx + + + {93995380-89BD-4b04-88EB-625FBE52EBFB} + h;hh;hpp;hxx;hm;inl;inc;ipp;xsd + + + {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} + rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms + + + + + Header Files + + + Source Files + + + Source Files + + + Source Files + + + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + Source Files + + + + + Source Files + + + \ No newline at end of file diff --git a/tunsafe_threading.cpp b/tunsafe_threading.cpp index af21db3..f7c64b7 100644 --- a/tunsafe_threading.cpp +++ b/tunsafe_threading.cpp @@ -10,10 +10,14 @@ MultithreadedDelayedDelete::MultithreadedDelayedDelete() { } MultithreadedDelayedDelete::~MultithreadedDelayedDelete() { + assert(curr_.size() == 0); + assert(next_.size() == 0); + assert(to_delete_.size() == 0); free(table_); } -void MultithreadedDelayedDelete::Initialize(uint32 num_threads) { +void MultithreadedDelayedDelete::Configure(uint32 num_threads) { + assert(table_ == NULL); num_threads_ = num_threads; table_ = (CheckpointData*)calloc(sizeof(CheckpointData), num_threads); } diff --git a/tunsafe_threading.h b/tunsafe_threading.h index 1362678..d595104 100644 --- a/tunsafe_threading.h +++ b/tunsafe_threading.h @@ -150,12 +150,14 @@ public: typedef void DoDeleteFunc(void *x); void Add(DoDeleteFunc *func, void *param); - void Initialize(uint32 num_threads); + void Configure(uint32 num_threads); void Checkpoint(uint32 thread_id); void MainCheckpoint(); + bool enabled() const { return num_threads_ != 0; } + private: struct Entry { DoDeleteFunc *func; diff --git a/tunsafe_win32.cpp b/tunsafe_win32.cpp index 96dba0e..daea673 100644 --- a/tunsafe_win32.cpp +++ b/tunsafe_win32.cpp @@ -107,22 +107,6 @@ static void SetUiVisibility(bool visible) { UpdateGraphReq(); } -static bool GetConfigFullName(const char *basename, char *fullname, size_t fullname_size) { - size_t len = strlen(basename); - - if (FindFilenameComponent(basename)[0]) { - if (len >= fullname_size) - return false; - memcpy(fullname, basename, len + 1); - return true; - } - size_t clen = GetConfigPath(fullname, fullname_size); - if (clen == 0 || clen + len >= fullname_size) - return false; - memcpy(fullname + clen, basename, (len + 1) * sizeof(fullname[0])); - return true; -} - void StopTunsafeBackend(UpdateIconWhy why) { if (g_backend->is_started()) { g_backend->Stop(); @@ -154,6 +138,7 @@ void StartTunsafeBackend(UpdateIconWhy reason) { } g_notified_connected_server = false; g_is_connected_to_server = false; + memset(&g_processor_stats, 0, sizeof(g_processor_stats)); g_backend->Start(g_current_filename); RegWriteInt(g_reg_key, "IsConnected", 1); } @@ -189,12 +174,21 @@ public: } virtual void OnLogLine(const char **s) { + const char *line = *s; + size_t len = strlen(line); + char *tmp = (char*)alloca(len + 3); + + tmp[len + 0] = '\r'; + tmp[len + 1] = '\n'; + tmp[len + 2] = 0; + memcpy(tmp, line, len); + CHARRANGE cr; cr.cpMin = -1; cr.cpMax = -1; // hwnd = rich edit hwnd SendMessage(hwndEdit, EM_EXSETSEL, 0, (LPARAM)&cr); - SendMessage(hwndEdit, EM_REPLACESEL, 0, (LPARAM)*s); + SendMessage(hwndEdit, EM_REPLACESEL, 0, (LPARAM)tmp); } virtual void OnStateChanged() { @@ -204,13 +198,13 @@ public: const char *filename = g_cmdline_filename; if (filename) { - if (GetConfigFullName(filename, fullname, sizeof(fullname))) + if (ExpandConfigPath(filename, fullname, sizeof(fullname))) SetCurrentConfigFilename(fullname); } else { std::string currconfig = g_backend->GetConfigFileName(); if (currconfig.empty()) { char *conf = RegReadStr(g_reg_key, "ConfigFile", "TunSafe.conf"); - if (GetConfigFullName(conf, fullname, sizeof(fullname))) + if (ExpandConfigPath(conf, fullname, sizeof(fullname))) SetCurrentConfigFilename(fullname); free(conf); } else { @@ -233,10 +227,12 @@ public: } virtual void OnStatusCode(TunsafeBackend::StatusCode status) override { + if (status != g_status_code) + InvalidatePaintbox(); + g_status_code = status; if (TunsafeBackend::IsPermanentError(status)) { UpdateIcon(g_is_connected_to_server ? UIW_STOPPED_WORKING_FAIL : UIW_NONE); - InvalidatePaintbox(); return; } bool is_connected = (status == TunsafeBackend::kStatusConnected); @@ -254,13 +250,15 @@ public: if (is_connected > not_first && (g_startup_flags & kStartupFlag_BackgroundService)) g_notified_connected_server = true; UpdateIcon(UIW_NONE); - InvalidatePaintbox(); } } virtual void OnClearLog() override { SetWindowText(hwndEdit, ""); } + + virtual void OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) override { + } }; static MyBackendDelegate my_procdel; @@ -411,7 +409,7 @@ public: ConfigMenuBuilder::ConfigMenuBuilder() : nfiles_(0), depth_(0) { - if (!GetConfigFullName("", buf_, sizeof(buf_))) + if (!ExpandConfigPath("", buf_, sizeof(buf_))) bufpos_ = sizeof(buf_); else bufpos_ = strlen(buf_); @@ -556,7 +554,7 @@ static void OpenEditor() { static void BrowseFiles() { char buf[MAX_PATH]; - if (GetConfigFullName("", buf, ARRAYSIZE(buf))) { + if (ExpandConfigPath("", buf, ARRAYSIZE(buf))) { size_t l = strlen(buf); buf[l - 1] = 0; ShellExecuteFromExplorer(buf, NULL, NULL, "explore"); @@ -572,7 +570,7 @@ bool ImportFile(const char *s, bool silent = false) { bool rv = false; int filerv; - if (!*last || !GetConfigFullName(last, buf, ARRAYSIZE(buf)) || _stricmp(buf, s) == 0) + if (!*last || !ExpandConfigPath(last, buf, ARRAYSIZE(buf)) || _stricmp(buf, s) == 0) goto out; filedata = LoadFileSane(s, &filesize); @@ -657,9 +655,8 @@ void BrowseFile(HWND wnd) { static const uint8 kCurve25519Basepoint[32] = {9}; static void SetKeyBox(HWND wnd, int ctr, uint8 buf[32]) { - uint8 *privs = base64_encode(buf, 32, NULL); - SetDlgItemText(wnd, ctr, (char*)privs); - free(privs); + char base64[WG_PUBLIC_KEY_LEN_BASE64 + 1]; + SetDlgItemText(wnd, ctr, base64_encode(buf, 32, base64, sizeof(base64), NULL)); } static INT_PTR WINAPI KeyPairDlgProc(HWND hWnd, UINT message, WPARAM wParam, @@ -1075,20 +1072,18 @@ void PushLine(const char *s) { snprintf(buf, sizeof(buf), "[%.2d:%.2d:%.2d] ", t.wHour, t.wMinute, t.wSecond); size_t tl = strlen(buf); - char *x = (char*)malloc(tl + l + 3); + char *x = (char*)malloc(tl + l + 1); if (!x) return; memcpy(x, buf, tl); memcpy(x + tl, s, l); - x[l + tl] = '\r'; - x[l + tl + 1] = '\n'; - x[l + tl + 2] = '\0'; + x[l + tl] = '\0'; g_backend_delegate->OnLogLine((const char**)&x); free(x); } void EnsureConfigDirCreated() { char fullname[1024]; - if (GetConfigFullName("", fullname, sizeof(fullname))) + if (ExpandConfigPath("", fullname, sizeof(fullname))) CreateDirectory(fullname, NULL); } @@ -1358,7 +1353,7 @@ static void DrawGraph(HDC dc, const RECT *rr, StatsCollector::TimeSeries **sourc for (size_t j = 0; j != num_source; j++) { const StatsCollector::TimeSeries *src = sources[j]; for (size_t i = 0; i != src->size; i++) - mx = max(mx, src->data[i]); + mx = std::max(mx, src->data[i]); } int topval = (int)(mx + 0.5f); // round it appropriately @@ -1432,11 +1427,13 @@ static void DrawInGraphBox(HDC hdc, int w, int h) { for (int i = 0; i < graph->num_charts; i++) { time_series_ptr[i] = &time_series[i]; time_series[i].shift = 0; - time_series[i].size = *(uint32*)ptr; + + uint32 size = *(uint32*)ptr; + time_series[i].size = size; time_series[i].data = (float*)(ptr + 4); - ptr += 4 + *(uint32*)ptr * 4; - if (ptr - (uint8*)graph > graph->total_size) + if ((ptr - (uint8*)graph) + 4 + (uint64)size * 4 > graph->total_size) break; + ptr += 4 + size * 4; } num_charts = graph->num_charts; } @@ -1517,9 +1514,7 @@ static const char *GetAdvancedInfoValue(char buffer[256], int i) { case 0: { if (IsOnlyZeros(g_backend->public_key(), 32)) return ""; - char *str = (char*)base64_encode(g_backend->public_key(), 32, NULL); - snprintf(buffer, 256, "%s", str); - free(str); + base64_encode(g_backend->public_key(), 32, buffer, 256, NULL); return buffer; } case 1: { @@ -1764,7 +1759,7 @@ int WINAPI WinMain(HINSTANCE hInstance, HINSTANCE hPrevInstance, LPSTR lpCmdLine // Check if the app is already running. g_runonce_mutex = CreateMutexA(0, FALSE, "TunSafe-f19e092db01cbe0fb6aee132f8231e5b71c98f90"); - if (GetLastError() == ERROR_ALREADY_EXISTS) { + if (GetLastError() == ERROR_ALREADY_EXISTS&&0) { HWND window = FindWindow("TunSafe-f19e092db01cbe0fb6aee132f8231e5b71c98f90", NULL); DWORD_PTR result; if (!window || !SendMessageTimeout(window, WM_USER + 10, 0, 0, SMTO_BLOCK, 3000, &result) || result != 31337) { diff --git a/util.cpp b/util.cpp index 2269a3f..2d7fe8e 100644 --- a/util.cpp +++ b/util.cpp @@ -17,46 +17,55 @@ #include #endif +#include #include #include "tunsafe_types.h" -static char base64_alphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; +static const char kBase64Alphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; -uint8 *base64_encode(const uint8 *input, size_t length, size_t *out_length) { - uint32 a; - size_t size; - uint8 *result, *r; +char *base64_encode(const uint8 *input, size_t length, char *output, size_t output_size, size_t *out_length) { + char *result, *r; const uint8 *end; - size = length * 4 / 3 + 4 + 1; - r = result = (byte*)malloc(size); + size_t size = (length + 2) / 3 * 4 + 1; + if (output != NULL) { + result = output; + assert(output_size >= size); + if (output_size < size) { + *result = 0; + return NULL; + } + } else { + result = (char*)malloc(size); + if (!result) + return NULL; + } + r = result; end = input + length - 3; - // Encode full blocks while (input <= end) { - a = (input[0] << 16) + (input[1] << 8) + input[2]; + uint32 a = (input[0] << 16) + (input[1] << 8) + input[2]; input += 3; - r[0] = base64_alphabet[(a >> 18)/* & 0x3F*/]; - r[1] = base64_alphabet[(a >> 12) & 0x3F]; - r[2] = base64_alphabet[(a >> 6) & 0x3F]; - r[3] = base64_alphabet[(a) & 0x3F]; + r[0] = kBase64Alphabet[(a >> 18)/* & 0x3F*/]; + r[1] = kBase64Alphabet[(a >> 12) & 0x3F]; + r[2] = kBase64Alphabet[(a >> 6) & 0x3F]; + r[3] = kBase64Alphabet[(a) & 0x3F]; r += 4; } - if (input == end + 2) { - a = input[0] << 4; - r[0] = base64_alphabet[(a >> 6) /*& 0x3F*/]; - r[1] = base64_alphabet[(a) & 0x3F]; + uint32 a = input[0] << 4; + r[0] = kBase64Alphabet[(a >> 6) /*& 0x3F*/]; + r[1] = kBase64Alphabet[(a) & 0x3F]; r[2] = '='; r[3] = '='; r += 4; } else if (input == end + 1) { - a = (input[0] << 10) + (input[1] << 2); - r[0] = base64_alphabet[(a >> 12) /*& 0x3F*/]; - r[1] = base64_alphabet[(a >> 6) & 0x3F]; - r[2] = base64_alphabet[(a) & 0x3F]; + uint32 a = (input[0] << 10) + (input[1] << 2); + r[0] = kBase64Alphabet[(a >> 12) /*& 0x3F*/]; + r[1] = kBase64Alphabet[(a >> 6) & 0x3F]; + r[2] = kBase64Alphabet[(a) & 0x3F]; r[3] = '='; r += 4; } @@ -250,13 +259,6 @@ void RERROR(const char *msg, ...) { } } -void rinfo(const char *msg, ...) { - printf("muu"); -} - -void rinfo2(const char *msg) { - printf("muu2"); -} void RINFO(const char *msg, ...) { va_list va; @@ -296,4 +298,115 @@ size_t my_strlcpy(char *dst, size_t dstsize, const char *src) { memcpy(dst, src, lenx); } return len; -} \ No newline at end of file +} + +void OsGetRandomBytes(uint8 *data, size_t data_size) { +#if defined(OS_WIN) + static BOOLEAN(APIENTRY *pfn)(void*, ULONG); + if (!pfn) { + pfn = (BOOLEAN(APIENTRY *)(void*, ULONG))GetProcAddress(LoadLibrary("ADVAPI32.DLL"), "SystemFunction036"); + if (!pfn) + ExitProcess(1); + } + if (!pfn(data, (ULONG)data_size)) { + ExitProcess(1); + return; + } +#elif defined(OS_POSIX) + int fd = open("/dev/urandom", O_RDONLY); + if (fd < 0) { + fprintf(stderr, "/dev/urandom failed\n"); + exit(1); + } + int r = read(fd, data, data_size); + if (r != data_size) { + fprintf(stderr, "/dev/urandom failed\n"); + exit(1); + } + close(fd); +#else +#error +#endif +} + +bool ParseConfigKeyValue(char *m, std::vector> *result) { + for (;;) { + char *nl = strchr(m, '\n'); + if (nl) + *nl = 0; + if (*m != '\0') { + char *value = strchr(m, '='); + if (value == NULL) + return false; + *value++ = '\0'; + result->emplace_back(m, value); + } + if (!nl) + return true; + m = nl + 1; + } +} + +bool ParseHexString(const char *text, void *data, size_t data_size) { + size_t len = strlen(text); + if (len != data_size * 2) + return false; + for (size_t i = 0; i < data_size; i++) { + uint32 c = text[i * 2 + 0]; + if (c >= '0' && c <= '9') { + c -= '0'; + } else if ((c |= 32) >= 'a' && c <= 'f') { + c -= 'a' - 10; + } else { + return false; + } + uint32 d = text[i * 2 + 1]; + if (d >= '0' && d <= '9') { + d -= '0'; + } else if ((d |= 32) >= 'a' && d <= 'f') { + d -= 'a' - 10; + } else { + return false; + } + ((uint8*)data)[i] = c * 16 + d; + } + return true; +} + +bool is_space(uint8_t c) { + return c == ' ' || c == '\r' || c == '\n' || c == '\t'; +} + +void SplitString(char *s, int separator, std::vector *components) { + components->clear(); + for (;;) { + while (is_space(*s)) s++; + char *d = strchr(s, separator); + if (d == NULL) { + if (*s) + components->push_back(s); + return; + } + *d = 0; + char *e = d; + while (e > s && is_space(e[-1])) + *--e = 0; + components->push_back(s); + s = d + 1; + } +} + +void PrintHexString(const void *data, size_t data_size, char *result) { + for (size_t i = 0; i < data_size; i++) { + uint8 c = ((uint8*)data)[i]; + *result++ = "0123456789abcdef"[c >> 4]; + *result++ = "0123456789abcdef"[c & 0xF]; + } + *result++ = 0; +} + +bool ParseBase64Key(const char *s, uint8 key[32]) { + size_t size = 32; + return base64_decode((uint8*)s, strlen(s), key, &size) && size == 32; +} + diff --git a/util.h b/util.h index e0846f3..128d376 100644 --- a/util.h +++ b/util.h @@ -3,7 +3,7 @@ #pragma once #include "tunsafe_types.h" -uint8 *base64_encode(const uint8 *input, size_t length, size_t *out_length); +char *base64_encode(const uint8 *input, size_t length, char *output, size_t output_size, size_t *actual_size); bool base64_decode(uint8 *in, size_t inLen, uint8 *out, size_t *outLen); bool IsOnlyZeros(const uint8 *data, size_t data_size); @@ -17,9 +17,28 @@ char *my_strndup(const char *p, size_t size); size_t my_strlcpy(char *dst, size_t dstsize, const char *src); - template static inline T postinc(T&x, U v) { T t = x; x += v; return t; } + +template static inline T exch(T&x, U v) { + T t = x; + x = v; + return t; +} + +template static inline T exch_null(T&x) { + T t = x; + x = NULL; + return t; +} + +bool is_space(uint8_t c); +void OsGetRandomBytes(uint8 *dst, size_t dst_size); +bool ParseConfigKeyValue(char *m, std::vector> *result); +bool ParseHexString(const char *text, void *data, size_t data_size); +void PrintHexString(const void *data, size_t data_size, char *result); +void SplitString(char *s, int separator, std::vector *components); +bool ParseBase64Key(const char *s, uint8 key[32]); \ No newline at end of file diff --git a/util_win32.cpp b/util_win32.cpp index 1dec101..821372c 100644 --- a/util_win32.cpp +++ b/util_win32.cpp @@ -297,6 +297,23 @@ size_t GetConfigPath(char *path, size_t path_size) { return last + 7 - path; } +bool ExpandConfigPath(const char *basename, char *fullname, size_t fullname_size) { + size_t len = strlen(basename); + + if (FindFilenameComponent(basename)[0]) { + if (len >= fullname_size) + return false; + memcpy(fullname, basename, len + 1); + return true; + } + size_t clen = GetConfigPath(fullname, fullname_size); + if (clen == 0 || clen + len >= fullname_size) + return false; + memcpy(fullname + clen, basename, (len + 1) * sizeof(fullname[0])); + return true; +} + + static bool ContainsDotDot(const char *path) { for (uint8 last = 0, cur; (cur = path[0]) != '\0'; last = cur, path++) if (cur == '.' && last == cur) @@ -308,7 +325,7 @@ bool EnsureValidConfigPath(const char *path) { char buf[1024]; size_t len = GetConfigPath(buf, sizeof(buf)); - return (len != 0) && (strlen(path) > len && memcmp(path, buf, len) == 0 && !ContainsDotDot(path + len)); + return (len != 0) && (strlen(path) > len && _strnicmp(path, buf, len) == 0 && !ContainsDotDot(path + len)); } bool RunProcessAsAdminWithArgs(const char *args, bool wait_for_exit) { @@ -376,3 +393,65 @@ RECT MakeRect(int l, int t, int r, int b) { RECT rr = { l, t, r, b }; return rr; } + +// Retrieve the device path to the TAP adapter. + +#define kAdapterKeyName "SYSTEM\\CurrentControlSet\\Control\\Class\\{4D36E972-E325-11CE-BFC1-08002BE10318}" +#define kNetworkConnectionsKeyName "SYSTEM\\CurrentControlSet\\Control\\Network\\{4D36E972-E325-11CE-BFC1-08002BE10318}" +#define kTapComponentId "tap0901" + + +void GetTapAdapterInfo(std::vector *result) { + LONG err; + HKEY adapter_key, device_key, network_connections_key; + bool retval = false; + GuidAndDevName gn; + + err = RegOpenKeyEx(HKEY_LOCAL_MACHINE, kAdapterKeyName, 0, KEY_READ, &adapter_key); + if (err != ERROR_SUCCESS) { + RERROR("GetTapAdapterName: RegOpenKeyEx failed: 0x%X", GetLastError()); + return; + } + for (int i = 0; !retval; i++) { + char keyname[64 + sizeof(kAdapterKeyName) + 1 + 32 /* some margin */]; + char value[64]; + DWORD len = sizeof(value), type; + err = RegEnumKeyEx(adapter_key, i, value, &len, NULL, NULL, NULL, NULL); + if (err == ERROR_NO_MORE_ITEMS) + break; + if (err != ERROR_SUCCESS) { + RERROR("GetTapAdapterName: RegEnumKeyEx failed: 0x%X", GetLastError()); + break; + } + snprintf(keyname, sizeof(keyname), "%s\\%s", kAdapterKeyName, value); + err = RegOpenKeyEx(HKEY_LOCAL_MACHINE, keyname, 0, KEY_READ, &device_key); + if (err == ERROR_SUCCESS) { + len = sizeof(value); + err = RegQueryValueEx(device_key, "ComponentId", NULL, &type, (LPBYTE)value, &len); + if (err == ERROR_SUCCESS && type == REG_SZ && !memcmp(value, kTapComponentId, sizeof(kTapComponentId))) { + len = sizeof(gn.guid); + err = RegQueryValueEx(device_key, "NetCfgInstanceId", NULL, &type, (LPBYTE)gn.guid, &len); + if (err == ERROR_SUCCESS && type == REG_SZ) { + gn.guid[sizeof(gn.guid) - 1] = 0; + gn.name[0] = 0; + + snprintf(keyname, sizeof(keyname), "%s\\%s\\Connection", kNetworkConnectionsKeyName, gn.guid); + err = RegOpenKeyEx(HKEY_LOCAL_MACHINE, keyname, 0, KEY_READ, &network_connections_key); + if (err == ERROR_SUCCESS) { + len = sizeof(gn.guid); + err = RegQueryValueEx(network_connections_key, "Name", NULL, &type, (LPBYTE)gn.name, &len); + if (err == ERROR_SUCCESS && type == REG_SZ) { + gn.name[sizeof(gn.guid) - 1] = 0; + } else { + gn.name[0] = 0; + } + RegCloseKey(network_connections_key); + } + result->push_back(gn); + } + } + RegCloseKey(device_key); + } + } + RegCloseKey(adapter_key); +} diff --git a/util_win32.h b/util_win32.h index 8497903..0ae16ec 100644 --- a/util_win32.h +++ b/util_win32.h @@ -1,6 +1,7 @@ // SPDX-License-Identifier: AGPL-1.0-only // Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. #include "tunsafe_types.h" +#include #pragma once const char *FindFilenameComponent(const char *s); @@ -47,6 +48,7 @@ void ShellExecuteFromExplorer( int nShowCmd = SW_SHOWNORMAL); size_t GetConfigPath(char *path, size_t path_size); +bool ExpandConfigPath(const char *basename, char *fullname, size_t fullname_size); bool EnsureValidConfigPath(const char *path); bool RunProcessAsAdminWithArgs(const char *args, bool wait_for_exit); @@ -54,3 +56,8 @@ bool RestartProcessAsAdministrator(); bool SetClipboardString(const char *string); RECT GetParentRect(HWND wnd); RECT MakeRect(int l, int t, int r, int b); +struct GuidAndDevName { + char guid[40]; + char name[64]; +}; +void GetTapAdapterInfo(std::vector *result); diff --git a/wireguard.cpp b/wireguard.cpp index ceeb068..fb616e3 100644 --- a/wireguard.cpp +++ b/wireguard.cpp @@ -15,8 +15,7 @@ #include "ipzip2/ipzip2.h" #include "wireguard.h" #include "wireguard_config.h" - -uint64 OsGetMilliseconds(); +#include "util.h" enum { IPV4_HEADER_SIZE = 20, @@ -36,23 +35,24 @@ WireguardProcessor::WireguardProcessor(UdpInterface *udp, TunInterface *tun, Pro add_routes_mode_ = true; dns_blocking_ = true; internet_blocking_ = kBlockInternet_Default; - + is_started_ = false; stats_last_bytes_in_ = 0; stats_last_bytes_out_ = 0; stats_last_ts_ = OsGetMilliseconds(); - - main_thread_scheduled_ = NULL; - main_thread_scheduled_last_ = &main_thread_scheduled_; } WireguardProcessor::~WireguardProcessor() { } void WireguardProcessor::SetListenPort(int listen_port) { - listen_port_ = listen_port; + if (listen_port_ != listen_port) { + listen_port_ = listen_port; + if (is_started_ && !ConfigureUdp()) { + RINFO("ConfigureUdp failed"); + } + } } - void WireguardProcessor::AddDnsServer(const IpAddr &sin) { std::vector *target = (sin.sin.sin_family == AF_INET6) ? &dns6_addr_ : &dns_addr_; target->push_back(sin); @@ -66,6 +66,11 @@ bool WireguardProcessor::SetTunAddress(const WgCidrAddr &addr) { return true; } +void WireguardProcessor::ClearTunAddress() { + tun_addr_.size = 0; + tun6_addr_.size = 0; +} + void WireguardProcessor::AddExcludedIp(const WgCidrAddr &cidr_addr) { excluded_ips_.push_back(cidr_addr); } @@ -129,23 +134,29 @@ static bool IsWgCidrAddrSubsetOf(const WgCidrAddr &inner, const WgCidrAddr &oute } bool WireguardProcessor::Start() { + return ConfigureUdp() && ConfigureTun(); +} + +bool WireguardProcessor::ConfigureUdp() { assert(dev_.IsMainThread()); - if (!udp_->Initialize(listen_port_)) - return false; + return udp_->Configure(listen_port_); +} - if (tun_addr_.size != 32) { - RERROR("No IPv4 address configured"); - return false; - } - - if (tun_addr_.cidr >= 31) { - RERROR("TAP is not compatible CIDR /31 or /32. Changing to /24"); - tun_addr_.cidr = 24; - } +bool WireguardProcessor::ConfigureTun() { + assert(dev_.IsMainThread()); TunInterface::TunConfig config = {0}; - config.ip = ReadBE32(tun_addr_.addr); - config.cidr = tun_addr_.cidr; + if (tun_addr_.size == 32) { + if (tun_addr_.cidr >= 31) { + RERROR("TAP is not compatible CIDR /31 or /32. Changing to /24"); + tun_addr_.cidr = 24; + } + config.ip = ReadBE32(tun_addr_.addr); + config.cidr = tun_addr_.cidr; + } else { + RERROR("No IPv4 address configured"); + } + config.mtu = mtu_; config.pre_post_commands = pre_post_; config.excluded_ips = excluded_ips_; @@ -205,7 +216,7 @@ bool WireguardProcessor::Start() { config.ipv6_dns = dns6_addr_; TunInterface::TunConfigOut config_out; - if (!tun_->Initialize(std::move(config), &config_out)) + if (!tun_->Configure(std::move(config), &config_out)) return false; SetupCompressionHeader(dev_.compression_header()); @@ -221,6 +232,7 @@ bool WireguardProcessor::Start() { } } + is_started_ = true; return true; } @@ -395,22 +407,6 @@ getout: FreePacket(packet); } -void WgPeer::AddPacketToPeerQueue(Packet *packet) { - assert(IsPeerLocked()); - // Keep only the first MAX_QUEUED_PACKETS packets. - while (num_queued_packets_ >= MAX_QUEUED_PACKETS_PER_PEER) { - Packet *packet = first_queued_packet_; - first_queued_packet_ = packet->next; - num_queued_packets_--; - FreePacket(packet); - } - // Add the packet to the out queue that will get sent once handshake completes - *last_queued_packet_ptr_ = packet; - last_queued_packet_ptr_ = &packet->next; - packet->next = NULL; - num_queued_packets_++; -} - // This function must be called with the peer lock held. It will remove the lock void WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet) { assert(peer->IsPeerLocked()); @@ -427,11 +423,17 @@ void WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Pac if ((keypair = peer->curr_keypair_) == NULL || (send_ctr = keypair->send_ctr) >= REJECT_AFTER_MESSAGES) { - peer->AddPacketToPeerQueue(packet); + // If RemovePeer has been called then discard any packets currently being written to it. + // curr_keypair_ is NULL when RemovePeer has been called so it's safe to do this here. + if (peer->marked_for_delete_) + goto getout_discard; + + peer->AddPacketToPeerQueue_Locked(packet); WG_RELEASE_LOCK(peer->mutex_); - ScheduleNewHandshake(peer); + peer->ScheduleNewHandshake(); return; } + assert(!peer->marked_for_delete_); stats_.tun_bytes_in += size; stats_.tun_packets_in++; @@ -524,13 +526,14 @@ add_padding: WriteLE32(write -= 4, keypair->remote_key_id); *--write = tag; - // Not using any fields from now on - WG_RELEASE_LOCK(peer->mutex_); - header_size = data - write; + packet->size = (int)(size + header_size + keypair->auth_tag_length); + peer->tx_bytes_ += packet->size; stats_.compression_wg_saved_out += (int64)16 - header_size; packet->data = data - header_size; - packet->size = (int)(size + header_size + keypair->auth_tag_length); + + // Not using any fields from now on + WG_RELEASE_LOCK(peer->mutex_); // todo: figure out what to actually use as ad. ad = write_after_ack_header; @@ -540,6 +543,9 @@ need_big_packet: #else { #endif // #if WITH_SHORT_HEADERS + packet->size = (int)(size + sizeof(MessageData) + keypair->auth_tag_length); + peer->tx_bytes_ += packet->size; + // Not using any fields from now on WG_RELEASE_LOCK(peer->mutex_); @@ -547,7 +553,6 @@ need_big_packet: ((MessageData*)data)[-1].receiver_id = keypair->remote_key_id; ((MessageData*)data)[-1].counter = ToLE64(send_ctr); packet->data = data - sizeof(MessageData); - packet->size = (int)(size + sizeof(MessageData) + keypair->auth_tag_length); ad = NULL; ad_len = 0; } @@ -556,7 +561,7 @@ need_big_packet: DoWriteUdpPacket(packet); if (want_handshake) - ScheduleNewHandshake(peer); + peer->ScheduleNewHandshake(); return; getout_discard: @@ -608,38 +613,32 @@ void WireguardProcessor::DoWriteUdpPacket(Packet *packet) { ScrambleUnscrambleAndWrite(packet, &dev_.header_obfuscation_key_, udp_); } -void WireguardProcessor::ScheduleNewHandshake(WgPeer *peer) { - if (peer->main_thread_scheduled_.fetch_or(WgPeer::kMainThreadScheduled_ScheduleHandshake) == 0) { - peer->main_thread_scheduled_next_ = NULL; - WG_ACQUIRE_LOCK(main_thread_scheduled_lock_); - *main_thread_scheduled_last_ = peer; - main_thread_scheduled_last_ = &peer->main_thread_scheduled_next_; - WG_RELEASE_LOCK(main_thread_scheduled_lock_); - // todo: in multithreaded impl need to trigger |RunAllMainThreadScheduled| to get called - } -} - void WireguardProcessor::RunAllMainThreadScheduled() { + WgPeer *peer, *next; assert(dev_.IsMainThread()); - if (main_thread_scheduled_ == NULL) + if (dev_.main_thread_scheduled_ == NULL) return; - WG_ACQUIRE_LOCK(main_thread_scheduled_lock_); - WgPeer *peer = main_thread_scheduled_; - main_thread_scheduled_ = NULL; - main_thread_scheduled_last_ = &main_thread_scheduled_; - WG_RELEASE_LOCK(main_thread_scheduled_lock_); + WG_ACQUIRE_LOCK(dev_.main_thread_scheduled_lock_); + peer = dev_.main_thread_scheduled_; + dev_.main_thread_scheduled_ = NULL; + dev_.main_thread_scheduled_last_ = &dev_.main_thread_scheduled_; + WG_RELEASE_LOCK(dev_.main_thread_scheduled_lock_); + + for (; peer; peer = next) { + // todo: for the multithreaded use case figure out whether to use atomic_thread_fence here, + // because we need to read this next value before any other thread sees the 0 we write + // to peer->main_thread_scheduled_. + next = peer->main_thread_scheduled_next_; + if (peer->marked_for_delete_) + continue; - while (peer) { - // todo: for the multithreaded use case figure out whether to use atomic_thread_fence here. - WgPeer *next = peer->main_thread_scheduled_next_; uint32 ev = peer->main_thread_scheduled_.exchange(0); if (ev & WgPeer::kMainThreadScheduled_ScheduleHandshake) { peer->handshake_attempts_ = 0; SendHandshakeInitiation(peer); } - peer = next; } } @@ -658,6 +657,7 @@ void WireguardProcessor::SendHandshakeInitiation(WgPeer *peer) { procdel_->OnConnectionRetry(attempts); peer->OnHandshakeInitSent(); packet->addr = peer->endpoint_; + peer->tx_bytes_ += packet->size; WG_RELEASE_LOCK(peer->mutex_); DoWriteUdpPacket(packet); if (attempts > 1 && attempts <= 20) @@ -696,19 +696,21 @@ void WireguardProcessor::HandleUdpPacket(Packet *packet, bool overload) { #endif // WITH_SHORT_HEADERS } else if (type == MESSAGE_HANDSHAKE_COOKIE) { assert(dev_.IsMainThread()); - if (packet->size != sizeof(MessageHandshakeCookie)) + if (packet->size != sizeof(MessageHandshakeCookie) || !dev_.is_private_key_initialized()) goto invalid_size; HandleHandshakeCookiePacket(packet); } else if (type == MESSAGE_HANDSHAKE_INITIATION) { assert(dev_.IsMainThread()); - if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeInitiation)) : (packet->size != sizeof(MessageHandshakeInitiation))) + if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeInitiation)) : (packet->size != sizeof(MessageHandshakeInitiation)) || + !dev_.is_private_key_initialized()) goto invalid_size; stats_.handshakes_in++; if (CheckIncomingHandshakeRateLimit(packet, overload)) HandleHandshakeInitiationPacket(packet); } else if (type == MESSAGE_HANDSHAKE_RESPONSE) { assert(dev_.IsMainThread()); - if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeResponse)) : (packet->size != sizeof(MessageHandshakeResponse))) + if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeResponse)) : (packet->size != sizeof(MessageHandshakeResponse)) || + !dev_.is_private_key_initialized()) goto invalid_size; if (CheckIncomingHandshakeRateLimit(packet, overload)) HandleHandshakeResponsePacket(packet); @@ -749,6 +751,8 @@ void WgPeer::CopyEndpointToPeer_Locked(WgKeypair *keypair, const IpAddr *addr) { #if WITH_SHORT_HEADERS void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packet) { + assert(dev_.IsMainOrDataThread()); + uint8 *data = packet->data + 1; size_t bytes_left = packet->size - 1; WgKeypair *keypair; @@ -832,6 +836,8 @@ void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packe WG_ACQUIRE_LOCK(keypair->peer->mutex_); + keypair->peer->rx_bytes_ += packet->size; + if (keypair->recv_key_state == WgKeypair::KEY_INVALID) goto getout_unlock; @@ -896,7 +902,7 @@ void WireguardProcessor::HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *key WgKeypair *curr_keypair = peer->curr_keypair_; if (curr_keypair && curr_keypair->recv_key_state == WgKeypair::KEY_WANT_REFRESH) { curr_keypair->recv_key_state = WgKeypair::KEY_DID_REFRESH; - ScheduleNewHandshake(peer); + peer->ScheduleNewHandshake(); } if (data_size == 0) { @@ -965,6 +971,8 @@ getout: } void WireguardProcessor::HandleDataPacket(Packet *packet) { + assert(dev_.IsMainOrDataThread()); + uint8 *data = packet->data; size_t data_size = packet->size; uint32 key_id = ((MessageData*)data)->receiver_id; @@ -984,6 +992,7 @@ getout: } WG_ACQUIRE_LOCK(keypair->peer->mutex_); + keypair->peer->rx_bytes_ += data_size; if (keypair->recv_key_state == WgKeypair::KEY_INVALID) { stats_.error_key_id++; WG_RELEASE_LOCK(keypair->peer->mutex_); @@ -993,6 +1002,8 @@ getout: WG_RELEASE_LOCK(keypair->peer->mutex_); goto getout; } else { + assert(!keypair->peer->marked_for_delete_); + WgPeer::CopyEndpointToPeer_Locked(keypair, &packet->addr); HandleAuthenticatedDataPacket_WillUnlock(keypair, packet, data + sizeof(MessageData), data_size - sizeof(MessageData) - keypair->auth_tag_length); } @@ -1119,7 +1130,7 @@ void WireguardProcessor::SecondLoop() { uint32 mask; { WG_SCOPED_LOCK(peer->mutex_); - mask = peer->CheckTimeouts(now); + mask = peer->CheckTimeouts_Locked(now); if (mask == 0) continue; if (mask & WgPeer::ACTION_SEND_KEEPALIVE) diff --git a/wireguard.h b/wireguard.h index d357d41..4e4bfc1 100644 --- a/wireguard.h +++ b/wireguard.h @@ -66,6 +66,7 @@ enum InternetBlockState { }; class WireguardProcessor { + friend class WgConfig; public: WireguardProcessor(UdpInterface *udp, TunInterface *tun, ProcessorDelegate *procdel); ~WireguardProcessor(); @@ -73,13 +74,14 @@ public: void SetListenPort(int listen_port); void AddDnsServer(const IpAddr &sin); bool SetTunAddress(const WgCidrAddr &addr); + void ClearTunAddress(); void AddExcludedIp(const WgCidrAddr &cidr_addr); void SetMtu(int mtu); void SetAddRoutesMode(bool mode); void SetDnsBlocking(bool dns_blocking); void SetInternetBlocking(InternetBlockState internet_blocking); void SetHeaderObfuscation(const char *key); - + void HandleTunPacket(Packet *packet); void HandleUdpPacket(Packet *packet, bool overload); static bool IsMainThreadPacket(Packet *packet); @@ -91,6 +93,9 @@ public: bool Start(); + bool ConfigureUdp(); + bool ConfigureTun(); + WgDevice &dev() { return dev_; } TunInterface::PrePostCommands &prepost() { return pre_post_; } const WgCidrAddr &tun_addr() { return tun_addr_; } @@ -100,7 +105,6 @@ private: void DoWriteUdpPacket(Packet *packet); void WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet); void SendHandshakeInitiation(WgPeer *peer); - void ScheduleNewHandshake(WgPeer *peer); void SendKeepalive_Locked(WgPeer *peer); void SendQueuedPackets_Locked(WgPeer *peer); @@ -110,29 +114,25 @@ private: void HandleDataPacket(Packet *packet); void HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet, uint8 *data, size_t data_size); - void HandleShortHeaderFormatPacket(uint32 tag, Packet *packet); - bool CheckIncomingHandshakeRateLimit(Packet *packet, bool overload); - bool HandleIcmpv6NeighborSolicitation(const byte *data, size_t data_size); - void SetupCompressionHeader(WgPacketCompressionVer01 *c); void NotifyHandshakeComplete(); - int listen_port_; - ProcessorDelegate *procdel_; TunInterface *tun_; UdpInterface *udp_; - int mtu_; - WgProcessorStats stats_; + + uint16 listen_port_; + uint16 mtu_; bool dns_blocking_; uint8 internet_blocking_; bool add_routes_mode_; bool network_discovery_spoofing_; bool did_have_first_handshake_; + bool is_started_; uint8 network_discovery_mac_[6]; WgDevice dev_; @@ -140,14 +140,12 @@ private: WgCidrAddr tun_addr_; WgCidrAddr tun6_addr_; + WgProcessorStats stats_; + std::vector dns_addr_, dns6_addr_; TunInterface::PrePostCommands pre_post_; - // Queue of things scheduled to run on the main thread. - WG_DECLARE_LOCK(main_thread_scheduled_lock_); - WgPeer *main_thread_scheduled_, **main_thread_scheduled_last_; - uint64 stats_last_bytes_in_, stats_last_bytes_out_; uint64 stats_last_ts_; diff --git a/wireguard_config.cpp b/wireguard_config.cpp index 6f34d5b..2bb59da 100644 --- a/wireguard_config.cpp +++ b/wireguard_config.cpp @@ -45,12 +45,26 @@ char *PrintIpAddr(const IpAddr &addr, char buf[kSizeOfAddress]) { return buf; } + +char *PrintWgCidrAddr(const WgCidrAddr &addr, char buf[kSizeOfAddress]) { + if (addr.size == 32) { + print_ip_prefix(buf, AF_INET, addr.addr, addr.cidr); + } else if (addr.size == 128) { + print_ip_prefix(buf, AF_INET6, addr.addr, addr.cidr); + } else { + buf[0] = 0; + } + return buf; +} + + + struct Addr { byte addr[4]; uint8 cidr; }; -static bool ParseCidrAddr(char *s, WgCidrAddr *out) { +bool ParseCidrAddr(char *s, WgCidrAddr *out) { char *slash = strchr(s, '/'); if (!slash) return false; @@ -92,15 +106,6 @@ bool DnsResolver::Resolve(const char *hostname, IpAddr *result) { char buf[kSizeOfAddress]; memset(result, 0, sizeof(IpAddr)); - if (inet_pton(AF_INET6, hostname, &result->sin6.sin6_addr) == 1) { - result->sin.sin_family = AF_INET6; - return true; - } - - if (inet_pton(AF_INET, hostname, &result->sin.sin_addr) == 1) { - result->sin.sin_family = AF_INET; - return true; - } // First check cache for (auto it = cache_.begin(); it != cache_.end(); ++it) { @@ -145,10 +150,7 @@ bool DnsResolver::Resolve(const char *hostname, IpAddr *result) { } } - - - -static bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver) { +bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver) { memset(sin, 0, sizeof(IpAddr)); if (*s == '[') { char *end = strchr(s, ']'); @@ -168,7 +170,11 @@ static bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver) if (!x) return false; *x = 0; - if (!resolver->Resolve(s, sin)) { + if (inet_pton(AF_INET, s, &sin->sin.sin_addr) == 1) { + sin->sin.sin_family = AF_INET; + } else if (!resolver) { + return false; + } else if (!resolver->Resolve(s, sin)) { RERROR("Unable to resolve %s", s); return false; } @@ -177,18 +183,19 @@ static bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver) } static bool ParseSockaddrInWithoutPort(char *s, IpAddr *sin, DnsResolver *resolver) { - if (!resolver->Resolve(s, sin)) { + if (inet_pton(AF_INET6, s, &sin->sin6.sin6_addr) == 1) { + sin->sin.sin_family = AF_INET6; + return true; + } else if (inet_pton(AF_INET, s, &sin->sin.sin_addr) == 1) { + sin->sin.sin_family = AF_INET; + return true; + } else if (!resolver->Resolve(s, sin)) { RERROR("Unable to resolve %s", s); return false; } return true; } -static bool ParseBase64Key(const char *s, uint8 key[32]) { - size_t size = 32; - return base64_decode((uint8*)s, strlen(s), key, &size) && size == 32; -} - class WgFileParser { public: WgFileParser(WireguardProcessor *wg, DnsResolver *resolver) : wg_(wg), dns_resolver_(resolver) {} @@ -197,7 +204,7 @@ public: void FinishGroup(); struct Peer { - uint8 pub[32]; + WgPublicKey pub; uint8 psk[32]; }; Peer pi_; @@ -206,29 +213,6 @@ public: bool had_interface_ = false; }; -bool is_space(uint8_t c) { - return c == ' ' || c == '\r' || c == '\n' || c == '\t'; -} - - -void SplitString(char *s, int separator, std::vector *components) { - for (;;) { - while (is_space(*s)) s++; - char *d = strchr(s, separator); - if (d == NULL) { - if (*s) - components->push_back(s); - return; - } - *d = 0; - char *e = d; - while (e > s && is_space(e[-1])) - *--e = 0; - components->push_back(s); - s = d + 1; - } -} - static bool ParseBoolean(const char *str, bool *value) { if (_stricmp(str, "true") == 0 || _stricmp(str, "yes") == 0 || @@ -285,7 +269,7 @@ static int ParseCipherSuite(const char *cipher) { void WgFileParser::FinishGroup() { if (peer_) { - peer_->Initialize(pi_.pub, pi_.psk); + peer_->SetPublicKey(pi_.pub); peer_ = NULL; } } @@ -303,7 +287,7 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) { if (!ParseBase64Key(value, binkey)) return false; had_interface_ = true; - wg_->dev().Initialize(binkey); + wg_->dev().SetPrivateKey(binkey); } else if (strcmp(key, "ListenPort") == 0) { wg_->SetListenPort(atoi(value)); } else if (strcmp(key, "Address") == 0) { @@ -394,11 +378,12 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) { return true; } if (strcmp(key, "PublicKey") == 0) { - if (!ParseBase64Key(value, pi_.pub)) + if (!ParseBase64Key(value, pi_.pub.bytes)) return false; } else if (strcmp(key, "PresharedKey") == 0) { if (!ParseBase64Key(value, pi_.psk)) return false; + peer_->SetPresharedKey(pi_.psk); } else if (strcmp(key, "AllowedIPs") == 0) { SplitString(value, ',', &ss); for (size_t i = 0; i < ss.size(); i++) { @@ -412,7 +397,8 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) { return false; peer_->SetEndpoint(sin); } else if (strcmp(key, "PersistentKeepalive") == 0) { - peer_->SetPersistentKeepalive(atoi(value)); + if (!peer_->SetPersistentKeepalive(atoi(value))) + return false; } else if (strcmp(key, "AllowMulticast") == 0) { bool b; if (!ParseBoolean(value, &b)) @@ -524,3 +510,154 @@ bool ParseWireGuardConfigFile(WireguardProcessor *wg, const char *filename, DnsR fclose(f); return true; } + + +static void CmsgAppendFmt(std::string *result, const char *fmt, ...) { + va_list va; + char buf[256]; + va_start(va, fmt); + vsnprintf(buf, sizeof(buf), fmt, va); + (*result) += buf; + (*result) += '\n'; + va_end(va); +} + +static void CmsgAppendHex(std::string *result, const char *key, const void *data, size_t data_size) { + char *tmp = (char*)alloca(data_size * 2 + 2); + PrintHexString(data, data_size, tmp + 1); + tmp[0] = '='; + tmp[data_size * 2 + 1] = '\n'; + (*result) += key; + result->append(tmp, data_size * 2 + 2); +} + +void WgConfig::HandleConfigurationProtocolGet(WireguardProcessor *proc, std::string *result) { + char buf[kSizeOfAddress]; + + CmsgAppendHex(result, "private_key", proc->dev_.s_priv_, sizeof(proc->dev_.s_priv_)); + if (proc->listen_port_) + CmsgAppendFmt(result, "listen_port=%d", proc->listen_port_); + if (proc->tun_addr_.size == 32) + CmsgAppendFmt(result, "address=%s", PrintWgCidrAddr(proc->tun_addr_, buf)); + if (proc->tun6_addr_.size == 128) + CmsgAppendFmt(result, "address=%s", PrintWgCidrAddr(proc->tun6_addr_, buf)); + + for (WgPeer *peer = proc->dev_.peers_; peer; peer = peer->next_peer_) { + WG_SCOPED_LOCK(peer->lock_); + + CmsgAppendHex(result, "public_key", peer->s_remote_.bytes, sizeof(peer->s_remote_)); + if (!IsOnlyZeros(peer->preshared_key_, sizeof(peer->preshared_key_))) + CmsgAppendHex(result, "preshared_key", peer->preshared_key_, sizeof(peer->preshared_key_)); + if (peer->tx_bytes_ | peer->rx_bytes_) + CmsgAppendFmt(result, "tx_bytes=%lld\nrx_bytes=%lld", peer->tx_bytes_, peer->rx_bytes_); + for (auto it = peer->allowed_ips_.begin(); it != peer->allowed_ips_.end(); ++it) + CmsgAppendFmt(result, "allowed_ip=%s", PrintWgCidrAddr(*it, buf)); + if (peer->persistent_keepalive_ms_) + CmsgAppendFmt(result, "persistent_keepalive_interval=%d", peer->persistent_keepalive_ms_ / 1000); + if (peer->endpoint_.sin.sin_family == AF_INET) + CmsgAppendFmt(result, "endpoint=%s:%d", PrintIpAddr(peer->endpoint_, buf), htons(peer->endpoint_.sin.sin_port)); + else if (peer->endpoint_.sin.sin_family == AF_INET6) + CmsgAppendFmt(result, "endpoint=[%s]:%d", PrintIpAddr(peer->endpoint_, buf), htons(peer->endpoint_.sin6.sin6_port)); + + if (peer->last_complete_handskake_timestamp_) { + uint64 millis_since = OsGetMilliseconds() - peer->last_complete_handskake_timestamp_; + uint64 when = time(NULL) - millis_since / 1000; + CmsgAppendFmt(result, "last_handshake_time_sec=%lld", when); + } + } + CmsgAppendFmt(result, "protocol_version=1"); +} + +bool WgConfig::HandleConfigurationProtocolMessage(WireguardProcessor *proc, const std::string &&message, std::string *result) { + std::string message_copy(std::move(message)); + std::vector> kv; + bool is_set = false; + bool did_set_address = false; + WgPeer *peer = NULL; + WgCidrAddr cidr_addr; + IpAddr sin; + uint8 buf32[32]; + assert(proc->dev().IsMainThread()); + + result->clear(); + + if (!ParseConfigKeyValue(&message_copy[0], &kv)) + return false; + + for (auto it : kv) { + char *key = it.first, *value = it.second; + if (strcmp(key, "get") == 0) { + if (strcmp(value, "1") != 0) + goto getout_fail; + HandleConfigurationProtocolGet(proc, result); + break; + } else if (strcmp(key, "set") == 0) { + if (strcmp(value, "1") != 0) + goto getout_fail; + is_set = true; + } else if (is_set) { + if (strcmp(key, "private_key") == 0) { + if (!ParseHexString(value, buf32, 32)) goto getout_fail; + proc->dev_.SetPrivateKey(buf32); + } else if (strcmp(key, "listen_port") == 0) { + int new_port = atoi(value); + proc->SetListenPort(new_port); + } else if (strcmp(key, "replace_peers") == 0) { + if (strcmp(value, "true") != 0) goto getout_fail; + proc->dev_.RemoveAllPeers(); + } else if (strcmp(key, "address") == 0) { + if (!ParseCidrAddr(value, &cidr_addr)) goto getout_fail; + if (!did_set_address) { + did_set_address = true; + proc->ClearTunAddress(); + } + if (!proc->SetTunAddress(cidr_addr)) goto getout_fail; + } else if (strcmp(key, "public_key") == 0) { + WgPublicKey pubkey; + if (!ParseHexString(value, pubkey.bytes, 32)) goto getout_fail; + peer = proc->dev_.GetPeerFromPublicKey(pubkey); + if (!peer) { + peer = proc->dev_.AddPeer(); + peer->SetPublicKey(pubkey); + } + } else if (peer != NULL) { + if (strcmp(key, "remove") == 0) { + if (strcmp(value, "true") != 0) goto getout_fail; + peer->RemovePeer(); + peer = NULL; + } else if (strcmp(key, "preshared_key") == 0) { + if (!ParseHexString(value, buf32, 32)) goto getout_fail; + peer->SetPresharedKey(buf32); + } else if (strcmp(key, "endpoint") == 0) { + if (!ParseSockaddrInWithPort(value, &sin, NULL)) goto getout_fail; + peer->SetEndpoint(sin); + } else if (strcmp(key, "persistent_keepalive_interval") == 0) { + if (!peer->SetPersistentKeepalive(atoi(value))) + goto getout_fail; + } else if (strcmp(key, "replace_allowed_ips") == 0) { + if (strcmp(value, "true") != 0) goto getout_fail; + peer->RemoveAllIps(); + } else if (strcmp(key, "allowed_ip") == 0) { + if (!ParseCidrAddr(value, &cidr_addr)) goto getout_fail; + peer->AddIp(cidr_addr); + } + } + } else { + goto getout_fail; + } + } + + // reconfigure the tun interface? + if (did_set_address) { + proc->ConfigureTun(); + } + + result->append("errno=0\n\n"); + return true; + +getout_fail: + (*result) = "errno=1\n\n"; + return false; +} + + diff --git a/wireguard_config.h b/wireguard_config.h index 01d9678..791925d 100644 --- a/wireguard_config.h +++ b/wireguard_config.h @@ -30,11 +30,19 @@ private: }; +class WgConfig { +public: + static bool HandleConfigurationProtocolMessage(WireguardProcessor *proc, const std::string &&message, std::string *result); +private: + static void HandleConfigurationProtocolGet(WireguardProcessor *proc, std::string *result); +}; + bool ParseWireGuardConfigFile(WireguardProcessor *wg, const char *filename, DnsResolver *dns_resolver); #define kSizeOfAddress 64 const char *print_ip_prefix(char buf[kSizeOfAddress], int family, const void *ip, int prefixlen); char *PrintIpAddr(const IpAddr &addr, char buf[kSizeOfAddress]); - +char *PrintWgCidrAddr(const WgCidrAddr &addr, char buf[kSizeOfAddress]); +bool ParseCidrAddr(char *s, WgCidrAddr *out); #endif // TINYVPN_TINYVPN_H_ diff --git a/wireguard_proto.cpp b/wireguard_proto.cpp index c4a3b50..a7a5567 100644 --- a/wireguard_proto.cpp +++ b/wireguard_proto.cpp @@ -54,18 +54,27 @@ bool ReplayDetector::CheckReplay(uint64 seq_nr) { WgDevice::WgDevice() { peers_ = NULL; + last_peer_ptr_ = &peers_; delegate_ = NULL; header_obfuscation_ = false; + is_private_key_initialized_ = false; next_rng_slot_ = 0; + main_thread_scheduled_ = NULL; + main_thread_scheduled_last_ = &main_thread_scheduled_; memset(&compression_header_, 0, sizeof(compression_header_)); low_resolution_timestamp_ = cookie_secret_timestamp_ = OsGetMilliseconds(); OsGetRandomBytes(cookie_secret_, sizeof(cookie_secret_)); OsGetRandomBytes((uint8*)random_number_input_, sizeof(random_number_input_)); - SetCurrentThreadAsMainThread(); + main_thread_id_ = GetCurrentThreadId(); + + memset(s_priv_, 0, sizeof(s_priv_)); + memset(s_pub_, 0, sizeof(s_pub_)); } WgDevice::~WgDevice() { + assert(IsMainThread()); + RemoveAllPeers(); } void WgDevice::SecondLoop(uint64 now) { @@ -151,7 +160,8 @@ static inline void ComputeHKDF2DH(uint8 ci[WG_HASH_LEN], uint8 k[WG_SYMMETRIC_KE memzero_crypto(dh, sizeof(dh)); } -void WgDevice::Initialize(const uint8 private_key[WG_PUBLIC_KEY_LEN]) { +void WgDevice::SetPrivateKey(const uint8 private_key[WG_PUBLIC_KEY_LEN]) { + assert(IsMainThread()); // Derive the public key from the private key. memcpy(s_priv_, private_key, sizeof(s_priv_)); curve25519_donna(s_pub_, s_priv_, kCurve25519Basepoint); @@ -162,26 +172,31 @@ void WgDevice::Initialize(const uint8 private_key[WG_PUBLIC_KEY_LEN]) { kLabelCookie, sizeof(kLabelCookie), s_pub_, sizeof(s_pub_)); BlakeX2(precomputed_mac1_key_, sizeof(precomputed_mac1_key_), kLabelMac1, sizeof(kLabelMac1), s_pub_, sizeof(s_pub_)); + + is_private_key_initialized_ = true; + + // Recompute peer data because it depends on my privkey + for (WgPeer *peer = peers_; peer; peer = peer->next_peer_) + peer->SetPublicKey(peer->s_remote_); } WgPeer *WgDevice::AddPeer() { assert(IsMainThread()); WgPeer *peer = new WgPeer(this); - WgPeer **pp = &peers_; - while (*pp) - pp = &(*pp)->next_peer_; - *pp = peer; return peer; } -WgPeer *WgDevice::GetPeerFromPublicKey(uint8 public_key[WG_PUBLIC_KEY_LEN]) { +void WgDevice::RemoveAllPeers() { assert(IsMainThread()); - // todo: add O(1) lookup - for (WgPeer *peer = peers_; peer; peer = peer->next_peer_) { - if (memcmp(peer->s_remote_, public_key, WG_PUBLIC_KEY_LEN) == 0) - return peer; - } - return NULL; + while (peers_) + peers_->RemovePeer(); +} + +WgPeer *WgDevice::GetPeerFromPublicKey(const WgPublicKey &pubkey) { + assert(IsMainThread()); + + auto it = peer_id_lookup_.find(pubkey); + return (it != peer_id_lookup_.end()) ? it->second : NULL; } bool WgDevice::CheckCookieMac1(Packet *packet) { @@ -230,6 +245,7 @@ void WgDevice::CreateCookieMessage(MessageHandshakeCookie *dst, Packet *packet, } void WgDevice::EraseKeypairAddrEntry_Locked(WgKeypair *kp) { + // todo: figure out how to make this multithread safe. WgAddrEntry *ae = kp->addr_entry; assert(ae->ref_count >= 1); @@ -313,7 +329,6 @@ void WgDevice::SetHeaderObfuscation(const char *key) { #endif // WITH_HEADER_OBFUSCATION } - WgPeer::WgPeer(WgDevice *dev) { assert(dev->IsMainThread()); dev_ = dev; @@ -323,6 +338,7 @@ WgPeer::WgPeer(WgDevice *dev) { expect_cookie_reply_ = false; has_mac2_cookie_ = false; pending_keepalive_ = false; + marked_for_delete_ = false; allow_multicast_through_peer_ = false; allow_endpoint_change_ = true; supports_handshake_extensions_ = true; @@ -331,6 +347,8 @@ WgPeer::WgPeer(WgDevice *dev) { last_handshake_init_recv_timestamp_ = 0; last_complete_handskake_timestamp_ = 0; persistent_keepalive_ms_ = 0; + rx_bytes_ = 0; + tx_bytes_ = 0; timers_ = 0; first_queued_packet_ = NULL; last_queued_packet_ptr_ = &first_queued_packet_; @@ -343,15 +361,66 @@ WgPeer::WgPeer(WgDevice *dev) { memset(last_timestamp_, 0, sizeof(last_timestamp_)); ipv4_broadcast_addr_ = 0xffffffff; memset(features_, 0, sizeof(features_)); + memset(preshared_key_, 0, sizeof(preshared_key_)); + memset(&s_remote_, 0, sizeof(s_remote_)); + + // Insert into the parent's linked list + *dev_->last_peer_ptr_ = this; + dev_->last_peer_ptr_ = &next_peer_; } WgPeer::~WgPeer() { + // do not delete this directly, instead call RemovePeer + assert(marked_for_delete_); assert(dev_->IsMainThread()); + assert(curr_keypair_ == NULL && next_keypair_ == NULL && prev_keypair_ == NULL); + assert(local_key_id_during_hs_ == 0); + assert(first_queued_packet_ == NULL); +} + +void WgPeer::DelayedDelete(void *x) { + WgPeer *peer = (WgPeer*)x; + assert(peer->dev_->IsMainThread()); + + if (peer->main_thread_scheduled_ != 0) { + WG_ACQUIRE_LOCK(peer->dev_->main_thread_scheduled_lock_); + // Unlink myself from the main thread scheduled list + for (WgPeer **pp = &peer->dev_->main_thread_scheduled_; *pp; pp = &(*pp)->main_thread_scheduled_next_) { + if (*pp == peer) { + *pp = peer->main_thread_scheduled_next_; + break; + } + } + WG_RELEASE_LOCK(peer->dev_->main_thread_scheduled_lock_); + } + delete peer; +} + +void WgPeer::RemovePeer() { + assert(dev_->IsMainThread()); + assert(!marked_for_delete_); + + // Find and unlink the peer from the parent's peer list + WgPeer **pp = &dev_->peers_; + while (*pp != this) + pp = &(*pp)->next_peer_; + if ((*pp = next_peer_) == NULL) + dev_->last_peer_ptr_ = pp; + + RemoveAllIps(); + dev_->peer_id_lookup_.erase(s_remote_); + WG_ACQUIRE_LOCK(mutex_); + marked_for_delete_ = true; ClearKeys_Locked(); ClearHandshake_Locked(); ClearPacketQueue_Locked(); WG_RELEASE_LOCK(mutex_); + + // The WgPeer instance may still be accessible from + // worker threads that already started processing a packet, + // so defer the actual delete of it. + dev_->delayed_delete_.Add(&WgPeer::DelayedDelete, this); } void WgPeer::ClearKeys_Locked() { @@ -382,21 +451,55 @@ void WgPeer::ClearPacketQueue_Locked() { num_queued_packets_ = 0; } -void WgPeer::Initialize(const uint8 spub[WG_PUBLIC_KEY_LEN], const uint8 preshared_key[WG_SYMMETRIC_KEY_LEN]) { - // Optionally use a preshared key, it defaults to all zeros. +void WgPeer::AddPacketToPeerQueue_Locked(Packet *packet) { + assert(IsPeerLocked()); + assert(!marked_for_delete_); + // Keep only the first MAX_QUEUED_PACKETS packets. + while (num_queued_packets_ >= MAX_QUEUED_PACKETS_PER_PEER) { + Packet *packet = first_queued_packet_; + first_queued_packet_ = packet->next; + num_queued_packets_--; + FreePacket(packet); + } + // Add the packet to the out queue that will get sent once handshake completes + *last_queued_packet_ptr_ = packet; + last_queued_packet_ptr_ = &packet->next; + packet->next = NULL; + num_queued_packets_++; +} + +void WgPeer::SetPublicKey(const WgPublicKey &spub) { + assert(dev_->IsMainThread()); + assert(IsOnlyZeros(s_remote_.bytes, sizeof(s_remote_.bytes)) || + memcmp(s_remote_.bytes, spub.bytes, sizeof(s_remote_.bytes)) == 0); + + s_remote_ = spub; + dev_->peer_id_lookup_[s_remote_] = this; + + if (!dev_->is_private_key_initialized_) + return; + + // Precompute: s_priv_pub_ := DH(sprivr, spubi) + curve25519_donna(s_priv_pub_, dev_->s_priv_, s_remote_.bytes); + // Precompute: precomputed_cookie_key_ := HASH(LABEL-COOKIE || Spub_m) + // precomputed_mac1_key_ := HASH(MAC1-COOKIE || Spub_m) + BlakeX2(precomputed_cookie_key_, sizeof(precomputed_cookie_key_), + kLabelCookie, sizeof(kLabelCookie), s_remote_.bytes, WG_PUBLIC_KEY_LEN); + BlakeX2(precomputed_mac1_key_, sizeof(precomputed_mac1_key_), + kLabelMac1, sizeof(kLabelMac1), s_remote_.bytes, WG_PUBLIC_KEY_LEN); + + // Remove the peer's keys + WG_ACQUIRE_LOCK(mutex_); + ClearKeys_Locked(); + ClearHandshake_Locked(); + WG_RELEASE_LOCK(mutex_); +} + +void WgPeer::SetPresharedKey(const uint8 preshared_key[WG_SYMMETRIC_KEY_LEN]) { if (preshared_key) memcpy(preshared_key_, preshared_key, sizeof(preshared_key_)); else memset(preshared_key_, 0, sizeof(preshared_key_)); - // Precompute: s_priv_pub_ := DH(sprivr, spubi) - memcpy(s_remote_, spub, sizeof(s_remote_)); - curve25519_donna(s_priv_pub_, dev_->s_priv_, s_remote_); - // Precompute: precomputed_cookie_key_ := HASH(LABEL-COOKIE || Spub_m) - // precomputed_mac1_key_ := HASH(MAC1-COOKIE || Spub_m) - BlakeX2(precomputed_cookie_key_, sizeof(precomputed_cookie_key_), - kLabelCookie, sizeof(kLabelCookie), spub, WG_PUBLIC_KEY_LEN); - BlakeX2(precomputed_mac1_key_, sizeof(precomputed_mac1_key_), - kLabelMac1, sizeof(kLabelMac1), spub, WG_PUBLIC_KEY_LEN); } // run on the client @@ -411,7 +514,7 @@ void WgPeer::CreateMessageHandshakeInitiation(Packet *packet) { // Hi := HASH(Ci || IDENTIFIER) memcpy(hs_.hi, kWgInitHash, sizeof(hs_.hi)); // Hi := HASH(Hi || Spub_r) - BlakeMix(hs_.hi, s_remote_, sizeof(s_remote_)); + BlakeMix(hs_.hi, s_remote_.bytes, sizeof(s_remote_)); // (Epriv_r, Epub_r) := DH-GENERATE() // msg.ephemeral = Epub_r OsGetRandomBytes(hs_.e_priv, sizeof(hs_.e_priv)); @@ -422,7 +525,7 @@ void WgPeer::CreateMessageHandshakeInitiation(Packet *packet) { // Hi := HASH(Hi || msg.ephemeral) BlakeMix(hs_.hi, dst->ephemeral, sizeof(dst->ephemeral)); // (Ci, K) := KDF2(Ci, DH(epriv, spub_r)) - ComputeHKDF2DH(hs_.ci, k, hs_.e_priv, s_remote_); + ComputeHKDF2DH(hs_.ci, k, hs_.e_priv, s_remote_.bytes); // msg.static = AEAD(K, 0, Spub_i, Hi) chacha20poly1305_encrypt(dst->static_enc, dev_->s_pub_, sizeof(dev_->s_pub_), hs_.hi, sizeof(hs_.hi), 0, k); // Hi := HASH(Hi || msg.static) @@ -461,7 +564,7 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) { uint8 e_priv[WG_PUBLIC_KEY_LEN]; }; union { - uint8 spubi[WG_PUBLIC_KEY_LEN]; + WgPublicKey spubi; uint8 e_remote[WG_PUBLIC_KEY_LEN]; uint8 hi2[WG_HASH_LEN]; }; @@ -488,13 +591,13 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) { // (Ci, K) := KDF2(Ci, DH(spriv, msg.ephemeral)) ComputeHKDF2DH(ci, k, dev->s_priv_, src->ephemeral); // Spub_i = AEAD_DEC(K, 0, msg.static, Hi) - if (!chacha20poly1305_decrypt(spubi, src->static_enc, sizeof(src->static_enc), hi, sizeof(hi), 0, k)) + if (!chacha20poly1305_decrypt(spubi.bytes, src->static_enc, sizeof(src->static_enc), hi, sizeof(hi), 0, k)) goto getout; // Hi := HASH(Hi || msg.static) BlakeMix(hi, src->static_enc, sizeof(src->static_enc)); // Lookup the peer with this ID while ((peer = dev->GetPeerFromPublicKey(spubi)) == NULL) { - if (dev->delegate_ == NULL || !dev->delegate_->HandleUnknownPeerId(spubi, packet)) + if (dev->delegate_ == NULL || !dev->delegate_->HandleUnknownPeerId(spubi.bytes, packet)) goto getout; } // (Ci, K) := KDF2(Ci, DH(sprivr, spubi)) @@ -538,7 +641,7 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) { // Ci : = KDF2(Ci, DH(epriv, epub)) ComputeHKDF2DH(ci, NULL, e_priv, e_remote); // Ci : = KDF2(Ci, DH(epriv, spub)) - ComputeHKDF2DH(ci, NULL, e_priv, peer->s_remote_); + ComputeHKDF2DH(ci, NULL, e_priv, peer->s_remote_.bytes); // (Ci, T, K) := KDF3(Ci, Q) blake2s_hkdf(ci, sizeof(ci), t, sizeof(t), k, sizeof(k), peer->preshared_key_, sizeof(preshared_key_), ci, WG_HASH_LEN); // Hr := HASH(Hr || T) @@ -548,11 +651,6 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) { keypair = WgPeer::CreateNewKeypair(false, ci, remote_key_id, extbuf + WG_TIMESTAMP_LEN, extfield_size); if (keypair) { - WG_ACQUIRE_LOCK(peer->mutex_); - peer->InsertKeypairInPeer_Locked(keypair); - peer->OnHandshakeAuthComplete(); - WG_RELEASE_LOCK(peer->mutex_); - dst->sender_key_id = dev->InsertInKeyIdLookup(peer, keypair); size_t extfield_out_size = 0; @@ -560,8 +658,17 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) { if (extfield_size) extfield_out_size = peer->WriteHandshakeExtension(dst->empty_enc, keypair); #endif // WITH_HANDSHAKE_EXT + + uint32 orig_packet_size = packet->size; packet->size = (unsigned)(sizeof(MessageHandshakeResponse) + extfield_out_size); + WG_ACQUIRE_LOCK(peer->mutex_); + peer->rx_bytes_ += orig_packet_size; + peer->tx_bytes_ += packet->size; + peer->InsertKeypairInPeer_Locked(keypair); + peer->OnHandshakeAuthComplete(); + WG_RELEASE_LOCK(peer->mutex_); + // msg.empty := AEAD(K, 0, "", Hr) chacha20poly1305_encrypt(dst->empty_enc, dst->empty_enc, extfield_out_size, hi, sizeof(hi), 0, k); // Hr := HASH(Hr || "") @@ -624,6 +731,7 @@ WgPeer *WgPeer::ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packe peer_and_keypair->second = keypair; WG_ACQUIRE_LOCK(peer->mutex_); + peer->rx_bytes_ += packet->size; peer->InsertKeypairInPeer_Locked(keypair); WG_RELEASE_LOCK(peer->mutex_); @@ -651,6 +759,9 @@ void WgPeer::ParseMessageHandshakeCookie(WgDevice *dev, const MessageHandshakeCo if (!xchacha20poly1305_decrypt(cookie, src->cookie_enc, sizeof(src->cookie_enc), peer->sent_mac1_, sizeof(peer->sent_mac1_), src->nonce, peer->precomputed_cookie_key_)) return; + WG_ACQUIRE_LOCK(peer->mutex_); + peer->rx_bytes_ += sizeof(MessageHandshakeCookie); + WG_RELEASE_LOCK(peer->mutex_); peer->expect_cookie_reply_ = false; peer->has_mac2_cookie_ = true; peer->mac2_cookie_timestamp_ = OsGetMilliseconds(); @@ -796,7 +907,7 @@ bool WgKeypairParseExtendedHandshake(WgKeypair *keypair, const uint8 *data, size #endif // WITH_HANDSHAKE_EXT -static void ActualFreeKeypair(void *x) { +static void WgKeypairDelayedDelete(void *x) { WgKeypair *t = (WgKeypair*)x; if (t->aes_gcm128_context_) free(t->aes_gcm128_context_); @@ -808,17 +919,18 @@ void WgPeer::DeleteKeypair(WgKeypair **kp) { *kp = NULL; if (t) { assert(t->peer->IsPeerLocked()); + WgDevice *dev = t->peer->dev_; if (t->addr_entry) { - WG_SCOPED_RWLOCK_EXCLUSIVE(dev_->addr_entry_lookup_lock_); - dev_->EraseKeypairAddrEntry_Locked(t); + WG_SCOPED_RWLOCK_EXCLUSIVE(dev->addr_entry_lookup_lock_); + dev->EraseKeypairAddrEntry_Locked(t); } if (t->local_key_id) { - WG_SCOPED_RWLOCK_EXCLUSIVE(dev_->key_id_lookup_lock_); - dev_->key_id_lookup_.erase(t->local_key_id); + WG_SCOPED_RWLOCK_EXCLUSIVE(dev->key_id_lookup_lock_); + dev->key_id_lookup_.erase(t->local_key_id); t->local_key_id = 0; } t->recv_key_state = WgKeypair::KEY_INVALID; - dev_->delayed_delete_.Add(&ActualFreeKeypair, t); + dev->delayed_delete_.Add(&WgKeypairDelayedDelete, t); } } @@ -1029,8 +1141,8 @@ void WgPeer::OnHandshakeFullyComplete() { } // Check if any of the timeouts have expired -uint32 WgPeer::CheckTimeouts(uint64 now) { - assert(IsPeerLocked()); +uint32 WgPeer::CheckTimeouts_Locked(uint64 now) { + assert(dev_->IsMainThread() && IsPeerLocked()); uint32 t, rv = 0; @@ -1096,7 +1208,7 @@ uint32 WgPeer::CheckTimeouts(uint64 now) { // Check all key stuff here to avoid calling possibly expensive timestamp routines in the packet handler void WgPeer::CheckAndUpdateTimeOfNextKeyEvent(uint64 now) { - assert(IsPeerLocked()); + assert(dev_->IsMainThread() && IsPeerLocked()); uint64 next_time = UINT64_MAX; uint32 rv = 0; @@ -1142,34 +1254,60 @@ void WgPeer::SetEndpoint(const IpAddr &sin) { endpoint_ = sin; } -void WgPeer::SetPersistentKeepalive(int persistent_keepalive_secs) { - if (persistent_keepalive_secs < 10 || persistent_keepalive_secs > 10000) - return; +bool WgPeer::SetPersistentKeepalive(int persistent_keepalive_secs) { + if (persistent_keepalive_secs < 0 || persistent_keepalive_secs > 65535) + return false; persistent_keepalive_ms_ = persistent_keepalive_secs * 1000; + return true; +} + +bool WgCidrAddrEquals(const WgCidrAddr &a, const WgCidrAddr &b) { + return (a.size == b.size && a.cidr == b.cidr && memcmp(a.addr, b.addr, a.size >> 3) == 0); } bool WgPeer::AddIp(const WgCidrAddr &cidr_addr) { + WgPeer *old_peer; assert(dev_->IsMainThread()); if (cidr_addr.size == 32) { if (cidr_addr.cidr > 32) return false; WG_ACQUIRE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_); - dev_->ip_to_peer_map_.InsertV4(ReadBE32(cidr_addr.addr), cidr_addr.cidr, this); + old_peer = (WgPeer*)dev_->ip_to_peer_map_.InsertV4(ReadBE32(cidr_addr.addr), cidr_addr.cidr, this); WG_RELEASE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_); - allowed_ips_.push_back(cidr_addr); - return true; } else if (cidr_addr.size == 128) { if (cidr_addr.cidr > 128) return false; WG_ACQUIRE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_); - dev_->ip_to_peer_map_.InsertV6(cidr_addr.addr, cidr_addr.cidr, this); + old_peer = (WgPeer*)dev_->ip_to_peer_map_.InsertV6(cidr_addr.addr, cidr_addr.cidr, this); WG_RELEASE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_); - allowed_ips_.push_back(cidr_addr); - return true; } else { return false; } + if (old_peer) { + for (auto it = old_peer->allowed_ips_.begin(); it != old_peer->allowed_ips_.end(); ++it) { + if (WgCidrAddrEquals(*it, cidr_addr)) { + old_peer->allowed_ips_.erase(it); + break; + } + } + } + allowed_ips_.push_back(cidr_addr); + return true; +} + +void WgPeer::RemoveAllIps() { + assert(dev_->IsMainThread()); + WG_ACQUIRE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_); + for (auto it = allowed_ips_.begin(); it != allowed_ips_.end(); ++it) { + if (it->size == 32) { + dev_->ip_to_peer_map_.RemoveV4(ReadBE32(it->addr), it->cidr); + } else if (it->size == 128) { + dev_->ip_to_peer_map_.RemoveV6(it->addr, it->cidr); + } + } + WG_RELEASE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_); + allowed_ips_.clear(); } void WgPeer::SetAllowMulticast(bool allow) { @@ -1196,6 +1334,18 @@ bool WgPeer::AddCipher(int cipher) { return true; } +void WgPeer::ScheduleNewHandshake() { + // Note, it's possible that the peer has already been marked for delete + if (main_thread_scheduled_.fetch_or(WgPeer::kMainThreadScheduled_ScheduleHandshake) == 0) { + main_thread_scheduled_next_ = NULL; + WG_ACQUIRE_LOCK(dev_->main_thread_scheduled_lock_); + *dev_->main_thread_scheduled_last_ = this; + dev_->main_thread_scheduled_last_ = &main_thread_scheduled_next_; + WG_RELEASE_LOCK(dev_->main_thread_scheduled_lock_); + // todo: in multithreaded impl need to trigger |RunAllMainThreadScheduled| to get called + } +} + WgRateLimit::WgRateLimit() { key1_[0] = key1_[1] = 1; key2_[0] = key2_[1] = 1; diff --git a/wireguard_proto.h b/wireguard_proto.h index 36f7b51..15d1432 100644 --- a/wireguard_proto.h +++ b/wireguard_proto.h @@ -11,7 +11,7 @@ #include #include #include - +#include // Threading macros that enable locks only in MT builds #if WITH_WG_THREADING #define WG_SCOPED_LOCK(name) ScopedLock scoped_lock(&name) @@ -25,6 +25,7 @@ #define WG_RELEASE_RWLOCK_EXCLUSIVE(name) name.ReleaseExclusive() #define WG_SCOPED_RWLOCK_SHARED(name) ScopedLockShared scoped_lock(&name) #define WG_SCOPED_RWLOCK_EXCLUSIVE(name) ScopedLockExclusive scoped_lock(&name) +#define WG_IF_LOCKS_ENABLED_ELSE(expr, def) (expr) #else // WITH_WG_THREADING #define WG_SCOPED_LOCK(name) #define WG_ACQUIRE_LOCK(name) @@ -37,6 +38,7 @@ #define WG_RELEASE_RWLOCK_EXCLUSIVE(name) #define WG_SCOPED_RWLOCK_SHARED(name) #define WG_SCOPED_RWLOCK_EXCLUSIVE(name) +#define WG_IF_LOCKS_ENABLED_ELSE(expr, def) (def) #endif // WITH_WG_THREADING enum ProtocolTimeouts { @@ -77,6 +79,7 @@ enum MessageFieldSizes { WG_MAC_LEN = 16, WG_TIMESTAMP_LEN = 12, WG_SIPHASH_KEY_LEN = 16, + WG_PUBLIC_KEY_LEN_BASE64 = 44, }; enum { @@ -194,11 +197,9 @@ struct WgPacketCompressionVer01 { }; STATIC_ASSERT(sizeof(WgPacketCompressionVer01) == 24, WgPacketCompressionVer01_wrong_size); - struct WgKeypair; class WgPeer; - class WgRateLimit { public: WgRateLimit(); @@ -260,10 +261,26 @@ struct WgAddrEntry { struct ScramblerSiphashKeys { uint64 keys[4]; }; - + +union WgPublicKey { + uint8 bytes[WG_PUBLIC_KEY_LEN]; + uint64 u64[WG_PUBLIC_KEY_LEN / 8]; + friend bool operator==(const WgPublicKey &a, const WgPublicKey &b) { + return memcmp(a.bytes, b.bytes, WG_PUBLIC_KEY_LEN) == 0; + } +}; + +struct WgPublicKeyHasher { + size_t operator()(const WgPublicKey&a) const { + uint64 rv = a.u64[0] ^ a.u64[1] ^ a.u64[2] ^ a.u64[3]; + return (size_t)(rv ^ (rv >> 32)); + } +}; + class WgDevice { friend class WgPeer; friend class WireguardProcessor; + friend class WgConfig; public: // Can be used to customize the behavior of WgDevice @@ -278,12 +295,15 @@ public: WgDevice(); ~WgDevice(); - // Initialize with the private key, precompute all internal keys etc. - void Initialize(const uint8 private_key[WG_PUBLIC_KEY_LEN]); + // Configure with the private key, precompute all internal keys etc. + void SetPrivateKey(const uint8 private_key[WG_PUBLIC_KEY_LEN]); // Create a new peer WgPeer *AddPeer(); + // Remove all peers + void RemoveAllPeers(); + // Setup header obfuscation void SetHeaderObfuscation(const char *key); @@ -303,17 +323,19 @@ public: WgRateLimit *rate_limiter() { return &rate_limiter_; } std::unordered_map &addr_entry_map() { return addr_entry_lookup_; } WgPacketCompressionVer01 *compression_header() { return &compression_header_; } + bool is_private_key_initialized() { return is_private_key_initialized_; } bool IsMainThread() { return CurrentThreadIdEquals(main_thread_id_); } - void SetCurrentThreadAsMainThread() { main_thread_id_ = GetCurrentThreadId(); } + bool IsMainOrDataThread() { return CurrentThreadIdEquals(main_thread_id_) || WG_IF_LOCKS_ENABLED_ELSE(delayed_delete_.enabled(), false); } void SetDelegate(Delegate *del) { delegate_ = del; } + private: std::pair *LookupPeerInKeyIdLookup(uint32 key_id); WgKeypair *LookupKeypairByKeyId(uint32 key_id); WgKeypair *LookupKeypairInAddrEntryMap(uint64 addr, uint32 slot); // Return the peer matching the |public_key| or NULL - WgPeer *GetPeerFromPublicKey(uint8 public_key[WG_PUBLIC_KEY_LEN]); + WgPeer *GetPeerFromPublicKey(const WgPublicKey &pubkey); // Create a cookie by inspecting the source address of the |packet| void MakeCookie(uint8 cookie[WG_COOKIE_LEN], Packet *packet); // Insert a new entry in |key_id_lookup_| @@ -330,7 +352,7 @@ private: WG_DECLARE_RWLOCK(ip_to_peer_map_lock_); // For enumerating all peers - WgPeer *peers_; + WgPeer *peers_, **last_peer_ptr_; // For hooking Delegate *delegate_; @@ -346,12 +368,22 @@ private: std::unordered_map addr_entry_lookup_; WG_DECLARE_RWLOCK(addr_entry_lookup_lock_); + // Mapping from peer id to peer. This may be accessed only from MT. + std::unordered_map peer_id_lookup_; + + // Queue of things scheduled to run on the main thread. + WG_DECLARE_LOCK(main_thread_scheduled_lock_); + WgPeer *main_thread_scheduled_, **main_thread_scheduled_last_; + // Counter for generating new indices in |keypair_lookup_| uint8 next_rng_slot_; // Whether packet obfuscation is enabled bool header_obfuscation_; + // Whether a private key has been setup for the device + bool is_private_key_initialized_; + ThreadId main_thread_id_; uint64 low_resolution_timestamp_; @@ -382,15 +414,16 @@ private: class WgPeer { friend class WgDevice; friend class WireguardProcessor; + friend class WgConfig; friend bool WgKeypairParseExtendedHandshake(WgKeypair *keypair, const uint8 *data, size_t data_size); friend void WgKeypairSetupCompressionExtension(WgKeypair *keypair, const WgPacketCompressionVer01 *remotec); public: explicit WgPeer(WgDevice *dev); ~WgPeer(); - void Initialize(const uint8 spub[WG_PUBLIC_KEY_LEN], const uint8 preshared_key[WG_SYMMETRIC_KEY_LEN]); - - void SetPersistentKeepalive(int persistent_keepalive_secs); + void SetPublicKey(const WgPublicKey &spub); + void SetPresharedKey(const uint8 preshared_key[WG_SYMMETRIC_KEY_LEN]); + bool SetPersistentKeepalive(int persistent_keepalive_secs); void SetEndpoint(const IpAddr &sin); void SetAllowMulticast(bool allow); @@ -398,15 +431,14 @@ public: bool AddCipher(int cipher); void SetCipherPrio(bool prio) { cipher_prio_ = prio; } bool AddIp(const WgCidrAddr &cidr_addr); + void RemoveAllIps(); static WgPeer *ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet); static WgPeer *ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packet); static void ParseMessageHandshakeCookie(WgDevice *dev, const MessageHandshakeCookie *src); void CreateMessageHandshakeInitiation(Packet *packet); bool CheckSwitchToNextKey_Locked(WgKeypair *keypair); - void ClearKeys_Locked(); - void ClearHandshake_Locked(); - void ClearPacketQueue_Locked(); + void RemovePeer(); bool CheckHandshakeRateLimit(); // Timer notifications @@ -422,25 +454,25 @@ public: ACTION_SEND_KEEPALIVE = 1, ACTION_SEND_HANDSHAKE = 2, }; - uint32 CheckTimeouts(uint64 now); + uint32 CheckTimeouts_Locked(uint64 now); - void AddPacketToPeerQueue(Packet *packet); - -#if WITH_WG_THREADING - bool IsPeerLocked() { return mutex_.IsLocked(); } -#else // WITH_WG_THREADING - bool IsPeerLocked() { return true; } -#endif // WITH_WG_THREADING + void AddPacketToPeerQueue_Locked(Packet *packet); + bool IsPeerLocked() { return WG_IF_LOCKS_ENABLED_ELSE(mutex_.IsLocked(), true); } private: static WgKeypair *CreateNewKeypair(bool is_initiator, const uint8 key[WG_HASH_LEN], uint32 send_key_id, const uint8 *extfield, size_t extfield_size); void WriteMacToPacket(const uint8 *data, MessageMacs *mac); - void DeleteKeypair(WgKeypair **kp); void CheckAndUpdateTimeOfNextKeyEvent(uint64 now); + static void DeleteKeypair(WgKeypair **kp); static void CopyEndpointToPeer_Locked(WgKeypair *keypair, const IpAddr *addr); + static void DelayedDelete(void *x); size_t WriteHandshakeExtension(uint8 *dst, WgKeypair *keypair); void InsertKeypairInPeer_Locked(WgKeypair *keypair); - + void ClearKeys_Locked(); + void ClearHandshake_Locked(); + void ClearPacketQueue_Locked(); + void ScheduleNewHandshake(); + WgDevice *dev_; WgPeer *next_peer_; @@ -492,6 +524,10 @@ private: // Whether |mac2_cookie_| is valid. bool has_mac2_cookie_; + // Whether the WgPeer has been deleted (i.e. RemovePeer has been called), + // and will be deleted as soon as the threads sync. + bool marked_for_delete_; + // Number of handshakes made so far, when this gets too high we stop connecting. uint8 handshake_attempts_; @@ -517,7 +553,10 @@ private: uint8 cipher_prio_; uint8 num_ciphers_; uint8 ciphers_[MAX_CIPHERS]; - + + uint64 rx_bytes_; + uint64 tx_bytes_; + // Handshake state that gets setup in |CreateMessageHandshakeInitiation| and used in // the response. struct HandshakeState { @@ -530,7 +569,7 @@ private: }; HandshakeState hs_; // Remote's static public key - init only. - uint8 s_remote_[WG_PUBLIC_KEY_LEN]; + WgPublicKey s_remote_; // Remote's preshared key - init only. uint8 preshared_key_[WG_SYMMETRIC_KEY_LEN]; // Precomputed DH(spriv_local, spub_remote) - init only.