简易五子棋AI对战系统:Q-learning与监督学习

全程使用Python实现,代码可直接运行。

一、前言

基于强化学习和监督学习,训练五子棋对弈AI

二、准备

本次项目所需依赖极少,提前通过pip安装核心依赖即可,无需复杂的环境配置:

1
2
pip install torch numpy
pip install tqdm`
核心依赖:PyTorch(模型训练/推理)、NumPy(数据处理与棋盘维护),tqdm 用于训练过程进度条展示

三、项目整体架构

整个项目分为3个核心文件,低耦合、高内聚,每个文件承担明确的职责,方便后续扩展与修改:

  1. gomoku_data.py:高质量自对弈样本生成,为CNN模型提供”学习素材”;
  2. train_model.py:CNN模型定义与训练,生成可用于对弈的模型文件;
  3. main.py:可视化对弈界面与AI核心决策,实现玩家与AI的实时对战。

三者的执行流程为:生成样本(gomoku_data.py)训练模型(train_model.py)启动对弈(main.py)

四、核心思路讲解

第一步:生成高质量自对弈样本

核心逻辑

  1. 实现基础工具:棋盘状态管理、合法落子校验、胜负判断、连子检测(3/4/5连);
  2. 设计优先级落子策略:直接赢棋 > 阻挡对手赢棋 > 自己连4 > 阻挡对手连4 > 自己连3 > 相邻位置 > 随机落子,确保生成的样本带有明确的攻防逻辑;
  3. 批量生成对局:运行2000局自对弈,记录每一步的”棋盘状态 + 落子位置”,保存为npz格式(方便PyTorch高效加载)。

关键亮点

  • 样本带有明确的策略性,避免无意义的随机落子,提升模型训练效率;
  • 自动过滤无效对局,确保生成的样本具备可用性;
  • 输出文件gomoku_train_data.npz包含”棋盘状态数组”和”落子索引数组”,直接适配后续模型训练。

第二步:训练CNN五子棋落子预测模型

有了样本之后,我们需要训练一个模型,让它从样本中学习”在某个棋盘状态下,应该落子在哪里”。

核心逻辑

  1. 定义轻量级CNN模型:输入为3通道15x15棋盘(对应”玩家棋子、AI棋子、空位置”),输出为225个落子位置的得分(15x15=225);
    • 采用两层卷积层提取棋盘空间特征,配合批归一化提升训练稳定性;
    • 采用两层全连接层,将卷积特征映射为落子位置得分;
    • 模型结构简洁,新手易理解,训练速度快,无需高端GPU。
  2. 加载自对弈样本:实现自定义GomokuDataset类,解析npz文件,转换为PyTorch可处理的张量;
  3. 模型训练配置:采用CrossEntropyLoss损失函数(适配分类任务,落子位置可视为225分类)、Adam优化器(收敛速度快)、学习率调度器(防止训练后期震荡);
  4. 保存训练模型:训练完成后,保存模型参数为gomoku_model.pth,供后续对弈界面加载。

关键亮点

  • 模型结构与后续对弈界面的模型定义完全一致,避免加载时出现”结构不匹配”错误;
  • 训练过程带有进度条展示,实时输出损失值,方便监控训练效果;
  • 支持CPU训练,普通电脑即可完成,门槛低。

第三步:搭建可视化对弈界面与AI核心决策

​ 我们通过Tkinter搭建可视化界面,同时实现 模型预测与Q-Learning优化的AI决策逻辑。

Part 1:可视化界面搭建

  1. 棋盘绘制:绘制15x15标准五子棋棋盘,包含横线、竖线与星位(3、7、11位置),还原真实五子棋场景;
  2. 棋子绘制:区分玩家黑棋与AI白棋,通过鼠标点击事件实现玩家落子,自动校验落子合法性(越界、重复落子);
  3. 游戏控制:实现”开始新游戏””重置棋盘”按钮,支持游戏状态重置;
  4. 胜负提示:落子后自动检测横、竖、斜四个方向的五子连珠,弹出提示框告知对局结果。

Part 2:AI核心决策逻辑

AI落子优先级严格遵循”先保命、再获胜、后常规”,彻底避免低级失误,具体优先级如下:

  1. 优先补自己4连(获胜):实时检测自身的4连优势,只要检测到,立即补子形成5连,直接获胜
  2. 优先堵对手4连(保命):实时检测对手的4连威胁,只要检测到,立即落子堵塞,防止对手下一步获胜;
  3. 优先堵对手3连(防威胁):检测对手的3连威胁,堵塞其形成4连的可能,遏制对手进攻;
  4. 优先补自己3连(造威胁):检测自身的3连优势,补子形成4连,给对手施加防守压力;
  5. 模型+Q-Learning预测(常规决策):无明显攻防威胁时,加载训练好的CNN模型获取落子得分,结合Q-Learning在线优化落子选择,提升决策灵活性;
  6. 随机落子(极端兜底):无模型文件时,自动降级为随机落子,不影响游戏运行。

关键亮点

  1. 通过check_opponent_threat(检测对手威胁)和check_ai_advantage(检测自身优势)两个方法,100%识别3连/4连,避免模型训练不确定性带来的低级失误;
  2. Q-Learning在线优化:通过Q表记录棋盘状态与落子价值,在线对弈过程中不断更新Q值,提升AI的自适应能力;
  3. 兼容性强:无gomoku_model.pth文件时,AI自动降级为硬编码策略落子,依然具备攻防能力,不影响对弈体验;

五、项目运行效果验证

  1. 运行gomoku_data.py,生成gomoku_train_data.npz样本文件;
  2. 运行train_model.py,训练完成后生成gomoku_model.pth模型文件;
  3. 运行main.py,点击”开始新游戏”,即可开始与AI对弈:
    • 玩家落子形成3连/4连,AI会立即堵塞,不会漏防;
    • AI形成3连/4连后,会优先补子冲5获胜,不会浪费机会;
    • 无明显攻防威胁时,AI会基于CNN模型落子,具备灵活的常规策略。

六、改错历程

在项目开发过程中,AI的决策逻辑经历了三次核心迭代,从”无策略乱落子”逐步优化为”攻防兼备、能攻善守”,每一次改错都针对性解决了实际对弈中的关键问题,具体历程如下:

1. 初始问题:AI纯随机落子,强化学习无效

问题表现

项目初期,仅实现了Q-Learning强化学习框架和基础可视化界面,未加入任何策略逻辑,运行游戏后发现,AI落子完全随机——无论玩家如何落子、是否形成连子威胁,AI都不会针对性防守,也不会主动连子,甚至会落在棋盘边缘无意义的位置,对弈体验极差;同时观察Q-Learning的Q表的更新日志,发现Q值始终处于随机波动状态,并未随着对局次数增加而收敛,强化学习完全没有发挥作用。

排查与错误原因

通过逐行调试代码、打印关键变量,最终定位到两个核心错误:

  • 错误1:Q-Learning的状态哈希生成异常,get_state_hash方法中,误用board.tostring()(NumPy新版本已废弃),导致不同棋盘状态生成相同的哈希值,Q表无法正确记录”棋盘状态-落子动作”的对应关系,强化学习无法积累有效经验;
  • 错误2:未设置任何基础落子策略,仅依赖Q-Learning的ε-贪心策略,而初期Q表为空,AI的”探索率”(epsilon=0.1)虽低,但无有效Q值可利用,本质上还是随机落子;同时,未加载任何训练好的CNN模型,缺乏基础决策依据,强化学习没有”学习方向”。

修改方案

  1. 修复哈希生成问题:将board.tostring()替换为board.tobytes(),适配NumPy新版本,确保每个不同的棋盘状态都能生成唯一的哈希值,让Q表可以正常更新和查询;
1
2
3
4
5
6
7
def get_state_hash(self, board):
"""将棋盘状态转换为哈希值(用于Q表存储)"""
# 修复前:使用已废弃的tostring()方法
# return hash(board.tostring()) # ❌ 老版本NumPy方法

# 修复后:使用tobytes()方法
return hash(board.tobytes()) # ✅ 新版本NumPy方法
  1. 新增基础落子逻辑:在ai_choose_action方法中,先筛选”相邻已有棋子”的合法位置,优先从这些位置中选择落子(避免无意义的边缘落子),给强化学习提供基础方向;

  2. 完善Q-Learning参数配置:调整学习率(lr=0.1)、折扣因子(gamma=0.9),延长Q值的更新周期,确保Q表能随着对局次数增加逐步收敛,让强化学习逐步积累有效落子经验。

修改效果

AI不再完全随机落子,会优先选择已有棋子附近的位置,落子更具合理性;Q表能够正常更新,随着对局次数增加,AI会逐步避开无意义的落子位置,强化学习开始发挥作用,但此时AI仍不会判断局势、不会连子和防守,仅解决了”乱落子”的基础问题。

2. 进阶问题:AI优先附近落子,但不会判断局势

问题表现

解决随机落子问题后,AI能够优先选择已有棋子附近的位置落子,但依然存在明显缺陷——不会判断对局局势:既不会主动形成3连、4连的进攻态势,也不会识别玩家的连子威胁并堵塞;比如玩家连续落子形成3连,AI依然会在无关位置落子,不会针对性防守;AI自身有机会形成3连时,也会错过机会,始终处于”被动跟随”状态,对弈毫无挑战性。

排查与错误原因

核心原因是缺乏局势判断逻辑,AI仅能根据”位置是否相邻”选择落子,无法识别”连子威胁”和”连子优势”,具体体现在:

  • 未实现连子检测功能:没有编写count_continuous_pieces(连续棋子计数)、check_opponent_threat(对手威胁检测)等方法,无法识别玩家和AI自身的3连、4连状态;
  • 落子优先级混乱:仅设置了”相邻位置优先”,未定义”防守>进攻>常规落子”的优先级,AI无法判断”哪个位置更重要”;
  • 强化学习与局势判断脱节:Q-Learning仅能根据历史落子经验优化选择,但无法理解”连子=优势””漏堵=输棋”的核心规则,经验积累效率极低。

修改方案

  1. 新增局势判断核心方法:在BoardManager类中,实现count_continuous_pieces方法,用于统计某一方向上的连续棋子数;新增check_opponent_threat方法,专门检测玩家的3连、4连威胁,为防守策略提供支撑;
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def count_continuous_pieces(self, x, y, player, dx, dy):
"""统计某方向上的连续棋子数(含当前位置)"""
count = 0
nx, ny = x, y
while 0 <= nx < self.size and 0 <= ny < self.size and self.board[nx, ny] == player:
count += 1
nx += dx
ny += dy
return count

def check_opponent_threat(self):
"""检测对手(玩家)的3连/4连威胁,返回优先级堵塞位置"""
opponent = 1
four_threats = [] # 对手4连(必堵)
three_threats = [] # 对手3连(次必堵)
directions = [(0, 1), (1, 0), (1, 1), (1, -1)]

for x in range(self.size):
for y in range(self.size):
if self.board[x, y] == 0:
for dx, dy in directions:
left_count = self.count_continuous_pieces(x - dx, y - dy, opponent, -dx, -dy)
right_count = self.count_continuous_pieces(x + dx, y + dy, opponent, dx, dy)
total = left_count + right_count

if total == 4:
if (x, y) not in four_threats:
four_threats.append((x, y))
elif total == 3:
left_block = not (0 <= x - (left_count + 1)*dx < self.size and 0 <= y - (left_count + 1)*dy < self.size)
right_block = not (0 <= x + (right_count + 1)*dx < self.size and 0 <= y + (right_count + 1)*dy < self.size)
if not (left_block and right_block) and (x, y) not in three_threats:
three_threats.append((x, y))

if four_threats:
return four_threats[0]
elif three_threats:
return three_threats[0]
return None
  1. 优化落子优先级:在ai_choose_action方法中,新增”堵塞对手4连→堵塞对手3连”的防守优先级,优先处理玩家的致命威胁,再选择常规落子位置;

  2. 联动强化学习与局势判断:将局势判断的结果融入Q-Learning的奖励机制,若AI成功堵塞玩家的4连,给予正向奖励(reward=10);若AI漏堵玩家的4连导致输棋,给予负向奖励(reward=-20),引导强化学习向”防守优先”的方向积累经验。

修改效果

AI具备了基础的局势判断能力,能够识别玩家的3连、4连威胁,并优先落子堵塞;落子不再局限于”相邻位置”,会主动选择能遏制玩家进攻的位置,对弈的挑战性显著提升;但此时AI仍存在一个关键缺陷——只会防守,不会主动追求胜利,即便自身形成4连、有机会一步获胜,也会选择防守或常规落子,浪费获胜机会。

3. 最终问题:AI会堵塞你落子,但不会主动胜利

问题表现

经过前两次修改,AI的防守能力已经基本完善,能够精准堵塞玩家的3连、4连威胁,不会出现”漏堵致命威胁”的低级失误,但新的问题随之出现:AI只会被动防守,不会主动进攻、追求胜利;比如AI自身已经形成4连,只需补子一步就能形成5连获胜,但AI依然会选择堵塞玩家的无关3连,或落在其他无意义的位置;即便没有玩家的威胁,AI也不会主动形成3连、4连,始终处于”被动防守”状态,无法主动结束对局。

排查与错误原因

核心原因是缺乏AI自身进攻优势的检测逻辑,落子优先级仅侧重防守,未兼顾进攻,具体体现在:

  • 未检测AI自身的连子优势:仅实现了check_opponent_threat(检测玩家威胁),未编写检测AI自身3连、4连优势的方法,AI无法知晓自己的连子状态,自然不会主动补子冲5;
  • 落子优先级失衡:仅设置了”防守优先级”,未将”AI自身冲5、冲4”纳入高优先级,导致AI即便有获胜机会,也会优先选择防守,忽视进攻;
  • 强化学习奖励机制偏向防守:之前的奖励机制仅针对防守行为给予正向奖励,未对AI主动连子、冲5获胜的行为给予奖励,导致AI缺乏”进攻动力”。

修改方案

  1. 新增AI进攻优势检测方法:在BoardManager类中,新增check_ai_advantage方法,与check_opponent_threat对称,专门检测AI自身的3连、4连优势,识别”补子即获胜””补子即形成4连”的关键位置;
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def check_ai_advantage(self):
"""检测AI自身的3连/4连优势,返回优先级冲5位置"""
ai = 2
four_advantages = [] # AI4连(补子即5连获胜)
three_advantages = [] # AI3连(补子即4连)
directions = [(0, 1), (1, 0), (1, 1), (1, -1)]

for x in range(self.size):
for y in range(self.size):
if self.board[x, y] == 0:
for dx, dy in directions:
left_count = self.count_continuous_pieces(x - dx, y - dy, ai, -dx, -dy)
right_count = self.count_continuous_pieces(x + dx, y + dy, ai, dx, dy)
total = left_count + right_count

if total == 4:
if (x, y) not in four_advantages:
four_advantages.append((x, y))
elif total == 3:
left_block = not (0 <= x - (left_count + 1)*dx < self.size and 0 <= y - (left_count + 1)*dy < self.size)
right_block = not (0 <= x + (right_count + 1)*dx < self.size and 0 <= y + (right_count + 1)*dy < self.size)
if not (left_block and right_block) and (x, y) not in three_advantages:
three_advantages.append((x, y))

if four_advantages:
return four_advantages[0]
elif three_advantages:
return three_advantages[0]
return None
  1. 调整落子优先级(核心修改):重新定义AI落子优先级,兼顾防守与进攻,最终确定优先级为:堵对手4连(保命)→ 补自己4连(冲5获胜)→ 堵对手3连(防威胁)→ 补自己3连(造威胁)→ 模型+Q-Learning预测 → 随机落子,确保AI在守住致命威胁的前提下,优先追求自身获胜;
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def ai_choose_action(self):
"""AI落子逻辑:防致命 > 追获胜 > 基础策略"""
# 第一步:优先堵对手4连(保命)
opponent_four_threat = self.board_manager.check_opponent_threat()
if opponent_four_threat is not None:
x, y = opponent_four_threat
if self.board_manager.board[x, y] == 0:
return x, y

# 第二步:优先补自己4连(冲5获胜)
ai_four_advantage = self.board_manager.check_ai_advantage()
if ai_four_advantage is not None:
x, y = ai_four_advantage
if self.board_manager.board[x, y] == 0:
return x, y

# 第三步:堵对手3连(防威胁)
opponent_three_threat = self.board_manager.check_opponent_threat()
if opponent_three_threat is not None and opponent_three_threat != opponent_four_threat:
x, y = opponent_three_threat
if self.board_manager.board[x, y] == 0:
return x, y

# 第四步:补自己3连(造威胁)
ai_three_advantage = self.board_manager.check_ai_advantage()
if ai_three_advantage is not None and ai_three_advantage != ai_four_advantage:
x, y = ai_three_advantage
if self.board_manager.board[x, y] == 0:
return x, y

# 第五步:模型+Q-Learning决策(无攻防威胁时)
# ...(后续代码)
  1. 优化强化学习奖励机制:新增进攻相关奖励,若AI补自己4连获胜,给予高额正向奖励(reward=50);若AI补自己3连形成4连,给予正向奖励(reward=20),引导AI主动进攻、积累进攻经验;同时,加载训练好的CNN模型,让AI在无攻防威胁时,能够基于模型预测选择最优进攻位置。

修改效果

AI彻底实现”攻防兼备”:既能精准堵塞玩家的3连、4连威胁,避免自身输棋;也能主动识别自身的3连、4连优势,优先补子冲5获胜、冲4造威胁;当玩家与AI同时有致命威胁时,AI会先堵对手4连(保命),再补自己4连(获胜),符合五子棋的对弈逻辑;强化学习的Q表也能同时积累防守和进攻经验,AI的决策越来越灵活,对弈体验达到预期。

附录:完整项目代码

附录1:gomoku_data.py(高质量自对弈样本生成)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import numpy as np
import random

# 棋盘管理类:负责棋盘状态维护与更新
class BoardManager:
def __init__(self):
self.size = 15
# 3通道棋盘:[0]玩家1(黑)、[1]玩家2(白)、[2]空位置(辅助标识)
self.board = np.zeros((self.size, self.size, 3), dtype=np.int8)
self.history = []

def update_board(self, x, y, player):
"""更新棋盘:落子后更新对应通道,记录历史状态"""
if not (0 <= x < self.size and 0 <= y < self.size):
return False

# 确保当前位置为空
if self.board[x, y, 0] == 1 or self.board[x, y, 1] == 1:
return False

# 记录历史
self.history.append(np.copy(self.board))

# 重置当前位置所有通道,设置对应玩家通道
self.board[x, y, :] = 0
self.board[x, y, player-1] = 1
self.board[x, y, 2] = 0 # 空位置通道置0

# 补充空位置通道(可选,优化样本特征)
for i in range(self.size):
for j in range(self.size):
if self.board[i, j, 0] == 0 and self.board[i, j, 1] == 0:
self.board[i, j, 2] = 1

return True

def get_state(self):
"""获取当前棋盘状态副本"""
return np.copy(self.board)

# 合法落子校验类
class Validator:
def __init__(self, board_manager):
self.board_manager = board_manager
self.size = board_manager.size

def check_valid(self, x, y):
"""判断单个位置是否合法(空位置且在棋盘内)"""
if not (0 <= x < self.size and 0 <= y < self.size):
return False
return self.board_manager.board[x, y, 0] == 0 and self.board_manager.board[x, y, 1] == 0

def get_all_valid_positions(self):
"""获取所有合法落子位置"""
valid_pos = []
for x in range(self.size):
for y in range(self.size):
if self.check_valid(x, y):
valid_pos.append((x, y))
return valid_pos

# 胜负判断函数:检测五子连珠
def judge_win(board, x, y, player):
size = board.shape[0]
directions = [(0, 1), (1, 0), (1, 1), (1, -1)] # 横、竖、正斜、反斜
player_channel = player - 1

for dx, dy in directions:
count = 1
# 正向遍历
nx, ny = x + dx, y + dy
while 0 <= nx < size and 0 <= ny < size and board[nx, ny, player_channel] == 1:
count += 1
nx += dx
ny += dy

# 反向遍历
nx, ny = x - dx, y - dy
while 0 <= nx < size and 0 <= ny < size and board[nx, ny, player_channel] == 1:
count += 1
nx -= dx
ny -= dy

# 五子连珠判定
if count >= 5:
return True
return False

# 连子判断函数:检测n连子(n=3/4)
def has_n_pieces(board, x, y, player, n):
size = board.shape[0]
directions = [(0, 1), (1, 0), (1, 1), (1, -1)]
player_channel = player - 1

for dx, dy in directions:
count = 1
# 正向遍历
nx, ny = x + dx, y + dy
while 0 <= nx < size and 0 <= ny < size and board[nx, ny, player_channel] == 1:
count += 1
nx += dx
ny += dy

# 反向遍历
nx, ny = x - dx, y - dy
while 0 <= nx < size and 0 <= ny < size and board[nx, ny, player_channel] == 1:
count += 1
nx -= dx
ny -= dy

# n连子判定
if count >= n:
return True
return False

# 高优先级落子策略:生成带逻辑的自对弈落子
def get_better_action(validator, board_manager, current_player):
valid_pos = validator.get_all_valid_positions()
if not valid_pos:
return None, None

board = board_manager.get_state()
size = board_manager.size
opponent = 2 if current_player == 1 else 1

# 定义各优先级落子列表
win_pos = [] # 直接赢棋(5连)
block_win_pos = [] # 阻挡对手赢棋
four_pos = [] # 自己连4
block_four_pos = [] # 阻挡对手连4
three_pos = [] # 自己连3

# 遍历所有合法位置,评估落子价值
for (x, y) in valid_pos:
# 评估自己落子价值
temp_board_self = board.copy()
temp_board_self[x, y, :] = 0
temp_board_self[x, y, current_player-1] = 1
is_self_win = judge_win(temp_board_self, x, y, current_player)
has_self_four = has_n_pieces(temp_board_self, x, y, current_player, 4)
has_self_three = has_n_pieces(temp_board_self, x, y, current_player, 3)

# 评估阻挡对手价值
temp_board_opp = board.copy()
temp_board_opp[x, y, :] = 0
temp_board_opp[x, y, opponent-1] = 1
is_opp_win = judge_win(temp_board_opp, x, y, opponent)
has_opp_four = has_n_pieces(temp_board_opp, x, y, opponent, 4)

# 按优先级归类
if is_self_win:
win_pos.append((x, y))
elif is_opp_win:
block_win_pos.append((x, y))
elif has_self_four:
four_pos.append((x, y))
elif has_opp_four:
block_four_pos.append((x, y))
elif has_self_three:
three_pos.append((x, y))

# 按优先级选择落子
if win_pos:
return random.choice(win_pos)
elif block_win_pos:
return random.choice(block_win_pos)
elif four_pos:
return random.choice(four_pos)
elif block_four_pos:
return random.choice(block_four_pos)
elif three_pos:
return random.choice(three_pos)

# 无连子机会,选相邻位置(提升样本策略性)
adjacent_pos = []
for (x, y) in valid_pos:
has_adjacent = False
for dx in [-1, 0, 1]:
for dy in [-1, 0, 1]:
if dx == 0 and dy == 0:
continue
nx, ny = x + dx, y + dy
if 0 <= nx < size and 0 <= ny < size:
if board[nx, ny, 0] == 1 or board[nx, ny, 1] == 1:
has_adjacent = True
break
if has_adjacent:
break
if has_adjacent:
adjacent_pos.append((x, y))

if adjacent_pos:
return random.choice(adjacent_pos)
else:
return random.choice(valid_pos)

# 生成自对弈数据主函数
def generate_self_play_data(num_games=2000):
data = []

for game_idx in range(num_games):
# 进度提示
if (game_idx + 1) % 200 == 0:
print(f"已生成 {game_idx+1}/{num_games} 局对局")

# 初始化每局游戏
board_manager = BoardManager()
validator = Validator(board_manager)
current_player = 1
game_over = False

while not game_over:
valid_pos = validator.get_all_valid_positions()
if not valid_pos: # 棋盘下满,平局
break

# 获取高优先级落子
x, y = get_better_action(validator, board_manager, current_player)
if x is None or y is None:
break

# 记录样本:落子前状态 + 落子位置索引(x*15 + y)
state = board_manager.get_state()
action_idx = x * board_manager.size + y
data.append((state, action_idx))

# 更新棋盘并判断胜负
board_manager.update_board(x, y, current_player)
if judge_win(board_manager.get_state(), x, y, current_player):
game_over = True
break

# 切换玩家
current_player = 2 if current_player == 1 else 1

# 保存数据到npz文件
if len(data) == 0:
print("警告:未生成任何训练数据!")
return

# 转换为numpy数组,适配PyTorch训练
states = np.array([d[0] for d in data], dtype=np.float32)
actions = np.array([d[1] for d in data], dtype=np.int64)

np.savez("gomoku_train_data.npz", states=states, actions=actions)
print(f"数据生成完成!共 {len(states)} 个样本,保存为 gomoku_train_data.npz")

if __name__ == "__main__":
generate_self_play_data(num_games=2000)

附录2:train_model.py(CNN模型训练)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

# 设备配置:优先使用GPU,无GPU则使用CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. 定义CNN模型(与main.py中模型结构完全一致,避免加载不匹配)
class GomokuModel(nn.Module):
def __init__(self):
super(GomokuModel, self).__init__()

# 卷积层:提取棋盘空间特征
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(16)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(32)

# 全连接层:映射为落子位置得分
self.fc1 = nn.Linear(32 * 15 * 15, 128)
self.fc2 = nn.Linear(128, 15 * 15) # 输出225个落子位置的得分

# 激活函数与Dropout
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.3)

def forward(self, x):
# 卷积层 + 批归一化 + 激活函数
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))

# 展平特征图
x = x.view(-1, 32 * 15 * 15)

# 全连接层 + Dropout
x = self.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)

return x

# 2. 自定义Dataset类:加载自对弈样本
class GomokuDataset(Dataset):
def __init__(self, data_path):
# 加载npz文件
data = np.load(data_path)
self.states = data["states"]
self.actions = data["actions"]

# 调整数据形状:(N, H, W, C) → (N, C, H, W)(适配PyTorch卷积层输入格式)
self.states = np.transpose(self.states, (0, 3, 1, 2))

def __len__(self):
"""返回样本总数"""
return len(self.actions)

def __getitem__(self, idx):
"""返回单个样本(状态张量 + 动作标签)"""
state = torch.tensor(self.states[idx], dtype=torch.float32)
action = torch.tensor(self.actions[idx], dtype=torch.long)
return state, action

# 3. 模型训练函数
def train_model(data_path="gomoku_train_data.npz", epochs=50, batch_size=64, lr=0.001):
# 加载数据集
dataset = GomokuDataset(data_path)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

# 初始化模型、损失函数、优化器
model = GomokuModel().to(device)
criterion = nn.CrossEntropyLoss() # 分类损失函数,适配落子位置预测
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=10, gamma=0.8) # 学习率调度器

# 开始训练
model.train()
for epoch in range(epochs):
total_loss = 0.0

# 进度条展示训练过程
with tqdm(total=len(dataloader), desc=f"Epoch {epoch+1}/{epochs}") as pbar:
for batch_idx, (states, actions) in enumerate(dataloader):
# 数据移至设备(GPU/CPU)
states = states.to(device)
actions = actions.to(device)

# 前向传播
outputs = model(states)
loss = criterion(outputs, actions)

# 反向传播与优化
optimizer.zero_grad()
loss.backward()
optimizer.step()

# 累计损失值
total_loss += loss.item()
pbar.update(1)
pbar.set_postfix({"batch_loss": loss.item(), "avg_loss": total_loss/(batch_idx+1)})

# 更新学习率
scheduler.step()

# 输出每轮平均损失
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1} 平均损失:{avg_loss:.6f}")

# 保存训练好的模型
torch.save(model.state_dict(), "gomoku_model.pth")
print("模型训练完成!已保存为 gomoku_model.pth")

if __name__ == "__main__":
train_model(epochs=50, batch_size=64, lr=0.001)

附录3:main.py(可视化对弈与AI核心决策)

import tkinter as tk
from tkinter import messagebox
import numpy as np
import torch
import torch.nn as nn

# ---------------------- 1. 模型定义(与 train_model.py 完全一致,避免加载不匹配)----------------------
class GomokuModel(nn.Module):
    def __init__(self):
        super(GomokuModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.bn2 = nn.BatchNorm2d(32)
        self.fc1 = nn.Linear(32 * 15 * 15, 128)
        self.fc2 = nn.Linear(128, 15*15)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
    
    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = x.view(-1, 32 * 15 * 15)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# ---------------------- 2. 核心工具类(棋盘管理、攻防检测)----------------------
class BoardManager:
    def __init__(self, size=15):
        self.size = size
        # 棋盘状态:0=空,1=玩家(黑棋),2=AI(白棋)
        self.board = np.zeros((size, size), dtype=np.int8)
        # 模型输入用的3通道状态(one-hot编码)
        self.model_board = np.zeros((3, size, size), dtype=np.float32)
    
    def reset(self):
        """重置棋盘"""
        self.board = np.zeros((self.size, self.size), dtype=np.int8)
        self.model_board = np.zeros((3, self.size, self.size), dtype=np.float32)
    
    def update_board(self, x, y, player):
        """更新棋盘状态(player=1/2),返回是否更新成功"""
        if 0 <= x < self.size and 0 <= y < self.size and self.board[x, y] == 0:
            self.board[x, y] = player
            
            # 更新模型输入的3通道棋盘(one-hot编码)
            self.model_board = np.zeros((3, self.size, self.size), dtype=np.float32)
            self.model_board[0][self.board == 1] = 1.0  # 玩家棋子通道
            self.model_board[1][self.board == 2] = 1.0  # AI棋子通道
            self.model_board[2][self.board == 0] = 1.0  # 空位置通道
            
            return True
        return False
    
    def get_model_input(self):
        """获取模型输入的张量(增加batch维度,适配PyTorch)"""
        input_tensor = torch.from_numpy(self.model_board).unsqueeze(0)
        return input_tensor
    
    def is_win(self, x, y, player):
        """判断落子后是否获胜(五子连珠)"""
        directions = [(0, 1), (1, 0), (1, 1), (1, -1)]  # 横、竖、正斜、反斜
        
        for dx, dy in directions:
            count = 1
            
            # 正向遍历
            nx, ny = x + dx, y + dy
            while 0 <= nx < self.size and 0 <= ny < self.size and self.board[nx, ny] == player:
                count += 1
                nx += dx
                ny += dy
            
            # 反向遍历
            nx, ny = x - dx, y - dy
            while 0 <= nx < self.size and 0 <= ny < self.size and self.board[nx, ny] == player:
                count += 1
                nx -= dx
                ny -= dy
            
            if count >= 5:
                return True
        return False
    
    def count_continuous_pieces(self, x, y, player, dx, dy):
        """统计某方向上的连续棋子数(含当前位置)"""
        count = 0
        nx, ny = x, y
        while 0 <= nx < self.size and 0 <= ny < self.size and self.board[nx, ny] == player:
            count += 1
            nx += dx
            ny += dy
        return count
    
    def check_opponent_threat(self):
        """检测对手(玩家)的3连/4连威胁,返回优先级堵塞位置"""
        opponent = 1
        four_threats = []  # 对手4连(必堵)
        three_threats = []  # 对手3连(次必堵)
        directions = [(0, 1), (1, 0), (1, 1), (1, -1)]
        
        for x in range(self.size):
            for y in range(self.size):
                if self.board[x, y] == 0:
                    for dx, dy in directions:
                        left_count = self.count_continuous_pieces(x - dx, y - dy, opponent, -dx, -dy)
                        right_count = self.count_continuous_pieces(x + dx, y + dy, opponent, dx, dy)
                        total = left_count + right_count
                        
                        if total == 4:
                            if (x, y) not in four_threats:
                                four_threats.append((x, y))
                        elif total == 3:
                            left_block = not (0 <= x - (left_count + 1)*dx < self.size and 0 <= y - (left_count + 1)*dy < self.size)
                            right_block = not (0 <= x + (right_count + 1)*dx < self.size and 0 <= y + (right_count + 1)*dy < self.size)
                            if not (left_block and right_block) and (x, y) not in three_threats:
                                three_threats.append((x, y))
        
        if four_threats:
            return four_threats[0]
        elif three_threats:
            return three_threats[0]
        return None
    
    def check_ai_advantage(self):
        """检测AI自身的3连/4连优势,返回优先级冲5位置"""
        ai = 2
        four_advantages = []  # AI4连(补子即5连获胜)
        three_advantages = []  # AI3连(补子即4连)
        directions = [(0, 1), (1, 0), (1, 1), (1, -1)]
        
        for x in range(self.size):
            for y in range(self.size):
                if self.board[x, y] == 0:
                    for dx, dy in directions:
                        left_count = self.count_continuous_pieces(x - dx, y - dy, ai, -dx, -dy)
                        right_count = self.count_continuous_pieces(x + dx, y + dy, ai, dx, dy)
                        total = left_count + right_count
                        
                        if total == 4:
                            if (x, y) not in four_advantages:
                                four_advantages.append((x, y))
                        elif total == 3:
                            left_block = not (0 <= x - (left_count + 1)*dx < self.size and 0 <= y - (left_count + 1)*dy < self.size)
                            right_block = not (0 <= x + (right_count + 1)*dx < self.size and 0 <= y + (right_count + 1)*dy < self.size)
                            if not (left_block and right_block) and (x, y) not in three_advantages:
                                three_advantages.append((x, y))
        
        if four_advantages:
            return four_advantages[0]
        elif three_advantages:
            return three_advantages[0]
        return None

# ---------------------- 3. Q-Learning 优化器(在线优化落子选择)----------------------
class QLearningAgent:
    def __init__(self, size=15, lr=0.1, gamma=0.9, epsilon=0.1):
        self.size = size
        self.lr = lr  # 学习率
        self.gamma = gamma  # 折扣因子
        self.epsilon = epsilon  # 探索率
        self.q_table = {}  # Q表:key=棋盘状态哈希,value=动作Q值
    
    def get_state_hash(self, board):
        """将棋盘状态转换为哈希值(用于Q表存储)"""
        return hash(board.tobytes())
    
    def get_q_value(self, state_hash, action):
        """获取Q值(不存在则返回0)"""
        if state_hash not in self.q_table:
            self.q_table[state_hash] = np.zeros(self.size * self.size, dtype=np.float32)
        return self.q_table[state_hash][action]
    
    def update_q_value(self, state_hash, action, reward, next_state_hash):
        """更新Q表(Q-Learning核心公式)"""
        current_q = self.get_q_value(state_hash, action)
        next_max_q = np.max(self.q_table.get(next_state_hash, np.zeros(self.size * self.size)))
        new_q = current_q + self.lr * (reward + self.gamma * next_max_q - current_q)
        self.q_table[state_hash][action] = new_q
    
    def choose_action(self, model_output, valid_actions):
        """ε-贪心策略:结合模型输出与Q值选择动作"""
        if np.random.random() < self.epsilon:
            # 探索:随机选择合法动作
            return np.random.choice(valid_actions)
        else:
            # 利用:结合模型得分与Q值选择最优动作
            q_values = self.q_table.get(self.get_state_hash(model_output), np.zeros_like(model_output))
            combined_scores = model_output + q_values
            valid_scores = combined_scores[valid_actions]
            best_idx = np.argmax(valid_scores)
            return valid_actions[best_idx]

# ---------------------- 4. 五子棋GUI界面与主逻辑----------------------
class GomokuGame:
    def __init__(self, root):
        self.root = root
        self.root.title("五子棋(AI对战-攻防兼备)")
        self.size = 15
        self.cell_size = 30
        self.canvas_width = self.size * self.cell_size
        self.canvas_height = self.size * self.cell_size
        
        # 初始化核心组件
        self.board_manager = BoardManager(self.size)
        self.q_agent = QLearningAgent(self.size)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.load_model()
        self.game_running = False
        
        # 创建GUI组件
        self.create_widgets()
    
    def create_widgets(self):
        """创建GUI界面元素"""
        # 棋盘画布
        self.canvas = tk.Canvas(
            self.root,
            width=self.canvas_width,
            height=self.canvas_height,
            bg="#F0D9B5"
        )
        self.canvas.pack(padx=10, pady=10)
        self.canvas.bind("<Button-1>", self.on_click)  # 绑定鼠标点击落子事件
        
        # 按钮框架
        self.btn_frame = tk.Frame(self.root)
        self.btn_frame.pack(pady=5)
        
        # 开始新游戏按钮
        self.start_btn = tk.Button(
            self.btn_frame,
            text="开始新游戏",
            command=self.start_game,
            width=15
        )
        self.start_btn.grid(row=0, column=0, padx=5)
        
        # 重置棋盘按钮
        self.reset_btn = tk.Button(
            self.btn_frame,
            text="重置棋盘",
            command=self.reset_game,
            width=15
        )
        self.reset_btn.grid(row=0, column=1, padx=5)
        
        # 绘制初始棋盘
        self.draw_board()
    
    def draw_board(self):
        """绘制棋盘网格与星位"""
        self.canvas.delete("all")
        
        # 绘制横线与竖线
        for i in range(self.size):
            # 横线
            self.canvas.create_line(
                self.cell_size // 2,
                self.cell_size // 2 + i * self.cell_size,
                self.canvas_width - self.cell_size // 2,
                self.cell_size // 2 + i * self.cell_size,
                fill="#000000"
            )
            # 竖线
            self.canvas.create_line(
                self.cell_size // 2 + i * self.cell_size,
                self.cell_size // 2,
                self.cell_size // 2 + i * self.cell_size,
                self.canvas_height - self.cell_size // 2,
                fill="#000000"
            )
        
        # 绘制星位(五子棋标准星位)
        star_positions = [3, 7, 11]
        for x in star_positions:
            for y in star_positions:
                self.canvas.create_oval(
                    self.cell_size // 2 + x * self.cell_size - 5,
                    self.cell_size // 2 + y * self.cell_size - 5,
                    self.cell_size // 2 + x * self.cell_size + 5,
                    self.cell_size // 2 + y * self.cell_size + 5,
                    fill="#000000"
                )
    
    def draw_piece(self, x, y, player):
        """绘制棋子(1=黑棋,2=白棋)"""
        center_x = self.cell_size // 2 + x * self.cell_size
        center_y = self.cell_size // 2 + y * self.cell_size
        radius = self.cell_size // 2 - 2
        
        if player == 1:
            # 玩家:黑棋
            self.canvas.create_oval(
                center_x - radius,
                center_y - radius,
                center_x + radius,
                center_y + radius,
                fill="#000000",
                outline="#000000"
            )
        else:
            # AI:白棋
            self.canvas.create_oval(
                center_x - radius,
                center_y - radius,
                center_x + radius,
                center_y + radius,
                fill="#FFFFFF",
                outline="#000000"
            )
    
    def load_model(self):
        """加载训练好的五子棋模型"""
        model = GomokuModel().to(self.device)
        try:
            model.load_state_dict(torch.load("gomoku_model.pth", map_location=self.device))
            model.eval()  # 切换为评估模式(关闭Dropout)
            messagebox.showinfo("模型加载", "成功加载训练好的五子棋模型!")
        except FileNotFoundError:
            messagebox.showwarning("模型加载", "未找到 gomoku_model.pth,AI将使用策略落子!")
            model = None
        except RuntimeError as e:
            messagebox.showerror("模型加载失败", f"模型结构不匹配:{str(e)}")
            model = None
        return model
    
    def ai_choose_action(self):
        """AI落子逻辑:防致命 > 追获胜 > 基础策略"""
        # 第一步:优先堵对手4连(保命)
        opponent_four_threat = self.board_manager.check_opponent_threat()
        if opponent_four_threat is not None:
            x, y = opponent_four_threat
            if self.board_manager.board[x, y] == 0:
                return x, y
        
        # 第二步:优先补自己4连(冲5获胜)
        ai_four_advantage = self.board_manager.check_ai_advantage()
        if ai_four_advantage is not None:
            x, y = ai_four_advantage
            if self.board_manager.board[x, y] == 0:
                return x, y
        
        # 第三步:堵对手3连(防威胁)
        opponent_three_threat = self.board_manager.check_opponent_threat()
        if opponent_three_threat is not None and opponent_three_threat != opponent_four_threat:
            x, y = opponent_three_threat
            if self.board_manager.board[x, y] == 0:
                return x, y
        
        # 第四步:补自己3连(造威胁)
        ai_three_advantage = self.board_manager.check_ai_advantage()
        if ai_three_advantage is not None and ai_three_advantage != ai_four_advantage:
            x, y = ai_three_advantage
            if self.board_manager.board[x, y] == 0:
                return x, y
        
        # 第五步:模型+Q-Learning决策(无攻防威胁时)
        valid_positions = np.argwhere(self.board_manager.board == 0)
        if len(valid_positions) == 0:
            return None, None
        
        valid_actions = [x * self.size + y for (x, y) in valid_positions]
        
        if self.model is not None:
            with torch.no_grad():  # 关闭梯度计算,提升推理速度
                input_tensor = self.board_manager.get_model_input().to(self.device)
                model_output = self.model(input_tensor).squeeze().cpu().numpy()
            best_action_idx = self.q_agent.choose_action(model_output, valid_actions)
            x = best_action_idx // self.size
            y = best_action_idx % self.size
        else:
            # 无模型:随机选择合法位置
            x, y = valid_positions[np.random.randint(0, len(valid_positions))]
        
        return x, y
    
    def on_click(self, event):
        """鼠标点击事件(玩家落子)"""
        if not self.game_running:
            return
        
        # 转换鼠标坐标为棋盘坐标
        x = int((event.x - self.cell_size // 2) // self.cell_size)
        y = int((event.y - self.cell_size // 2) // self.cell_size)
        
        # 验证并更新玩家落子
        if 0 <= x < self.size and 0 <= y < self.size and self.board_manager.board[x, y] == 0:
            self.board_manager.update_board(x, y, 1)
            self.draw_piece(x, y, 1)
            
            # 判断玩家是否获胜
            if self.board_manager.is_win(x, y, 1):
                messagebox.showinfo("游戏结束", "恭喜你!获胜了!")
                self.game_running = False
                return
            
            # AI落子(包含攻防兜底策略)
            ai_x, ai_y = self.ai_choose_action()
            if ai_x is not None and ai_y is not None:
                self.board_manager.update_board(ai_x, ai_y, 2)
                self.draw_piece(ai_x, ai_y, 2)
                
                # 判断AI是否获胜
                if self.board_manager.is_win(ai_x, ai_y, 2):
                    messagebox.showinfo("游戏结束", "AI获胜!再试一次吧!")
                    self.game_running = False
                    return
    
    def start_game(self):
        """开始新游戏"""
        self.reset_game()
        self.game_running = True
        messagebox.showinfo("游戏开始", "AI会优先堵你的3连/4连,也会优先补自己的3连/4连冲5获胜!")
    
    def reset_game(self):
        """重置游戏"""
        self.board_manager.reset()
        self.draw_board()
        self.game_running = False

# ---------------------- 5. 程序入口 ----------------------
if __name__ == "__main__":
    root = tk.Tk()
    game = GomokuGame(root)
    root.mainloop()