// SPDX-License-Identifier: AGPL-1.0-only // Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. #ifndef TUNSAFE_NETWORK_BSD_H_ #define TUNSAFE_NETWORK_BSD_H_ #include #include #include #include #include "network_common.h" class BaseSocketBsd; class TcpSocketBsd; class WireguardProcessor; class Packet; class NetworkBsd { friend class BaseSocketBsd; friend class TcpSocketBsd; friend class UdpSocketBsd; friend class TunSocketBsd; public: enum { #if defined(OS_ANDROID) WithSigalarmSupport = 0, #else WithSigalarmSupport = 1 #endif }; class NetworkBsdDelegate { public: virtual void OnSecondLoop(uint64 now) {} virtual void RunAllMainThreadScheduled() {} }; explicit NetworkBsd(NetworkBsdDelegate *delegate, int max_sockets); ~NetworkBsd(); void RunLoop(const sigset_t *sigmask); void PostExit() { exit_ = true; } bool *exit_flag() { return &exit_; } bool *sigalarm_flag() { return &sigalarm_flag_; } TcpSocketBsd *tcp_sockets() { return tcp_sockets_; } bool overload() { return overload_; } private: void RemoveFromRoundRobin(int slot); void ReallocateIov(size_t i); void EnsureIovAllocated(); Packet *read_packet_; bool exit_; bool overload_; bool sigalarm_flag_; enum { // This controls the max # of sockets we can support kMaxIovec = 16, }; int num_sock_; int num_roundrobin_; int num_endloop_; int max_sockets_; SimplePacketPool packet_pool_; NetworkBsdDelegate *delegate_; struct pollfd *pollfd_; BaseSocketBsd **sockets_; BaseSocketBsd **roundrobin_; BaseSocketBsd **endloop_; // Linked list of all tcp sockets TcpSocketBsd *tcp_sockets_; struct iovec iov_[kMaxIovec]; Packet *iov_packets_[kMaxIovec]; }; class BaseSocketBsd { friend class NetworkBsd; public: BaseSocketBsd(NetworkBsd *network) : pollfd_slot_(-1), roundrobin_slot_(-1), endloop_slot_(-1), fd_(-1), network_(network) {} virtual ~BaseSocketBsd(); virtual void HandleEvents(int revents) = 0; // Return |false| to remove socket from roundrobin list. virtual bool DoRoundRobin() { return false; } virtual void DoEndloop() {} virtual void Periodic() {} // Make sure this socket gets called during each round robin step. void AddToRoundRobin(); // Make sure this sockets get called at the end of the loop void AddToEndLoop(); int GetFd() { return fd_; } protected: void SetPollFlags(int events) { network_->pollfd_[pollfd_slot_].events = events; } void InitPollSlot(int fd, int events); bool HasFreePollSlot() { return network_->num_sock_ != network_->max_sockets_; } void CloseSocket(); NetworkBsd *network_; int pollfd_slot_; int roundrobin_slot_; int endloop_slot_; int fd_; }; class TunSocketBsd : public BaseSocketBsd { public: explicit TunSocketBsd(NetworkBsd *network, WireguardProcessor *processor); virtual ~TunSocketBsd(); bool Initialize(int fd); virtual void HandleEvents(int revents) override; virtual bool DoRoundRobin() override; void WritePacket(Packet *packet); bool tun_interface_gone() const { return tun_interface_gone_; } private: bool DoRead(); bool DoWrite(); bool tun_readable_, tun_writable_; bool tun_interface_gone_; Packet *tun_queue_, **tun_queue_end_; WireguardProcessor *processor_; }; class UdpSocketBsd : public BaseSocketBsd { public: explicit UdpSocketBsd(NetworkBsd *network, WireguardProcessor *processor); virtual ~UdpSocketBsd(); bool Initialize(int listen_port); virtual void HandleEvents(int revents) override; virtual bool DoRoundRobin() override; bool DoRead(); bool DoWrite(); void WritePacket(Packet *packet); private: bool udp_readable_, udp_writable_; Packet *udp_queue_, **udp_queue_end_; WireguardProcessor *processor_; }; #if defined(OS_LINUX) // Keeps track of when the unix socket gets deleted class UnixSocketDeletionWatcher { public: UnixSocketDeletionWatcher(); ~UnixSocketDeletionWatcher(); bool Start(const char *path, bool *flag_to_set); void Stop(); bool Poll(const char *path) { return false; } private: static void *RunThread(void *arg); void *RunThreadInner(); const char *path_; int inotify_fd_; int pid_; int pipes_[2]; pthread_t thread_; bool *flag_to_set_; }; #else // !defined(OS_LINUX) // all other platforms that lack inotify class UnixSocketDeletionWatcher { public: UnixSocketDeletionWatcher() {} ~UnixSocketDeletionWatcher() {} bool Start(const char *path, bool *flag_to_set) { return true; } void Stop() {} bool Poll(const char *path); }; #endif // !defined(OS_LINUX) class UnixDomainSocketListenerBsd : public BaseSocketBsd { public: explicit UnixDomainSocketListenerBsd(NetworkBsd *network, WireguardProcessor *processor); virtual ~UnixDomainSocketListenerBsd(); bool Initialize(const char *devname); bool Start(bool *exit_flag) { return un_deletion_watcher_.Start(un_addr_.sun_path, exit_flag); } void Stop() { un_deletion_watcher_.Stop(); } virtual void HandleEvents(int revents) override; virtual void Periodic() override; private: struct sockaddr_un un_addr_; WireguardProcessor *processor_; UnixSocketDeletionWatcher un_deletion_watcher_; }; class UnixDomainSocketChannelBsd : public BaseSocketBsd { public: explicit UnixDomainSocketChannelBsd(NetworkBsd *network, WireguardProcessor *processor, int fd); virtual ~UnixDomainSocketChannelBsd(); virtual void HandleEvents(int revents) override; private: bool HandleEventsInner(int revents); WireguardProcessor *processor_; std::string inbuf_, outbuf_; }; class TcpSocketListenerBsd : public BaseSocketBsd { public: explicit TcpSocketListenerBsd(NetworkBsd *bsd, WireguardProcessor *processor); virtual ~TcpSocketListenerBsd(); bool Initialize(int port); virtual void HandleEvents(int revents) override; virtual void Periodic() override; private: WireguardProcessor *processor_; }; class TcpSocketBsd : public BaseSocketBsd { public: explicit TcpSocketBsd(NetworkBsd *bsd, WireguardProcessor *processor); virtual ~TcpSocketBsd(); void InitializeIncoming(int fd, const IpAddr &addr); bool InitializeOutgoing(const IpAddr &addr); void WritePacket(Packet *packet); virtual void HandleEvents(int revents) override; virtual void DoEndloop() override; TcpSocketBsd *next() { return next_; } uint8 endpoint_protocol() { return endpoint_protocol_; } const IpAddr &endpoint() { return endpoint_; } public: uint8 age; uint8 handshake_attempts; private: void DoRead(); void DoWrite(); void CloseSocketAndDestroy(); bool readable_, writable_; bool got_eof_; uint8 endpoint_protocol_; bool want_connect_; uint32 wqueue_bytes_; Packet *wqueue_, **wqueue_end_; TcpSocketBsd *next_; WireguardProcessor *processor_; TcpPacketHandler tcp_packet_handler_; IpAddr endpoint_; }; class NotificationPipeBsd : public BaseSocketBsd { public: NotificationPipeBsd(NetworkBsd *network); ~NotificationPipeBsd(); typedef void CallbackFunc(void *x); void InjectCallback(CallbackFunc *func, void *param); void Wakeup(); virtual void HandleEvents(int revents) override; private: struct CallbackState { CallbackFunc *func; void *param; CallbackState *next; }; int pipe_fds_[2]; std::atomic injected_cb_; }; #endif // TUNSAFE_NETWORK_BSD_H_