我非常确定dask可以解决的问题需要帮助。
但是我不知道该如何解决。

我需要递归构造一棵树。

对于每个节点,如果满足条件,则进行计算(compute_val),否则将创建2个新子级。对子代(build)执行相同的处理。
然后,如果节点的所有子节点都执行了计算,则可以继续进行合并(merge)。合并可以执行子项的融合(如果它们都符合条件)或什么都不进行。
目前,我只能并行处理第一个级别,而且我不知道我应该使用哪种轻巧的工具来提高效率。
这是我要实现的简化的MRE顺序:

import numpy as np
import time

class Node:
    def __init__(self, level):
        self.level = level
        self.val = None

def merge(node, childs):
    values = [child.val for child in childs]
    if all(values) and sum(values)<0.1:
        node.val = np.mean(values)
    else:
        node.childs = childs
    return node

def compute_val():
    time.sleep(0.1)
    return np.random.rand(1)

def build(node):
    print(node.level)
    if (np.random.rand(1) < 0.1 and node.level>1) or node.level>5:
        node.val = compute_val()
    else:
        childs = [build(Node(level=node.level+1)) for _ in range(2)]
        node = merge(node, childs)
    return node

tree = build(Node(level=0))

最佳答案

据我了解,解决递归(或任何动态计算)的方法是在任务中创建任务。

我正在尝试类似的方法,因此下面是我的5分钟说明性解决方案。您必须根据算法的特征对其进行优化。

请记住,任务会增加开销,因此您需要对计算进行分块以获得最佳结果。

相关文件:


https://distributed.dask.org/en/latest/task-launch.html


API参考:


https://distributed.dask.org/en/latest/api.html#distributed.worker_client
https://distributed.dask.org/en/latest/api.html#distributed.Client.gather
https://distributed.dask.org/en/latest/api.html#distributed.Client.submit


import numpy as np
import time
from dask.distributed import Client, worker_client

# Create a dask client
# For convenience, I'm creating a localcluster.
client = Client(threads_per_worker=1, n_workers=8)
client

class Node:
    def __init__(self, level):
        self.level = level
        self.val = None
        self.childs = None   # This was missing

def merge(node, childs):
    values = [child.val for child in childs]
    if all(values) and sum(values)<0.1:
        node.val = np.mean(values)
    else:
        node.childs = childs
    return node

def compute_val():
    time.sleep(0.1)            # Is this required.
    return np.random.rand(1)

def build(node):
    print(node.level)
    if (np.random.rand(1) < 0.1 and node.level>1) or node.level>5:
        node.val = compute_val()
    else:
        with worker_client() as client:
            child_futures = [client.submit(build, Node(level=node.level+1)) for _ in range(2)]
            childs = client.gather(child_futures)
        node = merge(node, childs)
    return node

tree_future = client.submit(build, Node(level=0))
tree = tree_future.result()

10-08 12:55