在Colab上测试Mamba

简介: 我们在前面的文章介绍了研究人员推出了一种挑战Transformer的新架构Mamba

他们的研究表明,Mamba是一种状态空间模型(SSM),在不同的模式(如语言、音频和时间序列)中表现出卓越的性能。为了说明这一点,研究人员使用Mamba-3B模型进行了语言建模实验。该模型超越了基于相同大小的Transformer的其他模型,并且在预训练和下游评估期间,它的表现与大小为其两倍的Transformer模型一样好。

Mamba的独特之处在于它的快速处理能力,选择性SSM层,以及受FlashAttention启发的硬件友好设计。这些特点使Mamba超越Transformer(Transformer没有了传统的注意力和MLP块)。

有很多人希望自己测试Mamba的效果,所以本文整理了一个能够在Colab上完整运行Mamba代码,代码中还使用了Mamba官方的3B模型来进行实际运行测试。

首先我们安装依赖,这是官网介绍的:

 !pip install causal-conv1d==1.0.0
 !pip install mamba-ssm==1.0.1

然后直接使用transformers库读取预训练的Mamba-3B

 import torch
 import os
 from transformers import AutoTokenizer
 from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
 tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
 model = MambaLMHeadModel.from_pretrained(os.path.expanduser("state-spaces/mamba-2.8b"), device="cuda", dtype=torch.bfloat16)

可以看到,3b的模型有11G

然后就是测试生成内容

 tokens = tokenizer("What is the meaning of life", return_tensors="pt")
 input_ids = tokens.input_ids.to(device="cuda")
 max_length = input_ids.shape[1] + 80
 fn = lambda: model.generate(
         input_ids=input_ids, max_length=max_length, cg=True,
         return_dict_in_generate=True, output_scores=True,
         enable_timing=False, temperature=0.1, top_k=10, top_p=0.1,)
 out = fn()
 print(tokenizer.decode(out[0][0]))

这里还有一个chat的示例

 import torch
 from transformers import AutoTokenizer
 from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

 device = "cuda"
 tokenizer = AutoTokenizer.from_pretrained("havenhq/mamba-chat")
 tokenizer.eos_token = "<|endoftext|>"
 tokenizer.pad_token = tokenizer.eos_token
 tokenizer.chat_template = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta").chat_template

 model = MambaLMHeadModel.from_pretrained("havenhq/mamba-chat", device="cuda", dtype=torch.float16)


 messages = []
 user_message = """
 What is the date for announcement
 On August 10 said that its arm JSW Neo Energy has agreed to buy a portfolio of 1753 mega watt renewable energy generation capacity from Mytrah Energy India Pvt Ltd for Rs 10,530 crore.
  """

 messages.append(dict(role="user",content=user_message))
 input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda")
 out = model.generate(input_ids=input_ids, max_length=2000, temperature=0.9, top_p=0.7, eos_token_id=tokenizer.eos_token_id)
 decoded = tokenizer.batch_decode(out)
 messages.append(dict(role="assistant",content=decoded[0].split("<|assistant|>\n")[-1]))
 print("Model:", decoded[0].split("<|assistant|>\n")[-1])

这里我将所有代码整理成了Colab Notebook,有兴趣的可以直接使用:

https://avoid.overfit.cn/post/ed2d2cc2460d4e0683a270e2761e10ea

目录
相关文章
|
机器学习/深度学习 数据采集 自然语言处理
机器学习模型的部署与上线:从训练到实际应用
在机器学习中,模型训练只是整个过程的一部分。将训练好的模型部署到实际应用中,并使其稳定运行,也是非常重要的。本文将介绍机器学习模型的部署与上线过程,包括数据处理、模型选择、部署环境搭建、模型调优等方面。同时,我们也会介绍一些实际应用场景,并分享一些经验和技巧。
|
缓存 NoSQL 安全
Redis 最佳实践 [后端必看]
Redis 最佳实践 [后端必看]
315 0
|
数据库
【latex】在Overleaf的IEEE会议模板中,快速插入参考文献
【latex】在Overleaf的IEEE会议模板中,快速插入参考文献
4023 1
|
11月前
|
人工智能 自然语言处理 算法
科研论文翻译神器!BabelDOC:开源AI工具让PDF论文秒变双语对照,公式图表全保留
BabelDOC 是一款专为科学论文设计的开源AI翻译工具,采用先进的无损解析技术和智能布局识别算法,能完美保留原文格式并生成双语对照翻译。
2756 67
科研论文翻译神器!BabelDOC:开源AI工具让PDF论文秒变双语对照,公式图表全保留
|
数据采集 数据可视化 数据挖掘
销售漏斗分析怎么做?提高成交率的秘密在这里
销售分析是企业提升业绩、优化策略的重要手段。通过系统化数据分析,企业能精准了解市场需求、优化流程并提高转化率。然而,许多企业在实际操作中面临数据分散、分析滞后等问题。本文从核心步骤出发,探讨如何高效开展销售分析,助力企业实现可视化管理和高效协作。具体包括明确分析目标、收集整合数据、分类清洗、深入分析及结果解读,最终将洞察转化为策略优化。借助如板栗看板等工具,可大幅提升分析效率,使企业在数据驱动下做出更精准的决策,从而提高销售业绩和市场份额。
554 23
|
机器学习/深度学习 人工智能 自然语言处理
深度学习与计算机视觉的结合:技术趋势与应用
深度学习与计算机视觉的结合:技术趋势与应用
801 9
|
存储 固态存储 安全
租用阿里云企业级云服务器最新收费标准与活动价格参考
租用阿里云企业级云服务器多少钱?阿里云服务器有多种实例分类,其中通用型、计算型、内存型、通用算力型、大数据型、本地SSD、高主频型和增强型均属于企业级云服务器,目前在阿里云的活动中,通用型、计算型、内存型和通用算力型均有优惠,下面是阿里云企业级云服务器价格表,包含最新收费标准与活动价格,以表格形式展示给大家,以供参考和了解。
租用阿里云企业级云服务器最新收费标准与活动价格参考
|
机器学习/深度学习 编解码 自然语言处理
【18】Vision Transformer:笔记总结与pytorch实现
【18】Vision Transformer:笔记总结与pytorch实现
1349 0
【18】Vision Transformer:笔记总结与pytorch实现
|
机器学习/深度学习 人工智能 自然语言处理
RT-DETR原理与简介(干翻YOLO的最新目标检测项目)
RT-DETR原理与简介(干翻YOLO的最新目标检测项目)
|
开发工具 数据库 git
向量检索服务体验评测
通过一个实用的例子带你全方位了解向量检索服务DashVector
121192 4