提问人:Sol 提问时间:10/27/2023 更新时间:10/28/2023 访问量:16
运行 k-fold 交叉验证时的 BrokenProcessPool
BrokenProcessPool while running k-fold cross-validation
问:
我一直在尝试对感知模型进行 k 折交叉验证。有一个错误,但多亏了有人,我才能够解决它。但后来我遇到了一条新的错误消息,如下所示。
class Perceptron(tf.keras.Model):
def __init__(self):
super(Perceptron, self).__init__()
self.dense = tf.keras.layers.Dense(units=1, activation='sigmoid')
def call(self, inputs):
return self.dense(inputs)
from sklearn.model_selection import RepeatedKFold, cross_val_score
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.wrappers.scikit_learn import KerasRegressor
def buildmodel():
model_kfold = Perceptron()
model_kfold.compile(optimizer='sgd', loss='binary_crossentropy', metrics=['accuracy'])
return(model_kfold)
estimator= KerasRegressor(build_fn=buildmodel, epochs=100, batch_size=10, verbose=0)
kfold= RepeatedKFold(n_splits=5, n_repeats=10)
results= cross_val_score(estimator, x_train, y_train, cv=kfold, n_jobs=2) # 2 cpus
results.mean() # Mean MSE
下面是错误消息
---------------------------------------------------------------------------
_RemoteTraceback Traceback (most recent call last)
_RemoteTraceback:
"""
Traceback (most recent call last):
File "/home/jinzzasol/miniconda3/envs/tensorflow/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 391, in _process_worker
call_item = call_queue.get(block=True, timeout=timeout)
File "/home/jinzzasol/miniconda3/envs/tensorflow/lib/python3.10/multiprocessing/queues.py", line 122, in get
return _ForkingPickler.loads(res)
ModuleNotFoundError: No module named 'keras.wrappers'
"""
The above exception was the direct cause of the following exception:
BrokenProcessPool Traceback (most recent call last)
/home/jinzzasol/Code/cs5834/assignment4/hw4.ipynb Cell 43 line 1
11 estimator= KerasRegressor(build_fn=buildmodel, epochs=100, batch_size=10, verbose=0)
12 kfold= RepeatedKFold(n_splits=5, n_repeats=10)
---> 13 results= cross_val_score(estimator, x_train, y_train, cv=kfold, n_jobs=2) # 2 cpus
14 results.mean() # Mean MSE
File ~/miniconda3/envs/tensorflow/lib/python3.10/site-packages/sklearn/model_selection/_validation.py:562, in cross_val_score(estimator, X, y, groups, scoring, cv, n_jobs, verbose, fit_params, pre_dispatch, error_score)
559 # To ensure multimetric format is not supported
560 scorer = check_scoring(estimator, scoring=scoring)
--> 562 cv_results = cross_validate(
...
404 finally:
405 # Break a reference cycle with the exception in self._exception
406 self = None
BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.
答:
0赞
Sol
10/28/2023
#1
我发现这是因为我制作的自定义模型与cross_val_score方法不兼容。我最终使用 Keras 函数而不是我自己的模型构建了模型。谢谢。
评论