如何使用 TensorFlow Dataset.from_generator

how to use tensorflow Dataset.from_generator

提问人:G. Lippolis 提问时间:10/29/2023 更新时间:10/30/2023 访问量:24

问:

我正在尝试使用使用“tf.data.Dataset.from_generator”构建的数据集来拟合模型。 但合身失败了。

这里是数据集的代码:

cd_gen=CordicDatasetFT(14)
cos=(tf.TensorSpec(shape=(14, 3), dtype=tf.float32, name=None),
     tf.TensorSpec(shape=(14, 3), dtype=tf.float32, name=None))
cds = tf.data.Dataset.from_generator(cd_gen, output_signature = cos)

似乎它已经准备好训练我的模型了:

print(type(cds))
cds_tst=cds.batch(512)

for batch_it in cds_tst:
    x, y = batch_it
    y_pre=model.predict(x)
    print(y_pre.shape)
    print("step")
    break
<class 'tensorflow.python.data.ops.flat_map_op._FlatMapDataset'>
[CordicDatasetFT]: call
16/16 [==============================] - 0s 7ms/step
(512, 14, 3)
step

但是,如果我尝试适应:

history=model.fit(cds, epochs=1)

我收到此错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[19], line 4
      1 #model.fit(ds, validation_data=ds, batch_size=512, epochs=20, steps_per_epoch=256, validation_steps=32)
      2 #history=model.fit(cds, batch_size=512, epochs=75, steps_per_epoch=256)
      3 print(type(cds))
----> 4 history=model.fit(cds, epochs=1)

File /opt/conda/lib/python3.10/site-packages/keras/utils/traceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     67     filtered_tb = _process_traceback_frames(e.__traceback__)
     68     # To get the full stack trace, call:
     69     # `tf.debugging.disable_traceback_filtering()`
---> 70     raise e.with_traceback(filtered_tb) from None
     71 finally:
     72     del filtered_tb

File /opt/conda/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:926, in Function._call(self, *args, **kwds)
    923   self._lock.release()
    924   # In this case we have created variables on the first call, so we run the
    925   # defunned version which is guaranteed to never create variables.
--> 926   return self._no_variable_creation_fn(*args, **kwds)  # pylint: disable=not-callable
    927 elif self._variable_creation_fn is not None:
    928   # Release the lock early so that multiple threads can perform the call
    929   # in parallel.
    930   self._lock.release()

TypeError: 'NoneType' object is not callable

我哪里错了?

python tensorflow tensorflow-datasets

评论


答:

0赞 G. Lippolis 10/30/2023 #1

我找到了解决方案,问题是数据集应该被批量化:

cds = tf.data.Dataset.from_generator(cd_gen, output_signature = cos)
cds = cds.batch(512)
print(cds.element_spec)

(TensorSpec(shape=(None, 14, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 14, 3), dtype=tf.float32, name=None))

我需要调整拟合调用:

history=model.fit(cds, epochs=1, steps_per_epoch=256)
256/256 [==============================] - 35s 135ms/step - loss: 5.0874 - mean_squared_error: 5.0874