2017-09-14 136 views
5

我已经使用How to apply piecewise linear fit in Python?这个问题中发现的一些代码来执行具有单个断点的分段线性近似。具有n个断点的分段线性拟合

的代码如下:

from scipy import optimize 
import matplotlib.pyplot as plt 
import numpy as np 
%matplotlib inline 

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12, 13, 14, 15], dtype=float) 
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03]) 

def piecewise_linear(x, x0, y0, k1, k2): 
    return np.piecewise(x, 
         [x < x0], 
         [lambda x:k1*x + y0-k1*x0, lambda x:k2*x + y0-k2*x0]) 

p , e = optimize.curve_fit(piecewise_linear, x, y) 
xd = np.linspace(0, 15, 100) 
plt.plot(x, y, "o") 
plt.plot(xd, piecewise_linear(xd, *p)) 

我试图找出如何我可以扩展处理ñ断点。

我试着用下面的代码来处理2断点的piecewise_linear()方法,但它不以任何方式改变断点的值。

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20], dtype=float) 
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03, 150, 152, 154, 156, 158]) 

def piecewise_linear(x, x0, x1, a1, b1, a2, b2, a3, b3): 
    return np.piecewise(x, 
         [x < x0, np.logical_and(x >= x0, x < x1), x >= x1 ], 
         [lambda x:a1*x + b1, lambda x:a2*x+b2, lambda x: a3*x + b3]) 

p , e = optimize.curve_fit(piecewise_linear, x, y) 
xd = np.linspace(0, 20, 100) 
plt.plot(x, y, "o") 
plt.plot(xd, piecewise_linear(xd, *p)) 

任何投入将不胜感激

+0

'''它不work'''是几乎无用的描述。我也认为你不能通过curve_fit()来实现这一点,当有多个断点时(需要线性约束来处理b0 sascha

+0

我认为,如果我最初在x轴上均匀分布断点,那么找到局部最小值就足以提供一个体面的非最优解。你知道另一个支持线性约束的优化模块吗? –

+0

正如我告诉你的,这不仅仅是这个。忽略平滑性和潜在的非凸性,你可以用scipy的更一般的优化函数,即COBYLA和SQSLP(唯一的两个支持约束)来解决这个问题。我看到的唯一真正的方法是混合整数凸规划,但软件是稀疏的(bonmin和couenne是两个开源解决方案,不适合从python使用; pajarito @ julialang;但是这种方法通常需要一些非 - 简单的公式)。 – sascha

回答

4

NumPy的有polyfit function这使得它很容易通过一组点找到最佳拟合线:

coefs = npoly.polyfit(xi, yi, 1) 

所以,真正唯一的困难正在找到断点。对于给定的一组 断点,通过给定数据找到最合适的线是很简单的。

因此,而不是试图一下子找到断点系数线性部分的 的位置,就足够了断点的参数空间 减少了。

由于断点可以通过它们的整数索引值来指定到x阵列, 参数空间可以被认为是对N尺寸,其中 N是断点的数目的整数网格点。

optimize.curve_fit不是一个很好的选择,因为这个问题的最小值为 ,因为参数空间是整数值。如果您要使用curve_fit, ,算法会调整参数以确定 移动的方向。如果调整小于1个单位,则断点的x值不会变为 ,因此错误不会更改,因此算法不会收到有关正确移动参数方向的信息 。因此,当参数空间基本上是整数值时,curve_fit 往往会失败。

一个更好但不是很快的最小化器将是一个强力网格搜索。如果 断点数很少(参数空间x-值小于 ),这可能就足够了。如果断点数量很大和/或参数空间很大,则可能会设置多级粗/细网格搜索(蛮力)。或者,也许有人会建议比蛮力更聪明的最小化...


import numpy as np 
import numpy.polynomial.polynomial as npoly 
from scipy import optimize 
import matplotlib.pyplot as plt 
np.random.seed(2017) 

def f(breakpoints, x, y, fcache): 
    breakpoints = tuple(map(int, sorted(breakpoints))) 
    if breakpoints not in fcache: 
     total_error = 0 
     for f, xi, yi in find_best_piecewise_polynomial(breakpoints, x, y): 
      total_error += ((f(xi) - yi)**2).sum() 
     fcache[breakpoints] = total_error 
    # print('{} --> {}'.format(breakpoints, fcache[breakpoints])) 
    return fcache[breakpoints] 

def find_best_piecewise_polynomial(breakpoints, x, y): 
    breakpoints = tuple(map(int, sorted(breakpoints))) 
    xs = np.split(x, breakpoints) 
    ys = np.split(y, breakpoints) 
    result = [] 
    for xi, yi in zip(xs, ys): 
     if len(xi) < 2: continue 
     coefs = npoly.polyfit(xi, yi, 1) 
     f = npoly.Polynomial(coefs) 
     result.append([f, xi, yi]) 
    return result 

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 
       18, 19, 20], dtype=float) 
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 
       126.14, 140.03, 150, 152, 154, 156, 158]) 
# Add some noise to make it exciting :) 
y += np.random.random(len(y))*10 

num_breakpoints = 2 
breakpoints = optimize.brute(
    f, [slice(1, len(x), 1)]*num_breakpoints, args=(x, y, {}), finish=None) 

plt.scatter(x, y, c='blue', s=50) 
for f, xi, yi in find_best_piecewise_polynomial(breakpoints, x, y): 
    x_interval = np.array([xi.min(), xi.max()]) 
    print('y = {:35s}, if x in [{}, {}]'.format(str(f), *x_interval)) 
    plt.plot(x_interval, f(x_interval), 'ro-') 


plt.show() 

打印

y = poly([ 4.58801083 2.94476604]) , if x in [1.0, 6.0] 
y = poly([-70.36472935 14.37305793]) , if x in [7.0, 15.0] 
y = poly([ 123.24565235 1.94982153]), if x in [16.0, 20.0] 

和情节

enter image description here

+0

很好的答案......我尽可能用'leastsq'和'minim'来尝试一切,但分段参数'x0'和'x1'只是没有正确优化 –

+0

完美。谢谢! –