자연어처리를 위한 RNN, LSTM, GRU

아래 내용은 CS224N 2019 Lecture 6: Language Models and Recurrent Neural Networks
7: Vanishing Gradients and Fancy RNNs을 참조하였습니다.

RNN(Recurrent Neural Network)은 순서가 있는 데이터를 처리하기 위한 Neural Network입니다.
순서가 있는 데이터는 음성, 언어, 주가 등 발생 순서가 중요한 데이터를 의미합니다.

문장의 경우는 이전에 발생한 단어를 보고 다음 단어를 예측하는 경우를 생각할 수 있습니다.


1. RNN

rnn-01.png

RNN은 위 그림과 같이 이전 시점의 hidden state $h_{t-1}$과 현재 시점의 입력 $x_t$를 이용해 현재 시점의 hidden state $h_t$를 계산합니다. 수식은 다음과 같습니다.

$$ h_{t}=tanh(W_h h_{t-1} + W_x x_t + b) $$

이때 weight $W_h$, $W_x$와 bias $b$는 시간 $t$와 상관없이 모두 동일한 값을 사용합니다.

'I went to home'이라는 4개의 단어로 된 문장에 대한 RNN의 동작은 다음과 같습니다.

rnn-02.gif

위 그림과 같이 'I', 'went', 'to', 'home'4 개의 단어로 구성된 문장을 순서대로 위 수식을 이용해서 hidden state를 계산합니다. 이때 $h_0$는 초깃값으로 모두 0인 값을 사용하거나 필요에 따라서는 초깃값을 별도로 지정할 수 있습니다.

RNN의 특징은 다음과 같습니다.

1.1. 순서대로 처리해야 하기 때문에 느림

RNN은 $h_t$를 계산하기 위해서는 $h_{t-1}$ 값이 있어야 합니다. 즉 순차적으로 계산을 해야 합니다. 순차적으로 계산을 해야 하기 때문에 여러 개의 GPU를 사용해서 병렬처리가 어렵습니다. 처리해야 할 데이터의 길이가 길어질수록 RNN은 속도가 느려지는 경향이 있습니다.

1.2. Vanishing gradient problem

RNN은 거리가 먼 데이터의 경우에 gradient가 감소하는 경향이 있습니다.

rnn-03.png

위 그림과 같이 $h_1$의 gradient는 $J_4(\theta)$보다 $J_2(\theta)$가 더 많은 영향을 주는 경향이 있습니다.

$J_2(\theta)$에 대한 $h_1$의 gradient는 아래 수식과 같습니다.

$$ \begin{equation} \begin{split} {\partial J_2(\theta) \over \partial h_1} &= {\partial h_2 \over \partial h_1} \times {\partial J_4(\theta) \over \partial h_2} \end{split} \end{equation} $$

$J_4(\theta)$에 대한 $h_1$의 gradient는 아래 수식과 같습니다.

$$ \begin{equation} \begin{split} {\partial J_4(\theta) \over \partial h_1} &= {\partial h_2 \over \partial h_1} \times {\partial h_3 \over \partial h_2} \times {\partial h_4 \over \partial h_3} \times {\partial J_4(\theta) \over \partial h_4} \end{split} \end{equation} $$

위 두 수식에서 ${\partial h_3 \over \partial h_2} \times {\partial h_4 \over \partial h_3}$이 작을 경우 $J_4(\theta)$는 $J_2(\theta)$보다 $h_1$의 gradient에 영향을 작게 주게 됩니다. 이런 현상은 거리가 멀어질수록 커지게 됩니다.

rnn-04.png

결과적으로 위 그림과 같이 'The writer of the books' 다음에 올 단어는 주어가 writer이므로 is가 정답이지만 거리가 가까운 books에 영향을 받아서 are를 예측할 가능성이 커집니다.

1.3. Exploding gradient problem

$$ \begin{equation} \begin{split} {\partial J_4(\theta) \over \partial h_1} &= {\partial h_2 \over \partial h_1} \times {\partial h_3 \over \partial h_2} \times {\partial h_4 \over \partial h_3} \times {\partial J_4(\theta) \over \partial h_4} \end{split} \end{equation} $$

위 수식에서 ${\partial h_3 \over \partial h_2} \times {\partial h_4 \over \partial h_3}$이 큰 값을 가질 경우 $h_1$의 gradient는 매우 큰 값을 가지게 됩니다. 이럴 경우 학습이 잘 안되거나 loss가 INF, NaN 등의 값을 가질 수 있습니다.

Exploding gradient problem은 gradient clipping을 통해서 문제를 완화할 수 있습니다.


2. LSTM (Long-Short Term Memory)

LSTM은 RNN의 Vanishing gradient problem을 완화하기 위한 Nural Network입니다.

lstm-01.png

위 그림과 같이 hidden state에 추가로 메모리에 이전 정보를 간직하기 위한 cell state가 추가되었습니다.
이전 정보를 간직하거나 새로운 정보의 반영은 3개의 gate에 의해서 조절이 됩니다.

2.1. Forget gate

이전 step의 cell state 정보를 얼마나 사용할지를 결정하는 gate입니다. 수식은 아래와 같습니다.

$$ f_t = \sigma(W_{hf} h_{t-1} + W_{xf} x_t + b_f) $$

위 수식에서 사용된 $\sigma$는 sigmoid로 결과가 (0, 1) 사이의 값을 가지게 됩니다. 즉 값이 0에 가까울수록 과거 cell state 정보를 조금 사용하고 1에 가까울수록 과거 cell state 정보를 많이 사용합니다.

2.2. Input gate

현재 step의 cell state 정보를 얼마나 사용할지를 결정하는 gate입니다. 수식은 아래와 같습니다.

$$ i_t = \sigma(W_{hi} h_{t-1} + W_{xi} x_t + b_i) $$

위 수식의 결과가 0에 가까울수록 현재 cell state 정보를 조금 사용하고 1에 가까울수록 현재 cell state 정보를 많이 사용합니다.

2.3. Output gate

현재 stpe의 hidden state를 계산할 때 cell state 정보를 얼마나 사용할지를 결정하는 gate입니다. 수식은 아래와 같습니다.

$$ o_t = \sigma(W_{ho} h_{t-1} + W_{xo} x_t + b_o) $$

위 수식의 결과가 0에 가까울수록 cell state 정보를 조금 사용하고 1에 가까울수록 cell state 정보를 많이 사용합니다.

2.4. New cell contents

현재 stpe의 cell state 후보입니다. 수식은 아래와 같습니다.

$$ \tilde{C}_t = tanh(W_{hc} h_{t-1} + W_{xc} x_t + b_c) $$

2.5. Cell state

현재 stpe의 cell state입니다. 이전 step의 cell state와 new cell contents의 합으로 구성됩니다. 수식은 아래와 같습니다.

$$ C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t $$

이전 step의 cell state와 forget gate의 element wise product 한 값과 new cell contents와 input gate의 element wise product 한 값의 합입니다.
forget gate 값이 크고 input gate 값이 작은 경우는 이전 step의 cell state가 많이 사용되어 오랫동안 과거의 정보를 기억할 수 있습니다.

2.6. Hidden state

현재 stpe의 hidden state 값 입니다. 수식은 아래와 같습니다.

$$ h_t = o_t \odot tanh(C_t) $$

Output gate를 이용해 hidden state 값이 조절 됩니다.


3. GRU (Gated Recurrent Unit)

GRU는 LSTM에 비해서 단순한 구조이면서도 긴 데이터를 잘 처리하는 Nural Network입니다.

gru-01.png

위 그림과 같이 cell state를 사용하지 않고 hidden state만을 사용했습니다.
Gate의 숫자도 2개로 줄어들었습니다.

3.1. Reset gate

이전 step의 hidden state 정보를 얼마나 사용할지를 결정하는 gate입니다. 수식은 아래와 같습니다.

$$ r_t = \sigma(W_{hr} h_{t-1} + b_{hr}+ W_{xr} x_t + b_{xr}) $$

위 수식의 결과가 0에 가까울수록 이전 hidden state 정보를 조금 사용하고 1에 가까울수록 이전 hidden state 정보를 많이 사용합니다.

3.2. Update gate

현재 step의 new hidden contents 정보를 얼마나 사용할지를 결정하는 gate입니다. 수식은 아래와 같습니다.

$$ u_t = \sigma(W_{hz} h_{t-1} + b_{hz}+ W_{xz} x_t + b_{xz}) $$

위 수식의 결과가 0에 가까울수록 현재 new hidden contents 정보를 조금 사용하고 1에 가까울수록 현재 new hidden contents 정보를 많이 사용합니다.

3.3. New hidden contents

현재 stpe의 hidden state 후보입니다. 수식은 아래와 같습니다.

$$ \tilde{h}_t = tanh(r \odot (W_{hg}h_{t-1} + b_{hg}) + W_{xg} x_t + b_{xg}) $$

Reset gate를 이용해 이전 hidden state의 사용량을 조절합니다.

3.4. Hidden state

현재 stpe의 hidden state 값 입니다. 수식은 아래와 같습니다.

$$ h_t = u_t \odot h_{t - 1} + (1 - u_t) \odot \tilde{h}_t $$

Update gate를 이용해 이전 step의 hidden state와 현재 new hidden contents 사용량이 조절됩니다.
Update gate 값이 1에 가까울수록 이전 step의 hidden state가 많이 사용되어 오랫동안 과거의 정보를 기억할 수 있습니다.