将 SVD 代码从 MATLAB 转换为 Python

Converting SVD code from MATLAB to Python

提问人:Harley Towler 提问时间:7/20/2023 更新时间:7/20/2023 访问量:54

问:

我在MATLAB中有一些代码需要转换为Python,以SVD为中心。在相同的输入下,我无法得到相同的结果。有没有人知道我哪里出错了?

MATLAB代码

function v = homography_solve(pin, pout)
% HOMOGRAPHY_SOLVE finds a homography from point pairs
%   V = HOMOGRAPHY_SOLVE(PIN, POUT) takes a 2xN matrix of input vectors and
%   a 2xN matrix of output vectors, and returns the homogeneous
%   transformation matrix that maps the inputs to the outputs, to some
%   approximation if there is noise.
%
%   This uses the SVD method of
%   http://www.robots.ox.ac.uk/%7Evgg/presentations/bmvc97/criminispaper/node3.html
% David Young, University of Sussex, February 2008
if ~isequal(size(pin), size(pout))
    error('Points matrices different sizes');
end
if size(pin, 1) ~= 2
    error('Points matrices must have two rows');
end
n = size(pin, 2);
if n < 4
    error('Need at least 4 matching points');
end
% Solve equations using SVD
x = pout(1, :); y = pout(2,:); X = pin(1,:); Y = pin(2,:);
rows0 = zeros(3, n);
rowsXY = -[X; Y; ones(1,n)];
hx = [rowsXY; rows0; x.*X; x.*Y; x];
hy = [rows0; rowsXY; y.*X; y.*Y; y];
h = [hx hy];
if n == 4
    [U, ~, ~] = svd(h);
else
    [U, ~, ~] = svd(h, 'econ');
end
v = (reshape(U(:,9), 3, 3)).';
end
pin =

  [167.4787  300.2447  430.9681  114.3723  298.2021  479.9894;
  200.5000  199.1383  202.5426  226.3723  228.4149  228.4149]

pout = 
   [-3.0500         0    3.0500   -3.0500         0    3.0500;
    6.7050    6.7050    6.7050    1.9800    1.9800    1.9800]

This returns 

v =    

   [-0.0006   -0.0000    0.1934;
   -0.0000    0.0040   -0.9801;
    0.0000   -0.0004    0.0449]

我在 Python 中的当前代码如下,但没有给出类似的结果。它将返回一个 3x3 矩阵,但与 MATLAB 不同。

import numpy as np
from scipy import linalg

def homography_solve(pin, pout):
    # Check if input and output matrices have the same size
    if pin.shape != pout.shape:
        raise ValueError('Points matrices different sizes')
    
    # Check if the matrices have two rows
    if pin.shape[0] != 2:
        raise ValueError('Points matrices must have two rows')
    
    n = pin.shape[1]
    
    # Check if we have at least 4 matching points
    if n < 4:
        raise ValueError('Need at least 4 matching points')
    
    # Solve equations using SVD
    x = pout[0, :]
    y = pout[1, :]
    X = pin[0, :]
    Y = pin[1, :]
    
    rows0 = np.zeros((3, n))
    rowsXY = np.vstack((-X, -Y, -np.ones(n)))
    hx = np.vstack((rowsXY, rows0, x * X, x * Y, x))
    hy = np.vstack((rows0, rowsXY, y * X, y * Y, y))
    h = np.hstack((hx, hy))
    

    if n == 4:
        U, sdiag, VH = np.linalg.svd(h)
    else:
        U, sdiag, VH = np.linalg.svd(h, full_matrices=False)
        
    S = np.zeros((h.shape[0], h.shape[1]))
    np.fill_diagonal(S, sdiag)
    V = VH.T.conj() 
    
    v = (U[-1,:].reshape(3, 3)).T
    
    return v
Python MATLAB 计算机视觉 SVD

评论

0赞 D.L 7/20/2023
请展示你期望在 Python 中得到什么......
0赞 D.L 7/20/2023
此外,您已从 导入,但在您使用的代码中......你是故意这样做的吗?linalgscipynp.linalg
0赞 D.L 7/20/2023
最后,没有定义。据推测,它们是包含 2 行的 numpy 数组。pin.shapepout.shape
0赞 Harley Towler 7/21/2023
我希望像在 Matlab 代码中一样获得 v。Linalg 和 scipy 都是导入的,因为一些资源建议一个或另一个,因此两者都被导入,但一个使用 pin.shape[0] 是指行数,pin.shape[1] 是指列数 - 或者至少应该

答:

0赞 Reinderien 7/20/2023 #1

您显示的参考是非解,并且与您的参考的反向乘法看起来根本不像。因此,代码没有给出类似的结果是一件好事 - 但 Python 解决方案的输出也是错误的。最简单的解决方案是调用:vpinpoutlstsq

import numpy as np


def homography_solve(pin: np.ndarray, pout: np.ndarray) -> np.ndarray:
    if pin.shape != pout.shape:
        raise ValueError('Points matrices different sizes')

    return np.linalg.lstsq(
        a=np.hstack((pin.T, np.ones(shape=(pin.shape[1], 1), dtype=pin.dtype))),
        b=pout.T,
        rcond=None,
    )[0]


ain = np.array([
    [167.4787,  300.2447,  430.9681,  114.3723,  298.2021,  479.9894],
    [200.5000,  199.1383,  202.5426,  226.3723,  228.4149,  228.4149],
])
aout = np.array([
      [-3.0500,         0,    3.0500,   -3.0500,         0,    3.0500],
      [ 6.7050,    6.7050,    6.7050,    1.9800,    1.9800,    1.9800],
])
output = homography_solve(ain, aout)
np.set_printoptions(linewidth=200)
print(output)
print(aout)
print(output.T @ np.vstack((ain, np.ones(6))))
[[ 1.88837036e-02  1.10075526e-03]
 [ 1.72936013e-03 -1.74139133e-01]
 [-6.00807070e+00  4.13197913e+01]]
[[-3.05   0.     3.05  -3.05   0.     3.05 ]
 [ 6.705  6.705  6.705  1.98   1.98   1.98 ]]
[[-2.49871587e+00  6.04304700e-03  2.48047224e+00 -3.45681886e+00  1.81009790e-02  3.45091846e+00]
 [ 6.58924815e+00  6.97251628e+00  6.52358890e+00  2.02541110e+00  1.87206612e+00  2.07216945e+00]]

您的数据被过度确定;如果从输入中剪掉几列,则解决方案将变得精确:

[[ 2.30741752e-02 -2.05810916e-17]
 [ 9.88907056e-03  2.15105711e-16]
 [-8.89719152e+00  6.70500000e+00]]
[[-3.05   0.     3.05 ]
 [ 6.705  6.705  6.705]]
[[-3.05   0.     3.05 ]
 [ 6.705  6.705  6.705]]

评论

0赞 Harley Towler 7/22/2023
感谢您的回复。从本质上讲,上述解决方案旨在使用来自单个图像的 2D 坐标(以像素为单位)给出 3D 坐标(假设它们位于平面中,例如,z 坐标 = 0 的地面)。因此,代码应采用以像素(引脚)为单位的坐标及其已知位置(pout),而v应该是在现实世界中将坐标(像素)转换为3D坐标的矩阵。任何其他可以做到这一点的代码都非常感谢。上面的代码基于至少需要 4 分,但最适合超过 5 分(或者这就是我的想法)。
0赞 Reinderien 7/22/2023
这就是它的作用,不是吗?
0赞 Harley Towler 7/23/2023
我有两个新坐标,我正在尝试查找。array([[296.79422895], [203.24645222]]) array([[428.0672238 ], [464.78329654]]) 它们都应该在 (0,6) 和 (0,-6.5) 左右。第二个坐标离得很远。我不知道为什么。第一个是给我一个合适的值。