我非常确定dask可以解决的问题需要帮助。
但是我不知道该如何解决。
我需要递归构造一棵树。
对于每个节点,如果满足条件,则进行计算(compute_val
),否则将创建2个新子级。对子代(build
)执行相同的处理。
然后,如果节点的所有子节点都执行了计算,则可以继续进行合并(merge
)。合并可以执行子项的融合(如果它们都符合条件)或什么都不进行。
目前,我只能并行处理第一个级别,而且我不知道我应该使用哪种轻巧的工具来提高效率。
这是我要实现的简化的MRE顺序:
import numpy as np
import time
class Node:
def __init__(self, level):
self.level = level
self.val = None
def merge(node, childs):
values = [child.val for child in childs]
if all(values) and sum(values)<0.1:
node.val = np.mean(values)
else:
node.childs = childs
return node
def compute_val():
time.sleep(0.1)
return np.random.rand(1)
def build(node):
print(node.level)
if (np.random.rand(1) < 0.1 and node.level>1) or node.level>5:
node.val = compute_val()
else:
childs = [build(Node(level=node.level+1)) for _ in range(2)]
node = merge(node, childs)
return node
tree = build(Node(level=0))
最佳答案
据我了解,解决递归(或任何动态计算)的方法是在任务中创建任务。
我正在尝试类似的方法,因此下面是我的5分钟说明性解决方案。您必须根据算法的特征对其进行优化。
请记住,任务会增加开销,因此您需要对计算进行分块以获得最佳结果。
相关文件:
https://distributed.dask.org/en/latest/task-launch.html
API参考:
https://distributed.dask.org/en/latest/api.html#distributed.worker_client
https://distributed.dask.org/en/latest/api.html#distributed.Client.gather
https://distributed.dask.org/en/latest/api.html#distributed.Client.submit
import numpy as np
import time
from dask.distributed import Client, worker_client
# Create a dask client
# For convenience, I'm creating a localcluster.
client = Client(threads_per_worker=1, n_workers=8)
client
class Node:
def __init__(self, level):
self.level = level
self.val = None
self.childs = None # This was missing
def merge(node, childs):
values = [child.val for child in childs]
if all(values) and sum(values)<0.1:
node.val = np.mean(values)
else:
node.childs = childs
return node
def compute_val():
time.sleep(0.1) # Is this required.
return np.random.rand(1)
def build(node):
print(node.level)
if (np.random.rand(1) < 0.1 and node.level>1) or node.level>5:
node.val = compute_val()
else:
with worker_client() as client:
child_futures = [client.submit(build, Node(level=node.level+1)) for _ in range(2)]
childs = client.gather(child_futures)
node = merge(node, childs)
return node
tree_future = client.submit(build, Node(level=0))
tree = tree_future.result()