Google开源Tunix:JAX生态的LLM微调方案来了

本文涉及的产品
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
实时数仓Hologres,5000CU*H 100GB 3个月
实时计算 Flink 版,1000CU*H 3个月
简介: Tunix是Google推出的基于JAX的LLM后训练库,支持微调、强化学习与知识蒸馏,集成Flax NNX,主打TPU优化与模块化设计,支持QLoRA等高效训练方法,适用于高性能分布式训练场景。

JAX生态这两年在LLM训练这块追赶得挺快。PyTorch虽然还是主流但JAX在并行计算、TPU加速和API组合性上确实有些独特的优势。Google今天放出了Tunix这个库,专门做LLM的后训练——微调、强化学习、知识蒸馏这些都能搞。

Tunix是什么

这是个构建在JAX之上的后训练库,和Flax NNX集成得比较紧密。主要解决三类问题:

  • 监督微调(Supervised Fine-Tuning)
  • 强化学习(Reinforcement Learning)
  • 知识蒸馏(Knowledge Distillation)

现在还在早期开发阶段,功能在持续迭代,支持的模型也在慢慢扩展。

核心功能

监督微调:既支持全参数微调,也支持LoRA和Q-LoRA这类参数高效的方法。内存和算力受限的时候,PEFT方案还是挺实用的。

强化学习:实现了几个主流算法:PPO(Proximal Policy Optimization)、GRPO(Group Relative Policy Optimization)、还有token级别的GSPO。另外还有DPO(Direct Preference Optimization)做偏好对齐,这个在RLHF场景用得比较多。

知识蒸馏:支持几种策略,包括基于logit的概率分布匹配、注意力机制的转移和投影、跨架构的特征池化与投影。这几种方法在不同场景下各有用处。

库的设计比较模块化,组件可以自由组合,想扩展自定义流程也不算麻烦。分布式训练支持数据并行(DP)、完全分片数据并行(FSDP)和张量并行(TP),对TPU做了专门优化。

安装

三种装法:

从PyPI装(推荐):

 pip install "tunix[prod]"

或者直接从GitHub主分支:

 pip install git+https://github.com/google/tunix

开发模式从源码装:

 git clone https://github.com/google/tunix.git  
 cd tunix  
 pip install -e".[dev]"

TPU上用QLoRA微调Gemma

拿个英译法的任务来演示。用的是Google的Gemma 2B模型,跑在TPU v5e-8上。

环境准备

 pip install -q kagglehub safetensors tensorflow tensorflow_datasets tensorboardX transformers grain datasets  
 pip install -q git+https://github.com/google/tunix  
 pip install -q git+https://github.com/google/qwix  

 # Flax需要升级到最新版
 pip uninstall -q -y flax  
 pip install -q git+https://github.com/google/flax.git

完整流程

第一步,从Kaggle拉预训练checkpoint:

 import kagglehub  

 model_path = "google/gemma/flax/2b"  
 kaggle_ckpt_path = kagglehub.model_download(model_path)

初始化模型和tokenizer:

 from flax import nnx  
from tunix.models.gemma import model as gemma_lib, params as params_lib  
from tunix.generate import tokenizer_adapter as tokenizer_lib  

base_model = gemma_lib.Transformer.from_params(  
    params_lib.load_and_format_params(kaggle_ckpt_path, "2b"),  
    version="2b"  
)  
 tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=f"{kaggle_ckpt_path}/tokenizer.model")

挂上QLoRA adapter:

 import qwix  

lora_provider = qwix.LoraProvider(  
    module_path=".*(q_einsum|kv_einsum|proj)",  
    rank=16,  
    alpha=2.0,  
    weight_qtype="nf4"  # enable QLoRA quantization
)  
 lora_model = qwix.apply_lora_to_model(base_model, lora_provider)

这里rank设成16,alpha是2.0,weight_qtype指定nf4量化格式。

加载训练数据:

 from tunix.examples.data import translation_dataset  

train_ds, validation_ds = translation_dataset.create_datasets(  
    dataset_name="mtnt/en-fr",  
    global_batch_size=16,  
    max_target_length=256,  
    num_train_epochs=3,  
    tokenizer=tokenizer,  
 )

用的是mtnt的英法平行语料,batch size 16,目标序列最长256个token。

开始训练:

 from tunix.sft import peft_trainer, utils  
import optax  

trainer=peft_trainer.PeftTrainer(  
    lora_model,  
    optimizer=optax.adamw(1e-3),  
    config=peft_trainer.TrainingConfig(max_steps=100)  
)  
 trainer.train(train_ds, validation_ds)

优化器用AdamW,学习率1e-3,跑100步看看效果。

推理测试:

训练完直接用adapter过的模型做生成。Tunix提供了Sampler工具:

 from tunix.generate import sampler as sampler_lib  

# initialize sampler
sampler = sampler_lib.Sampler(  
    transformer=lora_model,  
    tokenizer=tokenizer,  
    cache_config=sampler_lib.CacheConfig(  
        cache_size=256,  
        num_layers=base_model.num_layers,  
        num_kv_heads=base_model.num_kv_heads,  
        head_dim=base_model.head_dim,  
    ),  
)  

# test prompts
input_batch = [  
    "Translate this into French:\nHello, my name is Morgane.\n",  
    "Translate this into French:\nThis dish is delicious!\n",  
    "Translate this into French:\nI am a student.\n",  
    "Translate this into French:\nHow's the weather today?\n",  
]  

# generate predictions
out_data = sampler(  
    input_strings=input_batch,  
    max_generation_steps=20,  
)  

# print results
for input_string, out_string in zip(input_batch, out_data.text):  
    print(f"----------------------")  
    print(f"Prompt:\n{input_string}")  
     print(f"Output:\n{out_string}")

如果用的是QLoRA,把lora_model换成qlora_model就行。生产环境可以考虑把adapter合并回基模型,推理延迟能降下来。

总结

100步训练之后,模型已经能生成一些翻译结果了,虽然质量还不够好。多训练一段时间,准确率会明显提升,而且内存开销和训练速度都保持在不错的水平。

Tunix现在还比较新,但已经能看出一些潜力。TPU优先的设计、模块化的API、LoRA/QLoRA支持、完整的分布式训练策略,这些对做LLM适配研究的人来说都挺有用。

后续应该会继续扩展支持的模型类型和训练算法,值得关注。

地址:https://avoid.overfit.cn/post/c434311d8a894922b6c52ea179cf8d97

作者:Abish Pius

目录
相关文章
|
1月前
|
数据采集 机器学习/深度学习 自然语言处理
98_数据增强:提升LLM微调效果的关键技术
在大语言模型(LLM)的微调过程中,数据质量与数量往往是决定最终性能的关键因素。然而,获取高质量、多样化且标注准确的训练数据却常常面临诸多挑战:数据标注成本高昂、领域特定数据稀缺、数据分布不均等问题都会直接影响微调效果。在这种背景下,数据增强技术作为一种能够有效扩充训练数据并提升其多样性的方法,正发挥着越来越重要的作用。
|
1月前
|
机器学习/深度学习 人工智能 自然语言处理
37_开源LLM:LLaMA与Mistral的突破_深度解析
在人工智能领域,2025年已经成为开源大语言模型的黄金时代。从Meta的LLaMA系列到欧洲初创公司Mistral AI的创新突破,开源LLM正在重塑整个AI生态系统的格局。截至2025年4月,Meta的LLaMA系列已成为全球下载量最高、社区使用最活跃的开源大语言模型之一,并被集成于数百个学术项目、创业平台和AI产品之中
|
7月前
|
人工智能 自然语言处理 测试技术
能够双向推理的LLM!Dream-7B:港大联合华为开源的扩散推理模型,能够同时考虑前后文信息
Dream-7B是由香港大学与华为诺亚方舟实验室联合研发的开源扩散大语言模型,采用独特的掩码扩散范式,在文本生成、数学推理和代码编写等任务中展现出卓越性能。
325 3
能够双向推理的LLM!Dream-7B:港大联合华为开源的扩散推理模型,能够同时考虑前后文信息
|
3月前
|
数据可视化 物联网 开发者
深度解析四大LLM微调工具:从单卡到千亿级训练的四大解决方案
本文详解大语言模型微调四大工具——Unsloth、Axolotl、LlamaFactory、DeepSpeed,覆盖从单卡实验到万亿参数分布式训练场景,助你掌握主流框架选型策略,提升微调效率。建议点赞收藏。
1052 1
|
4月前
|
缓存 异构计算 Docker
构建高性能LLM推理服务的完整方案:单GPU处理172个查询/秒、10万并发仅需15美元/小时
本文将通过系统性实验不同的优化技术来构建自定义LLaMA模型服务,目标是高效处理约102,000个并行查询请求,并通过对比分析确定最优解决方案。
344 0
构建高性能LLM推理服务的完整方案:单GPU处理172个查询/秒、10万并发仅需15美元/小时
|
3月前
|
人工智能 JSON 前端开发
告别无效调参!ReAct代理设计:让LLM精准执行复杂任务的终极方案
ReAct模式通过“推理+行动”循环,使大语言模型能自主调用工具、获取实时信息并执行多步骤任务,有效突破LLM固有局限,提升任务准确性和智能化水平。
527 0
|
3月前
|
机器学习/深度学习 人工智能 自然语言处理
MCP、LLM与Agent:企业AI实施的新基建设计方案
MCP+LLM+Agent架构通过"大脑-神经网络-手脚"的协同机制,实现从数据贯通到自主执行的智能闭环。本文将深度解析该架构如何将产线排查效率提升5倍、让LLM专业术语识别准确率提升26%,并提供从技术选型到分层落地的实战指南,助力企业打造真正融入业务流的"数字员工"。通过协议标准化、动态规划与自愈执行的三重突破,推动AI从演示场景迈向核心业务深水区。
|
7月前
|
机器学习/深度学习 人工智能 算法
RAGEN:RL训练LLM推理新范式!开源强化学习框架让Agent学会多轮决策
RAGEN是一个基于StarPO框架的开源强化学习系统,通过马尔可夫决策过程形式化Agent与环境的交互,支持PPO、GRPO等多种优化算法,显著提升多轮推理训练的稳定性。
736 5
RAGEN:RL训练LLM推理新范式!开源强化学习框架让Agent学会多轮决策
|
数据可视化 定位技术 Sentinel
如何用Google Earth Engine快速、大量下载遥感影像数据?
【2月更文挑战第9天】本文介绍在谷歌地球引擎(Google Earth Engine,GEE)中,批量下载指定时间范围、空间范围的遥感影像数据(包括Landsat、Sentinel等)的方法~
4823 1
如何用Google Earth Engine快速、大量下载遥感影像数据?

推荐镜像

更多