发布于 

tensorflow笔记(1)

激励函数(激活函数)

为了神经网络实现非线性的任务,所引入的方式。
激活函数并没有什么特别要选定的,只要可以「掰弯」线性函数就行。
不过要确保这些激励函数必须是可以微分的,因为在 backpropagation 误差反向传递的时候,只有这些可微分的激励函数才能把误差传递回去。

** 不过如果随便使用激励函数,在两三层的时候无所谓,但是在特别多层的时候。容易出现梯度爆炸和梯度消失的问题。**

  • 在卷积神经网络(Convolutional neural networks) 的卷积层中, 推荐的激励函数是 relu.
  • 在循环神经网络中(recurrent neural networks), 推荐的是 tanh 或者是 relu

简单的tensorflow训练模型(Hello World)

import tensorflow as tf
import numpy as np

#create data
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3

#搭建模型
Weights = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
biases = tf.Variable(tf.zeros([1]))
y = Weights * x_data + biases

#计算误差
loss = tf.reduce_mean(tf.square(y - y_data))

#传播误差  使用的误差传递方法是梯度下降法
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)

#训练
init = tf.global_variables_initializer()  #初始化
sess = tf.Session()
sess.run(init)

for step in range(201):
    sess.run(train)
    if step % 20 == 0:
        print(step, sess.run(Weights), sess.run(biases))

placeholder使用例子

import tensorflow as tf

#在tensorflow 中需要定义placeholder的type,yi ban
input1 = tf.placeholder(tf.float32)
input2 = tf.placeholder(tf.float32)

# mul = multiply 是将input1和input2 做乘法运算
output = tf.multiply(input1, input2)

#传值的工作交给了sess.run()

with tf.Session() as sess:
    print(sess.run(output, feed_dict={input1: [7.], input2: [2.]}))

####session使用例子

import tensorflow as tf

martrix1 = tf.constant([[3, 3]])
martrix2 = tf.constant([[2], [2]])
product = tf.matmul(martrix1, martrix2)

# product
# method 1
sess = tf.Session()
result = sess.run(product)
print(result)
sess.close()
#[[12]]d

# method 2
with tf.Session() as sess:
    result2 = sess.run(product)
    print(result2)

variable使用例子

import tensorflow as tf

state = tf.Variable(0, name="counter")

#定义常量one
one = tf.constant(1)

#定义加法步骤(注:此步并没有直接计算)
new_value = tf.add(state, one)

#将State更新成new_value
update = tf.assign(state, new_value)

# 如果定义Variable,就一定要initialize
init = tf.global_variables_initializer()
# 
#
with tf.Session() as sess:
    sess.run(init)
    for _ in range(3):
        sess.run(update)
        print(sess.run(state))

本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。

本站由 @shyiuanchen 创建,使用 Stellar 作为主题。