如何向 Optuna 函数添加交叉验证以调整 LSTM 的超参数?

How to add cross validation to Optuna function to tune hyperparameters for LSTM?

提问人:dingaro 提问时间:11/17/2023 最后编辑:desertnautdingaro 更新时间:11/18/2023 访问量:15

问:

我有代码来调整LSTM中的超参数。我怎样才能:

  1. 在训练数据集上添加基于 5 个 folds 的交叉验证
  2. 从训练数据集中打印每次迭代的平均平均值,分为 5 个折叠AUC
  3. 从测试数据集打印(当然不要在折叠上划分测试数据集):AUC
def objective(trial):
    start_time = time.time()
    model = create_model(trial)
    history = model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=15, verbose=0)

    y_pred = model.predict(X_test)
    auc = roc_auc_score(y_test, y_pred)
    
    end_time = time.time()
    elapsed_time = end_time - start_time

    print("iteration no:", trial.number)
    print("AUC:", auc)
    print("hyperparameters:", trial.params)
    print("time:", elapsed_time, "sec")

    return auc 

如何在Python中做到这一点?

Python 机器学习 LSTM 交叉验证 optuna

评论


答: 暂无答案