我从here获得了这个卷积神经网络(CNN)。它接受32 x 32图像,默认为10类。但是,我有500个班级的64 x 64图像。当我传递64 x 64图像(批量大小保持恒定为32)时,出现以下错误。
ValueError: Expected input batch_size (128) to match target batch_size (32).
The stack trace starts at the line loss = loss_fn(outputs, labels)
. The outputs.shape
is [128, 500]
and the labels.shape
is [32]
.
The code is listed here for completeness.
class Unit(nn.Module):
def __init__(self,in_channels,out_channels):
super(Unit,self).__init__()
self.conv = nn.Conv2d(in_channels=in_channels,kernel_size=3,out_channels=out_channels,stride=1,padding=1)
self.bn = nn.BatchNorm2d(num_features=out_channels)
self.relu = nn.ReLU()
def forward(self,input):
output = self.conv(input)
output = self.bn(output)
output = self.relu(output)
return output
class SimpleNet(nn.Module):
def __init__(self,num_classes=10):
super(SimpleNet,self).__init__()
self.unit1 = Unit(in_channels=3,out_channels=32)
self.unit2 = Unit(in_channels=32, out_channels=32)
self.unit3 = Unit(in_channels=32, out_channels=32)
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.unit4 = Unit(in_channels=32, out_channels=64)
self.unit5 = Unit(in_channels=64, out_channels=64)
self.unit6 = Unit(in_channels=64, out_channels=64)
self.unit7 = Unit(in_channels=64, out_channels=64)
self.pool2 = nn.MaxPool2d(kernel_size=2)
self.unit8 = Unit(in_channels=64, out_channels=128)
self.unit9 = Unit(in_channels=128, out_channels=128)
self.unit10 = Unit(in_channels=128, out_channels=128)
self.unit11 = Unit(in_channels=128, out_channels=128)
self.pool3 = nn.MaxPool2d(kernel_size=2)
self.unit12 = Unit(in_channels=128, out_channels=128)
self.unit13 = Unit(in_channels=128, out_channels=128)
self.unit14 = Unit(in_channels=128, out_channels=128)
self.avgpool = nn.AvgPool2d(kernel_size=4)
self.net = nn.Sequential(self.unit1, self.unit2, self.unit3, self.pool1, self.unit4, self.unit5, self.unit6
,self.unit7, self.pool2, self.unit8, self.unit9, self.unit10, self.unit11, self.pool3,
self.unit12, self.unit13, self.unit14, self.avgpool)
self.fc = nn.Linear(in_features=128,out_features=num_classes)
def forward(self, input):
output = self.net(input)
output = output.view(-1,128)
output = self.fc(output)
return output
关于如何修改此CNN以接受并正确返回输出的任何想法?
参考方案
问题是最后的重塑(视图)不兼容。
最后,您使用的是一种“扁平化”,这与“全局池化”不同。两者均对CNN有效,但只有全局池与任何图像大小兼容。
扁平网(您的情况)
对于您的情况,使用展平,您需要跟踪所有图像尺寸,以便知道如何在最后进行整形。
所以:
以64x64输入
池1至32x32
池2至16x16
池3至8x8
AvgPool至2x2
然后,最后您得到的形状为(batch, 128, 2, 2)
。如果图像为32x32,则为最终数字的四倍。
然后,最后的重塑应为output = output.view(-1,128*2*2)
。
但是,这是一个具有不同分类层的不同网络,因为in_features=512
。
全球池网
另一方面,如果将最后一个池替换为全局池,则对于大于等于32的任何图像大小,可以使用相同的模型,相同的层和相同的权重:
def flatChannels(x):
size = x.size()
return x.view(size[0],size[1],size[2]*size[3])
def globalAvgPool2D(x):
return flatChannels(x).mean(dim=-1)
def globalMaxPool2D(x):
return flatChannels(x).max(dim=-1)
模型的结尾:
#removed the pool from here to put it in forward
self.net = nn.Sequential(self.unit1, self.unit2, self.unit3, self.pool1, self.unit4,
self.unit5, self.unit6, self.unit7, self.pool2, self.unit8,
self.unit9, self.unit10, self.unit11, self.pool3,
self.unit12, self.unit13, self.unit14)
self.fc = nn.Linear(in_features=128,out_features=num_classes)
def forward(self, input):
output = self.net(input)
output = globalAvgPool2D(output) #or globalMaxPool2D
output = self.fc(output)
return output
如何在PyQt4的动态复选框列表中检查stateChanged - python所以我要从PyQt4的列表中添加复选框。但是我找不到在Window类中对每个状态使用stateChanged的方法。这是从列表元素添加它们的功能: def addCheckbox(self): colunas = Graphic(self.caminho).getColunas() for col in colunas: c = QtGui.QCheckBo…
Python GPU资源利用 - python我有一个Python脚本在某些深度学习模型上运行推理。有什么办法可以找出GPU资源的利用率水平?例如,使用着色器,float16乘法器等。我似乎在网上找不到太多有关这些GPU资源的文档。谢谢! 参考方案 您可以尝试在像Renderdoc这样的GPU分析器中运行pyxthon应用程序。它将分析您的跑步情况。您将能够获得有关已使用资源,已用缓冲区,不同渲染状态上…
管理多个会话和图形的合理方法 - python我想在多个会话中管理多个Keras模型。构建我的应用程序后,除了创建,保存和加载模型外,还可以同时运行它们。处理这种情况的正确方法是什么?当前,一个模型由包装类的实例表示。它用于训练,保存,加载和预测。每个实例创建一个tf.Graph和tf.Session,它们在需要实际模型的每个函数中使用。class NN: def __init__(self): sel…
Python sqlite3数据库已锁定 - python我在Windows上使用Python 3和sqlite3。我正在开发一个使用数据库存储联系人的小型应用程序。我注意到,如果应用程序被强制关闭(通过错误或通过任务管理器结束),则会收到sqlite3错误(sqlite3.OperationalError:数据库已锁定)。我想这是因为在应用程序关闭之前,我没有正确关闭数据库连接。我已经试过了: connectio…
python中类初始化的最佳实践 - python嗨,我想知道最佳实践是在python中初始化类,同时确保我的属性具有正确的数据类型。我应该使用默认值初始化类属性还是调用检查功能?class Foo: # Call with default value def __init__(self, bar=""): self._bar = bar # Calling set-function d…