对于基于回合的游戏,我想计算每个玩家可以在地图的每个区域上移动或生成的最大单位数。

我需要计算的所有数据已经​​存储在几个numpy数组中,但是我一直在努力寻找先进的数组索引技术来尽可能快地进行计算。

为了解决这个问题,我用一些For循环以最简单的方式重写了该函数:

import numpy as np

def get_max_units_on_zone_per_player(unitCountPerPlayer, zoneOwner, playerAvailableUnits, zoneLinks, blockedMovesPerPlayer):
    """
    Parameters
    ----------
    unitCountPerPlayer: np.array((zoneCount, playerCount), dtype=int)
        How many units each player has on a zone
    zoneOwner: np.array(zoneCount, dtype=int)
        Which player is owning a zone (-1 for none)
    playerAvailableUnits: np.array(playerCount, dtype=int)
        How many units each player can spawn
    zoneLinks: np.array((zoneCount, zoneCount), dtype=int)
        > 0 if zone1 is connected to zone2 (directed and weighted graph)
    blockedMovesPerPlayer: np.array((playerCount, zoneCount, zoneCount), dtype=bool)
        True if player can not move from zone1 to zone2

    Returns
    -------
    np.array((zoneCount, playerCount), dtype=int)
        Maximum count of units each player can have on each zone
    """

    zoneCount, playerCount = unitCountPerPlayer.shape

    # Adding units already on zone
    result = np.zeros((zoneCount, playerCount), dtype=int) + unitCountPerPlayer

    for p in xrange(playerCount):
        for z1 in xrange(zoneCount):

            if zoneOwner[z1] in (-1, p):
                # Player can spawn on neutral or owned zones
                result[z1, p] += playerAvailableUnits[p]

            for z2 in xrange(zoneCount):

                if zoneLinks[z1, z2] > 0 and not blockedMovesPerPlayer[p, z1, z2]:
                    # If z1 and z2 are connected and player can move from z1 to z2, adding units count on z1 to z2
                    result[z2, p] += unitCountPerPlayer[z1, p]
    return result


问题是我无法使用此函数,每次调用大约需要30毫秒,而且我确定它可以重写,因为某些numpy操作需要不到5毫秒的时间来处理。

有人可以帮我吗?还有一个逐步的过程,以便下次我可以自己做吗?我已经阅读了numpy关于数组和索引的文档多次,但是它还不是很清晰,我只是想不通。

编辑:根据要求,以下是一些随机数据可以用作示例:

zoneCount=8 ; playerCount=2

unitCountPerPlayer:
[[1 2]
 [1 3]
 [1 3]
 [3 2]
 [1 2]
 [3 2]
 [0 2]
 [3 2]]

zoneOwner:
[ 1  0 -1 -1 -1  0 -1 -1]

playerAvailableUnits:
[2 2]

zoneLinks:
[[0 1 1 1 0 1 0 0]
 [1 0 0 1 0 0 0 1]
 [1 1 1 1 0 1 0 1]
 [0 1 1 1 1 0 1 0]
 [0 0 1 1 1 0 1 1]
 [0 0 1 1 1 1 1 1]
 [1 0 0 0 0 1 0 1]
 [1 1 1 1 0 1 1 1]]

blockedMovesPerPlayer:
[[[False False False False False False False False]
  [ True False False False False False False False]
  [ True False False False False False False False]
  [False False False False False False False False]
  [False False False False False False False False]
  [False False False False False False False False]
  [ True False False False False False False False]
  [ True False False False False False False False]]

 [[False  True False False False  True False False]
  [False False False False False False False False]
  [False  True False False False  True False False]
  [False  True False False False False False False]
  [False False False False False False False False]
  [False False False False False  True False False]
  [False False False False False  True False False]
  [False  True False False False  True False False]]]

get_max_units_on_zone_per_player():
[[ 1 14]
 [11  3]
 [15 18]
 [18 20]
 [10 10]
 [13  2]
 [12 12]
 [14 18]]




复制/粘贴数据:

zoneCount = 8
playerCount = 2

unitCountPerPlayer = np.array([[1,2], [1,3], [1,3], [3,2],
                               [1,2], [3,2], [0,2], [3,2]])

zoneOwner = np.array([1, 0, -1, -1, -1, 0, -1, -1])

playerAvailableUnits = np.array([2,2])

zoneLinks = np.array([[0,1,1,1,0,1,0,0], [1,0,0,1,0,0,0,1],
                      [1,1,1,1,0,1,0,1], [0,1,1,1,1,0,1,0],
                      [0,0,1,1,1,0,1,1], [0,0,1,1,1,1,1,1],
                      [1,0,0,0,0,1,0,1], [1,1,1,1,0,1,1,1]])

bmpp = [[[False, False, False, False, False, False, False, False],
         [ True, False, False, False, False, False, False, False],
         [ True, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False],
         [ True, False, False, False, False, False, False, False],
         [ True, False, False, False, False, False, False, False]],
        [[False,  True, False, False, False,  True, False, False],
         [False, False, False, False, False, False, False, False],
         [False,  True, False, False, False,  True, False, False],
         [False,  True, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False],
         [False, False, False, False, False,  True, False, False],
         [False, False, False, False, False,  True, False, False],
         [False,  True, False, False, False,  True, False, False]]]
blockedMovesPerPlayer = np.array(bmpp)

最佳答案

[更新:numpy方法的实现,避免了for循环]

这是我的get_max_units_on_zone_per_player()新实现:

def get_max_units_on_zone_per_player(unitCountPerPlayer, zoneOwner, playerAvailableUnits, zoneLinks, blockedMovesPerPlayer):
    result = unitCountPerPlayer.copy()
    result[zoneOwner < 0] += playerAvailableUnits
    _z1 = np.where(zoneOwner >= 0)
    result[_z1, zoneOwner[_z1]] += playerAvailableUnits[zoneOwner[_z1]]
    _p, _z1, _z2 = np.where(np.logical_and(zoneLinks > 0, np.logical_not(blockedMovesPerPlayer)))
    np.add.at(result, [_z2, _p], unitCountPerPlayer[_z1, _p])
    return result


我使用以下设置测试了这两种实现:

zoneCount = 100
playerCount = 1000
maxUnits = 500

unitCountPerPlayer = np.random.randint(0, maxUnits, size=(zoneCount, playerCount))
zoneOwner = np.random.randint(-1, playerCount, size=zoneCount)
playerAvailableUnits = np.random.randint(0, maxUnits, size=playerCount)
zoneLinks = np.random.randint(0, maxUnits, size=(zoneCount, zoneCount))
blockedMovesPerPlayer = np.random.randint(0, 2, size=(playerCount, zoneCount, zoneCount), dtype=bool)


这是测试结果(使用%timeit


fbparis的原始实现:

每个循环7.27 s±10毫秒(平均±标准偏差,共7次运行,每个循环1次)
我的新实现:

每个循环645 ms±490 µs(平均±标准偏差,共运行7次,每个循环1个)

关于python - 如何用单个numpy数组操作替换此三重For循环?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/48472961/

10-12 12:50
查看更多