提问人:kleite 提问时间:10/13/2023 最后编辑:snakecharmerbkleite 更新时间:10/18/2023 访问量:201
如何在双双算术中正确四舍五入到最接近的整数
How to properly round to the nearest integer in double-double arithmetic
问:
我必须使用 Python3(PyPy 实现)分析大量数据,在其中我对相当大的浮点数进行一些操作,并且必须检查结果是否足够接近整数。
举例来说,假设我正在生成随机的数字对,并检查它们是否形成毕达哥拉斯三元组(是具有整数边的直角三角形的边):
from math import hypot
from pprint import pprint
from random import randrange
from time import time
def gen_rand_tuples(start, stop, amount):
'''
Generates random integer pairs and converts them to tuples of floats.
'''
for _ in range(amount):
yield (float(randrange(start, stop)), float(randrange(start, stop)))
t0 = time()
## Results are those pairs that results in integer hypothenuses, or
## at least very close, to within 1e-12.
results = [t for t in gen_rand_tuples(1, 2**32, 10_000_000) if abs((h := hypot(*t)) - int(h)) < 1e-12]
print('Results found:')
pprint(results)
print('finished in:', round(time() - t0, 2), 'seconds.')
运行它我得到了:
Python 3.9.17 (a61d7152b989, Aug 13 2023, 10:27:46)
[PyPy 7.3.12 with GCC 13.2.1 20230728 (Red Hat 13.2.1-1)] on linux
Type "help", "copyright", "credits" or "license()" for more information.
>>>
===== RESTART: /home/user/Downloads/pythagorean_test_floats.py ====
Results found:
[(2176124225.0, 2742331476.0),
(342847595.0, 3794647043.0),
(36.0, 2983807908.0),
(791324089.0, 2122279232.0)]
finished in: 2.64 seconds.
有趣的是,它运行得很快,在 2 秒多一点的时间内处理了 1000 万个数据点,我甚至找到了一些匹配的数据。hypothenuse 显然是整数:
>>> pprint([hypot(*x) for x in results])
[3500842551.0, 3810103759.0, 2983807908.0, 2265008378.0]
但事实并非如此,如果我们使用十进制任意精度模块检查结果,我们会发现结果实际上并不够接近整数:
>>> from decimal import Decimal
>>> pprint([(x[0]*x[0] + x[1]*x[1]).sqrt() for x in (tuple(map(Decimal, x)) for x in results)])
[Decimal('3500842551.000000228516418075'),
Decimal('3810103758.999999710375341513'),
Decimal('2983807908.000000217172157183'),
Decimal('2265008377.999999748566051441')]
所以,我认为问题在于数字足够大,以至于落在 python 浮点数缺乏精度的范围内,因此返回了误报。
现在,我们可以将程序更改为在任何地方使用任意精度小数:
from decimal import Decimal
from pprint import pprint
from random import randrange
from time import time
def dec_hypot(x, y):
return (x*x + y*y).sqrt()
def gen_rand_tuples(start, stop, amount):
'''
Generates random integer pairs and converts them to tuples of decimals.
'''
for _ in range(amount):
yield (Decimal(randrange(start, stop)), Decimal(randrange(start, stop)))
t0 = time()
## Results are those pairs that results in integer hypothenuses, or
## at least very close, to within 1e-12.
results = [t for t in gen_rand_tuples(1, 2**32, 10_000_000) if abs((h := dec_hypot(*t)) - h.to_integral_value()) < Decimal(1e-12)]
print('Results found:')
pprint(results)
print('finished in:', round(time() - t0, 2), 'seconds.')
现在,我们没有得到任何误报,但我们对性能造成了很大的影响。以前需要 2 秒多一点,现在需要 100 多秒。小数点似乎对 JIT 不友好:
====== RESTART: /home/user/Downloads/pythagorean_test_dec.py ======
Results found:
[]
finished in: 113.82 seconds.
我找到了这个问题的答案,CPython 和 PyPy Decimal 操作性能,建议使用双倍精度数字作为小数的更快、JIT 友好的替代方案,以获得比内置浮点数更好的精度。所以我pip安装了doubledouble第三方模块,并相应地更改了程序:
from doubledouble import DoubleDouble
from decimal import Decimal
from pprint import pprint
from random import randrange
from time import time
def dd_hypot(x, y):
return (x*x + y*y).sqrt()
def gen_rand_tuples(start, stop, amount):
for _ in range(amount):
yield (DoubleDouble(randrange(start, stop)), DoubleDouble(randrange(start, stop)))
t0 = time()
print('Results found:')
results = [t for t in gen_rand_tuples(1, 2**32, 10_000_000) if abs((h := dd_hypot(*t)) - int(h)) < DoubleDouble(1e-12)]
pprint(results)
print('finished in:', round(time() - t0, 2), 'seconds.')
但是我收到这个错误:
======= RESTART: /home/user/Downloads/pythagorean_test_dd.py ======
Results found:
Traceback (most recent call last):
File "/home/user/Downloads/pythagorean_test_dd.py", line 24, in <module>
results = [t for t in gen_rand_tuples(1, 2**32, 10_000_000) if abs((h := dd_hypot(*t)) - int(h)) < DoubleDouble(1e-12)]
File "/home/user/Downloads/pythagorean_test_dd.py", line 24, in <listcomp>
results = [t for t in gen_rand_tuples(1, 2**32, 10_000_000) if abs((h := dd_hypot(*t)) - int(h)) < DoubleDouble(1e-12)]
TypeError: int() argument must be a string, a bytes-like object or a number, not 'DoubleDouble'
我认为问题是该模块没有指定转换或舍入到最接近的整数方法。我能写的最好的是一个非常人为的“int”函数,它通过对字符串和小数进行往返并返回 DoubleDouble 来将 double-double 四舍五入到最接近的整数:
def contrived_int(dd):
rounded_str = (Decimal(dd.x) + Decimal(dd.y)).to_integral_value()
hi = float(rounded_str)
lo = float(Decimal(rounded_str) - Decimal(hi))
return DoubleDouble(hi, lo)
但它非常迂回,违背了回避小数的目的,并使程序甚至比全十进制版本更慢。
然后我问,有没有一种快速的方法可以将双精度数字直接舍入到最接近的整数,而无需中间步骤通过小数或字符串?
答:
不是你直接问的问题的答案,但这里至少有一种方法可以检查任何大小的整数是否是完美的平方(我确信有更快的方法,但至少这应该始终有效并且具有对数复杂性):
def is_square(n):
low = 0
high = 1
while high * high <= n:
low = high
high *= 2
while low < high:
mid = (low + high) >> 1
if mid * mid == n:
return True
if mid * mid > n:
high = mid
else:
low = mid + 1
return False
它只是在做一个二进制搜索。
评论
math.isqrt()
math.modf
拆分小数部分和整数部分,您可以将其与阈值进行比较。您还可以移动此检查以减少开销(不是很多,但仍然有一些东西)gen_rand_tuples
import math
from pprint import pprint
from random import randrange
from time import time
from doubledouble import DoubleDouble
def dd_hypot(x, y):
return (x * x + y * y).sqrt()
def gen_rand_tuples(start, stop, amount):
t = DoubleDouble(1e-12)
for _ in range(amount):
x, y = DoubleDouble(randrange(start, stop)), DoubleDouble(randrange(start, stop))
if math.modf(dd_hypot(x, y))[0] < t:
yield float(x), float(y)
t0 = time()
results = gen_rand_tuples(1, 2 ** 32, 10_000_000)
print('results found in', round(time() - t0, 2), 'seconds:')
pprint([t for t in results])
print('finished in:', round(time() - t0, 2), 'seconds.')
输出:
results found in 0.0 seconds:
[(680368648.0, 3711917722.0),
(3725230685.0, 4052331950.0),
(3105505826.0, 4185910333.0),
(4149112881.0, 1954134663.0),
(2526797500.0, 3295693164.0),
(1386817952.0, 1040113474.0)]
finished in: 49.76 seconds.
评论
您可以尝试的一件事是使用多处理来生成随机对并对其进行测试。
在下面的代码中,调用 以确定 ,多处理池中应包含的进程数。然后我们提交任务传递给每个任务,任务应该生成和测试的对数(其中是要生成和测试的对总数)。这些任务的结果将累积到 中。multiprocessing.cpu_count()
pool_size
pool_size
n_pairs
N_TUPLETS // pool_size
N_TUPLETS
pool_size
results
显然,您拥有的 CPU 内核越多,时间的减少就越大。此外,仅当每个任务生成的对数和测试足够大时,多处理才会提高性能,这样使用多处理产生的开销就可以通过并行运行任务来补偿,这种情况如下:
from decimal import Decimal
from random import randrange
from pprint import pprint
from time import time
N_TUPLETS = 10_000_000
def dec_hypot(x, y):
return (x*x + y*y).sqrt()
def gen_rand_tuples(start, stop, amount):
'''
Generates random integer pairs and converts them to tuples of decimals.
'''
for _ in range(amount):
yield (Decimal(randrange(start, stop)), Decimal(randrange(start, stop)))
def generate_and_test_tuplets(n_pairs):
## Results are those pairs that results in integer hypothenuses, or
## at least very close, to within 1e-12.
return [t for t in gen_rand_tuples(1, 2**32, n_pairs) if abs((h := dec_hypot(*t)) - h.to_integral_value()) < Decimal(1e-12)]
def serial_test():
t0 = time()
results = generate_and_test_tuplets(N_TUPLETS)
print('Serial results found:')
pprint(results)
print('finished in:', round(time() - t0, 2), 'seconds.')
def parallel_test():
from multiprocessing import Pool, cpu_count
t0 = time()
pool_size = cpu_count()
print('The pool size is', pool_size)
n_pairs_list = [N_TUPLETS // pool_size] * (pool_size - 1)
n_pairs_list.append(N_TUPLETS - sum(n_pairs_list))
with Pool(pool_size) as pool:
results = []
for result in pool.imap_unordered(generate_and_test_tuplets, n_pairs_list):
results.extend(result)
print('Parallel results found:')
pprint(results)
print('finished in:', round(time() - t0, 2), 'seconds.')
if __name__ == '__main__':
serial_test()
print()
parallel_test()
指纹:
Serial results found:
[]
finished in: 78.92 seconds.
The pool size is 8
Parallel results found:
[]
finished in: 16.83 seconds.
由于有 8 个逻辑内核(4 个物理内核),多处理版本将运行时间缩短了大约 5 倍。没有什么能阻止你把其他答案中建议的改进也纳入你的帖子。
评论
由于 Python 整数没有上限,并且您正在寻找积分结果,因此您应该坚持使用整数输入和整数运算。在您的示例中,您可以使用 math.isqrt
来执行整数平方根,以完全避免浮点数的任何不精确性:
results = [
(x, y) for x, y in gen_rand_tuples(1, 2 ** 32, 10_000_000)
if (s := x * x + y * y) == math.isqrt(s) ** 2
]
在测试中,这与第一次尝试浮点运算的速度差不多,但没有任何不精确性:
演示:在线试用!
评论