2017-08-15 104 views
0

我有一个包含地理信息的火花数据框。将pyspark kmeans算法的结果添加到数据框中?

my_df.show(2) 

## +----+----+-----------+----------+ 
## | x0 | x1 | longitude | latitude | 
## +----+----+-----------+----------+ 
## | ...| ...| 51.043 | 13.6847 | 
## | ...| ...| 42.6753 | 23.3218 | 

我把经度和纬度了我的数据框和caluculated一些中心点与pyspark的K均值库。

#Trains a k-means model 
k = 120 
model = KMeans.train(dataset, k) 
print ("Final centers: " + str(model.clusterCenters)) 

输出

Final centers: [array([ 51.04307692, 13.68474126]), array([-33.434  , -70.58366667]), array([ 42.67533333, 23.32185981]), array([ 45.876, -61.492]), array([ 53.07465714, 8.4655 ]), array([ 4.594, 114.262]), array([ 48.15665306, 11.54269728]), array([ 51.51729851, 7.49838806]), array([ 48.76316125, 9.15357859]), .... 

任何一个想法如何匹配中心添加到我的数据帧?

## +----+----+-----------+----------+-----------+----------+ 
## | x0 | x1 | longitude | latitude | mean_long | mean_lat | 
## +----+----+-----------+----------+-----------+----------+ 
## | ...| ...| 51.043 | 13.6847 | 50.000 | 15.000 | 
## | ...| ...| 42.6753 | 23.3218 | 50.000 | 15.000 | 

回答

0

如果你决定使用DataFrames你应该使用新的API pyspark.ml,不遗留pyspark.mllib。它提供了许多聚类方法,包括K-Means,其预测方法将预测列附加到DataFrame

请查看详情ML文档(API和所需的输入类型):

+0

好建议,但我目前正在与星火1.6.3集群不具有对于k Python实现上 - 表示在psyspark.ml中 –

0

希望这有助于!
- 我已经采取样本数据来自星火文档页面)

from pyspark.ml.linalg import Vectors 
from pyspark.ml.clustering import KMeans 
import pandas as pd 

#generate data 
data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),), 
     (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)] 
df = sqlContext.createDataFrame(data, ["features"]) 
df.show() 

#run kmeans clustering model 
kmeans = KMeans(k=2, seed=1) 
model = kmeans.fit(df) 
predictions=model.transform(df).withColumnRenamed("prediction","cluster_id") 

centers = model.clusterCenters() 
#preprocessing centers so that it can be joined with predictions dataframe 
centers_p_df = pd.DataFrame(centers) 
centers_p_df.insert(0, 'new_col', range(0, len(centers_p_df))) 
centers_df = sqlContext.createDataFrame(centers_p_df, schema=['cluster_id','centers_col1','centers_col2']) 

final_df = predictions.join(centers_df, on="cluster_id").drop("cluster_id") 
final_df.show() 
相关问题