diff --git a/TunSafe.vcxproj b/TunSafe.vcxproj index 42a9fbb..d0b93cf 100644 --- a/TunSafe.vcxproj +++ b/TunSafe.vcxproj @@ -185,10 +185,24 @@ + + true + true + true + true + + + + + true + true + true + true + @@ -215,8 +229,28 @@ + + true + true + true + true + + + true + true + true + true + + + + + true + true + true + true + diff --git a/TunSafe.vcxproj.filters b/TunSafe.vcxproj.filters index 3134cb4..10ced29 100644 --- a/TunSafe.vcxproj.filters +++ b/TunSafe.vcxproj.filters @@ -23,6 +23,12 @@ {1ca37c7b-e91e-4648-9584-7d0c73d8e416} + + {4b2f2fd9-780e-45db-8fe1-f03079439723} + + + {0f45e1a0-f33e-4c6e-88ae-eb4639f12041} + @@ -92,7 +98,6 @@ Source Files\Win32 - crypto\blake2s @@ -123,6 +128,21 @@ Source Files + + crypto\siphash + + + Source Files\BSD + + + Source Files\BSD + + + Source Files + + + Source Files\Win32 + @@ -173,9 +193,6 @@ Source Files\Win32 - - Source Files - crypto\blake2s @@ -188,6 +205,24 @@ Source Files + + Source Files\BSD + + + Source Files\BSD + + + crypto\siphash + + + Source Files\BSD + + + Source Files + + + Source Files\Win32 + diff --git a/docs/WireGuard TCP.txt b/docs/WireGuard TCP.txt new file mode 100644 index 0000000..dabb5c3 --- /dev/null +++ b/docs/WireGuard TCP.txt @@ -0,0 +1,104 @@ +WireGuard over TCP +------------------ + +We hate running one TCP implementation on top of another TCP implementation. +There's problems with cascading retransmissions and head of line blocking, +and performance is always much worse than a UDP based tunnel. + +However, we also recognize that several users need to run WireGuard over TCP. +One reason is that UDP packets are sometimes blocked by the network in +corporate scenarios or in other types of firewalls. Also, in misconfigured +networks outside of the user's control, TCP may be more reliable than UDP. + +Additionally, we want TunSafe to be a drop-in replacement for OpenVPN, which +also supports TCP based tunneling. The feature could also be used to run +WireGuard tunnels over ssh tunnels, or through socks/https proxies. + +The TunSafe project therefore takes the pragmatic approach of supporting +WireGuard over TCP, while discouraging its use. We absolutely don't want +people to start using TCP by default. It's meant to be used only in the +extreme cases when nothing else is working. + +We've added experimental support for TCP in the latest TunSafe master, +which means you can try this out on Windows, OSX, Linux, and FreeBSD. +On the server side, to listen on a TCP port, use ListenPortTCP=1234. (Not +working on Windows yet). On the clients, use Endpoint=tcp://5.5.5.5:1234. +The code is still very experimental and untested, and is not recommended +for general use. Once the code is more well tested, we'll also release +support for connecting to WireGuard over TCP in our Android and iOS clients. + +To make the impact as small as possible to our WireGuard protocol handling, +and to minimize the risk of security related issues, the TCP feature has been +designed to be as self-contained as possible. When a packet comes in over +TCP, it's sent over to the WireGuard protocol handler and treated as if it +was a UDP packet, and vice versa. This means TCP support can also be supported +in existing WireGuard deployments by using a separate process that converts +TCP connections into UDP packets sent to the WireGuard Linux kernel module. + +Each packet over TCP is prefixed by a 2-byte big endian number, which contains +the length of the packet's payload. The payload is then the actual WireGuard +UDP packet. + +TCP has larger overhead than UDP, and we want to support the usual WireGuard +MTU of 1420 without introducing extra packet "fragmenting". So we implemented +an optimization to skip sending the 16-byte WireGuard header for every packet. +TCP is a reliable connection, we know that sequence numbers are always +monotonically increasing, so we can predict the contents of this header. + +Here's an example: +A 1420 byte big packet sent over a WireGuard link will have 2 bytes of +TCP payload length, 16 bytes of WireGuard headers, 16 bytes of WireGuard MAC, +20 bytes of TCP headers, and 40 bytes of IPv6 headers. +This is a total of 1420 + 2 + 16 + 16 + 20 + 40 = 1514 bytes, exceeding +the usual 1500 byte Ethernet MTU by 14 bytes. This means that a single full +sized packet over WireGuard will result in 2 TCP packets. With our +optimization, we reduce this to 1498 bytes, so it fits in one TCP packet. + +Protocol specification +---------------------- + +TT LLLLLL LLLLLLLL [Payload LL bytes] +| | +| \-- Payload length, high byte first. +\----- Packet type + +The packet types (TT) currently defined are: +TT = 00 = Normal The payload is a normal unmodified WireGuard packet + including the regular WireGuard header. + 01 = Reserved + 10 = Data A WireGuard data packet (type 04) without the 16 byte + header. The predicted header is prefixed to the payload. + 11 = Control A TCP control packet. Currently this is used only to setup + the header prediction. See below. + +There's only one defined Control packet, type 00 (SetKeyAndCounter): + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + |1 1| Length is 13 (14 bits) | 00 (8 bits) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Key ID (32 bits) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Counter (64 bits) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +This sets up the Key ID and Counter used for the Data packets. Then Counter +is incremented by 1 for every such packet. + +For every Data packet, the predicted Key ID and Counter is expanded to a +regular WireGuard data (type 04) header, which is prefixed to the payload: + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | 04 (8 bits) | Reserved (24 bits) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Key ID (32 bits) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Counter (64 bits) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Data Payload (LL * 8 bits) ... + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + +This happens independently in each of the two TCP directions. diff --git a/netapi.h b/netapi.h index 8fee819..f1b9318 100644 --- a/netapi.h +++ b/netapi.h @@ -59,22 +59,30 @@ struct Packet : QueuedItem { uint8 userdata; uint8 protocol; // which protocol is this packet for/from IpAddr addr; // Optionally set to target/source of the packet - - byte data_pre[4]; - byte data_buf[0]; + enum { - // there's always this much data before data_ptr + // there's always this much data before data_buf, to allow for header expansion + // in front. HEADROOM_BEFORE = 64, }; + + byte data_pre[HEADROOM_BEFORE]; + byte data_buf[0]; + + void Reset() { + data = data_buf; + size = 0; + } }; enum { kPacketAllocSize = 2048 - 16, - kPacketCapacity = kPacketAllocSize - sizeof(Packet) - Packet::HEADROOM_BEFORE, + kPacketCapacity = kPacketAllocSize - sizeof(Packet), }; void FreePacket(Packet *packet); void FreePackets(Packet *packet, Packet **end, int count); +void FreePacketList(Packet *packet); Packet *AllocPacket(); void FreeAllPackets(); @@ -123,7 +131,7 @@ public: class UdpInterface { public: - virtual bool Configure(int listen_port) = 0; + virtual bool Configure(int listen_port_udp, int listen_port_tcp) = 0; virtual void WriteUdpPacket(Packet *packet) = 0; }; diff --git a/network_bsd.cpp b/network_bsd.cpp index 547e3db..62d5e5f 100644 --- a/network_bsd.cpp +++ b/network_bsd.cpp @@ -1,6 +1,7 @@ // SPDX-License-Identifier: AGPL-1.0-only // Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. -#include "network_bsd_common.h" +#include "network_bsd.h" +#include "network_common.h" #include "tunsafe_endian.h" #include "util.h" @@ -14,23 +15,52 @@ #include #include #include +#include #include #include #include #include #include +#if defined(OS_LINUX) +#include +#include +#include +#endif + +#include + +#include "wireguard.h" +#include "wireguard_config.h" + +#if defined(OS_MACOSX) || defined(OS_FREEBSD) +#define TUN_PREFIX_BYTES 4 +#elif defined(OS_LINUX) || defined(OS_ANDROID) +#define TUN_PREFIX_BYTES 0 +#endif + static Packet *freelist; +void tunsafe_die(const char *msg) { + fprintf(stderr, "%s\n", msg); + exit(1); +} + +void SetThreadName(const char *name) { +#if defined(OS_LINUX) + prctl(PR_SET_NAME, name, 0, 0, 0); +#endif // defined(OS_LINUX) +} + void FreePacket(Packet *packet) { - packet->next = freelist; + packet->queue_next = freelist; freelist = packet; } Packet *AllocPacket() { Packet *p = freelist; if (p) { - freelist = p->next; + freelist = Packet_NEXT(p); } else { p = (Packet*)malloc(kPacketAllocSize); if (p == NULL) { @@ -38,192 +68,291 @@ Packet *AllocPacket() { abort(); } } - p->data = p->data_buf + Packet::HEADROOM_BEFORE; - p->size = 0; + p->Reset(); return p; } -void FreePackets() { - Packet *p; - while ( (p = freelist ) != NULL) { - freelist = p->next; - free(p); - } +void FreePacketList(Packet *packet) { + while (packet) + free(exch(packet, Packet_NEXT(packet))); } - -class TunsafeBackendBsdImpl : public TunsafeBackendBsd { -public: - TunsafeBackendBsdImpl(); - virtual ~TunsafeBackendBsdImpl(); - - virtual void RunLoopInner() override; - virtual bool InitializeTun(char devname[16]) override; - - // -- from TunInterface - virtual void WriteTunPacket(Packet *packet) override; - - // -- from UdpInterface - virtual bool Configure(int listen_port) override; - virtual void WriteUdpPacket(Packet *packet) override; - - virtual void HandleSigAlrm() override { got_sig_alarm_ = true; } - virtual void HandleExit() override { exit_ = true; } - -private: - bool ReadFromUdp(bool overload); - bool ReadFromTun(); - bool WriteToUdp(); - bool WriteToTun(); - bool InitializeUnixDomainSocket(const char *devname); - - // Exists for the unix domain sockets - struct SockInfo { - bool is_listener; - - std::string inbuf, outbuf; - }; - bool HandleSpecialPollfd(struct pollfd *pollfd, struct SockInfo *sockinfo); - void CloseSpecialPollfd(size_t i); - void SetUdpFd(int fd); - void SetTunFd(int fd); - - bool got_sig_alarm_; - bool exit_; - - bool tun_readable_, tun_writable_; - bool udp_readable_, udp_writable_; - - Packet *tun_queue_, **tun_queue_end_; - Packet *udp_queue_, **udp_queue_end_; - - Packet *read_packet_; - - enum { - kMaxPollFd = 5, - kPollFdTun = 0, - kPollFdUdp = 1, - kPollFdUnix = 2, - }; - - unsigned int pollfd_num_; - struct pollfd pollfd_[kMaxPollFd]; - - struct SockInfo sockinfo_[kMaxPollFd - 2]; - - struct sockaddr_un un_addr_; - - UnixSocketDeletionWatcher un_deletion_watcher_; -}; - -TunsafeBackendBsdImpl::TunsafeBackendBsdImpl() - : tun_readable_(false), - tun_writable_(false), - udp_readable_(false), - udp_writable_(false), - got_sig_alarm_(false), - exit_(false), - tun_queue_(NULL), - tun_queue_end_(&tun_queue_), - udp_queue_(NULL), - udp_queue_end_(&udp_queue_), - read_packet_(NULL) { - read_packet_ = AllocPacket(); - for(size_t i = 0; i < kMaxPollFd; i++) - pollfd_[i].fd = -1; - pollfd_num_ = 3; - sockinfo_[0].is_listener = true; - memset(&un_addr_, 0, sizeof(un_addr_)); +void FreeAllPackets() { + FreePacketList(exch_null(freelist)); } -TunsafeBackendBsdImpl::~TunsafeBackendBsdImpl() { - if (un_addr_.sun_path[0]) - unlink(un_addr_.sun_path); +////////////////////////////////////////////////////////////////////////////////////////////// + +NetworkBsd::NetworkBsd(NetworkBsdDelegate *delegate, int max_sockets) + : exit_(false), + overload_(false), + sigalarm_flag_(false), + num_roundrobin_(0), + num_sock_(0), + num_endloop_(0), + read_packet_(NULL), + tcp_sockets_(NULL), + delegate_(delegate), + max_sockets_(max_sockets) { + if (max_sockets < 5 || max_sockets > 1000) + tunsafe_die("invalid value for max_sockets"); + + pollfd_ = new struct pollfd[max_sockets]; + sockets_ = new BaseSocketBsd*[max_sockets]; + roundrobin_ = new BaseSocketBsd*[max_sockets]; + endloop_ = new BaseSocketBsd*[max_sockets]; + if (!pollfd_ || !sockets_ || !roundrobin_ || !endloop_) + tunsafe_die("no memory"); + + memset(iov_packets_, 0, sizeof(iov_packets_)); +} + +NetworkBsd::~NetworkBsd() { + assert(tcp_sockets_ == NULL); + assert(num_sock_ == 0); if (read_packet_) FreePacket(read_packet_); - for(size_t i = 0; i < pollfd_num_; i++) - close(pollfd_[i].fd); + for (size_t i = 0; i < kMaxIovec; i++) + if (iov_packets_[i]) + FreePacket(iov_packets_[i]); + + delete [] pollfd_; + delete [] sockets_; + delete [] roundrobin_; + delete [] endloop_; } -void TunsafeBackendBsdImpl::SetUdpFd(int fd) { - pollfd_[kPollFdUdp].fd = fd; - pollfd_[kPollFdUdp].events = POLLIN; - udp_writable_ = true; -} +void NetworkBsd::RunLoop(const sigset_t *sigmask) { + int free_packet_interval = 10; + int overload_ctr = 0; + uint64 last_second_loop = 0; + uint64 now = 0; -void TunsafeBackendBsdImpl::SetTunFd(int fd) { - pollfd_[kPollFdTun].fd = fd; - pollfd_[kPollFdTun].events = POLLIN; - tun_writable_ = true; -} + if (!WithSigalarmSupport) + last_second_loop = OsGetMilliseconds(); + + while (!exit_) { + int n; + bool new_second = false; -bool TunsafeBackendBsdImpl::ReadFromUdp(bool overload) { - socklen_t sin_len; - sin_len = sizeof(read_packet_->addr.sin); - int r = recvfrom(pollfd_[kPollFdUdp].fd, read_packet_->data, kPacketCapacity, 0, - (sockaddr*)&read_packet_->addr.sin, &sin_len); - if (r >= 0) { -// printf("Read %d bytes from UDP\n", r); - read_packet_->sin_size = sin_len; - read_packet_->size = r; - if (processor_) { - processor_->HandleUdpPacket(read_packet_, overload); - read_packet_ = AllocPacket(); + if (WithSigalarmSupport) { + if (sigalarm_flag_) { + sigalarm_flag_ = false; + new_second = true; + } + } else { + now = OsGetMilliseconds(); + if ((now - last_second_loop) >= 1000) { + // Avoid falling behind too much + last_second_loop = (now - last_second_loop) >= 2000 ? now : last_second_loop + 1000; + new_second = true; + } } - return true; - } else { - if (errno != EAGAIN) { - fprintf(stderr, "Read from UDP failed\n"); + + if (new_second) { + delegate_->OnSecondLoop(now); + + struct BaseSocketBsd **socks = sockets_; + for (int i = 0; i < num_sock_; i++) + socks[i]->Periodic(); + + if (free_packet_interval == 0) { + FreeAllPackets(); + free_packet_interval = 10; + } + free_packet_interval--; + + overload_ctr -= (overload_ctr != 0); } - udp_readable_ = false; + +#if defined(OS_LINUX) || defined(OS_FREEBSD) + n = ppoll(pollfd_, num_sock_, NULL, sigmask); +#else + n = poll(pollfd_, num_sock_, WithSigalarmSupport ? -1 : std::max((int)(last_second_loop - now) + 1000, 0)); +#endif + if (n == -1) { + if (errno != EINTR) { + RERROR("poll failed"); + break; + } + } else { + // Iterate backwards to support deleting elements + struct pollfd *pfd = pollfd_; + struct BaseSocketBsd **socks = sockets_; + for (int i = num_sock_ - 1; i >= 0; i--) { + if (pfd[i].revents) + socks[i]->HandleEvents(pfd[i].revents); + } + } + + overload_ = (overload_ctr != 0); + for (int loop = 0; ; loop++) { + // Whenever we don't finish set overload ctr. + if (loop == 256) { + overload_ctr = 4; + break; + } + int i = num_roundrobin_ - 1; + struct BaseSocketBsd **rrlist = roundrobin_; + if (i < 0) + break; + do { + if (!rrlist[i]->DoRoundRobin()) + RemoveFromRoundRobin(i); + } while (i--); + } + + struct BaseSocketBsd **endloop = endloop_; + for (int j = num_endloop_ - 1; j >= 0; j--) { + endloop[j]->endloop_slot_ = -1; + endloop[j]->DoEndloop(); + } + num_endloop_ = 0; + + delegate_->RunAllMainThreadScheduled(); + } +} + +void NetworkBsd::RemoveFromRoundRobin(int i) { + BaseSocketBsd *cur = roundrobin_[i], *last = roundrobin_[num_roundrobin_-- - 1]; + assert(cur->roundrobin_slot_ == i); + roundrobin_[i] = last; + last->roundrobin_slot_ = i; + cur->roundrobin_slot_ = -1; +} + +void NetworkBsd::ReallocateIov(size_t j) { + Packet *p = AllocPacket(); + iov_packets_[j] = p; + iov_[j].iov_base = p->data; + iov_[j].iov_len = kPacketCapacity; +} + +void NetworkBsd::EnsureIovAllocated() { + if (iov_packets_[0] == NULL) { + for (size_t i = 0; i < kMaxIovec; i++) + ReallocateIov(i); + } +} + +////////////////////////////////////////////////////////////////////////////////////////////// + +BaseSocketBsd::~BaseSocketBsd() { + CloseSocket(); +} + +void BaseSocketBsd::CloseSocket() { + if (fd_ != -1) + close(fd_); + if (roundrobin_slot_ >= 0) + network_->RemoveFromRoundRobin(roundrobin_slot_); + if (endloop_slot_ >= 0) { + BaseSocketBsd *last = network_->endloop_[network_->num_endloop_-- - 1]; + network_->endloop_[endloop_slot_] = last; + last->endloop_slot_ = endloop_slot_; + } + if (pollfd_slot_ >= 0) { + unsigned int cur = pollfd_slot_, last = network_->num_sock_-- - 1; + BaseSocketBsd *lastsock = network_->sockets_[last]; + network_->sockets_[cur] = lastsock; + lastsock->pollfd_slot_ = cur; + network_->pollfd_[cur] = network_->pollfd_[last]; + } + fd_ = -1; + endloop_slot_ = pollfd_slot_ = roundrobin_slot_ = -1; +} + +void BaseSocketBsd::InitPollSlot(int fd, int events) { + assert(network_->num_sock_ != network_->max_sockets_); + assert(fd_ == -1); + fd_ = fd; + unsigned int slot = pollfd_slot_; + if (pollfd_slot_ < 0) + pollfd_slot_ = slot = network_->num_sock_++; + network_->sockets_[slot] = this; + struct pollfd *pfd = &network_->pollfd_[slot]; + pfd->fd = fd; + pfd->events = events; + pfd->revents = 0; +} + +void BaseSocketBsd::AddToRoundRobin() { + if (roundrobin_slot_ < 0) + network_->roundrobin_[roundrobin_slot_ = network_->num_roundrobin_++] = this; +} + +void BaseSocketBsd::AddToEndLoop() { + if (endloop_slot_ < 0) + network_->endloop_[endloop_slot_ = network_->num_endloop_++] = this; +} + +////////////////////////////////////////////////////////////////////////////////////////////// + +TunSocketBsd::TunSocketBsd(NetworkBsd *network, WireguardProcessor *processor) + : BaseSocketBsd(network), + tun_readable_(false), + tun_writable_(false), + tun_interface_gone_(false), + tun_queue_(NULL), + tun_queue_end_(&tun_queue_), + processor_(processor) { +} + +TunSocketBsd::~TunSocketBsd() { +} + +bool TunSocketBsd::Initialize(int fd) { + if (!HasFreePollSlot()) return false; - } -} - -bool TunsafeBackendBsdImpl::WriteToUdp() { - assert(udp_writable_); -// RINFO("Send %d bytes to %s", (int)udp_queue_->size, inet_ntoa(udp_queue_->sin.sin_addr)); - int r = sendto(pollfd_[kPollFdUdp].fd, udp_queue_->data, udp_queue_->size, 0, - (sockaddr*)&udp_queue_->addr.sin, sizeof(udp_queue_->addr.sin)); - if (r < 0) { - if (errno == EAGAIN) { - udp_writable_ = false; - pollfd_[kPollFdUdp].events = POLLIN | POLLOUT; - return false; - } - perror("Write to UDP failed"); - } else { - if (r != udp_queue_->size) - perror("Write to udp incomplete!"); -// else -// RINFO("Wrote %d bytes to UDP", r); - } - Packet *next = udp_queue_->next; - FreePacket(udp_queue_); - if ((udp_queue_ = next) != NULL) return true; - udp_queue_end_ = &udp_queue_; - return false; + fcntl(fd, F_SETFD, FD_CLOEXEC); + fcntl(fd, F_SETFL, O_NONBLOCK); + InitPollSlot(fd, POLLIN); + tun_writable_ = true; + return true; } static inline bool IsCompatibleProto(uint32 v) { return v == AF_INET || v == AF_INET6; } -bool TunsafeBackendBsdImpl::ReadFromTun() { +void TunSocketBsd::HandleEvents(int revents) { + if (revents & (POLLERR | POLLHUP | POLLNVAL)) { + if (revents & POLLERR) { + tun_interface_gone_ = true; + RERROR("Tun interface gone, closing."); + } else { + RERROR("Tun interface error %d, closing.", revents); + } + tun_readable_ = tun_writable_ = false; + network_->PostExit(); + } else { + tun_readable_ = (revents & POLLIN) != 0; + if (revents & POLLOUT) { + SetPollFlags(POLLIN); + tun_writable_ = true; + } + } + AddToRoundRobin(); +} + +bool TunSocketBsd::DoRead() { assert(tun_readable_); - Packet *packet = read_packet_; - int r = read(pollfd_[kPollFdTun].fd, packet->data - TUN_PREFIX_BYTES, kPacketCapacity + TUN_PREFIX_BYTES); + Packet *packet = network_->read_packet_; + if (!packet) + network_->read_packet_ = packet = AllocPacket(); + + int r = read(fd_, packet->data - TUN_PREFIX_BYTES, kPacketCapacity + TUN_PREFIX_BYTES); if (r >= 0) { -// printf("Read %d bytes from TUN\n", r); +// printf("Read %d bytes from TUN\n", r); packet->size = r - TUN_PREFIX_BYTES; - if (r >= TUN_PREFIX_BYTES && (!TUN_PREFIX_BYTES || IsCompatibleProto(ReadBE32(packet->data - TUN_PREFIX_BYTES))) && processor_) { -// printf("%X %X %X %X %X %X %X %X\n", -// read_packet_->data[0], read_packet_->data[1], read_packet_->data[2], read_packet_->data[3], -// read_packet_->data[4], read_packet_->data[5], read_packet_->data[6], read_packet_->data[7]); - read_packet_ = AllocPacket(); + if (r >= TUN_PREFIX_BYTES && (!TUN_PREFIX_BYTES || IsCompatibleProto(ReadBE32(packet->data - TUN_PREFIX_BYTES)))) { + // printf("%X %X %X %X %X %X %X %X\n", + // read_packet_->data[0], read_packet_->data[1], read_packet_->data[2], read_packet_->data[3], + // read_packet_->data[4], read_packet_->data[5], read_packet_->data[6], read_packet_->data[7]); + network_->read_packet_ = NULL; processor_->HandleTunPacket(packet); } - return true; + return true; } else { if (errno != EAGAIN) { fprintf(stderr, "Read from tun failed\n"); @@ -237,16 +366,16 @@ static uint32 GetProtoFromPacket(const uint8 *data, size_t size) { return size < 1 || (data[0] >> 4) != 6 ? AF_INET : AF_INET6; } -bool TunsafeBackendBsdImpl::WriteToTun() { +bool TunSocketBsd::DoWrite() { assert(tun_writable_); if (TUN_PREFIX_BYTES) { WriteBE32(tun_queue_->data - TUN_PREFIX_BYTES, GetProtoFromPacket(tun_queue_->data, tun_queue_->size)); } - int r = write(pollfd_[kPollFdTun].fd, tun_queue_->data - TUN_PREFIX_BYTES, tun_queue_->size + TUN_PREFIX_BYTES); + int r = write(fd_, tun_queue_->data - TUN_PREFIX_BYTES, tun_queue_->size + TUN_PREFIX_BYTES); if (r < 0) { if (errno == EAGAIN) { tun_writable_ = false; - pollfd_[kPollFdTun].events = POLLIN | POLLOUT; + SetPollFlags(POLLIN | POLLOUT); return false; } RERROR("Write to tun failed"); @@ -254,58 +383,262 @@ bool TunsafeBackendBsdImpl::WriteToTun() { r -= TUN_PREFIX_BYTES; if (r != tun_queue_->size) RERROR("Write to tun incomplete!"); -// else -// RINFO("Wrote %d bytes to TUN", r); - } - Packet *next = tun_queue_->next; + // else + // RINFO("Wrote %d bytes to TUN", r); + } + Packet *next = Packet_NEXT(tun_queue_); FreePacket(tun_queue_); if ((tun_queue_ = next) != NULL) return true; tun_queue_end_ = &tun_queue_; return false; } -bool TunsafeBackendBsdImpl::InitializeTun(char devname[16]) { - int tun_fd = open_tun(devname, 16); - if (tun_fd < 0) { RERROR("Error opening tun device"); return false; } - fcntl(tun_fd, F_SETFD, FD_CLOEXEC); - fcntl(tun_fd, F_SETFL, O_NONBLOCK); - SetTunFd(tun_fd); - - InitializeUnixDomainSocket(devname); - return true; -} - -void TunsafeBackendBsdImpl::WriteTunPacket(Packet *packet) override { - assert(pollfd_[kPollFdTun].fd >= 0); +void TunSocketBsd::WritePacket(Packet *packet) { + assert(fd_ >= 0); Packet *queue_is_used = tun_queue_; *tun_queue_end_ = packet; - tun_queue_end_ = &packet->next; - packet->next = NULL; + tun_queue_end_ = &Packet_NEXT(packet); + packet->queue_next = NULL; if (!queue_is_used) - WriteToTun(); + DoWrite(); } -// Called to initialize udp -bool TunsafeBackendBsdImpl::Configure(int listen_port) override { - int udp_fd = open_udp(listen_port); - if (udp_fd < 0) { RERROR("Error opening udp"); return false; } +bool TunSocketBsd::DoRoundRobin() { + bool more_work = false; + if (tun_queue_ && tun_writable_) + more_work = DoWrite(); + if (tun_readable_) + more_work |= DoRead(); + return more_work; +} + +////////////////////////////////////////////////////////////////////////////////////////////// + +UdpSocketBsd::UdpSocketBsd(NetworkBsd *network, WireguardProcessor *processor) + : BaseSocketBsd(network), + udp_readable_(false), + udp_writable_(false), + udp_queue_(NULL), + udp_queue_end_(&udp_queue_), + processor_(processor) { +} + +UdpSocketBsd::~UdpSocketBsd() { +} + +bool UdpSocketBsd::Initialize(int listen_port) { + if (!HasFreePollSlot()) { + RERROR("No free internal sockets"); + return false; + } + int udp_fd = socket(AF_INET, SOCK_DGRAM, 0); + if (udp_fd < 0) { + RERROR("socket(SOCK_DGRAM) failed"); + return false; + } + sockaddr_in sin = {0}; + sin.sin_family = AF_INET; + sin.sin_port = htons(listen_port); + if (bind(udp_fd, (struct sockaddr*)&sin, sizeof(sin)) != 0) { + close(udp_fd); + RERROR("bind on udp socket port %d failed", listen_port); + return false; + } fcntl(udp_fd, F_SETFD, FD_CLOEXEC); fcntl(udp_fd, F_SETFL, O_NONBLOCK); - SetUdpFd(udp_fd); + InitPollSlot(udp_fd, POLLIN); + udp_writable_ = true; return true; } -void TunsafeBackendBsdImpl::WriteUdpPacket(Packet *packet) override { - assert(pollfd_[kPollFdUdp].fd >= 0); - Packet *queue_is_used = udp_queue_; - *udp_queue_end_ = packet; - udp_queue_end_ = &packet->next; - packet->next = NULL; - if (!queue_is_used) - WriteToUdp(); +void UdpSocketBsd::HandleEvents(int revents) { + if (revents & (POLLERR | POLLHUP | POLLNVAL)) { + RERROR("UDP error %d, closing.", revents); + network_->PostExit(); + } else { + udp_readable_ = (revents & POLLIN) != 0; + if (revents & POLLOUT) { + SetPollFlags(POLLIN); + udp_writable_ = true; + } + } + AddToRoundRobin(); } -bool TunsafeBackendBsdImpl::InitializeUnixDomainSocket(const char *devname) { +bool UdpSocketBsd::DoRead() { + socklen_t sin_len; + Packet *read_packet = network_->read_packet_; + if (read_packet == NULL) + network_->read_packet_ = read_packet = AllocPacket(); + + sin_len = sizeof(read_packet->addr.sin); + int r = recvfrom(fd_, read_packet->data, kPacketCapacity, 0, + (sockaddr*)&read_packet->addr.sin, &sin_len); + if (r >= 0) { + // printf("Read %d bytes from UDP\n", r); + read_packet->sin_size = sin_len; + read_packet->size = r; + read_packet->protocol = kPacketProtocolUdp; + network_->read_packet_ = NULL; + processor_->HandleUdpPacket(read_packet, network_->overload_); + return true; + } else { + if (errno != EAGAIN) { + fprintf(stderr, "Read from UDP failed\n"); + } + udp_readable_ = false; + return false; + } +} + +bool UdpSocketBsd::DoWrite() { + assert(udp_writable_); + // RINFO("Send %d bytes to %s", (int)udp_queue_->size, inet_ntoa(udp_queue_->sin.sin_addr)); + int r = sendto(fd_, udp_queue_->data, udp_queue_->size, 0, + (sockaddr*)&udp_queue_->addr.sin, sizeof(udp_queue_->addr.sin)); + if (r < 0) { + if (errno == EAGAIN) { + udp_writable_ = false; + SetPollFlags(POLLIN | POLLOUT); + return false; + } + perror("Write to UDP failed"); + } else { + if (r != udp_queue_->size) + perror("Write to udp incomplete!"); + // else + // RINFO("Wrote %d bytes to UDP", r); + } + Packet *next = Packet_NEXT(udp_queue_); + FreePacket(udp_queue_); + if ((udp_queue_ = next) != NULL) return true; + udp_queue_end_ = &udp_queue_; + return false; +} + +void UdpSocketBsd::WritePacket(Packet *packet) { + assert(fd_ >= 0); + Packet *queue_is_used = udp_queue_; + *udp_queue_end_ = packet; + udp_queue_end_ = &Packet_NEXT(packet); + packet->queue_next = NULL; + if (!queue_is_used) + DoWrite(); +} + +bool UdpSocketBsd::DoRoundRobin() { + bool did_work = false; + if (udp_queue_ && udp_writable_) + did_work = DoWrite(); + if (udp_readable_) + did_work |= DoRead(); + return did_work; +} + +////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(OS_LINUX) +UnixSocketDeletionWatcher::UnixSocketDeletionWatcher() + : inotify_fd_(-1) { + pipes_[0] = -1; + pipes_[0] = -1; +} + +UnixSocketDeletionWatcher::~UnixSocketDeletionWatcher() { + close(inotify_fd_); + close(pipes_[0]); + close(pipes_[1]); +} + +bool UnixSocketDeletionWatcher::Start(const char *path, bool *flag_to_set) { + assert(inotify_fd_ == -1); + path_ = path; + flag_to_set_ = flag_to_set; + pid_ = getpid(); + inotify_fd_ = inotify_init1(IN_CLOEXEC | IN_NONBLOCK); + if (inotify_fd_ == -1) { + perror("inotify_init1() failed"); + return false; + } + if (inotify_add_watch(inotify_fd_, "/var/run/wireguard", IN_DELETE | IN_DELETE_SELF) == -1) { + perror("inotify_add_watch failed"); + return false; + } + if (pipe(pipes_) == -1) { + perror("pipe() failed"); + return false; + } + return pthread_create(&thread_, NULL, &UnixSocketDeletionWatcher::RunThread, this) == 0; +} + +void UnixSocketDeletionWatcher::Stop() { + RINFO("Stopping.."); + void *retval; + write(pipes_[1], "", 1); + pthread_join(thread_, &retval); +} + +void *UnixSocketDeletionWatcher::RunThread(void *arg) { + UnixSocketDeletionWatcher *self = (UnixSocketDeletionWatcher*)arg; + return self->RunThreadInner(); +} + +void *UnixSocketDeletionWatcher::RunThreadInner() { + char buf[sizeof(struct inotify_event) + NAME_MAX + 1] + __attribute__ ((aligned(__alignof__(struct inotify_event)))); + fd_set fdset; + struct stat st; + for(;;) { + if (lstat(path_, &st) == -1 && errno == ENOENT) { + RINFO("Unix socket %s deleted.", path_); + *flag_to_set_ = true; + kill(pid_, SIGALRM); + break; + } + FD_ZERO(&fdset); + FD_SET(inotify_fd_, &fdset); + FD_SET(pipes_[0], &fdset); + int n = select(std::max(inotify_fd_, pipes_[0]) + 1, &fdset, NULL, NULL, NULL); + if (n == -1) { + perror("select"); + break; + } + if (FD_ISSET(inotify_fd_, &fdset)) { + ssize_t len = read(inotify_fd_, buf, sizeof(buf)); + if (len == -1) { + perror("read"); + break; + } + } + if (FD_ISSET(pipes_[0], &fdset)) + break; + } + return NULL; +} + +#else // !defined(OS_LINUX) + +bool UnixSocketDeletionWatcher::Poll(const char *path) { + struct stat st; + return lstat(path, &st) == -1 && errno == ENOENT; +} + +#endif // !defined(OS_LINUX) + +UnixDomainSocketListenerBsd::UnixDomainSocketListenerBsd(NetworkBsd *network, WireguardProcessor *processor) + : BaseSocketBsd(network), + processor_(processor) { + memset(&un_addr_, 0, sizeof(un_addr_)); +} + +UnixDomainSocketListenerBsd::~UnixDomainSocketListenerBsd() { + if (un_addr_.sun_path[0]) + unlink(un_addr_.sun_path); +} + +bool UnixDomainSocketListenerBsd::Initialize(const char *devname) { + if (!HasFreePollSlot()) + return false; int fd = socket(AF_UNIX, SOCK_STREAM, 0); if (fd == -1) { RERROR("Error creating unix domain socket"); @@ -329,194 +662,408 @@ bool TunsafeBackendBsdImpl::InitializeUnixDomainSocket(const char *devname) { close(fd); return false; } - - pollfd_[kPollFdUnix].fd = fd; - pollfd_[kPollFdUnix].events = POLLIN; - + InitPollSlot(fd, POLLIN); return true; } +void UnixDomainSocketListenerBsd::HandleEvents(int revents) { + if (revents & POLLIN) { + // wait if we can't allocate more pollfd + if (!HasFreePollSlot()) { + SetPollFlags(0); + return; + } + int new_fd = accept(fd_, NULL, NULL); + if (new_fd >= 0) { + UnixDomainSocketChannelBsd *channel = new UnixDomainSocketChannelBsd(network_, processor_, new_fd); + } else { + RERROR("Unix domain socket accept failed"); + } + } + if (revents & ~POLLIN) { + RERROR("Unix domain socket got an error code"); + } +} + +void UnixDomainSocketListenerBsd::Periodic() { + if (un_deletion_watcher_.Poll(un_addr_.sun_path)) { + RINFO("Unix socket %s deleted.", un_addr_.sun_path); + network_->PostExit(); + } else { + // try again + SetPollFlags(POLLIN); + } +} + +////////////////////////////////////////////////////////////////////////////////////////////// + +UnixDomainSocketChannelBsd::UnixDomainSocketChannelBsd(NetworkBsd *network, WireguardProcessor *processor, int fd) + : BaseSocketBsd(network), + processor_(processor) { + assert(HasFreePollSlot()); + InitPollSlot(fd, POLLIN); +} + +UnixDomainSocketChannelBsd::~UnixDomainSocketChannelBsd() { +} + static const char *FindMessageEnd(const char *start, size_t size) { if (size <= 1) return NULL; const char *start_end = start + size - 1; - for(;(start = (const char*)memchr(start, '\n', start_end - start)) != NULL; start++) { + for (; (start = (const char*)memchr(start, '\n', start_end - start)) != NULL; start++) { if (start[1] == '\n') return start + 2; } return NULL; } -bool TunsafeBackendBsdImpl::HandleSpecialPollfd(struct pollfd *pfd, struct SockInfo *sockinfo) { - // handle domain socket thing - if (sockinfo->is_listener) { - if (pfd->revents & POLLIN) { - // wait if we can't allocate more pollfd - if (pollfd_num_ == kMaxPollFd) { - pfd->events = 0; - return true; - } - int fd = accept(pfd->fd, NULL, NULL); - if (fd >= 0) { - size_t slot = pollfd_num_++; - pollfd_[slot].fd = fd; - pollfd_[slot].events = POLLIN; - pollfd_[slot].revents = 0; - sockinfo_[slot - 2].is_listener = false; - } else { - RERROR("Unix domain socket accept failed"); - } - } - if (pfd->revents & ~POLLIN) { - RERROR("Unix domain socket got an error code"); - return false; - } - return true; - } - if (pfd->revents & POLLIN) { +bool UnixDomainSocketChannelBsd::HandleEventsInner(int revents) { + if (revents & POLLIN) { char buf[4096]; // read as much data as we can until we see \n\n - ssize_t n = recv(pfd->fd, buf, sizeof(buf), 0); + ssize_t n = recv(fd_, buf, sizeof(buf), 0); if (n <= 0) return (n == -1 && errno == EAGAIN); // premature eof or error - sockinfo->inbuf.append(buf, n); - const char *message_end = FindMessageEnd(&sockinfo->inbuf[0], sockinfo->inbuf.size()); + inbuf_.append(buf, n); + const char *message_end = FindMessageEnd(&inbuf_[0], inbuf_.size()); if (message_end) { - if (message_end != &sockinfo->inbuf[sockinfo->inbuf.size()]) + if (message_end != &inbuf_[inbuf_.size()]) return false; // trailing data? - WgConfig::HandleConfigurationProtocolMessage(processor_, std::move(sockinfo->inbuf), &sockinfo->outbuf); - if (!sockinfo->outbuf.size()) + WgConfig::HandleConfigurationProtocolMessage(processor_, std::move(inbuf_), &outbuf_); + if (!outbuf_.size()) return false; - pfd->revents = pfd->events = POLLOUT; + SetPollFlags(POLLOUT); + revents |= POLLOUT; } } - if (pfd->revents & POLLOUT) { - size_t n = send(pfd->fd, sockinfo->outbuf.data(), sockinfo->outbuf.size(), 0); + if (revents & POLLOUT) { + size_t n = send(fd_, outbuf_.data(), outbuf_.size(), 0); if (n <= 0) return (n == -1 && errno == EAGAIN); // premature eof or error - sockinfo->outbuf.erase(0, n); - if (!sockinfo->outbuf.size()) + outbuf_.erase(0, n); + if (!outbuf_.size()) return false; } - - if (pfd->revents & ~(POLLIN | POLLOUT)) { + if (revents & ~(POLLIN | POLLOUT)) { RERROR("Unix domain socket got an error code"); return false; } return true; } -void TunsafeBackendBsdImpl::CloseSpecialPollfd(size_t i) { - close(pollfd_[i].fd); - pollfd_[i].fd = -1; - sockinfo_[i - 2].inbuf.clear(); - sockinfo_[i - 2].outbuf.clear(); - pollfd_[i] = pollfd_[(size_t)pollfd_num_ - 1]; - std::swap(sockinfo_[i - 2], sockinfo_[(size_t)pollfd_num_ - 1 - 2]); - - // Can now allow more sockets? - if (pollfd_num_-- == kMaxPollFd && sockinfo_[kPollFdUnix - 2].is_listener) - pollfd_[kPollFdUnix].events = POLLIN; +void UnixDomainSocketChannelBsd::HandleEvents(int revents) { + if (!HandleEventsInner(revents)) + delete this; } -void TunsafeBackendBsdImpl::RunLoopInner() { - int free_packet_interval = 10; - int overload_ctr = 0; +////////////////////////////////////////////////////////////////////////////////////////////// - if (!un_deletion_watcher_.Start(un_addr_.sun_path, &exit_)) - return; +TcpSocketListenerBsd::TcpSocketListenerBsd(NetworkBsd *bsd, WireguardProcessor *processor) + : BaseSocketBsd(bsd), + processor_(processor) { - while (!exit_) { - int n = -1; +} - if (got_sig_alarm_) { - got_sig_alarm_ = false; +TcpSocketListenerBsd::~TcpSocketListenerBsd() { - if (un_deletion_watcher_.Poll(un_addr_.sun_path)) { - RINFO("Unix socket %s deleted.", un_addr_.sun_path); - break; - } - processor_->SecondLoop(); +} - if (free_packet_interval == 0) { - FreePackets(); - free_packet_interval = 10; - } - free_packet_interval--; +bool TcpSocketListenerBsd::Initialize(int port) { + if (!HasFreePollSlot()) + return false; - overload_ctr -= (overload_ctr != 0); + CloseSocket(); + + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { RERROR("Error listen socket"); return false; } + fcntl(fd, F_SETFD, FD_CLOEXEC); + fcntl(fd, F_SETFL, O_NONBLOCK); + + int optval = 1; + // setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &optval, sizeof(optval)); + setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)); + + sockaddr_in sin = {0}; + sin.sin_family = AF_INET; + sin.sin_port = htons(port); + sin.sin_addr.s_addr = INADDR_ANY; + if (bind(fd, (sockaddr*)&sin, sizeof(sin))) { + RERROR("Error binding socket on port %d", port); + close(fd); + return false; + } + if (listen(fd, 5)) { + RERROR("Error listen socket"); + close(fd); + return false; + } + RINFO("Started TCP listening socket on port %d", port); + InitPollSlot(fd, POLLIN); + return true; +} + +void TcpSocketListenerBsd::HandleEvents(int revents) { + if (revents & POLLIN) { + // wait if we can't allocate more pollfd + if (!HasFreePollSlot()) { + SetPollFlags(0); + return; } - -#if defined(OS_LINUX) || defined(OS_FREEBSD) - n = ppoll(pollfd_, pollfd_num_, NULL, &orig_signal_mask_); -#else - n = poll(pollfd_, pollfd_num_, -1); -#endif - if (n == -1) { - if (errno != EINTR) { - RERROR("poll failed"); - break; - } + IpAddr addr; + socklen_t len = sizeof(addr); + int new_fd = accept(fd_, (sockaddr*)&addr, &len); + if (new_fd >= 0) { + RINFO("Created new tcp socket"); + TcpSocketBsd *channel = new TcpSocketBsd(network_, processor_); + if (channel) + channel->InitializeIncoming(new_fd, addr); + else + close(new_fd); } else { - - if (pollfd_[kPollFdTun].revents & (POLLERR | POLLHUP | POLLNVAL)) { - if (pollfd_[kPollFdTun].revents & POLLERR) { - tun_interface_gone_ = true; - RERROR("Tun interface gone, closing."); - } else { - RERROR("Tun interface error %d, closing.", pollfd_[kPollFdTun].revents); - } - break; - } - tun_readable_ = (pollfd_[kPollFdTun].revents & POLLIN) != 0; - if (pollfd_[kPollFdTun].revents & POLLOUT) { - pollfd_[kPollFdTun].events = POLLIN; - tun_writable_ = true; - } - - if (pollfd_[kPollFdUdp].revents & (POLLERR | POLLHUP | POLLNVAL)) { - RERROR("UDP error %d, closing.", pollfd_[kPollFdUdp].revents); - break; - } - - udp_readable_ = (pollfd_[kPollFdUdp].revents & POLLIN) != 0; - if (pollfd_[kPollFdUdp].revents & POLLOUT) { - pollfd_[kPollFdUdp].events = POLLIN; - udp_writable_ = true; - } - - for(size_t i = 2; i < pollfd_num_; i++) { - if (pollfd_[i].revents && !HandleSpecialPollfd(&pollfd_[i], &sockinfo_[i - 2])) { - // Close the fd / discard the sockinfo - CloseSpecialPollfd(i); - i--; - } - } + RERROR("Unix domain socket accept failed"); } - - bool overload = (overload_ctr != 0); - - for(int loop = 0; ; loop++) { - // Whenever we don't finish set overload ctr. - if (loop == 256) { - overload_ctr = 4; - break; - } - bool more_work = false; - if (tun_queue_ != NULL && tun_writable_) more_work |= WriteToTun(); - if (udp_queue_ != NULL && udp_writable_) more_work |= WriteToUdp(); - if (tun_readable_) more_work |= ReadFromTun(); - if (udp_readable_) more_work |= ReadFromUdp(overload); - if (!more_work) - break; - } - - processor_->RunAllMainThreadScheduled(); - } - - un_deletion_watcher_.Stop(); + } } -TunsafeBackendBsd *CreateTunsafeBackendBsd() { - return new TunsafeBackendBsdImpl; +void TcpSocketListenerBsd::Periodic() { + SetPollFlags(POLLIN); +} +////////////////////////////////////////////////////////////////////////////////////////////// + +TcpSocketBsd::TcpSocketBsd(NetworkBsd *net, WireguardProcessor *processor) + : BaseSocketBsd(net), + readable_(false), + writable_(true), + endpoint_protocol_(0), + age(0), + handshake_attempts(0), + wqueue_(NULL), + wqueue_end_(&wqueue_), + wqueue_bytes_(0), + processor_(processor), + tcp_packet_handler_(&net->packet_pool_) { + // insert in network's linked list + next_ = net->tcp_sockets_; + net->tcp_sockets_ = this; + + network_->EnsureIovAllocated(); +} + +TcpSocketBsd::~TcpSocketBsd() { + // Unlink myself from the network's linked list. + TcpSocketBsd **p = &network_->tcp_sockets_; + while (*p != this) p = &(*p)->next_; + *p = next_; + + RINFO("Destroyed tcp socket"); +} + +void TcpSocketBsd::InitializeIncoming(int fd, const IpAddr &addr) { + assert(fd_ == -1); + endpoint_protocol_ = kPacketProtocolTcp | kPacketProtocolIncomingConnection; + endpoint_ = addr; + InitPollSlot(fd, POLLIN); +} + +bool TcpSocketBsd::InitializeOutgoing(const IpAddr &addr) { + assert(fd_ == -1); + if (!HasFreePollSlot() || addr.sin.sin_family == 0) + return false; + + endpoint_protocol_ = kPacketProtocolTcp; + endpoint_ = addr; + writable_ = false; + + int fd = socket(addr.sin.sin_family, SOCK_STREAM, 0); + if (fd < 0) { perror("socket: outgoing tcp"); return false; } + fcntl(fd, F_SETFD, FD_CLOEXEC); + fcntl(fd, F_SETFL, O_NONBLOCK); + + char buf[kSizeOfAddress]; + RINFO("Connecting to tcp://%s:%d...", PrintIpAddr(endpoint_, buf), ReadBE16(&endpoint_.sin.sin_port)); + + if (connect(fd, (sockaddr*)&endpoint_.sin, + endpoint_.sin.sin_family == AF_INET ? sizeof(sockaddr_in) : sizeof(sockaddr_in6))) { + if (errno != EINPROGRESS) { + perror("connect: outgoing tcp"); + close(fd); + return false; + } + } + + InitPollSlot(fd, POLLOUT | POLLIN); + return true; +} + +void TcpSocketBsd::WritePacket(Packet *packet) { + assert(fd_ >= 0); + + tcp_packet_handler_.AddHeaderToOutgoingPacket(packet); + + Packet *old_value = wqueue_; + *wqueue_end_ = packet; + wqueue_end_ = &Packet_NEXT(packet); + packet->queue_next = NULL; + + AddToEndLoop(); + + wqueue_bytes_ += packet->size; + + // When many bytes have been queued, perform the write. + if (writable_ && wqueue_bytes_ >= 32768) + DoWrite(); +} + +void TcpSocketBsd::HandleEvents(int revents) { + if (revents & (POLLERR | POLLHUP | POLLNVAL)) { + RINFO("TcpSocket error"); + CloseSocketAndDestroy(); + return; + } + + if (revents & POLLOUT) { + SetPollFlags(POLLIN); + AddToEndLoop(); + writable_ = true; + } + + if (revents & POLLIN) + DoRead(); +} + +void TcpSocketBsd::DoEndloop() { + if (writable_ && wqueue_) + DoWrite(); +} + +void TcpSocketBsd::DoRead() { + ssize_t bytes_read = readv(fd_, network_->iov_, NetworkBsd::kMaxIovec); + ssize_t bytes_read_org = bytes_read; + if (bytes_read < 0) { + if (errno != EAGAIN) { + RERROR("tcp readv says error code: %d", errno); + CloseSocketAndDestroy(); + } + return; + } + // Go through and read the packet structures that are ready and queue them up + NetworkBsd *net = network_; + for (size_t j = 0; bytes_read; j++) { + size_t m = std::min(bytes_read, net->iov_[j].iov_len); + Packet *p = net->iov_packets_[j]; + p->size = (int)m; + bytes_read -= m; + tcp_packet_handler_.QueueIncomingPacket(p); + net->ReallocateIov(j); + } + // Parse it all + while (Packet *p = tcp_packet_handler_.GetNextWireguardPacket()) { + p->protocol = endpoint_protocol_; + p->addr = endpoint_; + processor_->HandleUdpPacket(p, network_->overload_); + } + + if (tcp_packet_handler_.error() || bytes_read_org == 0) + CloseSocketAndDestroy(); +} + +void TcpSocketBsd::DoWrite() { + enum { kMaxIoWrite = 16 }; + struct iovec vecs[kMaxIoWrite]; + Packet *p = wqueue_; + size_t nvec = 0; + for (; p && nvec < kMaxIoWrite; nvec++, p = Packet_NEXT(p)) { + vecs[nvec].iov_base = p->data; + vecs[nvec].iov_len = p->size; + } + ssize_t n = writev(fd_, vecs, nvec); + + if (n < 0) { + if (errno != EAGAIN) { + RERROR("tcp writev says error code: %d", errno); + CloseSocketAndDestroy(); + } else { + writable_ = false; + SetPollFlags(POLLIN | POLLOUT); + } + return; + } + wqueue_bytes_ -= n; + // discard those initial n bytes worth of packets + size_t i = 0; + p = wqueue_; + while (n) { + if (n < p->size) { + p->data += n, p->size -= n; + break; + } + n -= p->size; + FreePacket(exch(p, Packet_NEXT(p))); + } + if (!(wqueue_ = p)) + wqueue_end_ = &wqueue_; +} + +void TcpSocketBsd::CloseSocketAndDestroy() { + delete this; +} + +////////////////////////////////////////////////////////////////////////////////////////////// +NotificationPipeBsd::NotificationPipeBsd(NetworkBsd *network) + : BaseSocketBsd(network), + injected_cb_(NULL) { + + if (!HasFreePollSlot()) + tunsafe_die("no free poll slots"); + +#if !defined(OS_MACOSX) + if (pipe2(pipe_fds_, O_CLOEXEC | O_NONBLOCK)) + tunsafe_die("pipe2 failed"); +#else + if (pipe(pipe_fds_)) + tunsafe_die("pipe failed"); + for (int i = 0; i < 2; i++) { + fcntl(pipe_fds_[i], F_SETFD, FD_CLOEXEC); + fcntl(pipe_fds_[i], F_SETFL, O_NONBLOCK); + } +#endif + + + InitPollSlot(pipe_fds_[0], POLLIN); +} + +NotificationPipeBsd::~NotificationPipeBsd() { +} + +void NotificationPipeBsd::InjectCallback(CallbackFunc *func, void *param) { + CallbackState *st = new CallbackState; + st->func = func; + st->param = param; + // todo: support multiple writers? + st->next = injected_cb_.exchange(NULL); + injected_cb_.exchange(st); + write(pipe_fds_[1], "", 1); +} + +void NotificationPipeBsd::Wakeup() { + write(pipe_fds_[1], "", 1); +} + +void NotificationPipeBsd::HandleEvents(int revents) { + if (revents & (POLLERR | POLLHUP | POLLNVAL)) { + RERROR("Error with pipe() polling"); + CloseSocket(); + } else if (revents & POLLIN) { + char tmp[64]; + read(fd_, tmp, sizeof(tmp)); + if (CallbackState *cb = injected_cb_.exchange(NULL)) { + do { + CallbackState *next = cb->next; + cb->func(cb->param); + cb = next; + } while (cb); + } + } + } diff --git a/network_bsd.h b/network_bsd.h new file mode 100644 index 0000000..6f2c711 --- /dev/null +++ b/network_bsd.h @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: AGPL-1.0-only +// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. +#ifndef TUNSAFE_NETWORK_BSD_H_ +#define TUNSAFE_NETWORK_BSD_H_ + +#include +#include +#include +#include +#include "network_common.h" + +class BaseSocketBsd; +class TcpSocketBsd; +class WireguardProcessor; +class Packet; + +class NetworkBsd { + friend class BaseSocketBsd; + friend class TcpSocketBsd; + friend class UdpSocketBsd; + friend class TunSocketBsd; +public: + enum { +#if defined(OS_ANDROID) + WithSigalarmSupport = 0, +#else + WithSigalarmSupport = 1 +#endif + }; + + class NetworkBsdDelegate { + public: + virtual void OnSecondLoop(uint64 now) {} + virtual void RunAllMainThreadScheduled() {} + }; + + explicit NetworkBsd(NetworkBsdDelegate *delegate, int max_sockets); + ~NetworkBsd(); + + void RunLoop(const sigset_t *sigmask); + void PostExit() { exit_ = true; } + + bool *exit_flag() { return &exit_; } + bool *sigalarm_flag() { return &sigalarm_flag_; } + + TcpSocketBsd *tcp_sockets() { return tcp_sockets_; } + bool overload() { return overload_; } +private: + void RemoveFromRoundRobin(int slot); + + void ReallocateIov(size_t i); + void EnsureIovAllocated(); + + Packet *read_packet_; + bool exit_; + bool overload_; + bool sigalarm_flag_; + + enum { + // This controls the max # of sockets we can support + kMaxIovec = 16, + }; + int num_sock_; + int num_roundrobin_; + int num_endloop_; + int max_sockets_; + + SimplePacketPool packet_pool_; + NetworkBsdDelegate *delegate_; + + struct pollfd *pollfd_; + BaseSocketBsd **sockets_; + BaseSocketBsd **roundrobin_; + BaseSocketBsd **endloop_; + + // Linked list of all tcp sockets + TcpSocketBsd *tcp_sockets_; + + struct iovec iov_[kMaxIovec]; + Packet *iov_packets_[kMaxIovec]; + +}; + +class BaseSocketBsd { + friend class NetworkBsd; +public: + BaseSocketBsd(NetworkBsd *network) : pollfd_slot_(-1), roundrobin_slot_(-1), endloop_slot_(-1), fd_(-1), network_(network) {} + virtual ~BaseSocketBsd(); + + virtual void HandleEvents(int revents) = 0; + + // Return |false| to remove socket from roundrobin list. + virtual bool DoRoundRobin() { return false; } + virtual void DoEndloop() {} + virtual void Periodic() {} + + // Make sure this socket gets called during each round robin step. + void AddToRoundRobin(); + + // Make sure this sockets get called at the end of the loop + void AddToEndLoop(); + + int GetFd() { return fd_; } + +protected: + void SetPollFlags(int events) { + network_->pollfd_[pollfd_slot_].events = events; + } + void InitPollSlot(int fd, int events); + bool HasFreePollSlot() { return network_->num_sock_ != network_->max_sockets_; } + void CloseSocket(); + + NetworkBsd *network_; + int pollfd_slot_; + int roundrobin_slot_; + int endloop_slot_; + int fd_; +}; + +class TunSocketBsd : public BaseSocketBsd { +public: + explicit TunSocketBsd(NetworkBsd *network, WireguardProcessor *processor); + virtual ~TunSocketBsd(); + + bool Initialize(int fd); + + virtual void HandleEvents(int revents) override; + virtual bool DoRoundRobin() override; + + void WritePacket(Packet *packet); + + bool tun_interface_gone() const { return tun_interface_gone_; } + +private: + bool DoRead(); + bool DoWrite(); + + bool tun_readable_, tun_writable_; + bool tun_interface_gone_; + Packet *tun_queue_, **tun_queue_end_; + WireguardProcessor *processor_; +}; + +class UdpSocketBsd : public BaseSocketBsd { +public: + explicit UdpSocketBsd(NetworkBsd *network, WireguardProcessor *processor); + virtual ~UdpSocketBsd(); + + bool Initialize(int listen_port); + + virtual void HandleEvents(int revents) override; + virtual bool DoRoundRobin() override; + + bool DoRead(); + bool DoWrite(); + + void WritePacket(Packet *packet); + +private: + bool udp_readable_, udp_writable_; + Packet *udp_queue_, **udp_queue_end_; + WireguardProcessor *processor_; +}; + +#if defined(OS_LINUX) +// Keeps track of when the unix socket gets deleted +class UnixSocketDeletionWatcher { +public: + UnixSocketDeletionWatcher(); + ~UnixSocketDeletionWatcher(); + bool Start(const char *path, bool *flag_to_set); + void Stop(); + bool Poll(const char *path) { return false; } + +private: + static void *RunThread(void *arg); + void *RunThreadInner(); + const char *path_; + int inotify_fd_; + int pid_; + int pipes_[2]; + pthread_t thread_; + bool *flag_to_set_; +}; +#else // !defined(OS_LINUX) +// all other platforms that lack inotify +class UnixSocketDeletionWatcher { +public: + UnixSocketDeletionWatcher() {} + ~UnixSocketDeletionWatcher() {} + bool Start(const char *path, bool *flag_to_set) { return true; } + void Stop() {} + bool Poll(const char *path); +}; +#endif // !defined(OS_LINUX) + +class UnixDomainSocketListenerBsd : public BaseSocketBsd { +public: + explicit UnixDomainSocketListenerBsd(NetworkBsd *network, WireguardProcessor *processor); + virtual ~UnixDomainSocketListenerBsd(); + + bool Initialize(const char *devname); + + bool Start(bool *exit_flag) { + return un_deletion_watcher_.Start(un_addr_.sun_path, exit_flag); + } + void Stop() { un_deletion_watcher_.Stop(); } + + virtual void HandleEvents(int revents) override; + virtual void Periodic() override; +private: + struct sockaddr_un un_addr_; + WireguardProcessor *processor_; + UnixSocketDeletionWatcher un_deletion_watcher_; +}; + +class UnixDomainSocketChannelBsd : public BaseSocketBsd { +public: + explicit UnixDomainSocketChannelBsd(NetworkBsd *network, WireguardProcessor *processor, int fd); + virtual ~UnixDomainSocketChannelBsd(); + + virtual void HandleEvents(int revents) override; + +private: + bool HandleEventsInner(int revents); + WireguardProcessor *processor_; + std::string inbuf_, outbuf_; +}; + +class TcpSocketListenerBsd : public BaseSocketBsd { +public: + explicit TcpSocketListenerBsd(NetworkBsd *bsd, WireguardProcessor *processor); + virtual ~TcpSocketListenerBsd(); + + bool Initialize(int port); + + virtual void HandleEvents(int revents) override; + virtual void Periodic() override; + +private: + WireguardProcessor *processor_; +}; + +class TcpSocketBsd : public BaseSocketBsd { +public: + explicit TcpSocketBsd(NetworkBsd *bsd, WireguardProcessor *processor); + virtual ~TcpSocketBsd(); + + void InitializeIncoming(int fd, const IpAddr &addr); + bool InitializeOutgoing(const IpAddr &addr); + + void WritePacket(Packet *packet); + + virtual void HandleEvents(int revents) override; + virtual void DoEndloop() override; + + TcpSocketBsd *next() { return next_; } + uint8 endpoint_protocol() { return endpoint_protocol_; } + const IpAddr &endpoint() { return endpoint_; } + +public: + uint8 age; + uint8 handshake_attempts; +private: + void DoRead(); + void DoWrite(); + void CloseSocketAndDestroy(); + + bool readable_, writable_; + bool got_eof_; + uint8 endpoint_protocol_; + bool want_connect_; + + uint32 wqueue_bytes_; + Packet *wqueue_, **wqueue_end_; + TcpSocketBsd *next_; + WireguardProcessor *processor_; + TcpPacketHandler tcp_packet_handler_; + IpAddr endpoint_; +}; + +class NotificationPipeBsd : public BaseSocketBsd { +public: + NotificationPipeBsd(NetworkBsd *network); + ~NotificationPipeBsd(); + + typedef void CallbackFunc(void *x); + void InjectCallback(CallbackFunc *func, void *param); + void Wakeup(); + + virtual void HandleEvents(int revents) override; + +private: + struct CallbackState { + CallbackFunc *func; + void *param; + CallbackState *next; + }; + int pipe_fds_[2]; + std::atomic injected_cb_; +}; + + +#endif // TUNSAFE_NETWORK_BSD_H_ \ No newline at end of file diff --git a/network_bsd_common.h b/network_bsd_common.h deleted file mode 100644 index c4ef5ad..0000000 --- a/network_bsd_common.h +++ /dev/null @@ -1,101 +0,0 @@ -// SPDX-License-Identifier: AGPL-1.0-only -// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. -#ifndef TUNSAFE_NETWORK_BSD_COMMON_H_ -#define TUNSAFE_NETWORK_BSD_COMMON_H_ - -#include "netapi.h" -#include "wireguard.h" -#include "wireguard_config.h" -#include -#include - -struct RouteInfo { - uint8 family; - uint8 cidr; - uint8 ip[16]; - uint8 gw[16]; - std::string dev; -}; - -#if defined(OS_LINUX) -// Keeps track of when the unix socket gets deleted -class UnixSocketDeletionWatcher { -public: - UnixSocketDeletionWatcher(); - ~UnixSocketDeletionWatcher(); - bool Start(const char *path, bool *flag_to_set); - void Stop(); - bool Poll(const char *path) { return false; } - -private: - static void *RunThread(void *arg); - void *RunThreadInner(); - const char *path_; - int inotify_fd_; - int pid_; - int pipes_[2]; - pthread_t thread_; - bool *flag_to_set_; -}; -#else // !defined(OS_LINUX) -// all other platforms that lack inotify -class UnixSocketDeletionWatcher { -public: - UnixSocketDeletionWatcher() {} - ~UnixSocketDeletionWatcher() {} - bool Start(const char *path, bool *flag_to_set) { return true; } - void Stop() {} - bool Poll(const char *path); -}; -#endif // !defined(OS_LINUX) - - -class TunsafeBackendBsd : public TunInterface, public UdpInterface { -public: - TunsafeBackendBsd(); - virtual ~TunsafeBackendBsd(); - - void RunLoop(); - void CleanupRoutes(); - - void SetTunDeviceName(const char *name); - - void SetProcessor(WireguardProcessor *wg) { processor_ = wg; } - - // -- from TunInterface - virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override; - - virtual void HandleSigAlrm() = 0; - virtual void HandleExit() = 0; - -protected: - virtual bool InitializeTun(char devname[16]) = 0; - virtual void RunLoopInner() = 0; - - void AddRoute(uint32 ip, uint32 cidr, uint32 gw, const char *dev); - void DelRoute(const RouteInfo &cd); - bool AddRoute(int family, const void *dest, int dest_prefix, const void *gateway, const char *dev); - bool RunPrePostCommand(const std::vector &vec); - - WireguardProcessor *processor_; - std::vector cleanup_commands_; - std::vector pre_down_, post_down_; - std::vector addresses_to_remove_; - sigset_t orig_signal_mask_; - char devname_[16]; - bool tun_interface_gone_; -}; - -#if defined(OS_MACOSX) || defined(OS_FREEBSD) -#define TUN_PREFIX_BYTES 4 -#elif defined(OS_LINUX) -#define TUN_PREFIX_BYTES 0 -#endif - -int open_tun(char *devname, size_t devname_size); -int open_udp(int listen_on_port); - -void SetThreadName(const char *name); -TunsafeBackendBsd *CreateTunsafeBackendBsd(); - -#endif // TUNSAFE_NETWORK_BSD_COMMON_H_ \ No newline at end of file diff --git a/network_common.cpp b/network_common.cpp new file mode 100644 index 0000000..919e7fd --- /dev/null +++ b/network_common.cpp @@ -0,0 +1,174 @@ +#include "stdafx.h" +#include "network_common.h" +#include "netapi.h" +#include "tunsafe_endian.h" +#include +#include +#include "util.h" + +TcpPacketHandler::TcpPacketHandler(SimplePacketPool *packet_pool) { + packet_pool_ = packet_pool; + rqueue_bytes_ = 0; + error_flag_ = false; + rqueue_ = NULL; + rqueue_end_ = &rqueue_; + predicted_key_in_ = predicted_key_out_ = 0; + predicted_serial_in_ = predicted_serial_out_ = 0; +} + +TcpPacketHandler::~TcpPacketHandler() { + FreePacketList(rqueue_); +} + +enum { + kTcpPacketType_Normal = 0, + kTcpPacketType_Reserved = 1, + kTcpPacketType_Data = 2, + kTcpPacketType_Control = 3, + kTcpPacketControlType_SetKeyAndCounter = 0, +}; + +void TcpPacketHandler::AddHeaderToOutgoingPacket(Packet *p) { + unsigned int size = p->size; + uint8 *data = p->data; + if (size >= 16 && ReadLE32(data) == 4) { + uint32 key = Read32(data + 4); + uint64 serial = ReadLE64(data + 8); + WriteBE16(data + 14, size - 16 + (kTcpPacketType_Data << 14)); + data += 14, size -= 14; + // Insert a 15 byte control packet right before to set the new key/serial? + if ((predicted_key_out_ ^ key) | (predicted_serial_out_ ^ serial)) { + predicted_key_out_ = key; + WriteLE64(data - 8, serial); + Write32(data - 12, key); + data[-13] = kTcpPacketControlType_SetKeyAndCounter; + WriteBE16(data - 15, 13 + (kTcpPacketType_Control << 14)); + data -= 15, size += 15; + } + // Increase the serial by 1 for next packet. + predicted_serial_out_ = serial + 1; + } else { + WriteBE16(data - 2, size); + data -= 2, size += 2; + } + p->size = size; + p->data = data; +} + +void TcpPacketHandler::QueueIncomingPacket(Packet *p) { + rqueue_bytes_ += p->size; + p->queue_next = NULL; + *rqueue_end_ = p; + rqueue_end_ = &Packet_NEXT(p); +} + +// Either the packet fits in one buf or not. +static uint32 ReadPacketHeader(Packet *p) { + if (p->size >= 2) + return ReadBE16(p->data); + else + return (p->data[0] << 8) + (Packet_NEXT(p)->data[0]); +} + +// Move data around to ensure that exactly the first |num| bytes are stored +// in the first packet, and the rest of the data in subsequent packets. +Packet *TcpPacketHandler::ReadNextPacket(uint32 num) { + Packet *p = rqueue_; + + assert(num <= kPacketCapacity); + if (p->size < num) { + // There's not enough data in the current packet, copy data from the next packet + // into this packet. + if ((uint32)(&p->data_buf[kPacketCapacity] - p->data) < num) { + // Move data up front to make space. + memmove(p->data_buf, p->data, p->size); + p->data = p->data_buf; + } + // Copy data from future packets into p, and delete them should they become empty. + do { + Packet *n = Packet_NEXT(p); + uint32 bytes_to_copy = std::min(n->size, num - p->size); + uint32 nsize = (n->size -= bytes_to_copy); + memcpy(p->data + postinc(p->size, bytes_to_copy), postinc(n->data, bytes_to_copy), bytes_to_copy); + if (nsize == 0) { + p->queue_next = n->queue_next; + packet_pool_->FreePacketToPool(n); + } + } while (num - p->size); + } else if (p->size > num) { + // The packet has too much data. Split the packet into two packets. + Packet *n = packet_pool_->AllocPacketFromPool(); + if (!n) + return NULL; // unable to allocate a packet....? + if (num * 2 <= p->size) { + // There's a lot of trailing data: PP NNNNNN. Move PP. + n->size = num; + p->size -= num; + rqueue_bytes_ -= num; + memcpy(n->data, postinc(p->data, num), num); + return n; + } else { + uint32 overflow = p->size - num; + // There's a lot of leading data: PPPPPP NN. Move NN + n->size = overflow; + p->size = num; + rqueue_ = n; + if (!(n->queue_next = p->queue_next)) + rqueue_end_ = &Packet_NEXT(n); + rqueue_bytes_ -= num; + memcpy(n->data, p->data + num, overflow); + return p; + } + } + if ((rqueue_ = Packet_NEXT(p)) == NULL) + rqueue_end_ = &rqueue_; + rqueue_bytes_ -= num; + return p; +} + +Packet *TcpPacketHandler::GetNextWireguardPacket() { + while (rqueue_bytes_ >= 2) { + uint32 packet_header = ReadPacketHeader(rqueue_); + uint32 packet_size = packet_header & 0x3FFF; + uint32 packet_type = packet_header >> 14; + if (packet_size + 2 > rqueue_bytes_) + return NULL; + if (packet_size + 2 > kPacketCapacity) { + RERROR("Oversized packet?"); + error_flag_ = true; + return NULL; + } + Packet *packet = ReadNextPacket(packet_size + 2); + if (packet) { +// RINFO("Packet of type %d, size %d", packet_type, packet->size - 2); + packet->data += 2, packet->size -= 2; + if (packet_type == kTcpPacketType_Normal) { + + return packet; + } else if (packet_type == kTcpPacketType_Data) { + // Optimization when the 16 first bytes are known and prefixed to the packet + assert(packet->data >= packet->data_buf); + + packet->data -= 16, packet->size += 16; + WriteLE32(packet->data, 4); + Write32(packet->data + 4, predicted_key_in_); + WriteLE64(packet->data + 8, predicted_serial_in_); + predicted_serial_in_++; + return packet; + } else if (packet_type == kTcpPacketType_Control) { + // Unknown control packets are silently ignored + if (packet->size == 13 && packet->data[0] == kTcpPacketControlType_SetKeyAndCounter) { + // Control packet to setup the predicted key/sequence nr + predicted_key_in_ = Read32(packet->data + 1); + predicted_serial_in_ = ReadLE64(packet->data + 5); + } + packet_pool_->FreePacketToPool(packet); + } else { + packet_pool_->FreePacketToPool(packet); + error_flag_ = true; + return NULL; + } + } + } + return NULL; +} \ No newline at end of file diff --git a/network_common.h b/network_common.h new file mode 100644 index 0000000..8814cd4 --- /dev/null +++ b/network_common.h @@ -0,0 +1,95 @@ +#ifndef TUNSAFE_NETWORK_COMMON_H_ +#define TUNSAFE_NETWORK_COMMON_H_ + +#include "netapi.h" + +class PacketProcessor; + +// A simple singlethreaded pool of packets used on windows where +// FreePacket / AllocPacket are multithreded and thus slightly slower +#if defined(OS_WIN) +class SimplePacketPool { +public: + explicit SimplePacketPool() { + freed_packets_ = NULL; + freed_packets_count_ = 0; + } + ~SimplePacketPool() { + FreePacketList(freed_packets_); + } + Packet *AllocPacketFromPool() { + if (Packet *p = freed_packets_) { + freed_packets_ = Packet_NEXT(p); + freed_packets_count_--; + p->Reset(); + return p; + } + return AllocPacket(); + } + void FreePacketToPool(Packet *p) { + Packet_NEXT(p) = freed_packets_; + freed_packets_ = p; + freed_packets_count_++; + } + void FreeSomePackets() { + if (freed_packets_count_ > 32) + FreeSomePacketsInner(); + } + void FreeSomePacketsInner(); + + + int freed_packets_count_; + Packet *freed_packets_; +}; +#else +class SimplePacketPool { +public: + Packet *AllocPacketFromPool() { + return AllocPacket(); + } + void FreePacketToPool(Packet *packet) { + return FreePacket(packet); + } +}; +#endif + + + +// Aids with prefixing and parsing incoming and outgoing +// packets with the tcp protocol header. +class TcpPacketHandler { +public: + explicit TcpPacketHandler(SimplePacketPool *packet_pool); + ~TcpPacketHandler(); + + // Adds a tcp header to a data packet so it can be transmitted on the wire + void AddHeaderToOutgoingPacket(Packet *p); + + // Add a new chunk of incoming data to the packet list + void QueueIncomingPacket(Packet *p); + + // Attempt to extract the next packet, returns NULL when complete. + Packet *GetNextWireguardPacket(); + + bool error() const { return error_flag_; } + +private: + // Internal function to read a packet + Packet *ReadNextPacket(uint32 num); + + SimplePacketPool *packet_pool_; + + // Total # of bytes queued + uint32 rqueue_bytes_; + + // Set if there's a fatal error + bool error_flag_; + + // These hold the incoming packets before they're parsed + Packet *rqueue_, **rqueue_end_; + + uint32 predicted_key_in_, predicted_key_out_; + uint64 predicted_serial_in_, predicted_serial_out_; +}; + +#endif // TUNSAFE_NETWORK_COMMON_H_ \ No newline at end of file diff --git a/network_win32.cpp b/network_win32.cpp index 2a4bfe7..a946bc3 100644 --- a/network_win32.cpp +++ b/network_win32.cpp @@ -5,6 +5,8 @@ #include "wireguard_config.h" #include "netapi.h" #include +#include +#include #include #include #include @@ -12,7 +14,6 @@ #include #include #include -#include #include #include #include "tunsafe_endian.h" @@ -42,15 +43,20 @@ static HKEY g_hklm_reg_key; static uint8 g_killswitch_curr, g_killswitch_want, g_killswitch_currconn; bool g_allow_pre_post; +static volatile bool g_fail_malloc_flag; static void DeactivateKillSwitch(uint32 want); Packet *AllocPacket() { Packet *packet = (Packet*)InterlockedPopEntrySList(&freelist_head); - if (packet == NULL) - packet = (Packet *)_aligned_malloc(kPacketAllocSize, 16); - packet->data = packet->data_buf + Packet::HEADROOM_BEFORE; - packet->size = 0; + if (packet == NULL) { + while ((packet = (Packet *)_aligned_malloc(kPacketAllocSize, 16)) == NULL) { + if (g_fail_malloc_flag) + return NULL; + Sleep(1000); + } + } + packet->Reset(); return packet; } @@ -83,6 +89,14 @@ void FreeAllPackets() { } } +void SimplePacketPool::FreeSomePacketsInner() { + int n = freed_packets_count_ - 24; + Packet **p = &freed_packets_; + for (; n; n--) + p = &Packet_NEXT(*p); + FreePackets(exch(freed_packets_, *p), p, exch(freed_packets_count_, 24) - 24); +} + void InitPacketMutexes() { static bool mutex_inited; if (!mutex_inited) { @@ -91,17 +105,6 @@ void InitPacketMutexes() { } } -int tpq_last_qsize; -int g_tun_reads, g_tun_writes; - -struct { - uint32 pad1[3]; - uint32 udp_qsize1; - uint32 pad2[3]; - uint32 udp_qsize2; -} qs; - - #define kConcurrentReadTap 16 #define kConcurrentWriteTap 16 @@ -399,47 +402,13 @@ static bool GetDefaultRouteAndDeleteOldRoutes(int family, const NET_LUID *Interf return (rv == 0); } -static inline bool NoMoreAllocationRetry(volatile bool *exit_flag) { - if (*exit_flag) - return true; - Sleep(1000); - return *exit_flag; -} - -static inline bool AllocPacketFrom(Packet **list, int *counter, bool *exit_flag, Packet **res) { - Packet *p; - if (p = *list) { - *list = Packet_NEXT(p); - (*counter)--; - p->data = p->data_buf + Packet::HEADROOM_BEFORE; - } else { - while ((p = AllocPacket()) == NULL) { - if (NoMoreAllocationRetry(exit_flag)) - return false; - } - } - *res = p; - return true; -} - -static void FreePacketList(Packet *pp) { +void FreePacketList(Packet *pp) { while (Packet *p = pp) { pp = Packet_NEXT(p); FreePacket(p); } } -inline void NetworkWin32::FreePacketToPool(Packet *p) { - Packet_NEXT(p) = NULL; - *freed_packets_end_ = p; - freed_packets_end_ = &Packet_NEXT(p); - freed_packets_count_++; -} - -inline bool NetworkWin32::AllocPacketFromPool(Packet **p) { - return AllocPacketFrom(&freed_packets_, &freed_packets_count_, &exit_thread_, p); -} - UdpSocketWin32::UdpSocketWin32(NetworkWin32 *network_win32) { network_ = network_win32; wqueue_end_ = &wqueue_; @@ -455,6 +424,9 @@ UdpSocketWin32::UdpSocketWin32(NetworkWin32 *network_win32) { num_reads_[0] = num_reads_[1] = 0; num_writes_ = 0; pending_writes_ = NULL; + + qsize1_ = 0; + qsize2_ = 0; } UdpSocketWin32::~UdpSocketWin32() { @@ -529,12 +501,12 @@ fail: // Called on another thread to queue up a udp packet void UdpSocketWin32::WriteUdpPacket(Packet *packet) { - if (qs.udp_qsize2 - qs.udp_qsize1 >= (unsigned)(packet->size < 576 ? MAX_BYTES_IN_UDP_OUT_QUEUE_SMALL : MAX_BYTES_IN_UDP_OUT_QUEUE)) { + if (qsize2_ - qsize1_ >= (unsigned)(packet->size < 576 ? MAX_BYTES_IN_UDP_OUT_QUEUE_SMALL : MAX_BYTES_IN_UDP_OUT_QUEUE)) { FreePacket(packet); return; } - Packet_NEXT(packet) = NULL; - qs.udp_qsize2 += packet->size; + packet->queue_next = NULL; + qsize2_ += packet->size; mutex_.Acquire(); Packet *was_empty = wqueue_; @@ -542,20 +514,14 @@ void UdpSocketWin32::WriteUdpPacket(Packet *packet) { wqueue_end_ = &Packet_NEXT(packet); mutex_.Release(); - if (was_empty == NULL) { - // Notify the worker thread that it should attempt more writes - PostQueuedCompletionStatus(network_->completion_port_handle_, NULL, NULL, NULL); - } + if (was_empty == NULL) + network_->WakeUp(); } enum { kUdpGetQueuedCompletionStatusSize = kConcurrentWriteTap + kConcurrentReadTap + 1 }; -static inline void ClearOverlapped(OVERLAPPED *o) { - memset(o, 0, sizeof(*o)); -} - #ifndef STATUS_PORT_UNREACHABLE #define STATUS_PORT_UNREACHABLE 0xC000023F #endif @@ -567,8 +533,8 @@ static inline bool IsIgnoredUdpError(DWORD err) { void UdpSocketWin32::DoMoreReads() { // Listen with multiple ipv6 packets only if we ever sent an ipv6 packet. for (int i = num_reads_[IPV6]; i < max_read_ipv6_; i++) { - Packet *p; - if (!network_->AllocPacketFromPool(&p)) + Packet *p = network_->packet_pool().AllocPacketFromPool(); + if (!p) break; restart_read_udp6: ClearOverlapped(&p->overlapped); @@ -590,32 +556,35 @@ restart_read_udp6: num_reads_[IPV6]++; } // Initiate more reads, reusing the Packet structures in |finished_writes|. - for (int i = num_reads_[IPV4]; i < kConcurrentReadUdp; i++) { - Packet *p; - if (!network_->AllocPacketFromPool(&p)) - break; -restart_read_udp: - ClearOverlapped(&p->overlapped); - WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data}; - DWORD flags = 0; - p->userdata = IPV4; - p->sin_size = sizeof(p->addr.sin); - p->queue_cb = this; - if (WSARecvFrom(socket_, &wsabuf, 1, NULL, &flags, (struct sockaddr*)&p->addr, &p->sin_size, &p->overlapped, NULL) != 0) { - DWORD err = WSAGetLastError(); - if (err != WSA_IO_PENDING) { - if (err == WSAEMSGSIZE || err == WSAECONNRESET || err == WSAENETRESET) - goto restart_read_udp; - RERROR("UdpSocketWin32:WSARecvFrom failed 0x%X", err); - FreePacket(p); + + if (socket_ != INVALID_SOCKET) { + for (int i = num_reads_[IPV4]; i < kConcurrentReadUdp; i++) { + Packet *p = network_->packet_pool().AllocPacketFromPool(); + if (!p) break; +restart_read_udp: + ClearOverlapped(&p->overlapped); + WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data}; + DWORD flags = 0; + p->userdata = IPV4; + p->sin_size = sizeof(p->addr.sin); + p->queue_cb = this; + if (WSARecvFrom(socket_, &wsabuf, 1, NULL, &flags, (struct sockaddr*)&p->addr, &p->sin_size, &p->overlapped, NULL) != 0) { + DWORD err = WSAGetLastError(); + if (err != WSA_IO_PENDING) { + if (err == WSAEMSGSIZE || err == WSAECONNRESET || err == WSAENETRESET) + goto restart_read_udp; + RERROR("UdpSocketWin32:WSARecvFrom failed 0x%X", err); + FreePacket(p); + break; + } } + num_reads_[IPV4]++; } - num_reads_[IPV4]++; } } -void UdpSocketWin32::DoMoreWrites() { +void UdpSocketWin32::ProcessPackets() { // Push all the finished reads to the packet handler if (finished_reads_ != NULL) { packet_handler_->PostPackets(finished_reads_, finished_reads_end_, finished_reads_count_); @@ -623,7 +592,9 @@ void UdpSocketWin32::DoMoreWrites() { finished_reads_end_ = &finished_reads_; finished_reads_count_ = 0; } - +} + +void UdpSocketWin32::DoMoreWrites() { Packet *pending_writes = pending_writes_; // Initiate more writes from |wqueue_| while (num_writes_ < kConcurrentWriteUdp) { @@ -639,7 +610,7 @@ void UdpSocketWin32::DoMoreWrites() { if (!pending_writes) break; } - qs.udp_qsize1 += pending_writes->size; + qsize1_ += pending_writes->size; // Then issue writes Packet *p = pending_writes; @@ -688,13 +659,14 @@ void UdpSocketWin32::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) { if (p->userdata < 2) { num_reads_[p->userdata]--; if ((DWORD)p->overlapped.Internal != 0) { - network_->FreePacketToPool(p); if (!IsIgnoredUdpError((DWORD)p->overlapped.Internal)) RERROR("UdpSocketWin32::Read error 0x%X", (DWORD)p->overlapped.Internal); + network_->packet_pool().FreePacketToPool(p); } else { // Remember all the finished packets and queue them up to the next thread once we've // collected them all. p->size = (int)p->overlapped.InternalHigh; + p->protocol = kPacketProtocolUdp; p->queue_cb = packet_handler_->udp_queue(); p->queue_next = NULL; *finished_reads_end_ = p; @@ -703,9 +675,9 @@ void UdpSocketWin32::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) { } } else { num_writes_--; - network_->FreePacketToPool(p); if ((DWORD)p->overlapped.Internal != 0) RERROR("UdpSocketWin32::Write error 0x%X", (DWORD)p->overlapped.Internal); + network_->packet_pool().FreePacketToPool(p); } } @@ -716,28 +688,30 @@ void UdpSocketWin32::OnQueuedItemDelete(QueuedItem *qi) { } else { num_writes_--; } - network_->FreePacketToPool(p); + network_->packet_pool().FreePacketToPool(p); } void UdpSocketWin32::DoIO() { DoMoreWrites(); + ProcessPackets(); DoMoreReads(); } //////////////////////////////////////////////////////////////////////////////////////////////////////// -NetworkWin32::NetworkWin32() : udp_socket_(this) { +NetworkWin32::NetworkWin32() : udp_socket_(this), tcp_socket_queue_(this) { exit_thread_ = false; thread_ = NULL; - freed_packets_ = NULL; - freed_packets_end_ = &freed_packets_; - freed_packets_count_ = 0; completion_port_handle_ = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, NULL, 0); + tcp_socket_ = NULL; } NetworkWin32::~NetworkWin32() { assert(thread_ == NULL); + + for (TcpSocketWin32 *socket = tcp_socket_; socket; ) + delete exch(socket, socket->next_); + CloseHandle(completion_port_handle_); - FreePacketList(freed_packets_); } DWORD WINAPI NetworkWin32::NetworkThread(void *x) { @@ -750,20 +724,13 @@ void NetworkWin32::ThreadMain() { OVERLAPPED_ENTRY entries[kUdpGetQueuedCompletionStatusSize]; while (!exit_thread_) { - // Run IO on all sockets queued for IO + // TODO: In the future, don't process every socket here, only + // those sockets that requested it. udp_socket_.DoIO(); + for (TcpSocketWin32 *tcp = tcp_socket_; tcp;) + exch(tcp, tcp->next_)->DoIO(); - // Free some packets - assert(freed_packets_count_ >= 0); - if (freed_packets_count_ >= 32) { - FreePackets(freed_packets_, freed_packets_end_, freed_packets_count_); - freed_packets_count_ = 0; - freed_packets_ = NULL; - freed_packets_end_ = &freed_packets_; - } else if (freed_packets_ == NULL) { - assert(freed_packets_count_ == 0); - freed_packets_end_ = &freed_packets_; - } + packet_pool_.FreeSomePackets(); ULONG num_entries = 0; if (!GetQueuedCompletionStatusEx(completion_port_handle_, entries, kUdpGetQueuedCompletionStatusSize, &num_entries, INFINITE, FALSE)) { @@ -779,20 +746,33 @@ void NetworkWin32::ThreadMain() { } udp_socket_.CancelAllIO(); + for (TcpSocketWin32 *tcp = tcp_socket_; tcp; tcp = tcp->next_) + tcp->CancelAllIO(); - while (udp_socket_.HasOutstandingIO()) { + while (HasOutstandingIO()) { ULONG num_entries = 0; - if (!GetQueuedCompletionStatusEx(completion_port_handle_, entries, 1, &num_entries, INFINITE, FALSE)) { + if (!GetQueuedCompletionStatusEx(completion_port_handle_, entries, kUdpGetQueuedCompletionStatusSize, &num_entries, INFINITE, FALSE)) { RINFO("GetQueuedCompletionStatusEx failed."); break; } - if (entries[0].lpOverlapped) { - QueuedItem *w = (QueuedItem*)((byte*)entries[0].lpOverlapped - offsetof(QueuedItem, overlapped)); - w->queue_cb->OnQueuedItemDelete(w); + for (ULONG i = 0; i < num_entries; i++) { + if (entries[i].lpOverlapped) { + QueuedItem *w = (QueuedItem*)((byte*)entries[i].lpOverlapped - offsetof(QueuedItem, overlapped)); + w->queue_cb->OnQueuedItemDelete(w); + } } } } +bool NetworkWin32::HasOutstandingIO() { + if (udp_socket_.HasOutstandingIO()) + return true; + for (TcpSocketWin32 *tcp = tcp_socket_; tcp; tcp = tcp->next_) + if (tcp->HasOutstandingIO()) + return true; + return false; +} + void NetworkWin32::StartThread() { assert(completion_port_handle_); @@ -804,11 +784,36 @@ void NetworkWin32::StartThread() { void NetworkWin32::StopThread() { if (thread_ != NULL) { exit_thread_ = true; + g_fail_malloc_flag = true; PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, NULL); WaitForSingleObject(thread_, INFINITE); CloseHandle(thread_); thread_ = NULL; exit_thread_ = false; + g_fail_malloc_flag = false; + } +} + +void NetworkWin32::WakeUp() { + PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, NULL); +} + +void NetworkWin32::PostQueuedItem(QueuedItem *item) { + PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, &item->overlapped); +} + +bool NetworkWin32::Configure(int listen_port, int listen_port_tcp) { + if (listen_port_tcp) + RERROR("ListenPortTCP not supported in this version"); + return udp_socket_.Configure(listen_port); +} + +// Called from tunsafe thread +void NetworkWin32::WriteUdpPacket(Packet *packet) { + if (packet->protocol & kPacketProtocolUdp) { + udp_socket_.WriteUdpPacket(packet); + } else { + tcp_socket_queue_.WritePacket(packet); } } @@ -874,6 +879,8 @@ int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) { mutex_.Acquire(); while (!(exit_code = exit_code_)) { + FreeAllPackets(); + if (timer_interrupt_) { timer_interrupt_ = false; need_notify_ = 0; @@ -912,7 +919,6 @@ int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) { need_notify_ = 0; mutex_.Release(); - tpq_last_qsize = packets_in_queue; if (packets_in_queue >= 1024) overload = 2; queue_context.overload = (overload != 0); @@ -986,8 +992,8 @@ void PacketProcessor::PostPackets(Packet *first, Packet **end, int count) { } void PacketProcessor::ForcePost(QueuedItem *item) { - mutex_.Acquire(); item->queue_next = NULL; + mutex_.Acquire(); packets_in_queue_ += 1; *last_ptr_ = item; last_ptr_ = &item->queue_next; @@ -1648,7 +1654,6 @@ bool TunWin32Adapter::RunPrePostCommand(const std::vector &vec) { return success; } - ////////////////////////////////////////////////////////////////////////////// TunWin32Iocp::TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker, backend->guid_), backend_(backend) { @@ -1704,6 +1709,20 @@ enum { kTunGetQueuedCompletionStatusSize = kConcurrentWriteTap + kConcurrentReadTap + 1 }; +static inline bool AllocPacketFrom(Packet **list, int *counter, bool *exit_flag, Packet **res) { + Packet *p; + if (p = *list) { + *list = Packet_NEXT(p); + (*counter)--; + p->data = p->data_buf; + } else { + if (!(p = AllocPacket())) + return false; + } + *res = p; + return true; +} + void TunWin32Iocp::ThreadMain() { OVERLAPPED_ENTRY entries[kTunGetQueuedCompletionStatusSize]; Packet *pending_writes = NULL; @@ -1738,7 +1757,6 @@ void TunWin32Iocp::ThreadMain() { num_reads++; } } - g_tun_reads = num_reads; assert(freed_packets_count >= 0); if (freed_packets_count >= 32) { @@ -1820,7 +1838,6 @@ void TunWin32Iocp::ThreadMain() { num_writes++; } } - g_tun_writes = num_writes; } EXIT: @@ -1896,217 +1913,6 @@ void TunWin32Iocp::WriteTunPacket(Packet *packet) { ////////////////////////////////////////////////////////////////////////////// -TunWin32Overlapped::TunWin32Overlapped(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker, backend->guid_), backend_(backend) { - wqueue_end_ = &wqueue_; - wqueue_ = NULL; - - thread_ = NULL; - - read_event_ = CreateEvent(NULL, TRUE, FALSE, NULL); - write_event_ = CreateEvent(NULL, TRUE, FALSE, NULL); - wake_event_ = CreateEvent(NULL, FALSE, FALSE, NULL); - - packet_handler_ = NULL; - exit_thread_ = false; -} - -TunWin32Overlapped::~TunWin32Overlapped() { - CloseTun(); - CloseHandle(read_event_); - CloseHandle(write_event_); - CloseHandle(wake_event_); -} - -bool TunWin32Overlapped::Configure(const TunConfig &&config, TunConfigOut *out) { - CloseTun(); - if (adapter_.OpenAdapter(backend_, FILE_FLAG_OVERLAPPED) && - adapter_.ConfigureAdapter(std::move(config), out)) - return true; - CloseTun(); - return false; -} - -void TunWin32Overlapped::CloseTun() { - assert(thread_ == NULL); - adapter_.CloseAdapter(false); - FreePacketList(wqueue_); - wqueue_ = NULL; - wqueue_end_ = &wqueue_; -} - -void TunWin32Overlapped::ThreadMain() { - Packet *pending_writes = NULL; - DWORD err; - Packet *read_packet = NULL, *write_packet = NULL; - - HANDLE h[3]; - while (!exit_thread_) { - if (read_packet == NULL) { - Packet *p = AllocPacket(); - ClearOverlapped(&p->overlapped); - p->overlapped.hEvent = read_event_; - if (!ReadFile(adapter_.handle(), p->data, kPacketCapacity, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) { - FreePacket(p); - RERROR("TunWin32: ReadFile failed 0x%X", err); - } else { - read_packet = p; - } - } - - int n = 0; - if (write_packet) - h[n++] = write_event_; - if (read_packet != NULL) - h[n++] = read_event_; - h[n++] = wake_event_; - - DWORD res = WaitForMultipleObjects(n, h, FALSE, INFINITE); - - if (res >= WAIT_OBJECT_0 && res <= WAIT_OBJECT_0 + 2) { - HANDLE hx = h[res - WAIT_OBJECT_0]; - if (hx == read_event_) { - read_packet->size = (int)read_packet->overlapped.InternalHigh; - Packet_NEXT(read_packet) = NULL; - packet_handler_->PostPackets(read_packet, &Packet_NEXT(read_packet), 1); - read_packet = NULL; - } else if (hx == write_event_) { - FreePacket(write_packet); - write_packet = NULL; - } - } else { - RERROR("Wait said %d", res); - } - - if (write_packet == NULL) { - if (!pending_writes) { - mutex_.Acquire(); - pending_writes = wqueue_; - wqueue_end_ = &wqueue_; - wqueue_ = NULL; - mutex_.Release(); - } - if (pending_writes) { - // Then issue writes - Packet *p = pending_writes; - pending_writes = Packet_NEXT(p); - memset(&p->overlapped, 0, sizeof(p->overlapped)); - p->overlapped.hEvent = write_event_; - if (!WriteFile(adapter_.handle(), p->data, p->size, NULL, &p->overlapped) && (err = GetLastError()) != ERROR_IO_PENDING) { - RERROR("TunWin32: WriteFile failed 0x%X", err); - FreePacket(p); - } else { - write_packet = p; - } - } - } - } - - // TODO: Free memory - CancelIo(adapter_.handle()); - FreePacketList(pending_writes); -} - -DWORD WINAPI TunWin32Overlapped::TunThread(void *x) { - TunWin32Overlapped *xx = (TunWin32Overlapped *)x; - xx->ThreadMain(); - return 0; -} - -void TunWin32Overlapped::StartThread() { - DWORD thread_id; - thread_ = CreateThread(NULL, 0, &TunThread, this, 0, &thread_id); - SetThreadPriority(thread_, ABOVE_NORMAL_PRIORITY_CLASS); -} - -void TunWin32Overlapped::StopThread() { - exit_thread_ = true; - SetEvent(wake_event_); - WaitForSingleObject(thread_, INFINITE); - CloseHandle(thread_); - thread_ = NULL; -} - -void TunWin32Overlapped::WriteTunPacket(Packet *packet) { - Packet_NEXT(packet) = NULL; - mutex_.Acquire(); - Packet *was_empty = wqueue_; - *wqueue_end_ = packet; - wqueue_end_ = &Packet_NEXT(packet); - mutex_.Release(); - if (was_empty == NULL) - SetEvent(wake_event_); -} - -void TunsafeBackendWin32::SetPublicKey(const uint8 key[32]) { - memcpy(public_key_, key, 32); - delegate_->OnStateChanged(); -} - -DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) { - TunsafeBackendWin32 *backend = (TunsafeBackendWin32*)bk; - int stop_mode; - int fast_retry_ctr = 0; - - for(;;) { - TunWin32Iocp tun(&backend->dns_blocker_, backend); - NetworkWin32 net; - WireguardProcessor wg_proc(&net.udp(), &tun, backend); - - qs.udp_qsize1 = qs.udp_qsize2 = 0; - - net.udp().SetPacketHandler(&backend->packet_processor_); - tun.SetPacketHandler(&backend->packet_processor_); - - if (backend->config_file_[0] && - !ParseWireGuardConfigFile(&wg_proc, backend->config_file_, &backend->dns_resolver_)) - goto getout_fail; - - if (!wg_proc.Start()) - goto getout_fail; - - backend->SetPublicKey(wg_proc.dev().public_key()); - - backend->wg_processor_ = &wg_proc; - - net.StartThread(); - tun.StartThread(); - stop_mode = backend->packet_processor_.Run(&wg_proc, backend); - net.StopThread(); - tun.StopThread(); - - backend->wg_processor_ = NULL; - - // Keep DNS alive - if (stop_mode != MODE_EXIT) - tun.adapter().DisassociateDnsBlocker(); - else - backend->dns_resolver_.ClearCache(); - - FreeAllPackets(); - - if (stop_mode != MODE_TUN_FAILED) - return 0; - - uint32 last_fail = GetTickCount(); - fast_retry_ctr = (last_fail - backend->last_tun_adapter_failed_ < 5000) ? fast_retry_ctr + 1 : 0; - backend->last_tun_adapter_failed_ = last_fail; - - backend->SetStatus((fast_retry_ctr >= 3) ? TunsafeBackend::kErrorTunPermanent : TunsafeBackend::kStatusTunRetrying); - - if (backend->status_ == TunsafeBackend::kErrorTunPermanent) { - RERROR("Too many automatic restarts..."); - goto getout_fail_noseterr; - } - Sleep(1000); - } -getout_fail: - backend->status_ = TunsafeBackend::kErrorInitialize; - backend->delegate_->OnStatusCode(TunsafeBackend::kErrorInitialize); -getout_fail_noseterr: - backend->dns_blocker_.RestoreDns(); - return 0; -} - TunsafeBackend::TunsafeBackend() { is_started_ = false; is_remote_ = false; @@ -2152,6 +1958,75 @@ TunsafeBackendWin32::~TunsafeBackendWin32() { TunAdaptersInUse::GetInstance()->Release(this); } +void TunsafeBackendWin32::SetPublicKey(const uint8 key[32]) { + memcpy(public_key_, key, 32); + delegate_->OnStateChanged(); +} + +DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) { + TunsafeBackendWin32 *backend = (TunsafeBackendWin32*)bk; + int stop_mode; + int fast_retry_ctr = 0; + + for (;;) { + TunWin32Iocp tun(&backend->dns_blocker_, backend); + NetworkWin32 net; + WireguardProcessor wg_proc(&net, &tun, backend); + + net.udp().SetPacketHandler(&backend->packet_processor_); + net.tcp_socket_queue().SetPacketHandler(&backend->packet_processor_); + + tun.SetPacketHandler(&backend->packet_processor_); + + if (backend->config_file_[0] && + !ParseWireGuardConfigFile(&wg_proc, backend->config_file_, &backend->dns_resolver_)) + goto getout_fail; + + if (!wg_proc.Start()) + goto getout_fail; + + backend->SetPublicKey(wg_proc.dev().public_key()); + + backend->wg_processor_ = &wg_proc; + + net.StartThread(); + tun.StartThread(); + stop_mode = backend->packet_processor_.Run(&wg_proc, backend); + net.StopThread(); + tun.StopThread(); + + backend->wg_processor_ = NULL; + + // Keep DNS alive + if (stop_mode != MODE_EXIT) + tun.adapter().DisassociateDnsBlocker(); + else + backend->dns_resolver_.ClearCache(); + + FreeAllPackets(); + + if (stop_mode != MODE_TUN_FAILED) + return 0; + + uint32 last_fail = GetTickCount(); + fast_retry_ctr = (last_fail - backend->last_tun_adapter_failed_ < 5000) ? fast_retry_ctr + 1 : 0; + backend->last_tun_adapter_failed_ = last_fail; + + backend->SetStatus((fast_retry_ctr >= 3) ? TunsafeBackend::kErrorTunPermanent : TunsafeBackend::kStatusTunRetrying); + + if (backend->status_ == TunsafeBackend::kErrorTunPermanent) { + RERROR("Too many automatic restarts..."); + goto getout_fail_noseterr; + } + Sleep(1000); + } +getout_fail: + backend->status_ = TunsafeBackend::kErrorInitialize; + backend->delegate_->OnStatusCode(TunsafeBackend::kErrorInitialize); +getout_fail_noseterr: + backend->dns_blocker_.RestoreDns(); + return 0; +} void TunsafeBackendWin32::SetStatus(StatusCode status) { status_ = status; diff --git a/network_win32.h b/network_win32.h index 6b91ed0..5d79849 100644 --- a/network_win32.h +++ b/network_win32.h @@ -2,15 +2,14 @@ // Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. #pragma once -#include "stdafx.h" -#include "tunsafe_types.h" #include "netapi.h" #include "network_win32_api.h" #include "network_win32_dnsblock.h" #include "wireguard_config.h" #include "tunsafe_threading.h" #include "tunsafe_dnsresolve.h" -#include +#include "network_common.h" +#include "network_win32_tcp.h" enum { ADAPTER_GUID_SIZE = 40, @@ -18,6 +17,7 @@ enum { class WireguardProcessor; class TunsafeBackendWin32; +class DnsBlocker; struct PacketProcessorTunCb : QueuedItemCallback { virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override; @@ -73,16 +73,15 @@ class PacketAllocPool; // Encapsulates a UDP socket pair (ipv4 / ipv6), optionally listening for incoming packets // on a specific port. -class UdpSocketWin32 : public UdpInterface, QueuedItemCallback { +class UdpSocketWin32 : public QueuedItemCallback { public: explicit UdpSocketWin32(NetworkWin32 *network_win32); ~UdpSocketWin32(); void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; } - // -- from UdpInterface - virtual bool Configure(int listen_on_port) override; - virtual void WriteUdpPacket(Packet *packet) override; + bool Configure(int listen_on_port); + inline void WriteUdpPacket(Packet *packet); void DoIO(); void CancelAllIO(); @@ -94,9 +93,9 @@ public: }; private: - void DoMoreReads(); void DoMoreWrites(); + void ProcessPackets(); // From OverlappedCallbacks virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override; @@ -125,11 +124,16 @@ private: Packet *finished_reads_, **finished_reads_end_; int finished_reads_count_; + + __declspec(align(64)) uint32 qsize1_; + __declspec(align(64)) uint32 qsize2_; }; // Holds the thread for network communications -class NetworkWin32 { +class NetworkWin32 : public UdpInterface { friend class UdpSocketWin32; + friend class TcpSocketWin32; + friend class TcpSocketQueue; public: explicit NetworkWin32(); ~NetworkWin32(); @@ -138,13 +142,20 @@ public: void StopThread(); UdpSocketWin32 &udp() { return udp_socket_; } + SimplePacketPool &packet_pool() { return packet_pool_; } + TcpSocketQueue &tcp_socket_queue() { return tcp_socket_queue_; } + void WakeUp(); + void PostQueuedItem(QueuedItem *item); + + // -- from UdpInterface + virtual bool Configure(int listen_port_udp, int listen_port_tcp) override; + virtual void WriteUdpPacket(Packet *packet) override; private: void ThreadMain(); static DWORD WINAPI NetworkThread(void *x); - void FreePacketToPool(Packet *p); - bool AllocPacketFromPool(Packet **p); + bool HasOutstandingIO(); // The network thread handle HANDLE thread_; @@ -155,18 +166,17 @@ private: // The handle to the completion port HANDLE completion_port_handle_; - Packet *freed_packets_, **freed_packets_end_; - int freed_packets_count_; - // Right now there's always one udp socket only UdpSocketWin32 udp_socket_; + + // A linked list of all tcp sockets + TcpSocketWin32 *tcp_socket_; + + SimplePacketPool packet_pool_; + + TcpSocketQueue tcp_socket_queue_; }; - - - -class DnsBlocker; - class TunWin32Adapter { public: TunWin32Adapter(DnsBlocker *dns_blocker, const char guid[ADAPTER_GUID_SIZE]); @@ -243,42 +253,6 @@ private: TunWin32Adapter adapter_; }; -// Implementation of TUN interface handling using Overlapped IO -class TunWin32Overlapped : public TunInterface { -public: - explicit TunWin32Overlapped(DnsBlocker *blocker, TunsafeBackendWin32 *backend); - ~TunWin32Overlapped(); - - void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; } - - void StartThread(); - void StopThread(); - - // -- from TunInterface - virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override; - virtual void WriteTunPacket(Packet *packet) override; - -private: - void CloseTun(); - void ThreadMain(); - static DWORD WINAPI TunThread(void *x); - - PacketProcessor *packet_handler_; - HANDLE thread_; - - Mutex mutex_; - - HANDLE read_event_, write_event_, wake_event_; - - bool exit_thread_; - - Packet *wqueue_, **wqueue_end_; - - TunWin32Adapter adapter_; - - TunsafeBackendWin32 *backend_; -}; - class TunsafeBackendWin32 : public TunsafeBackend, public ProcessorDelegate { friend class PacketProcessor; friend class TunWin32Iocp; @@ -427,3 +401,8 @@ private: uint8 num_inuse_; Entry entry_[kMaxAdaptersInUse]; }; + +static inline void ClearOverlapped(OVERLAPPED *o) { + memset(o, 0, sizeof(*o)); +} + diff --git a/network_win32_api.h b/network_win32_api.h index aaff158..fb4b964 100644 --- a/network_win32_api.h +++ b/network_win32_api.h @@ -125,6 +125,3 @@ protected: TunsafeBackend *CreateNativeTunsafeBackend(TunsafeBackend::Delegate *delegate); TunsafeBackend::Delegate *CreateTunsafeBackendDelegateThreaded(TunsafeBackend::Delegate *delegate, const std::function &callback); - -extern int tpq_last_qsize; -extern int g_tun_reads, g_tun_writes; diff --git a/network_win32_tcp.cpp b/network_win32_tcp.cpp new file mode 100644 index 0000000..060c88f --- /dev/null +++ b/network_win32_tcp.cpp @@ -0,0 +1,344 @@ +// SPDX-License-Identifier: AGPL-1.0-only +// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. +#include +#include "network_win32_tcp.h" +#include "network_win32.h" +#include +#include +#include "util.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////////// + +TcpSocketWin32::TcpSocketWin32(NetworkWin32 *network) + : tcp_packet_handler_(&network->packet_pool()) { + network_ = network; + reads_active_ = 0; + writes_active_ = 0; + handshake_attempts = 0; + state_ = STATE_NONE; + wqueue_ = NULL; + wqueue_end_ = &wqueue_; + socket_ = INVALID_SOCKET; + next_ = NULL; + packet_processor_ = NULL; + // insert in network's linked list + next_ = network->tcp_socket_; + network->tcp_socket_ = this; +} + +TcpSocketWin32::~TcpSocketWin32() { + // Unlink myself from the network's linked list. + TcpSocketWin32 **p = &network_->tcp_socket_; + while (*p != this) p = &(*p)->next_; + *p = next_; + + FreePacketList(wqueue_); + if (socket_ != INVALID_SOCKET) + closesocket(socket_); +} + +void TcpSocketWin32::CloseSocket() { + if (socket_ != INVALID_SOCKET) + CancelIo((HANDLE)socket_); + state_ = STATE_ERROR; + endpoint_protocol_ = 0; +} + +void TcpSocketWin32::WritePacket(Packet *packet) { + packet->queue_next = NULL; + *wqueue_end_ = packet; + wqueue_end_ = &Packet_NEXT(packet); +} + +void TcpSocketWin32::CancelAllIO() { + if (socket_ != INVALID_SOCKET) + CancelIo((HANDLE)socket_); +} + +static const GUID WsaConnectExGUID = WSAID_CONNECTEX; + +void TcpSocketWin32::DoConnect() { + LPFN_CONNECTEX ConnectEx; + assert(socket_ == INVALID_SOCKET); + + socket_ = WSASocket(endpoint_.sin.sin_family, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED); + if (socket_ == INVALID_SOCKET) { + RERROR("socket() failed"); + CloseSocket(); + return; + } + + if (!CreateIoCompletionPort((HANDLE)socket_, network_->completion_port_handle_, 0, 0)) { + RERROR("TcpSocketWin32::DoConnect CreateIoCompletionPort failed"); + CloseSocket(); + return; + } + + int nodelay = 1; + setsockopt(socket_, IPPROTO_TCP, TCP_NODELAY, (char*)&nodelay, 1); + + DWORD dwBytes = sizeof(ConnectEx); + DWORD rc = WSAIoctl(socket_, SIO_GET_EXTENSION_FUNCTION_POINTER, (uint8*)&WsaConnectExGUID, sizeof(WsaConnectExGUID), &ConnectEx, sizeof(ConnectEx), &dwBytes, NULL, NULL); + assert(rc == 0); + + // ConnectEx requires the socket to be bound + sockaddr_in sin = {0}; + sin.sin_family = AF_INET; + sin.sin_addr.s_addr = INADDR_ANY; + sin.sin_port = 0; + if (bind(socket_, (sockaddr*)&sin, sizeof(sin))) { + RERROR("TcpSocketWin32::DoConnect bind failed: %d", WSAGetLastError()); + CloseSocket(); + return; + } + + char buf[kSizeOfAddress]; + RINFO("Connecting to tcp://%s...", PrintIpAddr(endpoint_, buf)); + + state_ = STATE_CONNECTING; + ClearOverlapped(&connect_overlapped_.overlapped); + connect_overlapped_.queue_cb = this; + if (!ConnectEx(socket_, (const sockaddr*)&endpoint_.sin, sizeof(endpoint_.sin), NULL, 0, NULL, &connect_overlapped_.overlapped)) { + int err = WSAGetLastError(); + if (err != ERROR_IO_PENDING) { + RERROR("ConnectEx failed: %d", err); + CloseSocket(); + return; + } + } + reads_active_ = 1; +} + +void TcpSocketWin32::DoMoreReads() { + assert(state_ != STATE_ERROR); + if (reads_active_ == 0) { + // Initiate a new read, we always read into 4 buffers. + Packet *p = network_->packet_pool().AllocPacketFromPool(); + if (!p) + return; + + ClearOverlapped(&p->overlapped); + p->userdata = 0; + p->queue_cb = this; + DWORD flags = 0; + WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data}; + if (WSARecv(socket_, &wsabuf, 1, NULL, &flags, &p->overlapped, NULL) != 0) { + DWORD err = WSAGetLastError(); + if (err != ERROR_IO_PENDING) { + RERROR("TcpSocketWin32:WSARecv failed 0x%X", err); + FreePacket(p); + return; + } + } + reads_active_ = 1; + } +} + +void TcpSocketWin32::DoMoreWrites() { + assert(state_ != STATE_ERROR); + if (writes_active_ == 0) { + WSABUF wsabuf[kMaxWsaBuf]; + uint32 num_wsabuf = 0; + + Packet *p = wqueue_; + if (p == NULL) + return; + + do { + tcp_packet_handler_.AddHeaderToOutgoingPacket(p); + wsabuf[num_wsabuf].buf = (char*)p->data; + wsabuf[num_wsabuf].len = (ULONG)p->size; + packets_in_write_io_[num_wsabuf] = p; + p = Packet_NEXT(p); + } while (++num_wsabuf < kMaxWsaBuf && p != NULL); + if (!(wqueue_ = p)) + wqueue_end_ = &wqueue_; + num_wsabuf_ = (uint8)num_wsabuf; + + p = packets_in_write_io_[0]; + ClearOverlapped(&p->overlapped); + p->userdata = 1; + p->queue_cb = this; + + if (WSASend(socket_, wsabuf, num_wsabuf, NULL, 0, &p->overlapped, NULL) != 0) { + DWORD err = WSAGetLastError(); + if (err != ERROR_IO_PENDING) { + RERROR("TcpSocketWin32: WSASend failed 0x%X", err); + FreePacket(p); + CloseSocket(); + return; + } + } + writes_active_ = 1; + } +} + +void TcpSocketWin32::DoIO() { + if (state_ == STATE_CONNECTED) { + DoMoreReads(); + while (Packet *p = tcp_packet_handler_.GetNextWireguardPacket()) { + p->protocol = endpoint_protocol_; + p->addr = endpoint_; + + p->queue_cb = packet_processor_->udp_queue(); + packet_processor_->ForcePost(p); + } + if (tcp_packet_handler_.error()) { + CloseSocket(); + DoIO(); + return; + } + DoMoreWrites(); + } else if (state_ == STATE_WANT_CONNECT) { + DoConnect(); + } else if (state_ == STATE_ERROR && !HasOutstandingIO()) { + delete this; + } +} + +bool TcpSocketWin32::HasOutstandingIO() { + return writes_active_ + reads_active_ != 0; +} + +void TcpSocketWin32::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) { + if (qi == &connect_overlapped_) { + assert(state_ == STATE_CONNECTING); + reads_active_ = 0; + if ((DWORD)qi->overlapped.Internal != 0) { + if (state_ != STATE_ERROR) { + RERROR("TcpSocketWin32::Connect error 0x%X", (DWORD)qi->overlapped.Internal); + CloseSocket(); + } + } else { + state_ = STATE_CONNECTED; + } + return; + } + Packet *p = static_cast(qi); + if (p->userdata == 0) { + // Read operation complete + if ((DWORD)p->overlapped.Internal != 0) { + if (state_ != STATE_ERROR) { + RERROR("TcpSocketWin32::Read error 0x%X", (DWORD)p->overlapped.Internal); + CloseSocket(); + } + network_->packet_pool().FreePacketToPool(p); + // What to do? + } else if ((int)p->overlapped.InternalHigh == 0) { + // Socket closed successfully + CloseSocket(); + network_->packet_pool().FreePacketToPool(p); + } else { + // Queue it up to rqueue + p->size = (int)p->overlapped.InternalHigh; + tcp_packet_handler_.QueueIncomingPacket(p); + } + reads_active_--; + } else { + assert(writes_active_); + assert(packets_in_write_io_[0] == p); + + if ((DWORD)p->overlapped.Internal != 0) { + if (state_ != STATE_ERROR) { + RERROR("TcpSocketWin32::Write error 0x%X", (DWORD)p->overlapped.Internal); + CloseSocket(); + } + } + // free all the packets involved in the write + for (size_t i = 0; i < num_wsabuf_; i++) + network_->packet_pool().FreePacketToPool(packets_in_write_io_[i]); + writes_active_--; + } +} + +void TcpSocketWin32::OnQueuedItemDelete(QueuedItem *qi) { + if (qi == &connect_overlapped_) { + reads_active_ = 0; + return; + } + Packet *p = static_cast(qi); + if (p->userdata == 0) { + FreePacket(p); + reads_active_--; + } else { + for (size_t i = 0; i < num_wsabuf_; i++) + network_->packet_pool().FreePacketToPool(packets_in_write_io_[i]); + writes_active_--; + } +} + +///////////////////////////////////////////////////////////////////////// + +TcpSocketQueue::TcpSocketQueue(NetworkWin32 *network) { + network_ = network; + wqueue_ = NULL; + wqueue_end_ = &wqueue_; + queued_item_.queue_cb = this; + packet_handler_ = NULL; +} + +TcpSocketQueue::~TcpSocketQueue() { + FreePacketList(wqueue_); +} + +void TcpSocketQueue::TransmitOnePacket(Packet *packet) { + // Check if we have a tcp connection for the endpoint, otherwise create one. + for (TcpSocketWin32 *tcp = network_->tcp_socket_; tcp; tcp = tcp->next_) { + // After we send 3 handshakes on a tcp socket in a row, then close and reopen the socket because it seems defunct. + if (CompareIpAddr(&tcp->endpoint_, &packet->addr) == 0 && tcp->endpoint_protocol_ == packet->protocol) { + if (ReadLE32(packet->data) == MESSAGE_HANDSHAKE_INITIATION) { + if (tcp->handshake_attempts == 2) { + RINFO("Making new Tcp socket due to too many handshake failures"); + tcp->CloseSocket(); + break; + } + tcp->handshake_attempts++; + } else { + tcp->handshake_attempts = -1; + } + tcp->WritePacket(packet); + return; + } + } + + // Drop tcp packet that's for an incoming connection, or packets that are + // not a handshake. + if ((packet->protocol & kPacketProtocolIncomingConnection) || + packet->size < 4 || ReadLE32(packet->data) != MESSAGE_HANDSHAKE_INITIATION) { + FreePacket(packet); + return; + } + + // Initialize a new tcp socket and connect to the endpoint + TcpSocketWin32 *tcp = new TcpSocketWin32(network_); + tcp->state_ = TcpSocketWin32::STATE_WANT_CONNECT; + tcp->endpoint_ = packet->addr; + tcp->endpoint_protocol_ = kPacketProtocolTcp; + tcp->SetPacketHandler(packet_handler_); + tcp->WritePacket(packet); +} + +void TcpSocketQueue::OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) { + wqueue_mutex_.Acquire(); + Packet *packet = wqueue_; + wqueue_ = NULL; + wqueue_end_ = &wqueue_; + wqueue_mutex_.Release(); + while (packet) + TransmitOnePacket(exch(packet, Packet_NEXT(packet))); +} + +void TcpSocketQueue::OnQueuedItemDelete(QueuedItem *ow) { + +} + +void TcpSocketQueue::WritePacket(Packet *packet) { + packet->queue_next = NULL; + wqueue_mutex_.Acquire(); + Packet *was_empty = wqueue_; + *wqueue_end_ = packet; + wqueue_end_ = &Packet_NEXT(packet); + wqueue_mutex_.Release(); + if (was_empty == NULL) + network_->PostQueuedItem(&queued_item_); +} diff --git a/network_win32_tcp.h b/network_win32_tcp.h new file mode 100644 index 0000000..0808f98 --- /dev/null +++ b/network_win32_tcp.h @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: AGPL-1.0-only +// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. + +#pragma once + +#include "netapi.h" +#include "network_common.h" +#include "tunsafe_threading.h" + +class NetworkWin32; +class PacketProcessor; + +class TcpSocketWin32 : public QueuedItemCallback { + friend class NetworkWin32; + friend class TcpSocketQueue; +public: + explicit TcpSocketWin32(NetworkWin32 *network); + ~TcpSocketWin32(); + + void SetPacketHandler(PacketProcessor *packet_handler) { packet_processor_ = packet_handler; } + + // Write a packet to the TCP socket. This may be called only from the + // wireguard thread. Will append to a buffer and schedule it to be written + // from the network thread. + void WritePacket(Packet *packet); + + // Call from IO completion thread to cancel all outstanding IO + void CancelAllIO(); + + // Call from IO completion thread to run more IO + void DoIO(); + + // Returns true if there's IO still left to run + bool HasOutstandingIO(); + +private: + void DoMoreReads(); + void DoMoreWrites(); + void DoConnect(); + + void CloseSocket(); + + // From OverlappedCallbacks + virtual void OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) override; + virtual void OnQueuedItemDelete(QueuedItem *qi) override; + + // Network subsystem + NetworkWin32 *network_; + + PacketProcessor *packet_processor_; + + enum { + STATE_NONE = 0, + STATE_ERROR = 1, + STATE_CONNECTING = 2, + STATE_CONNECTED = 3, + STATE_WANT_CONNECT = 4, + }; + + uint8 reads_active_; + uint8 writes_active_; + uint8 state_; + uint8 num_wsabuf_; + +public: + uint8 handshake_attempts; +private: + + // The handle to the socket + SOCKET socket_; + + // Packets taken over by the network thread waiting to be written, + // when these are written we'll start eating from wqueue_ + Packet *pending_writes_; + + // All packets queued for writing on the network thread. + Packet *wqueue_, **wqueue_end_; + + // Linked list of all TcpSocketWin32 wsockets + TcpSocketWin32 *next_; + + // Handles packet parsing + TcpPacketHandler tcp_packet_handler_; + + // An overlapped instance used for the initial Connect() call. + QueuedItem connect_overlapped_; + + IpAddr endpoint_; + uint8 endpoint_protocol_; + + // Packets currently involved in the wsabuf writing + enum { kMaxWsaBuf = 32 }; + Packet *packets_in_write_io_[kMaxWsaBuf]; +}; + +class TcpSocketQueue : public QueuedItemCallback { +public: + explicit TcpSocketQueue(NetworkWin32 *network); + ~TcpSocketQueue(); + + void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; } + + virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override; + virtual void OnQueuedItemDelete(QueuedItem *ow) override; + + void WritePacket(Packet *packet); + +private: + void TransmitOnePacket(Packet *packet); + NetworkWin32 *network_; + + // All packets queued for writing on the network thread. Locked by |wqueue_mutex_| + Packet *wqueue_, **wqueue_end_; + + PacketProcessor *packet_handler_; + + // Protects wqueue_ + Mutex wqueue_mutex_; + + // Used for queueing things on the network instance + QueuedItem queued_item_; + +}; + diff --git a/tunsafe_amalgam.cpp b/tunsafe_amalgam.cpp index 9d0c94e..9a28774 100644 --- a/tunsafe_amalgam.cpp +++ b/tunsafe_amalgam.cpp @@ -23,7 +23,7 @@ #if defined(WITH_NETWORK_BSD) #include "network_bsd.cpp" -#include "network_bsd_common.cpp" +#include "tunsafe_bsd.cpp" #include "ts.cpp" #include "benchmark.cpp" #endif diff --git a/network_bsd_common.cpp b/tunsafe_bsd.cpp similarity index 72% rename from network_bsd_common.cpp rename to tunsafe_bsd.cpp index b2bf710..ded6515 100644 --- a/network_bsd_common.cpp +++ b/tunsafe_bsd.cpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: AGPL-1.0-only // Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. -#include "network_bsd_common.h" +#include "tunsafe_bsd.h" #include "tunsafe_endian.h" #include "util.h" @@ -43,17 +43,6 @@ #include #endif -void tunsafe_die(const char *msg) { - fprintf(stderr, "%s\n", msg); - exit(1); -} - -void SetThreadName(const char *name) { -#if defined(OS_LINUX) - prctl(PR_SET_NAME, name, 0, 0, 0); -#endif // defined(OS_LINUX) -} - #if defined(OS_MACOSX) || defined(OS_FREEBSD) struct MyRouteMsg { struct rt_msghdr hdr; @@ -346,21 +335,7 @@ int open_tun(char *devname, size_t devname_size) { } #endif -int open_udp(int listen_on_port) { - int udp_fd = socket(AF_INET, SOCK_DGRAM, 0); - if (udp_fd < 0) return udp_fd; - sockaddr_in sin = {0}; - sin.sin_family = AF_INET; - sin.sin_port = htons(listen_on_port); - if (bind(udp_fd, (struct sockaddr*)&sin, sizeof(sin)) != 0) { - close(udp_fd); - return -1; - } - return udp_fd; -} - -TunsafeBackendBsd::TunsafeBackendBsd() - : processor_(NULL) { +TunsafeBackendBsd::TunsafeBackendBsd() { devname_[0] = 0; tun_interface_gone_ = false; } @@ -579,122 +554,33 @@ bool TunsafeBackendBsd::RunPrePostCommand(const std::vector &vec) { return success; } -#if defined(OS_LINUX) -UnixSocketDeletionWatcher::UnixSocketDeletionWatcher() - : inotify_fd_(-1) { - pipes_[0] = -1; - pipes_[0] = -1; -} - -UnixSocketDeletionWatcher::~UnixSocketDeletionWatcher() { - close(inotify_fd_); - close(pipes_[0]); - close(pipes_[1]); -} - -bool UnixSocketDeletionWatcher::Start(const char *path, bool *flag_to_set) { - assert(inotify_fd_ == -1); - path_ = path; - flag_to_set_ = flag_to_set; - pid_ = getpid(); - inotify_fd_ = inotify_init1(IN_CLOEXEC | IN_NONBLOCK); - if (inotify_fd_ == -1) { - perror("inotify_init1() failed"); - return false; - } - if (inotify_add_watch(inotify_fd_, "/var/run/wireguard", IN_DELETE | IN_DELETE_SELF) == -1) { - perror("inotify_add_watch failed"); - return false; - } - if (pipe(pipes_) == -1) { - perror("pipe() failed"); - return false; - } - return pthread_create(&thread_, NULL, &UnixSocketDeletionWatcher::RunThread, this) == 0; -} - -void UnixSocketDeletionWatcher::Stop() { - RINFO("Stopping.."); - void *retval; - write(pipes_[1], "", 1); - pthread_join(thread_, &retval); -} - -void *UnixSocketDeletionWatcher::RunThread(void *arg) { - UnixSocketDeletionWatcher *self = (UnixSocketDeletionWatcher*)arg; - return self->RunThreadInner(); -} - -void *UnixSocketDeletionWatcher::RunThreadInner() { - char buf[sizeof(struct inotify_event) + NAME_MAX + 1] - __attribute__ ((aligned(__alignof__(struct inotify_event)))); - fd_set fdset; - struct stat st; - for(;;) { - if (lstat(path_, &st) == -1 && errno == ENOENT) { - RINFO("Unix socket %s deleted.", path_); - *flag_to_set_ = true; - kill(pid_, SIGALRM); - break; - } - FD_ZERO(&fdset); - FD_SET(inotify_fd_, &fdset); - FD_SET(pipes_[0], &fdset); - int n = select(std::max(inotify_fd_, pipes_[0]) + 1, &fdset, NULL, NULL, NULL); - if (n == -1) { - perror("select"); - break; - } - if (FD_ISSET(inotify_fd_, &fdset)) { - ssize_t len = read(inotify_fd_, buf, sizeof(buf)); - if (len == -1) { - perror("read"); - break; - } - } - if (FD_ISSET(pipes_[0], &fdset)) - break; - } - return NULL; -} - -#else // !defined(OS_LINUX) - -bool UnixSocketDeletionWatcher::Poll(const char *path) { - struct stat st; - return lstat(path, &st) == -1 && errno == ENOENT; -} - -#endif // !defined(OS_LINUX) - -static TunsafeBackendBsd *g_tunsafe_backend_bsd; - -static void SigAlrm(int sig) { - if (g_tunsafe_backend_bsd) - g_tunsafe_backend_bsd->HandleSigAlrm(); -} +static SignalCatcher *g_signal_catcher; static bool did_ctrlc; -void SigInt(int sig) { +void SignalCatcher::SigAlrm(int sig) { + if (g_signal_catcher) + *g_signal_catcher->sigalarm_flag_ = true; +} + +void SignalCatcher::SigInt(int sig) { if (did_ctrlc) exit(1); did_ctrlc = true; - write(1, "Ctrl-C detected. Exiting. Press again to force quit.\n", sizeof("Ctrl-C detected. Exiting. Press again to force quit.\n")-1); - + write(1, "Ctrl-C detected. Exiting. Press again to force quit.\n", sizeof("Ctrl-C detected. Exiting. Press again to force quit.\n") - 1); // todo: fix signal safety? - if (g_tunsafe_backend_bsd) - g_tunsafe_backend_bsd->HandleExit(); + if (g_signal_catcher) + *g_signal_catcher->exit_flag_ = true; } -void TunsafeBackendBsd::RunLoop() { - assert(!g_tunsafe_backend_bsd); - assert(processor_); +SignalCatcher::SignalCatcher(bool *exit_flag, bool *sigalarm_flag) { + assert(g_signal_catcher == NULL); + exit_flag_ = exit_flag; + sigalarm_flag_ = sigalarm_flag; + g_signal_catcher = this; sigset_t mask; - g_tunsafe_backend_bsd = this; - // We want an alarm signal every second. { struct sigaction act = {0}; @@ -713,7 +599,6 @@ void TunsafeBackendBsd::RunLoop() { return; } } - #if defined(OS_LINUX) || defined(OS_FREEBSD) sigemptyset(&mask); sigaddset(&mask, SIGALRM); @@ -737,7 +622,7 @@ void TunsafeBackendBsd::RunLoop() { if (timer_create(CLOCK_MONOTONIC, &sev, &timer_id) < 0) { RERROR("timer_create failed"); return; - } + } if (timer_settime(timer_id, 0, &tv, NULL) < 0) { RERROR("timer_settime failed"); @@ -747,51 +632,209 @@ void TunsafeBackendBsd::RunLoop() { #elif defined(OS_MACOSX) ualarm(1000000, 1000000); #endif +} - RunLoopInner(); - - g_tunsafe_backend_bsd = NULL; +SignalCatcher::~SignalCatcher() { + g_signal_catcher = NULL; } void InitCpuFeatures(); void Benchmark(); - const char *print_ip(char buf[kSizeOfAddress], in_addr_t ip) { snprintf(buf, kSizeOfAddress, "%d.%d.%d.%d", (ip >> 24) & 0xff, (ip >> 16) & 0xff, (ip >> 8) & 0xff, (ip >> 0) & 0xff); return buf; } -class MyProcessorDelegate : public ProcessorDelegate { +class TunsafeBackendBsdImpl : public TunsafeBackendBsd, public NetworkBsd::NetworkBsdDelegate, public ProcessorDelegate { public: - MyProcessorDelegate() { - wg_processor_ = NULL; - is_connected_ = false; - } + TunsafeBackendBsdImpl(); + virtual ~TunsafeBackendBsdImpl(); - virtual void OnConnected() override { - if (!is_connected_) { - const WgCidrAddr *ipv4_addr = NULL; - for (const WgCidrAddr &x : wg_processor_->addr()) { - if (x.size == 32) { ipv4_addr = &x; break; } - } - uint32 ipv4_ip = ipv4_addr ? ReadBE32(ipv4_addr->addr) : 0; - char buf[kSizeOfAddress]; - RINFO("Connection established. IP %s", ipv4_ip ? print_ip(buf, ipv4_ip) : "(none)"); - is_connected_ = true; - } - } - virtual void OnConnectionRetry(uint32 attempts) override { - if (is_connected_ && attempts >= 3) { - is_connected_ = false; - RINFO("Reconnecting..."); - } - } + void RunLoop(); + virtual bool InitializeTun(char devname[16]) override; + + // -- from TunInterface + virtual void WriteTunPacket(Packet *packet) override; + + // -- from UdpInterface + virtual bool Configure(int listen_port_udp, int listen_port_tcp) override; + virtual void WriteUdpPacket(Packet *packet) override; + + // -- from NetworkBsdDelegate + virtual void OnSecondLoop(uint64 now) override; + virtual void RunAllMainThreadScheduled() override; + + // -- from ProcessorDelegate + virtual void OnConnected() override; + virtual void OnConnectionRetry(uint32 attempts) override; + + WireguardProcessor *processor() { return &processor_; } + +private: + void WriteTcpPacket(Packet *packet); + + // Close all TCP connections that are not pointed to by any of the peer endpoint. + void CloseOrphanTcpConnections(); - WireguardProcessor *wg_processor_; bool is_connected_; + uint8 close_orphan_counter_; + WireguardProcessor processor_; + NetworkBsd network_; + TunSocketBsd tun_; + UdpSocketBsd udp_; + UnixDomainSocketListenerBsd unix_socket_listener_; + TcpSocketListenerBsd tcp_socket_listener_; }; +TunsafeBackendBsdImpl::TunsafeBackendBsdImpl() + : is_connected_(false), + close_orphan_counter_(0), + processor_(this, this, this), + network_(this, 1000), + tun_(&network_, &processor_), + udp_(&network_, &processor_), + unix_socket_listener_(&network_, &processor_), + tcp_socket_listener_(&network_, &processor_) { +} + +TunsafeBackendBsdImpl::~TunsafeBackendBsdImpl() { +} + +bool TunsafeBackendBsdImpl::InitializeTun(char devname[16]) { + int tun_fd = open_tun(devname, 16); + if (tun_fd < 0) { RERROR("Error opening tun device"); return false; } + if (!tun_.Initialize(tun_fd)) { + close(tun_fd); + return false; + } + unix_socket_listener_.Initialize(devname); + return true; +} + +void TunsafeBackendBsdImpl::WriteTunPacket(Packet *packet) { + tun_.WritePacket(packet); +} + +// Called to initialize udp +bool TunsafeBackendBsdImpl::Configure(int listen_port, int listen_port_tcp) { + return udp_.Initialize(listen_port) && + (listen_port_tcp == 0 || tcp_socket_listener_.Initialize(listen_port_tcp)); +} + +void TunsafeBackendBsdImpl::WriteTcpPacket(Packet *packet) { + // Check if we have a tcp connection for the endpoint, otherwise create one. + for (TcpSocketBsd *tcp = network_.tcp_sockets(); tcp; tcp = tcp->next()) { + // After we send 3 handshakes on a tcp socket in a row, then close and reopen the socket because it seems defunct. + if (CompareIpAddr(&tcp->endpoint(), &packet->addr) == 0 && tcp->endpoint_protocol() == packet->protocol) { + if (ReadLE32(packet->data) == MESSAGE_HANDSHAKE_INITIATION) { + if (tcp->handshake_attempts == 2) { + RINFO("Making new Tcp socket due to too many handshake failures"); + delete tcp; + break; + } + tcp->handshake_attempts++; + } else { + tcp->handshake_attempts = -1; + } + tcp->WritePacket(packet); + return; + } + } + // Drop tcp packet that's for an incoming connection, or packets that are + // not a handshake. + if ((packet->protocol & kPacketProtocolIncomingConnection) || + ReadLE32(packet->data) != MESSAGE_HANDSHAKE_INITIATION) { + FreePacket(packet); + return; + } + // Initialize a new tcp socket and connect to the endpoint + TcpSocketBsd *tcp = new TcpSocketBsd(&network_, &processor_); + if (!tcp || !tcp->InitializeOutgoing(packet->addr)) { + delete tcp; + FreePacket(packet); + return; + } + tcp->WritePacket(packet); +} + +void TunsafeBackendBsdImpl::WriteUdpPacket(Packet *packet) { + assert((packet->protocol & 0x7F) <= 2); + if (packet->protocol & kPacketProtocolTcp) { + WriteTcpPacket(packet); + } else { + udp_.WritePacket(packet); + } +} + +void TunsafeBackendBsdImpl::RunLoop() { + if (!unix_socket_listener_.Start(network_.exit_flag())) + return; + + SignalCatcher signal_catcher(network_.exit_flag(), network_.sigalarm_flag()); + network_.RunLoop(&signal_catcher.orig_signal_mask_); + unix_socket_listener_.Stop(); + + tun_interface_gone_ = tun_.tun_interface_gone(); +} + +void TunsafeBackendBsdImpl::OnSecondLoop(uint64 now) { + if (!(close_orphan_counter_++ & 0xF)) + CloseOrphanTcpConnections(); + processor_.SecondLoop(); +} + +void TunsafeBackendBsdImpl::RunAllMainThreadScheduled() { + processor_.RunAllMainThreadScheduled(); +} + +void TunsafeBackendBsdImpl::OnConnected() { + if (!is_connected_) { + const WgCidrAddr *ipv4_addr = NULL; + for (const WgCidrAddr &x : processor_.addr()) { + if (x.size == 32) { ipv4_addr = &x; break; } + } + uint32 ipv4_ip = ipv4_addr ? ReadBE32(ipv4_addr->addr) : 0; + char buf[kSizeOfAddress]; + RINFO("Connection established. IP %s", ipv4_ip ? print_ip(buf, ipv4_ip) : "(none)"); + is_connected_ = true; + } +} + +void TunsafeBackendBsdImpl::OnConnectionRetry(uint32 attempts) { + if (is_connected_ && attempts >= 3) { + is_connected_ = false; + RINFO("Reconnecting..."); + } +} + +void TunsafeBackendBsdImpl::CloseOrphanTcpConnections() { + // Add all incoming tcp connections into a lookup table + WG_HASHTABLE_IMPL lookup; + for (TcpSocketBsd *tcp = network_.tcp_sockets(); tcp; tcp = tcp->next()) { + if (tcp->endpoint_protocol() == (kPacketProtocolTcp | kPacketProtocolIncomingConnection)) { + // Avoid deleting tcp sockets that were just born. + if (tcp->age == 0) { + tcp->age = 1; + } else { + lookup[ConvertIpAddrToAddrX(tcp->endpoint())] = tcp; + } + } + } + if (lookup.empty()) + return; + // For each peer, check if it has an endpoint that matches + // an entry in the lookup table, and delete it from the lookup + // table. + for(WgPeer *peer = processor_.dev().first_peer(); peer; peer = peer->next_peer()) { + if (peer->endpoint_protocol() == (kPacketProtocolTcp | kPacketProtocolIncomingConnection)) + lookup.erase(ConvertIpAddrToAddrX(peer->endpoint())); + } + // The tcp connections that are still in the hashtable can be deleted + for(const auto &it : lookup) + delete (TcpSocketBsd *)it.second; +} + int main(int argc, char **argv) { CommandLineOutput cmd = {0}; @@ -812,20 +855,15 @@ int main(int argc, char **argv) { SetThreadName("tunsafe-m"); - MyProcessorDelegate my_procdel; - TunsafeBackendBsd *backend = CreateTunsafeBackendBsd(); + TunsafeBackendBsdImpl backend; if (cmd.interface_name) - backend->SetTunDeviceName(cmd.interface_name); - - WireguardProcessor wg(backend, backend, &my_procdel); - - my_procdel.wg_processor_ = &wg; - backend->SetProcessor(&wg); + backend.SetTunDeviceName(cmd.interface_name); DnsResolver dns_resolver(NULL); - if (*cmd.filename_to_load && !ParseWireGuardConfigFile(&wg, cmd.filename_to_load, &dns_resolver)) + if (*cmd.filename_to_load && !ParseWireGuardConfigFile(backend.processor(), cmd.filename_to_load, &dns_resolver)) + return 1; + if (!backend.processor()->Start()) return 1; - if (!wg.Start()) return 1; if (cmd.daemon) { fprintf(stderr, "Switching to daemon mode...\n"); @@ -833,9 +871,8 @@ int main(int argc, char **argv) { perror("daemon() failed"); } - backend->RunLoop(); - backend->CleanupRoutes(); - delete backend; + backend.RunLoop(); + backend.CleanupRoutes(); return 0; } diff --git a/tunsafe_bsd.h b/tunsafe_bsd.h new file mode 100644 index 0000000..11bfb82 --- /dev/null +++ b/tunsafe_bsd.h @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: AGPL-1.0-only +// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. +#ifndef TUNSAFE_NETWORK_BSD_COMMON_H_ +#define TUNSAFE_NETWORK_BSD_COMMON_H_ + +#include "netapi.h" +#include "wireguard.h" +#include "wireguard_config.h" +#include +#include + +struct RouteInfo { + uint8 family; + uint8 cidr; + uint8 ip[16]; + uint8 gw[16]; + std::string dev; +}; + +class SignalCatcher { +public: + SignalCatcher(bool *exit_flag, bool *sigalarm_flag); + ~SignalCatcher(); + + sigset_t orig_signal_mask_; +private: + static void SigAlrm(int sig); + static void SigInt(int sig); + bool *exit_flag_; + bool *sigalarm_flag_; +}; + +class TunsafeBackendBsd : public TunInterface, public UdpInterface { +public: + TunsafeBackendBsd(); + virtual ~TunsafeBackendBsd(); + + void CleanupRoutes(); + + void SetTunDeviceName(const char *name); + + // -- from TunInterface + virtual bool Configure(const TunConfig &&config, TunConfigOut *out) override; + +protected: + virtual bool InitializeTun(char devname[16]) = 0; + + void AddRoute(uint32 ip, uint32 cidr, uint32 gw, const char *dev); + void DelRoute(const RouteInfo &cd); + bool AddRoute(int family, const void *dest, int dest_prefix, const void *gateway, const char *dev); + bool RunPrePostCommand(const std::vector &vec); + + std::vector cleanup_commands_; + std::vector pre_down_, post_down_; + std::vector addresses_to_remove_; + char devname_[16]; + bool tun_interface_gone_; +}; + +#endif // TUNSAFE_NETWORK_BSD_COMMON_H_ diff --git a/wireguard.cpp b/wireguard.cpp index 14859a3..8fce37f 100644 --- a/wireguard.cpp +++ b/wireguard.cpp @@ -28,6 +28,7 @@ WireguardProcessor::WireguardProcessor(UdpInterface *udp, TunInterface *tun, Pro mtu_ = 1420; memset(&stats_, 0, sizeof(stats_)); listen_port_ = 0; + listen_port_tcp_ = 0; network_discovery_spoofing_ = false; add_routes_mode_ = true; dns_blocking_ = true; @@ -50,6 +51,16 @@ void WireguardProcessor::SetListenPort(int listen_port) { } } +void WireguardProcessor::SetListenPortTcp(int listen_port) { + if (listen_port_tcp_ != listen_port) { + listen_port_tcp_ = listen_port; + if (is_started_ && !ConfigureUdp()) { + RINFO("ConfigureUdp failed"); + } + } +} + + void WireguardProcessor::AddDnsServer(const IpAddr &sin) { dns_addr_.push_back(sin); } @@ -126,7 +137,7 @@ bool WireguardProcessor::Start() { bool WireguardProcessor::ConfigureUdp() { assert(dev_.IsMainThread()); - return udp_->Configure(listen_port_); + return udp_->Configure(listen_port_, listen_port_tcp_); } bool WireguardProcessor::ConfigureTun() { diff --git a/wireguard.h b/wireguard.h index 7919bd1..869b3c0 100644 --- a/wireguard.h +++ b/wireguard.h @@ -81,6 +81,8 @@ public: ~WireguardProcessor(); void SetListenPort(int listen_port); + void SetListenPortTcp(int listen_port); + void AddDnsServer(const IpAddr &sin); bool SetTunAddress(const WgCidrAddr &addr); void ClearTunAddress(); @@ -132,6 +134,7 @@ private: UdpInterface *udp_; uint16 listen_port_; + uint16 listen_port_tcp_; uint16 mtu_; bool dns_blocking_; diff --git a/wireguard_config.cpp b/wireguard_config.cpp index 007517a..50c11c5 100644 --- a/wireguard_config.cpp +++ b/wireguard_config.cpp @@ -104,6 +104,8 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) { wg_->dev().SetPrivateKey(binkey); } else if (strcmp(key, "ListenPort") == 0) { wg_->SetListenPort(atoi(value)); + } else if (strcmp(key, "ListenPortTCP") == 0) { + wg_->SetListenPortTcp(atoi(value)); } else if (strcmp(key, "Address") == 0) { SplitString(value, ',', &ss); for (size_t i = 0; i < ss.size(); i++) {