2016-11-09 42 views
3

我必须找到最佳的解决方案> 10^7方程系统有5个方程,每个变量有2个变量(5次测量,找到2个参数,误差最小的系列)。 下面的代码(通常用来做曲线拟合)做什么,我想:有效地解决大量的线性最小二乘法系统

#Create_example_Data 
n = 100 
T_Arm = np.arange(10*n).reshape(-1, 5, 2) 
Erg = np.arange(5*n).reshape(-1, 5) 
m = np.zeros(n) 
c = np.zeros(n) 
#Run 
for counter in xrange(n): 
    m[counter], c[counter] = np.linalg.lstsq(T_Arm[counter, :, :], 
               Erg[counter, :])[0] 

可惜实在是太慢了。有什么办法可以显着提高代码的速度吗?我试图引导它,但我没有成功。将最后一个解决方案用作初始猜测也可能是一个好主意。使用scipy.optimize.leastsq也没有加速。

+0

什么是'Inputlen'?是'n'吗? – TuanDT

+0

n是等式系统的数目,等于Inputlen,我更正了代码 – Okapi575

+0

,我认为它应该是'xrange(n)'而不是'xrange(len(n))',因为'n'只是一个整数(100在这种情况下) – TuanDT

回答

3

您可以使用稀疏矩阵A存储T_Arm的(5,2)项在其对角线上,并求解AX = b,其中b是由堆积条目Erg组成的向量。然后用scipy.sparse.linalg.lsqr(A,b)解决系统问题。

为了构建A和B我使用n = 3的用于可视化目的:

import numpy as np 
import scipy 
from scipy.sparse import bsr_matrix 
n = 3 
col = np.hstack(5 * [np.arange(10 * n/5).reshape(n, 2)]).flatten() 
array([ 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 2., 3., 2., 
     3., 2., 3., 2., 3., 2., 3., 4., 5., 4., 5., 4., 5., 
     4., 5., 4., 5.]) 

row = np.tile(np.arange(10 * n/2), (2, 1)).T.flatten() 
array([ 0., 0., 1., 1., 2., 2., 3., 3., 4., 4., 5., 
     5., 6., 6., 7., 7., 8., 8., 9., 9., 10., 10., 
     11., 11., 12., 12., 13., 13., 14., 14.]) 

A = bsr_matrix((T_Arm[:n].flatten(), (row, col)), shape=(5 * n, 2 * n)) 
A.toarray() 
array([[ 0, 1, 0, 0, 0, 0], 
     [ 2, 3, 0, 0, 0, 0], 
     [ 4, 5, 0, 0, 0, 0], 
     [ 6, 7, 0, 0, 0, 0], 
     [ 8, 9, 0, 0, 0, 0], 
     [ 0, 0, 10, 11, 0, 0], 
     [ 0, 0, 12, 13, 0, 0], 
     [ 0, 0, 14, 15, 0, 0], 
     [ 0, 0, 16, 17, 0, 0], 
     [ 0, 0, 18, 19, 0, 0], 
     [ 0, 0, 0, 0, 20, 21], 
     [ 0, 0, 0, 0, 22, 23], 
     [ 0, 0, 0, 0, 24, 25], 
     [ 0, 0, 0, 0, 26, 27], 
     [ 0, 0, 0, 0, 28, 29]], dtype=int64) 

b = Erg[:n].flatten() 

然后

scipy.sparse.linalg.lsqr(A, b)[0] 
array([ 5.00000000e-01, -1.39548109e-14, 5.00000000e-01, 
     8.71088538e-16, 5.00000000e-01, 2.35398726e-15]) 

编辑:因为它似乎A没有在存储器中作为巨大:多个上块稀疏矩阵here

+0

太棒了! – piRSquared

+0

不错的想法。但数组A包含(n * 5)*(n * 2)个值,它们都是int64,所以对于n = 1000,需要80 MB,对于n = 10000 8GB等等。所有这些字节需要被读取和写入。也许使用n = 100并且以块处理数据比原始解决方案更快。我会试一试。 – Okapi575

+1

@ Okapi575 A不是一个数组(我只使用A.toarray()来表示它),而是一个稀疏矩阵,所以内存消耗会更低。我试图看看实施是否有效,但我承认我没有计时。 –