问题描述
我有这个:
shape = (2, 4) # arbitrary, could be 3 dimensions such as (3, 5, 7), etc...
for i in itertools.product(*(range(x) for x in shape)):
print(i)
# output: (0, 0) (0, 1) (0, 2) (0, 3) (1, 0) (1, 1) (1, 2) (1, 3)
到目前为止,很好,itertools.product
在每次迭代中都推进最右边的元素.但是现在我希望能够根据以下内容指定迭代顺序:
So far, so good, itertools.product
advances the rightmost element on every iteration. But now I want to be able to specify the iteration order according to the following:
axes = (0, 1) # normal order
# output: (0, 0) (0, 1) (0, 2) (0, 3) (1, 0) (1, 1) (1, 2) (1, 3)
axes = (1, 0) # reversed order
# output: (0, 0) (1, 0) (2, 0) (3, 0) (0, 1) (1, 1) (2, 1) (3, 1)
如果shapes
具有三个维度,则axes
可能是(0, 1, 2)
或(2, 0, 1)
等,因此,不必简单地使用reversed()
.所以我写了一些代码,但是看起来效率很低:
If shapes
had three dimensions, axes
could have been for instance (0, 1, 2)
or (2, 0, 1)
etc, so it's not a matter of simply using reversed()
. So I wrote some code that does that but seems very inefficient:
axes = (1, 0)
# transposed axes
tpaxes = [0]*len(axes)
for i in range(len(axes)):
tpaxes[axes[i]] = i
for i in itertools.product(*(range(x) for x in shape)):
# reorder the output of itertools.product
x = (i[y] for y in tpaxes)
print(tuple(x))
关于如何正确执行此操作的任何想法?
Any ideas on how to properly do this?
推荐答案
好吧,这实际上是一本专门的product
手册.由于轴只需要重新排序一次,所以速度应该更快:
Well, this is in fact a manual specialised product
. It should be faster since axes are reordered only once:
def gen_chain(dest, size, idx, parent):
# iterate over the axis once
# then trigger the previous dimension to update
# until everything is exhausted
while True:
if parent: next(parent) # StopIterator is propagated upwards
for i in xrange(size):
dest[idx] = i
yield
if not parent: break
def prod(shape, axes):
buf = [0] * len(shape)
gen = None
# EDIT: fixed the axes order to be compliant with the example in OP
for s, a in zip(shape, axes):
# iterate over the axis and put to transposed
gen = gen_chain(buf, s, a, gen)
for _ in gen:
yield tuple(buf)
print list(prod((2,4), (0,1)))
# [(0, 0), (0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2), (1, 3)]
print list(prod((2,4), (1,0)))
# [(0, 0), (1, 0), (2, 0), (3, 0), (0, 1), (1, 1), (2, 1), (3, 1)]
print list(prod((4,3,2),(1,2,0)))
# [(0, 0, 0), (1, 0, 0), (0, 0, 1), (1, 0, 1), (0, 0, 2), (1, 0, 2), ...
这篇关于Python itertools.product重新排列世代的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!