基于ViT主干的扩散模型技术,开源!

本文涉及的产品
交互式建模 PAI-DSW,每月250计算时 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
模型训练 PAI-DLC,100CU*H 3个月
简介: 近期大火的OpenAI推出的Sora模型,其核心技术点之一,是将视觉数据转化为Patch的统一表示形式,并通过Transformers技术和扩散模型结合,展现了卓越的scale特性。

引言

近期大火的OpenAI推出的Sora模型,其核心技术点之一,是将视觉数据转化为Patch的统一表示形式,并通过Transformers技术和扩散模型结合,展现了卓越的scale特性。

被Twitter上广泛传播的论文《Scalable diffusion models with transformers》也被认为是Sora技术背后的重要基础。而这项研究的发布遇到了一些坎坷,曾经被CVPR2023拒稿过。

无独有偶,虽然DiT被拒了,我们看到来自清华大学,人民大学和北京人工智能研究院等机构共同研究的CVPR2023的论文U-ViT《All are Worth Words: A ViT Backbone for Diffusion Models》,这项研究设计了一个简单而通用的基于vit的架构(U-ViT),替换了U-Net中的卷积神经网络(CNN),用于diffusion模型的图像生成任务。

该项研究现已开源,欢迎大家关注:

GitHub链接:

GitHub - baofff/U-ViT: A PyTorch implementation of the paper "All are Worth Words: A ViT Backbone for Diffusion Models".

论文链接:

[2209.12152] All are Worth Words: A ViT Backbone for Diffusion Models (arxiv.org)

模型链接:

imagenet256_uvit_huge · 模型库 (modelscope.cn)

同时,我们也意识到,Sora将基于Transformers的diffusion model scale up成功,不仅需要对底层算法有专家级理解,还要对整个深度学习工程体系有很好的把握,这项工作相比在学术数据集做出一个可行架构更加困难。

论文和代码技术解读

1、模型结构概览

基于卷积神经网络(CNN)的U-Net一直是之前的Diffusion Model中的主流backbone。基于CNN的U-Net具有一组下采样块和一组上采样块以及两组之间的Long skip connection的特征。然而视觉Transformers(ViT)在各类视觉任务重已经显示出很好的前景,其中ViT与基于CNN的方法效果相当甚至优于CNN。因此,论文的开篇就提出了一个问题:在扩散模型中是否有必要依赖基于CNN的U-Net?

在这篇研究中,遵循Transformers的设计方法,U-ViT将包括时间、条件和噪声图像patches在内的所有输入都视作为token。至关重要的是,U-ViT受U-Net启发,采用了浅层和深层之间的Long skip Connection。实际上,低级特征对于扩散模型中的像素级预测目标很重要,Long skip Connection可以简化对应预测网络的训练。此外,U-ViT可选的在输出之前增加了一个额外的3X3的卷积块,以获得更好的视觉质量。详见如下的模型架构图。

如上图所示,扩散模型的U-ViT架构,其特点是将时间,条件和噪声图像块在内的所有输入作为token,并在浅层和深层之间使用(blocks-1)/2 Long skip Connection。

结合上图,我们对UViT模型的结构做了一个大致的梳理,大家可以先有个初步的了解,下面我们将对提及的每个模块进行详细的介绍。

2、模型的输入

首先我们来看下模型的输入分别是哪些组成,是什么样的。

输入部分切合了论文的标题:All as words (token),将包括时间、条件和噪声图像快在内的所有输入表示为token,再通过Embedding层。

Embedding层的作用是将某种格式的输入数据,转变为模型可以处理的向量表示,来描述原始数据所包含的信息。

timestep_embedding

timestep_embedding,核心是为时间步长生成正弦嵌入,用于时序数据中引入时间信息,如下:

def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

Patchify函数

patchify函数,将图像划分为 patches ,示例图片如下:

代码:

def patchify(imgs, patch_size):
    x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
    return x

PatchEmbed:

通过卷积操作将图像转换为 patch token,即将图像分割成多个patches并投影到指定维度的向量空间。

代码:

class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, patch_size, in_chans=3, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        B, C, H, W = x.shape
        assert H % self.patch_size == 0 and W % self.patch_size == 0
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

3、编解码结构设计

参考代码:

U-ViT/libs/autoencoder.py at main · baofff/U-ViT · GitHub

上采样代码,根据需求选择性地结合上采样和卷积操作,实现对输入特征图的上采样并可选地进行特征提取。

class Upsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            self.conv = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)
    def forward(self, x):
        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
        if self.with_conv:
            x = self.conv(x)
        return x

下采样代码,对输入的二维特征图(例如图像)进行下采样操作,即降低特征图的空间维度。

class Downsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            # no asymmetric padding in torch conv, must do it ourselves
            self.conv = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=3,
                                        stride=2,
                                        padding=0)
    def forward(self, x):
        if self.with_conv:
            pad = (0,1,0,1)
            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
            x = self.conv(x)
        else:
            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
        return x

Encoder设计采用了深度残差网络结构,并结合了多尺度注意力机制来增强特征学习能力,旨在高效地捕获输入图像中的关键信息并转化为适合下游任务(如图像生成)的潜在表征。

class Encoder(nn.Module):
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
                 **ignore_kwargs):
        super().__init__()
        if use_linear_attn: attn_type = "linear"
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        # downsampling
        self.conv_in = torch.nn.Conv2d(in_channels,
                                       self.ch,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)
        curr_res = resolution
        in_ch_mult = (1,)+tuple(ch_mult)
        self.in_ch_mult = in_ch_mult
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = ch*in_ch_mult[i_level]
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(make_attn(block_in, attn_type=attn_type))
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions-1:
                down.downsample = Downsample(block_in, resamp_with_conv)
                curr_res = curr_res // 2
            self.down.append(down)
        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        2*z_channels if double_z else z_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)
    def forward(self, x):
        # timestep embedding
        temb = None
        # downsampling
        hs = [self.conv_in(x)]
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](hs[-1], temb)
                if len(self.down[i_level].attn) > 0:
                    h = self.down[i_level].attn[i_block](h)
                hs.append(h)
            if i_level != self.num_resolutions-1:
                hs.append(self.down[i_level].downsample(hs[-1]))
        # middle
        h = hs[-1]
        h = self.mid.block_1(h, temb)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h, temb)
        # end
        h = self.norm_out(h)
        h = nonlinearity(h)
        h = self.conv_out(h)
        return h

Decoder是一个递归结构,包含卷积层、残差块(ResnetBlock)、注意力机制(make_attn)以及上采样层(Upsample),可以逐步将低维潜在空间的向量(z)转换为高分辨率的图像。

class Downsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            # no asymmetric padding in torch conv, must do it ourselves
            self.conv = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=3,
                                        stride=2,
                                        padding=0)
    def forward(self, x):
        if self.with_conv:
            pad = (0,1,0,1)
            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
            x = self.conv(x)
        else:
            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
        return x

4、Transformers设计

注意力机制

UViT实现多头注意力机制,根据ATTENTION_MODE选择不同的计算方式,包括 Flash Attention、XFormers Attention 和Math attention。

以下是Attention模块的实现代码:

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    def forward(self, x):
        B, L, C = x.shape
        qkv = self.qkv(x)
        if ATTENTION_MODE == 'flash':
            qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
            q, k, v = qkv[0], qkv[1], qkv[2]  # B H L D
            x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
            x = einops.rearrange(x, 'B H L D -> B L (H D)')
        elif ATTENTION_MODE == 'xformers':
            qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
            q, k, v = qkv[0], qkv[1], qkv[2]  # B L H D
            x = xformers.ops.memory_efficient_attention(q, k, v)
            x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
        elif ATTENTION_MODE == 'math':
            qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads)
            q, k, v = qkv[0], qkv[1], qkv[2]  # B H L D
            attn = (q @ k.transpose(-2, -1)) * self.scale
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = (attn @ v).transpose(1, 2).reshape(B, L, C)
        else:
            raise NotImplemented
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

Transformers Block设计

Block类:一个基于Transformer编码器/解码器块,包含multi-head attention层和 MLP 层,并可选地使用skip connection以及 checkpoint进行内存优化。

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)
        self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
        self.use_checkpoint = use_checkpoint
    def forward(self, x, skip=None):
        if self.use_checkpoint:
            return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
        else:
            return self._forward(x, skip)
    def _forward(self, x, skip=None):
        if self.skip_linear is not None:
            x = self.skip_linear(torch.cat([x, skip], dim=-1))
            x = x + self.attn(self.norm1(x))
            x = x + self.mlp(self.norm2(x))
        return x

5、UViT主干:

初始化包含了patch embedding模块、时间embedding模块(可以选择MLP形式的时间嵌入)、Class embedding、Position embedding等。

定义了in_blocks和out_blocks部分,中间还有一个mid_block。

最后,模型通过decoder_pred线性层预测出patch级别的特征,然后重构回原始图像尺寸。

在前向过程中,模型首先对输入图像进行patch化处理,接着添加time embedding(如果有提供时间步信息的话)以及position embedding。经过一系列的编码器和解码器block后,输出被映射回patch级别特征,再重构为图像尺寸,最后通过一个可选的3X3卷积层或者恒等映射输出最终结果。

class UViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
                 qkv_bias=False, qk_scale=None, norm_layer=nn.LayerNorm, mlp_time_embed=False, num_classes=-1,
                 use_checkpoint=False, conv=True, skip=True):
        super().__init__()
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.num_classes = num_classes
        self.in_chans = in_chans
        self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = (img_size // patch_size) ** 2
        self.time_embed = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.SiLU(),
            nn.Linear(4 * embed_dim, embed_dim),
        ) if mlp_time_embed else nn.Identity()
        if self.num_classes > 0:
            self.label_emb = nn.Embedding(self.num_classes, embed_dim)
            self.extras = 2
        else:
            self.extras = 1
        self.pos_embed = nn.Parameter(torch.zeros(1, self.extras + num_patches, embed_dim))
        self.in_blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                norm_layer=norm_layer, use_checkpoint=use_checkpoint)
            for _ in range(depth // 2)])
        self.mid_block = Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                norm_layer=norm_layer, use_checkpoint=use_checkpoint)
        self.out_blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                norm_layer=norm_layer, skip=skip, use_checkpoint=use_checkpoint)
            for _ in range(depth // 2)])
        self.norm = norm_layer(embed_dim)
        self.patch_dim = patch_size ** 2 * in_chans
        self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True)
        self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, 3, padding=1) if conv else nn.Identity()
        trunc_normal_(self.pos_embed, std=.02)
        self.apply(self._init_weights)
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed'}
    def forward(self, x, timesteps, y=None):
        x = self.patch_embed(x)
        B, L, D = x.shape
        time_token = self.time_embed(timestep_embedding(timesteps, self.embed_dim))
        time_token = time_token.unsqueeze(dim=1)
        x = torch.cat((time_token, x), dim=1)
        if y is not None:
            label_emb = self.label_emb(y)
            label_emb = label_emb.unsqueeze(dim=1)
            x = torch.cat((label_emb, x), dim=1)
        x = x + self.pos_embed
        skips = []
        for blk in self.in_blocks:
            x = blk(x)
            skips.append(x)
        x = self.mid_block(x)
        for blk in self.out_blocks:
            x = blk(x, skips.pop())
        x = self.norm(x)
        x = self.decoder_pred(x)
        assert x.size(1) == self.extras + L
        x = x[:, self.extras:, :]
        x = unpatchify(x, self.in_chans)
        x = self.final_layer(x)
        return x

实战案例

实战案例是通过扩散模型,在给定类别标签条件下,使用UViT作为主干,生成连续图像样本,并展示了生成的结果。

根据UViT官方的实践代码(U-ViT/UViT_ImageNet_demo.ipynb at main · baofff/U-ViT · GitHub)修改,可直接在魔搭社区的免费算力上实践和运行。

魔搭案例链接:

modelscope/examples/pytorch/UViT_ImageNet_demo.ipynb at master · modelscope/modelscope · GitHub

1、环境依赖安装

!git clone https://github.com/baofff/U-ViT
!pip install einops
import os
os.chdir('/mnt/workspace/U-ViT')
os.environ['PYTHONPATH'] = '/env/python:/mnt/workspace/U-ViT'
import torch
from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
import libs.autoencoder
from libs.uvit import UViT
import einops
from torchvision.utils import save_image
from PIL import Image
from modelscope.hub.file_download import model_file_download

2、加载UViT模型

设置图像尺寸为256,下载对应的UViT模型,计算Latent Space的低维潜在表示的尺寸大小,初始化UViT模型结构,加载模型。

image_size = "256" #@param [256, 512]
image_size = int(image_size)
if image_size == 256:
    #download uvit model
    model_file_download(model_id='thu-ml/imagenet256_uvit_huge',file_path='imagenet256_uvit_huge.pth', cache_dir='/mnt/workspace')
    !mv /mnt/workspace/thu-ml/imagenet256_uvit_huge/imagenet256_uvit_huge.pth /mnt/workspace/U-ViT
else:
    model_file_download(model_id='thu-ml/imagenet512_uvit_huge',file_path='imagenet512_uvit_huge.pth', cache_dir='/mnt/workspace')
    !mv /mnt/workspace/thu-ml/imagenet512_uvit_huge/imagenet512_uvit_huge.pth /mnt/workspace/U-ViT
z_size = image_size // 8
patch_size = 2 if image_size == 256 else 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
nnet = UViT(img_size=z_size,
       patch_size=patch_size,
       in_chans=4,
       embed_dim=1152,
       depth=28,
       num_heads=16,
       num_classes=1001,
       conv=False)
nnet.to(device)
nnet.load_state_dict(torch.load(f'imagenet{image_size}_uvit_huge.pth', map_location='cpu'))
nnet.eval()

3、下载自动编码器模型并加载

model_file_download(model_id='AI-ModelScope/autoencoder_kl_ema',file_path='autoencoder_kl_ema.pth', cache_dir='/mnt/workspace')
!mv /mnt/workspace/AI-ModelScope/autoencoder_kl_ema/autoencoder_kl_ema.pth /mnt/workspace/U-ViT
autoencoder = libs.autoencoder.get_model('autoencoder_kl_ema.pth')
autoencoder.to(device)

4、UViT结合diffusion模型实现图像生成

通过扩散模型,在给定类别标签条件下,使用UViT作为主干,生成连续图像样本,并展示了生成的结果。Sample方式为dpm_solver。

seed = 42 #@param {type:"number"}
steps = 25 #@param {type:"slider", min:0, max:1000, step:1}
cfg_scale = 3 #@param {type:"slider", min:0, max:10, step:0.1}
class_labels = 207, 360, 387, 974, 88, 979, 417, 279 #@param {type:"raw"}
samples_per_row = 4 #@param {type:"number"}
torch.manual_seed(seed)
def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
    _betas = (
        torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
    )
    return _betas.numpy()
_betas = stable_diffusion_beta_schedule()  # set the noise schedule
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
y = torch.tensor(class_labels, device=device)
y = einops.repeat(y, 'B -> (B N)', N=samples_per_row)
def model_fn(x, t_continuous):
    t = t_continuous * len(_betas)
    _cond = nnet(x, t, y=y)
    _uncond = nnet(x, t, y=torch.tensor([1000] * x.size(0), device=device))
    return _cond + cfg_scale * (_cond - _uncond)  # classifier free guidance
z_init = torch.randn(len(y), 4, z_size, z_size, device=device)
dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
with torch.no_grad():
  with torch.cuda.amp.autocast():  # inference with mixed precision
    z = dpm_solver.sample(z_init, steps=steps, eps=1. / len(_betas), T=1.)
    samples = autoencoder.decode(z)
samples = 0.5 * (samples + 1.)
samples.clamp_(0., 1.)
save_image(samples, "sample.png", nrow=samples_per_row * 2, padding=0)
samples = Image.open("sample.png")
display(samples)

总结

正如论文中所说,具有UViT的latent diffusion模型在Im-ageNet 256X256上类条件图像生成中获得了创纪录的2.29分FID,在MS-COCO上的文本到图像生成中获得了5.48FID的高分,感谢UViT开源相应的工作,我们相信,基于UViT的研究基础上,开发者们可以更好的展开类似Sora这样的前沿技术研究工作。

GITHUB链接:

GitHub - baofff/U-ViT: A PyTorch implementation of the paper "All are Worth Words: A ViT Backbone for Diffusion Models".

相关文章
|
4月前
|
机器学习/深度学习 搜索推荐
CIKM 2024:LLM蒸馏到GNN,性能提升6.2%!Emory提出大模型蒸馏到文本图
【9月更文挑战第17天】在CIKM 2024会议上,Emory大学的研究人员提出了一种创新框架,将大型语言模型(LLM)的知识蒸馏到图神经网络(GNN)中,以克服文本图(TAGs)学习中的数据稀缺问题。该方法通过LLM生成文本推理,并训练解释器模型理解这些推理,再用学生模型模仿此过程。实验显示,在四个数据集上性能平均提升了6.2%,但依赖于LLM的质量和高性能。论文链接:https://arxiv.org/pdf/2402.12022
105 7
|
4月前
|
机器学习/深度学习 数据采集
详解Diffusion扩散模型:理论、架构与实现
【9月更文挑战第23天】扩散模型(Diffusion Models)是一类基于随机过程的深度学习模型,通过逐步加噪和去噪实现图像生成,在此领域表现优异。模型分正向扩散和反向生成两阶段:前者从真实数据加入噪声至完全噪音,后者则学习从噪声中恢复数据,经由反向过程逐步还原生成清晰图像。其主要架构采用U-net神经网络,实现过程中需数据预处理及高斯噪声添加等步骤,最终通过模型逆向扩散生成新数据,具有广泛应用前景。
136 0
|
5月前
|
机器学习/深度学习 自然语言处理 数据建模
7.2 Transformer:具有里程碑意义的新模型——自注意力模型
该文章详细介绍了Transformer模型及其核心组件自注意力模型(Self-Attention Model),解释了其如何克服循环神经网络在处理长序列数据时遇到的长程依赖问题,并促进了深度学习在多个领域的应用发展。
|
8月前
|
机器学习/深度学习
YOLOv8改进 | 2023主干篇 | RepViT从视觉变换器(ViT)的视角重新审视CNN
YOLOv8改进 | 2023主干篇 | RepViT从视觉变换器(ViT)的视角重新审视CNN
429 1
YOLOv8改进 | 2023主干篇 | RepViT从视觉变换器(ViT)的视角重新审视CNN
|
8月前
|
机器学习/深度学习 编解码 算法
助力目标检测涨点 | 可以这样把Vision Transformer知识蒸馏到CNN模型之中
助力目标检测涨点 | 可以这样把Vision Transformer知识蒸馏到CNN模型之中
275 0
|
8月前
|
机器学习/深度学习 编解码
YOLOv5改进 | 2023主干篇 | RepViT从视觉变换器(ViT)的视角重新审视CNN
YOLOv5改进 | 2023主干篇 | RepViT从视觉变换器(ViT)的视角重新审视CNN
343 0
|
8月前
|
机器学习/深度学习
YOLOv8改进 | 主干篇 | EfficientViT高效的特征提取网络完爆MobileNet系列(轻量化网络结构)
YOLOv8改进 | 主干篇 | EfficientViT高效的特征提取网络完爆MobileNet系列(轻量化网络结构)
537 0
|
编解码 测试技术 计算机视觉
LVT | ViT轻量化的曙光,完美超越MobileNet和ResNet系列(二)
LVT | ViT轻量化的曙光,完美超越MobileNet和ResNet系列(二)
255 0
LVT | ViT轻量化的曙光,完美超越MobileNet和ResNet系列(二)
|
机器学习/深度学习 人工智能 编解码
一文梳理视觉Transformer架构进展:与CNN相比,ViT赢在哪儿?(1)
一文梳理视觉Transformer架构进展:与CNN相比,ViT赢在哪儿?
629 0
|
机器学习/深度学习 编解码 计算机视觉
Transformer新SOTA | 超越SWin、CSWin,MAFormer再探ViT Backbone新高度
Transformer新SOTA | 超越SWin、CSWin,MAFormer再探ViT Backbone新高度
280 0

热门文章

最新文章