From 27b75b83de6d4510e2b590eb6ea46ef204f2c9c7 Mon Sep 17 00:00:00 2001 From: Ludvig Strigeus Date: Sun, 16 Dec 2018 16:02:50 +0100 Subject: [PATCH] Lots of new features - Hybrid TCP mode, uses both TCP and UDP - Simplified TCP protocol - Modified obfuscator to support padding - Obfuscation over TCP - Refactor parts of Win32 code to be more similar to BSD --- docs/WireGuard TCP.txt | 38 +- netapi.h | 1 + network_bsd.cpp | 87 ++++- network_bsd.h | 11 +- network_common.cpp | 770 +++++++++++++++++++++++++++++++++++------ network_common.h | 81 ++++- network_win32.cpp | 395 +++++++++++---------- network_win32.h | 108 ++++-- network_win32_api.h | 1 - network_win32_tcp.cpp | 90 ++--- network_win32_tcp.h | 20 +- tunsafe_bsd.cpp | 44 +-- tunsafe_win32.cpp | 1 - wireguard.cpp | 28 +- wireguard_config.cpp | 22 +- wireguard_proto.cpp | 107 ++++-- wireguard_proto.h | 37 +- 17 files changed, 1313 insertions(+), 528 deletions(-) diff --git a/docs/WireGuard TCP.txt b/docs/WireGuard TCP.txt index dabb5c3..1933349 100644 --- a/docs/WireGuard TCP.txt +++ b/docs/WireGuard TCP.txt @@ -65,40 +65,12 @@ TT LLLLLL LLLLLLLL [Payload LL bytes] The packet types (TT) currently defined are: TT = 00 = Normal The payload is a normal unmodified WireGuard packet including the regular WireGuard header. - 01 = Reserved 10 = Data A WireGuard data packet (type 04) without the 16 byte - header. The predicted header is prefixed to the payload. - 11 = Control A TCP control packet. Currently this is used only to setup - the header prediction. See below. + header. + ?1 = Reserved -There's only one defined Control packet, type 00 (SetKeyAndCounter): - - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - |1 1| Length is 13 (14 bits) | 00 (8 bits) | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Key ID (32 bits) | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Counter (64 bits) ... - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - -This sets up the Key ID and Counter used for the Data packets. Then Counter -is incremented by 1 for every such packet. - -For every Data packet, the predicted Key ID and Counter is expanded to a -regular WireGuard data (type 04) header, which is prefixed to the payload: - - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | 04 (8 bits) | Reserved (24 bits) | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Key ID (32 bits) | - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Counter (64 bits) ... - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - | Data Payload (LL * 8 bits) ... - +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +When parsing an incoming Data (TT=10) packet, the Key ID and Counter from +the most recently parsed WireGuard data packet (type 04) is prepended to +the payload, with Counter incremented by 1. This happens independently in each of the two TCP directions. diff --git a/netapi.h b/netapi.h index 4614aad..ca19ead 100644 --- a/netapi.h +++ b/netapi.h @@ -58,6 +58,7 @@ struct Packet : QueuedItem { byte *data; uint8 userdata; uint8 protocol; // which protocol is this packet for/from + bool prepared; IpAddr addr; // Optionally set to target/source of the packet enum { diff --git a/network_bsd.cpp b/network_bsd.cpp index 64c1cd1..bcc7b90 100644 --- a/network_bsd.cpp +++ b/network_bsd.cpp @@ -527,6 +527,10 @@ bool UdpSocketBsd::DoWrite() { void UdpSocketBsd::WritePacket(Packet *packet) { assert(fd_ >= 0); + + if (processor_->dev().packet_obfuscator().enabled()) + processor_->dev().packet_obfuscator().ObfuscatePacket(packet); + Packet *queue_is_used = udp_queue_; *udp_queue_end_ = packet; udp_queue_end_ = &Packet_NEXT(packet); @@ -824,7 +828,8 @@ void TcpSocketListenerBsd::HandleEvents(int revents) { int new_fd = accept(fd_, (sockaddr*)&addr, &len); if (new_fd >= 0) { RINFO("Created new tcp socket"); - TcpSocketBsd *channel = new TcpSocketBsd(network_, processor_); + + TcpSocketBsd *channel = new TcpSocketBsd(network_, processor_, true); if (channel) channel->InitializeIncoming(new_fd, addr); else @@ -840,18 +845,61 @@ void TcpSocketListenerBsd::Periodic() { } ////////////////////////////////////////////////////////////////////////////////////////////// -TcpSocketBsd::TcpSocketBsd(NetworkBsd *net, WireguardProcessor *processor) +void TcpSocketBsd::WriteTcpPacket(NetworkBsd *network, WireguardProcessor *processor, Packet *packet) { + bool is_handshake = ReadLE32(packet->data) == MESSAGE_HANDSHAKE_INITIATION; + + // Check if we have a tcp connection for the endpoint, otherwise create one. + for (TcpSocketBsd *tcp = network->tcp_sockets(); tcp; tcp = tcp->next()) { + // After we send 3 handshakes on a tcp socket in a row, then close and reopen the socket because it seems defunct. + if (CompareIpAddr(&tcp->endpoint(), &packet->addr) == 0 && tcp->endpoint_protocol() == packet->protocol) { + if (is_handshake) { + uint32 now = (uint32)OsGetMilliseconds(); + uint32 secs = (now - tcp->handshake_timestamp_) >> 10; + tcp->handshake_timestamp_ += secs * 1024; + int calc = (secs > (uint32)tcp->handshake_attempts_ + 25) ? 0 : tcp->handshake_attempts_ + 25 - secs; + tcp->handshake_attempts_ = calc; + if (calc >= 60) { + RINFO("Making new Tcp socket due to too many handshake failures"); + delete tcp; + break; + } + } + tcp->WritePacket(packet); + return; + } + } + // Drop tcp packet that's for an incoming connection, or packets that are + // not a handshake. + if ((packet->protocol & kPacketProtocolIncomingConnection) || !is_handshake) { + FreePacket(packet); + return; + } + // Initialize a new tcp socket and connect to the endpoint + TcpSocketBsd *tcp = new TcpSocketBsd(network, processor, false); + if (!tcp || !tcp->InitializeOutgoing(packet->addr)) { + delete tcp; + FreePacket(packet); + return; + } + tcp->WritePacket(packet); +} + + +////////////////////////////////////////////////////////////////////////////////////////////// + +TcpSocketBsd::TcpSocketBsd(NetworkBsd *net, WireguardProcessor *processor, bool is_incoming) : BaseSocketBsd(net), readable_(false), writable_(true), endpoint_protocol_(0), age(0), - handshake_attempts(0), + handshake_attempts_(0), + handshake_timestamp_(0), wqueue_(NULL), wqueue_end_(&wqueue_), - wqueue_bytes_(0), + wqueue_packets_(0), processor_(processor), - tcp_packet_handler_(&net->packet_pool_) { + tcp_packet_handler_(&net->packet_pool_, &processor->dev().packet_obfuscator(), is_incoming) { // insert in network's linked list next_ = net->tcp_sockets_; net->tcp_sockets_ = this; @@ -908,19 +956,22 @@ bool TcpSocketBsd::InitializeOutgoing(const IpAddr &addr) { void TcpSocketBsd::WritePacket(Packet *packet) { assert(fd_ >= 0); - tcp_packet_handler_.AddHeaderToOutgoingPacket(packet); Packet *old_value = wqueue_; *wqueue_end_ = packet; wqueue_end_ = &Packet_NEXT(packet); packet->queue_next = NULL; + packet->prepared = false; AddToEndLoop(); - wqueue_bytes_ += packet->size; + // Note: Cannot use bytes here, because the TCP packet + // headers have not been added yet, and then the + // accounting doesn't work + wqueue_packets_++; - // When many bytes have been queued, perform the write. - if (writable_ && wqueue_bytes_ >= 32768) + // When enough packets have been queued up, perform the write. + if (writable_ && wqueue_packets_ >= 16) DoWrite(); } @@ -982,10 +1033,19 @@ void TcpSocketBsd::DoWrite() { struct iovec vecs[kMaxIoWrite]; Packet *p = wqueue_; size_t nvec = 0; - for (; p && nvec < kMaxIoWrite; nvec++, p = Packet_NEXT(p)) { - vecs[nvec].iov_base = p->data; - vecs[nvec].iov_len = p->size; + for (; p && nvec < kMaxIoWrite; p = Packet_NEXT(p)) { + if (!p->prepared) + tcp_packet_handler_.PrepareOutgoingPackets(p); + + if (p->size != 0) { + vecs[nvec].iov_base = p->data; + vecs[nvec].iov_len = p->size; + nvec++; + } } + if (nvec == 0) + return; + ssize_t n = writev(fd_, vecs, nvec); if (n < 0) { @@ -998,9 +1058,7 @@ void TcpSocketBsd::DoWrite() { } return; } - wqueue_bytes_ -= n; // discard those initial n bytes worth of packets - size_t i = 0; p = wqueue_; while (n) { if (n < p->size) { @@ -1009,6 +1067,7 @@ void TcpSocketBsd::DoWrite() { } n -= p->size; FreePacket(exch(p, Packet_NEXT(p))); + wqueue_packets_--; } if (!(wqueue_ = p)) wqueue_end_ = &wqueue_; diff --git a/network_bsd.h b/network_bsd.h index 04a023b..45e82d3 100644 --- a/network_bsd.h +++ b/network_bsd.h @@ -244,7 +244,7 @@ private: class TcpSocketBsd : public BaseSocketBsd { public: - explicit TcpSocketBsd(NetworkBsd *bsd, WireguardProcessor *processor); + explicit TcpSocketBsd(NetworkBsd *bsd, WireguardProcessor *processor, bool is_incoming); virtual ~TcpSocketBsd(); void InitializeIncoming(int fd, const IpAddr &addr); @@ -259,9 +259,10 @@ public: uint8 endpoint_protocol() { return endpoint_protocol_; } const IpAddr &endpoint() { return endpoint_; } + static void WriteTcpPacket(NetworkBsd *network, WireguardProcessor *processor, Packet *packet); + public: uint8 age; - uint8 handshake_attempts; private: void DoRead(); void DoWrite(); @@ -271,8 +272,10 @@ private: bool got_eof_; uint8 endpoint_protocol_; bool want_connect_; - - uint32 wqueue_bytes_; + uint8 handshake_attempts_; + uint32 handshake_timestamp_; + + uint wqueue_packets_; Packet *wqueue_, **wqueue_end_; TcpSocketBsd *next_; WireguardProcessor *processor_; diff --git a/network_common.cpp b/network_common.cpp index 919e7fd..fb305af 100644 --- a/network_common.cpp +++ b/network_common.cpp @@ -5,81 +5,34 @@ #include #include #include "util.h" +#include "crypto/chacha20poly1305.h" +#include "crypto/blake2s/blake2s.h" +#include "wireguard_proto.h" -TcpPacketHandler::TcpPacketHandler(SimplePacketPool *packet_pool) { - packet_pool_ = packet_pool; - rqueue_bytes_ = 0; - error_flag_ = false; - rqueue_ = NULL; - rqueue_end_ = &rqueue_; - predicted_key_in_ = predicted_key_out_ = 0; - predicted_serial_in_ = predicted_serial_out_ = 0; -} +enum { + CRYPTO_HEADER_SIZE = 64, +}; -TcpPacketHandler::~TcpPacketHandler() { +enum { + READ_CRYPTO_HEADER = 0, + READ_PACKET_HEADER = 1, + READ_PACKET_DATA = 2, +}; + +TcpPacketQueue::~TcpPacketQueue() { FreePacketList(rqueue_); } -enum { - kTcpPacketType_Normal = 0, - kTcpPacketType_Reserved = 1, - kTcpPacketType_Data = 2, - kTcpPacketType_Control = 3, - kTcpPacketControlType_SetKeyAndCounter = 0, -}; - -void TcpPacketHandler::AddHeaderToOutgoingPacket(Packet *p) { - unsigned int size = p->size; - uint8 *data = p->data; - if (size >= 16 && ReadLE32(data) == 4) { - uint32 key = Read32(data + 4); - uint64 serial = ReadLE64(data + 8); - WriteBE16(data + 14, size - 16 + (kTcpPacketType_Data << 14)); - data += 14, size -= 14; - // Insert a 15 byte control packet right before to set the new key/serial? - if ((predicted_key_out_ ^ key) | (predicted_serial_out_ ^ serial)) { - predicted_key_out_ = key; - WriteLE64(data - 8, serial); - Write32(data - 12, key); - data[-13] = kTcpPacketControlType_SetKeyAndCounter; - WriteBE16(data - 15, 13 + (kTcpPacketType_Control << 14)); - data -= 15, size += 15; - } - // Increase the serial by 1 for next packet. - predicted_serial_out_ = serial + 1; - } else { - WriteBE16(data - 2, size); - data -= 2, size += 2; - } - p->size = size; - p->data = data; -} - -void TcpPacketHandler::QueueIncomingPacket(Packet *p) { - rqueue_bytes_ += p->size; - p->queue_next = NULL; - *rqueue_end_ = p; - rqueue_end_ = &Packet_NEXT(p); -} - -// Either the packet fits in one buf or not. -static uint32 ReadPacketHeader(Packet *p) { - if (p->size >= 2) - return ReadBE16(p->data); - else - return (p->data[0] << 8) + (Packet_NEXT(p)->data[0]); -} - -// Move data around to ensure that exactly the first |num| bytes are stored -// in the first packet, and the rest of the data in subsequent packets. -Packet *TcpPacketHandler::ReadNextPacket(uint32 num) { +Packet *TcpPacketQueue::Read(uint num) { + // Move data around to ensure that exactly the first |num| bytes are stored + // in the first packet, and the rest of the data in subsequent packets. Packet *p = rqueue_; assert(num <= kPacketCapacity); if (p->size < num) { // There's not enough data in the current packet, copy data from the next packet // into this packet. - if ((uint32)(&p->data_buf[kPacketCapacity] - p->data) < num) { + if ((uint)(&p->data_buf[kPacketCapacity] - p->data) < num) { // Move data up front to make space. memmove(p->data_buf, p->data, p->size); p->data = p->data_buf; @@ -87,17 +40,17 @@ Packet *TcpPacketHandler::ReadNextPacket(uint32 num) { // Copy data from future packets into p, and delete them should they become empty. do { Packet *n = Packet_NEXT(p); - uint32 bytes_to_copy = std::min(n->size, num - p->size); - uint32 nsize = (n->size -= bytes_to_copy); + uint bytes_to_copy = std::min(n->size, num - p->size); + uint nsize = (n->size -= bytes_to_copy); memcpy(p->data + postinc(p->size, bytes_to_copy), postinc(n->data, bytes_to_copy), bytes_to_copy); if (nsize == 0) { p->queue_next = n->queue_next; - packet_pool_->FreePacketToPool(n); + pool_->FreePacketToPool(n); } } while (num - p->size); } else if (p->size > num) { // The packet has too much data. Split the packet into two packets. - Packet *n = packet_pool_->AllocPacketFromPool(); + Packet *n = pool_->AllocPacketFromPool(); if (!n) return NULL; // unable to allocate a packet....? if (num * 2 <= p->size) { @@ -108,7 +61,7 @@ Packet *TcpPacketHandler::ReadNextPacket(uint32 num) { memcpy(n->data, postinc(p->data, num), num); return n; } else { - uint32 overflow = p->size - num; + uint overflow = p->size - num; // There's a lot of leading data: PPPPPP NN. Move NN n->size = overflow; p->size = num; @@ -126,49 +79,666 @@ Packet *TcpPacketHandler::ReadNextPacket(uint32 num) { return p; } -Packet *TcpPacketHandler::GetNextWireguardPacket() { - while (rqueue_bytes_ >= 2) { - uint32 packet_header = ReadPacketHeader(rqueue_); - uint32 packet_size = packet_header & 0x3FFF; - uint32 packet_type = packet_header >> 14; - if (packet_size + 2 > rqueue_bytes_) +Packet *TcpPacketQueue::ReadUpTo(uint num) { + assert(rqueue_bytes_ != 0); + Packet *p = rqueue_; + if (num < p->size) + return Read(num); + rqueue_bytes_ -= p->size; + if ((rqueue_ = Packet_NEXT(p)) == NULL) + rqueue_end_ = &rqueue_; + return p; +} + +void TcpPacketQueue::Add(Packet *p) { + assert(p->size != 0); + rqueue_bytes_ += p->size; + p->queue_next = NULL; + *rqueue_end_ = p; + rqueue_end_ = &Packet_NEXT(p); +} + +void TcpPacketQueue::Read(uint8 *dst, uint size) { + assert(size <= rqueue_bytes_); + rqueue_bytes_ -= size; + while (size) { + Packet *packet = rqueue_; + uint n = std::min(packet->size, size); + uint8 *src = packet->data; + for (uint i = 0; i != n; i++) + *dst++ = *src++; + packet->data = src; + size -= n; + if ((packet->size -= n) == 0) { + if ((rqueue_ = Packet_NEXT(packet)) == NULL) + rqueue_end_ = &rqueue_; + pool_->FreePacketToPool(packet); + } + } +} + +uint TcpPacketQueue::PeekUint16() { + return (rqueue_->size >= 2) ? ReadBE16(rqueue_->data) : + (rqueue_->data[0] << 8) + Packet_NEXT(rqueue_)->data[0]; +} + +TcpPacketHandler::TcpPacketHandler(SimplePacketPool *packet_pool, WgPacketObfuscator *obfuscator, bool is_incoming) + : queue_(packet_pool), + tls_queue_(packet_pool), + write_state_(is_incoming), + obfuscation_mode_(kObfuscationMode_None) { + + if (obfuscator->enabled() && obfuscator->obfuscate_tcp() != TcpPacketHandler::kObfuscationMode_None) { + memcpy(encryptor_.buf, obfuscator->key(), CHACHA20POLY1305_KEYLEN); + memcpy(decryptor_.buf, obfuscator->key(), CHACHA20POLY1305_KEYLEN); + obfuscation_mode_ = obfuscator->obfuscate_tcp() != TcpPacketHandler::kObfuscationMode_Unspecified ? obfuscator->obfuscate_tcp() : + (is_incoming ? TcpPacketHandler::kObfuscationMode_Autodetect : TcpPacketHandler::kObfuscationMode_Encrypted); + read_state_ = (obfuscation_mode_ == kObfuscationMode_Encrypted) ? READ_CRYPTO_HEADER : READ_PACKET_HEADER; + } else if (!obfuscator->enabled() && obfuscator->obfuscate_tcp() > TcpPacketHandler::kObfuscationMode_None) { + RERROR("No ObfuscateKey specified. Disabling TCP obfuscation."); + } + tls_read_state_ = 0; + error_flag_ = false; + decryptor_initialized_ = false; + predicted_key_in_ = predicted_key_out_ = 0; + predicted_serial_in_ = predicted_serial_out_ = 0; +} + +TcpPacketHandler::~TcpPacketHandler() { +} + +enum { + kTcpPacketType_Normal = 0, + kTcpPacketType_Reserved = 1, + kTcpPacketType_Data = 2, + kTcpPacketType_Control = 3, + kTcpPacketControlType_SetKeyAndCounter = 0, +}; + +static void SetChachaStreamingKey(chacha20_streaming *chacha, const uint8 *key, size_t key_len) { + blake2s(chacha->buf, CHACHA20POLY1305_KEYLEN, key, key_len, chacha->buf, CHACHA20POLY1305_KEYLEN); + chacha20_streaming_init(chacha, chacha->buf); +} + +size_t TcpPacketHandler::CreateTls13ClientHello(uint8 *dst) { + uint8 *dst_org = dst; + // handshake, tls 1.0 + *dst++ = 0x16; + *dst++ = 0x03; + *dst++ = 0x01; + uint8 *handshake_length = postinc(dst, 2); + // handshake client hello + *dst++ = 0x01; + *dst++ = 0x00; + uint8 *handshake_inner_length = postinc(dst, 2); + // version = tls 1.2 + *dst++ = 0x03; + *dst++ = 0x03; + // 32 byte random + OsGetRandomBytes(postinc(dst, 32), 32); + *dst++ = 0x20; // Session length = 32 + // 32 byte session id + OsGetRandomBytes(postinc(dst, 32), 32); + + bool firefox = (obfuscation_mode_ == kObfuscationMode_TlsFirefox); + + if (firefox) { + static const uint8 tls_header1[] = { + // 18 cipher suites + 0x00, 0x24, + 0x13, 0x01, 0x13, 0x03, 0x13, 0x02, 0xc0, 0x2b, 0xc0, 0x2f, 0xcc, 0xa9, 0xcc, 0xa8, 0xc0, 0x2c, 0xc0, 0x30, + 0xc0, 0x0a, 0xc0, 0x09, 0xc0, 0x13, 0xc0, 0x14, 0x00, 0x33, 0x00, 0x39, 0x00, 0x2f, 0x00, 0x35, 0x00, 0x0a, + // compression method = null + 0x01, 0x00, + }; + memcpy(postinc(dst, sizeof(tls_header1)), tls_header1, sizeof(tls_header1)); + } else { + static const uint8 tls_header1_chrome[] = { + // 17 cipher suites + 0x00, 0x22, + 0xda, 0xda, 0x13, 0x01, 0x13, 0x02, 0x13, 0x03, 0xc0, 0x2b, 0xc0, 0x2f, 0xc0, 0x2c, 0xc0, 0x30, 0xcc, 0xa9, + 0xcc, 0xa8, 0xc0, 0x13, 0xc0, 0x14, 0x00, 0x9c, 0x00, 0x9d, 0x00, 0x2f, 0x00, 0x35, 0x00, 0x0a, + // compression method = null + 0x01, 0x00, + }; + memcpy(postinc(dst, sizeof(tls_header1_chrome)), tls_header1_chrome, sizeof(tls_header1_chrome)); + + } + uint8 *extensions_length = postinc(dst, 2); + + if (!firefox) { + static const uint8 tls_header_grease[] = { 0xaa, 0xaa, 0x00, 0x00 }; + memcpy(postinc(dst, sizeof(tls_header_grease)), tls_header_grease, sizeof(tls_header_grease)); + } + + static const uint8 tls_header2[] = { + // extension server name + 0x00, 0x00, 0x00, 0x16, 0x00, 0x14, 0x00, 0x00, 0x11, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x2e, 0x74, 0x6c, 0x73, 0x31, 0x33, 0x2e, 0x63, 0x6f, 0x6d, + // extension master secret + 0x00, 0x17, 0x00, 0x00, + // extension renegotiation info + 0xff, 0x01, 0x00, 0x01, 0x00, + }; + memcpy(postinc(dst, sizeof(tls_header2)), tls_header2, sizeof(tls_header2)); + + if (firefox) { + static const uint8 tls_header_groups_ff[] = { + // extension supported groups + 0x00, 0x0a, 0x00, 0x0e, 0x00, 0x0c, + 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, 0x01, 0x00, 0x01, 0x01, + // extension ec_point_formats + 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, + // extension application_layer_protocol_negotiation + 0x00, 0x10, 0x00, 0x0e, 0x00, 0x0c, 0x02, 0x68, 0x32, 0x08, 0x68, 0x74, 0x74, 0x70, 0x2f, 0x31, 0x2e, 0x31, + // extension status request + 0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00, + // extension key share + 0x00, 0x33, 0x00, 0x6b, 0x00, 0x69, + // key share x25519 + 0x00, 0x1d, 0x00, 0x20, + }; + memcpy(postinc(dst, sizeof(tls_header_groups_ff)), tls_header_groups_ff, sizeof(tls_header_groups_ff)); + // Firefox has a secp251p1 key while chrome does not + OsGetRandomBytes(postinc(dst, 32), 32); + dst[-1] &= 0x7f; // clear top bit of x25519 key + static const uint8 tls_header3[] = { + // key share secp256p1 + 0x00, 0x17, 0x00, 0x41, + 0x04, + }; + memcpy(postinc(dst, sizeof(tls_header3)), tls_header3, sizeof(tls_header3)); + // todo: validate the secp256p1 key + OsGetRandomBytes(postinc(dst, 64), 64); + + static const uint8 tls_header4[] = { + // extension early data (seems to be sent only in resume) + 0x00, 0x2a, 0x00, 0x00, + // extension supported versions + 0x00, 0x2b, 0x00, 0x09, 0x08, 0x03, 0x04, 0x03, 0x03, 0x03, 0x02, 0x03, 0x01, + // extension signature_algorithms + 0x00, 0x0d, 0x00, 0x18, 0x00, 0x16, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01, 0x06, 0x01, 0x02, 0x03, 0x02, 0x01, + // extension psk_key_exchange_modes + 0x00, 0x2d, 0x00, 0x02, 0x01, 0x01, + // extension unknown type 28 + 0x00, 0x1c, 0x00, 0x02, 0x40, 0x01, + // extension pre shared key length=235 + 0x00, 0x29, 0x00, 0xeb, + // identities length=198, psk identity length = 192 + 0x00, 0xc6, 0x00, 0xc0, + }; + memcpy(postinc(dst, sizeof(tls_header4)), tls_header4, sizeof(tls_header4)); + + } else { + static const uint8 tls_header_groups_chrome[] = { + // extension supported groups + 0x00, 0x0a, 0x00, 0x0a, 0x00, 0x08, 0x2a, 0x2a, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, + // extension ec_point_formats + 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, + // extension sessionticket tls + 0x00, 0x23, 0x00, 0x00, + // extension application_layer_protocol_negotiation + 0x00, 0x10, 0x00, 0x0e, 0x00, 0x0c, 0x02, 0x68, 0x32, 0x08, 0x68, 0x74, 0x74, 0x70, 0x2f, 0x31, 0x2e, 0x31, + // extension status request + 0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00, + // extension signature_algorithms + 0x00, 0x0d, 0x00, 0x14, 0x00, 0x12, 0x04, 0x03, 0x08, 0x04, 0x04, 0x01, 0x05, 0x03, 0x08, 0x05, 0x05, 0x01, 0x08, 0x06, 0x06, 0x01, 0x02, 0x01, + // extension signed_certificate_timestamp + 0x00, 0x12, 0x00, 0x00, + // extension key_share + 0x00, 0x33, 0x00, 0x2b, 0x00, 0x29, + 0x2a, 0x2a, 0x00, 0x01, 0x00, + 0x00, 0x1d, 0x00, 0x20, + }; + memcpy(postinc(dst, sizeof(tls_header_groups_chrome)), tls_header_groups_chrome, sizeof(tls_header_groups_chrome)); + OsGetRandomBytes(postinc(dst, 32), 32); + dst[-1] &= 0x7f; // clear top bit of x25519 key + + static const uint8 tls_header4_chrome[] = { + // extension psk_key_exchange_modes + 0x00, 0x2d, 0x00, 0x02, 0x01, 0x01, + // extension supported versions + 0x00, 0x2b, 0x00, 0x0b, 0x0a, 0x1a, 0x1a, 0x03, 0x04, 0x03, 0x03, 0x03, 0x02, 0x03, 0x01, + // extension unknown type 27 + 0x00, 0x1b, 0x00, 0x03, 0x02, 0x00, 0x02, + // extension reserved (grease) + 0xea, 0xea, 0x00, 0x01, 0x00, + // extension pre shared key length=235 + 0x00, 0x29, 0x00, 0xeb, + // identities length=198, psk identity length = 192 + 0x00, 0xc6, 0x00, 0xc0, + }; + memcpy(postinc(dst, sizeof(tls_header4_chrome)), tls_header4_chrome, sizeof(tls_header4_chrome)); + } + + OsGetRandomBytes(postinc(dst, 192 + 4), 192 + 4); + static const uint8 tls_header5[] = { + // psk binders length + 0x00, 0x21, + }; + memcpy(postinc(dst, sizeof(tls_header5)), tls_header5, sizeof(tls_header5)); + OsGetRandomBytes(postinc(dst, 33), 33); + + // Fixup lengths + WriteBE16(handshake_length, (uint)(dst - dst_org - 5)); + WriteBE16(handshake_inner_length, (uint)(dst - dst_org - 9)); + WriteBE16(extensions_length, (uint)(dst - extensions_length - 2)); + + // Setup the key generator for outgoing packets. It will be the blake2s hash of + // the full message excluding the tls header. + SetChachaStreamingKey(&encryptor_, dst_org + 5, dst - dst_org - 5); + + static const uint8 tls_header6[] = { + // change cipher spec + 0x14, 0x03, 0x03, 0x00, 0x01, 0x01 + }; + memcpy(postinc(dst, sizeof(tls_header6)), tls_header6, sizeof(tls_header6)); + + return dst - dst_org; +} + +size_t TcpPacketHandler::CreateTls13ServerHello(uint8 *dst) { + if (!decryptor_initialized_) + return ~(size_t)0; + + uint8 *dst_org = dst; + // handshake, tls 1.0 + *dst++ = 0x16; + *dst++ = 0x03; + *dst++ = 0x03; + uint8 *handshake_length = postinc(dst, 2); + // handshake client hello + *dst++ = 0x02; + *dst++ = 0x00; + uint8 *handshake_inner_length = postinc(dst, 2); + // version = tls 1.2 + *dst++ = 0x03; + *dst++ = 0x03; + // 32 byte random + OsGetRandomBytes(postinc(dst, 32), 32); + *dst++ = 0x20; // Session length = 32 + // 32 byte session id taken from client hello. + memcpy(postinc(dst, 32), tls_session_id_, 32); + // cipher suite + *dst++ = 0x13; + *dst++ = 0x01; + // compression method + *dst++ = 0x00; + + uint8 *extensions_length = postinc(dst, 2); + static const uint8 tls_s_header0[] = { + // extension pre_shared_key + 0x00, 0x29, 0x00, 0x02, 0x00, 0x00, + // extension key share with x25519 key + 0x00, 0x33, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, + }; + memcpy(postinc(dst, sizeof(tls_s_header0)), tls_s_header0, sizeof(tls_s_header0)); + OsGetRandomBytes(postinc(dst, 32), 32); + dst[-1] &= 0x7f; // clear top bit of x25519 key + + static const uint8 tls_s_header1[] = { + // extension supported version tls1.3 + 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04, + }; + memcpy(postinc(dst, sizeof(tls_s_header1)), tls_s_header1, sizeof(tls_s_header1)); + + WriteBE16(handshake_length, (uint)(dst - dst_org - 5)); + WriteBE16(handshake_inner_length, (uint)(dst - dst_org - 9)); + WriteBE16(extensions_length, (uint)(dst - extensions_length - 2)); + + // Setup the key generator for outgoing packets. It will be the blake2s hash of + // the full message excluding the tls header. + SetChachaStreamingKey(&encryptor_, dst_org + 5, dst - dst_org - 5); + + static const uint8 tls_header6[] = { + // change cipher spec + 0x14, 0x03, 0x03, 0x00, 0x01, 0x01 + }; + memcpy(postinc(dst, sizeof(tls_header6)), tls_header6, sizeof(tls_header6)); + + return dst - dst_org; +} + +// Normal packet without obfuscation +void TcpPacketHandler::PrepareOutgoingPacketsNormal(Packet *p) { + uint8 *data = p->data; + uint data_size = p->size, packet_type = ReadLE32(data); + p->prepared = true; + if (packet_type == 4) { + assert(data_size >= 16); + uint32 key = Read32(data + 4); + uint64 serial = ReadLE64(data + 8); + if (((predicted_key_out_ ^ key) | (exch(predicted_serial_out_, serial) ^ (serial - 1))) == 0) { + p->data = data + 14; + p->size = data_size - 14; + WriteBE16(p->data, 0x8000 + data_size - 16); + return; + } + predicted_key_out_ = key; + } + p->size = data_size + 2; + p->data = data - 2; + WriteBE16(p->data, data_size); +} + +// Obfuscated stream that looks totally random +void TcpPacketHandler::PrepareOutgoingPacketsObfuscate(Packet *p) { + uint8 *data = p->data; + uint data_size = p->size, packet_type = ReadLE32(data); + p->prepared = true; + // When obfuscation is enabled, inject random shit into packets. + if ((packet_type == 4 && data_size <= 32) || packet_type < 4) { + if (packet_type != 4) { + assert(data_size >= 48); + // The 39:th (for handshake init) and 43:rd byte (for handshake response) + // have zero MSB because of curve25519 pubkey, so xor it with random. + if (packet_type < 4) + data[35 + packet_type * 4] ^= data[15]; + } else { + predicted_key_out_ = Read32(data + 4); + predicted_serial_out_ = ReadLE64(data + 8); + } + data_size = (uint)WgPacketObfuscator::InsertRandomBytesIntoPacket(data, data_size); + } else if (packet_type == 4) { + assert(data_size >= 16); + uint32 key = Read32(data + 4); + uint64 serial = ReadLE64(data + 8); + if (((exch(predicted_key_out_, key) ^ key) | (exch(predicted_serial_out_, serial) ^ (serial - 1))) == 0) { + p->data = data + 14; + p->size = data_size - 14; + WriteBE16(p->data, 0x8000 + data_size - 16); + chacha20_streaming_crypt(&encryptor_, p->data, 2); + return; + } + } + p->data = data - 2; + p->size = data_size + 2; + WriteBE16(p->data, data_size); + chacha20_streaming_crypt(&encryptor_, p->data, 18); +} + +static void PrependTlsApplicationData(Packet *p, uint data_size) { + p->size += 5; + p->data -= 5; + p->data[0] = 0x17; + p->data[1] = 0x03; + p->data[2] = 0x03; + p->data[4] = (uint8)data_size; + p->data[3] = (uint8)(data_size >> 8); +} + +void TcpPacketHandler::PrepareOutgoingPacketsTLS13(Packet *p) { + // Collect a number of packets, but add just a single TLS header + uint total_size = 0; + Packet *cur = p; + do { + PrepareOutgoingPacketsObfuscate(cur); + total_size += cur->size; + } while (total_size < 12000 && (cur = Packet_NEXT(cur))); + PrependTlsApplicationData(p, total_size); +} + +Packet *TcpPacketHandler::GetNextWireguardPacketObfuscate(TcpPacketQueue *queue) { + if (read_state_ == READ_CRYPTO_HEADER) { + // Wait for the 64 bytes of crypto header, they will + // be used to seed the decryptor. + if (queue->size() < CRYPTO_HEADER_SIZE) return NULL; - if (packet_size + 2 > kPacketCapacity) { - RERROR("Oversized packet?"); + Packet *packet = queue->Read(CRYPTO_HEADER_SIZE); + if (!packet) + return NULL; + SetChachaStreamingKey(&decryptor_, packet->data, CRYPTO_HEADER_SIZE); + queue->pool()->FreePacketToPool(packet); + read_state_ = READ_PACKET_HEADER; + } else if (read_state_ == READ_PACKET_DATA) { + goto case_READ_PACKET_DATA; + } + + while (queue->size() >= 2) { + // Peek and decrypt the packet header + queue->Read(packet_header_, 2); + chacha20_streaming_crypt(&decryptor_, packet_header_, 2); +case_READ_PACKET_DATA: + uint32 packet_header = ReadBE16(packet_header_); + uint32 packet_size = packet_header & 0x7FFF; + if (packet_size > kPacketCapacity) { +error: error_flag_ = true; return NULL; } - Packet *packet = ReadNextPacket(packet_size + 2); - if (packet) { -// RINFO("Packet of type %d, size %d", packet_type, packet->size - 2); + if (packet_size > queue->size()) { + read_state_ = READ_PACKET_DATA; + return NULL; + } + read_state_ = READ_PACKET_HEADER; + Packet *packet = queue->Read(packet_size); + if (!packet) + goto error; + // RINFO("Packet of type %d, size %d", packet_type, packet->size - 2); + if (!(packet_header & 0x8000)) { + unsigned int size = packet->size; + // decrypt the initial 16 bytes of the packet + if (size < 16) + goto error; + chacha20_streaming_crypt(&decryptor_, packet->data, 16); + // Discard any extra junk bytes appended at the end. + if (packet->data[0] <= 4) { + if (packet->data[3] > size) + goto error; + packet->size = (size -= packet->data[3]); + packet->data[3] = 0; + // The 39:th (for handshake init) and 43:rd byte (for handshake response) + // have zero MSB because of curve25519 pubkey, so xor it with random. + if (packet->data[0] < 4 && size >= 48) + packet->data[35 + packet->data[0] * 4] ^= packet->data[15]; + } + if (packet->data[0] == 4) { + predicted_key_in_ = Read32(packet->data + 4); + predicted_serial_in_ = ReadLE64(packet->data + 8); + } + return packet; + } else { + // Optimization when the 16 first bytes are known and prefixed to the packet + assert(packet->data >= packet->data_buf); + packet->data -= 16, packet->size += 16; + predicted_serial_in_++; + WriteLE32(packet->data, 4); + Write32(packet->data + 4, predicted_key_in_); + WriteLE64(packet->data + 8, predicted_serial_in_); + return packet; + } + } + return NULL; +} + +Packet *TcpPacketHandler::GetNextWireguardPacketNormal() { + while (queue_.size() >= 2) { + uint32 packet_header = queue_.PeekUint16(); + uint32 packet_size = packet_header & 0x7FFF; + if (packet_size + 2 > kPacketCapacity) { +error: + error_flag_ = true; + return NULL; + } + if (packet_size + 2 > queue_.size()) + return NULL; + Packet *packet = queue_.Read(packet_size + 2); + if (!packet) + goto error; + if (!(packet_header & 0x8000)) { packet->data += 2, packet->size -= 2; - if (packet_type == kTcpPacketType_Normal) { + if (packet->data[0] == 4 && packet->size >= 16) { + predicted_key_in_ = Read32(packet->data + 4); + predicted_serial_in_ = ReadLE64(packet->data + 8); + } + } else { + // Optimization when the 16 first bytes are known and prefixed to the packet + assert(packet->data >= packet->data_buf); + packet->data -= 14, packet->size += 14; + predicted_serial_in_++; + WriteLE32(packet->data, 4); + Write32(packet->data + 4, predicted_key_in_); + WriteLE64(packet->data + 8, predicted_serial_in_); + } + return packet; + } + return NULL; +} - return packet; - } else if (packet_type == kTcpPacketType_Data) { - // Optimization when the 16 first bytes are known and prefixed to the packet - assert(packet->data >= packet->data_buf); +#define TLS_ASYNC_BEGIN() switch (tls_read_state_) { +#define TLS_ASYNC_RESUMEPOINT(label) tls_read_state_ = (label); case label: +#define TLS_ASYNC_WAIT(expr, label) case label: if (!(expr)) { tls_read_state_ = (label); return NULL; } +#define TLS_ASYNC_END() } - packet->data -= 16, packet->size += 16; - WriteLE32(packet->data, 4); - Write32(packet->data + 4, predicted_key_in_); - WriteLE64(packet->data + 8, predicted_serial_in_); - predicted_serial_in_++; - return packet; - } else if (packet_type == kTcpPacketType_Control) { - // Unknown control packets are silently ignored - if (packet->size == 13 && packet->data[0] == kTcpPacketControlType_SetKeyAndCounter) { - // Control packet to setup the predicted key/sequence nr - predicted_key_in_ = Read32(packet->data + 1); - predicted_serial_in_ = ReadLE64(packet->data + 5); +// Unwrap the TLS framing +Packet *TcpPacketHandler::GetNextWireguardPacketTLS13() { + uint8 header[5]; + Packet *packet; + + enum { + TLS_STATE_INIT = 0, + TLS_WAIT_HANDSHAKE = 1, + TLS_WAIT_DATA = 2, + TLS_READ_PACKETS = 3, + TLS_WAIT_JUNK = 4, + TLS_ERROR = 5, + }; + TLS_ASYNC_BEGIN(); + for(;;) { + TLS_ASYNC_WAIT(queue_.size() >= 5, TLS_STATE_INIT); + queue_.Read(header, 5); + tls_bytes_left_ = ReadBE16(header + 3); + if (header[0] == 23) { + if (!decryptor_initialized_) + goto error; // no key yet + // Read the next |tls_bytes_left_| bytes and push them to the tls_queue_. + while (tls_bytes_left_ != 0) { + TLS_ASYNC_WAIT(queue_.size() != 0, TLS_WAIT_DATA); + if (!(packet = queue_.ReadUpTo(tls_bytes_left_))) goto error; + tls_bytes_left_ -= packet->size; + tls_queue_.Add(packet); + TLS_ASYNC_RESUMEPOINT(TLS_READ_PACKETS); + if ((packet = GetNextWireguardPacketObfuscate(&tls_queue_)) != NULL) + return packet; + } + } else { + if (tls_bytes_left_ > kPacketCapacity) + goto error; // too large packet? + if (header[0] == 22) { + TLS_ASYNC_WAIT(tls_bytes_left_ <= queue_.size(), TLS_WAIT_HANDSHAKE); + if (!(packet = queue_.Read(tls_bytes_left_))) + goto error; // eom + // Initialize decryptor + if (!decryptor_initialized_ && packet->size >= 39 + 32) { + // Store the session ID, so we can include it in server hello. + memcpy(tls_session_id_, packet->data + 39, 32); + // Initialize chacha decryptor + SetChachaStreamingKey(&decryptor_, packet->data, packet->size); + decryptor_initialized_ = true; } - packet_pool_->FreePacketToPool(packet); + FreePacket(packet); + } else if (header[0] == 20) { + TLS_ASYNC_WAIT(tls_bytes_left_ <= queue_.size(), TLS_WAIT_JUNK); + if (!(packet = queue_.Read(tls_bytes_left_))) + goto error; // eom + FreePacket(packet); } else { - packet_pool_->FreePacketToPool(packet); +error: + TLS_ASYNC_RESUMEPOINT(TLS_ERROR); error_flag_ = true; return NULL; } } } + TLS_ASYNC_END(); return NULL; -} \ No newline at end of file +} + +void TcpPacketHandler::PrepareOutgoingPacketsWithHeader(Packet *p) { + uint8 buf[1024]; + size_t hello_size; + + if (obfuscation_mode_ == kObfuscationMode_Encrypted) { + // Ensure it doesn't look like a tls or a regular packet. + do { + OsGetRandomBytes(buf, CRYPTO_HEADER_SIZE); + } while (ReadBE16(buf) == 0x1603 || ReadBE16(buf) <= 1500); + + SetChachaStreamingKey(&encryptor_, buf, CRYPTO_HEADER_SIZE); + hello_size = CRYPTO_HEADER_SIZE; + } else { + hello_size = (write_state_ == 0) ? CreateTls13ClientHello(buf) : CreateTls13ServerHello(buf); + // This could fail if the server tries to send a packet before the client sent hello. + if (hello_size == ~(size_t)0) { + RERROR("Trying to send server message before client hello"); + p->size = 0; + return; + } + } + write_state_ = 2; + PrepareOutgoingPackets(p); + if (hello_size + p->size > kPacketCapacity) { + RERROR("Outgoing TCP packet too big."); + return; + } + memmove(p->data_buf + hello_size, exch(p->data, p->data_buf), postinc(p->size, (uint)hello_size)); + memcpy(p->data_buf, buf, hello_size); +} + + +void TcpPacketHandler::PrepareOutgoingPackets(Packet *p) { + if (obfuscation_mode_ == kObfuscationMode_None) { + PrepareOutgoingPacketsNormal(p); + } else { + if (write_state_ != 2) { + PrepareOutgoingPacketsWithHeader(p); + return; + } + if (obfuscation_mode_ == kObfuscationMode_Encrypted) + PrepareOutgoingPacketsObfuscate(p); + else + PrepareOutgoingPacketsTLS13(p); + } +} + +Packet *TcpPacketHandler::GetNextWireguardPacket() { + // If this is an incoming connection, try to guess what type of obfuscation + // we're using, if any. + for (;;) { + if (obfuscation_mode_ == kObfuscationMode_None) + return GetNextWireguardPacketNormal(); + else if (obfuscation_mode_ == kObfuscationMode_Encrypted) + return GetNextWireguardPacketObfuscate(&queue_); + else if (obfuscation_mode_ != kObfuscationMode_Autodetect) + return GetNextWireguardPacketTLS13(); + + // Try and autodetect based on the first 2 bytes. + if (queue_.size() < 2) + return NULL; + + uint16 header = queue_.PeekUint16(); + if (header == 0x1603) { + // This is a SSL client hello, but don't know if it's + // chrome or ff, so use ff. + obfuscation_mode_ = kObfuscationMode_TlsFirefox; + } else if (header <= 1500) { + // Unobfuscated wireguard headers always start with a low value. + obfuscation_mode_ = kObfuscationMode_None; + } else { + read_state_ = READ_CRYPTO_HEADER; + obfuscation_mode_ = kObfuscationMode_Encrypted; + } + } +} + + +#if defined(OS_WIN) || defined(USE_MULTITHREADED_NETWORKING) +void SimplePacketPool::FreeSomePacketsInner() { + int n = freed_packets_count_ - 24; + Packet **p = &freed_packets_; + for (; n; n--) + p = &Packet_NEXT(*p); + FreePackets(exch(freed_packets_, *p), p, exch(freed_packets_count_, 24) - 24); +} +#endif + + diff --git a/network_common.h b/network_common.h index 8814cd4..b7c8891 100644 --- a/network_common.h +++ b/network_common.h @@ -2,8 +2,10 @@ #define TUNSAFE_NETWORK_COMMON_H_ #include "netapi.h" +#include "crypto/chacha20poly1305.h" class PacketProcessor; +class WgPacketObfuscator; // A simple singlethreaded pool of packets used on windows where // FreePacket / AllocPacket are multithreded and thus slightly slower @@ -54,19 +56,55 @@ public: #endif +class TcpPacketQueue { +public: + explicit TcpPacketQueue(SimplePacketPool *pool) : rqueue_bytes_(0), rqueue_(NULL), rqueue_end_(&rqueue_), pool_(pool) {} + ~TcpPacketQueue(); + + Packet *Read(uint num); + Packet *ReadUpTo(uint num); + void Read(uint8 *dst, uint num); + uint PeekUint16(); + + void Add(Packet *packet); + + uint32 size() const { return rqueue_bytes_; } + + SimplePacketPool *pool() { return pool_; } +private: + // Total # of bytes queued + uint rqueue_bytes_; + + // Buffered data + Packet *rqueue_, **rqueue_end_; + + SimplePacketPool *pool_; +}; + // Aids with prefixing and parsing incoming and outgoing // packets with the tcp protocol header. class TcpPacketHandler { public: - explicit TcpPacketHandler(SimplePacketPool *packet_pool); + enum { + kObfuscationMode_Unspecified = -1, + kObfuscationMode_None = 0, + kObfuscationMode_Encrypted = 1, + kObfuscationMode_TlsFirefox = 2, + kObfuscationMode_TlsChrome = 3, + kObfuscationMode_Autodetect = 4, + }; + + explicit TcpPacketHandler(SimplePacketPool *packet_pool, WgPacketObfuscator *obfuscator, bool is_incoming); ~TcpPacketHandler(); // Adds a tcp header to a data packet so it can be transmitted on the wire - void AddHeaderToOutgoingPacket(Packet *p); + void PrepareOutgoingPackets(Packet *p); // Add a new chunk of incoming data to the packet list - void QueueIncomingPacket(Packet *p); + void QueueIncomingPacket(Packet *p) { + queue_.Add(p); + } // Attempt to extract the next packet, returns NULL when complete. Packet *GetNextWireguardPacket(); @@ -74,22 +112,43 @@ public: bool error() const { return error_flag_; } private: - // Internal function to read a packet - Packet *ReadNextPacket(uint32 num); - - SimplePacketPool *packet_pool_; + void PrepareOutgoingPacketsNormal(Packet *p); + void PrepareOutgoingPacketsObfuscate(Packet *p); + void PrepareOutgoingPacketsTLS13(Packet *p); + void PrepareOutgoingPacketsWithHeader(Packet *p); - // Total # of bytes queued - uint32 rqueue_bytes_; + Packet *GetNextWireguardPacketNormal(); + Packet *GetNextWireguardPacketObfuscate(TcpPacketQueue *queue); + Packet *GetNextWireguardPacketTLS13(); + + size_t CreateTls13ClientHello(uint8 *dst); + size_t CreateTls13ServerHello(uint8 *dst); + // Set if there's a fatal error bool error_flag_; + uint8 obfuscation_mode_; + uint8 read_state_, write_state_, tls_read_state_; + bool decryptor_initialized_; - // These hold the incoming packets before they're parsed - Packet *rqueue_, **rqueue_end_; + uint8 packet_header_[2]; + + // Number of data bytes left + uint tls_bytes_left_; + + TcpPacketQueue queue_; + + // There's a separate queue for tls since it unwraps stuff + TcpPacketQueue tls_queue_; uint32 predicted_key_in_, predicted_key_out_; uint64 predicted_serial_in_, predicted_serial_out_; + + // For obfuscating + chacha20_streaming encryptor_, decryptor_; + + uint8 tls_session_id_[32]; + }; #endif // TUNSAFE_NETWORK_COMMON_H_ \ No newline at end of file diff --git a/network_win32.cpp b/network_win32.cpp index 620764f..389fc38 100644 --- a/network_win32.cpp +++ b/network_win32.cpp @@ -90,13 +90,6 @@ void FreeAllPackets() { } } -void SimplePacketPool::FreeSomePacketsInner() { - int n = freed_packets_count_ - 24; - Packet **p = &freed_packets_; - for (; n; n--) - p = &Packet_NEXT(*p); - FreePackets(exch(freed_packets_, *p), p, exch(freed_packets_count_, 24) - 24); -} void InitPacketMutexes() { static bool mutex_inited; @@ -169,7 +162,7 @@ static bool RunNetsh(const char *cmdline) { // Open the TAP adapter, either a random one or a specific one // On return, the adapter is locked in |TunAdaptersInUse|. -static HANDLE OpenTunAdapter(char guid[ADAPTER_GUID_SIZE], TunsafeBackendWin32 *backend, DWORD open_flags) { +static HANDLE OpenTunAdapter(char guid[ADAPTER_GUID_SIZE], TunsafeRunner *runner, DWORD open_flags) { char path[128]; HANDLE h; int retries = 0; @@ -196,7 +189,7 @@ RETRY: int error_code = 0; for (GuidAndDevName &x : adapters) { snprintf(path, sizeof(path), "\\\\.\\Global\\%s.tap", x.guid); - if (tun_adapters_in_use->Acquire(x.guid, static_cast(backend))) { + if (tun_adapters_in_use->Acquire(x.guid, static_cast(runner->backend()))) { h = CreateFile(path, GENERIC_READ | GENERIC_WRITE, 0, 0, OPEN_EXISTING, FILE_ATTRIBUTE_SYSTEM | open_flags, 0); if (h != INVALID_HANDLE_VALUE) { memcpy(guid, x.guid, ADAPTER_GUID_SIZE); @@ -204,7 +197,7 @@ RETRY: } did_try_adapter = true; error_code = GetLastError(); - tun_adapters_in_use->Release(static_cast(backend)); + tun_adapters_in_use->Release(static_cast(runner->backend())); } } if (!did_try_adapter) { @@ -214,7 +207,7 @@ RETRY: // Sometimes if you close the device right before, it will fail to open with errorcode 31. // When resuming from sleep in my VM, the error code is ERROR_FILE_NOT_FOUND - if ((error_code == ERROR_FILE_NOT_FOUND || error_code == ERROR_GEN_FAILURE) && !backend->exit_code()) { + if ((error_code == ERROR_FILE_NOT_FOUND || error_code == ERROR_GEN_FAILURE) && !runner->exit_code()) { if (retries <= 10) { RERROR("OpenTapAdapter: CreateFile failed: 0x%X... retrying%s", error_code, retries == 10 ? " (last notice)" : ""); if (retries == 10) { @@ -223,12 +216,12 @@ RETRY: } else if (error_code == ERROR_GEN_FAILURE) { RERROR(" Please ensure that the TAP device is not in use."); } - backend->SetStatus(TunsafeBackend::kStatusTunRetrying); + runner->backend()->SetStatus(TunsafeBackend::kStatusTunRetrying); } } int sleep_amount = 250 * std::min(++retries, 40); for (;;) { - if (backend->exit_code()) + if (runner->exit_code()) return NULL; if (sleep_amount == 0) break; @@ -699,7 +692,7 @@ void UdpSocketWin32::DoIO() { } //////////////////////////////////////////////////////////////////////////////////////////////////////// -NetworkWin32::NetworkWin32() : udp_socket_(this), tcp_socket_queue_(this) { +NetworkWin32::NetworkWin32() : udp_socket_(this) { exit_thread_ = false; thread_ = NULL; completion_port_handle_ = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, NULL, 0); @@ -803,21 +796,6 @@ void NetworkWin32::PostQueuedItem(QueuedItem *item) { PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, &item->overlapped); } -bool NetworkWin32::Configure(int listen_port, int listen_port_tcp) { - if (listen_port_tcp) - RERROR("ListenPortTCP not supported in this version"); - return udp_socket_.Configure(listen_port); -} - -// Called from tunsafe thread -void NetworkWin32::WriteUdpPacket(Packet *packet) { - if (packet->protocol & kPacketProtocolUdp) { - udp_socket_.WriteUdpPacket(packet); - } else { - tcp_socket_queue_.WritePacket(packet); - } -} - ///////////////////////////////////////////////////////////////////////// PacketProcessor::PacketProcessor() { @@ -829,6 +807,7 @@ PacketProcessor::PacketProcessor() { timer_interrupt_ = false; packets_in_queue_ = 0; need_notify_ = 0; + udp_cb_maybe_deobfuscate_ = &udp_cb_; } PacketProcessor::~PacketProcessor() { @@ -866,13 +845,13 @@ void PacketProcessor::Reset() { } } -int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) { +int PacketProcessor::Run(WireguardProcessor *wg, TunsafeRunner *runner) { int free_packets_ctr = 0; int overload = 0; int exit_code; QueuedItem *packet; PTP_TIMER threadpool_timer; - QueueContext queue_context = {wg, backend}; + QueueContext queue_context = {wg, runner}; threadpool_timer = CreateThreadpoolTimer(&ThreadPoolTimerCallback, this, NULL); static const int64 duetime = -10000000; // the unit is 100ns @@ -880,25 +859,13 @@ int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) { mutex_.Acquire(); while (!(exit_code = exit_code_)) { - FreeAllPackets(); - if (timer_interrupt_) { timer_interrupt_ = false; need_notify_ = 0; mutex_.Release(); wg->SecondLoop(); - backend->stats_mutex_.Acquire(); - backend->stats_ = wg->GetStats(); - float data[2] = { - // unit is megabits/second - backend->stats_.tun_bytes_in_per_second * (1.0f / 125000), - backend->stats_.tun_bytes_out_per_second * (1.0f / 125000), - }; - backend->stats_collector_.AddSamples(data); - backend->stats_mutex_.Release(); - backend->delegate_->OnGraphAvailable(); - backend->PushStats(); + runner->CollectStats(); // Conserve memory every 10s if (free_packets_ctr++ == 10) { @@ -933,7 +900,6 @@ int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) { wg->RunAllMainThreadScheduled(); mutex_.Acquire(); } - exit_code_ = 0; mutex_.Release(); SetThreadpoolTimer(threadpool_timer, nullptr, 0, 0); @@ -970,9 +936,7 @@ void PacketProcessorDeobfuscateUdpCb::OnQueuedItemEvent(QueuedItem *qi, uintptr_ void PacketProcessor::PostExit(int exit_code) { mutex_.Acquire(); - // Avoid race condition where mode_tun_failed is set during thread exit. - if (exit_code_ != TunsafeBackendWin32::MODE_RESTART && exit_code_ != TunsafeBackendWin32::MODE_EXIT) - exit_code_ = exit_code; + exit_code_ = exit_code; mutex_.Release(); SetEvent(event_); } @@ -1229,12 +1193,12 @@ TunWin32Adapter::~TunWin32Adapter() { } -bool TunWin32Adapter::OpenAdapter(TunsafeBackendWin32 *backend, DWORD open_flags) { +bool TunWin32Adapter::OpenAdapter(TunsafeRunner *runner, DWORD open_flags) { ULONG info[3]; DWORD len; assert(handle_ == NULL); - backend_ = backend; - handle_ = OpenTunAdapter(guid_, backend, open_flags); + backend_ = runner->backend(); + handle_ = OpenTunAdapter(guid_, runner, open_flags); if (handle_ != NULL) { memset(info, 0, sizeof(info)); if (DeviceIoControl(handle_, TAP_IOCTL_GET_VERSION, &info, sizeof(info), @@ -1664,7 +1628,7 @@ bool TunWin32Adapter::RunPrePostCommand(const std::vector &vec) { ////////////////////////////////////////////////////////////////////////////// -TunWin32Iocp::TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker, backend->guid_), backend_(backend) { +TunWin32Iocp::TunWin32Iocp(DnsBlocker *blocker, TunsafeRunner *runner) : adapter_(blocker, runner->backend()->guid_), runner_(runner) { wqueue_end_ = &wqueue_; wqueue_ = NULL; wqueue_size_ = 0; @@ -1680,7 +1644,6 @@ TunWin32Iocp::~TunWin32Iocp() { //assert(num_reads_ == 0 && num_writes_ == 0); assert(thread_ == NULL); CloseTun(false); - FreePacketList(wqueue_); } bool TunWin32Iocp::Configure(const TunConfig &&config, TunConfigOut *out) { @@ -1693,7 +1656,7 @@ bool TunWin32Iocp::Configure(const TunConfig &&config, TunConfigOut *out) { return rv; } CloseTun(true); - if (adapter_.OpenAdapter(backend_, FILE_FLAG_OVERLAPPED)) { + if (adapter_.OpenAdapter(runner_, FILE_FLAG_OVERLAPPED)) { completion_port_handle_ = CreateIoCompletionPort(adapter_.handle(), NULL, NULL, 0); if (completion_port_handle_ != NULL) { if (adapter_.ConfigureAdapter(std::move(config), out)) @@ -1707,10 +1670,9 @@ bool TunWin32Iocp::Configure(const TunConfig &&config, TunConfigOut *out) { void TunWin32Iocp::CloseTun(bool is_restart) { assert(thread_ == NULL); adapter_.CloseAdapter(is_restart); - if (completion_port_handle_) { - CloseHandle(completion_port_handle_); - completion_port_handle_ = NULL; - } + if (completion_port_handle_) + CloseHandle(exch_null(completion_port_handle_)); + FreePacketList(wqueue_); } enum { @@ -1758,7 +1720,7 @@ void TunWin32Iocp::ThreadMain() { if (err == ERROR_OPERATION_ABORTED || err == ERROR_FILE_NOT_FOUND) { RERROR("TAP driver stopped communicating. Attempting to restart.", err); // This can happen if we reinstall the TAP driver while there's an active connection. - backend_->PostExit(TunsafeBackendWin32::MODE_TUN_FAILED); + runner_->PostTunRestart(); goto EXIT; } } else { @@ -1921,6 +1883,139 @@ void TunWin32Iocp::WriteTunPacket(Packet *packet) { ////////////////////////////////////////////////////////////////////////////// +TunsafeRunner::TunsafeRunner(TunsafeBackendWin32 *backend) + : backend_(backend), + tun_(&backend->dns_blocker_, this), + wg_proc_(this, &tun_, this), + plugin_(CreateTunsafePlugin(this, &wg_proc_)), + tcp_socket_queue_(&net_, &wg_proc_.dev().packet_obfuscator()) { + + wg_proc_.dev().SetPlugin(plugin_); + + net_.udp().SetPacketHandler(&packet_processor_); + tcp_socket_queue_.SetPacketHandler(&packet_processor_); + tun_.SetPacketHandler(&packet_processor_); +} + +TunsafeRunner::~TunsafeRunner() { + wg_proc_.dev().SetCurrentThreadAsMainThread(); + delete plugin_; +} + +bool TunsafeRunner::Configure(int listen_port, int listen_port_tcp) { + if (listen_port_tcp) + RERROR("ListenPortTCP not supported in this version"); + return net_.udp().Configure(listen_port); +} + +void TunsafeRunner::WriteUdpPacket(Packet *packet) { + if (packet->protocol & kPacketProtocolUdp) { + if (wg_proc_.dev().packet_obfuscator().enabled()) + wg_proc_.dev().packet_obfuscator().ObfuscatePacket(packet); + net_.udp().WriteUdpPacket(packet); + } else { + tcp_socket_queue_.WritePacket(packet); + } +} + +void TunsafeRunner::OnConnected() { + TunsafeBackendWin32 *backend = backend_; + if (backend->status() != TunsafeBackend::kStatusConnected) { + const WgCidrAddr *ipv4_addr = NULL; + for (const WgCidrAddr &x : wg_proc_.addr()) { + if (x.size == 32) { + ipv4_addr = &x; + break; + } + } + backend->ipv4_ip_ = ipv4_addr ? ReadBE32(ipv4_addr->addr) : 0; + if (backend->status() != TunsafeBackend::kStatusReconnecting) { + char buf[kSizeOfAddress]; + RINFO("Connection established. IP %s", ipv4_addr ? print_ip_prefix(buf, AF_INET, ipv4_addr->addr, -1) : "(none)"); + } + backend->SetStatus(TunsafeBackend::kStatusConnected); + } +} + +void TunsafeRunner::OnConnectionRetry(uint32 attempts) { + TunsafeBackendWin32 *backend = backend_; + if (backend->status() == TunsafeBackend::kStatusInitializing) + backend->SetStatus(TunsafeBackend::kStatusConnecting); + else if (attempts >= 3 && backend->status() == TunsafeBackend::kStatusConnected) + backend->SetStatus(TunsafeBackend::kStatusReconnecting); +} + + +bool TunsafeRunner::Start() { + wg_proc_.dev().SetCurrentThreadAsMainThread(); + + if (config_file_.size()) { + if (config_file_is_text_format_) { + if (!ParseWireGuardConfigString(&wg_proc_, config_file_.c_str(), config_file_.size(), &backend_->dns_resolver_)) + return false; + } else { + if (!ParseWireGuardConfigFile(&wg_proc_, config_file_.c_str(), &backend_->dns_resolver_)) + return false; + } + } + if (wg_proc_.dev().packet_obfuscator().enabled()) + packet_processor_.EnableDeobfuscation(); + + if (!wg_proc_.Start()) + return false; + + backend_->SetPublicKey(wg_proc_.dev().public_key()); + + net_.StartThread(); + tun_.StartThread(); + int stop_mode = packet_processor_.Run(&wg_proc_, this); + net_.StopThread(); + tun_.StopThread(); + + if (stop_mode != TunsafeBackendWin32::MODE_EXIT) + tun_.adapter().DisassociateDnsBlocker(); + else + backend_->dns_resolver_.ClearCache(); + + return true; +} + +void TunsafeRunner::PostTunRestart() { + QueuedItem *qi = new QueuedItem; + qi->queue_cb = this; + packet_processor_.ForcePost(qi); +} + +void TunsafeRunner::OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) { + backend_->SetStatus(TunsafeBackend::kStatusTunRetrying); + RINFO("Restarting TUN adapter"); + Sleep(1000); + wg_proc_.ConfigureTun(); + delete ow; +} + +void TunsafeRunner::OnQueuedItemDelete(QueuedItem *ow) { + + delete ow; +} + + +void TunsafeRunner::OnRequestToken(WgPeer *peer, uint32 type) { + backend_->OnRequestToken(peer, type); +} + +void TunsafeRunner::CollectStats() { + backend_->CollectStats(); +} + +void TunsafeRunner::SetConfigFile(const char *file, bool is_text_format) { + config_file_is_text_format_ = is_text_format; + config_file_ = file; +} + +////////////////////////////////////////////////////////////////////////////// + + TunsafeBackend::TunsafeBackend() { is_started_ = false; is_remote_ = false; @@ -1946,8 +2041,8 @@ static void RemoveKillSwitchRoute() { TunsafeBackendWin32::TunsafeBackendWin32(Delegate *delegate) : delegate_(delegate), dns_resolver_(&dns_blocker_) { memset(&stats_, 0, sizeof(stats_)); - wg_processor_ = NULL; token_request_ = 0; + runner_ = NULL; InitPacketMutexes(); worker_thread_ = NULL; last_tun_adapter_failed_ = 0; @@ -1972,79 +2067,28 @@ void TunsafeBackendWin32::SetPublicKey(const uint8 key[32]) { delegate_->OnStateChanged(); } -struct PluginHolder { - PluginHolder(PluginDelegate *del) : plugin(CreateTunsafePlugin(del)) {} - ~PluginHolder() { delete plugin; } - TunsafePlugin *plugin; -}; +void TunsafeBackendWin32::CollectStats() { + stats_mutex_.Acquire(); + stats_ = runner_->wg_proc_.GetStats(); + float data[2] = { + // unit is megabits/second + stats_.tun_bytes_in_per_second * (1.0f / 125000), + stats_.tun_bytes_out_per_second * (1.0f / 125000), + }; + stats_collector_.AddSamples(data); + stats_mutex_.Release(); + + delegate_->OnGraphAvailable(); + PushStats(); +} DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) { TunsafeBackendWin32 *backend = (TunsafeBackendWin32*)bk; - int stop_mode; - int fast_retry_ctr = 0; - - for (;;) { - TunWin32Iocp tun(&backend->dns_blocker_, backend); - NetworkWin32 net; - PluginHolder plugin(backend); - WireguardProcessor wg_proc(&net, &tun, backend); - wg_proc.dev().SetPlugin(plugin.plugin); - plugin.plugin->Initialize(&wg_proc); - - net.udp().SetPacketHandler(&backend->packet_processor_); - net.tcp_socket_queue().SetPacketHandler(&backend->packet_processor_); - - tun.SetPacketHandler(&backend->packet_processor_); - - if (backend->config_file_[0] && - !ParseWireGuardConfigFile(&wg_proc, backend->config_file_, &backend->dns_resolver_)) - goto getout_fail; - - if (!wg_proc.Start()) - goto getout_fail; - - backend->SetPublicKey(wg_proc.dev().public_key()); - - backend->wg_processor_ = &wg_proc; - backend->tunsafe_wg_plugin_ = plugin.plugin; - - net.StartThread(); - tun.StartThread(); - stop_mode = backend->packet_processor_.Run(&wg_proc, backend); - net.StopThread(); - tun.StopThread(); - - backend->wg_processor_ = NULL; - backend->tunsafe_wg_plugin_ = NULL; - - // Keep DNS alive - if (stop_mode != MODE_EXIT) - tun.adapter().DisassociateDnsBlocker(); - else - backend->dns_resolver_.ClearCache(); - - FreeAllPackets(); - - if (stop_mode != MODE_TUN_FAILED) - return 0; - - uint32 last_fail = GetTickCount(); - fast_retry_ctr = (last_fail - backend->last_tun_adapter_failed_ < 5000) ? fast_retry_ctr + 1 : 0; - backend->last_tun_adapter_failed_ = last_fail; - - backend->SetStatus((fast_retry_ctr >= 3) ? TunsafeBackend::kErrorTunPermanent : TunsafeBackend::kStatusTunRetrying); - - if (backend->status_ == TunsafeBackend::kErrorTunPermanent) { - RERROR("Too many automatic restarts..."); - goto getout_fail_noseterr; - } - Sleep(1000); + + if (!backend->runner_->Start()) { + backend->SetStatus(TunsafeBackend::kErrorInitialize); + backend->dns_blocker_.RestoreDns(); } -getout_fail: - backend->status_ = TunsafeBackend::kErrorInitialize; - backend->delegate_->OnStatusCode(TunsafeBackend::kErrorInitialize); -getout_fail_noseterr: - backend->dns_blocker_.RestoreDns(); return 0; } @@ -2104,29 +2148,37 @@ void TunsafeBackendWin32::Start(const char *config_file) { SetStatus(kStatusInitializing); delegate_->OnClearLog(); DWORD thread_id; - config_file_ = _strdup(config_file); + + runner_ = new TunsafeRunner(this); + + // Connect to a server given by an ID. + if (strncmp(config_file, ":srv:", 5) == 0) { +// config_file_is_text_format_ = true; +// auto server = GetServerById(config_file + 5, NULL); +// config_file_ = GetServerConfigFile(server); + } else { + runner_->SetConfigFile(config_file, false); + } + worker_thread_ = CreateThread(NULL, 0, &WorkerThread, this, 0, &thread_id); SetThreadPriority(worker_thread_, THREAD_PRIORITY_ABOVE_NORMAL); delegate_->OnStateChanged(); } -void TunsafeBackendWin32::PostExit(int exit_code) { - packet_processor_.PostExit(exit_code); -} void TunsafeBackendWin32::StopInner(bool is_restart) { - if (worker_thread_) { + if (runner_) { ipv4_ip_ = 0; dns_resolver_.Cancel(); - PostExit(is_restart ? MODE_RESTART : MODE_EXIT); + runner_->packet_processor_.PostExit(is_restart ? MODE_RESTART : MODE_EXIT); WaitForSingleObject(worker_thread_, INFINITE); - CloseHandle(worker_thread_); - worker_thread_ = NULL; - free(config_file_); - config_file_ = NULL; + CloseHandle(exch_null(worker_thread_)); is_started_ = false; status_ = kStatusStopped; - packet_processor_.Reset(); + delete runner_; + runner_ = NULL; + + FreeAllPackets(); uint8 wanted_ibs = (g_killswitch_currconn == kBlockInternet_Default) ? g_killswitch_want : g_killswitch_currconn; if (!is_restart && !(wanted_ibs & kBlockInternet_BlockOnDisconnect)) @@ -2228,9 +2280,9 @@ void ConfigQueueItem::OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) { if (type == SendConfigurationProtocolPacket) { std::string reply; WgConfig::HandleConfigurationProtocolMessage(context->wg, std::move(message), &reply); - context->backend->delegate_->OnConfigurationProtocolReply(ident, std::move(reply)); + context->runner->backend()->delegate_->OnConfigurationProtocolReply(ident, std::move(reply)); } else { - context->backend->tunsafe_wg_plugin_->SubmitToken((const uint8*)message.data(), message.size()); + context->runner->plugin()->SubmitToken((const uint8*)message.data(), message.size()); } delete this; } @@ -2240,24 +2292,27 @@ void ConfigQueueItem::OnQueuedItemDelete(QueuedItem *ow) { } void TunsafeBackendWin32::SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) { - ConfigQueueItem *queue_item = new ConfigQueueItem; - queue_item->type = ConfigQueueItem::SendConfigurationProtocolPacket; - queue_item->ident = identifier; - queue_item->message = std::move(message); - queue_item->queue_cb = queue_item; - packet_processor_.ForcePost(queue_item); + if (runner_) { + ConfigQueueItem *queue_item = new ConfigQueueItem; + queue_item->type = ConfigQueueItem::SendConfigurationProtocolPacket; + queue_item->ident = identifier; + queue_item->message = std::move(message); + queue_item->queue_cb = queue_item; + runner_->packet_processor_.ForcePost(queue_item); + } } void TunsafeBackendWin32::SubmitToken(const std::string &&message) { - // Clear out the old token request so GetTokenRequest returns zero. - token_request_ = 0; - - ConfigQueueItem *queue_item = new ConfigQueueItem; - queue_item->type = ConfigQueueItem::SubmitToken; - queue_item->message = std::move(message); - queue_item->queue_cb = queue_item; - packet_processor_.ForcePost(queue_item); - + if (runner_) { + // Clear out the old token request so GetTokenRequest returns zero. + token_request_ = 0; + + ConfigQueueItem *queue_item = new ConfigQueueItem; + queue_item->type = ConfigQueueItem::SubmitToken; + queue_item->message = std::move(message); + queue_item->queue_cb = queue_item; + runner_->packet_processor_.ForcePost(queue_item); + } } uint32 TunsafeBackendWin32::GetTokenRequest() { @@ -2271,32 +2326,6 @@ void TunsafeBackendWin32::OnRequestToken(WgPeer *peer, uint32 type) { delegate_->OnStateChanged(); } - -void TunsafeBackendWin32::OnConnected() { - if (status_ != TunsafeBackend::kStatusConnected) { - const WgCidrAddr *ipv4_addr = NULL; - for (const WgCidrAddr &x : wg_processor_->addr()) { - if (x.size == 32) { - ipv4_addr = &x; - break; - } - } - ipv4_ip_ = ipv4_addr ? ReadBE32(ipv4_addr->addr) : 0; - if (status_ != TunsafeBackend::kStatusReconnecting) { - char buf[kSizeOfAddress]; - RINFO("Connection established. IP %s", ipv4_addr ? print_ip_prefix(buf, AF_INET, ipv4_addr->addr, -1) : "(none)"); - } - SetStatus(TunsafeBackend::kStatusConnected); - } -} - -void TunsafeBackendWin32::OnConnectionRetry(uint32 attempts) { - if (status_ == TunsafeBackend::kStatusInitializing) - SetStatus(TunsafeBackend::kStatusConnecting); - else if (attempts >= 3 && status_ == TunsafeBackend::kStatusConnected) - SetStatus(TunsafeBackend::kStatusReconnecting); -} - void TunsafeBackend::Delegate::DoWork() { // implemented by subclasses } diff --git a/network_win32.h b/network_win32.h index 693beb3..b6a487e 100644 --- a/network_win32.h +++ b/network_win32.h @@ -18,6 +18,7 @@ enum { class WireguardProcessor; class TunsafeBackendWin32; +class TunsafeRunner; class DnsBlocker; struct PacketProcessorTunCb : QueuedItemCallback { @@ -41,7 +42,7 @@ public: void Reset(); - int Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend); + int Run(WireguardProcessor *wg, TunsafeRunner *runner); void PostPackets(Packet *first, Packet **end, int count); void ForcePost(QueuedItem *item); void PostExit(int exit_code); @@ -62,7 +63,7 @@ public: struct QueueContext { WireguardProcessor *wg; - TunsafeBackendWin32 *backend; + TunsafeRunner *runner; bool overload; }; @@ -142,32 +143,27 @@ private: Packet *finished_reads_, **finished_reads_end_; int finished_reads_count_; - __declspec(align(64)) uint32 qsize1_; - __declspec(align(64)) uint32 qsize2_; + uint32 qsize1_; + uint8 align[64-4]; + uint32 qsize2_; }; // Holds the thread for network communications -class NetworkWin32 : public UdpInterface { +class NetworkWin32 { friend class UdpSocketWin32; friend class TcpSocketWin32; friend class TcpSocketQueue; public: explicit NetworkWin32(); ~NetworkWin32(); - + void StartThread(); void StopThread(); UdpSocketWin32 &udp() { return udp_socket_; } SimplePacketPool &packet_pool() { return packet_pool_; } - TcpSocketQueue &tcp_socket_queue() { return tcp_socket_queue_; } void WakeUp(); void PostQueuedItem(QueuedItem *item); - - // -- from UdpInterface - virtual bool Configure(int listen_port_udp, int listen_port_tcp) override; - virtual void WriteUdpPacket(Packet *packet) override; - private: void ThreadMain(); static DWORD WINAPI NetworkThread(void *x); @@ -190,8 +186,6 @@ private: TcpSocketWin32 *tcp_socket_; SimplePacketPool packet_pool_; - - TcpSocketQueue tcp_socket_queue_; }; class TunWin32Adapter { @@ -199,7 +193,7 @@ public: TunWin32Adapter(DnsBlocker *dns_blocker, const char guid[ADAPTER_GUID_SIZE]); ~TunWin32Adapter(); - bool OpenAdapter(TunsafeBackendWin32 *backend, DWORD open_flags); + bool OpenAdapter(TunsafeRunner *backend, DWORD open_flags); bool ConfigureAdapter(const TunInterface::TunConfig &&config, TunInterface::TunConfigOut *out); void CloseAdapter(bool is_restart); @@ -233,7 +227,7 @@ private: // Implementation of TUN interface handling using IO Completion Ports class TunWin32Iocp : public TunInterface { public: - explicit TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend); + explicit TunWin32Iocp(DnsBlocker *blocker, TunsafeRunner *backend); ~TunWin32Iocp(); void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; } @@ -266,11 +260,61 @@ private: // All packets queued for writing Packet *wqueue_, **wqueue_end_; - TunsafeBackendWin32 *backend_; + TunsafeRunner *runner_; TunWin32Adapter adapter_; }; -class TunsafeBackendWin32 : public TunsafeBackend, public ProcessorDelegate, public PluginDelegate { +// This class is the actual TunSafe thing and runs inside of a thread. +class TunsafeRunner : public UdpInterface, public ProcessorDelegate, public PluginDelegate, public QueuedItemCallback { + friend class TunsafeBackendWin32; +public: + TunsafeRunner(TunsafeBackendWin32 *backend); + ~TunsafeRunner(); + + void SetConfigFile(const char *file, bool is_text_format); + + TunsafeBackendWin32 *backend() { return backend_; } + + // -- from UdpInterface + virtual bool Configure(int listen_port_udp, int listen_port_tcp) override; + virtual void WriteUdpPacket(Packet *packet) override; + + virtual void OnConnected() override; + virtual void OnConnectionRetry(uint32 attempts) override; + + // -- from PluginDelegate + virtual void OnRequestToken(WgPeer *peer, uint32 type) override; + + bool Start(); + + // Called by the tun thing if tun stops working and a reset is needed. + void PostTunRestart(); + + uint32 exit_code() { return *packet_processor_.posted_exit_code(); } + + TunsafePlugin *plugin() { return plugin_; } + + void CollectStats(); + +private: + // From OverlappedCallbacks + virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override; + virtual void OnQueuedItemDelete(QueuedItem *ow) override; + + TunsafeBackendWin32 *backend_; + TunsafePlugin *plugin_; + bool config_file_is_text_format_; + std::string config_file_; + TunWin32Iocp tun_; + NetworkWin32 net_; + TcpSocketQueue tcp_socket_queue_; + WireguardProcessor wg_proc_; + PacketProcessor packet_processor_; +}; + + +class TunsafeBackendWin32 : public TunsafeBackend { + friend class TunsafeRunner; friend class PacketProcessor; friend class TunWin32Iocp; friend class TunWin32Overlapped; @@ -297,52 +341,44 @@ public: virtual uint32 GetTokenRequest() override; virtual void SubmitToken(const std::string &&message) override; - // -- from ProcessorDelegate - virtual void OnConnected() override; - virtual void OnConnectionRetry(uint32 attempts) override; - - // -- from PluginDelegate - virtual void OnRequestToken(WgPeer *peer, uint32 type) override; + void OnRequestToken(WgPeer *peer, uint32 type); void SetPublicKey(const uint8 key[32]); - void PostExit(int exit_code); + + StatusCode status() { return status_; } + void SetStatus(StatusCode status); + + void CollectStats(); + +private: + enum { MODE_NONE = 0, MODE_EXIT = 1, MODE_RESTART = 2, - MODE_TUN_FAILED = 3, }; - uint32 exit_code() { return *packet_processor_.posted_exit_code(); } - - void SetStatus(StatusCode status); -private: void StopInner(bool is_restart); static DWORD WINAPI WorkerThread(void *x); void PushStats(); + TunsafeRunner *runner_; HANDLE worker_thread_; bool want_periodic_stats_; Delegate *delegate_; - char *config_file_; std::atomic token_request_; DnsBlocker dns_blocker_; DnsResolver dns_resolver_; - WireguardProcessor *wg_processor_; - TunsafePlugin *tunsafe_wg_plugin_; - uint32 last_tun_adapter_failed_; StatsCollector stats_collector_; Mutex stats_mutex_; WgProcessorStats stats_; - PacketProcessor packet_processor_; - char guid_[ADAPTER_GUID_SIZE]; }; diff --git a/network_win32_api.h b/network_win32_api.h index 6ae91a2..fbf2bfa 100644 --- a/network_win32_api.h +++ b/network_win32_api.h @@ -54,7 +54,6 @@ public: kStatusTunRetrying = 10, kErrorInitialize = -1, - kErrorTunPermanent = -2, kErrorServiceLost = -3, }; diff --git a/network_win32_tcp.cpp b/network_win32_tcp.cpp index 060c88f..e17df95 100644 --- a/network_win32_tcp.cpp +++ b/network_win32_tcp.cpp @@ -9,18 +9,18 @@ //////////////////////////////////////////////////////////////////////////////////////////////////////// -TcpSocketWin32::TcpSocketWin32(NetworkWin32 *network) - : tcp_packet_handler_(&network->packet_pool()) { +TcpSocketWin32::TcpSocketWin32(NetworkWin32 *network, PacketProcessor *packet_handler, WgPacketObfuscator *obfuscator, bool is_incoming) + : packet_processor_(packet_handler), tcp_packet_handler_(&network->packet_pool(), obfuscator, is_incoming) { network_ = network; reads_active_ = 0; writes_active_ = 0; handshake_attempts = 0; + handshake_timestamp_ = 0; state_ = STATE_NONE; wqueue_ = NULL; wqueue_end_ = &wqueue_; socket_ = INVALID_SOCKET; next_ = NULL; - packet_processor_ = NULL; // insert in network's linked list next_ = network->tcp_socket_; network->tcp_socket_ = this; @@ -45,6 +45,7 @@ void TcpSocketWin32::CloseSocket() { } void TcpSocketWin32::WritePacket(Packet *packet) { + packet->prepared = false; packet->queue_next = NULL; *wqueue_end_ = packet; wqueue_end_ = &Packet_NEXT(packet); @@ -145,7 +146,9 @@ void TcpSocketWin32::DoMoreWrites() { return; do { - tcp_packet_handler_.AddHeaderToOutgoingPacket(p); + if (!p->prepared) + tcp_packet_handler_.PrepareOutgoingPackets(p); + wsabuf[num_wsabuf].buf = (char*)p->data; wsabuf[num_wsabuf].len = (ULONG)p->size; packets_in_write_io_[num_wsabuf] = p; @@ -179,8 +182,7 @@ void TcpSocketWin32::DoIO() { while (Packet *p = tcp_packet_handler_.GetNextWireguardPacket()) { p->protocol = endpoint_protocol_; p->addr = endpoint_; - - p->queue_cb = packet_processor_->udp_queue(); + p->queue_cb = packet_processor_->tcp_queue(); packet_processor_->ForcePost(p); } if (tcp_packet_handler_.error()) { @@ -269,63 +271,73 @@ void TcpSocketWin32::OnQueuedItemDelete(QueuedItem *qi) { ///////////////////////////////////////////////////////////////////////// -TcpSocketQueue::TcpSocketQueue(NetworkWin32 *network) { +TcpSocketQueue::TcpSocketQueue(NetworkWin32 *network, WgPacketObfuscator *obfuscator) { network_ = network; wqueue_ = NULL; wqueue_end_ = &wqueue_; queued_item_.queue_cb = this; packet_handler_ = NULL; + obfuscator_ = obfuscator; } TcpSocketQueue::~TcpSocketQueue() { FreePacketList(wqueue_); } -void TcpSocketQueue::TransmitOnePacket(Packet *packet) { - // Check if we have a tcp connection for the endpoint, otherwise create one. - for (TcpSocketWin32 *tcp = network_->tcp_socket_; tcp; tcp = tcp->next_) { - // After we send 3 handshakes on a tcp socket in a row, then close and reopen the socket because it seems defunct. - if (CompareIpAddr(&tcp->endpoint_, &packet->addr) == 0 && tcp->endpoint_protocol_ == packet->protocol) { - if (ReadLE32(packet->data) == MESSAGE_HANDSHAKE_INITIATION) { - if (tcp->handshake_attempts == 2) { - RINFO("Making new Tcp socket due to too many handshake failures"); - tcp->CloseSocket(); - break; +void TcpSocketQueue::TransmitPackets(Packet *packet) { +AGAIN: + while (packet) { + bool is_handshake = ReadLE32(packet->data) == MESSAGE_HANDSHAKE_INITIATION; + + // Check if we have a tcp connection for the endpoint, otherwise create one. + for (TcpSocketWin32 *tcp = network_->tcp_socket_; tcp; tcp = tcp->next_) { + // After we send 3 handshakes on a tcp socket in a row within a minute, + // then close and reopen the socket because it seems defunct. + if (CompareIpAddr(&tcp->endpoint_, &packet->addr) == 0 && tcp->endpoint_protocol_ == packet->protocol) { + if (is_handshake) { + uint32 now = (uint32)OsGetMilliseconds(); + uint32 secs = (now - tcp->handshake_timestamp_) >> 10; + tcp->handshake_timestamp_ += secs * 1024; + int calc = (secs > (uint32)tcp->handshake_attempts + 25) ? 0 : tcp->handshake_attempts + 25 - secs; + tcp->handshake_attempts = calc; + if (calc >= 60) { + RINFO("Making new Tcp socket due to too many handshake failures"); + tcp->CloseSocket(); + break; + } } - tcp->handshake_attempts++; - } else { - tcp->handshake_attempts = -1; + tcp->WritePacket(exch(packet, Packet_NEXT(packet))); + goto AGAIN; } - tcp->WritePacket(packet); - return; } + + // Drop tcp packet that's for an incoming connection, or packets that are + // not a handshake. + if ((packet->protocol & kPacketProtocolIncomingConnection) || !is_handshake) { + FreePacket(exch(packet, Packet_NEXT(packet))); + continue; + } + + // Initialize a new tcp socket and connect to the endpoint + TcpSocketWin32 *tcp = new TcpSocketWin32(network_, packet_handler_, obfuscator_, false); + tcp->state_ = TcpSocketWin32::STATE_WANT_CONNECT; + tcp->endpoint_ = packet->addr; + tcp->endpoint_protocol_ = kPacketProtocolTcp; + tcp->handshake_timestamp_ = (uint32)OsGetMilliseconds(); + tcp->WritePacket(exch(packet, Packet_NEXT(packet))); } - // Drop tcp packet that's for an incoming connection, or packets that are - // not a handshake. - if ((packet->protocol & kPacketProtocolIncomingConnection) || - packet->size < 4 || ReadLE32(packet->data) != MESSAGE_HANDSHAKE_INITIATION) { - FreePacket(packet); - return; - } - - // Initialize a new tcp socket and connect to the endpoint - TcpSocketWin32 *tcp = new TcpSocketWin32(network_); - tcp->state_ = TcpSocketWin32::STATE_WANT_CONNECT; - tcp->endpoint_ = packet->addr; - tcp->endpoint_protocol_ = kPacketProtocolTcp; - tcp->SetPacketHandler(packet_handler_); - tcp->WritePacket(packet); } void TcpSocketQueue::OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) { + // Runs on the network thread wqueue_mutex_.Acquire(); Packet *packet = wqueue_; wqueue_ = NULL; wqueue_end_ = &wqueue_; wqueue_mutex_.Release(); - while (packet) - TransmitOnePacket(exch(packet, Packet_NEXT(packet))); + + TransmitPackets(packet); } void TcpSocketQueue::OnQueuedItemDelete(QueuedItem *ow) { diff --git a/network_win32_tcp.h b/network_win32_tcp.h index 0808f98..605f177 100644 --- a/network_win32_tcp.h +++ b/network_win32_tcp.h @@ -9,19 +9,16 @@ class NetworkWin32; class PacketProcessor; +class WgPacketObfuscator; class TcpSocketWin32 : public QueuedItemCallback { friend class NetworkWin32; friend class TcpSocketQueue; public: - explicit TcpSocketWin32(NetworkWin32 *network); + explicit TcpSocketWin32(NetworkWin32 *network, PacketProcessor *packet_handler, WgPacketObfuscator *obfuscator, bool is_incoming); ~TcpSocketWin32(); - void SetPacketHandler(PacketProcessor *packet_handler) { packet_processor_ = packet_handler; } - - // Write a packet to the TCP socket. This may be called only from the - // wireguard thread. Will append to a buffer and schedule it to be written - // from the network thread. + // Write a packet to the TCP socket. void WritePacket(Packet *packet); // Call from IO completion thread to cancel all outstanding IO @@ -37,7 +34,6 @@ private: void DoMoreReads(); void DoMoreWrites(); void DoConnect(); - void CloseSocket(); // From OverlappedCallbacks @@ -64,6 +60,8 @@ private: public: uint8 handshake_attempts; + uint8 endpoint_protocol_; + uint32 handshake_timestamp_; private: // The handle to the socket @@ -86,7 +84,6 @@ private: QueuedItem connect_overlapped_; IpAddr endpoint_; - uint8 endpoint_protocol_; // Packets currently involved in the wsabuf writing enum { kMaxWsaBuf = 32 }; @@ -95,7 +92,7 @@ private: class TcpSocketQueue : public QueuedItemCallback { public: - explicit TcpSocketQueue(NetworkWin32 *network); + explicit TcpSocketQueue(NetworkWin32 *network, WgPacketObfuscator *obfusctor); ~TcpSocketQueue(); void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; } @@ -106,7 +103,7 @@ public: void WritePacket(Packet *packet); private: - void TransmitOnePacket(Packet *packet); + void TransmitPackets(Packet *packet); NetworkWin32 *network_; // All packets queued for writing on the network thread. Locked by |wqueue_mutex_| @@ -114,11 +111,12 @@ private: PacketProcessor *packet_handler_; + WgPacketObfuscator *obfuscator_; + // Protects wqueue_ Mutex wqueue_mutex_; // Used for queueing things on the network instance QueuedItem queued_item_; - }; diff --git a/tunsafe_bsd.cpp b/tunsafe_bsd.cpp index 34cbf93..17f855c 100644 --- a/tunsafe_bsd.cpp +++ b/tunsafe_bsd.cpp @@ -678,8 +678,6 @@ public: WireguardProcessor *processor() { return &processor_; } private: - void WriteTcpPacket(Packet *packet); - // Close all TCP connections that are not pointed to by any of the peer endpoint. void CloseOrphanTcpConnections(); @@ -697,7 +695,7 @@ private: TunsafeBackendBsdImpl::TunsafeBackendBsdImpl() : is_connected_(false), close_orphan_counter_(0), - plugin_(CreateTunsafePlugin(this)), + plugin_(CreateTunsafePlugin(this, &processor_)), processor_(this, this, this), network_(this, 1000), tun_(&network_, &processor_), @@ -732,49 +730,11 @@ bool TunsafeBackendBsdImpl::Configure(int listen_port, int listen_port_tcp) { (listen_port_tcp == 0 || tcp_socket_listener_.Initialize(listen_port_tcp)); } -void TunsafeBackendBsdImpl::WriteTcpPacket(Packet *packet) { - // Check if we have a tcp connection for the endpoint, otherwise create one. - for (TcpSocketBsd *tcp = network_.tcp_sockets(); tcp; tcp = tcp->next()) { - // After we send 3 handshakes on a tcp socket in a row, then close and reopen the socket because it seems defunct. - if (CompareIpAddr(&tcp->endpoint(), &packet->addr) == 0 && tcp->endpoint_protocol() == packet->protocol) { - if (ReadLE32(packet->data) == MESSAGE_HANDSHAKE_INITIATION) { - if (tcp->handshake_attempts == 2) { - RINFO("Making new Tcp socket due to too many handshake failures"); - delete tcp; - break; - } - tcp->handshake_attempts++; - } else { - tcp->handshake_attempts = -1; - } - tcp->WritePacket(packet); - return; - } - } - // Drop tcp packet that's for an incoming connection, or packets that are - // not a handshake. - if ((packet->protocol & kPacketProtocolIncomingConnection) || - ReadLE32(packet->data) != MESSAGE_HANDSHAKE_INITIATION) { - FreePacket(packet); - return; - } - // Initialize a new tcp socket and connect to the endpoint - TcpSocketBsd *tcp = new TcpSocketBsd(&network_, &processor_); - if (!tcp || !tcp->InitializeOutgoing(packet->addr)) { - delete tcp; - FreePacket(packet); - return; - } - tcp->WritePacket(packet); -} - void TunsafeBackendBsdImpl::WriteUdpPacket(Packet *packet) { assert((packet->protocol & 0x7F) <= 2); if (packet->protocol & kPacketProtocolTcp) { - WriteTcpPacket(packet); + TcpSocketBsd::WriteTcpPacket(&network_, &processor_, packet); } else { - if (processor_.dev().packet_obfuscator().enabled()) - processor_.dev().packet_obfuscator().ObfuscatePacket(packet); udp_.WritePacket(packet); } } diff --git a/tunsafe_win32.cpp b/tunsafe_win32.cpp index 7d37292..a34db5a 100644 --- a/tunsafe_win32.cpp +++ b/tunsafe_win32.cpp @@ -1187,7 +1187,6 @@ static HFONT CreateFontHelper(int size, byte flags, const char *face, int angle static const char *StatusCodeToString(TunsafeBackend::StatusCode code) { switch (code) { case TunsafeBackend::kErrorInitialize: return "Configuration Error"; - case TunsafeBackend::kErrorTunPermanent: return "TUN Adapter Error"; case TunsafeBackend::kErrorServiceLost: return "Service Lost"; case TunsafeBackend::kStatusStopped: return "Disconnected"; case TunsafeBackend::kStatusInitializing: return "Initializing"; diff --git a/wireguard.cpp b/wireguard.cpp index 60574bd..2a8581f 100644 --- a/wireguard.cpp +++ b/wireguard.cpp @@ -207,6 +207,10 @@ bool WireguardProcessor::ConfigureTun() { for (WgPeer *peer = dev_.first_peer(); peer; peer = peer->next_peer_) { peer->ipv4_broadcast_addr_ = ipv4_broadcast_addr; + + if (peer->endpoint_protocol_ == kPacketProtocolTcp) + peer->allow_endpoint_change_ = false; + if (peer->endpoint_.sin.sin_family != 0) { RINFO("Sending handshake..."); SendHandshakeInitiation(peer); @@ -419,7 +423,7 @@ WireguardProcessor::PacketResult WireguardProcessor::WriteAndEncryptPacketToUdp_ uint64 send_ctr; // Ensure packet will fit including the biggest padding - if (peer->endpoint_.sin.sin_family == 0 || + if (peer->data_endpoint_.sin.sin_family == 0 || size > kPacketCapacity - 15 - CHACHA20POLY1305_AUTHTAGLEN) goto getout_discard; @@ -443,8 +447,8 @@ WireguardProcessor::PacketResult WireguardProcessor::WriteAndEncryptPacketToUdp_ want_handshake = (send_ctr >= REKEY_AFTER_MESSAGES || keypair->send_key_state == WgKeypair::KEY_WANT_REFRESH); keypair->send_ctr = send_ctr + 1; - packet->addr = peer->endpoint_; - packet->protocol = peer->endpoint_protocol_; + packet->addr = peer->data_endpoint_; + packet->protocol = peer->data_endpoint_protocol_; WG_EXTENSION_HOOKS::OnPeerOutgoingUdp(peer, packet); @@ -639,7 +643,9 @@ void WireguardProcessor::SendHandshakeInitiation(WgPeer *peer) { if (attempts >= 3 && peer->allow_endpoint_change_ && (peer->endpoint_protocol_ & kPacketProtocolIncomingConnection)) { peer->endpoint_protocol_ = 0; + peer->data_endpoint_protocol_ = 0; peer->endpoint_.sin.sin_family = 0; + peer->data_endpoint_.sin.sin_family = 0; } WG_RELEASE_LOCK(peer->mutex_); @@ -840,15 +846,23 @@ WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPack assert(packet->addr.sin.sin_family != 0); // Remember the endpoint of the peer - if (peer->allow_endpoint_change_ && - (CompareIpAddr(&peer->endpoint_, &packet->addr) | (peer->endpoint_protocol_ ^ packet->protocol)) != 0) { + if (peer->allow_endpoint_change_ && + (CompareIpAddr(&peer->data_endpoint_, &packet->addr) | (peer->data_endpoint_protocol_ ^ packet->protocol)) != 0) { + #if WITH_SHORT_HEADERS // When the endpoint changes, forget about using the short key. keypair->broadcast_short_key = 0; keypair->can_use_short_key_for_outgoing = false; #endif // WITH_SHORT_HEADERS - peer->endpoint_ = packet->addr; - peer->endpoint_protocol_ = packet->protocol; + + peer->data_endpoint_ = packet->addr; + peer->data_endpoint_protocol_ = packet->protocol; + + // In the hybrid tcp mode, only the data endpoint gets overwritten on incoming data packets. + if (!keypair->enabled_features[WG_FEATURE_HYBRID_TCP]) { + peer->endpoint_ = packet->addr; + peer->endpoint_protocol_ = packet->protocol; + } } WG_EXTENSION_HOOKS::OnPeerIncomingUdp(peer, packet); diff --git a/wireguard_config.cpp b/wireguard_config.cpp index b1f8226..c968f0b 100644 --- a/wireguard_config.cpp +++ b/wireguard_config.cpp @@ -56,12 +56,14 @@ static int ParseFeature(const char *str) { } if (len == 5 && memcmp(str, "mac64", 5) == 0) return what + WG_FEATURE_ID_SHORT_MAC * 16; - if (len == 12 && memcmp(str, "short_header", 12) == 0) - return what + WG_FEATURE_ID_SHORT_HEADER * 16; if (len == 5 && memcmp(str, "ipzip", 5) == 0) return what + WG_FEATURE_ID_IPZIP * 16; + if (len == 10 && memcmp(str, "hybrid_tcp", 10) == 0) + return what + WG_FEATURE_HYBRID_TCP * 16; if (len == 10 && memcmp(str, "skip_keyid", 10) == 0) return what + WG_FEATURE_ID_SKIP_KEYID_IN * 16 + 1 * 4; + if (len == 12 && memcmp(str, "short_header", 12) == 0) + return what + WG_FEATURE_ID_SHORT_HEADER * 16; if (len == 13 && memcmp(str, "skip_keyid_in", 13) == 0) return what + WG_FEATURE_ID_SKIP_KEYID_IN * 16; if (len == 14 && memcmp(str, "skip_keyid_out", 14) == 0) @@ -169,8 +171,22 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) { } wg_->SetInternetBlocking((InternetBlockState)v); - } else if (strcmp(key, "HeaderObfuscation") == 0) { + } else if (strcmp(key, "ObfuscateKey") == 0) { wg_->dev().packet_obfuscator().SetKey((uint8*)value, strlen(value)); + } else if (strcmp(key, "ObfuscateTCP") == 0) { + bool flag; + int v = 1; + if (ParseBoolean(value, &flag)) { + v = flag; + } else if (strcmp(value, "tls-firefox") == 0) { + v = 2; + } else if (strcmp(value, "tls-chrome") == 0) { + v = 3; + } else if (*value != 0) { + RERROR("Unknown mode in ObfuscateTCP: %s", value); + } + wg_->dev().packet_obfuscator().set_obfuscate_tcp(v); + } else if (strcmp(key, "PostUp") == 0) { wg_->prepost().post_up.emplace_back(value); } else if (strcmp(key, "PostDown") == 0) { diff --git a/wireguard_proto.cpp b/wireguard_proto.cpp index ca053df..51b37b7 100644 --- a/wireguard_proto.cpp +++ b/wireguard_proto.cpp @@ -334,7 +334,9 @@ WgPeer::WgPeer(WgDevice *dev) { assert(dev->IsMainThread()); dev_ = dev; endpoint_.sin.sin_family = 0; + data_endpoint_.sin.sin_family = 0; endpoint_protocol_ = 0; + data_endpoint_protocol_ = 0; next_peer_ = NULL; peer_extra_data_ = NULL; curr_keypair_ = next_keypair_ = prev_keypair_ = NULL; @@ -685,6 +687,11 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) { WG_ACQUIRE_LOCK(peer->mutex_); peer->rx_bytes_ += packet->size; if (keypair != NULL) { + // The server side needs to remember the endpoint on incoming handshakes. + if (peer->allow_endpoint_change_ && keypair->enabled_features[WG_FEATURE_HYBRID_TCP]) { + peer->endpoint_ = packet->addr; + peer->endpoint_protocol_ = packet->protocol; + } peer->InsertKeypairInPeer_Locked(keypair); peer->OnHandshakeAuthComplete(); } @@ -772,8 +779,19 @@ WgPeer *WgPeer::ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packe WG_ACQUIRE_LOCK(peer->mutex_); if (peer->allow_endpoint_change_) { - peer->endpoint_ = packet->addr; + // TODO: Why is this needed, if we are able to get a response for the handshake init + // packet then we already know its endpoint? peer->endpoint_protocol_ = packet->protocol; + peer->endpoint_ = packet->addr; + if (!keypair->enabled_features[WG_FEATURE_HYBRID_TCP] || !peer->IsTransientDataEndpointActive()) { + peer->data_endpoint_protocol_ = peer->endpoint_protocol_; + peer->data_endpoint_ = peer->endpoint_; + } + // If hybrid tcp mode was enabled for the connection, switch + // the data endpoint to the udp endpoint. + } else if (peer->endpoint_protocol_ == kPacketProtocolTcp) { + peer->data_endpoint_protocol_ = keypair->enabled_features[WG_FEATURE_HYBRID_TCP] ? kPacketProtocolUdp : kPacketProtocolTcp; + peer->data_endpoint_ = peer->endpoint_; } WG_EXTENSION_HOOKS::OnPeerIncomingUdp(peer, packet); @@ -1065,11 +1083,15 @@ enum { TIMER_ZERO_KEYS = 3, // Timer for sending a keepalive packet every PERSISTENT_KEEPALIVE_MS TIMER_PERSISTENT_KEEPALIVE = 4, + // Timer for removing the transient UDP endpoint in hybrid TCP mode after 10 seconds + TIMER_HYBRID_TCP = 5, + + TIMERS_COUNT = 6, }; -#define WgClearTimer(x) (timers_ &= ~(33 << x)) -#define WgIsTimerActive(x) (timers_ & (33 << x)) -#define WgSetTimer(x) (timers_ |= (32 << (x))) +#define WgClearTimer(x) (timers_ &= ~(((1<> 5); - t &= 0x1F; + if (t & (((1 << TIMERS_COUNT) - 1) << TIMERS_COUNT)) { + if (t & (1 << (TIMERS_COUNT+0))) timer_value_[0] = now32; + if (t & (1 << (TIMERS_COUNT+1))) timer_value_[1] = now32; + if (t & (1 << (TIMERS_COUNT+2))) timer_value_[2] = now32; + if (t & (1 << (TIMERS_COUNT+3))) timer_value_[3] = now32; + if (t & (1 << (TIMERS_COUNT+4))) timer_value_[4] = now32; + if (t & (1 << (TIMERS_COUNT+5))) timer_value_[5] = now32; + t |= (t >> TIMERS_COUNT); + t &= (1 << TIMERS_COUNT) - 1; } // Got any expired timers? - if (t & 0x1F) { + if (t & ((1 << TIMERS_COUNT) - 1)) { if ((t & (1 << TIMER_RETRANSMIT_HANDSHAKE)) && (now32 - timer_value_[TIMER_RETRANSMIT_HANDSHAKE]) >= REKEY_TIMEOUT_MS) { t ^= (1 << TIMER_RETRANSMIT_HANDSHAKE); if (handshake_attempts_ > MAX_HANDSHAKE_ATTEMPTS || endpoint_.sin.sin_family == 0) { @@ -1212,6 +1237,16 @@ uint32 WgPeer::CheckTimeouts_Locked(uint64 now) { ClearKeys_Locked(); ClearHandshake_Locked(); } + + if ((t & (1 << TIMER_HYBRID_TCP)) && (now32 - timer_value_[TIMER_HYBRID_TCP]) >= HYBRID_TCP_TIMEOUT_MS) { + t &= ~(1 << TIMER_HYBRID_TCP); + // Forget about the data endpoint and switch to using the regular endpoint after 15 seconds. + if (allow_endpoint_change_) { + data_endpoint_protocol_ = endpoint_protocol_; + data_endpoint_ = endpoint_; + } + } + } timers_ = t; return rv; @@ -1261,9 +1296,15 @@ void WgPeer::CheckAndUpdateTimeOfNextKeyEvent(uint64 now) { time_of_next_key_event_ = next_time; } +bool WgPeer::IsTransientDataEndpointActive() { + return WgIsTimerActive(TIMER_HYBRID_TCP) != 0; +} + void WgPeer::SetEndpoint(int endpoint_proto, const IpAddr &sin) { endpoint_protocol_ = endpoint_proto; + data_endpoint_protocol_ = endpoint_proto; endpoint_ = sin; + data_endpoint_ = sin; } bool WgPeer::SetPersistentKeepalive(int persistent_keepalive_secs) { @@ -1484,18 +1525,19 @@ size_t WgPublicKeyHasher::operator()(const WgPublicKey&a) const { // This scrambles the initial 16 bytes of the packet with the // last 8 bytes of the packet as a seed. void WgPacketObfuscator::ScrambleUnscramble(uint8 *data, size_t data_size) { + assert(data_size >= 16); + uint64 last_uint64 = ReadLE64(data + data_size - 8); uint64 a = siphash_u64_u32(last_uint64, (uint32)data_size, (siphash_key_t*)&key_[0]); uint64 b = siphash_u64_u32(last_uint64, (uint32)data_size, (siphash_key_t*)&key_[2]); - a = ToLE64(a); + ((uint64*)data)[0] ^= ToLE64(a); b = ToLE64(b); if (data_size >= 24) { - ((uint64*)data)[0] ^= a; ((uint64*)data)[1] ^= b; } else { - uint64 d[2] = { a, b }; - for (size_t i = 0; i < data_size - 8; i++) - data[i] ^= ((uint8*)d)[i]; + uint64 d[1] = { b }; + for (size_t i = 0; i < data_size - 16; i++) + data[i + 8] ^= ((uint8*)d)[i]; } } @@ -1524,12 +1566,11 @@ void WgPacketObfuscator::ObfuscatePacket(Packet *packet) { // in the 3:rd byte of the packet. uint32 packet_type = ReadLE32(data); if ((packet_type == 4 && data_size <= 32) || packet_type < 4) { - if (packet_type != 4) { - // The 39:th and 43:rd bytes often have zero MSB because of curve25519 pubkey, - // so xor them with something in the header. - assert(data_size >= 44); - data[39] ^= data[12]; - data[43] ^= data[12]; + // The 39:th (for handshake init) and 43:rd byte (for handshake response) + // have zero MSB because of curve25519 pubkey, so xor it with random. + if (packet_type < 4) { + assert(data_size >= 48); + data[35 + packet_type * 4] ^= data[15]; } packet->size = data_size = InsertRandomBytesIntoPacket(data, data_size); } @@ -1552,16 +1593,14 @@ void WgPacketObfuscator::DeobfuscatePacket(Packet *packet) { // Check whether the packet type field says that we have // extra bytes appended at the end. if (data[0] <= 4) { - if (data[0] < 4 && data_size >= 44) { - // The 39:th and 43:rd bytes often have zero MSB because of curve25519 pubkey, - // so xor them with something in the header. - data[39] ^= data[12]; - data[43] ^= data[12]; - } - if (data[3] <= data_size) { - packet->size = (uint32)(data_size - data[3]); - data[3] = 0; - } + if (data[3] > data_size) + return; // invalid + packet->size = (uint32)(data_size -= data[3]); + data[3] = 0; + // The 39:th (for handshake init) and 43:rd byte (for handshake response) + // have zero MSB because of curve25519 pubkey, so xor it with random. + if (data[0] < 4 && data_size >= 48) + data[35 + data[0] * 4] ^= data[15]; } } diff --git a/wireguard_proto.h b/wireguard_proto.h index 4ffbac9..5cdd9e1 100644 --- a/wireguard_proto.h +++ b/wireguard_proto.h @@ -62,7 +62,12 @@ enum ProtocolTimeouts { REJECT_AFTER_TIME_MS = 180000, MIN_HANDSHAKE_INTERVAL_MS = 20, - MAX_SIZE_OF_HANDSHAKE_EXTENSION = 1024, + HYBRID_TCP_TIMEOUT_MS = 15000, + + // Chosen so that 1500 - 28 - sizeof(handshakeresponse) which means + // we can use this to probe mtu. + MAX_SIZE_OF_HANDSHAKE_EXTENSION = 1380, + }; enum ProtocolLimits { @@ -179,12 +184,13 @@ enum { }; enum { - WG_FEATURES_COUNT = 6, + WG_FEATURES_COUNT = 7, WG_FEATURE_ID_SHORT_HEADER = 0, // Supports short headers WG_FEATURE_ID_SHORT_MAC = 1, // Supports 8-byte MAC WG_FEATURE_ID_IPZIP = 2, // Using ipzip WG_FEATURE_ID_SKIP_KEYID_IN = 4, // Skip keyid for incoming packets WG_FEATURE_ID_SKIP_KEYID_OUT = 5, // Skip keyid for outgoing packets + WG_FEATURE_HYBRID_TCP = 6, // Use hybrid-tcp mode }; enum { @@ -340,7 +346,7 @@ public: // including adding random bytes at the end of the non-data packets. class WgPacketObfuscator { public: - WgPacketObfuscator() : enabled_(false) {} + WgPacketObfuscator() : enabled_(false), obfuscate_tcp_(-1) {} bool enabled() { return enabled_; } void ObfuscatePacket(Packet *packet); @@ -350,6 +356,9 @@ public: const uint8 *key() { return (uint8*)key_; } + int obfuscate_tcp() { return obfuscate_tcp_; } + void set_obfuscate_tcp(int v) { obfuscate_tcp_ = v; } + static size_t InsertRandomBytesIntoPacket(uint8 *data, size_t data_size); private: @@ -358,6 +367,9 @@ private: // Whether packet obfuscation is enabled bool enabled_; + // Type of obfuscation for tcp + int obfuscate_tcp_; + // Siphash keys for packet scrambling uint64 key_[4]; }; @@ -395,6 +407,8 @@ public: WgRateLimit *rate_limiter() { return &rate_limiter_; } bool is_private_key_initialized() { return is_private_key_initialized_; } + void SetCurrentThreadAsMainThread() { main_thread_id_ = GetCurrentThreadId(); } + bool IsMainThread() { return CurrentThreadIdEquals(main_thread_id_); } bool IsMainOrDataThread() { return CurrentThreadIdEquals(main_thread_id_) || WG_IF_LOCKS_ENABLED_ELSE(delayed_delete_.enabled(), false); } @@ -565,7 +579,7 @@ private: void ClearHandshake_Locked(); void ClearPacketQueue_Locked(); void ScheduleNewHandshake(); - + bool IsTransientDataEndpointActive(); WgDevice *dev_; WgPeer *next_peer_; @@ -582,7 +596,6 @@ private: // For timer management uint32 timers_; - uint32 timer_value_[5]; // Holds the entry into the key id table during handshake - mt only. uint32 local_key_id_during_hs_; @@ -623,7 +636,7 @@ private: uint8 handshake_attempts_; // What's the protocol of the currently configured endpoint - uint8 endpoint_protocol_; + uint8 endpoint_protocol_, data_endpoint_protocol_; // Which features are enabled for this peer? uint8 features_[WG_FEATURES_COUNT]; @@ -632,9 +645,6 @@ private: uint8 num_queued_packets_; Packet *first_queued_packet_, **last_queued_packet_ptr_; - // Address of peer - IpAddr endpoint_; - // For statistics uint64 last_handshake_init_timestamp_; uint64 last_complete_handskake_timestamp_; @@ -642,6 +652,13 @@ private: // Timestamp to detect flooding of handshakes uint64 last_handshake_init_recv_timestamp_; // main thread only + // Address of peer + IpAddr endpoint_; + + // Alternative endpoint. This is used in hybrid tcp mode to hold the + // udp endpoint. + IpAddr data_endpoint_; + // Number of handshake attempts since last successful handshake uint32 total_handshake_attempts_; @@ -653,6 +670,8 @@ private: uint32 keepalive_timeout_ms_; // Set to KEEPALIVE_TIMEOUT_MS + uint32 timer_value_[6]; + uint64 rx_bytes_; uint64 tx_bytes_;