2017-08-04 43 views
1

的列表numpy的数组:广播通过与给定一个邻接表阵列

adj_list = [array([0,1]),array([0,1,2]),array([0,2])] 

和指数的阵列,

ind_arr = array([0,1,2]) 

目标:

A = np.zeros((3,3)) 
for i in ind_arr: 
    A[i,list(adj_list[x])] = 1.0/float(adj_list[x].shape[0]) 


目前,我写道:

A[ind_list[:],adj_list[:]] = 1./len(adj_list[:]) 

并尝试在此脚手架索引的各种配置。

+0

'adj_list'的数组元素是否总是采用该序列格式 - '[0,1,2 ...]'? – Divakar

+0

@Divakar他们按升序排列。但是,它们很稀疏 – bordeo

回答

1

这里有一个方法 -

lens = np.array([len(i) for i in adj_list]) 
col_idx = np.concatenate(adj_list) 
out = np.zeros((len(lens), col_idx.max()+1)) 
row_idx = np.repeat(np.arange(len(lens)), lens) 
vals = np.repeat(1.0/lens, lens) 
out[row_idx, col_idx] = vals 

样品输入,输出 -

In [494]: adj_list = [np.array([0,2]),np.array([0,1,4])] 

In [496]: out 
Out[496]: 
array([[ 0.5  , 0.  , 0.5  , 0.  , 0.  ], 
     [ 0.33333333, 0.33333333, 0.  , 0.  , 0.33333333]]) 

稀疏矩阵输出

此外,如果你想节省内存并创建稀疏矩阵代替,这是一个简单的扩展 -

In [506]: from scipy.sparse import csr_matrix 

In [507]: csr_matrix((vals, (row_idx, col_idx)), shape=(len(lens), col_idx.max()+1)) 
Out[507]: 
<2x5 sparse matrix of type '<type 'numpy.float64'>' 
    with 5 stored elements in Compressed Sparse Row format> 

In [508]: _.toarray() 
Out[508]: 
array([[ 0.5  , 0.  , 0.5  , 0.  , 0.  ], 
     [ 0.33333333, 0.33333333, 0.  , 0.  , 0.33333333]]) 
1

我不认为你可以完全消除由于混合数据类型的循环,但可以减少循环嵌套双到一个单一的一个:

A = np.zeros((2, 3)) 
for i, arr in enumerate(adj_list): 
    arr_size = len(arr) 
    A[i, :arr_size] = 1./arr_size 

A 
# array([[ 0.5  , 0.5  , 0.  ], 
#  [ 0.33333333, 0.33333333, 0.33333333]]) 

或者,如果阵列中的数字是实际上列的位置:

A = np.zeros((2, 3)) 
for i, arr in enumerate(adj_list): 
    A[i, arr] = 1./len(arr) 

A 
# array([[ 0.5  , 0.5  , 0.  ], 
#  [ 0.33333333, 0.33333333, 0.33333333]]) 

使用从sklearnMultiLabelBinarizer另一种选择(但可能不会那样有效):

from sklearn.preprocessing import MultiLabelBinarizer 
​ 
mlb = MultiLabelBinarizer() 

adj_list = [np.array([0,1]),np.array([0,1,2])] 
​ 
sizes = np.fromiter(map(len, adj_list), dtype=int) 
mlb.fit_transform(adj_list)/sizes[:,None] 

# array([[ 0.5  , 0.5  , 0.  ], 
#  [ 0.33333333, 0.33333333, 0.33333333]])