我正在尝试使this MNIST example适应二进制分类。
但是,当我将NLABELS
从NLABELS=2
更改为NLABELS=1
时,损失函数始终返回0(精度为1)。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
# Import data
mnist = input_data.read_data_sets('data', one_hot=True)
NLABELS = 2
sess = tf.InteractiveSession()
# Create the model
x = tf.placeholder(tf.float32, [None, 784], name='x-input')
W = tf.Variable(tf.zeros([784, NLABELS]), name='weights')
b = tf.Variable(tf.zeros([NLABELS], name='bias'))
y = tf.nn.softmax(tf.matmul(x, W) + b)
# Add summary ops to collect data
_ = tf.histogram_summary('weights', W)
_ = tf.histogram_summary('biases', b)
_ = tf.histogram_summary('y', y)
# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, NLABELS], name='y-input')
# More name scopes will clean up the graph representation
with tf.name_scope('cross_entropy'):
cross_entropy = -tf.reduce_mean(y_ * tf.log(y))
_ = tf.scalar_summary('cross entropy', cross_entropy)
with tf.name_scope('train'):
train_step = tf.train.GradientDescentOptimizer(10.).minimize(cross_entropy)
with tf.name_scope('test'):
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
_ = tf.scalar_summary('accuracy', accuracy)
# Merge all the summaries and write them out to /tmp/mnist_logs
merged = tf.merge_all_summaries()
writer = tf.train.SummaryWriter('logs', sess.graph_def)
tf.initialize_all_variables().run()
# Train the model, and feed in test data and record summaries every 10 steps
for i in range(1000):
if i % 10 == 0: # Record summary data and the accuracy
labels = mnist.test.labels[:, 0:NLABELS]
feed = {x: mnist.test.images, y_: labels}
result = sess.run([merged, accuracy, cross_entropy], feed_dict=feed)
summary_str = result[0]
acc = result[1]
loss = result[2]
writer.add_summary(summary_str, i)
print('Accuracy at step %s: %s - loss: %f' % (i, acc, loss))
else:
batch_xs, batch_ys = mnist.train.next_batch(100)
batch_ys = batch_ys[:, 0:NLABELS]
feed = {x: batch_xs, y_: batch_ys}
sess.run(train_step, feed_dict=feed)
我已经检查了batch_ys
(送入y
)和_y
的尺寸,当NLABELS=1
时它们都是1xN矩阵,因此问题似乎早于此。也许与矩阵乘法有关?
我实际上在一个真实的项目中也遇到了同样的问题,因此,任何帮助将不胜感激……谢谢!
参考方案
原始的MNIST示例使用one-hot encoding表示数据中的标签:这意味着,如果存在NLABELS = 10
类(如MNIST),则目标输出是类0的[1 0 0 0 0 0 0 0 0 0]
,类1的[0 1 0 0 0 0 0 0 0 0]
等。 tf.nn.softmax()
运算符将将tf.matmul(x, W) + b
计算的logit分布到不同输出类的概率分布中,然后将其与y_
的输入值进行比较。
如果是NLABELS = 1
,则它的作用就好像只有一个类,并且tf.nn.softmax()
op将计算该类的1.0
的概率,导致0.0
的交叉熵,因为tf.log(1.0)
在所有示例中都是0.0
。
您可以尝试(至少)两种方法进行二进制分类:
NLABELS = 2
,然后将训练数据编码为标签0的[1 0]
和标签1的[0 1]
。This answer建议如何执行此操作。 0
的建议将标签保留为整数1
和tf.nn.sparse_softmax_cross_entropy_with_logits()
并使用this answer。 我在Angular工作,正在使用Http请求和响应。是否可以在“响应”中发送多个参数。角度文件:this.http.get("api/agent/applicationaware").subscribe((data:any)... python文件:def get(request): ... return Response(seriali…
Python exchangelib在子文件夹中读取邮件 - python我想从Outlook邮箱的子文件夹中读取邮件。Inbox ├──myfolder 我可以使用account.inbox.all()阅读收件箱,但我想阅读myfolder中的邮件我尝试了此页面folder部分中的内容,但无法正确完成https://pypi.python.org/pypi/exchangelib/ 参考方案 您需要首先掌握Folder的myfo…
R'relaimpo'软件包的Python端口 - python我需要计算Lindeman-Merenda-Gold(LMG)分数,以进行回归分析。我发现R语言的relaimpo包下有该文件。不幸的是,我对R没有任何经验。我检查了互联网,但找不到。这个程序包有python端口吗?如果不存在,是否可以通过python使用该包? python参考方案 最近,我遇到了pingouin库。
Python ThreadPoolExecutor抑制异常 - pythonfrom concurrent.futures import ThreadPoolExecutor, wait, ALL_COMPLETED def div_zero(x): print('In div_zero') return x / 0 with ThreadPoolExecutor(max_workers=4) as execut…
如何用'-'解析字符串到节点js本地脚本? - python我正在使用本地节点js脚本来处理字符串。我陷入了将'-'字符串解析为本地节点js脚本的问题。render.js:#! /usr/bin/env -S node -r esm let argv = require('yargs') .usage('$0 [string]') .argv; console.log(argv…