提问人:Ξένη Γήινος 提问时间:9/26/2023 最后编辑:Ξένη Γήινος 更新时间:9/26/2023 访问量:250
在不使用 cmath 的情况下在 C++ 中计算第 n 个根的有效方法
Efficient ways to compute nth root in C++ without using cmath
问:
如何在不使用 cmath 等的情况下有效地将数字的 n 次根计算到至少 12 位正确的小数位?
我试着自己解决它。我的想法是找到一个近似值,并使用牛顿的方法使近似值更准确。
我实现了 2 种方法,一种使用二叉搜索,另一种基于快速平方反根算法。
#include <array>
#include <chrono>
#include <cmath>
#include <iostream>
#include <vector>
using std::chrono::steady_clock;
using std::chrono::duration;
using std::cout;
using std::vector;
float r = 0.0;
inline float power(float base, int exp) {
if (not exp) {
return 1.0;
}
if (exp < 0) {
base = 1 / base;
exp = -exp;
}
float p = 1.0;
while (exp > 1) {
if (exp % 2) {
p = base * p;
}
base *= base;
exp /= 2;
}
return base * p;
}
inline float nth_root(float base, int n) {
float lo, hi, x, p, r, v;
int n1 = n - 1;
lo = 0;
hi = base;
for (int i = 0; i < 12; i++) {
x = (lo + hi) / 2;
p = power(x, n1);
v = x * p - base;
r = n * p;
if (v <= 0) {
lo = x;
}
else {
hi = x;
}
}
x = (lo + hi) / 2;
r = 1.0 / n;
for (int i = 0; i < 12; i++) {
x = r * (n * x + base / power(x, n));
}
return x;
}
inline float fast_nth_root(float base, int n)
{
uint32_t i = std::bit_cast<uint32_t>(base);
float rn = 1.0 / n;
i = 0x3F7A3BEA * rn * (n + 1) - i * rn;
float x = std::bit_cast<float>(i);
for (int j = 0; j < 6; j++) {
x = x * (n + 1 - base * power(x, n)) * rn;
}
return 1.0 / x;
}
int main()
{
vector<float> bases(256);
vector<int> ns(256);
float r256 = 1.0 / 256;
for (int i = 0; i < 256; i++) {
bases[i] = 1.0 + rand() % 16384 + (rand() % 256) * r256;
ns[i] = 2 + rand() % 30;
}
auto start = steady_clock::now();
for (int64_t i = 0; i < 1048576; i++) {
r += nth_root(bases[i % 256], ns[i % 256]);
}
auto end = steady_clock::now();
duration<double, std::nano> time = end - start;
cout << "nth_root: " << time.count() / 1048576 << " nanoseconds\n";
start = steady_clock::now();
for (int64_t i = 0; i < 1048576; i++) {
r += pow(bases[i % 256], 1.0 / ns[i % 256]);
}
end = steady_clock::now();
time = end - start;
cout << "pow: " << time.count() / 1048576 << " nanoseconds\n";
start = steady_clock::now();
for (int64_t i = 0; i < 1048576; i++) {
r += fast_nth_root(bases[i % 256], ns[i % 256]);
}
end = steady_clock::now();
time = end - start;
cout << "fast_nth_root: " << time.count() / 1048576 << " nanoseconds\n";
}
编译方式:
g++.exe -Wall -fexceptions -fomit-frame-pointer -fexpensive-optimizations -flto -O3 -m64 --std=c++20 -march=native -ffast-math -c D:\MyScript\CodeBlocks\testapp\main.cpp -o obj\Release\main.o
g++.exe -o bin\Release\testapp.exe obj\Release\main.o -O3 -flto -s -static-libstdc++ -static-libgcc -static -m64
PS C:\Users\Xeni> D:\MyScript\CodeBlocks\testapp\bin\Release\testapp.exe
nth_root: 318.12 nanoseconds
pow: 104.77 nanoseconds
fast_nth_root: 53.4222 nanoseconds
正如预期的那样,第一种方法非常慢,但尽管第二种方法比库代码快,但它可能不那么准确。
根据我在 Python 中的测试:
import random, struct
def root(num, p, lim, lin):
lo = 0
hi = num
for _ in range(lim):
x = (lo + hi) / 2
po = x ** (p - 1)
v = x * po - num
r = p * po
if v <= 0:
lo = x
else:
hi = x
x = (lo + hi) / 2
r = 1 / p
p -= 1
for _ in range(lin):
x = r * (p * x + num / x ** p)
return x
stats = {}
for _ in range(64):
n = random.randrange(1, 16384)
p = random.randrange(2, 32)
x = n ** (1 / p)
i = 2
j = 2
b = 0
while abs((y := root(n, p, i, j)) - x) > 1e-13:
if i < 16 and b or i >= 16:
j += 1
else:
i += 1
b = not b
stats[(n, p)] = (i, j, x, y)
def fast_nth_root(x: float, n: int, lim: int) -> float:
t = int.from_bytes(struct.pack('>f', x), 'big')
t = round(0x3F7A3BEA / n * (n + 1) - t / n)
t = struct.unpack('>f', struct.pack('>i', t))[0]
for _ in range(lim):
t = t * (n + 1 - x * t ** n) / n
return 1 / t
stats1 = {}
for _ in range(64):
n = random.randrange(1, 16384)
p = random.randrange(2, 32)
x = n ** (1 / p)
i = 2
while abs((y := fast_nth_root(n, p, i)) - x) > 1e-13:
i += 1
stats1[(n, p)] = (i, x, y)
In [1434]: stats
Out[1434]:
{(7162, 13): (11, 10, 1.979434344458145, 1.979434344458145),
(15510, 2): (5, 5, 124.53915047084591, 124.53915047084595),
(3054, 25): (10, 9, 1.3784618821753079, 1.3784618821753076),
(1601, 25): (9, 8, 1.3433081611539914, 1.3433081611540099),
(3522, 21): (10, 9, 1.475348878994169, 1.475348878994169),
(15107, 14): (12, 11, 1.9884410975573956, 1.9884410975573954),
(15200, 16): (12, 11, 1.8254301706001883, 1.825430170600188),
(1900, 15): (9, 8, 1.6541830766301984, 1.6541830766301981),
(16145, 20): (12, 11, 1.6233116388185762, 1.623311638818576),
(2580, 4): (7, 6, 7.126969930959522, 7.126969930959522),
(1702, 27): (11, 10, 1.3172407839773748, 1.3172407839773748),
(9875, 29): (13, 13, 1.3732280275029944, 1.3732280275029942),
(15687, 15): (12, 11, 1.9041565885196923, 1.904156588519692),
(5774, 16): (12, 11, 1.718273525571221, 1.718273525571221),
(6186, 2): (5, 4, 78.65112840894274, 78.65112840894277),
(4476, 23): (12, 11, 1.441233509663115, 1.441233509663115),
(4161, 24): (12, 11, 1.4151416228042228, 1.4151416228042226),
(16116, 13): (12, 11, 2.1068575543742956, 2.1068575543742956),
(12380, 14): (11, 11, 1.9603661231094314, 1.9603661231094311),
(9736, 19): (13, 12, 1.621491836717726, 1.6214918367177258),
(8612, 26): (13, 12, 1.4169357394205302, 1.4169357394205302),
(4586, 7): (9, 8, 3.334740217355978, 3.3347402173559777),
(5232, 24): (12, 12, 1.428711330576587, 1.4287113305765868),
(14698, 17): (12, 11, 1.7584613929955697, 1.7584613929955695),
(4931, 13): (10, 9, 1.923410237452901, 1.9234102374529014),
(7391, 4): (8, 7, 9.272050761175075, 9.272050761175075),
(9949, 6): (9, 9, 4.637635073009885, 4.637635073009886),
(4767, 18): (12, 11, 1.6008364077669808, 1.6008364077669806),
(16318, 8): (11, 10, 3.3618889684623863, 3.3618889684623863),
(7610, 28): (13, 12, 1.3760077520394016, 1.3760077520394014),
(13632, 6): (10, 9, 4.887573066390476, 4.887573066390476),
(8380, 21): (11, 11, 1.5375213098103222, 1.5375213098103224),
(7247, 14): (11, 10, 1.88679879448582, 1.8867987944858202),
(11343, 18): (13, 12, 1.6798196720467486, 1.6798196720467486),
(6468, 17): (11, 10, 1.6755714114675964, 1.6755714114675964),
(11801, 6): (10, 9, 4.771480299610415, 4.771480299610415),
(441, 28): (9, 8, 1.2429230307022932, 1.2429230307022932),
(15341, 14): (12, 11, 1.9906254301097404, 1.9906254301097404),
(8501, 20): (11, 10, 1.5720758667453518, 1.5720758667453518),
(2777, 19): (10, 10, 1.5178918732605664, 1.5178918732605664),
(14842, 30): (14, 13, 1.3773672345540857, 1.3773672345540857),
(6149, 28): (11, 10, 1.3655715058830975, 1.3655715058830973),
(13374, 21): (12, 11, 1.5721306482454025, 1.5721306482454025),
(9947, 30): (13, 13, 1.3591156205784112, 1.359115620578415),
(14423, 16): (12, 11, 1.8194535606369682, 1.8194535606369682),
(9341, 31): (13, 12, 1.3430036888404402, 1.3430036888404402),
(14558, 7): (10, 9, 3.9330441035217714, 3.9330441035217714),
(152, 16): (8, 7, 1.3688795144738382, 1.368879514473838),
(13593, 18): (12, 11, 1.6967920812890351, 1.6967920812890351),
(2834, 7): (8, 8, 3.113149507653915, 3.1131495076539637),
(11545, 14): (11, 10, 1.950612466604169, 1.9506124666041689),
(12416, 21): (12, 11, 1.566576147387506, 1.5665761473875057),
(8998, 6): (9, 8, 4.560624662646501, 4.560624662646504),
(5245, 27): (11, 10, 1.3733092699013858, 1.3733092699013856),
(5693, 29): (11, 10, 1.3473937347050562, 1.3473937347050562),
(3508, 26): (10, 10, 1.3688266297365765, 1.368826629736599),
(16237, 9): (11, 10, 2.936526854011741, 2.936526854011741),
(2911, 6): (8, 7, 3.778682197915392, 3.7786821979153924),
(387, 22): (10, 9, 1.311061987204142, 1.311061987204142),
(3324, 4): (7, 7, 7.593032412784651, 7.593032412784651),
(15300, 22): (12, 11, 1.5495773040868939, 1.5495773040868939),
(5469, 26): (11, 10, 1.3924053694535468, 1.392405369453547),
(1195, 13): (8, 8, 1.7247279767538015, 1.724727976753898),
(7998, 13): (11, 10, 1.9963162339549785, 1.9963162339549787)}
In [1435]: stats1
Out[1435]:
{(10882, 3): (4, 22.159990703206965, 22.159990703206965),
(6673, 28): (6, 1.3695657835909767, 1.3695657835909767),
(4803, 10): (4, 2.3342709144708604, 2.3342709144708604),
(1802, 27): (5, 1.3200291160996098, 1.3200291160996294),
(8380, 15): (4, 1.8262053100463662, 1.826205310046366),
(12898, 21): (5, 1.569419919895426, 1.5694199198954262),
(10227, 2): (4, 101.12863096077193, 101.12863096077193),
(4857, 25): (5, 1.4042832772764529, 1.4042832772765208),
(1351, 12): (4, 1.8234251715932501, 1.8234251715932501),
(10180, 16): (3, 1.7802632882832108, 1.7802632882832108),
(6948, 28): (6, 1.37154252932099, 1.3715425293209902),
(13901, 10): (4, 2.5959994991170756, 2.5959994991171),
(7513, 21): (5, 1.529545998990047, 1.529545998990047),
(7902, 18): (4, 1.6464211857566509, 1.6464211857566777),
(3277, 31): (6, 1.2983819395882321, 1.2983819395882321),
(3499, 10): (4, 2.261488555258189, 2.2614885552581887),
(15234, 30): (6, 1.3785646304878574, 1.3785646304878574),
(5739, 13): (5, 1.9459928850146935, 1.9459928850146935),
(1823, 24): (5, 1.3673072329614473, 1.367307232961453),
(15105, 16): (4, 1.8247150144295399, 1.8247150144295399),
(16215, 12): (3, 2.2429852247131876, 2.2429852247131876),
(15844, 20): (5, 1.6217848596088165, 1.6217848596088162),
(15677, 26): (5, 1.4499608216310222, 1.4499608216310922),
(11839, 22): (5, 1.5316187797414815, 1.5316187797414818),
(10163, 26): (6, 1.4259891723870766, 1.4259891723870766),
(1550, 18): (5, 1.5039751132184096, 1.5039751132184096),
(15194, 5): (4, 6.8601628355219795, 6.860162835521979),
(15612, 24): (5, 1.4952969217858556, 1.495296921785857),
(9469, 12): (4, 2.1446611088057317, 2.144661108805732),
(4030, 20): (5, 1.5144859625611744, 1.5144859625611742),
(11729, 3): (4, 22.720627875592783, 22.72062787559279),
(12709, 29): (6, 1.3852274277113097, 1.3852274277113097),
(12263, 31): (6, 1.354846884624788, 1.354846884624788),
(6372, 9): (4, 2.6466548424778766, 2.6466548424779086),
(7119, 5): (4, 5.8949998528103436, 5.8949998528103436),
(10737, 27): (6, 1.4102365332846034, 1.4102365332846034),
(2231, 11): (4, 2.01562182326949, 2.0156218232694902),
(412, 9): (4, 1.952289125066342, 1.9522891250663446),
(8417, 5): (4, 6.095810609109851, 6.095810609109851),
(6759, 31): (6, 1.3290600231725829, 1.3290600231725826),
(2207, 23): (5, 1.3975994148304356, 1.3975994148304371),
(4755, 16): (4, 1.6975473914419268, 1.697547391441927),
(7978, 13): (4, 1.995931787083301, 1.9959317870833602),
(14957, 19): (4, 1.658550339552885, 1.658550339552935),
(745, 28): (5, 1.2664178106847905, 1.2664178106847914),
(2696, 15): (4, 1.6932249497186587, 1.6932249497186587),
(5484, 7): (4, 3.421029175738748, 3.421029175738748),
(15410, 10): (4, 2.6228911266357167, 2.6228911266357264),
(315, 10): (3, 1.7775877772276876, 1.7775877772276876),
(9252, 13): (4, 2.0188081438649, 2.0188081438649035),
(1562, 27): (5, 1.313059731029047, 1.3130597310290748),
(9803, 8): (4, 3.1544225980493534, 3.154422598049354),
(14443, 16): (4, 1.8196111450479415, 1.8196111450479413),
(5033, 23): (5, 1.4486017214956801, 1.4486017214956872),
(16175, 2): (4, 127.18097341976905, 127.18097341976903),
(15125, 5): (4, 6.853920721113647, 6.853920721113645),
(16292, 19): (4, 1.6660301757087943, 1.666030175708797),
(11486, 20): (5, 1.5959101637580406, 1.5959101637580406),
(13824, 7): (4, 3.9040835527337374, 3.9040835527337383),
(8604, 3): (4, 20.491172086350442, 20.491172086350442),
(1225, 3): (4, 10.699874805650794, 10.699874805650795),
(9163, 21): (5, 1.54407524840236, 1.5440752484023603),
(7833, 21): (5, 1.53258704058841, 1.53258704058841),
(8425, 24): (5, 1.4573551932369144, 1.4573551932369178)}
平均而言,第一种方法大约需要 12 次二叉搜索迭代和 12 次牛顿方法迭代才能使误差低于 10-13,而第二种方法需要 6 次牛顿方法迭代才能获得相同的精度。
有没有办法让代码在相同次数的迭代中运行得更快,或者有办法加快所涉及的数学的收敛速度?
这不是一项任务。这是一个自我强加的编程挑战。
答:
使用二叉搜索来获得初始猜测并不是特别有效。nth_root
在进行正确的牛顿方法实现之前,您可以对浮点表示使用类似于 中的技巧。fast_nth...
像这样:
constexpr int32_t ONE_BITS = std::bit_cast<int32_t>(1.0f);
inline float nth_root(float base, int n) {
int32_t x_bits = std::bit_cast<int32_t>(base);
x_bits = (x_bits - ONE_BITS)/n + ONE_BITS;
// first guess
float x = std::bit_cast<float>(x_bits);
这样做的重要效果只是它将浮点指数除以 。n
评论
cmath
std::numeric_limits<double>::is_iec559;
你没有说你的函数应该涵盖什么范围的值,也没有说这些值的统计分布,所以我们必须建议一种通用的方法。
该策略是首先通过对数量级进行工作,为牛顿迭代找到一个好的初始近似值。
如果您可以访问浮点表示的指数(通过破解二进制表示),则一个好的起始值是 √2/2 乘以 2^(exponent/n)。我们选择 √2/2 而不是 1 以 [1/2, 1] 为中心。为了提高效率,您可以预先计算所有指数和所有根阶的这些常量。无论如何,为了节省空间,最好将指数分解为整数商和余数中的 n 上的指数。
如果指数不可用,那么你可以通过连续的加倍(或减半)来搜索它,从 1 开始(所以 1=2^0、2=2^1、4=2^2、8=2^3...)。这相当于指数之间的线性搜索。然而,更有效的是使用平方,并在指数之间实现指数搜索(每次的幂都加倍,2=2^1、4=2^2、16=2^4、256=2^8......找到可能的指数范围后,恢复为线性搜索。您还可以使用预先计算的值进行优化。
最后,你可以从牛顿的迭代开始。对于平方根的情况,您可以使用平方根反函数以避免除法。不幸的是,这并不能推广到高阶。
最后但并非最不重要的一点是,确定在最坏情况下(使用最差的初始近似值)所需的迭代次数可能是有益的,并且始终使用此数字,而不是测试具有特定容差的收敛性。
评论
power
std::pow
std::bit_cast<uint32_t>(base)
float
double
r
nth_root
main