Add more extension points in wireguard code

This commit is contained in:
Ludvig Strigeus 2018-12-10 23:12:57 +01:00
parent b1ffd5738e
commit 13158f9d90
5 changed files with 53 additions and 16 deletions

View file

@ -66,6 +66,12 @@ struct Packet : QueuedItem {
HEADROOM_BEFORE = 64,
};
#ifdef PACKET_EXTENSION_FIELDS
PACKET_EXTENSION_FIELDS
#endif // PACKET_EXTENSION_FIELDS
byte data_pre[HEADROOM_BEFORE];
byte data_buf[0];

View file

@ -450,6 +450,8 @@ WireguardProcessor::PacketResult WireguardProcessor::WriteAndEncryptPacketToUdp_
packet->addr = peer->endpoint_;
packet->protocol = peer->endpoint_protocol_;
WG_EXTENSION_HOOKS::OnPeerOutgoingUdp(peer, packet);
if (size == 0) {
peer->OnKeepaliveSent();
} else {
@ -663,6 +665,7 @@ void WireguardProcessor::SendHandshakeInitiation(WgPeer *peer) {
peer->OnHandshakeInitSent();
packet->addr = peer->endpoint_;
packet->protocol = peer->endpoint_protocol_;
WG_EXTENSION_HOOKS::OnPeerOutgoingUdp(peer, packet);
peer->tx_bytes_ += packet->size;
// If this is an incoming oneway connection (such as tcp), forget the
@ -888,6 +891,8 @@ WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPack
peer->endpoint_protocol_ = packet->protocol;
}
WG_EXTENSION_HOOKS::OnPeerIncomingUdp(peer, packet);
// Remember how many incoming packets we've seen so we can approximate loss
keypair->incoming_packet_count++;
@ -936,9 +941,13 @@ WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPack
if (ip_version == 4) {
if (data_size < IPV4_HEADER_SIZE)
goto getout_error_header;
WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
peer_from_header = (WgPeer*)dev_.ip_to_peer_map().LookupV4(ReadBE32(data + 12));
WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
if (!WG_EXTENSION_HOOKS::DisableSourceAddressVerification(peer)) {
WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
peer_from_header = (WgPeer*)dev_.ip_to_peer_map().LookupV4(ReadBE32(data + 12));
WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
if (peer_from_header != peer)
goto getout_error_header;
}
size_from_header = ReadBE16(data + 2);
if (size_from_header < IPV4_HEADER_SIZE) {
// too small packet?
@ -950,12 +959,14 @@ WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPack
WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
peer_from_header = (WgPeer*)dev_.ip_to_peer_map().LookupV6(data + 8);
WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
if (peer_from_header != peer)
goto getout_error_header;
size_from_header = IPV6_HEADER_SIZE + ReadBE16(data + 4);
} else {
// invalid ip version
goto getout_error_header;
}
if (peer_from_header != peer || size_from_header > data_size)
if (size_from_header > data_size)
goto getout_error_header;
packet->size = size_from_header;

View file

@ -97,7 +97,9 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
if (strcmp(group, "[Interface]") == 0) {
if (key == NULL) return true;
if (strcmp(key, "PrivateKey") == 0) {
if (wg_->dev().plugin() && wg_->dev().plugin()->OnUnknownInterfaceSetting(key, value)) {
// nothing here
} else if (strcmp(key, "PrivateKey") == 0) {
if (!ParseBase64Key(value, binkey))
return false;
had_interface_ = true;
@ -185,8 +187,7 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
wg_->AddExcludedIp(addr);
}
} else {
if (!wg_->dev().plugin() || !wg_->dev().plugin()->OnUnknownInterfaceSetting(key, value))
goto err;
goto err;
}
} else if (strcmp(group, "[Peer]") == 0) {
if (key == NULL) {
@ -199,7 +200,10 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
memset(&pi_, 0, sizeof(pi_));
return true;
}
if (strcmp(key, "PublicKey") == 0) {
if (wg_->dev().plugin() && wg_->dev().plugin()->OnUnknownPeerSetting(peer_, key, value)) {
// nothing here
} else if (strcmp(key, "PublicKey") == 0) {
if (!ParseBase64Key(value, pi_.pub.bytes))
return false;
} else if (strcmp(key, "PresharedKey") == 0) {
@ -252,8 +256,7 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
return false;
}
} else {
if (!wg_->dev().plugin() || !wg_->dev().plugin()->OnUnknownPeerSetting(peer_, key, value))
goto err;
goto err;
}
} else {
err:
@ -331,6 +334,11 @@ bool ParseWireGuardConfigString(WireguardProcessor *wg, const char *bufin, size_
buf = nl + 1;
}
file_parser.FinishGroup();
// Let plugin do any final processing
if (wg->dev().plugin() && !wg->dev().plugin()->OnAfterSettingsParsed())
return false;
return true;
}

View file

@ -787,6 +787,9 @@ WgPeer *WgPeer::ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packe
peer->endpoint_ = packet->addr;
peer->endpoint_protocol_ = packet->protocol;
}
WG_EXTENSION_HOOKS::OnPeerIncomingUdp(peer, packet);
peer->rx_bytes_ += packet->size;
peer->InsertKeypairInPeer_Locked(keypair);
WG_RELEASE_LOCK(peer->mutex_);

View file

@ -318,22 +318,26 @@ public:
virtual bool OnUnknownInterfaceSetting(const char *key, const char *value) = 0;
virtual bool OnUnknownPeerSetting(WgPeer *peer, const char *key, const char *value) = 0;
// Called after settings have been completely parsed, the plugin may modify the state
virtual bool OnAfterSettingsParsed() = 0;
// Returns true if we want to perform a handshake for this peer.
virtual bool WantHandshake(WgPeer *peer) = 0;
// Called right before handshake initiation is sent out. Can't drop packets.
virtual uint32 OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) = 0;
// Called after handshake initiation is parsed, but before handshake response is sent.
enum {
kHandshakeResponseDrop = 0xffffffff,
kHandshakeResponseFail = 0x80000000
};
// Called before handshake initiation is sent out. Can write extra headers. Can't drop packets.
virtual uint32 OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) = 0;
// Called after handshake initiation is parsed, but before handshake response is sent.
// Packet can be dropped or keypair failed.
virtual uint32 OnHandshake1(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt_in[WG_PUBLIC_KEY_LEN], uint8 *extout, uint32 extout_size, const uint8 salt_out[WG_PUBLIC_KEY_LEN]) = 0;
// Called when handshake response is parsed
virtual uint32 OnHandshake2(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) = 0;
// Called right before an outgoing non-data packet is sent out, but before it's scrambled.
virtual void OnOutgoingHandshakePacket(WgPeer *peer, Packet *packet) = 0;
};
@ -378,7 +382,9 @@ public:
void SetPlugin(WgPlugin *del) { plugin_ = del; }
WgPlugin *plugin() { return plugin_; }
MultithreadedDelayedDelete *GetDelayedDelete() { return &delayed_delete_; }
private:
std::pair<WgPeer*, WgKeypair*> *LookupPeerInKeyIdLookup(uint32 key_id);
WgKeypair *LookupKeypairByKeyId(uint32 key_id);
@ -773,6 +779,9 @@ bool WgKeypairDecryptPayload(uint8 *dst, const size_t src_len,
struct WgExtensionHooksDefault {
static uint32 GetIpv4Target(Packet *packet, uint8 *data) { return ReadBE32(data + 16); }
static void OnPeerIncomingUdp(WgPeer *peer, const Packet *packet) { }
static void OnPeerOutgoingUdp(WgPeer *peer, Packet *packet) { }
static bool DisableSourceAddressVerification(WgPeer *peer) { return false; }
};
#ifndef WG_EXTENSION_HOOKS