Allow multiple DNS servers

This commit is contained in:
Ludvig Strigeus 2018-09-10 23:46:49 +02:00
parent 0cea6c2960
commit de6e187db9
6 changed files with 136 additions and 93 deletions

View file

@ -79,7 +79,6 @@ public:
std::vector<std::string> post_down;
};
struct TunConfig {
// IP address and netmask of the tun device
in_addr_t ip;
@ -96,10 +95,6 @@ public:
// Set this to configure a default route for ipv6
bool use_ipv6_default_route;
// DHCP settings
const byte *dhcp_options;
size_t dhcp_options_size;
// This holds the address of the vpn endpoint, so those get routed to the old iface.
uint32 default_route_endpoint_v4;
@ -110,10 +105,9 @@ public:
uint8 ipv6_address[16];
uint8 ipv6_cidr;
bool set_ipv6_dns;
// Set this to configure DNS server.
uint8 dns_server_v6[16];
// Set this to configure DNS server for ipv4,ipv6
std::vector<IpAddr> ipv4_dns;
std::vector<IpAddr> ipv6_dns;
// This holds the address of the vpn endpoint, so those get routed to the old iface.
uint8 default_route_endpoint_v6[16];

View file

@ -33,6 +33,12 @@ enum {
ROUTE_BLOCK_ON = 2,
ROUTE_BLOCK_PENDING = 3,
};
enum {
kMetricNone = -1,
kMetricAutomatic = 0,
};
static uint8 internet_route_blocking_state;
static SLIST_HEADER freelist_head;
@ -53,6 +59,10 @@ void FreePacket(Packet *packet) {
InterlockedPushEntrySList(&freelist_head, &packet->list_entry);
}
static bool IsIpv6AddressSet(const void *p) {
return (ReadLE64(p) | ReadLE64((char*)p + 8)) != 0;
}
void OsGetRandomBytes(uint8 *data, size_t data_size) {
static BOOLEAN(APIENTRY *pfn)(void*, ULONG);
static bool resolved;
@ -968,15 +978,19 @@ DWORD SetMtuOnNetworkAdapter(NET_LUID *InterfaceLuid, ADDRESS_FAMILY family, int
return err;
}
DWORD SetMetricOnNetworkAdapter(NET_LUID *InterfaceLuid, ADDRESS_FAMILY family, int new_metric) {
DWORD SetMetricOnNetworkAdapter(NET_LUID *InterfaceLuid, ADDRESS_FAMILY family, int new_metric, int *old_metric) {
MIB_IPINTERFACE_ROW row;
DWORD err;
if (old_metric)
*old_metric = kMetricNone;
InitializeIpInterfaceEntry(&row);
row.Family = family;
row.InterfaceLuid = *InterfaceLuid;
if ((err = GetIpInterfaceEntry(&row)) == 0) {
if (old_metric)
*old_metric = row.UseAutomaticMetric ? kMetricAutomatic : row.Metric;
row.Metric = new_metric;
row.UseAutomaticMetric = (new_metric == 0);
row.UseAutomaticMetric = (new_metric == kMetricAutomatic);
if (row.Family == AF_INET)
row.SitePrefixLength = 0;
err = SetIpInterfaceEntry(&row);
@ -993,9 +1007,20 @@ static const char *PrintIPV6(const uint8 new_address[16]) {
return buf;
}
static bool SetIPV6AddressOnInterface(NET_LUID *InterfaceLuid, const uint8 new_address[16], int new_cidr) {
static void AssignIpv6Address(const void *new_address, int new_cidr, WgCidrAddr *target) {
target->size = 128;
target->cidr = new_cidr;
memcpy(target->addr, new_address, 16);
}
// Set new_cidr to 0 to clear it.
static bool SetIPV6AddressOnInterface(NET_LUID *InterfaceLuid, const uint8 new_address[16], int new_cidr, WgCidrAddr *old_address) {
NETIO_STATUS Status;
PMIB_UNICASTIPADDRESS_TABLE table = NULL;
if (old_address)
memset(old_address, 0, sizeof(WgCidrAddr));
Status = GetUnicastIpAddressTable(AF_INET6, &table);
if (Status != 0) {
RERROR("GetUnicastAddressTable Failed. Error %d\n", Status);
@ -1011,6 +1036,8 @@ static bool SetIPV6AddressOnInterface(NET_LUID *InterfaceLuid, const uint8 new_a
found_row = true;
continue;
}
if (old_address != NULL)
AssignIpv6Address(&row->Address.Ipv6.sin6_addr, row->OnLinkPrefixLength, old_address);
Status = DeleteUnicastIpAddressEntry(row);
if (Status)
RERROR("Error %d deleting IPv6 address: %s/%d", Status, PrintIPV6((uint8*)&row->Address.Ipv6.sin6_addr), row->OnLinkPrefixLength);
@ -1026,6 +1053,12 @@ static bool SetIPV6AddressOnInterface(NET_LUID *InterfaceLuid, const uint8 new_a
return true;
}
if (!IsIpv6AddressSet(new_address))
return true;
if (old_address != NULL)
old_address->size = 128;
MIB_UNICASTIPADDRESS_ROW Row;
InitializeUnicastIpAddressEntry(&Row);
Row.OnLinkPrefixLength = new_cidr;
@ -1041,26 +1074,25 @@ static bool SetIPV6AddressOnInterface(NET_LUID *InterfaceLuid, const uint8 new_a
return true;
}
static bool IsIpv6AddressSet(const void *p) {
return (ReadLE64(p) | ReadLE64((char*)p + 8)) != 0;
}
static bool SetIPV6DnsOnInterface(NET_LUID *InterfaceLuid, const uint8 new_address[16]) {
static bool SetIPV6DnsOnInterface(NET_LUID *InterfaceLuid, const IpAddr *new_address, size_t new_address_size) {
char buf[128];
char ipv6[128];
NET_IFINDEX InterfaceIndex;
if (ConvertInterfaceLuidToIndex(InterfaceLuid, &InterfaceIndex))
return false;
if (IsIpv6AddressSet(new_address)) {
if (!inet_ntop(AF_INET6, (void*)new_address, ipv6, sizeof(ipv6)))
return false;
snprintf(buf, sizeof(buf), "netsh interface ipv6 set dns name=%d static %s validate=no", InterfaceIndex, ipv6);
if (new_address_size) {
for (size_t i = 0; i < new_address_size; i++) {
if (!inet_ntop(AF_INET6, (void*)&new_address[i].sin6.sin6_addr, ipv6, sizeof(ipv6)))
return false;
snprintf(buf, sizeof(buf), "netsh interface ipv6 %s dns name=%d static %s validate=no", (i == 0) ? "set" : "add", InterfaceIndex, ipv6);
if (!RunNetsh(buf))
return false;
}
return true;
} else {
snprintf(buf, sizeof(buf), "netsh interface ipv6 delete dns name=%d all", InterfaceIndex);
return RunNetsh(buf);
}
return RunNetsh(buf);
}
static uint32 ComputeIpv4DefaultRoute(uint32 ip, uint32 netmask) {
@ -1100,6 +1132,10 @@ static bool AddMultipleCatchallRoutes(int inet, int bits, const uint8 *target, c
TunWin32Adapter::TunWin32Adapter(DnsBlocker *dns_blocker) {
handle_ = NULL;
dns_blocker_ = dns_blocker;
old_ipv6_address_.size = 0;
old_ipv6_metric_ = kMetricNone;
old_ipv4_metric_ = kMetricNone;
has_dns6_setting_ = false;
}
TunWin32Adapter::~TunWin32Adapter() {
@ -1122,7 +1158,10 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt
RERROR("Pre command failed!");
return false;
}
pre_down_ = std::move(config.pre_post_commands.pre_down);
post_down_ = std::move(config.pre_post_commands.post_down);
memset(info, 0, sizeof(info));
if (DeviceIoControl(handle_, TAP_IOCTL_GET_VERSION, &info, sizeof(info),
&info, sizeof(info), &len, NULL)) {
@ -1167,17 +1206,22 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt
}
}
bool has_dns_setting = false;
// Set DHCP config string
if (config.dhcp_options_size != 0) {
if (config.ipv4_dns.size()) {
enum { kMaxDnsServers = 4 };
uint8 dhcp_options[2 + kMaxDnsServers * 4]; // max 4 dns servers
size_t num_dns = std::min<size_t>(config.ipv4_dns.size(), kMaxDnsServers);
dhcp_options[0] = 6;
dhcp_options[1] = (uint8)(num_dns * 4);
for(size_t i = 0; i < num_dns; i++)
memcpy(&dhcp_options[2 + i * 4], &config.ipv4_dns[i].sin.sin_addr, num_dns * 4);
DWORD dhcp_options_size = (DWORD)(num_dns * 4 + 2);
byte output[10];
if (!DeviceIoControl(handle_, TAP_IOCTL_CONFIG_DHCP_SET_OPT,
(void*)config.dhcp_options, (DWORD)config.dhcp_options_size, output, sizeof(output), &len, NULL)) {
(void*)dhcp_options, dhcp_options_size, output, sizeof(output), &len, NULL)) {
RERROR("DeviceIoControl(TAP_IOCTL_CONFIG_DHCP_SET_OPT) failed");
return false;
}
has_dns_setting = true;
}
// Get device MAC address
@ -1196,8 +1240,7 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt
return false;
}
NET_LUID InterfaceLuid = {0};
bool has_interface_luid = GetNetLuidFromGuid(guid_, &InterfaceLuid);
bool has_interface_luid = GetNetLuidFromGuid(guid_, &interface_luid_);
if (!has_interface_luid) {
RERROR("Unable to determine interface luid for %s.", guid_);
@ -1207,36 +1250,38 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt
DWORD err;
if (config.mtu) {
err = SetMtuOnNetworkAdapter(&InterfaceLuid, AF_INET, config.mtu);
err = SetMtuOnNetworkAdapter(&interface_luid_, AF_INET, config.mtu);
if (err)
RERROR("SetMtuOnNetworkAdapter IPv4 failed: %d", err);
if (config.ipv6_cidr) {
err = SetMtuOnNetworkAdapter(&InterfaceLuid, AF_INET6, config.mtu);
err = SetMtuOnNetworkAdapter(&interface_luid_, AF_INET6, config.mtu);
if (err)
RERROR("SetMtuOnNetworkAdapter IPv6 failed: %d", err);
}
}
has_dns6_setting_ = false;
if (config.ipv6_cidr) {
SetIPV6AddressOnInterface(&InterfaceLuid, config.ipv6_address, config.ipv6_cidr);
if (config.set_ipv6_dns) {
has_dns_setting |= IsIpv6AddressSet(config.dns_server_v6);
if (!SetIPV6DnsOnInterface(&InterfaceLuid, config.dns_server_v6)) {
SetIPV6AddressOnInterface(&interface_luid_, config.ipv6_address, config.ipv6_cidr, &old_ipv6_address_);
if (config.ipv6_dns.size()) {
has_dns6_setting_ = true;
if (!SetIPV6DnsOnInterface(&interface_luid_, config.ipv6_dns.data(), config.ipv6_dns.size())) {
RERROR("SetIPV6DnsOnInterface: failed");
}
}
}
if (has_dns_setting && config.block_dns_on_adapters) {
if ((config.ipv4_dns.size() || has_dns6_setting_) && config.block_dns_on_adapters) {
RINFO("Blocking standard DNS on all adapters");
dns_blocker_->BlockDnsExceptOnAdapter(InterfaceLuid, config.ipv6_cidr != 0);
dns_blocker_->BlockDnsExceptOnAdapter(interface_luid_, has_dns6_setting_);
err = SetMetricOnNetworkAdapter(&InterfaceLuid, AF_INET, 2);
err = SetMetricOnNetworkAdapter(&interface_luid_, AF_INET, 2, &old_ipv4_metric_);
if (err)
RERROR("SetMetricOnNetworkAdapter IPv4 failed: %d", err);
if (config.ipv6_cidr) {
err = SetMetricOnNetworkAdapter(&InterfaceLuid, AF_INET6, 2);
err = SetMetricOnNetworkAdapter(&interface_luid_, AF_INET6, 2, &old_ipv6_metric_);
if (err)
RERROR("SetMetricOnNetworkAdapter IPv6 failed: %d", err);
}
@ -1257,14 +1302,14 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt
uint32 default_route_endpoint_v4 = ToBE32(config.default_route_endpoint_v4);
// Delete any current /1 default routes and read some stuff from the routing table.
if (!GetDefaultRouteAndDeleteOldRoutes(AF_INET, &InterfaceLuid, block_all_traffic_route, config.use_ipv4_default_route ? (uint8*)&default_route_endpoint_v4 : NULL, &ri)) {
if (!GetDefaultRouteAndDeleteOldRoutes(AF_INET, &interface_luid_, block_all_traffic_route, config.use_ipv4_default_route ? (uint8*)&default_route_endpoint_v4 : NULL, &ri)) {
RERROR("Unable to read old default gateway and delete old default routes.");
return false;
}
if (config.ipv6_cidr) {
// Delete any current /1 default routes and read some stuff from the routing table.
if (!GetDefaultRouteAndDeleteOldRoutes(AF_INET6, &InterfaceLuid, block_all_traffic_route, config.use_ipv6_default_route ? (uint8*)config.default_route_endpoint_v6 : NULL, &ri6)) {
if (!GetDefaultRouteAndDeleteOldRoutes(AF_INET6, &interface_luid_, block_all_traffic_route, config.use_ipv6_default_route ? (uint8*)config.default_route_endpoint_v6 : NULL, &ri6)) {
RERROR("Unable to read old default gateway and delete old default routes for IPv6.");
return false;
}
@ -1293,7 +1338,7 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt
if (ibs & kBlockInternet_Firewall) {
RINFO("Blocking all regular Internet traffic%s", ri.found_default_adapter ? " (except DHCP)" : "");
AddPersistentInternetBlocking(ri.found_default_adapter ? &ri.default_adapter : NULL, InterfaceLuid, config.ipv6_cidr != 0);
AddPersistentInternetBlocking(ri.found_default_adapter ? &ri.default_adapter : NULL, interface_luid_, config.ipv6_cidr != 0);
} else {
SetInternetFwBlockingState(false);
}
@ -1313,7 +1358,7 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt
}
// Either add 4 routes or 2 routes, depending on if we use route blocking.
uint32 be = ToBE32(default_route_v4);
if (!AddMultipleCatchallRoutes(AF_INET, block_all_traffic_route ? 2 : 1, (uint8*)&be, InterfaceLuid, &routes_to_undo_))
if (!AddMultipleCatchallRoutes(AF_INET, block_all_traffic_route ? 2 : 1, (uint8*)&be, interface_luid_, &routes_to_undo_))
RERROR("Unable to add new default ipv4 route.");
}
@ -1332,7 +1377,7 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt
return false;
}
}
if (!AddMultipleCatchallRoutes(AF_INET6, block_all_traffic_route ? 2 : 1, default_route_v6, InterfaceLuid, &routes_to_undo_))
if (!AddMultipleCatchallRoutes(AF_INET6, block_all_traffic_route ? 2 : 1, default_route_v6, interface_luid_, &routes_to_undo_))
RERROR("Unable to add new default ipv6 route.");
}
}
@ -1341,9 +1386,9 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt
for (auto it = config.extra_routes.begin(); it != config.extra_routes.end(); ++it) {
if (it->size == 32) {
uint32 be = ToBE32(default_route_v4);
AddRoute(AF_INET, it->addr, it->cidr, &be, &InterfaceLuid, &routes_to_undo_);
AddRoute(AF_INET, it->addr, it->cidr, &be, &interface_luid_, &routes_to_undo_);
} else if (it->size == 128 && config.ipv6_cidr) {
AddRoute(AF_INET6, it->addr, it->cidr, default_route_v6, &InterfaceLuid, &routes_to_undo_);
AddRoute(AF_INET6, it->addr, it->cidr, default_route_v6, &interface_luid_, &routes_to_undo_);
}
}
@ -1359,7 +1404,7 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt
}
NET_IFINDEX InterfaceIndex;
if (ConvertInterfaceLuidToIndex(&InterfaceLuid, &InterfaceIndex)) {
if (ConvertInterfaceLuidToIndex(&interface_luid_, &InterfaceIndex)) {
RERROR("Unable to get index of adapter");
return false;
}
@ -1376,9 +1421,6 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt
RunPrePostCommand(config.pre_post_commands.post_up);
pre_down_ = std::move(config.pre_post_commands.pre_down);
post_down_ = std::move(config.pre_post_commands.post_down);
return true;
}
@ -1394,6 +1436,21 @@ void TunWin32Adapter::CloseAdapter() {
handle_ = NULL;
}
if (old_ipv6_address_.size != 0)
SetIPV6AddressOnInterface(&interface_luid_, old_ipv6_address_.addr, old_ipv6_address_.cidr, NULL);
if (old_ipv4_metric_ != kMetricNone)
SetMetricOnNetworkAdapter(&interface_luid_, AF_INET, old_ipv4_metric_, NULL);
if (old_ipv6_metric_ != kMetricNone)
SetMetricOnNetworkAdapter(&interface_luid_, AF_INET6, old_ipv6_metric_, NULL);
old_ipv4_metric_ = old_ipv6_metric_ = -1;
old_ipv6_address_.size = 0;
if (has_dns6_setting_) {
has_dns6_setting_ = false;
SetIPV6DnsOnInterface(&interface_luid_, NULL, 0);
}
for (auto it = routes_to_undo_.begin(); it != routes_to_undo_.end(); ++it)
DeleteRoute(&*it);
routes_to_undo_.clear();
@ -1527,14 +1584,15 @@ TunWin32Iocp::~TunWin32Iocp() {
bool TunWin32Iocp::Initialize(const TunConfig &&config, TunConfigOut *out) {
assert(thread_ == NULL);
if (!adapter_.OpenAdapter(&backend_->stop_mode_, FILE_FLAG_OVERLAPPED))
return false;
completion_port_handle_ = CreateIoCompletionPort(adapter_.handle(), NULL, NULL, 0);
if (completion_port_handle_ == NULL)
return false;
return adapter_.InitAdapter(std::move(config), out);
if (adapter_.OpenAdapter(&backend_->stop_mode_, FILE_FLAG_OVERLAPPED)) {
completion_port_handle_ = CreateIoCompletionPort(adapter_.handle(), NULL, NULL, 0);
if (completion_port_handle_ != NULL) {
if (adapter_.InitAdapter(std::move(config), out))
return true;
}
}
CloseTun();
return false;
}
void TunWin32Iocp::CloseTun() {
@ -1759,8 +1817,11 @@ TunWin32Overlapped::~TunWin32Overlapped() {
bool TunWin32Overlapped::Initialize(const TunConfig &&config, TunConfigOut *out) {
CloseTun();
return adapter_.OpenAdapter(&backend_->stop_mode_, FILE_FLAG_OVERLAPPED) &&
adapter_.InitAdapter(std::move(config), out);
if (adapter_.OpenAdapter(&backend_->stop_mode_, FILE_FLAG_OVERLAPPED) &&
adapter_.InitAdapter(std::move(config), out))
return true;
CloseTun();
return false;
}
void TunWin32Overlapped::CloseTun() {

View file

@ -112,10 +112,17 @@ private:
std::vector<MIB_IPFORWARD_ROW2> routes_to_undo_;
uint8 mac_adress_[6];
bool has_dns6_setting_;
int mtu_;
char guid_[64];
int old_ipv4_metric_, old_ipv6_metric_;
WgCidrAddr old_ipv6_address_;
NET_LUID interface_luid_;
std::vector<std::string> pre_down_, post_down_;
char guid_[64];
};
// Implementation of TUN interface handling using IO Completion Ports

View file

@ -36,7 +36,6 @@ WireguardProcessor::WireguardProcessor(UdpInterface *udp, TunInterface *tun, Pro
add_routes_mode_ = true;
dns_blocking_ = true;
internet_blocking_ = kBlockInternet_Default;
dns6_addr_.sin.sin_family = dns_addr_.sin.sin_family = 0;
stats_last_bytes_in_ = 0;
stats_last_bytes_out_ = 0;
@ -54,12 +53,9 @@ void WireguardProcessor::SetListenPort(int listen_port) {
}
bool WireguardProcessor::AddDnsServer(const IpAddr &sin) {
IpAddr *target = (sin.sin.sin_family == AF_INET6) ? &dns6_addr_ : &dns_addr_;
if (target->sin.sin_family != 0)
return false;
*target = sin;
return true;
void WireguardProcessor::AddDnsServer(const IpAddr &sin) {
std::vector<IpAddr> *target = (sin.sin.sin_family == AF_INET6) ? &dns6_addr_ : &dns_addr_;
target->push_back(sin);
}
bool WireguardProcessor::SetTunAddress(const WgCidrAddr &addr) {
@ -201,24 +197,12 @@ bool WireguardProcessor::Start() {
}
}
uint8 dhcp_options[6];
config.block_dns_on_adapters = dns_blocking_ && ((config.use_ipv4_default_route && dns_addr_.sin.sin_family == AF_INET) ||
(config.use_ipv6_default_route && dns6_addr_.sin6.sin6_family == AF_INET6));
config.block_dns_on_adapters = dns_blocking_ && ((config.use_ipv4_default_route && dns_addr_.size()) ||
(config.use_ipv6_default_route && dns6_addr_.size()));
config.internet_blocking = internet_blocking_;
if (dns_addr_.sin.sin_family == AF_INET) {
dhcp_options[0] = 6;
dhcp_options[1] = 4;
memcpy(&dhcp_options[2], &dns_addr_.sin.sin_addr, 4);
config.dhcp_options = dhcp_options;
config.dhcp_options_size = sizeof(dhcp_options);
}
if (dns6_addr_.sin6.sin6_family == AF_INET6) {
config.set_ipv6_dns = true;
memcpy(&config.dns_server_v6, &dns6_addr_.sin6.sin6_addr, 16);
}
config.ipv4_dns = dns_addr_;
config.ipv6_dns = dns6_addr_;
TunInterface::TunConfigOut config_out;
if (!tun_->Initialize(std::move(config), &config_out))

View file

@ -69,7 +69,7 @@ public:
~WireguardProcessor();
void SetListenPort(int listen_port);
bool AddDnsServer(const IpAddr &sin);
void AddDnsServer(const IpAddr &sin);
bool SetTunAddress(const WgCidrAddr &addr);
void AddExcludedIp(const WgCidrAddr &cidr_addr);
void SetMtu(int mtu);
@ -138,7 +138,7 @@ private:
WgCidrAddr tun_addr_;
WgCidrAddr tun6_addr_;
IpAddr dns_addr_, dns6_addr_;
std::vector<IpAddr> dns_addr_, dns6_addr_;
TunInterface::PrePostCommands pre_post_;

View file

@ -333,10 +333,7 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) {
for (size_t i = 0; i < ss.size(); i++) {
if (!ParseSockaddrInWithoutPort(ss[i], &sin, dns_resolver_))
return false;
if (!wg_->AddDnsServer(sin)) {
RERROR("Multiple DNS not allowed.");
return false;
}
wg_->AddDnsServer(sin);
}
} else if (strcmp(key, "BlockDNS") == 0) {
bool v;