实现两种启发函数
采取两种策略实现启发函数:
- 策略1:不在目标位置的数字个数
- 策略2:曼哈顿距离(将数字直接移动到对应位置的步数总数)
# 策略1: 不在目标位置的数字个数,即 state 与 goal_state 不相同的数字个数
def h1(state, goal_state):
'''
state, goal_state - 3x3 list
'''
distance = 0
for i in range(3):
for j in range(3):
if state[i][j] != goal_state[i][j] and state[i][j] != 0:
distance += 1
return distance
# 功能性函数,用于查找给定数字 num 在 goal_state 中的坐标
def find_num(num, goal_state):
for i in range(3):
for j in range(3):
if goal_state[i][j] == num:
return i, j
return -1, -1
# 策略2: 曼哈顿距离之和
def h2(state, goal_state):
'''
state, goal_state - 3x3 list
'''
distance = 0
for i in range(3):
for j in range(3):
if state[i][j] == 0:
continue
if state[i][j] == goal_state[i][j]:
continue
goal_i, goal_j = find_num(state[i][j], goal_state)
distance += abs(i - goal_i) + abs(j - goal_j)
return distance
# 测试
start_state = [
[2, 8, 3],
[1, 6, 4],
[7, 0, 5]
]
goal_state = [
[1, 2, 3],
[8, 0, 4],
[7, 6, 5]
]
# 不在目标位置的数字:1、2、8、6,共 4 个
# 1 需移动 1 步到达正确位置
# 2 需移动 1 步到达正确位置
# 8 需移动 2 步到达正确位置
# 6 需移动 1 步到达正确位置
# 曼哈顿距离共 5 步
print(h1(start_state, goal_state)) # 4
print(h2(start_state, goal_state)) # 5
实现A*算法
为了便于替换启发函数,将其作为参数传入函数:
# 定义A*算法函数
def astar(start_state, goal_state, h):
'''
params:
start_state - 3x3 list 初始状态
goal_state - 3x3 list 目标状态
h - function 启发函数
returns:
expanded_nodes - 扩展节点数
run_time - 算法运行时间
path - 算法运行路径
ps. 当路径不存在时,会返回 run_time = 0, path = None
'''
start_time = time.time() # 算法开始
open_list = [(h(start_state, goal_state), start_state)] # 存储待扩展的节点的优先队列
closed_set = set() # 存储已经扩展过的节点的集合
came_from = {} # 记录节点之间的关系,即每个节点的父节点是哪个节点
expanded_nodes = 0 # 记录扩展节点的数量
while open_list: # 带扩展节点队列不为空
_, current_state = heapq.heappop(open_list) # 弹出优先级最高的节点
expanded_nodes += 1
if current_state == goal_state: # 找到目标状态
# 回溯路径
path = [current_state]
while tuple(map(tuple, current_state)) in came_from:
current_state = came_from[tuple(map(tuple, current_state))]
path.append(current_state)
end_time = time.time() # 记录算法结束时间
return expanded_nodes, end_time-start_time, path[::-1]
closed_set.add(tuple(map(tuple, current_state))) # 将当前节点状态加入已扩展节点集合
zero_i, zero_j = find_num(0, current_state) # 找到当前的空格坐标
moves = [(0, 1), (0, -1), (1, 0), (-1, 0)] # 四周的格子
for di, dj in moves:
new_i, new_j = zero_i + di, zero_j + dj # 移动的数字
if 0 <= new_i < 3 and 0 <= new_j < 3: # 确保新位置在范围内
new_state = [row[:] for row in current_state] # 拷贝 current_state
new_state[zero_i][zero_j], new_state[new_i][new_j] = current_state[new_i][new_j], current_state[zero_i][zero_j] # 移动空白格
if tuple(map(tuple, new_state)) in closed_set:
continue # 如果新状态已经扩展过,则跳过
new_cost = len(came_from) + 1 + h(new_state, goal_state) # 计算新状态的代价
heapq.heappush(open_list, (new_cost, new_state)) # 将新状态加入优先队列
came_from[tuple(map(tuple, new_state))] = tuple(map(tuple, current_state)) # 更新新状态的父节点信息
# 无可行解
return expanded_nodes, 0, None
测试
首先,定义一个函数 print_path()
用于查看路径:
def print_path(path):
step = 0
for state in path:
print("Step. ", step)
for row in state:
print(row)
step += 1
设置初始状态和目标状态进行测试:
# 设置初始状态和目标状态
start_state = [
[2, 8, 3],
[1, 6, 4],
[7, 0, 5]
]
goal_state = [
[1, 2, 3],
[8, 0, 4],
[7, 6, 5]
]
h1_nodes, h1_times, h1_path = astar(start_state, goal_state, h1) # 通过 h1 启发函数调用 astar 算法
h2_nodes, h2_times, h2_path = astar(start_state, goal_state, h2) # 通过 h2 启发函数调用 astar 算法
if h1_path:
print("调用 h1 启发函数的 A* 算法共扩展 {} 个节点,耗时 {}s,路径如下:".format(h1_nodes, h1_times))
# print_path(h1_path)
else:
print("调用 h1 启发函数的 A* 算法无法得到可行解。")
# print("=" * 50)
if h2_path:
print("调用 h2 启发函数的 A* 算法共扩展 {} 个节点,耗时 {}s,路径如下:".format(h2_nodes, h2_times))
# print_path(h2_path)
else:
print("调用 h2 启发函数的 A* 算法无法得到可行解。")
输出结果:(path
输出过长,这里省略)
调用 h1 启发函数的 A* 算法共扩展 28 个节点,耗时 0.00037217140197753906s,路径如下:
调用 h2 启发函数的 A* 算法共扩展 17 个节点,耗时 0.0002200603485107422s,路径如下:
测试鲁棒性——当可行解不存在时:
# 设置初始状态和目标状态
start_state = [
[7, 8, 3],
[1, 5, 2],
[6, 0, 4]
]
goal_state = [
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]
]
h1_nodes, h1_times, h1_path = astar(start_state, goal_state, h1) # 通过 h1 启发函数调用 astar 算法
h2_nodes, h2_times, h2_path = astar(start_state, goal_state, h2) # 通过 h2 启发函数调用 astar 算法
if h1_path:
print("调用 h1 启发函数的 A* 算法共扩展 {} 个节点,耗时 {}s,路径如下:".format(h1_nodes, h1_times))
# print_path(h1_path)
else:
print("调用 h1 启发函数的 A* 算法无法得到可行解。")
# print("=" * 50)
if h2_path:
print("调用 h2 启发函数的 A* 算法共扩展 {} 个节点,耗时 {}s,路径如下:".format(h2_nodes, h2_times))
# print_path(h2_path)
else:
print("调用 h2 启发函数的 A* 算法无法得到可行解。")
输出结果:(path
输出过长,这里省略)
调用 h1 启发函数的 A* 算法无法得到可行解。
调用 h2 启发函数的 A* 算法无法得到可行解。