344 lines
9.8 KiB
C++
344 lines
9.8 KiB
C++
// SPDX-License-Identifier: AGPL-1.0-only
|
|
// Copyright (C) 2018 Ludvig Strigeus <info@tunsafe.com>. All Rights Reserved.
|
|
#include <stdafx.h>
|
|
#include "network_win32_tcp.h"
|
|
#include "network_win32.h"
|
|
#include <Mswsock.h>
|
|
#include <ws2ipdef.h>
|
|
#include "util.h"
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
TcpSocketWin32::TcpSocketWin32(NetworkWin32 *network)
|
|
: tcp_packet_handler_(&network->packet_pool()) {
|
|
network_ = network;
|
|
reads_active_ = 0;
|
|
writes_active_ = 0;
|
|
handshake_attempts = 0;
|
|
state_ = STATE_NONE;
|
|
wqueue_ = NULL;
|
|
wqueue_end_ = &wqueue_;
|
|
socket_ = INVALID_SOCKET;
|
|
next_ = NULL;
|
|
packet_processor_ = NULL;
|
|
// insert in network's linked list
|
|
next_ = network->tcp_socket_;
|
|
network->tcp_socket_ = this;
|
|
}
|
|
|
|
TcpSocketWin32::~TcpSocketWin32() {
|
|
// Unlink myself from the network's linked list.
|
|
TcpSocketWin32 **p = &network_->tcp_socket_;
|
|
while (*p != this) p = &(*p)->next_;
|
|
*p = next_;
|
|
|
|
FreePacketList(wqueue_);
|
|
if (socket_ != INVALID_SOCKET)
|
|
closesocket(socket_);
|
|
}
|
|
|
|
void TcpSocketWin32::CloseSocket() {
|
|
if (socket_ != INVALID_SOCKET)
|
|
CancelIo((HANDLE)socket_);
|
|
state_ = STATE_ERROR;
|
|
endpoint_protocol_ = 0;
|
|
}
|
|
|
|
void TcpSocketWin32::WritePacket(Packet *packet) {
|
|
packet->queue_next = NULL;
|
|
*wqueue_end_ = packet;
|
|
wqueue_end_ = &Packet_NEXT(packet);
|
|
}
|
|
|
|
void TcpSocketWin32::CancelAllIO() {
|
|
if (socket_ != INVALID_SOCKET)
|
|
CancelIo((HANDLE)socket_);
|
|
}
|
|
|
|
static const GUID WsaConnectExGUID = WSAID_CONNECTEX;
|
|
|
|
void TcpSocketWin32::DoConnect() {
|
|
LPFN_CONNECTEX ConnectEx;
|
|
assert(socket_ == INVALID_SOCKET);
|
|
|
|
socket_ = WSASocket(endpoint_.sin.sin_family, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED);
|
|
if (socket_ == INVALID_SOCKET) {
|
|
RERROR("socket() failed");
|
|
CloseSocket();
|
|
return;
|
|
}
|
|
|
|
if (!CreateIoCompletionPort((HANDLE)socket_, network_->completion_port_handle_, 0, 0)) {
|
|
RERROR("TcpSocketWin32::DoConnect CreateIoCompletionPort failed");
|
|
CloseSocket();
|
|
return;
|
|
}
|
|
|
|
int nodelay = 1;
|
|
setsockopt(socket_, IPPROTO_TCP, TCP_NODELAY, (char*)&nodelay, 1);
|
|
|
|
DWORD dwBytes = sizeof(ConnectEx);
|
|
DWORD rc = WSAIoctl(socket_, SIO_GET_EXTENSION_FUNCTION_POINTER, (uint8*)&WsaConnectExGUID, sizeof(WsaConnectExGUID), &ConnectEx, sizeof(ConnectEx), &dwBytes, NULL, NULL);
|
|
assert(rc == 0);
|
|
|
|
// ConnectEx requires the socket to be bound
|
|
sockaddr_in sin = {0};
|
|
sin.sin_family = AF_INET;
|
|
sin.sin_addr.s_addr = INADDR_ANY;
|
|
sin.sin_port = 0;
|
|
if (bind(socket_, (sockaddr*)&sin, sizeof(sin))) {
|
|
RERROR("TcpSocketWin32::DoConnect bind failed: %d", WSAGetLastError());
|
|
CloseSocket();
|
|
return;
|
|
}
|
|
|
|
char buf[kSizeOfAddress];
|
|
RINFO("Connecting to tcp://%s...", PrintIpAddr(endpoint_, buf));
|
|
|
|
state_ = STATE_CONNECTING;
|
|
ClearOverlapped(&connect_overlapped_.overlapped);
|
|
connect_overlapped_.queue_cb = this;
|
|
if (!ConnectEx(socket_, (const sockaddr*)&endpoint_.sin, sizeof(endpoint_.sin), NULL, 0, NULL, &connect_overlapped_.overlapped)) {
|
|
int err = WSAGetLastError();
|
|
if (err != ERROR_IO_PENDING) {
|
|
RERROR("ConnectEx failed: %d", err);
|
|
CloseSocket();
|
|
return;
|
|
}
|
|
}
|
|
reads_active_ = 1;
|
|
}
|
|
|
|
void TcpSocketWin32::DoMoreReads() {
|
|
assert(state_ != STATE_ERROR);
|
|
if (reads_active_ == 0) {
|
|
// Initiate a new read, we always read into 4 buffers.
|
|
Packet *p = network_->packet_pool().AllocPacketFromPool();
|
|
if (!p)
|
|
return;
|
|
|
|
ClearOverlapped(&p->overlapped);
|
|
p->userdata = 0;
|
|
p->queue_cb = this;
|
|
DWORD flags = 0;
|
|
WSABUF wsabuf = {(ULONG)kPacketCapacity, (char*)p->data};
|
|
if (WSARecv(socket_, &wsabuf, 1, NULL, &flags, &p->overlapped, NULL) != 0) {
|
|
DWORD err = WSAGetLastError();
|
|
if (err != ERROR_IO_PENDING) {
|
|
RERROR("TcpSocketWin32:WSARecv failed 0x%X", err);
|
|
FreePacket(p);
|
|
return;
|
|
}
|
|
}
|
|
reads_active_ = 1;
|
|
}
|
|
}
|
|
|
|
void TcpSocketWin32::DoMoreWrites() {
|
|
assert(state_ != STATE_ERROR);
|
|
if (writes_active_ == 0) {
|
|
WSABUF wsabuf[kMaxWsaBuf];
|
|
uint32 num_wsabuf = 0;
|
|
|
|
Packet *p = wqueue_;
|
|
if (p == NULL)
|
|
return;
|
|
|
|
do {
|
|
tcp_packet_handler_.AddHeaderToOutgoingPacket(p);
|
|
wsabuf[num_wsabuf].buf = (char*)p->data;
|
|
wsabuf[num_wsabuf].len = (ULONG)p->size;
|
|
packets_in_write_io_[num_wsabuf] = p;
|
|
p = Packet_NEXT(p);
|
|
} while (++num_wsabuf < kMaxWsaBuf && p != NULL);
|
|
if (!(wqueue_ = p))
|
|
wqueue_end_ = &wqueue_;
|
|
num_wsabuf_ = (uint8)num_wsabuf;
|
|
|
|
p = packets_in_write_io_[0];
|
|
ClearOverlapped(&p->overlapped);
|
|
p->userdata = 1;
|
|
p->queue_cb = this;
|
|
|
|
if (WSASend(socket_, wsabuf, num_wsabuf, NULL, 0, &p->overlapped, NULL) != 0) {
|
|
DWORD err = WSAGetLastError();
|
|
if (err != ERROR_IO_PENDING) {
|
|
RERROR("TcpSocketWin32: WSASend failed 0x%X", err);
|
|
FreePacket(p);
|
|
CloseSocket();
|
|
return;
|
|
}
|
|
}
|
|
writes_active_ = 1;
|
|
}
|
|
}
|
|
|
|
void TcpSocketWin32::DoIO() {
|
|
if (state_ == STATE_CONNECTED) {
|
|
DoMoreReads();
|
|
while (Packet *p = tcp_packet_handler_.GetNextWireguardPacket()) {
|
|
p->protocol = endpoint_protocol_;
|
|
p->addr = endpoint_;
|
|
|
|
p->queue_cb = packet_processor_->udp_queue();
|
|
packet_processor_->ForcePost(p);
|
|
}
|
|
if (tcp_packet_handler_.error()) {
|
|
CloseSocket();
|
|
DoIO();
|
|
return;
|
|
}
|
|
DoMoreWrites();
|
|
} else if (state_ == STATE_WANT_CONNECT) {
|
|
DoConnect();
|
|
} else if (state_ == STATE_ERROR && !HasOutstandingIO()) {
|
|
delete this;
|
|
}
|
|
}
|
|
|
|
bool TcpSocketWin32::HasOutstandingIO() {
|
|
return writes_active_ + reads_active_ != 0;
|
|
}
|
|
|
|
void TcpSocketWin32::OnQueuedItemEvent(QueuedItem *qi, uintptr_t extra) {
|
|
if (qi == &connect_overlapped_) {
|
|
assert(state_ == STATE_CONNECTING);
|
|
reads_active_ = 0;
|
|
if ((DWORD)qi->overlapped.Internal != 0) {
|
|
if (state_ != STATE_ERROR) {
|
|
RERROR("TcpSocketWin32::Connect error 0x%X", (DWORD)qi->overlapped.Internal);
|
|
CloseSocket();
|
|
}
|
|
} else {
|
|
state_ = STATE_CONNECTED;
|
|
}
|
|
return;
|
|
}
|
|
Packet *p = static_cast<Packet*>(qi);
|
|
if (p->userdata == 0) {
|
|
// Read operation complete
|
|
if ((DWORD)p->overlapped.Internal != 0) {
|
|
if (state_ != STATE_ERROR) {
|
|
RERROR("TcpSocketWin32::Read error 0x%X", (DWORD)p->overlapped.Internal);
|
|
CloseSocket();
|
|
}
|
|
network_->packet_pool().FreePacketToPool(p);
|
|
// What to do?
|
|
} else if ((int)p->overlapped.InternalHigh == 0) {
|
|
// Socket closed successfully
|
|
CloseSocket();
|
|
network_->packet_pool().FreePacketToPool(p);
|
|
} else {
|
|
// Queue it up to rqueue
|
|
p->size = (int)p->overlapped.InternalHigh;
|
|
tcp_packet_handler_.QueueIncomingPacket(p);
|
|
}
|
|
reads_active_--;
|
|
} else {
|
|
assert(writes_active_);
|
|
assert(packets_in_write_io_[0] == p);
|
|
|
|
if ((DWORD)p->overlapped.Internal != 0) {
|
|
if (state_ != STATE_ERROR) {
|
|
RERROR("TcpSocketWin32::Write error 0x%X", (DWORD)p->overlapped.Internal);
|
|
CloseSocket();
|
|
}
|
|
}
|
|
// free all the packets involved in the write
|
|
for (size_t i = 0; i < num_wsabuf_; i++)
|
|
network_->packet_pool().FreePacketToPool(packets_in_write_io_[i]);
|
|
writes_active_--;
|
|
}
|
|
}
|
|
|
|
void TcpSocketWin32::OnQueuedItemDelete(QueuedItem *qi) {
|
|
if (qi == &connect_overlapped_) {
|
|
reads_active_ = 0;
|
|
return;
|
|
}
|
|
Packet *p = static_cast<Packet*>(qi);
|
|
if (p->userdata == 0) {
|
|
FreePacket(p);
|
|
reads_active_--;
|
|
} else {
|
|
for (size_t i = 0; i < num_wsabuf_; i++)
|
|
network_->packet_pool().FreePacketToPool(packets_in_write_io_[i]);
|
|
writes_active_--;
|
|
}
|
|
}
|
|
|
|
/////////////////////////////////////////////////////////////////////////
|
|
|
|
TcpSocketQueue::TcpSocketQueue(NetworkWin32 *network) {
|
|
network_ = network;
|
|
wqueue_ = NULL;
|
|
wqueue_end_ = &wqueue_;
|
|
queued_item_.queue_cb = this;
|
|
packet_handler_ = NULL;
|
|
}
|
|
|
|
TcpSocketQueue::~TcpSocketQueue() {
|
|
FreePacketList(wqueue_);
|
|
}
|
|
|
|
void TcpSocketQueue::TransmitOnePacket(Packet *packet) {
|
|
// Check if we have a tcp connection for the endpoint, otherwise create one.
|
|
for (TcpSocketWin32 *tcp = network_->tcp_socket_; tcp; tcp = tcp->next_) {
|
|
// After we send 3 handshakes on a tcp socket in a row, then close and reopen the socket because it seems defunct.
|
|
if (CompareIpAddr(&tcp->endpoint_, &packet->addr) == 0 && tcp->endpoint_protocol_ == packet->protocol) {
|
|
if (ReadLE32(packet->data) == MESSAGE_HANDSHAKE_INITIATION) {
|
|
if (tcp->handshake_attempts == 2) {
|
|
RINFO("Making new Tcp socket due to too many handshake failures");
|
|
tcp->CloseSocket();
|
|
break;
|
|
}
|
|
tcp->handshake_attempts++;
|
|
} else {
|
|
tcp->handshake_attempts = -1;
|
|
}
|
|
tcp->WritePacket(packet);
|
|
return;
|
|
}
|
|
}
|
|
|
|
// Drop tcp packet that's for an incoming connection, or packets that are
|
|
// not a handshake.
|
|
if ((packet->protocol & kPacketProtocolIncomingConnection) ||
|
|
packet->size < 4 || ReadLE32(packet->data) != MESSAGE_HANDSHAKE_INITIATION) {
|
|
FreePacket(packet);
|
|
return;
|
|
}
|
|
|
|
// Initialize a new tcp socket and connect to the endpoint
|
|
TcpSocketWin32 *tcp = new TcpSocketWin32(network_);
|
|
tcp->state_ = TcpSocketWin32::STATE_WANT_CONNECT;
|
|
tcp->endpoint_ = packet->addr;
|
|
tcp->endpoint_protocol_ = kPacketProtocolTcp;
|
|
tcp->SetPacketHandler(packet_handler_);
|
|
tcp->WritePacket(packet);
|
|
}
|
|
|
|
void TcpSocketQueue::OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) {
|
|
wqueue_mutex_.Acquire();
|
|
Packet *packet = wqueue_;
|
|
wqueue_ = NULL;
|
|
wqueue_end_ = &wqueue_;
|
|
wqueue_mutex_.Release();
|
|
while (packet)
|
|
TransmitOnePacket(exch(packet, Packet_NEXT(packet)));
|
|
}
|
|
|
|
void TcpSocketQueue::OnQueuedItemDelete(QueuedItem *ow) {
|
|
|
|
}
|
|
|
|
void TcpSocketQueue::WritePacket(Packet *packet) {
|
|
packet->queue_next = NULL;
|
|
wqueue_mutex_.Acquire();
|
|
Packet *was_empty = wqueue_;
|
|
*wqueue_end_ = packet;
|
|
wqueue_end_ = &Packet_NEXT(packet);
|
|
wqueue_mutex_.Release();
|
|
if (was_empty == NULL)
|
|
network_->PostQueuedItem(&queued_item_);
|
|
}
|