From 83183b2193f3f72ac2a44be4f5c8d832e1b20d51 Mon Sep 17 00:00:00 2001 From: Ludvig Strigeus Date: Sun, 7 Oct 2018 19:36:52 +0200 Subject: [PATCH] Background thread for DNS resolve --- network_win32.cpp | 4 +- tunsafe_threading.cpp | 104 +++++++++++++++++++++++++++ tunsafe_threading.h | 46 ++++++++++++ wireguard_config.cpp | 162 ++++++++++++++++++++++++++++++++++++++---- wireguard_config.h | 20 +++++- 5 files changed, 317 insertions(+), 19 deletions(-) diff --git a/network_win32.cpp b/network_win32.cpp index c1fc397..b364c6c 100644 --- a/network_win32.cpp +++ b/network_win32.cpp @@ -2088,7 +2088,7 @@ void TunsafeBackendWin32::Stop() { void TunsafeBackendWin32::Start(const char *config_file) { StopInner(true); - dns_resolver_.SetAbortFlag(false); + dns_resolver_.ResetCancel(); is_started_ = true; memset(public_key_, 0, sizeof(public_key_)); SetStatus(kStatusInitializing); @@ -2107,7 +2107,7 @@ void TunsafeBackendWin32::PostExit(int exit_code) { void TunsafeBackendWin32::StopInner(bool is_restart) { if (worker_thread_) { ipv4_ip_ = 0; - dns_resolver_.SetAbortFlag(true); + dns_resolver_.Cancel(); PostExit(is_restart ? MODE_RESTART : MODE_EXIT); WaitForSingleObject(worker_thread_, INFINITE); CloseHandle(worker_thread_); diff --git a/tunsafe_threading.cpp b/tunsafe_threading.cpp index f7c64b7..54cd5e6 100644 --- a/tunsafe_threading.cpp +++ b/tunsafe_threading.cpp @@ -3,6 +3,110 @@ #include "stdafx.h" #include "tunsafe_threading.h" #include +#include + +#if defined(OS_POSIX) +Thread::Thread() { + thread_ = 0; +} + +Thread::~Thread() { + assert(thread_ == 0); +} + +static void *ThreadMainStatic(void *x) { + Thread::Runner *t = (Thread::Runner*)x; + t->ThreadMain(); + return 0; +} + +void Thread::StartThread(Runner *runner) { + assert(thread_ == 0); + if (pthread_create(&thread_, NULL, &ThreadMainStatic, runner) != 0) + tunsafe_die("pthread_create failed"); +} + +void Thread::StopThread() { + if (thread_) { + void *x; + pthread_join(thread_, &x); + thread_ = 0; + } +} + +void Thread::DetachThread() { + if (thread_) { + pthread_detach(thread_); + thread_ = 0; + } +} + +bool Thread::is_started() { + return thread_ != 0; +} + +void ConditionVariable::WaitTimed(Mutex *mutex, int millis) { +#if !defined(OS_MACOSX) + struct timespec ts; + clock_gettime(CLOCK_REALTIME, &ts); + + ts.tv_sec += (millis / 1000); + ts.tv_nsec += (millis % 1000) * 1000000; + if (ts.tv_nsec >= 1000000000) { + ts.tv_nsec -= 1000000000; + ts.tv_sec++; + } + pthread_cond_timedwait(&condvar_, &mutex->lock_, &ts); +#else + struct timespec ts; + ts.tv_sec = millis / 1000; + ts.tv_nsec = (millis % 1000) * 1000000; + + pthread_cond_timedwait_relative_np(&condvar_, &mutex->lock_, &ts); +#endif +} +#endif // defined(OS_POSIX) + +#if defined(OS_WIN) +Thread::Thread() { + thread_ = 0; +} + +Thread::~Thread() { + assert(thread_ == 0); +} + +static DWORD WINAPI ThreadMainStatic(void *x) { + Thread::Runner *t = (Thread::Runner*)x; + t->ThreadMain(); + return 0; +} + +void Thread::StartThread(Runner *runner) { + assert(thread_ == 0); + DWORD thread_id; + thread_ = CreateThread(NULL, 0, &ThreadMainStatic, (LPVOID)runner, 0, &thread_id); +} + +void Thread::StopThread() { + if (thread_) { + WaitForSingleObject(thread_, INFINITE); + CloseHandle(thread_); + thread_ = 0; + } +} + +void Thread::DetachThread() { + if (thread_) { + CloseHandle(thread_); + thread_ = 0; + } +} + +bool Thread::is_started() { + return thread_ != 0; +} +#endif MultithreadedDelayedDelete::MultithreadedDelayedDelete() { table_ = NULL; diff --git a/tunsafe_threading.h b/tunsafe_threading.h index d595104..6e27c99 100644 --- a/tunsafe_threading.h +++ b/tunsafe_threading.h @@ -23,6 +23,7 @@ private: }; class Mutex { + friend class ConditionVariable; public: #if defined(_DEBUG) bool locked_; @@ -46,6 +47,17 @@ private: SRWLOCK lock_; }; +class ConditionVariable { +public: + ConditionVariable() { InitializeConditionVariable(&condvar_); } + void Wait(Mutex *mutex) { SleepConditionVariableSRW(&condvar_, &mutex->lock_, INFINITE, 0); } + void WaitTimed(Mutex *mutex, int millis) { SleepConditionVariableSRW(&condvar_, &mutex->lock_, millis, 0); } + void Wake() { WakeConditionVariable(&condvar_); } + +private: + CONDITION_VARIABLE condvar_; +}; + typedef uint32 ThreadId; static inline bool CurrentThreadIdEquals(ThreadId thread_id) { @@ -72,6 +84,7 @@ private: }; class Mutex { + friend class ConditionVariable; public: #if defined(_DEBUG) bool locked_; @@ -104,6 +117,17 @@ private: pthread_mutex_t lock_; }; +class ConditionVariable { +public: + ConditionVariable() { pthread_cond_init(&condvar_, NULL); } + void Wait(Mutex *mutex) { pthread_cond_wait(&condvar_, &mutex->lock_); } + void WaitTimed(Mutex *mutex, int millis); + void Wake() { pthread_cond_signal(&condvar_); } +private: + pthread_cond_t condvar_; +}; + + typedef pthread_t ThreadId; static inline bool CurrentThreadIdEquals(ThreadId thread_id) { @@ -140,6 +164,28 @@ private: Mutex *lock_; }; +class Thread { +public: + class Runner { + public: + virtual void ThreadMain() = 0; + }; + Thread(); + ~Thread(); + void StartThread(Runner *runner); + void StopThread(); + void DetachThread(); + bool is_started(); + +private: +#if defined(OS_WIN) + HANDLE thread_; +#else // defined(OS_WIN) + pthread_t thread_; +#endif // !defined(OS_WIN) +}; + + // This class deletes objects delayed. All participating threads will call a function, // and then once all threads did, all registered objects will get deleted. class MultithreadedDelayedDelete { diff --git a/wireguard_config.cpp b/wireguard_config.cpp index 2bb59da..65f32c3 100644 --- a/wireguard_config.cpp +++ b/wireguard_config.cpp @@ -57,8 +57,6 @@ char *PrintWgCidrAddr(const WgCidrAddr &addr, char buf[kSizeOfAddress]) { return buf; } - - struct Addr { byte addr[4]; uint8 cidr; @@ -88,9 +86,147 @@ bool ParseCidrAddr(char *s, WgCidrAddr *out) { return false; } +static Mutex g_dns_mutex; + +// This starts a background thread for running DNS resolving. +class DnsResolverThread : private Thread::Runner { +public: + DnsResolverThread(); + ~DnsResolverThread(); + + // Resolve the hostname and store the result in |result|. + // The function will block until it's resolved. If the cancellation + // token or becomes signalled, the call will fail. + bool Resolve(const char *hostname, IpAddr *result, DnsResolverCanceller *token); + +private: + virtual void ThreadMain(); + void StartThread(); + + struct Entry { + enum { + // Set when it's been posted to the job queue + POSTED = 0, + // Set when the thread has finished and original thread should delete + COMPLETE = 1, + // Set when the original thread has cancelled and worker thread should delete + CANCELLED = 2, + }; + + Entry() : hostname(NULL) {} + ~Entry() { free(hostname); } + + char *hostname; + IpAddr *result; + Entry *next; + uint32 state; + ConditionVariable *condvar; + }; + Entry *entry_; + Thread thread_; + bool thread_active_; +}; + +DnsResolverThread::DnsResolverThread() { + thread_active_ = false; + entry_ = NULL; +} + +DnsResolverThread::~DnsResolverThread() { + assert(entry_ == NULL); + thread_.StopThread(); +} + +void DnsResolverCanceller::Cancel() { + g_dns_mutex.Acquire(); + cancel_ = true; + condvar_.Wake(); + g_dns_mutex.Release(); +} + +bool DnsResolverThread::Resolve(const char *hostname, IpAddr *result, DnsResolverCanceller *token) { + if (token->cancel_) + return false; + + Entry *e = new Entry; + e->hostname = _strdup(hostname); + e->result = result; + e->next = NULL; + e->state = Entry::POSTED; + e->condvar = &token->condvar_; + result->sin.sin_family = 0; + + // Push it to the queue and start thread + g_dns_mutex.Acquire(); + Entry **p = &entry_; + while (*p) p = &(*p)->next; + *p = e; + if (!thread_active_) + StartThread(); + // Wait for something to happen with it. + while (!token->cancel_ && e->state == Entry::POSTED) + token->condvar_.Wait(&g_dns_mutex); + if (e->state == Entry::COMPLETE) { + delete e; + } else { + e->state = Entry::CANCELLED; + } + g_dns_mutex.Release(); + return result->sin.sin_family != 0; +} + +void DnsResolverThread::StartThread() { + thread_.StopThread(); + thread_active_ = true; + thread_.StartThread(this); +} + +void DnsResolverThread::ThreadMain() { + Entry *e = NULL; + struct hostent *he = NULL; + for (;;) { + g_dns_mutex.Acquire(); + if (e) { + if (e->state == Entry::CANCELLED) { + delete e; + } else { + if (he) { + e->result->sin.sin_family = AF_INET; + e->result->sin.sin_port = 0; + memcpy(&e->result->sin.sin_addr, he->h_addr_list[0], 4); + } + e->state = Entry::COMPLETE; + e->condvar->Wake(); + } + } + if (!(e = entry_)) { + thread_active_ = false; + break; + } + entry_ = e->next; + g_dns_mutex.Release(); + he = gethostbyname(e->hostname); + } + g_dns_mutex.Release(); +} + +static DnsResolverThread g_dnsresolver_thread; + +bool InterruptibleSleep(int delay, DnsResolverCanceller *token) { + g_dns_mutex.Acquire(); + uint32 time_at_start = (uint32)OsGetMilliseconds(); + while (delay > 0 && !token->cancel_) { + token->condvar_.WaitTimed(&g_dns_mutex, delay); + uint32 now = (uint32)OsGetMilliseconds(); + delay -= (now - time_at_start); + time_at_start = now; + } + g_dns_mutex.Release(); + return (delay <= 0); +} + DnsResolver::DnsResolver(DnsBlocker *dns_blocker) { dns_blocker_ = dns_blocker; - abort_flag_ = false; } DnsResolver::~DnsResolver() { @@ -126,23 +262,17 @@ bool DnsResolver::Resolve(const char *hostname, IpAddr *result) { #endif // defined(OS_WIN) for (;;) { - hostent *he = gethostbyname(hostname); - if (abort_flag_) - return false; - - if (he) { - result->sin.sin_family = AF_INET; - result->sin.sin_port = 0; - memcpy(&result->sin.sin_addr, he->h_addr_list[0], 4); + if (g_dnsresolver_thread.Resolve(hostname, result, &token_)) { // add to cache cache_.emplace_back(hostname, *result); RINFO("Resolved %s to %s%s", hostname, PrintIpAddr(*result, buf), ""); return true; } + if (token_.is_cancelled()) + return false; RINFO("Unable to resolve %s. Trying again in %d second(s)", hostname, retry_delays[attempt]); - OsInterruptibleSleep(retry_delays[attempt] * 1000); - if (abort_flag_) + if (!InterruptibleSleep(retry_delays[attempt] * 1000, &token_)) return false; if (attempt != ARRAY_SIZE(retry_delays) - 1) @@ -150,7 +280,11 @@ bool DnsResolver::Resolve(const char *hostname, IpAddr *result) { } } -bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver) { +bool ParseSockaddrInWithPort(const char *si, IpAddr *sin, DnsResolver *resolver) { + size_t len = strlen(si) + 1; + char *s = (char*)alloca(len); + memcpy(s, si, len); + memset(sin, 0, sizeof(IpAddr)); if (*s == '[') { char *end = strchr(s, ']'); diff --git a/wireguard_config.h b/wireguard_config.h index 791925d..635b2eb 100644 --- a/wireguard_config.h +++ b/wireguard_config.h @@ -4,20 +4,32 @@ #define TINYVPN_TINYVPN_H_ #include "netapi.h" +#include "tunsafe_threading.h" class WireguardProcessor; class DnsBlocker; +class DnsResolverCanceller { +public: + DnsResolverCanceller() : cancel_(false) {} + void Cancel(); + void Reset() { cancel_ = false; } + bool is_cancelled() { return cancel_; } +public: + bool cancel_; + ConditionVariable condvar_; +}; + class DnsResolver { public: explicit DnsResolver(DnsBlocker *dns_blocker); ~DnsResolver(); bool Resolve(const char *hostname, IpAddr *result); - void ClearCache(); - void SetAbortFlag(bool v) { abort_flag_ = v; } + void Cancel() { token_.Cancel(); } + void ResetCancel() { token_.Reset(); } private: struct Entry { std::string name; @@ -25,8 +37,8 @@ private: Entry(const std::string &name, const IpAddr &ip) : name(name), ip(ip) {} }; std::vector cache_; - bool abort_flag_; DnsBlocker *dns_blocker_; + DnsResolverCanceller token_; }; @@ -43,6 +55,8 @@ bool ParseWireGuardConfigFile(WireguardProcessor *wg, const char *filename, DnsR const char *print_ip_prefix(char buf[kSizeOfAddress], int family, const void *ip, int prefixlen); char *PrintIpAddr(const IpAddr &addr, char buf[kSizeOfAddress]); char *PrintWgCidrAddr(const WgCidrAddr &addr, char buf[kSizeOfAddress]); + bool ParseCidrAddr(char *s, WgCidrAddr *out); +bool ParseSockaddrInWithPort(const char *s, IpAddr *sin, DnsResolver *resolver); #endif // TINYVPN_TINYVPN_H_