PokéLLMon 源码解析(四)(3)

本文涉及的产品
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
云解析 DNS,旗舰版 1个月
简介: PokéLLMon 源码解析(四)

PokéLLMon 源码解析(四)(2)https://developer.aliyun.com/article/1483673

.\PokeLLMon\poke_env\player\gpt_player.py

import json  # 导入 json 模块
import os  # 导入 os 模块
import random  # 导入 random 模块
from typing import List  # 导入 List 类型提示
from poke_env.environment.abstract_battle import AbstractBattle  # 导入 AbstractBattle 类
from poke_env.environment.double_battle import DoubleBattle  # 导入 DoubleBattle 类
from poke_env.environment.move_category import MoveCategory  # 导入 MoveCategory 类
from poke_env.environment.pokemon import Pokemon  # 导入 Pokemon 类
from poke_env.environment.side_condition import SideCondition  # 导入 SideCondition 类
from poke_env.player.player import Player, BattleOrder  # 导入 Player 和 BattleOrder 类
from typing import Dict, List, Optional, Union  # 导入 Dict, List, Optional, Union 类型提示
from poke_env.environment.move import Move  # 导入 Move 类
import time  # 导入 time 模块
import json  # 再次导入 json 模块(重复导入)
from openai import OpenAI  # 导入 OpenAI 类
from poke_env.data.gen_data import GenData  # 导入 GenData 类
def calculate_move_type_damage_multipier(type_1, type_2, type_chart, constraint_type_list):
    TYPE_list = 'BUG,DARK,DRAGON,ELECTRIC,FAIRY,FIGHTING,FIRE,FLYING,GHOST,GRASS,GROUND,ICE,NORMAL,POISON,PSYCHIC,ROCK,STEEL,WATER'.split(",")
    move_type_damage_multiplier_list = []  # 初始化一个空列表,用于存储每种类型的伤害倍率
    if type_2:  # 如果存在第二种类型
        for type in TYPE_list:  # 遍历每种类型
            move_type_damage_multiplier_list.append(type_chart[type_1][type] * type_chart[type_2][type])  # 计算两种类型之间的伤害倍率并添加到列表中
        move_type_damage_multiplier_dict = dict(zip(TYPE_list, move_type_damage_multiplier_list))  # 将类型和对应的伤害倍率组成字典
    else:  # 如果只有一种类型
        move_type_damage_multiplier_dict = type_chart[type_1]  # 直接使用第一种类型的伤害倍率字典
    effective_type_list = []  # 初始化有效类型列表
    extreme_type_list = []  # 初始化极效类型列表
    resistant_type_list = []  # 初始化抵抗类型列表
    extreme_resistant_type_list = []  # 初始化极度抵抗类型列表
    immune_type_list = []  # 初始化免疫类型列表
    for type, value in move_type_damage_multiplier_dict.items():  # 遍历每种类型及其对应的伤害倍率
        if value == 2:  # 如果伤害倍率为 2
            effective_type_list.append(type)  # 添加到有效类型列表
        elif value == 4:  # 如果伤害倍率为 4
            extreme_type_list.append(type)  # 添加到极效类型列表
        elif value == 1 / 2:  # 如果伤害倍率为 1/2
            resistant_type_list.append(type)  # 添加到抵抗类型列表
        elif value == 1 / 4:  # 如果伤害倍率为 1/4
            extreme_resistant_type_list.append(type)  # 添加到极度抵抗类型列表
        elif value == 0:  # 如果伤害倍率为 0
            immune_type_list.append(type)  # 添加到免疫类型列表
        else:  # 如果伤害倍率为 1
            continue  # 继续循环
    # 如果约束类型列表不为空
    if constraint_type_list:
        # 将极端类型列表与约束类型列表的交集作为新的极端类型列表
        extreme_type_list = list(set(extreme_type_list).intersection(set(constraint_type_list)))
        # 将有效类型列表与约束类型列表的交集作为新的有效类型列表
        effective_type_list = list(set(effective_type_list).intersection(set(constraint_type_list)))
        # 将抗性类型列表与约束类型列表的交集作为新的抗性类型列表
        resistant_type_list = list(set(resistant_type_list).intersection(set(constraint_type_list)))
        # 将极端抗性类型列表与约束类型列表的交集作为新的极端抗性类型列表
        extreme_resistant_type_list = list(set(extreme_resistant_type_list).intersection(set(constraint_type_list)))
        # 将免疫类型列表与约束类型列表的交集作为新的免疫类型列表
        immune_type_list = list(set(immune_type_list).intersection(set(constraint_type_list)))
    # 返回各类型列表的首字母大写形式
    return (list(map(lambda x: x.capitalize(), extreme_type_list)),
           list(map(lambda x: x.capitalize(), effective_type_list)),
           list(map(lambda x: x.capitalize(), resistant_type_list)),
           list(map(lambda x: x.capitalize(), extreme_resistant_type_list)),
           list(map(lambda x: x.capitalize(), immune_type_list)))
# 定义一个函数,用于计算给定精灵对应的移动类型伤害提示
def move_type_damage_wraper(pokemon, type_chart, constraint_type_list=None):
    # 初始化变量,用于存储精灵的两种类型
    type_1 = None
    type_2 = None
    # 如果精灵有第一种类型
    if pokemon.type_1:
        # 获取第一种类型的名称
        type_1 = pokemon.type_1.name
        # 如果精灵有第二种类型
        if pokemon.type_2:
            # 获取第二种类型的名称
            type_2 = pokemon.type_2.name
    # 初始化移动类型伤害提示字符串
    move_type_damage_prompt = ""
    # 调用函数计算移动类型伤害倍数,得到不同类型的列表
    extreme_effective_type_list, effective_type_list, resistant_type_list, extreme_resistant_type_list, immune_type_list = calculate_move_type_damage_multipier(
        type_1, type_2, type_chart, constraint_type_list)
    # 根据不同类型的列表生成移动类型伤害提示
    if extreme_effective_type_list:
        move_type_damage_prompt = (move_type_damage_prompt + " " + ", ".join(extreme_effective_type_list) +
                                   f"-type attack is extremely-effective (4x damage) to {pokemon.species}.")
    if effective_type_list:
        move_type_damage_prompt = (move_type_damage_prompt + " " + ", ".join(effective_type_list) +
                                   f"-type attack is super-effective (2x damage) to {pokemon.species}.")
    if resistant_type_list:
        move_type_damage_prompt = (move_type_damage_prompt + " " + ", ".join(resistant_type_list) +
                                   f"-type attack is ineffective (0.5x damage) to {pokemon.species}.")
    if extreme_resistant_type_list:
        move_type_damage_prompt = (move_type_damage_prompt + " " + ", ".join(extreme_resistant_type_list) +
                                   f"-type attack is highly ineffective (0.25x damage) to {pokemon.species}.")
    if immune_type_list:
        move_type_damage_prompt = (move_type_damage_prompt + " " + ", ".join(immune_type_list) +
                                   f"-type attack is zero effect (0x damage) to {pokemon.species}.")
    # 返回移动类型伤害提示字符串
    return move_type_damage_prompt
# 定义一个类,继承自Player类
class LLMPlayer(Player):
    # 使用 OpenAI API 进行对话生成,返回生成的文本
    def chatgpt(self, system_prompt, user_prompt, model, temperature=0.7, json_format=False, seed=None, stop=[], max_tokens=200) -> str:
        # 创建 OpenAI 客户端对象
        client = OpenAI(api_key=self.api_key)
        # 如果需要返回 JSON 格式的响应
        if json_format:
            # 调用 API 完成对话生成,返回 JSON 格式的响应
            response = client.chat.completions.create(
                response_format={"type": "json_object"},
                model=model,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ],
                temperature=temperature,
                stream=False,
                # seed=seed,
                stop=stop,
                max_tokens=max_tokens
            )
        else:
            # 调用 API 完成对话生成
            response = client.chat.completions.create(
                model=model,
                messages=[
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt}
                ],
                temperature=temperature,
                stream=False,
                # seed=seed,
                max_tokens=max_tokens,
                stop=stop
            )
        # 获取生成的文本内容
        outputs = response.choices[0].message.content
        # 记录完成的 token 数量
        self.completion_tokens += response.usage.completion_tokens
        # 记录 prompt 的 token 数量
        self.prompt_tokens += response.usage.prompt_tokens
        # 返回生成的文本
        return outputs
    # 估算两只精灵之间的对战得分
    def _estimate_matchup(self, mon: Pokemon, opponent: Pokemon):
        # 计算对手对该精灵造成的伤害加成中的最大值
        score = max([opponent.damage_multiplier(t) for t in mon.types if t is not None])
        # 计算该精灵对对手造成的伤害加成中的最大值
        score -= max(
            [mon.damage_multiplier(t) for t in opponent.types if t is not None]
        )
        # 根据速度判断得分
        if mon.base_stats["spe"] > opponent.base_stats["spe"]:
            score += self.SPEED_TIER_COEFICIENT
        elif opponent.base_stats["spe"] > mon.base_stats["spe"]:
            score -= self.SPEED_TIER_COEFICIENT
        # 根据当前生命值比例调整得分
        score += mon.current_hp_fraction * self.HP_FRACTION_COEFICIENT
        score -= opponent.current_hp_fraction * self.HP_FRACTION_COEFICIENT
        return score
    # 判断是否应该使用极巨化
    def _should_dynamax(self, battle: AbstractBattle):
        # 统计队伍中剩余未倒下的精灵数量
        n_remaining_mons = len(
            [m for m in battle.team.values() if m.fainted is False]
        )
        if battle.can_dynamax and self._dynamax_disable is False:
            # 如果只剩下一只全血的精灵
            if (
                len([m for m in battle.team.values() if m.current_hp_fraction == 1])
                == 1
                and battle.active_pokemon.current_hp_fraction == 1
            ):
                return True
            # 如果有对战优势且双方都是全血状态
            if (
                self._estimate_matchup(
                    battle.active_pokemon, battle.opponent_active_pokemon
                )
                > 0
                and battle.active_pokemon.current_hp_fraction == 1
                and battle.opponent_active_pokemon.current_hp_fraction == 1
            ):
                return True
            # 如果只剩下一只精灵
            if n_remaining_mons == 1:
                return True
        return False
    # 解析LLM输出,找到JSON内容的起始位置
    json_start = llm_output.find('{')
    # 找到JSON内容的结束位置,从后往前找第一个}
    json_end = llm_output.rfind('}') + 1
    # 提取JSON内容
    json_content = llm_output[json_start:json_end]
    # 将JSON内容加载为Python对象
    llm_action_json = json.loads(json_content)
    # 初始化下一个动作为None
    next_action = None
    
    # 如果JSON中包含"move"字段
    if "move" in llm_action_json.keys():
        # 获取LLM中的移动ID并处理格式
        llm_move_id = llm_action_json["move"]
        llm_move_id = llm_move_id.replace(" ","").replace("-", "")
        # 遍历可用的移动列表,匹配LLM中的移动ID
        for i, move in enumerate(battle.available_moves):
            if move.id.lower() == llm_move_id.lower():
                # 创建相应的移动指令
                next_action = self.create_order(move, dynamax=self._should_dynamax(battle))
    # 如果JSON中包含"switch"字段
    elif "switch" in llm_action_json.keys():
        # 获取LLM中的交换精灵种类并匹配可用的交换精灵列表
        llm_switch_species = llm_action_json["switch"]
        for i, pokemon in enumerate(battle.available_switches):
            if pokemon.species.lower() == llm_switch_species.lower():
                # 创建相应的交换指令
                next_action = self.create_order(pokemon)
    # 如果下一个动作仍为None,则抛出数值错误异常
    if next_action is None:
        raise ValueError("Value Error")
    # 返回下一个动作
    return next_action
    # 解析LLM输出,找到JSON内容的起始位置
    json_start = llm_output.find('{')
    # 找到JSON内容的结束位置,从后往前找第一个}
    json_end = llm_output.rfind('}') + 1
    # 提取JSON内容
    json_content = llm_output[json_start:json_end]
    # 将JSON内容转换为Python对象
    llm_action_json = json.loads(json_content)
    next_action = None
    # 获取动作和目标
    action = llm_action_json["decision"]["action"]
    target = llm_action_json["decision"]["target"]
    # 处理目标字符串,去除空格和下划线
    target = target.replace(" ", "").replace("_", "")
    # 如果动作是移动
    if action.lower() == "move":
        # 遍历可用的移动
        for i, move in enumerate(battle.available_moves):
            # 如果移动ID匹配目标
            if move.id.lower() == target.lower():
                # 创建移动指令
                next_action = self.create_order(move, dynamax=self._should_dynamax(battle))
    # 如果动作是交换
    elif action.lower() == "switch":
        # 遍历可用的交换精灵
        for i, pokemon in enumerate(battle.available_switches):
            # 如果精灵种类匹配目标
            if pokemon.species.lower() == target.lower():
                # 创建交换指令
                next_action = self.create_order(pokemon)
    # 如果没有找到下一步动作,抛出数值错误
    if next_action is None:
        raise ValueError("Value Error")
    # 返回下一步动作
    return next_action
    # 检查状态并返回对应的字符串
    def check_status(self, status):
        if status:
            if status.value == 1:
                return "burnt"
            elif status.value == 2:
                return "fainted"
            elif status.value == 3:
                return "frozen"
            elif status.value == 4:
                return "paralyzed"
            elif status.value == 5:
                return "poisoned"
            elif status.value == 7:
                return "toxic"
            elif status.value == 6:
                return "sleeping"
        else:
            return ""
    # 根据状态和等级返回加成倍数
    def boost_multiplier(self, state, level):
        # 如果状态是准确度
        if state == "accuracy":
            # 根据等级返回对应的加成倍数
            if level == 0:
                return 1.0
            if level == 1:
                return 1.33
            if level == 2:
                return 1.66
            if level == 3:
                return 2.0
            if level == 4:
                return 2.5
            if level == 5:
                return 2.66
            if level == 6:
                return 3.0
            if level == -1:
                return 0.75
            if level == -2:
                return 0.6
            if level == -3:
                return 0.5
            if level == -4:
                return 0.43
            if level == -5:
                return 0.36
            if level == -6:
                return 0.33
        # 如果状态不是准确度
        else:
            # 根据等级返回对应的加成倍数
            if level == 0:
                return 1.0
            if level == 1:
                return 1.5
            if level == 2:
                return 2.0
            if level == 3:
                return 2.5
            if level == 4:
                return 3.0
            if level == 5:
                return 3.5
            if level == 6:
                return 4.0
            if level == -1:
                return 0.67
            if level == -2:
                return 0.5
            if level == -3:
                return 0.4
            if level == -4:
                return 0.33
            if level == -5:
                return 0.29
            if level == -6:
                return 0.25
    # 返回战斗摘要信息,包括击败得分、剩余得分、胜利列表和标签列表
    def battle_summary(self):
        
        # 初始化空列表用于存储击败得分、剩余得分、胜利列表和标签列表
        beat_list = []
        remain_list = []
        win_list = []
        tag_list = []
        
        # 遍历每场战斗,计算击败得分、剩余得分、是否胜利以及标签
        for tag, battle in self.battles.items():
            beat_score = 0
            # 计算对手队伍的击败得分
            for mon in battle.opponent_team.values():
                beat_score += (1-mon.current_hp_fraction)
            beat_list.append(beat_score)
            remain_score = 0
            # 计算己方队伍的剩余得分
            for mon in battle.team.values():
                remain_score += mon.current_hp_fraction
            remain_list.append(remain_score)
            # 如果战斗胜利,则在胜利列表中添加1
            if battle.won:
                win_list.append(1)
            tag_list.append(tag)
        # 返回击败得分列表、剩余得分列表、胜利列表和标签列表
        return beat_list, remain_list, win_list, tag_list
    # 辅助计算奖励值的函数
    def reward_computing_helper(
        self,
        battle: AbstractBattle,
        *,
        fainted_value: float = 0.0,
        hp_value: float = 0.0,
        number_of_pokemons: int = 6,
        starting_value: float = 0.0,
        status_value: float = 0.0,
        victory_value: float = 1.0,
    ) -> float:
        """A helper function to compute rewards."""
        # 如果战斗不在奖励缓冲区中,则将其添加,并设置初始值
        if battle not in self._reward_buffer:
            self._reward_buffer[battle] = starting_value
        current_value = 0
        # 遍历我方队伍中的每只精灵
        for mon in battle.team.values():
            # 根据当前生命值比例计算当前值
            current_value += mon.current_hp_fraction * hp_value
            # 如果精灵已经倒下,则减去倒下值
            if mon.fainted:
                current_value -= fainted_value
            # 如果精灵有异常状态,则减去异常状态值
            elif mon.status is not None:
                current_value -= status_value
        # 根据己方队伍中精灵数量与总精灵数量的差值计算当前值
        current_value += (number_of_pokemons - len(battle.team)) * hp_value
        # 遍历对方队伍中的每只精灵
        for mon in battle.opponent_team.values():
            # 根据当前生命值比例计算当前值
            current_value -= mon.current_hp_fraction * hp_value
            # 如果精灵已经倒下,则加上倒下值
            if mon.fainted:
                current_value += fainted_value
            # 如果精灵有异常状态,则加上异常状态值
            elif mon.status is not None:
                current_value += status_value
        # 根据对方队伍中精灵数量与总精灵数量的差值计算当前值
        current_value -= (number_of_pokemons - len(battle.opponent_team)) * hp_value
        # 如果战斗胜利,则加上胜利值
        if battle.won:
            current_value += victory_value
        # 如果战斗失败,则减去胜利值
        elif battle.lost:
            current_value -= victory_value
        # 计算当前值与奖励缓冲区中的值的差值作为返回值
        to_return = current_value - self._reward_buffer[battle] # the return value is the delta
        self._reward_buffer[battle] = current_value
        return to_return
    def choose_max_damage_move(self, battle: AbstractBattle):
        # 如果有可用的招式,则选择基础威力最大的招式
        if battle.available_moves:
            best_move = max(battle.available_moves, key=lambda move: move.base_power)
            return self.create_order(best_move)
        # 如果没有可用的招式,则随机选择一个招式
        return self.choose_random_move(battle)

.\PokeLLMon\poke_env\player\llama_player.py

# 导入所需的模块
from poke_env.player.gpt_player import LLMPlayer
from poke_env.environment.abstract_battle import AbstractBattle
import json
from peft import PeftModel
import transformers
import torch
from poke_env.player.player import BattleOrder
# 设置空字符串作为默认的令牌
my_token = ""
# 定义忽略索引
IGNORE_INDEX = -100
# 定义默认的填充令牌
DEFAULT_PAD_TOKEN = "[PAD]"
# 定义默认的结束令牌
DEFAULT_EOS_TOKEN = "</s>"
# 定义默认的开始令牌
DEFAULT_BOS_TOKEN = "<s>"
# 定义默认的未知令牌
DEFAULT_UNK_TOKEN = "<unk>"
# 定义 LLAMAPlayer 类,继承自 LLMPlayer
class LLAMAPlayer(LLMPlayer):
    # 初始化函数,接受多个参数
    def __init__(self, battle_format,
                 model_name_or_path: str = "",
                 # tokenizer_path: str = "",
                 lora_weights: str = "",
                 model_max_length: int = 2048,
                 w_reason = False,
                 log_dir = "",
                 account_configuration=None,
                 server_configuration=None,
                 ):
        # 调用父类的初始化函数
        super().__init__(battle_format=battle_format,
                         account_configuration=account_configuration,
                         server_configuration=server_configuration)
        # 初始化 LLAMA 模型
        # 加载 LLAMA 模型
        self.except_cnt = 0
        self.total_cnt = 0
        self.log_dir = log_dir
        self.w_reason = w_reason
        self.last_output = None
        self.last_state_prompt = None
        # 断言确保模型路径已指定
        assert (model_name_or_path), "Please specify the model path"
        # 使用指定的模型路径加载 tokenizer
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(
            model_name_or_path,
            model_max_length=model_max_length,
            padding_side="right",
            use_fast=False,
            use_auth_token=my_token
        )
        # 使用指定的模型路径加载模型
        self.model = transformers.AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            load_in_8bit=False,
            torch_dtype=torch.float16,
            device_map="auto",
            use_auth_token=my_token
        )
        # 如果有 LoRA 权重,则加载
        if lora_weights:
            print("Recover LoRA weights..")
            self.model = PeftModel.from_pretrained(
                self.model,
                lora_weights,
                torch_dtype=torch.float16,
            )
        # 输出加载完成信息
        print("Loading finished...")
        # 设置模型为评估模式
        self.model.eval()

PokéLLMon 源码解析(四)(4)https://developer.aliyun.com/article/1483675

相关文章
|
1月前
|
监控 Java 应用服务中间件
高级java面试---spring.factories文件的解析源码API机制
【11月更文挑战第20天】Spring Boot是一个用于快速构建基于Spring框架的应用程序的开源框架。它通过自动配置、起步依赖和内嵌服务器等特性,极大地简化了Spring应用的开发和部署过程。本文将深入探讨Spring Boot的背景历史、业务场景、功能点以及底层原理,并通过Java代码手写模拟Spring Boot的启动过程,特别是spring.factories文件的解析源码API机制。
69 2
|
2月前
|
缓存 Java 程序员
Map - LinkedHashSet&Map源码解析
Map - LinkedHashSet&Map源码解析
76 0
|
14天前
|
PyTorch Shell API
Ascend Extension for PyTorch的源码解析
本文介绍了Ascend对PyTorch代码的适配过程,包括源码下载、编译步骤及常见问题,详细解析了torch-npu编译后的文件结构和三种实现昇腾NPU算子调用的方式:通过torch的register方式、定义算子方式和API重定向映射方式。这对于开发者理解和使用Ascend平台上的PyTorch具有重要指导意义。
|
18天前
|
缓存 监控 Java
Java线程池提交任务流程底层源码与源码解析
【11月更文挑战第30天】嘿,各位技术爱好者们,今天咱们来聊聊Java线程池提交任务的底层源码与源码解析。作为一个资深的Java开发者,我相信你一定对线程池并不陌生。线程池作为并发编程中的一大利器,其重要性不言而喻。今天,我将以对话的方式,带你一步步深入线程池的奥秘,从概述到功能点,再到背景和业务点,最后到底层原理和示例,让你对线程池有一个全新的认识。
47 12
|
1月前
|
存储 安全 Linux
Golang的GMP调度模型与源码解析
【11月更文挑战第11天】GMP 调度模型是 Go 语言运行时系统的核心部分,用于高效管理和调度大量协程(goroutine)。它通过少量的操作系统线程(M)和逻辑处理器(P)来调度大量的轻量级协程(G),从而实现高性能的并发处理。GMP 模型通过本地队列和全局队列来减少锁竞争,提高调度效率。在 Go 源码中,`runtime.h` 文件定义了关键数据结构,`schedule()` 和 `findrunnable()` 函数实现了核心调度逻辑。通过深入研究 GMP 模型,可以更好地理解 Go 语言的并发机制。
|
1月前
|
消息中间件 缓存 安全
Future与FutureTask源码解析,接口阻塞问题及解决方案
【11月更文挑战第5天】在Java开发中,多线程编程是提高系统并发性能和资源利用率的重要手段。然而,多线程编程也带来了诸如线程安全、死锁、接口阻塞等一系列复杂问题。本文将深度剖析多线程优化技巧、Future与FutureTask的源码、接口阻塞问题及解决方案,并通过具体业务场景和Java代码示例进行实战演示。
48 3
|
2月前
|
存储
让星星⭐月亮告诉你,HashMap的put方法源码解析及其中两种会触发扩容的场景(足够详尽,有问题欢迎指正~)
`HashMap`的`put`方法通过调用`putVal`实现,主要涉及两个场景下的扩容操作:1. 初始化时,链表数组的初始容量设为16,阈值设为12;2. 当存储的元素个数超过阈值时,链表数组的容量和阈值均翻倍。`putVal`方法处理键值对的插入,包括链表和红黑树的转换,确保高效的数据存取。
61 5
|
2月前
|
Java Spring
Spring底层架构源码解析(三)
Spring底层架构源码解析(三)
144 5
|
2月前
|
XML Java 数据格式
Spring底层架构源码解析(二)
Spring底层架构源码解析(二)
|
2月前
|
算法 Java 程序员
Map - TreeSet & TreeMap 源码解析
Map - TreeSet & TreeMap 源码解析
39 0

推荐镜像

更多