페이지

2022년 2월 28일 월요일

RNN and LSTMs

 Let's imagine we are trying to predict the next word in a sentence, given the words up until this point. A neural network that attempted to predict the next word would need to take into account not only the current word but a variable number of prior inputs. If we instead used only a simple feedforward MLP, the network would essentially process the entire sentence of each word as a vector. This introduces the problem of either having to pad variable-length inputs to a common length and not preserving any notion of correlation (that is, which words in the sentence are more relevant than others in generating the next prediction), or only using the last word at each step as the input, which removes the context of the rest of the sentence and all the information it can provide. This kind of problem inspired the "vanilla" RNN which incorporates not only the current input but the prior step's hidden state in computing a neuron's output:

One way to visualize this is to imagine each layer feeding recursively into the next timestep in a sequence. In effect, if we "unroll" each part of the sequence, we end up with a very deep neural network, where each layer shares the same weights.

The same difficulties that characterize training deep feedforward networks also apply to RNNs; gradients tend to die out over long distances using traditional activation functions (or explode if the gradients become greater than 1).

However, unlike feedforward networks, RNNs aren't trained with traditional backpropagation, but rather a variant known as backpropagation through time(BPTT): the network is unrolled, as before, and backpropagation is used, averaging over errors at eatch time point(since an "output," the hidden state, occurs at each step). Also, in the case of RNNs, we run into the problem that the network has a very short memory; it only incorporates information from the most recent unit before the current one and has trouble maintaining long-range context. For applications such as traslation, this is clearly a problem, as the interpretation of a ward at the end of a sentence may depend on terms near the beginning, not just those directly preceding it.

The LSTM network was developed to allow RNNs to maintain a context or state over long sequences.

In a vanilla RNN, we only maintain a short-term memory h coming from the prior step's hidden unit activations, In addition to this short-term memory, the LSTM architecture introduces an additional layer c, the "long-term" memory, which can persist over many timesteps. The design is in some ways reminiscent of an electrical capacitor, which can use the c layer to store up or hold "charge," and discharge it once it has reached some threshold. To compute these updates, an LSTM unit consists of a number of related neurons, or gates, that act together to transform the input at each time step.

Given an input vector x, and the hidden state h, at the previous time t-1, at each time step an LSTM first computers a value from 0 to 1 for each element of c representing what fraction of information is "forgotten" of each element of the vector:

We make a second, similar calculation to determine what from the input value to preserve:

We now know which elements of c are updated; we can compute this update as follows:

where o is a Hadamard product (element-wise multiplication). In essence this equation tells us how to compute updates using the tanh transform, filter them using the input gate, and combine them with the prior time step's long-term memory using the forget gate to potentially filter out old values.


To compute the output at each time step, we compute another output gate:

And to compute the final output at each step (the hidden layer fed as short-term memory to the next step) we have:

Many variants of this basic design have been proposed; for example, the "peephole" LSTM substituted h(t-1) with c(t-1)(thus each operation gets to "peep" at the longterm memory cell), while the GRU simplifies the overall design by removing the output gate. What these designs all have in common is that they avoid the vanishing (or exploding) gradient difficulties seen during the training of RNNs, since the long-term memory acts as a buffer to maintain the gradient and propagate neuronal activations over many timesteps.


댓글 없음: