2017-06-01 167 views
0

我有两个点feat_left, feat_right获得从连体网络和我绘制这些点在x,y坐标如下所示。 enter image description here计算质心和精度

这里是python脚本

import json 
import matplotlib.pyplot as plt 
import numpy as np 



data = json.load(open('predictions-mnist.txt')) 

n=len(data['outputs'].items()) 
label_list = np.array(range(n)) 
feat_left = np.random.random((n,2)) 


count=1 

for key,val in data['outputs'].items(): 
    feat = data['outputs'][key]['feat_left'] 
    feat_left[count-1] = feat 
    key = key.split("/") 
    key = int(key[6]) 
    label_list[count - 1] = key 
    count = count + 1 


f = plt.figure(figsize=(16,9)) 

c = ['#ff0000', '#ffff00', '#00ff00', '#00ffff', '#0000ff', 
    '#ff00ff', '#990000', '#999900', '#009900', '#009999'] 

for i in range(10): 
    plt.plot(feat_left[label_list==i,0].flatten(), feat_left[label_list==i,1].flatten(), '.', c=c[i]) 
plt.legend(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']) 
plt.grid() 
plt.show() 

现在我想计算质心然后纯度每个集群的

+1

如何定义群集的“准确性”? – Shai

+0

您可以使用k-means(k = 10),或查看模块sklearn.cluster提供的不同[clustering method](http://scikit-learn.org/stable/modules/clustering.html) ' – Nuageux

+0

我正在关注这篇文章[集群评估](https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html)@Shai – cpwah

回答

2

重心仅仅是mean

centorids = np.zeros((10,2), dtype='f4') 
for i in xrange(10): 
    centroids[i,:] = np.mean(feat_left[label_list==i, :2], axis=0) 

至于精度,可以计算从质心的均方误差(距离):

sqerr = np.zeros((10,), dtype='f4') 
for i in xrange(10): 
    sqerr[i] = np.sum((feat_left[label_list==i, :2]-centroids[i,:])**2) 

计算purity

def compute_cluster_purity(gt_labels, pred_labels): 
    """ 
    Compute purity of predicted labels (pred_labels), given 
    the ground-truth labels (gt_labels). 

    Assuming gt_labels and pred_labels are both lists of int of length n 
    """ 
    n = len(gt_labels) # number of elements 
    assert len(pred_labels) == n 
    purity = 0 
    for l in set(pred_labels): 
    # for predicted label l, what are the gt_labels of this cluster? 
    gt = [gt_labels[i] for i, il in enumerate(pred_labels) if il==l] 
    # most frequent gt label in this cluster: 
    mfgt = max(set(gt), key=gt.count) 
    purity += gt.count(mfgt) # count intersection between most frequent ground truth and this cluster 
    return float(purity)/n 

请参阅this answer了解sele的更多详情设置群集中最频繁的标签。

+0

我使用kmeans聚类计算了centriod,我更感兴趣的是计算纯度度量。 @Shai – cpwah

+0

@cpwah这是一个**不同的问题:对于纯度,你应该有地面实况标签和'kmeans'的分配(例如你的例子中的'label_list') – Shai

+0

是的,我修改了我的问题。如您所示,地面实况标签存在于'label_list'中。 @Shai – cpwah