我已经开始工作的一个项目,其中我需要检测训练的参数对于给定的scikit学习估计,如果可能的话,找个分类变量(和连续那些合理的时间间隔)的允许值。如何检测参数网格中允许使用哪些值? (sklearn)
我可以使用estimator.get_params()
来获取带参数的字典,然后使用estimator.set_params(**{'var1':val1, 'var2':val2})
来设置值,依此类推。
例如,对于KNN分类器,我们有以下的参数字典: {'metric': 'minkowski', 'algorithm': 'auto', 'n_neighbors': 10, 'n_jobs': 1, 'p': 2, 'metric_params': None, 'weights': 'uniform', 'leaf_size': 30}
。
现在,我可以使用这些值的类型来推断哪些是分类的(str
类型),连续的(float
),离散的(int
)等等。一个可能相关的问题是默认设置为NoneType
的参数,但我可能只是不会触及这些参数,理由很充分。
目前的挑战变得推断和定义参数网格用于例如使用RandomizedSearchCV
。对于离散变量和连续变量,这个问题可以用例如的try
的组合 - except
块与scipy.stats模块一起,可能限制了间隔为位于所述“附近”周围的默认值(但同时小心不要设置例如n_jobs
一些疯狂的值 - 即可能需要硬编码,或稍后明确设置)。如果你有类似的经验,并有一些提示/技巧你的袖子,我很想听听他们。
但现在真正的问题是:如何推断例如algorithm
表示允许的值实际上是{‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’}
??
我刚开始寻找到这个问题,也许如果我们试图将它设置为一些未允许值,我们可以分析我们得到错误信息?我就看出来的好点子在这里,因为我想避免手动做到这一点(我会,如果我必须这样做,但它似乎相当不雅...)
谢谢!
自我提醒:这可能是一个非常困难的/无法解决的问题。我在api和源代码中探索过,并且看看如何auto-sklearn解决了这个问题。看来,手动(硬编码)解决方案是现在的方式。 – Magnus
你有那里有趣的问题。除了[解析签名和默认参数](https://stackoverflow.com/questions/2677185/how-can-i-read-a-functions-signature-including-default-argument-values)我想我会尝试解析scikit-learn的文档字符串,如[this](https://stackoverflow.com/questions/713138/getting-the-docstring-from-a-function)。另一个尝试将解析字符串化函数,例如估计器的__init__',但这是一个杂乱无章的镜头,因为我没有看到任何检查正在完成,并且您可能需要查看整个层次结构。 – mkaran
你好!很高兴你发现这个主题很有趣。是的,这是/我正在考虑/正在考虑的一个选项(解析文档)。但是令我担心的是,文档编写的方式是一致的,并且没有强制的约定(但我可能是错误的),这可能会被利用。我可能会花一点时间来实现一个解析器并在一堆文档中测试它...... – Magnus