Add more extension points in wireguard code
This commit is contained in:
parent
b1ffd5738e
commit
13158f9d90
6
netapi.h
6
netapi.h
|
@ -66,6 +66,12 @@ struct Packet : QueuedItem {
|
||||||
HEADROOM_BEFORE = 64,
|
HEADROOM_BEFORE = 64,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef PACKET_EXTENSION_FIELDS
|
||||||
|
PACKET_EXTENSION_FIELDS
|
||||||
|
#endif // PACKET_EXTENSION_FIELDS
|
||||||
|
|
||||||
|
|
||||||
byte data_pre[HEADROOM_BEFORE];
|
byte data_pre[HEADROOM_BEFORE];
|
||||||
byte data_buf[0];
|
byte data_buf[0];
|
||||||
|
|
||||||
|
|
|
@ -450,6 +450,8 @@ WireguardProcessor::PacketResult WireguardProcessor::WriteAndEncryptPacketToUdp_
|
||||||
packet->addr = peer->endpoint_;
|
packet->addr = peer->endpoint_;
|
||||||
packet->protocol = peer->endpoint_protocol_;
|
packet->protocol = peer->endpoint_protocol_;
|
||||||
|
|
||||||
|
WG_EXTENSION_HOOKS::OnPeerOutgoingUdp(peer, packet);
|
||||||
|
|
||||||
if (size == 0) {
|
if (size == 0) {
|
||||||
peer->OnKeepaliveSent();
|
peer->OnKeepaliveSent();
|
||||||
} else {
|
} else {
|
||||||
|
@ -663,6 +665,7 @@ void WireguardProcessor::SendHandshakeInitiation(WgPeer *peer) {
|
||||||
peer->OnHandshakeInitSent();
|
peer->OnHandshakeInitSent();
|
||||||
packet->addr = peer->endpoint_;
|
packet->addr = peer->endpoint_;
|
||||||
packet->protocol = peer->endpoint_protocol_;
|
packet->protocol = peer->endpoint_protocol_;
|
||||||
|
WG_EXTENSION_HOOKS::OnPeerOutgoingUdp(peer, packet);
|
||||||
peer->tx_bytes_ += packet->size;
|
peer->tx_bytes_ += packet->size;
|
||||||
|
|
||||||
// If this is an incoming oneway connection (such as tcp), forget the
|
// If this is an incoming oneway connection (such as tcp), forget the
|
||||||
|
@ -888,6 +891,8 @@ WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPack
|
||||||
peer->endpoint_protocol_ = packet->protocol;
|
peer->endpoint_protocol_ = packet->protocol;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
WG_EXTENSION_HOOKS::OnPeerIncomingUdp(peer, packet);
|
||||||
|
|
||||||
// Remember how many incoming packets we've seen so we can approximate loss
|
// Remember how many incoming packets we've seen so we can approximate loss
|
||||||
keypair->incoming_packet_count++;
|
keypair->incoming_packet_count++;
|
||||||
|
|
||||||
|
@ -936,9 +941,13 @@ WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPack
|
||||||
if (ip_version == 4) {
|
if (ip_version == 4) {
|
||||||
if (data_size < IPV4_HEADER_SIZE)
|
if (data_size < IPV4_HEADER_SIZE)
|
||||||
goto getout_error_header;
|
goto getout_error_header;
|
||||||
WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
|
if (!WG_EXTENSION_HOOKS::DisableSourceAddressVerification(peer)) {
|
||||||
peer_from_header = (WgPeer*)dev_.ip_to_peer_map().LookupV4(ReadBE32(data + 12));
|
WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
|
||||||
WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
|
peer_from_header = (WgPeer*)dev_.ip_to_peer_map().LookupV4(ReadBE32(data + 12));
|
||||||
|
WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
|
||||||
|
if (peer_from_header != peer)
|
||||||
|
goto getout_error_header;
|
||||||
|
}
|
||||||
size_from_header = ReadBE16(data + 2);
|
size_from_header = ReadBE16(data + 2);
|
||||||
if (size_from_header < IPV4_HEADER_SIZE) {
|
if (size_from_header < IPV4_HEADER_SIZE) {
|
||||||
// too small packet?
|
// too small packet?
|
||||||
|
@ -950,12 +959,14 @@ WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPack
|
||||||
WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
|
WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
|
||||||
peer_from_header = (WgPeer*)dev_.ip_to_peer_map().LookupV6(data + 8);
|
peer_from_header = (WgPeer*)dev_.ip_to_peer_map().LookupV6(data + 8);
|
||||||
WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
|
WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
|
||||||
|
if (peer_from_header != peer)
|
||||||
|
goto getout_error_header;
|
||||||
size_from_header = IPV6_HEADER_SIZE + ReadBE16(data + 4);
|
size_from_header = IPV6_HEADER_SIZE + ReadBE16(data + 4);
|
||||||
} else {
|
} else {
|
||||||
// invalid ip version
|
// invalid ip version
|
||||||
goto getout_error_header;
|
goto getout_error_header;
|
||||||
}
|
}
|
||||||
if (peer_from_header != peer || size_from_header > data_size)
|
if (size_from_header > data_size)
|
||||||
goto getout_error_header;
|
goto getout_error_header;
|
||||||
|
|
||||||
packet->size = size_from_header;
|
packet->size = size_from_header;
|
||||||
|
|
|
@ -97,7 +97,9 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
|
||||||
|
|
||||||
if (strcmp(group, "[Interface]") == 0) {
|
if (strcmp(group, "[Interface]") == 0) {
|
||||||
if (key == NULL) return true;
|
if (key == NULL) return true;
|
||||||
if (strcmp(key, "PrivateKey") == 0) {
|
if (wg_->dev().plugin() && wg_->dev().plugin()->OnUnknownInterfaceSetting(key, value)) {
|
||||||
|
// nothing here
|
||||||
|
} else if (strcmp(key, "PrivateKey") == 0) {
|
||||||
if (!ParseBase64Key(value, binkey))
|
if (!ParseBase64Key(value, binkey))
|
||||||
return false;
|
return false;
|
||||||
had_interface_ = true;
|
had_interface_ = true;
|
||||||
|
@ -185,8 +187,7 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
|
||||||
wg_->AddExcludedIp(addr);
|
wg_->AddExcludedIp(addr);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (!wg_->dev().plugin() || !wg_->dev().plugin()->OnUnknownInterfaceSetting(key, value))
|
goto err;
|
||||||
goto err;
|
|
||||||
}
|
}
|
||||||
} else if (strcmp(group, "[Peer]") == 0) {
|
} else if (strcmp(group, "[Peer]") == 0) {
|
||||||
if (key == NULL) {
|
if (key == NULL) {
|
||||||
|
@ -199,7 +200,10 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
|
||||||
memset(&pi_, 0, sizeof(pi_));
|
memset(&pi_, 0, sizeof(pi_));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (strcmp(key, "PublicKey") == 0) {
|
|
||||||
|
if (wg_->dev().plugin() && wg_->dev().plugin()->OnUnknownPeerSetting(peer_, key, value)) {
|
||||||
|
// nothing here
|
||||||
|
} else if (strcmp(key, "PublicKey") == 0) {
|
||||||
if (!ParseBase64Key(value, pi_.pub.bytes))
|
if (!ParseBase64Key(value, pi_.pub.bytes))
|
||||||
return false;
|
return false;
|
||||||
} else if (strcmp(key, "PresharedKey") == 0) {
|
} else if (strcmp(key, "PresharedKey") == 0) {
|
||||||
|
@ -252,8 +256,7 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (!wg_->dev().plugin() || !wg_->dev().plugin()->OnUnknownPeerSetting(peer_, key, value))
|
goto err;
|
||||||
goto err;
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err:
|
err:
|
||||||
|
@ -331,6 +334,11 @@ bool ParseWireGuardConfigString(WireguardProcessor *wg, const char *bufin, size_
|
||||||
buf = nl + 1;
|
buf = nl + 1;
|
||||||
}
|
}
|
||||||
file_parser.FinishGroup();
|
file_parser.FinishGroup();
|
||||||
|
|
||||||
|
// Let plugin do any final processing
|
||||||
|
if (wg->dev().plugin() && !wg->dev().plugin()->OnAfterSettingsParsed())
|
||||||
|
return false;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -787,6 +787,9 @@ WgPeer *WgPeer::ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packe
|
||||||
peer->endpoint_ = packet->addr;
|
peer->endpoint_ = packet->addr;
|
||||||
peer->endpoint_protocol_ = packet->protocol;
|
peer->endpoint_protocol_ = packet->protocol;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
WG_EXTENSION_HOOKS::OnPeerIncomingUdp(peer, packet);
|
||||||
|
|
||||||
peer->rx_bytes_ += packet->size;
|
peer->rx_bytes_ += packet->size;
|
||||||
peer->InsertKeypairInPeer_Locked(keypair);
|
peer->InsertKeypairInPeer_Locked(keypair);
|
||||||
WG_RELEASE_LOCK(peer->mutex_);
|
WG_RELEASE_LOCK(peer->mutex_);
|
||||||
|
|
|
@ -318,22 +318,26 @@ public:
|
||||||
virtual bool OnUnknownInterfaceSetting(const char *key, const char *value) = 0;
|
virtual bool OnUnknownInterfaceSetting(const char *key, const char *value) = 0;
|
||||||
virtual bool OnUnknownPeerSetting(WgPeer *peer, const char *key, const char *value) = 0;
|
virtual bool OnUnknownPeerSetting(WgPeer *peer, const char *key, const char *value) = 0;
|
||||||
|
|
||||||
|
// Called after settings have been completely parsed, the plugin may modify the state
|
||||||
|
virtual bool OnAfterSettingsParsed() = 0;
|
||||||
|
|
||||||
// Returns true if we want to perform a handshake for this peer.
|
// Returns true if we want to perform a handshake for this peer.
|
||||||
virtual bool WantHandshake(WgPeer *peer) = 0;
|
virtual bool WantHandshake(WgPeer *peer) = 0;
|
||||||
|
|
||||||
// Called right before handshake initiation is sent out. Can't drop packets.
|
|
||||||
virtual uint32 OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) = 0;
|
|
||||||
// Called after handshake initiation is parsed, but before handshake response is sent.
|
|
||||||
|
|
||||||
enum {
|
enum {
|
||||||
kHandshakeResponseDrop = 0xffffffff,
|
kHandshakeResponseDrop = 0xffffffff,
|
||||||
kHandshakeResponseFail = 0x80000000
|
kHandshakeResponseFail = 0x80000000
|
||||||
};
|
};
|
||||||
|
// Called before handshake initiation is sent out. Can write extra headers. Can't drop packets.
|
||||||
|
virtual uint32 OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) = 0;
|
||||||
|
// Called after handshake initiation is parsed, but before handshake response is sent.
|
||||||
// Packet can be dropped or keypair failed.
|
// Packet can be dropped or keypair failed.
|
||||||
virtual uint32 OnHandshake1(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt_in[WG_PUBLIC_KEY_LEN], uint8 *extout, uint32 extout_size, const uint8 salt_out[WG_PUBLIC_KEY_LEN]) = 0;
|
virtual uint32 OnHandshake1(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt_in[WG_PUBLIC_KEY_LEN], uint8 *extout, uint32 extout_size, const uint8 salt_out[WG_PUBLIC_KEY_LEN]) = 0;
|
||||||
|
|
||||||
// Called when handshake response is parsed
|
// Called when handshake response is parsed
|
||||||
virtual uint32 OnHandshake2(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) = 0;
|
virtual uint32 OnHandshake2(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) = 0;
|
||||||
|
|
||||||
|
// Called right before an outgoing non-data packet is sent out, but before it's scrambled.
|
||||||
|
virtual void OnOutgoingHandshakePacket(WgPeer *peer, Packet *packet) = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -378,7 +382,9 @@ public:
|
||||||
|
|
||||||
void SetPlugin(WgPlugin *del) { plugin_ = del; }
|
void SetPlugin(WgPlugin *del) { plugin_ = del; }
|
||||||
WgPlugin *plugin() { return plugin_; }
|
WgPlugin *plugin() { return plugin_; }
|
||||||
|
|
||||||
|
MultithreadedDelayedDelete *GetDelayedDelete() { return &delayed_delete_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::pair<WgPeer*, WgKeypair*> *LookupPeerInKeyIdLookup(uint32 key_id);
|
std::pair<WgPeer*, WgKeypair*> *LookupPeerInKeyIdLookup(uint32 key_id);
|
||||||
WgKeypair *LookupKeypairByKeyId(uint32 key_id);
|
WgKeypair *LookupKeypairByKeyId(uint32 key_id);
|
||||||
|
@ -773,6 +779,9 @@ bool WgKeypairDecryptPayload(uint8 *dst, const size_t src_len,
|
||||||
|
|
||||||
struct WgExtensionHooksDefault {
|
struct WgExtensionHooksDefault {
|
||||||
static uint32 GetIpv4Target(Packet *packet, uint8 *data) { return ReadBE32(data + 16); }
|
static uint32 GetIpv4Target(Packet *packet, uint8 *data) { return ReadBE32(data + 16); }
|
||||||
|
static void OnPeerIncomingUdp(WgPeer *peer, const Packet *packet) { }
|
||||||
|
static void OnPeerOutgoingUdp(WgPeer *peer, Packet *packet) { }
|
||||||
|
static bool DisableSourceAddressVerification(WgPeer *peer) { return false; }
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifndef WG_EXTENSION_HOOKS
|
#ifndef WG_EXTENSION_HOOKS
|
||||||
|
|
Loading…
Reference in a new issue