AI实战 | Tensorflow自定义数据集和迁移学习(附代码下载)

简介: AI实战 | Tensorflow自定义数据集和迁移学习(附代码下载)

自定义数据集


做深度学习项目时,我们一般都不用网上公开的数据集,而是用自己制作的数据集。那么,怎么用Tensorflow2.0来制作自己的数据集并把数据喂给神经网络呢?且看这篇文章慢慢道来。

Pokemon Datasets


这篇文章我们用的datasets是Pokemon datasets,也就是皮卡丘电影中的一些角色,如下图所示:

ec319a1d60ae94e1449a4bda0a0607b5.png

数据集下载


链接: https://pan.baidu.com/s/1V_ZJ7ufjUUFZwD2NHSNMFw

提取码:dsxl

数据集划分


image.png

由上图可知,60%的数据集用来train,20%的数据集用来validation,同样20%用来test

四个步骤


  • Load data:加载数据
  • Build model:建立模型
  • Train-Val-Test:训练和测试
  • Transfer Learning:迁移模型

加载数据


981d4663dcbce14f273858bb71f349fc.png

首先对数据进行预处理,把像素值的Numpy类型转换为Tensor类型,并归一化到[0~1]。把数据集的标签做one-hot编码。

def preprocess(x,y):
    # x: 图片的路径,y:图片的数字编码
    x = tf.io.read_file(x)
    x = tf.image.decode_jpeg(x, channels=3) # RGBA
    x = tf.image.resize(x, [244, 244])
    return x, y

数据集标准处理流程


代码中load_pokemon用的是自己的数据集写的代码,具体可阅读pokemon.py文件。

# 创建训练集Datset对象
images, labels, table = load_pokemon('pokemon',mode='train')
db_train = tf.data.Dataset.from_tensor_slices((images, labels))
db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz)
# 创建验证集Datset对象
images2, labels2, table = load_pokemon('pokemon',mode='val')
db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))
db_val = db_val.map(preprocess).batch(batchsz)
# 创建测试集Datset对象
images3, labels3, table = load_pokemon('pokemon',mode='test')
db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))
db_test = db_test.map(preprocess).batch(batchsz)

图片数据增强及标准化


7f55cb46a4356637adc42b67e20071e0.png

一般数据集较少的话需要使用数据增强以增加数据集,防止训练网络过拟合。比如旋转角度、裁剪等,并归一化到[0~1]。把数据集的标签做one-hot编码。所示代码如下:

# x = tf.image.random_flip_left_right(x)
    x = tf.image.random_flip_up_down(x)
    x = tf.image.random_crop(x, [224,224,3])
    # x: [0,255]=> -1~1
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = normalize(x)
    y = tf.convert_to_tensor(y)
    y = tf.one_hot(y, depth=5)

建立网络


e18319e4e3c66f3cd114dd9389c16dc4.png

神经网络从零开始训练,backbone用李沐大神的resnet网络。详细代码请查看resnet.py文件。部分代码如下:

class ResNet(keras.Model):
    def __init__(self, num_classes, initial_filters=16, **kwargs):
        super(ResNet, self).__init__(**kwargs)
        self.stem = layers.Conv2D(initial_filters, 3, strides=3, padding='valid')
        self.blocks = keras.models.Sequential([
            ResnetBlock(initial_filters * 2, strides=3),
            ResnetBlock(initial_filters * 2, strides=1),
            # layers.Dropout(rate=0.5),
            ResnetBlock(initial_filters * 4, strides=3),
            ResnetBlock(initial_filters * 4, strides=1),
            ResnetBlock(initial_filters * 8, strides=2),
            ResnetBlock(initial_filters * 8, strides=1),
            ResnetBlock(initial_filters * 16, strides=2),
            ResnetBlock(initial_filters * 16, strides=1),
        ])
        self.final_bn = layers.BatchNormalization()
        self.avg_pool = layers.GlobalMaxPool2D()
        self.fc = layers.Dense(num_classes)
    def call(self, inputs, training=None):
        # print('x:',inputs.shape)
        out = self.stem(inputs,training=training)
        out = tf.nn.relu(out)
        # print('stem:',out.shape)
        out = self.blocks(out, training=training)
        # print('res:',out.shape)
        out = self.final_bn(out, training=training)
        # out = tf.nn.relu(out)
        out = self.avg_pool(out)
        # print('avg_pool:',out.shape)
        out = self.fc(out)
        # print('out:',out.shape)
        return out

训练和测试


9143e2f8bf5e3c1300eca217a58bf971.png

部分代码如下:

resnet = keras.Sequential([
    layers.Conv2D(16,5,3),
    layers.MaxPool2D(3,3),
    layers.ReLU(),
    layers.Conv2D(64,5,3),
    layers.MaxPool2D(2,2),
    layers.ReLU(),
    layers.Flatten(),
    layers.Dense(64),
    layers.ReLU(),
    layers.Dense(5)
])
resnet = ResNet(5)
resnet.build(input_shape=(4, 224, 224, 3))
resnet.summary()
early_stopping = EarlyStopping(
    monitor='val_accuracy',
    min_delta=0.001,
    patience=5
)
resnet.compile(optimizer=optimizers.Adam(lr=1e-3),
               loss=losses.CategoricalCrossentropy(from_logits=True),
               metrics=['accuracy'])
resnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=100,
           callbacks=[early_stopping])
resnet.evaluate(db_test)

迁移网络学习


网络可以丛零开始训练,也可以从别的训练好的参数模型迁移过来,本次实战用Tensorflow预训练的vgg19模型来加载训练,从而加快训练过程。

迁移学习的原理如下图所示:

158e6a21e2674d9c17fd8eb5230b4784.png

4f42abe715993c8479f9bd37da0d7c8a.png

4f42abe715993c8479f9bd37da0d7c8a.png

部分代码如下:

net = keras.applications.VGG19(weights='imagenet', include_top=False,
                               pooling='max')
net.trainable = False
newnet = keras.Sequential([
    net,
    layers.Dense(5)
])
newnet.build(input_shape=(4,224,224,3))
newnet.summary()
early_stopping = EarlyStopping(
    monitor='val_accuracy',
    min_delta=0.001,
    patience=5
)
newnet.compile(optimizer=optimizers.Adam(lr=1e-3),
               loss=losses.CategoricalCrossentropy(from_logits=True),
               metrics=['accuracy'])
newnet.fit(db_train, validation_data=db_val, validation_freq=1, epochs=100,
           callbacks=[early_stopping])
newnet.evaluate(db_test)

代码下载


本篇文章完整代码在公众号对话框回复 “pokemon” 就可得到百度云链接,建议直接复制再去公众号回复。

参考资料


本篇文章主要参考网易云课堂龙龙老师的《深度学习与TensorFlow 2入门实战》

课程链接:https://study.163.com/course/courseMain.htm?courseId=1209092816&share=1&shareId=1026182418



相关文章
|
1月前
|
人工智能 机器人 UED
不怕不会设计logo拉-本篇教你如何使用AI设计logo-如何快速用AI设计logo-附上AI绘图logo设计的咒语-优雅草央千澈-实战教程
不怕不会设计logo拉-本篇教你如何使用AI设计logo-如何快速用AI设计logo-附上AI绘图logo设计的咒语-优雅草央千澈-实战教程
147 86
不怕不会设计logo拉-本篇教你如何使用AI设计logo-如何快速用AI设计logo-附上AI绘图logo设计的咒语-优雅草央千澈-实战教程
|
2月前
|
机器学习/深度学习 人工智能 物联网
AI赋能大学计划·大模型技术与应用实战学生训练营——湖南大学站圆满结营
12月14日,由中国软件行业校园招聘与实习公共服务平台携手魔搭社区共同举办的AI赋能大学计划·大模型技术与产业趋势高校行AIGC项目实战营·湖南大学站圆满结营。
AI赋能大学计划·大模型技术与应用实战学生训练营——湖南大学站圆满结营
|
1月前
|
人工智能 数据处理 语音技术
Pipecat实战:5步快速构建语音与AI整合项目,创建你的第一个多模态语音 AI 助手
Pipecat 是一个开源的 Python 框架,专注于构建语音和多模态对话代理,支持与多种 AI 服务集成,提供实时处理能力,适用于语音助手、企业服务等场景。
108 23
Pipecat实战:5步快速构建语音与AI整合项目,创建你的第一个多模态语音 AI 助手
|
13天前
|
人工智能 IDE 程序员
通义灵码 2.0 AI 程序员下载安装
通义灵码2.0 AI程序员支持JetBrains IDEs、Visual Studio Code及远程开发场景,暂不支持Visual Studio。用户可通过插件市场搜索“TONGYI Lingma”安装,确保版本升级至2.0以上。安装后登录阿里云账号即可使用,个人版和企业版均免费。新手可参考官方指南进行IDE安装配置。
582 9
|
2月前
|
机器学习/深度学习 人工智能 JSON
【实战干货】AI大模型工程应用于车联网场景的实战总结
本文介绍了图像生成技术在AIGC领域的发展历程、关键技术和当前趋势,以及这些技术如何应用于新能源汽车行业的车联网服务中。
576 37
|
2月前
|
机器学习/深度学习 人工智能 物联网
AI赋能大学计划·大模型技术与应用实战学生训练营——电子科技大学站圆满结营
12月05日,由中国软件行业校园招聘与实习公共服务平台携手阿里魔搭社区共同举办的AI赋能大学计划·大模型技术与产业趋势高校行AIGC项目实战营·电子科技大学站圆满结营。
AI赋能大学计划·大模型技术与应用实战学生训练营——电子科技大学站圆满结营
|
28天前
|
存储 人工智能 自然语言处理
AI 工程学习 - 三张图说明白什么是 RAG
RAG(检索增强生成)是一种结合信息检索和生成模型的自然语言处理框架,通过引入外部知识库(如文档库、数据库等),增强生成模型的回答准确性与相关性。其核心在于避免模型仅依赖训练数据产生不准确或“幻觉”内容,而是通过实时检索外部资料,确保回答更精准、丰富且上下文相关。RAG的实现包括建立索引(清洗、分割、嵌入存储)和检索生成(计算相似度、选择最优片段、整合提示词模板提交给大模型)。
123 0
|
3月前
|
人工智能 自然语言处理 前端开发
VideoChat:高效学习新神器!一键解读音视频内容,结合 AI 生成总结内容、思维导图和智能问答
VideoChat 是一款智能音视频内容解读助手,支持批量上传音视频文件并自动转录为文字。通过 AI 技术,它能快速生成内容总结、详细解读和思维导图,并提供智能对话功能,帮助用户更高效地理解和分析音视频内容。
212 6
VideoChat:高效学习新神器!一键解读音视频内容,结合 AI 生成总结内容、思维导图和智能问答
|
2月前
|
人工智能 自然语言处理 算法
AI时代的企业内训全景图:从案例到实战
作为一名扎根在HR培训领域多年的“老兵”,我越来越清晰地感受到,企业内训的本质其实是为企业持续“造血”。无论是基础岗的新人培训、技能岗的操作规范培训,还是面向技术中坚力量的高阶技术研讨,抑或是管理层的战略思维提升课,内训的价值都是在帮助企业内部提升能力水平,进而提高组织生产力,减少对外部资源的依赖。更为重要的是,在当前AI、大模型、Embodied Intelligence等新兴技术快速迭代的背景下,企业必须不断为人才升级赋能,才能在市场竞争中保持领先。
|
3月前
|
机器学习/深度学习 人工智能 自然语言处理
AI驱动的个性化学习路径优化
在当前教育领域,个性化学习正逐渐成为一种趋势。本文探讨了如何利用人工智能技术来优化个性化学习路径,提高学习效率和质量。通过分析学生的学习行为、偏好和表现,AI可以动态调整学习内容和难度,实现真正的因材施教。文章还讨论了实施这种技术所面临的挑战和潜在的解决方案。
180 7

热门文章

最新文章