diff --git a/TunSafe.vcxproj b/TunSafe.vcxproj
index 42a9fbb..d0b93cf 100644
--- a/TunSafe.vcxproj
+++ b/TunSafe.vcxproj
@@ -185,10 +185,24 @@
+
+ true
+ true
+ true
+ true
+
+
+
+
+ true
+ true
+ true
+ true
+
@@ -215,8 +229,28 @@
+
+ true
+ true
+ true
+ true
+
+
+ true
+ true
+ true
+ true
+
+
+
+
+ true
+ true
+ true
+ true
+
diff --git a/TunSafe.vcxproj.filters b/TunSafe.vcxproj.filters
index 3134cb4..10ced29 100644
--- a/TunSafe.vcxproj.filters
+++ b/TunSafe.vcxproj.filters
@@ -23,6 +23,12 @@
{1ca37c7b-e91e-4648-9584-7d0c73d8e416}
+
+ {4b2f2fd9-780e-45db-8fe1-f03079439723}
+
+
+ {0f45e1a0-f33e-4c6e-88ae-eb4639f12041}
+
@@ -92,7 +98,6 @@
Source Files\Win32
-
crypto\blake2s
@@ -123,6 +128,21 @@
Source Files
+
+ crypto\siphash
+
+
+ Source Files\BSD
+
+
+ Source Files\BSD
+
+
+ Source Files
+
+
+ Source Files\Win32
+
@@ -173,9 +193,6 @@
Source Files\Win32
-
- Source Files
-
crypto\blake2s
@@ -188,6 +205,24 @@
Source Files
+
+ Source Files\BSD
+
+
+ Source Files\BSD
+
+
+ crypto\siphash
+
+
+ Source Files\BSD
+
+
+ Source Files
+
+
+ Source Files\Win32
+
diff --git a/docs/WireGuard TCP.txt b/docs/WireGuard TCP.txt
new file mode 100644
index 0000000..dabb5c3
--- /dev/null
+++ b/docs/WireGuard TCP.txt
@@ -0,0 +1,104 @@
+WireGuard over TCP
+------------------
+
+We hate running one TCP implementation on top of another TCP implementation.
+There's problems with cascading retransmissions and head of line blocking,
+and performance is always much worse than a UDP based tunnel.
+
+However, we also recognize that several users need to run WireGuard over TCP.
+One reason is that UDP packets are sometimes blocked by the network in
+corporate scenarios or in other types of firewalls. Also, in misconfigured
+networks outside of the user's control, TCP may be more reliable than UDP.
+
+Additionally, we want TunSafe to be a drop-in replacement for OpenVPN, which
+also supports TCP based tunneling. The feature could also be used to run
+WireGuard tunnels over ssh tunnels, or through socks/https proxies.
+
+The TunSafe project therefore takes the pragmatic approach of supporting
+WireGuard over TCP, while discouraging its use. We absolutely don't want
+people to start using TCP by default. It's meant to be used only in the
+extreme cases when nothing else is working.
+
+We've added experimental support for TCP in the latest TunSafe master,
+which means you can try this out on Windows, OSX, Linux, and FreeBSD.
+On the server side, to listen on a TCP port, use ListenPortTCP=1234. (Not
+working on Windows yet). On the clients, use Endpoint=tcp://5.5.5.5:1234.
+The code is still very experimental and untested, and is not recommended
+for general use. Once the code is more well tested, we'll also release
+support for connecting to WireGuard over TCP in our Android and iOS clients.
+
+To make the impact as small as possible to our WireGuard protocol handling,
+and to minimize the risk of security related issues, the TCP feature has been
+designed to be as self-contained as possible. When a packet comes in over
+TCP, it's sent over to the WireGuard protocol handler and treated as if it
+was a UDP packet, and vice versa. This means TCP support can also be supported
+in existing WireGuard deployments by using a separate process that converts
+TCP connections into UDP packets sent to the WireGuard Linux kernel module.
+
+Each packet over TCP is prefixed by a 2-byte big endian number, which contains
+the length of the packet's payload. The payload is then the actual WireGuard
+UDP packet.
+
+TCP has larger overhead than UDP, and we want to support the usual WireGuard
+MTU of 1420 without introducing extra packet "fragmenting". So we implemented
+an optimization to skip sending the 16-byte WireGuard header for every packet.
+TCP is a reliable connection, we know that sequence numbers are always
+monotonically increasing, so we can predict the contents of this header.
+
+Here's an example:
+A 1420 byte big packet sent over a WireGuard link will have 2 bytes of
+TCP payload length, 16 bytes of WireGuard headers, 16 bytes of WireGuard MAC,
+20 bytes of TCP headers, and 40 bytes of IPv6 headers.
+This is a total of 1420 + 2 + 16 + 16 + 20 + 40 = 1514 bytes, exceeding
+the usual 1500 byte Ethernet MTU by 14 bytes. This means that a single full
+sized packet over WireGuard will result in 2 TCP packets. With our
+optimization, we reduce this to 1498 bytes, so it fits in one TCP packet.
+
+Protocol specification
+----------------------
+
+TT LLLLLL LLLLLLLL [Payload LL bytes]
+| |
+| \-- Payload length, high byte first.
+\----- Packet type
+
+The packet types (TT) currently defined are:
+TT = 00 = Normal The payload is a normal unmodified WireGuard packet
+ including the regular WireGuard header.
+ 01 = Reserved
+ 10 = Data A WireGuard data packet (type 04) without the 16 byte
+ header. The predicted header is prefixed to the payload.
+ 11 = Control A TCP control packet. Currently this is used only to setup
+ the header prediction. See below.
+
+There's only one defined Control packet, type 00 (SetKeyAndCounter):
+
+ 0 1 2 3
+ 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ |1 1| Length is 13 (14 bits) | 00 (8 bits) |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Key ID (32 bits) |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Counter (64 bits) ...
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+
+This sets up the Key ID and Counter used for the Data packets. Then Counter
+is incremented by 1 for every such packet.
+
+For every Data packet, the predicted Key ID and Counter is expanded to a
+regular WireGuard data (type 04) header, which is prefixed to the payload:
+
+ 0 1 2 3
+ 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | 04 (8 bits) | Reserved (24 bits) |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Key ID (32 bits) |
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Counter (64 bits) ...
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+ | Data Payload (LL * 8 bits) ...
+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+
+This happens independently in each of the two TCP directions.
diff --git a/netapi.h b/netapi.h
index 8fee819..f1b9318 100644
--- a/netapi.h
+++ b/netapi.h
@@ -59,22 +59,30 @@ struct Packet : QueuedItem {
uint8 userdata;
uint8 protocol; // which protocol is this packet for/from
IpAddr addr; // Optionally set to target/source of the packet
-
- byte data_pre[4];
- byte data_buf[0];
+
enum {
- // there's always this much data before data_ptr
+ // there's always this much data before data_buf, to allow for header expansion
+ // in front.
HEADROOM_BEFORE = 64,
};
+
+ byte data_pre[HEADROOM_BEFORE];
+ byte data_buf[0];
+
+ void Reset() {
+ data = data_buf;
+ size = 0;
+ }
};
enum {
kPacketAllocSize = 2048 - 16,
- kPacketCapacity = kPacketAllocSize - sizeof(Packet) - Packet::HEADROOM_BEFORE,
+ kPacketCapacity = kPacketAllocSize - sizeof(Packet),
};
void FreePacket(Packet *packet);
void FreePackets(Packet *packet, Packet **end, int count);
+void FreePacketList(Packet *packet);
Packet *AllocPacket();
void FreeAllPackets();
@@ -123,7 +131,7 @@ public:
class UdpInterface {
public:
- virtual bool Configure(int listen_port) = 0;
+ virtual bool Configure(int listen_port_udp, int listen_port_tcp) = 0;
virtual void WriteUdpPacket(Packet *packet) = 0;
};
diff --git a/network_bsd.cpp b/network_bsd.cpp
index 547e3db..62d5e5f 100644
--- a/network_bsd.cpp
+++ b/network_bsd.cpp
@@ -1,6 +1,7 @@
// SPDX-License-Identifier: AGPL-1.0-only
// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved.
-#include "network_bsd_common.h"
+#include "network_bsd.h"
+#include "network_common.h"
#include "tunsafe_endian.h"
#include "util.h"
@@ -14,23 +15,52 @@
#include
#include
#include
+#include
#include
#include
#include
#include
#include
+#if defined(OS_LINUX)
+#include
+#include
+#include
+#endif
+
+#include
+
+#include "wireguard.h"
+#include "wireguard_config.h"
+
+#if defined(OS_MACOSX) || defined(OS_FREEBSD)
+#define TUN_PREFIX_BYTES 4
+#elif defined(OS_LINUX) || defined(OS_ANDROID)
+#define TUN_PREFIX_BYTES 0
+#endif
+
static Packet *freelist;
+void tunsafe_die(const char *msg) {
+ fprintf(stderr, "%s\n", msg);
+ exit(1);
+}
+
+void SetThreadName(const char *name) {
+#if defined(OS_LINUX)
+ prctl(PR_SET_NAME, name, 0, 0, 0);
+#endif // defined(OS_LINUX)
+}
+
void FreePacket(Packet *packet) {
- packet->next = freelist;
+ packet->queue_next = freelist;
freelist = packet;
}
Packet *AllocPacket() {
Packet *p = freelist;
if (p) {
- freelist = p->next;
+ freelist = Packet_NEXT(p);
} else {
p = (Packet*)malloc(kPacketAllocSize);
if (p == NULL) {
@@ -38,192 +68,291 @@ Packet *AllocPacket() {
abort();
}
}
- p->data = p->data_buf + Packet::HEADROOM_BEFORE;
- p->size = 0;
+ p->Reset();
return p;
}
-void FreePackets() {
- Packet *p;
- while ( (p = freelist ) != NULL) {
- freelist = p->next;
- free(p);
- }
+void FreePacketList(Packet *packet) {
+ while (packet)
+ free(exch(packet, Packet_NEXT(packet)));
}
-
-class TunsafeBackendBsdImpl : public TunsafeBackendBsd {
-public:
- TunsafeBackendBsdImpl();
- virtual ~TunsafeBackendBsdImpl();
-
- virtual void RunLoopInner() override;
- virtual bool InitializeTun(char devname[16]) override;
-
- // -- from TunInterface
- virtual void WriteTunPacket(Packet *packet) override;
-
- // -- from UdpInterface
- virtual bool Configure(int listen_port) override;
- virtual void WriteUdpPacket(Packet *packet) override;
-
- virtual void HandleSigAlrm() override { got_sig_alarm_ = true; }
- virtual void HandleExit() override { exit_ = true; }
-
-private:
- bool ReadFromUdp(bool overload);
- bool ReadFromTun();
- bool WriteToUdp();
- bool WriteToTun();
- bool InitializeUnixDomainSocket(const char *devname);
-
- // Exists for the unix domain sockets
- struct SockInfo {
- bool is_listener;
-
- std::string inbuf, outbuf;
- };
- bool HandleSpecialPollfd(struct pollfd *pollfd, struct SockInfo *sockinfo);
- void CloseSpecialPollfd(size_t i);
- void SetUdpFd(int fd);
- void SetTunFd(int fd);
-
- bool got_sig_alarm_;
- bool exit_;
-
- bool tun_readable_, tun_writable_;
- bool udp_readable_, udp_writable_;
-
- Packet *tun_queue_, **tun_queue_end_;
- Packet *udp_queue_, **udp_queue_end_;
-
- Packet *read_packet_;
-
- enum {
- kMaxPollFd = 5,
- kPollFdTun = 0,
- kPollFdUdp = 1,
- kPollFdUnix = 2,
- };
-
- unsigned int pollfd_num_;
- struct pollfd pollfd_[kMaxPollFd];
-
- struct SockInfo sockinfo_[kMaxPollFd - 2];
-
- struct sockaddr_un un_addr_;
-
- UnixSocketDeletionWatcher un_deletion_watcher_;
-};
-
-TunsafeBackendBsdImpl::TunsafeBackendBsdImpl()
- : tun_readable_(false),
- tun_writable_(false),
- udp_readable_(false),
- udp_writable_(false),
- got_sig_alarm_(false),
- exit_(false),
- tun_queue_(NULL),
- tun_queue_end_(&tun_queue_),
- udp_queue_(NULL),
- udp_queue_end_(&udp_queue_),
- read_packet_(NULL) {
- read_packet_ = AllocPacket();
- for(size_t i = 0; i < kMaxPollFd; i++)
- pollfd_[i].fd = -1;
- pollfd_num_ = 3;
- sockinfo_[0].is_listener = true;
- memset(&un_addr_, 0, sizeof(un_addr_));
+void FreeAllPackets() {
+ FreePacketList(exch_null(freelist));
}
-TunsafeBackendBsdImpl::~TunsafeBackendBsdImpl() {
- if (un_addr_.sun_path[0])
- unlink(un_addr_.sun_path);
+//////////////////////////////////////////////////////////////////////////////////////////////
+
+NetworkBsd::NetworkBsd(NetworkBsdDelegate *delegate, int max_sockets)
+ : exit_(false),
+ overload_(false),
+ sigalarm_flag_(false),
+ num_roundrobin_(0),
+ num_sock_(0),
+ num_endloop_(0),
+ read_packet_(NULL),
+ tcp_sockets_(NULL),
+ delegate_(delegate),
+ max_sockets_(max_sockets) {
+ if (max_sockets < 5 || max_sockets > 1000)
+ tunsafe_die("invalid value for max_sockets");
+
+ pollfd_ = new struct pollfd[max_sockets];
+ sockets_ = new BaseSocketBsd*[max_sockets];
+ roundrobin_ = new BaseSocketBsd*[max_sockets];
+ endloop_ = new BaseSocketBsd*[max_sockets];
+ if (!pollfd_ || !sockets_ || !roundrobin_ || !endloop_)
+ tunsafe_die("no memory");
+
+ memset(iov_packets_, 0, sizeof(iov_packets_));
+}
+
+NetworkBsd::~NetworkBsd() {
+ assert(tcp_sockets_ == NULL);
+ assert(num_sock_ == 0);
if (read_packet_)
FreePacket(read_packet_);
- for(size_t i = 0; i < pollfd_num_; i++)
- close(pollfd_[i].fd);
+ for (size_t i = 0; i < kMaxIovec; i++)
+ if (iov_packets_[i])
+ FreePacket(iov_packets_[i]);
+
+ delete [] pollfd_;
+ delete [] sockets_;
+ delete [] roundrobin_;
+ delete [] endloop_;
}
-void TunsafeBackendBsdImpl::SetUdpFd(int fd) {
- pollfd_[kPollFdUdp].fd = fd;
- pollfd_[kPollFdUdp].events = POLLIN;
- udp_writable_ = true;
-}
+void NetworkBsd::RunLoop(const sigset_t *sigmask) {
+ int free_packet_interval = 10;
+ int overload_ctr = 0;
+ uint64 last_second_loop = 0;
+ uint64 now = 0;
-void TunsafeBackendBsdImpl::SetTunFd(int fd) {
- pollfd_[kPollFdTun].fd = fd;
- pollfd_[kPollFdTun].events = POLLIN;
- tun_writable_ = true;
-}
+ if (!WithSigalarmSupport)
+ last_second_loop = OsGetMilliseconds();
+
+ while (!exit_) {
+ int n;
+ bool new_second = false;
-bool TunsafeBackendBsdImpl::ReadFromUdp(bool overload) {
- socklen_t sin_len;
- sin_len = sizeof(read_packet_->addr.sin);
- int r = recvfrom(pollfd_[kPollFdUdp].fd, read_packet_->data, kPacketCapacity, 0,
- (sockaddr*)&read_packet_->addr.sin, &sin_len);
- if (r >= 0) {
-// printf("Read %d bytes from UDP\n", r);
- read_packet_->sin_size = sin_len;
- read_packet_->size = r;
- if (processor_) {
- processor_->HandleUdpPacket(read_packet_, overload);
- read_packet_ = AllocPacket();
+ if (WithSigalarmSupport) {
+ if (sigalarm_flag_) {
+ sigalarm_flag_ = false;
+ new_second = true;
+ }
+ } else {
+ now = OsGetMilliseconds();
+ if ((now - last_second_loop) >= 1000) {
+ // Avoid falling behind too much
+ last_second_loop = (now - last_second_loop) >= 2000 ? now : last_second_loop + 1000;
+ new_second = true;
+ }
}
- return true;
- } else {
- if (errno != EAGAIN) {
- fprintf(stderr, "Read from UDP failed\n");
+
+ if (new_second) {
+ delegate_->OnSecondLoop(now);
+
+ struct BaseSocketBsd **socks = sockets_;
+ for (int i = 0; i < num_sock_; i++)
+ socks[i]->Periodic();
+
+ if (free_packet_interval == 0) {
+ FreeAllPackets();
+ free_packet_interval = 10;
+ }
+ free_packet_interval--;
+
+ overload_ctr -= (overload_ctr != 0);
}
- udp_readable_ = false;
+
+#if defined(OS_LINUX) || defined(OS_FREEBSD)
+ n = ppoll(pollfd_, num_sock_, NULL, sigmask);
+#else
+ n = poll(pollfd_, num_sock_, WithSigalarmSupport ? -1 : std::max((int)(last_second_loop - now) + 1000, 0));
+#endif
+ if (n == -1) {
+ if (errno != EINTR) {
+ RERROR("poll failed");
+ break;
+ }
+ } else {
+ // Iterate backwards to support deleting elements
+ struct pollfd *pfd = pollfd_;
+ struct BaseSocketBsd **socks = sockets_;
+ for (int i = num_sock_ - 1; i >= 0; i--) {
+ if (pfd[i].revents)
+ socks[i]->HandleEvents(pfd[i].revents);
+ }
+ }
+
+ overload_ = (overload_ctr != 0);
+ for (int loop = 0; ; loop++) {
+ // Whenever we don't finish set overload ctr.
+ if (loop == 256) {
+ overload_ctr = 4;
+ break;
+ }
+ int i = num_roundrobin_ - 1;
+ struct BaseSocketBsd **rrlist = roundrobin_;
+ if (i < 0)
+ break;
+ do {
+ if (!rrlist[i]->DoRoundRobin())
+ RemoveFromRoundRobin(i);
+ } while (i--);
+ }
+
+ struct BaseSocketBsd **endloop = endloop_;
+ for (int j = num_endloop_ - 1; j >= 0; j--) {
+ endloop[j]->endloop_slot_ = -1;
+ endloop[j]->DoEndloop();
+ }
+ num_endloop_ = 0;
+
+ delegate_->RunAllMainThreadScheduled();
+ }
+}
+
+void NetworkBsd::RemoveFromRoundRobin(int i) {
+ BaseSocketBsd *cur = roundrobin_[i], *last = roundrobin_[num_roundrobin_-- - 1];
+ assert(cur->roundrobin_slot_ == i);
+ roundrobin_[i] = last;
+ last->roundrobin_slot_ = i;
+ cur->roundrobin_slot_ = -1;
+}
+
+void NetworkBsd::ReallocateIov(size_t j) {
+ Packet *p = AllocPacket();
+ iov_packets_[j] = p;
+ iov_[j].iov_base = p->data;
+ iov_[j].iov_len = kPacketCapacity;
+}
+
+void NetworkBsd::EnsureIovAllocated() {
+ if (iov_packets_[0] == NULL) {
+ for (size_t i = 0; i < kMaxIovec; i++)
+ ReallocateIov(i);
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////////////////////////
+
+BaseSocketBsd::~BaseSocketBsd() {
+ CloseSocket();
+}
+
+void BaseSocketBsd::CloseSocket() {
+ if (fd_ != -1)
+ close(fd_);
+ if (roundrobin_slot_ >= 0)
+ network_->RemoveFromRoundRobin(roundrobin_slot_);
+ if (endloop_slot_ >= 0) {
+ BaseSocketBsd *last = network_->endloop_[network_->num_endloop_-- - 1];
+ network_->endloop_[endloop_slot_] = last;
+ last->endloop_slot_ = endloop_slot_;
+ }
+ if (pollfd_slot_ >= 0) {
+ unsigned int cur = pollfd_slot_, last = network_->num_sock_-- - 1;
+ BaseSocketBsd *lastsock = network_->sockets_[last];
+ network_->sockets_[cur] = lastsock;
+ lastsock->pollfd_slot_ = cur;
+ network_->pollfd_[cur] = network_->pollfd_[last];
+ }
+ fd_ = -1;
+ endloop_slot_ = pollfd_slot_ = roundrobin_slot_ = -1;
+}
+
+void BaseSocketBsd::InitPollSlot(int fd, int events) {
+ assert(network_->num_sock_ != network_->max_sockets_);
+ assert(fd_ == -1);
+ fd_ = fd;
+ unsigned int slot = pollfd_slot_;
+ if (pollfd_slot_ < 0)
+ pollfd_slot_ = slot = network_->num_sock_++;
+ network_->sockets_[slot] = this;
+ struct pollfd *pfd = &network_->pollfd_[slot];
+ pfd->fd = fd;
+ pfd->events = events;
+ pfd->revents = 0;
+}
+
+void BaseSocketBsd::AddToRoundRobin() {
+ if (roundrobin_slot_ < 0)
+ network_->roundrobin_[roundrobin_slot_ = network_->num_roundrobin_++] = this;
+}
+
+void BaseSocketBsd::AddToEndLoop() {
+ if (endloop_slot_ < 0)
+ network_->endloop_[endloop_slot_ = network_->num_endloop_++] = this;
+}
+
+//////////////////////////////////////////////////////////////////////////////////////////////
+
+TunSocketBsd::TunSocketBsd(NetworkBsd *network, WireguardProcessor *processor)
+ : BaseSocketBsd(network),
+ tun_readable_(false),
+ tun_writable_(false),
+ tun_interface_gone_(false),
+ tun_queue_(NULL),
+ tun_queue_end_(&tun_queue_),
+ processor_(processor) {
+}
+
+TunSocketBsd::~TunSocketBsd() {
+}
+
+bool TunSocketBsd::Initialize(int fd) {
+ if (!HasFreePollSlot())
return false;
- }
-}
-
-bool TunsafeBackendBsdImpl::WriteToUdp() {
- assert(udp_writable_);
-// RINFO("Send %d bytes to %s", (int)udp_queue_->size, inet_ntoa(udp_queue_->sin.sin_addr));
- int r = sendto(pollfd_[kPollFdUdp].fd, udp_queue_->data, udp_queue_->size, 0,
- (sockaddr*)&udp_queue_->addr.sin, sizeof(udp_queue_->addr.sin));
- if (r < 0) {
- if (errno == EAGAIN) {
- udp_writable_ = false;
- pollfd_[kPollFdUdp].events = POLLIN | POLLOUT;
- return false;
- }
- perror("Write to UDP failed");
- } else {
- if (r != udp_queue_->size)
- perror("Write to udp incomplete!");
-// else
-// RINFO("Wrote %d bytes to UDP", r);
- }
- Packet *next = udp_queue_->next;
- FreePacket(udp_queue_);
- if ((udp_queue_ = next) != NULL) return true;
- udp_queue_end_ = &udp_queue_;
- return false;
+ fcntl(fd, F_SETFD, FD_CLOEXEC);
+ fcntl(fd, F_SETFL, O_NONBLOCK);
+ InitPollSlot(fd, POLLIN);
+ tun_writable_ = true;
+ return true;
}
static inline bool IsCompatibleProto(uint32 v) {
return v == AF_INET || v == AF_INET6;
}
-bool TunsafeBackendBsdImpl::ReadFromTun() {
+void TunSocketBsd::HandleEvents(int revents) {
+ if (revents & (POLLERR | POLLHUP | POLLNVAL)) {
+ if (revents & POLLERR) {
+ tun_interface_gone_ = true;
+ RERROR("Tun interface gone, closing.");
+ } else {
+ RERROR("Tun interface error %d, closing.", revents);
+ }
+ tun_readable_ = tun_writable_ = false;
+ network_->PostExit();
+ } else {
+ tun_readable_ = (revents & POLLIN) != 0;
+ if (revents & POLLOUT) {
+ SetPollFlags(POLLIN);
+ tun_writable_ = true;
+ }
+ }
+ AddToRoundRobin();
+}
+
+bool TunSocketBsd::DoRead() {
assert(tun_readable_);
- Packet *packet = read_packet_;
- int r = read(pollfd_[kPollFdTun].fd, packet->data - TUN_PREFIX_BYTES, kPacketCapacity + TUN_PREFIX_BYTES);
+ Packet *packet = network_->read_packet_;
+ if (!packet)
+ network_->read_packet_ = packet = AllocPacket();
+
+ int r = read(fd_, packet->data - TUN_PREFIX_BYTES, kPacketCapacity + TUN_PREFIX_BYTES);
if (r >= 0) {
-// printf("Read %d bytes from TUN\n", r);
+// printf("Read %d bytes from TUN\n", r);
packet->size = r - TUN_PREFIX_BYTES;
- if (r >= TUN_PREFIX_BYTES && (!TUN_PREFIX_BYTES || IsCompatibleProto(ReadBE32(packet->data - TUN_PREFIX_BYTES))) && processor_) {
-// printf("%X %X %X %X %X %X %X %X\n",
-// read_packet_->data[0], read_packet_->data[1], read_packet_->data[2], read_packet_->data[3],
-// read_packet_->data[4], read_packet_->data[5], read_packet_->data[6], read_packet_->data[7]);
- read_packet_ = AllocPacket();
+ if (r >= TUN_PREFIX_BYTES && (!TUN_PREFIX_BYTES || IsCompatibleProto(ReadBE32(packet->data - TUN_PREFIX_BYTES)))) {
+ // printf("%X %X %X %X %X %X %X %X\n",
+ // read_packet_->data[0], read_packet_->data[1], read_packet_->data[2], read_packet_->data[3],
+ // read_packet_->data[4], read_packet_->data[5], read_packet_->data[6], read_packet_->data[7]);
+ network_->read_packet_ = NULL;
processor_->HandleTunPacket(packet);
}
- return true;
+ return true;
} else {
if (errno != EAGAIN) {
fprintf(stderr, "Read from tun failed\n");
@@ -237,16 +366,16 @@ static uint32 GetProtoFromPacket(const uint8 *data, size_t size) {
return size < 1 || (data[0] >> 4) != 6 ? AF_INET : AF_INET6;
}
-bool TunsafeBackendBsdImpl::WriteToTun() {
+bool TunSocketBsd::DoWrite() {
assert(tun_writable_);
if (TUN_PREFIX_BYTES) {
WriteBE32(tun_queue_->data - TUN_PREFIX_BYTES, GetProtoFromPacket(tun_queue_->data, tun_queue_->size));
}
- int r = write(pollfd_[kPollFdTun].fd, tun_queue_->data - TUN_PREFIX_BYTES, tun_queue_->size + TUN_PREFIX_BYTES);
+ int r = write(fd_, tun_queue_->data - TUN_PREFIX_BYTES, tun_queue_->size + TUN_PREFIX_BYTES);
if (r < 0) {
if (errno == EAGAIN) {
tun_writable_ = false;
- pollfd_[kPollFdTun].events = POLLIN | POLLOUT;
+ SetPollFlags(POLLIN | POLLOUT);
return false;
}
RERROR("Write to tun failed");
@@ -254,58 +383,262 @@ bool TunsafeBackendBsdImpl::WriteToTun() {
r -= TUN_PREFIX_BYTES;
if (r != tun_queue_->size)
RERROR("Write to tun incomplete!");
-// else
-// RINFO("Wrote %d bytes to TUN", r);
- }
- Packet *next = tun_queue_->next;
+ // else
+ // RINFO("Wrote %d bytes to TUN", r);
+ }
+ Packet *next = Packet_NEXT(tun_queue_);
FreePacket(tun_queue_);
if ((tun_queue_ = next) != NULL) return true;
tun_queue_end_ = &tun_queue_;
return false;
}
-bool TunsafeBackendBsdImpl::InitializeTun(char devname[16]) {
- int tun_fd = open_tun(devname, 16);
- if (tun_fd < 0) { RERROR("Error opening tun device"); return false; }
- fcntl(tun_fd, F_SETFD, FD_CLOEXEC);
- fcntl(tun_fd, F_SETFL, O_NONBLOCK);
- SetTunFd(tun_fd);
-
- InitializeUnixDomainSocket(devname);
- return true;
-}
-
-void TunsafeBackendBsdImpl::WriteTunPacket(Packet *packet) override {
- assert(pollfd_[kPollFdTun].fd >= 0);
+void TunSocketBsd::WritePacket(Packet *packet) {
+ assert(fd_ >= 0);
Packet *queue_is_used = tun_queue_;
*tun_queue_end_ = packet;
- tun_queue_end_ = &packet->next;
- packet->next = NULL;
+ tun_queue_end_ = &Packet_NEXT(packet);
+ packet->queue_next = NULL;
if (!queue_is_used)
- WriteToTun();
+ DoWrite();
}
-// Called to initialize udp
-bool TunsafeBackendBsdImpl::Configure(int listen_port) override {
- int udp_fd = open_udp(listen_port);
- if (udp_fd < 0) { RERROR("Error opening udp"); return false; }
+bool TunSocketBsd::DoRoundRobin() {
+ bool more_work = false;
+ if (tun_queue_ && tun_writable_)
+ more_work = DoWrite();
+ if (tun_readable_)
+ more_work |= DoRead();
+ return more_work;
+}
+
+//////////////////////////////////////////////////////////////////////////////////////////////
+
+UdpSocketBsd::UdpSocketBsd(NetworkBsd *network, WireguardProcessor *processor)
+ : BaseSocketBsd(network),
+ udp_readable_(false),
+ udp_writable_(false),
+ udp_queue_(NULL),
+ udp_queue_end_(&udp_queue_),
+ processor_(processor) {
+}
+
+UdpSocketBsd::~UdpSocketBsd() {
+}
+
+bool UdpSocketBsd::Initialize(int listen_port) {
+ if (!HasFreePollSlot()) {
+ RERROR("No free internal sockets");
+ return false;
+ }
+ int udp_fd = socket(AF_INET, SOCK_DGRAM, 0);
+ if (udp_fd < 0) {
+ RERROR("socket(SOCK_DGRAM) failed");
+ return false;
+ }
+ sockaddr_in sin = {0};
+ sin.sin_family = AF_INET;
+ sin.sin_port = htons(listen_port);
+ if (bind(udp_fd, (struct sockaddr*)&sin, sizeof(sin)) != 0) {
+ close(udp_fd);
+ RERROR("bind on udp socket port %d failed", listen_port);
+ return false;
+ }
fcntl(udp_fd, F_SETFD, FD_CLOEXEC);
fcntl(udp_fd, F_SETFL, O_NONBLOCK);
- SetUdpFd(udp_fd);
+ InitPollSlot(udp_fd, POLLIN);
+ udp_writable_ = true;
return true;
}
-void TunsafeBackendBsdImpl::WriteUdpPacket(Packet *packet) override {
- assert(pollfd_[kPollFdUdp].fd >= 0);
- Packet *queue_is_used = udp_queue_;
- *udp_queue_end_ = packet;
- udp_queue_end_ = &packet->next;
- packet->next = NULL;
- if (!queue_is_used)
- WriteToUdp();
+void UdpSocketBsd::HandleEvents(int revents) {
+ if (revents & (POLLERR | POLLHUP | POLLNVAL)) {
+ RERROR("UDP error %d, closing.", revents);
+ network_->PostExit();
+ } else {
+ udp_readable_ = (revents & POLLIN) != 0;
+ if (revents & POLLOUT) {
+ SetPollFlags(POLLIN);
+ udp_writable_ = true;
+ }
+ }
+ AddToRoundRobin();
}
-bool TunsafeBackendBsdImpl::InitializeUnixDomainSocket(const char *devname) {
+bool UdpSocketBsd::DoRead() {
+ socklen_t sin_len;
+ Packet *read_packet = network_->read_packet_;
+ if (read_packet == NULL)
+ network_->read_packet_ = read_packet = AllocPacket();
+
+ sin_len = sizeof(read_packet->addr.sin);
+ int r = recvfrom(fd_, read_packet->data, kPacketCapacity, 0,
+ (sockaddr*)&read_packet->addr.sin, &sin_len);
+ if (r >= 0) {
+ // printf("Read %d bytes from UDP\n", r);
+ read_packet->sin_size = sin_len;
+ read_packet->size = r;
+ read_packet->protocol = kPacketProtocolUdp;
+ network_->read_packet_ = NULL;
+ processor_->HandleUdpPacket(read_packet, network_->overload_);
+ return true;
+ } else {
+ if (errno != EAGAIN) {
+ fprintf(stderr, "Read from UDP failed\n");
+ }
+ udp_readable_ = false;
+ return false;
+ }
+}
+
+bool UdpSocketBsd::DoWrite() {
+ assert(udp_writable_);
+ // RINFO("Send %d bytes to %s", (int)udp_queue_->size, inet_ntoa(udp_queue_->sin.sin_addr));
+ int r = sendto(fd_, udp_queue_->data, udp_queue_->size, 0,
+ (sockaddr*)&udp_queue_->addr.sin, sizeof(udp_queue_->addr.sin));
+ if (r < 0) {
+ if (errno == EAGAIN) {
+ udp_writable_ = false;
+ SetPollFlags(POLLIN | POLLOUT);
+ return false;
+ }
+ perror("Write to UDP failed");
+ } else {
+ if (r != udp_queue_->size)
+ perror("Write to udp incomplete!");
+ // else
+ // RINFO("Wrote %d bytes to UDP", r);
+ }
+ Packet *next = Packet_NEXT(udp_queue_);
+ FreePacket(udp_queue_);
+ if ((udp_queue_ = next) != NULL) return true;
+ udp_queue_end_ = &udp_queue_;
+ return false;
+}
+
+void UdpSocketBsd::WritePacket(Packet *packet) {
+ assert(fd_ >= 0);
+ Packet *queue_is_used = udp_queue_;
+ *udp_queue_end_ = packet;
+ udp_queue_end_ = &Packet_NEXT(packet);
+ packet->queue_next = NULL;
+ if (!queue_is_used)
+ DoWrite();
+}
+
+bool UdpSocketBsd::DoRoundRobin() {
+ bool did_work = false;
+ if (udp_queue_ && udp_writable_)
+ did_work = DoWrite();
+ if (udp_readable_)
+ did_work |= DoRead();
+ return did_work;
+}
+
+//////////////////////////////////////////////////////////////////////////////////////////////
+
+#if defined(OS_LINUX)
+UnixSocketDeletionWatcher::UnixSocketDeletionWatcher()
+ : inotify_fd_(-1) {
+ pipes_[0] = -1;
+ pipes_[0] = -1;
+}
+
+UnixSocketDeletionWatcher::~UnixSocketDeletionWatcher() {
+ close(inotify_fd_);
+ close(pipes_[0]);
+ close(pipes_[1]);
+}
+
+bool UnixSocketDeletionWatcher::Start(const char *path, bool *flag_to_set) {
+ assert(inotify_fd_ == -1);
+ path_ = path;
+ flag_to_set_ = flag_to_set;
+ pid_ = getpid();
+ inotify_fd_ = inotify_init1(IN_CLOEXEC | IN_NONBLOCK);
+ if (inotify_fd_ == -1) {
+ perror("inotify_init1() failed");
+ return false;
+ }
+ if (inotify_add_watch(inotify_fd_, "/var/run/wireguard", IN_DELETE | IN_DELETE_SELF) == -1) {
+ perror("inotify_add_watch failed");
+ return false;
+ }
+ if (pipe(pipes_) == -1) {
+ perror("pipe() failed");
+ return false;
+ }
+ return pthread_create(&thread_, NULL, &UnixSocketDeletionWatcher::RunThread, this) == 0;
+}
+
+void UnixSocketDeletionWatcher::Stop() {
+ RINFO("Stopping..");
+ void *retval;
+ write(pipes_[1], "", 1);
+ pthread_join(thread_, &retval);
+}
+
+void *UnixSocketDeletionWatcher::RunThread(void *arg) {
+ UnixSocketDeletionWatcher *self = (UnixSocketDeletionWatcher*)arg;
+ return self->RunThreadInner();
+}
+
+void *UnixSocketDeletionWatcher::RunThreadInner() {
+ char buf[sizeof(struct inotify_event) + NAME_MAX + 1]
+ __attribute__ ((aligned(__alignof__(struct inotify_event))));
+ fd_set fdset;
+ struct stat st;
+ for(;;) {
+ if (lstat(path_, &st) == -1 && errno == ENOENT) {
+ RINFO("Unix socket %s deleted.", path_);
+ *flag_to_set_ = true;
+ kill(pid_, SIGALRM);
+ break;
+ }
+ FD_ZERO(&fdset);
+ FD_SET(inotify_fd_, &fdset);
+ FD_SET(pipes_[0], &fdset);
+ int n = select(std::max(inotify_fd_, pipes_[0]) + 1, &fdset, NULL, NULL, NULL);
+ if (n == -1) {
+ perror("select");
+ break;
+ }
+ if (FD_ISSET(inotify_fd_, &fdset)) {
+ ssize_t len = read(inotify_fd_, buf, sizeof(buf));
+ if (len == -1) {
+ perror("read");
+ break;
+ }
+ }
+ if (FD_ISSET(pipes_[0], &fdset))
+ break;
+ }
+ return NULL;
+}
+
+#else // !defined(OS_LINUX)
+
+bool UnixSocketDeletionWatcher::Poll(const char *path) {
+ struct stat st;
+ return lstat(path, &st) == -1 && errno == ENOENT;
+}
+
+#endif // !defined(OS_LINUX)
+
+UnixDomainSocketListenerBsd::UnixDomainSocketListenerBsd(NetworkBsd *network, WireguardProcessor *processor)
+ : BaseSocketBsd(network),
+ processor_(processor) {
+ memset(&un_addr_, 0, sizeof(un_addr_));
+}
+
+UnixDomainSocketListenerBsd::~UnixDomainSocketListenerBsd() {
+ if (un_addr_.sun_path[0])
+ unlink(un_addr_.sun_path);
+}
+
+bool UnixDomainSocketListenerBsd::Initialize(const char *devname) {
+ if (!HasFreePollSlot())
+ return false;
int fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (fd == -1) {
RERROR("Error creating unix domain socket");
@@ -329,194 +662,408 @@ bool TunsafeBackendBsdImpl::InitializeUnixDomainSocket(const char *devname) {
close(fd);
return false;
}
-
- pollfd_[kPollFdUnix].fd = fd;
- pollfd_[kPollFdUnix].events = POLLIN;
-
+ InitPollSlot(fd, POLLIN);
return true;
}
+void UnixDomainSocketListenerBsd::HandleEvents(int revents) {
+ if (revents & POLLIN) {
+ // wait if we can't allocate more pollfd
+ if (!HasFreePollSlot()) {
+ SetPollFlags(0);
+ return;
+ }
+ int new_fd = accept(fd_, NULL, NULL);
+ if (new_fd >= 0) {
+ UnixDomainSocketChannelBsd *channel = new UnixDomainSocketChannelBsd(network_, processor_, new_fd);
+ } else {
+ RERROR("Unix domain socket accept failed");
+ }
+ }
+ if (revents & ~POLLIN) {
+ RERROR("Unix domain socket got an error code");
+ }
+}
+
+void UnixDomainSocketListenerBsd::Periodic() {
+ if (un_deletion_watcher_.Poll(un_addr_.sun_path)) {
+ RINFO("Unix socket %s deleted.", un_addr_.sun_path);
+ network_->PostExit();
+ } else {
+ // try again
+ SetPollFlags(POLLIN);
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////////////////////////
+
+UnixDomainSocketChannelBsd::UnixDomainSocketChannelBsd(NetworkBsd *network, WireguardProcessor *processor, int fd)
+ : BaseSocketBsd(network),
+ processor_(processor) {
+ assert(HasFreePollSlot());
+ InitPollSlot(fd, POLLIN);
+}
+
+UnixDomainSocketChannelBsd::~UnixDomainSocketChannelBsd() {
+}
+
static const char *FindMessageEnd(const char *start, size_t size) {
if (size <= 1)
return NULL;
const char *start_end = start + size - 1;
- for(;(start = (const char*)memchr(start, '\n', start_end - start)) != NULL; start++) {
+ for (; (start = (const char*)memchr(start, '\n', start_end - start)) != NULL; start++) {
if (start[1] == '\n')
return start + 2;
}
return NULL;
}
-bool TunsafeBackendBsdImpl::HandleSpecialPollfd(struct pollfd *pfd, struct SockInfo *sockinfo) {
- // handle domain socket thing
- if (sockinfo->is_listener) {
- if (pfd->revents & POLLIN) {
- // wait if we can't allocate more pollfd
- if (pollfd_num_ == kMaxPollFd) {
- pfd->events = 0;
- return true;
- }
- int fd = accept(pfd->fd, NULL, NULL);
- if (fd >= 0) {
- size_t slot = pollfd_num_++;
- pollfd_[slot].fd = fd;
- pollfd_[slot].events = POLLIN;
- pollfd_[slot].revents = 0;
- sockinfo_[slot - 2].is_listener = false;
- } else {
- RERROR("Unix domain socket accept failed");
- }
- }
- if (pfd->revents & ~POLLIN) {
- RERROR("Unix domain socket got an error code");
- return false;
- }
- return true;
- }
- if (pfd->revents & POLLIN) {
+bool UnixDomainSocketChannelBsd::HandleEventsInner(int revents) {
+ if (revents & POLLIN) {
char buf[4096];
// read as much data as we can until we see \n\n
- ssize_t n = recv(pfd->fd, buf, sizeof(buf), 0);
+ ssize_t n = recv(fd_, buf, sizeof(buf), 0);
if (n <= 0)
return (n == -1 && errno == EAGAIN); // premature eof or error
- sockinfo->inbuf.append(buf, n);
- const char *message_end = FindMessageEnd(&sockinfo->inbuf[0], sockinfo->inbuf.size());
+ inbuf_.append(buf, n);
+ const char *message_end = FindMessageEnd(&inbuf_[0], inbuf_.size());
if (message_end) {
- if (message_end != &sockinfo->inbuf[sockinfo->inbuf.size()])
+ if (message_end != &inbuf_[inbuf_.size()])
return false; // trailing data?
- WgConfig::HandleConfigurationProtocolMessage(processor_, std::move(sockinfo->inbuf), &sockinfo->outbuf);
- if (!sockinfo->outbuf.size())
+ WgConfig::HandleConfigurationProtocolMessage(processor_, std::move(inbuf_), &outbuf_);
+ if (!outbuf_.size())
return false;
- pfd->revents = pfd->events = POLLOUT;
+ SetPollFlags(POLLOUT);
+ revents |= POLLOUT;
}
}
- if (pfd->revents & POLLOUT) {
- size_t n = send(pfd->fd, sockinfo->outbuf.data(), sockinfo->outbuf.size(), 0);
+ if (revents & POLLOUT) {
+ size_t n = send(fd_, outbuf_.data(), outbuf_.size(), 0);
if (n <= 0)
return (n == -1 && errno == EAGAIN); // premature eof or error
- sockinfo->outbuf.erase(0, n);
- if (!sockinfo->outbuf.size())
+ outbuf_.erase(0, n);
+ if (!outbuf_.size())
return false;
}
-
- if (pfd->revents & ~(POLLIN | POLLOUT)) {
+ if (revents & ~(POLLIN | POLLOUT)) {
RERROR("Unix domain socket got an error code");
return false;
}
return true;
}
-void TunsafeBackendBsdImpl::CloseSpecialPollfd(size_t i) {
- close(pollfd_[i].fd);
- pollfd_[i].fd = -1;
- sockinfo_[i - 2].inbuf.clear();
- sockinfo_[i - 2].outbuf.clear();
- pollfd_[i] = pollfd_[(size_t)pollfd_num_ - 1];
- std::swap(sockinfo_[i - 2], sockinfo_[(size_t)pollfd_num_ - 1 - 2]);
-
- // Can now allow more sockets?
- if (pollfd_num_-- == kMaxPollFd && sockinfo_[kPollFdUnix - 2].is_listener)
- pollfd_[kPollFdUnix].events = POLLIN;
+void UnixDomainSocketChannelBsd::HandleEvents(int revents) {
+ if (!HandleEventsInner(revents))
+ delete this;
}
-void TunsafeBackendBsdImpl::RunLoopInner() {
- int free_packet_interval = 10;
- int overload_ctr = 0;
+//////////////////////////////////////////////////////////////////////////////////////////////
- if (!un_deletion_watcher_.Start(un_addr_.sun_path, &exit_))
- return;
+TcpSocketListenerBsd::TcpSocketListenerBsd(NetworkBsd *bsd, WireguardProcessor *processor)
+ : BaseSocketBsd(bsd),
+ processor_(processor) {
- while (!exit_) {
- int n = -1;
+}
- if (got_sig_alarm_) {
- got_sig_alarm_ = false;
+TcpSocketListenerBsd::~TcpSocketListenerBsd() {
- if (un_deletion_watcher_.Poll(un_addr_.sun_path)) {
- RINFO("Unix socket %s deleted.", un_addr_.sun_path);
- break;
- }
- processor_->SecondLoop();
+}
- if (free_packet_interval == 0) {
- FreePackets();
- free_packet_interval = 10;
- }
- free_packet_interval--;
+bool TcpSocketListenerBsd::Initialize(int port) {
+ if (!HasFreePollSlot())
+ return false;
- overload_ctr -= (overload_ctr != 0);
+ CloseSocket();
+
+ int fd = socket(AF_INET, SOCK_STREAM, 0);
+ if (fd < 0) { RERROR("Error listen socket"); return false; }
+ fcntl(fd, F_SETFD, FD_CLOEXEC);
+ fcntl(fd, F_SETFL, O_NONBLOCK);
+
+ int optval = 1;
+ // setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &optval, sizeof(optval));
+ setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval));
+
+ sockaddr_in sin = {0};
+ sin.sin_family = AF_INET;
+ sin.sin_port = htons(port);
+ sin.sin_addr.s_addr = INADDR_ANY;
+ if (bind(fd, (sockaddr*)&sin, sizeof(sin))) {
+ RERROR("Error binding socket on port %d", port);
+ close(fd);
+ return false;
+ }
+ if (listen(fd, 5)) {
+ RERROR("Error listen socket");
+ close(fd);
+ return false;
+ }
+ RINFO("Started TCP listening socket on port %d", port);
+ InitPollSlot(fd, POLLIN);
+ return true;
+}
+
+void TcpSocketListenerBsd::HandleEvents(int revents) {
+ if (revents & POLLIN) {
+ // wait if we can't allocate more pollfd
+ if (!HasFreePollSlot()) {
+ SetPollFlags(0);
+ return;
}
-
-#if defined(OS_LINUX) || defined(OS_FREEBSD)
- n = ppoll(pollfd_, pollfd_num_, NULL, &orig_signal_mask_);
-#else
- n = poll(pollfd_, pollfd_num_, -1);
-#endif
- if (n == -1) {
- if (errno != EINTR) {
- RERROR("poll failed");
- break;
- }
+ IpAddr addr;
+ socklen_t len = sizeof(addr);
+ int new_fd = accept(fd_, (sockaddr*)&addr, &len);
+ if (new_fd >= 0) {
+ RINFO("Created new tcp socket");
+ TcpSocketBsd *channel = new TcpSocketBsd(network_, processor_);
+ if (channel)
+ channel->InitializeIncoming(new_fd, addr);
+ else
+ close(new_fd);
} else {
-
- if (pollfd_[kPollFdTun].revents & (POLLERR | POLLHUP | POLLNVAL)) {
- if (pollfd_[kPollFdTun].revents & POLLERR) {
- tun_interface_gone_ = true;
- RERROR("Tun interface gone, closing.");
- } else {
- RERROR("Tun interface error %d, closing.", pollfd_[kPollFdTun].revents);
- }
- break;
- }
- tun_readable_ = (pollfd_[kPollFdTun].revents & POLLIN) != 0;
- if (pollfd_[kPollFdTun].revents & POLLOUT) {
- pollfd_[kPollFdTun].events = POLLIN;
- tun_writable_ = true;
- }
-
- if (pollfd_[kPollFdUdp].revents & (POLLERR | POLLHUP | POLLNVAL)) {
- RERROR("UDP error %d, closing.", pollfd_[kPollFdUdp].revents);
- break;
- }
-
- udp_readable_ = (pollfd_[kPollFdUdp].revents & POLLIN) != 0;
- if (pollfd_[kPollFdUdp].revents & POLLOUT) {
- pollfd_[kPollFdUdp].events = POLLIN;
- udp_writable_ = true;
- }
-
- for(size_t i = 2; i < pollfd_num_; i++) {
- if (pollfd_[i].revents && !HandleSpecialPollfd(&pollfd_[i], &sockinfo_[i - 2])) {
- // Close the fd / discard the sockinfo
- CloseSpecialPollfd(i);
- i--;
- }
- }
+ RERROR("Unix domain socket accept failed");
}
-
- bool overload = (overload_ctr != 0);
-
- for(int loop = 0; ; loop++) {
- // Whenever we don't finish set overload ctr.
- if (loop == 256) {
- overload_ctr = 4;
- break;
- }
- bool more_work = false;
- if (tun_queue_ != NULL && tun_writable_) more_work |= WriteToTun();
- if (udp_queue_ != NULL && udp_writable_) more_work |= WriteToUdp();
- if (tun_readable_) more_work |= ReadFromTun();
- if (udp_readable_) more_work |= ReadFromUdp(overload);
- if (!more_work)
- break;
- }
-
- processor_->RunAllMainThreadScheduled();
- }
-
- un_deletion_watcher_.Stop();
+ }
}
-TunsafeBackendBsd *CreateTunsafeBackendBsd() {
- return new TunsafeBackendBsdImpl;
+void TcpSocketListenerBsd::Periodic() {
+ SetPollFlags(POLLIN);
+}
+//////////////////////////////////////////////////////////////////////////////////////////////
+
+TcpSocketBsd::TcpSocketBsd(NetworkBsd *net, WireguardProcessor *processor)
+ : BaseSocketBsd(net),
+ readable_(false),
+ writable_(true),
+ endpoint_protocol_(0),
+ age(0),
+ handshake_attempts(0),
+ wqueue_(NULL),
+ wqueue_end_(&wqueue_),
+ wqueue_bytes_(0),
+ processor_(processor),
+ tcp_packet_handler_(&net->packet_pool_) {
+ // insert in network's linked list
+ next_ = net->tcp_sockets_;
+ net->tcp_sockets_ = this;
+
+ network_->EnsureIovAllocated();
+}
+
+TcpSocketBsd::~TcpSocketBsd() {
+ // Unlink myself from the network's linked list.
+ TcpSocketBsd **p = &network_->tcp_sockets_;
+ while (*p != this) p = &(*p)->next_;
+ *p = next_;
+
+ RINFO("Destroyed tcp socket");
+}
+
+void TcpSocketBsd::InitializeIncoming(int fd, const IpAddr &addr) {
+ assert(fd_ == -1);
+ endpoint_protocol_ = kPacketProtocolTcp | kPacketProtocolIncomingConnection;
+ endpoint_ = addr;
+ InitPollSlot(fd, POLLIN);
+}
+
+bool TcpSocketBsd::InitializeOutgoing(const IpAddr &addr) {
+ assert(fd_ == -1);
+ if (!HasFreePollSlot() || addr.sin.sin_family == 0)
+ return false;
+
+ endpoint_protocol_ = kPacketProtocolTcp;
+ endpoint_ = addr;
+ writable_ = false;
+
+ int fd = socket(addr.sin.sin_family, SOCK_STREAM, 0);
+ if (fd < 0) { perror("socket: outgoing tcp"); return false; }
+ fcntl(fd, F_SETFD, FD_CLOEXEC);
+ fcntl(fd, F_SETFL, O_NONBLOCK);
+
+ char buf[kSizeOfAddress];
+ RINFO("Connecting to tcp://%s:%d...", PrintIpAddr(endpoint_, buf), ReadBE16(&endpoint_.sin.sin_port));
+
+ if (connect(fd, (sockaddr*)&endpoint_.sin,
+ endpoint_.sin.sin_family == AF_INET ? sizeof(sockaddr_in) : sizeof(sockaddr_in6))) {
+ if (errno != EINPROGRESS) {
+ perror("connect: outgoing tcp");
+ close(fd);
+ return false;
+ }
+ }
+
+ InitPollSlot(fd, POLLOUT | POLLIN);
+ return true;
+}
+
+void TcpSocketBsd::WritePacket(Packet *packet) {
+ assert(fd_ >= 0);
+
+ tcp_packet_handler_.AddHeaderToOutgoingPacket(packet);
+
+ Packet *old_value = wqueue_;
+ *wqueue_end_ = packet;
+ wqueue_end_ = &Packet_NEXT(packet);
+ packet->queue_next = NULL;
+
+ AddToEndLoop();
+
+ wqueue_bytes_ += packet->size;
+
+ // When many bytes have been queued, perform the write.
+ if (writable_ && wqueue_bytes_ >= 32768)
+ DoWrite();
+}
+
+void TcpSocketBsd::HandleEvents(int revents) {
+ if (revents & (POLLERR | POLLHUP | POLLNVAL)) {
+ RINFO("TcpSocket error");
+ CloseSocketAndDestroy();
+ return;
+ }
+
+ if (revents & POLLOUT) {
+ SetPollFlags(POLLIN);
+ AddToEndLoop();
+ writable_ = true;
+ }
+
+ if (revents & POLLIN)
+ DoRead();
+}
+
+void TcpSocketBsd::DoEndloop() {
+ if (writable_ && wqueue_)
+ DoWrite();
+}
+
+void TcpSocketBsd::DoRead() {
+ ssize_t bytes_read = readv(fd_, network_->iov_, NetworkBsd::kMaxIovec);
+ ssize_t bytes_read_org = bytes_read;
+ if (bytes_read < 0) {
+ if (errno != EAGAIN) {
+ RERROR("tcp readv says error code: %d", errno);
+ CloseSocketAndDestroy();
+ }
+ return;
+ }
+ // Go through and read the packet structures that are ready and queue them up
+ NetworkBsd *net = network_;
+ for (size_t j = 0; bytes_read; j++) {
+ size_t m = std::min(bytes_read, net->iov_[j].iov_len);
+ Packet *p = net->iov_packets_[j];
+ p->size = (int)m;
+ bytes_read -= m;
+ tcp_packet_handler_.QueueIncomingPacket(p);
+ net->ReallocateIov(j);
+ }
+ // Parse it all
+ while (Packet *p = tcp_packet_handler_.GetNextWireguardPacket()) {
+ p->protocol = endpoint_protocol_;
+ p->addr = endpoint_;
+ processor_->HandleUdpPacket(p, network_->overload_);
+ }
+
+ if (tcp_packet_handler_.error() || bytes_read_org == 0)
+ CloseSocketAndDestroy();
+}
+
+void TcpSocketBsd::DoWrite() {
+ enum { kMaxIoWrite = 16 };
+ struct iovec vecs[kMaxIoWrite];
+ Packet *p = wqueue_;
+ size_t nvec = 0;
+ for (; p && nvec < kMaxIoWrite; nvec++, p = Packet_NEXT(p)) {
+ vecs[nvec].iov_base = p->data;
+ vecs[nvec].iov_len = p->size;
+ }
+ ssize_t n = writev(fd_, vecs, nvec);
+
+ if (n < 0) {
+ if (errno != EAGAIN) {
+ RERROR("tcp writev says error code: %d", errno);
+ CloseSocketAndDestroy();
+ } else {
+ writable_ = false;
+ SetPollFlags(POLLIN | POLLOUT);
+ }
+ return;
+ }
+ wqueue_bytes_ -= n;
+ // discard those initial n bytes worth of packets
+ size_t i = 0;
+ p = wqueue_;
+ while (n) {
+ if (n < p->size) {
+ p->data += n, p->size -= n;
+ break;
+ }
+ n -= p->size;
+ FreePacket(exch(p, Packet_NEXT(p)));
+ }
+ if (!(wqueue_ = p))
+ wqueue_end_ = &wqueue_;
+}
+
+void TcpSocketBsd::CloseSocketAndDestroy() {
+ delete this;
+}
+
+//////////////////////////////////////////////////////////////////////////////////////////////
+NotificationPipeBsd::NotificationPipeBsd(NetworkBsd *network)
+ : BaseSocketBsd(network),
+ injected_cb_(NULL) {
+
+ if (!HasFreePollSlot())
+ tunsafe_die("no free poll slots");
+
+#if !defined(OS_MACOSX)
+ if (pipe2(pipe_fds_, O_CLOEXEC | O_NONBLOCK))
+ tunsafe_die("pipe2 failed");
+#else
+ if (pipe(pipe_fds_))
+ tunsafe_die("pipe failed");
+ for (int i = 0; i < 2; i++) {
+ fcntl(pipe_fds_[i], F_SETFD, FD_CLOEXEC);
+ fcntl(pipe_fds_[i], F_SETFL, O_NONBLOCK);
+ }
+#endif
+
+
+ InitPollSlot(pipe_fds_[0], POLLIN);
+}
+
+NotificationPipeBsd::~NotificationPipeBsd() {
+}
+
+void NotificationPipeBsd::InjectCallback(CallbackFunc *func, void *param) {
+ CallbackState *st = new CallbackState;
+ st->func = func;
+ st->param = param;
+ // todo: support multiple writers?
+ st->next = injected_cb_.exchange(NULL);
+ injected_cb_.exchange(st);
+ write(pipe_fds_[1], "", 1);
+}
+
+void NotificationPipeBsd::Wakeup() {
+ write(pipe_fds_[1], "", 1);
+}
+
+void NotificationPipeBsd::HandleEvents(int revents) {
+ if (revents & (POLLERR | POLLHUP | POLLNVAL)) {
+ RERROR("Error with pipe() polling");
+ CloseSocket();
+ } else if (revents & POLLIN) {
+ char tmp[64];
+ read(fd_, tmp, sizeof(tmp));
+ if (CallbackState *cb = injected_cb_.exchange(NULL)) {
+ do {
+ CallbackState *next = cb->next;
+ cb->func(cb->param);
+ cb = next;
+ } while (cb);
+ }
+ }
+
}
diff --git a/network_bsd.h b/network_bsd.h
new file mode 100644
index 0000000..6f2c711
--- /dev/null
+++ b/network_bsd.h
@@ -0,0 +1,304 @@
+// SPDX-License-Identifier: AGPL-1.0-only
+// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved.
+#ifndef TUNSAFE_NETWORK_BSD_H_
+#define TUNSAFE_NETWORK_BSD_H_
+
+#include
+#include
+#include
+#include
+#include "network_common.h"
+
+class BaseSocketBsd;
+class TcpSocketBsd;
+class WireguardProcessor;
+class Packet;
+
+class NetworkBsd {
+ friend class BaseSocketBsd;
+ friend class TcpSocketBsd;
+ friend class UdpSocketBsd;
+ friend class TunSocketBsd;
+public:
+ enum {
+#if defined(OS_ANDROID)
+ WithSigalarmSupport = 0,
+#else
+ WithSigalarmSupport = 1
+#endif
+ };
+
+ class NetworkBsdDelegate {
+ public:
+ virtual void OnSecondLoop(uint64 now) {}
+ virtual void RunAllMainThreadScheduled() {}
+ };
+
+ explicit NetworkBsd(NetworkBsdDelegate *delegate, int max_sockets);
+ ~NetworkBsd();
+
+ void RunLoop(const sigset_t *sigmask);
+ void PostExit() { exit_ = true; }
+
+ bool *exit_flag() { return &exit_; }
+ bool *sigalarm_flag() { return &sigalarm_flag_; }
+
+ TcpSocketBsd *tcp_sockets() { return tcp_sockets_; }
+ bool overload() { return overload_; }
+private:
+ void RemoveFromRoundRobin(int slot);
+
+ void ReallocateIov(size_t i);
+ void EnsureIovAllocated();
+
+ Packet *read_packet_;
+ bool exit_;
+ bool overload_;
+ bool sigalarm_flag_;
+
+ enum {
+ // This controls the max # of sockets we can support
+ kMaxIovec = 16,
+ };
+ int num_sock_;
+ int num_roundrobin_;
+ int num_endloop_;
+ int max_sockets_;
+
+ SimplePacketPool packet_pool_;
+ NetworkBsdDelegate *delegate_;
+
+ struct pollfd *pollfd_;
+ BaseSocketBsd **sockets_;
+ BaseSocketBsd **roundrobin_;
+ BaseSocketBsd **endloop_;
+
+ // Linked list of all tcp sockets
+ TcpSocketBsd *tcp_sockets_;
+
+ struct iovec iov_[kMaxIovec];
+ Packet *iov_packets_[kMaxIovec];
+
+};
+
+class BaseSocketBsd {
+ friend class NetworkBsd;
+public:
+ BaseSocketBsd(NetworkBsd *network) : pollfd_slot_(-1), roundrobin_slot_(-1), endloop_slot_(-1), fd_(-1), network_(network) {}
+ virtual ~BaseSocketBsd();
+
+ virtual void HandleEvents(int revents) = 0;
+
+ // Return |false| to remove socket from roundrobin list.
+ virtual bool DoRoundRobin() { return false; }
+ virtual void DoEndloop() {}
+ virtual void Periodic() {}
+
+ // Make sure this socket gets called during each round robin step.
+ void AddToRoundRobin();
+
+ // Make sure this sockets get called at the end of the loop
+ void AddToEndLoop();
+
+ int GetFd() { return fd_; }
+
+protected:
+ void SetPollFlags(int events) {
+ network_->pollfd_[pollfd_slot_].events = events;
+ }
+ void InitPollSlot(int fd, int events);
+ bool HasFreePollSlot() { return network_->num_sock_ != network_->max_sockets_; }
+ void CloseSocket();
+
+ NetworkBsd *network_;
+ int pollfd_slot_;
+ int roundrobin_slot_;
+ int endloop_slot_;
+ int fd_;
+};
+
+class TunSocketBsd : public BaseSocketBsd {
+public:
+ explicit TunSocketBsd(NetworkBsd *network, WireguardProcessor *processor);
+ virtual ~TunSocketBsd();
+
+ bool Initialize(int fd);
+
+ virtual void HandleEvents(int revents) override;
+ virtual bool DoRoundRobin() override;
+
+ void WritePacket(Packet *packet);
+
+ bool tun_interface_gone() const { return tun_interface_gone_; }
+
+private:
+ bool DoRead();
+ bool DoWrite();
+
+ bool tun_readable_, tun_writable_;
+ bool tun_interface_gone_;
+ Packet *tun_queue_, **tun_queue_end_;
+ WireguardProcessor *processor_;
+};
+
+class UdpSocketBsd : public BaseSocketBsd {
+public:
+ explicit UdpSocketBsd(NetworkBsd *network, WireguardProcessor *processor);
+ virtual ~UdpSocketBsd();
+
+ bool Initialize(int listen_port);
+
+ virtual void HandleEvents(int revents) override;
+ virtual bool DoRoundRobin() override;
+
+ bool DoRead();
+ bool DoWrite();
+
+ void WritePacket(Packet *packet);
+
+private:
+ bool udp_readable_, udp_writable_;
+ Packet *udp_queue_, **udp_queue_end_;
+ WireguardProcessor *processor_;
+};
+
+#if defined(OS_LINUX)
+// Keeps track of when the unix socket gets deleted
+class UnixSocketDeletionWatcher {
+public:
+ UnixSocketDeletionWatcher();
+ ~UnixSocketDeletionWatcher();
+ bool Start(const char *path, bool *flag_to_set);
+ void Stop();
+ bool Poll(const char *path) { return false; }
+
+private:
+ static void *RunThread(void *arg);
+ void *RunThreadInner();
+ const char *path_;
+ int inotify_fd_;
+ int pid_;
+ int pipes_[2];
+ pthread_t thread_;
+ bool *flag_to_set_;
+};
+#else // !defined(OS_LINUX)
+// all other platforms that lack inotify
+class UnixSocketDeletionWatcher {
+public:
+ UnixSocketDeletionWatcher() {}
+ ~UnixSocketDeletionWatcher() {}
+ bool Start(const char *path, bool *flag_to_set) { return true; }
+ void Stop() {}
+ bool Poll(const char *path);
+};
+#endif // !defined(OS_LINUX)
+
+class UnixDomainSocketListenerBsd : public BaseSocketBsd {
+public:
+ explicit UnixDomainSocketListenerBsd(NetworkBsd *network, WireguardProcessor *processor);
+ virtual ~UnixDomainSocketListenerBsd();
+
+ bool Initialize(const char *devname);
+
+ bool Start(bool *exit_flag) {
+ return un_deletion_watcher_.Start(un_addr_.sun_path, exit_flag);
+ }
+ void Stop() { un_deletion_watcher_.Stop(); }
+
+ virtual void HandleEvents(int revents) override;
+ virtual void Periodic() override;
+private:
+ struct sockaddr_un un_addr_;
+ WireguardProcessor *processor_;
+ UnixSocketDeletionWatcher un_deletion_watcher_;
+};
+
+class UnixDomainSocketChannelBsd : public BaseSocketBsd {
+public:
+ explicit UnixDomainSocketChannelBsd(NetworkBsd *network, WireguardProcessor *processor, int fd);
+ virtual ~UnixDomainSocketChannelBsd();
+
+ virtual void HandleEvents(int revents) override;
+
+private:
+ bool HandleEventsInner(int revents);
+ WireguardProcessor *processor_;
+ std::string inbuf_, outbuf_;
+};
+
+class TcpSocketListenerBsd : public BaseSocketBsd {
+public:
+ explicit TcpSocketListenerBsd(NetworkBsd *bsd, WireguardProcessor *processor);
+ virtual ~TcpSocketListenerBsd();
+
+ bool Initialize(int port);
+
+ virtual void HandleEvents(int revents) override;
+ virtual void Periodic() override;
+
+private:
+ WireguardProcessor *processor_;
+};
+
+class TcpSocketBsd : public BaseSocketBsd {
+public:
+ explicit TcpSocketBsd(NetworkBsd *bsd, WireguardProcessor *processor);
+ virtual ~TcpSocketBsd();
+
+ void InitializeIncoming(int fd, const IpAddr &addr);
+ bool InitializeOutgoing(const IpAddr &addr);
+
+ void WritePacket(Packet *packet);
+
+ virtual void HandleEvents(int revents) override;
+ virtual void DoEndloop() override;
+
+ TcpSocketBsd *next() { return next_; }
+ uint8 endpoint_protocol() { return endpoint_protocol_; }
+ const IpAddr &endpoint() { return endpoint_; }
+
+public:
+ uint8 age;
+ uint8 handshake_attempts;
+private:
+ void DoRead();
+ void DoWrite();
+ void CloseSocketAndDestroy();
+
+ bool readable_, writable_;
+ bool got_eof_;
+ uint8 endpoint_protocol_;
+ bool want_connect_;
+
+ uint32 wqueue_bytes_;
+ Packet *wqueue_, **wqueue_end_;
+ TcpSocketBsd *next_;
+ WireguardProcessor *processor_;
+ TcpPacketHandler tcp_packet_handler_;
+ IpAddr endpoint_;
+};
+
+class NotificationPipeBsd : public BaseSocketBsd {
+public:
+ NotificationPipeBsd(NetworkBsd *network);
+ ~NotificationPipeBsd();
+
+ typedef void CallbackFunc(void *x);
+ void InjectCallback(CallbackFunc *func, void *param);
+ void Wakeup();
+
+ virtual void HandleEvents(int revents) override;
+
+private:
+ struct CallbackState {
+ CallbackFunc *func;
+ void *param;
+ CallbackState *next;
+ };
+ int pipe_fds_[2];
+ std::atomic injected_cb_;
+};
+
+
+#endif // TUNSAFE_NETWORK_BSD_H_
\ No newline at end of file
diff --git a/network_bsd_common.h b/network_bsd_common.h
deleted file mode 100644
index c4ef5ad..0000000
--- a/network_bsd_common.h
+++ /dev/null
@@ -1,101 +0,0 @@
-// SPDX-License-Identifier: AGPL-1.0-only
-// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved.
-#ifndef TUNSAFE_NETWORK_BSD_COMMON_H_
-#define TUNSAFE_NETWORK_BSD_COMMON_H_
-
-#include "netapi.h"
-#include "wireguard.h"
-#include "wireguard_config.h"
-#include
-#include
-
-struct RouteInfo {
- uint8 family;
- uint8 cidr;
- uint8 ip[16];
- uint8 gw[16];
- std::string dev;
-};
-
-#if defined(OS_LINUX)
-// Keeps track of when the unix socket gets deleted
-class UnixSocketDeletionWatcher {
-public:
- UnixSocketDeletionWatcher();
- ~UnixSocketDeletionWatcher();
- bool Start(const char *path, bool *flag_to_set);
- void Stop();
- bool Poll(const char *path) { return false; }
-
-private:
- static void *RunThread(void *arg);
- void *RunThreadInner();
- const char *path_;
- int inotify_fd_;
- int pid_;
- int pipes_[2];
- pthread_t thread_;
- bool *flag_to_set_;
-};
-#else // !defined(OS_LINUX)
-// all other platforms that lack inotify
-class UnixSocketDeletionWatcher {
-public:
- UnixSocketDeletionWatcher() {}
- ~UnixSocketDeletionWatcher() {}
- bool Start(const char *path, bool *flag_to_set) { return true; }
- void Stop() {}
- bool Poll(const char *path);
-};
-#endif // !defined(OS_LINUX)
-
-
-class TunsafeBackendBsd : public TunInterface, public UdpInterface {
-public:
- TunsafeBackendBsd();
- virtual ~TunsafeBackendBsd();
-
- void RunLoop();
- void CleanupRoutes();
-
- void SetTunDeviceName(const char *name);
-
- void SetProcessor(WireguardProcessor *wg) { processor_ = wg; }
-
- // -- from TunInterface
- virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override;
-
- virtual void HandleSigAlrm() = 0;
- virtual void HandleExit() = 0;
-
-protected:
- virtual bool InitializeTun(char devname[16]) = 0;
- virtual void RunLoopInner() = 0;
-
- void AddRoute(uint32 ip, uint32 cidr, uint32 gw, const char *dev);
- void DelRoute(const RouteInfo &cd);
- bool AddRoute(int family, const void *dest, int dest_prefix, const void *gateway, const char *dev);
- bool RunPrePostCommand(const std::vector &vec);
-
- WireguardProcessor *processor_;
- std::vector cleanup_commands_;
- std::vector pre_down_, post_down_;
- std::vector addresses_to_remove_;
- sigset_t orig_signal_mask_;
- char devname_[16];
- bool tun_interface_gone_;
-};
-
-#if defined(OS_MACOSX) || defined(OS_FREEBSD)
-#define TUN_PREFIX_BYTES 4
-#elif defined(OS_LINUX)
-#define TUN_PREFIX_BYTES 0
-#endif
-
-int open_tun(char *devname, size_t devname_size);
-int open_udp(int listen_on_port);
-
-void SetThreadName(const char *name);
-TunsafeBackendBsd *CreateTunsafeBackendBsd();
-
-#endif // TUNSAFE_NETWORK_BSD_COMMON_H_
\ No newline at end of file
diff --git a/network_common.cpp b/network_common.cpp
new file mode 100644
index 0000000..919e7fd
--- /dev/null
+++ b/network_common.cpp
@@ -0,0 +1,174 @@
+#include "stdafx.h"
+#include "network_common.h"
+#include "netapi.h"
+#include "tunsafe_endian.h"
+#include
+#include
+#include "util.h"
+
+TcpPacketHandler::TcpPacketHandler(SimplePacketPool *packet_pool) {
+ packet_pool_ = packet_pool;
+ rqueue_bytes_ = 0;
+ error_flag_ = false;
+ rqueue_ = NULL;
+ rqueue_end_ = &rqueue_;
+ predicted_key_in_ = predicted_key_out_ = 0;
+ predicted_serial_in_ = predicted_serial_out_ = 0;
+}
+
+TcpPacketHandler::~TcpPacketHandler() {
+ FreePacketList(rqueue_);
+}
+
+enum {
+ kTcpPacketType_Normal = 0,
+ kTcpPacketType_Reserved = 1,
+ kTcpPacketType_Data = 2,
+ kTcpPacketType_Control = 3,
+ kTcpPacketControlType_SetKeyAndCounter = 0,
+};
+
+void TcpPacketHandler::AddHeaderToOutgoingPacket(Packet *p) {
+ unsigned int size = p->size;
+ uint8 *data = p->data;
+ if (size >= 16 && ReadLE32(data) == 4) {
+ uint32 key = Read32(data + 4);
+ uint64 serial = ReadLE64(data + 8);
+ WriteBE16(data + 14, size - 16 + (kTcpPacketType_Data << 14));
+ data += 14, size -= 14;
+ // Insert a 15 byte control packet right before to set the new key/serial?
+ if ((predicted_key_out_ ^ key) | (predicted_serial_out_ ^ serial)) {
+ predicted_key_out_ = key;
+ WriteLE64(data - 8, serial);
+ Write32(data - 12, key);
+ data[-13] = kTcpPacketControlType_SetKeyAndCounter;
+ WriteBE16(data - 15, 13 + (kTcpPacketType_Control << 14));
+ data -= 15, size += 15;
+ }
+ // Increase the serial by 1 for next packet.
+ predicted_serial_out_ = serial + 1;
+ } else {
+ WriteBE16(data - 2, size);
+ data -= 2, size += 2;
+ }
+ p->size = size;
+ p->data = data;
+}
+
+void TcpPacketHandler::QueueIncomingPacket(Packet *p) {
+ rqueue_bytes_ += p->size;
+ p->queue_next = NULL;
+ *rqueue_end_ = p;
+ rqueue_end_ = &Packet_NEXT(p);
+}
+
+// Either the packet fits in one buf or not.
+static uint32 ReadPacketHeader(Packet *p) {
+ if (p->size >= 2)
+ return ReadBE16(p->data);
+ else
+ return (p->data[0] << 8) + (Packet_NEXT(p)->data[0]);
+}
+
+// Move data around to ensure that exactly the first |num| bytes are stored
+// in the first packet, and the rest of the data in subsequent packets.
+Packet *TcpPacketHandler::ReadNextPacket(uint32 num) {
+ Packet *p = rqueue_;
+
+ assert(num <= kPacketCapacity);
+ if (p->size < num) {
+ // There's not enough data in the current packet, copy data from the next packet
+ // into this packet.
+ if ((uint32)(&p->data_buf[kPacketCapacity] - p->data) < num) {
+ // Move data up front to make space.
+ memmove(p->data_buf, p->data, p->size);
+ p->data = p->data_buf;
+ }
+ // Copy data from future packets into p, and delete them should they become empty.
+ do {
+ Packet *n = Packet_NEXT(p);
+ uint32 bytes_to_copy = std::min(n->size, num - p->size);
+ uint32 nsize = (n->size -= bytes_to_copy);
+ memcpy(p->data + postinc(p->size, bytes_to_copy), postinc(n->data, bytes_to_copy), bytes_to_copy);
+ if (nsize == 0) {
+ p->queue_next = n->queue_next;
+ packet_pool_->FreePacketToPool(n);
+ }
+ } while (num - p->size);
+ } else if (p->size > num) {
+ // The packet has too much data. Split the packet into two packets.
+ Packet *n = packet_pool_->AllocPacketFromPool();
+ if (!n)
+ return NULL; // unable to allocate a packet....?
+ if (num * 2 <= p->size) {
+ // There's a lot of trailing data: PP NNNNNN. Move PP.
+ n->size = num;
+ p->size -= num;
+ rqueue_bytes_ -= num;
+ memcpy(n->data, postinc(p->data, num), num);
+ return n;
+ } else {
+ uint32 overflow = p->size - num;
+ // There's a lot of leading data: PPPPPP NN. Move NN
+ n->size = overflow;
+ p->size = num;
+ rqueue_ = n;
+ if (!(n->queue_next = p->queue_next))
+ rqueue_end_ = &Packet_NEXT(n);
+ rqueue_bytes_ -= num;
+ memcpy(n->data, p->data + num, overflow);
+ return p;
+ }
+ }
+ if ((rqueue_ = Packet_NEXT(p)) == NULL)
+ rqueue_end_ = &rqueue_;
+ rqueue_bytes_ -= num;
+ return p;
+}
+
+Packet *TcpPacketHandler::GetNextWireguardPacket() {
+ while (rqueue_bytes_ >= 2) {
+ uint32 packet_header = ReadPacketHeader(rqueue_);
+ uint32 packet_size = packet_header & 0x3FFF;
+ uint32 packet_type = packet_header >> 14;
+ if (packet_size + 2 > rqueue_bytes_)
+ return NULL;
+ if (packet_size + 2 > kPacketCapacity) {
+ RERROR("Oversized packet?");
+ error_flag_ = true;
+ return NULL;
+ }
+ Packet *packet = ReadNextPacket(packet_size + 2);
+ if (packet) {
+// RINFO("Packet of type %d, size %d", packet_type, packet->size - 2);
+ packet->data += 2, packet->size -= 2;
+ if (packet_type == kTcpPacketType_Normal) {
+
+ return packet;
+ } else if (packet_type == kTcpPacketType_Data) {
+ // Optimization when the 16 first bytes are known and prefixed to the packet
+ assert(packet->data >= packet->data_buf);
+
+ packet->data -= 16, packet->size += 16;
+ WriteLE32(packet->data, 4);
+ Write32(packet->data + 4, predicted_key_in_);
+ WriteLE64(packet->data + 8, predicted_serial_in_);
+ predicted_serial_in_++;
+ return packet;
+ } else if (packet_type == kTcpPacketType_Control) {
+ // Unknown control packets are silently ignored
+ if (packet->size == 13 && packet->data[0] == kTcpPacketControlType_SetKeyAndCounter) {
+ // Control packet to setup the predicted key/sequence nr
+ predicted_key_in_ = Read32(packet->data + 1);
+ predicted_serial_in_ = ReadLE64(packet->data + 5);
+ }
+ packet_pool_->FreePacketToPool(packet);
+ } else {
+ packet_pool_->FreePacketToPool(packet);
+ error_flag_ = true;
+ return NULL;
+ }
+ }
+ }
+ return NULL;
+}
\ No newline at end of file
diff --git a/network_common.h b/network_common.h
new file mode 100644
index 0000000..8814cd4
--- /dev/null
+++ b/network_common.h
@@ -0,0 +1,95 @@
+#ifndef TUNSAFE_NETWORK_COMMON_H_
+#define TUNSAFE_NETWORK_COMMON_H_
+
+#include "netapi.h"
+
+class PacketProcessor;
+
+// A simple singlethreaded pool of packets used on windows where
+// FreePacket / AllocPacket are multithreded and thus slightly slower
+#if defined(OS_WIN)
+class SimplePacketPool {
+public:
+ explicit SimplePacketPool() {
+ freed_packets_ = NULL;
+ freed_packets_count_ = 0;
+ }
+ ~SimplePacketPool() {
+ FreePacketList(freed_packets_);
+ }
+ Packet *AllocPacketFromPool() {
+ if (Packet *p = freed_packets_) {
+ freed_packets_ = Packet_NEXT(p);
+ freed_packets_count_--;
+ p->Reset();
+ return p;
+ }
+ return AllocPacket();
+ }
+ void FreePacketToPool(Packet *p) {
+ Packet_NEXT(p) = freed_packets_;
+ freed_packets_ = p;
+ freed_packets_count_++;
+ }
+ void FreeSomePackets() {
+ if (freed_packets_count_ > 32)
+ FreeSomePacketsInner();
+ }
+ void FreeSomePacketsInner();
+
+
+ int freed_packets_count_;
+ Packet *freed_packets_;
+};
+#else
+class SimplePacketPool {
+public:
+ Packet *AllocPacketFromPool() {
+ return AllocPacket();
+ }
+ void FreePacketToPool(Packet *packet) {
+ return FreePacket(packet);
+ }
+};
+#endif
+
+
+
+// Aids with prefixing and parsing incoming and outgoing
+// packets with the tcp protocol header.
+class TcpPacketHandler {
+public:
+ explicit TcpPacketHandler(SimplePacketPool *packet_pool);
+ ~TcpPacketHandler();
+
+ // Adds a tcp header to a data packet so it can be transmitted on the wire
+ void AddHeaderToOutgoingPacket(Packet *p);
+
+ // Add a new chunk of incoming data to the packet list
+ void QueueIncomingPacket(Packet *p);
+
+ // Attempt to extract the next packet, returns NULL when complete.
+ Packet *GetNextWireguardPacket();
+
+ bool error() const { return error_flag_; }
+
+private:
+ // Internal function to read a packet
+ Packet *ReadNextPacket(uint32 num);
+
+ SimplePacketPool *packet_pool_;
+
+ // Total # of bytes queued
+ uint32 rqueue_bytes_;
+
+ // Set if there's a fatal error
+ bool error_flag_;
+
+ // These hold the incoming packets before they're parsed
+ Packet *rqueue_, **rqueue_end_;
+
+ uint32 predicted_key_in_, predicted_key_out_;
+ uint64 predicted_serial_in_, predicted_serial_out_;
+};
+
+#endif // TUNSAFE_NETWORK_COMMON_H_
\ No newline at end of file
diff --git a/network_win32.cpp b/network_win32.cpp
index 2a4bfe7..a946bc3 100644
--- a/network_win32.cpp
+++ b/network_win32.cpp
@@ -5,6 +5,8 @@
#include "wireguard_config.h"
#include "netapi.h"
#include
+#include
+#include
#include
#include
#include
@@ -12,7 +14,6 @@
#include
#include
#include
-#include
#include
#include
#include "tunsafe_endian.h"
@@ -42,15 +43,20 @@ static HKEY g_hklm_reg_key;
static uint8 g_killswitch_curr, g_killswitch_want, g_killswitch_currconn;
bool g_allow_pre_post;
+static volatile bool g_fail_malloc_flag;
static void DeactivateKillSwitch(uint32 want);
Packet *AllocPacket() {
Packet *packet = (Packet*)InterlockedPopEntrySList(&freelist_head);
- if (packet == NULL)
- packet = (Packet *)_aligned_malloc(kPacketAllocSize, 16);
- packet->data = packet->data_buf + Packet::HEADROOM_BEFORE;
- packet->size = 0;
+ if (packet == NULL) {
+ while ((packet = (Packet *)_aligned_malloc(kPacketAllocSize, 16)) == NULL) {
+ if (g_fail_malloc_flag)
+ return NULL;
+ Sleep(1000);
+ }
+ }
+ packet->Reset();
return packet;
}
@@ -83,6 +89,14 @@ void FreeAllPackets() {
}
}
+void SimplePacketPool::FreeSomePacketsInner() {
+ int n = freed_packets_count_ - 24;
+ Packet **p = &freed_packets_;
+ for (; n; n--)
+ p = &Packet_NEXT(*p);
+ FreePackets(exch(freed_packets_, *p), p, exch(freed_packets_count_, 24) - 24);
+}
+
void InitPacketMutexes() {
static bool mutex_inited;
if (!mutex_inited) {
@@ -91,17 +105,6 @@ void InitPacketMutexes() {
}
}
-int tpq_last_qsize;
-int g_tun_reads, g_tun_writes;
-
-struct {
- uint32 pad1[3];
- uint32 udp_qsize1;
- uint32 pad2[3];
- uint32 udp_qsize2;
-} qs;
-
-
#define kConcurrentReadTap 16
#define kConcurrentWriteTap 16
@@ -399,47 +402,13 @@ static bool GetDefaultRouteAndDeleteOldRoutes(int family, const NET_LUID *Interf
return (rv == 0);
}
-static inline bool NoMoreAllocationRetry(volatile bool *exit_flag) {
- if (*exit_flag)
- return true;
- Sleep(1000);
- return *exit_flag;
-}
-
-static inline bool AllocPacketFrom(Packet **list, int *counter, bool *exit_flag, Packet **res) {
- Packet *p;
- if (p = *list) {
- *list = Packet_NEXT(p);
- (*counter)--;
- p->data = p->data_buf + Packet::HEADROOM_BEFORE;
- } else {
- while ((p = AllocPacket()) == NULL) {
- if (NoMoreAllocationRetry(exit_flag))
- return false;
- }
- }
- *res = p;
- return true;
-}
-
-static void FreePacketList(Packet *pp) {
+void FreePacketList(Packet *pp) {
while (Packet *p = pp) {
pp = Packet_NEXT(p);
FreePacket(p);
}
}
-inline void NetworkWin32::FreePacketToPool(Packet *p) {
- Packet_NEXT(p) = NULL;
- *freed_packets_end_ = p;
- freed_packets_end_ = &Packet_NEXT(p);
- freed_packets_count_++;
-}
-
-inline bool NetworkWin32::AllocPacketFromPool(Packet **p) {
- return AllocPacketFrom(&freed_packets_, &freed_packets_count_, &exit_thread_, p);
-}
-
UdpSocketWin32::UdpSocketWin32(NetworkWin32 *network_win32) {
network_ = network_win32;
wqueue_end_ = &wqueue_;
@@ -455,6 +424,9 @@ UdpSocketWin32::UdpSocketWin32(NetworkWin32 *network_win32) {
num_reads_[0] = num_reads_[1] = 0;
num_writes_ = 0;
pending_writes_ = NULL;
+
+ qsize1_ = 0;
+ qsize2_ = 0;
}
UdpSocketWin32::~UdpSocketWin32() {
@@ -529,12 +501,12 @@ fail:
// Called on another thread to queue up a udp packet
void UdpSocketWin32::WriteUdpPacket(Packet *packet) {
- if (qs.udp_qsize2 - qs.udp_qsize1 >= (unsigned)(packet->size < 576 ? MAX_BYTES_IN_UDP_OUT_QUEUE_SMALL : MAX_BYTES_IN_UDP_OUT_QUEUE)) {
+ if (qsize2_ - qsize1_ >= (unsigned)(packet->size < 576 ? MAX_BYTES_IN_UDP_OUT_QUEUE_SMALL : MAX_BYTES_IN_UDP_OUT_QUEUE)) {
FreePacket(packet);
return;
}
- Packet_NEXT(packet) = NULL;
- qs.udp_qsize2 += packet->size;
+ packet->queue_next = NULL;
+ qsize2_ += packet->size;
mutex_.Acquire();
Packet *was_empty = wqueue_;
@@ -542,20 +514,14 @@ void UdpSocketWin32::WriteUdpPacket(Packet *packet) {
wqueue_end_ = &Packet_NEXT(packet);
mutex_.Release();
- if (was_empty == NULL) {
- // Notify the worker thread that it should attempt more writes
- PostQueuedCompletionStatus(network_->completion_port_handle_, NULL, NULL, NULL);
- }
+ if (was_empty == NULL)
+ network_->WakeUp();
}
enum {
kUdpGetQueuedCompletionStatusSize = kConcurrentWriteTap + kConcurrentReadTap + 1
};
-static inline void ClearOverlapped(OVERLAPPED *o) {
- memset(o, 0, sizeof(*o));
-}
-
#ifndef STATUS_PORT_UNREACHABLE
#define STATUS_PORT_UNREACHABLE 0xC000023F
#endif
@@ -567,8 +533,8 @@ static inline bool IsIgnoredUdpError(DWORD err) {
void UdpSocketWin32::DoMoreReads() {
// Listen with multiple ipv6 packets only if we ever sent an ipv6 packet.
for (int i = num_reads_[IPV6]; i < max_read_ipv6_; i++) {
- Packet *p;
- if (!network_->AllocPacketFromPool(&p))
+ Packet *p = network_->packet_pool().AllocPacketFromPool();
+ if (!p)
break;
restart_read_udp6:
ClearOverlapped(&p->overlapped);
@@ -590,32 +556,35 @@ restart_read_udp6:
num_reads_[IPV6]++;
}
// Initiate more reads, reusing the Packet structures in |finished_writes|.
- for (int i = num_reads_[IPV4]; i < kConcurrentReadUdp; i++) {
- Packet *p;
- if (!network_->AllocPacketFromPool(&p))
- break;
-restart_read_udp:
- ClearOverlapped(&p->overlapped);
- WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data};
- DWORD flags = 0;
- p->userdata = IPV4;
- p->sin_size = sizeof(p->addr.sin);
- p->queue_cb = this;
- if (WSARecvFrom(socket_, &wsabuf, 1, NULL, &flags, (struct sockaddr*)&p->addr, &p->sin_size, &p->overlapped, NULL) != 0) {
- DWORD err = WSAGetLastError();
- if (err != WSA_IO_PENDING) {
- if (err == WSAEMSGSIZE || err == WSAECONNRESET || err == WSAENETRESET)
- goto restart_read_udp;
- RERROR("UdpSocketWin32:WSARecvFrom failed 0x%X", err);
- FreePacket(p);
+
+ if (socket_ != INVALID_SOCKET) {
+ for (int i = num_reads_[IPV4]; i < kConcurrentReadUdp; i++) {
+ Packet *p = network_->packet_pool().AllocPacketFromPool();
+ if (!p)
break;
+restart_read_udp:
+ ClearOverlapped(&p->overlapped);
+ WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data};
+ DWORD flags = 0;
+ p->userdata = IPV4;
+ p->sin_size = sizeof(p->addr.sin);
+ p->queue_cb = this;
+ if (WSARecvFrom(socket_, &wsabuf, 1, NULL, &flags, (struct sockaddr*)&p->addr, &p->sin_size, &p->overlapped, NULL) != 0) {
+ DWORD err = WSAGetLastError();
+ if (err != WSA_IO_PENDING) {
+ if (err == WSAEMSGSIZE || err == WSAECONNRESET || err == WSAENETRESET)
+ goto restart_read_udp;
+ RERROR("UdpSocketWin32:WSARecvFrom failed 0x%X", err);
+ FreePacket(p);
+ break;
+ }
}
+ num_reads_[IPV4]++;
}
- num_reads_[IPV4]++;
}
}
-void UdpSocketWin32::DoMoreWrites() {
+void UdpSocketWin32::ProcessPackets() {
// Push all the finished reads to the packet handler
if (finished_reads_ != NULL) {
packet_handler_->PostPackets(finished_reads_, finished_reads_end_, finished_reads_count_);
@@ -623,7 +592,9 @@ void UdpSocketWin32::DoMoreWrites() {
finished_reads_end_ = &finished_reads_;
finished_reads_count_ = 0;
}
-
+}
+
+void UdpSocketWin32::DoMoreWrites() {
Packet *pending_writes = pending_writes_;
// Initiate more writes from |wqueue_|
while (num_writes_ < kConcurrentWriteUdp) {
@@ -639,7 +610,7 @@ void UdpSocketWin32::DoMoreWrites() {
if (!pending_writes)
break;
}
- qs.udp_qsize1 += pending_writes->size;
+ qsize1_ += pending_writes->size;
// Then issue writes
Packet *p = pending_writes;
@@ -688,13 +659,14 @@ void UdpSocketWin32::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) {
if (p->userdata < 2) {
num_reads_[p->userdata]--;
if ((DWORD)p->overlapped.Internal != 0) {
- network_->FreePacketToPool(p);
if (!IsIgnoredUdpError((DWORD)p->overlapped.Internal))
RERROR("UdpSocketWin32::Read error 0x%X", (DWORD)p->overlapped.Internal);
+ network_->packet_pool().FreePacketToPool(p);
} else {
// Remember all the finished packets and queue them up to the next thread once we've
// collected them all.
p->size = (int)p->overlapped.InternalHigh;
+ p->protocol = kPacketProtocolUdp;
p->queue_cb = packet_handler_->udp_queue();
p->queue_next = NULL;
*finished_reads_end_ = p;
@@ -703,9 +675,9 @@ void UdpSocketWin32::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) {
}
} else {
num_writes_--;
- network_->FreePacketToPool(p);
if ((DWORD)p->overlapped.Internal != 0)
RERROR("UdpSocketWin32::Write error 0x%X", (DWORD)p->overlapped.Internal);
+ network_->packet_pool().FreePacketToPool(p);
}
}
@@ -716,28 +688,30 @@ void UdpSocketWin32::OnQueuedItemDelete(QueuedItem *qi) {
} else {
num_writes_--;
}
- network_->FreePacketToPool(p);
+ network_->packet_pool().FreePacketToPool(p);
}
void UdpSocketWin32::DoIO() {
DoMoreWrites();
+ ProcessPackets();
DoMoreReads();
}
////////////////////////////////////////////////////////////////////////////////////////////////////////
-NetworkWin32::NetworkWin32() : udp_socket_(this) {
+NetworkWin32::NetworkWin32() : udp_socket_(this), tcp_socket_queue_(this) {
exit_thread_ = false;
thread_ = NULL;
- freed_packets_ = NULL;
- freed_packets_end_ = &freed_packets_;
- freed_packets_count_ = 0;
completion_port_handle_ = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, NULL, 0);
+ tcp_socket_ = NULL;
}
NetworkWin32::~NetworkWin32() {
assert(thread_ == NULL);
+
+ for (TcpSocketWin32 *socket = tcp_socket_; socket; )
+ delete exch(socket, socket->next_);
+
CloseHandle(completion_port_handle_);
- FreePacketList(freed_packets_);
}
DWORD WINAPI NetworkWin32::NetworkThread(void *x) {
@@ -750,20 +724,13 @@ void NetworkWin32::ThreadMain() {
OVERLAPPED_ENTRY entries[kUdpGetQueuedCompletionStatusSize];
while (!exit_thread_) {
- // Run IO on all sockets queued for IO
+ // TODO: In the future, don't process every socket here, only
+ // those sockets that requested it.
udp_socket_.DoIO();
+ for (TcpSocketWin32 *tcp = tcp_socket_; tcp;)
+ exch(tcp, tcp->next_)->DoIO();
- // Free some packets
- assert(freed_packets_count_ >= 0);
- if (freed_packets_count_ >= 32) {
- FreePackets(freed_packets_, freed_packets_end_, freed_packets_count_);
- freed_packets_count_ = 0;
- freed_packets_ = NULL;
- freed_packets_end_ = &freed_packets_;
- } else if (freed_packets_ == NULL) {
- assert(freed_packets_count_ == 0);
- freed_packets_end_ = &freed_packets_;
- }
+ packet_pool_.FreeSomePackets();
ULONG num_entries = 0;
if (!GetQueuedCompletionStatusEx(completion_port_handle_, entries, kUdpGetQueuedCompletionStatusSize, &num_entries, INFINITE, FALSE)) {
@@ -779,20 +746,33 @@ void NetworkWin32::ThreadMain() {
}
udp_socket_.CancelAllIO();
+ for (TcpSocketWin32 *tcp = tcp_socket_; tcp; tcp = tcp->next_)
+ tcp->CancelAllIO();
- while (udp_socket_.HasOutstandingIO()) {
+ while (HasOutstandingIO()) {
ULONG num_entries = 0;
- if (!GetQueuedCompletionStatusEx(completion_port_handle_, entries, 1, &num_entries, INFINITE, FALSE)) {
+ if (!GetQueuedCompletionStatusEx(completion_port_handle_, entries, kUdpGetQueuedCompletionStatusSize, &num_entries, INFINITE, FALSE)) {
RINFO("GetQueuedCompletionStatusEx failed.");
break;
}
- if (entries[0].lpOverlapped) {
- QueuedItem *w = (QueuedItem*)((byte*)entries[0].lpOverlapped - offsetof(QueuedItem, overlapped));
- w->queue_cb->OnQueuedItemDelete(w);
+ for (ULONG i = 0; i < num_entries; i++) {
+ if (entries[i].lpOverlapped) {
+ QueuedItem *w = (QueuedItem*)((byte*)entries[i].lpOverlapped - offsetof(QueuedItem, overlapped));
+ w->queue_cb->OnQueuedItemDelete(w);
+ }
}
}
}
+bool NetworkWin32::HasOutstandingIO() {
+ if (udp_socket_.HasOutstandingIO())
+ return true;
+ for (TcpSocketWin32 *tcp = tcp_socket_; tcp; tcp = tcp->next_)
+ if (tcp->HasOutstandingIO())
+ return true;
+ return false;
+}
+
void NetworkWin32::StartThread() {
assert(completion_port_handle_);
@@ -804,11 +784,36 @@ void NetworkWin32::StartThread() {
void NetworkWin32::StopThread() {
if (thread_ != NULL) {
exit_thread_ = true;
+ g_fail_malloc_flag = true;
PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, NULL);
WaitForSingleObject(thread_, INFINITE);
CloseHandle(thread_);
thread_ = NULL;
exit_thread_ = false;
+ g_fail_malloc_flag = false;
+ }
+}
+
+void NetworkWin32::WakeUp() {
+ PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, NULL);
+}
+
+void NetworkWin32::PostQueuedItem(QueuedItem *item) {
+ PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, &item->overlapped);
+}
+
+bool NetworkWin32::Configure(int listen_port, int listen_port_tcp) {
+ if (listen_port_tcp)
+ RERROR("ListenPortTCP not supported in this version");
+ return udp_socket_.Configure(listen_port);
+}
+
+// Called from tunsafe thread
+void NetworkWin32::WriteUdpPacket(Packet *packet) {
+ if (packet->protocol & kPacketProtocolUdp) {
+ udp_socket_.WriteUdpPacket(packet);
+ } else {
+ tcp_socket_queue_.WritePacket(packet);
}
}
@@ -874,6 +879,8 @@ int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) {
mutex_.Acquire();
while (!(exit_code = exit_code_)) {
+ FreeAllPackets();
+
if (timer_interrupt_) {
timer_interrupt_ = false;
need_notify_ = 0;
@@ -912,7 +919,6 @@ int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) {
need_notify_ = 0;
mutex_.Release();
- tpq_last_qsize = packets_in_queue;
if (packets_in_queue >= 1024)
overload = 2;
queue_context.overload = (overload != 0);
@@ -986,8 +992,8 @@ void PacketProcessor::PostPackets(Packet *first, Packet **end, int count) {
}
void PacketProcessor::ForcePost(QueuedItem *item) {
- mutex_.Acquire();
item->queue_next = NULL;
+ mutex_.Acquire();
packets_in_queue_ += 1;
*last_ptr_ = item;
last_ptr_ = &item->queue_next;
@@ -1648,7 +1654,6 @@ bool TunWin32Adapter::RunPrePostCommand(const std::vector &vec) {
return success;
}
-
//////////////////////////////////////////////////////////////////////////////
TunWin32Iocp::TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker, backend->guid_), backend_(backend) {
@@ -1704,6 +1709,20 @@ enum {
kTunGetQueuedCompletionStatusSize = kConcurrentWriteTap + kConcurrentReadTap + 1
};
+static inline bool AllocPacketFrom(Packet **list, int *counter, bool *exit_flag, Packet **res) {
+ Packet *p;
+ if (p = *list) {
+ *list = Packet_NEXT(p);
+ (*counter)--;
+ p->data = p->data_buf;
+ } else {
+ if (!(p = AllocPacket()))
+ return false;
+ }
+ *res = p;
+ return true;
+}
+
void TunWin32Iocp::ThreadMain() {
OVERLAPPED_ENTRY entries[kTunGetQueuedCompletionStatusSize];
Packet *pending_writes = NULL;
@@ -1738,7 +1757,6 @@ void TunWin32Iocp::ThreadMain() {
num_reads++;
}
}
- g_tun_reads = num_reads;
assert(freed_packets_count >= 0);
if (freed_packets_count >= 32) {
@@ -1820,7 +1838,6 @@ void TunWin32Iocp::ThreadMain() {
num_writes++;
}
}
- g_tun_writes = num_writes;
}
EXIT:
@@ -1896,217 +1913,6 @@ void TunWin32Iocp::WriteTunPacket(Packet *packet) {
//////////////////////////////////////////////////////////////////////////////
-TunWin32Overlapped::TunWin32Overlapped(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker, backend->guid_), backend_(backend) {
- wqueue_end_ = &wqueue_;
- wqueue_ = NULL;
-
- thread_ = NULL;
-
- read_event_ = CreateEvent(NULL, TRUE, FALSE, NULL);
- write_event_ = CreateEvent(NULL, TRUE, FALSE, NULL);
- wake_event_ = CreateEvent(NULL, FALSE, FALSE, NULL);
-
- packet_handler_ = NULL;
- exit_thread_ = false;
-}
-
-TunWin32Overlapped::~TunWin32Overlapped() {
- CloseTun();
- CloseHandle(read_event_);
- CloseHandle(write_event_);
- CloseHandle(wake_event_);
-}
-
-bool TunWin32Overlapped::Configure(const TunConfig &&config, TunConfigOut *out) {
- CloseTun();
- if (adapter_.OpenAdapter(backend_, FILE_FLAG_OVERLAPPED) &&
- adapter_.ConfigureAdapter(std::move(config), out))
- return true;
- CloseTun();
- return false;
-}
-
-void TunWin32Overlapped::CloseTun() {
- assert(thread_ == NULL);
- adapter_.CloseAdapter(false);
- FreePacketList(wqueue_);
- wqueue_ = NULL;
- wqueue_end_ = &wqueue_;
-}
-
-void TunWin32Overlapped::ThreadMain() {
- Packet *pending_writes = NULL;
- DWORD err;
- Packet *read_packet = NULL, *write_packet = NULL;
-
- HANDLE h[3];
- while (!exit_thread_) {
- if (read_packet == NULL) {
- Packet *p = AllocPacket();
- ClearOverlapped(&p->overlapped);
- p->overlapped.hEvent = read_event_;
- if (!ReadFile(adapter_.handle(), p->data, kPacketCapacity, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) {
- FreePacket(p);
- RERROR("TunWin32: ReadFile failed 0x%X", err);
- } else {
- read_packet = p;
- }
- }
-
- int n = 0;
- if (write_packet)
- h[n++] = write_event_;
- if (read_packet != NULL)
- h[n++] = read_event_;
- h[n++] = wake_event_;
-
- DWORD res = WaitForMultipleObjects(n, h, FALSE, INFINITE);
-
- if (res >= WAIT_OBJECT_0 && res <= WAIT_OBJECT_0 + 2) {
- HANDLE hx = h[res - WAIT_OBJECT_0];
- if (hx == read_event_) {
- read_packet->size = (int)read_packet->overlapped.InternalHigh;
- Packet_NEXT(read_packet) = NULL;
- packet_handler_->PostPackets(read_packet, &Packet_NEXT(read_packet), 1);
- read_packet = NULL;
- } else if (hx == write_event_) {
- FreePacket(write_packet);
- write_packet = NULL;
- }
- } else {
- RERROR("Wait said %d", res);
- }
-
- if (write_packet == NULL) {
- if (!pending_writes) {
- mutex_.Acquire();
- pending_writes = wqueue_;
- wqueue_end_ = &wqueue_;
- wqueue_ = NULL;
- mutex_.Release();
- }
- if (pending_writes) {
- // Then issue writes
- Packet *p = pending_writes;
- pending_writes = Packet_NEXT(p);
- memset(&p->overlapped, 0, sizeof(p->overlapped));
- p->overlapped.hEvent = write_event_;
- if (!WriteFile(adapter_.handle(), p->data, p->size, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) {
- RERROR("TunWin32: WriteFile failed 0x%X", err);
- FreePacket(p);
- } else {
- write_packet = p;
- }
- }
- }
- }
-
- // TODO: Free memory
- CancelIo(adapter_.handle());
- FreePacketList(pending_writes);
-}
-
-DWORD WINAPI TunWin32Overlapped::TunThread(void *x) {
- TunWin32Overlapped *xx = (TunWin32Overlapped *)x;
- xx->ThreadMain();
- return 0;
-}
-
-void TunWin32Overlapped::StartThread() {
- DWORD thread_id;
- thread_ = CreateThread(NULL, 0, &TunThread, this, 0, &thread_id);
- SetThreadPriority(thread_, ABOVE_NORMAL_PRIORITY_CLASS);
-}
-
-void TunWin32Overlapped::StopThread() {
- exit_thread_ = true;
- SetEvent(wake_event_);
- WaitForSingleObject(thread_, INFINITE);
- CloseHandle(thread_);
- thread_ = NULL;
-}
-
-void TunWin32Overlapped::WriteTunPacket(Packet *packet) {
- Packet_NEXT(packet) = NULL;
- mutex_.Acquire();
- Packet *was_empty = wqueue_;
- *wqueue_end_ = packet;
- wqueue_end_ = &Packet_NEXT(packet);
- mutex_.Release();
- if (was_empty == NULL)
- SetEvent(wake_event_);
-}
-
-void TunsafeBackendWin32::SetPublicKey(const uint8 key[32]) {
- memcpy(public_key_, key, 32);
- delegate_->OnStateChanged();
-}
-
-DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) {
- TunsafeBackendWin32 *backend = (TunsafeBackendWin32*)bk;
- int stop_mode;
- int fast_retry_ctr = 0;
-
- for(;;) {
- TunWin32Iocp tun(&backend->dns_blocker_, backend);
- NetworkWin32 net;
- WireguardProcessor wg_proc(&net.udp(), &tun, backend);
-
- qs.udp_qsize1 = qs.udp_qsize2 = 0;
-
- net.udp().SetPacketHandler(&backend->packet_processor_);
- tun.SetPacketHandler(&backend->packet_processor_);
-
- if (backend->config_file_[0] &&
- !ParseWireGuardConfigFile(&wg_proc, backend->config_file_, &backend->dns_resolver_))
- goto getout_fail;
-
- if (!wg_proc.Start())
- goto getout_fail;
-
- backend->SetPublicKey(wg_proc.dev().public_key());
-
- backend->wg_processor_ = &wg_proc;
-
- net.StartThread();
- tun.StartThread();
- stop_mode = backend->packet_processor_.Run(&wg_proc, backend);
- net.StopThread();
- tun.StopThread();
-
- backend->wg_processor_ = NULL;
-
- // Keep DNS alive
- if (stop_mode != MODE_EXIT)
- tun.adapter().DisassociateDnsBlocker();
- else
- backend->dns_resolver_.ClearCache();
-
- FreeAllPackets();
-
- if (stop_mode != MODE_TUN_FAILED)
- return 0;
-
- uint32 last_fail = GetTickCount();
- fast_retry_ctr = (last_fail - backend->last_tun_adapter_failed_ < 5000) ? fast_retry_ctr + 1 : 0;
- backend->last_tun_adapter_failed_ = last_fail;
-
- backend->SetStatus((fast_retry_ctr >= 3) ? TunsafeBackend::kErrorTunPermanent : TunsafeBackend::kStatusTunRetrying);
-
- if (backend->status_ == TunsafeBackend::kErrorTunPermanent) {
- RERROR("Too many automatic restarts...");
- goto getout_fail_noseterr;
- }
- Sleep(1000);
- }
-getout_fail:
- backend->status_ = TunsafeBackend::kErrorInitialize;
- backend->delegate_->OnStatusCode(TunsafeBackend::kErrorInitialize);
-getout_fail_noseterr:
- backend->dns_blocker_.RestoreDns();
- return 0;
-}
-
TunsafeBackend::TunsafeBackend() {
is_started_ = false;
is_remote_ = false;
@@ -2152,6 +1958,75 @@ TunsafeBackendWin32::~TunsafeBackendWin32() {
TunAdaptersInUse::GetInstance()->Release(this);
}
+void TunsafeBackendWin32::SetPublicKey(const uint8 key[32]) {
+ memcpy(public_key_, key, 32);
+ delegate_->OnStateChanged();
+}
+
+DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) {
+ TunsafeBackendWin32 *backend = (TunsafeBackendWin32*)bk;
+ int stop_mode;
+ int fast_retry_ctr = 0;
+
+ for (;;) {
+ TunWin32Iocp tun(&backend->dns_blocker_, backend);
+ NetworkWin32 net;
+ WireguardProcessor wg_proc(&net, &tun, backend);
+
+ net.udp().SetPacketHandler(&backend->packet_processor_);
+ net.tcp_socket_queue().SetPacketHandler(&backend->packet_processor_);
+
+ tun.SetPacketHandler(&backend->packet_processor_);
+
+ if (backend->config_file_[0] &&
+ !ParseWireGuardConfigFile(&wg_proc, backend->config_file_, &backend->dns_resolver_))
+ goto getout_fail;
+
+ if (!wg_proc.Start())
+ goto getout_fail;
+
+ backend->SetPublicKey(wg_proc.dev().public_key());
+
+ backend->wg_processor_ = &wg_proc;
+
+ net.StartThread();
+ tun.StartThread();
+ stop_mode = backend->packet_processor_.Run(&wg_proc, backend);
+ net.StopThread();
+ tun.StopThread();
+
+ backend->wg_processor_ = NULL;
+
+ // Keep DNS alive
+ if (stop_mode != MODE_EXIT)
+ tun.adapter().DisassociateDnsBlocker();
+ else
+ backend->dns_resolver_.ClearCache();
+
+ FreeAllPackets();
+
+ if (stop_mode != MODE_TUN_FAILED)
+ return 0;
+
+ uint32 last_fail = GetTickCount();
+ fast_retry_ctr = (last_fail - backend->last_tun_adapter_failed_ < 5000) ? fast_retry_ctr + 1 : 0;
+ backend->last_tun_adapter_failed_ = last_fail;
+
+ backend->SetStatus((fast_retry_ctr >= 3) ? TunsafeBackend::kErrorTunPermanent : TunsafeBackend::kStatusTunRetrying);
+
+ if (backend->status_ == TunsafeBackend::kErrorTunPermanent) {
+ RERROR("Too many automatic restarts...");
+ goto getout_fail_noseterr;
+ }
+ Sleep(1000);
+ }
+getout_fail:
+ backend->status_ = TunsafeBackend::kErrorInitialize;
+ backend->delegate_->OnStatusCode(TunsafeBackend::kErrorInitialize);
+getout_fail_noseterr:
+ backend->dns_blocker_.RestoreDns();
+ return 0;
+}
void TunsafeBackendWin32::SetStatus(StatusCode status) {
status_ = status;
diff --git a/network_win32.h b/network_win32.h
index 6b91ed0..5d79849 100644
--- a/network_win32.h
+++ b/network_win32.h
@@ -2,15 +2,14 @@
// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved.
#pragma once
-#include "stdafx.h"
-#include "tunsafe_types.h"
#include "netapi.h"
#include "network_win32_api.h"
#include "network_win32_dnsblock.h"
#include "wireguard_config.h"
#include "tunsafe_threading.h"
#include "tunsafe_dnsresolve.h"
-#include
+#include "network_common.h"
+#include "network_win32_tcp.h"
enum {
ADAPTER_GUID_SIZE = 40,
@@ -18,6 +17,7 @@ enum {
class WireguardProcessor;
class TunsafeBackendWin32;
+class DnsBlocker;
struct PacketProcessorTunCb : QueuedItemCallback {
virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override;
@@ -73,16 +73,15 @@ class PacketAllocPool;
// Encapsulates a UDP socket pair (ipv4 / ipv6), optionally listening for incoming packets
// on a specific port.
-class UdpSocketWin32 : public UdpInterface, QueuedItemCallback {
+class UdpSocketWin32 : public QueuedItemCallback {
public:
explicit UdpSocketWin32(NetworkWin32 *network_win32);
~UdpSocketWin32();
void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
- // -- from UdpInterface
- virtual bool Configure(int listen_on_port) override;
- virtual void WriteUdpPacket(Packet *packet) override;
+ bool Configure(int listen_on_port);
+ inline void WriteUdpPacket(Packet *packet);
void DoIO();
void CancelAllIO();
@@ -94,9 +93,9 @@ public:
};
private:
-
void DoMoreReads();
void DoMoreWrites();
+ void ProcessPackets();
// From OverlappedCallbacks
virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override;
@@ -125,11 +124,16 @@ private:
Packet *finished_reads_, **finished_reads_end_;
int finished_reads_count_;
+
+ __declspec(align(64)) uint32 qsize1_;
+ __declspec(align(64)) uint32 qsize2_;
};
// Holds the thread for network communications
-class NetworkWin32 {
+class NetworkWin32 : public UdpInterface {
friend class UdpSocketWin32;
+ friend class TcpSocketWin32;
+ friend class TcpSocketQueue;
public:
explicit NetworkWin32();
~NetworkWin32();
@@ -138,13 +142,20 @@ public:
void StopThread();
UdpSocketWin32 &udp() { return udp_socket_; }
+ SimplePacketPool &packet_pool() { return packet_pool_; }
+ TcpSocketQueue &tcp_socket_queue() { return tcp_socket_queue_; }
+ void WakeUp();
+ void PostQueuedItem(QueuedItem *item);
+
+ // -- from UdpInterface
+ virtual bool Configure(int listen_port_udp, int listen_port_tcp) override;
+ virtual void WriteUdpPacket(Packet *packet) override;
private:
void ThreadMain();
static DWORD WINAPI NetworkThread(void *x);
- void FreePacketToPool(Packet *p);
- bool AllocPacketFromPool(Packet **p);
+ bool HasOutstandingIO();
// The network thread handle
HANDLE thread_;
@@ -155,18 +166,17 @@ private:
// The handle to the completion port
HANDLE completion_port_handle_;
- Packet *freed_packets_, **freed_packets_end_;
- int freed_packets_count_;
-
// Right now there's always one udp socket only
UdpSocketWin32 udp_socket_;
+
+ // A linked list of all tcp sockets
+ TcpSocketWin32 *tcp_socket_;
+
+ SimplePacketPool packet_pool_;
+
+ TcpSocketQueue tcp_socket_queue_;
};
-
-
-
-class DnsBlocker;
-
class TunWin32Adapter {
public:
TunWin32Adapter(DnsBlocker *dns_blocker, const char guid[ADAPTER_GUID_SIZE]);
@@ -243,42 +253,6 @@ private:
TunWin32Adapter adapter_;
};
-// Implementation of TUN interface handling using Overlapped IO
-class TunWin32Overlapped : public TunInterface {
-public:
- explicit TunWin32Overlapped(DnsBlocker *blocker, TunsafeBackendWin32 *backend);
- ~TunWin32Overlapped();
-
- void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
-
- void StartThread();
- void StopThread();
-
- // -- from TunInterface
- virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override;
- virtual void WriteTunPacket(Packet *packet) override;
-
-private:
- void CloseTun();
- void ThreadMain();
- static DWORD WINAPI TunThread(void *x);
-
- PacketProcessor *packet_handler_;
- HANDLE thread_;
-
- Mutex mutex_;
-
- HANDLE read_event_, write_event_, wake_event_;
-
- bool exit_thread_;
-
- Packet *wqueue_, **wqueue_end_;
-
- TunWin32Adapter adapter_;
-
- TunsafeBackendWin32 *backend_;
-};
-
class TunsafeBackendWin32 : public TunsafeBackend, public ProcessorDelegate {
friend class PacketProcessor;
friend class TunWin32Iocp;
@@ -427,3 +401,8 @@ private:
uint8 num_inuse_;
Entry entry_[kMaxAdaptersInUse];
};
+
+static inline void ClearOverlapped(OVERLAPPED *o) {
+ memset(o, 0, sizeof(*o));
+}
+
diff --git a/network_win32_api.h b/network_win32_api.h
index aaff158..fb4b964 100644
--- a/network_win32_api.h
+++ b/network_win32_api.h
@@ -125,6 +125,3 @@ protected:
TunsafeBackend *CreateNativeTunsafeBackend(TunsafeBackend::Delegate *delegate);
TunsafeBackend::Delegate *CreateTunsafeBackendDelegateThreaded(TunsafeBackend::Delegate *delegate, const std::function &callback);
-
-extern int tpq_last_qsize;
-extern int g_tun_reads, g_tun_writes;
diff --git a/network_win32_tcp.cpp b/network_win32_tcp.cpp
new file mode 100644
index 0000000..060c88f
--- /dev/null
+++ b/network_win32_tcp.cpp
@@ -0,0 +1,344 @@
+// SPDX-License-Identifier: AGPL-1.0-only
+// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved.
+#include
+#include "network_win32_tcp.h"
+#include "network_win32.h"
+#include
+#include
+#include "util.h"
+
+////////////////////////////////////////////////////////////////////////////////////////////////////////
+
+TcpSocketWin32::TcpSocketWin32(NetworkWin32 *network)
+ : tcp_packet_handler_(&network->packet_pool()) {
+ network_ = network;
+ reads_active_ = 0;
+ writes_active_ = 0;
+ handshake_attempts = 0;
+ state_ = STATE_NONE;
+ wqueue_ = NULL;
+ wqueue_end_ = &wqueue_;
+ socket_ = INVALID_SOCKET;
+ next_ = NULL;
+ packet_processor_ = NULL;
+ // insert in network's linked list
+ next_ = network->tcp_socket_;
+ network->tcp_socket_ = this;
+}
+
+TcpSocketWin32::~TcpSocketWin32() {
+ // Unlink myself from the network's linked list.
+ TcpSocketWin32 **p = &network_->tcp_socket_;
+ while (*p != this) p = &(*p)->next_;
+ *p = next_;
+
+ FreePacketList(wqueue_);
+ if (socket_ != INVALID_SOCKET)
+ closesocket(socket_);
+}
+
+void TcpSocketWin32::CloseSocket() {
+ if (socket_ != INVALID_SOCKET)
+ CancelIo((HANDLE)socket_);
+ state_ = STATE_ERROR;
+ endpoint_protocol_ = 0;
+}
+
+void TcpSocketWin32::WritePacket(Packet *packet) {
+ packet->queue_next = NULL;
+ *wqueue_end_ = packet;
+ wqueue_end_ = &Packet_NEXT(packet);
+}
+
+void TcpSocketWin32::CancelAllIO() {
+ if (socket_ != INVALID_SOCKET)
+ CancelIo((HANDLE)socket_);
+}
+
+static const GUID WsaConnectExGUID = WSAID_CONNECTEX;
+
+void TcpSocketWin32::DoConnect() {
+ LPFN_CONNECTEX ConnectEx;
+ assert(socket_ == INVALID_SOCKET);
+
+ socket_ = WSASocket(endpoint_.sin.sin_family, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED);
+ if (socket_ == INVALID_SOCKET) {
+ RERROR("socket() failed");
+ CloseSocket();
+ return;
+ }
+
+ if (!CreateIoCompletionPort((HANDLE)socket_, network_->completion_port_handle_, 0, 0)) {
+ RERROR("TcpSocketWin32::DoConnect CreateIoCompletionPort failed");
+ CloseSocket();
+ return;
+ }
+
+ int nodelay = 1;
+ setsockopt(socket_, IPPROTO_TCP, TCP_NODELAY, (char*)&nodelay, 1);
+
+ DWORD dwBytes = sizeof(ConnectEx);
+ DWORD rc = WSAIoctl(socket_, SIO_GET_EXTENSION_FUNCTION_POINTER, (uint8*)&WsaConnectExGUID, sizeof(WsaConnectExGUID), &ConnectEx, sizeof(ConnectEx), &dwBytes, NULL, NULL);
+ assert(rc == 0);
+
+ // ConnectEx requires the socket to be bound
+ sockaddr_in sin = {0};
+ sin.sin_family = AF_INET;
+ sin.sin_addr.s_addr = INADDR_ANY;
+ sin.sin_port = 0;
+ if (bind(socket_, (sockaddr*)&sin, sizeof(sin))) {
+ RERROR("TcpSocketWin32::DoConnect bind failed: %d", WSAGetLastError());
+ CloseSocket();
+ return;
+ }
+
+ char buf[kSizeOfAddress];
+ RINFO("Connecting to tcp://%s...", PrintIpAddr(endpoint_, buf));
+
+ state_ = STATE_CONNECTING;
+ ClearOverlapped(&connect_overlapped_.overlapped);
+ connect_overlapped_.queue_cb = this;
+ if (!ConnectEx(socket_, (const sockaddr*)&endpoint_.sin, sizeof(endpoint_.sin), NULL, 0, NULL, &connect_overlapped_.overlapped)) {
+ int err = WSAGetLastError();
+ if (err != ERROR_IO_PENDING) {
+ RERROR("ConnectEx failed: %d", err);
+ CloseSocket();
+ return;
+ }
+ }
+ reads_active_ = 1;
+}
+
+void TcpSocketWin32::DoMoreReads() {
+ assert(state_ != STATE_ERROR);
+ if (reads_active_ == 0) {
+ // Initiate a new read, we always read into 4 buffers.
+ Packet *p = network_->packet_pool().AllocPacketFromPool();
+ if (!p)
+ return;
+
+ ClearOverlapped(&p->overlapped);
+ p->userdata = 0;
+ p->queue_cb = this;
+ DWORD flags = 0;
+ WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data};
+ if (WSARecv(socket_, &wsabuf, 1, NULL, &flags, &p->overlapped, NULL) != 0) {
+ DWORD err = WSAGetLastError();
+ if (err != ERROR_IO_PENDING) {
+ RERROR("TcpSocketWin32:WSARecv failed 0x%X", err);
+ FreePacket(p);
+ return;
+ }
+ }
+ reads_active_ = 1;
+ }
+}
+
+void TcpSocketWin32::DoMoreWrites() {
+ assert(state_ != STATE_ERROR);
+ if (writes_active_ == 0) {
+ WSABUF wsabuf[kMaxWsaBuf];
+ uint32 num_wsabuf = 0;
+
+ Packet *p = wqueue_;
+ if (p == NULL)
+ return;
+
+ do {
+ tcp_packet_handler_.AddHeaderToOutgoingPacket(p);
+ wsabuf[num_wsabuf].buf = (char*)p->data;
+ wsabuf[num_wsabuf].len = (ULONG)p->size;
+ packets_in_write_io_[num_wsabuf] = p;
+ p = Packet_NEXT(p);
+ } while (++num_wsabuf < kMaxWsaBuf && p != NULL);
+ if (!(wqueue_ = p))
+ wqueue_end_ = &wqueue_;
+ num_wsabuf_ = (uint8)num_wsabuf;
+
+ p = packets_in_write_io_[0];
+ ClearOverlapped(&p->overlapped);
+ p->userdata = 1;
+ p->queue_cb = this;
+
+ if (WSASend(socket_, wsabuf, num_wsabuf, NULL, 0, &p->overlapped, NULL) != 0) {
+ DWORD err = WSAGetLastError();
+ if (err != ERROR_IO_PENDING) {
+ RERROR("TcpSocketWin32: WSASend failed 0x%X", err);
+ FreePacket(p);
+ CloseSocket();
+ return;
+ }
+ }
+ writes_active_ = 1;
+ }
+}
+
+void TcpSocketWin32::DoIO() {
+ if (state_ == STATE_CONNECTED) {
+ DoMoreReads();
+ while (Packet *p = tcp_packet_handler_.GetNextWireguardPacket()) {
+ p->protocol = endpoint_protocol_;
+ p->addr = endpoint_;
+
+ p->queue_cb = packet_processor_->udp_queue();
+ packet_processor_->ForcePost(p);
+ }
+ if (tcp_packet_handler_.error()) {
+ CloseSocket();
+ DoIO();
+ return;
+ }
+ DoMoreWrites();
+ } else if (state_ == STATE_WANT_CONNECT) {
+ DoConnect();
+ } else if (state_ == STATE_ERROR && !HasOutstandingIO()) {
+ delete this;
+ }
+}
+
+bool TcpSocketWin32::HasOutstandingIO() {
+ return writes_active_ + reads_active_ != 0;
+}
+
+void TcpSocketWin32::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) {
+ if (qi == &connect_overlapped_) {
+ assert(state_ == STATE_CONNECTING);
+ reads_active_ = 0;
+ if ((DWORD)qi->overlapped.Internal != 0) {
+ if (state_ != STATE_ERROR) {
+ RERROR("TcpSocketWin32::Connect error 0x%X", (DWORD)qi->overlapped.Internal);
+ CloseSocket();
+ }
+ } else {
+ state_ = STATE_CONNECTED;
+ }
+ return;
+ }
+ Packet *p = static_cast(qi);
+ if (p->userdata == 0) {
+ // Read operation complete
+ if ((DWORD)p->overlapped.Internal != 0) {
+ if (state_ != STATE_ERROR) {
+ RERROR("TcpSocketWin32::Read error 0x%X", (DWORD)p->overlapped.Internal);
+ CloseSocket();
+ }
+ network_->packet_pool().FreePacketToPool(p);
+ // What to do?
+ } else if ((int)p->overlapped.InternalHigh == 0) {
+ // Socket closed successfully
+ CloseSocket();
+ network_->packet_pool().FreePacketToPool(p);
+ } else {
+ // Queue it up to rqueue
+ p->size = (int)p->overlapped.InternalHigh;
+ tcp_packet_handler_.QueueIncomingPacket(p);
+ }
+ reads_active_--;
+ } else {
+ assert(writes_active_);
+ assert(packets_in_write_io_[0] == p);
+
+ if ((DWORD)p->overlapped.Internal != 0) {
+ if (state_ != STATE_ERROR) {
+ RERROR("TcpSocketWin32::Write error 0x%X", (DWORD)p->overlapped.Internal);
+ CloseSocket();
+ }
+ }
+ // free all the packets involved in the write
+ for (size_t i = 0; i < num_wsabuf_; i++)
+ network_->packet_pool().FreePacketToPool(packets_in_write_io_[i]);
+ writes_active_--;
+ }
+}
+
+void TcpSocketWin32::OnQueuedItemDelete(QueuedItem *qi) {
+ if (qi == &connect_overlapped_) {
+ reads_active_ = 0;
+ return;
+ }
+ Packet *p = static_cast(qi);
+ if (p->userdata == 0) {
+ FreePacket(p);
+ reads_active_--;
+ } else {
+ for (size_t i = 0; i < num_wsabuf_; i++)
+ network_->packet_pool().FreePacketToPool(packets_in_write_io_[i]);
+ writes_active_--;
+ }
+}
+
+/////////////////////////////////////////////////////////////////////////
+
+TcpSocketQueue::TcpSocketQueue(NetworkWin32 *network) {
+ network_ = network;
+ wqueue_ = NULL;
+ wqueue_end_ = &wqueue_;
+ queued_item_.queue_cb = this;
+ packet_handler_ = NULL;
+}
+
+TcpSocketQueue::~TcpSocketQueue() {
+ FreePacketList(wqueue_);
+}
+
+void TcpSocketQueue::TransmitOnePacket(Packet *packet) {
+ // Check if we have a tcp connection for the endpoint, otherwise create one.
+ for (TcpSocketWin32 *tcp = network_->tcp_socket_; tcp; tcp = tcp->next_) {
+ // After we send 3 handshakes on a tcp socket in a row, then close and reopen the socket because it seems defunct.
+ if (CompareIpAddr(&tcp->endpoint_, &packet->addr) == 0 && tcp->endpoint_protocol_ == packet->protocol) {
+ if (ReadLE32(packet->data) == MESSAGE_HANDSHAKE_INITIATION) {
+ if (tcp->handshake_attempts == 2) {
+ RINFO("Making new Tcp socket due to too many handshake failures");
+ tcp->CloseSocket();
+ break;
+ }
+ tcp->handshake_attempts++;
+ } else {
+ tcp->handshake_attempts = -1;
+ }
+ tcp->WritePacket(packet);
+ return;
+ }
+ }
+
+ // Drop tcp packet that's for an incoming connection, or packets that are
+ // not a handshake.
+ if ((packet->protocol & kPacketProtocolIncomingConnection) ||
+ packet->size < 4 || ReadLE32(packet->data) != MESSAGE_HANDSHAKE_INITIATION) {
+ FreePacket(packet);
+ return;
+ }
+
+ // Initialize a new tcp socket and connect to the endpoint
+ TcpSocketWin32 *tcp = new TcpSocketWin32(network_);
+ tcp->state_ = TcpSocketWin32::STATE_WANT_CONNECT;
+ tcp->endpoint_ = packet->addr;
+ tcp->endpoint_protocol_ = kPacketProtocolTcp;
+ tcp->SetPacketHandler(packet_handler_);
+ tcp->WritePacket(packet);
+}
+
+void TcpSocketQueue::OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) {
+ wqueue_mutex_.Acquire();
+ Packet *packet = wqueue_;
+ wqueue_ = NULL;
+ wqueue_end_ = &wqueue_;
+ wqueue_mutex_.Release();
+ while (packet)
+ TransmitOnePacket(exch(packet, Packet_NEXT(packet)));
+}
+
+void TcpSocketQueue::OnQueuedItemDelete(QueuedItem *ow) {
+
+}
+
+void TcpSocketQueue::WritePacket(Packet *packet) {
+ packet->queue_next = NULL;
+ wqueue_mutex_.Acquire();
+ Packet *was_empty = wqueue_;
+ *wqueue_end_ = packet;
+ wqueue_end_ = &Packet_NEXT(packet);
+ wqueue_mutex_.Release();
+ if (was_empty == NULL)
+ network_->PostQueuedItem(&queued_item_);
+}
diff --git a/network_win32_tcp.h b/network_win32_tcp.h
new file mode 100644
index 0000000..0808f98
--- /dev/null
+++ b/network_win32_tcp.h
@@ -0,0 +1,124 @@
+// SPDX-License-Identifier: AGPL-1.0-only
+// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved.
+
+#pragma once
+
+#include "netapi.h"
+#include "network_common.h"
+#include "tunsafe_threading.h"
+
+class NetworkWin32;
+class PacketProcessor;
+
+class TcpSocketWin32 : public QueuedItemCallback {
+ friend class NetworkWin32;
+ friend class TcpSocketQueue;
+public:
+ explicit TcpSocketWin32(NetworkWin32 *network);
+ ~TcpSocketWin32();
+
+ void SetPacketHandler(PacketProcessor *packet_handler) { packet_processor_ = packet_handler; }
+
+ // Write a packet to the TCP socket. This may be called only from the
+ // wireguard thread. Will append to a buffer and schedule it to be written
+ // from the network thread.
+ void WritePacket(Packet *packet);
+
+ // Call from IO completion thread to cancel all outstanding IO
+ void CancelAllIO();
+
+ // Call from IO completion thread to run more IO
+ void DoIO();
+
+ // Returns true if there's IO still left to run
+ bool HasOutstandingIO();
+
+private:
+ void DoMoreReads();
+ void DoMoreWrites();
+ void DoConnect();
+
+ void CloseSocket();
+
+ // From OverlappedCallbacks
+ virtual void OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) override;
+ virtual void OnQueuedItemDelete(QueuedItem *qi) override;
+
+ // Network subsystem
+ NetworkWin32 *network_;
+
+ PacketProcessor *packet_processor_;
+
+ enum {
+ STATE_NONE = 0,
+ STATE_ERROR = 1,
+ STATE_CONNECTING = 2,
+ STATE_CONNECTED = 3,
+ STATE_WANT_CONNECT = 4,
+ };
+
+ uint8 reads_active_;
+ uint8 writes_active_;
+ uint8 state_;
+ uint8 num_wsabuf_;
+
+public:
+ uint8 handshake_attempts;
+private:
+
+ // The handle to the socket
+ SOCKET socket_;
+
+ // Packets taken over by the network thread waiting to be written,
+ // when these are written we'll start eating from wqueue_
+ Packet *pending_writes_;
+
+ // All packets queued for writing on the network thread.
+ Packet *wqueue_, **wqueue_end_;
+
+ // Linked list of all TcpSocketWin32 wsockets
+ TcpSocketWin32 *next_;
+
+ // Handles packet parsing
+ TcpPacketHandler tcp_packet_handler_;
+
+ // An overlapped instance used for the initial Connect() call.
+ QueuedItem connect_overlapped_;
+
+ IpAddr endpoint_;
+ uint8 endpoint_protocol_;
+
+ // Packets currently involved in the wsabuf writing
+ enum { kMaxWsaBuf = 32 };
+ Packet *packets_in_write_io_[kMaxWsaBuf];
+};
+
+class TcpSocketQueue : public QueuedItemCallback {
+public:
+ explicit TcpSocketQueue(NetworkWin32 *network);
+ ~TcpSocketQueue();
+
+ void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
+
+ virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override;
+ virtual void OnQueuedItemDelete(QueuedItem *ow) override;
+
+ void WritePacket(Packet *packet);
+
+private:
+ void TransmitOnePacket(Packet *packet);
+ NetworkWin32 *network_;
+
+ // All packets queued for writing on the network thread. Locked by |wqueue_mutex_|
+ Packet *wqueue_, **wqueue_end_;
+
+ PacketProcessor *packet_handler_;
+
+ // Protects wqueue_
+ Mutex wqueue_mutex_;
+
+ // Used for queueing things on the network instance
+ QueuedItem queued_item_;
+
+};
+
diff --git a/tunsafe_amalgam.cpp b/tunsafe_amalgam.cpp
index 9d0c94e..9a28774 100644
--- a/tunsafe_amalgam.cpp
+++ b/tunsafe_amalgam.cpp
@@ -23,7 +23,7 @@
#if defined(WITH_NETWORK_BSD)
#include "network_bsd.cpp"
-#include "network_bsd_common.cpp"
+#include "tunsafe_bsd.cpp"
#include "ts.cpp"
#include "benchmark.cpp"
#endif
diff --git a/network_bsd_common.cpp b/tunsafe_bsd.cpp
similarity index 72%
rename from network_bsd_common.cpp
rename to tunsafe_bsd.cpp
index b2bf710..ded6515 100644
--- a/network_bsd_common.cpp
+++ b/tunsafe_bsd.cpp
@@ -1,6 +1,6 @@
// SPDX-License-Identifier: AGPL-1.0-only
// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved.
-#include "network_bsd_common.h"
+#include "tunsafe_bsd.h"
#include "tunsafe_endian.h"
#include "util.h"
@@ -43,17 +43,6 @@
#include
#endif
-void tunsafe_die(const char *msg) {
- fprintf(stderr, "%s\n", msg);
- exit(1);
-}
-
-void SetThreadName(const char *name) {
-#if defined(OS_LINUX)
- prctl(PR_SET_NAME, name, 0, 0, 0);
-#endif // defined(OS_LINUX)
-}
-
#if defined(OS_MACOSX) || defined(OS_FREEBSD)
struct MyRouteMsg {
struct rt_msghdr hdr;
@@ -346,21 +335,7 @@ int open_tun(char *devname, size_t devname_size) {
}
#endif
-int open_udp(int listen_on_port) {
- int udp_fd = socket(AF_INET, SOCK_DGRAM, 0);
- if (udp_fd < 0) return udp_fd;
- sockaddr_in sin = {0};
- sin.sin_family = AF_INET;
- sin.sin_port = htons(listen_on_port);
- if (bind(udp_fd, (struct sockaddr*)&sin, sizeof(sin)) != 0) {
- close(udp_fd);
- return -1;
- }
- return udp_fd;
-}
-
-TunsafeBackendBsd::TunsafeBackendBsd()
- : processor_(NULL) {
+TunsafeBackendBsd::TunsafeBackendBsd() {
devname_[0] = 0;
tun_interface_gone_ = false;
}
@@ -579,122 +554,33 @@ bool TunsafeBackendBsd::RunPrePostCommand(const std::vector &vec) {
return success;
}
-#if defined(OS_LINUX)
-UnixSocketDeletionWatcher::UnixSocketDeletionWatcher()
- : inotify_fd_(-1) {
- pipes_[0] = -1;
- pipes_[0] = -1;
-}
-
-UnixSocketDeletionWatcher::~UnixSocketDeletionWatcher() {
- close(inotify_fd_);
- close(pipes_[0]);
- close(pipes_[1]);
-}
-
-bool UnixSocketDeletionWatcher::Start(const char *path, bool *flag_to_set) {
- assert(inotify_fd_ == -1);
- path_ = path;
- flag_to_set_ = flag_to_set;
- pid_ = getpid();
- inotify_fd_ = inotify_init1(IN_CLOEXEC | IN_NONBLOCK);
- if (inotify_fd_ == -1) {
- perror("inotify_init1() failed");
- return false;
- }
- if (inotify_add_watch(inotify_fd_, "/var/run/wireguard", IN_DELETE | IN_DELETE_SELF) == -1) {
- perror("inotify_add_watch failed");
- return false;
- }
- if (pipe(pipes_) == -1) {
- perror("pipe() failed");
- return false;
- }
- return pthread_create(&thread_, NULL, &UnixSocketDeletionWatcher::RunThread, this) == 0;
-}
-
-void UnixSocketDeletionWatcher::Stop() {
- RINFO("Stopping..");
- void *retval;
- write(pipes_[1], "", 1);
- pthread_join(thread_, &retval);
-}
-
-void *UnixSocketDeletionWatcher::RunThread(void *arg) {
- UnixSocketDeletionWatcher *self = (UnixSocketDeletionWatcher*)arg;
- return self->RunThreadInner();
-}
-
-void *UnixSocketDeletionWatcher::RunThreadInner() {
- char buf[sizeof(struct inotify_event) + NAME_MAX + 1]
- __attribute__ ((aligned(__alignof__(struct inotify_event))));
- fd_set fdset;
- struct stat st;
- for(;;) {
- if (lstat(path_, &st) == -1 && errno == ENOENT) {
- RINFO("Unix socket %s deleted.", path_);
- *flag_to_set_ = true;
- kill(pid_, SIGALRM);
- break;
- }
- FD_ZERO(&fdset);
- FD_SET(inotify_fd_, &fdset);
- FD_SET(pipes_[0], &fdset);
- int n = select(std::max(inotify_fd_, pipes_[0]) + 1, &fdset, NULL, NULL, NULL);
- if (n == -1) {
- perror("select");
- break;
- }
- if (FD_ISSET(inotify_fd_, &fdset)) {
- ssize_t len = read(inotify_fd_, buf, sizeof(buf));
- if (len == -1) {
- perror("read");
- break;
- }
- }
- if (FD_ISSET(pipes_[0], &fdset))
- break;
- }
- return NULL;
-}
-
-#else // !defined(OS_LINUX)
-
-bool UnixSocketDeletionWatcher::Poll(const char *path) {
- struct stat st;
- return lstat(path, &st) == -1 && errno == ENOENT;
-}
-
-#endif // !defined(OS_LINUX)
-
-static TunsafeBackendBsd *g_tunsafe_backend_bsd;
-
-static void SigAlrm(int sig) {
- if (g_tunsafe_backend_bsd)
- g_tunsafe_backend_bsd->HandleSigAlrm();
-}
+static SignalCatcher *g_signal_catcher;
static bool did_ctrlc;
-void SigInt(int sig) {
+void SignalCatcher::SigAlrm(int sig) {
+ if (g_signal_catcher)
+ *g_signal_catcher->sigalarm_flag_ = true;
+}
+
+void SignalCatcher::SigInt(int sig) {
if (did_ctrlc)
exit(1);
did_ctrlc = true;
- write(1, "Ctrl-C detected. Exiting. Press again to force quit.\n", sizeof("Ctrl-C detected. Exiting. Press again to force quit.\n")-1);
-
+ write(1, "Ctrl-C detected. Exiting. Press again to force quit.\n", sizeof("Ctrl-C detected. Exiting. Press again to force quit.\n") - 1);
// todo: fix signal safety?
- if (g_tunsafe_backend_bsd)
- g_tunsafe_backend_bsd->HandleExit();
+ if (g_signal_catcher)
+ *g_signal_catcher->exit_flag_ = true;
}
-void TunsafeBackendBsd::RunLoop() {
- assert(!g_tunsafe_backend_bsd);
- assert(processor_);
+SignalCatcher::SignalCatcher(bool *exit_flag, bool *sigalarm_flag) {
+ assert(g_signal_catcher == NULL);
+ exit_flag_ = exit_flag;
+ sigalarm_flag_ = sigalarm_flag;
+ g_signal_catcher = this;
sigset_t mask;
- g_tunsafe_backend_bsd = this;
-
// We want an alarm signal every second.
{
struct sigaction act = {0};
@@ -713,7 +599,6 @@ void TunsafeBackendBsd::RunLoop() {
return;
}
}
-
#if defined(OS_LINUX) || defined(OS_FREEBSD)
sigemptyset(&mask);
sigaddset(&mask, SIGALRM);
@@ -737,7 +622,7 @@ void TunsafeBackendBsd::RunLoop() {
if (timer_create(CLOCK_MONOTONIC, &sev, &timer_id) < 0) {
RERROR("timer_create failed");
return;
- }
+ }
if (timer_settime(timer_id, 0, &tv, NULL) < 0) {
RERROR("timer_settime failed");
@@ -747,51 +632,209 @@ void TunsafeBackendBsd::RunLoop() {
#elif defined(OS_MACOSX)
ualarm(1000000, 1000000);
#endif
+}
- RunLoopInner();
-
- g_tunsafe_backend_bsd = NULL;
+SignalCatcher::~SignalCatcher() {
+ g_signal_catcher = NULL;
}
void InitCpuFeatures();
void Benchmark();
-
const char *print_ip(char buf[kSizeOfAddress], in_addr_t ip) {
snprintf(buf, kSizeOfAddress, "%d.%d.%d.%d", (ip >> 24) & 0xff, (ip >> 16) & 0xff, (ip >> 8) & 0xff, (ip >> 0) & 0xff);
return buf;
}
-class MyProcessorDelegate : public ProcessorDelegate {
+class TunsafeBackendBsdImpl : public TunsafeBackendBsd, public NetworkBsd::NetworkBsdDelegate, public ProcessorDelegate {
public:
- MyProcessorDelegate() {
- wg_processor_ = NULL;
- is_connected_ = false;
- }
+ TunsafeBackendBsdImpl();
+ virtual ~TunsafeBackendBsdImpl();
- virtual void OnConnected() override {
- if (!is_connected_) {
- const WgCidrAddr *ipv4_addr = NULL;
- for (const WgCidrAddr &x : wg_processor_->addr()) {
- if (x.size == 32) { ipv4_addr = &x; break; }
- }
- uint32 ipv4_ip = ipv4_addr ? ReadBE32(ipv4_addr->addr) : 0;
- char buf[kSizeOfAddress];
- RINFO("Connection established. IP %s", ipv4_ip ? print_ip(buf, ipv4_ip) : "(none)");
- is_connected_ = true;
- }
- }
- virtual void OnConnectionRetry(uint32 attempts) override {
- if (is_connected_ && attempts >= 3) {
- is_connected_ = false;
- RINFO("Reconnecting...");
- }
- }
+ void RunLoop();
+ virtual bool InitializeTun(char devname[16]) override;
+
+ // -- from TunInterface
+ virtual void WriteTunPacket(Packet *packet) override;
+
+ // -- from UdpInterface
+ virtual bool Configure(int listen_port_udp, int listen_port_tcp) override;
+ virtual void WriteUdpPacket(Packet *packet) override;
+
+ // -- from NetworkBsdDelegate
+ virtual void OnSecondLoop(uint64 now) override;
+ virtual void RunAllMainThreadScheduled() override;
+
+ // -- from ProcessorDelegate
+ virtual void OnConnected() override;
+ virtual void OnConnectionRetry(uint32 attempts) override;
+
+ WireguardProcessor *processor() { return &processor_; }
+
+private:
+ void WriteTcpPacket(Packet *packet);
+
+ // Close all TCP connections that are not pointed to by any of the peer endpoint.
+ void CloseOrphanTcpConnections();
- WireguardProcessor *wg_processor_;
bool is_connected_;
+ uint8 close_orphan_counter_;
+ WireguardProcessor processor_;
+ NetworkBsd network_;
+ TunSocketBsd tun_;
+ UdpSocketBsd udp_;
+ UnixDomainSocketListenerBsd unix_socket_listener_;
+ TcpSocketListenerBsd tcp_socket_listener_;
};
+TunsafeBackendBsdImpl::TunsafeBackendBsdImpl()
+ : is_connected_(false),
+ close_orphan_counter_(0),
+ processor_(this, this, this),
+ network_(this, 1000),
+ tun_(&network_, &processor_),
+ udp_(&network_, &processor_),
+ unix_socket_listener_(&network_, &processor_),
+ tcp_socket_listener_(&network_, &processor_) {
+}
+
+TunsafeBackendBsdImpl::~TunsafeBackendBsdImpl() {
+}
+
+bool TunsafeBackendBsdImpl::InitializeTun(char devname[16]) {
+ int tun_fd = open_tun(devname, 16);
+ if (tun_fd < 0) { RERROR("Error opening tun device"); return false; }
+ if (!tun_.Initialize(tun_fd)) {
+ close(tun_fd);
+ return false;
+ }
+ unix_socket_listener_.Initialize(devname);
+ return true;
+}
+
+void TunsafeBackendBsdImpl::WriteTunPacket(Packet *packet) {
+ tun_.WritePacket(packet);
+}
+
+// Called to initialize udp
+bool TunsafeBackendBsdImpl::Configure(int listen_port, int listen_port_tcp) {
+ return udp_.Initialize(listen_port) &&
+ (listen_port_tcp == 0 || tcp_socket_listener_.Initialize(listen_port_tcp));
+}
+
+void TunsafeBackendBsdImpl::WriteTcpPacket(Packet *packet) {
+ // Check if we have a tcp connection for the endpoint, otherwise create one.
+ for (TcpSocketBsd *tcp = network_.tcp_sockets(); tcp; tcp = tcp->next()) {
+ // After we send 3 handshakes on a tcp socket in a row, then close and reopen the socket because it seems defunct.
+ if (CompareIpAddr(&tcp->endpoint(), &packet->addr) == 0 && tcp->endpoint_protocol() == packet->protocol) {
+ if (ReadLE32(packet->data) == MESSAGE_HANDSHAKE_INITIATION) {
+ if (tcp->handshake_attempts == 2) {
+ RINFO("Making new Tcp socket due to too many handshake failures");
+ delete tcp;
+ break;
+ }
+ tcp->handshake_attempts++;
+ } else {
+ tcp->handshake_attempts = -1;
+ }
+ tcp->WritePacket(packet);
+ return;
+ }
+ }
+ // Drop tcp packet that's for an incoming connection, or packets that are
+ // not a handshake.
+ if ((packet->protocol & kPacketProtocolIncomingConnection) ||
+ ReadLE32(packet->data) != MESSAGE_HANDSHAKE_INITIATION) {
+ FreePacket(packet);
+ return;
+ }
+ // Initialize a new tcp socket and connect to the endpoint
+ TcpSocketBsd *tcp = new TcpSocketBsd(&network_, &processor_);
+ if (!tcp || !tcp->InitializeOutgoing(packet->addr)) {
+ delete tcp;
+ FreePacket(packet);
+ return;
+ }
+ tcp->WritePacket(packet);
+}
+
+void TunsafeBackendBsdImpl::WriteUdpPacket(Packet *packet) {
+ assert((packet->protocol & 0x7F) <= 2);
+ if (packet->protocol & kPacketProtocolTcp) {
+ WriteTcpPacket(packet);
+ } else {
+ udp_.WritePacket(packet);
+ }
+}
+
+void TunsafeBackendBsdImpl::RunLoop() {
+ if (!unix_socket_listener_.Start(network_.exit_flag()))
+ return;
+
+ SignalCatcher signal_catcher(network_.exit_flag(), network_.sigalarm_flag());
+ network_.RunLoop(&signal_catcher.orig_signal_mask_);
+ unix_socket_listener_.Stop();
+
+ tun_interface_gone_ = tun_.tun_interface_gone();
+}
+
+void TunsafeBackendBsdImpl::OnSecondLoop(uint64 now) {
+ if (!(close_orphan_counter_++ & 0xF))
+ CloseOrphanTcpConnections();
+ processor_.SecondLoop();
+}
+
+void TunsafeBackendBsdImpl::RunAllMainThreadScheduled() {
+ processor_.RunAllMainThreadScheduled();
+}
+
+void TunsafeBackendBsdImpl::OnConnected() {
+ if (!is_connected_) {
+ const WgCidrAddr *ipv4_addr = NULL;
+ for (const WgCidrAddr &x : processor_.addr()) {
+ if (x.size == 32) { ipv4_addr = &x; break; }
+ }
+ uint32 ipv4_ip = ipv4_addr ? ReadBE32(ipv4_addr->addr) : 0;
+ char buf[kSizeOfAddress];
+ RINFO("Connection established. IP %s", ipv4_ip ? print_ip(buf, ipv4_ip) : "(none)");
+ is_connected_ = true;
+ }
+}
+
+void TunsafeBackendBsdImpl::OnConnectionRetry(uint32 attempts) {
+ if (is_connected_ && attempts >= 3) {
+ is_connected_ = false;
+ RINFO("Reconnecting...");
+ }
+}
+
+void TunsafeBackendBsdImpl::CloseOrphanTcpConnections() {
+ // Add all incoming tcp connections into a lookup table
+ WG_HASHTABLE_IMPL lookup;
+ for (TcpSocketBsd *tcp = network_.tcp_sockets(); tcp; tcp = tcp->next()) {
+ if (tcp->endpoint_protocol() == (kPacketProtocolTcp | kPacketProtocolIncomingConnection)) {
+ // Avoid deleting tcp sockets that were just born.
+ if (tcp->age == 0) {
+ tcp->age = 1;
+ } else {
+ lookup[ConvertIpAddrToAddrX(tcp->endpoint())] = tcp;
+ }
+ }
+ }
+ if (lookup.empty())
+ return;
+ // For each peer, check if it has an endpoint that matches
+ // an entry in the lookup table, and delete it from the lookup
+ // table.
+ for(WgPeer *peer = processor_.dev().first_peer(); peer; peer = peer->next_peer()) {
+ if (peer->endpoint_protocol() == (kPacketProtocolTcp | kPacketProtocolIncomingConnection))
+ lookup.erase(ConvertIpAddrToAddrX(peer->endpoint()));
+ }
+ // The tcp connections that are still in the hashtable can be deleted
+ for(const auto &it : lookup)
+ delete (TcpSocketBsd *)it.second;
+}
+
int main(int argc, char **argv) {
CommandLineOutput cmd = {0};
@@ -812,20 +855,15 @@ int main(int argc, char **argv) {
SetThreadName("tunsafe-m");
- MyProcessorDelegate my_procdel;
- TunsafeBackendBsd *backend = CreateTunsafeBackendBsd();
+ TunsafeBackendBsdImpl backend;
if (cmd.interface_name)
- backend->SetTunDeviceName(cmd.interface_name);
-
- WireguardProcessor wg(backend, backend, &my_procdel);
-
- my_procdel.wg_processor_ = &wg;
- backend->SetProcessor(&wg);
+ backend.SetTunDeviceName(cmd.interface_name);
DnsResolver dns_resolver(NULL);
- if (*cmd.filename_to_load && !ParseWireGuardConfigFile(&wg, cmd.filename_to_load, &dns_resolver))
+ if (*cmd.filename_to_load && !ParseWireGuardConfigFile(backend.processor(), cmd.filename_to_load, &dns_resolver))
+ return 1;
+ if (!backend.processor()->Start())
return 1;
- if (!wg.Start()) return 1;
if (cmd.daemon) {
fprintf(stderr, "Switching to daemon mode...\n");
@@ -833,9 +871,8 @@ int main(int argc, char **argv) {
perror("daemon() failed");
}
- backend->RunLoop();
- backend->CleanupRoutes();
- delete backend;
+ backend.RunLoop();
+ backend.CleanupRoutes();
return 0;
}
diff --git a/tunsafe_bsd.h b/tunsafe_bsd.h
new file mode 100644
index 0000000..11bfb82
--- /dev/null
+++ b/tunsafe_bsd.h
@@ -0,0 +1,60 @@
+// SPDX-License-Identifier: AGPL-1.0-only
+// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved.
+#ifndef TUNSAFE_NETWORK_BSD_COMMON_H_
+#define TUNSAFE_NETWORK_BSD_COMMON_H_
+
+#include "netapi.h"
+#include "wireguard.h"
+#include "wireguard_config.h"
+#include
+#include
+
+struct RouteInfo {
+ uint8 family;
+ uint8 cidr;
+ uint8 ip[16];
+ uint8 gw[16];
+ std::string dev;
+};
+
+class SignalCatcher {
+public:
+ SignalCatcher(bool *exit_flag, bool *sigalarm_flag);
+ ~SignalCatcher();
+
+ sigset_t orig_signal_mask_;
+private:
+ static void SigAlrm(int sig);
+ static void SigInt(int sig);
+ bool *exit_flag_;
+ bool *sigalarm_flag_;
+};
+
+class TunsafeBackendBsd : public TunInterface, public UdpInterface {
+public:
+ TunsafeBackendBsd();
+ virtual ~TunsafeBackendBsd();
+
+ void CleanupRoutes();
+
+ void SetTunDeviceName(const char *name);
+
+ // -- from TunInterface
+ virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override;
+
+protected:
+ virtual bool InitializeTun(char devname[16]) = 0;
+
+ void AddRoute(uint32 ip, uint32 cidr, uint32 gw, const char *dev);
+ void DelRoute(const RouteInfo &cd);
+ bool AddRoute(int family, const void *dest, int dest_prefix, const void *gateway, const char *dev);
+ bool RunPrePostCommand(const std::vector &vec);
+
+ std::vector cleanup_commands_;
+ std::vector pre_down_, post_down_;
+ std::vector addresses_to_remove_;
+ char devname_[16];
+ bool tun_interface_gone_;
+};
+
+#endif // TUNSAFE_NETWORK_BSD_COMMON_H_
diff --git a/wireguard.cpp b/wireguard.cpp
index 14859a3..8fce37f 100644
--- a/wireguard.cpp
+++ b/wireguard.cpp
@@ -28,6 +28,7 @@ WireguardProcessor::WireguardProcessor(UdpInterface *udp, TunInterface *tun, Pro
mtu_ = 1420;
memset(&stats_, 0, sizeof(stats_));
listen_port_ = 0;
+ listen_port_tcp_ = 0;
network_discovery_spoofing_ = false;
add_routes_mode_ = true;
dns_blocking_ = true;
@@ -50,6 +51,16 @@ void WireguardProcessor::SetListenPort(int listen_port) {
}
}
+void WireguardProcessor::SetListenPortTcp(int listen_port) {
+ if (listen_port_tcp_ != listen_port) {
+ listen_port_tcp_ = listen_port;
+ if (is_started_ && !ConfigureUdp()) {
+ RINFO("ConfigureUdp failed");
+ }
+ }
+}
+
+
void WireguardProcessor::AddDnsServer(const IpAddr &sin) {
dns_addr_.push_back(sin);
}
@@ -126,7 +137,7 @@ bool WireguardProcessor::Start() {
bool WireguardProcessor::ConfigureUdp() {
assert(dev_.IsMainThread());
- return udp_->Configure(listen_port_);
+ return udp_->Configure(listen_port_, listen_port_tcp_);
}
bool WireguardProcessor::ConfigureTun() {
diff --git a/wireguard.h b/wireguard.h
index 7919bd1..869b3c0 100644
--- a/wireguard.h
+++ b/wireguard.h
@@ -81,6 +81,8 @@ public:
~WireguardProcessor();
void SetListenPort(int listen_port);
+ void SetListenPortTcp(int listen_port);
+
void AddDnsServer(const IpAddr &sin);
bool SetTunAddress(const WgCidrAddr &addr);
void ClearTunAddress();
@@ -132,6 +134,7 @@ private:
UdpInterface *udp_;
uint16 listen_port_;
+ uint16 listen_port_tcp_;
uint16 mtu_;
bool dns_blocking_;
diff --git a/wireguard_config.cpp b/wireguard_config.cpp
index 007517a..50c11c5 100644
--- a/wireguard_config.cpp
+++ b/wireguard_config.cpp
@@ -104,6 +104,8 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
wg_->dev().SetPrivateKey(binkey);
} else if (strcmp(key, "ListenPort") == 0) {
wg_->SetListenPort(atoi(value));
+ } else if (strcmp(key, "ListenPortTCP") == 0) {
+ wg_->SetListenPortTcp(atoi(value));
} else if (strcmp(key, "Address") == 0) {
SplitString(value, ',', &ss);
for (size_t i = 0; i < ss.size(); i++) {