2016-12-28 86 views
1

我只是试图绘制两个gaussians并找到交点。我有以下代码。这不是绘制确切的十字路口,我真的不知道为什么。这就像刚刚稍微偏离了一点,但如果我们把减去的gaussians的日志记录下来,并且看起来应该是正确的,那么我通过派生的解决方案工作。谁能帮忙?非常感谢!高斯之间的交点

import numpy as np 
import matplotlib.pyplot as plt 

def plot_normal(x, mean = 0, sigma = 1): 
    return 1.0/(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2)) 

# found online 
def solve_gasussians(m1, s1, m2, s2): 
    a = 1.0/(2.0*s1**2) - 1.0/(2.0*s2**2) 
    b = m2/(s2**2) - m1/(s1**2) 
    c = m1**2 /(2*s1**2) - m2**2/(2.0*s2**2) - np.log(s2/s1) 
    return np.roots([a,b,c]) 

s1 = np.linspace(0, 10,300) 
s2 = np.linspace(0, 14, 300) 

solved_val = solve_gasussians(5.0, 0.5, 7.0, 1.0) 
print solved_val 
solved_val = solved_val[0] 
plt.figure('Baseline Distributions') 
plt.title('Baseline Distributions') 
plt.xlabel('Response Rate') 
plt.ylabel('Probability') 
plt.plot(s1, plot_normal(s1, 5.0, 0.5),'r', label='s1') 
plt.plot(s2, plot_normal(s2, 7.0, 1.0),'b', label='s2') 
plt.plot(solved_val, plot_normal(solved_val, 7.0, 1.0), 'mo') 
plt.legend() 
plt.show() 
+0

你能带我们去你找到了解决办法在网上让大家不必为我们自己去尝试呢? –

+0

我认为他们提到的解决方案可能是[SO问题](http://stackoverflow.com/a/22579904/752843)。所以我们不能完全责怪他们在代码中缺乏评论。 – Richard

+0

@Richard,这也是我的想法。 –

回答

0

你有plot_normal功能的小错误 - 你缺少平方根分母。正确的版本:

def plot_normal(x, mean = 0, sigma = 1): 
    return 1.0/np.sqrt(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2)) 

给出了预期的结果: enter image description here

和两个讲话。

  1. 请记住,您可以有一般方程的2个根(两个交点),这是您提供的参数的情况。
  2. 据我所知np.roots为您提供了近似的结果,但你的猫很容易地得到确切的结果,改写solve_gasussians功能:

    def solve_gasussians(m1, s1, m2, s2): 
        # coefficients of quadratic equation ax^2 + bx + c = 0 
        a = (s1**2.0) - (s2**2.0) 
        b = 2 * (m1 * s2**2.0 - m2 * s1**2.0) 
        c = m2**2.0 * s1**2.0 - m1**2.0 * s2**2.0 - 2 * s1**2.0 * s2**2.0 * np.log(s1/s2) 
        x1 = (-b + np.sqrt(b**2.0 - 4.0 * a * c))/(2.0 * a) 
        x2 = (-b - np.sqrt(b**2.0 - 4.0 * a * c))/(2.0 * a) 
        return x1, x2 
    
0

我不知道错误在哪里在你的代码中。但我想我找到了你借来的代码,并将其作为你需要的调整的一部分。

import numpy as np 
import matplotlib.pyplot as plt 
from scipy.stats import norm 

def solve(m1,m2,std1,std2): 
    a = 1/(2*std1**2) - 1/(2*std2**2) 
    b = m2/(std2**2) - m1/(std1**2) 
    c = m1**2 /(2*std1**2) - m2**2/(2*std2**2) - np.log(std2/std1) 
    return np.roots([a,b,c]) 

m1 = 5 
std1 = 0.5 
m2 = 7 
std2 = 1 

result = solve(m1,m2,std1,std2) 

x = np.linspace(-5,9,10000) 
plot1=plt.plot(x,[norm.pdf(_,m1,std1) for _ in x]) 
plot2=plt.plot(x,[norm.pdf(_,m2,std2) for _ in x]) 
plot3=plt.plot(result[0],norm.pdf(result[0],m1,std1) ,'o') 

plt.show() 

我会提供不请自来的两点建议,可能使生活更容易为你(在他们为我做的方式):

  • 当你适应代码尽量做小,增量变化和检查代码在每一步仍然有效。
  • 寻找现有的免费库。在这种情况下,来自scipy的规范是原始代码中所用内容的很好替代品。
0

错误在这里。这条线:

def plot_normal(x, mean = 0, sigma = 1): 
    return 1.0/(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2)) 

应该是这样的:

def plot_normal(x, mean = 0, sigma = 1): 
    return 1.0/np.sqrt(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2)) 

你忘了sqrt

这将是明智的使用预先存在的正常的PDF如果这是可用的,如:

import scipy.stats 
def plot_normal(x, mean = 0, sigma = 1): 
    return scipy.stats.norm.pdf(x,loc=mean,scale=sigma) 

它也可以解决了整整交点。 This answer为高斯交点的根提供了一个二次方程。使用极大值解决x给出以下表达式。其中虽然复杂,但不依赖于迭代方法,并且可以从简单的表达式自动生成。

def solve_gaussians(m1,s1,m2,s2): 
    x1 = (s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)+m1*s2**2-m2*s1**2)/(s2**2-s1**2) 
    x2 = -(s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)-m1*s2**2+m2*s1**2)/(s2**2-s1**2) 
    return x1,x2 

把它完全给人:

import numpy as np 
import matplotlib.pyplot as plt 
import scipy.stats 

def plot_normal(x, mean = 0, sigma = 1): 
    return scipy.stats.norm.pdf(x,loc=mean,scale=sigma) 

#Use the equation from [this answer](https://stats.stackexchange.com/a/12213/12116) solved for x 
def solve_gaussians(m1,s1,m2,s2): 
    x1 = (s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)+m1*s2**2-m2*s1**2)/(s2**2-s1**2) 
    x2 = -(s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)-m1*s2**2+m2*s1**2)/(s2**2-s1**2) 
    return x1,x2 

s = np.linspace(0, 14,300) 
x = solve_gaussians(5.0,0.5,7.0,1.0) 

plt.figure('Baseline Distributions') 
plt.title('Baseline Distributions') 
plt.xlabel('Response Rate') 
plt.ylabel('Probability') 
plt.plot(s, plot_normal(s, 5.0, 0.5),'r', label='s1') 
plt.plot(s, plot_normal(s, 7.0, 1.0),'b', label='s2') 
plt.plot(x[0],plot_normal(x[0],5.,0.5),'mo') 
plt.plot(x[1],plot_normal(x[1],5.,0.5),'mo') 
plt.legend() 
plt.show() 

,并提供:

Intersection of Gaussians