Allow TOTP to be turned on/off

This commit is contained in:
Ludvig Strigeus 2018-12-10 23:48:47 +01:00
parent 008dc6c785
commit f7b09c43fd
4 changed files with 85 additions and 53 deletions

View file

@ -15,11 +15,13 @@ enum {
WG_SESSION_AUTH_LEN = 16, WG_SESSION_AUTH_LEN = 16,
}; };
enum {
WITH_TWO_FACTOR_AUTHENTICATION = 1,
};
class PluginPeer; class PluginPeer;
class TunsafePluginImpl; class TunsafePluginImpl;
static const char kTwoFactorTokenTag[] = "Two-Factor Token";
class ExtFieldWriter { class ExtFieldWriter {
public: public:
ExtFieldWriter(uint8 *target, uint32 target_size) : target_(target), target_size_(target_size), target_pos_(0), fail_flag_(false) { } ExtFieldWriter(uint8 *target, uint32 target_size) : target_(target), target_size_(target_size), target_pos_(0), fail_flag_(false) { }
@ -186,10 +188,10 @@ public:
class TunsafePluginImpl : public TunsafePlugin { class TunsafePluginImpl : public TunsafePlugin {
friend class PluginPeer; friend class PluginPeer;
public: public:
TunsafePluginImpl(PluginDelegate *del) { TunsafePluginImpl(PluginDelegate *del, WireguardProcessor *proc) {
delegate_ = del; delegate_ = del;
proc_ = proc;
peer_doing_2fa_ = NULL; peer_doing_2fa_ = NULL;
proc_ = NULL;
OsGetRandomBytes((uint8*)&siphash_key_, sizeof(siphash_key_)); OsGetRandomBytes((uint8*)&siphash_key_, sizeof(siphash_key_));
} }
@ -204,19 +206,25 @@ public:
private: private:
virtual bool HandleUnknownPeerId(uint8 public_key[WG_PUBLIC_KEY_LEN], Packet *packet) override { return false; } virtual bool HandleUnknownPeerId(uint8 public_key[WG_PUBLIC_KEY_LEN], Packet *packet) override { return false; }
virtual bool OnUnknownInterfaceSetting(const char *key, const char *value) override { return false; } virtual bool OnUnknownInterfaceSetting(const char *key, const char *value) override;
virtual bool OnUnknownPeerSetting(WgPeer *peer, const char *key, const char *value) override; virtual bool OnUnknownPeerSetting(WgPeer *peer, const char *key, const char *value) override;
virtual bool WantHandshake(WgPeer *peer) override; virtual bool WantHandshake(WgPeer *peer) override;
virtual uint32 OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) override; virtual uint32 OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) override;
virtual uint32 OnHandshake1(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt_in[WG_PUBLIC_KEY_LEN], uint8 *extout, uint32 extout_size, const uint8 salt_out[WG_PUBLIC_KEY_LEN]) override; virtual uint32 OnHandshake1(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt_in[WG_PUBLIC_KEY_LEN], uint8 *extout, uint32 extout_size, const uint8 salt_out[WG_PUBLIC_KEY_LEN]) override;
virtual uint32 OnHandshake2(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) override; virtual uint32 OnHandshake2(WgPeer *peer, const uint8 *ext, uint32 ext_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) override;
virtual bool OnAfterSettingsParsed() override;
virtual void OnOutgoingHandshakePacket(WgPeer *peer, Packet *packet) override;
PluginPeer *GetPluginPeer(WgPeer *peer); PluginPeer *GetPluginPeer(WgPeer *peer);
virtual void SubmitToken(const uint8 *text, size_t text_len) override; virtual void SubmitToken(const uint8 *text, size_t text_len) override;
WireguardProcessor *proc_;
PluginPeer *peer_doing_2fa_; PluginPeer *peer_doing_2fa_;
PluginDelegate *delegate_; PluginDelegate *delegate_;
siphash_key_t siphash_key_; siphash_key_t siphash_key_;
}; };
PluginPeer::~PluginPeer() { PluginPeer::~PluginPeer() {
@ -553,9 +561,9 @@ bool TokenServerHandler::VerifySessionId(const uint8 session_id_auth[WG_SESSION_
return memcmp_crypto(buf, session_id_auth, WG_SESSION_AUTH_LEN) == 0; return memcmp_crypto(buf, session_id_auth, WG_SESSION_AUTH_LEN) == 0;
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////
PluginPeer *TunsafePluginImpl::GetPluginPeer(WgPeer *peer) { PluginPeer *TunsafePluginImpl::GetPluginPeer(WgPeer *peer) {
PluginPeer *rv = (PluginPeer *)peer->extradata(); PluginPeer *rv = (PluginPeer *)peer->extradata();
if (!rv) { if (!rv) {
@ -567,15 +575,20 @@ PluginPeer *TunsafePluginImpl::GetPluginPeer(WgPeer *peer) {
bool TunsafePluginImpl::WantHandshake(WgPeer *peer) { bool TunsafePluginImpl::WantHandshake(WgPeer *peer) {
PluginPeer *pp = GetPluginPeer(peer); PluginPeer *pp = GetPluginPeer(peer);
return pp->token_server_handler.WantHandshake() && if (WITH_TWO_FACTOR_AUTHENTICATION) {
pp->token_client_handler.WantHandshake(); return pp->token_server_handler.WantHandshake() &&
pp->token_client_handler.WantHandshake();
} else {
return true;
}
} }
// This runs on client and appends data // This runs on client and appends data
uint32 TunsafePluginImpl::OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) { uint32 TunsafePluginImpl::OnHandshake0(WgPeer *peer, uint8 *extout, uint32 extout_size, const uint8 salt[WG_PUBLIC_KEY_LEN]) {
PluginPeer *pp = GetPluginPeer(peer); PluginPeer *pp = GetPluginPeer(peer);
ExtFieldWriter writer(extout, extout_size); ExtFieldWriter writer(extout, extout_size);
pp->token_client_handler.OnHandshakeCreate(peer, writer, salt); if (WITH_TWO_FACTOR_AUTHENTICATION)
pp->token_client_handler.OnHandshakeCreate(peer, writer, salt);
return writer.length(); return writer.length();
} }
@ -593,27 +606,31 @@ uint32 TunsafePluginImpl::OnHandshake1(WgPeer *peer, const uint8 *ext, uint32 ex
ext += 2, ext_size -= 2; ext += 2, ext_size -= 2;
if (size > ext_size) if (size > ext_size)
return false; return false;
switch (type) { if (WITH_TWO_FACTOR_AUTHENTICATION) {
case kExtensionType_SessionIDAuth: switch (type) {
if (size == WG_SESSION_AUTH_LEN) case kExtensionType_SessionIDAuth:
has_valid_session_id = pp->token_server_handler.VerifySessionId(ext, salt_in); if (size == WG_SESSION_AUTH_LEN)
break; has_valid_session_id = pp->token_server_handler.VerifySessionId(ext, salt_in);
break;
case kExtensionType_TokenReply: case kExtensionType_TokenReply:
token_reply = (uint8*)ext; token_reply = (uint8*)ext;
token_reply_size = size; token_reply_size = size;
break; break;
}
} }
ext += size, ext_size -= size; ext += size, ext_size -= size;
} }
if (ext_size != 0) if (ext_size != 0)
return kHandshakeResponseDrop; return kHandshakeResponseDrop;
// If this is a handshake in the other direction, also include session id. if (WITH_TWO_FACTOR_AUTHENTICATION) {
pp->token_client_handler.WriteSessionId(writer, salt_out); // If this is a handshake in the other direction, also include session id.
pp->token_client_handler.WriteSessionId(writer, salt_out);
if (!pp->token_server_handler.OnHandshake(token_reply, token_reply_size, has_valid_session_id, writer, &siphash_key_)) if (!pp->token_server_handler.OnHandshake(token_reply, token_reply_size, has_valid_session_id, writer, &siphash_key_))
return kHandshakeResponseDrop; return kHandshakeResponseDrop;
}
return writer.length() + writer.fail_flag() * WgPlugin::kHandshakeResponseFail; return writer.length() + writer.fail_flag() * WgPlugin::kHandshakeResponseFail;
} }
@ -629,47 +646,58 @@ uint32 TunsafePluginImpl::OnHandshake2(WgPeer *peer, const uint8 *ext, uint32 ex
ext += 2, ext_size -= 2; ext += 2, ext_size -= 2;
if (size > ext_size) if (size > ext_size)
return false; return false;
switch (type) {
case kExtensionType_SessionIDAuth: if (WITH_TWO_FACTOR_AUTHENTICATION) {
if (size == WG_SESSION_AUTH_LEN) switch (type) {
has_valid_session_id = pp->token_server_handler.VerifySessionId(ext, salt); case kExtensionType_SessionIDAuth:
break; if (size == WG_SESSION_AUTH_LEN)
// All token requests mean that handshake has failed. has_valid_session_id = pp->token_server_handler.VerifySessionId(ext, salt);
case kExtensionType_TokenRequest: break;
pp->token_client_handler.OnTokenRequest(ext, size); // All token requests mean that handshake has failed.
return kHandshakeResponseDrop; case kExtensionType_TokenRequest:
case kExtensionType_SetSessionID: pp->token_client_handler.OnTokenRequest(ext, size);
if (size == WG_SESSION_ID_LEN) return kHandshakeResponseDrop;
pp->token_client_handler.SetSessionId(ext); case kExtensionType_SetSessionID:
break; if (size == WG_SESSION_ID_LEN)
pp->token_client_handler.SetSessionId(ext);
break;
}
} }
ext += size, ext_size -= size; ext += size, ext_size -= size;
} }
if (ext_size != 0) if (ext_size != 0)
return kHandshakeResponseDrop; return kHandshakeResponseDrop;
// Stop outgoing handshakes if client didn't supply a valid session id. if (WITH_TWO_FACTOR_AUTHENTICATION) {
if (!pp->token_server_handler.OnHandshake2(has_valid_session_id)) // Stop outgoing handshakes if client didn't supply a valid session id.
return kHandshakeResponseDrop; if (!pp->token_server_handler.OnHandshake2(has_valid_session_id))
return kHandshakeResponseDrop;
pp->token_client_handler.OnHandshakeComplete(); pp->token_client_handler.OnHandshakeComplete();
}
return 0; return 0;
} }
bool TunsafePluginImpl::OnUnknownInterfaceSetting(const char *key, const char *value) {
return false;
}
bool TunsafePluginImpl::OnUnknownPeerSetting(WgPeer *peer, const char *key, const char *value) { bool TunsafePluginImpl::OnUnknownPeerSetting(WgPeer *peer, const char *key, const char *value) {
PluginPeer *pp = GetPluginPeer(peer); PluginPeer *pp = GetPluginPeer(peer);
return pp->token_server_handler.OnUnknownPeerSetting(key, value); if (WITH_TWO_FACTOR_AUTHENTICATION && pp->token_server_handler.OnUnknownPeerSetting(key, value))
return true;
return false;
} }
void TunsafePluginImpl::OnTokenRequest(PluginPeer *peer) { void TunsafePluginImpl::OnTokenRequest(PluginPeer *peer) {
if (peer_doing_2fa_ != NULL) if (!WITH_TWO_FACTOR_AUTHENTICATION || peer_doing_2fa_ != NULL)
return; return;
peer_doing_2fa_ = peer; peer_doing_2fa_ = peer;
delegate_->OnRequestToken(peer->peer, peer->token_client_handler.token_request()); delegate_->OnRequestToken(peer->peer, peer->token_client_handler.token_request());
} }
void TunsafePluginImpl::SubmitToken(const uint8 *text, size_t text_len) { void TunsafePluginImpl::SubmitToken(const uint8 *text, size_t text_len) {
if (peer_doing_2fa_ == NULL) if (!WITH_TWO_FACTOR_AUTHENTICATION || peer_doing_2fa_ == NULL)
return; return;
assert(peer_doing_2fa_->peer->dev()->IsMainThread()); assert(peer_doing_2fa_->peer->dev()->IsMainThread());
peer_doing_2fa_->token_client_handler.SetToken(text, text_len); peer_doing_2fa_->token_client_handler.SetToken(text, text_len);
@ -688,7 +716,14 @@ void TunsafePluginImpl::SubmitToken(const uint8 *text, size_t text_len) {
} }
} }
TunsafePlugin *CreateTunsafePlugin(PluginDelegate *delegate) { bool TunsafePluginImpl::OnAfterSettingsParsed() {
return new TunsafePluginImpl(delegate); return true;
}
void TunsafePluginImpl::OnOutgoingHandshakePacket(WgPeer *peer, Packet *packet) {
}
TunsafePlugin *CreateTunsafePlugin(PluginDelegate *delegate, WireguardProcessor *wgp) {
return new TunsafePluginImpl(delegate, wgp);
} }

View file

@ -32,15 +32,8 @@ public:
enum { enum {
kMaxTokenLen = 128, kMaxTokenLen = 128,
}; };
void Initialize(WireguardProcessor *proc) { proc_ = proc; }
// Called after OnRequest2FA to supply the token. // Called after OnRequest2FA to supply the token.
virtual void SubmitToken(const uint8 *text, size_t text_len) = 0; virtual void SubmitToken(const uint8 *text, size_t text_len) = 0;
protected:
WireguardProcessor *proc_;
}; };
TunsafePlugin *CreateTunsafePlugin(PluginDelegate *del); TunsafePlugin *CreateTunsafePlugin(PluginDelegate *del, WireguardProcessor *wgp);

View file

@ -571,6 +571,8 @@ getout_discard:
void WireguardProcessor::PrepareOutgoingHandshakePacket(WgPeer *peer, Packet *packet) { void WireguardProcessor::PrepareOutgoingHandshakePacket(WgPeer *peer, Packet *packet) {
assert(dev_.IsMainThread()); assert(dev_.IsMainThread());
if (dev_.plugin_)
dev_.plugin_->OnOutgoingHandshakePacket(peer, packet);
stats_.udp_packets_out++; stats_.udp_packets_out++;
stats_.udp_bytes_out += packet->size; stats_.udp_bytes_out += packet->size;
} }

View file

@ -550,6 +550,8 @@ public:
void SetExtradata(WgPeerExtraData *ex) { peer_extra_data_ = ex; } void SetExtradata(WgPeerExtraData *ex) { peer_extra_data_ = ex; }
WgDevice *dev() { return dev_; } WgDevice *dev() { return dev_; }
const uint8 *epriv() { return hs_.e_priv; }
private: private:
bool ParseExtendedHandshake(WgKeypair *keypair, const uint8 *data, size_t data_size); bool ParseExtendedHandshake(WgKeypair *keypair, const uint8 *data, size_t data_size);
static WgKeypair *CreateNewKeypair(bool is_initiator, const uint8 key[WG_HASH_LEN], uint32 send_key_id); static WgKeypair *CreateNewKeypair(bool is_initiator, const uint8 key[WG_HASH_LEN], uint32 send_key_id);