2016-12-15 158 views
0

工作在Kaggle泰坦尼克号数据集。我试图更好地理解决策树,我已经很好地使用了线性回归,但从来没有决策树。我正在尝试为我的树在Python中创建一个可视化文件。有些东西虽然不起作用。在下面检查我的代码。决策树与SKlearn和可视化

import pandas as pd 
from sklearn import tree 
from sklearn.datasets import load_iris 
import numpy as np 


train_file='.......\RUN.csv' 
train=pd.read_csv(train_file) 

#impute number values and missing values 
train["Sex"][train["Sex"] == "male"] = 0 
train["Sex"][train["Sex"] == "female"] = 1 
train["Embarked"] = train["Embarked"].fillna("S") 
train["Embarked"][train["Embarked"] == "S"]= 0 
train["Embarked"][train["Embarked"] == "C"]= 1 
train["Embarked"][train["Embarked"] == "Q"]= 2 
train["Age"] = train["Age"].fillna(train["Age"].median()) 
train["Pclass"] = train["Pclass"].fillna(train["Pclass"].median()) 
train["Fare"] = train["Fare"].fillna(train["Fare"].median()) 

target = train["Survived"].values 
features_one = train[["Pclass", "Sex", "Age", "Fare","SibSp","Parch","Embarked"]].values 


# Fit your first decision tree: my_tree_one 
my_tree_one = tree.DecisionTreeClassifier(max_depth = 10, min_samples_split = 5, random_state = 1) 

iris=load_iris() 

my_tree_one = my_tree_one.fit(features_one, target) 

tree.export_graphviz(my_tree_one, out_file='tree.dot') 

我该如何看到决策树?尝试将其可视化。

帮助感谢!

回答

2

你检查:http://scikit-learn.org/stable/modules/tree.html提到如何绘制树PNG图像:

from IPython.display import Image 
import pydotplus 
dot_data = tree.export_graphviz(my_tree_one, out_file='tree.dot') 
graph = pydotplus.graph_from_dot_data(dot_data) ` 
Image(graph.create_png()) 
+0

>>> import os >>> os.unlink('iris.dot') –

+0

I t说这样做^。但是,只是删除该文件。有任何想法吗?我也没有pydotplus。我试着用pip下载它,但没有奏效。 –

+0

我认为问题是Graphiz,你应该下载它:http://www.graphviz.org/Download..php http://stackoverflow.com/questions/18438997/why-is-pydot-unable-to-find -graphvizs-可执行文件,在窗口-8。首先安装graphiz然后pydot。或者使用linux。稍后我会回到它。 – Roxanne

0

维基百科:

的DOT语言定义的图形,但不提供用于呈现设施图形。有迹象表明,可以用来渲染,查看和操作的DOT语言图形几个方案:

的Graphviz - 库和工具的集合,操作和渲染图

Canviz - 一个JavaScript库,用于渲染点文件。

Viz.js - 一个简单的Graphviz JavaScript客户

拉帕 - 的Graphviz的局部端口到Java [4] [5]

Beluging - Python & Google云基于DOT和Beluga扩展的查看器。 [1]

郁金香可以导入点文件进行分析

的OmniGraffle可以导入DOT的子集,产生一个可编辑的文档。 (结果却无法回输到DOT。)

ZGRViewer,一个GraphViz的/ DOT查看器链接

VizierFX中,缩放图形渲染库链接

Gephi - 交互式可视化和勘探平台各种网络和复杂系统,动态和分层图形

因此,这些程序中的任何一个都能够可视化你的树。

+0

我已经使用graphviz,但无法将其显示为图像。它只是将它写入.dot文件。我已经尝试将ti更改为pdf,但似乎无法使其工作。 –

+0

我相信这应该只是写入.dot文件。然后您必须使用列出的应用程序之一来查看.dot文件。我个人喜欢格西。 –

0

我用条形图做了一个可视化。第一个图表示类的分布。第一个标题代表第一个分裂标准。所有满足这个标准的数据都会导致左下方的子图。如果不是,则右图是结果。因此,所有标题都表示下一次拆分的拆分标准。

百分比是来自初始分布的值。因此,通过查看百分比,可以容易地从初始数量的数据中获得多少分割后剩下的数据。

注意,如果你设置MAX_DEPTH高,这将需要大量的次要情节的(MAX_DEPTH,2 ^深度)

Tree visualization using bar plots

代码:

def give_nodes(nodes,amount_of_branches,left,right): 
    amount_of_branches*=2 
    nodes_splits=[] 
    for node in nodes: 
     nodes_splits.append(left[node]) 
     nodes_splits.append(right[node]) 
    return (nodes_splits,amount_of_branches) 

def plot_tree(tree, feature_names): 
    from matplotlib import gridspec 
    import matplotlib.pyplot as plt 
    from matplotlib import rc 
    import pylab 

    color = plt.cm.coolwarm(np.linspace(1,0,len(feature_names))) 

    plt.rc('text', usetex=True) 
    plt.rc('font', family='sans-serif') 
    plt.rc('font', size=14) 

    params = {'legend.fontsize': 20, 
      'axes.labelsize': 20, 
      'axes.titlesize':25, 
      'xtick.labelsize':20, 
      'ytick.labelsize':20} 
    plt.rcParams.update(params) 

    max_depth=tree.max_depth 
    left  = tree.tree_.children_left 
    right  = tree.tree_.children_right 
    threshold = tree.tree_.threshold 
    features = [feature_names[i] for i in tree.tree_.feature] 
    value = tree.tree_.value 

    fig = plt.figure(figsize=(3*2**max_depth,2*2**max_depth)) 
    gs = gridspec.GridSpec(max_depth, 2**max_depth) 
    plt.subplots_adjust(hspace = 0.6, wspace=0.8) 

    # All data 
    amount_of_branches=1 
    nodes=[0] 
    normalize=np.sum(value[0][0]) 

    for i,node in enumerate(nodes): 
     ax=fig.add_subplot(gs[0,(2**max_depth*i)/amount_of_branches:(2**max_depth*(i+1))/amount_of_branches]) 
     ax.set_title(features[node]+"$<= "+str(threshold[node])+"$") 
     if(i==0): ax.set_ylabel(r'$\%$') 
     ind=np.arange(1,len(value[node][0])+1,1) 
     width=0.2 
     bars= (np.array(value[node][0])/normalize)*100 
     plt.bar(ind-width/2, bars, width,color=color,alpha=1,linewidth=0) 
     plt.xticks(ind, [int(i) for i in ind-1]) 
     pylab.ticklabel_format(axis='y',style='sci',scilimits=(0,2)) 

    # Splits 
    for j in range(1,max_depth): 
     nodes,amount_of_branches=give_nodes(nodes,amount_of_branches,left,right) 
     for i,node in enumerate(nodes): 
      ax=fig.add_subplot(gs[j,(2**max_depth*i)/amount_of_branches:(2**max_depth*(i+1))/amount_of_branches]) 
      ax.set_title(features[node]+"$<= "+str(threshold[node])+"$") 
      if(i==0): ax.set_ylabel(r'$\%$') 
      ind=np.arange(1,len(value[node][0])+1,1) 
      width=0.2 
      bars= (np.array(value[node][0])/normalize)*100 
      plt.bar(ind-width/2, bars, width,color=color,alpha=1,linewidth=0) 
      plt.xticks(ind, [int(i) for i in ind-1]) 
      pylab.ticklabel_format(axis='y',style='sci',scilimits=(0,2)) 


    plt.tight_layout() 
    return fig 

例子:

X=[] 
Y=[] 
amount_of_labels=5 
feature_names=[ '$x_1$','$x_2$','$x_3$','$x_4$','$x_5$'] 
for i in range(200): 
    X.append([np.random.normal(),np.random.randint(0,100),np.random.uniform(200,500) ]) 
    Y.append(np.random.randint(0,amount_of_labels)) 

clf = tree.DecisionTreeClassifier(criterion='entropy',max_depth=4) 
clf = clf.fit(X,Y) 
fig=plot_tree(clf, feature_names)