TensorFlow实现多输入源多输出

简介: TensorFlow实现多输入源多输出

TensorFlow实现多输入多输出模型

有时我们的输入数据不只一个,会存在多个输入源,多个输出源,对于这种情况我们使用Sequential显然是不行的,因为Sequential只能够搭建线性拓扑模型,对于那种流水线型的模型较为适合,如果是非线性拓扑,复杂的拓扑使用Sequential是不能够实现的,这是我们就需要使用Function API,它会使我们处理多输入多输出变得简单。

例如,如果您要构建一个系统,该系统按照优先级对自定义问题工单进行排序,然后将工单传送到正确的部门,则此模型将具有三个输入:

  • 工单标题(文本输入),
  • 工单的文本正文(文本输入),以及
  • 用户添加的任何标签(分类输入)

此模型将具有两个输出:

  • 介于 0 和 1 之间的优先级分数(标量 Sigmoid 输出),以及
  • 应该处理工单的部门(部门范围内的 Softmax 输出)。

您可以使用函数式 API 通过几行代码构建此模型:

定义输入源

模型的数据输入源有三个:标题输入、正文输入、标签输入

# 输入源
title_input = keras.Input(shape=(None,), name='title')
body_input = keras.Input(shape=(None,), name='body')
tags_input = keras.Input(shape=(num_tags,), name='tags')

定义层之间关系

首先需要将标题和正文的文本进行Embedding处理,然后使用长短期网络LSTM进行处理,然后将处理后的特征向量进行维度拼接,将三个输入源的特征向量进行拼接

title_features = Embedding(num_words, 64)(title_input)
body_features = Embedding(num_words, 64)(body_input)
title_features = LSTM(128)(title_features)
body_features = LSTM(128)(body_features)
x = Concatenate(axis=1)([title_features, body_features, tags_input])

定义输出源

定义两个输出源,第一个是预测优先级,所以Dense为1,另外一个是预测部分,需要使用softmax,所以维度是部门的数量

# 输出源
priority_pred = Dense(1, name='priority')(x)
department_pred = Dense(num_departments, name='department')(x)

构建模型

只需要指定模型的输入和输出,然后keras.Model会自动根据输入和输出之间的关系搭建出计算图拓扑

model = keras.Model(
    inputs=[title_input, body_input, tags_input],
    outputs=[priority_pred, department_pred]
)

编译模型

为不同的输出指定不同的损失函数,并且赋值不同的损失权重

model.compile(
    optimizer=keras.optimizers.RMSprop(1e-3),
    loss={
        'priority': keras.losses.BinaryCrossentropy(from_logits=True),
        'department': keras.losses.CategoricalCrossentropy(from_logits=True)
    },
    loss_weights=[1.0, 0.2]
)

定义数据集

title_data = np.random.randint(num_words, size=(1280, 10))
body_data = np.random.randint(num_words, size=(1280, 100))
tags_data = np.random.randint(2, size=(1280, num_tags)).astype('float32')
priority_targets = np.random.random(size=(1280, 1))
department_targets = np.random.randint(2, size=(1280, num_departments))

训练模型

model.fit(
    {'title': title_data, 'body': body_data, 'tags': tags_data},
    {'priority': priority_targets, 'department': department_targets},
    epochs=2,
    batch_size=32
)

完整代码

"""
 * Created with PyCharm
 * 作者: 阿光
 * 日期: 2022/1/1
 * 时间: 19:32
 * 描述:
"""
import numpy as np
from tensorflow import keras
from tensorflow.keras.layers import *
num_tags = 12
num_words = 10000
num_departments = 4
# 输入源
title_input = keras.Input(shape=(None,), name='title')
body_input = keras.Input(shape=(None,), name='body')
tags_input = keras.Input(shape=(num_tags,), name='tags')
title_features = Embedding(num_words, 64)(title_input)
body_features = Embedding(num_words, 64)(body_input)
title_features = LSTM(128)(title_features)
body_features = LSTM(128)(body_features)
x = Concatenate(axis=1)([title_features, body_features, tags_input])
# 输出源
priority_pred = Dense(1, name='priority')(x)
department_pred = Dense(num_departments, name='department')(x)
model = keras.Model(
    inputs=[title_input, body_input, tags_input],
    outputs=[priority_pred, department_pred]
)
keras.utils.plot_model(model, "multi_input_and_output_model.png", show_shapes=True)
model.compile(
    optimizer=keras.optimizers.RMSprop(1e-3),
    loss={
        'priority': keras.losses.BinaryCrossentropy(from_logits=True),
        'department': keras.losses.CategoricalCrossentropy(from_logits=True)
    },
    loss_weights=[1.0, 0.2]
)
title_data = np.random.randint(num_words, size=(1280, 10))
body_data = np.random.randint(num_words, size=(1280, 100))
tags_data = np.random.randint(2, size=(1280, num_tags)).astype('float32')
priority_targets = np.random.random(size=(1280, 1))
department_targets = np.random.randint(2, size=(1280, num_departments))
model.fit(
    {'title': title_data, 'body': body_data, 'tags': tags_data},
    {'priority': priority_targets, 'department': department_targets},
    epochs=2,
    batch_size=32
)


目录
相关文章
|
传感器 数据采集 算法
嵌入式系统中的实时数据处理与优化
嵌入式系统中的实时数据处理与优化
343 0
嵌入式系统中的实时数据处理与优化
|
算法 JavaScript 前端开发
开源项目推荐:CNC+CRC/SoftPLC/OpenCASCADE/CAD/CAM(三)
开源项目推荐:CNC+CRC/SoftPLC/OpenCASCADE/CAD/CAM
3809 1
开源项目推荐:CNC+CRC/SoftPLC/OpenCASCADE/CAD/CAM(三)
|
5月前
|
数据采集 存储 监控
星河中的数据旅程:从普通字段到核心指标 -- 基于Dataphin的数据源资产全链路管理
在数据星河中,Starrocks星球的字段居民渴望登上资产管理平台,贡献数据力量。通过元数据采集、标准稽核与质量监控,字段们获得新身份“核心业务指标”。借助Dataphin平台功能,如自定义属性和QuickBI对接,它们最终参与经营分析报表,助力决策。Dataphin V4.4提升了全链路管理能力,新增大数据存储元数据采集、自定义指标等功能,释放数据潜力。加入Dataphin,探索数据无限可能!
162 8
|
5月前
|
存储 监控 安全
攻击者是如何利用安全支持提供程序(SSP)来转储凭据的
本文探讨了攻击者如何利用安全支持提供程序(SSP)动态链接库(DLL)窃取Windows系统中的登录凭据。通过修改注册表项或内存注入技术,攻击者可加载恶意SSP至本地安全机构(LSA)进程中,提取加密或明文密码。文章详细分析了两种方法:注册SSP DLL和内存中更新SSP,并展示了Mimikatz工具的应用。为防范此类攻击,建议使用监控解决方案检测域控制器上的异常修改,确保系统安全。
192 8
|
开发框架 监控 物联网
【Uniapp 专栏】探索 Uniapp 开发的更高级应用场景
【5月更文挑战第17天】Uniapp作为跨平台开发框架,在物联网、实时数据监控、企业级应用、地理定位和教育、电商领域展现出广泛应用潜力。通过蓝牙连接智能家居,实时展示数据变化,构建复杂业务流程,定位服务及互动学习平台,它提供了创新解决方案。随着技术发展,Uniapp将继续为开发者创造更多机遇和挑战,推动移动应用领域的前进。
425 0
【Uniapp 专栏】探索 Uniapp 开发的更高级应用场景
|
7月前
|
人工智能 监控 安全
从 DeepSeek 敏感信息泄露谈可观测系统的数据安全预防
本文将探讨 SLS 中增强数据安全的几种方式:权限精细化管控有效减少了潜在安全风险;接入层脱敏技术阻止敏感数据落库,提升了隐私保护;StoreView 字段集控制通过限制查询数据范围,降低数据泄露损害。智能监控系统提供实时监测,快速识别并阻断异常拖库行为,为企业提供了迅速响应和抵御威胁的能力。
|
11月前
|
XML 前端开发 测试技术
Postman
Postman是一款功能强大的API开发和测试工具,被广泛应用于软件开发的各个阶段
434 57
|
11月前
|
缓存 Linux 开发者
Avalonia开源控件库强力推荐-Semi.Avalonia
【11月更文挑战第3天】Semi.Avalonia 是一个基于 Avalonia 的开源控件库,提供了丰富的自定义控件和扩展功能。它支持多种样式按钮、高级输入控件和灵活的布局容器,简化了属性设置,并提供了详细的文档支持。Semi.Avalonia 还支持多种内置主题和自定义主题,具备高效的渲染机制和合理的资源管理,适用于跨平台桌面应用程序开发。
723 2
|
JavaScript 前端开发 Java
for循环、break和continue、二重循环
【10月更文挑战第12天】这段内容介绍了编程中的 `for` 循环,包括基本概念、应用场景以及 `break` 和 `continue` 语句的使用方法。`for` 循环是一种常用的流程控制语句,用于重复执行一段代码。文中通过不同语言的示例说明了如何遍历数组、计算数值和创建矩阵等。此外,还介绍了二重循环的概念及其在处理二维数据结构中的应用。
347 1
|
缓存 NoSQL 应用服务中间件
【开发系列】秒杀系统的设计
【开发系列】秒杀系统的设计