用 PyTorch 实现 LLM-JEPA:不预测 token,预测嵌入

简介: 本文从零实现LLM-JEPA:将大语言模型与联合嵌入预测架构(JEPA)结合。通过span遮蔽构造context/target双视图,用可训练编码器预测目标编码器在遮蔽位置的归一化嵌入,以余弦距离为对齐损失,并通过EMA稳定训练。代码简洁清晰,逐行注释,助你深入理解JEPA核心思想。

这篇文章从头实现 LLM-JEPA: Large Language Models Meet Joint Embedding Predictive Architectures。需要说明的是,这里写的是一个简洁的最小化训练脚本,目标是了解 JEPA 的本质:对同一文本创建两个视图,预测被遮蔽片段的嵌入,用表示对齐损失来训练。

本文的目标是让你真正理解这套方法。代码会逐行讲解,每个函数的用途都会解释清楚,并和论文的核心直觉对应起来。每个代码块都会详细说明,方便你根据自己的实验需求进行修改。

代码

整个 LLM-JEPA 训练脚本放在一个文件里:

它接收原始文本然后创建两个视图:context 视图把某些片段替换成 [MASK],target 视图保留原始文本但只在被遮蔽位置做监督。Context 编码器是可训练的,负责预测 target 编码器在遮蔽位置的表示。Target 编码器则是 context 编码器的 EMA 副本,不参与梯度计算。损失函数用的是预测嵌入和目标嵌入之间的余弦距离。

运行示例:

 # 小型冒烟测试(无需下载,随机初始化)
python llm_jepa_train.py --smoke_test

# 使用 HF 模型骨干训练
python llm_jepa_train.py --model_name distilbert-base-uncased --steps 200 --batch_size 8

# 在自己的文本文件上训练
 python llm_jepa_train.py --model_name distilbert-base-uncased --text_file data.txt --steps 2000

这是一个简洁的参考实现,不是完整的仓库代码。编码器用的是 Transformers 库。

 import argparse  
import math  
import os  
import random  
from dataclasses import dataclass  
from typing import List, Tuple, Optional  

import torch  
import torch.nn as nn  
import torch.nn.functional as F  
from torch.utils.data import Dataset, DataLoader  

try:  
    from transformers import AutoTokenizer, AutoModel, AutoConfig  
except Exception:  
    AutoTokenizer = None  
    AutoModel = None  
    AutoConfig = None

# -----------------------------  
# Utilities  
# -----------------------------  
def set_seed(seed: int):  
    random.seed(seed)  
    torch.manual_seed(seed)  
    torch.cuda.manual_seed_all(seed)

def pick_device(device_str: str) -> torch.device:  
    if device_str == "auto":  
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")  
    return torch.device(device_str)

# -----------------------------  
# Span masking (simple + effective)  
# -----------------------------  
def sample_span_mask(  
    seq_len: int,  
    mask_ratio: float,  
    mean_span_len: int,  
    special_positions: Optional[set] = None,  
) -> torch.BoolTensor:  
    """  
    Returns a boolean mask of length seq_len indicating which positions are masked.  
    We mask contiguous spans until we reach approximately mask_ratio of tokens.  
    """  
    if special_positions is None:  
        special_positions = set()  

    mask = torch.zeros(seq_len, dtype=torch.bool)  
    if seq_len <= 0:  
        return mask  

    target_to_mask = max(1, int(round(seq_len * mask_ratio)))  
    masked = 0  

    attempts = 0  
    max_attempts = seq_len * 4  

    while masked < target_to_mask and attempts < max_attempts:  
        attempts += 1  

        span_len = max(1, int(random.expovariate(1.0 / max(1, mean_span_len))))  
        span_len = min(span_len, seq_len)  

        start = random.randint(0, seq_len - 1)  
        end = min(seq_len, start + span_len)  

        span_positions = [i for i in range(start, end) if i not in special_positions]  
        if not span_positions:  
            continue  

        newly = 0  
        for i in span_positions:  
            if not mask[i]:  
                mask[i] = True  
                newly += 1  

        masked += newly  

    return mask

def apply_mask_to_input_ids(  
    input_ids: torch.LongTensor,  
    attention_mask: torch.LongTensor,  
    tokenizer,  
    mask_ratio: float,  
    mean_span_len: int,  
) -> Tuple[torch.LongTensor, torch.BoolTensor]:  
    """  
    Masks spans inside non-special, non-padding tokens.  
    Returns:  
      masked_input_ids: input ids with masked tokens replaced by [MASK]  
      pred_mask: boolean mask over positions where we apply JEPA loss  
    """  
    assert input_ids.dim() == 1  
    seq_len = int(attention_mask.sum().item())  

    # Identify special token positions (CLS, SEP, etc.) in the visible region  
    special_positions = set()  
    for i in range(seq_len):  
        tid = int(input_ids[i].item())  
        if tid in {  
            tokenizer.cls_token_id,  
            tokenizer.sep_token_id,  
            tokenizer.pad_token_id,  
        }:  
            special_positions.add(i)  

    pred_mask = sample_span_mask(  
        seq_len=seq_len,  
        mask_ratio=mask_ratio,  
        mean_span_len=mean_span_len,  
        special_positions=special_positions,  
    )  

    masked_input_ids = input_ids.clone()  
    mask_token_id = tokenizer.mask_token_id  
    if mask_token_id is None:  
        raise ValueError("Tokenizer has no mask_token_id. Use a model with [MASK].")  

    # Replace masked positions with [MASK]  
    masked_input_ids[:seq_len][pred_mask] = mask_token_id  

    # pred_mask should be full length (includes pads as False)  
    full_mask = torch.zeros_like(attention_mask, dtype=torch.bool)  
    full_mask[:seq_len] = pred_mask  

    return masked_input_ids, full_mask

# -----------------------------  
# Dataset  
# -----------------------------  
class TextLinesDataset(Dataset):  
    def __init__(self, texts: List[str]):  
        self.texts = [t.strip() for t in texts if t.strip()]  

    def __len__(self) -> int:  
        return len(self.texts)  

    def __getitem__(self, idx: int) -> str:  
        return self.texts[idx]

def load_texts_from_file(path: str, max_lines: Optional[int] = None) -> List[str]:  
    texts = []  
    with open(path, "r", encoding="utf-8") as f:  
        for i, line in enumerate(f):  
            if max_lines is not None and i >= max_lines:  
                break  
            texts.append(line.rstrip("\n"))  
    return texts

def default_tiny_corpus() -> List[str]:  
    return [  
        "The cat sat on the mat and looked at the window.",  
        "A quick brown fox jumps over the lazy dog.",  
        "Deep learning models can learn useful representations from raw data.",  
        "Rocket Learning builds AI tools for education in India.",  
        "Transformers use attention to mix information across tokens.",  
        "Self-supervised learning can reduce the need for labels.",  
        "JEPA trains models to predict embeddings, not tokens.",  
        "Bengaluru is a major tech hub in India.",  
        "A good system design balances simplicity and scalability.",  
        "Reading code carefully helps you understand how an idea is implemented.",  
    ]

@dataclass  
class Batch:  
    input_ids: torch.LongTensor          # [B, L]  
    attention_mask: torch.LongTensor     # [B, L]  
    masked_input_ids: torch.LongTensor   # [B, L]  
    pred_mask: torch.BoolTensor          # [B, L]  positions to compute loss on

def collate_jepa(  
    batch_texts: List[str],  
    tokenizer,  
    max_length: int,  
    mask_ratio: float,  
    mean_span_len: int,  
) -> Batch:  
    toks = tokenizer(  
        batch_texts,  
        padding=True,  
        truncation=True,  
        max_length=max_length,  
        return_tensors="pt",  
    )  
    input_ids = toks["input_ids"]              # [B, L]  
    attention_mask = toks["attention_mask"]    # [B, L]  

    masked_input_ids_list = []  
    pred_mask_list = []  

    for b in range(input_ids.size(0)):  
        mi, pm = apply_mask_to_input_ids(  
            input_ids[b],  
            attention_mask[b],  
            tokenizer,  
            mask_ratio=mask_ratio,  
            mean_span_len=mean_span_len,  
        )  
        masked_input_ids_list.append(mi)  
        pred_mask_list.append(pm)  

    masked_input_ids = torch.stack(masked_input_ids_list, dim=0)  
    pred_mask = torch.stack(pred_mask_list, dim=0)  

    return Batch(  
        input_ids=input_ids,  
        attention_mask=attention_mask,  
        masked_input_ids=masked_input_ids,  
        pred_mask=pred_mask,  
    )

# -----------------------------  
# Model: Encoder + Predictor + EMA target encoder  
# -----------------------------  
class PredictorMLP(nn.Module):  
    def __init__(self, dim: int, hidden_mult: int = 4, dropout: float = 0.0):  
        super().__init__()  
        hidden = dim * hidden_mult  
        self.net = nn.Sequential(  
            nn.Linear(dim, hidden),  
            nn.GELU(),  
            nn.Dropout(dropout),  
            nn.Linear(hidden, dim),  
        )  

    def forward(self, x: torch.Tensor) -> torch.Tensor:  
        return self.net(x)

class LLMJEPA(nn.Module):  
    def __init__(self, encoder: nn.Module, dim: int, ema_m: float = 0.99, pred_hidden_mult: int = 4):  
        super().__init__()  
        self.context_encoder = encoder  
        self.target_encoder = self._copy_encoder(encoder)  
        self.predictor = PredictorMLP(dim=dim, hidden_mult=pred_hidden_mult, dropout=0.0)  
        self.ema_m = ema_m  

        for p in self.target_encoder.parameters():  
            p.requires_grad = False  

    @staticmethod  
    def _copy_encoder(enc: nn.Module) -> nn.Module:  
        import copy  
        return copy.deepcopy(enc)  

    @torch.no_grad()  
    def ema_update(self):  
        m = self.ema_m  
        for p_ctx, p_tgt in zip(self.context_encoder.parameters(), self.target_encoder.parameters()):  
            p_tgt.data.mul_(m).add_(p_ctx.data, alpha=(1.0 - m))  

    def forward(  
        self,  
        masked_input_ids: torch.LongTensor,  
        input_ids: torch.LongTensor,  
        attention_mask: torch.LongTensor,  
        pred_mask: torch.BoolTensor,  
    ) -> torch.Tensor:  
        """  
        Returns JEPA loss (scalar).  
        We compute:  
          z_ctx = context_encoder(masked_input)  
          z_tgt = target_encoder(full input)  
          pred = predictor(z_ctx)  
          loss over positions in pred_mask  
        """  
        out_ctx = self.context_encoder(input_ids=masked_input_ids, attention_mask=attention_mask)  
        z_ctx = out_ctx.last_hidden_state  # [B, L, D]  

        with torch.no_grad():  
            out_tgt = self.target_encoder(input_ids=input_ids, attention_mask=attention_mask)  
            z_tgt = out_tgt.last_hidden_state  # [B, L, D]  

        pred = self.predictor(z_ctx)  # [B, L, D]  

        # Select masked positions  
        # pred_mask: [B, L] bool  
        masked_pred = pred[pred_mask]  # [N, D]  
        masked_tgt = z_tgt[pred_mask]  # [N, D]  

        if masked_pred.numel() == 0:  
            # Safety: if a batch ends up with no masked tokens, return zero loss  
            return pred.sum() * 0.0  

        masked_pred = F.normalize(masked_pred, dim=-1)  
        masked_tgt = F.normalize(masked_tgt, dim=-1)  

        # Cosine distance  
        loss = 1.0 - (masked_pred * masked_tgt).sum(dim=-1)  
        return loss.mean()

# -----------------------------  
# Training  
# -----------------------------  
def build_hf_encoder(model_name: str):  
    if AutoModel is None:  
        raise RuntimeError("transformers is not installed. pip install transformers")  

    config = AutoConfig.from_pretrained(model_name)  
    encoder = AutoModel.from_pretrained(model_name, config=config)  
    dim = int(config.hidden_size)  
    return encoder, dim

def build_random_encoder(vocab_size: int = 30522, dim: int = 256, layers: int = 4, heads: int = 4):  
    """  
    For smoke tests only: small Transformer encoder (random init).  
    Requires a tokenizer with vocab mapping for ids.  
    """  
    encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, batch_first=True)  
    transformer = nn.TransformerEncoder(encoder_layer, num_layers=layers)  

    class TinyEncoder(nn.Module):  
        def __init__(self):  
            super().__init__()  
            self.emb = nn.Embedding(vocab_size, dim)  
            self.pos = nn.Embedding(512, dim)  
            self.enc = transformer  

        def forward(self, input_ids, attention_mask):  
            B, L = input_ids.shape  
            pos_ids = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, L)  
            x = self.emb(input_ids) + self.pos(pos_ids)  

            # attention_mask: 1 for keep, 0 for pad  
            # transformer expects src_key_padding_mask: True for pad  
            pad_mask = attention_mask == 0  
            h = self.enc(x, src_key_padding_mask=pad_mask)  
            return type("Out", (), {"last_hidden_state": h})  

    return TinyEncoder(), dim

def save_checkpoint(path: str, model: LLMJEPA, optimizer: torch.optim.Optimizer, step: int):  
    os.makedirs(os.path.dirname(path), exist_ok=True)  
    torch.save(  
        {  
            "step": step,  
            "context_encoder": model.context_encoder.state_dict(),  
            "target_encoder": model.target_encoder.state_dict(),  
            "predictor": model.predictor.state_dict(),  
            "optimizer": optimizer.state_dict(),  
        },  
        path,  
    )

def main():  
    parser = argparse.ArgumentParser()  
    parser.add_argument("--model_name", type=str, default="distilbert-base-uncased", help="HF encoder backbone")  
    parser.add_argument("--text_file", type=str, default="", help="Path to a newline-separated text file")  
    parser.add_argument("--max_lines", type=int, default=50000)  
    parser.add_argument("--max_length", type=int, default=128)  
    parser.add_argument("--mask_ratio", type=float, default=0.3)  
    parser.add_argument("--mean_span_len", type=int, default=5)  
    parser.add_argument("--ema_m", type=float, default=0.99)  
    parser.add_argument("--pred_hidden_mult", type=int, default=4)  

    parser.add_argument("--batch_size", type=int, default=8)  
    parser.add_argument("--lr", type=float, default=2e-5)  
    parser.add_argument("--weight_decay", type=float, default=0.01)  
    parser.add_argument("--steps", type=int, default=500)  
    parser.add_argument("--warmup_steps", type=int, default=50)  
    parser.add_argument("--log_every", type=int, default=25)  
    parser.add_argument("--save_every", type=int, default=200)  
    parser.add_argument("--save_path", type=str, default="checkpoints/llm_jepa.pt")  

    parser.add_argument("--device", type=str, default="auto")  
    parser.add_argument("--seed", type=int, default=42)  
    parser.add_argument("--smoke_test", action="store_true", help="No downloads, tiny random encoder, tiny corpus")  
    args = parser.parse_args()  

    set_seed(args.seed)  
    device = pick_device(args.device)  

    if args.smoke_test:  
        if AutoTokenizer is None:  
            raise RuntimeError("transformers is required even for smoke_test (for tokenizer).")  
        tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")  
        # Ensure mask token exists  
        if tokenizer.mask_token_id is None:  
            raise ValueError("Tokenizer must support [MASK]. Use a masked LM tokenizer.")  

        texts = default_tiny_corpus()  
        ds = TextLinesDataset(texts)  

        encoder, dim = build_random_encoder(vocab_size=int(tokenizer.vocab_size), dim=256, layers=4, heads=4)  
        model = LLMJEPA(encoder=encoder, dim=dim, ema_m=0.95, pred_hidden_mult=2).to(device)  

        lr = 1e-4  
    else:  
        if AutoTokenizer is None:  
            raise RuntimeError("transformers is not installed. pip install transformers")  
        tokenizer = AutoTokenizer.from_pretrained(args.model_name)  
        if tokenizer.mask_token_id is None:  
            raise ValueError(  
                "This tokenizer has no [MASK]. Pick a masked-encoder model (BERT/DeBERTa/DistilBERT)."  
            )  

        if args.text_file:  
            texts = load_texts_from_file(args.text_file, max_lines=args.max_lines)  
        else:  
            texts = default_tiny_corpus()  

        ds = TextLinesDataset(texts)  

        encoder, dim = build_hf_encoder(args.model_name)  
        model = LLMJEPA(encoder=encoder, dim=dim, ema_m=args.ema_m, pred_hidden_mult=args.pred_hidden_mult).to(device)  

        lr = args.lr  

    # DataLoader  
    def _collate(batch_texts):  
        return collate_jepa(  
            batch_texts=batch_texts,  
            tokenizer=tokenizer,  
            max_length=args.max_length,  
            mask_ratio=args.mask_ratio,  
            mean_span_len=args.mean_span_len,  
        )  

    dl = DataLoader(ds, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=_collate)  

    # Optimizer  
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=args.weight_decay)  

    # Simple warmup + cosine schedule  
    def lr_at(step: int) -> float:  
        if step < args.warmup_steps:  
            return float(step + 1) / float(max(1, args.warmup_steps))  
        progress = (step - args.warmup_steps) / float(max(1, args.steps - args.warmup_steps))  
        progress = min(max(progress, 0.0), 1.0)  
        return 0.5 * (1.0 + math.cos(math.pi * progress))  

    model.train()  
    running = 0.0  
    step = 0  
    data_iter = iter(dl)  

    while step < args.steps:  
        try:  
            batch = next(data_iter)  
        except StopIteration:  
            data_iter = iter(dl)  
            batch = next(data_iter)  

        # Move to device  
        input_ids = batch.input_ids.to(device)  
        attention_mask = batch.attention_mask.to(device)  
        masked_input_ids = batch.masked_input_ids.to(device)  
        pred_mask = batch.pred_mask.to(device)  

        # LR schedule  
        scale = lr_at(step)  
        for pg in optimizer.param_groups:  
            pg["lr"] = lr * scale  

        loss = model(  
            masked_input_ids=masked_input_ids,  
            input_ids=input_ids,  
            attention_mask=attention_mask,  
            pred_mask=pred_mask,  
        )  

        optimizer.zero_grad(set_to_none=True)  
        loss.backward()  
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  
        optimizer.step()  

        # EMA update after optimizer step  
        model.ema_update()  

        running += float(loss.item())  
        step += 1  

        if step % args.log_every == 0:  
            avg = running / float(args.log_every)  
            running = 0.0  
            print(f"step {step:6d} | loss {avg:.4f} | lr {optimizer.param_groups[0]['lr']:.6g}")  

        if step % args.save_every == 0:  
            save_checkpoint(args.save_path, model, optimizer, step)  
            print(f"saved checkpoint to {args.save_path} at step {step}")  

    save_checkpoint(args.save_path, model, optimizer, step)  
    print(f"training done. final checkpoint: {args.save_path}")

if __name__ == "__main__":  
     main()

这个脚本在训练什么

这是一个面向文本的 JEPA 风格表示预测器。

输入普通文本行,对每个样本创建两个视图。遮蔽视图(context view)是同一个句子,但某些 span 被替换成 `[MASK];原始视图(target view)保持原样,没有遮蔽。

训练流程是这样的:遮蔽视图过一个可训练的 context 编码器,原始视图过一个不可训练的 target 编码器,然后训练一个预测器,让 context 编码器的表示能预测 target 编码器的表示——但只在被遮蔽的位置上计算损失。Target 编码器通过 EMA 更新来保持稳定。

这种设计鼓励模型学习"填补语义"的表示,而不是预测具体的 token。

set_seed 函数

 defset_seed(seed: int):  
     random.seed(seed)  
     torch.manual_seed(seed)  
     torch.cuda.manual_seed_all(seed)

这个函数确保运行可复现。

random.seed(seed)

固定 Python 的随机操作(span 遮蔽会用到),

torch.manual_seed(seed)

固定 PyTorch 在 CPU 上的随机性,

torch.cuda.manual_seed_all(seed)

固定 CUDA 内核的随机性。

span 遮蔽和模型初始化都是随机的,不设种子的话每次跑结果都不一样。

pick_device 函数

 def pick_device(device_str: str) -> torch.device:  
     if device_str == "auto":  
         return torch.device("cuda" if torch.cuda.is_available() else "cpu")  
     return torch.device(device_str)

返回 PyTorch 设备对象。如果传

--device auto

,有 GPU 就用 GPU,没有就用 CPU。也可以直接指定

--device cpu

--device cuda

张量和模型必须在同一设备上,这是基本要求。

sample_span_mask 函数

 def sample_span_mask(seq_len, mask_ratio, mean_span_len, special_positions=None)

整个脚本里最重要的函数之一。

目标是创建一个布尔掩码,标记序列中哪些位置该被遮蔽。参数包括:seq_len 是真实 token 数量(不含 padding),mask_ratio 是遮蔽比例(比如 0.3),mean_span_len 是连续遮蔽 span 的平均长度,special_positions 是永远不该遮蔽的位置(CLS、SEP、PAD)。

内部逻辑是先创建一个全 False 的掩码,然后计算需要遮蔽多少 token:

 target_to_mask=max(1, int(round(seq_len*mask_ratio)))

即使序列很短也至少遮蔽 1 个。

接下来循环采样 span 直到凑够数。Span 长度从指数分布采样:

 span_len=max(1, int(random.expovariate(1.0/max(1, mean_span_len))))

这会产出很多短 span 和少量长 span,比较符合自然分布。随机选一个起始位置,过滤掉特殊 token,把剩下的位置标记为 True。

遮蔽策略对表示学习质量影响很大。Span 遮蔽能迫使模型从周围上下文推断缺失的语义。

apply_mask_to_input_ids 函数

 defapply_mask_to_input_ids(input_ids, attention_mask, tokenizer, mask_ratio, mean_span_len)

拿到一个样本的 token ids,输出两个东西:masked_input_ids 是把遮蔽位置换成 [MASK] 后的 ids,pred_mask 是标记哪些位置要算损失的布尔掩码。

先算可见序列长度:

seq_len = int(attention_mask.sum().item())

。attention_mask 里真实 token 是 1,padding 是 0。

然后识别特殊 token 位置,CLS 和 SEP 不能遮蔽,否则模型容易出问题。调用 sample_span_mask 采样遮蔽位置,把这些位置替换成 mask_token_id:

 masked_input_ids[:seq_len][pred_mask] =mask_token_id

返回的 pred_mask 是完整长度的,padding 位置都是 False。只在遮蔽位置算 JEPA 损失,其他位置忽略。

TextLinesDataset 类

 classTextLinesDataset(Dataset):  
     def__init__(self, texts):  
         self.texts= [t.strip() fortintextsift.strip()]

极简的数据集实现,存文本行列表,去掉空行和首尾空白。

__len__

返回行数,

__getitem__

返回单条文本。

load_texts_from_file 逐行读文件,可限制最大行数,传

--text_file

时用。default_tiny_corpus 提供内置测试数据集。

Batch 数据类

 @dataclass  
 classBatch:  
     input_ids  
     attention_mask  
     masked_input_ids  
     pred_mask

用 dataclass 比返回元组清晰多了,代码可读性好。

collate_jepa 函数

DataLoader 创建批次时调用的函数。输入是原始文本列表,先用 tokenizer 做分词、padding、截断:

 toks=tokenizer(batch_texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt")

产出 input_ids 和 attention_mask。然后对每个样本调 apply_mask_to_input_ids 生成遮蔽版本和 pred_mask,最后堆叠成 [B, L] 张量返回 Batch。

DataLoader 是逐样本读的,但训练需要批次。批处理和遮蔽都在这里发生。

PredictorMLP 类

预测器头,结构简单:

 nn.Linear(dim, hidden)  
 nn.GELU()  
 nn.Dropout()  
 nn.Linear(hidden, dim)

把 context 表示映射到 target 表示空间,相当于一个学习出来的适配器,帮助对齐两边的嵌入。

LLMJEPA 模型类

主模型包装器,包含四个核心部件:context_encoder 是可训练的 Transformer 编码器,target_encoder 是它的深拷贝但不可训练,predictor 是 MLP,ema_m 是 EMA 动量因子。

_copy_encoder 用

copy.deepcopy

确保 target 和 context 初始状态一致。

ema_update 缓慢更新 target 编码器权重:

 p_tgt=m*p_tgt+ (1-m) *p_ctx

m=0.99 时 target 变化非常慢,这能稳定训练、降低表示坍塌风险。

forward 的流程:把遮蔽视图过 context 编码器(可训练),原始视图过 target 编码器(无梯度),predictor 处理 context 输出,然后只取遮蔽位置的向量:

 masked_pred=pred[pred_mask]  # [N, D]  
 masked_tgt=z_tgt[pred_mask]  # [N, D]

从 [B, L, D] 变成 [N, D],N 是遮蔽 token 总数。归一化后算余弦距离:

 loss=1- (masked_pred*masked_tgt).sum(dim=-1)  
 returnloss.mean()

归一化是因为余弦相似度只看向量方向,不看大小。

build_hf_encoder 函数

加载 Hugging Face 编码器,返回模型和隐藏维度(从 config.hidden_size 读)。

build_random_encoder 函数

冒烟测试专用,从头建一个小 Transformer 编码器,包括嵌入层、位置嵌入、编码器堆栈。注意这不是掩码语言模型,只是个编码器架构。返回对象带

.last_hidden_state

属性是为了匹配 HF 输出格式。

总结

这个实现刻意追求清晰而非完整,所以没有自定义注意力掩码、多视图数据集或混合目标。但是把它当参考实现用是非常合适的。原始 LLM-JEPA 论文做得更深入,把 JEPA 和 token 预测结合起来,还利用了文本-代码这样的自然配对视图。那些设计对下游任务表现很重要,但也增加了复杂度,容易让人看不清核心机制。

论文:

https://avoid.overfit.cn/post/09eb991a93f64a83a376cdb52ac5c661

作者:azhar

目录
相关文章
|
3月前
|
负载均衡 数据中心 异构计算
大模型如何训练百万 Token 上下文:上下文并行与 Ring Attention
上下文窗口暴增至千万级,但硬件难承其重:405B模型单精度权重就需6.5TB内存。为突破显存瓶颈,上下文并行与Ring Attention应运而生——将长序列切分至多卡,边传边算;Zig-Zag分配更实现因果注意力下的负载均衡。高速互连(NVLink/InfiniBand)已成刚需。
276 4
大模型如何训练百万 Token 上下文:上下文并行与 Ring Attention
|
4月前
|
机器学习/深度学习 人工智能 PyTorch
深度解析 Google JAX 全栈:带你上手开发,从零构建神经网络
Google凭借JAX AI栈实现AI全栈垂直整合,覆盖模型、应用、云与硬件。JAX结合XLA编译器,Flax构建网络,Optax优化训练,Orbax管理 checkpoint,已在Google及Anthropic、Apple等广泛应用,助力高效大规模AI训练。
592 6
|
2月前
|
机器学习/深度学习 存储 人工智能
让 AI 智能体学会自我进化:Agent Lightning 实战入门
Agent Lightning 是一个框架无关的强化学习包装层,赋能现有AI智能体实现在线持续学习。它解耦执行与训练,支持LangChain/AutoGen等任意框架,通过VERL算法解决稀疏奖励难题,让智能体从运行反馈中自动优化提示词与策略。
390 5
让 AI 智能体学会自我进化:Agent Lightning 实战入门
|
机器学习/深度学习 人工智能 数据管理
文生图的基石CLIP模型的发展综述
CLIP(Contrastive Language-Image Pre-training)是OpenAI在2021年发布的多模态模型,用于学习文本-图像对的匹配。模型由文本和图像编码器组成,通过对比学习使匹配的输入对在向量空间中靠近,非匹配对远离。预训练后,CLIP被广泛应用于各种任务,如零样本分类和语义搜索。后续研究包括ALIGN、K-LITE、OpenCLIP、MetaCLIP和DFN,它们分别在数据规模、知识增强、性能缩放和数据过滤等方面进行了改进和扩展,促进了多模态AI的发展。
2893 0
|
2月前
|
人工智能 NoSQL Redis
LangGraph 入门:用图结构构建你的第一个多智能体工作流
LangGraph 是面向多智能体系统的图编排框架,以有向状态图替代线性链式调用。通过节点(智能体)、边(条件/静态跳转)和类型化共享状态三者解耦,天然支持分支、循环、并行与汇合;内置检查点、原子状态更新与Reducer机制,保障一致性、可调试性与容错恢复能力。
2670 1
|
3月前
|
存储 人工智能 Java
用 AgentScope Java 开家 AI 奶茶店
开一家 AI 奶茶店,让 AgentScope Java 替你打理一切。
1233 43
|
2月前
|
人工智能 测试技术
LLM创造力可以被度量吗?一个基于提示词变更的探索性实验
本文探讨提示词工程为何仍是“玄学”,并通过实验证明:加入明确指令(如“Be as creative as possible”)可显著、可量化地提升LLM输出多样性,效果甚至超过调高温度。研究以embedding距离为代理指标,覆盖13个主流模型,揭示提示词迭代可度量、可预测,为LLM应用从经验走向工程化提供新路径。
185 17
LLM创造力可以被度量吗?一个基于提示词变更的探索性实验
|
2月前
|
存储 监控 搜索推荐
长上下文"记忆"的舒适陷阱:为什么更多记忆不等于更可靠
本文警示长上下文的隐性风险:它虽提升交互顺手度,却严重损害可靠性、可测试性与可重复性;共享账户导致意图混杂,“我是谁”故障频发;向量平均无法调和对立目标;上下文膨胀引发注意力稀释、幻觉加剧与约束遗忘。生产中须以预算制、会话隔离、结构化记忆和可控重置进行主动治理。
186 2
长上下文"记忆"的舒适陷阱:为什么更多记忆不等于更可靠
|
3月前
|
SQL 人工智能 安全
手把手教你调出“懂你”的AI:大模型微调实战与资源管理
本文深入浅出讲解大模型微调核心知识:用生活化比喻解析学习率、训练轮数、批量大小、截断长度和LoRA秩五大关键参数;提供适配不同显存的实操配置表;分享Liger Kernel、DeepSpeed等省显存技巧;并强调定量、定性与效率三维评估。零基础也能快速上手定制专属AI。
419 11
手把手教你调出“懂你”的AI:大模型微调实战与资源管理