使用QLoRa微调Llama 2

简介: 上篇文章我们介绍了Llama 2的量化和部署,本篇文章将介绍使用PEFT库和QLoRa方法对Llama 27b预训练模型进行微调。我们将使用自定义数据集来构建情感分析模型。只有可以对数据进行微调我们才可以将这种大模型进行符合我们数据集的定制化。

一些前置的知识

如果熟悉Google Colab、Weights & Biases (W&B)、HF库,可以跳过这一节。

虽然Google Colab(托管的Jupyter笔记本环境)不是真正的先决条件,但我们建议使用它来访问GPU并进行快速实验。如果是付费的用户,则可以使用高级GPU访问,比如A100这样的GPU。

W&B帐户的作用是记录进度和训练指标,这个如果不需要也可以用tensorboard替代,但是我们是演示Google Colab环境所以直接用它。

然后就是需要一个HF帐户。然后转到settings,创建至少具有读权限的API令牌。因为在训练脚本时将使用它下载预训练的Llama 2模型和数据集。

最后就是请求访问Llama 2模型。等待Meta AI和HF的邮件。这可能要1-2天。

准备数据集

指令微调是一种常用技术,用于为特定的下游用例微调基本LLM。训练示例如下:

 Below is an instruction that describes a sentiment analysis task...

 ### Instruction:
 Analyze the following comment and classify the tone as...

 ### Input:
 I love reading your articles...

 ### Response:
 friendly & constructive

我们建议使用json,因为这样比较灵活。比如为每个示例创建一个JSON对象,其中只有一个文本字段。像这样:

 { "text": "Below is an instruction ... ### Instruction: Analyze the... ### Input: I love... ### Response: friendly" },
 { "text": "Below is an instruction ... ### Instruction: ..." }

有很多很多方法可以提取原始数据、处理和创建训练数据集作为json文件。下面是一个简单的脚本:

 with open('train.jsonl', 'a') as outfile:
     for example in raw_data:
         text = '<process_example>'
         # now append entry to the jsonl file.
         outfile.write('{"text": "' + text + '"}')
         outfile.write('\n')

如HF的Datasets库也是一个选择,但是我个人觉得他不好用。

在我们开始训练之前,我们要将文件作为数据集存储库推送到HF。可以直接使用huggingface-cli上传数据集。

训练

Parameter-Efficient Fine-Tuning(PEFT)可以用于在不触及LLM的所有参数的情况下对LLM进行有效的微调。PEFT支持QLoRa方法,通过4位量化对LLM参数的一小部分进行微调。

Transformer Reinforcement Learning (TRL)是一个使用强化学习来训练语言模型的库。TRL也提供的监督微调(SFT)训练器API可以让我们快速的微调模型。

 !pip install -q huggingface_hub
 !pip install -q -U trl transformers accelerate peft
 !pip install -q -U datasets bitsandbytes einops wandb

 # Uncomment to install new features that support latest models like Llama 2
 # !pip install git+https://github.com/huggingface/peft.git
 # !pip install git+https://github.com/huggingface/transformers.git

 # When prompted, paste the HF access token you created earlier.
 from huggingface_hub import notebook_login
 notebook_login()

 from datasets import load_dataset
 import torch
 from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer, TrainingArguments
 from peft import LoraConfig
 from trl import SFTTrainer

 dataset_name = "<your_hf_dataset>"
 dataset = load_dataset(dataset_name, split="train")

 base_model_name = "meta-llama/Llama-2-7b-hf"

 bnb_config = BitsAndBytesConfig(
     load_in_4bit=True,
     bnb_4bit_quant_type="nf4",
     bnb_4bit_compute_dtype=torch.float16,
 )

 device_map = {"": 0}

 base_model = AutoModelForCausalLM.from_pretrained(
     base_model_name,
     quantization_config=bnb_config,
     device_map=device_map,
     trust_remote_code=True,
     use_auth_token=True
 )
 base_model.config.use_cache = False

 # More info: https://github.com/huggingface/transformers/pull/24906
 base_model.config.pretraining_tp = 1 

 peft_config = LoraConfig(
     lora_alpha=16,
     lora_dropout=0.1,
     r=64,
     bias="none",
     task_type="CAUSAL_LM",
 )

 tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
 tokenizer.pad_token = tokenizer.eos_token

 output_dir = "./results"

 training_args = TrainingArguments(
     output_dir=output_dir,
     per_device_train_batch_size=4,
     gradient_accumulation_steps=4,
     learning_rate=2e-4,
     logging_steps=10,
     max_steps=500
 )

 max_seq_length = 512

 trainer = SFTTrainer(
     model=base_model,
     train_dataset=dataset,
     peft_config=peft_config,
     dataset_text_field="text",
     max_seq_length=max_seq_length,
     tokenizer=tokenizer,
     args=training_args,
 )

 trainer.train()

 import os
 output_dir = os.path.join(output_dir, "final_checkpoint")
 trainer.model.save_pretrained(output_dir)

上面的脚本就是一个微调的简单代码,这里可以添加命令行参数解析器模块,如HfArgumentParser,这样就不必硬编码这些值

测试

下面时一个简单的加载模型并进行完整性测试的快速方法。

 from peft import AutoPeftModelForCausalLM

 model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map=device_map, torch_dtype=torch.bfloat16)
 text = "..."
 inputs = tokenizer(text, return_tensors="pt").to(device)
 outputs = model.generate(input_ids=inputs["input_ids"].to("cuda"), attention_mask=inputs["attention_mask"], max_new_tokens=50, pad_token_id=tokenizer.eos_token_id)

 print(tokenizer.decode(outputs[0], skip_special_tokens=True))

这样就能够查看我们的结果了。

本文作者:UD

原文地址:

https://avoid.overfit.cn/post/e2b178db4f9344c2a659925689c1f049

目录
相关文章
|
1月前
|
监控 计算机视觉 知识图谱
YOLOv10的改进、部署和微调训练总结
YOLOv10在实时目标检测中提升性能与效率,通过无NMS训练解决延迟问题,采用一致的双任务和效率-精度驱动的模型设计。YOLOv10-S比RT-DETR-R18快1.8倍,YOLOv10-B比YOLOv9-C延迟减少46%。新方法包括一致性双标签分配,优化计算冗余和增强模型能力。实验结果显示YOLOv10在AP和延迟上均有显著改善。文章还提供了部署和微调YOLOv10的示例代码。
200 2
|
2月前
|
机器学习/深度学习 算法 测试技术
使用ORPO微调Llama 3
ORPO是一种结合监督微调和偏好对齐的新型微调技术,旨在减少训练大型语言模型所需资源和时间。通过在一个综合训练过程中结合这两种方法,ORPO优化了语言模型的目标,强化了对首选响应的奖励,弱化对不期望回答的惩罚。实验证明ORPO在不同模型和基准上优于其他对齐方法。本文使用Llama 3 8b模型测试ORPO,结果显示即使只微调1000条数据一个epoch,性能也有所提升,证实了ORPO的有效性。完整代码和更多细节可在相关链接中找到。
248 10
|
2月前
|
数据采集 机器学习/深度学习 存储
使用LORA微调RoBERTa
模型微调是指在一个已经训练好的模型的基础上,针对特定任务或者特定数据集进行再次训练以提高性能的过程。微调可以在使其适应特定任务时产生显着的结果。
152 0
|
9月前
|
人工智能 搜索推荐 算法
曼曼心理咨询【基于ChatGLM-6B微调】
曼曼心理咨询【基于ChatGLM-6B微调】
430 0
|
12月前
|
人工智能 搜索推荐 物联网
如何训练个人的Gpt4ALL
如何训练个人的Gpt4ALL
2894 0
如何训练个人的Gpt4ALL
|
机器学习/深度学习 JSON 物联网
ChatGLM-6B 部署与 P-Tuning 微调实战
自从 ChatGPT 爆火以来,树先生一直琢磨想打造一个垂直领域的 LLM 专属模型,但学习文本大模型的技术原理,从头打造一个 LLM 模型难度极大。。。
2816 1
|
26天前
|
缓存 自然语言处理 分布式计算
LLM 推理的极限速度
【6月更文挑战第9天】自然语言处理中的大型语言模型面临着推理速度挑战。为了实现快速推理,优化涉及硬件(如使用高性能GPU)、软件(模型架构设计和算法优化)、数据预处理等方面。代码示例展示了Python中LLM推理时间的计算。其他加速方法包括模型量化、缓存机制和分布式计算。通过多方位优化,可提升LLM的性能,以满足实时应用需求。未来技术发展有望带来更大突破。
95 5
|
2月前
|
人工智能 物联网 调度
Llama 3 训练推理,上阿里云!
Llama 3 训练推理,上阿里云!
161 1
|
2月前
spinbox微调器
spinbox微调器
28 5
|
2月前
|
自然语言处理 C++
GPT4 vs Llama,大模型训练的坑
训练大模型,总觉得效果哪里不对,查了三天,终于发现了原因
107 0

热门文章

最新文章

  • 1
    流量控制系统,用正则表达式提取汉字
    25
  • 2
    Redis09-----List类型,有序,元素可以重复,插入和删除快,查询速度一般,一般保存一些有顺序的数据,如朋友圈点赞列表,评论列表等,LPUSH user 1 2 3可以一个一个推
    26
  • 3
    Redis08命令-Hash类型,也叫散列,其中value是一个无序字典,类似于java的HashMap结构,Hash结构可以将对象中的每个字段独立存储,可以针对每字段做CRUD
    25
  • 4
    Redis07命令-String类型字符串,不管是哪种格式,底层都是字节数组形式存储的,最大空间不超过512m,SET添加,MSET批量添加,INCRBY age 2可以,MSET,INCRSETEX
    27
  • 5
    S外部函数可以访问函数内部的变量的闭包-闭包最简单的用不了,闭包是内层函数+外层函数的变量,简称为函数套函数,外部函数可以访问函数内部的变量,存在函数套函数
    23
  • 6
    Redis06-Redis常用的命令,模糊的搜索查询往往会对服务器产生很大的压力,MSET k1 v1 k2 v2 k3 v3 添加,DEL是删除的意思,EXISTS age 可以用来查询是否有存在1
    30
  • 7
    Redis05数据结构介绍,数据结构介绍,官方网站中看到
    21
  • 8
    JS字符串数据类型转换,字符串如何转成变量,+号只要有一个是字符串,就会把另外一个转成字符串,- * / 都会把数据转成数字类型,数字型控制台是蓝色,字符型控制台是黑色,
    19
  • 9
    JS数组操作---删除,arr.pop()方法从数组中删除最后一个元素,并返回该元素的值,arr.shift() 删除第一个值,arr.splice()方法,删除指定元素,arr.splice,从第一
    19
  • 10
    定义好变量,${age}模版字符串,对象可以放null,检验数据类型console.log(typeof str)
    19