将 tf.dataset 转换为 PyTorch 数据集?

Converting a tf.dataset to a PyTorch Dataset?

提问人:skidjoe 提问时间:5/1/2021 最后编辑:Innatskidjoe 更新时间:7/12/2023 访问量:3706

问:

我正在做这个项目,其中所有数据都经过预处理并准备好作为 TensorFlow 数据集,如下所示:

<MapDataset shapes: {input_ids: (128,), input_mask: (128,), label_ids: (), segment_ids: (128,)}, types: {input_ids: tf.int64, input_mask: tf.int64, label_ids: tf.int64, segment_ids: tf.int64}>

我拥有的脚本位于 PyTorch 中,并采用一个 Dataset 对象,如下所示:

Dataset({
    features: [
        'attention_mask', 
        'input_ids', 
        'label', 
        'sentence', 
        'token_type_ids'
    ],
    num_rows: 12
})

有什么方法可以将一个转换为另一个?我对这两个 API 都很陌生,所以我不太确定它们是如何工作的。我可以将一个转换为另一个吗?

TensorFlow Keras PyTorch 数据集 TensorFlow-Datasets

评论


答:

1赞 Jaideep Heer 4/24/2022 #1

我用作模型训练的数据加载器。为了转换传递给我的模型的数据,我在模型的 forward 函数中使用。tfds.as_numpy(dataset)torch.as_tensor(data, device=<device>)

import tensorflow_datasets as tfds
import torch.nn as nn

def train_dataloader(batch_size):
    return tfds.as_numpy(tfds.load('mnist').batch(batch_size))

class Model(nn.Module):
    def forward(self, x):
        x = torch.as_tensor(x, device='cuda')
        ...
0赞 Innat 7/12/2023 #2

更新

Keras-Core -Keras Core是Keras代码库的完全重写,将其重新定位在模块化后端架构之上.它使得在任意框架上运行 Keras 工作流成为可能——从 TensorFlow、JAX 和 PyTorch 开始。

...您可以在 上训练模型,也可以在 上训练模型。Keras Core + TensorFlowPyTorch DataLoaderKeras Core + PyTorchtf.data.Dataset


这是一个迟到的答案,但它可能会对未来的读者有所帮助。虽然这不是一个很好的解决方案,但目前正在进行一些可能会有所帮助的讨论,即链接 1链接 2

为了演示,我们将首先构建一个加载器并将其转换为加载器。tf. datatorch.utils.data

tf.data

BATCH_SIZE = 64
(x_train, y_train), _ = keras.datasets.cifar10.load_data()

def normalize(image, label, denorm=False):
    rescale = keras.layers.Rescaling(scale=1./255.)
    norms = keras.layers.Normalization(
        mean=[0.4914, 0.4822, 0.4465], 
        variance=[np.square(0.2023), np.square(0.1994), np.square(0.2010)], 
        invert=denorm,
        axis=-1,
    )
    if not denorm:
        image = rescale(image)
    return norms(image), label

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.map(normalize)
train_ds = train_ds.shuffle(buffer_size=8*BATCH_SIZE)
train_ds = train_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

x, y = next(iter(train_ds))
x.shape, y.shape
(TensorShape([64, 32, 32, 3]), TensorShape([64, 1]))

torch.utils.data

# Define a custom PyTorch dataset class
class TFDatasetWrapper(Dataset):
    def __init__(self, tf_dataset):
        self.tf_dataset = tf_dataset

    def __len__(self):
        return len(x_train)

    def __getitem__(self, idx):
        return next(iter(self.tf_dataset.skip(idx).take(1)))

def tf_collate_fn(batch):
    x, y = zip(*batch)
    x = torch.stack(x).permute(0, 3, 1, 2).type(torch.FloatTensor)
    y = torch.stack(y)
    return x, y

def iter_tf_data(train_ds):
    x_list = []
    y_list = []
    for data in train_ds.as_numpy_iterator():
        x, y = data
        x_list += [torch.from_numpy(x)] 
        y_list += [torch.from_numpy(y)]
    x_list_cat = torch.cat(x_list, axis=0)
    y_list_cat = torch.cat(y_list, axis=0)
    return [x_list_cat, y_list_cat]
def tf_dataset_to_pytorch_dataloader(
    tf_dataset, batch_size, shuffle=True, num_workers=0
):
    """Converts a TensorFlow Dataset to a PyTorch DataLoader."""
    data_list = iter_tf_data(tf_dataset)
    pytorch_dataset = TensorDataset(*data_list)
    pytorch_dataloader = DataLoader(
        pytorch_dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn=tf_collate_fn
    )
    return pytorch_dataloader
train_ds_torch = tf_dataset_to_pytorch_dataloader(
    train_ds, batch_size=BATCH_SIZE // 2, shuffle=True
)
x, y = next(iter(train_ds_torch))
x.shape, y.shape
(torch.Size([32, 3, 32, 32]), torch.Size([32, 1]))

最后,让我们可视化 loader 中的一些示例。torch

import matplotlib.pyplot as plt
from torchvision.utils import make_grid

fig, ax = plt.subplots(figsize=(12, 6))
plt.title("CIFAR10 dataset")
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(make_grid(x, nrow=8).permute(1, 2, 0))
plt.show()

download