2017-02-09 42 views
0

我正在尝试编写一个函数,该函数允许我灵活运行字典中参数子集的网格搜索。我想实现的具体行为如下:迭代器用于遍历字典键子集中的参数范围

def my_grid_searching_function(fiducial_dict, **param_iterators): 
    for params in desired_iterator: 
     fiducial_dict.update(params) 
     # compute chi^2 
     # write new fiducial_dict values and associated chi^2 value to disk 

我的具体目标是要弄清楚如何写desired_iterator

函数my_grid_searching_function接受关键字参数的任意子集,每个关键字参数将被解释为参数fiducial_dict

这似乎是itertools.product的任务,但我遇到了问题。在下面的实现中,我使用product我能够在值有效的输入迭代器到单个环路的改造嵌套循环:

from itertools import product 
def my_failed_grid_searching_function(fiducial_dict, **param_iterators): 
    desired_iterator = product(*list(param_iterators.values())) 
    for params in desired_iterator: 
     print(params) 
fiducial_dict = {'x': 0, 'y': 0, 'z': 9} 
my_failed_grid_searching_function(fiducial_dict, x=[4, 5, 6], y=[1, 2])  

(1, 4) 
(1, 5) 
(1, 6) 
(2, 4) 
(2, 5) 
(2, 6) 

当然,这个问题是输入param_iterators已摔成一个普通的字典,所以在my_failed_grid_searching_function的命名空间内,我不知道值的顺序是什么。

任何人都可以提供关于如何编写desired_iterator的任何提示,以便产生足够的信息来更新fiducial_dict,如上所示?

回答

1

由于您使用的是任意关键字参数,因此您可以抓住param_iterators字典的键来同步param产品的位置。或者,我会推荐使用sklearn软件包来执行grid search

不管怎么说,尝试此解决方案:

from itertools import product 

def my_grid_searching_function(fiducial_dict, **param_iterators): 
    keys = param_iterators.keys() 
    desired_iterator = list(product(*list(param_iterators.values()))) 
    for i in range(len(desired_iterator)): 
     print("Epoch: ", i) 
     for loc in range(len(desired_iterator[i])): 
      print(keys[loc], desired_iterator[i][loc]) 
     # update your fiducial_dict here 

my_grid_searching_function({'x': 0, 'y': 0}, x=[1,2,3,4], y=[6,7,8]) 

输出:

('Epoch: ', 0) 
('y', 6) 
('x', 1) 
('Epoch: ', 1) 
('y', 6) 
('x', 2) 
('Epoch: ', 2) 
('y', 6) 
('x', 3) 
('Epoch: ', 3) 
('y', 6) 
('x', 4) 
('Epoch: ', 4) 
('y', 7) 
('x', 1) 
('Epoch: ', 5) 
('y', 7) 
('x', 2) 
('Epoch: ', 6) 
('y', 7) 
('x', 3) 
('Epoch: ', 7) 
('y', 7) 
('x', 4) 
('Epoch: ', 8) 
('y', 8) 
('x', 1) 
('Epoch: ', 9) 
('y', 8) 
('x', 2) 
('Epoch: ', 10) 
('y', 8) 
('x', 3) 
('Epoch: ', 11) 
('y', 8) 
('x', 4) 

***Repl Closed*** 
+0

谢谢@ Scratch'N'Purr。您的第一条评论是最有帮助的 - 这是我需要解决我试图解决的问题的指针,所以您应该获得我将发布的解决方案的重大荣誉。你提出的解决方案的缺点是它需要一次分配所有的内存,所以它不是一个真正的迭代器解决方案。此外,尽管sklearn grid_search在探索sklearn模型时非常有用,但它仅适用于sklearn模型,并且我想避免必须将我工作的每个问题都转换为sklearn兼容模型,以执行简单的网格搜索。 – aph

+0

没问题@aph!我注意到你在生成器中使用yield语句,所以看起来你已经解决了内存问题:) –

1

由于Scratch'N'Purr用于指出序列顺序可以简单地从.keys()确定方法。

from itertools import product 
def param_grid_search_generator(**param_iterators): 
    param_names = list(param_iterators.keys()) 
    param_combination_generator = product(*list(param_iterators.values())) 
    for param_combination in param_combination_generator: 
     yield {param_names[i]: param_combination[i] for i in range(len(param_names))}