LSTM应用场景以及pytorch实例
在去年介绍的一篇paper中,应用了多任务RNN来解决问题,当时RNN指的即是LSTM。本文介绍LSTM实现以及应用。
1. LSTM简介 #
循环神经网络要点在于可以将上一时刻的信息传递给下一时刻,但是在需要长程信息依赖的场景,训练一个好的RNN十分困难,存在梯度爆炸和梯度消失的情况。LSTM通过刻意的设计来解决该问题。
简单的RNN网络中重复的模块只有一个简单的结构,例如一个relu
层,而在LSTM中重复的模块拥有4个不同的结构相互交互来完成。
1.1 首先决定从cell中丢弃什么信息 #
$$f_t = \sigma(W_f*[h_{t-1}, X_t] + b_f) \tag1$$ sigma函数在0到1选择代表丢弃与否
1.2 什么样的新信息存放到cell中 #
$$i_t = \sigma(W_i*[h_{t-1}, x_t] + b_i) \tag2$$
$$\widetilde{C_t} = tanh(W_c*[h_{t-1}, x_t] + b_c) \tag3$$
$$C_t = f_t*C_{t-1} + {i_t} * \widetilde{C_{t}} \tag4$$
4式中旧状态与$f_t$相乘,丢弃确定需要丢弃的信息,加上新的候选值。可以看到假如遗忘门一直为1,就可以保持以前的信息$C_{t-1}$
1.3 输出结果 #
$$o_t = \sigma(W_o[h_{t-1}, x_t] + b_o)\tag5$$ $$h_t = o_t*tanh(C_t)\tag6$$
2. LSTM实例以及Pytorch实现 #
循环神经网络可以应用到以下场景。
- 点对点(单个图片(文字)被分类;图像分类)
- 点对序列(单个图像(文字)被分为多个类;图像输出文字)
- 序列分析(一系列图片(文字)被分类;情感分析)
- 不等长序列对序列(机器翻译)
- 等长序列对序列(视频帧分类)
举两个例子:图像分类以及时间序列预测
2.1 LSTM图像分类 #
关于图片分类常用卷积神经网络,侧重空间上处理;而循环神经网络侧重序列处理。但是也能用来图片分类。第一个例子以常用的mnist手写字体识别为例。
2.1.1 导入所需用到的包以及超参数设置等 #
# Setup
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision.datasets as dsets
import torchvision.transforms as transforms
torch.manual_seed(1)
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
2.1.2 导入数据集 #
# Mnist手写数字
train_data = dsets.MNIST(root='./mnist/', # 保存或者提取位置
train=True, # this is tra`ining data
transform=transforms.ToTensor(), # 转换 PIL.Image or numpy.ndarray 成
# torch.FloatTensor (C x H...
剩余内容已隐藏
LSTM应用场景以及pytorch实例
在去年介绍的一篇paper中,应用了多任务RNN来解决问题,当时RNN指的即是LSTM。本文介绍LSTM实现以及应用。
1. LSTM简介 #
循环神经网络要点在于可以将上一时刻的信息传递给下一时刻,但是在需要长程信息依赖的场景,训练一个好的RNN十分困难,存在梯度爆炸和梯度消失的情况。LSTM通过刻意的设计来解决该问题。
简单的RNN网络中重复的模块只有一个简单的结构,例如一个relu
层,而在LSTM中重复的模块拥有4个不同的结构相互交互来完成。
1.1 首先决定从cell中丢弃什么信息 #
$$f_t = \sigma(W_f*[h_{t-1}, X_t] + b_f) \tag1$$ sigma函数在0到1选择代表丢弃与否
1.2 什么样的新信息存放到cell中 #
$$i_t = \sigma(W_i*[h_{t-1}, x_t] + b_i) \tag2$$
$$\widetilde{C_t} = tanh(W_c*[h_{t-1}, x_t] + b_c) \tag3$$
$$C_t = f_t*C_{t-1} + {i_t} * \widetilde{C_{t}} \tag4$$
4式中旧状态与$f_t$相乘,丢弃确定需要丢弃的信息,加上新的候选值。可以看到假如遗忘门一直为1,就可以保持以前的信息$C_{t-1}$
1.3 输出结果 #
$$o_t = \sigma(W_o[h_{t-1}, x_t] + b_o)\tag5$$ $$h_t = o_t*tanh(C_t)\tag6$$
2. LSTM实例以及Pytorch实现 #
循环神经网络可以应用到以下场景。
- 点对点(单个图片(文字)被分类;图像分类)
- 点对序列(单个图像(文字)被分为多个类;图像输出文字)
- 序列分析(一系列图片(文字)被分类;情感分析)
- 不等长序列对序列(机器翻译)
- 等长序列对序列(视频帧分类)
举两个例子:图像分类以及时间序列预测
2.1 LSTM图像分类 #
关于图片分类常用卷积神经网络,侧重空间上处理;而循环神经网络侧重序列处理。但是也能用来图片分类。第一个例子以常用的mnist手写字体识别为例。
2.1.1 导入所需用到的包以及超参数设置等 #
# Setup
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision.datasets as dsets
import torchvision.transforms as transforms
torch.manual_seed(1)
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
2.1.2 导入数据集 #
# Mnist手写数字
train_data = dsets.MNIST(root='./mnist/', # 保存或者提取位置
train=True, # this is tra`ining data
transform=transforms.ToTensor(), # 转换 PIL.Image or numpy.ndarray 成
# torch.FloatTensor (C x H...
剩余内容已隐藏