使用PyTorch从零开始构建Elman循环神经网络

简介: 循环神经网络是如何工作的?如何构建一个Elman循环神经网络?在这里,教你手把手创建一个Elman循环神经网络进行简单的序列预测。

本文以最简单的RNNs模型为例:Elman循环神经网络,讲述循环神经网络的工作原理即便是你没有太多循环神经网络(RNNs)的基础知识,也可以很容易的理解。为了让你更好的理解RNNs我们使用Pytorch张量包和autograd从头开始构建Elman循环神经网络。该文中完整代码在Github上是可实现的。

在这里,假设你对前馈神经网络略有了解。Pytorchautograd更为详细的内容请查看我的其他教程

68363d99d874dfb5a2f53084e522c73a9f5f34cd

Elman循环神经网络

Jeff Elman首次提出了Elman循环神经网络,并发表在论文《Finding structure in time》中:它只是一个三层前馈神经网络,输入层由一个输入神经元x1一组上下文神经元单元{c1 ... cn}组成隐藏层前一时间步的神经元作为上下文神经元的输入在隐藏层中每个神经元有一个上下文神经元。由于前一时间步的状态作为输入的一部分,因此我们可以说Elman循环神经网络拥有一定的内存——上下文神经元代表一个内存。

预测正弦波

现在,我们来训练RNNs学习正弦函数。在训练过程中,一次为模型提供一个数据,这就是为什么我们只需要一个输入神经元x1,并且我们希望在下一时间步预测该值。输入序列x20个数据组成,并且目标序列与输入序列相同。

5a336e128286628cf86977c17783f34d70d06e35 

模型实现

首先导入包。

e5048e42574aeaea8f76d0af60084ea0f24d3c9c 

接下来,设置模型的超参数。设置输入层的大小为76个上下文神经元和1个输入神经元),seq_length用来定义输入和目标序列的长度。

c8be968e1760e5b306d6e3a5bec48b5d46c12026 

生成训练数据:x是输入序列,y是目标序列。

6fa08fc0e07c757d55b30fd33e064bb0347ae3b7 

创建两个权重矩阵。大小为(input_sizehidden_size)的矩阵w1用于隐藏连接的输入,大小为(hidden_sizeoutput_size)的矩阵w2用于隐藏连接的输出。 用零均值的正态分布对权重矩阵进行初始化。

d8a4877fbd8b331ef7755cc0a7b4b1cec27df859 

定义forward方法,其参数为input向量、context_state向量和两个权重矩阵,连接inputcontext_state创建xh向量。对xh向量和权重矩阵w1执行点积运算,然后用tanh函数作为非线性函数,在RNNstanhsigmoid效果要好。 然后对新的context_state和权重矩阵w2再次执行点积运算。 我们想要预测连续值,因此这个阶段不使用任何非线性。

请注意,context_state向量将在下一时间步填充上下文神经元。 这就是为什么我们要返回context_state向量和out

4ca6517a08cbd409471972aab15631a2447a6b98 

训练

训练循环的结构如下:

1.外循环遍历每个epochepoch被定义为所有的训练数据全部通过训练网络一次。在每个epoch开始时,将context_state向量初始化为0

2.内部循环遍历序列中的每个元素。执行forward方法进行正向传递,该方法返回predcontext_state,将用于下一个时间步。然后计算均方误差(MSE)用于预测连续值。执行backward()方法计算梯度,然后更新权重w1w2。每次迭代中调用zero_()方法清除梯度,否则梯度将会累计起来。最后将context_state向量包装放到新变量中,以将其与历史值分离开来。

4085d3f491eab971b4f5612b58d3e9d11b68b7b4 

训练期间产生的输出显示了每个epoch的损失是如何减少的,这是一个好的衡量方式。损失的逐渐减少则意味着我们的模型正在学习。

23b7ff14535e2eb453038bfc222503c1cec5cc26 

预测

一旦模型训练完毕,我们就可以进行预测。在序列的每一步我们只为模型提供一个数据,并要求模型在下一个步预测一个值。

f25e66a8ab5e47ae383972b9cc3a80e64420e433 

预测结果如下图所示:黄色圆点表示预测值,蓝色圆点表示实际值,二者基本吻合,因此模型的预测效果非常好。

a16cb76e976c4b32d1db8e55780e5a18816f67dd

结论

在这里,我们使用了Pytorch从零开始构建一个基本的RNNs模型,并且学习了如何将RNNs应用于简单的序列预测问题。


数十款阿里云产品限时折扣中,赶紧点击领劵开始云上实践吧! 

以上为译文。

本文由北邮@爱可可-爱生活 老师推荐,阿里云云栖社区组织翻译。

文章原标题《Introduction to Recurrent Neural Networks in Pytorch》,译者:Mags,审校:袁虎。

文章为简译,更为详细的内容,请查看原文 

 

相关文章
|
5天前
|
安全 网络安全 API
构建高效微服务架构的五大关键策略网络安全与信息安全:防范网络威胁的关键策略
【5月更文挑战第31天】 在现代软件开发领域,微服务架构已经成为实现灵活、可扩展及容错系统的重要解决方案。本文将深入探讨构建高效微服务架构的五个核心策略:服务划分原则、API网关设计、服务发现与注册、熔断机制以及持续集成与部署。这些策略不仅有助于开发团队提升系统的可维护性和可伸缩性,同时也确保了高可用性和服务质量。通过实践案例和性能分析,我们将展示如何有效应用这些策略以提高微服务的性能和稳定性。
|
5天前
|
机器学习/深度学习 计算机视觉
【YOLOv8改进】 YOLOv8 更换骨干网络之GhostNetV2 长距离注意力机制增强廉价操作,构建更强端侧轻量型骨干 (论文笔记+引入代码)
该专栏聚焦YOLO目标检测的创新改进与实战,介绍了轻量级CNNs和注意力机制在移动设备上的应用。文章提出了一种名为GhostNetV2的新架构,结合了硬件友好的DFC注意力机制,强化了特征表达能力和全局信息捕获,同时保持低计算成本和高效推理。GhostNetV2在ImageNet上以167M FLOPs达到75.3%的top-1准确率,优于同类模型。创新点包括DFC注意力、模型结构优化和效率提升。源代码可在GitHub和MindSpore平台上找到。此外,还提到了YOLOv8的相关实现和任务配置。
|
5天前
|
JSON Android开发 开发者
构建高效Android应用:采用Kotlin协程优化网络请求
【5月更文挑战第31天】 在移动开发领域,尤其是针对Android平台,网络请求的管理和性能优化一直是开发者关注的焦点。随着Kotlin语言的普及,其提供的协程特性为异步编程提供了全新的解决方案。本文将深入探讨如何利用Kotlin协程来优化Android应用中的网络请求,从而提升应用的响应速度和用户体验。我们将通过具体实例分析协程与传统异步处理方式的差异,并展示如何在现有项目中集成协程进行网络请求优化。
|
6天前
|
人工智能 自然语言处理 安全
构建未来:AI驱动的自适应网络安全防御系统提升软件测试效率:自动化与持续集成的实践之路
【5月更文挑战第30天】 在数字化时代,网络安全已成为维护信息完整性、保障用户隐私和企业持续运营的关键。传统的安全防御手段,如防火墙和入侵检测系统,面对日益复杂的网络攻击已显得力不从心。本文提出了一种基于人工智能(AI)技术的自适应网络安全防御系统,该系统能够实时分析网络流量,自动识别潜在威胁,并动态调整防御策略以应对未知攻击。通过深度学习算法和自然语言处理技术的结合,系统不仅能够提高检测速度和准确性,还能自主学习和适应新型攻击模式,从而显著提升网络安全防御的效率和智能化水平。 【5月更文挑战第30天】 在快速迭代的软件开发周期中,传统的手动测试方法已不再适应现代高效交付的要求。本文探讨了如
|
7天前
|
机器学习/深度学习 安全 网络安全
利用机器学习优化数据中心能效的研究数字堡垒的构建者:网络安全与信息安全的深层探索
【5月更文挑战第29天】在云计算和大数据时代,数据中心的能效问题成为关键挑战之一。本文通过集成机器学习技术与现有数据中心管理策略,提出了一种新型的智能优化框架。该框架能够实时分析数据中心的能耗模式,并自动调整资源分配,以达到降低能耗的目的。研究结果表明,应用机器学习算法可以显著提升数据中心的能源使用效率,同时保持服务质量。
|
7天前
|
监控 安全 PHP
构建安全防线:在云计算时代维护网络与信息安全深入理解PHP的命名空间与自动加载机制
【5月更文挑战第29天】 随着企业和个人日益依赖云服务,云计算的便捷性和成本效益已得到广泛认可。然而,数据存储和处理的云端迁移也带来了新的挑战,尤其是网络安全和信息保护方面的问题。本文将深入探讨云计算环境中的安全威胁,包括数据泄露、不正当访问和服务中断等,以及为应对这些威胁而采取的策略和技术措施。我们将重点讨论最新的加密技术、身份验证机制、入侵检测系统(IDS)以及合规性监控工具,它们共同构成了维护云服务安全的多层次防御体系。 【5月更文挑战第29天】在PHP的世界中,命名空间和自动加载机制是现代PHP应用程序开发中不可或缺的组成部分。它们不仅解决了代码重用性和依赖管理的问题,而且促进了更为
|
7天前
|
机器学习/深度学习 人工智能 安全
构建未来:AI驱动的自适应网络安全防御系统
【5月更文挑战第29天】 随着网络攻击手段的不断演变和升级,传统的基于特征的安全防御机制已不再能够有效地应对日益复杂的安全威胁。本文探讨了如何通过集成人工智能(AI)技术来构建一个自适应的网络安全防御系统,该系统能够在不断变化的网络环境中学习、预测并主动防御未知威胁。通过深度学习算法、实时数据分析和自动化响应策略,这种新型系统旨在提高企业级网络安全的智能化水平,减少人为干预,同时提升防御效率和准确性。
|
8天前
|
云安全 安全 网络安全
云端防御:在云服务中构建网络安全的长城
【5月更文挑战第28天】 随着企业纷纷迁移至云端,云计算服务的安全性成为维护信息安全的关键。本文深入探讨了在云环境中实现网络安全的策略和技术,分析了云服务模型(IaaS、PaaS、SaaS)的安全挑战,并提出了一系列针对数据保护、身份验证、访问控制和威胁监测的解决方案。通过综合运用加密技术、多因素认证、入侵检测系统等手段,文章旨在为读者提供一套全面的云安全架构设计指南。
|
8天前
|
机器学习/深度学习 人工智能 算法
利用深度学习技术优化图像识别准确性网络堡垒的构建者:深入网络安全与信息保护策略
【5月更文挑战第28天】 随着人工智能的不断发展,图像识别作为其重要分支之一,在多个领域内得到了广泛应用。然而,识别准确性的提升一直是该领域的研究重点。本文通过引入深度学习技术,构建了一个多层次的卷积神经网络模型,用于提升图像识别的准确性。文中详细阐述了模型的结构设计、训练过程以及参数调优策略,并通过实验验证了所提出方法的有效性。结果表明,与传统图像识别方法相比,深度学习技术能显著提高识别精度,并具有较强的泛化能力。
|
9天前
|
监控 安全 网络安全
网络安全与信息安全:防护之道与加密技术构建高效Android应用:从基础到高级的内存优化策略
【5月更文挑战第27天】在数字化时代,数据成为了新的货币。然而,随着信息技术的蓬勃发展,网络安全漏洞和信息泄露事件层出不穷,对个人隐私和企业安全构成了严重威胁。本文将深入探讨网络安全的重要性,分析当前常见的网络攻击方式,并重点分享关于加密技术和提升安全意识的知识。通过阅读本文,读者将获得如何有效防御网络威胁、保护个人和企业信息安全的策略。