Background thread for DNS resolve

This commit is contained in:
Ludvig Strigeus 2018-10-07 19:36:52 +02:00
parent c08b18c028
commit 83183b2193
5 changed files with 317 additions and 19 deletions

View file

@ -2088,7 +2088,7 @@ void TunsafeBackendWin32::Stop() {
void TunsafeBackendWin32::Start(const char *config_file) {
StopInner(true);
dns_resolver_.SetAbortFlag(false);
dns_resolver_.ResetCancel();
is_started_ = true;
memset(public_key_, 0, sizeof(public_key_));
SetStatus(kStatusInitializing);
@ -2107,7 +2107,7 @@ void TunsafeBackendWin32::PostExit(int exit_code) {
void TunsafeBackendWin32::StopInner(bool is_restart) {
if (worker_thread_) {
ipv4_ip_ = 0;
dns_resolver_.SetAbortFlag(true);
dns_resolver_.Cancel();
PostExit(is_restart ? MODE_RESTART : MODE_EXIT);
WaitForSingleObject(worker_thread_, INFINITE);
CloseHandle(worker_thread_);

View file

@ -3,6 +3,110 @@
#include "stdafx.h"
#include "tunsafe_threading.h"
#include <stdlib.h>
#include <assert.h>
#if defined(OS_POSIX)
Thread::Thread() {
thread_ = 0;
}
Thread::~Thread() {
assert(thread_ == 0);
}
static void *ThreadMainStatic(void *x) {
Thread::Runner *t = (Thread::Runner*)x;
t->ThreadMain();
return 0;
}
void Thread::StartThread(Runner *runner) {
assert(thread_ == 0);
if (pthread_create(&thread_, NULL, &ThreadMainStatic, runner) != 0)
tunsafe_die("pthread_create failed");
}
void Thread::StopThread() {
if (thread_) {
void *x;
pthread_join(thread_, &x);
thread_ = 0;
}
}
void Thread::DetachThread() {
if (thread_) {
pthread_detach(thread_);
thread_ = 0;
}
}
bool Thread::is_started() {
return thread_ != 0;
}
void ConditionVariable::WaitTimed(Mutex *mutex, int millis) {
#if !defined(OS_MACOSX)
struct timespec ts;
clock_gettime(CLOCK_REALTIME, &ts);
ts.tv_sec += (millis / 1000);
ts.tv_nsec += (millis % 1000) * 1000000;
if (ts.tv_nsec >= 1000000000) {
ts.tv_nsec -= 1000000000;
ts.tv_sec++;
}
pthread_cond_timedwait(&condvar_, &mutex->lock_, &ts);
#else
struct timespec ts;
ts.tv_sec = millis / 1000;
ts.tv_nsec = (millis % 1000) * 1000000;
pthread_cond_timedwait_relative_np(&condvar_, &mutex->lock_, &ts);
#endif
}
#endif // defined(OS_POSIX)
#if defined(OS_WIN)
Thread::Thread() {
thread_ = 0;
}
Thread::~Thread() {
assert(thread_ == 0);
}
static DWORD WINAPI ThreadMainStatic(void *x) {
Thread::Runner *t = (Thread::Runner*)x;
t->ThreadMain();
return 0;
}
void Thread::StartThread(Runner *runner) {
assert(thread_ == 0);
DWORD thread_id;
thread_ = CreateThread(NULL, 0, &ThreadMainStatic, (LPVOID)runner, 0, &thread_id);
}
void Thread::StopThread() {
if (thread_) {
WaitForSingleObject(thread_, INFINITE);
CloseHandle(thread_);
thread_ = 0;
}
}
void Thread::DetachThread() {
if (thread_) {
CloseHandle(thread_);
thread_ = 0;
}
}
bool Thread::is_started() {
return thread_ != 0;
}
#endif
MultithreadedDelayedDelete::MultithreadedDelayedDelete() {
table_ = NULL;

View file

@ -23,6 +23,7 @@ private:
};
class Mutex {
friend class ConditionVariable;
public:
#if defined(_DEBUG)
bool locked_;
@ -46,6 +47,17 @@ private:
SRWLOCK lock_;
};
class ConditionVariable {
public:
ConditionVariable() { InitializeConditionVariable(&condvar_); }
void Wait(Mutex *mutex) { SleepConditionVariableSRW(&condvar_, &mutex->lock_, INFINITE, 0); }
void WaitTimed(Mutex *mutex, int millis) { SleepConditionVariableSRW(&condvar_, &mutex->lock_, millis, 0); }
void Wake() { WakeConditionVariable(&condvar_); }
private:
CONDITION_VARIABLE condvar_;
};
typedef uint32 ThreadId;
static inline bool CurrentThreadIdEquals(ThreadId thread_id) {
@ -72,6 +84,7 @@ private:
};
class Mutex {
friend class ConditionVariable;
public:
#if defined(_DEBUG)
bool locked_;
@ -104,6 +117,17 @@ private:
pthread_mutex_t lock_;
};
class ConditionVariable {
public:
ConditionVariable() { pthread_cond_init(&condvar_, NULL); }
void Wait(Mutex *mutex) { pthread_cond_wait(&condvar_, &mutex->lock_); }
void WaitTimed(Mutex *mutex, int millis);
void Wake() { pthread_cond_signal(&condvar_); }
private:
pthread_cond_t condvar_;
};
typedef pthread_t ThreadId;
static inline bool CurrentThreadIdEquals(ThreadId thread_id) {
@ -140,6 +164,28 @@ private:
Mutex *lock_;
};
class Thread {
public:
class Runner {
public:
virtual void ThreadMain() = 0;
};
Thread();
~Thread();
void StartThread(Runner *runner);
void StopThread();
void DetachThread();
bool is_started();
private:
#if defined(OS_WIN)
HANDLE thread_;
#else // defined(OS_WIN)
pthread_t thread_;
#endif // !defined(OS_WIN)
};
// This class deletes objects delayed. All participating threads will call a function,
// and then once all threads did, all registered objects will get deleted.
class MultithreadedDelayedDelete {

View file

@ -57,8 +57,6 @@ char *PrintWgCidrAddr(const WgCidrAddr &addr, char buf[kSizeOfAddress]) {
return buf;
}
struct Addr {
byte addr[4];
uint8 cidr;
@ -88,9 +86,147 @@ bool ParseCidrAddr(char *s, WgCidrAddr *out) {
return false;
}
static Mutex g_dns_mutex;
// This starts a background thread for running DNS resolving.
class DnsResolverThread : private Thread::Runner {
public:
DnsResolverThread();
~DnsResolverThread();
// Resolve the hostname and store the result in |result|.
// The function will block until it's resolved. If the cancellation
// token or becomes signalled, the call will fail.
bool Resolve(const char *hostname, IpAddr *result, DnsResolverCanceller *token);
private:
virtual void ThreadMain();
void StartThread();
struct Entry {
enum {
// Set when it's been posted to the job queue
POSTED = 0,
// Set when the thread has finished and original thread should delete
COMPLETE = 1,
// Set when the original thread has cancelled and worker thread should delete
CANCELLED = 2,
};
Entry() : hostname(NULL) {}
~Entry() { free(hostname); }
char *hostname;
IpAddr *result;
Entry *next;
uint32 state;
ConditionVariable *condvar;
};
Entry *entry_;
Thread thread_;
bool thread_active_;
};
DnsResolverThread::DnsResolverThread() {
thread_active_ = false;
entry_ = NULL;
}
DnsResolverThread::~DnsResolverThread() {
assert(entry_ == NULL);
thread_.StopThread();
}
void DnsResolverCanceller::Cancel() {
g_dns_mutex.Acquire();
cancel_ = true;
condvar_.Wake();
g_dns_mutex.Release();
}
bool DnsResolverThread::Resolve(const char *hostname, IpAddr *result, DnsResolverCanceller *token) {
if (token->cancel_)
return false;
Entry *e = new Entry;
e->hostname = _strdup(hostname);
e->result = result;
e->next = NULL;
e->state = Entry::POSTED;
e->condvar = &token->condvar_;
result->sin.sin_family = 0;
// Push it to the queue and start thread
g_dns_mutex.Acquire();
Entry **p = &entry_;
while (*p) p = &(*p)->next;
*p = e;
if (!thread_active_)
StartThread();
// Wait for something to happen with it.
while (!token->cancel_ && e->state == Entry::POSTED)
token->condvar_.Wait(&g_dns_mutex);
if (e->state == Entry::COMPLETE) {
delete e;
} else {
e->state = Entry::CANCELLED;
}
g_dns_mutex.Release();
return result->sin.sin_family != 0;
}
void DnsResolverThread::StartThread() {
thread_.StopThread();
thread_active_ = true;
thread_.StartThread(this);
}
void DnsResolverThread::ThreadMain() {
Entry *e = NULL;
struct hostent *he = NULL;
for (;;) {
g_dns_mutex.Acquire();
if (e) {
if (e->state == Entry::CANCELLED) {
delete e;
} else {
if (he) {
e->result->sin.sin_family = AF_INET;
e->result->sin.sin_port = 0;
memcpy(&e->result->sin.sin_addr, he->h_addr_list[0], 4);
}
e->state = Entry::COMPLETE;
e->condvar->Wake();
}
}
if (!(e = entry_)) {
thread_active_ = false;
break;
}
entry_ = e->next;
g_dns_mutex.Release();
he = gethostbyname(e->hostname);
}
g_dns_mutex.Release();
}
static DnsResolverThread g_dnsresolver_thread;
bool InterruptibleSleep(int delay, DnsResolverCanceller *token) {
g_dns_mutex.Acquire();
uint32 time_at_start = (uint32)OsGetMilliseconds();
while (delay > 0 && !token->cancel_) {
token->condvar_.WaitTimed(&g_dns_mutex, delay);
uint32 now = (uint32)OsGetMilliseconds();
delay -= (now - time_at_start);
time_at_start = now;
}
g_dns_mutex.Release();
return (delay <= 0);
}
DnsResolver::DnsResolver(DnsBlocker *dns_blocker) {
dns_blocker_ = dns_blocker;
abort_flag_ = false;
}
DnsResolver::~DnsResolver() {
@ -126,23 +262,17 @@ bool DnsResolver::Resolve(const char *hostname, IpAddr *result) {
#endif // defined(OS_WIN)
for (;;) {
hostent *he = gethostbyname(hostname);
if (abort_flag_)
return false;
if (he) {
result->sin.sin_family = AF_INET;
result->sin.sin_port = 0;
memcpy(&result->sin.sin_addr, he->h_addr_list[0], 4);
if (g_dnsresolver_thread.Resolve(hostname, result, &token_)) {
// add to cache
cache_.emplace_back(hostname, *result);
RINFO("Resolved %s to %s%s", hostname, PrintIpAddr(*result, buf), "");
return true;
}
if (token_.is_cancelled())
return false;
RINFO("Unable to resolve %s. Trying again in %d second(s)", hostname, retry_delays[attempt]);
OsInterruptibleSleep(retry_delays[attempt] * 1000);
if (abort_flag_)
if (!InterruptibleSleep(retry_delays[attempt] * 1000, &token_))
return false;
if (attempt != ARRAY_SIZE(retry_delays) - 1)
@ -150,7 +280,11 @@ bool DnsResolver::Resolve(const char *hostname, IpAddr *result) {
}
}
bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver) {
bool ParseSockaddrInWithPort(const char *si, IpAddr *sin, DnsResolver *resolver) {
size_t len = strlen(si) + 1;
char *s = (char*)alloca(len);
memcpy(s, si, len);
memset(sin, 0, sizeof(IpAddr));
if (*s == '[') {
char *end = strchr(s, ']');

View file

@ -4,20 +4,32 @@
#define TINYVPN_TINYVPN_H_
#include "netapi.h"
#include "tunsafe_threading.h"
class WireguardProcessor;
class DnsBlocker;
class DnsResolverCanceller {
public:
DnsResolverCanceller() : cancel_(false) {}
void Cancel();
void Reset() { cancel_ = false; }
bool is_cancelled() { return cancel_; }
public:
bool cancel_;
ConditionVariable condvar_;
};
class DnsResolver {
public:
explicit DnsResolver(DnsBlocker *dns_blocker);
~DnsResolver();
bool Resolve(const char *hostname, IpAddr *result);
void ClearCache();
void SetAbortFlag(bool v) { abort_flag_ = v; }
void Cancel() { token_.Cancel(); }
void ResetCancel() { token_.Reset(); }
private:
struct Entry {
std::string name;
@ -25,8 +37,8 @@ private:
Entry(const std::string &name, const IpAddr &ip) : name(name), ip(ip) {}
};
std::vector<Entry> cache_;
bool abort_flag_;
DnsBlocker *dns_blocker_;
DnsResolverCanceller token_;
};
@ -43,6 +55,8 @@ bool ParseWireGuardConfigFile(WireguardProcessor *wg, const char *filename, DnsR
const char *print_ip_prefix(char buf[kSizeOfAddress], int family, const void *ip, int prefixlen);
char *PrintIpAddr(const IpAddr &addr, char buf[kSizeOfAddress]);
char *PrintWgCidrAddr(const WgCidrAddr &addr, char buf[kSizeOfAddress]);
bool ParseCidrAddr(char *s, WgCidrAddr *out);
bool ParseSockaddrInWithPort(const char *s, IpAddr *sin, DnsResolver *resolver);
#endif // TINYVPN_TINYVPN_H_