Experimental support for WireGuard over TCP
This commit is contained in:
parent
9a8acb7091
commit
a03980e74b
|
@ -185,10 +185,24 @@
|
|||
<ClInclude Include="crypto\blake2s\blake2s-sse-impl.h" />
|
||||
<ClInclude Include="crypto\curve25519\curve25519-donna.h" />
|
||||
<ClInclude Include="ip_to_peer_map.h" />
|
||||
<ClInclude Include="network_bsd.h">
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">true</ExcludedFromBuild>
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">true</ExcludedFromBuild>
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
|
||||
</ClInclude>
|
||||
<ClInclude Include="network_common.h" />
|
||||
<ClInclude Include="network_win32_tcp.h" />
|
||||
<ClInclude Include="service_pipe_win32.h" />
|
||||
<ClInclude Include="service_win32.h" />
|
||||
<ClInclude Include="service_win32_api.h" />
|
||||
<ClInclude Include="service_win32_constants.h" />
|
||||
<ClInclude Include="tunsafe_bsd.h">
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">true</ExcludedFromBuild>
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">true</ExcludedFromBuild>
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
|
||||
</ClInclude>
|
||||
<ClInclude Include="tunsafe_config.h" />
|
||||
<ClInclude Include="tunsafe_cpu.h" />
|
||||
<ClInclude Include="crypto\aesgcm\aes.h" />
|
||||
|
@ -215,8 +229,28 @@
|
|||
<ItemGroup>
|
||||
<ClCompile Include="benchmark.cpp" />
|
||||
<ClCompile Include="ip_to_peer_map.cpp" />
|
||||
<ClCompile Include="network_bsd.cpp">
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">true</ExcludedFromBuild>
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">true</ExcludedFromBuild>
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
|
||||
</ClCompile>
|
||||
<ClCompile Include="network_bsd_mt.cpp">
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">true</ExcludedFromBuild>
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">true</ExcludedFromBuild>
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
|
||||
</ClCompile>
|
||||
<ClCompile Include="network_common.cpp" />
|
||||
<ClCompile Include="network_win32_tcp.cpp" />
|
||||
<ClCompile Include="service_pipe_win32.cpp" />
|
||||
<ClCompile Include="service_win32.cpp" />
|
||||
<ClCompile Include="tunsafe_bsd.cpp">
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">true</ExcludedFromBuild>
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">true</ExcludedFromBuild>
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
|
||||
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
|
||||
</ClCompile>
|
||||
<ClCompile Include="tunsafe_cpu.cpp" />
|
||||
<ClCompile Include="crypto\aesgcm\aesgcm.cpp" />
|
||||
<ClCompile Include="crypto\siphash\siphash.cpp" />
|
||||
|
|
|
@ -23,6 +23,12 @@
|
|||
<Filter Include="crypto\chacha20poly1305">
|
||||
<UniqueIdentifier>{1ca37c7b-e91e-4648-9584-7d0c73d8e416}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="Source Files\BSD">
|
||||
<UniqueIdentifier>{4b2f2fd9-780e-45db-8fe1-f03079439723}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="crypto\siphash">
|
||||
<UniqueIdentifier>{0f45e1a0-f33e-4c6e-88ae-eb4639f12041}</UniqueIdentifier>
|
||||
</Filter>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="stdafx.h">
|
||||
|
@ -92,7 +98,6 @@
|
|||
<ClInclude Include="service_win32_constants.h">
|
||||
<Filter>Source Files\Win32</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="crypto\siphash\siphash.h" />
|
||||
<ClInclude Include="crypto\blake2s\blake2s.h">
|
||||
<Filter>crypto\blake2s</Filter>
|
||||
</ClInclude>
|
||||
|
@ -123,6 +128,21 @@
|
|||
<ClInclude Include="tunsafe_dnsresolve.h">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="crypto\siphash\siphash.h">
|
||||
<Filter>crypto\siphash</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="network_bsd.h">
|
||||
<Filter>Source Files\BSD</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="tunsafe_bsd.h">
|
||||
<Filter>Source Files\BSD</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="network_common.h">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="network_win32_tcp.h">
|
||||
<Filter>Source Files\Win32</Filter>
|
||||
</ClInclude>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="stdafx.cpp">
|
||||
|
@ -173,9 +193,6 @@
|
|||
<ClCompile Include="service_pipe_win32.cpp">
|
||||
<Filter>Source Files\Win32</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="crypto\siphash\siphash.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="crypto\blake2s\blake2s.cpp">
|
||||
<Filter>crypto\blake2s</Filter>
|
||||
</ClCompile>
|
||||
|
@ -188,6 +205,24 @@
|
|||
<ClCompile Include="tunsafe_ipaddr.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="network_bsd.cpp">
|
||||
<Filter>Source Files\BSD</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="network_bsd_mt.cpp">
|
||||
<Filter>Source Files\BSD</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="crypto\siphash\siphash.cpp">
|
||||
<Filter>crypto\siphash</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="tunsafe_bsd.cpp">
|
||||
<Filter>Source Files\BSD</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="network_common.cpp">
|
||||
<Filter>Source Files</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="network_win32_tcp.cpp">
|
||||
<Filter>Source Files\Win32</Filter>
|
||||
</ClCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ResourceCompile Include="TunSafe.rc" />
|
||||
|
|
104
docs/WireGuard TCP.txt
Normal file
104
docs/WireGuard TCP.txt
Normal file
|
@ -0,0 +1,104 @@
|
|||
WireGuard over TCP
|
||||
------------------
|
||||
|
||||
We hate running one TCP implementation on top of another TCP implementation.
|
||||
There's problems with cascading retransmissions and head of line blocking,
|
||||
and performance is always much worse than a UDP based tunnel.
|
||||
|
||||
However, we also recognize that several users need to run WireGuard over TCP.
|
||||
One reason is that UDP packets are sometimes blocked by the network in
|
||||
corporate scenarios or in other types of firewalls. Also, in misconfigured
|
||||
networks outside of the user's control, TCP may be more reliable than UDP.
|
||||
|
||||
Additionally, we want TunSafe to be a drop-in replacement for OpenVPN, which
|
||||
also supports TCP based tunneling. The feature could also be used to run
|
||||
WireGuard tunnels over ssh tunnels, or through socks/https proxies.
|
||||
|
||||
The TunSafe project therefore takes the pragmatic approach of supporting
|
||||
WireGuard over TCP, while discouraging its use. We absolutely don't want
|
||||
people to start using TCP by default. It's meant to be used only in the
|
||||
extreme cases when nothing else is working.
|
||||
|
||||
We've added experimental support for TCP in the latest TunSafe master,
|
||||
which means you can try this out on Windows, OSX, Linux, and FreeBSD.
|
||||
On the server side, to listen on a TCP port, use ListenPortTCP=1234. (Not
|
||||
working on Windows yet). On the clients, use Endpoint=tcp://5.5.5.5:1234.
|
||||
The code is still very experimental and untested, and is not recommended
|
||||
for general use. Once the code is more well tested, we'll also release
|
||||
support for connecting to WireGuard over TCP in our Android and iOS clients.
|
||||
|
||||
To make the impact as small as possible to our WireGuard protocol handling,
|
||||
and to minimize the risk of security related issues, the TCP feature has been
|
||||
designed to be as self-contained as possible. When a packet comes in over
|
||||
TCP, it's sent over to the WireGuard protocol handler and treated as if it
|
||||
was a UDP packet, and vice versa. This means TCP support can also be supported
|
||||
in existing WireGuard deployments by using a separate process that converts
|
||||
TCP connections into UDP packets sent to the WireGuard Linux kernel module.
|
||||
|
||||
Each packet over TCP is prefixed by a 2-byte big endian number, which contains
|
||||
the length of the packet's payload. The payload is then the actual WireGuard
|
||||
UDP packet.
|
||||
|
||||
TCP has larger overhead than UDP, and we want to support the usual WireGuard
|
||||
MTU of 1420 without introducing extra packet "fragmenting". So we implemented
|
||||
an optimization to skip sending the 16-byte WireGuard header for every packet.
|
||||
TCP is a reliable connection, we know that sequence numbers are always
|
||||
monotonically increasing, so we can predict the contents of this header.
|
||||
|
||||
Here's an example:
|
||||
A 1420 byte big packet sent over a WireGuard link will have 2 bytes of
|
||||
TCP payload length, 16 bytes of WireGuard headers, 16 bytes of WireGuard MAC,
|
||||
20 bytes of TCP headers, and 40 bytes of IPv6 headers.
|
||||
This is a total of 1420 + 2 + 16 + 16 + 20 + 40 = 1514 bytes, exceeding
|
||||
the usual 1500 byte Ethernet MTU by 14 bytes. This means that a single full
|
||||
sized packet over WireGuard will result in 2 TCP packets. With our
|
||||
optimization, we reduce this to 1498 bytes, so it fits in one TCP packet.
|
||||
|
||||
Protocol specification
|
||||
----------------------
|
||||
|
||||
TT LLLLLL LLLLLLLL [Payload LL bytes]
|
||||
| |
|
||||
| \-- Payload length, high byte first.
|
||||
\----- Packet type
|
||||
|
||||
The packet types (TT) currently defined are:
|
||||
TT = 00 = Normal The payload is a normal unmodified WireGuard packet
|
||||
including the regular WireGuard header.
|
||||
01 = Reserved
|
||||
10 = Data A WireGuard data packet (type 04) without the 16 byte
|
||||
header. The predicted header is prefixed to the payload.
|
||||
11 = Control A TCP control packet. Currently this is used only to setup
|
||||
the header prediction. See below.
|
||||
|
||||
There's only one defined Control packet, type 00 (SetKeyAndCounter):
|
||||
|
||||
0 1 2 3
|
||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
||||
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
|1 1| Length is 13 (14 bits) | 00 (8 bits) |
|
||||
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
| Key ID (32 bits) |
|
||||
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
| Counter (64 bits) ...
|
||||
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
|
||||
This sets up the Key ID and Counter used for the Data packets. Then Counter
|
||||
is incremented by 1 for every such packet.
|
||||
|
||||
For every Data packet, the predicted Key ID and Counter is expanded to a
|
||||
regular WireGuard data (type 04) header, which is prefixed to the payload:
|
||||
|
||||
0 1 2 3
|
||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
||||
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
| 04 (8 bits) | Reserved (24 bits) |
|
||||
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
| Key ID (32 bits) |
|
||||
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
| Counter (64 bits) ...
|
||||
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
| Data Payload (LL * 8 bits) ...
|
||||
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
|
||||
This happens independently in each of the two TCP directions.
|
18
netapi.h
18
netapi.h
|
@ -60,21 +60,29 @@ struct Packet : QueuedItem {
|
|||
uint8 protocol; // which protocol is this packet for/from
|
||||
IpAddr addr; // Optionally set to target/source of the packet
|
||||
|
||||
byte data_pre[4];
|
||||
byte data_buf[0];
|
||||
enum {
|
||||
// there's always this much data before data_ptr
|
||||
// there's always this much data before data_buf, to allow for header expansion
|
||||
// in front.
|
||||
HEADROOM_BEFORE = 64,
|
||||
};
|
||||
|
||||
byte data_pre[HEADROOM_BEFORE];
|
||||
byte data_buf[0];
|
||||
|
||||
void Reset() {
|
||||
data = data_buf;
|
||||
size = 0;
|
||||
}
|
||||
};
|
||||
|
||||
enum {
|
||||
kPacketAllocSize = 2048 - 16,
|
||||
kPacketCapacity = kPacketAllocSize - sizeof(Packet) - Packet::HEADROOM_BEFORE,
|
||||
kPacketCapacity = kPacketAllocSize - sizeof(Packet),
|
||||
};
|
||||
|
||||
void FreePacket(Packet *packet);
|
||||
void FreePackets(Packet *packet, Packet **end, int count);
|
||||
void FreePacketList(Packet *packet);
|
||||
Packet *AllocPacket();
|
||||
void FreeAllPackets();
|
||||
|
||||
|
@ -123,7 +131,7 @@ public:
|
|||
|
||||
class UdpInterface {
|
||||
public:
|
||||
virtual bool Configure(int listen_port) = 0;
|
||||
virtual bool Configure(int listen_port_udp, int listen_port_tcp) = 0;
|
||||
virtual void WriteUdpPacket(Packet *packet) = 0;
|
||||
};
|
||||
|
||||
|
|
1235
network_bsd.cpp
1235
network_bsd.cpp
File diff suppressed because it is too large
Load diff
304
network_bsd.h
Normal file
304
network_bsd.h
Normal file
|
@ -0,0 +1,304 @@
|
|||
// SPDX-License-Identifier: AGPL-1.0-only
|
||||
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
|
||||
#ifndef TUNSAFE_NETWORK_BSD_H_
|
||||
#define TUNSAFE_NETWORK_BSD_H_
|
||||
|
||||
#include <poll.h>
|
||||
#include <sys/un.h>
|
||||
#include <sys/uio.h>
|
||||
#include <string>
|
||||
#include "network_common.h"
|
||||
|
||||
class BaseSocketBsd;
|
||||
class TcpSocketBsd;
|
||||
class WireguardProcessor;
|
||||
class Packet;
|
||||
|
||||
class NetworkBsd {
|
||||
friend class BaseSocketBsd;
|
||||
friend class TcpSocketBsd;
|
||||
friend class UdpSocketBsd;
|
||||
friend class TunSocketBsd;
|
||||
public:
|
||||
enum {
|
||||
#if defined(OS_ANDROID)
|
||||
WithSigalarmSupport = 0,
|
||||
#else
|
||||
WithSigalarmSupport = 1
|
||||
#endif
|
||||
};
|
||||
|
||||
class NetworkBsdDelegate {
|
||||
public:
|
||||
virtual void OnSecondLoop(uint64 now) {}
|
||||
virtual void RunAllMainThreadScheduled() {}
|
||||
};
|
||||
|
||||
explicit NetworkBsd(NetworkBsdDelegate *delegate, int max_sockets);
|
||||
~NetworkBsd();
|
||||
|
||||
void RunLoop(const sigset_t *sigmask);
|
||||
void PostExit() { exit_ = true; }
|
||||
|
||||
bool *exit_flag() { return &exit_; }
|
||||
bool *sigalarm_flag() { return &sigalarm_flag_; }
|
||||
|
||||
TcpSocketBsd *tcp_sockets() { return tcp_sockets_; }
|
||||
bool overload() { return overload_; }
|
||||
private:
|
||||
void RemoveFromRoundRobin(int slot);
|
||||
|
||||
void ReallocateIov(size_t i);
|
||||
void EnsureIovAllocated();
|
||||
|
||||
Packet *read_packet_;
|
||||
bool exit_;
|
||||
bool overload_;
|
||||
bool sigalarm_flag_;
|
||||
|
||||
enum {
|
||||
// This controls the max # of sockets we can support
|
||||
kMaxIovec = 16,
|
||||
};
|
||||
int num_sock_;
|
||||
int num_roundrobin_;
|
||||
int num_endloop_;
|
||||
int max_sockets_;
|
||||
|
||||
SimplePacketPool packet_pool_;
|
||||
NetworkBsdDelegate *delegate_;
|
||||
|
||||
struct pollfd *pollfd_;
|
||||
BaseSocketBsd **sockets_;
|
||||
BaseSocketBsd **roundrobin_;
|
||||
BaseSocketBsd **endloop_;
|
||||
|
||||
// Linked list of all tcp sockets
|
||||
TcpSocketBsd *tcp_sockets_;
|
||||
|
||||
struct iovec iov_[kMaxIovec];
|
||||
Packet *iov_packets_[kMaxIovec];
|
||||
|
||||
};
|
||||
|
||||
class BaseSocketBsd {
|
||||
friend class NetworkBsd;
|
||||
public:
|
||||
BaseSocketBsd(NetworkBsd *network) : pollfd_slot_(-1), roundrobin_slot_(-1), endloop_slot_(-1), fd_(-1), network_(network) {}
|
||||
virtual ~BaseSocketBsd();
|
||||
|
||||
virtual void HandleEvents(int revents) = 0;
|
||||
|
||||
// Return |false| to remove socket from roundrobin list.
|
||||
virtual bool DoRoundRobin() { return false; }
|
||||
virtual void DoEndloop() {}
|
||||
virtual void Periodic() {}
|
||||
|
||||
// Make sure this socket gets called during each round robin step.
|
||||
void AddToRoundRobin();
|
||||
|
||||
// Make sure this sockets get called at the end of the loop
|
||||
void AddToEndLoop();
|
||||
|
||||
int GetFd() { return fd_; }
|
||||
|
||||
protected:
|
||||
void SetPollFlags(int events) {
|
||||
network_->pollfd_[pollfd_slot_].events = events;
|
||||
}
|
||||
void InitPollSlot(int fd, int events);
|
||||
bool HasFreePollSlot() { return network_->num_sock_ != network_->max_sockets_; }
|
||||
void CloseSocket();
|
||||
|
||||
NetworkBsd *network_;
|
||||
int pollfd_slot_;
|
||||
int roundrobin_slot_;
|
||||
int endloop_slot_;
|
||||
int fd_;
|
||||
};
|
||||
|
||||
class TunSocketBsd : public BaseSocketBsd {
|
||||
public:
|
||||
explicit TunSocketBsd(NetworkBsd *network, WireguardProcessor *processor);
|
||||
virtual ~TunSocketBsd();
|
||||
|
||||
bool Initialize(int fd);
|
||||
|
||||
virtual void HandleEvents(int revents) override;
|
||||
virtual bool DoRoundRobin() override;
|
||||
|
||||
void WritePacket(Packet *packet);
|
||||
|
||||
bool tun_interface_gone() const { return tun_interface_gone_; }
|
||||
|
||||
private:
|
||||
bool DoRead();
|
||||
bool DoWrite();
|
||||
|
||||
bool tun_readable_, tun_writable_;
|
||||
bool tun_interface_gone_;
|
||||
Packet *tun_queue_, **tun_queue_end_;
|
||||
WireguardProcessor *processor_;
|
||||
};
|
||||
|
||||
class UdpSocketBsd : public BaseSocketBsd {
|
||||
public:
|
||||
explicit UdpSocketBsd(NetworkBsd *network, WireguardProcessor *processor);
|
||||
virtual ~UdpSocketBsd();
|
||||
|
||||
bool Initialize(int listen_port);
|
||||
|
||||
virtual void HandleEvents(int revents) override;
|
||||
virtual bool DoRoundRobin() override;
|
||||
|
||||
bool DoRead();
|
||||
bool DoWrite();
|
||||
|
||||
void WritePacket(Packet *packet);
|
||||
|
||||
private:
|
||||
bool udp_readable_, udp_writable_;
|
||||
Packet *udp_queue_, **udp_queue_end_;
|
||||
WireguardProcessor *processor_;
|
||||
};
|
||||
|
||||
#if defined(OS_LINUX)
|
||||
// Keeps track of when the unix socket gets deleted
|
||||
class UnixSocketDeletionWatcher {
|
||||
public:
|
||||
UnixSocketDeletionWatcher();
|
||||
~UnixSocketDeletionWatcher();
|
||||
bool Start(const char *path, bool *flag_to_set);
|
||||
void Stop();
|
||||
bool Poll(const char *path) { return false; }
|
||||
|
||||
private:
|
||||
static void *RunThread(void *arg);
|
||||
void *RunThreadInner();
|
||||
const char *path_;
|
||||
int inotify_fd_;
|
||||
int pid_;
|
||||
int pipes_[2];
|
||||
pthread_t thread_;
|
||||
bool *flag_to_set_;
|
||||
};
|
||||
#else // !defined(OS_LINUX)
|
||||
// all other platforms that lack inotify
|
||||
class UnixSocketDeletionWatcher {
|
||||
public:
|
||||
UnixSocketDeletionWatcher() {}
|
||||
~UnixSocketDeletionWatcher() {}
|
||||
bool Start(const char *path, bool *flag_to_set) { return true; }
|
||||
void Stop() {}
|
||||
bool Poll(const char *path);
|
||||
};
|
||||
#endif // !defined(OS_LINUX)
|
||||
|
||||
class UnixDomainSocketListenerBsd : public BaseSocketBsd {
|
||||
public:
|
||||
explicit UnixDomainSocketListenerBsd(NetworkBsd *network, WireguardProcessor *processor);
|
||||
virtual ~UnixDomainSocketListenerBsd();
|
||||
|
||||
bool Initialize(const char *devname);
|
||||
|
||||
bool Start(bool *exit_flag) {
|
||||
return un_deletion_watcher_.Start(un_addr_.sun_path, exit_flag);
|
||||
}
|
||||
void Stop() { un_deletion_watcher_.Stop(); }
|
||||
|
||||
virtual void HandleEvents(int revents) override;
|
||||
virtual void Periodic() override;
|
||||
private:
|
||||
struct sockaddr_un un_addr_;
|
||||
WireguardProcessor *processor_;
|
||||
UnixSocketDeletionWatcher un_deletion_watcher_;
|
||||
};
|
||||
|
||||
class UnixDomainSocketChannelBsd : public BaseSocketBsd {
|
||||
public:
|
||||
explicit UnixDomainSocketChannelBsd(NetworkBsd *network, WireguardProcessor *processor, int fd);
|
||||
virtual ~UnixDomainSocketChannelBsd();
|
||||
|
||||
virtual void HandleEvents(int revents) override;
|
||||
|
||||
private:
|
||||
bool HandleEventsInner(int revents);
|
||||
WireguardProcessor *processor_;
|
||||
std::string inbuf_, outbuf_;
|
||||
};
|
||||
|
||||
class TcpSocketListenerBsd : public BaseSocketBsd {
|
||||
public:
|
||||
explicit TcpSocketListenerBsd(NetworkBsd *bsd, WireguardProcessor *processor);
|
||||
virtual ~TcpSocketListenerBsd();
|
||||
|
||||
bool Initialize(int port);
|
||||
|
||||
virtual void HandleEvents(int revents) override;
|
||||
virtual void Periodic() override;
|
||||
|
||||
private:
|
||||
WireguardProcessor *processor_;
|
||||
};
|
||||
|
||||
class TcpSocketBsd : public BaseSocketBsd {
|
||||
public:
|
||||
explicit TcpSocketBsd(NetworkBsd *bsd, WireguardProcessor *processor);
|
||||
virtual ~TcpSocketBsd();
|
||||
|
||||
void InitializeIncoming(int fd, const IpAddr &addr);
|
||||
bool InitializeOutgoing(const IpAddr &addr);
|
||||
|
||||
void WritePacket(Packet *packet);
|
||||
|
||||
virtual void HandleEvents(int revents) override;
|
||||
virtual void DoEndloop() override;
|
||||
|
||||
TcpSocketBsd *next() { return next_; }
|
||||
uint8 endpoint_protocol() { return endpoint_protocol_; }
|
||||
const IpAddr &endpoint() { return endpoint_; }
|
||||
|
||||
public:
|
||||
uint8 age;
|
||||
uint8 handshake_attempts;
|
||||
private:
|
||||
void DoRead();
|
||||
void DoWrite();
|
||||
void CloseSocketAndDestroy();
|
||||
|
||||
bool readable_, writable_;
|
||||
bool got_eof_;
|
||||
uint8 endpoint_protocol_;
|
||||
bool want_connect_;
|
||||
|
||||
uint32 wqueue_bytes_;
|
||||
Packet *wqueue_, **wqueue_end_;
|
||||
TcpSocketBsd *next_;
|
||||
WireguardProcessor *processor_;
|
||||
TcpPacketHandler tcp_packet_handler_;
|
||||
IpAddr endpoint_;
|
||||
};
|
||||
|
||||
class NotificationPipeBsd : public BaseSocketBsd {
|
||||
public:
|
||||
NotificationPipeBsd(NetworkBsd *network);
|
||||
~NotificationPipeBsd();
|
||||
|
||||
typedef void CallbackFunc(void *x);
|
||||
void InjectCallback(CallbackFunc *func, void *param);
|
||||
void Wakeup();
|
||||
|
||||
virtual void HandleEvents(int revents) override;
|
||||
|
||||
private:
|
||||
struct CallbackState {
|
||||
CallbackFunc *func;
|
||||
void *param;
|
||||
CallbackState *next;
|
||||
};
|
||||
int pipe_fds_[2];
|
||||
std::atomic<CallbackState*> injected_cb_;
|
||||
};
|
||||
|
||||
|
||||
#endif // TUNSAFE_NETWORK_BSD_H_
|
|
@ -1,101 +0,0 @@
|
|||
// SPDX-License-Identifier: AGPL-1.0-only
|
||||
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
|
||||
#ifndef TUNSAFE_NETWORK_BSD_COMMON_H_
|
||||
#define TUNSAFE_NETWORK_BSD_COMMON_H_
|
||||
|
||||
#include "netapi.h"
|
||||
#include "wireguard.h"
|
||||
#include "wireguard_config.h"
|
||||
#include <string>
|
||||
#include <signal.h>
|
||||
|
||||
struct RouteInfo {
|
||||
uint8 family;
|
||||
uint8 cidr;
|
||||
uint8 ip[16];
|
||||
uint8 gw[16];
|
||||
std::string dev;
|
||||
};
|
||||
|
||||
#if defined(OS_LINUX)
|
||||
// Keeps track of when the unix socket gets deleted
|
||||
class UnixSocketDeletionWatcher {
|
||||
public:
|
||||
UnixSocketDeletionWatcher();
|
||||
~UnixSocketDeletionWatcher();
|
||||
bool Start(const char *path, bool *flag_to_set);
|
||||
void Stop();
|
||||
bool Poll(const char *path) { return false; }
|
||||
|
||||
private:
|
||||
static void *RunThread(void *arg);
|
||||
void *RunThreadInner();
|
||||
const char *path_;
|
||||
int inotify_fd_;
|
||||
int pid_;
|
||||
int pipes_[2];
|
||||
pthread_t thread_;
|
||||
bool *flag_to_set_;
|
||||
};
|
||||
#else // !defined(OS_LINUX)
|
||||
// all other platforms that lack inotify
|
||||
class UnixSocketDeletionWatcher {
|
||||
public:
|
||||
UnixSocketDeletionWatcher() {}
|
||||
~UnixSocketDeletionWatcher() {}
|
||||
bool Start(const char *path, bool *flag_to_set) { return true; }
|
||||
void Stop() {}
|
||||
bool Poll(const char *path);
|
||||
};
|
||||
#endif // !defined(OS_LINUX)
|
||||
|
||||
|
||||
class TunsafeBackendBsd : public TunInterface, public UdpInterface {
|
||||
public:
|
||||
TunsafeBackendBsd();
|
||||
virtual ~TunsafeBackendBsd();
|
||||
|
||||
void RunLoop();
|
||||
void CleanupRoutes();
|
||||
|
||||
void SetTunDeviceName(const char *name);
|
||||
|
||||
void SetProcessor(WireguardProcessor *wg) { processor_ = wg; }
|
||||
|
||||
// -- from TunInterface
|
||||
virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override;
|
||||
|
||||
virtual void HandleSigAlrm() = 0;
|
||||
virtual void HandleExit() = 0;
|
||||
|
||||
protected:
|
||||
virtual bool InitializeTun(char devname[16]) = 0;
|
||||
virtual void RunLoopInner() = 0;
|
||||
|
||||
void AddRoute(uint32 ip, uint32 cidr, uint32 gw, const char *dev);
|
||||
void DelRoute(const RouteInfo &cd);
|
||||
bool AddRoute(int family, const void *dest, int dest_prefix, const void *gateway, const char *dev);
|
||||
bool RunPrePostCommand(const std::vector<std::string> &vec);
|
||||
|
||||
WireguardProcessor *processor_;
|
||||
std::vector<RouteInfo> cleanup_commands_;
|
||||
std::vector<std::string> pre_down_, post_down_;
|
||||
std::vector<WgCidrAddr> addresses_to_remove_;
|
||||
sigset_t orig_signal_mask_;
|
||||
char devname_[16];
|
||||
bool tun_interface_gone_;
|
||||
};
|
||||
|
||||
#if defined(OS_MACOSX) || defined(OS_FREEBSD)
|
||||
#define TUN_PREFIX_BYTES 4
|
||||
#elif defined(OS_LINUX)
|
||||
#define TUN_PREFIX_BYTES 0
|
||||
#endif
|
||||
|
||||
int open_tun(char *devname, size_t devname_size);
|
||||
int open_udp(int listen_on_port);
|
||||
|
||||
void SetThreadName(const char *name);
|
||||
TunsafeBackendBsd *CreateTunsafeBackendBsd();
|
||||
|
||||
#endif // TUNSAFE_NETWORK_BSD_COMMON_H_
|
174
network_common.cpp
Normal file
174
network_common.cpp
Normal file
|
@ -0,0 +1,174 @@
|
|||
#include "stdafx.h"
|
||||
#include "network_common.h"
|
||||
#include "netapi.h"
|
||||
#include "tunsafe_endian.h"
|
||||
#include <assert.h>
|
||||
#include <algorithm>
|
||||
#include "util.h"
|
||||
|
||||
TcpPacketHandler::TcpPacketHandler(SimplePacketPool *packet_pool) {
|
||||
packet_pool_ = packet_pool;
|
||||
rqueue_bytes_ = 0;
|
||||
error_flag_ = false;
|
||||
rqueue_ = NULL;
|
||||
rqueue_end_ = &rqueue_;
|
||||
predicted_key_in_ = predicted_key_out_ = 0;
|
||||
predicted_serial_in_ = predicted_serial_out_ = 0;
|
||||
}
|
||||
|
||||
TcpPacketHandler::~TcpPacketHandler() {
|
||||
FreePacketList(rqueue_);
|
||||
}
|
||||
|
||||
enum {
|
||||
kTcpPacketType_Normal = 0,
|
||||
kTcpPacketType_Reserved = 1,
|
||||
kTcpPacketType_Data = 2,
|
||||
kTcpPacketType_Control = 3,
|
||||
kTcpPacketControlType_SetKeyAndCounter = 0,
|
||||
};
|
||||
|
||||
void TcpPacketHandler::AddHeaderToOutgoingPacket(Packet *p) {
|
||||
unsigned int size = p->size;
|
||||
uint8 *data = p->data;
|
||||
if (size >= 16 && ReadLE32(data) == 4) {
|
||||
uint32 key = Read32(data + 4);
|
||||
uint64 serial = ReadLE64(data + 8);
|
||||
WriteBE16(data + 14, size - 16 + (kTcpPacketType_Data << 14));
|
||||
data += 14, size -= 14;
|
||||
// Insert a 15 byte control packet right before to set the new key/serial?
|
||||
if ((predicted_key_out_ ^ key) | (predicted_serial_out_ ^ serial)) {
|
||||
predicted_key_out_ = key;
|
||||
WriteLE64(data - 8, serial);
|
||||
Write32(data - 12, key);
|
||||
data[-13] = kTcpPacketControlType_SetKeyAndCounter;
|
||||
WriteBE16(data - 15, 13 + (kTcpPacketType_Control << 14));
|
||||
data -= 15, size += 15;
|
||||
}
|
||||
// Increase the serial by 1 for next packet.
|
||||
predicted_serial_out_ = serial + 1;
|
||||
} else {
|
||||
WriteBE16(data - 2, size);
|
||||
data -= 2, size += 2;
|
||||
}
|
||||
p->size = size;
|
||||
p->data = data;
|
||||
}
|
||||
|
||||
void TcpPacketHandler::QueueIncomingPacket(Packet *p) {
|
||||
rqueue_bytes_ += p->size;
|
||||
p->queue_next = NULL;
|
||||
*rqueue_end_ = p;
|
||||
rqueue_end_ = &Packet_NEXT(p);
|
||||
}
|
||||
|
||||
// Either the packet fits in one buf or not.
|
||||
static uint32 ReadPacketHeader(Packet *p) {
|
||||
if (p->size >= 2)
|
||||
return ReadBE16(p->data);
|
||||
else
|
||||
return (p->data[0] << 8) + (Packet_NEXT(p)->data[0]);
|
||||
}
|
||||
|
||||
// Move data around to ensure that exactly the first |num| bytes are stored
|
||||
// in the first packet, and the rest of the data in subsequent packets.
|
||||
Packet *TcpPacketHandler::ReadNextPacket(uint32 num) {
|
||||
Packet *p = rqueue_;
|
||||
|
||||
assert(num <= kPacketCapacity);
|
||||
if (p->size < num) {
|
||||
// There's not enough data in the current packet, copy data from the next packet
|
||||
// into this packet.
|
||||
if ((uint32)(&p->data_buf[kPacketCapacity] - p->data) < num) {
|
||||
// Move data up front to make space.
|
||||
memmove(p->data_buf, p->data, p->size);
|
||||
p->data = p->data_buf;
|
||||
}
|
||||
// Copy data from future packets into p, and delete them should they become empty.
|
||||
do {
|
||||
Packet *n = Packet_NEXT(p);
|
||||
uint32 bytes_to_copy = std::min(n->size, num - p->size);
|
||||
uint32 nsize = (n->size -= bytes_to_copy);
|
||||
memcpy(p->data + postinc(p->size, bytes_to_copy), postinc(n->data, bytes_to_copy), bytes_to_copy);
|
||||
if (nsize == 0) {
|
||||
p->queue_next = n->queue_next;
|
||||
packet_pool_->FreePacketToPool(n);
|
||||
}
|
||||
} while (num - p->size);
|
||||
} else if (p->size > num) {
|
||||
// The packet has too much data. Split the packet into two packets.
|
||||
Packet *n = packet_pool_->AllocPacketFromPool();
|
||||
if (!n)
|
||||
return NULL; // unable to allocate a packet....?
|
||||
if (num * 2 <= p->size) {
|
||||
// There's a lot of trailing data: PP NNNNNN. Move PP.
|
||||
n->size = num;
|
||||
p->size -= num;
|
||||
rqueue_bytes_ -= num;
|
||||
memcpy(n->data, postinc(p->data, num), num);
|
||||
return n;
|
||||
} else {
|
||||
uint32 overflow = p->size - num;
|
||||
// There's a lot of leading data: PPPPPP NN. Move NN
|
||||
n->size = overflow;
|
||||
p->size = num;
|
||||
rqueue_ = n;
|
||||
if (!(n->queue_next = p->queue_next))
|
||||
rqueue_end_ = &Packet_NEXT(n);
|
||||
rqueue_bytes_ -= num;
|
||||
memcpy(n->data, p->data + num, overflow);
|
||||
return p;
|
||||
}
|
||||
}
|
||||
if ((rqueue_ = Packet_NEXT(p)) == NULL)
|
||||
rqueue_end_ = &rqueue_;
|
||||
rqueue_bytes_ -= num;
|
||||
return p;
|
||||
}
|
||||
|
||||
Packet *TcpPacketHandler::GetNextWireguardPacket() {
|
||||
while (rqueue_bytes_ >= 2) {
|
||||
uint32 packet_header = ReadPacketHeader(rqueue_);
|
||||
uint32 packet_size = packet_header & 0x3FFF;
|
||||
uint32 packet_type = packet_header >> 14;
|
||||
if (packet_size + 2 > rqueue_bytes_)
|
||||
return NULL;
|
||||
if (packet_size + 2 > kPacketCapacity) {
|
||||
RERROR("Oversized packet?");
|
||||
error_flag_ = true;
|
||||
return NULL;
|
||||
}
|
||||
Packet *packet = ReadNextPacket(packet_size + 2);
|
||||
if (packet) {
|
||||
// RINFO("Packet of type %d, size %d", packet_type, packet->size - 2);
|
||||
packet->data += 2, packet->size -= 2;
|
||||
if (packet_type == kTcpPacketType_Normal) {
|
||||
|
||||
return packet;
|
||||
} else if (packet_type == kTcpPacketType_Data) {
|
||||
// Optimization when the 16 first bytes are known and prefixed to the packet
|
||||
assert(packet->data >= packet->data_buf);
|
||||
|
||||
packet->data -= 16, packet->size += 16;
|
||||
WriteLE32(packet->data, 4);
|
||||
Write32(packet->data + 4, predicted_key_in_);
|
||||
WriteLE64(packet->data + 8, predicted_serial_in_);
|
||||
predicted_serial_in_++;
|
||||
return packet;
|
||||
} else if (packet_type == kTcpPacketType_Control) {
|
||||
// Unknown control packets are silently ignored
|
||||
if (packet->size == 13 && packet->data[0] == kTcpPacketControlType_SetKeyAndCounter) {
|
||||
// Control packet to setup the predicted key/sequence nr
|
||||
predicted_key_in_ = Read32(packet->data + 1);
|
||||
predicted_serial_in_ = ReadLE64(packet->data + 5);
|
||||
}
|
||||
packet_pool_->FreePacketToPool(packet);
|
||||
} else {
|
||||
packet_pool_->FreePacketToPool(packet);
|
||||
error_flag_ = true;
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
}
|
||||
return NULL;
|
||||
}
|
95
network_common.h
Normal file
95
network_common.h
Normal file
|
@ -0,0 +1,95 @@
|
|||
#ifndef TUNSAFE_NETWORK_COMMON_H_
|
||||
#define TUNSAFE_NETWORK_COMMON_H_
|
||||
|
||||
#include "netapi.h"
|
||||
|
||||
class PacketProcessor;
|
||||
|
||||
// A simple singlethreaded pool of packets used on windows where
|
||||
// FreePacket / AllocPacket are multithreded and thus slightly slower
|
||||
#if defined(OS_WIN)
|
||||
class SimplePacketPool {
|
||||
public:
|
||||
explicit SimplePacketPool() {
|
||||
freed_packets_ = NULL;
|
||||
freed_packets_count_ = 0;
|
||||
}
|
||||
~SimplePacketPool() {
|
||||
FreePacketList(freed_packets_);
|
||||
}
|
||||
Packet *AllocPacketFromPool() {
|
||||
if (Packet *p = freed_packets_) {
|
||||
freed_packets_ = Packet_NEXT(p);
|
||||
freed_packets_count_--;
|
||||
p->Reset();
|
||||
return p;
|
||||
}
|
||||
return AllocPacket();
|
||||
}
|
||||
void FreePacketToPool(Packet *p) {
|
||||
Packet_NEXT(p) = freed_packets_;
|
||||
freed_packets_ = p;
|
||||
freed_packets_count_++;
|
||||
}
|
||||
void FreeSomePackets() {
|
||||
if (freed_packets_count_ > 32)
|
||||
FreeSomePacketsInner();
|
||||
}
|
||||
void FreeSomePacketsInner();
|
||||
|
||||
|
||||
int freed_packets_count_;
|
||||
Packet *freed_packets_;
|
||||
};
|
||||
#else
|
||||
class SimplePacketPool {
|
||||
public:
|
||||
Packet *AllocPacketFromPool() {
|
||||
return AllocPacket();
|
||||
}
|
||||
void FreePacketToPool(Packet *packet) {
|
||||
return FreePacket(packet);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
// Aids with prefixing and parsing incoming and outgoing
|
||||
// packets with the tcp protocol header.
|
||||
class TcpPacketHandler {
|
||||
public:
|
||||
explicit TcpPacketHandler(SimplePacketPool *packet_pool);
|
||||
~TcpPacketHandler();
|
||||
|
||||
// Adds a tcp header to a data packet so it can be transmitted on the wire
|
||||
void AddHeaderToOutgoingPacket(Packet *p);
|
||||
|
||||
// Add a new chunk of incoming data to the packet list
|
||||
void QueueIncomingPacket(Packet *p);
|
||||
|
||||
// Attempt to extract the next packet, returns NULL when complete.
|
||||
Packet *GetNextWireguardPacket();
|
||||
|
||||
bool error() const { return error_flag_; }
|
||||
|
||||
private:
|
||||
// Internal function to read a packet
|
||||
Packet *ReadNextPacket(uint32 num);
|
||||
|
||||
SimplePacketPool *packet_pool_;
|
||||
|
||||
// Total # of bytes queued
|
||||
uint32 rqueue_bytes_;
|
||||
|
||||
// Set if there's a fatal error
|
||||
bool error_flag_;
|
||||
|
||||
// These hold the incoming packets before they're parsed
|
||||
Packet *rqueue_, **rqueue_end_;
|
||||
|
||||
uint32 predicted_key_in_, predicted_key_out_;
|
||||
uint64 predicted_serial_in_, predicted_serial_out_;
|
||||
};
|
||||
|
||||
#endif // TUNSAFE_NETWORK_COMMON_H_
|
|
@ -5,6 +5,8 @@
|
|||
#include "wireguard_config.h"
|
||||
#include "netapi.h"
|
||||
#include <Iphlpapi.h>
|
||||
#include <Mswsock.h>
|
||||
#include <ws2ipdef.h>
|
||||
#include <stdlib.h>
|
||||
#include <assert.h>
|
||||
#include <malloc.h>
|
||||
|
@ -12,7 +14,6 @@
|
|||
#include <string.h>
|
||||
#include <vector>
|
||||
#include <Iphlpapi.h>
|
||||
#include <ws2ipdef.h>
|
||||
#include <assert.h>
|
||||
#include <exdisp.h>
|
||||
#include "tunsafe_endian.h"
|
||||
|
@ -42,15 +43,20 @@ static HKEY g_hklm_reg_key;
|
|||
static uint8 g_killswitch_curr, g_killswitch_want, g_killswitch_currconn;
|
||||
|
||||
bool g_allow_pre_post;
|
||||
static volatile bool g_fail_malloc_flag;
|
||||
|
||||
static void DeactivateKillSwitch(uint32 want);
|
||||
|
||||
Packet *AllocPacket() {
|
||||
Packet *packet = (Packet*)InterlockedPopEntrySList(&freelist_head);
|
||||
if (packet == NULL)
|
||||
packet = (Packet *)_aligned_malloc(kPacketAllocSize, 16);
|
||||
packet->data = packet->data_buf + Packet::HEADROOM_BEFORE;
|
||||
packet->size = 0;
|
||||
if (packet == NULL) {
|
||||
while ((packet = (Packet *)_aligned_malloc(kPacketAllocSize, 16)) == NULL) {
|
||||
if (g_fail_malloc_flag)
|
||||
return NULL;
|
||||
Sleep(1000);
|
||||
}
|
||||
}
|
||||
packet->Reset();
|
||||
return packet;
|
||||
}
|
||||
|
||||
|
@ -83,6 +89,14 @@ void FreeAllPackets() {
|
|||
}
|
||||
}
|
||||
|
||||
void SimplePacketPool::FreeSomePacketsInner() {
|
||||
int n = freed_packets_count_ - 24;
|
||||
Packet **p = &freed_packets_;
|
||||
for (; n; n--)
|
||||
p = &Packet_NEXT(*p);
|
||||
FreePackets(exch(freed_packets_, *p), p, exch(freed_packets_count_, 24) - 24);
|
||||
}
|
||||
|
||||
void InitPacketMutexes() {
|
||||
static bool mutex_inited;
|
||||
if (!mutex_inited) {
|
||||
|
@ -91,17 +105,6 @@ void InitPacketMutexes() {
|
|||
}
|
||||
}
|
||||
|
||||
int tpq_last_qsize;
|
||||
int g_tun_reads, g_tun_writes;
|
||||
|
||||
struct {
|
||||
uint32 pad1[3];
|
||||
uint32 udp_qsize1;
|
||||
uint32 pad2[3];
|
||||
uint32 udp_qsize2;
|
||||
} qs;
|
||||
|
||||
|
||||
#define kConcurrentReadTap 16
|
||||
#define kConcurrentWriteTap 16
|
||||
|
||||
|
@ -399,47 +402,13 @@ static bool GetDefaultRouteAndDeleteOldRoutes(int family, const NET_LUID *Interf
|
|||
return (rv == 0);
|
||||
}
|
||||
|
||||
static inline bool NoMoreAllocationRetry(volatile bool *exit_flag) {
|
||||
if (*exit_flag)
|
||||
return true;
|
||||
Sleep(1000);
|
||||
return *exit_flag;
|
||||
}
|
||||
|
||||
static inline bool AllocPacketFrom(Packet **list, int *counter, bool *exit_flag, Packet **res) {
|
||||
Packet *p;
|
||||
if (p = *list) {
|
||||
*list = Packet_NEXT(p);
|
||||
(*counter)--;
|
||||
p->data = p->data_buf + Packet::HEADROOM_BEFORE;
|
||||
} else {
|
||||
while ((p = AllocPacket()) == NULL) {
|
||||
if (NoMoreAllocationRetry(exit_flag))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
*res = p;
|
||||
return true;
|
||||
}
|
||||
|
||||
static void FreePacketList(Packet *pp) {
|
||||
void FreePacketList(Packet *pp) {
|
||||
while (Packet *p = pp) {
|
||||
pp = Packet_NEXT(p);
|
||||
FreePacket(p);
|
||||
}
|
||||
}
|
||||
|
||||
inline void NetworkWin32::FreePacketToPool(Packet *p) {
|
||||
Packet_NEXT(p) = NULL;
|
||||
*freed_packets_end_ = p;
|
||||
freed_packets_end_ = &Packet_NEXT(p);
|
||||
freed_packets_count_++;
|
||||
}
|
||||
|
||||
inline bool NetworkWin32::AllocPacketFromPool(Packet **p) {
|
||||
return AllocPacketFrom(&freed_packets_, &freed_packets_count_, &exit_thread_, p);
|
||||
}
|
||||
|
||||
UdpSocketWin32::UdpSocketWin32(NetworkWin32 *network_win32) {
|
||||
network_ = network_win32;
|
||||
wqueue_end_ = &wqueue_;
|
||||
|
@ -455,6 +424,9 @@ UdpSocketWin32::UdpSocketWin32(NetworkWin32 *network_win32) {
|
|||
num_reads_[0] = num_reads_[1] = 0;
|
||||
num_writes_ = 0;
|
||||
pending_writes_ = NULL;
|
||||
|
||||
qsize1_ = 0;
|
||||
qsize2_ = 0;
|
||||
}
|
||||
|
||||
UdpSocketWin32::~UdpSocketWin32() {
|
||||
|
@ -529,12 +501,12 @@ fail:
|
|||
|
||||
// Called on another thread to queue up a udp packet
|
||||
void UdpSocketWin32::WriteUdpPacket(Packet *packet) {
|
||||
if (qs.udp_qsize2 - qs.udp_qsize1 >= (unsigned)(packet->size < 576 ? MAX_BYTES_IN_UDP_OUT_QUEUE_SMALL : MAX_BYTES_IN_UDP_OUT_QUEUE)) {
|
||||
if (qsize2_ - qsize1_ >= (unsigned)(packet->size < 576 ? MAX_BYTES_IN_UDP_OUT_QUEUE_SMALL : MAX_BYTES_IN_UDP_OUT_QUEUE)) {
|
||||
FreePacket(packet);
|
||||
return;
|
||||
}
|
||||
Packet_NEXT(packet) = NULL;
|
||||
qs.udp_qsize2 += packet->size;
|
||||
packet->queue_next = NULL;
|
||||
qsize2_ += packet->size;
|
||||
|
||||
mutex_.Acquire();
|
||||
Packet *was_empty = wqueue_;
|
||||
|
@ -542,20 +514,14 @@ void UdpSocketWin32::WriteUdpPacket(Packet *packet) {
|
|||
wqueue_end_ = &Packet_NEXT(packet);
|
||||
mutex_.Release();
|
||||
|
||||
if (was_empty == NULL) {
|
||||
// Notify the worker thread that it should attempt more writes
|
||||
PostQueuedCompletionStatus(network_->completion_port_handle_, NULL, NULL, NULL);
|
||||
}
|
||||
if (was_empty == NULL)
|
||||
network_->WakeUp();
|
||||
}
|
||||
|
||||
enum {
|
||||
kUdpGetQueuedCompletionStatusSize = kConcurrentWriteTap + kConcurrentReadTap + 1
|
||||
};
|
||||
|
||||
static inline void ClearOverlapped(OVERLAPPED *o) {
|
||||
memset(o, 0, sizeof(*o));
|
||||
}
|
||||
|
||||
#ifndef STATUS_PORT_UNREACHABLE
|
||||
#define STATUS_PORT_UNREACHABLE 0xC000023F
|
||||
#endif
|
||||
|
@ -567,8 +533,8 @@ static inline bool IsIgnoredUdpError(DWORD err) {
|
|||
void UdpSocketWin32::DoMoreReads() {
|
||||
// Listen with multiple ipv6 packets only if we ever sent an ipv6 packet.
|
||||
for (int i = num_reads_[IPV6]; i < max_read_ipv6_; i++) {
|
||||
Packet *p;
|
||||
if (!network_->AllocPacketFromPool(&p))
|
||||
Packet *p = network_->packet_pool().AllocPacketFromPool();
|
||||
if (!p)
|
||||
break;
|
||||
restart_read_udp6:
|
||||
ClearOverlapped(&p->overlapped);
|
||||
|
@ -590,9 +556,11 @@ restart_read_udp6:
|
|||
num_reads_[IPV6]++;
|
||||
}
|
||||
// Initiate more reads, reusing the Packet structures in |finished_writes|.
|
||||
|
||||
if (socket_ != INVALID_SOCKET) {
|
||||
for (int i = num_reads_[IPV4]; i < kConcurrentReadUdp; i++) {
|
||||
Packet *p;
|
||||
if (!network_->AllocPacketFromPool(&p))
|
||||
Packet *p = network_->packet_pool().AllocPacketFromPool();
|
||||
if (!p)
|
||||
break;
|
||||
restart_read_udp:
|
||||
ClearOverlapped(&p->overlapped);
|
||||
|
@ -613,9 +581,10 @@ restart_read_udp:
|
|||
}
|
||||
num_reads_[IPV4]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void UdpSocketWin32::DoMoreWrites() {
|
||||
void UdpSocketWin32::ProcessPackets() {
|
||||
// Push all the finished reads to the packet handler
|
||||
if (finished_reads_ != NULL) {
|
||||
packet_handler_->PostPackets(finished_reads_, finished_reads_end_, finished_reads_count_);
|
||||
|
@ -623,7 +592,9 @@ void UdpSocketWin32::DoMoreWrites() {
|
|||
finished_reads_end_ = &finished_reads_;
|
||||
finished_reads_count_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void UdpSocketWin32::DoMoreWrites() {
|
||||
Packet *pending_writes = pending_writes_;
|
||||
// Initiate more writes from |wqueue_|
|
||||
while (num_writes_ < kConcurrentWriteUdp) {
|
||||
|
@ -639,7 +610,7 @@ void UdpSocketWin32::DoMoreWrites() {
|
|||
if (!pending_writes)
|
||||
break;
|
||||
}
|
||||
qs.udp_qsize1 += pending_writes->size;
|
||||
qsize1_ += pending_writes->size;
|
||||
|
||||
// Then issue writes
|
||||
Packet *p = pending_writes;
|
||||
|
@ -688,13 +659,14 @@ void UdpSocketWin32::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) {
|
|||
if (p->userdata < 2) {
|
||||
num_reads_[p->userdata]--;
|
||||
if ((DWORD)p->overlapped.Internal != 0) {
|
||||
network_->FreePacketToPool(p);
|
||||
if (!IsIgnoredUdpError((DWORD)p->overlapped.Internal))
|
||||
RERROR("UdpSocketWin32::Read error 0x%X", (DWORD)p->overlapped.Internal);
|
||||
network_->packet_pool().FreePacketToPool(p);
|
||||
} else {
|
||||
// Remember all the finished packets and queue them up to the next thread once we've
|
||||
// collected them all.
|
||||
p->size = (int)p->overlapped.InternalHigh;
|
||||
p->protocol = kPacketProtocolUdp;
|
||||
p->queue_cb = packet_handler_->udp_queue();
|
||||
p->queue_next = NULL;
|
||||
*finished_reads_end_ = p;
|
||||
|
@ -703,9 +675,9 @@ void UdpSocketWin32::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) {
|
|||
}
|
||||
} else {
|
||||
num_writes_--;
|
||||
network_->FreePacketToPool(p);
|
||||
if ((DWORD)p->overlapped.Internal != 0)
|
||||
RERROR("UdpSocketWin32::Write error 0x%X", (DWORD)p->overlapped.Internal);
|
||||
network_->packet_pool().FreePacketToPool(p);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -716,28 +688,30 @@ void UdpSocketWin32::OnQueuedItemDelete(QueuedItem *qi) {
|
|||
} else {
|
||||
num_writes_--;
|
||||
}
|
||||
network_->FreePacketToPool(p);
|
||||
network_->packet_pool().FreePacketToPool(p);
|
||||
}
|
||||
|
||||
void UdpSocketWin32::DoIO() {
|
||||
DoMoreWrites();
|
||||
ProcessPackets();
|
||||
DoMoreReads();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
NetworkWin32::NetworkWin32() : udp_socket_(this) {
|
||||
NetworkWin32::NetworkWin32() : udp_socket_(this), tcp_socket_queue_(this) {
|
||||
exit_thread_ = false;
|
||||
thread_ = NULL;
|
||||
freed_packets_ = NULL;
|
||||
freed_packets_end_ = &freed_packets_;
|
||||
freed_packets_count_ = 0;
|
||||
completion_port_handle_ = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, NULL, 0);
|
||||
tcp_socket_ = NULL;
|
||||
}
|
||||
|
||||
NetworkWin32::~NetworkWin32() {
|
||||
assert(thread_ == NULL);
|
||||
|
||||
for (TcpSocketWin32 *socket = tcp_socket_; socket; )
|
||||
delete exch(socket, socket->next_);
|
||||
|
||||
CloseHandle(completion_port_handle_);
|
||||
FreePacketList(freed_packets_);
|
||||
}
|
||||
|
||||
DWORD WINAPI NetworkWin32::NetworkThread(void *x) {
|
||||
|
@ -750,20 +724,13 @@ void NetworkWin32::ThreadMain() {
|
|||
OVERLAPPED_ENTRY entries[kUdpGetQueuedCompletionStatusSize];
|
||||
|
||||
while (!exit_thread_) {
|
||||
// Run IO on all sockets queued for IO
|
||||
// TODO: In the future, don't process every socket here, only
|
||||
// those sockets that requested it.
|
||||
udp_socket_.DoIO();
|
||||
for (TcpSocketWin32 *tcp = tcp_socket_; tcp;)
|
||||
exch(tcp, tcp->next_)->DoIO();
|
||||
|
||||
// Free some packets
|
||||
assert(freed_packets_count_ >= 0);
|
||||
if (freed_packets_count_ >= 32) {
|
||||
FreePackets(freed_packets_, freed_packets_end_, freed_packets_count_);
|
||||
freed_packets_count_ = 0;
|
||||
freed_packets_ = NULL;
|
||||
freed_packets_end_ = &freed_packets_;
|
||||
} else if (freed_packets_ == NULL) {
|
||||
assert(freed_packets_count_ == 0);
|
||||
freed_packets_end_ = &freed_packets_;
|
||||
}
|
||||
packet_pool_.FreeSomePackets();
|
||||
|
||||
ULONG num_entries = 0;
|
||||
if (!GetQueuedCompletionStatusEx(completion_port_handle_, entries, kUdpGetQueuedCompletionStatusSize, &num_entries, INFINITE, FALSE)) {
|
||||
|
@ -779,18 +746,31 @@ void NetworkWin32::ThreadMain() {
|
|||
}
|
||||
|
||||
udp_socket_.CancelAllIO();
|
||||
for (TcpSocketWin32 *tcp = tcp_socket_; tcp; tcp = tcp->next_)
|
||||
tcp->CancelAllIO();
|
||||
|
||||
while (udp_socket_.HasOutstandingIO()) {
|
||||
while (HasOutstandingIO()) {
|
||||
ULONG num_entries = 0;
|
||||
if (!GetQueuedCompletionStatusEx(completion_port_handle_, entries, 1, &num_entries, INFINITE, FALSE)) {
|
||||
if (!GetQueuedCompletionStatusEx(completion_port_handle_, entries, kUdpGetQueuedCompletionStatusSize, &num_entries, INFINITE, FALSE)) {
|
||||
RINFO("GetQueuedCompletionStatusEx failed.");
|
||||
break;
|
||||
}
|
||||
if (entries[0].lpOverlapped) {
|
||||
QueuedItem *w = (QueuedItem*)((byte*)entries[0].lpOverlapped - offsetof(QueuedItem, overlapped));
|
||||
for (ULONG i = 0; i < num_entries; i++) {
|
||||
if (entries[i].lpOverlapped) {
|
||||
QueuedItem *w = (QueuedItem*)((byte*)entries[i].lpOverlapped - offsetof(QueuedItem, overlapped));
|
||||
w->queue_cb->OnQueuedItemDelete(w);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool NetworkWin32::HasOutstandingIO() {
|
||||
if (udp_socket_.HasOutstandingIO())
|
||||
return true;
|
||||
for (TcpSocketWin32 *tcp = tcp_socket_; tcp; tcp = tcp->next_)
|
||||
if (tcp->HasOutstandingIO())
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
void NetworkWin32::StartThread() {
|
||||
|
@ -804,11 +784,36 @@ void NetworkWin32::StartThread() {
|
|||
void NetworkWin32::StopThread() {
|
||||
if (thread_ != NULL) {
|
||||
exit_thread_ = true;
|
||||
g_fail_malloc_flag = true;
|
||||
PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, NULL);
|
||||
WaitForSingleObject(thread_, INFINITE);
|
||||
CloseHandle(thread_);
|
||||
thread_ = NULL;
|
||||
exit_thread_ = false;
|
||||
g_fail_malloc_flag = false;
|
||||
}
|
||||
}
|
||||
|
||||
void NetworkWin32::WakeUp() {
|
||||
PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, NULL);
|
||||
}
|
||||
|
||||
void NetworkWin32::PostQueuedItem(QueuedItem *item) {
|
||||
PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, &item->overlapped);
|
||||
}
|
||||
|
||||
bool NetworkWin32::Configure(int listen_port, int listen_port_tcp) {
|
||||
if (listen_port_tcp)
|
||||
RERROR("ListenPortTCP not supported in this version");
|
||||
return udp_socket_.Configure(listen_port);
|
||||
}
|
||||
|
||||
// Called from tunsafe thread
|
||||
void NetworkWin32::WriteUdpPacket(Packet *packet) {
|
||||
if (packet->protocol & kPacketProtocolUdp) {
|
||||
udp_socket_.WriteUdpPacket(packet);
|
||||
} else {
|
||||
tcp_socket_queue_.WritePacket(packet);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -874,6 +879,8 @@ int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) {
|
|||
|
||||
mutex_.Acquire();
|
||||
while (!(exit_code = exit_code_)) {
|
||||
FreeAllPackets();
|
||||
|
||||
if (timer_interrupt_) {
|
||||
timer_interrupt_ = false;
|
||||
need_notify_ = 0;
|
||||
|
@ -912,7 +919,6 @@ int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) {
|
|||
need_notify_ = 0;
|
||||
mutex_.Release();
|
||||
|
||||
tpq_last_qsize = packets_in_queue;
|
||||
if (packets_in_queue >= 1024)
|
||||
overload = 2;
|
||||
queue_context.overload = (overload != 0);
|
||||
|
@ -986,8 +992,8 @@ void PacketProcessor::PostPackets(Packet *first, Packet **end, int count) {
|
|||
}
|
||||
|
||||
void PacketProcessor::ForcePost(QueuedItem *item) {
|
||||
mutex_.Acquire();
|
||||
item->queue_next = NULL;
|
||||
mutex_.Acquire();
|
||||
packets_in_queue_ += 1;
|
||||
*last_ptr_ = item;
|
||||
last_ptr_ = &item->queue_next;
|
||||
|
@ -1648,7 +1654,6 @@ bool TunWin32Adapter::RunPrePostCommand(const std::vector<std::string> &vec) {
|
|||
return success;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TunWin32Iocp::TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker, backend->guid_), backend_(backend) {
|
||||
|
@ -1704,6 +1709,20 @@ enum {
|
|||
kTunGetQueuedCompletionStatusSize = kConcurrentWriteTap + kConcurrentReadTap + 1
|
||||
};
|
||||
|
||||
static inline bool AllocPacketFrom(Packet **list, int *counter, bool *exit_flag, Packet **res) {
|
||||
Packet *p;
|
||||
if (p = *list) {
|
||||
*list = Packet_NEXT(p);
|
||||
(*counter)--;
|
||||
p->data = p->data_buf;
|
||||
} else {
|
||||
if (!(p = AllocPacket()))
|
||||
return false;
|
||||
}
|
||||
*res = p;
|
||||
return true;
|
||||
}
|
||||
|
||||
void TunWin32Iocp::ThreadMain() {
|
||||
OVERLAPPED_ENTRY entries[kTunGetQueuedCompletionStatusSize];
|
||||
Packet *pending_writes = NULL;
|
||||
|
@ -1738,7 +1757,6 @@ void TunWin32Iocp::ThreadMain() {
|
|||
num_reads++;
|
||||
}
|
||||
}
|
||||
g_tun_reads = num_reads;
|
||||
|
||||
assert(freed_packets_count >= 0);
|
||||
if (freed_packets_count >= 32) {
|
||||
|
@ -1820,7 +1838,6 @@ void TunWin32Iocp::ThreadMain() {
|
|||
num_writes++;
|
||||
}
|
||||
}
|
||||
g_tun_writes = num_writes;
|
||||
}
|
||||
|
||||
EXIT:
|
||||
|
@ -1896,145 +1913,49 @@ void TunWin32Iocp::WriteTunPacket(Packet *packet) {
|
|||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TunWin32Overlapped::TunWin32Overlapped(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker, backend->guid_), backend_(backend) {
|
||||
wqueue_end_ = &wqueue_;
|
||||
wqueue_ = NULL;
|
||||
|
||||
thread_ = NULL;
|
||||
|
||||
read_event_ = CreateEvent(NULL, TRUE, FALSE, NULL);
|
||||
write_event_ = CreateEvent(NULL, TRUE, FALSE, NULL);
|
||||
wake_event_ = CreateEvent(NULL, FALSE, FALSE, NULL);
|
||||
|
||||
packet_handler_ = NULL;
|
||||
exit_thread_ = false;
|
||||
TunsafeBackend::TunsafeBackend() {
|
||||
is_started_ = false;
|
||||
is_remote_ = false;
|
||||
ipv4_ip_ = 0;
|
||||
status_ = kStatusStopped;
|
||||
memset(public_key_, 0, sizeof(public_key_));
|
||||
}
|
||||
|
||||
TunWin32Overlapped::~TunWin32Overlapped() {
|
||||
CloseTun();
|
||||
CloseHandle(read_event_);
|
||||
CloseHandle(write_event_);
|
||||
CloseHandle(wake_event_);
|
||||
TunsafeBackend::~TunsafeBackend() {
|
||||
|
||||
}
|
||||
|
||||
bool TunWin32Overlapped::Configure(const TunConfig &&config, TunConfigOut *out) {
|
||||
CloseTun();
|
||||
if (adapter_.OpenAdapter(backend_, FILE_FLAG_OVERLAPPED) &&
|
||||
adapter_.ConfigureAdapter(std::move(config), out))
|
||||
return true;
|
||||
CloseTun();
|
||||
return false;
|
||||
static bool GetKillSwitchRouteActive() {
|
||||
RouteInfo ri;
|
||||
return (GetDefaultRouteAndDeleteOldRoutes(AF_INET, NULL, TRUE, NULL, &ri) && ri.found_null_routes == 2);
|
||||
}
|
||||
|
||||
void TunWin32Overlapped::CloseTun() {
|
||||
assert(thread_ == NULL);
|
||||
adapter_.CloseAdapter(false);
|
||||
FreePacketList(wqueue_);
|
||||
wqueue_ = NULL;
|
||||
wqueue_end_ = &wqueue_;
|
||||
static void RemoveKillSwitchRoute() {
|
||||
RouteInfo ri;
|
||||
GetDefaultRouteAndDeleteOldRoutes(AF_INET, NULL, FALSE, NULL, &ri);
|
||||
GetDefaultRouteAndDeleteOldRoutes(AF_INET6, NULL, FALSE, NULL, &ri);
|
||||
}
|
||||
|
||||
void TunWin32Overlapped::ThreadMain() {
|
||||
Packet *pending_writes = NULL;
|
||||
DWORD err;
|
||||
Packet *read_packet = NULL, *write_packet = NULL;
|
||||
|
||||
HANDLE h[3];
|
||||
while (!exit_thread_) {
|
||||
if (read_packet == NULL) {
|
||||
Packet *p = AllocPacket();
|
||||
ClearOverlapped(&p->overlapped);
|
||||
p->overlapped.hEvent = read_event_;
|
||||
if (!ReadFile(adapter_.handle(), p->data, kPacketCapacity, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) {
|
||||
FreePacket(p);
|
||||
RERROR("TunWin32: ReadFile failed 0x%X", err);
|
||||
} else {
|
||||
read_packet = p;
|
||||
TunsafeBackendWin32::TunsafeBackendWin32(Delegate *delegate) : delegate_(delegate), dns_resolver_(&dns_blocker_) {
|
||||
memset(&stats_, 0, sizeof(stats_));
|
||||
wg_processor_ = NULL;
|
||||
InitPacketMutexes();
|
||||
worker_thread_ = NULL;
|
||||
last_tun_adapter_failed_ = 0;
|
||||
want_periodic_stats_ = false;
|
||||
guid_[0] = 0;
|
||||
if (g_hklm_reg_key == NULL) {
|
||||
RegCreateKeyEx(HKEY_LOCAL_MACHINE, "Software\\TunSafe", NULL, NULL, 0, KEY_ALL_ACCESS, NULL, &g_hklm_reg_key, NULL);
|
||||
g_killswitch_want = RegReadInt(g_hklm_reg_key, "KillSwitch", 0);
|
||||
g_killswitch_curr = GetKillSwitchRouteActive() * kBlockInternet_Route +
|
||||
GetKillSwitchFirewallActive() * kBlockInternet_Firewall;
|
||||
}
|
||||
}
|
||||
|
||||
int n = 0;
|
||||
if (write_packet)
|
||||
h[n++] = write_event_;
|
||||
if (read_packet != NULL)
|
||||
h[n++] = read_event_;
|
||||
h[n++] = wake_event_;
|
||||
|
||||
DWORD res = WaitForMultipleObjects(n, h, FALSE, INFINITE);
|
||||
|
||||
if (res >= WAIT_OBJECT_0 && res <= WAIT_OBJECT_0 + 2) {
|
||||
HANDLE hx = h[res - WAIT_OBJECT_0];
|
||||
if (hx == read_event_) {
|
||||
read_packet->size = (int)read_packet->overlapped.InternalHigh;
|
||||
Packet_NEXT(read_packet) = NULL;
|
||||
packet_handler_->PostPackets(read_packet, &Packet_NEXT(read_packet), 1);
|
||||
read_packet = NULL;
|
||||
} else if (hx == write_event_) {
|
||||
FreePacket(write_packet);
|
||||
write_packet = NULL;
|
||||
}
|
||||
} else {
|
||||
RERROR("Wait said %d", res);
|
||||
}
|
||||
|
||||
if (write_packet == NULL) {
|
||||
if (!pending_writes) {
|
||||
mutex_.Acquire();
|
||||
pending_writes = wqueue_;
|
||||
wqueue_end_ = &wqueue_;
|
||||
wqueue_ = NULL;
|
||||
mutex_.Release();
|
||||
}
|
||||
if (pending_writes) {
|
||||
// Then issue writes
|
||||
Packet *p = pending_writes;
|
||||
pending_writes = Packet_NEXT(p);
|
||||
memset(&p->overlapped, 0, sizeof(p->overlapped));
|
||||
p->overlapped.hEvent = write_event_;
|
||||
if (!WriteFile(adapter_.handle(), p->data, p->size, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) {
|
||||
RERROR("TunWin32: WriteFile failed 0x%X", err);
|
||||
FreePacket(p);
|
||||
} else {
|
||||
write_packet = p;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Free memory
|
||||
CancelIo(adapter_.handle());
|
||||
FreePacketList(pending_writes);
|
||||
delegate_->OnStateChanged();
|
||||
}
|
||||
|
||||
DWORD WINAPI TunWin32Overlapped::TunThread(void *x) {
|
||||
TunWin32Overlapped *xx = (TunWin32Overlapped *)x;
|
||||
xx->ThreadMain();
|
||||
return 0;
|
||||
}
|
||||
|
||||
void TunWin32Overlapped::StartThread() {
|
||||
DWORD thread_id;
|
||||
thread_ = CreateThread(NULL, 0, &TunThread, this, 0, &thread_id);
|
||||
SetThreadPriority(thread_, ABOVE_NORMAL_PRIORITY_CLASS);
|
||||
}
|
||||
|
||||
void TunWin32Overlapped::StopThread() {
|
||||
exit_thread_ = true;
|
||||
SetEvent(wake_event_);
|
||||
WaitForSingleObject(thread_, INFINITE);
|
||||
CloseHandle(thread_);
|
||||
thread_ = NULL;
|
||||
}
|
||||
|
||||
void TunWin32Overlapped::WriteTunPacket(Packet *packet) {
|
||||
Packet_NEXT(packet) = NULL;
|
||||
mutex_.Acquire();
|
||||
Packet *was_empty = wqueue_;
|
||||
*wqueue_end_ = packet;
|
||||
wqueue_end_ = &Packet_NEXT(packet);
|
||||
mutex_.Release();
|
||||
if (was_empty == NULL)
|
||||
SetEvent(wake_event_);
|
||||
TunsafeBackendWin32::~TunsafeBackendWin32() {
|
||||
StopInner(false);
|
||||
TunAdaptersInUse::GetInstance()->Release(this);
|
||||
}
|
||||
|
||||
void TunsafeBackendWin32::SetPublicKey(const uint8 key[32]) {
|
||||
|
@ -2047,14 +1968,14 @@ DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) {
|
|||
int stop_mode;
|
||||
int fast_retry_ctr = 0;
|
||||
|
||||
for(;;) {
|
||||
for (;;) {
|
||||
TunWin32Iocp tun(&backend->dns_blocker_, backend);
|
||||
NetworkWin32 net;
|
||||
WireguardProcessor wg_proc(&net.udp(), &tun, backend);
|
||||
|
||||
qs.udp_qsize1 = qs.udp_qsize2 = 0;
|
||||
WireguardProcessor wg_proc(&net, &tun, backend);
|
||||
|
||||
net.udp().SetPacketHandler(&backend->packet_processor_);
|
||||
net.tcp_socket_queue().SetPacketHandler(&backend->packet_processor_);
|
||||
|
||||
tun.SetPacketHandler(&backend->packet_processor_);
|
||||
|
||||
if (backend->config_file_[0] &&
|
||||
|
@ -2107,52 +2028,6 @@ getout_fail_noseterr:
|
|||
return 0;
|
||||
}
|
||||
|
||||
TunsafeBackend::TunsafeBackend() {
|
||||
is_started_ = false;
|
||||
is_remote_ = false;
|
||||
ipv4_ip_ = 0;
|
||||
status_ = kStatusStopped;
|
||||
memset(public_key_, 0, sizeof(public_key_));
|
||||
}
|
||||
|
||||
TunsafeBackend::~TunsafeBackend() {
|
||||
|
||||
}
|
||||
|
||||
static bool GetKillSwitchRouteActive() {
|
||||
RouteInfo ri;
|
||||
return (GetDefaultRouteAndDeleteOldRoutes(AF_INET, NULL, TRUE, NULL, &ri) && ri.found_null_routes == 2);
|
||||
}
|
||||
|
||||
static void RemoveKillSwitchRoute() {
|
||||
RouteInfo ri;
|
||||
GetDefaultRouteAndDeleteOldRoutes(AF_INET, NULL, FALSE, NULL, &ri);
|
||||
GetDefaultRouteAndDeleteOldRoutes(AF_INET6, NULL, FALSE, NULL, &ri);
|
||||
}
|
||||
|
||||
TunsafeBackendWin32::TunsafeBackendWin32(Delegate *delegate) : delegate_(delegate), dns_resolver_(&dns_blocker_) {
|
||||
memset(&stats_, 0, sizeof(stats_));
|
||||
wg_processor_ = NULL;
|
||||
InitPacketMutexes();
|
||||
worker_thread_ = NULL;
|
||||
last_tun_adapter_failed_ = 0;
|
||||
want_periodic_stats_ = false;
|
||||
guid_[0] = 0;
|
||||
if (g_hklm_reg_key == NULL) {
|
||||
RegCreateKeyEx(HKEY_LOCAL_MACHINE, "Software\\TunSafe", NULL, NULL, 0, KEY_ALL_ACCESS, NULL, &g_hklm_reg_key, NULL);
|
||||
g_killswitch_want = RegReadInt(g_hklm_reg_key, "KillSwitch", 0);
|
||||
g_killswitch_curr = GetKillSwitchRouteActive() * kBlockInternet_Route +
|
||||
GetKillSwitchFirewallActive() * kBlockInternet_Firewall;
|
||||
}
|
||||
delegate_->OnStateChanged();
|
||||
}
|
||||
|
||||
TunsafeBackendWin32::~TunsafeBackendWin32() {
|
||||
StopInner(false);
|
||||
TunAdaptersInUse::GetInstance()->Release(this);
|
||||
}
|
||||
|
||||
|
||||
void TunsafeBackendWin32::SetStatus(StatusCode status) {
|
||||
status_ = status;
|
||||
delegate_->OnStatusCode(status);
|
||||
|
|
|
@ -2,15 +2,14 @@
|
|||
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
|
||||
#pragma once
|
||||
|
||||
#include "stdafx.h"
|
||||
#include "tunsafe_types.h"
|
||||
#include "netapi.h"
|
||||
#include "network_win32_api.h"
|
||||
#include "network_win32_dnsblock.h"
|
||||
#include "wireguard_config.h"
|
||||
#include "tunsafe_threading.h"
|
||||
#include "tunsafe_dnsresolve.h"
|
||||
#include <functional>
|
||||
#include "network_common.h"
|
||||
#include "network_win32_tcp.h"
|
||||
|
||||
enum {
|
||||
ADAPTER_GUID_SIZE = 40,
|
||||
|
@ -18,6 +17,7 @@ enum {
|
|||
|
||||
class WireguardProcessor;
|
||||
class TunsafeBackendWin32;
|
||||
class DnsBlocker;
|
||||
|
||||
struct PacketProcessorTunCb : QueuedItemCallback {
|
||||
virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override;
|
||||
|
@ -73,16 +73,15 @@ class PacketAllocPool;
|
|||
|
||||
// Encapsulates a UDP socket pair (ipv4 / ipv6), optionally listening for incoming packets
|
||||
// on a specific port.
|
||||
class UdpSocketWin32 : public UdpInterface, QueuedItemCallback {
|
||||
class UdpSocketWin32 : public QueuedItemCallback {
|
||||
public:
|
||||
explicit UdpSocketWin32(NetworkWin32 *network_win32);
|
||||
~UdpSocketWin32();
|
||||
|
||||
void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
|
||||
|
||||
// -- from UdpInterface
|
||||
virtual bool Configure(int listen_on_port) override;
|
||||
virtual void WriteUdpPacket(Packet *packet) override;
|
||||
bool Configure(int listen_on_port);
|
||||
inline void WriteUdpPacket(Packet *packet);
|
||||
|
||||
void DoIO();
|
||||
void CancelAllIO();
|
||||
|
@ -94,9 +93,9 @@ public:
|
|||
};
|
||||
|
||||
private:
|
||||
|
||||
void DoMoreReads();
|
||||
void DoMoreWrites();
|
||||
void ProcessPackets();
|
||||
|
||||
// From OverlappedCallbacks
|
||||
virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override;
|
||||
|
@ -125,11 +124,16 @@ private:
|
|||
|
||||
Packet *finished_reads_, **finished_reads_end_;
|
||||
int finished_reads_count_;
|
||||
|
||||
__declspec(align(64)) uint32 qsize1_;
|
||||
__declspec(align(64)) uint32 qsize2_;
|
||||
};
|
||||
|
||||
// Holds the thread for network communications
|
||||
class NetworkWin32 {
|
||||
class NetworkWin32 : public UdpInterface {
|
||||
friend class UdpSocketWin32;
|
||||
friend class TcpSocketWin32;
|
||||
friend class TcpSocketQueue;
|
||||
public:
|
||||
explicit NetworkWin32();
|
||||
~NetworkWin32();
|
||||
|
@ -138,13 +142,20 @@ public:
|
|||
void StopThread();
|
||||
|
||||
UdpSocketWin32 &udp() { return udp_socket_; }
|
||||
SimplePacketPool &packet_pool() { return packet_pool_; }
|
||||
TcpSocketQueue &tcp_socket_queue() { return tcp_socket_queue_; }
|
||||
void WakeUp();
|
||||
void PostQueuedItem(QueuedItem *item);
|
||||
|
||||
// -- from UdpInterface
|
||||
virtual bool Configure(int listen_port_udp, int listen_port_tcp) override;
|
||||
virtual void WriteUdpPacket(Packet *packet) override;
|
||||
|
||||
private:
|
||||
void ThreadMain();
|
||||
static DWORD WINAPI NetworkThread(void *x);
|
||||
|
||||
void FreePacketToPool(Packet *p);
|
||||
bool AllocPacketFromPool(Packet **p);
|
||||
bool HasOutstandingIO();
|
||||
|
||||
// The network thread handle
|
||||
HANDLE thread_;
|
||||
|
@ -155,18 +166,17 @@ private:
|
|||
// The handle to the completion port
|
||||
HANDLE completion_port_handle_;
|
||||
|
||||
Packet *freed_packets_, **freed_packets_end_;
|
||||
int freed_packets_count_;
|
||||
|
||||
// Right now there's always one udp socket only
|
||||
UdpSocketWin32 udp_socket_;
|
||||
|
||||
// A linked list of all tcp sockets
|
||||
TcpSocketWin32 *tcp_socket_;
|
||||
|
||||
SimplePacketPool packet_pool_;
|
||||
|
||||
TcpSocketQueue tcp_socket_queue_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
class DnsBlocker;
|
||||
|
||||
class TunWin32Adapter {
|
||||
public:
|
||||
TunWin32Adapter(DnsBlocker *dns_blocker, const char guid[ADAPTER_GUID_SIZE]);
|
||||
|
@ -243,42 +253,6 @@ private:
|
|||
TunWin32Adapter adapter_;
|
||||
};
|
||||
|
||||
// Implementation of TUN interface handling using Overlapped IO
|
||||
class TunWin32Overlapped : public TunInterface {
|
||||
public:
|
||||
explicit TunWin32Overlapped(DnsBlocker *blocker, TunsafeBackendWin32 *backend);
|
||||
~TunWin32Overlapped();
|
||||
|
||||
void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
|
||||
|
||||
void StartThread();
|
||||
void StopThread();
|
||||
|
||||
// -- from TunInterface
|
||||
virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override;
|
||||
virtual void WriteTunPacket(Packet *packet) override;
|
||||
|
||||
private:
|
||||
void CloseTun();
|
||||
void ThreadMain();
|
||||
static DWORD WINAPI TunThread(void *x);
|
||||
|
||||
PacketProcessor *packet_handler_;
|
||||
HANDLE thread_;
|
||||
|
||||
Mutex mutex_;
|
||||
|
||||
HANDLE read_event_, write_event_, wake_event_;
|
||||
|
||||
bool exit_thread_;
|
||||
|
||||
Packet *wqueue_, **wqueue_end_;
|
||||
|
||||
TunWin32Adapter adapter_;
|
||||
|
||||
TunsafeBackendWin32 *backend_;
|
||||
};
|
||||
|
||||
class TunsafeBackendWin32 : public TunsafeBackend, public ProcessorDelegate {
|
||||
friend class PacketProcessor;
|
||||
friend class TunWin32Iocp;
|
||||
|
@ -427,3 +401,8 @@ private:
|
|||
uint8 num_inuse_;
|
||||
Entry entry_[kMaxAdaptersInUse];
|
||||
};
|
||||
|
||||
static inline void ClearOverlapped(OVERLAPPED *o) {
|
||||
memset(o, 0, sizeof(*o));
|
||||
}
|
||||
|
||||
|
|
|
@ -125,6 +125,3 @@ protected:
|
|||
|
||||
TunsafeBackend *CreateNativeTunsafeBackend(TunsafeBackend::Delegate *delegate);
|
||||
TunsafeBackend::Delegate *CreateTunsafeBackendDelegateThreaded(TunsafeBackend::Delegate *delegate, const std::function<void(void)> &callback);
|
||||
|
||||
extern int tpq_last_qsize;
|
||||
extern int g_tun_reads, g_tun_writes;
|
||||
|
|
344
network_win32_tcp.cpp
Normal file
344
network_win32_tcp.cpp
Normal file
|
@ -0,0 +1,344 @@
|
|||
// SPDX-License-Identifier: AGPL-1.0-only
|
||||
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
|
||||
#include <stdafx.h>
|
||||
#include "network_win32_tcp.h"
|
||||
#include "network_win32.h"
|
||||
#include <Mswsock.h>
|
||||
#include <ws2ipdef.h>
|
||||
#include "util.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TcpSocketWin32::TcpSocketWin32(NetworkWin32 *network)
|
||||
: tcp_packet_handler_(&network->packet_pool()) {
|
||||
network_ = network;
|
||||
reads_active_ = 0;
|
||||
writes_active_ = 0;
|
||||
handshake_attempts = 0;
|
||||
state_ = STATE_NONE;
|
||||
wqueue_ = NULL;
|
||||
wqueue_end_ = &wqueue_;
|
||||
socket_ = INVALID_SOCKET;
|
||||
next_ = NULL;
|
||||
packet_processor_ = NULL;
|
||||
// insert in network's linked list
|
||||
next_ = network->tcp_socket_;
|
||||
network->tcp_socket_ = this;
|
||||
}
|
||||
|
||||
TcpSocketWin32::~TcpSocketWin32() {
|
||||
// Unlink myself from the network's linked list.
|
||||
TcpSocketWin32 **p = &network_->tcp_socket_;
|
||||
while (*p != this) p = &(*p)->next_;
|
||||
*p = next_;
|
||||
|
||||
FreePacketList(wqueue_);
|
||||
if (socket_ != INVALID_SOCKET)
|
||||
closesocket(socket_);
|
||||
}
|
||||
|
||||
void TcpSocketWin32::CloseSocket() {
|
||||
if (socket_ != INVALID_SOCKET)
|
||||
CancelIo((HANDLE)socket_);
|
||||
state_ = STATE_ERROR;
|
||||
endpoint_protocol_ = 0;
|
||||
}
|
||||
|
||||
void TcpSocketWin32::WritePacket(Packet *packet) {
|
||||
packet->queue_next = NULL;
|
||||
*wqueue_end_ = packet;
|
||||
wqueue_end_ = &Packet_NEXT(packet);
|
||||
}
|
||||
|
||||
void TcpSocketWin32::CancelAllIO() {
|
||||
if (socket_ != INVALID_SOCKET)
|
||||
CancelIo((HANDLE)socket_);
|
||||
}
|
||||
|
||||
static const GUID WsaConnectExGUID = WSAID_CONNECTEX;
|
||||
|
||||
void TcpSocketWin32::DoConnect() {
|
||||
LPFN_CONNECTEX ConnectEx;
|
||||
assert(socket_ == INVALID_SOCKET);
|
||||
|
||||
socket_ = WSASocket(endpoint_.sin.sin_family, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED);
|
||||
if (socket_ == INVALID_SOCKET) {
|
||||
RERROR("socket() failed");
|
||||
CloseSocket();
|
||||
return;
|
||||
}
|
||||
|
||||
if (!CreateIoCompletionPort((HANDLE)socket_, network_->completion_port_handle_, 0, 0)) {
|
||||
RERROR("TcpSocketWin32::DoConnect CreateIoCompletionPort failed");
|
||||
CloseSocket();
|
||||
return;
|
||||
}
|
||||
|
||||
int nodelay = 1;
|
||||
setsockopt(socket_, IPPROTO_TCP, TCP_NODELAY, (char*)&nodelay, 1);
|
||||
|
||||
DWORD dwBytes = sizeof(ConnectEx);
|
||||
DWORD rc = WSAIoctl(socket_, SIO_GET_EXTENSION_FUNCTION_POINTER, (uint8*)&WsaConnectExGUID, sizeof(WsaConnectExGUID), &ConnectEx, sizeof(ConnectEx), &dwBytes, NULL, NULL);
|
||||
assert(rc == 0);
|
||||
|
||||
// ConnectEx requires the socket to be bound
|
||||
sockaddr_in sin = {0};
|
||||
sin.sin_family = AF_INET;
|
||||
sin.sin_addr.s_addr = INADDR_ANY;
|
||||
sin.sin_port = 0;
|
||||
if (bind(socket_, (sockaddr*)&sin, sizeof(sin))) {
|
||||
RERROR("TcpSocketWin32::DoConnect bind failed: %d", WSAGetLastError());
|
||||
CloseSocket();
|
||||
return;
|
||||
}
|
||||
|
||||
char buf[kSizeOfAddress];
|
||||
RINFO("Connecting to tcp://%s...", PrintIpAddr(endpoint_, buf));
|
||||
|
||||
state_ = STATE_CONNECTING;
|
||||
ClearOverlapped(&connect_overlapped_.overlapped);
|
||||
connect_overlapped_.queue_cb = this;
|
||||
if (!ConnectEx(socket_, (const sockaddr*)&endpoint_.sin, sizeof(endpoint_.sin), NULL, 0, NULL, &connect_overlapped_.overlapped)) {
|
||||
int err = WSAGetLastError();
|
||||
if (err != ERROR_IO_PENDING) {
|
||||
RERROR("ConnectEx failed: %d", err);
|
||||
CloseSocket();
|
||||
return;
|
||||
}
|
||||
}
|
||||
reads_active_ = 1;
|
||||
}
|
||||
|
||||
void TcpSocketWin32::DoMoreReads() {
|
||||
assert(state_ != STATE_ERROR);
|
||||
if (reads_active_ == 0) {
|
||||
// Initiate a new read, we always read into 4 buffers.
|
||||
Packet *p = network_->packet_pool().AllocPacketFromPool();
|
||||
if (!p)
|
||||
return;
|
||||
|
||||
ClearOverlapped(&p->overlapped);
|
||||
p->userdata = 0;
|
||||
p->queue_cb = this;
|
||||
DWORD flags = 0;
|
||||
WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data};
|
||||
if (WSARecv(socket_, &wsabuf, 1, NULL, &flags, &p->overlapped, NULL) != 0) {
|
||||
DWORD err = WSAGetLastError();
|
||||
if (err != ERROR_IO_PENDING) {
|
||||
RERROR("TcpSocketWin32:WSARecv failed 0x%X", err);
|
||||
FreePacket(p);
|
||||
return;
|
||||
}
|
||||
}
|
||||
reads_active_ = 1;
|
||||
}
|
||||
}
|
||||
|
||||
void TcpSocketWin32::DoMoreWrites() {
|
||||
assert(state_ != STATE_ERROR);
|
||||
if (writes_active_ == 0) {
|
||||
WSABUF wsabuf[kMaxWsaBuf];
|
||||
uint32 num_wsabuf = 0;
|
||||
|
||||
Packet *p = wqueue_;
|
||||
if (p == NULL)
|
||||
return;
|
||||
|
||||
do {
|
||||
tcp_packet_handler_.AddHeaderToOutgoingPacket(p);
|
||||
wsabuf[num_wsabuf].buf = (char*)p->data;
|
||||
wsabuf[num_wsabuf].len = (ULONG)p->size;
|
||||
packets_in_write_io_[num_wsabuf] = p;
|
||||
p = Packet_NEXT(p);
|
||||
} while (++num_wsabuf < kMaxWsaBuf && p != NULL);
|
||||
if (!(wqueue_ = p))
|
||||
wqueue_end_ = &wqueue_;
|
||||
num_wsabuf_ = (uint8)num_wsabuf;
|
||||
|
||||
p = packets_in_write_io_[0];
|
||||
ClearOverlapped(&p->overlapped);
|
||||
p->userdata = 1;
|
||||
p->queue_cb = this;
|
||||
|
||||
if (WSASend(socket_, wsabuf, num_wsabuf, NULL, 0, &p->overlapped, NULL) != 0) {
|
||||
DWORD err = WSAGetLastError();
|
||||
if (err != ERROR_IO_PENDING) {
|
||||
RERROR("TcpSocketWin32: WSASend failed 0x%X", err);
|
||||
FreePacket(p);
|
||||
CloseSocket();
|
||||
return;
|
||||
}
|
||||
}
|
||||
writes_active_ = 1;
|
||||
}
|
||||
}
|
||||
|
||||
void TcpSocketWin32::DoIO() {
|
||||
if (state_ == STATE_CONNECTED) {
|
||||
DoMoreReads();
|
||||
while (Packet *p = tcp_packet_handler_.GetNextWireguardPacket()) {
|
||||
p->protocol = endpoint_protocol_;
|
||||
p->addr = endpoint_;
|
||||
|
||||
p->queue_cb = packet_processor_->udp_queue();
|
||||
packet_processor_->ForcePost(p);
|
||||
}
|
||||
if (tcp_packet_handler_.error()) {
|
||||
CloseSocket();
|
||||
DoIO();
|
||||
return;
|
||||
}
|
||||
DoMoreWrites();
|
||||
} else if (state_ == STATE_WANT_CONNECT) {
|
||||
DoConnect();
|
||||
} else if (state_ == STATE_ERROR && !HasOutstandingIO()) {
|
||||
delete this;
|
||||
}
|
||||
}
|
||||
|
||||
bool TcpSocketWin32::HasOutstandingIO() {
|
||||
return writes_active_ + reads_active_ != 0;
|
||||
}
|
||||
|
||||
void TcpSocketWin32::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) {
|
||||
if (qi == &connect_overlapped_) {
|
||||
assert(state_ == STATE_CONNECTING);
|
||||
reads_active_ = 0;
|
||||
if ((DWORD)qi->overlapped.Internal != 0) {
|
||||
if (state_ != STATE_ERROR) {
|
||||
RERROR("TcpSocketWin32::Connect error 0x%X", (DWORD)qi->overlapped.Internal);
|
||||
CloseSocket();
|
||||
}
|
||||
} else {
|
||||
state_ = STATE_CONNECTED;
|
||||
}
|
||||
return;
|
||||
}
|
||||
Packet *p = static_cast<Packet*>(qi);
|
||||
if (p->userdata == 0) {
|
||||
// Read operation complete
|
||||
if ((DWORD)p->overlapped.Internal != 0) {
|
||||
if (state_ != STATE_ERROR) {
|
||||
RERROR("TcpSocketWin32::Read error 0x%X", (DWORD)p->overlapped.Internal);
|
||||
CloseSocket();
|
||||
}
|
||||
network_->packet_pool().FreePacketToPool(p);
|
||||
// What to do?
|
||||
} else if ((int)p->overlapped.InternalHigh == 0) {
|
||||
// Socket closed successfully
|
||||
CloseSocket();
|
||||
network_->packet_pool().FreePacketToPool(p);
|
||||
} else {
|
||||
// Queue it up to rqueue
|
||||
p->size = (int)p->overlapped.InternalHigh;
|
||||
tcp_packet_handler_.QueueIncomingPacket(p);
|
||||
}
|
||||
reads_active_--;
|
||||
} else {
|
||||
assert(writes_active_);
|
||||
assert(packets_in_write_io_[0] == p);
|
||||
|
||||
if ((DWORD)p->overlapped.Internal != 0) {
|
||||
if (state_ != STATE_ERROR) {
|
||||
RERROR("TcpSocketWin32::Write error 0x%X", (DWORD)p->overlapped.Internal);
|
||||
CloseSocket();
|
||||
}
|
||||
}
|
||||
// free all the packets involved in the write
|
||||
for (size_t i = 0; i < num_wsabuf_; i++)
|
||||
network_->packet_pool().FreePacketToPool(packets_in_write_io_[i]);
|
||||
writes_active_--;
|
||||
}
|
||||
}
|
||||
|
||||
void TcpSocketWin32::OnQueuedItemDelete(QueuedItem *qi) {
|
||||
if (qi == &connect_overlapped_) {
|
||||
reads_active_ = 0;
|
||||
return;
|
||||
}
|
||||
Packet *p = static_cast<Packet*>(qi);
|
||||
if (p->userdata == 0) {
|
||||
FreePacket(p);
|
||||
reads_active_--;
|
||||
} else {
|
||||
for (size_t i = 0; i < num_wsabuf_; i++)
|
||||
network_->packet_pool().FreePacketToPool(packets_in_write_io_[i]);
|
||||
writes_active_--;
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TcpSocketQueue::TcpSocketQueue(NetworkWin32 *network) {
|
||||
network_ = network;
|
||||
wqueue_ = NULL;
|
||||
wqueue_end_ = &wqueue_;
|
||||
queued_item_.queue_cb = this;
|
||||
packet_handler_ = NULL;
|
||||
}
|
||||
|
||||
TcpSocketQueue::~TcpSocketQueue() {
|
||||
FreePacketList(wqueue_);
|
||||
}
|
||||
|
||||
void TcpSocketQueue::TransmitOnePacket(Packet *packet) {
|
||||
// Check if we have a tcp connection for the endpoint, otherwise create one.
|
||||
for (TcpSocketWin32 *tcp = network_->tcp_socket_; tcp; tcp = tcp->next_) {
|
||||
// After we send 3 handshakes on a tcp socket in a row, then close and reopen the socket because it seems defunct.
|
||||
if (CompareIpAddr(&tcp->endpoint_, &packet->addr) == 0 && tcp->endpoint_protocol_ == packet->protocol) {
|
||||
if (ReadLE32(packet->data) == MESSAGE_HANDSHAKE_INITIATION) {
|
||||
if (tcp->handshake_attempts == 2) {
|
||||
RINFO("Making new Tcp socket due to too many handshake failures");
|
||||
tcp->CloseSocket();
|
||||
break;
|
||||
}
|
||||
tcp->handshake_attempts++;
|
||||
} else {
|
||||
tcp->handshake_attempts = -1;
|
||||
}
|
||||
tcp->WritePacket(packet);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Drop tcp packet that's for an incoming connection, or packets that are
|
||||
// not a handshake.
|
||||
if ((packet->protocol & kPacketProtocolIncomingConnection) ||
|
||||
packet->size < 4 || ReadLE32(packet->data) != MESSAGE_HANDSHAKE_INITIATION) {
|
||||
FreePacket(packet);
|
||||
return;
|
||||
}
|
||||
|
||||
// Initialize a new tcp socket and connect to the endpoint
|
||||
TcpSocketWin32 *tcp = new TcpSocketWin32(network_);
|
||||
tcp->state_ = TcpSocketWin32::STATE_WANT_CONNECT;
|
||||
tcp->endpoint_ = packet->addr;
|
||||
tcp->endpoint_protocol_ = kPacketProtocolTcp;
|
||||
tcp->SetPacketHandler(packet_handler_);
|
||||
tcp->WritePacket(packet);
|
||||
}
|
||||
|
||||
void TcpSocketQueue::OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) {
|
||||
wqueue_mutex_.Acquire();
|
||||
Packet *packet = wqueue_;
|
||||
wqueue_ = NULL;
|
||||
wqueue_end_ = &wqueue_;
|
||||
wqueue_mutex_.Release();
|
||||
while (packet)
|
||||
TransmitOnePacket(exch(packet, Packet_NEXT(packet)));
|
||||
}
|
||||
|
||||
void TcpSocketQueue::OnQueuedItemDelete(QueuedItem *ow) {
|
||||
|
||||
}
|
||||
|
||||
void TcpSocketQueue::WritePacket(Packet *packet) {
|
||||
packet->queue_next = NULL;
|
||||
wqueue_mutex_.Acquire();
|
||||
Packet *was_empty = wqueue_;
|
||||
*wqueue_end_ = packet;
|
||||
wqueue_end_ = &Packet_NEXT(packet);
|
||||
wqueue_mutex_.Release();
|
||||
if (was_empty == NULL)
|
||||
network_->PostQueuedItem(&queued_item_);
|
||||
}
|
124
network_win32_tcp.h
Normal file
124
network_win32_tcp.h
Normal file
|
@ -0,0 +1,124 @@
|
|||
// SPDX-License-Identifier: AGPL-1.0-only
|
||||
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "netapi.h"
|
||||
#include "network_common.h"
|
||||
#include "tunsafe_threading.h"
|
||||
|
||||
class NetworkWin32;
|
||||
class PacketProcessor;
|
||||
|
||||
class TcpSocketWin32 : public QueuedItemCallback {
|
||||
friend class NetworkWin32;
|
||||
friend class TcpSocketQueue;
|
||||
public:
|
||||
explicit TcpSocketWin32(NetworkWin32 *network);
|
||||
~TcpSocketWin32();
|
||||
|
||||
void SetPacketHandler(PacketProcessor *packet_handler) { packet_processor_ = packet_handler; }
|
||||
|
||||
// Write a packet to the TCP socket. This may be called only from the
|
||||
// wireguard thread. Will append to a buffer and schedule it to be written
|
||||
// from the network thread.
|
||||
void WritePacket(Packet *packet);
|
||||
|
||||
// Call from IO completion thread to cancel all outstanding IO
|
||||
void CancelAllIO();
|
||||
|
||||
// Call from IO completion thread to run more IO
|
||||
void DoIO();
|
||||
|
||||
// Returns true if there's IO still left to run
|
||||
bool HasOutstandingIO();
|
||||
|
||||
private:
|
||||
void DoMoreReads();
|
||||
void DoMoreWrites();
|
||||
void DoConnect();
|
||||
|
||||
void CloseSocket();
|
||||
|
||||
// From OverlappedCallbacks
|
||||
virtual void OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) override;
|
||||
virtual void OnQueuedItemDelete(QueuedItem *qi) override;
|
||||
|
||||
// Network subsystem
|
||||
NetworkWin32 *network_;
|
||||
|
||||
PacketProcessor *packet_processor_;
|
||||
|
||||
enum {
|
||||
STATE_NONE = 0,
|
||||
STATE_ERROR = 1,
|
||||
STATE_CONNECTING = 2,
|
||||
STATE_CONNECTED = 3,
|
||||
STATE_WANT_CONNECT = 4,
|
||||
};
|
||||
|
||||
uint8 reads_active_;
|
||||
uint8 writes_active_;
|
||||
uint8 state_;
|
||||
uint8 num_wsabuf_;
|
||||
|
||||
public:
|
||||
uint8 handshake_attempts;
|
||||
private:
|
||||
|
||||
// The handle to the socket
|
||||
SOCKET socket_;
|
||||
|
||||
// Packets taken over by the network thread waiting to be written,
|
||||
// when these are written we'll start eating from wqueue_
|
||||
Packet *pending_writes_;
|
||||
|
||||
// All packets queued for writing on the network thread.
|
||||
Packet *wqueue_, **wqueue_end_;
|
||||
|
||||
// Linked list of all TcpSocketWin32 wsockets
|
||||
TcpSocketWin32 *next_;
|
||||
|
||||
// Handles packet parsing
|
||||
TcpPacketHandler tcp_packet_handler_;
|
||||
|
||||
// An overlapped instance used for the initial Connect() call.
|
||||
QueuedItem connect_overlapped_;
|
||||
|
||||
IpAddr endpoint_;
|
||||
uint8 endpoint_protocol_;
|
||||
|
||||
// Packets currently involved in the wsabuf writing
|
||||
enum { kMaxWsaBuf = 32 };
|
||||
Packet *packets_in_write_io_[kMaxWsaBuf];
|
||||
};
|
||||
|
||||
class TcpSocketQueue : public QueuedItemCallback {
|
||||
public:
|
||||
explicit TcpSocketQueue(NetworkWin32 *network);
|
||||
~TcpSocketQueue();
|
||||
|
||||
void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
|
||||
|
||||
virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override;
|
||||
virtual void OnQueuedItemDelete(QueuedItem *ow) override;
|
||||
|
||||
void WritePacket(Packet *packet);
|
||||
|
||||
private:
|
||||
void TransmitOnePacket(Packet *packet);
|
||||
NetworkWin32 *network_;
|
||||
|
||||
// All packets queued for writing on the network thread. Locked by |wqueue_mutex_|
|
||||
Packet *wqueue_, **wqueue_end_;
|
||||
|
||||
PacketProcessor *packet_handler_;
|
||||
|
||||
// Protects wqueue_
|
||||
Mutex wqueue_mutex_;
|
||||
|
||||
// Used for queueing things on the network instance
|
||||
QueuedItem queued_item_;
|
||||
|
||||
};
|
||||
|
|
@ -23,7 +23,7 @@
|
|||
|
||||
#if defined(WITH_NETWORK_BSD)
|
||||
#include "network_bsd.cpp"
|
||||
#include "network_bsd_common.cpp"
|
||||
#include "tunsafe_bsd.cpp"
|
||||
#include "ts.cpp"
|
||||
#include "benchmark.cpp"
|
||||
#endif
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// SPDX-License-Identifier: AGPL-1.0-only
|
||||
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
|
||||
#include "network_bsd_common.h"
|
||||
#include "tunsafe_bsd.h"
|
||||
#include "tunsafe_endian.h"
|
||||
#include "util.h"
|
||||
|
||||
|
@ -43,17 +43,6 @@
|
|||
#include <limits.h>
|
||||
#endif
|
||||
|
||||
void tunsafe_die(const char *msg) {
|
||||
fprintf(stderr, "%s\n", msg);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
void SetThreadName(const char *name) {
|
||||
#if defined(OS_LINUX)
|
||||
prctl(PR_SET_NAME, name, 0, 0, 0);
|
||||
#endif // defined(OS_LINUX)
|
||||
}
|
||||
|
||||
#if defined(OS_MACOSX) || defined(OS_FREEBSD)
|
||||
struct MyRouteMsg {
|
||||
struct rt_msghdr hdr;
|
||||
|
@ -346,21 +335,7 @@ int open_tun(char *devname, size_t devname_size) {
|
|||
}
|
||||
#endif
|
||||
|
||||
int open_udp(int listen_on_port) {
|
||||
int udp_fd = socket(AF_INET, SOCK_DGRAM, 0);
|
||||
if (udp_fd < 0) return udp_fd;
|
||||
sockaddr_in sin = {0};
|
||||
sin.sin_family = AF_INET;
|
||||
sin.sin_port = htons(listen_on_port);
|
||||
if (bind(udp_fd, (struct sockaddr*)&sin, sizeof(sin)) != 0) {
|
||||
close(udp_fd);
|
||||
return -1;
|
||||
}
|
||||
return udp_fd;
|
||||
}
|
||||
|
||||
TunsafeBackendBsd::TunsafeBackendBsd()
|
||||
: processor_(NULL) {
|
||||
TunsafeBackendBsd::TunsafeBackendBsd() {
|
||||
devname_[0] = 0;
|
||||
tun_interface_gone_ = false;
|
||||
}
|
||||
|
@ -579,122 +554,33 @@ bool TunsafeBackendBsd::RunPrePostCommand(const std::vector<std::string> &vec) {
|
|||
return success;
|
||||
}
|
||||
|
||||
#if defined(OS_LINUX)
|
||||
UnixSocketDeletionWatcher::UnixSocketDeletionWatcher()
|
||||
: inotify_fd_(-1) {
|
||||
pipes_[0] = -1;
|
||||
pipes_[0] = -1;
|
||||
}
|
||||
|
||||
UnixSocketDeletionWatcher::~UnixSocketDeletionWatcher() {
|
||||
close(inotify_fd_);
|
||||
close(pipes_[0]);
|
||||
close(pipes_[1]);
|
||||
}
|
||||
|
||||
bool UnixSocketDeletionWatcher::Start(const char *path, bool *flag_to_set) {
|
||||
assert(inotify_fd_ == -1);
|
||||
path_ = path;
|
||||
flag_to_set_ = flag_to_set;
|
||||
pid_ = getpid();
|
||||
inotify_fd_ = inotify_init1(IN_CLOEXEC | IN_NONBLOCK);
|
||||
if (inotify_fd_ == -1) {
|
||||
perror("inotify_init1() failed");
|
||||
return false;
|
||||
}
|
||||
if (inotify_add_watch(inotify_fd_, "/var/run/wireguard", IN_DELETE | IN_DELETE_SELF) == -1) {
|
||||
perror("inotify_add_watch failed");
|
||||
return false;
|
||||
}
|
||||
if (pipe(pipes_) == -1) {
|
||||
perror("pipe() failed");
|
||||
return false;
|
||||
}
|
||||
return pthread_create(&thread_, NULL, &UnixSocketDeletionWatcher::RunThread, this) == 0;
|
||||
}
|
||||
|
||||
void UnixSocketDeletionWatcher::Stop() {
|
||||
RINFO("Stopping..");
|
||||
void *retval;
|
||||
write(pipes_[1], "", 1);
|
||||
pthread_join(thread_, &retval);
|
||||
}
|
||||
|
||||
void *UnixSocketDeletionWatcher::RunThread(void *arg) {
|
||||
UnixSocketDeletionWatcher *self = (UnixSocketDeletionWatcher*)arg;
|
||||
return self->RunThreadInner();
|
||||
}
|
||||
|
||||
void *UnixSocketDeletionWatcher::RunThreadInner() {
|
||||
char buf[sizeof(struct inotify_event) + NAME_MAX + 1]
|
||||
__attribute__ ((aligned(__alignof__(struct inotify_event))));
|
||||
fd_set fdset;
|
||||
struct stat st;
|
||||
for(;;) {
|
||||
if (lstat(path_, &st) == -1 && errno == ENOENT) {
|
||||
RINFO("Unix socket %s deleted.", path_);
|
||||
*flag_to_set_ = true;
|
||||
kill(pid_, SIGALRM);
|
||||
break;
|
||||
}
|
||||
FD_ZERO(&fdset);
|
||||
FD_SET(inotify_fd_, &fdset);
|
||||
FD_SET(pipes_[0], &fdset);
|
||||
int n = select(std::max(inotify_fd_, pipes_[0]) + 1, &fdset, NULL, NULL, NULL);
|
||||
if (n == -1) {
|
||||
perror("select");
|
||||
break;
|
||||
}
|
||||
if (FD_ISSET(inotify_fd_, &fdset)) {
|
||||
ssize_t len = read(inotify_fd_, buf, sizeof(buf));
|
||||
if (len == -1) {
|
||||
perror("read");
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (FD_ISSET(pipes_[0], &fdset))
|
||||
break;
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
#else // !defined(OS_LINUX)
|
||||
|
||||
bool UnixSocketDeletionWatcher::Poll(const char *path) {
|
||||
struct stat st;
|
||||
return lstat(path, &st) == -1 && errno == ENOENT;
|
||||
}
|
||||
|
||||
#endif // !defined(OS_LINUX)
|
||||
|
||||
static TunsafeBackendBsd *g_tunsafe_backend_bsd;
|
||||
|
||||
static void SigAlrm(int sig) {
|
||||
if (g_tunsafe_backend_bsd)
|
||||
g_tunsafe_backend_bsd->HandleSigAlrm();
|
||||
}
|
||||
|
||||
static SignalCatcher *g_signal_catcher;
|
||||
static bool did_ctrlc;
|
||||
|
||||
void SigInt(int sig) {
|
||||
void SignalCatcher::SigAlrm(int sig) {
|
||||
if (g_signal_catcher)
|
||||
*g_signal_catcher->sigalarm_flag_ = true;
|
||||
}
|
||||
|
||||
void SignalCatcher::SigInt(int sig) {
|
||||
if (did_ctrlc)
|
||||
exit(1);
|
||||
did_ctrlc = true;
|
||||
write(1, "Ctrl-C detected. Exiting. Press again to force quit.\n", sizeof("Ctrl-C detected. Exiting. Press again to force quit.\n")-1);
|
||||
|
||||
write(1, "Ctrl-C detected. Exiting. Press again to force quit.\n", sizeof("Ctrl-C detected. Exiting. Press again to force quit.\n") - 1);
|
||||
// todo: fix signal safety?
|
||||
if (g_tunsafe_backend_bsd)
|
||||
g_tunsafe_backend_bsd->HandleExit();
|
||||
if (g_signal_catcher)
|
||||
*g_signal_catcher->exit_flag_ = true;
|
||||
}
|
||||
|
||||
void TunsafeBackendBsd::RunLoop() {
|
||||
assert(!g_tunsafe_backend_bsd);
|
||||
assert(processor_);
|
||||
SignalCatcher::SignalCatcher(bool *exit_flag, bool *sigalarm_flag) {
|
||||
assert(g_signal_catcher == NULL);
|
||||
exit_flag_ = exit_flag;
|
||||
sigalarm_flag_ = sigalarm_flag;
|
||||
g_signal_catcher = this;
|
||||
|
||||
sigset_t mask;
|
||||
|
||||
g_tunsafe_backend_bsd = this;
|
||||
|
||||
// We want an alarm signal every second.
|
||||
{
|
||||
struct sigaction act = {0};
|
||||
|
@ -713,7 +599,6 @@ void TunsafeBackendBsd::RunLoop() {
|
|||
return;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(OS_LINUX) || defined(OS_FREEBSD)
|
||||
sigemptyset(&mask);
|
||||
sigaddset(&mask, SIGALRM);
|
||||
|
@ -747,32 +632,166 @@ void TunsafeBackendBsd::RunLoop() {
|
|||
#elif defined(OS_MACOSX)
|
||||
ualarm(1000000, 1000000);
|
||||
#endif
|
||||
}
|
||||
|
||||
RunLoopInner();
|
||||
|
||||
g_tunsafe_backend_bsd = NULL;
|
||||
SignalCatcher::~SignalCatcher() {
|
||||
g_signal_catcher = NULL;
|
||||
}
|
||||
|
||||
void InitCpuFeatures();
|
||||
void Benchmark();
|
||||
|
||||
|
||||
const char *print_ip(char buf[kSizeOfAddress], in_addr_t ip) {
|
||||
snprintf(buf, kSizeOfAddress, "%d.%d.%d.%d", (ip >> 24) & 0xff, (ip >> 16) & 0xff, (ip >> 8) & 0xff, (ip >> 0) & 0xff);
|
||||
return buf;
|
||||
}
|
||||
|
||||
class MyProcessorDelegate : public ProcessorDelegate {
|
||||
class TunsafeBackendBsdImpl : public TunsafeBackendBsd, public NetworkBsd::NetworkBsdDelegate, public ProcessorDelegate {
|
||||
public:
|
||||
MyProcessorDelegate() {
|
||||
wg_processor_ = NULL;
|
||||
is_connected_ = false;
|
||||
}
|
||||
TunsafeBackendBsdImpl();
|
||||
virtual ~TunsafeBackendBsdImpl();
|
||||
|
||||
virtual void OnConnected() override {
|
||||
void RunLoop();
|
||||
virtual bool InitializeTun(char devname[16]) override;
|
||||
|
||||
// -- from TunInterface
|
||||
virtual void WriteTunPacket(Packet *packet) override;
|
||||
|
||||
// -- from UdpInterface
|
||||
virtual bool Configure(int listen_port_udp, int listen_port_tcp) override;
|
||||
virtual void WriteUdpPacket(Packet *packet) override;
|
||||
|
||||
// -- from NetworkBsdDelegate
|
||||
virtual void OnSecondLoop(uint64 now) override;
|
||||
virtual void RunAllMainThreadScheduled() override;
|
||||
|
||||
// -- from ProcessorDelegate
|
||||
virtual void OnConnected() override;
|
||||
virtual void OnConnectionRetry(uint32 attempts) override;
|
||||
|
||||
WireguardProcessor *processor() { return &processor_; }
|
||||
|
||||
private:
|
||||
void WriteTcpPacket(Packet *packet);
|
||||
|
||||
// Close all TCP connections that are not pointed to by any of the peer endpoint.
|
||||
void CloseOrphanTcpConnections();
|
||||
|
||||
bool is_connected_;
|
||||
uint8 close_orphan_counter_;
|
||||
WireguardProcessor processor_;
|
||||
NetworkBsd network_;
|
||||
TunSocketBsd tun_;
|
||||
UdpSocketBsd udp_;
|
||||
UnixDomainSocketListenerBsd unix_socket_listener_;
|
||||
TcpSocketListenerBsd tcp_socket_listener_;
|
||||
};
|
||||
|
||||
TunsafeBackendBsdImpl::TunsafeBackendBsdImpl()
|
||||
: is_connected_(false),
|
||||
close_orphan_counter_(0),
|
||||
processor_(this, this, this),
|
||||
network_(this, 1000),
|
||||
tun_(&network_, &processor_),
|
||||
udp_(&network_, &processor_),
|
||||
unix_socket_listener_(&network_, &processor_),
|
||||
tcp_socket_listener_(&network_, &processor_) {
|
||||
}
|
||||
|
||||
TunsafeBackendBsdImpl::~TunsafeBackendBsdImpl() {
|
||||
}
|
||||
|
||||
bool TunsafeBackendBsdImpl::InitializeTun(char devname[16]) {
|
||||
int tun_fd = open_tun(devname, 16);
|
||||
if (tun_fd < 0) { RERROR("Error opening tun device"); return false; }
|
||||
if (!tun_.Initialize(tun_fd)) {
|
||||
close(tun_fd);
|
||||
return false;
|
||||
}
|
||||
unix_socket_listener_.Initialize(devname);
|
||||
return true;
|
||||
}
|
||||
|
||||
void TunsafeBackendBsdImpl::WriteTunPacket(Packet *packet) {
|
||||
tun_.WritePacket(packet);
|
||||
}
|
||||
|
||||
// Called to initialize udp
|
||||
bool TunsafeBackendBsdImpl::Configure(int listen_port, int listen_port_tcp) {
|
||||
return udp_.Initialize(listen_port) &&
|
||||
(listen_port_tcp == 0 || tcp_socket_listener_.Initialize(listen_port_tcp));
|
||||
}
|
||||
|
||||
void TunsafeBackendBsdImpl::WriteTcpPacket(Packet *packet) {
|
||||
// Check if we have a tcp connection for the endpoint, otherwise create one.
|
||||
for (TcpSocketBsd *tcp = network_.tcp_sockets(); tcp; tcp = tcp->next()) {
|
||||
// After we send 3 handshakes on a tcp socket in a row, then close and reopen the socket because it seems defunct.
|
||||
if (CompareIpAddr(&tcp->endpoint(), &packet->addr) == 0 && tcp->endpoint_protocol() == packet->protocol) {
|
||||
if (ReadLE32(packet->data) == MESSAGE_HANDSHAKE_INITIATION) {
|
||||
if (tcp->handshake_attempts == 2) {
|
||||
RINFO("Making new Tcp socket due to too many handshake failures");
|
||||
delete tcp;
|
||||
break;
|
||||
}
|
||||
tcp->handshake_attempts++;
|
||||
} else {
|
||||
tcp->handshake_attempts = -1;
|
||||
}
|
||||
tcp->WritePacket(packet);
|
||||
return;
|
||||
}
|
||||
}
|
||||
// Drop tcp packet that's for an incoming connection, or packets that are
|
||||
// not a handshake.
|
||||
if ((packet->protocol & kPacketProtocolIncomingConnection) ||
|
||||
ReadLE32(packet->data) != MESSAGE_HANDSHAKE_INITIATION) {
|
||||
FreePacket(packet);
|
||||
return;
|
||||
}
|
||||
// Initialize a new tcp socket and connect to the endpoint
|
||||
TcpSocketBsd *tcp = new TcpSocketBsd(&network_, &processor_);
|
||||
if (!tcp || !tcp->InitializeOutgoing(packet->addr)) {
|
||||
delete tcp;
|
||||
FreePacket(packet);
|
||||
return;
|
||||
}
|
||||
tcp->WritePacket(packet);
|
||||
}
|
||||
|
||||
void TunsafeBackendBsdImpl::WriteUdpPacket(Packet *packet) {
|
||||
assert((packet->protocol & 0x7F) <= 2);
|
||||
if (packet->protocol & kPacketProtocolTcp) {
|
||||
WriteTcpPacket(packet);
|
||||
} else {
|
||||
udp_.WritePacket(packet);
|
||||
}
|
||||
}
|
||||
|
||||
void TunsafeBackendBsdImpl::RunLoop() {
|
||||
if (!unix_socket_listener_.Start(network_.exit_flag()))
|
||||
return;
|
||||
|
||||
SignalCatcher signal_catcher(network_.exit_flag(), network_.sigalarm_flag());
|
||||
network_.RunLoop(&signal_catcher.orig_signal_mask_);
|
||||
unix_socket_listener_.Stop();
|
||||
|
||||
tun_interface_gone_ = tun_.tun_interface_gone();
|
||||
}
|
||||
|
||||
void TunsafeBackendBsdImpl::OnSecondLoop(uint64 now) {
|
||||
if (!(close_orphan_counter_++ & 0xF))
|
||||
CloseOrphanTcpConnections();
|
||||
processor_.SecondLoop();
|
||||
}
|
||||
|
||||
void TunsafeBackendBsdImpl::RunAllMainThreadScheduled() {
|
||||
processor_.RunAllMainThreadScheduled();
|
||||
}
|
||||
|
||||
void TunsafeBackendBsdImpl::OnConnected() {
|
||||
if (!is_connected_) {
|
||||
const WgCidrAddr *ipv4_addr = NULL;
|
||||
for (const WgCidrAddr &x : wg_processor_->addr()) {
|
||||
for (const WgCidrAddr &x : processor_.addr()) {
|
||||
if (x.size == 32) { ipv4_addr = &x; break; }
|
||||
}
|
||||
uint32 ipv4_ip = ipv4_addr ? ReadBE32(ipv4_addr->addr) : 0;
|
||||
|
@ -780,17 +799,41 @@ public:
|
|||
RINFO("Connection established. IP %s", ipv4_ip ? print_ip(buf, ipv4_ip) : "(none)");
|
||||
is_connected_ = true;
|
||||
}
|
||||
}
|
||||
virtual void OnConnectionRetry(uint32 attempts) override {
|
||||
}
|
||||
|
||||
void TunsafeBackendBsdImpl::OnConnectionRetry(uint32 attempts) {
|
||||
if (is_connected_ && attempts >= 3) {
|
||||
is_connected_ = false;
|
||||
RINFO("Reconnecting...");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
WireguardProcessor *wg_processor_;
|
||||
bool is_connected_;
|
||||
};
|
||||
void TunsafeBackendBsdImpl::CloseOrphanTcpConnections() {
|
||||
// Add all incoming tcp connections into a lookup table
|
||||
WG_HASHTABLE_IMPL<WgAddrEntry::IpPort, void*, WgAddrEntry::IpPortHasher> lookup;
|
||||
for (TcpSocketBsd *tcp = network_.tcp_sockets(); tcp; tcp = tcp->next()) {
|
||||
if (tcp->endpoint_protocol() == (kPacketProtocolTcp | kPacketProtocolIncomingConnection)) {
|
||||
// Avoid deleting tcp sockets that were just born.
|
||||
if (tcp->age == 0) {
|
||||
tcp->age = 1;
|
||||
} else {
|
||||
lookup[ConvertIpAddrToAddrX(tcp->endpoint())] = tcp;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (lookup.empty())
|
||||
return;
|
||||
// For each peer, check if it has an endpoint that matches
|
||||
// an entry in the lookup table, and delete it from the lookup
|
||||
// table.
|
||||
for(WgPeer *peer = processor_.dev().first_peer(); peer; peer = peer->next_peer()) {
|
||||
if (peer->endpoint_protocol() == (kPacketProtocolTcp | kPacketProtocolIncomingConnection))
|
||||
lookup.erase(ConvertIpAddrToAddrX(peer->endpoint()));
|
||||
}
|
||||
// The tcp connections that are still in the hashtable can be deleted
|
||||
for(const auto &it : lookup)
|
||||
delete (TcpSocketBsd *)it.second;
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
CommandLineOutput cmd = {0};
|
||||
|
@ -812,20 +855,15 @@ int main(int argc, char **argv) {
|
|||
|
||||
SetThreadName("tunsafe-m");
|
||||
|
||||
MyProcessorDelegate my_procdel;
|
||||
TunsafeBackendBsd *backend = CreateTunsafeBackendBsd();
|
||||
TunsafeBackendBsdImpl backend;
|
||||
if (cmd.interface_name)
|
||||
backend->SetTunDeviceName(cmd.interface_name);
|
||||
|
||||
WireguardProcessor wg(backend, backend, &my_procdel);
|
||||
|
||||
my_procdel.wg_processor_ = &wg;
|
||||
backend->SetProcessor(&wg);
|
||||
backend.SetTunDeviceName(cmd.interface_name);
|
||||
|
||||
DnsResolver dns_resolver(NULL);
|
||||
if (*cmd.filename_to_load && !ParseWireGuardConfigFile(&wg, cmd.filename_to_load, &dns_resolver))
|
||||
if (*cmd.filename_to_load && !ParseWireGuardConfigFile(backend.processor(), cmd.filename_to_load, &dns_resolver))
|
||||
return 1;
|
||||
if (!backend.processor()->Start())
|
||||
return 1;
|
||||
if (!wg.Start()) return 1;
|
||||
|
||||
if (cmd.daemon) {
|
||||
fprintf(stderr, "Switching to daemon mode...\n");
|
||||
|
@ -833,9 +871,8 @@ int main(int argc, char **argv) {
|
|||
perror("daemon() failed");
|
||||
}
|
||||
|
||||
backend->RunLoop();
|
||||
backend->CleanupRoutes();
|
||||
delete backend;
|
||||
backend.RunLoop();
|
||||
backend.CleanupRoutes();
|
||||
|
||||
return 0;
|
||||
}
|
60
tunsafe_bsd.h
Normal file
60
tunsafe_bsd.h
Normal file
|
@ -0,0 +1,60 @@
|
|||
// SPDX-License-Identifier: AGPL-1.0-only
|
||||
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
|
||||
#ifndef TUNSAFE_NETWORK_BSD_COMMON_H_
|
||||
#define TUNSAFE_NETWORK_BSD_COMMON_H_
|
||||
|
||||
#include "netapi.h"
|
||||
#include "wireguard.h"
|
||||
#include "wireguard_config.h"
|
||||
#include <string>
|
||||
#include <signal.h>
|
||||
|
||||
struct RouteInfo {
|
||||
uint8 family;
|
||||
uint8 cidr;
|
||||
uint8 ip[16];
|
||||
uint8 gw[16];
|
||||
std::string dev;
|
||||
};
|
||||
|
||||
class SignalCatcher {
|
||||
public:
|
||||
SignalCatcher(bool *exit_flag, bool *sigalarm_flag);
|
||||
~SignalCatcher();
|
||||
|
||||
sigset_t orig_signal_mask_;
|
||||
private:
|
||||
static void SigAlrm(int sig);
|
||||
static void SigInt(int sig);
|
||||
bool *exit_flag_;
|
||||
bool *sigalarm_flag_;
|
||||
};
|
||||
|
||||
class TunsafeBackendBsd : public TunInterface, public UdpInterface {
|
||||
public:
|
||||
TunsafeBackendBsd();
|
||||
virtual ~TunsafeBackendBsd();
|
||||
|
||||
void CleanupRoutes();
|
||||
|
||||
void SetTunDeviceName(const char *name);
|
||||
|
||||
// -- from TunInterface
|
||||
virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override;
|
||||
|
||||
protected:
|
||||
virtual bool InitializeTun(char devname[16]) = 0;
|
||||
|
||||
void AddRoute(uint32 ip, uint32 cidr, uint32 gw, const char *dev);
|
||||
void DelRoute(const RouteInfo &cd);
|
||||
bool AddRoute(int family, const void *dest, int dest_prefix, const void *gateway, const char *dev);
|
||||
bool RunPrePostCommand(const std::vector<std::string> &vec);
|
||||
|
||||
std::vector<RouteInfo> cleanup_commands_;
|
||||
std::vector<std::string> pre_down_, post_down_;
|
||||
std::vector<WgCidrAddr> addresses_to_remove_;
|
||||
char devname_[16];
|
||||
bool tun_interface_gone_;
|
||||
};
|
||||
|
||||
#endif // TUNSAFE_NETWORK_BSD_COMMON_H_
|
|
@ -28,6 +28,7 @@ WireguardProcessor::WireguardProcessor(UdpInterface *udp, TunInterface *tun, Pro
|
|||
mtu_ = 1420;
|
||||
memset(&stats_, 0, sizeof(stats_));
|
||||
listen_port_ = 0;
|
||||
listen_port_tcp_ = 0;
|
||||
network_discovery_spoofing_ = false;
|
||||
add_routes_mode_ = true;
|
||||
dns_blocking_ = true;
|
||||
|
@ -50,6 +51,16 @@ void WireguardProcessor::SetListenPort(int listen_port) {
|
|||
}
|
||||
}
|
||||
|
||||
void WireguardProcessor::SetListenPortTcp(int listen_port) {
|
||||
if (listen_port_tcp_ != listen_port) {
|
||||
listen_port_tcp_ = listen_port;
|
||||
if (is_started_ && !ConfigureUdp()) {
|
||||
RINFO("ConfigureUdp failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void WireguardProcessor::AddDnsServer(const IpAddr &sin) {
|
||||
dns_addr_.push_back(sin);
|
||||
}
|
||||
|
@ -126,7 +137,7 @@ bool WireguardProcessor::Start() {
|
|||
|
||||
bool WireguardProcessor::ConfigureUdp() {
|
||||
assert(dev_.IsMainThread());
|
||||
return udp_->Configure(listen_port_);
|
||||
return udp_->Configure(listen_port_, listen_port_tcp_);
|
||||
}
|
||||
|
||||
bool WireguardProcessor::ConfigureTun() {
|
||||
|
|
|
@ -81,6 +81,8 @@ public:
|
|||
~WireguardProcessor();
|
||||
|
||||
void SetListenPort(int listen_port);
|
||||
void SetListenPortTcp(int listen_port);
|
||||
|
||||
void AddDnsServer(const IpAddr &sin);
|
||||
bool SetTunAddress(const WgCidrAddr &addr);
|
||||
void ClearTunAddress();
|
||||
|
@ -132,6 +134,7 @@ private:
|
|||
UdpInterface *udp_;
|
||||
|
||||
uint16 listen_port_;
|
||||
uint16 listen_port_tcp_;
|
||||
uint16 mtu_;
|
||||
|
||||
bool dns_blocking_;
|
||||
|
|
|
@ -104,6 +104,8 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
|
|||
wg_->dev().SetPrivateKey(binkey);
|
||||
} else if (strcmp(key, "ListenPort") == 0) {
|
||||
wg_->SetListenPort(atoi(value));
|
||||
} else if (strcmp(key, "ListenPortTCP") == 0) {
|
||||
wg_->SetListenPortTcp(atoi(value));
|
||||
} else if (strcmp(key, "Address") == 0) {
|
||||
SplitString(value, ',', &ss);
|
||||
for (size_t i = 0; i < ss.size(); i++) {
|
||||
|
|
Loading…
Reference in a new issue