在 sklearn.neighbors 中为混合类型数据集设置用户定义的异构距离指标时出现错误“无法将字符串转换为浮点数”

Error "could not convert string to float" when setting user-defined heterogeneous distance metric for mixed-type dataset in sklearn.neighbors

提问人:Tomas H. 提问时间:10/27/2023 最后编辑:DataJanitorTomas H. 更新时间:10/27/2023 访问量:38

问:

希望有人可以帮助我解决以下问题:

我有一个混合类型的数据集,其中包含 Python 中的数值 (dtypes: , , ) 和分类 (dtype: ) 变量。现在,我想使用 中的类在此数据集上训练最近邻算法。为了处理不同的数据类型,我想使用异构距离指标初始化参数。描述指出“来自或可以使用的任何指标”来定义此参数。因为(据我所知)这些不包括异构距离指标,所以我决定使用 distython:一个用户定义的距离指标类,可以计算具有数值和分类变量的数据集的异构距离。intfloatboolcategoricalNearestNeighborssklearn.neighborsmetricsklearn.neighborsscikit-learnscipy.spatial.distance

我的代码如下:

from sklearn.neighbors import NearestNeighbors
from distython import HEOM

X_train # dataset with numerical & categorical variables
catIndices # column indices of categorical variables

# initialize heterogenous distance metric
heom_metric = HEOM(X_train, catIndices)
    
# Construct & train nearest neighbor algorithm
neigh = NearestNeighbors(n_neighbors=5, metric = heom_metric.heom)
neigh.fit(X_train)

但是,我收到以下错误:

Cell In[17], line 12
     11 neigh = NearestNeighbors(n_neighbors=5, metric = heom_metric.heom)
---> 12 neigh.fit(X_train)

File ~\anaconda3\Lib\site-packages\sklearn\base.py:1152, in _fit_context.<locals>.decorator.<locals>.wrapper(estimator, *args, **kwargs)
   1145     estimator._validate_params()
   1147 with config_context(
   1148     skip_parameter_validation=(
   1149         prefer_skip_nested_validation or global_skip_validation
   1150     )
   1151 ):
-> 1152     return fit_method(estimator, *args, **kwargs)

File ~\anaconda3\Lib\site-packages\sklearn\neighbors\_unsupervised.py:175, in NearestNeighbors.fit(self, X, y)
    154 @_fit_context(
    155     # NearestNeighbors.metric is not validated yet
    156     prefer_skip_nested_validation=False
    157 )
    158 def fit(self, X, y=None):
    159     """Fit the nearest neighbors estimator from the training dataset.
    160 
    161     Parameters
   (...)
    173         The fitted nearest neighbors estimator.
    174     """
--> 175     return self._fit(X)

File ~\anaconda3\Lib\site-packages\sklearn\neighbors\_base.py:498, in NeighborsBase._fit(self, X, y)
    496 else:
    497     if not isinstance(X, (KDTree, BallTree, NeighborsBase)):
--> 498         X = self._validate_data(X, accept_sparse="csr", order="C")
    500 self._check_algorithm_metric()
    501 if self.metric_params is None:

File ~\anaconda3\Lib\site-packages\sklearn\base.py:605, in BaseEstimator._validate_data(self, X, y, reset, validate_separately, cast_to_ndarray, **check_params)
    603         out = X, y
    604 elif not no_val_X and no_val_y:
--> 605     out = check_array(X, input_name="X", **check_params)
    606 elif no_val_X and not no_val_y:
    607     out = _check_y(y, **check_params)

File ~\anaconda3\Lib\site-packages\sklearn\utils\validation.py:915, in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator, input_name)
    913         array = xp.astype(array, dtype, copy=False)
    914     else:
--> 915         array = _asarray_with_order(array, order=order, dtype=dtype, xp=xp)
    916 except ComplexWarning as complex_warning:
    917     raise ValueError(
    918         "Complex data not supported\n{}\n".format(array)
    919     ) from complex_warning

File ~\anaconda3\Lib\site-packages\sklearn\utils\_array_api.py:380, in _asarray_with_order(array, dtype, order, copy, xp)
    378     array = numpy.array(array, order=order, dtype=dtype)
    379 else:
--> 380     array = numpy.asarray(array, order=order, dtype=dtype)
    382 # At this point array is a NumPy ndarray. We convert it to an array
    383 # container that is consistent with the input's namespace.
    384 return xp.asarray(array)
**ValueError: could not convert string to float: [element from categorical column]**

我知道类中的方法不处理分类对象,因为 kNN 算法无法计算元素之间的距离(显然)。但是,我不明白为什么在我明确声明一个距离指标的情况下也给出了这个错误,该指标可以将数字和分类的组合作为输入,并给出一个数字距离值作为输出。fit()NearestNeighborsstring

我的假设是,由于某种原因,我的用户定义的距离度量没有被类正确“识别”为异构,因此在它有机会计算距离之前就已经产生了这个度量。奇怪的是,此错误不会发生在用户定义的距离指标的 Github 页面上提供的示例中。NearestNeighborsfit()ValueError

我的问题是:如何解决这个问题并确保我的类正确接受我的用户定义的距离指标?NearestNeighbors

提前致谢。

scikit-learn 用户定义函数 距离 指标 最近邻

评论


答: 暂无答案