2012-04-11 39 views
3

我有一个函数,计算一些简单的node.id,node.parentId关联的treeNodes集合的左右节点值。它非常简单并且运作良好......但是,我想知道是否有更习惯的方法。特别是有没有一种方法来跟踪左/右值,而不使用一些外部跟踪值,但仍然保持美味的递归。我该如何使这种方法更Scalalicious

/* 
* A tree node 
*/ 
case class TreeNode(val id:String, val parentId: String){ 
    var left: Int = 0 
    var right: Int = 0 
} 

/* 
* a method to compute the left/right node values 
*/ 
def walktree(node: TreeNode) = { 
    /* 
    * increment state for the inner function 
    */ 
    var c = 0 

    /* 
    * A method to set the increment state 
    */ 
    def increment = { c+=1; c } // poo 

    /* 
    * the tasty inner method 
    * treeNodes is a List[TreeNode] 
    */ 
    def walk(node: TreeNode): Unit = { 
     node.left = increment 

     /* 
     * recurse on all direct descendants 
     */ 
     treeNodes filter(_.parentId == node.id) foreach (walk(_)) 

     node.right = increment 
    } 

    walk(node) 
} 

walktree(someRootNode) 

编辑 - 节点列表取自数据库。将节点拉入适当的树会花费太多时间。我正在将一个单子列表放入记忆中,而我所拥有的是一个通过节点标识符与父母和孩子相关的关联。

添加左/右节点值允许我通过单个SQL查询获得所有儿童(和儿童儿童)的快照。

如果父母 - 子女关联发生变化(他们经常这样做),计算需要非常快速地执行以保持数据完整性。

除了使用令人​​敬畏的Scala集合外,我还通过对树节点上的某些前/后过滤使用并行处理来提高速度。我想找到更加习惯的方式来跟踪左/右节点值。从@dhg看到答案后,情况变得更好了。使用groupBy而不是过滤器可以使算法(主要是?)线性而不是四边形!

val treeNodeMap = treeNodes.groupBy(_.parentId).withDefaultValue(Nil) 

def walktree(node: TreeNode) = { 
    def walk(node: TreeNode, counter: Int): Int = { 
     node.left = counter 
     node.right = 
      treeNodeMap(node.id) 
      .foldLeft(counter+1) { 
      (result, curnode) => walk(curnode, result) + 1 
     } 
     node.right 
    } 
    walk(node,1) 
} 
+0

'treeNodes'定义在哪里?有没有理由你没有递归定义TreeNode? 'walktree'有什么意义?重新编号“左”和“右”值?为什么不是与'TreeNode's相关的'left'和'right'值? – dhg 2012-04-11 16:10:13

+3

这不是codereview,但你需要少量评论,而且你至少需要一个有用的评论。您的评论完全为零说明代码尚未告诉您的任何相关内容。把它们全部扔掉,并添加两行描述目标。 – 2012-04-11 16:52:09

+0

@Rex,我有同样的想法:-) – dhg 2012-04-11 16:55:39

回答

6

您的代码看起来是计算中序遍历的编号。

我认为你想让你的代码更好的是fold,它将当前值向下并向上传递更新后的值。请注意,在walktree之前执行treeNodes.groupBy(_.parentId)以防止每次拨打walk时致电treeNodes.filter(...)也可能值得。

val treeNodes = List(TreeNode("1","0"),TreeNode("2","1"),TreeNode("3","1")) 

val treeNodeMap = treeNodes.groupBy(_.parentId).withDefaultValue(Nil) 

def walktree2(node: TreeNode) = { 
    def walk(node: TreeNode, c: Int): Int = { 
    node.left = c 
    val newC = 
     treeNodeMap(node.id)   // get the children without filtering 
     .foldLeft(c+1)((c, child) => walk(child, c) + 1) 
    node.right = newC 
    newC 
    } 

    walk(node, 1) 
} 

并可以产生相同的结果:

scala> walktree2(TreeNode("0","-1")) 
scala> treeNodes.map(n => "(%s,%s)".format(n.left,n.right)) 
res32: List[String] = List((2,7), (3,4), (5,6)) 

这就是说,我会完全重写你的代码如下:

case class TreeNode(  // class is now immutable; `walktree` returns a new tree 
    id: String, 
    value: Int,    // value to be set during `walktree` 
    left: Option[TreeNode], // recursively-defined structure 
    right: Option[TreeNode]) // makes traversal much simpler 

def walktree(node: TreeNode) = { 
    def walk(nodeOption: Option[TreeNode], c: Int): (Option[TreeNode], Int) = { 
    nodeOption match { 
     case None => (None, c) // if this child doesn't exist, do nothing 
     case Some(node) =>  // if this child exists, recursively walk 
     val (newLeft, cLeft) = walk(node.left, c)  // walk the left side 
     val newC = cLeft + 1        // update the value 
     val (newRight, cRight) = walk(node.right, newC) // walk the right side 
     (Some(TreeNode(node.id, newC, newLeft, newRight)), cRight) 
    } 
    } 

    walk(Some(node), 0)._1 
} 

然后你可以使用它像这样:

walktree(
    TreeNode("1", -1, 
    Some(TreeNode("2", -1, 
     Some(TreeNode("3", -1, None, None)), 
     Some(TreeNode("4", -1, None, None)))), 
    Some(TreeNode("5", -1, None, None)))) 

要生产:

Some(TreeNode(1,4, 
    Some(TreeNode(2,2, 
    Some(TreeNode(3,1,None,None)), 
    Some(TreeNode(4,3,None,None)))), 
    Some(TreeNode(5,5,None,None)))) 
+0

美味foldLeft。纯金。 – 2012-04-11 16:42:49

+0

+1了解什么是最好的,以找到解决方案 – 2012-04-11 16:53:09

+0

它实际上比黄金好。这会让节点散步时间减少半秒!谢谢! – 2012-04-11 16:58:56

1

如果我得到你的算法正确:

def walktree(node: TreeNode, c: Int): Int = { 
    node.left = c 

    val c2 = treeNodes.filter(_.parentId == node.id).foldLeft(c + 1) { 
     (cur, n) => walktree(n, cur) 
    } 

    node.right = c2 + 1 
    c2 + 2 
} 

walktree(new TreeNode("", ""), 0) 

关闭接一个错误都可能发生。

一些随机的想法(更适合http://codereview.stackexchange.com):

  • ,编译......我们不得不猜测,尝试张贴是TreeNode序列:

  • val是隐含的case类:

    case class TreeNode(val id: String, val parentId: String) { 
    
  • 避免显式=一个d UnitUnit功能:

    def walktree(node: TreeNode) = { 
    def walk(node: TreeNode): Unit = { 
    
  • 有副作用的方法应该有()

    def increment = {c += 1; c} 
    
  • 这是非常缓慢的,考虑存储在实际的子节点的名单:

    treeNodes filter (_.parentId == node.id) foreach (walk(_)) 
    
  • 更简洁的语法是treeNodes foreach walk

    treeNodes foreach (walk(_)) 
    
+0

谢谢,汤姆。我会看codereview – 2012-04-11 17:05:56