2017-06-30 15 views
1

我已经开始工作的一个项目,其中我需要检测训练的参数对于给定的scikit学习估计,如果可能的话,找个分类变量(和连续那些合理的时间间隔)的允许值。如何检测参数网格中允许使用哪些值? (sklearn)

我可以使用estimator.get_params()来获取带参数的字典,然后使用estimator.set_params(**{'var1':val1, 'var2':val2})来设置值,依此类推。

例如,对于KNN分类器,我们有以下的参数字典: {'metric': 'minkowski', 'algorithm': 'auto', 'n_neighbors': 10, 'n_jobs': 1, 'p': 2, 'metric_params': None, 'weights': 'uniform', 'leaf_size': 30}

现在,我可以使用这些值的类型来推断哪些是分类的(str类型),连续的(float),离散的(int)等等。一个可能相关的问题是默认设置为NoneType的参数,但我可能只是不会触及这些参数,理由很充分。

目前的挑战变得推断和定义参数网格用于例如使用RandomizedSearchCV。对于离散变量和连续变量,这个问题可以用例如的try的组合 - except块与scipy.stats模块一起,可能限制了间隔为位于所述“附近”周围的默认值(但同时小心不要设置例如n_jobs一些疯狂的值 - 即可能需要硬编码,或稍后明确设置)。如果你有类似的经验,并有一些提示/技巧你的袖子,我很想听听他们。

但现在真正的问题是:如何推断例如algorithm表示允许的值实际上是{‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’} ??

我刚开始寻找到这个问题,也许如果我们试图将它设置为一些未允许值,我们可以分析我们得到错误信息?我就看出来的好点子在这里,因为我想避免手动做到这一点(我会,如果我必须这样做,但它似乎相当不雅...)

谢谢!

+0

自我提醒:这可能是一个非常困难的/无法解决的问题。我在api和源代码中探索过,并且看看如何auto-sklearn解决了这个问题。看来,手动(硬编码)解决方案是现在的方式。 – Magnus

+0

你有那里有趣的问题。除了[解析签名和默认参数](https://stackoverflow.com/questions/2677185/how-can-i-read-a-functions-signature-including-default-argument-values)我想我会尝试解析scikit-learn的文档字符串,如[this](https://stackoverflow.com/questions/713138/getting-the-docstring-from-a-function)。另一个尝试将解析字符串化函数,例如估计器的__init__',但这是一个杂乱无章的镜头,因为我没有看到任何检查正在完成,并且您可能需要查看整个层次结构。 – mkaran

+0

你好!很高兴你发现这个主题很有趣。是的,这是/我正在考虑/正在考虑的一个选项(解析文档)。但是令我担心的是,文档编写的方式是一致的,并且没有强制的约定(但我可能是错误的),这可能会被利用。我可能会花一点时间来实现一个解析器并在一堆文档中测试它...... – Magnus

回答

0

我发现我一直在寻找的具体示例的解决方案,但是,它并没有因为没有一套约定WRT他们是如何在sklearn每个估计书面推广以及到其他文档串。

因此,我发表我的“解决方案”,以便其他人可以接管并在其上可能会提高。请参见下面的代码片段:

import re 
from pprint import pprint 
from sklearn.neighbors import KNeighborsClassifier 
knn = KNeighborsClassifier() 
doc = knn.__doc__ # Get the doc string 
#from sklearn.svm import SVC 
#svc = SVC() 
#doc = svc.__doc__ 
pattern = "([a-zA-Z_]+\s:\s)|(-\s*)'([a-zA-Z_]+)'" # Define search pattern 
re.compile(pattern) 
matches = re.findall(pattern, doc) 

clf_params = {} 
previous_param = '' 
for param, _, value in matches: 
    if ":" in param and param[-4]!="_": # 'Hack-y' 
     if param not in clf_params.keys(): 
      clf_params[param] = list() 
      previous_param = param 
     else: 
      if len(value)>0: 
       clf_params[previous_param].append(value) 
pprint(clf_params) 

这个片断输出

{'algorithm : ': ['ball_tree', 'kd_tree', 'brute', 'auto'], 
'leaf_size : ': [], 
'metric : ': [], 
'metric_params : ': [], 
'n_jobs : ': [], 
'n_neighbors : ': [], 
'p : ': [], 
'weights : ': ['uniform', 'distance']} 

这是正确的。

然而,如果我们重复同样的程序SVC().__doc__,我们将看到它失败。

我希望有人认为这有点用处。

相关问题