*回溯

写在前面

从这一节开始,我们就要讨论一些更复杂的算法问题了。对于之前的内容,如果还有不是特别明白的小伙伴呢,建议你们先理解前面所学的内容,再来看这一部分。因为这一整章的内容都是打了星号的,属于算法的较为高阶的部分,所以你也可以选择直接跳过这一部分内容。

引入

我们知道,组合问题的求解是最耗时也是最难解决的问题之一,如果用暴力求解的方式来解决它,我们需要列举出这个问题的所有情况,也就是穷举出所有的子集,然后从子集中选择满足条件的一种来作为最后的解。例如,如果从集合 {1, 2, 5, 6} 中选出元素之和为 d = 8 的子集,我们就需要先把所有子集列举出来,求出每一个子集的元素之和,然后找出总和为 8 的组合来作为最后的解。

问题

下面来看看其中的一种情况,当我们的子集为 {5, 6} 时,元素之和为 11,已经不满足问题中要求的 s = 8。这时候,如果再判断含有 {5, 6} 的子集 {5, 6, 1} 之和是否等于 8 就会显得有些多此一举,因为 {5, 6} 之和已经都大于 8 了,此时如果再加入一个非负的数肯定也是大于 8 的。我们可以看到,有一些情况是多余的,无需再求和判断。因此,我们需要设计出一种算法来帮助我们及时发现这些多余的情况。

回溯

回溯backtracking)是一种类似枚举的搜索尝试过程,在搜索尝试的过程中寻找问题的解,如果在搜索的过程中发现不满足条件,到了“死胡同”(dead end),我们就回溯,回归到上一个状态,然后继续搜索下一个可能的解,如果满足要求就返回问题的解。

继续来看前面的问题,假设我们搜索到了状态 {5, 6},由于 5 + 6 > 8 不满足条件,所以我们要回溯到上一个状态 {5},而不是继续搜索下一个状态 {5, 6, 1},回到状态 {5} 之后,我们到下一个状态 {5, 1},发现小于 8,然后我们继续搜索下一个状态 {5, 1, 2},其元素之和等于 8,于是返回问题的解。

回溯的过程就有点像走迷宫,我们从起点出发,向上下左右四个方向探索路径,一旦遇到了死胡同,我们就退回,然后继续探索其他的路径,直到我们最后走出迷宫。

迷宫

八皇后问题

了解了回溯的基本概念之后,我们来看看一个例子吧。八皇后问题就是回溯算法的一道经典的例子,下过国际象棋的同学一定知道,皇后在棋盘中可以横着、竖着、斜着三个方向走,下面的图就展示了其中的一种解法。

棋盘

八皇后问题的描述是这样子的:我们需要在 8 × 8 的棋盘中放置 8 个皇后,使得所有皇后之间不能相互攻击,也就是每一横行、竖行、斜行只能放一个皇后。就像下面那样。

棋盘

要解决这个问题,我们可以先将八皇后问题变为一个一般的 n 皇后问题,当 n = 1 时,即棋盘大小是 1 × 1 的,皇后数为 1,这时只需要把皇后直接放到 (1, 1) (第一行第一列) 即可。

棋盘

接下来,我们发现当 n = 2 和 3 时问题是无解的。

棋盘

当 n = 4 时,我们采用回溯的方式来求解这个问题。首先,我们在将皇后 1 放置在棋盘 (1, 1) 的位置,然后将皇后 2 放置在 (2, 1) 的位置,发现两个皇后处于同一列,不满足规则,所以回溯到只有皇后 1 的状态那里,然后再将皇后 2 尝试放到 (2, 2) 处,发现还是不满足规则,于是再回溯,然后再探索棋盘的下一个位置 (2, 3),放完之后发现无论皇后 3 放置在第三行哪一列都不满足条件,因此我们再要回溯然后把皇后 2 放置在 (2, 4) 处。再通过同样的方式,我们将皇后 3 放置在棋盘的 (3, 2) 处,结果发现我们不能在第四行放最后一个皇后,然后我们回溯到状态 (3, 2),当皇后 3 尝试完剩下的 (3, 3) 和 (3, 4) 发现也不满足要求后,我们回溯到最初的状态(棋盘为空),然后将皇后 1 放在第二列 (1, 2) 处,然后我们再按照前面的方法探索,不满足题目要求就回溯,否则就继续探索,最后得到问题的解。

我们可以用一个树形结构来表示这一过程,树的每一个节点代表一个状态,每到一个节点,我们就判断一下是否满足要求,如满足就在树的下一层新建节点,如不满足就返回到父节点。我们把这种树称作状态空间树(state-space tree)。

状态树

最后就是用代码实现这个过程了,我们用一个table来存储每一行皇后放置的列,回溯的过程用递归来实现,如果is_promisingTrue,我们就向下递归,同时row加一。当row = n时,我们就认为已经找到了问题的解,最后我们用display函数打印出棋盘以及皇后的位置。

def is_promising(table, row):
    # 是否处于同一列
    if len(set(table[: row+1])) != len(table[: row+1]):
        return False
    # 是否处于同一对角线
    for i in range(row):
        if abs(table[row] - table[i]) == row - i:
            return False
    return True
def display_board(table):
    n = len(table)
    for col in table:
        print("+ " * (col) + "Q " + "+ " * (n-1-col))
def display(table, cnt):
    n = len(table)
    if cnt:
        print("Solution %d:" % cnt)
        display_board(table)
    else:
        print("Solution: ")
        display_board(table)
def backtracking(table, n, is_mul, row=0):
    if row == n:  # 递归出口
        if is_mul:
            global cnt  # 计数器
            display(table, cnt+1); cnt += 1
        else:
            display(table, None)
            exit()  # 终止程序
        return
    for col in range(n):
        table[row] = col
        if is_promising(table, row):
            backtracking(table, n, is_mul, row+1)  # 搜索下一层
def queen(n, multiple=False):
    col_tab = [-1] * n
    backtracking(col_tab, n, is_mul=multiple)

最后我们将 n 设为 8 来看看最后的结果:

>>> python queen.py
>>> Q + + + + + + +
    + + + + Q + + +
    + + + + + + + Q
    + + + + + Q + +
    + + Q + + + + +
    + + + + + + Q +
    + Q + + + + + +
    + + + Q + + + +

本节全部代码