Update service code to use a state machine instead of blocking

This commit is contained in:
Ludvig Strigeus 2018-09-11 19:33:38 +02:00
parent 64078ee051
commit e499a3d4f7
2 changed files with 151 additions and 177 deletions

View file

@ -401,9 +401,11 @@ PipeMessageHandler::PipeMessageHandler(const char *pipe_name, bool is_server_pip
thread_ = NULL;
packets_end_ = &packets_;
write_overlapped_active_ = false;
exit_ = false;
exit_thread_ = false;
connection_established_ = false;
thread_id_ = 0;
state_ = kStateNone;
tmp_packet_buf_ = NULL;
}
PipeMessageHandler::~PipeMessageHandler() {
@ -414,7 +416,7 @@ PipeMessageHandler::~PipeMessageHandler() {
free(pipe_name_);
}
bool PipeMessageHandler::InitializeServerPipe() {
bool PipeMessageHandler::InitializeServerPipeAndWait() {
int BUFSIZE = 2048;
SECURITY_ATTRIBUTES saPipeSecurity = {0};
uint8 buf[SECURITY_DESCRIPTOR_MIN_LENGTH];
@ -437,15 +439,23 @@ bool PipeMessageHandler::InitializeServerPipe() {
PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE | PIPE_REJECT_REMOTE_CLIENTS | PIPE_WAIT,
PIPE_UNLIMITED_INSTANCES,
BUFSIZE, BUFSIZE, 0, &saPipeSecurity);
return pipe_ != INVALID_HANDLE_VALUE;
if (pipe_ == INVALID_HANDLE_VALUE)
return false;
memset(&read_overlapped_, 0, sizeof(read_overlapped_));
read_overlapped_.hEvent = wait_handles_[0];
if (!ConnectNamedPipe(pipe_, &read_overlapped_)) {
DWORD rv = GetLastError();
if (rv != ERROR_PIPE_CONNECTED && rv != ERROR_IO_PENDING)
return false;
}
return true;
}
bool PipeMessageHandler::InitializeClientPipe() {
assert(pipe_ == INVALID_HANDLE_VALUE);
pipe_ = CreateFile(
pipe_name_,
GENERIC_READ | GENERIC_WRITE, 0,
NULL, OPEN_EXISTING, FILE_FLAG_OVERLAPPED, NULL);
pipe_ = CreateFile(pipe_name_, GENERIC_READ | GENERIC_WRITE, 0, NULL,
OPEN_EXISTING, FILE_FLAG_OVERLAPPED, NULL);
if (pipe_ == INVALID_HANDLE_VALUE)
return false;
DWORD mode = PIPE_READMODE_MESSAGE;
@ -462,6 +472,12 @@ void PipeMessageHandler::ClosePipe() {
connection_established_ = false;
write_overlapped_active_ = false;
free(tmp_packet_buf_);
tmp_packet_buf_ = NULL;
ResetEvent(wait_handles_[0]);
ResetEvent(wait_handles_[2]);
packets_mutex_.Acquire();
OutgoingPacket *packets = packets_;
packets_ = NULL;
@ -517,107 +533,83 @@ void PipeMessageHandler::SendNextQueuedWrite() {
write_overlapped_.hEvent = wait_handles_[2];
if (WriteFile(pipe_, p->data, p->size, NULL, &write_overlapped_) || GetLastError() == ERROR_IO_PENDING)
write_overlapped_active_ = true;
} else {
ResetEvent(wait_handles_[2]);
}
}
}
uint8 *PipeMessageHandler::ReadNamedPipeAsync(size_t *packet_size) {
OVERLAPPED ov = {0};
uint8 *result = NULL;
DWORD bytes_waiting = 0;
DWORD rv;
ov.hEvent = wait_handles_[0];
if (!ReadFile(pipe_, NULL, 0, NULL, &ov)) {
rv = GetLastError();
if (rv != ERROR_IO_PENDING && rv != ERROR_MORE_DATA)
goto getout;
}
#define TS_WAIT_BEGIN(t) switch(state_) { case t:
#define TS_WAIT_POINT(t) state_ = (t); return; case t:
#define TS_WAIT_END() }
if (!WaitAndHandleWrites(INFINITE)) {
CancelIo(pipe_);
write_overlapped_active_ = false;
goto getout;
}
void PipeMessageHandler::AdvanceStateMachine() {
DWORD rv, bytes_read;
PeekNamedPipe(pipe_, NULL, 0, NULL, &bytes_waiting, NULL);
if (bytes_waiting == 0)
goto getout; // this is typically what happens when pipe closes.
result = (uint8*)malloc(bytes_waiting);
if (!result)
goto getout;
if (!ReadFile(pipe_, result, bytes_waiting, NULL, &ov)) {
rv = GetLastError();
if (rv != ERROR_IO_PENDING)
goto getout;
}
if (!WaitAndHandleWrites(1000)) {
CancelIo(pipe_);
write_overlapped_active_ = false;
free(result);
result = NULL;
goto getout;
}
bytes_waiting = (uint32)ov.InternalHigh;
if (bytes_waiting == 0) {
free(result);
result = NULL;
goto getout;
}
*packet_size = bytes_waiting;
getout:
return result;
}
bool PipeMessageHandler::ConnectNamedPipeAsync() {
OVERLAPPED ov = {0};
DWORD rv;
bool result = false;
ov.hEvent = wait_handles_[0];
if (!ConnectNamedPipe(pipe_, &ov)) {
rv = GetLastError();
if (rv != ERROR_PIPE_CONNECTED && rv != ERROR_IO_PENDING)
goto getout;
}
if (!WaitAndHandleWrites(INFINITE)) {
CancelIo(pipe_);
write_overlapped_active_ = false;
goto getout;
}
result = true;
getout:
return result;
}
bool PipeMessageHandler::WaitAndHandleWrites(int delay) {
DWORD rv;
assert(thread_id_ == GetCurrentThreadId());
again:
rv = WaitForMultipleObjects(2 + write_overlapped_active_, wait_handles_, FALSE, delay);
if (rv == WAIT_OBJECT_0 + 2) {
assert(write_overlapped_active_);
write_overlapped_active_ = false;
// Remove the packet from the front of the queue, now
// that it was sent.
packets_mutex_.Acquire();
OutgoingPacket *p = packets_;
if ((packets_ = p->next) == NULL)
packets_end_ = &packets_;
packets_mutex_.Release();
free(p);
TS_WAIT_BEGIN(kStateNone)
for(;;) {
// Create a named pipe and wait for connections from the UI process
if (is_server_pipe_) {
if (!InitializeServerPipeAndWait()) {
if (!exit_thread_)
ExitProcess(1);
break;
}
TS_WAIT_POINT(kStateWaitConnect);
} else {
if (!InitializeClientPipe()) {
RINFO("Unable to connect to the TunSafe Service. Please make sure it's running.");
break;
}
}
connection_established_ = true;
delegate_->HandleNewConnection();
SendNextQueuedWrite();
goto again;
}
if (rv == WAIT_OBJECT_0 + 1) {
if (exit_ || !delegate_->HandleNotify())
return false;
SendNextQueuedWrite();
goto again;
}
return rv == WAIT_OBJECT_0;
for (;;) {
memset(&read_overlapped_, 0, sizeof(read_overlapped_));
read_overlapped_.hEvent = wait_handles_[0];
if (!ReadFile(pipe_, NULL, 0, NULL, &read_overlapped_)) {
rv = GetLastError();
if (rv != ERROR_IO_PENDING && rv != ERROR_MORE_DATA)
break;
}
TS_WAIT_POINT(kStateWaitReadLength);
PeekNamedPipe(pipe_, NULL, 0, NULL, &tmp_packet_size_, NULL);
if (tmp_packet_size_ == 0)
break;
free(tmp_packet_buf_);
tmp_packet_buf_ = (uint8*)malloc(tmp_packet_size_);
if (!tmp_packet_buf_)
break;
memset(&read_overlapped_, 0, sizeof(read_overlapped_));
read_overlapped_.hEvent = wait_handles_[0];
if (!ReadFile(pipe_, tmp_packet_buf_, tmp_packet_size_, NULL, &read_overlapped_)) {
rv = GetLastError();
if (rv != ERROR_IO_PENDING)
break;
}
TS_WAIT_POINT(kStateWaitReadPayload);
bytes_read = (uint32)read_overlapped_.InternalHigh;
if (bytes_read == 0)
break;
if (!delegate_->HandleMessage(tmp_packet_buf_[0], tmp_packet_buf_ + 1, bytes_read - 1)) {
ResetEvent(wait_handles_[0]);
TS_WAIT_POINT(kStateWaitTimeout);
break;
}
}
if (exit_thread_)
break;
delegate_->HandleDisconnect();
if (!is_server_pipe_)
break;
ClosePipe();
}
TS_WAIT_END()
ClosePipe();
}
DWORD WINAPI PipeMessageHandler::StaticThreadMain(void *x) {
@ -630,73 +622,49 @@ bool PipeMessageHandler::VerifyThread() {
DWORD PipeMessageHandler::ThreadMain() {
assert((thread_id_ = GetCurrentThreadId()) != 0);
assert(state_ == kStateNone);
while (!exit_) {
// Create a named pipe and wait for connections from the UI process
if (is_server_pipe_) {
if (!InitializeServerPipe()) {
if (!exit_)
ExitProcess(1);
AdvanceStateMachine();
for(;;) {
DWORD rv = WaitForMultipleObjects(3, wait_handles_, FALSE, (state_ == kStateWaitTimeout) ? 1000 : INFINITE);
// packet write finished?
if (rv == WAIT_OBJECT_0 + 2) {
assert(write_overlapped_active_);
write_overlapped_active_ = false;
// Remove the packet from the front of the queue, now that it was sent.
packets_mutex_.Acquire();
OutgoingPacket *p = packets_;
if ((packets_ = p->next) == NULL)
packets_end_ = &packets_;
packets_mutex_.Release();
free(p);
SendNextQueuedWrite();
// notification
} else if (rv == WAIT_OBJECT_0 + 1) {
if (exit_thread_ || !delegate_->HandleNotify())
break;
}
// Wait for a client to connect to us.
if (!ConnectNamedPipeAsync()) {
if (!exit_)
ExitProcess(1);
break;
}
// The notification event is set when there might be new messages to send,
// so try to send them.
SendNextQueuedWrite();
// read finished?
} else if (rv == WAIT_OBJECT_0) {
AdvanceStateMachine();
} else if (rv == WAIT_TIMEOUT) {
if (state_ == kStateWaitTimeout)
AdvanceStateMachine();
} else {
if (!InitializeClientPipe()) {
RINFO("Unable to connect to the TunSafe Service. Please make sure it's running.");
break;
}
assert(0);
}
connection_established_ = true;
if (!delegate_->HandleNewConnection())
goto closepipe;
SendNextQueuedWrite();
// Read/Process each message
for (;;) {
size_t message_size;
uint8 *message = ReadNamedPipeAsync(&message_size);
if (!message)
break;
if (message_size) {
if (!delegate_->HandleMessage(message[0], message + 1, message_size - 1)) {
FlushWrites(1000);
break;
}
}
free(message);
}
if (exit_)
break;
delegate_->HandleDisconnect();
if (!is_server_pipe_)
break;
closepipe:
ClosePipe();
}
ClosePipe();
return 0;
}
void PipeMessageHandler::FlushWrites(int delay) {
ResetEvent(wait_handles_[0]);
WaitAndHandleWrites(1000);
}
bool PipeMessageHandler::StartThread() {
DWORD thread_id;
assert(thread_ == NULL);
@ -706,7 +674,7 @@ bool PipeMessageHandler::StartThread() {
void PipeMessageHandler::StopThread() {
if (thread_ != NULL) {
exit_ = true;
exit_thread_ = true;
SetEvent(wait_handles_[1]);
WaitForSingleObject(thread_, INFINITE);
CloseHandle(thread_);
@ -749,8 +717,6 @@ static wchar_t *RegReadStrW(HKEY hkey, const wchar_t *key, const wchar_t *def) {
}
unsigned TunsafeServiceImpl::OnStart(int argc, wchar_t **argv) {
message_handler_.StartThread();
uint32 service_flags = RegReadInt(hkey_, "ServiceStartupFlags", 0);
if ( (service_flags & kStartupFlag_BackgroundService) && (service_flags & kStartupFlag_ConnectWhenWindowsStarts) ) {
char *conf = RegReadStr(hkey_, "LastUsedConfigFile", "");
@ -761,6 +727,7 @@ unsigned TunsafeServiceImpl::OnStart(int argc, wchar_t **argv) {
free(conf);
}
message_handler_.StartThread();
return 0;
}
@ -870,11 +837,10 @@ bool TunsafeServiceImpl::HandleNotify() {
return true;
}
bool TunsafeServiceImpl::HandleNewConnection() {
void TunsafeServiceImpl::HandleNewConnection() {
did_send_getstate_ = false;
did_authenticate_user_ = false;
last_line_sent_ = 0;
return true;
}
void TunsafeServiceImpl::HandleDisconnect() {
@ -1151,11 +1117,10 @@ bool TunsafeServiceClient::HandleNotify() {
}
bool TunsafeServiceClient::HandleNewConnection() {
void TunsafeServiceClient::HandleNewConnection() {
message_handler_.WritePacket(SERVICE_REQ_LOGIN, (uint8*)&kTunsafeServiceProtocolVersion, 8);
if (want_stats_)
message_handler_.WritePacket(SERVICE_REQ_GETSTATS, &want_stats_, 1);
return true;
}
void TunsafeServiceClient::HandleDisconnect() {

View file

@ -28,7 +28,7 @@ public:
public:
virtual bool HandleMessage(int type, uint8 *data, size_t size) = 0;
virtual bool HandleNotify() = 0;
virtual bool HandleNewConnection() = 0;
virtual void HandleNewConnection() = 0;
virtual void HandleDisconnect() = 0;
};
@ -45,17 +45,14 @@ public:
bool VerifyThread();
void FlushWrites(int delay);
bool is_connected() { return connection_established_; }
private:
bool InitializeServerPipe();
bool InitializeServerPipeAndWait();
bool InitializeClientPipe();
void AdvanceStateMachine();
void ClosePipe();
DWORD ThreadMain();
void SendNextQueuedWrite();
uint8 *ReadNamedPipeAsync(size_t *packet_size);
bool ConnectNamedPipeAsync();
bool WaitAndHandleWrites(int delay);
static DWORD WINAPI StaticThreadMain(void *x);
Delegate *delegate_;
@ -63,26 +60,38 @@ private:
HANDLE pipe_;
HANDLE thread_;
HANDLE wait_handles_[3];
OVERLAPPED write_overlapped_;
bool write_overlapped_active_;
bool exit_;
bool exit_thread_;
bool is_server_pipe_;
bool connection_established_;
char *pipe_name_;
enum State {
kStateNone,
kStateWaitConnect,
kStateWaitReadLength,
kStateWaitReadPayload,
kStateWaitTimeout,
};
int state_;
struct OutgoingPacket {
OutgoingPacket *next;
uint32 size;
uint8 data[0];
};
OutgoingPacket *packets_, **packets_end_;
uint8 *tmp_packet_buf_;
DWORD tmp_packet_size_;
OVERLAPPED write_overlapped_, read_overlapped_;
Mutex packets_mutex_;
DWORD thread_id_;
};
class TunsafeServiceImpl : public TunsafeBackend::Delegate, public PipeMessageHandler::Delegate {
public:
TunsafeServiceImpl();
@ -99,7 +108,7 @@ public:
// -- from PipeMessageHandler::Delegate
virtual bool HandleMessage(int type, uint8 *data, size_t size);
virtual bool HandleNotify();
virtual bool HandleNewConnection();
virtual void HandleNewConnection();
virtual void HandleDisconnect();
// virtual methods
@ -155,7 +164,7 @@ public:
// -- from PipeMessageHandler::Delegate
virtual bool HandleMessage(int type, uint8 *data, size_t size);
virtual bool HandleNotify();
virtual bool HandleNewConnection();
virtual void HandleNewConnection();
virtual void HandleDisconnect();
protected: