Add more extension points in wireguard code
This commit is contained in:
parent
b1ffd5738e
commit
13158f9d90
5 changed files with 53 additions and 16 deletions
6
netapi.h
6
netapi.h
|
@ -66,6 +66,12 @@ struct Packet : QueuedItem {
|
|||
HEADROOM_BEFORE = 64,
|
||||
};
|
||||
|
||||
|
||||
#ifdef PACKET_EXTENSION_FIELDS
|
||||
PACKET_EXTENSION_FIELDS
|
||||
#endif // PACKET_EXTENSION_FIELDS
|
||||
|
||||
|
||||
byte data_pre[HEADROOM_BEFORE];
|
||||
byte data_buf[0];
|
||||
|
||||
|
|
|
@ -450,6 +450,8 @@ WireguardProcessor::PacketResult WireguardProcessor::WriteAndEncryptPacketToUdp_
|
|||
packet->addr = peer->endpoint_;
|
||||
packet->protocol = peer->endpoint_protocol_;
|
||||
|
||||
WG_EXTENSION_HOOKS::OnPeerOutgoingUdp(peer, packet);
|
||||
|
||||
if (size == 0) {
|
||||
peer->OnKeepaliveSent();
|
||||
} else {
|
||||
|
@ -663,6 +665,7 @@ void WireguardProcessor::SendHandshakeInitiation(WgPeer *peer) {
|
|||
peer->OnHandshakeInitSent();
|
||||
packet->addr = peer->endpoint_;
|
||||
packet->protocol = peer->endpoint_protocol_;
|
||||
WG_EXTENSION_HOOKS::OnPeerOutgoingUdp(peer, packet);
|
||||
peer->tx_bytes_ += packet->size;
|
||||
|
||||
// If this is an incoming oneway connection (such as tcp), forget the
|
||||
|
@ -888,6 +891,8 @@ WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPack
|
|||
peer->endpoint_protocol_ = packet->protocol;
|
||||
}
|
||||
|
||||
WG_EXTENSION_HOOKS::OnPeerIncomingUdp(peer, packet);
|
||||
|
||||
// Remember how many incoming packets we've seen so we can approximate loss
|
||||
keypair->incoming_packet_count++;
|
||||
|
||||
|
@ -936,9 +941,13 @@ WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPack
|
|||
if (ip_version == 4) {
|
||||
if (data_size < IPV4_HEADER_SIZE)
|
||||
goto getout_error_header;
|
||||
WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
|
||||
peer_from_header = (WgPeer*)dev_.ip_to_peer_map().LookupV4(ReadBE32(data + 12));
|
||||
WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
|
||||
if (!WG_EXTENSION_HOOKS::DisableSourceAddressVerification(peer)) {
|
||||
WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
|
||||
peer_from_header = (WgPeer*)dev_.ip_to_peer_map().LookupV4(ReadBE32(data + 12));
|
||||
WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
|
||||
if (peer_from_header != peer)
|
||||
goto getout_error_header;
|
||||
}
|
||||
size_from_header = ReadBE16(data + 2);
|
||||
if (size_from_header < IPV4_HEADER_SIZE) {
|
||||
// too small packet?
|
||||
|
@ -950,12 +959,14 @@ WireguardProcessor::PacketResult WireguardProcessor::HandleAuthenticatedDataPack
|
|||
WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
|
||||
peer_from_header = (WgPeer*)dev_.ip_to_peer_map().LookupV6(data + 8);
|
||||
WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_);
|
||||
if (peer_from_header != peer)
|
||||
goto getout_error_header;
|
||||
size_from_header = IPV6_HEADER_SIZE + ReadBE16(data + 4);
|
||||
} else {
|
||||
// invalid ip version
|
||||
goto getout_error_header;
|
||||
}
|
||||
if (peer_from_header != peer || size_from_header > data_size)
|
||||
if (size_from_header > data_size)
|
||||
goto getout_error_header;
|
||||
|
||||
packet->size = size_from_header;
|
||||
|
|
|
@ -97,7 +97,9 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
|
|||
|
||||
if (strcmp(group, "[Interface]") == 0) {
|
||||
if (key == NULL) return true;
|
||||
if (strcmp(key, "PrivateKey") == 0) {
|
||||
if (wg_->dev().plugin() && wg_->dev().plugin()->OnUnknownInterfaceSetting(key, value)) {
|
||||
// nothing here
|
||||
} else if (strcmp(key, "PrivateKey") == 0) {
|
||||
if (!ParseBase64Key(value, binkey))
|
||||
return false;
|
||||
had_interface_ = true;
|
||||
|
@ -185,8 +187,7 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
|
|||
wg_->AddExcludedIp(addr);
|
||||
}
|
||||
} else {
|
||||
if (!wg_->dev().plugin() || !wg_->dev().plugin()->OnUnknownInterfaceSetting(key, value))
|
||||
goto err;
|
||||
goto err;
|
||||
}
|
||||
} else if (strcmp(group, "[Peer]") == 0) {
|
||||
if (key == NULL) {
|
||||
|
@ -199,7 +200,10 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
|
|||
memset(&pi_, 0, sizeof(pi_));
|
||||
return true;
|
||||
}
|
||||
if (strcmp(key, "PublicKey") == 0) {
|
||||
|
||||
if (wg_->dev().plugin() && wg_->dev().plugin()->OnUnknownPeerSetting(peer_, key, value)) {
|
||||
// nothing here
|
||||
} else if (strcmp(key, "PublicKey") == 0) {
|
||||
if (!ParseBase64Key(value, pi_.pub.bytes))
|
||||
return false;
|
||||
} else if (strcmp(key, "PresharedKey") == 0) {
|
||||
|
@ -252,8 +256,7 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
|
|||
return false;
|
||||
}
|
||||
} else {
|
||||
if (!wg_->dev().plugin() || !wg_->dev().plugin()->OnUnknownPeerSetting(peer_, key, value))
|
||||
goto err;
|
||||
goto err;
|
||||
}
|
||||
} else {
|
||||
err:
|
||||
|
@ -331,6 +334,11 @@ bool ParseWireGuardConfigString(WireguardProcessor *wg, const char *bufin, size_
|
|||
buf = nl + 1;
|
||||
}
|
||||
file_parser.FinishGroup();
|
||||
|
||||
// Let plugin do any final processing
|
||||
if (wg->dev().plugin() && !wg->dev().plugin()->OnAfterSettingsParsed())
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -787,6 +787,9 @@ WgPeer *WgPeer::ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packe
|
|||
peer->endpoint_ = packet->addr;
|
||||
peer->endpoint_protocol_ = packet->protocol;
|
||||
}
|
||||
|
||||
WG_EXTENSION_HOOKS::OnPeerIncomingUdp(peer, packet);
|
||||
|
||||
peer->rx_bytes_ += packet->size;
|
||||
peer->InsertKeypairInPeer_Locked(keypair);
|
||||
WG_RELEASE_LOCK(peer->mutex_);
|
||||
|
|
|
@ -318,22 +318,26 @@ public:
|
|||
virtual bool OnUnknownInterfaceSetting(const char *key, const char *value) = 0;
|
||||
virtual bool OnUnknownPeerSetting(WgPeer *peer, const char *key, const char *value) = 0;
|
||||
|
||||
// Called after settings have been completely parsed, the plugin may modify the state
|
||||
virtual bool OnAfterSettingsParsed() = 0;
|
||||
|
||||
// Returns true if we want to perform a handshake for this peer.
|
||||
virtual bool WantHandshake(WgPeer *peer) = 0;
|
||||
|
||||
// Called right before handshake initiation is sent out. Can't drop packets.
|
||||
virtual uint32 OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) = 0;
|
||||
// Called after handshake initiation is parsed, but before handshake response is sent.
|
||||
|
||||
enum {
|
||||
kHandshakeResponseDrop = 0xffffffff,
|
||||
kHandshakeResponseFail = 0x80000000
|
||||
};
|
||||
// Called before handshake initiation is sent out. Can write extra headers. Can't drop packets.
|
||||
virtual uint32 OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) = 0;
|
||||
// Called after handshake initiation is parsed, but before handshake response is sent.
|
||||
// Packet can be dropped or keypair failed.
|
||||
virtual uint32 OnHandshake1(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt_in[WG_PUBLIC_KEY_LEN], uint8 *extout, uint32 extout_size, const uint8 salt_out[WG_PUBLIC_KEY_LEN]) = 0;
|
||||
|
||||
// Called when handshake response is parsed
|
||||
virtual uint32 OnHandshake2(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) = 0;
|
||||
|
||||
// Called right before an outgoing non-data packet is sent out, but before it's scrambled.
|
||||
virtual void OnOutgoingHandshakePacket(WgPeer *peer, Packet *packet) = 0;
|
||||
};
|
||||
|
||||
|
||||
|
@ -378,7 +382,9 @@ public:
|
|||
|
||||
void SetPlugin(WgPlugin *del) { plugin_ = del; }
|
||||
WgPlugin *plugin() { return plugin_; }
|
||||
|
||||
|
||||
MultithreadedDelayedDelete *GetDelayedDelete() { return &delayed_delete_; }
|
||||
|
||||
private:
|
||||
std::pair<WgPeer*, WgKeypair*> *LookupPeerInKeyIdLookup(uint32 key_id);
|
||||
WgKeypair *LookupKeypairByKeyId(uint32 key_id);
|
||||
|
@ -773,6 +779,9 @@ bool WgKeypairDecryptPayload(uint8 *dst, const size_t src_len,
|
|||
|
||||
struct WgExtensionHooksDefault {
|
||||
static uint32 GetIpv4Target(Packet *packet, uint8 *data) { return ReadBE32(data + 16); }
|
||||
static void OnPeerIncomingUdp(WgPeer *peer, const Packet *packet) { }
|
||||
static void OnPeerOutgoingUdp(WgPeer *peer, Packet *packet) { }
|
||||
static bool DisableSourceAddressVerification(WgPeer *peer) { return false; }
|
||||
};
|
||||
|
||||
#ifndef WG_EXTENSION_HOOKS
|
||||
|
|
Loading…
Reference in a new issue