Background thread for DNS resolve
This commit is contained in:
parent
c08b18c028
commit
83183b2193
|
@ -2088,7 +2088,7 @@ void TunsafeBackendWin32::Stop() {
|
|||
|
||||
void TunsafeBackendWin32::Start(const char *config_file) {
|
||||
StopInner(true);
|
||||
dns_resolver_.SetAbortFlag(false);
|
||||
dns_resolver_.ResetCancel();
|
||||
is_started_ = true;
|
||||
memset(public_key_, 0, sizeof(public_key_));
|
||||
SetStatus(kStatusInitializing);
|
||||
|
@ -2107,7 +2107,7 @@ void TunsafeBackendWin32::PostExit(int exit_code) {
|
|||
void TunsafeBackendWin32::StopInner(bool is_restart) {
|
||||
if (worker_thread_) {
|
||||
ipv4_ip_ = 0;
|
||||
dns_resolver_.SetAbortFlag(true);
|
||||
dns_resolver_.Cancel();
|
||||
PostExit(is_restart ? MODE_RESTART : MODE_EXIT);
|
||||
WaitForSingleObject(worker_thread_, INFINITE);
|
||||
CloseHandle(worker_thread_);
|
||||
|
|
|
@ -3,6 +3,110 @@
|
|||
#include "stdafx.h"
|
||||
#include "tunsafe_threading.h"
|
||||
#include <stdlib.h>
|
||||
#include <assert.h>
|
||||
|
||||
#if defined(OS_POSIX)
|
||||
Thread::Thread() {
|
||||
thread_ = 0;
|
||||
}
|
||||
|
||||
Thread::~Thread() {
|
||||
assert(thread_ == 0);
|
||||
}
|
||||
|
||||
static void *ThreadMainStatic(void *x) {
|
||||
Thread::Runner *t = (Thread::Runner*)x;
|
||||
t->ThreadMain();
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Thread::StartThread(Runner *runner) {
|
||||
assert(thread_ == 0);
|
||||
if (pthread_create(&thread_, NULL, &ThreadMainStatic, runner) != 0)
|
||||
tunsafe_die("pthread_create failed");
|
||||
}
|
||||
|
||||
void Thread::StopThread() {
|
||||
if (thread_) {
|
||||
void *x;
|
||||
pthread_join(thread_, &x);
|
||||
thread_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void Thread::DetachThread() {
|
||||
if (thread_) {
|
||||
pthread_detach(thread_);
|
||||
thread_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
bool Thread::is_started() {
|
||||
return thread_ != 0;
|
||||
}
|
||||
|
||||
void ConditionVariable::WaitTimed(Mutex *mutex, int millis) {
|
||||
#if !defined(OS_MACOSX)
|
||||
struct timespec ts;
|
||||
clock_gettime(CLOCK_REALTIME, &ts);
|
||||
|
||||
ts.tv_sec += (millis / 1000);
|
||||
ts.tv_nsec += (millis % 1000) * 1000000;
|
||||
if (ts.tv_nsec >= 1000000000) {
|
||||
ts.tv_nsec -= 1000000000;
|
||||
ts.tv_sec++;
|
||||
}
|
||||
pthread_cond_timedwait(&condvar_, &mutex->lock_, &ts);
|
||||
#else
|
||||
struct timespec ts;
|
||||
ts.tv_sec = millis / 1000;
|
||||
ts.tv_nsec = (millis % 1000) * 1000000;
|
||||
|
||||
pthread_cond_timedwait_relative_np(&condvar_, &mutex->lock_, &ts);
|
||||
#endif
|
||||
}
|
||||
#endif // defined(OS_POSIX)
|
||||
|
||||
#if defined(OS_WIN)
|
||||
Thread::Thread() {
|
||||
thread_ = 0;
|
||||
}
|
||||
|
||||
Thread::~Thread() {
|
||||
assert(thread_ == 0);
|
||||
}
|
||||
|
||||
static DWORD WINAPI ThreadMainStatic(void *x) {
|
||||
Thread::Runner *t = (Thread::Runner*)x;
|
||||
t->ThreadMain();
|
||||
return 0;
|
||||
}
|
||||
|
||||
void Thread::StartThread(Runner *runner) {
|
||||
assert(thread_ == 0);
|
||||
DWORD thread_id;
|
||||
thread_ = CreateThread(NULL, 0, &ThreadMainStatic, (LPVOID)runner, 0, &thread_id);
|
||||
}
|
||||
|
||||
void Thread::StopThread() {
|
||||
if (thread_) {
|
||||
WaitForSingleObject(thread_, INFINITE);
|
||||
CloseHandle(thread_);
|
||||
thread_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void Thread::DetachThread() {
|
||||
if (thread_) {
|
||||
CloseHandle(thread_);
|
||||
thread_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
bool Thread::is_started() {
|
||||
return thread_ != 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
MultithreadedDelayedDelete::MultithreadedDelayedDelete() {
|
||||
table_ = NULL;
|
||||
|
|
|
@ -23,6 +23,7 @@ private:
|
|||
};
|
||||
|
||||
class Mutex {
|
||||
friend class ConditionVariable;
|
||||
public:
|
||||
#if defined(_DEBUG)
|
||||
bool locked_;
|
||||
|
@ -46,6 +47,17 @@ private:
|
|||
SRWLOCK lock_;
|
||||
};
|
||||
|
||||
class ConditionVariable {
|
||||
public:
|
||||
ConditionVariable() { InitializeConditionVariable(&condvar_); }
|
||||
void Wait(Mutex *mutex) { SleepConditionVariableSRW(&condvar_, &mutex->lock_, INFINITE, 0); }
|
||||
void WaitTimed(Mutex *mutex, int millis) { SleepConditionVariableSRW(&condvar_, &mutex->lock_, millis, 0); }
|
||||
void Wake() { WakeConditionVariable(&condvar_); }
|
||||
|
||||
private:
|
||||
CONDITION_VARIABLE condvar_;
|
||||
};
|
||||
|
||||
typedef uint32 ThreadId;
|
||||
|
||||
static inline bool CurrentThreadIdEquals(ThreadId thread_id) {
|
||||
|
@ -72,6 +84,7 @@ private:
|
|||
};
|
||||
|
||||
class Mutex {
|
||||
friend class ConditionVariable;
|
||||
public:
|
||||
#if defined(_DEBUG)
|
||||
bool locked_;
|
||||
|
@ -104,6 +117,17 @@ private:
|
|||
pthread_mutex_t lock_;
|
||||
};
|
||||
|
||||
class ConditionVariable {
|
||||
public:
|
||||
ConditionVariable() { pthread_cond_init(&condvar_, NULL); }
|
||||
void Wait(Mutex *mutex) { pthread_cond_wait(&condvar_, &mutex->lock_); }
|
||||
void WaitTimed(Mutex *mutex, int millis);
|
||||
void Wake() { pthread_cond_signal(&condvar_); }
|
||||
private:
|
||||
pthread_cond_t condvar_;
|
||||
};
|
||||
|
||||
|
||||
typedef pthread_t ThreadId;
|
||||
|
||||
static inline bool CurrentThreadIdEquals(ThreadId thread_id) {
|
||||
|
@ -140,6 +164,28 @@ private:
|
|||
Mutex *lock_;
|
||||
};
|
||||
|
||||
class Thread {
|
||||
public:
|
||||
class Runner {
|
||||
public:
|
||||
virtual void ThreadMain() = 0;
|
||||
};
|
||||
Thread();
|
||||
~Thread();
|
||||
void StartThread(Runner *runner);
|
||||
void StopThread();
|
||||
void DetachThread();
|
||||
bool is_started();
|
||||
|
||||
private:
|
||||
#if defined(OS_WIN)
|
||||
HANDLE thread_;
|
||||
#else // defined(OS_WIN)
|
||||
pthread_t thread_;
|
||||
#endif // !defined(OS_WIN)
|
||||
};
|
||||
|
||||
|
||||
// This class deletes objects delayed. All participating threads will call a function,
|
||||
// and then once all threads did, all registered objects will get deleted.
|
||||
class MultithreadedDelayedDelete {
|
||||
|
|
|
@ -57,8 +57,6 @@ char *PrintWgCidrAddr(const WgCidrAddr &addr, char buf[kSizeOfAddress]) {
|
|||
return buf;
|
||||
}
|
||||
|
||||
|
||||
|
||||
struct Addr {
|
||||
byte addr[4];
|
||||
uint8 cidr;
|
||||
|
@ -88,9 +86,147 @@ bool ParseCidrAddr(char *s, WgCidrAddr *out) {
|
|||
return false;
|
||||
}
|
||||
|
||||
static Mutex g_dns_mutex;
|
||||
|
||||
// This starts a background thread for running DNS resolving.
|
||||
class DnsResolverThread : private Thread::Runner {
|
||||
public:
|
||||
DnsResolverThread();
|
||||
~DnsResolverThread();
|
||||
|
||||
// Resolve the hostname and store the result in |result|.
|
||||
// The function will block until it's resolved. If the cancellation
|
||||
// token or becomes signalled, the call will fail.
|
||||
bool Resolve(const char *hostname, IpAddr *result, DnsResolverCanceller *token);
|
||||
|
||||
private:
|
||||
virtual void ThreadMain();
|
||||
void StartThread();
|
||||
|
||||
struct Entry {
|
||||
enum {
|
||||
// Set when it's been posted to the job queue
|
||||
POSTED = 0,
|
||||
// Set when the thread has finished and original thread should delete
|
||||
COMPLETE = 1,
|
||||
// Set when the original thread has cancelled and worker thread should delete
|
||||
CANCELLED = 2,
|
||||
};
|
||||
|
||||
Entry() : hostname(NULL) {}
|
||||
~Entry() { free(hostname); }
|
||||
|
||||
char *hostname;
|
||||
IpAddr *result;
|
||||
Entry *next;
|
||||
uint32 state;
|
||||
ConditionVariable *condvar;
|
||||
};
|
||||
Entry *entry_;
|
||||
Thread thread_;
|
||||
bool thread_active_;
|
||||
};
|
||||
|
||||
DnsResolverThread::DnsResolverThread() {
|
||||
thread_active_ = false;
|
||||
entry_ = NULL;
|
||||
}
|
||||
|
||||
DnsResolverThread::~DnsResolverThread() {
|
||||
assert(entry_ == NULL);
|
||||
thread_.StopThread();
|
||||
}
|
||||
|
||||
void DnsResolverCanceller::Cancel() {
|
||||
g_dns_mutex.Acquire();
|
||||
cancel_ = true;
|
||||
condvar_.Wake();
|
||||
g_dns_mutex.Release();
|
||||
}
|
||||
|
||||
bool DnsResolverThread::Resolve(const char *hostname, IpAddr *result, DnsResolverCanceller *token) {
|
||||
if (token->cancel_)
|
||||
return false;
|
||||
|
||||
Entry *e = new Entry;
|
||||
e->hostname = _strdup(hostname);
|
||||
e->result = result;
|
||||
e->next = NULL;
|
||||
e->state = Entry::POSTED;
|
||||
e->condvar = &token->condvar_;
|
||||
result->sin.sin_family = 0;
|
||||
|
||||
// Push it to the queue and start thread
|
||||
g_dns_mutex.Acquire();
|
||||
Entry **p = &entry_;
|
||||
while (*p) p = &(*p)->next;
|
||||
*p = e;
|
||||
if (!thread_active_)
|
||||
StartThread();
|
||||
// Wait for something to happen with it.
|
||||
while (!token->cancel_ && e->state == Entry::POSTED)
|
||||
token->condvar_.Wait(&g_dns_mutex);
|
||||
if (e->state == Entry::COMPLETE) {
|
||||
delete e;
|
||||
} else {
|
||||
e->state = Entry::CANCELLED;
|
||||
}
|
||||
g_dns_mutex.Release();
|
||||
return result->sin.sin_family != 0;
|
||||
}
|
||||
|
||||
void DnsResolverThread::StartThread() {
|
||||
thread_.StopThread();
|
||||
thread_active_ = true;
|
||||
thread_.StartThread(this);
|
||||
}
|
||||
|
||||
void DnsResolverThread::ThreadMain() {
|
||||
Entry *e = NULL;
|
||||
struct hostent *he = NULL;
|
||||
for (;;) {
|
||||
g_dns_mutex.Acquire();
|
||||
if (e) {
|
||||
if (e->state == Entry::CANCELLED) {
|
||||
delete e;
|
||||
} else {
|
||||
if (he) {
|
||||
e->result->sin.sin_family = AF_INET;
|
||||
e->result->sin.sin_port = 0;
|
||||
memcpy(&e->result->sin.sin_addr, he->h_addr_list[0], 4);
|
||||
}
|
||||
e->state = Entry::COMPLETE;
|
||||
e->condvar->Wake();
|
||||
}
|
||||
}
|
||||
if (!(e = entry_)) {
|
||||
thread_active_ = false;
|
||||
break;
|
||||
}
|
||||
entry_ = e->next;
|
||||
g_dns_mutex.Release();
|
||||
he = gethostbyname(e->hostname);
|
||||
}
|
||||
g_dns_mutex.Release();
|
||||
}
|
||||
|
||||
static DnsResolverThread g_dnsresolver_thread;
|
||||
|
||||
bool InterruptibleSleep(int delay, DnsResolverCanceller *token) {
|
||||
g_dns_mutex.Acquire();
|
||||
uint32 time_at_start = (uint32)OsGetMilliseconds();
|
||||
while (delay > 0 && !token->cancel_) {
|
||||
token->condvar_.WaitTimed(&g_dns_mutex, delay);
|
||||
uint32 now = (uint32)OsGetMilliseconds();
|
||||
delay -= (now - time_at_start);
|
||||
time_at_start = now;
|
||||
}
|
||||
g_dns_mutex.Release();
|
||||
return (delay <= 0);
|
||||
}
|
||||
|
||||
DnsResolver::DnsResolver(DnsBlocker *dns_blocker) {
|
||||
dns_blocker_ = dns_blocker;
|
||||
abort_flag_ = false;
|
||||
}
|
||||
|
||||
DnsResolver::~DnsResolver() {
|
||||
|
@ -126,23 +262,17 @@ bool DnsResolver::Resolve(const char *hostname, IpAddr *result) {
|
|||
#endif // defined(OS_WIN)
|
||||
|
||||
for (;;) {
|
||||
hostent *he = gethostbyname(hostname);
|
||||
if (abort_flag_)
|
||||
return false;
|
||||
|
||||
if (he) {
|
||||
result->sin.sin_family = AF_INET;
|
||||
result->sin.sin_port = 0;
|
||||
memcpy(&result->sin.sin_addr, he->h_addr_list[0], 4);
|
||||
if (g_dnsresolver_thread.Resolve(hostname, result, &token_)) {
|
||||
// add to cache
|
||||
cache_.emplace_back(hostname, *result);
|
||||
RINFO("Resolved %s to %s%s", hostname, PrintIpAddr(*result, buf), "");
|
||||
return true;
|
||||
}
|
||||
if (token_.is_cancelled())
|
||||
return false;
|
||||
|
||||
RINFO("Unable to resolve %s. Trying again in %d second(s)", hostname, retry_delays[attempt]);
|
||||
OsInterruptibleSleep(retry_delays[attempt] * 1000);
|
||||
if (abort_flag_)
|
||||
if (!InterruptibleSleep(retry_delays[attempt] * 1000, &token_))
|
||||
return false;
|
||||
|
||||
if (attempt != ARRAY_SIZE(retry_delays) - 1)
|
||||
|
@ -150,7 +280,11 @@ bool DnsResolver::Resolve(const char *hostname, IpAddr *result) {
|
|||
}
|
||||
}
|
||||
|
||||
bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver) {
|
||||
bool ParseSockaddrInWithPort(const char *si, IpAddr *sin, DnsResolver *resolver) {
|
||||
size_t len = strlen(si) + 1;
|
||||
char *s = (char*)alloca(len);
|
||||
memcpy(s, si, len);
|
||||
|
||||
memset(sin, 0, sizeof(IpAddr));
|
||||
if (*s == '[') {
|
||||
char *end = strchr(s, ']');
|
||||
|
|
|
@ -4,20 +4,32 @@
|
|||
#define TINYVPN_TINYVPN_H_
|
||||
|
||||
#include "netapi.h"
|
||||
#include "tunsafe_threading.h"
|
||||
|
||||
class WireguardProcessor;
|
||||
class DnsBlocker;
|
||||
|
||||
class DnsResolverCanceller {
|
||||
public:
|
||||
DnsResolverCanceller() : cancel_(false) {}
|
||||
void Cancel();
|
||||
void Reset() { cancel_ = false; }
|
||||
bool is_cancelled() { return cancel_; }
|
||||
public:
|
||||
bool cancel_;
|
||||
ConditionVariable condvar_;
|
||||
};
|
||||
|
||||
class DnsResolver {
|
||||
public:
|
||||
explicit DnsResolver(DnsBlocker *dns_blocker);
|
||||
~DnsResolver();
|
||||
|
||||
bool Resolve(const char *hostname, IpAddr *result);
|
||||
|
||||
void ClearCache();
|
||||
|
||||
void SetAbortFlag(bool v) { abort_flag_ = v; }
|
||||
void Cancel() { token_.Cancel(); }
|
||||
void ResetCancel() { token_.Reset(); }
|
||||
private:
|
||||
struct Entry {
|
||||
std::string name;
|
||||
|
@ -25,8 +37,8 @@ private:
|
|||
Entry(const std::string &name, const IpAddr &ip) : name(name), ip(ip) {}
|
||||
};
|
||||
std::vector<Entry> cache_;
|
||||
bool abort_flag_;
|
||||
DnsBlocker *dns_blocker_;
|
||||
DnsResolverCanceller token_;
|
||||
};
|
||||
|
||||
|
||||
|
@ -43,6 +55,8 @@ bool ParseWireGuardConfigFile(WireguardProcessor *wg, const char *filename, DnsR
|
|||
const char *print_ip_prefix(char buf[kSizeOfAddress], int family, const void *ip, int prefixlen);
|
||||
char *PrintIpAddr(const IpAddr &addr, char buf[kSizeOfAddress]);
|
||||
char *PrintWgCidrAddr(const WgCidrAddr &addr, char buf[kSizeOfAddress]);
|
||||
|
||||
bool ParseCidrAddr(char *s, WgCidrAddr *out);
|
||||
bool ParseSockaddrInWithPort(const char *s, IpAddr *sin, DnsResolver *resolver);
|
||||
|
||||
#endif // TINYVPN_TINYVPN_H_
|
||||
|
|
Loading…
Reference in a new issue