diff --git a/.gitignore b/.gitignore index fc80387..71ce91a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,6 @@ /Build /Win32/ /TunSafe.aps -/ipch /*.sdf /*vcxproj.user /*.opensdf @@ -15,4 +14,5 @@ /*.psess /*.vspx /installer/*.zip -/config/ \ No newline at end of file +/config/ +/tunsafe.com/ diff --git a/TunSafe.conf b/TunSafe.conf index 073c9e5..f1c4a46 100644 --- a/TunSafe.conf +++ b/TunSafe.conf @@ -4,7 +4,6 @@ ListenPort = 51820 Address = 192.168.2.2/24 MTU = 1420 - [Peer] PublicKey = 2m1BdGW9AwwF5dqaGm0NgMggdDZDUPFAL4JxCySdgBw= #AllowedIPs = 0.0.0.0/0, fc00::2/64 @@ -14,3 +13,4 @@ Endpoint = 192.168.1.4:8040 PersistentKeepalive = 25 + diff --git a/TunSafe.rc b/TunSafe.rc index 7139e02..55f2de5 100644 Binary files a/TunSafe.rc and b/TunSafe.rc differ diff --git a/TunSafe.sln b/TunSafe.sln index cc929b1..40906e6 100644 --- a/TunSafe.sln +++ b/TunSafe.sln @@ -25,22 +25,4 @@ Global GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE EndGlobalSection - GlobalSection(Performance) = preSolution - HasPerformanceSessions = true - EndGlobalSection - GlobalSection(Performance) = preSolution - HasPerformanceSessions = true - EndGlobalSection - GlobalSection(Performance) = preSolution - HasPerformanceSessions = true - EndGlobalSection - GlobalSection(Performance) = preSolution - HasPerformanceSessions = true - EndGlobalSection - GlobalSection(Performance) = preSolution - HasPerformanceSessions = true - EndGlobalSection - GlobalSection(Performance) = preSolution - HasPerformanceSessions = true - EndGlobalSection EndGlobal diff --git a/TunSafe.vcxproj b/TunSafe.vcxproj index f9118c6..2645b39 100644 --- a/TunSafe.vcxproj +++ b/TunSafe.vcxproj @@ -103,7 +103,6 @@ Windows true kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies);ws2_32.lib;Iphlpapi.lib - RequireAdministrator @@ -122,7 +121,6 @@ kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies);ws2_32.lib;Iphlpapi.lib;Comctl32.lib - RequireAdministrator @@ -142,7 +140,6 @@ true true kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies);ws2_32.lib;Iphlpapi.lib - RequireAdministrator @@ -167,11 +164,13 @@ true true kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;%(AdditionalDependencies);ws2_32.lib;Iphlpapi.lib - RequireAdministrator + + + @@ -179,12 +178,15 @@ + + + @@ -193,13 +195,18 @@ + + + + + NotUsing @@ -229,10 +236,10 @@ + - diff --git a/TunSafe.vcxproj.filters b/TunSafe.vcxproj.filters index 220b7f6..2a444bb 100644 --- a/TunSafe.vcxproj.filters +++ b/TunSafe.vcxproj.filters @@ -53,6 +53,9 @@ Source Files + + Source Files + crypto @@ -71,6 +74,21 @@ Source Files + + Source Files\Win32 + + + Source Files\Win32 + + + Source Files + + + Source Files + + + Source Files\Win32 + @@ -109,6 +127,9 @@ Source Files + + Source Files + crypto @@ -121,6 +142,18 @@ Source Files + + Source Files\Win32 + + + Source Files\Win32 + + + Source Files + + + Source Files + @@ -128,8 +161,8 @@ - + diff --git a/benchmark.cpp b/benchmark.cpp index 4d30d80..43b0ee4 100644 --- a/benchmark.cpp +++ b/benchmark.cpp @@ -66,11 +66,15 @@ void Benchmark() { fake_glb = dst; +size_t max_bytes = 1000000000; +#if defined(ARCH_CPU_ARM_FAMILY) + max_bytes = 100000000; +#endif auto RunOneBenchmark = [&](const char *name, const std::function &ff) { uint64 bytes = 0; QueryPerformanceCounter((LARGE_INTEGER*)&b); size_t i; - for (i = 0; bytes < 1000000000; i++) + for (i = 0; bytes < max_bytes; i++) bytes += ff(i); QueryPerformanceCounter((LARGE_INTEGER*)&a); RINFO("%s: %f MB/s", name, (double)bytes * 0.000001 / (a - b) * f); diff --git a/build.py b/build.py index 934e577..8bcfe72 100644 --- a/build.py +++ b/build.py @@ -11,9 +11,10 @@ import re MSBUILD_PATH = r"C:\Dev\VS2017\MSBuild\15.0\Bin\MSBuild.exe" NSIS_PATH = r'C:\Dev\NSIS\makeNSIS.EXE' + SIGNTOOL_PATH = r'c:\Program Files (x86)\Windows Kits\10\bin\10.0.15063.0\x86\signtool.exe' -SIGNTOOL_KEY_PATH = '' # put key here -SIGNTOOL_PASS = '' # put key pass here +SIGNTOOL_KEY_PATH = "" # path to key file +SIGNTOOL_PASS = "" # password def RmTree(path): try: diff --git a/build_freebsd.sh b/build_freebsd.sh old mode 100644 new mode 100755 index 8f516b5..b546216 --- a/build_freebsd.sh +++ b/build_freebsd.sh @@ -1,5 +1,4 @@ -g++7 -I . -O2 -static -mssse3 -o tunsafe benchmark.cpp tunsafe_cpu.cpp wireguard_config.cpp \ +g++7 -I . -O2 -DNDEBUG -static -mssse3 -o tunsafe benchmark.cpp tunsafe_cpu.cpp wireguard_config.cpp ip_to_peer_map.cpp tunsafe_threading.cpp \ wireguard.cpp wireguard_proto.cpp util.cpp network_bsd.cpp network_bsd_common.cpp \ crypto/blake2s.cpp crypto/blake2s_sse.cpp crypto/chacha20poly1305.cpp crypto/curve25519-donna.cpp \ crypto/siphash.cpp crypto/chacha20_x64_gas.s crypto/poly1305_x64_gas.s ipzip2/ipzip2.cpp -lrt -pthread - diff --git a/build_linux.sh b/build_linux.sh old mode 100644 new mode 100755 index 85a8a0b..026955e --- a/build_linux.sh +++ b/build_linux.sh @@ -1,6 +1,6 @@ #!/bin/sh clang++-6.0 -c -march=skylake-avx512 crypto/poly1305_x64_gas.s crypto/chacha20_x64_gas.s -clang++-6.0 -I . -O3 -mssse3 -pthread -lrt -o tunsafe util.cpp wireguard_config.cpp wireguard.cpp \ +clang++-6.0 -I . -O3 -DNDEBUG -mssse3 -pthread -lrt -o tunsafe util.cpp wireguard_config.cpp wireguard.cpp ip_to_peer_map.cpp tunsafe_threading.cpp \ wireguard_proto.cpp network_bsd.cpp network_bsd_common.cpp tunsafe_cpu.cpp benchmark.cpp crypto/blake2s.cpp crypto/blake2s_sse.cpp crypto/chacha20poly1305.cpp \ crypto/curve25519-donna.cpp crypto/siphash.cpp chacha20_x64_gas.o crypto/aesgcm/aesni_gcm_x64_gas.s \ crypto/aesgcm/aesni_x64_gas.s crypto/aesgcm/aesgcm.cpp poly1305_x64_gas.o ipzip2/ipzip2.cpp \ diff --git a/build_linux_rpi.sh b/build_linux_rpi.sh new file mode 100755 index 0000000..96208a8 --- /dev/null +++ b/build_linux_rpi.sh @@ -0,0 +1,11 @@ +#!/bin/sh + +set -e + +cpp -D__ARM_ARCH__=7 crypto/chacha20/chacha20-arm.s > crypto/chacha20/chacha20-arm.preprocessed.s +cpp -D__ARM_ARCH__=7 crypto/poly1305/poly1305-arm.s > crypto/poly1305/poly1305-arm.preprocessed.s + +g++-6 -mfpu=neon -I . -g -O2 -DNDEBUG -fno-omit-frame-pointer -march=armv7-a -mthumb -std=c++11 -pthread -lrt -o tunsafe util.cpp wireguard_config.cpp wireguard.cpp ip_to_peer_map.cpp tunsafe_threading.cpp \ +wireguard_proto.cpp network_bsd.cpp network_bsd_common.cpp tunsafe_cpu.cpp benchmark.cpp crypto/blake2s.cpp crypto/chacha20poly1305.cpp \ +crypto/curve25519-donna.cpp crypto/siphash.cpp crypto/aesgcm/aesgcm.cpp ipzip2/ipzip2.cpp \ +crypto/chacha20/chacha20-arm.preprocessed.s crypto/poly1305/poly1305-arm.preprocessed.s diff --git a/build_osx.sh b/build_osx.sh old mode 100644 new mode 100755 index c77dc01..29a02c0 --- a/build_osx.sh +++ b/build_osx.sh @@ -4,7 +4,7 @@ set -e clang++ -c -mavx512f -mavx512vl crypto/poly1305_x64_gas_macosx.s crypto/chacha20_x64_gas_macosx.s clang++ -g -O3 -I . -std=c++11 -DNDEBUG=1 -fno-exceptions -fno-rtti -ffunction-sections -o tunsafe \ -wireguard_config.cpp wireguard.cpp wireguard_proto.cpp util.cpp network_bsd.cpp network_bsd_common.cpp benchmark.cpp tunsafe_cpu.cpp \ +wireguard_config.cpp ip_to_peer_map.cpp tunsafe_threading.cpp wireguard.cpp wireguard_proto.cpp util.cpp network_bsd.cpp network_bsd_common.cpp benchmark.cpp tunsafe_cpu.cpp \ crypto/blake2s.cpp crypto/blake2s_sse.cpp crypto/chacha20poly1305.cpp crypto/curve25519-donna.cpp \ crypto/siphash.cpp crypto/aesgcm/aesgcm.cpp ipzip2/ipzip2.cpp \ crypto/aesgcm/aesni_gcm_x64_gas_macosx.s crypto/aesgcm/aesni_x64_gas_macosx.s crypto/aesgcm/ghash_x64_gas_macosx.s \ diff --git a/crypto/blake2s.cpp b/crypto/blake2s.cpp old mode 100644 new mode 100755 diff --git a/crypto/blake2s_sse.cpp b/crypto/blake2s_sse.cpp old mode 100644 new mode 100755 diff --git a/crypto/chacha20/chacha20-arm.pl b/crypto/chacha20/chacha20-arm.pl new file mode 100644 index 0000000..cec1b89 --- /dev/null +++ b/crypto/chacha20/chacha20-arm.pl @@ -0,0 +1,1160 @@ +#! /usr/bin/env perl +# Copyright 2016 The OpenSSL Project Authors. All Rights Reserved. +# +# Licensed under the OpenSSL license (the "License"). You may not use +# this file except in compliance with the License. You can obtain a copy +# in the file LICENSE in the source distribution or at +# https://www.openssl.org/source/license.html + +# +# ==================================================================== +# Written by Andy Polyakov for the OpenSSL +# project. The module is, however, dual licensed under OpenSSL and +# CRYPTOGAMS licenses depending on where you obtain it. For further +# details see http://www.openssl.org/~appro/cryptogams/. +# ==================================================================== +# +# December 2014 +# +# ChaCha20 for ARMv4. +# +# Performance in cycles per byte out of large buffer. +# +# IALU/gcc-4.4 1xNEON 3xNEON+1xIALU +# +# Cortex-A5 19.3(*)/+95% 21.8 14.1 +# Cortex-A8 10.5(*)/+160% 13.9 6.35 +# Cortex-A9 12.9(**)/+110% 14.3 6.50 +# Cortex-A15 11.0/+40% 16.0 5.00 +# Snapdragon S4 11.5/+125% 13.6 4.90 +# +# (*) most "favourable" result for aligned data on little-endian +# processor, result for misaligned data is 10-15% lower; +# (**) this result is a trade-off: it can be improved by 20%, +# but then Snapdragon S4 and Cortex-A8 results get +# 20-25% worse; + +$flavour = shift; +if ($flavour=~/\w[\w\-]*\.\w+$/) { $output=$flavour; undef $flavour; } +else { while (($output=shift) && ($output!~/\w[\w\-]*\.\w+$/)) {} } + +if ($flavour && $flavour ne "void") { + $0 =~ m/(.*[\/\\])[^\/\\]+$/; $dir=$1; + ( $xlate="${dir}../arm-xlate.pl" and -f $xlate ) or + ( $xlate="${dir}../../perlasm/arm-xlate.pl" and -f $xlate) or + die "can't locate arm-xlate.pl"; + + open STDOUT,"| \"$^X\" $xlate $flavour $output"; +} else { + open STDOUT,">$output"; +} + +sub AUTOLOAD() # thunk [simplified] x86-style perlasm +{ my $opcode = $AUTOLOAD; $opcode =~ s/.*:://; $opcode =~ s/_/\./; + my $arg = pop; + $arg = "#$arg" if ($arg*1 eq $arg); + $code .= "\t$opcode\t".join(',',@_,$arg)."\n"; +} + +my @x=map("r$_",(0..7,"x","x","x","x",12,"x",14,"x")); +my @t=map("r$_",(8..11)); + +sub ROUND { +my ($a0,$b0,$c0,$d0)=@_; +my ($a1,$b1,$c1,$d1)=map(($_&~3)+(($_+1)&3),($a0,$b0,$c0,$d0)); +my ($a2,$b2,$c2,$d2)=map(($_&~3)+(($_+1)&3),($a1,$b1,$c1,$d1)); +my ($a3,$b3,$c3,$d3)=map(($_&~3)+(($_+1)&3),($a2,$b2,$c2,$d2)); +my $odd = $d0&1; +my ($xc,$xc_) = (@t[0..1]); +my ($xd,$xd_) = $odd ? (@t[2],@x[$d1]) : (@x[$d0],@t[2]); +my @ret; + + # Consider order in which variables are addressed by their + # index: + # + # a b c d + # + # 0 4 8 12 < even round + # 1 5 9 13 + # 2 6 10 14 + # 3 7 11 15 + # 0 5 10 15 < odd round + # 1 6 11 12 + # 2 7 8 13 + # 3 4 9 14 + # + # 'a', 'b' are permanently allocated in registers, @x[0..7], + # while 'c's and pair of 'd's are maintained in memory. If + # you observe 'c' column, you'll notice that pair of 'c's is + # invariant between rounds. This means that we have to reload + # them once per round, in the middle. This is why you'll see + # bunch of 'c' stores and loads in the middle, but none in + # the beginning or end. If you observe 'd' column, you'll + # notice that 15 and 13 are reused in next pair of rounds. + # This is why these two are chosen for offloading to memory, + # to make loads count more. + push @ret,( + "&add (@x[$a0],@x[$a0],@x[$b0])", + "&mov ($xd,$xd,'ror#16')", + "&add (@x[$a1],@x[$a1],@x[$b1])", + "&mov ($xd_,$xd_,'ror#16')", + "&eor ($xd,$xd,@x[$a0],'ror#16')", + "&eor ($xd_,$xd_,@x[$a1],'ror#16')", + + "&add ($xc,$xc,$xd)", + "&mov (@x[$b0],@x[$b0],'ror#20')", + "&add ($xc_,$xc_,$xd_)", + "&mov (@x[$b1],@x[$b1],'ror#20')", + "&eor (@x[$b0],@x[$b0],$xc,'ror#20')", + "&eor (@x[$b1],@x[$b1],$xc_,'ror#20')", + + "&add (@x[$a0],@x[$a0],@x[$b0])", + "&mov ($xd,$xd,'ror#24')", + "&add (@x[$a1],@x[$a1],@x[$b1])", + "&mov ($xd_,$xd_,'ror#24')", + "&eor ($xd,$xd,@x[$a0],'ror#24')", + "&eor ($xd_,$xd_,@x[$a1],'ror#24')", + + "&add ($xc,$xc,$xd)", + "&mov (@x[$b0],@x[$b0],'ror#25')" ); + push @ret,( + "&str ($xd,'[sp,#4*(16+$d0)]')", + "&ldr ($xd,'[sp,#4*(16+$d2)]')" ) if ($odd); + push @ret,( + "&add ($xc_,$xc_,$xd_)", + "&mov (@x[$b1],@x[$b1],'ror#25')" ); + push @ret,( + "&str ($xd_,'[sp,#4*(16+$d1)]')", + "&ldr ($xd_,'[sp,#4*(16+$d3)]')" ) if (!$odd); + push @ret,( + "&eor (@x[$b0],@x[$b0],$xc,'ror#25')", + "&eor (@x[$b1],@x[$b1],$xc_,'ror#25')" ); + + $xd=@x[$d2] if (!$odd); + $xd_=@x[$d3] if ($odd); + push @ret,( + "&str ($xc,'[sp,#4*(16+$c0)]')", + "&ldr ($xc,'[sp,#4*(16+$c2)]')", + "&add (@x[$a2],@x[$a2],@x[$b2])", + "&mov ($xd,$xd,'ror#16')", + "&str ($xc_,'[sp,#4*(16+$c1)]')", + "&ldr ($xc_,'[sp,#4*(16+$c3)]')", + "&add (@x[$a3],@x[$a3],@x[$b3])", + "&mov ($xd_,$xd_,'ror#16')", + "&eor ($xd,$xd,@x[$a2],'ror#16')", + "&eor ($xd_,$xd_,@x[$a3],'ror#16')", + + "&add ($xc,$xc,$xd)", + "&mov (@x[$b2],@x[$b2],'ror#20')", + "&add ($xc_,$xc_,$xd_)", + "&mov (@x[$b3],@x[$b3],'ror#20')", + "&eor (@x[$b2],@x[$b2],$xc,'ror#20')", + "&eor (@x[$b3],@x[$b3],$xc_,'ror#20')", + + "&add (@x[$a2],@x[$a2],@x[$b2])", + "&mov ($xd,$xd,'ror#24')", + "&add (@x[$a3],@x[$a3],@x[$b3])", + "&mov ($xd_,$xd_,'ror#24')", + "&eor ($xd,$xd,@x[$a2],'ror#24')", + "&eor ($xd_,$xd_,@x[$a3],'ror#24')", + + "&add ($xc,$xc,$xd)", + "&mov (@x[$b2],@x[$b2],'ror#25')", + "&add ($xc_,$xc_,$xd_)", + "&mov (@x[$b3],@x[$b3],'ror#25')", + "&eor (@x[$b2],@x[$b2],$xc,'ror#25')", + "&eor (@x[$b3],@x[$b3],$xc_,'ror#25')" ); + + @ret; +} + +$code.=<<___; +#include "arm_arch.h" + +.text +#if defined(__thumb2__) || defined(__clang__) +.syntax unified +#endif +#if defined(__thumb2__) +.thumb +#else +.code 32 +#endif + +#if defined(__thumb2__) || defined(__clang__) +#define ldrhsb ldrbhs +#endif + +.align 5 +.Lsigma: +.long 0x61707865,0x3320646e,0x79622d32,0x6b206574 @ endian-neutral +.Lone: +.long 1,0,0,0 +#if __ARM_MAX_ARCH__>=7 +.LOPENSSL_armcap: +.word OPENSSL_armcap_P-.LChaCha20_ctr32 +#else +.word -1 +#endif + +.globl ChaCha20_ctr32 +.type ChaCha20_ctr32,%function +.align 5 +ChaCha20_ctr32: +.LChaCha20_ctr32: + ldr r12,[sp,#0] @ pull pointer to counter and nonce + stmdb sp!,{r0-r2,r4-r11,lr} +#if __ARM_ARCH__<7 && !defined(__thumb2__) + sub r14,pc,#16 @ ChaCha20_ctr32 +#else + adr r14,.LChaCha20_ctr32 +#endif + cmp r2,#0 @ len==0? +#ifdef __thumb2__ + itt eq +#endif + addeq sp,sp,#4*3 + beq .Lno_data +#if __ARM_MAX_ARCH__>=7 + cmp r2,#192 @ test len + bls .Lshort + ldr r4,[r14,#-32] + ldr r4,[r14,r4] +# ifdef __APPLE__ + ldr r4,[r4] +# endif + tst r4,#ARMV7_NEON + bne .LChaCha20_neon +.Lshort: +#endif + ldmia r12,{r4-r7} @ load counter and nonce + sub sp,sp,#4*(16) @ off-load area + sub r14,r14,#64 @ .Lsigma + stmdb sp!,{r4-r7} @ copy counter and nonce + ldmia r3,{r4-r11} @ load key + ldmia r14,{r0-r3} @ load sigma + stmdb sp!,{r4-r11} @ copy key + stmdb sp!,{r0-r3} @ copy sigma + str r10,[sp,#4*(16+10)] @ off-load "@x[10]" + str r11,[sp,#4*(16+11)] @ off-load "@x[11]" + b .Loop_outer_enter + +.align 4 +.Loop_outer: + ldmia sp,{r0-r9} @ load key material + str @t[3],[sp,#4*(32+2)] @ save len + str r12, [sp,#4*(32+1)] @ save inp + str r14, [sp,#4*(32+0)] @ save out +.Loop_outer_enter: + ldr @t[3], [sp,#4*(15)] + ldr @x[12],[sp,#4*(12)] @ modulo-scheduled load + ldr @t[2], [sp,#4*(13)] + ldr @x[14],[sp,#4*(14)] + str @t[3], [sp,#4*(16+15)] + mov @t[3],#10 + b .Loop + +.align 4 +.Loop: + subs @t[3],@t[3],#1 +___ + foreach (&ROUND(0, 4, 8,12)) { eval; } + foreach (&ROUND(0, 5,10,15)) { eval; } +$code.=<<___; + bne .Loop + + ldr @t[3],[sp,#4*(32+2)] @ load len + + str @t[0], [sp,#4*(16+8)] @ modulo-scheduled store + str @t[1], [sp,#4*(16+9)] + str @x[12],[sp,#4*(16+12)] + str @t[2], [sp,#4*(16+13)] + str @x[14],[sp,#4*(16+14)] + + @ at this point we have first half of 512-bit result in + @ @x[0-7] and second half at sp+4*(16+8) + + cmp @t[3],#64 @ done yet? +#ifdef __thumb2__ + itete lo +#endif + addlo r12,sp,#4*(0) @ shortcut or ... + ldrhs r12,[sp,#4*(32+1)] @ ... load inp + addlo r14,sp,#4*(0) @ shortcut or ... + ldrhs r14,[sp,#4*(32+0)] @ ... load out + + ldr @t[0],[sp,#4*(0)] @ load key material + ldr @t[1],[sp,#4*(1)] + +#if __ARM_ARCH__>=6 || !defined(__ARMEB__) +# if __ARM_ARCH__<7 + orr @t[2],r12,r14 + tst @t[2],#3 @ are input and output aligned? + ldr @t[2],[sp,#4*(2)] + bne .Lunaligned + cmp @t[3],#64 @ restore flags +# else + ldr @t[2],[sp,#4*(2)] +# endif + ldr @t[3],[sp,#4*(3)] + + add @x[0],@x[0],@t[0] @ accumulate key material + add @x[1],@x[1],@t[1] +# ifdef __thumb2__ + itt hs +# endif + ldrhs @t[0],[r12],#16 @ load input + ldrhs @t[1],[r12,#-12] + + add @x[2],@x[2],@t[2] + add @x[3],@x[3],@t[3] +# ifdef __thumb2__ + itt hs +# endif + ldrhs @t[2],[r12,#-8] + ldrhs @t[3],[r12,#-4] +# if __ARM_ARCH__>=6 && defined(__ARMEB__) + rev @x[0],@x[0] + rev @x[1],@x[1] + rev @x[2],@x[2] + rev @x[3],@x[3] +# endif +# ifdef __thumb2__ + itt hs +# endif + eorhs @x[0],@x[0],@t[0] @ xor with input + eorhs @x[1],@x[1],@t[1] + add @t[0],sp,#4*(4) + str @x[0],[r14],#16 @ store output +# ifdef __thumb2__ + itt hs +# endif + eorhs @x[2],@x[2],@t[2] + eorhs @x[3],@x[3],@t[3] + ldmia @t[0],{@t[0]-@t[3]} @ load key material + str @x[1],[r14,#-12] + str @x[2],[r14,#-8] + str @x[3],[r14,#-4] + + add @x[4],@x[4],@t[0] @ accumulate key material + add @x[5],@x[5],@t[1] +# ifdef __thumb2__ + itt hs +# endif + ldrhs @t[0],[r12],#16 @ load input + ldrhs @t[1],[r12,#-12] + add @x[6],@x[6],@t[2] + add @x[7],@x[7],@t[3] +# ifdef __thumb2__ + itt hs +# endif + ldrhs @t[2],[r12,#-8] + ldrhs @t[3],[r12,#-4] +# if __ARM_ARCH__>=6 && defined(__ARMEB__) + rev @x[4],@x[4] + rev @x[5],@x[5] + rev @x[6],@x[6] + rev @x[7],@x[7] +# endif +# ifdef __thumb2__ + itt hs +# endif + eorhs @x[4],@x[4],@t[0] + eorhs @x[5],@x[5],@t[1] + add @t[0],sp,#4*(8) + str @x[4],[r14],#16 @ store output +# ifdef __thumb2__ + itt hs +# endif + eorhs @x[6],@x[6],@t[2] + eorhs @x[7],@x[7],@t[3] + str @x[5],[r14,#-12] + ldmia @t[0],{@t[0]-@t[3]} @ load key material + str @x[6],[r14,#-8] + add @x[0],sp,#4*(16+8) + str @x[7],[r14,#-4] + + ldmia @x[0],{@x[0]-@x[7]} @ load second half + + add @x[0],@x[0],@t[0] @ accumulate key material + add @x[1],@x[1],@t[1] +# ifdef __thumb2__ + itt hs +# endif + ldrhs @t[0],[r12],#16 @ load input + ldrhs @t[1],[r12,#-12] +# ifdef __thumb2__ + itt hi +# endif + strhi @t[2],[sp,#4*(16+10)] @ copy "@x[10]" while at it + strhi @t[3],[sp,#4*(16+11)] @ copy "@x[11]" while at it + add @x[2],@x[2],@t[2] + add @x[3],@x[3],@t[3] +# ifdef __thumb2__ + itt hs +# endif + ldrhs @t[2],[r12,#-8] + ldrhs @t[3],[r12,#-4] +# if __ARM_ARCH__>=6 && defined(__ARMEB__) + rev @x[0],@x[0] + rev @x[1],@x[1] + rev @x[2],@x[2] + rev @x[3],@x[3] +# endif +# ifdef __thumb2__ + itt hs +# endif + eorhs @x[0],@x[0],@t[0] + eorhs @x[1],@x[1],@t[1] + add @t[0],sp,#4*(12) + str @x[0],[r14],#16 @ store output +# ifdef __thumb2__ + itt hs +# endif + eorhs @x[2],@x[2],@t[2] + eorhs @x[3],@x[3],@t[3] + str @x[1],[r14,#-12] + ldmia @t[0],{@t[0]-@t[3]} @ load key material + str @x[2],[r14,#-8] + str @x[3],[r14,#-4] + + add @x[4],@x[4],@t[0] @ accumulate key material + add @x[5],@x[5],@t[1] +# ifdef __thumb2__ + itt hi +# endif + addhi @t[0],@t[0],#1 @ next counter value + strhi @t[0],[sp,#4*(12)] @ save next counter value +# ifdef __thumb2__ + itt hs +# endif + ldrhs @t[0],[r12],#16 @ load input + ldrhs @t[1],[r12,#-12] + add @x[6],@x[6],@t[2] + add @x[7],@x[7],@t[3] +# ifdef __thumb2__ + itt hs +# endif + ldrhs @t[2],[r12,#-8] + ldrhs @t[3],[r12,#-4] +# if __ARM_ARCH__>=6 && defined(__ARMEB__) + rev @x[4],@x[4] + rev @x[5],@x[5] + rev @x[6],@x[6] + rev @x[7],@x[7] +# endif +# ifdef __thumb2__ + itt hs +# endif + eorhs @x[4],@x[4],@t[0] + eorhs @x[5],@x[5],@t[1] +# ifdef __thumb2__ + it ne +# endif + ldrne @t[0],[sp,#4*(32+2)] @ re-load len +# ifdef __thumb2__ + itt hs +# endif + eorhs @x[6],@x[6],@t[2] + eorhs @x[7],@x[7],@t[3] + str @x[4],[r14],#16 @ store output + str @x[5],[r14,#-12] +# ifdef __thumb2__ + it hs +# endif + subhs @t[3],@t[0],#64 @ len-=64 + str @x[6],[r14,#-8] + str @x[7],[r14,#-4] + bhi .Loop_outer + + beq .Ldone +# if __ARM_ARCH__<7 + b .Ltail + +.align 4 +.Lunaligned: @ unaligned endian-neutral path + cmp @t[3],#64 @ restore flags +# endif +#endif +#if __ARM_ARCH__<7 + ldr @t[3],[sp,#4*(3)] +___ +for ($i=0;$i<16;$i+=4) { +my $j=$i&0x7; + +$code.=<<___ if ($i==4); + add @x[0],sp,#4*(16+8) +___ +$code.=<<___ if ($i==8); + ldmia @x[0],{@x[0]-@x[7]} @ load second half +# ifdef __thumb2__ + itt hi +# endif + strhi @t[2],[sp,#4*(16+10)] @ copy "@x[10]" + strhi @t[3],[sp,#4*(16+11)] @ copy "@x[11]" +___ +$code.=<<___; + add @x[$j+0],@x[$j+0],@t[0] @ accumulate key material +___ +$code.=<<___ if ($i==12); +# ifdef __thumb2__ + itt hi +# endif + addhi @t[0],@t[0],#1 @ next counter value + strhi @t[0],[sp,#4*(12)] @ save next counter value +___ +$code.=<<___; + add @x[$j+1],@x[$j+1],@t[1] + add @x[$j+2],@x[$j+2],@t[2] +# ifdef __thumb2__ + itete lo +# endif + eorlo @t[0],@t[0],@t[0] @ zero or ... + ldrhsb @t[0],[r12],#16 @ ... load input + eorlo @t[1],@t[1],@t[1] + ldrhsb @t[1],[r12,#-12] + + add @x[$j+3],@x[$j+3],@t[3] +# ifdef __thumb2__ + itete lo +# endif + eorlo @t[2],@t[2],@t[2] + ldrhsb @t[2],[r12,#-8] + eorlo @t[3],@t[3],@t[3] + ldrhsb @t[3],[r12,#-4] + + eor @x[$j+0],@t[0],@x[$j+0] @ xor with input (or zero) + eor @x[$j+1],@t[1],@x[$j+1] +# ifdef __thumb2__ + itt hs +# endif + ldrhsb @t[0],[r12,#-15] @ load more input + ldrhsb @t[1],[r12,#-11] + eor @x[$j+2],@t[2],@x[$j+2] + strb @x[$j+0],[r14],#16 @ store output + eor @x[$j+3],@t[3],@x[$j+3] +# ifdef __thumb2__ + itt hs +# endif + ldrhsb @t[2],[r12,#-7] + ldrhsb @t[3],[r12,#-3] + strb @x[$j+1],[r14,#-12] + eor @x[$j+0],@t[0],@x[$j+0],lsr#8 + strb @x[$j+2],[r14,#-8] + eor @x[$j+1],@t[1],@x[$j+1],lsr#8 +# ifdef __thumb2__ + itt hs +# endif + ldrhsb @t[0],[r12,#-14] @ load more input + ldrhsb @t[1],[r12,#-10] + strb @x[$j+3],[r14,#-4] + eor @x[$j+2],@t[2],@x[$j+2],lsr#8 + strb @x[$j+0],[r14,#-15] + eor @x[$j+3],@t[3],@x[$j+3],lsr#8 +# ifdef __thumb2__ + itt hs +# endif + ldrhsb @t[2],[r12,#-6] + ldrhsb @t[3],[r12,#-2] + strb @x[$j+1],[r14,#-11] + eor @x[$j+0],@t[0],@x[$j+0],lsr#8 + strb @x[$j+2],[r14,#-7] + eor @x[$j+1],@t[1],@x[$j+1],lsr#8 +# ifdef __thumb2__ + itt hs +# endif + ldrhsb @t[0],[r12,#-13] @ load more input + ldrhsb @t[1],[r12,#-9] + strb @x[$j+3],[r14,#-3] + eor @x[$j+2],@t[2],@x[$j+2],lsr#8 + strb @x[$j+0],[r14,#-14] + eor @x[$j+3],@t[3],@x[$j+3],lsr#8 +# ifdef __thumb2__ + itt hs +# endif + ldrhsb @t[2],[r12,#-5] + ldrhsb @t[3],[r12,#-1] + strb @x[$j+1],[r14,#-10] + strb @x[$j+2],[r14,#-6] + eor @x[$j+0],@t[0],@x[$j+0],lsr#8 + strb @x[$j+3],[r14,#-2] + eor @x[$j+1],@t[1],@x[$j+1],lsr#8 + strb @x[$j+0],[r14,#-13] + eor @x[$j+2],@t[2],@x[$j+2],lsr#8 + strb @x[$j+1],[r14,#-9] + eor @x[$j+3],@t[3],@x[$j+3],lsr#8 + strb @x[$j+2],[r14,#-5] + strb @x[$j+3],[r14,#-1] +___ +$code.=<<___ if ($i<12); + add @t[0],sp,#4*(4+$i) + ldmia @t[0],{@t[0]-@t[3]} @ load key material +___ +} +$code.=<<___; +# ifdef __thumb2__ + it ne +# endif + ldrne @t[0],[sp,#4*(32+2)] @ re-load len +# ifdef __thumb2__ + it hs +# endif + subhs @t[3],@t[0],#64 @ len-=64 + bhi .Loop_outer + + beq .Ldone +#endif + +.Ltail: + ldr r12,[sp,#4*(32+1)] @ load inp + add @t[1],sp,#4*(0) + ldr r14,[sp,#4*(32+0)] @ load out + +.Loop_tail: + ldrb @t[2],[@t[1]],#1 @ read buffer on stack + ldrb @t[3],[r12],#1 @ read input + subs @t[0],@t[0],#1 + eor @t[3],@t[3],@t[2] + strb @t[3],[r14],#1 @ store output + bne .Loop_tail + +.Ldone: + add sp,sp,#4*(32+3) +.Lno_data: + ldmia sp!,{r4-r11,pc} +.size ChaCha20_ctr32,.-ChaCha20_ctr32 +___ + +{{{ +my ($a0,$b0,$c0,$d0,$a1,$b1,$c1,$d1,$a2,$b2,$c2,$d2,$t0,$t1,$t2,$t3) = + map("q$_",(0..15)); + +sub NEONROUND { +my $odd = pop; +my ($a,$b,$c,$d,$t)=@_; + + ( + "&vadd_i32 ($a,$a,$b)", + "&veor ($d,$d,$a)", + "&vrev32_16 ($d,$d)", # vrot ($d,16) + + "&vadd_i32 ($c,$c,$d)", + "&veor ($t,$b,$c)", + "&vshr_u32 ($b,$t,20)", + "&vsli_32 ($b,$t,12)", + + "&vadd_i32 ($a,$a,$b)", + "&veor ($t,$d,$a)", + "&vshr_u32 ($d,$t,24)", + "&vsli_32 ($d,$t,8)", + + "&vadd_i32 ($c,$c,$d)", + "&veor ($t,$b,$c)", + "&vshr_u32 ($b,$t,25)", + "&vsli_32 ($b,$t,7)", + + "&vext_8 ($c,$c,$c,8)", + "&vext_8 ($b,$b,$b,$odd?12:4)", + "&vext_8 ($d,$d,$d,$odd?4:12)" + ); +} + +$code.=<<___; +#if __ARM_MAX_ARCH__>=7 +.arch armv7-a +.fpu neon + +.type ChaCha20_neon,%function +.align 5 +ChaCha20_neon: + ldr r12,[sp,#0] @ pull pointer to counter and nonce + stmdb sp!,{r0-r2,r4-r11,lr} +.LChaCha20_neon: + adr r14,.Lsigma + vstmdb sp!,{d8-d15} @ ABI spec says so + stmdb sp!,{r0-r3} + + vld1.32 {$b0-$c0},[r3] @ load key + ldmia r3,{r4-r11} @ load key + + sub sp,sp,#4*(16+16) + vld1.32 {$d0},[r12] @ load counter and nonce + add r12,sp,#4*8 + ldmia r14,{r0-r3} @ load sigma + vld1.32 {$a0},[r14]! @ load sigma + vld1.32 {$t0},[r14] @ one + vst1.32 {$c0-$d0},[r12] @ copy 1/2key|counter|nonce + vst1.32 {$a0-$b0},[sp] @ copy sigma|1/2key + + str r10,[sp,#4*(16+10)] @ off-load "@x[10]" + str r11,[sp,#4*(16+11)] @ off-load "@x[11]" + vshl.i32 $t1#lo,$t0#lo,#1 @ two + vstr $t0#lo,[sp,#4*(16+0)] + vshl.i32 $t2#lo,$t0#lo,#2 @ four + vstr $t1#lo,[sp,#4*(16+2)] + vmov $a1,$a0 + vstr $t2#lo,[sp,#4*(16+4)] + vmov $a2,$a0 + vmov $b1,$b0 + vmov $b2,$b0 + b .Loop_neon_enter + +.align 4 +.Loop_neon_outer: + ldmia sp,{r0-r9} @ load key material + cmp @t[3],#64*2 @ if len<=64*2 + bls .Lbreak_neon @ switch to integer-only + vmov $a1,$a0 + str @t[3],[sp,#4*(32+2)] @ save len + vmov $a2,$a0 + str r12, [sp,#4*(32+1)] @ save inp + vmov $b1,$b0 + str r14, [sp,#4*(32+0)] @ save out + vmov $b2,$b0 +.Loop_neon_enter: + ldr @t[3], [sp,#4*(15)] + vadd.i32 $d1,$d0,$t0 @ counter+1 + ldr @x[12],[sp,#4*(12)] @ modulo-scheduled load + vmov $c1,$c0 + ldr @t[2], [sp,#4*(13)] + vmov $c2,$c0 + ldr @x[14],[sp,#4*(14)] + vadd.i32 $d2,$d1,$t0 @ counter+2 + str @t[3], [sp,#4*(16+15)] + mov @t[3],#10 + add @x[12],@x[12],#3 @ counter+3 + b .Loop_neon + +.align 4 +.Loop_neon: + subs @t[3],@t[3],#1 +___ + my @thread0=&NEONROUND($a0,$b0,$c0,$d0,$t0,0); + my @thread1=&NEONROUND($a1,$b1,$c1,$d1,$t1,0); + my @thread2=&NEONROUND($a2,$b2,$c2,$d2,$t2,0); + my @thread3=&ROUND(0,4,8,12); + + foreach (@thread0) { + eval; eval(shift(@thread3)); + eval(shift(@thread1)); eval(shift(@thread3)); + eval(shift(@thread2)); eval(shift(@thread3)); + } + + @thread0=&NEONROUND($a0,$b0,$c0,$d0,$t0,1); + @thread1=&NEONROUND($a1,$b1,$c1,$d1,$t1,1); + @thread2=&NEONROUND($a2,$b2,$c2,$d2,$t2,1); + @thread3=&ROUND(0,5,10,15); + + foreach (@thread0) { + eval; eval(shift(@thread3)); + eval(shift(@thread1)); eval(shift(@thread3)); + eval(shift(@thread2)); eval(shift(@thread3)); + } +$code.=<<___; + bne .Loop_neon + + add @t[3],sp,#32 + vld1.32 {$t0-$t1},[sp] @ load key material + vld1.32 {$t2-$t3},[@t[3]] + + ldr @t[3],[sp,#4*(32+2)] @ load len + + str @t[0], [sp,#4*(16+8)] @ modulo-scheduled store + str @t[1], [sp,#4*(16+9)] + str @x[12],[sp,#4*(16+12)] + str @t[2], [sp,#4*(16+13)] + str @x[14],[sp,#4*(16+14)] + + @ at this point we have first half of 512-bit result in + @ @x[0-7] and second half at sp+4*(16+8) + + ldr r12,[sp,#4*(32+1)] @ load inp + ldr r14,[sp,#4*(32+0)] @ load out + + vadd.i32 $a0,$a0,$t0 @ accumulate key material + vadd.i32 $a1,$a1,$t0 + vadd.i32 $a2,$a2,$t0 + vldr $t0#lo,[sp,#4*(16+0)] @ one + + vadd.i32 $b0,$b0,$t1 + vadd.i32 $b1,$b1,$t1 + vadd.i32 $b2,$b2,$t1 + vldr $t1#lo,[sp,#4*(16+2)] @ two + + vadd.i32 $c0,$c0,$t2 + vadd.i32 $c1,$c1,$t2 + vadd.i32 $c2,$c2,$t2 + vadd.i32 $d1#lo,$d1#lo,$t0#lo @ counter+1 + vadd.i32 $d2#lo,$d2#lo,$t1#lo @ counter+2 + + vadd.i32 $d0,$d0,$t3 + vadd.i32 $d1,$d1,$t3 + vadd.i32 $d2,$d2,$t3 + + cmp @t[3],#64*4 + blo .Ltail_neon + + vld1.8 {$t0-$t1},[r12]! @ load input + mov @t[3],sp + vld1.8 {$t2-$t3},[r12]! + veor $a0,$a0,$t0 @ xor with input + veor $b0,$b0,$t1 + vld1.8 {$t0-$t1},[r12]! + veor $c0,$c0,$t2 + veor $d0,$d0,$t3 + vld1.8 {$t2-$t3},[r12]! + + veor $a1,$a1,$t0 + vst1.8 {$a0-$b0},[r14]! @ store output + veor $b1,$b1,$t1 + vld1.8 {$t0-$t1},[r12]! + veor $c1,$c1,$t2 + vst1.8 {$c0-$d0},[r14]! + veor $d1,$d1,$t3 + vld1.8 {$t2-$t3},[r12]! + + veor $a2,$a2,$t0 + vld1.32 {$a0-$b0},[@t[3]]! @ load for next iteration + veor $t0#hi,$t0#hi,$t0#hi + vldr $t0#lo,[sp,#4*(16+4)] @ four + veor $b2,$b2,$t1 + vld1.32 {$c0-$d0},[@t[3]] + veor $c2,$c2,$t2 + vst1.8 {$a1-$b1},[r14]! + veor $d2,$d2,$t3 + vst1.8 {$c1-$d1},[r14]! + + vadd.i32 $d0#lo,$d0#lo,$t0#lo @ next counter value + vldr $t0#lo,[sp,#4*(16+0)] @ one + + ldmia sp,{@t[0]-@t[3]} @ load key material + add @x[0],@x[0],@t[0] @ accumulate key material + ldr @t[0],[r12],#16 @ load input + vst1.8 {$a2-$b2},[r14]! + add @x[1],@x[1],@t[1] + ldr @t[1],[r12,#-12] + vst1.8 {$c2-$d2},[r14]! + add @x[2],@x[2],@t[2] + ldr @t[2],[r12,#-8] + add @x[3],@x[3],@t[3] + ldr @t[3],[r12,#-4] +# ifdef __ARMEB__ + rev @x[0],@x[0] + rev @x[1],@x[1] + rev @x[2],@x[2] + rev @x[3],@x[3] +# endif + eor @x[0],@x[0],@t[0] @ xor with input + add @t[0],sp,#4*(4) + eor @x[1],@x[1],@t[1] + str @x[0],[r14],#16 @ store output + eor @x[2],@x[2],@t[2] + str @x[1],[r14,#-12] + eor @x[3],@x[3],@t[3] + ldmia @t[0],{@t[0]-@t[3]} @ load key material + str @x[2],[r14,#-8] + str @x[3],[r14,#-4] + + add @x[4],@x[4],@t[0] @ accumulate key material + ldr @t[0],[r12],#16 @ load input + add @x[5],@x[5],@t[1] + ldr @t[1],[r12,#-12] + add @x[6],@x[6],@t[2] + ldr @t[2],[r12,#-8] + add @x[7],@x[7],@t[3] + ldr @t[3],[r12,#-4] +# ifdef __ARMEB__ + rev @x[4],@x[4] + rev @x[5],@x[5] + rev @x[6],@x[6] + rev @x[7],@x[7] +# endif + eor @x[4],@x[4],@t[0] + add @t[0],sp,#4*(8) + eor @x[5],@x[5],@t[1] + str @x[4],[r14],#16 @ store output + eor @x[6],@x[6],@t[2] + str @x[5],[r14,#-12] + eor @x[7],@x[7],@t[3] + ldmia @t[0],{@t[0]-@t[3]} @ load key material + str @x[6],[r14,#-8] + add @x[0],sp,#4*(16+8) + str @x[7],[r14,#-4] + + ldmia @x[0],{@x[0]-@x[7]} @ load second half + + add @x[0],@x[0],@t[0] @ accumulate key material + ldr @t[0],[r12],#16 @ load input + add @x[1],@x[1],@t[1] + ldr @t[1],[r12,#-12] +# ifdef __thumb2__ + it hi +# endif + strhi @t[2],[sp,#4*(16+10)] @ copy "@x[10]" while at it + add @x[2],@x[2],@t[2] + ldr @t[2],[r12,#-8] +# ifdef __thumb2__ + it hi +# endif + strhi @t[3],[sp,#4*(16+11)] @ copy "@x[11]" while at it + add @x[3],@x[3],@t[3] + ldr @t[3],[r12,#-4] +# ifdef __ARMEB__ + rev @x[0],@x[0] + rev @x[1],@x[1] + rev @x[2],@x[2] + rev @x[3],@x[3] +# endif + eor @x[0],@x[0],@t[0] + add @t[0],sp,#4*(12) + eor @x[1],@x[1],@t[1] + str @x[0],[r14],#16 @ store output + eor @x[2],@x[2],@t[2] + str @x[1],[r14,#-12] + eor @x[3],@x[3],@t[3] + ldmia @t[0],{@t[0]-@t[3]} @ load key material + str @x[2],[r14,#-8] + str @x[3],[r14,#-4] + + add @x[4],@x[4],@t[0] @ accumulate key material + add @t[0],@t[0],#4 @ next counter value + add @x[5],@x[5],@t[1] + str @t[0],[sp,#4*(12)] @ save next counter value + ldr @t[0],[r12],#16 @ load input + add @x[6],@x[6],@t[2] + add @x[4],@x[4],#3 @ counter+3 + ldr @t[1],[r12,#-12] + add @x[7],@x[7],@t[3] + ldr @t[2],[r12,#-8] + ldr @t[3],[r12,#-4] +# ifdef __ARMEB__ + rev @x[4],@x[4] + rev @x[5],@x[5] + rev @x[6],@x[6] + rev @x[7],@x[7] +# endif + eor @x[4],@x[4],@t[0] +# ifdef __thumb2__ + it hi +# endif + ldrhi @t[0],[sp,#4*(32+2)] @ re-load len + eor @x[5],@x[5],@t[1] + eor @x[6],@x[6],@t[2] + str @x[4],[r14],#16 @ store output + eor @x[7],@x[7],@t[3] + str @x[5],[r14,#-12] + sub @t[3],@t[0],#64*4 @ len-=64*4 + str @x[6],[r14,#-8] + str @x[7],[r14,#-4] + bhi .Loop_neon_outer + + b .Ldone_neon + +.align 4 +.Lbreak_neon: + @ harmonize NEON and integer-only stack frames: load data + @ from NEON frame, but save to integer-only one; distance + @ between the two is 4*(32+4+16-32)=4*(20). + + str @t[3], [sp,#4*(20+32+2)] @ save len + add @t[3],sp,#4*(32+4) + str r12, [sp,#4*(20+32+1)] @ save inp + str r14, [sp,#4*(20+32+0)] @ save out + + ldr @x[12],[sp,#4*(16+10)] + ldr @x[14],[sp,#4*(16+11)] + vldmia @t[3],{d8-d15} @ fulfill ABI requirement + str @x[12],[sp,#4*(20+16+10)] @ copy "@x[10]" + str @x[14],[sp,#4*(20+16+11)] @ copy "@x[11]" + + ldr @t[3], [sp,#4*(15)] + ldr @x[12],[sp,#4*(12)] @ modulo-scheduled load + ldr @t[2], [sp,#4*(13)] + ldr @x[14],[sp,#4*(14)] + str @t[3], [sp,#4*(20+16+15)] + add @t[3],sp,#4*(20) + vst1.32 {$a0-$b0},[@t[3]]! @ copy key + add sp,sp,#4*(20) @ switch frame + vst1.32 {$c0-$d0},[@t[3]] + mov @t[3],#10 + b .Loop @ go integer-only + +.align 4 +.Ltail_neon: + cmp @t[3],#64*3 + bhs .L192_or_more_neon + cmp @t[3],#64*2 + bhs .L128_or_more_neon + cmp @t[3],#64*1 + bhs .L64_or_more_neon + + add @t[0],sp,#4*(8) + vst1.8 {$a0-$b0},[sp] + add @t[2],sp,#4*(0) + vst1.8 {$c0-$d0},[@t[0]] + b .Loop_tail_neon + +.align 4 +.L64_or_more_neon: + vld1.8 {$t0-$t1},[r12]! + vld1.8 {$t2-$t3},[r12]! + veor $a0,$a0,$t0 + veor $b0,$b0,$t1 + veor $c0,$c0,$t2 + veor $d0,$d0,$t3 + vst1.8 {$a0-$b0},[r14]! + vst1.8 {$c0-$d0},[r14]! + + beq .Ldone_neon + + add @t[0],sp,#4*(8) + vst1.8 {$a1-$b1},[sp] + add @t[2],sp,#4*(0) + vst1.8 {$c1-$d1},[@t[0]] + sub @t[3],@t[3],#64*1 @ len-=64*1 + b .Loop_tail_neon + +.align 4 +.L128_or_more_neon: + vld1.8 {$t0-$t1},[r12]! + vld1.8 {$t2-$t3},[r12]! + veor $a0,$a0,$t0 + veor $b0,$b0,$t1 + vld1.8 {$t0-$t1},[r12]! + veor $c0,$c0,$t2 + veor $d0,$d0,$t3 + vld1.8 {$t2-$t3},[r12]! + + veor $a1,$a1,$t0 + veor $b1,$b1,$t1 + vst1.8 {$a0-$b0},[r14]! + veor $c1,$c1,$t2 + vst1.8 {$c0-$d0},[r14]! + veor $d1,$d1,$t3 + vst1.8 {$a1-$b1},[r14]! + vst1.8 {$c1-$d1},[r14]! + + beq .Ldone_neon + + add @t[0],sp,#4*(8) + vst1.8 {$a2-$b2},[sp] + add @t[2],sp,#4*(0) + vst1.8 {$c2-$d2},[@t[0]] + sub @t[3],@t[3],#64*2 @ len-=64*2 + b .Loop_tail_neon + +.align 4 +.L192_or_more_neon: + vld1.8 {$t0-$t1},[r12]! + vld1.8 {$t2-$t3},[r12]! + veor $a0,$a0,$t0 + veor $b0,$b0,$t1 + vld1.8 {$t0-$t1},[r12]! + veor $c0,$c0,$t2 + veor $d0,$d0,$t3 + vld1.8 {$t2-$t3},[r12]! + + veor $a1,$a1,$t0 + veor $b1,$b1,$t1 + vld1.8 {$t0-$t1},[r12]! + veor $c1,$c1,$t2 + vst1.8 {$a0-$b0},[r14]! + veor $d1,$d1,$t3 + vld1.8 {$t2-$t3},[r12]! + + veor $a2,$a2,$t0 + vst1.8 {$c0-$d0},[r14]! + veor $b2,$b2,$t1 + vst1.8 {$a1-$b1},[r14]! + veor $c2,$c2,$t2 + vst1.8 {$c1-$d1},[r14]! + veor $d2,$d2,$t3 + vst1.8 {$a2-$b2},[r14]! + vst1.8 {$c2-$d2},[r14]! + + beq .Ldone_neon + + ldmia sp,{@t[0]-@t[3]} @ load key material + add @x[0],@x[0],@t[0] @ accumulate key material + add @t[0],sp,#4*(4) + add @x[1],@x[1],@t[1] + add @x[2],@x[2],@t[2] + add @x[3],@x[3],@t[3] + ldmia @t[0],{@t[0]-@t[3]} @ load key material + + add @x[4],@x[4],@t[0] @ accumulate key material + add @t[0],sp,#4*(8) + add @x[5],@x[5],@t[1] + add @x[6],@x[6],@t[2] + add @x[7],@x[7],@t[3] + ldmia @t[0],{@t[0]-@t[3]} @ load key material +# ifdef __ARMEB__ + rev @x[0],@x[0] + rev @x[1],@x[1] + rev @x[2],@x[2] + rev @x[3],@x[3] + rev @x[4],@x[4] + rev @x[5],@x[5] + rev @x[6],@x[6] + rev @x[7],@x[7] +# endif + stmia sp,{@x[0]-@x[7]} + add @x[0],sp,#4*(16+8) + + ldmia @x[0],{@x[0]-@x[7]} @ load second half + + add @x[0],@x[0],@t[0] @ accumulate key material + add @t[0],sp,#4*(12) + add @x[1],@x[1],@t[1] + add @x[2],@x[2],@t[2] + add @x[3],@x[3],@t[3] + ldmia @t[0],{@t[0]-@t[3]} @ load key material + + add @x[4],@x[4],@t[0] @ accumulate key material + add @t[0],sp,#4*(8) + add @x[5],@x[5],@t[1] + add @x[4],@x[4],#3 @ counter+3 + add @x[6],@x[6],@t[2] + add @x[7],@x[7],@t[3] + ldr @t[3],[sp,#4*(32+2)] @ re-load len +# ifdef __ARMEB__ + rev @x[0],@x[0] + rev @x[1],@x[1] + rev @x[2],@x[2] + rev @x[3],@x[3] + rev @x[4],@x[4] + rev @x[5],@x[5] + rev @x[6],@x[6] + rev @x[7],@x[7] +# endif + stmia @t[0],{@x[0]-@x[7]} + add @t[2],sp,#4*(0) + sub @t[3],@t[3],#64*3 @ len-=64*3 + +.Loop_tail_neon: + ldrb @t[0],[@t[2]],#1 @ read buffer on stack + ldrb @t[1],[r12],#1 @ read input + subs @t[3],@t[3],#1 + eor @t[0],@t[0],@t[1] + strb @t[0],[r14],#1 @ store output + bne .Loop_tail_neon + +.Ldone_neon: + add sp,sp,#4*(32+4) + vldmia sp,{d8-d15} + add sp,sp,#4*(16+3) + ldmia sp!,{r4-r11,pc} +.size ChaCha20_neon,.-ChaCha20_neon +.comm OPENSSL_armcap_P,4,4 +#endif +___ +}}} + +foreach (split("\n",$code)) { + s/\`([^\`]*)\`/eval $1/geo; + + s/\bq([0-9]+)#(lo|hi)/sprintf "d%d",2*$1+($2 eq "hi")/geo; + + print $_,"\n"; +} +close STDOUT; diff --git a/crypto/chacha20/chacha20-arm.s b/crypto/chacha20/chacha20-arm.s new file mode 100644 index 0000000..2e22fd1 --- /dev/null +++ b/crypto/chacha20/chacha20-arm.s @@ -0,0 +1,1475 @@ +/* SPDX-License-Identifier: OpenSSL OR (BSD-3-Clause OR GPL-2.0) + * + * Copyright (C) 2015-2018 Jason A. Donenfeld . All Rights Reserved. + * Copyright 2016 The OpenSSL Project Authors. All Rights Reserved. + */ + +/*#include */ + +.text +#if defined(__thumb2__) || defined(__clang__) +.syntax unified +#endif +#if defined(__thumb2__) +.thumb +#else +.code 32 +#endif + +#if defined(__thumb2__) || defined(__clang__) +#define ldrbhs ldrbhs +#endif + +.align 5 +.Lsigma: +.long 0x61707865,0x3320646e,0x79622d32,0x6b206574 @ endian-neutral +.Lone: +.long 1,0,0,0 +.word -1 + +#if __ARM_ARCH__ >= 7 +.arch armv7-a +.fpu neon + +.align 5 +.globl chacha20_neon +.type chacha20_neon,%function +chacha20_neon: + ldr r12,[sp,#0] @ pull pointer to counter and nonce + stmdb sp!,{r0-r2,r4-r11,lr} + cmp r2,#0 @ len==0? +#ifdef __thumb2__ + itt eq +#endif + addeq sp,sp,#4*3 + beq .Lno_data_neon + cmp r2,#192 @ test len + bls .Lshort +.Lchacha20_neon_begin: + adr r14,.Lsigma + vstmdb sp!,{d8-d15} @ ABI spec says so + stmdb sp!,{r0-r3} + + vld1.32 {q1-q2},[r3] @ load key + ldmia r3,{r4-r11} @ load key + + sub sp,sp,#4*(16+16) + vld1.32 {q3},[r12] @ load counter and nonce + add r12,sp,#4*8 + ldmia r14,{r0-r3} @ load sigma + vld1.32 {q0},[r14]! @ load sigma + vld1.32 {q12},[r14] @ one + vst1.32 {q2-q3},[r12] @ copy 1/2key|counter|nonce + vst1.32 {q0-q1},[sp] @ copy sigma|1/2key + + str r10,[sp,#4*(16+10)] @ off-load "rx" + str r11,[sp,#4*(16+11)] @ off-load "rx" + vshl.i32 d26,d24,#1 @ two + vstr d24,[sp,#4*(16+0)] + vshl.i32 d28,d24,#2 @ four + vstr d26,[sp,#4*(16+2)] + vmov q4,q0 + vstr d28,[sp,#4*(16+4)] + vmov q8,q0 + vmov q5,q1 + vmov q9,q1 + b .Loop_neon_enter + +.align 4 +.Loop_neon_outer: + ldmia sp,{r0-r9} @ load key material + cmp r11,#64*2 @ if len<=64*2 + bls .Lbreak_neon @ switch to integer-only + vmov q4,q0 + str r11,[sp,#4*(32+2)] @ save len + vmov q8,q0 + str r12, [sp,#4*(32+1)] @ save inp + vmov q5,q1 + str r14, [sp,#4*(32+0)] @ save out + vmov q9,q1 +.Loop_neon_enter: + ldr r11, [sp,#4*(15)] + vadd.i32 q7,q3,q12 @ counter+1 + ldr r12,[sp,#4*(12)] @ modulo-scheduled load + vmov q6,q2 + ldr r10, [sp,#4*(13)] + vmov q10,q2 + ldr r14,[sp,#4*(14)] + vadd.i32 q11,q7,q12 @ counter+2 + str r11, [sp,#4*(16+15)] + mov r11,#10 + add r12,r12,#3 @ counter+3 + b .Loop_neon + +.align 4 +.Loop_neon: + subs r11,r11,#1 + vadd.i32 q0,q0,q1 + add r0,r0,r4 + vadd.i32 q4,q4,q5 + mov r12,r12,ror#16 + vadd.i32 q8,q8,q9 + add r1,r1,r5 + veor q3,q3,q0 + mov r10,r10,ror#16 + veor q7,q7,q4 + eor r12,r12,r0,ror#16 + veor q11,q11,q8 + eor r10,r10,r1,ror#16 + vrev32.16 q3,q3 + add r8,r8,r12 + vrev32.16 q7,q7 + mov r4,r4,ror#20 + vrev32.16 q11,q11 + add r9,r9,r10 + vadd.i32 q2,q2,q3 + mov r5,r5,ror#20 + vadd.i32 q6,q6,q7 + eor r4,r4,r8,ror#20 + vadd.i32 q10,q10,q11 + eor r5,r5,r9,ror#20 + veor q12,q1,q2 + add r0,r0,r4 + veor q13,q5,q6 + mov r12,r12,ror#24 + veor q14,q9,q10 + add r1,r1,r5 + vshr.u32 q1,q12,#20 + mov r10,r10,ror#24 + vshr.u32 q5,q13,#20 + eor r12,r12,r0,ror#24 + vshr.u32 q9,q14,#20 + eor r10,r10,r1,ror#24 + vsli.32 q1,q12,#12 + add r8,r8,r12 + vsli.32 q5,q13,#12 + mov r4,r4,ror#25 + vsli.32 q9,q14,#12 + add r9,r9,r10 + vadd.i32 q0,q0,q1 + mov r5,r5,ror#25 + vadd.i32 q4,q4,q5 + str r10,[sp,#4*(16+13)] + vadd.i32 q8,q8,q9 + ldr r10,[sp,#4*(16+15)] + veor q12,q3,q0 + eor r4,r4,r8,ror#25 + veor q13,q7,q4 + eor r5,r5,r9,ror#25 + veor q14,q11,q8 + str r8,[sp,#4*(16+8)] + vshr.u32 q3,q12,#24 + ldr r8,[sp,#4*(16+10)] + vshr.u32 q7,q13,#24 + add r2,r2,r6 + vshr.u32 q11,q14,#24 + mov r14,r14,ror#16 + vsli.32 q3,q12,#8 + str r9,[sp,#4*(16+9)] + vsli.32 q7,q13,#8 + ldr r9,[sp,#4*(16+11)] + vsli.32 q11,q14,#8 + add r3,r3,r7 + vadd.i32 q2,q2,q3 + mov r10,r10,ror#16 + vadd.i32 q6,q6,q7 + eor r14,r14,r2,ror#16 + vadd.i32 q10,q10,q11 + eor r10,r10,r3,ror#16 + veor q12,q1,q2 + add r8,r8,r14 + veor q13,q5,q6 + mov r6,r6,ror#20 + veor q14,q9,q10 + add r9,r9,r10 + vshr.u32 q1,q12,#25 + mov r7,r7,ror#20 + vshr.u32 q5,q13,#25 + eor r6,r6,r8,ror#20 + vshr.u32 q9,q14,#25 + eor r7,r7,r9,ror#20 + vsli.32 q1,q12,#7 + add r2,r2,r6 + vsli.32 q5,q13,#7 + mov r14,r14,ror#24 + vsli.32 q9,q14,#7 + add r3,r3,r7 + vext.8 q2,q2,q2,#8 + mov r10,r10,ror#24 + vext.8 q6,q6,q6,#8 + eor r14,r14,r2,ror#24 + vext.8 q10,q10,q10,#8 + eor r10,r10,r3,ror#24 + vext.8 q1,q1,q1,#4 + add r8,r8,r14 + vext.8 q5,q5,q5,#4 + mov r6,r6,ror#25 + vext.8 q9,q9,q9,#4 + add r9,r9,r10 + vext.8 q3,q3,q3,#12 + mov r7,r7,ror#25 + vext.8 q7,q7,q7,#12 + eor r6,r6,r8,ror#25 + vext.8 q11,q11,q11,#12 + eor r7,r7,r9,ror#25 + vadd.i32 q0,q0,q1 + add r0,r0,r5 + vadd.i32 q4,q4,q5 + mov r10,r10,ror#16 + vadd.i32 q8,q8,q9 + add r1,r1,r6 + veor q3,q3,q0 + mov r12,r12,ror#16 + veor q7,q7,q4 + eor r10,r10,r0,ror#16 + veor q11,q11,q8 + eor r12,r12,r1,ror#16 + vrev32.16 q3,q3 + add r8,r8,r10 + vrev32.16 q7,q7 + mov r5,r5,ror#20 + vrev32.16 q11,q11 + add r9,r9,r12 + vadd.i32 q2,q2,q3 + mov r6,r6,ror#20 + vadd.i32 q6,q6,q7 + eor r5,r5,r8,ror#20 + vadd.i32 q10,q10,q11 + eor r6,r6,r9,ror#20 + veor q12,q1,q2 + add r0,r0,r5 + veor q13,q5,q6 + mov r10,r10,ror#24 + veor q14,q9,q10 + add r1,r1,r6 + vshr.u32 q1,q12,#20 + mov r12,r12,ror#24 + vshr.u32 q5,q13,#20 + eor r10,r10,r0,ror#24 + vshr.u32 q9,q14,#20 + eor r12,r12,r1,ror#24 + vsli.32 q1,q12,#12 + add r8,r8,r10 + vsli.32 q5,q13,#12 + mov r5,r5,ror#25 + vsli.32 q9,q14,#12 + str r10,[sp,#4*(16+15)] + vadd.i32 q0,q0,q1 + ldr r10,[sp,#4*(16+13)] + vadd.i32 q4,q4,q5 + add r9,r9,r12 + vadd.i32 q8,q8,q9 + mov r6,r6,ror#25 + veor q12,q3,q0 + eor r5,r5,r8,ror#25 + veor q13,q7,q4 + eor r6,r6,r9,ror#25 + veor q14,q11,q8 + str r8,[sp,#4*(16+10)] + vshr.u32 q3,q12,#24 + ldr r8,[sp,#4*(16+8)] + vshr.u32 q7,q13,#24 + add r2,r2,r7 + vshr.u32 q11,q14,#24 + mov r10,r10,ror#16 + vsli.32 q3,q12,#8 + str r9,[sp,#4*(16+11)] + vsli.32 q7,q13,#8 + ldr r9,[sp,#4*(16+9)] + vsli.32 q11,q14,#8 + add r3,r3,r4 + vadd.i32 q2,q2,q3 + mov r14,r14,ror#16 + vadd.i32 q6,q6,q7 + eor r10,r10,r2,ror#16 + vadd.i32 q10,q10,q11 + eor r14,r14,r3,ror#16 + veor q12,q1,q2 + add r8,r8,r10 + veor q13,q5,q6 + mov r7,r7,ror#20 + veor q14,q9,q10 + add r9,r9,r14 + vshr.u32 q1,q12,#25 + mov r4,r4,ror#20 + vshr.u32 q5,q13,#25 + eor r7,r7,r8,ror#20 + vshr.u32 q9,q14,#25 + eor r4,r4,r9,ror#20 + vsli.32 q1,q12,#7 + add r2,r2,r7 + vsli.32 q5,q13,#7 + mov r10,r10,ror#24 + vsli.32 q9,q14,#7 + add r3,r3,r4 + vext.8 q2,q2,q2,#8 + mov r14,r14,ror#24 + vext.8 q6,q6,q6,#8 + eor r10,r10,r2,ror#24 + vext.8 q10,q10,q10,#8 + eor r14,r14,r3,ror#24 + vext.8 q1,q1,q1,#12 + add r8,r8,r10 + vext.8 q5,q5,q5,#12 + mov r7,r7,ror#25 + vext.8 q9,q9,q9,#12 + add r9,r9,r14 + vext.8 q3,q3,q3,#4 + mov r4,r4,ror#25 + vext.8 q7,q7,q7,#4 + eor r7,r7,r8,ror#25 + vext.8 q11,q11,q11,#4 + eor r4,r4,r9,ror#25 + bne .Loop_neon + + add r11,sp,#32 + vld1.32 {q12-q13},[sp] @ load key material + vld1.32 {q14-q15},[r11] + + ldr r11,[sp,#4*(32+2)] @ load len + + str r8, [sp,#4*(16+8)] @ modulo-scheduled store + str r9, [sp,#4*(16+9)] + str r12,[sp,#4*(16+12)] + str r10, [sp,#4*(16+13)] + str r14,[sp,#4*(16+14)] + + @ at this point we have first half of 512-bit result in + @ rx and second half at sp+4*(16+8) + + ldr r12,[sp,#4*(32+1)] @ load inp + ldr r14,[sp,#4*(32+0)] @ load out + + vadd.i32 q0,q0,q12 @ accumulate key material + vadd.i32 q4,q4,q12 + vadd.i32 q8,q8,q12 + vldr d24,[sp,#4*(16+0)] @ one + + vadd.i32 q1,q1,q13 + vadd.i32 q5,q5,q13 + vadd.i32 q9,q9,q13 + vldr d26,[sp,#4*(16+2)] @ two + + vadd.i32 q2,q2,q14 + vadd.i32 q6,q6,q14 + vadd.i32 q10,q10,q14 + vadd.i32 d14,d14,d24 @ counter+1 + vadd.i32 d22,d22,d26 @ counter+2 + + vadd.i32 q3,q3,q15 + vadd.i32 q7,q7,q15 + vadd.i32 q11,q11,q15 + + cmp r11,#64*4 + blo .Ltail_neon + + vld1.8 {q12-q13},[r12]! @ load input + mov r11,sp + vld1.8 {q14-q15},[r12]! + veor q0,q0,q12 @ xor with input + veor q1,q1,q13 + vld1.8 {q12-q13},[r12]! + veor q2,q2,q14 + veor q3,q3,q15 + vld1.8 {q14-q15},[r12]! + + veor q4,q4,q12 + vst1.8 {q0-q1},[r14]! @ store output + veor q5,q5,q13 + vld1.8 {q12-q13},[r12]! + veor q6,q6,q14 + vst1.8 {q2-q3},[r14]! + veor q7,q7,q15 + vld1.8 {q14-q15},[r12]! + + veor q8,q8,q12 + vld1.32 {q0-q1},[r11]! @ load for next iteration + veor d25,d25,d25 + vldr d24,[sp,#4*(16+4)] @ four + veor q9,q9,q13 + vld1.32 {q2-q3},[r11] + veor q10,q10,q14 + vst1.8 {q4-q5},[r14]! + veor q11,q11,q15 + vst1.8 {q6-q7},[r14]! + + vadd.i32 d6,d6,d24 @ next counter value + vldr d24,[sp,#4*(16+0)] @ one + + ldmia sp,{r8-r11} @ load key material + add r0,r0,r8 @ accumulate key material + ldr r8,[r12],#16 @ load input + vst1.8 {q8-q9},[r14]! + add r1,r1,r9 + ldr r9,[r12,#-12] + vst1.8 {q10-q11},[r14]! + add r2,r2,r10 + ldr r10,[r12,#-8] + add r3,r3,r11 + ldr r11,[r12,#-4] +#ifdef __ARMEB__ + rev r0,r0 + rev r1,r1 + rev r2,r2 + rev r3,r3 +#endif + eor r0,r0,r8 @ xor with input + add r8,sp,#4*(4) + eor r1,r1,r9 + str r0,[r14],#16 @ store output + eor r2,r2,r10 + str r1,[r14,#-12] + eor r3,r3,r11 + ldmia r8,{r8-r11} @ load key material + str r2,[r14,#-8] + str r3,[r14,#-4] + + add r4,r4,r8 @ accumulate key material + ldr r8,[r12],#16 @ load input + add r5,r5,r9 + ldr r9,[r12,#-12] + add r6,r6,r10 + ldr r10,[r12,#-8] + add r7,r7,r11 + ldr r11,[r12,#-4] +#ifdef __ARMEB__ + rev r4,r4 + rev r5,r5 + rev r6,r6 + rev r7,r7 +#endif + eor r4,r4,r8 + add r8,sp,#4*(8) + eor r5,r5,r9 + str r4,[r14],#16 @ store output + eor r6,r6,r10 + str r5,[r14,#-12] + eor r7,r7,r11 + ldmia r8,{r8-r11} @ load key material + str r6,[r14,#-8] + add r0,sp,#4*(16+8) + str r7,[r14,#-4] + + ldmia r0,{r0-r7} @ load second half + + add r0,r0,r8 @ accumulate key material + ldr r8,[r12],#16 @ load input + add r1,r1,r9 + ldr r9,[r12,#-12] +#ifdef __thumb2__ + it hi +#endif + strhi r10,[sp,#4*(16+10)] @ copy "rx" while at it + add r2,r2,r10 + ldr r10,[r12,#-8] +#ifdef __thumb2__ + it hi +#endif + strhi r11,[sp,#4*(16+11)] @ copy "rx" while at it + add r3,r3,r11 + ldr r11,[r12,#-4] +#ifdef __ARMEB__ + rev r0,r0 + rev r1,r1 + rev r2,r2 + rev r3,r3 +#endif + eor r0,r0,r8 + add r8,sp,#4*(12) + eor r1,r1,r9 + str r0,[r14],#16 @ store output + eor r2,r2,r10 + str r1,[r14,#-12] + eor r3,r3,r11 + ldmia r8,{r8-r11} @ load key material + str r2,[r14,#-8] + str r3,[r14,#-4] + + add r4,r4,r8 @ accumulate key material + add r8,r8,#4 @ next counter value + add r5,r5,r9 + str r8,[sp,#4*(12)] @ save next counter value + ldr r8,[r12],#16 @ load input + add r6,r6,r10 + add r4,r4,#3 @ counter+3 + ldr r9,[r12,#-12] + add r7,r7,r11 + ldr r10,[r12,#-8] + ldr r11,[r12,#-4] +#ifdef __ARMEB__ + rev r4,r4 + rev r5,r5 + rev r6,r6 + rev r7,r7 +#endif + eor r4,r4,r8 +#ifdef __thumb2__ + it hi +#endif + ldrhi r8,[sp,#4*(32+2)] @ re-load len + eor r5,r5,r9 + eor r6,r6,r10 + str r4,[r14],#16 @ store output + eor r7,r7,r11 + str r5,[r14,#-12] + sub r11,r8,#64*4 @ len-=64*4 + str r6,[r14,#-8] + str r7,[r14,#-4] + bhi .Loop_neon_outer + + b .Ldone_neon + +.align 4 +.Lbreak_neon: + @ harmonize NEON and integer-only stack frames: load data + @ from NEON frame, but save to integer-only one; distance + @ between the two is 4*(32+4+16-32)=4*(20). + + str r11, [sp,#4*(20+32+2)] @ save len + add r11,sp,#4*(32+4) + str r12, [sp,#4*(20+32+1)] @ save inp + str r14, [sp,#4*(20+32+0)] @ save out + + ldr r12,[sp,#4*(16+10)] + ldr r14,[sp,#4*(16+11)] + vldmia r11,{d8-d15} @ fulfill ABI requirement + str r12,[sp,#4*(20+16+10)] @ copy "rx" + str r14,[sp,#4*(20+16+11)] @ copy "rx" + + ldr r11, [sp,#4*(15)] + ldr r12,[sp,#4*(12)] @ modulo-scheduled load + ldr r10, [sp,#4*(13)] + ldr r14,[sp,#4*(14)] + str r11, [sp,#4*(20+16+15)] + add r11,sp,#4*(20) + vst1.32 {q0-q1},[r11]! @ copy key + add sp,sp,#4*(20) @ switch frame + vst1.32 {q2-q3},[r11] + mov r11,#10 + b .Loop @ go integer-only + +.align 4 +.Ltail_neon: + cmp r11,#64*3 + bhs .L192_or_more_neon + cmp r11,#64*2 + bhs .L128_or_more_neon + cmp r11,#64*1 + bhs .L64_or_more_neon + + add r8,sp,#4*(8) + vst1.8 {q0-q1},[sp] + add r10,sp,#4*(0) + vst1.8 {q2-q3},[r8] + b .Loop_tail_neon + +.align 4 +.L64_or_more_neon: + vld1.8 {q12-q13},[r12]! + vld1.8 {q14-q15},[r12]! + veor q0,q0,q12 + veor q1,q1,q13 + veor q2,q2,q14 + veor q3,q3,q15 + vst1.8 {q0-q1},[r14]! + vst1.8 {q2-q3},[r14]! + + beq .Ldone_neon + + add r8,sp,#4*(8) + vst1.8 {q4-q5},[sp] + add r10,sp,#4*(0) + vst1.8 {q6-q7},[r8] + sub r11,r11,#64*1 @ len-=64*1 + b .Loop_tail_neon + +.align 4 +.L128_or_more_neon: + vld1.8 {q12-q13},[r12]! + vld1.8 {q14-q15},[r12]! + veor q0,q0,q12 + veor q1,q1,q13 + vld1.8 {q12-q13},[r12]! + veor q2,q2,q14 + veor q3,q3,q15 + vld1.8 {q14-q15},[r12]! + + veor q4,q4,q12 + veor q5,q5,q13 + vst1.8 {q0-q1},[r14]! + veor q6,q6,q14 + vst1.8 {q2-q3},[r14]! + veor q7,q7,q15 + vst1.8 {q4-q5},[r14]! + vst1.8 {q6-q7},[r14]! + + beq .Ldone_neon + + add r8,sp,#4*(8) + vst1.8 {q8-q9},[sp] + add r10,sp,#4*(0) + vst1.8 {q10-q11},[r8] + sub r11,r11,#64*2 @ len-=64*2 + b .Loop_tail_neon + +.align 4 +.L192_or_more_neon: + vld1.8 {q12-q13},[r12]! + vld1.8 {q14-q15},[r12]! + veor q0,q0,q12 + veor q1,q1,q13 + vld1.8 {q12-q13},[r12]! + veor q2,q2,q14 + veor q3,q3,q15 + vld1.8 {q14-q15},[r12]! + + veor q4,q4,q12 + veor q5,q5,q13 + vld1.8 {q12-q13},[r12]! + veor q6,q6,q14 + vst1.8 {q0-q1},[r14]! + veor q7,q7,q15 + vld1.8 {q14-q15},[r12]! + + veor q8,q8,q12 + vst1.8 {q2-q3},[r14]! + veor q9,q9,q13 + vst1.8 {q4-q5},[r14]! + veor q10,q10,q14 + vst1.8 {q6-q7},[r14]! + veor q11,q11,q15 + vst1.8 {q8-q9},[r14]! + vst1.8 {q10-q11},[r14]! + + beq .Ldone_neon + + ldmia sp,{r8-r11} @ load key material + add r0,r0,r8 @ accumulate key material + add r8,sp,#4*(4) + add r1,r1,r9 + add r2,r2,r10 + add r3,r3,r11 + ldmia r8,{r8-r11} @ load key material + + add r4,r4,r8 @ accumulate key material + add r8,sp,#4*(8) + add r5,r5,r9 + add r6,r6,r10 + add r7,r7,r11 + ldmia r8,{r8-r11} @ load key material +#ifdef __ARMEB__ + rev r0,r0 + rev r1,r1 + rev r2,r2 + rev r3,r3 + rev r4,r4 + rev r5,r5 + rev r6,r6 + rev r7,r7 +#endif + stmia sp,{r0-r7} + add r0,sp,#4*(16+8) + + ldmia r0,{r0-r7} @ load second half + + add r0,r0,r8 @ accumulate key material + add r8,sp,#4*(12) + add r1,r1,r9 + add r2,r2,r10 + add r3,r3,r11 + ldmia r8,{r8-r11} @ load key material + + add r4,r4,r8 @ accumulate key material + add r8,sp,#4*(8) + add r5,r5,r9 + add r4,r4,#3 @ counter+3 + add r6,r6,r10 + add r7,r7,r11 + ldr r11,[sp,#4*(32+2)] @ re-load len +#ifdef __ARMEB__ + rev r0,r0 + rev r1,r1 + rev r2,r2 + rev r3,r3 + rev r4,r4 + rev r5,r5 + rev r6,r6 + rev r7,r7 +#endif + stmia r8,{r0-r7} + add r10,sp,#4*(0) + sub r11,r11,#64*3 @ len-=64*3 + +.Loop_tail_neon: + ldrb r8,[r10],#1 @ read buffer on stack + ldrb r9,[r12],#1 @ read input + subs r11,r11,#1 + eor r8,r8,r9 + strb r8,[r14],#1 @ store output + bne .Loop_tail_neon + +.Ldone_neon: + add sp,sp,#4*(32+4) + vldmia sp,{d8-d15} + add sp,sp,#4*(16+3) +.Lno_data_neon: + ldmia sp!,{r4-r11,pc} +.size chacha20_neon,.-chacha20_neon +#endif + +.align 5 +.Lsigma2: +.long 0x61707865,0x3320646e,0x79622d32,0x6b206574 @ endian-neutral +.Lone2: +.long 1,0,0,0 +.word -1 + +.align 5 +.globl chacha20_arm +.type chacha20_arm,%function +chacha20_arm: + ldr r12,[sp,#0] @ pull pointer to counter and nonce + stmdb sp!,{r0-r2,r4-r11,lr} + cmp r2,#0 @ len==0? +#ifdef __thumb2__ + itt eq +#endif + addeq sp,sp,#4*3 + beq .Lno_data_arm +.Lshort: + ldmia r12,{r4-r7} @ load counter and nonce + sub sp,sp,#4*(16) @ off-load area +#if __ARM_ARCH__ < 7 && !defined(__thumb2__) + sub r14,pc,#100 @ .Lsigma2 +#else + adr r14,.Lsigma2 @ .Lsigma2 +#endif + stmdb sp!,{r4-r7} @ copy counter and nonce + ldmia r3,{r4-r11} @ load key + ldmia r14,{r0-r3} @ load sigma + stmdb sp!,{r4-r11} @ copy key + stmdb sp!,{r0-r3} @ copy sigma + str r10,[sp,#4*(16+10)] @ off-load "rx" + str r11,[sp,#4*(16+11)] @ off-load "rx" + b .Loop_outer_enter + +.align 4 +.Loop_outer: + ldmia sp,{r0-r9} @ load key material + str r11,[sp,#4*(32+2)] @ save len + str r12, [sp,#4*(32+1)] @ save inp + str r14, [sp,#4*(32+0)] @ save out +.Loop_outer_enter: + ldr r11, [sp,#4*(15)] + ldr r12,[sp,#4*(12)] @ modulo-scheduled load + ldr r10, [sp,#4*(13)] + ldr r14,[sp,#4*(14)] + str r11, [sp,#4*(16+15)] + mov r11,#10 + b .Loop + +.align 4 +.Loop: + subs r11,r11,#1 + add r0,r0,r4 + mov r12,r12,ror#16 + add r1,r1,r5 + mov r10,r10,ror#16 + eor r12,r12,r0,ror#16 + eor r10,r10,r1,ror#16 + add r8,r8,r12 + mov r4,r4,ror#20 + add r9,r9,r10 + mov r5,r5,ror#20 + eor r4,r4,r8,ror#20 + eor r5,r5,r9,ror#20 + add r0,r0,r4 + mov r12,r12,ror#24 + add r1,r1,r5 + mov r10,r10,ror#24 + eor r12,r12,r0,ror#24 + eor r10,r10,r1,ror#24 + add r8,r8,r12 + mov r4,r4,ror#25 + add r9,r9,r10 + mov r5,r5,ror#25 + str r10,[sp,#4*(16+13)] + ldr r10,[sp,#4*(16+15)] + eor r4,r4,r8,ror#25 + eor r5,r5,r9,ror#25 + str r8,[sp,#4*(16+8)] + ldr r8,[sp,#4*(16+10)] + add r2,r2,r6 + mov r14,r14,ror#16 + str r9,[sp,#4*(16+9)] + ldr r9,[sp,#4*(16+11)] + add r3,r3,r7 + mov r10,r10,ror#16 + eor r14,r14,r2,ror#16 + eor r10,r10,r3,ror#16 + add r8,r8,r14 + mov r6,r6,ror#20 + add r9,r9,r10 + mov r7,r7,ror#20 + eor r6,r6,r8,ror#20 + eor r7,r7,r9,ror#20 + add r2,r2,r6 + mov r14,r14,ror#24 + add r3,r3,r7 + mov r10,r10,ror#24 + eor r14,r14,r2,ror#24 + eor r10,r10,r3,ror#24 + add r8,r8,r14 + mov r6,r6,ror#25 + add r9,r9,r10 + mov r7,r7,ror#25 + eor r6,r6,r8,ror#25 + eor r7,r7,r9,ror#25 + add r0,r0,r5 + mov r10,r10,ror#16 + add r1,r1,r6 + mov r12,r12,ror#16 + eor r10,r10,r0,ror#16 + eor r12,r12,r1,ror#16 + add r8,r8,r10 + mov r5,r5,ror#20 + add r9,r9,r12 + mov r6,r6,ror#20 + eor r5,r5,r8,ror#20 + eor r6,r6,r9,ror#20 + add r0,r0,r5 + mov r10,r10,ror#24 + add r1,r1,r6 + mov r12,r12,ror#24 + eor r10,r10,r0,ror#24 + eor r12,r12,r1,ror#24 + add r8,r8,r10 + mov r5,r5,ror#25 + str r10,[sp,#4*(16+15)] + ldr r10,[sp,#4*(16+13)] + add r9,r9,r12 + mov r6,r6,ror#25 + eor r5,r5,r8,ror#25 + eor r6,r6,r9,ror#25 + str r8,[sp,#4*(16+10)] + ldr r8,[sp,#4*(16+8)] + add r2,r2,r7 + mov r10,r10,ror#16 + str r9,[sp,#4*(16+11)] + ldr r9,[sp,#4*(16+9)] + add r3,r3,r4 + mov r14,r14,ror#16 + eor r10,r10,r2,ror#16 + eor r14,r14,r3,ror#16 + add r8,r8,r10 + mov r7,r7,ror#20 + add r9,r9,r14 + mov r4,r4,ror#20 + eor r7,r7,r8,ror#20 + eor r4,r4,r9,ror#20 + add r2,r2,r7 + mov r10,r10,ror#24 + add r3,r3,r4 + mov r14,r14,ror#24 + eor r10,r10,r2,ror#24 + eor r14,r14,r3,ror#24 + add r8,r8,r10 + mov r7,r7,ror#25 + add r9,r9,r14 + mov r4,r4,ror#25 + eor r7,r7,r8,ror#25 + eor r4,r4,r9,ror#25 + bne .Loop + + ldr r11,[sp,#4*(32+2)] @ load len + + str r8, [sp,#4*(16+8)] @ modulo-scheduled store + str r9, [sp,#4*(16+9)] + str r12,[sp,#4*(16+12)] + str r10, [sp,#4*(16+13)] + str r14,[sp,#4*(16+14)] + + @ at this point we have first half of 512-bit result in + @ rx and second half at sp+4*(16+8) + + cmp r11,#64 @ done yet? +#ifdef __thumb2__ + itete lo +#endif + addlo r12,sp,#4*(0) @ shortcut or ... + ldrhs r12,[sp,#4*(32+1)] @ ... load inp + addlo r14,sp,#4*(0) @ shortcut or ... + ldrhs r14,[sp,#4*(32+0)] @ ... load out + + ldr r8,[sp,#4*(0)] @ load key material + ldr r9,[sp,#4*(1)] + +#if __ARM_ARCH__ >= 6 || !defined(__ARMEB__) +#if __ARM_ARCH__ < 7 + orr r10,r12,r14 + tst r10,#3 @ are input and output aligned? + ldr r10,[sp,#4*(2)] + bne .Lunaligned + cmp r11,#64 @ restore flags +#else + ldr r10,[sp,#4*(2)] +#endif + ldr r11,[sp,#4*(3)] + + add r0,r0,r8 @ accumulate key material + add r1,r1,r9 +#ifdef __thumb2__ + itt hs +#endif + ldrhs r8,[r12],#16 @ load input + ldrhs r9,[r12,#-12] + + add r2,r2,r10 + add r3,r3,r11 +#ifdef __thumb2__ + itt hs +#endif + ldrhs r10,[r12,#-8] + ldrhs r11,[r12,#-4] +#if __ARM_ARCH__ >= 6 && defined(__ARMEB__) + rev r0,r0 + rev r1,r1 + rev r2,r2 + rev r3,r3 +#endif +#ifdef __thumb2__ + itt hs +#endif + eorhs r0,r0,r8 @ xor with input + eorhs r1,r1,r9 + add r8,sp,#4*(4) + str r0,[r14],#16 @ store output +#ifdef __thumb2__ + itt hs +#endif + eorhs r2,r2,r10 + eorhs r3,r3,r11 + ldmia r8,{r8-r11} @ load key material + str r1,[r14,#-12] + str r2,[r14,#-8] + str r3,[r14,#-4] + + add r4,r4,r8 @ accumulate key material + add r5,r5,r9 +#ifdef __thumb2__ + itt hs +#endif + ldrhs r8,[r12],#16 @ load input + ldrhs r9,[r12,#-12] + add r6,r6,r10 + add r7,r7,r11 +#ifdef __thumb2__ + itt hs +#endif + ldrhs r10,[r12,#-8] + ldrhs r11,[r12,#-4] +#if __ARM_ARCH__ >= 6 && defined(__ARMEB__) + rev r4,r4 + rev r5,r5 + rev r6,r6 + rev r7,r7 +#endif +#ifdef __thumb2__ + itt hs +#endif + eorhs r4,r4,r8 + eorhs r5,r5,r9 + add r8,sp,#4*(8) + str r4,[r14],#16 @ store output +#ifdef __thumb2__ + itt hs +#endif + eorhs r6,r6,r10 + eorhs r7,r7,r11 + str r5,[r14,#-12] + ldmia r8,{r8-r11} @ load key material + str r6,[r14,#-8] + add r0,sp,#4*(16+8) + str r7,[r14,#-4] + + ldmia r0,{r0-r7} @ load second half + + add r0,r0,r8 @ accumulate key material + add r1,r1,r9 +#ifdef __thumb2__ + itt hs +#endif + ldrhs r8,[r12],#16 @ load input + ldrhs r9,[r12,#-12] +#ifdef __thumb2__ + itt hi +#endif + strhi r10,[sp,#4*(16+10)] @ copy "rx" while at it + strhi r11,[sp,#4*(16+11)] @ copy "rx" while at it + add r2,r2,r10 + add r3,r3,r11 +#ifdef __thumb2__ + itt hs +#endif + ldrhs r10,[r12,#-8] + ldrhs r11,[r12,#-4] +#if __ARM_ARCH__ >= 6 && defined(__ARMEB__) + rev r0,r0 + rev r1,r1 + rev r2,r2 + rev r3,r3 +#endif +#ifdef __thumb2__ + itt hs +#endif + eorhs r0,r0,r8 + eorhs r1,r1,r9 + add r8,sp,#4*(12) + str r0,[r14],#16 @ store output +#ifdef __thumb2__ + itt hs +#endif + eorhs r2,r2,r10 + eorhs r3,r3,r11 + str r1,[r14,#-12] + ldmia r8,{r8-r11} @ load key material + str r2,[r14,#-8] + str r3,[r14,#-4] + + add r4,r4,r8 @ accumulate key material + add r5,r5,r9 +#ifdef __thumb2__ + itt hi +#endif + addhi r8,r8,#1 @ next counter value + strhi r8,[sp,#4*(12)] @ save next counter value +#ifdef __thumb2__ + itt hs +#endif + ldrhs r8,[r12],#16 @ load input + ldrhs r9,[r12,#-12] + add r6,r6,r10 + add r7,r7,r11 +#ifdef __thumb2__ + itt hs +#endif + ldrhs r10,[r12,#-8] + ldrhs r11,[r12,#-4] +#if __ARM_ARCH__ >= 6 && defined(__ARMEB__) + rev r4,r4 + rev r5,r5 + rev r6,r6 + rev r7,r7 +#endif +#ifdef __thumb2__ + itt hs +#endif + eorhs r4,r4,r8 + eorhs r5,r5,r9 +#ifdef __thumb2__ + it ne +#endif + ldrne r8,[sp,#4*(32+2)] @ re-load len +#ifdef __thumb2__ + itt hs +#endif + eorhs r6,r6,r10 + eorhs r7,r7,r11 + str r4,[r14],#16 @ store output + str r5,[r14,#-12] +#ifdef __thumb2__ + it hs +#endif + subhs r11,r8,#64 @ len-=64 + str r6,[r14,#-8] + str r7,[r14,#-4] + bhi .Loop_outer + + beq .Ldone +#if __ARM_ARCH__ < 7 + b .Ltail + +.align 4 +.Lunaligned: @ unaligned endian-neutral path + cmp r11,#64 @ restore flags +#endif +#endif +#if __ARM_ARCH__ < 7 + ldr r11,[sp,#4*(3)] + add r0,r0,r8 @ accumulate key material + add r1,r1,r9 + add r2,r2,r10 +#ifdef __thumb2__ + itete lo +#endif + eorlo r8,r8,r8 @ zero or ... + ldrbhs r8,[r12],#16 @ ... load input + eorlo r9,r9,r9 + ldrbhs r9,[r12,#-12] + + add r3,r3,r11 +#ifdef __thumb2__ + itete lo +#endif + eorlo r10,r10,r10 + ldrbhs r10,[r12,#-8] + eorlo r11,r11,r11 + ldrbhs r11,[r12,#-4] + + eor r0,r8,r0 @ xor with input (or zero) + eor r1,r9,r1 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r8,[r12,#-15] @ load more input + ldrbhs r9,[r12,#-11] + eor r2,r10,r2 + strb r0,[r14],#16 @ store output + eor r3,r11,r3 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r10,[r12,#-7] + ldrbhs r11,[r12,#-3] + strb r1,[r14,#-12] + eor r0,r8,r0,lsr#8 + strb r2,[r14,#-8] + eor r1,r9,r1,lsr#8 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r8,[r12,#-14] @ load more input + ldrbhs r9,[r12,#-10] + strb r3,[r14,#-4] + eor r2,r10,r2,lsr#8 + strb r0,[r14,#-15] + eor r3,r11,r3,lsr#8 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r10,[r12,#-6] + ldrbhs r11,[r12,#-2] + strb r1,[r14,#-11] + eor r0,r8,r0,lsr#8 + strb r2,[r14,#-7] + eor r1,r9,r1,lsr#8 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r8,[r12,#-13] @ load more input + ldrbhs r9,[r12,#-9] + strb r3,[r14,#-3] + eor r2,r10,r2,lsr#8 + strb r0,[r14,#-14] + eor r3,r11,r3,lsr#8 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r10,[r12,#-5] + ldrbhs r11,[r12,#-1] + strb r1,[r14,#-10] + strb r2,[r14,#-6] + eor r0,r8,r0,lsr#8 + strb r3,[r14,#-2] + eor r1,r9,r1,lsr#8 + strb r0,[r14,#-13] + eor r2,r10,r2,lsr#8 + strb r1,[r14,#-9] + eor r3,r11,r3,lsr#8 + strb r2,[r14,#-5] + strb r3,[r14,#-1] + add r8,sp,#4*(4+0) + ldmia r8,{r8-r11} @ load key material + add r0,sp,#4*(16+8) + add r4,r4,r8 @ accumulate key material + add r5,r5,r9 + add r6,r6,r10 +#ifdef __thumb2__ + itete lo +#endif + eorlo r8,r8,r8 @ zero or ... + ldrbhs r8,[r12],#16 @ ... load input + eorlo r9,r9,r9 + ldrbhs r9,[r12,#-12] + + add r7,r7,r11 +#ifdef __thumb2__ + itete lo +#endif + eorlo r10,r10,r10 + ldrbhs r10,[r12,#-8] + eorlo r11,r11,r11 + ldrbhs r11,[r12,#-4] + + eor r4,r8,r4 @ xor with input (or zero) + eor r5,r9,r5 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r8,[r12,#-15] @ load more input + ldrbhs r9,[r12,#-11] + eor r6,r10,r6 + strb r4,[r14],#16 @ store output + eor r7,r11,r7 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r10,[r12,#-7] + ldrbhs r11,[r12,#-3] + strb r5,[r14,#-12] + eor r4,r8,r4,lsr#8 + strb r6,[r14,#-8] + eor r5,r9,r5,lsr#8 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r8,[r12,#-14] @ load more input + ldrbhs r9,[r12,#-10] + strb r7,[r14,#-4] + eor r6,r10,r6,lsr#8 + strb r4,[r14,#-15] + eor r7,r11,r7,lsr#8 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r10,[r12,#-6] + ldrbhs r11,[r12,#-2] + strb r5,[r14,#-11] + eor r4,r8,r4,lsr#8 + strb r6,[r14,#-7] + eor r5,r9,r5,lsr#8 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r8,[r12,#-13] @ load more input + ldrbhs r9,[r12,#-9] + strb r7,[r14,#-3] + eor r6,r10,r6,lsr#8 + strb r4,[r14,#-14] + eor r7,r11,r7,lsr#8 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r10,[r12,#-5] + ldrbhs r11,[r12,#-1] + strb r5,[r14,#-10] + strb r6,[r14,#-6] + eor r4,r8,r4,lsr#8 + strb r7,[r14,#-2] + eor r5,r9,r5,lsr#8 + strb r4,[r14,#-13] + eor r6,r10,r6,lsr#8 + strb r5,[r14,#-9] + eor r7,r11,r7,lsr#8 + strb r6,[r14,#-5] + strb r7,[r14,#-1] + add r8,sp,#4*(4+4) + ldmia r8,{r8-r11} @ load key material + ldmia r0,{r0-r7} @ load second half +#ifdef __thumb2__ + itt hi +#endif + strhi r10,[sp,#4*(16+10)] @ copy "rx" + strhi r11,[sp,#4*(16+11)] @ copy "rx" + add r0,r0,r8 @ accumulate key material + add r1,r1,r9 + add r2,r2,r10 +#ifdef __thumb2__ + itete lo +#endif + eorlo r8,r8,r8 @ zero or ... + ldrbhs r8,[r12],#16 @ ... load input + eorlo r9,r9,r9 + ldrbhs r9,[r12,#-12] + + add r3,r3,r11 +#ifdef __thumb2__ + itete lo +#endif + eorlo r10,r10,r10 + ldrbhs r10,[r12,#-8] + eorlo r11,r11,r11 + ldrbhs r11,[r12,#-4] + + eor r0,r8,r0 @ xor with input (or zero) + eor r1,r9,r1 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r8,[r12,#-15] @ load more input + ldrbhs r9,[r12,#-11] + eor r2,r10,r2 + strb r0,[r14],#16 @ store output + eor r3,r11,r3 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r10,[r12,#-7] + ldrbhs r11,[r12,#-3] + strb r1,[r14,#-12] + eor r0,r8,r0,lsr#8 + strb r2,[r14,#-8] + eor r1,r9,r1,lsr#8 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r8,[r12,#-14] @ load more input + ldrbhs r9,[r12,#-10] + strb r3,[r14,#-4] + eor r2,r10,r2,lsr#8 + strb r0,[r14,#-15] + eor r3,r11,r3,lsr#8 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r10,[r12,#-6] + ldrbhs r11,[r12,#-2] + strb r1,[r14,#-11] + eor r0,r8,r0,lsr#8 + strb r2,[r14,#-7] + eor r1,r9,r1,lsr#8 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r8,[r12,#-13] @ load more input + ldrbhs r9,[r12,#-9] + strb r3,[r14,#-3] + eor r2,r10,r2,lsr#8 + strb r0,[r14,#-14] + eor r3,r11,r3,lsr#8 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r10,[r12,#-5] + ldrbhs r11,[r12,#-1] + strb r1,[r14,#-10] + strb r2,[r14,#-6] + eor r0,r8,r0,lsr#8 + strb r3,[r14,#-2] + eor r1,r9,r1,lsr#8 + strb r0,[r14,#-13] + eor r2,r10,r2,lsr#8 + strb r1,[r14,#-9] + eor r3,r11,r3,lsr#8 + strb r2,[r14,#-5] + strb r3,[r14,#-1] + add r8,sp,#4*(4+8) + ldmia r8,{r8-r11} @ load key material + add r4,r4,r8 @ accumulate key material +#ifdef __thumb2__ + itt hi +#endif + addhi r8,r8,#1 @ next counter value + strhi r8,[sp,#4*(12)] @ save next counter value + add r5,r5,r9 + add r6,r6,r10 +#ifdef __thumb2__ + itete lo +#endif + eorlo r8,r8,r8 @ zero or ... + ldrbhs r8,[r12],#16 @ ... load input + eorlo r9,r9,r9 + ldrbhs r9,[r12,#-12] + + add r7,r7,r11 +#ifdef __thumb2__ + itete lo +#endif + eorlo r10,r10,r10 + ldrbhs r10,[r12,#-8] + eorlo r11,r11,r11 + ldrbhs r11,[r12,#-4] + + eor r4,r8,r4 @ xor with input (or zero) + eor r5,r9,r5 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r8,[r12,#-15] @ load more input + ldrbhs r9,[r12,#-11] + eor r6,r10,r6 + strb r4,[r14],#16 @ store output + eor r7,r11,r7 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r10,[r12,#-7] + ldrbhs r11,[r12,#-3] + strb r5,[r14,#-12] + eor r4,r8,r4,lsr#8 + strb r6,[r14,#-8] + eor r5,r9,r5,lsr#8 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r8,[r12,#-14] @ load more input + ldrbhs r9,[r12,#-10] + strb r7,[r14,#-4] + eor r6,r10,r6,lsr#8 + strb r4,[r14,#-15] + eor r7,r11,r7,lsr#8 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r10,[r12,#-6] + ldrbhs r11,[r12,#-2] + strb r5,[r14,#-11] + eor r4,r8,r4,lsr#8 + strb r6,[r14,#-7] + eor r5,r9,r5,lsr#8 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r8,[r12,#-13] @ load more input + ldrbhs r9,[r12,#-9] + strb r7,[r14,#-3] + eor r6,r10,r6,lsr#8 + strb r4,[r14,#-14] + eor r7,r11,r7,lsr#8 +#ifdef __thumb2__ + itt hs +#endif + ldrbhs r10,[r12,#-5] + ldrbhs r11,[r12,#-1] + strb r5,[r14,#-10] + strb r6,[r14,#-6] + eor r4,r8,r4,lsr#8 + strb r7,[r14,#-2] + eor r5,r9,r5,lsr#8 + strb r4,[r14,#-13] + eor r6,r10,r6,lsr#8 + strb r5,[r14,#-9] + eor r7,r11,r7,lsr#8 + strb r6,[r14,#-5] + strb r7,[r14,#-1] +#ifdef __thumb2__ + it ne +#endif + ldrne r8,[sp,#4*(32+2)] @ re-load len +#ifdef __thumb2__ + it hs +#endif + subhs r11,r8,#64 @ len-=64 + bhi .Loop_outer + + beq .Ldone +#endif + +.Ltail: + ldr r12,[sp,#4*(32+1)] @ load inp + add r9,sp,#4*(0) + ldr r14,[sp,#4*(32+0)] @ load out + +.Loop_tail: + ldrb r10,[r9],#1 @ read buffer on stack + ldrb r11,[r12],#1 @ read input + subs r8,r8,#1 + eor r11,r11,r10 + strb r11,[r14],#1 @ store output + bne .Loop_tail + +.Ldone: + add sp,sp,#4*(32+3) +.Lno_data_arm: + ldmia sp!,{r4-r11,pc} +.size chacha20_arm,.-chacha20_arm diff --git a/crypto/chacha20/chacha20-arm64.pl b/crypto/chacha20/chacha20-arm64.pl new file mode 100644 index 0000000..4a838bc --- /dev/null +++ b/crypto/chacha20/chacha20-arm64.pl @@ -0,0 +1,1136 @@ +#! /usr/bin/env perl +# Copyright 2016 The OpenSSL Project Authors. All Rights Reserved. +# +# Licensed under the OpenSSL license (the "License"). You may not use +# this file except in compliance with the License. You can obtain a copy +# in the file LICENSE in the source distribution or at +# https://www.openssl.org/source/license.html + +# +# ==================================================================== +# Written by Andy Polyakov for the OpenSSL +# project. The module is, however, dual licensed under OpenSSL and +# CRYPTOGAMS licenses depending on where you obtain it. For further +# details see http://www.openssl.org/~appro/cryptogams/. +# ==================================================================== +# +# June 2015 +# +# ChaCha20 for ARMv8. +# +# Performance in cycles per byte out of large buffer. +# +# IALU/gcc-4.9 3xNEON+1xIALU 6xNEON+2xIALU +# +# Apple A7 5.50/+49% 3.33 1.70 +# Cortex-A53 8.40/+80% 4.72 4.72(*) +# Cortex-A57 8.06/+43% 4.90 4.43(**) +# Denver 4.50/+82% 2.63 2.67(*) +# X-Gene 9.50/+46% 8.82 8.89(*) +# Mongoose 8.00/+44% 3.64 3.25 +# Kryo 8.17/+50% 4.83 4.65 +# +# (*) it's expected that doubling interleave factor doesn't help +# all processors, only those with higher NEON latency and +# higher instruction issue rate; +# (**) expected improvement was actually higher; + +$flavour=shift; +$output=shift; + +$0 =~ m/(.*[\/\\])[^\/\\]+$/; $dir=$1; +( $xlate="${dir}arm-xlate.pl" and -f $xlate ) or +( $xlate="${dir}../../perlasm/arm-xlate.pl" and -f $xlate) or +die "can't locate arm-xlate.pl"; + +open OUT,"| \"$^X\" $xlate $flavour $output"; +*STDOUT=*OUT; + +sub AUTOLOAD() # thunk [simplified] x86-style perlasm +{ my $opcode = $AUTOLOAD; $opcode =~ s/.*:://; $opcode =~ s/_/\./; + my $arg = pop; + $arg = "#$arg" if ($arg*1 eq $arg); + $code .= "\t$opcode\t".join(',',@_,$arg)."\n"; +} + +my ($out,$inp,$len,$key,$ctr) = map("x$_",(0..4)); + +my @x=map("x$_",(5..17,19..21)); +my @d=map("x$_",(22..28,30)); + +sub ROUND { +my ($a0,$b0,$c0,$d0)=@_; +my ($a1,$b1,$c1,$d1)=map(($_&~3)+(($_+1)&3),($a0,$b0,$c0,$d0)); +my ($a2,$b2,$c2,$d2)=map(($_&~3)+(($_+1)&3),($a1,$b1,$c1,$d1)); +my ($a3,$b3,$c3,$d3)=map(($_&~3)+(($_+1)&3),($a2,$b2,$c2,$d2)); + + ( + "&add_32 (@x[$a0],@x[$a0],@x[$b0])", + "&add_32 (@x[$a1],@x[$a1],@x[$b1])", + "&add_32 (@x[$a2],@x[$a2],@x[$b2])", + "&add_32 (@x[$a3],@x[$a3],@x[$b3])", + "&eor_32 (@x[$d0],@x[$d0],@x[$a0])", + "&eor_32 (@x[$d1],@x[$d1],@x[$a1])", + "&eor_32 (@x[$d2],@x[$d2],@x[$a2])", + "&eor_32 (@x[$d3],@x[$d3],@x[$a3])", + "&ror_32 (@x[$d0],@x[$d0],16)", + "&ror_32 (@x[$d1],@x[$d1],16)", + "&ror_32 (@x[$d2],@x[$d2],16)", + "&ror_32 (@x[$d3],@x[$d3],16)", + + "&add_32 (@x[$c0],@x[$c0],@x[$d0])", + "&add_32 (@x[$c1],@x[$c1],@x[$d1])", + "&add_32 (@x[$c2],@x[$c2],@x[$d2])", + "&add_32 (@x[$c3],@x[$c3],@x[$d3])", + "&eor_32 (@x[$b0],@x[$b0],@x[$c0])", + "&eor_32 (@x[$b1],@x[$b1],@x[$c1])", + "&eor_32 (@x[$b2],@x[$b2],@x[$c2])", + "&eor_32 (@x[$b3],@x[$b3],@x[$c3])", + "&ror_32 (@x[$b0],@x[$b0],20)", + "&ror_32 (@x[$b1],@x[$b1],20)", + "&ror_32 (@x[$b2],@x[$b2],20)", + "&ror_32 (@x[$b3],@x[$b3],20)", + + "&add_32 (@x[$a0],@x[$a0],@x[$b0])", + "&add_32 (@x[$a1],@x[$a1],@x[$b1])", + "&add_32 (@x[$a2],@x[$a2],@x[$b2])", + "&add_32 (@x[$a3],@x[$a3],@x[$b3])", + "&eor_32 (@x[$d0],@x[$d0],@x[$a0])", + "&eor_32 (@x[$d1],@x[$d1],@x[$a1])", + "&eor_32 (@x[$d2],@x[$d2],@x[$a2])", + "&eor_32 (@x[$d3],@x[$d3],@x[$a3])", + "&ror_32 (@x[$d0],@x[$d0],24)", + "&ror_32 (@x[$d1],@x[$d1],24)", + "&ror_32 (@x[$d2],@x[$d2],24)", + "&ror_32 (@x[$d3],@x[$d3],24)", + + "&add_32 (@x[$c0],@x[$c0],@x[$d0])", + "&add_32 (@x[$c1],@x[$c1],@x[$d1])", + "&add_32 (@x[$c2],@x[$c2],@x[$d2])", + "&add_32 (@x[$c3],@x[$c3],@x[$d3])", + "&eor_32 (@x[$b0],@x[$b0],@x[$c0])", + "&eor_32 (@x[$b1],@x[$b1],@x[$c1])", + "&eor_32 (@x[$b2],@x[$b2],@x[$c2])", + "&eor_32 (@x[$b3],@x[$b3],@x[$c3])", + "&ror_32 (@x[$b0],@x[$b0],25)", + "&ror_32 (@x[$b1],@x[$b1],25)", + "&ror_32 (@x[$b2],@x[$b2],25)", + "&ror_32 (@x[$b3],@x[$b3],25)" + ); +} + +$code.=<<___; +#include "arm_arch.h" + +.text + +.extern OPENSSL_armcap_P + +.align 5 +.Lsigma: +.quad 0x3320646e61707865,0x6b20657479622d32 // endian-neutral +.Lone: +.long 1,0,0,0 +.LOPENSSL_armcap_P: +#ifdef __ILP32__ +.long OPENSSL_armcap_P-. +#else +.quad OPENSSL_armcap_P-. +#endif +.asciz "ChaCha20 for ARMv8, CRYPTOGAMS by " + +.globl ChaCha20_ctr32 +.type ChaCha20_ctr32,%function +.align 5 +ChaCha20_ctr32: + cbz $len,.Labort + adr @x[0],.LOPENSSL_armcap_P + cmp $len,#192 + b.lo .Lshort +#ifdef __ILP32__ + ldrsw @x[1],[@x[0]] +#else + ldr @x[1],[@x[0]] +#endif + ldr w17,[@x[1],@x[0]] + tst w17,#ARMV7_NEON + b.ne ChaCha20_neon + +.Lshort: + stp x29,x30,[sp,#-96]! + add x29,sp,#0 + + adr @x[0],.Lsigma + stp x19,x20,[sp,#16] + stp x21,x22,[sp,#32] + stp x23,x24,[sp,#48] + stp x25,x26,[sp,#64] + stp x27,x28,[sp,#80] + sub sp,sp,#64 + + ldp @d[0],@d[1],[@x[0]] // load sigma + ldp @d[2],@d[3],[$key] // load key + ldp @d[4],@d[5],[$key,#16] + ldp @d[6],@d[7],[$ctr] // load counter +#ifdef __ARMEB__ + ror @d[2],@d[2],#32 + ror @d[3],@d[3],#32 + ror @d[4],@d[4],#32 + ror @d[5],@d[5],#32 + ror @d[6],@d[6],#32 + ror @d[7],@d[7],#32 +#endif + +.Loop_outer: + mov.32 @x[0],@d[0] // unpack key block + lsr @x[1],@d[0],#32 + mov.32 @x[2],@d[1] + lsr @x[3],@d[1],#32 + mov.32 @x[4],@d[2] + lsr @x[5],@d[2],#32 + mov.32 @x[6],@d[3] + lsr @x[7],@d[3],#32 + mov.32 @x[8],@d[4] + lsr @x[9],@d[4],#32 + mov.32 @x[10],@d[5] + lsr @x[11],@d[5],#32 + mov.32 @x[12],@d[6] + lsr @x[13],@d[6],#32 + mov.32 @x[14],@d[7] + lsr @x[15],@d[7],#32 + + mov $ctr,#10 + subs $len,$len,#64 +.Loop: + sub $ctr,$ctr,#1 +___ + foreach (&ROUND(0, 4, 8,12)) { eval; } + foreach (&ROUND(0, 5,10,15)) { eval; } +$code.=<<___; + cbnz $ctr,.Loop + + add.32 @x[0],@x[0],@d[0] // accumulate key block + add @x[1],@x[1],@d[0],lsr#32 + add.32 @x[2],@x[2],@d[1] + add @x[3],@x[3],@d[1],lsr#32 + add.32 @x[4],@x[4],@d[2] + add @x[5],@x[5],@d[2],lsr#32 + add.32 @x[6],@x[6],@d[3] + add @x[7],@x[7],@d[3],lsr#32 + add.32 @x[8],@x[8],@d[4] + add @x[9],@x[9],@d[4],lsr#32 + add.32 @x[10],@x[10],@d[5] + add @x[11],@x[11],@d[5],lsr#32 + add.32 @x[12],@x[12],@d[6] + add @x[13],@x[13],@d[6],lsr#32 + add.32 @x[14],@x[14],@d[7] + add @x[15],@x[15],@d[7],lsr#32 + + b.lo .Ltail + + add @x[0],@x[0],@x[1],lsl#32 // pack + add @x[2],@x[2],@x[3],lsl#32 + ldp @x[1],@x[3],[$inp,#0] // load input + add @x[4],@x[4],@x[5],lsl#32 + add @x[6],@x[6],@x[7],lsl#32 + ldp @x[5],@x[7],[$inp,#16] + add @x[8],@x[8],@x[9],lsl#32 + add @x[10],@x[10],@x[11],lsl#32 + ldp @x[9],@x[11],[$inp,#32] + add @x[12],@x[12],@x[13],lsl#32 + add @x[14],@x[14],@x[15],lsl#32 + ldp @x[13],@x[15],[$inp,#48] + add $inp,$inp,#64 +#ifdef __ARMEB__ + rev @x[0],@x[0] + rev @x[2],@x[2] + rev @x[4],@x[4] + rev @x[6],@x[6] + rev @x[8],@x[8] + rev @x[10],@x[10] + rev @x[12],@x[12] + rev @x[14],@x[14] +#endif + eor @x[0],@x[0],@x[1] + eor @x[2],@x[2],@x[3] + eor @x[4],@x[4],@x[5] + eor @x[6],@x[6],@x[7] + eor @x[8],@x[8],@x[9] + eor @x[10],@x[10],@x[11] + eor @x[12],@x[12],@x[13] + eor @x[14],@x[14],@x[15] + + stp @x[0],@x[2],[$out,#0] // store output + add @d[6],@d[6],#1 // increment counter + stp @x[4],@x[6],[$out,#16] + stp @x[8],@x[10],[$out,#32] + stp @x[12],@x[14],[$out,#48] + add $out,$out,#64 + + b.hi .Loop_outer + + ldp x19,x20,[x29,#16] + add sp,sp,#64 + ldp x21,x22,[x29,#32] + ldp x23,x24,[x29,#48] + ldp x25,x26,[x29,#64] + ldp x27,x28,[x29,#80] + ldp x29,x30,[sp],#96 +.Labort: + ret + +.align 4 +.Ltail: + add $len,$len,#64 +.Less_than_64: + sub $out,$out,#1 + add $inp,$inp,$len + add $out,$out,$len + add $ctr,sp,$len + neg $len,$len + + add @x[0],@x[0],@x[1],lsl#32 // pack + add @x[2],@x[2],@x[3],lsl#32 + add @x[4],@x[4],@x[5],lsl#32 + add @x[6],@x[6],@x[7],lsl#32 + add @x[8],@x[8],@x[9],lsl#32 + add @x[10],@x[10],@x[11],lsl#32 + add @x[12],@x[12],@x[13],lsl#32 + add @x[14],@x[14],@x[15],lsl#32 +#ifdef __ARMEB__ + rev @x[0],@x[0] + rev @x[2],@x[2] + rev @x[4],@x[4] + rev @x[6],@x[6] + rev @x[8],@x[8] + rev @x[10],@x[10] + rev @x[12],@x[12] + rev @x[14],@x[14] +#endif + stp @x[0],@x[2],[sp,#0] + stp @x[4],@x[6],[sp,#16] + stp @x[8],@x[10],[sp,#32] + stp @x[12],@x[14],[sp,#48] + +.Loop_tail: + ldrb w10,[$inp,$len] + ldrb w11,[$ctr,$len] + add $len,$len,#1 + eor w10,w10,w11 + strb w10,[$out,$len] + cbnz $len,.Loop_tail + + stp xzr,xzr,[sp,#0] + stp xzr,xzr,[sp,#16] + stp xzr,xzr,[sp,#32] + stp xzr,xzr,[sp,#48] + + ldp x19,x20,[x29,#16] + add sp,sp,#64 + ldp x21,x22,[x29,#32] + ldp x23,x24,[x29,#48] + ldp x25,x26,[x29,#64] + ldp x27,x28,[x29,#80] + ldp x29,x30,[sp],#96 + ret +.size ChaCha20_ctr32,.-ChaCha20_ctr32 +___ + +{{{ +my ($A0,$B0,$C0,$D0,$A1,$B1,$C1,$D1,$A2,$B2,$C2,$D2,$T0,$T1,$T2,$T3) = + map("v$_.4s",(0..7,16..23)); +my (@K)=map("v$_.4s",(24..30)); +my $ONE="v31.4s"; + +sub NEONROUND { +my $odd = pop; +my ($a,$b,$c,$d,$t)=@_; + + ( + "&add ('$a','$a','$b')", + "&eor ('$d','$d','$a')", + "&rev32_16 ('$d','$d')", # vrot ($d,16) + + "&add ('$c','$c','$d')", + "&eor ('$t','$b','$c')", + "&ushr ('$b','$t',20)", + "&sli ('$b','$t',12)", + + "&add ('$a','$a','$b')", + "&eor ('$t','$d','$a')", + "&ushr ('$d','$t',24)", + "&sli ('$d','$t',8)", + + "&add ('$c','$c','$d')", + "&eor ('$t','$b','$c')", + "&ushr ('$b','$t',25)", + "&sli ('$b','$t',7)", + + "&ext ('$c','$c','$c',8)", + "&ext ('$d','$d','$d',$odd?4:12)", + "&ext ('$b','$b','$b',$odd?12:4)" + ); +} + +$code.=<<___; + +.type ChaCha20_neon,%function +.align 5 +ChaCha20_neon: + stp x29,x30,[sp,#-96]! + add x29,sp,#0 + + adr @x[0],.Lsigma + stp x19,x20,[sp,#16] + stp x21,x22,[sp,#32] + stp x23,x24,[sp,#48] + stp x25,x26,[sp,#64] + stp x27,x28,[sp,#80] + cmp $len,#512 + b.hs .L512_or_more_neon + + sub sp,sp,#64 + + ldp @d[0],@d[1],[@x[0]] // load sigma + ld1 {@K[0]},[@x[0]],#16 + ldp @d[2],@d[3],[$key] // load key + ldp @d[4],@d[5],[$key,#16] + ld1 {@K[1],@K[2]},[$key] + ldp @d[6],@d[7],[$ctr] // load counter + ld1 {@K[3]},[$ctr] + ld1 {$ONE},[@x[0]] +#ifdef __ARMEB__ + rev64 @K[0],@K[0] + ror @d[2],@d[2],#32 + ror @d[3],@d[3],#32 + ror @d[4],@d[4],#32 + ror @d[5],@d[5],#32 + ror @d[6],@d[6],#32 + ror @d[7],@d[7],#32 +#endif + add @K[3],@K[3],$ONE // += 1 + add @K[4],@K[3],$ONE + add @K[5],@K[4],$ONE + shl $ONE,$ONE,#2 // 1 -> 4 + +.Loop_outer_neon: + mov.32 @x[0],@d[0] // unpack key block + lsr @x[1],@d[0],#32 + mov $A0,@K[0] + mov.32 @x[2],@d[1] + lsr @x[3],@d[1],#32 + mov $A1,@K[0] + mov.32 @x[4],@d[2] + lsr @x[5],@d[2],#32 + mov $A2,@K[0] + mov.32 @x[6],@d[3] + mov $B0,@K[1] + lsr @x[7],@d[3],#32 + mov $B1,@K[1] + mov.32 @x[8],@d[4] + mov $B2,@K[1] + lsr @x[9],@d[4],#32 + mov $D0,@K[3] + mov.32 @x[10],@d[5] + mov $D1,@K[4] + lsr @x[11],@d[5],#32 + mov $D2,@K[5] + mov.32 @x[12],@d[6] + mov $C0,@K[2] + lsr @x[13],@d[6],#32 + mov $C1,@K[2] + mov.32 @x[14],@d[7] + mov $C2,@K[2] + lsr @x[15],@d[7],#32 + + mov $ctr,#10 + subs $len,$len,#256 +.Loop_neon: + sub $ctr,$ctr,#1 +___ + my @thread0=&NEONROUND($A0,$B0,$C0,$D0,$T0,0); + my @thread1=&NEONROUND($A1,$B1,$C1,$D1,$T1,0); + my @thread2=&NEONROUND($A2,$B2,$C2,$D2,$T2,0); + my @thread3=&ROUND(0,4,8,12); + + foreach (@thread0) { + eval; eval(shift(@thread3)); + eval(shift(@thread1)); eval(shift(@thread3)); + eval(shift(@thread2)); eval(shift(@thread3)); + } + + @thread0=&NEONROUND($A0,$B0,$C0,$D0,$T0,1); + @thread1=&NEONROUND($A1,$B1,$C1,$D1,$T1,1); + @thread2=&NEONROUND($A2,$B2,$C2,$D2,$T2,1); + @thread3=&ROUND(0,5,10,15); + + foreach (@thread0) { + eval; eval(shift(@thread3)); + eval(shift(@thread1)); eval(shift(@thread3)); + eval(shift(@thread2)); eval(shift(@thread3)); + } +$code.=<<___; + cbnz $ctr,.Loop_neon + + add.32 @x[0],@x[0],@d[0] // accumulate key block + add $A0,$A0,@K[0] + add @x[1],@x[1],@d[0],lsr#32 + add $A1,$A1,@K[0] + add.32 @x[2],@x[2],@d[1] + add $A2,$A2,@K[0] + add @x[3],@x[3],@d[1],lsr#32 + add $C0,$C0,@K[2] + add.32 @x[4],@x[4],@d[2] + add $C1,$C1,@K[2] + add @x[5],@x[5],@d[2],lsr#32 + add $C2,$C2,@K[2] + add.32 @x[6],@x[6],@d[3] + add $D0,$D0,@K[3] + add @x[7],@x[7],@d[3],lsr#32 + add.32 @x[8],@x[8],@d[4] + add $D1,$D1,@K[4] + add @x[9],@x[9],@d[4],lsr#32 + add.32 @x[10],@x[10],@d[5] + add $D2,$D2,@K[5] + add @x[11],@x[11],@d[5],lsr#32 + add.32 @x[12],@x[12],@d[6] + add $B0,$B0,@K[1] + add @x[13],@x[13],@d[6],lsr#32 + add.32 @x[14],@x[14],@d[7] + add $B1,$B1,@K[1] + add @x[15],@x[15],@d[7],lsr#32 + add $B2,$B2,@K[1] + + b.lo .Ltail_neon + + add @x[0],@x[0],@x[1],lsl#32 // pack + add @x[2],@x[2],@x[3],lsl#32 + ldp @x[1],@x[3],[$inp,#0] // load input + add @x[4],@x[4],@x[5],lsl#32 + add @x[6],@x[6],@x[7],lsl#32 + ldp @x[5],@x[7],[$inp,#16] + add @x[8],@x[8],@x[9],lsl#32 + add @x[10],@x[10],@x[11],lsl#32 + ldp @x[9],@x[11],[$inp,#32] + add @x[12],@x[12],@x[13],lsl#32 + add @x[14],@x[14],@x[15],lsl#32 + ldp @x[13],@x[15],[$inp,#48] + add $inp,$inp,#64 +#ifdef __ARMEB__ + rev @x[0],@x[0] + rev @x[2],@x[2] + rev @x[4],@x[4] + rev @x[6],@x[6] + rev @x[8],@x[8] + rev @x[10],@x[10] + rev @x[12],@x[12] + rev @x[14],@x[14] +#endif + ld1.8 {$T0-$T3},[$inp],#64 + eor @x[0],@x[0],@x[1] + eor @x[2],@x[2],@x[3] + eor @x[4],@x[4],@x[5] + eor @x[6],@x[6],@x[7] + eor @x[8],@x[8],@x[9] + eor $A0,$A0,$T0 + eor @x[10],@x[10],@x[11] + eor $B0,$B0,$T1 + eor @x[12],@x[12],@x[13] + eor $C0,$C0,$T2 + eor @x[14],@x[14],@x[15] + eor $D0,$D0,$T3 + ld1.8 {$T0-$T3},[$inp],#64 + + stp @x[0],@x[2],[$out,#0] // store output + add @d[6],@d[6],#4 // increment counter + stp @x[4],@x[6],[$out,#16] + add @K[3],@K[3],$ONE // += 4 + stp @x[8],@x[10],[$out,#32] + add @K[4],@K[4],$ONE + stp @x[12],@x[14],[$out,#48] + add @K[5],@K[5],$ONE + add $out,$out,#64 + + st1.8 {$A0-$D0},[$out],#64 + ld1.8 {$A0-$D0},[$inp],#64 + + eor $A1,$A1,$T0 + eor $B1,$B1,$T1 + eor $C1,$C1,$T2 + eor $D1,$D1,$T3 + st1.8 {$A1-$D1},[$out],#64 + + eor $A2,$A2,$A0 + eor $B2,$B2,$B0 + eor $C2,$C2,$C0 + eor $D2,$D2,$D0 + st1.8 {$A2-$D2},[$out],#64 + + b.hi .Loop_outer_neon + + ldp x19,x20,[x29,#16] + add sp,sp,#64 + ldp x21,x22,[x29,#32] + ldp x23,x24,[x29,#48] + ldp x25,x26,[x29,#64] + ldp x27,x28,[x29,#80] + ldp x29,x30,[sp],#96 + ret + +.Ltail_neon: + add $len,$len,#256 + cmp $len,#64 + b.lo .Less_than_64 + + add @x[0],@x[0],@x[1],lsl#32 // pack + add @x[2],@x[2],@x[3],lsl#32 + ldp @x[1],@x[3],[$inp,#0] // load input + add @x[4],@x[4],@x[5],lsl#32 + add @x[6],@x[6],@x[7],lsl#32 + ldp @x[5],@x[7],[$inp,#16] + add @x[8],@x[8],@x[9],lsl#32 + add @x[10],@x[10],@x[11],lsl#32 + ldp @x[9],@x[11],[$inp,#32] + add @x[12],@x[12],@x[13],lsl#32 + add @x[14],@x[14],@x[15],lsl#32 + ldp @x[13],@x[15],[$inp,#48] + add $inp,$inp,#64 +#ifdef __ARMEB__ + rev @x[0],@x[0] + rev @x[2],@x[2] + rev @x[4],@x[4] + rev @x[6],@x[6] + rev @x[8],@x[8] + rev @x[10],@x[10] + rev @x[12],@x[12] + rev @x[14],@x[14] +#endif + eor @x[0],@x[0],@x[1] + eor @x[2],@x[2],@x[3] + eor @x[4],@x[4],@x[5] + eor @x[6],@x[6],@x[7] + eor @x[8],@x[8],@x[9] + eor @x[10],@x[10],@x[11] + eor @x[12],@x[12],@x[13] + eor @x[14],@x[14],@x[15] + + stp @x[0],@x[2],[$out,#0] // store output + add @d[6],@d[6],#4 // increment counter + stp @x[4],@x[6],[$out,#16] + stp @x[8],@x[10],[$out,#32] + stp @x[12],@x[14],[$out,#48] + add $out,$out,#64 + b.eq .Ldone_neon + sub $len,$len,#64 + cmp $len,#64 + b.lo .Less_than_128 + + ld1.8 {$T0-$T3},[$inp],#64 + eor $A0,$A0,$T0 + eor $B0,$B0,$T1 + eor $C0,$C0,$T2 + eor $D0,$D0,$T3 + st1.8 {$A0-$D0},[$out],#64 + b.eq .Ldone_neon + sub $len,$len,#64 + cmp $len,#64 + b.lo .Less_than_192 + + ld1.8 {$T0-$T3},[$inp],#64 + eor $A1,$A1,$T0 + eor $B1,$B1,$T1 + eor $C1,$C1,$T2 + eor $D1,$D1,$T3 + st1.8 {$A1-$D1},[$out],#64 + b.eq .Ldone_neon + sub $len,$len,#64 + + st1.8 {$A2-$D2},[sp] + b .Last_neon + +.Less_than_128: + st1.8 {$A0-$D0},[sp] + b .Last_neon +.Less_than_192: + st1.8 {$A1-$D1},[sp] + b .Last_neon + +.align 4 +.Last_neon: + sub $out,$out,#1 + add $inp,$inp,$len + add $out,$out,$len + add $ctr,sp,$len + neg $len,$len + +.Loop_tail_neon: + ldrb w10,[$inp,$len] + ldrb w11,[$ctr,$len] + add $len,$len,#1 + eor w10,w10,w11 + strb w10,[$out,$len] + cbnz $len,.Loop_tail_neon + + stp xzr,xzr,[sp,#0] + stp xzr,xzr,[sp,#16] + stp xzr,xzr,[sp,#32] + stp xzr,xzr,[sp,#48] + +.Ldone_neon: + ldp x19,x20,[x29,#16] + add sp,sp,#64 + ldp x21,x22,[x29,#32] + ldp x23,x24,[x29,#48] + ldp x25,x26,[x29,#64] + ldp x27,x28,[x29,#80] + ldp x29,x30,[sp],#96 + ret +.size ChaCha20_neon,.-ChaCha20_neon +___ +{ +my ($T0,$T1,$T2,$T3,$T4,$T5)=@K; +my ($A0,$B0,$C0,$D0,$A1,$B1,$C1,$D1,$A2,$B2,$C2,$D2, + $A3,$B3,$C3,$D3,$A4,$B4,$C4,$D4,$A5,$B5,$C5,$D5) = map("v$_.4s",(0..23)); + +$code.=<<___; +.type ChaCha20_512_neon,%function +.align 5 +ChaCha20_512_neon: + stp x29,x30,[sp,#-96]! + add x29,sp,#0 + + adr @x[0],.Lsigma + stp x19,x20,[sp,#16] + stp x21,x22,[sp,#32] + stp x23,x24,[sp,#48] + stp x25,x26,[sp,#64] + stp x27,x28,[sp,#80] + +.L512_or_more_neon: + sub sp,sp,#128+64 + + ldp @d[0],@d[1],[@x[0]] // load sigma + ld1 {@K[0]},[@x[0]],#16 + ldp @d[2],@d[3],[$key] // load key + ldp @d[4],@d[5],[$key,#16] + ld1 {@K[1],@K[2]},[$key] + ldp @d[6],@d[7],[$ctr] // load counter + ld1 {@K[3]},[$ctr] + ld1 {$ONE},[@x[0]] +#ifdef __ARMEB__ + rev64 @K[0],@K[0] + ror @d[2],@d[2],#32 + ror @d[3],@d[3],#32 + ror @d[4],@d[4],#32 + ror @d[5],@d[5],#32 + ror @d[6],@d[6],#32 + ror @d[7],@d[7],#32 +#endif + add @K[3],@K[3],$ONE // += 1 + stp @K[0],@K[1],[sp,#0] // off-load key block, invariant part + add @K[3],@K[3],$ONE // not typo + str @K[2],[sp,#32] + add @K[4],@K[3],$ONE + add @K[5],@K[4],$ONE + add @K[6],@K[5],$ONE + shl $ONE,$ONE,#2 // 1 -> 4 + + stp d8,d9,[sp,#128+0] // meet ABI requirements + stp d10,d11,[sp,#128+16] + stp d12,d13,[sp,#128+32] + stp d14,d15,[sp,#128+48] + + sub $len,$len,#512 // not typo + +.Loop_outer_512_neon: + mov $A0,@K[0] + mov $A1,@K[0] + mov $A2,@K[0] + mov $A3,@K[0] + mov $A4,@K[0] + mov $A5,@K[0] + mov $B0,@K[1] + mov.32 @x[0],@d[0] // unpack key block + mov $B1,@K[1] + lsr @x[1],@d[0],#32 + mov $B2,@K[1] + mov.32 @x[2],@d[1] + mov $B3,@K[1] + lsr @x[3],@d[1],#32 + mov $B4,@K[1] + mov.32 @x[4],@d[2] + mov $B5,@K[1] + lsr @x[5],@d[2],#32 + mov $D0,@K[3] + mov.32 @x[6],@d[3] + mov $D1,@K[4] + lsr @x[7],@d[3],#32 + mov $D2,@K[5] + mov.32 @x[8],@d[4] + mov $D3,@K[6] + lsr @x[9],@d[4],#32 + mov $C0,@K[2] + mov.32 @x[10],@d[5] + mov $C1,@K[2] + lsr @x[11],@d[5],#32 + add $D4,$D0,$ONE // +4 + mov.32 @x[12],@d[6] + add $D5,$D1,$ONE // +4 + lsr @x[13],@d[6],#32 + mov $C2,@K[2] + mov.32 @x[14],@d[7] + mov $C3,@K[2] + lsr @x[15],@d[7],#32 + mov $C4,@K[2] + stp @K[3],@K[4],[sp,#48] // off-load key block, variable part + mov $C5,@K[2] + str @K[5],[sp,#80] + + mov $ctr,#5 + subs $len,$len,#512 +.Loop_upper_neon: + sub $ctr,$ctr,#1 +___ + my @thread0=&NEONROUND($A0,$B0,$C0,$D0,$T0,0); + my @thread1=&NEONROUND($A1,$B1,$C1,$D1,$T1,0); + my @thread2=&NEONROUND($A2,$B2,$C2,$D2,$T2,0); + my @thread3=&NEONROUND($A3,$B3,$C3,$D3,$T3,0); + my @thread4=&NEONROUND($A4,$B4,$C4,$D4,$T4,0); + my @thread5=&NEONROUND($A5,$B5,$C5,$D5,$T5,0); + my @thread67=(&ROUND(0,4,8,12),&ROUND(0,5,10,15)); + my $diff = ($#thread0+1)*6 - $#thread67 - 1; + my $i = 0; + + foreach (@thread0) { + eval; eval(shift(@thread67)); + eval(shift(@thread1)); eval(shift(@thread67)); + eval(shift(@thread2)); eval(shift(@thread67)); + eval(shift(@thread3)); eval(shift(@thread67)); + eval(shift(@thread4)); eval(shift(@thread67)); + eval(shift(@thread5)); eval(shift(@thread67)); + } + + @thread0=&NEONROUND($A0,$B0,$C0,$D0,$T0,1); + @thread1=&NEONROUND($A1,$B1,$C1,$D1,$T1,1); + @thread2=&NEONROUND($A2,$B2,$C2,$D2,$T2,1); + @thread3=&NEONROUND($A3,$B3,$C3,$D3,$T3,1); + @thread4=&NEONROUND($A4,$B4,$C4,$D4,$T4,1); + @thread5=&NEONROUND($A5,$B5,$C5,$D5,$T5,1); + @thread67=(&ROUND(0,4,8,12),&ROUND(0,5,10,15)); + + foreach (@thread0) { + eval; eval(shift(@thread67)); + eval(shift(@thread1)); eval(shift(@thread67)); + eval(shift(@thread2)); eval(shift(@thread67)); + eval(shift(@thread3)); eval(shift(@thread67)); + eval(shift(@thread4)); eval(shift(@thread67)); + eval(shift(@thread5)); eval(shift(@thread67)); + } +$code.=<<___; + cbnz $ctr,.Loop_upper_neon + + add.32 @x[0],@x[0],@d[0] // accumulate key block + add @x[1],@x[1],@d[0],lsr#32 + add.32 @x[2],@x[2],@d[1] + add @x[3],@x[3],@d[1],lsr#32 + add.32 @x[4],@x[4],@d[2] + add @x[5],@x[5],@d[2],lsr#32 + add.32 @x[6],@x[6],@d[3] + add @x[7],@x[7],@d[3],lsr#32 + add.32 @x[8],@x[8],@d[4] + add @x[9],@x[9],@d[4],lsr#32 + add.32 @x[10],@x[10],@d[5] + add @x[11],@x[11],@d[5],lsr#32 + add.32 @x[12],@x[12],@d[6] + add @x[13],@x[13],@d[6],lsr#32 + add.32 @x[14],@x[14],@d[7] + add @x[15],@x[15],@d[7],lsr#32 + + add @x[0],@x[0],@x[1],lsl#32 // pack + add @x[2],@x[2],@x[3],lsl#32 + ldp @x[1],@x[3],[$inp,#0] // load input + add @x[4],@x[4],@x[5],lsl#32 + add @x[6],@x[6],@x[7],lsl#32 + ldp @x[5],@x[7],[$inp,#16] + add @x[8],@x[8],@x[9],lsl#32 + add @x[10],@x[10],@x[11],lsl#32 + ldp @x[9],@x[11],[$inp,#32] + add @x[12],@x[12],@x[13],lsl#32 + add @x[14],@x[14],@x[15],lsl#32 + ldp @x[13],@x[15],[$inp,#48] + add $inp,$inp,#64 +#ifdef __ARMEB__ + rev @x[0],@x[0] + rev @x[2],@x[2] + rev @x[4],@x[4] + rev @x[6],@x[6] + rev @x[8],@x[8] + rev @x[10],@x[10] + rev @x[12],@x[12] + rev @x[14],@x[14] +#endif + eor @x[0],@x[0],@x[1] + eor @x[2],@x[2],@x[3] + eor @x[4],@x[4],@x[5] + eor @x[6],@x[6],@x[7] + eor @x[8],@x[8],@x[9] + eor @x[10],@x[10],@x[11] + eor @x[12],@x[12],@x[13] + eor @x[14],@x[14],@x[15] + + stp @x[0],@x[2],[$out,#0] // store output + add @d[6],@d[6],#1 // increment counter + mov.32 @x[0],@d[0] // unpack key block + lsr @x[1],@d[0],#32 + stp @x[4],@x[6],[$out,#16] + mov.32 @x[2],@d[1] + lsr @x[3],@d[1],#32 + stp @x[8],@x[10],[$out,#32] + mov.32 @x[4],@d[2] + lsr @x[5],@d[2],#32 + stp @x[12],@x[14],[$out,#48] + add $out,$out,#64 + mov.32 @x[6],@d[3] + lsr @x[7],@d[3],#32 + mov.32 @x[8],@d[4] + lsr @x[9],@d[4],#32 + mov.32 @x[10],@d[5] + lsr @x[11],@d[5],#32 + mov.32 @x[12],@d[6] + lsr @x[13],@d[6],#32 + mov.32 @x[14],@d[7] + lsr @x[15],@d[7],#32 + + mov $ctr,#5 +.Loop_lower_neon: + sub $ctr,$ctr,#1 +___ + @thread0=&NEONROUND($A0,$B0,$C0,$D0,$T0,0); + @thread1=&NEONROUND($A1,$B1,$C1,$D1,$T1,0); + @thread2=&NEONROUND($A2,$B2,$C2,$D2,$T2,0); + @thread3=&NEONROUND($A3,$B3,$C3,$D3,$T3,0); + @thread4=&NEONROUND($A4,$B4,$C4,$D4,$T4,0); + @thread5=&NEONROUND($A5,$B5,$C5,$D5,$T5,0); + @thread67=(&ROUND(0,4,8,12),&ROUND(0,5,10,15)); + + foreach (@thread0) { + eval; eval(shift(@thread67)); + eval(shift(@thread1)); eval(shift(@thread67)); + eval(shift(@thread2)); eval(shift(@thread67)); + eval(shift(@thread3)); eval(shift(@thread67)); + eval(shift(@thread4)); eval(shift(@thread67)); + eval(shift(@thread5)); eval(shift(@thread67)); + } + + @thread0=&NEONROUND($A0,$B0,$C0,$D0,$T0,1); + @thread1=&NEONROUND($A1,$B1,$C1,$D1,$T1,1); + @thread2=&NEONROUND($A2,$B2,$C2,$D2,$T2,1); + @thread3=&NEONROUND($A3,$B3,$C3,$D3,$T3,1); + @thread4=&NEONROUND($A4,$B4,$C4,$D4,$T4,1); + @thread5=&NEONROUND($A5,$B5,$C5,$D5,$T5,1); + @thread67=(&ROUND(0,4,8,12),&ROUND(0,5,10,15)); + + foreach (@thread0) { + eval; eval(shift(@thread67)); + eval(shift(@thread1)); eval(shift(@thread67)); + eval(shift(@thread2)); eval(shift(@thread67)); + eval(shift(@thread3)); eval(shift(@thread67)); + eval(shift(@thread4)); eval(shift(@thread67)); + eval(shift(@thread5)); eval(shift(@thread67)); + } +$code.=<<___; + cbnz $ctr,.Loop_lower_neon + + add.32 @x[0],@x[0],@d[0] // accumulate key block + ldp @K[0],@K[1],[sp,#0] + add @x[1],@x[1],@d[0],lsr#32 + ldp @K[2],@K[3],[sp,#32] + add.32 @x[2],@x[2],@d[1] + ldp @K[4],@K[5],[sp,#64] + add @x[3],@x[3],@d[1],lsr#32 + add $A0,$A0,@K[0] + add.32 @x[4],@x[4],@d[2] + add $A1,$A1,@K[0] + add @x[5],@x[5],@d[2],lsr#32 + add $A2,$A2,@K[0] + add.32 @x[6],@x[6],@d[3] + add $A3,$A3,@K[0] + add @x[7],@x[7],@d[3],lsr#32 + add $A4,$A4,@K[0] + add.32 @x[8],@x[8],@d[4] + add $A5,$A5,@K[0] + add @x[9],@x[9],@d[4],lsr#32 + add $C0,$C0,@K[2] + add.32 @x[10],@x[10],@d[5] + add $C1,$C1,@K[2] + add @x[11],@x[11],@d[5],lsr#32 + add $C2,$C2,@K[2] + add.32 @x[12],@x[12],@d[6] + add $C3,$C3,@K[2] + add @x[13],@x[13],@d[6],lsr#32 + add $C4,$C4,@K[2] + add.32 @x[14],@x[14],@d[7] + add $C5,$C5,@K[2] + add @x[15],@x[15],@d[7],lsr#32 + add $D4,$D4,$ONE // +4 + add @x[0],@x[0],@x[1],lsl#32 // pack + add $D5,$D5,$ONE // +4 + add @x[2],@x[2],@x[3],lsl#32 + add $D0,$D0,@K[3] + ldp @x[1],@x[3],[$inp,#0] // load input + add $D1,$D1,@K[4] + add @x[4],@x[4],@x[5],lsl#32 + add $D2,$D2,@K[5] + add @x[6],@x[6],@x[7],lsl#32 + add $D3,$D3,@K[6] + ldp @x[5],@x[7],[$inp,#16] + add $D4,$D4,@K[3] + add @x[8],@x[8],@x[9],lsl#32 + add $D5,$D5,@K[4] + add @x[10],@x[10],@x[11],lsl#32 + add $B0,$B0,@K[1] + ldp @x[9],@x[11],[$inp,#32] + add $B1,$B1,@K[1] + add @x[12],@x[12],@x[13],lsl#32 + add $B2,$B2,@K[1] + add @x[14],@x[14],@x[15],lsl#32 + add $B3,$B3,@K[1] + ldp @x[13],@x[15],[$inp,#48] + add $B4,$B4,@K[1] + add $inp,$inp,#64 + add $B5,$B5,@K[1] + +#ifdef __ARMEB__ + rev @x[0],@x[0] + rev @x[2],@x[2] + rev @x[4],@x[4] + rev @x[6],@x[6] + rev @x[8],@x[8] + rev @x[10],@x[10] + rev @x[12],@x[12] + rev @x[14],@x[14] +#endif + ld1.8 {$T0-$T3},[$inp],#64 + eor @x[0],@x[0],@x[1] + eor @x[2],@x[2],@x[3] + eor @x[4],@x[4],@x[5] + eor @x[6],@x[6],@x[7] + eor @x[8],@x[8],@x[9] + eor $A0,$A0,$T0 + eor @x[10],@x[10],@x[11] + eor $B0,$B0,$T1 + eor @x[12],@x[12],@x[13] + eor $C0,$C0,$T2 + eor @x[14],@x[14],@x[15] + eor $D0,$D0,$T3 + ld1.8 {$T0-$T3},[$inp],#64 + + stp @x[0],@x[2],[$out,#0] // store output + add @d[6],@d[6],#7 // increment counter + stp @x[4],@x[6],[$out,#16] + stp @x[8],@x[10],[$out,#32] + stp @x[12],@x[14],[$out,#48] + add $out,$out,#64 + st1.8 {$A0-$D0},[$out],#64 + + ld1.8 {$A0-$D0},[$inp],#64 + eor $A1,$A1,$T0 + eor $B1,$B1,$T1 + eor $C1,$C1,$T2 + eor $D1,$D1,$T3 + st1.8 {$A1-$D1},[$out],#64 + + ld1.8 {$A1-$D1},[$inp],#64 + eor $A2,$A2,$A0 + ldp @K[0],@K[1],[sp,#0] + eor $B2,$B2,$B0 + ldp @K[2],@K[3],[sp,#32] + eor $C2,$C2,$C0 + eor $D2,$D2,$D0 + st1.8 {$A2-$D2},[$out],#64 + + ld1.8 {$A2-$D2},[$inp],#64 + eor $A3,$A3,$A1 + eor $B3,$B3,$B1 + eor $C3,$C3,$C1 + eor $D3,$D3,$D1 + st1.8 {$A3-$D3},[$out],#64 + + ld1.8 {$A3-$D3},[$inp],#64 + eor $A4,$A4,$A2 + eor $B4,$B4,$B2 + eor $C4,$C4,$C2 + eor $D4,$D4,$D2 + st1.8 {$A4-$D4},[$out],#64 + + shl $A0,$ONE,#1 // 4 -> 8 + eor $A5,$A5,$A3 + eor $B5,$B5,$B3 + eor $C5,$C5,$C3 + eor $D5,$D5,$D3 + st1.8 {$A5-$D5},[$out],#64 + + add @K[3],@K[3],$A0 // += 8 + add @K[4],@K[4],$A0 + add @K[5],@K[5],$A0 + add @K[6],@K[6],$A0 + + b.hs .Loop_outer_512_neon + + adds $len,$len,#512 + ushr $A0,$ONE,#2 // 4 -> 1 + + ldp d8,d9,[sp,#128+0] // meet ABI requirements + ldp d10,d11,[sp,#128+16] + ldp d12,d13,[sp,#128+32] + ldp d14,d15,[sp,#128+48] + + stp @K[0],$ONE,[sp,#0] // wipe off-load area + stp @K[0],$ONE,[sp,#32] + stp @K[0],$ONE,[sp,#64] + + b.eq .Ldone_512_neon + + cmp $len,#192 + sub @K[3],@K[3],$A0 // -= 1 + sub @K[4],@K[4],$A0 + sub @K[5],@K[5],$A0 + add sp,sp,#128 + b.hs .Loop_outer_neon + + eor @K[1],@K[1],@K[1] + eor @K[2],@K[2],@K[2] + eor @K[3],@K[3],@K[3] + eor @K[4],@K[4],@K[4] + eor @K[5],@K[5],@K[5] + eor @K[6],@K[6],@K[6] + b .Loop_outer + +.Ldone_512_neon: + ldp x19,x20,[x29,#16] + add sp,sp,#128+64 + ldp x21,x22,[x29,#32] + ldp x23,x24,[x29,#48] + ldp x25,x26,[x29,#64] + ldp x27,x28,[x29,#80] + ldp x29,x30,[sp],#96 + ret +.size ChaCha20_512_neon,.-ChaCha20_512_neon +___ +} +}}} + +foreach (split("\n",$code)) { + s/\`([^\`]*)\`/eval $1/geo; + + (s/\b([a-z]+)\.32\b/$1/ and (s/x([0-9]+)/w$1/g or 1)) or + (m/\b(eor|ext|mov)\b/ and (s/\.4s/\.16b/g or 1)) or + (s/\b((?:ld|st)1)\.8\b/$1/ and (s/\.4s/\.16b/g or 1)) or + (m/\b(ld|st)[rp]\b/ and (s/v([0-9]+)\.4s/q$1/g or 1)) or + (s/\brev32\.16\b/rev32/ and (s/\.4s/\.8h/g or 1)); + + #s/\bq([0-9]+)#(lo|hi)/sprintf "d%d",2*$1+($2 eq "hi")/geo; + + print $_,"\n"; +} +close STDOUT; # flush diff --git a/crypto/chacha20/chacha20-arm64.s b/crypto/chacha20/chacha20-arm64.s new file mode 100644 index 0000000..c3d1243 --- /dev/null +++ b/crypto/chacha20/chacha20-arm64.s @@ -0,0 +1,1940 @@ +/* SPDX-License-Identifier: OpenSSL OR (BSD-3-Clause OR GPL-2.0) + * + * Copyright (C) 2015-2018 Jason A. Donenfeld . All Rights Reserved. + * Copyright 2016 The OpenSSL Project Authors. All Rights Reserved. + */ + +#include + +.text +.align 5 +.Lsigma: +.quad 0x3320646e61707865,0x6b20657479622d32 // endian-neutral +.Lone: +.long 1,0,0,0 + +.align 5 +ENTRY(chacha20_arm) + cbz x2,.Labort +.Lshort: + stp x29,x30,[sp,#-96]! + add x29,sp,#0 + + adr x5,.Lsigma + stp x19,x20,[sp,#16] + stp x21,x22,[sp,#32] + stp x23,x24,[sp,#48] + stp x25,x26,[sp,#64] + stp x27,x28,[sp,#80] + sub sp,sp,#64 + + ldp x22,x23,[x5] // load sigma + ldp x24,x25,[x3] // load key + ldp x26,x27,[x3,#16] + ldp x28,x30,[x4] // load counter +#ifdef __ARMEB__ + ror x24,x24,#32 + ror x25,x25,#32 + ror x26,x26,#32 + ror x27,x27,#32 + ror x28,x28,#32 + ror x30,x30,#32 +#endif + +.Loop_outer: + mov w5,w22 // unpack key block + lsr x6,x22,#32 + mov w7,w23 + lsr x8,x23,#32 + mov w9,w24 + lsr x10,x24,#32 + mov w11,w25 + lsr x12,x25,#32 + mov w13,w26 + lsr x14,x26,#32 + mov w15,w27 + lsr x16,x27,#32 + mov w17,w28 + lsr x19,x28,#32 + mov w20,w30 + lsr x21,x30,#32 + + mov x4,#10 + subs x2,x2,#64 +.Loop: + sub x4,x4,#1 + add w5,w5,w9 + add w6,w6,w10 + add w7,w7,w11 + add w8,w8,w12 + eor w17,w17,w5 + eor w19,w19,w6 + eor w20,w20,w7 + eor w21,w21,w8 + ror w17,w17,#16 + ror w19,w19,#16 + ror w20,w20,#16 + ror w21,w21,#16 + add w13,w13,w17 + add w14,w14,w19 + add w15,w15,w20 + add w16,w16,w21 + eor w9,w9,w13 + eor w10,w10,w14 + eor w11,w11,w15 + eor w12,w12,w16 + ror w9,w9,#20 + ror w10,w10,#20 + ror w11,w11,#20 + ror w12,w12,#20 + add w5,w5,w9 + add w6,w6,w10 + add w7,w7,w11 + add w8,w8,w12 + eor w17,w17,w5 + eor w19,w19,w6 + eor w20,w20,w7 + eor w21,w21,w8 + ror w17,w17,#24 + ror w19,w19,#24 + ror w20,w20,#24 + ror w21,w21,#24 + add w13,w13,w17 + add w14,w14,w19 + add w15,w15,w20 + add w16,w16,w21 + eor w9,w9,w13 + eor w10,w10,w14 + eor w11,w11,w15 + eor w12,w12,w16 + ror w9,w9,#25 + ror w10,w10,#25 + ror w11,w11,#25 + ror w12,w12,#25 + add w5,w5,w10 + add w6,w6,w11 + add w7,w7,w12 + add w8,w8,w9 + eor w21,w21,w5 + eor w17,w17,w6 + eor w19,w19,w7 + eor w20,w20,w8 + ror w21,w21,#16 + ror w17,w17,#16 + ror w19,w19,#16 + ror w20,w20,#16 + add w15,w15,w21 + add w16,w16,w17 + add w13,w13,w19 + add w14,w14,w20 + eor w10,w10,w15 + eor w11,w11,w16 + eor w12,w12,w13 + eor w9,w9,w14 + ror w10,w10,#20 + ror w11,w11,#20 + ror w12,w12,#20 + ror w9,w9,#20 + add w5,w5,w10 + add w6,w6,w11 + add w7,w7,w12 + add w8,w8,w9 + eor w21,w21,w5 + eor w17,w17,w6 + eor w19,w19,w7 + eor w20,w20,w8 + ror w21,w21,#24 + ror w17,w17,#24 + ror w19,w19,#24 + ror w20,w20,#24 + add w15,w15,w21 + add w16,w16,w17 + add w13,w13,w19 + add w14,w14,w20 + eor w10,w10,w15 + eor w11,w11,w16 + eor w12,w12,w13 + eor w9,w9,w14 + ror w10,w10,#25 + ror w11,w11,#25 + ror w12,w12,#25 + ror w9,w9,#25 + cbnz x4,.Loop + + add w5,w5,w22 // accumulate key block + add x6,x6,x22,lsr#32 + add w7,w7,w23 + add x8,x8,x23,lsr#32 + add w9,w9,w24 + add x10,x10,x24,lsr#32 + add w11,w11,w25 + add x12,x12,x25,lsr#32 + add w13,w13,w26 + add x14,x14,x26,lsr#32 + add w15,w15,w27 + add x16,x16,x27,lsr#32 + add w17,w17,w28 + add x19,x19,x28,lsr#32 + add w20,w20,w30 + add x21,x21,x30,lsr#32 + + b.lo .Ltail + + add x5,x5,x6,lsl#32 // pack + add x7,x7,x8,lsl#32 + ldp x6,x8,[x1,#0] // load input + add x9,x9,x10,lsl#32 + add x11,x11,x12,lsl#32 + ldp x10,x12,[x1,#16] + add x13,x13,x14,lsl#32 + add x15,x15,x16,lsl#32 + ldp x14,x16,[x1,#32] + add x17,x17,x19,lsl#32 + add x20,x20,x21,lsl#32 + ldp x19,x21,[x1,#48] + add x1,x1,#64 +#ifdef __ARMEB__ + rev x5,x5 + rev x7,x7 + rev x9,x9 + rev x11,x11 + rev x13,x13 + rev x15,x15 + rev x17,x17 + rev x20,x20 +#endif + eor x5,x5,x6 + eor x7,x7,x8 + eor x9,x9,x10 + eor x11,x11,x12 + eor x13,x13,x14 + eor x15,x15,x16 + eor x17,x17,x19 + eor x20,x20,x21 + + stp x5,x7,[x0,#0] // store output + add x28,x28,#1 // increment counter + stp x9,x11,[x0,#16] + stp x13,x15,[x0,#32] + stp x17,x20,[x0,#48] + add x0,x0,#64 + + b.hi .Loop_outer + + ldp x19,x20,[x29,#16] + add sp,sp,#64 + ldp x21,x22,[x29,#32] + ldp x23,x24,[x29,#48] + ldp x25,x26,[x29,#64] + ldp x27,x28,[x29,#80] + ldp x29,x30,[sp],#96 +.Labort: + ret + +.align 4 +.Ltail: + add x2,x2,#64 +.Less_than_64: + sub x0,x0,#1 + add x1,x1,x2 + add x0,x0,x2 + add x4,sp,x2 + neg x2,x2 + + add x5,x5,x6,lsl#32 // pack + add x7,x7,x8,lsl#32 + add x9,x9,x10,lsl#32 + add x11,x11,x12,lsl#32 + add x13,x13,x14,lsl#32 + add x15,x15,x16,lsl#32 + add x17,x17,x19,lsl#32 + add x20,x20,x21,lsl#32 +#ifdef __ARMEB__ + rev x5,x5 + rev x7,x7 + rev x9,x9 + rev x11,x11 + rev x13,x13 + rev x15,x15 + rev x17,x17 + rev x20,x20 +#endif + stp x5,x7,[sp,#0] + stp x9,x11,[sp,#16] + stp x13,x15,[sp,#32] + stp x17,x20,[sp,#48] + +.Loop_tail: + ldrb w10,[x1,x2] + ldrb w11,[x4,x2] + add x2,x2,#1 + eor w10,w10,w11 + strb w10,[x0,x2] + cbnz x2,.Loop_tail + + stp xzr,xzr,[sp,#0] + stp xzr,xzr,[sp,#16] + stp xzr,xzr,[sp,#32] + stp xzr,xzr,[sp,#48] + + ldp x19,x20,[x29,#16] + add sp,sp,#64 + ldp x21,x22,[x29,#32] + ldp x23,x24,[x29,#48] + ldp x25,x26,[x29,#64] + ldp x27,x28,[x29,#80] + ldp x29,x30,[sp],#96 + ret +ENDPROC(chacha20_arm) + +.align 5 +ENTRY(chacha20_neon) + cbz x2,.Labort_neon + cmp x2,#192 + b.lo .Lshort + + stp x29,x30,[sp,#-96]! + add x29,sp,#0 + + adr x5,.Lsigma + stp x19,x20,[sp,#16] + stp x21,x22,[sp,#32] + stp x23,x24,[sp,#48] + stp x25,x26,[sp,#64] + stp x27,x28,[sp,#80] + cmp x2,#512 + b.hs .L512_or_more_neon + + sub sp,sp,#64 + + ldp x22,x23,[x5] // load sigma + ld1 {v24.4s},[x5],#16 + ldp x24,x25,[x3] // load key + ldp x26,x27,[x3,#16] + ld1 {v25.4s,v26.4s},[x3] + ldp x28,x30,[x4] // load counter + ld1 {v27.4s},[x4] + ld1 {v31.4s},[x5] +#ifdef __ARMEB__ + rev64 v24.4s,v24.4s + ror x24,x24,#32 + ror x25,x25,#32 + ror x26,x26,#32 + ror x27,x27,#32 + ror x28,x28,#32 + ror x30,x30,#32 +#endif + add v27.4s,v27.4s,v31.4s // += 1 + add v28.4s,v27.4s,v31.4s + add v29.4s,v28.4s,v31.4s + shl v31.4s,v31.4s,#2 // 1 -> 4 + +.Loop_outer_neon: + mov w5,w22 // unpack key block + lsr x6,x22,#32 + mov v0.16b,v24.16b + mov w7,w23 + lsr x8,x23,#32 + mov v4.16b,v24.16b + mov w9,w24 + lsr x10,x24,#32 + mov v16.16b,v24.16b + mov w11,w25 + mov v1.16b,v25.16b + lsr x12,x25,#32 + mov v5.16b,v25.16b + mov w13,w26 + mov v17.16b,v25.16b + lsr x14,x26,#32 + mov v3.16b,v27.16b + mov w15,w27 + mov v7.16b,v28.16b + lsr x16,x27,#32 + mov v19.16b,v29.16b + mov w17,w28 + mov v2.16b,v26.16b + lsr x19,x28,#32 + mov v6.16b,v26.16b + mov w20,w30 + mov v18.16b,v26.16b + lsr x21,x30,#32 + + mov x4,#10 + subs x2,x2,#256 +.Loop_neon: + sub x4,x4,#1 + add v0.4s,v0.4s,v1.4s + add w5,w5,w9 + add v4.4s,v4.4s,v5.4s + add w6,w6,w10 + add v16.4s,v16.4s,v17.4s + add w7,w7,w11 + eor v3.16b,v3.16b,v0.16b + add w8,w8,w12 + eor v7.16b,v7.16b,v4.16b + eor w17,w17,w5 + eor v19.16b,v19.16b,v16.16b + eor w19,w19,w6 + rev32 v3.8h,v3.8h + eor w20,w20,w7 + rev32 v7.8h,v7.8h + eor w21,w21,w8 + rev32 v19.8h,v19.8h + ror w17,w17,#16 + add v2.4s,v2.4s,v3.4s + ror w19,w19,#16 + add v6.4s,v6.4s,v7.4s + ror w20,w20,#16 + add v18.4s,v18.4s,v19.4s + ror w21,w21,#16 + eor v20.16b,v1.16b,v2.16b + add w13,w13,w17 + eor v21.16b,v5.16b,v6.16b + add w14,w14,w19 + eor v22.16b,v17.16b,v18.16b + add w15,w15,w20 + ushr v1.4s,v20.4s,#20 + add w16,w16,w21 + ushr v5.4s,v21.4s,#20 + eor w9,w9,w13 + ushr v17.4s,v22.4s,#20 + eor w10,w10,w14 + sli v1.4s,v20.4s,#12 + eor w11,w11,w15 + sli v5.4s,v21.4s,#12 + eor w12,w12,w16 + sli v17.4s,v22.4s,#12 + ror w9,w9,#20 + add v0.4s,v0.4s,v1.4s + ror w10,w10,#20 + add v4.4s,v4.4s,v5.4s + ror w11,w11,#20 + add v16.4s,v16.4s,v17.4s + ror w12,w12,#20 + eor v20.16b,v3.16b,v0.16b + add w5,w5,w9 + eor v21.16b,v7.16b,v4.16b + add w6,w6,w10 + eor v22.16b,v19.16b,v16.16b + add w7,w7,w11 + ushr v3.4s,v20.4s,#24 + add w8,w8,w12 + ushr v7.4s,v21.4s,#24 + eor w17,w17,w5 + ushr v19.4s,v22.4s,#24 + eor w19,w19,w6 + sli v3.4s,v20.4s,#8 + eor w20,w20,w7 + sli v7.4s,v21.4s,#8 + eor w21,w21,w8 + sli v19.4s,v22.4s,#8 + ror w17,w17,#24 + add v2.4s,v2.4s,v3.4s + ror w19,w19,#24 + add v6.4s,v6.4s,v7.4s + ror w20,w20,#24 + add v18.4s,v18.4s,v19.4s + ror w21,w21,#24 + eor v20.16b,v1.16b,v2.16b + add w13,w13,w17 + eor v21.16b,v5.16b,v6.16b + add w14,w14,w19 + eor v22.16b,v17.16b,v18.16b + add w15,w15,w20 + ushr v1.4s,v20.4s,#25 + add w16,w16,w21 + ushr v5.4s,v21.4s,#25 + eor w9,w9,w13 + ushr v17.4s,v22.4s,#25 + eor w10,w10,w14 + sli v1.4s,v20.4s,#7 + eor w11,w11,w15 + sli v5.4s,v21.4s,#7 + eor w12,w12,w16 + sli v17.4s,v22.4s,#7 + ror w9,w9,#25 + ext v2.16b,v2.16b,v2.16b,#8 + ror w10,w10,#25 + ext v6.16b,v6.16b,v6.16b,#8 + ror w11,w11,#25 + ext v18.16b,v18.16b,v18.16b,#8 + ror w12,w12,#25 + ext v3.16b,v3.16b,v3.16b,#12 + ext v7.16b,v7.16b,v7.16b,#12 + ext v19.16b,v19.16b,v19.16b,#12 + ext v1.16b,v1.16b,v1.16b,#4 + ext v5.16b,v5.16b,v5.16b,#4 + ext v17.16b,v17.16b,v17.16b,#4 + add v0.4s,v0.4s,v1.4s + add w5,w5,w10 + add v4.4s,v4.4s,v5.4s + add w6,w6,w11 + add v16.4s,v16.4s,v17.4s + add w7,w7,w12 + eor v3.16b,v3.16b,v0.16b + add w8,w8,w9 + eor v7.16b,v7.16b,v4.16b + eor w21,w21,w5 + eor v19.16b,v19.16b,v16.16b + eor w17,w17,w6 + rev32 v3.8h,v3.8h + eor w19,w19,w7 + rev32 v7.8h,v7.8h + eor w20,w20,w8 + rev32 v19.8h,v19.8h + ror w21,w21,#16 + add v2.4s,v2.4s,v3.4s + ror w17,w17,#16 + add v6.4s,v6.4s,v7.4s + ror w19,w19,#16 + add v18.4s,v18.4s,v19.4s + ror w20,w20,#16 + eor v20.16b,v1.16b,v2.16b + add w15,w15,w21 + eor v21.16b,v5.16b,v6.16b + add w16,w16,w17 + eor v22.16b,v17.16b,v18.16b + add w13,w13,w19 + ushr v1.4s,v20.4s,#20 + add w14,w14,w20 + ushr v5.4s,v21.4s,#20 + eor w10,w10,w15 + ushr v17.4s,v22.4s,#20 + eor w11,w11,w16 + sli v1.4s,v20.4s,#12 + eor w12,w12,w13 + sli v5.4s,v21.4s,#12 + eor w9,w9,w14 + sli v17.4s,v22.4s,#12 + ror w10,w10,#20 + add v0.4s,v0.4s,v1.4s + ror w11,w11,#20 + add v4.4s,v4.4s,v5.4s + ror w12,w12,#20 + add v16.4s,v16.4s,v17.4s + ror w9,w9,#20 + eor v20.16b,v3.16b,v0.16b + add w5,w5,w10 + eor v21.16b,v7.16b,v4.16b + add w6,w6,w11 + eor v22.16b,v19.16b,v16.16b + add w7,w7,w12 + ushr v3.4s,v20.4s,#24 + add w8,w8,w9 + ushr v7.4s,v21.4s,#24 + eor w21,w21,w5 + ushr v19.4s,v22.4s,#24 + eor w17,w17,w6 + sli v3.4s,v20.4s,#8 + eor w19,w19,w7 + sli v7.4s,v21.4s,#8 + eor w20,w20,w8 + sli v19.4s,v22.4s,#8 + ror w21,w21,#24 + add v2.4s,v2.4s,v3.4s + ror w17,w17,#24 + add v6.4s,v6.4s,v7.4s + ror w19,w19,#24 + add v18.4s,v18.4s,v19.4s + ror w20,w20,#24 + eor v20.16b,v1.16b,v2.16b + add w15,w15,w21 + eor v21.16b,v5.16b,v6.16b + add w16,w16,w17 + eor v22.16b,v17.16b,v18.16b + add w13,w13,w19 + ushr v1.4s,v20.4s,#25 + add w14,w14,w20 + ushr v5.4s,v21.4s,#25 + eor w10,w10,w15 + ushr v17.4s,v22.4s,#25 + eor w11,w11,w16 + sli v1.4s,v20.4s,#7 + eor w12,w12,w13 + sli v5.4s,v21.4s,#7 + eor w9,w9,w14 + sli v17.4s,v22.4s,#7 + ror w10,w10,#25 + ext v2.16b,v2.16b,v2.16b,#8 + ror w11,w11,#25 + ext v6.16b,v6.16b,v6.16b,#8 + ror w12,w12,#25 + ext v18.16b,v18.16b,v18.16b,#8 + ror w9,w9,#25 + ext v3.16b,v3.16b,v3.16b,#4 + ext v7.16b,v7.16b,v7.16b,#4 + ext v19.16b,v19.16b,v19.16b,#4 + ext v1.16b,v1.16b,v1.16b,#12 + ext v5.16b,v5.16b,v5.16b,#12 + ext v17.16b,v17.16b,v17.16b,#12 + cbnz x4,.Loop_neon + + add w5,w5,w22 // accumulate key block + add v0.4s,v0.4s,v24.4s + add x6,x6,x22,lsr#32 + add v4.4s,v4.4s,v24.4s + add w7,w7,w23 + add v16.4s,v16.4s,v24.4s + add x8,x8,x23,lsr#32 + add v2.4s,v2.4s,v26.4s + add w9,w9,w24 + add v6.4s,v6.4s,v26.4s + add x10,x10,x24,lsr#32 + add v18.4s,v18.4s,v26.4s + add w11,w11,w25 + add v3.4s,v3.4s,v27.4s + add x12,x12,x25,lsr#32 + add w13,w13,w26 + add v7.4s,v7.4s,v28.4s + add x14,x14,x26,lsr#32 + add w15,w15,w27 + add v19.4s,v19.4s,v29.4s + add x16,x16,x27,lsr#32 + add w17,w17,w28 + add v1.4s,v1.4s,v25.4s + add x19,x19,x28,lsr#32 + add w20,w20,w30 + add v5.4s,v5.4s,v25.4s + add x21,x21,x30,lsr#32 + add v17.4s,v17.4s,v25.4s + + b.lo .Ltail_neon + + add x5,x5,x6,lsl#32 // pack + add x7,x7,x8,lsl#32 + ldp x6,x8,[x1,#0] // load input + add x9,x9,x10,lsl#32 + add x11,x11,x12,lsl#32 + ldp x10,x12,[x1,#16] + add x13,x13,x14,lsl#32 + add x15,x15,x16,lsl#32 + ldp x14,x16,[x1,#32] + add x17,x17,x19,lsl#32 + add x20,x20,x21,lsl#32 + ldp x19,x21,[x1,#48] + add x1,x1,#64 +#ifdef __ARMEB__ + rev x5,x5 + rev x7,x7 + rev x9,x9 + rev x11,x11 + rev x13,x13 + rev x15,x15 + rev x17,x17 + rev x20,x20 +#endif + ld1 {v20.16b,v21.16b,v22.16b,v23.16b},[x1],#64 + eor x5,x5,x6 + eor x7,x7,x8 + eor x9,x9,x10 + eor x11,x11,x12 + eor x13,x13,x14 + eor v0.16b,v0.16b,v20.16b + eor x15,x15,x16 + eor v1.16b,v1.16b,v21.16b + eor x17,x17,x19 + eor v2.16b,v2.16b,v22.16b + eor x20,x20,x21 + eor v3.16b,v3.16b,v23.16b + ld1 {v20.16b,v21.16b,v22.16b,v23.16b},[x1],#64 + + stp x5,x7,[x0,#0] // store output + add x28,x28,#4 // increment counter + stp x9,x11,[x0,#16] + add v27.4s,v27.4s,v31.4s // += 4 + stp x13,x15,[x0,#32] + add v28.4s,v28.4s,v31.4s + stp x17,x20,[x0,#48] + add v29.4s,v29.4s,v31.4s + add x0,x0,#64 + + st1 {v0.16b,v1.16b,v2.16b,v3.16b},[x0],#64 + ld1 {v0.16b,v1.16b,v2.16b,v3.16b},[x1],#64 + + eor v4.16b,v4.16b,v20.16b + eor v5.16b,v5.16b,v21.16b + eor v6.16b,v6.16b,v22.16b + eor v7.16b,v7.16b,v23.16b + st1 {v4.16b,v5.16b,v6.16b,v7.16b},[x0],#64 + + eor v16.16b,v16.16b,v0.16b + eor v17.16b,v17.16b,v1.16b + eor v18.16b,v18.16b,v2.16b + eor v19.16b,v19.16b,v3.16b + st1 {v16.16b,v17.16b,v18.16b,v19.16b},[x0],#64 + + b.hi .Loop_outer_neon + + ldp x19,x20,[x29,#16] + add sp,sp,#64 + ldp x21,x22,[x29,#32] + ldp x23,x24,[x29,#48] + ldp x25,x26,[x29,#64] + ldp x27,x28,[x29,#80] + ldp x29,x30,[sp],#96 + ret + +.Ltail_neon: + add x2,x2,#256 + cmp x2,#64 + b.lo .Less_than_64 + + add x5,x5,x6,lsl#32 // pack + add x7,x7,x8,lsl#32 + ldp x6,x8,[x1,#0] // load input + add x9,x9,x10,lsl#32 + add x11,x11,x12,lsl#32 + ldp x10,x12,[x1,#16] + add x13,x13,x14,lsl#32 + add x15,x15,x16,lsl#32 + ldp x14,x16,[x1,#32] + add x17,x17,x19,lsl#32 + add x20,x20,x21,lsl#32 + ldp x19,x21,[x1,#48] + add x1,x1,#64 +#ifdef __ARMEB__ + rev x5,x5 + rev x7,x7 + rev x9,x9 + rev x11,x11 + rev x13,x13 + rev x15,x15 + rev x17,x17 + rev x20,x20 +#endif + eor x5,x5,x6 + eor x7,x7,x8 + eor x9,x9,x10 + eor x11,x11,x12 + eor x13,x13,x14 + eor x15,x15,x16 + eor x17,x17,x19 + eor x20,x20,x21 + + stp x5,x7,[x0,#0] // store output + add x28,x28,#4 // increment counter + stp x9,x11,[x0,#16] + stp x13,x15,[x0,#32] + stp x17,x20,[x0,#48] + add x0,x0,#64 + b.eq .Ldone_neon + sub x2,x2,#64 + cmp x2,#64 + b.lo .Less_than_128 + + ld1 {v20.16b,v21.16b,v22.16b,v23.16b},[x1],#64 + eor v0.16b,v0.16b,v20.16b + eor v1.16b,v1.16b,v21.16b + eor v2.16b,v2.16b,v22.16b + eor v3.16b,v3.16b,v23.16b + st1 {v0.16b,v1.16b,v2.16b,v3.16b},[x0],#64 + b.eq .Ldone_neon + sub x2,x2,#64 + cmp x2,#64 + b.lo .Less_than_192 + + ld1 {v20.16b,v21.16b,v22.16b,v23.16b},[x1],#64 + eor v4.16b,v4.16b,v20.16b + eor v5.16b,v5.16b,v21.16b + eor v6.16b,v6.16b,v22.16b + eor v7.16b,v7.16b,v23.16b + st1 {v4.16b,v5.16b,v6.16b,v7.16b},[x0],#64 + b.eq .Ldone_neon + sub x2,x2,#64 + + st1 {v16.16b,v17.16b,v18.16b,v19.16b},[sp] + b .Last_neon + +.Less_than_128: + st1 {v0.16b,v1.16b,v2.16b,v3.16b},[sp] + b .Last_neon +.Less_than_192: + st1 {v4.16b,v5.16b,v6.16b,v7.16b},[sp] + b .Last_neon + +.align 4 +.Last_neon: + sub x0,x0,#1 + add x1,x1,x2 + add x0,x0,x2 + add x4,sp,x2 + neg x2,x2 + +.Loop_tail_neon: + ldrb w10,[x1,x2] + ldrb w11,[x4,x2] + add x2,x2,#1 + eor w10,w10,w11 + strb w10,[x0,x2] + cbnz x2,.Loop_tail_neon + + stp xzr,xzr,[sp,#0] + stp xzr,xzr,[sp,#16] + stp xzr,xzr,[sp,#32] + stp xzr,xzr,[sp,#48] + +.Ldone_neon: + ldp x19,x20,[x29,#16] + add sp,sp,#64 + ldp x21,x22,[x29,#32] + ldp x23,x24,[x29,#48] + ldp x25,x26,[x29,#64] + ldp x27,x28,[x29,#80] + ldp x29,x30,[sp],#96 + ret + +.L512_or_more_neon: + sub sp,sp,#128+64 + + ldp x22,x23,[x5] // load sigma + ld1 {v24.4s},[x5],#16 + ldp x24,x25,[x3] // load key + ldp x26,x27,[x3,#16] + ld1 {v25.4s,v26.4s},[x3] + ldp x28,x30,[x4] // load counter + ld1 {v27.4s},[x4] + ld1 {v31.4s},[x5] +#ifdef __ARMEB__ + rev64 v24.4s,v24.4s + ror x24,x24,#32 + ror x25,x25,#32 + ror x26,x26,#32 + ror x27,x27,#32 + ror x28,x28,#32 + ror x30,x30,#32 +#endif + add v27.4s,v27.4s,v31.4s // += 1 + stp q24,q25,[sp,#0] // off-load key block, invariant part + add v27.4s,v27.4s,v31.4s // not typo + str q26,[sp,#32] + add v28.4s,v27.4s,v31.4s + add v29.4s,v28.4s,v31.4s + add v30.4s,v29.4s,v31.4s + shl v31.4s,v31.4s,#2 // 1 -> 4 + + stp d8,d9,[sp,#128+0] // meet ABI requirements + stp d10,d11,[sp,#128+16] + stp d12,d13,[sp,#128+32] + stp d14,d15,[sp,#128+48] + + sub x2,x2,#512 // not typo + +.Loop_outer_512_neon: + mov v0.16b,v24.16b + mov v4.16b,v24.16b + mov v8.16b,v24.16b + mov v12.16b,v24.16b + mov v16.16b,v24.16b + mov v20.16b,v24.16b + mov v1.16b,v25.16b + mov w5,w22 // unpack key block + mov v5.16b,v25.16b + lsr x6,x22,#32 + mov v9.16b,v25.16b + mov w7,w23 + mov v13.16b,v25.16b + lsr x8,x23,#32 + mov v17.16b,v25.16b + mov w9,w24 + mov v21.16b,v25.16b + lsr x10,x24,#32 + mov v3.16b,v27.16b + mov w11,w25 + mov v7.16b,v28.16b + lsr x12,x25,#32 + mov v11.16b,v29.16b + mov w13,w26 + mov v15.16b,v30.16b + lsr x14,x26,#32 + mov v2.16b,v26.16b + mov w15,w27 + mov v6.16b,v26.16b + lsr x16,x27,#32 + add v19.4s,v3.4s,v31.4s // +4 + mov w17,w28 + add v23.4s,v7.4s,v31.4s // +4 + lsr x19,x28,#32 + mov v10.16b,v26.16b + mov w20,w30 + mov v14.16b,v26.16b + lsr x21,x30,#32 + mov v18.16b,v26.16b + stp q27,q28,[sp,#48] // off-load key block, variable part + mov v22.16b,v26.16b + str q29,[sp,#80] + + mov x4,#5 + subs x2,x2,#512 +.Loop_upper_neon: + sub x4,x4,#1 + add v0.4s,v0.4s,v1.4s + add w5,w5,w9 + add v4.4s,v4.4s,v5.4s + add w6,w6,w10 + add v8.4s,v8.4s,v9.4s + add w7,w7,w11 + add v12.4s,v12.4s,v13.4s + add w8,w8,w12 + add v16.4s,v16.4s,v17.4s + eor w17,w17,w5 + add v20.4s,v20.4s,v21.4s + eor w19,w19,w6 + eor v3.16b,v3.16b,v0.16b + eor w20,w20,w7 + eor v7.16b,v7.16b,v4.16b + eor w21,w21,w8 + eor v11.16b,v11.16b,v8.16b + ror w17,w17,#16 + eor v15.16b,v15.16b,v12.16b + ror w19,w19,#16 + eor v19.16b,v19.16b,v16.16b + ror w20,w20,#16 + eor v23.16b,v23.16b,v20.16b + ror w21,w21,#16 + rev32 v3.8h,v3.8h + add w13,w13,w17 + rev32 v7.8h,v7.8h + add w14,w14,w19 + rev32 v11.8h,v11.8h + add w15,w15,w20 + rev32 v15.8h,v15.8h + add w16,w16,w21 + rev32 v19.8h,v19.8h + eor w9,w9,w13 + rev32 v23.8h,v23.8h + eor w10,w10,w14 + add v2.4s,v2.4s,v3.4s + eor w11,w11,w15 + add v6.4s,v6.4s,v7.4s + eor w12,w12,w16 + add v10.4s,v10.4s,v11.4s + ror w9,w9,#20 + add v14.4s,v14.4s,v15.4s + ror w10,w10,#20 + add v18.4s,v18.4s,v19.4s + ror w11,w11,#20 + add v22.4s,v22.4s,v23.4s + ror w12,w12,#20 + eor v24.16b,v1.16b,v2.16b + add w5,w5,w9 + eor v25.16b,v5.16b,v6.16b + add w6,w6,w10 + eor v26.16b,v9.16b,v10.16b + add w7,w7,w11 + eor v27.16b,v13.16b,v14.16b + add w8,w8,w12 + eor v28.16b,v17.16b,v18.16b + eor w17,w17,w5 + eor v29.16b,v21.16b,v22.16b + eor w19,w19,w6 + ushr v1.4s,v24.4s,#20 + eor w20,w20,w7 + ushr v5.4s,v25.4s,#20 + eor w21,w21,w8 + ushr v9.4s,v26.4s,#20 + ror w17,w17,#24 + ushr v13.4s,v27.4s,#20 + ror w19,w19,#24 + ushr v17.4s,v28.4s,#20 + ror w20,w20,#24 + ushr v21.4s,v29.4s,#20 + ror w21,w21,#24 + sli v1.4s,v24.4s,#12 + add w13,w13,w17 + sli v5.4s,v25.4s,#12 + add w14,w14,w19 + sli v9.4s,v26.4s,#12 + add w15,w15,w20 + sli v13.4s,v27.4s,#12 + add w16,w16,w21 + sli v17.4s,v28.4s,#12 + eor w9,w9,w13 + sli v21.4s,v29.4s,#12 + eor w10,w10,w14 + add v0.4s,v0.4s,v1.4s + eor w11,w11,w15 + add v4.4s,v4.4s,v5.4s + eor w12,w12,w16 + add v8.4s,v8.4s,v9.4s + ror w9,w9,#25 + add v12.4s,v12.4s,v13.4s + ror w10,w10,#25 + add v16.4s,v16.4s,v17.4s + ror w11,w11,#25 + add v20.4s,v20.4s,v21.4s + ror w12,w12,#25 + eor v24.16b,v3.16b,v0.16b + add w5,w5,w10 + eor v25.16b,v7.16b,v4.16b + add w6,w6,w11 + eor v26.16b,v11.16b,v8.16b + add w7,w7,w12 + eor v27.16b,v15.16b,v12.16b + add w8,w8,w9 + eor v28.16b,v19.16b,v16.16b + eor w21,w21,w5 + eor v29.16b,v23.16b,v20.16b + eor w17,w17,w6 + ushr v3.4s,v24.4s,#24 + eor w19,w19,w7 + ushr v7.4s,v25.4s,#24 + eor w20,w20,w8 + ushr v11.4s,v26.4s,#24 + ror w21,w21,#16 + ushr v15.4s,v27.4s,#24 + ror w17,w17,#16 + ushr v19.4s,v28.4s,#24 + ror w19,w19,#16 + ushr v23.4s,v29.4s,#24 + ror w20,w20,#16 + sli v3.4s,v24.4s,#8 + add w15,w15,w21 + sli v7.4s,v25.4s,#8 + add w16,w16,w17 + sli v11.4s,v26.4s,#8 + add w13,w13,w19 + sli v15.4s,v27.4s,#8 + add w14,w14,w20 + sli v19.4s,v28.4s,#8 + eor w10,w10,w15 + sli v23.4s,v29.4s,#8 + eor w11,w11,w16 + add v2.4s,v2.4s,v3.4s + eor w12,w12,w13 + add v6.4s,v6.4s,v7.4s + eor w9,w9,w14 + add v10.4s,v10.4s,v11.4s + ror w10,w10,#20 + add v14.4s,v14.4s,v15.4s + ror w11,w11,#20 + add v18.4s,v18.4s,v19.4s + ror w12,w12,#20 + add v22.4s,v22.4s,v23.4s + ror w9,w9,#20 + eor v24.16b,v1.16b,v2.16b + add w5,w5,w10 + eor v25.16b,v5.16b,v6.16b + add w6,w6,w11 + eor v26.16b,v9.16b,v10.16b + add w7,w7,w12 + eor v27.16b,v13.16b,v14.16b + add w8,w8,w9 + eor v28.16b,v17.16b,v18.16b + eor w21,w21,w5 + eor v29.16b,v21.16b,v22.16b + eor w17,w17,w6 + ushr v1.4s,v24.4s,#25 + eor w19,w19,w7 + ushr v5.4s,v25.4s,#25 + eor w20,w20,w8 + ushr v9.4s,v26.4s,#25 + ror w21,w21,#24 + ushr v13.4s,v27.4s,#25 + ror w17,w17,#24 + ushr v17.4s,v28.4s,#25 + ror w19,w19,#24 + ushr v21.4s,v29.4s,#25 + ror w20,w20,#24 + sli v1.4s,v24.4s,#7 + add w15,w15,w21 + sli v5.4s,v25.4s,#7 + add w16,w16,w17 + sli v9.4s,v26.4s,#7 + add w13,w13,w19 + sli v13.4s,v27.4s,#7 + add w14,w14,w20 + sli v17.4s,v28.4s,#7 + eor w10,w10,w15 + sli v21.4s,v29.4s,#7 + eor w11,w11,w16 + ext v2.16b,v2.16b,v2.16b,#8 + eor w12,w12,w13 + ext v6.16b,v6.16b,v6.16b,#8 + eor w9,w9,w14 + ext v10.16b,v10.16b,v10.16b,#8 + ror w10,w10,#25 + ext v14.16b,v14.16b,v14.16b,#8 + ror w11,w11,#25 + ext v18.16b,v18.16b,v18.16b,#8 + ror w12,w12,#25 + ext v22.16b,v22.16b,v22.16b,#8 + ror w9,w9,#25 + ext v3.16b,v3.16b,v3.16b,#12 + ext v7.16b,v7.16b,v7.16b,#12 + ext v11.16b,v11.16b,v11.16b,#12 + ext v15.16b,v15.16b,v15.16b,#12 + ext v19.16b,v19.16b,v19.16b,#12 + ext v23.16b,v23.16b,v23.16b,#12 + ext v1.16b,v1.16b,v1.16b,#4 + ext v5.16b,v5.16b,v5.16b,#4 + ext v9.16b,v9.16b,v9.16b,#4 + ext v13.16b,v13.16b,v13.16b,#4 + ext v17.16b,v17.16b,v17.16b,#4 + ext v21.16b,v21.16b,v21.16b,#4 + add v0.4s,v0.4s,v1.4s + add w5,w5,w9 + add v4.4s,v4.4s,v5.4s + add w6,w6,w10 + add v8.4s,v8.4s,v9.4s + add w7,w7,w11 + add v12.4s,v12.4s,v13.4s + add w8,w8,w12 + add v16.4s,v16.4s,v17.4s + eor w17,w17,w5 + add v20.4s,v20.4s,v21.4s + eor w19,w19,w6 + eor v3.16b,v3.16b,v0.16b + eor w20,w20,w7 + eor v7.16b,v7.16b,v4.16b + eor w21,w21,w8 + eor v11.16b,v11.16b,v8.16b + ror w17,w17,#16 + eor v15.16b,v15.16b,v12.16b + ror w19,w19,#16 + eor v19.16b,v19.16b,v16.16b + ror w20,w20,#16 + eor v23.16b,v23.16b,v20.16b + ror w21,w21,#16 + rev32 v3.8h,v3.8h + add w13,w13,w17 + rev32 v7.8h,v7.8h + add w14,w14,w19 + rev32 v11.8h,v11.8h + add w15,w15,w20 + rev32 v15.8h,v15.8h + add w16,w16,w21 + rev32 v19.8h,v19.8h + eor w9,w9,w13 + rev32 v23.8h,v23.8h + eor w10,w10,w14 + add v2.4s,v2.4s,v3.4s + eor w11,w11,w15 + add v6.4s,v6.4s,v7.4s + eor w12,w12,w16 + add v10.4s,v10.4s,v11.4s + ror w9,w9,#20 + add v14.4s,v14.4s,v15.4s + ror w10,w10,#20 + add v18.4s,v18.4s,v19.4s + ror w11,w11,#20 + add v22.4s,v22.4s,v23.4s + ror w12,w12,#20 + eor v24.16b,v1.16b,v2.16b + add w5,w5,w9 + eor v25.16b,v5.16b,v6.16b + add w6,w6,w10 + eor v26.16b,v9.16b,v10.16b + add w7,w7,w11 + eor v27.16b,v13.16b,v14.16b + add w8,w8,w12 + eor v28.16b,v17.16b,v18.16b + eor w17,w17,w5 + eor v29.16b,v21.16b,v22.16b + eor w19,w19,w6 + ushr v1.4s,v24.4s,#20 + eor w20,w20,w7 + ushr v5.4s,v25.4s,#20 + eor w21,w21,w8 + ushr v9.4s,v26.4s,#20 + ror w17,w17,#24 + ushr v13.4s,v27.4s,#20 + ror w19,w19,#24 + ushr v17.4s,v28.4s,#20 + ror w20,w20,#24 + ushr v21.4s,v29.4s,#20 + ror w21,w21,#24 + sli v1.4s,v24.4s,#12 + add w13,w13,w17 + sli v5.4s,v25.4s,#12 + add w14,w14,w19 + sli v9.4s,v26.4s,#12 + add w15,w15,w20 + sli v13.4s,v27.4s,#12 + add w16,w16,w21 + sli v17.4s,v28.4s,#12 + eor w9,w9,w13 + sli v21.4s,v29.4s,#12 + eor w10,w10,w14 + add v0.4s,v0.4s,v1.4s + eor w11,w11,w15 + add v4.4s,v4.4s,v5.4s + eor w12,w12,w16 + add v8.4s,v8.4s,v9.4s + ror w9,w9,#25 + add v12.4s,v12.4s,v13.4s + ror w10,w10,#25 + add v16.4s,v16.4s,v17.4s + ror w11,w11,#25 + add v20.4s,v20.4s,v21.4s + ror w12,w12,#25 + eor v24.16b,v3.16b,v0.16b + add w5,w5,w10 + eor v25.16b,v7.16b,v4.16b + add w6,w6,w11 + eor v26.16b,v11.16b,v8.16b + add w7,w7,w12 + eor v27.16b,v15.16b,v12.16b + add w8,w8,w9 + eor v28.16b,v19.16b,v16.16b + eor w21,w21,w5 + eor v29.16b,v23.16b,v20.16b + eor w17,w17,w6 + ushr v3.4s,v24.4s,#24 + eor w19,w19,w7 + ushr v7.4s,v25.4s,#24 + eor w20,w20,w8 + ushr v11.4s,v26.4s,#24 + ror w21,w21,#16 + ushr v15.4s,v27.4s,#24 + ror w17,w17,#16 + ushr v19.4s,v28.4s,#24 + ror w19,w19,#16 + ushr v23.4s,v29.4s,#24 + ror w20,w20,#16 + sli v3.4s,v24.4s,#8 + add w15,w15,w21 + sli v7.4s,v25.4s,#8 + add w16,w16,w17 + sli v11.4s,v26.4s,#8 + add w13,w13,w19 + sli v15.4s,v27.4s,#8 + add w14,w14,w20 + sli v19.4s,v28.4s,#8 + eor w10,w10,w15 + sli v23.4s,v29.4s,#8 + eor w11,w11,w16 + add v2.4s,v2.4s,v3.4s + eor w12,w12,w13 + add v6.4s,v6.4s,v7.4s + eor w9,w9,w14 + add v10.4s,v10.4s,v11.4s + ror w10,w10,#20 + add v14.4s,v14.4s,v15.4s + ror w11,w11,#20 + add v18.4s,v18.4s,v19.4s + ror w12,w12,#20 + add v22.4s,v22.4s,v23.4s + ror w9,w9,#20 + eor v24.16b,v1.16b,v2.16b + add w5,w5,w10 + eor v25.16b,v5.16b,v6.16b + add w6,w6,w11 + eor v26.16b,v9.16b,v10.16b + add w7,w7,w12 + eor v27.16b,v13.16b,v14.16b + add w8,w8,w9 + eor v28.16b,v17.16b,v18.16b + eor w21,w21,w5 + eor v29.16b,v21.16b,v22.16b + eor w17,w17,w6 + ushr v1.4s,v24.4s,#25 + eor w19,w19,w7 + ushr v5.4s,v25.4s,#25 + eor w20,w20,w8 + ushr v9.4s,v26.4s,#25 + ror w21,w21,#24 + ushr v13.4s,v27.4s,#25 + ror w17,w17,#24 + ushr v17.4s,v28.4s,#25 + ror w19,w19,#24 + ushr v21.4s,v29.4s,#25 + ror w20,w20,#24 + sli v1.4s,v24.4s,#7 + add w15,w15,w21 + sli v5.4s,v25.4s,#7 + add w16,w16,w17 + sli v9.4s,v26.4s,#7 + add w13,w13,w19 + sli v13.4s,v27.4s,#7 + add w14,w14,w20 + sli v17.4s,v28.4s,#7 + eor w10,w10,w15 + sli v21.4s,v29.4s,#7 + eor w11,w11,w16 + ext v2.16b,v2.16b,v2.16b,#8 + eor w12,w12,w13 + ext v6.16b,v6.16b,v6.16b,#8 + eor w9,w9,w14 + ext v10.16b,v10.16b,v10.16b,#8 + ror w10,w10,#25 + ext v14.16b,v14.16b,v14.16b,#8 + ror w11,w11,#25 + ext v18.16b,v18.16b,v18.16b,#8 + ror w12,w12,#25 + ext v22.16b,v22.16b,v22.16b,#8 + ror w9,w9,#25 + ext v3.16b,v3.16b,v3.16b,#4 + ext v7.16b,v7.16b,v7.16b,#4 + ext v11.16b,v11.16b,v11.16b,#4 + ext v15.16b,v15.16b,v15.16b,#4 + ext v19.16b,v19.16b,v19.16b,#4 + ext v23.16b,v23.16b,v23.16b,#4 + ext v1.16b,v1.16b,v1.16b,#12 + ext v5.16b,v5.16b,v5.16b,#12 + ext v9.16b,v9.16b,v9.16b,#12 + ext v13.16b,v13.16b,v13.16b,#12 + ext v17.16b,v17.16b,v17.16b,#12 + ext v21.16b,v21.16b,v21.16b,#12 + cbnz x4,.Loop_upper_neon + + add w5,w5,w22 // accumulate key block + add x6,x6,x22,lsr#32 + add w7,w7,w23 + add x8,x8,x23,lsr#32 + add w9,w9,w24 + add x10,x10,x24,lsr#32 + add w11,w11,w25 + add x12,x12,x25,lsr#32 + add w13,w13,w26 + add x14,x14,x26,lsr#32 + add w15,w15,w27 + add x16,x16,x27,lsr#32 + add w17,w17,w28 + add x19,x19,x28,lsr#32 + add w20,w20,w30 + add x21,x21,x30,lsr#32 + + add x5,x5,x6,lsl#32 // pack + add x7,x7,x8,lsl#32 + ldp x6,x8,[x1,#0] // load input + add x9,x9,x10,lsl#32 + add x11,x11,x12,lsl#32 + ldp x10,x12,[x1,#16] + add x13,x13,x14,lsl#32 + add x15,x15,x16,lsl#32 + ldp x14,x16,[x1,#32] + add x17,x17,x19,lsl#32 + add x20,x20,x21,lsl#32 + ldp x19,x21,[x1,#48] + add x1,x1,#64 +#ifdef __ARMEB__ + rev x5,x5 + rev x7,x7 + rev x9,x9 + rev x11,x11 + rev x13,x13 + rev x15,x15 + rev x17,x17 + rev x20,x20 +#endif + eor x5,x5,x6 + eor x7,x7,x8 + eor x9,x9,x10 + eor x11,x11,x12 + eor x13,x13,x14 + eor x15,x15,x16 + eor x17,x17,x19 + eor x20,x20,x21 + + stp x5,x7,[x0,#0] // store output + add x28,x28,#1 // increment counter + mov w5,w22 // unpack key block + lsr x6,x22,#32 + stp x9,x11,[x0,#16] + mov w7,w23 + lsr x8,x23,#32 + stp x13,x15,[x0,#32] + mov w9,w24 + lsr x10,x24,#32 + stp x17,x20,[x0,#48] + add x0,x0,#64 + mov w11,w25 + lsr x12,x25,#32 + mov w13,w26 + lsr x14,x26,#32 + mov w15,w27 + lsr x16,x27,#32 + mov w17,w28 + lsr x19,x28,#32 + mov w20,w30 + lsr x21,x30,#32 + + mov x4,#5 +.Loop_lower_neon: + sub x4,x4,#1 + add v0.4s,v0.4s,v1.4s + add w5,w5,w9 + add v4.4s,v4.4s,v5.4s + add w6,w6,w10 + add v8.4s,v8.4s,v9.4s + add w7,w7,w11 + add v12.4s,v12.4s,v13.4s + add w8,w8,w12 + add v16.4s,v16.4s,v17.4s + eor w17,w17,w5 + add v20.4s,v20.4s,v21.4s + eor w19,w19,w6 + eor v3.16b,v3.16b,v0.16b + eor w20,w20,w7 + eor v7.16b,v7.16b,v4.16b + eor w21,w21,w8 + eor v11.16b,v11.16b,v8.16b + ror w17,w17,#16 + eor v15.16b,v15.16b,v12.16b + ror w19,w19,#16 + eor v19.16b,v19.16b,v16.16b + ror w20,w20,#16 + eor v23.16b,v23.16b,v20.16b + ror w21,w21,#16 + rev32 v3.8h,v3.8h + add w13,w13,w17 + rev32 v7.8h,v7.8h + add w14,w14,w19 + rev32 v11.8h,v11.8h + add w15,w15,w20 + rev32 v15.8h,v15.8h + add w16,w16,w21 + rev32 v19.8h,v19.8h + eor w9,w9,w13 + rev32 v23.8h,v23.8h + eor w10,w10,w14 + add v2.4s,v2.4s,v3.4s + eor w11,w11,w15 + add v6.4s,v6.4s,v7.4s + eor w12,w12,w16 + add v10.4s,v10.4s,v11.4s + ror w9,w9,#20 + add v14.4s,v14.4s,v15.4s + ror w10,w10,#20 + add v18.4s,v18.4s,v19.4s + ror w11,w11,#20 + add v22.4s,v22.4s,v23.4s + ror w12,w12,#20 + eor v24.16b,v1.16b,v2.16b + add w5,w5,w9 + eor v25.16b,v5.16b,v6.16b + add w6,w6,w10 + eor v26.16b,v9.16b,v10.16b + add w7,w7,w11 + eor v27.16b,v13.16b,v14.16b + add w8,w8,w12 + eor v28.16b,v17.16b,v18.16b + eor w17,w17,w5 + eor v29.16b,v21.16b,v22.16b + eor w19,w19,w6 + ushr v1.4s,v24.4s,#20 + eor w20,w20,w7 + ushr v5.4s,v25.4s,#20 + eor w21,w21,w8 + ushr v9.4s,v26.4s,#20 + ror w17,w17,#24 + ushr v13.4s,v27.4s,#20 + ror w19,w19,#24 + ushr v17.4s,v28.4s,#20 + ror w20,w20,#24 + ushr v21.4s,v29.4s,#20 + ror w21,w21,#24 + sli v1.4s,v24.4s,#12 + add w13,w13,w17 + sli v5.4s,v25.4s,#12 + add w14,w14,w19 + sli v9.4s,v26.4s,#12 + add w15,w15,w20 + sli v13.4s,v27.4s,#12 + add w16,w16,w21 + sli v17.4s,v28.4s,#12 + eor w9,w9,w13 + sli v21.4s,v29.4s,#12 + eor w10,w10,w14 + add v0.4s,v0.4s,v1.4s + eor w11,w11,w15 + add v4.4s,v4.4s,v5.4s + eor w12,w12,w16 + add v8.4s,v8.4s,v9.4s + ror w9,w9,#25 + add v12.4s,v12.4s,v13.4s + ror w10,w10,#25 + add v16.4s,v16.4s,v17.4s + ror w11,w11,#25 + add v20.4s,v20.4s,v21.4s + ror w12,w12,#25 + eor v24.16b,v3.16b,v0.16b + add w5,w5,w10 + eor v25.16b,v7.16b,v4.16b + add w6,w6,w11 + eor v26.16b,v11.16b,v8.16b + add w7,w7,w12 + eor v27.16b,v15.16b,v12.16b + add w8,w8,w9 + eor v28.16b,v19.16b,v16.16b + eor w21,w21,w5 + eor v29.16b,v23.16b,v20.16b + eor w17,w17,w6 + ushr v3.4s,v24.4s,#24 + eor w19,w19,w7 + ushr v7.4s,v25.4s,#24 + eor w20,w20,w8 + ushr v11.4s,v26.4s,#24 + ror w21,w21,#16 + ushr v15.4s,v27.4s,#24 + ror w17,w17,#16 + ushr v19.4s,v28.4s,#24 + ror w19,w19,#16 + ushr v23.4s,v29.4s,#24 + ror w20,w20,#16 + sli v3.4s,v24.4s,#8 + add w15,w15,w21 + sli v7.4s,v25.4s,#8 + add w16,w16,w17 + sli v11.4s,v26.4s,#8 + add w13,w13,w19 + sli v15.4s,v27.4s,#8 + add w14,w14,w20 + sli v19.4s,v28.4s,#8 + eor w10,w10,w15 + sli v23.4s,v29.4s,#8 + eor w11,w11,w16 + add v2.4s,v2.4s,v3.4s + eor w12,w12,w13 + add v6.4s,v6.4s,v7.4s + eor w9,w9,w14 + add v10.4s,v10.4s,v11.4s + ror w10,w10,#20 + add v14.4s,v14.4s,v15.4s + ror w11,w11,#20 + add v18.4s,v18.4s,v19.4s + ror w12,w12,#20 + add v22.4s,v22.4s,v23.4s + ror w9,w9,#20 + eor v24.16b,v1.16b,v2.16b + add w5,w5,w10 + eor v25.16b,v5.16b,v6.16b + add w6,w6,w11 + eor v26.16b,v9.16b,v10.16b + add w7,w7,w12 + eor v27.16b,v13.16b,v14.16b + add w8,w8,w9 + eor v28.16b,v17.16b,v18.16b + eor w21,w21,w5 + eor v29.16b,v21.16b,v22.16b + eor w17,w17,w6 + ushr v1.4s,v24.4s,#25 + eor w19,w19,w7 + ushr v5.4s,v25.4s,#25 + eor w20,w20,w8 + ushr v9.4s,v26.4s,#25 + ror w21,w21,#24 + ushr v13.4s,v27.4s,#25 + ror w17,w17,#24 + ushr v17.4s,v28.4s,#25 + ror w19,w19,#24 + ushr v21.4s,v29.4s,#25 + ror w20,w20,#24 + sli v1.4s,v24.4s,#7 + add w15,w15,w21 + sli v5.4s,v25.4s,#7 + add w16,w16,w17 + sli v9.4s,v26.4s,#7 + add w13,w13,w19 + sli v13.4s,v27.4s,#7 + add w14,w14,w20 + sli v17.4s,v28.4s,#7 + eor w10,w10,w15 + sli v21.4s,v29.4s,#7 + eor w11,w11,w16 + ext v2.16b,v2.16b,v2.16b,#8 + eor w12,w12,w13 + ext v6.16b,v6.16b,v6.16b,#8 + eor w9,w9,w14 + ext v10.16b,v10.16b,v10.16b,#8 + ror w10,w10,#25 + ext v14.16b,v14.16b,v14.16b,#8 + ror w11,w11,#25 + ext v18.16b,v18.16b,v18.16b,#8 + ror w12,w12,#25 + ext v22.16b,v22.16b,v22.16b,#8 + ror w9,w9,#25 + ext v3.16b,v3.16b,v3.16b,#12 + ext v7.16b,v7.16b,v7.16b,#12 + ext v11.16b,v11.16b,v11.16b,#12 + ext v15.16b,v15.16b,v15.16b,#12 + ext v19.16b,v19.16b,v19.16b,#12 + ext v23.16b,v23.16b,v23.16b,#12 + ext v1.16b,v1.16b,v1.16b,#4 + ext v5.16b,v5.16b,v5.16b,#4 + ext v9.16b,v9.16b,v9.16b,#4 + ext v13.16b,v13.16b,v13.16b,#4 + ext v17.16b,v17.16b,v17.16b,#4 + ext v21.16b,v21.16b,v21.16b,#4 + add v0.4s,v0.4s,v1.4s + add w5,w5,w9 + add v4.4s,v4.4s,v5.4s + add w6,w6,w10 + add v8.4s,v8.4s,v9.4s + add w7,w7,w11 + add v12.4s,v12.4s,v13.4s + add w8,w8,w12 + add v16.4s,v16.4s,v17.4s + eor w17,w17,w5 + add v20.4s,v20.4s,v21.4s + eor w19,w19,w6 + eor v3.16b,v3.16b,v0.16b + eor w20,w20,w7 + eor v7.16b,v7.16b,v4.16b + eor w21,w21,w8 + eor v11.16b,v11.16b,v8.16b + ror w17,w17,#16 + eor v15.16b,v15.16b,v12.16b + ror w19,w19,#16 + eor v19.16b,v19.16b,v16.16b + ror w20,w20,#16 + eor v23.16b,v23.16b,v20.16b + ror w21,w21,#16 + rev32 v3.8h,v3.8h + add w13,w13,w17 + rev32 v7.8h,v7.8h + add w14,w14,w19 + rev32 v11.8h,v11.8h + add w15,w15,w20 + rev32 v15.8h,v15.8h + add w16,w16,w21 + rev32 v19.8h,v19.8h + eor w9,w9,w13 + rev32 v23.8h,v23.8h + eor w10,w10,w14 + add v2.4s,v2.4s,v3.4s + eor w11,w11,w15 + add v6.4s,v6.4s,v7.4s + eor w12,w12,w16 + add v10.4s,v10.4s,v11.4s + ror w9,w9,#20 + add v14.4s,v14.4s,v15.4s + ror w10,w10,#20 + add v18.4s,v18.4s,v19.4s + ror w11,w11,#20 + add v22.4s,v22.4s,v23.4s + ror w12,w12,#20 + eor v24.16b,v1.16b,v2.16b + add w5,w5,w9 + eor v25.16b,v5.16b,v6.16b + add w6,w6,w10 + eor v26.16b,v9.16b,v10.16b + add w7,w7,w11 + eor v27.16b,v13.16b,v14.16b + add w8,w8,w12 + eor v28.16b,v17.16b,v18.16b + eor w17,w17,w5 + eor v29.16b,v21.16b,v22.16b + eor w19,w19,w6 + ushr v1.4s,v24.4s,#20 + eor w20,w20,w7 + ushr v5.4s,v25.4s,#20 + eor w21,w21,w8 + ushr v9.4s,v26.4s,#20 + ror w17,w17,#24 + ushr v13.4s,v27.4s,#20 + ror w19,w19,#24 + ushr v17.4s,v28.4s,#20 + ror w20,w20,#24 + ushr v21.4s,v29.4s,#20 + ror w21,w21,#24 + sli v1.4s,v24.4s,#12 + add w13,w13,w17 + sli v5.4s,v25.4s,#12 + add w14,w14,w19 + sli v9.4s,v26.4s,#12 + add w15,w15,w20 + sli v13.4s,v27.4s,#12 + add w16,w16,w21 + sli v17.4s,v28.4s,#12 + eor w9,w9,w13 + sli v21.4s,v29.4s,#12 + eor w10,w10,w14 + add v0.4s,v0.4s,v1.4s + eor w11,w11,w15 + add v4.4s,v4.4s,v5.4s + eor w12,w12,w16 + add v8.4s,v8.4s,v9.4s + ror w9,w9,#25 + add v12.4s,v12.4s,v13.4s + ror w10,w10,#25 + add v16.4s,v16.4s,v17.4s + ror w11,w11,#25 + add v20.4s,v20.4s,v21.4s + ror w12,w12,#25 + eor v24.16b,v3.16b,v0.16b + add w5,w5,w10 + eor v25.16b,v7.16b,v4.16b + add w6,w6,w11 + eor v26.16b,v11.16b,v8.16b + add w7,w7,w12 + eor v27.16b,v15.16b,v12.16b + add w8,w8,w9 + eor v28.16b,v19.16b,v16.16b + eor w21,w21,w5 + eor v29.16b,v23.16b,v20.16b + eor w17,w17,w6 + ushr v3.4s,v24.4s,#24 + eor w19,w19,w7 + ushr v7.4s,v25.4s,#24 + eor w20,w20,w8 + ushr v11.4s,v26.4s,#24 + ror w21,w21,#16 + ushr v15.4s,v27.4s,#24 + ror w17,w17,#16 + ushr v19.4s,v28.4s,#24 + ror w19,w19,#16 + ushr v23.4s,v29.4s,#24 + ror w20,w20,#16 + sli v3.4s,v24.4s,#8 + add w15,w15,w21 + sli v7.4s,v25.4s,#8 + add w16,w16,w17 + sli v11.4s,v26.4s,#8 + add w13,w13,w19 + sli v15.4s,v27.4s,#8 + add w14,w14,w20 + sli v19.4s,v28.4s,#8 + eor w10,w10,w15 + sli v23.4s,v29.4s,#8 + eor w11,w11,w16 + add v2.4s,v2.4s,v3.4s + eor w12,w12,w13 + add v6.4s,v6.4s,v7.4s + eor w9,w9,w14 + add v10.4s,v10.4s,v11.4s + ror w10,w10,#20 + add v14.4s,v14.4s,v15.4s + ror w11,w11,#20 + add v18.4s,v18.4s,v19.4s + ror w12,w12,#20 + add v22.4s,v22.4s,v23.4s + ror w9,w9,#20 + eor v24.16b,v1.16b,v2.16b + add w5,w5,w10 + eor v25.16b,v5.16b,v6.16b + add w6,w6,w11 + eor v26.16b,v9.16b,v10.16b + add w7,w7,w12 + eor v27.16b,v13.16b,v14.16b + add w8,w8,w9 + eor v28.16b,v17.16b,v18.16b + eor w21,w21,w5 + eor v29.16b,v21.16b,v22.16b + eor w17,w17,w6 + ushr v1.4s,v24.4s,#25 + eor w19,w19,w7 + ushr v5.4s,v25.4s,#25 + eor w20,w20,w8 + ushr v9.4s,v26.4s,#25 + ror w21,w21,#24 + ushr v13.4s,v27.4s,#25 + ror w17,w17,#24 + ushr v17.4s,v28.4s,#25 + ror w19,w19,#24 + ushr v21.4s,v29.4s,#25 + ror w20,w20,#24 + sli v1.4s,v24.4s,#7 + add w15,w15,w21 + sli v5.4s,v25.4s,#7 + add w16,w16,w17 + sli v9.4s,v26.4s,#7 + add w13,w13,w19 + sli v13.4s,v27.4s,#7 + add w14,w14,w20 + sli v17.4s,v28.4s,#7 + eor w10,w10,w15 + sli v21.4s,v29.4s,#7 + eor w11,w11,w16 + ext v2.16b,v2.16b,v2.16b,#8 + eor w12,w12,w13 + ext v6.16b,v6.16b,v6.16b,#8 + eor w9,w9,w14 + ext v10.16b,v10.16b,v10.16b,#8 + ror w10,w10,#25 + ext v14.16b,v14.16b,v14.16b,#8 + ror w11,w11,#25 + ext v18.16b,v18.16b,v18.16b,#8 + ror w12,w12,#25 + ext v22.16b,v22.16b,v22.16b,#8 + ror w9,w9,#25 + ext v3.16b,v3.16b,v3.16b,#4 + ext v7.16b,v7.16b,v7.16b,#4 + ext v11.16b,v11.16b,v11.16b,#4 + ext v15.16b,v15.16b,v15.16b,#4 + ext v19.16b,v19.16b,v19.16b,#4 + ext v23.16b,v23.16b,v23.16b,#4 + ext v1.16b,v1.16b,v1.16b,#12 + ext v5.16b,v5.16b,v5.16b,#12 + ext v9.16b,v9.16b,v9.16b,#12 + ext v13.16b,v13.16b,v13.16b,#12 + ext v17.16b,v17.16b,v17.16b,#12 + ext v21.16b,v21.16b,v21.16b,#12 + cbnz x4,.Loop_lower_neon + + add w5,w5,w22 // accumulate key block + ldp q24,q25,[sp,#0] + add x6,x6,x22,lsr#32 + ldp q26,q27,[sp,#32] + add w7,w7,w23 + ldp q28,q29,[sp,#64] + add x8,x8,x23,lsr#32 + add v0.4s,v0.4s,v24.4s + add w9,w9,w24 + add v4.4s,v4.4s,v24.4s + add x10,x10,x24,lsr#32 + add v8.4s,v8.4s,v24.4s + add w11,w11,w25 + add v12.4s,v12.4s,v24.4s + add x12,x12,x25,lsr#32 + add v16.4s,v16.4s,v24.4s + add w13,w13,w26 + add v20.4s,v20.4s,v24.4s + add x14,x14,x26,lsr#32 + add v2.4s,v2.4s,v26.4s + add w15,w15,w27 + add v6.4s,v6.4s,v26.4s + add x16,x16,x27,lsr#32 + add v10.4s,v10.4s,v26.4s + add w17,w17,w28 + add v14.4s,v14.4s,v26.4s + add x19,x19,x28,lsr#32 + add v18.4s,v18.4s,v26.4s + add w20,w20,w30 + add v22.4s,v22.4s,v26.4s + add x21,x21,x30,lsr#32 + add v19.4s,v19.4s,v31.4s // +4 + add x5,x5,x6,lsl#32 // pack + add v23.4s,v23.4s,v31.4s // +4 + add x7,x7,x8,lsl#32 + add v3.4s,v3.4s,v27.4s + ldp x6,x8,[x1,#0] // load input + add v7.4s,v7.4s,v28.4s + add x9,x9,x10,lsl#32 + add v11.4s,v11.4s,v29.4s + add x11,x11,x12,lsl#32 + add v15.4s,v15.4s,v30.4s + ldp x10,x12,[x1,#16] + add v19.4s,v19.4s,v27.4s + add x13,x13,x14,lsl#32 + add v23.4s,v23.4s,v28.4s + add x15,x15,x16,lsl#32 + add v1.4s,v1.4s,v25.4s + ldp x14,x16,[x1,#32] + add v5.4s,v5.4s,v25.4s + add x17,x17,x19,lsl#32 + add v9.4s,v9.4s,v25.4s + add x20,x20,x21,lsl#32 + add v13.4s,v13.4s,v25.4s + ldp x19,x21,[x1,#48] + add v17.4s,v17.4s,v25.4s + add x1,x1,#64 + add v21.4s,v21.4s,v25.4s + +#ifdef __ARMEB__ + rev x5,x5 + rev x7,x7 + rev x9,x9 + rev x11,x11 + rev x13,x13 + rev x15,x15 + rev x17,x17 + rev x20,x20 +#endif + ld1 {v24.16b,v25.16b,v26.16b,v27.16b},[x1],#64 + eor x5,x5,x6 + eor x7,x7,x8 + eor x9,x9,x10 + eor x11,x11,x12 + eor x13,x13,x14 + eor v0.16b,v0.16b,v24.16b + eor x15,x15,x16 + eor v1.16b,v1.16b,v25.16b + eor x17,x17,x19 + eor v2.16b,v2.16b,v26.16b + eor x20,x20,x21 + eor v3.16b,v3.16b,v27.16b + ld1 {v24.16b,v25.16b,v26.16b,v27.16b},[x1],#64 + + stp x5,x7,[x0,#0] // store output + add x28,x28,#7 // increment counter + stp x9,x11,[x0,#16] + stp x13,x15,[x0,#32] + stp x17,x20,[x0,#48] + add x0,x0,#64 + st1 {v0.16b,v1.16b,v2.16b,v3.16b},[x0],#64 + + ld1 {v0.16b,v1.16b,v2.16b,v3.16b},[x1],#64 + eor v4.16b,v4.16b,v24.16b + eor v5.16b,v5.16b,v25.16b + eor v6.16b,v6.16b,v26.16b + eor v7.16b,v7.16b,v27.16b + st1 {v4.16b,v5.16b,v6.16b,v7.16b},[x0],#64 + + ld1 {v4.16b,v5.16b,v6.16b,v7.16b},[x1],#64 + eor v8.16b,v8.16b,v0.16b + ldp q24,q25,[sp,#0] + eor v9.16b,v9.16b,v1.16b + ldp q26,q27,[sp,#32] + eor v10.16b,v10.16b,v2.16b + eor v11.16b,v11.16b,v3.16b + st1 {v8.16b,v9.16b,v10.16b,v11.16b},[x0],#64 + + ld1 {v8.16b,v9.16b,v10.16b,v11.16b},[x1],#64 + eor v12.16b,v12.16b,v4.16b + eor v13.16b,v13.16b,v5.16b + eor v14.16b,v14.16b,v6.16b + eor v15.16b,v15.16b,v7.16b + st1 {v12.16b,v13.16b,v14.16b,v15.16b},[x0],#64 + + ld1 {v12.16b,v13.16b,v14.16b,v15.16b},[x1],#64 + eor v16.16b,v16.16b,v8.16b + eor v17.16b,v17.16b,v9.16b + eor v18.16b,v18.16b,v10.16b + eor v19.16b,v19.16b,v11.16b + st1 {v16.16b,v17.16b,v18.16b,v19.16b},[x0],#64 + + shl v0.4s,v31.4s,#1 // 4 -> 8 + eor v20.16b,v20.16b,v12.16b + eor v21.16b,v21.16b,v13.16b + eor v22.16b,v22.16b,v14.16b + eor v23.16b,v23.16b,v15.16b + st1 {v20.16b,v21.16b,v22.16b,v23.16b},[x0],#64 + + add v27.4s,v27.4s,v0.4s // += 8 + add v28.4s,v28.4s,v0.4s + add v29.4s,v29.4s,v0.4s + add v30.4s,v30.4s,v0.4s + + b.hs .Loop_outer_512_neon + + adds x2,x2,#512 + ushr v0.4s,v31.4s,#2 // 4 -> 1 + + ldp d8,d9,[sp,#128+0] // meet ABI requirements + ldp d10,d11,[sp,#128+16] + ldp d12,d13,[sp,#128+32] + ldp d14,d15,[sp,#128+48] + + stp q24,q31,[sp,#0] // wipe off-load area + stp q24,q31,[sp,#32] + stp q24,q31,[sp,#64] + + b.eq .Ldone_512_neon + + cmp x2,#192 + sub v27.4s,v27.4s,v0.4s // -= 1 + sub v28.4s,v28.4s,v0.4s + sub v29.4s,v29.4s,v0.4s + add sp,sp,#128 + b.hs .Loop_outer_neon + + eor v25.16b,v25.16b,v25.16b + eor v26.16b,v26.16b,v26.16b + eor v27.16b,v27.16b,v27.16b + eor v28.16b,v28.16b,v28.16b + eor v29.16b,v29.16b,v29.16b + eor v30.16b,v30.16b,v30.16b + b .Loop_outer + +.Ldone_512_neon: + ldp x19,x20,[x29,#16] + add sp,sp,#128+64 + ldp x21,x22,[x29,#32] + ldp x23,x24,[x29,#48] + ldp x25,x26,[x29,#64] + ldp x27,x28,[x29,#80] + ldp x29,x30,[sp],#96 +.Labort_neon: + ret +ENDPROC(chacha20_neon) diff --git a/crypto/chacha20poly1305.cpp b/crypto/chacha20poly1305.cpp index a5c222d..6ec5a75 100644 --- a/crypto/chacha20poly1305.cpp +++ b/crypto/chacha20poly1305.cpp @@ -43,8 +43,22 @@ void _cdecl poly1305_emit_avx(void *ctx, uint8 mac[16], const uint32 nonce[4]); void _cdecl poly1305_blocks_avx(void *ctx, const uint8 *inp, size_t len, uint32 padbit); void _cdecl poly1305_blocks_avx2(void *ctx, const uint8 *inp, size_t len, uint32 padbit); void _cdecl poly1305_blocks_avx512(void *ctx, const uint8 *inp, size_t len, uint32 padbit); + +#if defined(ARCH_CPU_ARM_FAMILY) +void chacha20_arm(uint8 *out, const uint8 *in, size_t len, const uint32 key[8], const uint32 counter[4]); +void chacha20_neon(uint8 *out, const uint8 *in, size_t len, const uint32 key[8], const uint32 counter[4]); +#endif +void poly1305_init_arm(void *ctx, const uint8 key[16]); +void poly1305_blocks_arm(void *ctx, const uint8 *inp, size_t len, uint32 padbit); +void poly1305_emit_arm(void *ctx, uint8 mac[16], const uint32 nonce[4]); +void poly1305_blocks_neon(void *ctx, const uint8 *inp, size_t len, uint32 padbit); +void poly1305_emit_neon(void *ctx, uint8 mac[16], const uint32 nonce[4]); + } + + + struct chacha20_ctx { uint32 state[CHACHA20_BLOCK_SIZE / sizeof(uint32)]; }; @@ -193,6 +207,17 @@ SAFEBUFFERS static void chacha20_crypt(struct chacha20_ctx *ctx, uint8 *dst, con } #endif // defined(ARCH_CPU_X86_64) +#if defined(ARCH_CPU_ARM_FAMILY) + if (ARM_PCAP_NEON) { + chacha20_neon(dst, src, bytes, &ctx->state[4], &ctx->state[12]); + } else { + chacha20_arm(dst, src, bytes, &ctx->state[4], &ctx->state[12]); + } + ctx->state[12] += (bytes + 63) / 64; + return; +#endif // defined(ARCH_CPU_ARM_FAMILY) + + if (dst != src) memcpy(dst, src, bytes); @@ -385,7 +410,7 @@ SAFEBUFFERS static void poly1305_init(struct poly1305_ctx *ctx, const uint8 key[ #if defined(ARCH_CPU_X86_64) poly1305_init_x86_64(ctx->opaque, key); -#elif defined(CONFIG_ARM) || defined(CONFIG_ARM64) +#elif defined(ARCH_CPU_ARM_FAMILY) poly1305_init_arm(ctx->opaque, key); #elif defined(CONFIG_MIPS) && defined(CONFIG_64BIT) poly1305_init_mips(ctx->opaque, key); @@ -409,7 +434,12 @@ static inline void poly1305_blocks(void *ctx, const uint8 *inp, size_t len, uint poly1305_blocks_avx(ctx, inp, len, padbit); else poly1305_blocks_x86_64(ctx, inp, len, padbit); -#else // defined(ARCH_CPU_X86_64) +#elif defined(ARCH_CPU_ARM_FAMILY) + if (ARM_PCAP_NEON) + poly1305_blocks_neon(ctx, inp, len, padbit); + else + poly1305_blocks_arm(ctx, inp, len, padbit); +#else poly1305_blocks_generic(ctx, inp, len, padbit); #endif // defined(ARCH_CPU_X86_64) } @@ -421,6 +451,11 @@ static inline void poly1305_emit(void *ctx, uint8 mac[16], const uint32 nonce[4] poly1305_emit_avx(ctx, mac, nonce); else poly1305_emit_x86_64(ctx, mac, nonce); +#elif defined(ARCH_CPU_ARM_FAMILY) + if (ARM_PCAP_NEON) + poly1305_emit_neon(ctx, mac, nonce); + else + poly1305_emit_arm(ctx, mac, nonce); #else // defined(ARCH_CPU_X86_64) poly1305_emit_generic(ctx, mac, nonce); #endif // defined(ARCH_CPU_X86_64) diff --git a/crypto/curve25519-donna.h b/crypto/curve25519-donna.h index 6985273..93380b3 100644 --- a/crypto/curve25519-donna.h +++ b/crypto/curve25519-donna.h @@ -1,17 +1,17 @@ -#ifndef TUNSAFE_CRYPTO_CURVE25519_DONNA_H_ -#define TUNSAFE_CRYPTO_CURVE25519_DONNA_H_ - -#include "tunsafe_types.h" - -void curve25519_donna_ref(uint8 *mypublic, const uint8 *secret, const uint8 *basepoint); -extern "C" void curve25519_donna_x64(uint8 *mypublic, const uint8 *secret, const uint8 *basepoint); - -#if defined(ARCH_CPU_X86_64) && defined(COMPILER_MSVC) -#define curve25519_donna curve25519_donna_x64 -#else -#define curve25519_donna curve25519_donna_ref -#endif - -void curve25519_normalize(uint8 *e); - +#ifndef TUNSAFE_CRYPTO_CURVE25519_DONNA_H_ +#define TUNSAFE_CRYPTO_CURVE25519_DONNA_H_ + +#include "tunsafe_types.h" + +void curve25519_donna_ref(uint8 *mypublic, const uint8 *secret, const uint8 *basepoint); +extern "C" void curve25519_donna_x64(uint8 *mypublic, const uint8 *secret, const uint8 *basepoint); + +#if defined(ARCH_CPU_X86_64) && defined(COMPILER_MSVC) +#define curve25519_donna curve25519_donna_x64 +#else +#define curve25519_donna curve25519_donna_ref +#endif + +void curve25519_normalize(uint8 *e); + #endif // TUNSAFE_CRYPTO_CURVE25519_DONNA_H_ \ No newline at end of file diff --git a/crypto/make_all_asm_files.sh b/crypto/make_all_asm_files.sh old mode 100644 new mode 100755 diff --git a/crypto/make_poly1305_x64.pl b/crypto/make_poly1305_x64.pl old mode 100644 new mode 100755 diff --git a/crypto/poly1305/poly1305-arm.pl b/crypto/poly1305/poly1305-arm.pl new file mode 100644 index 0000000..5cdb6be --- /dev/null +++ b/crypto/poly1305/poly1305-arm.pl @@ -0,0 +1,1253 @@ +#! /usr/bin/env perl +# Copyright 2016-2018 The OpenSSL Project Authors. All Rights Reserved. +# +# Licensed under the OpenSSL license (the "License"). You may not use +# this file except in compliance with the License. You can obtain a copy +# in the file LICENSE in the source distribution or at +# https://www.openssl.org/source/license.html + +# +# ==================================================================== +# Written by Andy Polyakov for the OpenSSL +# project. The module is, however, dual licensed under OpenSSL and +# CRYPTOGAMS licenses depending on where you obtain it. For further +# details see http://www.openssl.org/~appro/cryptogams/. +# ==================================================================== +# +# IALU(*)/gcc-4.4 NEON +# +# ARM11xx(ARMv6) 7.78/+100% - +# Cortex-A5 6.35/+130% 3.00 +# Cortex-A8 6.25/+115% 2.36 +# Cortex-A9 5.10/+95% 2.55 +# Cortex-A15 3.85/+85% 1.25(**) +# Snapdragon S4 5.70/+100% 1.48(**) +# +# (*) this is for -march=armv6, i.e. with bunch of ldrb loading data; +# (**) these are trade-off results, they can be improved by ~8% but at +# the cost of 15/12% regression on Cortex-A5/A7, it's even possible +# to improve Cortex-A9 result, but then A5/A7 loose more than 20%; + +$flavour = shift; +if ($flavour=~/\w[\w\-]*\.\w+$/) { $output=$flavour; undef $flavour; } +else { while (($output=shift) && ($output!~/\w[\w\-]*\.\w+$/)) {} } + +if ($flavour && $flavour ne "void") { + $0 =~ m/(.*[\/\\])[^\/\\]+$/; $dir=$1; + ( $xlate="${dir}arm-xlate.pl" and -f $xlate ) or + ( $xlate="${dir}../../perlasm/arm-xlate.pl" and -f $xlate) or + die "can't locate arm-xlate.pl"; + + open STDOUT,"| \"$^X\" $xlate $flavour $output"; +} else { + open STDOUT,">$output"; +} + +($ctx,$inp,$len,$padbit)=map("r$_",(0..3)); + +$code.=<<___; +#include "arm_arch.h" + +.text +#if defined(__thumb2__) +.syntax unified +.thumb +#else +.code 32 +#endif + +.globl poly1305_emit +.globl poly1305_blocks +.globl poly1305_init +.type poly1305_init,%function +.align 5 +poly1305_init: +.Lpoly1305_init: + stmdb sp!,{r4-r11} + + eor r3,r3,r3 + cmp $inp,#0 + str r3,[$ctx,#0] @ zero hash value + str r3,[$ctx,#4] + str r3,[$ctx,#8] + str r3,[$ctx,#12] + str r3,[$ctx,#16] + str r3,[$ctx,#36] @ is_base2_26 + add $ctx,$ctx,#20 + +#ifdef __thumb2__ + it eq +#endif + moveq r0,#0 + beq .Lno_key + +#if __ARM_MAX_ARCH__>=7 + adr r11,.Lpoly1305_init + ldr r12,.LOPENSSL_armcap +#endif + ldrb r4,[$inp,#0] + mov r10,#0x0fffffff + ldrb r5,[$inp,#1] + and r3,r10,#-4 @ 0x0ffffffc + ldrb r6,[$inp,#2] + ldrb r7,[$inp,#3] + orr r4,r4,r5,lsl#8 + ldrb r5,[$inp,#4] + orr r4,r4,r6,lsl#16 + ldrb r6,[$inp,#5] + orr r4,r4,r7,lsl#24 + ldrb r7,[$inp,#6] + and r4,r4,r10 + +#if __ARM_MAX_ARCH__>=7 + ldr r12,[r11,r12] @ OPENSSL_armcap_P +# ifdef __APPLE__ + ldr r12,[r12] +# endif +#endif + ldrb r8,[$inp,#7] + orr r5,r5,r6,lsl#8 + ldrb r6,[$inp,#8] + orr r5,r5,r7,lsl#16 + ldrb r7,[$inp,#9] + orr r5,r5,r8,lsl#24 + ldrb r8,[$inp,#10] + and r5,r5,r3 + +#if __ARM_MAX_ARCH__>=7 + tst r12,#ARMV7_NEON @ check for NEON +# ifdef __APPLE__ + adr r9,poly1305_blocks_neon + adr r11,poly1305_blocks +# ifdef __thumb2__ + it ne +# endif + movne r11,r9 + adr r12,poly1305_emit + adr r10,poly1305_emit_neon +# ifdef __thumb2__ + it ne +# endif + movne r12,r10 +# else +# ifdef __thumb2__ + itete eq +# endif + addeq r12,r11,#(poly1305_emit-.Lpoly1305_init) + addne r12,r11,#(poly1305_emit_neon-.Lpoly1305_init) + addeq r11,r11,#(poly1305_blocks-.Lpoly1305_init) + addne r11,r11,#(poly1305_blocks_neon-.Lpoly1305_init) +# endif +# ifdef __thumb2__ + orr r12,r12,#1 @ thumb-ify address + orr r11,r11,#1 +# endif +#endif + ldrb r9,[$inp,#11] + orr r6,r6,r7,lsl#8 + ldrb r7,[$inp,#12] + orr r6,r6,r8,lsl#16 + ldrb r8,[$inp,#13] + orr r6,r6,r9,lsl#24 + ldrb r9,[$inp,#14] + and r6,r6,r3 + + ldrb r10,[$inp,#15] + orr r7,r7,r8,lsl#8 + str r4,[$ctx,#0] + orr r7,r7,r9,lsl#16 + str r5,[$ctx,#4] + orr r7,r7,r10,lsl#24 + str r6,[$ctx,#8] + and r7,r7,r3 + str r7,[$ctx,#12] +#if __ARM_MAX_ARCH__>=7 + stmia r2,{r11,r12} @ fill functions table + mov r0,#1 +#else + mov r0,#0 +#endif +.Lno_key: + ldmia sp!,{r4-r11} +#if __ARM_ARCH__>=5 + ret @ bx lr +#else + tst lr,#1 + moveq pc,lr @ be binary compatible with V4, yet + bx lr @ interoperable with Thumb ISA:-) +#endif +.size poly1305_init,.-poly1305_init +___ +{ +my ($h0,$h1,$h2,$h3,$h4,$r0,$r1,$r2,$r3)=map("r$_",(4..12)); +my ($s1,$s2,$s3)=($r1,$r2,$r3); + +$code.=<<___; +.type poly1305_blocks,%function +.align 5 +poly1305_blocks: +.Lpoly1305_blocks: + stmdb sp!,{r3-r11,lr} + + ands $len,$len,#-16 + beq .Lno_data + + cmp $padbit,#0 + add $len,$len,$inp @ end pointer + sub sp,sp,#32 + + ldmia $ctx,{$h0-$r3} @ load context + + str $ctx,[sp,#12] @ offload stuff + mov lr,$inp + str $len,[sp,#16] + str $r1,[sp,#20] + str $r2,[sp,#24] + str $r3,[sp,#28] + b .Loop + +.Loop: +#if __ARM_ARCH__<7 + ldrb r0,[lr],#16 @ load input +# ifdef __thumb2__ + it hi +# endif + addhi $h4,$h4,#1 @ 1<<128 + ldrb r1,[lr,#-15] + ldrb r2,[lr,#-14] + ldrb r3,[lr,#-13] + orr r1,r0,r1,lsl#8 + ldrb r0,[lr,#-12] + orr r2,r1,r2,lsl#16 + ldrb r1,[lr,#-11] + orr r3,r2,r3,lsl#24 + ldrb r2,[lr,#-10] + adds $h0,$h0,r3 @ accumulate input + + ldrb r3,[lr,#-9] + orr r1,r0,r1,lsl#8 + ldrb r0,[lr,#-8] + orr r2,r1,r2,lsl#16 + ldrb r1,[lr,#-7] + orr r3,r2,r3,lsl#24 + ldrb r2,[lr,#-6] + adcs $h1,$h1,r3 + + ldrb r3,[lr,#-5] + orr r1,r0,r1,lsl#8 + ldrb r0,[lr,#-4] + orr r2,r1,r2,lsl#16 + ldrb r1,[lr,#-3] + orr r3,r2,r3,lsl#24 + ldrb r2,[lr,#-2] + adcs $h2,$h2,r3 + + ldrb r3,[lr,#-1] + orr r1,r0,r1,lsl#8 + str lr,[sp,#8] @ offload input pointer + orr r2,r1,r2,lsl#16 + add $s1,$r1,$r1,lsr#2 + orr r3,r2,r3,lsl#24 +#else + ldr r0,[lr],#16 @ load input +# ifdef __thumb2__ + it hi +# endif + addhi $h4,$h4,#1 @ padbit + ldr r1,[lr,#-12] + ldr r2,[lr,#-8] + ldr r3,[lr,#-4] +# ifdef __ARMEB__ + rev r0,r0 + rev r1,r1 + rev r2,r2 + rev r3,r3 +# endif + adds $h0,$h0,r0 @ accumulate input + str lr,[sp,#8] @ offload input pointer + adcs $h1,$h1,r1 + add $s1,$r1,$r1,lsr#2 + adcs $h2,$h2,r2 +#endif + add $s2,$r2,$r2,lsr#2 + adcs $h3,$h3,r3 + add $s3,$r3,$r3,lsr#2 + + umull r2,r3,$h1,$r0 + adc $h4,$h4,#0 + umull r0,r1,$h0,$r0 + umlal r2,r3,$h4,$s1 + umlal r0,r1,$h3,$s1 + ldr $r1,[sp,#20] @ reload $r1 + umlal r2,r3,$h2,$s3 + umlal r0,r1,$h1,$s3 + umlal r2,r3,$h3,$s2 + umlal r0,r1,$h2,$s2 + umlal r2,r3,$h0,$r1 + str r0,[sp,#0] @ future $h0 + mul r0,$s2,$h4 + ldr $r2,[sp,#24] @ reload $r2 + adds r2,r2,r1 @ d1+=d0>>32 + eor r1,r1,r1 + adc lr,r3,#0 @ future $h2 + str r2,[sp,#4] @ future $h1 + + mul r2,$s3,$h4 + eor r3,r3,r3 + umlal r0,r1,$h3,$s3 + ldr $r3,[sp,#28] @ reload $r3 + umlal r2,r3,$h3,$r0 + umlal r0,r1,$h2,$r0 + umlal r2,r3,$h2,$r1 + umlal r0,r1,$h1,$r1 + umlal r2,r3,$h1,$r2 + umlal r0,r1,$h0,$r2 + umlal r2,r3,$h0,$r3 + ldr $h0,[sp,#0] + mul $h4,$r0,$h4 + ldr $h1,[sp,#4] + + adds $h2,lr,r0 @ d2+=d1>>32 + ldr lr,[sp,#8] @ reload input pointer + adc r1,r1,#0 + adds $h3,r2,r1 @ d3+=d2>>32 + ldr r0,[sp,#16] @ reload end pointer + adc r3,r3,#0 + add $h4,$h4,r3 @ h4+=d3>>32 + + and r1,$h4,#-4 + and $h4,$h4,#3 + add r1,r1,r1,lsr#2 @ *=5 + adds $h0,$h0,r1 + adcs $h1,$h1,#0 + adcs $h2,$h2,#0 + adcs $h3,$h3,#0 + adc $h4,$h4,#0 + + cmp r0,lr @ done yet? + bhi .Loop + + ldr $ctx,[sp,#12] + add sp,sp,#32 + stmia $ctx,{$h0-$h4} @ store the result + +.Lno_data: +#if __ARM_ARCH__>=5 + ldmia sp!,{r3-r11,pc} +#else + ldmia sp!,{r3-r11,lr} + tst lr,#1 + moveq pc,lr @ be binary compatible with V4, yet + bx lr @ interoperable with Thumb ISA:-) +#endif +.size poly1305_blocks,.-poly1305_blocks +___ +} +{ +my ($ctx,$mac,$nonce)=map("r$_",(0..2)); +my ($h0,$h1,$h2,$h3,$h4,$g0,$g1,$g2,$g3)=map("r$_",(3..11)); +my $g4=$h4; + +$code.=<<___; +.type poly1305_emit,%function +.align 5 +poly1305_emit: + stmdb sp!,{r4-r11} +.Lpoly1305_emit_enter: + + ldmia $ctx,{$h0-$h4} + adds $g0,$h0,#5 @ compare to modulus + adcs $g1,$h1,#0 + adcs $g2,$h2,#0 + adcs $g3,$h3,#0 + adc $g4,$h4,#0 + tst $g4,#4 @ did it carry/borrow? + +#ifdef __thumb2__ + it ne +#endif + movne $h0,$g0 + ldr $g0,[$nonce,#0] +#ifdef __thumb2__ + it ne +#endif + movne $h1,$g1 + ldr $g1,[$nonce,#4] +#ifdef __thumb2__ + it ne +#endif + movne $h2,$g2 + ldr $g2,[$nonce,#8] +#ifdef __thumb2__ + it ne +#endif + movne $h3,$g3 + ldr $g3,[$nonce,#12] + + adds $h0,$h0,$g0 + adcs $h1,$h1,$g1 + adcs $h2,$h2,$g2 + adc $h3,$h3,$g3 + +#if __ARM_ARCH__>=7 +# ifdef __ARMEB__ + rev $h0,$h0 + rev $h1,$h1 + rev $h2,$h2 + rev $h3,$h3 +# endif + str $h0,[$mac,#0] + str $h1,[$mac,#4] + str $h2,[$mac,#8] + str $h3,[$mac,#12] +#else + strb $h0,[$mac,#0] + mov $h0,$h0,lsr#8 + strb $h1,[$mac,#4] + mov $h1,$h1,lsr#8 + strb $h2,[$mac,#8] + mov $h2,$h2,lsr#8 + strb $h3,[$mac,#12] + mov $h3,$h3,lsr#8 + + strb $h0,[$mac,#1] + mov $h0,$h0,lsr#8 + strb $h1,[$mac,#5] + mov $h1,$h1,lsr#8 + strb $h2,[$mac,#9] + mov $h2,$h2,lsr#8 + strb $h3,[$mac,#13] + mov $h3,$h3,lsr#8 + + strb $h0,[$mac,#2] + mov $h0,$h0,lsr#8 + strb $h1,[$mac,#6] + mov $h1,$h1,lsr#8 + strb $h2,[$mac,#10] + mov $h2,$h2,lsr#8 + strb $h3,[$mac,#14] + mov $h3,$h3,lsr#8 + + strb $h0,[$mac,#3] + strb $h1,[$mac,#7] + strb $h2,[$mac,#11] + strb $h3,[$mac,#15] +#endif + ldmia sp!,{r4-r11} +#if __ARM_ARCH__>=5 + ret @ bx lr +#else + tst lr,#1 + moveq pc,lr @ be binary compatible with V4, yet + bx lr @ interoperable with Thumb ISA:-) +#endif +.size poly1305_emit,.-poly1305_emit +___ +{ +my ($R0,$R1,$S1,$R2,$S2,$R3,$S3,$R4,$S4) = map("d$_",(0..9)); +my ($D0,$D1,$D2,$D3,$D4, $H0,$H1,$H2,$H3,$H4) = map("q$_",(5..14)); +my ($T0,$T1,$MASK) = map("q$_",(15,4,0)); + +my ($in2,$zeros,$tbl0,$tbl1) = map("r$_",(4..7)); + +$code.=<<___; +#if __ARM_MAX_ARCH__>=7 +.fpu neon + +.type poly1305_init_neon,%function +.align 5 +poly1305_init_neon: + ldr r4,[$ctx,#20] @ load key base 2^32 + ldr r5,[$ctx,#24] + ldr r6,[$ctx,#28] + ldr r7,[$ctx,#32] + + and r2,r4,#0x03ffffff @ base 2^32 -> base 2^26 + mov r3,r4,lsr#26 + mov r4,r5,lsr#20 + orr r3,r3,r5,lsl#6 + mov r5,r6,lsr#14 + orr r4,r4,r6,lsl#12 + mov r6,r7,lsr#8 + orr r5,r5,r7,lsl#18 + and r3,r3,#0x03ffffff + and r4,r4,#0x03ffffff + and r5,r5,#0x03ffffff + + vdup.32 $R0,r2 @ r^1 in both lanes + add r2,r3,r3,lsl#2 @ *5 + vdup.32 $R1,r3 + add r3,r4,r4,lsl#2 + vdup.32 $S1,r2 + vdup.32 $R2,r4 + add r4,r5,r5,lsl#2 + vdup.32 $S2,r3 + vdup.32 $R3,r5 + add r5,r6,r6,lsl#2 + vdup.32 $S3,r4 + vdup.32 $R4,r6 + vdup.32 $S4,r5 + + mov $zeros,#2 @ counter + +.Lsquare_neon: + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ d0 = h0*r0 + h4*5*r1 + h3*5*r2 + h2*5*r3 + h1*5*r4 + @ d1 = h1*r0 + h0*r1 + h4*5*r2 + h3*5*r3 + h2*5*r4 + @ d2 = h2*r0 + h1*r1 + h0*r2 + h4*5*r3 + h3*5*r4 + @ d3 = h3*r0 + h2*r1 + h1*r2 + h0*r3 + h4*5*r4 + @ d4 = h4*r0 + h3*r1 + h2*r2 + h1*r3 + h0*r4 + + vmull.u32 $D0,$R0,${R0}[1] + vmull.u32 $D1,$R1,${R0}[1] + vmull.u32 $D2,$R2,${R0}[1] + vmull.u32 $D3,$R3,${R0}[1] + vmull.u32 $D4,$R4,${R0}[1] + + vmlal.u32 $D0,$R4,${S1}[1] + vmlal.u32 $D1,$R0,${R1}[1] + vmlal.u32 $D2,$R1,${R1}[1] + vmlal.u32 $D3,$R2,${R1}[1] + vmlal.u32 $D4,$R3,${R1}[1] + + vmlal.u32 $D0,$R3,${S2}[1] + vmlal.u32 $D1,$R4,${S2}[1] + vmlal.u32 $D3,$R1,${R2}[1] + vmlal.u32 $D2,$R0,${R2}[1] + vmlal.u32 $D4,$R2,${R2}[1] + + vmlal.u32 $D0,$R2,${S3}[1] + vmlal.u32 $D3,$R0,${R3}[1] + vmlal.u32 $D1,$R3,${S3}[1] + vmlal.u32 $D2,$R4,${S3}[1] + vmlal.u32 $D4,$R1,${R3}[1] + + vmlal.u32 $D3,$R4,${S4}[1] + vmlal.u32 $D0,$R1,${S4}[1] + vmlal.u32 $D1,$R2,${S4}[1] + vmlal.u32 $D2,$R3,${S4}[1] + vmlal.u32 $D4,$R0,${R4}[1] + + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ lazy reduction as discussed in "NEON crypto" by D.J. Bernstein + @ and P. Schwabe + @ + @ H0>>+H1>>+H2>>+H3>>+H4 + @ H3>>+H4>>*5+H0>>+H1 + @ + @ Trivia. + @ + @ Result of multiplication of n-bit number by m-bit number is + @ n+m bits wide. However! Even though 2^n is a n+1-bit number, + @ m-bit number multiplied by 2^n is still n+m bits wide. + @ + @ Sum of two n-bit numbers is n+1 bits wide, sum of three - n+2, + @ and so is sum of four. Sum of 2^m n-m-bit numbers and n-bit + @ one is n+1 bits wide. + @ + @ >>+ denotes Hnext += Hn>>26, Hn &= 0x3ffffff. This means that + @ H0, H2, H3 are guaranteed to be 26 bits wide, while H1 and H4 + @ can be 27. However! In cases when their width exceeds 26 bits + @ they are limited by 2^26+2^6. This in turn means that *sum* + @ of the products with these values can still be viewed as sum + @ of 52-bit numbers as long as the amount of addends is not a + @ power of 2. For example, + @ + @ H4 = H4*R0 + H3*R1 + H2*R2 + H1*R3 + H0 * R4, + @ + @ which can't be larger than 5 * (2^26 + 2^6) * (2^26 + 2^6), or + @ 5 * (2^52 + 2*2^32 + 2^12), which in turn is smaller than + @ 8 * (2^52) or 2^55. However, the value is then multiplied by + @ by 5, so we should be looking at 5 * 5 * (2^52 + 2^33 + 2^12), + @ which is less than 32 * (2^52) or 2^57. And when processing + @ data we are looking at triple as many addends... + @ + @ In key setup procedure pre-reduced H0 is limited by 5*4+1 and + @ 5*H4 - by 5*5 52-bit addends, or 57 bits. But when hashing the + @ input H0 is limited by (5*4+1)*3 addends, or 58 bits, while + @ 5*H4 by 5*5*3, or 59[!] bits. How is this relevant? vmlal.u32 + @ instruction accepts 2x32-bit input and writes 2x64-bit result. + @ This means that result of reduction have to be compressed upon + @ loop wrap-around. This can be done in the process of reduction + @ to minimize amount of instructions [as well as amount of + @ 128-bit instructions, which benefits low-end processors], but + @ one has to watch for H2 (which is narrower than H0) and 5*H4 + @ not being wider than 58 bits, so that result of right shift + @ by 26 bits fits in 32 bits. This is also useful on x86, + @ because it allows to use paddd in place for paddq, which + @ benefits Atom, where paddq is ridiculously slow. + + vshr.u64 $T0,$D3,#26 + vmovn.i64 $D3#lo,$D3 + vshr.u64 $T1,$D0,#26 + vmovn.i64 $D0#lo,$D0 + vadd.i64 $D4,$D4,$T0 @ h3 -> h4 + vbic.i32 $D3#lo,#0xfc000000 @ &=0x03ffffff + vadd.i64 $D1,$D1,$T1 @ h0 -> h1 + vbic.i32 $D0#lo,#0xfc000000 + + vshrn.u64 $T0#lo,$D4,#26 + vmovn.i64 $D4#lo,$D4 + vshr.u64 $T1,$D1,#26 + vmovn.i64 $D1#lo,$D1 + vadd.i64 $D2,$D2,$T1 @ h1 -> h2 + vbic.i32 $D4#lo,#0xfc000000 + vbic.i32 $D1#lo,#0xfc000000 + + vadd.i32 $D0#lo,$D0#lo,$T0#lo + vshl.u32 $T0#lo,$T0#lo,#2 + vshrn.u64 $T1#lo,$D2,#26 + vmovn.i64 $D2#lo,$D2 + vadd.i32 $D0#lo,$D0#lo,$T0#lo @ h4 -> h0 + vadd.i32 $D3#lo,$D3#lo,$T1#lo @ h2 -> h3 + vbic.i32 $D2#lo,#0xfc000000 + + vshr.u32 $T0#lo,$D0#lo,#26 + vbic.i32 $D0#lo,#0xfc000000 + vshr.u32 $T1#lo,$D3#lo,#26 + vbic.i32 $D3#lo,#0xfc000000 + vadd.i32 $D1#lo,$D1#lo,$T0#lo @ h0 -> h1 + vadd.i32 $D4#lo,$D4#lo,$T1#lo @ h3 -> h4 + + subs $zeros,$zeros,#1 + beq .Lsquare_break_neon + + add $tbl0,$ctx,#(48+0*9*4) + add $tbl1,$ctx,#(48+1*9*4) + + vtrn.32 $R0,$D0#lo @ r^2:r^1 + vtrn.32 $R2,$D2#lo + vtrn.32 $R3,$D3#lo + vtrn.32 $R1,$D1#lo + vtrn.32 $R4,$D4#lo + + vshl.u32 $S2,$R2,#2 @ *5 + vshl.u32 $S3,$R3,#2 + vshl.u32 $S1,$R1,#2 + vshl.u32 $S4,$R4,#2 + vadd.i32 $S2,$S2,$R2 + vadd.i32 $S1,$S1,$R1 + vadd.i32 $S3,$S3,$R3 + vadd.i32 $S4,$S4,$R4 + + vst4.32 {${R0}[0],${R1}[0],${S1}[0],${R2}[0]},[$tbl0]! + vst4.32 {${R0}[1],${R1}[1],${S1}[1],${R2}[1]},[$tbl1]! + vst4.32 {${S2}[0],${R3}[0],${S3}[0],${R4}[0]},[$tbl0]! + vst4.32 {${S2}[1],${R3}[1],${S3}[1],${R4}[1]},[$tbl1]! + vst1.32 {${S4}[0]},[$tbl0,:32] + vst1.32 {${S4}[1]},[$tbl1,:32] + + b .Lsquare_neon + +.align 4 +.Lsquare_break_neon: + add $tbl0,$ctx,#(48+2*4*9) + add $tbl1,$ctx,#(48+3*4*9) + + vmov $R0,$D0#lo @ r^4:r^3 + vshl.u32 $S1,$D1#lo,#2 @ *5 + vmov $R1,$D1#lo + vshl.u32 $S2,$D2#lo,#2 + vmov $R2,$D2#lo + vshl.u32 $S3,$D3#lo,#2 + vmov $R3,$D3#lo + vshl.u32 $S4,$D4#lo,#2 + vmov $R4,$D4#lo + vadd.i32 $S1,$S1,$D1#lo + vadd.i32 $S2,$S2,$D2#lo + vadd.i32 $S3,$S3,$D3#lo + vadd.i32 $S4,$S4,$D4#lo + + vst4.32 {${R0}[0],${R1}[0],${S1}[0],${R2}[0]},[$tbl0]! + vst4.32 {${R0}[1],${R1}[1],${S1}[1],${R2}[1]},[$tbl1]! + vst4.32 {${S2}[0],${R3}[0],${S3}[0],${R4}[0]},[$tbl0]! + vst4.32 {${S2}[1],${R3}[1],${S3}[1],${R4}[1]},[$tbl1]! + vst1.32 {${S4}[0]},[$tbl0] + vst1.32 {${S4}[1]},[$tbl1] + + ret @ bx lr +.size poly1305_init_neon,.-poly1305_init_neon + +.type poly1305_blocks_neon,%function +.align 5 +poly1305_blocks_neon: + ldr ip,[$ctx,#36] @ is_base2_26 + ands $len,$len,#-16 + beq .Lno_data_neon + + cmp $len,#64 + bhs .Lenter_neon + tst ip,ip @ is_base2_26? + beq .Lpoly1305_blocks + +.Lenter_neon: + stmdb sp!,{r4-r7} + vstmdb sp!,{d8-d15} @ ABI specification says so + + tst ip,ip @ is_base2_26? + bne .Lbase2_26_neon + + stmdb sp!,{r1-r3,lr} + bl poly1305_init_neon + + ldr r4,[$ctx,#0] @ load hash value base 2^32 + ldr r5,[$ctx,#4] + ldr r6,[$ctx,#8] + ldr r7,[$ctx,#12] + ldr ip,[$ctx,#16] + + and r2,r4,#0x03ffffff @ base 2^32 -> base 2^26 + mov r3,r4,lsr#26 + veor $D0#lo,$D0#lo,$D0#lo + mov r4,r5,lsr#20 + orr r3,r3,r5,lsl#6 + veor $D1#lo,$D1#lo,$D1#lo + mov r5,r6,lsr#14 + orr r4,r4,r6,lsl#12 + veor $D2#lo,$D2#lo,$D2#lo + mov r6,r7,lsr#8 + orr r5,r5,r7,lsl#18 + veor $D3#lo,$D3#lo,$D3#lo + and r3,r3,#0x03ffffff + orr r6,r6,ip,lsl#24 + veor $D4#lo,$D4#lo,$D4#lo + and r4,r4,#0x03ffffff + mov r1,#1 + and r5,r5,#0x03ffffff + str r1,[$ctx,#36] @ is_base2_26 + + vmov.32 $D0#lo[0],r2 + vmov.32 $D1#lo[0],r3 + vmov.32 $D2#lo[0],r4 + vmov.32 $D3#lo[0],r5 + vmov.32 $D4#lo[0],r6 + adr $zeros,.Lzeros + + ldmia sp!,{r1-r3,lr} + b .Lbase2_32_neon + +.align 4 +.Lbase2_26_neon: + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ load hash value + + veor $D0#lo,$D0#lo,$D0#lo + veor $D1#lo,$D1#lo,$D1#lo + veor $D2#lo,$D2#lo,$D2#lo + veor $D3#lo,$D3#lo,$D3#lo + veor $D4#lo,$D4#lo,$D4#lo + vld4.32 {$D0#lo[0],$D1#lo[0],$D2#lo[0],$D3#lo[0]},[$ctx]! + adr $zeros,.Lzeros + vld1.32 {$D4#lo[0]},[$ctx] + sub $ctx,$ctx,#16 @ rewind + +.Lbase2_32_neon: + add $in2,$inp,#32 + mov $padbit,$padbit,lsl#24 + tst $len,#31 + beq .Leven + + vld4.32 {$H0#lo[0],$H1#lo[0],$H2#lo[0],$H3#lo[0]},[$inp]! + vmov.32 $H4#lo[0],$padbit + sub $len,$len,#16 + add $in2,$inp,#32 + +# ifdef __ARMEB__ + vrev32.8 $H0,$H0 + vrev32.8 $H3,$H3 + vrev32.8 $H1,$H1 + vrev32.8 $H2,$H2 +# endif + vsri.u32 $H4#lo,$H3#lo,#8 @ base 2^32 -> base 2^26 + vshl.u32 $H3#lo,$H3#lo,#18 + + vsri.u32 $H3#lo,$H2#lo,#14 + vshl.u32 $H2#lo,$H2#lo,#12 + vadd.i32 $H4#hi,$H4#lo,$D4#lo @ add hash value and move to #hi + + vbic.i32 $H3#lo,#0xfc000000 + vsri.u32 $H2#lo,$H1#lo,#20 + vshl.u32 $H1#lo,$H1#lo,#6 + + vbic.i32 $H2#lo,#0xfc000000 + vsri.u32 $H1#lo,$H0#lo,#26 + vadd.i32 $H3#hi,$H3#lo,$D3#lo + + vbic.i32 $H0#lo,#0xfc000000 + vbic.i32 $H1#lo,#0xfc000000 + vadd.i32 $H2#hi,$H2#lo,$D2#lo + + vadd.i32 $H0#hi,$H0#lo,$D0#lo + vadd.i32 $H1#hi,$H1#lo,$D1#lo + + mov $tbl1,$zeros + add $tbl0,$ctx,#48 + + cmp $len,$len + b .Long_tail + +.align 4 +.Leven: + subs $len,$len,#64 + it lo + movlo $in2,$zeros + + vmov.i32 $H4,#1<<24 @ padbit, yes, always + vld4.32 {$H0#lo,$H1#lo,$H2#lo,$H3#lo},[$inp] @ inp[0:1] + add $inp,$inp,#64 + vld4.32 {$H0#hi,$H1#hi,$H2#hi,$H3#hi},[$in2] @ inp[2:3] (or 0) + add $in2,$in2,#64 + itt hi + addhi $tbl1,$ctx,#(48+1*9*4) + addhi $tbl0,$ctx,#(48+3*9*4) + +# ifdef __ARMEB__ + vrev32.8 $H0,$H0 + vrev32.8 $H3,$H3 + vrev32.8 $H1,$H1 + vrev32.8 $H2,$H2 +# endif + vsri.u32 $H4,$H3,#8 @ base 2^32 -> base 2^26 + vshl.u32 $H3,$H3,#18 + + vsri.u32 $H3,$H2,#14 + vshl.u32 $H2,$H2,#12 + + vbic.i32 $H3,#0xfc000000 + vsri.u32 $H2,$H1,#20 + vshl.u32 $H1,$H1,#6 + + vbic.i32 $H2,#0xfc000000 + vsri.u32 $H1,$H0,#26 + + vbic.i32 $H0,#0xfc000000 + vbic.i32 $H1,#0xfc000000 + + bls .Lskip_loop + + vld4.32 {${R0}[1],${R1}[1],${S1}[1],${R2}[1]},[$tbl1]! @ load r^2 + vld4.32 {${R0}[0],${R1}[0],${S1}[0],${R2}[0]},[$tbl0]! @ load r^4 + vld4.32 {${S2}[1],${R3}[1],${S3}[1],${R4}[1]},[$tbl1]! + vld4.32 {${S2}[0],${R3}[0],${S3}[0],${R4}[0]},[$tbl0]! + b .Loop_neon + +.align 5 +.Loop_neon: + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ ((inp[0]*r^4+inp[2]*r^2+inp[4])*r^4+inp[6]*r^2 + @ ((inp[1]*r^4+inp[3]*r^2+inp[5])*r^3+inp[7]*r + @ \___________________/ + @ ((inp[0]*r^4+inp[2]*r^2+inp[4])*r^4+inp[6]*r^2+inp[8])*r^2 + @ ((inp[1]*r^4+inp[3]*r^2+inp[5])*r^4+inp[7]*r^2+inp[9])*r + @ \___________________/ \____________________/ + @ + @ Note that we start with inp[2:3]*r^2. This is because it + @ doesn't depend on reduction in previous iteration. + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ d4 = h4*r0 + h3*r1 + h2*r2 + h1*r3 + h0*r4 + @ d3 = h3*r0 + h2*r1 + h1*r2 + h0*r3 + h4*5*r4 + @ d2 = h2*r0 + h1*r1 + h0*r2 + h4*5*r3 + h3*5*r4 + @ d1 = h1*r0 + h0*r1 + h4*5*r2 + h3*5*r3 + h2*5*r4 + @ d0 = h0*r0 + h4*5*r1 + h3*5*r2 + h2*5*r3 + h1*5*r4 + + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ inp[2:3]*r^2 + + vadd.i32 $H2#lo,$H2#lo,$D2#lo @ accumulate inp[0:1] + vmull.u32 $D2,$H2#hi,${R0}[1] + vadd.i32 $H0#lo,$H0#lo,$D0#lo + vmull.u32 $D0,$H0#hi,${R0}[1] + vadd.i32 $H3#lo,$H3#lo,$D3#lo + vmull.u32 $D3,$H3#hi,${R0}[1] + vmlal.u32 $D2,$H1#hi,${R1}[1] + vadd.i32 $H1#lo,$H1#lo,$D1#lo + vmull.u32 $D1,$H1#hi,${R0}[1] + + vadd.i32 $H4#lo,$H4#lo,$D4#lo + vmull.u32 $D4,$H4#hi,${R0}[1] + subs $len,$len,#64 + vmlal.u32 $D0,$H4#hi,${S1}[1] + it lo + movlo $in2,$zeros + vmlal.u32 $D3,$H2#hi,${R1}[1] + vld1.32 ${S4}[1],[$tbl1,:32] + vmlal.u32 $D1,$H0#hi,${R1}[1] + vmlal.u32 $D4,$H3#hi,${R1}[1] + + vmlal.u32 $D0,$H3#hi,${S2}[1] + vmlal.u32 $D3,$H1#hi,${R2}[1] + vmlal.u32 $D4,$H2#hi,${R2}[1] + vmlal.u32 $D1,$H4#hi,${S2}[1] + vmlal.u32 $D2,$H0#hi,${R2}[1] + + vmlal.u32 $D3,$H0#hi,${R3}[1] + vmlal.u32 $D0,$H2#hi,${S3}[1] + vmlal.u32 $D4,$H1#hi,${R3}[1] + vmlal.u32 $D1,$H3#hi,${S3}[1] + vmlal.u32 $D2,$H4#hi,${S3}[1] + + vmlal.u32 $D3,$H4#hi,${S4}[1] + vmlal.u32 $D0,$H1#hi,${S4}[1] + vmlal.u32 $D4,$H0#hi,${R4}[1] + vmlal.u32 $D1,$H2#hi,${S4}[1] + vmlal.u32 $D2,$H3#hi,${S4}[1] + + vld4.32 {$H0#hi,$H1#hi,$H2#hi,$H3#hi},[$in2] @ inp[2:3] (or 0) + add $in2,$in2,#64 + + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ (hash+inp[0:1])*r^4 and accumulate + + vmlal.u32 $D3,$H3#lo,${R0}[0] + vmlal.u32 $D0,$H0#lo,${R0}[0] + vmlal.u32 $D4,$H4#lo,${R0}[0] + vmlal.u32 $D1,$H1#lo,${R0}[0] + vmlal.u32 $D2,$H2#lo,${R0}[0] + vld1.32 ${S4}[0],[$tbl0,:32] + + vmlal.u32 $D3,$H2#lo,${R1}[0] + vmlal.u32 $D0,$H4#lo,${S1}[0] + vmlal.u32 $D4,$H3#lo,${R1}[0] + vmlal.u32 $D1,$H0#lo,${R1}[0] + vmlal.u32 $D2,$H1#lo,${R1}[0] + + vmlal.u32 $D3,$H1#lo,${R2}[0] + vmlal.u32 $D0,$H3#lo,${S2}[0] + vmlal.u32 $D4,$H2#lo,${R2}[0] + vmlal.u32 $D1,$H4#lo,${S2}[0] + vmlal.u32 $D2,$H0#lo,${R2}[0] + + vmlal.u32 $D3,$H0#lo,${R3}[0] + vmlal.u32 $D0,$H2#lo,${S3}[0] + vmlal.u32 $D4,$H1#lo,${R3}[0] + vmlal.u32 $D1,$H3#lo,${S3}[0] + vmlal.u32 $D3,$H4#lo,${S4}[0] + + vmlal.u32 $D2,$H4#lo,${S3}[0] + vmlal.u32 $D0,$H1#lo,${S4}[0] + vmlal.u32 $D4,$H0#lo,${R4}[0] + vmov.i32 $H4,#1<<24 @ padbit, yes, always + vmlal.u32 $D1,$H2#lo,${S4}[0] + vmlal.u32 $D2,$H3#lo,${S4}[0] + + vld4.32 {$H0#lo,$H1#lo,$H2#lo,$H3#lo},[$inp] @ inp[0:1] + add $inp,$inp,#64 +# ifdef __ARMEB__ + vrev32.8 $H0,$H0 + vrev32.8 $H1,$H1 + vrev32.8 $H2,$H2 + vrev32.8 $H3,$H3 +# endif + + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ lazy reduction interleaved with base 2^32 -> base 2^26 of + @ inp[0:3] previously loaded to $H0-$H3 and smashed to $H0-$H4. + + vshr.u64 $T0,$D3,#26 + vmovn.i64 $D3#lo,$D3 + vshr.u64 $T1,$D0,#26 + vmovn.i64 $D0#lo,$D0 + vadd.i64 $D4,$D4,$T0 @ h3 -> h4 + vbic.i32 $D3#lo,#0xfc000000 + vsri.u32 $H4,$H3,#8 @ base 2^32 -> base 2^26 + vadd.i64 $D1,$D1,$T1 @ h0 -> h1 + vshl.u32 $H3,$H3,#18 + vbic.i32 $D0#lo,#0xfc000000 + + vshrn.u64 $T0#lo,$D4,#26 + vmovn.i64 $D4#lo,$D4 + vshr.u64 $T1,$D1,#26 + vmovn.i64 $D1#lo,$D1 + vadd.i64 $D2,$D2,$T1 @ h1 -> h2 + vsri.u32 $H3,$H2,#14 + vbic.i32 $D4#lo,#0xfc000000 + vshl.u32 $H2,$H2,#12 + vbic.i32 $D1#lo,#0xfc000000 + + vadd.i32 $D0#lo,$D0#lo,$T0#lo + vshl.u32 $T0#lo,$T0#lo,#2 + vbic.i32 $H3,#0xfc000000 + vshrn.u64 $T1#lo,$D2,#26 + vmovn.i64 $D2#lo,$D2 + vaddl.u32 $D0,$D0#lo,$T0#lo @ h4 -> h0 [widen for a sec] + vsri.u32 $H2,$H1,#20 + vadd.i32 $D3#lo,$D3#lo,$T1#lo @ h2 -> h3 + vshl.u32 $H1,$H1,#6 + vbic.i32 $D2#lo,#0xfc000000 + vbic.i32 $H2,#0xfc000000 + + vshrn.u64 $T0#lo,$D0,#26 @ re-narrow + vmovn.i64 $D0#lo,$D0 + vsri.u32 $H1,$H0,#26 + vbic.i32 $H0,#0xfc000000 + vshr.u32 $T1#lo,$D3#lo,#26 + vbic.i32 $D3#lo,#0xfc000000 + vbic.i32 $D0#lo,#0xfc000000 + vadd.i32 $D1#lo,$D1#lo,$T0#lo @ h0 -> h1 + vadd.i32 $D4#lo,$D4#lo,$T1#lo @ h3 -> h4 + vbic.i32 $H1,#0xfc000000 + + bhi .Loop_neon + +.Lskip_loop: + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ multiply (inp[0:1]+hash) or inp[2:3] by r^2:r^1 + + add $tbl1,$ctx,#(48+0*9*4) + add $tbl0,$ctx,#(48+1*9*4) + adds $len,$len,#32 + it ne + movne $len,#0 + bne .Long_tail + + vadd.i32 $H2#hi,$H2#lo,$D2#lo @ add hash value and move to #hi + vadd.i32 $H0#hi,$H0#lo,$D0#lo + vadd.i32 $H3#hi,$H3#lo,$D3#lo + vadd.i32 $H1#hi,$H1#lo,$D1#lo + vadd.i32 $H4#hi,$H4#lo,$D4#lo + +.Long_tail: + vld4.32 {${R0}[1],${R1}[1],${S1}[1],${R2}[1]},[$tbl1]! @ load r^1 + vld4.32 {${R0}[0],${R1}[0],${S1}[0],${R2}[0]},[$tbl0]! @ load r^2 + + vadd.i32 $H2#lo,$H2#lo,$D2#lo @ can be redundant + vmull.u32 $D2,$H2#hi,$R0 + vadd.i32 $H0#lo,$H0#lo,$D0#lo + vmull.u32 $D0,$H0#hi,$R0 + vadd.i32 $H3#lo,$H3#lo,$D3#lo + vmull.u32 $D3,$H3#hi,$R0 + vadd.i32 $H1#lo,$H1#lo,$D1#lo + vmull.u32 $D1,$H1#hi,$R0 + vadd.i32 $H4#lo,$H4#lo,$D4#lo + vmull.u32 $D4,$H4#hi,$R0 + + vmlal.u32 $D0,$H4#hi,$S1 + vld4.32 {${S2}[1],${R3}[1],${S3}[1],${R4}[1]},[$tbl1]! + vmlal.u32 $D3,$H2#hi,$R1 + vld4.32 {${S2}[0],${R3}[0],${S3}[0],${R4}[0]},[$tbl0]! + vmlal.u32 $D1,$H0#hi,$R1 + vmlal.u32 $D4,$H3#hi,$R1 + vmlal.u32 $D2,$H1#hi,$R1 + + vmlal.u32 $D3,$H1#hi,$R2 + vld1.32 ${S4}[1],[$tbl1,:32] + vmlal.u32 $D0,$H3#hi,$S2 + vld1.32 ${S4}[0],[$tbl0,:32] + vmlal.u32 $D4,$H2#hi,$R2 + vmlal.u32 $D1,$H4#hi,$S2 + vmlal.u32 $D2,$H0#hi,$R2 + + vmlal.u32 $D3,$H0#hi,$R3 + it ne + addne $tbl1,$ctx,#(48+2*9*4) + vmlal.u32 $D0,$H2#hi,$S3 + it ne + addne $tbl0,$ctx,#(48+3*9*4) + vmlal.u32 $D4,$H1#hi,$R3 + vmlal.u32 $D1,$H3#hi,$S3 + vmlal.u32 $D2,$H4#hi,$S3 + + vmlal.u32 $D3,$H4#hi,$S4 + vorn $MASK,$MASK,$MASK @ all-ones, can be redundant + vmlal.u32 $D0,$H1#hi,$S4 + vshr.u64 $MASK,$MASK,#38 + vmlal.u32 $D4,$H0#hi,$R4 + vmlal.u32 $D1,$H2#hi,$S4 + vmlal.u32 $D2,$H3#hi,$S4 + + beq .Lshort_tail + + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ (hash+inp[0:1])*r^4:r^3 and accumulate + + vld4.32 {${R0}[1],${R1}[1],${S1}[1],${R2}[1]},[$tbl1]! @ load r^3 + vld4.32 {${R0}[0],${R1}[0],${S1}[0],${R2}[0]},[$tbl0]! @ load r^4 + + vmlal.u32 $D2,$H2#lo,$R0 + vmlal.u32 $D0,$H0#lo,$R0 + vmlal.u32 $D3,$H3#lo,$R0 + vmlal.u32 $D1,$H1#lo,$R0 + vmlal.u32 $D4,$H4#lo,$R0 + + vmlal.u32 $D0,$H4#lo,$S1 + vld4.32 {${S2}[1],${R3}[1],${S3}[1],${R4}[1]},[$tbl1]! + vmlal.u32 $D3,$H2#lo,$R1 + vld4.32 {${S2}[0],${R3}[0],${S3}[0],${R4}[0]},[$tbl0]! + vmlal.u32 $D1,$H0#lo,$R1 + vmlal.u32 $D4,$H3#lo,$R1 + vmlal.u32 $D2,$H1#lo,$R1 + + vmlal.u32 $D3,$H1#lo,$R2 + vld1.32 ${S4}[1],[$tbl1,:32] + vmlal.u32 $D0,$H3#lo,$S2 + vld1.32 ${S4}[0],[$tbl0,:32] + vmlal.u32 $D4,$H2#lo,$R2 + vmlal.u32 $D1,$H4#lo,$S2 + vmlal.u32 $D2,$H0#lo,$R2 + + vmlal.u32 $D3,$H0#lo,$R3 + vmlal.u32 $D0,$H2#lo,$S3 + vmlal.u32 $D4,$H1#lo,$R3 + vmlal.u32 $D1,$H3#lo,$S3 + vmlal.u32 $D2,$H4#lo,$S3 + + vmlal.u32 $D3,$H4#lo,$S4 + vorn $MASK,$MASK,$MASK @ all-ones + vmlal.u32 $D0,$H1#lo,$S4 + vshr.u64 $MASK,$MASK,#38 + vmlal.u32 $D4,$H0#lo,$R4 + vmlal.u32 $D1,$H2#lo,$S4 + vmlal.u32 $D2,$H3#lo,$S4 + +.Lshort_tail: + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ horizontal addition + + vadd.i64 $D3#lo,$D3#lo,$D3#hi + vadd.i64 $D0#lo,$D0#lo,$D0#hi + vadd.i64 $D4#lo,$D4#lo,$D4#hi + vadd.i64 $D1#lo,$D1#lo,$D1#hi + vadd.i64 $D2#lo,$D2#lo,$D2#hi + + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ lazy reduction, but without narrowing + + vshr.u64 $T0,$D3,#26 + vand.i64 $D3,$D3,$MASK + vshr.u64 $T1,$D0,#26 + vand.i64 $D0,$D0,$MASK + vadd.i64 $D4,$D4,$T0 @ h3 -> h4 + vadd.i64 $D1,$D1,$T1 @ h0 -> h1 + + vshr.u64 $T0,$D4,#26 + vand.i64 $D4,$D4,$MASK + vshr.u64 $T1,$D1,#26 + vand.i64 $D1,$D1,$MASK + vadd.i64 $D2,$D2,$T1 @ h1 -> h2 + + vadd.i64 $D0,$D0,$T0 + vshl.u64 $T0,$T0,#2 + vshr.u64 $T1,$D2,#26 + vand.i64 $D2,$D2,$MASK + vadd.i64 $D0,$D0,$T0 @ h4 -> h0 + vadd.i64 $D3,$D3,$T1 @ h2 -> h3 + + vshr.u64 $T0,$D0,#26 + vand.i64 $D0,$D0,$MASK + vshr.u64 $T1,$D3,#26 + vand.i64 $D3,$D3,$MASK + vadd.i64 $D1,$D1,$T0 @ h0 -> h1 + vadd.i64 $D4,$D4,$T1 @ h3 -> h4 + + cmp $len,#0 + bne .Leven + + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ store hash value + + vst4.32 {$D0#lo[0],$D1#lo[0],$D2#lo[0],$D3#lo[0]},[$ctx]! + vst1.32 {$D4#lo[0]},[$ctx] + + vldmia sp!,{d8-d15} @ epilogue + ldmia sp!,{r4-r7} +.Lno_data_neon: + ret @ bx lr +.size poly1305_blocks_neon,.-poly1305_blocks_neon + +.type poly1305_emit_neon,%function +.align 5 +poly1305_emit_neon: + ldr ip,[$ctx,#36] @ is_base2_26 + + stmdb sp!,{r4-r11} + + tst ip,ip + beq .Lpoly1305_emit_enter + + ldmia $ctx,{$h0-$h4} + eor $g0,$g0,$g0 + + adds $h0,$h0,$h1,lsl#26 @ base 2^26 -> base 2^32 + mov $h1,$h1,lsr#6 + adcs $h1,$h1,$h2,lsl#20 + mov $h2,$h2,lsr#12 + adcs $h2,$h2,$h3,lsl#14 + mov $h3,$h3,lsr#18 + adcs $h3,$h3,$h4,lsl#8 + adc $h4,$g0,$h4,lsr#24 @ can be partially reduced ... + + and $g0,$h4,#-4 @ ... so reduce + and $h4,$h3,#3 + add $g0,$g0,$g0,lsr#2 @ *= 5 + adds $h0,$h0,$g0 + adcs $h1,$h1,#0 + adcs $h2,$h2,#0 + adcs $h3,$h3,#0 + adc $h4,$h4,#0 + + adds $g0,$h0,#5 @ compare to modulus + adcs $g1,$h1,#0 + adcs $g2,$h2,#0 + adcs $g3,$h3,#0 + adc $g4,$h4,#0 + tst $g4,#4 @ did it carry/borrow? + + it ne + movne $h0,$g0 + ldr $g0,[$nonce,#0] + it ne + movne $h1,$g1 + ldr $g1,[$nonce,#4] + it ne + movne $h2,$g2 + ldr $g2,[$nonce,#8] + it ne + movne $h3,$g3 + ldr $g3,[$nonce,#12] + + adds $h0,$h0,$g0 @ accumulate nonce + adcs $h1,$h1,$g1 + adcs $h2,$h2,$g2 + adc $h3,$h3,$g3 + +# ifdef __ARMEB__ + rev $h0,$h0 + rev $h1,$h1 + rev $h2,$h2 + rev $h3,$h3 +# endif + str $h0,[$mac,#0] @ store the result + str $h1,[$mac,#4] + str $h2,[$mac,#8] + str $h3,[$mac,#12] + + ldmia sp!,{r4-r11} + ret @ bx lr +.size poly1305_emit_neon,.-poly1305_emit_neon + +.align 5 +.Lzeros: +.long 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 +.LOPENSSL_armcap: +.word OPENSSL_armcap_P-.Lpoly1305_init +#endif +___ +} } +$code.=<<___; +.asciz "Poly1305 for ARMv4/NEON, CRYPTOGAMS by " +.align 2 +#if __ARM_MAX_ARCH__>=7 +.comm OPENSSL_armcap_P,4,4 +#endif +___ + +foreach (split("\n",$code)) { + s/\`([^\`]*)\`/eval $1/geo; + + s/\bq([0-9]+)#(lo|hi)/sprintf "d%d",2*$1+($2 eq "hi")/geo or + s/\bret\b/bx lr/go or + s/\bbx\s+lr\b/.word\t0xe12fff1e/go; # make it possible to compile with -march=armv4 + + print $_,"\n"; +} +close STDOUT; # enforce flush diff --git a/crypto/poly1305/poly1305-arm.s b/crypto/poly1305/poly1305-arm.s new file mode 100644 index 0000000..1893360 --- /dev/null +++ b/crypto/poly1305/poly1305-arm.s @@ -0,0 +1,1127 @@ +/* SPDX-License-Identifier: OpenSSL OR (BSD-3-Clause OR GPL-2.0) + * + * Copyright (C) 2015-2018 Jason A. Donenfeld . All Rights Reserved. + * Copyright 2016 The OpenSSL Project Authors. All Rights Reserved. + */ + +//#include + +.text +#if defined(__thumb2__) +.syntax unified +.thumb +#else +.code 32 +#endif + +.align 5 +.globl poly1305_init_arm +.type poly1305_init_arm,%function +poly1305_init_arm: + stmdb sp!,{r4-r11} + + eor r3,r3,r3 + cmp r1,#0 + str r3,[r0,#0] @ zero hash value + str r3,[r0,#4] + str r3,[r0,#8] + str r3,[r0,#12] + str r3,[r0,#16] + str r3,[r0,#36] @ is_base2_26 + add r0,r0,#20 + +#ifdef __thumb2__ + it eq +#endif + moveq r0,#0 + beq .Lno_key + + ldrb r4,[r1,#0] + mov r10,#0x0fffffff + ldrb r5,[r1,#1] + and r3,r10,#-4 @ 0x0ffffffc + ldrb r6,[r1,#2] + ldrb r7,[r1,#3] + orr r4,r4,r5,lsl#8 + ldrb r5,[r1,#4] + orr r4,r4,r6,lsl#16 + ldrb r6,[r1,#5] + orr r4,r4,r7,lsl#24 + ldrb r7,[r1,#6] + and r4,r4,r10 + + ldrb r8,[r1,#7] + orr r5,r5,r6,lsl#8 + ldrb r6,[r1,#8] + orr r5,r5,r7,lsl#16 + ldrb r7,[r1,#9] + orr r5,r5,r8,lsl#24 + ldrb r8,[r1,#10] + and r5,r5,r3 + + ldrb r9,[r1,#11] + orr r6,r6,r7,lsl#8 + ldrb r7,[r1,#12] + orr r6,r6,r8,lsl#16 + ldrb r8,[r1,#13] + orr r6,r6,r9,lsl#24 + ldrb r9,[r1,#14] + and r6,r6,r3 + + ldrb r10,[r1,#15] + orr r7,r7,r8,lsl#8 + str r4,[r0,#0] + orr r7,r7,r9,lsl#16 + str r5,[r0,#4] + orr r7,r7,r10,lsl#24 + str r6,[r0,#8] + and r7,r7,r3 + str r7,[r0,#12] +.Lno_key: + ldmia sp!,{r4-r11} +#if __ARM_ARCH__ >= 5 + bx lr @ bx lr +#else + tst lr,#1 + moveq pc,lr @ be binary compatible with V4, yet + .word 0xe12fff1e @ interoperable with Thumb ISA:-) +#endif +.size poly1305_init_arm,.-poly1305_init_arm + +.align 5 +.globl poly1305_blocks_arm +.type poly1305_blocks_arm,%function +poly1305_blocks_arm: +.Lpoly1305_blocks_arm: + stmdb sp!,{r3-r11,lr} + + ands r2,r2,#-16 + beq .Lno_data + + cmp r3,#0 + add r2,r2,r1 @ end pointer + sub sp,sp,#32 + + ldmia r0,{r4-r12} @ load context + + str r0,[sp,#12] @ offload stuff + mov lr,r1 + str r2,[sp,#16] + str r10,[sp,#20] + str r11,[sp,#24] + str r12,[sp,#28] + b .Loop + +.Loop: +#if __ARM_ARCH__ < 7 + ldrb r0,[lr],#16 @ load input +#ifdef __thumb2__ + it hi +#endif + addhi r8,r8,#1 @ 1<<128 + ldrb r1,[lr,#-15] + ldrb r2,[lr,#-14] + ldrb r3,[lr,#-13] + orr r1,r0,r1,lsl#8 + ldrb r0,[lr,#-12] + orr r2,r1,r2,lsl#16 + ldrb r1,[lr,#-11] + orr r3,r2,r3,lsl#24 + ldrb r2,[lr,#-10] + adds r4,r4,r3 @ accumulate input + + ldrb r3,[lr,#-9] + orr r1,r0,r1,lsl#8 + ldrb r0,[lr,#-8] + orr r2,r1,r2,lsl#16 + ldrb r1,[lr,#-7] + orr r3,r2,r3,lsl#24 + ldrb r2,[lr,#-6] + adcs r5,r5,r3 + + ldrb r3,[lr,#-5] + orr r1,r0,r1,lsl#8 + ldrb r0,[lr,#-4] + orr r2,r1,r2,lsl#16 + ldrb r1,[lr,#-3] + orr r3,r2,r3,lsl#24 + ldrb r2,[lr,#-2] + adcs r6,r6,r3 + + ldrb r3,[lr,#-1] + orr r1,r0,r1,lsl#8 + str lr,[sp,#8] @ offload input pointer + orr r2,r1,r2,lsl#16 + add r10,r10,r10,lsr#2 + orr r3,r2,r3,lsl#24 +#else + ldr r0,[lr],#16 @ load input +#ifdef __thumb2__ + it hi +#endif + addhi r8,r8,#1 @ padbit + ldr r1,[lr,#-12] + ldr r2,[lr,#-8] + ldr r3,[lr,#-4] +#ifdef __ARMEB__ + rev r0,r0 + rev r1,r1 + rev r2,r2 + rev r3,r3 +#endif + adds r4,r4,r0 @ accumulate input + str lr,[sp,#8] @ offload input pointer + adcs r5,r5,r1 + add r10,r10,r10,lsr#2 + adcs r6,r6,r2 +#endif + add r11,r11,r11,lsr#2 + adcs r7,r7,r3 + add r12,r12,r12,lsr#2 + + umull r2,r3,r5,r9 + adc r8,r8,#0 + umull r0,r1,r4,r9 + umlal r2,r3,r8,r10 + umlal r0,r1,r7,r10 + ldr r10,[sp,#20] @ reload r10 + umlal r2,r3,r6,r12 + umlal r0,r1,r5,r12 + umlal r2,r3,r7,r11 + umlal r0,r1,r6,r11 + umlal r2,r3,r4,r10 + str r0,[sp,#0] @ future r4 + mul r0,r11,r8 + ldr r11,[sp,#24] @ reload r11 + adds r2,r2,r1 @ d1+=d0>>32 + eor r1,r1,r1 + adc lr,r3,#0 @ future r6 + str r2,[sp,#4] @ future r5 + + mul r2,r12,r8 + eor r3,r3,r3 + umlal r0,r1,r7,r12 + ldr r12,[sp,#28] @ reload r12 + umlal r2,r3,r7,r9 + umlal r0,r1,r6,r9 + umlal r2,r3,r6,r10 + umlal r0,r1,r5,r10 + umlal r2,r3,r5,r11 + umlal r0,r1,r4,r11 + umlal r2,r3,r4,r12 + ldr r4,[sp,#0] + mul r8,r9,r8 + ldr r5,[sp,#4] + + adds r6,lr,r0 @ d2+=d1>>32 + ldr lr,[sp,#8] @ reload input pointer + adc r1,r1,#0 + adds r7,r2,r1 @ d3+=d2>>32 + ldr r0,[sp,#16] @ reload end pointer + adc r3,r3,#0 + add r8,r8,r3 @ h4+=d3>>32 + + and r1,r8,#-4 + and r8,r8,#3 + add r1,r1,r1,lsr#2 @ *=5 + adds r4,r4,r1 + adcs r5,r5,#0 + adcs r6,r6,#0 + adcs r7,r7,#0 + adc r8,r8,#0 + + cmp r0,lr @ done yet? + bhi .Loop + + ldr r0,[sp,#12] + add sp,sp,#32 + stmia r0,{r4-r8} @ store the result + +.Lno_data: +#if __ARM_ARCH__ >= 5 + ldmia sp!,{r3-r11,pc} +#else + ldmia sp!,{r3-r11,lr} + tst lr,#1 + moveq pc,lr @ be binary compatible with V4, yet + .word 0xe12fff1e @ interoperable with Thumb ISA:-) +#endif +.size poly1305_blocks_arm,.-poly1305_blocks_arm + + +.align 5 +.globl poly1305_emit_arm +.type poly1305_emit_arm,%function +poly1305_emit_arm: + stmdb sp!,{r4-r11} +.Lpoly1305_emit_enter: + ldmia r0,{r3-r7} + adds r8,r3,#5 @ compare to modulus + adcs r9,r4,#0 + adcs r10,r5,#0 + adcs r11,r6,#0 + adc r7,r7,#0 + tst r7,#4 @ did it carry/borrow? + +#ifdef __thumb2__ + it ne +#endif + movne r3,r8 + ldr r8,[r2,#0] +#ifdef __thumb2__ + it ne +#endif + movne r4,r9 + ldr r9,[r2,#4] +#ifdef __thumb2__ + it ne +#endif + movne r5,r10 + ldr r10,[r2,#8] +#ifdef __thumb2__ + it ne +#endif + movne r6,r11 + ldr r11,[r2,#12] + + adds r3,r3,r8 + adcs r4,r4,r9 + adcs r5,r5,r10 + adc r6,r6,r11 + +#if __ARM_ARCH__ >= 7 +#ifdef __ARMEB__ + rev r3,r3 + rev r4,r4 + rev r5,r5 + rev r6,r6 +#endif + str r3,[r1,#0] + str r4,[r1,#4] + str r5,[r1,#8] + str r6,[r1,#12] +#else + strb r3,[r1,#0] + mov r3,r3,lsr#8 + strb r4,[r1,#4] + mov r4,r4,lsr#8 + strb r5,[r1,#8] + mov r5,r5,lsr#8 + strb r6,[r1,#12] + mov r6,r6,lsr#8 + + strb r3,[r1,#1] + mov r3,r3,lsr#8 + strb r4,[r1,#5] + mov r4,r4,lsr#8 + strb r5,[r1,#9] + mov r5,r5,lsr#8 + strb r6,[r1,#13] + mov r6,r6,lsr#8 + + strb r3,[r1,#2] + mov r3,r3,lsr#8 + strb r4,[r1,#6] + mov r4,r4,lsr#8 + strb r5,[r1,#10] + mov r5,r5,lsr#8 + strb r6,[r1,#14] + mov r6,r6,lsr#8 + + strb r3,[r1,#3] + strb r4,[r1,#7] + strb r5,[r1,#11] + strb r6,[r1,#15] +#endif + ldmia sp!,{r4-r11} +#if __ARM_ARCH__ >= 5 + bx lr @ bx lr +#else + tst lr,#1 + moveq pc,lr @ be binary compatible with V4, yet + .word 0xe12fff1e @ interoperable with Thumb ISA:-) +#endif +.size poly1305_emit_arm,.-poly1305_emit_arm + + +#if __ARM_ARCH__ >= 7 +.fpu neon + +.align 5 +.type poly1305_init_neon,%function +poly1305_init_neon: +.Lpoly1305_init_neon: + ldr r4,[r0,#20] @ load key base 2^32 + ldr r5,[r0,#24] + ldr r6,[r0,#28] + ldr r7,[r0,#32] + + and r2,r4,#0x03ffffff @ base 2^32 -> base 2^26 + mov r3,r4,lsr#26 + mov r4,r5,lsr#20 + orr r3,r3,r5,lsl#6 + mov r5,r6,lsr#14 + orr r4,r4,r6,lsl#12 + mov r6,r7,lsr#8 + orr r5,r5,r7,lsl#18 + and r3,r3,#0x03ffffff + and r4,r4,#0x03ffffff + and r5,r5,#0x03ffffff + + vdup.32 d0,r2 @ r^1 in both lanes + add r2,r3,r3,lsl#2 @ *5 + vdup.32 d1,r3 + add r3,r4,r4,lsl#2 + vdup.32 d2,r2 + vdup.32 d3,r4 + add r4,r5,r5,lsl#2 + vdup.32 d4,r3 + vdup.32 d5,r5 + add r5,r6,r6,lsl#2 + vdup.32 d6,r4 + vdup.32 d7,r6 + vdup.32 d8,r5 + + mov r5,#2 @ counter + +.Lsquare_neon: + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ d0 = h0*r0 + h4*5*r1 + h3*5*r2 + h2*5*r3 + h1*5*r4 + @ d1 = h1*r0 + h0*r1 + h4*5*r2 + h3*5*r3 + h2*5*r4 + @ d2 = h2*r0 + h1*r1 + h0*r2 + h4*5*r3 + h3*5*r4 + @ d3 = h3*r0 + h2*r1 + h1*r2 + h0*r3 + h4*5*r4 + @ d4 = h4*r0 + h3*r1 + h2*r2 + h1*r3 + h0*r4 + + vmull.u32 q5,d0,d0[1] + vmull.u32 q6,d1,d0[1] + vmull.u32 q7,d3,d0[1] + vmull.u32 q8,d5,d0[1] + vmull.u32 q9,d7,d0[1] + + vmlal.u32 q5,d7,d2[1] + vmlal.u32 q6,d0,d1[1] + vmlal.u32 q7,d1,d1[1] + vmlal.u32 q8,d3,d1[1] + vmlal.u32 q9,d5,d1[1] + + vmlal.u32 q5,d5,d4[1] + vmlal.u32 q6,d7,d4[1] + vmlal.u32 q8,d1,d3[1] + vmlal.u32 q7,d0,d3[1] + vmlal.u32 q9,d3,d3[1] + + vmlal.u32 q5,d3,d6[1] + vmlal.u32 q8,d0,d5[1] + vmlal.u32 q6,d5,d6[1] + vmlal.u32 q7,d7,d6[1] + vmlal.u32 q9,d1,d5[1] + + vmlal.u32 q8,d7,d8[1] + vmlal.u32 q5,d1,d8[1] + vmlal.u32 q6,d3,d8[1] + vmlal.u32 q7,d5,d8[1] + vmlal.u32 q9,d0,d7[1] + + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ lazy reduction as discussed in "NEON crypto" by D.J. Bernstein + @ and P. Schwabe + @ + @ H0>>+H1>>+H2>>+H3>>+H4 + @ H3>>+H4>>*5+H0>>+H1 + @ + @ Trivia. + @ + @ Result of multiplication of n-bit number by m-bit number is + @ n+m bits wide. However! Even though 2^n is a n+1-bit number, + @ m-bit number multiplied by 2^n is still n+m bits wide. + @ + @ Sum of two n-bit numbers is n+1 bits wide, sum of three - n+2, + @ and so is sum of four. Sum of 2^m n-m-bit numbers and n-bit + @ one is n+1 bits wide. + @ + @ >>+ denotes Hnext += Hn>>26, Hn &= 0x3ffffff. This means that + @ H0, H2, H3 are guaranteed to be 26 bits wide, while H1 and H4 + @ can be 27. However! In cases when their width exceeds 26 bits + @ they are limited by 2^26+2^6. This in turn means that *sum* + @ of the products with these values can still be viewed as sum + @ of 52-bit numbers as long as the amount of addends is not a + @ power of 2. For example, + @ + @ H4 = H4*R0 + H3*R1 + H2*R2 + H1*R3 + H0 * R4, + @ + @ which can't be larger than 5 * (2^26 + 2^6) * (2^26 + 2^6), or + @ 5 * (2^52 + 2*2^32 + 2^12), which in turn is smaller than + @ 8 * (2^52) or 2^55. However, the value is then multiplied by + @ by 5, so we should be looking at 5 * 5 * (2^52 + 2^33 + 2^12), + @ which is less than 32 * (2^52) or 2^57. And when processing + @ data we are looking at triple as many addends... + @ + @ In key setup procedure pre-reduced H0 is limited by 5*4+1 and + @ 5*H4 - by 5*5 52-bit addends, or 57 bits. But when hashing the + @ input H0 is limited by (5*4+1)*3 addends, or 58 bits, while + @ 5*H4 by 5*5*3, or 59[!] bits. How is this relevant? vmlal.u32 + @ instruction accepts 2x32-bit input and writes 2x64-bit result. + @ This means that result of reduction have to be compressed upon + @ loop wrap-around. This can be done in the process of reduction + @ to minimize amount of instructions [as well as amount of + @ 128-bit instructions, which benefits low-end processors], but + @ one has to watch for H2 (which is narrower than H0) and 5*H4 + @ not being wider than 58 bits, so that result of right shift + @ by 26 bits fits in 32 bits. This is also useful on x86, + @ because it allows to use paddd in place for paddq, which + @ benefits Atom, where paddq is ridiculously slow. + + vshr.u64 q15,q8,#26 + vmovn.i64 d16,q8 + vshr.u64 q4,q5,#26 + vmovn.i64 d10,q5 + vadd.i64 q9,q9,q15 @ h3 -> h4 + vbic.i32 d16,#0xfc000000 @ &=0x03ffffff + vadd.i64 q6,q6,q4 @ h0 -> h1 + vbic.i32 d10,#0xfc000000 + + vshrn.u64 d30,q9,#26 + vmovn.i64 d18,q9 + vshr.u64 q4,q6,#26 + vmovn.i64 d12,q6 + vadd.i64 q7,q7,q4 @ h1 -> h2 + vbic.i32 d18,#0xfc000000 + vbic.i32 d12,#0xfc000000 + + vadd.i32 d10,d10,d30 + vshl.u32 d30,d30,#2 + vshrn.u64 d8,q7,#26 + vmovn.i64 d14,q7 + vadd.i32 d10,d10,d30 @ h4 -> h0 + vadd.i32 d16,d16,d8 @ h2 -> h3 + vbic.i32 d14,#0xfc000000 + + vshr.u32 d30,d10,#26 + vbic.i32 d10,#0xfc000000 + vshr.u32 d8,d16,#26 + vbic.i32 d16,#0xfc000000 + vadd.i32 d12,d12,d30 @ h0 -> h1 + vadd.i32 d18,d18,d8 @ h3 -> h4 + + subs r5,r5,#1 + beq .Lsquare_break_neon + + add r6,r0,#(48+0*9*4) + add r7,r0,#(48+1*9*4) + + vtrn.32 d0,d10 @ r^2:r^1 + vtrn.32 d3,d14 + vtrn.32 d5,d16 + vtrn.32 d1,d12 + vtrn.32 d7,d18 + + vshl.u32 d4,d3,#2 @ *5 + vshl.u32 d6,d5,#2 + vshl.u32 d2,d1,#2 + vshl.u32 d8,d7,#2 + vadd.i32 d4,d4,d3 + vadd.i32 d2,d2,d1 + vadd.i32 d6,d6,d5 + vadd.i32 d8,d8,d7 + + vst4.32 {d0[0],d1[0],d2[0],d3[0]},[r6]! + vst4.32 {d0[1],d1[1],d2[1],d3[1]},[r7]! + vst4.32 {d4[0],d5[0],d6[0],d7[0]},[r6]! + vst4.32 {d4[1],d5[1],d6[1],d7[1]},[r7]! + vst1.32 {d8[0]},[r6,:32] + vst1.32 {d8[1]},[r7,:32] + + b .Lsquare_neon + +.align 4 +.Lsquare_break_neon: + add r6,r0,#(48+2*4*9) + add r7,r0,#(48+3*4*9) + + vmov d0,d10 @ r^4:r^3 + vshl.u32 d2,d12,#2 @ *5 + vmov d1,d12 + vshl.u32 d4,d14,#2 + vmov d3,d14 + vshl.u32 d6,d16,#2 + vmov d5,d16 + vshl.u32 d8,d18,#2 + vmov d7,d18 + vadd.i32 d2,d2,d12 + vadd.i32 d4,d4,d14 + vadd.i32 d6,d6,d16 + vadd.i32 d8,d8,d18 + + vst4.32 {d0[0],d1[0],d2[0],d3[0]},[r6]! + vst4.32 {d0[1],d1[1],d2[1],d3[1]},[r7]! + vst4.32 {d4[0],d5[0],d6[0],d7[0]},[r6]! + vst4.32 {d4[1],d5[1],d6[1],d7[1]},[r7]! + vst1.32 {d8[0]},[r6] + vst1.32 {d8[1]},[r7] + + bx lr @ bx lr +.size poly1305_init_neon,.-poly1305_init_neon + +.align 5 +.globl poly1305_blocks_neon +.type poly1305_blocks_neon,%function +poly1305_blocks_neon: + ldr ip,[r0,#36] @ is_base2_26 + ands r2,r2,#-16 + beq .Lno_data_neon + + cmp r2,#64 + bhs .Lenter_neon + tst ip,ip @ is_base2_26? + beq .Lpoly1305_blocks_arm + +.Lenter_neon: + stmdb sp!,{r4-r7} + vstmdb sp!,{d8-d15} @ ABI specification says so + + tst ip,ip @ is_base2_26? + bne .Lbase2_26_neon + + stmdb sp!,{r1-r3,lr} + bl .Lpoly1305_init_neon + + ldr r4,[r0,#0] @ load hash value base 2^32 + ldr r5,[r0,#4] + ldr r6,[r0,#8] + ldr r7,[r0,#12] + ldr ip,[r0,#16] + + and r2,r4,#0x03ffffff @ base 2^32 -> base 2^26 + mov r3,r4,lsr#26 + veor d10,d10,d10 + mov r4,r5,lsr#20 + orr r3,r3,r5,lsl#6 + veor d12,d12,d12 + mov r5,r6,lsr#14 + orr r4,r4,r6,lsl#12 + veor d14,d14,d14 + mov r6,r7,lsr#8 + orr r5,r5,r7,lsl#18 + veor d16,d16,d16 + and r3,r3,#0x03ffffff + orr r6,r6,ip,lsl#24 + veor d18,d18,d18 + and r4,r4,#0x03ffffff + mov r1,#1 + and r5,r5,#0x03ffffff + str r1,[r0,#36] @ is_base2_26 + + vmov.32 d10[0],r2 + vmov.32 d12[0],r3 + vmov.32 d14[0],r4 + vmov.32 d16[0],r5 + vmov.32 d18[0],r6 + adr r5,.Lzeros + + ldmia sp!,{r1-r3,lr} + b .Lbase2_32_neon + +.align 4 +.Lbase2_26_neon: + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ load hash value + + veor d10,d10,d10 + veor d12,d12,d12 + veor d14,d14,d14 + veor d16,d16,d16 + veor d18,d18,d18 + vld4.32 {d10[0],d12[0],d14[0],d16[0]},[r0]! + adr r5,.Lzeros + vld1.32 {d18[0]},[r0] + sub r0,r0,#16 @ rewind + +.Lbase2_32_neon: + add r4,r1,#32 + mov r3,r3,lsl#24 + tst r2,#31 + beq .Leven + + vld4.32 {d20[0],d22[0],d24[0],d26[0]},[r1]! + vmov.32 d28[0],r3 + sub r2,r2,#16 + add r4,r1,#32 + +#ifdef __ARMEB__ + vrev32.8 q10,q10 + vrev32.8 q13,q13 + vrev32.8 q11,q11 + vrev32.8 q12,q12 +#endif + vsri.u32 d28,d26,#8 @ base 2^32 -> base 2^26 + vshl.u32 d26,d26,#18 + + vsri.u32 d26,d24,#14 + vshl.u32 d24,d24,#12 + vadd.i32 d29,d28,d18 @ add hash value and move to #hi + + vbic.i32 d26,#0xfc000000 + vsri.u32 d24,d22,#20 + vshl.u32 d22,d22,#6 + + vbic.i32 d24,#0xfc000000 + vsri.u32 d22,d20,#26 + vadd.i32 d27,d26,d16 + + vbic.i32 d20,#0xfc000000 + vbic.i32 d22,#0xfc000000 + vadd.i32 d25,d24,d14 + + vadd.i32 d21,d20,d10 + vadd.i32 d23,d22,d12 + + mov r7,r5 + add r6,r0,#48 + + cmp r2,r2 + b .Long_tail + +.align 4 +.Leven: + subs r2,r2,#64 + it lo + movlo r4,r5 + + vmov.i32 q14,#1<<24 @ padbit, yes, always + vld4.32 {d20,d22,d24,d26},[r1] @ inp[0:1] + add r1,r1,#64 + vld4.32 {d21,d23,d25,d27},[r4] @ inp[2:3] (or 0) + add r4,r4,#64 + itt hi + addhi r7,r0,#(48+1*9*4) + addhi r6,r0,#(48+3*9*4) + +#ifdef __ARMEB__ + vrev32.8 q10,q10 + vrev32.8 q13,q13 + vrev32.8 q11,q11 + vrev32.8 q12,q12 +#endif + vsri.u32 q14,q13,#8 @ base 2^32 -> base 2^26 + vshl.u32 q13,q13,#18 + + vsri.u32 q13,q12,#14 + vshl.u32 q12,q12,#12 + + vbic.i32 q13,#0xfc000000 + vsri.u32 q12,q11,#20 + vshl.u32 q11,q11,#6 + + vbic.i32 q12,#0xfc000000 + vsri.u32 q11,q10,#26 + + vbic.i32 q10,#0xfc000000 + vbic.i32 q11,#0xfc000000 + + bls .Lskip_loop + + vld4.32 {d0[1],d1[1],d2[1],d3[1]},[r7]! @ load r^2 + vld4.32 {d0[0],d1[0],d2[0],d3[0]},[r6]! @ load r^4 + vld4.32 {d4[1],d5[1],d6[1],d7[1]},[r7]! + vld4.32 {d4[0],d5[0],d6[0],d7[0]},[r6]! + b .Loop_neon + +.align 5 +.Loop_neon: + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ ((inp[0]*r^4+inp[2]*r^2+inp[4])*r^4+inp[6]*r^2 + @ ((inp[1]*r^4+inp[3]*r^2+inp[5])*r^3+inp[7]*r + @ ___________________/ + @ ((inp[0]*r^4+inp[2]*r^2+inp[4])*r^4+inp[6]*r^2+inp[8])*r^2 + @ ((inp[1]*r^4+inp[3]*r^2+inp[5])*r^4+inp[7]*r^2+inp[9])*r + @ ___________________/ ____________________/ + @ + @ Note that we start with inp[2:3]*r^2. This is because it + @ doesn't depend on reduction in previous iteration. + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ d4 = h4*r0 + h3*r1 + h2*r2 + h1*r3 + h0*r4 + @ d3 = h3*r0 + h2*r1 + h1*r2 + h0*r3 + h4*5*r4 + @ d2 = h2*r0 + h1*r1 + h0*r2 + h4*5*r3 + h3*5*r4 + @ d1 = h1*r0 + h0*r1 + h4*5*r2 + h3*5*r3 + h2*5*r4 + @ d0 = h0*r0 + h4*5*r1 + h3*5*r2 + h2*5*r3 + h1*5*r4 + + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ inp[2:3]*r^2 + + vadd.i32 d24,d24,d14 @ accumulate inp[0:1] + vmull.u32 q7,d25,d0[1] + vadd.i32 d20,d20,d10 + vmull.u32 q5,d21,d0[1] + vadd.i32 d26,d26,d16 + vmull.u32 q8,d27,d0[1] + vmlal.u32 q7,d23,d1[1] + vadd.i32 d22,d22,d12 + vmull.u32 q6,d23,d0[1] + + vadd.i32 d28,d28,d18 + vmull.u32 q9,d29,d0[1] + subs r2,r2,#64 + vmlal.u32 q5,d29,d2[1] + it lo + movlo r4,r5 + vmlal.u32 q8,d25,d1[1] + vld1.32 d8[1],[r7,:32] + vmlal.u32 q6,d21,d1[1] + vmlal.u32 q9,d27,d1[1] + + vmlal.u32 q5,d27,d4[1] + vmlal.u32 q8,d23,d3[1] + vmlal.u32 q9,d25,d3[1] + vmlal.u32 q6,d29,d4[1] + vmlal.u32 q7,d21,d3[1] + + vmlal.u32 q8,d21,d5[1] + vmlal.u32 q5,d25,d6[1] + vmlal.u32 q9,d23,d5[1] + vmlal.u32 q6,d27,d6[1] + vmlal.u32 q7,d29,d6[1] + + vmlal.u32 q8,d29,d8[1] + vmlal.u32 q5,d23,d8[1] + vmlal.u32 q9,d21,d7[1] + vmlal.u32 q6,d25,d8[1] + vmlal.u32 q7,d27,d8[1] + + vld4.32 {d21,d23,d25,d27},[r4] @ inp[2:3] (or 0) + add r4,r4,#64 + + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ (hash+inp[0:1])*r^4 and accumulate + + vmlal.u32 q8,d26,d0[0] + vmlal.u32 q5,d20,d0[0] + vmlal.u32 q9,d28,d0[0] + vmlal.u32 q6,d22,d0[0] + vmlal.u32 q7,d24,d0[0] + vld1.32 d8[0],[r6,:32] + + vmlal.u32 q8,d24,d1[0] + vmlal.u32 q5,d28,d2[0] + vmlal.u32 q9,d26,d1[0] + vmlal.u32 q6,d20,d1[0] + vmlal.u32 q7,d22,d1[0] + + vmlal.u32 q8,d22,d3[0] + vmlal.u32 q5,d26,d4[0] + vmlal.u32 q9,d24,d3[0] + vmlal.u32 q6,d28,d4[0] + vmlal.u32 q7,d20,d3[0] + + vmlal.u32 q8,d20,d5[0] + vmlal.u32 q5,d24,d6[0] + vmlal.u32 q9,d22,d5[0] + vmlal.u32 q6,d26,d6[0] + vmlal.u32 q8,d28,d8[0] + + vmlal.u32 q7,d28,d6[0] + vmlal.u32 q5,d22,d8[0] + vmlal.u32 q9,d20,d7[0] + vmov.i32 q14,#1<<24 @ padbit, yes, always + vmlal.u32 q6,d24,d8[0] + vmlal.u32 q7,d26,d8[0] + + vld4.32 {d20,d22,d24,d26},[r1] @ inp[0:1] + add r1,r1,#64 +#ifdef __ARMEB__ + vrev32.8 q10,q10 + vrev32.8 q11,q11 + vrev32.8 q12,q12 + vrev32.8 q13,q13 +#endif + + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ lazy reduction interleaved with base 2^32 -> base 2^26 of + @ inp[0:3] previously loaded to q10-q13 and smashed to q10-q14. + + vshr.u64 q15,q8,#26 + vmovn.i64 d16,q8 + vshr.u64 q4,q5,#26 + vmovn.i64 d10,q5 + vadd.i64 q9,q9,q15 @ h3 -> h4 + vbic.i32 d16,#0xfc000000 + vsri.u32 q14,q13,#8 @ base 2^32 -> base 2^26 + vadd.i64 q6,q6,q4 @ h0 -> h1 + vshl.u32 q13,q13,#18 + vbic.i32 d10,#0xfc000000 + + vshrn.u64 d30,q9,#26 + vmovn.i64 d18,q9 + vshr.u64 q4,q6,#26 + vmovn.i64 d12,q6 + vadd.i64 q7,q7,q4 @ h1 -> h2 + vsri.u32 q13,q12,#14 + vbic.i32 d18,#0xfc000000 + vshl.u32 q12,q12,#12 + vbic.i32 d12,#0xfc000000 + + vadd.i32 d10,d10,d30 + vshl.u32 d30,d30,#2 + vbic.i32 q13,#0xfc000000 + vshrn.u64 d8,q7,#26 + vmovn.i64 d14,q7 + vaddl.u32 q5,d10,d30 @ h4 -> h0 [widen for a sec] + vsri.u32 q12,q11,#20 + vadd.i32 d16,d16,d8 @ h2 -> h3 + vshl.u32 q11,q11,#6 + vbic.i32 d14,#0xfc000000 + vbic.i32 q12,#0xfc000000 + + vshrn.u64 d30,q5,#26 @ re-narrow + vmovn.i64 d10,q5 + vsri.u32 q11,q10,#26 + vbic.i32 q10,#0xfc000000 + vshr.u32 d8,d16,#26 + vbic.i32 d16,#0xfc000000 + vbic.i32 d10,#0xfc000000 + vadd.i32 d12,d12,d30 @ h0 -> h1 + vadd.i32 d18,d18,d8 @ h3 -> h4 + vbic.i32 q11,#0xfc000000 + + bhi .Loop_neon + +.Lskip_loop: + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ multiply (inp[0:1]+hash) or inp[2:3] by r^2:r^1 + + add r7,r0,#(48+0*9*4) + add r6,r0,#(48+1*9*4) + adds r2,r2,#32 + it ne + movne r2,#0 + bne .Long_tail + + vadd.i32 d25,d24,d14 @ add hash value and move to #hi + vadd.i32 d21,d20,d10 + vadd.i32 d27,d26,d16 + vadd.i32 d23,d22,d12 + vadd.i32 d29,d28,d18 + +.Long_tail: + vld4.32 {d0[1],d1[1],d2[1],d3[1]},[r7]! @ load r^1 + vld4.32 {d0[0],d1[0],d2[0],d3[0]},[r6]! @ load r^2 + + vadd.i32 d24,d24,d14 @ can be redundant + vmull.u32 q7,d25,d0 + vadd.i32 d20,d20,d10 + vmull.u32 q5,d21,d0 + vadd.i32 d26,d26,d16 + vmull.u32 q8,d27,d0 + vadd.i32 d22,d22,d12 + vmull.u32 q6,d23,d0 + vadd.i32 d28,d28,d18 + vmull.u32 q9,d29,d0 + + vmlal.u32 q5,d29,d2 + vld4.32 {d4[1],d5[1],d6[1],d7[1]},[r7]! + vmlal.u32 q8,d25,d1 + vld4.32 {d4[0],d5[0],d6[0],d7[0]},[r6]! + vmlal.u32 q6,d21,d1 + vmlal.u32 q9,d27,d1 + vmlal.u32 q7,d23,d1 + + vmlal.u32 q8,d23,d3 + vld1.32 d8[1],[r7,:32] + vmlal.u32 q5,d27,d4 + vld1.32 d8[0],[r6,:32] + vmlal.u32 q9,d25,d3 + vmlal.u32 q6,d29,d4 + vmlal.u32 q7,d21,d3 + + vmlal.u32 q8,d21,d5 + it ne + addne r7,r0,#(48+2*9*4) + vmlal.u32 q5,d25,d6 + it ne + addne r6,r0,#(48+3*9*4) + vmlal.u32 q9,d23,d5 + vmlal.u32 q6,d27,d6 + vmlal.u32 q7,d29,d6 + + vmlal.u32 q8,d29,d8 + vorn q0,q0,q0 @ all-ones, can be redundant + vmlal.u32 q5,d23,d8 + vshr.u64 q0,q0,#38 + vmlal.u32 q9,d21,d7 + vmlal.u32 q6,d25,d8 + vmlal.u32 q7,d27,d8 + + beq .Lshort_tail + + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ (hash+inp[0:1])*r^4:r^3 and accumulate + + vld4.32 {d0[1],d1[1],d2[1],d3[1]},[r7]! @ load r^3 + vld4.32 {d0[0],d1[0],d2[0],d3[0]},[r6]! @ load r^4 + + vmlal.u32 q7,d24,d0 + vmlal.u32 q5,d20,d0 + vmlal.u32 q8,d26,d0 + vmlal.u32 q6,d22,d0 + vmlal.u32 q9,d28,d0 + + vmlal.u32 q5,d28,d2 + vld4.32 {d4[1],d5[1],d6[1],d7[1]},[r7]! + vmlal.u32 q8,d24,d1 + vld4.32 {d4[0],d5[0],d6[0],d7[0]},[r6]! + vmlal.u32 q6,d20,d1 + vmlal.u32 q9,d26,d1 + vmlal.u32 q7,d22,d1 + + vmlal.u32 q8,d22,d3 + vld1.32 d8[1],[r7,:32] + vmlal.u32 q5,d26,d4 + vld1.32 d8[0],[r6,:32] + vmlal.u32 q9,d24,d3 + vmlal.u32 q6,d28,d4 + vmlal.u32 q7,d20,d3 + + vmlal.u32 q8,d20,d5 + vmlal.u32 q5,d24,d6 + vmlal.u32 q9,d22,d5 + vmlal.u32 q6,d26,d6 + vmlal.u32 q7,d28,d6 + + vmlal.u32 q8,d28,d8 + vorn q0,q0,q0 @ all-ones + vmlal.u32 q5,d22,d8 + vshr.u64 q0,q0,#38 + vmlal.u32 q9,d20,d7 + vmlal.u32 q6,d24,d8 + vmlal.u32 q7,d26,d8 + +.Lshort_tail: + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ horizontal addition + + vadd.i64 d16,d16,d17 + vadd.i64 d10,d10,d11 + vadd.i64 d18,d18,d19 + vadd.i64 d12,d12,d13 + vadd.i64 d14,d14,d15 + + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ lazy reduction, but without narrowing + + vshr.u64 q15,q8,#26 + vand.i64 q8,q8,q0 + vshr.u64 q4,q5,#26 + vand.i64 q5,q5,q0 + vadd.i64 q9,q9,q15 @ h3 -> h4 + vadd.i64 q6,q6,q4 @ h0 -> h1 + + vshr.u64 q15,q9,#26 + vand.i64 q9,q9,q0 + vshr.u64 q4,q6,#26 + vand.i64 q6,q6,q0 + vadd.i64 q7,q7,q4 @ h1 -> h2 + + vadd.i64 q5,q5,q15 + vshl.u64 q15,q15,#2 + vshr.u64 q4,q7,#26 + vand.i64 q7,q7,q0 + vadd.i64 q5,q5,q15 @ h4 -> h0 + vadd.i64 q8,q8,q4 @ h2 -> h3 + + vshr.u64 q15,q5,#26 + vand.i64 q5,q5,q0 + vshr.u64 q4,q8,#26 + vand.i64 q8,q8,q0 + vadd.i64 q6,q6,q15 @ h0 -> h1 + vadd.i64 q9,q9,q4 @ h3 -> h4 + + cmp r2,#0 + bne .Leven + + @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + @ store hash value + + vst4.32 {d10[0],d12[0],d14[0],d16[0]},[r0]! + vst1.32 {d18[0]},[r0] + + vldmia sp!,{d8-d15} @ epilogue + ldmia sp!,{r4-r7} +.Lno_data_neon: + bx lr @ bx lr +.size poly1305_blocks_neon,.-poly1305_blocks_neon + +.align 5 +.globl poly1305_emit_neon +.type poly1305_emit_neon,%function +poly1305_emit_neon: + ldr ip,[r0,#36] @ is_base2_26 + + stmdb sp!,{r4-r11} + + tst ip,ip + beq .Lpoly1305_emit_enter + + ldmia r0,{r3-r7} + eor r8,r8,r8 + + adds r3,r3,r4,lsl#26 @ base 2^26 -> base 2^32 + mov r4,r4,lsr#6 + adcs r4,r4,r5,lsl#20 + mov r5,r5,lsr#12 + adcs r5,r5,r6,lsl#14 + mov r6,r6,lsr#18 + adcs r6,r6,r7,lsl#8 + adc r7,r8,r7,lsr#24 @ can be partially reduced ... + + and r8,r7,#-4 @ ... so reduce + and r7,r6,#3 + add r8,r8,r8,lsr#2 @ *= 5 + adds r3,r3,r8 + adcs r4,r4,#0 + adcs r5,r5,#0 + adcs r6,r6,#0 + adc r7,r7,#0 + + adds r8,r3,#5 @ compare to modulus + adcs r9,r4,#0 + adcs r10,r5,#0 + adcs r11,r6,#0 + adc r7,r7,#0 + tst r7,#4 @ did it carry/borrow? + + it ne + movne r3,r8 + ldr r8,[r2,#0] + it ne + movne r4,r9 + ldr r9,[r2,#4] + it ne + movne r5,r10 + ldr r10,[r2,#8] + it ne + movne r6,r11 + ldr r11,[r2,#12] + + adds r3,r3,r8 @ accumulate nonce + adcs r4,r4,r9 + adcs r5,r5,r10 + adc r6,r6,r11 + +#ifdef __ARMEB__ + rev r3,r3 + rev r4,r4 + rev r5,r5 + rev r6,r6 +#endif + str r3,[r1,#0] @ store the result + str r4,[r1,#4] + str r5,[r1,#8] + str r6,[r1,#12] + + ldmia sp!,{r4-r11} + bx lr @ bx lr +.size poly1305_emit_neon,.-poly1305_emit_neon + +.align 5 +.Lzeros: +.long 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 +#endif diff --git a/crypto/poly1305/poly1305-arm64.pl b/crypto/poly1305/poly1305-arm64.pl new file mode 100644 index 0000000..ac06457 --- /dev/null +++ b/crypto/poly1305/poly1305-arm64.pl @@ -0,0 +1,944 @@ +#! /usr/bin/env perl +# Copyright 2016 The OpenSSL Project Authors. All Rights Reserved. +# +# Licensed under the OpenSSL license (the "License"). You may not use +# this file except in compliance with the License. You can obtain a copy +# in the file LICENSE in the source distribution or at +# https://www.openssl.org/source/license.html + +# +# ==================================================================== +# Written by Andy Polyakov for the OpenSSL +# project. The module is, however, dual licensed under OpenSSL and +# CRYPTOGAMS licenses depending on where you obtain it. For further +# details see http://www.openssl.org/~appro/cryptogams/. +# ==================================================================== +# +# This module implements Poly1305 hash for ARMv8. +# +# June 2015 +# +# Numbers are cycles per processed byte with poly1305_blocks alone. +# +# IALU/gcc-4.9 NEON +# +# Apple A7 1.86/+5% 0.72 +# Cortex-A53 2.69/+58% 1.47 +# Cortex-A57 2.70/+7% 1.14 +# Denver 1.64/+50% 1.18(*) +# X-Gene 2.13/+68% 2.27 +# Mongoose 1.77/+75% 1.12 +# Kryo 2.70/+55% 1.13 +# +# (*) estimate based on resources availability is less than 1.0, +# i.e. measured result is worse than expected, presumably binary +# translator is not almighty; + +$flavour=shift; +$output=shift; + +$0 =~ m/(.*[\/\\])[^\/\\]+$/; $dir=$1; +( $xlate="${dir}arm-xlate.pl" and -f $xlate ) or +( $xlate="${dir}../../perlasm/arm-xlate.pl" and -f $xlate) or +die "can't locate arm-xlate.pl"; + +open OUT,"| \"$^X\" $xlate $flavour $output"; +*STDOUT=*OUT; + +my ($ctx,$inp,$len,$padbit) = map("x$_",(0..3)); +my ($mac,$nonce)=($inp,$len); + +my ($h0,$h1,$h2,$r0,$r1,$s1,$t0,$t1,$d0,$d1,$d2) = map("x$_",(4..14)); + +$code.=<<___; +#include "arm_arch.h" + +.text + +// forward "declarations" are required for Apple +.extern OPENSSL_armcap_P +.globl poly1305_blocks +.globl poly1305_emit + +.globl poly1305_init +.type poly1305_init,%function +.align 5 +poly1305_init: + cmp $inp,xzr + stp xzr,xzr,[$ctx] // zero hash value + stp xzr,xzr,[$ctx,#16] // [along with is_base2_26] + + csel x0,xzr,x0,eq + b.eq .Lno_key + +#ifdef __ILP32__ + ldrsw $t1,.LOPENSSL_armcap_P +#else + ldr $t1,.LOPENSSL_armcap_P +#endif + adr $t0,.LOPENSSL_armcap_P + + ldp $r0,$r1,[$inp] // load key + mov $s1,#0xfffffffc0fffffff + movk $s1,#0x0fff,lsl#48 + ldr w17,[$t0,$t1] +#ifdef __ARMEB__ + rev $r0,$r0 // flip bytes + rev $r1,$r1 +#endif + and $r0,$r0,$s1 // &=0ffffffc0fffffff + and $s1,$s1,#-4 + and $r1,$r1,$s1 // &=0ffffffc0ffffffc + stp $r0,$r1,[$ctx,#32] // save key value + + tst w17,#ARMV7_NEON + + adr $d0,poly1305_blocks + adr $r0,poly1305_blocks_neon + adr $d1,poly1305_emit + adr $r1,poly1305_emit_neon + + csel $d0,$d0,$r0,eq + csel $d1,$d1,$r1,eq + +#ifdef __ILP32__ + stp w12,w13,[$len] +#else + stp $d0,$d1,[$len] +#endif + + mov x0,#1 +.Lno_key: + ret +.size poly1305_init,.-poly1305_init + +.type poly1305_blocks,%function +.align 5 +poly1305_blocks: + ands $len,$len,#-16 + b.eq .Lno_data + + ldp $h0,$h1,[$ctx] // load hash value + ldp $r0,$r1,[$ctx,#32] // load key value + ldr $h2,[$ctx,#16] + add $s1,$r1,$r1,lsr#2 // s1 = r1 + (r1 >> 2) + b .Loop + +.align 5 +.Loop: + ldp $t0,$t1,[$inp],#16 // load input + sub $len,$len,#16 +#ifdef __ARMEB__ + rev $t0,$t0 + rev $t1,$t1 +#endif + adds $h0,$h0,$t0 // accumulate input + adcs $h1,$h1,$t1 + + mul $d0,$h0,$r0 // h0*r0 + adc $h2,$h2,$padbit + umulh $d1,$h0,$r0 + + mul $t0,$h1,$s1 // h1*5*r1 + umulh $t1,$h1,$s1 + + adds $d0,$d0,$t0 + mul $t0,$h0,$r1 // h0*r1 + adc $d1,$d1,$t1 + umulh $d2,$h0,$r1 + + adds $d1,$d1,$t0 + mul $t0,$h1,$r0 // h1*r0 + adc $d2,$d2,xzr + umulh $t1,$h1,$r0 + + adds $d1,$d1,$t0 + mul $t0,$h2,$s1 // h2*5*r1 + adc $d2,$d2,$t1 + mul $t1,$h2,$r0 // h2*r0 + + adds $d1,$d1,$t0 + adc $d2,$d2,$t1 + + and $t0,$d2,#-4 // final reduction + and $h2,$d2,#3 + add $t0,$t0,$d2,lsr#2 + adds $h0,$d0,$t0 + adcs $h1,$d1,xzr + adc $h2,$h2,xzr + + cbnz $len,.Loop + + stp $h0,$h1,[$ctx] // store hash value + str $h2,[$ctx,#16] + +.Lno_data: + ret +.size poly1305_blocks,.-poly1305_blocks + +.type poly1305_emit,%function +.align 5 +poly1305_emit: + ldp $h0,$h1,[$ctx] // load hash base 2^64 + ldr $h2,[$ctx,#16] + ldp $t0,$t1,[$nonce] // load nonce + + adds $d0,$h0,#5 // compare to modulus + adcs $d1,$h1,xzr + adc $d2,$h2,xzr + + tst $d2,#-4 // see if it's carried/borrowed + + csel $h0,$h0,$d0,eq + csel $h1,$h1,$d1,eq + +#ifdef __ARMEB__ + ror $t0,$t0,#32 // flip nonce words + ror $t1,$t1,#32 +#endif + adds $h0,$h0,$t0 // accumulate nonce + adc $h1,$h1,$t1 +#ifdef __ARMEB__ + rev $h0,$h0 // flip output bytes + rev $h1,$h1 +#endif + stp $h0,$h1,[$mac] // write result + + ret +.size poly1305_emit,.-poly1305_emit +___ +my ($R0,$R1,$S1,$R2,$S2,$R3,$S3,$R4,$S4) = map("v$_.4s",(0..8)); +my ($IN01_0,$IN01_1,$IN01_2,$IN01_3,$IN01_4) = map("v$_.2s",(9..13)); +my ($IN23_0,$IN23_1,$IN23_2,$IN23_3,$IN23_4) = map("v$_.2s",(14..18)); +my ($ACC0,$ACC1,$ACC2,$ACC3,$ACC4) = map("v$_.2d",(19..23)); +my ($H0,$H1,$H2,$H3,$H4) = map("v$_.2s",(24..28)); +my ($T0,$T1,$MASK) = map("v$_",(29..31)); + +my ($in2,$zeros)=("x16","x17"); +my $is_base2_26 = $zeros; # borrow + +$code.=<<___; +.type poly1305_mult,%function +.align 5 +poly1305_mult: + mul $d0,$h0,$r0 // h0*r0 + umulh $d1,$h0,$r0 + + mul $t0,$h1,$s1 // h1*5*r1 + umulh $t1,$h1,$s1 + + adds $d0,$d0,$t0 + mul $t0,$h0,$r1 // h0*r1 + adc $d1,$d1,$t1 + umulh $d2,$h0,$r1 + + adds $d1,$d1,$t0 + mul $t0,$h1,$r0 // h1*r0 + adc $d2,$d2,xzr + umulh $t1,$h1,$r0 + + adds $d1,$d1,$t0 + mul $t0,$h2,$s1 // h2*5*r1 + adc $d2,$d2,$t1 + mul $t1,$h2,$r0 // h2*r0 + + adds $d1,$d1,$t0 + adc $d2,$d2,$t1 + + and $t0,$d2,#-4 // final reduction + and $h2,$d2,#3 + add $t0,$t0,$d2,lsr#2 + adds $h0,$d0,$t0 + adcs $h1,$d1,xzr + adc $h2,$h2,xzr + + ret +.size poly1305_mult,.-poly1305_mult + +.type poly1305_splat,%function +.align 5 +poly1305_splat: + and x12,$h0,#0x03ffffff // base 2^64 -> base 2^26 + ubfx x13,$h0,#26,#26 + extr x14,$h1,$h0,#52 + and x14,x14,#0x03ffffff + ubfx x15,$h1,#14,#26 + extr x16,$h2,$h1,#40 + + str w12,[$ctx,#16*0] // r0 + add w12,w13,w13,lsl#2 // r1*5 + str w13,[$ctx,#16*1] // r1 + add w13,w14,w14,lsl#2 // r2*5 + str w12,[$ctx,#16*2] // s1 + str w14,[$ctx,#16*3] // r2 + add w14,w15,w15,lsl#2 // r3*5 + str w13,[$ctx,#16*4] // s2 + str w15,[$ctx,#16*5] // r3 + add w15,w16,w16,lsl#2 // r4*5 + str w14,[$ctx,#16*6] // s3 + str w16,[$ctx,#16*7] // r4 + str w15,[$ctx,#16*8] // s4 + + ret +.size poly1305_splat,.-poly1305_splat + +.type poly1305_blocks_neon,%function +.align 5 +poly1305_blocks_neon: + ldr $is_base2_26,[$ctx,#24] + cmp $len,#128 + b.hs .Lblocks_neon + cbz $is_base2_26,poly1305_blocks + +.Lblocks_neon: + stp x29,x30,[sp,#-80]! + add x29,sp,#0 + + ands $len,$len,#-16 + b.eq .Lno_data_neon + + cbz $is_base2_26,.Lbase2_64_neon + + ldp w10,w11,[$ctx] // load hash value base 2^26 + ldp w12,w13,[$ctx,#8] + ldr w14,[$ctx,#16] + + tst $len,#31 + b.eq .Leven_neon + + ldp $r0,$r1,[$ctx,#32] // load key value + + add $h0,x10,x11,lsl#26 // base 2^26 -> base 2^64 + lsr $h1,x12,#12 + adds $h0,$h0,x12,lsl#52 + add $h1,$h1,x13,lsl#14 + adc $h1,$h1,xzr + lsr $h2,x14,#24 + adds $h1,$h1,x14,lsl#40 + adc $d2,$h2,xzr // can be partially reduced... + + ldp $d0,$d1,[$inp],#16 // load input + sub $len,$len,#16 + add $s1,$r1,$r1,lsr#2 // s1 = r1 + (r1 >> 2) + + and $t0,$d2,#-4 // ... so reduce + and $h2,$d2,#3 + add $t0,$t0,$d2,lsr#2 + adds $h0,$h0,$t0 + adcs $h1,$h1,xzr + adc $h2,$h2,xzr + +#ifdef __ARMEB__ + rev $d0,$d0 + rev $d1,$d1 +#endif + adds $h0,$h0,$d0 // accumulate input + adcs $h1,$h1,$d1 + adc $h2,$h2,$padbit + + bl poly1305_mult + ldr x30,[sp,#8] + + cbz $padbit,.Lstore_base2_64_neon + + and x10,$h0,#0x03ffffff // base 2^64 -> base 2^26 + ubfx x11,$h0,#26,#26 + extr x12,$h1,$h0,#52 + and x12,x12,#0x03ffffff + ubfx x13,$h1,#14,#26 + extr x14,$h2,$h1,#40 + + cbnz $len,.Leven_neon + + stp w10,w11,[$ctx] // store hash value base 2^26 + stp w12,w13,[$ctx,#8] + str w14,[$ctx,#16] + b .Lno_data_neon + +.align 4 +.Lstore_base2_64_neon: + stp $h0,$h1,[$ctx] // store hash value base 2^64 + stp $h2,xzr,[$ctx,#16] // note that is_base2_26 is zeroed + b .Lno_data_neon + +.align 4 +.Lbase2_64_neon: + ldp $r0,$r1,[$ctx,#32] // load key value + + ldp $h0,$h1,[$ctx] // load hash value base 2^64 + ldr $h2,[$ctx,#16] + + tst $len,#31 + b.eq .Linit_neon + + ldp $d0,$d1,[$inp],#16 // load input + sub $len,$len,#16 + add $s1,$r1,$r1,lsr#2 // s1 = r1 + (r1 >> 2) +#ifdef __ARMEB__ + rev $d0,$d0 + rev $d1,$d1 +#endif + adds $h0,$h0,$d0 // accumulate input + adcs $h1,$h1,$d1 + adc $h2,$h2,$padbit + + bl poly1305_mult + +.Linit_neon: + and x10,$h0,#0x03ffffff // base 2^64 -> base 2^26 + ubfx x11,$h0,#26,#26 + extr x12,$h1,$h0,#52 + and x12,x12,#0x03ffffff + ubfx x13,$h1,#14,#26 + extr x14,$h2,$h1,#40 + + stp d8,d9,[sp,#16] // meet ABI requirements + stp d10,d11,[sp,#32] + stp d12,d13,[sp,#48] + stp d14,d15,[sp,#64] + + fmov ${H0},x10 + fmov ${H1},x11 + fmov ${H2},x12 + fmov ${H3},x13 + fmov ${H4},x14 + + ////////////////////////////////// initialize r^n table + mov $h0,$r0 // r^1 + add $s1,$r1,$r1,lsr#2 // s1 = r1 + (r1 >> 2) + mov $h1,$r1 + mov $h2,xzr + add $ctx,$ctx,#48+12 + bl poly1305_splat + + bl poly1305_mult // r^2 + sub $ctx,$ctx,#4 + bl poly1305_splat + + bl poly1305_mult // r^3 + sub $ctx,$ctx,#4 + bl poly1305_splat + + bl poly1305_mult // r^4 + sub $ctx,$ctx,#4 + bl poly1305_splat + ldr x30,[sp,#8] + + add $in2,$inp,#32 + adr $zeros,.Lzeros + subs $len,$len,#64 + csel $in2,$zeros,$in2,lo + + mov x4,#1 + str x4,[$ctx,#-24] // set is_base2_26 + sub $ctx,$ctx,#48 // restore original $ctx + b .Ldo_neon + +.align 4 +.Leven_neon: + add $in2,$inp,#32 + adr $zeros,.Lzeros + subs $len,$len,#64 + csel $in2,$zeros,$in2,lo + + stp d8,d9,[sp,#16] // meet ABI requirements + stp d10,d11,[sp,#32] + stp d12,d13,[sp,#48] + stp d14,d15,[sp,#64] + + fmov ${H0},x10 + fmov ${H1},x11 + fmov ${H2},x12 + fmov ${H3},x13 + fmov ${H4},x14 + +.Ldo_neon: + ldp x8,x12,[$in2],#16 // inp[2:3] (or zero) + ldp x9,x13,[$in2],#48 + + lsl $padbit,$padbit,#24 + add x15,$ctx,#48 + +#ifdef __ARMEB__ + rev x8,x8 + rev x12,x12 + rev x9,x9 + rev x13,x13 +#endif + and x4,x8,#0x03ffffff // base 2^64 -> base 2^26 + and x5,x9,#0x03ffffff + ubfx x6,x8,#26,#26 + ubfx x7,x9,#26,#26 + add x4,x4,x5,lsl#32 // bfi x4,x5,#32,#32 + extr x8,x12,x8,#52 + extr x9,x13,x9,#52 + add x6,x6,x7,lsl#32 // bfi x6,x7,#32,#32 + fmov $IN23_0,x4 + and x8,x8,#0x03ffffff + and x9,x9,#0x03ffffff + ubfx x10,x12,#14,#26 + ubfx x11,x13,#14,#26 + add x12,$padbit,x12,lsr#40 + add x13,$padbit,x13,lsr#40 + add x8,x8,x9,lsl#32 // bfi x8,x9,#32,#32 + fmov $IN23_1,x6 + add x10,x10,x11,lsl#32 // bfi x10,x11,#32,#32 + add x12,x12,x13,lsl#32 // bfi x12,x13,#32,#32 + fmov $IN23_2,x8 + fmov $IN23_3,x10 + fmov $IN23_4,x12 + + ldp x8,x12,[$inp],#16 // inp[0:1] + ldp x9,x13,[$inp],#48 + + ld1 {$R0,$R1,$S1,$R2},[x15],#64 + ld1 {$S2,$R3,$S3,$R4},[x15],#64 + ld1 {$S4},[x15] + +#ifdef __ARMEB__ + rev x8,x8 + rev x12,x12 + rev x9,x9 + rev x13,x13 +#endif + and x4,x8,#0x03ffffff // base 2^64 -> base 2^26 + and x5,x9,#0x03ffffff + ubfx x6,x8,#26,#26 + ubfx x7,x9,#26,#26 + add x4,x4,x5,lsl#32 // bfi x4,x5,#32,#32 + extr x8,x12,x8,#52 + extr x9,x13,x9,#52 + add x6,x6,x7,lsl#32 // bfi x6,x7,#32,#32 + fmov $IN01_0,x4 + and x8,x8,#0x03ffffff + and x9,x9,#0x03ffffff + ubfx x10,x12,#14,#26 + ubfx x11,x13,#14,#26 + add x12,$padbit,x12,lsr#40 + add x13,$padbit,x13,lsr#40 + add x8,x8,x9,lsl#32 // bfi x8,x9,#32,#32 + fmov $IN01_1,x6 + add x10,x10,x11,lsl#32 // bfi x10,x11,#32,#32 + add x12,x12,x13,lsl#32 // bfi x12,x13,#32,#32 + movi $MASK.2d,#-1 + fmov $IN01_2,x8 + fmov $IN01_3,x10 + fmov $IN01_4,x12 + ushr $MASK.2d,$MASK.2d,#38 + + b.ls .Lskip_loop + +.align 4 +.Loop_neon: + //////////////////////////////////////////////////////////////// + // ((inp[0]*r^4+inp[2]*r^2+inp[4])*r^4+inp[6]*r^2 + // ((inp[1]*r^4+inp[3]*r^2+inp[5])*r^3+inp[7]*r + // \___________________/ + // ((inp[0]*r^4+inp[2]*r^2+inp[4])*r^4+inp[6]*r^2+inp[8])*r^2 + // ((inp[1]*r^4+inp[3]*r^2+inp[5])*r^4+inp[7]*r^2+inp[9])*r + // \___________________/ \____________________/ + // + // Note that we start with inp[2:3]*r^2. This is because it + // doesn't depend on reduction in previous iteration. + //////////////////////////////////////////////////////////////// + // d4 = h0*r4 + h1*r3 + h2*r2 + h3*r1 + h4*r0 + // d3 = h0*r3 + h1*r2 + h2*r1 + h3*r0 + h4*5*r4 + // d2 = h0*r2 + h1*r1 + h2*r0 + h3*5*r4 + h4*5*r3 + // d1 = h0*r1 + h1*r0 + h2*5*r4 + h3*5*r3 + h4*5*r2 + // d0 = h0*r0 + h1*5*r4 + h2*5*r3 + h3*5*r2 + h4*5*r1 + + subs $len,$len,#64 + umull $ACC4,$IN23_0,${R4}[2] + csel $in2,$zeros,$in2,lo + umull $ACC3,$IN23_0,${R3}[2] + umull $ACC2,$IN23_0,${R2}[2] + ldp x8,x12,[$in2],#16 // inp[2:3] (or zero) + umull $ACC1,$IN23_0,${R1}[2] + ldp x9,x13,[$in2],#48 + umull $ACC0,$IN23_0,${R0}[2] +#ifdef __ARMEB__ + rev x8,x8 + rev x12,x12 + rev x9,x9 + rev x13,x13 +#endif + + umlal $ACC4,$IN23_1,${R3}[2] + and x4,x8,#0x03ffffff // base 2^64 -> base 2^26 + umlal $ACC3,$IN23_1,${R2}[2] + and x5,x9,#0x03ffffff + umlal $ACC2,$IN23_1,${R1}[2] + ubfx x6,x8,#26,#26 + umlal $ACC1,$IN23_1,${R0}[2] + ubfx x7,x9,#26,#26 + umlal $ACC0,$IN23_1,${S4}[2] + add x4,x4,x5,lsl#32 // bfi x4,x5,#32,#32 + + umlal $ACC4,$IN23_2,${R2}[2] + extr x8,x12,x8,#52 + umlal $ACC3,$IN23_2,${R1}[2] + extr x9,x13,x9,#52 + umlal $ACC2,$IN23_2,${R0}[2] + add x6,x6,x7,lsl#32 // bfi x6,x7,#32,#32 + umlal $ACC1,$IN23_2,${S4}[2] + fmov $IN23_0,x4 + umlal $ACC0,$IN23_2,${S3}[2] + and x8,x8,#0x03ffffff + + umlal $ACC4,$IN23_3,${R1}[2] + and x9,x9,#0x03ffffff + umlal $ACC3,$IN23_3,${R0}[2] + ubfx x10,x12,#14,#26 + umlal $ACC2,$IN23_3,${S4}[2] + ubfx x11,x13,#14,#26 + umlal $ACC1,$IN23_3,${S3}[2] + add x8,x8,x9,lsl#32 // bfi x8,x9,#32,#32 + umlal $ACC0,$IN23_3,${S2}[2] + fmov $IN23_1,x6 + + add $IN01_2,$IN01_2,$H2 + add x12,$padbit,x12,lsr#40 + umlal $ACC4,$IN23_4,${R0}[2] + add x13,$padbit,x13,lsr#40 + umlal $ACC3,$IN23_4,${S4}[2] + add x10,x10,x11,lsl#32 // bfi x10,x11,#32,#32 + umlal $ACC2,$IN23_4,${S3}[2] + add x12,x12,x13,lsl#32 // bfi x12,x13,#32,#32 + umlal $ACC1,$IN23_4,${S2}[2] + fmov $IN23_2,x8 + umlal $ACC0,$IN23_4,${S1}[2] + fmov $IN23_3,x10 + + //////////////////////////////////////////////////////////////// + // (hash+inp[0:1])*r^4 and accumulate + + add $IN01_0,$IN01_0,$H0 + fmov $IN23_4,x12 + umlal $ACC3,$IN01_2,${R1}[0] + ldp x8,x12,[$inp],#16 // inp[0:1] + umlal $ACC0,$IN01_2,${S3}[0] + ldp x9,x13,[$inp],#48 + umlal $ACC4,$IN01_2,${R2}[0] + umlal $ACC1,$IN01_2,${S4}[0] + umlal $ACC2,$IN01_2,${R0}[0] +#ifdef __ARMEB__ + rev x8,x8 + rev x12,x12 + rev x9,x9 + rev x13,x13 +#endif + + add $IN01_1,$IN01_1,$H1 + umlal $ACC3,$IN01_0,${R3}[0] + umlal $ACC4,$IN01_0,${R4}[0] + and x4,x8,#0x03ffffff // base 2^64 -> base 2^26 + umlal $ACC2,$IN01_0,${R2}[0] + and x5,x9,#0x03ffffff + umlal $ACC0,$IN01_0,${R0}[0] + ubfx x6,x8,#26,#26 + umlal $ACC1,$IN01_0,${R1}[0] + ubfx x7,x9,#26,#26 + + add $IN01_3,$IN01_3,$H3 + add x4,x4,x5,lsl#32 // bfi x4,x5,#32,#32 + umlal $ACC3,$IN01_1,${R2}[0] + extr x8,x12,x8,#52 + umlal $ACC4,$IN01_1,${R3}[0] + extr x9,x13,x9,#52 + umlal $ACC0,$IN01_1,${S4}[0] + add x6,x6,x7,lsl#32 // bfi x6,x7,#32,#32 + umlal $ACC2,$IN01_1,${R1}[0] + fmov $IN01_0,x4 + umlal $ACC1,$IN01_1,${R0}[0] + and x8,x8,#0x03ffffff + + add $IN01_4,$IN01_4,$H4 + and x9,x9,#0x03ffffff + umlal $ACC3,$IN01_3,${R0}[0] + ubfx x10,x12,#14,#26 + umlal $ACC0,$IN01_3,${S2}[0] + ubfx x11,x13,#14,#26 + umlal $ACC4,$IN01_3,${R1}[0] + add x8,x8,x9,lsl#32 // bfi x8,x9,#32,#32 + umlal $ACC1,$IN01_3,${S3}[0] + fmov $IN01_1,x6 + umlal $ACC2,$IN01_3,${S4}[0] + add x12,$padbit,x12,lsr#40 + + umlal $ACC3,$IN01_4,${S4}[0] + add x13,$padbit,x13,lsr#40 + umlal $ACC0,$IN01_4,${S1}[0] + add x10,x10,x11,lsl#32 // bfi x10,x11,#32,#32 + umlal $ACC4,$IN01_4,${R0}[0] + add x12,x12,x13,lsl#32 // bfi x12,x13,#32,#32 + umlal $ACC1,$IN01_4,${S2}[0] + fmov $IN01_2,x8 + umlal $ACC2,$IN01_4,${S3}[0] + fmov $IN01_3,x10 + fmov $IN01_4,x12 + + ///////////////////////////////////////////////////////////////// + // lazy reduction as discussed in "NEON crypto" by D.J. Bernstein + // and P. Schwabe + // + // [see discussion in poly1305-armv4 module] + + ushr $T0.2d,$ACC3,#26 + xtn $H3,$ACC3 + ushr $T1.2d,$ACC0,#26 + and $ACC0,$ACC0,$MASK.2d + add $ACC4,$ACC4,$T0.2d // h3 -> h4 + bic $H3,#0xfc,lsl#24 // &=0x03ffffff + add $ACC1,$ACC1,$T1.2d // h0 -> h1 + + ushr $T0.2d,$ACC4,#26 + xtn $H4,$ACC4 + ushr $T1.2d,$ACC1,#26 + xtn $H1,$ACC1 + bic $H4,#0xfc,lsl#24 + add $ACC2,$ACC2,$T1.2d // h1 -> h2 + + add $ACC0,$ACC0,$T0.2d + shl $T0.2d,$T0.2d,#2 + shrn $T1.2s,$ACC2,#26 + xtn $H2,$ACC2 + add $ACC0,$ACC0,$T0.2d // h4 -> h0 + bic $H1,#0xfc,lsl#24 + add $H3,$H3,$T1.2s // h2 -> h3 + bic $H2,#0xfc,lsl#24 + + shrn $T0.2s,$ACC0,#26 + xtn $H0,$ACC0 + ushr $T1.2s,$H3,#26 + bic $H3,#0xfc,lsl#24 + bic $H0,#0xfc,lsl#24 + add $H1,$H1,$T0.2s // h0 -> h1 + add $H4,$H4,$T1.2s // h3 -> h4 + + b.hi .Loop_neon + +.Lskip_loop: + dup $IN23_2,${IN23_2}[0] + add $IN01_2,$IN01_2,$H2 + + //////////////////////////////////////////////////////////////// + // multiply (inp[0:1]+hash) or inp[2:3] by r^2:r^1 + + adds $len,$len,#32 + b.ne .Long_tail + + dup $IN23_2,${IN01_2}[0] + add $IN23_0,$IN01_0,$H0 + add $IN23_3,$IN01_3,$H3 + add $IN23_1,$IN01_1,$H1 + add $IN23_4,$IN01_4,$H4 + +.Long_tail: + dup $IN23_0,${IN23_0}[0] + umull2 $ACC0,$IN23_2,${S3} + umull2 $ACC3,$IN23_2,${R1} + umull2 $ACC4,$IN23_2,${R2} + umull2 $ACC2,$IN23_2,${R0} + umull2 $ACC1,$IN23_2,${S4} + + dup $IN23_1,${IN23_1}[0] + umlal2 $ACC0,$IN23_0,${R0} + umlal2 $ACC2,$IN23_0,${R2} + umlal2 $ACC3,$IN23_0,${R3} + umlal2 $ACC4,$IN23_0,${R4} + umlal2 $ACC1,$IN23_0,${R1} + + dup $IN23_3,${IN23_3}[0] + umlal2 $ACC0,$IN23_1,${S4} + umlal2 $ACC3,$IN23_1,${R2} + umlal2 $ACC2,$IN23_1,${R1} + umlal2 $ACC4,$IN23_1,${R3} + umlal2 $ACC1,$IN23_1,${R0} + + dup $IN23_4,${IN23_4}[0] + umlal2 $ACC3,$IN23_3,${R0} + umlal2 $ACC4,$IN23_3,${R1} + umlal2 $ACC0,$IN23_3,${S2} + umlal2 $ACC1,$IN23_3,${S3} + umlal2 $ACC2,$IN23_3,${S4} + + umlal2 $ACC3,$IN23_4,${S4} + umlal2 $ACC0,$IN23_4,${S1} + umlal2 $ACC4,$IN23_4,${R0} + umlal2 $ACC1,$IN23_4,${S2} + umlal2 $ACC2,$IN23_4,${S3} + + b.eq .Lshort_tail + + //////////////////////////////////////////////////////////////// + // (hash+inp[0:1])*r^4:r^3 and accumulate + + add $IN01_0,$IN01_0,$H0 + umlal $ACC3,$IN01_2,${R1} + umlal $ACC0,$IN01_2,${S3} + umlal $ACC4,$IN01_2,${R2} + umlal $ACC1,$IN01_2,${S4} + umlal $ACC2,$IN01_2,${R0} + + add $IN01_1,$IN01_1,$H1 + umlal $ACC3,$IN01_0,${R3} + umlal $ACC0,$IN01_0,${R0} + umlal $ACC4,$IN01_0,${R4} + umlal $ACC1,$IN01_0,${R1} + umlal $ACC2,$IN01_0,${R2} + + add $IN01_3,$IN01_3,$H3 + umlal $ACC3,$IN01_1,${R2} + umlal $ACC0,$IN01_1,${S4} + umlal $ACC4,$IN01_1,${R3} + umlal $ACC1,$IN01_1,${R0} + umlal $ACC2,$IN01_1,${R1} + + add $IN01_4,$IN01_4,$H4 + umlal $ACC3,$IN01_3,${R0} + umlal $ACC0,$IN01_3,${S2} + umlal $ACC4,$IN01_3,${R1} + umlal $ACC1,$IN01_3,${S3} + umlal $ACC2,$IN01_3,${S4} + + umlal $ACC3,$IN01_4,${S4} + umlal $ACC0,$IN01_4,${S1} + umlal $ACC4,$IN01_4,${R0} + umlal $ACC1,$IN01_4,${S2} + umlal $ACC2,$IN01_4,${S3} + +.Lshort_tail: + //////////////////////////////////////////////////////////////// + // horizontal add + + addp $ACC3,$ACC3,$ACC3 + ldp d8,d9,[sp,#16] // meet ABI requirements + addp $ACC0,$ACC0,$ACC0 + ldp d10,d11,[sp,#32] + addp $ACC4,$ACC4,$ACC4 + ldp d12,d13,[sp,#48] + addp $ACC1,$ACC1,$ACC1 + ldp d14,d15,[sp,#64] + addp $ACC2,$ACC2,$ACC2 + + //////////////////////////////////////////////////////////////// + // lazy reduction, but without narrowing + + ushr $T0.2d,$ACC3,#26 + and $ACC3,$ACC3,$MASK.2d + ushr $T1.2d,$ACC0,#26 + and $ACC0,$ACC0,$MASK.2d + + add $ACC4,$ACC4,$T0.2d // h3 -> h4 + add $ACC1,$ACC1,$T1.2d // h0 -> h1 + + ushr $T0.2d,$ACC4,#26 + and $ACC4,$ACC4,$MASK.2d + ushr $T1.2d,$ACC1,#26 + and $ACC1,$ACC1,$MASK.2d + add $ACC2,$ACC2,$T1.2d // h1 -> h2 + + add $ACC0,$ACC0,$T0.2d + shl $T0.2d,$T0.2d,#2 + ushr $T1.2d,$ACC2,#26 + and $ACC2,$ACC2,$MASK.2d + add $ACC0,$ACC0,$T0.2d // h4 -> h0 + add $ACC3,$ACC3,$T1.2d // h2 -> h3 + + ushr $T0.2d,$ACC0,#26 + and $ACC0,$ACC0,$MASK.2d + ushr $T1.2d,$ACC3,#26 + and $ACC3,$ACC3,$MASK.2d + add $ACC1,$ACC1,$T0.2d // h0 -> h1 + add $ACC4,$ACC4,$T1.2d // h3 -> h4 + + //////////////////////////////////////////////////////////////// + // write the result, can be partially reduced + + st4 {$ACC0,$ACC1,$ACC2,$ACC3}[0],[$ctx],#16 + st1 {$ACC4}[0],[$ctx] + +.Lno_data_neon: + ldr x29,[sp],#80 + ret +.size poly1305_blocks_neon,.-poly1305_blocks_neon + +.type poly1305_emit_neon,%function +.align 5 +poly1305_emit_neon: + ldr $is_base2_26,[$ctx,#24] + cbz $is_base2_26,poly1305_emit + + ldp w10,w11,[$ctx] // load hash value base 2^26 + ldp w12,w13,[$ctx,#8] + ldr w14,[$ctx,#16] + + add $h0,x10,x11,lsl#26 // base 2^26 -> base 2^64 + lsr $h1,x12,#12 + adds $h0,$h0,x12,lsl#52 + add $h1,$h1,x13,lsl#14 + adc $h1,$h1,xzr + lsr $h2,x14,#24 + adds $h1,$h1,x14,lsl#40 + adc $h2,$h2,xzr // can be partially reduced... + + ldp $t0,$t1,[$nonce] // load nonce + + and $d0,$h2,#-4 // ... so reduce + add $d0,$d0,$h2,lsr#2 + and $h2,$h2,#3 + adds $h0,$h0,$d0 + adcs $h1,$h1,xzr + adc $h2,$h2,xzr + + adds $d0,$h0,#5 // compare to modulus + adcs $d1,$h1,xzr + adc $d2,$h2,xzr + + tst $d2,#-4 // see if it's carried/borrowed + + csel $h0,$h0,$d0,eq + csel $h1,$h1,$d1,eq + +#ifdef __ARMEB__ + ror $t0,$t0,#32 // flip nonce words + ror $t1,$t1,#32 +#endif + adds $h0,$h0,$t0 // accumulate nonce + adc $h1,$h1,$t1 +#ifdef __ARMEB__ + rev $h0,$h0 // flip output bytes + rev $h1,$h1 +#endif + stp $h0,$h1,[$mac] // write result + + ret +.size poly1305_emit_neon,.-poly1305_emit_neon + +.align 5 +.Lzeros: +.long 0,0,0,0,0,0,0,0 +.LOPENSSL_armcap_P: +#ifdef __ILP32__ +.long OPENSSL_armcap_P-. +#else +.quad OPENSSL_armcap_P-. +#endif +.asciz "Poly1305 for ARMv8, CRYPTOGAMS by " +.align 2 +___ + +foreach (split("\n",$code)) { + s/\b(shrn\s+v[0-9]+)\.[24]d/$1.2s/ or + s/\b(fmov\s+)v([0-9]+)[^,]*,\s*x([0-9]+)/$1d$2,x$3/ or + (m/\bdup\b/ and (s/\.[24]s/.2d/g or 1)) or + (m/\b(eor|and)/ and (s/\.[248][sdh]/.16b/g or 1)) or + (m/\bum(ul|la)l\b/ and (s/\.4s/.2s/g or 1)) or + (m/\bum(ul|la)l2\b/ and (s/\.2s/.4s/g or 1)) or + (m/\bst[1-4]\s+{[^}]+}\[/ and (s/\.[24]d/.s/g or 1)); + + s/\.[124]([sd])\[/.$1\[/; + + print $_,"\n"; +} +close STDOUT; diff --git a/crypto/poly1305/poly1305-arm64.s b/crypto/poly1305/poly1305-arm64.s new file mode 100644 index 0000000..911b57e --- /dev/null +++ b/crypto/poly1305/poly1305-arm64.s @@ -0,0 +1,820 @@ +/* SPDX-License-Identifier: OpenSSL OR (BSD-3-Clause OR GPL-2.0) + * + * Copyright (C) 2015-2018 Jason A. Donenfeld . All Rights Reserved. + * Copyright 2016 The OpenSSL Project Authors. All Rights Reserved. + */ + +#include +.text + +.align 5 +ENTRY(poly1305_init_arm) + cmp x1,xzr + stp xzr,xzr,[x0] // zero hash value + stp xzr,xzr,[x0,#16] // [along with is_base2_26] + + csel x0,xzr,x0,eq + b.eq .Lno_key + + ldp x7,x8,[x1] // load key + mov x9,#0xfffffffc0fffffff + movk x9,#0x0fff,lsl#48 +#ifdef __ARMEB__ + rev x7,x7 // flip bytes + rev x8,x8 +#endif + and x7,x7,x9 // &=0ffffffc0fffffff + and x9,x9,#-4 + and x8,x8,x9 // &=0ffffffc0ffffffc + stp x7,x8,[x0,#32] // save key value + +.Lno_key: + ret +ENDPROC(poly1305_init_arm) + +.align 5 +ENTRY(poly1305_blocks_arm) + ands x2,x2,#-16 + b.eq .Lno_data + + ldp x4,x5,[x0] // load hash value + ldp x7,x8,[x0,#32] // load key value + ldr x6,[x0,#16] + add x9,x8,x8,lsr#2 // s1 = r1 + (r1 >> 2) + b .Loop + +.align 5 +.Loop: + ldp x10,x11,[x1],#16 // load input + sub x2,x2,#16 +#ifdef __ARMEB__ + rev x10,x10 + rev x11,x11 +#endif + adds x4,x4,x10 // accumulate input + adcs x5,x5,x11 + + mul x12,x4,x7 // h0*r0 + adc x6,x6,x3 + umulh x13,x4,x7 + + mul x10,x5,x9 // h1*5*r1 + umulh x11,x5,x9 + + adds x12,x12,x10 + mul x10,x4,x8 // h0*r1 + adc x13,x13,x11 + umulh x14,x4,x8 + + adds x13,x13,x10 + mul x10,x5,x7 // h1*r0 + adc x14,x14,xzr + umulh x11,x5,x7 + + adds x13,x13,x10 + mul x10,x6,x9 // h2*5*r1 + adc x14,x14,x11 + mul x11,x6,x7 // h2*r0 + + adds x13,x13,x10 + adc x14,x14,x11 + + and x10,x14,#-4 // final reduction + and x6,x14,#3 + add x10,x10,x14,lsr#2 + adds x4,x12,x10 + adcs x5,x13,xzr + adc x6,x6,xzr + + cbnz x2,.Loop + + stp x4,x5,[x0] // store hash value + str x6,[x0,#16] + +.Lno_data: + ret +ENDPROC(poly1305_blocks_arm) + +.align 5 +ENTRY(poly1305_emit_arm) + ldp x4,x5,[x0] // load hash base 2^64 + ldr x6,[x0,#16] + ldp x10,x11,[x2] // load nonce + + adds x12,x4,#5 // compare to modulus + adcs x13,x5,xzr + adc x14,x6,xzr + + tst x14,#-4 // see if it's carried/borrowed + + csel x4,x4,x12,eq + csel x5,x5,x13,eq + +#ifdef __ARMEB__ + ror x10,x10,#32 // flip nonce words + ror x11,x11,#32 +#endif + adds x4,x4,x10 // accumulate nonce + adc x5,x5,x11 +#ifdef __ARMEB__ + rev x4,x4 // flip output bytes + rev x5,x5 +#endif + stp x4,x5,[x1] // write result + + ret +ENDPROC(poly1305_emit_arm) + +.align 5 +__poly1305_mult: + mul x12,x4,x7 // h0*r0 + umulh x13,x4,x7 + + mul x10,x5,x9 // h1*5*r1 + umulh x11,x5,x9 + + adds x12,x12,x10 + mul x10,x4,x8 // h0*r1 + adc x13,x13,x11 + umulh x14,x4,x8 + + adds x13,x13,x10 + mul x10,x5,x7 // h1*r0 + adc x14,x14,xzr + umulh x11,x5,x7 + + adds x13,x13,x10 + mul x10,x6,x9 // h2*5*r1 + adc x14,x14,x11 + mul x11,x6,x7 // h2*r0 + + adds x13,x13,x10 + adc x14,x14,x11 + + and x10,x14,#-4 // final reduction + and x6,x14,#3 + add x10,x10,x14,lsr#2 + adds x4,x12,x10 + adcs x5,x13,xzr + adc x6,x6,xzr + + ret + +__poly1305_splat: + and x12,x4,#0x03ffffff // base 2^64 -> base 2^26 + ubfx x13,x4,#26,#26 + extr x14,x5,x4,#52 + and x14,x14,#0x03ffffff + ubfx x15,x5,#14,#26 + extr x16,x6,x5,#40 + + str w12,[x0,#16*0] // r0 + add w12,w13,w13,lsl#2 // r1*5 + str w13,[x0,#16*1] // r1 + add w13,w14,w14,lsl#2 // r2*5 + str w12,[x0,#16*2] // s1 + str w14,[x0,#16*3] // r2 + add w14,w15,w15,lsl#2 // r3*5 + str w13,[x0,#16*4] // s2 + str w15,[x0,#16*5] // r3 + add w15,w16,w16,lsl#2 // r4*5 + str w14,[x0,#16*6] // s3 + str w16,[x0,#16*7] // r4 + str w15,[x0,#16*8] // s4 + + ret + +.align 5 +ENTRY(poly1305_blocks_neon) + ldr x17,[x0,#24] + cmp x2,#128 + b.hs .Lblocks_neon + cbz x17,poly1305_blocks_arm + +.Lblocks_neon: + stp x29,x30,[sp,#-80]! + add x29,sp,#0 + + ands x2,x2,#-16 + b.eq .Lno_data_neon + + cbz x17,.Lbase2_64_neon + + ldp w10,w11,[x0] // load hash value base 2^26 + ldp w12,w13,[x0,#8] + ldr w14,[x0,#16] + + tst x2,#31 + b.eq .Leven_neon + + ldp x7,x8,[x0,#32] // load key value + + add x4,x10,x11,lsl#26 // base 2^26 -> base 2^64 + lsr x5,x12,#12 + adds x4,x4,x12,lsl#52 + add x5,x5,x13,lsl#14 + adc x5,x5,xzr + lsr x6,x14,#24 + adds x5,x5,x14,lsl#40 + adc x14,x6,xzr // can be partially reduced... + + ldp x12,x13,[x1],#16 // load input + sub x2,x2,#16 + add x9,x8,x8,lsr#2 // s1 = r1 + (r1 >> 2) + + and x10,x14,#-4 // ... so reduce + and x6,x14,#3 + add x10,x10,x14,lsr#2 + adds x4,x4,x10 + adcs x5,x5,xzr + adc x6,x6,xzr + +#ifdef __ARMEB__ + rev x12,x12 + rev x13,x13 +#endif + adds x4,x4,x12 // accumulate input + adcs x5,x5,x13 + adc x6,x6,x3 + + bl __poly1305_mult + ldr x30,[sp,#8] + + cbz x3,.Lstore_base2_64_neon + + and x10,x4,#0x03ffffff // base 2^64 -> base 2^26 + ubfx x11,x4,#26,#26 + extr x12,x5,x4,#52 + and x12,x12,#0x03ffffff + ubfx x13,x5,#14,#26 + extr x14,x6,x5,#40 + + cbnz x2,.Leven_neon + + stp w10,w11,[x0] // store hash value base 2^26 + stp w12,w13,[x0,#8] + str w14,[x0,#16] + b .Lno_data_neon + +.align 4 +.Lstore_base2_64_neon: + stp x4,x5,[x0] // store hash value base 2^64 + stp x6,xzr,[x0,#16] // note that is_base2_26 is zeroed + b .Lno_data_neon + +.align 4 +.Lbase2_64_neon: + ldp x7,x8,[x0,#32] // load key value + + ldp x4,x5,[x0] // load hash value base 2^64 + ldr x6,[x0,#16] + + tst x2,#31 + b.eq .Linit_neon + + ldp x12,x13,[x1],#16 // load input + sub x2,x2,#16 + add x9,x8,x8,lsr#2 // s1 = r1 + (r1 >> 2) +#ifdef __ARMEB__ + rev x12,x12 + rev x13,x13 +#endif + adds x4,x4,x12 // accumulate input + adcs x5,x5,x13 + adc x6,x6,x3 + + bl __poly1305_mult + +.Linit_neon: + and x10,x4,#0x03ffffff // base 2^64 -> base 2^26 + ubfx x11,x4,#26,#26 + extr x12,x5,x4,#52 + and x12,x12,#0x03ffffff + ubfx x13,x5,#14,#26 + extr x14,x6,x5,#40 + + stp d8,d9,[sp,#16] // meet ABI requirements + stp d10,d11,[sp,#32] + stp d12,d13,[sp,#48] + stp d14,d15,[sp,#64] + + fmov d24,x10 + fmov d25,x11 + fmov d26,x12 + fmov d27,x13 + fmov d28,x14 + + ////////////////////////////////// initialize r^n table + mov x4,x7 // r^1 + add x9,x8,x8,lsr#2 // s1 = r1 + (r1 >> 2) + mov x5,x8 + mov x6,xzr + add x0,x0,#48+12 + bl __poly1305_splat + + bl __poly1305_mult // r^2 + sub x0,x0,#4 + bl __poly1305_splat + + bl __poly1305_mult // r^3 + sub x0,x0,#4 + bl __poly1305_splat + + bl __poly1305_mult // r^4 + sub x0,x0,#4 + bl __poly1305_splat + ldr x30,[sp,#8] + + add x16,x1,#32 + adr x17,.Lzeros + subs x2,x2,#64 + csel x16,x17,x16,lo + + mov x4,#1 + str x4,[x0,#-24] // set is_base2_26 + sub x0,x0,#48 // restore original x0 + b .Ldo_neon + +.align 4 +.Leven_neon: + add x16,x1,#32 + adr x17,.Lzeros + subs x2,x2,#64 + csel x16,x17,x16,lo + + stp d8,d9,[sp,#16] // meet ABI requirements + stp d10,d11,[sp,#32] + stp d12,d13,[sp,#48] + stp d14,d15,[sp,#64] + + fmov d24,x10 + fmov d25,x11 + fmov d26,x12 + fmov d27,x13 + fmov d28,x14 + +.Ldo_neon: + ldp x8,x12,[x16],#16 // inp[2:3] (or zero) + ldp x9,x13,[x16],#48 + + lsl x3,x3,#24 + add x15,x0,#48 + +#ifdef __ARMEB__ + rev x8,x8 + rev x12,x12 + rev x9,x9 + rev x13,x13 +#endif + and x4,x8,#0x03ffffff // base 2^64 -> base 2^26 + and x5,x9,#0x03ffffff + ubfx x6,x8,#26,#26 + ubfx x7,x9,#26,#26 + add x4,x4,x5,lsl#32 // bfi x4,x5,#32,#32 + extr x8,x12,x8,#52 + extr x9,x13,x9,#52 + add x6,x6,x7,lsl#32 // bfi x6,x7,#32,#32 + fmov d14,x4 + and x8,x8,#0x03ffffff + and x9,x9,#0x03ffffff + ubfx x10,x12,#14,#26 + ubfx x11,x13,#14,#26 + add x12,x3,x12,lsr#40 + add x13,x3,x13,lsr#40 + add x8,x8,x9,lsl#32 // bfi x8,x9,#32,#32 + fmov d15,x6 + add x10,x10,x11,lsl#32 // bfi x10,x11,#32,#32 + add x12,x12,x13,lsl#32 // bfi x12,x13,#32,#32 + fmov d16,x8 + fmov d17,x10 + fmov d18,x12 + + ldp x8,x12,[x1],#16 // inp[0:1] + ldp x9,x13,[x1],#48 + + ld1 {v0.4s,v1.4s,v2.4s,v3.4s},[x15],#64 + ld1 {v4.4s,v5.4s,v6.4s,v7.4s},[x15],#64 + ld1 {v8.4s},[x15] + +#ifdef __ARMEB__ + rev x8,x8 + rev x12,x12 + rev x9,x9 + rev x13,x13 +#endif + and x4,x8,#0x03ffffff // base 2^64 -> base 2^26 + and x5,x9,#0x03ffffff + ubfx x6,x8,#26,#26 + ubfx x7,x9,#26,#26 + add x4,x4,x5,lsl#32 // bfi x4,x5,#32,#32 + extr x8,x12,x8,#52 + extr x9,x13,x9,#52 + add x6,x6,x7,lsl#32 // bfi x6,x7,#32,#32 + fmov d9,x4 + and x8,x8,#0x03ffffff + and x9,x9,#0x03ffffff + ubfx x10,x12,#14,#26 + ubfx x11,x13,#14,#26 + add x12,x3,x12,lsr#40 + add x13,x3,x13,lsr#40 + add x8,x8,x9,lsl#32 // bfi x8,x9,#32,#32 + fmov d10,x6 + add x10,x10,x11,lsl#32 // bfi x10,x11,#32,#32 + add x12,x12,x13,lsl#32 // bfi x12,x13,#32,#32 + movi v31.2d,#-1 + fmov d11,x8 + fmov d12,x10 + fmov d13,x12 + ushr v31.2d,v31.2d,#38 + + b.ls .Lskip_loop + +.align 4 +.Loop_neon: + //////////////////////////////////////////////////////////////// + // ((inp[0]*r^4+inp[2]*r^2+inp[4])*r^4+inp[6]*r^2 + // ((inp[1]*r^4+inp[3]*r^2+inp[5])*r^3+inp[7]*r + // ___________________/ + // ((inp[0]*r^4+inp[2]*r^2+inp[4])*r^4+inp[6]*r^2+inp[8])*r^2 + // ((inp[1]*r^4+inp[3]*r^2+inp[5])*r^4+inp[7]*r^2+inp[9])*r + // ___________________/ ____________________/ + // + // Note that we start with inp[2:3]*r^2. This is because it + // doesn't depend on reduction in previous iteration. + //////////////////////////////////////////////////////////////// + // d4 = h0*r4 + h1*r3 + h2*r2 + h3*r1 + h4*r0 + // d3 = h0*r3 + h1*r2 + h2*r1 + h3*r0 + h4*5*r4 + // d2 = h0*r2 + h1*r1 + h2*r0 + h3*5*r4 + h4*5*r3 + // d1 = h0*r1 + h1*r0 + h2*5*r4 + h3*5*r3 + h4*5*r2 + // d0 = h0*r0 + h1*5*r4 + h2*5*r3 + h3*5*r2 + h4*5*r1 + + subs x2,x2,#64 + umull v23.2d,v14.2s,v7.s[2] + csel x16,x17,x16,lo + umull v22.2d,v14.2s,v5.s[2] + umull v21.2d,v14.2s,v3.s[2] + ldp x8,x12,[x16],#16 // inp[2:3] (or zero) + umull v20.2d,v14.2s,v1.s[2] + ldp x9,x13,[x16],#48 + umull v19.2d,v14.2s,v0.s[2] +#ifdef __ARMEB__ + rev x8,x8 + rev x12,x12 + rev x9,x9 + rev x13,x13 +#endif + + umlal v23.2d,v15.2s,v5.s[2] + and x4,x8,#0x03ffffff // base 2^64 -> base 2^26 + umlal v22.2d,v15.2s,v3.s[2] + and x5,x9,#0x03ffffff + umlal v21.2d,v15.2s,v1.s[2] + ubfx x6,x8,#26,#26 + umlal v20.2d,v15.2s,v0.s[2] + ubfx x7,x9,#26,#26 + umlal v19.2d,v15.2s,v8.s[2] + add x4,x4,x5,lsl#32 // bfi x4,x5,#32,#32 + + umlal v23.2d,v16.2s,v3.s[2] + extr x8,x12,x8,#52 + umlal v22.2d,v16.2s,v1.s[2] + extr x9,x13,x9,#52 + umlal v21.2d,v16.2s,v0.s[2] + add x6,x6,x7,lsl#32 // bfi x6,x7,#32,#32 + umlal v20.2d,v16.2s,v8.s[2] + fmov d14,x4 + umlal v19.2d,v16.2s,v6.s[2] + and x8,x8,#0x03ffffff + + umlal v23.2d,v17.2s,v1.s[2] + and x9,x9,#0x03ffffff + umlal v22.2d,v17.2s,v0.s[2] + ubfx x10,x12,#14,#26 + umlal v21.2d,v17.2s,v8.s[2] + ubfx x11,x13,#14,#26 + umlal v20.2d,v17.2s,v6.s[2] + add x8,x8,x9,lsl#32 // bfi x8,x9,#32,#32 + umlal v19.2d,v17.2s,v4.s[2] + fmov d15,x6 + + add v11.2s,v11.2s,v26.2s + add x12,x3,x12,lsr#40 + umlal v23.2d,v18.2s,v0.s[2] + add x13,x3,x13,lsr#40 + umlal v22.2d,v18.2s,v8.s[2] + add x10,x10,x11,lsl#32 // bfi x10,x11,#32,#32 + umlal v21.2d,v18.2s,v6.s[2] + add x12,x12,x13,lsl#32 // bfi x12,x13,#32,#32 + umlal v20.2d,v18.2s,v4.s[2] + fmov d16,x8 + umlal v19.2d,v18.2s,v2.s[2] + fmov d17,x10 + + //////////////////////////////////////////////////////////////// + // (hash+inp[0:1])*r^4 and accumulate + + add v9.2s,v9.2s,v24.2s + fmov d18,x12 + umlal v22.2d,v11.2s,v1.s[0] + ldp x8,x12,[x1],#16 // inp[0:1] + umlal v19.2d,v11.2s,v6.s[0] + ldp x9,x13,[x1],#48 + umlal v23.2d,v11.2s,v3.s[0] + umlal v20.2d,v11.2s,v8.s[0] + umlal v21.2d,v11.2s,v0.s[0] +#ifdef __ARMEB__ + rev x8,x8 + rev x12,x12 + rev x9,x9 + rev x13,x13 +#endif + + add v10.2s,v10.2s,v25.2s + umlal v22.2d,v9.2s,v5.s[0] + umlal v23.2d,v9.2s,v7.s[0] + and x4,x8,#0x03ffffff // base 2^64 -> base 2^26 + umlal v21.2d,v9.2s,v3.s[0] + and x5,x9,#0x03ffffff + umlal v19.2d,v9.2s,v0.s[0] + ubfx x6,x8,#26,#26 + umlal v20.2d,v9.2s,v1.s[0] + ubfx x7,x9,#26,#26 + + add v12.2s,v12.2s,v27.2s + add x4,x4,x5,lsl#32 // bfi x4,x5,#32,#32 + umlal v22.2d,v10.2s,v3.s[0] + extr x8,x12,x8,#52 + umlal v23.2d,v10.2s,v5.s[0] + extr x9,x13,x9,#52 + umlal v19.2d,v10.2s,v8.s[0] + add x6,x6,x7,lsl#32 // bfi x6,x7,#32,#32 + umlal v21.2d,v10.2s,v1.s[0] + fmov d9,x4 + umlal v20.2d,v10.2s,v0.s[0] + and x8,x8,#0x03ffffff + + add v13.2s,v13.2s,v28.2s + and x9,x9,#0x03ffffff + umlal v22.2d,v12.2s,v0.s[0] + ubfx x10,x12,#14,#26 + umlal v19.2d,v12.2s,v4.s[0] + ubfx x11,x13,#14,#26 + umlal v23.2d,v12.2s,v1.s[0] + add x8,x8,x9,lsl#32 // bfi x8,x9,#32,#32 + umlal v20.2d,v12.2s,v6.s[0] + fmov d10,x6 + umlal v21.2d,v12.2s,v8.s[0] + add x12,x3,x12,lsr#40 + + umlal v22.2d,v13.2s,v8.s[0] + add x13,x3,x13,lsr#40 + umlal v19.2d,v13.2s,v2.s[0] + add x10,x10,x11,lsl#32 // bfi x10,x11,#32,#32 + umlal v23.2d,v13.2s,v0.s[0] + add x12,x12,x13,lsl#32 // bfi x12,x13,#32,#32 + umlal v20.2d,v13.2s,v4.s[0] + fmov d11,x8 + umlal v21.2d,v13.2s,v6.s[0] + fmov d12,x10 + fmov d13,x12 + + ///////////////////////////////////////////////////////////////// + // lazy reduction as discussed in "NEON crypto" by D.J. Bernstein + // and P. Schwabe + // + // [see discussion in poly1305-armv4 module] + + ushr v29.2d,v22.2d,#26 + xtn v27.2s,v22.2d + ushr v30.2d,v19.2d,#26 + and v19.16b,v19.16b,v31.16b + add v23.2d,v23.2d,v29.2d // h3 -> h4 + bic v27.2s,#0xfc,lsl#24 // &=0x03ffffff + add v20.2d,v20.2d,v30.2d // h0 -> h1 + + ushr v29.2d,v23.2d,#26 + xtn v28.2s,v23.2d + ushr v30.2d,v20.2d,#26 + xtn v25.2s,v20.2d + bic v28.2s,#0xfc,lsl#24 + add v21.2d,v21.2d,v30.2d // h1 -> h2 + + add v19.2d,v19.2d,v29.2d + shl v29.2d,v29.2d,#2 + shrn v30.2s,v21.2d,#26 + xtn v26.2s,v21.2d + add v19.2d,v19.2d,v29.2d // h4 -> h0 + bic v25.2s,#0xfc,lsl#24 + add v27.2s,v27.2s,v30.2s // h2 -> h3 + bic v26.2s,#0xfc,lsl#24 + + shrn v29.2s,v19.2d,#26 + xtn v24.2s,v19.2d + ushr v30.2s,v27.2s,#26 + bic v27.2s,#0xfc,lsl#24 + bic v24.2s,#0xfc,lsl#24 + add v25.2s,v25.2s,v29.2s // h0 -> h1 + add v28.2s,v28.2s,v30.2s // h3 -> h4 + + b.hi .Loop_neon + +.Lskip_loop: + dup v16.2d,v16.d[0] + add v11.2s,v11.2s,v26.2s + + //////////////////////////////////////////////////////////////// + // multiply (inp[0:1]+hash) or inp[2:3] by r^2:r^1 + + adds x2,x2,#32 + b.ne .Long_tail + + dup v16.2d,v11.d[0] + add v14.2s,v9.2s,v24.2s + add v17.2s,v12.2s,v27.2s + add v15.2s,v10.2s,v25.2s + add v18.2s,v13.2s,v28.2s + +.Long_tail: + dup v14.2d,v14.d[0] + umull2 v19.2d,v16.4s,v6.4s + umull2 v22.2d,v16.4s,v1.4s + umull2 v23.2d,v16.4s,v3.4s + umull2 v21.2d,v16.4s,v0.4s + umull2 v20.2d,v16.4s,v8.4s + + dup v15.2d,v15.d[0] + umlal2 v19.2d,v14.4s,v0.4s + umlal2 v21.2d,v14.4s,v3.4s + umlal2 v22.2d,v14.4s,v5.4s + umlal2 v23.2d,v14.4s,v7.4s + umlal2 v20.2d,v14.4s,v1.4s + + dup v17.2d,v17.d[0] + umlal2 v19.2d,v15.4s,v8.4s + umlal2 v22.2d,v15.4s,v3.4s + umlal2 v21.2d,v15.4s,v1.4s + umlal2 v23.2d,v15.4s,v5.4s + umlal2 v20.2d,v15.4s,v0.4s + + dup v18.2d,v18.d[0] + umlal2 v22.2d,v17.4s,v0.4s + umlal2 v23.2d,v17.4s,v1.4s + umlal2 v19.2d,v17.4s,v4.4s + umlal2 v20.2d,v17.4s,v6.4s + umlal2 v21.2d,v17.4s,v8.4s + + umlal2 v22.2d,v18.4s,v8.4s + umlal2 v19.2d,v18.4s,v2.4s + umlal2 v23.2d,v18.4s,v0.4s + umlal2 v20.2d,v18.4s,v4.4s + umlal2 v21.2d,v18.4s,v6.4s + + b.eq .Lshort_tail + + //////////////////////////////////////////////////////////////// + // (hash+inp[0:1])*r^4:r^3 and accumulate + + add v9.2s,v9.2s,v24.2s + umlal v22.2d,v11.2s,v1.2s + umlal v19.2d,v11.2s,v6.2s + umlal v23.2d,v11.2s,v3.2s + umlal v20.2d,v11.2s,v8.2s + umlal v21.2d,v11.2s,v0.2s + + add v10.2s,v10.2s,v25.2s + umlal v22.2d,v9.2s,v5.2s + umlal v19.2d,v9.2s,v0.2s + umlal v23.2d,v9.2s,v7.2s + umlal v20.2d,v9.2s,v1.2s + umlal v21.2d,v9.2s,v3.2s + + add v12.2s,v12.2s,v27.2s + umlal v22.2d,v10.2s,v3.2s + umlal v19.2d,v10.2s,v8.2s + umlal v23.2d,v10.2s,v5.2s + umlal v20.2d,v10.2s,v0.2s + umlal v21.2d,v10.2s,v1.2s + + add v13.2s,v13.2s,v28.2s + umlal v22.2d,v12.2s,v0.2s + umlal v19.2d,v12.2s,v4.2s + umlal v23.2d,v12.2s,v1.2s + umlal v20.2d,v12.2s,v6.2s + umlal v21.2d,v12.2s,v8.2s + + umlal v22.2d,v13.2s,v8.2s + umlal v19.2d,v13.2s,v2.2s + umlal v23.2d,v13.2s,v0.2s + umlal v20.2d,v13.2s,v4.2s + umlal v21.2d,v13.2s,v6.2s + +.Lshort_tail: + //////////////////////////////////////////////////////////////// + // horizontal add + + addp v22.2d,v22.2d,v22.2d + ldp d8,d9,[sp,#16] // meet ABI requirements + addp v19.2d,v19.2d,v19.2d + ldp d10,d11,[sp,#32] + addp v23.2d,v23.2d,v23.2d + ldp d12,d13,[sp,#48] + addp v20.2d,v20.2d,v20.2d + ldp d14,d15,[sp,#64] + addp v21.2d,v21.2d,v21.2d + + //////////////////////////////////////////////////////////////// + // lazy reduction, but without narrowing + + ushr v29.2d,v22.2d,#26 + and v22.16b,v22.16b,v31.16b + ushr v30.2d,v19.2d,#26 + and v19.16b,v19.16b,v31.16b + + add v23.2d,v23.2d,v29.2d // h3 -> h4 + add v20.2d,v20.2d,v30.2d // h0 -> h1 + + ushr v29.2d,v23.2d,#26 + and v23.16b,v23.16b,v31.16b + ushr v30.2d,v20.2d,#26 + and v20.16b,v20.16b,v31.16b + add v21.2d,v21.2d,v30.2d // h1 -> h2 + + add v19.2d,v19.2d,v29.2d + shl v29.2d,v29.2d,#2 + ushr v30.2d,v21.2d,#26 + and v21.16b,v21.16b,v31.16b + add v19.2d,v19.2d,v29.2d // h4 -> h0 + add v22.2d,v22.2d,v30.2d // h2 -> h3 + + ushr v29.2d,v19.2d,#26 + and v19.16b,v19.16b,v31.16b + ushr v30.2d,v22.2d,#26 + and v22.16b,v22.16b,v31.16b + add v20.2d,v20.2d,v29.2d // h0 -> h1 + add v23.2d,v23.2d,v30.2d // h3 -> h4 + + //////////////////////////////////////////////////////////////// + // write the result, can be partially reduced + + st4 {v19.s,v20.s,v21.s,v22.s}[0],[x0],#16 + st1 {v23.s}[0],[x0] + +.Lno_data_neon: + ldr x29,[sp],#80 + ret +ENDPROC(poly1305_blocks_neon) + +.align 5 +ENTRY(poly1305_emit_neon) + ldr x17,[x0,#24] + cbz x17,poly1305_emit_arm + + ldp w10,w11,[x0] // load hash value base 2^26 + ldp w12,w13,[x0,#8] + ldr w14,[x0,#16] + + add x4,x10,x11,lsl#26 // base 2^26 -> base 2^64 + lsr x5,x12,#12 + adds x4,x4,x12,lsl#52 + add x5,x5,x13,lsl#14 + adc x5,x5,xzr + lsr x6,x14,#24 + adds x5,x5,x14,lsl#40 + adc x6,x6,xzr // can be partially reduced... + + ldp x10,x11,[x2] // load nonce + + and x12,x6,#-4 // ... so reduce + add x12,x12,x6,lsr#2 + and x6,x6,#3 + adds x4,x4,x12 + adcs x5,x5,xzr + adc x6,x6,xzr + + adds x12,x4,#5 // compare to modulus + adcs x13,x5,xzr + adc x14,x6,xzr + + tst x14,#-4 // see if it's carried/borrowed + + csel x4,x4,x12,eq + csel x5,x5,x13,eq + +#ifdef __ARMEB__ + ror x10,x10,#32 // flip nonce words + ror x11,x11,#32 +#endif + adds x4,x4,x10 // accumulate nonce + adc x5,x5,x11 +#ifdef __ARMEB__ + rev x4,x4 // flip output bytes + rev x5,x5 +#endif + stp x4,x5,[x1] // write result + + ret +ENDPROC(poly1305_emit_neon) + +.align 5 +.Lzeros: +.long 0,0,0,0,0,0,0,0 diff --git a/crypto/poly1305_x64_gas.s b/crypto/poly1305_x64_gas.s old mode 100644 new mode 100755 diff --git a/crypto_ops.h b/crypto_ops.h index 4c72280..09b598f 100644 --- a/crypto_ops.h +++ b/crypto_ops.h @@ -7,6 +7,7 @@ #include "tunsafe_types.h" #include + #if defined(COMPILER_MSVC) #include #endif // defined(COMPILER_MSVC) diff --git a/downarrow.bmp b/downarrow.bmp new file mode 100644 index 0000000..0237044 Binary files /dev/null and b/downarrow.bmp differ diff --git a/installer/ChangeLog.txt b/installer/ChangeLog.txt index 3cb9202..22120bd 100644 --- a/installer/ChangeLog.txt +++ b/installer/ChangeLog.txt @@ -1,3 +1,26 @@ +2018-08-11 - TunSafe v1.4-rc1 +1.Subfolders in the Config/ directory now show up as submenus. +2.Added a way to run TunSafe as a Windows Service. + Foreground Mode: The service will disconnect when TunSafe closes. + Background Mode: The service will stay connected in the background. + No longer required to run the TunSafe client as Admin as long as + the service is running. +3.New config setting [Interface].ExcludedIPs to configure IPs that + should not be routed through TunSafe. +4.Can now automatically start TunSafe when Windows starts +5.New UI with tabs and graphs +6.Cache DNS queries to ensure DNS will succeed if connection fails +7.Recreate tray icon when explorer.exe restarts +8.Renamed window title to TunSafe instead of TunSafe VPN Client +9.Main window is now resizable +10.Disallow roaming endpoint when using AllowedIPs=0.0.0.0/0 + Only the original endpoint is added in the routing table so + this would result in an endless loop of packets. +11.Display approximate Wireguard framing overhead in stats +12.Preparations for protocol handling with multiple threads +13.Delete the routes we made when disconnecting +14.Fix error message about unable to delete a route when connecting + 2018-06-20 - TunSafe v1.3-rc3 Changes: diff --git a/installer/servicelib.nsh b/installer/servicelib.nsh new file mode 100644 index 0000000..7796a58 --- /dev/null +++ b/installer/servicelib.nsh @@ -0,0 +1,419 @@ +; NSIS SERVICE LIBRARY - servicelib.nsh +; Version 1.8.1 - Jun 21th, 2013 +; Questions/Comments - dselkirk@hotmail.com +; +; Description: +; Provides an interface to window services +; +; Inputs: +; action - systemlib action ie. create, delete, start, stop, pause, +; continue, installed, running, status +; name - name of service to manipulate +; param - action parameters; usage: var1=value1;var2=value2;...etc. +; (don't forget to add a ';' after the last value!) +; +; Actions: +; create - creates a new windows service +; Parameters: +; path - path to service executable +; autostart - automatically start with system ie. 1|0 +; interact - interact with the desktop ie. 1|0 +; depend - service dependencies +; user - user that runs the service +; password - password of the above user +; display - display name in service's console +; description - Description of service +; starttype - start type (supersedes autostart) +; servicetype - service type (supersedes interact) +; +; delete - deletes a windows service +; start - start a stopped windows service +; stop - stops a running windows service +; pause - pauses a running windows service +; continue - continues a paused windows service +; installed - is the provided service installed +; Parameters: +; action - if true then invokes the specified action +; running - is the provided service running +; Parameters: +; action - if true then invokes the specified action +; status - check the status of the provided service +; +; Usage: +; Method 1: +; Push "action" +; Push "name" +; Push "param" +; Call Service +; Pop $0 ;response +; +; Method 2: +; !insertmacro SERVICE "action" "name" "param" +; +; History: +; 1.0 - 09/15/2003 - Initial release +; 1.1 - 09/16/2003 - Changed &l to i, thx brainsucker +; 1.2 - 02/29/2004 - Fixed documentation. +; 1.3 - 01/05/2006 - Fixed interactive flag and pop order (Kichik) +; 1.4 - 12/07/2006 - Added display and depend, fixed datatypes (Vitoco) +; 1.5 - 06/25/2008 - Added description of service.(DeSafe.com/liuqixing#gmail.com) +; 1.5.1 - 06/12/2009 - Added use of __UNINSTALL__ +; 1.6 - 08/02/2010 - Fixed description implementation (Anders) +; 1.7 - 04/11/2010 - Added get running service process id (Nico) +; 1.8 - 24/03/2011 - Added starttype and servicetype (Sergius) +; 1.8.1 - 21/06/2013 - Added dynamic ASCII & Unicode support (Zinthose) + +!ifndef SERVICELIB + !define SERVICELIB + + !define SC_MANAGER_ALL_ACCESS 0x3F + !define SC_STATUS_PROCESS_INFO 0x0 + !define SERVICE_ALL_ACCESS 0xF01FF + + !define SERVICE_CONTROL_STOP 1 + !define SERVICE_CONTROL_PAUSE 2 + !define SERVICE_CONTROL_CONTINUE 3 + + !define SERVICE_STOPPED 0x1 + !define SERVICE_START_PENDING 0x2 + !define SERVICE_STOP_PENDING 0x3 + !define SERVICE_RUNNING 0x4 + !define SERVICE_CONTINUE_PENDING 0x5 + !define SERVICE_PAUSE_PENDING 0x6 + !define SERVICE_PAUSED 0x7 + + !define SERVICE_KERNEL_DRIVER 0x00000001 + !define SERVICE_FILE_SYSTEM_DRIVER 0x00000002 + !define SERVICE_WIN32_OWN_PROCESS 0x00000010 + !define SERVICE_WIN32_SHARE_PROCESS 0x00000020 + !define SERVICE_INTERACTIVE_PROCESS 0x00000100 + + + !define SERVICE_BOOT_START 0x00000000 + !define SERVICE_SYSTEM_START 0x00000001 + !define SERVICE_AUTO_START 0x00000002 + !define SERVICE_DEMAND_START 0x00000003 + !define SERVICE_DISABLED 0x00000004 + + ## Added by Zinthose for Native Unicode Support + !ifdef NSIS_UNICODE + !define APITAG "W" + !else + !define APITAG "A" + !endif + + !macro SERVICE ACTION NAME PARAM + Push '${ACTION}' + Push '${NAME}' + Push '${PARAM}' + !ifdef __UNINSTALL__ + Call un.Service + !else + Call Service + !endif + !macroend + + !macro FUNC_GETPARAM + Push $0 + Push $1 + Push $2 + Push $3 + Push $4 + Push $5 + Push $6 + Push $7 + Exch 8 + Pop $1 ;name + Exch 8 + Pop $2 ;source + StrCpy $0 "" + StrLen $7 $2 + StrCpy $3 0 + lbl_loop: + IntCmp $3 $7 0 0 lbl_done + StrLen $4 "$1=" + StrCpy $5 $2 $4 $3 + StrCmp $5 "$1=" 0 lbl_next + IntOp $5 $3 + $4 + StrCpy $3 $5 + lbl_loop2: + IntCmp $3 $7 0 0 lbl_done + StrCpy $6 $2 1 $3 + StrCmp $6 ";" 0 lbl_next2 + IntOp $6 $3 - $5 + StrCpy $0 $2 $6 $5 + Goto lbl_done + lbl_next2: + IntOp $3 $3 + 1 + Goto lbl_loop2 + lbl_next: + IntOp $3 $3 + 1 + Goto lbl_loop + lbl_done: + Pop $5 + Pop $4 + Pop $3 + Pop $2 + Pop $1 + Exch 2 + Pop $6 + Pop $7 + Exch $0 + !macroend + + !macro CALL_GETPARAM VAR NAME DEFAULT LABEL + Push $1 + Push ${NAME} + Call ${UN}GETPARAM + Pop $6 + StrCpy ${VAR} "${DEFAULT}" + StrCmp $6 "" "${LABEL}" 0 + StrCpy ${VAR} $6 + !macroend + + !macro FUNC_SERVICE UN + Push $0 + Push $1 + Push $2 + Push $3 + Push $4 + Push $5 + Push $6 + Push $7 + Exch 8 + Pop $1 ;param + Exch 8 + Pop $2 ;name + Exch 8 + Pop $3 ;action + ;$0 return + ;$4 OpenSCManager + ;$5 OpenService + + StrCpy $0 "false" + System::Call 'advapi32::OpenSCManager${APITAG}(n, n, i ${SC_MANAGER_ALL_ACCESS}) i.r4' + IntCmp $4 0 lbl_done + StrCmp $3 "create" lbl_create + System::Call 'advapi32::OpenService${APITAG}(i r4, t r2, i ${SERVICE_ALL_ACCESS}) i.r5' + IntCmp $5 0 lbl_done + + lbl_select: + StrCmp $3 "delete" lbl_delete + StrCmp $3 "start" lbl_start + StrCmp $3 "stop" lbl_stop + StrCmp $3 "pause" lbl_pause + StrCmp $3 "continue" lbl_continue + StrCmp $3 "installed" lbl_installed + StrCmp $3 "running" lbl_running + StrCmp $3 "status" lbl_status + StrCmp $3 "processid" lbl_processid + Goto lbl_done + + ; create service + lbl_create: + Push $R1 ;depend + Push $R2 ;user + Push $R3 ;password + Push $R4 ;servicetype/interact + Push $R5 ;starttype/autostart + Push $R6 ;path + Push $R7 ;display + Push $R8 ;description + + !insertmacro CALL_GETPARAM $R1 "depend" "n" "lbl_depend" + StrCpy $R1 't "$R1"' + lbl_depend: + StrCmp $R1 "n" 0 lbl_machine ;old name of depend param + !insertmacro CALL_GETPARAM $R1 "machine" "n" "lbl_machine" + StrCpy $R1 't "$R1"' + lbl_machine: + + !insertmacro CALL_GETPARAM $R2 "user" "n" "lbl_user" + StrCpy $R2 't "$R2"' + lbl_user: + + !insertmacro CALL_GETPARAM $R3 "password" "n" "lbl_password" + StrCpy $R3 't "$R3"' + lbl_password: + + !insertmacro CALL_GETPARAM $R4 "interact" "${SERVICE_WIN32_OWN_PROCESS}" "lbl_interact" + StrCpy $6 ${SERVICE_WIN32_OWN_PROCESS} + IntCmp $R4 0 +2 + IntOp $6 $6 | ${SERVICE_INTERACTIVE_PROCESS} + StrCpy $R4 $6 + lbl_interact: + + !insertmacro CALL_GETPARAM $R4 "servicetype" "$R4" "lbl_servicetype" + lbl_servicetype: + + !insertmacro CALL_GETPARAM $R5 "autostart" "${SERVICE_DEMAND_START}" "lbl_autostart" + StrCpy $6 ${SERVICE_DEMAND_START} + IntCmp $R5 0 +2 + StrCpy $6 ${SERVICE_AUTO_START} + StrCpy $R5 $6 + lbl_autostart: + + !insertmacro CALL_GETPARAM $R5 "starttype" "$R5" "lbl_starttype" + lbl_starttype: + + !insertmacro CALL_GETPARAM $R6 "path" "n" "lbl_path" + lbl_path: + + !insertmacro CALL_GETPARAM $R7 "display" "$2" "lbl_display" + lbl_display: + + !insertmacro CALL_GETPARAM $R8 "description" "$2" "lbl_description" + lbl_description: + + System::Call 'advapi32::CreateService${APITAG}(i r4, t r2, t R7, i ${SERVICE_ALL_ACCESS}, \ + i R4, i R5, i 0, t R6, n, n, $R1, $R2, $R3) i.r6' + + ; write description of service (SERVICE_CONFIG_DESCRIPTION) + System::Call 'advapi32::ChangeServiceConfig2${APITAG}(ir6,i1,*t "$R8")i.R7' + strcmp $R7 "error" 0 lbl_descriptioncomplete + WriteRegStr HKLM "SYSTEM\CurrentControlSet\Services\$2" "Description" $R8 + lbl_descriptioncomplete: + + Pop $R8 + Pop $R7 + Pop $R6 + Pop $R5 + Pop $R4 + Pop $R3 + Pop $R2 + Pop $R1 + StrCmp $6 0 lbl_done lbl_good + + ; delete service + lbl_delete: + System::Call 'advapi32::DeleteService(i r5) i.r6' + StrCmp $6 0 lbl_done lbl_good + + ; start service + lbl_start: + System::Call 'advapi32::StartService${APITAG}(i r5, i 0, i 0) i.r6' + StrCmp $6 0 lbl_done lbl_good + + ; stop service + lbl_stop: + Push $R1 + System::Call '*(i,i,i,i,i,i,i) i.R1' + System::Call 'advapi32::ControlService(i r5, i ${SERVICE_CONTROL_STOP}, i $R1) i' + System::Free $R1 + Pop $R1 + StrCmp $6 0 lbl_done lbl_good + + ; pause service + lbl_pause: + Push $R1 + System::Call '*(i,i,i,i,i,i,i) i.R1' + System::Call 'advapi32::ControlService(i r5, i ${SERVICE_CONTROL_PAUSE}, i $R1) i' + System::Free $R1 + Pop $R1 + StrCmp $6 0 lbl_done lbl_good + + ; continue service + lbl_continue: + Push $R1 + System::Call '*(i,i,i,i,i,i,i) i.R1' + System::Call 'advapi32::ControlService(i r5, i ${SERVICE_CONTROL_CONTINUE}, i $R1) i' + System::Free $R1 + Pop $R1 + StrCmp $6 0 lbl_done lbl_good + + ; is installed + lbl_installed: + !insertmacro CALL_GETPARAM $7 "action" "" "lbl_good" + StrCpy $3 $7 + Goto lbl_select + + ; is service running + lbl_running: + Push $R1 + System::Call '*(i,i,i,i,i,i,i) i.R1' + System::Call 'advapi32::QueryServiceStatus(i r5, i $R1) i' + System::Call '*$R1(i, i.r6)' + System::Free $R1 + Pop $R1 + IntFmt $6 "0x%X" $6 + StrCmp $6 ${SERVICE_RUNNING} 0 lbl_done + !insertmacro CALL_GETPARAM $7 "action" "" "lbl_good" + StrCpy $3 $7 + Goto lbl_select + + lbl_status: + Push $R1 + System::Call '*(i,i,i,i,i,i,i) i.R1' + System::Call 'advapi32::QueryServiceStatus(i r5, i $R1) i' + System::Call '*$R1(i, i .r6)' + System::Free $R1 + Pop $R1 + IntFmt $6 "0x%X" $6 + StrCpy $0 "running" + IntCmp $6 ${SERVICE_RUNNING} lbl_done + StrCpy $0 "stopped" + IntCmp $6 ${SERVICE_STOPPED} lbl_done + StrCpy $0 "start_pending" + IntCmp $6 ${SERVICE_START_PENDING} lbl_done + StrCpy $0 "stop_pending" + IntCmp $6 ${SERVICE_STOP_PENDING} lbl_done + StrCpy $0 "running" + IntCmp $6 ${SERVICE_RUNNING} lbl_done + StrCpy $0 "continue_pending" + IntCmp $6 ${SERVICE_CONTINUE_PENDING} lbl_done + StrCpy $0 "pause_pending" + IntCmp $6 ${SERVICE_PAUSE_PENDING} lbl_done + StrCpy $0 "paused" + IntCmp $6 ${SERVICE_PAUSED} lbl_done + StrCpy $0 "unknown" + Goto lbl_done + + lbl_processid: + Push $R1 + Push $R2 + System::Call '*(i,i,i,i,i,i,i,i,i) i.R1' + System::Call '*(i 0) i.R2' + System::Call "advapi32::QueryServiceStatusEx(i r5, i ${SC_STATUS_PROCESS_INFO}, i $R1, i 36, i $R2) i" + System::Call "*$R1(i,i,i,i,i,i,i, i .r0)" + System::Free $R2 + System::Free $R1 + Pop $R2 + Pop $R1 + Goto lbl_done + + lbl_good: + StrCpy $0 "true" + lbl_done: + IntCmp $5 0 +2 + System::Call 'advapi32::CloseServiceHandle(i r5) n' + IntCmp $4 0 +2 + System::Call 'advapi32::CloseServiceHandle(i r4) n' + Pop $4 + Pop $3 + Pop $2 + Pop $1 + Exch 3 + Pop $5 + Pop $7 + Pop $6 + Exch $0 + !macroend + + Function Service + !insertmacro FUNC_SERVICE "" + FunctionEnd + + Function un.Service + !insertmacro FUNC_SERVICE "un." + FunctionEnd + + Function GetParam + !insertmacro FUNC_GETPARAM + FunctionEnd + + Function un.GetParam + !insertmacro FUNC_GETPARAM + FunctionEnd + + !undef APITAG +!endif \ No newline at end of file diff --git a/installer/tunsafe.nsi b/installer/tunsafe.nsi index 7b77322..044f485 100644 --- a/installer/tunsafe.nsi +++ b/installer/tunsafe.nsi @@ -9,6 +9,7 @@ SetCompressor /SOLID lzma !include "x64.nsh" !define MULTIUSER_EXECUTIONLEVEL Admin !include "MultiUser.nsh" +!include "servicelib.nsh" !insertmacro GetParameters !insertmacro GetOptions @@ -130,6 +131,7 @@ again: Sleep 500 Goto again done: + !insertmacro SERVICE stop TunSafeService "" FunctionEnd Function .onInit @@ -198,6 +200,10 @@ Function un.onInit FunctionEnd Section "Uninstall" + !insertmacro SERVICE stop "TunSafeService" "" + !insertmacro SERVICE delete "TunSafeService" "" + + Delete "$INSTDIR\TunSafe.exe" Delete "$INSTDIR\License.txt" Delete "$INSTDIR\ChangeLog.txt" diff --git a/ip_to_peer_map.cpp b/ip_to_peer_map.cpp new file mode 100644 index 0000000..4210e66 --- /dev/null +++ b/ip_to_peer_map.cpp @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: AGPL-1.0-only +// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. +#include "stdafx.h" +#include "ip_to_peer_map.h" +#include "bit_ops.h" +#include + +IpToPeerMap::IpToPeerMap() { + +} + +IpToPeerMap::~IpToPeerMap() { +} + +bool IpToPeerMap::InsertV4(const void *addr, int cidr, void *peer) { + uint32 mask = cidr == 32 ? 0xffffffff : ~(0xffffffff >> cidr); + Entry4 e = {ReadBE32(addr) & mask, mask, peer}; + ipv4_.push_back(e); + return true; +} + +bool IpToPeerMap::InsertV6(const void *addr, int cidr, void *peer) { + Entry6 e; + e.cidr_len = cidr; + e.peer = peer; + memcpy(e.ip, addr, 16); + ipv6_.push_back(e); + return true; +} + +void *IpToPeerMap::LookupV4(uint32 ip) { + uint32 best_mask = 0; + void *best_peer = NULL; + for (auto it = ipv4_.begin(); it != ipv4_.end(); ++it) { + if (it->ip == (ip & it->mask) && it->mask >= best_mask) { + best_mask = it->mask; + best_peer = it->peer; + } + } + return best_peer; +} + +void *IpToPeerMap::LookupV4DefaultPeer() { + for (auto it = ipv4_.begin(); it != ipv4_.end(); ++it) { + if (it->mask == 0) + return it->peer; + } + return NULL; +} + +void *IpToPeerMap::LookupV6DefaultPeer() { + for (auto it = ipv6_.begin(); it != ipv6_.end(); ++it) { + if (it->cidr_len == 0) + return it->peer; + } + return NULL; +} + +static int CalculateIPv6CommonPrefix(const uint8 *a, const uint8 *b) { + uint64 x = ToBE64(*(uint64*)&a[0] ^ *(uint64*)&b[0]); + uint64 y = ToBE64(*(uint64*)&a[8] ^ *(uint64*)&b[8]); + return x ? 64 - FindHighestSetBit64(x) : 128 - FindHighestSetBit64(y); +} + +void *IpToPeerMap::LookupV6(const void *addr) { + int best_len = 0; + void *best_peer = NULL; + for (auto it = ipv6_.begin(); it != ipv6_.end(); ++it) { + int len = CalculateIPv6CommonPrefix((const uint8*)addr, it->ip); + if (len >= it->cidr_len && len >= best_len) { + best_len = len; + best_peer = it->peer; + } + } + return best_peer; +} + +void IpToPeerMap::RemovePeer(void *peer) { + { + size_t n = ipv4_.size(); + Entry4 *r = &ipv4_[0], *w = r; + for (size_t i = 0; i != n; i++, r++) { + if (r->peer != peer) + *w++ = *r; + } + ipv4_.resize(w - &ipv4_[0]); + } + { + size_t n = ipv6_.size(); + Entry6 *r = &ipv6_[0], *w = r; + for (size_t i = 0; i != n; i++, r++) { + if (r->peer != peer) + *w++ = *r; + } + ipv6_.resize(w - &ipv6_[0]); + } +} \ No newline at end of file diff --git a/ip_to_peer_map.h b/ip_to_peer_map.h new file mode 100644 index 0000000..476f8cb --- /dev/null +++ b/ip_to_peer_map.h @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: AGPL-1.0-only +// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. +#pragma once + +#include "tunsafe_types.h" +#include + +// Maps CIDR addresses to a peer, always returning the longest match +// Slow O(n) implementation +class IpToPeerMap { +public: + IpToPeerMap(); + ~IpToPeerMap(); + + // Inserts an IP address of a given CIDR length into the lookup table, pointing to peer. + bool InsertV4(const void *addr, int cidr, void *peer); + bool InsertV6(const void *addr, int cidr, void *peer); + + // Lookup the peer matching the IP Address + void *LookupV4(uint32 ip); + void *LookupV6(const void *addr); + + void *LookupV4DefaultPeer(); + void *LookupV6DefaultPeer(); + + // Remove a peer from the table + void RemovePeer(void *peer); +private: + struct Entry4 { + uint32 ip; + uint32 mask; + void *peer; + }; + struct Entry6 { + uint8 ip[16]; + uint8 cidr_len; + void *peer; + }; + std::vector ipv4_; + std::vector ipv6_; +}; diff --git a/ipzip2/ipzip2.cpp b/ipzip2/ipzip2.cpp index 1b23962..5ed17a3 100644 --- a/ipzip2/ipzip2.cpp +++ b/ipzip2/ipzip2.cpp @@ -1 +1,2 @@ -// this is a placeholder for a packet compression algorithm not yet released. \ No newline at end of file +#include "stdafx.h" +// this is a placeholder for a packet compression algorithm not yet released. diff --git a/ipzip2/ipzip2.h b/ipzip2/ipzip2.h new file mode 100644 index 0000000..bfcd94b --- /dev/null +++ b/ipzip2/ipzip2.h @@ -0,0 +1 @@ +// this is a placeholder for a packet compression algorithm not yet released. diff --git a/netapi.h b/netapi.h index 56af4f6..4dfc8e1 100644 --- a/netapi.h +++ b/netapi.h @@ -121,6 +121,9 @@ public: // This holds all cidr addresses to add as additional routing entries std::vector extra_routes; + // This holds all the ips to exclude + std::vector excluded_ips; + // This holds the pre/post commands PrePostCommands pre_post_commands; }; diff --git a/network_bsd_common.cpp b/network_bsd_common.cpp index 479dff1..161f70f 100644 --- a/network_bsd_common.cpp +++ b/network_bsd_common.cpp @@ -41,6 +41,11 @@ #include #endif +void tunsafe_die(const char *msg) { + fprintf(stderr, "%s\n", msg); + exit(1); +} + void SetThreadName(const char *name) { #if defined(OS_LINUX) prctl(PR_SET_NAME, name, 0, 0, 0); @@ -438,11 +443,11 @@ static void ComputeIpv6DefaultRoute(const uint8 *ipv6_address, uint8 ipv6_cidr, default_route_v6[15] ^= 3; } -void TunsafeBackendBsd::AddRoute(uint32 ip, uint32 cidr, uint32 gw) { +void TunsafeBackendBsd::AddRoute(uint32 ip, uint32 cidr, uint32 gw, const char *dev) { uint32 ip_be, gw_be; WriteBE32(&ip_be, ip); WriteBE32(&gw_be, gw); - AddRoute(AF_INET, &ip_be, cidr, &gw_be); + AddRoute(AF_INET, &ip_be, cidr, &gw_be, dev); } static void AddOrRemoveRoute(const RouteInfo &cd, bool remove) { @@ -452,13 +457,12 @@ static void AddOrRemoveRoute(const RouteInfo &cd, bool remove) { print_ip_prefix(buf2, cd.family, cd.gw, -1); #if defined(OS_LINUX) - const char *cmd = remove ? "delete" : "add"; - if (cd.family == AF_INET) { - const char *net_or_host = (cd.cidr == 32) ? "-host" : "-net"; - RunCommand("/sbin/route %s %s %s gw %s", cmd, net_or_host, buf1, buf2); + const char *cmd = remove ? "del" : "add"; + const char *proto = (cd.family == AF_INET) ? NULL : "-6"; + if (cd.dev.empty()) { + RunCommand("/sbin/ip %s route %s %s via %s", proto, cmd, buf1, buf2); } else { - const char *net_or_host = (cd.cidr == 128) ? "-host" : "-net"; - RunCommand("/sbin/route %s %s inet6 %s gw %s", cmd, net_or_host, buf1, buf2); + RunCommand("/sbin/ip %s route %s %s dev %s", proto, cmd, buf1, cd.dev.c_str()); } #elif defined(OS_MACOSX) || defined(OS_FREEBSD) const char *cmd = remove ? "delete" : "add"; @@ -470,9 +474,10 @@ static void AddOrRemoveRoute(const RouteInfo &cd, bool remove) { #endif } -bool TunsafeBackendBsd::AddRoute(int family, const void *dest, int dest_prefix, const void *gateway) { +bool TunsafeBackendBsd::AddRoute(int family, const void *dest, int dest_prefix, const void *gateway, const char *dev) { RouteInfo c; + c.dev = dev ? dev : ""; c.family = family; size_t len = (family == AF_INET) ? 4 : 16; memcpy(c.ip, dest, len); @@ -493,7 +498,6 @@ static bool IsIpv6AddressSet(const void *p) { // Called to initialize tun bool TunsafeBackendBsd::Initialize(const TunConfig &&config, TunConfigOut *out) override { - char def_iface[12]; char devname[16]; if (!RunPrePostCommand(config.pre_post_commands.pre_up)) { @@ -513,20 +517,24 @@ bool TunsafeBackendBsd::Initialize(const TunConfig &&config, TunConfigOut *out) uint32 default_route_v4 = ComputeIpv4DefaultRoute(config.ip, netmask); RunCommand("/sbin/ifconfig %s %A mtu %d %A netmask %A up", devname, config.ip, config.mtu, config.ip, netmask); - AddRoute(config.ip & netmask, config.cidr, config.ip); + AddRoute(config.ip & netmask, config.cidr, config.ip, devname); if (config.use_ipv4_default_route) { if (config.default_route_endpoint_v4) { - uint32 gw; - if (!GetDefaultRoute(def_iface, sizeof(def_iface), &gw)) { + uint32 ipv4_default_gw; + char default_iface[16]; + if (!GetDefaultRoute(default_iface, sizeof(default_iface), &ipv4_default_gw)) { RERROR("Unable to determine default interface."); return false; } - AddRoute(config.default_route_endpoint_v4, 32, gw); - + AddRoute(config.default_route_endpoint_v4, 32, ipv4_default_gw, NULL); + for (auto it = config.excluded_ips.begin(); it != config.excluded_ips.end(); ++it) { + if (it->size == 32) + AddRoute(ReadBE32(it->addr), it->cidr, ipv4_default_gw, default_iface); + } } - AddRoute(0x00000000, 1, default_route_v4); - AddRoute(0x80000000, 1, default_route_v4); + AddRoute(0x00000000, 1, default_route_v4, devname); + AddRoute(0x80000000, 1, default_route_v4, devname); } uint8 default_route_v6[16]; @@ -537,23 +545,23 @@ bool TunsafeBackendBsd::Initialize(const TunConfig &&config, TunConfigOut *out) ComputeIpv6DefaultRoute(config.ipv6_address, config.ipv6_cidr, default_route_v6); - RunCommand("/sbin/ifconfig %s inet6 %s", devname, print_ip_prefix(buf, AF_INET6, config.ipv6_address, config.ipv6_cidr)); + RunCommand("/sbin/ifconfig %s inet6 add %s", devname, print_ip_prefix(buf, AF_INET6, config.ipv6_address, config.ipv6_cidr)); if (config.use_ipv6_default_route) { if (IsIpv6AddressSet(config.default_route_endpoint_v6)) { RERROR("default_route_endpoint_v6 not supported"); } - AddRoute(AF_INET6, matchall_1_route + 1, 1, default_route_v6); - AddRoute(AF_INET6, matchall_1_route + 0, 1, default_route_v6); + AddRoute(AF_INET6, matchall_1_route + 1, 1, default_route_v6, devname); + AddRoute(AF_INET6, matchall_1_route + 0, 1, default_route_v6, devname); } } // Add all the extra routes for (auto it = config.extra_routes.begin(); it != config.extra_routes.end(); ++it) { if (it->size == 32) { - AddRoute(ReadBE32(it->addr), it->cidr, default_route_v4); + AddRoute(ReadBE32(it->addr), it->cidr, default_route_v4, devname); } else if (it->size == 128 && config.ipv6_cidr) { - AddRoute(AF_INET6, it->addr, it->cidr, default_route_v6); + AddRoute(AF_INET6, it->addr, it->cidr, default_route_v6, devname); } } @@ -688,34 +696,38 @@ void InitCpuFeatures(); void Benchmark(); -uint32 g_ui_ip; - const char *print_ip(char buf[kSizeOfAddress], in_addr_t ip) { snprintf(buf, kSizeOfAddress, "%d.%d.%d.%d", (ip >> 24) & 0xff, (ip >> 16) & 0xff, (ip >> 8) & 0xff, (ip >> 0) & 0xff); return buf; } - class MyProcessorDelegate : public ProcessorDelegate { public: - virtual void OnConnected(in_addr_t my_ip) { - if (my_ip != g_ui_ip) { - if (my_ip) { - char buf[kSizeOfAddress]; - print_ip(buf, my_ip); - RINFO("Connection established. IP %s", buf); - } - g_ui_ip = my_ip; + MyProcessorDelegate() { + wg_processor_ = NULL; + is_connected_ = false; + } + + virtual void OnConnected() override { + if (!is_connected_) { + uint32 ipv4_ip = ReadBE32(wg_processor_->tun_addr().addr); + char buf[kSizeOfAddress]; + RINFO("Connection established. IP %s", print_ip(buf, ipv4_ip)); + is_connected_ = true; } } - virtual void OnDisconnected() { - MyProcessorDelegate::OnConnected(0); + virtual void OnConnectionRetry(uint32 attempts) override { + if (is_connected_ && attempts >= 3) { + is_connected_ = false; + RINFO("Reconnecting..."); + } } + + WireguardProcessor *wg_processor_; + bool is_connected_; }; int main(int argc, char **argv) { - bool exit_flag = false; - InitCpuFeatures(); if (argc == 2 && strcmp(argv[1], "--benchmark") == 0) { @@ -739,9 +751,12 @@ int main(int argc, char **argv) { MyProcessorDelegate my_procdel; TunsafeBackendBsd *socket_loop = CreateTunsafeBackendBsd(); WireguardProcessor wg(socket_loop, socket_loop, &my_procdel); + + my_procdel.wg_processor_ = &wg; socket_loop->SetProcessor(&wg); - if (!ParseWireGuardConfigFile(&wg, argv[1], &exit_flag)) return 1; + DnsResolver dns_resolver(NULL); + if (!ParseWireGuardConfigFile(&wg, argv[1], &dns_resolver)) return 1; if (!wg.Start()) return 1; socket_loop->RunLoop(); diff --git a/network_bsd_common.h b/network_bsd_common.h index 1db5646..cc14bc6 100644 --- a/network_bsd_common.h +++ b/network_bsd_common.h @@ -6,12 +6,14 @@ #include "netapi.h" #include "wireguard.h" #include "wireguard_config.h" +#include struct RouteInfo { uint8 family; uint8 cidr; uint8 ip[16]; uint8 gw[16]; + std::string dev; }; class TunsafeBackendBsd : public TunInterface, public UdpInterface { @@ -34,9 +36,9 @@ protected: virtual bool InitializeTun(char devname[16]) = 0; virtual void RunLoopInner() = 0; - void AddRoute(uint32 ip, uint32 cidr, uint32 gw); + void AddRoute(uint32 ip, uint32 cidr, uint32 gw, const char *dev); void DelRoute(const RouteInfo &cd); - bool AddRoute(int family, const void *dest, int dest_prefix, const void *gateway); + bool AddRoute(int family, const void *dest, int dest_prefix, const void *gateway, const char *dev); bool RunPrePostCommand(const std::vector &vec); WireguardProcessor *processor_; diff --git a/network_bsd_mt.cpp b/network_bsd_mt.cpp index 30caaf9..f7134a8 100644 --- a/network_bsd_mt.cpp +++ b/network_bsd_mt.cpp @@ -1,8 +1,11 @@ // SPDX-License-Identifier: AGPL-1.0-only // Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. +// Note: This is an experimental implementation that doesn't work, there's no way +// for the alarm signal to interrupt the tunsafe main thread. #include "network_bsd_common.h" #include "tunsafe_endian.h" #include "tunsafe_config.h" +#include "tunsafe_threading.h" #include "util.h" #include @@ -91,7 +94,7 @@ private: bool shutting_down_; bool got_sig_alarm_; - pthread_mutex_t lock_; + Mutex lock_; pthread_cond_t cond_; }; @@ -120,7 +123,7 @@ private: bool shutting_down_; - pthread_mutex_t lock_; + Mutex lock_; pthread_cond_t cond_; }; @@ -147,7 +150,7 @@ private: WorkerLoop *worker_; pthread_t read_tid_, write_tid_; Packet *queue_, **queue_end_; - pthread_mutex_t lock_; + Mutex lock_; pthread_cond_t cond_; }; @@ -158,12 +161,11 @@ WorkerLoop::WorkerLoop() { shutting_down_ = false; got_sig_alarm_ = false; processor_ = NULL; - pthread_mutex_init(&lock_, NULL); - pthread_cond_init(&cond_, NULL); + if (pthread_cond_init(&cond_, NULL) != 0) + tunsafe_die("pthread_cond_init failed"); } WorkerLoop::~WorkerLoop() { - pthread_mutex_destroy(&lock_); pthread_cond_destroy(&cond_); } @@ -174,13 +176,14 @@ bool WorkerLoop::Initialize(WireguardProcessor *processor) { void WorkerLoop::StartThread() { assert(tid_ == 0); - pthread_create(&tid_, NULL, &ThreadMainStatic, this); + if (pthread_create(&tid_, NULL, &ThreadMainStatic, this) != 0) + tunsafe_die("pthread_create failed"); } void WorkerLoop::StopThread() { - pthread_mutex_lock(&lock_); + lock_.Acquire(); shutting_down_ = true; - pthread_mutex_unlock(&lock_); + lock_.Release(); if (tid_) { void *x; @@ -198,16 +201,16 @@ void WorkerLoop::NotifyStop() { void WorkerLoop::HandlePacket(Packet *packet, int target) { // RINFO("WorkerLoop::HandlePacket"); packet->post_target = target; - pthread_mutex_lock(&lock_); + lock_.Acquire(); Packet *old_queue = queue_; *queue_end_ = packet; queue_end_ = &packet->next; packet->next = NULL; if (old_queue == NULL) { - pthread_mutex_unlock(&lock_); + lock_.Release(); pthread_cond_signal(&cond_); } else { - pthread_mutex_unlock(&lock_); + lock_.Release(); } } @@ -218,19 +221,19 @@ void *WorkerLoop::ThreadMainStatic(void *x) { void *WorkerLoop::ThreadMain() { Packet *packet_queue; - pthread_mutex_lock(&lock_); + lock_.Acquire(); for (;;) { // Grab the whole list for (;;) { while (got_sig_alarm_) { got_sig_alarm_ = false; - pthread_mutex_unlock(&lock_); + lock_.Release(); processor_->SecondLoop(); - pthread_mutex_lock(&lock_); + lock_.Acquire(); } if (shutting_down_ || queue_ != NULL) break; - pthread_cond_wait(&cond_, &lock_); + pthread_cond_wait(&cond_, lock_.impl()); } if (shutting_down_) break; @@ -238,7 +241,7 @@ void *WorkerLoop::ThreadMain() { queue_ = NULL; queue_end_ = &queue_; - pthread_mutex_unlock(&lock_); + lock_.Release(); // And send all items in the list while (packet_queue != NULL) { Packet *next = packet_queue->next; @@ -249,9 +252,9 @@ void *WorkerLoop::ThreadMain() { } packet_queue = next; } - pthread_mutex_lock(&lock_); + lock_.Acquire(); } - pthread_mutex_unlock(&lock_); + lock_.Release(); return NULL; } @@ -265,14 +268,13 @@ UdpLoop::UdpLoop() { worker_ = NULL; queue_ = NULL; queue_end_ = &queue_; - pthread_mutex_init(&lock_, NULL); - pthread_cond_init(&cond_, NULL); + if (pthread_cond_init(&cond_, NULL) != 0) + tunsafe_die("pthread_cond_init failed"); } UdpLoop::~UdpLoop() { if (fd_ != -1) close(fd_); - pthread_mutex_destroy(&lock_); pthread_cond_destroy(&cond_); } @@ -286,16 +288,18 @@ bool UdpLoop::Initialize(int listen_port, WorkerLoop *worker) { } void UdpLoop::Start() { - pthread_create(&read_tid_, NULL, &ReaderMainStatic, this); - pthread_create(&write_tid_, NULL, &WriterMainStatic, this); + if (pthread_create(&read_tid_, NULL, &ReaderMainStatic, this) != 0) + tunsafe_die("pthread_create failed"); + if (pthread_create(&write_tid_, NULL, &WriterMainStatic, this) != 0) + tunsafe_die("pthread_create failed"); } void UdpLoop::Stop() { void *x; - pthread_mutex_lock(&lock_); + lock_.Acquire(); shutting_down_ = true; - pthread_mutex_unlock(&lock_); + lock_.Release(); pthread_cond_signal(&cond_); pthread_kill(read_tid_, SIGUSR1); @@ -345,17 +349,17 @@ void *UdpLoop::ReaderMain() { void *UdpLoop::WriterMain() { Packet *queue; - pthread_mutex_lock(&lock_); + lock_.Acquire(); for (;;) { // Grab the whole list while (!shutting_down_ && queue_ == NULL) - pthread_cond_wait(&cond_, &lock_); + pthread_cond_wait(&cond_, lock_.impl()); if (shutting_down_) break; queue = queue_; queue_ = NULL; queue_end_ = &queue_; - pthread_mutex_unlock(&lock_); + lock_.Release(); // And send all items in the list while (queue != NULL) { int r = sendto(fd_, queue->data, queue->size, 0, @@ -370,9 +374,9 @@ void *UdpLoop::WriterMain() { queue = queue->next; FreePacket(to_free); } - pthread_mutex_lock(&lock_); + lock_.Acquire(); } - pthread_mutex_unlock(&lock_); + lock_.Release(); return NULL; } @@ -380,15 +384,15 @@ void UdpLoop::WriteUdpPacket(Packet *packet) { // RINFO("write udp packet to queue!"); packet->next = NULL; - pthread_mutex_lock(&lock_); + lock_.Acquire(); Packet *old_queue = queue_; *queue_end_ = packet; queue_end_ = &packet->next; if (old_queue == NULL) { - pthread_mutex_unlock(&lock_); + lock_.Release(); pthread_cond_signal(&cond_); } else { - pthread_mutex_unlock(&lock_); + lock_.Release(); } } @@ -400,14 +404,13 @@ TunLoop::TunLoop() { write_tid_ = 0; queue_ = NULL; queue_end_ = &queue_; - pthread_mutex_init(&lock_, NULL); - pthread_cond_init(&cond_, NULL); + if (pthread_cond_init(&cond_, NULL) != 0) + tunsafe_die("pthread_cond_init failed"); } TunLoop::~TunLoop() { if (fd_ != -1) close(fd_); - pthread_mutex_destroy(&lock_); pthread_cond_destroy(&cond_); } @@ -421,16 +424,18 @@ bool TunLoop::Initialize(char devname[16], WorkerLoop *worker) { } void TunLoop::Start() { - pthread_create(&read_tid_, NULL, &ReaderMainStatic, this); - pthread_create(&write_tid_, NULL, &WriterMainStatic, this); + if (pthread_create(&read_tid_, NULL, &ReaderMainStatic, this) != 0) + tunsafe_die("pthread_create failed"); + if (pthread_create(&write_tid_, NULL, &WriterMainStatic, this) != 0) + tunsafe_die("pthread_create failed"); } void TunLoop::Stop() { void *x; - pthread_mutex_lock(&lock_); + lock_.Acquire(); shutting_down_ = true; - pthread_mutex_unlock(&lock_); + lock_.Release(); pthread_kill(read_tid_, SIGUSR1); pthread_kill(write_tid_, SIGUSR1); @@ -469,18 +474,18 @@ void *TunLoop::ReaderMain() { void *TunLoop::WriterMain() { Packet *queue; - pthread_mutex_lock(&lock_); + lock_.Acquire(); for (;;) { // Grab the whole list while (!shutting_down_ && queue_ == NULL) { - pthread_cond_wait(&cond_, &lock_); + pthread_cond_wait(&cond_, lock_.impl()); } if (shutting_down_) break; queue = queue_; queue_ = NULL; queue_end_ = &queue_; - pthread_mutex_unlock(&lock_); + lock_.Release(); // And send all items in the list while (queue != NULL) { if (TUN_PREFIX_BYTES) @@ -494,24 +499,24 @@ void *TunLoop::WriterMain() { queue = queue->next; FreePacket(to_free); } - pthread_mutex_lock(&lock_); + lock_.Acquire(); } - pthread_mutex_unlock(&lock_); + lock_.Release(); return NULL; } void TunLoop::WriteTunPacket(Packet *packet) { packet->next = NULL; - pthread_mutex_lock(&lock_); + lock_.Acquire(); Packet *old_queue = queue_; *queue_end_ = packet; queue_end_ = &packet->next; if (old_queue == NULL) { - pthread_mutex_unlock(&lock_); + lock_.Release(); pthread_cond_signal(&cond_); } else { - pthread_mutex_unlock(&lock_); + lock_.Release(); } } diff --git a/network_win32.cpp b/network_win32.cpp index beb1b39..d364611 100644 --- a/network_win32.cpp +++ b/network_win32.cpp @@ -38,6 +38,8 @@ static SLIST_HEADER freelist_head; bool g_allow_pre_post; +static InternetBlockState GetInternetBlockState(bool *is_activated); + Packet *AllocPacket() { Packet *packet = (Packet*)InterlockedPopEntrySList(&freelist_head); if (packet == NULL) @@ -51,6 +53,40 @@ void FreePacket(Packet *packet) { InterlockedPushEntrySList(&freelist_head, &packet->list_entry); } +void OsGetRandomBytes(uint8 *data, size_t data_size) { + static BOOLEAN(APIENTRY *pfn)(void*, ULONG); + static bool resolved; + if (!resolved) { + pfn = (BOOLEAN(APIENTRY *)(void*, ULONG))GetProcAddress(LoadLibrary("ADVAPI32.DLL"), "SystemFunction036"); + resolved = true; + } + if (pfn && pfn(data, (ULONG)data_size)) + return; + size_t r = 0; + for (; r < data_size; r++) + data[r] = rand() >> 6; +} + +void OsInterruptibleSleep(int millis) { + SleepEx(millis, TRUE); +} + +uint64 OsGetMilliseconds() { + return GetTickCount64(); +} + +void OsGetTimestampTAI64N(uint8 dst[12]) { + SYSTEMTIME systime; + uint64 file_time_uint64 = 0; + GetSystemTime(&systime); + SystemTimeToFileTime(&systime, (FILETIME*)&file_time_uint64); + uint64 time_since_epoch_100ns = (file_time_uint64 - 116444736000000000); + uint64 secs_since_epoch = time_since_epoch_100ns / 10000000 + 0x400000000000000a; + uint32 nanos = (uint32)(time_since_epoch_100ns % 10000000) * 100; + WriteBE64(dst, secs_since_epoch); + WriteBE32(dst + 8, nanos); +} + extern "C" PSLIST_ENTRY __fastcall InterlockedPushListSList( IN PSLIST_HEADER ListHead, @@ -80,11 +116,6 @@ void InitPacketMutexes() { } } - -void CallbackUpdateUI(); -void CallbackTriggerReconnect(); -void CallbackSetPublicKey(const uint8 public_key[32]); - int tpq_last_qsize; int g_tun_reads, g_tun_writes; @@ -200,7 +231,7 @@ static bool GetTapAdapterGuid(char guid[64]) { } // Open the TAP adapter -static HANDLE OpenTunAdapter(char guid[64], int retry_count, bool *exit_thread, DWORD open_flags) { +static HANDLE OpenTunAdapter(char guid[64], int retry_count, uint32 *exit_thread, DWORD open_flags) { char path[128]; HANDLE h; int retries = 0; @@ -221,7 +252,16 @@ RETRY: if ((error_code == ERROR_FILE_NOT_FOUND || error_code == ERROR_GEN_FAILURE) && retry_count != 0 && !*exit_thread) { RERROR("OpenTapAdapter: CreateFile failed: 0x%X... retrying", error_code); retry_count--; - Sleep(250 * ++retries); + + int sleep_amount = 250 * ++retries; + for(;;) { + if (*exit_thread) + return NULL; + if (sleep_amount == 0) + break; + Sleep(50); + sleep_amount -= 50; + } goto RETRY; } @@ -239,7 +279,7 @@ RETRY: static bool AddRoute(int family, const void *dest, int dest_prefix, const void *gateway, const NET_LUID *interface_luid, - std::vector *undo_array = NULL) { + std::vector *undo_array) { MIB_IPFORWARD_ROW2 row = {0}; char buf1[kSizeOfAddress], buf2[kSizeOfAddress]; @@ -261,11 +301,12 @@ static bool AddRoute(int family, row.Metric = 100; row.Protocol = MIB_IPPROTO_NETMGMT; - if (undo_array) - undo_array->push_back(row); - DWORD error = CreateIpForwardEntry2(&row); if (error == NO_ERROR || error == ERROR_OBJECT_ALREADY_EXISTS) { + + if (undo_array) + undo_array->push_back(row); + RINFO("Added Route %s => %s", print_ip_prefix(buf1, family, dest, dest_prefix), print_ip_prefix(buf2, family, gateway, -1)); return true; @@ -352,7 +393,7 @@ static bool GetDefaultRouteAndDeleteOldRoutes(int family, const NET_LUID *Interf for (unsigned i = 0; i < table->NumEntries; i++) { MIB_IPFORWARD_ROW2 *row = &table->Table[i]; if (InterfaceLuid && memcmp(&row->InterfaceLuid, InterfaceLuid, sizeof(NET_LUID)) == 0) { - if (row->Protocol == MIB_IPPROTO_NETMGMT) + if (row->Protocol == MIB_IPPROTO_NETMGMT && !row->AutoconfigureAddress) DeleteRouteOrPrintErr(row); } else if (IsRouteOriginatingFromNullRoute(row)) { ri->found_null_routes++; @@ -422,8 +463,6 @@ UdpSocketWin32::UdpSocketWin32() { thread_ = NULL; socket_ipv6_ = INVALID_SOCKET; completion_port_handle_ = NULL; - - InitializeCriticalSectionAndSpinCount(&mutex_, 1024); } UdpSocketWin32::~UdpSocketWin32() { @@ -432,7 +471,6 @@ UdpSocketWin32::~UdpSocketWin32() { closesocket(socket_ipv6_); CloseHandle(completion_port_handle_); FreePacketList(wqueue_); - DeleteCriticalSection(&mutex_); } bool UdpSocketWin32::Initialize(int listen_on_port) { @@ -613,11 +651,11 @@ restart_read_udp: if (!pending_writes) { if (!wqueue_) break; - EnterCriticalSection(&mutex_); + mutex_.Acquire(); pending_writes = wqueue_; wqueue_end_ = &wqueue_; wqueue_ = NULL; - LeaveCriticalSection(&mutex_); + mutex_.Release(); if (!pending_writes) break; } @@ -690,11 +728,11 @@ void UdpSocketWin32::WriteUdpPacket(Packet *packet) { packet->next = NULL; qs.udp_qsize2 += packet->size; - EnterCriticalSection(&mutex_); + mutex_.Acquire(); Packet *was_empty = wqueue_; *wqueue_end_ = packet; wqueue_end_ = &packet->next; - LeaveCriticalSection(&mutex_); + mutex_.Release(); if (was_empty == NULL) { // Notify the worker thread that it should attempt more writes @@ -722,10 +760,9 @@ void UdpSocketWin32::StopThread() { thread_ = NULL; } -ThreadedPacketQueue::ThreadedPacketQueue(WireguardProcessor *wg, NetworkStats *stats) { +ThreadedPacketQueue::ThreadedPacketQueue(WireguardProcessor *wg, TunsafeBackendWin32 *backend) { wg_ = wg; - stats_ = stats; - InitializeCriticalSectionAndSpinCount(&mutex_, 1024); + backend_ = backend; event_ = CreateEvent(NULL, FALSE, FALSE, NULL); last_ptr_ = &first_; @@ -743,7 +780,6 @@ ThreadedPacketQueue::~ThreadedPacketQueue() { assert(timer_handle_ == NULL); first_ = NULL; last_ptr_ = &first_; - DeleteCriticalSection(&mutex_); CloseHandle(event_); } @@ -755,23 +791,29 @@ DWORD WINAPI ThreadedPacketQueue::ThreadedPacketQueueLauncher(VOID *x) { DWORD ThreadedPacketQueue::ThreadMain() { int free_packets_ctr = 0; int overload = 0; + Packet *packet; - EnterCriticalSection(&mutex_); + wg_->dev().SetCurrentThreadAsMainThread(); + + mutex_.Acquire(); while (!exit_flag_) { if (timer_interrupt_) { timer_interrupt_ = false; need_notify_ = 0; - LeaveCriticalSection(&mutex_); + mutex_.Release(); wg_->SecondLoop(); - EnterCriticalSection(&stats_->mutex); - if (stats_->reset_stats) { - stats_->reset_stats = false; - wg_->ResetStats(); - } - stats_->packet_stats = wg_->GetStats(); - LeaveCriticalSection(&stats_->mutex); + backend_->stats_mutex_.Acquire(); + backend_->stats_ = wg_->GetStats(); + float data[2] = { + // unit is megabits/second + backend_->stats_.tun_bytes_in_per_second * (1.0f / 125000), + backend_->stats_.tun_bytes_out_per_second * (1.0f / 125000), + }; + backend_->stats_collector_.AddSamples(data); + backend_->stats_mutex_.Release(); - CallbackUpdateUI(); + backend_->delegate_->OnGraphAvailable(); + backend_->PushStats(); // Conserve memory every 10s if (free_packets_ctr++ == 10) { @@ -780,46 +822,38 @@ DWORD ThreadedPacketQueue::ThreadMain() { } if (overload) overload -= 1; - EnterCriticalSection(&mutex_); - continue; - } - - // Grab the elements of the queue - Packet *packet = first_; - if (packet == NULL) { + } else if ((packet = first_) == NULL) { need_notify_ = 1; - LeaveCriticalSection(&mutex_); + mutex_.Release(); WaitForSingleObject(event_, INFINITE); - EnterCriticalSection(&mutex_); + } else { + // Steal the whole work queue + first_ = NULL; + last_ptr_ = &first_; + int packets_in_queue = packets_in_queue_; + packets_in_queue_ = 0; + need_notify_ = 0; + mutex_.Release(); - //SleepConditionVariableCS(&cv_, &mutex, INFINITE); - continue; + tpq_last_qsize = packets_in_queue; + if (packets_in_queue >= 1024) + overload = 2; + bool is_overload = (overload != 0); + + WireguardProcessor *procint = wg_; + do { + Packet *next = packet->next; + if (packet->post_target == TARGET_PROCESSOR_UDP) + procint->HandleUdpPacket(packet, is_overload); + else + procint->HandleTunPacket(packet); + packet = next; + } while (packet); } - // Steal the whole work queue - first_ = NULL; - last_ptr_ = &first_; - int packets_in_queue = packets_in_queue_; - packets_in_queue_ = 0; - need_notify_ = 0; - LeaveCriticalSection(&mutex_); - - tpq_last_qsize = packets_in_queue; - if (packets_in_queue >= 1024) - overload = 2; - bool is_overload = (overload != 0); - - WireguardProcessor *procint = wg_; - do { - Packet *next = packet->next; - if (packet->post_target == TARGET_PROCESSOR_UDP) - procint->HandleUdpPacket(packet, is_overload); - else - procint->HandleTunPacket(packet); - packet = next; - } while (packet); - EnterCriticalSection(&mutex_); + wg_->RunAllMainThreadScheduled(); + mutex_.Acquire(); } - LeaveCriticalSection(&mutex_); + mutex_.Release(); return 0; } @@ -837,9 +871,9 @@ void ThreadedPacketQueue::Start() { } void ThreadedPacketQueue::Stop() { - EnterCriticalSection(&mutex_); + mutex_.Acquire(); exit_flag_ = true; - LeaveCriticalSection(&mutex_); + mutex_.Release(); SetEvent(event_); @@ -859,15 +893,15 @@ void ThreadedPacketQueue::Stop() { } void ThreadedPacketQueue::AbortingDriver() { - EnterCriticalSection(&mutex_); + mutex_.Acquire(); exit_flag_ = true; - LeaveCriticalSection(&mutex_); + mutex_.Release(); } void ThreadedPacketQueue::Post(Packet *packet, Packet **end, int count) { - EnterCriticalSection(&mutex_); + mutex_.Acquire(); if (packets_in_queue_ >= HARD_MAXIMUM_QUEUE_SIZE) { - LeaveCriticalSection(&mutex_); + mutex_.Release(); FreePackets(packet, end, count); return; } @@ -883,11 +917,11 @@ void ThreadedPacketQueue::Post(Packet *packet, Packet **end, int count) { } if (need_notify_) { need_notify_ = 0; - LeaveCriticalSection(&mutex_); + mutex_.Release(); SetEvent(event_); return; } - LeaveCriticalSection(&mutex_); + mutex_.Release(); } void CALLBACK ThreadedPacketQueue::TimerRoutine(LPVOID lpArgToCompletionRoutine, DWORD dwTimerLowValue, DWORD dwTimerHighValue) { @@ -895,15 +929,15 @@ void CALLBACK ThreadedPacketQueue::TimerRoutine(LPVOID lpArgToCompletionRoutine, } void ThreadedPacketQueue::PostTimerInterrupt() { - EnterCriticalSection(&mutex_); + mutex_.Acquire(); timer_interrupt_ = true; if (need_notify_) { need_notify_ = 0; - LeaveCriticalSection(&mutex_); + mutex_.Release(); SetEvent(event_); return; } - LeaveCriticalSection(&mutex_); + mutex_.Release(); } bool GetNetLuidFromGuid(const char *adapter_guid, NET_LUID *luid) { @@ -1052,64 +1086,28 @@ static void ComputeIpv6DefaultRoute(const uint8 *ipv6_address, uint8 ipv6_cidr, } -static bool AddMultipleCatchallRoutes(int inet, int bits, const uint8 *target, const NET_LUID &luid) { +static bool AddMultipleCatchallRoutes(int inet, int bits, const uint8 *target, const NET_LUID &luid, std::vector *undo_array) { uint8 tmp[16] = {0}; bool success = true; for (int i = 0; i < (1 << bits); i++) { tmp[0] = i << (8 - bits); - success &= AddRoute(inet, tmp, bits, target, &luid); + success &= AddRoute(inet, tmp, bits, target, &luid, undo_array); } return success; } -static uint8 GetInternetRouteBlockingState() { - if (internet_route_blocking_state == ROUTE_BLOCK_UNKNOWN) { - RouteInfo ri; - internet_route_blocking_state = - (GetDefaultRouteAndDeleteOldRoutes(AF_INET, NULL, TRUE, NULL, &ri) && ri.found_null_routes == 2) + ROUTE_BLOCK_OFF; - } - return internet_route_blocking_state; -} - -static void SetInternetRouteBlockingState(bool want) { - if (want) { - internet_route_blocking_state = ROUTE_BLOCK_PENDING; - } else if (internet_route_blocking_state != ROUTE_BLOCK_OFF) { - RouteInfo ri; - GetDefaultRouteAndDeleteOldRoutes(AF_INET, NULL, FALSE, NULL, &ri); - GetDefaultRouteAndDeleteOldRoutes(AF_INET6, NULL, FALSE, NULL, &ri); - internet_route_blocking_state = ROUTE_BLOCK_OFF; - } -} - -InternetBlockState GetInternetBlockState(bool *is_activated) { - int a = GetInternetRouteBlockingState(); - int b = GetInternetFwBlockingState(); - - if (is_activated) - *is_activated = (a == ROUTE_BLOCK_ON || b == IBS_ACTIVE); - - return (InternetBlockState)( - (a >= ROUTE_BLOCK_ON) * kBlockInternet_Route + - (b >= IBS_ACTIVE) * kBlockInternet_Firewall); -} - -void SetInternetBlockState(InternetBlockState s) { - SetInternetRouteBlockingState((s & kBlockInternet_Route) != 0); - SetInternetFwBlockingState((s & kBlockInternet_Firewall) != 0); -} - -TunWin32Adapter::TunWin32Adapter() { +TunWin32Adapter::TunWin32Adapter(DnsBlocker *dns_blocker) { handle_ = NULL; - current_dns_block_ = NULL; + dns_blocker_ = dns_blocker; } TunWin32Adapter::~TunWin32Adapter() { } -bool TunWin32Adapter::OpenAdapter(bool *exit_thread, DWORD open_flags) { - int retry_count = 10; +bool TunWin32Adapter::OpenAdapter(uint32 *exit_thread, DWORD open_flags) { + assert(handle_ == NULL); + int retry_count = 20; handle_ = OpenTunAdapter(guid_, retry_count, exit_thread, open_flags); return (handle_ != NULL); } @@ -1230,7 +1228,7 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt if (has_dns_setting && config.block_dns_on_adapters) { RINFO("Blocking standard DNS on all adapters"); - current_dns_block_ = BlockDnsExceptOnAdapter(InterfaceLuid, config.ipv6_cidr != 0); + dns_blocker_->BlockDnsExceptOnAdapter(InterfaceLuid, config.ipv6_cidr != 0); err = SetMetricOnNetworkAdapter(&InterfaceLuid, AF_INET, 2); if (err) @@ -1241,6 +1239,8 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt if (err) RERROR("SetMetricOnNetworkAdapter IPv6 failed: %d", err); } + } else { + dns_blocker_->RestoreDns(); } uint8 ibs = config.internet_blocking; @@ -1279,10 +1279,10 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt RERROR("Unable to get localhost luid - while adding route based blocking."); } else { uint32 dst[4] = {0}; - if (!AddMultipleCatchallRoutes(AF_INET, 1, (uint8*)&dst, localhost_luid)) + if (!AddMultipleCatchallRoutes(AF_INET, 1, (uint8*)&dst, localhost_luid, NULL)) RERROR("Unable to add routes for route based blocking."); if (config.ipv6_cidr) { - if (!AddMultipleCatchallRoutes(AF_INET6, 1, (uint8*)&dst, localhost_luid)) + if (!AddMultipleCatchallRoutes(AF_INET6, 1, (uint8*)&dst, localhost_luid, NULL)) RERROR("Unable to add IPv6 routes for route based blocking."); } } @@ -1312,7 +1312,7 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt } // Either add 4 routes or 2 routes, depending on if we use route blocking. uint32 be = ToBE32(default_route_v4); - if (!AddMultipleCatchallRoutes(AF_INET, block_all_traffic_route ? 2 : 1, (uint8*)&be, InterfaceLuid)) + if (!AddMultipleCatchallRoutes(AF_INET, block_all_traffic_route ? 2 : 1, (uint8*)&be, InterfaceLuid, &routes_to_undo_)) RERROR("Unable to add new default ipv4 route."); } @@ -1331,7 +1331,7 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt return false; } } - if (!AddMultipleCatchallRoutes(AF_INET6, block_all_traffic_route ? 2 : 1, default_route_v6, InterfaceLuid)) + if (!AddMultipleCatchallRoutes(AF_INET6, block_all_traffic_route ? 2 : 1, default_route_v6, InterfaceLuid, &routes_to_undo_)) RERROR("Unable to add new default ipv6 route."); } } @@ -1340,9 +1340,20 @@ bool TunWin32Adapter::InitAdapter(const TunInterface::TunConfig &&config, TunInt for (auto it = config.extra_routes.begin(); it != config.extra_routes.end(); ++it) { if (it->size == 32) { uint32 be = ToBE32(default_route_v4); - AddRoute(AF_INET, it->addr, it->cidr, &be, &InterfaceLuid); + AddRoute(AF_INET, it->addr, it->cidr, &be, &InterfaceLuid, &routes_to_undo_); } else if (it->size == 128 && config.ipv6_cidr) { - AddRoute(AF_INET6, it->addr, it->cidr, default_route_v6, &InterfaceLuid); + AddRoute(AF_INET6, it->addr, it->cidr, default_route_v6, &InterfaceLuid, &routes_to_undo_); + } + } + + // Add all the routes that should bypass vpn + for (auto it = config.excluded_ips.begin(); it != config.excluded_ips.end(); ++it) { + if (it->size == 32) { + if (ri.found_default_adapter) + AddRoute(AF_INET, it->addr, it->cidr, ri.default_gw, &ri.default_adapter, &routes_to_undo_); + } else if (it->size == 128 && config.ipv6_cidr) { + if (ri6.found_default_adapter) + AddRoute(AF_INET6, it->addr, it->cidr, ri6.default_gw, &ri6.default_adapter, &routes_to_undo_); } } @@ -1386,9 +1397,9 @@ void TunWin32Adapter::CloseAdapter() { DeleteRoute(&*it); routes_to_undo_.clear(); - RestoreDnsExceptOnAdapter(current_dns_block_); - current_dns_block_ = NULL; - + if (dns_blocker_) + dns_blocker_->RestoreDns(); + RunPrePostCommand(post_down_); } @@ -1445,7 +1456,7 @@ static bool RunOneCommand(const std::string &cmd) { char *nl2 = nl; if (nl != buf + bufstart && nl[-1] == '\r') nl--; - bufstart = nl2 - buf + 1; + bufstart = (DWORD)(nl2 - buf + 1); RINFO("%.*s", nl - st, st); } if (bufend - bufstart == sizeof(buf) || foundeof) { @@ -1496,14 +1507,13 @@ bool TunWin32Adapter::RunPrePostCommand(const std::vector &vec) { ////////////////////////////////////////////////////////////////////////////// -TunWin32Iocp::TunWin32Iocp() { +TunWin32Iocp::TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker), backend_(backend) { wqueue_end_ = &wqueue_; wqueue_ = NULL; thread_ = NULL; completion_port_handle_ = NULL; packet_handler_ = NULL; - InitializeCriticalSectionAndSpinCount(&mutex_, 1024); exit_thread_ = false; } @@ -1511,13 +1521,12 @@ TunWin32Iocp::~TunWin32Iocp() { //assert(num_reads_ == 0 && num_writes_ == 0); assert(thread_ == NULL); CloseTun(); - DeleteCriticalSection(&mutex_); } bool TunWin32Iocp::Initialize(const TunConfig &&config, TunConfigOut *out) { - CloseTun(); + assert(thread_ == NULL); - if (!adapter_.OpenAdapter(&exit_thread_, FILE_FLAG_OVERLAPPED)) + if (!adapter_.OpenAdapter(&backend_->stop_mode_, FILE_FLAG_OVERLAPPED)) return false; completion_port_handle_ = CreateIoCompletionPort(adapter_.handle(), NULL, NULL, 0); @@ -1568,13 +1577,13 @@ void TunWin32Iocp::ThreadMain() { RERROR("TunWin32: ReadFile failed 0x%X", err); - if (err == ERROR_OPERATION_ABORTED) { + if (err == ERROR_OPERATION_ABORTED || err == ERROR_FILE_NOT_FOUND) { packet_handler_->AbortingDriver(); RERROR("TAP driver stopped communicating. Attempting to restart.", err); // This can happen if we reinstall the TAP driver while there's an active connection. Wait a bit, then attempt to // restart. Sleep(1000); - CallbackTriggerReconnect(); + backend_->TunAdapterFailed(); goto EXIT; } } else { @@ -1642,11 +1651,11 @@ void TunWin32Iocp::ThreadMain() { if (!pending_writes) { if (!wqueue_) break; - EnterCriticalSection(&mutex_); + mutex_.Acquire(); pending_writes = wqueue_; wqueue_end_ = &wqueue_; wqueue_ = NULL; - LeaveCriticalSection(&mutex_); + mutex_.Release(); if (!pending_writes) break; } @@ -1711,11 +1720,11 @@ void TunWin32Iocp::StopThread() { void TunWin32Iocp::WriteTunPacket(Packet *packet) { packet->next = NULL; - EnterCriticalSection(&mutex_); + mutex_.Acquire(); Packet *was_empty = wqueue_; *wqueue_end_ = packet; wqueue_end_ = &packet->next; - LeaveCriticalSection(&mutex_); + mutex_.Release(); if (was_empty == NULL) { // Notify the worker thread that it should attempt more writes PostQueuedCompletionStatus(completion_port_handle_, NULL, NULL, NULL); @@ -1726,7 +1735,7 @@ void TunWin32Iocp::WriteTunPacket(Packet *packet) { ////////////////////////////////////////////////////////////////////////////// -TunWin32Overlapped::TunWin32Overlapped() { +TunWin32Overlapped::TunWin32Overlapped(DnsBlocker *blocker, TunsafeBackendWin32 *backend) : adapter_(blocker), backend_(backend) { wqueue_end_ = &wqueue_; wqueue_ = NULL; @@ -1737,13 +1746,11 @@ TunWin32Overlapped::TunWin32Overlapped() { wake_event_ = CreateEvent(NULL, FALSE, FALSE, NULL); packet_handler_ = NULL; - InitializeCriticalSectionAndSpinCount(&mutex_, 1024); exit_thread_ = false; } TunWin32Overlapped::~TunWin32Overlapped() { CloseTun(); - DeleteCriticalSection(&mutex_); CloseHandle(read_event_); CloseHandle(write_event_); CloseHandle(wake_event_); @@ -1751,7 +1758,7 @@ TunWin32Overlapped::~TunWin32Overlapped() { bool TunWin32Overlapped::Initialize(const TunConfig &&config, TunConfigOut *out) { CloseTun(); - return adapter_.OpenAdapter(&exit_thread_, FILE_FLAG_OVERLAPPED) && + return adapter_.OpenAdapter(&backend_->stop_mode_, FILE_FLAG_OVERLAPPED) && adapter_.InitAdapter(std::move(config), out); } @@ -1809,11 +1816,11 @@ void TunWin32Overlapped::ThreadMain() { if (write_packet == NULL) { if (!pending_writes) { - EnterCriticalSection(&mutex_); + mutex_.Acquire(); pending_writes = wqueue_; wqueue_end_ = &wqueue_; wqueue_ = NULL; - LeaveCriticalSection(&mutex_); + mutex_.Release(); } if (pending_writes) { // Then issue writes @@ -1859,98 +1866,427 @@ void TunWin32Overlapped::StopThread() { void TunWin32Overlapped::WriteTunPacket(Packet *packet) { packet->next = NULL; - EnterCriticalSection(&mutex_); + mutex_.Acquire(); Packet *was_empty = wqueue_; *wqueue_end_ = packet; wqueue_end_ = &packet->next; - LeaveCriticalSection(&mutex_); + mutex_.Release(); if (was_empty == NULL) SetEvent(wake_event_); } - - - +void TunsafeBackendWin32::SetPublicKey(const uint8 key[32]) { + memcpy(public_key_, key, 32); + delegate_->OnStateChanged(); +} DWORD WINAPI TunsafeBackendWin32::WorkerThread(void *bk) { TunsafeBackendWin32 *backend = (TunsafeBackendWin32*)bk; + int stop_mode; - TunWin32Iocp tun; - UdpSocketWin32 udp; - WireguardProcessor wg_proc(&udp, &tun, backend->procdel_); + for(;;) { + TunWin32Iocp tun(&backend->dns_blocker_, backend); + UdpSocketWin32 udp; + WireguardProcessor wg_proc(&udp, &tun, backend); - ThreadedPacketQueue queues_for_processor(&wg_proc, &backend->stats_); + ThreadedPacketQueue queues_for_processor(&wg_proc, backend); - qs.udp_qsize1 = qs.udp_qsize2 = 0; + qs.udp_qsize1 = qs.udp_qsize2 = 0; - udp.SetPacketHandler(&queues_for_processor); - tun.SetPacketHandler(&queues_for_processor); + udp.SetPacketHandler(&queues_for_processor); + tun.SetPacketHandler(&queues_for_processor); - if (!ParseWireGuardConfigFile(&wg_proc, backend->config_file_, &backend->exit_flag_)) - goto getout; + wg_proc.dev().SetCurrentThreadAsMainThread(); - if (!wg_proc.Start()) - goto getout; + if (!ParseWireGuardConfigFile(&wg_proc, backend->config_file_, &backend->dns_resolver_)) + goto getout_fail; - queues_for_processor.Start(); - udp.StartThread(); - tun.StartThread(); - - CallbackSetPublicKey(wg_proc.dev().public_key()); - - while (!backend->exit_flag_) { - SleepEx(INFINITE, TRUE); + if (!wg_proc.Start()) + goto getout_fail; + + // only for use in callbacks from wg + backend->wg_processor_ = &wg_proc; + + queues_for_processor.Start(); + udp.StartThread(); + tun.StartThread(); + + backend->SetPublicKey(wg_proc.dev().public_key()); + + while ((stop_mode = InterlockedExchange(&backend->stop_mode_, MODE_NONE)) == MODE_NONE) { + SleepEx(INFINITE, TRUE); + } + + // Keep DNS alive + if (stop_mode != MODE_EXIT) + tun.adapter().DisassociateDnsBlocker(); + else + backend->dns_resolver_.ClearCache(); + + udp.StopThread(); + tun.StopThread(); + queues_for_processor.Stop(); + + backend->wg_processor_ = NULL; + + FreeAllPackets(); + + if (stop_mode != MODE_TUN_FAILED) + return 0; + + uint32 last_fail = GetTickCount(); + bool permanent_fail = (last_fail - backend->last_tun_adapter_failed_) < 5000; + backend->last_tun_adapter_failed_ = last_fail; + + backend->status_ = permanent_fail ? TunsafeBackend::kErrorTunPermanent : TunsafeBackend::kStatusTunRetrying; + backend->delegate_->OnStatusCode(backend->status_); + + if (permanent_fail) { + RERROR("Too many automatic restarts..."); + goto getout_fail; + } } - - udp.StopThread(); - tun.StopThread(); - queues_for_processor.Stop(); - - FreeAllPackets(); -getout: +getout_fail: + backend->dns_blocker_.RestoreDns(); + backend->status_ = TunsafeBackend::kErrorInitialize; + backend->delegate_->OnStatusCode(TunsafeBackend::kErrorInitialize); return 0; } static void WINAPI ExitServiceAPC(ULONG_PTR a) { - *(bool*)a = true; } -TunsafeBackendWin32::TunsafeBackendWin32() { +TunsafeBackend::TunsafeBackend() { + is_started_ = false; + is_remote_ = false; + ipv4_ip_ = 0; + status_ = kStatusStopped; + memset(public_key_, 0, sizeof(public_key_)); +} + +TunsafeBackend::~TunsafeBackend() { + +} + + +TunsafeBackendWin32::TunsafeBackendWin32(Delegate *delegate) : delegate_(delegate), dns_resolver_(&dns_blocker_) { memset(&stats_, 0, sizeof(stats_)); + wg_processor_ = NULL; InitPacketMutexes(); - InitializeCriticalSectionAndSpinCount(&stats_.mutex, 1024); worker_thread_ = NULL; + stop_mode_ = MODE_NONE; + last_tun_adapter_failed_ = 0; + want_periodic_stats_ = false; + + internet_route_blocking_state = ROUTE_BLOCK_UNKNOWN; + ClearInternetFwBlockingStateCache(); + + delegate_->OnStateChanged(); } TunsafeBackendWin32::~TunsafeBackendWin32() { - DeleteCriticalSection(&stats_.mutex); + StopInner(false); } -ProcessorStats TunsafeBackendWin32::GetStats() { - EnterCriticalSection(&stats_.mutex); - ProcessorStats stats = stats_.packet_stats; - LeaveCriticalSection(&stats_.mutex); - return stats; +bool TunsafeBackendWin32::Initialize() { + // it's always initialized + + return true; } -void TunsafeBackendWin32::Start(ProcessorDelegate *procdel, const char *config_file) { - Stop(); - procdel_ = procdel; - exit_flag_ = false; +void TunsafeBackendWin32::Teardown() { + +} + + +void TunsafeBackendWin32::RequestStats(bool enable) { + want_periodic_stats_ = enable; + PushStats(); +} + +void TunsafeBackendWin32::PushStats() { + if (want_periodic_stats_) { + stats_mutex_.Acquire(); + WgProcessorStats stats = stats_; + stats_mutex_.Release(); + delegate_->OnGetStats(stats); + } +} + +void TunsafeBackendWin32::Stop() { + StopInner(false); + delegate_->OnStatusCode(status_); + delegate_->OnStateChanged(); +} + +void TunsafeBackendWin32::Start(const char *config_file) { + StopInner(true); + stop_mode_ = MODE_NONE; // this needs to be here cause it's not reset on config file errors + dns_resolver_.SetAbortFlag(false); + is_started_ = true; + memset(public_key_, 0, sizeof(public_key_)); + status_ = kStatusInitializing; + delegate_->OnStatusCode(kStatusInitializing); + delegate_->OnClearLog(); DWORD thread_id; config_file_ = _strdup(config_file); worker_thread_ = CreateThread(NULL, 0, &WorkerThread, this, 0, &thread_id); SetThreadPriority(worker_thread_, THREAD_PRIORITY_ABOVE_NORMAL); + delegate_->OnStateChanged(); } -void TunsafeBackendWin32::Stop() { +void TunsafeBackendWin32::TunAdapterFailed() { + InterlockedExchange(&stop_mode_, MODE_TUN_FAILED); + QueueUserAPC(&ExitServiceAPC, worker_thread_, NULL); +} + +void TunsafeBackendWin32::StopInner(bool is_restart) { if (worker_thread_) { - QueueUserAPC(&ExitServiceAPC, worker_thread_, (ULONG_PTR)&exit_flag_); + ipv4_ip_ = 0; + dns_resolver_.SetAbortFlag(true); + InterlockedExchange(&stop_mode_, is_restart ? MODE_RESTART : MODE_EXIT); + QueueUserAPC(&ExitServiceAPC, worker_thread_, NULL); WaitForSingleObject(worker_thread_, INFINITE); CloseHandle(worker_thread_); worker_thread_ = NULL; free(config_file_); config_file_ = NULL; + is_started_ = false; + status_ = kStatusStopped; } } +void TunsafeBackendWin32::ResetStats() { +} + +LinearizedGraph *TunsafeBackendWin32::GetGraph(int type) { + if (type < 0 || type >= 4) + return NULL; + + size_t size = sizeof(LinearizedGraph) + 2 * (sizeof(uint32) + sizeof(float) * 120); + LinearizedGraph *graph = (LinearizedGraph *)malloc(size); + if (graph) { + graph->total_size = (uint32)size; + graph->num_charts = 2; + graph->graph_type = type; + memset(graph->reserved, 0, sizeof(graph->reserved)); + stats_mutex_.Acquire(); + + uint8 *ptr = (uint8*)(graph + 1); + for (size_t i = 0; i < 2; i++) { + *(uint32*)ptr = 120; + ptr += 4; + const StatsCollector::TimeSeries *series = stats_collector_.GetTimeSeries((int)i, type); + memcpy(postinc(ptr, (series->size - series->shift) * sizeof(float)), + series->data + series->shift, + (series->size - series->shift) * sizeof(float)); + memcpy(postinc(ptr, series->shift * sizeof(float)), series->data, series->shift * sizeof(float)); + } + stats_mutex_.Release(); + } + return graph; +} + + +static uint8 GetInternetRouteBlockingState() { + if (internet_route_blocking_state == ROUTE_BLOCK_UNKNOWN) { + RouteInfo ri; + internet_route_blocking_state = + (GetDefaultRouteAndDeleteOldRoutes(AF_INET, NULL, TRUE, NULL, &ri) && ri.found_null_routes == 2) + ROUTE_BLOCK_OFF; + } + return internet_route_blocking_state; +} + +static void SetInternetRouteBlockingState(bool want) { + if (want) { + internet_route_blocking_state = ROUTE_BLOCK_PENDING; + } else if (internet_route_blocking_state != ROUTE_BLOCK_OFF) { + RouteInfo ri; + GetDefaultRouteAndDeleteOldRoutes(AF_INET, NULL, FALSE, NULL, &ri); + GetDefaultRouteAndDeleteOldRoutes(AF_INET6, NULL, FALSE, NULL, &ri); + internet_route_blocking_state = ROUTE_BLOCK_OFF; + } +} + +static InternetBlockState GetInternetBlockState(bool *is_activated) { + int a = GetInternetRouteBlockingState(); + int b = GetInternetFwBlockingState(); + + if (is_activated) + *is_activated = (a == ROUTE_BLOCK_ON || b == IBS_ACTIVE); + + return (InternetBlockState)( + (a >= ROUTE_BLOCK_ON) * kBlockInternet_Route + + (b >= IBS_ACTIVE) * kBlockInternet_Firewall); +} + +InternetBlockState TunsafeBackendWin32::GetInternetBlockState(bool *is_activated) { + return ::GetInternetBlockState(is_activated); +} + +void TunsafeBackendWin32::SetInternetBlockState(InternetBlockState s) { + SetInternetRouteBlockingState((s & kBlockInternet_Route) != 0); + SetInternetFwBlockingState((s & kBlockInternet_Firewall) != 0); +} + +void TunsafeBackendWin32::SetServiceStartupFlags(uint32 flags) { + // not used +} + +std::string TunsafeBackendWin32::GetConfigFileName() { + return std::string(); +} + +void TunsafeBackendWin32::OnConnected() { + if (status_ != TunsafeBackend::kStatusConnected) { + ipv4_ip_ = ReadBE32(wg_processor_->tun_addr().addr); + if (status_ != TunsafeBackend::kStatusReconnecting) { + char buf[kSizeOfAddress]; + RINFO("Connection established. IP %s", print_ip_prefix(buf, AF_INET, wg_processor_->tun_addr().addr, -1)); + } + status_ = TunsafeBackend::kStatusConnected; + delegate_->OnStatusCode(TunsafeBackend::kStatusConnected); + } +} + +void TunsafeBackendWin32::OnConnectionRetry(uint32 attempts) { + if (status_ == TunsafeBackend::kStatusInitializing) { + status_ = TunsafeBackend::kStatusConnecting; + delegate_->OnStatusCode(TunsafeBackend::kStatusConnecting); + } else if (attempts >= 3 && status_ == TunsafeBackend::kStatusConnected) { + status_ = TunsafeBackend::kStatusReconnecting; + delegate_->OnStatusCode(TunsafeBackend::kStatusReconnecting); + } +} + +void TunsafeBackend::Delegate::DoWork() { + // implemented by subclasses +} + +TunsafeBackendDelegateThreaded::TunsafeBackendDelegateThreaded(TunsafeBackend::Delegate *delegate, const std::function &callback) { + callback_ = callback; + delegate_ = delegate; +} + +TunsafeBackendDelegateThreaded::~TunsafeBackendDelegateThreaded() { + for (auto it = incoming_entry_.begin(); it != incoming_entry_.end(); ++it) + FreeEntry(&*it); +} + +void TunsafeBackendDelegateThreaded::FreeEntry(Entry *e) { + if (e->lparam) { + free((void*)e->lparam); + e->lparam = NULL; + } +} + +void TunsafeBackendDelegateThreaded::DoWork() { + mutex_.Acquire(); + std::swap(incoming_entry_, processing_entry_); + mutex_.Release(); + TunsafeBackend::Delegate *delegate = delegate_; + for (auto it = processing_entry_.begin(); it != processing_entry_.end(); ++it) { + switch (it->which) { + case Id_OnGetStats: delegate->OnGetStats(*(WgProcessorStats*)it->lparam); break; + case Id_OnStateChanged: delegate->OnStateChanged(); break; + case Id_OnLogLine: delegate->OnLogLine((const char**)&it->lparam); break; + case Id_OnStatusCode: delegate->OnStatusCode((TunsafeBackend::StatusCode)it->wparam); break; + case Id_OnClearLog: delegate->OnClearLog(); break; + case Id_OnGraphAvailable: delegate->OnGraphAvailable(); break; + } + FreeEntry(&*it); + } + processing_entry_.clear(); +} + +void TunsafeBackendDelegateThreaded::AddEntry(Which which, intptr_t lparam, uint32 wparam) { + mutex_.Acquire(); + bool was_empty = incoming_entry_.empty(); + incoming_entry_.emplace_back(which, wparam, lparam); + mutex_.Release(); + if (was_empty) + callback_(); +} + +void TunsafeBackendDelegateThreaded::OnGetStats(const WgProcessorStats &stats) { + AddEntry(Id_OnGetStats, (intptr_t)memdup(&stats, sizeof(stats))); +} + +void TunsafeBackendDelegateThreaded::OnGraphAvailable() { + AddEntry(Id_OnGraphAvailable); +} + +void TunsafeBackendDelegateThreaded::OnStateChanged() { + AddEntry(Id_OnStateChanged); +} + +void TunsafeBackendDelegateThreaded::OnLogLine(const char **s) { + const char *ss = *s; + *s = NULL; + AddEntry(Id_OnLogLine, (intptr_t)ss); +} + +void TunsafeBackendDelegateThreaded::OnStatusCode(TunsafeBackend::StatusCode status) { + AddEntry(Id_OnStatusCode, 0, status); +} + +void TunsafeBackendDelegateThreaded::OnClearLog() { + AddEntry(Id_OnClearLog); +} + +TunsafeBackend::Delegate::~Delegate() { +} + +TunsafeBackend *CreateNativeTunsafeBackend(TunsafeBackend::Delegate *delegate) { + return new TunsafeBackendWin32(delegate); +} + +TunsafeBackend::Delegate *CreateTunsafeBackendDelegateThreaded(TunsafeBackend::Delegate *delegate, const std::function &callback) { + return new TunsafeBackendDelegateThreaded(delegate, callback); +} + +/////////////////////////////////////////////////// + +void StatsCollector::Init() { + Accumulator *acc = &accum_[0][0]; + static const int kAccMax[TIMEVALS] = {5, 6, 10, 0}; + + // Initialize all stats channels + for (uint32 channel = 0; channel != CHANNELS; channel++) { + for (uint32 timeval = 0; timeval != TIMEVALS; timeval++, acc++) { + acc->acc = 0; + acc->dirty = false; + acc->acc_count = 0; + acc->acc_max = kAccMax[timeval]; + acc->data.size = 120; + acc->data.data = (float*)calloc(sizeof(float), acc->data.size); + acc->data.shift = 0; + } + } +} + +void StatsCollector::AddToGraphDataSource(StatsCollector::TimeSeries *ts, float value) { + ts->data[ts->shift] = value; + if (++ts->shift == ts->size) + ts->shift = 0; +} + +void StatsCollector::AddToAccumulators(StatsCollector::Accumulator *acc, float rval) { + for (;;) { + AddToGraphDataSource(&acc->data, rval); + acc->dirty = true; + acc->acc += rval; + if (acc->acc_max == 0 || ++acc->acc_count < acc->acc_max) + break; + rval = acc->acc / (float)acc->acc_count; + acc->acc_count = 0; + acc->acc = 0.0f; + acc++; + } +} + +void StatsCollector::AddSamples(float data[CHANNELS]) { + for (size_t i = 0; i < CHANNELS; i++) + AddToAccumulators(&accum_[i][0], data[i]); +} + diff --git a/network_win32.h b/network_win32.h index a67f226..162d9b3 100644 --- a/network_win32.h +++ b/network_win32.h @@ -6,14 +6,18 @@ #include "tunsafe_types.h" #include "netapi.h" #include "network_win32_api.h" +#include "network_win32_dnsblock.h" +#include "wireguard_config.h" +#include "tunsafe_threading.h" +#include struct Packet; class WireguardProcessor; - +class TunsafeBackendWin32; class ThreadedPacketQueue { public: - explicit ThreadedPacketQueue(WireguardProcessor *wg, NetworkStats *stats); + explicit ThreadedPacketQueue(WireguardProcessor *wg, TunsafeBackendWin32 *backend); ~ThreadedPacketQueue(); enum { @@ -39,7 +43,7 @@ private: Packet **last_ptr_; uint32 packets_in_queue_; uint32 need_notify_; - CRITICAL_SECTION mutex_; + Mutex mutex_; HANDLE event_; HANDLE timer_handle_; @@ -47,7 +51,7 @@ private: WireguardProcessor *wg_; bool exit_flag_; bool timer_interrupt_; - NetworkStats *stats_; + TunsafeBackendWin32 *backend_; }; // Encapsulates a UDP socket, optionally listening for incoming packets @@ -74,7 +78,7 @@ private: // All packets queued for writing. Locked by |mutex_| Packet *wqueue_, **wqueue_end_; - CRITICAL_SECTION mutex_; + Mutex mutex_; ThreadedPacketQueue *packet_handler_; SOCKET socket_; @@ -85,22 +89,26 @@ private: bool exit_thread_; }; +class DnsBlocker; + class TunWin32Adapter { public: - TunWin32Adapter(); + TunWin32Adapter(DnsBlocker *dns_blocker); ~TunWin32Adapter(); - bool OpenAdapter(bool *exit_thread, DWORD open_flags); + bool OpenAdapter(unsigned int *exit_thread, DWORD open_flags); bool InitAdapter(const TunInterface::TunConfig &&config, TunInterface::TunConfigOut *out); void CloseAdapter(); HANDLE handle() { return handle_; } + void DisassociateDnsBlocker() { dns_blocker_ = NULL; } + private: bool RunPrePostCommand(const std::vector &vec); HANDLE handle_; - HANDLE current_dns_block_; + DnsBlocker *dns_blocker_; std::vector routes_to_undo_; uint8 mac_adress_[6]; @@ -113,7 +121,7 @@ private: // Implementation of TUN interface handling using IO Completion Ports class TunWin32Iocp : public TunInterface { public: - explicit TunWin32Iocp(); + explicit TunWin32Iocp(DnsBlocker *blocker, TunsafeBackendWin32 *backend); ~TunWin32Iocp(); void SetPacketHandler(ThreadedPacketQueue *packet_handler) { packet_handler_ = packet_handler; } @@ -125,6 +133,8 @@ public: virtual bool Initialize(const TunConfig &&config, TunConfigOut *out) override; virtual void WriteTunPacket(Packet *packet) override; + TunWin32Adapter &adapter() { return adapter_; } + private: void CloseTun(); void ThreadMain(); @@ -134,20 +144,21 @@ private: HANDLE completion_port_handle_; HANDLE thread_; - CRITICAL_SECTION mutex_; + Mutex mutex_; bool exit_thread_; // All packets queued for writing Packet *wqueue_, **wqueue_end_; + TunsafeBackendWin32 *backend_; TunWin32Adapter adapter_; }; // Implementation of TUN interface handling using Overlapped IO class TunWin32Overlapped : public TunInterface { public: - explicit TunWin32Overlapped(); + explicit TunWin32Overlapped(DnsBlocker *blocker, TunsafeBackendWin32 *backend); ~TunWin32Overlapped(); void SetPacketHandler(ThreadedPacketQueue *packet_handler) { packet_handler_ = packet_handler; } @@ -167,7 +178,7 @@ private: ThreadedPacketQueue *packet_handler_; HANDLE thread_; - CRITICAL_SECTION mutex_; + Mutex mutex_; HANDLE read_event_, write_event_, wake_event_; @@ -176,4 +187,111 @@ private: Packet *wqueue_, **wqueue_end_; TunWin32Adapter adapter_; + + TunsafeBackendWin32 *backend_; }; + +class TunsafeBackendWin32 : public TunsafeBackend, public ProcessorDelegate { + friend class ThreadedPacketQueue; + friend class TunWin32Iocp; + friend class TunWin32Overlapped; +public: + TunsafeBackendWin32(Delegate *delegate); + ~TunsafeBackendWin32(); + + // -- from TunsafeBackend + virtual bool Initialize() override; + virtual void Teardown() override; + virtual void Start(const char *config_file) override; + virtual void Stop() override; + virtual void RequestStats(bool enable) override; + virtual void ResetStats() override; + virtual InternetBlockState GetInternetBlockState(bool *is_activated) override; + virtual void SetInternetBlockState(InternetBlockState s) override; + virtual void SetServiceStartupFlags(uint32 flags) override; + virtual LinearizedGraph *GetGraph(int type) override; + virtual std::string GetConfigFileName() override; + + // -- from ProcessorDelegate + virtual void OnConnected() override; + virtual void OnConnectionRetry(uint32 attempts) override; + + void SetPublicKey(const uint8 key[32]); + void TunAdapterFailed(); +private: + + void StopInner(bool is_restart); + static DWORD WINAPI WorkerThread(void *x); + void PushStats(); + + HANDLE worker_thread_; + + enum { + MODE_NONE = 0, + MODE_EXIT = 1, + MODE_RESTART = 2, + MODE_TUN_FAILED = 3, + }; + + bool want_periodic_stats_; + unsigned int stop_mode_; + + Delegate *delegate_; + char *config_file_; + + DnsBlocker dns_blocker_; + DnsResolver dns_resolver_; + + WireguardProcessor *wg_processor_; + + uint32 last_tun_adapter_failed_; + StatsCollector stats_collector_; + + Mutex stats_mutex_; + WgProcessorStats stats_; +}; + +// This class ensures that all callbacks get rescheduled to another thread +class TunsafeBackendDelegateThreaded : public TunsafeBackend::Delegate { +public: + TunsafeBackendDelegateThreaded(TunsafeBackend::Delegate *delegate, const std::function &callback); + ~TunsafeBackendDelegateThreaded(); + +private: + virtual void OnGetStats(const WgProcessorStats &stats); + virtual void OnGraphAvailable(); + virtual void OnStateChanged(); + virtual void OnClearLog(); + virtual void OnLogLine(const char **s); + virtual void OnStatusCode(TunsafeBackend::StatusCode status); + virtual void DoWork(); + + enum Which { + Id_OnGetStats, + Id_OnStateChanged, + Id_OnClearLog, + Id_OnLogLine, + Id_OnUpdateUI, + Id_OnStatusCode, + Id_OnGraphAvailable, + }; + + void AddEntry(Which which, intptr_t lparam = 0, uint32 wparam = 0); + + TunsafeBackend::Delegate *delegate_; + std::function callback_; + + struct Entry { + uint8 which; + uint32 wparam; + intptr_t lparam; + Entry(uint8 which, uint32 wparam, intptr_t lparam) : which(which), wparam(wparam), lparam(lparam) {} + }; + + static void FreeEntry(Entry *e); + + Mutex mutex_; + std::vector incoming_entry_; + std::vector processing_entry_; +}; + diff --git a/network_win32_api.h b/network_win32_api.h index dac9856..bf5cf88 100644 --- a/network_win32_api.h +++ b/network_win32_api.h @@ -6,44 +6,115 @@ #include "tunsafe_types.h" #include "wireguard.h" -struct NetworkStats { - bool reset_stats; - CRITICAL_SECTION mutex; - ProcessorStats packet_stats; -}; +#include -class TunsafeBackendWin32 { +struct StatsCollector { public: - TunsafeBackendWin32(); - ~TunsafeBackendWin32(); - - void Start(ProcessorDelegate *procdel, const char *config_file); - void Stop(); - - ProcessorStats GetStats(); - void ResetStats() { stats_.reset_stats = true; } - - bool is_started() const { return worker_thread_ != NULL; } - + enum { + CHANNELS = 2, + TIMEVALS = 4, + }; + StatsCollector() { Init(); } + void AddSamples(float data[CHANNELS]); + struct TimeSeries { + float *data; + int size; + int shift; + }; + const TimeSeries *GetTimeSeries(int channel, int timeval) { return &accum_[channel][timeval].data; } private: - static DWORD WINAPI WorkerThread(void *x); - - NetworkStats stats_; - HANDLE worker_thread_; - bool exit_flag_; - - ProcessorDelegate *procdel_; - char *config_file_; + struct Accumulator { + float acc; + int acc_count; + int acc_max; + bool dirty; + TimeSeries data; + }; + void Init(); + static void AddToGraphDataSource(StatsCollector::TimeSeries *ts, float value); + static void AddToAccumulators(StatsCollector::Accumulator *acc, float rval); + Accumulator accum_[CHANNELS][TIMEVALS]; }; +struct LinearizedGraph { + uint32 total_size; + uint32 graph_type; + uint8 num_charts; + uint8 reserved[7]; +}; +class TunsafeBackend { +public: + // All codes < 0 are permanent errors + enum StatusCode { + kStatusStopped = 0, + kStatusInitializing = 1, + kStatusConnecting = 2, + kStatusConnected = 3, + kStatusReconnecting = 4, + kStatusTunRetrying = 10, -InternetBlockState GetInternetBlockState(bool *is_activated); + kErrorInitialize = -1, + kErrorTunPermanent = -2, + kErrorServiceLost = -3, + }; -// Returns if reconnect is needed -void SetInternetBlockState(InternetBlockState s); + static bool IsPermanentError(StatusCode status) { + return (int32)status < 0; + } + class Delegate { + public: + virtual ~Delegate(); + virtual void OnGetStats(const WgProcessorStats &stats) = 0; + virtual void OnGraphAvailable() = 0; + virtual void OnStateChanged() = 0; + virtual void OnClearLog() = 0; + virtual void OnLogLine(const char **s) = 0; + virtual void OnStatusCode(TunsafeBackend::StatusCode status) = 0; + // This function is needed for CreateTunsafeBackendDelegateThreaded, + // It's expected to be called on the main thread and then all callbacks will arrive + // on the right thread. + virtual void DoWork(); + }; + TunsafeBackend(); + virtual ~TunsafeBackend(); + + // Setup/teardown the connection to the local service (if any) + virtual bool Initialize() = 0; + virtual void Teardown() = 0; + + virtual void Start(const char *config_file) = 0; + virtual void Stop() = 0; + virtual void RequestStats(bool enable) = 0; + virtual void ResetStats() = 0; + + virtual InternetBlockState GetInternetBlockState(bool *is_activated) = 0; + virtual void SetInternetBlockState(InternetBlockState s) = 0; + virtual void SetServiceStartupFlags(uint32 flags) = 0; + + virtual std::string GetConfigFileName() = 0; + + virtual LinearizedGraph *GetGraph(int type) = 0; + + bool is_started() { return is_started_; } + bool is_remote() { return is_remote_; } + const uint8 *public_key() { return public_key_; } + + StatusCode status() { return status_; } + uint32 GetIP() { return ipv4_ip_; } + +protected: + bool is_started_; + bool is_remote_; + StatusCode status_; + uint32 ipv4_ip_; + uint8 public_key_[32]; +}; + +TunsafeBackend *CreateNativeTunsafeBackend(TunsafeBackend::Delegate *delegate); +TunsafeBackend::Delegate *CreateTunsafeBackendDelegateThreaded(TunsafeBackend::Delegate *delegate, const std::function &callback); extern int tpq_last_qsize; extern int g_tun_reads, g_tun_writes; diff --git a/network_win32_dnsblock.cpp b/network_win32_dnsblock.cpp index e17f09a..b76fb91 100644 --- a/network_win32_dnsblock.cpp +++ b/network_win32_dnsblock.cpp @@ -5,6 +5,7 @@ #include "network_win32_dnsblock.h" #include #include +#include #pragma comment (lib, "Fwpuclnt.lib") @@ -43,11 +44,19 @@ static inline bool FwpmFilterAddCheckedAleConnect(HANDLE handle, FWPM_FILTER0 *f return false; } } - return true; } -HANDLE BlockDnsExceptOnAdapter(const NET_LUID &luid, bool also_ipv6) { +DnsBlocker::DnsBlocker() { + also_ipv6_ = false; + handle_ = NULL; +} + +DnsBlocker::~DnsBlocker() { + RestoreDns(); +} + +bool DnsBlocker::BlockDnsExceptOnAdapter(const NET_LUID &luid, bool also_ipv6) { FWPM_SUBLAYER0 *sublayer = NULL; FWP_BYTE_BLOB *fwp_appid = NULL; @@ -56,6 +65,14 @@ HANDLE BlockDnsExceptOnAdapter(const NET_LUID &luid, bool also_ipv6) { DWORD err; HANDLE handle = NULL; + // Check if it already matches + if (handle_ != NULL) { + if (memcmp(&luid, &luid_, sizeof(luid)) == 0 && also_ipv6_) + return true; + FwpmEngineClose0(handle_); + handle_ = NULL; + } + { FWPM_SESSION0 session = {0}; session.flags = FWPM_SESSION_FLAG_DYNAMIC; @@ -69,7 +86,7 @@ HANDLE BlockDnsExceptOnAdapter(const NET_LUID &luid, bool also_ipv6) { { FWPM_SUBLAYER0 sublayer = {0}; sublayer.subLayerKey = TUNSAFE_DNS_SUBLAYER; - sublayer.displayData.name = L"TunSafe"; + sublayer.displayData.name = L"TunSafe DNS Block"; sublayer.weight = 0x100; err = FwpmSubLayerAdd0(handle, &sublayer, NULL); if (err != 0) { @@ -96,7 +113,7 @@ HANDLE BlockDnsExceptOnAdapter(const NET_LUID &luid, bool also_ipv6) { filter.filterCondition = filter_condition; filter.numFilterConditions = 2; filter.subLayerKey = TUNSAFE_DNS_SUBLAYER; - filter.displayData.name = L"TunSafe"; + filter.displayData.name = L"TunSafe DNS Block"; filter.weight.type = FWP_UINT8; filter.weight.uint8 = 15; filter.action.type = FWP_ACTION_PERMIT; @@ -127,15 +144,21 @@ getout: success: if (fwp_appid) FwpmFreeMemory0((void **)&fwp_appid); - return handle; + + handle_ = handle; + also_ipv6_ = also_ipv6; + luid_ = luid; + return handle != NULL; } -void RestoreDnsExceptOnAdapter(HANDLE h) { - if (h) +void DnsBlocker::RestoreDns() { + HANDLE h = handle_; + if (h) { + handle_ = NULL; FwpmEngineClose0(h); + } } - static bool RemovePersistentInternetBlockingInner(HANDLE handle) { FWPM_FILTER_ENUM_TEMPLATE0 enum_template = {0}; HANDLE enum_handle = NULL; @@ -337,6 +360,10 @@ getout: return false; } +void ClearInternetFwBlockingStateCache() { + internet_fw_blocking_state = 0; +} + uint8 GetInternetFwBlockingState() { if (internet_fw_blocking_state != 0) return internet_fw_blocking_state; diff --git a/network_win32_dnsblock.h b/network_win32_dnsblock.h index 1da7e64..3bc9c7f 100644 --- a/network_win32_dnsblock.h +++ b/network_win32_dnsblock.h @@ -2,13 +2,25 @@ // Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. #pragma once -HANDLE BlockDnsExceptOnAdapter(const NET_LUID &luid, bool also_ipv6 ); -void RestoreDnsExceptOnAdapter(HANDLE h); + +class DnsBlocker { +public: + DnsBlocker(); + ~DnsBlocker(); + + bool BlockDnsExceptOnAdapter(const NET_LUID &luid, bool also_ipv6); + void RestoreDns(); + bool IsActive() { return handle_ != NULL; } + + // Current state + NET_LUID luid_; + HANDLE handle_; + bool also_ipv6_; +}; + bool AddPersistentInternetBlocking(const NET_LUID *default_interface, const NET_LUID &luid_to_allow, bool also_ipv6); - - enum { IBS_UNKOWN, IBS_INACTIVE, @@ -18,3 +30,4 @@ enum { void SetInternetFwBlockingState(bool want); uint8 GetInternetFwBlockingState(); +void ClearInternetFwBlockingStateCache(); \ No newline at end of file diff --git a/resource.h b/resource.h index 3c10a98..3e32567 100644 Binary files a/resource.h and b/resource.h differ diff --git a/service_win32.cpp b/service_win32.cpp new file mode 100644 index 0000000..002abc9 --- /dev/null +++ b/service_win32.cpp @@ -0,0 +1,1179 @@ +// SPDX-License-Identifier: AGPL-1.0-only +// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. +#include "stdafx.h" +#include "service_win32.h" +#include +#include "util.h" +#include "network_win32_api.h" +#include +#include +#include +#include "util_win32.h" + +static const uint64 kTunsafeServiceProtocolVersion = 20180809001; + +static SERVICE_STATUS_HANDLE m_statusHandle; +static TunsafeServiceImpl *g_service; + +#define SERVICE_NAME L"TunSafeService" +#define SERVICE_NAMEA "TunSafeService" +#define SERVICE_START_TYPE SERVICE_AUTO_START +#define SERVICE_DEPENDENCIES L"tap0901\0dhcp\0" +#define SERVICE_ACCOUNT NULL +//L"NT AUTHORITY\\LocalService" +#define SERVICE_PASSWORD NULL +#define PIPE_NAME "\\\\.\\pipe\\TunSafe\\ServiceControl" + + +enum { + SERVICE_REQ_LOGIN = 0, + SERVICE_REQ_START = 1, + SERVICE_REQ_STOP = 2, + SERVICE_REQ_GETSTATS = 4, + SERVICE_REQ_SET_INTERNET_BLOCKSTATE = 5, + SERVICE_REQ_RESETSTATS = 6, + SERVICE_REQ_SET_STARTUP_FLAGS = 7, + + SERVICE_MSG_STATE = 8, + SERVICE_MSG_LOGLINE = 9, + SERVICE_MSG_STATS = 11, + SERVICE_MSG_CLEARLOG = 12, + SERVICE_MSG_STATUS_CODE = 14, + + SERVICE_REQ_GET_GRAPH = 15, + SERVICE_MSG_GRAPH = 16, +}; + +struct ServiceHandles { + SC_HANDLE manager; + SC_HANDLE service; + + ServiceHandles() : manager(NULL), service(NULL) {} + ~ServiceHandles() { + if (manager) + CloseServiceHandle(manager); + if (service) + CloseServiceHandle(service); + } + + bool Open(PWSTR pszServiceName, DWORD sc_rights, DWORD service_rights); + bool StopService(); + bool StartService(); +}; + + +static DWORD InstallService(PWSTR pszServiceName, + PWSTR pszDisplayName, + DWORD dwStartType, + PWSTR pszDependencies, + PWSTR pszAccount, + PWSTR pszPassword) { + wchar_t szPath[MAX_PATH + 32]; + ServiceHandles handles; + DWORD res; + + szPath[0] = '"'; + if (GetModuleFileNameW(NULL, szPath + 1, MAX_PATH) == 0) { + res = GetLastError(); + goto Cleanup; + } + size_t len = wcslen(szPath); + memcpy(szPath + len, L"\" --service", 12 * sizeof(wchar_t)); + + // Open the local default service control manager database + handles.manager = OpenSCManagerW(NULL, NULL, SC_MANAGER_CONNECT | + SC_MANAGER_CREATE_SERVICE); + if (handles.manager == NULL) { + res = GetLastError(); + goto Cleanup; + } + + // Install the service into SCM by calling CreateService + handles.service = CreateServiceW( + handles.manager, // SCManager database + pszServiceName, // Name of service + pszDisplayName, // Name to display + SERVICE_QUERY_STATUS, // Desired access + SERVICE_WIN32_OWN_PROCESS, // Service type + dwStartType, // Service start type + SERVICE_ERROR_NORMAL, // Error control type + szPath, // Service's binary + NULL, // No load ordering group + NULL, // No tag identifier + pszDependencies, // Dependencies + pszAccount, // Service running account + pszPassword // Password of the account + ); + if (handles.service == NULL) { + res = GetLastError(); + goto Cleanup; + } + { + SERVICE_DESCRIPTIONA desc; + desc.lpDescription = "TunSafe uses this service to connect to a VPN server in the background."; + ChangeServiceConfig2A(handles.service, SERVICE_CONFIG_DESCRIPTION, &desc); + } + res = 0; +Cleanup: + if (res && res != ERROR_SERVICE_EXISTS) + RERROR("TunSafe service installation failed: %d", res); + return res; +} + +bool ServiceHandles::Open(PWSTR pszServiceName, DWORD sc_rights, DWORD service_rights) { + manager = OpenSCManagerW(NULL, NULL, sc_rights); + if (manager == NULL) + return false; + service = OpenServiceW(manager, pszServiceName, service_rights); + return (service != NULL); +} + +bool ServiceHandles::StopService() { + SERVICE_STATUS ssSvcStatus = {}; + // Try to stop the service + if (ControlService(service, SERVICE_CONTROL_STOP, &ssSvcStatus)) { + Sleep(100); + while (QueryServiceStatus(service, &ssSvcStatus)) { + if (ssSvcStatus.dwCurrentState == SERVICE_STOP_PENDING) { + Sleep(100); + } else { + break; + } + } + } + return (ssSvcStatus.dwCurrentState == SERVICE_STOPPED); +} + +static wchar_t *GetUsernameOfCurrentUser(bool use_thread_token) { + HANDLE thread_token = NULL; + wchar_t *result = NULL; + DWORD len; + PTOKEN_USER token_user = NULL; + DWORD domain_len; + WCHAR username[256], domain[256]; + SID_NAME_USE sid_type; + + if (use_thread_token) { + if (!OpenThreadToken(GetCurrentThread(), TOKEN_ALL_ACCESS, FALSE, &thread_token)) + goto getout; + } else { + if (!OpenProcessToken(GetCurrentProcess(), TOKEN_ALL_ACCESS, &thread_token)) + goto getout; + + } + len = 0; + token_user = NULL; + while (!GetTokenInformation(thread_token, TokenUser, token_user, len, &len)) { + if (GetLastError() != ERROR_INSUFFICIENT_BUFFER) + goto getout; + token_user = (PTOKEN_USER)realloc(token_user, len); + if (!token_user) + goto getout; + } + if (!IsValidSid(token_user->User.Sid)) + goto getout; + domain_len = len = 256; + if (!LookupAccountSidW(NULL, token_user->User.Sid, username, &len, domain, &domain_len, &sid_type)) + goto getout; + + size_t alen = wcslen(username); + size_t blen = wcslen(domain); + + result = (wchar_t*)malloc((alen + blen + 2) * sizeof(wchar_t)); + if (result) { + result[alen] = '@'; + memcpy(result, username, alen * sizeof(wchar_t)); + memcpy(result + alen + 1, domain, (blen + 1) * sizeof(wchar_t)); + } +getout: + free(token_user); + CloseHandle(thread_token); + return result; +} + + +static DWORD GetNonTransientServiceStatus(SC_HANDLE service) { + SERVICE_STATUS ssSvcStatus = {}; + int delay = 100; + for(;;) { + if (!QueryServiceStatus(service, &ssSvcStatus)) + return 0; + + if (--delay == 0 || + ssSvcStatus.dwCurrentState != SERVICE_START_PENDING && + ssSvcStatus.dwCurrentState != SERVICE_STOP_PENDING) + return ssSvcStatus.dwCurrentState; + Sleep(100); + delay--; + } +} + + +bool ServiceHandles::StartService() { + DWORD state = GetNonTransientServiceStatus(service); + if (state == 0 || state == SERVICE_RUNNING) + return false; // service already running, no need to start + if (!::StartService(service, 0, NULL)) { +// if (GetLastError() == ERROR_SERVICE_ALREADY_RUNNING) +// return false; + return false; + } + return GetNonTransientServiceStatus(service) == SERVICE_RUNNING; +} + + +static bool StartTunsafeService() { + ServiceHandles handles; + + if (!handles.Open(SERVICE_NAME, SC_MANAGER_CONNECT, SERVICE_START | SERVICE_QUERY_STATUS)) + return false; + return handles.StartService(); +} + +bool IsTunsafeServiceRunning() { + ServiceHandles handles; + + if (!handles.Open(SERVICE_NAME, SC_MANAGER_CONNECT, SERVICE_QUERY_STATUS)) + return false; + + return GetNonTransientServiceStatus(handles.service) == SERVICE_RUNNING; +} + + +void StopTunsafeService() { + ServiceHandles handles; + if (!handles.Open(SERVICE_NAME, SC_MANAGER_CONNECT, + SERVICE_STOP | SERVICE_QUERY_STATUS)) + goto Cleanup; + handles.StopService(); +Cleanup: + return; +} + +static void SetTunsafeUserNameInRegistry() { + wchar_t *user = GetUsernameOfCurrentUser(false); + if (!user) { + RERROR("Unable to get current username"); + return; + } + HKEY hkey = NULL; + RegCreateKeyEx(HKEY_LOCAL_MACHINE, "Software\\TunSafe", NULL, NULL, 0, KEY_ALL_ACCESS, NULL, &hkey, NULL); + if (!hkey) { + RERROR("Unable to open registry key"); + return; + } + if (RegSetValueExW(hkey, L"AllowedUsername", NULL, REG_SZ, (BYTE*)user, (DWORD)(wcslen(user) + 1) * 2) != ERROR_SUCCESS) { + RERROR("Unable to set registry key"); + } + RegCloseKey(hkey); +} + +void InstallTunSafeWindowsService() { + InstallService(SERVICE_NAME, L"TunSafe Service", SERVICE_START_TYPE, + SERVICE_DEPENDENCIES, SERVICE_ACCOUNT, SERVICE_PASSWORD); + StartTunsafeService(); + SetTunsafeUserNameInRegistry(); +} + +bool UninstallTunSafeWindowsService() { + ServiceHandles handles; + + if (!handles.Open(SERVICE_NAME, SC_MANAGER_CONNECT, + SERVICE_STOP | SERVICE_QUERY_STATUS | DELETE)) + goto Cleanup; + + handles.StopService(); + + if (!DeleteService(handles.service)) + goto Cleanup; + return true; +Cleanup: + return false; +} + +bool IsTunSafeServiceInstalled() { + ServiceHandles handles; + return handles.Open(SERVICE_NAME, SC_MANAGER_CONNECT, SERVICE_QUERY_STATUS); +} + + +static void WriteServiceLog(const char *pszFunction, WORD dwError) { + char szMessage[260]; + snprintf(szMessage, ARRAYSIZE(szMessage), "%s failed w/err 0x%08lx", pszFunction, dwError); + HANDLE hEventSource = NULL; + LPCSTR lpszStrings[2] = {NULL, NULL}; + hEventSource = RegisterEventSourceW(NULL, SERVICE_NAME); + if (hEventSource) { + lpszStrings[0] = SERVICE_NAMEA; + lpszStrings[1] = szMessage; + + ReportEventA(hEventSource, // Event log handle + dwError, // Event type + 0, // Event category + 0, // Event identifier + NULL, // No security identifier + 2, // Size of lpszStrings array + 0, // No binary data + lpszStrings, // Array of strings + NULL // No binary data + ); + DeregisterEventSource(hEventSource); + } +} + +static void SetServiceStatus(DWORD dwCurrentState, + DWORD dwWin32ExitCode = 0, + DWORD dwWaitHint = 0) { + static DWORD dwCheckPoint = 1; + + SERVICE_STATUS m_status; + m_status.dwServiceType = SERVICE_WIN32_OWN_PROCESS; + m_status.dwControlsAccepted = SERVICE_ACCEPT_STOP | SERVICE_ACCEPT_SHUTDOWN; + m_status.dwServiceSpecificExitCode = 0; + m_status.dwCurrentState = dwCurrentState; + m_status.dwWin32ExitCode = dwWin32ExitCode; + m_status.dwWaitHint = dwWaitHint; + m_status.dwCheckPoint = + ((dwCurrentState == SERVICE_RUNNING) || + (dwCurrentState == SERVICE_STOPPED)) ? + 0 : dwCheckPoint++; + // Report the status of the service to the SCM. + ::SetServiceStatus(m_statusHandle, &m_status); +} + +static void OnServiceStart(DWORD dwArgc, PWSTR *pszArgv) { + WriteServiceLog("Service Starting", EVENTLOG_INFORMATION_TYPE); + SetServiceStatus(SERVICE_START_PENDING); + DWORD rv = g_service->OnStart(dwArgc, pszArgv); + if (rv) { + SetServiceStatus(SERVICE_STOPPED, rv); + } else { + SetServiceStatus(SERVICE_RUNNING); + } +} + +static void OnServiceStop() { + WriteServiceLog("Service Stopping", EVENTLOG_INFORMATION_TYPE); + SetServiceStatus(SERVICE_STOP_PENDING); + g_service->OnStop(); + SetServiceStatus(SERVICE_STOPPED); +} + +static void OnServiceShutdown() { + g_service->OnShutdown(); + SetServiceStatus(SERVICE_STOPPED); +} + +static void WINAPI ServiceCtrlHandler(DWORD dwCtrl) { + switch (dwCtrl) { + case SERVICE_CONTROL_STOP: OnServiceStop(); break; +// case SERVICE_CONTROL_PAUSE: OnServicePause(); break; +// case SERVICE_CONTROL_CONTINUE: OnServiceContinue(); break; + case SERVICE_CONTROL_SHUTDOWN: OnServiceShutdown(); break; + case SERVICE_CONTROL_INTERROGATE: break; + default: break; + } +} + +static void WINAPI ServiceMain(DWORD dwArgc, PWSTR *pszArgv) { + // Register the handler function for the service + m_statusHandle = RegisterServiceCtrlHandlerW(SERVICE_NAME, ServiceCtrlHandler); + if (m_statusHandle == NULL) + throw GetLastError(); + // Start the service. + OnServiceStart(dwArgc, pszArgv); +} + +static const SERVICE_TABLE_ENTRYW serviceTable[] = { + {SERVICE_NAME, ServiceMain}, + {NULL, NULL} +}; + +PipeMessageHandler::PipeMessageHandler(const char *pipe_name, bool is_server_pipe, Delegate *delegate) { + pipe_name_ = _strdup(pipe_name); + is_server_pipe_ = is_server_pipe; + delegate_ = delegate; + pipe_ = INVALID_HANDLE_VALUE; + wait_handles_[0] = CreateEvent(NULL, TRUE, FALSE, NULL); // for ReadFile + wait_handles_[1] = CreateEvent(NULL, FALSE, FALSE, NULL); // For Exit + wait_handles_[2] = CreateEvent(NULL, TRUE, FALSE, NULL); // for WriteFile + packets_ = NULL; + thread_ = NULL; + packets_end_ = &packets_; + write_overlapped_active_ = false; + exit_ = false; + connection_established_ = false; + thread_id_ = 0; +} + +PipeMessageHandler::~PipeMessageHandler() { + StopThread(); + CloseHandle(wait_handles_[0]); + CloseHandle(wait_handles_[1]); + CloseHandle(wait_handles_[2]); + free(pipe_name_); +} + +bool PipeMessageHandler::InitializeServerPipe() { + int BUFSIZE = 2048; + SECURITY_ATTRIBUTES saPipeSecurity = {0}; + uint8 buf[SECURITY_DESCRIPTOR_MIN_LENGTH]; + PSECURITY_DESCRIPTOR pPipeSD = (PSECURITY_DESCRIPTOR)buf; + + if (!InitializeSecurityDescriptor(pPipeSD, SECURITY_DESCRIPTOR_REVISION)) + return false; + + // set NULL DACL on the SD + if (!SetSecurityDescriptorDacl(pPipeSD, TRUE, (PACL)NULL, FALSE)) + return false; + + // now set up the security attributes + saPipeSecurity.nLength = sizeof(SECURITY_ATTRIBUTES); + saPipeSecurity.bInheritHandle = TRUE; + saPipeSecurity.lpSecurityDescriptor = pPipeSD; + + pipe_ = CreateNamedPipeW(L"\\\\.\\pipe\\TunSafe\\ServiceControl", + PIPE_ACCESS_DUPLEX | FILE_FLAG_OVERLAPPED, + PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE | PIPE_REJECT_REMOTE_CLIENTS | PIPE_WAIT, + PIPE_UNLIMITED_INSTANCES, + BUFSIZE, BUFSIZE, 0, &saPipeSecurity); + return pipe_ != INVALID_HANDLE_VALUE; +} + +bool PipeMessageHandler::InitializeClientPipe() { + assert(pipe_ == INVALID_HANDLE_VALUE); + pipe_ = CreateFile( + pipe_name_, + GENERIC_READ | GENERIC_WRITE, 0, + NULL, OPEN_EXISTING, FILE_FLAG_OVERLAPPED, NULL); + if (pipe_ == INVALID_HANDLE_VALUE) + return false; + DWORD mode = PIPE_READMODE_MESSAGE; + SetNamedPipeHandleState(pipe_, &mode, NULL, NULL); + return true; +} + +void PipeMessageHandler::ClosePipe() { + if (pipe_ != INVALID_HANDLE_VALUE) { + CancelIo(pipe_); + CloseHandle(pipe_); + pipe_ = INVALID_HANDLE_VALUE; + } + connection_established_ = false; + write_overlapped_active_ = false; + + packets_mutex_.Acquire(); + OutgoingPacket *packets = packets_; + packets_ = NULL; + packets_end_ = &packets_; + packets_mutex_.Release(); + while (packets) { + OutgoingPacket *p = packets; + packets = p->next; + free(p); + } +} + +bool PipeMessageHandler::WritePacket(int type, const uint8 *data, size_t data_size) { + OutgoingPacket *packet = (OutgoingPacket *)malloc(offsetof(OutgoingPacket, data[data_size + 1])); + if (packet) { + packet->size = (uint32)(data_size + 1); + packet->data[0] = type; + memcpy(packet->data + 1, data, data_size); + packet->next = NULL; + + packets_mutex_.Acquire(); + OutgoingPacket *was_empty = packets_; + // login messages are always queued up front + if (type == SERVICE_REQ_LOGIN) { + packet->next = packets_; + if (packet->next == NULL) + packets_end_ = &packet->next; + packets_ = packet; + } else { + *packets_end_ = packet; + packets_end_ = &packet->next; + } + packets_mutex_.Release(); + + if (was_empty == NULL) { + // Only allow the pipe thread to invoke the send + if (GetCurrentThreadId() == thread_id_) { + SendNextQueuedWrite(); + } else { + SetEvent(wait_handles_[1]); + } + } + } + return true; +} + +void PipeMessageHandler::SendNextQueuedWrite() { + assert(thread_id_ == GetCurrentThreadId()); + if (!write_overlapped_active_) { + OutgoingPacket *p = packets_; + if (p && connection_established_) { + memset(&write_overlapped_, 0, sizeof(write_overlapped_)); + write_overlapped_.hEvent = wait_handles_[2]; + if (WriteFile(pipe_, p->data, p->size, NULL, &write_overlapped_) || GetLastError() == ERROR_IO_PENDING) + write_overlapped_active_ = true; + } + } +} + +uint8 *PipeMessageHandler::ReadNamedPipeAsync(size_t *packet_size) { + OVERLAPPED ov = {0}; + uint8 *result = NULL; + DWORD bytes_waiting = 0; + DWORD rv; + ov.hEvent = wait_handles_[0]; + if (!ReadFile(pipe_, NULL, 0, NULL, &ov)) { + rv = GetLastError(); + if (rv != ERROR_IO_PENDING && rv != ERROR_MORE_DATA) + goto getout; + } + + if (!WaitAndHandleWrites(INFINITE)) { + CancelIo(pipe_); + write_overlapped_active_ = false; + goto getout; + } + + PeekNamedPipe(pipe_, NULL, 0, NULL, &bytes_waiting, NULL); + if (bytes_waiting == 0) + goto getout; // this is typically what happens when pipe closes. + + result = (uint8*)malloc(bytes_waiting); + if (!result) + goto getout; + + if (!ReadFile(pipe_, result, bytes_waiting, NULL, &ov)) { + rv = GetLastError(); + if (rv != ERROR_IO_PENDING) + goto getout; + } + if (!WaitAndHandleWrites(1000)) { + CancelIo(pipe_); + write_overlapped_active_ = false; + free(result); + result = NULL; + goto getout; + } + bytes_waiting = (uint32)ov.InternalHigh; + if (bytes_waiting == 0) { + free(result); + result = NULL; + goto getout; + } + *packet_size = bytes_waiting; +getout: + return result; +} + +bool PipeMessageHandler::ConnectNamedPipeAsync() { + OVERLAPPED ov = {0}; + DWORD rv; + bool result = false; + ov.hEvent = wait_handles_[0]; + if (!ConnectNamedPipe(pipe_, &ov)) { + rv = GetLastError(); + if (rv != ERROR_PIPE_CONNECTED && rv != ERROR_IO_PENDING) + goto getout; + } + if (!WaitAndHandleWrites(INFINITE)) { + CancelIo(pipe_); + write_overlapped_active_ = false; + goto getout; + } + result = true; +getout: + return result; +} + +bool PipeMessageHandler::WaitAndHandleWrites(int delay) { + DWORD rv; + assert(thread_id_ == GetCurrentThreadId()); + +again: + rv = WaitForMultipleObjects(2 + write_overlapped_active_, wait_handles_, FALSE, delay); + if (rv == WAIT_OBJECT_0 + 2) { + assert(write_overlapped_active_); + write_overlapped_active_ = false; + // Remove the packet from the front of the queue, now + // that it was sent. + packets_mutex_.Acquire(); + OutgoingPacket *p = packets_; + if ((packets_ = p->next) == NULL) + packets_end_ = &packets_; + packets_mutex_.Release(); + free(p); + SendNextQueuedWrite(); + goto again; + } + if (rv == WAIT_OBJECT_0 + 1) { + if (exit_ || !delegate_->HandleNotify()) + return false; + + SendNextQueuedWrite(); + goto again; + } + return rv == WAIT_OBJECT_0; +} + +DWORD WINAPI PipeMessageHandler::StaticThreadMain(void *x) { + return ((PipeMessageHandler*)x)->ThreadMain(); +} + +bool PipeMessageHandler::VerifyThread() { + return thread_id_ == GetCurrentThreadId(); +} + +DWORD PipeMessageHandler::ThreadMain() { + assert((thread_id_ = GetCurrentThreadId()) != 0); + + while (!exit_) { + // Create a named pipe and wait for connections from the UI process + if (is_server_pipe_) { + if (!InitializeServerPipe()) { + if (!exit_) + ExitProcess(1); + break; + } + // Wait for a client to connect to us. + if (!ConnectNamedPipeAsync()) { + if (!exit_) + ExitProcess(1); + break; + } + } else { + if (!InitializeClientPipe()) { + RINFO("Unable to connect to the TunSafe Service. Please make sure it's running."); + break; + } + } + + connection_established_ = true; + if (!delegate_->HandleNewConnection()) + goto closepipe; + + SendNextQueuedWrite(); + + // Read/Process each message + for (;;) { + size_t message_size; + uint8 *message = ReadNamedPipeAsync(&message_size); + if (!message) + break; + + if (message_size) { + if (!delegate_->HandleMessage(message[0], message + 1, message_size - 1)) { + FlushWrites(1000); + break; + } + } + free(message); + } + + if (exit_) + break; + + delegate_->HandleDisconnect(); + + if (!is_server_pipe_) + break; + +closepipe: + ClosePipe(); + } + + + ClosePipe(); + + return 0; +} + +void PipeMessageHandler::FlushWrites(int delay) { + ResetEvent(wait_handles_[0]); + WaitAndHandleWrites(1000); +} + +bool PipeMessageHandler::StartThread() { + DWORD thread_id; + assert(thread_ == NULL); + thread_ = CreateThread(NULL, 0, &StaticThreadMain, this, 0, &thread_id); + return thread_ != NULL; +} + +void PipeMessageHandler::StopThread() { + if (thread_ != NULL) { + exit_ = true; + SetEvent(wait_handles_[1]); + WaitForSingleObject(thread_, INFINITE); + CloseHandle(thread_); + thread_ = NULL; + } + ClosePipe(); +} + +TunsafeServiceImpl::TunsafeServiceImpl() + : message_handler_(PIPE_NAME, true, this) { + thread_delegate_ = CreateTunsafeBackendDelegateThreaded(this, [=] { + SetEvent(message_handler_.notify_handle()); + }); + + backend_ = CreateNativeTunsafeBackend(thread_delegate_); + historical_log_lines_count_ = historical_log_lines_pos_ = 0; + last_line_sent_ = 0; + did_send_getstate_ = false; + memset(historical_log_lines_, 0, sizeof(historical_log_lines_)); + hkey_ = NULL; + want_graph_type_ = 0xffffffff; + RegCreateKeyEx(HKEY_LOCAL_MACHINE, "Software\\TunSafe", NULL, NULL, 0, KEY_ALL_ACCESS, NULL, &hkey_, NULL); +} + +TunsafeServiceImpl::~TunsafeServiceImpl() { + RegCloseKey(hkey_); +} + +static wchar_t *RegReadStrW(HKEY hkey, const wchar_t *key, const wchar_t *def) { + wchar_t buf[1024]; + DWORD n = sizeof(buf) - 2; + DWORD type = 0; + if (RegQueryValueExW(hkey, key, NULL, &type, (BYTE*)buf, &n) != ERROR_SUCCESS || type != REG_SZ) + return def ? _wcsdup(def) : NULL; + n >>= 1; + if (n && buf[n - 1] == 0) + n--; + buf[n] = 0; + return _wcsdup(buf); +} + +unsigned TunsafeServiceImpl::OnStart(int argc, wchar_t **argv) { + message_handler_.StartThread(); + + uint32 service_flags = RegReadInt(hkey_, "ServiceStartupFlags", 0); + if ( (service_flags & kStartupFlag_BackgroundService) && (service_flags & kStartupFlag_ConnectWhenWindowsStarts) ) { + char *conf = RegReadStr(hkey_, "LastUsedConfigFile", ""); + if (conf && *conf) { + current_filename_ = (char*)conf; + backend_->Start((char*)conf); + } + free(conf); + } + + return 0; +} + +bool TunsafeServiceImpl::AuthenticateUser() { + did_authenticate_user_ = true; + + if (!ImpersonateNamedPipeClient(message_handler_.pipe_handle())) + return false; + wchar_t *user = GetUsernameOfCurrentUser(true); + RevertToSelf(); + if (!user) + return false; + wchar_t *valid_user = RegReadStrW(hkey_, L"AllowedUsername", L""); + bool rv = valid_user && wcscmp(user, valid_user) == 0; + + free(user); + free(valid_user); + return rv; +} + +bool TunsafeServiceImpl::HandleMessage(int type, uint8 *data, size_t size) { + if (!did_authenticate_user_) { + if (type != SERVICE_REQ_LOGIN || size < 8 || *(uint64*)data != kTunsafeServiceProtocolVersion) { + const char *s = "Versioning Problem: The TunSafe service is a different version than the UI."; + message_handler_.WritePacket(SERVICE_MSG_LOGLINE, (uint8*)s, strlen(s)); + return false; + } + if (!AuthenticateUser()) { + const char *s = "Permission Problem: Your Windows account is different from the account\r\nthat installed the TunSafe Service. Please reinstall it.\r\n"; + message_handler_.WritePacket(SERVICE_MSG_LOGLINE, (uint8*)s, strlen(s)); + return false; + } + } + + switch (type) { + case SERVICE_REQ_START: + if (data[size - 1] != 0) + return false; + + // Don't allow reading arbitrary files on disk + if (!EnsureValidConfigPath((char*)data)) { + char buf[MAX_PATH]; + GetConfigPath(buf, sizeof(buf)); + char *s = str_cat_alloc("Permission Problem: The Config file is in an unsafe location.\r\n Must be in:", buf, "\r\n"); + message_handler_.WritePacket(SERVICE_MSG_LOGLINE, (uint8*)s, strlen(s)); + free(s); + return false; + } + + g_allow_pre_post = RegReadInt(hkey_, "AllowPrePost", 0) != 0; + + current_filename_ = (char*)data; + backend_->Start((char*)data); + RegWriteStr(hkey_, "LastUsedConfigFile", (char*)data); + + break; + + case SERVICE_REQ_STOP: + backend_->Stop(); + RegWriteStr(hkey_, "LastUsedConfigFile", ""); + OnStateChanged(); + break; + + case SERVICE_REQ_LOGIN: + did_send_getstate_ = true; + OnStatusCode(backend_->status()); + OnStateChanged(); + SendQueuedLogLines(); + break; + + case SERVICE_REQ_GETSTATS: + if (size < 1) return false; + backend_->RequestStats(data[0] != 0); + break; + + case SERVICE_REQ_SET_INTERNET_BLOCKSTATE: + if (size < 1) + return false; + backend_->SetInternetBlockState((InternetBlockState)data[0]); + OnStateChanged(); + break; + + case SERVICE_REQ_RESETSTATS: + backend_->ResetStats(); + break; + + case SERVICE_REQ_GET_GRAPH: + if (size < 4) return false; + want_graph_type_ = *(int*)data; + TunsafeServiceImpl::OnGraphAvailable(); + break; + + case SERVICE_REQ_SET_STARTUP_FLAGS: + if (size < 4) + return false; + RegSetValueEx(hkey_, "ServiceStartupFlags", NULL, REG_DWORD, (BYTE*)data, 4); + break; + + default: + return false; + } + return true; +} + +bool TunsafeServiceImpl::HandleNotify() { + thread_delegate_->DoWork(); + return true; +} + +bool TunsafeServiceImpl::HandleNewConnection() { + did_send_getstate_ = false; + did_authenticate_user_ = false; + last_line_sent_ = 0; + return true; +} + +void TunsafeServiceImpl::HandleDisconnect() { + want_graph_type_ = 0xffffffff; + backend_->RequestStats(false); + uint32 service_flags = RegReadInt(hkey_, "ServiceStartupFlags", 0); + if (!(service_flags & kStartupFlag_BackgroundService)) + backend_->Stop(); +} + +void TunsafeServiceImpl::OnGraphAvailable() { + if (want_graph_type_ != 0xffffffff) { + LinearizedGraph *graph = backend_->GetGraph(want_graph_type_); + if (graph) + message_handler_.WritePacket(SERVICE_MSG_GRAPH, (uint8*)graph, graph->total_size); + } +} + +void TunsafeServiceImpl::SendQueuedLogLines() { + assert(message_handler_.VerifyThread()); + uint32 maxi = std::min(historical_log_lines_count_, historical_log_lines_pos_ - last_line_sent_); + last_line_sent_ = historical_log_lines_pos_; + for (uint32 i = 0; i < maxi; i++) { + const char *s = historical_log_lines_[(historical_log_lines_pos_ - maxi + i) & (LOGLINE_COUNT - 1)]; + if (s) + message_handler_.WritePacket(SERVICE_MSG_LOGLINE, (uint8*)s, strlen(s)); + } +} + +void TunsafeServiceImpl::OnClearLog() { + historical_log_lines_pos_ = 0; + historical_log_lines_count_ = 0; + message_handler_.WritePacket(SERVICE_MSG_CLEARLOG, NULL, 0); +} + +void TunsafeServiceImpl::OnLogLine(const char **s) { + assert(message_handler_.VerifyThread()); + char *ss = (char*)*s; + *s = NULL; + char *&x = historical_log_lines_[historical_log_lines_pos_++ & (LOGLINE_COUNT - 1)]; + std::swap(x, ss); + if (historical_log_lines_count_ < LOGLINE_COUNT) + historical_log_lines_count_++; + free(ss); + if (did_send_getstate_) + SendQueuedLogLines(); +} + +void TunsafeServiceImpl::OnGetStats(const WgProcessorStats &stats) { + message_handler_.WritePacket(SERVICE_MSG_STATS, (uint8*)&stats, sizeof(stats)); +} + +void TunsafeServiceImpl::OnStateChanged() { + uint8 *temp = new uint8[current_filename_.size() + 1 + sizeof(ServiceState)]; + bool is_activated; + + memset(temp, 0, sizeof(ServiceState)); + + ServiceState *ss = (ServiceState *)temp; + ss->is_started = backend_->is_started(); + ss->internet_block_state = backend_->GetInternetBlockState(&is_activated); + ss->internet_block_state_active = is_activated; + ss->ipv4_ip = backend_->GetIP(); + memcpy(ss->public_key, backend_->public_key(), 32); + + memcpy(temp + sizeof(ServiceState), current_filename_.c_str(), current_filename_.size() + 1); + message_handler_.WritePacket(SERVICE_MSG_STATE, temp, current_filename_.size() + 1 + sizeof(ServiceState)); + delete[] temp; +} + +void TunsafeServiceImpl::OnStatusCode(TunsafeBackend::StatusCode status) { + if (status == TunsafeBackend::kStatusConnected) + OnStateChanged(); // ensure we know the ip first + uint32 v32 = (uint32)status; + message_handler_.WritePacket(SERVICE_MSG_STATUS_CODE, (uint8*)&v32, 4); +} + +void TunsafeServiceImpl::OnStop() { + message_handler_.StopThread(); + backend_->Stop(); +} + +void TunsafeServiceImpl::OnShutdown() { + +} + +static void PushServiceLine(const char *s) { + if (g_service) { + char buf[64]; + SYSTEMTIME t; + + size_t l = strlen(s); + GetLocalTime(&t); + snprintf(buf, sizeof(buf), "[%.2d:%.2d:%.2d] ", t.wHour, t.wMinute, t.wSecond); + size_t tl = strlen(buf); + + char *x = (char*) malloc(tl + l + 3); + memcpy(x, buf, tl); + memcpy(x + tl, s, l); + x[l + tl] = '\r'; + x[l + tl + 1] = '\n'; + x[l + tl + 2] = '\0'; + g_service->delegate()->OnLogLine((const char**)&x); + free(x); + } else { + size_t l = strlen(s); + char buf[1024]; + SYSTEMTIME t; + GetLocalTime(&t); + + snprintf(buf, sizeof(buf), "[%.2d:%.2d:%.2d] ", t.wHour, t.wMinute, t.wSecond); + size_t tl = strlen(buf); + + if (l >= ARRAYSIZE(buf) - tl - 1) + l = ARRAYSIZE(buf) - tl - 1; + + memcpy(buf + tl, s, l); + buf[l + tl] = '\0'; + + WriteServiceLog(buf, EVENTLOG_INFORMATION_TYPE); + } +} + +BOOL RunProcessAsTunsafeServiceProcess() { + g_service = new TunsafeServiceImpl; + g_logger = &PushServiceLine; + + //g_service->OnStart(NULL, 0); + + //MessageBoxA(0, "Service running", "Service running", 0); + //return TRUE; +// while (true)Sleep(1000); + + // Connects the main thread of a service process to the service control + // manager, which causes the thread to be the service control dispatcher + // thread for the calling process. This call returns when the service has + // stopped. The process should simply terminate when the call returns. + return StartServiceCtrlDispatcherW(serviceTable); +} +TunsafeServiceClient::TunsafeServiceClient(TunsafeBackend::Delegate *delegate) + : message_handler_(PIPE_NAME, false, this) { + is_remote_ = true; + got_state_from_control_ = false; + delegate_ = delegate; + cached_graph_ = 0; + last_graph_type_ = 0xffffffff; + memset(&service_state_, 0, sizeof(service_state_)); +} + +TunsafeServiceClient::~TunsafeServiceClient() { + message_handler_.StopThread(); +} + +bool TunsafeServiceClient::Initialize() { + // Wait for the service to start + last_graph_type_ = 0xffffffff; + return message_handler_.StartThread(); +} + +void TunsafeServiceClient::Start(const char *config_file) { + message_handler_.WritePacket(SERVICE_REQ_START, (uint8*)config_file, strlen(config_file) + 1); +} + +void TunsafeServiceClient::Stop() { + message_handler_.WritePacket(SERVICE_REQ_STOP, NULL, 0); +} + +void TunsafeServiceClient::RequestStats(bool enable) { + want_stats_ = enable; + if (message_handler_.is_connected()) + message_handler_.WritePacket(SERVICE_REQ_GETSTATS, &want_stats_, 1); +} + +void TunsafeServiceClient::ResetStats() { + message_handler_.WritePacket(SERVICE_REQ_RESETSTATS, NULL, 0); +} + +InternetBlockState TunsafeServiceClient::GetInternetBlockState(bool *is_activated) { + if (is_activated) + *is_activated = service_state_.internet_block_state_active; + return (InternetBlockState)service_state_.internet_block_state; +} + +void TunsafeServiceClient::SetInternetBlockState(InternetBlockState s) { + uint8 v = (uint8)s; + message_handler_.WritePacket(SERVICE_REQ_SET_INTERNET_BLOCKSTATE, &v, 1); +} + +void TunsafeServiceClient::SetServiceStartupFlags(uint32 flags) { + message_handler_.WritePacket(SERVICE_REQ_SET_STARTUP_FLAGS, (uint8*)&flags, 4); +} + +LinearizedGraph *TunsafeServiceClient::GetGraph(int type) { + if (type != last_graph_type_) { + last_graph_type_ = type; + message_handler_.WritePacket(SERVICE_REQ_GET_GRAPH, (uint8*)&type, 4); + } + mutex_.Acquire(); + LinearizedGraph *graph = cached_graph_; + LinearizedGraph *new_graph = (graph && graph->graph_type == type) ? (LinearizedGraph*)memdup(graph, graph->total_size) : NULL; + mutex_.Release(); + return new_graph; +} + + +std::string TunsafeServiceClient::GetConfigFileName() { + mutex_.Acquire(); + std::string rv = config_file_; + mutex_.Release(); + return rv; +} + +bool TunsafeServiceClient::HandleMessage(int type, uint8 *data, size_t data_size) { + switch(type) { + case SERVICE_MSG_STATE: + if (data_size <= sizeof(service_state_) || data[data_size - 1]) + return false; + got_state_from_control_ = true; + + mutex_.Acquire(); + config_file_.assign((char*)data + sizeof(service_state_), data_size - 1 - sizeof(service_state_)); + memcpy(&service_state_, data, sizeof(service_state_)); + memcpy(public_key_, service_state_.public_key, 32); + is_started_ = service_state_.is_started; + ipv4_ip_ = service_state_.ipv4_ip; + mutex_.Release(); + delegate_->OnStateChanged(); + return true; + case SERVICE_MSG_LOGLINE: { + if (data_size == 0) + return false; + char *s = my_strndup((char*)data, data_size); + delegate_->OnLogLine((const char **)&s); + free(s); + return true; + } + case SERVICE_MSG_STATS: { + WgProcessorStats stats; + if (data_size != sizeof(WgProcessorStats)) + return false; + memcpy(&stats, data, sizeof(WgProcessorStats)); + delegate_->OnGetStats(stats); + return true; + } + case SERVICE_MSG_CLEARLOG: + delegate_->OnClearLog(); + return true; + + case SERVICE_MSG_STATUS_CODE: + if (data_size < 4) + return false; + status_ = (StatusCode)*(uint32*)data; + delegate_->OnStatusCode(status_); + return true; + + case SERVICE_MSG_GRAPH: + if (data_size < 4 || data_size != *(uint32*)data) + return false; + + LinearizedGraph *graph = (LinearizedGraph*)memdup(data, data_size); + mutex_.Acquire(); + std::swap(graph, cached_graph_); + mutex_.Release(); + free(graph); + delegate_->OnGraphAvailable(); + return true; + } + + return false; +} + +bool TunsafeServiceClient::HandleNotify() { + return true; +} + + +bool TunsafeServiceClient::HandleNewConnection() { + message_handler_.WritePacket(SERVICE_REQ_LOGIN, (uint8*)&kTunsafeServiceProtocolVersion, 8); + if (want_stats_) + message_handler_.WritePacket(SERVICE_REQ_GETSTATS, &want_stats_, 1); + return true; +} + +void TunsafeServiceClient::HandleDisconnect() { + status_ = TunsafeBackend::kErrorServiceLost; + delegate_->OnStatusCode(TunsafeBackend::kErrorServiceLost); +} + +void TunsafeServiceClient::Teardown() { + message_handler_.StopThread(); +} + +TunsafeBackend *CreateTunsafeServiceClient(TunsafeBackend::Delegate *delegate) { + TunsafeServiceClient *client = new TunsafeServiceClient(delegate); + if (client && !client->Initialize()) { + delete client; + client = NULL; + } + return client; +} + + diff --git a/service_win32.h b/service_win32.h new file mode 100644 index 0000000..eee61be --- /dev/null +++ b/service_win32.h @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: AGPL-1.0-only +// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. +#pragma once + +#include "service_win32_api.h" +#include +#include "util.h" +#include "network_win32_api.h" +#include "tunsafe_threading.h" +#include +#include +#include + +struct ServiceState { + uint8 is_started : 1; + uint8 internet_block_state_active : 1; + uint8 internet_block_state; + uint8 reserved[26+64]; + uint32 ipv4_ip; + uint8 public_key[32]; +}; + +STATIC_ASSERT(sizeof(ServiceState) == 128, ServiceState_wrong_size); + +class PipeMessageHandler { +public: + class Delegate { + public: + virtual bool HandleMessage(int type, uint8 *data, size_t size) = 0; + virtual bool HandleNotify() = 0; + virtual bool HandleNewConnection() = 0; + virtual void HandleDisconnect() = 0; + }; + + PipeMessageHandler(const char *pipe_name, bool is_server_pipe, Delegate *delegate); + ~PipeMessageHandler(); + + bool StartThread(); + void StopThread(); + + bool WritePacket(int type, const uint8 *data, size_t data_size); + + HANDLE notify_handle() { return wait_handles_[1]; } + HANDLE pipe_handle() { return pipe_; } + + bool VerifyThread(); + + void FlushWrites(int delay); + bool is_connected() { return connection_established_; } +private: + bool InitializeServerPipe(); + bool InitializeClientPipe(); + void ClosePipe(); + DWORD ThreadMain(); + void SendNextQueuedWrite(); + uint8 *ReadNamedPipeAsync(size_t *packet_size); + bool ConnectNamedPipeAsync(); + bool WaitAndHandleWrites(int delay); + static DWORD WINAPI StaticThreadMain(void *x); + + Delegate *delegate_; + + HANDLE pipe_; + HANDLE thread_; + HANDLE wait_handles_[3]; + OVERLAPPED write_overlapped_; + bool write_overlapped_active_; + bool exit_; + bool is_server_pipe_; + bool connection_established_; + char *pipe_name_; + + struct OutgoingPacket { + OutgoingPacket *next; + uint32 size; + uint8 data[0]; + }; + OutgoingPacket *packets_, **packets_end_; + + Mutex packets_mutex_; + + DWORD thread_id_; +}; + + +class TunsafeServiceImpl : public TunsafeBackend::Delegate, public PipeMessageHandler::Delegate { +public: + TunsafeServiceImpl(); + virtual ~TunsafeServiceImpl(); + + // -- from TunsafeBackend::Delegate + virtual void OnGetStats(const WgProcessorStats &stats); + virtual void OnClearLog(); + virtual void OnLogLine(const char **s); + virtual void OnStateChanged(); + virtual void OnStatusCode(TunsafeBackend::StatusCode status); + virtual void OnGraphAvailable(); + + // -- from PipeMessageHandler::Delegate + virtual bool HandleMessage(int type, uint8 *data, size_t size); + virtual bool HandleNotify(); + virtual bool HandleNewConnection(); + virtual void HandleDisconnect(); + + // virtual methods + virtual unsigned OnStart(int argc, wchar_t **argv); + virtual void OnStop(); + virtual void OnShutdown(); + + TunsafeBackend::Delegate *delegate() { return thread_delegate_; } + +private: + void SendQueuedLogLines(); + bool AuthenticateUser(); + + bool did_send_getstate_; + + bool did_authenticate_user_; + uint32 want_graph_type_; + + HKEY hkey_; + + TunsafeBackend *backend_; + TunsafeBackend::Delegate *thread_delegate_; + + PipeMessageHandler message_handler_; + + uint32 historical_log_lines_pos_; + uint32 historical_log_lines_count_; + uint32 last_line_sent_; + std::string current_filename_; + + enum { + LOGLINE_COUNT = 256 + }; + char *historical_log_lines_[LOGLINE_COUNT]; +}; + +class TunsafeServiceClient : public TunsafeBackend, public PipeMessageHandler::Delegate { +public: + TunsafeServiceClient(TunsafeBackend::Delegate *delegate); + virtual ~TunsafeServiceClient(); + virtual bool Initialize(); + virtual void Teardown(); + virtual void Start(const char *config_file); + virtual void Stop(); + virtual void RequestStats(bool enable); + virtual void ResetStats(); + virtual InternetBlockState GetInternetBlockState(bool *is_activated); + virtual void SetInternetBlockState(InternetBlockState s); + virtual std::string GetConfigFileName(); + virtual void SetServiceStartupFlags(uint32 flags); + virtual LinearizedGraph *GetGraph(int type); + + // -- from PipeMessageHandler::Delegate + virtual bool HandleMessage(int type, uint8 *data, size_t size); + virtual bool HandleNotify(); + virtual bool HandleNewConnection(); + virtual void HandleDisconnect(); + +protected: + TunsafeBackend::Delegate *delegate_; + uint8 want_stats_; + bool got_state_from_control_; + ServiceState service_state_; + std::string config_file_; + PipeMessageHandler message_handler_; + LinearizedGraph *cached_graph_; + uint32 last_graph_type_; + Mutex mutex_; +}; diff --git a/service_win32_api.h b/service_win32_api.h new file mode 100644 index 0000000..dac7151 --- /dev/null +++ b/service_win32_api.h @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: AGPL-1.0-only +// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. +#pragma once + +#include "network_win32_api.h" + +enum StartupFlags { + kStartupFlag_ForegroundService = 1, + kStartupFlag_BackgroundService = 2, + kStartupFlag_ConnectWhenWindowsStarts = 4, + kStartupFlag_MinimizeToTrayWhenWindowsStarts = 8, +}; + +BOOL RunProcessAsTunsafeServiceProcess(); + +void StopTunsafeService(); + +bool IsTunSafeServiceInstalled(); + +bool IsTunsafeServiceRunning(); +void InstallTunSafeWindowsService(); +bool UninstallTunSafeWindowsService(); + +TunsafeBackend *CreateTunsafeServiceClient(TunsafeBackend::Delegate *delegate); diff --git a/stdafx.h b/stdafx.h index bd6427f..61625f4 100644 --- a/stdafx.h +++ b/stdafx.h @@ -21,7 +21,7 @@ #include #include #include - +#include #include #else @@ -31,3 +31,5 @@ #include #include +#undef min + diff --git a/tunsafe_config.h b/tunsafe_config.h index 2f29472..d493b25 100644 --- a/tunsafe_config.h +++ b/tunsafe_config.h @@ -1,6 +1,8 @@ +// SPDX-License-Identifier: AGPL-1.0-only +// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. #pragma once -#define TUNSAFE_VERSION_STRING "TunSafe 1.3-rc3" +#define TUNSAFE_VERSION_STRING "TunSafe 1.4-rc1" #define WITH_HANDSHAKE_EXT 0 #define WITH_SHORT_HEADERS 0 diff --git a/tunsafe_cpu.cpp b/tunsafe_cpu.cpp index b1ee8cc..ec14aac 100644 --- a/tunsafe_cpu.cpp +++ b/tunsafe_cpu.cpp @@ -10,6 +10,16 @@ #include +static char *strcpy_e(char *dst, char *end, const char *copy) { + size_t len = strlen(copy); + if (len >= (size_t)(end - dst)) return end; + memcpy(dst, copy, len + 1); + return dst + len; +} + + +#if defined(ARCH_CPU_X86_FAMILY) + uint32 x86_pcap[3]; #if !defined(COMPILER_MSVC) @@ -22,6 +32,7 @@ static inline void __cpuid(int info[4], int func) { } #endif + void InitCpuFeatures() { unsigned nIds, nExIds; @@ -45,13 +56,6 @@ void InitCpuFeatures() { } } -static char *strcpy_e(char *dst, char *end, const char *copy) { - size_t len = strlen(copy); - if (len >= (size_t)(end - dst)) return end; - memcpy(dst, copy, len + 1); - return dst + len; -} - void PrintCpuFeatures() { char capbuf[2048], *end = capbuf + 2048, *s = capbuf; @@ -66,3 +70,22 @@ void PrintCpuFeatures() { RINFO("Using:%s", capbuf); } + +#endif // defined(ARCH_CPU_X86_FAMILY) + +#if defined(ARCH_CPU_ARM_FAMILY) + +uint32 arm_pcap[1]; + +void InitCpuFeatures() { + arm_pcap[0] = 0xffffffff; +} + +void PrintCpuFeatures() { + char capbuf[2048], *end = capbuf + 2048, *s = capbuf; + + if (ARM_PCAP_NEON) s = strcpy_e(s, end, " neon"); + + RINFO("Using:%s", capbuf); +} +#endif // defined(ARCH_CPU_ARM_FAMILY) diff --git a/tunsafe_cpu.h b/tunsafe_cpu.h index de97b6c..c19f6b1 100644 --- a/tunsafe_cpu.h +++ b/tunsafe_cpu.h @@ -5,6 +5,9 @@ #include "tunsafe_types.h" + +#if defined(ARCH_CPU_X86_FAMILY) + extern uint32 x86_pcap[3]; // cpuid 1, edx @@ -22,8 +25,19 @@ extern uint32 x86_pcap[3]; #define X86_PCAP_AVX512F (x86_pcap[2] & (1 << 16)) #define X86_PCAP_AVX512VL (x86_pcap[2] & (1 << 31)) +#endif // defined(ARCH_CPU_X86_FAMILY) + + +#if defined(ARCH_CPU_ARM_FAMILY) + +extern uint32 arm_pcap[1]; + +#define ARM_PCAP_NEON (arm_pcap[0] & (1 << 0)) + +#endif // defined(ARCH_CPU_ARM_FAMILY) + void InitCpuFeatures(); void PrintCpuFeatures(); -#endif // TUNSAFE_CPU_H_ \ No newline at end of file +#endif // TUNSAFE_CPU_H_ diff --git a/tunsafe_endian.h b/tunsafe_endian.h index 32bce5e..45cb316 100644 --- a/tunsafe_endian.h +++ b/tunsafe_endian.h @@ -70,6 +70,7 @@ #define ReadBE32Aligned(pt) ToBE32(*(uint32*)(pt)) #define WriteBE32Aligned(ct, st) (*(uint32*)(ct) = ToBE32(st)) +// todo: these need to support unaligned pointers #define ReadBE16(pt) ToBE16(*(uint16*)(pt)) #define WriteBE16(ct, st) (*(uint16*)(ct) = ToBE16(st)) #define ReadBE32(pt) ToBE32(*(uint32*)(pt)) diff --git a/tunsafe_threading.cpp b/tunsafe_threading.cpp new file mode 100644 index 0000000..af21db3 --- /dev/null +++ b/tunsafe_threading.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: AGPL-1.0-only +// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. +#include "stdafx.h" +#include "tunsafe_threading.h" +#include + +MultithreadedDelayedDelete::MultithreadedDelayedDelete() { + table_ = NULL; + num_threads_ = 0; +} + +MultithreadedDelayedDelete::~MultithreadedDelayedDelete() { + free(table_); +} + +void MultithreadedDelayedDelete::Initialize(uint32 num_threads) { + num_threads_ = num_threads; + table_ = (CheckpointData*)calloc(sizeof(CheckpointData), num_threads); +} + +void MultithreadedDelayedDelete::Add(DoDeleteFunc *func, void *param) { + if (num_threads_ == 0) { + func(param); + return; + } + lock_.Acquire(); + Entry e = {func, param}; + curr_.push_back(e); + lock_.Release(); +} + +void MultithreadedDelayedDelete::Checkpoint(uint32 thread_id) { + table_[thread_id].value.store(1); +} + +void MultithreadedDelayedDelete::MainCheckpoint() { + // Wait for all threads to signal that they reached the checkpoint + for (size_t i = 0; i < num_threads_; i++) { + if (table_[i].value.load() == 0) + return; + } + + // All threads reached the checkpoint, clear the values + for (size_t i = 0; i < num_threads_; i++) + table_[i].value.store(0); + + // Swap curr and next, and delete all nexts. + lock_.Acquire(); + std::swap(curr_, next_); + std::swap(curr_, to_delete_); + lock_.Release(); + + for (auto it = to_delete_.begin(); it != to_delete_.end(); ++it) { + it->func(it->param); + } + to_delete_.clear(); +} diff --git a/tunsafe_threading.h b/tunsafe_threading.h new file mode 100644 index 0000000..1362678 --- /dev/null +++ b/tunsafe_threading.h @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: AGPL-1.0-only +// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. +#pragma once +#include "tunsafe_types.h" +#include +#include +#include +#if !defined(OS_WIN) +#include +#endif // !defined(OS_WIN) + +#if defined(OS_WIN) + +class ReaderWriterLock { +public: + ReaderWriterLock() : lock_(SRWLOCK_INIT) {} + void AcquireExclusive() { AcquireSRWLockExclusive(&lock_); } + void AcquireShared() { AcquireSRWLockShared(&lock_); } + void ReleaseExclusive() { ReleaseSRWLockExclusive(&lock_); } + void ReleaseShared() { ReleaseSRWLockShared(&lock_); } +private: + SRWLOCK lock_; +}; + +class Mutex { +public: +#if defined(_DEBUG) + bool locked_; + bool IsLocked() { return locked_; } +#define Mutex_SETLOCKED(x) locked_ = x; +#else + bool IsLocked() { return false; } +#define Mutex_SETLOCKED(x) +#endif + Mutex() : lock_(SRWLOCK_INIT) { Mutex_SETLOCKED(false); } + ~Mutex() { } + void Acquire() { + AcquireSRWLockExclusive(&lock_); + Mutex_SETLOCKED(true); + } + void Release() { + Mutex_SETLOCKED(false); + ReleaseSRWLockExclusive(&lock_); + } +private: + SRWLOCK lock_; +}; + +typedef uint32 ThreadId; + +static inline bool CurrentThreadIdEquals(ThreadId thread_id) { + return thread_id == GetCurrentThreadId(); +} + +#else // defined(OS_WIN) + +class ReaderWriterLock { +public: + ReaderWriterLock() { + if (pthread_rwlock_init(&lock_, NULL) != 0) + tunsafe_die("pthread_rwlock_init failed"); + } + ~ReaderWriterLock() { + pthread_rwlock_destroy(&lock_); + } + void AcquireExclusive() { int rv = pthread_rwlock_wrlock(&lock_); assert(rv == 0); } + void AcquireShared() { int rv = pthread_rwlock_rdlock(&lock_); assert(rv == 0); } + void ReleaseExclusive() { int rv = pthread_rwlock_unlock(&lock_); assert(rv == 0); } + void ReleaseShared() { int rv = pthread_rwlock_unlock(&lock_); assert(rv == 0); } +private: + pthread_rwlock_t lock_; +}; + +class Mutex { +public: +#if defined(_DEBUG) + bool locked_; + bool IsLocked() { return locked_; } +#define Mutex_SETLOCKED(x) locked_ = x; +#else + bool IsLocked() { return false; } +#define Mutex_SETLOCKED(x) +#endif + Mutex() { + if (pthread_mutex_init(&lock_, NULL) != 0) + tunsafe_die("pthread_mutex_init failed"); + Mutex_SETLOCKED(false); + } + ~Mutex() { + pthread_mutex_destroy(&lock_); + } + void Acquire() { + int rv = pthread_mutex_lock(&lock_); + assert(rv == 0); + Mutex_SETLOCKED(true); + } + void Release() { + Mutex_SETLOCKED(false); + int rv = pthread_mutex_unlock(&lock_); + assert(rv == 0); + } + pthread_mutex_t *impl() { return &lock_; } +private: + pthread_mutex_t lock_; +}; + +typedef pthread_t ThreadId; + +static inline bool CurrentThreadIdEquals(ThreadId thread_id) { + return pthread_equal(thread_id, pthread_self()) != 0; +} + +static inline ThreadId GetCurrentThreadId() { + return pthread_self(); +} + +#endif // !defined(OS_WIN) + +class ScopedLockShared { +public: + ScopedLockShared(ReaderWriterLock *lock) : lock_(lock) { lock->AcquireShared(); } + ~ScopedLockShared() { lock_->ReleaseShared(); } +private: + ReaderWriterLock *lock_; +}; + +class ScopedLockExclusive { +public: + ScopedLockExclusive(ReaderWriterLock *lock) : lock_(lock) { lock->AcquireExclusive(); } + ~ScopedLockExclusive() { lock_->ReleaseExclusive(); } +private: + ReaderWriterLock *lock_; +}; + +class ScopedLock { +public: + ScopedLock(Mutex *lock) : lock_(lock) { lock->Acquire(); } + ~ScopedLock() { lock_->Release(); } +private: + Mutex *lock_; +}; + +// This class deletes objects delayed. All participating threads will call a function, +// and then once all threads did, all registered objects will get deleted. +class MultithreadedDelayedDelete { +public: + MultithreadedDelayedDelete(); + ~MultithreadedDelayedDelete(); + + typedef void DoDeleteFunc(void *x); + void Add(DoDeleteFunc *func, void *param); + + void Initialize(uint32 num_threads); + + void Checkpoint(uint32 thread_id); + + void MainCheckpoint(); + +private: + struct Entry { + DoDeleteFunc *func; + void *param; + }; + + struct CheckpointData { + std::atomic value; + uint8 align[60]; + }; + + uint32 num_threads_; + + std::vector curr_, next_, to_delete_; + CheckpointData *table_; + Mutex lock_; +}; diff --git a/tunsafe_types.h b/tunsafe_types.h index 9ddabab..7bdc600 100644 --- a/tunsafe_types.h +++ b/tunsafe_types.h @@ -68,6 +68,6 @@ static inline uint32 rol32(uint32 x, int8_t r) { void RERROR(const char *msg, ...); void RINFO(const char *msg, ...); - +void tunsafe_die(const char *msg); #endif // TINYVPN_TYPES_H_ diff --git a/tunsafe_win32.cpp b/tunsafe_win32.cpp index 846ce28..7a59d5e 100644 --- a/tunsafe_win32.cpp +++ b/tunsafe_win32.cpp @@ -23,8 +23,9 @@ #include #include #include "crypto/curve25519-donna.h" +#include "service_win32.h" +#include "util_win32.h" -#undef min #pragma comment(lib, "iphlpapi.lib") #pragma comment(lib, "rpcrt4.lib") #pragma comment(lib,"comctl32.lib") @@ -34,119 +35,103 @@ void InitCpuFeatures(); void PrintCpuFeatures(); void Benchmark(); static const char *GetCurrentConfigTitle(char *buf, size_t max_size); +static char *PrintMB(char *buf, int64 bytes); +static void LoadConfigFile(const char *filename, bool save, bool force_start); +static void SetCurrentConfigFilename(const char *filename); +static void CreateLocalOrRemoteBackend(bool remote); +static void UpdateGraphReq(); #pragma warning(disable: 4200) -static void MyPostMessage(int msg, WPARAM wparam, LPARAM lparam); - +static bool g_is_connected_to_server; +static bool g_notified_connected_server; static HWND g_ui_window; -static in_addr_t g_ui_ip; static HICON g_icons[2]; static bool g_minimize_on_connect; static bool g_ui_visible; static char *g_current_filename; -static HKEY g_reg_key; static HINSTANCE g_hinstance; -static TunsafeBackendWin32 *g_backend; -static bool g_last_popup_is_tray; +static TunsafeBackend *g_backend; +static TunsafeBackend::Delegate *g_backend_delegate; +static const char *g_cmdline_filename; +static bool g_first_state_msg; +static bool g_is_limited_uac_account; +static bool g_is_tunsafe_service_running; +static bool g_disable_connect_on_start; +static bool g_not_first_status_msg; +static HANDLE g_runonce_mutex; +static int g_startup_flags; +static HKEY g_reg_key; +static HKEY g_hklm_reg_key; +static HKEY g_hklm_readonly_reg_key; +static HWND hwndPaintBox, hwndStatus, hwndGraphBox, hwndTab, hwndAdvancedBox, hwndEdit; +static WgProcessorStats g_processor_stats; +static int g_large_fonts; +static TunsafeBackend::StatusCode g_status_code; +static UINT g_message_taskbar_created; +static int g_current_tab; +static bool wm_dropfiles_recursive; +static bool g_has_icon; +static int g_selected_graph_type; +static RECT comborect; +static HBITMAP arrowbitmap; +static uint32 g_timestamp_of_exit_menuloop; +enum UpdateIconWhy { + UIW_NONE = 0, + UIW_STOPPED_WORKING_FAIL = 1, + UIW_START = 2, +}; +static void UpdateIcon(UpdateIconWhy error); -int RegReadInt(const char *key, int def) { - DWORD value = def, n = sizeof(value); - RegQueryValueEx(g_reg_key, key, NULL, NULL, (BYTE*)&value, &n); - return value; + +int RescaleDpi(int size) { + return (g_large_fonts == 96) ? size : size * g_large_fonts / 96; } -void RegWriteInt(const char *key, int value) { - RegSetValueEx(g_reg_key, key, NULL, REG_DWORD, (BYTE*)&value, sizeof(value)); -} - -char *RegReadStr(const char *key, const char *def) { - char buf[1024]; - DWORD n = sizeof(buf) - 1; - DWORD type = 0; - if (RegQueryValueEx(g_reg_key, key, NULL, &type, (BYTE*)buf, &n) != ERROR_SUCCESS || type != REG_SZ) - return def ? _strdup(def) : NULL; - if (n && buf[n - 1] == 0) - n--; - buf[n] = 0; - return _strdup(buf); -} - -void RegWriteStr(const char *key, const char *v) { - RegSetValueEx(g_reg_key, key, NULL, REG_SZ, (BYTE*)v, (DWORD)strlen(v) + 1); -} - -void str_set(char **x, const char *s) { - free(*x); - *x = _strdup(s); -} - -char *str_cat_alloc(const char *a, const char *b) { - size_t al = strlen(a); - size_t bl = strlen(b); - char *r = (char *)malloc(al + bl + 1); - memcpy(r, a, al); - r[al + bl] = 0; - memcpy(r + al, b, bl); - return r; -} - -static const char *FindLastFolderSep(const char *s) { - size_t len = strlen(s); - for (;;) { - if (len == 0) - return NULL; - len--; - if (s[len] == '\\' || s[len] == '/') - break; +RECT RescaleDpiRect(const RECT &r) { + RECT rr = r; + if (g_large_fonts != 96) { + rr.left = rr.left * g_large_fonts / 96; + rr.top = rr.top * g_large_fonts / 96; + rr.right = rr.right * g_large_fonts / 96; + rr.bottom = rr.bottom * g_large_fonts / 96; } - return s + len; + return rr; } +static void SetUiVisibility(bool visible) { + g_ui_visible = visible; + ShowWindow(g_ui_window, visible ? SW_SHOW : SW_HIDE); + g_backend->RequestStats(visible); + UpdateGraphReq(); +} static bool GetConfigFullName(const char *basename, char *fullname, size_t fullname_size) { size_t len = strlen(basename); - if (FindLastFolderSep(basename)) { + if (FindFilenameComponent(basename)[0]) { if (len >= fullname_size) return false; memcpy(fullname, basename, len + 1); return true; } - if (!GetModuleFileName(NULL, fullname, (DWORD)fullname_size)) + size_t clen = GetConfigPath(fullname, fullname_size); + if (clen == 0 || clen + len >= fullname_size) return false; - char *last = (char *)FindLastFolderSep(fullname); - if (!last || last + len + 8 >= fullname + fullname_size) - return false; - memcpy(last + 1, "Config\\", 7 * sizeof(last[0])); - memcpy(last + 8, basename, (len + 1) * sizeof(last[0])); + memcpy(fullname + clen, basename, (len + 1) * sizeof(fullname[0])); return true; } -enum UpdateIconWhy { - UIW_NONE = 0, - UIW_STOPPED_WORKING_FAIL = 1, - UIW_STOPPED_WORKING_RETRY = 2, - UIW_EXITING = 3, -}; -static void UpdateIcon(UpdateIconWhy error); -static void UpdateButtons(); - - -void StopService(UpdateIconWhy error) { +void StopTunsafeBackend(UpdateIconWhy why) { if (g_backend->is_started()) { g_backend->Stop(); - - g_ui_ip = 0; - - if (error != UIW_EXITING) { - UpdateIcon(error); - RINFO("Disconnecting"); - UpdateButtons(); - RegWriteInt("IsConnected", 0); - } + if (g_is_connected_to_server) + RINFO("Disconnected"); + g_is_connected_to_server = false; + UpdateIcon(why); + RegWriteInt(g_reg_key, "IsConnected", 0); } } @@ -155,44 +140,146 @@ const char *print_ip(char buf[kSizeOfAddress], in_addr_t ip) { return buf; } -class MyProcessorDelegate : public ProcessorDelegate { -public: - virtual void OnConnected(in_addr_t my_ip) { - if (my_ip != g_ui_ip) { +void StartTunsafeBackend(UpdateIconWhy reason) { + if (!*g_current_filename) + return; - if (my_ip) { - char buf[kSizeOfAddress]; - print_ip(buf, my_ip); - RINFO("Connection established. IP %s", buf); + // recreate service connection + if (g_backend->status() == TunsafeBackend::kErrorServiceLost) + CreateLocalOrRemoteBackend(g_backend->is_remote()); + + if (g_backend->is_remote() && !EnsureValidConfigPath(g_current_filename)) { + RERROR("The config file needs to be in the Config-directory. Maybe the TunSafe\r\n process doesn't match with the running service. Try selecting 'Don't Use a Service'."); + StopTunsafeBackend(UIW_NONE); + return; + } + g_notified_connected_server = false; + g_is_connected_to_server = false; + g_backend->Start(g_current_filename); + RegWriteInt(g_reg_key, "IsConnected", 1); +} + +static void InvalidatePaintbox() { + InvalidateRect(hwndPaintBox, NULL, FALSE); +} + +class MyBackendDelegate : public TunsafeBackend::Delegate { +public: + virtual void OnGraphAvailable() { + InvalidateRect(hwndGraphBox, NULL, FALSE); + } + + virtual void OnGetStats(const WgProcessorStats &stats) { + g_processor_stats = stats; + InvalidatePaintbox(); + + char buf[64]; + uint32 mbs_in = (uint32)(stats.tun_bytes_out_per_second * (1.0 / 1250)); + uint32 gb_in = (uint32)(stats.tun_bytes_out * (1.0 / (1024 * 1024 * 1024 / 100))); + + snprintf(buf, ARRAYSIZE(buf), "D: %d.%.2d Mbps (%d.%.2d GB)", mbs_in / 100, mbs_in % 100, gb_in / 100, gb_in % 100); + SendMessage(hwndStatus, SB_SETTEXT, 1, (LPARAM)buf); + + uint32 mbs_out = (uint32)(stats.tun_bytes_in_per_second * (1.0 / 1250)); + uint32 gb_out = (uint32)(stats.tun_bytes_in * (1.0 / (1024 * 1024 * 1024 / 100))); + + snprintf(buf, ARRAYSIZE(buf), "U: %d.%.2d Mbps (%d.%.2d GB)", mbs_out / 100, mbs_out % 100, gb_out / 100, gb_out % 100); + SendMessage(hwndStatus, SB_SETTEXT, 2, (LPARAM)buf); + + InvalidateRect(hwndAdvancedBox, NULL, FALSE); + } + + virtual void OnLogLine(const char **s) { + CHARRANGE cr; + cr.cpMin = -1; + cr.cpMax = -1; + // hwnd = rich edit hwnd + SendMessage(hwndEdit, EM_EXSETSEL, 0, (LPARAM)&cr); + SendMessage(hwndEdit, EM_REPLACESEL, 0, (LPARAM)*s); + } + + virtual void OnStateChanged() { + if (!g_first_state_msg) { + g_first_state_msg = true; + char fullname[1024]; + + const char *filename = g_cmdline_filename; + if (filename) { + if (GetConfigFullName(filename, fullname, sizeof(fullname))) + SetCurrentConfigFilename(fullname); + } else { + std::string currconfig = g_backend->GetConfigFileName(); + if (currconfig.empty()) { + char *conf = RegReadStr(g_reg_key, "ConfigFile", "TunSafe.conf"); + if (GetConfigFullName(conf, fullname, sizeof(fullname))) + SetCurrentConfigFilename(fullname); + free(conf); + } else { + SetCurrentConfigFilename(currconfig.c_str()); + } } - g_ui_ip = my_ip; - MyPostMessage(WM_USER + 2, 0, 0); + + if (filename != NULL || !(g_startup_flags & kStartupFlag_BackgroundService) && !g_disable_connect_on_start && RegReadInt(g_reg_key, "IsConnected", 0)) { + StartTunsafeBackend(UIW_START); + } else { + if (!g_backend->is_started()) + RINFO("Press Connect to initiate a connection to the WireGuard server."); + } + } + + bool running = g_backend->is_started(); + SetDlgItemText(g_ui_window, ID_START, running ? "Re&connect" : "&Connect"); + InvalidatePaintbox(); + EnableWindow(GetDlgItem(g_ui_window, ID_STOP), running); + } + + virtual void OnStatusCode(TunsafeBackend::StatusCode status) override { + g_status_code = status; + if (TunsafeBackend::IsPermanentError(status)) { + UpdateIcon(g_is_connected_to_server ? UIW_STOPPED_WORKING_FAIL : UIW_NONE); + InvalidatePaintbox(); + return; + } + bool is_connected = (status == TunsafeBackend::kStatusConnected); + if (is_connected && g_minimize_on_connect) { + g_minimize_on_connect = false; + SetUiVisibility(false); + } + + bool not_first = g_not_first_status_msg; + g_not_first_status_msg = true; + + if (is_connected != g_is_connected_to_server) { + g_is_connected_to_server = is_connected; + // avoid showing a notice if service is already connected + if (is_connected > not_first && (g_startup_flags & kStartupFlag_BackgroundService)) + g_notified_connected_server = true; + UpdateIcon(UIW_NONE); + InvalidatePaintbox(); } } - virtual void OnDisconnected() { - MyProcessorDelegate::OnConnected(0); + + virtual void OnClearLog() override { + SetWindowText(hwndEdit, ""); } }; -static MyProcessorDelegate my_procdel; +static MyBackendDelegate my_procdel; -void StartService(bool skip_clear = false) { - char buf[1024]; - if (!GetConfigFullName(g_current_filename, buf, ARRAYSIZE(buf))) - return; - - if (!g_backend->is_started()) { - if (!skip_clear) - PostMessage(g_ui_window, WM_USER + 6, NULL, NULL); - - g_backend->Start(&my_procdel, buf); +static void CreateLocalOrRemoteBackend(bool remote) { + delete g_backend; - UpdateButtons(); - RegWriteInt("IsConnected", 1); + g_first_state_msg = false; + + if (!remote) { + g_backend = CreateNativeTunsafeBackend(g_backend_delegate); + } else { + RINFO("Connecting to the TunSafe Service..."); + g_backend = CreateTunsafeServiceClient(g_backend_delegate); } -} -static bool g_has_icon; + g_backend->RequestStats(g_ui_visible); +} static char *PrintMB(char *buf, int64 bytes) { char *bo = buf; @@ -215,55 +302,7 @@ static char *PrintMB(char *buf, int64 bytes) { return bo; } -static void UpdateStats() { - ProcessorStats stats = g_backend->GetStats(); - - char tmp[64], tmp2[64]; - char buf[512]; - snprintf(buf, 512, "%s received (%lld packets), %s sent (%lld packets)", - PrintMB(tmp, stats.udp_bytes_in), stats.udp_packets_in, - PrintMB(tmp2, stats.udp_bytes_out), stats.udp_packets_out/*, udp_qsize2 - udp_qsize1, g_tun_reads*/); - SetDlgItemText(g_ui_window, IDTXT_UDP, buf); - - snprintf(buf, 512, "%s received (%lld packets), %s sent (%lld packets)", - PrintMB(tmp, stats.tun_bytes_in), stats.tun_packets_in, - PrintMB(tmp2, stats.tun_bytes_out), stats.tun_packets_out/*, - tpq_last_qsize, g_tun_writes*/); - SetDlgItemText(g_ui_window, IDTXT_TUN, buf); - - char *d = buf; - if (stats.last_complete_handskake_timestamp) { - uint32 ago = (uint32)((OsGetMilliseconds() - stats.last_complete_handskake_timestamp) / 1000); - uint32 hours = ago / 3600; - uint32 minutes = (ago - hours * 3600) / 60; - uint32 seconds = (ago - hours * 3600 - minutes * 60); - - if (hours) - d += snprintf(d, 32, hours == 1 ? "%d hour, " : "%d hours, ", hours); - if (minutes) - d += snprintf(d, 32, minutes == 1 ? "%d minute, " : "%d minutes, ", minutes); - if (d == buf || seconds) - d += snprintf(d, 32, seconds == 1 ? "%d second, " : "%d seconds, ", seconds); - memcpy(d - 2, " ago", 5); - } else { - memcpy(buf, "(never)", 8); - } - SetDlgItemText(g_ui_window, IDTXT_HANDSHAKE, buf); -} - -void UpdatePublicKey(char *s) { - SetDlgItemText(g_ui_window, IDC_PUBLIC_KEY, s); - free(s); -} - -static void UpdateButtons() { - bool running = g_backend->is_started(); - SetDlgItemText(g_ui_window, ID_START, running ? "Re&connect" : "&Connect"); - EnableWindow(GetDlgItem(g_ui_window, ID_STOP), running); -} - static void UpdateIcon(UpdateIconWhy why) { - in_addr_t ip = g_ui_ip; NOTIFYICONDATA nid; memset(&nid, 0, sizeof(nid)); nid.cbSize = sizeof(nid); @@ -272,18 +311,22 @@ static void UpdateIcon(UpdateIconWhy why) { nid.uVersion = NOTIFYICON_VERSION; nid.uCallbackMessage = WM_USER + 1; nid.uFlags = NIF_MESSAGE | NIF_TIP | NIF_ICON; - nid.hIcon = g_icons[ip ? 0 : 1]; + nid.hIcon = g_icons[g_is_connected_to_server ? 0 : 1]; char buf[kSizeOfAddress]; - char namebuf[64]; - if (ip != 0) { - snprintf(nid.szTip, sizeof(nid.szTip), "TunSafe [%s - %s]", GetCurrentConfigTitle(namebuf, sizeof(namebuf)), print_ip(buf, ip)); - nid.uFlags |= NIF_INFO; - snprintf(nid.szInfoTitle, sizeof(nid.szInfoTitle), "Connected to: %s", namebuf); - snprintf(nid.szInfo, sizeof(nid.szInfo), "IP: %s", buf); - nid.uTimeout = 5000; - nid.dwInfoFlags = NIIF_INFO; + char namebuf[128]; + if (g_is_connected_to_server) { + snprintf(nid.szTip, sizeof(nid.szTip), "TunSafe [%s - %s]", GetCurrentConfigTitle(namebuf, sizeof(namebuf)), print_ip(buf, g_backend->GetIP())); + if (!g_notified_connected_server) { + g_notified_connected_server = true; + nid.uFlags |= NIF_INFO; + snprintf(nid.szInfoTitle, sizeof(nid.szInfoTitle), "Connected to: %s", namebuf); + snprintf(nid.szInfo, sizeof(nid.szInfo), "IP: %s", buf); + nid.uTimeout = 5000; + nid.dwInfoFlags = NIIF_INFO; + } } else { + g_notified_connected_server = false; snprintf(nid.szTip, sizeof(nid.szTip), "TunSafe [%s]", "Disconnected"); if (why == UIW_STOPPED_WORKING_FAIL) { @@ -296,7 +339,7 @@ static void UpdateIcon(UpdateIconWhy why) { } Shell_NotifyIcon(g_has_icon ? NIM_MODIFY : NIM_ADD, &nid); - SendMessage(g_ui_window, WM_SETICON, ICON_SMALL, (LPARAM)g_icons[ip ? 0 : 1]); + SendMessage(g_ui_window, WM_SETICON, ICON_SMALL, (LPARAM)g_icons[g_is_connected_to_server ? 0 : 1]); g_has_icon = true; } @@ -312,16 +355,10 @@ static void RemoveIcon() { } } -#define MAX_CONFIG_FILES 100 +#define MAX_CONFIG_FILES 1024 #define ID_POPUP_CONFIG_FILE 10000 char *config_filenames[MAX_CONFIG_FILES]; - -static void RestartService(UpdateIconWhy why, bool only_if_active) { - if (!only_if_active || g_backend->is_started()) { - StopService(why); - StartService(why != UIW_NONE); - } -} +uint8 config_filenames_indent[MAX_CONFIG_FILES]; static char *StripConfExtension(const char *src, char *target, size_t size) { size_t len = strlen(src); @@ -335,65 +372,121 @@ static char *StripConfExtension(const char *src, char *target, size_t size) { } static const char *GetCurrentConfigTitle(char *target, size_t size) { - const char *ll = FindLastFolderSep(g_current_filename); - return StripConfExtension(ll ? ll + 1 : g_current_filename, target, size); + const char *ll = FindFilenameComponent(g_current_filename); + return StripConfExtension(ll, target, size); } -static void LoadConfigFile(const char *filename, bool save, bool force_start) { +static void SetCurrentConfigFilename(const char *filename) { str_set(&g_current_filename, filename); char namebuf[64]; - char *f = str_cat_alloc("TunSafe VPN Client - ", GetCurrentConfigTitle(namebuf, sizeof(namebuf))); + char *f = str_cat_alloc("TunSafe - ", GetCurrentConfigTitle(namebuf, sizeof(namebuf))); SetWindowText(g_ui_window, f); free(f); - RestartService(UIW_NONE, !force_start); - if (save) - RegWriteStr("ConfigFile", filename); + + InvalidateRect(hwndPaintBox, NULL, FALSE); } -static void AddToAvailableFilesPopup(HMENU menu, int max_num_items, bool is_settings) { - char buf[1024]; - int nfiles = 0; - if (!GetConfigFullName("*.*", buf, ARRAYSIZE(buf))) + +static void LoadConfigFile(const char *filename, bool save, bool force_start) { + SetCurrentConfigFilename(filename); + + if (force_start || g_backend->is_started()) + StartTunsafeBackend(UIW_START); + + if (save) + RegWriteStr(g_reg_key, "ConfigFile", filename); +} + +class ConfigMenuBuilder { +public: + ConfigMenuBuilder(); + + void Recurse(); + + int depth_; + int nfiles_; + size_t bufpos_; + WIN32_FIND_DATA wfd_; + char buf_[1024]; +}; + +ConfigMenuBuilder::ConfigMenuBuilder() + : nfiles_(0), depth_(0) { + if (!GetConfigFullName("", buf_, sizeof(buf_))) + bufpos_ = sizeof(buf_); + else + bufpos_ = strlen(buf_); +} + +void ConfigMenuBuilder::Recurse() { + if (bufpos_ >= sizeof(buf_) - 4) return; - - int selected_item = -1; - WIN32_FIND_DATA wfd; - HANDLE handle = FindFirstFile(buf, &wfd); + memcpy(buf_ + bufpos_, "*.*", 4); + HANDLE handle = FindFirstFile(buf_, &wfd_); if (handle != INVALID_HANDLE_VALUE) { do { - if (wfd.cFileName[0] == '.') + if (wfd_.cFileName[0] == '.') continue; - if (strcmp(g_current_filename, wfd.cFileName) == 0) - selected_item = nfiles; - - str_set(&config_filenames[nfiles], wfd.cFileName); - - nfiles++; - if (nfiles == MAX_CONFIG_FILES) + size_t len = strlen(wfd_.cFileName); + if (bufpos_ + len >= sizeof(buf_) - 1) + continue; + size_t old_bufpos = bufpos_; + memcpy(buf_ + bufpos_, wfd_.cFileName, len + 1); + bufpos_ = bufpos_ + len + 1; + config_filenames_indent[nfiles_] = depth_ + !!(wfd_.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY); + str_set(&config_filenames[nfiles_], buf_); + nfiles_++; + if (nfiles_ == MAX_CONFIG_FILES) break; - } while (FindNextFile(handle, &wfd)); + if (wfd_.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) { + buf_[bufpos_ - 1] = '\\'; + depth_++; + if (depth_ < 16) + Recurse(); + depth_--; + if (nfiles_ == MAX_CONFIG_FILES) + break; + } + bufpos_ = old_bufpos; + } while (FindNextFile(handle, &wfd_)); FindClose(handle); } +} - HMENU where; + +static int AddToAvailableFilesPopup(HMENU menu, int max_num_items, bool is_settings) { + ConfigMenuBuilder menu_builder; + HMENU where[16] = {0}; + + menu_builder.Recurse(); bool is_connected = g_backend->is_started(); + uint32 last_indent = 0; + where[0] = menu; - where = menu; - for (int i = 0; i < nfiles; i++) { - if (i == max_num_items) { - where = CreatePopupMenu(); - AppendMenu(menu, MF_POPUP, (UINT_PTR)where, "&More"); + for (int i = 0; i < menu_builder.nfiles_; i++) { + uint32 indent = config_filenames_indent[i]; + if (indent > last_indent) { + HMENU n = CreatePopupMenu(); + where[indent] = n; + AppendMenu(where[last_indent], MF_POPUP, (UINT_PTR)n, FindFilenameComponent(config_filenames[i])); + } else { + bool selected_item = (strcmp(g_current_filename, config_filenames[i]) == 0); + AppendMenu(where[indent], (selected_item && is_connected) ? + MF_CHECKED : 0, ID_POPUP_CONFIG_FILE + i, + StripConfExtension( + FindFilenameComponent(config_filenames[i]), menu_builder.buf_, sizeof(menu_builder.buf_))); + if (selected_item) + SetMenuDefaultItem(where[indent], ID_POPUP_CONFIG_FILE + i, MF_BYCOMMAND); } - - AppendMenu(where, (i == selected_item && is_connected) ? MF_CHECKED : 0, ID_POPUP_CONFIG_FILE + i, StripConfExtension(config_filenames[i], buf, sizeof(buf))); - - if (i == selected_item) - SetMenuDefaultItem(where, ID_POPUP_CONFIG_FILE + i, MF_BYCOMMAND); + last_indent = indent; } - if (nfiles) - AppendMenu(menu, MF_SEPARATOR, 0, 0); + + if (menu_builder.nfiles_ == 0) + AppendMenu(menu, MF_GRAYED | MF_DISABLED, 0, "(no config files found)"); + + return menu_builder.nfiles_; } static void ShowSettingsMenu(HWND wnd) { @@ -401,102 +494,64 @@ static void ShowSettingsMenu(HWND wnd) { AddToAvailableFilesPopup(menu, 10, true); - AppendMenu(menu, 0, IDSETT_OPEN_FILE, "&Import File..."); - AppendMenu(menu, 0, IDSETT_BROWSE_FILES, "&Browse in Explorer"); + //POINT pt; + //GetCursorPos(&pt); + + RECT r = GetParentRect(GetDlgItem(g_ui_window, ID_START)); + + RECT r2 = GetParentRect(hwndPaintBox); + + POINT pt = {r2.left, r.bottom}; + + ClientToScreen(g_ui_window, &pt); - AppendMenu(menu, MF_SEPARATOR, 0, 0); - AppendMenu(menu, 0, IDSETT_KEYPAIR, "Generate &Key Pair..."); - AppendMenu(menu, MF_SEPARATOR, 0, 0); - HMENU blockinternet = CreatePopupMenu(); - AppendMenu(blockinternet, 0, IDSETT_BLOCKINTERNET_OFF, "Off"); - AppendMenu(blockinternet, MF_SEPARATOR, 0, 0); - AppendMenu(blockinternet, 0, IDSETT_BLOCKINTERNET_ROUTE, "Yes, with Routing Rules"); - AppendMenu(blockinternet, 0, IDSETT_BLOCKINTERNET_FIREWALL, "Yes, with Firewall Rules"); - AppendMenu(blockinternet, 0, IDSETT_BLOCKINTERNET_BOTH, "Yes, Both Methods"); - bool is_activated = false; - int value = GetInternetBlockState(&is_activated); - CheckMenuRadioItem(blockinternet, IDSETT_BLOCKINTERNET_OFF, IDSETT_BLOCKINTERNET_BOTH, IDSETT_BLOCKINTERNET_OFF + value, MF_BYCOMMAND); - AppendMenu(menu, MF_POPUP + is_activated * MF_CHECKED, (UINT_PTR)blockinternet, "Block &All Internet Traffic"); - - if (g_allow_pre_post || GetAsyncKeyState(VK_SHIFT) < 0) { - AppendMenu(menu, g_allow_pre_post ? MF_CHECKED : 0, IDSETT_PREPOST, "&Allow Pre/Post commands"); - } - AppendMenu(menu, MF_SEPARATOR, 0, 0); - AppendMenu(menu, 0, IDSETT_WEB_PAGE, "Go to &Web Page"); - AppendMenu(menu, 0, IDSETT_OPENSOURCE, "See Open Source Licenses"); - AppendMenu(menu, 0, IDSETT_ABOUT, "&About TunSafe..."); - - POINT pt; - GetCursorPos(&pt); - g_last_popup_is_tray = false; int rv = TrackPopupMenu(menu, 0, pt.x, pt.y, 0, wnd, NULL); DestroyMenu(menu); } -void FindDesktopFolderView(REFIID riid, void **ppv) { - CComPtr spShellWindows; - spShellWindows.CoCreateInstance(CLSID_ShellWindows); - - CComVariant vtLoc(CSIDL_DESKTOP); - CComVariant vtEmpty; - long lhwnd; - CComPtr spdisp; - spShellWindows->FindWindowSW( - &vtLoc, &vtEmpty, - SWC_DESKTOP, &lhwnd, SWFO_NEEDDISPATCH, &spdisp); - - CComPtr spBrowser; - CComQIPtr(spdisp)-> - QueryService(SID_STopLevelBrowser, - IID_PPV_ARGS(&spBrowser)); - - CComPtr spView; - spBrowser->QueryActiveShellView(&spView); - - spView->QueryInterface(riid, ppv); -} - -void GetDesktopAutomationObject(REFIID riid, void **ppv) { - CComPtr spsv; - FindDesktopFolderView(IID_PPV_ARGS(&spsv)); - CComPtr spdispView; - spsv->GetItemObject(SVGIO_BACKGROUND, IID_PPV_ARGS(&spdispView)); - spdispView->QueryInterface(riid, ppv); -} - -void ShellExecuteFromExplorer( - PCSTR pszFile, - PCSTR pszParameters = nullptr, - PCSTR pszDirectory = nullptr, - PCSTR pszOperation = nullptr, - int nShowCmd = SW_SHOWNORMAL) { - CComPtr spFolderView; - GetDesktopAutomationObject(IID_PPV_ARGS(&spFolderView)); - CComPtr spdispShell; - spFolderView->get_Application(&spdispShell); - - CComQIPtr(spdispShell) - ->ShellExecute(CComBSTR(pszFile), - CComVariant(pszParameters ? pszParameters : ""), - CComVariant(pszDirectory ? pszDirectory : ""), - CComVariant(pszOperation ? pszOperation : ""), - CComVariant(nShowCmd)); +static bool HasReadWriteAccess(const char *filename) { + HANDLE fileH = CreateFile(filename, + GENERIC_READ | GENERIC_WRITE, + FILE_SHARE_READ | FILE_SHARE_WRITE, // For Exclusive access + 0, + OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL, + NULL); + if (fileH != INVALID_HANDLE_VALUE) { + CloseHandle(fileH); + return true; + } + return false; } static void OpenEditor() { - char buf[MAX_PATH]; - if (GetConfigFullName(g_current_filename, buf, ARRAYSIZE(buf))) { - SHELLEXECUTEINFO shinfo = {0}; - shinfo.cbSize = sizeof(shinfo); - shinfo.fMask = SEE_MASK_CLASSNAME; - shinfo.lpFile = buf; - shinfo.lpParameters = ""; - shinfo.lpClass = ".txt"; - shinfo.nShow = SW_SHOWNORMAL; - ShellExecuteEx(&shinfo); + SHELLEXECUTEINFO shinfo = {0}; + shinfo.hwnd = g_ui_window; + shinfo.cbSize = sizeof(shinfo); + shinfo.nShow = SW_SHOWNORMAL; + + if (g_current_filename[0]) { + if (!HasReadWriteAccess(g_current_filename)) { + // Need to runas admin + char buf[1024]; + if (!ExpandEnvironmentStrings("%windir%\\system32\\notepad.exe", buf, sizeof(buf))) + return; + shinfo.lpFile = buf; + char *filename = str_cat_alloc("\"", g_current_filename, "\""); + shinfo.lpParameters = filename; + shinfo.lpVerb = "runas"; + ShellExecuteEx(&shinfo); + free(filename); + } else { + shinfo.fMask = SEE_MASK_CLASSNAME; + shinfo.lpFile = g_current_filename; + shinfo.lpParameters = ""; + shinfo.lpClass = ".txt"; + ShellExecuteEx(&shinfo); + } } } @@ -509,127 +564,62 @@ static void BrowseFiles() { } } -bool FileExists(const CHAR *fileName) { - DWORD fileAttr = GetFileAttributes(fileName); - return (0xFFFFFFFF != fileAttr); -} - -__int64 FileSize(const char* name) { - WIN32_FILE_ATTRIBUTE_DATA fad; - if (!GetFileAttributesEx(name, GetFileExInfoStandard, &fad)) - return -1; // error condition, could call GetLastError to find out more - LARGE_INTEGER size; - size.HighPart = fad.nFileSizeHigh; - size.LowPart = fad.nFileSizeLow; - return size.QuadPart; -} - -static bool is_space(uint8_t c) { - return c == ' ' || c == '\r' || c == '\n' || c == '\t'; -} - -static bool is_valid(uint8_t c) { - return c >= ' ' || c == '\r' || c == '\n' || c == '\t'; -} - -bool SanityCheckBuf(uint8 *buf, size_t n) { - for (size_t i = 0; i < n; i++) { - if (!is_space(buf[i])) { - if (buf[i] != '[' && buf[i] != '#') - return false; - for (; i < n; i++) - if (!is_valid(buf[i])) - return false; - return true; - } - } - return false; -} - -uint8* LoadFileSane(const char *name, size_t *size) { - FILE *f = fopen(name, "rb"); - uint8 *new_file = NULL, *file = NULL; - size_t j, i, n; - if (!f) return false; - fseek(f, 0, SEEK_END); - long x = ftell(f); - fseek(f, 0, SEEK_SET); - if (x < 0 || x >= 65536) goto error; - file = (uint8*)malloc(x + 1); - if (!file) goto error; - n = fread(file, 1, x + 1, f); - if (n != x || !SanityCheckBuf(file, n)) - goto error; - // Convert the file to DOS new lines - for (i = j = 0; i < n; i++) - j += (file[i] == '\n'); - new_file = (uint8*)malloc(n + 1 + j); - if (!new_file) goto error; - for (i = j = 0; i < n; i++) { - uint8 c = file[i]; - if (c == '\r') - continue; - if (c == '\n') - new_file[j++] = '\r'; - new_file[j++] = c; - } - new_file[j] = 0; - *size = j; - -error: - fclose(f); - free(file); - return new_file; -} - -bool WriteOutFile(const char *filename, uint8 *filedata, size_t filesize) { - FILE *f = fopen(filename, "wb"); - if (!f) return false; - if (fwrite(filedata, 1, filesize, f) != filesize) { - fclose(f); - return false; - } - fclose(f); - return true; -} - -void ImportFile(const char *s) { +bool ImportFile(const char *s, bool silent = false) { char buf[1024]; char mesg[1024]; size_t filesize; - const char *last = FindLastFolderSep(s); - if (!last || !GetConfigFullName(last + 1, buf, ARRAYSIZE(buf)) || _stricmp(buf, s) == 0) - return; + const char *last = FindFilenameComponent(s); + uint8 *filedata = NULL; + bool rv = false; + int filerv; - uint8 *filedata = LoadFileSane(s, &filesize); - if (!filedata) goto fail; + if (!*last || !GetConfigFullName(last, buf, ARRAYSIZE(buf)) || _stricmp(buf, s) == 0) + goto out; - if (FileExists(buf)) { - snprintf(mesg, ARRAYSIZE(mesg), "A file already exists with the name '%s' in the configuration folder. Do you want to overwrite it?", last + 1); - if (MessageBoxA(g_ui_window, mesg, "TunSafe", MB_OKCANCEL | MB_ICONEXCLAMATION) != IDOK) - goto out; - } else { - snprintf(mesg, ARRAYSIZE(mesg), "Do you want to import '%s' into TunSafe?", last + 1); - if (MessageBoxA(g_ui_window, mesg, "TunSafe", MB_OKCANCEL | MB_ICONQUESTION) != IDOK) - goto out; + filedata = LoadFileSane(s, &filesize); + if (!filedata) + goto out; + + if (!silent) { + if (FileExists(buf)) { + snprintf(mesg, ARRAYSIZE(mesg), "A file already exists with the name '%s' in the configuration folder. Do you want to overwrite it?", last); + if (MessageBoxA(g_ui_window, mesg, "TunSafe", MB_OKCANCEL | MB_ICONEXCLAMATION) != IDOK) + goto out; + } else { + snprintf(mesg, ARRAYSIZE(mesg), "Do you want to import '%s' into TunSafe?", last); + if (MessageBoxA(g_ui_window, mesg, "TunSafe", MB_OKCANCEL | MB_ICONQUESTION) != IDOK) + goto out; + } } - if (!WriteOutFile(buf, filedata, filesize)) { + filerv = WriteOutFile(buf, filedata, filesize); + + // elevate? + if (filerv == kWriteOutFile_AccessError && g_is_limited_uac_account) { + char *args = str_cat_alloc("--import \"", s, "\""); + rv = RunProcessAsAdminWithArgs(args, true); + free(args); + return rv; + } + + rv = (filerv == kWriteOutFile_Ok); + if (!rv) DeleteFileA(buf); -fail: - MessageBoxA(g_ui_window, "There was a problem importing the file.", "TunSafe", MB_ICONEXCLAMATION); - } else { - LoadConfigFile(last + 1, true, false); - } out: free(filedata); + + if (!silent) { + if (rv) + LoadConfigFile(buf, true, false); + else + MessageBoxA(g_ui_window, "There was a problem importing the file.", "TunSafe", MB_ICONEXCLAMATION); + } + return !rv; } void ShowUI(HWND hWnd) { - g_ui_visible = true; - UpdateStats(); - ShowWindow(hWnd, SW_SHOW); + SetUiVisibility(true); BringWindowToTop(hWnd); SetForegroundWindow(hWnd); } @@ -717,77 +707,280 @@ static INT_PTR WINAPI KeyPairDlgProc(HWND hWnd, UINT message, WPARAM wParam, return FALSE; } -bool wm_dropfiles_recursive; -uint64 last_auto_service_restart; +static void SetStartupFlags(int new_flags) { + // Determine whether to autorun or not. + bool autorun = (new_flags & kStartupFlag_MinimizeToTrayWhenWindowsStarts) || + !(new_flags & kStartupFlag_BackgroundService) && (new_flags & kStartupFlag_ConnectWhenWindowsStarts); + + // Update the autorun key. + HKEY hkey; + LSTATUS result; + result = RegOpenKeyEx(HKEY_CURRENT_USER, "Software\\Microsoft\\Windows\\CurrentVersion\\Run", 0, KEY_WRITE, &hkey); + if (result == 0) { + if (autorun) { + wchar_t buf[512 + 32]; + buf[0] = '"'; + DWORD len = GetModuleFileNameW(NULL, buf + 1, 512); + if (len < 512) { + memcpy(buf + len + 1, L"\" --autostart", sizeof(wchar_t) * 14); + result = RegSetValueExW(hkey, L"TunSafe", NULL, REG_SZ, (BYTE*)buf, (DWORD)(len + 15) * sizeof(wchar_t)); + } + } else { + RegDeleteValueW(hkey, L"TunSafe"); + } + RegCloseKey(hkey); + } + RegWriteInt(g_reg_key, "StartupFlags", new_flags); + + bool was_started = g_backend && g_backend->is_started(); + bool recreate_backend = false; + + if (!!(new_flags & (kStartupFlag_BackgroundService | kStartupFlag_ForegroundService))) { + // Want to run as a service - make sure service is installed and running. + if (!IsTunsafeServiceRunning()) { + g_backend->Stop(); + RINFO("Starting TunSafe service..."); + InstallTunSafeWindowsService(); + recreate_backend = true; + } +} else { + if (IsTunSafeServiceInstalled()) { + g_backend->Stop(); + g_backend->Teardown(); + + RINFO("Removing TunSafe service..."); + // Don't want to run as a service - Make sure we delete the service. + if (g_is_limited_uac_account) { + // Need to stop this early so service process is able to open. + CloseHandle(g_runonce_mutex); + if (!RunProcessAsAdminWithArgs("--delete-service-and-start", false)) { + RINFO("Unable to stop and remove service"); + uint32 m = kStartupFlag_BackgroundService | kStartupFlag_ForegroundService; + new_flags = (g_startup_flags & m) | (new_flags & ~m); + } else { + PostQuitMessage(0); + return; + } + } else { + if (!UninstallTunSafeWindowsService()) { + RINFO("Unable to stop and remove service"); + uint32 m = kStartupFlag_BackgroundService | kStartupFlag_ForegroundService; + new_flags = (g_startup_flags & m) | (new_flags & ~m); + } + } + recreate_backend = true; + } + } + if (recreate_backend) { + CreateLocalOrRemoteBackend(!!(new_flags & (kStartupFlag_BackgroundService | kStartupFlag_ForegroundService))); + if (was_started) + StartTunsafeBackend(UIW_START); + } + g_startup_flags = new_flags; + g_backend->SetServiceStartupFlags(g_startup_flags); +} + +enum { + kTab_Logs = 0, + kTab_Charts = 1, + kTab_Advanced = 2, +}; + +static void UpdateGraphReq() { + if (g_backend && (g_current_tab != 1 || !g_ui_visible)) + g_backend->GetGraph(0); +} + +static void UpdateTabSelection() { + int tab = TabCtrl_GetCurSel(hwndTab); + HWND wnd = g_ui_window; + g_current_tab = tab; + ShowWindow(hwndEdit, (tab == kTab_Logs) ? SW_SHOW : SW_HIDE); + ShowWindow(hwndGraphBox, (tab == kTab_Charts) ? SW_SHOW : SW_HIDE); + ShowWindow(hwndAdvancedBox, (tab == kTab_Advanced) ? SW_SHOW : SW_HIDE); + UpdateGraphReq(); +} + +struct WindowSizingItem { + uint16 id; + uint16 edges; +}; + +enum { + WSI_LEFT = 1, + WSI_RIGHT = 2, + WSI_TOP = 4, + WSI_BOTTOM = 8, +}; + +static const WindowSizingItem kWindowSizing[] = { + {ID_START,WSI_LEFT | WSI_RIGHT}, + {ID_STOP,WSI_LEFT | WSI_RIGHT}, + {ID_EDITCONF,WSI_LEFT | WSI_RIGHT}, + {IDC_PAINTBOX,WSI_RIGHT}, + {IDC_TAB, WSI_RIGHT | WSI_BOTTOM}, +}; + +static void HandleWindowSizing() { + RECT wr; + + GetClientRect(g_ui_window, &wr); + + static int g_orig_w, g_orig_h; + static RECT g_orig_rects[ARRAYSIZE(kWindowSizing)]; + + if (g_orig_w == 0) { + g_orig_w = wr.right; + g_orig_h = wr.bottom; + for (size_t i = 0; i < ARRAYSIZE(kWindowSizing); i++) { + const WindowSizingItem *it = &kWindowSizing[i]; + g_orig_rects[i] = GetParentRect(GetDlgItem(g_ui_window, it->id)); + } + } + + int dx = wr.right - g_orig_w; + int dy = wr.bottom - g_orig_h; + + if (dx|dy) { + HDWP dwp = BeginDeferWindowPos(10), dwp_next; + for (size_t i = 0; i < ARRAYSIZE(kWindowSizing); i++) { + const WindowSizingItem *it = &kWindowSizing[i]; + HWND wnd = GetDlgItem(g_ui_window, it->id); + RECT r = g_orig_rects[i]; + if (it->edges & WSI_LEFT) r.left += dx; + if (it->edges & WSI_RIGHT) r.right += dx; + if (it->edges & WSI_TOP) r.top += dy; + if (it->edges & WSI_BOTTOM) r.bottom += dy; + if (r.right < r.left) r.right = r.left; + if (r.bottom < r.top) r.bottom = r.top; + dwp_next = DeferWindowPos(dwp, wnd, NULL, r.left, r.top, r.right - r.left, r.bottom - r.top, SWP_NOZORDER | SWP_NOREPOSITION | SWP_NOACTIVATE); + dwp = dwp_next ? dwp_next : dwp; + } + EndDeferWindowPos(dwp); + } + + RECT rect = GetParentRect(hwndTab); + TabCtrl_AdjustRect(hwndTab, false, &rect); + MoveWindow(hwndEdit, rect.left, rect.top, rect.right - rect.left, rect.bottom - rect.top, TRUE); + MoveWindow(hwndGraphBox, rect.left, rect.top, rect.right - rect.left, rect.bottom - rect.top, TRUE); + MoveWindow(hwndAdvancedBox, rect.left, rect.top, rect.right - rect.left, rect.bottom - rect.top, TRUE); + + int parts[3] = { + (int)(wr.right * 0.2f), + (int)(wr.right * 0.6f), + (int)-1, + }; + + SendMessage(hwndStatus, SB_SETPARTS, 3, (LPARAM)parts); + SendMessage(hwndStatus, WM_SIZE, 0, 0); + InvalidateRect(hwndStatus, NULL, TRUE); +} + +static void HandleClickedItem(HWND hWnd, int wParam) { + if (wParam >= ID_POPUP_CONFIG_FILE && wParam < ID_POPUP_CONFIG_FILE + MAX_CONFIG_FILES) { + const char *new_conf = config_filenames[wParam - ID_POPUP_CONFIG_FILE]; + if (!new_conf) + return; + + if (strcmp(new_conf, g_current_filename) == 0 && g_backend->is_started()) { + StopTunsafeBackend(UIW_NONE); + } else { + LoadConfigFile(new_conf, true, GetAsyncKeyState(VK_SHIFT) >= 0); + } + + return; + } + switch (wParam) { + case ID_START: StartTunsafeBackend(UIW_START); break; + case ID_STOP: StopTunsafeBackend(UIW_NONE); break; + case ID_EXIT: PostQuitMessage(0); break; + case ID_MORE_BUTTON: ShowSettingsMenu(hWnd); break; + case IDSETT_WEB_PAGE: ShellExecute(g_ui_window, NULL, "https://tunsafe.com/", NULL, NULL, 0); break; + case IDSETT_OPENSOURCE: ShellExecute(g_ui_window, NULL, "https://tunsafe.com/open-source", NULL, NULL, 0); break; + case ID_EDITCONF: OpenEditor(); break; + case IDSETT_BROWSE_FILES:BrowseFiles(); break; + case IDSETT_OPEN_FILE: BrowseFile(hWnd); break; + case IDSETT_ABOUT: + MessageBoxA(g_ui_window, TUNSAFE_VERSION_STRING "\r\n\r\nCopyright © 2018, Ludvig Strigeus\r\n\r\nThanks for choosing TunSafe!\r\n\r\nThis version was built on " __DATE__ " " __TIME__, "About TunSafe", MB_ICONINFORMATION); + break; + case IDSETT_KEYPAIR: + DialogBox(g_hinstance, MAKEINTRESOURCE(IDD_DIALOG2), hWnd, &KeyPairDlgProc); + break; + case IDSETT_BLOCKINTERNET_OFF: + case IDSETT_BLOCKINTERNET_ROUTE: + case IDSETT_BLOCKINTERNET_FIREWALL: + case IDSETT_BLOCKINTERNET_BOTH: + { + InternetBlockState old_state = g_backend->GetInternetBlockState(NULL); + InternetBlockState new_state = (InternetBlockState)(wParam - IDSETT_BLOCKINTERNET_OFF); + + if (old_state == kBlockInternet_Off && new_state != kBlockInternet_Off) { + if (MessageBoxA(g_ui_window, "Warning! All Internet traffic will be blocked until you restart your computer. Only traffic through TunSafe will be allowed.\r\n\r\nThe blocking is activated the next time you connect to a VPN server.\r\n\r\nDo you want to continue?", "TunSafe", MB_ICONWARNING | MB_OKCANCEL) == IDCANCEL) + return; + } + + g_backend->SetInternetBlockState(new_state); + + if ((~old_state & new_state) && g_backend->is_started()) + StartTunsafeBackend(UIW_START); + return; + } + case IDSETT_SERVICE_OFF: + case IDSETT_SERVICE_FOREGROUND: + case IDSETT_SERVICE_BACKGROUND: + SetStartupFlags((int)((g_startup_flags & ~3) + wParam - IDSETT_SERVICE_OFF)); + break; + case IDSETT_SERVICE_CONNECT_AUTO: + SetStartupFlags(g_startup_flags ^ kStartupFlag_ConnectWhenWindowsStarts); + break; + case IDSETT_SERVICE_MINIMIZE_AUTO: + SetStartupFlags(g_startup_flags ^ kStartupFlag_MinimizeToTrayWhenWindowsStarts); + break; + + case IDSETT_PREPOST: + { + if (!g_hklm_reg_key) { + if (!RunProcessAsAdminWithArgs(g_allow_pre_post ? "--set-allow-pre-post 0" : "--set-allow-pre-post 1", true)) + MessageBox(g_ui_window, "You need to run TunSafe as an Administrator to be able to change this setting.", "TunSafe", MB_ICONWARNING); + g_allow_pre_post = RegReadInt(g_hklm_readonly_reg_key, "AllowPrePost", 0) != 0; + return; + } + g_allow_pre_post = !g_allow_pre_post; + RegWriteInt(g_hklm_reg_key, "AllowPrePost", g_allow_pre_post); + return; + } + } +} + static INT_PTR WINAPI DlgProc(HWND hWnd, UINT message, WPARAM wParam, LPARAM lParam) { - switch(message) { + + switch (message) { case WM_INITDIALOG: + SetMenu(hWnd, LoadMenu(g_hinstance, MAKEINTRESOURCE(IDR_MENU1))); return TRUE; case WM_CLOSE: - g_ui_visible = false; - ShowWindow(hWnd, SW_HIDE); + SetUiVisibility(false); return TRUE; - case WM_COMMAND: - if (wParam >= ID_POPUP_CONFIG_FILE && wParam < ID_POPUP_CONFIG_FILE + MAX_CONFIG_FILES) { - const char *new_conf = config_filenames[wParam - ID_POPUP_CONFIG_FILE]; - if (!new_conf) + case WM_NOTIFY: { + UINT idFrom = (UINT)((NMHDR*)lParam)->idFrom; + switch (((NMHDR*)lParam)->code) { + case TCN_SELCHANGE: + switch (idFrom) { + case IDC_TAB: + UpdateTabSelection(); return TRUE; - - if (g_last_popup_is_tray && strcmp(new_conf, g_current_filename) == 0 && g_backend->is_started()) { - StopService(UIW_NONE); - } else { - LoadConfigFile(new_conf, true, g_last_popup_is_tray); } - - return TRUE; - } - switch(wParam) { - case ID_START: - StopService(UIW_NONE); - StartService(); break; - case ID_STOP: StopService(UIW_NONE); break; - case ID_EXIT: PostQuitMessage(0); break; - case ID_RESET: g_backend->ResetStats(); break; - case ID_MORE_BUTTON: ShowSettingsMenu(hWnd); break; - case IDSETT_WEB_PAGE: ShellExecute(NULL, NULL, "https://tunsafe.com/", NULL, NULL, 0); break; - case IDSETT_OPENSOURCE: ShellExecute(NULL, NULL, "https://tunsafe.com/open-source", NULL, NULL, 0); break; - case ID_EDITCONF: OpenEditor(); break; - case IDSETT_BROWSE_FILES:BrowseFiles(); break; - case IDSETT_OPEN_FILE: BrowseFile(hWnd); break; - case IDSETT_ABOUT: - MessageBoxA(g_ui_window, TUNSAFE_VERSION_STRING "\r\n\r\nCopyright © 2018, Ludvig Strigeus\r\n\r\nThanks for choosing TunSafe!\r\n\r\nThis version was built on " __DATE__ " " __TIME__, "About TunSafe", MB_ICONINFORMATION); - break; - case IDSETT_KEYPAIR: - DialogBox(g_hinstance, MAKEINTRESOURCE(IDD_DIALOG2), hWnd, &KeyPairDlgProc); - break; - case IDSETT_BLOCKINTERNET_OFF: - case IDSETT_BLOCKINTERNET_ROUTE: - case IDSETT_BLOCKINTERNET_FIREWALL: - case IDSETT_BLOCKINTERNET_BOTH: { - InternetBlockState old_state = GetInternetBlockState(NULL); - InternetBlockState new_state = (InternetBlockState)(wParam - IDSETT_BLOCKINTERNET_OFF); - - if (old_state == kBlockInternet_Off && new_state != kBlockInternet_Off) { - if (MessageBoxA(g_ui_window, "Warning! All Internet traffic will be blocked until you restart your computer. Only traffic through TunSafe will be allowed.\r\n\r\nThe blocking is activated the next time you connect to a VPN server.\r\n\r\nDo you want to continue?", "TunSafe", MB_ICONWARNING | MB_OKCANCEL) == IDCANCEL) - return TRUE; - } - - SetInternetBlockState(new_state); - - if ((~old_state & new_state) && g_backend->is_started()) { - StopService(UIW_NONE); - StartService(); - } - return TRUE; - } - case IDSETT_PREPOST: { - g_allow_pre_post = !g_allow_pre_post; - RegWriteInt("AllowPrePost", g_allow_pre_post); - return TRUE; } + break; + } + case WM_COMMAND: + switch (HIWORD(wParam)) { + case 0: + HandleClickedItem(hWnd, (int)wParam); + break; } break; case WM_DROPFILES: @@ -800,7 +993,8 @@ static INT_PTR WINAPI DlgProc(HWND hWnd, UINT message, WPARAM wParam, case WM_USER + 1: if (lParam == WM_RBUTTONUP) { HMENU menu = CreatePopupMenu(); - AddToAvailableFilesPopup(menu, 10, false); + if (AddToAvailableFilesPopup(menu, 10, false)) + AppendMenu(menu, MF_SEPARATOR, 0, 0); bool active = g_backend->is_started(); AppendMenu(menu, 0, ID_START, active ? "Re&connect" : "&Connect"); @@ -812,150 +1006,56 @@ static INT_PTR WINAPI DlgProc(HWND hWnd, UINT message, WPARAM wParam, SetForegroundWindow(hWnd); - g_last_popup_is_tray = true; - - int rv = TrackPopupMenu(menu, 0, pt.x, pt.y, 0, hWnd, NULL); + int rv = TrackPopupMenu(menu, 0, pt.x, pt.y, 0, hWnd, NULL); DestroyMenu(menu); } else if (lParam == WM_LBUTTONDBLCLK) { if (IsWindowVisible(hWnd)) { - g_ui_visible = false; - ShowWindow(hWnd, SW_HIDE); + SetUiVisibility(false); } else { ShowUI(hWnd); } } return TRUE; case WM_USER + 2: - if (g_ui_ip != 0 && g_minimize_on_connect) { - g_minimize_on_connect = false; - g_ui_visible = false; - ShowWindow(hWnd, SW_HIDE); - } - UpdateIcon(UIW_NONE); - return TRUE; - case WM_USER + 3: { - CHARRANGE cr; - cr.cpMin = -1; - cr.cpMax = -1; - // hwnd = rich edit hwnd - SendDlgItemMessage(hWnd, IDC_RICHEDIT21, EM_EXSETSEL, 0, (LPARAM)&cr); - SendDlgItemMessage(hWnd, IDC_RICHEDIT21, EM_REPLACESEL, 0, (LPARAM)lParam); - free( (void*) lParam); + g_backend_delegate->DoWork(); return true; + + case WM_INITMENU: { + HMENU menu = GetMenu(g_ui_window); + + CheckMenuItem(menu, IDSETT_SERVICE_CONNECT_AUTO, MF_CHECKED * !!(g_startup_flags & kStartupFlag_ConnectWhenWindowsStarts)); + CheckMenuItem(menu, IDSETT_SERVICE_MINIMIZE_AUTO, MF_CHECKED * !!(g_startup_flags & kStartupFlag_MinimizeToTrayWhenWindowsStarts)); + CheckMenuItem(menu, IDSETT_PREPOST, g_allow_pre_post ? MF_CHECKED : 0); + + bool is_activated = false; + int value = g_backend->GetInternetBlockState(&is_activated); + CheckMenuRadioItem(menu, IDSETT_BLOCKINTERNET_OFF, IDSETT_BLOCKINTERNET_BOTH, IDSETT_BLOCKINTERNET_OFF + value, MF_BYCOMMAND); + CheckMenuRadioItem(menu, IDSETT_SERVICE_OFF, IDSETT_SERVICE_BACKGROUND, IDSETT_SERVICE_OFF + (g_startup_flags & 3), MF_BYCOMMAND); + + break; } - case WM_USER + 6: - SetDlgItemText(hWnd, IDC_RICHEDIT21, ""); - return true; - case WM_USER + 5: - UpdatePublicKey((char*)lParam); - return true; - case WM_USER + 4: { - UpdateStats(); - return true; - } - case WM_USER + 10: + + case WM_SIZE: + if (wParam == SIZE_MAXIMIZED || wParam == SIZE_RESTORED) { + if (g_ui_window) + HandleWindowSizing(); + } break; - case WM_USER + 11: { - uint64 now = GetTickCount64(); - if (now < last_auto_service_restart + 5000) { - RERROR("Too many automatic restarts..."); - StopService(UIW_STOPPED_WORKING_FAIL); - } else { - last_auto_service_restart = now; - RestartService(UIW_STOPPED_WORKING_RETRY, true); + case WM_EXITMENULOOP: + g_timestamp_of_exit_menuloop = GetTickCount(); + break; + + default: + if (message == g_message_taskbar_created) { + g_has_icon = false; + UpdateIcon(UIW_NONE); } break; } - } return FALSE; } -struct PostMsg { - int msg; - WPARAM wparam; - LPARAM lparam; - PostMsg(int a, WPARAM b, LPARAM c) : msg(a), wparam(b), lparam(c) {} -}; - -static HANDLE msg_event; -static CRITICAL_SECTION msg_section; -static std::vector msgvect; - -static DWORD WINAPI MessageThread(void *x) { - std::vector proc; - for(;;) { - WaitForSingleObject(msg_event, INFINITE); - proc.clear(); - EnterCriticalSection(&msg_section); - std::swap(proc, msgvect); - LeaveCriticalSection(&msg_section); - for(size_t i = 0; i != proc.size(); i++) - PostMessage(g_ui_window, proc[i].msg, proc[i].wparam, proc[i].lparam); - } -} - -static void MyPostMessage(int msg, WPARAM wparam, LPARAM lparam) { - size_t count; - EnterCriticalSection(&msg_section); - count = msgvect.size(); - msgvect.emplace_back(msg, wparam, lparam); - LeaveCriticalSection(&msg_section); - if (count == 0) SetEvent(msg_event); -} - -static void InitMyPostMessage() { - msg_event = CreateEvent(NULL, FALSE, FALSE, NULL); - InitializeCriticalSection(&msg_section); - DWORD thread_id; - CloseHandle(CreateThread(NULL, 0, &MessageThread, NULL, 0, &thread_id)); -} - - -void OsGetRandomBytes(uint8 *data, size_t data_size) { -#if defined(OS_WIN) - static BOOLEAN(APIENTRY *pfn)(void*, ULONG); - static bool resolved; - if (!resolved) { - pfn = (BOOLEAN(APIENTRY *)(void*, ULONG))GetProcAddress(LoadLibrary("ADVAPI32.DLL"), "SystemFunction036"); - resolved = true; - } - if (pfn && pfn(data, (ULONG)data_size)) - return; - int r = 0; -#else - int fd = open("/dev/urandom", O_RDONLY); - int r = read(fd, data, data_size); - if (r < 0) r = 0; - close(fd); -#endif - for (; r < data_size; r++) - data[r] = rand() >> 6; -} - -void OsInterruptibleSleep(int millis) { - SleepEx(millis, TRUE); -} - - -uint64 OsGetMilliseconds() { - return GetTickCount64(); -} - -void OsGetTimestampTAI64N(uint8 dst[12]) { - SYSTEMTIME systime; - uint64 file_time_uint64 = 0; - GetSystemTime(&systime); - SystemTimeToFileTime(&systime, (FILETIME*)&file_time_uint64); - uint64 time_since_epoch_100ns = (file_time_uint64 - 116444736000000000); - uint64 secs_since_epoch = time_since_epoch_100ns / 10000000 + 0x400000000000000a; - uint32 nanos = (uint32)(time_since_epoch_100ns % 10000000) * 100; - WriteBE64(dst, secs_since_epoch); - WriteBE32(dst + 8, nanos); -} - - - void PushLine(const char *s) { size_t l = strlen(s); char buf[64]; @@ -973,7 +1073,8 @@ void PushLine(const char *s) { x[l + tl] = '\r'; x[l + tl + 1] = '\n'; x[l + tl + 2] = '\0'; - MyPostMessage(WM_USER + 3, 0, (LPARAM)x); + g_backend_delegate->OnLogLine((const char**)&x); + free(x); } void EnsureConfigDirCreated() { @@ -986,7 +1087,6 @@ void EnableControl(int wnd, bool b) { EnableWindow(GetDlgItem(g_ui_window, wnd), b); } - LRESULT CALLBACK NotifyWndProc(HWND hwnd, UINT uMsg, WPARAM wParam, LPARAM lParam) { switch (uMsg) { case WM_USER + 10: @@ -1012,27 +1112,649 @@ void CreateNotificationWindow() { CreateWindow("TunSafe-f19e092db01cbe0fb6aee132f8231e5b71c98f90", "TunSafe-f19e092db01cbe0fb6aee132f8231e5b71c98f90", 0, 0, 0, 0, 0, 0, 0, g_hinstance, NULL); } - -void CallbackUpdateUI() { - if (g_ui_visible) - MyPostMessage(WM_USER + 4, NULL, NULL); +HFONT CreateBoldUiFont() { + LOGFONT lf; + HFONT ffont = (HFONT)SendMessage(g_ui_window, WM_GETFONT, 0, 0); + GetObject(ffont, sizeof(lf), &lf); + lf.lfWeight = FW_BOLD; + HFONT font = CreateFontIndirect(&lf); + return font; } -void CallbackTriggerReconnect() { - PostMessage(g_ui_window, WM_USER + 11, 0, 0); +void FillRectColor(HDC dc, const RECT &r, COLORREF color) { + COLORREF old = ::SetBkColor(dc, color); + ExtTextOut(dc, 0, 0, ETO_OPAQUE, &r, NULL, 0, NULL); + ::SetBkColor(dc, old); } -void CallbackSetPublicKey(const uint8 public_key[32]) { - char *str = (char*)base64_encode(public_key, 32, NULL); - PostMessage(g_ui_window, WM_USER + 5, NULL, (LPARAM)str); +void DrawRectOutline(HDC dc, const RECT &r) { + POINT points[5] = { + {r.left, r.top}, + {r.right, r.top}, + {r.right, r.bottom}, + {r.left, r.bottom}, + {r.left, r.top} + }; + Polyline(dc, points, 5); } -int WINAPI WinMain (HINSTANCE hInstance, HINSTANCE hPrevInstance, LPSTR lpCmdLine, int nShowCmd) { +static HFONT CreateFontHelper(int size, byte flags, const char *face, int angle = 0) { + return CreateFontA(-RescaleDpi(size), 0, angle, angle, flags & 1 ? FW_BOLD : 0, FALSE, flags & 2 ? 1 : 0, FALSE, DEFAULT_CHARSET, OUT_DEFAULT_PRECIS, + CLIP_DEFAULT_PRECIS, DEFAULT_QUALITY, DEFAULT_PITCH, face); +} + +static const char *StatusCodeToString(TunsafeBackend::StatusCode code) { + switch (code) { + case TunsafeBackend::kErrorInitialize: return "Configuration Error"; + case TunsafeBackend::kErrorTunPermanent: return "TUN Adapter Error"; + case TunsafeBackend::kErrorServiceLost: return "Service Lost"; + case TunsafeBackend::kStatusStopped: return "Disconnected"; + case TunsafeBackend::kStatusInitializing: return "Initializing"; + case TunsafeBackend::kStatusConnecting: return "Connecting..."; + case TunsafeBackend::kStatusReconnecting: return "Reconnecting..."; + case TunsafeBackend::kStatusConnected: return "Connected"; + case TunsafeBackend::kStatusTunRetrying: return "TUN Adapter Error, retrying..."; + default: + return "Unknown"; + } +} + +static void DrawInPaintBox(HDC hdc, int w, int h) { + RECT rect = {0, 0, w, h}; + FillRect(hdc, &rect, (HBRUSH)(COLOR_3DFACE + 1)); + + HFONT font = CreateBoldUiFont(); + + char namebuf[128]; + GetCurrentConfigTitle(namebuf, sizeof(namebuf)); + + RECT btrect = GetParentRect(GetDlgItem(g_ui_window, ID_START)); + + HPEN pen = CreatePen(PS_SOLID, 0, GetSysColor(COLOR_3DSHADOW)); + HBRUSH brush = GetSysColorBrush(COLOR_WINDOW); + + SelectObject(hdc, pen); + SelectObject(hdc, brush); + + comborect = MakeRect(0, btrect.top + 1, w, btrect.bottom - 1); + Rectangle(hdc, 0, btrect.top + 1, w, btrect.bottom - 1); + + if (arrowbitmap == NULL) + arrowbitmap = LoadBitmap(g_hinstance, MAKEINTRESOURCE(IDB_DOWNARROW)); + + int bw = RescaleDpi(6); + + HDC memdc = CreateCompatibleDC(hdc); + SelectObject(memdc, arrowbitmap); + StretchBlt(hdc, w - 1 - bw - 5, btrect.top + 1 + ((btrect.bottom - btrect.top - bw) >> 1), + bw, bw, memdc, 0, 0, 6, 6, SRCCOPY); + + int th = RescaleDpi(20); + + SelectObject(hdc, font); + SetBkColor(hdc, GetSysColor(COLOR_WINDOW)); + TextOut(hdc, RescaleDpi(4), btrect.top + RescaleDpi(4), namebuf, (int)strlen(namebuf)); + + int y = btrect.bottom + RescaleDpi(4); + + DeleteObject(pen); + + SelectObject(hdc, (HFONT)SendMessage(g_ui_window, WM_GETFONT, 0, 0)); + SetBkColor(hdc, GetSysColor(COLOR_3DFACE)); + + TunsafeBackend::StatusCode status = g_backend->status(); + my_strlcpy(namebuf, sizeof(namebuf) - 32, StatusCodeToString(status)); + if (status == TunsafeBackend::kStatusConnected || status == TunsafeBackend::kStatusReconnecting) { + uint64 when = g_processor_stats.first_complete_handshake_timestamp; + uint32 seconds = (when != 0) ? (uint32)((OsGetMilliseconds() - when + 500) / 1000) : 0; + snprintf(strchr(namebuf, 0), 32, ", %.2d:%.2d:%.2d", seconds / 3600, (seconds / 60) % 60, seconds % 60); + } + + int img = (status == TunsafeBackend::kStatusConnected) ? 0 : + g_backend->is_started() && !TunsafeBackend::IsPermanentError(status) ? 1 : 2; + + static const COLORREF kDotColors[3] = { + 0x51a600, + 0x00c0c0, + 0x0000c0, + }; + SetBkMode(hdc, TRANSPARENT); + COLORREF oldcolor = SetTextColor(hdc, kDotColors[img]); + HFONT oldfont = (HFONT)SelectObject(hdc, CreateFontHelper(18, 0, "Tahoma")); + wchar_t bullet = 0x25CF; + TextOutW(hdc, RescaleDpi(2), y - RescaleDpi(7), &bullet, 1); + DeleteObject(SelectObject(hdc, oldfont)); + SetTextColor(hdc, oldcolor); + + TextOut(hdc, RescaleDpi(2 + 14), y, namebuf, (int)strlen(namebuf)); + + y += RescaleDpi(18); + + uint32 ip = g_backend->GetIP(); + if (ip) { + print_ip(namebuf, ip); + TextOut(hdc, 2, y, namebuf, (int)strlen(namebuf)); + } + DeleteObject(font); + DeleteDC(memdc); +} + +typedef void DrawInPaintBoxFunc(HDC dc, int w, int h); +static void HandleWmPaintPaintbox(HWND hwnd, DrawInPaintBoxFunc *func) { + PAINTSTRUCT ps; + BeginPaint(hwnd, &ps); + + RECT r; + GetClientRect(hwnd, &r); + + HBITMAP bmp = CreateCompatibleBitmap(ps.hdc, r.right, r.bottom); + HDC dc = CreateCompatibleDC(ps.hdc); + SelectObject(dc, bmp); + + func(dc, r.right, r.bottom); + + BitBlt(ps.hdc, 0, 0, r.right, r.bottom, dc, 0, 0, SRCCOPY); + DeleteDC(dc); + DeleteObject(bmp); + EndPaint(hwnd, &ps); +} + +static LRESULT CALLBACK PaintBoxWndProc(HWND hwnd, UINT uMsg, WPARAM wParam, LPARAM lParam) { + switch (uMsg) { + case WM_PAINT: { + HandleWmPaintPaintbox(hwnd, &DrawInPaintBox); + return TRUE; + } + case WM_LBUTTONDOWN: { + POINT pt = {GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam)}; + if (PtInRect(&comborect, pt)) { + // Avoid showing the menu again if clicking to close. + if (GetTickCount() - g_timestamp_of_exit_menuloop >= 50u) + ShowSettingsMenu(g_ui_window); + } + return TRUE; + } + } + return DefWindowProc(hwnd, uMsg, wParam, lParam); +} + +static void DrawGraph(HDC dc, const RECT *rr, StatsCollector::TimeSeries **sources, const COLORREF *colors, int num_source, const char *xcaption, const char *ycaption) { + RECT r = *rr; + FillRectColor(dc, r, 0xffffff); + + RECT margins = { 30, 10, -10, -15 }; + margins = RescaleDpiRect(margins); + + r.left += margins.left; + r.top += margins.top; + r.right += margins.right; + r.bottom += margins.bottom; + + HPEN borderpen = CreatePen(PS_SOLID, 1, 0x808080); + SelectObject(dc, borderpen); + DrawRectOutline(dc, r); + + static const uint8 bits[4] = {0x70, 0, 0, 0}; + HBITMAP bmp = CreateBitmap(4, 1, 1, 1, &bits); + HBRUSH brush = CreatePatternBrush(bmp); + DeleteObject(bmp); + + // Draw horizontal dotted lines + { + SetTextColor(dc, 0x808080); + SetBkColor(dc, 0xffffff); + int inc = (r.bottom - r.top) >> 2; + RECT r2 = {r.left + 1, r.top + inc * 1, r.right - 1, r.top + inc * 1 + 1}; + FillRect(dc, &r2, brush); + r2.top += inc; r2.bottom += inc; + FillRect(dc, &r2, brush); + r2.top += inc; r2.bottom += inc; + FillRect(dc, &r2, brush); + } + DeleteObject(brush); + + static const uint8 bits_vertical[16] = { + 0xff, 0x0, 0xff, 0, + 0xff, 0x0, 0x0, 0, + 0xff, 0x0, 0x0, 0, + 0x0, 0x0, 0x0, 0}; + bmp = CreateBitmap(1, 4, 1, 1, &bits_vertical); + brush = CreatePatternBrush(bmp); + DeleteObject(bmp); + + { + // Draw vertical dotted lines + for (int i = 1; i < 12; i++) { + int x = (r.right - r.left) * i / 12; + RECT r2 = {r.left + x, r.top + 1, r.left + x + 1, r.bottom - 1}; + FillRect(dc, &r2, brush); + } + } + + { + // Draw legend text + HFONT font = CreateFontHelper(10, 0, "Tahoma"); + SelectObject(dc, font); + SetTextColor(dc, 0x202020); + SetBkMode(dc, TRANSPARENT); + RECT r2 = {r.left + 1, r.bottom, r.right - 1, r.bottom + RescaleDpi(15)}; + DrawText(dc, xcaption, (int)strlen(xcaption), &r2, DT_CENTER | DT_SINGLELINE | DT_VCENTER); + DeleteObject(font); + } + DeleteObject(brush); + DeleteObject(borderpen); + + // Determine the scaling factor + float mx = 1; + for (size_t j = 0; j != num_source; j++) { + const StatsCollector::TimeSeries *src = sources[j]; + for (size_t i = 0; i != src->size; i++) + mx = max(mx, src->data[i]); + } + int topval = (int)(mx + 0.5f); + // round it appropriately + if (topval >= 500) + topval = (topval + 99) / 100 * 100; + else if (topval >= 200) + topval = (topval + 49) / 50 * 50; + else if (topval >= 50) + topval = (topval + 9) / 10 * 10; + else if (topval >= 20) + topval = (topval + 4) / 5 * 5; + if (topval > mx) + mx = (float)topval; + + { + RECT r2 = {r.left - RescaleDpi(30), r.top - RescaleDpi(2), r.left - RescaleDpi(2), r.bottom}; + char buf[30]; + sprintf(buf, "%d", topval); + DrawText(dc, buf, (int)strlen(buf), &r2, DT_RIGHT | DT_SINGLELINE); + r2.top = r.bottom - RescaleDpi(12); + DrawText(dc, "0", 1, &r2, DT_RIGHT | DT_SINGLELINE); + } + + float mx_f = (1.0f / mx) * (r.bottom - r.top); + + for (size_t k = 0; k != num_source; k++) { + HPEN borderpen = CreatePen(PS_SOLID, 2, colors[k]); + SelectObject(dc, borderpen); + const StatsCollector::TimeSeries *src = sources[k]; + POINT *points = new POINT[src->size]; + for (size_t i = 0, j = src->shift; i != src->size; i++) { + points[i].x = (int)(r.left + (r.right - r.left) * i / (src->size - 1)); + points[i].y = r.bottom - (int)((float)src->data[j] * mx_f); + if (++j == src->size) j = 0; + } + Polyline(dc, points, src->size); + delete points; + DeleteObject(borderpen); + } + + if (ycaption != NULL) { + HFONT font = CreateFontHelper(10, 0, "Tahoma", 900); + SelectObject(dc, font); + TextOut(dc, r.left - RescaleDpi(18), ((r.top + r.bottom) >> 1) + RescaleDpi(12), ycaption, (int)strlen(ycaption)); + DeleteObject(font); + } +} + +static const char * const kGraphStepNames[] = { + "1 second step", + "5 second step", + "30 second step", + "5 minute step", +}; + +static void DrawInGraphBox(HDC hdc, int w, int h) { + RECT r = {0, 0, w, h}; + + static const COLORREF color[4] = { + 0x00c000, + 0xc00000, + }; + + LinearizedGraph *graph = g_backend->GetGraph(g_selected_graph_type); + StatsCollector::TimeSeries *time_series_ptr[4]; + StatsCollector::TimeSeries time_series[4]; + + int num_charts = 0; + if (graph && graph->num_charts <= 4) { + uint8 *ptr = (uint8*)(graph + 1); + for (int i = 0; i < graph->num_charts; i++) { + time_series_ptr[i] = &time_series[i]; + time_series[i].shift = 0; + time_series[i].size = *(uint32*)ptr; + time_series[i].data = (float*)(ptr + 4); + ptr += 4 + *(uint32*)ptr * 4; + if (ptr - (uint8*)graph > graph->total_size) + break; + } + num_charts = graph->num_charts; + } + + char buf[256]; + snprintf(buf, sizeof(buf), "Time (%s)", kGraphStepNames[g_selected_graph_type]); + + DrawGraph(hdc, &r, time_series_ptr, color, num_charts, buf, "Mbps"); + + free(graph); +} + +static LRESULT CALLBACK GraphBoxWndProc(HWND hwnd, UINT uMsg, WPARAM wParam, LPARAM lParam) { + switch (uMsg) { + case WM_PAINT: { + HandleWmPaintPaintbox(hwnd, &DrawInGraphBox); + return TRUE; + } + case WM_RBUTTONDOWN: { + HMENU menu = CreatePopupMenu(); + for(int i = 0; i < ARRAYSIZE(kGraphStepNames); i++) + AppendMenu(menu, (i == g_selected_graph_type) * MF_CHECKED, i + 1, kGraphStepNames[i]); + POINT pt = {GET_X_LPARAM(lParam), GET_Y_LPARAM(lParam)}; + ClientToScreen(hwnd, &pt); + int rv = TrackPopupMenu(menu, TPM_NONOTIFY | TPM_RETURNCMD, pt.x, pt.y, 0, hwnd, NULL); + DestroyMenu(menu); + if (rv != 0) { + g_selected_graph_type = rv - 1; + InvalidateRect(hwnd, NULL, FALSE); + } + return TRUE; + } + } + return DefWindowProc(hwnd, uMsg, wParam, lParam); +} + +struct AdvancedTextInfo { + uint16 y; + uint8 indent; + const char *title; +}; + +static const AdvancedTextInfo ADVANCED_TEXT_INFOS[] = { +#define Y 26 + {Y + 19 * 0, 66, "Public Key:"}, + {Y + 19 * 1, 66, "Endpoint:"}, + {Y + 19 * 2, 66, "Transfer:"}, + {Y + 19 * 3, 66, "Handshake:"}, + {Y + 19 * 4, 66, ""}, + {Y + 19 * 5, 66, "Overhead:"}, +#undef Y +}; + +static char *PrintLastHandshakeAt(char buf[256], WgProcessorStats *ps) { + char *d = buf; + if (ps->last_complete_handshake_timestamp) { + uint32 ago = (uint32)((OsGetMilliseconds() - ps->last_complete_handshake_timestamp + 500) / 1000); + uint32 hours = ago / 3600; + uint32 minutes = (ago - hours * 3600) / 60; + uint32 seconds = (ago - hours * 3600 - minutes * 60); + if (hours) + d += snprintf(d, 32, hours == 1 ? "%d hour, " : "%d hours, ", hours); + if (minutes) + d += snprintf(d, 32, minutes == 1 ? "%d minute, " : "%d minutes, ", minutes); + if (d == buf || seconds) + d += snprintf(d, 32, seconds == 1 ? "%d second, " : "%d seconds, ", seconds); + memcpy(d - 2, " ago", 5); + } else { + memcpy(buf, "(never)", 8); + } + return buf; +} + +static const char *GetAdvancedInfoValue(char buffer[256], int i) { + char tmp[64], tmp2[64]; + WgProcessorStats *ps = &g_processor_stats; + switch (i) { + case 0: { + if (IsOnlyZeros(g_backend->public_key(), 32)) + return ""; + char *str = (char*)base64_encode(g_backend->public_key(), 32, NULL); + snprintf(buffer, 256, "%s", str); + free(str); + return buffer; + } + case 1: { + char ip[kSizeOfAddress]; + if (ps->endpoint.sin.sin_family == 0) + return ""; + PrintIpAddr(ps->endpoint, ip); + snprintf(buffer, 256, "%s:%d", ip, htons(ps->endpoint.sin.sin_port)); + return buffer; + } + + case 2: + snprintf(buffer, 256, "%s in (%lld packets), %s out (%lld packets)", + PrintMB(tmp, ps->udp_bytes_in), ps->udp_packets_in, + PrintMB(tmp2, ps->udp_bytes_out), ps->udp_packets_out/*, udp_qsize2 - udp_qsize1, g_tun_reads*/); + return buffer; + case 3: return PrintLastHandshakeAt(buffer, ps); + case 4: { + snprintf(buffer, 256, "%d handshakes in (%d failed), %d handshakes out (%d failed)", + ps->handshakes_in, ps->handshakes_in - ps->handshakes_in_success, + ps->handshakes_out, ps->handshakes_out - ps->handshakes_out_success); + return buffer; + } + case 5: { + uint64 overhead_in = ps->udp_bytes_in + ps->udp_packets_in * 40 - ps->tun_bytes_out; + uint32 overhead_in_pct = ps->tun_bytes_out ? (uint32)(overhead_in * 100000 / ps->tun_bytes_out) : 0; + + uint64 overhead_out = ps->udp_bytes_out + ps->udp_packets_out * 40 - ps->tun_bytes_in; + uint32 overhead_out_pct = ps->tun_bytes_in ? (uint32)(overhead_out * 100000 / ps->tun_bytes_in) : 0; + + snprintf(buffer, 256, "%d.%.3d%% in, %d.%.3d%% out", overhead_in_pct / 1000, overhead_in_pct % 1000, + overhead_out_pct / 1000, overhead_out_pct % 1000); + return buffer; + } + default: return ""; + } +} + +static void DrawInAdvancedBox(HDC dc, int w, int h) { + RECT r = {0, 0, w, h}; + + FillRectColor(dc, r, 0xffffff); + + SelectObject(dc, (HFONT)SendMessage(g_ui_window, WM_GETFONT, 0, 0)); + SetTextColor(dc, GetSysColor(COLOR_WINDOWTEXT)); + SetBkColor(dc, GetSysColor(COLOR_WINDOW)); + + const AdvancedTextInfo *tp = ADVANCED_TEXT_INFOS; + char buffer[256]; + + for (size_t i = 0; i != ARRAYSIZE(ADVANCED_TEXT_INFOS); i++, tp++) { + int x = 8; + + RECT r = {x, tp->y, x + tp->indent, tp->y + 19}; + r = RescaleDpiRect(r); + ::ExtTextOut(dc, r.left, r.top, ETO_CLIPPED | ETO_OPAQUE, &r, tp->title, (UINT)strlen(tp->title), NULL); + + const char *s = GetAdvancedInfoValue(buffer, (int)i); + r.left = r.right; + r.right = w; + ::ExtTextOut(dc, r.left, r.top, ETO_CLIPPED | ETO_OPAQUE, &r, s, (UINT)strlen(s), NULL); + } + + SetBkColor(dc, GetSysColor(COLOR_3DFACE)); + + static const int grouptop[1] = { + 2 + }; + static const char *grouptext[1] = { + "General", + }; + + HFONT font = CreateFontHelper(12, 1, "Tahoma"); + SelectObject(dc, font); + for (size_t i = 0; i != ARRAYSIZE(grouptext); i++) { + RECT r = {RescaleDpi(4), RescaleDpi(grouptop[i]), w - RescaleDpi(4), RescaleDpi(grouptop[i] + 18)}; + ::ExtTextOut(dc, RescaleDpi(8), r.top + 1, ETO_CLIPPED | ETO_OPAQUE, &r, grouptext[i], (UINT)strlen(grouptext[i]), NULL); + } + DeleteFont(font); +} + +static LRESULT CALLBACK AdvancedBoxWndProc(HWND hwnd, UINT uMsg, WPARAM wParam, LPARAM lParam) { + switch (uMsg) { + case WM_PAINT: { + HandleWmPaintPaintbox(hwnd, &DrawInAdvancedBox); + return TRUE; + } + case WM_ERASEBKGND: + return TRUE; + + case WM_RBUTTONDOWN: { + int x = GET_X_LPARAM(lParam), y = GET_Y_LPARAM(lParam); + char buffer[256]; + + const AdvancedTextInfo *tp = ADVANCED_TEXT_INFOS; + for (size_t i = 0; i != ARRAYSIZE(ADVANCED_TEXT_INFOS); i++, tp++) { + if (x >= RescaleDpi(tp->indent) && y >= RescaleDpi(tp->y) && y < RescaleDpi(tp->y + 19)) { + HMENU menu = CreatePopupMenu(); + AppendMenu(menu, 0, 1, "Copy"); + POINT pt = {x, y}; + ClientToScreen(hwnd, &pt); + int rv = TrackPopupMenu(menu, TPM_NONOTIFY | TPM_RETURNCMD, pt.x, pt.y, 0, hwnd, NULL); + DestroyMenu(menu); + if (rv == 1) + SetClipboardString(GetAdvancedInfoValue(buffer, (int)i)); + return TRUE; + } + } + return TRUE; + } + } + return DefWindowProc(hwnd, uMsg, wParam, lParam); +} + +void InitializeClass(WNDPROC wndproc, const char *name) { + WNDCLASSEX wce = {0}; + wce.cbSize = sizeof(wce); + wce.lpfnWndProc = wndproc; + wce.hInstance = g_hinstance; + wce.lpszClassName = name; + wce.style = CS_HREDRAW | CS_VREDRAW; + wce.hCursor = LoadCursor(NULL, IDC_ARROW); + RegisterClassEx(&wce); +} + +static bool CreateMainWindow() { + LoadLibrary(TEXT("Riched20.dll")); + INITCOMMONCONTROLSEX ccx; + ccx.dwSize = sizeof(INITCOMMONCONTROLSEX); + ccx.dwICC = ICC_TAB_CLASSES; + InitCommonControlsEx(&ccx); + + InitializeClass(&PaintBoxWndProc, "PaintBox"); + InitializeClass(&GraphBoxWndProc, "GraphBox"); + InitializeClass(&AdvancedBoxWndProc, "AdvancedBox"); + + HDC dc = GetDC(0); + g_large_fonts = GetDeviceCaps(dc, LOGPIXELSX); + ReleaseDC(0, dc); + + g_message_taskbar_created = RegisterWindowMessage(TEXT("TaskbarCreated")); + + g_icons[0] = LoadIcon(GetModuleHandle(NULL), MAKEINTRESOURCE(IDI_ICON1)); + g_icons[1] = LoadIcon(GetModuleHandle(NULL), MAKEINTRESOURCE(IDI_ICON0)); + g_ui_window = CreateDialog(GetModuleHandle(NULL), MAKEINTRESOURCE(IDD_DIALOG1), NULL, &DlgProc); + + if (!g_ui_window) + return false; + + DragAcceptFiles(g_ui_window, TRUE); + + ChangeWindowMessageFilter(WM_DROPFILES, MSGFLT_ADD); + ChangeWindowMessageFilter(WM_COPYDATA, MSGFLT_ADD); + ChangeWindowMessageFilter(0x0049, MSGFLT_ADD); + ChangeWindowMessageFilter(WM_USER + 10, MSGFLT_ADD); + + TCITEM tabitem; + HWND hwnd_tab = GetDlgItem(g_ui_window, IDC_TAB); + hwndTab = hwnd_tab; + tabitem.mask = TCIF_TEXT; + tabitem.pszText = "Logs"; + TabCtrl_InsertItem(hwnd_tab, 0, &tabitem); + tabitem.pszText = "Charts"; + TabCtrl_InsertItem(hwnd_tab, 1, &tabitem); + tabitem.pszText = "Advanced"; + TabCtrl_InsertItem(hwnd_tab, 2, &tabitem); + SetWindowLong(hwnd_tab, GWL_EXSTYLE, GetWindowLong(hwnd_tab, GWL_EXSTYLE) | WS_EX_COMPOSITED); + + + + hwndEdit = GetDlgItem(g_ui_window, IDC_RICHEDIT21); + hwndPaintBox = GetDlgItem(g_ui_window, IDC_PAINTBOX); + hwndGraphBox = GetDlgItem(g_ui_window, IDC_GRAPHBOX); + hwndAdvancedBox = GetDlgItem(g_ui_window, IDC_ADVANCEDBOX); + + SetWindowLong(hwndEdit, GWL_EXSTYLE, GetWindowLong(hwndEdit, GWL_EXSTYLE) &~ WS_EX_CLIENTEDGE); + + // Create the status bar. + hwndStatus = CreateWindowEx( + WS_EX_COMPOSITED, STATUSCLASSNAME, NULL, + WS_CHILD | WS_VISIBLE, 0, 0, 0, 0, g_ui_window, + (HMENU)IDC_STATUSBAR, g_hinstance, NULL); + + HandleWindowSizing(); + UpdateTabSelection(); + return true; +} + +int WINAPI WinMain(HINSTANCE hInstance, HINSTANCE hPrevInstance, LPSTR lpCmdLine, int nShowCmd) { g_hinstance = hInstance; InitCpuFeatures(); + WSADATA wsaData = {0}; + WSAStartup(MAKEWORD(2, 2), &wsaData); + + bool minimize = false; + bool is_autostart = false; + const char *filename = NULL; + + for (int i = 1; i < __argc; i++) { + const char *arg = __argv[i]; + if (strcmp(arg, "/minimize") == 0) { + minimize = true; + } else if (strcmp(arg, "/minimize_on_connect") == 0) { + g_minimize_on_connect = true; + } else if (strcmp(arg, "/allow_pre_post") == 0) { + g_allow_pre_post = true; + } else if (strcmp(arg, "--service") == 0) { + RunProcessAsTunsafeServiceProcess(); + return 0; + } else if (strcmp(arg, "--delete-service-and-start") == 0) { + UninstallTunSafeWindowsService(); + } else if (strcmp(arg, "--autostart") == 0) { + is_autostart = true; + } else if (strcmp(arg, "--set-allow-pre-post") == 0) { + bool want = i + 1 < __argc && atoi(__argv[i + 1]) != 0; + RegCreateKeyEx(HKEY_LOCAL_MACHINE, "Software\\TunSafe", NULL, NULL, 0, KEY_ALL_ACCESS, NULL, &g_hklm_reg_key, NULL); + RegWriteInt(g_hklm_reg_key, "AllowPrePost", want); + return 0; + } else if (strcmp(arg, "--import") == 0) { + if (i + 1 >= __argc) return 1; + const char *filename = __argv[i + 1]; + return ImportFile(filename, true); + } else { + filename = arg; + break; + } + } + + SetProcessDPIAware(); + + RegCreateKeyEx(HKEY_CURRENT_USER, "Software\\TunSafe", NULL, NULL, 0, KEY_ALL_ACCESS, NULL, &g_reg_key, NULL); + RegCreateKeyEx(HKEY_LOCAL_MACHINE, "Software\\TunSafe", NULL, NULL, 0, KEY_ALL_ACCESS, NULL, &g_hklm_reg_key, NULL); + RegOpenKeyEx(HKEY_LOCAL_MACHINE, "Software\\TunSafe", 0, KEY_READ, &g_hklm_readonly_reg_key); + + g_startup_flags = RegReadInt(g_reg_key, "StartupFlags", 0); + + if (is_autostart) { + g_disable_connect_on_start = !(g_startup_flags & kStartupFlag_ConnectWhenWindowsStarts); + minimize = !!(g_startup_flags & kStartupFlag_MinimizeToTrayWhenWindowsStarts); + } + // Check if the app is already running. - CreateMutexA(0, FALSE, "TunSafe-f19e092db01cbe0fb6aee132f8231e5b71c98f90"); + g_runonce_mutex = CreateMutexA(0, FALSE, "TunSafe-f19e092db01cbe0fb6aee132f8231e5b71c98f90"); if (GetLastError() == ERROR_ALREADY_EXISTS) { HWND window = FindWindow("TunSafe-f19e092db01cbe0fb6aee132f8231e5b71c98f90", NULL); DWORD_PTR result; @@ -1041,103 +1763,64 @@ int WINAPI WinMain (HINSTANCE hInstance, HINSTANCE hPrevInstance, LPSTR lpCmdLin } return 1; } + + TOKEN_ELEVATION_TYPE toktype; + g_is_limited_uac_account = (GetProcessElevationType(&toktype) && toktype == TokenElevationTypeLimited); + g_is_tunsafe_service_running = IsTunsafeServiceRunning(); + bool want_use_service = !!(g_startup_flags & (kStartupFlag_BackgroundService | kStartupFlag_ForegroundService)); + + // Re-launch the process as administrator if the TunSafe service isn't running. + if ((!g_is_tunsafe_service_running || !want_use_service) && g_is_limited_uac_account) { + CloseHandle(g_runonce_mutex); + if (!RestartProcessAsAdministrator()) + MessageBoxA(0, "TunSafe needs to run as Administrator unless the TunSafe Service is started.", "TunSafe", MB_ICONWARNING); + return 0; + } + CreateNotificationWindow(); - WSADATA wsaData = {0}; - if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) { - RERROR("WSAStartup failed"); - return 1; - } - - LoadLibrary(TEXT("Riched20.dll")); - - g_backend = new TunsafeBackendWin32(); - - InitMyPostMessage(); - InitCommonControls(); - - g_icons[0] = LoadIcon(GetModuleHandle(NULL), MAKEINTRESOURCE(IDI_ICON1)); - g_icons[1] = LoadIcon(GetModuleHandle(NULL), MAKEINTRESOURCE(IDI_ICON0)); - g_ui_window = CreateDialog(GetModuleHandle(NULL), MAKEINTRESOURCE(IDD_DIALOG1), NULL, &DlgProc); - - if (!g_ui_window) - return 1; - - RegCreateKeyEx(HKEY_CURRENT_USER, "Software\\TunSafe", NULL, NULL, 0, KEY_ALL_ACCESS, NULL, &g_reg_key, NULL); - DragAcceptFiles(g_ui_window, TRUE); - - ChangeWindowMessageFilter(WM_DROPFILES, MSGFLT_ADD); - ChangeWindowMessageFilter(WM_COPYDATA, MSGFLT_ADD); - ChangeWindowMessageFilter(0x0049, MSGFLT_ADD); - - static const int ctrls[] = {IDTXT_UDP, IDTXT_TUN, IDTXT_HANDSHAKE}; - for (int i = 0; i < 3; i++) { - HWND w = GetDlgItem(g_ui_window, ctrls[i]); - SetWindowLong(w, GWL_EXSTYLE, GetWindowLong(w, GWL_EXSTYLE) | WS_EX_COMPOSITED); - } - - g_allow_pre_post = RegReadInt("AllowPrePost", 0) != 0; - - bool minimize = false; - const char *filename = NULL; - - for (size_t i = 1; i < __argc; i++) { - const char *arg = __argv[i]; - - if (_stricmp(arg, "/minimize") == 0) { - minimize = true; - } else if (_stricmp(arg, "/minimize_on_connect") == 0) { - g_minimize_on_connect = true; - } else if (_stricmp(arg, "/allow_pre_post") == 0) { - g_allow_pre_post = true; - } else { - filename = arg; - break; - } - } - - if (!minimize) { - g_ui_visible = true; - ShowWindow(g_ui_window, SW_SHOW); - } - - UpdateIcon(UIW_NONE); - + g_backend_delegate = CreateTunsafeBackendDelegateThreaded(&my_procdel, []() { + if (g_ui_window) + PostMessage(g_ui_window, WM_USER + 2, 0, 0); + }); g_logger = &PushLine; + if (!CreateMainWindow()) + return 1; + + g_current_filename = _strdup(""); + g_cmdline_filename = filename; + + if (!g_allow_pre_post && g_hklm_readonly_reg_key) + g_allow_pre_post = RegReadInt(g_hklm_readonly_reg_key, "AllowPrePost", 0) != 0; + + // Attempt to start service... + if (want_use_service && !g_is_tunsafe_service_running) { + RINFO("Starting TunSafe service..."); + InstallTunSafeWindowsService(); + } + + CreateLocalOrRemoteBackend(want_use_service); + + if (!minimize) { + SetUiVisibility(true); + } + UpdateIcon(UIW_NONE); EnsureConfigDirCreated(); - if (filename) { - LoadConfigFile(filename, false, false); - } else { - char *conf = RegReadStr("ConfigFile", "TunSafe.conf"); - LoadConfigFile(conf, false, false); - free(conf); - } - - // PrintCpuFeatures(); - -// Benchmark(); - - if (filename != NULL || RegReadInt("IsConnected", 0)) { - StartService(); - } else { - RINFO("Press Connect to initiate a connection to the WireGuard server."); - } - MSG msg; - while (GetMessage(&msg, NULL, 0, 0)) { if (!IsDialogMessage(g_ui_window, &msg)) { TranslateMessage(&msg); DispatchMessage(&msg); } } - StopService(UIW_EXITING); + + if (!g_backend->is_remote()) + g_backend->Stop(); + + delete g_backend; RemoveIcon(); return 0; } - - - diff --git a/util.cpp b/util.cpp index a601a0b..2269a3f 100644 --- a/util.cpp +++ b/util.cpp @@ -17,6 +17,7 @@ #include #endif +#include #include "tunsafe_types.h" static char base64_alphabet[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; @@ -133,6 +134,7 @@ int RunCommand(const char *fmt, ...) { char *args[33]; char *envp[1] = {NULL}; int nargs = 0; + bool didadd = false; va_start(va, fmt); for (;;) { c = *fmt++; @@ -140,13 +142,14 @@ int RunCommand(const char *fmt, ...) { c = *fmt++; if (c == 0) goto ZERO; if (c == 's') { - tmp += va_arg(va, char*); + char *arg = va_arg(va, char*); + if (arg != NULL) { + tmp += arg; + didadd = true; + } } else if (c == 'd') { snprintf(buf, 32, "%d", va_arg(va, int)); tmp += buf; - } else if (c == 'u') { - snprintf(buf, 32, "%u", va_arg(va, int)); - tmp += buf; } else if (c == '%') { tmp += '%'; } else if (c == 'A') { @@ -156,9 +159,12 @@ int RunCommand(const char *fmt, ...) { } } else if (c == ' ' || c == 0) { ZERO: - args[nargs++] = _strdup(tmp.c_str()); - tmp.clear(); - if (nargs == 32 || c == 0) break; + if (!tmp.empty() || didadd) { + args[nargs++] = _strdup(tmp.c_str()); + tmp.clear(); + if (nargs == 32 || c == 0) break; + } + didadd = false; } else { tmp += c; } @@ -187,7 +193,7 @@ ZERO: #endif if (ret != 0) - RERROR("Command %s failed %d!", fmt_org, ret); + RERROR("Command failed %d!", ret); return ret; } @@ -265,3 +271,29 @@ void RINFO(const char *msg, ...) { fputs("\n", stderr); } } + +void *memdup(const void *p, size_t size) { + void *x = malloc(size); + if (x) + memcpy(x, p, size); + return x; +} + +char *my_strndup(const char *p, size_t size) { + char *x = (char*)malloc(size + 1); + if (x) { + x[size] = 0; + memcpy(x, p, size); + } + return x; +} + +size_t my_strlcpy(char *dst, size_t dstsize, const char *src) { + size_t len = strlen(src); + if (dstsize) { + size_t lenx = std::min(dstsize - 1, len); + dst[lenx] = 0; + memcpy(dst, src, lenx); + } + return len; +} \ No newline at end of file diff --git a/util.h b/util.h index 48b8324..e0846f3 100644 --- a/util.h +++ b/util.h @@ -12,3 +12,14 @@ typedef void Logger(const char *msg); extern Logger *g_logger; +void *memdup(const void *p, size_t size); +char *my_strndup(const char *p, size_t size); + +size_t my_strlcpy(char *dst, size_t dstsize, const char *src); + + +template static inline T postinc(T&x, U v) { + T t = x; + x += v; + return t; +} diff --git a/util_win32.cpp b/util_win32.cpp new file mode 100644 index 0000000..1dec101 --- /dev/null +++ b/util_win32.cpp @@ -0,0 +1,378 @@ +// SPDX-License-Identifier: AGPL-1.0-only +// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. +#include "stdafx.h" +#include "util_win32.h" +#include +#include +#include +#include +#include +#include + +const char *FindFilenameComponent(const char *s) { + size_t len = strlen(s); + for (;;) { + if (len == 0) + return ""; + len--; + if (s[len] == '\\' || s[len] == '/') + break; + } + return s + len + 1; +} + +void str_set(char **x, const char *s) { + free(*x); + *x = _strdup(s); +} + +char *str_cat_alloc(const char * const *a, size_t n) { + if (n > 32) return NULL; + size_t len[32], totlen = 0; + for (size_t i = 0; i < n; i++) { + len[i] = strlen(a[i]); + totlen += len[i]; + } + char *r = (char *)malloc(totlen + 1); + totlen = 0; + for (size_t i = 0; i < n; i++) { + size_t n = len[i]; + memcpy(r + totlen, a[i], n); + totlen += n; + } + r[totlen] = 0; + return r; +} + +char *str_cat_alloc(const char *a, const char *b) { + const char * x[2] = {a, b}; + return str_cat_alloc(x, 2); +} + +char *str_cat_alloc(const char *a, const char *b, const char *c) { + const char * x[3] = {a, b, c}; + return str_cat_alloc(x, 3); +} + + +int RegReadInt(HKEY hkey, const char *key, int def) { + DWORD value = def, n = sizeof(value); + RegQueryValueEx(hkey, key, NULL, NULL, (BYTE*)&value, &n); + return value; +} + +void RegWriteInt(HKEY hkey, const char *key, int value) { + RegSetValueEx(hkey, key, NULL, REG_DWORD, (BYTE*)&value, sizeof(value)); +} + +char *RegReadStr(HKEY hkey, const char *key, const char *def) { + char buf[1024]; + DWORD n = sizeof(buf) - 1; + DWORD type = 0; + if (RegQueryValueEx(hkey, key, NULL, &type, (BYTE*)buf, &n) != ERROR_SUCCESS || type != REG_SZ) + return def ? _strdup(def) : NULL; + if (n && buf[n - 1] == 0) + n--; + buf[n] = 0; + return _strdup(buf); +} + +void RegWriteStr(HKEY hkey, const char *key, const char *v) { + RegSetValueEx(hkey, key, NULL, REG_SZ, (BYTE*)v, (DWORD)strlen(v) + 1); +} + +bool GetProcessElevationType(TOKEN_ELEVATION_TYPE *pOutElevationType) { + *pOutElevationType = TokenElevationTypeDefault; + bool fResult = false; + HANDLE hProcToken = NULL; + if (::OpenProcessToken(::GetCurrentProcess(), TOKEN_QUERY, &hProcToken)) { + DWORD dwSize = 0; + TOKEN_ELEVATION_TYPE elevationType = TokenElevationTypeDefault; + if (::GetTokenInformation(hProcToken, TokenElevationType, &elevationType, sizeof(elevationType), &dwSize) + && dwSize == sizeof(elevationType)) { + *pOutElevationType = elevationType; + fResult = true; + } + ::CloseHandle(hProcToken); + } + return fResult; +} + +/*++ +Routine Description: This routine returns TRUE if the caller's +process is a member of the Administrators local group. Caller is NOT +expected to be impersonating anyone and is expected to be able to +open its own process and process token. +Arguments: None. +Return Value: +TRUE - Caller has Administrators local group. +FALSE - Caller does not have Administrators local group. -- +*/ + +BOOL IsUserAdmin(VOID) { + BOOL b; + SID_IDENTIFIER_AUTHORITY NtAuthority = SECURITY_NT_AUTHORITY; + PSID AdministratorsGroup; + b = AllocateAndInitializeSid( + &NtAuthority, + 2, + SECURITY_BUILTIN_DOMAIN_RID, + DOMAIN_ALIAS_RID_ADMINS, + 0, 0, 0, 0, 0, 0, + &AdministratorsGroup); + if (b) { + if (!CheckTokenMembership(NULL, AdministratorsGroup, &b)) { + b = FALSE; + } + FreeSid(AdministratorsGroup); + } + + return(b); +} + + +const wchar_t *SkipAppNameInCommandLineArgs(const wchar_t *s) { + if (*s == '\"') { + for (;;) { + s++; + if (*s == 0) return s; + if (*s == '\"') return s + 1; + } + } else { + for (;;) { + if (*s == 0) return s; + if (*s == ' ') return s + 1; + s++; + } + } +} + + +uint8* LoadFileSane(const char *name, size_t *size) { + FILE *f = fopen(name, "rb"); + uint8 *new_file = NULL, *file = NULL; + size_t j, i, n; + if (!f) return false; + fseek(f, 0, SEEK_END); + long x = ftell(f); + fseek(f, 0, SEEK_SET); + if (x < 0 || x >= 65536) goto error; + file = (uint8*)malloc(x + 1); + if (!file) goto error; + n = fread(file, 1, x + 1, f); + if (n != x || !SanityCheckBuf(file, n)) + goto error; + // Convert the file to DOS new lines + for (i = j = 0; i < n; i++) + j += (file[i] == '\n'); + new_file = (uint8*)malloc(n + 1 + j); + if (!new_file) goto error; + for (i = j = 0; i < n; i++) { + uint8 c = file[i]; + if (c == '\r') + continue; + if (c == '\n') + new_file[j++] = '\r'; + new_file[j++] = c; + } + new_file[j] = 0; + *size = j; + +error: + fclose(f); + free(file); + return new_file; +} + +int WriteOutFile(const char *filename, uint8 *filedata, size_t filesize) { + FILE *f = fopen(filename, "wb"); + if (!f) return kWriteOutFile_AccessError; + if (fwrite(filedata, 1, filesize, f) != filesize) { + fclose(f); + return kWriteOutFile_OtherError; + } + fclose(f); + return kWriteOutFile_Ok; +} + +bool FileExists(const CHAR *fileName) { + DWORD fileAttr = GetFileAttributes(fileName); + return (0xFFFFFFFF != fileAttr); +} + +__int64 FileSize(const char* name) { + WIN32_FILE_ATTRIBUTE_DATA fad; + if (!GetFileAttributesEx(name, GetFileExInfoStandard, &fad)) + return -1; // error condition, could call GetLastError to find out more + LARGE_INTEGER size; + size.HighPart = fad.nFileSizeHigh; + size.LowPart = fad.nFileSizeLow; + return size.QuadPart; +} + +static bool is_space(uint8_t c) { + return c == ' ' || c == '\r' || c == '\n' || c == '\t'; +} + +static bool is_valid(uint8_t c) { + return c >= ' ' || c == '\r' || c == '\n' || c == '\t'; +} + +bool SanityCheckBuf(uint8 *buf, size_t n) { + for (size_t i = 0; i < n; i++) { + if (!is_space(buf[i])) { + if (buf[i] != '[' && buf[i] != '#') + return false; + for (; i < n; i++) + if (!is_valid(buf[i])) + return false; + return true; + } + } + return false; +} + +void FindDesktopFolderView(REFIID riid, void **ppv) { + CComPtr spShellWindows; + spShellWindows.CoCreateInstance(CLSID_ShellWindows); + + CComVariant vtLoc(CSIDL_DESKTOP); + CComVariant vtEmpty; + long lhwnd; + CComPtr spdisp; + spShellWindows->FindWindowSW( + &vtLoc, &vtEmpty, + SWC_DESKTOP, &lhwnd, SWFO_NEEDDISPATCH, &spdisp); + + CComPtr spBrowser; + CComQIPtr(spdisp)-> + QueryService(SID_STopLevelBrowser, + IID_PPV_ARGS(&spBrowser)); + + CComPtr spView; + spBrowser->QueryActiveShellView(&spView); + + spView->QueryInterface(riid, ppv); +} + +void GetDesktopAutomationObject(REFIID riid, void **ppv) { + CComPtr spsv; + FindDesktopFolderView(IID_PPV_ARGS(&spsv)); + CComPtr spdispView; + spsv->GetItemObject(SVGIO_BACKGROUND, IID_PPV_ARGS(&spdispView)); + spdispView->QueryInterface(riid, ppv); +} + +void ShellExecuteFromExplorer( + PCSTR pszFile, + PCSTR pszParameters, + PCSTR pszDirectory, + PCSTR pszOperation, + int nShowCmd) { + CComPtr spFolderView; + GetDesktopAutomationObject(IID_PPV_ARGS(&spFolderView)); + CComPtr spdispShell; + spFolderView->get_Application(&spdispShell); + + CComQIPtr(spdispShell) + ->ShellExecute(CComBSTR(pszFile), + CComVariant(pszParameters ? pszParameters : ""), + CComVariant(pszDirectory ? pszDirectory : ""), + CComVariant(pszOperation ? pszOperation : ""), + CComVariant(nShowCmd)); +} + +size_t GetConfigPath(char *path, size_t path_size) { + + if (!GetModuleFileName(NULL, path, (DWORD)path_size)) { + *path = 0; + return 0; + } + char *last = (char *)FindFilenameComponent(path); + if (!*last || last + 8 > path + path_size) { + *path = 0; + return 0; + } + memcpy(last, "Config\\", 8 * sizeof(last[0])); + return last + 7 - path; +} + +static bool ContainsDotDot(const char *path) { + for (uint8 last = 0, cur; (cur = path[0]) != '\0'; last = cur, path++) + if (cur == '.' && last == cur) + return true; + return false; +} + +bool EnsureValidConfigPath(const char *path) { + char buf[1024]; + + size_t len = GetConfigPath(buf, sizeof(buf)); + return (len != 0) && (strlen(path) > len && memcmp(path, buf, len) == 0 && !ContainsDotDot(path + len)); +} + +bool RunProcessAsAdminWithArgs(const char *args, bool wait_for_exit) { + SHELLEXECUTEINFO shExecInfo = {0}; + char buf[1024]; + + if (!GetModuleFileName(NULL, buf, 1024)) + return false; + shExecInfo.cbSize = sizeof(shExecInfo); + shExecInfo.lpVerb = "runas"; + shExecInfo.lpFile = buf; + shExecInfo.lpParameters = args; + shExecInfo.nShow = SW_SHOW; + shExecInfo.fMask = SEE_MASK_NOASYNC | wait_for_exit * SEE_MASK_NOCLOSEPROCESS; + if (!ShellExecuteExA(&shExecInfo)) + return false; + if (shExecInfo.hProcess) { + WaitForSingleObject(shExecInfo.hProcess, 10000); + CloseHandle(shExecInfo.hProcess); + } + return true; +} + +bool RestartProcessAsAdministrator() { + SHELLEXECUTEINFOW shExecInfo = {0}; + wchar_t buf[1024]; + + if (!GetModuleFileNameW(NULL, buf, 1024)) + return false; + +// shExecInfo.hwnd = window; + shExecInfo.cbSize = sizeof(shExecInfo); + shExecInfo.lpVerb = L"runas"; + shExecInfo.lpFile = buf; + shExecInfo.lpParameters = SkipAppNameInCommandLineArgs(GetCommandLineW()); + shExecInfo.nShow = SW_SHOW; + + return ShellExecuteExW(&shExecInfo) != 0; +} + +bool SetClipboardString(const char *string) { + bool ok = false; + if (OpenClipboard(NULL)) { + HGLOBAL hglb; + size_t len = strlen(string); + hglb = GlobalAlloc(GMEM_SHARE | GMEM_MOVEABLE, (len + 1) * sizeof(char)); + LPSTR lptstr = (LPSTR)GlobalLock(hglb); + memcpy(lptstr, string, len + 1); + GlobalUnlock(hglb); + EmptyClipboard(); + ok = SetClipboardData(CF_TEXT, hglb) != 0; + CloseClipboard(); + } + return ok; +} + +RECT GetParentRect(HWND wnd) { + RECT btrect; + GetClientRect(wnd, &btrect); + MapWindowPoints(wnd, GetParent(wnd), (LPPOINT)&btrect, 2); + return btrect; +} + +RECT MakeRect(int l, int t, int r, int b) { + RECT rr = { l, t, r, b }; + return rr; +} diff --git a/util_win32.h b/util_win32.h new file mode 100644 index 0000000..8497903 --- /dev/null +++ b/util_win32.h @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: AGPL-1.0-only +// Copyright (C) 2018 Ludvig Strigeus . All Rights Reserved. +#include "tunsafe_types.h" + +#pragma once +const char *FindFilenameComponent(const char *s); +void str_set(char **x, const char *s); + +char *str_cat_alloc(const char * const *a, size_t n); +char *str_cat_alloc(const char *a, const char *b); +char *str_cat_alloc(const char *a, const char *b, const char *c); + +int RegReadInt(HKEY hkey, const char *key, int def); +void RegWriteInt(HKEY hkey, const char *key, int value); +char *RegReadStr(HKEY hkey, const char *key, const char *def); +void RegWriteStr(HKEY hkey, const char *key, const char *v); + +// TokenElevationTypeDefault -- User is not using a split token. (e.g. UAC disabled or local admin "Administrator" account which UAC may not apply to.) +// TokenElevationTypeFull -- User has a split token, and the process is running elevated. +// TokenElevationTypeLimited -- User has a split token, but the process is not running elevated. +bool GetProcessElevationType(TOKEN_ELEVATION_TYPE *pOutElevationType); + + +const wchar_t *SkipAppNameInCommandLineArgs(const wchar_t *s); + +uint8* LoadFileSane(const char *name, size_t *size); + +enum { + kWriteOutFile_Ok = 0, + kWriteOutFile_AccessError = 1, + kWriteOutFile_OtherError = 2, +}; + +int WriteOutFile(const char *filename, uint8 *filedata, size_t filesize); + +bool SanityCheckBuf(uint8 *buf, size_t n); + +__int64 FileSize(const char* name); + +bool FileExists(const CHAR *fileName); + +void ShellExecuteFromExplorer( + PCSTR pszFile, + PCSTR pszParameters = nullptr, + PCSTR pszDirectory = nullptr, + PCSTR pszOperation = nullptr, + int nShowCmd = SW_SHOWNORMAL); + +size_t GetConfigPath(char *path, size_t path_size); +bool EnsureValidConfigPath(const char *path); + +bool RunProcessAsAdminWithArgs(const char *args, bool wait_for_exit); +bool RestartProcessAsAdministrator(); +bool SetClipboardString(const char *string); +RECT GetParentRect(HWND wnd); +RECT MakeRect(int l, int t, int r, int b); diff --git a/wireguard.cpp b/wireguard.cpp index ab9b393..2e0e72a 100644 --- a/wireguard.cpp +++ b/wireguard.cpp @@ -12,7 +12,9 @@ #include #include #include +#include "ipzip2/ipzip2.h" #include "wireguard.h" +#include "wireguard_config.h" uint64 OsGetMilliseconds(); @@ -35,11 +37,23 @@ WireguardProcessor::WireguardProcessor(UdpInterface *udp, TunInterface *tun, Pro dns_blocking_ = true; internet_blocking_ = kBlockInternet_Default; dns6_addr_.sin.sin_family = dns_addr_.sin.sin_family = 0; + + stats_last_bytes_in_ = 0; + stats_last_bytes_out_ = 0; + stats_last_ts_ = OsGetMilliseconds(); + + main_thread_scheduled_ = NULL; + main_thread_scheduled_last_ = &main_thread_scheduled_; } WireguardProcessor::~WireguardProcessor() { } +void WireguardProcessor::SetListenPort(int listen_port) { + listen_port_ = listen_port; +} + + bool WireguardProcessor::AddDnsServer(const IpAddr &sin) { IpAddr *target = (sin.sin.sin_family == AF_INET6) ? &dns6_addr_ : &dns_addr_; if (target->sin.sin_family != 0) @@ -48,7 +62,6 @@ bool WireguardProcessor::AddDnsServer(const IpAddr &sin) { return true; } - bool WireguardProcessor::SetTunAddress(const WgCidrAddr &addr) { WgCidrAddr *target = (addr.size == 128) ? &tun6_addr_ : &tun_addr_; if (target->size != 0) @@ -57,9 +70,37 @@ bool WireguardProcessor::SetTunAddress(const WgCidrAddr &addr) { return true; } +void WireguardProcessor::AddExcludedIp(const WgCidrAddr &cidr_addr) { + excluded_ips_.push_back(cidr_addr); +} -ProcessorStats WireguardProcessor::GetStats() { - stats_.last_complete_handskake_timestamp = dev_.last_complete_handskake_timestamp(); +void WireguardProcessor::SetMtu(int mtu) { + if (mtu >= 576 && mtu <= 10000) + mtu_ = mtu; +} + +void WireguardProcessor::SetAddRoutesMode(bool mode) { + add_routes_mode_ = mode; +} + +void WireguardProcessor::SetDnsBlocking(bool dns_blocking) { + dns_blocking_ = dns_blocking; +} + +void WireguardProcessor::SetInternetBlocking(InternetBlockState internet_blocking) { + internet_blocking_ = internet_blocking; +} + +void WireguardProcessor::SetHeaderObfuscation(const char *key) { + dev_.SetHeaderObfuscation(key); +} + +WgProcessorStats WireguardProcessor::GetStats() { + // todo: only supports one peer but i want this in the ui for now. + stats_.endpoint.sin.sin_family = 0; + WgPeer *peer = dev_.first_peer(); + if (peer) + stats_.endpoint = peer->endpoint_; return stats_; } @@ -92,6 +133,7 @@ static bool IsWgCidrAddrSubsetOf(const WgCidrAddr &inner, const WgCidrAddr &oute } bool WireguardProcessor::Start() { + assert(dev_.IsMainThread()); if (!udp_->Initialize(listen_port_)) return false; @@ -101,7 +143,7 @@ bool WireguardProcessor::Start() { } if (tun_addr_.cidr >= 31) { - RERROR("The TAP driver is not compatible with Address using CIDR /31 or /32. Changing to /24"); + RERROR("TAP is not compatible CIDR /31 or /32. Changing to /24"); tun_addr_.cidr = 24; } @@ -110,7 +152,8 @@ bool WireguardProcessor::Start() { config.cidr = tun_addr_.cidr; config.mtu = mtu_; config.pre_post_commands = pre_post_; - + config.excluded_ips = excluded_ips_; + uint32 netmask = tun_addr_.cidr == 32 ? 0xffffffff : 0xffffffff << (32 - tun_addr_.cidr); uint32 ipv4_broadcast_addr = (netmask == 0xffffffff) ? 0xffffffff : config.ip | ~netmask; @@ -130,6 +173,7 @@ bool WireguardProcessor::Start() { config.default_route_endpoint_v4 = (peer->endpoint_.sin.sin_family == AF_INET) ? ReadBE32(&peer->endpoint_.sin.sin_addr) : 0; // Set the default route to something config.use_ipv4_default_route = true; + peer->allow_endpoint_change_ = false; } // Also configure ipv6 gw? @@ -139,6 +183,7 @@ bool WireguardProcessor::Start() { if (peer->endpoint_.sin.sin_family == AF_INET6) memcpy(&config.default_route_endpoint_v6, &peer->endpoint_.sin6.sin6_addr, 16); config.use_ipv6_default_route = true; + peer->allow_endpoint_change_ = false; } } @@ -158,7 +203,8 @@ bool WireguardProcessor::Start() { uint8 dhcp_options[6]; - config.block_dns_on_adapters = dns_blocking_; + config.block_dns_on_adapters = dns_blocking_ && ((config.use_ipv4_default_route && dns_addr_.sin.sin_family == AF_INET) || + (config.use_ipv6_default_route && dns6_addr_.sin6.sin6_family == AF_INET6)); config.internet_blocking = internet_blocking_; if (dns_addr_.sin.sin_family == AF_INET) { @@ -187,7 +233,7 @@ bool WireguardProcessor::Start() { peer->ipv4_broadcast_addr_ = ipv4_broadcast_addr; if (peer->endpoint_.sin.sin_family != 0) { RINFO("Sending handshake..."); - SendHandshakeInitiationAndResetRetries(peer); + SendHandshakeInitiation(peer); } } @@ -222,10 +268,8 @@ struct ICMPv6NaPacketWithoutTarget { uint8 reserved[3]; uint8 target[16]; }; - #pragma pack (pop) - static uint16 ComputeIcmpv6Checksum(const uint8 *buf, int buf_size, const uint8 src_addr[16], const uint8 dst_addr[16]) { uint32 sum = 0; for (int i = 0; i < buf_size - 1; i += 2) @@ -242,28 +286,25 @@ static uint16 ComputeIcmpv6Checksum(const uint8 *buf, int buf_size, const uint8 return ((uint16)~sum); } - bool WireguardProcessor::HandleIcmpv6NeighborSolicitation(const byte *data, size_t data_size) { if (data_size < 48 + 16) return false; // Filter out neighbor solicitation - if (data[40] != kICMPv6_NeighborSolicitation || data[41] != 0) - return false; - - if (!network_discovery_spoofing_) + if (data[40] != kICMPv6_NeighborSolicitation || data[41] != 0 || !network_discovery_spoofing_) return false; bool is_broadcast = true; - if (memcmp(data + 24, kIcmpv6NeighborMulticastPrefix, sizeof(kIcmpv6NeighborMulticastPrefix)) != 0) { if (memcmp(data + 24, data + 48, 16) != 0) return false; is_broadcast = false; } - + // Target address must match a peer's range. + WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_); WgPeer *peer = (WgPeer*)dev_.ip_to_peer_map().LookupV6(data + 48); + WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_) if (peer == NULL) return false; @@ -273,8 +314,7 @@ bool WireguardProcessor::HandleIcmpv6NeighborSolicitation(const byte *data, size return false; byte *odata = out->data; - - int packet_size = is_broadcast ? sizeof(ICMPv6NaPacket) : sizeof(ICMPv6NaPacketWithoutTarget); + size_t packet_size = is_broadcast ? sizeof(ICMPv6NaPacket) : sizeof(ICMPv6NaPacketWithoutTarget); memcpy(odata, data, 4); WriteBE16(odata + 4, packet_size); @@ -298,10 +338,10 @@ bool WireguardProcessor::HandleIcmpv6NeighborSolicitation(const byte *data, size // For some reason this is openvpn's 'related mac' ((ICMPv6NaPacket*)(odata + 40))->target_mac[2] += 1; } - uint16 checksum = ComputeIcmpv6Checksum(odata + 40, packet_size, odata + 8, odata + 24); + uint16 checksum = ComputeIcmpv6Checksum(odata + 40, (int)packet_size, odata + 8, odata + 24); WriteBE16(&((ICMPv6NaPacket*)(odata + 40))->checksum, checksum); - out->size = 40 + packet_size; + out->size = (unsigned)(40 + packet_size); tun_->WriteTunPacket(out); return true; } @@ -317,9 +357,6 @@ void WireguardProcessor::HandleTunPacket(Packet *packet) { unsigned ip_version, size_from_header; WgPeer *peer; - stats_.tun_bytes_in += data_size; - stats_.tun_packets_in++; - // Sanity check that it looks like a valid ipv4 or ipv6 packet, // and determine the destination peer from the ip header if (data_size < IPV4_HEADER_SIZE) @@ -328,7 +365,9 @@ void WireguardProcessor::HandleTunPacket(Packet *packet) { ip_version = *data >> 4; if (ip_version == 4) { uint32 ip = ReadBE32(data + 16); + WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_); peer = (WgPeer*)dev_.ip_to_peer_map().LookupV4(ip); + WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_) if (peer == NULL) goto getout; if ((ip >= (224 << 24) || ip == peer->ipv4_broadcast_addr_) && !peer->allow_multicast_through_peer_) @@ -346,7 +385,9 @@ void WireguardProcessor::HandleTunPacket(Packet *packet) { if (data[6] == kIpProto_ICMPv6 && HandleIcmpv6NeighborSolicitation(data, data_size)) goto getout; + WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_); peer = (WgPeer*)dev_.ip_to_peer_map().LookupV6(data + 24); + WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_) if (peer == NULL) goto getout; @@ -359,10 +400,10 @@ void WireguardProcessor::HandleTunPacket(Packet *packet) { } if (size_from_header > data_size) goto getout; - if (peer->endpoint_.sin.sin_family == 0) - goto getout; - WritePacketToUdp(peer, packet); + // WriteAndEncryptPacketToUdp needs a held lock + WG_ACQUIRE_LOCK(peer->mutex_); + WriteAndEncryptPacketToUdp_WillUnlock(peer, packet); return; getout: @@ -370,25 +411,52 @@ getout: FreePacket(packet); } -void WireguardProcessor::WritePacketToUdp(WgPeer *peer, Packet *packet) { - byte *data = packet->data; - size_t size = packet->size; +void WgPeer::AddPacketToPeerQueue(Packet *packet) { + assert(IsPeerLocked()); + // Keep only the first MAX_QUEUED_PACKETS packets. + while (num_queued_packets_ >= MAX_QUEUED_PACKETS_PER_PEER) { + Packet *packet = first_queued_packet_; + first_queued_packet_ = packet->next; + num_queued_packets_--; + FreePacket(packet); + } + // Add the packet to the out queue that will get sent once handshake completes + *last_queued_packet_ptr_ = packet; + last_queued_packet_ptr_ = &packet->next; + packet->next = NULL; + num_queued_packets_++; +} + +// This function must be called with the peer lock held. It will remove the lock +void WireguardProcessor::WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet) { + assert(peer->IsPeerLocked()); + uint8 *data = packet->data, *ad; + size_t size = packet->size, ad_len, orig_size = size; bool want_handshake; + WgKeypair *keypair; uint64 send_ctr; - WgKeypair *keypair = peer->curr_keypair_; - - if (keypair == NULL || - keypair->send_key_state == WgKeypair::KEY_INVALID || - keypair->send_ctr >= REJECT_AFTER_MESSAGES) - goto getout_handshake; - - want_handshake = (keypair->send_ctr >= REKEY_AFTER_MESSAGES || - keypair->send_key_state == WgKeypair::KEY_WANT_REFRESH); // Ensure packet will fit including the biggest padding - if (size > kPacketCapacity - 15 - CHACHA20POLY1305_AUTHTAGLEN) + if (peer->endpoint_.sin.sin_family == 0 || + size > kPacketCapacity - 15 - CHACHA20POLY1305_AUTHTAGLEN) goto getout_discard; + if ((keypair = peer->curr_keypair_) == NULL || + (send_ctr = keypair->send_ctr) >= REJECT_AFTER_MESSAGES) { + peer->AddPacketToPeerQueue(packet); + WG_RELEASE_LOCK(peer->mutex_); + ScheduleNewHandshake(peer); + return; + } + + stats_.tun_bytes_in += size; + stats_.tun_packets_in++; + + want_handshake = (send_ctr >= REKEY_AFTER_MESSAGES || + keypair->send_key_state == WgKeypair::KEY_WANT_REFRESH); + keypair->send_ctr = send_ctr + 1; + packet->addr = peer->endpoint_; + if (size == 0) { peer->OnKeepaliveSent(); } else { @@ -416,7 +484,6 @@ add_padding: size += padding; } } - send_ctr = keypair->send_ctr++; #if WITH_SHORT_HEADERS if (keypair->enabled_features[WG_FEATURE_ID_SHORT_HEADER]) { @@ -434,8 +501,9 @@ add_padding: WriteLE32(write -= 4, (uint32)next_expected_packet); inner_tag = WG_ACK_HEADER_COUNTER_4; } else { - WriteLE64(write -= 8, next_expected_packet); - inner_tag = WG_ACK_HEADER_COUNTER_8; + WriteLE32(write -= 4, (uint32)next_expected_packet); + WriteLE16(write -= 2, (uint16)(next_expected_packet>>32)); + inner_tag = WG_ACK_HEADER_COUNTER_6; } if (keypair->broadcast_short_key != 0) { inner_tag += keypair->addr_entry_slot; @@ -448,6 +516,7 @@ add_padding: *--write = keypair->addr_entry_slot; tag += WG_SHORT_HEADER_ACK; } + byte *write_after_ack_header = write; // Determine the distance from the most recently acked packet, // be conservative when picking a suitable packet length to send. @@ -471,61 +540,54 @@ add_padding: WriteLE32(write -= 4, keypair->remote_key_id); *--write = tag; + // Not using any fields from now on + WG_RELEASE_LOCK(peer->mutex_); header_size = data - write; - stats_.compression_wg_saved_out += (int64)16 - header_size; - packet->data = data - header_size; packet->size = (int)(size + header_size + keypair->auth_tag_length); - WgKeypairEncryptPayload(data, size, write, data - write, send_ctr, keypair); + + // todo: figure out what to actually use as ad. + ad = write_after_ack_header; + ad_len = data - write_after_ack_header; } else { need_big_packet: #else { #endif // #if WITH_SHORT_HEADERS + // Not using any fields from now on + WG_RELEASE_LOCK(peer->mutex_); + ((MessageData*)data)[-1].type = ToLE32(MESSAGE_DATA); ((MessageData*)data)[-1].receiver_id = keypair->remote_key_id; ((MessageData*)data)[-1].counter = ToLE64(send_ctr); packet->data = data - sizeof(MessageData); packet->size = (int)(size + sizeof(MessageData) + keypair->auth_tag_length); - WgKeypairEncryptPayload(data, size, NULL, 0, send_ctr, keypair); + ad = NULL; + ad_len = 0; } - packet->addr = peer->endpoint_; + WgKeypairEncryptPayload(data, size, ad, ad_len, send_ctr, keypair); + DoWriteUdpPacket(packet); if (want_handshake) - SendHandshakeInitiationAndResetRetries(peer); + ScheduleNewHandshake(peer); return; getout_discard: + WG_RELEASE_LOCK(peer->mutex_); FreePacket(packet); return; - -getout_handshake: - // Keep only the first MAX_QUEUED_PACKETS packets. - while (peer->num_queued_packets_ >= MAX_QUEUED_PACKETS_PER_PEER) { - Packet *packet = peer->first_queued_packet_; - peer->first_queued_packet_ = packet->next; - peer->num_queued_packets_--; - FreePacket(packet); - } - // Add the packet to the out queue that will get sent once handshake completes - *peer->last_queued_packet_ptr_ = packet; - peer->last_queued_packet_ptr_ = &packet->next; - packet->next = NULL; - peer->num_queued_packets_++; - - SendHandshakeInitiationAndResetRetries(peer); } // This scrambles the initial 16 bytes of the packet with the -// trailing 8 bytes of the packet. +// next 8 bytes of the packet as a seed. static void ScrambleUnscramblePacket(Packet *packet, ScramblerSiphashKeys *keys) { uint8 *data = packet->data; size_t data_size = packet->size; - if (data_size < 8) + if (data_size <= 8) return; uint64 last_uint64 = ReadLE64(data_size >= 24 ? data + 16 : data + data_size - 8); @@ -537,10 +599,12 @@ static void ScrambleUnscramblePacket(Packet *packet, ScramblerSiphashKeys *keys) ((uint64*)data)[0] ^= a; ((uint64*)data)[1] ^= b; } else { - struct { uint64 a, b; } scramblers = {a, b}; - uint8 *s = (uint8*)&scramblers; + union { + uint64 d[2]; + uint8 s[16]; + } scrambler = {{a,b}}; for (size_t i = 0; i < data_size - 8; i++) - data[i] ^= s[i]; + data[i] ^= scrambler.s[i]; } } @@ -560,38 +624,81 @@ void WireguardProcessor::DoWriteUdpPacket(Packet *packet) { ScrambleUnscrambleAndWrite(packet, &dev_.header_obfuscation_key_, udp_); } -void WireguardProcessor::SendHandshakeInitiationAndResetRetries(WgPeer *peer) { - peer->handshake_attempts_ = 0; - SendHandshakeInitiation(peer); +void WireguardProcessor::ScheduleNewHandshake(WgPeer *peer) { + if (peer->main_thread_scheduled_.fetch_or(WgPeer::kMainThreadScheduled_ScheduleHandshake) == 0) { + peer->main_thread_scheduled_next_ = NULL; + WG_ACQUIRE_LOCK(main_thread_scheduled_lock_); + *main_thread_scheduled_last_ = peer; + main_thread_scheduled_last_ = &peer->main_thread_scheduled_next_; + WG_RELEASE_LOCK(main_thread_scheduled_lock_); + // todo: in multithreaded impl need to trigger |RunAllMainThreadScheduled| to get called + } +} + +void WireguardProcessor::RunAllMainThreadScheduled() { + assert(dev_.IsMainThread()); + + if (main_thread_scheduled_ == NULL) + return; + + WG_ACQUIRE_LOCK(main_thread_scheduled_lock_); + WgPeer *peer = main_thread_scheduled_; + main_thread_scheduled_ = NULL; + main_thread_scheduled_last_ = &main_thread_scheduled_; + WG_RELEASE_LOCK(main_thread_scheduled_lock_); + + while (peer) { + // todo: for the multithreaded use case figure out whether to use atomic_thread_fence here. + WgPeer *next = peer->main_thread_scheduled_next_; + uint32 ev = peer->main_thread_scheduled_.exchange(0); + if (ev & WgPeer::kMainThreadScheduled_ScheduleHandshake) { + peer->handshake_attempts_ = 0; + SendHandshakeInitiation(peer); + } + peer = next; + } } void WireguardProcessor::SendHandshakeInitiation(WgPeer *peer) { - // Send out a handshake init packet to trigger the handshake procedure + assert(dev_.IsMainThread()); + if (!peer->CheckHandshakeRateLimit()) return; + stats_.handshakes_out++; Packet *packet = AllocPacket(); - if (!packet) - return; - peer->CreateMessageHandshakeInitiation(packet); + if (packet) { + peer->CreateMessageHandshakeInitiation(packet); + WG_ACQUIRE_LOCK(peer->mutex_); + int attempts = ++peer->total_handshake_attempts_; + if (procdel_) + procdel_->OnConnectionRetry(attempts); + peer->OnHandshakeInitSent(); + packet->addr = peer->endpoint_; + WG_RELEASE_LOCK(peer->mutex_); + DoWriteUdpPacket(packet); + if (attempts > 1 && attempts <= 20) + RINFO("Retrying handshake, attempt %d...%s", attempts, (attempts == 20) ? " (last notice)" : ""); + } +} - packet->addr = peer->endpoint_; - DoWriteUdpPacket(packet); - peer->OnHandshakeInitSent(); +bool WireguardProcessor::IsMainThreadPacket(Packet *packet) { + // TODO(ludde): Support header obfuscation + return packet->size == 0 || (packet->data[0] != MESSAGE_DATA && !(packet->data[0] & WG_SHORT_HEADER_BIT)); } // Handles an incoming WireGuard packet from the UDP side, decrypt etc. void WireguardProcessor::HandleUdpPacket(Packet *packet, bool overload) { uint32 type; - stats_.udp_bytes_in += packet->size; - stats_.udp_packets_in++; - // Unscramble incoming packets #if WITH_HEADER_OBFUSCATION if (dev_.header_obfuscation_) ScrambleUnscramblePacket(packet, &dev_.header_obfuscation_key_); #endif // WITH_HEADER_OBFUSCATION + stats_.udp_bytes_in += packet->size; + stats_.udp_packets_in++; + if (packet->size < sizeof(uint32)) goto invalid_size; type = ReadLE32((uint32*)packet->data); @@ -604,22 +711,23 @@ void WireguardProcessor::HandleUdpPacket(Packet *packet, bool overload) { HandleShortHeaderFormatPacket(type, packet); #endif // WITH_SHORT_HEADERS } else if (type == MESSAGE_HANDSHAKE_COOKIE) { + assert(dev_.IsMainThread()); if (packet->size != sizeof(MessageHandshakeCookie)) goto invalid_size; HandleHandshakeCookiePacket(packet); } else if (type == MESSAGE_HANDSHAKE_INITIATION) { + assert(dev_.IsMainThread()); if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeInitiation)) : (packet->size != sizeof(MessageHandshakeInitiation))) goto invalid_size; - - if (!CheckIncomingHandshakeRateLimit(packet, overload)) - return; - HandleHandshakeInitiationPacket(packet); + stats_.handshakes_in++; + if (CheckIncomingHandshakeRateLimit(packet, overload)) + HandleHandshakeInitiationPacket(packet); } else if (type == MESSAGE_HANDSHAKE_RESPONSE) { + assert(dev_.IsMainThread()); if (WITH_HANDSHAKE_EXT ? (packet->size < sizeof(MessageHandshakeResponse)) : (packet->size != sizeof(MessageHandshakeResponse))) goto invalid_size; - if (!CheckIncomingHandshakeRateLimit(packet, overload)) - return; - HandleHandshakeResponsePacket(packet); + if (CheckIncomingHandshakeRateLimit(packet, overload)) + HandleHandshakeResponsePacket(packet); } else { // unknown packet invalid_size: @@ -628,7 +736,7 @@ invalid_size: } // Returns nonzero if two endpoints are different. -static uint32 CompareEndpoint(const IpAddr *a, const IpAddr *b) { +static uint32 CompareIpAddr(const IpAddr *a, const IpAddr *b) { uint32 rv = b->sin.sin_family ^ a->sin.sin_family; if (b->sin.sin_family != AF_INET6) { rv |= b->sin.sin_addr.s_addr ^ a->sin.sin_addr.s_addr; @@ -642,9 +750,10 @@ static uint32 CompareEndpoint(const IpAddr *a, const IpAddr *b) { return rv; } -void WgPeer::CopyEndpointToPeer(WgKeypair *keypair, const IpAddr *addr) { +void WgPeer::CopyEndpointToPeer_Locked(WgKeypair *keypair, const IpAddr *addr) { // Remember how to send packets to this peer - if (CompareEndpoint(&keypair->peer->endpoint_, addr)) { + if (keypair->peer->allow_endpoint_change_ && + CompareIpAddr(&keypair->peer->endpoint_, addr)) { #if WITH_SHORT_HEADERS // When the endpoint changes, forget about using the short key. keypair->broadcast_short_key = 0; @@ -660,28 +769,21 @@ void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packe size_t bytes_left = packet->size - 1; WgKeypair *keypair; uint64 counter, acked_counter; - uint8 ack_tag; + uint8 ack_tag, *ack_start; if ((tag & WG_SHORT_HEADER_KEY_ID_MASK) == 0x00) { // The key_id is explicitly included in the packet. if (bytes_left < 4) goto getout; uint32 key_id = ReadLE32(data); data += 4, bytes_left -= 4; - auto it = dev_.key_id_lookup().find(key_id); - if (it == dev_.key_id_lookup().end()) goto getout; - keypair = it->second.second; + keypair = dev_.LookupKeypairByKeyId(key_id); } else { // Lookup the packet source ip and port in the address mapping uint64 addr_id = packet->addr.sin.sin_addr.s_addr | ((uint64)packet->addr.sin.sin_port << 32); - auto it = dev_.addr_entry_map().find(addr_id); - if (it == dev_.addr_entry_map().end()) - goto getout; - WgAddrEntry *addr_entry = it->second; - keypair = addr_entry->keys[((tag / WG_SHORT_HEADER_KEY_ID) & 3) - 1]; + keypair = dev_.LookupKeypairInAddrEntryMap(addr_id, ((tag / WG_SHORT_HEADER_KEY_ID) & 3) - 1); } - if (!keypair || keypair->recv_key_state == WgKeypair::KEY_INVALID || - !keypair->enabled_features[WG_FEATURE_ID_SHORT_HEADER]) + if (!keypair || !keypair->enabled_features[WG_FEATURE_ID_SHORT_HEADER]) goto getout; // Pick the closest possible counter value with the same low bits. @@ -709,11 +811,13 @@ void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packe acked_counter = 0; ack_tag = 0; + ack_start = data; // If the acknowledge header is present, then parse it so we may // get an ack for the highest seen packet. if (tag & WG_SHORT_HEADER_ACK) { if (bytes_left == 0) goto getout; ack_tag = *data; + if (ack_tag & 0xF0) goto getout; // undefined bits data += 1, bytes_left -= 1; switch (ack_tag & WG_ACK_HEADER_COUNTER_MASK) { @@ -727,83 +831,104 @@ void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packe acked_counter = ReadLE32(data); data += 4, bytes_left -= 4; break; - case WG_ACK_HEADER_COUNTER_8: - if (bytes_left < 8) goto getout; - acked_counter = ReadLE64(data); - data += 8, bytes_left -= 8; + case WG_ACK_HEADER_COUNTER_6: + if (bytes_left < 6) goto getout; + acked_counter = ReadLE32(data) | ((uint64)ReadLE16(data + 4) << 32); + data += 6, bytes_left -= 6; break; default: - break; + goto getout; } } if (counter >= REJECT_AFTER_MESSAGES) goto getout; // Authenticate the packet before we can apply the state changes. - if (!WgKeypairDecryptPayload(data, bytes_left, packet->data, data - packet->data, counter, keypair)) + if (!WgKeypairDecryptPayload(data, bytes_left, ack_start, data - ack_start, counter, keypair)) goto getout; + WG_ACQUIRE_LOCK(keypair->peer->mutex_); + + if (keypair->recv_key_state == WgKeypair::KEY_INVALID) + goto getout_unlock; + if (!keypair->replay_detector.CheckReplay(counter)) - goto getout; + goto getout_unlock; stats_.compression_wg_saved_in += 16 - (data - packet->data); keypair->send_ctr_acked = std::max(keypair->send_ctr_acked, acked_counter); keypair->incoming_packet_count++; - WgPeer::CopyEndpointToPeer(keypair, &packet->addr); + WgPeer::CopyEndpointToPeer_Locked(keypair, &packet->addr); // Periodically broadcast out the short key if ((tag & WG_SHORT_HEADER_KEY_ID_MASK) == 0x00 && !keypair->did_attempt_remember_ip_port) { keypair->did_attempt_remember_ip_port = true; if (keypair->enabled_features[WG_FEATURE_ID_SKIP_KEYID_IN]) { uint64 addr_id = packet->addr.sin.sin_addr.s_addr | ((uint64)packet->addr.sin.sin_port << 32); - dev_.UpdateKeypairAddrEntry(addr_id, keypair); + dev_.UpdateKeypairAddrEntry_Locked(addr_id, keypair); } } - // Ack header may also signal that we can omit the key id in packets from now on. if (tag & WG_SHORT_HEADER_ACK) keypair->can_use_short_key_for_outgoing = (ack_tag & WG_ACK_HEADER_KEY_MASK) * WG_SHORT_HEADER_KEY_ID; - HandleAuthenticatedDataPacket(keypair, packet, data, bytes_left - keypair->auth_tag_length); + HandleAuthenticatedDataPacket_WillUnlock(keypair, packet, data, bytes_left - keypair->auth_tag_length); return; +getout_unlock: + WG_RELEASE_LOCK(keypair->peer->mutex_); getout: FreePacket(packet); return; } #endif // WITH_SHORT_HEADERS -void WireguardProcessor::HandleAuthenticatedDataPacket(WgKeypair *keypair, Packet *packet, uint8 *data, size_t data_size) { +void WireguardProcessor::NotifyHandshakeComplete() { + uint64 now = OsGetMilliseconds(); + + // todo: should lock something + stats_.last_complete_handshake_timestamp = now; + if (stats_.first_complete_handshake_timestamp == 0) + stats_.first_complete_handshake_timestamp = now; + + if (procdel_) + procdel_->OnConnected(); +} + +void WireguardProcessor::HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet, uint8 *data, size_t data_size) { WgPeer *peer = keypair->peer; + assert(peer->IsPeerLocked()); // Promote the next key to the current key when we receive a data packet, // the handshake is now complete. - if (peer->CheckSwitchToNextKey(keypair)) { - if (procdel_) { - procdel_->OnConnected(ReadBE32(tun_addr_.addr)); - } + if (peer->CheckSwitchToNextKey_Locked(keypair)) { + stats_.handshakes_in_success++; peer->OnHandshakeFullyComplete(); - SendQueuedPackets(peer); + NotifyHandshakeComplete(); + SendQueuedPackets_Locked(peer); } // Refresh when current key gets too old - if (peer->curr_keypair_ && peer->curr_keypair_->recv_key_state == WgKeypair::KEY_WANT_REFRESH) { - peer->curr_keypair_->recv_key_state = WgKeypair::KEY_DID_REFRESH; - SendHandshakeInitiationAndResetRetries(peer); + WgKeypair *curr_keypair = peer->curr_keypair_; + if (curr_keypair && curr_keypair->recv_key_state == WgKeypair::KEY_WANT_REFRESH) { + curr_keypair->recv_key_state = WgKeypair::KEY_DID_REFRESH; + ScheduleNewHandshake(peer); } if (data_size == 0) { peer->OnKeepaliveReceived(); + WG_RELEASE_LOCK(peer->mutex_); goto getout; } peer->OnDataReceived(); + WG_RELEASE_LOCK(peer->mutex_); #if WITH_HANDSHAKE_EXT // Unpack the packet headers using ipzip if (keypair->enabled_features[WG_FEATURE_ID_IPZIP]) { uint32 rv = IpzipDecompress(data, (uint32)data_size, &keypair->ipzip_state_, IPZIP_RECV_BY_CLIENT); if (rv == (uint32)-1) - goto getout; // ipzip failed decompress + goto getout; stats_.compression_hdr_saved_in += (int64)rv - data_size; data -= (int64)rv - data_size, data_size = rv; } @@ -816,36 +941,30 @@ void WireguardProcessor::HandleAuthenticatedDataPacket(WgKeypair *keypair, Packe ip_version = *data >> 4; if (ip_version == 4) { - if (data_size < IPV4_HEADER_SIZE) { - // too small ipv4 header - goto getout; - } + if (data_size < IPV4_HEADER_SIZE) + goto getout_error_header; + WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_); peer_from_header = (WgPeer*)dev_.ip_to_peer_map().LookupV4(ReadBE32(data + 12)); + WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_) size_from_header = ReadBE16(data + 2); if (size_from_header < IPV4_HEADER_SIZE) { // too small packet? - goto getout; + goto getout_error_header; } } else if (ip_version == 6) { - if (data_size < IPV6_HEADER_SIZE) { - // too small ipv6 header - goto getout; - } + if (data_size < IPV6_HEADER_SIZE) + goto getout_error_header; + WG_ACQUIRE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_); peer_from_header = (WgPeer*)dev_.ip_to_peer_map().LookupV6(data + 8); + WG_RELEASE_RWLOCK_SHARED(dev_.ip_to_peer_map_lock_) size_from_header = IPV6_HEADER_SIZE + ReadBE16(data + 4); } else { // invalid ip version - goto getout; + goto getout_error_header; } - if (size_from_header > data_size) { - // oversized packet? - goto getout; - } - if (peer_from_header != peer) { - // source address mismatch? - goto getout; - } - //RINFO("Outgoing TUN packet of size %d", (int)size_from_header); + if (peer_from_header != peer || size_from_header > data_size) + goto getout_error_header; + packet->data = data; packet->size = size_from_header; @@ -855,9 +974,10 @@ void WireguardProcessor::HandleAuthenticatedDataPacket(WgKeypair *keypair, Packe tun_->WriteTunPacket(packet); return; +getout_error_header: + stats_.error_header++; getout: FreePacket(packet); - return; } void WireguardProcessor::HandleDataPacket(Packet *packet) { @@ -865,29 +985,33 @@ void WireguardProcessor::HandleDataPacket(Packet *packet) { size_t data_size = packet->size; uint32 key_id = ((MessageData*)data)->receiver_id; uint64 counter = ToLE64((((MessageData*)data)->counter)); - WgKeypair *keypair; - - auto it = dev_.key_id_lookup().find(key_id); - if (it == dev_.key_id_lookup().end() || - (keypair = it->second.second) == NULL || - keypair->recv_key_state == WgKeypair::KEY_INVALID) { + WgKeypair *keypair = dev_.LookupKeypairByKeyId(key_id); + if (keypair == NULL || counter >= REJECT_AFTER_MESSAGES) { + stats_.error_key_id++; getout: FreePacket(packet); return; } - if (counter >= REJECT_AFTER_MESSAGES) - goto getout; - if (!WgKeypairDecryptPayload(data + sizeof(MessageData), data_size - sizeof(MessageData), - NULL, 0, counter, keypair)) { + NULL, 0, counter, keypair)) { + stats_.error_mac++; goto getout; } - if (!keypair->replay_detector.CheckReplay(counter)) - goto getout; - WgPeer::CopyEndpointToPeer(keypair, &packet->addr); - HandleAuthenticatedDataPacket(keypair, packet, data + sizeof(MessageData), data_size - sizeof(MessageData) - keypair->auth_tag_length); + WG_ACQUIRE_LOCK(keypair->peer->mutex_); + if (keypair->recv_key_state == WgKeypair::KEY_INVALID) { + stats_.error_key_id++; + WG_RELEASE_LOCK(keypair->peer->mutex_); + goto getout; + } else if (!keypair->replay_detector.CheckReplay(counter)) { + stats_.error_duplicate++; + WG_RELEASE_LOCK(keypair->peer->mutex_); + goto getout; + } else { + WgPeer::CopyEndpointToPeer_Locked(keypair, &packet->addr); + HandleAuthenticatedDataPacket_WillUnlock(keypair, packet, data + sizeof(MessageData), data_size - sizeof(MessageData) - keypair->auth_tag_length); + } } static uint64 GetIpForRateLimit(Packet *packet) { @@ -899,54 +1023,55 @@ static uint64 GetIpForRateLimit(Packet *packet) { } bool WireguardProcessor::CheckIncomingHandshakeRateLimit(Packet *packet, bool overload) { + assert(dev_.IsMainThread()); WgRateLimit::RateLimitResult rr = dev_.rate_limiter()->CheckRateLimit(GetIpForRateLimit(packet)); if ((overload && rr.is_rate_limited()) || !dev_.CheckCookieMac1(packet)) { FreePacket(packet); return false; } + dev_.rate_limiter()->CommitResult(rr); if (overload && !rr.is_first_ip() && !dev_.CheckCookieMac2(packet)) { - dev_.rate_limiter()->CommitResult(rr); dev_.CreateCookieMessage((MessageHandshakeCookie*)packet->data, packet, ((MessageHandshakeInitiation*)packet->data)->sender_key_id); packet->size = sizeof(MessageHandshakeCookie); DoWriteUdpPacket(packet); return false; } - dev_.rate_limiter()->CommitResult(rr); return true; } // server receives this when client wants to setup a session void WireguardProcessor::HandleHandshakeInitiationPacket(Packet *packet) { + assert(dev_.IsMainThread()); WgPeer *peer = WgPeer::ParseMessageHandshakeInitiation(&dev_, packet); - if (!peer) { + if (peer) { + DoWriteUdpPacket(packet); + } else { FreePacket(packet); - return; } - peer->OnHandshakeAuthComplete(); - DoWriteUdpPacket(packet); } // client receives this after session is established void WireguardProcessor::HandleHandshakeResponsePacket(Packet *packet) { + assert(dev_.IsMainThread()); WgPeer *peer = WgPeer::ParseMessageHandshakeResponse(&dev_, packet); - if (!peer) { - FreePacket(packet); - return; + if (peer) { + stats_.handshakes_out_success++; + WG_SCOPED_LOCK(peer->mutex_); + if (peer->allow_endpoint_change_) + peer->endpoint_ = packet->addr; + peer->OnHandshakeAuthComplete(); + peer->OnHandshakeFullyComplete(); + NotifyHandshakeComplete(); + SendKeepalive_Locked(peer); } - peer->endpoint_ = packet->addr; FreePacket(packet); - peer->OnHandshakeAuthComplete(); - peer->OnHandshakeFullyComplete(); - if (procdel_) - procdel_->OnConnected(ReadBE32(tun_addr_.addr)); - SendKeepalive(peer); } -void WireguardProcessor::SendKeepalive(WgPeer *peer) { +void WireguardProcessor::SendKeepalive_Locked(WgPeer *peer) { + assert(dev_.IsMainThread() && peer->IsPeerLocked()); // can't send keepalive if no endpoint is configured if (peer->endpoint_.sin.sin_family == 0) return; - // If nothing is queued, insert a keepalive packet if (peer->first_queued_packet_ == NULL) { Packet *packet = AllocPacket(); @@ -956,43 +1081,70 @@ void WireguardProcessor::SendKeepalive(WgPeer *peer) { packet->next = NULL; peer->first_queued_packet_ = packet; } - SendQueuedPackets(peer); + SendQueuedPackets_Locked(peer); } -void WireguardProcessor::SendQueuedPackets(WgPeer *peer) { - // Steal the packets +void WireguardProcessor::SendQueuedPackets_Locked(WgPeer *peer) { + assert(peer->IsPeerLocked()); + // Steal the queue of all packets and send them all. Packet *packet = peer->first_queued_packet_; peer->first_queued_packet_ = NULL; peer->last_queued_packet_ptr_ = &peer->first_queued_packet_; peer->num_queued_packets_ = 0; - while (packet) { + while (packet != NULL) { Packet *next = packet->next; - WritePacketToUdp(peer, packet); + WriteAndEncryptPacketToUdp_WillUnlock(peer, packet); packet = next; + WG_ACQUIRE_LOCK(peer->mutex_); // WriteAndEncryptPacketToUdp_WillUnlock releases the lock } } void WireguardProcessor::HandleHandshakeCookiePacket(Packet *packet) { + assert(dev_.IsMainThread()); WgPeer::ParseMessageHandshakeCookie(&dev_, (MessageHandshakeCookie *)packet->data); } +// Only one thread may run the second loop void WireguardProcessor::SecondLoop() { + assert(dev_.IsMainThread()); uint64 now = OsGetMilliseconds(); + + uint64 bytes_in = stats_.tun_bytes_in - stats_last_bytes_in_; + uint64 bytes_out = stats_.tun_bytes_out - stats_last_bytes_out_; + + stats_last_bytes_in_ = stats_.tun_bytes_in; + stats_last_bytes_out_ = stats_.tun_bytes_out; + + uint64 millis = now - stats_last_ts_; + stats_last_ts_ = now; + + double f = 1000.0 / std::max((uint32)millis, 500); + + stats_.tun_bytes_in_per_second = (float)(bytes_in * f); + stats_.tun_bytes_out_per_second = (float)(bytes_out * f); + for (WgPeer *peer = dev_.first_peer(); peer; peer = peer->next_peer_) { + WgKeypair *keypair = peer->curr_keypair_; // Allow ip/port to be remembered again for this keypair - if (peer->curr_keypair_) - peer->curr_keypair_->did_attempt_remember_ip_port = false; + if (keypair) + keypair->did_attempt_remember_ip_port = false; - uint32 mask = peer->CheckTimeouts(now); - if (mask == 0) - continue; - if (mask & WgPeer::ACTION_SEND_KEEPALIVE) - SendKeepalive(peer); - if (mask & WgPeer::ACTION_SEND_HANDSHAKE) - SendHandshakeInitiation(peer); + // Avoid taking the lock if it seems unneccessary + if (now >= peer->time_of_next_key_event_ || peer->timers_ != 0) { + uint32 mask; + { + WG_SCOPED_LOCK(peer->mutex_); + mask = peer->CheckTimeouts(now); + if (mask == 0) + continue; + if (mask & WgPeer::ACTION_SEND_KEEPALIVE) + SendKeepalive_Locked(peer); + } + if (mask & WgPeer::ACTION_SEND_HANDSHAKE) + SendHandshakeInitiation(peer); + } } dev_.SecondLoop(now); } - diff --git a/wireguard.h b/wireguard.h index ef050c5..75a10a6 100644 --- a/wireguard.h +++ b/wireguard.h @@ -5,24 +5,50 @@ #include "tunsafe_types.h" #include "wireguard_proto.h" -struct ProcessorStats { - // Number of bytes sent/received over the physical UDP connections - int64 udp_bytes_in, udp_bytes_out; - int64 udp_packets_in, udp_packets_out; - // Number of bytes sent/received over the TUN interface - int64 tun_bytes_in, tun_bytes_out; - int64 tun_packets_in, tun_packets_out; - uint64 last_complete_handskake_timestamp; +// todo: for multithreaded use case need to use atomic ops. +struct WgProcessorStats { + // Number of bytes sent/received over the physical UDP connection + uint64 udp_bytes_in, udp_bytes_out; + uint64 udp_packets_in, udp_packets_out; + // Number of valid packets sent/received over the TUN interface + uint64 tun_bytes_in, tun_bytes_out; + uint64 tun_packets_in, tun_packets_out; + + // Error types + uint32 error_key_id; + uint32 error_mac; + uint32 error_duplicate; + uint32 error_source_addr; + uint32 error_header; + + // Current speed of TUN packets + float tun_bytes_in_per_second, tun_bytes_out_per_second; + + // Timestamp of handshakes + uint64 first_complete_handshake_timestamp; + uint64 last_complete_handshake_timestamp; + + // How much saved from header compression int64 compression_hdr_saved_in, compression_hdr_saved_out; - int64 compression_wg_saved_in, compression_wg_saved_out; + + // Number of handshakes received and sent + // Number of successful handshakes in and out + uint32 handshakes_in, handshakes_out; + uint32 handshakes_in_success, handshakes_out_success; + + // Key stuff + uint8 public_key[32]; + + // Address of the endpoint + IpAddr endpoint; }; class ProcessorDelegate { public: - virtual void OnConnected(in_addr_t my_ip) = 0; - virtual void OnDisconnected() = 0; + virtual void OnConnected() = 0; + virtual void OnConnectionRetry(uint32 attempts) = 0; }; enum InternetBlockState { @@ -42,62 +68,46 @@ public: WireguardProcessor(UdpInterface *udp, TunInterface *tun, ProcessorDelegate *procdel); ~WireguardProcessor(); - void SetListenPort(int listen_port) { - listen_port_ = listen_port; - } - - bool SetTunAddress(const WgCidrAddr &addr); - + void SetListenPort(int listen_port); bool AddDnsServer(const IpAddr &sin); - - void SetMtu(int mtu) { - if (mtu >= 576 && mtu <= 10000) - mtu_ = mtu; - } - - void SetAddRoutesMode(bool mode) { - add_routes_mode_ = mode; - } - - void SetDnsBlocking(bool dns_blocking) { - dns_blocking_ = dns_blocking; - } - - void SetInternetBlocking(InternetBlockState internet_blocking) { - internet_blocking_ = internet_blocking; - } - - void SetHeaderObfuscation(const char *key) { - dev_.SetHeaderObfuscation(key); - } + bool SetTunAddress(const WgCidrAddr &addr); + void AddExcludedIp(const WgCidrAddr &cidr_addr); + void SetMtu(int mtu); + void SetAddRoutesMode(bool mode); + void SetDnsBlocking(bool dns_blocking); + void SetInternetBlocking(InternetBlockState internet_blocking); + void SetHeaderObfuscation(const char *key); void HandleTunPacket(Packet *packet); void HandleUdpPacket(Packet *packet, bool overload); + static bool IsMainThreadPacket(Packet *packet); + void SecondLoop(); - ProcessorStats GetStats(); + WgProcessorStats GetStats(); void ResetStats(); bool Start(); WgDevice &dev() { return dev_; } - TunInterface::PrePostCommands &prepost() { return pre_post_; } + const WgCidrAddr &tun_addr() { return tun_addr_; } + void RunAllMainThreadScheduled(); private: void DoWriteUdpPacket(Packet *packet); - void WritePacketToUdp(WgPeer *peer, Packet *packet); + void WriteAndEncryptPacketToUdp_WillUnlock(WgPeer *peer, Packet *packet); void SendHandshakeInitiation(WgPeer *peer); - void SendHandshakeInitiationAndResetRetries(WgPeer *peer); - void SendKeepalive(WgPeer *peer); - void SendQueuedPackets(WgPeer *peer); + void ScheduleNewHandshake(WgPeer *peer); + void SendKeepalive_Locked(WgPeer *peer); + void SendQueuedPackets_Locked(WgPeer *peer); void HandleHandshakeInitiationPacket(Packet *packet); void HandleHandshakeResponsePacket(Packet *packet); void HandleHandshakeCookiePacket(Packet *packet); void HandleDataPacket(Packet *packet); - void HandleAuthenticatedDataPacket(WgKeypair *keypair, Packet *packet, uint8 *data, size_t data_size); + void HandleAuthenticatedDataPacket_WillUnlock(WgKeypair *keypair, Packet *packet, uint8 *data, size_t data_size); void HandleShortHeaderFormatPacket(uint32 tag, Packet *packet); @@ -106,6 +116,7 @@ private: bool HandleIcmpv6NeighborSolicitation(const byte *data, size_t data_size); void SetupCompressionHeader(WgPacketCompressionVer01 *c); + void NotifyHandshakeComplete(); int listen_port_; @@ -113,12 +124,13 @@ private: TunInterface *tun_; UdpInterface *udp_; int mtu_; - ProcessorStats stats_; + WgProcessorStats stats_; bool dns_blocking_; uint8 internet_blocking_; bool add_routes_mode_; bool network_discovery_spoofing_; + bool did_have_first_handshake_; uint8 network_discovery_mac_[6]; WgDevice dev_; @@ -129,5 +141,15 @@ private: IpAddr dns_addr_, dns6_addr_; TunInterface::PrePostCommands pre_post_; + + // Queue of things scheduled to run on the main thread. + WG_DECLARE_LOCK(main_thread_scheduled_lock_); + WgPeer *main_thread_scheduled_, **main_thread_scheduled_last_; + + uint64 stats_last_bytes_in_, stats_last_bytes_out_; + uint64 stats_last_ts_; + + // IPs we want to map to the default route + std::vector excluded_ips_; }; diff --git a/wireguard_config.cpp b/wireguard_config.cpp index 3d51f62..dbc67b0 100644 --- a/wireguard_config.cpp +++ b/wireguard_config.cpp @@ -20,6 +20,10 @@ #include #endif +#if defined(OS_WIN) +#include "network_win32_dnsblock.h" +#endif + const char *print_ip_prefix(char buf[kSizeOfAddress], int family, const void *ip, int prefixlen) { if (!inet_ntop(family, ip, buf, kSizeOfAddress - 8)) { memcpy(buf, "unknown", 8); @@ -29,6 +33,17 @@ const char *print_ip_prefix(char buf[kSizeOfAddress], int family, const void *ip return buf; } +char *PrintIpAddr(const IpAddr &addr, char buf[kSizeOfAddress]) { + if (addr.sin.sin_family == AF_INET) { + print_ip_prefix(buf, addr.sin.sin_family, &addr.sin.sin_addr, -1); + } else if (addr.sin.sin_family == AF_INET) { + print_ip_prefix(buf, addr.sin.sin_family, &addr.sin6.sin6_addr, -1); + } else { + buf[0] = 0; + } + return buf; +} + struct Addr { byte addr[4]; uint8 cidr; @@ -58,19 +73,71 @@ static bool ParseCidrAddr(char *s, WgCidrAddr *out) { return false; } -struct hostent *gethostbyname_retry_on_failure(const char * name, bool *exit_flag) { +DnsResolver::DnsResolver(DnsBlocker *dns_blocker) { + dns_blocker_ = dns_blocker; + abort_flag_ = false; +} + +DnsResolver::~DnsResolver() { +} + +void DnsResolver::ClearCache() { + cache_.clear(); +} + +bool DnsResolver::Resolve(const char *hostname, IpAddr *result) { int attempt = 0; - static const uint8 retry_delays[] = {1, 2, 3, 5, 10, 20, 40, 60}; + static const uint8 retry_delays[] = {1, 2, 3, 5, 10}; + char buf[kSizeOfAddress]; + + memset(result, 0, sizeof(IpAddr)); + if (inet_pton(AF_INET6, hostname, &result->sin6.sin6_addr) == 1) { + result->sin.sin_family = AF_INET6; + return true; + } + + if (inet_pton(AF_INET, hostname, &result->sin.sin_addr) == 1) { + result->sin.sin_family = AF_INET; + return true; + } + + // First check cache + for (auto it = cache_.begin(); it != cache_.end(); ++it) { + if (it->name == hostname) { + + *result = it->ip; + RINFO("Resolved %s to %s%s", hostname, PrintIpAddr(*result, buf), " (cached)"); + return true; + } + } + +#if defined(OS_WIN) + // Then disable dns blocker (otherwise the windows dns client service can't resolve) + if (dns_blocker_ && dns_blocker_->IsActive()) { + RINFO("Disabling DNS blocker to resolve %s", hostname); + dns_blocker_->RestoreDns(); + } +#endif // defined(OS_WIN) for (;;) { - hostent *he = gethostbyname(name); - if (he || exit_flag == NULL || *exit_flag) - return he; + hostent *he = gethostbyname(hostname); + if (abort_flag_) + return false; - RINFO("Unable to resolve %s. Trying again in %d second(s)", name, retry_delays[attempt]); + if (he) { + result->sin.sin_family = AF_INET; + result->sin.sin_port = 0; + memcpy(&result->sin.sin_addr, he->h_addr_list[0], 4); + // add to cache + cache_.emplace_back(hostname, *result); + RINFO("Resolved %s to %s%s", hostname, PrintIpAddr(*result, buf), ""); + return true; + } + + RINFO("Unable to resolve %s. Trying again in %d second(s)", hostname, retry_delays[attempt]); OsInterruptibleSleep(retry_delays[attempt] * 1000); - if (*exit_flag) - return NULL; + if (abort_flag_) + return false; if (attempt != ARRAY_SIZE(retry_delays) - 1) attempt++; @@ -78,7 +145,9 @@ struct hostent *gethostbyname_retry_on_failure(const char * name, bool *exit_fla } -static bool ParseSockaddrInWithPort(char *s, IpAddr *sin, bool *exit_flag) { + + +static bool ParseSockaddrInWithPort(char *s, IpAddr *sin, DnsResolver *resolver) { memset(sin, 0, sizeof(IpAddr)); if (*s == '[') { char *end = strchr(s, ']'); @@ -97,30 +166,20 @@ static bool ParseSockaddrInWithPort(char *s, IpAddr *sin, bool *exit_flag) { char *x = strchr(s, ':'); if (!x) return false; *x = 0; - hostent *he = gethostbyname_retry_on_failure(s, exit_flag); - if (!he) { + + if (!resolver->Resolve(s, sin)) { RERROR("Unable to resolve %s", s); return false; } - sin->sin.sin_family = AF_INET; sin->sin.sin_port = htons(atoi(x + 1)); - memcpy(&sin->sin.sin_addr, he->h_addr_list[0], 4); return true; } -static bool ParseSockaddrInWithoutPort(char *s, IpAddr *sin, bool *exit_flag) { - memset(sin, 0, sizeof(IpAddr)); - if (inet_pton(AF_INET6, s, &sin->sin6.sin6_addr) == 1) { - sin->sin.sin_family = AF_INET6; - return true; - } - hostent *he = gethostbyname_retry_on_failure(s, exit_flag); - if (!he) { +static bool ParseSockaddrInWithoutPort(char *s, IpAddr *sin, DnsResolver *resolver) { + if (!resolver->Resolve(s, sin)) { RERROR("Unable to resolve %s", s); return false; } - sin->sin.sin_family = AF_INET; - memcpy(&sin->sin.sin_addr, he->h_addr_list[0], 4); return true; } @@ -131,7 +190,7 @@ static bool ParseBase64Key(const char *s, uint8 key[32]) { class WgFileParser { public: - WgFileParser(WireguardProcessor *wg, bool *exit_flag) : wg_(wg), exit_flag_(exit_flag) {} + WgFileParser(WireguardProcessor *wg, DnsResolver *resolver) : wg_(wg), dns_resolver_(resolver) {} bool ParseFlag(const char *group, const char *key, char *value); WireguardProcessor *wg_; @@ -142,7 +201,7 @@ public: }; Peer pi_; WgPeer *peer_ = NULL; - bool *exit_flag_; + DnsResolver *dns_resolver_; bool had_interface_ = false; }; @@ -271,7 +330,7 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) { } else if (strcmp(key, "DNS") == 0) { SplitString(value, ',', &ss); for (size_t i = 0; i < ss.size(); i++) { - if (!ParseSockaddrInWithoutPort(ss[i], &sin, exit_flag_)) + if (!ParseSockaddrInWithoutPort(ss[i], &sin, dns_resolver_)) return false; if (!wg_->AddDnsServer(sin)) { RERROR("Multiple DNS not allowed."); @@ -315,6 +374,13 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) { wg_->prepost().pre_up.emplace_back(value); } else if (strcmp(key, "PreDown") == 0) { wg_->prepost().pre_down.emplace_back(value); + } else if (strcmp(key, "ExcludedIPs") == 0) { + SplitString(value, ',', &ss); + for (size_t i = 0; i < ss.size(); i++) { + if (!ParseCidrAddr(ss[i], &addr)) + return false; + wg_->AddExcludedIp(addr); + } } else { goto err; } @@ -344,7 +410,7 @@ bool WgFileParser::ParseFlag(const char *group, const char *key, char *value) { return false; } } else if (strcmp(key, "Endpoint") == 0) { - if (!ParseSockaddrInWithPort(value, &sin, exit_flag_)) + if (!ParseSockaddrInWithPort(value, &sin, dns_resolver_)) return false; peer_->SetEndpoint(sin); } else if (strcmp(key, "PersistentKeepalive") == 0) { @@ -384,11 +450,20 @@ err: return true; } -bool ParseWireGuardConfigFile(WireguardProcessor *wg, const char *filename, bool *exit_flag) { +static bool ContainsNonAsciiCharacter(const char *buf, size_t size) { + for (size_t i = 0; i < size; i++) { + uint8 c = buf[i]; + if (c < 32 && ((1 << c) & (1 << '\n' | 1 << '\r' | 1 << '\t')) == 0) + return true; + } + return false; +} + +bool ParseWireGuardConfigFile(WireguardProcessor *wg, const char *filename, DnsResolver *dns_resolver) { char buf[1024]; char group[32] = {0}; - WgFileParser file_parser(wg, exit_flag); + WgFileParser file_parser(wg, dns_resolver); RINFO("Loading file: %s", filename); @@ -400,6 +475,13 @@ bool ParseWireGuardConfigFile(WireguardProcessor *wg, const char *filename, bool while (fgets(buf, sizeof(buf), f)) { size_t l = strlen(buf); + + if (ContainsNonAsciiCharacter(buf, l)) { + RERROR("File is not a config file: %s", filename); + return false; + } + + while (l && is_space(buf[l - 1])) buf[--l] = 0; if (buf[0] == '#' || buf[0] == '\0') diff --git a/wireguard_config.h b/wireguard_config.h index 03d7899..01d9678 100644 --- a/wireguard_config.h +++ b/wireguard_config.h @@ -3,13 +3,38 @@ #ifndef TINYVPN_TINYVPN_H_ #define TINYVPN_TINYVPN_H_ -class WireguardProcessor; +#include "netapi.h" -bool ParseWireGuardConfigFile(WireguardProcessor *wg, const char *filename, bool *exit_flag); +class WireguardProcessor; +class DnsBlocker; + +class DnsResolver { +public: + explicit DnsResolver(DnsBlocker *dns_blocker); + ~DnsResolver(); + + bool Resolve(const char *hostname, IpAddr *result); + + void ClearCache(); + + void SetAbortFlag(bool v) { abort_flag_ = v; } +private: + struct Entry { + std::string name; + IpAddr ip; + Entry(const std::string &name, const IpAddr &ip) : name(name), ip(ip) {} + }; + std::vector cache_; + bool abort_flag_; + DnsBlocker *dns_blocker_; +}; + + +bool ParseWireGuardConfigFile(WireguardProcessor *wg, const char *filename, DnsResolver *dns_resolver); #define kSizeOfAddress 64 const char *print_ip_prefix(char buf[kSizeOfAddress], int family, const void *ip, int prefixlen); - +char *PrintIpAddr(const IpAddr &addr, char buf[kSizeOfAddress]); #endif // TINYVPN_TINYVPN_H_ diff --git a/wireguard_proto.cpp b/wireguard_proto.cpp index ad20a53..9d6b0ab 100644 --- a/wireguard_proto.cpp +++ b/wireguard_proto.cpp @@ -11,7 +11,7 @@ #include "util.h" #include "crypto_ops.h" #include "bit_ops.h" -#include "tunsafe_cpu.h" +#include "tunsafe_cpu.h" #include #include #include @@ -23,97 +23,6 @@ static const uint8 kWgInitHash[WG_HASH_LEN] = {0x22,0x11,0xb3,0x61,0x08,0x1a,0xc static const uint8 kWgInitChainingKey[WG_HASH_LEN] = {0x60,0xe2,0x6d,0xae,0xf3,0x27,0xef,0xc0,0x2e,0xc3,0x35,0xe2,0xa0,0x25,0xd2,0xd0,0x16,0xeb,0x42,0x06,0xf8,0x72,0x77,0xf5,0x2d,0x38,0xd1,0x98,0x8b,0x78,0xcd,0x36}; static const uint8 kCurve25519Basepoint[32] = {9}; -IpToPeerMap::IpToPeerMap() { - -} - -IpToPeerMap::~IpToPeerMap() { -} - -bool IpToPeerMap::InsertV4(const void *addr, int cidr, void *peer) { - uint32 mask = cidr == 32 ? 0xffffffff : ~(0xffffffff >> cidr); - Entry4 e = {ReadBE32(addr) & mask, mask, peer}; - ipv4_.push_back(e); - return true; -} - -bool IpToPeerMap::InsertV6(const void *addr, int cidr, void *peer) { - Entry6 e; - e.cidr_len = cidr; - e.peer = peer; - memcpy(e.ip, addr, 16); - ipv6_.push_back(e); - return true; -} - -void *IpToPeerMap::LookupV4(uint32 ip) { - uint32 best_mask = 0; - void *best_peer = NULL; - for (auto it = ipv4_.begin(); it != ipv4_.end(); ++it) { - if (it->ip == (ip & it->mask) && it->mask >= best_mask) { - best_mask = it->mask; - best_peer = it->peer; - } - } - return best_peer; -} - -void *IpToPeerMap::LookupV4DefaultPeer() { - for (auto it = ipv4_.begin(); it != ipv4_.end(); ++it) { - if (it->mask == 0) - return it->peer; - } - return NULL; -} - -void *IpToPeerMap::LookupV6DefaultPeer() { - for (auto it = ipv6_.begin(); it != ipv6_.end(); ++it) { - if (it->cidr_len == 0) - return it->peer; - } - return NULL; -} - -static int CalculateIPv6CommonPrefix(const uint8 *a, const uint8 *b) { - uint64 x = ToBE64(*(uint64*)&a[0] ^ *(uint64*)&b[0]); - uint64 y = ToBE64(*(uint64*)&a[8] ^ *(uint64*)&b[8]); - return x ? 64 - FindHighestSetBit64(x) : 128 - FindHighestSetBit64(y); -} - -void *IpToPeerMap::LookupV6(const void *addr) { - int best_len = 0; - void *best_peer = NULL; - for (auto it = ipv6_.begin(); it != ipv6_.end(); ++it) { - int len = CalculateIPv6CommonPrefix((const uint8*)addr, it->ip); - if (len >= it->cidr_len && len >= best_len) { - best_len = len; - best_peer = it->peer; - } - } - return best_peer; -} - -void IpToPeerMap::RemovePeer(void *peer) { - { - size_t n = ipv4_.size(); - Entry4 *r = &ipv4_[0], *w = r; - for (size_t i = 0; i != n; i++, r++) { - if (r->peer != peer) - *w++ = *r; - } - ipv4_.resize(w - &ipv4_[0]); - } - { - size_t n = ipv6_.size(); - Entry6 *r = &ipv6_[0], *w = r; - for (size_t i = 0; i != n; i++, r++) { - if (r->peer != peer) - *w++ = *r; - } - ipv6_.resize(w - &ipv6_[0]); - } -} - ReplayDetector::ReplayDetector() { expected_seq_nr_ = 0; memset(bitmap_, 0, sizeof(bitmap_)); @@ -124,8 +33,9 @@ ReplayDetector::~ReplayDetector() { bool ReplayDetector::CheckReplay(uint64 seq_nr) { uint64 slot = seq_nr / BITS_PER_ENTRY; - if (seq_nr >= expected_seq_nr_) { - uint64 prev_slot = (expected_seq_nr_ + BITS_PER_ENTRY - 1) / BITS_PER_ENTRY - 1, n; + uint64 expected_seq_nr = expected_seq_nr_; + if (seq_nr >= expected_seq_nr) { + uint64 prev_slot = (expected_seq_nr + BITS_PER_ENTRY - 1) / BITS_PER_ENTRY - 1, n; if ((n = slot - prev_slot) != 0) { size_t nn = (size_t)std::min(n, BITMAP_SIZE); do { @@ -133,7 +43,7 @@ bool ReplayDetector::CheckReplay(uint64 seq_nr) { } while (--nn); } expected_seq_nr_ = seq_nr + 1; - } else if (seq_nr + WINDOW_SIZE <= expected_seq_nr_) { + } else if (seq_nr + WINDOW_SIZE <= expected_seq_nr) { return false; } uint32 mask = 1 << (seq_nr & (BITS_PER_ENTRY - 1)), prev; @@ -146,21 +56,21 @@ WgDevice::WgDevice() { peers_ = NULL; header_obfuscation_ = false; next_rng_slot_ = 0; - last_complete_handskake_timestamp_ = 0; memset(&compression_header_, 0, sizeof(compression_header_)); low_resolution_timestamp_ = cookie_secret_timestamp_ = OsGetMilliseconds(); OsGetRandomBytes(cookie_secret_, sizeof(cookie_secret_)); OsGetRandomBytes((uint8*)random_number_input_, sizeof(random_number_input_)); - + SetCurrentThreadAsMainThread(); } WgDevice::~WgDevice() { } void WgDevice::SecondLoop(uint64 now) { - low_resolution_timestamp_ = now; + assert(IsMainThread()); + low_resolution_timestamp_ = now; if (rate_limiter_.is_used()) { uint32 k[5]; for (size_t i = 0; i < ARRAY_SIZE(k); i++) @@ -170,11 +80,16 @@ void WgDevice::SecondLoop(uint64 now) { } uint32 WgDevice::InsertInKeyIdLookup(WgPeer *peer, WgKeypair *kp) { + assert(IsMainThread()); assert(peer); for (;;) { uint32 v = GetRandomNumber(); if (v == 0) continue; + + // Take the exclusive lock since we're modifying it. + WG_SCOPED_RWLOCK_EXCLUSIVE(key_id_lookup_lock_); + std::pair &peer_and_keypair = key_id_lookup_[v]; if (peer_and_keypair.first == NULL) { peer_and_keypair = std::make_pair(peer, kp); @@ -188,7 +103,24 @@ uint32 WgDevice::InsertInKeyIdLookup(WgPeer *peer, WgKeypair *kp) { } } +std::pair *WgDevice::LookupPeerInKeyIdLookup(uint32 key_id) { + // This function is only ever called by the main thread, so no need to lock, + // since the main thread is the only mutator. + assert(IsMainThread()); + auto it = key_id_lookup_.find(key_id); + return (it != key_id_lookup_.end() && it->second.second == NULL) ? &it->second : NULL; +} + +WgKeypair *WgDevice::LookupKeypairByKeyId(uint32 key_id) { + // This function can be called from any thread, so make sure to + // lock using the shared lock. + WG_SCOPED_RWLOCK_SHARED(key_id_lookup_lock_); + auto it = key_id_lookup_.find(key_id); + return (it != key_id_lookup_.end()) ? it->second.second : NULL; +} + uint32 WgDevice::GetRandomNumber() { + assert(IsMainThread()); size_t slot; if ((slot = next_rng_slot_) == 0) { blake2s(random_number_output_, sizeof(random_number_output_), random_number_input_, sizeof(random_number_input_), NULL, 0); @@ -232,6 +164,7 @@ void WgDevice::Initialize(const uint8 private_key[WG_PUBLIC_KEY_LEN]) { } WgPeer *WgDevice::AddPeer() { + assert(IsMainThread()); WgPeer *peer = new WgPeer(this); WgPeer **pp = &peers_; while (*pp) @@ -241,6 +174,8 @@ WgPeer *WgDevice::AddPeer() { } WgPeer *WgDevice::GetPeerFromPublicKey(uint8 public_key[WG_PUBLIC_KEY_LEN]) { + assert(IsMainThread()); + // todo: add O(1) lookup for (WgPeer *peer = peers_; peer; peer = peer->next_peer_) { if (memcmp(peer->s_remote_, public_key, WG_PUBLIC_KEY_LEN) == 0) return peer; @@ -249,15 +184,16 @@ WgPeer *WgDevice::GetPeerFromPublicKey(uint8 public_key[WG_PUBLIC_KEY_LEN]) { } bool WgDevice::CheckCookieMac1(Packet *packet) { + assert(IsMainThread()); uint8 mac[WG_COOKIE_LEN]; const uint8 *data = packet->data; size_t data_size = packet->size; - blake2s(mac, sizeof(mac), data, data_size - WG_COOKIE_LEN * 2, precomputed_mac1_key_, sizeof(precomputed_mac1_key_)); return !memcmp_crypto(mac, data + data_size - WG_COOKIE_LEN * 2, WG_COOKIE_LEN); } void WgDevice::MakeCookie(uint8 cookie[WG_COOKIE_LEN], Packet *packet) { + assert(IsMainThread()); blake2s_state b2s; uint64 now = OsGetMilliseconds(); if (now - cookie_secret_timestamp_ >= COOKIE_SECRET_MAX_AGE_MS) { @@ -274,6 +210,7 @@ void WgDevice::MakeCookie(uint8 cookie[WG_COOKIE_LEN], Packet *packet) { } bool WgDevice::CheckCookieMac2(Packet *packet) { + assert(IsMainThread()); uint8 cookie[WG_COOKIE_LEN]; uint8 mac[WG_COOKIE_LEN]; MakeCookie(cookie, packet); @@ -282,6 +219,7 @@ bool WgDevice::CheckCookieMac2(Packet *packet) { } void WgDevice::CreateCookieMessage(MessageHandshakeCookie *dst, Packet *packet, uint32 remote_key_id) { + assert(IsMainThread()); dst->type = MESSAGE_HANDSHAKE_COOKIE; dst->receiver_key_id = remote_key_id; MakeCookie(dst->cookie_enc, packet); @@ -290,7 +228,7 @@ void WgDevice::CreateCookieMessage(MessageHandshakeCookie *dst, Packet *packet, xchacha20poly1305_encrypt(dst->cookie_enc, dst->cookie_enc, WG_COOKIE_LEN, mac->mac1, WG_COOKIE_LEN, dst->nonce, precomputed_cookie_key_); } -void WgDevice::EraseKeypairAddrEntry(WgKeypair *kp) { +void WgDevice::EraseKeypairAddrEntry_Locked(WgKeypair *kp) { WgAddrEntry *ae = kp->addr_entry; assert(ae->ref_count >= 1); @@ -308,14 +246,28 @@ void WgDevice::EraseKeypairAddrEntry(WgKeypair *kp) { } } -void WgDevice::UpdateKeypairAddrEntry(uint64 addr_id, WgKeypair *keypair) { - if (keypair->addr_entry != NULL && keypair->addr_entry->addr_entry_id == addr_id) { - keypair->broadcast_short_key = 1; - return; +WgKeypair *WgDevice::LookupKeypairInAddrEntryMap(uint64 addr, uint32 slot) { + WG_SCOPED_RWLOCK_SHARED(addr_entry_lookup_lock_); + auto it = addr_entry_lookup_.find(addr); + if (it == addr_entry_lookup_.end()) + return NULL; + WgAddrEntry *addr_entry = it->second; + return addr_entry->keys[slot]; +} + +void WgDevice::UpdateKeypairAddrEntry_Locked(uint64 addr_id, WgKeypair *keypair) { + assert(keypair->peer->IsPeerLocked()); + { + WG_SCOPED_RWLOCK_SHARED(addr_entry_lookup_lock_); + if (keypair->addr_entry != NULL && keypair->addr_entry->addr_entry_id == addr_id) { + keypair->broadcast_short_key = 1; + return; + } } + WG_SCOPED_RWLOCK_EXCLUSIVE(addr_entry_lookup_lock_); if (keypair->addr_entry != NULL) - EraseKeypairAddrEntry(keypair); + EraseKeypairAddrEntry_Locked(keypair); WgAddrEntry **aep = &addr_entry_lookup_[addr_id], *ae; @@ -362,13 +314,16 @@ void WgDevice::SetHeaderObfuscation(const char *key) { WgPeer::WgPeer(WgDevice *dev) { + assert(dev->IsMainThread()); dev_ = dev; endpoint_.sin.sin_family = 0; next_peer_ = NULL; curr_keypair_ = next_keypair_ = prev_keypair_ = NULL; expect_cookie_reply_ = false; has_mac2_cookie_ = false; + pending_keepalive_ = false; allow_multicast_through_peer_ = false; + allow_endpoint_change_ = true; supports_handshake_extensions_ = true; local_key_id_during_hs_ = 0; last_handshake_init_timestamp_ = -1000000ll; @@ -380,20 +335,43 @@ WgPeer::WgPeer(WgDevice *dev) { last_queued_packet_ptr_ = &first_queued_packet_; num_queued_packets_ = 0; handshake_attempts_ = 0; + total_handshake_attempts_ = 0; num_ciphers_ = 0; cipher_prio_ = 0; + main_thread_scheduled_ = 0; memset(last_timestamp_, 0, sizeof(last_timestamp_)); ipv4_broadcast_addr_ = 0xffffffff; memset(features_, 0, sizeof(features_)); } WgPeer::~WgPeer() { - ClearKeys(); - ClearHandshake(); - ClearPacketQueue(); + assert(dev_->IsMainThread()); + WG_ACQUIRE_LOCK(mutex_); + ClearKeys_Locked(); + ClearHandshake_Locked(); + ClearPacketQueue_Locked(); + WG_RELEASE_LOCK(mutex_); } -void WgPeer::ClearPacketQueue() { +void WgPeer::ClearKeys_Locked() { + assert(dev_->IsMainThread() && IsPeerLocked()); + DeleteKeypair(&curr_keypair_); + DeleteKeypair(&next_keypair_); + DeleteKeypair(&prev_keypair_); +} + +void WgPeer::ClearHandshake_Locked() { + assert(dev_->IsMainThread() && IsPeerLocked()); + uint32 v = local_key_id_during_hs_; + if (v != 0) { + local_key_id_during_hs_ = 0; + WG_SCOPED_RWLOCK_EXCLUSIVE(dev_->key_id_lookup_lock_); + dev_->key_id_lookup_.erase(v); + } +} + +void WgPeer::ClearPacketQueue_Locked() { + assert(dev_->IsMainThread() && IsPeerLocked()); Packet *packet; while ((packet = first_queued_packet_) != NULL) { first_queued_packet_ = packet->next; @@ -422,6 +400,8 @@ void WgPeer::Initialize(const uint8 spub[WG_PUBLIC_KEY_LEN], const uint8 preshar // run on the client void WgPeer::CreateMessageHandshakeInitiation(Packet *packet) { + assert(dev_->IsMainThread()); + uint8 k[WG_SYMMETRIC_KEY_LEN]; MessageHandshakeInitiation *dst = (MessageHandshakeInitiation *)packet->data; @@ -463,7 +443,6 @@ void WgPeer::CreateMessageHandshakeInitiation(Packet *packet) { packet->size = (unsigned)(sizeof(MessageHandshakeInitiation) + extfield_size); - // Insert a pointer to this object, dst->sender_key_id = dev_->InsertInKeyIdLookup(this, NULL); dst->type = MESSAGE_HANDSHAKE_INITIATION; memzero_crypto(k, sizeof(k)); @@ -472,6 +451,7 @@ void WgPeer::CreateMessageHandshakeInitiation(Packet *packet) { // Parsed by server WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) { // const MessageHandshakeInitiation *src, MessageHandshakeResponse *dst) { + assert(dev->IsMainThread()); // Copy values into handshake once we've validated it all. uint8 ci[WG_HASH_LEN]; uint8 hi[WG_HASH_LEN]; @@ -562,9 +542,14 @@ WgPeer *WgPeer::ParseMessageHandshakeInitiation(WgDevice *dev, Packet *packet) { BlakeMix(hi, t, sizeof(t)); dst->receiver_key_id = remote_key_id; - keypair = peer->CreateNewKeypair(false, ci, remote_key_id, extbuf + WG_TIMESTAMP_LEN, extfield_size); + keypair = WgPeer::CreateNewKeypair(false, ci, remote_key_id, extbuf + WG_TIMESTAMP_LEN, extfield_size); if (keypair) { - peer->InsertKeypairInPeer(keypair); + + WG_ACQUIRE_LOCK(peer->mutex_); + peer->InsertKeypairInPeer_Locked(keypair); + peer->OnHandshakeAuthComplete(); + WG_RELEASE_LOCK(peer->mutex_); + dst->sender_key_id = dev->InsertInKeyIdLookup(peer, keypair); size_t extfield_out_size = 0; @@ -593,15 +578,15 @@ getout: } WgPeer *WgPeer::ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packet) { + assert(dev->IsMainThread()); MessageHandshakeResponse *src = (MessageHandshakeResponse *)packet->data; uint8 t[WG_HASH_LEN]; uint8 k[WG_SYMMETRIC_KEY_LEN]; WgKeypair *keypair; - auto it = dev->key_id_lookup().find(src->receiver_key_id); - if (it == dev->key_id_lookup().end() || it->second.second != NULL) + auto peer_and_keypair = dev->LookupPeerInKeyIdLookup(src->receiver_key_id); + if (peer_and_keypair == NULL) return NULL; - WgPeer *peer = it->second.first; - + WgPeer *peer = peer_and_keypair->first; assert(src->receiver_key_id == peer->local_key_id_during_hs_); HandshakeState hs = peer->hs_; @@ -626,16 +611,18 @@ WgPeer *WgPeer::ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packe if (!chacha20poly1305_decrypt(src->empty_enc, src->empty_enc, extfield_size + sizeof(src->empty_enc), hs.hi, sizeof(hs.hi), 0, k)) goto getout; - keypair = peer->CreateNewKeypair(true, hs.ci, src->sender_key_id, src->empty_enc, extfield_size); + keypair = WgPeer::CreateNewKeypair(true, hs.ci, src->sender_key_id, src->empty_enc, extfield_size); if (!keypair) goto getout; - peer->InsertKeypairInPeer(keypair); - // Re-map the entry in the id table so it points at this keypair instead. keypair->local_key_id = peer->local_key_id_during_hs_; peer->local_key_id_during_hs_ = 0; - it->second.second = keypair; + peer_and_keypair->second = keypair; + + WG_ACQUIRE_LOCK(peer->mutex_); + peer->InsertKeypairInPeer_Locked(keypair); + WG_RELEASE_LOCK(peer->mutex_); if (0) { getout: @@ -650,11 +637,12 @@ getout: // This is parsed by the initiator, when it needs to re-send the handshake message with a better mac. void WgPeer::ParseMessageHandshakeCookie(WgDevice *dev, const MessageHandshakeCookie *src) { + assert(dev->IsMainThread()); uint8 cookie[WG_COOKIE_LEN]; - auto it = dev->key_id_lookup().find(src->receiver_key_id); - if (it == dev->key_id_lookup().end() || it->second.second != NULL) + auto peer_and_keypair = dev->LookupPeerInKeyIdLookup(src->receiver_key_id); + if (!peer_and_keypair) return; - WgPeer *peer = it->second.first; + WgPeer *peer = peer_and_keypair->first; if (!peer->expect_cookie_reply_) return; if (!xchacha20poly1305_decrypt(cookie, src->cookie_enc, sizeof(src->cookie_enc), @@ -756,6 +744,7 @@ void WgKeypairSetupCompressionExtension(WgKeypair *keypair, const WgPacketCompre state->server_addr_v4_subnet_bytes = (remotec->flags & 3); WriteLE32(&state->server_addr_v4_netmask, 0xffffffff >> ((remotec->flags & 3) * 8)); } + bool WgKeypairParseExtendedHandshake(WgKeypair *keypair, const uint8 *data, size_t data_size) { bool did_setup_compression = false; @@ -804,33 +793,29 @@ bool WgKeypairParseExtendedHandshake(WgKeypair *keypair, const uint8 *data, size #endif // WITH_HANDSHAKE_EXT -void WgPeer::ClearKeys() { - DeleteKeypair(&curr_keypair_); - DeleteKeypair(&next_keypair_); - DeleteKeypair(&prev_keypair_); -} - -void WgPeer::ClearHandshake() { - uint32 v = local_key_id_during_hs_; - if (v != 0) { - local_key_id_during_hs_ = 0; - dev_->key_id_lookup_.erase(v); - } +static void ActualFreeKeypair(void *x) { + WgKeypair *t = (WgKeypair*)x; + if (t->aes_gcm128_context_) + free(t->aes_gcm128_context_); + delete t; } void WgPeer::DeleteKeypair(WgKeypair **kp) { WgKeypair *t = *kp; *kp = NULL; if (t) { - if (t->addr_entry) - dev_->EraseKeypairAddrEntry(t); - - if (t->local_key_id) + assert(t->peer->IsPeerLocked()); + if (t->addr_entry) { + WG_SCOPED_RWLOCK_EXCLUSIVE(dev_->addr_entry_lookup_lock_); + dev_->EraseKeypairAddrEntry_Locked(t); + } + if (t->local_key_id) { + WG_SCOPED_RWLOCK_EXCLUSIVE(dev_->key_id_lookup_lock_); dev_->key_id_lookup_.erase(t->local_key_id); - - if (t->aes_gcm128_context_) - free(t->aes_gcm128_context_); - delete t; + t->local_key_id = 0; + } + t->recv_key_state = WgKeypair::KEY_INVALID; + dev_->delayed_delete_.Add(&ActualFreeKeypair, t); } } @@ -840,21 +825,24 @@ WgKeypair *WgPeer::CreateNewKeypair(bool is_initiator, const uint8 chaining_key[ if (!kp) return NULL; memset(kp, 0, offsetof(WgKeypair, replay_detector)); - kp->peer = this; kp->is_initiator = is_initiator; kp->remote_key_id = remote_key_id; kp->auth_tag_length = CHACHA20POLY1305_AUTHTAGLEN; #if WITH_HANDSHAKE_EXT - if (!WgKeypairParseExtendedHandshake(kp, extfield, extfield_size)) - goto fail; + if (!WgKeypairParseExtendedHandshake(kp, extfield, extfield_size)) { +fail: + delete kp; + return NULL; + } #endif // WITH_HANDSHAKE_EXT first_key = kp->send_key, second_key = kp->recv_key; if (!is_initiator) std::swap(first_key, second_key); blake2s_hkdf(first_key, sizeof(kp->send_key), second_key, sizeof(kp->recv_key), - kp->auth_tag_length != CHACHA20POLY1305_AUTHTAGLEN ? (uint8*)kp->compress_mac_keys : NULL, 32, NULL, 0, chaining_key, WG_HASH_LEN); + kp->auth_tag_length != CHACHA20POLY1305_AUTHTAGLEN ? (uint8*)kp->compress_mac_keys : NULL, 32, + NULL, 0, chaining_key, WG_HASH_LEN); if (!is_initiator) { std::swap(kp->compress_mac_keys[0][0], kp->compress_mac_keys[1][0]); @@ -870,25 +858,22 @@ WgKeypair *WgPeer::CreateNewKeypair(bool is_initiator, const uint8 chaining_key[ int key_size = (kp->cipher_suite == EXT_CIPHER_SUITE_AES128_GCM) ? 128 : 256; CRYPTO_gcm128_init(&kp->aes_gcm128_context_[0], kp->send_key, key_size); CRYPTO_gcm128_init(&kp->aes_gcm128_context_[1], kp->recv_key, key_size); -#else +#else // WITH_AESGCM goto fail; -#endif +#endif // WITH_AESGCM } #endif // WITH_HANDSHAKE_EXT kp->send_key_state = kp->recv_key_state = WgKeypair::KEY_VALID; - time_of_next_key_event_ = 0; kp->key_timestamp = OsGetMilliseconds(); - return kp; - -fail: - delete kp; - return NULL; } -void WgPeer::InsertKeypairInPeer(WgKeypair *kp) { - assert(kp->peer == this); +void WgPeer::InsertKeypairInPeer_Locked(WgKeypair *kp) { + assert(dev_->IsMainThread() && IsPeerLocked()); + assert(kp->peer == NULL); + kp->peer = this; + time_of_next_key_event_ = 0; DeleteKeypair(&prev_keypair_); if (kp->is_initiator) { // When we're the initator then we got the handshake and we can @@ -908,7 +893,8 @@ void WgPeer::InsertKeypairInPeer(WgKeypair *kp) { } } -bool WgPeer::CheckSwitchToNextKey(WgKeypair *keypair) { +bool WgPeer::CheckSwitchToNextKey_Locked(WgKeypair *keypair) { + assert(IsPeerLocked()); if (keypair != next_keypair_) return false; DeleteKeypair(&prev_keypair_); @@ -920,6 +906,7 @@ bool WgPeer::CheckSwitchToNextKey(WgKeypair *keypair) { } bool WgPeer::CheckHandshakeRateLimit() { + assert(dev_->IsMainThread()); uint64 now = OsGetMilliseconds(); if (now - last_handshake_init_timestamp_ < REKEY_TIMEOUT_MS) return false; @@ -928,6 +915,7 @@ bool WgPeer::CheckHandshakeRateLimit() { } void WgPeer::WriteMacToPacket(const uint8 *data, MessageMacs *dst) { + assert(dev_->IsMainThread()); expect_cookie_reply_ = true; blake2s(dst->mac1, sizeof(dst->mac1), data, (uint8*)dst->mac1 - data, precomputed_mac1_key_, sizeof(precomputed_mac1_key_)); memcpy(sent_mac1_, dst->mac1, sizeof(sent_mac1_)); @@ -964,6 +952,7 @@ enum { #define WgSetTimer(x) (timers_ |= (32 << (x))) void WgPeer::OnDataSent() { + assert(IsPeerLocked()); WgClearTimer(TIMER_SEND_KEEPALIVE); if (!WgIsTimerActive(TIMER_NEW_HANDSHAKE)) WgSetTimer(TIMER_NEW_HANDSHAKE); @@ -971,10 +960,12 @@ void WgPeer::OnDataSent() { } void WgPeer::OnKeepaliveSent() { + assert(IsPeerLocked()); WgSetTimer(TIMER_PERSISTENT_KEEPALIVE); } void WgPeer::OnDataReceived() { + assert(IsPeerLocked()); WgClearTimer(TIMER_NEW_HANDSHAKE); if (!WgIsTimerActive(TIMER_SEND_KEEPALIVE)) WgSetTimer(TIMER_SEND_KEEPALIVE); @@ -984,16 +975,19 @@ void WgPeer::OnDataReceived() { } void WgPeer::OnKeepaliveReceived() { + assert(IsPeerLocked()); WgClearTimer(TIMER_NEW_HANDSHAKE); WgSetTimer(TIMER_PERSISTENT_KEEPALIVE); } void WgPeer::OnHandshakeInitSent() { + assert(IsPeerLocked()); WgClearTimer(TIMER_SEND_KEEPALIVE); WgSetTimer(TIMER_RETRANSMIT_HANDSHAKE); } void WgPeer::OnHandshakeAuthComplete() { + assert(IsPeerLocked()); WgClearTimer(TIMER_NEW_HANDSHAKE); WgSetTimer(TIMER_ZERO_KEYS); WgSetTimer(TIMER_PERSISTENT_KEEPALIVE); @@ -1007,8 +1001,11 @@ static const char * const kCipherSuites[] = { }; void WgPeer::OnHandshakeFullyComplete() { + assert(IsPeerLocked()); WgClearTimer(TIMER_RETRANSMIT_HANDSHAKE); - handshake_attempts_ = 0; + total_handshake_attempts_ = handshake_attempts_ = 0; + + uint64 now = OsGetMilliseconds(); if (last_complete_handskake_timestamp_ == 0) { bool any_feature = false; @@ -1022,17 +1019,15 @@ void WgPeer::OnHandshakeFullyComplete() { curr_keypair_->enabled_features[4] ? "skip_keyid_in" : "", curr_keypair_->enabled_features[5] ? "skip_keyid_out" : ""); } - - } - - last_complete_handskake_timestamp_ = OsGetMilliseconds(); - dev_->last_complete_handskake_timestamp_ = last_complete_handskake_timestamp_; + last_complete_handskake_timestamp_ = now; // RINFO("Connection established."); } // Check if any of the timeouts have expired uint32 WgPeer::CheckTimeouts(uint64 now) { + assert(IsPeerLocked()); + uint32 t, rv = 0; if (now >= time_of_next_key_event_) @@ -1056,11 +1051,9 @@ uint32 WgPeer::CheckTimeouts(uint64 now) { if ((t & (1 << TIMER_RETRANSMIT_HANDSHAKE)) && (now32 - timer_value_[TIMER_RETRANSMIT_HANDSHAKE]) >= REKEY_TIMEOUT_MS) { t ^= (1 << TIMER_RETRANSMIT_HANDSHAKE); if (handshake_attempts_ > MAX_HANDSHAKE_ATTEMPTS) { - RINFO("Too many handshake attempts. Stopping."); t &= ~(1 << TIMER_SEND_KEEPALIVE); - ClearPacketQueue(); + ClearPacketQueue_Locked(); } else { - RINFO("Retrying handshake, attempt %d...", handshake_attempts_ + 2); handshake_attempts_++; rv |= ACTION_SEND_HANDSHAKE; } @@ -1085,13 +1078,12 @@ uint32 WgPeer::CheckTimeouts(uint64 now) { t &= ~(1 << TIMER_NEW_HANDSHAKE); handshake_attempts_ = 0; rv |= ACTION_SEND_HANDSHAKE; - RINFO("Retrying handshake with peer"); } if ((t & (1 << TIMER_ZERO_KEYS)) && (now32 - timer_value_[TIMER_ZERO_KEYS]) >= REJECT_AFTER_TIME_MS * 3) { RINFO("Expiring all keys for peer"); t &= ~(1 << TIMER_ZERO_KEYS); - ClearKeys(); - ClearHandshake(); + ClearKeys_Locked(); + ClearHandshake_Locked(); } } timers_ = t; @@ -1100,6 +1092,7 @@ uint32 WgPeer::CheckTimeouts(uint64 now) { // Check all key stuff here to avoid calling possibly expensive timestamp routines in the packet handler void WgPeer::CheckAndUpdateTimeOfNextKeyEvent(uint64 now) { + assert(IsPeerLocked()); uint64 next_time = UINT64_MAX; uint32 rv = 0; @@ -1110,8 +1103,7 @@ void WgPeer::CheckAndUpdateTimeOfNextKeyEvent(uint64 now) { // if a peer is the initiator of a current secure session, WireGuard will send a handshake initiation // message to begin a new secure session if, after transmitting a transport data message, the current secure session // is REKEY_AFTER_TIME_MS old, or if after receiving a transport data message, the current secure session is - // (REKEY_AFTER_TIME_MS - KEEPALIVE_TIMEOUT_MS - REKEY_TIMEOUT_MS) seconds old and it has not yet acted upon - // this event. + // (REKEY_AFTER_TIME_MS - KEEPALIVE_TIMEOUT_MS - REKEY_TIMEOUT_MS) seconds old and it has not yet acted upon it. if (now >= curr_keypair_->key_timestamp + (REJECT_AFTER_TIME_MS - KEEPALIVE_TIMEOUT_MS - REKEY_TIMEOUT_MS)) { next_time = curr_keypair_->key_timestamp + REJECT_AFTER_TIME_MS; if (curr_keypair_->recv_key_state == WgKeypair::KEY_VALID) @@ -1153,16 +1145,22 @@ void WgPeer::SetPersistentKeepalive(int persistent_keepalive_secs) { } bool WgPeer::AddIp(const WgCidrAddr &cidr_addr) { + assert(dev_->IsMainThread()); + if (cidr_addr.size == 32) { if (cidr_addr.cidr > 32) return false; + WG_ACQUIRE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_); dev_->ip_to_peer_map_.InsertV4(cidr_addr.addr, cidr_addr.cidr, this); + WG_RELEASE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_); allowed_ips_.push_back(cidr_addr); return true; } else if (cidr_addr.size == 128) { if (cidr_addr.cidr > 128) return false; + WG_ACQUIRE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_); dev_->ip_to_peer_map_.InsertV6(cidr_addr.addr, cidr_addr.cidr, this); + WG_RELEASE_RWLOCK_EXCLUSIVE(dev_->ip_to_peer_map_lock_); allowed_ips_.push_back(cidr_addr); return true; } else { @@ -1183,14 +1181,13 @@ bool WgPeer::AddCipher(int cipher) { return false; if (cipher == EXT_CIPHER_SUITE_AES128_GCM || cipher == EXT_CIPHER_SUITE_AES256_GCM) { -#if !WITH_AESGCM - return true; -#endif // !WITH_AESGCM +#if defined(ARCH_CPU_X86_FAMILY) && WITH_AESGCM if (!X86_PCAP_AES) return true; +#else + return true; +#endif // defined(ARCH_CPU_X86_FAMILY) && WITH_AESGCM } - - ciphers_[num_ciphers_++] = cipher; return true; } @@ -1214,15 +1211,10 @@ void WgRateLimit::Periodic(uint32 s[5]) { if (per_sec < 1) per_sec = 1; } - if ((unsigned)per_sec > packets_per_sec_) per_sec = (per_sec + packets_per_sec_ + 1) >> 1; - -// if (per_sec != packets_per_sec_) { -// RINFO("Setting pps: %d", per_sec); - packets_per_sec_ = per_sec; -// } - + + packets_per_sec_ = per_sec; used_rate_limit_ = 0; rand_xor_ = s[4]; key2_[0] = key1_[0]; @@ -1278,7 +1270,8 @@ void WgKeypairEncryptPayload(uint8 *dst, const size_t src_len, bool WgKeypairDecryptPayload(uint8 *dst, size_t src_len, const uint8 *ad, size_t ad_len, const uint64 nonce, WgKeypair *keypair) { - uint8 mac[16]; + + __aligned(16) uint8 mac[16]; if (src_len < keypair->auth_tag_length) return false; diff --git a/wireguard_proto.h b/wireguard_proto.h index cd66901..9e5c12f 100644 --- a/wireguard_proto.h +++ b/wireguard_proto.h @@ -4,9 +4,40 @@ #include "tunsafe_types.h" #include "netapi.h" +#include "ipzip2/ipzip2.h" #include "tunsafe_config.h" +#include "tunsafe_threading.h" +#include "ip_to_peer_map.h" #include #include +#include + +// Threading macros that enable locks only in MT builds +#if WITH_WG_THREADING +#define WG_SCOPED_LOCK(name) AutoLock scoped_lock(&name) +#define WG_ACQUIRE_LOCK(name) name.Acquire() +#define WG_RELEASE_LOCK(name) name.Release() +#define WG_DECLARE_LOCK(name) Mutex name; +#define WG_DECLARE_RWLOCK(name) ReaderWriterLock name; +#define WG_ACQUIRE_RWLOCK_SHARED(name) name.AcquireShared() +#define WG_RELEASE_RWLOCK_SHARED(name) name.ReleaseShared() +#define WG_ACQUIRE_RWLOCK_EXCLUSIVE(name) name.AcquireExclusive() +#define WG_RELEASE_RWLOCK_EXCLUSIVE(name) name.ReleaseExclusive() +#define WG_SCOPED_RWLOCK_SHARED(name) ScopedLockShared scoped_lock(&name) +#define WG_SCOPED_RWLOCK_EXCLUSIVE(name) ScopedLockExclusive scoped_lock(&name) +#else // WITH_WG_THREADING +#define WG_SCOPED_LOCK(name) +#define WG_ACQUIRE_LOCK(name) +#define WG_RELEASE_LOCK(name) +#define WG_DECLARE_LOCK(name) +#define WG_DECLARE_RWLOCK(name) +#define WG_ACQUIRE_RWLOCK_SHARED(name) +#define WG_RELEASE_RWLOCK_SHARED(name) +#define WG_ACQUIRE_RWLOCK_EXCLUSIVE(name) +#define WG_RELEASE_RWLOCK_EXCLUSIVE(name) +#define WG_SCOPED_RWLOCK_SHARED(name) +#define WG_SCOPED_RWLOCK_EXCLUSIVE(name) +#endif // WITH_WG_THREADING enum ProtocolTimeouts { COOKIE_SECRET_MAX_AGE_MS = 120000, @@ -17,6 +48,8 @@ enum ProtocolTimeouts { REJECT_AFTER_TIME_MS = 180000, PERSISTENT_KEEPALIVE_MS = 25000, MIN_HANDSHAKE_INTERVAL_MS = 20, + + MAX_SIZE_OF_HANDSHAKE_EXTENSION = 1024, }; enum ProtocolLimits { @@ -26,7 +59,6 @@ enum ProtocolLimits { MAX_HANDSHAKE_ATTEMPTS = 20, MAX_QUEUED_PACKETS_PER_PEER = 128, MESSAGE_MINIMUM_SIZE = 16, - MAX_SIZE_OF_HANDSHAKE_EXTENSION = 1024, }; enum MessageType { @@ -61,7 +93,7 @@ enum { WG_ACK_HEADER_COUNTER_NONE = 0x00, WG_ACK_HEADER_COUNTER_2 = 0x04, WG_ACK_HEADER_COUNTER_4 = 0x08, - WG_ACK_HEADER_COUNTER_8 = 0x0C, + WG_ACK_HEADER_COUNTER_6 = 0x0C, WG_ACK_HEADER_KEY_MASK = 3, }; @@ -166,39 +198,6 @@ STATIC_ASSERT(sizeof(WgPacketCompressionVer01) == 24, WgPacketCompressionVer01_w struct WgKeypair; class WgPeer; -// Maps CIDR addresses to a peer, always returning the longest match -class IpToPeerMap { -public: - IpToPeerMap(); - ~IpToPeerMap(); - - // Inserts an IP address of a given CIDR length into the lookup table, pointing to peer. - bool InsertV4(const void *addr, int cidr, void *peer); - bool InsertV6(const void *addr, int cidr, void *peer); - - // Lookup the peer matching the IP Address - void *LookupV4(uint32 ip); - void *LookupV6(const void *addr); - - void *LookupV4DefaultPeer(); - void *LookupV6DefaultPeer(); - - // Remove a peer from the table - void RemovePeer(void *peer); -private: - struct Entry4 { - uint32 ip; - uint32 mask; - void *peer; - }; - struct Entry6 { - uint8 ip[16]; - uint8 cidr_len; - void *peer; - }; - std::vector ipv4_; - std::vector ipv6_; -}; class WgRateLimit { public: @@ -262,7 +261,6 @@ struct ScramblerSiphashKeys { uint64 keys[4]; }; -// Implementation of most business logic of Wireguard class WgDevice { friend class WgPeer; friend class WireguardProcessor; @@ -272,7 +270,8 @@ public: // Initialize with the private key, precompute all internal keys etc. void Initialize(const uint8 private_key[WG_PUBLIC_KEY_LEN]); - + + // Create a new peer WgPeer *AddPeer(); // Setup header obfuscation @@ -281,35 +280,26 @@ public: // Check whether Mac1 appears to be valid bool CheckCookieMac1(Packet *packet); - // Check whether Mac2 appears to be valid, this also uses - // the remote ip address + // Check whether Mac2 appears to be valid, this also uses the remote ip address bool CheckCookieMac2(Packet *packet); void CreateCookieMessage(MessageHandshakeCookie *dst, Packet *packet, uint32 remote_key_id); - - void UpdateKeypairAddrEntry(uint64 addr_id, WgKeypair *keypair); + void UpdateKeypairAddrEntry_Locked(uint64 addr_id, WgKeypair *keypair); + void SecondLoop(uint64 now); IpToPeerMap &ip_to_peer_map() { return ip_to_peer_map_; } - - std::unordered_map > &key_id_lookup() { return key_id_lookup_; } - WgPeer *first_peer() { return peers_; } - - uint64 last_complete_handskake_timestamp() const { - return last_complete_handskake_timestamp_; - } - const uint8 *public_key() const { return s_pub_; } - - void SecondLoop(uint64 now); - WgRateLimit *rate_limiter() { return &rate_limiter_; } - std::unordered_map &addr_entry_map() { return addr_entry_lookup_; } - - WgPacketCompressionVer01 *compression_header() { return &compression_header_; } + + bool IsMainThread() { return CurrentThreadIdEquals(main_thread_id_); } + void SetCurrentThreadAsMainThread() { main_thread_id_ = GetCurrentThreadId(); } private: + std::pair *LookupPeerInKeyIdLookup(uint32 key_id); + WgKeypair *LookupKeypairByKeyId(uint32 key_id); + WgKeypair *LookupKeypairInAddrEntryMap(uint64 addr, uint32 slot); // Return the peer matching the |public_key| or NULL WgPeer *GetPeerFromPublicKey(uint8 public_key[WG_PUBLIC_KEY_LEN]); // Create a cookie by inspecting the source address of the |packet| @@ -319,12 +309,19 @@ private: // Get a random number uint32 GetRandomNumber(); - void EraseKeypairAddrEntry(WgKeypair *kp); + void EraseKeypairAddrEntry_Locked(WgKeypair *kp); // Maps IP addresses to peers IpToPeerMap ip_to_peer_map_; + + // This lock protects |ip_to_peer_map_|. + WG_DECLARE_RWLOCK(ip_to_peer_map_lock_); + // For enumerating all peers WgPeer *peers_; + + // Lock that protects key_id_lookup_ + WG_DECLARE_RWLOCK(key_id_lookup_lock_); // Mapping from key-id to either an active keypair (if keypair is non-NULL), // or to a handshake. std::unordered_map > key_id_lookup_; @@ -332,6 +329,7 @@ private: // Mapping from IPV4 IP/PORT to WgPeer*, so we can find the peer when a key id is // not explicitly included. std::unordered_map addr_entry_lookup_; + WG_DECLARE_RWLOCK(addr_entry_lookup_lock_); // Counter for generating new indices in |keypair_lookup_| uint8 next_rng_slot_; @@ -339,7 +337,7 @@ private: // Whether packet obfuscation is enabled bool header_obfuscation_; - uint64 last_complete_handskake_timestamp_; + ThreadId main_thread_id_; uint64 low_resolution_timestamp_; @@ -360,9 +358,12 @@ private: WgRateLimit rate_limiter_; WgPacketCompressionVer01 compression_header_; + + // For defering deletes until all worker threads are guaranteed not to use an object. + MultithreadedDelayedDelete delayed_delete_; }; -// State for Noise handshake +// State for peer class WgPeer { friend class WgDevice; friend class WireguardProcessor; @@ -387,10 +388,10 @@ public: static WgPeer *ParseMessageHandshakeResponse(WgDevice *dev, const Packet *packet); static void ParseMessageHandshakeCookie(WgDevice *dev, const MessageHandshakeCookie *src); void CreateMessageHandshakeInitiation(Packet *packet); - bool CheckSwitchToNextKey(WgKeypair *keypair); - void ClearKeys(); - void ClearHandshake(); - void ClearPacketQueue(); + bool CheckSwitchToNextKey_Locked(WgKeypair *keypair); + void ClearKeys_Locked(); + void ClearHandshake_Locked(); + void ClearPacketQueue_Locked(); bool CheckHandshakeRateLimit(); // Timer notifications @@ -408,23 +409,32 @@ public: }; uint32 CheckTimeouts(uint64 now); + void AddPacketToPeerQueue(Packet *packet); + +#if WITH_WG_THREADING + bool IsPeerLocked() { return mutex_.IsLocked(); } +#else // WITH_WG_THREADING + bool IsPeerLocked() { return true; } +#endif // WITH_WG_THREADING + private: - WgKeypair *CreateNewKeypair(bool is_initiator, const uint8 key[WG_HASH_LEN], uint32 send_key_id, const uint8 *extfield, size_t extfield_size); + static WgKeypair *CreateNewKeypair(bool is_initiator, const uint8 key[WG_HASH_LEN], uint32 send_key_id, const uint8 *extfield, size_t extfield_size); void WriteMacToPacket(const uint8 *data, MessageMacs *mac); void DeleteKeypair(WgKeypair **kp); void CheckAndUpdateTimeOfNextKeyEvent(uint64 now); - static void CopyEndpointToPeer(WgKeypair *keypair, const IpAddr *addr); + static void CopyEndpointToPeer_Locked(WgKeypair *keypair, const IpAddr *addr); size_t WriteHandshakeExtension(uint8 *dst, WgKeypair *keypair); - void InsertKeypairInPeer(WgKeypair *keypair); + void InsertKeypairInPeer_Locked(WgKeypair *keypair); WgDevice *dev_; WgPeer *next_peer_; // Keypairs, |curr_keypair_| is the used one, the other ones are // the old ones and the next one. - WgKeypair *curr_keypair_; - WgKeypair *prev_keypair_; - WgKeypair *next_keypair_; + WgKeypair *curr_keypair_, *prev_keypair_, *next_keypair_; + + // Protects shared variables of the WgPeer + WG_DECLARE_LOCK(mutex_); // Timestamp when the next key related event is going to occur. uint64 time_of_next_key_event_; @@ -433,23 +443,38 @@ private: uint32 timers_; uint32 timer_value_[5]; - // Holds the entry into the key id table during handshake + // Holds the entry into the key id table during handshake - mt only. uint32 local_key_id_during_hs_; + + // Address of peer IpAddr endpoint_; + enum { + kMainThreadScheduled_ScheduleHandshake = 1, + }; + std::atomic main_thread_scheduled_; + WgPeer *main_thread_scheduled_next_; + // The broadcast address of the IPv4 network, used to block broadcast traffic // from being sent out over the VPN link. uint32 ipv4_broadcast_addr_; + // Whether the tunsafe specific handshake extensions are supported bool supports_handshake_extensions_; + // Whether any data was sent since the keepalive timer was set bool pending_keepalive_; + + // Whether to change the endpoint on incoming packets. + bool allow_endpoint_change_; + + // Whether we've sent a mac to the peer so we may expect a cookie reply back. bool expect_cookie_reply_; // Whether we want to route incoming multicast/broadcast traffic to this peer. bool allow_multicast_through_peer_; - // Whether + // Whether |mac2_cookie_| is valid. bool has_mac2_cookie_; // Number of handshakes made so far, when this gets too high we stop connecting. @@ -462,11 +487,18 @@ private: uint8 num_queued_packets_; Packet *first_queued_packet_, **last_queued_packet_ptr_; + // For statistics uint64 last_handshake_init_timestamp_; uint64 last_complete_handskake_timestamp_; - uint64 last_handshake_init_recv_timestamp_; - enum { MAX_CIPHERS = 16 }; + // Timestamp to detect flooding of handshakes + uint64 last_handshake_init_recv_timestamp_; // main thread only + + // Number of handshake attempts since last successful handshake + uint32 total_handshake_attempts_; + + // For dynamic ciphers, holds the list of supported ciphers. + enum { MAX_CIPHERS = 4 }; uint8 cipher_prio_; uint8 num_ciphers_; uint8 ciphers_[MAX_CIPHERS]; @@ -482,19 +514,19 @@ private: uint8 e_priv[WG_PUBLIC_KEY_LEN]; }; HandshakeState hs_; - // Remote's static public key - Written only by Init + // Remote's static public key - init only. uint8 s_remote_[WG_PUBLIC_KEY_LEN]; - // Remote's preshared key - Written only by Init + // Remote's preshared key - init only. uint8 preshared_key_[WG_SYMMETRIC_KEY_LEN]; - // Precomputed DH(spriv_local, spub_remote). + // Precomputed DH(spriv_local, spub_remote) - init only. uint8 s_priv_pub_[WG_PUBLIC_KEY_LEN]; - // The most recent seen timestamp, only accept higher timestamps. - uint8 last_timestamp_[WG_TIMESTAMP_LEN]; - // Precomputed key for decrypting cookies from the peer. + // The most recent seen timestamp, only accept higher timestamps - mt only. + uint8 last_timestamp_[WG_TIMESTAMP_LEN]; + // Precomputed key for decrypting cookies from the peer - init only. uint8 precomputed_cookie_key_[WG_SYMMETRIC_KEY_LEN]; - // Precomputed key for sending MACs to the peer. + // Precomputed key for sending MACs to the peer - init only. uint8 precomputed_mac1_key_[WG_SYMMETRIC_KEY_LEN]; - // The last mac value sent, required to make cookies + // The last mac value sent, required to make cookies - mt only. uint8 sent_mac1_[WG_COOKIE_LEN]; // The mac2 cookie that gets appended to outgoing packets uint8 mac2_cookie_[WG_COOKIE_LEN]; @@ -520,10 +552,10 @@ public: BITMAP_MASK = BITMAP_SIZE - 1, }; - uint64 expected_seq_nr() const { return expected_seq_nr_; } + const uint64 expected_seq_nr() const { return expected_seq_nr_; } private: - uint64 expected_seq_nr_; + std::atomic expected_seq_nr_; uint32 bitmap_[BITMAP_SIZE]; }; @@ -574,7 +606,7 @@ struct WgKeypair { // Used so we know when to send out ack packets. uint32 incoming_packet_count; - // Id of the key in my map + // Id of the key in my map. (MainThread) uint32 local_key_id; // Id of the key in their map uint32 remote_key_id; @@ -602,7 +634,6 @@ struct WgKeypair { // State for packet compressor IpzipState ipzip_state_; #endif // WITH_HANDSHAKE_EXT - }; void WgKeypairEncryptPayload(uint8 *dst, const size_t src_len,