在 tf.data 中切片会导致“迭代 'tf.Graph 执行中不允许出现 Tensor'“错误

Slicing in tf.data causes "iterating over `tf.Tensor` is not allowed in Graph execution" error

提问人:momo 提问时间:4/8/2021 最后编辑:momo 更新时间:4/8/2021 访问量:528

问:

我创建了一个数据集,如下所示,其中是图像文件路径列表, 例如。.我需要提取文件夹路径,例如,然后进行一些其他操作。我尝试使用以下函数来执行此操作。image_train_path[b'/content/drive/My Drive/data/folder1/im1.png', b'/content/drive/My Drive/data/folder2/im6.png',...]'/content/drive/My Drive/data/folder1'preprocessData

dataset = tf.data.Dataset.from_tensor_slices(image_train_path)
dataset = dataset.map(preprocessData, num_parallel_calls=16)

在哪里:preprocessData

def preprocessData(images_path):
    folder=tf.strings.split(images_path,'/')
    foldername=tf.strings.join(tf.slice(folder,(0,),(6,)),'/')
    ....

但是,切片线会导致以下错误:

OperatorNotAllowedInGraphError: in user code:

    <ipython-input-21-2a9827982c16>:4 preprocessData  *
        foldername=tf.strings.join(tf.slice(folder,(0,),(6,)),'/')
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py:210 wrapper  **
        result = dispatch(wrapper, args, kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py:122 dispatch
        result = dispatcher.handle(args, kwargs)
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/ragged/ragged_dispatch.py:130 handle
        for elt in x:
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:524 __iter__
        self._disallow_iteration()
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:520 _disallow_iteration
        self._disallow_in_graph_mode("iterating over `tf.Tensor`")
    /usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py:500 _disallow_in_graph_mode
        " this function with @tf.function.".format(task))

    OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.

我在 Tf2.4 和 tf nightly 中都尝试过这个。我尝试过装饰和使用.总是给出相同的错误。@tf.functiontf.data.experimental.enable_debug_mode()

我不太明白哪个部分导致了“迭代”,尽管我想问题出在切片上。有没有其他方法可以做到这一点?

Python TensorFlow 切片 tf.data.dataset

评论

0赞 krenerd 4/8/2021
你能发布完整的代码吗?preprocessData

答:

1赞 Lescurel 4/8/2021 #1

函数 tf.strings.join 需要 Tensor 的列表,如文档所述:

参数

inputs:tf 的列表。相同大小和 tf.string dtype 的张量对象。

tf.slice返回一个 Tensor,然后 join 函数将尝试遍历它,从而导致错误。

您可以使用简单的列表推导式来馈送函数:

def preprocessData(images_path):
    folder=tf.strings.split(images_path,'/')
    foldername=tf.strings.join([folder[i] for i in range(6)],"/")
    return foldername