行业报告 AI展会 数据标注 标注供求
数据标注数据集
主页 > 机器学习 > 正文

Keras实现LSTM

LSTM是优秀的循环神经网络(RNN)结构,而LSTM在结构上也比较复杂,对RNN和LSTM还稍有疑问的朋友可以参考:Recurrent Neural Networks vs LSTM

这里我们将要使用Keras搭建LSTM.Keras封装了一些优秀的深度学习框架的底层实现,使用起来相当简洁,甚至不需要深度学习的理论知识,你都可以轻松快速的搭建你的深度学习网络,强烈推荐给刚入门深度学习的同学使用,当然我也是还没入门的那个。Keras:https://keras.io/,keras的backend有,theano,TensorFlow、CNTk,这里我使用的是TensorFlow。

下面我们就开始搭建LSTM,实现mnist数据的分类。

step 0 加载包和定义参数

mnist的image是28*28的shape,我们定义LSTM的input为(28,),将image一行一行地输入到LSTM的cell中,这样time_step就是28,表示一个image有28行,LSTM的output是30个。
 
from keras.datasets import mnist
from keras.layers import Dense, LSTM
from keras.utils import to_categorical
from keras.models import Sequential

#parameters for LSTM
nb_lstm_outputs = 30  #神经元个数
nb_time_steps = 28  #时间序列长度
nb_input_vector = 28 #输入序列
 

step 1 数据预处理

特别注意label要使用one_hot encoding,x_train的shape(60000, 28,28)

#data preprocessing: tofloat32, normalization, one_hot encoding
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10)
 

step 2 搭建模型

keras搭建模型相当简单,只需要在Sequential容器中不断add新的layer就可以了。

#build model
model = Sequential()
model.add(LSTM(units=nb_lstm_outputs, input_shape=(nb_time_steps, nb_input_vector)))
model.add(Dense(10, activation='softmax'))

 

step 3 compile

模型compile,指定loss function, optimizer, metrics

#compile:loss, optimizer, metrics
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

 

 step 4 train

模型训练,需要指定,epochs训练的轮次数,batch_size。

#train: epcoch, batch_size
model.fit(x_train, y_train, epochs=20, batch_size=128, verbose=1)

 step 5 evaluate

可以使用model.summary()来查看你的神经网络的架构和参数量等信息。

model.summary()

score = model.evaluate(x_test, y_test,batch_size=128, verbose=1)
print(score)

我的最后结果是:

 

 

微信公众号

声明:本站部分作品是由网友自主投稿和发布、编辑整理上传,对此类作品本站仅提供交流平台,转载的目的在于传递更多信息及用于网络分享,并不代表本站赞同其观点和对其真实性负责,不为其版权负责。如果您发现网站上有侵犯您的知识产权的作品,请与我们取得联系,我们会及时修改或删除。

网友评论:

发表评论
请自觉遵守互联网相关的政策法规,严禁发布色情、暴力、反动的言论。
评价:
表情:
用户名: 验证码:点击我更换图片
SEM推广服务

Copyright©2005-2026 Sykv.com 可思数据 版权所有    京ICP备14056871号

关于我们   免责声明   广告合作   版权声明   联系我们   原创投稿   网站地图  

可思数据 数据标注行业联盟

扫码入群
扫码关注

微信公众号

返回顶部