2012-03-12 113 views
0

我有尝试实现和分析下面的代码,kd树实现

#include <iostream.h> 
    #include "vector.h" 

    /** 
    * Quick illustration of a two-dimensional tree. 
    * No abstraction here. 
    */ 
    template <class Comparable> 
    class KdTree 
    { 
     public: 
     KdTree() : root(NULL) { } 

     void insert(const vector<Comparable> & x) 
     { 
      insert(x, root, 0); 
     } 

     /** 
     * Print items satisfying 
     * low[ 0 ] <= x[ 0 ] <= high[ 0 ] and 
     * low[ 1 ] <= x[ 1 ] <= high[ 1 ] 
     */ 
     void printRange(const vector<Comparable> & low, 
         const vector<Comparable> & high) const 
     { 
      printRange(low, high, root, 0); 
     } 

     private: 
     struct KdNode 
     { 
      vector<Comparable> data; 
      KdNode   *left; 
      KdNode   *right; 

      KdNode(const vector<Comparable> & item) 
       : data(item), left(NULL), right(NULL) { } 
     }; 

     KdNode *root; 

     void insert(const vector<Comparable> & x, KdNode * & t, int level) 
     { 
      if(t == NULL) 
       t = new KdNode(x); 
      else if(x[ level ] < t->data[ level ]) 
       insert(x, t->left, 1 - level); 
      else 
       insert(x, t->right, 1 - level); 
     } 


     void printRange(const vector<Comparable> & low, 
         const vector<Comparable> & high, 
         KdNode *t, int level) const 
     { 
      if(t != NULL) 
      { 
       if(low[ 0 ] <= t->data[ 0 ] && high[ 0 ] >= t->data[ 0 ] && 
        low[ 1 ] <= t->data[ 1 ] && high[ 1 ] >= t->data[ 1 ]) 
        cout << "(" << t->data[ 0 ] << "," 
           << t->data[ 1 ] << ")" << endl; 

       if(low[ level ] <= t->data[ level ]) 
        printRange(low, high, t->left, 1 - level); 
       if(high[ level ] >= t->data[ level ]) 
        printRange(low, high, t->right, 1 - level); 
      } 
     } 
    }; 

     // Test program 
     int main() 
     { 
      KdTree<int> t; 

      cout << "Starting program" << endl; 
      for(int i = 300; i < 370; i++) 
      { 
       vector<int> it(2); 
       it[ 0 ] = i; 
       it[ 1 ] = 2500 - i; 
       t.insert(it); 
      } 

      vector<int> low(2), high(2); 
      low[ 0 ] = 70; 
      low[ 1 ] = 2186; 
      high[ 0 ] = 1200; 
      high[ 1 ] = 2200; 

      t.printRange(low, high); 

      return 0; 
     } 

问题是,这里向量类是从源描述非常困难,所以我想用现有的C++ STL的载体,但不知道怎么做,请帮助我,例如如何在插入程序中使用vector?等等,请

+0

如果向量总是有'大小()= 2'可能将其更改为'的std :: pair's一个好主意。 – 2012-03-13 10:54:49

回答

3

您的代码已经与STL兼容:我只是更改了头文件,主要是为了便于阅读,引入了一个typedef:

#include <iostream> 
#include <vector> 
using namespace std; 

/** 
* Quick illustration of a two-dimensional tree. 
* No abstraction here. 
*/ 
template <class Comparable> 
class KdTree 
{ 
public: 
    typedef vector<Comparable> tVec; 

    KdTree() : root(NULL) { } 

    void insert(const tVec & x) 
    { 
     insert(x, root, 0); 
    } 

    /** 
    * Print items satisfying 
    * low[ 0 ] <= x[ 0 ] <= high[ 0 ] and 
    * low[ 1 ] <= x[ 1 ] <= high[ 1 ] 
    */ 
    void printRange(const tVec & low, 
        const tVec & high) const 
    { 
     printRange(low, high, root, 0); 
    } 

private: 
    struct KdNode 
    { 
     tVec data; 
     KdNode   *left; 
     KdNode   *right; 

     KdNode(const tVec & item) 
      : data(item), left(NULL), right(NULL) { } 
    }; 

    KdNode *root; 

    void insert(const tVec & x, KdNode * & t, int level) 
    { 
     if(t == NULL) 
      t = new KdNode(x); 
     else if(x[ level ] < t->data[ level ]) 
      insert(x, t->left, 1 - level); 
     else 
      insert(x, t->right, 1 - level); 
    } 


    void printRange(const tVec & low, 
        const tVec & high, 
        KdNode *t, int level) const 
    { 
     if(t != NULL) 
     { 
      if(low[ 0 ] <= t->data[ 0 ] && high[ 0 ] >= t->data[ 0 ] && 
        low[ 1 ] <= t->data[ 1 ] && high[ 1 ] >= t->data[ 1 ]) 
       cout << "(" << t->data[ 0 ] << "," 
        << t->data[ 1 ] << ")" << endl; 

      if(low[ level ] <= t->data[ level ]) 
       printRange(low, high, t->left, 1 - level); 
      if(high[ level ] >= t->data[ level ]) 
       printRange(low, high, t->right, 1 - level); 
     } 
    } 
}; 

// Test program 
int main_kdtree(int, char **) 
{ 
    typedef KdTree<int> tTree; 
    tTree t; 

    cout << "Starting program" << endl; 
    for(int i = 300; i < 370; i++) 
    { 
     tTree::tVec it(2); 
     it[ 0 ] = i; 
     it[ 1 ] = 2500 - i; 
     t.insert(it); 
    } 

    tTree::tVec low(2), high(2); 
    low[ 0 ] = 70; 
    low[ 1 ] = 2186; 
    high[ 0 ] = 1200; 
    high[ 1 ] = 2200; 

    t.printRange(low, high); 

    return 0; 
} 

输出:

Starting program 
(300,2200) 
(301,2199) 
.... 
(313,2187) 
(314,2186)