Experimental support for WireGuard over TCP

This commit is contained in:
Ludvig Strigeus 2018-11-17 19:14:05 +01:00
parent 9a8acb7091
commit a03980e74b
20 changed files with 2648 additions and 1016 deletions

View file

@ -185,10 +185,24 @@
<ClInclude Include="crypto\blake2s\blake2s-sse-impl.h" />
<ClInclude Include="crypto\curve25519\curve25519-donna.h" />
<ClInclude Include="ip_to_peer_map.h" />
<ClInclude Include="network_bsd.h">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</ClInclude>
<ClInclude Include="network_common.h" />
<ClInclude Include="network_win32_tcp.h" />
<ClInclude Include="service_pipe_win32.h" />
<ClInclude Include="service_win32.h" />
<ClInclude Include="service_win32_api.h" />
<ClInclude Include="service_win32_constants.h" />
<ClInclude Include="tunsafe_bsd.h">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</ClInclude>
<ClInclude Include="tunsafe_config.h" />
<ClInclude Include="tunsafe_cpu.h" />
<ClInclude Include="crypto\aesgcm\aes.h" />
@ -215,8 +229,28 @@
<ItemGroup>
<ClCompile Include="benchmark.cpp" />
<ClCompile Include="ip_to_peer_map.cpp" />
<ClCompile Include="network_bsd.cpp">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</ClCompile>
<ClCompile Include="network_bsd_mt.cpp">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</ClCompile>
<ClCompile Include="network_common.cpp" />
<ClCompile Include="network_win32_tcp.cpp" />
<ClCompile Include="service_pipe_win32.cpp" />
<ClCompile Include="service_win32.cpp" />
<ClCompile Include="tunsafe_bsd.cpp">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</ClCompile>
<ClCompile Include="tunsafe_cpu.cpp" />
<ClCompile Include="crypto\aesgcm\aesgcm.cpp" />
<ClCompile Include="crypto\siphash\siphash.cpp" />

View file

@ -23,6 +23,12 @@
<Filter Include="crypto\chacha20poly1305">
<UniqueIdentifier>{1ca37c7b-e91e-4648-9584-7d0c73d8e416}</UniqueIdentifier>
</Filter>
<Filter Include="Source Files\BSD">
<UniqueIdentifier>{4b2f2fd9-780e-45db-8fe1-f03079439723}</UniqueIdentifier>
</Filter>
<Filter Include="crypto\siphash">
<UniqueIdentifier>{0f45e1a0-f33e-4c6e-88ae-eb4639f12041}</UniqueIdentifier>
</Filter>
</ItemGroup>
<ItemGroup>
<ClInclude Include="stdafx.h">
@ -92,7 +98,6 @@
<ClInclude Include="service_win32_constants.h">
<Filter>Source Files\Win32</Filter>
</ClInclude>
<ClInclude Include="crypto\siphash\siphash.h" />
<ClInclude Include="crypto\blake2s\blake2s.h">
<Filter>crypto\blake2s</Filter>
</ClInclude>
@ -123,6 +128,21 @@
<ClInclude Include="tunsafe_dnsresolve.h">
<Filter>Source Files</Filter>
</ClInclude>
<ClInclude Include="crypto\siphash\siphash.h">
<Filter>crypto\siphash</Filter>
</ClInclude>
<ClInclude Include="network_bsd.h">
<Filter>Source Files\BSD</Filter>
</ClInclude>
<ClInclude Include="tunsafe_bsd.h">
<Filter>Source Files\BSD</Filter>
</ClInclude>
<ClInclude Include="network_common.h">
<Filter>Source Files</Filter>
</ClInclude>
<ClInclude Include="network_win32_tcp.h">
<Filter>Source Files\Win32</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<ClCompile Include="stdafx.cpp">
@ -173,9 +193,6 @@
<ClCompile Include="service_pipe_win32.cpp">
<Filter>Source Files\Win32</Filter>
</ClCompile>
<ClCompile Include="crypto\siphash\siphash.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="crypto\blake2s\blake2s.cpp">
<Filter>crypto\blake2s</Filter>
</ClCompile>
@ -188,6 +205,24 @@
<ClCompile Include="tunsafe_ipaddr.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="network_bsd.cpp">
<Filter>Source Files\BSD</Filter>
</ClCompile>
<ClCompile Include="network_bsd_mt.cpp">
<Filter>Source Files\BSD</Filter>
</ClCompile>
<ClCompile Include="crypto\siphash\siphash.cpp">
<Filter>crypto\siphash</Filter>
</ClCompile>
<ClCompile Include="tunsafe_bsd.cpp">
<Filter>Source Files\BSD</Filter>
</ClCompile>
<ClCompile Include="network_common.cpp">
<Filter>Source Files</Filter>
</ClCompile>
<ClCompile Include="network_win32_tcp.cpp">
<Filter>Source Files\Win32</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
<ResourceCompile Include="TunSafe.rc" />

104
docs/WireGuard TCP.txt Normal file
View file

@ -0,0 +1,104 @@
WireGuard over TCP
------------------
We hate running one TCP implementation on top of another TCP implementation.
There's problems with cascading retransmissions and head of line blocking,
and performance is always much worse than a UDP based tunnel.
However, we also recognize that several users need to run WireGuard over TCP.
One reason is that UDP packets are sometimes blocked by the network in
corporate scenarios or in other types of firewalls. Also, in misconfigured
networks outside of the user's control, TCP may be more reliable than UDP.
Additionally, we want TunSafe to be a drop-in replacement for OpenVPN, which
also supports TCP based tunneling. The feature could also be used to run
WireGuard tunnels over ssh tunnels, or through socks/https proxies.
The TunSafe project therefore takes the pragmatic approach of supporting
WireGuard over TCP, while discouraging its use. We absolutely don't want
people to start using TCP by default. It's meant to be used only in the
extreme cases when nothing else is working.
We've added experimental support for TCP in the latest TunSafe master,
which means you can try this out on Windows, OSX, Linux, and FreeBSD.
On the server side, to listen on a TCP port, use ListenPortTCP=1234. (Not
working on Windows yet). On the clients, use Endpoint=tcp://5.5.5.5:1234.
The code is still very experimental and untested, and is not recommended
for general use. Once the code is more well tested, we'll also release
support for connecting to WireGuard over TCP in our Android and iOS clients.
To make the impact as small as possible to our WireGuard protocol handling,
and to minimize the risk of security related issues, the TCP feature has been
designed to be as self-contained as possible. When a packet comes in over
TCP, it's sent over to the WireGuard protocol handler and treated as if it
was a UDP packet, and vice versa. This means TCP support can also be supported
in existing WireGuard deployments by using a separate process that converts
TCP connections into UDP packets sent to the WireGuard Linux kernel module.
Each packet over TCP is prefixed by a 2-byte big endian number, which contains
the length of the packet's payload. The payload is then the actual WireGuard
UDP packet.
TCP has larger overhead than UDP, and we want to support the usual WireGuard
MTU of 1420 without introducing extra packet "fragmenting". So we implemented
an optimization to skip sending the 16-byte WireGuard header for every packet.
TCP is a reliable connection, we know that sequence numbers are always
monotonically increasing, so we can predict the contents of this header.
Here's an example:
A 1420 byte big packet sent over a WireGuard link will have 2 bytes of
TCP payload length, 16 bytes of WireGuard headers, 16 bytes of WireGuard MAC,
20 bytes of TCP headers, and 40 bytes of IPv6 headers.
This is a total of 1420 + 2 + 16 + 16 + 20 + 40 = 1514 bytes, exceeding
the usual 1500 byte Ethernet MTU by 14 bytes. This means that a single full
sized packet over WireGuard will result in 2 TCP packets. With our
optimization, we reduce this to 1498 bytes, so it fits in one TCP packet.
Protocol specification
----------------------
TT LLLLLL LLLLLLLL [Payload LL bytes]
| |
| \-- Payload length, high byte first.
\----- Packet type
The packet types (TT) currently defined are:
TT = 00 = Normal The payload is a normal unmodified WireGuard packet
including the regular WireGuard header.
01 = Reserved
10 = Data A WireGuard data packet (type 04) without the 16 byte
header. The predicted header is prefixed to the payload.
11 = Control A TCP control packet. Currently this is used only to setup
the header prediction. See below.
There's only one defined Control packet, type 00 (SetKeyAndCounter):
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|1 1| Length is 13 (14 bits) | 00 (8 bits) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Key ID (32 bits) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Counter (64 bits) ...
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
This sets up the Key ID and Counter used for the Data packets. Then Counter
is incremented by 1 for every such packet.
For every Data packet, the predicted Key ID and Counter is expanded to a
regular WireGuard data (type 04) header, which is prefixed to the payload:
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| 04 (8 bits) | Reserved (24 bits) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Key ID (32 bits) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Counter (64 bits) ...
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Data Payload (LL * 8 bits) ...
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
This happens independently in each of the two TCP directions.

View file

@ -59,22 +59,30 @@ struct Packet : QueuedItem {
uint8 userdata;
uint8 protocol; // which protocol is this packet for/from
IpAddr addr; // Optionally set to target/source of the packet
byte data_pre[4];
byte data_buf[0];
enum {
// there's always this much data before data_ptr
// there's always this much data before data_buf, to allow for header expansion
// in front.
HEADROOM_BEFORE = 64,
};
byte data_pre[HEADROOM_BEFORE];
byte data_buf[0];
void Reset() {
data = data_buf;
size = 0;
}
};
enum {
kPacketAllocSize = 2048 - 16,
kPacketCapacity = kPacketAllocSize - sizeof(Packet) - Packet::HEADROOM_BEFORE,
kPacketCapacity = kPacketAllocSize - sizeof(Packet),
};
void FreePacket(Packet *packet);
void FreePackets(Packet *packet, Packet **end, int count);
void FreePacketList(Packet *packet);
Packet *AllocPacket();
void FreeAllPackets();
@ -123,7 +131,7 @@ public:
class UdpInterface {
public:
virtual bool Configure(int listen_port) = 0;
virtual bool Configure(int listen_port_udp, int listen_port_tcp) = 0;
virtual void WriteUdpPacket(Packet *packet) = 0;
};

File diff suppressed because it is too large Load diff

304
network_bsd.h Normal file
View file

@ -0,0 +1,304 @@
// SPDX-License-Identifier: AGPL-1.0-only
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
#ifndef TUNSAFE_NETWORK_BSD_H_
#define TUNSAFE_NETWORK_BSD_H_
#include <poll.h>
#include <sys/un.h>
#include <sys/uio.h>
#include <string>
#include "network_common.h"
class BaseSocketBsd;
class TcpSocketBsd;
class WireguardProcessor;
class Packet;
class NetworkBsd {
friend class BaseSocketBsd;
friend class TcpSocketBsd;
friend class UdpSocketBsd;
friend class TunSocketBsd;
public:
enum {
#if defined(OS_ANDROID)
WithSigalarmSupport = 0,
#else
WithSigalarmSupport = 1
#endif
};
class NetworkBsdDelegate {
public:
virtual void OnSecondLoop(uint64 now) {}
virtual void RunAllMainThreadScheduled() {}
};
explicit NetworkBsd(NetworkBsdDelegate *delegate, int max_sockets);
~NetworkBsd();
void RunLoop(const sigset_t *sigmask);
void PostExit() { exit_ = true; }
bool *exit_flag() { return &exit_; }
bool *sigalarm_flag() { return &sigalarm_flag_; }
TcpSocketBsd *tcp_sockets() { return tcp_sockets_; }
bool overload() { return overload_; }
private:
void RemoveFromRoundRobin(int slot);
void ReallocateIov(size_t i);
void EnsureIovAllocated();
Packet *read_packet_;
bool exit_;
bool overload_;
bool sigalarm_flag_;
enum {
// This controls the max # of sockets we can support
kMaxIovec = 16,
};
int num_sock_;
int num_roundrobin_;
int num_endloop_;
int max_sockets_;
SimplePacketPool packet_pool_;
NetworkBsdDelegate *delegate_;
struct pollfd *pollfd_;
BaseSocketBsd **sockets_;
BaseSocketBsd **roundrobin_;
BaseSocketBsd **endloop_;
// Linked list of all tcp sockets
TcpSocketBsd *tcp_sockets_;
struct iovec iov_[kMaxIovec];
Packet *iov_packets_[kMaxIovec];
};
class BaseSocketBsd {
friend class NetworkBsd;
public:
BaseSocketBsd(NetworkBsd *network) : pollfd_slot_(-1), roundrobin_slot_(-1), endloop_slot_(-1), fd_(-1), network_(network) {}
virtual ~BaseSocketBsd();
virtual void HandleEvents(int revents) = 0;
// Return |false| to remove socket from roundrobin list.
virtual bool DoRoundRobin() { return false; }
virtual void DoEndloop() {}
virtual void Periodic() {}
// Make sure this socket gets called during each round robin step.
void AddToRoundRobin();
// Make sure this sockets get called at the end of the loop
void AddToEndLoop();
int GetFd() { return fd_; }
protected:
void SetPollFlags(int events) {
network_->pollfd_[pollfd_slot_].events = events;
}
void InitPollSlot(int fd, int events);
bool HasFreePollSlot() { return network_->num_sock_ != network_->max_sockets_; }
void CloseSocket();
NetworkBsd *network_;
int pollfd_slot_;
int roundrobin_slot_;
int endloop_slot_;
int fd_;
};
class TunSocketBsd : public BaseSocketBsd {
public:
explicit TunSocketBsd(NetworkBsd *network, WireguardProcessor *processor);
virtual ~TunSocketBsd();
bool Initialize(int fd);
virtual void HandleEvents(int revents) override;
virtual bool DoRoundRobin() override;
void WritePacket(Packet *packet);
bool tun_interface_gone() const { return tun_interface_gone_; }
private:
bool DoRead();
bool DoWrite();
bool tun_readable_, tun_writable_;
bool tun_interface_gone_;
Packet *tun_queue_, **tun_queue_end_;
WireguardProcessor *processor_;
};
class UdpSocketBsd : public BaseSocketBsd {
public:
explicit UdpSocketBsd(NetworkBsd *network, WireguardProcessor *processor);
virtual ~UdpSocketBsd();
bool Initialize(int listen_port);
virtual void HandleEvents(int revents) override;
virtual bool DoRoundRobin() override;
bool DoRead();
bool DoWrite();
void WritePacket(Packet *packet);
private:
bool udp_readable_, udp_writable_;
Packet *udp_queue_, **udp_queue_end_;
WireguardProcessor *processor_;
};
#if defined(OS_LINUX)
// Keeps track of when the unix socket gets deleted
class UnixSocketDeletionWatcher {
public:
UnixSocketDeletionWatcher();
~UnixSocketDeletionWatcher();
bool Start(const char *path, bool *flag_to_set);
void Stop();
bool Poll(const char *path) { return false; }
private:
static void *RunThread(void *arg);
void *RunThreadInner();
const char *path_;
int inotify_fd_;
int pid_;
int pipes_[2];
pthread_t thread_;
bool *flag_to_set_;
};
#else // !defined(OS_LINUX)
// all other platforms that lack inotify
class UnixSocketDeletionWatcher {
public:
UnixSocketDeletionWatcher() {}
~UnixSocketDeletionWatcher() {}
bool Start(const char *path, bool *flag_to_set) { return true; }
void Stop() {}
bool Poll(const char *path);
};
#endif // !defined(OS_LINUX)
class UnixDomainSocketListenerBsd : public BaseSocketBsd {
public:
explicit UnixDomainSocketListenerBsd(NetworkBsd *network, WireguardProcessor *processor);
virtual ~UnixDomainSocketListenerBsd();
bool Initialize(const char *devname);
bool Start(bool *exit_flag) {
return un_deletion_watcher_.Start(un_addr_.sun_path, exit_flag);
}
void Stop() { un_deletion_watcher_.Stop(); }
virtual void HandleEvents(int revents) override;
virtual void Periodic() override;
private:
struct sockaddr_un un_addr_;
WireguardProcessor *processor_;
UnixSocketDeletionWatcher un_deletion_watcher_;
};
class UnixDomainSocketChannelBsd : public BaseSocketBsd {
public:
explicit UnixDomainSocketChannelBsd(NetworkBsd *network, WireguardProcessor *processor, int fd);
virtual ~UnixDomainSocketChannelBsd();
virtual void HandleEvents(int revents) override;
private:
bool HandleEventsInner(int revents);
WireguardProcessor *processor_;
std::string inbuf_, outbuf_;
};
class TcpSocketListenerBsd : public BaseSocketBsd {
public:
explicit TcpSocketListenerBsd(NetworkBsd *bsd, WireguardProcessor *processor);
virtual ~TcpSocketListenerBsd();
bool Initialize(int port);
virtual void HandleEvents(int revents) override;
virtual void Periodic() override;
private:
WireguardProcessor *processor_;
};
class TcpSocketBsd : public BaseSocketBsd {
public:
explicit TcpSocketBsd(NetworkBsd *bsd, WireguardProcessor *processor);
virtual ~TcpSocketBsd();
void InitializeIncoming(int fd, const IpAddr &addr);
bool InitializeOutgoing(const IpAddr &addr);
void WritePacket(Packet *packet);
virtual void HandleEvents(int revents) override;
virtual void DoEndloop() override;
TcpSocketBsd *next() { return next_; }
uint8 endpoint_protocol() { return endpoint_protocol_; }
const IpAddr &endpoint() { return endpoint_; }
public:
uint8 age;
uint8 handshake_attempts;
private:
void DoRead();
void DoWrite();
void CloseSocketAndDestroy();
bool readable_, writable_;
bool got_eof_;
uint8 endpoint_protocol_;
bool want_connect_;
uint32 wqueue_bytes_;
Packet *wqueue_, **wqueue_end_;
TcpSocketBsd *next_;
WireguardProcessor *processor_;
TcpPacketHandler tcp_packet_handler_;
IpAddr endpoint_;
};
class NotificationPipeBsd : public BaseSocketBsd {
public:
NotificationPipeBsd(NetworkBsd *network);
~NotificationPipeBsd();
typedef void CallbackFunc(void *x);
void InjectCallback(CallbackFunc *func, void *param);
void Wakeup();
virtual void HandleEvents(int revents) override;
private:
struct CallbackState {
CallbackFunc *func;
void *param;
CallbackState *next;
};
int pipe_fds_[2];
std::atomic<CallbackState*> injected_cb_;
};
#endif // TUNSAFE_NETWORK_BSD_H_

View file

@ -1,101 +0,0 @@
// SPDX-License-Identifier: AGPL-1.0-only
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
#ifndef TUNSAFE_NETWORK_BSD_COMMON_H_
#define TUNSAFE_NETWORK_BSD_COMMON_H_
#include "netapi.h"
#include "wireguard.h"
#include "wireguard_config.h"
#include <string>
#include <signal.h>
struct RouteInfo {
uint8 family;
uint8 cidr;
uint8 ip[16];
uint8 gw[16];
std::string dev;
};
#if defined(OS_LINUX)
// Keeps track of when the unix socket gets deleted
class UnixSocketDeletionWatcher {
public:
UnixSocketDeletionWatcher();
~UnixSocketDeletionWatcher();
bool Start(const char *path, bool *flag_to_set);
void Stop();
bool Poll(const char *path) { return false; }
private:
static void *RunThread(void *arg);
void *RunThreadInner();
const char *path_;
int inotify_fd_;
int pid_;
int pipes_[2];
pthread_t thread_;
bool *flag_to_set_;
};
#else // !defined(OS_LINUX)
// all other platforms that lack inotify
class UnixSocketDeletionWatcher {
public:
UnixSocketDeletionWatcher() {}
~UnixSocketDeletionWatcher() {}
bool Start(const char *path, bool *flag_to_set) { return true; }
void Stop() {}
bool Poll(const char *path);
};
#endif // !defined(OS_LINUX)
class TunsafeBackendBsd : public TunInterface, public UdpInterface {
public:
TunsafeBackendBsd();
virtual ~TunsafeBackendBsd();
void RunLoop();
void CleanupRoutes();
void SetTunDeviceName(const char *name);
void SetProcessor(WireguardProcessor *wg) { processor_ = wg; }
// -- from TunInterface
virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override;
virtual void HandleSigAlrm() = 0;
virtual void HandleExit() = 0;
protected:
virtual bool InitializeTun(char devname[16]) = 0;
virtual void RunLoopInner() = 0;
void AddRoute(uint32 ip, uint32 cidr, uint32 gw, const char *dev);
void DelRoute(const RouteInfo &cd);
bool AddRoute(int family, const void *dest, int dest_prefix, const void *gateway, const char *dev);
bool RunPrePostCommand(const std::vector<std::string> &vec);
WireguardProcessor *processor_;
std::vector<RouteInfo> cleanup_commands_;
std::vector<std::string> pre_down_, post_down_;
std::vector<WgCidrAddr> addresses_to_remove_;
sigset_t orig_signal_mask_;
char devname_[16];
bool tun_interface_gone_;
};
#if defined(OS_MACOSX) || defined(OS_FREEBSD)
#define TUN_PREFIX_BYTES 4
#elif defined(OS_LINUX)
#define TUN_PREFIX_BYTES 0
#endif
int open_tun(char *devname, size_t devname_size);
int open_udp(int listen_on_port);
void SetThreadName(const char *name);
TunsafeBackendBsd *CreateTunsafeBackendBsd();
#endif // TUNSAFE_NETWORK_BSD_COMMON_H_

174
network_common.cpp Normal file
View file

@ -0,0 +1,174 @@
#include "stdafx.h"
#include "network_common.h"
#include "netapi.h"
#include "tunsafe_endian.h"
#include <assert.h>
#include <algorithm>
#include "util.h"
TcpPacketHandler::TcpPacketHandler(SimplePacketPool *packet_pool) {
packet_pool_ = packet_pool;
rqueue_bytes_ = 0;
error_flag_ = false;
rqueue_ = NULL;
rqueue_end_ = &rqueue_;
predicted_key_in_ = predicted_key_out_ = 0;
predicted_serial_in_ = predicted_serial_out_ = 0;
}
TcpPacketHandler::~TcpPacketHandler() {
FreePacketList(rqueue_);
}
enum {
kTcpPacketType_Normal = 0,
kTcpPacketType_Reserved = 1,
kTcpPacketType_Data = 2,
kTcpPacketType_Control = 3,
kTcpPacketControlType_SetKeyAndCounter = 0,
};
void TcpPacketHandler::AddHeaderToOutgoingPacket(Packet *p) {
unsigned int size = p->size;
uint8 *data = p->data;
if (size >= 16 && ReadLE32(data) == 4) {
uint32 key = Read32(data + 4);
uint64 serial = ReadLE64(data + 8);
WriteBE16(data + 14, size - 16 + (kTcpPacketType_Data << 14));
data += 14, size -= 14;
// Insert a 15 byte control packet right before to set the new key/serial?
if ((predicted_key_out_ ^ key) | (predicted_serial_out_ ^ serial)) {
predicted_key_out_ = key;
WriteLE64(data - 8, serial);
Write32(data - 12, key);
data[-13] = kTcpPacketControlType_SetKeyAndCounter;
WriteBE16(data - 15, 13 + (kTcpPacketType_Control << 14));
data -= 15, size += 15;
}
// Increase the serial by 1 for next packet.
predicted_serial_out_ = serial + 1;
} else {
WriteBE16(data - 2, size);
data -= 2, size += 2;
}
p->size = size;
p->data = data;
}
void TcpPacketHandler::QueueIncomingPacket(Packet *p) {
rqueue_bytes_ += p->size;
p->queue_next = NULL;
*rqueue_end_ = p;
rqueue_end_ = &Packet_NEXT(p);
}
// Either the packet fits in one buf or not.
static uint32 ReadPacketHeader(Packet *p) {
if (p->size >= 2)
return ReadBE16(p->data);
else
return (p->data[0] << 8) + (Packet_NEXT(p)->data[0]);
}
// Move data around to ensure that exactly the first |num| bytes are stored
// in the first packet, and the rest of the data in subsequent packets.
Packet *TcpPacketHandler::ReadNextPacket(uint32 num) {
Packet *p = rqueue_;
assert(num <= kPacketCapacity);
if (p->size < num) {
// There's not enough data in the current packet, copy data from the next packet
// into this packet.
if ((uint32)(&p->data_buf[kPacketCapacity] - p->data) < num) {
// Move data up front to make space.
memmove(p->data_buf, p->data, p->size);
p->data = p->data_buf;
}
// Copy data from future packets into p, and delete them should they become empty.
do {
Packet *n = Packet_NEXT(p);
uint32 bytes_to_copy = std::min(n->size, num - p->size);
uint32 nsize = (n->size -= bytes_to_copy);
memcpy(p->data + postinc(p->size, bytes_to_copy), postinc(n->data, bytes_to_copy), bytes_to_copy);
if (nsize == 0) {
p->queue_next = n->queue_next;
packet_pool_->FreePacketToPool(n);
}
} while (num - p->size);
} else if (p->size > num) {
// The packet has too much data. Split the packet into two packets.
Packet *n = packet_pool_->AllocPacketFromPool();
if (!n)
return NULL; // unable to allocate a packet....?
if (num * 2 <= p->size) {
// There's a lot of trailing data: PP NNNNNN. Move PP.
n->size = num;
p->size -= num;
rqueue_bytes_ -= num;
memcpy(n->data, postinc(p->data, num), num);
return n;
} else {
uint32 overflow = p->size - num;
// There's a lot of leading data: PPPPPP NN. Move NN
n->size = overflow;
p->size = num;
rqueue_ = n;
if (!(n->queue_next = p->queue_next))
rqueue_end_ = &Packet_NEXT(n);
rqueue_bytes_ -= num;
memcpy(n->data, p->data + num, overflow);
return p;
}
}
if ((rqueue_ = Packet_NEXT(p)) == NULL)
rqueue_end_ = &rqueue_;
rqueue_bytes_ -= num;
return p;
}
Packet *TcpPacketHandler::GetNextWireguardPacket() {
while (rqueue_bytes_ >= 2) {
uint32 packet_header = ReadPacketHeader(rqueue_);
uint32 packet_size = packet_header & 0x3FFF;
uint32 packet_type = packet_header >> 14;
if (packet_size + 2 > rqueue_bytes_)
return NULL;
if (packet_size + 2 > kPacketCapacity) {
RERROR("Oversized packet?");
error_flag_ = true;
return NULL;
}
Packet *packet = ReadNextPacket(packet_size + 2);
if (packet) {
// RINFO("Packet of type %d, size %d", packet_type, packet->size - 2);
packet->data += 2, packet->size -= 2;
if (packet_type == kTcpPacketType_Normal) {
return packet;
} else if (packet_type == kTcpPacketType_Data) {
// Optimization when the 16 first bytes are known and prefixed to the packet
assert(packet->data >= packet->data_buf);
packet->data -= 16, packet->size += 16;
WriteLE32(packet->data, 4);
Write32(packet->data + 4, predicted_key_in_);
WriteLE64(packet->data + 8, predicted_serial_in_);
predicted_serial_in_++;
return packet;
} else if (packet_type == kTcpPacketType_Control) {
// Unknown control packets are silently ignored
if (packet->size == 13 && packet->data[0] == kTcpPacketControlType_SetKeyAndCounter) {
// Control packet to setup the predicted key/sequence nr
predicted_key_in_ = Read32(packet->data + 1);
predicted_serial_in_ = ReadLE64(packet->data + 5);
}
packet_pool_->FreePacketToPool(packet);
} else {
packet_pool_->FreePacketToPool(packet);
error_flag_ = true;
return NULL;
}
}
}
return NULL;
}

95
network_common.h Normal file
View file

@ -0,0 +1,95 @@
#ifndef TUNSAFE_NETWORK_COMMON_H_
#define TUNSAFE_NETWORK_COMMON_H_
#include "netapi.h"
class PacketProcessor;
// A simple singlethreaded pool of packets used on windows where
// FreePacket / AllocPacket are multithreded and thus slightly slower
#if defined(OS_WIN)
class SimplePacketPool {
public:
explicit SimplePacketPool() {
freed_packets_ = NULL;
freed_packets_count_ = 0;
}
~SimplePacketPool() {
FreePacketList(freed_packets_);
}
Packet *AllocPacketFromPool() {
if (Packet *p = freed_packets_) {
freed_packets_ = Packet_NEXT(p);
freed_packets_count_--;
p->Reset();
return p;
}
return AllocPacket();
}
void FreePacketToPool(Packet *p) {
Packet_NEXT(p) = freed_packets_;
freed_packets_ = p;
freed_packets_count_++;
}
void FreeSomePackets() {
if (freed_packets_count_ > 32)
FreeSomePacketsInner();
}
void FreeSomePacketsInner();
int freed_packets_count_;
Packet *freed_packets_;
};
#else
class SimplePacketPool {
public:
Packet *AllocPacketFromPool() {
return AllocPacket();
}
void FreePacketToPool(Packet *packet) {
return FreePacket(packet);
}
};
#endif
// Aids with prefixing and parsing incoming and outgoing
// packets with the tcp protocol header.
class TcpPacketHandler {
public:
explicit TcpPacketHandler(SimplePacketPool *packet_pool);
~TcpPacketHandler();
// Adds a tcp header to a data packet so it can be transmitted on the wire
void AddHeaderToOutgoingPacket(Packet *p);
// Add a new chunk of incoming data to the packet list
void QueueIncomingPacket(Packet *p);
// Attempt to extract the next packet, returns NULL when complete.
Packet *GetNextWireguardPacket();
bool error() const { return error_flag_; }
private:
// Internal function to read a packet
Packet *ReadNextPacket(uint32 num);
SimplePacketPool *packet_pool_;
// Total # of bytes queued
uint32 rqueue_bytes_;
// Set if there's a fatal error
bool error_flag_;
// These hold the incoming packets before they're parsed
Packet *rqueue_, **rqueue_end_;
uint32 predicted_key_in_, predicted_key_out_;
uint64 predicted_serial_in_, predicted_serial_out_;
};
#endif // TUNSAFE_NETWORK_COMMON_H_

View file

@ -5,6 +5,8 @@
#include "wireguard_config.h"
#include "netapi.h"
#include <Iphlpapi.h>
#include <Mswsock.h>
#include <ws2ipdef.h>
#include <stdlib.h>
#include <assert.h>
#include <malloc.h>
@ -12,7 +14,6 @@
#include <string.h>
#include <vector>
#include <Iphlpapi.h>
#include <ws2ipdef.h>
#include <assert.h>
#include <exdisp.h>
#include "tunsafe_endian.h"
@ -42,15 +43,20 @@ static HKEY g_hklm_reg_key;
static uint8 g_killswitch_curr, g_killswitch_want, g_killswitch_currconn;
bool g_allow_pre_post;
static volatile bool g_fail_malloc_flag;
static void DeactivateKillSwitch(uint32 want);
Packet *AllocPacket() {
Packet *packet = (Packet*)InterlockedPopEntrySList(&freelist_head);
if (packet == NULL)
packet = (Packet *)_aligned_malloc(kPacketAllocSize, 16);
packet->data = packet->data_buf + Packet::HEADROOM_BEFORE;
packet->size = 0;
if (packet == NULL) {
while ((packet = (Packet *)_aligned_malloc(kPacketAllocSize, 16)) == NULL) {
if (g_fail_malloc_flag)
return NULL;
Sleep(1000);
}
}
packet->Reset();
return packet;
}
@ -83,6 +89,14 @@ void FreeAllPackets() {
}
}
void SimplePacketPool::FreeSomePacketsInner() {
int n = freed_packets_count_ - 24;
Packet **p = &freed_packets_;
for (; n; n--)
p = &Packet_NEXT(*p);
FreePackets(exch(freed_packets_, *p), p, exch(freed_packets_count_, 24) - 24);
}
void InitPacketMutexes() {
static bool mutex_inited;
if (!mutex_inited) {
@ -91,17 +105,6 @@ void InitPacketMutexes() {
}
}
int tpq_last_qsize;
int g_tun_reads, g_tun_writes;
struct {
uint32 pad1[3];
uint32 udp_qsize1;
uint32 pad2[3];
uint32 udp_qsize2;
} qs;
#define kConcurrentReadTap 16
#define kConcurrentWriteTap 16
@ -399,47 +402,13 @@ static bool GetDefaultRouteAndDeleteOldRoutes(int family, const NET_LUID *Interf
return (rv == 0);
}
static inline bool NoMoreAllocationRetry(volatile bool *exit_flag) {
if (*exit_flag)
return true;
Sleep(1000);
return *exit_flag;
}
static inline bool AllocPacketFrom(Packet **list, int *counter, bool *exit_flag, Packet **res) {
Packet *p;
if (p = *list) {
*list = Packet_NEXT(p);
(*counter)--;
p->data = p->data_buf + Packet::HEADROOM_BEFORE;
} else {
while ((p = AllocPacket()) == NULL) {
if (NoMoreAllocationRetry(exit_flag))
return false;
}
}
*res = p;
return true;
}
static void FreePacketList(Packet *pp) {
void FreePacketList(Packet *pp) {
while (Packet *p = pp) {
pp = Packet_NEXT(p);
FreePacket(p);
}
}
inline void NetworkWin32::FreePacketToPool(Packet *p) {
Packet_NEXT(p) = NULL;
*freed_packets_end_ = p;
freed_packets_end_ = &Packet_NEXT(p);
freed_packets_count_++;
}
inline bool NetworkWin32::AllocPacketFromPool(Packet **p) {
return AllocPacketFrom(&freed_packets_, &freed_packets_count_, &exit_thread_, p);
}
UdpSocketWin32::UdpSocketWin32(NetworkWin32 *network_win32) {
network_ = network_win32;
wqueue_end_ = &wqueue_;
@ -455,6 +424,9 @@ UdpSocketWin32::UdpSocketWin32(NetworkWin32 *network_win32) {
num_reads_[0] = num_reads_[1] = 0;
num_writes_ = 0;
pending_writes_ = NULL;
qsize1_ = 0;
qsize2_ = 0;
}
UdpSocketWin32::~UdpSocketWin32() {
@ -529,12 +501,12 @@ fail:
// Called on another thread to queue up a udp packet
void UdpSocketWin32::WriteUdpPacket(Packet *packet) {
if (qs.udp_qsize2 - qs.udp_qsize1 >= (unsigned)(packet->size < 576 ? MAX_BYTES_IN_UDP_OUT_QUEUE_SMALL : MAX_BYTES_IN_UDP_OUT_QUEUE)) {
if (qsize2_ - qsize1_ >= (unsigned)(packet->size < 576 ? MAX_BYTES_IN_UDP_OUT_QUEUE_SMALL : MAX_BYTES_IN_UDP_OUT_QUEUE)) {
FreePacket(packet);
return;
}
Packet_NEXT(packet) = NULL;
qs.udp_qsize2 += packet->size;
packet->queue_next = NULL;
qsize2_ += packet->size;
mutex_.Acquire();
Packet *was_empty = wqueue_;
@ -542,20 +514,14 @@ void UdpSocketWin32::WriteUdpPacket(Packet *packet) {
wqueue_end_ = &Packet_NEXT(packet);
mutex_.Release();
if (was_empty == NULL) {
// Notify the worker thread that it should attempt more writes
PostQueuedCompletionStatus(network_->completion_port_handle_, NULL, NULL, NULL);
}
if (was_empty == NULL)
network_->WakeUp();
}
enum {
kUdpGetQueuedCompletionStatusSize = kConcurrentWriteTap + kConcurrentReadTap + 1
};
static inline void ClearOverlapped(OVERLAPPED *o) {
memset(o, 0, sizeof(*o));
}
#ifndef STATUS_PORT_UNREACHABLE
#define STATUS_PORT_UNREACHABLE 0xC000023F
#endif
@ -567,8 +533,8 @@ static inline bool IsIgnoredUdpError(DWORD err) {
void UdpSocketWin32::DoMoreReads() {
// Listen with multiple ipv6 packets only if we ever sent an ipv6 packet.
for (int i = num_reads_[IPV6]; i < max_read_ipv6_; i++) {
Packet *p;
if (!network_->AllocPacketFromPool(&p))
Packet *p = network_->packet_pool().AllocPacketFromPool();
if (!p)
break;
restart_read_udp6:
ClearOverlapped(&p->overlapped);
@ -590,32 +556,35 @@ restart_read_udp6:
num_reads_[IPV6]++;
}
// Initiate more reads, reusing the Packet structures in |finished_writes|.
for (int i = num_reads_[IPV4]; i < kConcurrentReadUdp; i++) {
Packet *p;
if (!network_->AllocPacketFromPool(&p))
break;
restart_read_udp:
ClearOverlapped(&p->overlapped);
WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data};
DWORD flags = 0;
p->userdata = IPV4;
p->sin_size = sizeof(p->addr.sin);
p->queue_cb = this;
if (WSARecvFrom(socket_, &wsabuf, 1, NULL, &flags, (struct sockaddr*)&p->addr, &p->sin_size, &p->overlapped, NULL) != 0) {
DWORD err = WSAGetLastError();
if (err != WSA_IO_PENDING) {
if (err == WSAEMSGSIZE || err == WSAECONNRESET || err == WSAENETRESET)
goto restart_read_udp;
RERROR("UdpSocketWin32:WSARecvFrom failed 0x%X", err);
FreePacket(p);
if (socket_ != INVALID_SOCKET) {
for (int i = num_reads_[IPV4]; i < kConcurrentReadUdp; i++) {
Packet *p = network_->packet_pool().AllocPacketFromPool();
if (!p)
break;
restart_read_udp:
ClearOverlapped(&p->overlapped);
WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data};
DWORD flags = 0;
p->userdata = IPV4;
p->sin_size = sizeof(p->addr.sin);
p->queue_cb = this;
if (WSARecvFrom(socket_, &wsabuf, 1, NULL, &flags, (struct sockaddr*)&p->addr, &p->sin_size, &p->overlapped, NULL) != 0) {
DWORD err = WSAGetLastError();
if (err != WSA_IO_PENDING) {
if (err == WSAEMSGSIZE || err == WSAECONNRESET || err == WSAENETRESET)
goto restart_read_udp;
RERROR("UdpSocketWin32:WSARecvFrom failed 0x%X", err);
FreePacket(p);
break;
}
}
num_reads_[IPV4]++;
}
num_reads_[IPV4]++;
}
}
void UdpSocketWin32::DoMoreWrites() {
void UdpSocketWin32::ProcessPackets() {
// Push all the finished reads to the packet handler
if (finished_reads_ != NULL) {
packet_handler_->PostPackets(finished_reads_, finished_reads_end_, finished_reads_count_);
@ -623,7 +592,9 @@ void UdpSocketWin32::DoMoreWrites() {
finished_reads_end_ = &finished_reads_;
finished_reads_count_ = 0;
}
}
void UdpSocketWin32::DoMoreWrites() {
Packet *pending_writes = pending_writes_;
// Initiate more writes from |wqueue_|
while (num_writes_ < kConcurrentWriteUdp) {
@ -639,7 +610,7 @@ void UdpSocketWin32::DoMoreWrites() {
if (!pending_writes)
break;
}
qs.udp_qsize1 += pending_writes->size;
qsize1_ += pending_writes->size;
// Then issue writes
Packet *p = pending_writes;
@ -688,13 +659,14 @@ void UdpSocketWin32::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) {
if (p->userdata < 2) {
num_reads_[p->userdata]--;
if ((DWORD)p->overlapped.Internal != 0) {
network_->FreePacketToPool(p);
if (!IsIgnoredUdpError((DWORD)p->overlapped.Internal))
RERROR("UdpSocketWin32::Read error 0x%X", (DWORD)p->overlapped.Internal);
network_->packet_pool().FreePacketToPool(p);
} else {
// Remember all the finished packets and queue them up to the next thread once we've
// collected them all.
p->size = (int)p->overlapped.InternalHigh;
p->protocol = kPacketProtocolUdp;
p->queue_cb = packet_handler_->udp_queue();
p->queue_next = NULL;
*finished_reads_end_ = p;
@ -703,9 +675,9 @@ void UdpSocketWin32::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) {
}
} else {
num_writes_--;
network_->FreePacketToPool(p);
if ((DWORD)p->overlapped.Internal != 0)
RERROR("UdpSocketWin32::Write error 0x%X", (DWORD)p->overlapped.Internal);
network_->packet_pool().FreePacketToPool(p);
}
}
@ -716,28 +688,30 @@ void UdpSocketWin32::OnQueuedItemDelete(QueuedItem *qi) {
} else {
num_writes_--;
}
network_->FreePacketToPool(p);
network_->packet_pool().FreePacketToPool(p);
}
void UdpSocketWin32::DoIO() {
DoMoreWrites();
ProcessPackets();
DoMoreReads();
}
////////////////////////////////////////////////////////////////////////////////////////////////////////
NetworkWin32::NetworkWin32() : udp_socket_(this) {
NetworkWin32::NetworkWin32() : udp_socket_(this), tcp_socket_queue_(this) {
exit_thread_ = false;
thread_ = NULL;
freed_packets_ = NULL;
freed_packets_end_ = &freed_packets_;
freed_packets_count_ = 0;
completion_port_handle_ = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, NULL, 0);
tcp_socket_ = NULL;
}
NetworkWin32::~NetworkWin32() {
assert(thread_ == NULL);
for (TcpSocketWin32 *socket = tcp_socket_; socket; )
delete exch(socket, socket->next_);
CloseHandle(completion_port_handle_);
FreePacketList(freed_packets_);
}
DWORD WINAPI NetworkWin32::NetworkThread(void *x) {
@ -750,20 +724,13 @@ void NetworkWin32::ThreadMain() {
OVERLAPPED_ENTRY entries[kUdpGetQueuedCompletionStatusSize];
while (!exit_thread_) {
// Run IO on all sockets queued for IO
// TODO: In the future, don't process every socket here, only
// those sockets that requested it.
udp_socket_.DoIO();
for (TcpSocketWin32 *tcp = tcp_socket_; tcp;)
exch(tcp, tcp->next_)->DoIO();
// Free some packets
assert(freed_packets_count_ >= 0);
if (freed_packets_count_ >= 32) {
FreePackets(freed_packets_, freed_packets_end_, freed_packets_count_);
freed_packets_count_ = 0;
freed_packets_ = NULL;
freed_packets_end_ = &freed_packets_;
} else if (freed_packets_ == NULL) {
assert(freed_packets_count_ == 0);
freed_packets_end_ = &freed_packets_;
}
packet_pool_.FreeSomePackets();
ULONG num_entries = 0;
if (!GetQueuedCompletionStatusEx(completion_port_handle_, entries, kUdpGetQueuedCompletionStatusSize, &num_entries, INFINITE, FALSE)) {
@ -779,20 +746,33 @@ void NetworkWin32::ThreadMain() {
}
udp_socket_.CancelAllIO();
for (TcpSocketWin32 *tcp = tcp_socket_; tcp; tcp = tcp->next_)
tcp->CancelAllIO();
while (udp_socket_.HasOutstandingIO()) {
while (HasOutstandingIO()) {
ULONG num_entries = 0;
if (!GetQueuedCompletionStatusEx(completion_port_handle_, entries, 1, &num_entries, INFINITE, FALSE)) {
if (!GetQueuedCompletionStatusEx(completion_port_handle_, entries, kUdpGetQueuedCompletionStatusSize, &num_entries, INFINITE, FALSE)) {
RINFO("GetQueuedCompletionStatusEx failed.");
break;
}
if (entries[0].lpOverlapped) {
QueuedItem *w = (QueuedItem*)((byte*)entries[0].lpOverlapped - offsetof(QueuedItem, overlapped));
w->queue_cb->OnQueuedItemDelete(w);
for (ULONG i = 0; i < num_entries; i++) {
if (entries[i].lpOverlapped) {
QueuedItem *w = (QueuedItem*)((byte*)entries[i].lpOverlapped - offsetof(QueuedItem, overlapped));
w->queue_cb->OnQueuedItemDelete(w);
}
}
}
}
bool NetworkWin32::HasOutstandingIO() {
if (udp_socket_.HasOutstandingIO())
return true;
for (TcpSocketWin32 *tcp = tcp_socket_; tcp; tcp = tcp->next_)
if (tcp->HasOutstandingIO())
return true;
return false;
}
void NetworkWin32::StartThread() {
assert(completion_port_handle_);
@ -804,11 +784,36 @@ void NetworkWin32::StartThread() {
void NetworkWin32::StopThread() {
if (thread_ != NULL) {
exit_thread_ = true;
g_fail_malloc_flag = true;
PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, NULL);
WaitForSingleObject(thread_, INFINITE);
CloseHandle(thread_);
thread_ = NULL;
exit_thread_ = false;
g_fail_malloc_flag = false;
}
}
void NetworkWin32::WakeUp() {
PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, NULL);
}
void NetworkWin32::PostQueuedItem(QueuedItem *item) {
PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, &item->overlapped);
}
bool NetworkWin32::Configure(int listen_port, int listen_port_tcp) {
if (listen_port_tcp)
RERROR("ListenPortTCP not supported in this version");
return udp_socket_.Configure(listen_port);
}
// Called from tunsafe thread
void NetworkWin32::WriteUdpPacket(Packet *packet) {
if (packet->protocol & kPacketProtocolUdp) {
udp_socket_.WriteUdpPacket(packet);
} else {
tcp_socket_queue_.WritePacket(packet);
}
}
@ -874,6 +879,8 @@ int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) {
mutex_.Acquire();
while (!(exit_code = exit_code_)) {
FreeAllPackets();
if (timer_interrupt_) {
timer_interrupt_ = false;
need_notify_ = 0;
@ -912,7 +919,6 @@ int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) {
need_notify_ = 0;
mutex_.Release();
tpq_last_qsize = packets_in_queue;
if (packets_in_queue >= 1024)
overload = 2;
queue_context.overload = (overload != 0);
@ -986,8 +992,8 @@ void PacketProcessor::PostPackets(Packet *first, Packet **end, int count) {
}
void PacketProcessor::ForcePost(QueuedItem *item) {
mutex_.Acquire();
item->queue_next = NULL;
mutex_.Acquire();
packets_in_queue_ += 1;
*last_ptr_ = item;
last_ptr_ = &item->queue_next;
@ -1648,7 +1654,6 @@ bool TunWin32Adapter::RunPrePostCommand(const std::vector<std::string> &vec) {
return success;
}
//////////////////////////////////////////////////////////////////////////////
TunWin32Iocp::TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker, backend->guid_), backend_(backend) {
@ -1704,6 +1709,20 @@ enum {
kTunGetQueuedCompletionStatusSize = kConcurrentWriteTap + kConcurrentReadTap + 1
};
static inline bool AllocPacketFrom(Packet **list, int *counter, bool *exit_flag, Packet **res) {
Packet *p;
if (p = *list) {
*list = Packet_NEXT(p);
(*counter)--;
p->data = p->data_buf;
} else {
if (!(p = AllocPacket()))
return false;
}
*res = p;
return true;
}
void TunWin32Iocp::ThreadMain() {
OVERLAPPED_ENTRY entries[kTunGetQueuedCompletionStatusSize];
Packet *pending_writes = NULL;
@ -1738,7 +1757,6 @@ void TunWin32Iocp::ThreadMain() {
num_reads++;
}
}
g_tun_reads = num_reads;
assert(freed_packets_count >= 0);
if (freed_packets_count >= 32) {
@ -1820,7 +1838,6 @@ void TunWin32Iocp::ThreadMain() {
num_writes++;
}
}
g_tun_writes = num_writes;
}
EXIT:
@ -1896,217 +1913,6 @@ void TunWin32Iocp::WriteTunPacket(Packet *packet) {
//////////////////////////////////////////////////////////////////////////////
TunWin32Overlapped::TunWin32Overlapped(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker, backend->guid_), backend_(backend) {
wqueue_end_ = &wqueue_;
wqueue_ = NULL;
thread_ = NULL;
read_event_ = CreateEvent(NULL, TRUE, FALSE, NULL);
write_event_ = CreateEvent(NULL, TRUE, FALSE, NULL);
wake_event_ = CreateEvent(NULL, FALSE, FALSE, NULL);
packet_handler_ = NULL;
exit_thread_ = false;
}
TunWin32Overlapped::~TunWin32Overlapped() {
CloseTun();
CloseHandle(read_event_);
CloseHandle(write_event_);
CloseHandle(wake_event_);
}
bool TunWin32Overlapped::Configure(const TunConfig &&config, TunConfigOut *out) {
CloseTun();
if (adapter_.OpenAdapter(backend_, FILE_FLAG_OVERLAPPED) &&
adapter_.ConfigureAdapter(std::move(config), out))
return true;
CloseTun();
return false;
}
void TunWin32Overlapped::CloseTun() {
assert(thread_ == NULL);
adapter_.CloseAdapter(false);
FreePacketList(wqueue_);
wqueue_ = NULL;
wqueue_end_ = &wqueue_;
}
void TunWin32Overlapped::ThreadMain() {
Packet *pending_writes = NULL;
DWORD err;
Packet *read_packet = NULL, *write_packet = NULL;
HANDLE h[3];
while (!exit_thread_) {
if (read_packet == NULL) {
Packet *p = AllocPacket();
ClearOverlapped(&p->overlapped);
p->overlapped.hEvent = read_event_;
if (!ReadFile(adapter_.handle(), p->data, kPacketCapacity, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) {
FreePacket(p);
RERROR("TunWin32: ReadFile failed 0x%X", err);
} else {
read_packet = p;
}
}
int n = 0;
if (write_packet)
h[n++] = write_event_;
if (read_packet != NULL)
h[n++] = read_event_;
h[n++] = wake_event_;
DWORD res = WaitForMultipleObjects(n, h, FALSE, INFINITE);
if (res >= WAIT_OBJECT_0 && res <= WAIT_OBJECT_0 + 2) {
HANDLE hx = h[res - WAIT_OBJECT_0];
if (hx == read_event_) {
read_packet->size = (int)read_packet->overlapped.InternalHigh;
Packet_NEXT(read_packet) = NULL;
packet_handler_->PostPackets(read_packet, &Packet_NEXT(read_packet), 1);
read_packet = NULL;
} else if (hx == write_event_) {
FreePacket(write_packet);
write_packet = NULL;
}
} else {
RERROR("Wait said %d", res);
}
if (write_packet == NULL) {
if (!pending_writes) {
mutex_.Acquire();
pending_writes = wqueue_;
wqueue_end_ = &wqueue_;
wqueue_ = NULL;
mutex_.Release();
}
if (pending_writes) {
// Then issue writes
Packet *p = pending_writes;
pending_writes = Packet_NEXT(p);
memset(&p->overlapped, 0, sizeof(p->overlapped));
p->overlapped.hEvent = write_event_;
if (!WriteFile(adapter_.handle(), p->data, p->size, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) {
RERROR("TunWin32: WriteFile failed 0x%X", err);
FreePacket(p);
} else {
write_packet = p;
}
}
}
}
// TODO: Free memory
CancelIo(adapter_.handle());
FreePacketList(pending_writes);
}
DWORD WINAPI TunWin32Overlapped::TunThread(void *x) {
TunWin32Overlapped *xx = (TunWin32Overlapped *)x;
xx->ThreadMain();
return 0;
}
void TunWin32Overlapped::StartThread() {
DWORD thread_id;
thread_ = CreateThread(NULL, 0, &TunThread, this, 0, &thread_id);
SetThreadPriority(thread_, ABOVE_NORMAL_PRIORITY_CLASS);
}
void TunWin32Overlapped::StopThread() {
exit_thread_ = true;
SetEvent(wake_event_);
WaitForSingleObject(thread_, INFINITE);
CloseHandle(thread_);
thread_ = NULL;
}
void TunWin32Overlapped::WriteTunPacket(Packet *packet) {
Packet_NEXT(packet) = NULL;
mutex_.Acquire();
Packet *was_empty = wqueue_;
*wqueue_end_ = packet;
wqueue_end_ = &Packet_NEXT(packet);
mutex_.Release();
if (was_empty == NULL)
SetEvent(wake_event_);
}
void TunsafeBackendWin32::SetPublicKey(const uint8 key[32]) {
memcpy(public_key_, key, 32);
delegate_->OnStateChanged();
}
DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) {
TunsafeBackendWin32 *backend = (TunsafeBackendWin32*)bk;
int stop_mode;
int fast_retry_ctr = 0;
for(;;) {
TunWin32Iocp tun(&backend->dns_blocker_, backend);
NetworkWin32 net;
WireguardProcessor wg_proc(&net.udp(), &tun, backend);
qs.udp_qsize1 = qs.udp_qsize2 = 0;
net.udp().SetPacketHandler(&backend->packet_processor_);
tun.SetPacketHandler(&backend->packet_processor_);
if (backend->config_file_[0] &&
!ParseWireGuardConfigFile(&wg_proc, backend->config_file_, &backend->dns_resolver_))
goto getout_fail;
if (!wg_proc.Start())
goto getout_fail;
backend->SetPublicKey(wg_proc.dev().public_key());
backend->wg_processor_ = &wg_proc;
net.StartThread();
tun.StartThread();
stop_mode = backend->packet_processor_.Run(&wg_proc, backend);
net.StopThread();
tun.StopThread();
backend->wg_processor_ = NULL;
// Keep DNS alive
if (stop_mode != MODE_EXIT)
tun.adapter().DisassociateDnsBlocker();
else
backend->dns_resolver_.ClearCache();
FreeAllPackets();
if (stop_mode != MODE_TUN_FAILED)
return 0;
uint32 last_fail = GetTickCount();
fast_retry_ctr = (last_fail - backend->last_tun_adapter_failed_ < 5000) ? fast_retry_ctr + 1 : 0;
backend->last_tun_adapter_failed_ = last_fail;
backend->SetStatus((fast_retry_ctr >= 3) ? TunsafeBackend::kErrorTunPermanent : TunsafeBackend::kStatusTunRetrying);
if (backend->status_ == TunsafeBackend::kErrorTunPermanent) {
RERROR("Too many automatic restarts...");
goto getout_fail_noseterr;
}
Sleep(1000);
}
getout_fail:
backend->status_ = TunsafeBackend::kErrorInitialize;
backend->delegate_->OnStatusCode(TunsafeBackend::kErrorInitialize);
getout_fail_noseterr:
backend->dns_blocker_.RestoreDns();
return 0;
}
TunsafeBackend::TunsafeBackend() {
is_started_ = false;
is_remote_ = false;
@ -2152,6 +1958,75 @@ TunsafeBackendWin32::~TunsafeBackendWin32() {
TunAdaptersInUse::GetInstance()->Release(this);
}
void TunsafeBackendWin32::SetPublicKey(const uint8 key[32]) {
memcpy(public_key_, key, 32);
delegate_->OnStateChanged();
}
DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) {
TunsafeBackendWin32 *backend = (TunsafeBackendWin32*)bk;
int stop_mode;
int fast_retry_ctr = 0;
for (;;) {
TunWin32Iocp tun(&backend->dns_blocker_, backend);
NetworkWin32 net;
WireguardProcessor wg_proc(&net, &tun, backend);
net.udp().SetPacketHandler(&backend->packet_processor_);
net.tcp_socket_queue().SetPacketHandler(&backend->packet_processor_);
tun.SetPacketHandler(&backend->packet_processor_);
if (backend->config_file_[0] &&
!ParseWireGuardConfigFile(&wg_proc, backend->config_file_, &backend->dns_resolver_))
goto getout_fail;
if (!wg_proc.Start())
goto getout_fail;
backend->SetPublicKey(wg_proc.dev().public_key());
backend->wg_processor_ = &wg_proc;
net.StartThread();
tun.StartThread();
stop_mode = backend->packet_processor_.Run(&wg_proc, backend);
net.StopThread();
tun.StopThread();
backend->wg_processor_ = NULL;
// Keep DNS alive
if (stop_mode != MODE_EXIT)
tun.adapter().DisassociateDnsBlocker();
else
backend->dns_resolver_.ClearCache();
FreeAllPackets();
if (stop_mode != MODE_TUN_FAILED)
return 0;
uint32 last_fail = GetTickCount();
fast_retry_ctr = (last_fail - backend->last_tun_adapter_failed_ < 5000) ? fast_retry_ctr + 1 : 0;
backend->last_tun_adapter_failed_ = last_fail;
backend->SetStatus((fast_retry_ctr >= 3) ? TunsafeBackend::kErrorTunPermanent : TunsafeBackend::kStatusTunRetrying);
if (backend->status_ == TunsafeBackend::kErrorTunPermanent) {
RERROR("Too many automatic restarts...");
goto getout_fail_noseterr;
}
Sleep(1000);
}
getout_fail:
backend->status_ = TunsafeBackend::kErrorInitialize;
backend->delegate_->OnStatusCode(TunsafeBackend::kErrorInitialize);
getout_fail_noseterr:
backend->dns_blocker_.RestoreDns();
return 0;
}
void TunsafeBackendWin32::SetStatus(StatusCode status) {
status_ = status;

View file

@ -2,15 +2,14 @@
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
#pragma once
#include "stdafx.h"
#include "tunsafe_types.h"
#include "netapi.h"
#include "network_win32_api.h"
#include "network_win32_dnsblock.h"
#include "wireguard_config.h"
#include "tunsafe_threading.h"
#include "tunsafe_dnsresolve.h"
#include <functional>
#include "network_common.h"
#include "network_win32_tcp.h"
enum {
ADAPTER_GUID_SIZE = 40,
@ -18,6 +17,7 @@ enum {
class WireguardProcessor;
class TunsafeBackendWin32;
class DnsBlocker;
struct PacketProcessorTunCb : QueuedItemCallback {
virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override;
@ -73,16 +73,15 @@ class PacketAllocPool;
// Encapsulates a UDP socket pair (ipv4 / ipv6), optionally listening for incoming packets
// on a specific port.
class UdpSocketWin32 : public UdpInterface, QueuedItemCallback {
class UdpSocketWin32 : public QueuedItemCallback {
public:
explicit UdpSocketWin32(NetworkWin32 *network_win32);
~UdpSocketWin32();
void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
// -- from UdpInterface
virtual bool Configure(int listen_on_port) override;
virtual void WriteUdpPacket(Packet *packet) override;
bool Configure(int listen_on_port);
inline void WriteUdpPacket(Packet *packet);
void DoIO();
void CancelAllIO();
@ -94,9 +93,9 @@ public:
};
private:
void DoMoreReads();
void DoMoreWrites();
void ProcessPackets();
// From OverlappedCallbacks
virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override;
@ -125,11 +124,16 @@ private:
Packet *finished_reads_, **finished_reads_end_;
int finished_reads_count_;
__declspec(align(64)) uint32 qsize1_;
__declspec(align(64)) uint32 qsize2_;
};
// Holds the thread for network communications
class NetworkWin32 {
class NetworkWin32 : public UdpInterface {
friend class UdpSocketWin32;
friend class TcpSocketWin32;
friend class TcpSocketQueue;
public:
explicit NetworkWin32();
~NetworkWin32();
@ -138,13 +142,20 @@ public:
void StopThread();
UdpSocketWin32 &udp() { return udp_socket_; }
SimplePacketPool &packet_pool() { return packet_pool_; }
TcpSocketQueue &tcp_socket_queue() { return tcp_socket_queue_; }
void WakeUp();
void PostQueuedItem(QueuedItem *item);
// -- from UdpInterface
virtual bool Configure(int listen_port_udp, int listen_port_tcp) override;
virtual void WriteUdpPacket(Packet *packet) override;
private:
void ThreadMain();
static DWORD WINAPI NetworkThread(void *x);
void FreePacketToPool(Packet *p);
bool AllocPacketFromPool(Packet **p);
bool HasOutstandingIO();
// The network thread handle
HANDLE thread_;
@ -155,18 +166,17 @@ private:
// The handle to the completion port
HANDLE completion_port_handle_;
Packet *freed_packets_, **freed_packets_end_;
int freed_packets_count_;
// Right now there's always one udp socket only
UdpSocketWin32 udp_socket_;
// A linked list of all tcp sockets
TcpSocketWin32 *tcp_socket_;
SimplePacketPool packet_pool_;
TcpSocketQueue tcp_socket_queue_;
};
class DnsBlocker;
class TunWin32Adapter {
public:
TunWin32Adapter(DnsBlocker *dns_blocker, const char guid[ADAPTER_GUID_SIZE]);
@ -243,42 +253,6 @@ private:
TunWin32Adapter adapter_;
};
// Implementation of TUN interface handling using Overlapped IO
class TunWin32Overlapped : public TunInterface {
public:
explicit TunWin32Overlapped(DnsBlocker *blocker, TunsafeBackendWin32 *backend);
~TunWin32Overlapped();
void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
void StartThread();
void StopThread();
// -- from TunInterface
virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override;
virtual void WriteTunPacket(Packet *packet) override;
private:
void CloseTun();
void ThreadMain();
static DWORD WINAPI TunThread(void *x);
PacketProcessor *packet_handler_;
HANDLE thread_;
Mutex mutex_;
HANDLE read_event_, write_event_, wake_event_;
bool exit_thread_;
Packet *wqueue_, **wqueue_end_;
TunWin32Adapter adapter_;
TunsafeBackendWin32 *backend_;
};
class TunsafeBackendWin32 : public TunsafeBackend, public ProcessorDelegate {
friend class PacketProcessor;
friend class TunWin32Iocp;
@ -427,3 +401,8 @@ private:
uint8 num_inuse_;
Entry entry_[kMaxAdaptersInUse];
};
static inline void ClearOverlapped(OVERLAPPED *o) {
memset(o, 0, sizeof(*o));
}

View file

@ -125,6 +125,3 @@ protected:
TunsafeBackend *CreateNativeTunsafeBackend(TunsafeBackend::Delegate *delegate);
TunsafeBackend::Delegate *CreateTunsafeBackendDelegateThreaded(TunsafeBackend::Delegate *delegate, const std::function<void(void)> &callback);
extern int tpq_last_qsize;
extern int g_tun_reads, g_tun_writes;

344
network_win32_tcp.cpp Normal file
View file

@ -0,0 +1,344 @@
// SPDX-License-Identifier: AGPL-1.0-only
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
#include <stdafx.h>
#include "network_win32_tcp.h"
#include "network_win32.h"
#include <Mswsock.h>
#include <ws2ipdef.h>
#include "util.h"
////////////////////////////////////////////////////////////////////////////////////////////////////////
TcpSocketWin32::TcpSocketWin32(NetworkWin32 *network)
: tcp_packet_handler_(&network->packet_pool()) {
network_ = network;
reads_active_ = 0;
writes_active_ = 0;
handshake_attempts = 0;
state_ = STATE_NONE;
wqueue_ = NULL;
wqueue_end_ = &wqueue_;
socket_ = INVALID_SOCKET;
next_ = NULL;
packet_processor_ = NULL;
// insert in network's linked list
next_ = network->tcp_socket_;
network->tcp_socket_ = this;
}
TcpSocketWin32::~TcpSocketWin32() {
// Unlink myself from the network's linked list.
TcpSocketWin32 **p = &network_->tcp_socket_;
while (*p != this) p = &(*p)->next_;
*p = next_;
FreePacketList(wqueue_);
if (socket_ != INVALID_SOCKET)
closesocket(socket_);
}
void TcpSocketWin32::CloseSocket() {
if (socket_ != INVALID_SOCKET)
CancelIo((HANDLE)socket_);
state_ = STATE_ERROR;
endpoint_protocol_ = 0;
}
void TcpSocketWin32::WritePacket(Packet *packet) {
packet->queue_next = NULL;
*wqueue_end_ = packet;
wqueue_end_ = &Packet_NEXT(packet);
}
void TcpSocketWin32::CancelAllIO() {
if (socket_ != INVALID_SOCKET)
CancelIo((HANDLE)socket_);
}
static const GUID WsaConnectExGUID = WSAID_CONNECTEX;
void TcpSocketWin32::DoConnect() {
LPFN_CONNECTEX ConnectEx;
assert(socket_ == INVALID_SOCKET);
socket_ = WSASocket(endpoint_.sin.sin_family, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED);
if (socket_ == INVALID_SOCKET) {
RERROR("socket() failed");
CloseSocket();
return;
}
if (!CreateIoCompletionPort((HANDLE)socket_, network_->completion_port_handle_, 0, 0)) {
RERROR("TcpSocketWin32::DoConnect CreateIoCompletionPort failed");
CloseSocket();
return;
}
int nodelay = 1;
setsockopt(socket_, IPPROTO_TCP, TCP_NODELAY, (char*)&nodelay, 1);
DWORD dwBytes = sizeof(ConnectEx);
DWORD rc = WSAIoctl(socket_, SIO_GET_EXTENSION_FUNCTION_POINTER, (uint8*)&WsaConnectExGUID, sizeof(WsaConnectExGUID), &ConnectEx, sizeof(ConnectEx), &dwBytes, NULL, NULL);
assert(rc == 0);
// ConnectEx requires the socket to be bound
sockaddr_in sin = {0};
sin.sin_family = AF_INET;
sin.sin_addr.s_addr = INADDR_ANY;
sin.sin_port = 0;
if (bind(socket_, (sockaddr*)&sin, sizeof(sin))) {
RERROR("TcpSocketWin32::DoConnect bind failed: %d", WSAGetLastError());
CloseSocket();
return;
}
char buf[kSizeOfAddress];
RINFO("Connecting to tcp://%s...", PrintIpAddr(endpoint_, buf));
state_ = STATE_CONNECTING;
ClearOverlapped(&connect_overlapped_.overlapped);
connect_overlapped_.queue_cb = this;
if (!ConnectEx(socket_, (const sockaddr*)&endpoint_.sin, sizeof(endpoint_.sin), NULL, 0, NULL, &connect_overlapped_.overlapped)) {
int err = WSAGetLastError();
if (err != ERROR_IO_PENDING) {
RERROR("ConnectEx failed: %d", err);
CloseSocket();
return;
}
}
reads_active_ = 1;
}
void TcpSocketWin32::DoMoreReads() {
assert(state_ != STATE_ERROR);
if (reads_active_ == 0) {
// Initiate a new read, we always read into 4 buffers.
Packet *p = network_->packet_pool().AllocPacketFromPool();
if (!p)
return;
ClearOverlapped(&p->overlapped);
p->userdata = 0;
p->queue_cb = this;
DWORD flags = 0;
WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data};
if (WSARecv(socket_, &wsabuf, 1, NULL, &flags, &p->overlapped, NULL) != 0) {
DWORD err = WSAGetLastError();
if (err != ERROR_IO_PENDING) {
RERROR("TcpSocketWin32:WSARecv failed 0x%X", err);
FreePacket(p);
return;
}
}
reads_active_ = 1;
}
}
void TcpSocketWin32::DoMoreWrites() {
assert(state_ != STATE_ERROR);
if (writes_active_ == 0) {
WSABUF wsabuf[kMaxWsaBuf];
uint32 num_wsabuf = 0;
Packet *p = wqueue_;
if (p == NULL)
return;
do {
tcp_packet_handler_.AddHeaderToOutgoingPacket(p);
wsabuf[num_wsabuf].buf = (char*)p->data;
wsabuf[num_wsabuf].len = (ULONG)p->size;
packets_in_write_io_[num_wsabuf] = p;
p = Packet_NEXT(p);
} while (++num_wsabuf < kMaxWsaBuf && p != NULL);
if (!(wqueue_ = p))
wqueue_end_ = &wqueue_;
num_wsabuf_ = (uint8)num_wsabuf;
p = packets_in_write_io_[0];
ClearOverlapped(&p->overlapped);
p->userdata = 1;
p->queue_cb = this;
if (WSASend(socket_, wsabuf, num_wsabuf, NULL, 0, &p->overlapped, NULL) != 0) {
DWORD err = WSAGetLastError();
if (err != ERROR_IO_PENDING) {
RERROR("TcpSocketWin32: WSASend failed 0x%X", err);
FreePacket(p);
CloseSocket();
return;
}
}
writes_active_ = 1;
}
}
void TcpSocketWin32::DoIO() {
if (state_ == STATE_CONNECTED) {
DoMoreReads();
while (Packet *p = tcp_packet_handler_.GetNextWireguardPacket()) {
p->protocol = endpoint_protocol_;
p->addr = endpoint_;
p->queue_cb = packet_processor_->udp_queue();
packet_processor_->ForcePost(p);
}
if (tcp_packet_handler_.error()) {
CloseSocket();
DoIO();
return;
}
DoMoreWrites();
} else if (state_ == STATE_WANT_CONNECT) {
DoConnect();
} else if (state_ == STATE_ERROR && !HasOutstandingIO()) {
delete this;
}
}
bool TcpSocketWin32::HasOutstandingIO() {
return writes_active_ + reads_active_ != 0;
}
void TcpSocketWin32::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) {
if (qi == &connect_overlapped_) {
assert(state_ == STATE_CONNECTING);
reads_active_ = 0;
if ((DWORD)qi->overlapped.Internal != 0) {
if (state_ != STATE_ERROR) {
RERROR("TcpSocketWin32::Connect error 0x%X", (DWORD)qi->overlapped.Internal);
CloseSocket();
}
} else {
state_ = STATE_CONNECTED;
}
return;
}
Packet *p = static_cast<Packet*>(qi);
if (p->userdata == 0) {
// Read operation complete
if ((DWORD)p->overlapped.Internal != 0) {
if (state_ != STATE_ERROR) {
RERROR("TcpSocketWin32::Read error 0x%X", (DWORD)p->overlapped.Internal);
CloseSocket();
}
network_->packet_pool().FreePacketToPool(p);
// What to do?
} else if ((int)p->overlapped.InternalHigh == 0) {
// Socket closed successfully
CloseSocket();
network_->packet_pool().FreePacketToPool(p);
} else {
// Queue it up to rqueue
p->size = (int)p->overlapped.InternalHigh;
tcp_packet_handler_.QueueIncomingPacket(p);
}
reads_active_--;
} else {
assert(writes_active_);
assert(packets_in_write_io_[0] == p);
if ((DWORD)p->overlapped.Internal != 0) {
if (state_ != STATE_ERROR) {
RERROR("TcpSocketWin32::Write error 0x%X", (DWORD)p->overlapped.Internal);
CloseSocket();
}
}
// free all the packets involved in the write
for (size_t i = 0; i < num_wsabuf_; i++)
network_->packet_pool().FreePacketToPool(packets_in_write_io_[i]);
writes_active_--;
}
}
void TcpSocketWin32::OnQueuedItemDelete(QueuedItem *qi) {
if (qi == &connect_overlapped_) {
reads_active_ = 0;
return;
}
Packet *p = static_cast<Packet*>(qi);
if (p->userdata == 0) {
FreePacket(p);
reads_active_--;
} else {
for (size_t i = 0; i < num_wsabuf_; i++)
network_->packet_pool().FreePacketToPool(packets_in_write_io_[i]);
writes_active_--;
}
}
/////////////////////////////////////////////////////////////////////////
TcpSocketQueue::TcpSocketQueue(NetworkWin32 *network) {
network_ = network;
wqueue_ = NULL;
wqueue_end_ = &wqueue_;
queued_item_.queue_cb = this;
packet_handler_ = NULL;
}
TcpSocketQueue::~TcpSocketQueue() {
FreePacketList(wqueue_);
}
void TcpSocketQueue::TransmitOnePacket(Packet *packet) {
// Check if we have a tcp connection for the endpoint, otherwise create one.
for (TcpSocketWin32 *tcp = network_->tcp_socket_; tcp; tcp = tcp->next_) {
// After we send 3 handshakes on a tcp socket in a row, then close and reopen the socket because it seems defunct.
if (CompareIpAddr(&tcp->endpoint_, &packet->addr) == 0 && tcp->endpoint_protocol_ == packet->protocol) {
if (ReadLE32(packet->data) == MESSAGE_HANDSHAKE_INITIATION) {
if (tcp->handshake_attempts == 2) {
RINFO("Making new Tcp socket due to too many handshake failures");
tcp->CloseSocket();
break;
}
tcp->handshake_attempts++;
} else {
tcp->handshake_attempts = -1;
}
tcp->WritePacket(packet);
return;
}
}
// Drop tcp packet that's for an incoming connection, or packets that are
// not a handshake.
if ((packet->protocol & kPacketProtocolIncomingConnection) ||
packet->size < 4 || ReadLE32(packet->data) != MESSAGE_HANDSHAKE_INITIATION) {
FreePacket(packet);
return;
}
// Initialize a new tcp socket and connect to the endpoint
TcpSocketWin32 *tcp = new TcpSocketWin32(network_);
tcp->state_ = TcpSocketWin32::STATE_WANT_CONNECT;
tcp->endpoint_ = packet->addr;
tcp->endpoint_protocol_ = kPacketProtocolTcp;
tcp->SetPacketHandler(packet_handler_);
tcp->WritePacket(packet);
}
void TcpSocketQueue::OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) {
wqueue_mutex_.Acquire();
Packet *packet = wqueue_;
wqueue_ = NULL;
wqueue_end_ = &wqueue_;
wqueue_mutex_.Release();
while (packet)
TransmitOnePacket(exch(packet, Packet_NEXT(packet)));
}
void TcpSocketQueue::OnQueuedItemDelete(QueuedItem *ow) {
}
void TcpSocketQueue::WritePacket(Packet *packet) {
packet->queue_next = NULL;
wqueue_mutex_.Acquire();
Packet *was_empty = wqueue_;
*wqueue_end_ = packet;
wqueue_end_ = &Packet_NEXT(packet);
wqueue_mutex_.Release();
if (was_empty == NULL)
network_->PostQueuedItem(&queued_item_);
}

124
network_win32_tcp.h Normal file
View file

@ -0,0 +1,124 @@
// SPDX-License-Identifier: AGPL-1.0-only
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
#pragma once
#include "netapi.h"
#include "network_common.h"
#include "tunsafe_threading.h"
class NetworkWin32;
class PacketProcessor;
class TcpSocketWin32 : public QueuedItemCallback {
friend class NetworkWin32;
friend class TcpSocketQueue;
public:
explicit TcpSocketWin32(NetworkWin32 *network);
~TcpSocketWin32();
void SetPacketHandler(PacketProcessor *packet_handler) { packet_processor_ = packet_handler; }
// Write a packet to the TCP socket. This may be called only from the
// wireguard thread. Will append to a buffer and schedule it to be written
// from the network thread.
void WritePacket(Packet *packet);
// Call from IO completion thread to cancel all outstanding IO
void CancelAllIO();
// Call from IO completion thread to run more IO
void DoIO();
// Returns true if there's IO still left to run
bool HasOutstandingIO();
private:
void DoMoreReads();
void DoMoreWrites();
void DoConnect();
void CloseSocket();
// From OverlappedCallbacks
virtual void OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) override;
virtual void OnQueuedItemDelete(QueuedItem *qi) override;
// Network subsystem
NetworkWin32 *network_;
PacketProcessor *packet_processor_;
enum {
STATE_NONE = 0,
STATE_ERROR = 1,
STATE_CONNECTING = 2,
STATE_CONNECTED = 3,
STATE_WANT_CONNECT = 4,
};
uint8 reads_active_;
uint8 writes_active_;
uint8 state_;
uint8 num_wsabuf_;
public:
uint8 handshake_attempts;
private:
// The handle to the socket
SOCKET socket_;
// Packets taken over by the network thread waiting to be written,
// when these are written we'll start eating from wqueue_
Packet *pending_writes_;
// All packets queued for writing on the network thread.
Packet *wqueue_, **wqueue_end_;
// Linked list of all TcpSocketWin32 wsockets
TcpSocketWin32 *next_;
// Handles packet parsing
TcpPacketHandler tcp_packet_handler_;
// An overlapped instance used for the initial Connect() call.
QueuedItem connect_overlapped_;
IpAddr endpoint_;
uint8 endpoint_protocol_;
// Packets currently involved in the wsabuf writing
enum { kMaxWsaBuf = 32 };
Packet *packets_in_write_io_[kMaxWsaBuf];
};
class TcpSocketQueue : public QueuedItemCallback {
public:
explicit TcpSocketQueue(NetworkWin32 *network);
~TcpSocketQueue();
void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override;
virtual void OnQueuedItemDelete(QueuedItem *ow) override;
void WritePacket(Packet *packet);
private:
void TransmitOnePacket(Packet *packet);
NetworkWin32 *network_;
// All packets queued for writing on the network thread. Locked by |wqueue_mutex_|
Packet *wqueue_, **wqueue_end_;
PacketProcessor *packet_handler_;
// Protects wqueue_
Mutex wqueue_mutex_;
// Used for queueing things on the network instance
QueuedItem queued_item_;
};

View file

@ -23,7 +23,7 @@
#if defined(WITH_NETWORK_BSD)
#include "network_bsd.cpp"
#include "network_bsd_common.cpp"
#include "tunsafe_bsd.cpp"
#include "ts.cpp"
#include "benchmark.cpp"
#endif

View file

@ -1,6 +1,6 @@
// SPDX-License-Identifier: AGPL-1.0-only
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
#include "network_bsd_common.h"
#include "tunsafe_bsd.h"
#include "tunsafe_endian.h"
#include "util.h"
@ -43,17 +43,6 @@
#include <limits.h>
#endif
void tunsafe_die(const char *msg) {
fprintf(stderr, "%s\n", msg);
exit(1);
}
void SetThreadName(const char *name) {
#if defined(OS_LINUX)
prctl(PR_SET_NAME, name, 0, 0, 0);
#endif // defined(OS_LINUX)
}
#if defined(OS_MACOSX) || defined(OS_FREEBSD)
struct MyRouteMsg {
struct rt_msghdr hdr;
@ -346,21 +335,7 @@ int open_tun(char *devname, size_t devname_size) {
}
#endif
int open_udp(int listen_on_port) {
int udp_fd = socket(AF_INET, SOCK_DGRAM, 0);
if (udp_fd < 0) return udp_fd;
sockaddr_in sin = {0};
sin.sin_family = AF_INET;
sin.sin_port = htons(listen_on_port);
if (bind(udp_fd, (struct sockaddr*)&sin, sizeof(sin)) != 0) {
close(udp_fd);
return -1;
}
return udp_fd;
}
TunsafeBackendBsd::TunsafeBackendBsd()
: processor_(NULL) {
TunsafeBackendBsd::TunsafeBackendBsd() {
devname_[0] = 0;
tun_interface_gone_ = false;
}
@ -579,122 +554,33 @@ bool TunsafeBackendBsd::RunPrePostCommand(const std::vector<std::string> &vec) {
return success;
}
#if defined(OS_LINUX)
UnixSocketDeletionWatcher::UnixSocketDeletionWatcher()
: inotify_fd_(-1) {
pipes_[0] = -1;
pipes_[0] = -1;
}
UnixSocketDeletionWatcher::~UnixSocketDeletionWatcher() {
close(inotify_fd_);
close(pipes_[0]);
close(pipes_[1]);
}
bool UnixSocketDeletionWatcher::Start(const char *path, bool *flag_to_set) {
assert(inotify_fd_ == -1);
path_ = path;
flag_to_set_ = flag_to_set;
pid_ = getpid();
inotify_fd_ = inotify_init1(IN_CLOEXEC | IN_NONBLOCK);
if (inotify_fd_ == -1) {
perror("inotify_init1() failed");
return false;
}
if (inotify_add_watch(inotify_fd_, "/var/run/wireguard", IN_DELETE | IN_DELETE_SELF) == -1) {
perror("inotify_add_watch failed");
return false;
}
if (pipe(pipes_) == -1) {
perror("pipe() failed");
return false;
}
return pthread_create(&thread_, NULL, &UnixSocketDeletionWatcher::RunThread, this) == 0;
}
void UnixSocketDeletionWatcher::Stop() {
RINFO("Stopping..");
void *retval;
write(pipes_[1], "", 1);
pthread_join(thread_, &retval);
}
void *UnixSocketDeletionWatcher::RunThread(void *arg) {
UnixSocketDeletionWatcher *self = (UnixSocketDeletionWatcher*)arg;
return self->RunThreadInner();
}
void *UnixSocketDeletionWatcher::RunThreadInner() {
char buf[sizeof(struct inotify_event) + NAME_MAX + 1]
__attribute__ ((aligned(__alignof__(struct inotify_event))));
fd_set fdset;
struct stat st;
for(;;) {
if (lstat(path_, &st) == -1 && errno == ENOENT) {
RINFO("Unix socket %s deleted.", path_);
*flag_to_set_ = true;
kill(pid_, SIGALRM);
break;
}
FD_ZERO(&fdset);
FD_SET(inotify_fd_, &fdset);
FD_SET(pipes_[0], &fdset);
int n = select(std::max(inotify_fd_, pipes_[0]) + 1, &fdset, NULL, NULL, NULL);
if (n == -1) {
perror("select");
break;
}
if (FD_ISSET(inotify_fd_, &fdset)) {
ssize_t len = read(inotify_fd_, buf, sizeof(buf));
if (len == -1) {
perror("read");
break;
}
}
if (FD_ISSET(pipes_[0], &fdset))
break;
}
return NULL;
}
#else // !defined(OS_LINUX)
bool UnixSocketDeletionWatcher::Poll(const char *path) {
struct stat st;
return lstat(path, &st) == -1 && errno == ENOENT;
}
#endif // !defined(OS_LINUX)
static TunsafeBackendBsd *g_tunsafe_backend_bsd;
static void SigAlrm(int sig) {
if (g_tunsafe_backend_bsd)
g_tunsafe_backend_bsd->HandleSigAlrm();
}
static SignalCatcher *g_signal_catcher;
static bool did_ctrlc;
void SigInt(int sig) {
void SignalCatcher::SigAlrm(int sig) {
if (g_signal_catcher)
*g_signal_catcher->sigalarm_flag_ = true;
}
void SignalCatcher::SigInt(int sig) {
if (did_ctrlc)
exit(1);
did_ctrlc = true;
write(1, "Ctrl-C detected. Exiting. Press again to force quit.\n", sizeof("Ctrl-C detected. Exiting. Press again to force quit.\n")-1);
write(1, "Ctrl-C detected. Exiting. Press again to force quit.\n", sizeof("Ctrl-C detected. Exiting. Press again to force quit.\n") - 1);
// todo: fix signal safety?
if (g_tunsafe_backend_bsd)
g_tunsafe_backend_bsd->HandleExit();
if (g_signal_catcher)
*g_signal_catcher->exit_flag_ = true;
}
void TunsafeBackendBsd::RunLoop() {
assert(!g_tunsafe_backend_bsd);
assert(processor_);
SignalCatcher::SignalCatcher(bool *exit_flag, bool *sigalarm_flag) {
assert(g_signal_catcher == NULL);
exit_flag_ = exit_flag;
sigalarm_flag_ = sigalarm_flag;
g_signal_catcher = this;
sigset_t mask;
g_tunsafe_backend_bsd = this;
// We want an alarm signal every second.
{
struct sigaction act = {0};
@ -713,7 +599,6 @@ void TunsafeBackendBsd::RunLoop() {
return;
}
}
#if defined(OS_LINUX) || defined(OS_FREEBSD)
sigemptyset(&mask);
sigaddset(&mask, SIGALRM);
@ -737,7 +622,7 @@ void TunsafeBackendBsd::RunLoop() {
if (timer_create(CLOCK_MONOTONIC, &sev, &timer_id) < 0) {
RERROR("timer_create failed");
return;
}
}
if (timer_settime(timer_id, 0, &tv, NULL) < 0) {
RERROR("timer_settime failed");
@ -747,51 +632,209 @@ void TunsafeBackendBsd::RunLoop() {
#elif defined(OS_MACOSX)
ualarm(1000000, 1000000);
#endif
}
RunLoopInner();
g_tunsafe_backend_bsd = NULL;
SignalCatcher::~SignalCatcher() {
g_signal_catcher = NULL;
}
void InitCpuFeatures();
void Benchmark();
const char *print_ip(char buf[kSizeOfAddress], in_addr_t ip) {
snprintf(buf, kSizeOfAddress, "%d.%d.%d.%d", (ip >> 24) & 0xff, (ip >> 16) & 0xff, (ip >> 8) & 0xff, (ip >> 0) & 0xff);
return buf;
}
class MyProcessorDelegate : public ProcessorDelegate {
class TunsafeBackendBsdImpl : public TunsafeBackendBsd, public NetworkBsd::NetworkBsdDelegate, public ProcessorDelegate {
public:
MyProcessorDelegate() {
wg_processor_ = NULL;
is_connected_ = false;
}
TunsafeBackendBsdImpl();
virtual ~TunsafeBackendBsdImpl();
virtual void OnConnected() override {
if (!is_connected_) {
const WgCidrAddr *ipv4_addr = NULL;
for (const WgCidrAddr &x : wg_processor_->addr()) {
if (x.size == 32) { ipv4_addr = &x; break; }
}
uint32 ipv4_ip = ipv4_addr ? ReadBE32(ipv4_addr->addr) : 0;
char buf[kSizeOfAddress];
RINFO("Connection established. IP %s", ipv4_ip ? print_ip(buf, ipv4_ip) : "(none)");
is_connected_ = true;
}
}
virtual void OnConnectionRetry(uint32 attempts) override {
if (is_connected_ && attempts >= 3) {
is_connected_ = false;
RINFO("Reconnecting...");
}
}
void RunLoop();
virtual bool InitializeTun(char devname[16]) override;
// -- from TunInterface
virtual void WriteTunPacket(Packet *packet) override;
// -- from UdpInterface
virtual bool Configure(int listen_port_udp, int listen_port_tcp) override;
virtual void WriteUdpPacket(Packet *packet) override;
// -- from NetworkBsdDelegate
virtual void OnSecondLoop(uint64 now) override;
virtual void RunAllMainThreadScheduled() override;
// -- from ProcessorDelegate
virtual void OnConnected() override;
virtual void OnConnectionRetry(uint32 attempts) override;
WireguardProcessor *processor() { return &processor_; }
private:
void WriteTcpPacket(Packet *packet);
// Close all TCP connections that are not pointed to by any of the peer endpoint.
void CloseOrphanTcpConnections();
WireguardProcessor *wg_processor_;
bool is_connected_;
uint8 close_orphan_counter_;
WireguardProcessor processor_;
NetworkBsd network_;
TunSocketBsd tun_;
UdpSocketBsd udp_;
UnixDomainSocketListenerBsd unix_socket_listener_;
TcpSocketListenerBsd tcp_socket_listener_;
};
TunsafeBackendBsdImpl::TunsafeBackendBsdImpl()
: is_connected_(false),
close_orphan_counter_(0),
processor_(this, this, this),
network_(this, 1000),
tun_(&network_, &processor_),
udp_(&network_, &processor_),
unix_socket_listener_(&network_, &processor_),
tcp_socket_listener_(&network_, &processor_) {
}
TunsafeBackendBsdImpl::~TunsafeBackendBsdImpl() {
}
bool TunsafeBackendBsdImpl::InitializeTun(char devname[16]) {
int tun_fd = open_tun(devname, 16);
if (tun_fd < 0) { RERROR("Error opening tun device"); return false; }
if (!tun_.Initialize(tun_fd)) {
close(tun_fd);
return false;
}
unix_socket_listener_.Initialize(devname);
return true;
}
void TunsafeBackendBsdImpl::WriteTunPacket(Packet *packet) {
tun_.WritePacket(packet);
}
// Called to initialize udp
bool TunsafeBackendBsdImpl::Configure(int listen_port, int listen_port_tcp) {
return udp_.Initialize(listen_port) &&
(listen_port_tcp == 0 || tcp_socket_listener_.Initialize(listen_port_tcp));
}
void TunsafeBackendBsdImpl::WriteTcpPacket(Packet *packet) {
// Check if we have a tcp connection for the endpoint, otherwise create one.
for (TcpSocketBsd *tcp = network_.tcp_sockets(); tcp; tcp = tcp->next()) {
// After we send 3 handshakes on a tcp socket in a row, then close and reopen the socket because it seems defunct.
if (CompareIpAddr(&tcp->endpoint(), &packet->addr) == 0 && tcp->endpoint_protocol() == packet->protocol) {
if (ReadLE32(packet->data) == MESSAGE_HANDSHAKE_INITIATION) {
if (tcp->handshake_attempts == 2) {
RINFO("Making new Tcp socket due to too many handshake failures");
delete tcp;
break;
}
tcp->handshake_attempts++;
} else {
tcp->handshake_attempts = -1;
}
tcp->WritePacket(packet);
return;
}
}
// Drop tcp packet that's for an incoming connection, or packets that are
// not a handshake.
if ((packet->protocol & kPacketProtocolIncomingConnection) ||
ReadLE32(packet->data) != MESSAGE_HANDSHAKE_INITIATION) {
FreePacket(packet);
return;
}
// Initialize a new tcp socket and connect to the endpoint
TcpSocketBsd *tcp = new TcpSocketBsd(&network_, &processor_);
if (!tcp || !tcp->InitializeOutgoing(packet->addr)) {
delete tcp;
FreePacket(packet);
return;
}
tcp->WritePacket(packet);
}
void TunsafeBackendBsdImpl::WriteUdpPacket(Packet *packet) {
assert((packet->protocol & 0x7F) <= 2);
if (packet->protocol & kPacketProtocolTcp) {
WriteTcpPacket(packet);
} else {
udp_.WritePacket(packet);
}
}
void TunsafeBackendBsdImpl::RunLoop() {
if (!unix_socket_listener_.Start(network_.exit_flag()))
return;
SignalCatcher signal_catcher(network_.exit_flag(), network_.sigalarm_flag());
network_.RunLoop(&signal_catcher.orig_signal_mask_);
unix_socket_listener_.Stop();
tun_interface_gone_ = tun_.tun_interface_gone();
}
void TunsafeBackendBsdImpl::OnSecondLoop(uint64 now) {
if (!(close_orphan_counter_++ & 0xF))
CloseOrphanTcpConnections();
processor_.SecondLoop();
}
void TunsafeBackendBsdImpl::RunAllMainThreadScheduled() {
processor_.RunAllMainThreadScheduled();
}
void TunsafeBackendBsdImpl::OnConnected() {
if (!is_connected_) {
const WgCidrAddr *ipv4_addr = NULL;
for (const WgCidrAddr &x : processor_.addr()) {
if (x.size == 32) { ipv4_addr = &x; break; }
}
uint32 ipv4_ip = ipv4_addr ? ReadBE32(ipv4_addr->addr) : 0;
char buf[kSizeOfAddress];
RINFO("Connection established. IP %s", ipv4_ip ? print_ip(buf, ipv4_ip) : "(none)");
is_connected_ = true;
}
}
void TunsafeBackendBsdImpl::OnConnectionRetry(uint32 attempts) {
if (is_connected_ && attempts >= 3) {
is_connected_ = false;
RINFO("Reconnecting...");
}
}
void TunsafeBackendBsdImpl::CloseOrphanTcpConnections() {
// Add all incoming tcp connections into a lookup table
WG_HASHTABLE_IMPL<WgAddrEntry::IpPort, void*, WgAddrEntry::IpPortHasher> lookup;
for (TcpSocketBsd *tcp = network_.tcp_sockets(); tcp; tcp = tcp->next()) {
if (tcp->endpoint_protocol() == (kPacketProtocolTcp | kPacketProtocolIncomingConnection)) {
// Avoid deleting tcp sockets that were just born.
if (tcp->age == 0) {
tcp->age = 1;
} else {
lookup[ConvertIpAddrToAddrX(tcp->endpoint())] = tcp;
}
}
}
if (lookup.empty())
return;
// For each peer, check if it has an endpoint that matches
// an entry in the lookup table, and delete it from the lookup
// table.
for(WgPeer *peer = processor_.dev().first_peer(); peer; peer = peer->next_peer()) {
if (peer->endpoint_protocol() == (kPacketProtocolTcp | kPacketProtocolIncomingConnection))
lookup.erase(ConvertIpAddrToAddrX(peer->endpoint()));
}
// The tcp connections that are still in the hashtable can be deleted
for(const auto &it : lookup)
delete (TcpSocketBsd *)it.second;
}
int main(int argc, char **argv) {
CommandLineOutput cmd = {0};
@ -812,20 +855,15 @@ int main(int argc, char **argv) {
SetThreadName("tunsafe-m");
MyProcessorDelegate my_procdel;
TunsafeBackendBsd *backend = CreateTunsafeBackendBsd();
TunsafeBackendBsdImpl backend;
if (cmd.interface_name)
backend->SetTunDeviceName(cmd.interface_name);
WireguardProcessor wg(backend, backend, &my_procdel);
my_procdel.wg_processor_ = &wg;
backend->SetProcessor(&wg);
backend.SetTunDeviceName(cmd.interface_name);
DnsResolver dns_resolver(NULL);
if (*cmd.filename_to_load && !ParseWireGuardConfigFile(&wg, cmd.filename_to_load, &dns_resolver))
if (*cmd.filename_to_load && !ParseWireGuardConfigFile(backend.processor(), cmd.filename_to_load, &dns_resolver))
return 1;
if (!backend.processor()->Start())
return 1;
if (!wg.Start()) return 1;
if (cmd.daemon) {
fprintf(stderr, "Switching to daemon mode...\n");
@ -833,9 +871,8 @@ int main(int argc, char **argv) {
perror("daemon() failed");
}
backend->RunLoop();
backend->CleanupRoutes();
delete backend;
backend.RunLoop();
backend.CleanupRoutes();
return 0;
}

60
tunsafe_bsd.h Normal file
View file

@ -0,0 +1,60 @@
// SPDX-License-Identifier: AGPL-1.0-only
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
#ifndef TUNSAFE_NETWORK_BSD_COMMON_H_
#define TUNSAFE_NETWORK_BSD_COMMON_H_
#include "netapi.h"
#include "wireguard.h"
#include "wireguard_config.h"
#include <string>
#include <signal.h>
struct RouteInfo {
uint8 family;
uint8 cidr;
uint8 ip[16];
uint8 gw[16];
std::string dev;
};
class SignalCatcher {
public:
SignalCatcher(bool *exit_flag, bool *sigalarm_flag);
~SignalCatcher();
sigset_t orig_signal_mask_;
private:
static void SigAlrm(int sig);
static void SigInt(int sig);
bool *exit_flag_;
bool *sigalarm_flag_;
};
class TunsafeBackendBsd : public TunInterface, public UdpInterface {
public:
TunsafeBackendBsd();
virtual ~TunsafeBackendBsd();
void CleanupRoutes();
void SetTunDeviceName(const char *name);
// -- from TunInterface
virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override;
protected:
virtual bool InitializeTun(char devname[16]) = 0;
void AddRoute(uint32 ip, uint32 cidr, uint32 gw, const char *dev);
void DelRoute(const RouteInfo &cd);
bool AddRoute(int family, const void *dest, int dest_prefix, const void *gateway, const char *dev);
bool RunPrePostCommand(const std::vector<std::string> &vec);
std::vector<RouteInfo> cleanup_commands_;
std::vector<std::string> pre_down_, post_down_;
std::vector<WgCidrAddr> addresses_to_remove_;
char devname_[16];
bool tun_interface_gone_;
};
#endif // TUNSAFE_NETWORK_BSD_COMMON_H_

View file

@ -28,6 +28,7 @@ WireguardProcessor::WireguardProcessor(UdpInterface *udp, TunInterface *tun, Pro
mtu_ = 1420;
memset(&stats_, 0, sizeof(stats_));
listen_port_ = 0;
listen_port_tcp_ = 0;
network_discovery_spoofing_ = false;
add_routes_mode_ = true;
dns_blocking_ = true;
@ -50,6 +51,16 @@ void WireguardProcessor::SetListenPort(int listen_port) {
}
}
void WireguardProcessor::SetListenPortTcp(int listen_port) {
if (listen_port_tcp_ != listen_port) {
listen_port_tcp_ = listen_port;
if (is_started_ && !ConfigureUdp()) {
RINFO("ConfigureUdp failed");
}
}
}
void WireguardProcessor::AddDnsServer(const IpAddr &sin) {
dns_addr_.push_back(sin);
}
@ -126,7 +137,7 @@ bool WireguardProcessor::Start() {
bool WireguardProcessor::ConfigureUdp() {
assert(dev_.IsMainThread());
return udp_->Configure(listen_port_);
return udp_->Configure(listen_port_, listen_port_tcp_);
}
bool WireguardProcessor::ConfigureTun() {

View file

@ -81,6 +81,8 @@ public:
~WireguardProcessor();
void SetListenPort(int listen_port);
void SetListenPortTcp(int listen_port);
void AddDnsServer(const IpAddr &sin);
bool SetTunAddress(const WgCidrAddr &addr);
void ClearTunAddress();
@ -132,6 +134,7 @@ private:
UdpInterface *udp_;
uint16 listen_port_;
uint16 listen_port_tcp_;
uint16 mtu_;
bool dns_blocking_;

View file

@ -104,6 +104,8 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
wg_->dev().SetPrivateKey(binkey);
} else if (strcmp(key, "ListenPort") == 0) {
wg_->SetListenPort(atoi(value));
} else if (strcmp(key, "ListenPortTCP") == 0) {
wg_->SetListenPortTcp(atoi(value));
} else if (strcmp(key, "Address") == 0) {
SplitString(value, ',', &ss);
for (size_t i = 0; i < ss.size(); i++) {