diff --git a/.gitignore b/.gitignore
index 71ce91a..e1c4422 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,18 +1,10 @@
-/Debug/
-/Release/
-/ipzip2/Debug/
-/Build
-/Win32/
/TunSafe.aps
/*.sdf
-/*vcxproj.user
+*vcxproj.user
/*.opensdf
/*.suo
/.vs/
-/x64/
-/Azire.conf
+/build/
/*.psess
/*.vspx
/installer/*.zip
-/config/
-/tunsafe.com/
diff --git a/TunSafe.sln b/TunSafe.sln
index 40906e6..ff47cf6 100644
--- a/TunSafe.sln
+++ b/TunSafe.sln
@@ -5,6 +5,8 @@ VisualStudioVersion = 15.0.26403.7
MinimumVisualStudioVersion = 10.0.40219.1
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "TunSafe", "TunSafe.vcxproj", "{626FBC16-64C6-407D-BC2B-6C087794E0D0}"
EndProject
+Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ts", "ts.vcxproj", "{443E105E-8D7C-401F-BD41-D3F56C76104B}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Win32 = Debug|Win32
@@ -21,8 +23,19 @@ Global
{626FBC16-64C6-407D-BC2B-6C087794E0D0}.Release|Win32.Build.0 = Release|Win32
{626FBC16-64C6-407D-BC2B-6C087794E0D0}.Release|x64.ActiveCfg = Release|x64
{626FBC16-64C6-407D-BC2B-6C087794E0D0}.Release|x64.Build.0 = Release|x64
+ {443E105E-8D7C-401F-BD41-D3F56C76104B}.Debug|Win32.ActiveCfg = Debug|Win32
+ {443E105E-8D7C-401F-BD41-D3F56C76104B}.Debug|Win32.Build.0 = Debug|Win32
+ {443E105E-8D7C-401F-BD41-D3F56C76104B}.Debug|x64.ActiveCfg = Debug|x64
+ {443E105E-8D7C-401F-BD41-D3F56C76104B}.Debug|x64.Build.0 = Debug|x64
+ {443E105E-8D7C-401F-BD41-D3F56C76104B}.Release|Win32.ActiveCfg = Release|Win32
+ {443E105E-8D7C-401F-BD41-D3F56C76104B}.Release|Win32.Build.0 = Release|Win32
+ {443E105E-8D7C-401F-BD41-D3F56C76104B}.Release|x64.ActiveCfg = Release|x64
+ {443E105E-8D7C-401F-BD41-D3F56C76104B}.Release|x64.Build.0 = Release|x64
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
EndGlobalSection
+ GlobalSection(ExtensibilityGlobals) = postSolution
+ SolutionGuid = {F2DD9ED8-DDEA-4B40-9208-41726750D33D}
+ EndGlobalSection
EndGlobal
diff --git a/TunSafe.vcxproj b/TunSafe.vcxproj
index 2645b39..a85d424 100644
--- a/TunSafe.vcxproj
+++ b/TunSafe.vcxproj
@@ -22,7 +22,7 @@
{626FBC16-64C6-407D-BC2B-6C087794E0D0}
Win32Proj
TunSafe
- 10.0.15063.0
+ 10.0.17134.0
TunSafe
@@ -72,24 +72,28 @@
true
TunSafe
- $(SolutionDir)$(Platform)\$(Configuration)\
- $(Platform)\$(Configuration)\
+ $(SolutionDir)build\$(Platform)_$(Configuration)\
+ $(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\
true
$(VC_ExecutablePath_x64);$(WindowsSDK_ExecutablePath);$(VS_ExecutablePath);$(MSBuild_ExecutablePath);$(FxCopDir);$(PATH);C:\Bin\Dev\nasm
TunSafe
+ $(SolutionDir)build\$(Platform)_$(Configuration)\
+ $(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\
false
TunSafe
- $(SolutionDir)$(Platform)\$(Configuration)\
- $(Platform)\$(Configuration)\
+ $(SolutionDir)build\$(Platform)_$(Configuration)\
+ $(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\
false
$(VC_ExecutablePath_x64);$(WindowsSDK_ExecutablePath);$(VS_ExecutablePath);$(MSBuild_ExecutablePath);$(FxCopDir);$(PATH);C:\Bin\Dev\nasm
TunSafe
+ $(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\
+ $(SolutionDir)build\$(Platform)_$(Configuration)\
@@ -98,6 +102,7 @@
Disabled
WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions);_CRT_SECURE_NO_WARNINGS;_CRT_SECURE_NO_WARNINGS
.
+ false
Windows
@@ -114,6 +119,7 @@
.
+ false
Windows
@@ -133,6 +139,8 @@
WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions);_CRT_SECURE_NO_WARNINGS
MultiThreaded
.
+ false
+ false
Windows
@@ -157,6 +165,8 @@
AnySuitable
true
.
+ false
+ false
Windows
@@ -169,8 +179,10 @@
+
+
@@ -196,6 +208,7 @@
+
diff --git a/TunSafe.vcxproj.filters b/TunSafe.vcxproj.filters
index 2a444bb..01e1ff8 100644
--- a/TunSafe.vcxproj.filters
+++ b/TunSafe.vcxproj.filters
@@ -89,6 +89,12 @@
Source Files\Win32
+
+ Source Files\Win32
+
+
+ Source Files\Win32
+
@@ -154,6 +160,9 @@
Source Files
+
+ Source Files\Win32
+
diff --git a/build_freebsd.sh b/build_freebsd.sh
index b546216..2f86855 100755
--- a/build_freebsd.sh
+++ b/build_freebsd.sh
@@ -1,4 +1,4 @@
g++7 -I . -O2 -DNDEBUG -static -mssse3 -o tunsafe benchmark.cpp tunsafe_cpu.cpp wireguard_config.cpp ip_to_peer_map.cpp tunsafe_threading.cpp \
-wireguard.cpp wireguard_proto.cpp util.cpp network_bsd.cpp network_bsd_common.cpp \
+wireguard.cpp wireguard_proto.cpp ts.cpp util.cpp network_bsd.cpp network_bsd_common.cpp \
crypto/blake2s.cpp crypto/blake2s_sse.cpp crypto/chacha20poly1305.cpp crypto/curve25519-donna.cpp \
crypto/siphash.cpp crypto/chacha20_x64_gas.s crypto/poly1305_x64_gas.s ipzip2/ipzip2.cpp -lrt -pthread
diff --git a/build_linux.sh b/build_linux.sh
index 026955e..850dcdc 100755
--- a/build_linux.sh
+++ b/build_linux.sh
@@ -1,6 +1,6 @@
#!/bin/sh
clang++-6.0 -c -march=skylake-avx512 crypto/poly1305_x64_gas.s crypto/chacha20_x64_gas.s
-clang++-6.0 -I . -O3 -DNDEBUG -mssse3 -pthread -lrt -o tunsafe util.cpp wireguard_config.cpp wireguard.cpp ip_to_peer_map.cpp tunsafe_threading.cpp \
+clang++-6.0 -I . -O3 -DNDEBUG -mssse3 -pthread -lrt -o tunsafe util.cpp wireguard_config.cpp wireguard.cpp ts.cpp ip_to_peer_map.cpp tunsafe_threading.cpp \
wireguard_proto.cpp network_bsd.cpp network_bsd_common.cpp tunsafe_cpu.cpp benchmark.cpp crypto/blake2s.cpp crypto/blake2s_sse.cpp crypto/chacha20poly1305.cpp \
crypto/curve25519-donna.cpp crypto/siphash.cpp chacha20_x64_gas.o crypto/aesgcm/aesni_gcm_x64_gas.s \
crypto/aesgcm/aesni_x64_gas.s crypto/aesgcm/aesgcm.cpp poly1305_x64_gas.o ipzip2/ipzip2.cpp \
diff --git a/build_osx.sh b/build_osx.sh
index 29a02c0..3e905f9 100755
--- a/build_osx.sh
+++ b/build_osx.sh
@@ -3,8 +3,8 @@ set -e
clang++ -c -mavx512f -mavx512vl crypto/poly1305_x64_gas_macosx.s crypto/chacha20_x64_gas_macosx.s
-clang++ -g -O3 -I . -std=c++11 -DNDEBUG=1 -fno-exceptions -fno-rtti -ffunction-sections -o tunsafe \
-wireguard_config.cpp ip_to_peer_map.cpp tunsafe_threading.cpp wireguard.cpp wireguard_proto.cpp util.cpp network_bsd.cpp network_bsd_common.cpp benchmark.cpp tunsafe_cpu.cpp \
+clang++ -g -O3 -I . -std=c++11 -DNDEBUG=1 -Wno-deprecated-declarations -fno-exceptions -fno-rtti -ffunction-sections -o tunsafe \
+wireguard_config.cpp ip_to_peer_map.cpp tunsafe_threading.cpp wireguard.cpp wireguard_proto.cpp ts.cpp util.cpp network_bsd.cpp network_bsd_common.cpp benchmark.cpp tunsafe_cpu.cpp \
crypto/blake2s.cpp crypto/blake2s_sse.cpp crypto/chacha20poly1305.cpp crypto/curve25519-donna.cpp \
crypto/siphash.cpp crypto/aesgcm/aesgcm.cpp ipzip2/ipzip2.cpp \
crypto/aesgcm/aesni_gcm_x64_gas_macosx.s crypto/aesgcm/aesni_x64_gas_macosx.s crypto/aesgcm/ghash_x64_gas_macosx.s \
diff --git a/installer/tunsafe.nsi b/installer/tunsafe.nsi
index 044f485..8550365 100644
--- a/installer/tunsafe.nsi
+++ b/installer/tunsafe.nsi
@@ -57,10 +57,12 @@ Section "TunSafe Client" SecTunSafe
DetailPrint "Installing 64-bit version of TunSafe."
SetOutPath "$INSTDIR"
File "x64\TunSafe.exe"
+ File "x64\ts.exe"
${Else}
DetailPrint "Installing 32-bit version of TunSafe."
SetOutPath "$INSTDIR"
File "x86\TunSafe.exe"
+ File "x86\ts.exe"
${EndIf}
File "License.txt"
File "ChangeLog.txt"
@@ -205,6 +207,7 @@ Section "Uninstall"
Delete "$INSTDIR\TunSafe.exe"
+ Delete "$INSTDIR\ts.exe"
Delete "$INSTDIR\License.txt"
Delete "$INSTDIR\ChangeLog.txt"
Delete "$INSTDIR\Config\TunSafe.conf"
diff --git a/ip_to_peer_map.cpp b/ip_to_peer_map.cpp
index 93b8dd4..48539bb 100644
--- a/ip_to_peer_map.cpp
+++ b/ip_to_peer_map.cpp
@@ -5,6 +5,8 @@
#include "bit_ops.h"
#include
#include
+#include
+#include "util.h"
IpToPeerMap::IpToPeerMap() {
@@ -13,18 +15,22 @@ IpToPeerMap::IpToPeerMap() {
IpToPeerMap::~IpToPeerMap() {
}
-bool IpToPeerMap::InsertV4(uint32 ip, int cidr, void *peer) {
- ipv4_.Insert(ip, cidr, peer);
- return true;
+void *IpToPeerMap::InsertV4(uint32 ip, int cidr, void *peer) {
+ ipv4_.Insert(ip, cidr, &peer);
+ return peer;
}
-bool IpToPeerMap::InsertV6(const void *addr, int cidr, void *peer) {
+void *IpToPeerMap::InsertV6(const void *addr, int cidr, void *peer) {
Entry6 e;
+ for (auto it = ipv6_.begin(); it != ipv6_.end(); ++it) {
+ if (it->cidr_len == cidr && memcmp(it->ip, addr, 16) == 0)
+ return exch(it->peer, peer);
+ }
e.cidr_len = cidr;
e.peer = peer;
memcpy(e.ip, addr, 16);
ipv6_.push_back(e);
- return true;
+ return NULL;
}
void *IpToPeerMap::LookupV4(uint32 ip) {
@@ -43,6 +49,19 @@ void *IpToPeerMap::LookupV6DefaultPeer() {
return NULL;
}
+void IpToPeerMap::RemoveV4(uint32 ip, int cidr) {
+ ipv4_.Delete(ip, cidr);
+}
+
+void IpToPeerMap::RemoveV6(const void *addr, int cidr) {
+ for (auto it = ipv6_.begin(); it != ipv6_.end(); ++it) {
+ if (it->cidr_len == cidr && memcmp(it->ip, addr, 16) == 0) {
+ ipv6_.erase(it);
+ return;
+ }
+ }
+}
+
static int CalculateIPv6CommonPrefix(const uint8 *a, const uint8 *b) {
uint64 x = ToBE64(*(uint64*)&a[0] ^ *(uint64*)&b[0]);
uint64 y = ToBE64(*(uint64*)&a[8] ^ *(uint64*)&b[8]);
@@ -62,20 +81,6 @@ void *IpToPeerMap::LookupV6(const void *addr) {
return best_peer;
}
-void IpToPeerMap::RemovePeer(void *peer) {
- assert(0);
- // todo: remove peer also from ipv4_
- {
- size_t n = ipv6_.size();
- Entry6 *r = &ipv6_[0], *w = r;
- for (size_t i = 0; i != n; i++, r++) {
- if (r->peer != peer)
- *w++ = *r;
- }
- ipv6_.resize(w - &ipv6_[0]);
- }
-}
-
#pragma warning (disable: 4200) // warning C4200: nonstandard extension used: zero-sized array in struct/union
struct RoutingTrie32::Node {
uint32 key;
@@ -175,7 +180,7 @@ RoutingTrie32::~RoutingTrie32() {
RoutingTrie32::Value RoutingTrie32::Lookup(uint32 ip) {
uint32 key = ip;
Node *n = root_, *pn = n, *ppn;
- int cindex = 0;
+ uint32 cindex = 0;
if (!n)
return NULL;
// Find the longest prefix match
@@ -232,7 +237,7 @@ backtrace:
}
// strip lsb of cindex and find child
cindex &= cindex - 1;
- assert(cindex < (1 << pn->bits));
+ assert(cindex < (1U << pn->bits));
n = pn->child[cindex];
if (!NODE_IS_NULL_OR_OLEAF(n))
break;
@@ -246,7 +251,7 @@ backtrace:
}
}
-bool RoutingTrie32::InsertLeafInto(Node **nn, uint8 leaf_pos, Value value) {
+bool RoutingTrie32::InsertLeafInto(Node **nn, uint8 leaf_pos, Value *valuep) {
// put higher cidr higher up
Node *n = *nn;
assert(IS_LEAF(n));
@@ -255,12 +260,12 @@ bool RoutingTrie32::InsertLeafInto(Node **nn, uint8 leaf_pos, Value value) {
if (leaf_pos < n->pos)
break;
if (leaf_pos == n->pos) {
- n->leaf_value = value;
+ std::swap(n->leaf_value, *valuep);
return true;
}
nn = &n->leaf_next;
} while ((n = *nn) != NULL);
- Node *leaf = NewLeaf(key, leaf_pos, value);
+ Node *leaf = NewLeaf(key, leaf_pos, *valuep);
if (leaf == NULL)
return false;
leaf->leaf_next = *nn;
@@ -283,14 +288,14 @@ void RoutingTrie32::PutChild(Node *pn, uint32 i, Node *n) {
assert(pn->full_children < 0x80000000);
}
-bool RoutingTrie32::Insert(uint32 ip, int cidr, Value value) {
+bool RoutingTrie32::Insert(uint32 ip, int cidr, Value *valuep) {
uint32 key = ip;
Node **nn = &root_, *n = root_, *pn = NULL, *leaf, *tn = NULL, *leaf_to_free = NULL;
uint8 leaf_pos = 32 - cidr;
-
+
if (n == NULL) {
- root_ = NewLeaf(key, leaf_pos, value);
- return (root_ != NULL);
+ root_ = NewLeaf(key, leaf_pos, exch_null(*valuep));
+ return false;
}
assert(!NODE_IS_OLEAF(n));
@@ -316,7 +321,7 @@ force_add:
if (IS_LEAF(n)) {
if (key != n->key)
goto force_add;
- return InsertLeafInto(nn, leaf_pos, value);
+ return InsertLeafInto(nn, leaf_pos, valuep);
}
pn = n;
nn = &n->child[index];
@@ -330,6 +335,7 @@ force_add:
*nn = n;
}
}
+ Value value = *valuep;
// Create either leaf or oleaf
if (tn->pos == leaf_pos) {
leaf = VALUE_TO_OLEAF(value);
@@ -338,8 +344,8 @@ force_add:
FreeNode(tn);
return false;
}
-
// -- Start making irreversible changes here
+ *valuep = NULL;
if (leaf_to_free)
FreeNode(leaf_to_free);
diff --git a/ip_to_peer_map.h b/ip_to_peer_map.h
index eef7a36..2db5d5e 100644
--- a/ip_to_peer_map.h
+++ b/ip_to_peer_map.h
@@ -15,7 +15,7 @@ public:
~RoutingTrie32();
NOINLINE Value Lookup(uint32 ip);
NOINLINE Value LookupExact(uint32 ip, int cidr);
- bool Insert(uint32 ip, int cidr, Value value);
+ bool Insert(uint32 ip, int cidr, Value *value);
bool Delete(uint32 ip, int cidr);
private:
@@ -31,7 +31,7 @@ private:
static void PutChild(Node *pn, uint32 i, Node *n);
static void ReplaceChild(Node **pnp, Node *n);
static Node *ConvertOleafToLeaf(Node *pn, uint32 i, Node *n);
- static bool InsertLeafInto(Node **n, uint8 leaf_pos, RoutingTrie32::Value value);
+ static bool InsertLeafInto(Node **n, uint8 leaf_pos, RoutingTrie32::Value *value);
};
@@ -43,8 +43,8 @@ public:
~IpToPeerMap();
// Inserts an IP address of a given CIDR length into the lookup table, pointing to peer.
- bool InsertV4(uint32 ip, int cidr, void *peer);
- bool InsertV6(const void *addr, int cidr, void *peer);
+ void *InsertV4(uint32 ip, int cidr, void *peer);
+ void *InsertV6(const void *addr, int cidr, void *peer);
// Lookup the peer matching the IP Address
void *LookupV4(uint32 ip);
@@ -53,8 +53,8 @@ public:
void *LookupV4DefaultPeer();
void *LookupV6DefaultPeer();
- // Remove a peer from the table
- void RemovePeer(void *peer);
+ void RemoveV4(uint32 ip, int cidr);
+ void RemoveV6(const void *addr, int cidr);
private:
struct Entry6 {
uint8 ip[16];
diff --git a/netapi.h b/netapi.h
index dac0b0a..2bb455e 100644
--- a/netapi.h
+++ b/netapi.h
@@ -18,7 +18,6 @@
#pragma warning (disable: 4200)
-void OsGetRandomBytes(uint8 *dst, size_t dst_size);
uint64 OsGetMilliseconds();
void OsGetTimestampTAI64N(uint8 dst[12]);
void OsInterruptibleSleep(int millis);
@@ -127,13 +126,13 @@ public:
uint8 neighbor_discovery_spoofing_mac[6];
};
- virtual bool Initialize(const TunConfig &&config, TunConfigOut *out) = 0;
+ virtual bool Configure(const TunConfig &&config, TunConfigOut *out) = 0;
virtual void WriteTunPacket(Packet *packet) = 0;
};
class UdpInterface {
public:
- virtual bool Initialize(int listen_port) = 0;
+ virtual bool Configure(int listen_port) = 0;
virtual void WriteUdpPacket(Packet *packet) = 0;
};
diff --git a/network_bsd.cpp b/network_bsd.cpp
index 40e987c..547e3db 100644
--- a/network_bsd.cpp
+++ b/network_bsd.cpp
@@ -13,10 +13,12 @@
#include
#include
#include
+#include
#include
#include
#include
#include
+#include
static Packet *freelist;
@@ -49,6 +51,7 @@ void FreePackets() {
}
}
+
class TunsafeBackendBsdImpl : public TunsafeBackendBsd {
public:
TunsafeBackendBsdImpl();
@@ -61,7 +64,7 @@ public:
virtual void WriteTunPacket(Packet *packet) override;
// -- from UdpInterface
- virtual bool Initialize(int listen_port) override;
+ virtual bool Configure(int listen_port) override;
virtual void WriteUdpPacket(Packet *packet) override;
virtual void HandleSigAlrm() override { got_sig_alarm_ = true; }
@@ -72,12 +75,19 @@ private:
bool ReadFromTun();
bool WriteToUdp();
bool WriteToTun();
+ bool InitializeUnixDomainSocket(const char *devname);
+ // Exists for the unix domain sockets
+ struct SockInfo {
+ bool is_listener;
+
+ std::string inbuf, outbuf;
+ };
+ bool HandleSpecialPollfd(struct pollfd *pollfd, struct SockInfo *sockinfo);
+ void CloseSpecialPollfd(size_t i);
void SetUdpFd(int fd);
void SetTunFd(int fd);
- inline void RecomputeMaxFd() { max_fd_ = ((tun_fd_>udp_fd_) ? tun_fd_ : udp_fd_) + 1; }
- int tun_fd_, udp_fd_, max_fd_;
bool got_sig_alarm_;
bool exit_;
@@ -89,13 +99,25 @@ private:
Packet *read_packet_;
- fd_set readfds_, writefds_;
+ enum {
+ kMaxPollFd = 5,
+ kPollFdTun = 0,
+ kPollFdUdp = 1,
+ kPollFdUnix = 2,
+ };
+
+ unsigned int pollfd_num_;
+ struct pollfd pollfd_[kMaxPollFd];
+
+ struct SockInfo sockinfo_[kMaxPollFd - 2];
+
+ struct sockaddr_un un_addr_;
+
+ UnixSocketDeletionWatcher un_deletion_watcher_;
};
TunsafeBackendBsdImpl::TunsafeBackendBsdImpl()
- : tun_fd_(-1),
- udp_fd_(-1),
- tun_readable_(false),
+ : tun_readable_(false),
tun_writable_(false),
udp_readable_(false),
udp_writable_(false),
@@ -106,35 +128,39 @@ TunsafeBackendBsdImpl::TunsafeBackendBsdImpl()
udp_queue_(NULL),
udp_queue_end_(&udp_queue_),
read_packet_(NULL) {
- RecomputeMaxFd();
-
- FD_ZERO(&readfds_);
- FD_ZERO(&writefds_);
read_packet_ = AllocPacket();
+ for(size_t i = 0; i < kMaxPollFd; i++)
+ pollfd_[i].fd = -1;
+ pollfd_num_ = 3;
+ sockinfo_[0].is_listener = true;
+ memset(&un_addr_, 0, sizeof(un_addr_));
}
TunsafeBackendBsdImpl::~TunsafeBackendBsdImpl() {
+ if (un_addr_.sun_path[0])
+ unlink(un_addr_.sun_path);
if (read_packet_)
FreePacket(read_packet_);
+ for(size_t i = 0; i < pollfd_num_; i++)
+ close(pollfd_[i].fd);
}
void TunsafeBackendBsdImpl::SetUdpFd(int fd) {
- udp_fd_ = fd;
- RecomputeMaxFd();
+ pollfd_[kPollFdUdp].fd = fd;
+ pollfd_[kPollFdUdp].events = POLLIN;
udp_writable_ = true;
}
void TunsafeBackendBsdImpl::SetTunFd(int fd) {
- tun_fd_ = fd;
- RecomputeMaxFd();
+ pollfd_[kPollFdTun].fd = fd;
+ pollfd_[kPollFdTun].events = POLLIN;
tun_writable_ = true;
}
-
bool TunsafeBackendBsdImpl::ReadFromUdp(bool overload) {
socklen_t sin_len;
sin_len = sizeof(read_packet_->addr.sin);
- int r = recvfrom(udp_fd_, read_packet_->data, kPacketCapacity, 0,
+ int r = recvfrom(pollfd_[kPollFdUdp].fd, read_packet_->data, kPacketCapacity, 0,
(sockaddr*)&read_packet_->addr.sin, &sin_len);
if (r >= 0) {
// printf("Read %d bytes from UDP\n", r);
@@ -157,11 +183,12 @@ bool TunsafeBackendBsdImpl::ReadFromUdp(bool overload) {
bool TunsafeBackendBsdImpl::WriteToUdp() {
assert(udp_writable_);
// RINFO("Send %d bytes to %s", (int)udp_queue_->size, inet_ntoa(udp_queue_->sin.sin_addr));
- int r = sendto(udp_fd_, udp_queue_->data, udp_queue_->size, 0,
+ int r = sendto(pollfd_[kPollFdUdp].fd, udp_queue_->data, udp_queue_->size, 0,
(sockaddr*)&udp_queue_->addr.sin, sizeof(udp_queue_->addr.sin));
if (r < 0) {
if (errno == EAGAIN) {
udp_writable_ = false;
+ pollfd_[kPollFdUdp].events = POLLIN | POLLOUT;
return false;
}
perror("Write to UDP failed");
@@ -185,7 +212,7 @@ static inline bool IsCompatibleProto(uint32 v) {
bool TunsafeBackendBsdImpl::ReadFromTun() {
assert(tun_readable_);
Packet *packet = read_packet_;
- int r = read(tun_fd_, packet->data - TUN_PREFIX_BYTES, kPacketCapacity + TUN_PREFIX_BYTES);
+ int r = read(pollfd_[kPollFdTun].fd, packet->data - TUN_PREFIX_BYTES, kPacketCapacity + TUN_PREFIX_BYTES);
if (r >= 0) {
// printf("Read %d bytes from TUN\n", r);
packet->size = r - TUN_PREFIX_BYTES;
@@ -215,10 +242,11 @@ bool TunsafeBackendBsdImpl::WriteToTun() {
if (TUN_PREFIX_BYTES) {
WriteBE32(tun_queue_->data - TUN_PREFIX_BYTES, GetProtoFromPacket(tun_queue_->data, tun_queue_->size));
}
- int r = write(tun_fd_, tun_queue_->data - TUN_PREFIX_BYTES, tun_queue_->size + TUN_PREFIX_BYTES);
+ int r = write(pollfd_[kPollFdTun].fd, tun_queue_->data - TUN_PREFIX_BYTES, tun_queue_->size + TUN_PREFIX_BYTES);
if (r < 0) {
if (errno == EAGAIN) {
tun_writable_ = false;
+ pollfd_[kPollFdTun].events = POLLIN | POLLOUT;
return false;
}
RERROR("Write to tun failed");
@@ -242,11 +270,13 @@ bool TunsafeBackendBsdImpl::InitializeTun(char devname[16]) {
fcntl(tun_fd, F_SETFD, FD_CLOEXEC);
fcntl(tun_fd, F_SETFL, O_NONBLOCK);
SetTunFd(tun_fd);
+
+ InitializeUnixDomainSocket(devname);
return true;
}
void TunsafeBackendBsdImpl::WriteTunPacket(Packet *packet) override {
- assert(tun_fd_ >= 0);
+ assert(pollfd_[kPollFdTun].fd >= 0);
Packet *queue_is_used = tun_queue_;
*tun_queue_end_ = packet;
tun_queue_end_ = &packet->next;
@@ -256,7 +286,7 @@ void TunsafeBackendBsdImpl::WriteTunPacket(Packet *packet) override {
}
// Called to initialize udp
-bool TunsafeBackendBsdImpl::Initialize(int listen_port) override {
+bool TunsafeBackendBsdImpl::Configure(int listen_port) override {
int udp_fd = open_udp(listen_port);
if (udp_fd < 0) { RERROR("Error opening udp"); return false; }
fcntl(udp_fd, F_SETFD, FD_CLOEXEC);
@@ -266,7 +296,7 @@ bool TunsafeBackendBsdImpl::Initialize(int listen_port) override {
}
void TunsafeBackendBsdImpl::WriteUdpPacket(Packet *packet) override {
- assert(udp_fd_ >= 0);
+ assert(pollfd_[kPollFdUdp].fd >= 0);
Packet *queue_is_used = udp_queue_;
*udp_queue_end_ = packet;
udp_queue_end_ = &packet->next;
@@ -275,16 +305,137 @@ void TunsafeBackendBsdImpl::WriteUdpPacket(Packet *packet) override {
WriteToUdp();
}
+bool TunsafeBackendBsdImpl::InitializeUnixDomainSocket(const char *devname) {
+ int fd = socket(AF_UNIX, SOCK_STREAM, 0);
+ if (fd == -1) {
+ RERROR("Error creating unix domain socket");
+ return false;
+ }
+
+ fcntl(fd, F_SETFD, FD_CLOEXEC);
+ fcntl(fd, F_SETFL, O_NONBLOCK);
+
+ mkdir("/var/run/wireguard", 0755);
+ un_addr_.sun_family = AF_UNIX;
+ snprintf(un_addr_.sun_path, sizeof(un_addr_.sun_path), "/var/run/wireguard/%s.sock", devname);
+ unlink(un_addr_.sun_path);
+ if (bind(fd, (struct sockaddr*)&un_addr_, sizeof(un_addr_)) == -1) {
+ RERROR("Error binding unix domain socket");
+ close(fd);
+ return false;
+ }
+ if (listen(fd, 5) == -1) {
+ RERROR("Error listening on unix domain socket");
+ close(fd);
+ return false;
+ }
+
+ pollfd_[kPollFdUnix].fd = fd;
+ pollfd_[kPollFdUnix].events = POLLIN;
+
+ return true;
+}
+
+static const char *FindMessageEnd(const char *start, size_t size) {
+ if (size <= 1)
+ return NULL;
+ const char *start_end = start + size - 1;
+ for(;(start = (const char*)memchr(start, '\n', start_end - start)) != NULL; start++) {
+ if (start[1] == '\n')
+ return start + 2;
+ }
+ return NULL;
+}
+
+bool TunsafeBackendBsdImpl::HandleSpecialPollfd(struct pollfd *pfd, struct SockInfo *sockinfo) {
+ // handle domain socket thing
+ if (sockinfo->is_listener) {
+ if (pfd->revents & POLLIN) {
+ // wait if we can't allocate more pollfd
+ if (pollfd_num_ == kMaxPollFd) {
+ pfd->events = 0;
+ return true;
+ }
+ int fd = accept(pfd->fd, NULL, NULL);
+ if (fd >= 0) {
+ size_t slot = pollfd_num_++;
+ pollfd_[slot].fd = fd;
+ pollfd_[slot].events = POLLIN;
+ pollfd_[slot].revents = 0;
+ sockinfo_[slot - 2].is_listener = false;
+ } else {
+ RERROR("Unix domain socket accept failed");
+ }
+ }
+ if (pfd->revents & ~POLLIN) {
+ RERROR("Unix domain socket got an error code");
+ return false;
+ }
+ return true;
+ }
+ if (pfd->revents & POLLIN) {
+ char buf[4096];
+ // read as much data as we can until we see \n\n
+ ssize_t n = recv(pfd->fd, buf, sizeof(buf), 0);
+ if (n <= 0)
+ return (n == -1 && errno == EAGAIN); // premature eof or error
+ sockinfo->inbuf.append(buf, n);
+ const char *message_end = FindMessageEnd(&sockinfo->inbuf[0], sockinfo->inbuf.size());
+ if (message_end) {
+ if (message_end != &sockinfo->inbuf[sockinfo->inbuf.size()])
+ return false; // trailing data?
+ WgConfig::HandleConfigurationProtocolMessage(processor_, std::move(sockinfo->inbuf), &sockinfo->outbuf);
+ if (!sockinfo->outbuf.size())
+ return false;
+ pfd->revents = pfd->events = POLLOUT;
+ }
+ }
+ if (pfd->revents & POLLOUT) {
+ size_t n = send(pfd->fd, sockinfo->outbuf.data(), sockinfo->outbuf.size(), 0);
+ if (n <= 0)
+ return (n == -1 && errno == EAGAIN); // premature eof or error
+ sockinfo->outbuf.erase(0, n);
+ if (!sockinfo->outbuf.size())
+ return false;
+ }
+
+ if (pfd->revents & ~(POLLIN | POLLOUT)) {
+ RERROR("Unix domain socket got an error code");
+ return false;
+ }
+ return true;
+}
+
+void TunsafeBackendBsdImpl::CloseSpecialPollfd(size_t i) {
+ close(pollfd_[i].fd);
+ pollfd_[i].fd = -1;
+ sockinfo_[i - 2].inbuf.clear();
+ sockinfo_[i - 2].outbuf.clear();
+ pollfd_[i] = pollfd_[(size_t)pollfd_num_ - 1];
+ std::swap(sockinfo_[i - 2], sockinfo_[(size_t)pollfd_num_ - 1 - 2]);
+
+ // Can now allow more sockets?
+ if (pollfd_num_-- == kMaxPollFd && sockinfo_[kPollFdUnix - 2].is_listener)
+ pollfd_[kPollFdUnix].events = POLLIN;
+}
+
void TunsafeBackendBsdImpl::RunLoopInner() {
int free_packet_interval = 10;
int overload_ctr = 0;
+ if (!un_deletion_watcher_.Start(un_addr_.sun_path, &exit_))
+ return;
+
while (!exit_) {
int n = -1;
- // This is not fully signal safe.
if (got_sig_alarm_) {
got_sig_alarm_ = false;
+
+ if (un_deletion_watcher_.Poll(un_addr_.sun_path)) {
+ RINFO("Unix socket %s deleted.", un_addr_.sun_path);
+ break;
+ }
processor_->SecondLoop();
if (free_packet_interval == 0) {
@@ -296,33 +447,53 @@ void TunsafeBackendBsdImpl::RunLoopInner() {
overload_ctr -= (overload_ctr != 0);
}
- if (tun_fd_ >= 0) {
- FD_SET(tun_fd_, &readfds_);
- if (tun_writable_) FD_CLR(tun_fd_, &writefds_); else FD_SET(tun_fd_, &writefds_);
- }
-
- if (udp_fd_ >= 0) {
- FD_SET(udp_fd_, &readfds_);
- if (udp_writable_) FD_CLR(udp_fd_, &writefds_); else FD_SET(udp_fd_, &writefds_);
- }
-
- n = select(max_fd_, &readfds_, &writefds_, NULL, NULL);
+#if defined(OS_LINUX) || defined(OS_FREEBSD)
+ n = ppoll(pollfd_, pollfd_num_, NULL, &orig_signal_mask_);
+#else
+ n = poll(pollfd_, pollfd_num_, -1);
+#endif
if (n == -1) {
if (errno != EINTR) {
- fprintf(stderr, "select failed\n");
+ RERROR("poll failed");
break;
}
} else {
- if (tun_fd_ >= 0) {
- tun_readable_ = (FD_ISSET(tun_fd_, &readfds_) != 0);
- tun_writable_ |= (FD_ISSET(tun_fd_, &writefds_) != 0);
+
+ if (pollfd_[kPollFdTun].revents & (POLLERR | POLLHUP | POLLNVAL)) {
+ if (pollfd_[kPollFdTun].revents & POLLERR) {
+ tun_interface_gone_ = true;
+ RERROR("Tun interface gone, closing.");
+ } else {
+ RERROR("Tun interface error %d, closing.", pollfd_[kPollFdTun].revents);
+ }
+ break;
}
- if (udp_fd_ >= 0) {
- udp_readable_ = (FD_ISSET(udp_fd_, &readfds_) != 0);
- udp_writable_ |= (FD_ISSET(udp_fd_, &writefds_) != 0);
+ tun_readable_ = (pollfd_[kPollFdTun].revents & POLLIN) != 0;
+ if (pollfd_[kPollFdTun].revents & POLLOUT) {
+ pollfd_[kPollFdTun].events = POLLIN;
+ tun_writable_ = true;
+ }
+
+ if (pollfd_[kPollFdUdp].revents & (POLLERR | POLLHUP | POLLNVAL)) {
+ RERROR("UDP error %d, closing.", pollfd_[kPollFdUdp].revents);
+ break;
+ }
+
+ udp_readable_ = (pollfd_[kPollFdUdp].revents & POLLIN) != 0;
+ if (pollfd_[kPollFdUdp].revents & POLLOUT) {
+ pollfd_[kPollFdUdp].events = POLLIN;
+ udp_writable_ = true;
+ }
+
+ for(size_t i = 2; i < pollfd_num_; i++) {
+ if (pollfd_[i].revents && !HandleSpecialPollfd(&pollfd_[i], &sockinfo_[i - 2])) {
+ // Close the fd / discard the sockinfo
+ CloseSpecialPollfd(i);
+ i--;
+ }
}
}
-
+
bool overload = (overload_ctr != 0);
for(int loop = 0; ; loop++) {
@@ -342,6 +513,8 @@ void TunsafeBackendBsdImpl::RunLoopInner() {
processor_->RunAllMainThreadScheduled();
}
+
+ un_deletion_watcher_.Stop();
}
TunsafeBackendBsd *CreateTunsafeBackendBsd() {
diff --git a/network_bsd_common.cpp b/network_bsd_common.cpp
index 161f70f..f0a59fa 100644
--- a/network_bsd_common.cpp
+++ b/network_bsd_common.cpp
@@ -39,6 +39,8 @@
#include
#include
#include
+#include
+#include
#endif
void tunsafe_die(const char *msg) {
@@ -286,15 +288,6 @@ void OsGetTimestampTAI64N(uint8 dst[12]) {
WriteBE32(dst + 8, nanos);
}
-void OsGetRandomBytes(uint8 *data, size_t data_size) {
- int fd = open("/dev/urandom", O_RDONLY);
- int r = read(fd, data, data_size);
- if (r < 0) r = 0;
- close(fd);
- for (; r < data_size; r++)
- data[r] = rand() >> 6;
-}
-
void OsInterruptibleSleep(int millis) {
usleep((useconds_t)millis * 1000);
}
@@ -387,11 +380,12 @@ int open_tun(char *devname, size_t devname_size) {
memset(&ifr, 0, sizeof(ifr));
ifr.ifr_flags = IFF_TUN | IFF_NO_PI;
+ my_strlcpy(ifr.ifr_name, sizeof(ifr.ifr_name), devname);
if ((err = ioctl(fd, TUNSETIFF, (void *) &ifr)) < 0) {
close(fd);
return err;
}
- strcpy(devname, ifr.ifr_name);
+ my_strlcpy(devname, devname_size, ifr.ifr_name);
return fd;
}
#endif
@@ -411,6 +405,8 @@ int open_udp(int listen_on_port) {
TunsafeBackendBsd::TunsafeBackendBsd()
: processor_(NULL) {
+ devname_[0] = 0;
+ tun_interface_gone_ = false;
}
TunsafeBackendBsd::~TunsafeBackendBsd() {
@@ -495,10 +491,10 @@ void TunsafeBackendBsd::DelRoute(const RouteInfo &cd) {
static bool IsIpv6AddressSet(const void *p) {
return (ReadLE64(p) | ReadLE64((char*)p + 8)) != 0;
}
-
+
// Called to initialize tun
-bool TunsafeBackendBsd::Initialize(const TunConfig &&config, TunConfigOut *out) override {
- char devname[16];
+bool TunsafeBackendBsd::Configure(const TunConfig &&config, TunConfigOut *out) override {
+ char buf[kSizeOfAddress];
if (!RunPrePostCommand(config.pre_post_commands.pre_up)) {
RERROR("Pre command failed!");
@@ -507,17 +503,35 @@ bool TunsafeBackendBsd::Initialize(const TunConfig &&config, TunConfigOut *out)
out->enable_neighbor_discovery_spoofing = false;
- if (!InitializeTun(devname))
+ if (!InitializeTun(devname_))
return false;
- if (config.ipv6_cidr)
- RERROR("IPv6 not supported");
-
uint32 netmask = CidrToNetmaskV4(config.cidr);
uint32 default_route_v4 = ComputeIpv4DefaultRoute(config.ip, netmask);
-
- RunCommand("/sbin/ifconfig %s %A mtu %d %A netmask %A up", devname, config.ip, config.mtu, config.ip, netmask);
- AddRoute(config.ip & netmask, config.cidr, config.ip, devname);
+
+
+#if defined(OS_LINUX)
+ if (config.ip) {
+ char ip[4];
+ WriteBE32(ip, config.ip);
+ RunCommand("/sbin/ip address add dev %s %s", devname_, print_ip_prefix(buf, AF_INET, ip, config.cidr));
+ }
+ if (config.ipv6_cidr) {
+ RunCommand("/sbin/ip address add dev %s %s", devname_, print_ip_prefix(buf, AF_INET6, config.ipv6_address, config.ipv6_cidr));
+ }
+ RunCommand("/sbin/ip link set dev %s mtu %d up", devname_, config.mtu);
+#else // !defined(OS_LINUX)
+ if (config.ip) {
+ RunCommand("/sbin/ifconfig %s %A mtu %d %A netmask %A up", devname_, config.ip, config.mtu, config.ip, netmask);
+ }
+ if (config.ipv6_cidr) {
+ RunCommand("/sbin/ifconfig %s inet6 add %s", devname_, print_ip_prefix(buf, AF_INET6, config.ipv6_address, config.ipv6_cidr));
+ }
+#endif // !defined(OS_LINUX)
+
+ if (config.ip) {
+ AddRoute(config.ip & netmask, config.cidr, config.ip, devname_);
+ }
if (config.use_ipv4_default_route) {
if (config.default_route_endpoint_v4) {
@@ -533,35 +547,30 @@ bool TunsafeBackendBsd::Initialize(const TunConfig &&config, TunConfigOut *out)
AddRoute(ReadBE32(it->addr), it->cidr, ipv4_default_gw, default_iface);
}
}
- AddRoute(0x00000000, 1, default_route_v4, devname);
- AddRoute(0x80000000, 1, default_route_v4, devname);
+ AddRoute(0x00000000, 1, default_route_v4, devname_);
+ AddRoute(0x80000000, 1, default_route_v4, devname_);
}
uint8 default_route_v6[16];
if (config.ipv6_cidr) {
static const uint8 matchall_1_route[17] = {0x80, 0, 0, 0};
- char buf[kSizeOfAddress];
-
ComputeIpv6DefaultRoute(config.ipv6_address, config.ipv6_cidr, default_route_v6);
-
- RunCommand("/sbin/ifconfig %s inet6 add %s", devname, print_ip_prefix(buf, AF_INET6, config.ipv6_address, config.ipv6_cidr));
-
if (config.use_ipv6_default_route) {
if (IsIpv6AddressSet(config.default_route_endpoint_v6)) {
RERROR("default_route_endpoint_v6 not supported");
}
- AddRoute(AF_INET6, matchall_1_route + 1, 1, default_route_v6, devname);
- AddRoute(AF_INET6, matchall_1_route + 0, 1, default_route_v6, devname);
+ AddRoute(AF_INET6, matchall_1_route + 1, 1, default_route_v6, devname_);
+ AddRoute(AF_INET6, matchall_1_route + 0, 1, default_route_v6, devname_);
}
}
// Add all the extra routes
for (auto it = config.extra_routes.begin(); it != config.extra_routes.end(); ++it) {
if (it->size == 32) {
- AddRoute(ReadBE32(it->addr), it->cidr, default_route_v4, devname);
+ AddRoute(ReadBE32(it->addr), it->cidr, default_route_v4, devname_);
} else if (it->size == 128 && config.ipv6_cidr) {
- AddRoute(AF_INET6, it->addr, it->cidr, default_route_v6, devname);
+ AddRoute(AF_INET6, it->addr, it->cidr, default_route_v6, devname_);
}
}
@@ -576,8 +585,10 @@ bool TunsafeBackendBsd::Initialize(const TunConfig &&config, TunConfigOut *out)
void TunsafeBackendBsd::CleanupRoutes() {
RunPrePostCommand(pre_down_);
- for(auto it = cleanup_commands_.begin(); it != cleanup_commands_.end(); ++it)
- DelRoute(*it);
+ for(auto it = cleanup_commands_.begin(); it != cleanup_commands_.end(); ++it) {
+ if (!tun_interface_gone_ || strcmp(it->dev.c_str(), devname_) != 0)
+ DelRoute(*it);
+ }
cleanup_commands_.clear();
RunPrePostCommand(post_down_);
@@ -586,6 +597,10 @@ void TunsafeBackendBsd::CleanupRoutes() {
post_down_.clear();
}
+void TunsafeBackendBsd::SetTunDeviceName(const char *name) {
+ my_strlcpy(devname_, sizeof(devname_), name);
+}
+
static bool RunOneCommand(const std::string &cmd) {
RINFO("Run: %s", cmd.c_str());
int exit_code = system(cmd.c_str());
@@ -604,6 +619,94 @@ bool TunsafeBackendBsd::RunPrePostCommand(const std::vector &vec) {
return success;
}
+#if defined(OS_LINUX)
+UnixSocketDeletionWatcher::UnixSocketDeletionWatcher()
+ : inotify_fd_(-1) {
+ pipes_[0] = -1;
+ pipes_[0] = -1;
+}
+
+UnixSocketDeletionWatcher::~UnixSocketDeletionWatcher() {
+ close(inotify_fd_);
+ close(pipes_[0]);
+ close(pipes_[1]);
+}
+
+bool UnixSocketDeletionWatcher::Start(const char *path, bool *flag_to_set) {
+ assert(inotify_fd_ == -1);
+ path_ = path;
+ flag_to_set_ = flag_to_set;
+ pid_ = getpid();
+ inotify_fd_ = inotify_init1(IN_CLOEXEC | IN_NONBLOCK);
+ if (inotify_fd_ == -1) {
+ perror("inotify_init1() failed");
+ return false;
+ }
+ if (inotify_add_watch(inotify_fd_, "/var/run/wireguard", IN_DELETE | IN_DELETE_SELF) == -1) {
+ perror("inotify_add_watch failed");
+ return false;
+ }
+ if (pipe(pipes_) == -1) {
+ perror("pipe() failed");
+ return false;
+ }
+ return pthread_create(&thread_, NULL, &UnixSocketDeletionWatcher::RunThread, this) == 0;
+}
+
+void UnixSocketDeletionWatcher::Stop() {
+ RINFO("Stopping..");
+ void *retval;
+ write(pipes_[1], "", 1);
+ pthread_join(thread_, &retval);
+}
+
+void *UnixSocketDeletionWatcher::RunThread(void *arg) {
+ UnixSocketDeletionWatcher *self = (UnixSocketDeletionWatcher*)arg;
+ return self->RunThreadInner();
+}
+
+void *UnixSocketDeletionWatcher::RunThreadInner() {
+ char buf[sizeof(struct inotify_event) + NAME_MAX + 1]
+ __attribute__ ((aligned(__alignof__(struct inotify_event))));
+ fd_set fdset;
+ struct stat st;
+ for(;;) {
+ if (lstat(path_, &st) == -1 && errno == ENOENT) {
+ RINFO("Unix socket %s deleted.", path_);
+ *flag_to_set_ = true;
+ kill(pid_, SIGALRM);
+ break;
+ }
+ FD_ZERO(&fdset);
+ FD_SET(inotify_fd_, &fdset);
+ FD_SET(pipes_[0], &fdset);
+ int n = select(std::max(inotify_fd_, pipes_[0]) + 1, &fdset, NULL, NULL, NULL);
+ if (n == -1) {
+ perror("select");
+ break;
+ }
+ if (FD_ISSET(inotify_fd_, &fdset)) {
+ ssize_t len = read(inotify_fd_, buf, sizeof(buf));
+ if (len == -1) {
+ perror("read");
+ break;
+ }
+ }
+ if (FD_ISSET(pipes_[0], &fdset))
+ break;
+ }
+ return NULL;
+}
+
+#else // !defined(OS_LINUX)
+
+bool UnixSocketDeletionWatcher::Poll(const char *path) {
+ struct stat st;
+ return lstat(path, &st) == -1 && errno == ENOENT;
+}
+
+#endif // !defined(OS_LINUX)
+
static TunsafeBackendBsd *g_tunsafe_backend_bsd;
static void SigAlrm(int sig) {
@@ -611,10 +714,6 @@ static void SigAlrm(int sig) {
g_tunsafe_backend_bsd->HandleSigAlrm();
}
-static void SigUsr1(int sig) {
-
-}
-
static bool did_ctrlc;
void SigInt(int sig) {
@@ -623,6 +722,7 @@ void SigInt(int sig) {
did_ctrlc = true;
write(1, "Ctrl-C detected. Exiting. Press again to force quit.\n", sizeof("Ctrl-C detected. Exiting. Press again to force quit.\n")-1);
+ // todo: fix signal safety?
if (g_tunsafe_backend_bsd)
g_tunsafe_backend_bsd->HandleExit();
}
@@ -631,7 +731,10 @@ void TunsafeBackendBsd::RunLoop() {
assert(!g_tunsafe_backend_bsd);
assert(processor_);
+ sigset_t mask;
+
g_tunsafe_backend_bsd = this;
+
// We want an alarm signal every second.
{
struct sigaction act = {0};
@@ -651,16 +754,14 @@ void TunsafeBackendBsd::RunLoop() {
}
}
- {
- struct sigaction act = {0};
- act.sa_handler = SigUsr1;
- if (sigaction(SIGUSR1, &act, NULL) < 0) {
- RERROR("Unable to install SIGUSR1 handler.");
- return;
- }
+#if defined(OS_LINUX) || defined(OS_FREEBSD)
+ sigemptyset(&mask);
+ sigaddset(&mask, SIGALRM);
+ if (sigprocmask(SIG_BLOCK, &mask, &orig_signal_mask_) < 0) {
+ perror("sigprocmask");
+ return;
}
-#if defined(OS_LINUX) || defined(OS_FREEBSD)
{
struct itimerspec tv = {0};
struct sigevent sev;
@@ -727,7 +828,17 @@ public:
bool is_connected_;
};
+struct CommandLineOutput {
+ const char *filename_to_load;
+ const char *interface_name;
+ bool daemon;
+};
+
+int HandleCommandLine(int argc, char **argv, CommandLineOutput *output);
+
int main(int argc, char **argv) {
+ CommandLineOutput cmd = {0};
+
InitCpuFeatures();
if (argc == 2 && strcmp(argv[1], "--benchmark") == 0) {
@@ -735,12 +846,9 @@ int main(int argc, char **argv) {
return 0;
}
- fprintf(stderr, "%s\n", TUNSAFE_VERSION_STRING);
-
- if (argc < 2) {
- fprintf(stderr, "Syntax: tunsafe file.conf\n");
- return 1;
- }
+ int rv = HandleCommandLine(argc, argv, &cmd);
+ if (!cmd.filename_to_load)
+ return rv;
#if defined(OS_MACOSX)
InitOsxGetMilliseconds();
@@ -749,19 +857,29 @@ int main(int argc, char **argv) {
SetThreadName("tunsafe-m");
MyProcessorDelegate my_procdel;
- TunsafeBackendBsd *socket_loop = CreateTunsafeBackendBsd();
- WireguardProcessor wg(socket_loop, socket_loop, &my_procdel);
+ TunsafeBackendBsd *backend = CreateTunsafeBackendBsd();
+ if (cmd.interface_name)
+ backend->SetTunDeviceName(cmd.interface_name);
+
+ WireguardProcessor wg(backend, backend, &my_procdel);
my_procdel.wg_processor_ = &wg;
- socket_loop->SetProcessor(&wg);
+ backend->SetProcessor(&wg);
DnsResolver dns_resolver(NULL);
- if (!ParseWireGuardConfigFile(&wg, argv[1], &dns_resolver)) return 1;
+ if (*cmd.filename_to_load && !ParseWireGuardConfigFile(&wg, cmd.filename_to_load, &dns_resolver))
+ return 1;
if (!wg.Start()) return 1;
- socket_loop->RunLoop();
- socket_loop->CleanupRoutes();
- delete socket_loop;
+ if (cmd.daemon) {
+ fprintf(stderr, "Switching to daemon mode...\n");
+ if (daemon(0, 0) == -1)
+ perror("daemon() failed");
+ }
+
+ backend->RunLoop();
+ backend->CleanupRoutes();
+ delete backend;
return 0;
}
diff --git a/network_bsd_common.h b/network_bsd_common.h
index cc14bc6..36d566c 100644
--- a/network_bsd_common.h
+++ b/network_bsd_common.h
@@ -7,6 +7,7 @@
#include "wireguard.h"
#include "wireguard_config.h"
#include
+#include
struct RouteInfo {
uint8 family;
@@ -16,6 +17,39 @@ struct RouteInfo {
std::string dev;
};
+#if defined(OS_LINUX)
+// Keeps track of when the unix socket gets deleted
+class UnixSocketDeletionWatcher {
+public:
+ UnixSocketDeletionWatcher();
+ ~UnixSocketDeletionWatcher();
+ bool Start(const char *path, bool *flag_to_set);
+ void Stop();
+ bool Poll(const char *path) { return false; }
+
+private:
+ static void *RunThread(void *arg);
+ void *RunThreadInner();
+ const char *path_;
+ int inotify_fd_;
+ int pid_;
+ int pipes_[2];
+ pthread_t thread_;
+ bool *flag_to_set_;
+};
+#else // !defined(OS_LINUX)
+// all other platforms that lack inotify
+class UnixSocketDeletionWatcher {
+public:
+ UnixSocketDeletionWatcher() {}
+ ~UnixSocketDeletionWatcher() {}
+ bool Start(const char *path, bool *flag_to_set) { return true; }
+ void Stop() {}
+ bool Poll(const char *path);
+};
+#endif // !defined(OS_LINUX)
+
+
class TunsafeBackendBsd : public TunInterface, public UdpInterface {
public:
TunsafeBackendBsd();
@@ -24,10 +58,12 @@ public:
void RunLoop();
void CleanupRoutes();
+ void SetTunDeviceName(const char *name);
+
void SetProcessor(WireguardProcessor *wg) { processor_ = wg; }
// -- from TunInterface
- virtual bool Initialize(const TunConfig &&config, TunConfigOut *out) override;
+ virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override;
virtual void HandleSigAlrm() = 0;
virtual void HandleExit() = 0;
@@ -44,6 +80,9 @@ protected:
WireguardProcessor *processor_;
std::vector cleanup_commands_;
std::vector pre_down_, post_down_;
+ sigset_t orig_signal_mask_;
+ char devname_[16];
+ bool tun_interface_gone_;
};
#if defined(OS_MACOSX) || defined(OS_FREEBSD)
diff --git a/network_win32.cpp b/network_win32.cpp
index 2b9c680..5bc7f01 100644
--- a/network_win32.cpp
+++ b/network_win32.cpp
@@ -60,20 +60,6 @@ static bool IsIpv6AddressSet(const void *p) {
return (ReadLE64(p) | ReadLE64((char*)p + 8)) != 0;
}
-void OsGetRandomBytes(uint8 *data, size_t data_size) {
- static BOOLEAN(APIENTRY *pfn)(void*, ULONG);
- static bool resolved;
- if (!resolved) {
- pfn = (BOOLEAN(APIENTRY *)(void*, ULONG))GetProcAddress(LoadLibrary("ADVAPI32.DLL"), "SystemFunction036");
- resolved = true;
- }
- if (pfn && pfn(data, (ULONG)data_size))
- return;
- size_t r = 0;
- for (; r < data_size; r++)
- data[r] = rand() >> 6;
-}
-
void OsInterruptibleSleep(int millis) {
SleepEx(millis, TRUE);
}
@@ -140,6 +126,7 @@ struct {
#define kConcurrentWriteTap 16
#define kAdapterKeyName "SYSTEM\\CurrentControlSet\\Control\\Class\\{4D36E972-E325-11CE-BFC1-08002BE10318}"
+#define kNetworkConnectionsKeyName "SYSTEM\\CurrentControlSet\\Control\\Network\\{4D36E972-E325-11CE-BFC1-08002BE10318}"
#define kTapComponentId "tap0901"
#define TAP_CONTROL_CODE(request,method) \
@@ -196,91 +183,79 @@ static bool RunNetsh(const char *cmdline) {
return result;
}
-// Retrieve the device path to the TAP adapter.
-static bool GetTapAdapterGuid(char guid[64]) {
- LONG err;
- HKEY adapter_key, device_key;
- bool retval = false;
- err = RegOpenKeyEx(HKEY_LOCAL_MACHINE, kAdapterKeyName, 0, KEY_READ, &adapter_key);
- if (err != ERROR_SUCCESS) {
- RERROR("GetTapAdapterName: RegOpenKeyEx failed: 0x%X", GetLastError());
- return false;
- }
- for (int i = 0; !retval; i++) {
- char keyname[64 + sizeof(kAdapterKeyName) + 1];
- char value[64];
- DWORD len = sizeof(value), type;
- err = RegEnumKeyEx(adapter_key, i, value, &len, NULL, NULL, NULL, NULL);
- if (err == ERROR_NO_MORE_ITEMS)
- break;
- if (err != ERROR_SUCCESS) {
- RERROR("GetTapAdapterName: RegEnumKeyEx failed: 0x%X", GetLastError());
- return false;
- }
- snprintf(keyname, sizeof(keyname), "%s\\%s", kAdapterKeyName, value);
- err = RegOpenKeyEx(HKEY_LOCAL_MACHINE, keyname, 0, KEY_READ, &device_key);
- if (err == ERROR_SUCCESS) {
- len = sizeof(value);
- err = RegQueryValueEx(device_key, "ComponentId", NULL, &type, (LPBYTE)value, &len);
- if (err == ERROR_SUCCESS && type == REG_SZ && !memcmp(value, kTapComponentId, sizeof(kTapComponentId))) {
- len = 64;
- err = RegQueryValueEx(device_key, "NetCfgInstanceId", NULL, &type, (LPBYTE)guid, &len);
- if (err == ERROR_SUCCESS && type == REG_SZ) {
- guid[63] = 0;
- retval = true;
- }
- }
- RegCloseKey(device_key);
- }
- }
- RegCloseKey(adapter_key);
- return retval;
-}
-
-// Open the TAP adapter
-static HANDLE OpenTunAdapter(char guid[64], int retry_count, uint32 *exit_thread, DWORD open_flags) {
+// Open the TAP adapter, either a random one or a specific one
+// On return, the adapter is locked in |TunAdaptersInUse|.
+static HANDLE OpenTunAdapter(char guid[ADAPTER_GUID_SIZE], TunsafeBackendWin32 *backend, DWORD open_flags) {
char path[128];
HANDLE h;
int retries = 0;
- if (!GetTapAdapterGuid(guid)) {
- RERROR("Unable to find ID of TAP adapter");
- RERROR(" Please ensure that TunSafe-TAP is properly installed.");
- return NULL;
- }
- snprintf(path, sizeof(path), "\\\\.\\Global\\%s.tap", guid);
-RETRY:
- h = CreateFile(path, GENERIC_READ | GENERIC_WRITE, 0, 0, OPEN_EXISTING,
- FILE_ATTRIBUTE_SYSTEM | open_flags, 0);
- if (h == INVALID_HANDLE_VALUE) {
- int error_code = GetLastError();
-
- // Sometimes if you close the device right before, it will fail to open with errorcode 31.
- // When resuming from sleep in my VM, the error code is ERROR_FILE_NOT_FOUND
- if ((error_code == ERROR_FILE_NOT_FOUND || error_code == ERROR_GEN_FAILURE) && retry_count != 0 && !*exit_thread) {
- RERROR("OpenTapAdapter: CreateFile failed: 0x%X... retrying", error_code);
- retry_count--;
-
- int sleep_amount = 250 * ++retries;
- for(;;) {
- if (*exit_thread)
- return NULL;
- if (sleep_amount == 0)
- break;
- Sleep(50);
- sleep_amount -= 50;
- }
- goto RETRY;
- }
-
- RERROR("OpenTapAdapter: CreateFile failed: 0x%X", error_code);
- if (error_code == ERROR_FILE_NOT_FOUND) {
+ std::vector adapters;
+
+ // When guid is empty, we try all adapters, otherwise
+ // just try the specific adapter
+ if (guid[0] == 0) {
+ GetTapAdapterInfo(&adapters);
+ if (adapters.empty()) {
+ RERROR("Unable to find any TAP adapters");
RERROR(" Please ensure that TunSafe-TAP is properly installed.");
- } else if (error_code == 0x1f) {
- RERROR(" Please ensure that the TAP device is not in use.");
+ return NULL;
}
+ } else {
+ adapters.emplace_back();
+ memcpy(adapters.back().guid, guid, ADAPTER_GUID_SIZE);
+ adapters.back().name[0] = 0;
+ }
+ TunAdaptersInUse *tun_adapters_in_use = TunAdaptersInUse::GetInstance();
+
+RETRY:
+ bool did_try_adapter = false;
+ int error_code = 0;
+ for (GuidAndDevName &x : adapters) {
+ snprintf(path, sizeof(path), "\\\\.\\Global\\%s.tap", x.guid);
+ if (tun_adapters_in_use->Acquire(x.guid, static_cast(backend))) {
+ h = CreateFile(path, GENERIC_READ | GENERIC_WRITE, 0, 0, OPEN_EXISTING, FILE_ATTRIBUTE_SYSTEM | open_flags, 0);
+ if (h != INVALID_HANDLE_VALUE) {
+ memcpy(guid, x.guid, ADAPTER_GUID_SIZE);
+ return h;
+ }
+ did_try_adapter = true;
+ error_code = GetLastError();
+ tun_adapters_in_use->Release(static_cast(backend));
+ }
+ }
+ if (!did_try_adapter) {
+ RERROR("All TAP adapters are currently in use");
return NULL;
}
- return h;
+
+ // Sometimes if you close the device right before, it will fail to open with errorcode 31.
+ // When resuming from sleep in my VM, the error code is ERROR_FILE_NOT_FOUND
+ if ((error_code == ERROR_FILE_NOT_FOUND || error_code == ERROR_GEN_FAILURE) && !backend->exit_code()) {
+ if (retries <= 10) {
+ RERROR("OpenTapAdapter: CreateFile failed: 0x%X... retrying%s", error_code, retries == 10 ? " (last notice)" : "");
+ if (retries == 10) {
+ if (error_code == ERROR_FILE_NOT_FOUND) {
+ RERROR(" Please ensure that TunSafe-TAP is properly installed.");
+ } else if (error_code == ERROR_GEN_FAILURE) {
+ RERROR(" Please ensure that the TAP device is not in use.");
+ }
+ backend->SetStatus(TunsafeBackend::kStatusTunRetrying);
+ }
+ }
+ int sleep_amount = 250 * std::min(++retries, 40);
+ for (;;) {
+ if (backend->exit_code())
+ return NULL;
+ if (sleep_amount == 0)
+ break;
+ Sleep(125);
+ sleep_amount -= 125;
+ }
+ goto RETRY;
+ }
+
+ RERROR("OpenTapAdapter: CreateFile failed: 0x%X", error_code);
+ return NULL;
}
static bool AddRoute(int family,
@@ -466,60 +441,85 @@ UdpSocketWin32::UdpSocketWin32() {
wqueue_end_ = &wqueue_;
wqueue_ = NULL;
exit_thread_ = false;
- socket_ = INVALID_SOCKET;
thread_ = NULL;
+ socket_ = INVALID_SOCKET;
socket_ipv6_ = INVALID_SOCKET;
completion_port_handle_ = NULL;
}
UdpSocketWin32::~UdpSocketWin32() {
assert(thread_ == NULL);
+ CloseHandle(completion_port_handle_);
closesocket(socket_);
closesocket(socket_ipv6_);
- CloseHandle(completion_port_handle_);
FreePacketList(wqueue_);
}
-bool UdpSocketWin32::Initialize(int listen_on_port) {
- SOCKET s = WSASocket(AF_INET, SOCK_DGRAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED);
- if (s == INVALID_SOCKET) {
- RERROR("UdpSocketWin32::Initialize WSASocket failed");
- return false;
+bool UdpSocketWin32::Configure(int listen_on_port) {
+ // If attempting to initialize when the thread is already started, then stop
+ // the thread, reinitialize, and start the thread.
+ if (thread_ != NULL) {
+ StopThread();
+ bool retcode = Configure(listen_on_port);
+ StartThread();
+ return retcode;
}
- completion_port_handle_ = CreateIoCompletionPort((HANDLE)s, NULL, NULL, 0);
- if (!completion_port_handle_) {
- closesocket(s);
- return false;
- }
- socket_ = s;
- sockaddr_in sin = {0};
- sin.sin_family = AF_INET;
- sin.sin_port = htons(listen_on_port);
- if (bind(s, (struct sockaddr*)&sin, sizeof(sin)) != 0) {
- RERROR("UdpSocketWin32::Initialize bind failed");
- return false;
+ bool retval = false;
+ HANDLE completion_port = NULL;
+ SOCKET socket_ipv4 = INVALID_SOCKET, socket_ipv6 = INVALID_SOCKET;
+
+ socket_ipv4 = WSASocket(AF_INET, SOCK_DGRAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED);
+ if (socket_ipv4 == INVALID_SOCKET) {
+ RERROR("UdpSocketWin32::Initialize WSASocket failed");
+ goto fail;
+ }
+ completion_port = CreateIoCompletionPort((HANDLE)socket_ipv4, NULL, NULL, 0);
+ if (!completion_port) {
+ RERROR("UdpSocketWin32::Initialize CreateIoCompletionPort failed");
+ goto fail;
+ }
+
+ {
+ sockaddr_in sin = {0};
+ sin.sin_family = AF_INET;
+ sin.sin_port = htons(listen_on_port);
+ if (bind(socket_ipv4, (struct sockaddr*)&sin, sizeof(sin)) != 0) {
+ RERROR("UdpSocketWin32::Initialize bind failed");
+ goto fail;
+ }
}
// Also open up a socket for ipv6
- s = WSASocket(AF_INET6, SOCK_DGRAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED);
- if (s != INVALID_SOCKET) {
- if (!CreateIoCompletionPort((HANDLE)s, completion_port_handle_, 1, 0)) {
+ socket_ipv6 = WSASocket(AF_INET6, SOCK_DGRAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED);
+ if (socket_ipv6 != INVALID_SOCKET) {
+ if (!CreateIoCompletionPort((HANDLE)socket_ipv6, completion_port, 1, 0)) {
RERROR("IPv6 Socket completion port failed.");
- closesocket(s);
+ closesocket(socket_ipv6);
+ socket_ipv6 = INVALID_SOCKET;
} else {
- socket_ipv6_ = s;
sockaddr_in6 sin6 = {0};
sin6.sin6_family = AF_INET6;
sin6.sin6_port = htons(listen_on_port);
- if (bind(s, (struct sockaddr*)&sin6, sizeof(sin6)) != 0) {
+ if (bind(socket_ipv6, (struct sockaddr*)&sin6, sizeof(sin6)) != 0) {
RERROR("UdpSocketWin32::Initialize bind failed IPv6");
}
}
} else {
RERROR("IPv6 Socket creation failed.");
}
- return true;
+ std::swap(socket_ipv6_, socket_ipv6);
+ std::swap(socket_, socket_ipv4);
+ std::swap(completion_port_handle_, completion_port);
+ retval = true;
+fail:
+ if (completion_port)
+ CloseHandle(completion_port);
+ if (socket_ipv4 != INVALID_SOCKET)
+ closesocket(socket_ipv4);
+ if (socket_ipv6 != INVALID_SOCKET)
+ closesocket(socket_ipv6);
+ return retval;
}
enum {
@@ -556,7 +556,7 @@ void UdpSocketWin32::ThreadMain() {
break;
restart_read_udp6:
ClearOverlapped(&p->overlapped);
- p->post_target = ThreadedPacketQueue::TARGET_PROCESSOR_UDP;
+ p->post_target = PacketProcessor::TARGET_PROCESSOR_UDP;
WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data};
DWORD flags = 0;
p->sin_size = sizeof(p->addr.sin6);
@@ -580,7 +580,7 @@ restart_read_udp6:
break;
restart_read_udp:
ClearOverlapped(&p->overlapped);
- p->post_target = ThreadedPacketQueue::TARGET_PROCESSOR_UDP;
+ p->post_target = PacketProcessor::TARGET_PROCESSOR_UDP;
WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data};
DWORD flags = 0;
p->sin_size = sizeof(p->addr.sin);
@@ -620,7 +620,7 @@ restart_read_udp:
if (!entries[i].lpOverlapped)
continue; // This is the dummy entry from |PostQueuedCompletionStatus|
Packet *p = (Packet*)((byte*)entries[i].lpOverlapped - offsetof(Packet, overlapped));
- if (p->post_target == ThreadedPacketQueue::TARGET_PROCESSOR_UDP) {
+ if (p->post_target == PacketProcessor::TARGET_PROCESSOR_UDP) {
num_reads[entries[i].lpCompletionKey]--;
if ((DWORD)p->overlapped.Internal != 0) {
if (!IsIgnoredUdpError((DWORD)p->overlapped.Internal))
@@ -673,7 +673,7 @@ restart_read_udp:
Packet *p = pending_writes;
pending_writes = p->next;
ClearOverlapped(&p->overlapped);
- p->post_target = ThreadedPacketQueue::TARGET_UDP_DEVICE;
+ p->post_target = PacketProcessor::TARGET_UDP_DEVICE;
WSABUF wsabuf = {(ULONG)p->size, (char*)p->data};
int rv;
@@ -715,7 +715,7 @@ restart_read_udp:
if (!entries[0].lpOverlapped)
continue; // This is the dummy entry from |PostQueuedCompletionStatus|
Packet *p = (Packet*)((byte*)entries[0].lpOverlapped - offsetof(Packet, overlapped));
- if (p->post_target == ThreadedPacketQueue::TARGET_PROCESSOR_UDP) {
+ if (p->post_target == PacketProcessor::TARGET_PROCESSOR_UDP) {
num_reads[entries[0].lpCompletionKey]--;
} else {
num_writes--;
@@ -724,8 +724,6 @@ restart_read_udp:
}
}
-
-
// Called on another thread to queue up a udp packet
void UdpSocketWin32::WriteUdpPacket(Packet *packet) {
if (qs.udp_qsize2 - qs.udp_qsize1 >= (unsigned)(packet->size < 576 ? MAX_BYTES_IN_UDP_OUT_QUEUE_SMALL : MAX_BYTES_IN_UDP_OUT_QUEUE)) {
@@ -754,73 +752,111 @@ DWORD WINAPI UdpSocketWin32::UdpThread(void *x) {
}
void UdpSocketWin32::StartThread() {
+ assert(completion_port_handle_);
+
DWORD thread_id;
thread_ = CreateThread(NULL, 0, &UdpThread, this, 0, &thread_id);
SetThreadPriority(thread_, ABOVE_NORMAL_PRIORITY_CLASS);
}
void UdpSocketWin32::StopThread() {
- exit_thread_ = true;
- PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, NULL);
- WaitForSingleObject(thread_, INFINITE);
- CloseHandle(thread_);
- thread_ = NULL;
+ if (thread_ != NULL) {
+ exit_thread_ = true;
+ PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, NULL);
+ WaitForSingleObject(thread_, INFINITE);
+ CloseHandle(thread_);
+ thread_ = NULL;
+ exit_thread_ = false;
+ }
}
-ThreadedPacketQueue::ThreadedPacketQueue(WireguardProcessor *wg, TunsafeBackendWin32 *backend) {
- wg_ = wg;
- backend_ = backend;
+PacketProcessor::PacketProcessor() {
event_ = CreateEvent(NULL, FALSE, FALSE, NULL);
last_ptr_ = &first_;
first_ = NULL;
- handle_ = NULL;
- timer_handle_ = NULL;
- exit_flag_ = false;
+ exit_code_ = 0;
timer_interrupt_ = false;
packets_in_queue_ = 0;
need_notify_ = 0;
}
-ThreadedPacketQueue::~ThreadedPacketQueue() {
- assert(handle_ == NULL);
- assert(timer_handle_ == NULL);
+PacketProcessor::~PacketProcessor() {
first_ = NULL;
last_ptr_ = &first_;
CloseHandle(event_);
}
-DWORD WINAPI ThreadedPacketQueue::ThreadedPacketQueueLauncher(VOID *x) {
- ThreadedPacketQueue *pq = (ThreadedPacketQueue *)x;
- return pq->ThreadMain();
+void CALLBACK PacketProcessor::ThreadPoolTimerCallback(PTP_CALLBACK_INSTANCE iTimerInstance, PVOID pContext, PTP_TIMER) {
+ PacketProcessor *th = (PacketProcessor *)pContext;
+ th->mutex_.Acquire();
+ th->timer_interrupt_ = true;
+ if (th->need_notify_) {
+ th->need_notify_ = 0;
+ th->mutex_.Release();
+ SetEvent(th->event_);
+ return;
+ }
+ th->mutex_.Release();
}
-DWORD ThreadedPacketQueue::ThreadMain() {
- int free_packets_ctr = 0;
- int overload = 0;
+struct ConfigPacket {
+ std::string message;
+ uint32 ident;
+ Packet packet;
+};
+
+void PacketProcessor::Reset() {
Packet *packet;
- wg_->dev().SetCurrentThreadAsMainThread();
+ packet = first_;
+ first_ = NULL;
+ exit_code_ = 0;
+ last_ptr_ = &first_;
+ timer_interrupt_ = false;
+
+ while (packet) {
+ Packet *next = packet->next;
+ if (packet->post_target == TARGET_CONFIG_PROTOCOL) {
+ ConfigPacket *config = (ConfigPacket*)((uint8*)packet - offsetof(ConfigPacket, packet));
+ delete config;
+ } else {
+ FreePacket(packet);
+ }
+ packet = next;
+ }
+}
+
+int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) {
+ int free_packets_ctr = 0;
+ int overload = 0;
+ int exit_code;
+ Packet *packet;
+ PTP_TIMER threadpool_timer;
+
+ threadpool_timer = CreateThreadpoolTimer(&ThreadPoolTimerCallback, this, NULL);
+ static const int64 duetime = -10000000; // the unit is 100ns
+ SetThreadpoolTimer(threadpool_timer, (FILETIME*)&duetime, 1000, 1000);
mutex_.Acquire();
- while (!exit_flag_) {
+ while (!(exit_code = exit_code_)) {
if (timer_interrupt_) {
timer_interrupt_ = false;
need_notify_ = 0;
mutex_.Release();
- wg_->SecondLoop();
- backend_->stats_mutex_.Acquire();
- backend_->stats_ = wg_->GetStats();
+ wg->SecondLoop();
+ backend->stats_mutex_.Acquire();
+ backend->stats_ = wg->GetStats();
float data[2] = {
// unit is megabits/second
- backend_->stats_.tun_bytes_in_per_second * (1.0f / 125000),
- backend_->stats_.tun_bytes_out_per_second * (1.0f / 125000),
+ backend->stats_.tun_bytes_in_per_second * (1.0f / 125000),
+ backend->stats_.tun_bytes_out_per_second * (1.0f / 125000),
};
- backend_->stats_collector_.AddSamples(data);
- backend_->stats_mutex_.Release();
+ backend->stats_collector_.AddSamples(data);
+ backend->stats_mutex_.Release();
- backend_->delegate_->OnGraphAvailable();
- backend_->PushStats();
+ backend->delegate_->OnGraphAvailable();
+ backend->PushStats();
// Conserve memory every 10s
if (free_packets_ctr++ == 10) {
@@ -847,64 +883,50 @@ DWORD ThreadedPacketQueue::ThreadMain() {
overload = 2;
bool is_overload = (overload != 0);
- WireguardProcessor *procint = wg_;
do {
Packet *next = packet->next;
- if (packet->post_target == TARGET_PROCESSOR_UDP)
- procint->HandleUdpPacket(packet, is_overload);
- else
- procint->HandleTunPacket(packet);
+ if (packet->post_target == TARGET_PROCESSOR_UDP) {
+ wg->HandleUdpPacket(packet, is_overload);
+ } else if (packet->post_target == TARGET_PROCESSOR_TUN) {
+ wg->HandleTunPacket(packet);
+ } else {
+ assert(packet->post_target == TARGET_CONFIG_PROTOCOL);
+ HandleConfigurationProtocolPacket(wg, backend, packet);
+ }
packet = next;
} while (packet);
}
- wg_->RunAllMainThreadScheduled();
+ wg->RunAllMainThreadScheduled();
mutex_.Acquire();
}
+ exit_code_ = 0;
mutex_.Release();
- return 0;
+
+ SetThreadpoolTimer(threadpool_timer, nullptr, 0, 0);
+ WaitForThreadpoolTimerCallbacks(threadpool_timer, true);
+ CloseThreadpoolTimer(threadpool_timer);
+
+ return exit_code;
}
-void ThreadedPacketQueue::Start() {
- if (handle_ == NULL) {
- exit_flag_ = false;
- DWORD thread_id;
- handle_ = CreateThread(NULL, 0, &ThreadedPacketQueueLauncher, this, 0, &thread_id);
- }
+void PacketProcessor::HandleConfigurationProtocolPacket(WireguardProcessor *wg, TunsafeBackendWin32 *backend, Packet *packet) {
+ ConfigPacket *config = (ConfigPacket*)((uint8*)packet - offsetof(ConfigPacket, packet));
+ std::string reply;
+ WgConfig::HandleConfigurationProtocolMessage(wg, std::move(config->message), &reply);
+ backend->delegate_->OnConfigurationProtocolReply(config->ident, std::move(reply));
- assert(timer_handle_ == NULL);
- timer_handle_ = CreateWaitableTimer(NULL, FALSE, NULL);
- long long due_time = 10000000;
- SetWaitableTimer(timer_handle_, (LARGE_INTEGER*)&due_time, 1000, &TimerRoutine, this, FALSE);
}
-void ThreadedPacketQueue::Stop() {
+void PacketProcessor::PostExit(int exit_code) {
mutex_.Acquire();
- exit_flag_ = true;
+ // Avoid race condition where mode_tun_failed is set during thread exit.
+ if (exit_code_ != TunsafeBackendWin32::MODE_RESTART && exit_code_ != TunsafeBackendWin32::MODE_EXIT)
+ exit_code_ = exit_code;
mutex_.Release();
-
SetEvent(event_);
-
- if (timer_handle_ != NULL) {
- // Not sure if just CloseHandle will close any outstanding APCs
- CancelWaitableTimer(timer_handle_);
- CloseHandle(timer_handle_);
- timer_handle_ = NULL;
- }
-
- if (handle_ != NULL) {
- WaitForSingleObject(handle_, INFINITE);
- CloseHandle(handle_);
- handle_ = NULL;
- }
}
-void ThreadedPacketQueue::AbortingDriver() {
- mutex_.Acquire();
- exit_flag_ = true;
- mutex_.Release();
-}
-
-void ThreadedPacketQueue::Post(Packet *packet, Packet **end, int count) {
+void PacketProcessor::Post(Packet *packet, Packet **end, int count) {
mutex_.Acquire();
if (packets_in_queue_ >= HARD_MAXIMUM_QUEUE_SIZE) {
mutex_.Release();
@@ -912,15 +934,11 @@ void ThreadedPacketQueue::Post(Packet *packet, Packet **end, int count) {
return;
}
assert(packet != NULL);
- if (!first_) {
- assert(last_ptr_ == &first_);
- }
+ assert(first_ || last_ptr_ == &first_);
packets_in_queue_ += count;
*last_ptr_ = packet;
last_ptr_ = end;
- if (!first_) {
- assert(last_ptr_ == &first_);
- }
+ assert(first_ || last_ptr_ == &first_);
if (need_notify_) {
need_notify_ = 0;
mutex_.Release();
@@ -930,13 +948,12 @@ void ThreadedPacketQueue::Post(Packet *packet, Packet **end, int count) {
mutex_.Release();
}
-void CALLBACK ThreadedPacketQueue::TimerRoutine(LPVOID lpArgToCompletionRoutine, DWORD dwTimerLowValue, DWORD dwTimerHighValue) {
- ((ThreadedPacketQueue*)lpArgToCompletionRoutine)->PostTimerInterrupt();
-}
-
-void ThreadedPacketQueue::PostTimerInterrupt() {
+void PacketProcessor::ForcePost(Packet *packet) {
mutex_.Acquire();
- timer_interrupt_ = true;
+ packet->next = NULL;
+ packets_in_queue_ += 1;
+ *last_ptr_ = packet;
+ last_ptr_ = &packet->next;
if (need_notify_) {
need_notify_ = 0;
mutex_.Release();
@@ -1125,29 +1142,47 @@ static bool AddMultipleCatchallRoutes(int inet, int bits, const uint8 *target, c
return success;
}
-TunWin32Adapter::TunWin32Adapter(DnsBlocker *dns_blocker) {
+TunWin32Adapter::TunWin32Adapter(DnsBlocker *dns_blocker, const char guid[ADAPTER_GUID_SIZE]) {
handle_ = NULL;
dns_blocker_ = dns_blocker;
old_ipv6_address_.size = 0;
old_ipv6_metric_ = kMetricNone;
old_ipv4_metric_ = kMetricNone;
has_dns6_setting_ = false;
+ guid_[0] = 0;
+ if (guid)
+ memcpy(guid_, guid, ADAPTER_GUID_SIZE);
}
TunWin32Adapter::~TunWin32Adapter() {
}
-bool TunWin32Adapter::OpenAdapter(uint32 *exit_thread, DWORD open_flags) {
+bool TunWin32Adapter::OpenAdapter(TunsafeBackendWin32 *backend, DWORD open_flags) {
+ ULONG info[3];
+ DWORD len;
assert(handle_ == NULL);
- int retry_count = 20;
- handle_ = OpenTunAdapter(guid_, retry_count, exit_thread, open_flags);
+ backend_ = backend;
+ handle_ = OpenTunAdapter(guid_, backend, open_flags);
+ if (handle_ != NULL) {
+ memset(info, 0, sizeof(info));
+ if (DeviceIoControl(handle_, TAP_IOCTL_GET_VERSION, &info, sizeof(info),
+ &info, sizeof(info), &len, NULL)) {
+ RINFO("TAP Driver Version %d.%d %s", (int)info[0], (int)info[1], (info[2] ? "(DEBUG)" : ""));
+ }
+
+ if (info[0] < 9 || info[0] == 9 && info[1] <= 8) {
+ RERROR("TAP is too old. Go to https://tunsafe.com/download to upgrade the driver");
+ CloseHandle(handle_);
+ handle_ = NULL;
+ }
+ }
return (handle_ != NULL);
}
-bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInterface::TunConfigOut *out) {
- ULONG info[3];
- DWORD len;
+bool TunWin32Adapter::ConfigureAdapter(const TunInterface::TunConfig &&config, TunInterface::TunConfigOut *out) {
+ DWORD len, err;
+
out->enable_neighbor_discovery_spoofing = false;
if (!RunPrePostCommand(config.pre_post_commands.pre_up)) {
@@ -1158,28 +1193,11 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt
pre_down_ = std::move(config.pre_post_commands.pre_down);
post_down_ = std::move(config.pre_post_commands.post_down);
- memset(info, 0, sizeof(info));
- if (DeviceIoControl(handle_, TAP_IOCTL_GET_VERSION, &info, sizeof(info),
- &info, sizeof(info), &len, NULL)) {
- RINFO("TAP Driver Version %d.%d %s", (int)info[0], (int)info[1], (info[2] ? "(DEBUG)" : ""));
- }
-
- if (info[0] < 9 || info[0] == 9 && info[1] <= 8) {
- RERROR("TAP is too old. Go to https://tunsafe.com/download to upgrade the driver");
- return false;
- }
-
- // ULONG mtu = 0;
- // if (DeviceIoControl(handle_, TAP_IOCTL_GET_MTU, &mtu, sizeof(mtu), &mtu, sizeof(mtu), &len, NULL))
- // RINFO("TAP-Win32 MTU=%d", (int)mtu);
- // mtu_ = mtu;
-
uint32 netmask = CidrToNetmaskV4(config.cidr);
// Set TAP-Windows TUN subnet mode
if (1) {
uint32 v[3];
-
v[0] = htonl(config.ip);
v[1] = htonl(config.ip & netmask);
v[2] = htonl(netmask);
@@ -1237,14 +1255,11 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt
}
bool has_interface_luid = GetNetLuidFromGuid(guid_, &interface_luid_);
-
if (!has_interface_luid) {
RERROR("Unable to determine interface luid for %s.", guid_);
return false;
}
- DWORD err;
-
if (config.mtu) {
err = SetMtuOnNetworkAdapter(&interface_luid_, AF_INET, config.mtu);
if (err)
@@ -1419,11 +1434,10 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt
}
RunPrePostCommand(config.pre_post_commands.post_up);
-
return true;
}
-void TunWin32Adapter::CloseAdapter() {
+void TunWin32Adapter::CloseAdapter(bool is_restart) {
RunPrePostCommand(pre_down_);
if (handle_ != NULL) {
@@ -1433,6 +1447,8 @@ void TunWin32Adapter::CloseAdapter() {
&status, sizeof(status), &len, NULL);
CloseHandle(handle_);
handle_ = NULL;
+
+ TunAdaptersInUse::GetInstance()->Release(backend_);
}
if (old_ipv6_address_.size != 0)
@@ -1441,23 +1457,24 @@ void TunWin32Adapter::CloseAdapter() {
SetMetricOnNetworkAdapter(&interface_luid_, AF_INET, old_ipv4_metric_, NULL);
if (old_ipv6_metric_ != kMetricNone)
SetMetricOnNetworkAdapter(&interface_luid_, AF_INET6, old_ipv6_metric_, NULL);
+ if (has_dns6_setting_)
+ SetIPV6DnsOnInterface(&interface_luid_, NULL, 0);
old_ipv4_metric_ = old_ipv6_metric_ = -1;
old_ipv6_address_.size = 0;
-
- if (has_dns6_setting_) {
- has_dns6_setting_ = false;
- SetIPV6DnsOnInterface(&interface_luid_, NULL, 0);
- }
+ has_dns6_setting_ = false;
for (auto it = routes_to_undo_.begin(); it != routes_to_undo_.end(); ++it)
DeleteRoute(&*it);
routes_to_undo_.clear();
- if (dns_blocker_)
+ if (!is_restart && dns_blocker_)
dns_blocker_->RestoreDns();
RunPrePostCommand(post_down_);
+
+ pre_down_.clear();
+ post_down_.clear();
}
static bool RunOneCommand(const std::string &cmd) {
@@ -1564,7 +1581,7 @@ bool TunWin32Adapter::RunPrePostCommand(const std::vector &vec) {
//////////////////////////////////////////////////////////////////////////////
-TunWin32Iocp::TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker), backend_(backend) {
+TunWin32Iocp::TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker, backend->guid_), backend_(backend) {
wqueue_end_ = &wqueue_;
wqueue_ = NULL;
@@ -1577,36 +1594,38 @@ TunWin32Iocp::TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend) :
TunWin32Iocp::~TunWin32Iocp() {
//assert(num_reads_ == 0 && num_writes_ == 0);
assert(thread_ == NULL);
- CloseTun();
+ CloseTun(false);
+ FreePacketList(wqueue_);
}
-bool TunWin32Iocp::Initialize(const TunConfig &&config, TunConfigOut *out) {
- assert(thread_ == NULL);
-
- if (adapter_.OpenAdapter(&backend_->stop_mode_, FILE_FLAG_OVERLAPPED)) {
+bool TunWin32Iocp::Configure(const TunConfig &&config, TunConfigOut *out) {
+ // Reconfigure while started?
+ if (thread_ != NULL) {
+ assert(completion_port_handle_);
+ StopThread();
+ bool rv = Configure(std::move(config), out);
+ StartThread();
+ return rv;
+ }
+ CloseTun(true);
+ if (adapter_.OpenAdapter(backend_, FILE_FLAG_OVERLAPPED)) {
completion_port_handle_ = CreateIoCompletionPort(adapter_.handle(), NULL, NULL, 0);
if (completion_port_handle_ != NULL) {
- if (adapter_.InitAdapter(std::move(config), out))
+ if (adapter_.ConfigureAdapter(std::move(config), out))
return true;
}
}
- CloseTun();
+ CloseTun(false);
return false;
}
-void TunWin32Iocp::CloseTun() {
+void TunWin32Iocp::CloseTun(bool is_restart) {
assert(thread_ == NULL);
-
- adapter_.CloseAdapter();
-
+ adapter_.CloseAdapter(is_restart);
if (completion_port_handle_) {
CloseHandle(completion_port_handle_);
completion_port_handle_ = NULL;
}
-
- FreePacketList(wqueue_);
- wqueue_ = NULL;
- wqueue_end_ = &wqueue_;
}
enum {
@@ -1621,6 +1640,8 @@ void TunWin32Iocp::ThreadMain() {
Packet *freed_packets = NULL, **freed_packets_end;
int freed_packets_count = 0;
DWORD err;
+ if (!completion_port_handle_)
+ return;
while (!exit_thread_) {
// Initiate more reads, reusing the Packet structures in |finished_writes|.
@@ -1629,19 +1650,16 @@ void TunWin32Iocp::ThreadMain() {
if (!AllocPacketFrom(&freed_packets, &freed_packets_count, &exit_thread_, &p))
break;
memset(&p->overlapped, 0, sizeof(p->overlapped));
- p->post_target = ThreadedPacketQueue::TARGET_PROCESSOR_TUN;
+ p->post_target = PacketProcessor::TARGET_PROCESSOR_TUN;
if (!ReadFile(adapter_.handle(), p->data, kPacketCapacity, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) {
FreePacket(p);
RERROR("TunWin32: ReadFile failed 0x%X", err);
if (err == ERROR_OPERATION_ABORTED || err == ERROR_FILE_NOT_FOUND) {
- packet_handler_->AbortingDriver();
RERROR("TAP driver stopped communicating. Attempting to restart.", err);
- // This can happen if we reinstall the TAP driver while there's an active connection. Wait a bit, then attempt to
- // restart.
- Sleep(1000);
- backend_->TunAdapterFailed();
+ // This can happen if we reinstall the TAP driver while there's an active connection.
+ backend_->PostExit(TunsafeBackendWin32::MODE_TUN_FAILED);
goto EXIT;
}
} else {
@@ -1673,7 +1691,7 @@ void TunWin32Iocp::ThreadMain() {
if (!entries[i].lpOverlapped)
continue; // This is the dummy entry from |PostQueuedCompletionStatus|
Packet *p = (Packet*)((byte*)entries[i].lpOverlapped - offsetof(Packet, overlapped));
- if (p->post_target == ThreadedPacketQueue::TARGET_PROCESSOR_TUN) {
+ if (p->post_target == PacketProcessor::TARGET_PROCESSOR_TUN) {
num_reads--;
if ((int)p->overlapped.Internal != 0) {
RERROR("TunWin32::ReadComplete error 0x%X", (int)p->overlapped.Internal);
@@ -1721,7 +1739,7 @@ void TunWin32Iocp::ThreadMain() {
Packet *p = pending_writes;
pending_writes = p->next;
memset(&p->overlapped, 0, sizeof(p->overlapped));
- p->post_target = ThreadedPacketQueue::TARGET_TUN_DEVICE;
+ p->post_target = PacketProcessor::TARGET_TUN_DEVICE;
if (!WriteFile(adapter_.handle(), p->data, p->size, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) {
RERROR("TunWin32: WriteFile failed 0x%X", err);
FreePacket(p);
@@ -1744,7 +1762,7 @@ EXIT:
if (!entries[0].lpOverlapped)
continue; // This is the dummy entry from |PostQueuedCompletionStatus|
Packet *p = (Packet*)((byte*)entries[0].lpOverlapped - offsetof(Packet, overlapped));
- if (p->post_target == ThreadedPacketQueue::TARGET_PROCESSOR_TUN) {
+ if (p->post_target == PacketProcessor::TARGET_PROCESSOR_TUN) {
num_reads--;
} else {
num_writes--;
@@ -1764,6 +1782,8 @@ DWORD WINAPI TunWin32Iocp::TunThread(void *x) {
void TunWin32Iocp::StartThread() {
DWORD thread_id;
+ assert(thread_ == NULL);
+ assert(completion_port_handle_ != NULL);
thread_ = CreateThread(NULL, 0, &TunThread, this, 0, &thread_id);
SetThreadPriority(thread_, ABOVE_NORMAL_PRIORITY_CLASS);
}
@@ -1774,6 +1794,7 @@ void TunWin32Iocp::StopThread() {
WaitForSingleObject(thread_, INFINITE);
CloseHandle(thread_);
thread_ = NULL;
+ exit_thread_ = false;
}
void TunWin32Iocp::WriteTunPacket(Packet *packet) {
@@ -1789,11 +1810,9 @@ void TunWin32Iocp::WriteTunPacket(Packet *packet) {
}
}
-
-
//////////////////////////////////////////////////////////////////////////////
-TunWin32Overlapped::TunWin32Overlapped(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker), backend_(backend) {
+TunWin32Overlapped::TunWin32Overlapped(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker, backend->guid_), backend_(backend) {
wqueue_end_ = &wqueue_;
wqueue_ = NULL;
@@ -1814,10 +1833,10 @@ TunWin32Overlapped::~TunWin32Overlapped() {
CloseHandle(wake_event_);
}
-bool TunWin32Overlapped::Initialize(const TunConfig &&config, TunConfigOut *out) {
+bool TunWin32Overlapped::Configure(const TunConfig &&config, TunConfigOut *out) {
CloseTun();
- if (adapter_.OpenAdapter(&backend_->stop_mode_, FILE_FLAG_OVERLAPPED) &&
- adapter_.InitAdapter(std::move(config), out))
+ if (adapter_.OpenAdapter(backend_, FILE_FLAG_OVERLAPPED) &&
+ adapter_.ConfigureAdapter(std::move(config), out))
return true;
CloseTun();
return false;
@@ -1825,7 +1844,7 @@ bool TunWin32Overlapped::Initialize(const TunConfig &&config, TunConfigOut *out)
void TunWin32Overlapped::CloseTun() {
assert(thread_ == NULL);
- adapter_.CloseAdapter();
+ adapter_.CloseAdapter(false);
FreePacketList(wqueue_);
wqueue_ = NULL;
wqueue_end_ = &wqueue_;
@@ -1842,7 +1861,7 @@ void TunWin32Overlapped::ThreadMain() {
Packet *p = AllocPacket();
memset(&p->overlapped, 0, sizeof(p->overlapped));
p->overlapped.hEvent = read_event_;
- p->post_target = ThreadedPacketQueue::TARGET_PROCESSOR_TUN;
+ p->post_target = PacketProcessor::TARGET_PROCESSOR_TUN;
if (!ReadFile(adapter_.handle(), p->data, kPacketCapacity, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) {
FreePacket(p);
RERROR("TunWin32: ReadFile failed 0x%X", err);
@@ -1889,7 +1908,7 @@ void TunWin32Overlapped::ThreadMain() {
pending_writes = p->next;
memset(&p->overlapped, 0, sizeof(p->overlapped));
p->overlapped.hEvent = write_event_;
- p->post_target = ThreadedPacketQueue::TARGET_TUN_DEVICE;
+ p->post_target = PacketProcessor::TARGET_TUN_DEVICE;
if (!WriteFile(adapter_.handle(), p->data, p->size, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) {
RERROR("TunWin32: WriteFile failed 0x%X", err);
FreePacket(p);
@@ -1944,39 +1963,37 @@ void TunsafeBackendWin32::SetPublicKey(const uint8 key[32]) {
DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) {
TunsafeBackendWin32 *backend = (TunsafeBackendWin32*)bk;
int stop_mode;
+ int fast_retry_ctr = 0;
for(;;) {
TunWin32Iocp tun(&backend->dns_blocker_, backend);
UdpSocketWin32 udp;
WireguardProcessor wg_proc(&udp, &tun, backend);
- ThreadedPacketQueue queues_for_processor(&wg_proc, backend);
-
qs.udp_qsize1 = qs.udp_qsize2 = 0;
- udp.SetPacketHandler(&queues_for_processor);
- tun.SetPacketHandler(&queues_for_processor);
+ udp.SetPacketHandler(&backend->packet_processor_);
+ tun.SetPacketHandler(&backend->packet_processor_);
- wg_proc.dev().SetCurrentThreadAsMainThread();
-
- if (!ParseWireGuardConfigFile(&wg_proc, backend->config_file_, &backend->dns_resolver_))
+ if (backend->config_file_[0] &&
+ !ParseWireGuardConfigFile(&wg_proc, backend->config_file_, &backend->dns_resolver_))
goto getout_fail;
if (!wg_proc.Start())
goto getout_fail;
- // only for use in callbacks from wg
- backend->wg_processor_ = &wg_proc;
-
- queues_for_processor.Start();
- udp.StartThread();
- tun.StartThread();
-
backend->SetPublicKey(wg_proc.dev().public_key());
- while ((stop_mode = InterlockedExchange(&backend->stop_mode_, MODE_NONE)) == MODE_NONE) {
- SleepEx(INFINITE, TRUE);
- }
+ backend->wg_processor_ = &wg_proc;
+
+ udp.StartThread();
+ tun.StartThread();
+ stop_mode = backend->packet_processor_.Run(&wg_proc, backend);
+ udp.StopThread();
+ tun.StopThread();
+
+
+ backend->wg_processor_ = NULL;
// Keep DNS alive
if (stop_mode != MODE_EXIT)
@@ -1984,39 +2001,31 @@ DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) {
else
backend->dns_resolver_.ClearCache();
- udp.StopThread();
- tun.StopThread();
- queues_for_processor.Stop();
-
- backend->wg_processor_ = NULL;
-
FreeAllPackets();
if (stop_mode != MODE_TUN_FAILED)
return 0;
uint32 last_fail = GetTickCount();
- bool permanent_fail = (last_fail - backend->last_tun_adapter_failed_) < 5000;
+ fast_retry_ctr = (last_fail - backend->last_tun_adapter_failed_ < 5000) ? fast_retry_ctr + 1 : 0;
backend->last_tun_adapter_failed_ = last_fail;
- backend->status_ = permanent_fail ? TunsafeBackend::kErrorTunPermanent : TunsafeBackend::kStatusTunRetrying;
- backend->delegate_->OnStatusCode(backend->status_);
+ backend->SetStatus((fast_retry_ctr >= 3) ? TunsafeBackend::kErrorTunPermanent : TunsafeBackend::kStatusTunRetrying);
- if (permanent_fail) {
+ if (backend->status_ == TunsafeBackend::kErrorTunPermanent) {
RERROR("Too many automatic restarts...");
- goto getout_fail;
+ goto getout_fail_noseterr;
}
+ Sleep(1000);
}
getout_fail:
- backend->dns_blocker_.RestoreDns();
backend->status_ = TunsafeBackend::kErrorInitialize;
backend->delegate_->OnStatusCode(TunsafeBackend::kErrorInitialize);
+getout_fail_noseterr:
+ backend->dns_blocker_.RestoreDns();
return 0;
}
-static void WINAPI ExitServiceAPC(ULONG_PTR a) {
-}
-
TunsafeBackend::TunsafeBackend() {
is_started_ = false;
is_remote_ = false;
@@ -2029,29 +2038,33 @@ TunsafeBackend::~TunsafeBackend() {
}
-
TunsafeBackendWin32::TunsafeBackendWin32(Delegate *delegate) : delegate_(delegate), dns_resolver_(&dns_blocker_) {
memset(&stats_, 0, sizeof(stats_));
wg_processor_ = NULL;
InitPacketMutexes();
worker_thread_ = NULL;
- stop_mode_ = MODE_NONE;
last_tun_adapter_failed_ = 0;
want_periodic_stats_ = false;
-
+ guid_[0] = 0;
if (g_hklm_reg_key == NULL) {
RegCreateKeyEx(HKEY_LOCAL_MACHINE, "Software\\TunSafe", NULL, NULL, 0, KEY_ALL_ACCESS, NULL, &g_hklm_reg_key, NULL);
g_killswitch_want = RegReadInt(g_hklm_reg_key, "KillSwitch", 0);
}
-
delegate_->OnStateChanged();
}
TunsafeBackendWin32::~TunsafeBackendWin32() {
StopInner(false);
+ TunAdaptersInUse::GetInstance()->Release(this);
}
-bool TunsafeBackendWin32::Initialize() {
+
+void TunsafeBackendWin32::SetStatus(StatusCode status) {
+ status_ = status;
+ delegate_->OnStatusCode(status);
+}
+
+bool TunsafeBackendWin32::Configure() {
// it's always initialized
return true;
@@ -2061,6 +2074,17 @@ void TunsafeBackendWin32::Teardown() {
}
+bool TunsafeBackendWin32::SetTunAdapterName(const char *name) {
+ assert(worker_thread_ == NULL);
+ size_t len = strlen(name);
+ if (len >= sizeof(guid_) || guid_[0])
+ return false;
+ if (!TunAdaptersInUse::GetInstance()->Acquire(name, this))
+ return false;
+ memcpy(guid_, name, len + 1);
+ return true;
+}
+
void TunsafeBackendWin32::RequestStats(bool enable) {
want_periodic_stats_ = enable;
@@ -2084,12 +2108,10 @@ void TunsafeBackendWin32::Stop() {
void TunsafeBackendWin32::Start(const char *config_file) {
StopInner(true);
- stop_mode_ = MODE_NONE; // this needs to be here cause it's not reset on config file errors
dns_resolver_.SetAbortFlag(false);
is_started_ = true;
memset(public_key_, 0, sizeof(public_key_));
- status_ = kStatusInitializing;
- delegate_->OnStatusCode(kStatusInitializing);
+ SetStatus(kStatusInitializing);
delegate_->OnClearLog();
DWORD thread_id;
config_file_ = _strdup(config_file);
@@ -2098,17 +2120,15 @@ void TunsafeBackendWin32::Start(const char *config_file) {
delegate_->OnStateChanged();
}
-void TunsafeBackendWin32::TunAdapterFailed() {
- InterlockedExchange(&stop_mode_, MODE_TUN_FAILED);
- QueueUserAPC(&ExitServiceAPC, worker_thread_, NULL);
+void TunsafeBackendWin32::PostExit(int exit_code) {
+ packet_processor_.PostExit(exit_code);
}
void TunsafeBackendWin32::StopInner(bool is_restart) {
if (worker_thread_) {
ipv4_ip_ = 0;
dns_resolver_.SetAbortFlag(true);
- InterlockedExchange(&stop_mode_, is_restart ? MODE_RESTART : MODE_EXIT);
- QueueUserAPC(&ExitServiceAPC, worker_thread_, NULL);
+ PostExit(is_restart ? MODE_RESTART : MODE_EXIT);
WaitForSingleObject(worker_thread_, INFINITE);
CloseHandle(worker_thread_);
worker_thread_ = NULL;
@@ -2116,6 +2136,7 @@ void TunsafeBackendWin32::StopInner(bool is_restart) {
config_file_ = NULL;
is_started_ = false;
status_ = kStatusStopped;
+ packet_processor_.Reset();
if (!is_restart && !(g_killswitch_want & kBlockInternet_BlockOnDisconnect))
DeactivateKillSwitch(kBlockInternet_Off);
@@ -2217,6 +2238,14 @@ std::string TunsafeBackendWin32::GetConfigFileName() {
return std::string();
}
+void TunsafeBackendWin32::SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) {
+ ConfigPacket *config_packet = new ConfigPacket;
+ config_packet->ident = identifier;
+ config_packet->message = std::move(message);
+ config_packet->packet.post_target = PacketProcessor::TARGET_CONFIG_PROTOCOL;
+ packet_processor_.ForcePost(&config_packet->packet);
+}
+
void TunsafeBackendWin32::OnConnected() {
if (status_ != TunsafeBackend::kStatusConnected) {
ipv4_ip_ = ReadBE32(wg_processor_->tun_addr().addr);
@@ -2224,19 +2253,15 @@ void TunsafeBackendWin32::OnConnected() {
char buf[kSizeOfAddress];
RINFO("Connection established. IP %s", print_ip_prefix(buf, AF_INET, wg_processor_->tun_addr().addr, -1));
}
- status_ = TunsafeBackend::kStatusConnected;
- delegate_->OnStatusCode(TunsafeBackend::kStatusConnected);
+ SetStatus(TunsafeBackend::kStatusConnected);
}
}
void TunsafeBackendWin32::OnConnectionRetry(uint32 attempts) {
- if (status_ == TunsafeBackend::kStatusInitializing) {
- status_ = TunsafeBackend::kStatusConnecting;
- delegate_->OnStatusCode(TunsafeBackend::kStatusConnecting);
- } else if (attempts >= 3 && status_ == TunsafeBackend::kStatusConnected) {
- status_ = TunsafeBackend::kStatusReconnecting;
- delegate_->OnStatusCode(TunsafeBackend::kStatusReconnecting);
- }
+ if (status_ == TunsafeBackend::kStatusInitializing)
+ SetStatus(TunsafeBackend::kStatusConnecting);
+ else if (attempts >= 3 && status_ == TunsafeBackend::kStatusConnected)
+ SetStatus(TunsafeBackend::kStatusReconnecting);
}
void TunsafeBackend::Delegate::DoWork() {
@@ -2255,7 +2280,10 @@ TunsafeBackendDelegateThreaded::~TunsafeBackendDelegateThreaded() {
void TunsafeBackendDelegateThreaded::FreeEntry(Entry *e) {
if (e->lparam) {
- free((void*)e->lparam);
+ if (e->which == Id_OnConfigurationProtocolReply)
+ delete (std::string*)e->lparam;
+ else
+ free((void*)e->lparam);
e->lparam = NULL;
}
}
@@ -2273,6 +2301,7 @@ void TunsafeBackendDelegateThreaded::DoWork() {
case Id_OnStatusCode: delegate->OnStatusCode((TunsafeBackend::StatusCode)it->wparam); break;
case Id_OnClearLog: delegate->OnClearLog(); break;
case Id_OnGraphAvailable: delegate->OnGraphAvailable(); break;
+ case Id_OnConfigurationProtocolReply: delegate->OnConfigurationProtocolReply(it->wparam, std::move(*(std::string*)it->lparam)); break;
}
FreeEntry(&*it);
}
@@ -2314,6 +2343,10 @@ void TunsafeBackendDelegateThreaded::OnClearLog() {
AddEntry(Id_OnClearLog);
}
+void TunsafeBackendDelegateThreaded::OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) {
+ AddEntry(Id_OnConfigurationProtocolReply, (intptr_t)new std::string(std::move(reply)), ident);
+}
+
TunsafeBackend::Delegate::~Delegate() {
}
@@ -2331,7 +2364,7 @@ void StatsCollector::Init() {
Accumulator *acc = &accum_[0][0];
static const int kAccMax[TIMEVALS] = {5, 6, 10, 0};
- // Initialize all stats channels
+ // Configure all stats channels
for (uint32 channel = 0; channel != CHANNELS; channel++) {
for (uint32 timeval = 0; timeval != TIMEVALS; timeval++, acc++) {
acc->acc = 0;
@@ -2370,3 +2403,83 @@ void StatsCollector::AddSamples(float data[CHANNELS]) {
AddToAccumulators(&accum_[i][0], data[i]);
}
+
+TunAdaptersInUse::TunAdaptersInUse() {
+ num_inuse_ = 0;
+}
+
+bool TunAdaptersInUse::Acquire(const char guid[ADAPTER_GUID_SIZE], void *context) {
+ size_t len = strlen(guid);
+ if (len >= ADAPTER_GUID_SIZE)
+ return false;
+ ScopedLock scoped_lock(&mutex_);
+ Entry *e = entry_;
+ for (uint32 n = num_inuse_; ; n--, e++) {
+ if (n == 0) {
+ if (num_inuse_ == kMaxAdaptersInUse)
+ return false;
+ num_inuse_++;
+ e->context = context;
+ e->count = 0;
+ memcpy(e->guid, guid, len + 1);
+ return true;
+ }
+ if (!strcmp(e->guid, guid)) {
+ if (e->context != context)
+ return false;
+ e->count++;
+ return true;
+ }
+ }
+}
+
+void TunAdaptersInUse::Release(void *context) {
+ ScopedLock scoped_lock(&mutex_);
+ Entry *e = entry_;
+ for (uint32 n = num_inuse_; n; n--, e++) {
+ if (e->context == context) {
+ if (e->count-- == 0)
+ *e = entry_[num_inuse_-- - 1];
+ break;
+ }
+ }
+}
+
+void *TunAdaptersInUse::LookupContextFromGuid(const char guid[ADAPTER_GUID_SIZE]) {
+ ScopedLock scoped_lock(&mutex_);
+ Entry *e = entry_;
+ for (uint32 n = num_inuse_; n; n--, e++) {
+ if (!strcmp(e->guid, guid))
+ return e->context;
+ }
+ return NULL;
+}
+
+char *TunAdaptersInUse::GetAllGuid() {
+ ScopedLock scoped_lock(&mutex_);
+ char *rv = (char*)malloc(ADAPTER_GUID_SIZE * num_inuse_ + 1), *p = rv;
+ if (rv) {
+ Entry *e = entry_;
+ for (uint32 n = num_inuse_; n; n--, e++) {
+ size_t len = strlen(e->guid);
+ p[len] = '\n';
+ memcpy(p, e->guid, len);
+ p += len + 1;
+ }
+ *p = 0;
+ }
+ return rv;
+}
+
+static TunAdaptersInUse g_tun_adapters_in_use;
+TunAdaptersInUse *TunAdaptersInUse::GetInstance() {
+ return &g_tun_adapters_in_use;
+}
+
+TunsafeBackend *TunsafeBackend::FindBackendByTunGuid(const char *guid) {
+ return (TunsafeBackend*)TunAdaptersInUse::GetInstance()->LookupContextFromGuid(guid);
+}
+
+char *TunsafeBackend::GetAllGuid() {
+ return TunAdaptersInUse::GetInstance()->GetAllGuid();
+}
diff --git a/network_win32.h b/network_win32.h
index 9e865e8..bcd8adf 100644
--- a/network_win32.h
+++ b/network_win32.h
@@ -11,34 +11,39 @@
#include "tunsafe_threading.h"
#include
+enum {
+ ADAPTER_GUID_SIZE = 40,
+};
+
struct Packet;
class WireguardProcessor;
class TunsafeBackendWin32;
-class ThreadedPacketQueue {
+class PacketProcessor {
public:
- explicit ThreadedPacketQueue(WireguardProcessor *wg, TunsafeBackendWin32 *backend);
- ~ThreadedPacketQueue();
+ explicit PacketProcessor();
+ ~PacketProcessor();
enum {
TARGET_PROCESSOR_UDP = 0,
TARGET_PROCESSOR_TUN = 1,
TARGET_UDP_DEVICE = 2,
TARGET_TUN_DEVICE = 3,
+ TARGET_CONFIG_PROTOCOL = 4,
};
- void Start();
- void Stop();
+ void Reset();
+ int Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend);
void Post(Packet *packet, Packet **end, int count);
- void AbortingDriver();
+ void ForcePost(Packet *packet);
+ void PostExit(int exit_code);
+
+ const uint32 *posted_exit_code() { return &exit_code_; }
private:
- void PostTimerInterrupt();
- static void CALLBACK TimerRoutine(LPVOID lpArgToCompletionRoutine, DWORD dwTimerLowValue, DWORD dwTimerHighValue);
-
- DWORD ThreadMain();
- static DWORD WINAPI ThreadedPacketQueueLauncher(VOID *x);
+ static void CALLBACK ThreadPoolTimerCallback(PTP_CALLBACK_INSTANCE iTimerInstance, PVOID pContext, PTP_TIMER);
+ void HandleConfigurationProtocolPacket(WireguardProcessor *wg, TunsafeBackendWin32 *backend, Packet *packet);
Packet *first_;
Packet **last_ptr_;
uint32 packets_in_queue_;
@@ -46,12 +51,8 @@ private:
Mutex mutex_;
HANDLE event_;
- HANDLE timer_handle_;
- HANDLE handle_;
- WireguardProcessor *wg_;
- bool exit_flag_;
+ uint32 exit_code_;
bool timer_interrupt_;
- TunsafeBackendWin32 *backend_;
};
// Encapsulates a UDP socket, optionally listening for incoming packets
@@ -61,17 +62,16 @@ public:
explicit UdpSocketWin32();
~UdpSocketWin32();
- void SetPacketHandler(ThreadedPacketQueue *packet_handler) { packet_handler_ = packet_handler; }
+ void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
void StartThread();
void StopThread();
// -- from UdpInterface
- virtual bool Initialize(int listen_on_port) override;
+ virtual bool Configure(int listen_on_port) override;
virtual void WriteUdpPacket(Packet *packet) override;
private:
-
void ThreadMain();
static DWORD WINAPI UdpThread(void *x);
@@ -80,7 +80,7 @@ private:
Mutex mutex_;
- ThreadedPacketQueue *packet_handler_;
+ PacketProcessor *packet_handler_;
SOCKET socket_;
SOCKET socket_ipv6_;
HANDLE completion_port_handle_;
@@ -93,12 +93,12 @@ class DnsBlocker;
class TunWin32Adapter {
public:
- TunWin32Adapter(DnsBlocker *dns_blocker);
+ TunWin32Adapter(DnsBlocker *dns_blocker, const char guid[ADAPTER_GUID_SIZE]);
~TunWin32Adapter();
- bool OpenAdapter(unsigned int *exit_thread, DWORD open_flags);
- bool InitAdapter(const TunInterface::TunConfig &&config, TunInterface::TunConfigOut *out);
- void CloseAdapter();
+ bool OpenAdapter(TunsafeBackendWin32 *backend, DWORD open_flags);
+ bool ConfigureAdapter(const TunInterface::TunConfig &&config, TunInterface::TunConfigOut *out);
+ void CloseAdapter(bool is_restart);
HANDLE handle() { return handle_; }
@@ -121,8 +121,10 @@ private:
NET_LUID interface_luid_;
+ void *backend_;
+
std::vector pre_down_, post_down_;
- char guid_[64];
+ char guid_[ADAPTER_GUID_SIZE];
};
// Implementation of TUN interface handling using IO Completion Ports
@@ -131,23 +133,23 @@ public:
explicit TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend);
~TunWin32Iocp();
- void SetPacketHandler(ThreadedPacketQueue *packet_handler) { packet_handler_ = packet_handler; }
+ void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
void StartThread();
void StopThread();
// -- from TunInterface
- virtual bool Initialize(const TunConfig &&config, TunConfigOut *out) override;
+ virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override;
virtual void WriteTunPacket(Packet *packet) override;
TunWin32Adapter &adapter() { return adapter_; }
private:
- void CloseTun();
+ void CloseTun(bool is_restart);
void ThreadMain();
static DWORD WINAPI TunThread(void *x);
- ThreadedPacketQueue *packet_handler_;
+ PacketProcessor *packet_handler_;
HANDLE completion_port_handle_;
HANDLE thread_;
@@ -168,13 +170,13 @@ public:
explicit TunWin32Overlapped(DnsBlocker *blocker, TunsafeBackendWin32 *backend);
~TunWin32Overlapped();
- void SetPacketHandler(ThreadedPacketQueue *packet_handler) { packet_handler_ = packet_handler; }
+ void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
void StartThread();
void StopThread();
// -- from TunInterface
- virtual bool Initialize(const TunConfig &&config, TunConfigOut *out) override;
+ virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override;
virtual void WriteTunPacket(Packet *packet) override;
private:
@@ -182,7 +184,7 @@ private:
void ThreadMain();
static DWORD WINAPI TunThread(void *x);
- ThreadedPacketQueue *packet_handler_;
+ PacketProcessor *packet_handler_;
HANDLE thread_;
Mutex mutex_;
@@ -199,16 +201,18 @@ private:
};
class TunsafeBackendWin32 : public TunsafeBackend, public ProcessorDelegate {
- friend class ThreadedPacketQueue;
+ friend class PacketProcessor;
friend class TunWin32Iocp;
friend class TunWin32Overlapped;
+ friend class TunWin32Adapter;
public:
TunsafeBackendWin32(Delegate *delegate);
~TunsafeBackendWin32();
// -- from TunsafeBackend
- virtual bool Initialize() override;
+ virtual bool Configure() override;
virtual void Teardown() override;
+ virtual bool SetTunAdapterName(const char *name) override;
virtual void Start(const char *config_file) override;
virtual void Stop() override;
virtual void RequestStats(bool enable) override;
@@ -218,13 +222,23 @@ public:
virtual void SetServiceStartupFlags(uint32 flags) override;
virtual LinearizedGraph *GetGraph(int type) override;
virtual std::string GetConfigFileName() override;
+ virtual void SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) override;
// -- from ProcessorDelegate
virtual void OnConnected() override;
virtual void OnConnectionRetry(uint32 attempts) override;
void SetPublicKey(const uint8 key[32]);
- void TunAdapterFailed();
+ void PostExit(int exit_code);
+ enum {
+ MODE_NONE = 0,
+ MODE_EXIT = 1,
+ MODE_RESTART = 2,
+ MODE_TUN_FAILED = 3,
+ };
+ uint32 exit_code() { return *packet_processor_.posted_exit_code(); }
+
+ void SetStatus(StatusCode status);
private:
void StopInner(bool is_restart);
@@ -232,16 +246,7 @@ private:
void PushStats();
HANDLE worker_thread_;
-
- enum {
- MODE_NONE = 0,
- MODE_EXIT = 1,
- MODE_RESTART = 2,
- MODE_TUN_FAILED = 3,
- };
-
bool want_periodic_stats_;
- unsigned int stop_mode_;
Delegate *delegate_;
char *config_file_;
@@ -256,6 +261,10 @@ private:
Mutex stats_mutex_;
WgProcessorStats stats_;
+
+ PacketProcessor packet_processor_;
+
+ char guid_[ADAPTER_GUID_SIZE];
};
// This class ensures that all callbacks get rescheduled to another thread
@@ -265,13 +274,14 @@ public:
~TunsafeBackendDelegateThreaded();
private:
- virtual void OnGetStats(const WgProcessorStats &stats);
- virtual void OnGraphAvailable();
- virtual void OnStateChanged();
- virtual void OnClearLog();
- virtual void OnLogLine(const char **s);
- virtual void OnStatusCode(TunsafeBackend::StatusCode status);
- virtual void DoWork();
+ virtual void OnGetStats(const WgProcessorStats &stats) override;
+ virtual void OnGraphAvailable() override;
+ virtual void OnStateChanged() override;
+ virtual void OnClearLog() override;
+ virtual void OnLogLine(const char **s) override;
+ virtual void OnStatusCode(TunsafeBackend::StatusCode status) override;
+ virtual void OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) override;
+ virtual void DoWork() override;
enum Which {
Id_OnGetStats,
@@ -281,6 +291,7 @@ private:
Id_OnUpdateUI,
Id_OnStatusCode,
Id_OnGraphAvailable,
+ Id_OnConfigurationProtocolReply,
};
void AddEntry(Which which, intptr_t lparam = 0, uint32 wparam = 0);
@@ -302,3 +313,37 @@ private:
std::vector processing_entry_;
};
+// For each adapter, remembers whether the adapter is in use
+class TunAdaptersInUse {
+public:
+ TunAdaptersInUse();
+
+ // attempt to acquire the adapter, so it can't be acquired by anyone else
+ bool Acquire(const char guid[ADAPTER_GUID_SIZE], void *context);
+
+ // mark as free
+ void Release(void *context);
+
+ // Lookup a context from a guid
+ void *LookupContextFromGuid(const char guid[ADAPTER_GUID_SIZE]);
+
+ // Lookup a guid from a context
+ bool LookupGuidFromContext(void *context, char guid[ADAPTER_GUID_SIZE]);
+
+ char *GetAllGuid();
+
+ static TunAdaptersInUse *GetInstance();
+
+private:
+ enum {
+ kMaxAdaptersInUse = 16,
+ };
+ struct Entry {
+ char guid[ADAPTER_GUID_SIZE];
+ void *context;
+ int count;
+ };
+ Mutex mutex_;
+ uint8 num_inuse_;
+ Entry entry_[kMaxAdaptersInUse];
+};
diff --git a/network_win32_api.h b/network_win32_api.h
index bf5cf88..b280997 100644
--- a/network_win32_api.h
+++ b/network_win32_api.h
@@ -5,7 +5,6 @@
#include "stdafx.h"
#include "tunsafe_types.h"
#include "wireguard.h"
-
#include
struct StatsCollector {
@@ -72,6 +71,7 @@ public:
virtual void OnClearLog() = 0;
virtual void OnLogLine(const char **s) = 0;
virtual void OnStatusCode(TunsafeBackend::StatusCode status) = 0;
+ virtual void OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) = 0;
// This function is needed for CreateTunsafeBackendDelegateThreaded,
// It's expected to be called on the main thread and then all callbacks will arrive
// on the right thread.
@@ -82,9 +82,16 @@ public:
virtual ~TunsafeBackend();
// Setup/teardown the connection to the local service (if any)
- virtual bool Initialize() = 0;
+ virtual bool Configure() = 0;
virtual void Teardown() = 0;
+ // Set the name of the tun adapter that we want to use.
+ // On Windows this is the guid of the adapter.
+ // After having called this, this tun name cannot be used by any other instances.
+ // Returns false if the name can't be exclusively reserved to this adapter.
+ virtual bool SetTunAdapterName(const char *name) = 0;
+
+
virtual void Start(const char *config_file) = 0;
virtual void Stop() = 0;
virtual void RequestStats(bool enable) = 0;
@@ -93,10 +100,9 @@ public:
virtual InternetBlockState GetInternetBlockState(bool *is_activated) = 0;
virtual void SetInternetBlockState(InternetBlockState s) = 0;
virtual void SetServiceStartupFlags(uint32 flags) = 0;
-
virtual std::string GetConfigFileName() = 0;
-
virtual LinearizedGraph *GetGraph(int type) = 0;
+ virtual void SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) = 0;
bool is_started() { return is_started_; }
bool is_remote() { return is_remote_; }
@@ -105,6 +111,9 @@ public:
StatusCode status() { return status_; }
uint32 GetIP() { return ipv4_ip_; }
+ static TunsafeBackend *FindBackendByTunGuid(const char *guid);
+ static char *GetAllGuid();
+
protected:
bool is_started_;
bool is_remote_;
diff --git a/service_pipe_win32.cpp b/service_pipe_win32.cpp
new file mode 100644
index 0000000..bb178c5
--- /dev/null
+++ b/service_pipe_win32.cpp
@@ -0,0 +1,374 @@
+// SPDX-License-Identifier: AGPL-1.0-only
+// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved.
+#include "stdafx.h"
+#include "service_pipe_win32.h"
+#include "util.h"
+#include "service_win32_constants.h"
+
+///////////////////////////////////////////////////////////////////////////////////////
+// PipeManager
+///////////////////////////////////////////////////////////////////////////////////////
+
+PipeManager::PipeManager(const char *pipe_name, bool is_server_pipe, Delegate *delegate) {
+ pipe_name_ = _strdup(pipe_name);
+ is_server_pipe_ = is_server_pipe;
+ for (size_t i = 0; i < kMaxConnections * 2 + 1; i++)
+ events_[i] = CreateEvent(NULL, i != 0, FALSE, NULL); // For Exit
+ delegate_ = delegate;
+ thread_ = NULL;
+ exit_thread_ = false;
+ thread_id_ = 0;
+ for (size_t i = 0; i != kMaxConnections; i++)
+ connections_[i].Configure(this, (int)i);
+ connections_[0].state_ = PipeConnection::kStateStarting;
+}
+
+PipeManager::~PipeManager() {
+ StopThread();
+ for (size_t i = 0; i < kMaxConnections * 2 + 1; i++)
+ CloseHandle(events_[i]);
+ free(pipe_name_);
+}
+
+bool PipeManager::StartThread() {
+ assert(thread_ == NULL);
+ thread_ = CreateThread(NULL, 0, &StaticThreadMain, this, 0, &thread_id_);
+ return thread_ != NULL;
+}
+
+void PipeManager::StopThread() {
+ if (thread_ != NULL) {
+ exit_thread_ = true;
+ SetEvent(events_[0]);
+ WaitForSingleObject(thread_, INFINITE);
+ CloseHandle(thread_);
+ thread_ = NULL;
+ }
+}
+
+bool PipeManager::VerifyThread() {
+ return thread_id_ == GetCurrentThreadId();
+}
+
+void PipeManager::TryStartNewListener() {
+ assert(VerifyThread());
+ assert(is_server_pipe_);
+ // Check if any thread is in the listener state, if not, start
+ PipeConnection *found_conn = NULL;
+ for (size_t i = 0; i < kMaxConnections; i++) {
+ PipeConnection *conn = &connections_[i];
+ if (conn->connection_established_)
+ continue;
+ if (conn->state_ == PipeConnection::kStateWaitConnect)
+ return;
+ if (conn->state_ == PipeConnection::kStateNone && found_conn == NULL)
+ found_conn = conn;
+ }
+ if (found_conn) {
+ found_conn->state_ = PipeConnection::kStateStarting;
+ found_conn->AdvanceStateMachine();
+ }
+}
+
+DWORD WINAPI PipeManager::StaticThreadMain(void *x) {
+ return ((PipeManager*)x)->ThreadMain();
+}
+
+DWORD PipeManager::ThreadMain() {
+ assert(VerifyThread());
+
+ for (size_t i = 0; i < kMaxConnections; i++)
+ connections_[i].AdvanceStateMachine();
+
+ for (;;) {
+ DWORD rv = WaitForMultipleObjects(1 + kMaxConnections * 2, events_, FALSE, INFINITE);
+
+ // notify?
+ if (rv == WAIT_OBJECT_0) {
+ if (exit_thread_)
+ break;
+
+ delegate_->HandleNotify();
+ // The notification event is set when there might be new messages to send,
+ // so try to send them.
+ for (size_t i = 0; i != kMaxConnections; i++)
+ connections_[i].TrySendNextQueuedWrite();
+ } else if (rv >= WAIT_OBJECT_0 + 1 && rv < WAIT_OBJECT_0 + 1 + kMaxConnections * 2) {
+ PipeConnection *conn = &connections_[(rv - 1) >> 1];
+ if (rv & 1) {
+ // read finished
+ conn->AdvanceStateMachine();
+ } else {
+ // is the write event
+ conn->HandleWriteComplete();
+ }
+ } else {
+ assert(0);
+ }
+ }
+ return 0;
+}
+
+///////////////////////////////////////////////////////////////////////////////////////
+// PipeConnection
+///////////////////////////////////////////////////////////////////////////////////////
+
+static void ClearPipeOverlapped(OVERLAPPED *ov) {
+ ov->Internal = 0;
+ ov->InternalHigh = 0;
+ ov->Offset = 0;
+ ov->OffsetHigh = 0;
+}
+
+PipeConnection::PipeConnection() {
+ pipe_ = INVALID_HANDLE_VALUE;
+ packets_ = NULL;
+ packets_end_ = &packets_;
+ write_overlapped_active_ = false;
+ connection_established_ = false;
+ state_ = kStateNone;
+ tmp_packet_buf_ = NULL;
+ tmp_packet_size_ = 0;
+ manager_ = NULL;
+ delegate_ = NULL;
+}
+
+PipeConnection::~PipeConnection() {
+}
+
+void PipeConnection::Configure(PipeManager *manager, int slot) {
+ manager_ = manager;
+ read_overlapped_.hEvent = manager->events_[1 + slot * 2];
+ write_overlapped_.hEvent = manager->events_[1 + slot * 2 + 1];
+}
+
+int PipeConnection::InitializeServerPipeAndConnect() {
+ int BUFSIZE = 8192;
+ SECURITY_ATTRIBUTES saPipeSecurity = {0};
+ uint8 buf[SECURITY_DESCRIPTOR_MIN_LENGTH];
+ PSECURITY_DESCRIPTOR pPipeSD = (PSECURITY_DESCRIPTOR)buf;
+
+ if (!InitializeSecurityDescriptor(pPipeSD, SECURITY_DESCRIPTOR_REVISION))
+ return -1;
+
+ // set NULL DACL on the SD
+ if (!SetSecurityDescriptorDacl(pPipeSD, TRUE, (PACL)NULL, FALSE))
+ return -1;
+
+ // now set up the security attributes
+ saPipeSecurity.nLength = sizeof(SECURITY_ATTRIBUTES);
+ saPipeSecurity.bInheritHandle = TRUE;
+ saPipeSecurity.lpSecurityDescriptor = pPipeSD;
+
+ pipe_ = CreateNamedPipe(manager_->pipe_name_,
+ PIPE_ACCESS_DUPLEX | FILE_FLAG_OVERLAPPED,
+ PIPE_TYPE_BYTE | PIPE_READMODE_BYTE | PIPE_REJECT_REMOTE_CLIENTS | PIPE_WAIT,
+ PIPE_UNLIMITED_INSTANCES,
+ BUFSIZE, BUFSIZE, 0, &saPipeSecurity);
+ if (pipe_ == INVALID_HANDLE_VALUE)
+ return -1;
+
+ ClearPipeOverlapped(&read_overlapped_);
+ // It seems like ConnectNamedPipe never sets the event object if it completes
+ // right away.
+ if (!ConnectNamedPipe(pipe_, &read_overlapped_)) {
+ DWORD rv = GetLastError();
+ return (rv == ERROR_IO_PENDING) ? 0 : (rv == ERROR_PIPE_CONNECTED) ? 1 : -1;
+ } else {
+ return 1;
+ }
+}
+
+bool PipeConnection::InitializeClientPipe() {
+ assert(pipe_ == INVALID_HANDLE_VALUE);
+ pipe_ = CreateFile(manager_->pipe_name_, GENERIC_READ | GENERIC_WRITE, 0, NULL,
+ OPEN_EXISTING, FILE_FLAG_OVERLAPPED, NULL);
+ return (pipe_ != INVALID_HANDLE_VALUE);
+}
+
+void PipeConnection::ClosePipe() {
+ if (pipe_ != INVALID_HANDLE_VALUE) {
+ CancelIo(pipe_);
+ CloseHandle(pipe_);
+ pipe_ = INVALID_HANDLE_VALUE;
+ }
+ connection_established_ = false;
+ write_overlapped_active_ = false;
+ state_ = kStateNone;
+
+ free(tmp_packet_buf_);
+ tmp_packet_buf_ = NULL;
+ tmp_packet_size_ = 0;
+
+ ResetEvent(read_overlapped_.hEvent);
+ ResetEvent(write_overlapped_.hEvent);
+
+ packets_mutex_.Acquire();
+ OutgoingPacket *packets = packets_;
+ packets_ = NULL;
+ packets_end_ = &packets_;
+ packets_mutex_.Release();
+ while (packets) {
+ OutgoingPacket *p = packets;
+ packets = p->next;
+ free(p);
+ }
+}
+
+void PipeConnection::HandleWriteComplete() {
+ assert(write_overlapped_active_);
+
+ write_overlapped_active_ = false;
+
+ // Remove the packet from the front of the queue, now that it was sent.
+ packets_mutex_.Acquire();
+ OutgoingPacket *p = packets_;
+ if ((packets_ = p->next) == NULL)
+ packets_end_ = &packets_;
+ packets_mutex_.Release();
+ free(p);
+
+ if (packets_ == NULL && state_ == kStateWaitTimeout)
+ AdvanceStateMachine();
+ else
+ TrySendNextQueuedWrite();
+}
+
+bool PipeConnection::WritePacket(int type, const uint8 *data, size_t data_size) {
+ OutgoingPacket *packet = (OutgoingPacket *)malloc(offsetof(OutgoingPacket, data[data_size + 1]));
+ if (packet) {
+ packet->size = (uint32)(data_size + 1);
+ packet->data[0] = type;
+ memcpy(packet->data + 1, data, data_size);
+ packet->next = NULL;
+
+ packets_mutex_.Acquire();
+ OutgoingPacket *was_empty = packets_;
+ // login messages are always queued up front
+ if (type == TS_SERVICE_REQ_LOGIN) {
+ packet->next = packets_;
+ if (packet->next == NULL)
+ packets_end_ = &packet->next;
+ packets_ = packet;
+ } else {
+ *packets_end_ = packet;
+ packets_end_ = &packet->next;
+ }
+ packets_mutex_.Release();
+
+ if (was_empty == NULL) {
+ // Only allow the pipe thread to invoke the send
+ if (GetCurrentThreadId() == manager_->thread_id_) {
+ TrySendNextQueuedWrite();
+ } else {
+ SetEvent(manager_->notify_handle());
+ }
+ }
+ }
+ return true;
+}
+
+bool PipeConnection::VerifyThread() {
+ return manager_->VerifyThread();
+}
+
+void PipeConnection::TrySendNextQueuedWrite() {
+ assert(manager_->VerifyThread());
+ if (!write_overlapped_active_) {
+ OutgoingPacket *p = packets_;
+ if (p && connection_established_) {
+ ClearPipeOverlapped(&write_overlapped_);
+ if (WriteFile(pipe_, &p->size, p->size + 4, NULL, &write_overlapped_) || GetLastError() == ERROR_IO_PENDING)
+ write_overlapped_active_ = true;
+ } else {
+ ResetEvent(write_overlapped_.hEvent);
+ }
+ }
+}
+
+#define TS_WAIT_BEGIN(t) switch(state_) { case t:
+#define TS_WAIT_POINT(t) state_ = (t); return; case t:
+#define TS_WAIT_END() }
+
+void PipeConnection::AdvanceStateMachine() {
+ DWORD rv;
+ int srv;
+
+ TS_WAIT_BEGIN(kStateStarting)
+ // Create a named pipe and wait for connections from the UI process
+ if (manager_->is_server_pipe_) {
+ srv = InitializeServerPipeAndConnect();
+ if (srv < 0) {
+ if (!manager_->exit_thread_)
+ ExitProcess(1);
+ ClosePipe();
+ return;
+ }
+ if (srv == 0) {
+ TS_WAIT_POINT(kStateWaitConnect);
+ }
+ } else {
+ if (!InitializeClientPipe()) {
+ RINFO("Unable to connect to the TunSafe Service. Please make sure it's running.");
+ ClosePipe();
+ return;
+ }
+ }
+ connection_established_ = true;
+ delegate_ = manager_->delegate_->HandleNewConnection(this);
+ TrySendNextQueuedWrite();
+
+ for (;;) {
+ // Read the packet length
+ read_pos_ = 0;
+ do {
+ ClearPipeOverlapped(&read_overlapped_);
+ if (!ReadFile(pipe_, (uint8*)&packet_size_ + read_pos_, 4 - read_pos_, NULL, &read_overlapped_)) {
+ if ((rv = GetLastError()) != ERROR_IO_PENDING)
+ goto fail;
+ TS_WAIT_POINT(kStateWaitReadLength);
+ }
+ if ((uint32)read_overlapped_.InternalHigh == 0)
+ goto fail;
+ read_pos_ += (uint32)read_overlapped_.InternalHigh;
+ } while (read_pos_ != 4);
+ assert(packet_size_ != 0 && packet_size_ < 0x1000000);
+ if (packet_size_ == 0 || packet_size_ >= 0x1000000)
+ break;
+ free(tmp_packet_buf_);
+ tmp_packet_buf_ = (uint8*)malloc(packet_size_);
+ if (!tmp_packet_buf_)
+ break;
+ // Read the packet payload
+ read_pos_ = 0;
+ do {
+ ClearPipeOverlapped(&read_overlapped_);
+ if (!ReadFile(pipe_, tmp_packet_buf_ + read_pos_, packet_size_ - read_pos_, NULL, &read_overlapped_)) {
+ if ((rv = GetLastError()) != ERROR_IO_PENDING)
+ goto fail;
+ TS_WAIT_POINT(kStateWaitReadPayload);
+ }
+ if ((uint32)read_overlapped_.InternalHigh == 0)
+ goto fail;
+ read_pos_ += (uint32)read_overlapped_.InternalHigh;
+ } while (read_pos_ != packet_size_);
+ if (!delegate_->HandleMessage(tmp_packet_buf_[0], tmp_packet_buf_ + 1, packet_size_ - 1)) {
+ ResetEvent(read_overlapped_.hEvent);
+ if (packets_ != NULL) {
+ TS_WAIT_POINT(kStateWaitTimeout);
+ }
+ break;
+ }
+ }
+fail:
+ ClosePipe();
+ if (!manager_->exit_thread_) {
+ delegate_->HandleDisconnect();
+ if (manager_->is_server_pipe_)
+ manager_->TryStartNewListener();
+ }
+ TS_WAIT_END()
+
+
+}
+
diff --git a/service_pipe_win32.h b/service_pipe_win32.h
new file mode 100644
index 0000000..e350755
--- /dev/null
+++ b/service_pipe_win32.h
@@ -0,0 +1,111 @@
+// SPDX-License-Identifier: AGPL-1.0-only
+// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved.
+#pragma once
+#include "tunsafe_threading.h"
+#include "network_win32_api.h"
+
+class PipeManager;
+
+// Once a pipe connects, this object is used to facilitate the connection
+class PipeConnection {
+ friend class PipeManager;
+public:
+ class Delegate {
+ public:
+ virtual bool HandleMessage(int type, uint8 *data, size_t size) = 0;
+ virtual void HandleDisconnect() = 0;
+ };
+
+ PipeConnection();
+ ~PipeConnection();
+
+ void Configure(PipeManager *manager, int slot);
+ bool WritePacket(int type, const uint8 *data, size_t data_size);
+ HANDLE pipe_handle() { return pipe_; }
+ bool is_connected() { return connection_established_; }
+ bool VerifyThread();
+
+private:
+ // -1 = fail, 0 = wait, 1 = conn
+ int InitializeServerPipeAndConnect();
+ bool InitializeClientPipe();
+ void AdvanceStateMachine();
+ void ClosePipe();
+ void TrySendNextQueuedWrite();
+ void HandleWriteComplete();
+
+ Delegate *delegate_;
+ PipeManager *manager_;
+
+ HANDLE pipe_;
+ bool write_overlapped_active_;
+ bool connection_established_;
+
+ enum State {
+ kStateNone,
+ kStateStarting,
+ kStateWaitConnect,
+ kStateWaitReadLength,
+ kStateWaitReadPayload,
+ kStateWaitTimeout,
+ };
+
+ uint8 state_;
+
+ uint32 packet_size_;
+
+ struct OutgoingPacket {
+ OutgoingPacket *next;
+ uint32 size;
+ uint8 data[0];
+ };
+ OutgoingPacket *packets_, **packets_end_;
+ uint8 *tmp_packet_buf_;
+ DWORD tmp_packet_size_;
+ DWORD read_pos_;
+ OVERLAPPED write_overlapped_, read_overlapped_;
+ Mutex packets_mutex_;
+};
+
+// This class supports multiple PipeConnections and calls HandleNewConnection
+// when a new pipe connection is established.
+class PipeManager {
+ friend class PipeConnection;
+public:
+ class Delegate {
+ public:
+ // Called when a new connection is established
+ virtual PipeConnection::Delegate *HandleNewConnection(PipeConnection *handler) = 0;
+
+ // Called when a notification event was pushed
+ virtual void HandleNotify() = 0;
+ };
+
+ PipeManager(const char *pipe_name, bool is_server_pipe, Delegate *delegate);
+ ~PipeManager();
+
+ bool StartThread();
+ void StopThread();
+ bool VerifyThread();
+
+ HANDLE notify_handle() { return events_[0]; }
+ PipeConnection *GetClientConnection() { return &connections_[0]; }
+
+ void TryStartNewListener();
+
+private:
+ DWORD ThreadMain();
+ static DWORD WINAPI StaticThreadMain(void *x);
+
+ Delegate *delegate_;
+ HANDLE thread_;
+ char *pipe_name_;
+ DWORD thread_id_;
+ bool is_server_pipe_;
+ bool exit_thread_;
+
+ enum { kMaxConnections = 2 };
+ HANDLE events_[1 + kMaxConnections * 2];
+ PipeConnection connections_[kMaxConnections];
+};
+
diff --git a/service_win32.cpp b/service_win32.cpp
index 48866ff..b4e3c1f 100644
--- a/service_win32.cpp
+++ b/service_win32.cpp
@@ -9,40 +9,19 @@
#include
#include
#include "util_win32.h"
-
-static const uint64 kTunsafeServiceProtocolVersion = 20180809001;
+#include "service_win32_constants.h"
static SERVICE_STATUS_HANDLE m_statusHandle;
-static TunsafeServiceImpl *g_service;
+static TunsafeServiceManager *g_service;
+
+#define SERVICE_DEBUGGING 0
#define SERVICE_NAME L"TunSafeService"
#define SERVICE_NAMEA "TunSafeService"
#define SERVICE_START_TYPE SERVICE_AUTO_START
#define SERVICE_DEPENDENCIES L"tap0901\0dhcp\0"
#define SERVICE_ACCOUNT NULL
-//L"NT AUTHORITY\\LocalService"
#define SERVICE_PASSWORD NULL
-#define PIPE_NAME "\\\\.\\pipe\\TunSafe\\ServiceControl"
-
-
-enum {
- SERVICE_REQ_LOGIN = 0,
- SERVICE_REQ_START = 1,
- SERVICE_REQ_STOP = 2,
- SERVICE_REQ_GETSTATS = 4,
- SERVICE_REQ_SET_INTERNET_BLOCKSTATE = 5,
- SERVICE_REQ_RESETSTATS = 6,
- SERVICE_REQ_SET_STARTUP_FLAGS = 7,
-
- SERVICE_MSG_STATE = 8,
- SERVICE_MSG_LOGLINE = 9,
- SERVICE_MSG_STATS = 11,
- SERVICE_MSG_CLEARLOG = 12,
- SERVICE_MSG_STATUS_CODE = 14,
-
- SERVICE_REQ_GET_GRAPH = 15,
- SERVICE_MSG_GRAPH = 16,
-};
struct ServiceHandles {
SC_HANDLE manager;
@@ -61,7 +40,6 @@ struct ServiceHandles {
bool StartService();
};
-
static DWORD InstallService(PWSTR pszServiceName,
PWSTR pszDisplayName,
DWORD dwStartType,
@@ -191,6 +169,18 @@ getout:
return result;
}
+static wchar_t *RegReadStrW(HKEY hkey, const wchar_t *key, const wchar_t *def) {
+ wchar_t buf[1024];
+ DWORD n = sizeof(buf) - 2;
+ DWORD type = 0;
+ if (RegQueryValueExW(hkey, key, NULL, &type, (BYTE*)buf, &n) != ERROR_SUCCESS || type != REG_SZ)
+ return def ? _wcsdup(def) : NULL;
+ n >>= 1;
+ if (n && buf[n - 1] == 0)
+ n--;
+ buf[n] = 0;
+ return _wcsdup(buf);
+}
static DWORD GetNonTransientServiceStatus(SC_HANDLE service) {
SERVICE_STATUS ssSvcStatus = {};
@@ -208,7 +198,6 @@ static DWORD GetNonTransientServiceStatus(SC_HANDLE service) {
}
}
-
bool ServiceHandles::StartService() {
DWORD state = GetNonTransientServiceStatus(service);
if (state == 0 || state == SERVICE_RUNNING)
@@ -221,7 +210,6 @@ bool ServiceHandles::StartService() {
return GetNonTransientServiceStatus(service) == SERVICE_RUNNING;
}
-
static bool StartTunsafeService() {
ServiceHandles handles;
@@ -232,14 +220,14 @@ static bool StartTunsafeService() {
bool IsTunsafeServiceRunning() {
ServiceHandles handles;
-
+#if SERVICE_DEBUGGING
+ return true;
+#endif
if (!handles.Open(SERVICE_NAME, SC_MANAGER_CONNECT, SERVICE_QUERY_STATUS))
return false;
-
return GetNonTransientServiceStatus(handles.service) == SERVICE_RUNNING;
}
-
void StopTunsafeService() {
ServiceHandles handles;
if (!handles.Open(SERVICE_NAME, SC_MANAGER_CONNECT,
@@ -296,7 +284,6 @@ bool IsTunSafeServiceInstalled() {
return handles.Open(SERVICE_NAME, SC_MANAGER_CONNECT, SERVICE_QUERY_STATUS);
}
-
static void WriteServiceLog(const char *pszFunction, WORD dwError) {
char szMessage[260];
snprintf(szMessage, ARRAYSIZE(szMessage), "%s failed w/err 0x%08lx", pszFunction, dwError);
@@ -321,8 +308,7 @@ static void WriteServiceLog(const char *pszFunction, WORD dwError) {
}
}
-static void SetServiceStatus(DWORD dwCurrentState,
- DWORD dwWin32ExitCode = 0,
+static void SetServiceStatus(DWORD dwCurrentState, DWORD dwWin32ExitCode = 0,
DWORD dwWaitHint = 0) {
static DWORD dwCheckPoint = 1;
@@ -333,10 +319,8 @@ static void SetServiceStatus(DWORD dwCurrentState,
m_status.dwCurrentState = dwCurrentState;
m_status.dwWin32ExitCode = dwWin32ExitCode;
m_status.dwWaitHint = dwWaitHint;
- m_status.dwCheckPoint =
- ((dwCurrentState == SERVICE_RUNNING) ||
- (dwCurrentState == SERVICE_STOPPED)) ?
- 0 : dwCheckPoint++;
+ m_status.dwCheckPoint = ((dwCurrentState == SERVICE_RUNNING) ||
+ (dwCurrentState == SERVICE_STOPPED)) ? 0 : dwCheckPoint++;
// Report the status of the service to the SCM.
::SetServiceStatus(m_statusHandle, &m_status);
}
@@ -389,495 +373,147 @@ static const SERVICE_TABLE_ENTRYW serviceTable[] = {
{NULL, NULL}
};
-PipeMessageHandler::PipeMessageHandler(const char *pipe_name, bool is_server_pipe, Delegate *delegate) {
- pipe_name_ = _strdup(pipe_name);
- is_server_pipe_ = is_server_pipe;
- delegate_ = delegate;
- pipe_ = INVALID_HANDLE_VALUE;
- wait_handles_[0] = CreateEvent(NULL, TRUE, FALSE, NULL); // for ReadFile
- wait_handles_[1] = CreateEvent(NULL, FALSE, FALSE, NULL); // For Exit
- wait_handles_[2] = CreateEvent(NULL, TRUE, FALSE, NULL); // for WriteFile
- packets_ = NULL;
- thread_ = NULL;
- packets_end_ = &packets_;
- write_overlapped_active_ = false;
- exit_thread_ = false;
- connection_established_ = false;
- thread_id_ = 0;
- state_ = kStateNone;
- tmp_packet_buf_ = NULL;
-}
+///////////////////////////////////////////////////////////////////////////////////////
+// TunsafeServiceManager
+///////////////////////////////////////////////////////////////////////////////////////
-PipeMessageHandler::~PipeMessageHandler() {
- StopThread();
- CloseHandle(wait_handles_[0]);
- CloseHandle(wait_handles_[1]);
- CloseHandle(wait_handles_[2]);
- free(pipe_name_);
-}
-
-bool PipeMessageHandler::InitializeServerPipeAndWait() {
- int BUFSIZE = 2048;
- SECURITY_ATTRIBUTES saPipeSecurity = {0};
- uint8 buf[SECURITY_DESCRIPTOR_MIN_LENGTH];
- PSECURITY_DESCRIPTOR pPipeSD = (PSECURITY_DESCRIPTOR)buf;
-
- if (!InitializeSecurityDescriptor(pPipeSD, SECURITY_DESCRIPTOR_REVISION))
- return false;
-
- // set NULL DACL on the SD
- if (!SetSecurityDescriptorDacl(pPipeSD, TRUE, (PACL)NULL, FALSE))
- return false;
-
- // now set up the security attributes
- saPipeSecurity.nLength = sizeof(SECURITY_ATTRIBUTES);
- saPipeSecurity.bInheritHandle = TRUE;
- saPipeSecurity.lpSecurityDescriptor = pPipeSD;
-
- pipe_ = CreateNamedPipeW(L"\\\\.\\pipe\\TunSafe\\ServiceControl",
- PIPE_ACCESS_DUPLEX | FILE_FLAG_OVERLAPPED,
- PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE | PIPE_REJECT_REMOTE_CLIENTS | PIPE_WAIT,
- PIPE_UNLIMITED_INSTANCES,
- BUFSIZE, BUFSIZE, 0, &saPipeSecurity);
- if (pipe_ == INVALID_HANDLE_VALUE)
- return false;
+TunsafeServiceManager::TunsafeServiceManager()
+ : pipe_manager_(TUNSAFE_PIPE_NAME, true, this) {
+ server_unique_id_ = 0;
- memset(&read_overlapped_, 0, sizeof(read_overlapped_));
- read_overlapped_.hEvent = wait_handles_[0];
- if (!ConnectNamedPipe(pipe_, &read_overlapped_)) {
- DWORD rv = GetLastError();
- if (rv != ERROR_PIPE_CONNECTED && rv != ERROR_IO_PENDING)
- return false;
- }
- return true;
-}
-
-bool PipeMessageHandler::InitializeClientPipe() {
- assert(pipe_ == INVALID_HANDLE_VALUE);
- pipe_ = CreateFile(pipe_name_, GENERIC_READ | GENERIC_WRITE, 0, NULL,
- OPEN_EXISTING, FILE_FLAG_OVERLAPPED, NULL);
- if (pipe_ == INVALID_HANDLE_VALUE)
- return false;
- DWORD mode = PIPE_READMODE_MESSAGE;
- SetNamedPipeHandleState(pipe_, &mode, NULL, NULL);
- return true;
-}
-
-void PipeMessageHandler::ClosePipe() {
- if (pipe_ != INVALID_HANDLE_VALUE) {
- CancelIo(pipe_);
- CloseHandle(pipe_);
- pipe_ = INVALID_HANDLE_VALUE;
- }
- connection_established_ = false;
- write_overlapped_active_ = false;
-
- free(tmp_packet_buf_);
- tmp_packet_buf_ = NULL;
-
- ResetEvent(wait_handles_[0]);
- ResetEvent(wait_handles_[2]);
-
- packets_mutex_.Acquire();
- OutgoingPacket *packets = packets_;
- packets_ = NULL;
- packets_end_ = &packets_;
- packets_mutex_.Release();
- while (packets) {
- OutgoingPacket *p = packets;
- packets = p->next;
- free(p);
- }
-}
-
-bool PipeMessageHandler::WritePacket(int type, const uint8 *data, size_t data_size) {
- OutgoingPacket *packet = (OutgoingPacket *)malloc(offsetof(OutgoingPacket, data[data_size + 1]));
- if (packet) {
- packet->size = (uint32)(data_size + 1);
- packet->data[0] = type;
- memcpy(packet->data + 1, data, data_size);
- packet->next = NULL;
-
- packets_mutex_.Acquire();
- OutgoingPacket *was_empty = packets_;
- // login messages are always queued up front
- if (type == SERVICE_REQ_LOGIN) {
- packet->next = packets_;
- if (packet->next == NULL)
- packets_end_ = &packet->next;
- packets_ = packet;
- } else {
- *packets_end_ = packet;
- packets_end_ = &packet->next;
- }
- packets_mutex_.Release();
-
- if (was_empty == NULL) {
- // Only allow the pipe thread to invoke the send
- if (GetCurrentThreadId() == thread_id_) {
- SendNextQueuedWrite();
- } else {
- SetEvent(wait_handles_[1]);
- }
- }
- }
- return true;
-}
-
-void PipeMessageHandler::SendNextQueuedWrite() {
- assert(thread_id_ == GetCurrentThreadId());
- if (!write_overlapped_active_) {
- OutgoingPacket *p = packets_;
- if (p && connection_established_) {
- memset(&write_overlapped_, 0, sizeof(write_overlapped_));
- write_overlapped_.hEvent = wait_handles_[2];
- if (WriteFile(pipe_, p->data, p->size, NULL, &write_overlapped_) || GetLastError() == ERROR_IO_PENDING)
- write_overlapped_active_ = true;
- } else {
- ResetEvent(wait_handles_[2]);
- }
- }
-}
-
-#define TS_WAIT_BEGIN(t) switch(state_) { case t:
-#define TS_WAIT_POINT(t) state_ = (t); return; case t:
-#define TS_WAIT_END() }
-
-void PipeMessageHandler::AdvanceStateMachine() {
- DWORD rv, bytes_read;
-
- TS_WAIT_BEGIN(kStateNone)
- for(;;) {
- // Create a named pipe and wait for connections from the UI process
- if (is_server_pipe_) {
- if (!InitializeServerPipeAndWait()) {
- if (!exit_thread_)
- ExitProcess(1);
- break;
- }
- TS_WAIT_POINT(kStateWaitConnect);
- } else {
- if (!InitializeClientPipe()) {
- RINFO("Unable to connect to the TunSafe Service. Please make sure it's running.");
- break;
- }
- }
- connection_established_ = true;
- delegate_->HandleNewConnection();
- SendNextQueuedWrite();
-
- for (;;) {
- memset(&read_overlapped_, 0, sizeof(read_overlapped_));
- read_overlapped_.hEvent = wait_handles_[0];
- if (!ReadFile(pipe_, NULL, 0, NULL, &read_overlapped_)) {
- rv = GetLastError();
- if (rv != ERROR_IO_PENDING && rv != ERROR_MORE_DATA)
- break;
- }
- TS_WAIT_POINT(kStateWaitReadLength);
- PeekNamedPipe(pipe_, NULL, 0, NULL, &tmp_packet_size_, NULL);
- if (tmp_packet_size_ == 0)
- break;
-
- free(tmp_packet_buf_);
- tmp_packet_buf_ = (uint8*)malloc(tmp_packet_size_);
- if (!tmp_packet_buf_)
- break;
-
- memset(&read_overlapped_, 0, sizeof(read_overlapped_));
- read_overlapped_.hEvent = wait_handles_[0];
- if (!ReadFile(pipe_, tmp_packet_buf_, tmp_packet_size_, NULL, &read_overlapped_)) {
- rv = GetLastError();
- if (rv != ERROR_IO_PENDING)
- break;
- }
- TS_WAIT_POINT(kStateWaitReadPayload);
- bytes_read = (uint32)read_overlapped_.InternalHigh;
- if (bytes_read == 0)
- break;
- if (!delegate_->HandleMessage(tmp_packet_buf_[0], tmp_packet_buf_ + 1, bytes_read - 1)) {
- ResetEvent(wait_handles_[0]);
- TS_WAIT_POINT(kStateWaitTimeout);
- break;
- }
- }
- if (exit_thread_)
- break;
- delegate_->HandleDisconnect();
- if (!is_server_pipe_)
- break;
- ClosePipe();
- }
- TS_WAIT_END()
- ClosePipe();
-}
-
-DWORD WINAPI PipeMessageHandler::StaticThreadMain(void *x) {
- return ((PipeMessageHandler*)x)->ThreadMain();
-}
-
-bool PipeMessageHandler::VerifyThread() {
- return thread_id_ == GetCurrentThreadId();
-}
-
-DWORD PipeMessageHandler::ThreadMain() {
- assert((thread_id_ = GetCurrentThreadId()) != 0);
- assert(state_ == kStateNone);
-
- AdvanceStateMachine();
-
- for(;;) {
- DWORD rv = WaitForMultipleObjects(3, wait_handles_, FALSE, (state_ == kStateWaitTimeout) ? 1000 : INFINITE);
-
- // packet write finished?
- if (rv == WAIT_OBJECT_0 + 2) {
- assert(write_overlapped_active_);
-
- write_overlapped_active_ = false;
-
- // Remove the packet from the front of the queue, now that it was sent.
- packets_mutex_.Acquire();
- OutgoingPacket *p = packets_;
- if ((packets_ = p->next) == NULL)
- packets_end_ = &packets_;
- packets_mutex_.Release();
- free(p);
- SendNextQueuedWrite();
-
- // notification
- } else if (rv == WAIT_OBJECT_0 + 1) {
- if (exit_thread_ || !delegate_->HandleNotify())
- break;
- // The notification event is set when there might be new messages to send,
- // so try to send them.
- SendNextQueuedWrite();
-
- // read finished?
- } else if (rv == WAIT_OBJECT_0) {
- AdvanceStateMachine();
- } else if (rv == WAIT_TIMEOUT) {
- if (state_ == kStateWaitTimeout)
- AdvanceStateMachine();
- } else {
- assert(0);
- }
- }
- return 0;
-}
-
-bool PipeMessageHandler::StartThread() {
- DWORD thread_id;
- assert(thread_ == NULL);
- thread_ = CreateThread(NULL, 0, &StaticThreadMain, this, 0, &thread_id);
- return thread_ != NULL;
-}
-
-void PipeMessageHandler::StopThread() {
- if (thread_ != NULL) {
- exit_thread_ = true;
- SetEvent(wait_handles_[1]);
- WaitForSingleObject(thread_, INFINITE);
- CloseHandle(thread_);
- thread_ = NULL;
- }
- ClosePipe();
-}
-
-TunsafeServiceImpl::TunsafeServiceImpl()
- : message_handler_(PIPE_NAME, true, this) {
- thread_delegate_ = CreateTunsafeBackendDelegateThreaded(this, [=] {
- SetEvent(message_handler_.notify_handle());
- });
-
- backend_ = CreateNativeTunsafeBackend(thread_delegate_);
- historical_log_lines_count_ = historical_log_lines_pos_ = 0;
- last_line_sent_ = 0;
- did_send_getstate_ = false;
- memset(historical_log_lines_, 0, sizeof(historical_log_lines_));
hkey_ = NULL;
- want_graph_type_ = 0xffffffff;
RegCreateKeyEx(HKEY_LOCAL_MACHINE, "Software\\TunSafe", NULL, NULL, 0, KEY_ALL_ACCESS, NULL, &hkey_, NULL);
+
+ main_backend_ = new TunsafeServiceBackend(this);
+ backends_.push_back(main_backend_);
}
-TunsafeServiceImpl::~TunsafeServiceImpl() {
+TunsafeServiceManager::~TunsafeServiceManager() {
+ for (TunsafeServiceBackend *backend : backends_)
+ delete backend;
RegCloseKey(hkey_);
}
-static wchar_t *RegReadStrW(HKEY hkey, const wchar_t *key, const wchar_t *def) {
- wchar_t buf[1024];
- DWORD n = sizeof(buf) - 2;
- DWORD type = 0;
- if (RegQueryValueExW(hkey, key, NULL, &type, (BYTE*)buf, &n) != ERROR_SUCCESS || type != REG_SZ)
- return def ? _wcsdup(def) : NULL;
- n >>= 1;
- if (n && buf[n - 1] == 0)
- n--;
- buf[n] = 0;
- return _wcsdup(buf);
+void TunsafeServiceManager::HandleNotify() {
+ for (TunsafeServiceBackend *backend : backends_)
+ backend->HandleNotify();
}
-unsigned TunsafeServiceImpl::OnStart(int argc, wchar_t **argv) {
+PipeConnection::Delegate *TunsafeServiceManager::HandleNewConnection(PipeConnection *connection) {
+ TunsafeServiceServer *server = new TunsafeServiceServer(connection, main_backend_, server_unique_id_++);
+ main_backend_->AddPipeServer(server);
+ pipe_manager_.TryStartNewListener();
+ return server;
+}
+
+unsigned TunsafeServiceManager::OnStart(int argc, wchar_t **argv) {
uint32 service_flags = RegReadInt(hkey_, "ServiceStartupFlags", 0);
- if ( (service_flags & kStartupFlag_BackgroundService) && (service_flags & kStartupFlag_ConnectWhenWindowsStarts) ) {
+ if ((service_flags & kStartupFlag_BackgroundService) && (service_flags & kStartupFlag_ConnectWhenWindowsStarts)) {
char *conf = RegReadStr(hkey_, "LastUsedConfigFile", "");
- if (conf && *conf) {
- current_filename_ = (char*)conf;
- backend_->Start((char*)conf);
- }
+ if (conf && *conf)
+ main_backend_->Start(conf);
free(conf);
}
-
- message_handler_.StartThread();
+ pipe_manager_.StartThread();
return 0;
}
-bool TunsafeServiceImpl::AuthenticateUser() {
- did_authenticate_user_ = true;
-
- if (!ImpersonateNamedPipeClient(message_handler_.pipe_handle()))
- return false;
- wchar_t *user = GetUsernameOfCurrentUser(true);
- RevertToSelf();
- if (!user)
- return false;
- wchar_t *valid_user = RegReadStrW(hkey_, L"AllowedUsername", L"");
- bool rv = valid_user && wcscmp(user, valid_user) == 0;
-
- free(user);
- free(valid_user);
- return rv;
+void TunsafeServiceManager::OnStop() {
+ pipe_manager_.StopThread();
+ for (TunsafeServiceBackend *backend : backends_)
+ backend->Stop();
}
-bool TunsafeServiceImpl::HandleMessage(int type, uint8 *data, size_t size) {
- if (!did_authenticate_user_) {
- if (type != SERVICE_REQ_LOGIN || size < 8 || *(uint64*)data != kTunsafeServiceProtocolVersion) {
- const char *s = "Versioning Problem: The TunSafe service is a different version than the UI.";
- message_handler_.WritePacket(SERVICE_MSG_LOGLINE, (uint8*)s, strlen(s));
- return false;
- }
- if (!AuthenticateUser()) {
- const char *s = "Permission Problem: Your Windows account is different from the account\r\nthat installed the TunSafe Service. Please reinstall it.\r\n";
- message_handler_.WritePacket(SERVICE_MSG_LOGLINE, (uint8*)s, strlen(s));
- return false;
+void TunsafeServiceManager::OnShutdown() {
+
+}
+
+TunsafeServiceBackend *TunsafeServiceManager::CreateBackend(const char *guid) {
+ TunsafeServiceBackend *service_backend = new TunsafeServiceBackend(this);
+
+ // If we're unable to assign the name, maybe it's already in use
+ if (!service_backend->backend()->SetTunAdapterName(guid)) {
+ delete service_backend;
+ return NULL;
+ }
+ backends_.push_back(service_backend);
+ return service_backend;
+}
+
+void TunsafeServiceManager::DestroyBackend(TunsafeServiceBackend *service_backend) {
+ assert(service_backend != main_backend_);
+
+ // Erase from the list
+ auto it = std::find(backends_.begin(), backends_.end(), service_backend);
+ if (it != backends_.end())
+ backends_.erase(it);
+
+ delete service_backend;
+}
+
+bool TunsafeServiceManager::SwitchInterface(TunsafeServiceServer *server, const char *interfac, bool want_create) {
+ // Find a backend by name
+ TunsafeBackend *backend = TunsafeBackend::FindBackendByTunGuid(interfac);
+ TunsafeServiceBackend *service_backend = NULL;
+ if (backend) {
+ for (TunsafeServiceBackend *sb : backends_) {
+ if (sb->backend() == backend) {
+ service_backend = sb;
+ break;
+ }
}
}
-
- switch (type) {
- case SERVICE_REQ_START:
- if (data[size - 1] != 0)
+ if (!service_backend) {
+ if (!want_create)
return false;
-
- // Don't allow reading arbitrary files on disk
- if (!EnsureValidConfigPath((char*)data)) {
- char buf[MAX_PATH];
- GetConfigPath(buf, sizeof(buf));
- char *s = str_cat_alloc("Permission Problem: The Config file is in an unsafe location.\r\n Must be in:", buf, "\r\n");
- message_handler_.WritePacket(SERVICE_MSG_LOGLINE, (uint8*)s, strlen(s));
- free(s);
+ service_backend = CreateBackend(interfac);
+ if (!service_backend)
return false;
- }
-
- g_allow_pre_post = RegReadInt(hkey_, "AllowPrePost", 0) != 0;
-
- current_filename_ = (char*)data;
- backend_->Start((char*)data);
- RegWriteStr(hkey_, "LastUsedConfigFile", (char*)data);
-
- break;
-
- case SERVICE_REQ_STOP:
- backend_->Stop();
- RegWriteStr(hkey_, "LastUsedConfigFile", "");
- OnStateChanged();
- break;
-
- case SERVICE_REQ_LOGIN:
- did_send_getstate_ = true;
- OnStatusCode(backend_->status());
- OnStateChanged();
- SendQueuedLogLines();
- break;
-
- case SERVICE_REQ_GETSTATS:
- if (size < 1) return false;
- backend_->RequestStats(data[0] != 0);
- break;
-
- case SERVICE_REQ_SET_INTERNET_BLOCKSTATE:
- if (size < 1)
- return false;
- backend_->SetInternetBlockState((InternetBlockState)data[0]);
- OnStateChanged();
- break;
-
- case SERVICE_REQ_RESETSTATS:
- backend_->ResetStats();
- break;
-
- case SERVICE_REQ_GET_GRAPH:
- if (size < 4) return false;
- want_graph_type_ = *(int*)data;
- TunsafeServiceImpl::OnGraphAvailable();
- break;
-
- case SERVICE_REQ_SET_STARTUP_FLAGS:
- if (size < 4)
- return false;
- RegSetValueEx(hkey_, "ServiceStartupFlags", NULL, REG_DWORD, (BYTE*)data, 4);
- break;
-
- default:
- return false;
+ }
+ if (server->service_backend() != service_backend) {
+ server->service_backend()->RemovePipeServer(server);
+ service_backend->AddPipeServer(server);
+ server->set_service_backend(service_backend);
}
return true;
}
-bool TunsafeServiceImpl::HandleNotify() {
- thread_delegate_->DoWork();
- return true;
+
+///////////////////////////////////////////////////////////////////////////////////////
+// TunsafeServiceBackend
+///////////////////////////////////////////////////////////////////////////////////////
+
+TunsafeServiceBackend::TunsafeServiceBackend(TunsafeServiceManager *manager) {
+ manager_ = manager;
+ historical_log_lines_count_ = historical_log_lines_pos_ = 0;
+ memset(historical_log_lines_, 0, sizeof(historical_log_lines_));
+ HANDLE event = manager_->pipe_manager_.notify_handle();
+ thread_delegate_ = CreateTunsafeBackendDelegateThreaded(this, [=] { SetEvent(event); });
+ backend_ = CreateNativeTunsafeBackend(thread_delegate_);
}
-void TunsafeServiceImpl::HandleNewConnection() {
- did_send_getstate_ = false;
- did_authenticate_user_ = false;
- last_line_sent_ = 0;
+TunsafeServiceBackend::~TunsafeServiceBackend() {
+ assert(pipe_servers_.empty());
+ delete backend_;
+ delete thread_delegate_;
}
-void TunsafeServiceImpl::HandleDisconnect() {
- want_graph_type_ = 0xffffffff;
- backend_->RequestStats(false);
- uint32 service_flags = RegReadInt(hkey_, "ServiceStartupFlags", 0);
- if (!(service_flags & kStartupFlag_BackgroundService))
- backend_->Stop();
+void TunsafeServiceBackend::OnGetStats(const WgProcessorStats &stats) {
+ for (TunsafeServiceServer *pipe_server : pipe_servers_)
+ if (pipe_server->want_stats())
+ pipe_server->WritePacket(TS_SERVICE_MSG_STATS, (uint8*)&stats, sizeof(stats));
}
-void TunsafeServiceImpl::OnGraphAvailable() {
- if (want_graph_type_ != 0xffffffff) {
- LinearizedGraph *graph = backend_->GetGraph(want_graph_type_);
- if (graph)
- message_handler_.WritePacket(SERVICE_MSG_GRAPH, (uint8*)graph, graph->total_size);
- }
-}
-
-void TunsafeServiceImpl::SendQueuedLogLines() {
- assert(message_handler_.VerifyThread());
- uint32 maxi = std::min(historical_log_lines_count_, historical_log_lines_pos_ - last_line_sent_);
- last_line_sent_ = historical_log_lines_pos_;
- for (uint32 i = 0; i < maxi; i++) {
- const char *s = historical_log_lines_[(historical_log_lines_pos_ - maxi + i) & (LOGLINE_COUNT - 1)];
- if (s)
- message_handler_.WritePacket(SERVICE_MSG_LOGLINE, (uint8*)s, strlen(s));
- }
-}
-
-void TunsafeServiceImpl::OnClearLog() {
+void TunsafeServiceBackend::OnClearLog() {
historical_log_lines_pos_ = 0;
historical_log_lines_count_ = 0;
- message_handler_.WritePacket(SERVICE_MSG_CLEARLOG, NULL, 0);
+ for (TunsafeServiceServer *pipe_server : pipe_servers_)
+ if (pipe_server->want_state_updates())
+ pipe_server->WritePacket(TS_SERVICE_MSG_CLEARLOG, NULL, 0);
}
-void TunsafeServiceImpl::OnLogLine(const char **s) {
- assert(message_handler_.VerifyThread());
+void TunsafeServiceBackend::OnLogLine(const char **s) {
+ assert(manager_->pipe_manager_.VerifyThread());
char *ss = (char*)*s;
*s = NULL;
char *&x = historical_log_lines_[historical_log_lines_pos_++ & (LOGLINE_COUNT - 1)];
@@ -885,46 +521,321 @@ void TunsafeServiceImpl::OnLogLine(const char **s) {
if (historical_log_lines_count_ < LOGLINE_COUNT)
historical_log_lines_count_++;
free(ss);
- if (did_send_getstate_)
- SendQueuedLogLines();
+
+ for (TunsafeServiceServer *pipe_server : pipe_servers_)
+ pipe_server->SendQueuedLogLines();
}
-void TunsafeServiceImpl::OnGetStats(const WgProcessorStats &stats) {
- message_handler_.WritePacket(SERVICE_MSG_STATS, (uint8*)&stats, sizeof(stats));
+void TunsafeServiceBackend::OnStateChanged() {
+ SendStateUpdate(NULL); // Send to all
}
-void TunsafeServiceImpl::OnStateChanged() {
+void TunsafeServiceBackend::OnStatusCode(TunsafeBackend::StatusCode status) {
+ if (status == TunsafeBackend::kStatusConnected)
+ OnStateChanged(); // ensure we know the ip first
+ uint32 v32 = (uint32)status;
+ for (TunsafeServiceServer *pipe_server : pipe_servers_)
+ if (pipe_server->want_state_updates())
+ pipe_server->WritePacket(TS_SERVICE_MSG_STATUS_CODE, (uint8*)&v32, 4);
+}
+
+void TunsafeServiceBackend::OnGraphAvailable() {
+ for (TunsafeServiceServer *pipe_server : pipe_servers_)
+ pipe_server->OnGraphAvailable();
+}
+
+void TunsafeServiceBackend::OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) {
+ for (TunsafeServiceServer *pipe_server : pipe_servers_)
+ if (pipe_server->unique_id() == ident)
+ pipe_server->WritePacket(TS_SERVICE_REQ_TEXT_PROTOCOL_REPLY, (uint8*)reply.data(), reply.size());
+}
+
+void TunsafeServiceBackend::Start(const char *filename) {
+ g_allow_pre_post = RegReadInt(manager_->hkey_, "AllowPrePost", 0) != 0;
+ current_filename_ = filename;
+ backend_->Start(filename);
+}
+
+void TunsafeServiceBackend::RememberLastUsedConfigFile(const char *filename) {
+ if (manager_->main_backend() == this)
+ RegWriteStr(manager_->hkey_, "LastUsedConfigFile", filename);
+}
+
+void TunsafeServiceBackend::Stop() {
+ if (manager_->main_backend() == this)
+ RegWriteStr(manager_->hkey_, "LastUsedConfigFile", "");
+
+ backend_->Stop();
+ OnStateChanged();
+}
+
+void TunsafeServiceBackend::UpdateRequestStats() {
+ bool want = false;
+ for (auto it = pipe_servers_.begin(); it != pipe_servers_.end(); ++it) {
+ if ((*it)->want_stats()) {
+ want = true;
+ break;
+ }
+ }
+ backend_->RequestStats(want);
+}
+
+void TunsafeServiceBackend::HandleNotify() {
+ thread_delegate_->DoWork();
+}
+
+void TunsafeServiceBackend::SendStateUpdate(TunsafeServiceServer *filter) {
+ if (pipe_servers_.empty())
+ return;
+
uint8 *temp = new uint8[current_filename_.size() + 1 + sizeof(ServiceState)];
bool is_activated;
memset(temp, 0, sizeof(ServiceState));
-
ServiceState *ss = (ServiceState *)temp;
ss->is_started = backend_->is_started();
ss->internet_block_state = backend_->GetInternetBlockState(&is_activated);
ss->internet_block_state_active = is_activated;
ss->ipv4_ip = backend_->GetIP();
memcpy(ss->public_key, backend_->public_key(), 32);
-
memcpy(temp + sizeof(ServiceState), current_filename_.c_str(), current_filename_.size() + 1);
- message_handler_.WritePacket(SERVICE_MSG_STATE, temp, current_filename_.size() + 1 + sizeof(ServiceState));
+ for (TunsafeServiceServer *pipe_server : pipe_servers_) {
+ if (filter != NULL && pipe_server != filter)
+ continue;
+ if (pipe_server->want_state_updates())
+ pipe_server->WritePacket(TS_SERVICE_MSG_STATE, temp, current_filename_.size() + 1 + sizeof(ServiceState));
+ }
+
delete[] temp;
}
-void TunsafeServiceImpl::OnStatusCode(TunsafeBackend::StatusCode status) {
- if (status == TunsafeBackend::kStatusConnected)
- OnStateChanged(); // ensure we know the ip first
- uint32 v32 = (uint32)status;
- message_handler_.WritePacket(SERVICE_MSG_STATUS_CODE, (uint8*)&v32, 4);
+void TunsafeServiceBackend::RemovePipeServer(TunsafeServiceServer *pipe_server) {
+ auto it = std::find(pipe_servers_.begin(), pipe_servers_.end(), pipe_server);
+ if (it != pipe_servers_.end())
+ pipe_servers_.erase(it);
+
+ UpdateRequestStats();
+
+ // Stop the main backend, or destroy a disconnetced backend, when the last client disconnects.
+ if (pipe_servers_.empty()) {
+ if (this == manager_->main_backend_) {
+ uint32 service_flags = RegReadInt(manager_->hkey_, "ServiceStartupFlags", 0);
+ if (!(service_flags & kStartupFlag_BackgroundService))
+ backend_->Stop();
+ } else {
+ if (!backend_->is_started())
+ manager_->DestroyBackend(this);
+ }
+ }
}
-void TunsafeServiceImpl::OnStop() {
- message_handler_.StopThread();
- backend_->Stop();
+void TunsafeServiceBackend::AddPipeServer(TunsafeServiceServer *pipe_server) {
+ pipe_servers_.push_back(pipe_server);
}
-void TunsafeServiceImpl::OnShutdown() {
+///////////////////////////////////////////////////////////////////////////////////////
+// TunsafeServiceServer
+///////////////////////////////////////////////////////////////////////////////////////
+TunsafeServiceServer::TunsafeServiceServer(PipeConnection *pipe, TunsafeServiceBackend *backend, uint32 unique_id) {
+ unique_id_ = unique_id;
+ connection_ = pipe;
+ service_backend_ = backend;
+ last_line_sent_ = 0;
+ want_state_updates_ = false;
+ did_authenticate_user_ = false;
+ want_stats_ = false;
+ want_graph_type_ = 0xffffffff;
+}
+
+TunsafeServiceServer::~TunsafeServiceServer() {
+}
+
+void TunsafeServiceServer::WritePacket(int type, const uint8 *data, size_t data_size) {
+ connection_->WritePacket(type, data, data_size);
+}
+
+struct ServiceLoginMessage {
+ uint64 version;
+ char interfac[kTsMaxDevnameSize];
+ bool want_state_updates;
+ bool want_create_interface;
+};
+
+bool TunsafeServiceServer::HandleMessage(int type, uint8 *data, size_t size) {
+ if (!did_authenticate_user_) {
+ if (type != TS_SERVICE_REQ_LOGIN ||
+ size < sizeof(ServiceLoginMessage) ||
+ ((ServiceLoginMessage*)data)->version != TUNSAFE_SERVICE_PROTOCOL_VERSION) {
+ const char *s = "Versioning Problem: The TunSafe service is a different version than the UI.";
+ connection_->WritePacket(TS_SERVICE_MSG_ERROR_REPLY, (uint8*)s, strlen(s));
+ return false;
+ }
+ if (!AuthenticateUser()) {
+ const char *s = "Permission Problem: Your Windows account is different from the account\r\nthat installed the TunSafe Service. Please reinstall it.";
+ connection_->WritePacket(TS_SERVICE_MSG_ERROR_REPLY, (uint8*)s, strlen(s));
+ return false;
+ }
+ }
+
+ switch (type) {
+ case TS_SERVICE_REQ_START: {
+ if (size == 0 || data[size - 1] != 0)
+ return false;
+
+ for (size_t i = 0; i < size; i++) {
+ if (data[i] == '/')
+ data[i] = '\\';
+ }
+
+ char buf[MAX_PATH];
+ buf[0] = 0;
+
+ if (data[0]) {
+ if (!ExpandConfigPath((char*)data, buf, sizeof(buf)) || GetFileAttributesA(buf) == INVALID_FILE_ATTRIBUTES) {
+ char *s = str_cat_alloc("File '", (char*)data, "' not found");
+ connection_->WritePacket(TS_SERVICE_MSG_ERROR_REPLY, (uint8*)s, strlen(s));
+ free(s);
+ return false;
+ }
+ // Don't allow reading arbitrary files on disk
+ if (!EnsureValidConfigPath(buf)) {
+ GetConfigPath(buf, sizeof(buf));
+ char *s = str_cat_alloc("Permission Problem: The Config file is in an unsafe location.\r\n Must be in: ", buf, "");
+ connection_->WritePacket(TS_SERVICE_MSG_ERROR_REPLY, (uint8*)s, strlen(s));
+ free(s);
+ return false;
+ }
+ }
+ service_backend_->Start(buf);
+ service_backend_->RememberLastUsedConfigFile(buf);
+
+ // Ensure we reply with something
+ if (!want_state_updates_) {
+ uint32 v32 = (uint32)service_backend_->backend_->status();
+ connection_->WritePacket(TS_SERVICE_MSG_STATUS_CODE, (uint8*)&v32, 4);
+ }
+ break;
+ }
+
+ case TS_SERVICE_REQ_STOP:
+ service_backend_->Stop();
+ if (!want_state_updates_) {
+ uint32 v32 = (uint32)service_backend_->backend_->status();
+ connection_->WritePacket(TS_SERVICE_MSG_STATUS_CODE, (uint8*)&v32, 4);
+ }
+ break;
+
+ case TS_SERVICE_REQ_LOGIN: {
+ if (((ServiceLoginMessage*)data)->interfac[kTsMaxDevnameSize - 1])
+ return false; // sanity check
+
+ if (((ServiceLoginMessage*)data)->interfac[0] != 0) {
+ if (!service_backend_->manager_->SwitchInterface(this, ((ServiceLoginMessage*)data)->interfac, ((ServiceLoginMessage*)data)->want_create_interface)) {
+ const char *s = ((ServiceLoginMessage*)data)->want_create_interface ? "Unable to add the interface" : "Interface is not started";
+ connection_->WritePacket(TS_SERVICE_MSG_ERROR_REPLY, (uint8*)s, strlen(s));
+ return false;
+ }
+ }
+ want_state_updates_ = ((ServiceLoginMessage*)data)->want_state_updates;
+ if (want_state_updates_) {
+ SendQueuedLogLines();
+ service_backend_->SendStateUpdate(this);
+ uint32 v32 = (uint32)service_backend_->backend_->status();
+ connection_->WritePacket(TS_SERVICE_MSG_STATUS_CODE, (uint8*)&v32, 4);
+ }
+
+ break;
+ }
+
+ // return a list of all running interfaces
+ case TS_SERVICE_REQ_GETINTERFACES: {
+ char *s = TunsafeBackend::GetAllGuid();
+ connection_->WritePacket(TS_SERVICE_REQ_GETINTERFACES_REPLY, (uint8*)s, s ? strlen(s) : 0);
+ free(s);
+ break;
+ }
+
+ case TS_SERVICE_REQ_GETSTATS:
+ if (size < 1) return false;
+ want_stats_ = (data[0] != 0);
+ service_backend_->UpdateRequestStats();
+ break;
+
+ case TS_SERVICE_REQ_SET_INTERNET_BLOCKSTATE:
+ if (size < 1)
+ return false;
+ service_backend_->backend_->SetInternetBlockState((InternetBlockState)data[0]);
+ service_backend_->OnStateChanged();
+ break;
+
+ case TS_SERVICE_REQ_RESETSTATS:
+ service_backend_->backend_->ResetStats();
+ break;
+
+ case TS_SERVICE_REQ_GET_GRAPH:
+ if (size < 4) return false;
+ want_graph_type_ = *(int*)data;
+ TunsafeServiceServer::OnGraphAvailable();
+ break;
+
+ case TS_SERVICE_REQ_SET_STARTUP_FLAGS:
+ if (size < 4)
+ return false;
+ RegSetValueEx(service_backend_->manager_->hkey_, "ServiceStartupFlags", NULL, REG_DWORD, (BYTE*)data, 4);
+ break;
+
+ case TS_SERVICE_REQ_TEXT_PROTOCOL:
+ if (!service_backend_->backend_->is_started())
+ service_backend_->Start("");
+ service_backend_->backend_->SendConfigurationProtocolPacket(unique_id_, std::string((char*)data, size));
+ break;
+
+ default:
+ return false;
+ }
+ return true;
+}
+
+void TunsafeServiceServer::HandleDisconnect() {
+ service_backend_->RemovePipeServer(this);
+ delete this;
+}
+
+void TunsafeServiceServer::OnGraphAvailable() {
+ if (want_graph_type_ != 0xffffffff) {
+ LinearizedGraph *graph = service_backend_->backend_->GetGraph(want_graph_type_);
+ if (graph)
+ connection_->WritePacket(TS_SERVICE_MSG_GRAPH, (uint8*)graph, graph->total_size);
+ }
+}
+
+void TunsafeServiceServer::SendQueuedLogLines() {
+ if (!want_state_updates_)
+ return;
+ assert(connection_->VerifyThread());
+ uint32 maxi = std::min(service_backend_->historical_log_lines_count_, service_backend_->historical_log_lines_pos_ - last_line_sent_);
+ last_line_sent_ = service_backend_->historical_log_lines_pos_;
+ for (uint32 i = 0; i < maxi; i++) {
+ const char *s = service_backend_->historical_log_lines_[(service_backend_->historical_log_lines_pos_ - maxi + i) & (TunsafeServiceBackend::LOGLINE_COUNT - 1)];
+ if (s)
+ connection_->WritePacket(TS_SERVICE_MSG_LOGLINE, (uint8*)s, strlen(s));
+ }
+}
+
+bool TunsafeServiceServer::AuthenticateUser() {
+ if (!ImpersonateNamedPipeClient(connection_->pipe_handle()))
+ return false;
+ wchar_t *user = GetUsernameOfCurrentUser(true);
+ RevertToSelf();
+ if (!user)
+ return false;
+ wchar_t *valid_user = RegReadStrW(service_backend_->manager_->hkey_, L"AllowedUsername", L"");
+ bool rv = valid_user && wcscmp(user, valid_user) == 0;
+ did_authenticate_user_ = rv;
+ free(user);
+ free(valid_user);
+ return rv;
}
static void PushServiceLine(const char *s) {
@@ -937,13 +848,11 @@ static void PushServiceLine(const char *s) {
snprintf(buf, sizeof(buf), "[%.2d:%.2d:%.2d] ", t.wHour, t.wMinute, t.wSecond);
size_t tl = strlen(buf);
- char *x = (char*) malloc(tl + l + 3);
+ char *x = (char*) malloc(tl + l + 1);
memcpy(x, buf, tl);
memcpy(x + tl, s, l);
- x[l + tl] = '\r';
- x[l + tl + 1] = '\n';
- x[l + tl + 2] = '\0';
- g_service->delegate()->OnLogLine((const char**)&x);
+ x[l + tl] = '\0';
+ g_service->main_backend()->delegate()->OnLogLine((const char**)&x);
free(x);
} else {
size_t l = strlen(s);
@@ -965,14 +874,14 @@ static void PushServiceLine(const char *s) {
}
BOOL RunProcessAsTunsafeServiceProcess() {
- g_service = new TunsafeServiceImpl;
+ g_service = new TunsafeServiceManager;
g_logger = &PushServiceLine;
-
- //g_service->OnStart(NULL, 0);
- //MessageBoxA(0, "Service running", "Service running", 0);
- //return TRUE;
-// while (true)Sleep(1000);
+#if SERVICE_DEBUGGING
+ g_service->OnStart(NULL, 0);
+ while (true)
+ Sleep(1000);
+#endif
// Connects the main thread of a service process to the service control
// manager, which causes the thread to be the service control dispatcher
@@ -980,42 +889,47 @@ BOOL RunProcessAsTunsafeServiceProcess() {
// stopped. The process should simply terminate when the call returns.
return StartServiceCtrlDispatcherW(serviceTable);
}
-TunsafeServiceClient::TunsafeServiceClient(TunsafeBackend::Delegate *delegate)
- : message_handler_(PIPE_NAME, false, this) {
+
+
+///////////////////////////////////////////////////////////////////////////////////////
+// TunsafeServiceClient
+///////////////////////////////////////////////////////////////////////////////////////
+
+TunsafeServiceClient::TunsafeServiceClient(TunsafeBackend::Delegate *delegate)
+ : pipe_manager_(TUNSAFE_PIPE_NAME, false, this) {
is_remote_ = true;
got_state_from_control_ = false;
delegate_ = delegate;
cached_graph_ = 0;
last_graph_type_ = 0xffffffff;
memset(&service_state_, 0, sizeof(service_state_));
+ connection_ = pipe_manager_.GetClientConnection();
}
TunsafeServiceClient::~TunsafeServiceClient() {
- message_handler_.StopThread();
+ pipe_manager_.StopThread();
}
-bool TunsafeServiceClient::Initialize() {
- // Wait for the service to start
- last_graph_type_ = 0xffffffff;
- return message_handler_.StartThread();
+bool TunsafeServiceClient::Configure() {
+ return pipe_manager_.StartThread();
}
void TunsafeServiceClient::Start(const char *config_file) {
- message_handler_.WritePacket(SERVICE_REQ_START, (uint8*)config_file, strlen(config_file) + 1);
+ connection_->WritePacket(TS_SERVICE_REQ_START, (uint8*)config_file, strlen(config_file) + 1);
}
void TunsafeServiceClient::Stop() {
- message_handler_.WritePacket(SERVICE_REQ_STOP, NULL, 0);
+ connection_->WritePacket(TS_SERVICE_REQ_STOP, NULL, 0);
}
void TunsafeServiceClient::RequestStats(bool enable) {
want_stats_ = enable;
- if (message_handler_.is_connected())
- message_handler_.WritePacket(SERVICE_REQ_GETSTATS, &want_stats_, 1);
+ if (connection_->is_connected())
+ connection_->WritePacket(TS_SERVICE_REQ_GETSTATS, &want_stats_, 1);
}
void TunsafeServiceClient::ResetStats() {
- message_handler_.WritePacket(SERVICE_REQ_RESETSTATS, NULL, 0);
+ connection_->WritePacket(TS_SERVICE_REQ_RESETSTATS, NULL, 0);
}
InternetBlockState TunsafeServiceClient::GetInternetBlockState(bool *is_activated) {
@@ -1026,17 +940,17 @@ InternetBlockState TunsafeServiceClient::GetInternetBlockState(bool *is_activate
void TunsafeServiceClient::SetInternetBlockState(InternetBlockState s) {
uint8 v = (uint8)s;
- message_handler_.WritePacket(SERVICE_REQ_SET_INTERNET_BLOCKSTATE, &v, 1);
+ connection_->WritePacket(TS_SERVICE_REQ_SET_INTERNET_BLOCKSTATE, &v, 1);
}
void TunsafeServiceClient::SetServiceStartupFlags(uint32 flags) {
- message_handler_.WritePacket(SERVICE_REQ_SET_STARTUP_FLAGS, (uint8*)&flags, 4);
+ connection_->WritePacket(TS_SERVICE_REQ_SET_STARTUP_FLAGS, (uint8*)&flags, 4);
}
LinearizedGraph *TunsafeServiceClient::GetGraph(int type) {
if (type != last_graph_type_) {
last_graph_type_ = type;
- message_handler_.WritePacket(SERVICE_REQ_GET_GRAPH, (uint8*)&type, 4);
+ connection_->WritePacket(TS_SERVICE_REQ_GET_GRAPH, (uint8*)&type, 4);
}
mutex_.Acquire();
LinearizedGraph *graph = cached_graph_;
@@ -1045,6 +959,8 @@ LinearizedGraph *TunsafeServiceClient::GetGraph(int type) {
return new_graph;
}
+void TunsafeServiceClient::SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) {
+}
std::string TunsafeServiceClient::GetConfigFileName() {
mutex_.Acquire();
@@ -1054,8 +970,8 @@ std::string TunsafeServiceClient::GetConfigFileName() {
}
bool TunsafeServiceClient::HandleMessage(int type, uint8 *data, size_t data_size) {
- switch(type) {
- case SERVICE_MSG_STATE:
+ switch (type) {
+ case TS_SERVICE_MSG_STATE:
if (data_size <= sizeof(service_state_) || data[data_size - 1])
return false;
got_state_from_control_ = true;
@@ -1069,15 +985,18 @@ bool TunsafeServiceClient::HandleMessage(int type, uint8 *data, size_t data_size
mutex_.Release();
delegate_->OnStateChanged();
return true;
- case SERVICE_MSG_LOGLINE: {
+ case TS_SERVICE_MSG_LOGLINE:
+ case TS_SERVICE_MSG_ERROR_REPLY: {
if (data_size == 0)
return false;
char *s = my_strndup((char*)data, data_size);
- delegate_->OnLogLine((const char **)&s);
- free(s);
+ if (s) {
+ delegate_->OnLogLine((const char **)&s);
+ free(s);
+ }
return true;
}
- case SERVICE_MSG_STATS: {
+ case TS_SERVICE_MSG_STATS: {
WgProcessorStats stats;
if (data_size != sizeof(WgProcessorStats))
return false;
@@ -1085,25 +1004,24 @@ bool TunsafeServiceClient::HandleMessage(int type, uint8 *data, size_t data_size
delegate_->OnGetStats(stats);
return true;
}
- case SERVICE_MSG_CLEARLOG:
+ case TS_SERVICE_MSG_CLEARLOG:
delegate_->OnClearLog();
return true;
- case SERVICE_MSG_STATUS_CODE:
+ case TS_SERVICE_MSG_STATUS_CODE:
if (data_size < 4)
return false;
status_ = (StatusCode)*(uint32*)data;
delegate_->OnStatusCode(status_);
return true;
- case SERVICE_MSG_GRAPH:
- if (data_size < 4 || data_size != *(uint32*)data)
+ case TS_SERVICE_MSG_GRAPH:
+ if (data_size < sizeof(LinearizedGraph) || data_size != *(uint32*)data)
return false;
-
LinearizedGraph *graph = (LinearizedGraph*)memdup(data, data_size);
mutex_.Acquire();
std::swap(graph, cached_graph_);
- mutex_.Release();
+ mutex_.Release();
free(graph);
delegate_->OnGraphAvailable();
return true;
@@ -1112,15 +1030,18 @@ bool TunsafeServiceClient::HandleMessage(int type, uint8 *data, size_t data_size
return false;
}
-bool TunsafeServiceClient::HandleNotify() {
- return true;
+void TunsafeServiceClient::HandleNotify() {
}
-
-void TunsafeServiceClient::HandleNewConnection() {
- message_handler_.WritePacket(SERVICE_REQ_LOGIN, (uint8*)&kTunsafeServiceProtocolVersion, 8);
+PipeConnection::Delegate *TunsafeServiceClient::HandleNewConnection(PipeConnection *connection) {
+ assert(connection == connection_);
+ ServiceLoginMessage msg = {0};
+ msg.want_state_updates = true;
+ msg.version = TUNSAFE_SERVICE_PROTOCOL_VERSION;
+ connection_->WritePacket(TS_SERVICE_REQ_LOGIN, (uint8*)&msg, sizeof(msg));
if (want_stats_)
- message_handler_.WritePacket(SERVICE_REQ_GETSTATS, &want_stats_, 1);
+ connection_->WritePacket(TS_SERVICE_REQ_GETSTATS, &want_stats_, 1);
+ return this;
}
void TunsafeServiceClient::HandleDisconnect() {
@@ -1129,16 +1050,19 @@ void TunsafeServiceClient::HandleDisconnect() {
}
void TunsafeServiceClient::Teardown() {
- message_handler_.StopThread();
+ pipe_manager_.StopThread();
+}
+
+bool TunsafeServiceClient::SetTunAdapterName(const char *name) {
+ // override which tun adapter we want to start
+ return false;
}
TunsafeBackend *CreateTunsafeServiceClient(TunsafeBackend::Delegate *delegate) {
TunsafeServiceClient *client = new TunsafeServiceClient(delegate);
- if (client && !client->Initialize()) {
+ if (client && !client->Configure()) {
delete client;
client = NULL;
}
return client;
}
-
-
diff --git a/service_win32.h b/service_win32.h
index e19766e..6af6e2d 100644
--- a/service_win32.h
+++ b/service_win32.h
@@ -3,154 +3,183 @@
#pragma once
#include "service_win32_api.h"
-#include
-#include "util.h"
+#include "service_pipe_win32.h"
#include "network_win32_api.h"
#include "tunsafe_threading.h"
-#include
-#include
-#include
+
+// Takes care of multiple TunsafeServiceBackend
+class TunsafeServiceManager : public PipeManager::Delegate {
+ friend class TunsafeServiceBackend;
+ friend class TunsafeServiceServer;
+public:
+ TunsafeServiceManager();
+ virtual ~TunsafeServiceManager();
+
+ // -- from PipeManager::Delegate
+ virtual void HandleNotify() override;
+ virtual PipeConnection::Delegate *HandleNewConnection(PipeConnection *connection) override;
+
+ // Called by the service control code to bring the service up or down
+ unsigned OnStart(int argc, wchar_t **argv);
+ void OnStop();
+ void OnShutdown();
+
+ TunsafeServiceBackend *main_backend() { return main_backend_; }
+
+ TunsafeServiceBackend *CreateBackend(const char *guid);
+ void DestroyBackend(TunsafeServiceBackend *backend);
+
+ bool SwitchInterface(TunsafeServiceServer *server, const char *interfac, bool want_create);
+
+private:
+ // Points at the Tunsafe hklm reg key
+ HKEY hkey_;
+ uint32 server_unique_id_;
+
+ PipeManager pipe_manager_;
+
+ TunsafeServiceBackend *main_backend_;
+ std::vector backends_;
+};
+
+// One of these exist for each TunsafeBackend
+class TunsafeServiceBackend : public TunsafeBackend::Delegate {
+ friend class TunsafeServiceServer;
+public:
+ explicit TunsafeServiceBackend(TunsafeServiceManager *manager);
+ virtual ~TunsafeServiceBackend();
+
+ // -- from TunsafeBackend::Delegate
+ virtual void OnGetStats(const WgProcessorStats &stats) override;
+ virtual void OnClearLog() override;
+ virtual void OnLogLine(const char **s) override;
+ virtual void OnStateChanged() override;
+ virtual void OnStatusCode(TunsafeBackend::StatusCode status) override;
+ virtual void OnGraphAvailable() override;
+ virtual void OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) override;
+
+ TunsafeBackend *backend() { return backend_; }
+ TunsafeBackend::Delegate *delegate() { return thread_delegate_; }
+
+ void Start(const char *filename);
+ void RememberLastUsedConfigFile(const char *filename);
+
+ void Stop();
+
+ // Trigger backend stats updates whenever a connected pipe client needs it
+ void UpdateRequestStats();
+
+ // Called by TunsafeServiceManager::HandleNotify to process events
+ // on each backend.
+ void HandleNotify();
+
+ // Send a state update to all connected pipes unless filter is set, then it
+ // sends only to that.
+ void SendStateUpdate(TunsafeServiceServer *filter);
+
+ // Called whenever a pipe server disconnects
+ void RemovePipeServer(TunsafeServiceServer *pipe_server);
+
+ // Called to register a pipe server with this backend
+ void AddPipeServer(TunsafeServiceServer *pipe_server);
+private:
+ // Points at the service manager
+ TunsafeServiceManager *manager_;
+
+ // Points at the actual TunsafeBackend
+ TunsafeBackend *backend_;
+
+ // Points at all |TunsafeServiceServer| currently associated with this
+ // backend.
+ std::vector pipe_servers_;
+
+ // Points at the thing that transmits TunsafeBackend events to
+ // the main thread
+ TunsafeBackend::Delegate *thread_delegate_;
+
+ // The config filename that is loaded
+ std::string current_filename_;
+
+ // Positions into |historical_log_lines_|
+ uint32 historical_log_lines_pos_;
+ uint32 historical_log_lines_count_;
+
+ enum { LOGLINE_COUNT = 256 };
+ char *historical_log_lines_[LOGLINE_COUNT];
+};
+
+// The server side of the client<->server pipe connection
+class TunsafeServiceServer : public PipeConnection::Delegate {
+
+public:
+ TunsafeServiceServer(PipeConnection *pipe, TunsafeServiceBackend *backend, uint32 unique_id);
+ virtual ~TunsafeServiceServer();
+
+ void WritePacket(int type, const uint8 *data, size_t data_size);
+
+ // -- from PipeConnection::Delegate
+ virtual bool HandleMessage(int type, uint8 *data, size_t size) override;
+ virtual void HandleDisconnect() override;
+
+ // Called by TunsafeServiceBackend to push a graph to the client
+ void OnGraphAvailable();
+
+ // Called by TunsafeServiceBackend to push more log lines to the client
+ void SendQueuedLogLines();
+
+ bool want_stats() const { return want_stats_; }
+ bool want_state_updates() const { return want_state_updates_; }
+ uint32 unique_id() const { return unique_id_; }
+ TunsafeServiceBackend *service_backend() { return service_backend_; }
+ void set_service_backend(TunsafeServiceBackend *sb) { service_backend_ = sb; }
+private:
+ bool AuthenticateUser();
+
+ // Whether the client wants state updates
+ bool want_state_updates_;
+
+ // Whether the client has authenticated
+ bool did_authenticate_user_;
+
+ // Whether we want stats
+ bool want_stats_;
+
+ // Whether the currently connected user wants a graph
+ uint32 want_graph_type_;
+
+ // The last log line sent to the currently connected user
+ uint32 last_line_sent_;
+
+ uint32 unique_id_;
+
+ // The pipe used to communicate
+ PipeConnection *connection_;
+
+ // The backend we're currently associated with
+ TunsafeServiceBackend *service_backend_;
+};
+
struct ServiceState {
uint8 is_started : 1;
uint8 internet_block_state_active : 1;
uint8 internet_block_state;
- uint8 reserved[26+64];
+ uint8 reserved[26 + 64];
uint32 ipv4_ip;
uint8 public_key[32];
};
STATIC_ASSERT(sizeof(ServiceState) == 128, ServiceState_wrong_size);
-class PipeMessageHandler {
-public:
- class Delegate {
- public:
- virtual bool HandleMessage(int type, uint8 *data, size_t size) = 0;
- virtual bool HandleNotify() = 0;
- virtual void HandleNewConnection() = 0;
- virtual void HandleDisconnect() = 0;
- };
-
- PipeMessageHandler(const char *pipe_name, bool is_server_pipe, Delegate *delegate);
- ~PipeMessageHandler();
-
- bool StartThread();
- void StopThread();
-
- bool WritePacket(int type, const uint8 *data, size_t data_size);
-
- HANDLE notify_handle() { return wait_handles_[1]; }
- HANDLE pipe_handle() { return pipe_; }
-
- bool VerifyThread();
-
- bool is_connected() { return connection_established_; }
-private:
- bool InitializeServerPipeAndWait();
- bool InitializeClientPipe();
- void AdvanceStateMachine();
- void ClosePipe();
- DWORD ThreadMain();
- void SendNextQueuedWrite();
- static DWORD WINAPI StaticThreadMain(void *x);
-
- Delegate *delegate_;
-
- HANDLE pipe_;
- HANDLE thread_;
- HANDLE wait_handles_[3];
- bool write_overlapped_active_;
- bool exit_thread_;
- bool is_server_pipe_;
- bool connection_established_;
- char *pipe_name_;
-
- enum State {
- kStateNone,
- kStateWaitConnect,
- kStateWaitReadLength,
- kStateWaitReadPayload,
- kStateWaitTimeout,
- };
-
- int state_;
-
- struct OutgoingPacket {
- OutgoingPacket *next;
- uint32 size;
- uint8 data[0];
- };
- OutgoingPacket *packets_, **packets_end_;
- uint8 *tmp_packet_buf_;
- DWORD tmp_packet_size_;
-
- OVERLAPPED write_overlapped_, read_overlapped_;
-
- Mutex packets_mutex_;
-
- DWORD thread_id_;
-};
-
-class TunsafeServiceImpl : public TunsafeBackend::Delegate, public PipeMessageHandler::Delegate {
-public:
- TunsafeServiceImpl();
- virtual ~TunsafeServiceImpl();
-
- // -- from TunsafeBackend::Delegate
- virtual void OnGetStats(const WgProcessorStats &stats);
- virtual void OnClearLog();
- virtual void OnLogLine(const char **s);
- virtual void OnStateChanged();
- virtual void OnStatusCode(TunsafeBackend::StatusCode status);
- virtual void OnGraphAvailable();
-
- // -- from PipeMessageHandler::Delegate
- virtual bool HandleMessage(int type, uint8 *data, size_t size);
- virtual bool HandleNotify();
- virtual void HandleNewConnection();
- virtual void HandleDisconnect();
-
- // virtual methods
- virtual unsigned OnStart(int argc, wchar_t **argv);
- virtual void OnStop();
- virtual void OnShutdown();
-
- TunsafeBackend::Delegate *delegate() { return thread_delegate_; }
-
-private:
- void SendQueuedLogLines();
- bool AuthenticateUser();
-
- bool did_send_getstate_;
-
- bool did_authenticate_user_;
- uint32 want_graph_type_;
-
- HKEY hkey_;
-
- TunsafeBackend *backend_;
- TunsafeBackend::Delegate *thread_delegate_;
-
- PipeMessageHandler message_handler_;
-
- uint32 historical_log_lines_pos_;
- uint32 historical_log_lines_count_;
- uint32 last_line_sent_;
- std::string current_filename_;
-
- enum {
- LOGLINE_COUNT = 256
- };
- char *historical_log_lines_[LOGLINE_COUNT];
-};
-
-class TunsafeServiceClient : public TunsafeBackend, public PipeMessageHandler::Delegate {
+class TunsafeServiceClient : public TunsafeBackend, public PipeConnection::Delegate, public PipeManager::Delegate {
public:
TunsafeServiceClient(TunsafeBackend::Delegate *delegate);
virtual ~TunsafeServiceClient();
- virtual bool Initialize();
+
+ // -- from TunsafeBackend
+ virtual bool Configure();
virtual void Teardown();
+ virtual bool SetTunAdapterName(const char *name);
virtual void Start(const char *config_file);
virtual void Stop();
virtual void RequestStats(bool enable);
@@ -160,12 +189,16 @@ public:
virtual std::string GetConfigFileName();
virtual void SetServiceStartupFlags(uint32 flags);
virtual LinearizedGraph *GetGraph(int type);
+ virtual void SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) override;
+
+ // -- from PipeConnection::Delegate
+ virtual bool HandleMessage(int type, uint8 *data, size_t size) override;
+ virtual void HandleDisconnect() override;
+
+ // -- from PipeManager::Delegate
+ virtual void HandleNotify() override;
+ virtual PipeConnection::Delegate *HandleNewConnection(PipeConnection *connection) override;
- // -- from PipeMessageHandler::Delegate
- virtual bool HandleMessage(int type, uint8 *data, size_t size);
- virtual bool HandleNotify();
- virtual void HandleNewConnection();
- virtual void HandleDisconnect();
protected:
TunsafeBackend::Delegate *delegate_;
@@ -173,8 +206,10 @@ protected:
bool got_state_from_control_;
ServiceState service_state_;
std::string config_file_;
- PipeMessageHandler message_handler_;
+ PipeManager pipe_manager_;
+ PipeConnection *connection_;
LinearizedGraph *cached_graph_;
uint32 last_graph_type_;
Mutex mutex_;
};
+
diff --git a/service_win32_constants.h b/service_win32_constants.h
new file mode 100644
index 0000000..f1c5bd1
--- /dev/null
+++ b/service_win32_constants.h
@@ -0,0 +1,36 @@
+#pragma once
+
+#define TUNSAFE_PIPE_NAME "\\\\.\\pipe\\TunSafe\\ServiceControl"
+#define TUNSAFE_SERVICE_PROTOCOL_VERSION 20180916001
+
+enum {
+ TS_SERVICE_REQ_LOGIN = 0,
+ TS_SERVICE_REQ_START = 1,
+ TS_SERVICE_REQ_STOP = 2,
+
+ TS_SERVICE_REQ_GETSTATS = 4,
+ TS_SERVICE_REQ_SET_INTERNET_BLOCKSTATE = 5,
+ TS_SERVICE_REQ_RESETSTATS = 6,
+ TS_SERVICE_REQ_SET_STARTUP_FLAGS = 7,
+
+ TS_SERVICE_MSG_STATE = 8,
+ TS_SERVICE_MSG_LOGLINE = 9,
+ TS_SERVICE_MSG_ERROR_REPLY = 10,
+ TS_SERVICE_MSG_STATS = 11,
+ TS_SERVICE_MSG_CLEARLOG = 12,
+ TS_SERVICE_MSG_STATUS_CODE = 14,
+
+ TS_SERVICE_REQ_GET_GRAPH = 15,
+ TS_SERVICE_MSG_GRAPH = 16,
+
+
+ TS_SERVICE_REQ_TEXT_PROTOCOL = 17,
+ TS_SERVICE_REQ_TEXT_PROTOCOL_REPLY = 18,
+
+ TS_SERVICE_REQ_GETINTERFACES = 19,
+ TS_SERVICE_REQ_GETINTERFACES_REPLY = 20,
+};
+
+enum {
+ kTsMaxDevnameSize = 40
+};
\ No newline at end of file
diff --git a/stdafx.h b/stdafx.h
index 61625f4..a729e22 100644
--- a/stdafx.h
+++ b/stdafx.h
@@ -13,10 +13,14 @@
#if defined(OS_WIN)
#define _WINSOCK_DEPRECATED_NO_WARNINGS 1
+#define _HAS_EXCEPTIONS 0
+#define _CRT_SECURE_NO_WARNINGS 1
+
//#include
#include
#include
+#undef max
//#include
#include
#include
diff --git a/ts.cpp b/ts.cpp
new file mode 100644
index 0000000..2cd5d8c
--- /dev/null
+++ b/ts.cpp
@@ -0,0 +1,883 @@
+#include "stdafx.h"
+#include "tunsafe_types.h"
+#include "netapi.h"
+#include "crypto/curve25519-donna.h"
+#include "util.h"
+#include "wireguard_proto.h"
+#include
+#include
+
+#if defined(OS_WIN)
+#include "util_win32.h"
+#include "service_pipe_win32.h"
+#include "service_win32_constants.h"
+#endif // defined(OS_WIN)
+
+#if defined(OS_POSIX)
+#include
+#include
+#include
+#include
+#include
+#endif // defined(OS_WIN)
+
+
+#pragma comment(lib, "ws2_32.lib")
+
+#define ANSI_RESET "\x1b[0m"
+#define ANSI_BOLD "\x1b[1m"
+#define ANSI_FG_BLACK "\x1b[30m"
+#define ANSI_FG_RED "\x1b[31m"
+#define ANSI_FG_GREEN "\x1b[32m"
+#define ANSI_FG_YELLOW "\x1b[33m"
+#define ANSI_FG_BLUE "\x1b[34m"
+#define ANSI_FG_MAGENTA "\x1b[35m"
+#define ANSI_FG_CYAN "\x1b[36m"
+#define ANSI_FG_WHITE "\x1b[37m"
+
+static const uint8 kCurve25519Basepoint[32] = {9};
+
+#if defined(OS_WIN)
+#define EXENAME "ts"
+
+static bool SendMessageToService(HANDLE pipe, int message, const void *data, size_t data_size) {
+ uint8 *temp = new uint8[data_size + 5];
+ *(uint32*)temp = (uint32)(data_size + 1);
+ temp[4] = (uint8)message;
+ memcpy(temp + 5, data, data_size);
+ // Write the whole thing
+ DWORD pos = 0, bytes_to_write = (DWORD)(data_size + 5), bytes_written;
+ do {
+ if (!WriteFile(pipe, temp + pos, bytes_to_write, &bytes_written, NULL)) {
+ fprintf(stderr, "Error writing to service pipe, error = %d\n", GetLastError());
+ break;
+ }
+ pos += bytes_written;
+ bytes_to_write -= bytes_written;
+ } while (bytes_to_write != 0);
+ delete[] temp;
+ return (bytes_to_write == 0);
+}
+
+static bool ReadExactBytesFromPipe(HANDLE pipe, const void *data, DWORD bytes_to_read) {
+ DWORD pos = 0, n;
+ do {
+ if (!ReadFile(pipe, (uint8*)data + pos, bytes_to_read, &n, NULL))
+ return false;
+ if (n == 0)
+ return false; // premature eof..
+ pos += n;
+ bytes_to_read -= n;
+ } while (bytes_to_read != 0);
+ return true;
+}
+
+static bool ReadMessageFromService(HANDLE pipe, int *message, std::string *data) {
+ uint8 header[5];
+ uint32 message_size;
+
+ if (!ReadExactBytesFromPipe(pipe, header, 5) || (message_size = *(uint32*)header) == 0) {
+ fprintf(stderr, "Error reading from service pipe, error = %d\n", GetLastError());
+ return false;
+ }
+ *message = header[4];
+ data->resize(message_size - 1);
+ if (message_size - 1 != 0 && !ReadExactBytesFromPipe(pipe, data->data(), message_size - 1)) {
+ fprintf(stderr, "Error reading from service pipe, error = %d\n", GetLastError());
+ return false;
+ }
+ return true;
+}
+
+struct ServiceLoginMessage {
+ uint64 version;
+ char interfac[kTsMaxDevnameSize];
+ bool want_state_updates;
+ bool want_create_interface;
+};
+
+static std::vector g_tap_adapters;
+static bool g_did_get_adapters;
+static const std::vector &GetTapAdapterInfo() {
+ if (!g_did_get_adapters) {
+ g_did_get_adapters = true;
+ GetTapAdapterInfo(&g_tap_adapters);
+ }
+ return g_tap_adapters;
+}
+
+static const char *GetGuidFromInterfaceName(const char *name) {
+ for (const GuidAndDevName &e : GetTapAdapterInfo())
+ if (strcmp(e.name, name) == 0)
+ return e.guid;
+ return NULL;
+}
+
+static const char *GetInterfaceNameFromGuid(const char *guid) {
+ for (const GuidAndDevName &e : GetTapAdapterInfo())
+ if (strcmp(e.guid, guid) == 0)
+ return e.name;
+ return NULL;
+}
+
+
+static HANDLE ConnectToService(const char *devname, bool want_updates, bool want_create = false) {
+ ServiceLoginMessage msg = {0};
+ msg.version = TUNSAFE_SERVICE_PROTOCOL_VERSION;
+ msg.want_state_updates = want_updates;
+ msg.want_create_interface = want_create;
+
+ // Rename devname to a guid
+ if (devname) {
+ const char *guid = (devname[0] == '{' || devname[0] == 0) ? devname : GetGuidFromInterfaceName(devname);
+ if (!guid) {
+ fprintf(stderr, "Interface '%s' not found\n", devname);
+ return NULL;
+ }
+ my_strlcpy(msg.interfac, sizeof(msg.interfac), guid);
+ }
+
+ for (;;) {
+ HANDLE pipe = CreateFile(TUNSAFE_PIPE_NAME, GENERIC_READ | GENERIC_WRITE, 0, NULL,
+ OPEN_EXISTING, 0, NULL);
+ if (pipe != INVALID_HANDLE_VALUE) {
+ if (!SendMessageToService(pipe, TS_SERVICE_REQ_LOGIN, &msg, sizeof(msg))) {
+ CloseHandle(pipe);
+ pipe = NULL;
+ }
+ return pipe;
+ }
+ DWORD error = GetLastError();
+ if (error != ERROR_PIPE_BUSY) {
+ fprintf(stderr, "Error connecting to TunSafe service, error = %d\n", error);
+ if (error == ERROR_FILE_NOT_FOUND)
+ fprintf(stderr, "Please check that the TunSafe service is started\n");
+ return NULL;
+ }
+ if (!WaitNamedPipe(TUNSAFE_PIPE_NAME, 10000)) {
+ fprintf(stderr, "Error connecting to TunSafe service, timed out.\n");
+ return NULL;
+ }
+ }
+}
+
+static bool CommunicateWithService(const char *devname, const std::string &query, std::string *reply) {
+ HANDLE pipe = ConnectToService(devname, false);
+ int message_code;
+ bool rv = false;
+
+ if (pipe != NULL &&
+ SendMessageToService(pipe, TS_SERVICE_REQ_TEXT_PROTOCOL, query.data(), query.size()) &&
+ ReadMessageFromService(pipe, &message_code, reply)) {
+ if (message_code == TS_SERVICE_REQ_TEXT_PROTOCOL_REPLY) {
+ rv = true;
+ } else {
+ if (message_code == TS_SERVICE_MSG_ERROR_REPLY) {
+ fprintf(stderr, "Error: %s\n", reply->c_str());
+ } else {
+ fprintf(stderr, "Unknown reply (%d) from TunSafe service.\n", message_code);
+ }
+ }
+ }
+ CloseHandle(pipe);
+ return rv;
+}
+
+static bool GetInterfaceList(std::string *result) {
+ HANDLE pipe = ConnectToService(NULL, false);
+ int message_code;
+ bool rv = false;
+
+ if (pipe != NULL &&
+ SendMessageToService(pipe, TS_SERVICE_REQ_GETINTERFACES, NULL, 0) &&
+ ReadMessageFromService(pipe, &message_code, result)) {
+ if (message_code == TS_SERVICE_REQ_GETINTERFACES_REPLY) {
+ rv = true;
+ } else {
+ fprintf(stderr, "GetInterfaceList: bad reply\n");
+ }
+ }
+ CloseHandle(pipe);
+ return rv;
+}
+#endif // defined(OS_WIN)
+
+#if defined(OS_POSIX)
+#define EXENAME "tunsafe"
+
+static const char *GetGuidFromInterfaceName(const char *name) {
+ return name;
+}
+
+static const char *GetInterfaceNameFromGuid(const char *guid) {
+ return guid;
+}
+
+
+static int OpenUserspaceInterface(const char *iface) {
+ struct stat st;
+ struct sockaddr_un un = { 0 };
+ int fd = -1, rv;
+
+ if (strchr(iface, '/') != NULL) {
+ fprintf(stderr, "Unable to open usermode socket: No such device\n");
+ goto getout;
+ }
+
+ snprintf(un.sun_path, sizeof(un.sun_path), "/var/run/wireguard/%s.sock", iface);
+ if (stat(un.sun_path, &st) < 0) {
+ perror("Unable to open usermode socket");
+ goto getout;
+ }
+
+ if (!S_ISSOCK(st.st_mode)) {
+ fprintf(stderr, "Unable to open usermode socket: No such device\n");
+ goto getout;
+ }
+
+ fd = socket(AF_UNIX, SOCK_STREAM, 0);
+ if (fd < 0)
+ goto getout;
+
+ un.sun_family = AF_UNIX;
+ if (connect(fd, (struct sockaddr *)&un, sizeof(un)) < 0) {
+ if (errno == ECONNREFUSED)
+ unlink(un.sun_path);
+ else
+ perror("Error opening wireguard usermode interface socket");
+ goto getout;
+ }
+ return fd;
+
+getout:
+ if (fd >= 0)
+ close(fd);
+ return -1;
+}
+
+
+static bool GetInterfaceList(std::string *result) {
+ struct dirent *dent;
+
+ DIR *dir = opendir("/var/run/wireguard/");
+ if (!dir)
+ return errno == ENOENT;
+
+ while ((dent = readdir(dir)) != NULL) {
+ size_t len = strlen(dent->d_name);
+ static const char kSuffix[6] = ".sock";
+ if (len >= sizeof(kSuffix) - 1 &&
+ memcmp(&dent->d_name[len - (sizeof(kSuffix) - 1)], kSuffix, sizeof(kSuffix) - 1) == 0) {
+ dent->d_name[len - (sizeof(kSuffix) - 1)] = '\n';
+ result->append(dent->d_name, len - (sizeof(kSuffix) - 1) + 1);
+ }
+ }
+ closedir(dir);
+ return true;
+}
+
+static bool CommunicateWithService(const char *devname, const std::string &query, std::string *reply) {
+ ssize_t n;
+ char buf[4096];
+ bool rv = false;
+
+ reply->clear();
+
+ int fd = OpenUserspaceInterface(devname);
+ if (fd == -1)
+ return false;
+
+ for(size_t pos = 0; query.size() - pos; pos += n) {
+ n = write(fd, query.data() + pos, query.size() - pos);
+ if (n <= 0) {
+ perror("Error writing to service pipe");
+ goto getout;
+ }
+ }
+
+ for(;;) {
+ n = read(fd, buf, sizeof(buf));
+ if (n <= 0) {
+ if (n == 0) {
+ // ensure that it ends with \n\n
+ if (reply->size() >= 2 && (*reply)[reply->size() - 1] == '\n' && (*reply)[reply->size() - 2] == '\n') {
+ rv = true;
+ } else {
+ fprintf(stderr, "Bad reply from service pipe\n");
+ }
+ } else {
+ perror("Error reading from service pipe");
+ }
+ break;
+ }
+ reply->append(buf, n);
+ }
+
+getout:
+ close(fd);
+ return rv;
+}
+
+static int HandleStopCommand(int argc, char **argv) {
+ if (argc != 1) {
+ fprintf(stderr, "Usage: " EXENAME " stop \n");
+ return 1;
+ }
+ struct sockaddr_un un;
+ struct stat st;
+ const char *iface = argv[0];
+ if (strchr(iface, '/')) {
+ fprintf(stderr, "No such interface\n");
+ return 1;
+ }
+ snprintf(un.sun_path, sizeof(un.sun_path), "/var/run/wireguard/%s.sock", iface);
+ if (unlink(un.sun_path) == -1) {
+ perror("unlink");
+ return 1;
+ }
+ return 0;
+}
+
+#endif // defined(OS_POSIX)
+
+void ShowHelp() {
+ fprintf(stderr,
+ "Usage: " EXENAME " []\n\n"
+#if defined(OS_POSIX)
+ " " EXENAME " filename.conf\n\n"
+#endif // defined(OS_POSIX)
+ "Available subcommands:\n"
+ " show: Shows the configuration and status of the interfaces\n"
+ " set: Change the configuration or the peer list\n"
+ " start: Start TunSafe on an interface\n"
+ " stop: Stop TunSafe on an interface\n"
+#if defined(OS_WIN)
+ " log: Display recent log entries\n"
+#endif // defined(OS_WIN)
+ " genkey: Writes a new private key to stdout\n"
+ " genpsk: Writes a new preshared key to stdout\n"
+ " pubkey: Reads a private key from stdin and writes its public key to stdout\n"
+ "To see more help about a subcommand, pass --help to it\n");
+}
+
+
+
+static bool ParseHexKeyToBase64(const char *key, char base64key[WG_PUBLIC_KEY_LEN_BASE64 + 1]) {
+ uint8 keybuf[32];
+ if (!ParseHexString(key, keybuf, 32))
+ return false;
+ return base64_encode(keybuf, 32, base64key, WG_PUBLIC_KEY_LEN_BASE64 + 1, NULL) != NULL;
+}
+
+static char *FormatTransferPart(char *buf, size_t bufsize, uint64 n) {
+ if (n < 1024)
+ snprintf(buf, bufsize, "%u " ANSI_FG_CYAN "B" ANSI_RESET, (unsigned)n);
+ else if (n < 1024 * 1024)
+ snprintf(buf, bufsize, "%.2f " ANSI_FG_CYAN "KiB" ANSI_RESET, (double)n * (1.0 / 1024));
+ else if (n < 1024 * 1024 * 1024)
+ snprintf(buf, bufsize, "%.2f " ANSI_FG_CYAN "MiB" ANSI_RESET, (double)n * (1.0 / 1024 / 1024));
+ else if (n < 1024ull * 1024 * 1024 * 1024)
+ snprintf(buf, bufsize, "%.2f " ANSI_FG_CYAN "GiB" ANSI_RESET, (double)n * (1.0 / 1024 / 1024 / 1024));
+ else
+ snprintf(buf, bufsize, "%.2f " ANSI_FG_CYAN "TiB" ANSI_RESET, (double)n * (1.0 / 1024 / 1024 / 1024 / 1024));
+ return buf;
+}
+
+static size_t PrintTime(char *buf, size_t bufsize, uint64 n) {
+ size_t pos = 0;
+ uint64 years = n / (365 * 24 * 60 * 60);
+ uint32 n32 = n % (365 * 24 * 60 * 60);
+ if (years)
+ pos += snprintf(buf + pos, bufsize - pos, "%llu " ANSI_FG_CYAN "year%s" ANSI_RESET ", ", (unsigned long long)years, (years == 1) ? "" : "s");
+ uint32 days = n32 / (24 * 60 * 60);
+ n32 %= (24 * 60 * 60);
+ if (days)
+ pos += snprintf(buf + pos, bufsize - pos, "%u " ANSI_FG_CYAN "day%s" ANSI_RESET ", ", days, (days == 1) ? "" : "s");
+ uint32 hours = n32 / (60 * 60);
+ n32 %= (60 * 60);
+ if (hours)
+ pos += snprintf(buf + pos, bufsize - pos, "%u " ANSI_FG_CYAN "hour%s" ANSI_RESET ", ", hours, (hours == 1) ? "" : "s");
+ uint32 minutes = n32 / 60;
+ if (minutes)
+ pos += snprintf(buf + pos, bufsize - pos, "%u " ANSI_FG_CYAN "minute%s" ANSI_RESET ", ", minutes, (minutes == 1) ? "" : "s");
+ uint32 seconds = n32 % 60;
+ if (seconds)
+ pos += snprintf(buf + pos, bufsize - pos, "%u " ANSI_FG_CYAN "second%s" ANSI_RESET ", ", seconds, (seconds == 1) ? "" : "s");
+ if (pos)
+ buf[pos -= 2] = '\0';
+ return pos;
+}
+
+static char *PrintHandshake(char *buf, size_t bufsize, uint64 secs) {
+ time_t now = time(NULL);
+ if (now == secs) {
+ snprintf(buf, bufsize, "Now");
+ } else if (now < (int64)secs) {
+ snprintf(buf, bufsize, ANSI_FG_RED "System clock going backwards" ANSI_RESET);
+ } else {
+ size_t pos = PrintTime(buf, bufsize - 4, now - secs);
+ memcpy(buf + pos, " ago", 5);
+ }
+ return buf;
+}
+
+static void AppendIpToString(const char *value, std::string *result) {
+ if (!result->empty())
+ (*result) += ", ";
+ const char *slash = strchr(value, '/');
+ if (slash) {
+ result->append(value, slash - value);
+ result->append(ANSI_FG_CYAN "/" ANSI_RESET);
+ result->append(slash + 1);
+ } else {
+ result->append(value);
+ }
+}
+
+static int ShowUserFriendlyForDevice(char *devname) {
+ std::string reply;
+ std::vector> kv;
+ std::string ips;
+
+ if (!CommunicateWithService(devname, "get=1\n\n", &reply))
+ return 1;
+
+ if (!ParseConfigKeyValue(&reply[0], &kv)) {
+getout_fail:
+ fprintf(stderr, "Unable to parse response");
+ return 1;
+ }
+
+ size_t i = 0;
+ char base64key[WG_PUBLIC_KEY_LEN_BASE64 + 1];
+ char base64psk[WG_PUBLIC_KEY_LEN_BASE64 + 1];
+ int listen_port = 0;
+ base64key[0] = 0;
+
+ // Parse all interface level keys
+ for (; i < kv.size(); i++) {
+ char *key = kv[i].first, *value = kv[i].second;
+ if (strcmp(key, "private_key") == 0) {
+ uint8 binkey[32];
+ if (!ParseHexString(value, binkey, sizeof(binkey)))
+ goto getout_fail;
+ if (!IsOnlyZeros(binkey, 32)) {
+ curve25519_donna(binkey, binkey, kCurve25519Basepoint);
+ base64_encode(binkey, sizeof(binkey), base64key, sizeof(base64key), NULL);
+ }
+ } else if (strcmp(key, "address") == 0) {
+ AppendIpToString(value, &ips);
+ } else if (strcmp(key, "listen_port") == 0) {
+ listen_port = atoi(value);
+ } else if (strcmp(key, "public_key") == 0) {
+ break;
+ }
+ }
+
+ const char *interfacename = (devname[0] == '{') ? GetInterfaceNameFromGuid(devname) : devname;
+
+ printf(ANSI_RESET ANSI_FG_GREEN ANSI_BOLD "interface" ANSI_RESET ": " ANSI_FG_GREEN "%s" ANSI_RESET "\n",
+ interfacename);
+ if (base64key[0]) {
+ printf(" " ANSI_BOLD "public key" ANSI_RESET ": %s\n"
+ " " ANSI_BOLD "private key" ANSI_RESET ": (hidden)\n", base64key);
+ }
+ if (listen_port)
+ printf(" " ANSI_BOLD "listening port" ANSI_RESET ": %d\n", listen_port);
+ if (ips.size())
+ printf(" " ANSI_BOLD "address" ANSI_RESET ": %s\n", ips.c_str());
+
+ const char *endpoint = NULL;
+ uint64 rx_bytes, tx_bytes, last_handshake_time_sec;
+ int persistent_keepalive;
+ char text[256];
+ bool clear_state = true;
+
+ // Parse peer level keys
+ for (; i < kv.size(); i++) {
+ char *key = kv[i].first, *value = kv[i].second;
+
+ if (clear_state) {
+ base64key[0] = base64psk[0] = 0;
+ endpoint = NULL;
+ ips.clear();
+ persistent_keepalive = 0;
+ last_handshake_time_sec = tx_bytes = rx_bytes = 0;
+ clear_state = false;
+ }
+ if (strcmp(key, "public_key") == 0) {
+ if (!ParseHexKeyToBase64(value, base64key))
+ goto getout_fail;
+ } else if (strcmp(key, "preshared_key") == 0) {
+ if (!ParseHexKeyToBase64(value, base64psk))
+ goto getout_fail;
+ } else if (strcmp(key, "tx_bytes") == 0) {
+ tx_bytes = strtoull(value, NULL, 0);
+ } else if (strcmp(key, "rx_bytes") == 0) {
+ rx_bytes = strtoull(value, NULL, 0);
+ } else if (strcmp(key, "allowed_ip") == 0) {
+ AppendIpToString(value, &ips);
+ } else if (strcmp(key, "persistent_keepalive_interval") == 0) {
+ persistent_keepalive = atoi(value);
+ } else if (strcmp(key, "endpoint") == 0) {
+ endpoint = value;
+ } else if (strcmp(key, "last_handshake_time_sec") == 0) {
+ last_handshake_time_sec = strtoull(value, NULL, 0);
+ }
+ if (i == kv.size() - 1 || strcmp(kv[i + 1].first, "public_key") == 0) {
+ if (!base64key[0])
+ goto getout_fail;
+ printf("\n" ANSI_FG_YELLOW ANSI_BOLD "peer" ANSI_RESET ": " ANSI_FG_YELLOW "%s" ANSI_RESET "\n", base64key);
+ if (base64psk[0])
+ printf(" " ANSI_BOLD "preshared key" ANSI_RESET ": (hidden)\n");
+ if (endpoint)
+ printf(" " ANSI_BOLD "endpoint" ANSI_RESET ": %s\n", endpoint);
+ printf(" " ANSI_BOLD "allowed ips" ANSI_RESET ": %s\n", ips.size() ? ips.c_str() : "(none)");
+ if (last_handshake_time_sec)
+ printf(" " ANSI_BOLD "latest handshake" ANSI_RESET ": %s\n", PrintHandshake(text, sizeof(text), last_handshake_time_sec));
+ if (tx_bytes | rx_bytes) {
+ printf(" " ANSI_BOLD "transfer" ANSI_RESET ": %s received, ", FormatTransferPart(text, sizeof(text), rx_bytes));
+ printf("%s sent\n", FormatTransferPart(text, sizeof(text), tx_bytes));
+ }
+ if (persistent_keepalive) {
+ PrintTime(text, sizeof(text), persistent_keepalive);
+ printf(" " ANSI_BOLD "persistent keepalive" ANSI_RESET ": every %s\n", text);
+ }
+ clear_state = true;
+ }
+ }
+ return 0;
+}
+
+static int HandleShowCommand(int argc, char **argv) {
+ if (argc != 0 && strcmp(argv[0], "--help") == 0) {
+ fprintf(stderr, "Usage: ts show { | all | interfaces }\n");
+ return 0;
+ }
+
+ std::vector interfaces;
+ std::string interfaces_str;
+
+ if (argc == 0 || strcmp(argv[0], "all") == 0) {
+ if (!GetInterfaceList(&interfaces_str))
+ return 1;
+ SplitString(&interfaces_str[0], '\n', &interfaces);
+
+ bool want_newline = false;
+ for (char *interfac : interfaces) {
+ if (want_newline)
+ printf("\n");
+ want_newline = true;
+ if (ShowUserFriendlyForDevice(interfac))
+ return 1;
+ }
+ } else if (strcmp(argv[0], "interfaces") == 0) {
+ if (!GetInterfaceList(&interfaces_str))
+ return 1;
+ SplitString(&interfaces_str[0], '\n', &interfaces);
+
+ for (char *interfac : interfaces) {
+ const char *name = GetInterfaceNameFromGuid(interfac);
+ if (name)
+ printf("%s\n", name);
+ }
+ } else {
+ return ShowUserFriendlyForDevice(argv[0]);
+ }
+ return 0;
+}
+
+static void AppendCommand(std::string *result, const char *tag, const char *value) {
+ result->append(tag);
+ result->append("=");
+ result->append(value);
+ result->append("\n");
+}
+
+static bool ConvertBase64KeyToHex(const char *s, char key[65]) {
+ uint8 tmp[32];
+ size_t size = 32;
+ if (!base64_decode((uint8*)s, strlen(s), tmp, &size) || size != 32)
+ return false;
+ PrintHexString(tmp, 32, key);
+ return true;
+}
+
+static int HandleSetCommand(int argc, char **argv) {
+ std::string command, reply;
+ std::vector ss;
+ char hexkey[65];
+
+ if (argc == 0) {
+ fprintf(stderr, "Usage: ts set [address ] [listen-port ] [private-key ] "
+ "[peer [remove] [preshared-key ] [endpoint :] "
+ "[persistent-keepalive ] [allowed-ips /[,/]] ]");
+ return 1;
+ }
+ char **argv_end = argv + argc;
+ const char *interfc = *argv++;
+
+ command = "set=1\n";
+
+ bool in_interface_section = true;
+ bool in_peer_section = false;
+ bool did_clear_allowed_ips = false;
+
+ while (argv != argv_end) {
+ const char *key = *argv++;
+
+ if (argv != argv_end) {
+ if (in_interface_section) {
+ if (strcmp(key, "listen-port") == 0) {
+ AppendCommand(&command, "listen_port", *argv++);
+ continue;
+ } else if (strcmp(key, "address") == 0) {
+ AppendCommand(&command, "address", *argv++);
+ continue;
+ } else if (strcmp(key, "private-key") == 0) {
+ if (!ConvertBase64KeyToHex(*argv++, hexkey))
+ goto invalid_key_format;
+ AppendCommand(&command, "private_key", hexkey);
+ continue;
+ }
+ }
+ if (strcmp(key, "peer") == 0) {
+ in_interface_section = false;
+ in_peer_section = true;
+ did_clear_allowed_ips = false;
+ if (!ConvertBase64KeyToHex(*argv++, hexkey))
+ goto invalid_key_format;
+ AppendCommand(&command, "public_key", hexkey);
+
+ continue;
+ }
+ if (in_peer_section) {
+ if (strcmp(key, "preshared-key") == 0) {
+ if (!ConvertBase64KeyToHex(*argv++, hexkey))
+ goto invalid_key_format;
+ AppendCommand(&command, "preshared_key", hexkey);
+ continue;
+ } else if (strcmp(key, "endpoint") == 0) {
+ AppendCommand(&command, "endpoint", *argv++);
+ continue;
+ } else if (strcmp(key, "persistent-keepalive") == 0) {
+ AppendCommand(&command, "persistent_keepalive_interval", *argv++);
+ continue;
+ } else if (strcmp(key, "allowed-ips") == 0) {
+ if (!did_clear_allowed_ips) {
+ AppendCommand(&command, "replace_allowed_ips", "true");
+ did_clear_allowed_ips = true;
+ }
+ SplitString(*argv++, ',', &ss);
+ for (char *x : ss)
+ AppendCommand(&command, "allowed_ip", x);
+ continue;
+ }
+ }
+ }
+ if (in_peer_section) {
+ if (strcmp(key, "remove") == 0) {
+ in_peer_section = false;
+ AppendCommand(&command, "remove", "true");
+ continue;
+ }
+ }
+
+ fprintf(stderr, "Invalid argument: %s\n", key);
+ return 1;
+
+invalid_key_format:
+ fprintf(stderr, "Key is not in the correct format: '%s'\n", argv[-1]);
+ return 1;
+ }
+
+ command.append("\n");
+
+ if (!CommunicateWithService(interfc, command, &reply))
+ return 1;
+
+ return 0;
+}
+
+#if defined(OS_WIN)
+static int HandleLogCommand() {
+ HANDLE pipe = ConnectToService(NULL, true);
+
+ int message_code;
+ std::string reply;
+
+ while (pipe != NULL && ReadMessageFromService(pipe, &message_code, &reply) && message_code == TS_SERVICE_MSG_LOGLINE)
+ printf("%s\n", reply.c_str());
+
+ CloseHandle(pipe);
+ return 0;
+}
+
+static int HandleStartCommand(int argc, char **argv) {
+ if (argc < 1 || argc > 2 || strcmp(argv[0], "--help") == 0) {
+ fprintf(stderr, "Usage: " EXENAME " start []\n");
+ return 1;
+ }
+
+ const char *devname = argv[0];
+ HANDLE pipe = ConnectToService(devname, false, true);
+ int message_code;
+ std::string reply;
+
+ const char *path = (argc == 1) ? "" : argv[1];
+
+ // Tell the server to startup a new interface
+ if (pipe == NULL ||
+ !SendMessageToService(pipe, TS_SERVICE_REQ_START, path, strlen(path) + 1) ||
+ !ReadMessageFromService(pipe, &message_code, &reply))
+ return 1;
+
+ if (message_code == TS_SERVICE_MSG_ERROR_REPLY) {
+ fprintf(stderr, "%s\n", reply.c_str());
+ return 1;
+ }
+
+ return 0;
+}
+
+static int HandleStopCommand(int argc, char **argv) {
+ if (argc != 1) {
+ fprintf(stderr, "Usage: " EXENAME " stop \n");
+ return 1;
+ }
+
+ const char *devname = argv[0];
+ HANDLE pipe = ConnectToService(devname, false);
+
+ // Tell the server to stop the interface
+ if (pipe == NULL ||
+ !SendMessageToService(pipe, TS_SERVICE_REQ_STOP, NULL, 0))
+ return 1;
+ return 0;
+}
+#endif // defined(OS_WIN)
+
+
+struct CommandLineOutput {
+ const char *filename_to_load;
+ const char *interface_name;
+ bool daemon;
+};
+
+// Returns -1 on invalid subcommand
+int HandleCommandLine(int argc, char **argv, CommandLineOutput *output) {
+ uint8 key[32];
+ char base64buf[WG_PUBLIC_KEY_LEN_BASE64 + 1];
+
+ if (argc == 1) {
+ ShowHelp();
+ return 1;
+ }
+
+ const char *subcommand = argv[1];
+ argv += 2;
+ argc -= 2;
+
+ if (!strcmp(subcommand, "show")) {
+ return HandleShowCommand(argc, argv);
+
+ } else if (!strcmp(subcommand, "set")) {
+ return HandleSetCommand(argc, argv);
+
+#if defined(OS_WIN)
+ } else if (!strcmp(subcommand, "log")) {
+ if (argc != 0) {
+ fprintf(stderr, "Usage: " EXENAME " log\n");
+ return 1;
+ }
+ return HandleLogCommand();
+
+ } else if (!strcmp(subcommand, "start")) {
+ return HandleStartCommand(argc, argv);
+
+#else
+ } else if (!strcmp(subcommand, "start") && output) {
+ if (argc != 0 && !strcmp(argv[0], "--help")) {
+start_usage:
+ fprintf(stderr, "Usage: " EXENAME " start [-d/--daemon] [-n ] []\n");
+ return 0;
+ }
+ for (; argc; argc--, argv++) {
+ char *arg = argv[0];
+ if (strcmp(arg, "-d") == 0 || strcmp(arg, "--daemon") == 0) {
+ output->daemon = true;
+ continue;
+ }
+ if (strcmp(arg, "-n") == 0) {
+ if (argc < 2) goto start_usage;
+ output->interface_name = argv[1];
+ argc--,argv++;
+ continue;
+ }
+ break;
+ }
+ if (argc > 1) goto start_usage;
+ output->filename_to_load = (argc == 0) ? "" : argv[0];
+ return 0;
+#endif // defined(OS_WIN)
+ } else if (!strcmp(subcommand, "stop")) {
+ return HandleStopCommand(argc, argv);
+ } else if(!strcmp(subcommand, "genkey")) {
+ if (argc != 0) {
+ fprintf(stderr, "Usage: " EXENAME " genkey\n");
+ return 1;
+ }
+ OsGetRandomBytes(key, 32);
+ curve25519_normalize(key);
+ printf("%s\n", base64_encode(key, 32, base64buf, sizeof(base64buf), NULL));
+
+ } else if (!strcmp(subcommand, "genpsk")) {
+ if (argc != 0) {
+ fprintf(stderr, "Usage: " EXENAME " genpsk\n");
+ return 1;
+ }
+ OsGetRandomBytes(key, 32);
+ printf("%s\n", base64_encode(key, 32, base64buf, sizeof(base64buf), NULL));
+ } else if (!strcmp(subcommand, "pubkey")) {
+ char base64[WG_PUBLIC_KEY_LEN_BASE64 + 2];
+ size_t n = fread(base64, 1, sizeof(base64), stdin);
+ if (n < sizeof(base64) - 2 || n >= sizeof(base64) ||
+ (n == sizeof(base64) - 1 && (base64[WG_PUBLIC_KEY_LEN_BASE64] != ' ' && base64[WG_PUBLIC_KEY_LEN_BASE64] != '\n'))) {
+ fprintf(stderr, EXENAME ": Incorrect key format\n");
+ return 1;
+ }
+ size_t size = 32;
+ if (!base64_decode((uint8*)base64, n, key, &size) || size != 32) {
+ fprintf(stderr, EXENAME ": Incorrect key format\n");
+ return 1;
+ }
+ curve25519_donna(key, key, kCurve25519Basepoint);
+ printf("%s\n", base64_encode(key, 32, base64buf, sizeof(base64buf), NULL));
+ } else if (!strcmp(subcommand, "--help")) {
+ ShowHelp();
+ } else if (!strcmp(subcommand, "--version")) {
+ printf("%s\n", TUNSAFE_VERSION_STRING);
+ } else {
+ if (argc == 0) {
+ if (output)
+ output->filename_to_load = subcommand;
+ } else {
+ ShowHelp();
+ }
+ return -1;
+ }
+ return 0;
+}
+
+#if defined(OS_WIN)
+// This is integrated into the main tunsafe binary on posix systems
+int main(int argc, char **argv) {
+ int rv = HandleCommandLine(argc, argv, NULL);
+ if (rv == -1) {
+ fprintf(stderr, "Invalid subcommand '%s'\n", argv[1]);
+ ShowHelp();
+ return 1;
+ }
+ return rv;
+}
+#endif // defined(OS_WIN)
diff --git a/ts.vcxproj b/ts.vcxproj
new file mode 100644
index 0000000..0774b3d
--- /dev/null
+++ b/ts.vcxproj
@@ -0,0 +1,199 @@
+
+
+
+
+ Debug
+ Win32
+
+
+ Release
+ Win32
+
+
+ Debug
+ x64
+
+
+ Release
+ x64
+
+
+
+ 15.0
+ {443E105E-8D7C-401F-BD41-D3F56C76104B}
+ Win32Proj
+ ts
+ 10.0.17134.0
+
+
+
+ Application
+ true
+ v141
+ MultiByte
+
+
+ Application
+ false
+ v141
+ true
+ MultiByte
+
+
+ Application
+ true
+ v141
+ MultiByte
+
+
+ Application
+ false
+ v141
+ true
+ MultiByte
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ true
+ $(SolutionDir)build\$(Platform)_$(Configuration)\
+ $(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\
+
+
+ true
+ $(SolutionDir)build\$(Platform)_$(Configuration)\
+ $(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\
+
+
+ false
+ $(SolutionDir)build\$(Platform)_$(Configuration)\
+ $(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\
+
+
+ false
+ $(SolutionDir)build\$(Platform)_$(Configuration)\
+ $(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\
+
+
+
+ Use
+ Level3
+ Disabled
+ true
+ _DEBUG;_CONSOLE;%(PreprocessorDefinitions)
+ true
+ false
+
+
+ Console
+ true
+
+
+
+
+ Use
+ Level3
+ Disabled
+ true
+ WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions)
+ true
+ false
+
+
+ Console
+ true
+
+
+
+
+ Use
+ Level3
+ MinSpace
+ true
+ true
+ true
+ WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions)
+ true
+ MultiThreaded
+ true
+ false
+
+
+ Console
+ true
+ true
+ true
+
+
+
+
+ Use
+ Level3
+ MinSpace
+ true
+ true
+ true
+ NDEBUG;_CONSOLE;%(PreprocessorDefinitions)
+ true
+ MultiThreaded
+ true
+ false
+
+
+ Console
+ true
+ true
+ true
+
+
+
+
+
+
+
+
+
+
+ NotUsing
+ NotUsing
+ NotUsing
+ NotUsing
+
+
+ Create
+ Create
+ Create
+ Create
+
+
+
+
+
+
+
+ true
+ true
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/ts.vcxproj.filters b/ts.vcxproj.filters
new file mode 100644
index 0000000..4d6247d
--- /dev/null
+++ b/ts.vcxproj.filters
@@ -0,0 +1,53 @@
+
+
+
+
+ {4FC737F1-C7A5-4376-A066-2A32D752A2FF}
+ cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx
+
+
+ {93995380-89BD-4b04-88EB-625FBE52EBFB}
+ h;hh;hpp;hxx;hm;inl;inc;ipp;xsd
+
+
+ {67DA6AB6-F800-4c08-8B7A-83BB121AAD01}
+ rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms
+
+
+
+
+ Header Files
+
+
+ Source Files
+
+
+ Source Files
+
+
+ Source Files
+
+
+
+
+ Source Files
+
+
+ Source Files
+
+
+ Source Files
+
+
+ Source Files
+
+
+ Source Files
+
+
+
+
+ Source Files
+
+
+
\ No newline at end of file
diff --git a/tunsafe_threading.cpp b/tunsafe_threading.cpp
index af21db3..f7c64b7 100644
--- a/tunsafe_threading.cpp
+++ b/tunsafe_threading.cpp
@@ -10,10 +10,14 @@ MultithreadedDelayedDelete::MultithreadedDelayedDelete() {
}
MultithreadedDelayedDelete::~MultithreadedDelayedDelete() {
+ assert(curr_.size() == 0);
+ assert(next_.size() == 0);
+ assert(to_delete_.size() == 0);
free(table_);
}
-void MultithreadedDelayedDelete::Initialize(uint32 num_threads) {
+void MultithreadedDelayedDelete::Configure(uint32 num_threads) {
+ assert(table_ == NULL);
num_threads_ = num_threads;
table_ = (CheckpointData*)calloc(sizeof(CheckpointData), num_threads);
}
diff --git a/tunsafe_threading.h b/tunsafe_threading.h
index 1362678..d595104 100644
--- a/tunsafe_threading.h
+++ b/tunsafe_threading.h
@@ -150,12 +150,14 @@ public:
typedef void DoDeleteFunc(void *x);
void Add(DoDeleteFunc *func, void *param);
- void Initialize(uint32 num_threads);
+ void Configure(uint32 num_threads);
void Checkpoint(uint32 thread_id);
void MainCheckpoint();
+ bool enabled() const { return num_threads_ != 0; }
+
private:
struct Entry {
DoDeleteFunc *func;
diff --git a/tunsafe_win32.cpp b/tunsafe_win32.cpp
index 96dba0e..daea673 100644
--- a/tunsafe_win32.cpp
+++ b/tunsafe_win32.cpp
@@ -107,22 +107,6 @@ static void SetUiVisibility(bool visible) {
UpdateGraphReq();
}
-static bool GetConfigFullName(const char *basename, char *fullname, size_t fullname_size) {
- size_t len = strlen(basename);
-
- if (FindFilenameComponent(basename)[0]) {
- if (len >= fullname_size)
- return false;
- memcpy(fullname, basename, len + 1);
- return true;
- }
- size_t clen = GetConfigPath(fullname, fullname_size);
- if (clen == 0 || clen + len >= fullname_size)
- return false;
- memcpy(fullname + clen, basename, (len + 1) * sizeof(fullname[0]));
- return true;
-}
-
void StopTunsafeBackend(UpdateIconWhy why) {
if (g_backend->is_started()) {
g_backend->Stop();
@@ -154,6 +138,7 @@ void StartTunsafeBackend(UpdateIconWhy reason) {
}
g_notified_connected_server = false;
g_is_connected_to_server = false;
+ memset(&g_processor_stats, 0, sizeof(g_processor_stats));
g_backend->Start(g_current_filename);
RegWriteInt(g_reg_key, "IsConnected", 1);
}
@@ -189,12 +174,21 @@ public:
}
virtual void OnLogLine(const char **s) {
+ const char *line = *s;
+ size_t len = strlen(line);
+ char *tmp = (char*)alloca(len + 3);
+
+ tmp[len + 0] = '\r';
+ tmp[len + 1] = '\n';
+ tmp[len + 2] = 0;
+ memcpy(tmp, line, len);
+
CHARRANGE cr;
cr.cpMin = -1;
cr.cpMax = -1;
// hwnd = rich edit hwnd
SendMessage(hwndEdit, EM_EXSETSEL, 0, (LPARAM)&cr);
- SendMessage(hwndEdit, EM_REPLACESEL, 0, (LPARAM)*s);
+ SendMessage(hwndEdit, EM_REPLACESEL, 0, (LPARAM)tmp);
}
virtual void OnStateChanged() {
@@ -204,13 +198,13 @@ public:
const char *filename = g_cmdline_filename;
if (filename) {
- if (GetConfigFullName(filename, fullname, sizeof(fullname)))
+ if (ExpandConfigPath(filename, fullname, sizeof(fullname)))
SetCurrentConfigFilename(fullname);
} else {
std::string currconfig = g_backend->GetConfigFileName();
if (currconfig.empty()) {
char *conf = RegReadStr(g_reg_key, "ConfigFile", "TunSafe.conf");
- if (GetConfigFullName(conf, fullname, sizeof(fullname)))
+ if (ExpandConfigPath(conf, fullname, sizeof(fullname)))
SetCurrentConfigFilename(fullname);
free(conf);
} else {
@@ -233,10 +227,12 @@ public:
}
virtual void OnStatusCode(TunsafeBackend::StatusCode status) override {
+ if (status != g_status_code)
+ InvalidatePaintbox();
+
g_status_code = status;
if (TunsafeBackend::IsPermanentError(status)) {
UpdateIcon(g_is_connected_to_server ? UIW_STOPPED_WORKING_FAIL : UIW_NONE);
- InvalidatePaintbox();
return;
}
bool is_connected = (status == TunsafeBackend::kStatusConnected);
@@ -254,13 +250,15 @@ public:
if (is_connected > not_first && (g_startup_flags & kStartupFlag_BackgroundService))
g_notified_connected_server = true;
UpdateIcon(UIW_NONE);
- InvalidatePaintbox();
}
}
virtual void OnClearLog() override {
SetWindowText(hwndEdit, "");
}
+
+ virtual void OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) override {
+ }
};
static MyBackendDelegate my_procdel;
@@ -411,7 +409,7 @@ public:
ConfigMenuBuilder::ConfigMenuBuilder()
: nfiles_(0), depth_(0) {
- if (!GetConfigFullName("", buf_, sizeof(buf_)))
+ if (!ExpandConfigPath("", buf_, sizeof(buf_)))
bufpos_ = sizeof(buf_);
else
bufpos_ = strlen(buf_);
@@ -556,7 +554,7 @@ static void OpenEditor() {
static void BrowseFiles() {
char buf[MAX_PATH];
- if (GetConfigFullName("", buf, ARRAYSIZE(buf))) {
+ if (ExpandConfigPath("", buf, ARRAYSIZE(buf))) {
size_t l = strlen(buf);
buf[l - 1] = 0;
ShellExecuteFromExplorer(buf, NULL, NULL, "explore");
@@ -572,7 +570,7 @@ bool ImportFile(const char *s, bool silent = false) {
bool rv = false;
int filerv;
- if (!*last || !GetConfigFullName(last, buf, ARRAYSIZE(buf)) || _stricmp(buf, s) == 0)
+ if (!*last || !ExpandConfigPath(last, buf, ARRAYSIZE(buf)) || _stricmp(buf, s) == 0)
goto out;
filedata = LoadFileSane(s, &filesize);
@@ -657,9 +655,8 @@ void BrowseFile(HWND wnd) {
static const uint8 kCurve25519Basepoint[32] = {9};
static void SetKeyBox(HWND wnd, int ctr, uint8 buf[32]) {
- uint8 *privs = base64_encode(buf, 32, NULL);
- SetDlgItemText(wnd, ctr, (char*)privs);
- free(privs);
+ char base64[WG_PUBLIC_KEY_LEN_BASE64 + 1];
+ SetDlgItemText(wnd, ctr, base64_encode(buf, 32, base64, sizeof(base64), NULL));
}
static INT_PTR WINAPI KeyPairDlgProc(HWND hWnd, UINT message, WPARAM wParam,
@@ -1075,20 +1072,18 @@ void PushLine(const char *s) {
snprintf(buf, sizeof(buf), "[%.2d:%.2d:%.2d] ", t.wHour, t.wMinute, t.wSecond);
size_t tl = strlen(buf);
- char *x = (char*)malloc(tl + l + 3);
+ char *x = (char*)malloc(tl + l + 1);
if (!x) return;
memcpy(x, buf, tl);
memcpy(x + tl, s, l);
- x[l + tl] = '\r';
- x[l + tl + 1] = '\n';
- x[l + tl + 2] = '\0';
+ x[l + tl] = '\0';
g_backend_delegate->OnLogLine((const char**)&x);
free(x);
}
void EnsureConfigDirCreated() {
char fullname[1024];
- if (GetConfigFullName("", fullname, sizeof(fullname)))
+ if (ExpandConfigPath("", fullname, sizeof(fullname)))
CreateDirectory(fullname, NULL);
}
@@ -1358,7 +1353,7 @@ static void DrawGraph(HDC dc, const RECT *rr, StatsCollector::TimeSeries **sourc
for (size_t j = 0; j != num_source; j++) {
const StatsCollector::TimeSeries *src = sources[j];
for (size_t i = 0; i != src->size; i++)
- mx = max(mx, src->data[i]);
+ mx = std::max(mx, src->data[i]);
}
int topval = (int)(mx + 0.5f);
// round it appropriately
@@ -1432,11 +1427,13 @@ static void DrawInGraphBox(HDC hdc, int w, int h) {
for (int i = 0; i < graph->num_charts; i++) {
time_series_ptr[i] = &time_series[i];
time_series[i].shift = 0;
- time_series[i].size = *(uint32*)ptr;
+
+ uint32 size = *(uint32*)ptr;
+ time_series[i].size = size;
time_series[i].data = (float*)(ptr + 4);
- ptr += 4 + *(uint32*)ptr * 4;
- if (ptr - (uint8*)graph > graph->total_size)
+ if ((ptr - (uint8*)graph) + 4 + (uint64)size * 4 > graph->total_size)
break;
+ ptr += 4 + size * 4;
}
num_charts = graph->num_charts;
}
@@ -1517,9 +1514,7 @@ static const char *GetAdvancedInfoValue(char buffer[256], int i) {
case 0: {
if (IsOnlyZeros(g_backend->public_key(), 32))
return "";
- char *str = (char*)base64_encode(g_backend->public_key(), 32, NULL);
- snprintf(buffer, 256, "%s", str);
- free(str);
+ base64_encode(g_backend->public_key(), 32, buffer, 256, NULL);
return buffer;
}
case 1: {
@@ -1764,7 +1759,7 @@ int WINAPI WinMain(HINSTANCE hInstance, HINSTANCE hPrevInstance, LPSTR lpCmdLine
// Check if the app is already running.
g_runonce_mutex = CreateMutexA(0, FALSE, "TunSafe-f19e092db01cbe0fb6aee132f8231e5b71c98f90");
- if (GetLastError() == ERROR_ALREADY_EXISTS) {
+ if (GetLastError() == ERROR_ALREADY_EXISTS&&0) {
HWND window = FindWindow("TunSafe-f19e092db01cbe0fb6aee132f8231e5b71c98f90", NULL);
DWORD_PTR result;
if (!window || !SendMessageTimeout(window, WM_USER + 10, 0, 0, SMTO_BLOCK, 3000, &result) || result != 31337) {
diff --git a/util.cpp b/util.cpp
index 2269a3f..2d7fe8e 100644
--- a/util.cpp
+++ b/util.cpp
@@ -17,46 +17,55 @@
#include
#endif
+#include
#include
#include "tunsafe_types.h"
-static char base64_alphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+static const char kBase64Alphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
-uint8 *base64_encode(const uint8 *input, size_t length, size_t *out_length) {
- uint32 a;
- size_t size;
- uint8 *result, *r;
+char *base64_encode(const uint8 *input, size_t length, char *output, size_t output_size, size_t *out_length) {
+ char *result, *r;
const uint8 *end;
- size = length * 4 / 3 + 4 + 1;
- r = result = (byte*)malloc(size);
+ size_t size = (length + 2) / 3 * 4 + 1;
+ if (output != NULL) {
+ result = output;
+ assert(output_size >= size);
+ if (output_size < size) {
+ *result = 0;
+ return NULL;
+ }
+ } else {
+ result = (char*)malloc(size);
+ if (!result)
+ return NULL;
+ }
+ r = result;
end = input + length - 3;
-
// Encode full blocks
while (input <= end) {
- a = (input[0] << 16) + (input[1] << 8) + input[2];
+ uint32 a = (input[0] << 16) + (input[1] << 8) + input[2];
input += 3;
- r[0] = base64_alphabet[(a >> 18)/* & 0x3F*/];
- r[1] = base64_alphabet[(a >> 12) & 0x3F];
- r[2] = base64_alphabet[(a >> 6) & 0x3F];
- r[3] = base64_alphabet[(a) & 0x3F];
+ r[0] = kBase64Alphabet[(a >> 18)/* & 0x3F*/];
+ r[1] = kBase64Alphabet[(a >> 12) & 0x3F];
+ r[2] = kBase64Alphabet[(a >> 6) & 0x3F];
+ r[3] = kBase64Alphabet[(a) & 0x3F];
r += 4;
}
-
if (input == end + 2) {
- a = input[0] << 4;
- r[0] = base64_alphabet[(a >> 6) /*& 0x3F*/];
- r[1] = base64_alphabet[(a) & 0x3F];
+ uint32 a = input[0] << 4;
+ r[0] = kBase64Alphabet[(a >> 6) /*& 0x3F*/];
+ r[1] = kBase64Alphabet[(a) & 0x3F];
r[2] = '=';
r[3] = '=';
r += 4;
} else if (input == end + 1) {
- a = (input[0] << 10) + (input[1] << 2);
- r[0] = base64_alphabet[(a >> 12) /*& 0x3F*/];
- r[1] = base64_alphabet[(a >> 6) & 0x3F];
- r[2] = base64_alphabet[(a) & 0x3F];
+ uint32 a = (input[0] << 10) + (input[1] << 2);
+ r[0] = kBase64Alphabet[(a >> 12) /*& 0x3F*/];
+ r[1] = kBase64Alphabet[(a >> 6) & 0x3F];
+ r[2] = kBase64Alphabet[(a) & 0x3F];
r[3] = '=';
r += 4;
}
@@ -250,13 +259,6 @@ void RERROR(const char *msg, ...) {
}
}
-void rinfo(const char *msg, ...) {
- printf("muu");
-}
-
-void rinfo2(const char *msg) {
- printf("muu2");
-}
void RINFO(const char *msg, ...) {
va_list va;
@@ -296,4 +298,115 @@ size_t my_strlcpy(char *dst, size_t dstsize, const char *src) {
memcpy(dst, src, lenx);
}
return len;
-}
\ No newline at end of file
+}
+
+void OsGetRandomBytes(uint8 *data, size_t data_size) {
+#if defined(OS_WIN)
+ static BOOLEAN(APIENTRY *pfn)(void*, ULONG);
+ if (!pfn) {
+ pfn = (BOOLEAN(APIENTRY *)(void*, ULONG))GetProcAddress(LoadLibrary("ADVAPI32.DLL"), "SystemFunction036");
+ if (!pfn)
+ ExitProcess(1);
+ }
+ if (!pfn(data, (ULONG)data_size)) {
+ ExitProcess(1);
+ return;
+ }
+#elif defined(OS_POSIX)
+ int fd = open("/dev/urandom", O_RDONLY);
+ if (fd < 0) {
+ fprintf(stderr, "/dev/urandom failed\n");
+ exit(1);
+ }
+ int r = read(fd, data, data_size);
+ if (r != data_size) {
+ fprintf(stderr, "/dev/urandom failed\n");
+ exit(1);
+ }
+ close(fd);
+#else
+#error
+#endif
+}
+
+bool ParseConfigKeyValue(char *m, std::vector> *result) {
+ for (;;) {
+ char *nl = strchr(m, '\n');
+ if (nl)
+ *nl = 0;
+ if (*m != '\0') {
+ char *value = strchr(m, '=');
+ if (value == NULL)
+ return false;
+ *value++ = '\0';
+ result->emplace_back(m, value);
+ }
+ if (!nl)
+ return true;
+ m = nl + 1;
+ }
+}
+
+bool ParseHexString(const char *text, void *data, size_t data_size) {
+ size_t len = strlen(text);
+ if (len != data_size * 2)
+ return false;
+ for (size_t i = 0; i < data_size; i++) {
+ uint32 c = text[i * 2 + 0];
+ if (c >= '0' && c <= '9') {
+ c -= '0';
+ } else if ((c |= 32) >= 'a' && c <= 'f') {
+ c -= 'a' - 10;
+ } else {
+ return false;
+ }
+ uint32 d = text[i * 2 + 1];
+ if (d >= '0' && d <= '9') {
+ d -= '0';
+ } else if ((d |= 32) >= 'a' && d <= 'f') {
+ d -= 'a' - 10;
+ } else {
+ return false;
+ }
+ ((uint8*)data)[i] = c * 16 + d;
+ }
+ return true;
+}
+
+bool is_space(uint8_t c) {
+ return c == ' ' || c == '\r' || c == '\n' || c == '\t';
+}
+
+void SplitString(char *s, int separator, std::vector *components) {
+ components->clear();
+ for (;;) {
+ while (is_space(*s)) s++;
+ char *d = strchr(s, separator);
+ if (d == NULL) {
+ if (*s)
+ components->push_back(s);
+ return;
+ }
+ *d = 0;
+ char *e = d;
+ while (e > s && is_space(e[-1]))
+ *--e = 0;
+ components->push_back(s);
+ s = d + 1;
+ }
+}
+
+void PrintHexString(const void *data, size_t data_size, char *result) {
+ for (size_t i = 0; i < data_size; i++) {
+ uint8 c = ((uint8*)data)[i];
+ *result++ = "0123456789abcdef"[c >> 4];
+ *result++ = "0123456789abcdef"[c & 0xF];
+ }
+ *result++ = 0;
+}
+
+bool ParseBase64Key(const char *s, uint8 key[32]) {
+ size_t size = 32;
+ return base64_decode((uint8*)s, strlen(s), key, &size) && size == 32;
+}
+
diff --git a/util.h b/util.h
index e0846f3..128d376 100644
--- a/util.h
+++ b/util.h
@@ -3,7 +3,7 @@
#pragma once
#include "tunsafe_types.h"
-uint8 *base64_encode(const uint8 *input, size_t length, size_t *out_length);
+char *base64_encode(const uint8 *input, size_t length, char *output, size_t output_size, size_t *actual_size);
bool base64_decode(uint8 *in, size_t inLen, uint8 *out, size_t *outLen);
bool IsOnlyZeros(const uint8 *data, size_t data_size);
@@ -17,9 +17,28 @@ char *my_strndup(const char *p, size_t size);
size_t my_strlcpy(char *dst, size_t dstsize, const char *src);
-
template static inline T postinc(T&x, U v) {
T t = x;
x += v;
return t;
}
+
+template static inline T exch(T&x, U v) {
+ T t = x;
+ x = v;
+ return t;
+}
+
+template static inline T exch_null(T&x) {
+ T t = x;
+ x = NULL;
+ return t;
+}
+
+bool is_space(uint8_t c);
+void OsGetRandomBytes(uint8 *dst, size_t dst_size);
+bool ParseConfigKeyValue(char *m, std::vector> *result);
+bool ParseHexString(const char *text, void *data, size_t data_size);
+void PrintHexString(const void *data, size_t data_size, char *result);
+void SplitString(char *s, int separator, std::vector *components);
+bool ParseBase64Key(const char *s, uint8 key[32]);
\ No newline at end of file
diff --git a/util_win32.cpp b/util_win32.cpp
index 1dec101..821372c 100644
--- a/util_win32.cpp
+++ b/util_win32.cpp
@@ -297,6 +297,23 @@ size_t GetConfigPath(char *path, size_t path_size) {
return last + 7 - path;
}
+bool ExpandConfigPath(const char *basename, char *fullname, size_t fullname_size) {
+ size_t len = strlen(basename);
+
+ if (FindFilenameComponent(basename)[0]) {
+ if (len >= fullname_size)
+ return false;
+ memcpy(fullname, basename, len + 1);
+ return true;
+ }
+ size_t clen = GetConfigPath(fullname, fullname_size);
+ if (clen == 0 || clen + len >= fullname_size)
+ return false;
+ memcpy(fullname + clen, basename, (len + 1) * sizeof(fullname[0]));
+ return true;
+}
+
+
static bool ContainsDotDot(const char *path) {
for (uint8 last = 0, cur; (cur = path[0]) != '\0'; last = cur, path++)
if (cur == '.' && last == cur)
@@ -308,7 +325,7 @@ bool EnsureValidConfigPath(const char *path) {
char buf[1024];
size_t len = GetConfigPath(buf, sizeof(buf));
- return (len != 0) && (strlen(path) > len && memcmp(path, buf, len) == 0 && !ContainsDotDot(path + len));
+ return (len != 0) && (strlen(path) > len && _strnicmp(path, buf, len) == 0 && !ContainsDotDot(path + len));
}
bool RunProcessAsAdminWithArgs(const char *args, bool wait_for_exit) {
@@ -376,3 +393,65 @@ RECT MakeRect(int l, int t, int r, int b) {
RECT rr = { l, t, r, b };
return rr;
}
+
+// Retrieve the device path to the TAP adapter.
+
+#define kAdapterKeyName "SYSTEM\\CurrentControlSet\\Control\\Class\\{4D36E972-E325-11CE-BFC1-08002BE10318}"
+#define kNetworkConnectionsKeyName "SYSTEM\\CurrentControlSet\\Control\\Network\\{4D36E972-E325-11CE-BFC1-08002BE10318}"
+#define kTapComponentId "tap0901"
+
+
+void GetTapAdapterInfo(std::vector *result) {
+ LONG err;
+ HKEY adapter_key, device_key, network_connections_key;
+ bool retval = false;
+ GuidAndDevName gn;
+
+ err = RegOpenKeyEx(HKEY_LOCAL_MACHINE, kAdapterKeyName, 0, KEY_READ, &adapter_key);
+ if (err != ERROR_SUCCESS) {
+ RERROR("GetTapAdapterName: RegOpenKeyEx failed: 0x%X", GetLastError());
+ return;
+ }
+ for (int i = 0; !retval; i++) {
+ char keyname[64 + sizeof(kAdapterKeyName) + 1 + 32 /* some margin */];
+ char value[64];
+ DWORD len = sizeof(value), type;
+ err = RegEnumKeyEx(adapter_key, i, value, &len, NULL, NULL, NULL, NULL);
+ if (err == ERROR_NO_MORE_ITEMS)
+ break;
+ if (err != ERROR_SUCCESS) {
+ RERROR("GetTapAdapterName: RegEnumKeyEx failed: 0x%X", GetLastError());
+ break;
+ }
+ snprintf(keyname, sizeof(keyname), "%s\\%s", kAdapterKeyName, value);
+ err = RegOpenKeyEx(HKEY_LOCAL_MACHINE, keyname, 0, KEY_READ, &device_key);
+ if (err == ERROR_SUCCESS) {
+ len = sizeof(value);
+ err = RegQueryValueEx(device_key, "ComponentId", NULL, &type, (LPBYTE)value, &len);
+ if (err == ERROR_SUCCESS && type == REG_SZ && !memcmp(value, kTapComponentId, sizeof(kTapComponentId))) {
+ len = sizeof(gn.guid);
+ err = RegQueryValueEx(device_key, "NetCfgInstanceId", NULL, &type, (LPBYTE)gn.guid, &len);
+ if (err == ERROR_SUCCESS && type == REG_SZ) {
+ gn.guid[sizeof(gn.guid) - 1] = 0;
+ gn.name[0] = 0;
+
+ snprintf(keyname, sizeof(keyname), "%s\\%s\\Connection", kNetworkConnectionsKeyName, gn.guid);
+ err = RegOpenKeyEx(HKEY_LOCAL_MACHINE, keyname, 0, KEY_READ, &network_connections_key);
+ if (err == ERROR_SUCCESS) {
+ len = sizeof(gn.guid);
+ err = RegQueryValueEx(network_connections_key, "Name", NULL, &type, (LPBYTE)gn.name, &len);
+ if (err == ERROR_SUCCESS && type == REG_SZ) {
+ gn.name[sizeof(gn.guid) - 1] = 0;
+ } else {
+ gn.name[0] = 0;
+ }
+ RegCloseKey(network_connections_key);
+ }
+ result->push_back(gn);
+ }
+ }
+ RegCloseKey(device_key);
+ }
+ }
+ RegCloseKey(adapter_key);
+}
diff --git a/util_win32.h b/util_win32.h
index 8497903..0ae16ec 100644
--- a/util_win32.h
+++ b/util_win32.h
@@ -1,6 +1,7 @@
// SPDX-License-Identifier: AGPL-1.0-only
// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved.
#include "tunsafe_types.h"
+#include
#pragma once
const char *FindFilenameComponent(const char *s);
@@ -47,6 +48,7 @@ void ShellExecuteFromExplorer(
int nShowCmd = SW_SHOWNORMAL);
size_t GetConfigPath(char *path, size_t path_size);
+bool ExpandConfigPath(const char *basename, char *fullname, size_t fullname_size);
bool EnsureValidConfigPath(const char *path);
bool RunProcessAsAdminWithArgs(const char *args, bool wait_for_exit);
@@ -54,3 +56,8 @@ bool RestartProcessAsAdministrator();
bool SetClipboardString(const char *string);
RECT GetParentRect(HWND wnd);
RECT MakeRect(int l, int t, int r, int b);
+struct GuidAndDevName {
+ char guid[40];
+ char name[64];
+};
+void GetTapAdapterInfo(std::vector *result);
diff --git a/wireguard.cpp b/wireguard.cpp
index ceeb068..fb616e3 100644
--- a/wireguard.cpp
+++ b/wireguard.cpp
@@ -15,8 +15,7 @@
#include "ipzip2/ipzip2.h"
#include "wireguard.h"
#include "wireguard_config.h"
-
-uint64 OsGetMilliseconds();
+#include "util.h"
enum {
IPV4_HEADER_SIZE = 20,
@@ -36,23 +35,24 @@ WireguardProcessor::WireguardProcessor(UdpInterface *udp, TunInterface *tun, Pro
add_routes_mode_ = true;
dns_blocking_ = true;
internet_blocking_ = kBlockInternet_Default;
-
+ is_started_ = false;
stats_last_bytes_in_ = 0;
stats_last_bytes_out_ = 0;
stats_last_ts_ = OsGetMilliseconds();
-
- main_thread_scheduled_ = NULL;
- main_thread_scheduled_last_ = &main_thread_scheduled_;
}
WireguardProcessor::~WireguardProcessor() {
}
void WireguardProcessor::SetListenPort(int listen_port) {
- listen_port_ = listen_port;
+ if (listen_port_ != listen_port) {
+ listen_port_ = listen_port;
+ if (is_started_ && !ConfigureUdp()) {
+ RINFO("ConfigureUdp failed");
+ }
+ }
}
-
void WireguardProcessor::AddDnsServer(const IpAddr &sin) {
std::vector *target = (sin.sin.sin_family == AF_INET6) ? &dns6_addr_ : &dns_addr_;
target->push_back(sin);
@@ -66,6 +66,11 @@ bool WireguardProcessor::SetTunAddress(const WgCidrAddr &addr) {
return true;
}
+void WireguardProcessor::ClearTunAddress() {
+ tun_addr_.size = 0;
+ tun6_addr_.size = 0;
+}
+
void WireguardProcessor::AddExcludedIp(const WgCidrAddr &cidr_addr) {
excluded_ips_.push_back(cidr_addr);
}
@@ -129,23 +134,29 @@ static bool IsWgCidrAddrSubsetOf(const WgCidrAddr &inner, const WgCidrAddr &oute
}
bool WireguardProcessor::Start() {
+ return ConfigureUdp() && ConfigureTun();
+}
+
+bool WireguardProcessor::ConfigureUdp() {
assert(dev_.IsMainThread());
- if (!udp_->Initialize(listen_port_))
- return false;
+ return udp_->Configure(listen_port_);
+}
- if (tun_addr_.size != 32) {
- RERROR("No IPv4 address configured");
- return false;
- }
-
- if (tun_addr_.cidr >= 31) {
- RERROR("TAP is not compatible CIDR /31 or /32. Changing to /24");
- tun_addr_.cidr = 24;
- }
+bool WireguardProcessor::ConfigureTun() {
+ assert(dev_.IsMainThread());
TunInterface::TunConfig config = {0};
- config.ip = ReadBE32(tun_addr_.addr);
- config.cidr = tun_addr_.cidr;
+ if (tun_addr_.size == 32) {
+ if (tun_addr_.cidr >= 31) {
+ RERROR("TAP is not compatible CIDR /31 or /32. Changing to /24");
+ tun_addr_.cidr = 24;
+ }
+ config.ip = ReadBE32(tun_addr_.addr);
+ config.cidr = tun_addr_.cidr;
+ } else {
+ RERROR("No IPv4 address configured");
+ }
+
config.mtu = mtu_;
config.pre_post_commands = pre_post_;
config.excluded_ips = excluded_ips_;
@@ -205,7 +216,7 @@ bool WireguardProcessor::Start() {
config.ipv6_dns = dns6_addr_;
TunInterface::TunConfigOut config_out;
- if (!tun_->Initialize(std::move(config), &config_out))
+ if (!tun_->Configure(std::move(config), &config_out))
return false;
SetupCompressionHeader(dev_.compression_header());
@@ -221,6 +232,7 @@ bool WireguardProcessor::Start() {
}
}
+ is_started_ = true;
return true;
}
@@ -395,22 +407,6 @@ getout:
FreePacket(packet);
}
-void WgPeer::AddPacketToPeerQueue(Packet *packet) {
- assert(IsPeerLocked());
- // Keep only the first MAX_QUEUED_PACKETS packets.
- while (num_queued_packets_ >= MAX_QUEUED_PACKETS_PER_PEER) {
- Packet *packet = first_queued_packet_;
- first_queued_packet_ = packet->next;
- num_queued_packets_--;
- FreePacket(packet);
- }
- // Add the packet to the out queue that will get sent once handshake completes
- *last_queued_packet_ptr_ = packet;
- last_queued_packet_ptr_ = &packet->next;
- packet->next = NULL;
- num_queued_packets_++;
-}
-
// This function must be called with the peer lock held. It will remove the lock
void WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet) {
assert(peer->IsPeerLocked());
@@ -427,11 +423,17 @@ void WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Pac
if ((keypair = peer->curr_keypair_) == NULL ||
(send_ctr = keypair->send_ctr) >= REJECT_AFTER_MESSAGES) {
- peer->AddPacketToPeerQueue(packet);
+ // If RemovePeer has been called then discard any packets currently being written to it.
+ // curr_keypair_ is NULL when RemovePeer has been called so it's safe to do this here.
+ if (peer->marked_for_delete_)
+ goto getout_discard;
+
+ peer->AddPacketToPeerQueue_Locked(packet);
WG_RELEASE_LOCK(peer->mutex_);
- ScheduleNewHandshake(peer);
+ peer->ScheduleNewHandshake();
return;
}
+ assert(!peer->marked_for_delete_);
stats_.tun_bytes_in += size;
stats_.tun_packets_in++;
@@ -524,13 +526,14 @@ add_padding:
WriteLE32(write -= 4, keypair->remote_key_id);
*--write = tag;
- // Not using any fields from now on
- WG_RELEASE_LOCK(peer->mutex_);
-
header_size = data - write;
+ packet->size = (int)(size + header_size + keypair->auth_tag_length);
+ peer->tx_bytes_ += packet->size;
stats_.compression_wg_saved_out += (int64)16 - header_size;
packet->data = data - header_size;
- packet->size = (int)(size + header_size + keypair->auth_tag_length);
+
+ // Not using any fields from now on
+ WG_RELEASE_LOCK(peer->mutex_);
// todo: figure out what to actually use as ad.
ad = write_after_ack_header;
@@ -540,6 +543,9 @@ need_big_packet:
#else
{
#endif // #if WITH_SHORT_HEADERS
+ packet->size = (int)(size + sizeof(MessageData) + keypair->auth_tag_length);
+ peer->tx_bytes_ += packet->size;
+
// Not using any fields from now on
WG_RELEASE_LOCK(peer->mutex_);
@@ -547,7 +553,6 @@ need_big_packet:
((MessageData*)data)[-1].receiver_id = keypair->remote_key_id;
((MessageData*)data)[-1].counter = ToLE64(send_ctr);
packet->data = data - sizeof(MessageData);
- packet->size = (int)(size + sizeof(MessageData) + keypair->auth_tag_length);
ad = NULL;
ad_len = 0;
}
@@ -556,7 +561,7 @@ need_big_packet:
DoWriteUdpPacket(packet);
if (want_handshake)
- ScheduleNewHandshake(peer);
+ peer->ScheduleNewHandshake();
return;
getout_discard:
@@ -608,38 +613,32 @@ void WireguardProcessor::DoWriteUdpPacket(Packet *packet) {
ScrambleUnscrambleAndWrite(packet, &dev_.header_obfuscation_key_, udp_);
}
-void WireguardProcessor::ScheduleNewHandshake(WgPeer *peer) {
- if (peer->main_thread_scheduled_.fetch_or(WgPeer::kMainThreadScheduled_ScheduleHandshake) == 0) {
- peer->main_thread_scheduled_next_ = NULL;
- WG_ACQUIRE_LOCK(main_thread_scheduled_lock_);
- *main_thread_scheduled_last_ = peer;
- main_thread_scheduled_last_ = &peer->main_thread_scheduled_next_;
- WG_RELEASE_LOCK(main_thread_scheduled_lock_);
- // todo: in multithreaded impl need to trigger |RunAllMainThreadScheduled| to get called
- }
-}
-
void WireguardProcessor::RunAllMainThreadScheduled() {
+ WgPeer *peer, *next;
assert(dev_.IsMainThread());
- if (main_thread_scheduled_ == NULL)
+ if (dev_.main_thread_scheduled_ == NULL)
return;
- WG_ACQUIRE_LOCK(main_thread_scheduled_lock_);
- WgPeer *peer = main_thread_scheduled_;
- main_thread_scheduled_ = NULL;
- main_thread_scheduled_last_ = &main_thread_scheduled_;
- WG_RELEASE_LOCK(main_thread_scheduled_lock_);
+ WG_ACQUIRE_LOCK(dev_.main_thread_scheduled_lock_);
+ peer = dev_.main_thread_scheduled_;
+ dev_.main_thread_scheduled_ = NULL;
+ dev_.main_thread_scheduled_last_ = &dev_.main_thread_scheduled_;
+ WG_RELEASE_LOCK(dev_.main_thread_scheduled_lock_);
+
+ for (; peer; peer = next) {
+ // todo: for the multithreaded use case figure out whether to use atomic_thread_fence here,
+ // because we need to read this next value before any other thread sees the 0 we write
+ // to peer->main_thread_scheduled_.
+ next = peer->main_thread_scheduled_next_;
+ if (peer->marked_for_delete_)
+ continue;
- while (peer) {
- // todo: for the multithreaded use case figure out whether to use atomic_thread_fence here.
- WgPeer *next = peer->main_thread_scheduled_next_;
uint32 ev = peer->main_thread_scheduled_.exchange(0);
if (ev & WgPeer::kMainThreadScheduled_ScheduleHandshake) {
peer->handshake_attempts_ = 0;
SendHandshakeInitiation(peer);
}
- peer = next;
}
}
@@ -658,6 +657,7 @@ void WireguardProcessor::SendHandshakeInitiation(WgPeer *peer) {
procdel_->OnConnectionRetry(attempts);
peer->OnHandshakeInitSent();
packet->addr = peer->endpoint_;
+ peer->tx_bytes_ += packet->size;
WG_RELEASE_LOCK(peer->mutex_);
DoWriteUdpPacket(packet);
if (attempts > 1 && attempts <= 20)
@@ -696,19 +696,21 @@ void WireguardProcessor::HandleUdpPacket(Packet *packet, bool overload) {
#endif // WITH_SHORT_HEADERS
} else if (type == MESSAGE_HANDSHAKE_COOKIE) {
assert(dev_.IsMainThread());
- if (packet->size != sizeof(MessageHandshakeCookie))
+ if (packet->size != sizeof(MessageHandshakeCookie) || !dev_.is_private_key_initialized())
goto invalid_size;
HandleHandshakeCookiePacket(packet);
} else if (type == MESSAGE_HANDSHAKE_INITIATION) {
assert(dev_.IsMainThread());
- if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeInitiation)) : (packet->size != sizeof(MessageHandshakeInitiation)))
+ if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeInitiation)) : (packet->size != sizeof(MessageHandshakeInitiation)) ||
+ !dev_.is_private_key_initialized())
goto invalid_size;
stats_.handshakes_in++;
if (CheckIncomingHandshakeRateLimit(packet, overload))
HandleHandshakeInitiationPacket(packet);
} else if (type == MESSAGE_HANDSHAKE_RESPONSE) {
assert(dev_.IsMainThread());
- if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeResponse)) : (packet->size != sizeof(MessageHandshakeResponse)))
+ if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeResponse)) : (packet->size != sizeof(MessageHandshakeResponse)) ||
+ !dev_.is_private_key_initialized())
goto invalid_size;
if (CheckIncomingHandshakeRateLimit(packet, overload))
HandleHandshakeResponsePacket(packet);
@@ -749,6 +751,8 @@ void WgPeer::CopyEndpointToPeer_Locked(WgKeypair *keypair, const IpAddr *addr) {
#if WITH_SHORT_HEADERS
void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packet) {
+ assert(dev_.IsMainOrDataThread());
+
uint8 *data = packet->data + 1;
size_t bytes_left = packet->size - 1;
WgKeypair *keypair;
@@ -832,6 +836,8 @@ void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packe
WG_ACQUIRE_LOCK(keypair->peer->mutex_);
+ keypair->peer->rx_bytes_ += packet->size;
+
if (keypair->recv_key_state == WgKeypair::KEY_INVALID)
goto getout_unlock;
@@ -896,7 +902,7 @@ void WireguardProcessor::HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *key
WgKeypair *curr_keypair = peer->curr_keypair_;
if (curr_keypair && curr_keypair->recv_key_state == WgKeypair::KEY_WANT_REFRESH) {
curr_keypair->recv_key_state = WgKeypair::KEY_DID_REFRESH;
- ScheduleNewHandshake(peer);
+ peer->ScheduleNewHandshake();
}
if (data_size == 0) {
@@ -965,6 +971,8 @@ getout:
}
void WireguardProcessor::HandleDataPacket(Packet *packet) {
+ assert(dev_.IsMainOrDataThread());
+
uint8 *data = packet->data;
size_t data_size = packet->size;
uint32 key_id = ((MessageData*)data)->receiver_id;
@@ -984,6 +992,7 @@ getout:
}
WG_ACQUIRE_LOCK(keypair->peer->mutex_);
+ keypair->peer->rx_bytes_ += data_size;
if (keypair->recv_key_state == WgKeypair::KEY_INVALID) {
stats_.error_key_id++;
WG_RELEASE_LOCK(keypair->peer->mutex_);
@@ -993,6 +1002,8 @@ getout:
WG_RELEASE_LOCK(keypair->peer->mutex_);
goto getout;
} else {
+ assert(!keypair->peer->marked_for_delete_);
+
WgPeer::CopyEndpointToPeer_Locked(keypair, &packet->addr);
HandleAuthenticatedDataPacket_WillUnlock(keypair, packet, data + sizeof(MessageData), data_size - sizeof(MessageData) - keypair->auth_tag_length);
}
@@ -1119,7 +1130,7 @@ void WireguardProcessor::SecondLoop() {
uint32 mask;
{
WG_SCOPED_LOCK(peer->mutex_);
- mask = peer->CheckTimeouts(now);
+ mask = peer->CheckTimeouts_Locked(now);
if (mask == 0)
continue;
if (mask & WgPeer::ACTION_SEND_KEEPALIVE)
diff --git a/wireguard.h b/wireguard.h
index d357d41..4e4bfc1 100644
--- a/wireguard.h
+++ b/wireguard.h
@@ -66,6 +66,7 @@ enum InternetBlockState {
};
class WireguardProcessor {
+ friend class WgConfig;
public:
WireguardProcessor(UdpInterface *udp, TunInterface *tun, ProcessorDelegate *procdel);
~WireguardProcessor();
@@ -73,13 +74,14 @@ public:
void SetListenPort(int listen_port);
void AddDnsServer(const IpAddr &sin);
bool SetTunAddress(const WgCidrAddr &addr);
+ void ClearTunAddress();
void AddExcludedIp(const WgCidrAddr &cidr_addr);
void SetMtu(int mtu);
void SetAddRoutesMode(bool mode);
void SetDnsBlocking(bool dns_blocking);
void SetInternetBlocking(InternetBlockState internet_blocking);
void SetHeaderObfuscation(const char *key);
-
+
void HandleTunPacket(Packet *packet);
void HandleUdpPacket(Packet *packet, bool overload);
static bool IsMainThreadPacket(Packet *packet);
@@ -91,6 +93,9 @@ public:
bool Start();
+ bool ConfigureUdp();
+ bool ConfigureTun();
+
WgDevice &dev() { return dev_; }
TunInterface::PrePostCommands &prepost() { return pre_post_; }
const WgCidrAddr &tun_addr() { return tun_addr_; }
@@ -100,7 +105,6 @@ private:
void DoWriteUdpPacket(Packet *packet);
void WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet);
void SendHandshakeInitiation(WgPeer *peer);
- void ScheduleNewHandshake(WgPeer *peer);
void SendKeepalive_Locked(WgPeer *peer);
void SendQueuedPackets_Locked(WgPeer *peer);
@@ -110,29 +114,25 @@ private:
void HandleDataPacket(Packet *packet);
void HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet, uint8 *data, size_t data_size);
-
void HandleShortHeaderFormatPacket(uint32 tag, Packet *packet);
-
bool CheckIncomingHandshakeRateLimit(Packet *packet, bool overload);
-
bool HandleIcmpv6NeighborSolicitation(const byte *data, size_t data_size);
-
void SetupCompressionHeader(WgPacketCompressionVer01 *c);
void NotifyHandshakeComplete();
- int listen_port_;
-
ProcessorDelegate *procdel_;
TunInterface *tun_;
UdpInterface *udp_;
- int mtu_;
- WgProcessorStats stats_;
+
+ uint16 listen_port_;
+ uint16 mtu_;
bool dns_blocking_;
uint8 internet_blocking_;
bool add_routes_mode_;
bool network_discovery_spoofing_;
bool did_have_first_handshake_;
+ bool is_started_;
uint8 network_discovery_mac_[6];
WgDevice dev_;
@@ -140,14 +140,12 @@ private:
WgCidrAddr tun_addr_;
WgCidrAddr tun6_addr_;
+ WgProcessorStats stats_;
+
std::vector dns_addr_, dns6_addr_;
TunInterface::PrePostCommands pre_post_;
- // Queue of things scheduled to run on the main thread.
- WG_DECLARE_LOCK(main_thread_scheduled_lock_);
- WgPeer *main_thread_scheduled_, **main_thread_scheduled_last_;
-
uint64 stats_last_bytes_in_, stats_last_bytes_out_;
uint64 stats_last_ts_;
diff --git a/wireguard_config.cpp b/wireguard_config.cpp
index 6f34d5b..2bb59da 100644
--- a/wireguard_config.cpp
+++ b/wireguard_config.cpp
@@ -45,12 +45,26 @@ char *PrintIpAddr(const IpAddr &addr, char buf[kSizeOfAddress]) {
return buf;
}
+
+char *PrintWgCidrAddr(const WgCidrAddr &addr, char buf[kSizeOfAddress]) {
+ if (addr.size == 32) {
+ print_ip_prefix(buf, AF_INET, addr.addr, addr.cidr);
+ } else if (addr.size == 128) {
+ print_ip_prefix(buf, AF_INET6, addr.addr, addr.cidr);
+ } else {
+ buf[0] = 0;
+ }
+ return buf;
+}
+
+
+
struct Addr {
byte addr[4];
uint8 cidr;
};
-static bool ParseCidrAddr(char *s, WgCidrAddr *out) {
+bool ParseCidrAddr(char *s, WgCidrAddr *out) {
char *slash = strchr(s, '/');
if (!slash)
return false;
@@ -92,15 +106,6 @@ bool DnsResolver::Resolve(const char *hostname, IpAddr *result) {
char buf[kSizeOfAddress];
memset(result, 0, sizeof(IpAddr));
- if (inet_pton(AF_INET6, hostname, &result->sin6.sin6_addr) == 1) {
- result->sin.sin_family = AF_INET6;
- return true;
- }
-
- if (inet_pton(AF_INET, hostname, &result->sin.sin_addr) == 1) {
- result->sin.sin_family = AF_INET;
- return true;
- }
// First check cache
for (auto it = cache_.begin(); it != cache_.end(); ++it) {
@@ -145,10 +150,7 @@ bool DnsResolver::Resolve(const char *hostname, IpAddr *result) {
}
}
-
-
-
-static bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver) {
+bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver) {
memset(sin, 0, sizeof(IpAddr));
if (*s == '[') {
char *end = strchr(s, ']');
@@ -168,7 +170,11 @@ static bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver)
if (!x) return false;
*x = 0;
- if (!resolver->Resolve(s, sin)) {
+ if (inet_pton(AF_INET, s, &sin->sin.sin_addr) == 1) {
+ sin->sin.sin_family = AF_INET;
+ } else if (!resolver) {
+ return false;
+ } else if (!resolver->Resolve(s, sin)) {
RERROR("Unable to resolve %s", s);
return false;
}
@@ -177,18 +183,19 @@ static bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver)
}
static bool ParseSockaddrInWithoutPort(char *s, IpAddr *sin, DnsResolver *resolver) {
- if (!resolver->Resolve(s, sin)) {
+ if (inet_pton(AF_INET6, s, &sin->sin6.sin6_addr) == 1) {
+ sin->sin.sin_family = AF_INET6;
+ return true;
+ } else if (inet_pton(AF_INET, s, &sin->sin.sin_addr) == 1) {
+ sin->sin.sin_family = AF_INET;
+ return true;
+ } else if (!resolver->Resolve(s, sin)) {
RERROR("Unable to resolve %s", s);
return false;
}
return true;
}
-static bool ParseBase64Key(const char *s, uint8 key[32]) {
- size_t size = 32;
- return base64_decode((uint8*)s, strlen(s), key, &size) && size == 32;
-}
-
class WgFileParser {
public:
WgFileParser(WireguardProcessor *wg, DnsResolver *resolver) : wg_(wg), dns_resolver_(resolver) {}
@@ -197,7 +204,7 @@ public:
void FinishGroup();
struct Peer {
- uint8 pub[32];
+ WgPublicKey pub;
uint8 psk[32];
};
Peer pi_;
@@ -206,29 +213,6 @@ public:
bool had_interface_ = false;
};
-bool is_space(uint8_t c) {
- return c == ' ' || c == '\r' || c == '\n' || c == '\t';
-}
-
-
-void SplitString(char *s, int separator, std::vector *components) {
- for (;;) {
- while (is_space(*s)) s++;
- char *d = strchr(s, separator);
- if (d == NULL) {
- if (*s)
- components->push_back(s);
- return;
- }
- *d = 0;
- char *e = d;
- while (e > s && is_space(e[-1]))
- *--e = 0;
- components->push_back(s);
- s = d + 1;
- }
-}
-
static bool ParseBoolean(const char *str, bool *value) {
if (_stricmp(str, "true") == 0 ||
_stricmp(str, "yes") == 0 ||
@@ -285,7 +269,7 @@ static int ParseCipherSuite(const char *cipher) {
void WgFileParser::FinishGroup() {
if (peer_) {
- peer_->Initialize(pi_.pub, pi_.psk);
+ peer_->SetPublicKey(pi_.pub);
peer_ = NULL;
}
}
@@ -303,7 +287,7 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
if (!ParseBase64Key(value, binkey))
return false;
had_interface_ = true;
- wg_->dev().Initialize(binkey);
+ wg_->dev().SetPrivateKey(binkey);
} else if (strcmp(key, "ListenPort") == 0) {
wg_->SetListenPort(atoi(value));
} else if (strcmp(key, "Address") == 0) {
@@ -394,11 +378,12 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
return true;
}
if (strcmp(key, "PublicKey") == 0) {
- if (!ParseBase64Key(value, pi_.pub))
+ if (!ParseBase64Key(value, pi_.pub.bytes))
return false;
} else if (strcmp(key, "PresharedKey") == 0) {
if (!ParseBase64Key(value, pi_.psk))
return false;
+ peer_->SetPresharedKey(pi_.psk);
} else if (strcmp(key, "AllowedIPs") == 0) {
SplitString(value, ',', &ss);
for (size_t i = 0; i < ss.size(); i++) {
@@ -412,7 +397,8 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
return false;
peer_->SetEndpoint(sin);
} else if (strcmp(key, "PersistentKeepalive") == 0) {
- peer_->SetPersistentKeepalive(atoi(value));
+ if (!peer_->SetPersistentKeepalive(atoi(value)))
+ return false;
} else if (strcmp(key, "AllowMulticast") == 0) {
bool b;
if (!ParseBoolean(value, &b))
@@ -524,3 +510,154 @@ bool ParseWireGuardConfigFile(WireguardProcessor *wg, const char *filename, DnsR
fclose(f);
return true;
}
+
+
+static void CmsgAppendFmt(std::string *result, const char *fmt, ...) {
+ va_list va;
+ char buf[256];
+ va_start(va, fmt);
+ vsnprintf(buf, sizeof(buf), fmt, va);
+ (*result) += buf;
+ (*result) += '\n';
+ va_end(va);
+}
+
+static void CmsgAppendHex(std::string *result, const char *key, const void *data, size_t data_size) {
+ char *tmp = (char*)alloca(data_size * 2 + 2);
+ PrintHexString(data, data_size, tmp + 1);
+ tmp[0] = '=';
+ tmp[data_size * 2 + 1] = '\n';
+ (*result) += key;
+ result->append(tmp, data_size * 2 + 2);
+}
+
+void WgConfig::HandleConfigurationProtocolGet(WireguardProcessor *proc, std::string *result) {
+ char buf[kSizeOfAddress];
+
+ CmsgAppendHex(result, "private_key", proc->dev_.s_priv_, sizeof(proc->dev_.s_priv_));
+ if (proc->listen_port_)
+ CmsgAppendFmt(result, "listen_port=%d", proc->listen_port_);
+ if (proc->tun_addr_.size == 32)
+ CmsgAppendFmt(result, "address=%s", PrintWgCidrAddr(proc->tun_addr_, buf));
+ if (proc->tun6_addr_.size == 128)
+ CmsgAppendFmt(result, "address=%s", PrintWgCidrAddr(proc->tun6_addr_, buf));
+
+ for (WgPeer *peer = proc->dev_.peers_; peer; peer = peer->next_peer_) {
+ WG_SCOPED_LOCK(peer->lock_);
+
+ CmsgAppendHex(result, "public_key", peer->s_remote_.bytes, sizeof(peer->s_remote_));
+ if (!IsOnlyZeros(peer->preshared_key_, sizeof(peer->preshared_key_)))
+ CmsgAppendHex(result, "preshared_key", peer->preshared_key_, sizeof(peer->preshared_key_));
+ if (peer->tx_bytes_ | peer->rx_bytes_)
+ CmsgAppendFmt(result, "tx_bytes=%lld\nrx_bytes=%lld", peer->tx_bytes_, peer->rx_bytes_);
+ for (auto it = peer->allowed_ips_.begin(); it != peer->allowed_ips_.end(); ++it)
+ CmsgAppendFmt(result, "allowed_ip=%s", PrintWgCidrAddr(*it, buf));
+ if (peer->persistent_keepalive_ms_)
+ CmsgAppendFmt(result, "persistent_keepalive_interval=%d", peer->persistent_keepalive_ms_ / 1000);
+ if (peer->endpoint_.sin.sin_family == AF_INET)
+ CmsgAppendFmt(result, "endpoint=%s:%d", PrintIpAddr(peer->endpoint_, buf), htons(peer->endpoint_.sin.sin_port));
+ else if (peer->endpoint_.sin.sin_family == AF_INET6)
+ CmsgAppendFmt(result, "endpoint=[%s]:%d", PrintIpAddr(peer->endpoint_, buf), htons(peer->endpoint_.sin6.sin6_port));
+
+ if (peer->last_complete_handskake_timestamp_) {
+ uint64 millis_since = OsGetMilliseconds() - peer->last_complete_handskake_timestamp_;
+ uint64 when = time(NULL) - millis_since / 1000;
+ CmsgAppendFmt(result, "last_handshake_time_sec=%lld", when);
+ }
+ }
+ CmsgAppendFmt(result, "protocol_version=1");
+}
+
+bool WgConfig::HandleConfigurationProtocolMessage(WireguardProcessor *proc, const std::string &&message, std::string *result) {
+ std::string message_copy(std::move(message));
+ std::vector> kv;
+ bool is_set = false;
+ bool did_set_address = false;
+ WgPeer *peer = NULL;
+ WgCidrAddr cidr_addr;
+ IpAddr sin;
+ uint8 buf32[32];
+ assert(proc->dev().IsMainThread());
+
+ result->clear();
+
+ if (!ParseConfigKeyValue(&message_copy[0], &kv))
+ return false;
+
+ for (auto it : kv) {
+ char *key = it.first, *value = it.second;
+ if (strcmp(key, "get") == 0) {
+ if (strcmp(value, "1") != 0)
+ goto getout_fail;
+ HandleConfigurationProtocolGet(proc, result);
+ break;
+ } else if (strcmp(key, "set") == 0) {
+ if (strcmp(value, "1") != 0)
+ goto getout_fail;
+ is_set = true;
+ } else if (is_set) {
+ if (strcmp(key, "private_key") == 0) {
+ if (!ParseHexString(value, buf32, 32)) goto getout_fail;
+ proc->dev_.SetPrivateKey(buf32);
+ } else if (strcmp(key, "listen_port") == 0) {
+ int new_port = atoi(value);
+ proc->SetListenPort(new_port);
+ } else if (strcmp(key, "replace_peers") == 0) {
+ if (strcmp(value, "true") != 0) goto getout_fail;
+ proc->dev_.RemoveAllPeers();
+ } else if (strcmp(key, "address") == 0) {
+ if (!ParseCidrAddr(value, &cidr_addr)) goto getout_fail;
+ if (!did_set_address) {
+ did_set_address = true;
+ proc->ClearTunAddress();
+ }
+ if (!proc->SetTunAddress(cidr_addr)) goto getout_fail;
+ } else if (strcmp(key, "public_key") == 0) {
+ WgPublicKey pubkey;
+ if (!ParseHexString(value, pubkey.bytes, 32)) goto getout_fail;
+ peer = proc->dev_.GetPeerFromPublicKey(pubkey);
+ if (!peer) {
+ peer = proc->dev_.AddPeer();
+ peer->SetPublicKey(pubkey);
+ }
+ } else if (peer != NULL) {
+ if (strcmp(key, "remove") == 0) {
+ if (strcmp(value, "true") != 0) goto getout_fail;
+ peer->RemovePeer();
+ peer = NULL;
+ } else if (strcmp(key, "preshared_key") == 0) {
+ if (!ParseHexString(value, buf32, 32)) goto getout_fail;
+ peer->SetPresharedKey(buf32);
+ } else if (strcmp(key, "endpoint") == 0) {
+ if (!ParseSockaddrInWithPort(value, &sin, NULL)) goto getout_fail;
+ peer->SetEndpoint(sin);
+ } else if (strcmp(key, "persistent_keepalive_interval") == 0) {
+ if (!peer->SetPersistentKeepalive(atoi(value)))
+ goto getout_fail;
+ } else if (strcmp(key, "replace_allowed_ips") == 0) {
+ if (strcmp(value, "true") != 0) goto getout_fail;
+ peer->RemoveAllIps();
+ } else if (strcmp(key, "allowed_ip") == 0) {
+ if (!ParseCidrAddr(value, &cidr_addr)) goto getout_fail;
+ peer->AddIp(cidr_addr);
+ }
+ }
+ } else {
+ goto getout_fail;
+ }
+ }
+
+ // reconfigure the tun interface?
+ if (did_set_address) {
+ proc->ConfigureTun();
+ }
+
+ result->append("errno=0\n\n");
+ return true;
+
+getout_fail:
+ (*result) = "errno=1\n\n";
+ return false;
+}
+
+
diff --git a/wireguard_config.h b/wireguard_config.h
index 01d9678..791925d 100644
--- a/wireguard_config.h
+++ b/wireguard_config.h
@@ -30,11 +30,19 @@ private:
};
+class WgConfig {
+public:
+ static bool HandleConfigurationProtocolMessage(WireguardProcessor *proc, const std::string &&message, std::string *result);
+private:
+ static void HandleConfigurationProtocolGet(WireguardProcessor *proc, std::string *result);
+};
+
bool ParseWireGuardConfigFile(WireguardProcessor *wg, const char *filename, DnsResolver *dns_resolver);
#define kSizeOfAddress 64
const char *print_ip_prefix(char buf[kSizeOfAddress], int family, const void *ip, int prefixlen);
char *PrintIpAddr(const IpAddr &addr, char buf[kSizeOfAddress]);
-
+char *PrintWgCidrAddr(const WgCidrAddr &addr, char buf[kSizeOfAddress]);
+bool ParseCidrAddr(char *s, WgCidrAddr *out);
#endif // TINYVPN_TINYVPN_H_
diff --git a/wireguard_proto.cpp b/wireguard_proto.cpp
index c4a3b50..a7a5567 100644
--- a/wireguard_proto.cpp
+++ b/wireguard_proto.cpp
@@ -54,18 +54,27 @@ bool ReplayDetector::CheckReplay(uint64 seq_nr) {
WgDevice::WgDevice() {
peers_ = NULL;
+ last_peer_ptr_ = &peers_;
delegate_ = NULL;
header_obfuscation_ = false;
+ is_private_key_initialized_ = false;
next_rng_slot_ = 0;
+ main_thread_scheduled_ = NULL;
+ main_thread_scheduled_last_ = &main_thread_scheduled_;
memset(&compression_header_, 0, sizeof(compression_header_));
low_resolution_timestamp_ = cookie_secret_timestamp_ = OsGetMilliseconds();
OsGetRandomBytes(cookie_secret_, sizeof(cookie_secret_));
OsGetRandomBytes((uint8*)random_number_input_, sizeof(random_number_input_));
- SetCurrentThreadAsMainThread();
+ main_thread_id_ = GetCurrentThreadId();
+
+ memset(s_priv_, 0, sizeof(s_priv_));
+ memset(s_pub_, 0, sizeof(s_pub_));
}
WgDevice::~WgDevice() {
+ assert(IsMainThread());
+ RemoveAllPeers();
}
void WgDevice::SecondLoop(uint64 now) {
@@ -151,7 +160,8 @@ static inline void ComputeHKDF2DH(uint8 ci[WG_HASH_LEN], uint8 k[WG_SYMMETRIC_KE
memzero_crypto(dh, sizeof(dh));
}
-void WgDevice::Initialize(const uint8 private_key[WG_PUBLIC_KEY_LEN]) {
+void WgDevice::SetPrivateKey(const uint8 private_key[WG_PUBLIC_KEY_LEN]) {
+ assert(IsMainThread());
// Derive the public key from the private key.
memcpy(s_priv_, private_key, sizeof(s_priv_));
curve25519_donna(s_pub_, s_priv_, kCurve25519Basepoint);
@@ -162,26 +172,31 @@ void WgDevice::Initialize(const uint8 private_key[WG_PUBLIC_KEY_LEN]) {
kLabelCookie, sizeof(kLabelCookie), s_pub_, sizeof(s_pub_));
BlakeX2(precomputed_mac1_key_, sizeof(precomputed_mac1_key_),
kLabelMac1, sizeof(kLabelMac1), s_pub_, sizeof(s_pub_));
+
+ is_private_key_initialized_ = true;
+
+ // Recompute peer data because it depends on my privkey
+ for (WgPeer *peer = peers_; peer; peer = peer->next_peer_)
+ peer->SetPublicKey(peer->s_remote_);
}
WgPeer *WgDevice::AddPeer() {
assert(IsMainThread());
WgPeer *peer = new WgPeer(this);
- WgPeer **pp = &peers_;
- while (*pp)
- pp = &(*pp)->next_peer_;
- *pp = peer;
return peer;
}
-WgPeer *WgDevice::GetPeerFromPublicKey(uint8 public_key[WG_PUBLIC_KEY_LEN]) {
+void WgDevice::RemoveAllPeers() {
assert(IsMainThread());
- // todo: add O(1) lookup
- for (WgPeer *peer = peers_; peer; peer = peer->next_peer_) {
- if (memcmp(peer->s_remote_, public_key, WG_PUBLIC_KEY_LEN) == 0)
- return peer;
- }
- return NULL;
+ while (peers_)
+ peers_->RemovePeer();
+}
+
+WgPeer *WgDevice::GetPeerFromPublicKey(const WgPublicKey &pubkey) {
+ assert(IsMainThread());
+
+ auto it = peer_id_lookup_.find(pubkey);
+ return (it != peer_id_lookup_.end()) ? it->second : NULL;
}
bool WgDevice::CheckCookieMac1(Packet *packet) {
@@ -230,6 +245,7 @@ void WgDevice::CreateCookieMessage(MessageHandshakeCookie *dst, Packet *packet,
}
void WgDevice::EraseKeypairAddrEntry_Locked(WgKeypair *kp) {
+ // todo: figure out how to make this multithread safe.
WgAddrEntry *ae = kp->addr_entry;
assert(ae->ref_count >= 1);
@@ -313,7 +329,6 @@ void WgDevice::SetHeaderObfuscation(const char *key) {
#endif // WITH_HEADER_OBFUSCATION
}
-
WgPeer::WgPeer(WgDevice *dev) {
assert(dev->IsMainThread());
dev_ = dev;
@@ -323,6 +338,7 @@ WgPeer::WgPeer(WgDevice *dev) {
expect_cookie_reply_ = false;
has_mac2_cookie_ = false;
pending_keepalive_ = false;
+ marked_for_delete_ = false;
allow_multicast_through_peer_ = false;
allow_endpoint_change_ = true;
supports_handshake_extensions_ = true;
@@ -331,6 +347,8 @@ WgPeer::WgPeer(WgDevice *dev) {
last_handshake_init_recv_timestamp_ = 0;
last_complete_handskake_timestamp_ = 0;
persistent_keepalive_ms_ = 0;
+ rx_bytes_ = 0;
+ tx_bytes_ = 0;
timers_ = 0;
first_queued_packet_ = NULL;
last_queued_packet_ptr_ = &first_queued_packet_;
@@ -343,15 +361,66 @@ WgPeer::WgPeer(WgDevice *dev) {
memset(last_timestamp_, 0, sizeof(last_timestamp_));
ipv4_broadcast_addr_ = 0xffffffff;
memset(features_, 0, sizeof(features_));
+ memset(preshared_key_, 0, sizeof(preshared_key_));
+ memset(&s_remote_, 0, sizeof(s_remote_));
+
+ // Insert into the parent's linked list
+ *dev_->last_peer_ptr_ = this;
+ dev_->last_peer_ptr_ = &next_peer_;
}
WgPeer::~WgPeer() {
+ // do not delete this directly, instead call RemovePeer
+ assert(marked_for_delete_);
assert(dev_->IsMainThread());
+ assert(curr_keypair_ == NULL && next_keypair_ == NULL && prev_keypair_ == NULL);
+ assert(local_key_id_during_hs_ == 0);
+ assert(first_queued_packet_ == NULL);
+}
+
+void WgPeer::DelayedDelete(void *x) {
+ WgPeer *peer = (WgPeer*)x;
+ assert(peer->dev_->IsMainThread());
+
+ if (peer->main_thread_scheduled_ != 0) {
+ WG_ACQUIRE_LOCK(peer->dev_->main_thread_scheduled_lock_);
+ // Unlink myself from the main thread scheduled list
+ for (WgPeer **pp = &peer->dev_->main_thread_scheduled_; *pp; pp = &(*pp)->main_thread_scheduled_next_) {
+ if (*pp == peer) {
+ *pp = peer->main_thread_scheduled_next_;
+ break;
+ }
+ }
+ WG_RELEASE_LOCK(peer->dev_->main_thread_scheduled_lock_);
+ }
+ delete peer;
+}
+
+void WgPeer::RemovePeer() {
+ assert(dev_->IsMainThread());
+ assert(!marked_for_delete_);
+
+ // Find and unlink the peer from the parent's peer list
+ WgPeer **pp = &dev_->peers_;
+ while (*pp != this)
+ pp = &(*pp)->next_peer_;
+ if ((*pp = next_peer_) == NULL)
+ dev_->last_peer_ptr_ = pp;
+
+ RemoveAllIps();
+ dev_->peer_id_lookup_.erase(s_remote_);
+
WG_ACQUIRE_LOCK(mutex_);
+ marked_for_delete_ = true;
ClearKeys_Locked();
ClearHandshake_Locked();
ClearPacketQueue_Locked();
WG_RELEASE_LOCK(mutex_);
+
+ // The WgPeer instance may still be accessible from
+ // worker threads that already started processing a packet,
+ // so defer the actual delete of it.
+ dev_->delayed_delete_.Add(&WgPeer::DelayedDelete, this);
}
void WgPeer::ClearKeys_Locked() {
@@ -382,21 +451,55 @@ void WgPeer::ClearPacketQueue_Locked() {
num_queued_packets_ = 0;
}
-void WgPeer::Initialize(const uint8 spub[WG_PUBLIC_KEY_LEN], const uint8 preshared_key[WG_SYMMETRIC_KEY_LEN]) {
- // Optionally use a preshared key, it defaults to all zeros.
+void WgPeer::AddPacketToPeerQueue_Locked(Packet *packet) {
+ assert(IsPeerLocked());
+ assert(!marked_for_delete_);
+ // Keep only the first MAX_QUEUED_PACKETS packets.
+ while (num_queued_packets_ >= MAX_QUEUED_PACKETS_PER_PEER) {
+ Packet *packet = first_queued_packet_;
+ first_queued_packet_ = packet->next;
+ num_queued_packets_--;
+ FreePacket(packet);
+ }
+ // Add the packet to the out queue that will get sent once handshake completes
+ *last_queued_packet_ptr_ = packet;
+ last_queued_packet_ptr_ = &packet->next;
+ packet->next = NULL;
+ num_queued_packets_++;
+}
+
+void WgPeer::SetPublicKey(const WgPublicKey &spub) {
+ assert(dev_->IsMainThread());
+ assert(IsOnlyZeros(s_remote_.bytes, sizeof(s_remote_.bytes)) ||
+ memcmp(s_remote_.bytes, spub.bytes, sizeof(s_remote_.bytes)) == 0);
+
+ s_remote_ = spub;
+ dev_->peer_id_lookup_[s_remote_] = this;
+
+ if (!dev_->is_private_key_initialized_)
+ return;
+
+ // Precompute: s_priv_pub_ := DH(sprivr, spubi)
+ curve25519_donna(s_priv_pub_, dev_->s_priv_, s_remote_.bytes);
+ // Precompute: precomputed_cookie_key_ := HASH(LABEL-COOKIE || Spub_m)
+ // precomputed_mac1_key_ := HASH(MAC1-COOKIE || Spub_m)
+ BlakeX2(precomputed_cookie_key_, sizeof(precomputed_cookie_key_),
+ kLabelCookie, sizeof(kLabelCookie), s_remote_.bytes, WG_PUBLIC_KEY_LEN);
+ BlakeX2(precomputed_mac1_key_, sizeof(precomputed_mac1_key_),
+ kLabelMac1, sizeof(kLabelMac1), s_remote_.bytes, WG_PUBLIC_KEY_LEN);
+
+ // Remove the peer's keys
+ WG_ACQUIRE_LOCK(mutex_);
+ ClearKeys_Locked();
+ ClearHandshake_Locked();
+ WG_RELEASE_LOCK(mutex_);
+}
+
+void WgPeer::SetPresharedKey(const uint8 preshared_key[WG_SYMMETRIC_KEY_LEN]) {
if (preshared_key)
memcpy(preshared_key_, preshared_key, sizeof(preshared_key_));
else
memset(preshared_key_, 0, sizeof(preshared_key_));
- // Precompute: s_priv_pub_ := DH(sprivr, spubi)
- memcpy(s_remote_, spub, sizeof(s_remote_));
- curve25519_donna(s_priv_pub_, dev_->s_priv_, s_remote_);
- // Precompute: precomputed_cookie_key_ := HASH(LABEL-COOKIE || Spub_m)
- // precomputed_mac1_key_ := HASH(MAC1-COOKIE || Spub_m)
- BlakeX2(precomputed_cookie_key_, sizeof(precomputed_cookie_key_),
- kLabelCookie, sizeof(kLabelCookie), spub, WG_PUBLIC_KEY_LEN);
- BlakeX2(precomputed_mac1_key_, sizeof(precomputed_mac1_key_),
- kLabelMac1, sizeof(kLabelMac1), spub, WG_PUBLIC_KEY_LEN);
}
// run on the client
@@ -411,7 +514,7 @@ void WgPeer::CreateMessageHandshakeInitiation(Packet *packet) {
// Hi := HASH(Ci || IDENTIFIER)
memcpy(hs_.hi, kWgInitHash, sizeof(hs_.hi));
// Hi := HASH(Hi || Spub_r)
- BlakeMix(hs_.hi, s_remote_, sizeof(s_remote_));
+ BlakeMix(hs_.hi, s_remote_.bytes, sizeof(s_remote_));
// (Epriv_r, Epub_r) := DH-GENERATE()
// msg.ephemeral = Epub_r
OsGetRandomBytes(hs_.e_priv, sizeof(hs_.e_priv));
@@ -422,7 +525,7 @@ void WgPeer::CreateMessageHandshakeInitiation(Packet *packet) {
// Hi := HASH(Hi || msg.ephemeral)
BlakeMix(hs_.hi, dst->ephemeral, sizeof(dst->ephemeral));
// (Ci, K) := KDF2(Ci, DH(epriv, spub_r))
- ComputeHKDF2DH(hs_.ci, k, hs_.e_priv, s_remote_);
+ ComputeHKDF2DH(hs_.ci, k, hs_.e_priv, s_remote_.bytes);
// msg.static = AEAD(K, 0, Spub_i, Hi)
chacha20poly1305_encrypt(dst->static_enc, dev_->s_pub_, sizeof(dev_->s_pub_), hs_.hi, sizeof(hs_.hi), 0, k);
// Hi := HASH(Hi || msg.static)
@@ -461,7 +564,7 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) {
uint8 e_priv[WG_PUBLIC_KEY_LEN];
};
union {
- uint8 spubi[WG_PUBLIC_KEY_LEN];
+ WgPublicKey spubi;
uint8 e_remote[WG_PUBLIC_KEY_LEN];
uint8 hi2[WG_HASH_LEN];
};
@@ -488,13 +591,13 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) {
// (Ci, K) := KDF2(Ci, DH(spriv, msg.ephemeral))
ComputeHKDF2DH(ci, k, dev->s_priv_, src->ephemeral);
// Spub_i = AEAD_DEC(K, 0, msg.static, Hi)
- if (!chacha20poly1305_decrypt(spubi, src->static_enc, sizeof(src->static_enc), hi, sizeof(hi), 0, k))
+ if (!chacha20poly1305_decrypt(spubi.bytes, src->static_enc, sizeof(src->static_enc), hi, sizeof(hi), 0, k))
goto getout;
// Hi := HASH(Hi || msg.static)
BlakeMix(hi, src->static_enc, sizeof(src->static_enc));
// Lookup the peer with this ID
while ((peer = dev->GetPeerFromPublicKey(spubi)) == NULL) {
- if (dev->delegate_ == NULL || !dev->delegate_->HandleUnknownPeerId(spubi, packet))
+ if (dev->delegate_ == NULL || !dev->delegate_->HandleUnknownPeerId(spubi.bytes, packet))
goto getout;
}
// (Ci, K) := KDF2(Ci, DH(sprivr, spubi))
@@ -538,7 +641,7 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) {
// Ci : = KDF2(Ci, DH(epriv, epub))
ComputeHKDF2DH(ci, NULL, e_priv, e_remote);
// Ci : = KDF2(Ci, DH(epriv, spub))
- ComputeHKDF2DH(ci, NULL, e_priv, peer->s_remote_);
+ ComputeHKDF2DH(ci, NULL, e_priv, peer->s_remote_.bytes);
// (Ci, T, K) := KDF3(Ci, Q)
blake2s_hkdf(ci, sizeof(ci), t, sizeof(t), k, sizeof(k), peer->preshared_key_, sizeof(preshared_key_), ci, WG_HASH_LEN);
// Hr := HASH(Hr || T)
@@ -548,11 +651,6 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) {
keypair = WgPeer::CreateNewKeypair(false, ci, remote_key_id, extbuf + WG_TIMESTAMP_LEN, extfield_size);
if (keypair) {
- WG_ACQUIRE_LOCK(peer->mutex_);
- peer->InsertKeypairInPeer_Locked(keypair);
- peer->OnHandshakeAuthComplete();
- WG_RELEASE_LOCK(peer->mutex_);
-
dst->sender_key_id = dev->InsertInKeyIdLookup(peer, keypair);
size_t extfield_out_size = 0;
@@ -560,8 +658,17 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) {
if (extfield_size)
extfield_out_size = peer->WriteHandshakeExtension(dst->empty_enc, keypair);
#endif // WITH_HANDSHAKE_EXT
+
+ uint32 orig_packet_size = packet->size;
packet->size = (unsigned)(sizeof(MessageHandshakeResponse) + extfield_out_size);
+ WG_ACQUIRE_LOCK(peer->mutex_);
+ peer->rx_bytes_ += orig_packet_size;
+ peer->tx_bytes_ += packet->size;
+ peer->InsertKeypairInPeer_Locked(keypair);
+ peer->OnHandshakeAuthComplete();
+ WG_RELEASE_LOCK(peer->mutex_);
+
// msg.empty := AEAD(K, 0, "", Hr)
chacha20poly1305_encrypt(dst->empty_enc, dst->empty_enc, extfield_out_size, hi, sizeof(hi), 0, k);
// Hr := HASH(Hr || "")
@@ -624,6 +731,7 @@ WgPeer *WgPeer::ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packe
peer_and_keypair->second = keypair;
WG_ACQUIRE_LOCK(peer->mutex_);
+ peer->rx_bytes_ += packet->size;
peer->InsertKeypairInPeer_Locked(keypair);
WG_RELEASE_LOCK(peer->mutex_);
@@ -651,6 +759,9 @@ void WgPeer::ParseMessageHandshakeCookie(WgDevice *dev, const MessageHandshakeCo
if (!xchacha20poly1305_decrypt(cookie, src->cookie_enc, sizeof(src->cookie_enc),
peer->sent_mac1_, sizeof(peer->sent_mac1_), src->nonce, peer->precomputed_cookie_key_))
return;
+ WG_ACQUIRE_LOCK(peer->mutex_);
+ peer->rx_bytes_ += sizeof(MessageHandshakeCookie);
+ WG_RELEASE_LOCK(peer->mutex_);
peer->expect_cookie_reply_ = false;
peer->has_mac2_cookie_ = true;
peer->mac2_cookie_timestamp_ = OsGetMilliseconds();
@@ -796,7 +907,7 @@ bool WgKeypairParseExtendedHandshake(WgKeypair *keypair, const uint8 *data, size
#endif // WITH_HANDSHAKE_EXT
-static void ActualFreeKeypair(void *x) {
+static void WgKeypairDelayedDelete(void *x) {
WgKeypair *t = (WgKeypair*)x;
if (t->aes_gcm128_context_)
free(t->aes_gcm128_context_);
@@ -808,17 +919,18 @@ void WgPeer::DeleteKeypair(WgKeypair **kp) {
*kp = NULL;
if (t) {
assert(t->peer->IsPeerLocked());
+ WgDevice *dev = t->peer->dev_;
if (t->addr_entry) {
- WG_SCOPED_RWLOCK_EXCLUSIVE(dev_->addr_entry_lookup_lock_);
- dev_->EraseKeypairAddrEntry_Locked(t);
+ WG_SCOPED_RWLOCK_EXCLUSIVE(dev->addr_entry_lookup_lock_);
+ dev->EraseKeypairAddrEntry_Locked(t);
}
if (t->local_key_id) {
- WG_SCOPED_RWLOCK_EXCLUSIVE(dev_->key_id_lookup_lock_);
- dev_->key_id_lookup_.erase(t->local_key_id);
+ WG_SCOPED_RWLOCK_EXCLUSIVE(dev->key_id_lookup_lock_);
+ dev->key_id_lookup_.erase(t->local_key_id);
t->local_key_id = 0;
}
t->recv_key_state = WgKeypair::KEY_INVALID;
- dev_->delayed_delete_.Add(&ActualFreeKeypair, t);
+ dev->delayed_delete_.Add(&WgKeypairDelayedDelete, t);
}
}
@@ -1029,8 +1141,8 @@ void WgPeer::OnHandshakeFullyComplete() {
}
// Check if any of the timeouts have expired
-uint32 WgPeer::CheckTimeouts(uint64 now) {
- assert(IsPeerLocked());
+uint32 WgPeer::CheckTimeouts_Locked(uint64 now) {
+ assert(dev_->IsMainThread() && IsPeerLocked());
uint32 t, rv = 0;
@@ -1096,7 +1208,7 @@ uint32 WgPeer::CheckTimeouts(uint64 now) {
// Check all key stuff here to avoid calling possibly expensive timestamp routines in the packet handler
void WgPeer::CheckAndUpdateTimeOfNextKeyEvent(uint64 now) {
- assert(IsPeerLocked());
+ assert(dev_->IsMainThread() && IsPeerLocked());
uint64 next_time = UINT64_MAX;
uint32 rv = 0;
@@ -1142,34 +1254,60 @@ void WgPeer::SetEndpoint(const IpAddr &sin) {
endpoint_ = sin;
}
-void WgPeer::SetPersistentKeepalive(int persistent_keepalive_secs) {
- if (persistent_keepalive_secs < 10 || persistent_keepalive_secs > 10000)
- return;
+bool WgPeer::SetPersistentKeepalive(int persistent_keepalive_secs) {
+ if (persistent_keepalive_secs < 0 || persistent_keepalive_secs > 65535)
+ return false;
persistent_keepalive_ms_ = persistent_keepalive_secs * 1000;
+ return true;
+}
+
+bool WgCidrAddrEquals(const WgCidrAddr &a, const WgCidrAddr &b) {
+ return (a.size == b.size && a.cidr == b.cidr && memcmp(a.addr, b.addr, a.size >> 3) == 0);
}
bool WgPeer::AddIp(const WgCidrAddr &cidr_addr) {
+ WgPeer *old_peer;
assert(dev_->IsMainThread());
if (cidr_addr.size == 32) {
if (cidr_addr.cidr > 32)
return false;
WG_ACQUIRE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_);
- dev_->ip_to_peer_map_.InsertV4(ReadBE32(cidr_addr.addr), cidr_addr.cidr, this);
+ old_peer = (WgPeer*)dev_->ip_to_peer_map_.InsertV4(ReadBE32(cidr_addr.addr), cidr_addr.cidr, this);
WG_RELEASE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_);
- allowed_ips_.push_back(cidr_addr);
- return true;
} else if (cidr_addr.size == 128) {
if (cidr_addr.cidr > 128)
return false;
WG_ACQUIRE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_);
- dev_->ip_to_peer_map_.InsertV6(cidr_addr.addr, cidr_addr.cidr, this);
+ old_peer = (WgPeer*)dev_->ip_to_peer_map_.InsertV6(cidr_addr.addr, cidr_addr.cidr, this);
WG_RELEASE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_);
- allowed_ips_.push_back(cidr_addr);
- return true;
} else {
return false;
}
+ if (old_peer) {
+ for (auto it = old_peer->allowed_ips_.begin(); it != old_peer->allowed_ips_.end(); ++it) {
+ if (WgCidrAddrEquals(*it, cidr_addr)) {
+ old_peer->allowed_ips_.erase(it);
+ break;
+ }
+ }
+ }
+ allowed_ips_.push_back(cidr_addr);
+ return true;
+}
+
+void WgPeer::RemoveAllIps() {
+ assert(dev_->IsMainThread());
+ WG_ACQUIRE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_);
+ for (auto it = allowed_ips_.begin(); it != allowed_ips_.end(); ++it) {
+ if (it->size == 32) {
+ dev_->ip_to_peer_map_.RemoveV4(ReadBE32(it->addr), it->cidr);
+ } else if (it->size == 128) {
+ dev_->ip_to_peer_map_.RemoveV6(it->addr, it->cidr);
+ }
+ }
+ WG_RELEASE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_);
+ allowed_ips_.clear();
}
void WgPeer::SetAllowMulticast(bool allow) {
@@ -1196,6 +1334,18 @@ bool WgPeer::AddCipher(int cipher) {
return true;
}
+void WgPeer::ScheduleNewHandshake() {
+ // Note, it's possible that the peer has already been marked for delete
+ if (main_thread_scheduled_.fetch_or(WgPeer::kMainThreadScheduled_ScheduleHandshake) == 0) {
+ main_thread_scheduled_next_ = NULL;
+ WG_ACQUIRE_LOCK(dev_->main_thread_scheduled_lock_);
+ *dev_->main_thread_scheduled_last_ = this;
+ dev_->main_thread_scheduled_last_ = &main_thread_scheduled_next_;
+ WG_RELEASE_LOCK(dev_->main_thread_scheduled_lock_);
+ // todo: in multithreaded impl need to trigger |RunAllMainThreadScheduled| to get called
+ }
+}
+
WgRateLimit::WgRateLimit() {
key1_[0] = key1_[1] = 1;
key2_[0] = key2_[1] = 1;
diff --git a/wireguard_proto.h b/wireguard_proto.h
index 36f7b51..15d1432 100644
--- a/wireguard_proto.h
+++ b/wireguard_proto.h
@@ -11,7 +11,7 @@
#include
#include
#include
-
+#include
// Threading macros that enable locks only in MT builds
#if WITH_WG_THREADING
#define WG_SCOPED_LOCK(name) ScopedLock scoped_lock(&name)
@@ -25,6 +25,7 @@
#define WG_RELEASE_RWLOCK_EXCLUSIVE(name) name.ReleaseExclusive()
#define WG_SCOPED_RWLOCK_SHARED(name) ScopedLockShared scoped_lock(&name)
#define WG_SCOPED_RWLOCK_EXCLUSIVE(name) ScopedLockExclusive scoped_lock(&name)
+#define WG_IF_LOCKS_ENABLED_ELSE(expr, def) (expr)
#else // WITH_WG_THREADING
#define WG_SCOPED_LOCK(name)
#define WG_ACQUIRE_LOCK(name)
@@ -37,6 +38,7 @@
#define WG_RELEASE_RWLOCK_EXCLUSIVE(name)
#define WG_SCOPED_RWLOCK_SHARED(name)
#define WG_SCOPED_RWLOCK_EXCLUSIVE(name)
+#define WG_IF_LOCKS_ENABLED_ELSE(expr, def) (def)
#endif // WITH_WG_THREADING
enum ProtocolTimeouts {
@@ -77,6 +79,7 @@ enum MessageFieldSizes {
WG_MAC_LEN = 16,
WG_TIMESTAMP_LEN = 12,
WG_SIPHASH_KEY_LEN = 16,
+ WG_PUBLIC_KEY_LEN_BASE64 = 44,
};
enum {
@@ -194,11 +197,9 @@ struct WgPacketCompressionVer01 {
};
STATIC_ASSERT(sizeof(WgPacketCompressionVer01) == 24, WgPacketCompressionVer01_wrong_size);
-
struct WgKeypair;
class WgPeer;
-
class WgRateLimit {
public:
WgRateLimit();
@@ -260,10 +261,26 @@ struct WgAddrEntry {
struct ScramblerSiphashKeys {
uint64 keys[4];
};
-
+
+union WgPublicKey {
+ uint8 bytes[WG_PUBLIC_KEY_LEN];
+ uint64 u64[WG_PUBLIC_KEY_LEN / 8];
+ friend bool operator==(const WgPublicKey &a, const WgPublicKey &b) {
+ return memcmp(a.bytes, b.bytes, WG_PUBLIC_KEY_LEN) == 0;
+ }
+};
+
+struct WgPublicKeyHasher {
+ size_t operator()(const WgPublicKey&a) const {
+ uint64 rv = a.u64[0] ^ a.u64[1] ^ a.u64[2] ^ a.u64[3];
+ return (size_t)(rv ^ (rv >> 32));
+ }
+};
+
class WgDevice {
friend class WgPeer;
friend class WireguardProcessor;
+ friend class WgConfig;
public:
// Can be used to customize the behavior of WgDevice
@@ -278,12 +295,15 @@ public:
WgDevice();
~WgDevice();
- // Initialize with the private key, precompute all internal keys etc.
- void Initialize(const uint8 private_key[WG_PUBLIC_KEY_LEN]);
+ // Configure with the private key, precompute all internal keys etc.
+ void SetPrivateKey(const uint8 private_key[WG_PUBLIC_KEY_LEN]);
// Create a new peer
WgPeer *AddPeer();
+ // Remove all peers
+ void RemoveAllPeers();
+
// Setup header obfuscation
void SetHeaderObfuscation(const char *key);
@@ -303,17 +323,19 @@ public:
WgRateLimit *rate_limiter() { return &rate_limiter_; }
std::unordered_map &addr_entry_map() { return addr_entry_lookup_; }
WgPacketCompressionVer01 *compression_header() { return &compression_header_; }
+ bool is_private_key_initialized() { return is_private_key_initialized_; }
bool IsMainThread() { return CurrentThreadIdEquals(main_thread_id_); }
- void SetCurrentThreadAsMainThread() { main_thread_id_ = GetCurrentThreadId(); }
+ bool IsMainOrDataThread() { return CurrentThreadIdEquals(main_thread_id_) || WG_IF_LOCKS_ENABLED_ELSE(delayed_delete_.enabled(), false); }
void SetDelegate(Delegate *del) { delegate_ = del; }
+
private:
std::pair *LookupPeerInKeyIdLookup(uint32 key_id);
WgKeypair *LookupKeypairByKeyId(uint32 key_id);
WgKeypair *LookupKeypairInAddrEntryMap(uint64 addr, uint32 slot);
// Return the peer matching the |public_key| or NULL
- WgPeer *GetPeerFromPublicKey(uint8 public_key[WG_PUBLIC_KEY_LEN]);
+ WgPeer *GetPeerFromPublicKey(const WgPublicKey &pubkey);
// Create a cookie by inspecting the source address of the |packet|
void MakeCookie(uint8 cookie[WG_COOKIE_LEN], Packet *packet);
// Insert a new entry in |key_id_lookup_|
@@ -330,7 +352,7 @@ private:
WG_DECLARE_RWLOCK(ip_to_peer_map_lock_);
// For enumerating all peers
- WgPeer *peers_;
+ WgPeer *peers_, **last_peer_ptr_;
// For hooking
Delegate *delegate_;
@@ -346,12 +368,22 @@ private:
std::unordered_map addr_entry_lookup_;
WG_DECLARE_RWLOCK(addr_entry_lookup_lock_);
+ // Mapping from peer id to peer. This may be accessed only from MT.
+ std::unordered_map peer_id_lookup_;
+
+ // Queue of things scheduled to run on the main thread.
+ WG_DECLARE_LOCK(main_thread_scheduled_lock_);
+ WgPeer *main_thread_scheduled_, **main_thread_scheduled_last_;
+
// Counter for generating new indices in |keypair_lookup_|
uint8 next_rng_slot_;
// Whether packet obfuscation is enabled
bool header_obfuscation_;
+ // Whether a private key has been setup for the device
+ bool is_private_key_initialized_;
+
ThreadId main_thread_id_;
uint64 low_resolution_timestamp_;
@@ -382,15 +414,16 @@ private:
class WgPeer {
friend class WgDevice;
friend class WireguardProcessor;
+ friend class WgConfig;
friend bool WgKeypairParseExtendedHandshake(WgKeypair *keypair, const uint8 *data, size_t data_size);
friend void WgKeypairSetupCompressionExtension(WgKeypair *keypair, const WgPacketCompressionVer01 *remotec);
public:
explicit WgPeer(WgDevice *dev);
~WgPeer();
- void Initialize(const uint8 spub[WG_PUBLIC_KEY_LEN], const uint8 preshared_key[WG_SYMMETRIC_KEY_LEN]);
-
- void SetPersistentKeepalive(int persistent_keepalive_secs);
+ void SetPublicKey(const WgPublicKey &spub);
+ void SetPresharedKey(const uint8 preshared_key[WG_SYMMETRIC_KEY_LEN]);
+ bool SetPersistentKeepalive(int persistent_keepalive_secs);
void SetEndpoint(const IpAddr &sin);
void SetAllowMulticast(bool allow);
@@ -398,15 +431,14 @@ public:
bool AddCipher(int cipher);
void SetCipherPrio(bool prio) { cipher_prio_ = prio; }
bool AddIp(const WgCidrAddr &cidr_addr);
+ void RemoveAllIps();
static WgPeer *ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet);
static WgPeer *ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packet);
static void ParseMessageHandshakeCookie(WgDevice *dev, const MessageHandshakeCookie *src);
void CreateMessageHandshakeInitiation(Packet *packet);
bool CheckSwitchToNextKey_Locked(WgKeypair *keypair);
- void ClearKeys_Locked();
- void ClearHandshake_Locked();
- void ClearPacketQueue_Locked();
+ void RemovePeer();
bool CheckHandshakeRateLimit();
// Timer notifications
@@ -422,25 +454,25 @@ public:
ACTION_SEND_KEEPALIVE = 1,
ACTION_SEND_HANDSHAKE = 2,
};
- uint32 CheckTimeouts(uint64 now);
+ uint32 CheckTimeouts_Locked(uint64 now);
- void AddPacketToPeerQueue(Packet *packet);
-
-#if WITH_WG_THREADING
- bool IsPeerLocked() { return mutex_.IsLocked(); }
-#else // WITH_WG_THREADING
- bool IsPeerLocked() { return true; }
-#endif // WITH_WG_THREADING
+ void AddPacketToPeerQueue_Locked(Packet *packet);
+ bool IsPeerLocked() { return WG_IF_LOCKS_ENABLED_ELSE(mutex_.IsLocked(), true); }
private:
static WgKeypair *CreateNewKeypair(bool is_initiator, const uint8 key[WG_HASH_LEN], uint32 send_key_id, const uint8 *extfield, size_t extfield_size);
void WriteMacToPacket(const uint8 *data, MessageMacs *mac);
- void DeleteKeypair(WgKeypair **kp);
void CheckAndUpdateTimeOfNextKeyEvent(uint64 now);
+ static void DeleteKeypair(WgKeypair **kp);
static void CopyEndpointToPeer_Locked(WgKeypair *keypair, const IpAddr *addr);
+ static void DelayedDelete(void *x);
size_t WriteHandshakeExtension(uint8 *dst, WgKeypair *keypair);
void InsertKeypairInPeer_Locked(WgKeypair *keypair);
-
+ void ClearKeys_Locked();
+ void ClearHandshake_Locked();
+ void ClearPacketQueue_Locked();
+ void ScheduleNewHandshake();
+
WgDevice *dev_;
WgPeer *next_peer_;
@@ -492,6 +524,10 @@ private:
// Whether |mac2_cookie_| is valid.
bool has_mac2_cookie_;
+ // Whether the WgPeer has been deleted (i.e. RemovePeer has been called),
+ // and will be deleted as soon as the threads sync.
+ bool marked_for_delete_;
+
// Number of handshakes made so far, when this gets too high we stop connecting.
uint8 handshake_attempts_;
@@ -517,7 +553,10 @@ private:
uint8 cipher_prio_;
uint8 num_ciphers_;
uint8 ciphers_[MAX_CIPHERS];
-
+
+ uint64 rx_bytes_;
+ uint64 tx_bytes_;
+
// Handshake state that gets setup in |CreateMessageHandshakeInitiation| and used in
// the response.
struct HandshakeState {
@@ -530,7 +569,7 @@ private:
};
HandshakeState hs_;
// Remote's static public key - init only.
- uint8 s_remote_[WG_PUBLIC_KEY_LEN];
+ WgPublicKey s_remote_;
// Remote's preshared key - init only.
uint8 preshared_key_[WG_SYMMETRIC_KEY_LEN];
// Precomputed DH(spriv_local, spub_remote) - init only.