diff --git a/tunsafe_wg_plugin.cpp b/tunsafe_wg_plugin.cpp index 33ef486..49e1493 100644 --- a/tunsafe_wg_plugin.cpp +++ b/tunsafe_wg_plugin.cpp @@ -15,11 +15,13 @@ enum { WG_SESSION_AUTH_LEN = 16, }; +enum { + WITH_TWO_FACTOR_AUTHENTICATION = 1, +}; + class PluginPeer; class TunsafePluginImpl; -static const char kTwoFactorTokenTag[] = "Two-Factor Token"; - class ExtFieldWriter { public: ExtFieldWriter(uint8 *target, uint32 target_size) : target_(target), target_size_(target_size), target_pos_(0), fail_flag_(false) { } @@ -186,10 +188,10 @@ public: class TunsafePluginImpl : public TunsafePlugin { friend class PluginPeer; public: - TunsafePluginImpl(PluginDelegate *del) { + TunsafePluginImpl(PluginDelegate *del, WireguardProcessor *proc) { delegate_ = del; + proc_ = proc; peer_doing_2fa_ = NULL; - proc_ = NULL; OsGetRandomBytes((uint8*)&siphash_key_, sizeof(siphash_key_)); } @@ -204,19 +206,25 @@ public: private: virtual bool HandleUnknownPeerId(uint8 public_key[WG_PUBLIC_KEY_LEN], Packet *packet) override { return false; } - virtual bool OnUnknownInterfaceSetting(const char *key, const char *value) override { return false; } + virtual bool OnUnknownInterfaceSetting(const char *key, const char *value) override; virtual bool OnUnknownPeerSetting(WgPeer *peer, const char *key, const char *value) override; virtual bool WantHandshake(WgPeer *peer) override; virtual uint32 OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) override; virtual uint32 OnHandshake1(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt_in[WG_PUBLIC_KEY_LEN], uint8 *extout, uint32 extout_size, const uint8 salt_out[WG_PUBLIC_KEY_LEN]) override; virtual uint32 OnHandshake2(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) override; + virtual bool OnAfterSettingsParsed() override; + virtual void OnOutgoingHandshakePacket(WgPeer *peer, Packet *packet) override; + PluginPeer *GetPluginPeer(WgPeer *peer); virtual void SubmitToken(const uint8 *text, size_t text_len) override; + WireguardProcessor *proc_; PluginPeer *peer_doing_2fa_; PluginDelegate *delegate_; + siphash_key_t siphash_key_; + }; PluginPeer::~PluginPeer() { @@ -553,9 +561,9 @@ bool TokenServerHandler::VerifySessionId(const uint8 session_id_auth[WG_SESSION_ return memcmp_crypto(buf, session_id_auth, WG_SESSION_AUTH_LEN) == 0; } - //////////////////////////////////////////////////////////////////////////////////////////////////////////////// + PluginPeer *TunsafePluginImpl::GetPluginPeer(WgPeer *peer) { PluginPeer *rv = (PluginPeer *)peer->extradata(); if (!rv) { @@ -567,15 +575,20 @@ PluginPeer *TunsafePluginImpl::GetPluginPeer(WgPeer *peer) { bool TunsafePluginImpl::WantHandshake(WgPeer *peer) { PluginPeer *pp = GetPluginPeer(peer); - return pp->token_server_handler.WantHandshake() && - pp->token_client_handler.WantHandshake(); + if (WITH_TWO_FACTOR_AUTHENTICATION) { + return pp->token_server_handler.WantHandshake() && + pp->token_client_handler.WantHandshake(); + } else { + return true; + } } // This runs on client and appends data uint32 TunsafePluginImpl::OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) { PluginPeer *pp = GetPluginPeer(peer); ExtFieldWriter writer(extout, extout_size); - pp->token_client_handler.OnHandshakeCreate(peer, writer, salt); + if (WITH_TWO_FACTOR_AUTHENTICATION) + pp->token_client_handler.OnHandshakeCreate(peer, writer, salt); return writer.length(); } @@ -593,27 +606,31 @@ uint32 TunsafePluginImpl::OnHandshake1(WgPeer *peer, const uint8 *ext, uint32 ex ext += 2, ext_size -= 2; if (size > ext_size) return false; - switch (type) { - case kExtensionType_SessionIDAuth: - if (size == WG_SESSION_AUTH_LEN) - has_valid_session_id = pp->token_server_handler.VerifySessionId(ext, salt_in); - break; + if (WITH_TWO_FACTOR_AUTHENTICATION) { + switch (type) { + case kExtensionType_SessionIDAuth: + if (size == WG_SESSION_AUTH_LEN) + has_valid_session_id = pp->token_server_handler.VerifySessionId(ext, salt_in); + break; - case kExtensionType_TokenReply: - token_reply = (uint8*)ext; - token_reply_size = size; - break; + case kExtensionType_TokenReply: + token_reply = (uint8*)ext; + token_reply_size = size; + break; + } } ext += size, ext_size -= size; } if (ext_size != 0) return kHandshakeResponseDrop; - // If this is a handshake in the other direction, also include session id. - pp->token_client_handler.WriteSessionId(writer, salt_out); + if (WITH_TWO_FACTOR_AUTHENTICATION) { + // If this is a handshake in the other direction, also include session id. + pp->token_client_handler.WriteSessionId(writer, salt_out); - if (!pp->token_server_handler.OnHandshake(token_reply, token_reply_size, has_valid_session_id, writer, &siphash_key_)) - return kHandshakeResponseDrop; + if (!pp->token_server_handler.OnHandshake(token_reply, token_reply_size, has_valid_session_id, writer, &siphash_key_)) + return kHandshakeResponseDrop; + } return writer.length() + writer.fail_flag() * WgPlugin::kHandshakeResponseFail; } @@ -629,47 +646,58 @@ uint32 TunsafePluginImpl::OnHandshake2(WgPeer *peer, const uint8 *ext, uint32 ex ext += 2, ext_size -= 2; if (size > ext_size) return false; - switch (type) { - case kExtensionType_SessionIDAuth: - if (size == WG_SESSION_AUTH_LEN) - has_valid_session_id = pp->token_server_handler.VerifySessionId(ext, salt); - break; - // All token requests mean that handshake has failed. - case kExtensionType_TokenRequest: - pp->token_client_handler.OnTokenRequest(ext, size); - return kHandshakeResponseDrop; - case kExtensionType_SetSessionID: - if (size == WG_SESSION_ID_LEN) - pp->token_client_handler.SetSessionId(ext); - break; + + if (WITH_TWO_FACTOR_AUTHENTICATION) { + switch (type) { + case kExtensionType_SessionIDAuth: + if (size == WG_SESSION_AUTH_LEN) + has_valid_session_id = pp->token_server_handler.VerifySessionId(ext, salt); + break; + // All token requests mean that handshake has failed. + case kExtensionType_TokenRequest: + pp->token_client_handler.OnTokenRequest(ext, size); + return kHandshakeResponseDrop; + case kExtensionType_SetSessionID: + if (size == WG_SESSION_ID_LEN) + pp->token_client_handler.SetSessionId(ext); + break; + } } ext += size, ext_size -= size; } if (ext_size != 0) return kHandshakeResponseDrop; - // Stop outgoing handshakes if client didn't supply a valid session id. - if (!pp->token_server_handler.OnHandshake2(has_valid_session_id)) - return kHandshakeResponseDrop; - - pp->token_client_handler.OnHandshakeComplete(); + if (WITH_TWO_FACTOR_AUTHENTICATION) { + // Stop outgoing handshakes if client didn't supply a valid session id. + if (!pp->token_server_handler.OnHandshake2(has_valid_session_id)) + return kHandshakeResponseDrop; + pp->token_client_handler.OnHandshakeComplete(); + } return 0; } +bool TunsafePluginImpl::OnUnknownInterfaceSetting(const char *key, const char *value) { + return false; +} + bool TunsafePluginImpl::OnUnknownPeerSetting(WgPeer *peer, const char *key, const char *value) { PluginPeer *pp = GetPluginPeer(peer); - return pp->token_server_handler.OnUnknownPeerSetting(key, value); + if (WITH_TWO_FACTOR_AUTHENTICATION && pp->token_server_handler.OnUnknownPeerSetting(key, value)) + return true; + + return false; } void TunsafePluginImpl::OnTokenRequest(PluginPeer *peer) { - if (peer_doing_2fa_ != NULL) + if (!WITH_TWO_FACTOR_AUTHENTICATION || peer_doing_2fa_ != NULL) return; peer_doing_2fa_ = peer; delegate_->OnRequestToken(peer->peer, peer->token_client_handler.token_request()); } void TunsafePluginImpl::SubmitToken(const uint8 *text, size_t text_len) { - if (peer_doing_2fa_ == NULL) + if (!WITH_TWO_FACTOR_AUTHENTICATION || peer_doing_2fa_ == NULL) return; assert(peer_doing_2fa_->peer->dev()->IsMainThread()); peer_doing_2fa_->token_client_handler.SetToken(text, text_len); @@ -688,7 +716,14 @@ void TunsafePluginImpl::SubmitToken(const uint8 *text, size_t text_len) { } } -TunsafePlugin *CreateTunsafePlugin(PluginDelegate *delegate) { - return new TunsafePluginImpl(delegate); +bool TunsafePluginImpl::OnAfterSettingsParsed() { + return true; +} + +void TunsafePluginImpl::OnOutgoingHandshakePacket(WgPeer *peer, Packet *packet) { +} + +TunsafePlugin *CreateTunsafePlugin(PluginDelegate *delegate, WireguardProcessor *wgp) { + return new TunsafePluginImpl(delegate, wgp); } diff --git a/tunsafe_wg_plugin.h b/tunsafe_wg_plugin.h index 73cbc01..d21eb1c 100644 --- a/tunsafe_wg_plugin.h +++ b/tunsafe_wg_plugin.h @@ -32,15 +32,8 @@ public: enum { kMaxTokenLen = 128, }; - - - void Initialize(WireguardProcessor *proc) { proc_ = proc; } - // Called after OnRequest2FA to supply the token. virtual void SubmitToken(const uint8 *text, size_t text_len) = 0; - -protected: - WireguardProcessor *proc_; }; -TunsafePlugin *CreateTunsafePlugin(PluginDelegate *del); +TunsafePlugin *CreateTunsafePlugin(PluginDelegate *del, WireguardProcessor *wgp); diff --git a/wireguard.cpp b/wireguard.cpp index 5184e81..60574bd 100644 --- a/wireguard.cpp +++ b/wireguard.cpp @@ -571,6 +571,8 @@ getout_discard: void WireguardProcessor::PrepareOutgoingHandshakePacket(WgPeer *peer, Packet *packet) { assert(dev_.IsMainThread()); + if (dev_.plugin_) + dev_.plugin_->OnOutgoingHandshakePacket(peer, packet); stats_.udp_packets_out++; stats_.udp_bytes_out += packet->size; } diff --git a/wireguard_proto.h b/wireguard_proto.h index 1471036..4ffbac9 100644 --- a/wireguard_proto.h +++ b/wireguard_proto.h @@ -550,6 +550,8 @@ public: void SetExtradata(WgPeerExtraData *ex) { peer_extra_data_ = ex; } WgDevice *dev() { return dev_; } + const uint8 *epriv() { return hs_.e_priv; } + private: bool ParseExtendedHandshake(WgKeypair *keypair, const uint8 *data, size_t data_size); static WgKeypair *CreateNewKeypair(bool is_initiator, const uint8 key[WG_HASH_LEN], uint32 send_key_id);