初学tensorflow-分类

学习使用tensorflow的数字10分类教程

1、导入功能包

import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
from tensorflow.examples.tutorials.mnist import input_data

2、下载数据集

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

3、设计隐层权重、偏置

def add_layer(inputs, in_size, out_size, activation_function=None,):
    Weights = tf.Variable(tf.random_normal([in_size, out_size]))
    biases = tf.Variable(tf.zeros([1, out_size]) + 0.1,)
    Wx_plus_b = tf.matmul(inputs, Weights) + biases
    if activation_function is None:
        outputs = Wx_plus_b
    else:
        outputs = activation_function(Wx_plus_b,)
    return outputs

4、模型输入

xs = tf.placeholder(tf.float32, [None, 784])
ys = tf.placeholder(tf.float32, [None, 10])
  • 此题的输入图像为28*28个像素点,需要全部输入网络中训练,所以输入结构为784列。另外此教程为10分类问题,输出结构为10列(one-hot编码)。

5、模型隐层

prediction = add_layer(xs, 784, 10,  activation_function=tf.nn.softmax)
  • 此处与回归模型不同,使用的是softmax激活函数,本次设计1层隐层。

6、模型编译

cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
                                                      reduction_indices=[1])) 
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
  • 注意此处计算误差的方式也与回归有区别,此处暂时理解为与激活函数对应。优化器有多种方案,同回归一样。

7、初始化参数

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

8、模型训练

for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
       sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys})
  • 每次训练时并没有把全部数据丢入模型进行训练,此处使用的是全部数据集的100个数据,依据莫烦教程讲解,100个数据依旧能把模型训练的很好。

9、模型评估

def compute_accuracy(v_xs, v_ys):
    global prediction
    y_pre = sess.run(prediction, feed_dict={xs: v_xs})
    correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(v_ys,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))#计算准确率
    result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys})
    return result
  • 定义模型评估函数

print(compute_accuracy(
    mnist.test.images, mnist.test.labels))
  • 输出测试集的准确率,评估模型。

10、代码已上传本人github,其主要根据莫烦代码修改而成。

参考文献:1、莫烦tensorflow教程

注意:在tensorflow2.0版本是不存在MNIST教程的,需要额外下载。下载教程

打赏
  • 版权声明: 本博客所有文章除特别声明外,均采用 Apache License 2.0 许可协议。转载请注明出处!
  • © 2015-2021 高腾腾
  • Powered by Hexo Theme Ayer
  • PV: UV:

谢谢大爷

支付宝
微信