用 PyTorch 实现 LLM-JEPA:不预测 token,预测嵌入

简介: 本文从零实现LLM-JEPA:将大语言模型与联合嵌入预测架构(JEPA)结合。通过span遮蔽构造context/target双视图,用可训练编码器预测目标编码器在遮蔽位置的归一化嵌入,以余弦距离为对齐损失,并通过EMA稳定训练。代码简洁清晰,逐行注释,助你深入理解JEPA核心思想。

这篇文章从头实现 LLM-JEPA: Large Language Models Meet Joint Embedding Predictive Architectures。需要说明的是,这里写的是一个简洁的最小化训练脚本,目标是了解 JEPA 的本质:对同一文本创建两个视图,预测被遮蔽片段的嵌入,用表示对齐损失来训练。

本文的目标是让你真正理解这套方法。代码会逐行讲解,每个函数的用途都会解释清楚,并和论文的核心直觉对应起来。每个代码块都会详细说明,方便你根据自己的实验需求进行修改。

代码

整个 LLM-JEPA 训练脚本放在一个文件里:

它接收原始文本然后创建两个视图:context 视图把某些片段替换成 [MASK],target 视图保留原始文本但只在被遮蔽位置做监督。Context 编码器是可训练的,负责预测 target 编码器在遮蔽位置的表示。Target 编码器则是 context 编码器的 EMA 副本,不参与梯度计算。损失函数用的是预测嵌入和目标嵌入之间的余弦距离。

运行示例:

 # 小型冒烟测试(无需下载,随机初始化)
python llm_jepa_train.py --smoke_test

# 使用 HF 模型骨干训练
python llm_jepa_train.py --model_name distilbert-base-uncased --steps 200 --batch_size 8

# 在自己的文本文件上训练
 python llm_jepa_train.py --model_name distilbert-base-uncased --text_file data.txt --steps 2000

这是一个简洁的参考实现,不是完整的仓库代码。编码器用的是 Transformers 库。

 import argparse  
import math  
import os  
import random  
from dataclasses import dataclass  
from typing import List, Tuple, Optional  

import torch  
import torch.nn as nn  
import torch.nn.functional as F  
from torch.utils.data import Dataset, DataLoader  

try:  
    from transformers import AutoTokenizer, AutoModel, AutoConfig  
except Exception:  
    AutoTokenizer = None  
    AutoModel = None  
    AutoConfig = None

# -----------------------------  
# Utilities  
# -----------------------------  
def set_seed(seed: int):  
    random.seed(seed)  
    torch.manual_seed(seed)  
    torch.cuda.manual_seed_all(seed)

def pick_device(device_str: str) -> torch.device:  
    if device_str == "auto":  
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")  
    return torch.device(device_str)

# -----------------------------  
# Span masking (simple + effective)  
# -----------------------------  
def sample_span_mask(  
    seq_len: int,  
    mask_ratio: float,  
    mean_span_len: int,  
    special_positions: Optional[set] = None,  
) -> torch.BoolTensor:  
    """  
    Returns a boolean mask of length seq_len indicating which positions are masked.  
    We mask contiguous spans until we reach approximately mask_ratio of tokens.  
    """  
    if special_positions is None:  
        special_positions = set()  

    mask = torch.zeros(seq_len, dtype=torch.bool)  
    if seq_len <= 0:  
        return mask  

    target_to_mask = max(1, int(round(seq_len * mask_ratio)))  
    masked = 0  

    attempts = 0  
    max_attempts = seq_len * 4  

    while masked < target_to_mask and attempts < max_attempts:  
        attempts += 1  

        span_len = max(1, int(random.expovariate(1.0 / max(1, mean_span_len))))  
        span_len = min(span_len, seq_len)  

        start = random.randint(0, seq_len - 1)  
        end = min(seq_len, start + span_len)  

        span_positions = [i for i in range(start, end) if i not in special_positions]  
        if not span_positions:  
            continue  

        newly = 0  
        for i in span_positions:  
            if not mask[i]:  
                mask[i] = True  
                newly += 1  

        masked += newly  

    return mask

def apply_mask_to_input_ids(  
    input_ids: torch.LongTensor,  
    attention_mask: torch.LongTensor,  
    tokenizer,  
    mask_ratio: float,  
    mean_span_len: int,  
) -> Tuple[torch.LongTensor, torch.BoolTensor]:  
    """  
    Masks spans inside non-special, non-padding tokens.  
    Returns:  
      masked_input_ids: input ids with masked tokens replaced by [MASK]  
      pred_mask: boolean mask over positions where we apply JEPA loss  
    """  
    assert input_ids.dim() == 1  
    seq_len = int(attention_mask.sum().item())  

    # Identify special token positions (CLS, SEP, etc.) in the visible region  
    special_positions = set()  
    for i in range(seq_len):  
        tid = int(input_ids[i].item())  
        if tid in {  
            tokenizer.cls_token_id,  
            tokenizer.sep_token_id,  
            tokenizer.pad_token_id,  
        }:  
            special_positions.add(i)  

    pred_mask = sample_span_mask(  
        seq_len=seq_len,  
        mask_ratio=mask_ratio,  
        mean_span_len=mean_span_len,  
        special_positions=special_positions,  
    )  

    masked_input_ids = input_ids.clone()  
    mask_token_id = tokenizer.mask_token_id  
    if mask_token_id is None:  
        raise ValueError("Tokenizer has no mask_token_id. Use a model with [MASK].")  

    # Replace masked positions with [MASK]  
    masked_input_ids[:seq_len][pred_mask] = mask_token_id  

    # pred_mask should be full length (includes pads as False)  
    full_mask = torch.zeros_like(attention_mask, dtype=torch.bool)  
    full_mask[:seq_len] = pred_mask  

    return masked_input_ids, full_mask

# -----------------------------  
# Dataset  
# -----------------------------  
class TextLinesDataset(Dataset):  
    def __init__(self, texts: List[str]):  
        self.texts = [t.strip() for t in texts if t.strip()]  

    def __len__(self) -> int:  
        return len(self.texts)  

    def __getitem__(self, idx: int) -> str:  
        return self.texts[idx]

def load_texts_from_file(path: str, max_lines: Optional[int] = None) -> List[str]:  
    texts = []  
    with open(path, "r", encoding="utf-8") as f:  
        for i, line in enumerate(f):  
            if max_lines is not None and i >= max_lines:  
                break  
            texts.append(line.rstrip("\n"))  
    return texts

def default_tiny_corpus() -> List[str]:  
    return [  
        "The cat sat on the mat and looked at the window.",  
        "A quick brown fox jumps over the lazy dog.",  
        "Deep learning models can learn useful representations from raw data.",  
        "Rocket Learning builds AI tools for education in India.",  
        "Transformers use attention to mix information across tokens.",  
        "Self-supervised learning can reduce the need for labels.",  
        "JEPA trains models to predict embeddings, not tokens.",  
        "Bengaluru is a major tech hub in India.",  
        "A good system design balances simplicity and scalability.",  
        "Reading code carefully helps you understand how an idea is implemented.",  
    ]

@dataclass  
class Batch:  
    input_ids: torch.LongTensor          # [B, L]  
    attention_mask: torch.LongTensor     # [B, L]  
    masked_input_ids: torch.LongTensor   # [B, L]  
    pred_mask: torch.BoolTensor          # [B, L]  positions to compute loss on

def collate_jepa(  
    batch_texts: List[str],  
    tokenizer,  
    max_length: int,  
    mask_ratio: float,  
    mean_span_len: int,  
) -> Batch:  
    toks = tokenizer(  
        batch_texts,  
        padding=True,  
        truncation=True,  
        max_length=max_length,  
        return_tensors="pt",  
    )  
    input_ids = toks["input_ids"]              # [B, L]  
    attention_mask = toks["attention_mask"]    # [B, L]  

    masked_input_ids_list = []  
    pred_mask_list = []  

    for b in range(input_ids.size(0)):  
        mi, pm = apply_mask_to_input_ids(  
            input_ids[b],  
            attention_mask[b],  
            tokenizer,  
            mask_ratio=mask_ratio,  
            mean_span_len=mean_span_len,  
        )  
        masked_input_ids_list.append(mi)  
        pred_mask_list.append(pm)  

    masked_input_ids = torch.stack(masked_input_ids_list, dim=0)  
    pred_mask = torch.stack(pred_mask_list, dim=0)  

    return Batch(  
        input_ids=input_ids,  
        attention_mask=attention_mask,  
        masked_input_ids=masked_input_ids,  
        pred_mask=pred_mask,  
    )

# -----------------------------  
# Model: Encoder + Predictor + EMA target encoder  
# -----------------------------  
class PredictorMLP(nn.Module):  
    def __init__(self, dim: int, hidden_mult: int = 4, dropout: float = 0.0):  
        super().__init__()  
        hidden = dim * hidden_mult  
        self.net = nn.Sequential(  
            nn.Linear(dim, hidden),  
            nn.GELU(),  
            nn.Dropout(dropout),  
            nn.Linear(hidden, dim),  
        )  

    def forward(self, x: torch.Tensor) -> torch.Tensor:  
        return self.net(x)

class LLMJEPA(nn.Module):  
    def __init__(self, encoder: nn.Module, dim: int, ema_m: float = 0.99, pred_hidden_mult: int = 4):  
        super().__init__()  
        self.context_encoder = encoder  
        self.target_encoder = self._copy_encoder(encoder)  
        self.predictor = PredictorMLP(dim=dim, hidden_mult=pred_hidden_mult, dropout=0.0)  
        self.ema_m = ema_m  

        for p in self.target_encoder.parameters():  
            p.requires_grad = False  

    @staticmethod  
    def _copy_encoder(enc: nn.Module) -> nn.Module:  
        import copy  
        return copy.deepcopy(enc)  

    @torch.no_grad()  
    def ema_update(self):  
        m = self.ema_m  
        for p_ctx, p_tgt in zip(self.context_encoder.parameters(), self.target_encoder.parameters()):  
            p_tgt.data.mul_(m).add_(p_ctx.data, alpha=(1.0 - m))  

    def forward(  
        self,  
        masked_input_ids: torch.LongTensor,  
        input_ids: torch.LongTensor,  
        attention_mask: torch.LongTensor,  
        pred_mask: torch.BoolTensor,  
    ) -> torch.Tensor:  
        """  
        Returns JEPA loss (scalar).  
        We compute:  
          z_ctx = context_encoder(masked_input)  
          z_tgt = target_encoder(full input)  
          pred = predictor(z_ctx)  
          loss over positions in pred_mask  
        """  
        out_ctx = self.context_encoder(input_ids=masked_input_ids, attention_mask=attention_mask)  
        z_ctx = out_ctx.last_hidden_state  # [B, L, D]  

        with torch.no_grad():  
            out_tgt = self.target_encoder(input_ids=input_ids, attention_mask=attention_mask)  
            z_tgt = out_tgt.last_hidden_state  # [B, L, D]  

        pred = self.predictor(z_ctx)  # [B, L, D]  

        # Select masked positions  
        # pred_mask: [B, L] bool  
        masked_pred = pred[pred_mask]  # [N, D]  
        masked_tgt = z_tgt[pred_mask]  # [N, D]  

        if masked_pred.numel() == 0:  
            # Safety: if a batch ends up with no masked tokens, return zero loss  
            return pred.sum() * 0.0  

        masked_pred = F.normalize(masked_pred, dim=-1)  
        masked_tgt = F.normalize(masked_tgt, dim=-1)  

        # Cosine distance  
        loss = 1.0 - (masked_pred * masked_tgt).sum(dim=-1)  
        return loss.mean()

# -----------------------------  
# Training  
# -----------------------------  
def build_hf_encoder(model_name: str):  
    if AutoModel is None:  
        raise RuntimeError("transformers is not installed. pip install transformers")  

    config = AutoConfig.from_pretrained(model_name)  
    encoder = AutoModel.from_pretrained(model_name, config=config)  
    dim = int(config.hidden_size)  
    return encoder, dim

def build_random_encoder(vocab_size: int = 30522, dim: int = 256, layers: int = 4, heads: int = 4):  
    """  
    For smoke tests only: small Transformer encoder (random init).  
    Requires a tokenizer with vocab mapping for ids.  
    """  
    encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, batch_first=True)  
    transformer = nn.TransformerEncoder(encoder_layer, num_layers=layers)  

    class TinyEncoder(nn.Module):  
        def __init__(self):  
            super().__init__()  
            self.emb = nn.Embedding(vocab_size, dim)  
            self.pos = nn.Embedding(512, dim)  
            self.enc = transformer  

        def forward(self, input_ids, attention_mask):  
            B, L = input_ids.shape  
            pos_ids = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, L)  
            x = self.emb(input_ids) + self.pos(pos_ids)  

            # attention_mask: 1 for keep, 0 for pad  
            # transformer expects src_key_padding_mask: True for pad  
            pad_mask = attention_mask == 0  
            h = self.enc(x, src_key_padding_mask=pad_mask)  
            return type("Out", (), {"last_hidden_state": h})  

    return TinyEncoder(), dim

def save_checkpoint(path: str, model: LLMJEPA, optimizer: torch.optim.Optimizer, step: int):  
    os.makedirs(os.path.dirname(path), exist_ok=True)  
    torch.save(  
        {  
            "step": step,  
            "context_encoder": model.context_encoder.state_dict(),  
            "target_encoder": model.target_encoder.state_dict(),  
            "predictor": model.predictor.state_dict(),  
            "optimizer": optimizer.state_dict(),  
        },  
        path,  
    )

def main():  
    parser = argparse.ArgumentParser()  
    parser.add_argument("--model_name", type=str, default="distilbert-base-uncased", help="HF encoder backbone")  
    parser.add_argument("--text_file", type=str, default="", help="Path to a newline-separated text file")  
    parser.add_argument("--max_lines", type=int, default=50000)  
    parser.add_argument("--max_length", type=int, default=128)  
    parser.add_argument("--mask_ratio", type=float, default=0.3)  
    parser.add_argument("--mean_span_len", type=int, default=5)  
    parser.add_argument("--ema_m", type=float, default=0.99)  
    parser.add_argument("--pred_hidden_mult", type=int, default=4)  

    parser.add_argument("--batch_size", type=int, default=8)  
    parser.add_argument("--lr", type=float, default=2e-5)  
    parser.add_argument("--weight_decay", type=float, default=0.01)  
    parser.add_argument("--steps", type=int, default=500)  
    parser.add_argument("--warmup_steps", type=int, default=50)  
    parser.add_argument("--log_every", type=int, default=25)  
    parser.add_argument("--save_every", type=int, default=200)  
    parser.add_argument("--save_path", type=str, default="checkpoints/llm_jepa.pt")  

    parser.add_argument("--device", type=str, default="auto")  
    parser.add_argument("--seed", type=int, default=42)  
    parser.add_argument("--smoke_test", action="store_true", help="No downloads, tiny random encoder, tiny corpus")  
    args = parser.parse_args()  

    set_seed(args.seed)  
    device = pick_device(args.device)  

    if args.smoke_test:  
        if AutoTokenizer is None:  
            raise RuntimeError("transformers is required even for smoke_test (for tokenizer).")  
        tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")  
        # Ensure mask token exists  
        if tokenizer.mask_token_id is None:  
            raise ValueError("Tokenizer must support [MASK]. Use a masked LM tokenizer.")  

        texts = default_tiny_corpus()  
        ds = TextLinesDataset(texts)  

        encoder, dim = build_random_encoder(vocab_size=int(tokenizer.vocab_size), dim=256, layers=4, heads=4)  
        model = LLMJEPA(encoder=encoder, dim=dim, ema_m=0.95, pred_hidden_mult=2).to(device)  

        lr = 1e-4  
    else:  
        if AutoTokenizer is None:  
            raise RuntimeError("transformers is not installed. pip install transformers")  
        tokenizer = AutoTokenizer.from_pretrained(args.model_name)  
        if tokenizer.mask_token_id is None:  
            raise ValueError(  
                "This tokenizer has no [MASK]. Pick a masked-encoder model (BERT/DeBERTa/DistilBERT)."  
            )  

        if args.text_file:  
            texts = load_texts_from_file(args.text_file, max_lines=args.max_lines)  
        else:  
            texts = default_tiny_corpus()  

        ds = TextLinesDataset(texts)  

        encoder, dim = build_hf_encoder(args.model_name)  
        model = LLMJEPA(encoder=encoder, dim=dim, ema_m=args.ema_m, pred_hidden_mult=args.pred_hidden_mult).to(device)  

        lr = args.lr  

    # DataLoader  
    def _collate(batch_texts):  
        return collate_jepa(  
            batch_texts=batch_texts,  
            tokenizer=tokenizer,  
            max_length=args.max_length,  
            mask_ratio=args.mask_ratio,  
            mean_span_len=args.mean_span_len,  
        )  

    dl = DataLoader(ds, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=_collate)  

    # Optimizer  
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=args.weight_decay)  

    # Simple warmup + cosine schedule  
    def lr_at(step: int) -> float:  
        if step < args.warmup_steps:  
            return float(step + 1) / float(max(1, args.warmup_steps))  
        progress = (step - args.warmup_steps) / float(max(1, args.steps - args.warmup_steps))  
        progress = min(max(progress, 0.0), 1.0)  
        return 0.5 * (1.0 + math.cos(math.pi * progress))  

    model.train()  
    running = 0.0  
    step = 0  
    data_iter = iter(dl)  

    while step < args.steps:  
        try:  
            batch = next(data_iter)  
        except StopIteration:  
            data_iter = iter(dl)  
            batch = next(data_iter)  

        # Move to device  
        input_ids = batch.input_ids.to(device)  
        attention_mask = batch.attention_mask.to(device)  
        masked_input_ids = batch.masked_input_ids.to(device)  
        pred_mask = batch.pred_mask.to(device)  

        # LR schedule  
        scale = lr_at(step)  
        for pg in optimizer.param_groups:  
            pg["lr"] = lr * scale  

        loss = model(  
            masked_input_ids=masked_input_ids,  
            input_ids=input_ids,  
            attention_mask=attention_mask,  
            pred_mask=pred_mask,  
        )  

        optimizer.zero_grad(set_to_none=True)  
        loss.backward()  
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  
        optimizer.step()  

        # EMA update after optimizer step  
        model.ema_update()  

        running += float(loss.item())  
        step += 1  

        if step % args.log_every == 0:  
            avg = running / float(args.log_every)  
            running = 0.0  
            print(f"step {step:6d} | loss {avg:.4f} | lr {optimizer.param_groups[0]['lr']:.6g}")  

        if step % args.save_every == 0:  
            save_checkpoint(args.save_path, model, optimizer, step)  
            print(f"saved checkpoint to {args.save_path} at step {step}")  

    save_checkpoint(args.save_path, model, optimizer, step)  
    print(f"training done. final checkpoint: {args.save_path}")

if __name__ == "__main__":  
     main()

这个脚本在训练什么

这是一个面向文本的 JEPA 风格表示预测器。

输入普通文本行,对每个样本创建两个视图。遮蔽视图(context view)是同一个句子,但某些 span 被替换成 `[MASK];原始视图(target view)保持原样,没有遮蔽。

训练流程是这样的:遮蔽视图过一个可训练的 context 编码器,原始视图过一个不可训练的 target 编码器,然后训练一个预测器,让 context 编码器的表示能预测 target 编码器的表示——但只在被遮蔽的位置上计算损失。Target 编码器通过 EMA 更新来保持稳定。

这种设计鼓励模型学习"填补语义"的表示,而不是预测具体的 token。

set_seed 函数

 defset_seed(seed: int):  
     random.seed(seed)  
     torch.manual_seed(seed)  
     torch.cuda.manual_seed_all(seed)

这个函数确保运行可复现。

random.seed(seed)

固定 Python 的随机操作(span 遮蔽会用到),

torch.manual_seed(seed)

固定 PyTorch 在 CPU 上的随机性,

torch.cuda.manual_seed_all(seed)

固定 CUDA 内核的随机性。

span 遮蔽和模型初始化都是随机的,不设种子的话每次跑结果都不一样。

pick_device 函数

 def pick_device(device_str: str) -> torch.device:  
     if device_str == "auto":  
         return torch.device("cuda" if torch.cuda.is_available() else "cpu")  
     return torch.device(device_str)

返回 PyTorch 设备对象。如果传

--device auto

,有 GPU 就用 GPU,没有就用 CPU。也可以直接指定

--device cpu

--device cuda

张量和模型必须在同一设备上,这是基本要求。

sample_span_mask 函数

 def sample_span_mask(seq_len, mask_ratio, mean_span_len, special_positions=None)

整个脚本里最重要的函数之一。

目标是创建一个布尔掩码,标记序列中哪些位置该被遮蔽。参数包括:seq_len 是真实 token 数量(不含 padding),mask_ratio 是遮蔽比例(比如 0.3),mean_span_len 是连续遮蔽 span 的平均长度,special_positions 是永远不该遮蔽的位置(CLS、SEP、PAD)。

内部逻辑是先创建一个全 False 的掩码,然后计算需要遮蔽多少 token:

 target_to_mask=max(1, int(round(seq_len*mask_ratio)))

即使序列很短也至少遮蔽 1 个。

接下来循环采样 span 直到凑够数。Span 长度从指数分布采样:

 span_len=max(1, int(random.expovariate(1.0/max(1, mean_span_len))))

这会产出很多短 span 和少量长 span,比较符合自然分布。随机选一个起始位置,过滤掉特殊 token,把剩下的位置标记为 True。

遮蔽策略对表示学习质量影响很大。Span 遮蔽能迫使模型从周围上下文推断缺失的语义。

apply_mask_to_input_ids 函数

 defapply_mask_to_input_ids(input_ids, attention_mask, tokenizer, mask_ratio, mean_span_len)

拿到一个样本的 token ids,输出两个东西:masked_input_ids 是把遮蔽位置换成 [MASK] 后的 ids,pred_mask 是标记哪些位置要算损失的布尔掩码。

先算可见序列长度:

seq_len = int(attention_mask.sum().item())

。attention_mask 里真实 token 是 1,padding 是 0。

然后识别特殊 token 位置,CLS 和 SEP 不能遮蔽,否则模型容易出问题。调用 sample_span_mask 采样遮蔽位置,把这些位置替换成 mask_token_id:

 masked_input_ids[:seq_len][pred_mask] =mask_token_id

返回的 pred_mask 是完整长度的,padding 位置都是 False。只在遮蔽位置算 JEPA 损失,其他位置忽略。

TextLinesDataset 类

 classTextLinesDataset(Dataset):  
     def__init__(self, texts):  
         self.texts= [t.strip() fortintextsift.strip()]

极简的数据集实现,存文本行列表,去掉空行和首尾空白。

__len__

返回行数,

__getitem__

返回单条文本。

load_texts_from_file 逐行读文件,可限制最大行数,传

--text_file

时用。default_tiny_corpus 提供内置测试数据集。

Batch 数据类

 @dataclass  
 classBatch:  
     input_ids  
     attention_mask  
     masked_input_ids  
     pred_mask

用 dataclass 比返回元组清晰多了,代码可读性好。

collate_jepa 函数

DataLoader 创建批次时调用的函数。输入是原始文本列表,先用 tokenizer 做分词、padding、截断:

 toks=tokenizer(batch_texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt")

产出 input_ids 和 attention_mask。然后对每个样本调 apply_mask_to_input_ids 生成遮蔽版本和 pred_mask,最后堆叠成 [B, L] 张量返回 Batch。

DataLoader 是逐样本读的,但训练需要批次。批处理和遮蔽都在这里发生。

PredictorMLP 类

预测器头,结构简单:

 nn.Linear(dim, hidden)  
 nn.GELU()  
 nn.Dropout()  
 nn.Linear(hidden, dim)

把 context 表示映射到 target 表示空间,相当于一个学习出来的适配器,帮助对齐两边的嵌入。

LLMJEPA 模型类

主模型包装器,包含四个核心部件:context_encoder 是可训练的 Transformer 编码器,target_encoder 是它的深拷贝但不可训练,predictor 是 MLP,ema_m 是 EMA 动量因子。

_copy_encoder 用

copy.deepcopy

确保 target 和 context 初始状态一致。

ema_update 缓慢更新 target 编码器权重:

 p_tgt=m*p_tgt+ (1-m) *p_ctx

m=0.99 时 target 变化非常慢,这能稳定训练、降低表示坍塌风险。

forward 的流程:把遮蔽视图过 context 编码器(可训练),原始视图过 target 编码器(无梯度),predictor 处理 context 输出,然后只取遮蔽位置的向量:

 masked_pred=pred[pred_mask]  # [N, D]  
 masked_tgt=z_tgt[pred_mask]  # [N, D]

从 [B, L, D] 变成 [N, D],N 是遮蔽 token 总数。归一化后算余弦距离:

 loss=1- (masked_pred*masked_tgt).sum(dim=-1)  
 returnloss.mean()

归一化是因为余弦相似度只看向量方向,不看大小。

build_hf_encoder 函数

加载 Hugging Face 编码器,返回模型和隐藏维度(从 config.hidden_size 读)。

build_random_encoder 函数

冒烟测试专用,从头建一个小 Transformer 编码器,包括嵌入层、位置嵌入、编码器堆栈。注意这不是掩码语言模型,只是个编码器架构。返回对象带

.last_hidden_state

属性是为了匹配 HF 输出格式。

总结

这个实现刻意追求清晰而非完整,所以没有自定义注意力掩码、多视图数据集或混合目标。但是把它当参考实现用是非常合适的。原始 LLM-JEPA 论文做得更深入,把 JEPA 和 token 预测结合起来,还利用了文本-代码这样的自然配对视图。那些设计对下游任务表现很重要,但也增加了复杂度,容易让人看不清核心机制。

论文:

https://avoid.overfit.cn/post/09eb991a93f64a83a376cdb52ac5c661

作者:azhar

目录
相关文章
|
15小时前
|
人工智能 JavaScript API
零门槛部署本地AI助手:2026年Windows系统OpenClaw(原Clawdbot/Moltbot)保姆级教程
OpenClaw(原Clawdbot/Moltbot)是一款功能全面的智能体AI助手,不仅能通过聊天互动响应需求,还具备“动手”和“跑腿”能力——“手”可读写本地文件、执行代码、操控命令行,“脚”能联网搜索、访问网页并分析内容,“大脑”则可接入Qwen、OpenAI等云端API,或利用本地GPU运行模型。本教程专为Windows系统用户打造,从环境搭建到问题排查,详细拆解全流程,即使无技术基础也能顺利部署本地AI助理。
339 5
|
17小时前
|
人工智能 弹性计算 监控
零基础入门:阿里云OpenClaw部署全流程详解(图文版)
OpenClaw(原Moltbot/Clawdbot)是一款高权限、隐私自主的本地化AI智能体,支持钉钉/QQ/飞书交互,可自动处理邮件、日程、数据查询等任务。依托阿里云轻量服务器,一键部署、开箱即用,7×24小时稳定运行,打造专属“AI员工”。
386 4
|
18小时前
|
存储 安全 数据库
2026年使用Docker部署OpenClaw(原Clawdbot/Moltbot)完整步骤教程
OpenClaw(原Clawdbot/Moltbot)是一款开源的本地运行个人AI助手,支持WhatsApp、Telegram、Slack等十余种通信渠道,兼容macOS、iOS、Android系统,还可渲染实时Canvas界面。本文提供基于Docker Compose的生产级部署指南,涵盖环境准备、源码获取、配置、构建、启动及运维等关键环节,补充生产环境必需的安全配置、数据持久化、备份与监控建议,与官方配置无冲突,适用于希望通过Docker快速部署的用户。需说明的是,OpenClaw暂无官方预构建Docker镜像,需通过源码+Dockerfile本地构建,这也是官方推荐的最稳定部署方式。
484 0
|
15小时前
|
存储 JSON 数据格式
FossFLOW:开源等距图表工具,为技术文档注入立体活力!
FossFLOW是一款创新的开源等距图表工具,专为技术文档设计。它通过立体视角将复杂的系统架构转化为直观的3D图表,支持拖放式操作和离线使用,让技术图表变得生动易懂。无需注册,数据安全存储在本地,并提供JSON导入导出功能。无论是Docker快速部署还是在线体验,FossFLOW都能为架构图、流程图注入立体活力,是提升技术文档表现力的得力助手。
30 6
|
15小时前
|
人工智能 弹性计算 机器人
阿里云部署OpenClaw(原Moltbot、Clawdbot)构建钉钉AI员工实践教程
阿里云推出OpenClaw(原Moltbot/Clawdbot)一键部署方案,无需手动安装调试,几步操作即可开箱即用!支持快速搭建钉钉AI员工,集成百炼大模型与钉钉机器人,实现智能对话、流式卡片回复,助力企业高效落地AIGC应用。
54 4
|
14小时前
|
存储 人工智能 自然语言处理
OpenClaw(原Clawdbot、Moltbot)开箱即用,阿里云无影云电脑秒变7x24h个人助理!
最近,OpenClaw人工智能工具(曾用名:Clawdbot、Moltbot)这波真的玩疯了~和主打对话的普通AI不同,OpenClaw不仅会独立思考,更能直接上手:查资料、整理文件、自动化跑流程、清空每日重复琐事,相当于给你配了位7×24小时连轴转、啥活都能干的“超级智能AI员工”。热度太猛,网友们为体验它疯狂抢购Mac mini,硬是把库存干到断货。想玩转新晋顶流AI工具OpenClaw(下文统称OpenClaw),何必花大价钱买Mac mini?用阿里云无影云电脑一键部署就够了!无需昂贵硬件投入,云端高配算力一键直达,轻松部署运行OpenClaw。
54 6
|
15小时前
|
数据采集 人工智能 自然语言处理
技术内幕:一文读懂章鱼AI的跨平台数据采集与创作架构
本文拆解了AI运营工具如何通过数据采集、分析、创作闭环,解决新手“发什么会火”的决策难题。
|
15小时前
|
人工智能 安全
智能体来了从 0 到 1 :核心挑战,是非技术性的认知与场景重构
本文探讨AI智能体从概念到落地的核心瓶颈:非模型能力,而在业务理解与结构化水平。指出智能体本质是“决策执行体”,其成败取决于能否将模糊业务目标拆解为可执行、可校验、可容错的逻辑结构,强调目标对齐、任务拆解、知识显性化与人机协同评估体系。
28 2
|
17小时前
|
人工智能 算法 自动驾驶
智能体领航员:重构效率边界与人的主体性
在数字化深水区,“智能体领航员”正推动人机关系从“工具辅助”迈向“逻辑共生”。它不再是被动软件,而是能理解模糊意图、拆解任务、过滤噪声、跨平台调度的主动协作伙伴。其核心价值在于释放人类心智带宽,让人回归决策高地——定义问题、判断价值、把握方向。这是算法时代重拾主体性的关键跃迁。(239字)
31 5
|
15小时前
|
人工智能 程序员 API
2026 AI 元年:从“单兵作战”到“智能体集群”,程序员的生存与重构
2026 年是真正的“AI Agent 元年”。大模型已从单一的文本生成进化为具备自主执行能力的“智能体集群”。本文将深度解析中国 AI 产业在这一进程中的技术贡献,探讨开发者如何从底层代码编写者转型为智能体编排专家,并揭示未来三年的行业重构路径。
43 0

热门文章

最新文章