提问人:G. Lippolis 提问时间:10/29/2023 更新时间:10/30/2023 访问量:24
如何使用 TensorFlow Dataset.from_generator
how to use tensorflow Dataset.from_generator
问:
我正在尝试使用使用“tf.data.Dataset.from_generator”构建的数据集来拟合模型。 但合身失败了。
这里是数据集的代码:
cd_gen=CordicDatasetFT(14)
cos=(tf.TensorSpec(shape=(14, 3), dtype=tf.float32, name=None),
tf.TensorSpec(shape=(14, 3), dtype=tf.float32, name=None))
cds = tf.data.Dataset.from_generator(cd_gen, output_signature = cos)
似乎它已经准备好训练我的模型了:
print(type(cds))
cds_tst=cds.batch(512)
for batch_it in cds_tst:
x, y = batch_it
y_pre=model.predict(x)
print(y_pre.shape)
print("step")
break
<class 'tensorflow.python.data.ops.flat_map_op._FlatMapDataset'>
[CordicDatasetFT]: call
16/16 [==============================] - 0s 7ms/step
(512, 14, 3)
step
但是,如果我尝试适应:
history=model.fit(cds, epochs=1)
我收到此错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[19], line 4
1 #model.fit(ds, validation_data=ds, batch_size=512, epochs=20, steps_per_epoch=256, validation_steps=32)
2 #history=model.fit(cds, batch_size=512, epochs=75, steps_per_epoch=256)
3 print(type(cds))
----> 4 history=model.fit(cds, epochs=1)
File /opt/conda/lib/python3.10/site-packages/keras/utils/traceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
67 filtered_tb = _process_traceback_frames(e.__traceback__)
68 # To get the full stack trace, call:
69 # `tf.debugging.disable_traceback_filtering()`
---> 70 raise e.with_traceback(filtered_tb) from None
71 finally:
72 del filtered_tb
File /opt/conda/lib/python3.10/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py:926, in Function._call(self, *args, **kwds)
923 self._lock.release()
924 # In this case we have created variables on the first call, so we run the
925 # defunned version which is guaranteed to never create variables.
--> 926 return self._no_variable_creation_fn(*args, **kwds) # pylint: disable=not-callable
927 elif self._variable_creation_fn is not None:
928 # Release the lock early so that multiple threads can perform the call
929 # in parallel.
930 self._lock.release()
TypeError: 'NoneType' object is not callable
我哪里错了?
答:
0赞
G. Lippolis
10/30/2023
#1
我找到了解决方案,问题是数据集应该被批量化:
cds = tf.data.Dataset.from_generator(cd_gen, output_signature = cos)
cds = cds.batch(512)
print(cds.element_spec)
(TensorSpec(shape=(None, 14, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 14, 3), dtype=tf.float32, name=None))
我需要调整拟合调用:
history=model.fit(cds, epochs=1, steps_per_epoch=256)
256/256 [==============================] - 35s 135ms/step - loss: 5.0874 - mean_squared_error: 5.0874
评论