在不使用 cmath 的情况下在 C++ 中计算第 n 个根的有效方法

Efficient ways to compute nth root in C++ without using cmath

提问人:Ξένη Γήινος 提问时间:9/26/2023 最后编辑:Ξένη Γήινος 更新时间:9/26/2023 访问量:250

问:

如何在不使用 cmath 等的情况下有效地将数字的 n 次根计算到至少 12 位正确的小数位?

我试着自己解决它。我的想法是找到一个近似值,并使用牛顿的方法使近似值更准确。

我实现了 2 种方法,一种使用二叉搜索,另一种基于快速平方反根算法。

#include <array>
#include <chrono>
#include <cmath>
#include <iostream>
#include <vector>

using std::chrono::steady_clock;
using std::chrono::duration;
using std::cout;
using std::vector;
float r = 0.0;

inline float power(float base, int exp) {
    if (not exp) {
        return 1.0;
    }
    if (exp < 0) {
        base = 1 / base;
        exp = -exp;
    }
    float p = 1.0;
    while (exp > 1) {
        if (exp % 2) {
            p = base * p;
        }
        base *= base;
        exp /= 2;
    }
    return base * p;
}

inline float nth_root(float base, int n) {
    float lo, hi, x, p, r, v;
    int n1 = n - 1;
    lo = 0;
    hi = base;
    for (int i = 0; i < 12; i++) {
        x = (lo + hi) / 2;
        p = power(x, n1);
        v = x * p - base;
        r = n * p;
        if (v <= 0) {
            lo = x;
        }
        else {
            hi = x;
        }
    }
    x = (lo + hi) / 2;
    r = 1.0 / n;
    for (int i = 0; i < 12; i++) {
        x = r * (n * x + base / power(x, n));
    }
    return x;
}

inline float fast_nth_root(float base, int n)
{
   uint32_t i = std::bit_cast<uint32_t>(base);
   float rn = 1.0 / n;
   i = 0x3F7A3BEA * rn * (n + 1) - i * rn;
   float x = std::bit_cast<float>(i);
   for (int j = 0; j < 6; j++) {
       x = x * (n + 1 - base * power(x, n)) * rn;
   }
   return 1.0 / x;
}

int main()
{
    vector<float> bases(256);
    vector<int> ns(256);
    float r256 = 1.0 / 256;
    for (int i = 0; i < 256; i++) {
        bases[i] = 1.0 + rand() % 16384 + (rand() % 256) * r256;
        ns[i] = 2 + rand() % 30;
    }
    auto start = steady_clock::now();
    for (int64_t i = 0; i < 1048576; i++) {
        r += nth_root(bases[i % 256], ns[i % 256]);
    }
    auto end = steady_clock::now();
    duration<double, std::nano> time = end - start;
    cout << "nth_root: " << time.count() / 1048576 << " nanoseconds\n";
    start = steady_clock::now();
    for (int64_t i = 0; i < 1048576; i++) {
        r += pow(bases[i % 256], 1.0 / ns[i % 256]);
    }
    end = steady_clock::now();
    time = end - start;
    cout << "pow: " << time.count() / 1048576 << " nanoseconds\n";
    start = steady_clock::now();
    for (int64_t i = 0; i < 1048576; i++) {
        r += fast_nth_root(bases[i % 256], ns[i % 256]);
    }
    end = steady_clock::now();
    time = end - start;
    cout << "fast_nth_root: " << time.count() / 1048576 << " nanoseconds\n";
}

编译方式:

g++.exe -Wall -fexceptions -fomit-frame-pointer -fexpensive-optimizations -flto -O3 -m64 --std=c++20 -march=native -ffast-math  -c D:\MyScript\CodeBlocks\testapp\main.cpp -o obj\Release\main.o
g++.exe  -o bin\Release\testapp.exe obj\Release\main.o  -O3 -flto -s -static-libstdc++ -static-libgcc -static -m64  
PS C:\Users\Xeni> D:\MyScript\CodeBlocks\testapp\bin\Release\testapp.exe
nth_root: 318.12 nanoseconds
pow: 104.77 nanoseconds
fast_nth_root: 53.4222 nanoseconds

正如预期的那样,第一种方法非常慢,但尽管第二种方法比库代码快,但它可能不那么准确。

根据我在 Python 中的测试:

import random, struct

def root(num, p, lim, lin):
    lo = 0
    hi = num
    for _ in range(lim):
        x = (lo + hi) / 2
        po = x ** (p - 1)
        v = x * po - num
        r = p * po
        if v <= 0:
            lo = x
        else:
            hi = x

    x = (lo + hi) / 2
    r = 1 / p
    p -= 1
    for _ in range(lin):
        x = r * (p * x + num / x ** p)
    
    return x

stats = {}
for _ in range(64):
    n = random.randrange(1, 16384)
    p = random.randrange(2, 32)
    x = n ** (1 / p)
    i = 2
    j = 2
    b = 0
    while abs((y := root(n, p, i, j)) - x) > 1e-13:
        if i < 16 and b or i >= 16:
            j += 1
        else:
            i += 1
        b = not b

    stats[(n, p)] = (i, j, x, y)

def fast_nth_root(x: float, n: int, lim: int) -> float:
    t = int.from_bytes(struct.pack('>f', x), 'big')
    t = round(0x3F7A3BEA / n * (n + 1) - t / n)
    t = struct.unpack('>f', struct.pack('>i', t))[0]
    for _ in range(lim):
        t = t * (n + 1 - x * t ** n) / n
    
    return 1 / t


stats1 = {}
for _ in range(64):
    n = random.randrange(1, 16384)
    p = random.randrange(2, 32)
    x = n ** (1 / p)
    i = 2
    while abs((y := fast_nth_root(n, p, i)) - x) > 1e-13:
        i += 1

    stats1[(n, p)] = (i, x, y)
In [1434]: stats
Out[1434]:
{(7162, 13): (11, 10, 1.979434344458145, 1.979434344458145),
 (15510, 2): (5, 5, 124.53915047084591, 124.53915047084595),
 (3054, 25): (10, 9, 1.3784618821753079, 1.3784618821753076),
 (1601, 25): (9, 8, 1.3433081611539914, 1.3433081611540099),
 (3522, 21): (10, 9, 1.475348878994169, 1.475348878994169),
 (15107, 14): (12, 11, 1.9884410975573956, 1.9884410975573954),
 (15200, 16): (12, 11, 1.8254301706001883, 1.825430170600188),
 (1900, 15): (9, 8, 1.6541830766301984, 1.6541830766301981),
 (16145, 20): (12, 11, 1.6233116388185762, 1.623311638818576),
 (2580, 4): (7, 6, 7.126969930959522, 7.126969930959522),
 (1702, 27): (11, 10, 1.3172407839773748, 1.3172407839773748),
 (9875, 29): (13, 13, 1.3732280275029944, 1.3732280275029942),
 (15687, 15): (12, 11, 1.9041565885196923, 1.904156588519692),
 (5774, 16): (12, 11, 1.718273525571221, 1.718273525571221),
 (6186, 2): (5, 4, 78.65112840894274, 78.65112840894277),
 (4476, 23): (12, 11, 1.441233509663115, 1.441233509663115),
 (4161, 24): (12, 11, 1.4151416228042228, 1.4151416228042226),
 (16116, 13): (12, 11, 2.1068575543742956, 2.1068575543742956),
 (12380, 14): (11, 11, 1.9603661231094314, 1.9603661231094311),
 (9736, 19): (13, 12, 1.621491836717726, 1.6214918367177258),
 (8612, 26): (13, 12, 1.4169357394205302, 1.4169357394205302),
 (4586, 7): (9, 8, 3.334740217355978, 3.3347402173559777),
 (5232, 24): (12, 12, 1.428711330576587, 1.4287113305765868),
 (14698, 17): (12, 11, 1.7584613929955697, 1.7584613929955695),
 (4931, 13): (10, 9, 1.923410237452901, 1.9234102374529014),
 (7391, 4): (8, 7, 9.272050761175075, 9.272050761175075),
 (9949, 6): (9, 9, 4.637635073009885, 4.637635073009886),
 (4767, 18): (12, 11, 1.6008364077669808, 1.6008364077669806),
 (16318, 8): (11, 10, 3.3618889684623863, 3.3618889684623863),
 (7610, 28): (13, 12, 1.3760077520394016, 1.3760077520394014),
 (13632, 6): (10, 9, 4.887573066390476, 4.887573066390476),
 (8380, 21): (11, 11, 1.5375213098103222, 1.5375213098103224),
 (7247, 14): (11, 10, 1.88679879448582, 1.8867987944858202),
 (11343, 18): (13, 12, 1.6798196720467486, 1.6798196720467486),
 (6468, 17): (11, 10, 1.6755714114675964, 1.6755714114675964),
 (11801, 6): (10, 9, 4.771480299610415, 4.771480299610415),
 (441, 28): (9, 8, 1.2429230307022932, 1.2429230307022932),
 (15341, 14): (12, 11, 1.9906254301097404, 1.9906254301097404),
 (8501, 20): (11, 10, 1.5720758667453518, 1.5720758667453518),
 (2777, 19): (10, 10, 1.5178918732605664, 1.5178918732605664),
 (14842, 30): (14, 13, 1.3773672345540857, 1.3773672345540857),
 (6149, 28): (11, 10, 1.3655715058830975, 1.3655715058830973),
 (13374, 21): (12, 11, 1.5721306482454025, 1.5721306482454025),
 (9947, 30): (13, 13, 1.3591156205784112, 1.359115620578415),
 (14423, 16): (12, 11, 1.8194535606369682, 1.8194535606369682),
 (9341, 31): (13, 12, 1.3430036888404402, 1.3430036888404402),
 (14558, 7): (10, 9, 3.9330441035217714, 3.9330441035217714),
 (152, 16): (8, 7, 1.3688795144738382, 1.368879514473838),
 (13593, 18): (12, 11, 1.6967920812890351, 1.6967920812890351),
 (2834, 7): (8, 8, 3.113149507653915, 3.1131495076539637),
 (11545, 14): (11, 10, 1.950612466604169, 1.9506124666041689),
 (12416, 21): (12, 11, 1.566576147387506, 1.5665761473875057),
 (8998, 6): (9, 8, 4.560624662646501, 4.560624662646504),
 (5245, 27): (11, 10, 1.3733092699013858, 1.3733092699013856),
 (5693, 29): (11, 10, 1.3473937347050562, 1.3473937347050562),
 (3508, 26): (10, 10, 1.3688266297365765, 1.368826629736599),
 (16237, 9): (11, 10, 2.936526854011741, 2.936526854011741),
 (2911, 6): (8, 7, 3.778682197915392, 3.7786821979153924),
 (387, 22): (10, 9, 1.311061987204142, 1.311061987204142),
 (3324, 4): (7, 7, 7.593032412784651, 7.593032412784651),
 (15300, 22): (12, 11, 1.5495773040868939, 1.5495773040868939),
 (5469, 26): (11, 10, 1.3924053694535468, 1.392405369453547),
 (1195, 13): (8, 8, 1.7247279767538015, 1.724727976753898),
 (7998, 13): (11, 10, 1.9963162339549785, 1.9963162339549787)}

In [1435]: stats1
Out[1435]:
{(10882, 3): (4, 22.159990703206965, 22.159990703206965),
 (6673, 28): (6, 1.3695657835909767, 1.3695657835909767),
 (4803, 10): (4, 2.3342709144708604, 2.3342709144708604),
 (1802, 27): (5, 1.3200291160996098, 1.3200291160996294),
 (8380, 15): (4, 1.8262053100463662, 1.826205310046366),
 (12898, 21): (5, 1.569419919895426, 1.5694199198954262),
 (10227, 2): (4, 101.12863096077193, 101.12863096077193),
 (4857, 25): (5, 1.4042832772764529, 1.4042832772765208),
 (1351, 12): (4, 1.8234251715932501, 1.8234251715932501),
 (10180, 16): (3, 1.7802632882832108, 1.7802632882832108),
 (6948, 28): (6, 1.37154252932099, 1.3715425293209902),
 (13901, 10): (4, 2.5959994991170756, 2.5959994991171),
 (7513, 21): (5, 1.529545998990047, 1.529545998990047),
 (7902, 18): (4, 1.6464211857566509, 1.6464211857566777),
 (3277, 31): (6, 1.2983819395882321, 1.2983819395882321),
 (3499, 10): (4, 2.261488555258189, 2.2614885552581887),
 (15234, 30): (6, 1.3785646304878574, 1.3785646304878574),
 (5739, 13): (5, 1.9459928850146935, 1.9459928850146935),
 (1823, 24): (5, 1.3673072329614473, 1.367307232961453),
 (15105, 16): (4, 1.8247150144295399, 1.8247150144295399),
 (16215, 12): (3, 2.2429852247131876, 2.2429852247131876),
 (15844, 20): (5, 1.6217848596088165, 1.6217848596088162),
 (15677, 26): (5, 1.4499608216310222, 1.4499608216310922),
 (11839, 22): (5, 1.5316187797414815, 1.5316187797414818),
 (10163, 26): (6, 1.4259891723870766, 1.4259891723870766),
 (1550, 18): (5, 1.5039751132184096, 1.5039751132184096),
 (15194, 5): (4, 6.8601628355219795, 6.860162835521979),
 (15612, 24): (5, 1.4952969217858556, 1.495296921785857),
 (9469, 12): (4, 2.1446611088057317, 2.144661108805732),
 (4030, 20): (5, 1.5144859625611744, 1.5144859625611742),
 (11729, 3): (4, 22.720627875592783, 22.72062787559279),
 (12709, 29): (6, 1.3852274277113097, 1.3852274277113097),
 (12263, 31): (6, 1.354846884624788, 1.354846884624788),
 (6372, 9): (4, 2.6466548424778766, 2.6466548424779086),
 (7119, 5): (4, 5.8949998528103436, 5.8949998528103436),
 (10737, 27): (6, 1.4102365332846034, 1.4102365332846034),
 (2231, 11): (4, 2.01562182326949, 2.0156218232694902),
 (412, 9): (4, 1.952289125066342, 1.9522891250663446),
 (8417, 5): (4, 6.095810609109851, 6.095810609109851),
 (6759, 31): (6, 1.3290600231725829, 1.3290600231725826),
 (2207, 23): (5, 1.3975994148304356, 1.3975994148304371),
 (4755, 16): (4, 1.6975473914419268, 1.697547391441927),
 (7978, 13): (4, 1.995931787083301, 1.9959317870833602),
 (14957, 19): (4, 1.658550339552885, 1.658550339552935),
 (745, 28): (5, 1.2664178106847905, 1.2664178106847914),
 (2696, 15): (4, 1.6932249497186587, 1.6932249497186587),
 (5484, 7): (4, 3.421029175738748, 3.421029175738748),
 (15410, 10): (4, 2.6228911266357167, 2.6228911266357264),
 (315, 10): (3, 1.7775877772276876, 1.7775877772276876),
 (9252, 13): (4, 2.0188081438649, 2.0188081438649035),
 (1562, 27): (5, 1.313059731029047, 1.3130597310290748),
 (9803, 8): (4, 3.1544225980493534, 3.154422598049354),
 (14443, 16): (4, 1.8196111450479415, 1.8196111450479413),
 (5033, 23): (5, 1.4486017214956801, 1.4486017214956872),
 (16175, 2): (4, 127.18097341976905, 127.18097341976903),
 (15125, 5): (4, 6.853920721113647, 6.853920721113645),
 (16292, 19): (4, 1.6660301757087943, 1.666030175708797),
 (11486, 20): (5, 1.5959101637580406, 1.5959101637580406),
 (13824, 7): (4, 3.9040835527337374, 3.9040835527337383),
 (8604, 3): (4, 20.491172086350442, 20.491172086350442),
 (1225, 3): (4, 10.699874805650794, 10.699874805650795),
 (9163, 21): (5, 1.54407524840236, 1.5440752484023603),
 (7833, 21): (5, 1.53258704058841, 1.53258704058841),
 (8425, 24): (5, 1.4573551932369144, 1.4573551932369178)}

平均而言,第一种方法大约需要 12 次二叉搜索迭代和 12 次牛顿方法迭代才能使误差低于 10-13,而第二种方法需要 6 次牛顿方法迭代才能获得相同的精度。

有没有办法让代码在相同次数的迭代中运行得更快,或者有办法加快所涉及的数学的收敛速度?


这不是一项任务。这是一个自我强加的编程挑战。

C++ 算法 math c++20 nth-root

评论

1赞 Some programmer dude 9/26/2023
不相关,但是您自己的函数解决了什么问题,而这不相关?powerstd::pow
0赞 Some programmer dude 9/26/2023
你的实际任务是什么?它的要求是什么?它的局限性是什么?您提到的“没有库”部分被您使用大量 C++ 标准库代码所否定。如果不使用库,您甚至无法进行输出。如果只是数学函数应该在没有“库函数”的情况下完成,那么你甚至不应该使用 .std::bit_cast<uint32_t>(base)
3赞 Some programmer dude 9/26/2023
另外,你为什么要使用?除非你有非常具体的要求,否则你应该改用。不,它不会不那么“有效”或不那么“最佳”。为什么是全局变量?在函数中隐藏它,它的使用可以在函数中是本地的。floatdoublernth_rootmain
2赞 Some programmer dude 9/26/2023
顺便说一句,你展示的代码是否有效(硬性要求)?然后,由于您真正询问的似乎是代码审查,因此您应该在代码审查 SE 上发帖(但请确保它确实与主题有关)。
1赞 Simon Goater 9/26/2023
我不知道 c++ 的规则,但至少在 c 中,将浮点数转换为uint32_t是严格的别名违规。在 c 中,正确的方法是使用 union 或 memcpy。此外,如前所述,你不能有 12 位带浮点数的有效数字,所以你对结果的分析看起来很可疑。

答:

2赞 Matt Timmermans 9/26/2023 #1

使用二叉搜索来获得初始猜测并不是特别有效。nth_root

在进行正确的牛顿方法实现之前,您可以对浮点表示使用类似于 中的技巧。fast_nth...

像这样:

constexpr int32_t ONE_BITS = std::bit_cast<int32_t>(1.0f);

inline float nth_root(float base, int n) {
    int32_t x_bits = std::bit_cast<int32_t>(base);
    x_bits = (x_bits - ONE_BITS)/n + ONE_BITS;
    // first guess
    float x = std::bit_cast<float>(x_bits);

这样做的重要效果只是它将浮点指数除以 。n

评论

0赞 Oersted 9/26/2023
如果你关心可移植性,你不应该使用浮点二进制表示,它不是标准化的。不过,您可以通过 std 函数访问指数,这是唯一可移植的方法(但如果它碰巧在给定的架构/编译器上工作,则直接访问底层内存可能会产生开销)。
0赞 Matt Timmermans 9/26/2023
OP 试图避免函数和类似功能。cmath
0赞 Oersted 9/26/2023
明白了,但他是否希望它的实现是可移植的?他应该告诉我们,以便得到更中肯的答案。
0赞 Jérôme Richard 9/29/2023
@Oersted 实际上,它是标准化的。IEEE-754 标准规定了 32 位浮点数的精确表示(更具体地说,是 IEC-559 浮点数)。十多年来,所有主流平台都正确地支持它。这包括 Intel 和 AMD x86-64 处理器,以及 ARM 和 RISC 处理器(AFAIK>99% 的服务器/PC/移动市场份额)。如今,即使是主流 GPU 也符合 IEEE-754 标准。只有一些罕见的异国情调的嵌入式设备可能不支持这一点。无论如何,如果需要,可以在 C++ 中使用检查合规性。std::numeric_limits<double>::is_iec559;
0赞 Oersted 10/2/2023
@JérômeRichard 据我了解,IEEE-754 只提出了内存布局,但没有强加。如果不是这样,我将不胜感激地指出措辞。此外,我对标准的命名并不“流利”:iec559 和 IEEE-754 之间有什么区别(如果有的话)?提前致谢。
2赞 user21508463 9/26/2023 #2

你没有说你的函数应该涵盖什么范围的值,也没有说这些值的统计分布,所以我们必须建议一种通用的方法。

该策略是首先通过对数量级进行工作,为牛顿迭代找到一个好的初始近似值。

如果您可以访问浮点表示的指数(通过破解二进制表示),则一个好的起始值是 √2/2 乘以 2^(exponent/n)。我们选择 √2/2 而不是 1 以 [1/2, 1] 为中心。为了提高效率,您可以预先计算所有指数和所有根阶的这些常量。无论如何,为了节省空间,最好将指数分解为整数商和余数中的 n 上的指数。

如果指数不可用,那么你可以通过连续的加倍(或减半)来搜索它,从 1 开始(所以 1=2^0、2=2^1、4=2^2、8=2^3...)。这相当于指数之间的线性搜索。然而,更有效的是使用平方,并在指数之间实现指数搜索(每次的幂都加倍,2=2^1、4=2^2、16=2^4、256=2^8......找到可能的指数范围后,恢复为线性搜索。您还可以使用预先计算的值进行优化。

最后,你可以从牛顿的迭代开始。对于平方根的情况,您可以使用平方根反函数以避免除法。不幸的是,这并不能推广到高阶。

最后但并非最不重要的一点是,确定在最坏情况下(使用最差的初始近似值)所需的迭代次数可能是有益的,并且始终使用此数字,而不是测试具有特定容差的收敛性。