下面是三个变体,演示如何在Theano中重新实现部分代码。
请注意,Theano的Unique
操作不支持在GPU上运行,也不支持渐变。因此版本3很多用处不大。第2版提供了一种解决方法:计算Theano以外的唯一值并将它们传入。版本1只是您的numpy代码的最后一行的Theano实现。
要解决您的特定问题:不需要使用nonzero
;在这种情况下,Theano的索引工作就像在numpy中工作一样。也许你在y
和Y
之间感到困惑? (常见的Python风格是为所有变量和参数名称使用小写)。
import numpy as np
import theano
import theano.tensor as tt
import theano.tensor.extra_ops
def numpy_ver(y, y_hat):
Y = np.zeros(shape=(len(y), len(np.unique(y))), dtype=np.int64)
Y_hat = np.zeros_like(Y, dtype=np.int64)
rows = np.arange(len(y), dtype=np.int64)
Y[rows, y] = 1
Y_hat[rows, y_hat] = 1
return ((Y_hat == Y) & (Y == 1)).sum(axis=0), Y, Y_hat
def compile_theano_ver1():
Y = tt.matrix(dtype='int64')
Y_hat = tt.matrix(dtype='int64')
z = (tt.eq(Y_hat, Y) & tt.eq(Y, 1)).sum(axis=0)
return theano.function([Y, Y_hat], outputs=z)
def compile_theano_ver2():
y = tt.vector(dtype='int64')
y_hat = tt.vector(dtype='int64')
y_uniq = tt.vector(dtype='int64')
Y = tt.zeros(shape=(y.shape[0], y_uniq.shape[0]), dtype='int64')
Y_hat = tt.zeros_like(Y, dtype='int64')
rows = tt.arange(y.shape[0], dtype='int64')
Y = tt.set_subtensor(Y[rows, y], 1)
Y_hat = tt.set_subtensor(Y_hat[rows, y_hat], 1)
z = (tt.eq(Y_hat, Y) & tt.eq(Y, 1)).sum(axis=0)
return theano.function([y, y_hat, y_uniq], outputs=z)
def compile_theano_ver3():
y = tt.vector(dtype='int64')
y_hat = tt.vector(dtype='int64')
y_uniq = tt.extra_ops.Unique()(y)
Y = tt.zeros(shape=(y.shape[0], y_uniq.shape[0]), dtype='int64')
Y_hat = tt.zeros_like(Y, dtype='int64')
rows = tt.arange(y.shape[0], dtype='int64')
Y = tt.set_subtensor(Y[rows, y], 1)
Y_hat = tt.set_subtensor(Y_hat[rows, y_hat], 1)
z = (tt.eq(Y_hat, Y) & tt.eq(Y, 1)).sum(axis=0)
return theano.function([y, y_hat], outputs=z)
def main():
y = np.array([1, 0, 1, 2, 2], dtype=np.int64)
y_hat = np.array([2, 0, 1, 1, 0], dtype=np.int64)
y_uniq = np.unique(y)
result, Y, Y_hat = numpy_ver(y, y_hat)
print result
theano_ver1 = compile_theano_ver1()
print theano_ver1(Y, Y_hat)
theano_ver2 = compile_theano_ver2()
print theano_ver2(y, y_hat, y_uniq)
theano_ver3 = compile_theano_ver3()
print theano_ver3(y, y_hat)
main()
非常感谢。我应该先报告导致我首先提出这个问题的错误 - 这是我提出问题的第一条规则之一。我知道我第一次尝试简单地使用
((Y_hat == Y) & (Y == 1)).sum(axis=0)
,并且该错误与没有sum
方法的布尔值有关。无论如何,你的解决方案都可以工作再次感谢。 – ndronen