风格迁移 图像合成 图像重构 更换姿态和图像背景(使用交叉注意控制进行提示到图像编辑)GAN网络增强版

简介: 风格迁移 图像合成 图像重构 更换姿态和图像背景(使用交叉注意控制进行提示到图像编辑)GAN网络增强版

前言


顺人性而为:你们最想要的是代码-------->

代码链接:[(https://download.csdn.net/download/ALiLiLiYa/86784427) 也可私信本人获取。 哈哈 但是还是要看下是否符合需求嘛!!


正文


摘要


最近的大规模文本驱动的合成模型由于其出色的生成遵循给定文本提示的高度多样化图像的能力而引起了广泛关注。这种基于文本的合成方法对习惯于口头描述其意图的人类特别有吸引力。因此,将文本驱动的图像合成扩展到文本驱动的图像编辑是很自然的。对于这些生成模型,编辑具有挑战性,因为编辑技术的固有属性是保留大部分原始图像,而在基于文本的模型中,即使对文本提示进行少量修改也通常会导致完全不同的结果。最先进的方法通过要求用户提供空间遮罩来定位编辑来减轻这种情况,因此忽略了遮罩区域内的原始结构和内容。在本文中,我们采用了一种直观的提示-提示编辑框架,其中的编辑仅由文本控制。为此,我们深入分析了文本条件模型,并观察到交叉注意层是控制图像空间布局与提示中每个单词之间关系的关键。通过此观察,我们提出了几种仅通过编辑文本提示来监视图像合成的应用程序。这包括通过替换单词进行本地化编辑,通过添加规范进行全局编辑,甚至微妙地控制单词在图像中的反映程度。我们在不同的图像和提示上展示了我们的结果,展示了高质量的合成和对编辑提示的保真度。


效果及目标


c1baac8946de2b7ee742b4b3b6333fe1_8a72410b51cf4618a91c5525acfc33e8.png


主要方法


让我是由文本引导扩散模型 [38] 使用文本提示P和随机种子s生成的图像。我们的目标是仅在编辑的提示p ∗ 的引导下编辑输入图像,从而产生编辑图像i ∗。例如,考虑从提示 “我的新自行车” 生成的图像,并假设用户想要编辑自行车的颜色、其材料,或者甚至用踏板车替换它,同时保留原始图像的外观和结构。对于用户来说,一个直观的界面是通过进一步描述自行车的外观或将其替换为另一个单词来直接更改文本提示。与以前的作品相反,我们希望避免依赖任何用户定义的掩码来帮助或表示编辑应该发生的位置。一个简单但不成功的尝试是修复内部随机性,并使用编辑后的文本提示重新生成。不幸的是,如图2所示,这导致具有不同结构和组成的完全不同的图像。我们的主要观察结果是,生成的图像的结构和外观不仅取决于随机种子,而且还取决于像素之间通过扩散过程嵌入文本的相互作用。通过修改交叉注意层中发生的像素到文本交互,我们提供了提示到提示的图像编辑功能。更具体地说,注入输入图像I的交叉注意图使我们能够保留原始构图和结构。在第3.1节中,我们回顾了如何使用交叉注意,在第3.中,我们描述了如何利用交叉注意进行编辑。有关扩散模型的其他背景,请参阅附录A。

下图描述了背景的替换和姿态的变换。

12349632e25c30cba7e0ee8becddc490_ed2627f4b52840f28dc9300d5924b1f4.png

from typing import Union, Tuple, List, Callable, Dict, Optional
import torch
import torch.nn.functional as nnf
from diffusers import DiffusionPipeline
import numpy as np
from IPython.display import display
from PIL import Image
import abc
import ptp_utils
import seq_aligner
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
model_id = "CompVis/ldm-text2im-large-256"
NUM_DIFFUSION_STEPS = 50
GUIDANCE_SCALE = 5.
MAX_NUM_WORDS = 77
# load model and scheduler
ldm = DiffusionPipeline.from_pretrained(model_id).to(device)
tokenizer = ldm.tokenizer
Fetching 17 files:   0%|          | 0/17 [00:00<?, ?it/s]
{'cross_attention_dim'} was not found in config. Values will be initialized to default values.
{'set_alpha_to_one'} was not found in config. Values will be initialized to default values.
Prompt-to-Prompt Attnetion Controllers
Our main logic is implemented in the forward call in an AttentionControl object. The forward is called in each attention layer of the diffusion model and it can modify the input attnetion weights attn.
is_cross, place_in_unet in ("down", "mid", "up"), AttentionControl.cur_step can help us track the exact attention layer and timestamp during the diffusion iference.
class LocalBlend:
    def __call__(self, x_t, attention_store, step):
        k = 1
        maps = attention_store["down_cross"][:2] + attention_store["up_cross"][3:6]
        maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps]
        maps = torch.cat(maps, dim=1)
        maps = (maps * self.alpha_layers).sum(-1).mean(1)
        mask = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k))
        mask = nnf.interpolate(maps, size=(x_t.shape[2:]))
        mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
        mask = mask.gt(self.threshold)
        mask = (mask[:1] + mask).float()
        x_t = x_t[:1] + mask * (x_t - x_t[:1])
        return x_t
    def __init__(self, prompts: List[str], words: [List[List[str]]], threshold: float = .3):
        alpha_layers = torch.zeros(len(prompts),  1, 1, 1, 1, MAX_NUM_WORDS)
        for i, (prompt, words_) in enumerate(zip(prompts, words)):
            if type(words_) is str:
                words_ = [words_]
            for word in words_:
                ind = ptp_utils.get_word_inds(prompt, word, tokenizer)
                alpha_layers[i, :, :, :, :, ind] = 1
        self.alpha_layers = alpha_layers.to(device)
        self.threshold = threshold
class AttentionControl(abc.ABC):
    def step_callback(self, x_t):
        return x_t
    def between_steps(self):
        return
    @abc.abstractmethod
    def forward (self, attn, is_cross: bool, place_in_unet: str):
        raise NotImplementedError
    def __call__(self, attn, is_cross: bool, place_in_unet: str):
        h = attn.shape[0]
        attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
        self.cur_att_layer += 1
        if self.cur_att_layer == self.num_att_layers:
            self.cur_att_layer = 0
            self.cur_step += 1
            self.between_steps()
        return attn
    def reset(self):
        self.cur_step = 0
        self.cur_att_layer = 0
    def __init__(self):
        self.cur_step = 0
        self.num_att_layers = -1
        self.cur_att_layer = 0
class EmptyControl(AttentionControl):
    def forward (self, attn, is_cross: bool, place_in_unet: str):
        return attn
class AttentionStore(AttentionControl):
    @staticmethod
    def get_empty_store():
        return {"down_cross": [], "mid_cross": [], "up_cross": [],
                "down_self": [],  "mid_self": [],  "up_self": []}
    def forward(self, attn, is_cross: bool, place_in_unet: str):
        if attn.shape[1] <= 16 ** 2:  # avoid memory overhead
            key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
            self.step_store[key].append(attn)
        return attn
    def between_steps(self):
        if len(self.attention_store) == 0:
            self.attention_store = self.step_store
        else:
            for key in self.attention_store:
                for i in range(len(self.attention_store[key])):
                    self.attention_store[key][i] += self.step_store[key][i]
        self.step_store = self.get_empty_store()
    def get_average_attention(self):
        average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
        return average_attention
    def reset(self):
        super(AttentionStore, self).reset()
        self.step_store = self.get_empty_store()
        self.attention_store = {}
    def __init__(self):
        super(AttentionStore, self).__init__()
        self.step_store = self.get_empty_store()
        self.attention_store = {}
class AttentionControlEdit(AttentionStore, abc.ABC):
    def step_callback(self, x_t):
        if self.local_blend is not None:
            x_t = self.local_blend(x_t, self.attention_store, self.cur_step)
        return x_t
    def replace_self_attention(self, attn_base, att_replace):
        if att_replace.shape[2] <= 16 ** 2:
            return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
        else:
            return att_replace
    @abc.abstractmethod
    def replace_cross_attention(self, attn_base, att_replace):
        raise NotImplementedError
    def forward(self, attn, is_cross: bool, place_in_unet: str):
        super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
        if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
            h = attn.shape[0] // (self.batch_size)
            attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
            attn_base, attn_repalce = attn[0], attn[1:]
            if is_cross:
                alpha_words = self.cross_replace_alpha[self.cur_step]
                attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce
                attn[1:] = attn_repalce_new
            else:
                attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
            attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
        return attn
    def __init__(self, prompts, num_steps: int,
                 cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
                 self_replace_steps: Union[float, Tuple[float, float]],
                 local_blend: Optional[LocalBlend]):
        super(AttentionControlEdit, self).__init__()
        self.batch_size = len(prompts)
        self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device)
        if type(self_replace_steps) is float:
            self_replace_steps = 0, self_replace_steps
        self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
        self.local_blend = local_blend

我们使用Image文本引导的综合模型作为主干。由于构图和几何形状主要是在64 × 64分辨率下确定的,因此我们仅使用超分辨率过程来适应文本到图像的扩散模型。回想一下,每个扩散步骤t都包括从嘈杂的图像zt预测噪声 ,并使用U形网络进行文本嵌入 ψ§ [。在最后一步,这个过程产生生成的图像I = z0。最重要的是,两种模式之间的相互作用发生在噪声预测期间,其中使用交叉注意层融合视觉和文本特征的嵌入,这些交叉注意层为每个文本令牌生成空间注意图。

c02147b7ebc0766cfceb91df67eebd73_1e534ea2fae543fea60583645cced390.png

636171a3820f1d5e2f66726d058630fb_b5eab3928b5f4d9f84c0c30f93d1d84e.png

67f6925541905a9169a97f1ef9904d43_1f51302d13fe41378a8636615d47d714.png


结论


在这项工作中,我们发现了文本到图像扩散模型中交叉注意层的强大功能。我们证明了这些高维层具有可解释的空间地图表示形式,它们在将文本提示中的单词与合成图像的空间布局联系起来方面起着关键作用。通过此观察,我们展示了对提示的各种操作如何直接控制合成图像中的属性,从而为包括本地和全局编辑在内的各种应用程序铺平了道路。这项工作是为用户提供简单直观的方法来编辑图像的第一步,利用文本语义能力。它使用户能够在语义,文本空间中导航,该空间在每个步骤之后都会显示增量更改,而不是在每次文本操作之后从头开始生成所需的图像。虽然我们已经通过仅更改文本提示来展示语义控制,但我们的技术仍然受到一些限制,需要在后续工作中解决。首先,当前的反演过程会导致某些测试一下图像上的可见失真。此外,反转需要用户提出合适的提示。对于复杂的构图,这可能具有挑战性。请注意,文本引导扩散模型的反演挑战是我们工作的正交努力,将来将对此进行深入研究。其次,当前的注意力地图分辨率低,因为交叉注意力被置于网络的瓶颈中。这限制了我们执行更精确的本地化编辑的能力。为了缓解这种情况,我们建议在更高分辨率的层中也纳入交叉注意力。我们将其留给将来的工作,因为它需要分析超出我们当前范围的培训程序。最后,我们认识到,我们当前的方法不能用于在图像上空间移动现有对象,也不能将这种控制留给将来的工作。

9034a2ee13c0485e94835015ff925339.jpeg


相关文章
|
3月前
|
计算机视觉
【论文复现】经典再现:yolov4的主干网络重构(结合Slim-neck by GSConv)
【论文复现】经典再现:yolov4的主干网络重构(结合Slim-neck by GSConv)
52 0
【论文复现】经典再现:yolov4的主干网络重构(结合Slim-neck by GSConv)
|
3月前
|
机器学习/深度学习 算法 PyTorch
python手把手搭建图像多分类神经网络-代码教程(手动搭建残差网络、mobileNET)
python手把手搭建图像多分类神经网络-代码教程(手动搭建残差网络、mobileNET)
46 0
|
4月前
|
机器学习/深度学习 缓存 算法
【论文速递】CVPR2020 - CRNet:用于小样本分割的交叉参考网络
【论文速递】CVPR2020 - CRNet:用于小样本分割的交叉参考网络
|
8天前
|
机器学习/深度学习 TensorFlow 算法框架/工具
PYTHON TENSORFLOW 2二维卷积神经网络CNN对图像物体识别混淆矩阵评估|数据分享
PYTHON TENSORFLOW 2二维卷积神经网络CNN对图像物体识别混淆矩阵评估|数据分享
36 7
|
2月前
|
机器学习/深度学习 编解码 异构计算
ELAN:用于图像超分辨率的高效远程注意力网络
ELAN:用于图像超分辨率的高效远程注意力网络
33 1
|
3月前
|
机器学习/深度学习 算法 Python
【Siamese】手把手教你搭建一个孪生神经网络,比较两张图像的相似度
【Siamese】手把手教你搭建一个孪生神经网络,比较两张图像的相似度
88 0
|
3月前
|
前端开发 PyTorch 算法框架/工具
【基础实操】借用torch自带网络进行训练自己的图像数据
【基础实操】借用torch自带网络进行训练自己的图像数据
24 0
【基础实操】借用torch自带网络进行训练自己的图像数据
|
3月前
|
机器学习/深度学习 编解码 计算机视觉
YOLOv5改进 | 主干篇 | CSWinTransformer交叉形窗口网络
YOLOv5改进 | 主干篇 | CSWinTransformer交叉形窗口网络
37 0
|
10天前
|
机器学习/深度学习 缓存 监控
linux查看CPU、内存、网络、磁盘IO命令
`Linux`系统中,使用`top`命令查看CPU状态,要查看CPU详细信息,可利用`cat /proc/cpuinfo`相关命令。`free`命令用于查看内存使用情况。网络相关命令包括`ifconfig`(查看网卡状态)、`ifdown/ifup`(禁用/启用网卡)、`netstat`(列出网络连接,如`-tuln`组合)以及`nslookup`、`ping`、`telnet`、`traceroute`等。磁盘IO方面,`iostat`(如`-k -p ALL`)显示磁盘IO统计,`iotop`(如`-o -d 1`)则用于查看磁盘IO瓶颈。
|
4天前
|
网络协议 算法 Linux
【Linux】深入探索:Linux网络调试、追踪与优化
【Linux】深入探索:Linux网络调试、追踪与优化