源代码/数据集已上传到 Github - tensorflow-tutorial-samples

这篇文章是 TensorFlow Tutorial 入门教程的第二篇文章。

上一篇文章TensorFlow入门(一) - mnist手写数字识别(网络搭建)介绍了神经网络输入输出独热编码损失函数等最基本的知识,并且演示了如何用最简单的模型实现mnist手写数字识别91%的正确率。但是遗留的问题是,模型保存在内存中,每次都得重新开始训练。

这篇文章解决的就是这个问题。将依次介绍tensorflow中如何保存已经训练好的模型,如何在某个训练步数的基础上继续训练,最后将演示如何加载模型,并借助pillow(Python2中称为PIL)库实现真实手写数字图片的识别。

模型的保存

  • 首先看一下项目的目录结构
1
2
3
4
5
6
7
8
9
10
11
|--mnist/
|--data_set/ 训练以及测试数据集
|--test_images/ 多张测试图片
|--0.png
|--1.png
|--4.png
|--v2/
|--ckpt/ 模型保存在这里!!!
|--model.py 网络模型
|--train.py 训练代码
|--predict.py 预测代码

第一步更改模型,记录global_step

每一次训练,会进行一次梯度下降,传入的global_step的值会自增1,因此,可以通过计算global_step这个张量的值,知道当前训练了多少步。

1
2
3
4
5
6
7
8
9
10
11
12

class Network:
def __init__(self):

self.global_step = tf.Variable(0, trainable=False)





self.train = tf.train.GradientDescentOptimizer(0.001).minimize(
self.loss, global_step=self.global_step)

第二步,每隔N步保存

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
CKPT_DIR = 'ckpt' 
net = Network()
sess = tf.Session()
sess.run(tf.global_variables_initializer())




saver = tf.train.Saver(max_to_keep=10)

train_step = 10000
step = 0
save_interval = 1000

while step < train_step:


step = sess.run(net.global_step)


if step % save_interval == 0:
saver.save(sess, CKPT_DIR + '/model', global_step=step)
  • 最终保存的模型如下所示

假设训练到了2000步,保存了2次模型。ckpt文件夹下会生成7个文件,第一个文件是 checkpoint文件,保存了所有的模型的路径。其中第一行代表当前的状态,即在加载模型时,使用哪一个模型是由第一行决定的。

每个模型包含3个文件,分别是

  1. model-xxx.data-00000-of-00001
  2. model-xxx.index
  3. model-xxx.meta

checkpoint文件

1
2
3
model_checkpoint_path: "model-2000"
all_model_checkpoint_paths: "model-1000"
all_model_checkpoint_paths: "model-2000"

目录结构

1
2
3
4
5
6
7
8
9
10
11
12
|--v2/  
|--ckpt/ 模型保存在这里!!!
|--checkpoint
|--model-1000.data-00000-of-00001
|--model-1000.index
|--model-1000.meta
|--model-2000.data-00000-of-00001
|--model-2000.index
|--model-2000.meta
|--model.py 网络模型
|--train.py 训练代码
|--predict.py 预测代码

加载模型与继续训练(train.py)

假设我们当前模型已经训练到了2000步,但是由于某种原因停止了。那么是否可以在2000步的基础上继续训练呢?

  • 只需一步,训练前保存的模型restore到session中即可。这里需要注意的是,创建 tf.train.Saver对象一定要在创建tf.Session之后。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
CKPT_DIR = 'ckpt'
net = Network()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(max_to_keep=10)

train_step = 10000
step = 0
save_interval = 1000



ckpt = tf.train.get_checkpoint_state(CKPT_DIR)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)

step = sess.run(net.global_step)
print('Continue from')
print(' -> Minibatch update : ', step)

while step < train_step:

  • 再次运行代码,将打印出
1
2
3
Continue from
-> Minibatch update : 2000
第 3000步,...
  • 如果将checkpoint文件的第一行改为如下,训练将从1000开始,再次训练到2000时,会将原来的2000的模型覆盖。所以restore哪一个模型,只与checkpoint的第一行有关,即只与model_checkpoint_path有关。
    1
    model_checkpoint_path: "model-1000"
1
2
3
Continue from
-> Minibatch update : 1000
第 2000步,...

使用模型预测数字(predict.py)

第一步,restore模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import numpy as np
from PIL import Image


class Predict:
def __init__(self):
self.net = Network()
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())
self.restore()

def restore(self):
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(CKPT_DIR)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(self.sess, ckpt.model_checkpoint_path)
else:
raise FileNotFoundError("未保存任何模型")

def predict(self, image_path):

第二步读入图片并预测

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Predict:


def predict(self, image_path):

img = Image.open(image_path).convert('L')
flatten_img = np.reshape(img, 784)
x = np.array([1 - flatten_img])
y = self.sess.run(self.net.y, feed_dict={self.net.x: x})



print(image_path)
print(' -> Predict digit', np.argmax(y[0]))
  • test_images目录下的0.png1.png4.png三张图片的预测结果。
    1
    2
    3
    4
    app = Predict()
    app.predict('../test_images/0.png')
    app.predict('../test_images/1.png')
    app.predict('../test_images/4.png')

最后的结果

  • 第一次 python train.py

    1
    2
    第 1000步,当前loss:26.94
    第 2000步,当前loss:28.36
  • 2000步时停止,第二次 python train.py

    1
    2
    3
    4
    5
    Continue from
    -> Minibatch update : 2000
    第 3000步,当前loss:23.49
    第 4000步,当前loss:20.40
    第 5000步,当前loss:11.65
  • python predict.py

    1
    2
    3
    4
    5
    6
    ../test_images/0.png
    -> Predict digit 0
    ../test_images/1.png
    -> Predict digit 1
    ../test_images/4.png
    -> Predict digit 4

源代码&数据集已上传到 Github

觉得还不错,不要吝惜你的star,支持是持续不断更新的动力。

附 推荐



上一篇 « TensorFlow入门(一) - mnist手写数字识别(网络搭建) 下一篇 » Pandas 数据处理(一) - DataFrame 与 Series

赞赏支持

请我吃胡萝卜 =^_^=

i ali

支付宝

i wechat

微信

© 2026 - 极客兔兔 - 沪ICP备18001798号-1

👁   📚