我正在修改Tensorflow中的一个简单CNN,并且当我索引4d数组时,出现此错误。 我可再现的例子是:索引4D数组时的Tensorflow错误:ValueError:形状必须相等,但是1和0
from __future__ import print_function
import pdb
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
def conv2d(x, W, stride=1):
return tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding='SAME')
def max_pool_2d(x, k=10):
return tf.nn.max_pool(x, ksize=[1, k, k, 1],
strides=[1, k, k, 1], padding='SAME')
indices = np.array([[0, 1], [5, 2],[300, 400]]).astype(np.int32)
input_updatable = weight_variable(shape=[1, 1200, 600, 100])
# Convolutional layer 1
W_conv1 = weight_variable([5, 5, 100, 100])
b_conv1 = bias_variable([100])
h_conv1 = tf.nn.relu(conv2d(input_updatable, W_conv1) + b_conv1)
h_pool1 = max_pool_2d(h_conv1)
# Convolutional layer 2
W_conv2 = weight_variable([5, 5, 100, 100])
b_conv2 = bias_variable([100])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2d(h_conv2)
#extract vectoris based on input
l1_vecs = input_updatable[0, indices[:, 0], indices[:, 1], :]
# Training steps
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
max_steps = 1000
for step in range(max_steps):
l1 = sess.run(l1_vecs)
pdb.set_trace()
此代码引发以下错误:
l1_vecs = input_updatable[0, indices[:, 0], indices[:, 1], :]
File "/home/arahimi/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 722, in _SliceHelperVar
return _SliceHelper(var._AsTensor(), slice_spec, var)
File "/home/arahimi/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 480, in _SliceHelper
stack(begin), stack(end), stack(strides))
File "/home/arahimi/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 824, in stack
return gen_array_ops._pack(values, axis=axis, name=name)
File "/home/arahimi/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 2041, in _pack
result = _op_def_lib.apply_op("Pack", values=values, axis=axis, name=name)
File "/home/arahimi/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op
op_def=op_def)
File "/home/arahimi/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2329, in create_op
set_shapes_for_outputs(ret)
File "/home/arahimi/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1717, in set_shapes_for_outputs
shapes = shape_func(op)
File "/home/arahimi/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1667, in call_with_requiring
return call_cpp_shape_fn(op, require_shape_fn=True)
File "/home/arahimi/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 610, in call_cpp_shape_fn
debug_python_shape_fn, require_shape_fn)
File "/home/arahimi/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 676, in _call_cpp_shape_fn_impl
raise ValueError(err.message)
ValueError: Shapes must be equal rank, but are 1 and 0
From merging shape 2 with other shapes. for 'strided_slice/stack_1' (op: 'Pack') with input shapes: [], [3], [3], [].
需要注意的是,当我提取input_updatable值有:
ip = sess.run(input_updatable)
然后我就可以建立索引使用:
l1_vecs = input_updatable[0, indices[:, 0], indices[:, 1], :]
我不确定是什么原因。
你尝试过使用tf.gather_nd()吗? https://www.tensorflow.org/api_docs/python/tf/gather_nd – hars
作为@hars提到tf.gather_nd()的作品。 TF不支持numpy的高级索引,所以我不得不将索引更改为3D矩阵来索引input_updatable。 – Ash