本文共 2693 字,大约阅读时间需要 8 分钟。
通常情况下,我们在训练的时候都想边训练边找来一批验证数据看看在验证数据上的性能。但是由于tensorflow中如果使用queue的方法来读取数据的话,这方法就不好实现了,主要问题是由于queue读取数据的时候,我们的graph的构建就不是使用placeholder的方法而是,那么在同一个session中就没办法换输入数据了。为了更好的描述这个问题,我先给出代码:
#使用queue方法得到的数据train_images, train_label = get_batch_train_data(batch_size) valid_images, valid_label = get_batch_valid_data(batch_size)def build_graph(x, y): #the first layer w1 = tf...... b1 = tf..... h1 = tf.nn.relu(tf.matmul(x,w1)+b)... #the second layer ..... #the xx layer ...... h = tf.nn.relu(..) loss = ..... accuracy = ... train_op = .... return loss, accuracy, train_opwith tf.Session() as sess: .... loss, accuracy, train_op = build_graph(train_images, train_label) coord = tf.train.Coordinator() enqueue_threads = qr.create_threads(sess, coord=coord, start=True) try: for step in xrange(1000000): if coord.should_stop(): break _, acc_str, los_str = sess.run([train_op, accuracy, loss]) if step % 100 == 0: # 这个地方我想加入对验证集合的处理咋办?请看我在代码后面的解释 except Exception, e: coord.request_stop(e) finally: coord.request_stop() coord.join(threads)
一个容易的方法很多人想着是再来一个
loss, accuracy, train_op = build_graph(valid_images, valid_label)
直接换成验证数据进去不就好了?其实是不行的,因为这意味着你再建立了一个graph,你必须重新initialize_global_variables,一旦你再初始化,那么之前训练好的weight就重新随机了。很多教程给出的方法是,再这个地方把模型和训练好的参数存起来,然后重新再调用上句话,这样当然是可以的,但是这样对于训练和验证来回折腾的话 就一点都不方便了。 那么一个直接的做法就是改写对graph的定义。
def build_graph(): is_training = tf.placeholder(dtype=tf.bool, shape=()) x = tf.cond(is_training, lambda:train_images, lambda:valid_images) y = tf.cond(is_training, lambda:train_label, lambda:valid_label) #the first layer w1 = tf...... b1 = tf..... h1 = tf.nn.relu(tf.matmul(x,w1)+b)... #the second layer ..... #the xx layer ...... h = tf.nn.relu(..) loss = ..... accuracy = ... train_op = .... return loss, accuracy, train_op, is_trainingwith tf.Session() as sess: .... loss, accuracy, train_op, is_training = build_graph() coord = tf.train.Coordinator() enqueue_threads = qr.create_threads(sess, coord=coord, start=True) try: for step in xrange(1000000): if coord.should_stop(): break _, acc_str, los_str = sess.run([train_op, accuracy, loss], {is_training :True}) if step % 100 == 0: valid_acc_str, valid_los_str = sess.run([accuracy, loss], {is_training:False}) except Exception, e: coord.request_stop(e) finally: coord.request_stop() coord.join(threads)
仔细对比上下代码就可以发现,就是在构建graph的时候传入了一个变量,通过可视化可以看到此时tf构建了1个大图,里面包含了分支条件。其实此处还有一个问题就是在可视化summary的问题,因为常用的方法都是训练一个batch加入到summary中,而验证的时候我是在全部验证集上做的,怎么办?下次再说!加入2个placeholder就行了。。
转载地址:http://roini.baihongyu.com/