我想使用自定义keras层来操纵上一层的激活。下一层只是将数字与上一层的激活数相乘。
class myLayer(Layer):
def __init__(self, **kwargs):
super(myLayer, self).__init__(**kwargs)
def build(self, input_shape):
self.output_dim = input_shape[0][1]
super(myLayer, self).build(input_shape)
def call(self, inputs, **kwargs):
if not isinstance(inputs, list):
raise ValueError('This layer should be called on a list of inputs.')
mainInput = inputs[0]
nInput = inputs[1]
changed = tf.multiply(mainInput,nInput)
forTest = changed
forTrain = inputs[0]
return K.in_train_phase(forTrain, forTest)
def compute_output_shape(self, input_shape):
print(input_shape)
return (input_shape[0][0], self.output_dim)
我正在创建模型
inputTensor = Input((5,))
out = Dense(units, input_shape=(5,),activation='relu')(inputTensor)
n = K.placeholder(shape=(1,))
auxInput = Input(tensor=n)
out = myLayer()([out, auxInput])
out = Dense(units, activation='relu')(out)
out = Dense(3, activation='softmax')(out)
model = Model(inputs=[inputTensor, auxInput], outputs=out)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics='acc'])
尝试使用时出现此错误
model.fit(X_train, Y_train, epochs=epochs, verbose=1)
错误
InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder_3' with dtype float and shape [1]
当我尝试将值赋予占位符时
model.fit([X_train, np.array([3])], Y_train, epochs=epochs, verbose=1)
我得到:
ValueError: Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 1 arrays but instead got the following list of 2 arrays:
我应该如何初始化该占位符?我的目标是使用model.evaluate在推理期间测试模型中n个不同值的效果。
谢谢。
参考方案
我找到了一种避免对n
使用数组的解决方案。
代替使用placeholder
,而使用K.variable
:
n = K.variable([someInitialValue])
auxInput = Input(tensor=n)
然后,即使在编译模型之后,您也可以随时设置n
的值:
K.set_value(n,[anotherValue])
这使您可以继续训练,而不必重新编译模型,也无需将n
传递给fit
方法。
model.fit(X_train,Y_train,....)
如果使用许多类似的输入,则可以:
n = K.variable([val1,val2,val3,val4]) #tensor definition
K.set_value(n,[new1,new2,new3,new4]) #changing values
在图层内部,第二个输入是n
的张量将具有4个元素:
n1 = inputs[1][0]
n2 = inputs[1][1]
....
R'relaimpo'软件包的Python端口 - python我需要计算Lindeman-Merenda-Gold(LMG)分数,以进行回归分析。我发现R语言的relaimpo包下有该文件。不幸的是,我对R没有任何经验。我检查了互联网,但找不到。这个程序包有python端口吗?如果不存在,是否可以通过python使用该包? python参考方案 最近,我遇到了pingouin库。
如何用'-'解析字符串到节点js本地脚本? - python我正在使用本地节点js脚本来处理字符串。我陷入了将'-'字符串解析为本地节点js脚本的问题。render.js:#! /usr/bin/env -S node -r esm let argv = require('yargs') .usage('$0 [string]') .argv; console.log(argv…
Python:传递记录器是个好主意吗? - python我的Web服务器的API日志如下:started started succeeded failed 那是同时收到的两个请求。很难说哪一个成功或失败。为了彼此分离请求,我为每个请求创建了一个随机数,并将其用作记录器的名称logger = logging.getLogger(random_number) 日志变成[111] started [222] start…
Python-Excel导出 - python我有以下代码:import pandas as pd import requests from bs4 import BeautifulSoup res = requests.get("https://www.bankier.pl/gielda/notowania/akcje") soup = BeautifulSoup(res.cont…
Python sqlite3数据库已锁定 - python我在Windows上使用Python 3和sqlite3。我正在开发一个使用数据库存储联系人的小型应用程序。我注意到,如果应用程序被强制关闭(通过错误或通过任务管理器结束),则会收到sqlite3错误(sqlite3.OperationalError:数据库已锁定)。我想这是因为在应用程序关闭之前,我没有正确关闭数据库连接。我已经试过了: connectio…