提问人:TLW 提问时间:5/3/2021 更新时间:5/7/2021 访问量:424
正确舍入计算处理溢出的两个浮点数之和的 sqrt
Correctly rounded computation of sqrt of sum of two floats handling overflow
问:
有没有一种好的方法来计算正确舍入的结果
sqrt(a+b)
对于浮点数 和(精度相同),where 和 ?a
b
0<=a<+inf
0<=b<+inf
特别是,对于计算会溢出的输入值?a+b
(此处的“正确舍入”的含义与计算本身的意思相同,即返回最接近以无限精度计算的“真实”结果的可表示值。sqrt
(注意:一种明显的方法是以较大的浮点大小进行计算,并避免以这种方式溢出。不幸的是,这在一般情况下不起作用(例如,如果不支持更大的浮点格式)。
我试过 Herbie,但它完全放弃了。它似乎没有对 a+b 溢出的足够点进行采样来检测问题,并且似乎也没有很好地处理相关采样。不幸的是,因为它通常是一个很好的工具。
到目前为止,我一直在做的是(伪代码)
if a + b would overflow:
2*sqrt(a/4 + b/4) # Cannot overflow for finite inputs, as f::MAX/4 + f::MAX/4 <= f::MAX
else:
... # handle non-overflow case. Also interesting; not quite the topic of this question.
...这似乎在实践中大多有效,但 a) 完全无原则,b) 在实践中偶尔会返回一个结果,而 epsilon 在溢出避免部分(例如,真正的结果是,但这返回而不是x + 0.2(x.next_larger()-x)
x.next_larger()
x
)
有关 f32 中“off-by-epsilon”问题的快速示例:
>>> import decimal
>>> decimal.getcontext().prec = 256
>>> from decimal import Decimal as D
>>> from numpy import float32 as f32
>>> a = D(f32("6.0847234e31").astype(float))
>>> b = D(f32("3.4028235e38").astype(float))
>>> res_act = (a+b).sqrt()
>>> res_calc = D(f32("1.8446744e19").astype(float)) # 2*sqrt(a/4 + b/4) in f32 precision
>>> res_best = D(f32("1.8446746e19").astype(float)) # obtained by brute-force
>>> abs(res_calc - res_act) > abs(res_best - res_act)
True # oops
(你必须相信我以 f32 计算的结果,因为 Python 通常以 f64 精度运行。这也是 f32 跳舞的原因。
答:
通过2的幂进行适当的缩放,可以很容易地避免溢出,这样,大范围的参数就会向统一缩放。困难的部分是产生正确的圆角结果。我甚至不完全相信在下一个更大的 IEEE-754 二进制浮点类型中执行中间计算可以保证这一点,因为双舍五入存在潜在问题。
在没有更广泛的浮点类型的情况下,人们将不得不回退到将多个本机精度数字链接在一起,以执行具有更高中间精度的运算。Dekker 的一种常见方案称为成对精度。它使用成对的浮点数,其中较高重要的部分通常称为“头部”,而不太重要的部分称为“尾部”。这两个部分被归一化,使得尾巴的大小最多是头部大小的一半。
该方案中的有效有效位数为 2*p+1,其中 p 是底层浮点类型中的有效位数。“额外”位由尾部的符号位表示。需要注意的是,与底层基类型相比,指数范围没有变化,因此我们需要相当积极地向单位扩展,以避免在中间计算中遇到次正态操作数。成对精度计算无法保证结果的舍入正确。使用三胞胎可能会起作用,但需要付出更多的努力,而不是我所能承受的。
但是,成对精度可以提供忠实的舍入结果,并且几乎总是正确舍入。当FMA(融合乘加)可用时,可以相当有效地构建产生约2*p-1个好位的基于牛顿-拉夫森的对精度平方根。这就是我在下面的示例性 IS0-C99 代码中使用的代码,该代码使用映射到 IEEE-754 作为本机浮点类型。在编译精度对代码时,应最高遵守 IEEE-754 标准,以防止与浮点运算的书面顺序出现意外偏差。就我而言,我使用了 MSVC 2019 的命令行开关。float
binary32
/fp:strict
对于数百亿个随机测试向量,我的测试程序报告的最大误差为 0.500000179 ulp。
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <math.h>
/* compute square root of sum of two positive floating-point numbers */
float sqrt_sum_pos (float a, float b)
{
float mn, mx, res, scale_in, scale_out;
float r, s, t, u, v, w, x;
/* sort arguments according to magnitude */
mx = a < b ? b : a;
mn = a < b ? a : b;
/* select scale factor: scale argument larger in magnitude towards unity */
scale_in = (mx > 1.0f) ? 0x1.0p-64f : 0x1.0p+64f;
scale_out = (mx > 1.0f) ? 0x1.0p+32f : 0x1.0p-32f;
/* scale input arguments */
mn = mn * scale_in;
mx = mx * scale_in;
/* represent sum as a normalized pair s:t of 'float' */
s = mx + mn; // most significant bits
t = (mx - s) + mn; // least significant bits
/* compute square root of s:t. Based on Alan Karp and Peter Markstein,
"High Precision Division and Square Root", ACM TOMS, vol. 23, no. 4,
December 1997, pp. 561-589
*/
r = sqrtf (1.0f / s);
if (s == 0.0f) r = 0.0f;
x = r * s;
s = fmaf (x, -x, s);
r = 0.5f * r;
u = s + t;
v = (s - u) + t;
s = r * u;
t = fmaf (r, u, -s);
t = fmaf (r, v, t);
r = x + s;
s = (x - r) + s;
s = s + t;
t = r + s;
s = (r - t) + s;
/* Component sum of t:s represents square root with maximum error very close to 0.5 ulp */
w = s + t;
/* compensate scaling of source operands */
res = w * scale_out;
/* handle special cases: NaN, Inf */
t = a + b;
if (isinf (mx)) res = mx;
if (isnan (t)) res = t;
return res;
}
// George Marsaglia's KISS PRNG, period 2**123. Newsgroup sci.math, 21 Jan 1999
// Bug fix: Greg Rose, "KISS: A Bit Too Simple" http://eprint.iacr.org/2011/007
static uint32_t kiss_z=362436069, kiss_w=521288629;
static uint32_t kiss_jsr=123456789, kiss_jcong=380116160;
#define znew (kiss_z=36969*(kiss_z&65535)+(kiss_z>>16))
#define wnew (kiss_w=18000*(kiss_w&65535)+(kiss_w>>16))
#define MWC ((znew<<16)+wnew )
#define SHR3 (kiss_jsr^=(kiss_jsr<<13),kiss_jsr^=(kiss_jsr>>17), \
kiss_jsr^=(kiss_jsr<<5))
#define CONG (kiss_jcong=69069*kiss_jcong+1234567)
#define KISS ((MWC^CONG)+SHR3)
uint32_t float_as_uint32 (float a)
{
uint32_t r;
memcpy (&r, &a, sizeof r);
return r;
}
uint64_t double_as_uint64 (double a)
{
uint64_t r;
memcpy (&r, &a, sizeof r);
return r;
}
float uint32_as_float (uint32_t a)
{
float r;
memcpy (&r, &a, sizeof r);
return r;
}
double floatUlpErr (float res, double ref)
{
uint64_t i, j, err, refi;
int expoRef;
/* ulp error cannot be computed if either operand is NaN, infinity, zero */
if (isnan (res) || isnan (ref) || isinf (res) || isinf (ref) ||
(res == 0.0f) || (ref == 0.0f)) {
return 0.0;
}
/* Convert the float result to an "extended float". This is like a float
with 56 instead of 24 effective mantissa bits
*/
i = ((uint64_t) float_as_uint32 (res)) << 32;
/* Convert the double reference to an "extended float". If the reference is
>= 2^129, we need to clamp to the maximum "extended float". If reference
is < 2^-126, we need to denormalize because of float's limited exponent
range.
*/
refi = double_as_uint64 (ref);
expoRef = (int)(((refi >> 52) & 0x7ff) - 1023);
if (expoRef >= 129) {
j = 0x7fffffffffffffffULL;
} else if (expoRef < -126) {
j = ((refi << 11) | 0x8000000000000000ULL) >> 8;
j = j >> (-(expoRef + 126));
} else {
j = ((refi << 11) & 0x7fffffffffffffffULL) >> 8;
j = j | ((uint64_t)(expoRef + 127) << 55);
}
j = j | (refi & 0x8000000000000000ULL);
err = (i < j) ? (j - i) : (i - j);
return err / 4294967296.0;
}
int main (void)
{
float arga, argb, res, reff;
uint32_t argai, argbi, resi, refi, diff;
double ref, ulp, maxulp = 0;
unsigned long long int count = 0;
do {
/* random positive inputs */
argai = KISS & 0x7fffffff;
argbi = KISS & 0x7fffffff;
/* increase occurence of zero, infinity */
if ((argai & 0xffff) == 0x5555) argai = 0x00000000;
if ((argbi & 0xffff) == 0x3333) argbi = 0x00000000;
if ((argai & 0xffff) == 0xaaaa) argai = 0x7f800000;
if ((argbi & 0xffff) == 0xcccc) argbi = 0x7f800000;
arga = uint32_as_float (argai);
argb = uint32_as_float (argbi);
res = sqrt_sum_pos (arga, argb);
ref = sqrt ((double)arga + (double)argb);
reff = (float)ref;
ulp = floatUlpErr (res, ref);
resi = float_as_uint32 (res);
refi = float_as_uint32 (reff);
diff = (refi > resi) ? (refi - resi) : (resi - refi);
if (diff > 1) {
/* if both source operands were NaNs, result could be either NaN,
quietened if necessary
*/
if (!(isnan (arga) && isnan (argb) &&
((resi == (argai | 0x00400000)) ||
(resi == (argbi | 0x00400000))))) {
printf ("\rerror: refi=%08x resi=%08x a=% 15.8e %08x b=% 15.8e %08x\n",
refi, resi, arga, argai, argb, argbi);
return EXIT_FAILURE;
}
}
if (ulp > maxulp) {
printf ("\rulp = %.9f @ a=%14.8e (%15.6a) b=%14.8e (%15.6a) a+b=%22.13a res=%15.6a ref=%22.13a\n",
ulp, arga, arga, argb, argb, (double)arga + argb, res, ref);
maxulp = ulp;
}
count++;
if (!(count & 0xffffff)) printf ("\r%llu", count);
} while (1);
printf ("\ntest passed\n");
return EXIT_SUCCESS;
}
评论
另一种方法,既然@EricPostpischil和@njuffa突出了实际问题(即双舍入)。
(注:下面说的是“乖巧”的数字。它不考虑精度边界或亚法线,尽管可以扩展以执行此操作。
首先,请注意,保证 和 都返回与结果最接近的可表示值。问题是双舍五入。也就是说,当我们想要计算时,我们基本上是在计算。注意缺少内圆。sqrt(x)
a+b
round(sqrt(round(a+b)))
round(sqrt(a+b))
那么,内轮对结果的影响有多大呢?好吧,内轮加起来是 ±0.5 ULP 的加法结果。因此,我们粗略地假设了 -bit 尾数。sqrt((a+b)*(1 ±2**-p))
p
这减少到......但比现在更接近 1!(很接近,但不完全是,因为这是一个有限的差异。你可以从泰勒级数 1 (d/dx = 1/2) 左右看到这一点。然后,第二次舍入将结果影响另一个 ±0.5ULP。sqrt(a+b)*sqrt(1 ±2**-p)
sqrt(1 ±2**-p)
(1 ±2**-p)
(1 ±2**-(p+1))
这意味着我们保证与“真实”结果相差不超过 1 ULP。因此,如果我们能“只是”弄清楚如何选择,那么在两者之间进行选择的修复是一种可行的策略......{sqrt(a+b)-1ULP, sqrt(a+b), sqrt(a+b)+1ULP}
因此,让我们看看我们是否可以提出一种基于比较的方法,该方法在有限的精度下工作。(注:除非另有说明,否则以下为无限精度)
resy = float(sqrt(a+b))
resx = resy.prev_nearest()
resz = resy.next_nearest()
请注意,.resx < resy < resz
假设我们的浮点数有一点精度,那就变成了p
res = sqrt(a+b) // in infinite precision
resy = float(res)
resx = resy * (1 - 2**(1-p))
resz = resy * (1 + 2**(1-p))
因此,让我们比较一下:resx
resy
distx = abs(resx - res)
disty = abs(resy - res)
checkxy: distx < disty
checkxy: abs(resx - res) < abs(resy - res)
checkxy: (resx - res)**2 < (resy - res)**2
checkxy: resx**2 - 2*resx*res - res**2 < resy**2 - 2*resy*res - res**2
checkxy: resx**2 - resy**2 < 2*resx*res - 2*resy*res
checkxy: resx**2 - resy**2 < 2*res*(resx - resy)
// Assuming resx < resy
checkxy: resx+resy > 2*res
checkxy: resx+resy > 2*sqrt(a+b)
// Assuming resx+resy >= 0
checkxy: (resx+resy)**2 > 4*(a+b)
checkxy: (resy*(2 - 2**(1-p)))**2 > 4*(a+b)
checkxy: (resy**2)*((2 - 2**(1-p)))**2 > 4*(a+b)
checkxy: (resy**2)*(4 - 2*2**(1-p) + 2**(2-2p)) > 4*(a+b)
checkxy: (resy**2)*(4 - 4*2**(0-p) + 4*2**(0-2p)) > 4*(a+b)
checkxy: (resy**2)*(1 - 2**-p + 2**-2p) > a+b
...这是我们实际上可以在有限精度下进行的检查(尽管它仍然需要更高的精度,这很烦人)。
同上,因为我们得到checkyz
checkxy: disty < distz
checkyz: (resy**2)*(1 + 2**-p + 2**-2p) < a+b
从这两个检查中,您可以选择正确的结果。...然后它“只是”检查/处理我上面掩盖的边缘情况的问题。
现在,在实践中,我认为与一开始就以更高的精度进行sqrt相比,这是不值得的,至少除非有人能想出更好的选择方法。但这仍然是一个有趣的选择。
这是一个极端的例子。让我们看看 p 是浮点精度。u = 2^-p
我们有。(1+u)^2 = (1+2u) + u^2
如果我们取 ,我们有 ,a 是浮点数中的可表示值(它是 1 之后的下一个浮点数),并且 , , b 也可以表示为浮点数(作为 2^(-2p 的幂))。a = 1+2u
float(a)=a
b= u^2
float(b)=b
精确是 ,应该四舍五入为 ,由于精确的平局,它被四舍五入到最接近的甚至显着的......sqrt(a+b)
(1+u)
float(1+u)=1
float(a+b)=a
和 ,所以没关系。float(sqrt(a))=1
但是让我们改变 改变 b 的最后一点: ;,b 只是按比例缩小了两倍的精度。b=(1+2*u)*u^2
float(b)=b
我们现在有了确切的 ,因此它应该四舍五入为 。sqrt(a+b) > 1+u
float(sqrt(a+b)) = 1+2u
我们看到,稍微到 2^(-3p+1) 位(浮点精度的三倍)可以改变正确的舍入!
这意味着您不应依赖双精度来执行正确的舍入操作。
评论
sqrt(a)
b
sqrt(a+b)
a+b
sqrt(a+b)
2*sqrt(a/4 + b/4)
sqrt(a+b)