一、导读
今日,智谱AI发布了最新的代码模型CodeGeeX2-6B(https://mp.weixin.qq.com/s/qw31ThM4AjG6RrjNwsfZwg),并已在魔搭社区开源。
CodeGeeX2作为多语言代码生成模型CodeGeeX的第二代模型,使用ChatGLM2架构注入代码实现,具有多种特性,如更强大的代码能力、更优秀的模型特性、更全面的AI编程助手和更开放的协议等。
本文提供了CodeGeeX2的微调教程,希望更多开发者基于开源和数据集微调CodeGeeX2,共同创造AI生态。期待通过这一开源,让CodeGeeX2能成为每一位程序员的编程助手。
魔搭开源链接:https://modelscope.cn/models/ZhipuAI/codegeex2-6b/summary
二、环境配置与安装
本文使用ModelScope的Notebook免费环境测试,python>=3.8
升级ModelScope环境:
ModelScope需要升级到github上最新的master版本(预计8月1号发布版本),进入Notebook的Terminal环境:
更新ModelScope版本:
git clone https://github.com/modelscope/modelscope.git cd modelscope pip install .
三、模型链接及下载
CodeGeeX2-6B
模型链接:https://modelscope.cn/models/ZhipuAI/codegeex2-6b/summary
使用notebook进行模型weights下载(飞一样的速度,可以达到百兆每秒):
from modelscope.hub.snapshot_download import snapshot_download model_dir = snapshot_download('ZhipuAI/codegeex2-6b', revision='v1.0.0')
四、模型推理
CodeGeeX2-6B推理代码,版本更新前,需要在Notebook的Terminal里面执行
import torch from modelscope import AutoModel, AutoTokenizer model_id = 'ZhipuAI/codegeex2-6b' tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) model = AutoModel.from_pretrained(model_id, device_map={'': 'cuda:0'}, # or device_map='auto' torch_dtype=torch.bfloat16, trust_remote_code=True) model = model.eval() # remember adding a language tag for better performance prompt = "# language: python\n# write a bubble sort function\n" inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device) outputs = model.generate(inputs, max_length=256) response = tokenizer.decode(outputs[0]) print(response)
推理运行显存:13G
五、效果体验
体验了一下使用python解决八皇后问题,效果还是不错的!
>>> prompt = "# language: python\n# solve eight queen problem\n" >>> inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device) >>> outputs = model.generate(inputs, max_length=512) >>> response = tokenizer.decode(outputs[0]) >>> print(response) # language: python # solve eight queen problem def conflict(state, nextX): nextY = len(state) for i in range(nextY): if abs(state[i] - nextX) in (0, nextY - i): return True return False def queens(num=8, state=()): for pos in range(num): if not conflict(state, pos): if len(state) == num - 1: yield (pos,) else: for result in queens(num, state + (pos,)): yield (pos,) + result if __name__ == "__main__": print(list(queens(8)))
使用C++解决快排问题
>>> prompt = "// language: C++\n// write a quick sort function\n" >>> inputs = tokenizer.encode(prompt, return_tensors="pt").to(model.device) >>> outputs = model.generate(inputs, max_length=512) >>> response = tokenizer.decode(outputs[0]) >>> print(response) // language: C++ // write a quick sort function #include <iostream> #include <vector> using namespace std; void quickSort(vector<int> &arr, int start, int end) { if (start >= end) { return; } int pivot = arr[start]; int left = start; int right = end; while (left < right) { while (left < right && arr[right] >= pivot) { right--; } while (left < right && arr[left] <= pivot) { left++; } if (left < right) { swap(arr[left], arr[right]); } } swap(arr[left], arr[start]); quickSort(arr, start, left - 1); quickSort(arr, left + 1, end); } int main() { vector<int> arr = {5, 3, 4, 1, 2, 8, 7, 9, 6, 0}; quickSort(arr, 0, arr.size() - 1); for (int i = 0; i < arr.size(); i++) { cout << arr[i] << " "; } return 0; }