博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Tensorflow中怎么使用queue读取数据的情况下,在同一个session中边训练边测试
阅读量:4080 次
发布时间:2019-05-25

本文共 2693 字,大约阅读时间需要 8 分钟。

通常情况下,我们在训练的时候都想边训练边找来一批验证数据看看在验证数据上的性能。但是由于tensorflow中如果使用queue的方法来读取数据的话,这方法就不好实现了,主要问题是由于queue读取数据的时候,我们的graph的构建就不是使用placeholder的方法而是,那么在同一个session中就没办法换输入数据了。为了更好的描述这个问题,我先给出代码:

#使用queue方法得到的数据train_images, train_label = get_batch_train_data(batch_size) valid_images, valid_label = get_batch_valid_data(batch_size)def build_graph(x, y):    #the first layer    w1 = tf......    b1 = tf.....    h1 = tf.nn.relu(tf.matmul(x,w1)+b)...    #the second layer    .....    #the xx layer    ......    h = tf.nn.relu(..)    loss = .....    accuracy = ...    train_op = ....    return loss, accuracy, train_opwith tf.Session() as sess:    ....    loss, accuracy, train_op = build_graph(train_images, train_label)    coord = tf.train.Coordinator()    enqueue_threads = qr.create_threads(sess, coord=coord, start=True)    try:      for step in xrange(1000000):        if coord.should_stop():          break        _, acc_str, los_str = sess.run([train_op, accuracy, loss])        if step % 100 == 0:        # 这个地方我想加入对验证集合的处理咋办?请看我在代码后面的解释    except Exception, e:      coord.request_stop(e)    finally:      coord.request_stop()      coord.join(threads)

一个容易的方法很多人想着是再来一个

loss, accuracy, train_op = build_graph(valid_images, valid_label)

直接换成验证数据进去不就好了?其实是不行的,因为这意味着你再建立了一个graph,你必须重新initialize_global_variables,一旦你再初始化,那么之前训练好的weight就重新随机了。很多教程给出的方法是,再这个地方把模型和训练好的参数存起来,然后重新再调用上句话,这样当然是可以的,但是这样对于训练和验证来回折腾的话 就一点都不方便了。 那么一个直接的做法就是改写对graph的定义。

def build_graph():    is_training = tf.placeholder(dtype=tf.bool, shape=())    x = tf.cond(is_training, lambda:train_images, lambda:valid_images)    y = tf.cond(is_training, lambda:train_label, lambda:valid_label)    #the first layer    w1 = tf......    b1 = tf.....    h1 = tf.nn.relu(tf.matmul(x,w1)+b)...    #the second layer    .....    #the xx layer    ......    h = tf.nn.relu(..)    loss = .....    accuracy = ...    train_op = ....    return loss, accuracy, train_op, is_trainingwith tf.Session() as sess:    ....    loss, accuracy, train_op, is_training = build_graph()    coord = tf.train.Coordinator()    enqueue_threads = qr.create_threads(sess, coord=coord, start=True)    try:      for step in xrange(1000000):        if coord.should_stop():          break        _, acc_str, los_str = sess.run([train_op, accuracy, loss], {is_training :True})        if step % 100 == 0:            valid_acc_str, valid_los_str = sess.run([accuracy, loss], {is_training:False})    except Exception, e:      coord.request_stop(e)    finally:      coord.request_stop()      coord.join(threads)

仔细对比上下代码就可以发现,就是在构建graph的时候传入了一个变量,通过可视化可以看到此时tf构建了1个大图,里面包含了分支条件。其实此处还有一个问题就是在可视化summary的问题,因为常用的方法都是训练一个batch加入到summary中,而验证的时候我是在全部验证集上做的,怎么办?下次再说!加入2个placeholder就行了。。

转载地址:http://roini.baihongyu.com/

你可能感兴趣的文章
github上搜了下有ROS uart方面的
查看>>
STM32和ROS的串口通信(这篇是公众号文章写得比较正规详细)
查看>>
全网最实用的STM32和ROS机器人的串口通信方案
查看>>
我觉得还是把ACfly的传感器的逻辑弄清楚,这样再去二次开发好一些。(折腾半天发现有很关键一部分没有开源,怪不得找不到,这让我很失望)
查看>>
freertos工程似乎都是先创建一个任务,再在这个任务里面创建其他任务,似乎就像任务树
查看>>
无人机的高度自适应
查看>>
别人对ACfly的评价
查看>>
还有你怎么判断ACfly是正常接收到了数据,怎么从ACfly端能看到实时的T265传给ACfly的位置数据。
查看>>
我觉得对双目VIO+无人机,单单靠VIO这边输出很好的位置信息还是不够的,无人机这边还是需要做做滤波,比如防止跳变什么的,保证无人机的稳定。
查看>>
英特尔RealSense激光雷达摄像头L515拆解分析
查看>>
优象光流使用的一些注意事项(转载)(光流数据要融合其他传感器使用比较好)
查看>>
mavlink里面有个关键词 msg
查看>>
mavlink消息帧里最重要的两个东西,一个是msgid;一个是payload
查看>>
【无人机开发】通讯协议MavLink详解
查看>>
B站这个讲mavlink的视频不错(弄懂了很多东西)
查看>>
*我发觉不管是mavlink还是传感器驱动都是基于串口协议的一个更高层的协议!!!!!!!(没有协议没有规则是没有办法进行通信的)
查看>>
STM32控制APM飞控(四)MAVLINK协议深入理解之数据结构
查看>>
STM32控制APM飞控(五)MAVLINK的C源码的解释及MAVLINK心跳包
查看>>
STM32控制APM飞控(二)MAVLINK源码集成到stm32工程中
查看>>
STM32下mavlink的使用个人总结(包含对ACfly里面mavlink的分析,包含接收T265的位置信息的二次开发教程)
查看>>