2016-04-27 109 views
12

在机器学习任务。我们应该得到一组具有约束的随机w.r.t正态分布。我们可以通过np.random.normal()得到一个正态分布号,但它不提供任何绑定参数。我想知道如何做到这一点?如何在numpy的范围内获得正态分布?

+4

当不可t正常的随机样本按定义分布式数据是无界的? – Tom

回答

8

如果你正在寻找的Truncated normal distribution,SciPy的有一个功能叫truncnorm

这种分布的标准形式是一个标准的正常截断 区间[A,B] - 注意, a和b在标准法线的域 上定义。要转换剪辑的值对于特定均值和 标准偏差,使用:

A,B =(myclip_a - my_mean)/ my_std,(myclip_b - my_mean)/ my_std

truncnorm取A和B作为形状参数。

>>> from scipy.stats import truncnorm 
>>> truncnorm(a=-2/3., b=2/3., scale=3).rvs(size=10) 
array([-1.83136675, 0.77599978, -0.01276925, 1.87043384, 1.25024188, 
     0.59336279, -0.39343176, 1.9449987 , -1.97674358, -0.31944247]) 

上面的例子是由有界-2和2,并返回10个随机变元(使用.rvs()方法)

>>> min(truncnorm(a=-2/3., b=2/3., scale=3).rvs(size=10000)) 
-1.9996074381484044 
>>> max(truncnorm(a=-2/3., b=2/3., scale=3).rvs(size=10000)) 
1.9998486576228549 

下面是-6直方图,6:

enter image description here

+0

为什么你不使用truncnorm(a = -2,b = 2,scale = 1) – maple

+2

只是为了说清楚a和b是形状参数,否则读者可能会尝试-2,2的比例不等于1 ,然后得到外部的随机值[-2,2] – bakkal

12

参数化truncnorm复杂,所以这里是转换参数化的东西更直观的功能:

from scipy.stats import truncnorm 

def get_truncated_normal(mean=0, sd=1, low=0, upp=10): 
    return truncnorm(
     (low - mean)/sd, (upp - mean)/sd, loc=mean, scale=sd) 


如何使用它?

  1. 实例与参数发生器:意味着标准偏差,和截断范围

    >>> X = get_truncated_normal(mean=8, sd=2, low=1, upp=10) 
    
  2. 然后,可以使用X,以产生值:

    >>> X.rvs() 
    6.0491227353928894 
    
  3. 或者,numpy a rray用N产生的值:

    >>> X.rvs(10) 
    array([ 7.70231607, 6.7005871 , 7.15203887, 6.06768994, 7.25153472, 
         5.41384242, 7.75200702, 5.5725888 , 7.38512757, 7.47567455]) 
    

甲视觉例

这里是三个不同的截短的正态分布的情节:

X1 = get_truncated_normal(mean=2, sd=1, low=1, upp=10) 
X2 = get_truncated_normal(mean=5.5, sd=1, low=1, upp=10) 
X3 = get_truncated_normal(mean=8, sd=1, low=1, upp=10) 

import matplotlib.pyplot as plt 
fig, ax = plt.subplots(3, sharex=True) 
ax[0].hist(X1.rvs(10000), normed=True) 
ax[1].hist(X2.rvs(10000), normed=True) 
ax[2].hist(X3.rvs(10000), normed=True) 
plt.show() 

enter image description here

+1

精彩的回答,谢谢! – Gabriel

+0

+1。但值得注意的是,如果函数内部立即使用'get_truncated_normal.rvs()',而不是在外部调用该函数,该函数将变得更快。当然,这只有在你想要随机抽签时才有用 –