我必须使用python numpy库实现随机梯度下降。为了这个目的,我给下面的函数定义:Python的numpy随机梯度下降实现
def compute_stoch_gradient(y, tx, w):
"""Compute a stochastic gradient for batch data."""
def stochastic_gradient_descent(
y, tx, initial_w, batch_size, max_epochs, gamma):
"""Stochastic gradient descent algorithm."""
我也给予以下帮助功能:
def batch_iter(y, tx, batch_size, num_batches=1, shuffle=True):
"""
Generate a minibatch iterator for a dataset.
Takes as input two iterables (here the output desired values 'y' and the input data 'tx')
Outputs an iterator which gives mini-batches of `batch_size` matching elements from `y` and `tx`.
Data can be randomly shuffled to avoid ordering in the original data messing with the randomness of the minibatches.
Example of use :
for minibatch_y, minibatch_tx in batch_iter(y, tx, 32):
<DO-SOMETHING>
"""
data_size = len(y)
if shuffle:
shuffle_indices = np.random.permutation(np.arange(data_size))
shuffled_y = y[shuffle_indices]
shuffled_tx = tx[shuffle_indices]
else:
shuffled_y = y
shuffled_tx = tx
for batch_num in range(num_batches):
start_index = batch_num * batch_size
end_index = min((batch_num + 1) * batch_size, data_size)
if start_index != end_index:
yield shuffled_y[start_index:end_index], shuffled_tx[start_index:end_index]
我实现了以下两个功能:
def compute_stoch_gradient(y, tx, w):
"""Compute a stochastic gradient for batch data."""
e = y - tx.dot(w)
return (-1/y.shape[0])*tx.transpose().dot(e)
def stochastic_gradient_descent(y, tx, initial_w, batch_size, max_epochs, gamma):
"""Stochastic gradient descent algorithm."""
ws = [initial_w]
losses = []
w = initial_w
for n_iter in range(max_epochs):
for minibatch_y,minibatch_x in batch_iter(y,tx,batch_size):
w = ws[n_iter] - gamma * compute_stoch_gradient(minibatch_y,minibatch_x,ws[n_iter])
ws.append(np.copy(w))
loss = y - tx.dot(w)
losses.append(loss)
return losses, ws
我不确定迭代应该在范围内(max_epochs)还是在更大范围内完成。我这样说是因为我读到一个时代是“每次我们贯穿整个数据集”。所以我认为一个时代包含更多的迭代......
对于第二个问题:读了* *批**,**小批**和**时代**关于sgd。 – sascha
您在内部循环中调用'batch_iter',每次调用时都会实例化一个新的生成器对象。相反,你想要在循环外实例化一个单独的生成器,然后迭代它, '对于minibatch_y,minibatch_x在batch_iter(...)'中。 –