当前位置: 代码迷 >> python >> 迭代器,用于遍历字典键子集中的参数范围
  详细解决方案

迭代器,用于遍历字典键子集中的参数范围

热度:116   发布时间:2023-06-13 14:26:32.0

我正在尝试编写一个函数,该函数使我可以对字典中的参数子集灵活地运行网格搜索。 我要完成的特定行为如下:

def my_grid_searching_function(fiducial_dict, **param_iterators):
    for params in desired_iterator:
        fiducial_dict.update(params)
        # compute chi^2
        # write new fiducial_dict values and associated chi^2 value to disk

我的具体目标是弄清楚如何编写desired_iterator

功能my_grid_searching_function接受的关键字参数,其中的每一个将被解释为参数的任意子集fiducial_dict

这似乎是itertools.product的任务,但是我遇到了一个问题。 在以下实现中,我可以使用product将输入迭代器的上的嵌套循环有效地转换为单个循环:

from itertools import product
def my_failed_grid_searching_function(fiducial_dict, **param_iterators):
    desired_iterator = product(*list(param_iterators.values()))
    for params in desired_iterator:
        print(params)
fiducial_dict = {'x': 0, 'y': 0, 'z': 9}
my_failed_grid_searching_function(fiducial_dict, x=[4, 5, 6], y=[1, 2])    

(1, 4)
(1, 5)
(1, 6)
(2, 4)
(2, 5)
(2, 6)

当然,这样做的问题是输入的param_iterators已被放入普通字典中,因此在my_failed_grid_searching_function的命名空间中,我不知道值的顺序是什么。

任何人都可以提供关于我如何写任何提示desired_iterator ,使其产生足够的信息来更新fiducial_dict如上图所示?

由于您使用的是任意关键字参数,因此可以抓住param_iterators词典的键来同步param产品的位置。 另外,我建议使用sklearn包执行 。

无论如何,请尝试以下解决方案:

from itertools import product

def my_grid_searching_function(fiducial_dict, **param_iterators):
    keys = param_iterators.keys()
    desired_iterator = list(product(*list(param_iterators.values())))
    for i in range(len(desired_iterator)):
        print("Epoch: ", i)
        for loc in range(len(desired_iterator[i])):
            print(keys[loc], desired_iterator[i][loc])
        # update your fiducial_dict here

my_grid_searching_function({'x': 0, 'y': 0}, x=[1,2,3,4], y=[6,7,8])

输出:

('Epoch: ', 0)
('y', 6)
('x', 1)
('Epoch: ', 1)
('y', 6)
('x', 2)
('Epoch: ', 2)
('y', 6)
('x', 3)
('Epoch: ', 3)
('y', 6)
('x', 4)
('Epoch: ', 4)
('y', 7)
('x', 1)
('Epoch: ', 5)
('y', 7)
('x', 2)
('Epoch: ', 6)
('y', 7)
('x', 3)
('Epoch: ', 7)
('y', 7)
('x', 4)
('Epoch: ', 8)
('y', 8)
('x', 1)
('Epoch: ', 9)
('y', 8)
('x', 2)
('Epoch: ', 10)
('y', 8)
('x', 3)
('Epoch: ', 11)
('y', 8)
('x', 4)

***Repl Closed***

感谢Scratch'N'Purr指出,可以仅通过.keys()方法确定序列顺序。

from itertools import product
def param_grid_search_generator(**param_iterators):
    param_names = list(param_iterators.keys())
    param_combination_generator = product(*list(param_iterators.values()))
    for param_combination in param_combination_generator:
        yield {param_names[i]: param_combination[i] for i in range(len(param_names))}
  相关解决方案