From ad465d6703793339ca756e50cd63b1d2cb1fcb3a Mon Sep 17 00:00:00 2001 From: Ludvig Strigeus Date: Mon, 10 Sep 2018 23:07:06 +0200 Subject: [PATCH] Add WgDevice::Delegate to add peers on demand --- wireguard_proto.cpp | 7 +++++-- wireguard_proto.h | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/wireguard_proto.cpp b/wireguard_proto.cpp index 1d9e587..b582d41 100644 --- a/wireguard_proto.cpp +++ b/wireguard_proto.cpp @@ -54,6 +54,7 @@ bool ReplayDetector::CheckReplay(uint64 seq_nr) { WgDevice::WgDevice() { peers_ = NULL; + delegate_ = NULL; header_obfuscation_ = false; next_rng_slot_ = 0; memset(&compression_header_, 0, sizeof(compression_header_)); @@ -492,8 +493,10 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) { // Hi := HASH(Hi || msg.static) BlakeMix(hi, src->static_enc, sizeof(src->static_enc)); // Lookup the peer with this ID - if (!(peer = dev->GetPeerFromPublicKey(spubi))) - goto getout; + while ((peer = dev->GetPeerFromPublicKey(spubi)) == NULL) { + if (dev->delegate_ == NULL || !dev->delegate_->HandleUnknownPeerId(spubi, packet)) + goto getout; + } // (Ci, K) := KDF2(Ci, DH(sprivr, spubi)) blake2s_hkdf(ci, sizeof(ci), k, sizeof(k), NULL, 32, peer->s_priv_pub_, sizeof(peer->s_priv_pub_), ci, WG_HASH_LEN); // Hi2 := Hi diff --git a/wireguard_proto.h b/wireguard_proto.h index 9e5c12f..9dc7406 100644 --- a/wireguard_proto.h +++ b/wireguard_proto.h @@ -265,6 +265,16 @@ class WgDevice { friend class WgPeer; friend class WireguardProcessor; public: + + // Can be used to customize the behavior of WgDevice + class Delegate { + public: + // This is called from the main thread whenever a public key was not found in the WgDevice, + // return true to try again or false to fail. The packet can be copied and saved + // to resume a handshake later on. + virtual bool HandleUnknownPeerId(uint8 public_key[WG_PUBLIC_KEY_LEN], Packet *packet) = 0; + }; + WgDevice(); ~WgDevice(); @@ -296,6 +306,8 @@ public: bool IsMainThread() { return CurrentThreadIdEquals(main_thread_id_); } void SetCurrentThreadAsMainThread() { main_thread_id_ = GetCurrentThreadId(); } + + void SetDelegate(Delegate *del) { delegate_ = del; } private: std::pair *LookupPeerInKeyIdLookup(uint32 key_id); WgKeypair *LookupKeypairByKeyId(uint32 key_id); @@ -320,6 +332,9 @@ private: // For enumerating all peers WgPeer *peers_; + // For hooking + Delegate *delegate_; + // Lock that protects key_id_lookup_ WG_DECLARE_RWLOCK(key_id_lookup_lock_); // Mapping from key-id to either an active keypair (if keypair is non-NULL),