第三部分:模型架构
3.1 Transformer基础
现代代码智能模型几乎都基于Transformer架构。理解Transformer是理解AI代码开发的核心。
# Transformer核心组件实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
"""
多头注意力机制
允许模型同时关注不同位置的不同表示子空间
"""
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# 查询、键、值的线性变换
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
"""
缩放点积注意力
Attention(Q,K,V) = softmax(QK^T / sqrt(d_k)) * V
"""
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
output = torch.matmul(attention_weights, V)
return output, attention_weights
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 1. 线性变换并分头
Q = self.W_q(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# 2. 计算注意力
attn_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)
# 3. 合并多头
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 4. 最终线性变换
output = self.W_o(attn_output)
return output, attention_weights
class PositionalEncoding(nn.Module):
"""
位置编码
为序列添加位置信息,因为Transformer本身不包含顺序概念
"""
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
# 创建位置编码矩阵
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
class TransformerBlock(nn.Module):
"""
Transformer编码器块
包含多头注意力和前馈网络,以及残差连接和层归一化
"""
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, n_heads, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(), # GELU比ReLU更适合Transformer
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout)
)
def forward(self, x, mask=None):
# 多头注意力 + 残差连接
attn_output, _ = self.attention(x, x, x, mask)
x = self.norm1(x + attn_output)
# 前馈网络 + 残差连接
ff_output = self.feed_forward(x)
x = self.norm2(x + ff_output)
return x
3.2 代码专用Transformer架构
代码与自然语言不同,需要特殊的设计。
# 代码专用Transformer
class CodeTransformer(nn.Module):
"""
代码专用Transformer模型
包含代码特定的改进:
1. 代码结构感知:AST节点类型嵌入
2. 位置编码:代码的行和列位置
3. 语法约束:保证生成的代码语法正确
"""
def __init__(self, vocab_size, d_model=768, n_heads=12, n_layers=12, d_ff=3072, max_len=2048):
super().__init__()
self.d_model = d_model
# 词嵌入
self.token_embedding = nn.Embedding(vocab_size, d_model)
# 代码结构嵌入(AST节点类型)
self.ast_type_embedding = nn.Embedding(100, d_model) # 100种AST节点类型
# 位置编码
self.positional_encoding = PositionalEncoding(d_model, max_len)
# Transformer层
self.layers = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff)
for _ in range(n_layers)
])
# 输出层
self.ln_f = nn.LayerNorm(d_model)
self.output_projection = nn.Linear(d_model, vocab_size)
# 语法约束模块
self.syntax_constraint = SyntaxConstraintModule()
def forward(self, input_ids, ast_types=None, attention_mask=None):
"""
前向传播
Args:
input_ids: token ID序列 [batch, seq_len]
ast_types: AST节点类型 [batch, seq_len]
attention_mask: 注意力掩码 [batch, seq_len]
"""
# 1. 词嵌入
x = self.token_embedding(input_ids) * math.sqrt(self.d_model)
# 2. 添加AST类型嵌入
if ast_types is not None:
ast_emb = self.ast_type_embedding(ast_types)
x = x + ast_emb
# 3. 添加位置编码
x = self.positional_encoding(x)
# 4. 通过Transformer层
for layer in self.layers:
x = layer(x, attention_mask)
# 5. 输出层
x = self.ln_f(x)
logits = self.output_projection(x)
# 6. 应用语法约束
logits = self.syntax_constraint(logits, input_ids)
return logits
class SyntaxConstraintModule(nn.Module):
"""
语法约束模块
确保生成的token序列符合语言的语法规则
例如:括号必须匹配、字符串必须闭合等
"""
def __init__(self):
super().__init__()
# 定义语法规则
self.bracket_pairs = {
'(': ')',
'[': ']',
'{': '}'
}
def forward(self, logits, input_ids):
"""
根据当前上下文调整token的概率分布
禁止产生违反语法规则的token
"""
# 获取当前括号栈状态
bracket_stack = self.get_bracket_stack(input_ids)
# 创建禁止token掩码
banned_tokens = self.get_banned_tokens(bracket_stack)
# 将禁止token的概率设为负无穷
for token_id in banned_tokens:
logits[:, -1, token_id] = -float('inf')
return logits
def get_bracket_stack(self, input_ids):
"""解析当前括号栈状态"""
stack = []
for token_id in input_ids[0]: # 简化,实际需要考虑batch
token = self.id_to_token(token_id)
if token in self.bracket_pairs:
stack.append(token)
elif token in self.bracket_pairs.values():
if stack and self.bracket_pairs[stack[-1]] == token:
stack.pop()
return stack
def get_banned_tokens(self, bracket_stack):
"""根据括号栈状态返回禁止的token"""
banned = []
# 如果括号不匹配,禁止产生会加深不匹配的token
if bracket_stack:
# 不允许产生错误类型的闭括号
expected_closing = self.bracket_pairs[bracket_stack[-1]]
for token, closing in self.bracket_pairs.items():
if token != bracket_stack[-1]:
banned.append(self.token_to_id(closing))
return banned
3.3 代码专用预训练任务
# 代码预训练任务
class CodePretrainingTasks:
"""
代码专用预训练任务
比通用的掩码语言建模更适合代码理解
"""
@staticmethod
def masked_language_modeling(code, mask_ratio=0.15):
"""
掩码语言建模
随机掩码部分token,让模型预测被掩码的token
"""
tokens = code.split()
masked_tokens = []
labels = []
for token in tokens:
if np.random.random() < mask_ratio:
# 80%概率替换为[MASK]
# 10%概率替换为随机token
# 10%概率保持不变
r = np.random.random()
if r < 0.8:
masked_tokens.append('[MASK]')
elif r < 0.9:
masked_tokens.append(np.random.choice(vocab))
else:
masked_tokens.append(token)
labels.append(token)
else:
masked_tokens.append(token)
labels.append(-1) # -1表示不需要预测
return masked_tokens, labels
@staticmethod
def next_token_prediction(code_prefix, next_token):
"""
下一个token预测
标准的语言建模任务
"""
return code_prefix, next_token
@staticmethod
def identifier_prediction(code):
"""
标识符预测
掩码变量名、函数名,让模型预测
帮助模型理解命名约定和语义
"""
import ast
class IdentifierMasker(ast.NodeTransformer):
def visit_Name(self, node):
if isinstance(node.ctx, (ast.Store, ast.Load)):
# 80%概率掩码标识符
if np.random.random() < 0.8:
return ast.Constant(value='<MASK>')
return node
tree = ast.parse(code)
masked_tree = IdentifierMasker().visit(tree)
masked_code = ast.unparse(masked_tree)
return masked_code, code
@staticmethod
def edge_prediction(code):
"""
边预测
预测AST中节点之间的边关系
帮助模型理解代码结构
"""
import ast
tree = ast.parse(code)
# 提取AST中的节点和边
nodes = []
edges = []
def collect_nodes(node, parent_id=None):
node_id = len(nodes)
nodes.append(node)
if parent_id is not None:
edges.append((parent_id, node_id))
for child in ast.iter_child_nodes(node):
collect_nodes(child, node_id)
collect_nodes(tree)
# 随机掩码一些边,让模型预测
masked_edges = []
labels = []
for edge in edges:
if np.random.random() < 0.15:
masked_edges.append((edge[0], -1)) # -1表示掩码
labels.append(edge[1])
else:
masked_edges.append(edge)
labels.append(-1)
return masked_edges, labels
@staticmethod
def contrastive_learning(positive_pairs, negative_pairs):
"""
对比学习
让相似的代码片段在向量空间中靠近
不同的代码片段远离
"""
# 使用InfoNCE损失
# 让正样本对的相似度最大化,负样本对的相似度最小化
pass
第四部分:训练策略
4.1 训练流程
# 代码模型训练器
class CodeModelTrainer:
"""
代码模型训练器
处理代码数据的特殊需求
"""
def __init__(self, model, config):
self.model = model
self.config = config
# 优化器
self.optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
)
# 学习率调度器
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
self.optimizer, T_0=config.warmup_steps, T_mult=2
)
# 损失函数
self.criterion = nn.CrossEntropyLoss(ignore_index=-1)
def train_epoch(self, dataloader):
self.model.train()
total_loss = 0
for batch in dataloader:
input_ids = batch['input_ids'].to(self.config.device)
attention_mask = batch['attention_mask'].to(self.config.device)
labels = batch['labels'].to(self.config.device)
ast_types = batch.get('ast_types', None)
# 前向传播
logits = self.model(input_ids, ast_types, attention_mask)
# 计算损失
loss = self.criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
# 反向传播
self.optimizer.zero_grad()
loss.backward()
# 梯度裁剪(防止梯度爆炸)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
self.optimizer.step()
self.scheduler.step()
total_loss += loss.item()
return total_loss / len(dataloader)
def evaluate(self, dataloader):
self.model.eval()
total_loss = 0
with torch.no_grad():
for batch in dataloader:
input_ids = batch['input_ids'].to(self.config.device)
attention_mask = batch['attention_mask'].to(self.config.device)
labels = batch['labels'].to(self.config.device)
logits = self.model(input_ids, None, attention_mask)
loss = self.criterion(logits.view(-1, logits.size(-1)), labels.view(-1))
total_loss += loss.item()
return total_loss / len(dataloader)
4.2 课程学习
class CurriculumLearningScheduler:
"""
课程学习调度器
从简单到复杂逐步训练模型
"""
def __init__(self):
self.curriculum_stages = [
{
'name': 'basic_syntax',
'description': '基础语法',
'max_complexity': 5,
'min_line_count': 1,
'max_line_count': 10
},
{
'name': 'control_flow',
'description': '控制流',
'max_complexity': 10,
'min_line_count': 5,
'max_line_count': 30
},
{
'name': 'functions',
'description': '函数定义',
'max_complexity': 15,
'min_line_count': 10,
'max_line_count': 50
},
{
'name': 'classes',
'description': '类和对象',
'max_complexity': 20,
'min_line_count': 20,
'max_line_count': 100
},
{
'name': 'advanced',
'description': '高级特性',
'max_complexity': 100,
'min_line_count': 30,
'max_line_count': 500
}
]
self.current_stage = 0
def get_next_batch(self, dataset):
"""根据当前阶段筛选合适的数据"""
stage = self.curriculum_stages[self.current_stage]
filtered_data = []
for item in dataset:
if (item['complexity'] <= stage['max_complexity'] and
stage['min_line_count'] <= item['line_count'] <= stage['max_line_count']):
filtered_data.append(item)
# 当当前阶段数据不足时,进入下一阶段
if len(filtered_data) < 1000 and self.current_stage < len(self.curriculum_stages) - 1:
self.current_stage += 1
return self.get_next_batch(dataset)
return filtered_data