5-min reads: why propogation cell state through mini-batches slows convergence
learning a neural network from the inside out
In this 5-min read I’ll try to explain how convergence can be slower for LSTM neural networks where memory state is maintained from one mini-batch to the next. An example of this problem can be found in my project about a microsleep detector
Convergence of LSTM with memory passed through mini-batches;
Parameters determine cell state, used to make a prediction, whose error is the basis for parameter update: $$ \theta_b^i =g\left(C_{\mathit{{b}}}^{{{i}}} ,X_b ,Y_{b} \right)+\theta_{b-1}^i $$
The cell state calculated prior to a parameter update is used to calculate the cell state of the next batch: $$ C_b =\mathit{{f}}\left(C_{\mathit{{b}}-1}^{\mathit{{i}}-1} ,\theta_{{{b}}-1}^{\mathit{{i}}} ,{\mathit{{X}}}_{\mathit{{b}}} \right) $$ So, the cell state passed forward as ‘context’ now takes on a different meaning because the parameters aren’t the same as the ones that created the ‘context’ / cell state. This implies a leak in information through training epochs. i.e. $$C_{\mathit{{b}}}^{\mathit{{i}}} \not= C_{\mathit{{b}}}^{{{i}}-1}$$
$$\theta: network\ parameter\ setting,\ C: memory\ cell\ state \\ g\left(·\right): parameter\ update\ rule,\ \mathit{{f}}\left(·\right): memory\ calculation,\ b,i:\ batch\ and\ epoch.$$
Without passing memory through mini-batches, the parameter update will optimise the network for only the inputs of that one mini-batch, which isn’t a problem, because over 100’s of epochs all batches are seen, and they remain the same through the epochs. But a memory cell is putting a different ‘lens’ on those same inputs that are repeated over 100s of epochs, therefore all parameter updates are based off different inputs. There is less consistency for the network to learn.
Because only 1-2 of the transition-patterns the network is trying to learn occur per batch, if two adjacent patterns are very different, the leak in memory ‘meaning’ could be exacerbated by the time training of that same batch happens again. Careful selection of training hyperparmeters is needed to optimise convergence.
Cell state is a function of a lagged parameter setting and lagged cell state, across all batches and iterations;
$$\begin{array}{l}C_b =\mathit{{f}}\left({\mathit{{C}}}_{\mathit{{b}}-1} ,\theta_{\mathit{{b}}-1}^{\mathit{{i}}} ,{\mathit{{X}}}_{\mathit{{b}}} \right)\\C_0 =\mathit{{f}}\left(\left\lbrack 0\right\rbrack ,\theta_{\mathit{{b}}}^{\mathit{{i}}-1} ,{\mathit{{X}}}_{\mathit{{b}}} \right)\end{array}$$