2016-02-28 53 views
0

在我的函数中,有时我得到的结果是2D形式的一维numpy数组,因此它的形状是nx1(n,1)。其他时候,我可能会得到它的形式1xn array.shape =(1,n)测试Numpy数组以查看它是否为列形式

其他时候,我只得到一个numpy数组,其形状是(n,)。

当我运行下面的测试中,我得到一方面的错误,并且在另一假阳性(因为一个形状属性的长度总是大于1,显然):

y_predicted = forest.predict(testX) 
    if y_predicted.shape[1] != None: 
     y_predicted = y_predicted.T[0] 

y_predicted = forest.predict(testX) 
    if len(y_predicted.shape) > 1: 
     y_predicted = y_predicted.T[0] 

我只是需要确保y的最终形状总是在形式(N),而不是(N,1)或(1,N)...

+0

'squeeze','ravel'和'flatten'都会做这个工作;但请阅读他们的文档,以便了解他们的差异。 – hpaulj

回答

3

你应该使用numpy.squeeze

numpy.squeeze(a)从数组形状中删除单维条目。

例子:

>>> x = np.array([[1,2,3]]) 
>>> x.shape 
(1, 3) 
>>> np.squeeze(x) 
array([1, 2, 3]) 
>>> np.squeeze(x).shape 
(3,) 
+0

很好的答案。不知道这种方法。 –

相关问题