From 1e414a700e3ac13900f708175a08abacca7b124a Mon Sep 17 00:00:00 2001 From: Ludvig Strigeus Date: Mon, 19 Nov 2018 21:24:43 +0100 Subject: [PATCH] Two Factor Authentication (with TOTP) --- TunSafe.rc | 49 +++ TunSafe.vcxproj | 4 + TunSafe.vcxproj.filters | 15 + crypto/sha/sha1.cpp | 116 +++++++ crypto/sha/sha1.h | 49 +++ network_win32.cpp | 55 ++- network_win32.h | 13 +- network_win32_api.h | 7 +- resource.h | 7 +- service_win32.cpp | 20 ++ service_win32.h | 10 +- service_win32_constants.h | 2 + tunsafe_amalgam.cpp | 2 + tunsafe_bsd.cpp | 27 +- tunsafe_config.h | 2 +- tunsafe_wg_plugin.cpp | 689 ++++++++++++++++++++++++++++++++++++++ tunsafe_wg_plugin.h | 46 +++ tunsafe_win32.cpp | 202 +++++++++++ wireguard.cpp | 10 +- wireguard_proto.cpp | 31 +- wireguard_proto.h | 18 +- 21 files changed, 1336 insertions(+), 38 deletions(-) create mode 100644 crypto/sha/sha1.cpp create mode 100644 crypto/sha/sha1.h create mode 100644 tunsafe_wg_plugin.cpp create mode 100644 tunsafe_wg_plugin.h diff --git a/TunSafe.rc b/TunSafe.rc index 4bdb3ed..4600597 100644 --- a/TunSafe.rc +++ b/TunSafe.rc @@ -82,6 +82,29 @@ BEGIN PUSHBUTTON "&Randomize",IDRAND,7,70,50,14 END +IDD_DIALOG3 DIALOGEX 0, 0, 211, 94 +STYLE DS_SETFONT | DS_MODALFRAME | DS_FIXEDSYS | WS_POPUP | WS_CAPTION | WS_SYSMENU +CAPTION "Two Factor Authentication" +FONT 8, "MS Shell Dlg", 400, 0, 0x1 +BEGIN + LTEXT "The server requires Two Factor authentication. Please enter the code from your authenticator.",-1,7,7,197,18 + CONTROL "",IDC_PAINTBOX,"TwoFactorEditField",WS_TABSTOP,7,32,197,34,0x4000000L + PUSHBUTTON "&Cancel",IDCANCEL,154,72,50,14 + LTEXT "",IDC_CODENOTACCEPTED,7,74,133,8,NOT WS_VISIBLE +END + +IDD_DIALOG4 DIALOGEX 0, 0, 211, 78 +STYLE DS_SETFONT | DS_MODALFRAME | DS_FIXEDSYS | WS_POPUP | WS_CAPTION | WS_SYSMENU +CAPTION "Two Factor Authentication" +FONT 8, "MS Shell Dlg", 400, 0, 0x1 +BEGIN + EDITTEXT IDC_TWOFACTOREDIT,7,29,197,19,ES_PASSWORD | ES_AUTOHSCROLL + DEFPUSHBUTTON "&OK",IDOK,99,57,50,14 + PUSHBUTTON "&Cancel",IDCANCEL,155,57,50,14 + LTEXT "The server requires Two Factor authentication. Please enter the code from your authenticator.",-1,7,7,197,18 + LTEXT "",IDC_CODENOTACCEPTED,7,53,75,17,NOT WS_VISIBLE +END + ///////////////////////////////////////////////////////////////////////////// // @@ -106,6 +129,22 @@ BEGIN TOPMARGIN, 7 BOTTOMMARGIN, 84 END + + IDD_DIALOG3, DIALOG + BEGIN + LEFTMARGIN, 7 + RIGHTMARGIN, 204 + TOPMARGIN, 7 + BOTTOMMARGIN, 87 + END + + IDD_DIALOG4, DIALOG + BEGIN + LEFTMARGIN, 7 + RIGHTMARGIN, 204 + TOPMARGIN, 7 + BOTTOMMARGIN, 71 + END END #endif // APSTUDIO_INVOKED @@ -137,6 +176,16 @@ BEGIN 0 END +IDD_DIALOG3 AFX_DIALOG_LAYOUT +BEGIN + 0 +END + +IDD_DIALOG4 AFX_DIALOG_LAYOUT +BEGIN + 0 +END + ///////////////////////////////////////////////////////////////////////////// // diff --git a/TunSafe.vcxproj b/TunSafe.vcxproj index d0b93cf..92a582c 100644 --- a/TunSafe.vcxproj +++ b/TunSafe.vcxproj @@ -184,6 +184,7 @@ + true @@ -219,6 +220,7 @@ + @@ -228,6 +230,7 @@ + true @@ -257,6 +260,7 @@ + diff --git a/TunSafe.vcxproj.filters b/TunSafe.vcxproj.filters index 10ced29..5dc191c 100644 --- a/TunSafe.vcxproj.filters +++ b/TunSafe.vcxproj.filters @@ -29,6 +29,9 @@ {0f45e1a0-f33e-4c6e-88ae-eb4639f12041} + + {1725c5b8-3480-41d8-8b8e-356f68ceda3d} + @@ -143,6 +146,12 @@ Source Files\Win32 + + crypto\sha + + + Source Files + @@ -223,6 +232,12 @@ Source Files\Win32 + + crypto\sha + + + Source Files + diff --git a/crypto/sha/sha1.cpp b/crypto/sha/sha1.cpp new file mode 100644 index 0000000..8225842 --- /dev/null +++ b/crypto/sha/sha1.cpp @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: AGPL-1.0-only +// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. +// Copyright (C) 1998, 2009, Paul E. Jones + +#include "stdafx.h" +#include "crypto/sha/sha1.h" +#include +#include "tunsafe_endian.h" + +#define SHA1Rotate(word, bits) ((((word) << (bits)) & 0xFFFFFFFF) | ((word) >> (32-(bits)))) + +static void SHA1ProcessMessageBlock(SHA1Context *ctx) { + uint32 t, temp, W[80]; + + for (t = 0; t < 16; t++) + W[t] = ReadBE32(&ctx->buffer[t * 4]); + + for (t = 16; t < 80; t++) + W[t] = SHA1Rotate( W[t - 3] ^ W[t - 8] ^ W[t - 14] ^ W[t - 16], 1); + + uint32 A = ctx->state[0], B = ctx->state[1], C = ctx->state[2], D = ctx->state[3], E = ctx->state[4]; + +#define SHA1_ROUND(x) temp = SHA1Rotate(A, 5) + (x), E = D, D = C, C = SHA1Rotate(B, 30), B = A, A = temp + for (t = 0; t < 20; t++) + SHA1_ROUND(((B & C) | ((~B) & D)) + E + W[t] + 0x5A827999); + for (t = 20; t < 40; t++) + SHA1_ROUND((B ^ C ^ D) + E + W[t] + 0x6ED9EBA1); + for (t = 40; t < 60; t++) + SHA1_ROUND(((B & C) | (B & D) | (C & D)) + E + W[t] + 0x8F1BBCDC); + for (t = 60; t < 80; t++) + SHA1_ROUND((B ^ C ^ D) + E + W[t] + 0xCA62C1D6); +#undef SHA1_ROUND + + ctx->state[0] += A; + ctx->state[1] += B; + ctx->state[2] += C; + ctx->state[3] += D; + ctx->state[4] += E; + ctx->pos = 0; +} + +void SHA1Reset(SHA1Context *ctx) { + ctx->length = ctx->pos = 0; + ctx->state[0] = 0x67452301; + ctx->state[1] = 0xEFCDAB89; + ctx->state[2] = 0x98BADCFE; + ctx->state[3] = 0x10325476; + ctx->state[4] = 0xC3D2E1F0; +} + +void SHA1Finish(SHA1Context *ctx, uint8 digest[20]) { + ctx->buffer[ctx->pos++] = 0x80; + while (ctx->pos != 56) { + if (ctx->pos == 64) + SHA1ProcessMessageBlock(ctx); + ctx->buffer[ctx->pos++] = 0; + } + WriteBE64(&ctx->buffer[56], ctx->length); + SHA1ProcessMessageBlock(ctx); + for (int i = 0; i < 5; i++) + WriteBE32(digest + i * 4, ctx->state[i]); +} + +void SHA1Input(SHA1Context *ctx, const uint8 *input, size_t input_len) { + ctx->length += input_len * 8; + while (input_len--) { + ctx->buffer[ctx->pos++] = *input++; + if (ctx->pos == 64) + SHA1ProcessMessageBlock(ctx); + } +} + +void SHA1Hash(const uint8 *data, int data_size, uint8 digest[20]) { + SHA1Context ctx; + SHA1Reset(&ctx); + SHA1Input(&ctx, data, data_size); + SHA1Finish(&ctx, digest); +} + +void SHA1HmacReset(SHA1HmacContext *hmac, const unsigned char *key, unsigned key_size) { + byte temp[64]; + byte temp2[64]; + byte digest[20]; + int i; + + if (key_size > 64) { + SHA1Hash(key, key_size, digest); + key = digest; + key_size = sizeof(digest); + } + + for (i = 0; i != key_size; i++) { + temp[i] = key[i] ^ 0x36; + temp2[i] = key[i] ^ 0x5C; + } + for (; i != 64; i++) { + temp[i] = 0x36; + temp2[i] = 0x5C; + } + + SHA1Reset(&hmac->sha1); + SHA1Reset(&hmac->sha2); + + SHA1Input(&hmac->sha1, temp, sizeof(temp)); + SHA1Input(&hmac->sha2, temp2, sizeof(temp2)); +} + +void SHA1HmacInput(SHA1HmacContext *hmac, const unsigned char *input, unsigned input_size) { + SHA1Input(&hmac->sha1, input, input_size); +} + +void SHA1HmacFinish(SHA1HmacContext *hmac, byte digest[20]) { + SHA1Finish(&hmac->sha1, digest); + SHA1Input(&hmac->sha2, digest, 20); + SHA1Finish(&hmac->sha2, digest); +} diff --git a/crypto/sha/sha1.h b/crypto/sha/sha1.h new file mode 100644 index 0000000..19bb509 --- /dev/null +++ b/crypto/sha/sha1.h @@ -0,0 +1,49 @@ +/* + * sha1.h + * + * Copyright (C) 1998, 2009 + * Paul E. Jones + * All Rights Reserved + * + ***************************************************************************** + * $Id: sha1.h 12 2009-06-22 19:34:25Z paulej $ + ***************************************************************************** + * + * Description: + * This class implements the Secure Hashing Standard as defined + * in FIPS PUB 180-1 published April 17, 1995. + * + * Many of the variable names in the SHA1Context, especially the + * single character names, were used because those were the names + * used in the publication. + * + * Please read the file sha1.c for more information. + * + */ + +#ifndef _SHA1_H_ +#define _SHA1_H_ + +#include "tunsafe_types.h" + +struct SHA1Context { + uint32 state[5]; + uint64 length; + uint8 buffer[64]; + uint32 pos; +}; + +void SHA1Reset(SHA1Context *ctx); +void SHA1Input(SHA1Context *ctx, const uint8 *input, size_t input_len); +void SHA1Finish(SHA1Context *ctx, uint8 digest[20]); +void SHA1Hash(const uint8 *data, int data_size, uint8 digest[20]); + +struct SHA1HmacContext { + SHA1Context sha1, sha2; +}; + +void SHA1HmacReset(SHA1HmacContext *hmac, const unsigned char *key, unsigned key_size); +void SHA1HmacInput(SHA1HmacContext *hmac, const unsigned char *input, unsigned input_size); +void SHA1HmacFinish(SHA1HmacContext *hmac, byte digest[20]); + +#endif diff --git a/network_win32.cpp b/network_win32.cpp index a946bc3..51dce48 100644 --- a/network_win32.cpp +++ b/network_win32.cpp @@ -22,6 +22,7 @@ #include #include "network_win32_dnsblock.h" #include "util_win32.h" +#include "tunsafe_wg_plugin.h" enum { HARD_MAXIMUM_QUEUE_SIZE = 102400, @@ -1939,6 +1940,7 @@ static void RemoveKillSwitchRoute() { TunsafeBackendWin32::TunsafeBackendWin32(Delegate *delegate) : delegate_(delegate), dns_resolver_(&dns_blocker_) { memset(&stats_, 0, sizeof(stats_)); wg_processor_ = NULL; + token_request_ = 0; InitPacketMutexes(); worker_thread_ = NULL; last_tun_adapter_failed_ = 0; @@ -1963,6 +1965,12 @@ void TunsafeBackendWin32::SetPublicKey(const uint8 key[32]) { delegate_->OnStateChanged(); } +struct PluginHolder { + PluginHolder(PluginDelegate *del) : plugin(CreateTunsafePlugin(del)) {} + ~PluginHolder() { delete plugin; } + TunsafePlugin *plugin; +}; + DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) { TunsafeBackendWin32 *backend = (TunsafeBackendWin32*)bk; int stop_mode; @@ -1971,7 +1979,10 @@ DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) { for (;;) { TunWin32Iocp tun(&backend->dns_blocker_, backend); NetworkWin32 net; + PluginHolder plugin(backend); WireguardProcessor wg_proc(&net, &tun, backend); + wg_proc.dev().SetPlugin(plugin.plugin); + plugin.plugin->Initialize(&wg_proc); net.udp().SetPacketHandler(&backend->packet_processor_); net.tcp_socket_queue().SetPacketHandler(&backend->packet_processor_); @@ -1988,6 +1999,7 @@ DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) { backend->SetPublicKey(wg_proc.dev().public_key()); backend->wg_processor_ = &wg_proc; + backend->tunsafe_wg_plugin_ = plugin.plugin; net.StartThread(); tun.StartThread(); @@ -1996,6 +2008,7 @@ DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) { tun.StopThread(); backend->wg_processor_ = NULL; + backend->tunsafe_wg_plugin_ = NULL; // Keep DNS alive if (stop_mode != MODE_EXIT) @@ -2079,6 +2092,7 @@ void TunsafeBackendWin32::Start(const char *config_file) { dns_resolver_.ResetCancel(); g_killswitch_currconn = kBlockInternet_Default; is_started_ = true; + token_request_ = 0; memset(public_key_, 0, sizeof(public_key_)); SetStatus(kStatusInitializing); delegate_->OnClearLog(); @@ -2192,15 +2206,25 @@ struct ConfigQueueItem : QueuedItem, QueuedItemCallback { virtual void OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) override; virtual void OnQueuedItemDelete(QueuedItem *ow) override; + enum Type { + SendConfigurationProtocolPacket, + SubmitToken + }; + Type type; std::string message; uint32 ident; }; void ConfigQueueItem::OnQueuedItemEvent(QueuedItem *ow, uintptr_t extra) { PacketProcessor::QueueContext *context = (PacketProcessor::QueueContext *)extra; - std::string reply; - WgConfig::HandleConfigurationProtocolMessage(context->wg, std::move(message), &reply); - context->backend->delegate_->OnConfigurationProtocolReply(ident, std::move(reply)); + + if (type == SendConfigurationProtocolPacket) { + std::string reply; + WgConfig::HandleConfigurationProtocolMessage(context->wg, std::move(message), &reply); + context->backend->delegate_->OnConfigurationProtocolReply(ident, std::move(reply)); + } else { + context->backend->tunsafe_wg_plugin_->SubmitToken((const uint8*)message.data(), message.size()); + } delete this; } @@ -2210,12 +2234,37 @@ void ConfigQueueItem::OnQueuedItemDelete(QueuedItem *ow) { void TunsafeBackendWin32::SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) { ConfigQueueItem *queue_item = new ConfigQueueItem; + queue_item->type = ConfigQueueItem::SendConfigurationProtocolPacket; queue_item->ident = identifier; queue_item->message = std::move(message); queue_item->queue_cb = queue_item; packet_processor_.ForcePost(queue_item); } +void TunsafeBackendWin32::SubmitToken(const std::string &&message) { + // Clear out the old token request so GetTokenRequest returns zero. + token_request_ = 0; + + ConfigQueueItem *queue_item = new ConfigQueueItem; + queue_item->type = ConfigQueueItem::SubmitToken; + queue_item->message = std::move(message); + queue_item->queue_cb = queue_item; + packet_processor_.ForcePost(queue_item); + +} + +uint32 TunsafeBackendWin32::GetTokenRequest() { + return token_request_; +} + +// This is called on the wireguard thread whenever it needs a token, +// it should reschedule +void TunsafeBackendWin32::OnRequestToken(WgPeer *peer, uint32 type) { + token_request_ = type; + delegate_->OnStateChanged(); +} + + void TunsafeBackendWin32::OnConnected() { if (status_ != TunsafeBackend::kStatusConnected) { const WgCidrAddr *ipv4_addr = NULL; diff --git a/network_win32.h b/network_win32.h index 5d79849..4ba31c2 100644 --- a/network_win32.h +++ b/network_win32.h @@ -10,6 +10,7 @@ #include "tunsafe_dnsresolve.h" #include "network_common.h" #include "network_win32_tcp.h" +#include "tunsafe_wg_plugin.h" enum { ADAPTER_GUID_SIZE = 40, @@ -253,7 +254,7 @@ private: TunWin32Adapter adapter_; }; -class TunsafeBackendWin32 : public TunsafeBackend, public ProcessorDelegate { +class TunsafeBackendWin32 : public TunsafeBackend, public ProcessorDelegate, public PluginDelegate { friend class PacketProcessor; friend class TunWin32Iocp; friend class TunWin32Overlapped; @@ -277,11 +278,16 @@ public: virtual LinearizedGraph *GetGraph(int type) override; virtual std::string GetConfigFileName() override; virtual void SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) override; - + virtual uint32 GetTokenRequest() override; + virtual void SubmitToken(const std::string &&message) override; + // -- from ProcessorDelegate virtual void OnConnected() override; virtual void OnConnectionRetry(uint32 attempts) override; + // -- from PluginDelegate + virtual void OnRequestToken(WgPeer *peer, uint32 type) override; + void SetPublicKey(const uint8 key[32]); void PostExit(int exit_code); enum { @@ -305,10 +311,13 @@ private: Delegate *delegate_; char *config_file_; + std::atomic token_request_; + DnsBlocker dns_blocker_; DnsResolver dns_resolver_; WireguardProcessor *wg_processor_; + TunsafePlugin *tunsafe_wg_plugin_; uint32 last_tun_adapter_failed_; StatsCollector stats_collector_; diff --git a/network_win32_api.h b/network_win32_api.h index fb4b964..6ae91a2 100644 --- a/network_win32_api.h +++ b/network_win32_api.h @@ -91,7 +91,6 @@ public: // Returns false if the name can't be exclusively reserved to this adapter. virtual bool SetTunAdapterName(const char *name) = 0; - virtual void Start(const char *config_file) = 0; virtual void Stop() = 0; virtual void RequestStats(bool enable) = 0; @@ -105,6 +104,12 @@ public: virtual LinearizedGraph *GetGraph(int type) = 0; virtual void SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) = 0; + // Returns a nonzero value whenever a token is requested, + // as a reply to OnStateChanged + virtual uint32 GetTokenRequest() = 0; + // Called when the UI answers the token request + virtual void SubmitToken(const std::string &&message) = 0; + bool is_started() { return is_started_; } bool is_remote() { return is_remote_; } const uint8 *public_key() { return public_key_; } diff --git a/resource.h b/resource.h index 8979ce2..d613709 100644 --- a/resource.h +++ b/resource.h @@ -6,6 +6,7 @@ #define IDI_ICON1 2 #define ID_STOP 3 #define IDRAND 3 +#define IDCANCEL2 3 #define ID_RESTART 4 #define ID_START 5 #define ID_EXIT 6 @@ -37,11 +38,15 @@ #define IDC_PRIVATE_KEY 34 #define IDD_DIALOG1 101 #define IDD_DIALOG2 105 +#define IDD_DIALOG3 106 #define IDR_MENU1 107 +#define IDD_DIALOG4 107 #define IDC_STATUSBAR 108 #define IDB_DOWNARROW 108 #define IDC_PUBLIC_KEY 109 #define IDC_TAB 110 +#define IDC_CODENOTACCEPTED 111 +#define IDC_TWOFACTOREDIT 1017 // Next default values for new objects // @@ -49,7 +54,7 @@ #ifndef APSTUDIO_READONLY_SYMBOLS #define _APS_NEXT_RESOURCE_VALUE 113 #define _APS_NEXT_COMMAND_VALUE 40030 -#define _APS_NEXT_CONTROL_VALUE 1016 +#define _APS_NEXT_CONTROL_VALUE 1018 #define _APS_NEXT_SYMED_VALUE 101 #endif #endif diff --git a/service_win32.cpp b/service_win32.cpp index b0a9600..5ffaa3d 100644 --- a/service_win32.cpp +++ b/service_win32.cpp @@ -484,6 +484,7 @@ bool TunsafeServiceManager::SwitchInterface(TunsafeServiceServer *server, const /////////////////////////////////////////////////////////////////////////////////////// TunsafeServiceBackend::TunsafeServiceBackend(TunsafeServiceManager *manager) { + token_request_flag_ = 0; manager_ = manager; historical_log_lines_count_ = historical_log_lines_pos_ = 0; memset(historical_log_lines_, 0, sizeof(historical_log_lines_)); @@ -595,6 +596,8 @@ void TunsafeServiceBackend::SendStateUpdate(TunsafeServiceServer *filter) { ss->is_started = backend_->is_started(); ss->internet_block_state = backend_->GetInternetBlockState(); ss->ipv4_ip = backend_->GetIP(); + ss->token_request = backend_->GetTokenRequest(); + ss->token_request_flag = token_request_flag_; memcpy(ss->public_key, backend_->public_key(), 32); memcpy(temp + sizeof(ServiceState), current_filename_.c_str(), current_filename_.size() + 1); for (TunsafeServiceServer *pipe_server : pipe_servers_) { @@ -787,6 +790,10 @@ bool TunsafeServiceServer::HandleMessage(int type, uint8 *data, size_t size) { service_backend_->Start(""); service_backend_->backend_->SendConfigurationProtocolPacket(unique_id_, std::string((char*)data, size)); break; + case TS_SERVICE_REQ_SUBMIT_TOKEN: + service_backend_->token_request_flag_ ^= 1; + service_backend_->backend_->SubmitToken(std::string((char*)data, size)); + break; default: return false; @@ -901,6 +908,7 @@ TunsafeServiceClient::TunsafeServiceClient(TunsafeBackend::Delegate *delegate) delegate_ = delegate; cached_graph_ = 0; last_graph_type_ = 0xffffffff; + token_request_flag_ = 0xff; memset(&service_state_, 0, sizeof(service_state_)); connection_ = pipe_manager_.GetClientConnection(); } @@ -959,6 +967,18 @@ LinearizedGraph *TunsafeServiceClient::GetGraph(int type) { void TunsafeServiceClient::SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) { } +uint32 TunsafeServiceClient::GetTokenRequest() { + mutex_.Acquire(); + uint32 rv = (token_request_flag_ == service_state_.token_request_flag) ? 0 : service_state_.token_request; + mutex_.Release(); + return rv; +} + +void TunsafeServiceClient::SubmitToken(const std::string &&token) { + token_request_flag_ = service_state_.token_request_flag; + connection_->WritePacket(TS_SERVICE_REQ_SUBMIT_TOKEN, (const uint8*)token.data(), token.size()); +} + std::string TunsafeServiceClient::GetConfigFileName() { mutex_.Acquire(); std::string rv = config_file_; diff --git a/service_win32.h b/service_win32.h index d105004..7c231ee 100644 --- a/service_win32.h +++ b/service_win32.h @@ -83,6 +83,9 @@ public: // Called to register a pipe server with this backend void AddPipeServer(TunsafeServiceServer *pipe_server); private: + // toggled every time a token submit is processed + uint8 token_request_flag_; + // Points at the service manager TunsafeServiceManager *manager_; @@ -162,9 +165,11 @@ private: struct ServiceState { uint8 is_started : 1; + uint8 token_request_flag : 1; // toggled each time token_request changes uint8 reserved1; uint16 internet_block_state; - uint8 reserved[24 + 64]; + uint8 reserved[20 + 64]; + uint32 token_request; uint32 ipv4_ip; uint8 public_key[32]; }; @@ -190,6 +195,8 @@ public: virtual void SetServiceStartupFlags(uint32 flags); virtual LinearizedGraph *GetGraph(int type); virtual void SendConfigurationProtocolPacket(uint32 identifier, const std::string &&message) override; + virtual uint32 GetTokenRequest() override; + virtual void SubmitToken(const std::string &&token) override; // -- from PipeConnection::Delegate virtual bool HandleMessage(int type, uint8 *data, size_t size) override; @@ -203,6 +210,7 @@ public: protected: TunsafeBackend::Delegate *delegate_; uint8 want_stats_; + uint8 token_request_flag_; bool got_state_from_control_; ServiceState service_state_; std::string config_file_; diff --git a/service_win32_constants.h b/service_win32_constants.h index 903ad34..1c1f771 100644 --- a/service_win32_constants.h +++ b/service_win32_constants.h @@ -29,6 +29,8 @@ enum { TS_SERVICE_REQ_GETINTERFACES = 19, TS_SERVICE_REQ_GETINTERFACES_REPLY = 20, + + TS_SERVICE_REQ_SUBMIT_TOKEN = 21, }; enum { diff --git a/tunsafe_amalgam.cpp b/tunsafe_amalgam.cpp index 9a28774..b6545d3 100644 --- a/tunsafe_amalgam.cpp +++ b/tunsafe_amalgam.cpp @@ -9,6 +9,7 @@ #include "wireguard.cpp" #include "wireguard_proto.cpp" #include "wireguard_config.cpp" +#include "tunsafe_wg_plugin.cpp" #include "util.cpp" #include "tunsafe_threading.cpp" #include "tunsafe_cpu.cpp" @@ -19,6 +20,7 @@ #include "crypto/blake2s/blake2s.cpp" #include "crypto/siphash/siphash.cpp" #include "crypto/aesgcm/aesgcm.cpp" +#include "crypto/sha/sha1.cpp" #include "network_common.cpp" #if defined(WITH_NETWORK_BSD) diff --git a/tunsafe_bsd.cpp b/tunsafe_bsd.cpp index ded6515..af55a46 100644 --- a/tunsafe_bsd.cpp +++ b/tunsafe_bsd.cpp @@ -2,6 +2,7 @@ // Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. #include "tunsafe_bsd.h" #include "tunsafe_endian.h" +#include "tunsafe_wg_plugin.h" #include "util.h" #include @@ -43,6 +44,8 @@ #include #endif +static bool g_daemon_mode; + #if defined(OS_MACOSX) || defined(OS_FREEBSD) struct MyRouteMsg { struct rt_msghdr hdr; @@ -646,7 +649,7 @@ const char *print_ip(char buf[kSizeOfAddress], in_addr_t ip) { return buf; } -class TunsafeBackendBsdImpl : public TunsafeBackendBsd, public NetworkBsd::NetworkBsdDelegate, public ProcessorDelegate { +class TunsafeBackendBsdImpl : public TunsafeBackendBsd, public NetworkBsd::NetworkBsdDelegate, public ProcessorDelegate, public PluginDelegate { public: TunsafeBackendBsdImpl(); virtual ~TunsafeBackendBsdImpl(); @@ -669,6 +672,9 @@ public: virtual void OnConnected() override; virtual void OnConnectionRetry(uint32 attempts) override; + // -- from PluginDelegate + virtual void OnRequestToken(WgPeer *peer, uint32 type) override; + WireguardProcessor *processor() { return &processor_; } private: @@ -679,6 +685,7 @@ private: bool is_connected_; uint8 close_orphan_counter_; + TunsafePlugin *plugin_; WireguardProcessor processor_; NetworkBsd network_; TunSocketBsd tun_; @@ -690,15 +697,18 @@ private: TunsafeBackendBsdImpl::TunsafeBackendBsdImpl() : is_connected_(false), close_orphan_counter_(0), + plugin_(CreateTunsafePlugin(this)), processor_(this, this, this), network_(this, 1000), tun_(&network_, &processor_), udp_(&network_, &processor_), unix_socket_listener_(&network_, &processor_), tcp_socket_listener_(&network_, &processor_) { + processor_.dev().SetPlugin(plugin_); } TunsafeBackendBsdImpl::~TunsafeBackendBsdImpl() { + delete plugin_; } bool TunsafeBackendBsdImpl::InitializeTun(char devname[16]) { @@ -808,6 +818,19 @@ void TunsafeBackendBsdImpl::OnConnectionRetry(uint32 attempts) { } } +void TunsafeBackendBsdImpl::OnRequestToken(WgPeer *peer, uint32 type) { + if (!g_daemon_mode) { + fprintf(stderr, "A two factor token is required to login. Please enter the value from your authenticator.\nToken: "); + char buf[100], *rv; + while (!(rv = fgets(buf, 100, stdin)) && errno == EINTR) {} + if (rv) { + size_t len = strlen(buf); + while (len && buf[len-1] == '\n') buf[--len] = 0; + plugin_->SubmitToken((const uint8*)buf, strlen(buf)); + } + } +} + void TunsafeBackendBsdImpl::CloseOrphanTcpConnections() { // Add all incoming tcp connections into a lookup table WG_HASHTABLE_IMPL lookup; @@ -866,6 +889,8 @@ int main(int argc, char **argv) { return 1; if (cmd.daemon) { + g_daemon_mode = true; + fprintf(stderr, "Switching to daemon mode...\n"); if (daemon(0, 0) == -1) perror("daemon() failed"); diff --git a/tunsafe_config.h b/tunsafe_config.h index 37214c2..1fb11a8 100644 --- a/tunsafe_config.h +++ b/tunsafe_config.h @@ -5,7 +5,7 @@ #define TUNSAFE_VERSION_STRING "TunSafe 1.5-rc1" #define TUNSAFE_VERSION_STRING_LONG "TunSafe 1.5-rc1" -#define WITH_HANDSHAKE_EXT 0 +#define WITH_HANDSHAKE_EXT 1 #define WITH_CIPHER_SUITES 0 #define WITH_BOOLEAN_FEATURES 0 #define WITH_PACKET_COMPRESSION 0 diff --git a/tunsafe_wg_plugin.cpp b/tunsafe_wg_plugin.cpp new file mode 100644 index 0000000..a873305 --- /dev/null +++ b/tunsafe_wg_plugin.cpp @@ -0,0 +1,689 @@ +#include "stdafx.h" +#include "tunsafe_wg_plugin.h" +#include "wireguard.h" +#include "util.h" +#include "crypto/curve25519/curve25519-donna.h" +#include "crypto/sha/sha1.h" +#include "crypto/chacha20poly1305.h" +#include "crypto/siphash/siphash.h" +#include "crypto/blake2s/blake2s.h" +#include "tunsafe_endian.h" +#include + +enum { + WG_SESSION_ID_LEN = 32, + WG_SESSION_AUTH_LEN = 16, +}; + +class PluginPeer; +class TunsafePluginImpl; + +static const char kTwoFactorTokenTag[] = "Two-Factor Token"; + +class ExtFieldWriter { +public: + ExtFieldWriter(uint8 *target, uint32 target_size) : target_(target), target_size_(target_size), target_pos_(0), fail_flag_(false) { } + + bool WriteField(uint8 code, const uint8 *data, uint32 size); + void BlockLogin() { fail_flag_ = true; } + bool fail_flag() { return fail_flag_; } + uint32 length() { + return target_pos_; + } +private: + uint8 *target_; + uint32 target_size_; + uint32 target_pos_; + bool fail_flag_; +}; + +bool ExtFieldWriter::WriteField(uint8 code, const uint8 *data, uint32 size) { + assert(size < 256); + uint8 *dst = &target_[target_pos_]; + if (target_pos_ + size + 2 > target_size_) + return false; + target_pos_ += size + 2; + dst[0] = code; + dst[1] = size; + memcpy(dst + 2, data, size); + return true; +} + +enum { + kExtensionType_Padding = 0x00, + // The other peer has no way of identifying a specific instance of + // a connection. There's no way to distinguish a periodic handshake from + // a new client connection. Add a session ID to the Peer to solve this. + // We don't send the actual session id, instead we send: + // Hash(plaintext ephemeral public key, session id) + kExtensionType_SessionIDAuth = 0x01, + kExtensionType_SetSessionID = 0x02, + + // This is sent by the server to request an additional token to allow + // login, for example a TOTP token, or a password. + // By cleverly using session ids, the server can avoid having to request + // this for every new handshake, even when roaming. + kExtensionType_TokenRequest = 0x03, + + // This holds the token reply. + kExtensionType_TokenReply = 0x04, +}; + +class TokenClientHandler { +public: + TokenClientHandler(PluginPeer *pp); + ~TokenClientHandler(); + + void SetSessionId(const uint8 id[WG_SESSION_ID_LEN]); + void SetToken(const uint8 *token, size_t token_size); + void OnHandshakeCreate(WgPeer *peer, ExtFieldWriter &writer, const uint8 salt[WG_PUBLIC_KEY_LEN]); + void OnTokenRequest(const uint8 *data, uint32 data_size); + void OnHandshakeComplete(); + + bool waiting_for_token() { return waiting_for_token_; } + uint32 token_request() { return token_request_type_; } + + bool WantHandshake() { return !waiting_for_token_; } + void WriteSessionId(ExtFieldWriter &writer, const uint8 salt[WG_PUBLIC_KEY_LEN]); +private: + PluginPeer *pp_; + + // Set to true if we're waiting for the UI to set the TOTP-token, so login can continue. + bool waiting_for_token_; + uint8 token_size_; + bool has_session_id_; + + uint32 token_request_type_; + + // The session id + uint8 session_id_[WG_SESSION_ID_LEN]; + + // Crypto key for tokens + uint8 token_crypto_key_[WG_SYMMETRIC_KEY_LEN]; + + // This is set to the token given by the UI + uint8 token_[TunsafePlugin::kMaxTokenLen]; +}; + +class TotpTokenAuthenticator { +public: + TotpTokenAuthenticator() : secret_size_(0), window_size_(30), block_reuse_(false), digits_(6), precision_(0), next_allowed_code_(0) {} + bool Initialize(const char *config); + bool Authenticate(const uint8 *data, size_t size, uint64 *last_code); + bool configured() { return secret_size_ != 0; } + uint8 digits() { return digits_; } +private: + uint32 GetValueForTimestamp(uint64 now); + uint16 window_size_; + uint16 precision_; + bool block_reuse_; + uint8 digits_; + uint8 secret_size_; + uint64 next_allowed_code_; + uint8 secret_[64]; +}; + +class TokenServerHandler { +public: + TokenServerHandler(); + ~TokenServerHandler(); + bool OnHandshake(uint8 *token_reply, int token_reply_size, bool has_valid_session_id, ExtFieldWriter &writer, const siphash_key_t *siphash_key); + bool OnHandshake2(bool has_valid_session_id); + bool OnUnknownPeerSetting(const char *key, const char *value); + bool WantHandshake() { return !stop_reconnects_; } + bool VerifySessionId(const uint8 session_id_auth[WG_SESSION_AUTH_LEN], const uint8 salt[WG_PUBLIC_KEY_LEN]); + +private: + bool has_session_id_; + bool is_session_id_authed_; + bool stop_reconnects_; + uint8 num_failures_; + uint8 reset_recovery_counter_; + uint8 token_bucket_; + uint8 authentication_type_; + uint8 last_login_status_; + uint64 last_attempt_; + uint64 reset_recovery_last_code_; + uint64 last_cksum_; + uint64 cksum_equal_timestamp_; + enum { + // Allow one token attempt every 30 seconds + kTokenBucketCost = 30, + + // And the bucket size is 8, so you can perform 8 attempts + // in a row. + kTokenBucketFull = 240, + + // Failed attempts until lockout. When locked out, you need to + // perform 3 successful 2fa attempts in a row before it's unlocked + // for 3 different codes. + kAttemptsUntilLockout = 10, + + kAttemptsUntilLockoutRemoved = 3, + }; + + TotpTokenAuthenticator token_authenticator_; + uint8 session_id_[WG_SESSION_ID_LEN]; + uint8 token_crypto_key_[WG_SYMMETRIC_KEY_LEN]; +}; + +class PluginPeer : public WgPeerExtraData { +public: + PluginPeer(TunsafePluginImpl *plugin, WgPeer *peer) : plugin(plugin), peer(peer), token_client_handler(this) {} + ~PluginPeer(); + WgPeer *peer; + TunsafePluginImpl *plugin; + TokenClientHandler token_client_handler; + TokenServerHandler token_server_handler; +}; + +// Toplevel wireguard plugin +class TunsafePluginImpl : public TunsafePlugin { + friend class PluginPeer; +public: + TunsafePluginImpl(PluginDelegate *del) { + delegate_ = del; + peer_doing_2fa_ = NULL; + proc_ = NULL; + OsGetRandomBytes((uint8*)&siphash_key_, sizeof(siphash_key_)); + } + + void DeletingPeer(PluginPeer *peer) { + if (peer_doing_2fa_ == peer) + peer_doing_2fa_ = NULL; + } + + PluginDelegate *delegate() { return delegate_; } + + void OnTokenRequest(PluginPeer *peer); + +private: + virtual bool HandleUnknownPeerId(uint8 public_key[WG_PUBLIC_KEY_LEN], Packet *packet) override { return false; } + virtual bool OnUnknownInterfaceSetting(const char *key, const char *value) override { return false; } + virtual bool OnUnknownPeerSetting(WgPeer *peer, const char *key, const char *value) override; + virtual bool WantHandshake(WgPeer *peer) override; + virtual uint32 OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) override; + virtual uint32 OnHandshake1(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt_in[WG_PUBLIC_KEY_LEN], uint8 *extout, uint32 extout_size, const uint8 salt_out[WG_PUBLIC_KEY_LEN]) override; + virtual uint32 OnHandshake2(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) override; + PluginPeer *GetPluginPeer(WgPeer *peer); + + virtual void SubmitToken(const uint8 *text, size_t text_len) override; + + PluginPeer *peer_doing_2fa_; + PluginDelegate *delegate_; + siphash_key_t siphash_key_; +}; + +PluginPeer::~PluginPeer() { + plugin->DeletingPeer(this); +} + + +TokenClientHandler::TokenClientHandler(PluginPeer *pp) { + pp_ = pp; + waiting_for_token_ = false; + token_size_ = false; + has_session_id_ = false; + token_request_type_ = 0; + memset(token_crypto_key_, 0, sizeof(token_crypto_key_)); +} + +TokenClientHandler::~TokenClientHandler() { + +} + +void TokenClientHandler::SetSessionId(const uint8 id[WG_SESSION_ID_LEN]) { + has_session_id_ = true; + memcpy(session_id_, id, WG_SESSION_ID_LEN); +} + +void TokenClientHandler::SetToken(const uint8 *token, size_t token_size) { + if (token_size > TunsafePlugin::kMaxTokenLen || !waiting_for_token_) + return; + waiting_for_token_ = false; + token_size_ = (uint8)token_size; + memcpy(token_, token, token_size); +} + +void TokenClientHandler::WriteSessionId(ExtFieldWriter &writer, const uint8 salt[WG_PUBLIC_KEY_LEN]) { + if (has_session_id_) { + uint8 buf[WG_SESSION_AUTH_LEN]; + blake2s(buf, WG_SESSION_AUTH_LEN, salt, WG_PUBLIC_KEY_LEN, session_id_, sizeof(session_id_)); + writer.WriteField(kExtensionType_SessionIDAuth, buf, WG_SESSION_AUTH_LEN); + } +} + +// This is called to include a token (if the server has set one) in outgoing handshakes. +void TokenClientHandler::OnHandshakeCreate(WgPeer *peer, ExtFieldWriter &writer, const uint8 salt[WG_PUBLIC_KEY_LEN]) { + WriteSessionId(writer, salt); + + if (token_size_ && has_session_id_) { + // Encrypt and include the token in the response + // NOTE: Must not reuse the key to send different tokens, but we send + // only one token as a reply to TokenRequest so that's fine. + uint8 buf[TunsafePlugin::kMaxTokenLen + 16]; + chacha20poly1305_encrypt(buf, token_, token_size_, NULL, 0, 0, token_crypto_key_); + writer.WriteField(kExtensionType_TokenReply, buf, 16 + token_size_); + + static const uint8 kPadding[16] = {0}; + writer.WriteField(kExtensionType_Padding, kPadding, -token_size_ & 0xF); + } +} + +// This runs on the initiator, after the handshake has been parsed +void TokenClientHandler::OnHandshakeComplete() { + // Forget an old token, we'll request it again if needed. + token_size_ = 0; + memset(token_crypto_key_, 0, sizeof(token_crypto_key_)); +} + +// This runs when backend requests a token, ask the user for the token +// and then call SetToken. +void TokenClientHandler::OnTokenRequest(const uint8 *data, uint32 data_size) { + if (data_size >= WG_SYMMETRIC_KEY_LEN + 2 && !waiting_for_token_) { + memcpy(token_crypto_key_, data, WG_SYMMETRIC_KEY_LEN); + token_request_type_ = ReadLE16(data + WG_SYMMETRIC_KEY_LEN); + if (token_size_ && (token_request_type_ & kTokenRequestStatus_Mask) == kTokenRequestStatus_None) + token_request_type_ |= kTokenRequestStatus_NotAccepted; + waiting_for_token_ = true; + pp_->plugin->OnTokenRequest(pp_); + } +} + +// Decode a base32 string, skip whitespace and =. +// returns 0 on failure. +static size_t DecodeBase32String(const char *string, size_t string_len, uint8 *output, size_t output_len) { + // static const char kBase32Charset[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"; + size_t n = 0; + uint32 bitbuff = 0, nbits = 0, v; + for (size_t i = 0; i < string_len; i++) { + uint8 c = string[i]; + if (c >= '2' && c <= '7') { + v = c - '2' + 26; + } else if ((c | 32) >= 'a' && (c | 32) <= 'z') { + v = (c | 32) - 'a'; + } else if (c == ' ' || c == '=' || c == '\t') { + continue; + } else { + return 0; + } + bitbuff = bitbuff * 32 + v; + nbits += 5; + if (nbits >= 8) { + nbits -= 8; + if (n == output_len) + return 0; + output[n++] = (uint8)(bitbuff >> nbits); + } + } + return n; +} + +// RequireToken=totp-sha1:ALPHABETAGAMMAOSCAR,digits=6,period=30,precision=15 +bool TotpTokenAuthenticator::Initialize(const char *config) { + const char *end = config + strlen(config); + const char *comma = strchr(config, ','); + size_t rv = DecodeBase32String(config, (comma ? comma : end) - config, secret_, sizeof(secret_)); + if (!rv) + return false; + secret_size_ = (uint8)rv; + + bool has_precision = false; + while (comma) { + comma += 1; + while (*comma == ' ') comma++; + if (strncmp(comma, "digits=", 7) == 0) { + int v = atoi(comma + 7); + if (v < 6 || v > 8) + return false; + digits_ = (uint8)v; + } else if (strncmp(comma, "period=", 7) == 0) { + int v = atoi(comma + 7); + if (v < 1 || v > 3600) + return false; + window_size_ = (uint16)v; + } else if (strncmp(comma, "precision=", 10) == 0) { + int v = atoi(comma + 10); + if (v < 0 || v > 3600) + return false; + has_precision = true; + precision_ = (uint16)v; + } else if (strncmp(comma, "reuse=0", 7) == 0) { + block_reuse_ = true; + } else { + return false; + } + comma = strchr(comma, ','); + } + if (!has_precision) + precision_ = window_size_ >> 1; + + return true; +} + +extern int memcmp_crypto(const uint8 *a, const uint8 *b, size_t n); + +uint32 TotpTokenAuthenticator::GetValueForTimestamp(uint64 now) { + uint8 hmacbuf[20]; + SHA1HmacContext hmac; + SHA1HmacReset(&hmac, secret_, secret_size_); + uint8 timebuf[8]; + WriteBE64(timebuf, now); + SHA1HmacInput(&hmac, timebuf, 8); + SHA1HmacFinish(&hmac, hmacbuf); + uint32 tmp; + memcpy(&tmp, hmacbuf + (hmacbuf[19] & 0xF), 4); + uint32 value = ReadBE32(&tmp) & 0x7FFFFFFF; + switch (digits_) { + case 6: value %= 1000000; break; + case 7: value %= 10000000; break; + case 8: value %= 100000000; break; + } + return value; +} + +bool TotpTokenAuthenticator::Authenticate(const uint8 *data, size_t size, uint64 *code_out) { + uint64 now = time(NULL); + uint64 first_period = (now - precision_) / window_size_; + uint64 last_period = (now + precision_) / window_size_; + for (; first_period <= last_period; first_period++) { + char buf[16]; + int r = snprintf(buf, sizeof(buf), "%.*u", digits_, GetValueForTimestamp(first_period)); +// RINFO("Checking if %.*s equals %s", (int)size, data, buf); + if (r == size && memcmp_crypto((uint8*)buf, data, size) == 0) { + // Disable code reuse if requested. + if (block_reuse_ && first_period < next_allowed_code_) + return false; + + next_allowed_code_ = first_period + 1; + *code_out = first_period; + return true; + } + } + return false; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +TokenServerHandler::TokenServerHandler() { + num_failures_ = 0; + has_session_id_ = false; + is_session_id_authed_ = false; + stop_reconnects_ = false; + last_attempt_ = 0; + token_bucket_ = kTokenBucketFull; + reset_recovery_counter_ = 0; + reset_recovery_last_code_ = 0; + last_cksum_ = 0; + cksum_equal_timestamp_ = 0; + last_login_status_ = 0; +} + +TokenServerHandler::~TokenServerHandler() { +} + +bool TokenServerHandler::OnHandshake(uint8 *token_reply, int token_reply_size, bool has_valid_session_id, ExtFieldWriter &writer, const siphash_key_t *siphash_key) { + // Tokens not required for this peer? + if (!token_authenticator_.configured()) + return true; + // check that the client has a valid token session, otherwise block login. + if (!has_valid_session_id) { + OsGetRandomBytes(session_id_, sizeof(session_id_)); + writer.WriteField(kExtensionType_SetSessionID, session_id_, WG_SESSION_ID_LEN); + is_session_id_authed_ = false; + has_session_id_ = true; +request_token: + // TODO: Make it optional to reveal last_login_status_ + uint8 data[WG_SYMMETRIC_KEY_LEN + 2]; + OsGetRandomBytes(token_crypto_key_, sizeof(token_crypto_key_)); + data[WG_SYMMETRIC_KEY_LEN] = authentication_type_; + data[WG_SYMMETRIC_KEY_LEN + 1] = last_login_status_; + memcpy(data, token_crypto_key_, WG_SYMMETRIC_KEY_LEN); + writer.WriteField(kExtensionType_TokenRequest, data, sizeof(data)); + writer.BlockLogin(); + return true; + } + + if (!is_session_id_authed_ && token_reply && token_reply_size >= 16) { + // allow only so many attempts per second + uint64 now = OsGetMilliseconds(), code_out; + uint64 secs = (now - last_attempt_) >> 10; + last_attempt_ += secs << 10; + token_bucket_ = (uint8)std::min(token_bucket_ + secs, kTokenBucketFull); + + bool authenticated = false; + // Decrypt and verify the supplied key. If this fails, we're likely using an old key. + if (chacha20poly1305_decrypt(token_reply, token_reply, token_reply_size, NULL, 0, 0, token_crypto_key_)) { + authenticated = token_authenticator_.Authenticate(token_reply, token_reply_size - 16, &code_out); + + // Account is locked after 10 failed attempts. To unlock the account, you need to login 3 times successfully + // in a row, with distincts, increasing codes, with no failed attempts in between. + if (num_failures_ >= kAttemptsUntilLockout) { + if (authenticated && code_out > reset_recovery_last_code_) { + reset_recovery_last_code_ = code_out; + if (reset_recovery_counter_++ == kAttemptsUntilLockoutRemoved - 1) { + RINFO("Account unlocked."); + num_failures_ = 0; + reset_recovery_counter_ = 0; + token_bucket_ = kTokenBucketFull; + } else { + authenticated = false; + } + } else { + reset_recovery_counter_ = 0; + authenticated = false; + } + } else { + // Check if the password is the same as the previous attempt, this could indicate a retransmission of the packet. + // Don't increase num_failures_ based on this, but after a minute, force authenticated to false. + uint64 cksum = siphash(token_reply, token_reply_size - 16, siphash_key); + if (cksum == last_cksum_) { + if (now >= cksum_equal_timestamp_ + 60000) { + authenticated = false; + } else { + if (authenticated) + num_failures_ = 0; + } + } else { + cksum_equal_timestamp_ = now; + last_cksum_ = cksum; + num_failures_ = authenticated ? 0 : num_failures_ + 1; + if (num_failures_ == kAttemptsUntilLockout) + RINFO("Account locked because of %d failed login attempts.", num_failures_); + } + } + } + last_login_status_ = (num_failures_ >= kAttemptsUntilLockout) ? (kTokenRequestStatus_Locked >> 8) : + (token_bucket_ >= kTokenBucketCost) ? (kTokenRequestStatus_Wrong >> 8) : + (kTokenRequestStatus_Ratelimit >> 8); + // Fail when toket bucket is exceeded + if (token_bucket_ >= kTokenBucketCost) { + token_bucket_ -= kTokenBucketCost; + } else { + authenticated = false; + } + is_session_id_authed_ = authenticated; + } + + if (!is_session_id_authed_) + goto request_token; + + last_login_status_ = 0; + stop_reconnects_ = false; + return true; +} + +bool TokenServerHandler::OnHandshake2(bool has_valid_session_id) { + // Tokens not required for this peer? + if (!token_authenticator_.configured()) + return true; + // Only allow configuration in this direction if a valid session key was provided. + if (has_valid_session_id && is_session_id_authed_) + return true; + + // Stop further reconnections to this peer until further notice. + stop_reconnects_ = true; + return false; +} + + +bool TokenServerHandler::OnUnknownPeerSetting(const char *key, const char *value) { + if (strcmp(key, "RequireToken") != 0) + return false; + + if (strncmp(value, "totp-sha1:", 10) != 0) + return false; + + if (!token_authenticator_.Initialize(value + 10)) + return false; + authentication_type_ = kTokenRequestType_6digits + (token_authenticator_.digits() - 6); + return true; +} + +bool TokenServerHandler::VerifySessionId(const uint8 session_id_auth[WG_SESSION_AUTH_LEN], const uint8 salt[WG_PUBLIC_KEY_LEN]) { + if (!has_session_id_) + return false; + uint8 buf[WG_SESSION_AUTH_LEN]; + blake2s(buf, WG_SESSION_AUTH_LEN, salt, WG_PUBLIC_KEY_LEN, session_id_, sizeof(session_id_)); + return memcmp_crypto(buf, session_id_auth, WG_SESSION_AUTH_LEN) == 0; +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +PluginPeer *TunsafePluginImpl::GetPluginPeer(WgPeer *peer) { + PluginPeer *rv = (PluginPeer *)peer->extradata(); + if (!rv) { + rv = new PluginPeer(this, peer); + peer->SetExtradata(rv); + } + return rv; +} + +bool TunsafePluginImpl::WantHandshake(WgPeer *peer) { + PluginPeer *pp = GetPluginPeer(peer); + return pp->token_server_handler.WantHandshake() && + pp->token_client_handler.WantHandshake(); +} + +// This runs on client and appends data +uint32 TunsafePluginImpl::OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) { + PluginPeer *pp = GetPluginPeer(peer); + ExtFieldWriter writer(extout, extout_size); + pp->token_client_handler.OnHandshakeCreate(peer, writer, salt); + return writer.length(); +} + +// This runs on the server to parse init and send response +uint32 TunsafePluginImpl::OnHandshake1(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt_in[WG_PUBLIC_KEY_LEN], uint8 *extout, uint32 extout_size, const uint8 salt_out[WG_PUBLIC_KEY_LEN]) { + PluginPeer *pp = GetPluginPeer(peer); + ExtFieldWriter writer(extout, extout_size); + + bool has_valid_session_id = false; + uint8 *token_reply = NULL; + uint8 token_reply_size = 0; + + while (ext_size >= 2) { + uint8 type = ext[0], size = ext[1]; + ext += 2, ext_size -= 2; + if (size > ext_size) + return false; + switch (type) { + case kExtensionType_SessionIDAuth: + if (size == WG_SESSION_AUTH_LEN) + has_valid_session_id = pp->token_server_handler.VerifySessionId(ext, salt_in); + break; + + case kExtensionType_TokenReply: + token_reply = (uint8*)ext; + token_reply_size = size; + break; + } + ext += size, ext_size -= size; + } + if (ext_size != 0) + return kHandshakeResponseDrop; + + // If this is a handshake in the other direction, also include session id. + pp->token_client_handler.WriteSessionId(writer, salt_out); + + if (!pp->token_server_handler.OnHandshake(token_reply, token_reply_size, has_valid_session_id, writer, &siphash_key_)) + return kHandshakeResponseDrop; + + return writer.length() + writer.fail_flag() * WgPlugin::kHandshakeResponseFail; +} + +// This runs on client and parses response +uint32 TunsafePluginImpl::OnHandshake2(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) { + PluginPeer *pp = GetPluginPeer(peer); + + bool has_valid_session_id = false; + + while (ext_size >= 2) { + uint8 type = ext[0], size = ext[1]; + ext += 2, ext_size -= 2; + if (size > ext_size) + return false; + switch (type) { + case kExtensionType_SessionIDAuth: + if (size == WG_SESSION_AUTH_LEN) + has_valid_session_id = pp->token_server_handler.VerifySessionId(ext, salt); + break; + // All token requests mean that handshake has failed. + case kExtensionType_TokenRequest: + pp->token_client_handler.OnTokenRequest(ext, size); + return kHandshakeResponseDrop; + case kExtensionType_SetSessionID: + if (size == WG_SESSION_ID_LEN) + pp->token_client_handler.SetSessionId(ext); + break; + } + ext += size, ext_size -= size; + } + if (ext_size != 0) + return kHandshakeResponseDrop; + + // Stop outgoing handshakes if client didn't supply a valid session id. + if (!pp->token_server_handler.OnHandshake2(has_valid_session_id)) + return kHandshakeResponseDrop; + + pp->token_client_handler.OnHandshakeComplete(); + return 0; +} + +bool TunsafePluginImpl::OnUnknownPeerSetting(WgPeer *peer, const char *key, const char *value) { + PluginPeer *pp = GetPluginPeer(peer); + return pp->token_server_handler.OnUnknownPeerSetting(key, value); +} + +void TunsafePluginImpl::OnTokenRequest(PluginPeer *peer) { + if (peer_doing_2fa_ != NULL) + return; + peer_doing_2fa_ = peer; + delegate_->OnRequestToken(peer->peer, peer->token_client_handler.token_request()); +} + +void TunsafePluginImpl::SubmitToken(const uint8 *text, size_t text_len) { + if (peer_doing_2fa_ == NULL) + return; + assert(peer_doing_2fa_->peer->dev()->IsMainThread()); + peer_doing_2fa_->token_client_handler.SetToken(text, text_len); + proc_->ForceSendHandshakeInitiation(peer_doing_2fa_->peer); + peer_doing_2fa_ = NULL; + + // Find the next peer requiring a token + for (WgPeer *peer = proc_->dev().first_peer(); peer; peer = peer->next_peer()) { + PluginPeer *pp = (PluginPeer *)peer->extradata(); + if (!pp) + continue; + if (pp->token_client_handler.waiting_for_token()) { + OnTokenRequest(pp); + return; + } + } +} + +TunsafePlugin *CreateTunsafePlugin(PluginDelegate *delegate) { + return new TunsafePluginImpl(delegate); +} + diff --git a/tunsafe_wg_plugin.h b/tunsafe_wg_plugin.h new file mode 100644 index 0000000..73cbc01 --- /dev/null +++ b/tunsafe_wg_plugin.h @@ -0,0 +1,46 @@ +#pragma once +#include "wireguard_proto.h" + + +enum { + kTokenRequestType_Mask = 0xff, + kTokenRequestType_Text = 1, + kTokenRequestType_Password = 2, + kTokenRequestType_6digits = 3, + kTokenRequestType_7digits = 4, + kTokenRequestType_8digits = 5, + + + kTokenRequestStatus_Mask = 0xff00, + kTokenRequestStatus_None = 0x0000, + kTokenRequestStatus_NotAccepted = 0x0100, + kTokenRequestStatus_Wrong = 0x0200, + kTokenRequestStatus_Locked = 0x0300, + kTokenRequestStatus_Ratelimit = 0x0400, +}; + + + +class PluginDelegate { +public: + // Called when 2FA is requested for a peer. + virtual void OnRequestToken(WgPeer *peer, uint32 type) = 0; +}; + +class TunsafePlugin : public WgPlugin { +public: + enum { + kMaxTokenLen = 128, + }; + + + void Initialize(WireguardProcessor *proc) { proc_ = proc; } + + // Called after OnRequest2FA to supply the token. + virtual void SubmitToken(const uint8 *text, size_t text_len) = 0; + +protected: + WireguardProcessor *proc_; +}; + +TunsafePlugin *CreateTunsafePlugin(PluginDelegate *del); diff --git a/tunsafe_win32.cpp b/tunsafe_win32.cpp index b9bd86b..583c9ae 100644 --- a/tunsafe_win32.cpp +++ b/tunsafe_win32.cpp @@ -4,6 +4,7 @@ #include "wireguard_config.h" #include "network_win32_api.h" #include "network_win32_dnsblock.h" +#include "tunsafe_wg_plugin.h" #include #include #include @@ -34,6 +35,8 @@ void InitCpuFeatures(); void PrintCpuFeatures(); void Benchmark(); +void ShowTwoFactorDialog(); + static const char *GetCurrentConfigTitle(char *buf, size_t max_size); static char *PrintMB(char *buf, int64 bytes); static void LoadConfigFile(const char *filename, bool save, bool force_start); @@ -73,6 +76,8 @@ static UINT g_message_taskbar_created; static int g_current_tab; static bool wm_dropfiles_recursive; static bool g_has_icon; +static bool g_twofactor_dialog_shown; +static uint32 g_twofactor_dialog_request; static int g_selected_graph_type; static RECT comborect; static HBITMAP arrowbitmap; @@ -228,6 +233,15 @@ public: SetDlgItemText(g_ui_window, ID_START, running ? "Re&connect" : "&Connect"); InvalidatePaintbox(); EnableWindow(GetDlgItem(g_ui_window, ID_STOP), running); + + if (running && !g_twofactor_dialog_shown) { + uint32 token_request = g_backend->GetTokenRequest(); + if (token_request != 0) { + g_twofactor_dialog_shown = true; + g_twofactor_dialog_request = token_request; + PostMessage(g_ui_window, WM_USER + 3, 0, 0); // show two factor dialog + } + } } virtual void OnStatusCode(TunsafeBackend::StatusCode status) override { @@ -1045,6 +1059,10 @@ static INT_PTR WINAPI DlgProc(HWND hWnd, UINT message, WPARAM wParam, g_backend_delegate->DoWork(); return true; + case WM_USER + 3: + ShowTwoFactorDialog(); + return true; + case WM_INITMENU: { HMENU menu = GetMenu(g_ui_window); @@ -1666,6 +1684,189 @@ static LRESULT CALLBACK AdvancedBoxWndProc(HWND hwnd, UINT uMsg, WPARAM wParam return DefWindowProc(hwnd, uMsg, wParam, lParam); } +static char twofactordig[9]; +static uint8 twofactornum, twofactormax; + +static void DrawInTwoFactorBox(HDC hdc, int w, int h) { + RECT rect = {0, 0, w, h}; + FillRect(hdc, &rect, (HBRUSH)(COLOR_3DFACE + 1)); + + HPEN dc_pen = (HPEN)GetStockObject(DC_PEN); + HGDIOBJ original = SelectObject(hdc, dc_pen); + + HFONT font = CreateFontHelper(32, 0, "Tahoma"); + SelectObject(hdc, font); + SetBkMode(hdc, TRANSPARENT); + + HPEN sel_pen = CreatePen(PS_SOLID, RescaleDpi(3), GetSysColor(COLOR_HIGHLIGHT)); + SetDCPenColor(hdc, GetSysColor(COLOR_3DSHADOW)); + + int item_width = 35, item_spacing = 43, n = 6, xmarg = 15, middle_spacing = 12; + if (twofactormax == 8) { + item_width = 32; + xmarg = 2; + item_spacing = 36; + middle_spacing = 6; + } else if (twofactormax == 7) { + item_width = 35; + item_spacing = 41; + xmarg = 6; + middle_spacing = 0; + } + int radius = RescaleDpi(10); + for (int i = 0; i < twofactormax; i++) { + int x = xmarg + item_spacing * i + (i * 2 >= twofactormax) * middle_spacing; + + SelectObject(hdc, i == twofactornum ? sel_pen : dc_pen); + RECT r2 = {x, 5, x + item_width, 5 + 42}; + RECT r = RescaleDpiRect(r2); + RoundRect(hdc, r.left, r.top, r.right, r.bottom, radius, radius); + + if (i < twofactornum) + DrawText(hdc, twofactordig + i, 1, &r, DT_CENTER | DT_VCENTER | DT_SINGLELINE | DT_NOPREFIX | DT_NOCLIP); + } + + DeleteObject(font); + DeleteObject(sel_pen); + SelectObject(hdc, original); +} + +static LRESULT CALLBACK TwoFactorEditFieldWndProc(HWND hwnd, UINT uMsg, WPARAM wParam, LPARAM lParam) { + switch (uMsg) { + case WM_PAINT: { + HandleWmPaintPaintbox(hwnd, &DrawInTwoFactorBox); + return TRUE; + } + case WM_GETDLGCODE: + return DLGC_WANTCHARS; + case WM_KEYDOWN: + if (wParam >= '0' && wParam <= '9' && twofactornum < twofactormax) { + twofactordig[twofactornum++] = (char)wParam; + if (twofactornum == twofactormax) { + twofactordig[twofactornum] = 0; + g_backend->SubmitToken(std::string(twofactordig)); + SendMessage(GetParent(hwnd), WM_CLOSE, 0, 0); + } else { + InvalidateRect(hwnd, NULL, FALSE); + } + return FALSE; + } else if (wParam == VK_BACK) { + if (twofactornum > 0) { + twofactornum--; + InvalidateRect(hwnd, NULL, FALSE); + } + return FALSE; + } else if (wParam == 'V' && GetAsyncKeyState(VK_CONTROL) < 0) { + if (twofactornum == 0) { + std::string digits = GetClipboardString(); + if (digits.size() == twofactormax) { + g_backend->SubmitToken(std::move(digits)); + SendMessage(GetParent(hwnd), WM_CLOSE, 0, 0); + } + } + return FALSE; + } + break; + case WM_CHAR: + return FALSE; + + case WM_ERASEBKGND: + return TRUE; + + } + return DefWindowProc(hwnd, uMsg, wParam, lParam); +} + +static INT_PTR WINAPI TwoFactorDlgProc(HWND hWnd, UINT message, WPARAM wParam, + LPARAM lParam) { + static HFONT twofactorfont; + + switch (message) { + case WM_INITDIALOG: { + uint32 failreason = g_twofactor_dialog_request & kTokenRequestStatus_Mask; + if (failreason) { + size_t index = (failreason >> 8) - 1; + + static const char * const kFailReasons[] = { + "Code Not Accepted. Please try again.", + "Incorrect code. Please try again.", + "Account locked.", + "Rate limited. Please wait 30 seconds.", + }; + + HWND label = GetDlgItem(hWnd, IDC_CODENOTACCEPTED); + SetWindowText(label, kFailReasons[index >= ARRAYSIZE(kFailReasons) ? 0 : index]); + ShowWindow(label, SW_SHOW); + } + + int type = g_twofactor_dialog_request & kTokenRequestType_Mask; + if (type >= kTokenRequestType_6digits && type <= kTokenRequestType_8digits) { + twofactormax = (uint8)(type - kTokenRequestType_6digits + 6); + } else { + if (type != kTokenRequestType_Password) + SendDlgItemMessage(hWnd, IDC_TWOFACTOREDIT, EM_SETPASSWORDCHAR, 0, 0); + + twofactorfont = CreateFontHelper(20, 0, "Tahoma", 0); + SendDlgItemMessage(hWnd, IDC_TWOFACTOREDIT, WM_SETFONT, (WPARAM)twofactorfont, 0); + } + return TRUE; + } + + case WM_DESTROY: + if (twofactorfont) + DeleteObject(exch_null(twofactorfont)); + return FALSE; + + + case WM_CTLCOLORSTATIC: + if (GetWindowLong((HWND)lParam, GWL_ID) == IDC_CODENOTACCEPTED) { + SetTextColor((HDC)wParam, RGB(255, 0, 0)); + SetBkMode((HDC)wParam, TRANSPARENT); + return (LRESULT)GetSysColorBrush(COLOR_3DFACE); + } + break; + + case WM_CLOSE: + EndDialog(hWnd, 0); + g_twofactor_dialog_shown = false; + return TRUE; + case WM_COMMAND: + switch (wParam) { + case IDCANCEL: + EndDialog(hWnd, 0); + g_twofactor_dialog_shown = false; + return TRUE; + case IDOK: { + wchar_t buf[TunsafePlugin::kMaxTokenLen + 1]; + char utf8buf[TunsafePlugin::kMaxTokenLen + 1]; + buf[0] = 0; + int nw = GetDlgItemTextW(hWnd, IDC_TWOFACTOREDIT, buf, ARRAYSIZE(buf)) + 1; + int nutf8 = WideCharToMultiByte(CP_UTF8, 0, buf, nw, utf8buf, ARRAYSIZE(utf8buf), 0, NULL); + if (nutf8) { + g_backend->SubmitToken(std::string(utf8buf)); + EndDialog(hWnd, 0); + g_twofactor_dialog_shown = false; + return TRUE; + } + } + } + break; + } + return FALSE; +} + +void ShowTwoFactorDialog() { + twofactornum = 0; + int type = g_twofactor_dialog_request & kTokenRequestType_Mask; + int dialog; + if (type >= kTokenRequestType_6digits && type <= kTokenRequestType_8digits) { + dialog = IDD_DIALOG3; + } else { + dialog = IDD_DIALOG4; + } + DialogBox(g_hinstance, MAKEINTRESOURCE(dialog), g_ui_window, &TwoFactorDlgProc); +} + void InitializeClass(WNDPROC wndproc, const char *name) { WNDCLASSEX wce = {0}; wce.cbSize = sizeof(wce); @@ -1687,6 +1888,7 @@ static bool CreateMainWindow() { InitializeClass(&PaintBoxWndProc, "PaintBox"); InitializeClass(&GraphBoxWndProc, "GraphBox"); InitializeClass(&AdvancedBoxWndProc, "AdvancedBox"); + InitializeClass(&TwoFactorEditFieldWndProc, "TwoFactorEditField"); HDC dc = GetDC(0); g_large_fonts = GetDeviceCaps(dc, LOGPIXELSX); diff --git a/wireguard.cpp b/wireguard.cpp index 84a8af5..6c218dc 100644 --- a/wireguard.cpp +++ b/wireguard.cpp @@ -630,14 +630,14 @@ void WireguardProcessor::ForceSendHandshakeInitiation(WgPeer *peer) { void WireguardProcessor::SendHandshakeInitiation(WgPeer *peer) { assert(dev_.IsMainThread()); - if (!peer->CheckHandshakeRateLimit() || peer->endpoint_.sin.sin_family == 0) + if (!peer->CheckHandshakeRateLimit() || + peer->endpoint_.sin.sin_family == 0 || + (dev_.plugin_ && !dev_.plugin_->WantHandshake(peer))) return; Packet *packet = AllocPacket(); if (packet) { - if (!peer->CreateMessageHandshakeInitiation(packet)) { - FreePacket(packet); - return; - } + peer->CreateMessageHandshakeInitiation(packet); + stats_.handshakes_out++; WG_ACQUIRE_LOCK(peer->mutex_); int attempts = ++peer->total_handshake_attempts_; diff --git a/wireguard_proto.cpp b/wireguard_proto.cpp index d2e8437..5d13fb8 100644 --- a/wireguard_proto.cpp +++ b/wireguard_proto.cpp @@ -520,23 +520,12 @@ void WgPeer::SetPresharedKey(const uint8 preshared_key[WG_SYMMETRIC_KEY_LEN]) { } // run on the client -bool WgPeer::CreateMessageHandshakeInitiation(Packet *packet) { +void WgPeer::CreateMessageHandshakeInitiation(Packet *packet) { assert(dev_->IsMainThread()); uint8 k[WG_SYMMETRIC_KEY_LEN]; MessageHandshakeInitiation *dst = (MessageHandshakeInitiation *)packet->data; - int extfield_size = 0; - if (WITH_HANDSHAKE_EXT && supports_handshake_extensions_) - extfield_size = WriteHandshakeExtension(dst->timestamp_enc + WG_TIMESTAMP_LEN, NULL); - - if (dev_->plugin_) { - uint32 rv = dev_->plugin_->OnHandshake0(this, dst->timestamp_enc + WG_TIMESTAMP_LEN + extfield_size, MAX_SIZE_OF_HANDSHAKE_EXTENSION - extfield_size); - if (rv & WgPlugin::kHandshakeResponseFail) - return false; - extfield_size += rv; - } - // Ci := HASH(CONSTRUCTION) memcpy(hs_.ci, kWgInitChainingKey, sizeof(hs_.ci)); // Hi := HASH(Ci || IDENTIFIER) @@ -563,6 +552,17 @@ bool WgPeer::CreateMessageHandshakeInitiation(Packet *packet) { // TAI64N OsGetTimestampTAI64N(dst->timestamp_enc); + + int extfield_size = 0; + if (WITH_HANDSHAKE_EXT && supports_handshake_extensions_) + extfield_size = WriteHandshakeExtension(dst->timestamp_enc + WG_TIMESTAMP_LEN, NULL); + + if (dev_->plugin_) { + uint32 rv = dev_->plugin_->OnHandshake0(this, dst->timestamp_enc + WG_TIMESTAMP_LEN + extfield_size, MAX_SIZE_OF_HANDSHAKE_EXTENSION - extfield_size, dst->ephemeral); + assert(!(rv & WgPlugin::kHandshakeResponseFail)); + extfield_size += rv; + } + // msg.timestamp := AEAD(K, 0, timestamp, hi) chacha20poly1305_encrypt(dst->timestamp_enc, dst->timestamp_enc, extfield_size + WG_TIMESTAMP_LEN, hs_.hi, sizeof(hs_.hi), 0, k); // Hi := HASH(Hi || msg.timestamp) @@ -574,7 +574,6 @@ bool WgPeer::CreateMessageHandshakeInitiation(Packet *packet) { dst->type = MESSAGE_HANDSHAKE_INITIATION; memzero_crypto(k, sizeof(k)); WriteMacToPacket((uint8*)dst, (MessageMacs*)((uint8*)&dst->mac + extfield_size)); - return true; } // Parsed by server @@ -683,8 +682,8 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) { // Allow plugin to determine what to do with the packet, // it can append new headers to the response, and decide what to do. if (dev->plugin_) { - uint32 rv = dev->plugin_->OnHandshake1(peer, extbuf + WG_TIMESTAMP_LEN, extfield_size, - dst->empty_enc + extfield_out_size, MAX_SIZE_OF_HANDSHAKE_EXTENSION - extfield_out_size); + uint32 rv = dev->plugin_->OnHandshake1(peer, extbuf + WG_TIMESTAMP_LEN, extfield_size, e_remote, + dst->empty_enc + extfield_out_size, MAX_SIZE_OF_HANDSHAKE_EXTENSION - extfield_out_size, dst->ephemeral); if (rv == WgPlugin::kHandshakeResponseDrop) goto getout; if (rv & WgPlugin::kHandshakeResponseFail) @@ -770,7 +769,7 @@ WgPeer *WgPeer::ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packe // Allow plugin to determine what to do with the packet, // it can append new headers to the response, and decide what to do. if (dev->plugin_) { - uint32 rv = dev->plugin_->OnHandshake2(peer, src->empty_enc, extfield_size); + uint32 rv = dev->plugin_->OnHandshake2(peer, src->empty_enc, extfield_size, src->ephemeral); if (rv & WgPlugin::kHandshakeResponseFail) { delete keypair; goto getout; diff --git a/wireguard_proto.h b/wireguard_proto.h index 101e2b4..6438003 100644 --- a/wireguard_proto.h +++ b/wireguard_proto.h @@ -318,18 +318,22 @@ public: virtual bool OnUnknownInterfaceSetting(const char *key, const char *value) = 0; virtual bool OnUnknownPeerSetting(WgPeer *peer, const char *key, const char *value) = 0; + // Returns true if we want to perform a handshake for this peer. + virtual bool WantHandshake(WgPeer *peer) = 0; + + // Called right before handshake initiation is sent out. Can't drop packets. + virtual uint32 OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) = 0; + // Called after handshake initiation is parsed, but before handshake response is sent. + enum { kHandshakeResponseDrop = 0xffffffff, kHandshakeResponseFail = 0x80000000 }; - - // Called right before handshake initiation is sent out. Can be dropped. - virtual uint32 OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size) = 0; - // Called after handshake initiation is parsed, but before handshake response is sent. // Packet can be dropped or keypair failed. - virtual uint32 OnHandshake1(WgPeer *peer, const uint8 *ext, uint32 ext_size, uint8 *extout, uint32 extout_size) = 0; + virtual uint32 OnHandshake1(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt_in[WG_PUBLIC_KEY_LEN], uint8 *extout, uint32 extout_size, const uint8 salt_out[WG_PUBLIC_KEY_LEN]) = 0; + // Called when handshake response is parsed - virtual uint32 OnHandshake2(WgPeer *peer, const uint8 *ext, uint32 ext_size) = 0; + virtual uint32 OnHandshake2(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) = 0; }; @@ -491,7 +495,7 @@ public: static WgPeer *ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet); static WgPeer *ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packet); static void ParseMessageHandshakeCookie(WgDevice *dev, const MessageHandshakeCookie *src); - bool CreateMessageHandshakeInitiation(Packet *packet); + void CreateMessageHandshakeInitiation(Packet *packet); bool CheckSwitchToNextKey_Locked(WgKeypair *keypair); void RemovePeer(); bool CheckHandshakeRateLimit();