强化学习:基于蒙特卡洛树和策略价值网络的深度强化学习五子棋

简介: 强化学习:基于蒙特卡洛树和策略价值网络的深度强化学习五子棋

实现了基于蒙特卡洛树和策略价值网络的深度强化学习五子棋(含码源)

  • 特点

    • 自我对弈
    • 详细注释
    • 流程简单
  • 代码结构

    • net:策略价值网络实现
    • mcts:蒙特卡洛树实现
    • server:前端界面代码
    • legacy:废弃代码
    • docs:其他文件
    • utils:工具代码
    • network.py:移植过来的网络结构代码
    • model_5400.pkl:移植过来的网络训练权重
    • train_agent.py:训练脚本
    • web_server.py:对弈服务脚本
    • web_server_demo.py:对弈服务脚本(移植网络)

1.1 流程

1.2策略价值网络

采用了类似ResNet的结构,加入了SPP模块。

(目前,由于训练太耗时间了,连续跑了三个多星期,才跑了2000多个自我对弈的棋谱,经过实验,这个策略网络的表现,目前还是不行,可能育有还没有训练充分)

同时移植了另一个开源的策略网络以及其训练权重(network.py、model_5400.pkl),用于进行仿真演示效果。

1.3 训练

根据注释调整train_agent.py文件,并运行该脚本

部分代码展示:


if __name__ == '__main__':

    conf = LinXiaoNetConfig()
    conf.set_cuda(True)
    conf.set_input_shape(8, 8)
    conf.set_train_info(5, 16, 1e-2)
    conf.set_checkpoint_config(5, 'checkpoints/v2train')
    conf.set_num_worker(0)
    conf.set_log('log/v2train.log')
    # conf.set_pretrained_path('checkpoints/v2m4000/epoch_15')

    init_logger(conf.log_file)
    logger()(conf)

    device = 'cuda' if conf.use_cuda else 'cpu'

    # 创建策略网络
    model = LinXiaoNet(3)
    model.to(device)

    loss_func = AlphaLoss()
    loss_func.to(device)

    optimizer = torch.optim.SGD(model.parameters(), conf.init_lr, 0.9, weight_decay=5e-4)
    lr_schedule = torch.optim.lr_scheduler.StepLR(optimizer, 1, 0.95)

    # initial config tree
    tree = MonteTree(model, device, chess_size=conf.input_shape[0], simulate_count=500)
    data_cache = TrainDataCache(num_worker=conf.num_worker)

    ep_num = 0
    chess_num = 0
    # config train interval
    train_every_chess = 18

    # 加载检查点
    if conf.pretrain_path is not None:
        model_data, optimizer_data, lr_schedule_data, data_cache, ep_num, chess_num = load_checkpoint(conf.pretrain_path)
        model.load_state_dict(model_data)
        optimizer.load_state_dict(optimizer_data)
        lr_schedule.load_state_dict(lr_schedule_data)
        logger()('successfully load pretrained : {}'.format(conf.pretrain_path))

    while True:
        logger()(f'self chess game no.{chess_num+1} start.')
        # 进行一次自我对弈,获取对弈记录
        chess_record = tree.self_game()
        logger()(f'self chess game no.{chess_num+1} end.')
        # 根据对弈记录生成训练数据
        train_data = generate_train_data(tree.chess_size, chess_record)
        # 将训练数据存入缓存
        for i in range(len(train_data)):
            data_cache.push(train_data[i])
        if chess_num % train_every_chess == 0:
            logger()(f'train start.')
            loader = data_cache.get_loader(conf.batch_size)
            model.train()
            for _ in range(conf.epoch_num):
                loss_record = []
                for bat_state, bat_dist, bat_winner in loader:
                    bat_state, bat_dist, bat_winner = bat_state.to(device), bat_dist.to(device), bat_winner.to(device)
                    optimizer.zero_grad()
                    prob, value = model(bat_state)
                    loss = loss_func(prob, value, bat_dist, bat_winner)
                    loss.backward()
                    optimizer.step()
                    loss_record.append(loss.item())
                logger()(f'train epoch {ep_num} loss: {sum(loss_record) / float(len(loss_record))}')
                ep_num += 1
                if ep_num % conf.checkpoint_save_every_num == 0:
                    save_checkpoint(
                        os.path.join(conf.checkpoint_save_dir, f'epoch_{ep_num}'),
                        ep_num, chess_num, model.state_dict(), optimizer.state_dict(), lr_schedule.state_dict(), data_cache
                    )
            lr_schedule.step()
            logger()(f'train end.')
        chess_num += 1
        save_chess_record(
            os.path.join(conf.checkpoint_save_dir, f'chess_record_{chess_num}.pkl'),
            chess_record
        )
        # break

    pass

1.4 仿真实验

根据注释调整web_server.py文件,加载所用的预训练权重,并运行该脚本

浏览器打开网址:http://127.0.0.1:8080/ 进行对弈

部分代码展示

# 用户查询机器落子状态
@app.route('/state/get/<state_id>', methods=['GET'])
def get_state(state_id):
    global state_result
    state_id = int(state_id)
    state = 0
    chess_state = None
    if state_id in state_result.keys() and state_result[state_id] is not None:
        state = 1
        chess_state = state_result[state_id]
        state_result[state_id] = None
    ret = {
        'code': 0,
        'msg': 'OK',
        'data': {
            'state': state,
            'chess_state': chess_state
        }
    }
    return jsonify(ret)


# 游戏开始,为这场游戏创建蒙特卡洛树
@app.route('/game/start', methods=['POST'])
def game_start():
    global trees
    global model, device, chess_size, simulate_count
    tree_id = random.randint(1000, 100000)
    trees[tree_id] = MonteTree(model, device, chess_size=chess_size, simulate_count=simulate_count)
    ret = {
        'code': 0,
        'msg': 'OK',
        'data': {
            'tree_id': tree_id
        }
    }
    return jsonify(ret)


# 游戏结束,销毁蒙特卡洛树
@app.route('/game/end/<tree_id>', methods=['POST'])
def game_end(tree_id):
    global trees
    tree_id = int(tree_id)
    trees[tree_id] = None
    ret = {
        'code': 0,
        'msg': 'OK',
        'data': {}
    }
    return ret


if __name__ == '__main__':
    app.run(
        '0.0.0.0',
        8080
    )

1.5 仿真实验(移植网络)

运行脚本:python web_server_demo.py

浏览器打开网址:http://127.0.0.1:8080/ 进行对弈

码源链接见文末

码源链接

更多优质内容请关注公号&知乎:汀丶人工智能;会提供一些相关的资源和优质文章,免费获取阅读。

相关文章
|
8天前
|
存储 安全 网络安全
云计算时代的网络安全挑战与策略
【10月更文挑战第34天】在数字化转型的浪潮中,云计算作为一项关键技术,正深刻改变着企业的运营方式。然而,随着云服务的普及,网络安全问题也日益凸显。本文将探讨云计算环境下的安全挑战,并提出相应的防护策略。
|
3天前
|
存储 安全 网络安全
云计算与网络安全:探索云服务中的信息安全策略
【10月更文挑战第39天】随着云计算的飞速发展,越来越多的企业和个人将数据和服务迁移到云端。然而,随之而来的网络安全问题也日益突出。本文将从云计算的基本概念出发,深入探讨在云服务中如何实施有效的网络安全和信息安全措施。我们将分析云服务模型(IaaS, PaaS, SaaS)的安全特性,并讨论如何在这些平台上部署安全策略。文章还将涉及最新的网络安全技术和实践,旨在为读者提供一套全面的云计算安全解决方案。
|
3天前
|
云安全 安全 网络安全
云计算与网络安全:技术挑战与解决策略
【10月更文挑战第39天】随着云计算技术的飞速发展,网络安全问题也日益凸显。本文将探讨云计算环境下的网络安全挑战,并提出相应的解决策略。通过分析云服务模型、网络安全威胁以及信息安全技术的应用,我们将揭示如何构建一个安全的云计算环境。
|
6天前
|
网络虚拟化
生成树协议(STP)及其演进版本RSTP和MSTP,旨在解决网络中的环路问题,提高网络的可靠性和稳定性
生成树协议(STP)及其演进版本RSTP和MSTP,旨在解决网络中的环路问题,提高网络的可靠性和稳定性。本文介绍了这三种协议的原理、特点及区别,并提供了思科和华为设备的命令示例,帮助读者更好地理解和应用这些协议。
20 4
|
6天前
|
云安全 安全 网络安全
云计算与网络安全:挑战与应对策略####
云计算作为信息技术的一场革命,为数据存储和计算提供了前所未有的便利和效率。然而,随着云计算的广泛应用,其带来的网络安全问题也日益凸显。本文将探讨云计算环境下的主要网络安全挑战,包括数据泄露、网络攻击、身份和访问管理等问题,并分析云服务提供商和企业用户如何通过技术手段和管理策略来应对这些挑战。此外,还将讨论云计算与信息安全领域的最新发展趋势,旨在为读者提供一个全面的理解和实用的指导。通过深入剖析云计算的工作原理和安全机制,我们可以更好地理解如何保护我们的网络和信息安全。只有云计算提供商和用户共同努力,才能建立一个安全可靠的云计算环境。 ####
|
6天前
|
监控 安全 网络安全
网络安全的盾牌:漏洞防御与加密技术的现代策略
【10月更文挑战第36天】在数字化浪潮中,网络安全成为保护个人隐私和企业资产的关键防线。本文深入探讨网络安全漏洞的成因、影响及防御措施,并分析加密技术如何为信息安全提供坚固保障。通过案例分析和代码示例,揭示提升安全意识的重要性及其在防范网络攻击中的作用,旨在为读者提供一套全面的网络安全解决方案和预防策略。
|
11天前
|
存储 安全 云计算
云上防线:云计算时代的网络安全策略
云上防线:云计算时代的网络安全策略
31 4
|
14天前
|
存储 安全 网络安全
云计算与网络安全:保护数据的新策略
【10月更文挑战第28天】随着云计算的广泛应用,网络安全问题日益突出。本文将深入探讨云计算环境下的网络安全挑战,并提出有效的安全策略和措施。我们将分析云服务中的安全风险,探讨如何通过技术和管理措施来提升信息安全水平,包括加密技术、访问控制、安全审计等。此外,文章还将分享一些实用的代码示例,帮助读者更好地理解和应用这些安全策略。
|
19天前
|
SQL 安全 算法
网络安全漏洞与加密技术:保护信息安全的关键策略
【10月更文挑战第23天】在数字化时代,网络安全漏洞和信息安全问题日益突出。本文将探讨网络安全漏洞的概念、类型以及它们对信息系统的潜在威胁,并介绍加密技术如何成为防御这些安全威胁的有力工具。同时,强调安全意识的重要性,并提出加强网络安全教育和培训的建议。最后,通过一个代码示例,展示如何在网络应用中实现基本的加密措施,以增强读者对网络安全实践的认识。
|
22天前
|
机器学习/深度学习 计算机视觉 网络架构
【YOLO11改进 - C3k2融合】C3k2融合YOLO-MS的MSBlock : 分层特征融合策略,轻量化网络结构
【YOLO11改进 - C3k2融合】C3k2融合YOLO-MS的MSBlock : 分层特征融合策略,轻量化网络结构

热门文章

最新文章