提问人:Fido 提问时间:2/3/2022 最后编辑:mkrieger1Fido 更新时间:2/4/2022 访问量:1051
如果事先知道大小,将矩阵从 CSV 文件读取到 NumPy 的快速方法是什么?
What is a fast way to read a matrix from a CSV file to NumPy if the size is known in advance?
问:
我厌倦了在使用 .在另一个 SO 问题之后,我执行了 perfplot 测试,同时包括一些其他方法。结果(源代码在末尾):csv
numpy.genfromtxt
最大输入大小的结果显示最佳方法是 ,即:read_csv
def load_read_csv(path: str):
with open(path, 'r') as csv_file:
reader = csv.reader(csv_file)
matrix = None
first_row = True
for row_index, row in enumerate(reader):
if first_row:
size = len(row)
matrix = np.zeros((size, size), dtype=int)
first_row = False
matrix[row_index] = row
return matrix
现在我怀疑逐行读取文件,将其转换为字符串列表,然后调用列表中的每个项目并将其添加到 NumPy 矩阵中是最好的方法。int()
这个函数可以进一步优化吗,或者是否有一些用于 CSV 加载的快速库(如 Java 中的 Univocity 解析器),或者可能只是一个专用的 NumPy 函数?
测试源代码:
import perfplot
import csv
import numpy as np
import pandas as pd
def load_read_csv(path: str):
with open(path, 'r') as csv_file:
reader = csv.reader(csv_file)
matrix = None
first_row = True
for row_index, row in enumerate(reader):
if first_row:
size = len(row)
matrix = np.zeros((size, size), dtype=int)
first_row = False
# matrix[row_index] = [int(item) for item in row]
matrix[row_index] = row
return matrix
def load_loadtxt(path: str):
matrix = np.loadtxt(path, dtype=int, comments=None, delimiter=",", encoding="utf-8")
return matrix
def load_genfromtxt(path: str):
matrix = np.genfromtxt(path, dtype=int, comments=None, delimiter=",", deletechars=None, replace_space=None, encoding="utf-8")
return matrix
def load_pandas(path: str):
df = pd.read_csv(path, header=None, dtype=np.int32)
return df.values
def load_pandas_engine_pyarrow(path: str):
df = pd.read_csv(path, header=None, dtype=np.int32, engine='pyarrow')
return df.values
def load_pandas_engine_python(path: str):
df = pd.read_csv(path, header=None, dtype=np.int32, engine='python')
return df.values
def setup(n):
matrix = np.random.randint(0, 10000, size=(n, n), dtype=int)
filename = f"square_matrix_of_size_{n}.csv"
np.savetxt(filename, matrix, fmt="%d", delimiter=",")
return filename
b = perfplot.bench(
setup=setup, # or setup=np.random.rand
kernels=[
load_read_csv,
load_loadtxt,
load_genfromtxt,
load_pandas,
load_pandas_engine_pyarrow,
load_pandas_engine_python
],
n_range=[2 ** k for k in range(15)]
)
b.save("out.png")
b.show()
答:
在支持多种数据类型(例如浮点数、整数、字符串)和可能格式错误的输入文件的同时正确解析 CSV 文件显然并不容易,而且有效地做到这一点实际上非常困难。此外,解码 UTF-8 字符串也比直接读取 ASCII 字符串慢得多。这就是为什么大多数 CSV 库都非常慢的原因。更不用说在 Python 中包装库可能会引入有关输入类型(尤其是字符串)的相当大的开销。
希望如果你需要读取一个CSV文件,其中包含一个假定格式正确的整数方阵,那么你可以编写一个更快的特定代码,专门用于你的需求(它不关心浮点数、字符串、UTF-8、标头解码、错误处理等)。
话虽如此,任何对基本 CPython 函数的调用都会引入巨大的开销。即使是对 + 的简单调用也相对较慢(二进制模式比文本模式快得多,但不幸的是没有那么快)。诀窍是使用 Numpy 将整个二进制文件加载到 RAM 中。这个函数非常快:它只是一次读取整个文件,将其二进制内容放入原始内存缓冲区中并返回视图。当文件位于操作系统缓存或高吞吐量 NVMe SSD 存储设备中时,它可以以几 GiB/s 的速度加载文件。open
read
np.fromfile
加载文件后,您可以使用 Numba(或 Cython)对其进行解码,因此解码速度几乎与本机代码一样快。请注意,Numba 不支持良好/有效的字符串/字节。希望该函数产生一个连续的字节数组,Numba 可以非常快速地计算它。您只需阅读第一行并计算逗号的数量即可知道矩阵的大小。然后,您可以通过动态解码整数来非常有效地填充矩阵,将它们打包到一个扁平矩阵中,并将行尾字符视为常规分隔符。请注意,由于文件是以二进制模式读取的,因此 和 都可以出现在文件中。np.fromfile
\r
\n
下面是生成的实现:
import numba as nb
import numpy as np
@nb.njit('int32[:,:](uint8[::1],)', cache=True)
def decode_csv_buffer(rawData):
COMMA = np.uint8(ord(','))
CR = np.uint8(ord('\r'))
LF = np.uint8(ord('\n'))
ZERO = np.uint8(ord('0'))
# Find the size of the matrix (`n`)
n = 0
lineSize = 0
for i in range(rawData.size):
c = rawData[i]
if c == CR or c == LF:
break
n += rawData[i] == COMMA
lineSize += 1
n += 1
# Empty matrix
if lineSize == 0:
return np.empty((0, 0), dtype=np.int32)
# Initialization
res = np.empty(n * n, dtype=np.int32)
# Fill the matrix
curInt = 0
curPos = 0
lastCharIsDigit = True
for i in range(len(rawData)):
c = rawData[i]
if c == CR or c == LF or c == COMMA:
if lastCharIsDigit:
# Write the last int in the flatten matrix
res[curPos] = curInt
curPos += 1
curInt = 0
lastCharIsDigit = False
else:
curInt = curInt * 10 + (c - ZERO)
lastCharIsDigit = True
return res.reshape(n, n)
def load_numba(filename):
# Load fully the file in a raw memory buffer
rawData = np.fromfile(filename, dtype=np.uint8)
# Decode the buffer using the Numba JIT
# This method only work for your specific needs and
# can simply crash if the file content is invalid.
return decode_csv_buffer(rawData)
请注意,代码并不可靠(任何错误的输入都会导致未定义的行为,包括崩溃),但它非常快。
以下是我机器上的结果:
正如你所看到的,上面的 Numba 实现至少比所有其他实现快一个数量级。请注意,在解码过程中,您可以使用多个线程编写更快的代码,但这会使代码更加复杂。
评论
load_read_csv
size
load_read_csv