# AlphaGo Zero你也来造一只，PyTorch实现五脏俱全| 附代码

﻿﻿

﻿﻿

﻿

﻿﻿跳跃的样子，写成代码就是：
1class BasicBlock(nn.Module):
2    """
3    Basic residual block with 2 convolutions and a skip connection
4    before the last ReLU activation.
5    """
6
7    def __init__(self, inplanes, planes, stride=1, downsample=None):
8        super(BasicBlock, self).__init__()
9
10        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3,
12        self.bn1 = nn.BatchNorm2d(planes)
13
14        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
16        self.bn2 = nn.BatchNorm2d(planes)
17
18
19    def forward(self, x):
20        residual = x
21
22        out = self.conv1(x)
23        out = F.relu(self.bn1(out))
24
25        out = self.conv2(out)
26        out = self.bn2(out)
27
28        out += residual
29        out = F.relu(out)
30
31        return out


1class Extractor(nn.Module):
2    def __init__(self, inplanes, outplanes):
3        super(Extractor, self).__init__()
4        self.conv1 = nn.Conv2d(inplanes, outplanes, stride=1,
6        self.bn1 = nn.BatchNorm2d(outplanes)
7
8        for block in range(BLOCKS):
9            setattr(self, "res{}".format(block), \
10                BasicBlock(outplanes, outplanes))
11
12
13    def forward(self, x):
14        x = F.relu(self.bn1(self.conv1(x)))
15        for block in range(BLOCKS - 1):
16            x = getattr(self, "res{}".format(block))(x)
17
18        feature_maps = getattr(self, "res{}".format(BLOCKS - 1))(x)
19        return feature_maps


<p style="text-align:center">﻿﻿![image](https://yqfile.alicdn.com/70c7e67978ad20f5587ad71e5b68dad2d7fddb9f.png)</p>
1class PolicyNet(nn.Module):
2    def __init__(self, inplanes, outplanes):
3        super(PolicyNet, self).__init__()
4        self.outplanes = outplanes
5        self.conv = nn.Conv2d(inplanes, 1, kernel_size=1)
6        self.bn = nn.BatchNorm2d(1)
7        self.logsoftmax = nn.LogSoftmax(dim=1)
8        self.fc = nn.Linear(outplanes - 1, outplanes)
9
10
11    def forward(self, x):
12        x = F.relu(self.bn(self.conv(x)))
13        x = x.view(-1, self.outplanes - 1)
14        x = self.fc(x)
15        probas = self.logsoftmax(x).exp()
16
17        return probas


﻿﻿

1class ValueNet(nn.Module):
2    def __init__(self, inplanes, outplanes):
3        super(ValueNet, self).__init__()
4        self.outplanes = outplanes
5        self.conv = nn.Conv2d(inplanes, 1, kernel_size=1)
6        self.bn = nn.BatchNorm2d(1)
7        self.fc1 = nn.Linear(outplanes - 1, 256)
8        self.fc2 = nn.Linear(256, 1)
9
10
11    def forward(self, x):
12        x = F.relu(self.bn(self.conv(x)))
13        x = x.view(-1, self.outplanes - 1)
14        x = F.relu(self.fc1(x))
15        winning = F.tanh(self.fc2(x))
16        return winning


1class Node:
2    def __init__(self, parent=None, proba=None, move=None):
3        self.p = proba
4        self.n = 0
5        self.w = 0
6        self.q = 0
7        self.children = []
8        self.parent = parent
9        self.move = move


﻿﻿

1def select(nodes, c_puct=C_PUCT):
2    " Optimized version of the selection based of the PUCT formula "
3
4    total_count = 0
5    for i in range(nodes.shape[0]):
6        total_count += nodes[i][1]
7
8    action_scores = np.zeros(nodes.shape[0])
9    for i in range(nodes.shape[0]):
10        action_scores[i] = nodes[i][0] + c_puct * nodes[i][2] * \
11                (np.sqrt(total_count) / (1 + nodes[i][1]))
12
13    equals = np.where(action_scores == np.max(action_scores))[0]
14    if equals.shape[0] > 0:
15        return np.random.choice(equals)
16    return equals[0]


1def is_leaf(self):
2    """ Check whether a node is a leaf or not """
3
4    return len(self.children) == 0


1def expand(self, probas):
2    self.children = [Node(parent=self, move=idx, proba=probas[idx]) \
3                for idx in range(probas.shape[0]) if probas[idx] > 0]


1def update(self, v):
2    """ Update the node statistics after a rollout """
3
4    self.w = self.w + v
5    self.q = self.w / self.n if self.n > 0 else 0

1while current_node.parent:
2    current_node.update(v)
3    current_node = current_node.parent


1total = np.sum(action_scores)
2probas = action_scores / total
3move = np.random.choice(action_scores.shape[0], p=probas)


1def self_play():
2    while True:
4        if new_player:
5            player = new_player
6
7        ## Create the self-play match queue of processes
8        results = create_matches(player, cores=PARALLEL_SELF_PLAY,
9                                         match_number=SELF_PLAY_MATCH)
10        for _ in range(SELF_PLAY_MATCH):
11            result = results.get()
12            db.insert({
13                "game": result,
14                "id": game_id
15            })
16            game_id += 1


1def train():
2    criterion = AlphaLoss()
3    dataset = SelfPlayDataset()
5    optimizer = create_optimizer(player, lr,
6                                    param=checkpoint['optimizer'])
7    best_player = deepcopy(player)
9                batch_size=BATCH_SIZE, shuffle=True)
10
11    while True:
12        for batch_idx, (state, move, winner) in enumerate(dataloader):
13
14            ## Evaluate a copy of the current network
15            if total_ite % TRAIN_STEPS == 0:
16                pending_player = deepcopy(player)
17                result = evaluate(pending_player, best_player)
18
19                if result:
20                    best_player = pending_player
21
22            example = {
23                'state': state,
24                'winner': winner,
25                'move' : move
26            }
28            winner, probas = pending_player.predict(example['state'])
29
30            loss = criterion(winner, example['winner'], \
31                            probas, example['move'])
32            loss.backward()
33            optimizer.step()
34
35            ## Fetch new games
36            if total_ite % REFRESH_TICK == 0:
37                last_id = fetch_new_games(collection, dataset, last_id)


1class AlphaLoss(torch.nn.Module):
2    def __init__(self):
3        super(AlphaLoss, self).__init__()
4
5    def forward(self, pred_winner, winner, pred_probas, probas):
6        value_error = (winner - pred_winner) ** 2
7        policy_error = torch.sum((-probas *
8                                (1e-6 + pred_probas).log()), 1)
9        total_error = (value_error.view(-1) + policy_error).mean()


1def evaluate(player, new_player):
2    results = play(player, opponent=new_player)
3    black_wins = 0
4    white_wins = 0
5
6    for result in results:
7        if result[0] == 1:
8            white_wins += 1
9        elif result[0] == 0:
10            black_wins += 1
11
12    ## Check if the trained player (black) is better than
13    ## the current best player depending on the threshold
14    if black_wins >= EVAL_THRESH * len(results):
15        return True
16    return False


SuperGo还年幼，是在9x9棋盘上训练的。

Reddit上面也有同仁发来贺电。

△ 有前途的意思

AlphaGo Zero论文传送门：
https://www.nature.com/articles/nature24270.epdf

+ 订阅