diff --git a/network_win32.cpp b/network_win32.cpp index c2c0a7f..2a4bfe7 100644 --- a/network_win32.cpp +++ b/network_win32.cpp @@ -1237,22 +1237,6 @@ bool TunWin32Adapter::OpenAdapter(TunsafeBackendWin32 *backend, DWORD open_flags return (handle_ != NULL); } -static inline bool CheckFirstNbitsEquals(const byte *a, const byte *b, size_t n) { - return memcmp(a, b, n >> 3) == 0 && ((n & 7) == 0 || !((a[n >> 3] ^ b[n >> 3]) & (0xff << (8 - (n & 7))))); -} - -static bool IsWgCidrAddrSubsetOf(const WgCidrAddr &inner, const WgCidrAddr &outer) { - return inner.size == outer.size && inner.cidr >= outer.cidr && - CheckFirstNbitsEquals(inner.addr, outer.addr, outer.cidr); -} - -static bool IsWgCidrAddrSubsetOfAny(const WgCidrAddr &inner, const std::vector &addr) { - for (auto &a : addr) - if (IsWgCidrAddrSubsetOf(inner, a)) - return true; - return false; -} - bool TunWin32Adapter::ConfigureAdapter(const TunInterface::TunConfig &&config, TunInterface::TunConfigOut *out) { DWORD len, err; diff --git a/tunsafe_ipaddr.cpp b/tunsafe_ipaddr.cpp index 94b0809..b1f5f20 100644 --- a/tunsafe_ipaddr.cpp +++ b/tunsafe_ipaddr.cpp @@ -82,6 +82,22 @@ bool ParseCidrAddr(const char *s, WgCidrAddr *out) { return false; } +static inline bool CheckFirstNbitsEquals(const byte *a, const byte *b, size_t n) { + return memcmp(a, b, n >> 3) == 0 && ((n & 7) == 0 || !((a[n >> 3] ^ b[n >> 3]) & (0xff << (8 - (n & 7))))); +} + +static bool IsWgCidrAddrSubsetOf(const WgCidrAddr &inner, const WgCidrAddr &outer) { + return inner.size == outer.size && inner.cidr >= outer.cidr && + CheckFirstNbitsEquals(inner.addr, outer.addr, outer.cidr); +} + +bool IsWgCidrAddrSubsetOfAny(const WgCidrAddr &inner, const std::vector &addr) { + for (auto &a : addr) + if (IsWgCidrAddrSubsetOf(inner, a)) + return true; + return false; +} + static Mutex g_dns_mutex; // This starts a background thread for running DNS resolving. diff --git a/tunsafe_ipaddr.h b/tunsafe_ipaddr.h index 8b398a2..2e80bc6 100644 --- a/tunsafe_ipaddr.h +++ b/tunsafe_ipaddr.h @@ -2,7 +2,7 @@ #define TUNSAFE_IPADDR_H_ #include "tunsafe_types.h" - +#include #if !defined(OS_WIN) #include #include @@ -27,9 +27,10 @@ class DnsResolver; const char *print_ip_prefix(char buf[kSizeOfAddress], int family, const void *ip, int prefixlen); char *PrintIpAddr(const IpAddr &addr, char buf[kSizeOfAddress]); char *PrintWgCidrAddr(const WgCidrAddr &addr, char buf[kSizeOfAddress]); - bool ParseCidrAddr(const char *s, WgCidrAddr *out); +bool IsWgCidrAddrSubsetOfAny(const WgCidrAddr &inner, const std::vector &addr); + enum { kParseSockaddrDontDoNAT64 = 1, }; diff --git a/wireguard.cpp b/wireguard.cpp index 4be0024..4c64613 100644 --- a/wireguard.cpp +++ b/wireguard.cpp @@ -190,9 +190,12 @@ bool WireguardProcessor::ConfigureTun() { if (it->cidr == 0) peer->allow_endpoint_change_ = false; } - // Add the peer's endpoint to the route exclusion list. + } + for (WgPeer *peer = dev_.first_peer(); peer; peer = peer->next_peer_) { + // Add the peer's endpoint to the route exclusion list, but only + // if the endpoint is covered by one of the included_routes. WgCidrAddr endpoint_addr = WgCidrAddrFromIpAddr(peer->endpoint_); - if (endpoint_addr.size != 0) + if (endpoint_addr.size != 0 && IsWgCidrAddrSubsetOfAny(endpoint_addr, config.included_routes)) config.excluded_routes.push_back(endpoint_addr); } }