Update service code to use a state machine instead of blocking

This commit is contained in:
Ludvig Strigeus 2018-09-11 19:33:38 +02:00
parent 64078ee051
commit e499a3d4f7
2 changed files with 151 additions and 177 deletions

View file

@ -401,9 +401,11 @@ PipeMessageHandler::PipeMessageHandler(const char *pipe_name, bool is_server_pip
thread_ = NULL; thread_ = NULL;
packets_end_ = &packets_; packets_end_ = &packets_;
write_overlapped_active_ = false; write_overlapped_active_ = false;
exit_ = false; exit_thread_ = false;
connection_established_ = false; connection_established_ = false;
thread_id_ = 0; thread_id_ = 0;
state_ = kStateNone;
tmp_packet_buf_ = NULL;
} }
PipeMessageHandler::~PipeMessageHandler() { PipeMessageHandler::~PipeMessageHandler() {
@ -414,7 +416,7 @@ PipeMessageHandler::~PipeMessageHandler() {
free(pipe_name_); free(pipe_name_);
} }
bool PipeMessageHandler::InitializeServerPipe() { bool PipeMessageHandler::InitializeServerPipeAndWait() {
int BUFSIZE = 2048; int BUFSIZE = 2048;
SECURITY_ATTRIBUTES saPipeSecurity = {0}; SECURITY_ATTRIBUTES saPipeSecurity = {0};
uint8 buf[SECURITY_DESCRIPTOR_MIN_LENGTH]; uint8 buf[SECURITY_DESCRIPTOR_MIN_LENGTH];
@ -437,15 +439,23 @@ bool PipeMessageHandler::InitializeServerPipe() {
PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE | PIPE_REJECT_REMOTE_CLIENTS | PIPE_WAIT, PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE | PIPE_REJECT_REMOTE_CLIENTS | PIPE_WAIT,
PIPE_UNLIMITED_INSTANCES, PIPE_UNLIMITED_INSTANCES,
BUFSIZE, BUFSIZE, 0, &saPipeSecurity); BUFSIZE, BUFSIZE, 0, &saPipeSecurity);
return pipe_ != INVALID_HANDLE_VALUE; if (pipe_ == INVALID_HANDLE_VALUE)
return false;
memset(&read_overlapped_, 0, sizeof(read_overlapped_));
read_overlapped_.hEvent = wait_handles_[0];
if (!ConnectNamedPipe(pipe_, &read_overlapped_)) {
DWORD rv = GetLastError();
if (rv != ERROR_PIPE_CONNECTED && rv != ERROR_IO_PENDING)
return false;
}
return true;
} }
bool PipeMessageHandler::InitializeClientPipe() { bool PipeMessageHandler::InitializeClientPipe() {
assert(pipe_ == INVALID_HANDLE_VALUE); assert(pipe_ == INVALID_HANDLE_VALUE);
pipe_ = CreateFile( pipe_ = CreateFile(pipe_name_, GENERIC_READ | GENERIC_WRITE, 0, NULL,
pipe_name_, OPEN_EXISTING, FILE_FLAG_OVERLAPPED, NULL);
GENERIC_READ | GENERIC_WRITE, 0,
NULL, OPEN_EXISTING, FILE_FLAG_OVERLAPPED, NULL);
if (pipe_ == INVALID_HANDLE_VALUE) if (pipe_ == INVALID_HANDLE_VALUE)
return false; return false;
DWORD mode = PIPE_READMODE_MESSAGE; DWORD mode = PIPE_READMODE_MESSAGE;
@ -462,6 +472,12 @@ void PipeMessageHandler::ClosePipe() {
connection_established_ = false; connection_established_ = false;
write_overlapped_active_ = false; write_overlapped_active_ = false;
free(tmp_packet_buf_);
tmp_packet_buf_ = NULL;
ResetEvent(wait_handles_[0]);
ResetEvent(wait_handles_[2]);
packets_mutex_.Acquire(); packets_mutex_.Acquire();
OutgoingPacket *packets = packets_; OutgoingPacket *packets = packets_;
packets_ = NULL; packets_ = NULL;
@ -517,107 +533,83 @@ void PipeMessageHandler::SendNextQueuedWrite() {
write_overlapped_.hEvent = wait_handles_[2]; write_overlapped_.hEvent = wait_handles_[2];
if (WriteFile(pipe_, p->data, p->size, NULL, &write_overlapped_) || GetLastError() == ERROR_IO_PENDING) if (WriteFile(pipe_, p->data, p->size, NULL, &write_overlapped_) || GetLastError() == ERROR_IO_PENDING)
write_overlapped_active_ = true; write_overlapped_active_ = true;
} else {
ResetEvent(wait_handles_[2]);
} }
} }
} }
uint8 *PipeMessageHandler::ReadNamedPipeAsync(size_t *packet_size) { #define TS_WAIT_BEGIN(t) switch(state_) { case t:
OVERLAPPED ov = {0}; #define TS_WAIT_POINT(t) state_ = (t); return; case t:
uint8 *result = NULL; #define TS_WAIT_END() }
DWORD bytes_waiting = 0;
DWORD rv;
ov.hEvent = wait_handles_[0];
if (!ReadFile(pipe_, NULL, 0, NULL, &ov)) {
rv = GetLastError();
if (rv != ERROR_IO_PENDING && rv != ERROR_MORE_DATA)
goto getout;
}
if (!WaitAndHandleWrites(INFINITE)) { void PipeMessageHandler::AdvanceStateMachine() {
CancelIo(pipe_); DWORD rv, bytes_read;
write_overlapped_active_ = false;
goto getout;
}
PeekNamedPipe(pipe_, NULL, 0, NULL, &bytes_waiting, NULL); TS_WAIT_BEGIN(kStateNone)
if (bytes_waiting == 0) for(;;) {
goto getout; // this is typically what happens when pipe closes. // Create a named pipe and wait for connections from the UI process
if (is_server_pipe_) {
result = (uint8*)malloc(bytes_waiting); if (!InitializeServerPipeAndWait()) {
if (!result) if (!exit_thread_)
goto getout; ExitProcess(1);
break;
if (!ReadFile(pipe_, result, bytes_waiting, NULL, &ov)) { }
rv = GetLastError(); TS_WAIT_POINT(kStateWaitConnect);
if (rv != ERROR_IO_PENDING) } else {
goto getout; if (!InitializeClientPipe()) {
} RINFO("Unable to connect to the TunSafe Service. Please make sure it's running.");
if (!WaitAndHandleWrites(1000)) { break;
CancelIo(pipe_); }
write_overlapped_active_ = false; }
free(result); connection_established_ = true;
result = NULL; delegate_->HandleNewConnection();
goto getout;
}
bytes_waiting = (uint32)ov.InternalHigh;
if (bytes_waiting == 0) {
free(result);
result = NULL;
goto getout;
}
*packet_size = bytes_waiting;
getout:
return result;
}
bool PipeMessageHandler::ConnectNamedPipeAsync() {
OVERLAPPED ov = {0};
DWORD rv;
bool result = false;
ov.hEvent = wait_handles_[0];
if (!ConnectNamedPipe(pipe_, &ov)) {
rv = GetLastError();
if (rv != ERROR_PIPE_CONNECTED && rv != ERROR_IO_PENDING)
goto getout;
}
if (!WaitAndHandleWrites(INFINITE)) {
CancelIo(pipe_);
write_overlapped_active_ = false;
goto getout;
}
result = true;
getout:
return result;
}
bool PipeMessageHandler::WaitAndHandleWrites(int delay) {
DWORD rv;
assert(thread_id_ == GetCurrentThreadId());
again:
rv = WaitForMultipleObjects(2 + write_overlapped_active_, wait_handles_, FALSE, delay);
if (rv == WAIT_OBJECT_0 + 2) {
assert(write_overlapped_active_);
write_overlapped_active_ = false;
// Remove the packet from the front of the queue, now
// that it was sent.
packets_mutex_.Acquire();
OutgoingPacket *p = packets_;
if ((packets_ = p->next) == NULL)
packets_end_ = &packets_;
packets_mutex_.Release();
free(p);
SendNextQueuedWrite(); SendNextQueuedWrite();
goto again;
}
if (rv == WAIT_OBJECT_0 + 1) {
if (exit_ || !delegate_->HandleNotify())
return false;
SendNextQueuedWrite(); for (;;) {
goto again; memset(&read_overlapped_, 0, sizeof(read_overlapped_));
} read_overlapped_.hEvent = wait_handles_[0];
return rv == WAIT_OBJECT_0; if (!ReadFile(pipe_, NULL, 0, NULL, &read_overlapped_)) {
rv = GetLastError();
if (rv != ERROR_IO_PENDING && rv != ERROR_MORE_DATA)
break;
}
TS_WAIT_POINT(kStateWaitReadLength);
PeekNamedPipe(pipe_, NULL, 0, NULL, &tmp_packet_size_, NULL);
if (tmp_packet_size_ == 0)
break;
free(tmp_packet_buf_);
tmp_packet_buf_ = (uint8*)malloc(tmp_packet_size_);
if (!tmp_packet_buf_)
break;
memset(&read_overlapped_, 0, sizeof(read_overlapped_));
read_overlapped_.hEvent = wait_handles_[0];
if (!ReadFile(pipe_, tmp_packet_buf_, tmp_packet_size_, NULL, &read_overlapped_)) {
rv = GetLastError();
if (rv != ERROR_IO_PENDING)
break;
}
TS_WAIT_POINT(kStateWaitReadPayload);
bytes_read = (uint32)read_overlapped_.InternalHigh;
if (bytes_read == 0)
break;
if (!delegate_->HandleMessage(tmp_packet_buf_[0], tmp_packet_buf_ + 1, bytes_read - 1)) {
ResetEvent(wait_handles_[0]);
TS_WAIT_POINT(kStateWaitTimeout);
break;
}
}
if (exit_thread_)
break;
delegate_->HandleDisconnect();
if (!is_server_pipe_)
break;
ClosePipe();
}
TS_WAIT_END()
ClosePipe();
} }
DWORD WINAPI PipeMessageHandler::StaticThreadMain(void *x) { DWORD WINAPI PipeMessageHandler::StaticThreadMain(void *x) {
@ -630,73 +622,49 @@ bool PipeMessageHandler::VerifyThread() {
DWORD PipeMessageHandler::ThreadMain() { DWORD PipeMessageHandler::ThreadMain() {
assert((thread_id_ = GetCurrentThreadId()) != 0); assert((thread_id_ = GetCurrentThreadId()) != 0);
assert(state_ == kStateNone);
while (!exit_) { AdvanceStateMachine();
// Create a named pipe and wait for connections from the UI process
if (is_server_pipe_) { for(;;) {
if (!InitializeServerPipe()) { DWORD rv = WaitForMultipleObjects(3, wait_handles_, FALSE, (state_ == kStateWaitTimeout) ? 1000 : INFINITE);
if (!exit_)
ExitProcess(1); // packet write finished?
if (rv == WAIT_OBJECT_0 + 2) {
assert(write_overlapped_active_);
write_overlapped_active_ = false;
// Remove the packet from the front of the queue, now that it was sent.
packets_mutex_.Acquire();
OutgoingPacket *p = packets_;
if ((packets_ = p->next) == NULL)
packets_end_ = &packets_;
packets_mutex_.Release();
free(p);
SendNextQueuedWrite();
// notification
} else if (rv == WAIT_OBJECT_0 + 1) {
if (exit_thread_ || !delegate_->HandleNotify())
break; break;
} // The notification event is set when there might be new messages to send,
// Wait for a client to connect to us. // so try to send them.
if (!ConnectNamedPipeAsync()) { SendNextQueuedWrite();
if (!exit_)
ExitProcess(1); // read finished?
break; } else if (rv == WAIT_OBJECT_0) {
} AdvanceStateMachine();
} else if (rv == WAIT_TIMEOUT) {
if (state_ == kStateWaitTimeout)
AdvanceStateMachine();
} else { } else {
if (!InitializeClientPipe()) { assert(0);
RINFO("Unable to connect to the TunSafe Service. Please make sure it's running.");
break;
}
} }
connection_established_ = true;
if (!delegate_->HandleNewConnection())
goto closepipe;
SendNextQueuedWrite();
// Read/Process each message
for (;;) {
size_t message_size;
uint8 *message = ReadNamedPipeAsync(&message_size);
if (!message)
break;
if (message_size) {
if (!delegate_->HandleMessage(message[0], message + 1, message_size - 1)) {
FlushWrites(1000);
break;
}
}
free(message);
}
if (exit_)
break;
delegate_->HandleDisconnect();
if (!is_server_pipe_)
break;
closepipe:
ClosePipe();
} }
ClosePipe();
return 0; return 0;
} }
void PipeMessageHandler::FlushWrites(int delay) {
ResetEvent(wait_handles_[0]);
WaitAndHandleWrites(1000);
}
bool PipeMessageHandler::StartThread() { bool PipeMessageHandler::StartThread() {
DWORD thread_id; DWORD thread_id;
assert(thread_ == NULL); assert(thread_ == NULL);
@ -706,7 +674,7 @@ bool PipeMessageHandler::StartThread() {
void PipeMessageHandler::StopThread() { void PipeMessageHandler::StopThread() {
if (thread_ != NULL) { if (thread_ != NULL) {
exit_ = true; exit_thread_ = true;
SetEvent(wait_handles_[1]); SetEvent(wait_handles_[1]);
WaitForSingleObject(thread_, INFINITE); WaitForSingleObject(thread_, INFINITE);
CloseHandle(thread_); CloseHandle(thread_);
@ -749,8 +717,6 @@ static wchar_t *RegReadStrW(HKEY hkey, const wchar_t *key, const wchar_t *def) {
} }
unsigned TunsafeServiceImpl::OnStart(int argc, wchar_t **argv) { unsigned TunsafeServiceImpl::OnStart(int argc, wchar_t **argv) {
message_handler_.StartThread();
uint32 service_flags = RegReadInt(hkey_, "ServiceStartupFlags", 0); uint32 service_flags = RegReadInt(hkey_, "ServiceStartupFlags", 0);
if ( (service_flags & kStartupFlag_BackgroundService) && (service_flags & kStartupFlag_ConnectWhenWindowsStarts) ) { if ( (service_flags & kStartupFlag_BackgroundService) && (service_flags & kStartupFlag_ConnectWhenWindowsStarts) ) {
char *conf = RegReadStr(hkey_, "LastUsedConfigFile", ""); char *conf = RegReadStr(hkey_, "LastUsedConfigFile", "");
@ -761,6 +727,7 @@ unsigned TunsafeServiceImpl::OnStart(int argc, wchar_t **argv) {
free(conf); free(conf);
} }
message_handler_.StartThread();
return 0; return 0;
} }
@ -870,11 +837,10 @@ bool TunsafeServiceImpl::HandleNotify() {
return true; return true;
} }
bool TunsafeServiceImpl::HandleNewConnection() { void TunsafeServiceImpl::HandleNewConnection() {
did_send_getstate_ = false; did_send_getstate_ = false;
did_authenticate_user_ = false; did_authenticate_user_ = false;
last_line_sent_ = 0; last_line_sent_ = 0;
return true;
} }
void TunsafeServiceImpl::HandleDisconnect() { void TunsafeServiceImpl::HandleDisconnect() {
@ -1151,11 +1117,10 @@ bool TunsafeServiceClient::HandleNotify() {
} }
bool TunsafeServiceClient::HandleNewConnection() { void TunsafeServiceClient::HandleNewConnection() {
message_handler_.WritePacket(SERVICE_REQ_LOGIN, (uint8*)&kTunsafeServiceProtocolVersion, 8); message_handler_.WritePacket(SERVICE_REQ_LOGIN, (uint8*)&kTunsafeServiceProtocolVersion, 8);
if (want_stats_) if (want_stats_)
message_handler_.WritePacket(SERVICE_REQ_GETSTATS, &want_stats_, 1); message_handler_.WritePacket(SERVICE_REQ_GETSTATS, &want_stats_, 1);
return true;
} }
void TunsafeServiceClient::HandleDisconnect() { void TunsafeServiceClient::HandleDisconnect() {

View file

@ -28,7 +28,7 @@ public:
public: public:
virtual bool HandleMessage(int type, uint8 *data, size_t size) = 0; virtual bool HandleMessage(int type, uint8 *data, size_t size) = 0;
virtual bool HandleNotify() = 0; virtual bool HandleNotify() = 0;
virtual bool HandleNewConnection() = 0; virtual void HandleNewConnection() = 0;
virtual void HandleDisconnect() = 0; virtual void HandleDisconnect() = 0;
}; };
@ -45,17 +45,14 @@ public:
bool VerifyThread(); bool VerifyThread();
void FlushWrites(int delay);
bool is_connected() { return connection_established_; } bool is_connected() { return connection_established_; }
private: private:
bool InitializeServerPipe(); bool InitializeServerPipeAndWait();
bool InitializeClientPipe(); bool InitializeClientPipe();
void AdvanceStateMachine();
void ClosePipe(); void ClosePipe();
DWORD ThreadMain(); DWORD ThreadMain();
void SendNextQueuedWrite(); void SendNextQueuedWrite();
uint8 *ReadNamedPipeAsync(size_t *packet_size);
bool ConnectNamedPipeAsync();
bool WaitAndHandleWrites(int delay);
static DWORD WINAPI StaticThreadMain(void *x); static DWORD WINAPI StaticThreadMain(void *x);
Delegate *delegate_; Delegate *delegate_;
@ -63,26 +60,38 @@ private:
HANDLE pipe_; HANDLE pipe_;
HANDLE thread_; HANDLE thread_;
HANDLE wait_handles_[3]; HANDLE wait_handles_[3];
OVERLAPPED write_overlapped_;
bool write_overlapped_active_; bool write_overlapped_active_;
bool exit_; bool exit_thread_;
bool is_server_pipe_; bool is_server_pipe_;
bool connection_established_; bool connection_established_;
char *pipe_name_; char *pipe_name_;
enum State {
kStateNone,
kStateWaitConnect,
kStateWaitReadLength,
kStateWaitReadPayload,
kStateWaitTimeout,
};
int state_;
struct OutgoingPacket { struct OutgoingPacket {
OutgoingPacket *next; OutgoingPacket *next;
uint32 size; uint32 size;
uint8 data[0]; uint8 data[0];
}; };
OutgoingPacket *packets_, **packets_end_; OutgoingPacket *packets_, **packets_end_;
uint8 *tmp_packet_buf_;
DWORD tmp_packet_size_;
OVERLAPPED write_overlapped_, read_overlapped_;
Mutex packets_mutex_; Mutex packets_mutex_;
DWORD thread_id_; DWORD thread_id_;
}; };
class TunsafeServiceImpl : public TunsafeBackend::Delegate, public PipeMessageHandler::Delegate { class TunsafeServiceImpl : public TunsafeBackend::Delegate, public PipeMessageHandler::Delegate {
public: public:
TunsafeServiceImpl(); TunsafeServiceImpl();
@ -99,7 +108,7 @@ public:
// -- from PipeMessageHandler::Delegate // -- from PipeMessageHandler::Delegate
virtual bool HandleMessage(int type, uint8 *data, size_t size); virtual bool HandleMessage(int type, uint8 *data, size_t size);
virtual bool HandleNotify(); virtual bool HandleNotify();
virtual bool HandleNewConnection(); virtual void HandleNewConnection();
virtual void HandleDisconnect(); virtual void HandleDisconnect();
// virtual methods // virtual methods
@ -155,7 +164,7 @@ public:
// -- from PipeMessageHandler::Delegate // -- from PipeMessageHandler::Delegate
virtual bool HandleMessage(int type, uint8 *data, size_t size); virtual bool HandleMessage(int type, uint8 *data, size_t size);
virtual bool HandleNotify(); virtual bool HandleNotify();
virtual bool HandleNewConnection(); virtual void HandleNewConnection();
virtual void HandleDisconnect(); virtual void HandleDisconnect();
protected: protected: