Support multiple interfaces and the 'ts' command line tool

This commit is contained in:
Ludvig Strigeus 2018-09-15 18:22:05 +02:00
parent 7b7fb6126b
commit 6d916e9aaa
39 changed files with 4243 additions and 1530 deletions

12
.gitignore vendored
View file

@ -1,18 +1,10 @@
/Debug/
/Release/
/ipzip2/Debug/
/Build
/Win32/
/TunSafe.aps /TunSafe.aps
/*.sdf /*.sdf
/*vcxproj.user *vcxproj.user
/*.opensdf /*.opensdf
/*.suo /*.suo
/.vs/ /.vs/
/x64/ /build/
/Azire.conf
/*.psess /*.psess
/*.vspx /*.vspx
/installer/*.zip /installer/*.zip
/config/
/tunsafe.com/

View file

@ -5,6 +5,8 @@ VisualStudioVersion = 15.0.26403.7
MinimumVisualStudioVersion = 10.0.40219.1 MinimumVisualStudioVersion = 10.0.40219.1
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "TunSafe", "TunSafe.vcxproj", "{626FBC16-64C6-407D-BC2B-6C087794E0D0}" Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "TunSafe", "TunSafe.vcxproj", "{626FBC16-64C6-407D-BC2B-6C087794E0D0}"
EndProject EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ts", "ts.vcxproj", "{443E105E-8D7C-401F-BD41-D3F56C76104B}"
EndProject
Global Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Win32 = Debug|Win32 Debug|Win32 = Debug|Win32
@ -21,8 +23,19 @@ Global
{626FBC16-64C6-407D-BC2B-6C087794E0D0}.Release|Win32.Build.0 = Release|Win32 {626FBC16-64C6-407D-BC2B-6C087794E0D0}.Release|Win32.Build.0 = Release|Win32
{626FBC16-64C6-407D-BC2B-6C087794E0D0}.Release|x64.ActiveCfg = Release|x64 {626FBC16-64C6-407D-BC2B-6C087794E0D0}.Release|x64.ActiveCfg = Release|x64
{626FBC16-64C6-407D-BC2B-6C087794E0D0}.Release|x64.Build.0 = Release|x64 {626FBC16-64C6-407D-BC2B-6C087794E0D0}.Release|x64.Build.0 = Release|x64
{443E105E-8D7C-401F-BD41-D3F56C76104B}.Debug|Win32.ActiveCfg = Debug|Win32
{443E105E-8D7C-401F-BD41-D3F56C76104B}.Debug|Win32.Build.0 = Debug|Win32
{443E105E-8D7C-401F-BD41-D3F56C76104B}.Debug|x64.ActiveCfg = Debug|x64
{443E105E-8D7C-401F-BD41-D3F56C76104B}.Debug|x64.Build.0 = Debug|x64
{443E105E-8D7C-401F-BD41-D3F56C76104B}.Release|Win32.ActiveCfg = Release|Win32
{443E105E-8D7C-401F-BD41-D3F56C76104B}.Release|Win32.Build.0 = Release|Win32
{443E105E-8D7C-401F-BD41-D3F56C76104B}.Release|x64.ActiveCfg = Release|x64
{443E105E-8D7C-401F-BD41-D3F56C76104B}.Release|x64.Build.0 = Release|x64
EndGlobalSection EndGlobalSection
GlobalSection(SolutionProperties) = preSolution GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE HideSolutionNode = FALSE
EndGlobalSection EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {F2DD9ED8-DDEA-4B40-9208-41726750D33D}
EndGlobalSection
EndGlobal EndGlobal

View file

@ -22,7 +22,7 @@
<ProjectGuid>{626FBC16-64C6-407D-BC2B-6C087794E0D0}</ProjectGuid> <ProjectGuid>{626FBC16-64C6-407D-BC2B-6C087794E0D0}</ProjectGuid>
<Keyword>Win32Proj</Keyword> <Keyword>Win32Proj</Keyword>
<RootNamespace>TunSafe</RootNamespace> <RootNamespace>TunSafe</RootNamespace>
<WindowsTargetPlatformVersion>10.0.15063.0</WindowsTargetPlatformVersion> <WindowsTargetPlatformVersion>10.0.17134.0</WindowsTargetPlatformVersion>
<ProjectName>TunSafe</ProjectName> <ProjectName>TunSafe</ProjectName>
</PropertyGroup> </PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" /> <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
@ -72,24 +72,28 @@
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'"> <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<LinkIncremental>true</LinkIncremental> <LinkIncremental>true</LinkIncremental>
<TargetName>TunSafe</TargetName> <TargetName>TunSafe</TargetName>
<OutDir>$(SolutionDir)$(Platform)\$(Configuration)\</OutDir> <OutDir>$(SolutionDir)build\$(Platform)_$(Configuration)\</OutDir>
<IntDir>$(Platform)\$(Configuration)\</IntDir> <IntDir>$(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\</IntDir>
</PropertyGroup> </PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<LinkIncremental>true</LinkIncremental> <LinkIncremental>true</LinkIncremental>
<ExecutablePath>$(VC_ExecutablePath_x64);$(WindowsSDK_ExecutablePath);$(VS_ExecutablePath);$(MSBuild_ExecutablePath);$(FxCopDir);$(PATH);C:\Bin\Dev\nasm</ExecutablePath> <ExecutablePath>$(VC_ExecutablePath_x64);$(WindowsSDK_ExecutablePath);$(VS_ExecutablePath);$(MSBuild_ExecutablePath);$(FxCopDir);$(PATH);C:\Bin\Dev\nasm</ExecutablePath>
<TargetName>TunSafe</TargetName> <TargetName>TunSafe</TargetName>
<OutDir>$(SolutionDir)build\$(Platform)_$(Configuration)\</OutDir>
<IntDir>$(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\</IntDir>
</PropertyGroup> </PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'"> <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<LinkIncremental>false</LinkIncremental> <LinkIncremental>false</LinkIncremental>
<TargetName>TunSafe</TargetName> <TargetName>TunSafe</TargetName>
<OutDir>$(SolutionDir)$(Platform)\$(Configuration)\</OutDir> <OutDir>$(SolutionDir)build\$(Platform)_$(Configuration)\</OutDir>
<IntDir>$(Platform)\$(Configuration)\</IntDir> <IntDir>$(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\</IntDir>
</PropertyGroup> </PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<LinkIncremental>false</LinkIncremental> <LinkIncremental>false</LinkIncremental>
<ExecutablePath>$(VC_ExecutablePath_x64);$(WindowsSDK_ExecutablePath);$(VS_ExecutablePath);$(MSBuild_ExecutablePath);$(FxCopDir);$(PATH);C:\Bin\Dev\nasm</ExecutablePath> <ExecutablePath>$(VC_ExecutablePath_x64);$(WindowsSDK_ExecutablePath);$(VS_ExecutablePath);$(MSBuild_ExecutablePath);$(FxCopDir);$(PATH);C:\Bin\Dev\nasm</ExecutablePath>
<TargetName>TunSafe</TargetName> <TargetName>TunSafe</TargetName>
<IntDir>$(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\</IntDir>
<OutDir>$(SolutionDir)build\$(Platform)_$(Configuration)\</OutDir>
</PropertyGroup> </PropertyGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'"> <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<ClCompile> <ClCompile>
@ -98,6 +102,7 @@
<Optimization>Disabled</Optimization> <Optimization>Disabled</Optimization>
<PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions);_CRT_SECURE_NO_WARNINGS;_CRT_SECURE_NO_WARNINGS</PreprocessorDefinitions> <PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions);_CRT_SECURE_NO_WARNINGS;_CRT_SECURE_NO_WARNINGS</PreprocessorDefinitions>
<AdditionalIncludeDirectories>.</AdditionalIncludeDirectories> <AdditionalIncludeDirectories>.</AdditionalIncludeDirectories>
<ExceptionHandling>false</ExceptionHandling>
</ClCompile> </ClCompile>
<Link> <Link>
<SubSystem>Windows</SubSystem> <SubSystem>Windows</SubSystem>
@ -114,6 +119,7 @@
<ForcedIncludeFiles> <ForcedIncludeFiles>
</ForcedIncludeFiles> </ForcedIncludeFiles>
<AdditionalIncludeDirectories>.</AdditionalIncludeDirectories> <AdditionalIncludeDirectories>.</AdditionalIncludeDirectories>
<ExceptionHandling>false</ExceptionHandling>
</ClCompile> </ClCompile>
<Link> <Link>
<SubSystem>Windows</SubSystem> <SubSystem>Windows</SubSystem>
@ -133,6 +139,8 @@
<PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions);_CRT_SECURE_NO_WARNINGS</PreprocessorDefinitions> <PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions);_CRT_SECURE_NO_WARNINGS</PreprocessorDefinitions>
<RuntimeLibrary>MultiThreaded</RuntimeLibrary> <RuntimeLibrary>MultiThreaded</RuntimeLibrary>
<AdditionalIncludeDirectories>.</AdditionalIncludeDirectories> <AdditionalIncludeDirectories>.</AdditionalIncludeDirectories>
<ExceptionHandling>false</ExceptionHandling>
<RuntimeTypeInfo>false</RuntimeTypeInfo>
</ClCompile> </ClCompile>
<Link> <Link>
<SubSystem>Windows</SubSystem> <SubSystem>Windows</SubSystem>
@ -157,6 +165,8 @@
<InlineFunctionExpansion>AnySuitable</InlineFunctionExpansion> <InlineFunctionExpansion>AnySuitable</InlineFunctionExpansion>
<OmitFramePointers>true</OmitFramePointers> <OmitFramePointers>true</OmitFramePointers>
<AdditionalIncludeDirectories>.</AdditionalIncludeDirectories> <AdditionalIncludeDirectories>.</AdditionalIncludeDirectories>
<ExceptionHandling>false</ExceptionHandling>
<RuntimeTypeInfo>false</RuntimeTypeInfo>
</ClCompile> </ClCompile>
<Link> <Link>
<SubSystem>Windows</SubSystem> <SubSystem>Windows</SubSystem>
@ -169,8 +179,10 @@
<ItemGroup> <ItemGroup>
<ClInclude Include="bit_ops.h" /> <ClInclude Include="bit_ops.h" />
<ClInclude Include="ip_to_peer_map.h" /> <ClInclude Include="ip_to_peer_map.h" />
<ClInclude Include="service_pipe_win32.h" />
<ClInclude Include="service_win32.h" /> <ClInclude Include="service_win32.h" />
<ClInclude Include="service_win32_api.h" /> <ClInclude Include="service_win32_api.h" />
<ClInclude Include="service_win32_constants.h" />
<ClInclude Include="tunsafe_config.h" /> <ClInclude Include="tunsafe_config.h" />
<ClInclude Include="tunsafe_cpu.h" /> <ClInclude Include="tunsafe_cpu.h" />
<ClInclude Include="crypto\aesgcm\aes.h" /> <ClInclude Include="crypto\aesgcm\aes.h" />
@ -196,6 +208,7 @@
<ItemGroup> <ItemGroup>
<ClCompile Include="benchmark.cpp" /> <ClCompile Include="benchmark.cpp" />
<ClCompile Include="ip_to_peer_map.cpp" /> <ClCompile Include="ip_to_peer_map.cpp" />
<ClCompile Include="service_pipe_win32.cpp" />
<ClCompile Include="service_win32.cpp" /> <ClCompile Include="service_win32.cpp" />
<ClCompile Include="tunsafe_cpu.cpp" /> <ClCompile Include="tunsafe_cpu.cpp" />
<ClCompile Include="crypto\aesgcm\aesgcm.cpp" /> <ClCompile Include="crypto\aesgcm\aesgcm.cpp" />

View file

@ -89,6 +89,12 @@
<ClInclude Include="util_win32.h"> <ClInclude Include="util_win32.h">
<Filter>Source Files\Win32</Filter> <Filter>Source Files\Win32</Filter>
</ClInclude> </ClInclude>
<ClInclude Include="service_pipe_win32.h">
<Filter>Source Files\Win32</Filter>
</ClInclude>
<ClInclude Include="service_win32_constants.h">
<Filter>Source Files\Win32</Filter>
</ClInclude>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ClCompile Include="stdafx.cpp"> <ClCompile Include="stdafx.cpp">
@ -154,6 +160,9 @@
<ClCompile Include="ip_to_peer_map.cpp"> <ClCompile Include="ip_to_peer_map.cpp">
<Filter>Source Files</Filter> <Filter>Source Files</Filter>
</ClCompile> </ClCompile>
<ClCompile Include="service_pipe_win32.cpp">
<Filter>Source Files\Win32</Filter>
</ClCompile>
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<ResourceCompile Include="TunSafe.rc" /> <ResourceCompile Include="TunSafe.rc" />

View file

@ -1,4 +1,4 @@
g++7 -I . -O2 -DNDEBUG -static -mssse3 -o tunsafe benchmark.cpp tunsafe_cpu.cpp wireguard_config.cpp ip_to_peer_map.cpp tunsafe_threading.cpp \ g++7 -I . -O2 -DNDEBUG -static -mssse3 -o tunsafe benchmark.cpp tunsafe_cpu.cpp wireguard_config.cpp ip_to_peer_map.cpp tunsafe_threading.cpp \
wireguard.cpp wireguard_proto.cpp util.cpp network_bsd.cpp network_bsd_common.cpp \ wireguard.cpp wireguard_proto.cpp ts.cpp util.cpp network_bsd.cpp network_bsd_common.cpp \
crypto/blake2s.cpp crypto/blake2s_sse.cpp crypto/chacha20poly1305.cpp crypto/curve25519-donna.cpp \ crypto/blake2s.cpp crypto/blake2s_sse.cpp crypto/chacha20poly1305.cpp crypto/curve25519-donna.cpp \
crypto/siphash.cpp crypto/chacha20_x64_gas.s crypto/poly1305_x64_gas.s ipzip2/ipzip2.cpp -lrt -pthread crypto/siphash.cpp crypto/chacha20_x64_gas.s crypto/poly1305_x64_gas.s ipzip2/ipzip2.cpp -lrt -pthread

View file

@ -1,6 +1,6 @@
#!/bin/sh #!/bin/sh
clang++-6.0 -c -march=skylake-avx512 crypto/poly1305_x64_gas.s crypto/chacha20_x64_gas.s clang++-6.0 -c -march=skylake-avx512 crypto/poly1305_x64_gas.s crypto/chacha20_x64_gas.s
clang++-6.0 -I . -O3 -DNDEBUG -mssse3 -pthread -lrt -o tunsafe util.cpp wireguard_config.cpp wireguard.cpp ip_to_peer_map.cpp tunsafe_threading.cpp \ clang++-6.0 -I . -O3 -DNDEBUG -mssse3 -pthread -lrt -o tunsafe util.cpp wireguard_config.cpp wireguard.cpp ts.cpp ip_to_peer_map.cpp tunsafe_threading.cpp \
wireguard_proto.cpp network_bsd.cpp network_bsd_common.cpp tunsafe_cpu.cpp benchmark.cpp crypto/blake2s.cpp crypto/blake2s_sse.cpp crypto/chacha20poly1305.cpp \ wireguard_proto.cpp network_bsd.cpp network_bsd_common.cpp tunsafe_cpu.cpp benchmark.cpp crypto/blake2s.cpp crypto/blake2s_sse.cpp crypto/chacha20poly1305.cpp \
crypto/curve25519-donna.cpp crypto/siphash.cpp chacha20_x64_gas.o crypto/aesgcm/aesni_gcm_x64_gas.s \ crypto/curve25519-donna.cpp crypto/siphash.cpp chacha20_x64_gas.o crypto/aesgcm/aesni_gcm_x64_gas.s \
crypto/aesgcm/aesni_x64_gas.s crypto/aesgcm/aesgcm.cpp poly1305_x64_gas.o ipzip2/ipzip2.cpp \ crypto/aesgcm/aesni_x64_gas.s crypto/aesgcm/aesgcm.cpp poly1305_x64_gas.o ipzip2/ipzip2.cpp \

View file

@ -3,8 +3,8 @@ set -e
clang++ -c -mavx512f -mavx512vl crypto/poly1305_x64_gas_macosx.s crypto/chacha20_x64_gas_macosx.s clang++ -c -mavx512f -mavx512vl crypto/poly1305_x64_gas_macosx.s crypto/chacha20_x64_gas_macosx.s
clang++ -g -O3 -I . -std=c++11 -DNDEBUG=1 -fno-exceptions -fno-rtti -ffunction-sections -o tunsafe \ clang++ -g -O3 -I . -std=c++11 -DNDEBUG=1 -Wno-deprecated-declarations -fno-exceptions -fno-rtti -ffunction-sections -o tunsafe \
wireguard_config.cpp ip_to_peer_map.cpp tunsafe_threading.cpp wireguard.cpp wireguard_proto.cpp util.cpp network_bsd.cpp network_bsd_common.cpp benchmark.cpp tunsafe_cpu.cpp \ wireguard_config.cpp ip_to_peer_map.cpp tunsafe_threading.cpp wireguard.cpp wireguard_proto.cpp ts.cpp util.cpp network_bsd.cpp network_bsd_common.cpp benchmark.cpp tunsafe_cpu.cpp \
crypto/blake2s.cpp crypto/blake2s_sse.cpp crypto/chacha20poly1305.cpp crypto/curve25519-donna.cpp \ crypto/blake2s.cpp crypto/blake2s_sse.cpp crypto/chacha20poly1305.cpp crypto/curve25519-donna.cpp \
crypto/siphash.cpp crypto/aesgcm/aesgcm.cpp ipzip2/ipzip2.cpp \ crypto/siphash.cpp crypto/aesgcm/aesgcm.cpp ipzip2/ipzip2.cpp \
crypto/aesgcm/aesni_gcm_x64_gas_macosx.s crypto/aesgcm/aesni_x64_gas_macosx.s crypto/aesgcm/ghash_x64_gas_macosx.s \ crypto/aesgcm/aesni_gcm_x64_gas_macosx.s crypto/aesgcm/aesni_x64_gas_macosx.s crypto/aesgcm/ghash_x64_gas_macosx.s \

View file

@ -57,10 +57,12 @@ Section "TunSafe Client" SecTunSafe
DetailPrint "Installing 64-bit version of TunSafe." DetailPrint "Installing 64-bit version of TunSafe."
SetOutPath "$INSTDIR" SetOutPath "$INSTDIR"
File "x64\TunSafe.exe" File "x64\TunSafe.exe"
File "x64\ts.exe"
${Else} ${Else}
DetailPrint "Installing 32-bit version of TunSafe." DetailPrint "Installing 32-bit version of TunSafe."
SetOutPath "$INSTDIR" SetOutPath "$INSTDIR"
File "x86\TunSafe.exe" File "x86\TunSafe.exe"
File "x86\ts.exe"
${EndIf} ${EndIf}
File "License.txt" File "License.txt"
File "ChangeLog.txt" File "ChangeLog.txt"
@ -205,6 +207,7 @@ Section "Uninstall"
Delete "$INSTDIR\TunSafe.exe" Delete "$INSTDIR\TunSafe.exe"
Delete "$INSTDIR\ts.exe"
Delete "$INSTDIR\License.txt" Delete "$INSTDIR\License.txt"
Delete "$INSTDIR\ChangeLog.txt" Delete "$INSTDIR\ChangeLog.txt"
Delete "$INSTDIR\Config\TunSafe.conf" Delete "$INSTDIR\Config\TunSafe.conf"

View file

@ -5,6 +5,8 @@
#include "bit_ops.h" #include "bit_ops.h"
#include <string.h> #include <string.h>
#include <assert.h> #include <assert.h>
#include <stdlib.h>
#include "util.h"
IpToPeerMap::IpToPeerMap() { IpToPeerMap::IpToPeerMap() {
@ -13,18 +15,22 @@ IpToPeerMap::IpToPeerMap() {
IpToPeerMap::~IpToPeerMap() { IpToPeerMap::~IpToPeerMap() {
} }
bool IpToPeerMap::InsertV4(uint32 ip, int cidr, void *peer) { void *IpToPeerMap::InsertV4(uint32 ip, int cidr, void *peer) {
ipv4_.Insert(ip, cidr, peer); ipv4_.Insert(ip, cidr, &peer);
return true; return peer;
} }
bool IpToPeerMap::InsertV6(const void *addr, int cidr, void *peer) { void *IpToPeerMap::InsertV6(const void *addr, int cidr, void *peer) {
Entry6 e; Entry6 e;
for (auto it = ipv6_.begin(); it != ipv6_.end(); ++it) {
if (it->cidr_len == cidr && memcmp(it->ip, addr, 16) == 0)
return exch(it->peer, peer);
}
e.cidr_len = cidr; e.cidr_len = cidr;
e.peer = peer; e.peer = peer;
memcpy(e.ip, addr, 16); memcpy(e.ip, addr, 16);
ipv6_.push_back(e); ipv6_.push_back(e);
return true; return NULL;
} }
void *IpToPeerMap::LookupV4(uint32 ip) { void *IpToPeerMap::LookupV4(uint32 ip) {
@ -43,6 +49,19 @@ void *IpToPeerMap::LookupV6DefaultPeer() {
return NULL; return NULL;
} }
void IpToPeerMap::RemoveV4(uint32 ip, int cidr) {
ipv4_.Delete(ip, cidr);
}
void IpToPeerMap::RemoveV6(const void *addr, int cidr) {
for (auto it = ipv6_.begin(); it != ipv6_.end(); ++it) {
if (it->cidr_len == cidr && memcmp(it->ip, addr, 16) == 0) {
ipv6_.erase(it);
return;
}
}
}
static int CalculateIPv6CommonPrefix(const uint8 *a, const uint8 *b) { static int CalculateIPv6CommonPrefix(const uint8 *a, const uint8 *b) {
uint64 x = ToBE64(*(uint64*)&a[0] ^ *(uint64*)&b[0]); uint64 x = ToBE64(*(uint64*)&a[0] ^ *(uint64*)&b[0]);
uint64 y = ToBE64(*(uint64*)&a[8] ^ *(uint64*)&b[8]); uint64 y = ToBE64(*(uint64*)&a[8] ^ *(uint64*)&b[8]);
@ -62,20 +81,6 @@ void *IpToPeerMap::LookupV6(const void *addr) {
return best_peer; return best_peer;
} }
void IpToPeerMap::RemovePeer(void *peer) {
assert(0);
// todo: remove peer also from ipv4_
{
size_t n = ipv6_.size();
Entry6 *r = &ipv6_[0], *w = r;
for (size_t i = 0; i != n; i++, r++) {
if (r->peer != peer)
*w++ = *r;
}
ipv6_.resize(w - &ipv6_[0]);
}
}
#pragma warning (disable: 4200) // warning C4200: nonstandard extension used: zero-sized array in struct/union #pragma warning (disable: 4200) // warning C4200: nonstandard extension used: zero-sized array in struct/union
struct RoutingTrie32::Node { struct RoutingTrie32::Node {
uint32 key; uint32 key;
@ -175,7 +180,7 @@ RoutingTrie32::~RoutingTrie32() {
RoutingTrie32::Value RoutingTrie32::Lookup(uint32 ip) { RoutingTrie32::Value RoutingTrie32::Lookup(uint32 ip) {
uint32 key = ip; uint32 key = ip;
Node *n = root_, *pn = n, *ppn; Node *n = root_, *pn = n, *ppn;
int cindex = 0; uint32 cindex = 0;
if (!n) if (!n)
return NULL; return NULL;
// Find the longest prefix match // Find the longest prefix match
@ -232,7 +237,7 @@ backtrace:
} }
// strip lsb of cindex and find child // strip lsb of cindex and find child
cindex &= cindex - 1; cindex &= cindex - 1;
assert(cindex < (1 << pn->bits)); assert(cindex < (1U << pn->bits));
n = pn->child[cindex]; n = pn->child[cindex];
if (!NODE_IS_NULL_OR_OLEAF(n)) if (!NODE_IS_NULL_OR_OLEAF(n))
break; break;
@ -246,7 +251,7 @@ backtrace:
} }
} }
bool RoutingTrie32::InsertLeafInto(Node **nn, uint8 leaf_pos, Value value) { bool RoutingTrie32::InsertLeafInto(Node **nn, uint8 leaf_pos, Value *valuep) {
// put higher cidr higher up // put higher cidr higher up
Node *n = *nn; Node *n = *nn;
assert(IS_LEAF(n)); assert(IS_LEAF(n));
@ -255,12 +260,12 @@ bool RoutingTrie32::InsertLeafInto(Node **nn, uint8 leaf_pos, Value value) {
if (leaf_pos < n->pos) if (leaf_pos < n->pos)
break; break;
if (leaf_pos == n->pos) { if (leaf_pos == n->pos) {
n->leaf_value = value; std::swap(n->leaf_value, *valuep);
return true; return true;
} }
nn = &n->leaf_next; nn = &n->leaf_next;
} while ((n = *nn) != NULL); } while ((n = *nn) != NULL);
Node *leaf = NewLeaf(key, leaf_pos, value); Node *leaf = NewLeaf(key, leaf_pos, *valuep);
if (leaf == NULL) if (leaf == NULL)
return false; return false;
leaf->leaf_next = *nn; leaf->leaf_next = *nn;
@ -283,14 +288,14 @@ void RoutingTrie32::PutChild(Node *pn, uint32 i, Node *n) {
assert(pn->full_children < 0x80000000); assert(pn->full_children < 0x80000000);
} }
bool RoutingTrie32::Insert(uint32 ip, int cidr, Value value) { bool RoutingTrie32::Insert(uint32 ip, int cidr, Value *valuep) {
uint32 key = ip; uint32 key = ip;
Node **nn = &root_, *n = root_, *pn = NULL, *leaf, *tn = NULL, *leaf_to_free = NULL; Node **nn = &root_, *n = root_, *pn = NULL, *leaf, *tn = NULL, *leaf_to_free = NULL;
uint8 leaf_pos = 32 - cidr; uint8 leaf_pos = 32 - cidr;
if (n == NULL) { if (n == NULL) {
root_ = NewLeaf(key, leaf_pos, value); root_ = NewLeaf(key, leaf_pos, exch_null(*valuep));
return (root_ != NULL); return false;
} }
assert(!NODE_IS_OLEAF(n)); assert(!NODE_IS_OLEAF(n));
@ -316,7 +321,7 @@ force_add:
if (IS_LEAF(n)) { if (IS_LEAF(n)) {
if (key != n->key) if (key != n->key)
goto force_add; goto force_add;
return InsertLeafInto(nn, leaf_pos, value); return InsertLeafInto(nn, leaf_pos, valuep);
} }
pn = n; pn = n;
nn = &n->child[index]; nn = &n->child[index];
@ -330,6 +335,7 @@ force_add:
*nn = n; *nn = n;
} }
} }
Value value = *valuep;
// Create either leaf or oleaf // Create either leaf or oleaf
if (tn->pos == leaf_pos) { if (tn->pos == leaf_pos) {
leaf = VALUE_TO_OLEAF(value); leaf = VALUE_TO_OLEAF(value);
@ -338,8 +344,8 @@ force_add:
FreeNode(tn); FreeNode(tn);
return false; return false;
} }
// -- Start making irreversible changes here // -- Start making irreversible changes here
*valuep = NULL;
if (leaf_to_free) if (leaf_to_free)
FreeNode(leaf_to_free); FreeNode(leaf_to_free);

View file

@ -15,7 +15,7 @@ public:
~RoutingTrie32(); ~RoutingTrie32();
NOINLINE Value Lookup(uint32 ip); NOINLINE Value Lookup(uint32 ip);
NOINLINE Value LookupExact(uint32 ip, int cidr); NOINLINE Value LookupExact(uint32 ip, int cidr);
bool Insert(uint32 ip, int cidr, Value value); bool Insert(uint32 ip, int cidr, Value *value);
bool Delete(uint32 ip, int cidr); bool Delete(uint32 ip, int cidr);
private: private:
@ -31,7 +31,7 @@ private:
static void PutChild(Node *pn, uint32 i, Node *n); static void PutChild(Node *pn, uint32 i, Node *n);
static void ReplaceChild(Node **pnp, Node *n); static void ReplaceChild(Node **pnp, Node *n);
static Node *ConvertOleafToLeaf(Node *pn, uint32 i, Node *n); static Node *ConvertOleafToLeaf(Node *pn, uint32 i, Node *n);
static bool InsertLeafInto(Node **n, uint8 leaf_pos, RoutingTrie32::Value value); static bool InsertLeafInto(Node **n, uint8 leaf_pos, RoutingTrie32::Value *value);
}; };
@ -43,8 +43,8 @@ public:
~IpToPeerMap(); ~IpToPeerMap();
// Inserts an IP address of a given CIDR length into the lookup table, pointing to peer. // Inserts an IP address of a given CIDR length into the lookup table, pointing to peer.
bool InsertV4(uint32 ip, int cidr, void *peer); void *InsertV4(uint32 ip, int cidr, void *peer);
bool InsertV6(const void *addr, int cidr, void *peer); void *InsertV6(const void *addr, int cidr, void *peer);
// Lookup the peer matching the IP Address // Lookup the peer matching the IP Address
void *LookupV4(uint32 ip); void *LookupV4(uint32 ip);
@ -53,8 +53,8 @@ public:
void *LookupV4DefaultPeer(); void *LookupV4DefaultPeer();
void *LookupV6DefaultPeer(); void *LookupV6DefaultPeer();
// Remove a peer from the table void RemoveV4(uint32 ip, int cidr);
void RemovePeer(void *peer); void RemoveV6(const void *addr, int cidr);
private: private:
struct Entry6 { struct Entry6 {
uint8 ip[16]; uint8 ip[16];

View file

@ -18,7 +18,6 @@
#pragma warning (disable: 4200) #pragma warning (disable: 4200)
void OsGetRandomBytes(uint8 *dst, size_t dst_size);
uint64 OsGetMilliseconds(); uint64 OsGetMilliseconds();
void OsGetTimestampTAI64N(uint8 dst[12]); void OsGetTimestampTAI64N(uint8 dst[12]);
void OsInterruptibleSleep(int millis); void OsInterruptibleSleep(int millis);
@ -127,13 +126,13 @@ public:
uint8 neighbor_discovery_spoofing_mac[6]; uint8 neighbor_discovery_spoofing_mac[6];
}; };
virtual bool Initialize(const TunConfig &&config, TunConfigOut *out) = 0; virtual bool Configure(const TunConfig &&config, TunConfigOut *out) = 0;
virtual void WriteTunPacket(Packet *packet) = 0; virtual void WriteTunPacket(Packet *packet) = 0;
}; };
class UdpInterface { class UdpInterface {
public: public:
virtual bool Initialize(int listen_port) = 0; virtual bool Configure(int listen_port) = 0;
virtual void WriteUdpPacket(Packet *packet) = 0; virtual void WriteUdpPacket(Packet *packet) = 0;
}; };

View file

@ -13,10 +13,12 @@
#include <string.h> #include <string.h>
#include <arpa/inet.h> #include <arpa/inet.h>
#include <sys/stat.h> #include <sys/stat.h>
#include <sys/un.h>
#include <stdlib.h> #include <stdlib.h>
#include <errno.h> #include <errno.h>
#include <assert.h> #include <assert.h>
#include <signal.h> #include <signal.h>
#include <poll.h>
static Packet *freelist; static Packet *freelist;
@ -49,6 +51,7 @@ void FreePackets() {
} }
} }
class TunsafeBackendBsdImpl : public TunsafeBackendBsd { class TunsafeBackendBsdImpl : public TunsafeBackendBsd {
public: public:
TunsafeBackendBsdImpl(); TunsafeBackendBsdImpl();
@ -61,7 +64,7 @@ public:
virtual void WriteTunPacket(Packet *packet) override; virtual void WriteTunPacket(Packet *packet) override;
// -- from UdpInterface // -- from UdpInterface
virtual bool Initialize(int listen_port) override; virtual bool Configure(int listen_port) override;
virtual void WriteUdpPacket(Packet *packet) override; virtual void WriteUdpPacket(Packet *packet) override;
virtual void HandleSigAlrm() override { got_sig_alarm_ = true; } virtual void HandleSigAlrm() override { got_sig_alarm_ = true; }
@ -72,12 +75,19 @@ private:
bool ReadFromTun(); bool ReadFromTun();
bool WriteToUdp(); bool WriteToUdp();
bool WriteToTun(); bool WriteToTun();
bool InitializeUnixDomainSocket(const char *devname);
// Exists for the unix domain sockets
struct SockInfo {
bool is_listener;
std::string inbuf, outbuf;
};
bool HandleSpecialPollfd(struct pollfd *pollfd, struct SockInfo *sockinfo);
void CloseSpecialPollfd(size_t i);
void SetUdpFd(int fd); void SetUdpFd(int fd);
void SetTunFd(int fd); void SetTunFd(int fd);
inline void RecomputeMaxFd() { max_fd_ = ((tun_fd_>udp_fd_) ? tun_fd_ : udp_fd_) + 1; }
int tun_fd_, udp_fd_, max_fd_;
bool got_sig_alarm_; bool got_sig_alarm_;
bool exit_; bool exit_;
@ -89,13 +99,25 @@ private:
Packet *read_packet_; Packet *read_packet_;
fd_set readfds_, writefds_; enum {
kMaxPollFd = 5,
kPollFdTun = 0,
kPollFdUdp = 1,
kPollFdUnix = 2,
};
unsigned int pollfd_num_;
struct pollfd pollfd_[kMaxPollFd];
struct SockInfo sockinfo_[kMaxPollFd - 2];
struct sockaddr_un un_addr_;
UnixSocketDeletionWatcher un_deletion_watcher_;
}; };
TunsafeBackendBsdImpl::TunsafeBackendBsdImpl() TunsafeBackendBsdImpl::TunsafeBackendBsdImpl()
: tun_fd_(-1), : tun_readable_(false),
udp_fd_(-1),
tun_readable_(false),
tun_writable_(false), tun_writable_(false),
udp_readable_(false), udp_readable_(false),
udp_writable_(false), udp_writable_(false),
@ -106,35 +128,39 @@ TunsafeBackendBsdImpl::TunsafeBackendBsdImpl()
udp_queue_(NULL), udp_queue_(NULL),
udp_queue_end_(&udp_queue_), udp_queue_end_(&udp_queue_),
read_packet_(NULL) { read_packet_(NULL) {
RecomputeMaxFd();
FD_ZERO(&readfds_);
FD_ZERO(&writefds_);
read_packet_ = AllocPacket(); read_packet_ = AllocPacket();
for(size_t i = 0; i < kMaxPollFd; i++)
pollfd_[i].fd = -1;
pollfd_num_ = 3;
sockinfo_[0].is_listener = true;
memset(&un_addr_, 0, sizeof(un_addr_));
} }
TunsafeBackendBsdImpl::~TunsafeBackendBsdImpl() { TunsafeBackendBsdImpl::~TunsafeBackendBsdImpl() {
if (un_addr_.sun_path[0])
unlink(un_addr_.sun_path);
if (read_packet_) if (read_packet_)
FreePacket(read_packet_); FreePacket(read_packet_);
for(size_t i = 0; i < pollfd_num_; i++)
close(pollfd_[i].fd);
} }
void TunsafeBackendBsdImpl::SetUdpFd(int fd) { void TunsafeBackendBsdImpl::SetUdpFd(int fd) {
udp_fd_ = fd; pollfd_[kPollFdUdp].fd = fd;
RecomputeMaxFd(); pollfd_[kPollFdUdp].events = POLLIN;
udp_writable_ = true; udp_writable_ = true;
} }
void TunsafeBackendBsdImpl::SetTunFd(int fd) { void TunsafeBackendBsdImpl::SetTunFd(int fd) {
tun_fd_ = fd; pollfd_[kPollFdTun].fd = fd;
RecomputeMaxFd(); pollfd_[kPollFdTun].events = POLLIN;
tun_writable_ = true; tun_writable_ = true;
} }
bool TunsafeBackendBsdImpl::ReadFromUdp(bool overload) { bool TunsafeBackendBsdImpl::ReadFromUdp(bool overload) {
socklen_t sin_len; socklen_t sin_len;
sin_len = sizeof(read_packet_->addr.sin); sin_len = sizeof(read_packet_->addr.sin);
int r = recvfrom(udp_fd_, read_packet_->data, kPacketCapacity, 0, int r = recvfrom(pollfd_[kPollFdUdp].fd, read_packet_->data, kPacketCapacity, 0,
(sockaddr*)&read_packet_->addr.sin, &sin_len); (sockaddr*)&read_packet_->addr.sin, &sin_len);
if (r >= 0) { if (r >= 0) {
// printf("Read %d bytes from UDP\n", r); // printf("Read %d bytes from UDP\n", r);
@ -157,11 +183,12 @@ bool TunsafeBackendBsdImpl::ReadFromUdp(bool overload) {
bool TunsafeBackendBsdImpl::WriteToUdp() { bool TunsafeBackendBsdImpl::WriteToUdp() {
assert(udp_writable_); assert(udp_writable_);
// RINFO("Send %d bytes to %s", (int)udp_queue_->size, inet_ntoa(udp_queue_->sin.sin_addr)); // RINFO("Send %d bytes to %s", (int)udp_queue_->size, inet_ntoa(udp_queue_->sin.sin_addr));
int r = sendto(udp_fd_, udp_queue_->data, udp_queue_->size, 0, int r = sendto(pollfd_[kPollFdUdp].fd, udp_queue_->data, udp_queue_->size, 0,
(sockaddr*)&udp_queue_->addr.sin, sizeof(udp_queue_->addr.sin)); (sockaddr*)&udp_queue_->addr.sin, sizeof(udp_queue_->addr.sin));
if (r < 0) { if (r < 0) {
if (errno == EAGAIN) { if (errno == EAGAIN) {
udp_writable_ = false; udp_writable_ = false;
pollfd_[kPollFdUdp].events = POLLIN | POLLOUT;
return false; return false;
} }
perror("Write to UDP failed"); perror("Write to UDP failed");
@ -185,7 +212,7 @@ static inline bool IsCompatibleProto(uint32 v) {
bool TunsafeBackendBsdImpl::ReadFromTun() { bool TunsafeBackendBsdImpl::ReadFromTun() {
assert(tun_readable_); assert(tun_readable_);
Packet *packet = read_packet_; Packet *packet = read_packet_;
int r = read(tun_fd_, packet->data - TUN_PREFIX_BYTES, kPacketCapacity + TUN_PREFIX_BYTES); int r = read(pollfd_[kPollFdTun].fd, packet->data - TUN_PREFIX_BYTES, kPacketCapacity + TUN_PREFIX_BYTES);
if (r >= 0) { if (r >= 0) {
// printf("Read %d bytes from TUN\n", r); // printf("Read %d bytes from TUN\n", r);
packet->size = r - TUN_PREFIX_BYTES; packet->size = r - TUN_PREFIX_BYTES;
@ -215,10 +242,11 @@ bool TunsafeBackendBsdImpl::WriteToTun() {
if (TUN_PREFIX_BYTES) { if (TUN_PREFIX_BYTES) {
WriteBE32(tun_queue_->data - TUN_PREFIX_BYTES, GetProtoFromPacket(tun_queue_->data, tun_queue_->size)); WriteBE32(tun_queue_->data - TUN_PREFIX_BYTES, GetProtoFromPacket(tun_queue_->data, tun_queue_->size));
} }
int r = write(tun_fd_, tun_queue_->data - TUN_PREFIX_BYTES, tun_queue_->size + TUN_PREFIX_BYTES); int r = write(pollfd_[kPollFdTun].fd, tun_queue_->data - TUN_PREFIX_BYTES, tun_queue_->size + TUN_PREFIX_BYTES);
if (r < 0) { if (r < 0) {
if (errno == EAGAIN) { if (errno == EAGAIN) {
tun_writable_ = false; tun_writable_ = false;
pollfd_[kPollFdTun].events = POLLIN | POLLOUT;
return false; return false;
} }
RERROR("Write to tun failed"); RERROR("Write to tun failed");
@ -242,11 +270,13 @@ bool TunsafeBackendBsdImpl::InitializeTun(char devname[16]) {
fcntl(tun_fd, F_SETFD, FD_CLOEXEC); fcntl(tun_fd, F_SETFD, FD_CLOEXEC);
fcntl(tun_fd, F_SETFL, O_NONBLOCK); fcntl(tun_fd, F_SETFL, O_NONBLOCK);
SetTunFd(tun_fd); SetTunFd(tun_fd);
InitializeUnixDomainSocket(devname);
return true; return true;
} }
void TunsafeBackendBsdImpl::WriteTunPacket(Packet *packet) override { void TunsafeBackendBsdImpl::WriteTunPacket(Packet *packet) override {
assert(tun_fd_ >= 0); assert(pollfd_[kPollFdTun].fd >= 0);
Packet *queue_is_used = tun_queue_; Packet *queue_is_used = tun_queue_;
*tun_queue_end_ = packet; *tun_queue_end_ = packet;
tun_queue_end_ = &packet->next; tun_queue_end_ = &packet->next;
@ -256,7 +286,7 @@ void TunsafeBackendBsdImpl::WriteTunPacket(Packet *packet) override {
} }
// Called to initialize udp // Called to initialize udp
bool TunsafeBackendBsdImpl::Initialize(int listen_port) override { bool TunsafeBackendBsdImpl::Configure(int listen_port) override {
int udp_fd = open_udp(listen_port); int udp_fd = open_udp(listen_port);
if (udp_fd < 0) { RERROR("Error opening udp"); return false; } if (udp_fd < 0) { RERROR("Error opening udp"); return false; }
fcntl(udp_fd, F_SETFD, FD_CLOEXEC); fcntl(udp_fd, F_SETFD, FD_CLOEXEC);
@ -266,7 +296,7 @@ bool TunsafeBackendBsdImpl::Initialize(int listen_port) override {
} }
void TunsafeBackendBsdImpl::WriteUdpPacket(Packet *packet) override { void TunsafeBackendBsdImpl::WriteUdpPacket(Packet *packet) override {
assert(udp_fd_ >= 0); assert(pollfd_[kPollFdUdp].fd >= 0);
Packet *queue_is_used = udp_queue_; Packet *queue_is_used = udp_queue_;
*udp_queue_end_ = packet; *udp_queue_end_ = packet;
udp_queue_end_ = &packet->next; udp_queue_end_ = &packet->next;
@ -275,16 +305,137 @@ void TunsafeBackendBsdImpl::WriteUdpPacket(Packet *packet) override {
WriteToUdp(); WriteToUdp();
} }
bool TunsafeBackendBsdImpl::InitializeUnixDomainSocket(const char *devname) {
int fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (fd == -1) {
RERROR("Error creating unix domain socket");
return false;
}
fcntl(fd, F_SETFD, FD_CLOEXEC);
fcntl(fd, F_SETFL, O_NONBLOCK);
mkdir("/var/run/wireguard", 0755);
un_addr_.sun_family = AF_UNIX;
snprintf(un_addr_.sun_path, sizeof(un_addr_.sun_path), "/var/run/wireguard/%s.sock", devname);
unlink(un_addr_.sun_path);
if (bind(fd, (struct sockaddr*)&un_addr_, sizeof(un_addr_)) == -1) {
RERROR("Error binding unix domain socket");
close(fd);
return false;
}
if (listen(fd, 5) == -1) {
RERROR("Error listening on unix domain socket");
close(fd);
return false;
}
pollfd_[kPollFdUnix].fd = fd;
pollfd_[kPollFdUnix].events = POLLIN;
return true;
}
static const char *FindMessageEnd(const char *start, size_t size) {
if (size <= 1)
return NULL;
const char *start_end = start + size - 1;
for(;(start = (const char*)memchr(start, '\n', start_end - start)) != NULL; start++) {
if (start[1] == '\n')
return start + 2;
}
return NULL;
}
bool TunsafeBackendBsdImpl::HandleSpecialPollfd(struct pollfd *pfd, struct SockInfo *sockinfo) {
// handle domain socket thing
if (sockinfo->is_listener) {
if (pfd->revents & POLLIN) {
// wait if we can't allocate more pollfd
if (pollfd_num_ == kMaxPollFd) {
pfd->events = 0;
return true;
}
int fd = accept(pfd->fd, NULL, NULL);
if (fd >= 0) {
size_t slot = pollfd_num_++;
pollfd_[slot].fd = fd;
pollfd_[slot].events = POLLIN;
pollfd_[slot].revents = 0;
sockinfo_[slot - 2].is_listener = false;
} else {
RERROR("Unix domain socket accept failed");
}
}
if (pfd->revents & ~POLLIN) {
RERROR("Unix domain socket got an error code");
return false;
}
return true;
}
if (pfd->revents & POLLIN) {
char buf[4096];
// read as much data as we can until we see \n\n
ssize_t n = recv(pfd->fd, buf, sizeof(buf), 0);
if (n <= 0)
return (n == -1 && errno == EAGAIN); // premature eof or error
sockinfo->inbuf.append(buf, n);
const char *message_end = FindMessageEnd(&sockinfo->inbuf[0], sockinfo->inbuf.size());
if (message_end) {
if (message_end != &sockinfo->inbuf[sockinfo->inbuf.size()])
return false; // trailing data?
WgConfig::HandleConfigurationProtocolMessage(processor_, std::move(sockinfo->inbuf), &sockinfo->outbuf);
if (!sockinfo->outbuf.size())
return false;
pfd->revents = pfd->events = POLLOUT;
}
}
if (pfd->revents & POLLOUT) {
size_t n = send(pfd->fd, sockinfo->outbuf.data(), sockinfo->outbuf.size(), 0);
if (n <= 0)
return (n == -1 && errno == EAGAIN); // premature eof or error
sockinfo->outbuf.erase(0, n);
if (!sockinfo->outbuf.size())
return false;
}
if (pfd->revents & ~(POLLIN | POLLOUT)) {
RERROR("Unix domain socket got an error code");
return false;
}
return true;
}
void TunsafeBackendBsdImpl::CloseSpecialPollfd(size_t i) {
close(pollfd_[i].fd);
pollfd_[i].fd = -1;
sockinfo_[i - 2].inbuf.clear();
sockinfo_[i - 2].outbuf.clear();
pollfd_[i] = pollfd_[(size_t)pollfd_num_ - 1];
std::swap(sockinfo_[i - 2], sockinfo_[(size_t)pollfd_num_ - 1 - 2]);
// Can now allow more sockets?
if (pollfd_num_-- == kMaxPollFd && sockinfo_[kPollFdUnix - 2].is_listener)
pollfd_[kPollFdUnix].events = POLLIN;
}
void TunsafeBackendBsdImpl::RunLoopInner() { void TunsafeBackendBsdImpl::RunLoopInner() {
int free_packet_interval = 10; int free_packet_interval = 10;
int overload_ctr = 0; int overload_ctr = 0;
if (!un_deletion_watcher_.Start(un_addr_.sun_path, &exit_))
return;
while (!exit_) { while (!exit_) {
int n = -1; int n = -1;
// This is not fully signal safe.
if (got_sig_alarm_) { if (got_sig_alarm_) {
got_sig_alarm_ = false; got_sig_alarm_ = false;
if (un_deletion_watcher_.Poll(un_addr_.sun_path)) {
RINFO("Unix socket %s deleted.", un_addr_.sun_path);
break;
}
processor_->SecondLoop(); processor_->SecondLoop();
if (free_packet_interval == 0) { if (free_packet_interval == 0) {
@ -296,30 +447,50 @@ void TunsafeBackendBsdImpl::RunLoopInner() {
overload_ctr -= (overload_ctr != 0); overload_ctr -= (overload_ctr != 0);
} }
if (tun_fd_ >= 0) { #if defined(OS_LINUX) || defined(OS_FREEBSD)
FD_SET(tun_fd_, &readfds_); n = ppoll(pollfd_, pollfd_num_, NULL, &orig_signal_mask_);
if (tun_writable_) FD_CLR(tun_fd_, &writefds_); else FD_SET(tun_fd_, &writefds_); #else
} n = poll(pollfd_, pollfd_num_, -1);
#endif
if (udp_fd_ >= 0) {
FD_SET(udp_fd_, &readfds_);
if (udp_writable_) FD_CLR(udp_fd_, &writefds_); else FD_SET(udp_fd_, &writefds_);
}
n = select(max_fd_, &readfds_, &writefds_, NULL, NULL);
if (n == -1) { if (n == -1) {
if (errno != EINTR) { if (errno != EINTR) {
fprintf(stderr, "select failed\n"); RERROR("poll failed");
break; break;
} }
} else { } else {
if (tun_fd_ >= 0) {
tun_readable_ = (FD_ISSET(tun_fd_, &readfds_) != 0); if (pollfd_[kPollFdTun].revents & (POLLERR | POLLHUP | POLLNVAL)) {
tun_writable_ |= (FD_ISSET(tun_fd_, &writefds_) != 0); if (pollfd_[kPollFdTun].revents & POLLERR) {
tun_interface_gone_ = true;
RERROR("Tun interface gone, closing.");
} else {
RERROR("Tun interface error %d, closing.", pollfd_[kPollFdTun].revents);
}
break;
}
tun_readable_ = (pollfd_[kPollFdTun].revents & POLLIN) != 0;
if (pollfd_[kPollFdTun].revents & POLLOUT) {
pollfd_[kPollFdTun].events = POLLIN;
tun_writable_ = true;
}
if (pollfd_[kPollFdUdp].revents & (POLLERR | POLLHUP | POLLNVAL)) {
RERROR("UDP error %d, closing.", pollfd_[kPollFdUdp].revents);
break;
}
udp_readable_ = (pollfd_[kPollFdUdp].revents & POLLIN) != 0;
if (pollfd_[kPollFdUdp].revents & POLLOUT) {
pollfd_[kPollFdUdp].events = POLLIN;
udp_writable_ = true;
}
for(size_t i = 2; i < pollfd_num_; i++) {
if (pollfd_[i].revents && !HandleSpecialPollfd(&pollfd_[i], &sockinfo_[i - 2])) {
// Close the fd / discard the sockinfo
CloseSpecialPollfd(i);
i--;
} }
if (udp_fd_ >= 0) {
udp_readable_ = (FD_ISSET(udp_fd_, &readfds_) != 0);
udp_writable_ |= (FD_ISSET(udp_fd_, &writefds_) != 0);
} }
} }
@ -342,6 +513,8 @@ void TunsafeBackendBsdImpl::RunLoopInner() {
processor_->RunAllMainThreadScheduled(); processor_->RunAllMainThreadScheduled();
} }
un_deletion_watcher_.Stop();
} }
TunsafeBackendBsd *CreateTunsafeBackendBsd() { TunsafeBackendBsd *CreateTunsafeBackendBsd() {

View file

@ -39,6 +39,8 @@
#include <linux/if_tun.h> #include <linux/if_tun.h>
#include <sys/prctl.h> #include <sys/prctl.h>
#include <linux/rtnetlink.h> #include <linux/rtnetlink.h>
#include <sys/inotify.h>
#include <limits.h>
#endif #endif
void tunsafe_die(const char *msg) { void tunsafe_die(const char *msg) {
@ -286,15 +288,6 @@ void OsGetTimestampTAI64N(uint8 dst[12]) {
WriteBE32(dst + 8, nanos); WriteBE32(dst + 8, nanos);
} }
void OsGetRandomBytes(uint8 *data, size_t data_size) {
int fd = open("/dev/urandom", O_RDONLY);
int r = read(fd, data, data_size);
if (r < 0) r = 0;
close(fd);
for (; r < data_size; r++)
data[r] = rand() >> 6;
}
void OsInterruptibleSleep(int millis) { void OsInterruptibleSleep(int millis) {
usleep((useconds_t)millis * 1000); usleep((useconds_t)millis * 1000);
} }
@ -387,11 +380,12 @@ int open_tun(char *devname, size_t devname_size) {
memset(&ifr, 0, sizeof(ifr)); memset(&ifr, 0, sizeof(ifr));
ifr.ifr_flags = IFF_TUN | IFF_NO_PI; ifr.ifr_flags = IFF_TUN | IFF_NO_PI;
my_strlcpy(ifr.ifr_name, sizeof(ifr.ifr_name), devname);
if ((err = ioctl(fd, TUNSETIFF, (void *) &ifr)) < 0) { if ((err = ioctl(fd, TUNSETIFF, (void *) &ifr)) < 0) {
close(fd); close(fd);
return err; return err;
} }
strcpy(devname, ifr.ifr_name); my_strlcpy(devname, devname_size, ifr.ifr_name);
return fd; return fd;
} }
#endif #endif
@ -411,6 +405,8 @@ int open_udp(int listen_on_port) {
TunsafeBackendBsd::TunsafeBackendBsd() TunsafeBackendBsd::TunsafeBackendBsd()
: processor_(NULL) { : processor_(NULL) {
devname_[0] = 0;
tun_interface_gone_ = false;
} }
TunsafeBackendBsd::~TunsafeBackendBsd() { TunsafeBackendBsd::~TunsafeBackendBsd() {
@ -497,8 +493,8 @@ static bool IsIpv6AddressSet(const void *p) {
} }
// Called to initialize tun // Called to initialize tun
bool TunsafeBackendBsd::Initialize(const TunConfig &&config, TunConfigOut *out) override { bool TunsafeBackendBsd::Configure(const TunConfig &&config, TunConfigOut *out) override {
char devname[16]; char buf[kSizeOfAddress];
if (!RunPrePostCommand(config.pre_post_commands.pre_up)) { if (!RunPrePostCommand(config.pre_post_commands.pre_up)) {
RERROR("Pre command failed!"); RERROR("Pre command failed!");
@ -507,17 +503,35 @@ bool TunsafeBackendBsd::Initialize(const TunConfig &&config, TunConfigOut *out)
out->enable_neighbor_discovery_spoofing = false; out->enable_neighbor_discovery_spoofing = false;
if (!InitializeTun(devname)) if (!InitializeTun(devname_))
return false; return false;
if (config.ipv6_cidr)
RERROR("IPv6 not supported");
uint32 netmask = CidrToNetmaskV4(config.cidr); uint32 netmask = CidrToNetmaskV4(config.cidr);
uint32 default_route_v4 = ComputeIpv4DefaultRoute(config.ip, netmask); uint32 default_route_v4 = ComputeIpv4DefaultRoute(config.ip, netmask);
RunCommand("/sbin/ifconfig %s %A mtu %d %A netmask %A up", devname, config.ip, config.mtu, config.ip, netmask);
AddRoute(config.ip & netmask, config.cidr, config.ip, devname); #if defined(OS_LINUX)
if (config.ip) {
char ip[4];
WriteBE32(ip, config.ip);
RunCommand("/sbin/ip address add dev %s %s", devname_, print_ip_prefix(buf, AF_INET, ip, config.cidr));
}
if (config.ipv6_cidr) {
RunCommand("/sbin/ip address add dev %s %s", devname_, print_ip_prefix(buf, AF_INET6, config.ipv6_address, config.ipv6_cidr));
}
RunCommand("/sbin/ip link set dev %s mtu %d up", devname_, config.mtu);
#else // !defined(OS_LINUX)
if (config.ip) {
RunCommand("/sbin/ifconfig %s %A mtu %d %A netmask %A up", devname_, config.ip, config.mtu, config.ip, netmask);
}
if (config.ipv6_cidr) {
RunCommand("/sbin/ifconfig %s inet6 add %s", devname_, print_ip_prefix(buf, AF_INET6, config.ipv6_address, config.ipv6_cidr));
}
#endif // !defined(OS_LINUX)
if (config.ip) {
AddRoute(config.ip & netmask, config.cidr, config.ip, devname_);
}
if (config.use_ipv4_default_route) { if (config.use_ipv4_default_route) {
if (config.default_route_endpoint_v4) { if (config.default_route_endpoint_v4) {
@ -533,35 +547,30 @@ bool TunsafeBackendBsd::Initialize(const TunConfig &&config, TunConfigOut *out)
AddRoute(ReadBE32(it->addr), it->cidr, ipv4_default_gw, default_iface); AddRoute(ReadBE32(it->addr), it->cidr, ipv4_default_gw, default_iface);
} }
} }
AddRoute(0x00000000, 1, default_route_v4, devname); AddRoute(0x00000000, 1, default_route_v4, devname_);
AddRoute(0x80000000, 1, default_route_v4, devname); AddRoute(0x80000000, 1, default_route_v4, devname_);
} }
uint8 default_route_v6[16]; uint8 default_route_v6[16];
if (config.ipv6_cidr) { if (config.ipv6_cidr) {
static const uint8 matchall_1_route[17] = {0x80, 0, 0, 0}; static const uint8 matchall_1_route[17] = {0x80, 0, 0, 0};
char buf[kSizeOfAddress];
ComputeIpv6DefaultRoute(config.ipv6_address, config.ipv6_cidr, default_route_v6); ComputeIpv6DefaultRoute(config.ipv6_address, config.ipv6_cidr, default_route_v6);
RunCommand("/sbin/ifconfig %s inet6 add %s", devname, print_ip_prefix(buf, AF_INET6, config.ipv6_address, config.ipv6_cidr));
if (config.use_ipv6_default_route) { if (config.use_ipv6_default_route) {
if (IsIpv6AddressSet(config.default_route_endpoint_v6)) { if (IsIpv6AddressSet(config.default_route_endpoint_v6)) {
RERROR("default_route_endpoint_v6 not supported"); RERROR("default_route_endpoint_v6 not supported");
} }
AddRoute(AF_INET6, matchall_1_route + 1, 1, default_route_v6, devname); AddRoute(AF_INET6, matchall_1_route + 1, 1, default_route_v6, devname_);
AddRoute(AF_INET6, matchall_1_route + 0, 1, default_route_v6, devname); AddRoute(AF_INET6, matchall_1_route + 0, 1, default_route_v6, devname_);
} }
} }
// Add all the extra routes // Add all the extra routes
for (auto it = config.extra_routes.begin(); it != config.extra_routes.end(); ++it) { for (auto it = config.extra_routes.begin(); it != config.extra_routes.end(); ++it) {
if (it->size == 32) { if (it->size == 32) {
AddRoute(ReadBE32(it->addr), it->cidr, default_route_v4, devname); AddRoute(ReadBE32(it->addr), it->cidr, default_route_v4, devname_);
} else if (it->size == 128 && config.ipv6_cidr) { } else if (it->size == 128 && config.ipv6_cidr) {
AddRoute(AF_INET6, it->addr, it->cidr, default_route_v6, devname); AddRoute(AF_INET6, it->addr, it->cidr, default_route_v6, devname_);
} }
} }
@ -576,8 +585,10 @@ bool TunsafeBackendBsd::Initialize(const TunConfig &&config, TunConfigOut *out)
void TunsafeBackendBsd::CleanupRoutes() { void TunsafeBackendBsd::CleanupRoutes() {
RunPrePostCommand(pre_down_); RunPrePostCommand(pre_down_);
for(auto it = cleanup_commands_.begin(); it != cleanup_commands_.end(); ++it) for(auto it = cleanup_commands_.begin(); it != cleanup_commands_.end(); ++it) {
if (!tun_interface_gone_ || strcmp(it->dev.c_str(), devname_) != 0)
DelRoute(*it); DelRoute(*it);
}
cleanup_commands_.clear(); cleanup_commands_.clear();
RunPrePostCommand(post_down_); RunPrePostCommand(post_down_);
@ -586,6 +597,10 @@ void TunsafeBackendBsd::CleanupRoutes() {
post_down_.clear(); post_down_.clear();
} }
void TunsafeBackendBsd::SetTunDeviceName(const char *name) {
my_strlcpy(devname_, sizeof(devname_), name);
}
static bool RunOneCommand(const std::string &cmd) { static bool RunOneCommand(const std::string &cmd) {
RINFO("Run: %s", cmd.c_str()); RINFO("Run: %s", cmd.c_str());
int exit_code = system(cmd.c_str()); int exit_code = system(cmd.c_str());
@ -604,6 +619,94 @@ bool TunsafeBackendBsd::RunPrePostCommand(const std::vector<std::string> &vec) {
return success; return success;
} }
#if defined(OS_LINUX)
UnixSocketDeletionWatcher::UnixSocketDeletionWatcher()
: inotify_fd_(-1) {
pipes_[0] = -1;
pipes_[0] = -1;
}
UnixSocketDeletionWatcher::~UnixSocketDeletionWatcher() {
close(inotify_fd_);
close(pipes_[0]);
close(pipes_[1]);
}
bool UnixSocketDeletionWatcher::Start(const char *path, bool *flag_to_set) {
assert(inotify_fd_ == -1);
path_ = path;
flag_to_set_ = flag_to_set;
pid_ = getpid();
inotify_fd_ = inotify_init1(IN_CLOEXEC | IN_NONBLOCK);
if (inotify_fd_ == -1) {
perror("inotify_init1() failed");
return false;
}
if (inotify_add_watch(inotify_fd_, "/var/run/wireguard", IN_DELETE | IN_DELETE_SELF) == -1) {
perror("inotify_add_watch failed");
return false;
}
if (pipe(pipes_) == -1) {
perror("pipe() failed");
return false;
}
return pthread_create(&thread_, NULL, &UnixSocketDeletionWatcher::RunThread, this) == 0;
}
void UnixSocketDeletionWatcher::Stop() {
RINFO("Stopping..");
void *retval;
write(pipes_[1], "", 1);
pthread_join(thread_, &retval);
}
void *UnixSocketDeletionWatcher::RunThread(void *arg) {
UnixSocketDeletionWatcher *self = (UnixSocketDeletionWatcher*)arg;
return self->RunThreadInner();
}
void *UnixSocketDeletionWatcher::RunThreadInner() {
char buf[sizeof(struct inotify_event) + NAME_MAX + 1]
__attribute__ ((aligned(__alignof__(struct inotify_event))));
fd_set fdset;
struct stat st;
for(;;) {
if (lstat(path_, &st) == -1 && errno == ENOENT) {
RINFO("Unix socket %s deleted.", path_);
*flag_to_set_ = true;
kill(pid_, SIGALRM);
break;
}
FD_ZERO(&fdset);
FD_SET(inotify_fd_, &fdset);
FD_SET(pipes_[0], &fdset);
int n = select(std::max(inotify_fd_, pipes_[0]) + 1, &fdset, NULL, NULL, NULL);
if (n == -1) {
perror("select");
break;
}
if (FD_ISSET(inotify_fd_, &fdset)) {
ssize_t len = read(inotify_fd_, buf, sizeof(buf));
if (len == -1) {
perror("read");
break;
}
}
if (FD_ISSET(pipes_[0], &fdset))
break;
}
return NULL;
}
#else // !defined(OS_LINUX)
bool UnixSocketDeletionWatcher::Poll(const char *path) {
struct stat st;
return lstat(path, &st) == -1 && errno == ENOENT;
}
#endif // !defined(OS_LINUX)
static TunsafeBackendBsd *g_tunsafe_backend_bsd; static TunsafeBackendBsd *g_tunsafe_backend_bsd;
static void SigAlrm(int sig) { static void SigAlrm(int sig) {
@ -611,10 +714,6 @@ static void SigAlrm(int sig) {
g_tunsafe_backend_bsd->HandleSigAlrm(); g_tunsafe_backend_bsd->HandleSigAlrm();
} }
static void SigUsr1(int sig) {
}
static bool did_ctrlc; static bool did_ctrlc;
void SigInt(int sig) { void SigInt(int sig) {
@ -623,6 +722,7 @@ void SigInt(int sig) {
did_ctrlc = true; did_ctrlc = true;
write(1, "Ctrl-C detected. Exiting. Press again to force quit.\n", sizeof("Ctrl-C detected. Exiting. Press again to force quit.\n")-1); write(1, "Ctrl-C detected. Exiting. Press again to force quit.\n", sizeof("Ctrl-C detected. Exiting. Press again to force quit.\n")-1);
// todo: fix signal safety?
if (g_tunsafe_backend_bsd) if (g_tunsafe_backend_bsd)
g_tunsafe_backend_bsd->HandleExit(); g_tunsafe_backend_bsd->HandleExit();
} }
@ -631,7 +731,10 @@ void TunsafeBackendBsd::RunLoop() {
assert(!g_tunsafe_backend_bsd); assert(!g_tunsafe_backend_bsd);
assert(processor_); assert(processor_);
sigset_t mask;
g_tunsafe_backend_bsd = this; g_tunsafe_backend_bsd = this;
// We want an alarm signal every second. // We want an alarm signal every second.
{ {
struct sigaction act = {0}; struct sigaction act = {0};
@ -651,16 +754,14 @@ void TunsafeBackendBsd::RunLoop() {
} }
} }
{ #if defined(OS_LINUX) || defined(OS_FREEBSD)
struct sigaction act = {0}; sigemptyset(&mask);
act.sa_handler = SigUsr1; sigaddset(&mask, SIGALRM);
if (sigaction(SIGUSR1, &act, NULL) < 0) { if (sigprocmask(SIG_BLOCK, &mask, &orig_signal_mask_) < 0) {
RERROR("Unable to install SIGUSR1 handler."); perror("sigprocmask");
return; return;
} }
}
#if defined(OS_LINUX) || defined(OS_FREEBSD)
{ {
struct itimerspec tv = {0}; struct itimerspec tv = {0};
struct sigevent sev; struct sigevent sev;
@ -727,7 +828,17 @@ public:
bool is_connected_; bool is_connected_;
}; };
struct CommandLineOutput {
const char *filename_to_load;
const char *interface_name;
bool daemon;
};
int HandleCommandLine(int argc, char **argv, CommandLineOutput *output);
int main(int argc, char **argv) { int main(int argc, char **argv) {
CommandLineOutput cmd = {0};
InitCpuFeatures(); InitCpuFeatures();
if (argc == 2 && strcmp(argv[1], "--benchmark") == 0) { if (argc == 2 && strcmp(argv[1], "--benchmark") == 0) {
@ -735,12 +846,9 @@ int main(int argc, char **argv) {
return 0; return 0;
} }
fprintf(stderr, "%s\n", TUNSAFE_VERSION_STRING); int rv = HandleCommandLine(argc, argv, &cmd);
if (!cmd.filename_to_load)
if (argc < 2) { return rv;
fprintf(stderr, "Syntax: tunsafe file.conf\n");
return 1;
}
#if defined(OS_MACOSX) #if defined(OS_MACOSX)
InitOsxGetMilliseconds(); InitOsxGetMilliseconds();
@ -749,19 +857,29 @@ int main(int argc, char **argv) {
SetThreadName("tunsafe-m"); SetThreadName("tunsafe-m");
MyProcessorDelegate my_procdel; MyProcessorDelegate my_procdel;
TunsafeBackendBsd *socket_loop = CreateTunsafeBackendBsd(); TunsafeBackendBsd *backend = CreateTunsafeBackendBsd();
WireguardProcessor wg(socket_loop, socket_loop, &my_procdel); if (cmd.interface_name)
backend->SetTunDeviceName(cmd.interface_name);
WireguardProcessor wg(backend, backend, &my_procdel);
my_procdel.wg_processor_ = &wg; my_procdel.wg_processor_ = &wg;
socket_loop->SetProcessor(&wg); backend->SetProcessor(&wg);
DnsResolver dns_resolver(NULL); DnsResolver dns_resolver(NULL);
if (!ParseWireGuardConfigFile(&wg, argv[1], &dns_resolver)) return 1; if (*cmd.filename_to_load && !ParseWireGuardConfigFile(&wg, cmd.filename_to_load, &dns_resolver))
return 1;
if (!wg.Start()) return 1; if (!wg.Start()) return 1;
socket_loop->RunLoop(); if (cmd.daemon) {
socket_loop->CleanupRoutes(); fprintf(stderr, "Switching to daemon mode...\n");
delete socket_loop; if (daemon(0, 0) == -1)
perror("daemon() failed");
}
backend->RunLoop();
backend->CleanupRoutes();
delete backend;
return 0; return 0;
} }

View file

@ -7,6 +7,7 @@
#include "wireguard.h" #include "wireguard.h"
#include "wireguard_config.h" #include "wireguard_config.h"
#include <string> #include <string>
#include <signal.h>
struct RouteInfo { struct RouteInfo {
uint8 family; uint8 family;
@ -16,6 +17,39 @@ struct RouteInfo {
std::string dev; std::string dev;
}; };
#if defined(OS_LINUX)
// Keeps track of when the unix socket gets deleted
class UnixSocketDeletionWatcher {
public:
UnixSocketDeletionWatcher();
~UnixSocketDeletionWatcher();
bool Start(const char *path, bool *flag_to_set);
void Stop();
bool Poll(const char *path) { return false; }
private:
static void *RunThread(void *arg);
void *RunThreadInner();
const char *path_;
int inotify_fd_;
int pid_;
int pipes_[2];
pthread_t thread_;
bool *flag_to_set_;
};
#else // !defined(OS_LINUX)
// all other platforms that lack inotify
class UnixSocketDeletionWatcher {
public:
UnixSocketDeletionWatcher() {}
~UnixSocketDeletionWatcher() {}
bool Start(const char *path, bool *flag_to_set) { return true; }
void Stop() {}
bool Poll(const char *path);
};
#endif // !defined(OS_LINUX)
class TunsafeBackendBsd : public TunInterface, public UdpInterface { class TunsafeBackendBsd : public TunInterface, public UdpInterface {
public: public:
TunsafeBackendBsd(); TunsafeBackendBsd();
@ -24,10 +58,12 @@ public:
void RunLoop(); void RunLoop();
void CleanupRoutes(); void CleanupRoutes();
void SetTunDeviceName(const char *name);
void SetProcessor(WireguardProcessor *wg) { processor_ = wg; } void SetProcessor(WireguardProcessor *wg) { processor_ = wg; }
// -- from TunInterface // -- from TunInterface
virtual bool Initialize(const TunConfig &&config, TunConfigOut *out) override; virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override;
virtual void HandleSigAlrm() = 0; virtual void HandleSigAlrm() = 0;
virtual void HandleExit() = 0; virtual void HandleExit() = 0;
@ -44,6 +80,9 @@ protected:
WireguardProcessor *processor_; WireguardProcessor *processor_;
std::vector<RouteInfo> cleanup_commands_; std::vector<RouteInfo> cleanup_commands_;
std::vector<std::string> pre_down_, post_down_; std::vector<std::string> pre_down_, post_down_;
sigset_t orig_signal_mask_;
char devname_[16];
bool tun_interface_gone_;
}; };
#if defined(OS_MACOSX) || defined(OS_FREEBSD) #if defined(OS_MACOSX) || defined(OS_FREEBSD)

File diff suppressed because it is too large Load diff

View file

@ -11,34 +11,39 @@
#include "tunsafe_threading.h" #include "tunsafe_threading.h"
#include <functional> #include <functional>
enum {
ADAPTER_GUID_SIZE = 40,
};
struct Packet; struct Packet;
class WireguardProcessor; class WireguardProcessor;
class TunsafeBackendWin32; class TunsafeBackendWin32;
class ThreadedPacketQueue { class PacketProcessor {
public: public:
explicit ThreadedPacketQueue(WireguardProcessor *wg, TunsafeBackendWin32 *backend); explicit PacketProcessor();
~ThreadedPacketQueue(); ~PacketProcessor();
enum { enum {
TARGET_PROCESSOR_UDP = 0, TARGET_PROCESSOR_UDP = 0,
TARGET_PROCESSOR_TUN = 1, TARGET_PROCESSOR_TUN = 1,
TARGET_UDP_DEVICE = 2, TARGET_UDP_DEVICE = 2,
TARGET_TUN_DEVICE = 3, TARGET_TUN_DEVICE = 3,
TARGET_CONFIG_PROTOCOL = 4,
}; };
void Start(); void Reset();
void Stop();
int Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend);
void Post(Packet *packet, Packet **end, int count); void Post(Packet *packet, Packet **end, int count);
void AbortingDriver(); void ForcePost(Packet *packet);
void PostExit(int exit_code);
const uint32 *posted_exit_code() { return &exit_code_; }
private: private:
void PostTimerInterrupt(); static void CALLBACK ThreadPoolTimerCallback(PTP_CALLBACK_INSTANCE iTimerInstance, PVOID pContext, PTP_TIMER);
static void CALLBACK TimerRoutine(LPVOID lpArgToCompletionRoutine, DWORD dwTimerLowValue, DWORD dwTimerHighValue); void HandleConfigurationProtocolPacket(WireguardProcessor *wg, TunsafeBackendWin32 *backend, Packet *packet);
DWORD ThreadMain();
static DWORD WINAPI ThreadedPacketQueueLauncher(VOID *x);
Packet *first_; Packet *first_;
Packet **last_ptr_; Packet **last_ptr_;
uint32 packets_in_queue_; uint32 packets_in_queue_;
@ -46,12 +51,8 @@ private:
Mutex mutex_; Mutex mutex_;
HANDLE event_; HANDLE event_;
HANDLE timer_handle_; uint32 exit_code_;
HANDLE handle_;
WireguardProcessor *wg_;
bool exit_flag_;
bool timer_interrupt_; bool timer_interrupt_;
TunsafeBackendWin32 *backend_;
}; };
// Encapsulates a UDP socket, optionally listening for incoming packets // Encapsulates a UDP socket, optionally listening for incoming packets
@ -61,17 +62,16 @@ public:
explicit UdpSocketWin32(); explicit UdpSocketWin32();
~UdpSocketWin32(); ~UdpSocketWin32();
void SetPacketHandler(ThreadedPacketQueue *packet_handler) { packet_handler_ = packet_handler; } void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
void StartThread(); void StartThread();
void StopThread(); void StopThread();
// -- from UdpInterface // -- from UdpInterface
virtual bool Initialize(int listen_on_port) override; virtual bool Configure(int listen_on_port) override;
virtual void WriteUdpPacket(Packet *packet) override; virtual void WriteUdpPacket(Packet *packet) override;
private: private:
void ThreadMain(); void ThreadMain();
static DWORD WINAPI UdpThread(void *x); static DWORD WINAPI UdpThread(void *x);
@ -80,7 +80,7 @@ private:
Mutex mutex_; Mutex mutex_;
ThreadedPacketQueue *packet_handler_; PacketProcessor *packet_handler_;
SOCKET socket_; SOCKET socket_;
SOCKET socket_ipv6_; SOCKET socket_ipv6_;
HANDLE completion_port_handle_; HANDLE completion_port_handle_;
@ -93,12 +93,12 @@ class DnsBlocker;
class TunWin32Adapter { class TunWin32Adapter {
public: public:
TunWin32Adapter(DnsBlocker *dns_blocker); TunWin32Adapter(DnsBlocker *dns_blocker, const char guid[ADAPTER_GUID_SIZE]);
~TunWin32Adapter(); ~TunWin32Adapter();
bool OpenAdapter(unsigned int *exit_thread, DWORD open_flags); bool OpenAdapter(TunsafeBackendWin32 *backend, DWORD open_flags);
bool InitAdapter(const TunInterface::TunConfig &&config, TunInterface::TunConfigOut *out); bool ConfigureAdapter(const TunInterface::TunConfig &&config, TunInterface::TunConfigOut *out);
void CloseAdapter(); void CloseAdapter(bool is_restart);
HANDLE handle() { return handle_; } HANDLE handle() { return handle_; }
@ -121,8 +121,10 @@ private:
NET_LUID interface_luid_; NET_LUID interface_luid_;
void *backend_;
std::vector<std::string> pre_down_, post_down_; std::vector<std::string> pre_down_, post_down_;
char guid_[64]; char guid_[ADAPTER_GUID_SIZE];
}; };
// Implementation of TUN interface handling using IO Completion Ports // Implementation of TUN interface handling using IO Completion Ports
@ -131,23 +133,23 @@ public:
explicit TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend); explicit TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend);
~TunWin32Iocp(); ~TunWin32Iocp();
void SetPacketHandler(ThreadedPacketQueue *packet_handler) { packet_handler_ = packet_handler; } void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
void StartThread(); void StartThread();
void StopThread(); void StopThread();
// -- from TunInterface // -- from TunInterface
virtual bool Initialize(const TunConfig &&config, TunConfigOut *out) override; virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override;
virtual void WriteTunPacket(Packet *packet) override; virtual void WriteTunPacket(Packet *packet) override;
TunWin32Adapter &adapter() { return adapter_; } TunWin32Adapter &adapter() { return adapter_; }
private: private:
void CloseTun(); void CloseTun(bool is_restart);
void ThreadMain(); void ThreadMain();
static DWORD WINAPI TunThread(void *x); static DWORD WINAPI TunThread(void *x);
ThreadedPacketQueue *packet_handler_; PacketProcessor *packet_handler_;
HANDLE completion_port_handle_; HANDLE completion_port_handle_;
HANDLE thread_; HANDLE thread_;
@ -168,13 +170,13 @@ public:
explicit TunWin32Overlapped(DnsBlocker *blocker, TunsafeBackendWin32 *backend); explicit TunWin32Overlapped(DnsBlocker *blocker, TunsafeBackendWin32 *backend);
~TunWin32Overlapped(); ~TunWin32Overlapped();
void SetPacketHandler(ThreadedPacketQueue *packet_handler) { packet_handler_ = packet_handler; } void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
void StartThread(); void StartThread();
void StopThread(); void StopThread();
// -- from TunInterface // -- from TunInterface
virtual bool Initialize(const TunConfig &&config, TunConfigOut *out) override; virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override;
virtual void WriteTunPacket(Packet *packet) override; virtual void WriteTunPacket(Packet *packet) override;
private: private:
@ -182,7 +184,7 @@ private:
void ThreadMain(); void ThreadMain();
static DWORD WINAPI TunThread(void *x); static DWORD WINAPI TunThread(void *x);
ThreadedPacketQueue *packet_handler_; PacketProcessor *packet_handler_;
HANDLE thread_; HANDLE thread_;
Mutex mutex_; Mutex mutex_;
@ -199,16 +201,18 @@ private:
}; };
class TunsafeBackendWin32 : public TunsafeBackend, public ProcessorDelegate { class TunsafeBackendWin32 : public TunsafeBackend, public ProcessorDelegate {
friend class ThreadedPacketQueue; friend class PacketProcessor;
friend class TunWin32Iocp; friend class TunWin32Iocp;
friend class TunWin32Overlapped; friend class TunWin32Overlapped;
friend class TunWin32Adapter;
public: public:
TunsafeBackendWin32(Delegate *delegate); TunsafeBackendWin32(Delegate *delegate);
~TunsafeBackendWin32(); ~TunsafeBackendWin32();
// -- from TunsafeBackend // -- from TunsafeBackend
virtual bool Initialize() override; virtual bool Configure() override;
virtual void Teardown() override; virtual void Teardown() override;
virtual bool SetTunAdapterName(const char *name) override;
virtual void Start(const char *config_file) override; virtual void Start(const char *config_file) override;
virtual void Stop() override; virtual void Stop() override;
virtual void RequestStats(bool enable) override; virtual void RequestStats(bool enable) override;
@ -218,13 +222,23 @@ public:
virtual void SetServiceStartupFlags(uint32 flags) override; virtual void SetServiceStartupFlags(uint32 flags) override;
virtual LinearizedGraph *GetGraph(int type) override; virtual LinearizedGraph *GetGraph(int type) override;
virtual std::string GetConfigFileName() override; virtual std::string GetConfigFileName() override;
virtual void SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) override;
// -- from ProcessorDelegate // -- from ProcessorDelegate
virtual void OnConnected() override; virtual void OnConnected() override;
virtual void OnConnectionRetry(uint32 attempts) override; virtual void OnConnectionRetry(uint32 attempts) override;
void SetPublicKey(const uint8 key[32]); void SetPublicKey(const uint8 key[32]);
void TunAdapterFailed(); void PostExit(int exit_code);
enum {
MODE_NONE = 0,
MODE_EXIT = 1,
MODE_RESTART = 2,
MODE_TUN_FAILED = 3,
};
uint32 exit_code() { return *packet_processor_.posted_exit_code(); }
void SetStatus(StatusCode status);
private: private:
void StopInner(bool is_restart); void StopInner(bool is_restart);
@ -232,16 +246,7 @@ private:
void PushStats(); void PushStats();
HANDLE worker_thread_; HANDLE worker_thread_;
enum {
MODE_NONE = 0,
MODE_EXIT = 1,
MODE_RESTART = 2,
MODE_TUN_FAILED = 3,
};
bool want_periodic_stats_; bool want_periodic_stats_;
unsigned int stop_mode_;
Delegate *delegate_; Delegate *delegate_;
char *config_file_; char *config_file_;
@ -256,6 +261,10 @@ private:
Mutex stats_mutex_; Mutex stats_mutex_;
WgProcessorStats stats_; WgProcessorStats stats_;
PacketProcessor packet_processor_;
char guid_[ADAPTER_GUID_SIZE];
}; };
// This class ensures that all callbacks get rescheduled to another thread // This class ensures that all callbacks get rescheduled to another thread
@ -265,13 +274,14 @@ public:
~TunsafeBackendDelegateThreaded(); ~TunsafeBackendDelegateThreaded();
private: private:
virtual void OnGetStats(const WgProcessorStats &stats); virtual void OnGetStats(const WgProcessorStats &stats) override;
virtual void OnGraphAvailable(); virtual void OnGraphAvailable() override;
virtual void OnStateChanged(); virtual void OnStateChanged() override;
virtual void OnClearLog(); virtual void OnClearLog() override;
virtual void OnLogLine(const char **s); virtual void OnLogLine(const char **s) override;
virtual void OnStatusCode(TunsafeBackend::StatusCode status); virtual void OnStatusCode(TunsafeBackend::StatusCode status) override;
virtual void DoWork(); virtual void OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) override;
virtual void DoWork() override;
enum Which { enum Which {
Id_OnGetStats, Id_OnGetStats,
@ -281,6 +291,7 @@ private:
Id_OnUpdateUI, Id_OnUpdateUI,
Id_OnStatusCode, Id_OnStatusCode,
Id_OnGraphAvailable, Id_OnGraphAvailable,
Id_OnConfigurationProtocolReply,
}; };
void AddEntry(Which which, intptr_t lparam = 0, uint32 wparam = 0); void AddEntry(Which which, intptr_t lparam = 0, uint32 wparam = 0);
@ -302,3 +313,37 @@ private:
std::vector<Entry> processing_entry_; std::vector<Entry> processing_entry_;
}; };
// For each adapter, remembers whether the adapter is in use
class TunAdaptersInUse {
public:
TunAdaptersInUse();
// attempt to acquire the adapter, so it can't be acquired by anyone else
bool Acquire(const char guid[ADAPTER_GUID_SIZE], void *context);
// mark as free
void Release(void *context);
// Lookup a context from a guid
void *LookupContextFromGuid(const char guid[ADAPTER_GUID_SIZE]);
// Lookup a guid from a context
bool LookupGuidFromContext(void *context, char guid[ADAPTER_GUID_SIZE]);
char *GetAllGuid();
static TunAdaptersInUse *GetInstance();
private:
enum {
kMaxAdaptersInUse = 16,
};
struct Entry {
char guid[ADAPTER_GUID_SIZE];
void *context;
int count;
};
Mutex mutex_;
uint8 num_inuse_;
Entry entry_[kMaxAdaptersInUse];
};

View file

@ -5,7 +5,6 @@
#include "stdafx.h" #include "stdafx.h"
#include "tunsafe_types.h" #include "tunsafe_types.h"
#include "wireguard.h" #include "wireguard.h"
#include <functional> #include <functional>
struct StatsCollector { struct StatsCollector {
@ -72,6 +71,7 @@ public:
virtual void OnClearLog() = 0; virtual void OnClearLog() = 0;
virtual void OnLogLine(const char **s) = 0; virtual void OnLogLine(const char **s) = 0;
virtual void OnStatusCode(TunsafeBackend::StatusCode status) = 0; virtual void OnStatusCode(TunsafeBackend::StatusCode status) = 0;
virtual void OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) = 0;
// This function is needed for CreateTunsafeBackendDelegateThreaded, // This function is needed for CreateTunsafeBackendDelegateThreaded,
// It's expected to be called on the main thread and then all callbacks will arrive // It's expected to be called on the main thread and then all callbacks will arrive
// on the right thread. // on the right thread.
@ -82,9 +82,16 @@ public:
virtual ~TunsafeBackend(); virtual ~TunsafeBackend();
// Setup/teardown the connection to the local service (if any) // Setup/teardown the connection to the local service (if any)
virtual bool Initialize() = 0; virtual bool Configure() = 0;
virtual void Teardown() = 0; virtual void Teardown() = 0;
// Set the name of the tun adapter that we want to use.
// On Windows this is the guid of the adapter.
// After having called this, this tun name cannot be used by any other instances.
// Returns false if the name can't be exclusively reserved to this adapter.
virtual bool SetTunAdapterName(const char *name) = 0;
virtual void Start(const char *config_file) = 0; virtual void Start(const char *config_file) = 0;
virtual void Stop() = 0; virtual void Stop() = 0;
virtual void RequestStats(bool enable) = 0; virtual void RequestStats(bool enable) = 0;
@ -93,10 +100,9 @@ public:
virtual InternetBlockState GetInternetBlockState(bool *is_activated) = 0; virtual InternetBlockState GetInternetBlockState(bool *is_activated) = 0;
virtual void SetInternetBlockState(InternetBlockState s) = 0; virtual void SetInternetBlockState(InternetBlockState s) = 0;
virtual void SetServiceStartupFlags(uint32 flags) = 0; virtual void SetServiceStartupFlags(uint32 flags) = 0;
virtual std::string GetConfigFileName() = 0; virtual std::string GetConfigFileName() = 0;
virtual LinearizedGraph *GetGraph(int type) = 0; virtual LinearizedGraph *GetGraph(int type) = 0;
virtual void SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) = 0;
bool is_started() { return is_started_; } bool is_started() { return is_started_; }
bool is_remote() { return is_remote_; } bool is_remote() { return is_remote_; }
@ -105,6 +111,9 @@ public:
StatusCode status() { return status_; } StatusCode status() { return status_; }
uint32 GetIP() { return ipv4_ip_; } uint32 GetIP() { return ipv4_ip_; }
static TunsafeBackend *FindBackendByTunGuid(const char *guid);
static char *GetAllGuid();
protected: protected:
bool is_started_; bool is_started_;
bool is_remote_; bool is_remote_;

374
service_pipe_win32.cpp Normal file
View file

@ -0,0 +1,374 @@
// SPDX-License-Identifier: AGPL-1.0-only
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
#include "stdafx.h"
#include "service_pipe_win32.h"
#include "util.h"
#include "service_win32_constants.h"
///////////////////////////////////////////////////////////////////////////////////////
// PipeManager
///////////////////////////////////////////////////////////////////////////////////////
PipeManager::PipeManager(const char *pipe_name, bool is_server_pipe, Delegate *delegate) {
pipe_name_ = _strdup(pipe_name);
is_server_pipe_ = is_server_pipe;
for (size_t i = 0; i < kMaxConnections * 2 + 1; i++)
events_[i] = CreateEvent(NULL, i != 0, FALSE, NULL); // For Exit
delegate_ = delegate;
thread_ = NULL;
exit_thread_ = false;
thread_id_ = 0;
for (size_t i = 0; i != kMaxConnections; i++)
connections_[i].Configure(this, (int)i);
connections_[0].state_ = PipeConnection::kStateStarting;
}
PipeManager::~PipeManager() {
StopThread();
for (size_t i = 0; i < kMaxConnections * 2 + 1; i++)
CloseHandle(events_[i]);
free(pipe_name_);
}
bool PipeManager::StartThread() {
assert(thread_ == NULL);
thread_ = CreateThread(NULL, 0, &StaticThreadMain, this, 0, &thread_id_);
return thread_ != NULL;
}
void PipeManager::StopThread() {
if (thread_ != NULL) {
exit_thread_ = true;
SetEvent(events_[0]);
WaitForSingleObject(thread_, INFINITE);
CloseHandle(thread_);
thread_ = NULL;
}
}
bool PipeManager::VerifyThread() {
return thread_id_ == GetCurrentThreadId();
}
void PipeManager::TryStartNewListener() {
assert(VerifyThread());
assert(is_server_pipe_);
// Check if any thread is in the listener state, if not, start
PipeConnection *found_conn = NULL;
for (size_t i = 0; i < kMaxConnections; i++) {
PipeConnection *conn = &connections_[i];
if (conn->connection_established_)
continue;
if (conn->state_ == PipeConnection::kStateWaitConnect)
return;
if (conn->state_ == PipeConnection::kStateNone && found_conn == NULL)
found_conn = conn;
}
if (found_conn) {
found_conn->state_ = PipeConnection::kStateStarting;
found_conn->AdvanceStateMachine();
}
}
DWORD WINAPI PipeManager::StaticThreadMain(void *x) {
return ((PipeManager*)x)->ThreadMain();
}
DWORD PipeManager::ThreadMain() {
assert(VerifyThread());
for (size_t i = 0; i < kMaxConnections; i++)
connections_[i].AdvanceStateMachine();
for (;;) {
DWORD rv = WaitForMultipleObjects(1 + kMaxConnections * 2, events_, FALSE, INFINITE);
// notify?
if (rv == WAIT_OBJECT_0) {
if (exit_thread_)
break;
delegate_->HandleNotify();
// The notification event is set when there might be new messages to send,
// so try to send them.
for (size_t i = 0; i != kMaxConnections; i++)
connections_[i].TrySendNextQueuedWrite();
} else if (rv >= WAIT_OBJECT_0 + 1 && rv < WAIT_OBJECT_0 + 1 + kMaxConnections * 2) {
PipeConnection *conn = &connections_[(rv - 1) >> 1];
if (rv & 1) {
// read finished
conn->AdvanceStateMachine();
} else {
// is the write event
conn->HandleWriteComplete();
}
} else {
assert(0);
}
}
return 0;
}
///////////////////////////////////////////////////////////////////////////////////////
// PipeConnection
///////////////////////////////////////////////////////////////////////////////////////
static void ClearPipeOverlapped(OVERLAPPED *ov) {
ov->Internal = 0;
ov->InternalHigh = 0;
ov->Offset = 0;
ov->OffsetHigh = 0;
}
PipeConnection::PipeConnection() {
pipe_ = INVALID_HANDLE_VALUE;
packets_ = NULL;
packets_end_ = &packets_;
write_overlapped_active_ = false;
connection_established_ = false;
state_ = kStateNone;
tmp_packet_buf_ = NULL;
tmp_packet_size_ = 0;
manager_ = NULL;
delegate_ = NULL;
}
PipeConnection::~PipeConnection() {
}
void PipeConnection::Configure(PipeManager *manager, int slot) {
manager_ = manager;
read_overlapped_.hEvent = manager->events_[1 + slot * 2];
write_overlapped_.hEvent = manager->events_[1 + slot * 2 + 1];
}
int PipeConnection::InitializeServerPipeAndConnect() {
int BUFSIZE = 8192;
SECURITY_ATTRIBUTES saPipeSecurity = {0};
uint8 buf[SECURITY_DESCRIPTOR_MIN_LENGTH];
PSECURITY_DESCRIPTOR pPipeSD = (PSECURITY_DESCRIPTOR)buf;
if (!InitializeSecurityDescriptor(pPipeSD, SECURITY_DESCRIPTOR_REVISION))
return -1;
// set NULL DACL on the SD
if (!SetSecurityDescriptorDacl(pPipeSD, TRUE, (PACL)NULL, FALSE))
return -1;
// now set up the security attributes
saPipeSecurity.nLength = sizeof(SECURITY_ATTRIBUTES);
saPipeSecurity.bInheritHandle = TRUE;
saPipeSecurity.lpSecurityDescriptor = pPipeSD;
pipe_ = CreateNamedPipe(manager_->pipe_name_,
PIPE_ACCESS_DUPLEX | FILE_FLAG_OVERLAPPED,
PIPE_TYPE_BYTE | PIPE_READMODE_BYTE | PIPE_REJECT_REMOTE_CLIENTS | PIPE_WAIT,
PIPE_UNLIMITED_INSTANCES,
BUFSIZE, BUFSIZE, 0, &saPipeSecurity);
if (pipe_ == INVALID_HANDLE_VALUE)
return -1;
ClearPipeOverlapped(&read_overlapped_);
// It seems like ConnectNamedPipe never sets the event object if it completes
// right away.
if (!ConnectNamedPipe(pipe_, &read_overlapped_)) {
DWORD rv = GetLastError();
return (rv == ERROR_IO_PENDING) ? 0 : (rv == ERROR_PIPE_CONNECTED) ? 1 : -1;
} else {
return 1;
}
}
bool PipeConnection::InitializeClientPipe() {
assert(pipe_ == INVALID_HANDLE_VALUE);
pipe_ = CreateFile(manager_->pipe_name_, GENERIC_READ | GENERIC_WRITE, 0, NULL,
OPEN_EXISTING, FILE_FLAG_OVERLAPPED, NULL);
return (pipe_ != INVALID_HANDLE_VALUE);
}
void PipeConnection::ClosePipe() {
if (pipe_ != INVALID_HANDLE_VALUE) {
CancelIo(pipe_);
CloseHandle(pipe_);
pipe_ = INVALID_HANDLE_VALUE;
}
connection_established_ = false;
write_overlapped_active_ = false;
state_ = kStateNone;
free(tmp_packet_buf_);
tmp_packet_buf_ = NULL;
tmp_packet_size_ = 0;
ResetEvent(read_overlapped_.hEvent);
ResetEvent(write_overlapped_.hEvent);
packets_mutex_.Acquire();
OutgoingPacket *packets = packets_;
packets_ = NULL;
packets_end_ = &packets_;
packets_mutex_.Release();
while (packets) {
OutgoingPacket *p = packets;
packets = p->next;
free(p);
}
}
void PipeConnection::HandleWriteComplete() {
assert(write_overlapped_active_);
write_overlapped_active_ = false;
// Remove the packet from the front of the queue, now that it was sent.
packets_mutex_.Acquire();
OutgoingPacket *p = packets_;
if ((packets_ = p->next) == NULL)
packets_end_ = &packets_;
packets_mutex_.Release();
free(p);
if (packets_ == NULL && state_ == kStateWaitTimeout)
AdvanceStateMachine();
else
TrySendNextQueuedWrite();
}
bool PipeConnection::WritePacket(int type, const uint8 *data, size_t data_size) {
OutgoingPacket *packet = (OutgoingPacket *)malloc(offsetof(OutgoingPacket, data[data_size + 1]));
if (packet) {
packet->size = (uint32)(data_size + 1);
packet->data[0] = type;
memcpy(packet->data + 1, data, data_size);
packet->next = NULL;
packets_mutex_.Acquire();
OutgoingPacket *was_empty = packets_;
// login messages are always queued up front
if (type == TS_SERVICE_REQ_LOGIN) {
packet->next = packets_;
if (packet->next == NULL)
packets_end_ = &packet->next;
packets_ = packet;
} else {
*packets_end_ = packet;
packets_end_ = &packet->next;
}
packets_mutex_.Release();
if (was_empty == NULL) {
// Only allow the pipe thread to invoke the send
if (GetCurrentThreadId() == manager_->thread_id_) {
TrySendNextQueuedWrite();
} else {
SetEvent(manager_->notify_handle());
}
}
}
return true;
}
bool PipeConnection::VerifyThread() {
return manager_->VerifyThread();
}
void PipeConnection::TrySendNextQueuedWrite() {
assert(manager_->VerifyThread());
if (!write_overlapped_active_) {
OutgoingPacket *p = packets_;
if (p && connection_established_) {
ClearPipeOverlapped(&write_overlapped_);
if (WriteFile(pipe_, &p->size, p->size + 4, NULL, &write_overlapped_) || GetLastError() == ERROR_IO_PENDING)
write_overlapped_active_ = true;
} else {
ResetEvent(write_overlapped_.hEvent);
}
}
}
#define TS_WAIT_BEGIN(t) switch(state_) { case t:
#define TS_WAIT_POINT(t) state_ = (t); return; case t:
#define TS_WAIT_END() }
void PipeConnection::AdvanceStateMachine() {
DWORD rv;
int srv;
TS_WAIT_BEGIN(kStateStarting)
// Create a named pipe and wait for connections from the UI process
if (manager_->is_server_pipe_) {
srv = InitializeServerPipeAndConnect();
if (srv < 0) {
if (!manager_->exit_thread_)
ExitProcess(1);
ClosePipe();
return;
}
if (srv == 0) {
TS_WAIT_POINT(kStateWaitConnect);
}
} else {
if (!InitializeClientPipe()) {
RINFO("Unable to connect to the TunSafe Service. Please make sure it's running.");
ClosePipe();
return;
}
}
connection_established_ = true;
delegate_ = manager_->delegate_->HandleNewConnection(this);
TrySendNextQueuedWrite();
for (;;) {
// Read the packet length
read_pos_ = 0;
do {
ClearPipeOverlapped(&read_overlapped_);
if (!ReadFile(pipe_, (uint8*)&packet_size_ + read_pos_, 4 - read_pos_, NULL, &read_overlapped_)) {
if ((rv = GetLastError()) != ERROR_IO_PENDING)
goto fail;
TS_WAIT_POINT(kStateWaitReadLength);
}
if ((uint32)read_overlapped_.InternalHigh == 0)
goto fail;
read_pos_ += (uint32)read_overlapped_.InternalHigh;
} while (read_pos_ != 4);
assert(packet_size_ != 0 && packet_size_ < 0x1000000);
if (packet_size_ == 0 || packet_size_ >= 0x1000000)
break;
free(tmp_packet_buf_);
tmp_packet_buf_ = (uint8*)malloc(packet_size_);
if (!tmp_packet_buf_)
break;
// Read the packet payload
read_pos_ = 0;
do {
ClearPipeOverlapped(&read_overlapped_);
if (!ReadFile(pipe_, tmp_packet_buf_ + read_pos_, packet_size_ - read_pos_, NULL, &read_overlapped_)) {
if ((rv = GetLastError()) != ERROR_IO_PENDING)
goto fail;
TS_WAIT_POINT(kStateWaitReadPayload);
}
if ((uint32)read_overlapped_.InternalHigh == 0)
goto fail;
read_pos_ += (uint32)read_overlapped_.InternalHigh;
} while (read_pos_ != packet_size_);
if (!delegate_->HandleMessage(tmp_packet_buf_[0], tmp_packet_buf_ + 1, packet_size_ - 1)) {
ResetEvent(read_overlapped_.hEvent);
if (packets_ != NULL) {
TS_WAIT_POINT(kStateWaitTimeout);
}
break;
}
}
fail:
ClosePipe();
if (!manager_->exit_thread_) {
delegate_->HandleDisconnect();
if (manager_->is_server_pipe_)
manager_->TryStartNewListener();
}
TS_WAIT_END()
}

111
service_pipe_win32.h Normal file
View file

@ -0,0 +1,111 @@
// SPDX-License-Identifier: AGPL-1.0-only
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
#pragma once
#include "tunsafe_threading.h"
#include "network_win32_api.h"
class PipeManager;
// Once a pipe connects, this object is used to facilitate the connection
class PipeConnection {
friend class PipeManager;
public:
class Delegate {
public:
virtual bool HandleMessage(int type, uint8 *data, size_t size) = 0;
virtual void HandleDisconnect() = 0;
};
PipeConnection();
~PipeConnection();
void Configure(PipeManager *manager, int slot);
bool WritePacket(int type, const uint8 *data, size_t data_size);
HANDLE pipe_handle() { return pipe_; }
bool is_connected() { return connection_established_; }
bool VerifyThread();
private:
// -1 = fail, 0 = wait, 1 = conn
int InitializeServerPipeAndConnect();
bool InitializeClientPipe();
void AdvanceStateMachine();
void ClosePipe();
void TrySendNextQueuedWrite();
void HandleWriteComplete();
Delegate *delegate_;
PipeManager *manager_;
HANDLE pipe_;
bool write_overlapped_active_;
bool connection_established_;
enum State {
kStateNone,
kStateStarting,
kStateWaitConnect,
kStateWaitReadLength,
kStateWaitReadPayload,
kStateWaitTimeout,
};
uint8 state_;
uint32 packet_size_;
struct OutgoingPacket {
OutgoingPacket *next;
uint32 size;
uint8 data[0];
};
OutgoingPacket *packets_, **packets_end_;
uint8 *tmp_packet_buf_;
DWORD tmp_packet_size_;
DWORD read_pos_;
OVERLAPPED write_overlapped_, read_overlapped_;
Mutex packets_mutex_;
};
// This class supports multiple PipeConnections and calls HandleNewConnection
// when a new pipe connection is established.
class PipeManager {
friend class PipeConnection;
public:
class Delegate {
public:
// Called when a new connection is established
virtual PipeConnection::Delegate *HandleNewConnection(PipeConnection *handler) = 0;
// Called when a notification event was pushed
virtual void HandleNotify() = 0;
};
PipeManager(const char *pipe_name, bool is_server_pipe, Delegate *delegate);
~PipeManager();
bool StartThread();
void StopThread();
bool VerifyThread();
HANDLE notify_handle() { return events_[0]; }
PipeConnection *GetClientConnection() { return &connections_[0]; }
void TryStartNewListener();
private:
DWORD ThreadMain();
static DWORD WINAPI StaticThreadMain(void *x);
Delegate *delegate_;
HANDLE thread_;
char *pipe_name_;
DWORD thread_id_;
bool is_server_pipe_;
bool exit_thread_;
enum { kMaxConnections = 2 };
HANDLE events_[1 + kMaxConnections * 2];
PipeConnection connections_[kMaxConnections];
};

File diff suppressed because it is too large Load diff

View file

@ -3,13 +3,162 @@
#pragma once #pragma once
#include "service_win32_api.h" #include "service_win32_api.h"
#include <strsafe.h> #include "service_pipe_win32.h"
#include "util.h"
#include "network_win32_api.h" #include "network_win32_api.h"
#include "tunsafe_threading.h" #include "tunsafe_threading.h"
#include <algorithm>
#include <string> // Takes care of multiple TunsafeServiceBackend
#include <assert.h> class TunsafeServiceManager : public PipeManager::Delegate {
friend class TunsafeServiceBackend;
friend class TunsafeServiceServer;
public:
TunsafeServiceManager();
virtual ~TunsafeServiceManager();
// -- from PipeManager::Delegate
virtual void HandleNotify() override;
virtual PipeConnection::Delegate *HandleNewConnection(PipeConnection *connection) override;
// Called by the service control code to bring the service up or down
unsigned OnStart(int argc, wchar_t **argv);
void OnStop();
void OnShutdown();
TunsafeServiceBackend *main_backend() { return main_backend_; }
TunsafeServiceBackend *CreateBackend(const char *guid);
void DestroyBackend(TunsafeServiceBackend *backend);
bool SwitchInterface(TunsafeServiceServer *server, const char *interfac, bool want_create);
private:
// Points at the Tunsafe hklm reg key
HKEY hkey_;
uint32 server_unique_id_;
PipeManager pipe_manager_;
TunsafeServiceBackend *main_backend_;
std::vector<TunsafeServiceBackend *> backends_;
};
// One of these exist for each TunsafeBackend
class TunsafeServiceBackend : public TunsafeBackend::Delegate {
friend class TunsafeServiceServer;
public:
explicit TunsafeServiceBackend(TunsafeServiceManager *manager);
virtual ~TunsafeServiceBackend();
// -- from TunsafeBackend::Delegate
virtual void OnGetStats(const WgProcessorStats &stats) override;
virtual void OnClearLog() override;
virtual void OnLogLine(const char **s) override;
virtual void OnStateChanged() override;
virtual void OnStatusCode(TunsafeBackend::StatusCode status) override;
virtual void OnGraphAvailable() override;
virtual void OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) override;
TunsafeBackend *backend() { return backend_; }
TunsafeBackend::Delegate *delegate() { return thread_delegate_; }
void Start(const char *filename);
void RememberLastUsedConfigFile(const char *filename);
void Stop();
// Trigger backend stats updates whenever a connected pipe client needs it
void UpdateRequestStats();
// Called by TunsafeServiceManager::HandleNotify to process events
// on each backend.
void HandleNotify();
// Send a state update to all connected pipes unless filter is set, then it
// sends only to that.
void SendStateUpdate(TunsafeServiceServer *filter);
// Called whenever a pipe server disconnects
void RemovePipeServer(TunsafeServiceServer *pipe_server);
// Called to register a pipe server with this backend
void AddPipeServer(TunsafeServiceServer *pipe_server);
private:
// Points at the service manager
TunsafeServiceManager *manager_;
// Points at the actual TunsafeBackend
TunsafeBackend *backend_;
// Points at all |TunsafeServiceServer| currently associated with this
// backend.
std::vector<TunsafeServiceServer*> pipe_servers_;
// Points at the thing that transmits TunsafeBackend events to
// the main thread
TunsafeBackend::Delegate *thread_delegate_;
// The config filename that is loaded
std::string current_filename_;
// Positions into |historical_log_lines_|
uint32 historical_log_lines_pos_;
uint32 historical_log_lines_count_;
enum { LOGLINE_COUNT = 256 };
char *historical_log_lines_[LOGLINE_COUNT];
};
// The server side of the client<->server pipe connection
class TunsafeServiceServer : public PipeConnection::Delegate {
public:
TunsafeServiceServer(PipeConnection *pipe, TunsafeServiceBackend *backend, uint32 unique_id);
virtual ~TunsafeServiceServer();
void WritePacket(int type, const uint8 *data, size_t data_size);
// -- from PipeConnection::Delegate
virtual bool HandleMessage(int type, uint8 *data, size_t size) override;
virtual void HandleDisconnect() override;
// Called by TunsafeServiceBackend to push a graph to the client
void OnGraphAvailable();
// Called by TunsafeServiceBackend to push more log lines to the client
void SendQueuedLogLines();
bool want_stats() const { return want_stats_; }
bool want_state_updates() const { return want_state_updates_; }
uint32 unique_id() const { return unique_id_; }
TunsafeServiceBackend *service_backend() { return service_backend_; }
void set_service_backend(TunsafeServiceBackend *sb) { service_backend_ = sb; }
private:
bool AuthenticateUser();
// Whether the client wants state updates
bool want_state_updates_;
// Whether the client has authenticated
bool did_authenticate_user_;
// Whether we want stats
bool want_stats_;
// Whether the currently connected user wants a graph
uint32 want_graph_type_;
// The last log line sent to the currently connected user
uint32 last_line_sent_;
uint32 unique_id_;
// The pipe used to communicate
PipeConnection *connection_;
// The backend we're currently associated with
TunsafeServiceBackend *service_backend_;
};
struct ServiceState { struct ServiceState {
uint8 is_started : 1; uint8 is_started : 1;
@ -22,135 +171,15 @@ struct ServiceState {
STATIC_ASSERT(sizeof(ServiceState) == 128, ServiceState_wrong_size); STATIC_ASSERT(sizeof(ServiceState) == 128, ServiceState_wrong_size);
class PipeMessageHandler { class TunsafeServiceClient : public TunsafeBackend, public PipeConnection::Delegate, public PipeManager::Delegate {
public:
class Delegate {
public:
virtual bool HandleMessage(int type, uint8 *data, size_t size) = 0;
virtual bool HandleNotify() = 0;
virtual void HandleNewConnection() = 0;
virtual void HandleDisconnect() = 0;
};
PipeMessageHandler(const char *pipe_name, bool is_server_pipe, Delegate *delegate);
~PipeMessageHandler();
bool StartThread();
void StopThread();
bool WritePacket(int type, const uint8 *data, size_t data_size);
HANDLE notify_handle() { return wait_handles_[1]; }
HANDLE pipe_handle() { return pipe_; }
bool VerifyThread();
bool is_connected() { return connection_established_; }
private:
bool InitializeServerPipeAndWait();
bool InitializeClientPipe();
void AdvanceStateMachine();
void ClosePipe();
DWORD ThreadMain();
void SendNextQueuedWrite();
static DWORD WINAPI StaticThreadMain(void *x);
Delegate *delegate_;
HANDLE pipe_;
HANDLE thread_;
HANDLE wait_handles_[3];
bool write_overlapped_active_;
bool exit_thread_;
bool is_server_pipe_;
bool connection_established_;
char *pipe_name_;
enum State {
kStateNone,
kStateWaitConnect,
kStateWaitReadLength,
kStateWaitReadPayload,
kStateWaitTimeout,
};
int state_;
struct OutgoingPacket {
OutgoingPacket *next;
uint32 size;
uint8 data[0];
};
OutgoingPacket *packets_, **packets_end_;
uint8 *tmp_packet_buf_;
DWORD tmp_packet_size_;
OVERLAPPED write_overlapped_, read_overlapped_;
Mutex packets_mutex_;
DWORD thread_id_;
};
class TunsafeServiceImpl : public TunsafeBackend::Delegate, public PipeMessageHandler::Delegate {
public:
TunsafeServiceImpl();
virtual ~TunsafeServiceImpl();
// -- from TunsafeBackend::Delegate
virtual void OnGetStats(const WgProcessorStats &stats);
virtual void OnClearLog();
virtual void OnLogLine(const char **s);
virtual void OnStateChanged();
virtual void OnStatusCode(TunsafeBackend::StatusCode status);
virtual void OnGraphAvailable();
// -- from PipeMessageHandler::Delegate
virtual bool HandleMessage(int type, uint8 *data, size_t size);
virtual bool HandleNotify();
virtual void HandleNewConnection();
virtual void HandleDisconnect();
// virtual methods
virtual unsigned OnStart(int argc, wchar_t **argv);
virtual void OnStop();
virtual void OnShutdown();
TunsafeBackend::Delegate *delegate() { return thread_delegate_; }
private:
void SendQueuedLogLines();
bool AuthenticateUser();
bool did_send_getstate_;
bool did_authenticate_user_;
uint32 want_graph_type_;
HKEY hkey_;
TunsafeBackend *backend_;
TunsafeBackend::Delegate *thread_delegate_;
PipeMessageHandler message_handler_;
uint32 historical_log_lines_pos_;
uint32 historical_log_lines_count_;
uint32 last_line_sent_;
std::string current_filename_;
enum {
LOGLINE_COUNT = 256
};
char *historical_log_lines_[LOGLINE_COUNT];
};
class TunsafeServiceClient : public TunsafeBackend, public PipeMessageHandler::Delegate {
public: public:
TunsafeServiceClient(TunsafeBackend::Delegate *delegate); TunsafeServiceClient(TunsafeBackend::Delegate *delegate);
virtual ~TunsafeServiceClient(); virtual ~TunsafeServiceClient();
virtual bool Initialize();
// -- from TunsafeBackend
virtual bool Configure();
virtual void Teardown(); virtual void Teardown();
virtual bool SetTunAdapterName(const char *name);
virtual void Start(const char *config_file); virtual void Start(const char *config_file);
virtual void Stop(); virtual void Stop();
virtual void RequestStats(bool enable); virtual void RequestStats(bool enable);
@ -160,12 +189,16 @@ public:
virtual std::string GetConfigFileName(); virtual std::string GetConfigFileName();
virtual void SetServiceStartupFlags(uint32 flags); virtual void SetServiceStartupFlags(uint32 flags);
virtual LinearizedGraph *GetGraph(int type); virtual LinearizedGraph *GetGraph(int type);
virtual void SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) override;
// -- from PipeConnection::Delegate
virtual bool HandleMessage(int type, uint8 *data, size_t size) override;
virtual void HandleDisconnect() override;
// -- from PipeManager::Delegate
virtual void HandleNotify() override;
virtual PipeConnection::Delegate *HandleNewConnection(PipeConnection *connection) override;
// -- from PipeMessageHandler::Delegate
virtual bool HandleMessage(int type, uint8 *data, size_t size);
virtual bool HandleNotify();
virtual void HandleNewConnection();
virtual void HandleDisconnect();
protected: protected:
TunsafeBackend::Delegate *delegate_; TunsafeBackend::Delegate *delegate_;
@ -173,8 +206,10 @@ protected:
bool got_state_from_control_; bool got_state_from_control_;
ServiceState service_state_; ServiceState service_state_;
std::string config_file_; std::string config_file_;
PipeMessageHandler message_handler_; PipeManager pipe_manager_;
PipeConnection *connection_;
LinearizedGraph *cached_graph_; LinearizedGraph *cached_graph_;
uint32 last_graph_type_; uint32 last_graph_type_;
Mutex mutex_; Mutex mutex_;
}; };

36
service_win32_constants.h Normal file
View file

@ -0,0 +1,36 @@
#pragma once
#define TUNSAFE_PIPE_NAME "\\\\.\\pipe\\TunSafe\\ServiceControl"
#define TUNSAFE_SERVICE_PROTOCOL_VERSION 20180916001
enum {
TS_SERVICE_REQ_LOGIN = 0,
TS_SERVICE_REQ_START = 1,
TS_SERVICE_REQ_STOP = 2,
TS_SERVICE_REQ_GETSTATS = 4,
TS_SERVICE_REQ_SET_INTERNET_BLOCKSTATE = 5,
TS_SERVICE_REQ_RESETSTATS = 6,
TS_SERVICE_REQ_SET_STARTUP_FLAGS = 7,
TS_SERVICE_MSG_STATE = 8,
TS_SERVICE_MSG_LOGLINE = 9,
TS_SERVICE_MSG_ERROR_REPLY = 10,
TS_SERVICE_MSG_STATS = 11,
TS_SERVICE_MSG_CLEARLOG = 12,
TS_SERVICE_MSG_STATUS_CODE = 14,
TS_SERVICE_REQ_GET_GRAPH = 15,
TS_SERVICE_MSG_GRAPH = 16,
TS_SERVICE_REQ_TEXT_PROTOCOL = 17,
TS_SERVICE_REQ_TEXT_PROTOCOL_REPLY = 18,
TS_SERVICE_REQ_GETINTERFACES = 19,
TS_SERVICE_REQ_GETINTERFACES_REPLY = 20,
};
enum {
kTsMaxDevnameSize = 40
};

View file

@ -13,10 +13,14 @@
#if defined(OS_WIN) #if defined(OS_WIN)
#define _WINSOCK_DEPRECATED_NO_WARNINGS 1 #define _WINSOCK_DEPRECATED_NO_WARNINGS 1
#define _HAS_EXCEPTIONS 0
#define _CRT_SECURE_NO_WARNINGS 1
//#include <Winsock2.h> //#include <Winsock2.h>
#include <Ws2tcpip.h> #include <Ws2tcpip.h>
#include <Windows.h> #include <Windows.h>
#undef max
//#include <winsock2.h> //#include <winsock2.h>
#include <ws2ipdef.h> #include <ws2ipdef.h>
#include <iphlpapi.h> #include <iphlpapi.h>

883
ts.cpp Normal file
View file

@ -0,0 +1,883 @@
#include "stdafx.h"
#include "tunsafe_types.h"
#include "netapi.h"
#include "crypto/curve25519-donna.h"
#include "util.h"
#include "wireguard_proto.h"
#include <string.h>
#include <algorithm>
#if defined(OS_WIN)
#include "util_win32.h"
#include "service_pipe_win32.h"
#include "service_win32_constants.h"
#endif // defined(OS_WIN)
#if defined(OS_POSIX)
#include <sys/stat.h>
#include <sys/un.h>
#include <unistd.h>
#include <dirent.h>
#include <errno.h>
#endif // defined(OS_WIN)
#pragma comment(lib, "ws2_32.lib")
#define ANSI_RESET "\x1b[0m"
#define ANSI_BOLD "\x1b[1m"
#define ANSI_FG_BLACK "\x1b[30m"
#define ANSI_FG_RED "\x1b[31m"
#define ANSI_FG_GREEN "\x1b[32m"
#define ANSI_FG_YELLOW "\x1b[33m"
#define ANSI_FG_BLUE "\x1b[34m"
#define ANSI_FG_MAGENTA "\x1b[35m"
#define ANSI_FG_CYAN "\x1b[36m"
#define ANSI_FG_WHITE "\x1b[37m"
static const uint8 kCurve25519Basepoint[32] = {9};
#if defined(OS_WIN)
#define EXENAME "ts"
static bool SendMessageToService(HANDLE pipe, int message, const void *data, size_t data_size) {
uint8 *temp = new uint8[data_size + 5];
*(uint32*)temp = (uint32)(data_size + 1);
temp[4] = (uint8)message;
memcpy(temp + 5, data, data_size);
// Write the whole thing
DWORD pos = 0, bytes_to_write = (DWORD)(data_size + 5), bytes_written;
do {
if (!WriteFile(pipe, temp + pos, bytes_to_write, &bytes_written, NULL)) {
fprintf(stderr, "Error writing to service pipe, error = %d\n", GetLastError());
break;
}
pos += bytes_written;
bytes_to_write -= bytes_written;
} while (bytes_to_write != 0);
delete[] temp;
return (bytes_to_write == 0);
}
static bool ReadExactBytesFromPipe(HANDLE pipe, const void *data, DWORD bytes_to_read) {
DWORD pos = 0, n;
do {
if (!ReadFile(pipe, (uint8*)data + pos, bytes_to_read, &n, NULL))
return false;
if (n == 0)
return false; // premature eof..
pos += n;
bytes_to_read -= n;
} while (bytes_to_read != 0);
return true;
}
static bool ReadMessageFromService(HANDLE pipe, int *message, std::string *data) {
uint8 header[5];
uint32 message_size;
if (!ReadExactBytesFromPipe(pipe, header, 5) || (message_size = *(uint32*)header) == 0) {
fprintf(stderr, "Error reading from service pipe, error = %d\n", GetLastError());
return false;
}
*message = header[4];
data->resize(message_size - 1);
if (message_size - 1 != 0 && !ReadExactBytesFromPipe(pipe, data->data(), message_size - 1)) {
fprintf(stderr, "Error reading from service pipe, error = %d\n", GetLastError());
return false;
}
return true;
}
struct ServiceLoginMessage {
uint64 version;
char interfac[kTsMaxDevnameSize];
bool want_state_updates;
bool want_create_interface;
};
static std::vector<GuidAndDevName> g_tap_adapters;
static bool g_did_get_adapters;
static const std::vector<GuidAndDevName> &GetTapAdapterInfo() {
if (!g_did_get_adapters) {
g_did_get_adapters = true;
GetTapAdapterInfo(&g_tap_adapters);
}
return g_tap_adapters;
}
static const char *GetGuidFromInterfaceName(const char *name) {
for (const GuidAndDevName &e : GetTapAdapterInfo())
if (strcmp(e.name, name) == 0)
return e.guid;
return NULL;
}
static const char *GetInterfaceNameFromGuid(const char *guid) {
for (const GuidAndDevName &e : GetTapAdapterInfo())
if (strcmp(e.guid, guid) == 0)
return e.name;
return NULL;
}
static HANDLE ConnectToService(const char *devname, bool want_updates, bool want_create = false) {
ServiceLoginMessage msg = {0};
msg.version = TUNSAFE_SERVICE_PROTOCOL_VERSION;
msg.want_state_updates = want_updates;
msg.want_create_interface = want_create;
// Rename devname to a guid
if (devname) {
const char *guid = (devname[0] == '{' || devname[0] == 0) ? devname : GetGuidFromInterfaceName(devname);
if (!guid) {
fprintf(stderr, "Interface '%s' not found\n", devname);
return NULL;
}
my_strlcpy(msg.interfac, sizeof(msg.interfac), guid);
}
for (;;) {
HANDLE pipe = CreateFile(TUNSAFE_PIPE_NAME, GENERIC_READ | GENERIC_WRITE, 0, NULL,
OPEN_EXISTING, 0, NULL);
if (pipe != INVALID_HANDLE_VALUE) {
if (!SendMessageToService(pipe, TS_SERVICE_REQ_LOGIN, &msg, sizeof(msg))) {
CloseHandle(pipe);
pipe = NULL;
}
return pipe;
}
DWORD error = GetLastError();
if (error != ERROR_PIPE_BUSY) {
fprintf(stderr, "Error connecting to TunSafe service, error = %d\n", error);
if (error == ERROR_FILE_NOT_FOUND)
fprintf(stderr, "Please check that the TunSafe service is started\n");
return NULL;
}
if (!WaitNamedPipe(TUNSAFE_PIPE_NAME, 10000)) {
fprintf(stderr, "Error connecting to TunSafe service, timed out.\n");
return NULL;
}
}
}
static bool CommunicateWithService(const char *devname, const std::string &query, std::string *reply) {
HANDLE pipe = ConnectToService(devname, false);
int message_code;
bool rv = false;
if (pipe != NULL &&
SendMessageToService(pipe, TS_SERVICE_REQ_TEXT_PROTOCOL, query.data(), query.size()) &&
ReadMessageFromService(pipe, &message_code, reply)) {
if (message_code == TS_SERVICE_REQ_TEXT_PROTOCOL_REPLY) {
rv = true;
} else {
if (message_code == TS_SERVICE_MSG_ERROR_REPLY) {
fprintf(stderr, "Error: %s\n", reply->c_str());
} else {
fprintf(stderr, "Unknown reply (%d) from TunSafe service.\n", message_code);
}
}
}
CloseHandle(pipe);
return rv;
}
static bool GetInterfaceList(std::string *result) {
HANDLE pipe = ConnectToService(NULL, false);
int message_code;
bool rv = false;
if (pipe != NULL &&
SendMessageToService(pipe, TS_SERVICE_REQ_GETINTERFACES, NULL, 0) &&
ReadMessageFromService(pipe, &message_code, result)) {
if (message_code == TS_SERVICE_REQ_GETINTERFACES_REPLY) {
rv = true;
} else {
fprintf(stderr, "GetInterfaceList: bad reply\n");
}
}
CloseHandle(pipe);
return rv;
}
#endif // defined(OS_WIN)
#if defined(OS_POSIX)
#define EXENAME "tunsafe"
static const char *GetGuidFromInterfaceName(const char *name) {
return name;
}
static const char *GetInterfaceNameFromGuid(const char *guid) {
return guid;
}
static int OpenUserspaceInterface(const char *iface) {
struct stat st;
struct sockaddr_un un = { 0 };
int fd = -1, rv;
if (strchr(iface, '/') != NULL) {
fprintf(stderr, "Unable to open usermode socket: No such device\n");
goto getout;
}
snprintf(un.sun_path, sizeof(un.sun_path), "/var/run/wireguard/%s.sock", iface);
if (stat(un.sun_path, &st) < 0) {
perror("Unable to open usermode socket");
goto getout;
}
if (!S_ISSOCK(st.st_mode)) {
fprintf(stderr, "Unable to open usermode socket: No such device\n");
goto getout;
}
fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (fd < 0)
goto getout;
un.sun_family = AF_UNIX;
if (connect(fd, (struct sockaddr *)&un, sizeof(un)) < 0) {
if (errno == ECONNREFUSED)
unlink(un.sun_path);
else
perror("Error opening wireguard usermode interface socket");
goto getout;
}
return fd;
getout:
if (fd >= 0)
close(fd);
return -1;
}
static bool GetInterfaceList(std::string *result) {
struct dirent *dent;
DIR *dir = opendir("/var/run/wireguard/");
if (!dir)
return errno == ENOENT;
while ((dent = readdir(dir)) != NULL) {
size_t len = strlen(dent->d_name);
static const char kSuffix[6] = ".sock";
if (len >= sizeof(kSuffix) - 1 &&
memcmp(&dent->d_name[len - (sizeof(kSuffix) - 1)], kSuffix, sizeof(kSuffix) - 1) == 0) {
dent->d_name[len - (sizeof(kSuffix) - 1)] = '\n';
result->append(dent->d_name, len - (sizeof(kSuffix) - 1) + 1);
}
}
closedir(dir);
return true;
}
static bool CommunicateWithService(const char *devname, const std::string &query, std::string *reply) {
ssize_t n;
char buf[4096];
bool rv = false;
reply->clear();
int fd = OpenUserspaceInterface(devname);
if (fd == -1)
return false;
for(size_t pos = 0; query.size() - pos; pos += n) {
n = write(fd, query.data() + pos, query.size() - pos);
if (n <= 0) {
perror("Error writing to service pipe");
goto getout;
}
}
for(;;) {
n = read(fd, buf, sizeof(buf));
if (n <= 0) {
if (n == 0) {
// ensure that it ends with \n\n
if (reply->size() >= 2 && (*reply)[reply->size() - 1] == '\n' && (*reply)[reply->size() - 2] == '\n') {
rv = true;
} else {
fprintf(stderr, "Bad reply from service pipe\n");
}
} else {
perror("Error reading from service pipe");
}
break;
}
reply->append(buf, n);
}
getout:
close(fd);
return rv;
}
static int HandleStopCommand(int argc, char **argv) {
if (argc != 1) {
fprintf(stderr, "Usage: " EXENAME " stop <interface>\n");
return 1;
}
struct sockaddr_un un;
struct stat st;
const char *iface = argv[0];
if (strchr(iface, '/')) {
fprintf(stderr, "No such interface\n");
return 1;
}
snprintf(un.sun_path, sizeof(un.sun_path), "/var/run/wireguard/%s.sock", iface);
if (unlink(un.sun_path) == -1) {
perror("unlink");
return 1;
}
return 0;
}
#endif // defined(OS_POSIX)
void ShowHelp() {
fprintf(stderr,
"Usage: " EXENAME " <cmd> [<args>]\n\n"
#if defined(OS_POSIX)
" " EXENAME " filename.conf\n\n"
#endif // defined(OS_POSIX)
"Available subcommands:\n"
" show: Shows the configuration and status of the interfaces\n"
" set: Change the configuration or the peer list\n"
" start: Start TunSafe on an interface\n"
" stop: Stop TunSafe on an interface\n"
#if defined(OS_WIN)
" log: Display recent log entries\n"
#endif // defined(OS_WIN)
" genkey: Writes a new private key to stdout\n"
" genpsk: Writes a new preshared key to stdout\n"
" pubkey: Reads a private key from stdin and writes its public key to stdout\n"
"To see more help about a subcommand, pass --help to it\n");
}
static bool ParseHexKeyToBase64(const char *key, char base64key[WG_PUBLIC_KEY_LEN_BASE64 + 1]) {
uint8 keybuf[32];
if (!ParseHexString(key, keybuf, 32))
return false;
return base64_encode(keybuf, 32, base64key, WG_PUBLIC_KEY_LEN_BASE64 + 1, NULL) != NULL;
}
static char *FormatTransferPart(char *buf, size_t bufsize, uint64 n) {
if (n < 1024)
snprintf(buf, bufsize, "%u " ANSI_FG_CYAN "B" ANSI_RESET, (unsigned)n);
else if (n < 1024 * 1024)
snprintf(buf, bufsize, "%.2f " ANSI_FG_CYAN "KiB" ANSI_RESET, (double)n * (1.0 / 1024));
else if (n < 1024 * 1024 * 1024)
snprintf(buf, bufsize, "%.2f " ANSI_FG_CYAN "MiB" ANSI_RESET, (double)n * (1.0 / 1024 / 1024));
else if (n < 1024ull * 1024 * 1024 * 1024)
snprintf(buf, bufsize, "%.2f " ANSI_FG_CYAN "GiB" ANSI_RESET, (double)n * (1.0 / 1024 / 1024 / 1024));
else
snprintf(buf, bufsize, "%.2f " ANSI_FG_CYAN "TiB" ANSI_RESET, (double)n * (1.0 / 1024 / 1024 / 1024 / 1024));
return buf;
}
static size_t PrintTime(char *buf, size_t bufsize, uint64 n) {
size_t pos = 0;
uint64 years = n / (365 * 24 * 60 * 60);
uint32 n32 = n % (365 * 24 * 60 * 60);
if (years)
pos += snprintf(buf + pos, bufsize - pos, "%llu " ANSI_FG_CYAN "year%s" ANSI_RESET ", ", (unsigned long long)years, (years == 1) ? "" : "s");
uint32 days = n32 / (24 * 60 * 60);
n32 %= (24 * 60 * 60);
if (days)
pos += snprintf(buf + pos, bufsize - pos, "%u " ANSI_FG_CYAN "day%s" ANSI_RESET ", ", days, (days == 1) ? "" : "s");
uint32 hours = n32 / (60 * 60);
n32 %= (60 * 60);
if (hours)
pos += snprintf(buf + pos, bufsize - pos, "%u " ANSI_FG_CYAN "hour%s" ANSI_RESET ", ", hours, (hours == 1) ? "" : "s");
uint32 minutes = n32 / 60;
if (minutes)
pos += snprintf(buf + pos, bufsize - pos, "%u " ANSI_FG_CYAN "minute%s" ANSI_RESET ", ", minutes, (minutes == 1) ? "" : "s");
uint32 seconds = n32 % 60;
if (seconds)
pos += snprintf(buf + pos, bufsize - pos, "%u " ANSI_FG_CYAN "second%s" ANSI_RESET ", ", seconds, (seconds == 1) ? "" : "s");
if (pos)
buf[pos -= 2] = '\0';
return pos;
}
static char *PrintHandshake(char *buf, size_t bufsize, uint64 secs) {
time_t now = time(NULL);
if (now == secs) {
snprintf(buf, bufsize, "Now");
} else if (now < (int64)secs) {
snprintf(buf, bufsize, ANSI_FG_RED "System clock going backwards" ANSI_RESET);
} else {
size_t pos = PrintTime(buf, bufsize - 4, now - secs);
memcpy(buf + pos, " ago", 5);
}
return buf;
}
static void AppendIpToString(const char *value, std::string *result) {
if (!result->empty())
(*result) += ", ";
const char *slash = strchr(value, '/');
if (slash) {
result->append(value, slash - value);
result->append(ANSI_FG_CYAN "/" ANSI_RESET);
result->append(slash + 1);
} else {
result->append(value);
}
}
static int ShowUserFriendlyForDevice(char *devname) {
std::string reply;
std::vector<std::pair<char*, char*>> kv;
std::string ips;
if (!CommunicateWithService(devname, "get=1\n\n", &reply))
return 1;
if (!ParseConfigKeyValue(&reply[0], &kv)) {
getout_fail:
fprintf(stderr, "Unable to parse response");
return 1;
}
size_t i = 0;
char base64key[WG_PUBLIC_KEY_LEN_BASE64 + 1];
char base64psk[WG_PUBLIC_KEY_LEN_BASE64 + 1];
int listen_port = 0;
base64key[0] = 0;
// Parse all interface level keys
for (; i < kv.size(); i++) {
char *key = kv[i].first, *value = kv[i].second;
if (strcmp(key, "private_key") == 0) {
uint8 binkey[32];
if (!ParseHexString(value, binkey, sizeof(binkey)))
goto getout_fail;
if (!IsOnlyZeros(binkey, 32)) {
curve25519_donna(binkey, binkey, kCurve25519Basepoint);
base64_encode(binkey, sizeof(binkey), base64key, sizeof(base64key), NULL);
}
} else if (strcmp(key, "address") == 0) {
AppendIpToString(value, &ips);
} else if (strcmp(key, "listen_port") == 0) {
listen_port = atoi(value);
} else if (strcmp(key, "public_key") == 0) {
break;
}
}
const char *interfacename = (devname[0] == '{') ? GetInterfaceNameFromGuid(devname) : devname;
printf(ANSI_RESET ANSI_FG_GREEN ANSI_BOLD "interface" ANSI_RESET ": " ANSI_FG_GREEN "%s" ANSI_RESET "\n",
interfacename);
if (base64key[0]) {
printf(" " ANSI_BOLD "public key" ANSI_RESET ": %s\n"
" " ANSI_BOLD "private key" ANSI_RESET ": (hidden)\n", base64key);
}
if (listen_port)
printf(" " ANSI_BOLD "listening port" ANSI_RESET ": %d\n", listen_port);
if (ips.size())
printf(" " ANSI_BOLD "address" ANSI_RESET ": %s\n", ips.c_str());
const char *endpoint = NULL;
uint64 rx_bytes, tx_bytes, last_handshake_time_sec;
int persistent_keepalive;
char text[256];
bool clear_state = true;
// Parse peer level keys
for (; i < kv.size(); i++) {
char *key = kv[i].first, *value = kv[i].second;
if (clear_state) {
base64key[0] = base64psk[0] = 0;
endpoint = NULL;
ips.clear();
persistent_keepalive = 0;
last_handshake_time_sec = tx_bytes = rx_bytes = 0;
clear_state = false;
}
if (strcmp(key, "public_key") == 0) {
if (!ParseHexKeyToBase64(value, base64key))
goto getout_fail;
} else if (strcmp(key, "preshared_key") == 0) {
if (!ParseHexKeyToBase64(value, base64psk))
goto getout_fail;
} else if (strcmp(key, "tx_bytes") == 0) {
tx_bytes = strtoull(value, NULL, 0);
} else if (strcmp(key, "rx_bytes") == 0) {
rx_bytes = strtoull(value, NULL, 0);
} else if (strcmp(key, "allowed_ip") == 0) {
AppendIpToString(value, &ips);
} else if (strcmp(key, "persistent_keepalive_interval") == 0) {
persistent_keepalive = atoi(value);
} else if (strcmp(key, "endpoint") == 0) {
endpoint = value;
} else if (strcmp(key, "last_handshake_time_sec") == 0) {
last_handshake_time_sec = strtoull(value, NULL, 0);
}
if (i == kv.size() - 1 || strcmp(kv[i + 1].first, "public_key") == 0) {
if (!base64key[0])
goto getout_fail;
printf("\n" ANSI_FG_YELLOW ANSI_BOLD "peer" ANSI_RESET ": " ANSI_FG_YELLOW "%s" ANSI_RESET "\n", base64key);
if (base64psk[0])
printf(" " ANSI_BOLD "preshared key" ANSI_RESET ": (hidden)\n");
if (endpoint)
printf(" " ANSI_BOLD "endpoint" ANSI_RESET ": %s\n", endpoint);
printf(" " ANSI_BOLD "allowed ips" ANSI_RESET ": %s\n", ips.size() ? ips.c_str() : "(none)");
if (last_handshake_time_sec)
printf(" " ANSI_BOLD "latest handshake" ANSI_RESET ": %s\n", PrintHandshake(text, sizeof(text), last_handshake_time_sec));
if (tx_bytes | rx_bytes) {
printf(" " ANSI_BOLD "transfer" ANSI_RESET ": %s received, ", FormatTransferPart(text, sizeof(text), rx_bytes));
printf("%s sent\n", FormatTransferPart(text, sizeof(text), tx_bytes));
}
if (persistent_keepalive) {
PrintTime(text, sizeof(text), persistent_keepalive);
printf(" " ANSI_BOLD "persistent keepalive" ANSI_RESET ": every %s\n", text);
}
clear_state = true;
}
}
return 0;
}
static int HandleShowCommand(int argc, char **argv) {
if (argc != 0 && strcmp(argv[0], "--help") == 0) {
fprintf(stderr, "Usage: ts show { <interface> | all | interfaces }\n");
return 0;
}
std::vector<char*> interfaces;
std::string interfaces_str;
if (argc == 0 || strcmp(argv[0], "all") == 0) {
if (!GetInterfaceList(&interfaces_str))
return 1;
SplitString(&interfaces_str[0], '\n', &interfaces);
bool want_newline = false;
for (char *interfac : interfaces) {
if (want_newline)
printf("\n");
want_newline = true;
if (ShowUserFriendlyForDevice(interfac))
return 1;
}
} else if (strcmp(argv[0], "interfaces") == 0) {
if (!GetInterfaceList(&interfaces_str))
return 1;
SplitString(&interfaces_str[0], '\n', &interfaces);
for (char *interfac : interfaces) {
const char *name = GetInterfaceNameFromGuid(interfac);
if (name)
printf("%s\n", name);
}
} else {
return ShowUserFriendlyForDevice(argv[0]);
}
return 0;
}
static void AppendCommand(std::string *result, const char *tag, const char *value) {
result->append(tag);
result->append("=");
result->append(value);
result->append("\n");
}
static bool ConvertBase64KeyToHex(const char *s, char key[65]) {
uint8 tmp[32];
size_t size = 32;
if (!base64_decode((uint8*)s, strlen(s), tmp, &size) || size != 32)
return false;
PrintHexString(tmp, 32, key);
return true;
}
static int HandleSetCommand(int argc, char **argv) {
std::string command, reply;
std::vector<char*> ss;
char hexkey[65];
if (argc == 0) {
fprintf(stderr, "Usage: ts set <interface> [address <address>] [listen-port <port>] [private-key <file path>] "
"[peer <base64 public key> [remove] [preshared-key <file path>] [endpoint <ip>:<port>] "
"[persistent-keepalive <interval seconds>] [allowed-ips <ip1>/<cidr1>[,<ip2>/<cidr2>]] ]");
return 1;
}
char **argv_end = argv + argc;
const char *interfc = *argv++;
command = "set=1\n";
bool in_interface_section = true;
bool in_peer_section = false;
bool did_clear_allowed_ips = false;
while (argv != argv_end) {
const char *key = *argv++;
if (argv != argv_end) {
if (in_interface_section) {
if (strcmp(key, "listen-port") == 0) {
AppendCommand(&command, "listen_port", *argv++);
continue;
} else if (strcmp(key, "address") == 0) {
AppendCommand(&command, "address", *argv++);
continue;
} else if (strcmp(key, "private-key") == 0) {
if (!ConvertBase64KeyToHex(*argv++, hexkey))
goto invalid_key_format;
AppendCommand(&command, "private_key", hexkey);
continue;
}
}
if (strcmp(key, "peer") == 0) {
in_interface_section = false;
in_peer_section = true;
did_clear_allowed_ips = false;
if (!ConvertBase64KeyToHex(*argv++, hexkey))
goto invalid_key_format;
AppendCommand(&command, "public_key", hexkey);
continue;
}
if (in_peer_section) {
if (strcmp(key, "preshared-key") == 0) {
if (!ConvertBase64KeyToHex(*argv++, hexkey))
goto invalid_key_format;
AppendCommand(&command, "preshared_key", hexkey);
continue;
} else if (strcmp(key, "endpoint") == 0) {
AppendCommand(&command, "endpoint", *argv++);
continue;
} else if (strcmp(key, "persistent-keepalive") == 0) {
AppendCommand(&command, "persistent_keepalive_interval", *argv++);
continue;
} else if (strcmp(key, "allowed-ips") == 0) {
if (!did_clear_allowed_ips) {
AppendCommand(&command, "replace_allowed_ips", "true");
did_clear_allowed_ips = true;
}
SplitString(*argv++, ',', &ss);
for (char *x : ss)
AppendCommand(&command, "allowed_ip", x);
continue;
}
}
}
if (in_peer_section) {
if (strcmp(key, "remove") == 0) {
in_peer_section = false;
AppendCommand(&command, "remove", "true");
continue;
}
}
fprintf(stderr, "Invalid argument: %s\n", key);
return 1;
invalid_key_format:
fprintf(stderr, "Key is not in the correct format: '%s'\n", argv[-1]);
return 1;
}
command.append("\n");
if (!CommunicateWithService(interfc, command, &reply))
return 1;
return 0;
}
#if defined(OS_WIN)
static int HandleLogCommand() {
HANDLE pipe = ConnectToService(NULL, true);
int message_code;
std::string reply;
while (pipe != NULL && ReadMessageFromService(pipe, &message_code, &reply) && message_code == TS_SERVICE_MSG_LOGLINE)
printf("%s\n", reply.c_str());
CloseHandle(pipe);
return 0;
}
static int HandleStartCommand(int argc, char **argv) {
if (argc < 1 || argc > 2 || strcmp(argv[0], "--help") == 0) {
fprintf(stderr, "Usage: " EXENAME " start <interface> [<filename>]\n");
return 1;
}
const char *devname = argv[0];
HANDLE pipe = ConnectToService(devname, false, true);
int message_code;
std::string reply;
const char *path = (argc == 1) ? "" : argv[1];
// Tell the server to startup a new interface
if (pipe == NULL ||
!SendMessageToService(pipe, TS_SERVICE_REQ_START, path, strlen(path) + 1) ||
!ReadMessageFromService(pipe, &message_code, &reply))
return 1;
if (message_code == TS_SERVICE_MSG_ERROR_REPLY) {
fprintf(stderr, "%s\n", reply.c_str());
return 1;
}
return 0;
}
static int HandleStopCommand(int argc, char **argv) {
if (argc != 1) {
fprintf(stderr, "Usage: " EXENAME " stop <interface>\n");
return 1;
}
const char *devname = argv[0];
HANDLE pipe = ConnectToService(devname, false);
// Tell the server to stop the interface
if (pipe == NULL ||
!SendMessageToService(pipe, TS_SERVICE_REQ_STOP, NULL, 0))
return 1;
return 0;
}
#endif // defined(OS_WIN)
struct CommandLineOutput {
const char *filename_to_load;
const char *interface_name;
bool daemon;
};
// Returns -1 on invalid subcommand
int HandleCommandLine(int argc, char **argv, CommandLineOutput *output) {
uint8 key[32];
char base64buf[WG_PUBLIC_KEY_LEN_BASE64 + 1];
if (argc == 1) {
ShowHelp();
return 1;
}
const char *subcommand = argv[1];
argv += 2;
argc -= 2;
if (!strcmp(subcommand, "show")) {
return HandleShowCommand(argc, argv);
} else if (!strcmp(subcommand, "set")) {
return HandleSetCommand(argc, argv);
#if defined(OS_WIN)
} else if (!strcmp(subcommand, "log")) {
if (argc != 0) {
fprintf(stderr, "Usage: " EXENAME " log\n");
return 1;
}
return HandleLogCommand();
} else if (!strcmp(subcommand, "start")) {
return HandleStartCommand(argc, argv);
#else
} else if (!strcmp(subcommand, "start") && output) {
if (argc != 0 && !strcmp(argv[0], "--help")) {
start_usage:
fprintf(stderr, "Usage: " EXENAME " start [-d/--daemon] [-n <interface-name>] [<filename>]\n");
return 0;
}
for (; argc; argc--, argv++) {
char *arg = argv[0];
if (strcmp(arg, "-d") == 0 || strcmp(arg, "--daemon") == 0) {
output->daemon = true;
continue;
}
if (strcmp(arg, "-n") == 0) {
if (argc < 2) goto start_usage;
output->interface_name = argv[1];
argc--,argv++;
continue;
}
break;
}
if (argc > 1) goto start_usage;
output->filename_to_load = (argc == 0) ? "" : argv[0];
return 0;
#endif // defined(OS_WIN)
} else if (!strcmp(subcommand, "stop")) {
return HandleStopCommand(argc, argv);
} else if(!strcmp(subcommand, "genkey")) {
if (argc != 0) {
fprintf(stderr, "Usage: " EXENAME " genkey\n");
return 1;
}
OsGetRandomBytes(key, 32);
curve25519_normalize(key);
printf("%s\n", base64_encode(key, 32, base64buf, sizeof(base64buf), NULL));
} else if (!strcmp(subcommand, "genpsk")) {
if (argc != 0) {
fprintf(stderr, "Usage: " EXENAME " genpsk\n");
return 1;
}
OsGetRandomBytes(key, 32);
printf("%s\n", base64_encode(key, 32, base64buf, sizeof(base64buf), NULL));
} else if (!strcmp(subcommand, "pubkey")) {
char base64[WG_PUBLIC_KEY_LEN_BASE64 + 2];
size_t n = fread(base64, 1, sizeof(base64), stdin);
if (n < sizeof(base64) - 2 || n >= sizeof(base64) ||
(n == sizeof(base64) - 1 && (base64[WG_PUBLIC_KEY_LEN_BASE64] != ' ' && base64[WG_PUBLIC_KEY_LEN_BASE64] != '\n'))) {
fprintf(stderr, EXENAME ": Incorrect key format\n");
return 1;
}
size_t size = 32;
if (!base64_decode((uint8*)base64, n, key, &size) || size != 32) {
fprintf(stderr, EXENAME ": Incorrect key format\n");
return 1;
}
curve25519_donna(key, key, kCurve25519Basepoint);
printf("%s\n", base64_encode(key, 32, base64buf, sizeof(base64buf), NULL));
} else if (!strcmp(subcommand, "--help")) {
ShowHelp();
} else if (!strcmp(subcommand, "--version")) {
printf("%s\n", TUNSAFE_VERSION_STRING);
} else {
if (argc == 0) {
if (output)
output->filename_to_load = subcommand;
} else {
ShowHelp();
}
return -1;
}
return 0;
}
#if defined(OS_WIN)
// This is integrated into the main tunsafe binary on posix systems
int main(int argc, char **argv) {
int rv = HandleCommandLine(argc, argv, NULL);
if (rv == -1) {
fprintf(stderr, "Invalid subcommand '%s'\n", argv[1]);
ShowHelp();
return 1;
}
return rv;
}
#endif // defined(OS_WIN)

199
ts.vcxproj Normal file
View file

@ -0,0 +1,199 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" ToolsVersion="15.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Label="ProjectConfigurations">
<ProjectConfiguration Include="Debug|Win32">
<Configuration>Debug</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|Win32">
<Configuration>Release</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Debug|x64">
<Configuration>Debug</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|x64">
<Configuration>Release</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
</ItemGroup>
<PropertyGroup Label="Globals">
<VCProjectVersion>15.0</VCProjectVersion>
<ProjectGuid>{443E105E-8D7C-401F-BD41-D3F56C76104B}</ProjectGuid>
<Keyword>Win32Proj</Keyword>
<RootNamespace>ts</RootNamespace>
<WindowsTargetPlatformVersion>10.0.17134.0</WindowsTargetPlatformVersion>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v141</PlatformToolset>
<CharacterSet>MultiByte</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v141</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>MultiByte</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<PlatformToolset>v141</PlatformToolset>
<CharacterSet>MultiByte</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<PlatformToolset>v141</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>MultiByte</CharacterSet>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings">
<Import Project="crypto\nasm.props" />
</ImportGroup>
<ImportGroup Label="Shared">
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<PropertyGroup Label="UserMacros" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<LinkIncremental>true</LinkIncremental>
<OutDir>$(SolutionDir)build\$(Platform)_$(Configuration)\</OutDir>
<IntDir>$(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\</IntDir>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<LinkIncremental>true</LinkIncremental>
<OutDir>$(SolutionDir)build\$(Platform)_$(Configuration)\</OutDir>
<IntDir>$(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\</IntDir>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<LinkIncremental>false</LinkIncremental>
<OutDir>$(SolutionDir)build\$(Platform)_$(Configuration)\</OutDir>
<IntDir>$(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\</IntDir>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<LinkIncremental>false</LinkIncremental>
<OutDir>$(SolutionDir)build\$(Platform)_$(Configuration)\</OutDir>
<IntDir>$(SolutionDir)build\$(Platform)_$(Configuration)\obj\$(ProjectName)\</IntDir>
</PropertyGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<ClCompile>
<PrecompiledHeader>Use</PrecompiledHeader>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
<SDLCheck>true</SDLCheck>
<PreprocessorDefinitions>_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<ConformanceMode>true</ConformanceMode>
<ExceptionHandling>false</ExceptionHandling>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<ClCompile>
<PrecompiledHeader>Use</PrecompiledHeader>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
<SDLCheck>true</SDLCheck>
<PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<ConformanceMode>true</ConformanceMode>
<ExceptionHandling>false</ExceptionHandling>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<ClCompile>
<PrecompiledHeader>Use</PrecompiledHeader>
<WarningLevel>Level3</WarningLevel>
<Optimization>MinSpace</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<SDLCheck>true</SDLCheck>
<PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<ConformanceMode>true</ConformanceMode>
<RuntimeLibrary>MultiThreaded</RuntimeLibrary>
<OmitFramePointers>true</OmitFramePointers>
<ExceptionHandling>false</ExceptionHandling>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<ClCompile>
<PrecompiledHeader>Use</PrecompiledHeader>
<WarningLevel>Level3</WarningLevel>
<Optimization>MinSpace</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<SDLCheck>true</SDLCheck>
<PreprocessorDefinitions>NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<ConformanceMode>true</ConformanceMode>
<RuntimeLibrary>MultiThreaded</RuntimeLibrary>
<OmitFramePointers>true</OmitFramePointers>
<ExceptionHandling>false</ExceptionHandling>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemGroup>
<ClInclude Include="crypto\curve25519-donna.h" />
<ClInclude Include="stdafx.h" />
<ClInclude Include="util.h" />
<ClInclude Include="util_win32.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="crypto\curve25519-donna.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">NotUsing</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">NotUsing</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">NotUsing</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">NotUsing</PrecompiledHeader>
</ClCompile>
<ClCompile Include="stdafx.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">Create</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">Create</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">Create</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|x64'">Create</PrecompiledHeader>
</ClCompile>
<ClCompile Include="ts.cpp" />
<ClCompile Include="util.cpp" />
<ClCompile Include="util_win32.cpp" />
</ItemGroup>
<ItemGroup>
<NASM Include="crypto\curve25519_x64_nasm.asm">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">true</ExcludedFromBuild>
</NASM>
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
<Import Project="crypto\nasm.targets" />
</ImportGroup>
</Project>

53
ts.vcxproj.filters Normal file
View file

@ -0,0 +1,53 @@
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup>
<Filter Include="Source Files">
<UniqueIdentifier>{4FC737F1-C7A5-4376-A066-2A32D752A2FF}</UniqueIdentifier>
<Extensions>cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx</Extensions>
</Filter>
<Filter Include="Header Files">
<UniqueIdentifier>{93995380-89BD-4b04-88EB-625FBE52EBFB}</UniqueIdentifier>
<Extensions>h;hh;hpp;hxx;hm;inl;inc;ipp;xsd</Extensions>
</Filter>
<Filter Include="Resource Files">
<UniqueIdentifier>{67DA6AB6-F800-4c08-8B7A-83BB121AAD01}</UniqueIdentifier>
<Extensions>rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms</Extensions>
</Filter>
</ItemGroup>
<ItemGroup>
<ClInclude Include="stdafx.h">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="util.h">
<Filter>Source Files</Filter>
</ClInclude>
<ClInclude Include="crypto\curve25519-donna.h">
<Filter>Source Files</Filter>
</ClInclude>
<ClInclude Include="util_win32.h">
<Filter>Source Files</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<ClCompile Include="stdafx.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="ts.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="util.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="crypto\curve25519-donna.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="util_win32.cpp">
<Filter>Source Files</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<NASM Include="crypto\curve25519_x64_nasm.asm">
<Filter>Source Files</Filter>
</NASM>
</ItemGroup>
</Project>

View file

@ -10,10 +10,14 @@ MultithreadedDelayedDelete::MultithreadedDelayedDelete() {
} }
MultithreadedDelayedDelete::~MultithreadedDelayedDelete() { MultithreadedDelayedDelete::~MultithreadedDelayedDelete() {
assert(curr_.size() == 0);
assert(next_.size() == 0);
assert(to_delete_.size() == 0);
free(table_); free(table_);
} }
void MultithreadedDelayedDelete::Initialize(uint32 num_threads) { void MultithreadedDelayedDelete::Configure(uint32 num_threads) {
assert(table_ == NULL);
num_threads_ = num_threads; num_threads_ = num_threads;
table_ = (CheckpointData*)calloc(sizeof(CheckpointData), num_threads); table_ = (CheckpointData*)calloc(sizeof(CheckpointData), num_threads);
} }

View file

@ -150,12 +150,14 @@ public:
typedef void DoDeleteFunc(void *x); typedef void DoDeleteFunc(void *x);
void Add(DoDeleteFunc *func, void *param); void Add(DoDeleteFunc *func, void *param);
void Initialize(uint32 num_threads); void Configure(uint32 num_threads);
void Checkpoint(uint32 thread_id); void Checkpoint(uint32 thread_id);
void MainCheckpoint(); void MainCheckpoint();
bool enabled() const { return num_threads_ != 0; }
private: private:
struct Entry { struct Entry {
DoDeleteFunc *func; DoDeleteFunc *func;

View file

@ -107,22 +107,6 @@ static void SetUiVisibility(bool visible) {
UpdateGraphReq(); UpdateGraphReq();
} }
static bool GetConfigFullName(const char *basename, char *fullname, size_t fullname_size) {
size_t len = strlen(basename);
if (FindFilenameComponent(basename)[0]) {
if (len >= fullname_size)
return false;
memcpy(fullname, basename, len + 1);
return true;
}
size_t clen = GetConfigPath(fullname, fullname_size);
if (clen == 0 || clen + len >= fullname_size)
return false;
memcpy(fullname + clen, basename, (len + 1) * sizeof(fullname[0]));
return true;
}
void StopTunsafeBackend(UpdateIconWhy why) { void StopTunsafeBackend(UpdateIconWhy why) {
if (g_backend->is_started()) { if (g_backend->is_started()) {
g_backend->Stop(); g_backend->Stop();
@ -154,6 +138,7 @@ void StartTunsafeBackend(UpdateIconWhy reason) {
} }
g_notified_connected_server = false; g_notified_connected_server = false;
g_is_connected_to_server = false; g_is_connected_to_server = false;
memset(&g_processor_stats, 0, sizeof(g_processor_stats));
g_backend->Start(g_current_filename); g_backend->Start(g_current_filename);
RegWriteInt(g_reg_key, "IsConnected", 1); RegWriteInt(g_reg_key, "IsConnected", 1);
} }
@ -189,12 +174,21 @@ public:
} }
virtual void OnLogLine(const char **s) { virtual void OnLogLine(const char **s) {
const char *line = *s;
size_t len = strlen(line);
char *tmp = (char*)alloca(len + 3);
tmp[len + 0] = '\r';
tmp[len + 1] = '\n';
tmp[len + 2] = 0;
memcpy(tmp, line, len);
CHARRANGE cr; CHARRANGE cr;
cr.cpMin = -1; cr.cpMin = -1;
cr.cpMax = -1; cr.cpMax = -1;
// hwnd = rich edit hwnd // hwnd = rich edit hwnd
SendMessage(hwndEdit, EM_EXSETSEL, 0, (LPARAM)&cr); SendMessage(hwndEdit, EM_EXSETSEL, 0, (LPARAM)&cr);
SendMessage(hwndEdit, EM_REPLACESEL, 0, (LPARAM)*s); SendMessage(hwndEdit, EM_REPLACESEL, 0, (LPARAM)tmp);
} }
virtual void OnStateChanged() { virtual void OnStateChanged() {
@ -204,13 +198,13 @@ public:
const char *filename = g_cmdline_filename; const char *filename = g_cmdline_filename;
if (filename) { if (filename) {
if (GetConfigFullName(filename, fullname, sizeof(fullname))) if (ExpandConfigPath(filename, fullname, sizeof(fullname)))
SetCurrentConfigFilename(fullname); SetCurrentConfigFilename(fullname);
} else { } else {
std::string currconfig = g_backend->GetConfigFileName(); std::string currconfig = g_backend->GetConfigFileName();
if (currconfig.empty()) { if (currconfig.empty()) {
char *conf = RegReadStr(g_reg_key, "ConfigFile", "TunSafe.conf"); char *conf = RegReadStr(g_reg_key, "ConfigFile", "TunSafe.conf");
if (GetConfigFullName(conf, fullname, sizeof(fullname))) if (ExpandConfigPath(conf, fullname, sizeof(fullname)))
SetCurrentConfigFilename(fullname); SetCurrentConfigFilename(fullname);
free(conf); free(conf);
} else { } else {
@ -233,10 +227,12 @@ public:
} }
virtual void OnStatusCode(TunsafeBackend::StatusCode status) override { virtual void OnStatusCode(TunsafeBackend::StatusCode status) override {
if (status != g_status_code)
InvalidatePaintbox();
g_status_code = status; g_status_code = status;
if (TunsafeBackend::IsPermanentError(status)) { if (TunsafeBackend::IsPermanentError(status)) {
UpdateIcon(g_is_connected_to_server ? UIW_STOPPED_WORKING_FAIL : UIW_NONE); UpdateIcon(g_is_connected_to_server ? UIW_STOPPED_WORKING_FAIL : UIW_NONE);
InvalidatePaintbox();
return; return;
} }
bool is_connected = (status == TunsafeBackend::kStatusConnected); bool is_connected = (status == TunsafeBackend::kStatusConnected);
@ -254,13 +250,15 @@ public:
if (is_connected > not_first && (g_startup_flags & kStartupFlag_BackgroundService)) if (is_connected > not_first && (g_startup_flags & kStartupFlag_BackgroundService))
g_notified_connected_server = true; g_notified_connected_server = true;
UpdateIcon(UIW_NONE); UpdateIcon(UIW_NONE);
InvalidatePaintbox();
} }
} }
virtual void OnClearLog() override { virtual void OnClearLog() override {
SetWindowText(hwndEdit, ""); SetWindowText(hwndEdit, "");
} }
virtual void OnConfigurationProtocolReply(uint32 ident, const std::string &&reply) override {
}
}; };
static MyBackendDelegate my_procdel; static MyBackendDelegate my_procdel;
@ -411,7 +409,7 @@ public:
ConfigMenuBuilder::ConfigMenuBuilder() ConfigMenuBuilder::ConfigMenuBuilder()
: nfiles_(0), depth_(0) { : nfiles_(0), depth_(0) {
if (!GetConfigFullName("", buf_, sizeof(buf_))) if (!ExpandConfigPath("", buf_, sizeof(buf_)))
bufpos_ = sizeof(buf_); bufpos_ = sizeof(buf_);
else else
bufpos_ = strlen(buf_); bufpos_ = strlen(buf_);
@ -556,7 +554,7 @@ static void OpenEditor() {
static void BrowseFiles() { static void BrowseFiles() {
char buf[MAX_PATH]; char buf[MAX_PATH];
if (GetConfigFullName("", buf, ARRAYSIZE(buf))) { if (ExpandConfigPath("", buf, ARRAYSIZE(buf))) {
size_t l = strlen(buf); size_t l = strlen(buf);
buf[l - 1] = 0; buf[l - 1] = 0;
ShellExecuteFromExplorer(buf, NULL, NULL, "explore"); ShellExecuteFromExplorer(buf, NULL, NULL, "explore");
@ -572,7 +570,7 @@ bool ImportFile(const char *s, bool silent = false) {
bool rv = false; bool rv = false;
int filerv; int filerv;
if (!*last || !GetConfigFullName(last, buf, ARRAYSIZE(buf)) || _stricmp(buf, s) == 0) if (!*last || !ExpandConfigPath(last, buf, ARRAYSIZE(buf)) || _stricmp(buf, s) == 0)
goto out; goto out;
filedata = LoadFileSane(s, &filesize); filedata = LoadFileSane(s, &filesize);
@ -657,9 +655,8 @@ void BrowseFile(HWND wnd) {
static const uint8 kCurve25519Basepoint[32] = {9}; static const uint8 kCurve25519Basepoint[32] = {9};
static void SetKeyBox(HWND wnd, int ctr, uint8 buf[32]) { static void SetKeyBox(HWND wnd, int ctr, uint8 buf[32]) {
uint8 *privs = base64_encode(buf, 32, NULL); char base64[WG_PUBLIC_KEY_LEN_BASE64 + 1];
SetDlgItemText(wnd, ctr, (char*)privs); SetDlgItemText(wnd, ctr, base64_encode(buf, 32, base64, sizeof(base64), NULL));
free(privs);
} }
static INT_PTR WINAPI KeyPairDlgProc(HWND hWnd, UINT message, WPARAM wParam, static INT_PTR WINAPI KeyPairDlgProc(HWND hWnd, UINT message, WPARAM wParam,
@ -1075,20 +1072,18 @@ void PushLine(const char *s) {
snprintf(buf, sizeof(buf), "[%.2d:%.2d:%.2d] ", t.wHour, t.wMinute, t.wSecond); snprintf(buf, sizeof(buf), "[%.2d:%.2d:%.2d] ", t.wHour, t.wMinute, t.wSecond);
size_t tl = strlen(buf); size_t tl = strlen(buf);
char *x = (char*)malloc(tl + l + 3); char *x = (char*)malloc(tl + l + 1);
if (!x) return; if (!x) return;
memcpy(x, buf, tl); memcpy(x, buf, tl);
memcpy(x + tl, s, l); memcpy(x + tl, s, l);
x[l + tl] = '\r'; x[l + tl] = '\0';
x[l + tl + 1] = '\n';
x[l + tl + 2] = '\0';
g_backend_delegate->OnLogLine((const char**)&x); g_backend_delegate->OnLogLine((const char**)&x);
free(x); free(x);
} }
void EnsureConfigDirCreated() { void EnsureConfigDirCreated() {
char fullname[1024]; char fullname[1024];
if (GetConfigFullName("", fullname, sizeof(fullname))) if (ExpandConfigPath("", fullname, sizeof(fullname)))
CreateDirectory(fullname, NULL); CreateDirectory(fullname, NULL);
} }
@ -1358,7 +1353,7 @@ static void DrawGraph(HDC dc, const RECT *rr, StatsCollector::TimeSeries **sourc
for (size_t j = 0; j != num_source; j++) { for (size_t j = 0; j != num_source; j++) {
const StatsCollector::TimeSeries *src = sources[j]; const StatsCollector::TimeSeries *src = sources[j];
for (size_t i = 0; i != src->size; i++) for (size_t i = 0; i != src->size; i++)
mx = max(mx, src->data[i]); mx = std::max(mx, src->data[i]);
} }
int topval = (int)(mx + 0.5f); int topval = (int)(mx + 0.5f);
// round it appropriately // round it appropriately
@ -1432,11 +1427,13 @@ static void DrawInGraphBox(HDC hdc, int w, int h) {
for (int i = 0; i < graph->num_charts; i++) { for (int i = 0; i < graph->num_charts; i++) {
time_series_ptr[i] = &time_series[i]; time_series_ptr[i] = &time_series[i];
time_series[i].shift = 0; time_series[i].shift = 0;
time_series[i].size = *(uint32*)ptr;
uint32 size = *(uint32*)ptr;
time_series[i].size = size;
time_series[i].data = (float*)(ptr + 4); time_series[i].data = (float*)(ptr + 4);
ptr += 4 + *(uint32*)ptr * 4; if ((ptr - (uint8*)graph) + 4 + (uint64)size * 4 > graph->total_size)
if (ptr - (uint8*)graph > graph->total_size)
break; break;
ptr += 4 + size * 4;
} }
num_charts = graph->num_charts; num_charts = graph->num_charts;
} }
@ -1517,9 +1514,7 @@ static const char *GetAdvancedInfoValue(char buffer[256], int i) {
case 0: { case 0: {
if (IsOnlyZeros(g_backend->public_key(), 32)) if (IsOnlyZeros(g_backend->public_key(), 32))
return ""; return "";
char *str = (char*)base64_encode(g_backend->public_key(), 32, NULL); base64_encode(g_backend->public_key(), 32, buffer, 256, NULL);
snprintf(buffer, 256, "%s", str);
free(str);
return buffer; return buffer;
} }
case 1: { case 1: {
@ -1764,7 +1759,7 @@ int WINAPI WinMain(HINSTANCE hInstance, HINSTANCE hPrevInstance, LPSTR lpCmdLine
// Check if the app is already running. // Check if the app is already running.
g_runonce_mutex = CreateMutexA(0, FALSE, "TunSafe-f19e092db01cbe0fb6aee132f8231e5b71c98f90"); g_runonce_mutex = CreateMutexA(0, FALSE, "TunSafe-f19e092db01cbe0fb6aee132f8231e5b71c98f90");
if (GetLastError() == ERROR_ALREADY_EXISTS) { if (GetLastError() == ERROR_ALREADY_EXISTS&&0) {
HWND window = FindWindow("TunSafe-f19e092db01cbe0fb6aee132f8231e5b71c98f90", NULL); HWND window = FindWindow("TunSafe-f19e092db01cbe0fb6aee132f8231e5b71c98f90", NULL);
DWORD_PTR result; DWORD_PTR result;
if (!window || !SendMessageTimeout(window, WM_USER + 10, 0, 0, SMTO_BLOCK, 3000, &result) || result != 31337) { if (!window || !SendMessageTimeout(window, WM_USER + 10, 0, 0, SMTO_BLOCK, 3000, &result) || result != 31337) {

169
util.cpp
View file

@ -17,46 +17,55 @@
#include <arpa/inet.h> #include <arpa/inet.h>
#endif #endif
#include <vector>
#include <algorithm> #include <algorithm>
#include "tunsafe_types.h" #include "tunsafe_types.h"
static char base64_alphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; static const char kBase64Alphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
uint8 *base64_encode(const uint8 *input, size_t length, size_t *out_length) { char *base64_encode(const uint8 *input, size_t length, char *output, size_t output_size, size_t *out_length) {
uint32 a; char *result, *r;
size_t size;
uint8 *result, *r;
const uint8 *end; const uint8 *end;
size = length * 4 / 3 + 4 + 1; size_t size = (length + 2) / 3 * 4 + 1;
r = result = (byte*)malloc(size);
if (output != NULL) {
result = output;
assert(output_size >= size);
if (output_size < size) {
*result = 0;
return NULL;
}
} else {
result = (char*)malloc(size);
if (!result)
return NULL;
}
r = result;
end = input + length - 3; end = input + length - 3;
// Encode full blocks // Encode full blocks
while (input <= end) { while (input <= end) {
a = (input[0] << 16) + (input[1] << 8) + input[2]; uint32 a = (input[0] << 16) + (input[1] << 8) + input[2];
input += 3; input += 3;
r[0] = base64_alphabet[(a >> 18)/* & 0x3F*/]; r[0] = kBase64Alphabet[(a >> 18)/* & 0x3F*/];
r[1] = base64_alphabet[(a >> 12) & 0x3F]; r[1] = kBase64Alphabet[(a >> 12) & 0x3F];
r[2] = base64_alphabet[(a >> 6) & 0x3F]; r[2] = kBase64Alphabet[(a >> 6) & 0x3F];
r[3] = base64_alphabet[(a) & 0x3F]; r[3] = kBase64Alphabet[(a) & 0x3F];
r += 4; r += 4;
} }
if (input == end + 2) { if (input == end + 2) {
a = input[0] << 4; uint32 a = input[0] << 4;
r[0] = base64_alphabet[(a >> 6) /*& 0x3F*/]; r[0] = kBase64Alphabet[(a >> 6) /*& 0x3F*/];
r[1] = base64_alphabet[(a) & 0x3F]; r[1] = kBase64Alphabet[(a) & 0x3F];
r[2] = '='; r[2] = '=';
r[3] = '='; r[3] = '=';
r += 4; r += 4;
} else if (input == end + 1) { } else if (input == end + 1) {
a = (input[0] << 10) + (input[1] << 2); uint32 a = (input[0] << 10) + (input[1] << 2);
r[0] = base64_alphabet[(a >> 12) /*& 0x3F*/]; r[0] = kBase64Alphabet[(a >> 12) /*& 0x3F*/];
r[1] = base64_alphabet[(a >> 6) & 0x3F]; r[1] = kBase64Alphabet[(a >> 6) & 0x3F];
r[2] = base64_alphabet[(a) & 0x3F]; r[2] = kBase64Alphabet[(a) & 0x3F];
r[3] = '='; r[3] = '=';
r += 4; r += 4;
} }
@ -250,13 +259,6 @@ void RERROR(const char *msg, ...) {
} }
} }
void rinfo(const char *msg, ...) {
printf("muu");
}
void rinfo2(const char *msg) {
printf("muu2");
}
void RINFO(const char *msg, ...) { void RINFO(const char *msg, ...) {
va_list va; va_list va;
@ -297,3 +299,114 @@ size_t my_strlcpy(char *dst, size_t dstsize, const char *src) {
} }
return len; return len;
} }
void OsGetRandomBytes(uint8 *data, size_t data_size) {
#if defined(OS_WIN)
static BOOLEAN(APIENTRY *pfn)(void*, ULONG);
if (!pfn) {
pfn = (BOOLEAN(APIENTRY *)(void*, ULONG))GetProcAddress(LoadLibrary("ADVAPI32.DLL"), "SystemFunction036");
if (!pfn)
ExitProcess(1);
}
if (!pfn(data, (ULONG)data_size)) {
ExitProcess(1);
return;
}
#elif defined(OS_POSIX)
int fd = open("/dev/urandom", O_RDONLY);
if (fd < 0) {
fprintf(stderr, "/dev/urandom failed\n");
exit(1);
}
int r = read(fd, data, data_size);
if (r != data_size) {
fprintf(stderr, "/dev/urandom failed\n");
exit(1);
}
close(fd);
#else
#error
#endif
}
bool ParseConfigKeyValue(char *m, std::vector<std::pair<char *, char*>> *result) {
for (;;) {
char *nl = strchr(m, '\n');
if (nl)
*nl = 0;
if (*m != '\0') {
char *value = strchr(m, '=');
if (value == NULL)
return false;
*value++ = '\0';
result->emplace_back(m, value);
}
if (!nl)
return true;
m = nl + 1;
}
}
bool ParseHexString(const char *text, void *data, size_t data_size) {
size_t len = strlen(text);
if (len != data_size * 2)
return false;
for (size_t i = 0; i < data_size; i++) {
uint32 c = text[i * 2 + 0];
if (c >= '0' && c <= '9') {
c -= '0';
} else if ((c |= 32) >= 'a' && c <= 'f') {
c -= 'a' - 10;
} else {
return false;
}
uint32 d = text[i * 2 + 1];
if (d >= '0' && d <= '9') {
d -= '0';
} else if ((d |= 32) >= 'a' && d <= 'f') {
d -= 'a' - 10;
} else {
return false;
}
((uint8*)data)[i] = c * 16 + d;
}
return true;
}
bool is_space(uint8_t c) {
return c == ' ' || c == '\r' || c == '\n' || c == '\t';
}
void SplitString(char *s, int separator, std::vector<char*> *components) {
components->clear();
for (;;) {
while (is_space(*s)) s++;
char *d = strchr(s, separator);
if (d == NULL) {
if (*s)
components->push_back(s);
return;
}
*d = 0;
char *e = d;
while (e > s && is_space(e[-1]))
*--e = 0;
components->push_back(s);
s = d + 1;
}
}
void PrintHexString(const void *data, size_t data_size, char *result) {
for (size_t i = 0; i < data_size; i++) {
uint8 c = ((uint8*)data)[i];
*result++ = "0123456789abcdef"[c >> 4];
*result++ = "0123456789abcdef"[c & 0xF];
}
*result++ = 0;
}
bool ParseBase64Key(const char *s, uint8 key[32]) {
size_t size = 32;
return base64_decode((uint8*)s, strlen(s), key, &size) && size == 32;
}

23
util.h
View file

@ -3,7 +3,7 @@
#pragma once #pragma once
#include "tunsafe_types.h" #include "tunsafe_types.h"
uint8 *base64_encode(const uint8 *input, size_t length, size_t *out_length); char *base64_encode(const uint8 *input, size_t length, char *output, size_t output_size, size_t *actual_size);
bool base64_decode(uint8 *in, size_t inLen, uint8 *out, size_t *outLen); bool base64_decode(uint8 *in, size_t inLen, uint8 *out, size_t *outLen);
bool IsOnlyZeros(const uint8 *data, size_t data_size); bool IsOnlyZeros(const uint8 *data, size_t data_size);
@ -17,9 +17,28 @@ char *my_strndup(const char *p, size_t size);
size_t my_strlcpy(char *dst, size_t dstsize, const char *src); size_t my_strlcpy(char *dst, size_t dstsize, const char *src);
template<typename T, typename U> static inline T postinc(T&x, U v) { template<typename T, typename U> static inline T postinc(T&x, U v) {
T t = x; T t = x;
x += v; x += v;
return t; return t;
} }
template<typename T, typename U> static inline T exch(T&x, U v) {
T t = x;
x = v;
return t;
}
template<typename T> static inline T exch_null(T&x) {
T t = x;
x = NULL;
return t;
}
bool is_space(uint8_t c);
void OsGetRandomBytes(uint8 *dst, size_t dst_size);
bool ParseConfigKeyValue(char *m, std::vector<std::pair<char *, char*>> *result);
bool ParseHexString(const char *text, void *data, size_t data_size);
void PrintHexString(const void *data, size_t data_size, char *result);
void SplitString(char *s, int separator, std::vector<char*> *components);
bool ParseBase64Key(const char *s, uint8 key[32]);

View file

@ -297,6 +297,23 @@ size_t GetConfigPath(char *path, size_t path_size) {
return last + 7 - path; return last + 7 - path;
} }
bool ExpandConfigPath(const char *basename, char *fullname, size_t fullname_size) {
size_t len = strlen(basename);
if (FindFilenameComponent(basename)[0]) {
if (len >= fullname_size)
return false;
memcpy(fullname, basename, len + 1);
return true;
}
size_t clen = GetConfigPath(fullname, fullname_size);
if (clen == 0 || clen + len >= fullname_size)
return false;
memcpy(fullname + clen, basename, (len + 1) * sizeof(fullname[0]));
return true;
}
static bool ContainsDotDot(const char *path) { static bool ContainsDotDot(const char *path) {
for (uint8 last = 0, cur; (cur = path[0]) != '\0'; last = cur, path++) for (uint8 last = 0, cur; (cur = path[0]) != '\0'; last = cur, path++)
if (cur == '.' && last == cur) if (cur == '.' && last == cur)
@ -308,7 +325,7 @@ bool EnsureValidConfigPath(const char *path) {
char buf[1024]; char buf[1024];
size_t len = GetConfigPath(buf, sizeof(buf)); size_t len = GetConfigPath(buf, sizeof(buf));
return (len != 0) && (strlen(path) > len && memcmp(path, buf, len) == 0 && !ContainsDotDot(path + len)); return (len != 0) && (strlen(path) > len && _strnicmp(path, buf, len) == 0 && !ContainsDotDot(path + len));
} }
bool RunProcessAsAdminWithArgs(const char *args, bool wait_for_exit) { bool RunProcessAsAdminWithArgs(const char *args, bool wait_for_exit) {
@ -376,3 +393,65 @@ RECT MakeRect(int l, int t, int r, int b) {
RECT rr = { l, t, r, b }; RECT rr = { l, t, r, b };
return rr; return rr;
} }
// Retrieve the device path to the TAP adapter.
#define kAdapterKeyName "SYSTEM\\CurrentControlSet\\Control\\Class\\{4D36E972-E325-11CE-BFC1-08002BE10318}"
#define kNetworkConnectionsKeyName "SYSTEM\\CurrentControlSet\\Control\\Network\\{4D36E972-E325-11CE-BFC1-08002BE10318}"
#define kTapComponentId "tap0901"
void GetTapAdapterInfo(std::vector<GuidAndDevName> *result) {
LONG err;
HKEY adapter_key, device_key, network_connections_key;
bool retval = false;
GuidAndDevName gn;
err = RegOpenKeyEx(HKEY_LOCAL_MACHINE, kAdapterKeyName, 0, KEY_READ, &adapter_key);
if (err != ERROR_SUCCESS) {
RERROR("GetTapAdapterName: RegOpenKeyEx failed: 0x%X", GetLastError());
return;
}
for (int i = 0; !retval; i++) {
char keyname[64 + sizeof(kAdapterKeyName) + 1 + 32 /* some margin */];
char value[64];
DWORD len = sizeof(value), type;
err = RegEnumKeyEx(adapter_key, i, value, &len, NULL, NULL, NULL, NULL);
if (err == ERROR_NO_MORE_ITEMS)
break;
if (err != ERROR_SUCCESS) {
RERROR("GetTapAdapterName: RegEnumKeyEx failed: 0x%X", GetLastError());
break;
}
snprintf(keyname, sizeof(keyname), "%s\\%s", kAdapterKeyName, value);
err = RegOpenKeyEx(HKEY_LOCAL_MACHINE, keyname, 0, KEY_READ, &device_key);
if (err == ERROR_SUCCESS) {
len = sizeof(value);
err = RegQueryValueEx(device_key, "ComponentId", NULL, &type, (LPBYTE)value, &len);
if (err == ERROR_SUCCESS && type == REG_SZ && !memcmp(value, kTapComponentId, sizeof(kTapComponentId))) {
len = sizeof(gn.guid);
err = RegQueryValueEx(device_key, "NetCfgInstanceId", NULL, &type, (LPBYTE)gn.guid, &len);
if (err == ERROR_SUCCESS && type == REG_SZ) {
gn.guid[sizeof(gn.guid) - 1] = 0;
gn.name[0] = 0;
snprintf(keyname, sizeof(keyname), "%s\\%s\\Connection", kNetworkConnectionsKeyName, gn.guid);
err = RegOpenKeyEx(HKEY_LOCAL_MACHINE, keyname, 0, KEY_READ, &network_connections_key);
if (err == ERROR_SUCCESS) {
len = sizeof(gn.guid);
err = RegQueryValueEx(network_connections_key, "Name", NULL, &type, (LPBYTE)gn.name, &len);
if (err == ERROR_SUCCESS && type == REG_SZ) {
gn.name[sizeof(gn.guid) - 1] = 0;
} else {
gn.name[0] = 0;
}
RegCloseKey(network_connections_key);
}
result->push_back(gn);
}
}
RegCloseKey(device_key);
}
}
RegCloseKey(adapter_key);
}

View file

@ -1,6 +1,7 @@
// SPDX-License-Identifier: AGPL-1.0-only // SPDX-License-Identifier: AGPL-1.0-only
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved. // Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
#include "tunsafe_types.h" #include "tunsafe_types.h"
#include <vector>
#pragma once #pragma once
const char *FindFilenameComponent(const char *s); const char *FindFilenameComponent(const char *s);
@ -47,6 +48,7 @@ void ShellExecuteFromExplorer(
int nShowCmd = SW_SHOWNORMAL); int nShowCmd = SW_SHOWNORMAL);
size_t GetConfigPath(char *path, size_t path_size); size_t GetConfigPath(char *path, size_t path_size);
bool ExpandConfigPath(const char *basename, char *fullname, size_t fullname_size);
bool EnsureValidConfigPath(const char *path); bool EnsureValidConfigPath(const char *path);
bool RunProcessAsAdminWithArgs(const char *args, bool wait_for_exit); bool RunProcessAsAdminWithArgs(const char *args, bool wait_for_exit);
@ -54,3 +56,8 @@ bool RestartProcessAsAdministrator();
bool SetClipboardString(const char *string); bool SetClipboardString(const char *string);
RECT GetParentRect(HWND wnd); RECT GetParentRect(HWND wnd);
RECT MakeRect(int l, int t, int r, int b); RECT MakeRect(int l, int t, int r, int b);
struct GuidAndDevName {
char guid[40];
char name[64];
};
void GetTapAdapterInfo(std::vector<GuidAndDevName> *result);

View file

@ -15,8 +15,7 @@
#include "ipzip2/ipzip2.h" #include "ipzip2/ipzip2.h"
#include "wireguard.h" #include "wireguard.h"
#include "wireguard_config.h" #include "wireguard_config.h"
#include "util.h"
uint64 OsGetMilliseconds();
enum { enum {
IPV4_HEADER_SIZE = 20, IPV4_HEADER_SIZE = 20,
@ -36,22 +35,23 @@ WireguardProcessor::WireguardProcessor(UdpInterface *udp, TunInterface *tun, Pro
add_routes_mode_ = true; add_routes_mode_ = true;
dns_blocking_ = true; dns_blocking_ = true;
internet_blocking_ = kBlockInternet_Default; internet_blocking_ = kBlockInternet_Default;
is_started_ = false;
stats_last_bytes_in_ = 0; stats_last_bytes_in_ = 0;
stats_last_bytes_out_ = 0; stats_last_bytes_out_ = 0;
stats_last_ts_ = OsGetMilliseconds(); stats_last_ts_ = OsGetMilliseconds();
main_thread_scheduled_ = NULL;
main_thread_scheduled_last_ = &main_thread_scheduled_;
} }
WireguardProcessor::~WireguardProcessor() { WireguardProcessor::~WireguardProcessor() {
} }
void WireguardProcessor::SetListenPort(int listen_port) { void WireguardProcessor::SetListenPort(int listen_port) {
if (listen_port_ != listen_port) {
listen_port_ = listen_port; listen_port_ = listen_port;
if (is_started_ && !ConfigureUdp()) {
RINFO("ConfigureUdp failed");
}
}
} }
void WireguardProcessor::AddDnsServer(const IpAddr &sin) { void WireguardProcessor::AddDnsServer(const IpAddr &sin) {
std::vector<IpAddr> *target = (sin.sin.sin_family == AF_INET6) ? &dns6_addr_ : &dns_addr_; std::vector<IpAddr> *target = (sin.sin.sin_family == AF_INET6) ? &dns6_addr_ : &dns_addr_;
@ -66,6 +66,11 @@ bool WireguardProcessor::SetTunAddress(const WgCidrAddr &addr) {
return true; return true;
} }
void WireguardProcessor::ClearTunAddress() {
tun_addr_.size = 0;
tun6_addr_.size = 0;
}
void WireguardProcessor::AddExcludedIp(const WgCidrAddr &cidr_addr) { void WireguardProcessor::AddExcludedIp(const WgCidrAddr &cidr_addr) {
excluded_ips_.push_back(cidr_addr); excluded_ips_.push_back(cidr_addr);
} }
@ -129,23 +134,29 @@ static bool IsWgCidrAddrSubsetOf(const WgCidrAddr &inner, const WgCidrAddr &oute
} }
bool WireguardProcessor::Start() { bool WireguardProcessor::Start() {
assert(dev_.IsMainThread()); return ConfigureUdp() && ConfigureTun();
if (!udp_->Initialize(listen_port_))
return false;
if (tun_addr_.size != 32) {
RERROR("No IPv4 address configured");
return false;
} }
bool WireguardProcessor::ConfigureUdp() {
assert(dev_.IsMainThread());
return udp_->Configure(listen_port_);
}
bool WireguardProcessor::ConfigureTun() {
assert(dev_.IsMainThread());
TunInterface::TunConfig config = {0};
if (tun_addr_.size == 32) {
if (tun_addr_.cidr >= 31) { if (tun_addr_.cidr >= 31) {
RERROR("TAP is not compatible CIDR /31 or /32. Changing to /24"); RERROR("TAP is not compatible CIDR /31 or /32. Changing to /24");
tun_addr_.cidr = 24; tun_addr_.cidr = 24;
} }
TunInterface::TunConfig config = {0};
config.ip = ReadBE32(tun_addr_.addr); config.ip = ReadBE32(tun_addr_.addr);
config.cidr = tun_addr_.cidr; config.cidr = tun_addr_.cidr;
} else {
RERROR("No IPv4 address configured");
}
config.mtu = mtu_; config.mtu = mtu_;
config.pre_post_commands = pre_post_; config.pre_post_commands = pre_post_;
config.excluded_ips = excluded_ips_; config.excluded_ips = excluded_ips_;
@ -205,7 +216,7 @@ bool WireguardProcessor::Start() {
config.ipv6_dns = dns6_addr_; config.ipv6_dns = dns6_addr_;
TunInterface::TunConfigOut config_out; TunInterface::TunConfigOut config_out;
if (!tun_->Initialize(std::move(config), &config_out)) if (!tun_->Configure(std::move(config), &config_out))
return false; return false;
SetupCompressionHeader(dev_.compression_header()); SetupCompressionHeader(dev_.compression_header());
@ -221,6 +232,7 @@ bool WireguardProcessor::Start() {
} }
} }
is_started_ = true;
return true; return true;
} }
@ -395,22 +407,6 @@ getout:
FreePacket(packet); FreePacket(packet);
} }
void WgPeer::AddPacketToPeerQueue(Packet *packet) {
assert(IsPeerLocked());
// Keep only the first MAX_QUEUED_PACKETS packets.
while (num_queued_packets_ >= MAX_QUEUED_PACKETS_PER_PEER) {
Packet *packet = first_queued_packet_;
first_queued_packet_ = packet->next;
num_queued_packets_--;
FreePacket(packet);
}
// Add the packet to the out queue that will get sent once handshake completes
*last_queued_packet_ptr_ = packet;
last_queued_packet_ptr_ = &packet->next;
packet->next = NULL;
num_queued_packets_++;
}
// This function must be called with the peer lock held. It will remove the lock // This function must be called with the peer lock held. It will remove the lock
void WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet) { void WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet) {
assert(peer->IsPeerLocked()); assert(peer->IsPeerLocked());
@ -427,11 +423,17 @@ void WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Pac
if ((keypair = peer->curr_keypair_) == NULL || if ((keypair = peer->curr_keypair_) == NULL ||
(send_ctr = keypair->send_ctr) >= REJECT_AFTER_MESSAGES) { (send_ctr = keypair->send_ctr) >= REJECT_AFTER_MESSAGES) {
peer->AddPacketToPeerQueue(packet); // If RemovePeer has been called then discard any packets currently being written to it.
// curr_keypair_ is NULL when RemovePeer has been called so it's safe to do this here.
if (peer->marked_for_delete_)
goto getout_discard;
peer->AddPacketToPeerQueue_Locked(packet);
WG_RELEASE_LOCK(peer->mutex_); WG_RELEASE_LOCK(peer->mutex_);
ScheduleNewHandshake(peer); peer->ScheduleNewHandshake();
return; return;
} }
assert(!peer->marked_for_delete_);
stats_.tun_bytes_in += size; stats_.tun_bytes_in += size;
stats_.tun_packets_in++; stats_.tun_packets_in++;
@ -524,13 +526,14 @@ add_padding:
WriteLE32(write -= 4, keypair->remote_key_id); WriteLE32(write -= 4, keypair->remote_key_id);
*--write = tag; *--write = tag;
// Not using any fields from now on
WG_RELEASE_LOCK(peer->mutex_);
header_size = data - write; header_size = data - write;
packet->size = (int)(size + header_size + keypair->auth_tag_length);
peer->tx_bytes_ += packet->size;
stats_.compression_wg_saved_out += (int64)16 - header_size; stats_.compression_wg_saved_out += (int64)16 - header_size;
packet->data = data - header_size; packet->data = data - header_size;
packet->size = (int)(size + header_size + keypair->auth_tag_length);
// Not using any fields from now on
WG_RELEASE_LOCK(peer->mutex_);
// todo: figure out what to actually use as ad. // todo: figure out what to actually use as ad.
ad = write_after_ack_header; ad = write_after_ack_header;
@ -540,6 +543,9 @@ need_big_packet:
#else #else
{ {
#endif // #if WITH_SHORT_HEADERS #endif // #if WITH_SHORT_HEADERS
packet->size = (int)(size + sizeof(MessageData) + keypair->auth_tag_length);
peer->tx_bytes_ += packet->size;
// Not using any fields from now on // Not using any fields from now on
WG_RELEASE_LOCK(peer->mutex_); WG_RELEASE_LOCK(peer->mutex_);
@ -547,7 +553,6 @@ need_big_packet:
((MessageData*)data)[-1].receiver_id = keypair->remote_key_id; ((MessageData*)data)[-1].receiver_id = keypair->remote_key_id;
((MessageData*)data)[-1].counter = ToLE64(send_ctr); ((MessageData*)data)[-1].counter = ToLE64(send_ctr);
packet->data = data - sizeof(MessageData); packet->data = data - sizeof(MessageData);
packet->size = (int)(size + sizeof(MessageData) + keypair->auth_tag_length);
ad = NULL; ad = NULL;
ad_len = 0; ad_len = 0;
} }
@ -556,7 +561,7 @@ need_big_packet:
DoWriteUdpPacket(packet); DoWriteUdpPacket(packet);
if (want_handshake) if (want_handshake)
ScheduleNewHandshake(peer); peer->ScheduleNewHandshake();
return; return;
getout_discard: getout_discard:
@ -608,38 +613,32 @@ void WireguardProcessor::DoWriteUdpPacket(Packet *packet) {
ScrambleUnscrambleAndWrite(packet, &dev_.header_obfuscation_key_, udp_); ScrambleUnscrambleAndWrite(packet, &dev_.header_obfuscation_key_, udp_);
} }
void WireguardProcessor::ScheduleNewHandshake(WgPeer *peer) {
if (peer->main_thread_scheduled_.fetch_or(WgPeer::kMainThreadScheduled_ScheduleHandshake) == 0) {
peer->main_thread_scheduled_next_ = NULL;
WG_ACQUIRE_LOCK(main_thread_scheduled_lock_);
*main_thread_scheduled_last_ = peer;
main_thread_scheduled_last_ = &peer->main_thread_scheduled_next_;
WG_RELEASE_LOCK(main_thread_scheduled_lock_);
// todo: in multithreaded impl need to trigger |RunAllMainThreadScheduled| to get called
}
}
void WireguardProcessor::RunAllMainThreadScheduled() { void WireguardProcessor::RunAllMainThreadScheduled() {
WgPeer *peer, *next;
assert(dev_.IsMainThread()); assert(dev_.IsMainThread());
if (main_thread_scheduled_ == NULL) if (dev_.main_thread_scheduled_ == NULL)
return; return;
WG_ACQUIRE_LOCK(main_thread_scheduled_lock_); WG_ACQUIRE_LOCK(dev_.main_thread_scheduled_lock_);
WgPeer *peer = main_thread_scheduled_; peer = dev_.main_thread_scheduled_;
main_thread_scheduled_ = NULL; dev_.main_thread_scheduled_ = NULL;
main_thread_scheduled_last_ = &main_thread_scheduled_; dev_.main_thread_scheduled_last_ = &dev_.main_thread_scheduled_;
WG_RELEASE_LOCK(main_thread_scheduled_lock_); WG_RELEASE_LOCK(dev_.main_thread_scheduled_lock_);
for (; peer; peer = next) {
// todo: for the multithreaded use case figure out whether to use atomic_thread_fence here,
// because we need to read this next value before any other thread sees the 0 we write
// to peer->main_thread_scheduled_.
next = peer->main_thread_scheduled_next_;
if (peer->marked_for_delete_)
continue;
while (peer) {
// todo: for the multithreaded use case figure out whether to use atomic_thread_fence here.
WgPeer *next = peer->main_thread_scheduled_next_;
uint32 ev = peer->main_thread_scheduled_.exchange(0); uint32 ev = peer->main_thread_scheduled_.exchange(0);
if (ev & WgPeer::kMainThreadScheduled_ScheduleHandshake) { if (ev & WgPeer::kMainThreadScheduled_ScheduleHandshake) {
peer->handshake_attempts_ = 0; peer->handshake_attempts_ = 0;
SendHandshakeInitiation(peer); SendHandshakeInitiation(peer);
} }
peer = next;
} }
} }
@ -658,6 +657,7 @@ void WireguardProcessor::SendHandshakeInitiation(WgPeer *peer) {
procdel_->OnConnectionRetry(attempts); procdel_->OnConnectionRetry(attempts);
peer->OnHandshakeInitSent(); peer->OnHandshakeInitSent();
packet->addr = peer->endpoint_; packet->addr = peer->endpoint_;
peer->tx_bytes_ += packet->size;
WG_RELEASE_LOCK(peer->mutex_); WG_RELEASE_LOCK(peer->mutex_);
DoWriteUdpPacket(packet); DoWriteUdpPacket(packet);
if (attempts > 1 && attempts <= 20) if (attempts > 1 && attempts <= 20)
@ -696,19 +696,21 @@ void WireguardProcessor::HandleUdpPacket(Packet *packet, bool overload) {
#endif // WITH_SHORT_HEADERS #endif // WITH_SHORT_HEADERS
} else if (type == MESSAGE_HANDSHAKE_COOKIE) { } else if (type == MESSAGE_HANDSHAKE_COOKIE) {
assert(dev_.IsMainThread()); assert(dev_.IsMainThread());
if (packet->size != sizeof(MessageHandshakeCookie)) if (packet->size != sizeof(MessageHandshakeCookie) || !dev_.is_private_key_initialized())
goto invalid_size; goto invalid_size;
HandleHandshakeCookiePacket(packet); HandleHandshakeCookiePacket(packet);
} else if (type == MESSAGE_HANDSHAKE_INITIATION) { } else if (type == MESSAGE_HANDSHAKE_INITIATION) {
assert(dev_.IsMainThread()); assert(dev_.IsMainThread());
if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeInitiation)) : (packet->size != sizeof(MessageHandshakeInitiation))) if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeInitiation)) : (packet->size != sizeof(MessageHandshakeInitiation)) ||
!dev_.is_private_key_initialized())
goto invalid_size; goto invalid_size;
stats_.handshakes_in++; stats_.handshakes_in++;
if (CheckIncomingHandshakeRateLimit(packet, overload)) if (CheckIncomingHandshakeRateLimit(packet, overload))
HandleHandshakeInitiationPacket(packet); HandleHandshakeInitiationPacket(packet);
} else if (type == MESSAGE_HANDSHAKE_RESPONSE) { } else if (type == MESSAGE_HANDSHAKE_RESPONSE) {
assert(dev_.IsMainThread()); assert(dev_.IsMainThread());
if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeResponse)) : (packet->size != sizeof(MessageHandshakeResponse))) if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeResponse)) : (packet->size != sizeof(MessageHandshakeResponse)) ||
!dev_.is_private_key_initialized())
goto invalid_size; goto invalid_size;
if (CheckIncomingHandshakeRateLimit(packet, overload)) if (CheckIncomingHandshakeRateLimit(packet, overload))
HandleHandshakeResponsePacket(packet); HandleHandshakeResponsePacket(packet);
@ -749,6 +751,8 @@ void WgPeer::CopyEndpointToPeer_Locked(WgKeypair *keypair, const IpAddr *addr) {
#if WITH_SHORT_HEADERS #if WITH_SHORT_HEADERS
void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packet) { void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packet) {
assert(dev_.IsMainOrDataThread());
uint8 *data = packet->data + 1; uint8 *data = packet->data + 1;
size_t bytes_left = packet->size - 1; size_t bytes_left = packet->size - 1;
WgKeypair *keypair; WgKeypair *keypair;
@ -832,6 +836,8 @@ void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packe
WG_ACQUIRE_LOCK(keypair->peer->mutex_); WG_ACQUIRE_LOCK(keypair->peer->mutex_);
keypair->peer->rx_bytes_ += packet->size;
if (keypair->recv_key_state == WgKeypair::KEY_INVALID) if (keypair->recv_key_state == WgKeypair::KEY_INVALID)
goto getout_unlock; goto getout_unlock;
@ -896,7 +902,7 @@ void WireguardProcessor::HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *key
WgKeypair *curr_keypair = peer->curr_keypair_; WgKeypair *curr_keypair = peer->curr_keypair_;
if (curr_keypair && curr_keypair->recv_key_state == WgKeypair::KEY_WANT_REFRESH) { if (curr_keypair && curr_keypair->recv_key_state == WgKeypair::KEY_WANT_REFRESH) {
curr_keypair->recv_key_state = WgKeypair::KEY_DID_REFRESH; curr_keypair->recv_key_state = WgKeypair::KEY_DID_REFRESH;
ScheduleNewHandshake(peer); peer->ScheduleNewHandshake();
} }
if (data_size == 0) { if (data_size == 0) {
@ -965,6 +971,8 @@ getout:
} }
void WireguardProcessor::HandleDataPacket(Packet *packet) { void WireguardProcessor::HandleDataPacket(Packet *packet) {
assert(dev_.IsMainOrDataThread());
uint8 *data = packet->data; uint8 *data = packet->data;
size_t data_size = packet->size; size_t data_size = packet->size;
uint32 key_id = ((MessageData*)data)->receiver_id; uint32 key_id = ((MessageData*)data)->receiver_id;
@ -984,6 +992,7 @@ getout:
} }
WG_ACQUIRE_LOCK(keypair->peer->mutex_); WG_ACQUIRE_LOCK(keypair->peer->mutex_);
keypair->peer->rx_bytes_ += data_size;
if (keypair->recv_key_state == WgKeypair::KEY_INVALID) { if (keypair->recv_key_state == WgKeypair::KEY_INVALID) {
stats_.error_key_id++; stats_.error_key_id++;
WG_RELEASE_LOCK(keypair->peer->mutex_); WG_RELEASE_LOCK(keypair->peer->mutex_);
@ -993,6 +1002,8 @@ getout:
WG_RELEASE_LOCK(keypair->peer->mutex_); WG_RELEASE_LOCK(keypair->peer->mutex_);
goto getout; goto getout;
} else { } else {
assert(!keypair->peer->marked_for_delete_);
WgPeer::CopyEndpointToPeer_Locked(keypair, &packet->addr); WgPeer::CopyEndpointToPeer_Locked(keypair, &packet->addr);
HandleAuthenticatedDataPacket_WillUnlock(keypair, packet, data + sizeof(MessageData), data_size - sizeof(MessageData) - keypair->auth_tag_length); HandleAuthenticatedDataPacket_WillUnlock(keypair, packet, data + sizeof(MessageData), data_size - sizeof(MessageData) - keypair->auth_tag_length);
} }
@ -1119,7 +1130,7 @@ void WireguardProcessor::SecondLoop() {
uint32 mask; uint32 mask;
{ {
WG_SCOPED_LOCK(peer->mutex_); WG_SCOPED_LOCK(peer->mutex_);
mask = peer->CheckTimeouts(now); mask = peer->CheckTimeouts_Locked(now);
if (mask == 0) if (mask == 0)
continue; continue;
if (mask & WgPeer::ACTION_SEND_KEEPALIVE) if (mask & WgPeer::ACTION_SEND_KEEPALIVE)

View file

@ -66,6 +66,7 @@ enum InternetBlockState {
}; };
class WireguardProcessor { class WireguardProcessor {
friend class WgConfig;
public: public:
WireguardProcessor(UdpInterface *udp, TunInterface *tun, ProcessorDelegate *procdel); WireguardProcessor(UdpInterface *udp, TunInterface *tun, ProcessorDelegate *procdel);
~WireguardProcessor(); ~WireguardProcessor();
@ -73,6 +74,7 @@ public:
void SetListenPort(int listen_port); void SetListenPort(int listen_port);
void AddDnsServer(const IpAddr &sin); void AddDnsServer(const IpAddr &sin);
bool SetTunAddress(const WgCidrAddr &addr); bool SetTunAddress(const WgCidrAddr &addr);
void ClearTunAddress();
void AddExcludedIp(const WgCidrAddr &cidr_addr); void AddExcludedIp(const WgCidrAddr &cidr_addr);
void SetMtu(int mtu); void SetMtu(int mtu);
void SetAddRoutesMode(bool mode); void SetAddRoutesMode(bool mode);
@ -91,6 +93,9 @@ public:
bool Start(); bool Start();
bool ConfigureUdp();
bool ConfigureTun();
WgDevice &dev() { return dev_; } WgDevice &dev() { return dev_; }
TunInterface::PrePostCommands &prepost() { return pre_post_; } TunInterface::PrePostCommands &prepost() { return pre_post_; }
const WgCidrAddr &tun_addr() { return tun_addr_; } const WgCidrAddr &tun_addr() { return tun_addr_; }
@ -100,7 +105,6 @@ private:
void DoWriteUdpPacket(Packet *packet); void DoWriteUdpPacket(Packet *packet);
void WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet); void WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet);
void SendHandshakeInitiation(WgPeer *peer); void SendHandshakeInitiation(WgPeer *peer);
void ScheduleNewHandshake(WgPeer *peer);
void SendKeepalive_Locked(WgPeer *peer); void SendKeepalive_Locked(WgPeer *peer);
void SendQueuedPackets_Locked(WgPeer *peer); void SendQueuedPackets_Locked(WgPeer *peer);
@ -110,29 +114,25 @@ private:
void HandleDataPacket(Packet *packet); void HandleDataPacket(Packet *packet);
void HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet, uint8 *data, size_t data_size); void HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet, uint8 *data, size_t data_size);
void HandleShortHeaderFormatPacket(uint32 tag, Packet *packet); void HandleShortHeaderFormatPacket(uint32 tag, Packet *packet);
bool CheckIncomingHandshakeRateLimit(Packet *packet, bool overload); bool CheckIncomingHandshakeRateLimit(Packet *packet, bool overload);
bool HandleIcmpv6NeighborSolicitation(const byte *data, size_t data_size); bool HandleIcmpv6NeighborSolicitation(const byte *data, size_t data_size);
void SetupCompressionHeader(WgPacketCompressionVer01 *c); void SetupCompressionHeader(WgPacketCompressionVer01 *c);
void NotifyHandshakeComplete(); void NotifyHandshakeComplete();
int listen_port_;
ProcessorDelegate *procdel_; ProcessorDelegate *procdel_;
TunInterface *tun_; TunInterface *tun_;
UdpInterface *udp_; UdpInterface *udp_;
int mtu_;
WgProcessorStats stats_; uint16 listen_port_;
uint16 mtu_;
bool dns_blocking_; bool dns_blocking_;
uint8 internet_blocking_; uint8 internet_blocking_;
bool add_routes_mode_; bool add_routes_mode_;
bool network_discovery_spoofing_; bool network_discovery_spoofing_;
bool did_have_first_handshake_; bool did_have_first_handshake_;
bool is_started_;
uint8 network_discovery_mac_[6]; uint8 network_discovery_mac_[6];
WgDevice dev_; WgDevice dev_;
@ -140,14 +140,12 @@ private:
WgCidrAddr tun_addr_; WgCidrAddr tun_addr_;
WgCidrAddr tun6_addr_; WgCidrAddr tun6_addr_;
WgProcessorStats stats_;
std::vector<IpAddr> dns_addr_, dns6_addr_; std::vector<IpAddr> dns_addr_, dns6_addr_;
TunInterface::PrePostCommands pre_post_; TunInterface::PrePostCommands pre_post_;
// Queue of things scheduled to run on the main thread.
WG_DECLARE_LOCK(main_thread_scheduled_lock_);
WgPeer *main_thread_scheduled_, **main_thread_scheduled_last_;
uint64 stats_last_bytes_in_, stats_last_bytes_out_; uint64 stats_last_bytes_in_, stats_last_bytes_out_;
uint64 stats_last_ts_; uint64 stats_last_ts_;

View file

@ -45,12 +45,26 @@ char *PrintIpAddr(const IpAddr &addr, char buf[kSizeOfAddress]) {
return buf; return buf;
} }
char *PrintWgCidrAddr(const WgCidrAddr &addr, char buf[kSizeOfAddress]) {
if (addr.size == 32) {
print_ip_prefix(buf, AF_INET, addr.addr, addr.cidr);
} else if (addr.size == 128) {
print_ip_prefix(buf, AF_INET6, addr.addr, addr.cidr);
} else {
buf[0] = 0;
}
return buf;
}
struct Addr { struct Addr {
byte addr[4]; byte addr[4];
uint8 cidr; uint8 cidr;
}; };
static bool ParseCidrAddr(char *s, WgCidrAddr *out) { bool ParseCidrAddr(char *s, WgCidrAddr *out) {
char *slash = strchr(s, '/'); char *slash = strchr(s, '/');
if (!slash) if (!slash)
return false; return false;
@ -92,15 +106,6 @@ bool DnsResolver::Resolve(const char *hostname, IpAddr *result) {
char buf[kSizeOfAddress]; char buf[kSizeOfAddress];
memset(result, 0, sizeof(IpAddr)); memset(result, 0, sizeof(IpAddr));
if (inet_pton(AF_INET6, hostname, &result->sin6.sin6_addr) == 1) {
result->sin.sin_family = AF_INET6;
return true;
}
if (inet_pton(AF_INET, hostname, &result->sin.sin_addr) == 1) {
result->sin.sin_family = AF_INET;
return true;
}
// First check cache // First check cache
for (auto it = cache_.begin(); it != cache_.end(); ++it) { for (auto it = cache_.begin(); it != cache_.end(); ++it) {
@ -145,10 +150,7 @@ bool DnsResolver::Resolve(const char *hostname, IpAddr *result) {
} }
} }
bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver) {
static bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver) {
memset(sin, 0, sizeof(IpAddr)); memset(sin, 0, sizeof(IpAddr));
if (*s == '[') { if (*s == '[') {
char *end = strchr(s, ']'); char *end = strchr(s, ']');
@ -168,7 +170,11 @@ static bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver)
if (!x) return false; if (!x) return false;
*x = 0; *x = 0;
if (!resolver->Resolve(s, sin)) { if (inet_pton(AF_INET, s, &sin->sin.sin_addr) == 1) {
sin->sin.sin_family = AF_INET;
} else if (!resolver) {
return false;
} else if (!resolver->Resolve(s, sin)) {
RERROR("Unable to resolve %s", s); RERROR("Unable to resolve %s", s);
return false; return false;
} }
@ -177,18 +183,19 @@ static bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver)
} }
static bool ParseSockaddrInWithoutPort(char *s, IpAddr *sin, DnsResolver *resolver) { static bool ParseSockaddrInWithoutPort(char *s, IpAddr *sin, DnsResolver *resolver) {
if (!resolver->Resolve(s, sin)) { if (inet_pton(AF_INET6, s, &sin->sin6.sin6_addr) == 1) {
sin->sin.sin_family = AF_INET6;
return true;
} else if (inet_pton(AF_INET, s, &sin->sin.sin_addr) == 1) {
sin->sin.sin_family = AF_INET;
return true;
} else if (!resolver->Resolve(s, sin)) {
RERROR("Unable to resolve %s", s); RERROR("Unable to resolve %s", s);
return false; return false;
} }
return true; return true;
} }
static bool ParseBase64Key(const char *s, uint8 key[32]) {
size_t size = 32;
return base64_decode((uint8*)s, strlen(s), key, &size) && size == 32;
}
class WgFileParser { class WgFileParser {
public: public:
WgFileParser(WireguardProcessor *wg, DnsResolver *resolver) : wg_(wg), dns_resolver_(resolver) {} WgFileParser(WireguardProcessor *wg, DnsResolver *resolver) : wg_(wg), dns_resolver_(resolver) {}
@ -197,7 +204,7 @@ public:
void FinishGroup(); void FinishGroup();
struct Peer { struct Peer {
uint8 pub[32]; WgPublicKey pub;
uint8 psk[32]; uint8 psk[32];
}; };
Peer pi_; Peer pi_;
@ -206,29 +213,6 @@ public:
bool had_interface_ = false; bool had_interface_ = false;
}; };
bool is_space(uint8_t c) {
return c == ' ' || c == '\r' || c == '\n' || c == '\t';
}
void SplitString(char *s, int separator, std::vector<char*> *components) {
for (;;) {
while (is_space(*s)) s++;
char *d = strchr(s, separator);
if (d == NULL) {
if (*s)
components->push_back(s);
return;
}
*d = 0;
char *e = d;
while (e > s && is_space(e[-1]))
*--e = 0;
components->push_back(s);
s = d + 1;
}
}
static bool ParseBoolean(const char *str, bool *value) { static bool ParseBoolean(const char *str, bool *value) {
if (_stricmp(str, "true") == 0 || if (_stricmp(str, "true") == 0 ||
_stricmp(str, "yes") == 0 || _stricmp(str, "yes") == 0 ||
@ -285,7 +269,7 @@ static int ParseCipherSuite(const char *cipher) {
void WgFileParser::FinishGroup() { void WgFileParser::FinishGroup() {
if (peer_) { if (peer_) {
peer_->Initialize(pi_.pub, pi_.psk); peer_->SetPublicKey(pi_.pub);
peer_ = NULL; peer_ = NULL;
} }
} }
@ -303,7 +287,7 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
if (!ParseBase64Key(value, binkey)) if (!ParseBase64Key(value, binkey))
return false; return false;
had_interface_ = true; had_interface_ = true;
wg_->dev().Initialize(binkey); wg_->dev().SetPrivateKey(binkey);
} else if (strcmp(key, "ListenPort") == 0) { } else if (strcmp(key, "ListenPort") == 0) {
wg_->SetListenPort(atoi(value)); wg_->SetListenPort(atoi(value));
} else if (strcmp(key, "Address") == 0) { } else if (strcmp(key, "Address") == 0) {
@ -394,11 +378,12 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
return true; return true;
} }
if (strcmp(key, "PublicKey") == 0) { if (strcmp(key, "PublicKey") == 0) {
if (!ParseBase64Key(value, pi_.pub)) if (!ParseBase64Key(value, pi_.pub.bytes))
return false; return false;
} else if (strcmp(key, "PresharedKey") == 0) { } else if (strcmp(key, "PresharedKey") == 0) {
if (!ParseBase64Key(value, pi_.psk)) if (!ParseBase64Key(value, pi_.psk))
return false; return false;
peer_->SetPresharedKey(pi_.psk);
} else if (strcmp(key, "AllowedIPs") == 0) { } else if (strcmp(key, "AllowedIPs") == 0) {
SplitString(value, ',', &ss); SplitString(value, ',', &ss);
for (size_t i = 0; i < ss.size(); i++) { for (size_t i = 0; i < ss.size(); i++) {
@ -412,7 +397,8 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
return false; return false;
peer_->SetEndpoint(sin); peer_->SetEndpoint(sin);
} else if (strcmp(key, "PersistentKeepalive") == 0) { } else if (strcmp(key, "PersistentKeepalive") == 0) {
peer_->SetPersistentKeepalive(atoi(value)); if (!peer_->SetPersistentKeepalive(atoi(value)))
return false;
} else if (strcmp(key, "AllowMulticast") == 0) { } else if (strcmp(key, "AllowMulticast") == 0) {
bool b; bool b;
if (!ParseBoolean(value, &b)) if (!ParseBoolean(value, &b))
@ -524,3 +510,154 @@ bool ParseWireGuardConfigFile(WireguardProcessor *wg, const char *filename, DnsR
fclose(f); fclose(f);
return true; return true;
} }
static void CmsgAppendFmt(std::string *result, const char *fmt, ...) {
va_list va;
char buf[256];
va_start(va, fmt);
vsnprintf(buf, sizeof(buf), fmt, va);
(*result) += buf;
(*result) += '\n';
va_end(va);
}
static void CmsgAppendHex(std::string *result, const char *key, const void *data, size_t data_size) {
char *tmp = (char*)alloca(data_size * 2 + 2);
PrintHexString(data, data_size, tmp + 1);
tmp[0] = '=';
tmp[data_size * 2 + 1] = '\n';
(*result) += key;
result->append(tmp, data_size * 2 + 2);
}
void WgConfig::HandleConfigurationProtocolGet(WireguardProcessor *proc, std::string *result) {
char buf[kSizeOfAddress];
CmsgAppendHex(result, "private_key", proc->dev_.s_priv_, sizeof(proc->dev_.s_priv_));
if (proc->listen_port_)
CmsgAppendFmt(result, "listen_port=%d", proc->listen_port_);
if (proc->tun_addr_.size == 32)
CmsgAppendFmt(result, "address=%s", PrintWgCidrAddr(proc->tun_addr_, buf));
if (proc->tun6_addr_.size == 128)
CmsgAppendFmt(result, "address=%s", PrintWgCidrAddr(proc->tun6_addr_, buf));
for (WgPeer *peer = proc->dev_.peers_; peer; peer = peer->next_peer_) {
WG_SCOPED_LOCK(peer->lock_);
CmsgAppendHex(result, "public_key", peer->s_remote_.bytes, sizeof(peer->s_remote_));
if (!IsOnlyZeros(peer->preshared_key_, sizeof(peer->preshared_key_)))
CmsgAppendHex(result, "preshared_key", peer->preshared_key_, sizeof(peer->preshared_key_));
if (peer->tx_bytes_ | peer->rx_bytes_)
CmsgAppendFmt(result, "tx_bytes=%lld\nrx_bytes=%lld", peer->tx_bytes_, peer->rx_bytes_);
for (auto it = peer->allowed_ips_.begin(); it != peer->allowed_ips_.end(); ++it)
CmsgAppendFmt(result, "allowed_ip=%s", PrintWgCidrAddr(*it, buf));
if (peer->persistent_keepalive_ms_)
CmsgAppendFmt(result, "persistent_keepalive_interval=%d", peer->persistent_keepalive_ms_ / 1000);
if (peer->endpoint_.sin.sin_family == AF_INET)
CmsgAppendFmt(result, "endpoint=%s:%d", PrintIpAddr(peer->endpoint_, buf), htons(peer->endpoint_.sin.sin_port));
else if (peer->endpoint_.sin.sin_family == AF_INET6)
CmsgAppendFmt(result, "endpoint=[%s]:%d", PrintIpAddr(peer->endpoint_, buf), htons(peer->endpoint_.sin6.sin6_port));
if (peer->last_complete_handskake_timestamp_) {
uint64 millis_since = OsGetMilliseconds() - peer->last_complete_handskake_timestamp_;
uint64 when = time(NULL) - millis_since / 1000;
CmsgAppendFmt(result, "last_handshake_time_sec=%lld", when);
}
}
CmsgAppendFmt(result, "protocol_version=1");
}
bool WgConfig::HandleConfigurationProtocolMessage(WireguardProcessor *proc, const std::string &&message, std::string *result) {
std::string message_copy(std::move(message));
std::vector<std::pair<char *, char*>> kv;
bool is_set = false;
bool did_set_address = false;
WgPeer *peer = NULL;
WgCidrAddr cidr_addr;
IpAddr sin;
uint8 buf32[32];
assert(proc->dev().IsMainThread());
result->clear();
if (!ParseConfigKeyValue(&message_copy[0], &kv))
return false;
for (auto it : kv) {
char *key = it.first, *value = it.second;
if (strcmp(key, "get") == 0) {
if (strcmp(value, "1") != 0)
goto getout_fail;
HandleConfigurationProtocolGet(proc, result);
break;
} else if (strcmp(key, "set") == 0) {
if (strcmp(value, "1") != 0)
goto getout_fail;
is_set = true;
} else if (is_set) {
if (strcmp(key, "private_key") == 0) {
if (!ParseHexString(value, buf32, 32)) goto getout_fail;
proc->dev_.SetPrivateKey(buf32);
} else if (strcmp(key, "listen_port") == 0) {
int new_port = atoi(value);
proc->SetListenPort(new_port);
} else if (strcmp(key, "replace_peers") == 0) {
if (strcmp(value, "true") != 0) goto getout_fail;
proc->dev_.RemoveAllPeers();
} else if (strcmp(key, "address") == 0) {
if (!ParseCidrAddr(value, &cidr_addr)) goto getout_fail;
if (!did_set_address) {
did_set_address = true;
proc->ClearTunAddress();
}
if (!proc->SetTunAddress(cidr_addr)) goto getout_fail;
} else if (strcmp(key, "public_key") == 0) {
WgPublicKey pubkey;
if (!ParseHexString(value, pubkey.bytes, 32)) goto getout_fail;
peer = proc->dev_.GetPeerFromPublicKey(pubkey);
if (!peer) {
peer = proc->dev_.AddPeer();
peer->SetPublicKey(pubkey);
}
} else if (peer != NULL) {
if (strcmp(key, "remove") == 0) {
if (strcmp(value, "true") != 0) goto getout_fail;
peer->RemovePeer();
peer = NULL;
} else if (strcmp(key, "preshared_key") == 0) {
if (!ParseHexString(value, buf32, 32)) goto getout_fail;
peer->SetPresharedKey(buf32);
} else if (strcmp(key, "endpoint") == 0) {
if (!ParseSockaddrInWithPort(value, &sin, NULL)) goto getout_fail;
peer->SetEndpoint(sin);
} else if (strcmp(key, "persistent_keepalive_interval") == 0) {
if (!peer->SetPersistentKeepalive(atoi(value)))
goto getout_fail;
} else if (strcmp(key, "replace_allowed_ips") == 0) {
if (strcmp(value, "true") != 0) goto getout_fail;
peer->RemoveAllIps();
} else if (strcmp(key, "allowed_ip") == 0) {
if (!ParseCidrAddr(value, &cidr_addr)) goto getout_fail;
peer->AddIp(cidr_addr);
}
}
} else {
goto getout_fail;
}
}
// reconfigure the tun interface?
if (did_set_address) {
proc->ConfigureTun();
}
result->append("errno=0\n\n");
return true;
getout_fail:
(*result) = "errno=1\n\n";
return false;
}

View file

@ -30,11 +30,19 @@ private:
}; };
class WgConfig {
public:
static bool HandleConfigurationProtocolMessage(WireguardProcessor *proc, const std::string &&message, std::string *result);
private:
static void HandleConfigurationProtocolGet(WireguardProcessor *proc, std::string *result);
};
bool ParseWireGuardConfigFile(WireguardProcessor *wg, const char *filename, DnsResolver *dns_resolver); bool ParseWireGuardConfigFile(WireguardProcessor *wg, const char *filename, DnsResolver *dns_resolver);
#define kSizeOfAddress 64 #define kSizeOfAddress 64
const char *print_ip_prefix(char buf[kSizeOfAddress], int family, const void *ip, int prefixlen); const char *print_ip_prefix(char buf[kSizeOfAddress], int family, const void *ip, int prefixlen);
char *PrintIpAddr(const IpAddr &addr, char buf[kSizeOfAddress]); char *PrintIpAddr(const IpAddr &addr, char buf[kSizeOfAddress]);
char *PrintWgCidrAddr(const WgCidrAddr &addr, char buf[kSizeOfAddress]);
bool ParseCidrAddr(char *s, WgCidrAddr *out);
#endif // TINYVPN_TINYVPN_H_ #endif // TINYVPN_TINYVPN_H_

View file

@ -54,18 +54,27 @@ bool ReplayDetector::CheckReplay(uint64 seq_nr) {
WgDevice::WgDevice() { WgDevice::WgDevice() {
peers_ = NULL; peers_ = NULL;
last_peer_ptr_ = &peers_;
delegate_ = NULL; delegate_ = NULL;
header_obfuscation_ = false; header_obfuscation_ = false;
is_private_key_initialized_ = false;
next_rng_slot_ = 0; next_rng_slot_ = 0;
main_thread_scheduled_ = NULL;
main_thread_scheduled_last_ = &main_thread_scheduled_;
memset(&compression_header_, 0, sizeof(compression_header_)); memset(&compression_header_, 0, sizeof(compression_header_));
low_resolution_timestamp_ = cookie_secret_timestamp_ = OsGetMilliseconds(); low_resolution_timestamp_ = cookie_secret_timestamp_ = OsGetMilliseconds();
OsGetRandomBytes(cookie_secret_, sizeof(cookie_secret_)); OsGetRandomBytes(cookie_secret_, sizeof(cookie_secret_));
OsGetRandomBytes((uint8*)random_number_input_, sizeof(random_number_input_)); OsGetRandomBytes((uint8*)random_number_input_, sizeof(random_number_input_));
SetCurrentThreadAsMainThread(); main_thread_id_ = GetCurrentThreadId();
memset(s_priv_, 0, sizeof(s_priv_));
memset(s_pub_, 0, sizeof(s_pub_));
} }
WgDevice::~WgDevice() { WgDevice::~WgDevice() {
assert(IsMainThread());
RemoveAllPeers();
} }
void WgDevice::SecondLoop(uint64 now) { void WgDevice::SecondLoop(uint64 now) {
@ -151,7 +160,8 @@ static inline void ComputeHKDF2DH(uint8 ci[WG_HASH_LEN], uint8 k[WG_SYMMETRIC_KE
memzero_crypto(dh, sizeof(dh)); memzero_crypto(dh, sizeof(dh));
} }
void WgDevice::Initialize(const uint8 private_key[WG_PUBLIC_KEY_LEN]) { void WgDevice::SetPrivateKey(const uint8 private_key[WG_PUBLIC_KEY_LEN]) {
assert(IsMainThread());
// Derive the public key from the private key. // Derive the public key from the private key.
memcpy(s_priv_, private_key, sizeof(s_priv_)); memcpy(s_priv_, private_key, sizeof(s_priv_));
curve25519_donna(s_pub_, s_priv_, kCurve25519Basepoint); curve25519_donna(s_pub_, s_priv_, kCurve25519Basepoint);
@ -162,26 +172,31 @@ void WgDevice::Initialize(const uint8 private_key[WG_PUBLIC_KEY_LEN]) {
kLabelCookie, sizeof(kLabelCookie), s_pub_, sizeof(s_pub_)); kLabelCookie, sizeof(kLabelCookie), s_pub_, sizeof(s_pub_));
BlakeX2(precomputed_mac1_key_, sizeof(precomputed_mac1_key_), BlakeX2(precomputed_mac1_key_, sizeof(precomputed_mac1_key_),
kLabelMac1, sizeof(kLabelMac1), s_pub_, sizeof(s_pub_)); kLabelMac1, sizeof(kLabelMac1), s_pub_, sizeof(s_pub_));
is_private_key_initialized_ = true;
// Recompute peer data because it depends on my privkey
for (WgPeer *peer = peers_; peer; peer = peer->next_peer_)
peer->SetPublicKey(peer->s_remote_);
} }
WgPeer *WgDevice::AddPeer() { WgPeer *WgDevice::AddPeer() {
assert(IsMainThread()); assert(IsMainThread());
WgPeer *peer = new WgPeer(this); WgPeer *peer = new WgPeer(this);
WgPeer **pp = &peers_;
while (*pp)
pp = &(*pp)->next_peer_;
*pp = peer;
return peer; return peer;
} }
WgPeer *WgDevice::GetPeerFromPublicKey(uint8 public_key[WG_PUBLIC_KEY_LEN]) { void WgDevice::RemoveAllPeers() {
assert(IsMainThread()); assert(IsMainThread());
// todo: add O(1) lookup while (peers_)
for (WgPeer *peer = peers_; peer; peer = peer->next_peer_) { peers_->RemovePeer();
if (memcmp(peer->s_remote_, public_key, WG_PUBLIC_KEY_LEN) == 0)
return peer;
} }
return NULL;
WgPeer *WgDevice::GetPeerFromPublicKey(const WgPublicKey &pubkey) {
assert(IsMainThread());
auto it = peer_id_lookup_.find(pubkey);
return (it != peer_id_lookup_.end()) ? it->second : NULL;
} }
bool WgDevice::CheckCookieMac1(Packet *packet) { bool WgDevice::CheckCookieMac1(Packet *packet) {
@ -230,6 +245,7 @@ void WgDevice::CreateCookieMessage(MessageHandshakeCookie *dst, Packet *packet,
} }
void WgDevice::EraseKeypairAddrEntry_Locked(WgKeypair *kp) { void WgDevice::EraseKeypairAddrEntry_Locked(WgKeypair *kp) {
// todo: figure out how to make this multithread safe.
WgAddrEntry *ae = kp->addr_entry; WgAddrEntry *ae = kp->addr_entry;
assert(ae->ref_count >= 1); assert(ae->ref_count >= 1);
@ -313,7 +329,6 @@ void WgDevice::SetHeaderObfuscation(const char *key) {
#endif // WITH_HEADER_OBFUSCATION #endif // WITH_HEADER_OBFUSCATION
} }
WgPeer::WgPeer(WgDevice *dev) { WgPeer::WgPeer(WgDevice *dev) {
assert(dev->IsMainThread()); assert(dev->IsMainThread());
dev_ = dev; dev_ = dev;
@ -323,6 +338,7 @@ WgPeer::WgPeer(WgDevice *dev) {
expect_cookie_reply_ = false; expect_cookie_reply_ = false;
has_mac2_cookie_ = false; has_mac2_cookie_ = false;
pending_keepalive_ = false; pending_keepalive_ = false;
marked_for_delete_ = false;
allow_multicast_through_peer_ = false; allow_multicast_through_peer_ = false;
allow_endpoint_change_ = true; allow_endpoint_change_ = true;
supports_handshake_extensions_ = true; supports_handshake_extensions_ = true;
@ -331,6 +347,8 @@ WgPeer::WgPeer(WgDevice *dev) {
last_handshake_init_recv_timestamp_ = 0; last_handshake_init_recv_timestamp_ = 0;
last_complete_handskake_timestamp_ = 0; last_complete_handskake_timestamp_ = 0;
persistent_keepalive_ms_ = 0; persistent_keepalive_ms_ = 0;
rx_bytes_ = 0;
tx_bytes_ = 0;
timers_ = 0; timers_ = 0;
first_queued_packet_ = NULL; first_queued_packet_ = NULL;
last_queued_packet_ptr_ = &first_queued_packet_; last_queued_packet_ptr_ = &first_queued_packet_;
@ -343,15 +361,66 @@ WgPeer::WgPeer(WgDevice *dev) {
memset(last_timestamp_, 0, sizeof(last_timestamp_)); memset(last_timestamp_, 0, sizeof(last_timestamp_));
ipv4_broadcast_addr_ = 0xffffffff; ipv4_broadcast_addr_ = 0xffffffff;
memset(features_, 0, sizeof(features_)); memset(features_, 0, sizeof(features_));
memset(preshared_key_, 0, sizeof(preshared_key_));
memset(&s_remote_, 0, sizeof(s_remote_));
// Insert into the parent's linked list
*dev_->last_peer_ptr_ = this;
dev_->last_peer_ptr_ = &next_peer_;
} }
WgPeer::~WgPeer() { WgPeer::~WgPeer() {
// do not delete this directly, instead call RemovePeer
assert(marked_for_delete_);
assert(dev_->IsMainThread()); assert(dev_->IsMainThread());
assert(curr_keypair_ == NULL && next_keypair_ == NULL && prev_keypair_ == NULL);
assert(local_key_id_during_hs_ == 0);
assert(first_queued_packet_ == NULL);
}
void WgPeer::DelayedDelete(void *x) {
WgPeer *peer = (WgPeer*)x;
assert(peer->dev_->IsMainThread());
if (peer->main_thread_scheduled_ != 0) {
WG_ACQUIRE_LOCK(peer->dev_->main_thread_scheduled_lock_);
// Unlink myself from the main thread scheduled list
for (WgPeer **pp = &peer->dev_->main_thread_scheduled_; *pp; pp = &(*pp)->main_thread_scheduled_next_) {
if (*pp == peer) {
*pp = peer->main_thread_scheduled_next_;
break;
}
}
WG_RELEASE_LOCK(peer->dev_->main_thread_scheduled_lock_);
}
delete peer;
}
void WgPeer::RemovePeer() {
assert(dev_->IsMainThread());
assert(!marked_for_delete_);
// Find and unlink the peer from the parent's peer list
WgPeer **pp = &dev_->peers_;
while (*pp != this)
pp = &(*pp)->next_peer_;
if ((*pp = next_peer_) == NULL)
dev_->last_peer_ptr_ = pp;
RemoveAllIps();
dev_->peer_id_lookup_.erase(s_remote_);
WG_ACQUIRE_LOCK(mutex_); WG_ACQUIRE_LOCK(mutex_);
marked_for_delete_ = true;
ClearKeys_Locked(); ClearKeys_Locked();
ClearHandshake_Locked(); ClearHandshake_Locked();
ClearPacketQueue_Locked(); ClearPacketQueue_Locked();
WG_RELEASE_LOCK(mutex_); WG_RELEASE_LOCK(mutex_);
// The WgPeer instance may still be accessible from
// worker threads that already started processing a packet,
// so defer the actual delete of it.
dev_->delayed_delete_.Add(&WgPeer::DelayedDelete, this);
} }
void WgPeer::ClearKeys_Locked() { void WgPeer::ClearKeys_Locked() {
@ -382,21 +451,55 @@ void WgPeer::ClearPacketQueue_Locked() {
num_queued_packets_ = 0; num_queued_packets_ = 0;
} }
void WgPeer::Initialize(const uint8 spub[WG_PUBLIC_KEY_LEN], const uint8 preshared_key[WG_SYMMETRIC_KEY_LEN]) { void WgPeer::AddPacketToPeerQueue_Locked(Packet *packet) {
// Optionally use a preshared key, it defaults to all zeros. assert(IsPeerLocked());
assert(!marked_for_delete_);
// Keep only the first MAX_QUEUED_PACKETS packets.
while (num_queued_packets_ >= MAX_QUEUED_PACKETS_PER_PEER) {
Packet *packet = first_queued_packet_;
first_queued_packet_ = packet->next;
num_queued_packets_--;
FreePacket(packet);
}
// Add the packet to the out queue that will get sent once handshake completes
*last_queued_packet_ptr_ = packet;
last_queued_packet_ptr_ = &packet->next;
packet->next = NULL;
num_queued_packets_++;
}
void WgPeer::SetPublicKey(const WgPublicKey &spub) {
assert(dev_->IsMainThread());
assert(IsOnlyZeros(s_remote_.bytes, sizeof(s_remote_.bytes)) ||
memcmp(s_remote_.bytes, spub.bytes, sizeof(s_remote_.bytes)) == 0);
s_remote_ = spub;
dev_->peer_id_lookup_[s_remote_] = this;
if (!dev_->is_private_key_initialized_)
return;
// Precompute: s_priv_pub_ := DH(sprivr, spubi)
curve25519_donna(s_priv_pub_, dev_->s_priv_, s_remote_.bytes);
// Precompute: precomputed_cookie_key_ := HASH(LABEL-COOKIE || Spub_m)
// precomputed_mac1_key_ := HASH(MAC1-COOKIE || Spub_m)
BlakeX2(precomputed_cookie_key_, sizeof(precomputed_cookie_key_),
kLabelCookie, sizeof(kLabelCookie), s_remote_.bytes, WG_PUBLIC_KEY_LEN);
BlakeX2(precomputed_mac1_key_, sizeof(precomputed_mac1_key_),
kLabelMac1, sizeof(kLabelMac1), s_remote_.bytes, WG_PUBLIC_KEY_LEN);
// Remove the peer's keys
WG_ACQUIRE_LOCK(mutex_);
ClearKeys_Locked();
ClearHandshake_Locked();
WG_RELEASE_LOCK(mutex_);
}
void WgPeer::SetPresharedKey(const uint8 preshared_key[WG_SYMMETRIC_KEY_LEN]) {
if (preshared_key) if (preshared_key)
memcpy(preshared_key_, preshared_key, sizeof(preshared_key_)); memcpy(preshared_key_, preshared_key, sizeof(preshared_key_));
else else
memset(preshared_key_, 0, sizeof(preshared_key_)); memset(preshared_key_, 0, sizeof(preshared_key_));
// Precompute: s_priv_pub_ := DH(sprivr, spubi)
memcpy(s_remote_, spub, sizeof(s_remote_));
curve25519_donna(s_priv_pub_, dev_->s_priv_, s_remote_);
// Precompute: precomputed_cookie_key_ := HASH(LABEL-COOKIE || Spub_m)
// precomputed_mac1_key_ := HASH(MAC1-COOKIE || Spub_m)
BlakeX2(precomputed_cookie_key_, sizeof(precomputed_cookie_key_),
kLabelCookie, sizeof(kLabelCookie), spub, WG_PUBLIC_KEY_LEN);
BlakeX2(precomputed_mac1_key_, sizeof(precomputed_mac1_key_),
kLabelMac1, sizeof(kLabelMac1), spub, WG_PUBLIC_KEY_LEN);
} }
// run on the client // run on the client
@ -411,7 +514,7 @@ void WgPeer::CreateMessageHandshakeInitiation(Packet *packet) {
// Hi := HASH(Ci || IDENTIFIER) // Hi := HASH(Ci || IDENTIFIER)
memcpy(hs_.hi, kWgInitHash, sizeof(hs_.hi)); memcpy(hs_.hi, kWgInitHash, sizeof(hs_.hi));
// Hi := HASH(Hi || Spub_r) // Hi := HASH(Hi || Spub_r)
BlakeMix(hs_.hi, s_remote_, sizeof(s_remote_)); BlakeMix(hs_.hi, s_remote_.bytes, sizeof(s_remote_));
// (Epriv_r, Epub_r) := DH-GENERATE() // (Epriv_r, Epub_r) := DH-GENERATE()
// msg.ephemeral = Epub_r // msg.ephemeral = Epub_r
OsGetRandomBytes(hs_.e_priv, sizeof(hs_.e_priv)); OsGetRandomBytes(hs_.e_priv, sizeof(hs_.e_priv));
@ -422,7 +525,7 @@ void WgPeer::CreateMessageHandshakeInitiation(Packet *packet) {
// Hi := HASH(Hi || msg.ephemeral) // Hi := HASH(Hi || msg.ephemeral)
BlakeMix(hs_.hi, dst->ephemeral, sizeof(dst->ephemeral)); BlakeMix(hs_.hi, dst->ephemeral, sizeof(dst->ephemeral));
// (Ci, K) := KDF2(Ci, DH(epriv, spub_r)) // (Ci, K) := KDF2(Ci, DH(epriv, spub_r))
ComputeHKDF2DH(hs_.ci, k, hs_.e_priv, s_remote_); ComputeHKDF2DH(hs_.ci, k, hs_.e_priv, s_remote_.bytes);
// msg.static = AEAD(K, 0, Spub_i, Hi) // msg.static = AEAD(K, 0, Spub_i, Hi)
chacha20poly1305_encrypt(dst->static_enc, dev_->s_pub_, sizeof(dev_->s_pub_), hs_.hi, sizeof(hs_.hi), 0, k); chacha20poly1305_encrypt(dst->static_enc, dev_->s_pub_, sizeof(dev_->s_pub_), hs_.hi, sizeof(hs_.hi), 0, k);
// Hi := HASH(Hi || msg.static) // Hi := HASH(Hi || msg.static)
@ -461,7 +564,7 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) {
uint8 e_priv[WG_PUBLIC_KEY_LEN]; uint8 e_priv[WG_PUBLIC_KEY_LEN];
}; };
union { union {
uint8 spubi[WG_PUBLIC_KEY_LEN]; WgPublicKey spubi;
uint8 e_remote[WG_PUBLIC_KEY_LEN]; uint8 e_remote[WG_PUBLIC_KEY_LEN];
uint8 hi2[WG_HASH_LEN]; uint8 hi2[WG_HASH_LEN];
}; };
@ -488,13 +591,13 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) {
// (Ci, K) := KDF2(Ci, DH(spriv, msg.ephemeral)) // (Ci, K) := KDF2(Ci, DH(spriv, msg.ephemeral))
ComputeHKDF2DH(ci, k, dev->s_priv_, src->ephemeral); ComputeHKDF2DH(ci, k, dev->s_priv_, src->ephemeral);
// Spub_i = AEAD_DEC(K, 0, msg.static, Hi) // Spub_i = AEAD_DEC(K, 0, msg.static, Hi)
if (!chacha20poly1305_decrypt(spubi, src->static_enc, sizeof(src->static_enc), hi, sizeof(hi), 0, k)) if (!chacha20poly1305_decrypt(spubi.bytes, src->static_enc, sizeof(src->static_enc), hi, sizeof(hi), 0, k))
goto getout; goto getout;
// Hi := HASH(Hi || msg.static) // Hi := HASH(Hi || msg.static)
BlakeMix(hi, src->static_enc, sizeof(src->static_enc)); BlakeMix(hi, src->static_enc, sizeof(src->static_enc));
// Lookup the peer with this ID // Lookup the peer with this ID
while ((peer = dev->GetPeerFromPublicKey(spubi)) == NULL) { while ((peer = dev->GetPeerFromPublicKey(spubi)) == NULL) {
if (dev->delegate_ == NULL || !dev->delegate_->HandleUnknownPeerId(spubi, packet)) if (dev->delegate_ == NULL || !dev->delegate_->HandleUnknownPeerId(spubi.bytes, packet))
goto getout; goto getout;
} }
// (Ci, K) := KDF2(Ci, DH(sprivr, spubi)) // (Ci, K) := KDF2(Ci, DH(sprivr, spubi))
@ -538,7 +641,7 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) {
// Ci : = KDF2(Ci, DH(epriv, epub)) // Ci : = KDF2(Ci, DH(epriv, epub))
ComputeHKDF2DH(ci, NULL, e_priv, e_remote); ComputeHKDF2DH(ci, NULL, e_priv, e_remote);
// Ci : = KDF2(Ci, DH(epriv, spub)) // Ci : = KDF2(Ci, DH(epriv, spub))
ComputeHKDF2DH(ci, NULL, e_priv, peer->s_remote_); ComputeHKDF2DH(ci, NULL, e_priv, peer->s_remote_.bytes);
// (Ci, T, K) := KDF3(Ci, Q) // (Ci, T, K) := KDF3(Ci, Q)
blake2s_hkdf(ci, sizeof(ci), t, sizeof(t), k, sizeof(k), peer->preshared_key_, sizeof(preshared_key_), ci, WG_HASH_LEN); blake2s_hkdf(ci, sizeof(ci), t, sizeof(t), k, sizeof(k), peer->preshared_key_, sizeof(preshared_key_), ci, WG_HASH_LEN);
// Hr := HASH(Hr || T) // Hr := HASH(Hr || T)
@ -548,11 +651,6 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) {
keypair = WgPeer::CreateNewKeypair(false, ci, remote_key_id, extbuf + WG_TIMESTAMP_LEN, extfield_size); keypair = WgPeer::CreateNewKeypair(false, ci, remote_key_id, extbuf + WG_TIMESTAMP_LEN, extfield_size);
if (keypair) { if (keypair) {
WG_ACQUIRE_LOCK(peer->mutex_);
peer->InsertKeypairInPeer_Locked(keypair);
peer->OnHandshakeAuthComplete();
WG_RELEASE_LOCK(peer->mutex_);
dst->sender_key_id = dev->InsertInKeyIdLookup(peer, keypair); dst->sender_key_id = dev->InsertInKeyIdLookup(peer, keypair);
size_t extfield_out_size = 0; size_t extfield_out_size = 0;
@ -560,8 +658,17 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) {
if (extfield_size) if (extfield_size)
extfield_out_size = peer->WriteHandshakeExtension(dst->empty_enc, keypair); extfield_out_size = peer->WriteHandshakeExtension(dst->empty_enc, keypair);
#endif // WITH_HANDSHAKE_EXT #endif // WITH_HANDSHAKE_EXT
uint32 orig_packet_size = packet->size;
packet->size = (unsigned)(sizeof(MessageHandshakeResponse) + extfield_out_size); packet->size = (unsigned)(sizeof(MessageHandshakeResponse) + extfield_out_size);
WG_ACQUIRE_LOCK(peer->mutex_);
peer->rx_bytes_ += orig_packet_size;
peer->tx_bytes_ += packet->size;
peer->InsertKeypairInPeer_Locked(keypair);
peer->OnHandshakeAuthComplete();
WG_RELEASE_LOCK(peer->mutex_);
// msg.empty := AEAD(K, 0, "", Hr) // msg.empty := AEAD(K, 0, "", Hr)
chacha20poly1305_encrypt(dst->empty_enc, dst->empty_enc, extfield_out_size, hi, sizeof(hi), 0, k); chacha20poly1305_encrypt(dst->empty_enc, dst->empty_enc, extfield_out_size, hi, sizeof(hi), 0, k);
// Hr := HASH(Hr || "") // Hr := HASH(Hr || "")
@ -624,6 +731,7 @@ WgPeer *WgPeer::ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packe
peer_and_keypair->second = keypair; peer_and_keypair->second = keypair;
WG_ACQUIRE_LOCK(peer->mutex_); WG_ACQUIRE_LOCK(peer->mutex_);
peer->rx_bytes_ += packet->size;
peer->InsertKeypairInPeer_Locked(keypair); peer->InsertKeypairInPeer_Locked(keypair);
WG_RELEASE_LOCK(peer->mutex_); WG_RELEASE_LOCK(peer->mutex_);
@ -651,6 +759,9 @@ void WgPeer::ParseMessageHandshakeCookie(WgDevice *dev, const MessageHandshakeCo
if (!xchacha20poly1305_decrypt(cookie, src->cookie_enc, sizeof(src->cookie_enc), if (!xchacha20poly1305_decrypt(cookie, src->cookie_enc, sizeof(src->cookie_enc),
peer->sent_mac1_, sizeof(peer->sent_mac1_), src->nonce, peer->precomputed_cookie_key_)) peer->sent_mac1_, sizeof(peer->sent_mac1_), src->nonce, peer->precomputed_cookie_key_))
return; return;
WG_ACQUIRE_LOCK(peer->mutex_);
peer->rx_bytes_ += sizeof(MessageHandshakeCookie);
WG_RELEASE_LOCK(peer->mutex_);
peer->expect_cookie_reply_ = false; peer->expect_cookie_reply_ = false;
peer->has_mac2_cookie_ = true; peer->has_mac2_cookie_ = true;
peer->mac2_cookie_timestamp_ = OsGetMilliseconds(); peer->mac2_cookie_timestamp_ = OsGetMilliseconds();
@ -796,7 +907,7 @@ bool WgKeypairParseExtendedHandshake(WgKeypair *keypair, const uint8 *data, size
#endif // WITH_HANDSHAKE_EXT #endif // WITH_HANDSHAKE_EXT
static void ActualFreeKeypair(void *x) { static void WgKeypairDelayedDelete(void *x) {
WgKeypair *t = (WgKeypair*)x; WgKeypair *t = (WgKeypair*)x;
if (t->aes_gcm128_context_) if (t->aes_gcm128_context_)
free(t->aes_gcm128_context_); free(t->aes_gcm128_context_);
@ -808,17 +919,18 @@ void WgPeer::DeleteKeypair(WgKeypair **kp) {
*kp = NULL; *kp = NULL;
if (t) { if (t) {
assert(t->peer->IsPeerLocked()); assert(t->peer->IsPeerLocked());
WgDevice *dev = t->peer->dev_;
if (t->addr_entry) { if (t->addr_entry) {
WG_SCOPED_RWLOCK_EXCLUSIVE(dev_->addr_entry_lookup_lock_); WG_SCOPED_RWLOCK_EXCLUSIVE(dev->addr_entry_lookup_lock_);
dev_->EraseKeypairAddrEntry_Locked(t); dev->EraseKeypairAddrEntry_Locked(t);
} }
if (t->local_key_id) { if (t->local_key_id) {
WG_SCOPED_RWLOCK_EXCLUSIVE(dev_->key_id_lookup_lock_); WG_SCOPED_RWLOCK_EXCLUSIVE(dev->key_id_lookup_lock_);
dev_->key_id_lookup_.erase(t->local_key_id); dev->key_id_lookup_.erase(t->local_key_id);
t->local_key_id = 0; t->local_key_id = 0;
} }
t->recv_key_state = WgKeypair::KEY_INVALID; t->recv_key_state = WgKeypair::KEY_INVALID;
dev_->delayed_delete_.Add(&ActualFreeKeypair, t); dev->delayed_delete_.Add(&WgKeypairDelayedDelete, t);
} }
} }
@ -1029,8 +1141,8 @@ void WgPeer::OnHandshakeFullyComplete() {
} }
// Check if any of the timeouts have expired // Check if any of the timeouts have expired
uint32 WgPeer::CheckTimeouts(uint64 now) { uint32 WgPeer::CheckTimeouts_Locked(uint64 now) {
assert(IsPeerLocked()); assert(dev_->IsMainThread() && IsPeerLocked());
uint32 t, rv = 0; uint32 t, rv = 0;
@ -1096,7 +1208,7 @@ uint32 WgPeer::CheckTimeouts(uint64 now) {
// Check all key stuff here to avoid calling possibly expensive timestamp routines in the packet handler // Check all key stuff here to avoid calling possibly expensive timestamp routines in the packet handler
void WgPeer::CheckAndUpdateTimeOfNextKeyEvent(uint64 now) { void WgPeer::CheckAndUpdateTimeOfNextKeyEvent(uint64 now) {
assert(IsPeerLocked()); assert(dev_->IsMainThread() && IsPeerLocked());
uint64 next_time = UINT64_MAX; uint64 next_time = UINT64_MAX;
uint32 rv = 0; uint32 rv = 0;
@ -1142,34 +1254,60 @@ void WgPeer::SetEndpoint(const IpAddr &sin) {
endpoint_ = sin; endpoint_ = sin;
} }
void WgPeer::SetPersistentKeepalive(int persistent_keepalive_secs) { bool WgPeer::SetPersistentKeepalive(int persistent_keepalive_secs) {
if (persistent_keepalive_secs < 10 || persistent_keepalive_secs > 10000) if (persistent_keepalive_secs < 0 || persistent_keepalive_secs > 65535)
return; return false;
persistent_keepalive_ms_ = persistent_keepalive_secs * 1000; persistent_keepalive_ms_ = persistent_keepalive_secs * 1000;
return true;
}
bool WgCidrAddrEquals(const WgCidrAddr &a, const WgCidrAddr &b) {
return (a.size == b.size && a.cidr == b.cidr && memcmp(a.addr, b.addr, a.size >> 3) == 0);
} }
bool WgPeer::AddIp(const WgCidrAddr &cidr_addr) { bool WgPeer::AddIp(const WgCidrAddr &cidr_addr) {
WgPeer *old_peer;
assert(dev_->IsMainThread()); assert(dev_->IsMainThread());
if (cidr_addr.size == 32) { if (cidr_addr.size == 32) {
if (cidr_addr.cidr > 32) if (cidr_addr.cidr > 32)
return false; return false;
WG_ACQUIRE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_); WG_ACQUIRE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_);
dev_->ip_to_peer_map_.InsertV4(ReadBE32(cidr_addr.addr), cidr_addr.cidr, this); old_peer = (WgPeer*)dev_->ip_to_peer_map_.InsertV4(ReadBE32(cidr_addr.addr), cidr_addr.cidr, this);
WG_RELEASE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_); WG_RELEASE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_);
allowed_ips_.push_back(cidr_addr);
return true;
} else if (cidr_addr.size == 128) { } else if (cidr_addr.size == 128) {
if (cidr_addr.cidr > 128) if (cidr_addr.cidr > 128)
return false; return false;
WG_ACQUIRE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_); WG_ACQUIRE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_);
dev_->ip_to_peer_map_.InsertV6(cidr_addr.addr, cidr_addr.cidr, this); old_peer = (WgPeer*)dev_->ip_to_peer_map_.InsertV6(cidr_addr.addr, cidr_addr.cidr, this);
WG_RELEASE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_); WG_RELEASE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_);
allowed_ips_.push_back(cidr_addr);
return true;
} else { } else {
return false; return false;
} }
if (old_peer) {
for (auto it = old_peer->allowed_ips_.begin(); it != old_peer->allowed_ips_.end(); ++it) {
if (WgCidrAddrEquals(*it, cidr_addr)) {
old_peer->allowed_ips_.erase(it);
break;
}
}
}
allowed_ips_.push_back(cidr_addr);
return true;
}
void WgPeer::RemoveAllIps() {
assert(dev_->IsMainThread());
WG_ACQUIRE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_);
for (auto it = allowed_ips_.begin(); it != allowed_ips_.end(); ++it) {
if (it->size == 32) {
dev_->ip_to_peer_map_.RemoveV4(ReadBE32(it->addr), it->cidr);
} else if (it->size == 128) {
dev_->ip_to_peer_map_.RemoveV6(it->addr, it->cidr);
}
}
WG_RELEASE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_);
allowed_ips_.clear();
} }
void WgPeer::SetAllowMulticast(bool allow) { void WgPeer::SetAllowMulticast(bool allow) {
@ -1196,6 +1334,18 @@ bool WgPeer::AddCipher(int cipher) {
return true; return true;
} }
void WgPeer::ScheduleNewHandshake() {
// Note, it's possible that the peer has already been marked for delete
if (main_thread_scheduled_.fetch_or(WgPeer::kMainThreadScheduled_ScheduleHandshake) == 0) {
main_thread_scheduled_next_ = NULL;
WG_ACQUIRE_LOCK(dev_->main_thread_scheduled_lock_);
*dev_->main_thread_scheduled_last_ = this;
dev_->main_thread_scheduled_last_ = &main_thread_scheduled_next_;
WG_RELEASE_LOCK(dev_->main_thread_scheduled_lock_);
// todo: in multithreaded impl need to trigger |RunAllMainThreadScheduled| to get called
}
}
WgRateLimit::WgRateLimit() { WgRateLimit::WgRateLimit() {
key1_[0] = key1_[1] = 1; key1_[0] = key1_[1] = 1;
key2_[0] = key2_[1] = 1; key2_[0] = key2_[1] = 1;

View file

@ -11,7 +11,7 @@
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include <atomic> #include <atomic>
#include <string.h>
// Threading macros that enable locks only in MT builds // Threading macros that enable locks only in MT builds
#if WITH_WG_THREADING #if WITH_WG_THREADING
#define WG_SCOPED_LOCK(name) ScopedLock scoped_lock(&name) #define WG_SCOPED_LOCK(name) ScopedLock scoped_lock(&name)
@ -25,6 +25,7 @@
#define WG_RELEASE_RWLOCK_EXCLUSIVE(name) name.ReleaseExclusive() #define WG_RELEASE_RWLOCK_EXCLUSIVE(name) name.ReleaseExclusive()
#define WG_SCOPED_RWLOCK_SHARED(name) ScopedLockShared scoped_lock(&name) #define WG_SCOPED_RWLOCK_SHARED(name) ScopedLockShared scoped_lock(&name)
#define WG_SCOPED_RWLOCK_EXCLUSIVE(name) ScopedLockExclusive scoped_lock(&name) #define WG_SCOPED_RWLOCK_EXCLUSIVE(name) ScopedLockExclusive scoped_lock(&name)
#define WG_IF_LOCKS_ENABLED_ELSE(expr, def) (expr)
#else // WITH_WG_THREADING #else // WITH_WG_THREADING
#define WG_SCOPED_LOCK(name) #define WG_SCOPED_LOCK(name)
#define WG_ACQUIRE_LOCK(name) #define WG_ACQUIRE_LOCK(name)
@ -37,6 +38,7 @@
#define WG_RELEASE_RWLOCK_EXCLUSIVE(name) #define WG_RELEASE_RWLOCK_EXCLUSIVE(name)
#define WG_SCOPED_RWLOCK_SHARED(name) #define WG_SCOPED_RWLOCK_SHARED(name)
#define WG_SCOPED_RWLOCK_EXCLUSIVE(name) #define WG_SCOPED_RWLOCK_EXCLUSIVE(name)
#define WG_IF_LOCKS_ENABLED_ELSE(expr, def) (def)
#endif // WITH_WG_THREADING #endif // WITH_WG_THREADING
enum ProtocolTimeouts { enum ProtocolTimeouts {
@ -77,6 +79,7 @@ enum MessageFieldSizes {
WG_MAC_LEN = 16, WG_MAC_LEN = 16,
WG_TIMESTAMP_LEN = 12, WG_TIMESTAMP_LEN = 12,
WG_SIPHASH_KEY_LEN = 16, WG_SIPHASH_KEY_LEN = 16,
WG_PUBLIC_KEY_LEN_BASE64 = 44,
}; };
enum { enum {
@ -194,11 +197,9 @@ struct WgPacketCompressionVer01 {
}; };
STATIC_ASSERT(sizeof(WgPacketCompressionVer01) == 24, WgPacketCompressionVer01_wrong_size); STATIC_ASSERT(sizeof(WgPacketCompressionVer01) == 24, WgPacketCompressionVer01_wrong_size);
struct WgKeypair; struct WgKeypair;
class WgPeer; class WgPeer;
class WgRateLimit { class WgRateLimit {
public: public:
WgRateLimit(); WgRateLimit();
@ -261,9 +262,25 @@ struct ScramblerSiphashKeys {
uint64 keys[4]; uint64 keys[4];
}; };
union WgPublicKey {
uint8 bytes[WG_PUBLIC_KEY_LEN];
uint64 u64[WG_PUBLIC_KEY_LEN / 8];
friend bool operator==(const WgPublicKey &a, const WgPublicKey &b) {
return memcmp(a.bytes, b.bytes, WG_PUBLIC_KEY_LEN) == 0;
}
};
struct WgPublicKeyHasher {
size_t operator()(const WgPublicKey&a) const {
uint64 rv = a.u64[0] ^ a.u64[1] ^ a.u64[2] ^ a.u64[3];
return (size_t)(rv ^ (rv >> 32));
}
};
class WgDevice { class WgDevice {
friend class WgPeer; friend class WgPeer;
friend class WireguardProcessor; friend class WireguardProcessor;
friend class WgConfig;
public: public:
// Can be used to customize the behavior of WgDevice // Can be used to customize the behavior of WgDevice
@ -278,12 +295,15 @@ public:
WgDevice(); WgDevice();
~WgDevice(); ~WgDevice();
// Initialize with the private key, precompute all internal keys etc. // Configure with the private key, precompute all internal keys etc.
void Initialize(const uint8 private_key[WG_PUBLIC_KEY_LEN]); void SetPrivateKey(const uint8 private_key[WG_PUBLIC_KEY_LEN]);
// Create a new peer // Create a new peer
WgPeer *AddPeer(); WgPeer *AddPeer();
// Remove all peers
void RemoveAllPeers();
// Setup header obfuscation // Setup header obfuscation
void SetHeaderObfuscation(const char *key); void SetHeaderObfuscation(const char *key);
@ -303,17 +323,19 @@ public:
WgRateLimit *rate_limiter() { return &rate_limiter_; } WgRateLimit *rate_limiter() { return &rate_limiter_; }
std::unordered_map<uint64, WgAddrEntry*> &addr_entry_map() { return addr_entry_lookup_; } std::unordered_map<uint64, WgAddrEntry*> &addr_entry_map() { return addr_entry_lookup_; }
WgPacketCompressionVer01 *compression_header() { return &compression_header_; } WgPacketCompressionVer01 *compression_header() { return &compression_header_; }
bool is_private_key_initialized() { return is_private_key_initialized_; }
bool IsMainThread() { return CurrentThreadIdEquals(main_thread_id_); } bool IsMainThread() { return CurrentThreadIdEquals(main_thread_id_); }
void SetCurrentThreadAsMainThread() { main_thread_id_ = GetCurrentThreadId(); } bool IsMainOrDataThread() { return CurrentThreadIdEquals(main_thread_id_) || WG_IF_LOCKS_ENABLED_ELSE(delayed_delete_.enabled(), false); }
void SetDelegate(Delegate *del) { delegate_ = del; } void SetDelegate(Delegate *del) { delegate_ = del; }
private: private:
std::pair<WgPeer*, WgKeypair*> *LookupPeerInKeyIdLookup(uint32 key_id); std::pair<WgPeer*, WgKeypair*> *LookupPeerInKeyIdLookup(uint32 key_id);
WgKeypair *LookupKeypairByKeyId(uint32 key_id); WgKeypair *LookupKeypairByKeyId(uint32 key_id);
WgKeypair *LookupKeypairInAddrEntryMap(uint64 addr, uint32 slot); WgKeypair *LookupKeypairInAddrEntryMap(uint64 addr, uint32 slot);
// Return the peer matching the |public_key| or NULL // Return the peer matching the |public_key| or NULL
WgPeer *GetPeerFromPublicKey(uint8 public_key[WG_PUBLIC_KEY_LEN]); WgPeer *GetPeerFromPublicKey(const WgPublicKey &pubkey);
// Create a cookie by inspecting the source address of the |packet| // Create a cookie by inspecting the source address of the |packet|
void MakeCookie(uint8 cookie[WG_COOKIE_LEN], Packet *packet); void MakeCookie(uint8 cookie[WG_COOKIE_LEN], Packet *packet);
// Insert a new entry in |key_id_lookup_| // Insert a new entry in |key_id_lookup_|
@ -330,7 +352,7 @@ private:
WG_DECLARE_RWLOCK(ip_to_peer_map_lock_); WG_DECLARE_RWLOCK(ip_to_peer_map_lock_);
// For enumerating all peers // For enumerating all peers
WgPeer *peers_; WgPeer *peers_, **last_peer_ptr_;
// For hooking // For hooking
Delegate *delegate_; Delegate *delegate_;
@ -346,12 +368,22 @@ private:
std::unordered_map<uint64, WgAddrEntry*> addr_entry_lookup_; std::unordered_map<uint64, WgAddrEntry*> addr_entry_lookup_;
WG_DECLARE_RWLOCK(addr_entry_lookup_lock_); WG_DECLARE_RWLOCK(addr_entry_lookup_lock_);
// Mapping from peer id to peer. This may be accessed only from MT.
std::unordered_map<WgPublicKey, WgPeer*, WgPublicKeyHasher> peer_id_lookup_;
// Queue of things scheduled to run on the main thread.
WG_DECLARE_LOCK(main_thread_scheduled_lock_);
WgPeer *main_thread_scheduled_, **main_thread_scheduled_last_;
// Counter for generating new indices in |keypair_lookup_| // Counter for generating new indices in |keypair_lookup_|
uint8 next_rng_slot_; uint8 next_rng_slot_;
// Whether packet obfuscation is enabled // Whether packet obfuscation is enabled
bool header_obfuscation_; bool header_obfuscation_;
// Whether a private key has been setup for the device
bool is_private_key_initialized_;
ThreadId main_thread_id_; ThreadId main_thread_id_;
uint64 low_resolution_timestamp_; uint64 low_resolution_timestamp_;
@ -382,15 +414,16 @@ private:
class WgPeer { class WgPeer {
friend class WgDevice; friend class WgDevice;
friend class WireguardProcessor; friend class WireguardProcessor;
friend class WgConfig;
friend bool WgKeypairParseExtendedHandshake(WgKeypair *keypair, const uint8 *data, size_t data_size); friend bool WgKeypairParseExtendedHandshake(WgKeypair *keypair, const uint8 *data, size_t data_size);
friend void WgKeypairSetupCompressionExtension(WgKeypair *keypair, const WgPacketCompressionVer01 *remotec); friend void WgKeypairSetupCompressionExtension(WgKeypair *keypair, const WgPacketCompressionVer01 *remotec);
public: public:
explicit WgPeer(WgDevice *dev); explicit WgPeer(WgDevice *dev);
~WgPeer(); ~WgPeer();
void Initialize(const uint8 spub[WG_PUBLIC_KEY_LEN], const uint8 preshared_key[WG_SYMMETRIC_KEY_LEN]); void SetPublicKey(const WgPublicKey &spub);
void SetPresharedKey(const uint8 preshared_key[WG_SYMMETRIC_KEY_LEN]);
void SetPersistentKeepalive(int persistent_keepalive_secs); bool SetPersistentKeepalive(int persistent_keepalive_secs);
void SetEndpoint(const IpAddr &sin); void SetEndpoint(const IpAddr &sin);
void SetAllowMulticast(bool allow); void SetAllowMulticast(bool allow);
@ -398,15 +431,14 @@ public:
bool AddCipher(int cipher); bool AddCipher(int cipher);
void SetCipherPrio(bool prio) { cipher_prio_ = prio; } void SetCipherPrio(bool prio) { cipher_prio_ = prio; }
bool AddIp(const WgCidrAddr &cidr_addr); bool AddIp(const WgCidrAddr &cidr_addr);
void RemoveAllIps();
static WgPeer *ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet); static WgPeer *ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet);
static WgPeer *ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packet); static WgPeer *ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packet);
static void ParseMessageHandshakeCookie(WgDevice *dev, const MessageHandshakeCookie *src); static void ParseMessageHandshakeCookie(WgDevice *dev, const MessageHandshakeCookie *src);
void CreateMessageHandshakeInitiation(Packet *packet); void CreateMessageHandshakeInitiation(Packet *packet);
bool CheckSwitchToNextKey_Locked(WgKeypair *keypair); bool CheckSwitchToNextKey_Locked(WgKeypair *keypair);
void ClearKeys_Locked(); void RemovePeer();
void ClearHandshake_Locked();
void ClearPacketQueue_Locked();
bool CheckHandshakeRateLimit(); bool CheckHandshakeRateLimit();
// Timer notifications // Timer notifications
@ -422,24 +454,24 @@ public:
ACTION_SEND_KEEPALIVE = 1, ACTION_SEND_KEEPALIVE = 1,
ACTION_SEND_HANDSHAKE = 2, ACTION_SEND_HANDSHAKE = 2,
}; };
uint32 CheckTimeouts(uint64 now); uint32 CheckTimeouts_Locked(uint64 now);
void AddPacketToPeerQueue(Packet *packet); void AddPacketToPeerQueue_Locked(Packet *packet);
bool IsPeerLocked() { return WG_IF_LOCKS_ENABLED_ELSE(mutex_.IsLocked(), true); }
#if WITH_WG_THREADING
bool IsPeerLocked() { return mutex_.IsLocked(); }
#else // WITH_WG_THREADING
bool IsPeerLocked() { return true; }
#endif // WITH_WG_THREADING
private: private:
static WgKeypair *CreateNewKeypair(bool is_initiator, const uint8 key[WG_HASH_LEN], uint32 send_key_id, const uint8 *extfield, size_t extfield_size); static WgKeypair *CreateNewKeypair(bool is_initiator, const uint8 key[WG_HASH_LEN], uint32 send_key_id, const uint8 *extfield, size_t extfield_size);
void WriteMacToPacket(const uint8 *data, MessageMacs *mac); void WriteMacToPacket(const uint8 *data, MessageMacs *mac);
void DeleteKeypair(WgKeypair **kp);
void CheckAndUpdateTimeOfNextKeyEvent(uint64 now); void CheckAndUpdateTimeOfNextKeyEvent(uint64 now);
static void DeleteKeypair(WgKeypair **kp);
static void CopyEndpointToPeer_Locked(WgKeypair *keypair, const IpAddr *addr); static void CopyEndpointToPeer_Locked(WgKeypair *keypair, const IpAddr *addr);
static void DelayedDelete(void *x);
size_t WriteHandshakeExtension(uint8 *dst, WgKeypair *keypair); size_t WriteHandshakeExtension(uint8 *dst, WgKeypair *keypair);
void InsertKeypairInPeer_Locked(WgKeypair *keypair); void InsertKeypairInPeer_Locked(WgKeypair *keypair);
void ClearKeys_Locked();
void ClearHandshake_Locked();
void ClearPacketQueue_Locked();
void ScheduleNewHandshake();
WgDevice *dev_; WgDevice *dev_;
WgPeer *next_peer_; WgPeer *next_peer_;
@ -492,6 +524,10 @@ private:
// Whether |mac2_cookie_| is valid. // Whether |mac2_cookie_| is valid.
bool has_mac2_cookie_; bool has_mac2_cookie_;
// Whether the WgPeer has been deleted (i.e. RemovePeer has been called),
// and will be deleted as soon as the threads sync.
bool marked_for_delete_;
// Number of handshakes made so far, when this gets too high we stop connecting. // Number of handshakes made so far, when this gets too high we stop connecting.
uint8 handshake_attempts_; uint8 handshake_attempts_;
@ -518,6 +554,9 @@ private:
uint8 num_ciphers_; uint8 num_ciphers_;
uint8 ciphers_[MAX_CIPHERS]; uint8 ciphers_[MAX_CIPHERS];
uint64 rx_bytes_;
uint64 tx_bytes_;
// Handshake state that gets setup in |CreateMessageHandshakeInitiation| and used in // Handshake state that gets setup in |CreateMessageHandshakeInitiation| and used in
// the response. // the response.
struct HandshakeState { struct HandshakeState {
@ -530,7 +569,7 @@ private:
}; };
HandshakeState hs_; HandshakeState hs_;
// Remote's static public key - init only. // Remote's static public key - init only.
uint8 s_remote_[WG_PUBLIC_KEY_LEN]; WgPublicKey s_remote_;
// Remote's preshared key - init only. // Remote's preshared key - init only.
uint8 preshared_key_[WG_SYMMETRIC_KEY_LEN]; uint8 preshared_key_[WG_SYMMETRIC_KEY_LEN];
// Precomputed DH(spriv_local, spub_remote) - init only. // Precomputed DH(spriv_local, spub_remote) - init only.