2016-08-04 115 views
5

NumPy在创建数组时非常有用。如果numpy.array的第一个参数具有__getitem____len__方法,则基于它们可能是有效序列来使用它们。防止numpy创建多维数组

不幸的是我想创建一个包含dtype=object而不是NumPy“有帮助”的数组。

分解为一个最小的例子中,类将是这样的:

import numpy as np 

class Test(object): 
    def __init__(self, iterable): 
     self.data = iterable 

    def __getitem__(self, idx): 
     return self.data[idx] 

    def __len__(self): 
     return len(self.data) 

    def __repr__(self): 
     return '{}({})'.format(self.__class__.__name__, self.data) 

,如果“iterables”有不同的长度,一切都很好,我得到正是我想要的结果:

>>> np.array([Test([1,2,3]), Test([3,2])], dtype=object) 
array([Test([1, 2, 3]), Test([3, 2])], dtype=object) 

但NumPy的创建一个多维数组,如果这些发生在具有相同的长度:

>>> np.array([Test([1,2,3]), Test([3,2,1])], dtype=object) 
array([[1, 2, 3], 
     [3, 2, 1]], dtype=object) 

不幸的是,只有ndmin的论点,所以我想知道是否有一种方法来执行ndmax或以某种方式防止NumPy将自定义类解释为另一维(不删除__len____getitem__)?

回答

3

一种解决方法是当然的创建所需要的形状的阵列,然后复制数据:

In [19]: lst = [Test([1, 2, 3]), Test([3, 2, 1])] 

In [20]: arr = np.empty(len(lst), dtype=object) 

In [21]: arr[:] = lst[:] 

In [22]: arr 
Out[22]: array([Test([1, 2, 3]), Test([3, 2, 1])], dtype=object) 

请注意,在任何情况下,我也不会感到惊讶如果numpy的行为w.r.t.解释可迭代对象(这是你想要使用的,对吧?)是依赖于numpy的版本。可能还有越野车。或者,也许这些错误中的一些实际上是功能。无论如何,当一个numpy版本发生变化时,我会警惕破碎。

相反,复制到预先创建的数组应该更健壮。

5

此行为已经过多次讨论(例如Override a dict with numpy support)。 np.array试图使尽可能高的维数组。模型案例是嵌套列表。如果它可以迭代并且子列表的长度相等,它将“向下钻取”。

这走下2层中遇到不同长度的名单之前:

In [250]: np.array([[[1,2],[3]],[1,2]],dtype=object) 
Out[250]: 
array([[[1, 2], [3]], 
     [1, 2]], dtype=object) 
In [251]: _.shape 
Out[251]: (2, 2) 

没有形状也没有办法知道我是否希望它是(2,)(2,2)的方式ndmax参数。这两种方法都适用于dtype。

它是编译好的代码,所以很难确切地看到它使用的是什么测试。它试图迭代列表和元组,但不能在集合或字典上进行迭代。

,使具有给定尺寸的对象阵列的最可靠的方法是用空的启动,并填写

In [266]: A=np.empty((2,3),object) 
In [267]: A.fill([[1,'one']]) 
In [276]: A[:]={1,2} 
In [277]: A[:]=[1,2] # broadcast error 

的另一种方式是开始与至少一个不同的元件(例如一个None) ,然后替换它。

还有一个更原始的创造者,是ndarray初具规模:

In [280]: np.ndarray((2,3),dtype=object) 
Out[280]: 
array([[None, None, None], 
     [None, None, None]], dtype=object) 

但是,这是基本相同的np.empty(除非我给它一个缓冲)。

这些都是虚假,但它们并不昂贵(时间明智)。

=====(编辑)

https://github.com/numpy/numpy/issues/5933Enh: Object array creation function.是一个增强请求。另外https://github.com/numpy/numpy/issues/5303the error message for accidentally irregular arrays is confusing

开发人员的情绪似乎赞成单独的函数来创建dtype=object数组,其中一个对初始维度和迭代深度具有更多的控制权。他们甚至可能会加强错误检查,以防止np.array创建“不规则”阵列。

这样的函数可以检测到一个规则的嵌套可迭代到指定深度的形状,并构建一个要填充的对象类型数组。

def objarray(alist, depth=1): 
    shape=[]; l=alist 
    for _ in range(depth): 
     shape.append(len(l)) 
     l = l[0] 
    arr = np.empty(shape, dtype=object) 
    arr[:]=alist 
    return arr 

随着各种深度:

In [528]: alist=[[Test([1,2,3])], [Test([3,2,1])]] 
In [529]: objarray(alist,1) 
Out[529]: array([[Test([1, 2, 3])], [Test([3, 2, 1])]], dtype=object) 
In [530]: objarray(alist,2) 
Out[530]: 
array([[Test([1, 2, 3])], 
     [Test([3, 2, 1])]], dtype=object) 
In [531]: objarray(alist,3) 
Out[531]: 
array([[[1, 2, 3]], 

     [[3, 2, 1]]], dtype=object) 
In [532]: objarray(alist,4) 
... 
TypeError: object of type 'int' has no len() 
+0

我试图寻找类似的问题,但我还没有发现任何。也许我只是搜索错误的短语。如果您有任何提及较早的问题,那就太棒了。谢谢你的回答,但我实际上并没有寻找解决方法。我更关心的是如何在事先不知道确切长度的情况下定义数组的最大深度(维),或者禁用numpy将自定义类实例解释为序列。 – MSeifert

+0

通过将您的类更改为子类'dict',我可以阻止它在您的实例上迭代。这表明'np.array'正在测试多于'__getitem__'。但是我一直无法找到执行这种检查的代码。 – hpaulj

+0

http://stackoverflow.com/questions/36663919/override-a-dict-with-numpy-support - 与同一问题的斗争;控制'np.array'是否迭代你的自定义类。同样的解决方法。 – hpaulj