diff --git a/service_win32.cpp b/service_win32.cpp index 002abc9..48866ff 100644 --- a/service_win32.cpp +++ b/service_win32.cpp @@ -401,9 +401,11 @@ PipeMessageHandler::PipeMessageHandler(const char *pipe_name, bool is_server_pip thread_ = NULL; packets_end_ = &packets_; write_overlapped_active_ = false; - exit_ = false; + exit_thread_ = false; connection_established_ = false; thread_id_ = 0; + state_ = kStateNone; + tmp_packet_buf_ = NULL; } PipeMessageHandler::~PipeMessageHandler() { @@ -414,7 +416,7 @@ PipeMessageHandler::~PipeMessageHandler() { free(pipe_name_); } -bool PipeMessageHandler::InitializeServerPipe() { +bool PipeMessageHandler::InitializeServerPipeAndWait() { int BUFSIZE = 2048; SECURITY_ATTRIBUTES saPipeSecurity = {0}; uint8 buf[SECURITY_DESCRIPTOR_MIN_LENGTH]; @@ -437,15 +439,23 @@ bool PipeMessageHandler::InitializeServerPipe() { PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE | PIPE_REJECT_REMOTE_CLIENTS | PIPE_WAIT, PIPE_UNLIMITED_INSTANCES, BUFSIZE, BUFSIZE, 0, &saPipeSecurity); - return pipe_ != INVALID_HANDLE_VALUE; + if (pipe_ == INVALID_HANDLE_VALUE) + return false; + + memset(&read_overlapped_, 0, sizeof(read_overlapped_)); + read_overlapped_.hEvent = wait_handles_[0]; + if (!ConnectNamedPipe(pipe_, &read_overlapped_)) { + DWORD rv = GetLastError(); + if (rv != ERROR_PIPE_CONNECTED && rv != ERROR_IO_PENDING) + return false; + } + return true; } bool PipeMessageHandler::InitializeClientPipe() { assert(pipe_ == INVALID_HANDLE_VALUE); - pipe_ = CreateFile( - pipe_name_, - GENERIC_READ | GENERIC_WRITE, 0, - NULL, OPEN_EXISTING, FILE_FLAG_OVERLAPPED, NULL); + pipe_ = CreateFile(pipe_name_, GENERIC_READ | GENERIC_WRITE, 0, NULL, + OPEN_EXISTING, FILE_FLAG_OVERLAPPED, NULL); if (pipe_ == INVALID_HANDLE_VALUE) return false; DWORD mode = PIPE_READMODE_MESSAGE; @@ -462,6 +472,12 @@ void PipeMessageHandler::ClosePipe() { connection_established_ = false; write_overlapped_active_ = false; + free(tmp_packet_buf_); + tmp_packet_buf_ = NULL; + + ResetEvent(wait_handles_[0]); + ResetEvent(wait_handles_[2]); + packets_mutex_.Acquire(); OutgoingPacket *packets = packets_; packets_ = NULL; @@ -517,107 +533,83 @@ void PipeMessageHandler::SendNextQueuedWrite() { write_overlapped_.hEvent = wait_handles_[2]; if (WriteFile(pipe_, p->data, p->size, NULL, &write_overlapped_) || GetLastError() == ERROR_IO_PENDING) write_overlapped_active_ = true; + } else { + ResetEvent(wait_handles_[2]); } } } -uint8 *PipeMessageHandler::ReadNamedPipeAsync(size_t *packet_size) { - OVERLAPPED ov = {0}; - uint8 *result = NULL; - DWORD bytes_waiting = 0; - DWORD rv; - ov.hEvent = wait_handles_[0]; - if (!ReadFile(pipe_, NULL, 0, NULL, &ov)) { - rv = GetLastError(); - if (rv != ERROR_IO_PENDING && rv != ERROR_MORE_DATA) - goto getout; - } +#define TS_WAIT_BEGIN(t) switch(state_) { case t: +#define TS_WAIT_POINT(t) state_ = (t); return; case t: +#define TS_WAIT_END() } - if (!WaitAndHandleWrites(INFINITE)) { - CancelIo(pipe_); - write_overlapped_active_ = false; - goto getout; - } +void PipeMessageHandler::AdvanceStateMachine() { + DWORD rv, bytes_read; - PeekNamedPipe(pipe_, NULL, 0, NULL, &bytes_waiting, NULL); - if (bytes_waiting == 0) - goto getout; // this is typically what happens when pipe closes. - - result = (uint8*)malloc(bytes_waiting); - if (!result) - goto getout; - - if (!ReadFile(pipe_, result, bytes_waiting, NULL, &ov)) { - rv = GetLastError(); - if (rv != ERROR_IO_PENDING) - goto getout; - } - if (!WaitAndHandleWrites(1000)) { - CancelIo(pipe_); - write_overlapped_active_ = false; - free(result); - result = NULL; - goto getout; - } - bytes_waiting = (uint32)ov.InternalHigh; - if (bytes_waiting == 0) { - free(result); - result = NULL; - goto getout; - } - *packet_size = bytes_waiting; -getout: - return result; -} - -bool PipeMessageHandler::ConnectNamedPipeAsync() { - OVERLAPPED ov = {0}; - DWORD rv; - bool result = false; - ov.hEvent = wait_handles_[0]; - if (!ConnectNamedPipe(pipe_, &ov)) { - rv = GetLastError(); - if (rv != ERROR_PIPE_CONNECTED && rv != ERROR_IO_PENDING) - goto getout; - } - if (!WaitAndHandleWrites(INFINITE)) { - CancelIo(pipe_); - write_overlapped_active_ = false; - goto getout; - } - result = true; -getout: - return result; -} - -bool PipeMessageHandler::WaitAndHandleWrites(int delay) { - DWORD rv; - assert(thread_id_ == GetCurrentThreadId()); - -again: - rv = WaitForMultipleObjects(2 + write_overlapped_active_, wait_handles_, FALSE, delay); - if (rv == WAIT_OBJECT_0 + 2) { - assert(write_overlapped_active_); - write_overlapped_active_ = false; - // Remove the packet from the front of the queue, now - // that it was sent. - packets_mutex_.Acquire(); - OutgoingPacket *p = packets_; - if ((packets_ = p->next) == NULL) - packets_end_ = &packets_; - packets_mutex_.Release(); - free(p); + TS_WAIT_BEGIN(kStateNone) + for(;;) { + // Create a named pipe and wait for connections from the UI process + if (is_server_pipe_) { + if (!InitializeServerPipeAndWait()) { + if (!exit_thread_) + ExitProcess(1); + break; + } + TS_WAIT_POINT(kStateWaitConnect); + } else { + if (!InitializeClientPipe()) { + RINFO("Unable to connect to the TunSafe Service. Please make sure it's running."); + break; + } + } + connection_established_ = true; + delegate_->HandleNewConnection(); SendNextQueuedWrite(); - goto again; - } - if (rv == WAIT_OBJECT_0 + 1) { - if (exit_ || !delegate_->HandleNotify()) - return false; - SendNextQueuedWrite(); - goto again; - } - return rv == WAIT_OBJECT_0; + for (;;) { + memset(&read_overlapped_, 0, sizeof(read_overlapped_)); + read_overlapped_.hEvent = wait_handles_[0]; + if (!ReadFile(pipe_, NULL, 0, NULL, &read_overlapped_)) { + rv = GetLastError(); + if (rv != ERROR_IO_PENDING && rv != ERROR_MORE_DATA) + break; + } + TS_WAIT_POINT(kStateWaitReadLength); + PeekNamedPipe(pipe_, NULL, 0, NULL, &tmp_packet_size_, NULL); + if (tmp_packet_size_ == 0) + break; + + free(tmp_packet_buf_); + tmp_packet_buf_ = (uint8*)malloc(tmp_packet_size_); + if (!tmp_packet_buf_) + break; + + memset(&read_overlapped_, 0, sizeof(read_overlapped_)); + read_overlapped_.hEvent = wait_handles_[0]; + if (!ReadFile(pipe_, tmp_packet_buf_, tmp_packet_size_, NULL, &read_overlapped_)) { + rv = GetLastError(); + if (rv != ERROR_IO_PENDING) + break; + } + TS_WAIT_POINT(kStateWaitReadPayload); + bytes_read = (uint32)read_overlapped_.InternalHigh; + if (bytes_read == 0) + break; + if (!delegate_->HandleMessage(tmp_packet_buf_[0], tmp_packet_buf_ + 1, bytes_read - 1)) { + ResetEvent(wait_handles_[0]); + TS_WAIT_POINT(kStateWaitTimeout); + break; + } + } + if (exit_thread_) + break; + delegate_->HandleDisconnect(); + if (!is_server_pipe_) + break; + ClosePipe(); + } + TS_WAIT_END() + ClosePipe(); } DWORD WINAPI PipeMessageHandler::StaticThreadMain(void *x) { @@ -630,73 +622,49 @@ bool PipeMessageHandler::VerifyThread() { DWORD PipeMessageHandler::ThreadMain() { assert((thread_id_ = GetCurrentThreadId()) != 0); + assert(state_ == kStateNone); - while (!exit_) { - // Create a named pipe and wait for connections from the UI process - if (is_server_pipe_) { - if (!InitializeServerPipe()) { - if (!exit_) - ExitProcess(1); + AdvanceStateMachine(); + + for(;;) { + DWORD rv = WaitForMultipleObjects(3, wait_handles_, FALSE, (state_ == kStateWaitTimeout) ? 1000 : INFINITE); + + // packet write finished? + if (rv == WAIT_OBJECT_0 + 2) { + assert(write_overlapped_active_); + + write_overlapped_active_ = false; + + // Remove the packet from the front of the queue, now that it was sent. + packets_mutex_.Acquire(); + OutgoingPacket *p = packets_; + if ((packets_ = p->next) == NULL) + packets_end_ = &packets_; + packets_mutex_.Release(); + free(p); + SendNextQueuedWrite(); + + // notification + } else if (rv == WAIT_OBJECT_0 + 1) { + if (exit_thread_ || !delegate_->HandleNotify()) break; - } - // Wait for a client to connect to us. - if (!ConnectNamedPipeAsync()) { - if (!exit_) - ExitProcess(1); - break; - } + // The notification event is set when there might be new messages to send, + // so try to send them. + SendNextQueuedWrite(); + + // read finished? + } else if (rv == WAIT_OBJECT_0) { + AdvanceStateMachine(); + } else if (rv == WAIT_TIMEOUT) { + if (state_ == kStateWaitTimeout) + AdvanceStateMachine(); } else { - if (!InitializeClientPipe()) { - RINFO("Unable to connect to the TunSafe Service. Please make sure it's running."); - break; - } + assert(0); } - - connection_established_ = true; - if (!delegate_->HandleNewConnection()) - goto closepipe; - - SendNextQueuedWrite(); - - // Read/Process each message - for (;;) { - size_t message_size; - uint8 *message = ReadNamedPipeAsync(&message_size); - if (!message) - break; - - if (message_size) { - if (!delegate_->HandleMessage(message[0], message + 1, message_size - 1)) { - FlushWrites(1000); - break; - } - } - free(message); - } - - if (exit_) - break; - - delegate_->HandleDisconnect(); - - if (!is_server_pipe_) - break; - -closepipe: - ClosePipe(); } - - - ClosePipe(); - return 0; } -void PipeMessageHandler::FlushWrites(int delay) { - ResetEvent(wait_handles_[0]); - WaitAndHandleWrites(1000); -} - bool PipeMessageHandler::StartThread() { DWORD thread_id; assert(thread_ == NULL); @@ -706,7 +674,7 @@ bool PipeMessageHandler::StartThread() { void PipeMessageHandler::StopThread() { if (thread_ != NULL) { - exit_ = true; + exit_thread_ = true; SetEvent(wait_handles_[1]); WaitForSingleObject(thread_, INFINITE); CloseHandle(thread_); @@ -749,8 +717,6 @@ static wchar_t *RegReadStrW(HKEY hkey, const wchar_t *key, const wchar_t *def) { } unsigned TunsafeServiceImpl::OnStart(int argc, wchar_t **argv) { - message_handler_.StartThread(); - uint32 service_flags = RegReadInt(hkey_, "ServiceStartupFlags", 0); if ( (service_flags & kStartupFlag_BackgroundService) && (service_flags & kStartupFlag_ConnectWhenWindowsStarts) ) { char *conf = RegReadStr(hkey_, "LastUsedConfigFile", ""); @@ -761,6 +727,7 @@ unsigned TunsafeServiceImpl::OnStart(int argc, wchar_t **argv) { free(conf); } + message_handler_.StartThread(); return 0; } @@ -870,11 +837,10 @@ bool TunsafeServiceImpl::HandleNotify() { return true; } -bool TunsafeServiceImpl::HandleNewConnection() { +void TunsafeServiceImpl::HandleNewConnection() { did_send_getstate_ = false; did_authenticate_user_ = false; last_line_sent_ = 0; - return true; } void TunsafeServiceImpl::HandleDisconnect() { @@ -1151,11 +1117,10 @@ bool TunsafeServiceClient::HandleNotify() { } -bool TunsafeServiceClient::HandleNewConnection() { +void TunsafeServiceClient::HandleNewConnection() { message_handler_.WritePacket(SERVICE_REQ_LOGIN, (uint8*)&kTunsafeServiceProtocolVersion, 8); if (want_stats_) message_handler_.WritePacket(SERVICE_REQ_GETSTATS, &want_stats_, 1); - return true; } void TunsafeServiceClient::HandleDisconnect() { diff --git a/service_win32.h b/service_win32.h index eee61be..e19766e 100644 --- a/service_win32.h +++ b/service_win32.h @@ -28,7 +28,7 @@ public: public: virtual bool HandleMessage(int type, uint8 *data, size_t size) = 0; virtual bool HandleNotify() = 0; - virtual bool HandleNewConnection() = 0; + virtual void HandleNewConnection() = 0; virtual void HandleDisconnect() = 0; }; @@ -45,17 +45,14 @@ public: bool VerifyThread(); - void FlushWrites(int delay); bool is_connected() { return connection_established_; } private: - bool InitializeServerPipe(); + bool InitializeServerPipeAndWait(); bool InitializeClientPipe(); + void AdvanceStateMachine(); void ClosePipe(); DWORD ThreadMain(); void SendNextQueuedWrite(); - uint8 *ReadNamedPipeAsync(size_t *packet_size); - bool ConnectNamedPipeAsync(); - bool WaitAndHandleWrites(int delay); static DWORD WINAPI StaticThreadMain(void *x); Delegate *delegate_; @@ -63,26 +60,38 @@ private: HANDLE pipe_; HANDLE thread_; HANDLE wait_handles_[3]; - OVERLAPPED write_overlapped_; bool write_overlapped_active_; - bool exit_; + bool exit_thread_; bool is_server_pipe_; bool connection_established_; char *pipe_name_; + enum State { + kStateNone, + kStateWaitConnect, + kStateWaitReadLength, + kStateWaitReadPayload, + kStateWaitTimeout, + }; + + int state_; + struct OutgoingPacket { OutgoingPacket *next; uint32 size; uint8 data[0]; }; OutgoingPacket *packets_, **packets_end_; + uint8 *tmp_packet_buf_; + DWORD tmp_packet_size_; + + OVERLAPPED write_overlapped_, read_overlapped_; Mutex packets_mutex_; DWORD thread_id_; }; - class TunsafeServiceImpl : public TunsafeBackend::Delegate, public PipeMessageHandler::Delegate { public: TunsafeServiceImpl(); @@ -99,7 +108,7 @@ public: // -- from PipeMessageHandler::Delegate virtual bool HandleMessage(int type, uint8 *data, size_t size); virtual bool HandleNotify(); - virtual bool HandleNewConnection(); + virtual void HandleNewConnection(); virtual void HandleDisconnect(); // virtual methods @@ -155,7 +164,7 @@ public: // -- from PipeMessageHandler::Delegate virtual bool HandleMessage(int type, uint8 *data, size_t size); virtual bool HandleNotify(); - virtual bool HandleNewConnection(); + virtual void HandleNewConnection(); virtual void HandleDisconnect(); protected: