提问人:skidjoe 提问时间:5/1/2021 最后编辑:Innatskidjoe 更新时间:7/12/2023 访问量:3706
将 tf.dataset 转换为 PyTorch 数据集?
Converting a tf.dataset to a PyTorch Dataset?
问:
我正在做这个项目,其中所有数据都经过预处理并准备好作为 TensorFlow 数据集,如下所示:
<MapDataset shapes: {input_ids: (128,), input_mask: (128,), label_ids: (), segment_ids: (128,)}, types: {input_ids: tf.int64, input_mask: tf.int64, label_ids: tf.int64, segment_ids: tf.int64}>
我拥有的脚本位于 PyTorch 中,并采用一个 Dataset 对象,如下所示:
Dataset({
features: [
'attention_mask',
'input_ids',
'label',
'sentence',
'token_type_ids'
],
num_rows: 12
})
有什么方法可以将一个转换为另一个?我对这两个 API 都很陌生,所以我不太确定它们是如何工作的。我可以将一个转换为另一个吗?
答:
我用作模型训练的数据加载器。为了转换传递给我的模型的数据,我在模型的 forward 函数中使用。tfds.as_numpy(dataset)
torch.as_tensor(data, device=<device>)
import tensorflow_datasets as tfds
import torch.nn as nn
def train_dataloader(batch_size):
return tfds.as_numpy(tfds.load('mnist').batch(batch_size))
class Model(nn.Module):
def forward(self, x):
x = torch.as_tensor(x, device='cuda')
...
更新
Keras-Core -Keras Core是Keras代码库的完全重写,将其重新定位在模块化后端架构之上.它使得在任意框架上运行 Keras 工作流成为可能——从 TensorFlow、JAX 和 PyTorch 开始。
...您可以在 上训练模型,也可以在 上训练模型。
Keras Core + TensorFlow
PyTorch DataLoader
Keras Core + PyTorch
tf.data.Dataset
这是一个迟到的答案,但它可能会对未来的读者有所帮助。虽然这不是一个很好的解决方案,但目前正在进行一些可能会有所帮助的讨论,即链接 1、链接 2
为了演示,我们将首先构建一个加载器并将其转换为加载器。tf. data
torch.utils.data
建tf.data
BATCH_SIZE = 64
(x_train, y_train), _ = keras.datasets.cifar10.load_data()
def normalize(image, label, denorm=False):
rescale = keras.layers.Rescaling(scale=1./255.)
norms = keras.layers.Normalization(
mean=[0.4914, 0.4822, 0.4465],
variance=[np.square(0.2023), np.square(0.1994), np.square(0.2010)],
invert=denorm,
axis=-1,
)
if not denorm:
image = rescale(image)
return norms(image), label
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.map(normalize)
train_ds = train_ds.shuffle(buffer_size=8*BATCH_SIZE)
train_ds = train_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
x, y = next(iter(train_ds))
x.shape, y.shape
(TensorShape([64, 32, 32, 3]), TensorShape([64, 1]))
建torch.utils.data
# Define a custom PyTorch dataset class
class TFDatasetWrapper(Dataset):
def __init__(self, tf_dataset):
self.tf_dataset = tf_dataset
def __len__(self):
return len(x_train)
def __getitem__(self, idx):
return next(iter(self.tf_dataset.skip(idx).take(1)))
def tf_collate_fn(batch):
x, y = zip(*batch)
x = torch.stack(x).permute(0, 3, 1, 2).type(torch.FloatTensor)
y = torch.stack(y)
return x, y
def iter_tf_data(train_ds):
x_list = []
y_list = []
for data in train_ds.as_numpy_iterator():
x, y = data
x_list += [torch.from_numpy(x)]
y_list += [torch.from_numpy(y)]
x_list_cat = torch.cat(x_list, axis=0)
y_list_cat = torch.cat(y_list, axis=0)
return [x_list_cat, y_list_cat]
def tf_dataset_to_pytorch_dataloader(
tf_dataset, batch_size, shuffle=True, num_workers=0
):
"""Converts a TensorFlow Dataset to a PyTorch DataLoader."""
data_list = iter_tf_data(tf_dataset)
pytorch_dataset = TensorDataset(*data_list)
pytorch_dataloader = DataLoader(
pytorch_dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=tf_collate_fn
)
return pytorch_dataloader
train_ds_torch = tf_dataset_to_pytorch_dataloader(
train_ds, batch_size=BATCH_SIZE // 2, shuffle=True
)
x, y = next(iter(train_ds_torch))
x.shape, y.shape
(torch.Size([32, 3, 32, 32]), torch.Size([32, 1]))
最后,让我们可视化 loader 中的一些示例。torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
fig, ax = plt.subplots(figsize=(12, 6))
plt.title("CIFAR10 dataset")
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(make_grid(x, nrow=8).permute(1, 2, 0))
plt.show()
评论