以小博大,微软开源27亿参数模型Phi-2,魔搭最佳实践来啦!

简介: 近日,微软公布了在 Microsoft Ignite 2023大会上宣布开源的 Phi-2 模型的更多细节,“打破传统语言模型缩放定律,可PK比自己大25倍的模型”、“以小博大”等评价,让Phi-2一时间在开源社区中引发关注。

近日,微软公布了在 Microsoft Ignite 2023大会上宣布开源的 Phi-2 模型的更多细节,“打破传统语言模型缩放定律,可PK比自己大25倍的模型”、“以小博大”等评价,让Phi-2一时间在开源社区中引发关注。

Phi-2是一个具有27亿参数的模型,在常识推理、语言理解、数学和代码任务的基准测试评估中, Phi-2和参数不到130亿热门开源模型对比,表现出了优异的性能。

官方总结了Phi-2背后的关键洞见:

  • 教科书级的数据集:秉持Phi系列“Textbooks Are All You Need”的宗旨,Phi-2的训练数据混合了专门用于教授模型进行常识推理和掌握一般知识的合成数据集,包括科学、日常活动和心智理论等,并进一步筛选了具有教育价值和内容质量的网络数据来增强训练语料库。
  • 知识迁移:以1.3B的Phi-1.5模型为基础,将其知识嵌入到27亿参数的Phi-2中。这种规模化的知识转移不仅加速了训练的收敛,还明显提升了Phi-2的基准得分。

值得一提的是,Phi-2没有经过基于人类反馈的增强学习(RLHF)的校准,也没有进行指导性的微调,但得益于其高质量的训练数据,与经过对齐的现有开源模型相比,Phi-2在毒性和偏见方面表现更好。

应用方面官方强调 Phi-2 当前仅用于研究目的,旨在为AI开发者和研究者提供一个探索可解释性、安全性改进及各种任务微调实验的工具。模型生成的文本/代码应被视为潜在用例的起点,而不是最终解决方案,Phi-2目前未对生产任务进行过测试,暂无法保证支持生产级别的应用程序的性能。

接下来,为大家带来Phi-2在魔搭社区的推理、微调最佳实践教程,希望对其感兴趣的小伙伴有所帮助

环境配置与安装

  1. python 3.8及以上版本
  2. pytorch 1.12及以上版本,推荐2.0及以上版本
  3. 建议使用CUDA 11.4及以上

如果你使用ModelScope进行推理:

pip install modelscope transformers -U

如果使用SWIFT进行流式输出, 推理加速和微调:

# 安装ms-swift
git clone https://github.com/modelscope/swift.git
cd swift
pip install -e .[llm]
# 如果你要进行推理加速
# vllm与cuda版本有对应关系,请按照`https://docs.vllm.ai/en/latest/getting_started/installation.html`选择版本
pip install vllm -U

本文主要演示的 Phi-2 模型的模型推理在 ModelScope 的 Notebook 的环境(这里以PAI-DSW为例)的配置下运行.

模型链接和下载

Phi-2 模型现可在ModelScope社区下载体验:

Phi-2链接:

https://modelscope.cn/models/AI-ModelScope/phi-2/summary

from modelscope import snapshot_download
model_dir = snapshot_download("AI-ModelScope/phi-2", revision = "master")

模型推理

使用ModelScope

推理代码:

import torch
from modelscope import AutoModelForCausalLM, AutoTokenizer
torch.set_default_device("cuda")
model = AutoModelForCausalLM.from_pretrained("AI-ModelScope/phi-2", torch_dtype="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("AI-ModelScope/phi-2", trust_remote_code=True)
inputs = tokenizer('''def print_prime(n):
   """
   Print all primes between 1 and n
   """''', return_tensors="pt", return_attention_mask=False)
outputs = model.generate(**inputs, max_length=200)
text = tokenizer.batch_decode(outputs)[0]
print(text)

资源消耗: 11GB 显存

使用SWIFT进行流式输出

推理代码:

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from swift.llm import (
    get_model_tokenizer, get_template, inference_stream, 
    ModelType, get_default_template_type,
)
from swift.utils import seed_everything
model_type = ModelType.phi2_3b
template_type = get_default_template_type(model_type)
model, tokenizer = get_model_tokenizer(model_type, model_kwargs={'device_map': 'auto'})
# 修改max_new_tokens
model.generation_config.max_new_tokens = 512
template = get_template(template_type, tokenizer)
seed_everything(42)
query = """\
# Print all primes between 1 and n
```python
"""
gen = inference_stream(model, template, query, stop_words=['```\n'])
print_idx = 0
print(query, end='')
for response, history in gen:
    print(response[print_idx:], end='')
    print_idx = len(response)
print()

资源消耗: 7GB 显存

使用vllm推理加速

SWIFT集成了Phi-2和vllm. 我们可以使用SWIFT对模型进行推理加速:

文档:

https://github.com/modelscope/swift/blob/main/docs/source/LLM/VLLM推理加速与部署.md

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from swift.llm import (
    ModelType, get_vllm_engine, get_default_template_type,
    get_template, inference_vllm
)
model_type = ModelType.phi2_3b
llm_engine = get_vllm_engine(model_type)
template_type = get_default_template_type(model_type)
template = get_template(template_type, llm_engine.tokenizer)
# 与`transformers.GenerationConfig`类似的接口
llm_engine.generation_config.max_new_tokens = 512
query1 = """\
# Print all primes between 1 and n
```python
"""
query2 = """\
# quick_sort
```python
"""
request_list = [{'query': query1}, {'query': query2}]
llm_engine.generation_config.stop = ['```\n']
resp_list = inference_vllm(llm_engine, template, request_list)
for request, resp in zip(request_list, resp_list):
    print(f"{request['query']}{resp['response']}")
"""Out[0]
# Print all primes between 1 and n
```python
def is_prime(n):
    # Assume n is a positive integer
    # Check if n is 1 or less, which are not prime
    if n <= 1:
        return False
    # Check if n is divisible by any integer from 2 to its square root
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    # If no divisor is found, n is prime
    return True
for n in range(1, 101):
    if is_prime(n):
        print(n)
```
# quick_sort
```python
import random
def quick_sort(arr) -> None:
    if len(arr) <= 1:
        return
    pivot = arr[random.randint(0, len(arr)-1)]
    less = [x for x in arr if x < pivot]
    middle = [x for x in arr if x == pivot]
    greater = [x for x in arr if x > pivot]
    quick_sort(less)
    quick_sort(greater)
    arr[:] = less + middle + greater
a = [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5]
quick_sort(a)
print("Sorted array is:", a)
```
"""

Phi-2模型微调

使用SWIFT对Phi-2进行模型微调, LoRA微调的脚本可以查看:

https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/phi2_3b

微调脚本:

# Experimental environment: A100
# 60GB GPU memory
CUDA_VISIBLE_DEVICES=0 \
swift sft \
    --model_type phi2-3b \
    --sft_type lora \
    --template_type default \
    --train_dataset_sample 20000 \
    --eval_steps 100 \
    --output_dir output \
    --num_train_epochs 1 \
    --max_length 2048 \
    --learning_rate 1e-4 \
    --use_flash_attn true \
    --only_save_model true \
    --lora_target_modules ALL \
    --dataset codefuse-python-en \
    --gradient_checkpointing false \

训练过程也支持本地数据集,需要指定如下参数:

--custom_train_dataset_path xxx.jsonl \
--custom_val_dataset_path yyy.jsonl \

自定义数据集的格式可以参考:

https://github.com/modelscope/swift/blob/main/docs/source/LLM/自定义与拓展.md#注册数据集的方式

微调后的推理脚本:

这里的ckpt_dir需要修改为训练生成的checkpoint文件夹:

# Experimental environment: A10
# 8GB GPU memory
CUDA_VISIBLE_DEVICES=0 \
swift infer \
    --ckpt_dir "output/phi2-3b/vx_xxx/checkpoint-xxx" \
    --load_dataset_config true \
    --max_length 2048 \
    --use_flash_attn false \
    --max_new_tokens 2048 \
    --temperature 0.1 \
    --top_p 0.7 \
    --repetition_penalty 1.05 \
    --merge_lora_and_save false \

微调后生成样例

<<< quick sort, python
Sure! Here's an example of a quick sort implementation in Python with a detailed docstring explaining the code methods:
```python
def quick_sort(arr):
    """
    Sorts a given array using the Quick Sort algorithm.
    Args:
        arr (list): The input array to be sorted.
    Returns:
        list: The sorted array.
    Algorithm:
    1. Choose a pivot element from the array.
    2. Partition the array into two sub-arrays: one with elements smaller than the pivot and another with elements larger than the pivot.
    3. Recursively apply quick sort to the sub-arrays until they are fully sorted.
    4. Combine the sorted sub-arrays with the pivot element to obtain the final sorted array.
    Time Complexity:
    - Best Case: O(n log n)
    - Average Case: O(n log n)
    - Worst Case: O(n^2)
    Space Complexity: O(log n) for the recursive calls.
    """
    if len(arr) <= 1:
        return arr  # Base case: if the array has 0 or 1 element, it is already sorted
    pivot = arr[0]  # Choose the first element as the pivot
    left = []  # List to store elements smaller than the pivot
    right = []  # List to store elements larger than the pivot
    for i in range(1, len(arr)):
        if arr[i] < pivot:
            left.append(arr[i])
        else:
            right.append(arr[i])
    # Recursively sort the left and right sub-arrays
    sorted_left = quick_sort(left)
    sorted_right = quick_sort(right)
    # Combine the sorted sub-arrays with the pivot element
    return sorted_left + [pivot] + sorted_right
```
In this implementation, we use the partitioning technique to divide the array into two sub-arrays based on the pivot element. We then recursively apply quick sort to each sub-array until they are fully sorted. Finally, we combine the sorted sub-arrays with the pivot element to obtain the final sorted array.
The time complexity of quick sort is generally O(n log n), but in the worst case scenario where the pivot is always the smallest or largest element, it can become O(n^2). However, this is rare in practice.
The space complexity of quick sort is O(log n) due to the recursive calls.

跑生成的代码

print(quick_sort([10, 7, 8, 4, 3, 6]))
# [3, 4, 6, 7, 8, 10]

点击直达Phi-2模型卡片:phi-2 · 模型库 (modelscope.cn)

相关文章
|
Java
springboot实现自定义注解限流
springboot实现自定义注解限流
385 1
|
9月前
|
机器学习/深度学习 JSON 并行计算
10分钟微调,让0.6B模型媲美235B模型!免费体验进行中
本方案介绍如何通过模型蒸馏技术,利用大参数模型生成数据并微调小参数模型(如 Qwen3-0.6B),使其在特定任务(如从一句话中提取结构化信息)中达到接近大模型的效果。通过 GPU 云服务器进行高效微调,结合魔搭社区的 ms-swift 框架,用户可快速完成模型训练与部署,显著提升推理速度并降低成本。方案包含详细步骤:数据准备、模型微调、效果验证及部署建议,并提供免费试用资源,助力开发者快速上手实践。
10分钟微调,让0.6B模型媲美235B模型!免费体验进行中
|
4月前
|
机器学习/深度学习 人工智能 监控
基于强化学习的量化交易框架 TensorTrade
TensorTrade 是一个基于强化学习的开源交易算法框架。它通过环境模拟、策略训练与奖励机制,让AI在历史数据中自主学习买卖时机,构建逻辑自洽的交易策略,助力量化研究。
384 9
基于强化学习的量化交易框架 TensorTrade
|
自然语言处理 网络安全 Python
【Python】已解决:nltk.download(‘punkt’) [nltk_data] Error loading punkt: [WinError 10060] [nltk_data]
【Python】已解决:nltk.download(‘punkt’) [nltk_data] Error loading punkt: [WinError 10060] [nltk_data]
4192 1
|
7月前
|
SQL 人工智能 供应链
《AI协同供应链调度困局:从需求拆解到落地增效的全流程实践》
本文记录某制造业供应链调度系统升级的AI协同开发实践:面对旧系统“信息流滞后、决策流固化、响应流迟缓”困境及10周重构需求,团队构建“Cursor+Tabnine+Diagrams AI等”工具矩阵,以AI承接规则性工作、人聚焦核心决策。需求拆解3天完成(效率提130%),架构设计2天规避数据迁移风险,编码5天压缩重复工作,联调2小时定位性能瓶颈。项目提前3周落地,调度响应延迟2.8秒(优于目标30%),供应链成本降8%,订单延误率从15%降至3%。核心认知为AI是“认知延伸器”,需“AI生成+人工校验”闭环,工具矩阵最大化协同价值,同时需避免AI主导需求与核心编码。
435 8
|
安全 Unix Linux
VMware Workstation 17.6.3 发布下载,现在完全免费无论个人还是商业用途
VMware Workstation 17.6.3 发布下载,现在完全免费无论个人还是商业用途
127721 65
|
开发者
鸿蒙next版开发:ArkTS组件通用属性(图形变换)
在HarmonyOS 5.0中,ArkTS提供了强大的图形变换功能,支持组件的旋转、缩放和平移操作,增强用户界面的视觉效果和交互体验。本文详细解读了ArkTS中图形变换的通用属性,并提供了示例代码,包括基础变换、组合变换和动画效果的应用。通过这些示例,开发者可以轻松实现复杂的视觉效果和动态用户界面。
858 1
|
文字识别 自然语言处理 数据可视化
Qwen2.5 全链路模型体验、下载、推理、微调、部署实战!
在 Qwen2 发布后的过去三个月里,许多开发者基于 Qwen2 语言模型构建了新的模型,并提供了宝贵的反馈。在这段时间里,通义千问团队专注于创建更智能、更博学的语言模型。今天,Qwen 家族的最新成员:Qwen2.5系列正式开源
Qwen2.5 全链路模型体验、下载、推理、微调、部署实战!
|
关系型数据库 Linux 数据库
PostgreSQL 入门指南:安装、配置与基本命令
本文从零开始,详细介绍如何在 Windows、Linux 和 macOS 上安装和配置 PostgreSQL,涵盖30+个实操代码示例。内容包括安装步骤、配置远程访问和用户权限、基础数据库操作命令(如创建表、插入和查询数据),以及常见问题的解决方案。通过学习,你将掌握 PostgreSQL 的基本使用方法,并为后续深入学习打下坚实基础。
13842 1
|
存储 缓存 算法
Java代码优化指南
Java代码优化指南
333 1

热门文章

最新文章