2017-04-16 46 views
2

计算和序列假设我有一个data.table通过序列

data.table(A=c(1,2,3,4,5,6,4,2)) 

我如何计算n元素序列的总和?

假设n=3,A的顺序总和应列seq_sum的结果,

data.table(A=c(1,2,3,4,5,6,4,2),seq_sum=c(1+2+3,2+3+4,3+4+5,4+5+6,5+6+4,6+4+2,4+2,2)) 

如何有效地做到这一点?

回答

7

另一种选择是使用Reduceshift

dt[, seq_sum := Reduce(`+`, shift(A, 0:2, 0, 'lead'))] 

其给出:

> dt 
    A seq_sum 
1: 1  6 
2: 2  9 
3: 3  12 
4: 4  15 
5: 5  15 
6: 6  12 
7: 4  6 
8: 2  2 

全符号与参数名:

dt[, seq_sum := Reduce(`+`, shift(A, n = 0:2, fill = 0, type = 'lead'))] 
+0

谢谢!没有意识到R中有'shift'功能。 – WCMC

+0

@WCMC'shift'是'data.table'包中的一个函数;它与“滞后”和“引导”功能相当 – Jaap

2
library(data.table) 
dt <- data.table(A=c(1,2,3,4,5,6,4,2)) 
n = 3 
sapply(1:(length(dt$A)), function(i) {sum(dt$A[i:(min(i+n-1,length(dt$A)))])})  

    # [1] 6 9 12 15 15 12 6 2 
3

更新基于评论:

您还可以使用rollapplyzoo包:

library(data.table) 
library(zoo) 
dt <- data.table(A=c(1,2,3,4,5,6,4,2)) 
dt[, seq_sum := rollapply(A, width = 3, FUN = "sum", align = "left", partial = TRUE)] 

# > dt 
# A seq_sum 
# 1: 1  6 
# 2: 2  9 
# 3: 3  12 
# 4: 4  15 
# 5: 5  15 
# 6: 6  12 
# 7: 4  6 
# 8: 2  2 
+0

为什么最后一个元素9而不是2? – epi99

+0

在我看来,最后一个值是错误的。它应该是2而不是9。 – KoenV

+0

'dt [,seq_sum:= rollapply(A,3,sum,partial = TRUE,align =“left”)]'更正结果 – epi99

2
library(zoo) 

dtab <- data.table(A=c(1,2,3,4,5,6,4,2)) 
dtab[, seq_sum := rollapplyr(A, 3, sum, partial = TRUE, align = "left")] 
3

为了避免重复求和,一个累积性总和可以存储:

n = 3 
A2 = c(A, rep_len(0, n - 1)) 

cs = cumsum(A2) 

并减去相应的差异:

cs[-seq_len(n - 1)] - c(0, cs[seq_len(length(A2) - n)]) 
#[1] 6 9 12 15 15 12 6 2 

,或等效:

diff(c(0, cs), n) 
#[1] 6 9 12 15 15 12 6 2 
1

下面是使用RcppRoll:suml有的计时供你参考的另一种方法。 @ Jaap的解决方案使用data.table内置函数是最快的。

library(data.table) 
library(microbenchmark) 

N <- 1e5 
set.seed(0L) 
dt <- data.table(A=rnorm(N)) 
n <- 3 

dt_cumsum <- copy(dt) 
fun_cumsum <- function() { 
    dt_cumsum[, seq_sum := { 
     cs <- cumsum(c(A, rep_len(0, n - 1))) 
     diff(c(0, cs), n) 
    }] 
} 

dt_Reduce <- copy(dt) 
fun_Reduce <- function() { 
    dt_Reduce[, seq_sum := Reduce(`+`, shift(A, n = seq_len(n) - 1, fill = 0, type = 'lead'))] 
} 

library(zoo) 
dt_zoo <- copy(dt) 
fun_zoo <- function() { 
    dt_zoo[, seq_sum := rollapply(A, width = n, FUN = "sum", align = "left", partial = TRUE)] 
} 

fun_base <- function() { 
    sapply(1:(length(dt$A)), function(i) {sum(dt$A[i:(min(i+n-1,length(dt$A)))])}) 
} 

library(RcppRoll) 
dt_RcppRoll <- copy(dt) 
fun_RcppRoll <- function() { 
    dt_RcppRoll[, seq_sum:=head(roll_suml(c(A, rep_len(0, n - 1)), n), -(n-1))] 
} 

ans <- capture.output(microbenchmark(
    fun_cumsum(), 
    fun_Reduce(), 
    fun_zoo(), 
    fun_base(), 
    fun_RcppRoll(), 
    times=5L)) 
writeLines(paste("#", ans)) 

# Unit: milliseconds 
#   expr  min  lq  mean median  uq  max neval 
# fun_cumsum() 2.5983 2.6427 2.67526 2.6462 2.7311 2.7580  5 
# fun_Reduce() 1.3903 1.4274 2.84070 1.6620 1.7047 8.0191  5 
#  fun_zoo() 1225.1620 1242.9112 1289.76416 1258.1143 1355.1070 1367.5263  5 
#  fun_base() 2731.6609 2849.1003 2909.27308 2922.9430 2971.9956 3070.6656  5 
# fun_RcppRoll() 1.7890 1.8430 3.49892 1.9663 2.0774 9.8189  5