Add more extension points in wireguard code

This commit is contained in:
Ludvig Strigeus 2018-12-10 23:12:57 +01:00
parent b1ffd5738e
commit 13158f9d90
5 changed files with 53 additions and 16 deletions

View file

@ -66,6 +66,12 @@ struct Packet : QueuedItem {
HEADROOM_BEFORE = 64, HEADROOM_BEFORE = 64,
}; };
#ifdef PACKET_EXTENSION_FIELDS
PACKET_EXTENSION_FIELDS
#endif // PACKET_EXTENSION_FIELDS
byte data_pre[HEADROOM_BEFORE]; byte data_pre[HEADROOM_BEFORE];
byte data_buf[0]; byte data_buf[0];

View file

@ -450,6 +450,8 @@ WireguardProcessor::PacketResult WireguardProcessor::WriteAndEncryptPacketToUdp_
packet->addr = peer->endpoint_; packet->addr = peer->endpoint_;
packet->protocol = peer->endpoint_protocol_; packet->protocol = peer->endpoint_protocol_;
WG_EXTENSION_HOOKS::OnPeerOutgoingUdp(peer, packet);
if (size == 0) { if (size == 0) {
peer->OnKeepaliveSent(); peer->OnKeepaliveSent();
} else { } else {
@ -663,6 +665,7 @@ void WireguardProcessor::SendHandshakeInitiation(WgPeer *peer) {
peer->OnHandshakeInitSent(); peer->OnHandshakeInitSent();
packet->addr = peer->endpoint_; packet->addr = peer->endpoint_;
packet->protocol = peer->endpoint_protocol_; packet->protocol = peer->endpoint_protocol_;
WG_EXTENSION_HOOKS::OnPeerOutgoingUdp(peer, packet);
peer->tx_bytes_ += packet->size; peer->tx_bytes_ += packet->size;
// If this is an incoming oneway connection (such as tcp), forget the // If this is an incoming oneway connection (such as tcp), forget the
@ -888,6 +891,8 @@ WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPack
peer->endpoint_protocol_ = packet->protocol; peer->endpoint_protocol_ = packet->protocol;
} }
WG_EXTENSION_HOOKS::OnPeerIncomingUdp(peer, packet);
// Remember how many incoming packets we've seen so we can approximate loss // Remember how many incoming packets we've seen so we can approximate loss
keypair->incoming_packet_count++; keypair->incoming_packet_count++;
@ -936,9 +941,13 @@ WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPack
if (ip_version == 4) { if (ip_version == 4) {
if (data_size < IPV4_HEADER_SIZE) if (data_size < IPV4_HEADER_SIZE)
goto getout_error_header; goto getout_error_header;
if (!WG_EXTENSION_HOOKS::DisableSourceAddressVerification(peer)) {
WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_); WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
peer_from_header = (WgPeer*)dev_.ip_to_peer_map().LookupV4(ReadBE32(data + 12)); peer_from_header = (WgPeer*)dev_.ip_to_peer_map().LookupV4(ReadBE32(data + 12));
WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_); WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
if (peer_from_header != peer)
goto getout_error_header;
}
size_from_header = ReadBE16(data + 2); size_from_header = ReadBE16(data + 2);
if (size_from_header < IPV4_HEADER_SIZE) { if (size_from_header < IPV4_HEADER_SIZE) {
// too small packet? // too small packet?
@ -950,12 +959,14 @@ WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPack
WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_); WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
peer_from_header = (WgPeer*)dev_.ip_to_peer_map().LookupV6(data + 8); peer_from_header = (WgPeer*)dev_.ip_to_peer_map().LookupV6(data + 8);
WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_); WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
if (peer_from_header != peer)
goto getout_error_header;
size_from_header = IPV6_HEADER_SIZE + ReadBE16(data + 4); size_from_header = IPV6_HEADER_SIZE + ReadBE16(data + 4);
} else { } else {
// invalid ip version // invalid ip version
goto getout_error_header; goto getout_error_header;
} }
if (peer_from_header != peer || size_from_header > data_size) if (size_from_header > data_size)
goto getout_error_header; goto getout_error_header;
packet->size = size_from_header; packet->size = size_from_header;

View file

@ -97,7 +97,9 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
if (strcmp(group, "[Interface]") == 0) { if (strcmp(group, "[Interface]") == 0) {
if (key == NULL) return true; if (key == NULL) return true;
if (strcmp(key, "PrivateKey") == 0) { if (wg_->dev().plugin() && wg_->dev().plugin()->OnUnknownInterfaceSetting(key, value)) {
// nothing here
} else if (strcmp(key, "PrivateKey") == 0) {
if (!ParseBase64Key(value, binkey)) if (!ParseBase64Key(value, binkey))
return false; return false;
had_interface_ = true; had_interface_ = true;
@ -185,7 +187,6 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
wg_->AddExcludedIp(addr); wg_->AddExcludedIp(addr);
} }
} else { } else {
if (!wg_->dev().plugin() || !wg_->dev().plugin()->OnUnknownInterfaceSetting(key, value))
goto err; goto err;
} }
} else if (strcmp(group, "[Peer]") == 0) { } else if (strcmp(group, "[Peer]") == 0) {
@ -199,7 +200,10 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
memset(&pi_, 0, sizeof(pi_)); memset(&pi_, 0, sizeof(pi_));
return true; return true;
} }
if (strcmp(key, "PublicKey") == 0) {
if (wg_->dev().plugin() && wg_->dev().plugin()->OnUnknownPeerSetting(peer_, key, value)) {
// nothing here
} else if (strcmp(key, "PublicKey") == 0) {
if (!ParseBase64Key(value, pi_.pub.bytes)) if (!ParseBase64Key(value, pi_.pub.bytes))
return false; return false;
} else if (strcmp(key, "PresharedKey") == 0) { } else if (strcmp(key, "PresharedKey") == 0) {
@ -252,7 +256,6 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
return false; return false;
} }
} else { } else {
if (!wg_->dev().plugin() || !wg_->dev().plugin()->OnUnknownPeerSetting(peer_, key, value))
goto err; goto err;
} }
} else { } else {
@ -331,6 +334,11 @@ bool ParseWireGuardConfigString(WireguardProcessor *wg, const char *bufin, size_
buf = nl + 1; buf = nl + 1;
} }
file_parser.FinishGroup(); file_parser.FinishGroup();
// Let plugin do any final processing
if (wg->dev().plugin() && !wg->dev().plugin()->OnAfterSettingsParsed())
return false;
return true; return true;
} }

View file

@ -787,6 +787,9 @@ WgPeer *WgPeer::ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packe
peer->endpoint_ = packet->addr; peer->endpoint_ = packet->addr;
peer->endpoint_protocol_ = packet->protocol; peer->endpoint_protocol_ = packet->protocol;
} }
WG_EXTENSION_HOOKS::OnPeerIncomingUdp(peer, packet);
peer->rx_bytes_ += packet->size; peer->rx_bytes_ += packet->size;
peer->InsertKeypairInPeer_Locked(keypair); peer->InsertKeypairInPeer_Locked(keypair);
WG_RELEASE_LOCK(peer->mutex_); WG_RELEASE_LOCK(peer->mutex_);

View file

@ -318,22 +318,26 @@ public:
virtual bool OnUnknownInterfaceSetting(const char *key, const char *value) = 0; virtual bool OnUnknownInterfaceSetting(const char *key, const char *value) = 0;
virtual bool OnUnknownPeerSetting(WgPeer *peer, const char *key, const char *value) = 0; virtual bool OnUnknownPeerSetting(WgPeer *peer, const char *key, const char *value) = 0;
// Called after settings have been completely parsed, the plugin may modify the state
virtual bool OnAfterSettingsParsed() = 0;
// Returns true if we want to perform a handshake for this peer. // Returns true if we want to perform a handshake for this peer.
virtual bool WantHandshake(WgPeer *peer) = 0; virtual bool WantHandshake(WgPeer *peer) = 0;
// Called right before handshake initiation is sent out. Can't drop packets.
virtual uint32 OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) = 0;
// Called after handshake initiation is parsed, but before handshake response is sent.
enum { enum {
kHandshakeResponseDrop = 0xffffffff, kHandshakeResponseDrop = 0xffffffff,
kHandshakeResponseFail = 0x80000000 kHandshakeResponseFail = 0x80000000
}; };
// Called before handshake initiation is sent out. Can write extra headers. Can't drop packets.
virtual uint32 OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) = 0;
// Called after handshake initiation is parsed, but before handshake response is sent.
// Packet can be dropped or keypair failed. // Packet can be dropped or keypair failed.
virtual uint32 OnHandshake1(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt_in[WG_PUBLIC_KEY_LEN], uint8 *extout, uint32 extout_size, const uint8 salt_out[WG_PUBLIC_KEY_LEN]) = 0; virtual uint32 OnHandshake1(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt_in[WG_PUBLIC_KEY_LEN], uint8 *extout, uint32 extout_size, const uint8 salt_out[WG_PUBLIC_KEY_LEN]) = 0;
// Called when handshake response is parsed // Called when handshake response is parsed
virtual uint32 OnHandshake2(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) = 0; virtual uint32 OnHandshake2(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) = 0;
// Called right before an outgoing non-data packet is sent out, but before it's scrambled.
virtual void OnOutgoingHandshakePacket(WgPeer *peer, Packet *packet) = 0;
}; };
@ -379,6 +383,8 @@ public:
void SetPlugin(WgPlugin *del) { plugin_ = del; } void SetPlugin(WgPlugin *del) { plugin_ = del; }
WgPlugin *plugin() { return plugin_; } WgPlugin *plugin() { return plugin_; }
MultithreadedDelayedDelete *GetDelayedDelete() { return &delayed_delete_; }
private: private:
std::pair<WgPeer*, WgKeypair*> *LookupPeerInKeyIdLookup(uint32 key_id); std::pair<WgPeer*, WgKeypair*> *LookupPeerInKeyIdLookup(uint32 key_id);
WgKeypair *LookupKeypairByKeyId(uint32 key_id); WgKeypair *LookupKeypairByKeyId(uint32 key_id);
@ -773,6 +779,9 @@ bool WgKeypairDecryptPayload(uint8 *dst, const size_t src_len,
struct WgExtensionHooksDefault { struct WgExtensionHooksDefault {
static uint32 GetIpv4Target(Packet *packet, uint8 *data) { return ReadBE32(data + 16); } static uint32 GetIpv4Target(Packet *packet, uint8 *data) { return ReadBE32(data + 16); }
static void OnPeerIncomingUdp(WgPeer *peer, const Packet *packet) { }
static void OnPeerOutgoingUdp(WgPeer *peer, Packet *packet) { }
static bool DisableSourceAddressVerification(WgPeer *peer) { return false; }
}; };
#ifndef WG_EXTENSION_HOOKS #ifndef WG_EXTENSION_HOOKS