Support multiple interfaces and the 'ts' command line tool

This commit is contained in:
Ludvig Strigeus 2018-09-15 18:22:05 +02:00
parent 7b7fb6126b
commit 6d916e9aaa
39 changed files with 4243 additions and 1530 deletions

12
.gitignore vendored
View file

@ -1,18 +1,10 @@
/Debug/
/Release/
/ipzip2/Debug/
/Build
/Win32/
/TunSafe.aps
/*.sdf
/*vcxproj.user
*vcxproj.user
/*.opensdf
/*.suo
/.vs/
/x64/
/Azire.conf
/build/
/*.psess
/*.vspx
/installer/*.zip
/config/
/tunsafe.com/

View file

@ -5,6 +5,8 @@ VisualStudioVersion = 15.0.26403.7
MinimumVisualStudioVersion = 10.0.40219.1
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "TunSafe", "TunSafe.vcxproj", "{626FBC16-64C6-407D-BC2B-6C087794E0D0}"
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ts", "ts.vcxproj", "{443E105E-8D7C-401F-BD41-D3F56C76104B}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Win32 = Debug|Win32
@ -21,8 +23,19 @@ Global
{626FBC16-64C6-407D-BC2B-6C087794E0D0}.Release|Win32.Build.0 = Release|Win32
{626FBC16-64C6-407D-BC2B-6C087794E0D0}.Release|x64.ActiveCfg = Release|x64
{626FBC16-64C6-407D-BC2B-6C087794E0D0}.Release|x64.Build.0 = Release|x64
{443E105E-8D7C-401F-BD41-D3F56C76104B}.Debug|Win32.ActiveCfg = Debug|Win32
{443E105E-8D7C-401F-BD41-D3F56C76104B}.Debug|Win32.Build.0 = Debug|Win32
{443E105E-8D7C-401F-BD41-D3F56C76104B}.Debug|x64.ActiveCfg = Debug|x64
{443E105E-8D7C-401F-BD41-D3F56C76104B}.Debug|x64.Build.0 = Debug|x64
{443E105E-8D7C-401F-BD41-D3F56C76104B}.Release|Win32.ActiveCfg = Release|Win32
{443E105E-8D7C-401F-BD41-D3F56C76104B}.Release|Win32.Build.0 = Release|Win32
{443E105E-8D7C-401F-BD41-D3F56C76104B}.Release|x64.ActiveCfg = Release|x64
{443E105E-8D7C-401F-BD41-D3F56C76104B}.Release|x64.Build.0 = Release|x64
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {F2DD9ED8-DDEA-4B40-9208-41726750D33D}
EndGlobalSection
EndGlobal

View file

@ -22,7 +22,7 @@
<ProjectGuid>{626FBC16-64C6-407D-BC2B-6C087794E0D0}</ProjectGuid>
<Keyword>Win32Proj</Keyword>
<RootNamespace>TunSafe</RootNamespace>
<WindowsTargetPlatformVersion>10.0.15063.0</WindowsTargetPlatformVersion>
<WindowsTargetPlatformVersion>10.0.17134.0</WindowsTargetPlatformVersion>
<ProjectName>TunSafe</ProjectName>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
@ -72,24 +72,28 @@
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<LinkIncremental>true</LinkIncremental>
<TargetName>TunSafe</TargetName>
<OutDir>$(SolutionDir)$(Platform)\$(Configuration)\</OutDir>
<IntDir>$(Platform)\$(Configuration)\</IntDir>
<OutDir>$(SolutionDir)build\$(Platform)_$(Configuration)\</OutDir>
<IntDir>$(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\</IntDir>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<LinkIncremental>true</LinkIncremental>
<ExecutablePath>$(VC_ExecutablePath_x64);$(WindowsSDK_ExecutablePath);$(VS_ExecutablePath);$(MSBuild_ExecutablePath);$(FxCopDir);$(PATH);C:\Bin\Dev\nasm</ExecutablePath>
<TargetName>TunSafe</TargetName>
<OutDir>$(SolutionDir)build\$(Platform)_$(Configuration)\</OutDir>
<IntDir>$(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\</IntDir>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<LinkIncremental>false</LinkIncremental>
<TargetName>TunSafe</TargetName>
<OutDir>$(SolutionDir)$(Platform)\$(Configuration)\</OutDir>
<IntDir>$(Platform)\$(Configuration)\</IntDir>
<OutDir>$(SolutionDir)build\$(Platform)_$(Configuration)\</OutDir>
<IntDir>$(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\</IntDir>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<LinkIncremental>false</LinkIncremental>
<ExecutablePath>$(VC_ExecutablePath_x64);$(WindowsSDK_ExecutablePath);$(VS_ExecutablePath);$(MSBuild_ExecutablePath);$(FxCopDir);$(PATH);C:\Bin\Dev\nasm</ExecutablePath>
<TargetName>TunSafe</TargetName>
<IntDir>$(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\</IntDir>
<OutDir>$(SolutionDir)build\$(Platform)_$(Configuration)\</OutDir>
</PropertyGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<ClCompile>
@ -98,6 +102,7 @@
<Optimization>Disabled</Optimization>
<PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions);_CRT_SECURE_NO_WARNINGS;_CRT_SECURE_NO_WARNINGS</PreprocessorDefinitions>
<AdditionalIncludeDirectories>.</AdditionalIncludeDirectories>
<ExceptionHandling>false</ExceptionHandling>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
@ -114,6 +119,7 @@
<ForcedIncludeFiles>
</ForcedIncludeFiles>
<AdditionalIncludeDirectories>.</AdditionalIncludeDirectories>
<ExceptionHandling>false</ExceptionHandling>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
@ -133,6 +139,8 @@
<PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions);_CRT_SECURE_NO_WARNINGS</PreprocessorDefinitions>
<RuntimeLibrary>MultiThreaded</RuntimeLibrary>
<AdditionalIncludeDirectories>.</AdditionalIncludeDirectories>
<ExceptionHandling>false</ExceptionHandling>
<RuntimeTypeInfo>false</RuntimeTypeInfo>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
@ -157,6 +165,8 @@
<InlineFunctionExpansion>AnySuitable</InlineFunctionExpansion>
<OmitFramePointers>true</OmitFramePointers>
<AdditionalIncludeDirectories>.</AdditionalIncludeDirectories>
<ExceptionHandling>false</ExceptionHandling>
<RuntimeTypeInfo>false</RuntimeTypeInfo>
</ClCompile>
<Link>
<SubSystem>Windows</SubSystem>
@ -169,8 +179,10 @@
<ItemGroup>
<ClInclude Include="bit_ops.h" />
<ClInclude Include="ip_to_peer_map.h" />
<ClInclude Include="service_pipe_win32.h" />
<ClInclude Include="service_win32.h" />
<ClInclude Include="service_win32_api.h" />
<ClInclude Include="service_win32_constants.h" />
<ClInclude Include="tunsafe_config.h" />
<ClInclude Include="tunsafe_cpu.h" />
<ClInclude Include="crypto\aesgcm\aes.h" />
@ -196,6 +208,7 @@
<ItemGroup>
<ClCompile Include="benchmark.cpp" />
<ClCompile Include="ip_to_peer_map.cpp" />
<ClCompile Include="service_pipe_win32.cpp" />
<ClCompile Include="service_win32.cpp" />
<ClCompile Include="tunsafe_cpu.cpp" />
<ClCompile Include="crypto\aesgcm\aesgcm.cpp" />

View file

@ -89,6 +89,12 @@
<ClInclude Include="util_win32.h">
<Filter>Source Files\Win32</Filter>
</ClInclude>
<ClInclude Include="service_pipe_win32.h">
<Filter>Source Files\Win32</Filter>
</ClInclude>
<ClInclude Include="service_win32_constants.h">
<Filter>Source Files\Win32</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<ClCompile Include="stdafx.cpp">
@ -154,6 +160,9 @@
<ClCompile Include="ip_to_peer_map.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="service_pipe_win32.cpp">
<Filter>Source Files\Win32</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<ResourceCompile Include="TunSafe.rc" />

View file

@ -1,4 +1,4 @@
g++7 -I . -O2 -DNDEBUG -static -mssse3 -o tunsafe benchmark.cpp tunsafe_cpu.cpp wireguard_config.cpp ip_to_peer_map.cpp tunsafe_threading.cpp \
wireguard.cpp wireguard_proto.cpp util.cpp network_bsd.cpp network_bsd_common.cpp \
wireguard.cpp wireguard_proto.cpp ts.cpp util.cpp network_bsd.cpp network_bsd_common.cpp \
crypto/blake2s.cpp crypto/blake2s_sse.cpp crypto/chacha20poly1305.cpp crypto/curve25519-donna.cpp \
crypto/siphash.cpp crypto/chacha20_x64_gas.s crypto/poly1305_x64_gas.s ipzip2/ipzip2.cpp -lrt -pthread

View file

@ -1,6 +1,6 @@
#!/bin/sh
clang++-6.0 -c -march=skylake-avx512 crypto/poly1305_x64_gas.s crypto/chacha20_x64_gas.s
clang++-6.0 -I . -O3 -DNDEBUG -mssse3 -pthread -lrt -o tunsafe util.cpp wireguard_config.cpp wireguard.cpp ip_to_peer_map.cpp tunsafe_threading.cpp \
clang++-6.0 -I . -O3 -DNDEBUG -mssse3 -pthread -lrt -o tunsafe util.cpp wireguard_config.cpp wireguard.cpp ts.cpp ip_to_peer_map.cpp tunsafe_threading.cpp \
wireguard_proto.cpp network_bsd.cpp network_bsd_common.cpp tunsafe_cpu.cpp benchmark.cpp crypto/blake2s.cpp crypto/blake2s_sse.cpp crypto/chacha20poly1305.cpp \
crypto/curve25519-donna.cpp crypto/siphash.cpp chacha20_x64_gas.o crypto/aesgcm/aesni_gcm_x64_gas.s \
crypto/aesgcm/aesni_x64_gas.s crypto/aesgcm/aesgcm.cpp poly1305_x64_gas.o ipzip2/ipzip2.cpp \

View file

@ -3,8 +3,8 @@ set -e
clang++ -c -mavx512f -mavx512vl crypto/poly1305_x64_gas_macosx.s crypto/chacha20_x64_gas_macosx.s
clang++ -g -O3 -I . -std=c++11 -DNDEBUG=1 -fno-exceptions -fno-rtti -ffunction-sections -o tunsafe \
wireguard_config.cpp ip_to_peer_map.cpp tunsafe_threading.cpp wireguard.cpp wireguard_proto.cpp util.cpp network_bsd.cpp network_bsd_common.cpp benchmark.cpp tunsafe_cpu.cpp \
clang++ -g -O3 -I . -std=c++11 -DNDEBUG=1 -Wno-deprecated-declarations -fno-exceptions -fno-rtti -ffunction-sections -o tunsafe \
wireguard_config.cpp ip_to_peer_map.cpp tunsafe_threading.cpp wireguard.cpp wireguard_proto.cpp ts.cpp util.cpp network_bsd.cpp network_bsd_common.cpp benchmark.cpp tunsafe_cpu.cpp \
crypto/blake2s.cpp crypto/blake2s_sse.cpp crypto/chacha20poly1305.cpp crypto/curve25519-donna.cpp \
crypto/siphash.cpp crypto/aesgcm/aesgcm.cpp ipzip2/ipzip2.cpp \
crypto/aesgcm/aesni_gcm_x64_gas_macosx.s crypto/aesgcm/aesni_x64_gas_macosx.s crypto/aesgcm/ghash_x64_gas_macosx.s \

View file

@ -57,10 +57,12 @@ Section "TunSafe Client" SecTunSafe
DetailPrint "Installing 64-bit version of TunSafe."
SetOutPath "$INSTDIR"
File "x64\TunSafe.exe"
File "x64\ts.exe"
${Else}
DetailPrint "Installing 32-bit version of TunSafe."
SetOutPath "$INSTDIR"
File "x86\TunSafe.exe"
File "x86\ts.exe"
${EndIf}
File "License.txt"
File "ChangeLog.txt"
@ -205,6 +207,7 @@ Section "Uninstall"
Delete "$INSTDIR\TunSafe.exe"
Delete "$INSTDIR\ts.exe"
Delete "$INSTDIR\License.txt"
Delete "$INSTDIR\ChangeLog.txt"
Delete "$INSTDIR\Config\TunSafe.conf"

View file

@ -5,6 +5,8 @@
#include "bit_ops.h"
#include <string.h>
#include <assert.h>
#include <stdlib.h>
#include "util.h"
IpToPeerMap::IpToPeerMap() {
@ -13,18 +15,22 @@ IpToPeerMap::IpToPeerMap() {
IpToPeerMap::~IpToPeerMap() {
}
bool IpToPeerMap::InsertV4(uint32 ip, int cidr, void *peer) {
ipv4_.Insert(ip, cidr, peer);
return true;
void *IpToPeerMap::InsertV4(uint32 ip, int cidr, void *peer) {
ipv4_.Insert(ip, cidr, &peer);
return peer;
}
bool IpToPeerMap::InsertV6(const void *addr, int cidr, void *peer) {
void *IpToPeerMap::InsertV6(const void *addr, int cidr, void *peer) {
Entry6 e;
for (auto it = ipv6_.begin(); it != ipv6_.end(); ++it) {
if (it->cidr_len == cidr && memcmp(it->ip, addr, 16) == 0)
return exch(it->peer, peer);
}
e.cidr_len = cidr;
e.peer = peer;
memcpy(e.ip, addr, 16);
ipv6_.push_back(e);
return true;
return NULL;
}
void *IpToPeerMap::LookupV4(uint32 ip) {
@ -43,6 +49,19 @@ void *IpToPeerMap::LookupV6DefaultPeer() {
return NULL;
}
void IpToPeerMap::RemoveV4(uint32 ip, int cidr) {
ipv4_.Delete(ip, cidr);
}
void IpToPeerMap::RemoveV6(const void *addr, int cidr) {
for (auto it = ipv6_.begin(); it != ipv6_.end(); ++it) {
if (it->cidr_len == cidr && memcmp(it->ip, addr, 16) == 0) {
ipv6_.erase(it);
return;
}
}
}
static int CalculateIPv6CommonPrefix(const uint8 *a, const uint8 *b) {
uint64 x = ToBE64(*(uint64*)&a[0] ^ *(uint64*)&b[0]);
uint64 y = ToBE64(*(uint64*)&a[8] ^ *(uint64*)&b[8]);
@ -62,20 +81,6 @@ void *IpToPeerMap::LookupV6(const void *addr) {
return best_peer;
}
void IpToPeerMap::RemovePeer(void *peer) {
assert(0);
// todo: remove peer also from ipv4_
{
size_t n = ipv6_.size();
Entry6 *r = &ipv6_[0], *w = r;
for (size_t i = 0; i != n; i++, r++) {
if (r->peer != peer)
*w++ = *r;
}
ipv6_.resize(w - &ipv6_[0]);
}
}
#pragma warning (disable: 4200) // warning C4200: nonstandard extension used: zero-sized array in struct/union
struct RoutingTrie32::Node {
uint32 key;
@ -175,7 +180,7 @@ RoutingTrie32::~RoutingTrie32() {
RoutingTrie32::Value RoutingTrie32::Lookup(uint32 ip) {
uint32 key = ip;
Node *n = root_, *pn = n, *ppn;
int cindex = 0;
uint32 cindex = 0;
if (!n)
return NULL;
// Find the longest prefix match
@ -232,7 +237,7 @@ backtrace:
}
// strip lsb of cindex and find child
cindex &= cindex - 1;
assert(cindex < (1 << pn->bits));
assert(cindex < (1U << pn->bits));
n = pn->child[cindex];
if (!NODE_IS_NULL_OR_OLEAF(n))
break;
@ -246,7 +251,7 @@ backtrace:
}
}
bool RoutingTrie32::InsertLeafInto(Node **nn, uint8 leaf_pos, Value value) {
bool RoutingTrie32::InsertLeafInto(Node **nn, uint8 leaf_pos, Value *valuep) {
// put higher cidr higher up
Node *n = *nn;
assert(IS_LEAF(n));
@ -255,12 +260,12 @@ bool RoutingTrie32::InsertLeafInto(Node **nn, uint8 leaf_pos, Value value) {
if (leaf_pos < n->pos)
break;
if (leaf_pos == n->pos) {
n->leaf_value = value;
std::swap(n->leaf_value, *valuep);
return true;
}
nn = &n->leaf_next;
} while ((n = *nn) != NULL);
Node *leaf = NewLeaf(key, leaf_pos, value);
Node *leaf = NewLeaf(key, leaf_pos, *valuep);
if (leaf == NULL)
return false;
leaf->leaf_next = *nn;
@ -283,14 +288,14 @@ void RoutingTrie32::PutChild(Node *pn, uint32 i, Node *n) {
assert(pn->full_children < 0x80000000);
}
bool RoutingTrie32::Insert(uint32 ip, int cidr, Value value) {
bool RoutingTrie32::Insert(uint32 ip, int cidr, Value *valuep) {
uint32 key = ip;
Node **nn = &root_, *n = root_, *pn = NULL, *leaf, *tn = NULL, *leaf_to_free = NULL;
uint8 leaf_pos = 32 - cidr;
if (n == NULL) {
root_ = NewLeaf(key, leaf_pos, value);
return (root_ != NULL);
root_ = NewLeaf(key, leaf_pos, exch_null(*valuep));
return false;
}
assert(!NODE_IS_OLEAF(n));
@ -316,7 +321,7 @@ force_add:
if (IS_LEAF(n)) {
if (key != n->key)
goto force_add;
return InsertLeafInto(nn, leaf_pos, value);
return InsertLeafInto(nn, leaf_pos, valuep);
}
pn = n;
nn = &n->child[index];
@ -330,6 +335,7 @@ force_add:
*nn = n;
}
}
Value value = *valuep;
// Create either leaf or oleaf
if (tn->pos == leaf_pos) {
leaf = VALUE_TO_OLEAF(value);
@ -338,8 +344,8 @@ force_add:
FreeNode(tn);
return false;
}
// -- Start making irreversible changes here
*valuep = NULL;
if (leaf_to_free)
FreeNode(leaf_to_free);

View file

@ -15,7 +15,7 @@ public:
~RoutingTrie32();
NOINLINE Value Lookup(uint32 ip);
NOINLINE Value LookupExact(uint32 ip, int cidr);
bool Insert(uint32 ip, int cidr, Value value);
bool Insert(uint32 ip, int cidr, Value *value);
bool Delete(uint32 ip, int cidr);
private:
@ -31,7 +31,7 @@ private:
static void PutChild(Node *pn, uint32 i, Node *n);
static void ReplaceChild(Node **pnp, Node *n);
static Node *ConvertOleafToLeaf(Node *pn, uint32 i, Node *n);
static bool InsertLeafInto(Node **n, uint8 leaf_pos, RoutingTrie32::Value value);
static bool InsertLeafInto(Node **n, uint8 leaf_pos, RoutingTrie32::Value *value);
};
@ -43,8 +43,8 @@ public:
~IpToPeerMap();
// Inserts an IP address of a given CIDR length into the lookup table, pointing to peer.
bool InsertV4(uint32 ip, int cidr, void *peer);
bool InsertV6(const void *addr, int cidr, void *peer);
void *InsertV4(uint32 ip, int cidr, void *peer);
void *InsertV6(const void *addr, int cidr, void *peer);
// Lookup the peer matching the IP Address
void *LookupV4(uint32 ip);
@ -53,8 +53,8 @@ public:
void *LookupV4DefaultPeer();
void *LookupV6DefaultPeer();
// Remove a peer from the table
void RemovePeer(void *peer);
void RemoveV4(uint32 ip, int cidr);
void RemoveV6(const void *addr, int cidr);
private:
struct Entry6 {
uint8 ip[16];

View file

@ -18,7 +18,6 @@
#pragma warning (disable: 4200)
void OsGetRandomBytes(uint8 *dst, size_t dst_size);
uint64 OsGetMilliseconds();
void OsGetTimestampTAI64N(uint8 dst[12]);
void OsInterruptibleSleep(int millis);
@ -127,13 +126,13 @@ public:
uint8 neighbor_discovery_spoofing_mac[6];
};
virtual bool Initialize(const TunConfig &&config, TunConfigOut *out) = 0;
virtual bool Configure(const TunConfig &&config, TunConfigOut *out) = 0;
virtual void WriteTunPacket(Packet *packet) = 0;
};
class UdpInterface {
public:
virtual bool Initialize(int listen_port) = 0;
virtual bool Configure(int listen_port) = 0;
virtual void WriteUdpPacket(Packet *packet) = 0;
};

View file

@ -13,10 +13,12 @@
#include <string.h>
#include <arpa/inet.h>
#include <sys/stat.h>
#include <sys/un.h>
#include <stdlib.h>
#include <errno.h>
#include <assert.h>
#include <signal.h>
#include <poll.h>
static Packet *freelist;
@ -49,6 +51,7 @@ void FreePackets() {
}
}
class TunsafeBackendBsdImpl : public TunsafeBackendBsd {
public:
TunsafeBackendBsdImpl();
@ -61,7 +64,7 @@ public:
virtual void WriteTunPacket(Packet *packet) override;
// -- from UdpInterface
virtual bool Initialize(int listen_port) override;
virtual bool Configure(int listen_port) override;
virtual void WriteUdpPacket(Packet *packet) override;
virtual void HandleSigAlrm() override { got_sig_alarm_ = true; }
@ -72,12 +75,19 @@ private:
bool ReadFromTun();
bool WriteToUdp();
bool WriteToTun();
bool InitializeUnixDomainSocket(const char *devname);
// Exists for the unix domain sockets
struct SockInfo {
bool is_listener;
std::string inbuf, outbuf;
};
bool HandleSpecialPollfd(struct pollfd *pollfd, struct SockInfo *sockinfo);
void CloseSpecialPollfd(size_t i);
void SetUdpFd(int fd);
void SetTunFd(int fd);
inline void RecomputeMaxFd() { max_fd_ = ((tun_fd_>udp_fd_) ? tun_fd_ : udp_fd_) + 1; }
int tun_fd_, udp_fd_, max_fd_;
bool got_sig_alarm_;
bool exit_;
@ -89,13 +99,25 @@ private:
Packet *read_packet_;
fd_set readfds_, writefds_;
enum {
kMaxPollFd = 5,
kPollFdTun = 0,
kPollFdUdp = 1,
kPollFdUnix = 2,
};
unsigned int pollfd_num_;
struct pollfd pollfd_[kMaxPollFd];
struct SockInfo sockinfo_[kMaxPollFd - 2];
struct sockaddr_un un_addr_;
UnixSocketDeletionWatcher un_deletion_watcher_;
};
TunsafeBackendBsdImpl::TunsafeBackendBsdImpl()
: tun_fd_(-1),
udp_fd_(-1),
tun_readable_(false),
: tun_readable_(false),
tun_writable_(false),
udp_readable_(false),
udp_writable_(false),
@ -106,35 +128,39 @@ TunsafeBackendBsdImpl::TunsafeBackendBsdImpl()
udp_queue_(NULL),
udp_queue_end_(&udp_queue_),
read_packet_(NULL) {
RecomputeMaxFd();
FD_ZERO(&readfds_);
FD_ZERO(&writefds_);
read_packet_ = AllocPacket();
for(size_t i = 0; i < kMaxPollFd; i++)
pollfd_[i].fd = -1;
pollfd_num_ = 3;
sockinfo_[0].is_listener = true;
memset(&un_addr_, 0, sizeof(un_addr_));
}
TunsafeBackendBsdImpl::~TunsafeBackendBsdImpl() {
if (un_addr_.sun_path[0])
unlink(un_addr_.sun_path);
if (read_packet_)
FreePacket(read_packet_);
for(size_t i = 0; i < pollfd_num_; i++)
close(pollfd_[i].fd);
}
void TunsafeBackendBsdImpl::SetUdpFd(int fd) {
udp_fd_ = fd;
RecomputeMaxFd();
pollfd_[kPollFdUdp].fd = fd;
pollfd_[kPollFdUdp].events = POLLIN;
udp_writable_ = true;
}
void TunsafeBackendBsdImpl::SetTunFd(int fd) {
tun_fd_ = fd;
RecomputeMaxFd();
pollfd_[kPollFdTun].fd = fd;
pollfd_[kPollFdTun].events = POLLIN;
tun_writable_ = true;
}
bool TunsafeBackendBsdImpl::ReadFromUdp(bool overload) {
socklen_t sin_len;
sin_len = sizeof(read_packet_->addr.sin);
int r = recvfrom(udp_fd_, read_packet_->data, kPacketCapacity, 0,
int r = recvfrom(pollfd_[kPollFdUdp].fd, read_packet_->data, kPacketCapacity, 0,
(sockaddr*)&read_packet_->addr.sin, &sin_len);
if (r >= 0) {
// printf("Read %d bytes from UDP\n", r);
@ -157,11 +183,12 @@ bool TunsafeBackendBsdImpl::ReadFromUdp(bool overload) {
bool TunsafeBackendBsdImpl::WriteToUdp() {
assert(udp_writable_);
// RINFO("Send %d bytes to %s", (int)udp_queue_->size, inet_ntoa(udp_queue_->sin.sin_addr));
int r = sendto(udp_fd_, udp_queue_->data, udp_queue_->size, 0,
int r = sendto(pollfd_[kPollFdUdp].fd, udp_queue_->data, udp_queue_->size, 0,
(sockaddr*)&udp_queue_->addr.sin, sizeof(udp_queue_->addr.sin));
if (r < 0) {
if (errno == EAGAIN) {
udp_writable_ = false;
pollfd_[kPollFdUdp].events = POLLIN | POLLOUT;
return false;
}
perror("Write to UDP failed");
@ -185,7 +212,7 @@ static inline bool IsCompatibleProto(uint32 v) {
bool TunsafeBackendBsdImpl::ReadFromTun() {
assert(tun_readable_);
Packet *packet = read_packet_;
int r = read(tun_fd_, packet->data - TUN_PREFIX_BYTES, kPacketCapacity + TUN_PREFIX_BYTES);
int r = read(pollfd_[kPollFdTun].fd, packet->data - TUN_PREFIX_BYTES, kPacketCapacity + TUN_PREFIX_BYTES);
if (r >= 0) {
// printf("Read %d bytes from TUN\n", r);
packet->size = r - TUN_PREFIX_BYTES;
@ -215,10 +242,11 @@ bool TunsafeBackendBsdImpl::WriteToTun() {
if (TUN_PREFIX_BYTES) {
WriteBE32(tun_queue_->data - TUN_PREFIX_BYTES, GetProtoFromPacket(tun_queue_->data, tun_queue_->size));
}
int r = write(tun_fd_, tun_queue_->data - TUN_PREFIX_BYTES, tun_queue_->size + TUN_PREFIX_BYTES);
int r = write(pollfd_[kPollFdTun].fd, tun_queue_->data - TUN_PREFIX_BYTES, tun_queue_->size + TUN_PREFIX_BYTES);
if (r < 0) {
if (errno == EAGAIN) {
tun_writable_ = false;
pollfd_[kPollFdTun].events = POLLIN | POLLOUT;
return false;
}
RERROR("Write to tun failed");
@ -242,11 +270,13 @@ bool TunsafeBackendBsdImpl::InitializeTun(char devname[16]) {
fcntl(tun_fd, F_SETFD, FD_CLOEXEC);
fcntl(tun_fd, F_SETFL, O_NONBLOCK);
SetTunFd(tun_fd);
InitializeUnixDomainSocket(devname);
return true;
}
void TunsafeBackendBsdImpl::WriteTunPacket(Packet *packet) override {
assert(tun_fd_ >= 0);
assert(pollfd_[kPollFdTun].fd >= 0);
Packet *queue_is_used = tun_queue_;
*tun_queue_end_ = packet;
tun_queue_end_ = &packet->next;
@ -256,7 +286,7 @@ void TunsafeBackendBsdImpl::WriteTunPacket(Packet *packet) override {
}
// Called to initialize udp
bool TunsafeBackendBsdImpl::Initialize(int listen_port) override {
bool TunsafeBackendBsdImpl::Configure(int listen_port) override {
int udp_fd = open_udp(listen_port);
if (udp_fd < 0) { RERROR("Error opening udp"); return false; }
fcntl(udp_fd, F_SETFD, FD_CLOEXEC);
@ -266,7 +296,7 @@ bool TunsafeBackendBsdImpl::Initialize(int listen_port) override {
}
void TunsafeBackendBsdImpl::WriteUdpPacket(Packet *packet) override {
assert(udp_fd_ >= 0);
assert(pollfd_[kPollFdUdp].fd >= 0);
Packet *queue_is_used = udp_queue_;
*udp_queue_end_ = packet;
udp_queue_end_ = &packet->next;
@ -275,16 +305,137 @@ void TunsafeBackendBsdImpl::WriteUdpPacket(Packet *packet) override {
WriteToUdp();
}
bool TunsafeBackendBsdImpl::InitializeUnixDomainSocket(const char *devname) {
int fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (fd == -1) {
RERROR("Error creating unix domain socket");
return false;
}
fcntl(fd, F_SETFD, FD_CLOEXEC);
fcntl(fd, F_SETFL, O_NONBLOCK);
mkdir("/var/run/wireguard", 0755);
un_addr_.sun_family = AF_UNIX;
snprintf(un_addr_.sun_path, sizeof(un_addr_.sun_path), "/var/run/wireguard/%s.sock", devname);
unlink(un_addr_.sun_path);
if (bind(fd, (struct sockaddr*)&un_addr_, sizeof(un_addr_)) == -1) {
RERROR("Error binding unix domain socket");
close(fd);
return false;
}
if (listen(fd, 5) == -1) {
RERROR("Error listening on unix domain socket");
close(fd);
return false;
}
pollfd_[kPollFdUnix].fd = fd;
pollfd_[kPollFdUnix].events = POLLIN;
return true;
}
static const char *FindMessageEnd(const char *start, size_t size) {
if (size <= 1)
return NULL;
const char *start_end = start + size - 1;
for(;(start = (const char*)memchr(start, '\n', start_end - start)) != NULL; start++) {
if (start[1] == '\n')
return start + 2;
}
return NULL;
}
bool TunsafeBackendBsdImpl::HandleSpecialPollfd(struct pollfd *pfd, struct SockInfo *sockinfo) {
// handle domain socket thing
if (sockinfo->is_listener) {
if (pfd->revents & POLLIN) {
// wait if we can't allocate more pollfd
if (pollfd_num_ == kMaxPollFd) {
pfd->events = 0;
return true;
}
int fd = accept(pfd->fd, NULL, NULL);
if (fd >= 0) {
size_t slot = pollfd_num_++;
pollfd_[slot].fd = fd;
pollfd_[slot].events = POLLIN;
pollfd_[slot].revents = 0;
sockinfo_[slot - 2].is_listener = false;
} else {
RERROR("Unix domain socket accept failed");
}
}
if (pfd->revents & ~POLLIN) {
RERROR("Unix domain socket got an error code");
return false;
}
return true;
}
if (pfd->revents & POLLIN) {
char buf[4096];
// read as much data as we can until we see \n\n
ssize_t n = recv(pfd->fd, buf, sizeof(buf), 0);
if (n <= 0)
return (n == -1 && errno == EAGAIN); // premature eof or error
sockinfo->inbuf.append(buf, n);
const char *message_end = FindMessageEnd(&sockinfo->inbuf[0], sockinfo->inbuf.size());
if (message_end) {
if (message_end != &sockinfo->inbuf[sockinfo->inbuf.size()])
return false; // trailing data?
WgConfig::HandleConfigurationProtocolMessage(processor_, std::move(sockinfo->inbuf), &sockinfo->outbuf);
if (!sockinfo->outbuf.size())
return false;
pfd->revents = pfd->events = POLLOUT;
}
}
if (pfd->revents & POLLOUT) {
size_t n = send(pfd->fd, sockinfo->outbuf.data(), sockinfo->outbuf.size(), 0);
if (n <= 0)
return (n == -1 && errno == EAGAIN); // premature eof or error
sockinfo->outbuf.erase(0, n);
if (!sockinfo->outbuf.size())
return false;
}
if (pfd->revents & ~(POLLIN | POLLOUT)) {
RERROR("Unix domain socket got an error code");
return false;
}
return true;
}
void TunsafeBackendBsdImpl::CloseSpecialPollfd(size_t i) {
close(pollfd_[i].fd);
pollfd_[i].fd = -1;
sockinfo_[i - 2].inbuf.clear();
sockinfo_[i - 2].outbuf.clear();
pollfd_[i] = pollfd_[(size_t)pollfd_num_ - 1];
std::swap(sockinfo_[i - 2], sockinfo_[(size_t)pollfd_num_ - 1 - 2]);
// Can now allow more sockets?
if (pollfd_num_-- == kMaxPollFd && sockinfo_[kPollFdUnix - 2].is_listener)
pollfd_[kPollFdUnix].events = POLLIN;
}
void TunsafeBackendBsdImpl::RunLoopInner() {
int free_packet_interval = 10;
int overload_ctr = 0;
if (!un_deletion_watcher_.Start(un_addr_.sun_path, &exit_))
return;
while (!exit_) {
int n = -1;
// This is not fully signal safe.
if (got_sig_alarm_) {
got_sig_alarm_ = false;
if (un_deletion_watcher_.Poll(un_addr_.sun_path)) {
RINFO("Unix socket %s deleted.", un_addr_.sun_path);
break;
}
processor_->SecondLoop();
if (free_packet_interval == 0) {
@ -296,33 +447,53 @@ void TunsafeBackendBsdImpl::RunLoopInner() {
overload_ctr -= (overload_ctr != 0);
}
if (tun_fd_ >= 0) {
FD_SET(tun_fd_, &readfds_);
if (tun_writable_) FD_CLR(tun_fd_, &writefds_); else FD_SET(tun_fd_, &writefds_);
}
if (udp_fd_ >= 0) {
FD_SET(udp_fd_, &readfds_);
if (udp_writable_) FD_CLR(udp_fd_, &writefds_); else FD_SET(udp_fd_, &writefds_);
}
n = select(max_fd_, &readfds_, &writefds_, NULL, NULL);
#if defined(OS_LINUX) || defined(OS_FREEBSD)
n = ppoll(pollfd_, pollfd_num_, NULL, &orig_signal_mask_);
#else
n = poll(pollfd_, pollfd_num_, -1);
#endif
if (n == -1) {
if (errno != EINTR) {
fprintf(stderr, "select failed\n");
RERROR("poll failed");
break;
}
} else {
if (tun_fd_ >= 0) {
tun_readable_ = (FD_ISSET(tun_fd_, &readfds_) != 0);
tun_writable_ |= (FD_ISSET(tun_fd_, &writefds_) != 0);
if (pollfd_[kPollFdTun].revents & (POLLERR | POLLHUP | POLLNVAL)) {
if (pollfd_[kPollFdTun].revents & POLLERR) {
tun_interface_gone_ = true;
RERROR("Tun interface gone, closing.");
} else {
RERROR("Tun interface error %d, closing.", pollfd_[kPollFdTun].revents);
}
break;
}
if (udp_fd_ >= 0) {
udp_readable_ = (FD_ISSET(udp_fd_, &readfds_) != 0);
udp_writable_ |= (FD_ISSET(udp_fd_, &writefds_) != 0);
tun_readable_ = (pollfd_[kPollFdTun].revents & POLLIN) != 0;
if (pollfd_[kPollFdTun].revents & POLLOUT) {
pollfd_[kPollFdTun].events = POLLIN;
tun_writable_ = true;
}
if (pollfd_[kPollFdUdp].revents & (POLLERR | POLLHUP | POLLNVAL)) {
RERROR("UDP error %d, closing.", pollfd_[kPollFdUdp].revents);
break;
}
udp_readable_ = (pollfd_[kPollFdUdp].revents & POLLIN) != 0;
if (pollfd_[kPollFdUdp].revents & POLLOUT) {
pollfd_[kPollFdUdp].events = POLLIN;
udp_writable_ = true;
}
for(size_t i = 2; i < pollfd_num_; i++) {
if (pollfd_[i].revents && !HandleSpecialPollfd(&pollfd_[i], &sockinfo_[i - 2])) {
// Close the fd / discard the sockinfo
CloseSpecialPollfd(i);
i--;
}
}
}
bool overload = (overload_ctr != 0);
for(int loop = 0; ; loop++) {
@ -342,6 +513,8 @@ void TunsafeBackendBsdImpl::RunLoopInner() {
processor_->RunAllMainThreadScheduled();
}
un_deletion_watcher_.Stop();
}
TunsafeBackendBsd *CreateTunsafeBackendBsd() {

View file

@ -39,6 +39,8 @@
#include <linux/if_tun.h>
#include <sys/prctl.h>
#include <linux/rtnetlink.h>
#include <sys/inotify.h>
#include <limits.h>
#endif
void tunsafe_die(const char *msg) {
@ -286,15 +288,6 @@ void OsGetTimestampTAI64N(uint8 dst[12]) {
WriteBE32(dst + 8, nanos);
}
void OsGetRandomBytes(uint8 *data, size_t data_size) {
int fd = open("/dev/urandom", O_RDONLY);
int r = read(fd, data, data_size);
if (r < 0) r = 0;
close(fd);
for (; r < data_size; r++)
data[r] = rand() >> 6;
}
void OsInterruptibleSleep(int millis) {
usleep((useconds_t)millis * 1000);
}
@ -387,11 +380,12 @@ int open_tun(char *devname, size_t devname_size) {
memset(&ifr, 0, sizeof(ifr));
ifr.ifr_flags = IFF_TUN | IFF_NO_PI;
my_strlcpy(ifr.ifr_name, sizeof(ifr.ifr_name), devname);
if ((err = ioctl(fd, TUNSETIFF, (void *) &ifr)) < 0) {
close(fd);
return err;
}
strcpy(devname, ifr.ifr_name);
my_strlcpy(devname, devname_size, ifr.ifr_name);
return fd;
}
#endif
@ -411,6 +405,8 @@ int open_udp(int listen_on_port) {
TunsafeBackendBsd::TunsafeBackendBsd()
: processor_(NULL) {
devname_[0] = 0;
tun_interface_gone_ = false;
}
TunsafeBackendBsd::~TunsafeBackendBsd() {
@ -495,10 +491,10 @@ void TunsafeBackendBsd::DelRoute(const RouteInfo &cd) {
static bool IsIpv6AddressSet(const void *p) {
return (ReadLE64(p) | ReadLE64((char*)p + 8)) != 0;
}
// Called to initialize tun
bool TunsafeBackendBsd::Initialize(const TunConfig &&config, TunConfigOut *out) override {
char devname[16];
bool TunsafeBackendBsd::Configure(const TunConfig &&config, TunConfigOut *out) override {
char buf[kSizeOfAddress];
if (!RunPrePostCommand(config.pre_post_commands.pre_up)) {
RERROR("Pre command failed!");
@ -507,17 +503,35 @@ bool TunsafeBackendBsd::Initialize(const TunConfig &&config, TunConfigOut *out)
out->enable_neighbor_discovery_spoofing = false;
if (!InitializeTun(devname))
if (!InitializeTun(devname_))
return false;
if (config.ipv6_cidr)
RERROR("IPv6 not supported");
uint32 netmask = CidrToNetmaskV4(config.cidr);
uint32 default_route_v4 = ComputeIpv4DefaultRoute(config.ip, netmask);
RunCommand("/sbin/ifconfig %s %A mtu %d %A netmask %A up", devname, config.ip, config.mtu, config.ip, netmask);
AddRoute(config.ip & netmask, config.cidr, config.ip, devname);
#if defined(OS_LINUX)
if (config.ip) {
char ip[4];
WriteBE32(ip, config.ip);
RunCommand("/sbin/ip address add dev %s %s", devname_, print_ip_prefix(buf, AF_INET, ip, config.cidr));
}
if (config.ipv6_cidr) {
RunCommand("/sbin/ip address add dev %s %s", devname_, print_ip_prefix(buf, AF_INET6, config.ipv6_address, config.ipv6_cidr));
}
RunCommand("/sbin/ip link set dev %s mtu %d up", devname_, config.mtu);
#else // !defined(OS_LINUX)
if (config.ip) {
RunCommand("/sbin/ifconfig %s %A mtu %d %A netmask %A up", devname_, config.ip, config.mtu, config.ip, netmask);
}
if (config.ipv6_cidr) {
RunCommand("/sbin/ifconfig %s inet6 add %s", devname_, print_ip_prefix(buf, AF_INET6, config.ipv6_address, config.ipv6_cidr));
}
#endif // !defined(OS_LINUX)
if (config.ip) {
AddRoute(config.ip & netmask, config.cidr, config.ip, devname_);
}
if (config.use_ipv4_default_route) {
if (config.default_route_endpoint_v4) {
@ -533,35 +547,30 @@ bool TunsafeBackendBsd::Initialize(const TunConfig &&config, TunConfigOut *out)
AddRoute(ReadBE32(it->addr), it->cidr, ipv4_default_gw, default_iface);
}
}
AddRoute(0x00000000, 1, default_route_v4, devname);
AddRoute(0x80000000, 1, default_route_v4, devname);
AddRoute(0x00000000, 1, default_route_v4, devname_);
AddRoute(0x80000000, 1, default_route_v4, devname_);
}
uint8 default_route_v6[16];
if (config.ipv6_cidr) {
static const uint8 matchall_1_route[17] = {0x80, 0, 0, 0};
char buf[kSizeOfAddress];
ComputeIpv6DefaultRoute(config.ipv6_address, config.ipv6_cidr, default_route_v6);
RunCommand("/sbin/ifconfig %s inet6 add %s", devname, print_ip_prefix(buf, AF_INET6, config.ipv6_address, config.ipv6_cidr));
if (config.use_ipv6_default_route) {
if (IsIpv6AddressSet(config.default_route_endpoint_v6)) {
RERROR("default_route_endpoint_v6 not supported");
}
AddRoute(AF_INET6, matchall_1_route + 1, 1, default_route_v6, devname);
AddRoute(AF_INET6, matchall_1_route + 0, 1, default_route_v6, devname);
AddRoute(AF_INET6, matchall_1_route + 1, 1, default_route_v6, devname_);
AddRoute(AF_INET6, matchall_1_route + 0, 1, default_route_v6, devname_);
}
}
// Add all the extra routes
for (auto it = config.extra_routes.begin(); it != config.extra_routes.end(); ++it) {
if (it->size == 32) {
AddRoute(ReadBE32(it->addr), it->cidr, default_route_v4, devname);
AddRoute(ReadBE32(it->addr), it->cidr, default_route_v4, devname_);
} else if (it->size == 128 && config.ipv6_cidr) {
AddRoute(AF_INET6, it->addr, it->cidr, default_route_v6, devname);
AddRoute(AF_INET6, it->addr, it->cidr, default_route_v6, devname_);
}
}
@ -576,8 +585,10 @@ bool TunsafeBackendBsd::Initialize(const TunConfig &&config, TunConfigOut *out)
void TunsafeBackendBsd::CleanupRoutes() {
RunPrePostCommand(pre_down_);
for(auto it = cleanup_commands_.begin(); it != cleanup_commands_.end(); ++it)
DelRoute(*it);
for(auto it = cleanup_commands_.begin(); it != cleanup_commands_.end(); ++it) {
if (!tun_interface_gone_ || strcmp(it->dev.c_str(), devname_) != 0)
DelRoute(*it);
}
cleanup_commands_.clear();
RunPrePostCommand(post_down_);
@ -586,6 +597,10 @@ void TunsafeBackendBsd::CleanupRoutes() {
post_down_.clear();
}
void TunsafeBackendBsd::SetTunDeviceName(const char *name) {
my_strlcpy(devname_, sizeof(devname_), name);
}
static bool RunOneCommand(const std::string &cmd) {
RINFO("Run: %s", cmd.c_str());
int exit_code = system(cmd.c_str());
@ -604,6 +619,94 @@ bool TunsafeBackendBsd::RunPrePostCommand(const std::vector<std::string> &vec) {
return success;
}
#if defined(OS_LINUX)
UnixSocketDeletionWatcher::UnixSocketDeletionWatcher()
: inotify_fd_(-1) {
pipes_[0] = -1;
pipes_[0] = -1;
}
UnixSocketDeletionWatcher::~UnixSocketDeletionWatcher() {
close(inotify_fd_);
close(pipes_[0]);
close(pipes_[1]);
}
bool UnixSocketDeletionWatcher::Start(const char *path, bool *flag_to_set) {
assert(inotify_fd_ == -1);
path_ = path;
flag_to_set_ = flag_to_set;
pid_ = getpid();
inotify_fd_ = inotify_init1(IN_CLOEXEC | IN_NONBLOCK);
if (inotify_fd_ == -1) {
perror("inotify_init1() failed");
return false;
}
if (inotify_add_watch(inotify_fd_, "/var/run/wireguard", IN_DELETE | IN_DELETE_SELF) == -1) {
perror("inotify_add_watch failed");
return false;
}
if (pipe(pipes_) == -1) {
perror("pipe() failed");
return false;
}
return pthread_create(&thread_, NULL, &UnixSocketDeletionWatcher::RunThread, this) == 0;
}
void UnixSocketDeletionWatcher::Stop() {
RINFO("Stopping..");
void *retval;
write(pipes_[1], "", 1);
pthread_join(thread_, &retval);
}
void *UnixSocketDeletionWatcher::RunThread(void *arg) {
UnixSocketDeletionWatcher *self = (UnixSocketDeletionWatcher*)arg;
return self->RunThreadInner();
}
void *UnixSocketDeletionWatcher::RunThreadInner() {
char buf[sizeof(struct inotify_event) + NAME_MAX + 1]
__attribute__ ((aligned(__alignof__(struct inotify_event))));
fd_set fdset;
struct stat st;
for(;;) {
if (lstat(path_, &st) == -1 && errno == ENOENT) {
RINFO("Unix socket %s deleted.", path_);
*flag_to_set_ = true;
kill(pid_, SIGALRM);
break;
}
FD_ZERO(&fdset);
FD_SET(inotify_fd_, &fdset);
FD_SET(pipes_[0], &fdset);
int n = select(std::max(inotify_fd_, pipes_[0]) + 1, &fdset, NULL, NULL, NULL);
if (n == -1) {
perror("select");
break;
}
if (FD_ISSET(inotify_fd_, &fdset)) {
ssize_t len = read(inotify_fd_, buf, sizeof(buf));
if (len == -1) {
perror("read");
break;
}
}
if (FD_ISSET(pipes_[0], &fdset))
break;
}
return NULL;
}
#else // !defined(OS_LINUX)
bool UnixSocketDeletionWatcher::Poll(const char *path) {
struct stat st;
return lstat(path, &st) == -1 && errno == ENOENT;
}
#endif // !defined(OS_LINUX)
static TunsafeBackendBsd *g_tunsafe_backend_bsd;
static void SigAlrm(int sig) {
@ -611,10 +714,6 @@ static void SigAlrm(int sig) {
g_tunsafe_backend_bsd->HandleSigAlrm();
}
static void SigUsr1(int sig) {
}
static bool did_ctrlc;
void SigInt(int sig) {
@ -623,6 +722,7 @@ void SigInt(int sig) {
did_ctrlc = true;
write(1, "Ctrl-C detected. Exiting. Press again to force quit.\n", sizeof("Ctrl-C detected. Exiting. Press again to force quit.\n")-1);
// todo: fix signal safety?
if (g_tunsafe_backend_bsd)
g_tunsafe_backend_bsd->HandleExit();
}
@ -631,7 +731,10 @@ void TunsafeBackendBsd::RunLoop() {
assert(!g_tunsafe_backend_bsd);
assert(processor_);
sigset_t mask;
g_tunsafe_backend_bsd = this;
// We want an alarm signal every second.
{
struct sigaction act = {0};
@ -651,16 +754,14 @@ void TunsafeBackendBsd::RunLoop() {
}
}
{
struct sigaction act = {0};
act.sa_handler = SigUsr1;
if (sigaction(SIGUSR1, &act, NULL) < 0) {
RERROR("Unable to install SIGUSR1 handler.");
return;
}
#if defined(OS_LINUX) || defined(OS_FREEBSD)
sigemptyset(&mask);
sigaddset(&mask, SIGALRM);
if (sigprocmask(SIG_BLOCK, &mask, &orig_signal_mask_) < 0) {
perror("sigprocmask");
return;
}
#if defined(OS_LINUX) || defined(OS_FREEBSD)
{
struct itimerspec tv = {0};
struct sigevent sev;
@ -727,7 +828,17 @@ public:
bool is_connected_;
};
struct CommandLineOutput {
const char *filename_to_load;
const char *interface_name;
bool daemon;
};
int HandleCommandLine(int argc, char **argv, CommandLineOutput *output);
int main(int argc, char **argv) {
CommandLineOutput cmd = {0};
InitCpuFeatures();
if (argc == 2 && strcmp(argv[1], "--benchmark") == 0) {
@ -735,12 +846,9 @@ int main(int argc, char **argv) {
return 0;
}
fprintf(stderr, "%s\n", TUNSAFE_VERSION_STRING);
if (argc < 2) {
fprintf(stderr, "Syntax: tunsafe file.conf\n");
return 1;
}
int rv = HandleCommandLine(argc, argv, &cmd);
if (!cmd.filename_to_load)
return rv;
#if defined(OS_MACOSX)
InitOsxGetMilliseconds();
@ -749,19 +857,29 @@ int main(int argc, char **argv) {
SetThreadName("tunsafe-m");
MyProcessorDelegate my_procdel;
TunsafeBackendBsd *socket_loop = CreateTunsafeBackendBsd();
WireguardProcessor wg(socket_loop, socket_loop, &my_procdel);
TunsafeBackendBsd *backend = CreateTunsafeBackendBsd();
if (cmd.interface_name)
backend->SetTunDeviceName(cmd.interface_name);
WireguardProcessor wg(backend, backend, &my_procdel);
my_procdel.wg_processor_ = &wg;
socket_loop->SetProcessor(&wg);
backend->SetProcessor(&wg);
DnsResolver dns_resolver(NULL);
if (!ParseWireGuardConfigFile(&wg, argv[1], &dns_resolver)) return 1;
if (*cmd.filename_to_load && !ParseWireGuardConfigFile(&wg, cmd.filename_to_load, &dns_resolver))
return 1;
if (!wg.Start()) return 1;
socket_loop->RunLoop();
socket_loop->CleanupRoutes();
delete socket_loop;
if (cmd.daemon) {
fprintf(stderr, "Switching to daemon mode...\n");
if (daemon(0, 0) == -1)
perror("daemon() failed");
}
backend->RunLoop();
backend->CleanupRoutes();
delete backend;
return 0;
}

View file

@ -7,6 +7,7 @@
#include "wireguard.h"
#include "wireguard_config.h"
#include <string>
#include <signal.h>
struct RouteInfo {
uint8 family;
@ -16,6 +17,39 @@ struct RouteInfo {
std::string dev;
};
#if defined(OS_LINUX)
// Keeps track of when the unix socket gets deleted
class UnixSocketDeletionWatcher {
public:
UnixSocketDeletionWatcher();
~UnixSocketDeletionWatcher();
bool Start(const char *path, bool *flag_to_set);
void Stop();
bool Poll(const char *path) { return false; }
private:
static void *RunThread(void *arg);
void *RunThreadInner();
const char *path_;
int inotify_fd_;
int pid_;
int pipes_[2];
pthread_t thread_;
bool *flag_to_set_;
};
#else // !defined(OS_LINUX)
// all other platforms that lack inotify
class UnixSocketDeletionWatcher {
public:
UnixSocketDeletionWatcher() {}
~UnixSocketDeletionWatcher() {}
bool Start(const char *path, bool *flag_to_set) { return true; }
void Stop() {}
bool Poll(const char *path);
};
#endif // !defined(OS_LINUX)
class TunsafeBackendBsd : public TunInterface, public UdpInterface {
public:
TunsafeBackendBsd();
@ -24,10 +58,12 @@ public:
void RunLoop();
void CleanupRoutes();
void SetTunDeviceName(const char *name);
void SetProcessor(WireguardProcessor *wg) { processor_ = wg; }
// -- from TunInterface
virtual bool Initialize(const TunConfig &&config, TunConfigOut *out) override;
virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override;
virtual void HandleSigAlrm() = 0;
virtual void HandleExit() = 0;
@ -44,6 +80,9 @@ protected:
WireguardProcessor *processor_;
std::vector<RouteInfo> cleanup_commands_;
std::vector<std::string> pre_down_, post_down_;
sigset_t orig_signal_mask_;
char devname_[16];
bool tun_interface_gone_;
};
#if defined(OS_MACOSX) || defined(OS_FREEBSD)

File diff suppressed because it is too large Load diff

View file

@ -11,34 +11,39 @@
#include "tunsafe_threading.h"
#include <functional>
enum {
ADAPTER_GUID_SIZE = 40,
};
struct Packet;
class WireguardProcessor;
class TunsafeBackendWin32;
class ThreadedPacketQueue {
class PacketProcessor {
public:
explicit ThreadedPacketQueue(WireguardProcessor *wg, TunsafeBackendWin32 *backend);
~ThreadedPacketQueue();
explicit PacketProcessor();
~PacketProcessor();
enum {
TARGET_PROCESSOR_UDP = 0,
TARGET_PROCESSOR_TUN = 1,
TARGET_UDP_DEVICE = 2,
TARGET_TUN_DEVICE = 3,
TARGET_CONFIG_PROTOCOL = 4,
};
void Start();
void Stop();
void Reset();
int Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend);
void Post(Packet *packet, Packet **end, int count);
void AbortingDriver();
void ForcePost(Packet *packet);
void PostExit(int exit_code);
const uint32 *posted_exit_code() { return &exit_code_; }
private:
void PostTimerInterrupt();
static void CALLBACK TimerRoutine(LPVOID lpArgToCompletionRoutine, DWORD dwTimerLowValue, DWORD dwTimerHighValue);
DWORD ThreadMain();
static DWORD WINAPI ThreadedPacketQueueLauncher(VOID *x);
static void CALLBACK ThreadPoolTimerCallback(PTP_CALLBACK_INSTANCE iTimerInstance, PVOID pContext, PTP_TIMER);
void HandleConfigurationProtocolPacket(WireguardProcessor *wg, TunsafeBackendWin32 *backend, Packet *packet);
Packet *first_;
Packet **last_ptr_;
uint32 packets_in_queue_;
@ -46,12 +51,8 @@ private:
Mutex mutex_;
HANDLE event_;
HANDLE timer_handle_;
HANDLE handle_;
WireguardProcessor *wg_;
bool exit_flag_;
uint32 exit_code_;
bool timer_interrupt_;
TunsafeBackendWin32 *backend_;
};
// Encapsulates a UDP socket, optionally listening for incoming packets
@ -61,17 +62,16 @@ public:
explicit UdpSocketWin32();
~UdpSocketWin32();
void SetPacketHandler(ThreadedPacketQueue *packet_handler) { packet_handler_ = packet_handler; }
void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
void StartThread();
void StopThread();
// -- from UdpInterface
virtual bool Initialize(int listen_on_port) override;
virtual bool Configure(int listen_on_port) override;
virtual void WriteUdpPacket(Packet *packet) override;
private:
void ThreadMain();
static DWORD WINAPI UdpThread(void *x);
@ -80,7 +80,7 @@ private:
Mutex mutex_;
ThreadedPacketQueue *packet_handler_;
PacketProcessor *packet_handler_;
SOCKET socket_;
SOCKET socket_ipv6_;
HANDLE completion_port_handle_;
@ -93,12 +93,12 @@ class DnsBlocker;
class TunWin32Adapter {
public:
TunWin32Adapter(DnsBlocker *dns_blocker);
TunWin32Adapter(DnsBlocker *dns_blocker, const char guid[ADAPTER_GUID_SIZE]);
~TunWin32Adapter();
bool OpenAdapter(unsigned int *exit_thread, DWORD open_flags);
bool InitAdapter(const TunInterface::TunConfig &&config, TunInterface::TunConfigOut *out);
void CloseAdapter();
bool OpenAdapter(TunsafeBackendWin32 *backend, DWORD open_flags);
bool ConfigureAdapter(const TunInterface::TunConfig &&config, TunInterface::TunConfigOut *out);
void CloseAdapter(bool is_restart);
HANDLE handle() { return handle_; }
@ -121,8 +121,10 @@ private:
NET_LUID interface_luid_;
void *backend_;
std::vector<std::string> pre_down_, post_down_;
char guid_[64];
char guid_[ADAPTER_GUID_SIZE];
};
// Implementation of TUN interface handling using IO Completion Ports
@ -131,23 +133,23 @@ public:
explicit TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend);
~TunWin32Iocp();
void SetPacketHandler(ThreadedPacketQueue *packet_handler) { packet_handler_ = packet_handler; }
void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
void StartThread();
void StopThread();
// -- from TunInterface
virtual bool Initialize(const TunConfig &&config, TunConfigOut *out) override;
virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override;
virtual void WriteTunPacket(Packet *packet) override;
TunWin32Adapter &adapter() { return adapter_; }
private:
void CloseTun();
void CloseTun(bool is_restart);
void ThreadMain();
static DWORD WINAPI TunThread(void *x);
ThreadedPacketQueue *packet_handler_;
PacketProcessor *packet_handler_;
HANDLE completion_port_handle_;
HANDLE thread_;
@ -168,13 +170,13 @@ public:
explicit TunWin32Overlapped(DnsBlocker *blocker, TunsafeBackendWin32 *backend);
~TunWin32Overlapped();
void SetPacketHandler(ThreadedPacketQueue *packet_handler) { packet_handler_ = packet_handler; }
void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
void StartThread();
void StopThread();
// -- from TunInterface
virtual bool Initialize(const TunConfig &&config, TunConfigOut *out) override;
virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override;
virtual void WriteTunPacket(Packet *packet) override;
private:
@ -182,7 +184,7 @@ private:
void ThreadMain();
static DWORD WINAPI TunThread(void *x);
ThreadedPacketQueue *packet_handler_;
PacketProcessor *packet_handler_;
HANDLE thread_;
Mutex mutex_;
@ -199,16 +201,18 @@ private:
};
class TunsafeBackendWin32 : public TunsafeBackend, public ProcessorDelegate {
friend class ThreadedPacketQueue;
friend class PacketProcessor;
friend class TunWin32Iocp;
friend class TunWin32Overlapped;
friend class TunWin32Adapter;
public:
TunsafeBackendWin32(Delegate *delegate);
~TunsafeBackendWin32();
// -- from TunsafeBackend
virtual bool Initialize() override;
virtual bool Configure() override;
virtual void Teardown() override;
virtual bool SetTunAdapterName(const char *name) override;
virtual void Start(const char *config_file) override;
virtual void Stop() override;
virtual void RequestStats(bool enable) override;
@ -218,13 +222,23 @@ public:
virtual void SetServiceStartupFlags(uint32 flags) override;
virtual LinearizedGraph *GetGraph(int type) override;
virtual std::string GetConfigFileName() override;
virtual void SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) override;
// -- from ProcessorDelegate
virtual void OnConnected() override;
virtual void OnConnectionRetry(uint32 attempts) override;
void SetPublicKey(const uint8 key[32]);
void TunAdapterFailed();
void PostExit(int exit_code);
enum {
MODE_NONE = 0,
MODE_EXIT = 1,
MODE_RESTART = 2,
MODE_TUN_FAILED = 3,
};
uint32 exit_code() { return *packet_processor_.posted_exit_code(); }
void SetStatus(StatusCode status);
private:
void StopInner(bool is_restart);
@ -232,16 +246,7 @@ private:
void PushStats();
HANDLE worker_thread_;
enum {
MODE_NONE = 0,
MODE_EXIT = 1,
MODE_RESTART = 2,
MODE_TUN_FAILED = 3,
};
bool want_periodic_stats_;
unsigned int stop_mode_;
Delegate *delegate_;
char *config_file_;
@ -256,6 +261,10 @@ private:
Mutex stats_mutex_;
WgProcessorStats stats_;
PacketProcessor packet_processor_;
char guid_[ADAPTER_GUID_SIZE];
};
// This class ensures that all callbacks get rescheduled to another thread
@ -265,13 +274,14 @@ public:
~TunsafeBackendDelegateThreaded();
private:
virtual void OnGetStats(const WgProcessorStats &stats);
virtual void OnGraphAvailable();
virtual void OnStateChanged();
virtual void OnClearLog();
virtual void OnLogLine(const char **s);
virtual void OnStatusCode(TunsafeBackend::StatusCode status);
virtual void DoWork();
virtual void OnGetStats(const WgProcessorStats &stats) override;
virtual void OnGraphAvailable() override;
virtual void OnStateChanged() override;
virtual void OnClearLog() override;
virtual void OnLogLine(const char **s) override;
virtual void OnStatusCode(TunsafeBackend::StatusCode status) override;
virtual void OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) override;
virtual void DoWork() override;
enum Which {
Id_OnGetStats,
@ -281,6 +291,7 @@ private:
Id_OnUpdateUI,
Id_OnStatusCode,
Id_OnGraphAvailable,
Id_OnConfigurationProtocolReply,
};
void AddEntry(Which which, intptr_t lparam = 0, uint32 wparam = 0);
@ -302,3 +313,37 @@ private:
std::vector<Entry> processing_entry_;
};
// For each adapter, remembers whether the adapter is in use
class TunAdaptersInUse {
public:
TunAdaptersInUse();
// attempt to acquire the adapter, so it can't be acquired by anyone else
bool Acquire(const char guid[ADAPTER_GUID_SIZE], void *context);
// mark as free
void Release(void *context);
// Lookup a context from a guid
void *LookupContextFromGuid(const char guid[ADAPTER_GUID_SIZE]);
// Lookup a guid from a context
bool LookupGuidFromContext(void *context, char guid[ADAPTER_GUID_SIZE]);
char *GetAllGuid();
static TunAdaptersInUse *GetInstance();
private:
enum {
kMaxAdaptersInUse = 16,
};
struct Entry {
char guid[ADAPTER_GUID_SIZE];
void *context;
int count;
};
Mutex mutex_;
uint8 num_inuse_;
Entry entry_[kMaxAdaptersInUse];
};

View file

@ -5,7 +5,6 @@
#include "stdafx.h"
#include "tunsafe_types.h"
#include "wireguard.h"
#include <functional>
struct StatsCollector {
@ -72,6 +71,7 @@ public:
virtual void OnClearLog() = 0;
virtual void OnLogLine(const char **s) = 0;
virtual void OnStatusCode(TunsafeBackend::StatusCode status) = 0;
virtual void OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) = 0;
// This function is needed for CreateTunsafeBackendDelegateThreaded,
// It's expected to be called on the main thread and then all callbacks will arrive
// on the right thread.
@ -82,9 +82,16 @@ public:
virtual ~TunsafeBackend();
// Setup/teardown the connection to the local service (if any)
virtual bool Initialize() = 0;
virtual bool Configure() = 0;
virtual void Teardown() = 0;
// Set the name of the tun adapter that we want to use.
// On Windows this is the guid of the adapter.
// After having called this, this tun name cannot be used by any other instances.
// Returns false if the name can't be exclusively reserved to this adapter.
virtual bool SetTunAdapterName(const char *name) = 0;
virtual void Start(const char *config_file) = 0;
virtual void Stop() = 0;
virtual void RequestStats(bool enable) = 0;
@ -93,10 +100,9 @@ public:
virtual InternetBlockState GetInternetBlockState(bool *is_activated) = 0;
virtual void SetInternetBlockState(InternetBlockState s) = 0;
virtual void SetServiceStartupFlags(uint32 flags) = 0;
virtual std::string GetConfigFileName() = 0;
virtual LinearizedGraph *GetGraph(int type) = 0;
virtual void SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) = 0;
bool is_started() { return is_started_; }
bool is_remote() { return is_remote_; }
@ -105,6 +111,9 @@ public:
StatusCode status() { return status_; }
uint32 GetIP() { return ipv4_ip_; }
static TunsafeBackend *FindBackendByTunGuid(const char *guid);
static char *GetAllGuid();
protected:
bool is_started_;
bool is_remote_;

374
service_pipe_win32.cpp Normal file
View file

@ -0,0 +1,374 @@
// SPDX-License-Identifier: AGPL-1.0-only
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
#include "stdafx.h"
#include "service_pipe_win32.h"
#include "util.h"
#include "service_win32_constants.h"
///////////////////////////////////////////////////////////////////////////////////////
// PipeManager
///////////////////////////////////////////////////////////////////////////////////////
PipeManager::PipeManager(const char *pipe_name, bool is_server_pipe, Delegate *delegate) {
pipe_name_ = _strdup(pipe_name);
is_server_pipe_ = is_server_pipe;
for (size_t i = 0; i < kMaxConnections * 2 + 1; i++)
events_[i] = CreateEvent(NULL, i != 0, FALSE, NULL); // For Exit
delegate_ = delegate;
thread_ = NULL;
exit_thread_ = false;
thread_id_ = 0;
for (size_t i = 0; i != kMaxConnections; i++)
connections_[i].Configure(this, (int)i);
connections_[0].state_ = PipeConnection::kStateStarting;
}
PipeManager::~PipeManager() {
StopThread();
for (size_t i = 0; i < kMaxConnections * 2 + 1; i++)
CloseHandle(events_[i]);
free(pipe_name_);
}
bool PipeManager::StartThread() {
assert(thread_ == NULL);
thread_ = CreateThread(NULL, 0, &StaticThreadMain, this, 0, &thread_id_);
return thread_ != NULL;
}
void PipeManager::StopThread() {
if (thread_ != NULL) {
exit_thread_ = true;
SetEvent(events_[0]);
WaitForSingleObject(thread_, INFINITE);
CloseHandle(thread_);
thread_ = NULL;
}
}
bool PipeManager::VerifyThread() {
return thread_id_ == GetCurrentThreadId();
}
void PipeManager::TryStartNewListener() {
assert(VerifyThread());
assert(is_server_pipe_);
// Check if any thread is in the listener state, if not, start
PipeConnection *found_conn = NULL;
for (size_t i = 0; i < kMaxConnections; i++) {
PipeConnection *conn = &connections_[i];
if (conn->connection_established_)
continue;
if (conn->state_ == PipeConnection::kStateWaitConnect)
return;
if (conn->state_ == PipeConnection::kStateNone && found_conn == NULL)
found_conn = conn;
}
if (found_conn) {
found_conn->state_ = PipeConnection::kStateStarting;
found_conn->AdvanceStateMachine();
}
}
DWORD WINAPI PipeManager::StaticThreadMain(void *x) {
return ((PipeManager*)x)->ThreadMain();
}
DWORD PipeManager::ThreadMain() {
assert(VerifyThread());
for (size_t i = 0; i < kMaxConnections; i++)
connections_[i].AdvanceStateMachine();
for (;;) {
DWORD rv = WaitForMultipleObjects(1 + kMaxConnections * 2, events_, FALSE, INFINITE);
// notify?
if (rv == WAIT_OBJECT_0) {
if (exit_thread_)
break;
delegate_->HandleNotify();
// The notification event is set when there might be new messages to send,
// so try to send them.
for (size_t i = 0; i != kMaxConnections; i++)
connections_[i].TrySendNextQueuedWrite();
} else if (rv >= WAIT_OBJECT_0 + 1 && rv < WAIT_OBJECT_0 + 1 + kMaxConnections * 2) {
PipeConnection *conn = &connections_[(rv - 1) >> 1];
if (rv & 1) {
// read finished
conn->AdvanceStateMachine();
} else {
// is the write event
conn->HandleWriteComplete();
}
} else {
assert(0);
}
}
return 0;
}
///////////////////////////////////////////////////////////////////////////////////////
// PipeConnection
///////////////////////////////////////////////////////////////////////////////////////
static void ClearPipeOverlapped(OVERLAPPED *ov) {
ov->Internal = 0;
ov->InternalHigh = 0;
ov->Offset = 0;
ov->OffsetHigh = 0;
}
PipeConnection::PipeConnection() {
pipe_ = INVALID_HANDLE_VALUE;
packets_ = NULL;
packets_end_ = &packets_;
write_overlapped_active_ = false;
connection_established_ = false;
state_ = kStateNone;
tmp_packet_buf_ = NULL;
tmp_packet_size_ = 0;
manager_ = NULL;
delegate_ = NULL;
}
PipeConnection::~PipeConnection() {
}
void PipeConnection::Configure(PipeManager *manager, int slot) {
manager_ = manager;
read_overlapped_.hEvent = manager->events_[1 + slot * 2];
write_overlapped_.hEvent = manager->events_[1 + slot * 2 + 1];
}
int PipeConnection::InitializeServerPipeAndConnect() {
int BUFSIZE = 8192;
SECURITY_ATTRIBUTES saPipeSecurity = {0};
uint8 buf[SECURITY_DESCRIPTOR_MIN_LENGTH];
PSECURITY_DESCRIPTOR pPipeSD = (PSECURITY_DESCRIPTOR)buf;
if (!InitializeSecurityDescriptor(pPipeSD, SECURITY_DESCRIPTOR_REVISION))
return -1;
// set NULL DACL on the SD
if (!SetSecurityDescriptorDacl(pPipeSD, TRUE, (PACL)NULL, FALSE))
return -1;
// now set up the security attributes
saPipeSecurity.nLength = sizeof(SECURITY_ATTRIBUTES);
saPipeSecurity.bInheritHandle = TRUE;
saPipeSecurity.lpSecurityDescriptor = pPipeSD;
pipe_ = CreateNamedPipe(manager_->pipe_name_,
PIPE_ACCESS_DUPLEX | FILE_FLAG_OVERLAPPED,
PIPE_TYPE_BYTE | PIPE_READMODE_BYTE | PIPE_REJECT_REMOTE_CLIENTS | PIPE_WAIT,
PIPE_UNLIMITED_INSTANCES,
BUFSIZE, BUFSIZE, 0, &saPipeSecurity);
if (pipe_ == INVALID_HANDLE_VALUE)
return -1;
ClearPipeOverlapped(&read_overlapped_);
// It seems like ConnectNamedPipe never sets the event object if it completes
// right away.
if (!ConnectNamedPipe(pipe_, &read_overlapped_)) {
DWORD rv = GetLastError();
return (rv == ERROR_IO_PENDING) ? 0 : (rv == ERROR_PIPE_CONNECTED) ? 1 : -1;
} else {
return 1;
}
}
bool PipeConnection::InitializeClientPipe() {
assert(pipe_ == INVALID_HANDLE_VALUE);
pipe_ = CreateFile(manager_->pipe_name_, GENERIC_READ | GENERIC_WRITE, 0, NULL,
OPEN_EXISTING, FILE_FLAG_OVERLAPPED, NULL);
return (pipe_ != INVALID_HANDLE_VALUE);
}
void PipeConnection::ClosePipe() {
if (pipe_ != INVALID_HANDLE_VALUE) {
CancelIo(pipe_);
CloseHandle(pipe_);
pipe_ = INVALID_HANDLE_VALUE;
}
connection_established_ = false;
write_overlapped_active_ = false;
state_ = kStateNone;
free(tmp_packet_buf_);
tmp_packet_buf_ = NULL;
tmp_packet_size_ = 0;
ResetEvent(read_overlapped_.hEvent);
ResetEvent(write_overlapped_.hEvent);
packets_mutex_.Acquire();
OutgoingPacket *packets = packets_;
packets_ = NULL;
packets_end_ = &packets_;
packets_mutex_.Release();
while (packets) {
OutgoingPacket *p = packets;
packets = p->next;
free(p);
}
}
void PipeConnection::HandleWriteComplete() {
assert(write_overlapped_active_);
write_overlapped_active_ = false;
// Remove the packet from the front of the queue, now that it was sent.
packets_mutex_.Acquire();
OutgoingPacket *p = packets_;
if ((packets_ = p->next) == NULL)
packets_end_ = &packets_;
packets_mutex_.Release();
free(p);
if (packets_ == NULL && state_ == kStateWaitTimeout)
AdvanceStateMachine();
else
TrySendNextQueuedWrite();
}
bool PipeConnection::WritePacket(int type, const uint8 *data, size_t data_size) {
OutgoingPacket *packet = (OutgoingPacket *)malloc(offsetof(OutgoingPacket, data[data_size + 1]));
if (packet) {
packet->size = (uint32)(data_size + 1);
packet->data[0] = type;
memcpy(packet->data + 1, data, data_size);
packet->next = NULL;
packets_mutex_.Acquire();
OutgoingPacket *was_empty = packets_;
// login messages are always queued up front
if (type == TS_SERVICE_REQ_LOGIN) {
packet->next = packets_;
if (packet->next == NULL)
packets_end_ = &packet->next;
packets_ = packet;
} else {
*packets_end_ = packet;
packets_end_ = &packet->next;
}
packets_mutex_.Release();
if (was_empty == NULL) {
// Only allow the pipe thread to invoke the send
if (GetCurrentThreadId() == manager_->thread_id_) {
TrySendNextQueuedWrite();
} else {
SetEvent(manager_->notify_handle());
}
}
}
return true;
}
bool PipeConnection::VerifyThread() {
return manager_->VerifyThread();
}
void PipeConnection::TrySendNextQueuedWrite() {
assert(manager_->VerifyThread());
if (!write_overlapped_active_) {
OutgoingPacket *p = packets_;
if (p && connection_established_) {
ClearPipeOverlapped(&write_overlapped_);
if (WriteFile(pipe_, &p->size, p->size + 4, NULL, &write_overlapped_) || GetLastError() == ERROR_IO_PENDING)
write_overlapped_active_ = true;
} else {
ResetEvent(write_overlapped_.hEvent);
}
}
}
#define TS_WAIT_BEGIN(t) switch(state_) { case t:
#define TS_WAIT_POINT(t) state_ = (t); return; case t:
#define TS_WAIT_END() }
void PipeConnection::AdvanceStateMachine() {
DWORD rv;
int srv;
TS_WAIT_BEGIN(kStateStarting)
// Create a named pipe and wait for connections from the UI process
if (manager_->is_server_pipe_) {
srv = InitializeServerPipeAndConnect();
if (srv < 0) {
if (!manager_->exit_thread_)
ExitProcess(1);
ClosePipe();
return;
}
if (srv == 0) {
TS_WAIT_POINT(kStateWaitConnect);
}
} else {
if (!InitializeClientPipe()) {
RINFO("Unable to connect to the TunSafe Service. Please make sure it's running.");
ClosePipe();
return;
}
}
connection_established_ = true;
delegate_ = manager_->delegate_->HandleNewConnection(this);
TrySendNextQueuedWrite();
for (;;) {
// Read the packet length
read_pos_ = 0;
do {
ClearPipeOverlapped(&read_overlapped_);
if (!ReadFile(pipe_, (uint8*)&packet_size_ + read_pos_, 4 - read_pos_, NULL, &read_overlapped_)) {
if ((rv = GetLastError()) != ERROR_IO_PENDING)
goto fail;
TS_WAIT_POINT(kStateWaitReadLength);
}
if ((uint32)read_overlapped_.InternalHigh == 0)
goto fail;
read_pos_ += (uint32)read_overlapped_.InternalHigh;
} while (read_pos_ != 4);
assert(packet_size_ != 0 && packet_size_ < 0x1000000);
if (packet_size_ == 0 || packet_size_ >= 0x1000000)
break;
free(tmp_packet_buf_);
tmp_packet_buf_ = (uint8*)malloc(packet_size_);
if (!tmp_packet_buf_)
break;
// Read the packet payload
read_pos_ = 0;
do {
ClearPipeOverlapped(&read_overlapped_);
if (!ReadFile(pipe_, tmp_packet_buf_ + read_pos_, packet_size_ - read_pos_, NULL, &read_overlapped_)) {
if ((rv = GetLastError()) != ERROR_IO_PENDING)
goto fail;
TS_WAIT_POINT(kStateWaitReadPayload);
}
if ((uint32)read_overlapped_.InternalHigh == 0)
goto fail;
read_pos_ += (uint32)read_overlapped_.InternalHigh;
} while (read_pos_ != packet_size_);
if (!delegate_->HandleMessage(tmp_packet_buf_[0], tmp_packet_buf_ + 1, packet_size_ - 1)) {
ResetEvent(read_overlapped_.hEvent);
if (packets_ != NULL) {
TS_WAIT_POINT(kStateWaitTimeout);
}
break;
}
}
fail:
ClosePipe();
if (!manager_->exit_thread_) {
delegate_->HandleDisconnect();
if (manager_->is_server_pipe_)
manager_->TryStartNewListener();
}
TS_WAIT_END()
}

111
service_pipe_win32.h Normal file
View file

@ -0,0 +1,111 @@
// SPDX-License-Identifier: AGPL-1.0-only
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
#pragma once
#include "tunsafe_threading.h"
#include "network_win32_api.h"
class PipeManager;
// Once a pipe connects, this object is used to facilitate the connection
class PipeConnection {
friend class PipeManager;
public:
class Delegate {
public:
virtual bool HandleMessage(int type, uint8 *data, size_t size) = 0;
virtual void HandleDisconnect() = 0;
};
PipeConnection();
~PipeConnection();
void Configure(PipeManager *manager, int slot);
bool WritePacket(int type, const uint8 *data, size_t data_size);
HANDLE pipe_handle() { return pipe_; }
bool is_connected() { return connection_established_; }
bool VerifyThread();
private:
// -1 = fail, 0 = wait, 1 = conn
int InitializeServerPipeAndConnect();
bool InitializeClientPipe();
void AdvanceStateMachine();
void ClosePipe();
void TrySendNextQueuedWrite();
void HandleWriteComplete();
Delegate *delegate_;
PipeManager *manager_;
HANDLE pipe_;
bool write_overlapped_active_;
bool connection_established_;
enum State {
kStateNone,
kStateStarting,
kStateWaitConnect,
kStateWaitReadLength,
kStateWaitReadPayload,
kStateWaitTimeout,
};
uint8 state_;
uint32 packet_size_;
struct OutgoingPacket {
OutgoingPacket *next;
uint32 size;
uint8 data[0];
};
OutgoingPacket *packets_, **packets_end_;
uint8 *tmp_packet_buf_;
DWORD tmp_packet_size_;
DWORD read_pos_;
OVERLAPPED write_overlapped_, read_overlapped_;
Mutex packets_mutex_;
};
// This class supports multiple PipeConnections and calls HandleNewConnection
// when a new pipe connection is established.
class PipeManager {
friend class PipeConnection;
public:
class Delegate {
public:
// Called when a new connection is established
virtual PipeConnection::Delegate *HandleNewConnection(PipeConnection *handler) = 0;
// Called when a notification event was pushed
virtual void HandleNotify() = 0;
};
PipeManager(const char *pipe_name, bool is_server_pipe, Delegate *delegate);
~PipeManager();
bool StartThread();
void StopThread();
bool VerifyThread();
HANDLE notify_handle() { return events_[0]; }
PipeConnection *GetClientConnection() { return &connections_[0]; }
void TryStartNewListener();
private:
DWORD ThreadMain();
static DWORD WINAPI StaticThreadMain(void *x);
Delegate *delegate_;
HANDLE thread_;
char *pipe_name_;
DWORD thread_id_;
bool is_server_pipe_;
bool exit_thread_;
enum { kMaxConnections = 2 };
HANDLE events_[1 + kMaxConnections * 2];
PipeConnection connections_[kMaxConnections];
};

File diff suppressed because it is too large Load diff

View file

@ -3,154 +3,183 @@
#pragma once
#include "service_win32_api.h"
#include <strsafe.h>
#include "util.h"
#include "service_pipe_win32.h"
#include "network_win32_api.h"
#include "tunsafe_threading.h"
#include <algorithm>
#include <string>
#include <assert.h>
// Takes care of multiple TunsafeServiceBackend
class TunsafeServiceManager : public PipeManager::Delegate {
friend class TunsafeServiceBackend;
friend class TunsafeServiceServer;
public:
TunsafeServiceManager();
virtual ~TunsafeServiceManager();
// -- from PipeManager::Delegate
virtual void HandleNotify() override;
virtual PipeConnection::Delegate *HandleNewConnection(PipeConnection *connection) override;
// Called by the service control code to bring the service up or down
unsigned OnStart(int argc, wchar_t **argv);
void OnStop();
void OnShutdown();
TunsafeServiceBackend *main_backend() { return main_backend_; }
TunsafeServiceBackend *CreateBackend(const char *guid);
void DestroyBackend(TunsafeServiceBackend *backend);
bool SwitchInterface(TunsafeServiceServer *server, const char *interfac, bool want_create);
private:
// Points at the Tunsafe hklm reg key
HKEY hkey_;
uint32 server_unique_id_;
PipeManager pipe_manager_;
TunsafeServiceBackend *main_backend_;
std::vector<TunsafeServiceBackend *> backends_;
};
// One of these exist for each TunsafeBackend
class TunsafeServiceBackend : public TunsafeBackend::Delegate {
friend class TunsafeServiceServer;
public:
explicit TunsafeServiceBackend(TunsafeServiceManager *manager);
virtual ~TunsafeServiceBackend();
// -- from TunsafeBackend::Delegate
virtual void OnGetStats(const WgProcessorStats &stats) override;
virtual void OnClearLog() override;
virtual void OnLogLine(const char **s) override;
virtual void OnStateChanged() override;
virtual void OnStatusCode(TunsafeBackend::StatusCode status) override;
virtual void OnGraphAvailable() override;
virtual void OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) override;
TunsafeBackend *backend() { return backend_; }
TunsafeBackend::Delegate *delegate() { return thread_delegate_; }
void Start(const char *filename);
void RememberLastUsedConfigFile(const char *filename);
void Stop();
// Trigger backend stats updates whenever a connected pipe client needs it
void UpdateRequestStats();
// Called by TunsafeServiceManager::HandleNotify to process events
// on each backend.
void HandleNotify();
// Send a state update to all connected pipes unless filter is set, then it
// sends only to that.
void SendStateUpdate(TunsafeServiceServer *filter);
// Called whenever a pipe server disconnects
void RemovePipeServer(TunsafeServiceServer *pipe_server);
// Called to register a pipe server with this backend
void AddPipeServer(TunsafeServiceServer *pipe_server);
private:
// Points at the service manager
TunsafeServiceManager *manager_;
// Points at the actual TunsafeBackend
TunsafeBackend *backend_;
// Points at all |TunsafeServiceServer| currently associated with this
// backend.
std::vector<TunsafeServiceServer*> pipe_servers_;
// Points at the thing that transmits TunsafeBackend events to
// the main thread
TunsafeBackend::Delegate *thread_delegate_;
// The config filename that is loaded
std::string current_filename_;
// Positions into |historical_log_lines_|
uint32 historical_log_lines_pos_;
uint32 historical_log_lines_count_;
enum { LOGLINE_COUNT = 256 };
char *historical_log_lines_[LOGLINE_COUNT];
};
// The server side of the client<->server pipe connection
class TunsafeServiceServer : public PipeConnection::Delegate {
public:
TunsafeServiceServer(PipeConnection *pipe, TunsafeServiceBackend *backend, uint32 unique_id);
virtual ~TunsafeServiceServer();
void WritePacket(int type, const uint8 *data, size_t data_size);
// -- from PipeConnection::Delegate
virtual bool HandleMessage(int type, uint8 *data, size_t size) override;
virtual void HandleDisconnect() override;
// Called by TunsafeServiceBackend to push a graph to the client
void OnGraphAvailable();
// Called by TunsafeServiceBackend to push more log lines to the client
void SendQueuedLogLines();
bool want_stats() const { return want_stats_; }
bool want_state_updates() const { return want_state_updates_; }
uint32 unique_id() const { return unique_id_; }
TunsafeServiceBackend *service_backend() { return service_backend_; }
void set_service_backend(TunsafeServiceBackend *sb) { service_backend_ = sb; }
private:
bool AuthenticateUser();
// Whether the client wants state updates
bool want_state_updates_;
// Whether the client has authenticated
bool did_authenticate_user_;
// Whether we want stats
bool want_stats_;
// Whether the currently connected user wants a graph
uint32 want_graph_type_;
// The last log line sent to the currently connected user
uint32 last_line_sent_;
uint32 unique_id_;
// The pipe used to communicate
PipeConnection *connection_;
// The backend we're currently associated with
TunsafeServiceBackend *service_backend_;
};
struct ServiceState {
uint8 is_started : 1;
uint8 internet_block_state_active : 1;
uint8 internet_block_state;
uint8 reserved[26+64];
uint8 reserved[26 + 64];
uint32 ipv4_ip;
uint8 public_key[32];
};
STATIC_ASSERT(sizeof(ServiceState) == 128, ServiceState_wrong_size);
class PipeMessageHandler {
public:
class Delegate {
public:
virtual bool HandleMessage(int type, uint8 *data, size_t size) = 0;
virtual bool HandleNotify() = 0;
virtual void HandleNewConnection() = 0;
virtual void HandleDisconnect() = 0;
};
PipeMessageHandler(const char *pipe_name, bool is_server_pipe, Delegate *delegate);
~PipeMessageHandler();
bool StartThread();
void StopThread();
bool WritePacket(int type, const uint8 *data, size_t data_size);
HANDLE notify_handle() { return wait_handles_[1]; }
HANDLE pipe_handle() { return pipe_; }
bool VerifyThread();
bool is_connected() { return connection_established_; }
private:
bool InitializeServerPipeAndWait();
bool InitializeClientPipe();
void AdvanceStateMachine();
void ClosePipe();
DWORD ThreadMain();
void SendNextQueuedWrite();
static DWORD WINAPI StaticThreadMain(void *x);
Delegate *delegate_;
HANDLE pipe_;
HANDLE thread_;
HANDLE wait_handles_[3];
bool write_overlapped_active_;
bool exit_thread_;
bool is_server_pipe_;
bool connection_established_;
char *pipe_name_;
enum State {
kStateNone,
kStateWaitConnect,
kStateWaitReadLength,
kStateWaitReadPayload,
kStateWaitTimeout,
};
int state_;
struct OutgoingPacket {
OutgoingPacket *next;
uint32 size;
uint8 data[0];
};
OutgoingPacket *packets_, **packets_end_;
uint8 *tmp_packet_buf_;
DWORD tmp_packet_size_;
OVERLAPPED write_overlapped_, read_overlapped_;
Mutex packets_mutex_;
DWORD thread_id_;
};
class TunsafeServiceImpl : public TunsafeBackend::Delegate, public PipeMessageHandler::Delegate {
public:
TunsafeServiceImpl();
virtual ~TunsafeServiceImpl();
// -- from TunsafeBackend::Delegate
virtual void OnGetStats(const WgProcessorStats &stats);
virtual void OnClearLog();
virtual void OnLogLine(const char **s);
virtual void OnStateChanged();
virtual void OnStatusCode(TunsafeBackend::StatusCode status);
virtual void OnGraphAvailable();
// -- from PipeMessageHandler::Delegate
virtual bool HandleMessage(int type, uint8 *data, size_t size);
virtual bool HandleNotify();
virtual void HandleNewConnection();
virtual void HandleDisconnect();
// virtual methods
virtual unsigned OnStart(int argc, wchar_t **argv);
virtual void OnStop();
virtual void OnShutdown();
TunsafeBackend::Delegate *delegate() { return thread_delegate_; }
private:
void SendQueuedLogLines();
bool AuthenticateUser();
bool did_send_getstate_;
bool did_authenticate_user_;
uint32 want_graph_type_;
HKEY hkey_;
TunsafeBackend *backend_;
TunsafeBackend::Delegate *thread_delegate_;
PipeMessageHandler message_handler_;
uint32 historical_log_lines_pos_;
uint32 historical_log_lines_count_;
uint32 last_line_sent_;
std::string current_filename_;
enum {
LOGLINE_COUNT = 256
};
char *historical_log_lines_[LOGLINE_COUNT];
};
class TunsafeServiceClient : public TunsafeBackend, public PipeMessageHandler::Delegate {
class TunsafeServiceClient : public TunsafeBackend, public PipeConnection::Delegate, public PipeManager::Delegate {
public:
TunsafeServiceClient(TunsafeBackend::Delegate *delegate);
virtual ~TunsafeServiceClient();
virtual bool Initialize();
// -- from TunsafeBackend
virtual bool Configure();
virtual void Teardown();
virtual bool SetTunAdapterName(const char *name);
virtual void Start(const char *config_file);
virtual void Stop();
virtual void RequestStats(bool enable);
@ -160,12 +189,16 @@ public:
virtual std::string GetConfigFileName();
virtual void SetServiceStartupFlags(uint32 flags);
virtual LinearizedGraph *GetGraph(int type);
virtual void SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) override;
// -- from PipeConnection::Delegate
virtual bool HandleMessage(int type, uint8 *data, size_t size) override;
virtual void HandleDisconnect() override;
// -- from PipeManager::Delegate
virtual void HandleNotify() override;
virtual PipeConnection::Delegate *HandleNewConnection(PipeConnection *connection) override;
// -- from PipeMessageHandler::Delegate
virtual bool HandleMessage(int type, uint8 *data, size_t size);
virtual bool HandleNotify();
virtual void HandleNewConnection();
virtual void HandleDisconnect();
protected:
TunsafeBackend::Delegate *delegate_;
@ -173,8 +206,10 @@ protected:
bool got_state_from_control_;
ServiceState service_state_;
std::string config_file_;
PipeMessageHandler message_handler_;
PipeManager pipe_manager_;
PipeConnection *connection_;
LinearizedGraph *cached_graph_;
uint32 last_graph_type_;
Mutex mutex_;
};

36
service_win32_constants.h Normal file
View file

@ -0,0 +1,36 @@
#pragma once
#define TUNSAFE_PIPE_NAME "\\\\.\\pipe\\TunSafe\\ServiceControl"
#define TUNSAFE_SERVICE_PROTOCOL_VERSION 20180916001
enum {
TS_SERVICE_REQ_LOGIN = 0,
TS_SERVICE_REQ_START = 1,
TS_SERVICE_REQ_STOP = 2,
TS_SERVICE_REQ_GETSTATS = 4,
TS_SERVICE_REQ_SET_INTERNET_BLOCKSTATE = 5,
TS_SERVICE_REQ_RESETSTATS = 6,
TS_SERVICE_REQ_SET_STARTUP_FLAGS = 7,
TS_SERVICE_MSG_STATE = 8,
TS_SERVICE_MSG_LOGLINE = 9,
TS_SERVICE_MSG_ERROR_REPLY = 10,
TS_SERVICE_MSG_STATS = 11,
TS_SERVICE_MSG_CLEARLOG = 12,
TS_SERVICE_MSG_STATUS_CODE = 14,
TS_SERVICE_REQ_GET_GRAPH = 15,
TS_SERVICE_MSG_GRAPH = 16,
TS_SERVICE_REQ_TEXT_PROTOCOL = 17,
TS_SERVICE_REQ_TEXT_PROTOCOL_REPLY = 18,
TS_SERVICE_REQ_GETINTERFACES = 19,
TS_SERVICE_REQ_GETINTERFACES_REPLY = 20,
};
enum {
kTsMaxDevnameSize = 40
};

View file

@ -13,10 +13,14 @@
#if defined(OS_WIN)
#define _WINSOCK_DEPRECATED_NO_WARNINGS 1
#define _HAS_EXCEPTIONS 0
#define _CRT_SECURE_NO_WARNINGS 1
//#include <Winsock2.h>
#include <Ws2tcpip.h>
#include <Windows.h>
#undef max
//#include <winsock2.h>
#include <ws2ipdef.h>
#include <iphlpapi.h>

883
ts.cpp Normal file
View file

@ -0,0 +1,883 @@
#include "stdafx.h"
#include "tunsafe_types.h"
#include "netapi.h"
#include "crypto/curve25519-donna.h"
#include "util.h"
#include "wireguard_proto.h"
#include <string.h>
#include <algorithm>
#if defined(OS_WIN)
#include "util_win32.h"
#include "service_pipe_win32.h"
#include "service_win32_constants.h"
#endif // defined(OS_WIN)
#if defined(OS_POSIX)
#include <sys/stat.h>
#include <sys/un.h>
#include <unistd.h>
#include <dirent.h>
#include <errno.h>
#endif // defined(OS_WIN)
#pragma comment(lib, "ws2_32.lib")
#define ANSI_RESET "\x1b[0m"
#define ANSI_BOLD "\x1b[1m"
#define ANSI_FG_BLACK "\x1b[30m"
#define ANSI_FG_RED "\x1b[31m"
#define ANSI_FG_GREEN "\x1b[32m"
#define ANSI_FG_YELLOW "\x1b[33m"
#define ANSI_FG_BLUE "\x1b[34m"
#define ANSI_FG_MAGENTA "\x1b[35m"
#define ANSI_FG_CYAN "\x1b[36m"
#define ANSI_FG_WHITE "\x1b[37m"
static const uint8 kCurve25519Basepoint[32] = {9};
#if defined(OS_WIN)
#define EXENAME "ts"
static bool SendMessageToService(HANDLE pipe, int message, const void *data, size_t data_size) {
uint8 *temp = new uint8[data_size + 5];
*(uint32*)temp = (uint32)(data_size + 1);
temp[4] = (uint8)message;
memcpy(temp + 5, data, data_size);
// Write the whole thing
DWORD pos = 0, bytes_to_write = (DWORD)(data_size + 5), bytes_written;
do {
if (!WriteFile(pipe, temp + pos, bytes_to_write, &bytes_written, NULL)) {
fprintf(stderr, "Error writing to service pipe, error = %d\n", GetLastError());
break;
}
pos += bytes_written;
bytes_to_write -= bytes_written;
} while (bytes_to_write != 0);
delete[] temp;
return (bytes_to_write == 0);
}
static bool ReadExactBytesFromPipe(HANDLE pipe, const void *data, DWORD bytes_to_read) {
DWORD pos = 0, n;
do {
if (!ReadFile(pipe, (uint8*)data + pos, bytes_to_read, &n, NULL))
return false;
if (n == 0)
return false; // premature eof..
pos += n;
bytes_to_read -= n;
} while (bytes_to_read != 0);
return true;
}
static bool ReadMessageFromService(HANDLE pipe, int *message, std::string *data) {
uint8 header[5];
uint32 message_size;
if (!ReadExactBytesFromPipe(pipe, header, 5) || (message_size = *(uint32*)header) == 0) {
fprintf(stderr, "Error reading from service pipe, error = %d\n", GetLastError());
return false;
}
*message = header[4];
data->resize(message_size - 1);
if (message_size - 1 != 0 && !ReadExactBytesFromPipe(pipe, data->data(), message_size - 1)) {
fprintf(stderr, "Error reading from service pipe, error = %d\n", GetLastError());
return false;
}
return true;
}
struct ServiceLoginMessage {
uint64 version;
char interfac[kTsMaxDevnameSize];
bool want_state_updates;
bool want_create_interface;
};
static std::vector<GuidAndDevName> g_tap_adapters;
static bool g_did_get_adapters;
static const std::vector<GuidAndDevName> &GetTapAdapterInfo() {
if (!g_did_get_adapters) {
g_did_get_adapters = true;
GetTapAdapterInfo(&g_tap_adapters);
}
return g_tap_adapters;
}
static const char *GetGuidFromInterfaceName(const char *name) {
for (const GuidAndDevName &e : GetTapAdapterInfo())
if (strcmp(e.name, name) == 0)
return e.guid;
return NULL;
}
static const char *GetInterfaceNameFromGuid(const char *guid) {
for (const GuidAndDevName &e : GetTapAdapterInfo())
if (strcmp(e.guid, guid) == 0)
return e.name;
return NULL;
}
static HANDLE ConnectToService(const char *devname, bool want_updates, bool want_create = false) {
ServiceLoginMessage msg = {0};
msg.version = TUNSAFE_SERVICE_PROTOCOL_VERSION;
msg.want_state_updates = want_updates;
msg.want_create_interface = want_create;
// Rename devname to a guid
if (devname) {
const char *guid = (devname[0] == '{' || devname[0] == 0) ? devname : GetGuidFromInterfaceName(devname);
if (!guid) {
fprintf(stderr, "Interface '%s' not found\n", devname);
return NULL;
}
my_strlcpy(msg.interfac, sizeof(msg.interfac), guid);
}
for (;;) {
HANDLE pipe = CreateFile(TUNSAFE_PIPE_NAME, GENERIC_READ | GENERIC_WRITE, 0, NULL,
OPEN_EXISTING, 0, NULL);
if (pipe != INVALID_HANDLE_VALUE) {
if (!SendMessageToService(pipe, TS_SERVICE_REQ_LOGIN, &msg, sizeof(msg))) {
CloseHandle(pipe);
pipe = NULL;
}
return pipe;
}
DWORD error = GetLastError();
if (error != ERROR_PIPE_BUSY) {
fprintf(stderr, "Error connecting to TunSafe service, error = %d\n", error);
if (error == ERROR_FILE_NOT_FOUND)
fprintf(stderr, "Please check that the TunSafe service is started\n");
return NULL;
}
if (!WaitNamedPipe(TUNSAFE_PIPE_NAME, 10000)) {
fprintf(stderr, "Error connecting to TunSafe service, timed out.\n");
return NULL;
}
}
}
static bool CommunicateWithService(const char *devname, const std::string &query, std::string *reply) {
HANDLE pipe = ConnectToService(devname, false);
int message_code;
bool rv = false;
if (pipe != NULL &&
SendMessageToService(pipe, TS_SERVICE_REQ_TEXT_PROTOCOL, query.data(), query.size()) &&
ReadMessageFromService(pipe, &message_code, reply)) {
if (message_code == TS_SERVICE_REQ_TEXT_PROTOCOL_REPLY) {
rv = true;
} else {
if (message_code == TS_SERVICE_MSG_ERROR_REPLY) {
fprintf(stderr, "Error: %s\n", reply->c_str());
} else {
fprintf(stderr, "Unknown reply (%d) from TunSafe service.\n", message_code);
}
}
}
CloseHandle(pipe);
return rv;
}
static bool GetInterfaceList(std::string *result) {
HANDLE pipe = ConnectToService(NULL, false);
int message_code;
bool rv = false;
if (pipe != NULL &&
SendMessageToService(pipe, TS_SERVICE_REQ_GETINTERFACES, NULL, 0) &&
ReadMessageFromService(pipe, &message_code, result)) {
if (message_code == TS_SERVICE_REQ_GETINTERFACES_REPLY) {
rv = true;
} else {
fprintf(stderr, "GetInterfaceList: bad reply\n");
}
}
CloseHandle(pipe);
return rv;
}
#endif // defined(OS_WIN)
#if defined(OS_POSIX)
#define EXENAME "tunsafe"
static const char *GetGuidFromInterfaceName(const char *name) {
return name;
}
static const char *GetInterfaceNameFromGuid(const char *guid) {
return guid;
}
static int OpenUserspaceInterface(const char *iface) {
struct stat st;
struct sockaddr_un un = { 0 };
int fd = -1, rv;
if (strchr(iface, '/') != NULL) {
fprintf(stderr, "Unable to open usermode socket: No such device\n");
goto getout;
}
snprintf(un.sun_path, sizeof(un.sun_path), "/var/run/wireguard/%s.sock", iface);
if (stat(un.sun_path, &st) < 0) {
perror("Unable to open usermode socket");
goto getout;
}
if (!S_ISSOCK(st.st_mode)) {
fprintf(stderr, "Unable to open usermode socket: No such device\n");
goto getout;
}
fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (fd < 0)
goto getout;
un.sun_family = AF_UNIX;
if (connect(fd, (struct sockaddr *)&un, sizeof(un)) < 0) {
if (errno == ECONNREFUSED)
unlink(un.sun_path);
else
perror("Error opening wireguard usermode interface socket");
goto getout;
}
return fd;
getout:
if (fd >= 0)
close(fd);
return -1;
}
static bool GetInterfaceList(std::string *result) {
struct dirent *dent;
DIR *dir = opendir("/var/run/wireguard/");
if (!dir)
return errno == ENOENT;
while ((dent = readdir(dir)) != NULL) {
size_t len = strlen(dent->d_name);
static const char kSuffix[6] = ".sock";
if (len >= sizeof(kSuffix) - 1 &&
memcmp(&dent->d_name[len - (sizeof(kSuffix) - 1)], kSuffix, sizeof(kSuffix) - 1) == 0) {
dent->d_name[len - (sizeof(kSuffix) - 1)] = '\n';
result->append(dent->d_name, len - (sizeof(kSuffix) - 1) + 1);
}
}
closedir(dir);
return true;
}
static bool CommunicateWithService(const char *devname, const std::string &query, std::string *reply) {
ssize_t n;
char buf[4096];
bool rv = false;
reply->clear();
int fd = OpenUserspaceInterface(devname);
if (fd == -1)
return false;
for(size_t pos = 0; query.size() - pos; pos += n) {
n = write(fd, query.data() + pos, query.size() - pos);
if (n <= 0) {
perror("Error writing to service pipe");
goto getout;
}
}
for(;;) {
n = read(fd, buf, sizeof(buf));
if (n <= 0) {
if (n == 0) {
// ensure that it ends with \n\n
if (reply->size() >= 2 && (*reply)[reply->size() - 1] == '\n' && (*reply)[reply->size() - 2] == '\n') {
rv = true;
} else {
fprintf(stderr, "Bad reply from service pipe\n");
}
} else {
perror("Error reading from service pipe");
}
break;
}
reply->append(buf, n);
}
getout:
close(fd);
return rv;
}
static int HandleStopCommand(int argc, char **argv) {
if (argc != 1) {
fprintf(stderr, "Usage: " EXENAME " stop <interface>\n");
return 1;
}
struct sockaddr_un un;
struct stat st;
const char *iface = argv[0];
if (strchr(iface, '/')) {
fprintf(stderr, "No such interface\n");
return 1;
}
snprintf(un.sun_path, sizeof(un.sun_path), "/var/run/wireguard/%s.sock", iface);
if (unlink(un.sun_path) == -1) {
perror("unlink");
return 1;
}
return 0;
}
#endif // defined(OS_POSIX)
void ShowHelp() {
fprintf(stderr,
"Usage: " EXENAME " <cmd> [<args>]\n\n"
#if defined(OS_POSIX)
" " EXENAME " filename.conf\n\n"
#endif // defined(OS_POSIX)
"Available subcommands:\n"
" show: Shows the configuration and status of the interfaces\n"
" set: Change the configuration or the peer list\n"
" start: Start TunSafe on an interface\n"
" stop: Stop TunSafe on an interface\n"
#if defined(OS_WIN)
" log: Display recent log entries\n"
#endif // defined(OS_WIN)
" genkey: Writes a new private key to stdout\n"
" genpsk: Writes a new preshared key to stdout\n"
" pubkey: Reads a private key from stdin and writes its public key to stdout\n"
"To see more help about a subcommand, pass --help to it\n");
}
static bool ParseHexKeyToBase64(const char *key, char base64key[WG_PUBLIC_KEY_LEN_BASE64 + 1]) {
uint8 keybuf[32];
if (!ParseHexString(key, keybuf, 32))
return false;
return base64_encode(keybuf, 32, base64key, WG_PUBLIC_KEY_LEN_BASE64 + 1, NULL) != NULL;
}
static char *FormatTransferPart(char *buf, size_t bufsize, uint64 n) {
if (n < 1024)
snprintf(buf, bufsize, "%u " ANSI_FG_CYAN "B" ANSI_RESET, (unsigned)n);
else if (n < 1024 * 1024)
snprintf(buf, bufsize, "%.2f " ANSI_FG_CYAN "KiB" ANSI_RESET, (double)n * (1.0 / 1024));
else if (n < 1024 * 1024 * 1024)
snprintf(buf, bufsize, "%.2f " ANSI_FG_CYAN "MiB" ANSI_RESET, (double)n * (1.0 / 1024 / 1024));
else if (n < 1024ull * 1024 * 1024 * 1024)
snprintf(buf, bufsize, "%.2f " ANSI_FG_CYAN "GiB" ANSI_RESET, (double)n * (1.0 / 1024 / 1024 / 1024));
else
snprintf(buf, bufsize, "%.2f " ANSI_FG_CYAN "TiB" ANSI_RESET, (double)n * (1.0 / 1024 / 1024 / 1024 / 1024));
return buf;
}
static size_t PrintTime(char *buf, size_t bufsize, uint64 n) {
size_t pos = 0;
uint64 years = n / (365 * 24 * 60 * 60);
uint32 n32 = n % (365 * 24 * 60 * 60);
if (years)
pos += snprintf(buf + pos, bufsize - pos, "%llu " ANSI_FG_CYAN "year%s" ANSI_RESET ", ", (unsigned long long)years, (years == 1) ? "" : "s");
uint32 days = n32 / (24 * 60 * 60);
n32 %= (24 * 60 * 60);
if (days)
pos += snprintf(buf + pos, bufsize - pos, "%u " ANSI_FG_CYAN "day%s" ANSI_RESET ", ", days, (days == 1) ? "" : "s");
uint32 hours = n32 / (60 * 60);
n32 %= (60 * 60);
if (hours)
pos += snprintf(buf + pos, bufsize - pos, "%u " ANSI_FG_CYAN "hour%s" ANSI_RESET ", ", hours, (hours == 1) ? "" : "s");
uint32 minutes = n32 / 60;
if (minutes)
pos += snprintf(buf + pos, bufsize - pos, "%u " ANSI_FG_CYAN "minute%s" ANSI_RESET ", ", minutes, (minutes == 1) ? "" : "s");
uint32 seconds = n32 % 60;
if (seconds)
pos += snprintf(buf + pos, bufsize - pos, "%u " ANSI_FG_CYAN "second%s" ANSI_RESET ", ", seconds, (seconds == 1) ? "" : "s");
if (pos)
buf[pos -= 2] = '\0';
return pos;
}
static char *PrintHandshake(char *buf, size_t bufsize, uint64 secs) {
time_t now = time(NULL);
if (now == secs) {
snprintf(buf, bufsize, "Now");
} else if (now < (int64)secs) {
snprintf(buf, bufsize, ANSI_FG_RED "System clock going backwards" ANSI_RESET);
} else {
size_t pos = PrintTime(buf, bufsize - 4, now - secs);
memcpy(buf + pos, " ago", 5);
}
return buf;
}
static void AppendIpToString(const char *value, std::string *result) {
if (!result->empty())
(*result) += ", ";
const char *slash = strchr(value, '/');
if (slash) {
result->append(value, slash - value);
result->append(ANSI_FG_CYAN "/" ANSI_RESET);
result->append(slash + 1);
} else {
result->append(value);
}
}
static int ShowUserFriendlyForDevice(char *devname) {
std::string reply;
std::vector<std::pair<char*, char*>> kv;
std::string ips;
if (!CommunicateWithService(devname, "get=1\n\n", &reply))
return 1;
if (!ParseConfigKeyValue(&reply[0], &kv)) {
getout_fail:
fprintf(stderr, "Unable to parse response");
return 1;
}
size_t i = 0;
char base64key[WG_PUBLIC_KEY_LEN_BASE64 + 1];
char base64psk[WG_PUBLIC_KEY_LEN_BASE64 + 1];
int listen_port = 0;
base64key[0] = 0;
// Parse all interface level keys
for (; i < kv.size(); i++) {
char *key = kv[i].first, *value = kv[i].second;
if (strcmp(key, "private_key") == 0) {
uint8 binkey[32];
if (!ParseHexString(value, binkey, sizeof(binkey)))
goto getout_fail;
if (!IsOnlyZeros(binkey, 32)) {
curve25519_donna(binkey, binkey, kCurve25519Basepoint);
base64_encode(binkey, sizeof(binkey), base64key, sizeof(base64key), NULL);
}
} else if (strcmp(key, "address") == 0) {
AppendIpToString(value, &ips);
} else if (strcmp(key, "listen_port") == 0) {
listen_port = atoi(value);
} else if (strcmp(key, "public_key") == 0) {
break;
}
}
const char *interfacename = (devname[0] == '{') ? GetInterfaceNameFromGuid(devname) : devname;
printf(ANSI_RESET ANSI_FG_GREEN ANSI_BOLD "interface" ANSI_RESET ": " ANSI_FG_GREEN "%s" ANSI_RESET "\n",
interfacename);
if (base64key[0]) {
printf(" " ANSI_BOLD "public key" ANSI_RESET ": %s\n"
" " ANSI_BOLD "private key" ANSI_RESET ": (hidden)\n", base64key);
}
if (listen_port)
printf(" " ANSI_BOLD "listening port" ANSI_RESET ": %d\n", listen_port);
if (ips.size())
printf(" " ANSI_BOLD "address" ANSI_RESET ": %s\n", ips.c_str());
const char *endpoint = NULL;
uint64 rx_bytes, tx_bytes, last_handshake_time_sec;
int persistent_keepalive;
char text[256];
bool clear_state = true;
// Parse peer level keys
for (; i < kv.size(); i++) {
char *key = kv[i].first, *value = kv[i].second;
if (clear_state) {
base64key[0] = base64psk[0] = 0;
endpoint = NULL;
ips.clear();
persistent_keepalive = 0;
last_handshake_time_sec = tx_bytes = rx_bytes = 0;
clear_state = false;
}
if (strcmp(key, "public_key") == 0) {
if (!ParseHexKeyToBase64(value, base64key))
goto getout_fail;
} else if (strcmp(key, "preshared_key") == 0) {
if (!ParseHexKeyToBase64(value, base64psk))
goto getout_fail;
} else if (strcmp(key, "tx_bytes") == 0) {
tx_bytes = strtoull(value, NULL, 0);
} else if (strcmp(key, "rx_bytes") == 0) {
rx_bytes = strtoull(value, NULL, 0);
} else if (strcmp(key, "allowed_ip") == 0) {
AppendIpToString(value, &ips);
} else if (strcmp(key, "persistent_keepalive_interval") == 0) {
persistent_keepalive = atoi(value);
} else if (strcmp(key, "endpoint") == 0) {
endpoint = value;
} else if (strcmp(key, "last_handshake_time_sec") == 0) {
last_handshake_time_sec = strtoull(value, NULL, 0);
}
if (i == kv.size() - 1 || strcmp(kv[i + 1].first, "public_key") == 0) {
if (!base64key[0])
goto getout_fail;
printf("\n" ANSI_FG_YELLOW ANSI_BOLD "peer" ANSI_RESET ": " ANSI_FG_YELLOW "%s" ANSI_RESET "\n", base64key);
if (base64psk[0])
printf(" " ANSI_BOLD "preshared key" ANSI_RESET ": (hidden)\n");
if (endpoint)
printf(" " ANSI_BOLD "endpoint" ANSI_RESET ": %s\n", endpoint);
printf(" " ANSI_BOLD "allowed ips" ANSI_RESET ": %s\n", ips.size() ? ips.c_str() : "(none)");
if (last_handshake_time_sec)
printf(" " ANSI_BOLD "latest handshake" ANSI_RESET ": %s\n", PrintHandshake(text, sizeof(text), last_handshake_time_sec));
if (tx_bytes | rx_bytes) {
printf(" " ANSI_BOLD "transfer" ANSI_RESET ": %s received, ", FormatTransferPart(text, sizeof(text), rx_bytes));
printf("%s sent\n", FormatTransferPart(text, sizeof(text), tx_bytes));
}
if (persistent_keepalive) {
PrintTime(text, sizeof(text), persistent_keepalive);
printf(" " ANSI_BOLD "persistent keepalive" ANSI_RESET ": every %s\n", text);
}
clear_state = true;
}
}
return 0;
}
static int HandleShowCommand(int argc, char **argv) {
if (argc != 0 && strcmp(argv[0], "--help") == 0) {
fprintf(stderr, "Usage: ts show { <interface> | all | interfaces }\n");
return 0;
}
std::vector<char*> interfaces;
std::string interfaces_str;
if (argc == 0 || strcmp(argv[0], "all") == 0) {
if (!GetInterfaceList(&interfaces_str))
return 1;
SplitString(&interfaces_str[0], '\n', &interfaces);
bool want_newline = false;
for (char *interfac : interfaces) {
if (want_newline)
printf("\n");
want_newline = true;
if (ShowUserFriendlyForDevice(interfac))
return 1;
}
} else if (strcmp(argv[0], "interfaces") == 0) {
if (!GetInterfaceList(&interfaces_str))
return 1;
SplitString(&interfaces_str[0], '\n', &interfaces);
for (char *interfac : interfaces) {
const char *name = GetInterfaceNameFromGuid(interfac);
if (name)
printf("%s\n", name);
}
} else {
return ShowUserFriendlyForDevice(argv[0]);
}
return 0;
}
static void AppendCommand(std::string *result, const char *tag, const char *value) {
result->append(tag);
result->append("=");
result->append(value);
result->append("\n");
}
static bool ConvertBase64KeyToHex(const char *s, char key[65]) {
uint8 tmp[32];
size_t size = 32;
if (!base64_decode((uint8*)s, strlen(s), tmp, &size) || size != 32)
return false;
PrintHexString(tmp, 32, key);
return true;
}
static int HandleSetCommand(int argc, char **argv) {
std::string command, reply;
std::vector<char*> ss;
char hexkey[65];
if (argc == 0) {
fprintf(stderr, "Usage: ts set <interface> [address <address>] [listen-port <port>] [private-key <file path>] "
"[peer <base64 public key> [remove] [preshared-key <file path>] [endpoint <ip>:<port>] "
"[persistent-keepalive <interval seconds>] [allowed-ips <ip1>/<cidr1>[,<ip2>/<cidr2>]] ]");
return 1;
}
char **argv_end = argv + argc;
const char *interfc = *argv++;
command = "set=1\n";
bool in_interface_section = true;
bool in_peer_section = false;
bool did_clear_allowed_ips = false;
while (argv != argv_end) {
const char *key = *argv++;
if (argv != argv_end) {
if (in_interface_section) {
if (strcmp(key, "listen-port") == 0) {
AppendCommand(&command, "listen_port", *argv++);
continue;
} else if (strcmp(key, "address") == 0) {
AppendCommand(&command, "address", *argv++);
continue;
} else if (strcmp(key, "private-key") == 0) {
if (!ConvertBase64KeyToHex(*argv++, hexkey))
goto invalid_key_format;
AppendCommand(&command, "private_key", hexkey);
continue;
}
}
if (strcmp(key, "peer") == 0) {
in_interface_section = false;
in_peer_section = true;
did_clear_allowed_ips = false;
if (!ConvertBase64KeyToHex(*argv++, hexkey))
goto invalid_key_format;
AppendCommand(&command, "public_key", hexkey);
continue;
}
if (in_peer_section) {
if (strcmp(key, "preshared-key") == 0) {
if (!ConvertBase64KeyToHex(*argv++, hexkey))
goto invalid_key_format;
AppendCommand(&command, "preshared_key", hexkey);
continue;
} else if (strcmp(key, "endpoint") == 0) {
AppendCommand(&command, "endpoint", *argv++);
continue;
} else if (strcmp(key, "persistent-keepalive") == 0) {
AppendCommand(&command, "persistent_keepalive_interval", *argv++);
continue;
} else if (strcmp(key, "allowed-ips") == 0) {
if (!did_clear_allowed_ips) {
AppendCommand(&command, "replace_allowed_ips", "true");
did_clear_allowed_ips = true;
}
SplitString(*argv++, ',', &ss);
for (char *x : ss)
AppendCommand(&command, "allowed_ip", x);
continue;
}
}
}
if (in_peer_section) {
if (strcmp(key, "remove") == 0) {
in_peer_section = false;
AppendCommand(&command, "remove", "true");
continue;
}
}
fprintf(stderr, "Invalid argument: %s\n", key);
return 1;
invalid_key_format:
fprintf(stderr, "Key is not in the correct format: '%s'\n", argv[-1]);
return 1;
}
command.append("\n");
if (!CommunicateWithService(interfc, command, &reply))
return 1;
return 0;
}
#if defined(OS_WIN)
static int HandleLogCommand() {
HANDLE pipe = ConnectToService(NULL, true);
int message_code;
std::string reply;
while (pipe != NULL && ReadMessageFromService(pipe, &message_code, &reply) && message_code == TS_SERVICE_MSG_LOGLINE)
printf("%s\n", reply.c_str());
CloseHandle(pipe);
return 0;
}
static int HandleStartCommand(int argc, char **argv) {
if (argc < 1 || argc > 2 || strcmp(argv[0], "--help") == 0) {
fprintf(stderr, "Usage: " EXENAME " start <interface> [<filename>]\n");
return 1;
}
const char *devname = argv[0];
HANDLE pipe = ConnectToService(devname, false, true);
int message_code;
std::string reply;
const char *path = (argc == 1) ? "" : argv[1];
// Tell the server to startup a new interface
if (pipe == NULL ||
!SendMessageToService(pipe, TS_SERVICE_REQ_START, path, strlen(path) + 1) ||
!ReadMessageFromService(pipe, &message_code, &reply))
return 1;
if (message_code == TS_SERVICE_MSG_ERROR_REPLY) {
fprintf(stderr, "%s\n", reply.c_str());
return 1;
}
return 0;
}
static int HandleStopCommand(int argc, char **argv) {
if (argc != 1) {
fprintf(stderr, "Usage: " EXENAME " stop <interface>\n");
return 1;
}
const char *devname = argv[0];
HANDLE pipe = ConnectToService(devname, false);
// Tell the server to stop the interface
if (pipe == NULL ||
!SendMessageToService(pipe, TS_SERVICE_REQ_STOP, NULL, 0))
return 1;
return 0;
}
#endif // defined(OS_WIN)
struct CommandLineOutput {
const char *filename_to_load;
const char *interface_name;
bool daemon;
};
// Returns -1 on invalid subcommand
int HandleCommandLine(int argc, char **argv, CommandLineOutput *output) {
uint8 key[32];
char base64buf[WG_PUBLIC_KEY_LEN_BASE64 + 1];
if (argc == 1) {
ShowHelp();
return 1;
}
const char *subcommand = argv[1];
argv += 2;
argc -= 2;
if (!strcmp(subcommand, "show")) {
return HandleShowCommand(argc, argv);
} else if (!strcmp(subcommand, "set")) {
return HandleSetCommand(argc, argv);
#if defined(OS_WIN)
} else if (!strcmp(subcommand, "log")) {
if (argc != 0) {
fprintf(stderr, "Usage: " EXENAME " log\n");
return 1;
}
return HandleLogCommand();
} else if (!strcmp(subcommand, "start")) {
return HandleStartCommand(argc, argv);
#else
} else if (!strcmp(subcommand, "start") && output) {
if (argc != 0 && !strcmp(argv[0], "--help")) {
start_usage:
fprintf(stderr, "Usage: " EXENAME " start [-d/--daemon] [-n <interface-name>] [<filename>]\n");
return 0;
}
for (; argc; argc--, argv++) {
char *arg = argv[0];
if (strcmp(arg, "-d") == 0 || strcmp(arg, "--daemon") == 0) {
output->daemon = true;
continue;
}
if (strcmp(arg, "-n") == 0) {
if (argc < 2) goto start_usage;
output->interface_name = argv[1];
argc--,argv++;
continue;
}
break;
}
if (argc > 1) goto start_usage;
output->filename_to_load = (argc == 0) ? "" : argv[0];
return 0;
#endif // defined(OS_WIN)
} else if (!strcmp(subcommand, "stop")) {
return HandleStopCommand(argc, argv);
} else if(!strcmp(subcommand, "genkey")) {
if (argc != 0) {
fprintf(stderr, "Usage: " EXENAME " genkey\n");
return 1;
}
OsGetRandomBytes(key, 32);
curve25519_normalize(key);
printf("%s\n", base64_encode(key, 32, base64buf, sizeof(base64buf), NULL));
} else if (!strcmp(subcommand, "genpsk")) {
if (argc != 0) {
fprintf(stderr, "Usage: " EXENAME " genpsk\n");
return 1;
}
OsGetRandomBytes(key, 32);
printf("%s\n", base64_encode(key, 32, base64buf, sizeof(base64buf), NULL));
} else if (!strcmp(subcommand, "pubkey")) {
char base64[WG_PUBLIC_KEY_LEN_BASE64 + 2];
size_t n = fread(base64, 1, sizeof(base64), stdin);
if (n < sizeof(base64) - 2 || n >= sizeof(base64) ||
(n == sizeof(base64) - 1 && (base64[WG_PUBLIC_KEY_LEN_BASE64] != ' ' && base64[WG_PUBLIC_KEY_LEN_BASE64] != '\n'))) {
fprintf(stderr, EXENAME ": Incorrect key format\n");
return 1;
}
size_t size = 32;
if (!base64_decode((uint8*)base64, n, key, &size) || size != 32) {
fprintf(stderr, EXENAME ": Incorrect key format\n");
return 1;
}
curve25519_donna(key, key, kCurve25519Basepoint);
printf("%s\n", base64_encode(key, 32, base64buf, sizeof(base64buf), NULL));
} else if (!strcmp(subcommand, "--help")) {
ShowHelp();
} else if (!strcmp(subcommand, "--version")) {
printf("%s\n", TUNSAFE_VERSION_STRING);
} else {
if (argc == 0) {
if (output)
output->filename_to_load = subcommand;
} else {
ShowHelp();
}
return -1;
}
return 0;
}
#if defined(OS_WIN)
// This is integrated into the main tunsafe binary on posix systems
int main(int argc, char **argv) {
int rv = HandleCommandLine(argc, argv, NULL);
if (rv == -1) {
fprintf(stderr, "Invalid subcommand '%s'\n", argv[1]);
ShowHelp();
return 1;
}
return rv;
}
#endif // defined(OS_WIN)

199
ts.vcxproj Normal file
View file

@ -0,0 +1,199 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" ToolsVersion="15.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Label="ProjectConfigurations">
<ProjectConfiguration Include="Debug|Win32">
<Configuration>Debug</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|Win32">
<Configuration>Release</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Debug|x64">
<Configuration>Debug</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|x64">
<Configuration>Release</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
</ItemGroup>
<PropertyGroup Label="Globals">
<VCProjectVersion>15.0</VCProjectVersion>
<ProjectGuid>{443E105E-8D7C-401F-BD41-D3F56C76104B}</ProjectGuid>
<Keyword>Win32Proj</Keyword>
<RootNamespace>ts</RootNamespace>
<WindowsTargetPlatformVersion>10.0.17134.0</WindowsTargetPlatformVersion>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v141</PlatformToolset>
<CharacterSet>MultiByte</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v141</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>MultiByte</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v141</PlatformToolset>
<CharacterSet>MultiByte</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v141</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>MultiByte</CharacterSet>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings">
<Import Project="crypto\nasm.props" />
</ImportGroup>
<ImportGroup Label="Shared">
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<PropertyGroup Label="UserMacros" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<LinkIncremental>true</LinkIncremental>
<OutDir>$(SolutionDir)build\$(Platform)_$(Configuration)\</OutDir>
<IntDir>$(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\</IntDir>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<LinkIncremental>true</LinkIncremental>
<OutDir>$(SolutionDir)build\$(Platform)_$(Configuration)\</OutDir>
<IntDir>$(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\</IntDir>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<LinkIncremental>false</LinkIncremental>
<OutDir>$(SolutionDir)build\$(Platform)_$(Configuration)\</OutDir>
<IntDir>$(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\</IntDir>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<LinkIncremental>false</LinkIncremental>
<OutDir>$(SolutionDir)build\$(Platform)_$(Configuration)\</OutDir>
<IntDir>$(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\</IntDir>
</PropertyGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<ClCompile>
<PrecompiledHeader>Use</PrecompiledHeader>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
<SDLCheck>true</SDLCheck>
<PreprocessorDefinitions>_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<ConformanceMode>true</ConformanceMode>
<ExceptionHandling>false</ExceptionHandling>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<ClCompile>
<PrecompiledHeader>Use</PrecompiledHeader>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
<SDLCheck>true</SDLCheck>
<PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<ConformanceMode>true</ConformanceMode>
<ExceptionHandling>false</ExceptionHandling>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<ClCompile>
<PrecompiledHeader>Use</PrecompiledHeader>
<WarningLevel>Level3</WarningLevel>
<Optimization>MinSpace</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<SDLCheck>true</SDLCheck>
<PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<ConformanceMode>true</ConformanceMode>
<RuntimeLibrary>MultiThreaded</RuntimeLibrary>
<OmitFramePointers>true</OmitFramePointers>
<ExceptionHandling>false</ExceptionHandling>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<ClCompile>
<PrecompiledHeader>Use</PrecompiledHeader>
<WarningLevel>Level3</WarningLevel>
<Optimization>MinSpace</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<SDLCheck>true</SDLCheck>
<PreprocessorDefinitions>NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<ConformanceMode>true</ConformanceMode>
<RuntimeLibrary>MultiThreaded</RuntimeLibrary>
<OmitFramePointers>true</OmitFramePointers>
<ExceptionHandling>false</ExceptionHandling>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemGroup>
<ClInclude Include="crypto\curve25519-donna.h" />
<ClInclude Include="stdafx.h" />
<ClInclude Include="util.h" />
<ClInclude Include="util_win32.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="crypto\curve25519-donna.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">NotUsing</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">NotUsing</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
</ClCompile>
<ClCompile Include="stdafx.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">Create</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">Create</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">Create</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Create</PrecompiledHeader>
</ClCompile>
<ClCompile Include="ts.cpp" />
<ClCompile Include="util.cpp" />
<ClCompile Include="util_win32.cpp" />
</ItemGroup>
<ItemGroup>
<NASM Include="crypto\curve25519_x64_nasm.asm">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">true</ExcludedFromBuild>
</NASM>
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
<Import Project="crypto\nasm.targets" />
</ImportGroup>
</Project>

53
ts.vcxproj.filters Normal file
View file

@ -0,0 +1,53 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup>
<Filter Include="Source Files">
<UniqueIdentifier>{4FC737F1-C7A5-4376-A066-2A32D752A2FF}</UniqueIdentifier>
<Extensions>cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx</Extensions>
</Filter>
<Filter Include="Header Files">
<UniqueIdentifier>{93995380-89BD-4b04-88EB-625FBE52EBFB}</UniqueIdentifier>
<Extensions>h;hh;hpp;hxx;hm;inl;inc;ipp;xsd</Extensions>
</Filter>
<Filter Include="Resource Files">
<UniqueIdentifier>{67DA6AB6-F800-4c08-8B7A-83BB121AAD01}</UniqueIdentifier>
<Extensions>rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms</Extensions>
</Filter>
</ItemGroup>
<ItemGroup>
<ClInclude Include="stdafx.h">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="util.h">
<Filter>Source Files</Filter>
</ClInclude>
<ClInclude Include="crypto\curve25519-donna.h">
<Filter>Source Files</Filter>
</ClInclude>
<ClInclude Include="util_win32.h">
<Filter>Source Files</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<ClCompile Include="stdafx.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="ts.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="util.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="crypto\curve25519-donna.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="util_win32.cpp">
<Filter>Source Files</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<NASM Include="crypto\curve25519_x64_nasm.asm">
<Filter>Source Files</Filter>
</NASM>
</ItemGroup>
</Project>

View file

@ -10,10 +10,14 @@ MultithreadedDelayedDelete::MultithreadedDelayedDelete() {
}
MultithreadedDelayedDelete::~MultithreadedDelayedDelete() {
assert(curr_.size() == 0);
assert(next_.size() == 0);
assert(to_delete_.size() == 0);
free(table_);
}
void MultithreadedDelayedDelete::Initialize(uint32 num_threads) {
void MultithreadedDelayedDelete::Configure(uint32 num_threads) {
assert(table_ == NULL);
num_threads_ = num_threads;
table_ = (CheckpointData*)calloc(sizeof(CheckpointData), num_threads);
}

View file

@ -150,12 +150,14 @@ public:
typedef void DoDeleteFunc(void *x);
void Add(DoDeleteFunc *func, void *param);
void Initialize(uint32 num_threads);
void Configure(uint32 num_threads);
void Checkpoint(uint32 thread_id);
void MainCheckpoint();
bool enabled() const { return num_threads_ != 0; }
private:
struct Entry {
DoDeleteFunc *func;

View file

@ -107,22 +107,6 @@ static void SetUiVisibility(bool visible) {
UpdateGraphReq();
}
static bool GetConfigFullName(const char *basename, char *fullname, size_t fullname_size) {
size_t len = strlen(basename);
if (FindFilenameComponent(basename)[0]) {
if (len >= fullname_size)
return false;
memcpy(fullname, basename, len + 1);
return true;
}
size_t clen = GetConfigPath(fullname, fullname_size);
if (clen == 0 || clen + len >= fullname_size)
return false;
memcpy(fullname + clen, basename, (len + 1) * sizeof(fullname[0]));
return true;
}
void StopTunsafeBackend(UpdateIconWhy why) {
if (g_backend->is_started()) {
g_backend->Stop();
@ -154,6 +138,7 @@ void StartTunsafeBackend(UpdateIconWhy reason) {
}
g_notified_connected_server = false;
g_is_connected_to_server = false;
memset(&g_processor_stats, 0, sizeof(g_processor_stats));
g_backend->Start(g_current_filename);
RegWriteInt(g_reg_key, "IsConnected", 1);
}
@ -189,12 +174,21 @@ public:
}
virtual void OnLogLine(const char **s) {
const char *line = *s;
size_t len = strlen(line);
char *tmp = (char*)alloca(len + 3);
tmp[len + 0] = '\r';
tmp[len + 1] = '\n';
tmp[len + 2] = 0;
memcpy(tmp, line, len);
CHARRANGE cr;
cr.cpMin = -1;
cr.cpMax = -1;
// hwnd = rich edit hwnd
SendMessage(hwndEdit, EM_EXSETSEL, 0, (LPARAM)&cr);
SendMessage(hwndEdit, EM_REPLACESEL, 0, (LPARAM)*s);
SendMessage(hwndEdit, EM_REPLACESEL, 0, (LPARAM)tmp);
}
virtual void OnStateChanged() {
@ -204,13 +198,13 @@ public:
const char *filename = g_cmdline_filename;
if (filename) {
if (GetConfigFullName(filename, fullname, sizeof(fullname)))
if (ExpandConfigPath(filename, fullname, sizeof(fullname)))
SetCurrentConfigFilename(fullname);
} else {
std::string currconfig = g_backend->GetConfigFileName();
if (currconfig.empty()) {
char *conf = RegReadStr(g_reg_key, "ConfigFile", "TunSafe.conf");
if (GetConfigFullName(conf, fullname, sizeof(fullname)))
if (ExpandConfigPath(conf, fullname, sizeof(fullname)))
SetCurrentConfigFilename(fullname);
free(conf);
} else {
@ -233,10 +227,12 @@ public:
}
virtual void OnStatusCode(TunsafeBackend::StatusCode status) override {
if (status != g_status_code)
InvalidatePaintbox();
g_status_code = status;
if (TunsafeBackend::IsPermanentError(status)) {
UpdateIcon(g_is_connected_to_server ? UIW_STOPPED_WORKING_FAIL : UIW_NONE);
InvalidatePaintbox();
return;
}
bool is_connected = (status == TunsafeBackend::kStatusConnected);
@ -254,13 +250,15 @@ public:
if (is_connected > not_first && (g_startup_flags & kStartupFlag_BackgroundService))
g_notified_connected_server = true;
UpdateIcon(UIW_NONE);
InvalidatePaintbox();
}
}
virtual void OnClearLog() override {
SetWindowText(hwndEdit, "");
}
virtual void OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) override {
}
};
static MyBackendDelegate my_procdel;
@ -411,7 +409,7 @@ public:
ConfigMenuBuilder::ConfigMenuBuilder()
: nfiles_(0), depth_(0) {
if (!GetConfigFullName("", buf_, sizeof(buf_)))
if (!ExpandConfigPath("", buf_, sizeof(buf_)))
bufpos_ = sizeof(buf_);
else
bufpos_ = strlen(buf_);
@ -556,7 +554,7 @@ static void OpenEditor() {
static void BrowseFiles() {
char buf[MAX_PATH];
if (GetConfigFullName("", buf, ARRAYSIZE(buf))) {
if (ExpandConfigPath("", buf, ARRAYSIZE(buf))) {
size_t l = strlen(buf);
buf[l - 1] = 0;
ShellExecuteFromExplorer(buf, NULL, NULL, "explore");
@ -572,7 +570,7 @@ bool ImportFile(const char *s, bool silent = false) {
bool rv = false;
int filerv;
if (!*last || !GetConfigFullName(last, buf, ARRAYSIZE(buf)) || _stricmp(buf, s) == 0)
if (!*last || !ExpandConfigPath(last, buf, ARRAYSIZE(buf)) || _stricmp(buf, s) == 0)
goto out;
filedata = LoadFileSane(s, &filesize);
@ -657,9 +655,8 @@ void BrowseFile(HWND wnd) {
static const uint8 kCurve25519Basepoint[32] = {9};
static void SetKeyBox(HWND wnd, int ctr, uint8 buf[32]) {
uint8 *privs = base64_encode(buf, 32, NULL);
SetDlgItemText(wnd, ctr, (char*)privs);
free(privs);
char base64[WG_PUBLIC_KEY_LEN_BASE64 + 1];
SetDlgItemText(wnd, ctr, base64_encode(buf, 32, base64, sizeof(base64), NULL));
}
static INT_PTR WINAPI KeyPairDlgProc(HWND hWnd, UINT message, WPARAM wParam,
@ -1075,20 +1072,18 @@ void PushLine(const char *s) {
snprintf(buf, sizeof(buf), "[%.2d:%.2d:%.2d] ", t.wHour, t.wMinute, t.wSecond);
size_t tl = strlen(buf);
char *x = (char*)malloc(tl + l + 3);
char *x = (char*)malloc(tl + l + 1);
if (!x) return;
memcpy(x, buf, tl);
memcpy(x + tl, s, l);
x[l + tl] = '\r';
x[l + tl + 1] = '\n';
x[l + tl + 2] = '\0';
x[l + tl] = '\0';
g_backend_delegate->OnLogLine((const char**)&x);
free(x);
}
void EnsureConfigDirCreated() {
char fullname[1024];
if (GetConfigFullName("", fullname, sizeof(fullname)))
if (ExpandConfigPath("", fullname, sizeof(fullname)))
CreateDirectory(fullname, NULL);
}
@ -1358,7 +1353,7 @@ static void DrawGraph(HDC dc, const RECT *rr, StatsCollector::TimeSeries **sourc
for (size_t j = 0; j != num_source; j++) {
const StatsCollector::TimeSeries *src = sources[j];
for (size_t i = 0; i != src->size; i++)
mx = max(mx, src->data[i]);
mx = std::max(mx, src->data[i]);
}
int topval = (int)(mx + 0.5f);
// round it appropriately
@ -1432,11 +1427,13 @@ static void DrawInGraphBox(HDC hdc, int w, int h) {
for (int i = 0; i < graph->num_charts; i++) {
time_series_ptr[i] = &time_series[i];
time_series[i].shift = 0;
time_series[i].size = *(uint32*)ptr;
uint32 size = *(uint32*)ptr;
time_series[i].size = size;
time_series[i].data = (float*)(ptr + 4);
ptr += 4 + *(uint32*)ptr * 4;
if (ptr - (uint8*)graph > graph->total_size)
if ((ptr - (uint8*)graph) + 4 + (uint64)size * 4 > graph->total_size)
break;
ptr += 4 + size * 4;
}
num_charts = graph->num_charts;
}
@ -1517,9 +1514,7 @@ static const char *GetAdvancedInfoValue(char buffer[256], int i) {
case 0: {
if (IsOnlyZeros(g_backend->public_key(), 32))
return "";
char *str = (char*)base64_encode(g_backend->public_key(), 32, NULL);
snprintf(buffer, 256, "%s", str);
free(str);
base64_encode(g_backend->public_key(), 32, buffer, 256, NULL);
return buffer;
}
case 1: {
@ -1764,7 +1759,7 @@ int WINAPI WinMain(HINSTANCE hInstance, HINSTANCE hPrevInstance, LPSTR lpCmdLine
// Check if the app is already running.
g_runonce_mutex = CreateMutexA(0, FALSE, "TunSafe-f19e092db01cbe0fb6aee132f8231e5b71c98f90");
if (GetLastError() == ERROR_ALREADY_EXISTS) {
if (GetLastError() == ERROR_ALREADY_EXISTS&&0) {
HWND window = FindWindow("TunSafe-f19e092db01cbe0fb6aee132f8231e5b71c98f90", NULL);
DWORD_PTR result;
if (!window || !SendMessageTimeout(window, WM_USER + 10, 0, 0, SMTO_BLOCK, 3000, &result) || result != 31337) {

171
util.cpp
View file

@ -17,46 +17,55 @@
#include <arpa/inet.h>
#endif
#include <vector>
#include <algorithm>
#include "tunsafe_types.h"
static char base64_alphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
static const char kBase64Alphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
uint8 *base64_encode(const uint8 *input, size_t length, size_t *out_length) {
uint32 a;
size_t size;
uint8 *result, *r;
char *base64_encode(const uint8 *input, size_t length, char *output, size_t output_size, size_t *out_length) {
char *result, *r;
const uint8 *end;
size = length * 4 / 3 + 4 + 1;
r = result = (byte*)malloc(size);
size_t size = (length + 2) / 3 * 4 + 1;
if (output != NULL) {
result = output;
assert(output_size >= size);
if (output_size < size) {
*result = 0;
return NULL;
}
} else {
result = (char*)malloc(size);
if (!result)
return NULL;
}
r = result;
end = input + length - 3;
// Encode full blocks
while (input <= end) {
a = (input[0] << 16) + (input[1] << 8) + input[2];
uint32 a = (input[0] << 16) + (input[1] << 8) + input[2];
input += 3;
r[0] = base64_alphabet[(a >> 18)/* & 0x3F*/];
r[1] = base64_alphabet[(a >> 12) & 0x3F];
r[2] = base64_alphabet[(a >> 6) & 0x3F];
r[3] = base64_alphabet[(a) & 0x3F];
r[0] = kBase64Alphabet[(a >> 18)/* & 0x3F*/];
r[1] = kBase64Alphabet[(a >> 12) & 0x3F];
r[2] = kBase64Alphabet[(a >> 6) & 0x3F];
r[3] = kBase64Alphabet[(a) & 0x3F];
r += 4;
}
if (input == end + 2) {
a = input[0] << 4;
r[0] = base64_alphabet[(a >> 6) /*& 0x3F*/];
r[1] = base64_alphabet[(a) & 0x3F];
uint32 a = input[0] << 4;
r[0] = kBase64Alphabet[(a >> 6) /*& 0x3F*/];
r[1] = kBase64Alphabet[(a) & 0x3F];
r[2] = '=';
r[3] = '=';
r += 4;
} else if (input == end + 1) {
a = (input[0] << 10) + (input[1] << 2);
r[0] = base64_alphabet[(a >> 12) /*& 0x3F*/];
r[1] = base64_alphabet[(a >> 6) & 0x3F];
r[2] = base64_alphabet[(a) & 0x3F];
uint32 a = (input[0] << 10) + (input[1] << 2);
r[0] = kBase64Alphabet[(a >> 12) /*& 0x3F*/];
r[1] = kBase64Alphabet[(a >> 6) & 0x3F];
r[2] = kBase64Alphabet[(a) & 0x3F];
r[3] = '=';
r += 4;
}
@ -250,13 +259,6 @@ void RERROR(const char *msg, ...) {
}
}
void rinfo(const char *msg, ...) {
printf("muu");
}
void rinfo2(const char *msg) {
printf("muu2");
}
void RINFO(const char *msg, ...) {
va_list va;
@ -296,4 +298,115 @@ size_t my_strlcpy(char *dst, size_t dstsize, const char *src) {
memcpy(dst, src, lenx);
}
return len;
}
}
void OsGetRandomBytes(uint8 *data, size_t data_size) {
#if defined(OS_WIN)
static BOOLEAN(APIENTRY *pfn)(void*, ULONG);
if (!pfn) {
pfn = (BOOLEAN(APIENTRY *)(void*, ULONG))GetProcAddress(LoadLibrary("ADVAPI32.DLL"), "SystemFunction036");
if (!pfn)
ExitProcess(1);
}
if (!pfn(data, (ULONG)data_size)) {
ExitProcess(1);
return;
}
#elif defined(OS_POSIX)
int fd = open("/dev/urandom", O_RDONLY);
if (fd < 0) {
fprintf(stderr, "/dev/urandom failed\n");
exit(1);
}
int r = read(fd, data, data_size);
if (r != data_size) {
fprintf(stderr, "/dev/urandom failed\n");
exit(1);
}
close(fd);
#else
#error
#endif
}
bool ParseConfigKeyValue(char *m, std::vector<std::pair<char *, char*>> *result) {
for (;;) {
char *nl = strchr(m, '\n');
if (nl)
*nl = 0;
if (*m != '\0') {
char *value = strchr(m, '=');
if (value == NULL)
return false;
*value++ = '\0';
result->emplace_back(m, value);
}
if (!nl)
return true;
m = nl + 1;
}
}
bool ParseHexString(const char *text, void *data, size_t data_size) {
size_t len = strlen(text);
if (len != data_size * 2)
return false;
for (size_t i = 0; i < data_size; i++) {
uint32 c = text[i * 2 + 0];
if (c >= '0' && c <= '9') {
c -= '0';
} else if ((c |= 32) >= 'a' && c <= 'f') {
c -= 'a' - 10;
} else {
return false;
}
uint32 d = text[i * 2 + 1];
if (d >= '0' && d <= '9') {
d -= '0';
} else if ((d |= 32) >= 'a' && d <= 'f') {
d -= 'a' - 10;
} else {
return false;
}
((uint8*)data)[i] = c * 16 + d;
}
return true;
}
bool is_space(uint8_t c) {
return c == ' ' || c == '\r' || c == '\n' || c == '\t';
}
void SplitString(char *s, int separator, std::vector<char*> *components) {
components->clear();
for (;;) {
while (is_space(*s)) s++;
char *d = strchr(s, separator);
if (d == NULL) {
if (*s)
components->push_back(s);
return;
}
*d = 0;
char *e = d;
while (e > s && is_space(e[-1]))
*--e = 0;
components->push_back(s);
s = d + 1;
}
}
void PrintHexString(const void *data, size_t data_size, char *result) {
for (size_t i = 0; i < data_size; i++) {
uint8 c = ((uint8*)data)[i];
*result++ = "0123456789abcdef"[c >> 4];
*result++ = "0123456789abcdef"[c & 0xF];
}
*result++ = 0;
}
bool ParseBase64Key(const char *s, uint8 key[32]) {
size_t size = 32;
return base64_decode((uint8*)s, strlen(s), key, &size) && size == 32;
}

23
util.h
View file

@ -3,7 +3,7 @@
#pragma once
#include "tunsafe_types.h"
uint8 *base64_encode(const uint8 *input, size_t length, size_t *out_length);
char *base64_encode(const uint8 *input, size_t length, char *output, size_t output_size, size_t *actual_size);
bool base64_decode(uint8 *in, size_t inLen, uint8 *out, size_t *outLen);
bool IsOnlyZeros(const uint8 *data, size_t data_size);
@ -17,9 +17,28 @@ char *my_strndup(const char *p, size_t size);
size_t my_strlcpy(char *dst, size_t dstsize, const char *src);
template<typename T, typename U> static inline T postinc(T&x, U v) {
T t = x;
x += v;
return t;
}
template<typename T, typename U> static inline T exch(T&x, U v) {
T t = x;
x = v;
return t;
}
template<typename T> static inline T exch_null(T&x) {
T t = x;
x = NULL;
return t;
}
bool is_space(uint8_t c);
void OsGetRandomBytes(uint8 *dst, size_t dst_size);
bool ParseConfigKeyValue(char *m, std::vector<std::pair<char *, char*>> *result);
bool ParseHexString(const char *text, void *data, size_t data_size);
void PrintHexString(const void *data, size_t data_size, char *result);
void SplitString(char *s, int separator, std::vector<char*> *components);
bool ParseBase64Key(const char *s, uint8 key[32]);

View file

@ -297,6 +297,23 @@ size_t GetConfigPath(char *path, size_t path_size) {
return last + 7 - path;
}
bool ExpandConfigPath(const char *basename, char *fullname, size_t fullname_size) {
size_t len = strlen(basename);
if (FindFilenameComponent(basename)[0]) {
if (len >= fullname_size)
return false;
memcpy(fullname, basename, len + 1);
return true;
}
size_t clen = GetConfigPath(fullname, fullname_size);
if (clen == 0 || clen + len >= fullname_size)
return false;
memcpy(fullname + clen, basename, (len + 1) * sizeof(fullname[0]));
return true;
}
static bool ContainsDotDot(const char *path) {
for (uint8 last = 0, cur; (cur = path[0]) != '\0'; last = cur, path++)
if (cur == '.' && last == cur)
@ -308,7 +325,7 @@ bool EnsureValidConfigPath(const char *path) {
char buf[1024];
size_t len = GetConfigPath(buf, sizeof(buf));
return (len != 0) && (strlen(path) > len && memcmp(path, buf, len) == 0 && !ContainsDotDot(path + len));
return (len != 0) && (strlen(path) > len && _strnicmp(path, buf, len) == 0 && !ContainsDotDot(path + len));
}
bool RunProcessAsAdminWithArgs(const char *args, bool wait_for_exit) {
@ -376,3 +393,65 @@ RECT MakeRect(int l, int t, int r, int b) {
RECT rr = { l, t, r, b };
return rr;
}
// Retrieve the device path to the TAP adapter.
#define kAdapterKeyName "SYSTEM\\CurrentControlSet\\Control\\Class\\{4D36E972-E325-11CE-BFC1-08002BE10318}"
#define kNetworkConnectionsKeyName "SYSTEM\\CurrentControlSet\\Control\\Network\\{4D36E972-E325-11CE-BFC1-08002BE10318}"
#define kTapComponentId "tap0901"
void GetTapAdapterInfo(std::vector<GuidAndDevName> *result) {
LONG err;
HKEY adapter_key, device_key, network_connections_key;
bool retval = false;
GuidAndDevName gn;
err = RegOpenKeyEx(HKEY_LOCAL_MACHINE, kAdapterKeyName, 0, KEY_READ, &adapter_key);
if (err != ERROR_SUCCESS) {
RERROR("GetTapAdapterName: RegOpenKeyEx failed: 0x%X", GetLastError());
return;
}
for (int i = 0; !retval; i++) {
char keyname[64 + sizeof(kAdapterKeyName) + 1 + 32 /* some margin */];
char value[64];
DWORD len = sizeof(value), type;
err = RegEnumKeyEx(adapter_key, i, value, &len, NULL, NULL, NULL, NULL);
if (err == ERROR_NO_MORE_ITEMS)
break;
if (err != ERROR_SUCCESS) {
RERROR("GetTapAdapterName: RegEnumKeyEx failed: 0x%X", GetLastError());
break;
}
snprintf(keyname, sizeof(keyname), "%s\\%s", kAdapterKeyName, value);
err = RegOpenKeyEx(HKEY_LOCAL_MACHINE, keyname, 0, KEY_READ, &device_key);
if (err == ERROR_SUCCESS) {
len = sizeof(value);
err = RegQueryValueEx(device_key, "ComponentId", NULL, &type, (LPBYTE)value, &len);
if (err == ERROR_SUCCESS && type == REG_SZ && !memcmp(value, kTapComponentId, sizeof(kTapComponentId))) {
len = sizeof(gn.guid);
err = RegQueryValueEx(device_key, "NetCfgInstanceId", NULL, &type, (LPBYTE)gn.guid, &len);
if (err == ERROR_SUCCESS && type == REG_SZ) {
gn.guid[sizeof(gn.guid) - 1] = 0;
gn.name[0] = 0;
snprintf(keyname, sizeof(keyname), "%s\\%s\\Connection", kNetworkConnectionsKeyName, gn.guid);
err = RegOpenKeyEx(HKEY_LOCAL_MACHINE, keyname, 0, KEY_READ, &network_connections_key);
if (err == ERROR_SUCCESS) {
len = sizeof(gn.guid);
err = RegQueryValueEx(network_connections_key, "Name", NULL, &type, (LPBYTE)gn.name, &len);
if (err == ERROR_SUCCESS && type == REG_SZ) {
gn.name[sizeof(gn.guid) - 1] = 0;
} else {
gn.name[0] = 0;
}
RegCloseKey(network_connections_key);
}
result->push_back(gn);
}
}
RegCloseKey(device_key);
}
}
RegCloseKey(adapter_key);
}

View file

@ -1,6 +1,7 @@
// SPDX-License-Identifier: AGPL-1.0-only
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
#include "tunsafe_types.h"
#include <vector>
#pragma once
const char *FindFilenameComponent(const char *s);
@ -47,6 +48,7 @@ void ShellExecuteFromExplorer(
int nShowCmd = SW_SHOWNORMAL);
size_t GetConfigPath(char *path, size_t path_size);
bool ExpandConfigPath(const char *basename, char *fullname, size_t fullname_size);
bool EnsureValidConfigPath(const char *path);
bool RunProcessAsAdminWithArgs(const char *args, bool wait_for_exit);
@ -54,3 +56,8 @@ bool RestartProcessAsAdministrator();
bool SetClipboardString(const char *string);
RECT GetParentRect(HWND wnd);
RECT MakeRect(int l, int t, int r, int b);
struct GuidAndDevName {
char guid[40];
char name[64];
};
void GetTapAdapterInfo(std::vector<GuidAndDevName> *result);

View file

@ -15,8 +15,7 @@
#include "ipzip2/ipzip2.h"
#include "wireguard.h"
#include "wireguard_config.h"
uint64 OsGetMilliseconds();
#include "util.h"
enum {
IPV4_HEADER_SIZE = 20,
@ -36,23 +35,24 @@ WireguardProcessor::WireguardProcessor(UdpInterface *udp, TunInterface *tun, Pro
add_routes_mode_ = true;
dns_blocking_ = true;
internet_blocking_ = kBlockInternet_Default;
is_started_ = false;
stats_last_bytes_in_ = 0;
stats_last_bytes_out_ = 0;
stats_last_ts_ = OsGetMilliseconds();
main_thread_scheduled_ = NULL;
main_thread_scheduled_last_ = &main_thread_scheduled_;
}
WireguardProcessor::~WireguardProcessor() {
}
void WireguardProcessor::SetListenPort(int listen_port) {
listen_port_ = listen_port;
if (listen_port_ != listen_port) {
listen_port_ = listen_port;
if (is_started_ && !ConfigureUdp()) {
RINFO("ConfigureUdp failed");
}
}
}
void WireguardProcessor::AddDnsServer(const IpAddr &sin) {
std::vector<IpAddr> *target = (sin.sin.sin_family == AF_INET6) ? &dns6_addr_ : &dns_addr_;
target->push_back(sin);
@ -66,6 +66,11 @@ bool WireguardProcessor::SetTunAddress(const WgCidrAddr &addr) {
return true;
}
void WireguardProcessor::ClearTunAddress() {
tun_addr_.size = 0;
tun6_addr_.size = 0;
}
void WireguardProcessor::AddExcludedIp(const WgCidrAddr &cidr_addr) {
excluded_ips_.push_back(cidr_addr);
}
@ -129,23 +134,29 @@ static bool IsWgCidrAddrSubsetOf(const WgCidrAddr &inner, const WgCidrAddr &oute
}
bool WireguardProcessor::Start() {
return ConfigureUdp() && ConfigureTun();
}
bool WireguardProcessor::ConfigureUdp() {
assert(dev_.IsMainThread());
if (!udp_->Initialize(listen_port_))
return false;
return udp_->Configure(listen_port_);
}
if (tun_addr_.size != 32) {
RERROR("No IPv4 address configured");
return false;
}
if (tun_addr_.cidr >= 31) {
RERROR("TAP is not compatible CIDR /31 or /32. Changing to /24");
tun_addr_.cidr = 24;
}
bool WireguardProcessor::ConfigureTun() {
assert(dev_.IsMainThread());
TunInterface::TunConfig config = {0};
config.ip = ReadBE32(tun_addr_.addr);
config.cidr = tun_addr_.cidr;
if (tun_addr_.size == 32) {
if (tun_addr_.cidr >= 31) {
RERROR("TAP is not compatible CIDR /31 or /32. Changing to /24");
tun_addr_.cidr = 24;
}
config.ip = ReadBE32(tun_addr_.addr);
config.cidr = tun_addr_.cidr;
} else {
RERROR("No IPv4 address configured");
}
config.mtu = mtu_;
config.pre_post_commands = pre_post_;
config.excluded_ips = excluded_ips_;
@ -205,7 +216,7 @@ bool WireguardProcessor::Start() {
config.ipv6_dns = dns6_addr_;
TunInterface::TunConfigOut config_out;
if (!tun_->Initialize(std::move(config), &config_out))
if (!tun_->Configure(std::move(config), &config_out))
return false;
SetupCompressionHeader(dev_.compression_header());
@ -221,6 +232,7 @@ bool WireguardProcessor::Start() {
}
}
is_started_ = true;
return true;
}
@ -395,22 +407,6 @@ getout:
FreePacket(packet);
}
void WgPeer::AddPacketToPeerQueue(Packet *packet) {
assert(IsPeerLocked());
// Keep only the first MAX_QUEUED_PACKETS packets.
while (num_queued_packets_ >= MAX_QUEUED_PACKETS_PER_PEER) {
Packet *packet = first_queued_packet_;
first_queued_packet_ = packet->next;
num_queued_packets_--;
FreePacket(packet);
}
// Add the packet to the out queue that will get sent once handshake completes
*last_queued_packet_ptr_ = packet;
last_queued_packet_ptr_ = &packet->next;
packet->next = NULL;
num_queued_packets_++;
}
// This function must be called with the peer lock held. It will remove the lock
void WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet) {
assert(peer->IsPeerLocked());
@ -427,11 +423,17 @@ void WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Pac
if ((keypair = peer->curr_keypair_) == NULL ||
(send_ctr = keypair->send_ctr) >= REJECT_AFTER_MESSAGES) {
peer->AddPacketToPeerQueue(packet);
// If RemovePeer has been called then discard any packets currently being written to it.
// curr_keypair_ is NULL when RemovePeer has been called so it's safe to do this here.
if (peer->marked_for_delete_)
goto getout_discard;
peer->AddPacketToPeerQueue_Locked(packet);
WG_RELEASE_LOCK(peer->mutex_);
ScheduleNewHandshake(peer);
peer->ScheduleNewHandshake();
return;
}
assert(!peer->marked_for_delete_);
stats_.tun_bytes_in += size;
stats_.tun_packets_in++;
@ -524,13 +526,14 @@ add_padding:
WriteLE32(write -= 4, keypair->remote_key_id);
*--write = tag;
// Not using any fields from now on
WG_RELEASE_LOCK(peer->mutex_);
header_size = data - write;
packet->size = (int)(size + header_size + keypair->auth_tag_length);
peer->tx_bytes_ += packet->size;
stats_.compression_wg_saved_out += (int64)16 - header_size;
packet->data = data - header_size;
packet->size = (int)(size + header_size + keypair->auth_tag_length);
// Not using any fields from now on
WG_RELEASE_LOCK(peer->mutex_);
// todo: figure out what to actually use as ad.
ad = write_after_ack_header;
@ -540,6 +543,9 @@ need_big_packet:
#else
{
#endif // #if WITH_SHORT_HEADERS
packet->size = (int)(size + sizeof(MessageData) + keypair->auth_tag_length);
peer->tx_bytes_ += packet->size;
// Not using any fields from now on
WG_RELEASE_LOCK(peer->mutex_);
@ -547,7 +553,6 @@ need_big_packet:
((MessageData*)data)[-1].receiver_id = keypair->remote_key_id;
((MessageData*)data)[-1].counter = ToLE64(send_ctr);
packet->data = data - sizeof(MessageData);
packet->size = (int)(size + sizeof(MessageData) + keypair->auth_tag_length);
ad = NULL;
ad_len = 0;
}
@ -556,7 +561,7 @@ need_big_packet:
DoWriteUdpPacket(packet);
if (want_handshake)
ScheduleNewHandshake(peer);
peer->ScheduleNewHandshake();
return;
getout_discard:
@ -608,38 +613,32 @@ void WireguardProcessor::DoWriteUdpPacket(Packet *packet) {
ScrambleUnscrambleAndWrite(packet, &dev_.header_obfuscation_key_, udp_);
}
void WireguardProcessor::ScheduleNewHandshake(WgPeer *peer) {
if (peer->main_thread_scheduled_.fetch_or(WgPeer::kMainThreadScheduled_ScheduleHandshake) == 0) {
peer->main_thread_scheduled_next_ = NULL;
WG_ACQUIRE_LOCK(main_thread_scheduled_lock_);
*main_thread_scheduled_last_ = peer;
main_thread_scheduled_last_ = &peer->main_thread_scheduled_next_;
WG_RELEASE_LOCK(main_thread_scheduled_lock_);
// todo: in multithreaded impl need to trigger |RunAllMainThreadScheduled| to get called
}
}
void WireguardProcessor::RunAllMainThreadScheduled() {
WgPeer *peer, *next;
assert(dev_.IsMainThread());
if (main_thread_scheduled_ == NULL)
if (dev_.main_thread_scheduled_ == NULL)
return;
WG_ACQUIRE_LOCK(main_thread_scheduled_lock_);
WgPeer *peer = main_thread_scheduled_;
main_thread_scheduled_ = NULL;
main_thread_scheduled_last_ = &main_thread_scheduled_;
WG_RELEASE_LOCK(main_thread_scheduled_lock_);
WG_ACQUIRE_LOCK(dev_.main_thread_scheduled_lock_);
peer = dev_.main_thread_scheduled_;
dev_.main_thread_scheduled_ = NULL;
dev_.main_thread_scheduled_last_ = &dev_.main_thread_scheduled_;
WG_RELEASE_LOCK(dev_.main_thread_scheduled_lock_);
for (; peer; peer = next) {
// todo: for the multithreaded use case figure out whether to use atomic_thread_fence here,
// because we need to read this next value before any other thread sees the 0 we write
// to peer->main_thread_scheduled_.
next = peer->main_thread_scheduled_next_;
if (peer->marked_for_delete_)
continue;
while (peer) {
// todo: for the multithreaded use case figure out whether to use atomic_thread_fence here.
WgPeer *next = peer->main_thread_scheduled_next_;
uint32 ev = peer->main_thread_scheduled_.exchange(0);
if (ev & WgPeer::kMainThreadScheduled_ScheduleHandshake) {
peer->handshake_attempts_ = 0;
SendHandshakeInitiation(peer);
}
peer = next;
}
}
@ -658,6 +657,7 @@ void WireguardProcessor::SendHandshakeInitiation(WgPeer *peer) {
procdel_->OnConnectionRetry(attempts);
peer->OnHandshakeInitSent();
packet->addr = peer->endpoint_;
peer->tx_bytes_ += packet->size;
WG_RELEASE_LOCK(peer->mutex_);
DoWriteUdpPacket(packet);
if (attempts > 1 && attempts <= 20)
@ -696,19 +696,21 @@ void WireguardProcessor::HandleUdpPacket(Packet *packet, bool overload) {
#endif // WITH_SHORT_HEADERS
} else if (type == MESSAGE_HANDSHAKE_COOKIE) {
assert(dev_.IsMainThread());
if (packet->size != sizeof(MessageHandshakeCookie))
if (packet->size != sizeof(MessageHandshakeCookie) || !dev_.is_private_key_initialized())
goto invalid_size;
HandleHandshakeCookiePacket(packet);
} else if (type == MESSAGE_HANDSHAKE_INITIATION) {
assert(dev_.IsMainThread());
if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeInitiation)) : (packet->size != sizeof(MessageHandshakeInitiation)))
if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeInitiation)) : (packet->size != sizeof(MessageHandshakeInitiation)) ||
!dev_.is_private_key_initialized())
goto invalid_size;
stats_.handshakes_in++;
if (CheckIncomingHandshakeRateLimit(packet, overload))
HandleHandshakeInitiationPacket(packet);
} else if (type == MESSAGE_HANDSHAKE_RESPONSE) {
assert(dev_.IsMainThread());
if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeResponse)) : (packet->size != sizeof(MessageHandshakeResponse)))
if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeResponse)) : (packet->size != sizeof(MessageHandshakeResponse)) ||
!dev_.is_private_key_initialized())
goto invalid_size;
if (CheckIncomingHandshakeRateLimit(packet, overload))
HandleHandshakeResponsePacket(packet);
@ -749,6 +751,8 @@ void WgPeer::CopyEndpointToPeer_Locked(WgKeypair *keypair, const IpAddr *addr) {
#if WITH_SHORT_HEADERS
void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packet) {
assert(dev_.IsMainOrDataThread());
uint8 *data = packet->data + 1;
size_t bytes_left = packet->size - 1;
WgKeypair *keypair;
@ -832,6 +836,8 @@ void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packe
WG_ACQUIRE_LOCK(keypair->peer->mutex_);
keypair->peer->rx_bytes_ += packet->size;
if (keypair->recv_key_state == WgKeypair::KEY_INVALID)
goto getout_unlock;
@ -896,7 +902,7 @@ void WireguardProcessor::HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *key
WgKeypair *curr_keypair = peer->curr_keypair_;
if (curr_keypair && curr_keypair->recv_key_state == WgKeypair::KEY_WANT_REFRESH) {
curr_keypair->recv_key_state = WgKeypair::KEY_DID_REFRESH;
ScheduleNewHandshake(peer);
peer->ScheduleNewHandshake();
}
if (data_size == 0) {
@ -965,6 +971,8 @@ getout:
}
void WireguardProcessor::HandleDataPacket(Packet *packet) {
assert(dev_.IsMainOrDataThread());
uint8 *data = packet->data;
size_t data_size = packet->size;
uint32 key_id = ((MessageData*)data)->receiver_id;
@ -984,6 +992,7 @@ getout:
}
WG_ACQUIRE_LOCK(keypair->peer->mutex_);
keypair->peer->rx_bytes_ += data_size;
if (keypair->recv_key_state == WgKeypair::KEY_INVALID) {
stats_.error_key_id++;
WG_RELEASE_LOCK(keypair->peer->mutex_);
@ -993,6 +1002,8 @@ getout:
WG_RELEASE_LOCK(keypair->peer->mutex_);
goto getout;
} else {
assert(!keypair->peer->marked_for_delete_);
WgPeer::CopyEndpointToPeer_Locked(keypair, &packet->addr);
HandleAuthenticatedDataPacket_WillUnlock(keypair, packet, data + sizeof(MessageData), data_size - sizeof(MessageData) - keypair->auth_tag_length);
}
@ -1119,7 +1130,7 @@ void WireguardProcessor::SecondLoop() {
uint32 mask;
{
WG_SCOPED_LOCK(peer->mutex_);
mask = peer->CheckTimeouts(now);
mask = peer->CheckTimeouts_Locked(now);
if (mask == 0)
continue;
if (mask & WgPeer::ACTION_SEND_KEEPALIVE)

View file

@ -66,6 +66,7 @@ enum InternetBlockState {
};
class WireguardProcessor {
friend class WgConfig;
public:
WireguardProcessor(UdpInterface *udp, TunInterface *tun, ProcessorDelegate *procdel);
~WireguardProcessor();
@ -73,13 +74,14 @@ public:
void SetListenPort(int listen_port);
void AddDnsServer(const IpAddr &sin);
bool SetTunAddress(const WgCidrAddr &addr);
void ClearTunAddress();
void AddExcludedIp(const WgCidrAddr &cidr_addr);
void SetMtu(int mtu);
void SetAddRoutesMode(bool mode);
void SetDnsBlocking(bool dns_blocking);
void SetInternetBlocking(InternetBlockState internet_blocking);
void SetHeaderObfuscation(const char *key);
void HandleTunPacket(Packet *packet);
void HandleUdpPacket(Packet *packet, bool overload);
static bool IsMainThreadPacket(Packet *packet);
@ -91,6 +93,9 @@ public:
bool Start();
bool ConfigureUdp();
bool ConfigureTun();
WgDevice &dev() { return dev_; }
TunInterface::PrePostCommands &prepost() { return pre_post_; }
const WgCidrAddr &tun_addr() { return tun_addr_; }
@ -100,7 +105,6 @@ private:
void DoWriteUdpPacket(Packet *packet);
void WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet);
void SendHandshakeInitiation(WgPeer *peer);
void ScheduleNewHandshake(WgPeer *peer);
void SendKeepalive_Locked(WgPeer *peer);
void SendQueuedPackets_Locked(WgPeer *peer);
@ -110,29 +114,25 @@ private:
void HandleDataPacket(Packet *packet);
void HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet, uint8 *data, size_t data_size);
void HandleShortHeaderFormatPacket(uint32 tag, Packet *packet);
bool CheckIncomingHandshakeRateLimit(Packet *packet, bool overload);
bool HandleIcmpv6NeighborSolicitation(const byte *data, size_t data_size);
void SetupCompressionHeader(WgPacketCompressionVer01 *c);
void NotifyHandshakeComplete();
int listen_port_;
ProcessorDelegate *procdel_;
TunInterface *tun_;
UdpInterface *udp_;
int mtu_;
WgProcessorStats stats_;
uint16 listen_port_;
uint16 mtu_;
bool dns_blocking_;
uint8 internet_blocking_;
bool add_routes_mode_;
bool network_discovery_spoofing_;
bool did_have_first_handshake_;
bool is_started_;
uint8 network_discovery_mac_[6];
WgDevice dev_;
@ -140,14 +140,12 @@ private:
WgCidrAddr tun_addr_;
WgCidrAddr tun6_addr_;
WgProcessorStats stats_;
std::vector<IpAddr> dns_addr_, dns6_addr_;
TunInterface::PrePostCommands pre_post_;
// Queue of things scheduled to run on the main thread.
WG_DECLARE_LOCK(main_thread_scheduled_lock_);
WgPeer *main_thread_scheduled_, **main_thread_scheduled_last_;
uint64 stats_last_bytes_in_, stats_last_bytes_out_;
uint64 stats_last_ts_;

View file

@ -45,12 +45,26 @@ char *PrintIpAddr(const IpAddr &addr, char buf[kSizeOfAddress]) {
return buf;
}
char *PrintWgCidrAddr(const WgCidrAddr &addr, char buf[kSizeOfAddress]) {
if (addr.size == 32) {
print_ip_prefix(buf, AF_INET, addr.addr, addr.cidr);
} else if (addr.size == 128) {
print_ip_prefix(buf, AF_INET6, addr.addr, addr.cidr);
} else {
buf[0] = 0;
}
return buf;
}
struct Addr {
byte addr[4];
uint8 cidr;
};
static bool ParseCidrAddr(char *s, WgCidrAddr *out) {
bool ParseCidrAddr(char *s, WgCidrAddr *out) {
char *slash = strchr(s, '/');
if (!slash)
return false;
@ -92,15 +106,6 @@ bool DnsResolver::Resolve(const char *hostname, IpAddr *result) {
char buf[kSizeOfAddress];
memset(result, 0, sizeof(IpAddr));
if (inet_pton(AF_INET6, hostname, &result->sin6.sin6_addr) == 1) {
result->sin.sin_family = AF_INET6;
return true;
}
if (inet_pton(AF_INET, hostname, &result->sin.sin_addr) == 1) {
result->sin.sin_family = AF_INET;
return true;
}
// First check cache
for (auto it = cache_.begin(); it != cache_.end(); ++it) {
@ -145,10 +150,7 @@ bool DnsResolver::Resolve(const char *hostname, IpAddr *result) {
}
}
static bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver) {
bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver) {
memset(sin, 0, sizeof(IpAddr));
if (*s == '[') {
char *end = strchr(s, ']');
@ -168,7 +170,11 @@ static bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver)
if (!x) return false;
*x = 0;
if (!resolver->Resolve(s, sin)) {
if (inet_pton(AF_INET, s, &sin->sin.sin_addr) == 1) {
sin->sin.sin_family = AF_INET;
} else if (!resolver) {
return false;
} else if (!resolver->Resolve(s, sin)) {
RERROR("Unable to resolve %s", s);
return false;
}
@ -177,18 +183,19 @@ static bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver)
}
static bool ParseSockaddrInWithoutPort(char *s, IpAddr *sin, DnsResolver *resolver) {
if (!resolver->Resolve(s, sin)) {
if (inet_pton(AF_INET6, s, &sin->sin6.sin6_addr) == 1) {
sin->sin.sin_family = AF_INET6;
return true;
} else if (inet_pton(AF_INET, s, &sin->sin.sin_addr) == 1) {
sin->sin.sin_family = AF_INET;
return true;
} else if (!resolver->Resolve(s, sin)) {
RERROR("Unable to resolve %s", s);
return false;
}
return true;
}
static bool ParseBase64Key(const char *s, uint8 key[32]) {
size_t size = 32;
return base64_decode((uint8*)s, strlen(s), key, &size) && size == 32;
}
class WgFileParser {
public:
WgFileParser(WireguardProcessor *wg, DnsResolver *resolver) : wg_(wg), dns_resolver_(resolver) {}
@ -197,7 +204,7 @@ public:
void FinishGroup();
struct Peer {
uint8 pub[32];
WgPublicKey pub;
uint8 psk[32];
};
Peer pi_;
@ -206,29 +213,6 @@ public:
bool had_interface_ = false;
};
bool is_space(uint8_t c) {
return c == ' ' || c == '\r' || c == '\n' || c == '\t';
}
void SplitString(char *s, int separator, std::vector<char*> *components) {
for (;;) {
while (is_space(*s)) s++;
char *d = strchr(s, separator);
if (d == NULL) {
if (*s)
components->push_back(s);
return;
}
*d = 0;
char *e = d;
while (e > s && is_space(e[-1]))
*--e = 0;
components->push_back(s);
s = d + 1;
}
}
static bool ParseBoolean(const char *str, bool *value) {
if (_stricmp(str, "true") == 0 ||
_stricmp(str, "yes") == 0 ||
@ -285,7 +269,7 @@ static int ParseCipherSuite(const char *cipher) {
void WgFileParser::FinishGroup() {
if (peer_) {
peer_->Initialize(pi_.pub, pi_.psk);
peer_->SetPublicKey(pi_.pub);
peer_ = NULL;
}
}
@ -303,7 +287,7 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
if (!ParseBase64Key(value, binkey))
return false;
had_interface_ = true;
wg_->dev().Initialize(binkey);
wg_->dev().SetPrivateKey(binkey);
} else if (strcmp(key, "ListenPort") == 0) {
wg_->SetListenPort(atoi(value));
} else if (strcmp(key, "Address") == 0) {
@ -394,11 +378,12 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
return true;
}
if (strcmp(key, "PublicKey") == 0) {
if (!ParseBase64Key(value, pi_.pub))
if (!ParseBase64Key(value, pi_.pub.bytes))
return false;
} else if (strcmp(key, "PresharedKey") == 0) {
if (!ParseBase64Key(value, pi_.psk))
return false;
peer_->SetPresharedKey(pi_.psk);
} else if (strcmp(key, "AllowedIPs") == 0) {
SplitString(value, ',', &ss);
for (size_t i = 0; i < ss.size(); i++) {
@ -412,7 +397,8 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
return false;
peer_->SetEndpoint(sin);
} else if (strcmp(key, "PersistentKeepalive") == 0) {
peer_->SetPersistentKeepalive(atoi(value));
if (!peer_->SetPersistentKeepalive(atoi(value)))
return false;
} else if (strcmp(key, "AllowMulticast") == 0) {
bool b;
if (!ParseBoolean(value, &b))
@ -524,3 +510,154 @@ bool ParseWireGuardConfigFile(WireguardProcessor *wg, const char *filename, DnsR
fclose(f);
return true;
}
static void CmsgAppendFmt(std::string *result, const char *fmt, ...) {
va_list va;
char buf[256];
va_start(va, fmt);
vsnprintf(buf, sizeof(buf), fmt, va);
(*result) += buf;
(*result) += '\n';
va_end(va);
}
static void CmsgAppendHex(std::string *result, const char *key, const void *data, size_t data_size) {
char *tmp = (char*)alloca(data_size * 2 + 2);
PrintHexString(data, data_size, tmp + 1);
tmp[0] = '=';
tmp[data_size * 2 + 1] = '\n';
(*result) += key;
result->append(tmp, data_size * 2 + 2);
}
void WgConfig::HandleConfigurationProtocolGet(WireguardProcessor *proc, std::string *result) {
char buf[kSizeOfAddress];
CmsgAppendHex(result, "private_key", proc->dev_.s_priv_, sizeof(proc->dev_.s_priv_));
if (proc->listen_port_)
CmsgAppendFmt(result, "listen_port=%d", proc->listen_port_);
if (proc->tun_addr_.size == 32)
CmsgAppendFmt(result, "address=%s", PrintWgCidrAddr(proc->tun_addr_, buf));
if (proc->tun6_addr_.size == 128)
CmsgAppendFmt(result, "address=%s", PrintWgCidrAddr(proc->tun6_addr_, buf));
for (WgPeer *peer = proc->dev_.peers_; peer; peer = peer->next_peer_) {
WG_SCOPED_LOCK(peer->lock_);
CmsgAppendHex(result, "public_key", peer->s_remote_.bytes, sizeof(peer->s_remote_));
if (!IsOnlyZeros(peer->preshared_key_, sizeof(peer->preshared_key_)))
CmsgAppendHex(result, "preshared_key", peer->preshared_key_, sizeof(peer->preshared_key_));
if (peer->tx_bytes_ | peer->rx_bytes_)
CmsgAppendFmt(result, "tx_bytes=%lld\nrx_bytes=%lld", peer->tx_bytes_, peer->rx_bytes_);
for (auto it = peer->allowed_ips_.begin(); it != peer->allowed_ips_.end(); ++it)
CmsgAppendFmt(result, "allowed_ip=%s", PrintWgCidrAddr(*it, buf));
if (peer->persistent_keepalive_ms_)
CmsgAppendFmt(result, "persistent_keepalive_interval=%d", peer->persistent_keepalive_ms_ / 1000);
if (peer->endpoint_.sin.sin_family == AF_INET)
CmsgAppendFmt(result, "endpoint=%s:%d", PrintIpAddr(peer->endpoint_, buf), htons(peer->endpoint_.sin.sin_port));
else if (peer->endpoint_.sin.sin_family == AF_INET6)
CmsgAppendFmt(result, "endpoint=[%s]:%d", PrintIpAddr(peer->endpoint_, buf), htons(peer->endpoint_.sin6.sin6_port));
if (peer->last_complete_handskake_timestamp_) {
uint64 millis_since = OsGetMilliseconds() - peer->last_complete_handskake_timestamp_;
uint64 when = time(NULL) - millis_since / 1000;
CmsgAppendFmt(result, "last_handshake_time_sec=%lld", when);
}
}
CmsgAppendFmt(result, "protocol_version=1");
}
bool WgConfig::HandleConfigurationProtocolMessage(WireguardProcessor *proc, const std::string &&message, std::string *result) {
std::string message_copy(std::move(message));
std::vector<std::pair<char *, char*>> kv;
bool is_set = false;
bool did_set_address = false;
WgPeer *peer = NULL;
WgCidrAddr cidr_addr;
IpAddr sin;
uint8 buf32[32];
assert(proc->dev().IsMainThread());
result->clear();
if (!ParseConfigKeyValue(&message_copy[0], &kv))
return false;
for (auto it : kv) {
char *key = it.first, *value = it.second;
if (strcmp(key, "get") == 0) {
if (strcmp(value, "1") != 0)
goto getout_fail;
HandleConfigurationProtocolGet(proc, result);
break;
} else if (strcmp(key, "set") == 0) {
if (strcmp(value, "1") != 0)
goto getout_fail;
is_set = true;
} else if (is_set) {
if (strcmp(key, "private_key") == 0) {
if (!ParseHexString(value, buf32, 32)) goto getout_fail;
proc->dev_.SetPrivateKey(buf32);
} else if (strcmp(key, "listen_port") == 0) {
int new_port = atoi(value);
proc->SetListenPort(new_port);
} else if (strcmp(key, "replace_peers") == 0) {
if (strcmp(value, "true") != 0) goto getout_fail;
proc->dev_.RemoveAllPeers();
} else if (strcmp(key, "address") == 0) {
if (!ParseCidrAddr(value, &cidr_addr)) goto getout_fail;
if (!did_set_address) {
did_set_address = true;
proc->ClearTunAddress();
}
if (!proc->SetTunAddress(cidr_addr)) goto getout_fail;
} else if (strcmp(key, "public_key") == 0) {
WgPublicKey pubkey;
if (!ParseHexString(value, pubkey.bytes, 32)) goto getout_fail;
peer = proc->dev_.GetPeerFromPublicKey(pubkey);
if (!peer) {
peer = proc->dev_.AddPeer();
peer->SetPublicKey(pubkey);
}
} else if (peer != NULL) {
if (strcmp(key, "remove") == 0) {
if (strcmp(value, "true") != 0) goto getout_fail;
peer->RemovePeer();
peer = NULL;
} else if (strcmp(key, "preshared_key") == 0) {
if (!ParseHexString(value, buf32, 32)) goto getout_fail;
peer->SetPresharedKey(buf32);
} else if (strcmp(key, "endpoint") == 0) {
if (!ParseSockaddrInWithPort(value, &sin, NULL)) goto getout_fail;
peer->SetEndpoint(sin);
} else if (strcmp(key, "persistent_keepalive_interval") == 0) {
if (!peer->SetPersistentKeepalive(atoi(value)))
goto getout_fail;
} else if (strcmp(key, "replace_allowed_ips") == 0) {
if (strcmp(value, "true") != 0) goto getout_fail;
peer->RemoveAllIps();
} else if (strcmp(key, "allowed_ip") == 0) {
if (!ParseCidrAddr(value, &cidr_addr)) goto getout_fail;
peer->AddIp(cidr_addr);
}
}
} else {
goto getout_fail;
}
}
// reconfigure the tun interface?
if (did_set_address) {
proc->ConfigureTun();
}
result->append("errno=0\n\n");
return true;
getout_fail:
(*result) = "errno=1\n\n";
return false;
}

View file

@ -30,11 +30,19 @@ private:
};
class WgConfig {
public:
static bool HandleConfigurationProtocolMessage(WireguardProcessor *proc, const std::string &&message, std::string *result);
private:
static void HandleConfigurationProtocolGet(WireguardProcessor *proc, std::string *result);
};
bool ParseWireGuardConfigFile(WireguardProcessor *wg, const char *filename, DnsResolver *dns_resolver);
#define kSizeOfAddress 64
const char *print_ip_prefix(char buf[kSizeOfAddress], int family, const void *ip, int prefixlen);
char *PrintIpAddr(const IpAddr &addr, char buf[kSizeOfAddress]);
char *PrintWgCidrAddr(const WgCidrAddr &addr, char buf[kSizeOfAddress]);
bool ParseCidrAddr(char *s, WgCidrAddr *out);
#endif // TINYVPN_TINYVPN_H_

View file

@ -54,18 +54,27 @@ bool ReplayDetector::CheckReplay(uint64 seq_nr) {
WgDevice::WgDevice() {
peers_ = NULL;
last_peer_ptr_ = &peers_;
delegate_ = NULL;
header_obfuscation_ = false;
is_private_key_initialized_ = false;
next_rng_slot_ = 0;
main_thread_scheduled_ = NULL;
main_thread_scheduled_last_ = &main_thread_scheduled_;
memset(&compression_header_, 0, sizeof(compression_header_));
low_resolution_timestamp_ = cookie_secret_timestamp_ = OsGetMilliseconds();
OsGetRandomBytes(cookie_secret_, sizeof(cookie_secret_));
OsGetRandomBytes((uint8*)random_number_input_, sizeof(random_number_input_));
SetCurrentThreadAsMainThread();
main_thread_id_ = GetCurrentThreadId();
memset(s_priv_, 0, sizeof(s_priv_));
memset(s_pub_, 0, sizeof(s_pub_));
}
WgDevice::~WgDevice() {
assert(IsMainThread());
RemoveAllPeers();
}
void WgDevice::SecondLoop(uint64 now) {
@ -151,7 +160,8 @@ static inline void ComputeHKDF2DH(uint8 ci[WG_HASH_LEN], uint8 k[WG_SYMMETRIC_KE
memzero_crypto(dh, sizeof(dh));
}
void WgDevice::Initialize(const uint8 private_key[WG_PUBLIC_KEY_LEN]) {
void WgDevice::SetPrivateKey(const uint8 private_key[WG_PUBLIC_KEY_LEN]) {
assert(IsMainThread());
// Derive the public key from the private key.
memcpy(s_priv_, private_key, sizeof(s_priv_));
curve25519_donna(s_pub_, s_priv_, kCurve25519Basepoint);
@ -162,26 +172,31 @@ void WgDevice::Initialize(const uint8 private_key[WG_PUBLIC_KEY_LEN]) {
kLabelCookie, sizeof(kLabelCookie), s_pub_, sizeof(s_pub_));
BlakeX2(precomputed_mac1_key_, sizeof(precomputed_mac1_key_),
kLabelMac1, sizeof(kLabelMac1), s_pub_, sizeof(s_pub_));
is_private_key_initialized_ = true;
// Recompute peer data because it depends on my privkey
for (WgPeer *peer = peers_; peer; peer = peer->next_peer_)
peer->SetPublicKey(peer->s_remote_);
}
WgPeer *WgDevice::AddPeer() {
assert(IsMainThread());
WgPeer *peer = new WgPeer(this);
WgPeer **pp = &peers_;
while (*pp)
pp = &(*pp)->next_peer_;
*pp = peer;
return peer;
}
WgPeer *WgDevice::GetPeerFromPublicKey(uint8 public_key[WG_PUBLIC_KEY_LEN]) {
void WgDevice::RemoveAllPeers() {
assert(IsMainThread());
// todo: add O(1) lookup
for (WgPeer *peer = peers_; peer; peer = peer->next_peer_) {
if (memcmp(peer->s_remote_, public_key, WG_PUBLIC_KEY_LEN) == 0)
return peer;
}
return NULL;
while (peers_)
peers_->RemovePeer();
}
WgPeer *WgDevice::GetPeerFromPublicKey(const WgPublicKey &pubkey) {
assert(IsMainThread());
auto it = peer_id_lookup_.find(pubkey);
return (it != peer_id_lookup_.end()) ? it->second : NULL;
}
bool WgDevice::CheckCookieMac1(Packet *packet) {
@ -230,6 +245,7 @@ void WgDevice::CreateCookieMessage(MessageHandshakeCookie *dst, Packet *packet,
}
void WgDevice::EraseKeypairAddrEntry_Locked(WgKeypair *kp) {
// todo: figure out how to make this multithread safe.
WgAddrEntry *ae = kp->addr_entry;
assert(ae->ref_count >= 1);
@ -313,7 +329,6 @@ void WgDevice::SetHeaderObfuscation(const char *key) {
#endif // WITH_HEADER_OBFUSCATION
}
WgPeer::WgPeer(WgDevice *dev) {
assert(dev->IsMainThread());
dev_ = dev;
@ -323,6 +338,7 @@ WgPeer::WgPeer(WgDevice *dev) {
expect_cookie_reply_ = false;
has_mac2_cookie_ = false;
pending_keepalive_ = false;
marked_for_delete_ = false;
allow_multicast_through_peer_ = false;
allow_endpoint_change_ = true;
supports_handshake_extensions_ = true;
@ -331,6 +347,8 @@ WgPeer::WgPeer(WgDevice *dev) {
last_handshake_init_recv_timestamp_ = 0;
last_complete_handskake_timestamp_ = 0;
persistent_keepalive_ms_ = 0;
rx_bytes_ = 0;
tx_bytes_ = 0;
timers_ = 0;
first_queued_packet_ = NULL;
last_queued_packet_ptr_ = &first_queued_packet_;
@ -343,15 +361,66 @@ WgPeer::WgPeer(WgDevice *dev) {
memset(last_timestamp_, 0, sizeof(last_timestamp_));
ipv4_broadcast_addr_ = 0xffffffff;
memset(features_, 0, sizeof(features_));
memset(preshared_key_, 0, sizeof(preshared_key_));
memset(&s_remote_, 0, sizeof(s_remote_));
// Insert into the parent's linked list
*dev_->last_peer_ptr_ = this;
dev_->last_peer_ptr_ = &next_peer_;
}
WgPeer::~WgPeer() {
// do not delete this directly, instead call RemovePeer
assert(marked_for_delete_);
assert(dev_->IsMainThread());
assert(curr_keypair_ == NULL && next_keypair_ == NULL && prev_keypair_ == NULL);
assert(local_key_id_during_hs_ == 0);
assert(first_queued_packet_ == NULL);
}
void WgPeer::DelayedDelete(void *x) {
WgPeer *peer = (WgPeer*)x;
assert(peer->dev_->IsMainThread());
if (peer->main_thread_scheduled_ != 0) {
WG_ACQUIRE_LOCK(peer->dev_->main_thread_scheduled_lock_);
// Unlink myself from the main thread scheduled list
for (WgPeer **pp = &peer->dev_->main_thread_scheduled_; *pp; pp = &(*pp)->main_thread_scheduled_next_) {
if (*pp == peer) {
*pp = peer->main_thread_scheduled_next_;
break;
}
}
WG_RELEASE_LOCK(peer->dev_->main_thread_scheduled_lock_);
}
delete peer;
}
void WgPeer::RemovePeer() {
assert(dev_->IsMainThread());
assert(!marked_for_delete_);
// Find and unlink the peer from the parent's peer list
WgPeer **pp = &dev_->peers_;
while (*pp != this)
pp = &(*pp)->next_peer_;
if ((*pp = next_peer_) == NULL)
dev_->last_peer_ptr_ = pp;
RemoveAllIps();
dev_->peer_id_lookup_.erase(s_remote_);
WG_ACQUIRE_LOCK(mutex_);
marked_for_delete_ = true;
ClearKeys_Locked();
ClearHandshake_Locked();
ClearPacketQueue_Locked();
WG_RELEASE_LOCK(mutex_);
// The WgPeer instance may still be accessible from
// worker threads that already started processing a packet,
// so defer the actual delete of it.
dev_->delayed_delete_.Add(&WgPeer::DelayedDelete, this);
}
void WgPeer::ClearKeys_Locked() {
@ -382,21 +451,55 @@ void WgPeer::ClearPacketQueue_Locked() {
num_queued_packets_ = 0;
}
void WgPeer::Initialize(const uint8 spub[WG_PUBLIC_KEY_LEN], const uint8 preshared_key[WG_SYMMETRIC_KEY_LEN]) {
// Optionally use a preshared key, it defaults to all zeros.
void WgPeer::AddPacketToPeerQueue_Locked(Packet *packet) {
assert(IsPeerLocked());
assert(!marked_for_delete_);
// Keep only the first MAX_QUEUED_PACKETS packets.
while (num_queued_packets_ >= MAX_QUEUED_PACKETS_PER_PEER) {
Packet *packet = first_queued_packet_;
first_queued_packet_ = packet->next;
num_queued_packets_--;
FreePacket(packet);
}
// Add the packet to the out queue that will get sent once handshake completes
*last_queued_packet_ptr_ = packet;
last_queued_packet_ptr_ = &packet->next;
packet->next = NULL;
num_queued_packets_++;
}
void WgPeer::SetPublicKey(const WgPublicKey &spub) {
assert(dev_->IsMainThread());
assert(IsOnlyZeros(s_remote_.bytes, sizeof(s_remote_.bytes)) ||
memcmp(s_remote_.bytes, spub.bytes, sizeof(s_remote_.bytes)) == 0);
s_remote_ = spub;
dev_->peer_id_lookup_[s_remote_] = this;
if (!dev_->is_private_key_initialized_)
return;
// Precompute: s_priv_pub_ := DH(sprivr, spubi)
curve25519_donna(s_priv_pub_, dev_->s_priv_, s_remote_.bytes);
// Precompute: precomputed_cookie_key_ := HASH(LABEL-COOKIE || Spub_m)
// precomputed_mac1_key_ := HASH(MAC1-COOKIE || Spub_m)
BlakeX2(precomputed_cookie_key_, sizeof(precomputed_cookie_key_),
kLabelCookie, sizeof(kLabelCookie), s_remote_.bytes, WG_PUBLIC_KEY_LEN);
BlakeX2(precomputed_mac1_key_, sizeof(precomputed_mac1_key_),
kLabelMac1, sizeof(kLabelMac1), s_remote_.bytes, WG_PUBLIC_KEY_LEN);
// Remove the peer's keys
WG_ACQUIRE_LOCK(mutex_);
ClearKeys_Locked();
ClearHandshake_Locked();
WG_RELEASE_LOCK(mutex_);
}
void WgPeer::SetPresharedKey(const uint8 preshared_key[WG_SYMMETRIC_KEY_LEN]) {
if (preshared_key)
memcpy(preshared_key_, preshared_key, sizeof(preshared_key_));
else
memset(preshared_key_, 0, sizeof(preshared_key_));
// Precompute: s_priv_pub_ := DH(sprivr, spubi)
memcpy(s_remote_, spub, sizeof(s_remote_));
curve25519_donna(s_priv_pub_, dev_->s_priv_, s_remote_);
// Precompute: precomputed_cookie_key_ := HASH(LABEL-COOKIE || Spub_m)
// precomputed_mac1_key_ := HASH(MAC1-COOKIE || Spub_m)
BlakeX2(precomputed_cookie_key_, sizeof(precomputed_cookie_key_),
kLabelCookie, sizeof(kLabelCookie), spub, WG_PUBLIC_KEY_LEN);
BlakeX2(precomputed_mac1_key_, sizeof(precomputed_mac1_key_),
kLabelMac1, sizeof(kLabelMac1), spub, WG_PUBLIC_KEY_LEN);
}
// run on the client
@ -411,7 +514,7 @@ void WgPeer::CreateMessageHandshakeInitiation(Packet *packet) {
// Hi := HASH(Ci || IDENTIFIER)
memcpy(hs_.hi, kWgInitHash, sizeof(hs_.hi));
// Hi := HASH(Hi || Spub_r)
BlakeMix(hs_.hi, s_remote_, sizeof(s_remote_));
BlakeMix(hs_.hi, s_remote_.bytes, sizeof(s_remote_));
// (Epriv_r, Epub_r) := DH-GENERATE()
// msg.ephemeral = Epub_r
OsGetRandomBytes(hs_.e_priv, sizeof(hs_.e_priv));
@ -422,7 +525,7 @@ void WgPeer::CreateMessageHandshakeInitiation(Packet *packet) {
// Hi := HASH(Hi || msg.ephemeral)
BlakeMix(hs_.hi, dst->ephemeral, sizeof(dst->ephemeral));
// (Ci, K) := KDF2(Ci, DH(epriv, spub_r))
ComputeHKDF2DH(hs_.ci, k, hs_.e_priv, s_remote_);
ComputeHKDF2DH(hs_.ci, k, hs_.e_priv, s_remote_.bytes);
// msg.static = AEAD(K, 0, Spub_i, Hi)
chacha20poly1305_encrypt(dst->static_enc, dev_->s_pub_, sizeof(dev_->s_pub_), hs_.hi, sizeof(hs_.hi), 0, k);
// Hi := HASH(Hi || msg.static)
@ -461,7 +564,7 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) {
uint8 e_priv[WG_PUBLIC_KEY_LEN];
};
union {
uint8 spubi[WG_PUBLIC_KEY_LEN];
WgPublicKey spubi;
uint8 e_remote[WG_PUBLIC_KEY_LEN];
uint8 hi2[WG_HASH_LEN];
};
@ -488,13 +591,13 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) {
// (Ci, K) := KDF2(Ci, DH(spriv, msg.ephemeral))
ComputeHKDF2DH(ci, k, dev->s_priv_, src->ephemeral);
// Spub_i = AEAD_DEC(K, 0, msg.static, Hi)
if (!chacha20poly1305_decrypt(spubi, src->static_enc, sizeof(src->static_enc), hi, sizeof(hi), 0, k))
if (!chacha20poly1305_decrypt(spubi.bytes, src->static_enc, sizeof(src->static_enc), hi, sizeof(hi), 0, k))
goto getout;
// Hi := HASH(Hi || msg.static)
BlakeMix(hi, src->static_enc, sizeof(src->static_enc));
// Lookup the peer with this ID
while ((peer = dev->GetPeerFromPublicKey(spubi)) == NULL) {
if (dev->delegate_ == NULL || !dev->delegate_->HandleUnknownPeerId(spubi, packet))
if (dev->delegate_ == NULL || !dev->delegate_->HandleUnknownPeerId(spubi.bytes, packet))
goto getout;
}
// (Ci, K) := KDF2(Ci, DH(sprivr, spubi))
@ -538,7 +641,7 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) {
// Ci : = KDF2(Ci, DH(epriv, epub))
ComputeHKDF2DH(ci, NULL, e_priv, e_remote);
// Ci : = KDF2(Ci, DH(epriv, spub))
ComputeHKDF2DH(ci, NULL, e_priv, peer->s_remote_);
ComputeHKDF2DH(ci, NULL, e_priv, peer->s_remote_.bytes);
// (Ci, T, K) := KDF3(Ci, Q)
blake2s_hkdf(ci, sizeof(ci), t, sizeof(t), k, sizeof(k), peer->preshared_key_, sizeof(preshared_key_), ci, WG_HASH_LEN);
// Hr := HASH(Hr || T)
@ -548,11 +651,6 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) {
keypair = WgPeer::CreateNewKeypair(false, ci, remote_key_id, extbuf + WG_TIMESTAMP_LEN, extfield_size);
if (keypair) {
WG_ACQUIRE_LOCK(peer->mutex_);
peer->InsertKeypairInPeer_Locked(keypair);
peer->OnHandshakeAuthComplete();
WG_RELEASE_LOCK(peer->mutex_);
dst->sender_key_id = dev->InsertInKeyIdLookup(peer, keypair);
size_t extfield_out_size = 0;
@ -560,8 +658,17 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) {
if (extfield_size)
extfield_out_size = peer->WriteHandshakeExtension(dst->empty_enc, keypair);
#endif // WITH_HANDSHAKE_EXT
uint32 orig_packet_size = packet->size;
packet->size = (unsigned)(sizeof(MessageHandshakeResponse) + extfield_out_size);
WG_ACQUIRE_LOCK(peer->mutex_);
peer->rx_bytes_ += orig_packet_size;
peer->tx_bytes_ += packet->size;
peer->InsertKeypairInPeer_Locked(keypair);
peer->OnHandshakeAuthComplete();
WG_RELEASE_LOCK(peer->mutex_);
// msg.empty := AEAD(K, 0, "", Hr)
chacha20poly1305_encrypt(dst->empty_enc, dst->empty_enc, extfield_out_size, hi, sizeof(hi), 0, k);
// Hr := HASH(Hr || "")
@ -624,6 +731,7 @@ WgPeer *WgPeer::ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packe
peer_and_keypair->second = keypair;
WG_ACQUIRE_LOCK(peer->mutex_);
peer->rx_bytes_ += packet->size;
peer->InsertKeypairInPeer_Locked(keypair);
WG_RELEASE_LOCK(peer->mutex_);
@ -651,6 +759,9 @@ void WgPeer::ParseMessageHandshakeCookie(WgDevice *dev, const MessageHandshakeCo
if (!xchacha20poly1305_decrypt(cookie, src->cookie_enc, sizeof(src->cookie_enc),
peer->sent_mac1_, sizeof(peer->sent_mac1_), src->nonce, peer->precomputed_cookie_key_))
return;
WG_ACQUIRE_LOCK(peer->mutex_);
peer->rx_bytes_ += sizeof(MessageHandshakeCookie);
WG_RELEASE_LOCK(peer->mutex_);
peer->expect_cookie_reply_ = false;
peer->has_mac2_cookie_ = true;
peer->mac2_cookie_timestamp_ = OsGetMilliseconds();
@ -796,7 +907,7 @@ bool WgKeypairParseExtendedHandshake(WgKeypair *keypair, const uint8 *data, size
#endif // WITH_HANDSHAKE_EXT
static void ActualFreeKeypair(void *x) {
static void WgKeypairDelayedDelete(void *x) {
WgKeypair *t = (WgKeypair*)x;
if (t->aes_gcm128_context_)
free(t->aes_gcm128_context_);
@ -808,17 +919,18 @@ void WgPeer::DeleteKeypair(WgKeypair **kp) {
*kp = NULL;
if (t) {
assert(t->peer->IsPeerLocked());
WgDevice *dev = t->peer->dev_;
if (t->addr_entry) {
WG_SCOPED_RWLOCK_EXCLUSIVE(dev_->addr_entry_lookup_lock_);
dev_->EraseKeypairAddrEntry_Locked(t);
WG_SCOPED_RWLOCK_EXCLUSIVE(dev->addr_entry_lookup_lock_);
dev->EraseKeypairAddrEntry_Locked(t);
}
if (t->local_key_id) {
WG_SCOPED_RWLOCK_EXCLUSIVE(dev_->key_id_lookup_lock_);
dev_->key_id_lookup_.erase(t->local_key_id);
WG_SCOPED_RWLOCK_EXCLUSIVE(dev->key_id_lookup_lock_);
dev->key_id_lookup_.erase(t->local_key_id);
t->local_key_id = 0;
}
t->recv_key_state = WgKeypair::KEY_INVALID;
dev_->delayed_delete_.Add(&ActualFreeKeypair, t);
dev->delayed_delete_.Add(&WgKeypairDelayedDelete, t);
}
}
@ -1029,8 +1141,8 @@ void WgPeer::OnHandshakeFullyComplete() {
}
// Check if any of the timeouts have expired
uint32 WgPeer::CheckTimeouts(uint64 now) {
assert(IsPeerLocked());
uint32 WgPeer::CheckTimeouts_Locked(uint64 now) {
assert(dev_->IsMainThread() && IsPeerLocked());
uint32 t, rv = 0;
@ -1096,7 +1208,7 @@ uint32 WgPeer::CheckTimeouts(uint64 now) {
// Check all key stuff here to avoid calling possibly expensive timestamp routines in the packet handler
void WgPeer::CheckAndUpdateTimeOfNextKeyEvent(uint64 now) {
assert(IsPeerLocked());
assert(dev_->IsMainThread() && IsPeerLocked());
uint64 next_time = UINT64_MAX;
uint32 rv = 0;
@ -1142,34 +1254,60 @@ void WgPeer::SetEndpoint(const IpAddr &sin) {
endpoint_ = sin;
}
void WgPeer::SetPersistentKeepalive(int persistent_keepalive_secs) {
if (persistent_keepalive_secs < 10 || persistent_keepalive_secs > 10000)
return;
bool WgPeer::SetPersistentKeepalive(int persistent_keepalive_secs) {
if (persistent_keepalive_secs < 0 || persistent_keepalive_secs > 65535)
return false;
persistent_keepalive_ms_ = persistent_keepalive_secs * 1000;
return true;
}
bool WgCidrAddrEquals(const WgCidrAddr &a, const WgCidrAddr &b) {
return (a.size == b.size && a.cidr == b.cidr && memcmp(a.addr, b.addr, a.size >> 3) == 0);
}
bool WgPeer::AddIp(const WgCidrAddr &cidr_addr) {
WgPeer *old_peer;
assert(dev_->IsMainThread());
if (cidr_addr.size == 32) {
if (cidr_addr.cidr > 32)
return false;
WG_ACQUIRE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_);
dev_->ip_to_peer_map_.InsertV4(ReadBE32(cidr_addr.addr), cidr_addr.cidr, this);
old_peer = (WgPeer*)dev_->ip_to_peer_map_.InsertV4(ReadBE32(cidr_addr.addr), cidr_addr.cidr, this);
WG_RELEASE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_);
allowed_ips_.push_back(cidr_addr);
return true;
} else if (cidr_addr.size == 128) {
if (cidr_addr.cidr > 128)
return false;
WG_ACQUIRE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_);
dev_->ip_to_peer_map_.InsertV6(cidr_addr.addr, cidr_addr.cidr, this);
old_peer = (WgPeer*)dev_->ip_to_peer_map_.InsertV6(cidr_addr.addr, cidr_addr.cidr, this);
WG_RELEASE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_);
allowed_ips_.push_back(cidr_addr);
return true;
} else {
return false;
}
if (old_peer) {
for (auto it = old_peer->allowed_ips_.begin(); it != old_peer->allowed_ips_.end(); ++it) {
if (WgCidrAddrEquals(*it, cidr_addr)) {
old_peer->allowed_ips_.erase(it);
break;
}
}
}
allowed_ips_.push_back(cidr_addr);
return true;
}
void WgPeer::RemoveAllIps() {
assert(dev_->IsMainThread());
WG_ACQUIRE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_);
for (auto it = allowed_ips_.begin(); it != allowed_ips_.end(); ++it) {
if (it->size == 32) {
dev_->ip_to_peer_map_.RemoveV4(ReadBE32(it->addr), it->cidr);
} else if (it->size == 128) {
dev_->ip_to_peer_map_.RemoveV6(it->addr, it->cidr);
}
}
WG_RELEASE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_);
allowed_ips_.clear();
}
void WgPeer::SetAllowMulticast(bool allow) {
@ -1196,6 +1334,18 @@ bool WgPeer::AddCipher(int cipher) {
return true;
}
void WgPeer::ScheduleNewHandshake() {
// Note, it's possible that the peer has already been marked for delete
if (main_thread_scheduled_.fetch_or(WgPeer::kMainThreadScheduled_ScheduleHandshake) == 0) {
main_thread_scheduled_next_ = NULL;
WG_ACQUIRE_LOCK(dev_->main_thread_scheduled_lock_);
*dev_->main_thread_scheduled_last_ = this;
dev_->main_thread_scheduled_last_ = &main_thread_scheduled_next_;
WG_RELEASE_LOCK(dev_->main_thread_scheduled_lock_);
// todo: in multithreaded impl need to trigger |RunAllMainThreadScheduled| to get called
}
}
WgRateLimit::WgRateLimit() {
key1_[0] = key1_[1] = 1;
key2_[0] = key2_[1] = 1;

View file

@ -11,7 +11,7 @@
#include <vector>
#include <unordered_map>
#include <atomic>
#include <string.h>
// Threading macros that enable locks only in MT builds
#if WITH_WG_THREADING
#define WG_SCOPED_LOCK(name) ScopedLock scoped_lock(&name)
@ -25,6 +25,7 @@
#define WG_RELEASE_RWLOCK_EXCLUSIVE(name) name.ReleaseExclusive()
#define WG_SCOPED_RWLOCK_SHARED(name) ScopedLockShared scoped_lock(&name)
#define WG_SCOPED_RWLOCK_EXCLUSIVE(name) ScopedLockExclusive scoped_lock(&name)
#define WG_IF_LOCKS_ENABLED_ELSE(expr, def) (expr)
#else // WITH_WG_THREADING
#define WG_SCOPED_LOCK(name)
#define WG_ACQUIRE_LOCK(name)
@ -37,6 +38,7 @@
#define WG_RELEASE_RWLOCK_EXCLUSIVE(name)
#define WG_SCOPED_RWLOCK_SHARED(name)
#define WG_SCOPED_RWLOCK_EXCLUSIVE(name)
#define WG_IF_LOCKS_ENABLED_ELSE(expr, def) (def)
#endif // WITH_WG_THREADING
enum ProtocolTimeouts {
@ -77,6 +79,7 @@ enum MessageFieldSizes {
WG_MAC_LEN = 16,
WG_TIMESTAMP_LEN = 12,
WG_SIPHASH_KEY_LEN = 16,
WG_PUBLIC_KEY_LEN_BASE64 = 44,
};
enum {
@ -194,11 +197,9 @@ struct WgPacketCompressionVer01 {
};
STATIC_ASSERT(sizeof(WgPacketCompressionVer01) == 24, WgPacketCompressionVer01_wrong_size);
struct WgKeypair;
class WgPeer;
class WgRateLimit {
public:
WgRateLimit();
@ -260,10 +261,26 @@ struct WgAddrEntry {
struct ScramblerSiphashKeys {
uint64 keys[4];
};
union WgPublicKey {
uint8 bytes[WG_PUBLIC_KEY_LEN];
uint64 u64[WG_PUBLIC_KEY_LEN / 8];
friend bool operator==(const WgPublicKey &a, const WgPublicKey &b) {
return memcmp(a.bytes, b.bytes, WG_PUBLIC_KEY_LEN) == 0;
}
};
struct WgPublicKeyHasher {
size_t operator()(const WgPublicKey&a) const {
uint64 rv = a.u64[0] ^ a.u64[1] ^ a.u64[2] ^ a.u64[3];
return (size_t)(rv ^ (rv >> 32));
}
};
class WgDevice {
friend class WgPeer;
friend class WireguardProcessor;
friend class WgConfig;
public:
// Can be used to customize the behavior of WgDevice
@ -278,12 +295,15 @@ public:
WgDevice();
~WgDevice();
// Initialize with the private key, precompute all internal keys etc.
void Initialize(const uint8 private_key[WG_PUBLIC_KEY_LEN]);
// Configure with the private key, precompute all internal keys etc.
void SetPrivateKey(const uint8 private_key[WG_PUBLIC_KEY_LEN]);
// Create a new peer
WgPeer *AddPeer();
// Remove all peers
void RemoveAllPeers();
// Setup header obfuscation
void SetHeaderObfuscation(const char *key);
@ -303,17 +323,19 @@ public:
WgRateLimit *rate_limiter() { return &rate_limiter_; }
std::unordered_map<uint64, WgAddrEntry*> &addr_entry_map() { return addr_entry_lookup_; }
WgPacketCompressionVer01 *compression_header() { return &compression_header_; }
bool is_private_key_initialized() { return is_private_key_initialized_; }
bool IsMainThread() { return CurrentThreadIdEquals(main_thread_id_); }
void SetCurrentThreadAsMainThread() { main_thread_id_ = GetCurrentThreadId(); }
bool IsMainOrDataThread() { return CurrentThreadIdEquals(main_thread_id_) || WG_IF_LOCKS_ENABLED_ELSE(delayed_delete_.enabled(), false); }
void SetDelegate(Delegate *del) { delegate_ = del; }
private:
std::pair<WgPeer*, WgKeypair*> *LookupPeerInKeyIdLookup(uint32 key_id);
WgKeypair *LookupKeypairByKeyId(uint32 key_id);
WgKeypair *LookupKeypairInAddrEntryMap(uint64 addr, uint32 slot);
// Return the peer matching the |public_key| or NULL
WgPeer *GetPeerFromPublicKey(uint8 public_key[WG_PUBLIC_KEY_LEN]);
WgPeer *GetPeerFromPublicKey(const WgPublicKey &pubkey);
// Create a cookie by inspecting the source address of the |packet|
void MakeCookie(uint8 cookie[WG_COOKIE_LEN], Packet *packet);
// Insert a new entry in |key_id_lookup_|
@ -330,7 +352,7 @@ private:
WG_DECLARE_RWLOCK(ip_to_peer_map_lock_);
// For enumerating all peers
WgPeer *peers_;
WgPeer *peers_, **last_peer_ptr_;
// For hooking
Delegate *delegate_;
@ -346,12 +368,22 @@ private:
std::unordered_map<uint64, WgAddrEntry*> addr_entry_lookup_;
WG_DECLARE_RWLOCK(addr_entry_lookup_lock_);
// Mapping from peer id to peer. This may be accessed only from MT.
std::unordered_map<WgPublicKey, WgPeer*, WgPublicKeyHasher> peer_id_lookup_;
// Queue of things scheduled to run on the main thread.
WG_DECLARE_LOCK(main_thread_scheduled_lock_);
WgPeer *main_thread_scheduled_, **main_thread_scheduled_last_;
// Counter for generating new indices in |keypair_lookup_|
uint8 next_rng_slot_;
// Whether packet obfuscation is enabled
bool header_obfuscation_;
// Whether a private key has been setup for the device
bool is_private_key_initialized_;
ThreadId main_thread_id_;
uint64 low_resolution_timestamp_;
@ -382,15 +414,16 @@ private:
class WgPeer {
friend class WgDevice;
friend class WireguardProcessor;
friend class WgConfig;
friend bool WgKeypairParseExtendedHandshake(WgKeypair *keypair, const uint8 *data, size_t data_size);
friend void WgKeypairSetupCompressionExtension(WgKeypair *keypair, const WgPacketCompressionVer01 *remotec);
public:
explicit WgPeer(WgDevice *dev);
~WgPeer();
void Initialize(const uint8 spub[WG_PUBLIC_KEY_LEN], const uint8 preshared_key[WG_SYMMETRIC_KEY_LEN]);
void SetPersistentKeepalive(int persistent_keepalive_secs);
void SetPublicKey(const WgPublicKey &spub);
void SetPresharedKey(const uint8 preshared_key[WG_SYMMETRIC_KEY_LEN]);
bool SetPersistentKeepalive(int persistent_keepalive_secs);
void SetEndpoint(const IpAddr &sin);
void SetAllowMulticast(bool allow);
@ -398,15 +431,14 @@ public:
bool AddCipher(int cipher);
void SetCipherPrio(bool prio) { cipher_prio_ = prio; }
bool AddIp(const WgCidrAddr &cidr_addr);
void RemoveAllIps();
static WgPeer *ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet);
static WgPeer *ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packet);
static void ParseMessageHandshakeCookie(WgDevice *dev, const MessageHandshakeCookie *src);
void CreateMessageHandshakeInitiation(Packet *packet);
bool CheckSwitchToNextKey_Locked(WgKeypair *keypair);
void ClearKeys_Locked();
void ClearHandshake_Locked();
void ClearPacketQueue_Locked();
void RemovePeer();
bool CheckHandshakeRateLimit();
// Timer notifications
@ -422,25 +454,25 @@ public:
ACTION_SEND_KEEPALIVE = 1,
ACTION_SEND_HANDSHAKE = 2,
};
uint32 CheckTimeouts(uint64 now);
uint32 CheckTimeouts_Locked(uint64 now);
void AddPacketToPeerQueue(Packet *packet);
#if WITH_WG_THREADING
bool IsPeerLocked() { return mutex_.IsLocked(); }
#else // WITH_WG_THREADING
bool IsPeerLocked() { return true; }
#endif // WITH_WG_THREADING
void AddPacketToPeerQueue_Locked(Packet *packet);
bool IsPeerLocked() { return WG_IF_LOCKS_ENABLED_ELSE(mutex_.IsLocked(), true); }
private:
static WgKeypair *CreateNewKeypair(bool is_initiator, const uint8 key[WG_HASH_LEN], uint32 send_key_id, const uint8 *extfield, size_t extfield_size);
void WriteMacToPacket(const uint8 *data, MessageMacs *mac);
void DeleteKeypair(WgKeypair **kp);
void CheckAndUpdateTimeOfNextKeyEvent(uint64 now);
static void DeleteKeypair(WgKeypair **kp);
static void CopyEndpointToPeer_Locked(WgKeypair *keypair, const IpAddr *addr);
static void DelayedDelete(void *x);
size_t WriteHandshakeExtension(uint8 *dst, WgKeypair *keypair);
void InsertKeypairInPeer_Locked(WgKeypair *keypair);
void ClearKeys_Locked();
void ClearHandshake_Locked();
void ClearPacketQueue_Locked();
void ScheduleNewHandshake();
WgDevice *dev_;
WgPeer *next_peer_;
@ -492,6 +524,10 @@ private:
// Whether |mac2_cookie_| is valid.
bool has_mac2_cookie_;
// Whether the WgPeer has been deleted (i.e. RemovePeer has been called),
// and will be deleted as soon as the threads sync.
bool marked_for_delete_;
// Number of handshakes made so far, when this gets too high we stop connecting.
uint8 handshake_attempts_;
@ -517,7 +553,10 @@ private:
uint8 cipher_prio_;
uint8 num_ciphers_;
uint8 ciphers_[MAX_CIPHERS];
uint64 rx_bytes_;
uint64 tx_bytes_;
// Handshake state that gets setup in |CreateMessageHandshakeInitiation| and used in
// the response.
struct HandshakeState {
@ -530,7 +569,7 @@ private:
};
HandshakeState hs_;
// Remote's static public key - init only.
uint8 s_remote_[WG_PUBLIC_KEY_LEN];
WgPublicKey s_remote_;
// Remote's preshared key - init only.
uint8 preshared_key_[WG_SYMMETRIC_KEY_LEN];
// Precomputed DH(spriv_local, spub_remote) - init only.