计算机视觉PyTorch实现风格迁移

简介: 计算机视觉PyTorch实现风格迁移

神经网络风格迁移


它主要是通过神经网络,将一幅艺术风格画(style image)和一张普通的照片(content image)巧妙地融合,形成一张非常有意思的图片。


大白话说,图像往往由风格与内容组成,比如我们常常说画家的画风是怎么样的,毕加索的画风、动漫的画风。

风格迁移就是保留一张图片的内容(物体,人物),用另一张图片的色彩画图风格去填充。


风格迁移原理


在介绍原理之前先普及一个知识点:


通常将图像输出到卷积神经网络中,在神经网络第一层隐藏层通常会找出一些简单的特征,比如边缘或者颜色阴影。在神经网络深层部分的一个隐藏单元会看到一张图片更大的部分,在极端的情况下,可以假设图像中每一个像素都会影响到神经网络层更深层的输出,靠后的隐藏层可以看到更大的图片块。


也就是说在神经网络中隐藏单元里从第一层的边缘到第二层的质地再到更深层的复杂物体


原理


首先我们需要获取一张内容图片和一张风格图片;然后定义二个度量,一个度量值为内容度量值,另一个度量为风格度量值,其中内容度量值通过生成代价函数来衡量二个图片之间的内容差异程度,风格度量也通过生成代价函数衡量图片之间风格差异程度,最后建立生成图像的神经网络模型,对内容图片中的内容和风格图片的风格进行提取,以内容图片为基准将其输入建立的模型中,通过代价函数梯度下降来调整内容度量值和风格度量值,让它们趋近于最小,最后输出的图片就是内容和风格融合的图片。



1.生成图像代价函数


想要生成出我们想要的图像,就需要定义一个代价函数,通过最小化代价函数,你可以生成任何图像

我们用C表示内容图像,用S表示风格图像,用G表示想要生成的图像


图像代价函数


1687221498794.png


15d6a47055c1480ea1d1864758abd1bc.png


代价函数梯度下降优化过程,生成图像的变化:



2.内容代价函数


1687221529425.png

内容代价函数


1687221569604.png


3.风格代价函数

1687221608298.png


Gram矩阵计算公式如下

1687221625519.png

风格代价函数

1687221648024.png

风格迁移算法实现


1687221672909.png

f64832264ea241afbca8cacdc3a79484.png

VGG-16网络结构如下图所示:

aa17fee8fa154ddd9326a9c39f62f89b.png

Fast Neuarl Style训练步骤如下:


1687221725427.png


PyTorch代码实现如下:


from __future__ import print_function
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,models
from PIL import Image
import matplotlib.pyplot as plt
from torch.autograd import Variable
import copy
#%% 图像预处理
transform=transforms.Compose([transforms.Scale([128,128]),
                              transforms.ToTensor()])
def loadimg(path=None):
    img=Image.open(path)
    img=transform(img)
    img=Variable(img)
    img=img.unsqueeze(0)
    return img
content_img=loadimg("3.jpg")
style_img=loadimg("1.jpg")
#%% 显示图片
unloader = transforms.ToPILImage()  # 重新转换成PIL图像
plt.ion()
def imshow(tensor, title=None):
    image = tensor.clone().cpu()  # 我们克隆张量以不对其进行修改
    image = image.view(3, 128, 128)  # 删除批量处理维度
    image = unloader(image)
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.001) # pause a bit so that plots are updated
plt.figure()
imshow(style_img.data, title='Style Image')
plt.figure()
imshow(content_img.data, title='Content Image')
#%% 内容损失
class Content_loss(nn.Module):
    def __init__(self,target,weight):
        super(Content_loss, self).__init__()
        self.weight=weight
        self.target=target.detach()*weight
        self.loss_fn=nn.MSELoss()
    def forward(self,input):
        self.loss=self.loss_fn(input*self.weight,self.target)
        self.output = input
        return self.output
    def backward(self, retain_graph=True):
        self.loss.backward(retain_graph=retain_graph)
        return self.loss
#%% 风格损失
class Style_loss(nn.Module):
    def __init__(self, target, weight):
        super(Style_loss, self).__init__()
        self.target = target.detach() * weight
        self.weight = weight
        self.gram = Gram_matrix()
        self.loss_fn = nn.MSELoss()
    def forward(self, input):
        self.output = input.clone()
        self.G = self.gram(input)
        self.G.mul_(self.weight)
        self.loss = self.loss_fn(self.G, self.target)
        return self.output
    def backward(self, retain_graph=True):
        self.loss.backward(retain_graph=retain_graph)
        return self.loss
class Gram_matrix(nn.Module):
    def forward(self,input):
        a,b,c,d=input.size()
        feature=input.view(a*b,c*d)
        gram=torch.mm(feature, feature.t())
        return gram.div(a*b*c*d)
#%% 模型搭建
vgg=models.vgg19(pretrained=True).features
content_layer=["Conv_4"]
style_layer=['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
def get_style_model_and_losses(vgg, style_img, content_img,
                               style_weight=1000, content_weight=1,
                               content_layers=content_layer,
                               style_layers=style_layer):
    vgg = copy.deepcopy(vgg)
    # just in order to have an iterable access to or list of content/syle
    # losses
    content_losses = []
    style_losses = []
    model = nn.Sequential()  # the new Sequential module network
    gram = Gram_matrix() # we need a gram module in order to compute style targets
    # move these modules to the GPU if possible:
    i = 1
    for layer in list(vgg):
        if isinstance(layer, nn.Conv2d):
            name = "conv_" + str(i)
            model.add_module(name, layer)
            if name in content_layers:
                # add content loss:
                target = model(content_img).clone()
                content_loss = Content_loss(target, content_weight)
                model.add_module("content_loss_" + str(i), content_loss)
                content_losses.append(content_loss)
            if name in style_layers:
                # add style loss:
                target_feature = model(style_img).clone()
                target_feature_gram = gram(target_feature)
                style_loss = Style_loss(target_feature_gram, style_weight)
                model.add_module("style_loss_" + str(i), style_loss)
                style_losses.append(style_loss)
        if isinstance(layer, nn.ReLU):
            name = "relu_" + str(i)
            model.add_module(name, layer)
            if name in content_layers:
                # add content loss:
                target = model(content_img).clone()
                content_loss = Content_loss(target, content_weight)
                model.add_module("content_loss_" + str(i), content_loss)
                content_losses.append(content_loss)
            if name in style_layers:
                # add style loss:
                target_feature = model(style_img).clone()
                target_feature_gram = gram(target_feature)
                style_loss = Style_loss(target_feature_gram, style_weight)
                model.add_module("style_loss_" + str(i), style_loss)
                style_losses.append(style_loss)
            i += 1
        if isinstance(layer, nn.MaxPool2d):
            name = "pool_" + str(i)
            model.add_module(name, layer)  # ***
    return model, style_losses, content_losses
#%%输入图像
input_img = content_img.clone()
# if you want to use a white noise instead uncomment the below line:
# input_img = Variable(torch.randn(content_img.data.size())).type(dtype)
# add the original input image to the figue:
plt.figure()
imshow(input_img.data, title='Input Image')    
#%%梯度下降
def get_input_param_optimizer(input_img):
    # this line to show that input is a parameter that requires a gradient
    input_param = nn.Parameter(input_img.data)
    optimizer = torch.optim.LBFGS([input_param])
    return input_param, optimizer        
input_param,optimizer=get_input_param_optimizer(input_img)        
#%%参数优化
def run_style_transfer(cnn, content_img, style_img, input_img, num_steps=300,
                       style_weight=1000, content_weight=1):
    """Run the style transfer."""
    print('Building the style transfer model..')
    model, style_losses, content_losses = get_style_model_and_losses(cnn,
        style_img, content_img, style_weight, content_weight)
    input_param, optimizer = get_input_param_optimizer(input_img)
    print('Optimizing..')
    run = [0]
    while run[0] <= num_steps:
        def closure():
            # correct the values of updated input image
            input_param.data.clamp_(0, 1)
            optimizer.zero_grad()
            model(input_param)
            style_score = 0
            content_score = 0
            for sl in style_losses:
                style_score += sl.backward()
            for cl in content_losses:
                content_score += cl.backward()
            run[0] += 1
            if run[0] % 50 == 0:
                print("run {}:".format(run))
                print('Style Loss : {:5f} Content Loss: {:5f}'.format(
                    style_score, content_score))
            return style_score + content_score
        optimizer.step(closure)
    # a last correction...
    input_param.data.clamp_(0, 1)
    return input_param.data         
#%% 输出图像
output = run_style_transfer(vgg, content_img, style_img, input_img)
plt.figure()
imshow(output,title="Output Image")
# sphinx_gallery_thumbnail_number = 4
plt.ioff()
plt.show()

1b04f7565c19419e81d99b7f11aa6786.png9c9c642b0b774481af0295ac4207931d.png


bf4b30a73cf744a2bc85bf6469efce9f.png

相关文章
|
5天前
|
数据可视化 PyTorch 算法框架/工具
使用PyTorch搭建VGG模型进行图像风格迁移实战(附源码和数据集)
使用PyTorch搭建VGG模型进行图像风格迁移实战(附源码和数据集)
163 1
|
5天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch与迁移学习:利用预训练模型提升性能
【4月更文挑战第18天】PyTorch支持迁移学习,助力提升深度学习性能。预训练模型(如ResNet、VGG)在大规模数据集(如ImageNet)训练后,可在新任务中加速训练,提高准确率。通过选择模型、加载预训练权重、修改结构和微调,可适应不同任务需求。迁移学习节省资源,但也需考虑源任务与目标任务的相似度及超参数选择。实践案例显示,预训练模型能有效提升小数据集上的图像分类任务性能。未来,迁移学习将继续在深度学习领域发挥重要作用。
|
5天前
|
机器学习/深度学习 PyTorch 语音技术
Pytorch迁移学习使用Resnet50进行模型训练预测猫狗二分类
深度学习在图像分类、目标检测、语音识别等领域取得了重大突破,但是随着网络层数的增加,梯度消失和梯度爆炸问题逐渐凸显。随着层数的增加,梯度信息在反向传播过程中逐渐变小,导致网络难以收敛。同时,梯度爆炸问题也会导致网络的参数更新过大,无法正常收敛。 为了解决这些问题,ResNet提出了一个创新的思路:引入残差块(Residual Block)。残差块的设计允许网络学习残差映射,从而减轻了梯度消失问题,使得网络更容易训练。
123 0
|
9月前
|
机器学习/深度学习 并行计算 PyTorch
迁移学习的 PyTorch 实现
迁移学习的 PyTorch 实现
|
5天前
|
机器学习/深度学习 PyTorch 调度
迁移学习的 PyTorch 实现
迁移学习的 PyTorch 实现
|
11月前
|
机器学习/深度学习 PyTorch 算法框架/工具
计算机视觉PyTorch迁移学习 - (二)
计算机视觉PyTorch迁移学习 - (二)
|
11月前
|
机器学习/深度学习 数据采集 PyTorch
计算机视觉PyTorch迁移学习 - (一)
计算机视觉PyTorch迁移学习 - (一)
|
11月前
|
数据采集 XML 数据挖掘
计算机视觉PyTorch - 数据处理(库数据和训练自己的数据)
计算机视觉PyTorch - 数据处理(库数据和训练自己的数据)
|
11月前
|
算法 PyTorch 算法框架/工具
计算机视觉PyTorch实现图像着色 - (二)
计算机视觉PyTorch实现图像着色 - (二)
131 0
计算机视觉PyTorch实现图像着色 - (二)
|
11月前
|
PyTorch 算法框架/工具 计算机视觉
计算机视觉PyTorch实现图像着色 - (一)
计算机视觉PyTorch实现图像着色 - (一)