以尽可能快的速度比较 (a + sqrt(b)) 形式的两个值?

Comparing two values in the form (a + sqrt(b)) as fast as possible?

提问人:Bernard 提问时间:5/8/2019 最后编辑:Peter CordesBernard 更新时间:5/10/2019 访问量:2638

问:

作为我正在编写的程序的一部分,我需要比较两个值,其形式为 where 和 是无符号整数。由于这是一个紧密循环的一部分,我希望这个比较尽可能快地运行。(如果重要的话,我在 x86-64 机器上运行代码,并且无符号整数不大于 10^6。另外,我知道一个事实。a + sqrt(b)aba1<a2

作为一个独立的功能,这就是我试图优化的。我的数字足够小,可以(甚至)准确地表示它们,但结果中的舍入误差不能改变结果。doublefloatsqrt

// known pre-condition: a1 < a2  in case that helps
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    return a1+sqrt(b1) < a2+sqrt(b2);  // computed mathematically exactly
}

测试用例:应该返回 true,但如注释所示,@wim计算它将返回 false。因此,将截断回整数。is_smaller(900000, 1000000, 900001, 998002)sqrtf()(int)sqrt()

a1+sqrt(b1) = 90100和。最接近的浮点数正好是 90100。a2+sqrt(b2) = 901000.00050050037512481206


由于即使在现代 x86-64 上,当完全内联为指令时,该函数通常也非常昂贵,因此我尽量避免调用。sqrt()sqrtsdsqrt()

通过平方删除 sqrt 还可能通过使所有计算准确来避免舍入错误的任何危险。

相反,如果功能是这样的......

bool is_smaller(unsigned a1, unsigned b1, unsigned x) {
    return a1+sqrt(b1) < x;
}

...那我就可以了return x-a1>=0 && static_cast<uint64_t>(x-a1)*(x-a1)>b1;

但是现在由于有两个项,我不能做同样的代数运算。sqrt(...)

通过使用以下公式,我可以对值进行两次平方:

      a1 + sqrt(b1) = a2 + sqrt(b2)
<==>  a1 - a2 = sqrt(b2) - sqrt(b1)
<==>  (a1 - a2) * (a1 - a2) = b1 + b2 - 2 * sqrt(b1) * sqrt(b2)
<==>  (a1 - a2) * (a1 - a2) = b1 + b2 - 2 * sqrt(b1 * b2)
<==>  (a1 - a2) * (a1 - a2) - (b1 + b2) = - 2 * sqrt(b1 * b2)
<==>  ((b1 + b2) - (a1 - a2) * (a1 - a2)) / 2 = sqrt(b1 * b2)
<==>  ((b1 + b2) - (a1 - a2) * (a1 - a2)) * ((b1 + b2) - (a1 - a2) * (a1 - a2)) / 4 = b1 * b2

无符号除以 4 很便宜,因为它只是一个位移,但由于我将数字平方两次,我需要使用 128 位整数,并且我需要引入一些检查(因为我比较的是不等式而不是相等式)。>=0

感觉可能有一种方法可以更快地做到这一点,通过将更好的代数应用于这个问题。有没有办法更快地做到这一点?

C++ 优化 SQRT

评论

8赞 500 - Internal Server Error 5/8/2019
只是一个观察:如果为 true,那么您可以跳过 的计算。a1+sqrt(b1)<a2sqrt(b2)
4赞 StPiere 5/8/2019
您还可以观察到 max(sqrt(b)) = 1000 如果 b <= 10^6。因此,您只需要进一步调查 abs(a1-a2) <= 1000。否则,总是存在不平等
2赞 Dominique 5/8/2019
恐怕你进入代码的速度太快了:还有另一个问题“stackoverflow.com/questions/52807071/...”,我给出了一种减少浮点运算的方法,只需解释用例即可。您能向我们解释一下用例吗,也许我们可以想出更好的解决方案?(“a1+sqrt(b1)<a2+sqrt(b2)”是什么意思?
2赞 Peter Cordes 5/8/2019
@StPiere:如果输入分布相当均匀,那么在现代 x86 上将 LUT 用于 sqrt 将是可怕的。4MiB 缓存占用空间比 L2 缓存大小(通常为 256kiB)大得多,因此您最多只能获得 L3 命中,例如 Skylake 上的 45 个周期延迟。但即使在非常旧的 Core 2 上,单精度 sqrt 也有最坏情况下的 29 个周期延迟。(首先还有几个周期可以转换为 FP)。在 Skylake 上,FP sqrt 延迟 ~= L2 缓存达到延迟,并以吞吐量 = latency/4 进行流水线传输。更不用说缓存污染对其他代码的影响了。
3赞 kvantour 5/8/2019
由于 ,您已经可以直接排除所有情况a1 < a2b1 < b2

答:

4赞 Brendan 5/8/2019 #1

我累了,可能犯了一个错误;但我敢肯定,如果我这样做了,有人会指出来。.

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    a_diff = a1-a2;   // May be negative

    if(a_diff < 0) {
        if(b1 < b2) {
            return true;
        }
        temp = a_diff+sqrt(b1);
        if(temp < 0) {
            return true;
        }
        return temp*temp < b2;
    } else {
        if(b1 >= b2) {
            return false;
        }
    }
//  return a_diff+sqrt(b1) < sqrt(b2);

    temp = a_diff+sqrt(b1);
    return temp*temp < b2;
}

如果你知道,那么它可能会变成:a1 < a2

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    a_diff = a2-a1;    // Will be positive

    if(b1 > b2) {
        return false;
    }
    if(b1 >= a_diff*a_diff) {
        return false;
    }
    temp = a_diff+sqrt(b2);
    return b1 < temp*temp;
}

评论

2赞 Acorn 5/8/2019
我们知道 ,所以没有必要测试 .但也许值得对其进行测试(反转)。a1 < a2a_diff < 0> 1000
1赞 Peter Cordes 5/8/2019
你需要一个有符号的差值,如果你想对它进行平方,在不做你想要的实际 sqrt 的情况下检查条件。int a_diffint64_t
0赞 Brendan 5/8/2019
呵呵 - 我确实犯了一个错误(如果是负数,那么你不能在不使符号无聊的情况下将其平方) - 已修复。还添加了“如果你知道 a1 < a2”。a_diff+sqrt(b1);
0赞 kvantour 5/8/2019
只需为第二种情况定义即可。第一个检查应该是 。第一个平方根,如果你只是检查是否a_diff = a2 - a1b1 <= b2b1/a_diff < a_diff
0赞 Peter Cordes 5/8/2019
我也累了。半成品尝试:godbolt.org/z/U758t6。我认为我们可以使用(对于数字 >= 1.0 和 0 都是如此)。(ping @kvantour)。但也许最好先检查一下,然后sqrt( abs(b1-b2) ) <= sqrt(b1) - sqrt(b2)b1 < b2abs(a1-a2) <= 1000
19赞 geza 5/8/2019 #2

这是一个没有 的版本,尽管我不确定它是否比只有一个的版本更快(这可能取决于值的分布)。sqrtsqrt

这是数学(如何删除两个 sqrts):

ad = a2-a1
bd = b2-b1

a1+sqrt(b1) < a2+sqrt(b2)              // subtract a1
   sqrt(b1) < ad+sqrt(b2)              // square it
        b1  < ad^2+2*ad*sqrt(b2)+b2    // arrange
   ad^2+bd  > -2*ad*sqrt(b2)

在这里,右边总是负的。如果左侧为正数,则我们必须返回 true。

如果左边为负数,那么我们可以对不等式进行平方:

ad^4+bd^2+2*bd*ad^2 < 4*ad^2*b2

这里要注意的关键是,如果 ,则总是返回(因为最大值为 1000)。如果 ,则为小数,因此将始终适合 64 位(不需要 128 位算术)。代码如下:a2>=a1+1000is_smallertruesqrt(b1)a2<=a1+1000adad^4

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    int ad = a2 - a1;
    if (ad>1000) {
        return true;
    }

    int bd = b2 - b1;
    if (ad*ad+bd>0) {
        return true;
    }

    int ad2 = ad*ad;

    return (long long int)ad2*ad2 + (long long int)bd*bd + 2ll*bd*ad2 < 4ll*ad2*b2;
}

编辑:正如 Peter Cordes 所注意到的,第一个是不必要的,因为第二个 if 处理它,所以代码变得更小更快:if

bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    int ad = a2 - a1;
    int bd = b2 - b1;
    if ((long long int)ad*ad+bd>0) {
        return true;
    }

    int ad2 = ad*ad;
    return (long long int)ad2*ad2 + (long long int)bd*bd + 2ll*bd*ad2 < 4ll*ad2*b2;
}

评论

2赞 Peter Cordes 5/8/2019
省略分支可能会更好;我认为分支涵盖了一切。一个在大多数情况下是正确的分支比 2 个分支要好,因为每个分支都在某些时候被用于分支预测。除非检查捕获了大部分输入,否则值得额外执行 1 个 和 。(可能是 a 到 64 位)ad>1000ad*ad+bd>0ad>1000subimulmovzx
0赞 geza 5/8/2019
@PeterCordes:很好的观察(像往常一样:)),谢谢!
0赞 Peter Cordes 5/8/2019
哦,你这样做是为了不溢出一个?内联后,x86-64 编译器可以将零扩展优化为 64 位(因为编写 32 位寄存器会隐式执行此操作),因此我们可以将无符号输入提升到 ,然后减去得到 。(32 位减法需要对可能的负结果进行符号扩展,因此请避免这样做。ad*adint32_tuint64_tint64_tmovsxd
0赞 geza 5/8/2019
@PeterCordes:差不多:)我添加了它,所以肯定不会溢出 64 位。但是,是的,正如您所说,将乘法作为 64 位可以轻松处理这种情况。ad^4ad*ad
1赞 geza 5/9/2019
@wim:我已经验证了我的解决方案是正常的。但是,我发现 Brendan 的第二个版本(也许第一个版本)存在一些问题:)
2赞 StPiere 5/8/2019 #3

还有牛顿方法用于计算整数 sqrts,如此处所述 另一种方法是不计算平方根,而是通过二进制搜索搜索 floor(sqrt(n)) ...“只有”1000 个小于 10^6 的全平方数。 这可能性能不佳,但将是一种有趣的方法。我没有测量过其中任何一个,但这里有一些例子:

#include <iostream>
#include <array>
#include <algorithm>        // std::lower_bound
#include <cassert>          


bool is_smaller_sqrt(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
    return a1 + sqrt(b1) < a2 + sqrt(b2);
}

static std::array<int, 1001> squares;

template <typename C>
void squares_init(C& c)
{
    for (int i = 0; i < c.size(); ++i)
        c[i] = i*i;
}

inline bool greater(const int& l, const int& r)
{
    return r < l;
}

inline bool is_smaller_bsearch(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
    // return a1 + sqrt(b1) < a2 + sqrt(b2)

    // find floor(sqrt(b1)) - binary search withing 1000 elems
    auto it_b1 = std::lower_bound(crbegin(squares), crend(squares), b1, greater).base();

    // find floor(sqrt(b2)) - binary search withing 1000 elems
    auto it_b2 = std::lower_bound(crbegin(squares), crend(squares), b2, greater).base();

    return (a2 - a1) > (it_b1 - it_b2);
}

unsigned int sqrt32(unsigned long n)
{
    unsigned int c = 0x8000;
    unsigned int g = 0x8000;

    for (;;) {
        if (g*g > n) {
            g ^= c;
        }

        c >>= 1;

        if (c == 0) {
            return g;
        }

        g |= c;
    }
}

bool is_smaller_sqrt32(unsigned a1, unsigned b1, unsigned a2, unsigned b2)
{
    return a1 + sqrt32(b1) < a2 + sqrt32(b2);
}

int main()
{
    squares_init(squares);

    // now can use is_smaller
    assert(is_smaller_sqrt(1, 4, 3, 1) == is_smaller_sqrt32(1, 4, 3, 1));
    assert(is_smaller_sqrt(1, 2, 3, 3) == is_smaller_sqrt32(1, 2, 3, 3));
    assert(is_smaller_sqrt(1000, 4, 1001, 1) == is_smaller_sqrt32(1000, 4, 1001, 1));
    assert(is_smaller_sqrt(1, 300, 3, 200) == is_smaller_sqrt32(1, 300, 3, 200));
}

评论

0赞 Peter Cordes 5/8/2019
我认为,您的整数会运行该循环的固定 16 次迭代。您可以根据位扫描从较小的位置开始,以找到最高设置位并除以 2。或者只是从较低的固定起始点开始,因为已知的最大值是 1 密耳,而不是 ~40 亿。因此,我们可以保存大约 12 / 2 = 6 次迭代。但这可能仍然比转换为单精度 for 和 back 慢。也许如果你在整数循环中并行做两个平方根,那么更新和循环开销就会摊销,并且会有 2 个 dep 链sqrt32nnfloatsqrtssc
0赞 Peter Cordes 5/8/2019
二进制搜索反转表是一个有趣的想法,但在现代 x86-64 上仍然可能很糟糕,因为硬件 sqrt 不是很慢,但相对于具有更短/更简单管道的设计,分支错误预测的成本非常高。也许这个答案中的一些对有相同问题但在微控制器上的人有用。
1赞 Eric Towers 5/9/2019 #4

可能不比其他答案更好,但使用了不同的想法(以及大量的预分析)。

// Compute approximate integer square root of input in the range [0,10^6].
// Uses a piecewise linear approximation to sqrt() with bounded error in each piece:
//   0 <= x <= 784 : x/28
//   784 < x <= 7056 : 21 + x/112
//   7056 < x <= 28224 : 56 + x/252
//   28224 < x <= 78400 : 105 + x/448
//   78400 < x <= 176400 : 168 + x/700
//   176400 < x <= 345744 : 245 + x/1008
//   345744 < x <= 614656 : 336 + x/1372
//   614656 < x <= 1000000 : (784000+x)/1784
// It is the case that sqrt(x) - 7.9992711366390365897... <= pseudosqrt(x) <= sqrt(x).
unsigned pseudosqrt(unsigned x) {
    return 
        x <= 78400 ? 
            x <= 7056 ?
                x <= 764 ? x/28 : 21 + x/112
              : x <= 28224 ? 56 + x/252 : 105 + x/448
          : x <= 345744 ?
                x <= 176400 ? 168 + x/700 : 245 + x/1008
              : x <= 614656 ? 336 + x/1372 : (x+784000)/1784 ;
}

// known pre-conditions: a1 < a2, 
//                  0 <= b1 <= 1000000
//                  0 <= b2 <= 1000000
bool is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
// Try three refinements:
// 1: a1 + sqrt(b1) <= a1 + 1000, 
//    so is a1 + 1000 < a2 ?  
//    Convert to a2 - a1 > 1000 .
// 2: a1 + sqrt(b1) <= a1 + pseudosqrt(b1) + 8 and
//    a2 + pseudosqrt(b2) <= a2 + sqrt(b2), 
//    so is  a1 + pseudosqrt(b1) + 8 < a2 + pseudosqrt(b2) ?
//    Convert to a2 - a1 > pseudosqrt(b1) - pseudosqrt(b2) + 8 .
// 3: Actually do the work.
//    Convert to a2 - a1 > sqrt(b1) - sqrt(b2)
// Use short circuit evaluation to stop when resolved.
    unsigned ad = a2 - a1;
    return (ad > 1000)
           || (ad > pseudosqrt(b1) - pseudosqrt(b2) + 8)
           || ((int) ad > (int)(sqrt(b1) - sqrt(b2)));
}

(我手边没有编译器,所以这可能包含一两个错别字。

评论

0赞 Peter Cordes 5/9/2019
它确实与 (godbolt.org/z/VH4I3g) 编译,但我认为它不会做得很好。@Geza的答案有一个更快的整数早期输出,并且比这少做更少的整数数学运算(分支也更少),通过平方完全避免,而只需要 64 位整数。#include <math.h>sqrt
0赞 Eric Towers 5/9/2019
@PeterCordes:我没有办法做时间比较。Geza 的早期输出处理 99.946% 的均匀分布输入,并遵循两个减法和一个乘法、加法和比较。我的早期出局是一个减法和比较。我看不出“早期整数更快”的基础。我的早期输出处理了 99.8% 的均匀分布输入。该案例将处理的输入集提高到 99.9725%,额外花费 7 个比较和加法和减法各一个。(续)longlong longunsignedpseudosqrt()unsignedunsigned
0赞 Eric Towers 5/9/2019
@PeterCordes : 除非操作速度和现在的操作一样快,否则您的“更快”的说法会令人惊讶。事实似乎并非如此。[stackoverflow.com/questions/48779619/......long longunsigned
1赞 Peter Cordes 5/9/2019
比较和分支的成本很高,除非分支预测完美运行。比乘法贵得多(3c 延迟,现代 x86-64 上的 1 个周期吞吐量,如 AMD Zen 或 Intel 自 Nehalem 以来)。在现代 x86-64 上,只有更宽类型的除法成本更高,其他操作不依赖于数据或类型宽度。一些较旧的 x86-64 CPU(如 Bulldozer-family 或 Silvermont)具有较慢的 64 位乘法速度。agner.org/optimize。(当然,我们谈论的是标量;使用 SIMD 进行自动矢量化使窄类型变得有价值,因为您可以对每个向量执行更多操作)long long
1赞 Peter Cordes 5/9/2019
@wim:是的,无论 OP 做什么,一个相当详尽的单元测试都是有序的!(不过,1M ^4 太昂贵了,因此需要对搜索空间进行一些修剪以查看一些大值,以及一些使不等式两边几乎相等的值。
2赞 wim 5/9/2019 #5

我不确定代数运算是否与整数相结合 算术,必然导致最快的解决方案。您将需要 在这种情况下,许多标量相乘(不是很快),和/或 分支预测可能会失败,从而降低性能。 显然,您必须进行基准测试,以查看哪种解决方案在您的特定情况下最快。

一种方法可以制作 更快一点的是将选项添加到 gcc 或 clang。 在这种情况下,编译器不必检查负输入。 使用 icc 时,这是默认设置。sqrt-fno-math-errno

通过使用矢量化指令而不是标量指令可以提高更多的性能。 Peter Cordes 已经证明 clang 能够自动矢量化这段代码, 这样它就会生成这个.sqrtsqrtpdsqrtsqrtsdsqrtpd

但是,自动矢量化的成功程度很大程度上取决于正确的编译器设置 以及使用的编译器(clang、gcc、icc 等)。使用 或更早时,clang 不会矢量化。-march=nehalem

使用以下内部代码可以获得更可靠的矢量化结果,请参见下文。 对于可移植性,我们只假设 SSE2 支持,这是 x86-64 的基线。

/* gcc -m64 -O3 -fno-math-errno smaller.c                      */
/* Adding e.g. -march=nehalem or -march=skylake might further  */
/* improve the generated code                                  */
/* Note that SSE2 in guaranteed to exist with x86-64           */
#include<immintrin.h>
#include<math.h>
#include<stdio.h>
#include<stdint.h>

int is_smaller_v5(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    uint64_t a64    =  (((uint64_t)a2)<<32) | ((uint64_t)a1); /* Avoid too much port 5 pressure by combining 2 32 bit integers in one 64 bit integer */
    uint64_t b64    =  (((uint64_t)b2)<<32) | ((uint64_t)b1); 
    __m128i ax      = _mm_cvtsi64_si128(a64);         /* Move integer from gpr to xmm register                  */
    __m128i bx      = _mm_cvtsi64_si128(b64);         
    __m128d a       = _mm_cvtepi32_pd(ax);            /* Convert 2 integers to double                           */
    __m128d b       = _mm_cvtepi32_pd(bx);            /* We don't need _mm_cvtepu32_pd since a,b < 1e6          */
    __m128d sqrt_b  = _mm_sqrt_pd(b);                 /* Vectorized sqrt: compute 2 sqrt-s with 1 instruction   */
    __m128d sum     = _mm_add_pd(a, sqrt_b);
    __m128d sum_lo  = sum;                            /* a1 + sqrt(b1) in the lower 64 bits                     */
    __m128d sum_hi  =  _mm_unpackhi_pd(sum, sum);     /* a2 + sqrt(b2) in the lower 64 bits                     */
    return _mm_comilt_sd(sum_lo, sum_hi);
}


int is_smaller(unsigned a1, unsigned b1, unsigned a2, unsigned b2) {
    return a1+sqrt(b1) < a2+sqrt(b2);
}


int main(){
    unsigned a1; unsigned b1; unsigned a2; unsigned b2;
    a1 = 11; b1 = 10; a2 = 10; b2 = 10;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
    a1 = 10; b1 = 11; a2 = 10; b2 = 10;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
    a1 = 10; b1 = 10; a2 = 11; b2 = 10;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));
    a1 = 10; b1 = 10; a2 = 10; b2 = 11;
    printf("smaller?  %i  %i \n",is_smaller(a1,b1,a2,b2), is_smaller_v5(a1,b1,a2,b2));

    return 0;
}


有关生成的程序集,请参阅此 Godbolt 链接

在 Intel Skylake 上的简单吞吐量测试中,使用编译器选项,我发现了吞吐量 其中比原来的好 2.6 倍:6.8 个 CPU 周期对 18 个 CPU 周期,包括循环开销。然而,在一个(太? 简单的延迟测试,其中输入取决于之前的结果,我没有看到任何改进。(39.7 个周期对 39 个周期)。gcc -m64 -O3 -fno-math-errno -march=nehalemis_smaller_v5()is_smaller()a1, a2, b1, b2is_smaller(_v5)

评论

0赞 Peter Cordes 5/9/2019
Clang 已经像这样自动矢量化:Pgodbolt.org/z/GvNe2B 查看 double 和 signed int 版本。但仅为 ,不 .对于吞吐量,您绝对应该将浮点数与此策略一起使用,因为打包转换只有 1 uop,并且具有更好的吞吐量。OP 的数字都是 100 万或更少,因此可以精确地用 表示,它们的平方根也可以。顺便说一句,看起来你忘了设置,所以你的 gcc 选择了存储/重新加载策略而不是 ALU。doublefloatsqrtpsfloat-mtune=haswell_mm_set_epi32movd
0赞 wim 5/9/2019
@PeterCordes:单一精度不够准确,请在此处查看我的评论。我们只知道目标是 x86-64。不知何故,在这种情况下,clang 不会矢量化。即使有.Clang 确实以 4 秒产生更好的 asm。-march=nehalemmovd
1赞 wim 5/10/2019
@PeterCordes:请注意,在吞吐量测试中,自动矢量化函数很容易在端口 5 上出现瓶颈。如果我没记错的话,Clang 会生成 9 个 p5 微操作(Skylake)。
0赞 Peter Cordes 5/10/2019
我没有仔细看。这并不奇怪,这不是最佳的。:P有趣的是,我没有意识到 .当然,有是有道理的,但我以前从未注意到它。如果您使用 if 到 “dead” 变量而不是 (如果您在没有 AVX 的情况下编译),您应该能够保存 a 。但这需要大量的强制转换,因为内部函数使得帮助编译器以这种方式优化其随机播放变得不方便。(u)comisdmovapsmovhlpsunpckhpd
0赞 wim 5/10/2019
@PeterCordes 实际上,我以前从未使用过内在函数,但在这里它似乎很有用。(u)comisd