Lots of new features

- Hybrid TCP mode, uses both TCP and UDP
 - Simplified TCP protocol
 - Modified obfuscator to support padding
 - Obfuscation over TCP
 - Refactor parts of Win32 code to be more similar to BSD
This commit is contained in:
Ludvig Strigeus 2018-12-16 16:02:50 +01:00
parent f7b09c43fd
commit 27b75b83de
17 changed files with 1313 additions and 528 deletions

View file

@ -65,40 +65,12 @@ TT LLLLLL LLLLLLLL [Payload LL bytes]
The packet types (TT) currently defined are: The packet types (TT) currently defined are:
TT = 00 = Normal The payload is a normal unmodified WireGuard packet TT = 00 = Normal The payload is a normal unmodified WireGuard packet
including the regular WireGuard header. including the regular WireGuard header.
01 = Reserved
10 = Data A WireGuard data packet (type 04) without the 16 byte 10 = Data A WireGuard data packet (type 04) without the 16 byte
header. The predicted header is prefixed to the payload. header.
11 = Control A TCP control packet. Currently this is used only to setup ?1 = Reserved
the header prediction. See below.
There's only one defined Control packet, type 00 (SetKeyAndCounter): When parsing an incoming Data (TT=10) packet, the Key ID and Counter from
the most recently parsed WireGuard data packet (type 04) is prepended to
0 1 2 3 the payload, with Counter incremented by 1.
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|1 1| Length is 13 (14 bits) | 00 (8 bits) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Key ID (32 bits) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Counter (64 bits) ...
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
This sets up the Key ID and Counter used for the Data packets. Then Counter
is incremented by 1 for every such packet.
For every Data packet, the predicted Key ID and Counter is expanded to a
regular WireGuard data (type 04) header, which is prefixed to the payload:
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| 04 (8 bits) | Reserved (24 bits) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Key ID (32 bits) |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Counter (64 bits) ...
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Data Payload (LL * 8 bits) ...
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
This happens independently in each of the two TCP directions. This happens independently in each of the two TCP directions.

View file

@ -58,6 +58,7 @@ struct Packet : QueuedItem {
byte *data; byte *data;
uint8 userdata; uint8 userdata;
uint8 protocol; // which protocol is this packet for/from uint8 protocol; // which protocol is this packet for/from
bool prepared;
IpAddr addr; // Optionally set to target/source of the packet IpAddr addr; // Optionally set to target/source of the packet
enum { enum {

View file

@ -527,6 +527,10 @@ bool UdpSocketBsd::DoWrite() {
void UdpSocketBsd::WritePacket(Packet *packet) { void UdpSocketBsd::WritePacket(Packet *packet) {
assert(fd_ >= 0); assert(fd_ >= 0);
if (processor_->dev().packet_obfuscator().enabled())
processor_->dev().packet_obfuscator().ObfuscatePacket(packet);
Packet *queue_is_used = udp_queue_; Packet *queue_is_used = udp_queue_;
*udp_queue_end_ = packet; *udp_queue_end_ = packet;
udp_queue_end_ = &Packet_NEXT(packet); udp_queue_end_ = &Packet_NEXT(packet);
@ -824,7 +828,8 @@ void TcpSocketListenerBsd::HandleEvents(int revents) {
int new_fd = accept(fd_, (sockaddr*)&addr, &len); int new_fd = accept(fd_, (sockaddr*)&addr, &len);
if (new_fd >= 0) { if (new_fd >= 0) {
RINFO("Created new tcp socket"); RINFO("Created new tcp socket");
TcpSocketBsd *channel = new TcpSocketBsd(network_, processor_);
TcpSocketBsd *channel = new TcpSocketBsd(network_, processor_, true);
if (channel) if (channel)
channel->InitializeIncoming(new_fd, addr); channel->InitializeIncoming(new_fd, addr);
else else
@ -840,18 +845,61 @@ void TcpSocketListenerBsd::Periodic() {
} }
////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////
TcpSocketBsd::TcpSocketBsd(NetworkBsd *net, WireguardProcessor *processor) void TcpSocketBsd::WriteTcpPacket(NetworkBsd *network, WireguardProcessor *processor, Packet *packet) {
bool is_handshake = ReadLE32(packet->data) == MESSAGE_HANDSHAKE_INITIATION;
// Check if we have a tcp connection for the endpoint, otherwise create one.
for (TcpSocketBsd *tcp = network->tcp_sockets(); tcp; tcp = tcp->next()) {
// After we send 3 handshakes on a tcp socket in a row, then close and reopen the socket because it seems defunct.
if (CompareIpAddr(&tcp->endpoint(), &packet->addr) == 0 && tcp->endpoint_protocol() == packet->protocol) {
if (is_handshake) {
uint32 now = (uint32)OsGetMilliseconds();
uint32 secs = (now - tcp->handshake_timestamp_) >> 10;
tcp->handshake_timestamp_ += secs * 1024;
int calc = (secs > (uint32)tcp->handshake_attempts_ + 25) ? 0 : tcp->handshake_attempts_ + 25 - secs;
tcp->handshake_attempts_ = calc;
if (calc >= 60) {
RINFO("Making new Tcp socket due to too many handshake failures");
delete tcp;
break;
}
}
tcp->WritePacket(packet);
return;
}
}
// Drop tcp packet that's for an incoming connection, or packets that are
// not a handshake.
if ((packet->protocol & kPacketProtocolIncomingConnection) || !is_handshake) {
FreePacket(packet);
return;
}
// Initialize a new tcp socket and connect to the endpoint
TcpSocketBsd *tcp = new TcpSocketBsd(network, processor, false);
if (!tcp || !tcp->InitializeOutgoing(packet->addr)) {
delete tcp;
FreePacket(packet);
return;
}
tcp->WritePacket(packet);
}
//////////////////////////////////////////////////////////////////////////////////////////////
TcpSocketBsd::TcpSocketBsd(NetworkBsd *net, WireguardProcessor *processor, bool is_incoming)
: BaseSocketBsd(net), : BaseSocketBsd(net),
readable_(false), readable_(false),
writable_(true), writable_(true),
endpoint_protocol_(0), endpoint_protocol_(0),
age(0), age(0),
handshake_attempts(0), handshake_attempts_(0),
handshake_timestamp_(0),
wqueue_(NULL), wqueue_(NULL),
wqueue_end_(&wqueue_), wqueue_end_(&wqueue_),
wqueue_bytes_(0), wqueue_packets_(0),
processor_(processor), processor_(processor),
tcp_packet_handler_(&net->packet_pool_) { tcp_packet_handler_(&net->packet_pool_, &processor->dev().packet_obfuscator(), is_incoming) {
// insert in network's linked list // insert in network's linked list
next_ = net->tcp_sockets_; next_ = net->tcp_sockets_;
net->tcp_sockets_ = this; net->tcp_sockets_ = this;
@ -908,19 +956,22 @@ bool TcpSocketBsd::InitializeOutgoing(const IpAddr &addr) {
void TcpSocketBsd::WritePacket(Packet *packet) { void TcpSocketBsd::WritePacket(Packet *packet) {
assert(fd_ >= 0); assert(fd_ >= 0);
tcp_packet_handler_.AddHeaderToOutgoingPacket(packet);
Packet *old_value = wqueue_; Packet *old_value = wqueue_;
*wqueue_end_ = packet; *wqueue_end_ = packet;
wqueue_end_ = &Packet_NEXT(packet); wqueue_end_ = &Packet_NEXT(packet);
packet->queue_next = NULL; packet->queue_next = NULL;
packet->prepared = false;
AddToEndLoop(); AddToEndLoop();
wqueue_bytes_ += packet->size; // Note: Cannot use bytes here, because the TCP packet
// headers have not been added yet, and then the
// accounting doesn't work
wqueue_packets_++;
// When many bytes have been queued, perform the write. // When enough packets have been queued up, perform the write.
if (writable_ && wqueue_bytes_ >= 32768) if (writable_ && wqueue_packets_ >= 16)
DoWrite(); DoWrite();
} }
@ -982,10 +1033,19 @@ void TcpSocketBsd::DoWrite() {
struct iovec vecs[kMaxIoWrite]; struct iovec vecs[kMaxIoWrite];
Packet *p = wqueue_; Packet *p = wqueue_;
size_t nvec = 0; size_t nvec = 0;
for (; p && nvec < kMaxIoWrite; nvec++, p = Packet_NEXT(p)) { for (; p && nvec < kMaxIoWrite; p = Packet_NEXT(p)) {
if (!p->prepared)
tcp_packet_handler_.PrepareOutgoingPackets(p);
if (p->size != 0) {
vecs[nvec].iov_base = p->data; vecs[nvec].iov_base = p->data;
vecs[nvec].iov_len = p->size; vecs[nvec].iov_len = p->size;
nvec++;
} }
}
if (nvec == 0)
return;
ssize_t n = writev(fd_, vecs, nvec); ssize_t n = writev(fd_, vecs, nvec);
if (n < 0) { if (n < 0) {
@ -998,9 +1058,7 @@ void TcpSocketBsd::DoWrite() {
} }
return; return;
} }
wqueue_bytes_ -= n;
// discard those initial n bytes worth of packets // discard those initial n bytes worth of packets
size_t i = 0;
p = wqueue_; p = wqueue_;
while (n) { while (n) {
if (n < p->size) { if (n < p->size) {
@ -1009,6 +1067,7 @@ void TcpSocketBsd::DoWrite() {
} }
n -= p->size; n -= p->size;
FreePacket(exch(p, Packet_NEXT(p))); FreePacket(exch(p, Packet_NEXT(p)));
wqueue_packets_--;
} }
if (!(wqueue_ = p)) if (!(wqueue_ = p))
wqueue_end_ = &wqueue_; wqueue_end_ = &wqueue_;

View file

@ -244,7 +244,7 @@ private:
class TcpSocketBsd : public BaseSocketBsd { class TcpSocketBsd : public BaseSocketBsd {
public: public:
explicit TcpSocketBsd(NetworkBsd *bsd, WireguardProcessor *processor); explicit TcpSocketBsd(NetworkBsd *bsd, WireguardProcessor *processor, bool is_incoming);
virtual ~TcpSocketBsd(); virtual ~TcpSocketBsd();
void InitializeIncoming(int fd, const IpAddr &addr); void InitializeIncoming(int fd, const IpAddr &addr);
@ -259,9 +259,10 @@ public:
uint8 endpoint_protocol() { return endpoint_protocol_; } uint8 endpoint_protocol() { return endpoint_protocol_; }
const IpAddr &endpoint() { return endpoint_; } const IpAddr &endpoint() { return endpoint_; }
static void WriteTcpPacket(NetworkBsd *network, WireguardProcessor *processor, Packet *packet);
public: public:
uint8 age; uint8 age;
uint8 handshake_attempts;
private: private:
void DoRead(); void DoRead();
void DoWrite(); void DoWrite();
@ -271,8 +272,10 @@ private:
bool got_eof_; bool got_eof_;
uint8 endpoint_protocol_; uint8 endpoint_protocol_;
bool want_connect_; bool want_connect_;
uint8 handshake_attempts_;
uint32 handshake_timestamp_;
uint32 wqueue_bytes_; uint wqueue_packets_;
Packet *wqueue_, **wqueue_end_; Packet *wqueue_, **wqueue_end_;
TcpSocketBsd *next_; TcpSocketBsd *next_;
WireguardProcessor *processor_; WireguardProcessor *processor_;

View file

@ -5,81 +5,34 @@
#include <assert.h> #include <assert.h>
#include <algorithm> #include <algorithm>
#include "util.h" #include "util.h"
#include "crypto/chacha20poly1305.h"
#include "crypto/blake2s/blake2s.h"
#include "wireguard_proto.h"
TcpPacketHandler::TcpPacketHandler(SimplePacketPool *packet_pool) { enum {
packet_pool_ = packet_pool; CRYPTO_HEADER_SIZE = 64,
rqueue_bytes_ = 0; };
error_flag_ = false;
rqueue_ = NULL;
rqueue_end_ = &rqueue_;
predicted_key_in_ = predicted_key_out_ = 0;
predicted_serial_in_ = predicted_serial_out_ = 0;
}
TcpPacketHandler::~TcpPacketHandler() { enum {
READ_CRYPTO_HEADER = 0,
READ_PACKET_HEADER = 1,
READ_PACKET_DATA = 2,
};
TcpPacketQueue::~TcpPacketQueue() {
FreePacketList(rqueue_); FreePacketList(rqueue_);
} }
enum { Packet *TcpPacketQueue::Read(uint num) {
kTcpPacketType_Normal = 0, // Move data around to ensure that exactly the first |num| bytes are stored
kTcpPacketType_Reserved = 1, // in the first packet, and the rest of the data in subsequent packets.
kTcpPacketType_Data = 2,
kTcpPacketType_Control = 3,
kTcpPacketControlType_SetKeyAndCounter = 0,
};
void TcpPacketHandler::AddHeaderToOutgoingPacket(Packet *p) {
unsigned int size = p->size;
uint8 *data = p->data;
if (size >= 16 && ReadLE32(data) == 4) {
uint32 key = Read32(data + 4);
uint64 serial = ReadLE64(data + 8);
WriteBE16(data + 14, size - 16 + (kTcpPacketType_Data << 14));
data += 14, size -= 14;
// Insert a 15 byte control packet right before to set the new key/serial?
if ((predicted_key_out_ ^ key) | (predicted_serial_out_ ^ serial)) {
predicted_key_out_ = key;
WriteLE64(data - 8, serial);
Write32(data - 12, key);
data[-13] = kTcpPacketControlType_SetKeyAndCounter;
WriteBE16(data - 15, 13 + (kTcpPacketType_Control << 14));
data -= 15, size += 15;
}
// Increase the serial by 1 for next packet.
predicted_serial_out_ = serial + 1;
} else {
WriteBE16(data - 2, size);
data -= 2, size += 2;
}
p->size = size;
p->data = data;
}
void TcpPacketHandler::QueueIncomingPacket(Packet *p) {
rqueue_bytes_ += p->size;
p->queue_next = NULL;
*rqueue_end_ = p;
rqueue_end_ = &Packet_NEXT(p);
}
// Either the packet fits in one buf or not.
static uint32 ReadPacketHeader(Packet *p) {
if (p->size >= 2)
return ReadBE16(p->data);
else
return (p->data[0] << 8) + (Packet_NEXT(p)->data[0]);
}
// Move data around to ensure that exactly the first |num| bytes are stored
// in the first packet, and the rest of the data in subsequent packets.
Packet *TcpPacketHandler::ReadNextPacket(uint32 num) {
Packet *p = rqueue_; Packet *p = rqueue_;
assert(num <= kPacketCapacity); assert(num <= kPacketCapacity);
if (p->size < num) { if (p->size < num) {
// There's not enough data in the current packet, copy data from the next packet // There's not enough data in the current packet, copy data from the next packet
// into this packet. // into this packet.
if ((uint32)(&p->data_buf[kPacketCapacity] - p->data) < num) { if ((uint)(&p->data_buf[kPacketCapacity] - p->data) < num) {
// Move data up front to make space. // Move data up front to make space.
memmove(p->data_buf, p->data, p->size); memmove(p->data_buf, p->data, p->size);
p->data = p->data_buf; p->data = p->data_buf;
@ -87,17 +40,17 @@ Packet *TcpPacketHandler::ReadNextPacket(uint32 num) {
// Copy data from future packets into p, and delete them should they become empty. // Copy data from future packets into p, and delete them should they become empty.
do { do {
Packet *n = Packet_NEXT(p); Packet *n = Packet_NEXT(p);
uint32 bytes_to_copy = std::min(n->size, num - p->size); uint bytes_to_copy = std::min(n->size, num - p->size);
uint32 nsize = (n->size -= bytes_to_copy); uint nsize = (n->size -= bytes_to_copy);
memcpy(p->data + postinc(p->size, bytes_to_copy), postinc(n->data, bytes_to_copy), bytes_to_copy); memcpy(p->data + postinc(p->size, bytes_to_copy), postinc(n->data, bytes_to_copy), bytes_to_copy);
if (nsize == 0) { if (nsize == 0) {
p->queue_next = n->queue_next; p->queue_next = n->queue_next;
packet_pool_->FreePacketToPool(n); pool_->FreePacketToPool(n);
} }
} while (num - p->size); } while (num - p->size);
} else if (p->size > num) { } else if (p->size > num) {
// The packet has too much data. Split the packet into two packets. // The packet has too much data. Split the packet into two packets.
Packet *n = packet_pool_->AllocPacketFromPool(); Packet *n = pool_->AllocPacketFromPool();
if (!n) if (!n)
return NULL; // unable to allocate a packet....? return NULL; // unable to allocate a packet....?
if (num * 2 <= p->size) { if (num * 2 <= p->size) {
@ -108,7 +61,7 @@ Packet *TcpPacketHandler::ReadNextPacket(uint32 num) {
memcpy(n->data, postinc(p->data, num), num); memcpy(n->data, postinc(p->data, num), num);
return n; return n;
} else { } else {
uint32 overflow = p->size - num; uint overflow = p->size - num;
// There's a lot of leading data: PPPPPP NN. Move NN // There's a lot of leading data: PPPPPP NN. Move NN
n->size = overflow; n->size = overflow;
p->size = num; p->size = num;
@ -126,49 +79,666 @@ Packet *TcpPacketHandler::ReadNextPacket(uint32 num) {
return p; return p;
} }
Packet *TcpPacketHandler::GetNextWireguardPacket() { Packet *TcpPacketQueue::ReadUpTo(uint num) {
while (rqueue_bytes_ >= 2) { assert(rqueue_bytes_ != 0);
uint32 packet_header = ReadPacketHeader(rqueue_); Packet *p = rqueue_;
uint32 packet_size = packet_header & 0x3FFF; if (num < p->size)
uint32 packet_type = packet_header >> 14; return Read(num);
if (packet_size + 2 > rqueue_bytes_) rqueue_bytes_ -= p->size;
if ((rqueue_ = Packet_NEXT(p)) == NULL)
rqueue_end_ = &rqueue_;
return p;
}
void TcpPacketQueue::Add(Packet *p) {
assert(p->size != 0);
rqueue_bytes_ += p->size;
p->queue_next = NULL;
*rqueue_end_ = p;
rqueue_end_ = &Packet_NEXT(p);
}
void TcpPacketQueue::Read(uint8 *dst, uint size) {
assert(size <= rqueue_bytes_);
rqueue_bytes_ -= size;
while (size) {
Packet *packet = rqueue_;
uint n = std::min(packet->size, size);
uint8 *src = packet->data;
for (uint i = 0; i != n; i++)
*dst++ = *src++;
packet->data = src;
size -= n;
if ((packet->size -= n) == 0) {
if ((rqueue_ = Packet_NEXT(packet)) == NULL)
rqueue_end_ = &rqueue_;
pool_->FreePacketToPool(packet);
}
}
}
uint TcpPacketQueue::PeekUint16() {
return (rqueue_->size >= 2) ? ReadBE16(rqueue_->data) :
(rqueue_->data[0] << 8) + Packet_NEXT(rqueue_)->data[0];
}
TcpPacketHandler::TcpPacketHandler(SimplePacketPool *packet_pool, WgPacketObfuscator *obfuscator, bool is_incoming)
: queue_(packet_pool),
tls_queue_(packet_pool),
write_state_(is_incoming),
obfuscation_mode_(kObfuscationMode_None) {
if (obfuscator->enabled() && obfuscator->obfuscate_tcp() != TcpPacketHandler::kObfuscationMode_None) {
memcpy(encryptor_.buf, obfuscator->key(), CHACHA20POLY1305_KEYLEN);
memcpy(decryptor_.buf, obfuscator->key(), CHACHA20POLY1305_KEYLEN);
obfuscation_mode_ = obfuscator->obfuscate_tcp() != TcpPacketHandler::kObfuscationMode_Unspecified ? obfuscator->obfuscate_tcp() :
(is_incoming ? TcpPacketHandler::kObfuscationMode_Autodetect : TcpPacketHandler::kObfuscationMode_Encrypted);
read_state_ = (obfuscation_mode_ == kObfuscationMode_Encrypted) ? READ_CRYPTO_HEADER : READ_PACKET_HEADER;
} else if (!obfuscator->enabled() && obfuscator->obfuscate_tcp() > TcpPacketHandler::kObfuscationMode_None) {
RERROR("No ObfuscateKey specified. Disabling TCP obfuscation.");
}
tls_read_state_ = 0;
error_flag_ = false;
decryptor_initialized_ = false;
predicted_key_in_ = predicted_key_out_ = 0;
predicted_serial_in_ = predicted_serial_out_ = 0;
}
TcpPacketHandler::~TcpPacketHandler() {
}
enum {
kTcpPacketType_Normal = 0,
kTcpPacketType_Reserved = 1,
kTcpPacketType_Data = 2,
kTcpPacketType_Control = 3,
kTcpPacketControlType_SetKeyAndCounter = 0,
};
static void SetChachaStreamingKey(chacha20_streaming *chacha, const uint8 *key, size_t key_len) {
blake2s(chacha->buf, CHACHA20POLY1305_KEYLEN, key, key_len, chacha->buf, CHACHA20POLY1305_KEYLEN);
chacha20_streaming_init(chacha, chacha->buf);
}
size_t TcpPacketHandler::CreateTls13ClientHello(uint8 *dst) {
uint8 *dst_org = dst;
// handshake, tls 1.0
*dst++ = 0x16;
*dst++ = 0x03;
*dst++ = 0x01;
uint8 *handshake_length = postinc(dst, 2);
// handshake client hello
*dst++ = 0x01;
*dst++ = 0x00;
uint8 *handshake_inner_length = postinc(dst, 2);
// version = tls 1.2
*dst++ = 0x03;
*dst++ = 0x03;
// 32 byte random
OsGetRandomBytes(postinc(dst, 32), 32);
*dst++ = 0x20; // Session length = 32
// 32 byte session id
OsGetRandomBytes(postinc(dst, 32), 32);
bool firefox = (obfuscation_mode_ == kObfuscationMode_TlsFirefox);
if (firefox) {
static const uint8 tls_header1[] = {
// 18 cipher suites
0x00, 0x24,
0x13, 0x01, 0x13, 0x03, 0x13, 0x02, 0xc0, 0x2b, 0xc0, 0x2f, 0xcc, 0xa9, 0xcc, 0xa8, 0xc0, 0x2c, 0xc0, 0x30,
0xc0, 0x0a, 0xc0, 0x09, 0xc0, 0x13, 0xc0, 0x14, 0x00, 0x33, 0x00, 0x39, 0x00, 0x2f, 0x00, 0x35, 0x00, 0x0a,
// compression method = null
0x01, 0x00,
};
memcpy(postinc(dst, sizeof(tls_header1)), tls_header1, sizeof(tls_header1));
} else {
static const uint8 tls_header1_chrome[] = {
// 17 cipher suites
0x00, 0x22,
0xda, 0xda, 0x13, 0x01, 0x13, 0x02, 0x13, 0x03, 0xc0, 0x2b, 0xc0, 0x2f, 0xc0, 0x2c, 0xc0, 0x30, 0xcc, 0xa9,
0xcc, 0xa8, 0xc0, 0x13, 0xc0, 0x14, 0x00, 0x9c, 0x00, 0x9d, 0x00, 0x2f, 0x00, 0x35, 0x00, 0x0a,
// compression method = null
0x01, 0x00,
};
memcpy(postinc(dst, sizeof(tls_header1_chrome)), tls_header1_chrome, sizeof(tls_header1_chrome));
}
uint8 *extensions_length = postinc(dst, 2);
if (!firefox) {
static const uint8 tls_header_grease[] = { 0xaa, 0xaa, 0x00, 0x00 };
memcpy(postinc(dst, sizeof(tls_header_grease)), tls_header_grease, sizeof(tls_header_grease));
}
static const uint8 tls_header2[] = {
// extension server name
0x00, 0x00, 0x00, 0x16, 0x00, 0x14, 0x00, 0x00, 0x11, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x2e, 0x74, 0x6c, 0x73, 0x31, 0x33, 0x2e, 0x63, 0x6f, 0x6d,
// extension master secret
0x00, 0x17, 0x00, 0x00,
// extension renegotiation info
0xff, 0x01, 0x00, 0x01, 0x00,
};
memcpy(postinc(dst, sizeof(tls_header2)), tls_header2, sizeof(tls_header2));
if (firefox) {
static const uint8 tls_header_groups_ff[] = {
// extension supported groups
0x00, 0x0a, 0x00, 0x0e, 0x00, 0x0c,
0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, 0x01, 0x00, 0x01, 0x01,
// extension ec_point_formats
0x00, 0x0b, 0x00, 0x02, 0x01, 0x00,
// extension application_layer_protocol_negotiation
0x00, 0x10, 0x00, 0x0e, 0x00, 0x0c, 0x02, 0x68, 0x32, 0x08, 0x68, 0x74, 0x74, 0x70, 0x2f, 0x31, 0x2e, 0x31,
// extension status request
0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00,
// extension key share
0x00, 0x33, 0x00, 0x6b, 0x00, 0x69,
// key share x25519
0x00, 0x1d, 0x00, 0x20,
};
memcpy(postinc(dst, sizeof(tls_header_groups_ff)), tls_header_groups_ff, sizeof(tls_header_groups_ff));
// Firefox has a secp251p1 key while chrome does not
OsGetRandomBytes(postinc(dst, 32), 32);
dst[-1] &= 0x7f; // clear top bit of x25519 key
static const uint8 tls_header3[] = {
// key share secp256p1
0x00, 0x17, 0x00, 0x41,
0x04,
};
memcpy(postinc(dst, sizeof(tls_header3)), tls_header3, sizeof(tls_header3));
// todo: validate the secp256p1 key
OsGetRandomBytes(postinc(dst, 64), 64);
static const uint8 tls_header4[] = {
// extension early data (seems to be sent only in resume)
0x00, 0x2a, 0x00, 0x00,
// extension supported versions
0x00, 0x2b, 0x00, 0x09, 0x08, 0x03, 0x04, 0x03, 0x03, 0x03, 0x02, 0x03, 0x01,
// extension signature_algorithms
0x00, 0x0d, 0x00, 0x18, 0x00, 0x16, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01, 0x06, 0x01, 0x02, 0x03, 0x02, 0x01,
// extension psk_key_exchange_modes
0x00, 0x2d, 0x00, 0x02, 0x01, 0x01,
// extension unknown type 28
0x00, 0x1c, 0x00, 0x02, 0x40, 0x01,
// extension pre shared key length=235
0x00, 0x29, 0x00, 0xeb,
// identities length=198, psk identity length = 192
0x00, 0xc6, 0x00, 0xc0,
};
memcpy(postinc(dst, sizeof(tls_header4)), tls_header4, sizeof(tls_header4));
} else {
static const uint8 tls_header_groups_chrome[] = {
// extension supported groups
0x00, 0x0a, 0x00, 0x0a, 0x00, 0x08, 0x2a, 0x2a, 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18,
// extension ec_point_formats
0x00, 0x0b, 0x00, 0x02, 0x01, 0x00,
// extension sessionticket tls
0x00, 0x23, 0x00, 0x00,
// extension application_layer_protocol_negotiation
0x00, 0x10, 0x00, 0x0e, 0x00, 0x0c, 0x02, 0x68, 0x32, 0x08, 0x68, 0x74, 0x74, 0x70, 0x2f, 0x31, 0x2e, 0x31,
// extension status request
0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00,
// extension signature_algorithms
0x00, 0x0d, 0x00, 0x14, 0x00, 0x12, 0x04, 0x03, 0x08, 0x04, 0x04, 0x01, 0x05, 0x03, 0x08, 0x05, 0x05, 0x01, 0x08, 0x06, 0x06, 0x01, 0x02, 0x01,
// extension signed_certificate_timestamp
0x00, 0x12, 0x00, 0x00,
// extension key_share
0x00, 0x33, 0x00, 0x2b, 0x00, 0x29,
0x2a, 0x2a, 0x00, 0x01, 0x00,
0x00, 0x1d, 0x00, 0x20,
};
memcpy(postinc(dst, sizeof(tls_header_groups_chrome)), tls_header_groups_chrome, sizeof(tls_header_groups_chrome));
OsGetRandomBytes(postinc(dst, 32), 32);
dst[-1] &= 0x7f; // clear top bit of x25519 key
static const uint8 tls_header4_chrome[] = {
// extension psk_key_exchange_modes
0x00, 0x2d, 0x00, 0x02, 0x01, 0x01,
// extension supported versions
0x00, 0x2b, 0x00, 0x0b, 0x0a, 0x1a, 0x1a, 0x03, 0x04, 0x03, 0x03, 0x03, 0x02, 0x03, 0x01,
// extension unknown type 27
0x00, 0x1b, 0x00, 0x03, 0x02, 0x00, 0x02,
// extension reserved (grease)
0xea, 0xea, 0x00, 0x01, 0x00,
// extension pre shared key length=235
0x00, 0x29, 0x00, 0xeb,
// identities length=198, psk identity length = 192
0x00, 0xc6, 0x00, 0xc0,
};
memcpy(postinc(dst, sizeof(tls_header4_chrome)), tls_header4_chrome, sizeof(tls_header4_chrome));
}
OsGetRandomBytes(postinc(dst, 192 + 4), 192 + 4);
static const uint8 tls_header5[] = {
// psk binders length
0x00, 0x21,
};
memcpy(postinc(dst, sizeof(tls_header5)), tls_header5, sizeof(tls_header5));
OsGetRandomBytes(postinc(dst, 33), 33);
// Fixup lengths
WriteBE16(handshake_length, (uint)(dst - dst_org - 5));
WriteBE16(handshake_inner_length, (uint)(dst - dst_org - 9));
WriteBE16(extensions_length, (uint)(dst - extensions_length - 2));
// Setup the key generator for outgoing packets. It will be the blake2s hash of
// the full message excluding the tls header.
SetChachaStreamingKey(&encryptor_, dst_org + 5, dst - dst_org - 5);
static const uint8 tls_header6[] = {
// change cipher spec
0x14, 0x03, 0x03, 0x00, 0x01, 0x01
};
memcpy(postinc(dst, sizeof(tls_header6)), tls_header6, sizeof(tls_header6));
return dst - dst_org;
}
size_t TcpPacketHandler::CreateTls13ServerHello(uint8 *dst) {
if (!decryptor_initialized_)
return ~(size_t)0;
uint8 *dst_org = dst;
// handshake, tls 1.0
*dst++ = 0x16;
*dst++ = 0x03;
*dst++ = 0x03;
uint8 *handshake_length = postinc(dst, 2);
// handshake client hello
*dst++ = 0x02;
*dst++ = 0x00;
uint8 *handshake_inner_length = postinc(dst, 2);
// version = tls 1.2
*dst++ = 0x03;
*dst++ = 0x03;
// 32 byte random
OsGetRandomBytes(postinc(dst, 32), 32);
*dst++ = 0x20; // Session length = 32
// 32 byte session id taken from client hello.
memcpy(postinc(dst, 32), tls_session_id_, 32);
// cipher suite
*dst++ = 0x13;
*dst++ = 0x01;
// compression method
*dst++ = 0x00;
uint8 *extensions_length = postinc(dst, 2);
static const uint8 tls_s_header0[] = {
// extension pre_shared_key
0x00, 0x29, 0x00, 0x02, 0x00, 0x00,
// extension key share with x25519 key
0x00, 0x33, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20,
};
memcpy(postinc(dst, sizeof(tls_s_header0)), tls_s_header0, sizeof(tls_s_header0));
OsGetRandomBytes(postinc(dst, 32), 32);
dst[-1] &= 0x7f; // clear top bit of x25519 key
static const uint8 tls_s_header1[] = {
// extension supported version tls1.3
0x00, 0x2b, 0x00, 0x02, 0x03, 0x04,
};
memcpy(postinc(dst, sizeof(tls_s_header1)), tls_s_header1, sizeof(tls_s_header1));
WriteBE16(handshake_length, (uint)(dst - dst_org - 5));
WriteBE16(handshake_inner_length, (uint)(dst - dst_org - 9));
WriteBE16(extensions_length, (uint)(dst - extensions_length - 2));
// Setup the key generator for outgoing packets. It will be the blake2s hash of
// the full message excluding the tls header.
SetChachaStreamingKey(&encryptor_, dst_org + 5, dst - dst_org - 5);
static const uint8 tls_header6[] = {
// change cipher spec
0x14, 0x03, 0x03, 0x00, 0x01, 0x01
};
memcpy(postinc(dst, sizeof(tls_header6)), tls_header6, sizeof(tls_header6));
return dst - dst_org;
}
// Normal packet without obfuscation
void TcpPacketHandler::PrepareOutgoingPacketsNormal(Packet *p) {
uint8 *data = p->data;
uint data_size = p->size, packet_type = ReadLE32(data);
p->prepared = true;
if (packet_type == 4) {
assert(data_size >= 16);
uint32 key = Read32(data + 4);
uint64 serial = ReadLE64(data + 8);
if (((predicted_key_out_ ^ key) | (exch(predicted_serial_out_, serial) ^ (serial - 1))) == 0) {
p->data = data + 14;
p->size = data_size - 14;
WriteBE16(p->data, 0x8000 + data_size - 16);
return;
}
predicted_key_out_ = key;
}
p->size = data_size + 2;
p->data = data - 2;
WriteBE16(p->data, data_size);
}
// Obfuscated stream that looks totally random
void TcpPacketHandler::PrepareOutgoingPacketsObfuscate(Packet *p) {
uint8 *data = p->data;
uint data_size = p->size, packet_type = ReadLE32(data);
p->prepared = true;
// When obfuscation is enabled, inject random shit into packets.
if ((packet_type == 4 && data_size <= 32) || packet_type < 4) {
if (packet_type != 4) {
assert(data_size >= 48);
// The 39:th (for handshake init) and 43:rd byte (for handshake response)
// have zero MSB because of curve25519 pubkey, so xor it with random.
if (packet_type < 4)
data[35 + packet_type * 4] ^= data[15];
} else {
predicted_key_out_ = Read32(data + 4);
predicted_serial_out_ = ReadLE64(data + 8);
}
data_size = (uint)WgPacketObfuscator::InsertRandomBytesIntoPacket(data, data_size);
} else if (packet_type == 4) {
assert(data_size >= 16);
uint32 key = Read32(data + 4);
uint64 serial = ReadLE64(data + 8);
if (((exch(predicted_key_out_, key) ^ key) | (exch(predicted_serial_out_, serial) ^ (serial - 1))) == 0) {
p->data = data + 14;
p->size = data_size - 14;
WriteBE16(p->data, 0x8000 + data_size - 16);
chacha20_streaming_crypt(&encryptor_, p->data, 2);
return;
}
}
p->data = data - 2;
p->size = data_size + 2;
WriteBE16(p->data, data_size);
chacha20_streaming_crypt(&encryptor_, p->data, 18);
}
static void PrependTlsApplicationData(Packet *p, uint data_size) {
p->size += 5;
p->data -= 5;
p->data[0] = 0x17;
p->data[1] = 0x03;
p->data[2] = 0x03;
p->data[4] = (uint8)data_size;
p->data[3] = (uint8)(data_size >> 8);
}
void TcpPacketHandler::PrepareOutgoingPacketsTLS13(Packet *p) {
// Collect a number of packets, but add just a single TLS header
uint total_size = 0;
Packet *cur = p;
do {
PrepareOutgoingPacketsObfuscate(cur);
total_size += cur->size;
} while (total_size < 12000 && (cur = Packet_NEXT(cur)));
PrependTlsApplicationData(p, total_size);
}
Packet *TcpPacketHandler::GetNextWireguardPacketObfuscate(TcpPacketQueue *queue) {
if (read_state_ == READ_CRYPTO_HEADER) {
// Wait for the 64 bytes of crypto header, they will
// be used to seed the decryptor.
if (queue->size() < CRYPTO_HEADER_SIZE)
return NULL; return NULL;
if (packet_size + 2 > kPacketCapacity) { Packet *packet = queue->Read(CRYPTO_HEADER_SIZE);
RERROR("Oversized packet?"); if (!packet)
return NULL;
SetChachaStreamingKey(&decryptor_, packet->data, CRYPTO_HEADER_SIZE);
queue->pool()->FreePacketToPool(packet);
read_state_ = READ_PACKET_HEADER;
} else if (read_state_ == READ_PACKET_DATA) {
goto case_READ_PACKET_DATA;
}
while (queue->size() >= 2) {
// Peek and decrypt the packet header
queue->Read(packet_header_, 2);
chacha20_streaming_crypt(&decryptor_, packet_header_, 2);
case_READ_PACKET_DATA:
uint32 packet_header = ReadBE16(packet_header_);
uint32 packet_size = packet_header & 0x7FFF;
if (packet_size > kPacketCapacity) {
error:
error_flag_ = true; error_flag_ = true;
return NULL; return NULL;
} }
Packet *packet = ReadNextPacket(packet_size + 2); if (packet_size > queue->size()) {
if (packet) { read_state_ = READ_PACKET_DATA;
// RINFO("Packet of type %d, size %d", packet_type, packet->size - 2); return NULL;
packet->data += 2, packet->size -= 2; }
if (packet_type == kTcpPacketType_Normal) { read_state_ = READ_PACKET_HEADER;
Packet *packet = queue->Read(packet_size);
if (!packet)
goto error;
// RINFO("Packet of type %d, size %d", packet_type, packet->size - 2);
if (!(packet_header & 0x8000)) {
unsigned int size = packet->size;
// decrypt the initial 16 bytes of the packet
if (size < 16)
goto error;
chacha20_streaming_crypt(&decryptor_, packet->data, 16);
// Discard any extra junk bytes appended at the end.
if (packet->data[0] <= 4) {
if (packet->data[3] > size)
goto error;
packet->size = (size -= packet->data[3]);
packet->data[3] = 0;
// The 39:th (for handshake init) and 43:rd byte (for handshake response)
// have zero MSB because of curve25519 pubkey, so xor it with random.
if (packet->data[0] < 4 && size >= 48)
packet->data[35 + packet->data[0] * 4] ^= packet->data[15];
}
if (packet->data[0] == 4) {
predicted_key_in_ = Read32(packet->data + 4);
predicted_serial_in_ = ReadLE64(packet->data + 8);
}
return packet; return packet;
} else if (packet_type == kTcpPacketType_Data) { } else {
// Optimization when the 16 first bytes are known and prefixed to the packet // Optimization when the 16 first bytes are known and prefixed to the packet
assert(packet->data >= packet->data_buf); assert(packet->data >= packet->data_buf);
packet->data -= 16, packet->size += 16; packet->data -= 16, packet->size += 16;
predicted_serial_in_++;
WriteLE32(packet->data, 4); WriteLE32(packet->data, 4);
Write32(packet->data + 4, predicted_key_in_); Write32(packet->data + 4, predicted_key_in_);
WriteLE64(packet->data + 8, predicted_serial_in_); WriteLE64(packet->data + 8, predicted_serial_in_);
predicted_serial_in_++;
return packet; return packet;
} else if (packet_type == kTcpPacketType_Control) {
// Unknown control packets are silently ignored
if (packet->size == 13 && packet->data[0] == kTcpPacketControlType_SetKeyAndCounter) {
// Control packet to setup the predicted key/sequence nr
predicted_key_in_ = Read32(packet->data + 1);
predicted_serial_in_ = ReadLE64(packet->data + 5);
}
packet_pool_->FreePacketToPool(packet);
} else {
packet_pool_->FreePacketToPool(packet);
error_flag_ = true;
return NULL;
}
} }
} }
return NULL; return NULL;
} }
Packet *TcpPacketHandler::GetNextWireguardPacketNormal() {
while (queue_.size() >= 2) {
uint32 packet_header = queue_.PeekUint16();
uint32 packet_size = packet_header & 0x7FFF;
if (packet_size + 2 > kPacketCapacity) {
error:
error_flag_ = true;
return NULL;
}
if (packet_size + 2 > queue_.size())
return NULL;
Packet *packet = queue_.Read(packet_size + 2);
if (!packet)
goto error;
if (!(packet_header & 0x8000)) {
packet->data += 2, packet->size -= 2;
if (packet->data[0] == 4 && packet->size >= 16) {
predicted_key_in_ = Read32(packet->data + 4);
predicted_serial_in_ = ReadLE64(packet->data + 8);
}
} else {
// Optimization when the 16 first bytes are known and prefixed to the packet
assert(packet->data >= packet->data_buf);
packet->data -= 14, packet->size += 14;
predicted_serial_in_++;
WriteLE32(packet->data, 4);
Write32(packet->data + 4, predicted_key_in_);
WriteLE64(packet->data + 8, predicted_serial_in_);
}
return packet;
}
return NULL;
}
#define TLS_ASYNC_BEGIN() switch (tls_read_state_) {
#define TLS_ASYNC_RESUMEPOINT(label) tls_read_state_ = (label); case label:
#define TLS_ASYNC_WAIT(expr, label) case label: if (!(expr)) { tls_read_state_ = (label); return NULL; }
#define TLS_ASYNC_END() }
// Unwrap the TLS framing
Packet *TcpPacketHandler::GetNextWireguardPacketTLS13() {
uint8 header[5];
Packet *packet;
enum {
TLS_STATE_INIT = 0,
TLS_WAIT_HANDSHAKE = 1,
TLS_WAIT_DATA = 2,
TLS_READ_PACKETS = 3,
TLS_WAIT_JUNK = 4,
TLS_ERROR = 5,
};
TLS_ASYNC_BEGIN();
for(;;) {
TLS_ASYNC_WAIT(queue_.size() >= 5, TLS_STATE_INIT);
queue_.Read(header, 5);
tls_bytes_left_ = ReadBE16(header + 3);
if (header[0] == 23) {
if (!decryptor_initialized_)
goto error; // no key yet
// Read the next |tls_bytes_left_| bytes and push them to the tls_queue_.
while (tls_bytes_left_ != 0) {
TLS_ASYNC_WAIT(queue_.size() != 0, TLS_WAIT_DATA);
if (!(packet = queue_.ReadUpTo(tls_bytes_left_))) goto error;
tls_bytes_left_ -= packet->size;
tls_queue_.Add(packet);
TLS_ASYNC_RESUMEPOINT(TLS_READ_PACKETS);
if ((packet = GetNextWireguardPacketObfuscate(&tls_queue_)) != NULL)
return packet;
}
} else {
if (tls_bytes_left_ > kPacketCapacity)
goto error; // too large packet?
if (header[0] == 22) {
TLS_ASYNC_WAIT(tls_bytes_left_ <= queue_.size(), TLS_WAIT_HANDSHAKE);
if (!(packet = queue_.Read(tls_bytes_left_)))
goto error; // eom
// Initialize decryptor
if (!decryptor_initialized_ && packet->size >= 39 + 32) {
// Store the session ID, so we can include it in server hello.
memcpy(tls_session_id_, packet->data + 39, 32);
// Initialize chacha decryptor
SetChachaStreamingKey(&decryptor_, packet->data, packet->size);
decryptor_initialized_ = true;
}
FreePacket(packet);
} else if (header[0] == 20) {
TLS_ASYNC_WAIT(tls_bytes_left_ <= queue_.size(), TLS_WAIT_JUNK);
if (!(packet = queue_.Read(tls_bytes_left_)))
goto error; // eom
FreePacket(packet);
} else {
error:
TLS_ASYNC_RESUMEPOINT(TLS_ERROR);
error_flag_ = true;
return NULL;
}
}
}
TLS_ASYNC_END();
return NULL;
}
void TcpPacketHandler::PrepareOutgoingPacketsWithHeader(Packet *p) {
uint8 buf[1024];
size_t hello_size;
if (obfuscation_mode_ == kObfuscationMode_Encrypted) {
// Ensure it doesn't look like a tls or a regular packet.
do {
OsGetRandomBytes(buf, CRYPTO_HEADER_SIZE);
} while (ReadBE16(buf) == 0x1603 || ReadBE16(buf) <= 1500);
SetChachaStreamingKey(&encryptor_, buf, CRYPTO_HEADER_SIZE);
hello_size = CRYPTO_HEADER_SIZE;
} else {
hello_size = (write_state_ == 0) ? CreateTls13ClientHello(buf) : CreateTls13ServerHello(buf);
// This could fail if the server tries to send a packet before the client sent hello.
if (hello_size == ~(size_t)0) {
RERROR("Trying to send server message before client hello");
p->size = 0;
return;
}
}
write_state_ = 2;
PrepareOutgoingPackets(p);
if (hello_size + p->size > kPacketCapacity) {
RERROR("Outgoing TCP packet too big.");
return;
}
memmove(p->data_buf + hello_size, exch(p->data, p->data_buf), postinc(p->size, (uint)hello_size));
memcpy(p->data_buf, buf, hello_size);
}
void TcpPacketHandler::PrepareOutgoingPackets(Packet *p) {
if (obfuscation_mode_ == kObfuscationMode_None) {
PrepareOutgoingPacketsNormal(p);
} else {
if (write_state_ != 2) {
PrepareOutgoingPacketsWithHeader(p);
return;
}
if (obfuscation_mode_ == kObfuscationMode_Encrypted)
PrepareOutgoingPacketsObfuscate(p);
else
PrepareOutgoingPacketsTLS13(p);
}
}
Packet *TcpPacketHandler::GetNextWireguardPacket() {
// If this is an incoming connection, try to guess what type of obfuscation
// we're using, if any.
for (;;) {
if (obfuscation_mode_ == kObfuscationMode_None)
return GetNextWireguardPacketNormal();
else if (obfuscation_mode_ == kObfuscationMode_Encrypted)
return GetNextWireguardPacketObfuscate(&queue_);
else if (obfuscation_mode_ != kObfuscationMode_Autodetect)
return GetNextWireguardPacketTLS13();
// Try and autodetect based on the first 2 bytes.
if (queue_.size() < 2)
return NULL;
uint16 header = queue_.PeekUint16();
if (header == 0x1603) {
// This is a SSL client hello, but don't know if it's
// chrome or ff, so use ff.
obfuscation_mode_ = kObfuscationMode_TlsFirefox;
} else if (header <= 1500) {
// Unobfuscated wireguard headers always start with a low value.
obfuscation_mode_ = kObfuscationMode_None;
} else {
read_state_ = READ_CRYPTO_HEADER;
obfuscation_mode_ = kObfuscationMode_Encrypted;
}
}
}
#if defined(OS_WIN) || defined(USE_MULTITHREADED_NETWORKING)
void SimplePacketPool::FreeSomePacketsInner() {
int n = freed_packets_count_ - 24;
Packet **p = &freed_packets_;
for (; n; n--)
p = &Packet_NEXT(*p);
FreePackets(exch(freed_packets_, *p), p, exch(freed_packets_count_, 24) - 24);
}
#endif

View file

@ -2,8 +2,10 @@
#define TUNSAFE_NETWORK_COMMON_H_ #define TUNSAFE_NETWORK_COMMON_H_
#include "netapi.h" #include "netapi.h"
#include "crypto/chacha20poly1305.h"
class PacketProcessor; class PacketProcessor;
class WgPacketObfuscator;
// A simple singlethreaded pool of packets used on windows where // A simple singlethreaded pool of packets used on windows where
// FreePacket / AllocPacket are multithreded and thus slightly slower // FreePacket / AllocPacket are multithreded and thus slightly slower
@ -54,19 +56,55 @@ public:
#endif #endif
class TcpPacketQueue {
public:
explicit TcpPacketQueue(SimplePacketPool *pool) : rqueue_bytes_(0), rqueue_(NULL), rqueue_end_(&rqueue_), pool_(pool) {}
~TcpPacketQueue();
Packet *Read(uint num);
Packet *ReadUpTo(uint num);
void Read(uint8 *dst, uint num);
uint PeekUint16();
void Add(Packet *packet);
uint32 size() const { return rqueue_bytes_; }
SimplePacketPool *pool() { return pool_; }
private:
// Total # of bytes queued
uint rqueue_bytes_;
// Buffered data
Packet *rqueue_, **rqueue_end_;
SimplePacketPool *pool_;
};
// Aids with prefixing and parsing incoming and outgoing // Aids with prefixing and parsing incoming and outgoing
// packets with the tcp protocol header. // packets with the tcp protocol header.
class TcpPacketHandler { class TcpPacketHandler {
public: public:
explicit TcpPacketHandler(SimplePacketPool *packet_pool); enum {
kObfuscationMode_Unspecified = -1,
kObfuscationMode_None = 0,
kObfuscationMode_Encrypted = 1,
kObfuscationMode_TlsFirefox = 2,
kObfuscationMode_TlsChrome = 3,
kObfuscationMode_Autodetect = 4,
};
explicit TcpPacketHandler(SimplePacketPool *packet_pool, WgPacketObfuscator *obfuscator, bool is_incoming);
~TcpPacketHandler(); ~TcpPacketHandler();
// Adds a tcp header to a data packet so it can be transmitted on the wire // Adds a tcp header to a data packet so it can be transmitted on the wire
void AddHeaderToOutgoingPacket(Packet *p); void PrepareOutgoingPackets(Packet *p);
// Add a new chunk of incoming data to the packet list // Add a new chunk of incoming data to the packet list
void QueueIncomingPacket(Packet *p); void QueueIncomingPacket(Packet *p) {
queue_.Add(p);
}
// Attempt to extract the next packet, returns NULL when complete. // Attempt to extract the next packet, returns NULL when complete.
Packet *GetNextWireguardPacket(); Packet *GetNextWireguardPacket();
@ -74,22 +112,43 @@ public:
bool error() const { return error_flag_; } bool error() const { return error_flag_; }
private: private:
// Internal function to read a packet void PrepareOutgoingPacketsNormal(Packet *p);
Packet *ReadNextPacket(uint32 num); void PrepareOutgoingPacketsObfuscate(Packet *p);
void PrepareOutgoingPacketsTLS13(Packet *p);
void PrepareOutgoingPacketsWithHeader(Packet *p);
SimplePacketPool *packet_pool_; Packet *GetNextWireguardPacketNormal();
Packet *GetNextWireguardPacketObfuscate(TcpPacketQueue *queue);
Packet *GetNextWireguardPacketTLS13();
size_t CreateTls13ClientHello(uint8 *dst);
size_t CreateTls13ServerHello(uint8 *dst);
// Total # of bytes queued
uint32 rqueue_bytes_;
// Set if there's a fatal error // Set if there's a fatal error
bool error_flag_; bool error_flag_;
uint8 obfuscation_mode_;
uint8 read_state_, write_state_, tls_read_state_;
bool decryptor_initialized_;
// These hold the incoming packets before they're parsed uint8 packet_header_[2];
Packet *rqueue_, **rqueue_end_;
// Number of data bytes left
uint tls_bytes_left_;
TcpPacketQueue queue_;
// There's a separate queue for tls since it unwraps stuff
TcpPacketQueue tls_queue_;
uint32 predicted_key_in_, predicted_key_out_; uint32 predicted_key_in_, predicted_key_out_;
uint64 predicted_serial_in_, predicted_serial_out_; uint64 predicted_serial_in_, predicted_serial_out_;
// For obfuscating
chacha20_streaming encryptor_, decryptor_;
uint8 tls_session_id_[32];
}; };
#endif // TUNSAFE_NETWORK_COMMON_H_ #endif // TUNSAFE_NETWORK_COMMON_H_

View file

@ -90,13 +90,6 @@ void FreeAllPackets() {
} }
} }
void SimplePacketPool::FreeSomePacketsInner() {
int n = freed_packets_count_ - 24;
Packet **p = &freed_packets_;
for (; n; n--)
p = &Packet_NEXT(*p);
FreePackets(exch(freed_packets_, *p), p, exch(freed_packets_count_, 24) - 24);
}
void InitPacketMutexes() { void InitPacketMutexes() {
static bool mutex_inited; static bool mutex_inited;
@ -169,7 +162,7 @@ static bool RunNetsh(const char *cmdline) {
// Open the TAP adapter, either a random one or a specific one // Open the TAP adapter, either a random one or a specific one
// On return, the adapter is locked in |TunAdaptersInUse|. // On return, the adapter is locked in |TunAdaptersInUse|.
static HANDLE OpenTunAdapter(char guid[ADAPTER_GUID_SIZE], TunsafeBackendWin32 *backend, DWORD open_flags) { static HANDLE OpenTunAdapter(char guid[ADAPTER_GUID_SIZE], TunsafeRunner *runner, DWORD open_flags) {
char path[128]; char path[128];
HANDLE h; HANDLE h;
int retries = 0; int retries = 0;
@ -196,7 +189,7 @@ RETRY:
int error_code = 0; int error_code = 0;
for (GuidAndDevName &x : adapters) { for (GuidAndDevName &x : adapters) {
snprintf(path, sizeof(path), "\\\\.\\Global\\%s.tap", x.guid); snprintf(path, sizeof(path), "\\\\.\\Global\\%s.tap", x.guid);
if (tun_adapters_in_use->Acquire(x.guid, static_cast<TunsafeBackend*>(backend))) { if (tun_adapters_in_use->Acquire(x.guid, static_cast<TunsafeBackend*>(runner->backend()))) {
h = CreateFile(path, GENERIC_READ | GENERIC_WRITE, 0, 0, OPEN_EXISTING, FILE_ATTRIBUTE_SYSTEM | open_flags, 0); h = CreateFile(path, GENERIC_READ | GENERIC_WRITE, 0, 0, OPEN_EXISTING, FILE_ATTRIBUTE_SYSTEM | open_flags, 0);
if (h != INVALID_HANDLE_VALUE) { if (h != INVALID_HANDLE_VALUE) {
memcpy(guid, x.guid, ADAPTER_GUID_SIZE); memcpy(guid, x.guid, ADAPTER_GUID_SIZE);
@ -204,7 +197,7 @@ RETRY:
} }
did_try_adapter = true; did_try_adapter = true;
error_code = GetLastError(); error_code = GetLastError();
tun_adapters_in_use->Release(static_cast<TunsafeBackend*>(backend)); tun_adapters_in_use->Release(static_cast<TunsafeBackend*>(runner->backend()));
} }
} }
if (!did_try_adapter) { if (!did_try_adapter) {
@ -214,7 +207,7 @@ RETRY:
// Sometimes if you close the device right before, it will fail to open with errorcode 31. // Sometimes if you close the device right before, it will fail to open with errorcode 31.
// When resuming from sleep in my VM, the error code is ERROR_FILE_NOT_FOUND // When resuming from sleep in my VM, the error code is ERROR_FILE_NOT_FOUND
if ((error_code == ERROR_FILE_NOT_FOUND || error_code == ERROR_GEN_FAILURE) && !backend->exit_code()) { if ((error_code == ERROR_FILE_NOT_FOUND || error_code == ERROR_GEN_FAILURE) && !runner->exit_code()) {
if (retries <= 10) { if (retries <= 10) {
RERROR("OpenTapAdapter: CreateFile failed: 0x%X... retrying%s", error_code, retries == 10 ? " (last notice)" : ""); RERROR("OpenTapAdapter: CreateFile failed: 0x%X... retrying%s", error_code, retries == 10 ? " (last notice)" : "");
if (retries == 10) { if (retries == 10) {
@ -223,12 +216,12 @@ RETRY:
} else if (error_code == ERROR_GEN_FAILURE) { } else if (error_code == ERROR_GEN_FAILURE) {
RERROR(" Please ensure that the TAP device is not in use."); RERROR(" Please ensure that the TAP device is not in use.");
} }
backend->SetStatus(TunsafeBackend::kStatusTunRetrying); runner->backend()->SetStatus(TunsafeBackend::kStatusTunRetrying);
} }
} }
int sleep_amount = 250 * std::min(++retries, 40); int sleep_amount = 250 * std::min(++retries, 40);
for (;;) { for (;;) {
if (backend->exit_code()) if (runner->exit_code())
return NULL; return NULL;
if (sleep_amount == 0) if (sleep_amount == 0)
break; break;
@ -699,7 +692,7 @@ void UdpSocketWin32::DoIO() {
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////
NetworkWin32::NetworkWin32() : udp_socket_(this), tcp_socket_queue_(this) { NetworkWin32::NetworkWin32() : udp_socket_(this) {
exit_thread_ = false; exit_thread_ = false;
thread_ = NULL; thread_ = NULL;
completion_port_handle_ = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, NULL, 0); completion_port_handle_ = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, NULL, 0);
@ -803,21 +796,6 @@ void NetworkWin32::PostQueuedItem(QueuedItem *item) {
PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, &item->overlapped); PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, &item->overlapped);
} }
bool NetworkWin32::Configure(int listen_port, int listen_port_tcp) {
if (listen_port_tcp)
RERROR("ListenPortTCP not supported in this version");
return udp_socket_.Configure(listen_port);
}
// Called from tunsafe thread
void NetworkWin32::WriteUdpPacket(Packet *packet) {
if (packet->protocol & kPacketProtocolUdp) {
udp_socket_.WriteUdpPacket(packet);
} else {
tcp_socket_queue_.WritePacket(packet);
}
}
///////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////
PacketProcessor::PacketProcessor() { PacketProcessor::PacketProcessor() {
@ -829,6 +807,7 @@ PacketProcessor::PacketProcessor() {
timer_interrupt_ = false; timer_interrupt_ = false;
packets_in_queue_ = 0; packets_in_queue_ = 0;
need_notify_ = 0; need_notify_ = 0;
udp_cb_maybe_deobfuscate_ = &udp_cb_;
} }
PacketProcessor::~PacketProcessor() { PacketProcessor::~PacketProcessor() {
@ -866,13 +845,13 @@ void PacketProcessor::Reset() {
} }
} }
int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) { int PacketProcessor::Run(WireguardProcessor *wg, TunsafeRunner *runner) {
int free_packets_ctr = 0; int free_packets_ctr = 0;
int overload = 0; int overload = 0;
int exit_code; int exit_code;
QueuedItem *packet; QueuedItem *packet;
PTP_TIMER threadpool_timer; PTP_TIMER threadpool_timer;
QueueContext queue_context = {wg, backend}; QueueContext queue_context = {wg, runner};
threadpool_timer = CreateThreadpoolTimer(&ThreadPoolTimerCallback, this, NULL); threadpool_timer = CreateThreadpoolTimer(&ThreadPoolTimerCallback, this, NULL);
static const int64 duetime = -10000000; // the unit is 100ns static const int64 duetime = -10000000; // the unit is 100ns
@ -880,25 +859,13 @@ int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) {
mutex_.Acquire(); mutex_.Acquire();
while (!(exit_code = exit_code_)) { while (!(exit_code = exit_code_)) {
FreeAllPackets();
if (timer_interrupt_) { if (timer_interrupt_) {
timer_interrupt_ = false; timer_interrupt_ = false;
need_notify_ = 0; need_notify_ = 0;
mutex_.Release(); mutex_.Release();
wg->SecondLoop(); wg->SecondLoop();
backend->stats_mutex_.Acquire();
backend->stats_ = wg->GetStats();
float data[2] = {
// unit is megabits/second
backend->stats_.tun_bytes_in_per_second * (1.0f / 125000),
backend->stats_.tun_bytes_out_per_second * (1.0f / 125000),
};
backend->stats_collector_.AddSamples(data);
backend->stats_mutex_.Release();
backend->delegate_->OnGraphAvailable(); runner->CollectStats();
backend->PushStats();
// Conserve memory every 10s // Conserve memory every 10s
if (free_packets_ctr++ == 10) { if (free_packets_ctr++ == 10) {
@ -933,7 +900,6 @@ int PacketProcessor::Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend) {
wg->RunAllMainThreadScheduled(); wg->RunAllMainThreadScheduled();
mutex_.Acquire(); mutex_.Acquire();
} }
exit_code_ = 0;
mutex_.Release(); mutex_.Release();
SetThreadpoolTimer(threadpool_timer, nullptr, 0, 0); SetThreadpoolTimer(threadpool_timer, nullptr, 0, 0);
@ -970,8 +936,6 @@ void PacketProcessorDeobfuscateUdpCb::OnQueuedItemEvent(QueuedItem *qi, uintptr_
void PacketProcessor::PostExit(int exit_code) { void PacketProcessor::PostExit(int exit_code) {
mutex_.Acquire(); mutex_.Acquire();
// Avoid race condition where mode_tun_failed is set during thread exit.
if (exit_code_ != TunsafeBackendWin32::MODE_RESTART && exit_code_ != TunsafeBackendWin32::MODE_EXIT)
exit_code_ = exit_code; exit_code_ = exit_code;
mutex_.Release(); mutex_.Release();
SetEvent(event_); SetEvent(event_);
@ -1229,12 +1193,12 @@ TunWin32Adapter::~TunWin32Adapter() {
} }
bool TunWin32Adapter::OpenAdapter(TunsafeBackendWin32 *backend, DWORD open_flags) { bool TunWin32Adapter::OpenAdapter(TunsafeRunner *runner, DWORD open_flags) {
ULONG info[3]; ULONG info[3];
DWORD len; DWORD len;
assert(handle_ == NULL); assert(handle_ == NULL);
backend_ = backend; backend_ = runner->backend();
handle_ = OpenTunAdapter(guid_, backend, open_flags); handle_ = OpenTunAdapter(guid_, runner, open_flags);
if (handle_ != NULL) { if (handle_ != NULL) {
memset(info, 0, sizeof(info)); memset(info, 0, sizeof(info));
if (DeviceIoControl(handle_, TAP_IOCTL_GET_VERSION, &info, sizeof(info), if (DeviceIoControl(handle_, TAP_IOCTL_GET_VERSION, &info, sizeof(info),
@ -1664,7 +1628,7 @@ bool TunWin32Adapter::RunPrePostCommand(const std::vector<std::string> &vec) {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
TunWin32Iocp::TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker, backend->guid_), backend_(backend) { TunWin32Iocp::TunWin32Iocp(DnsBlocker *blocker, TunsafeRunner *runner) : adapter_(blocker, runner->backend()->guid_), runner_(runner) {
wqueue_end_ = &wqueue_; wqueue_end_ = &wqueue_;
wqueue_ = NULL; wqueue_ = NULL;
wqueue_size_ = 0; wqueue_size_ = 0;
@ -1680,7 +1644,6 @@ TunWin32Iocp::~TunWin32Iocp() {
//assert(num_reads_ == 0 && num_writes_ == 0); //assert(num_reads_ == 0 && num_writes_ == 0);
assert(thread_ == NULL); assert(thread_ == NULL);
CloseTun(false); CloseTun(false);
FreePacketList(wqueue_);
} }
bool TunWin32Iocp::Configure(const TunConfig &&config, TunConfigOut *out) { bool TunWin32Iocp::Configure(const TunConfig &&config, TunConfigOut *out) {
@ -1693,7 +1656,7 @@ bool TunWin32Iocp::Configure(const TunConfig &&config, TunConfigOut *out) {
return rv; return rv;
} }
CloseTun(true); CloseTun(true);
if (adapter_.OpenAdapter(backend_, FILE_FLAG_OVERLAPPED)) { if (adapter_.OpenAdapter(runner_, FILE_FLAG_OVERLAPPED)) {
completion_port_handle_ = CreateIoCompletionPort(adapter_.handle(), NULL, NULL, 0); completion_port_handle_ = CreateIoCompletionPort(adapter_.handle(), NULL, NULL, 0);
if (completion_port_handle_ != NULL) { if (completion_port_handle_ != NULL) {
if (adapter_.ConfigureAdapter(std::move(config), out)) if (adapter_.ConfigureAdapter(std::move(config), out))
@ -1707,10 +1670,9 @@ bool TunWin32Iocp::Configure(const TunConfig &&config, TunConfigOut *out) {
void TunWin32Iocp::CloseTun(bool is_restart) { void TunWin32Iocp::CloseTun(bool is_restart) {
assert(thread_ == NULL); assert(thread_ == NULL);
adapter_.CloseAdapter(is_restart); adapter_.CloseAdapter(is_restart);
if (completion_port_handle_) { if (completion_port_handle_)
CloseHandle(completion_port_handle_); CloseHandle(exch_null(completion_port_handle_));
completion_port_handle_ = NULL; FreePacketList(wqueue_);
}
} }
enum { enum {
@ -1758,7 +1720,7 @@ void TunWin32Iocp::ThreadMain() {
if (err == ERROR_OPERATION_ABORTED || err == ERROR_FILE_NOT_FOUND) { if (err == ERROR_OPERATION_ABORTED || err == ERROR_FILE_NOT_FOUND) {
RERROR("TAP driver stopped communicating. Attempting to restart.", err); RERROR("TAP driver stopped communicating. Attempting to restart.", err);
// This can happen if we reinstall the TAP driver while there's an active connection. // This can happen if we reinstall the TAP driver while there's an active connection.
backend_->PostExit(TunsafeBackendWin32::MODE_TUN_FAILED); runner_->PostTunRestart();
goto EXIT; goto EXIT;
} }
} else { } else {
@ -1921,6 +1883,139 @@ void TunWin32Iocp::WriteTunPacket(Packet *packet) {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
TunsafeRunner::TunsafeRunner(TunsafeBackendWin32 *backend)
: backend_(backend),
tun_(&backend->dns_blocker_, this),
wg_proc_(this, &tun_, this),
plugin_(CreateTunsafePlugin(this, &wg_proc_)),
tcp_socket_queue_(&net_, &wg_proc_.dev().packet_obfuscator()) {
wg_proc_.dev().SetPlugin(plugin_);
net_.udp().SetPacketHandler(&packet_processor_);
tcp_socket_queue_.SetPacketHandler(&packet_processor_);
tun_.SetPacketHandler(&packet_processor_);
}
TunsafeRunner::~TunsafeRunner() {
wg_proc_.dev().SetCurrentThreadAsMainThread();
delete plugin_;
}
bool TunsafeRunner::Configure(int listen_port, int listen_port_tcp) {
if (listen_port_tcp)
RERROR("ListenPortTCP not supported in this version");
return net_.udp().Configure(listen_port);
}
void TunsafeRunner::WriteUdpPacket(Packet *packet) {
if (packet->protocol & kPacketProtocolUdp) {
if (wg_proc_.dev().packet_obfuscator().enabled())
wg_proc_.dev().packet_obfuscator().ObfuscatePacket(packet);
net_.udp().WriteUdpPacket(packet);
} else {
tcp_socket_queue_.WritePacket(packet);
}
}
void TunsafeRunner::OnConnected() {
TunsafeBackendWin32 *backend = backend_;
if (backend->status() != TunsafeBackend::kStatusConnected) {
const WgCidrAddr *ipv4_addr = NULL;
for (const WgCidrAddr &x : wg_proc_.addr()) {
if (x.size == 32) {
ipv4_addr = &x;
break;
}
}
backend->ipv4_ip_ = ipv4_addr ? ReadBE32(ipv4_addr->addr) : 0;
if (backend->status() != TunsafeBackend::kStatusReconnecting) {
char buf[kSizeOfAddress];
RINFO("Connection established. IP %s", ipv4_addr ? print_ip_prefix(buf, AF_INET, ipv4_addr->addr, -1) : "(none)");
}
backend->SetStatus(TunsafeBackend::kStatusConnected);
}
}
void TunsafeRunner::OnConnectionRetry(uint32 attempts) {
TunsafeBackendWin32 *backend = backend_;
if (backend->status() == TunsafeBackend::kStatusInitializing)
backend->SetStatus(TunsafeBackend::kStatusConnecting);
else if (attempts >= 3 && backend->status() == TunsafeBackend::kStatusConnected)
backend->SetStatus(TunsafeBackend::kStatusReconnecting);
}
bool TunsafeRunner::Start() {
wg_proc_.dev().SetCurrentThreadAsMainThread();
if (config_file_.size()) {
if (config_file_is_text_format_) {
if (!ParseWireGuardConfigString(&wg_proc_, config_file_.c_str(), config_file_.size(), &backend_->dns_resolver_))
return false;
} else {
if (!ParseWireGuardConfigFile(&wg_proc_, config_file_.c_str(), &backend_->dns_resolver_))
return false;
}
}
if (wg_proc_.dev().packet_obfuscator().enabled())
packet_processor_.EnableDeobfuscation();
if (!wg_proc_.Start())
return false;
backend_->SetPublicKey(wg_proc_.dev().public_key());
net_.StartThread();
tun_.StartThread();
int stop_mode = packet_processor_.Run(&wg_proc_, this);
net_.StopThread();
tun_.StopThread();
if (stop_mode != TunsafeBackendWin32::MODE_EXIT)
tun_.adapter().DisassociateDnsBlocker();
else
backend_->dns_resolver_.ClearCache();
return true;
}
void TunsafeRunner::PostTunRestart() {
QueuedItem *qi = new QueuedItem;
qi->queue_cb = this;
packet_processor_.ForcePost(qi);
}
void TunsafeRunner::OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) {
backend_->SetStatus(TunsafeBackend::kStatusTunRetrying);
RINFO("Restarting TUN adapter");
Sleep(1000);
wg_proc_.ConfigureTun();
delete ow;
}
void TunsafeRunner::OnQueuedItemDelete(QueuedItem *ow) {
delete ow;
}
void TunsafeRunner::OnRequestToken(WgPeer *peer, uint32 type) {
backend_->OnRequestToken(peer, type);
}
void TunsafeRunner::CollectStats() {
backend_->CollectStats();
}
void TunsafeRunner::SetConfigFile(const char *file, bool is_text_format) {
config_file_is_text_format_ = is_text_format;
config_file_ = file;
}
//////////////////////////////////////////////////////////////////////////////
TunsafeBackend::TunsafeBackend() { TunsafeBackend::TunsafeBackend() {
is_started_ = false; is_started_ = false;
is_remote_ = false; is_remote_ = false;
@ -1946,8 +2041,8 @@ static void RemoveKillSwitchRoute() {
TunsafeBackendWin32::TunsafeBackendWin32(Delegate *delegate) : delegate_(delegate), dns_resolver_(&dns_blocker_) { TunsafeBackendWin32::TunsafeBackendWin32(Delegate *delegate) : delegate_(delegate), dns_resolver_(&dns_blocker_) {
memset(&stats_, 0, sizeof(stats_)); memset(&stats_, 0, sizeof(stats_));
wg_processor_ = NULL;
token_request_ = 0; token_request_ = 0;
runner_ = NULL;
InitPacketMutexes(); InitPacketMutexes();
worker_thread_ = NULL; worker_thread_ = NULL;
last_tun_adapter_failed_ = 0; last_tun_adapter_failed_ = 0;
@ -1972,79 +2067,28 @@ void TunsafeBackendWin32::SetPublicKey(const uint8 key[32]) {
delegate_->OnStateChanged(); delegate_->OnStateChanged();
} }
struct PluginHolder { void TunsafeBackendWin32::CollectStats() {
PluginHolder(PluginDelegate *del) : plugin(CreateTunsafePlugin(del)) {} stats_mutex_.Acquire();
~PluginHolder() { delete plugin; } stats_ = runner_->wg_proc_.GetStats();
TunsafePlugin *plugin; float data[2] = {
}; // unit is megabits/second
stats_.tun_bytes_in_per_second * (1.0f / 125000),
stats_.tun_bytes_out_per_second * (1.0f / 125000),
};
stats_collector_.AddSamples(data);
stats_mutex_.Release();
delegate_->OnGraphAvailable();
PushStats();
}
DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) { DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) {
TunsafeBackendWin32 *backend = (TunsafeBackendWin32*)bk; TunsafeBackendWin32 *backend = (TunsafeBackendWin32*)bk;
int stop_mode;
int fast_retry_ctr = 0;
for (;;) { if (!backend->runner_->Start()) {
TunWin32Iocp tun(&backend->dns_blocker_, backend); backend->SetStatus(TunsafeBackend::kErrorInitialize);
NetworkWin32 net;
PluginHolder plugin(backend);
WireguardProcessor wg_proc(&net, &tun, backend);
wg_proc.dev().SetPlugin(plugin.plugin);
plugin.plugin->Initialize(&wg_proc);
net.udp().SetPacketHandler(&backend->packet_processor_);
net.tcp_socket_queue().SetPacketHandler(&backend->packet_processor_);
tun.SetPacketHandler(&backend->packet_processor_);
if (backend->config_file_[0] &&
!ParseWireGuardConfigFile(&wg_proc, backend->config_file_, &backend->dns_resolver_))
goto getout_fail;
if (!wg_proc.Start())
goto getout_fail;
backend->SetPublicKey(wg_proc.dev().public_key());
backend->wg_processor_ = &wg_proc;
backend->tunsafe_wg_plugin_ = plugin.plugin;
net.StartThread();
tun.StartThread();
stop_mode = backend->packet_processor_.Run(&wg_proc, backend);
net.StopThread();
tun.StopThread();
backend->wg_processor_ = NULL;
backend->tunsafe_wg_plugin_ = NULL;
// Keep DNS alive
if (stop_mode != MODE_EXIT)
tun.adapter().DisassociateDnsBlocker();
else
backend->dns_resolver_.ClearCache();
FreeAllPackets();
if (stop_mode != MODE_TUN_FAILED)
return 0;
uint32 last_fail = GetTickCount();
fast_retry_ctr = (last_fail - backend->last_tun_adapter_failed_ < 5000) ? fast_retry_ctr + 1 : 0;
backend->last_tun_adapter_failed_ = last_fail;
backend->SetStatus((fast_retry_ctr >= 3) ? TunsafeBackend::kErrorTunPermanent : TunsafeBackend::kStatusTunRetrying);
if (backend->status_ == TunsafeBackend::kErrorTunPermanent) {
RERROR("Too many automatic restarts...");
goto getout_fail_noseterr;
}
Sleep(1000);
}
getout_fail:
backend->status_ = TunsafeBackend::kErrorInitialize;
backend->delegate_->OnStatusCode(TunsafeBackend::kErrorInitialize);
getout_fail_noseterr:
backend->dns_blocker_.RestoreDns(); backend->dns_blocker_.RestoreDns();
}
return 0; return 0;
} }
@ -2104,29 +2148,37 @@ void TunsafeBackendWin32::Start(const char *config_file) {
SetStatus(kStatusInitializing); SetStatus(kStatusInitializing);
delegate_->OnClearLog(); delegate_->OnClearLog();
DWORD thread_id; DWORD thread_id;
config_file_ = _strdup(config_file);
runner_ = new TunsafeRunner(this);
// Connect to a server given by an ID.
if (strncmp(config_file, ":srv:", 5) == 0) {
// config_file_is_text_format_ = true;
// auto server = GetServerById(config_file + 5, NULL);
// config_file_ = GetServerConfigFile(server);
} else {
runner_->SetConfigFile(config_file, false);
}
worker_thread_ = CreateThread(NULL, 0, &WorkerThread, this, 0, &thread_id); worker_thread_ = CreateThread(NULL, 0, &WorkerThread, this, 0, &thread_id);
SetThreadPriority(worker_thread_, THREAD_PRIORITY_ABOVE_NORMAL); SetThreadPriority(worker_thread_, THREAD_PRIORITY_ABOVE_NORMAL);
delegate_->OnStateChanged(); delegate_->OnStateChanged();
} }
void TunsafeBackendWin32::PostExit(int exit_code) {
packet_processor_.PostExit(exit_code);
}
void TunsafeBackendWin32::StopInner(bool is_restart) { void TunsafeBackendWin32::StopInner(bool is_restart) {
if (worker_thread_) { if (runner_) {
ipv4_ip_ = 0; ipv4_ip_ = 0;
dns_resolver_.Cancel(); dns_resolver_.Cancel();
PostExit(is_restart ? MODE_RESTART : MODE_EXIT); runner_->packet_processor_.PostExit(is_restart ? MODE_RESTART : MODE_EXIT);
WaitForSingleObject(worker_thread_, INFINITE); WaitForSingleObject(worker_thread_, INFINITE);
CloseHandle(worker_thread_); CloseHandle(exch_null(worker_thread_));
worker_thread_ = NULL;
free(config_file_);
config_file_ = NULL;
is_started_ = false; is_started_ = false;
status_ = kStatusStopped; status_ = kStatusStopped;
packet_processor_.Reset(); delete runner_;
runner_ = NULL;
FreeAllPackets();
uint8 wanted_ibs = (g_killswitch_currconn == kBlockInternet_Default) ? g_killswitch_want : g_killswitch_currconn; uint8 wanted_ibs = (g_killswitch_currconn == kBlockInternet_Default) ? g_killswitch_want : g_killswitch_currconn;
if (!is_restart && !(wanted_ibs & kBlockInternet_BlockOnDisconnect)) if (!is_restart && !(wanted_ibs & kBlockInternet_BlockOnDisconnect))
@ -2228,9 +2280,9 @@ void ConfigQueueItem::OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) {
if (type == SendConfigurationProtocolPacket) { if (type == SendConfigurationProtocolPacket) {
std::string reply; std::string reply;
WgConfig::HandleConfigurationProtocolMessage(context->wg, std::move(message), &reply); WgConfig::HandleConfigurationProtocolMessage(context->wg, std::move(message), &reply);
context->backend->delegate_->OnConfigurationProtocolReply(ident, std::move(reply)); context->runner->backend()->delegate_->OnConfigurationProtocolReply(ident, std::move(reply));
} else { } else {
context->backend->tunsafe_wg_plugin_->SubmitToken((const uint8*)message.data(), message.size()); context->runner->plugin()->SubmitToken((const uint8*)message.data(), message.size());
} }
delete this; delete this;
} }
@ -2240,15 +2292,18 @@ void ConfigQueueItem::OnQueuedItemDelete(QueuedItem *ow) {
} }
void TunsafeBackendWin32::SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) { void TunsafeBackendWin32::SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) {
if (runner_) {
ConfigQueueItem *queue_item = new ConfigQueueItem; ConfigQueueItem *queue_item = new ConfigQueueItem;
queue_item->type = ConfigQueueItem::SendConfigurationProtocolPacket; queue_item->type = ConfigQueueItem::SendConfigurationProtocolPacket;
queue_item->ident = identifier; queue_item->ident = identifier;
queue_item->message = std::move(message); queue_item->message = std::move(message);
queue_item->queue_cb = queue_item; queue_item->queue_cb = queue_item;
packet_processor_.ForcePost(queue_item); runner_->packet_processor_.ForcePost(queue_item);
}
} }
void TunsafeBackendWin32::SubmitToken(const std::string &&message) { void TunsafeBackendWin32::SubmitToken(const std::string &&message) {
if (runner_) {
// Clear out the old token request so GetTokenRequest returns zero. // Clear out the old token request so GetTokenRequest returns zero.
token_request_ = 0; token_request_ = 0;
@ -2256,8 +2311,8 @@ void TunsafeBackendWin32::SubmitToken(const std::string &&message) {
queue_item->type = ConfigQueueItem::SubmitToken; queue_item->type = ConfigQueueItem::SubmitToken;
queue_item->message = std::move(message); queue_item->message = std::move(message);
queue_item->queue_cb = queue_item; queue_item->queue_cb = queue_item;
packet_processor_.ForcePost(queue_item); runner_->packet_processor_.ForcePost(queue_item);
}
} }
uint32 TunsafeBackendWin32::GetTokenRequest() { uint32 TunsafeBackendWin32::GetTokenRequest() {
@ -2271,32 +2326,6 @@ void TunsafeBackendWin32::OnRequestToken(WgPeer *peer, uint32 type) {
delegate_->OnStateChanged(); delegate_->OnStateChanged();
} }
void TunsafeBackendWin32::OnConnected() {
if (status_ != TunsafeBackend::kStatusConnected) {
const WgCidrAddr *ipv4_addr = NULL;
for (const WgCidrAddr &x : wg_processor_->addr()) {
if (x.size == 32) {
ipv4_addr = &x;
break;
}
}
ipv4_ip_ = ipv4_addr ? ReadBE32(ipv4_addr->addr) : 0;
if (status_ != TunsafeBackend::kStatusReconnecting) {
char buf[kSizeOfAddress];
RINFO("Connection established. IP %s", ipv4_addr ? print_ip_prefix(buf, AF_INET, ipv4_addr->addr, -1) : "(none)");
}
SetStatus(TunsafeBackend::kStatusConnected);
}
}
void TunsafeBackendWin32::OnConnectionRetry(uint32 attempts) {
if (status_ == TunsafeBackend::kStatusInitializing)
SetStatus(TunsafeBackend::kStatusConnecting);
else if (attempts >= 3 && status_ == TunsafeBackend::kStatusConnected)
SetStatus(TunsafeBackend::kStatusReconnecting);
}
void TunsafeBackend::Delegate::DoWork() { void TunsafeBackend::Delegate::DoWork() {
// implemented by subclasses // implemented by subclasses
} }

View file

@ -18,6 +18,7 @@ enum {
class WireguardProcessor; class WireguardProcessor;
class TunsafeBackendWin32; class TunsafeBackendWin32;
class TunsafeRunner;
class DnsBlocker; class DnsBlocker;
struct PacketProcessorTunCb : QueuedItemCallback { struct PacketProcessorTunCb : QueuedItemCallback {
@ -41,7 +42,7 @@ public:
void Reset(); void Reset();
int Run(WireguardProcessor *wg, TunsafeBackendWin32 *backend); int Run(WireguardProcessor *wg, TunsafeRunner *runner);
void PostPackets(Packet *first, Packet **end, int count); void PostPackets(Packet *first, Packet **end, int count);
void ForcePost(QueuedItem *item); void ForcePost(QueuedItem *item);
void PostExit(int exit_code); void PostExit(int exit_code);
@ -62,7 +63,7 @@ public:
struct QueueContext { struct QueueContext {
WireguardProcessor *wg; WireguardProcessor *wg;
TunsafeBackendWin32 *backend; TunsafeRunner *runner;
bool overload; bool overload;
}; };
@ -142,12 +143,13 @@ private:
Packet *finished_reads_, **finished_reads_end_; Packet *finished_reads_, **finished_reads_end_;
int finished_reads_count_; int finished_reads_count_;
__declspec(align(64)) uint32 qsize1_; uint32 qsize1_;
__declspec(align(64)) uint32 qsize2_; uint8 align[64-4];
uint32 qsize2_;
}; };
// Holds the thread for network communications // Holds the thread for network communications
class NetworkWin32 : public UdpInterface { class NetworkWin32 {
friend class UdpSocketWin32; friend class UdpSocketWin32;
friend class TcpSocketWin32; friend class TcpSocketWin32;
friend class TcpSocketQueue; friend class TcpSocketQueue;
@ -160,14 +162,8 @@ public:
UdpSocketWin32 &udp() { return udp_socket_; } UdpSocketWin32 &udp() { return udp_socket_; }
SimplePacketPool &packet_pool() { return packet_pool_; } SimplePacketPool &packet_pool() { return packet_pool_; }
TcpSocketQueue &tcp_socket_queue() { return tcp_socket_queue_; }
void WakeUp(); void WakeUp();
void PostQueuedItem(QueuedItem *item); void PostQueuedItem(QueuedItem *item);
// -- from UdpInterface
virtual bool Configure(int listen_port_udp, int listen_port_tcp) override;
virtual void WriteUdpPacket(Packet *packet) override;
private: private:
void ThreadMain(); void ThreadMain();
static DWORD WINAPI NetworkThread(void *x); static DWORD WINAPI NetworkThread(void *x);
@ -190,8 +186,6 @@ private:
TcpSocketWin32 *tcp_socket_; TcpSocketWin32 *tcp_socket_;
SimplePacketPool packet_pool_; SimplePacketPool packet_pool_;
TcpSocketQueue tcp_socket_queue_;
}; };
class TunWin32Adapter { class TunWin32Adapter {
@ -199,7 +193,7 @@ public:
TunWin32Adapter(DnsBlocker *dns_blocker, const char guid[ADAPTER_GUID_SIZE]); TunWin32Adapter(DnsBlocker *dns_blocker, const char guid[ADAPTER_GUID_SIZE]);
~TunWin32Adapter(); ~TunWin32Adapter();
bool OpenAdapter(TunsafeBackendWin32 *backend, DWORD open_flags); bool OpenAdapter(TunsafeRunner *backend, DWORD open_flags);
bool ConfigureAdapter(const TunInterface::TunConfig &&config, TunInterface::TunConfigOut *out); bool ConfigureAdapter(const TunInterface::TunConfig &&config, TunInterface::TunConfigOut *out);
void CloseAdapter(bool is_restart); void CloseAdapter(bool is_restart);
@ -233,7 +227,7 @@ private:
// Implementation of TUN interface handling using IO Completion Ports // Implementation of TUN interface handling using IO Completion Ports
class TunWin32Iocp : public TunInterface { class TunWin32Iocp : public TunInterface {
public: public:
explicit TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend); explicit TunWin32Iocp(DnsBlocker *blocker, TunsafeRunner *backend);
~TunWin32Iocp(); ~TunWin32Iocp();
void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; } void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
@ -266,11 +260,61 @@ private:
// All packets queued for writing // All packets queued for writing
Packet *wqueue_, **wqueue_end_; Packet *wqueue_, **wqueue_end_;
TunsafeBackendWin32 *backend_; TunsafeRunner *runner_;
TunWin32Adapter adapter_; TunWin32Adapter adapter_;
}; };
class TunsafeBackendWin32 : public TunsafeBackend, public ProcessorDelegate, public PluginDelegate { // This class is the actual TunSafe thing and runs inside of a thread.
class TunsafeRunner : public UdpInterface, public ProcessorDelegate, public PluginDelegate, public QueuedItemCallback {
friend class TunsafeBackendWin32;
public:
TunsafeRunner(TunsafeBackendWin32 *backend);
~TunsafeRunner();
void SetConfigFile(const char *file, bool is_text_format);
TunsafeBackendWin32 *backend() { return backend_; }
// -- from UdpInterface
virtual bool Configure(int listen_port_udp, int listen_port_tcp) override;
virtual void WriteUdpPacket(Packet *packet) override;
virtual void OnConnected() override;
virtual void OnConnectionRetry(uint32 attempts) override;
// -- from PluginDelegate
virtual void OnRequestToken(WgPeer *peer, uint32 type) override;
bool Start();
// Called by the tun thing if tun stops working and a reset is needed.
void PostTunRestart();
uint32 exit_code() { return *packet_processor_.posted_exit_code(); }
TunsafePlugin *plugin() { return plugin_; }
void CollectStats();
private:
// From OverlappedCallbacks
virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override;
virtual void OnQueuedItemDelete(QueuedItem *ow) override;
TunsafeBackendWin32 *backend_;
TunsafePlugin *plugin_;
bool config_file_is_text_format_;
std::string config_file_;
TunWin32Iocp tun_;
NetworkWin32 net_;
TcpSocketQueue tcp_socket_queue_;
WireguardProcessor wg_proc_;
PacketProcessor packet_processor_;
};
class TunsafeBackendWin32 : public TunsafeBackend {
friend class TunsafeRunner;
friend class PacketProcessor; friend class PacketProcessor;
friend class TunWin32Iocp; friend class TunWin32Iocp;
friend class TunWin32Overlapped; friend class TunWin32Overlapped;
@ -297,52 +341,44 @@ public:
virtual uint32 GetTokenRequest() override; virtual uint32 GetTokenRequest() override;
virtual void SubmitToken(const std::string &&message) override; virtual void SubmitToken(const std::string &&message) override;
// -- from ProcessorDelegate void OnRequestToken(WgPeer *peer, uint32 type);
virtual void OnConnected() override;
virtual void OnConnectionRetry(uint32 attempts) override;
// -- from PluginDelegate
virtual void OnRequestToken(WgPeer *peer, uint32 type) override;
void SetPublicKey(const uint8 key[32]); void SetPublicKey(const uint8 key[32]);
void PostExit(int exit_code);
StatusCode status() { return status_; }
void SetStatus(StatusCode status);
void CollectStats();
private:
enum { enum {
MODE_NONE = 0, MODE_NONE = 0,
MODE_EXIT = 1, MODE_EXIT = 1,
MODE_RESTART = 2, MODE_RESTART = 2,
MODE_TUN_FAILED = 3,
}; };
uint32 exit_code() { return *packet_processor_.posted_exit_code(); }
void SetStatus(StatusCode status);
private:
void StopInner(bool is_restart); void StopInner(bool is_restart);
static DWORD WINAPI WorkerThread(void *x); static DWORD WINAPI WorkerThread(void *x);
void PushStats(); void PushStats();
TunsafeRunner *runner_;
HANDLE worker_thread_; HANDLE worker_thread_;
bool want_periodic_stats_; bool want_periodic_stats_;
Delegate *delegate_; Delegate *delegate_;
char *config_file_;
std::atomic<uint32> token_request_; std::atomic<uint32> token_request_;
DnsBlocker dns_blocker_; DnsBlocker dns_blocker_;
DnsResolver dns_resolver_; DnsResolver dns_resolver_;
WireguardProcessor *wg_processor_;
TunsafePlugin *tunsafe_wg_plugin_;
uint32 last_tun_adapter_failed_; uint32 last_tun_adapter_failed_;
StatsCollector stats_collector_; StatsCollector stats_collector_;
Mutex stats_mutex_; Mutex stats_mutex_;
WgProcessorStats stats_; WgProcessorStats stats_;
PacketProcessor packet_processor_;
char guid_[ADAPTER_GUID_SIZE]; char guid_[ADAPTER_GUID_SIZE];
}; };

View file

@ -54,7 +54,6 @@ public:
kStatusTunRetrying = 10, kStatusTunRetrying = 10,
kErrorInitialize = -1, kErrorInitialize = -1,
kErrorTunPermanent = -2,
kErrorServiceLost = -3, kErrorServiceLost = -3,
}; };

View file

@ -9,18 +9,18 @@
//////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////
TcpSocketWin32::TcpSocketWin32(NetworkWin32 *network) TcpSocketWin32::TcpSocketWin32(NetworkWin32 *network, PacketProcessor *packet_handler, WgPacketObfuscator *obfuscator, bool is_incoming)
: tcp_packet_handler_(&network->packet_pool()) { : packet_processor_(packet_handler), tcp_packet_handler_(&network->packet_pool(), obfuscator, is_incoming) {
network_ = network; network_ = network;
reads_active_ = 0; reads_active_ = 0;
writes_active_ = 0; writes_active_ = 0;
handshake_attempts = 0; handshake_attempts = 0;
handshake_timestamp_ = 0;
state_ = STATE_NONE; state_ = STATE_NONE;
wqueue_ = NULL; wqueue_ = NULL;
wqueue_end_ = &wqueue_; wqueue_end_ = &wqueue_;
socket_ = INVALID_SOCKET; socket_ = INVALID_SOCKET;
next_ = NULL; next_ = NULL;
packet_processor_ = NULL;
// insert in network's linked list // insert in network's linked list
next_ = network->tcp_socket_; next_ = network->tcp_socket_;
network->tcp_socket_ = this; network->tcp_socket_ = this;
@ -45,6 +45,7 @@ void TcpSocketWin32::CloseSocket() {
} }
void TcpSocketWin32::WritePacket(Packet *packet) { void TcpSocketWin32::WritePacket(Packet *packet) {
packet->prepared = false;
packet->queue_next = NULL; packet->queue_next = NULL;
*wqueue_end_ = packet; *wqueue_end_ = packet;
wqueue_end_ = &Packet_NEXT(packet); wqueue_end_ = &Packet_NEXT(packet);
@ -145,7 +146,9 @@ void TcpSocketWin32::DoMoreWrites() {
return; return;
do { do {
tcp_packet_handler_.AddHeaderToOutgoingPacket(p); if (!p->prepared)
tcp_packet_handler_.PrepareOutgoingPackets(p);
wsabuf[num_wsabuf].buf = (char*)p->data; wsabuf[num_wsabuf].buf = (char*)p->data;
wsabuf[num_wsabuf].len = (ULONG)p->size; wsabuf[num_wsabuf].len = (ULONG)p->size;
packets_in_write_io_[num_wsabuf] = p; packets_in_write_io_[num_wsabuf] = p;
@ -179,8 +182,7 @@ void TcpSocketWin32::DoIO() {
while (Packet *p = tcp_packet_handler_.GetNextWireguardPacket()) { while (Packet *p = tcp_packet_handler_.GetNextWireguardPacket()) {
p->protocol = endpoint_protocol_; p->protocol = endpoint_protocol_;
p->addr = endpoint_; p->addr = endpoint_;
p->queue_cb = packet_processor_->tcp_queue();
p->queue_cb = packet_processor_->udp_queue();
packet_processor_->ForcePost(p); packet_processor_->ForcePost(p);
} }
if (tcp_packet_handler_.error()) { if (tcp_packet_handler_.error()) {
@ -269,63 +271,73 @@ void TcpSocketWin32::OnQueuedItemDelete(QueuedItem *qi) {
///////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////
TcpSocketQueue::TcpSocketQueue(NetworkWin32 *network) { TcpSocketQueue::TcpSocketQueue(NetworkWin32 *network, WgPacketObfuscator *obfuscator) {
network_ = network; network_ = network;
wqueue_ = NULL; wqueue_ = NULL;
wqueue_end_ = &wqueue_; wqueue_end_ = &wqueue_;
queued_item_.queue_cb = this; queued_item_.queue_cb = this;
packet_handler_ = NULL; packet_handler_ = NULL;
obfuscator_ = obfuscator;
} }
TcpSocketQueue::~TcpSocketQueue() { TcpSocketQueue::~TcpSocketQueue() {
FreePacketList(wqueue_); FreePacketList(wqueue_);
} }
void TcpSocketQueue::TransmitOnePacket(Packet *packet) { void TcpSocketQueue::TransmitPackets(Packet *packet) {
AGAIN:
while (packet) {
bool is_handshake = ReadLE32(packet->data) == MESSAGE_HANDSHAKE_INITIATION;
// Check if we have a tcp connection for the endpoint, otherwise create one. // Check if we have a tcp connection for the endpoint, otherwise create one.
for (TcpSocketWin32 *tcp = network_->tcp_socket_; tcp; tcp = tcp->next_) { for (TcpSocketWin32 *tcp = network_->tcp_socket_; tcp; tcp = tcp->next_) {
// After we send 3 handshakes on a tcp socket in a row, then close and reopen the socket because it seems defunct. // After we send 3 handshakes on a tcp socket in a row within a minute,
// then close and reopen the socket because it seems defunct.
if (CompareIpAddr(&tcp->endpoint_, &packet->addr) == 0 && tcp->endpoint_protocol_ == packet->protocol) { if (CompareIpAddr(&tcp->endpoint_, &packet->addr) == 0 && tcp->endpoint_protocol_ == packet->protocol) {
if (ReadLE32(packet->data) == MESSAGE_HANDSHAKE_INITIATION) { if (is_handshake) {
if (tcp->handshake_attempts == 2) { uint32 now = (uint32)OsGetMilliseconds();
uint32 secs = (now - tcp->handshake_timestamp_) >> 10;
tcp->handshake_timestamp_ += secs * 1024;
int calc = (secs > (uint32)tcp->handshake_attempts + 25) ? 0 : tcp->handshake_attempts + 25 - secs;
tcp->handshake_attempts = calc;
if (calc >= 60) {
RINFO("Making new Tcp socket due to too many handshake failures"); RINFO("Making new Tcp socket due to too many handshake failures");
tcp->CloseSocket(); tcp->CloseSocket();
break; break;
} }
tcp->handshake_attempts++;
} else {
tcp->handshake_attempts = -1;
} }
tcp->WritePacket(packet); tcp->WritePacket(exch(packet, Packet_NEXT(packet)));
return; goto AGAIN;
} }
} }
// Drop tcp packet that's for an incoming connection, or packets that are // Drop tcp packet that's for an incoming connection, or packets that are
// not a handshake. // not a handshake.
if ((packet->protocol & kPacketProtocolIncomingConnection) || if ((packet->protocol & kPacketProtocolIncomingConnection) || !is_handshake) {
packet->size < 4 || ReadLE32(packet->data) != MESSAGE_HANDSHAKE_INITIATION) { FreePacket(exch(packet, Packet_NEXT(packet)));
FreePacket(packet); continue;
return;
} }
// Initialize a new tcp socket and connect to the endpoint // Initialize a new tcp socket and connect to the endpoint
TcpSocketWin32 *tcp = new TcpSocketWin32(network_); TcpSocketWin32 *tcp = new TcpSocketWin32(network_, packet_handler_, obfuscator_, false);
tcp->state_ = TcpSocketWin32::STATE_WANT_CONNECT; tcp->state_ = TcpSocketWin32::STATE_WANT_CONNECT;
tcp->endpoint_ = packet->addr; tcp->endpoint_ = packet->addr;
tcp->endpoint_protocol_ = kPacketProtocolTcp; tcp->endpoint_protocol_ = kPacketProtocolTcp;
tcp->SetPacketHandler(packet_handler_); tcp->handshake_timestamp_ = (uint32)OsGetMilliseconds();
tcp->WritePacket(packet); tcp->WritePacket(exch(packet, Packet_NEXT(packet)));
}
} }
void TcpSocketQueue::OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) { void TcpSocketQueue::OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) {
// Runs on the network thread
wqueue_mutex_.Acquire(); wqueue_mutex_.Acquire();
Packet *packet = wqueue_; Packet *packet = wqueue_;
wqueue_ = NULL; wqueue_ = NULL;
wqueue_end_ = &wqueue_; wqueue_end_ = &wqueue_;
wqueue_mutex_.Release(); wqueue_mutex_.Release();
while (packet)
TransmitOnePacket(exch(packet, Packet_NEXT(packet))); TransmitPackets(packet);
} }
void TcpSocketQueue::OnQueuedItemDelete(QueuedItem *ow) { void TcpSocketQueue::OnQueuedItemDelete(QueuedItem *ow) {

View file

@ -9,19 +9,16 @@
class NetworkWin32; class NetworkWin32;
class PacketProcessor; class PacketProcessor;
class WgPacketObfuscator;
class TcpSocketWin32 : public QueuedItemCallback { class TcpSocketWin32 : public QueuedItemCallback {
friend class NetworkWin32; friend class NetworkWin32;
friend class TcpSocketQueue; friend class TcpSocketQueue;
public: public:
explicit TcpSocketWin32(NetworkWin32 *network); explicit TcpSocketWin32(NetworkWin32 *network, PacketProcessor *packet_handler, WgPacketObfuscator *obfuscator, bool is_incoming);
~TcpSocketWin32(); ~TcpSocketWin32();
void SetPacketHandler(PacketProcessor *packet_handler) { packet_processor_ = packet_handler; } // Write a packet to the TCP socket.
// Write a packet to the TCP socket. This may be called only from the
// wireguard thread. Will append to a buffer and schedule it to be written
// from the network thread.
void WritePacket(Packet *packet); void WritePacket(Packet *packet);
// Call from IO completion thread to cancel all outstanding IO // Call from IO completion thread to cancel all outstanding IO
@ -37,7 +34,6 @@ private:
void DoMoreReads(); void DoMoreReads();
void DoMoreWrites(); void DoMoreWrites();
void DoConnect(); void DoConnect();
void CloseSocket(); void CloseSocket();
// From OverlappedCallbacks // From OverlappedCallbacks
@ -64,6 +60,8 @@ private:
public: public:
uint8 handshake_attempts; uint8 handshake_attempts;
uint8 endpoint_protocol_;
uint32 handshake_timestamp_;
private: private:
// The handle to the socket // The handle to the socket
@ -86,7 +84,6 @@ private:
QueuedItem connect_overlapped_; QueuedItem connect_overlapped_;
IpAddr endpoint_; IpAddr endpoint_;
uint8 endpoint_protocol_;
// Packets currently involved in the wsabuf writing // Packets currently involved in the wsabuf writing
enum { kMaxWsaBuf = 32 }; enum { kMaxWsaBuf = 32 };
@ -95,7 +92,7 @@ private:
class TcpSocketQueue : public QueuedItemCallback { class TcpSocketQueue : public QueuedItemCallback {
public: public:
explicit TcpSocketQueue(NetworkWin32 *network); explicit TcpSocketQueue(NetworkWin32 *network, WgPacketObfuscator *obfusctor);
~TcpSocketQueue(); ~TcpSocketQueue();
void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; } void SetPacketHandler(PacketProcessor *packet_handler) { packet_handler_ = packet_handler; }
@ -106,7 +103,7 @@ public:
void WritePacket(Packet *packet); void WritePacket(Packet *packet);
private: private:
void TransmitOnePacket(Packet *packet); void TransmitPackets(Packet *packet);
NetworkWin32 *network_; NetworkWin32 *network_;
// All packets queued for writing on the network thread. Locked by |wqueue_mutex_| // All packets queued for writing on the network thread. Locked by |wqueue_mutex_|
@ -114,11 +111,12 @@ private:
PacketProcessor *packet_handler_; PacketProcessor *packet_handler_;
WgPacketObfuscator *obfuscator_;
// Protects wqueue_ // Protects wqueue_
Mutex wqueue_mutex_; Mutex wqueue_mutex_;
// Used for queueing things on the network instance // Used for queueing things on the network instance
QueuedItem queued_item_; QueuedItem queued_item_;
}; };

View file

@ -678,8 +678,6 @@ public:
WireguardProcessor *processor() { return &processor_; } WireguardProcessor *processor() { return &processor_; }
private: private:
void WriteTcpPacket(Packet *packet);
// Close all TCP connections that are not pointed to by any of the peer endpoint. // Close all TCP connections that are not pointed to by any of the peer endpoint.
void CloseOrphanTcpConnections(); void CloseOrphanTcpConnections();
@ -697,7 +695,7 @@ private:
TunsafeBackendBsdImpl::TunsafeBackendBsdImpl() TunsafeBackendBsdImpl::TunsafeBackendBsdImpl()
: is_connected_(false), : is_connected_(false),
close_orphan_counter_(0), close_orphan_counter_(0),
plugin_(CreateTunsafePlugin(this)), plugin_(CreateTunsafePlugin(this, &processor_)),
processor_(this, this, this), processor_(this, this, this),
network_(this, 1000), network_(this, 1000),
tun_(&network_, &processor_), tun_(&network_, &processor_),
@ -732,49 +730,11 @@ bool TunsafeBackendBsdImpl::Configure(int listen_port, int listen_port_tcp) {
(listen_port_tcp == 0 || tcp_socket_listener_.Initialize(listen_port_tcp)); (listen_port_tcp == 0 || tcp_socket_listener_.Initialize(listen_port_tcp));
} }
void TunsafeBackendBsdImpl::WriteTcpPacket(Packet *packet) {
// Check if we have a tcp connection for the endpoint, otherwise create one.
for (TcpSocketBsd *tcp = network_.tcp_sockets(); tcp; tcp = tcp->next()) {
// After we send 3 handshakes on a tcp socket in a row, then close and reopen the socket because it seems defunct.
if (CompareIpAddr(&tcp->endpoint(), &packet->addr) == 0 && tcp->endpoint_protocol() == packet->protocol) {
if (ReadLE32(packet->data) == MESSAGE_HANDSHAKE_INITIATION) {
if (tcp->handshake_attempts == 2) {
RINFO("Making new Tcp socket due to too many handshake failures");
delete tcp;
break;
}
tcp->handshake_attempts++;
} else {
tcp->handshake_attempts = -1;
}
tcp->WritePacket(packet);
return;
}
}
// Drop tcp packet that's for an incoming connection, or packets that are
// not a handshake.
if ((packet->protocol & kPacketProtocolIncomingConnection) ||
ReadLE32(packet->data) != MESSAGE_HANDSHAKE_INITIATION) {
FreePacket(packet);
return;
}
// Initialize a new tcp socket and connect to the endpoint
TcpSocketBsd *tcp = new TcpSocketBsd(&network_, &processor_);
if (!tcp || !tcp->InitializeOutgoing(packet->addr)) {
delete tcp;
FreePacket(packet);
return;
}
tcp->WritePacket(packet);
}
void TunsafeBackendBsdImpl::WriteUdpPacket(Packet *packet) { void TunsafeBackendBsdImpl::WriteUdpPacket(Packet *packet) {
assert((packet->protocol & 0x7F) <= 2); assert((packet->protocol & 0x7F) <= 2);
if (packet->protocol & kPacketProtocolTcp) { if (packet->protocol & kPacketProtocolTcp) {
WriteTcpPacket(packet); TcpSocketBsd::WriteTcpPacket(&network_, &processor_, packet);
} else { } else {
if (processor_.dev().packet_obfuscator().enabled())
processor_.dev().packet_obfuscator().ObfuscatePacket(packet);
udp_.WritePacket(packet); udp_.WritePacket(packet);
} }
} }

View file

@ -1187,7 +1187,6 @@ static HFONT CreateFontHelper(int size, byte flags, const char *face, int angle
static const char *StatusCodeToString(TunsafeBackend::StatusCode code) { static const char *StatusCodeToString(TunsafeBackend::StatusCode code) {
switch (code) { switch (code) {
case TunsafeBackend::kErrorInitialize: return "Configuration Error"; case TunsafeBackend::kErrorInitialize: return "Configuration Error";
case TunsafeBackend::kErrorTunPermanent: return "TUN Adapter Error";
case TunsafeBackend::kErrorServiceLost: return "Service Lost"; case TunsafeBackend::kErrorServiceLost: return "Service Lost";
case TunsafeBackend::kStatusStopped: return "Disconnected"; case TunsafeBackend::kStatusStopped: return "Disconnected";
case TunsafeBackend::kStatusInitializing: return "Initializing"; case TunsafeBackend::kStatusInitializing: return "Initializing";

View file

@ -207,6 +207,10 @@ bool WireguardProcessor::ConfigureTun() {
for (WgPeer *peer = dev_.first_peer(); peer; peer = peer->next_peer_) { for (WgPeer *peer = dev_.first_peer(); peer; peer = peer->next_peer_) {
peer->ipv4_broadcast_addr_ = ipv4_broadcast_addr; peer->ipv4_broadcast_addr_ = ipv4_broadcast_addr;
if (peer->endpoint_protocol_ == kPacketProtocolTcp)
peer->allow_endpoint_change_ = false;
if (peer->endpoint_.sin.sin_family != 0) { if (peer->endpoint_.sin.sin_family != 0) {
RINFO("Sending handshake..."); RINFO("Sending handshake...");
SendHandshakeInitiation(peer); SendHandshakeInitiation(peer);
@ -419,7 +423,7 @@ WireguardProcessor::PacketResult WireguardProcessor::WriteAndEncryptPacketToUdp_
uint64 send_ctr; uint64 send_ctr;
// Ensure packet will fit including the biggest padding // Ensure packet will fit including the biggest padding
if (peer->endpoint_.sin.sin_family == 0 || if (peer->data_endpoint_.sin.sin_family == 0 ||
size > kPacketCapacity - 15 - CHACHA20POLY1305_AUTHTAGLEN) size > kPacketCapacity - 15 - CHACHA20POLY1305_AUTHTAGLEN)
goto getout_discard; goto getout_discard;
@ -443,8 +447,8 @@ WireguardProcessor::PacketResult WireguardProcessor::WriteAndEncryptPacketToUdp_
want_handshake = (send_ctr >= REKEY_AFTER_MESSAGES || want_handshake = (send_ctr >= REKEY_AFTER_MESSAGES ||
keypair->send_key_state == WgKeypair::KEY_WANT_REFRESH); keypair->send_key_state == WgKeypair::KEY_WANT_REFRESH);
keypair->send_ctr = send_ctr + 1; keypair->send_ctr = send_ctr + 1;
packet->addr = peer->endpoint_; packet->addr = peer->data_endpoint_;
packet->protocol = peer->endpoint_protocol_; packet->protocol = peer->data_endpoint_protocol_;
WG_EXTENSION_HOOKS::OnPeerOutgoingUdp(peer, packet); WG_EXTENSION_HOOKS::OnPeerOutgoingUdp(peer, packet);
@ -639,7 +643,9 @@ void WireguardProcessor::SendHandshakeInitiation(WgPeer *peer) {
if (attempts >= 3 && peer->allow_endpoint_change_ && if (attempts >= 3 && peer->allow_endpoint_change_ &&
(peer->endpoint_protocol_ & kPacketProtocolIncomingConnection)) { (peer->endpoint_protocol_ & kPacketProtocolIncomingConnection)) {
peer->endpoint_protocol_ = 0; peer->endpoint_protocol_ = 0;
peer->data_endpoint_protocol_ = 0;
peer->endpoint_.sin.sin_family = 0; peer->endpoint_.sin.sin_family = 0;
peer->data_endpoint_.sin.sin_family = 0;
} }
WG_RELEASE_LOCK(peer->mutex_); WG_RELEASE_LOCK(peer->mutex_);
@ -841,15 +847,23 @@ WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPack
// Remember the endpoint of the peer // Remember the endpoint of the peer
if (peer->allow_endpoint_change_ && if (peer->allow_endpoint_change_ &&
(CompareIpAddr(&peer->endpoint_, &packet->addr) | (peer->endpoint_protocol_ ^ packet->protocol)) != 0) { (CompareIpAddr(&peer->data_endpoint_, &packet->addr) | (peer->data_endpoint_protocol_ ^ packet->protocol)) != 0) {
#if WITH_SHORT_HEADERS #if WITH_SHORT_HEADERS
// When the endpoint changes, forget about using the short key. // When the endpoint changes, forget about using the short key.
keypair->broadcast_short_key = 0; keypair->broadcast_short_key = 0;
keypair->can_use_short_key_for_outgoing = false; keypair->can_use_short_key_for_outgoing = false;
#endif // WITH_SHORT_HEADERS #endif // WITH_SHORT_HEADERS
peer->data_endpoint_ = packet->addr;
peer->data_endpoint_protocol_ = packet->protocol;
// In the hybrid tcp mode, only the data endpoint gets overwritten on incoming data packets.
if (!keypair->enabled_features[WG_FEATURE_HYBRID_TCP]) {
peer->endpoint_ = packet->addr; peer->endpoint_ = packet->addr;
peer->endpoint_protocol_ = packet->protocol; peer->endpoint_protocol_ = packet->protocol;
} }
}
WG_EXTENSION_HOOKS::OnPeerIncomingUdp(peer, packet); WG_EXTENSION_HOOKS::OnPeerIncomingUdp(peer, packet);

View file

@ -56,12 +56,14 @@ static int ParseFeature(const char *str) {
} }
if (len == 5 && memcmp(str, "mac64", 5) == 0) if (len == 5 && memcmp(str, "mac64", 5) == 0)
return what + WG_FEATURE_ID_SHORT_MAC * 16; return what + WG_FEATURE_ID_SHORT_MAC * 16;
if (len == 12 && memcmp(str, "short_header", 12) == 0)
return what + WG_FEATURE_ID_SHORT_HEADER * 16;
if (len == 5 && memcmp(str, "ipzip", 5) == 0) if (len == 5 && memcmp(str, "ipzip", 5) == 0)
return what + WG_FEATURE_ID_IPZIP * 16; return what + WG_FEATURE_ID_IPZIP * 16;
if (len == 10 && memcmp(str, "hybrid_tcp", 10) == 0)
return what + WG_FEATURE_HYBRID_TCP * 16;
if (len == 10 && memcmp(str, "skip_keyid", 10) == 0) if (len == 10 && memcmp(str, "skip_keyid", 10) == 0)
return what + WG_FEATURE_ID_SKIP_KEYID_IN * 16 + 1 * 4; return what + WG_FEATURE_ID_SKIP_KEYID_IN * 16 + 1 * 4;
if (len == 12 && memcmp(str, "short_header", 12) == 0)
return what + WG_FEATURE_ID_SHORT_HEADER * 16;
if (len == 13 && memcmp(str, "skip_keyid_in", 13) == 0) if (len == 13 && memcmp(str, "skip_keyid_in", 13) == 0)
return what + WG_FEATURE_ID_SKIP_KEYID_IN * 16; return what + WG_FEATURE_ID_SKIP_KEYID_IN * 16;
if (len == 14 && memcmp(str, "skip_keyid_out", 14) == 0) if (len == 14 && memcmp(str, "skip_keyid_out", 14) == 0)
@ -169,8 +171,22 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
} }
wg_->SetInternetBlocking((InternetBlockState)v); wg_->SetInternetBlocking((InternetBlockState)v);
} else if (strcmp(key, "HeaderObfuscation") == 0) { } else if (strcmp(key, "ObfuscateKey") == 0) {
wg_->dev().packet_obfuscator().SetKey((uint8*)value, strlen(value)); wg_->dev().packet_obfuscator().SetKey((uint8*)value, strlen(value));
} else if (strcmp(key, "ObfuscateTCP") == 0) {
bool flag;
int v = 1;
if (ParseBoolean(value, &flag)) {
v = flag;
} else if (strcmp(value, "tls-firefox") == 0) {
v = 2;
} else if (strcmp(value, "tls-chrome") == 0) {
v = 3;
} else if (*value != 0) {
RERROR("Unknown mode in ObfuscateTCP: %s", value);
}
wg_->dev().packet_obfuscator().set_obfuscate_tcp(v);
} else if (strcmp(key, "PostUp") == 0) { } else if (strcmp(key, "PostUp") == 0) {
wg_->prepost().post_up.emplace_back(value); wg_->prepost().post_up.emplace_back(value);
} else if (strcmp(key, "PostDown") == 0) { } else if (strcmp(key, "PostDown") == 0) {

View file

@ -334,7 +334,9 @@ WgPeer::WgPeer(WgDevice *dev) {
assert(dev->IsMainThread()); assert(dev->IsMainThread());
dev_ = dev; dev_ = dev;
endpoint_.sin.sin_family = 0; endpoint_.sin.sin_family = 0;
data_endpoint_.sin.sin_family = 0;
endpoint_protocol_ = 0; endpoint_protocol_ = 0;
data_endpoint_protocol_ = 0;
next_peer_ = NULL; next_peer_ = NULL;
peer_extra_data_ = NULL; peer_extra_data_ = NULL;
curr_keypair_ = next_keypair_ = prev_keypair_ = NULL; curr_keypair_ = next_keypair_ = prev_keypair_ = NULL;
@ -685,6 +687,11 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) {
WG_ACQUIRE_LOCK(peer->mutex_); WG_ACQUIRE_LOCK(peer->mutex_);
peer->rx_bytes_ += packet->size; peer->rx_bytes_ += packet->size;
if (keypair != NULL) { if (keypair != NULL) {
// The server side needs to remember the endpoint on incoming handshakes.
if (peer->allow_endpoint_change_ && keypair->enabled_features[WG_FEATURE_HYBRID_TCP]) {
peer->endpoint_ = packet->addr;
peer->endpoint_protocol_ = packet->protocol;
}
peer->InsertKeypairInPeer_Locked(keypair); peer->InsertKeypairInPeer_Locked(keypair);
peer->OnHandshakeAuthComplete(); peer->OnHandshakeAuthComplete();
} }
@ -772,8 +779,19 @@ WgPeer *WgPeer::ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packe
WG_ACQUIRE_LOCK(peer->mutex_); WG_ACQUIRE_LOCK(peer->mutex_);
if (peer->allow_endpoint_change_) { if (peer->allow_endpoint_change_) {
peer->endpoint_ = packet->addr; // TODO: Why is this needed, if we are able to get a response for the handshake init
// packet then we already know its endpoint?
peer->endpoint_protocol_ = packet->protocol; peer->endpoint_protocol_ = packet->protocol;
peer->endpoint_ = packet->addr;
if (!keypair->enabled_features[WG_FEATURE_HYBRID_TCP] || !peer->IsTransientDataEndpointActive()) {
peer->data_endpoint_protocol_ = peer->endpoint_protocol_;
peer->data_endpoint_ = peer->endpoint_;
}
// If hybrid tcp mode was enabled for the connection, switch
// the data endpoint to the udp endpoint.
} else if (peer->endpoint_protocol_ == kPacketProtocolTcp) {
peer->data_endpoint_protocol_ = keypair->enabled_features[WG_FEATURE_HYBRID_TCP] ? kPacketProtocolUdp : kPacketProtocolTcp;
peer->data_endpoint_ = peer->endpoint_;
} }
WG_EXTENSION_HOOKS::OnPeerIncomingUdp(peer, packet); WG_EXTENSION_HOOKS::OnPeerIncomingUdp(peer, packet);
@ -1065,11 +1083,15 @@ enum {
TIMER_ZERO_KEYS = 3, TIMER_ZERO_KEYS = 3,
// Timer for sending a keepalive packet every PERSISTENT_KEEPALIVE_MS // Timer for sending a keepalive packet every PERSISTENT_KEEPALIVE_MS
TIMER_PERSISTENT_KEEPALIVE = 4, TIMER_PERSISTENT_KEEPALIVE = 4,
// Timer for removing the transient UDP endpoint in hybrid TCP mode after 10 seconds
TIMER_HYBRID_TCP = 5,
TIMERS_COUNT = 6,
}; };
#define WgClearTimer(x) (timers_ &= ~(33 << x)) #define WgClearTimer(x) (timers_ &= ~(((1<<TIMERS_COUNT)+1) << x))
#define WgIsTimerActive(x) (timers_ & (33 << x)) #define WgIsTimerActive(x) (timers_ & (((1<<TIMERS_COUNT)+1) << x))
#define WgSetTimer(x) (timers_ |= (32 << (x))) #define WgSetTimer(x) (timers_ |= (((1<<TIMERS_COUNT)) << (x)))
void WgPeer::OnDataSent() { void WgPeer::OnDataSent() {
assert(IsPeerLocked()); assert(IsPeerLocked());
@ -1092,12 +1114,14 @@ void WgPeer::OnDataReceived() {
else else
pending_keepalive_ = true; pending_keepalive_ = true;
WgSetTimer(TIMER_PERSISTENT_KEEPALIVE); WgSetTimer(TIMER_PERSISTENT_KEEPALIVE);
WgSetTimer(TIMER_HYBRID_TCP);
} }
void WgPeer::OnKeepaliveReceived() { void WgPeer::OnKeepaliveReceived() {
assert(IsPeerLocked()); assert(IsPeerLocked());
WgClearTimer(TIMER_NEW_HANDSHAKE); WgClearTimer(TIMER_NEW_HANDSHAKE);
WgSetTimer(TIMER_PERSISTENT_KEEPALIVE); WgSetTimer(TIMER_PERSISTENT_KEEPALIVE);
WgSetTimer(TIMER_HYBRID_TCP);
} }
void WgPeer::OnHandshakeInitSent() { void WgPeer::OnHandshakeInitSent() {
@ -1158,17 +1182,18 @@ uint32 WgPeer::CheckTimeouts_Locked(uint64 now) {
return 0; return 0;
uint32 now32 = (uint32)now; uint32 now32 = (uint32)now;
// Got any new timers? // Got any new timers?
if (t & (0x1f << 5)) { if (t & (((1 << TIMERS_COUNT) - 1) << TIMERS_COUNT)) {
if (t & (1 << (5+0))) timer_value_[0] = now32; if (t & (1 << (TIMERS_COUNT+0))) timer_value_[0] = now32;
if (t & (1 << (5+1))) timer_value_[1] = now32; if (t & (1 << (TIMERS_COUNT+1))) timer_value_[1] = now32;
if (t & (1 << (5+2))) timer_value_[2] = now32; if (t & (1 << (TIMERS_COUNT+2))) timer_value_[2] = now32;
if (t & (1 << (5+3))) timer_value_[3] = now32; if (t & (1 << (TIMERS_COUNT+3))) timer_value_[3] = now32;
if (t & (1 << (5+4))) timer_value_[4] = now32; if (t & (1 << (TIMERS_COUNT+4))) timer_value_[4] = now32;
t |= (t >> 5); if (t & (1 << (TIMERS_COUNT+5))) timer_value_[5] = now32;
t &= 0x1F; t |= (t >> TIMERS_COUNT);
t &= (1 << TIMERS_COUNT) - 1;
} }
// Got any expired timers? // Got any expired timers?
if (t & 0x1F) { if (t & ((1 << TIMERS_COUNT) - 1)) {
if ((t & (1 << TIMER_RETRANSMIT_HANDSHAKE)) && (now32 - timer_value_[TIMER_RETRANSMIT_HANDSHAKE]) >= REKEY_TIMEOUT_MS) { if ((t & (1 << TIMER_RETRANSMIT_HANDSHAKE)) && (now32 - timer_value_[TIMER_RETRANSMIT_HANDSHAKE]) >= REKEY_TIMEOUT_MS) {
t ^= (1 << TIMER_RETRANSMIT_HANDSHAKE); t ^= (1 << TIMER_RETRANSMIT_HANDSHAKE);
if (handshake_attempts_ > MAX_HANDSHAKE_ATTEMPTS || endpoint_.sin.sin_family == 0) { if (handshake_attempts_ > MAX_HANDSHAKE_ATTEMPTS || endpoint_.sin.sin_family == 0) {
@ -1212,6 +1237,16 @@ uint32 WgPeer::CheckTimeouts_Locked(uint64 now) {
ClearKeys_Locked(); ClearKeys_Locked();
ClearHandshake_Locked(); ClearHandshake_Locked();
} }
if ((t & (1 << TIMER_HYBRID_TCP)) && (now32 - timer_value_[TIMER_HYBRID_TCP]) >= HYBRID_TCP_TIMEOUT_MS) {
t &= ~(1 << TIMER_HYBRID_TCP);
// Forget about the data endpoint and switch to using the regular endpoint after 15 seconds.
if (allow_endpoint_change_) {
data_endpoint_protocol_ = endpoint_protocol_;
data_endpoint_ = endpoint_;
}
}
} }
timers_ = t; timers_ = t;
return rv; return rv;
@ -1261,9 +1296,15 @@ void WgPeer::CheckAndUpdateTimeOfNextKeyEvent(uint64 now) {
time_of_next_key_event_ = next_time; time_of_next_key_event_ = next_time;
} }
bool WgPeer::IsTransientDataEndpointActive() {
return WgIsTimerActive(TIMER_HYBRID_TCP) != 0;
}
void WgPeer::SetEndpoint(int endpoint_proto, const IpAddr &sin) { void WgPeer::SetEndpoint(int endpoint_proto, const IpAddr &sin) {
endpoint_protocol_ = endpoint_proto; endpoint_protocol_ = endpoint_proto;
data_endpoint_protocol_ = endpoint_proto;
endpoint_ = sin; endpoint_ = sin;
data_endpoint_ = sin;
} }
bool WgPeer::SetPersistentKeepalive(int persistent_keepalive_secs) { bool WgPeer::SetPersistentKeepalive(int persistent_keepalive_secs) {
@ -1484,18 +1525,19 @@ size_t WgPublicKeyHasher::operator()(const WgPublicKey&a) const {
// This scrambles the initial 16 bytes of the packet with the // This scrambles the initial 16 bytes of the packet with the
// last 8 bytes of the packet as a seed. // last 8 bytes of the packet as a seed.
void WgPacketObfuscator::ScrambleUnscramble(uint8 *data, size_t data_size) { void WgPacketObfuscator::ScrambleUnscramble(uint8 *data, size_t data_size) {
assert(data_size >= 16);
uint64 last_uint64 = ReadLE64(data + data_size - 8); uint64 last_uint64 = ReadLE64(data + data_size - 8);
uint64 a = siphash_u64_u32(last_uint64, (uint32)data_size, (siphash_key_t*)&key_[0]); uint64 a = siphash_u64_u32(last_uint64, (uint32)data_size, (siphash_key_t*)&key_[0]);
uint64 b = siphash_u64_u32(last_uint64, (uint32)data_size, (siphash_key_t*)&key_[2]); uint64 b = siphash_u64_u32(last_uint64, (uint32)data_size, (siphash_key_t*)&key_[2]);
a = ToLE64(a); ((uint64*)data)[0] ^= ToLE64(a);
b = ToLE64(b); b = ToLE64(b);
if (data_size >= 24) { if (data_size >= 24) {
((uint64*)data)[0] ^= a;
((uint64*)data)[1] ^= b; ((uint64*)data)[1] ^= b;
} else { } else {
uint64 d[2] = { a, b }; uint64 d[1] = { b };
for (size_t i = 0; i < data_size - 8; i++) for (size_t i = 0; i < data_size - 16; i++)
data[i] ^= ((uint8*)d)[i]; data[i + 8] ^= ((uint8*)d)[i];
} }
} }
@ -1524,12 +1566,11 @@ void WgPacketObfuscator::ObfuscatePacket(Packet *packet) {
// in the 3:rd byte of the packet. // in the 3:rd byte of the packet.
uint32 packet_type = ReadLE32(data); uint32 packet_type = ReadLE32(data);
if ((packet_type == 4 && data_size <= 32) || packet_type < 4) { if ((packet_type == 4 && data_size <= 32) || packet_type < 4) {
if (packet_type != 4) { // The 39:th (for handshake init) and 43:rd byte (for handshake response)
// The 39:th and 43:rd bytes often have zero MSB because of curve25519 pubkey, // have zero MSB because of curve25519 pubkey, so xor it with random.
// so xor them with something in the header. if (packet_type < 4) {
assert(data_size >= 44); assert(data_size >= 48);
data[39] ^= data[12]; data[35 + packet_type * 4] ^= data[15];
data[43] ^= data[12];
} }
packet->size = data_size = InsertRandomBytesIntoPacket(data, data_size); packet->size = data_size = InsertRandomBytesIntoPacket(data, data_size);
} }
@ -1552,16 +1593,14 @@ void WgPacketObfuscator::DeobfuscatePacket(Packet *packet) {
// Check whether the packet type field says that we have // Check whether the packet type field says that we have
// extra bytes appended at the end. // extra bytes appended at the end.
if (data[0] <= 4) { if (data[0] <= 4) {
if (data[0] < 4 && data_size >= 44) { if (data[3] > data_size)
// The 39:th and 43:rd bytes often have zero MSB because of curve25519 pubkey, return; // invalid
// so xor them with something in the header. packet->size = (uint32)(data_size -= data[3]);
data[39] ^= data[12];
data[43] ^= data[12];
}
if (data[3] <= data_size) {
packet->size = (uint32)(data_size - data[3]);
data[3] = 0; data[3] = 0;
} // The 39:th (for handshake init) and 43:rd byte (for handshake response)
// have zero MSB because of curve25519 pubkey, so xor it with random.
if (data[0] < 4 && data_size >= 48)
data[35 + data[0] * 4] ^= data[15];
} }
} }

View file

@ -62,7 +62,12 @@ enum ProtocolTimeouts {
REJECT_AFTER_TIME_MS = 180000, REJECT_AFTER_TIME_MS = 180000,
MIN_HANDSHAKE_INTERVAL_MS = 20, MIN_HANDSHAKE_INTERVAL_MS = 20,
MAX_SIZE_OF_HANDSHAKE_EXTENSION = 1024, HYBRID_TCP_TIMEOUT_MS = 15000,
// Chosen so that 1500 - 28 - sizeof(handshakeresponse) which means
// we can use this to probe mtu.
MAX_SIZE_OF_HANDSHAKE_EXTENSION = 1380,
}; };
enum ProtocolLimits { enum ProtocolLimits {
@ -179,12 +184,13 @@ enum {
}; };
enum { enum {
WG_FEATURES_COUNT = 6, WG_FEATURES_COUNT = 7,
WG_FEATURE_ID_SHORT_HEADER = 0, // Supports short headers WG_FEATURE_ID_SHORT_HEADER = 0, // Supports short headers
WG_FEATURE_ID_SHORT_MAC = 1, // Supports 8-byte MAC WG_FEATURE_ID_SHORT_MAC = 1, // Supports 8-byte MAC
WG_FEATURE_ID_IPZIP = 2, // Using ipzip WG_FEATURE_ID_IPZIP = 2, // Using ipzip
WG_FEATURE_ID_SKIP_KEYID_IN = 4, // Skip keyid for incoming packets WG_FEATURE_ID_SKIP_KEYID_IN = 4, // Skip keyid for incoming packets
WG_FEATURE_ID_SKIP_KEYID_OUT = 5, // Skip keyid for outgoing packets WG_FEATURE_ID_SKIP_KEYID_OUT = 5, // Skip keyid for outgoing packets
WG_FEATURE_HYBRID_TCP = 6, // Use hybrid-tcp mode
}; };
enum { enum {
@ -340,7 +346,7 @@ public:
// including adding random bytes at the end of the non-data packets. // including adding random bytes at the end of the non-data packets.
class WgPacketObfuscator { class WgPacketObfuscator {
public: public:
WgPacketObfuscator() : enabled_(false) {} WgPacketObfuscator() : enabled_(false), obfuscate_tcp_(-1) {}
bool enabled() { return enabled_; } bool enabled() { return enabled_; }
void ObfuscatePacket(Packet *packet); void ObfuscatePacket(Packet *packet);
@ -350,6 +356,9 @@ public:
const uint8 *key() { return (uint8*)key_; } const uint8 *key() { return (uint8*)key_; }
int obfuscate_tcp() { return obfuscate_tcp_; }
void set_obfuscate_tcp(int v) { obfuscate_tcp_ = v; }
static size_t InsertRandomBytesIntoPacket(uint8 *data, size_t data_size); static size_t InsertRandomBytesIntoPacket(uint8 *data, size_t data_size);
private: private:
@ -358,6 +367,9 @@ private:
// Whether packet obfuscation is enabled // Whether packet obfuscation is enabled
bool enabled_; bool enabled_;
// Type of obfuscation for tcp
int obfuscate_tcp_;
// Siphash keys for packet scrambling // Siphash keys for packet scrambling
uint64 key_[4]; uint64 key_[4];
}; };
@ -395,6 +407,8 @@ public:
WgRateLimit *rate_limiter() { return &rate_limiter_; } WgRateLimit *rate_limiter() { return &rate_limiter_; }
bool is_private_key_initialized() { return is_private_key_initialized_; } bool is_private_key_initialized() { return is_private_key_initialized_; }
void SetCurrentThreadAsMainThread() { main_thread_id_ = GetCurrentThreadId(); }
bool IsMainThread() { return CurrentThreadIdEquals(main_thread_id_); } bool IsMainThread() { return CurrentThreadIdEquals(main_thread_id_); }
bool IsMainOrDataThread() { return CurrentThreadIdEquals(main_thread_id_) || WG_IF_LOCKS_ENABLED_ELSE(delayed_delete_.enabled(), false); } bool IsMainOrDataThread() { return CurrentThreadIdEquals(main_thread_id_) || WG_IF_LOCKS_ENABLED_ELSE(delayed_delete_.enabled(), false); }
@ -565,7 +579,7 @@ private:
void ClearHandshake_Locked(); void ClearHandshake_Locked();
void ClearPacketQueue_Locked(); void ClearPacketQueue_Locked();
void ScheduleNewHandshake(); void ScheduleNewHandshake();
bool IsTransientDataEndpointActive();
WgDevice *dev_; WgDevice *dev_;
WgPeer *next_peer_; WgPeer *next_peer_;
@ -582,7 +596,6 @@ private:
// For timer management // For timer management
uint32 timers_; uint32 timers_;
uint32 timer_value_[5];
// Holds the entry into the key id table during handshake - mt only. // Holds the entry into the key id table during handshake - mt only.
uint32 local_key_id_during_hs_; uint32 local_key_id_during_hs_;
@ -623,7 +636,7 @@ private:
uint8 handshake_attempts_; uint8 handshake_attempts_;
// What's the protocol of the currently configured endpoint // What's the protocol of the currently configured endpoint
uint8 endpoint_protocol_; uint8 endpoint_protocol_, data_endpoint_protocol_;
// Which features are enabled for this peer? // Which features are enabled for this peer?
uint8 features_[WG_FEATURES_COUNT]; uint8 features_[WG_FEATURES_COUNT];
@ -632,9 +645,6 @@ private:
uint8 num_queued_packets_; uint8 num_queued_packets_;
Packet *first_queued_packet_, **last_queued_packet_ptr_; Packet *first_queued_packet_, **last_queued_packet_ptr_;
// Address of peer
IpAddr endpoint_;
// For statistics // For statistics
uint64 last_handshake_init_timestamp_; uint64 last_handshake_init_timestamp_;
uint64 last_complete_handskake_timestamp_; uint64 last_complete_handskake_timestamp_;
@ -642,6 +652,13 @@ private:
// Timestamp to detect flooding of handshakes // Timestamp to detect flooding of handshakes
uint64 last_handshake_init_recv_timestamp_; // main thread only uint64 last_handshake_init_recv_timestamp_; // main thread only
// Address of peer
IpAddr endpoint_;
// Alternative endpoint. This is used in hybrid tcp mode to hold the
// udp endpoint.
IpAddr data_endpoint_;
// Number of handshake attempts since last successful handshake // Number of handshake attempts since last successful handshake
uint32 total_handshake_attempts_; uint32 total_handshake_attempts_;
@ -653,6 +670,8 @@ private:
uint32 keepalive_timeout_ms_; // Set to KEEPALIVE_TIMEOUT_MS uint32 keepalive_timeout_ms_; // Set to KEEPALIVE_TIMEOUT_MS
uint32 timer_value_[6];
uint64 rx_bytes_; uint64 rx_bytes_;
uint64 tx_bytes_; uint64 tx_bytes_;