提问人:Magnus 提问时间:6/30/2017 最后编辑:Zoe is on strikeMagnus 更新时间:8/17/2023 访问量:1205
如何检测参数网格中允许的值?
How to detect which values are allowable in a parameter grid?
问:
我已经开始从事一个项目,在该项目中,我需要检测给定 scikit-learn 估计器的可训练参数,如果可能的话,找到分类变量的允许值(以及连续变量的合理间隔)。
我可以获取带有参数的字典 using,然后设置一个值 using,依此类推。estimator.get_params()
estimator.set_params(**{'var1':val1, 'var2':val2})
例如,对于 KNN 分类器,我们有以下参数字典:.{'metric': 'minkowski', 'algorithm': 'auto', 'n_neighbors': 10, 'n_jobs': 1, 'p': 2, 'metric_params': None, 'weights': 'uniform', 'leaf_size': 30}
现在,我可以使用值的类型来推断哪些是分类(类型)、连续()、离散()等。一个可能相关的问题是默认值设置为 的参数,但我可能无论如何都不会碰这些参数,这是有充分理由的。str
float
int
NoneType
现在的挑战是推断和定义一个参数网格,例如。对于离散变量和连续变量,该问题很容易使用例如将 - 块与 scipy.stats 模块组合在一起,可能会将区间限制在默认值附近的“附近”(但同时要注意不要设置例如 到一些疯狂的值 - 可能需要硬编码,或者稍后显式设置)。如果您有类似经验,并且有一些提示/技巧,我很想听听。RandomizedSearchCV
try
except
n_jobs
但现在真正的问题是:如何推断例如 允许的值实际上是??algorithm
{‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’}
我刚刚开始研究这个问题,如果我们尝试将其设置为某个不允许的值,也许我们可以解析我们得到的错误消息?我在这里寻找好主意,因为我想避免手动执行此操作(如果必须的话,我会这样做,但这似乎很不优雅......
答:
我找到了我正在查看的特定示例的解决方案,但是,它不能很好地推广到其他文档字符串,因为没有固定的约定来说明它们如何为 sklearn 中的每个估算器编写。
因此,我发布我的“解决方案”,以便其他人可以接管并可能改进它。请参阅以下代码片段:
import re
from pprint import pprint
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier()
doc = knn.__doc__ # Get the doc string
#from sklearn.svm import SVC
#svc = SVC()
#doc = svc.__doc__
pattern = "([a-zA-Z_]+\s:\s)|(-\s*)'([a-zA-Z_]+)'" # Define search pattern
re.compile(pattern)
matches = re.findall(pattern, doc)
clf_params = {}
previous_param = ''
for param, _, value in matches:
if ":" in param and param[-4]!="_": # 'Hack-y'
if param not in clf_params.keys():
clf_params[param] = list()
previous_param = param
else:
if len(value)>0:
clf_params[previous_param].append(value)
pprint(clf_params)
此代码片段打印
{'algorithm : ': ['ball_tree', 'kd_tree', 'brute', 'auto'],
'leaf_size : ': [],
'metric : ': [],
'metric_params : ': [],
'n_jobs : ': [],
'n_neighbors : ': [],
'p : ': [],
'weights : ': ['uniform', 'distance']}
这是正确的。
但是,如果我们重复相同的过程,我们将看到它失败了。SVC().__doc__
我希望有人觉得这有点有用。
评论
我试图从文档字符串(LinearSVC 作为示例算法)中获取所有这些内容,这得到了以下方面的极大帮助:splitlines()
liner = str(LinearSVC().__doc__).split('Parameters\n ----------\n')[1].split('\n\n Attributes\n')[0].replace('\n ', '\n').splitlines()
这不会创建一个字典,但足够简单,只需从文档字符串中提取解释的“参数”部分,该部分解释了所有参数,并列出了所有可能的/预期/接受的值输入,这些输入被一个很好地缩进,制表符,现在我们可以使用带有条件的简单循环,使用“ : “作为我们的锚点,用于识别可能/预期/可接受的值输入线:
for i in liner:
...: if " : " in i: #<<< the key is to use " : " as our anchor
...: print(i)
最终结果打印出来:
penalty : str, 'l1' or 'l2' (default='l2')
loss : str, 'hinge' or 'squared_hinge' (default='squared_hinge')
dual : bool, (default=True)
tol : float, optional (default=1e-4)
C : float, optional (default=1.0)
multi_class : str, 'ovr' or 'crammer_singer' (default='ovr')
fit_intercept : bool, optional (default=True)
intercept_scaling : float, optional (default=1)
class_weight : {dict, 'balanced'}, optional
verbose : int, (default=0)
random_state : int, RandomState instance or None, optional (default=None)
max_iter : int, (default=1000)
很高兴我能分享,如果其他人需要完整的文档字符串参数打印输出,只需使用:
print(str(LinearSVC().__doc__).split('Parameters\n ----------\n')[1].split('\n\n Attributes\n')[0].replace('\n ', '\n'))
编辑:如果这不是要打印出来的 - 将其作为字符串对象的最佳方法是使用列表推导式,但它需要一些丑陋的替换,因为文档字符串中有广泛的符号:
docstring_short = str([i for i in liner.splitlines() if " : " in i]).replace('[" ', '').replace(' ', ',\n').replace('", "', '').replace('", \'', '').replace("', '", '').replace("', \"", '').replace(']', '')
评论
__init__