在Colab上测试Mamba

本文涉及的产品
实时数仓Hologres,5000CU*H 100GB 3个月
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
实时计算 Flink 版,5000CU*H 3个月
简介: 我们在前面的文章介绍了研究人员推出了一种挑战Transformer的新架构Mamba

他们的研究表明,Mamba是一种状态空间模型(SSM),在不同的模式(如语言、音频和时间序列)中表现出卓越的性能。为了说明这一点,研究人员使用Mamba-3B模型进行了语言建模实验。该模型超越了基于相同大小的Transformer的其他模型,并且在预训练和下游评估期间,它的表现与大小为其两倍的Transformer模型一样好。

Mamba的独特之处在于它的快速处理能力,选择性SSM层,以及受FlashAttention启发的硬件友好设计。这些特点使Mamba超越Transformer(Transformer没有了传统的注意力和MLP块)。

有很多人希望自己测试Mamba的效果,所以本文整理了一个能够在Colab上完整运行Mamba代码,代码中还使用了Mamba官方的3B模型来进行实际运行测试。

首先我们安装依赖,这是官网介绍的:

 !pip install causal-conv1d==1.0.0
 !pip install mamba-ssm==1.0.1
AI 代码解读

然后直接使用transformers库读取预训练的Mamba-3B

 import torch
 import os
 from transformers import AutoTokenizer
 from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
 tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
 model = MambaLMHeadModel.from_pretrained(os.path.expanduser("state-spaces/mamba-2.8b"), device="cuda", dtype=torch.bfloat16)
AI 代码解读

可以看到,3b的模型有11G

然后就是测试生成内容

 tokens = tokenizer("What is the meaning of life", return_tensors="pt")
 input_ids = tokens.input_ids.to(device="cuda")
 max_length = input_ids.shape[1] + 80
 fn = lambda: model.generate(
         input_ids=input_ids, max_length=max_length, cg=True,
         return_dict_in_generate=True, output_scores=True,
         enable_timing=False, temperature=0.1, top_k=10, top_p=0.1,)
 out = fn()
 print(tokenizer.decode(out[0][0]))
AI 代码解读

这里还有一个chat的示例

 import torch
 from transformers import AutoTokenizer
 from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

 device = "cuda"
 tokenizer = AutoTokenizer.from_pretrained("havenhq/mamba-chat")
 tokenizer.eos_token = "<|endoftext|>"
 tokenizer.pad_token = tokenizer.eos_token
 tokenizer.chat_template = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta").chat_template

 model = MambaLMHeadModel.from_pretrained("havenhq/mamba-chat", device="cuda", dtype=torch.float16)


 messages = []
 user_message = """
 What is the date for announcement
 On August 10 said that its arm JSW Neo Energy has agreed to buy a portfolio of 1753 mega watt renewable energy generation capacity from Mytrah Energy India Pvt Ltd for Rs 10,530 crore.
  """

 messages.append(dict(role="user",content=user_message))
 input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to("cuda")
 out = model.generate(input_ids=input_ids, max_length=2000, temperature=0.9, top_p=0.7, eos_token_id=tokenizer.eos_token_id)
 decoded = tokenizer.batch_decode(out)
 messages.append(dict(role="assistant",content=decoded[0].split("<|assistant|>\n")[-1]))
 print("Model:", decoded[0].split("<|assistant|>\n")[-1])
AI 代码解读

这里我将所有代码整理成了Colab Notebook,有兴趣的可以直接使用:

https://avoid.overfit.cn/post/ed2d2cc2460d4e0683a270e2761e10ea

目录
打赏
0
1
2
1
538
分享
相关文章
利用谷歌colab跑github代码AttnGAN详细步骤 深度学习实验(colab+pytorch+jupyter+github+AttnGAN)
Google Colab,全名Colaboratory,是由谷歌提供的免费的云平台,可以使用pytorch、keras、tensorflow等框架进行深度学习。其GPU为Tesla T4 GPU,有很强的算力,对于刚入门机器学习或深度学习的用户,这个平台是不二之选。
利用谷歌colab跑github代码AttnGAN详细步骤 深度学习实验(colab+pytorch+jupyter+github+AttnGAN)
极智AI | ubuntu编译Darknet与YOLO训练
大家好,我是极智视界,本文介绍了在 ubuntu 上编译 darknet 及 yolo 训练的方法。
158 0
YOLOv3物体/目标检测之实战篇(Windows系统、Python3、TensorFlow2版本)
 基于YOLO进行物体检测、对象识别,在搭建好开发环境后,先和大家进行实践应用中,体验YOLOv3物体/目标检测效果和魅力;同时逐步了解YOLOv3的不足和优化思路。
303 0
我用 PyTorch 复现了 LeNet-5 神经网络(自定义数据集篇)!
详细介绍了卷积神经网络 LeNet-5 的理论部分和使用 PyTorch 复现 LeNet-5 网络来解决 MNIST 数据集和 CIFAR10 数据集。然而大多数实际应用中,我们需要自己构建数据集,进行识别。因此,本文将讲解一下如何使用 LeNet-5 训练自己的数据。
320 0