#!/usr/bin/env python # coding: utf-8 # ## Recurrent Networks # In this notebook we'll see how to use recurrent networks to create a character-level language model for text generation. We'll start with a simple fully-connected network and show how it can be used as an "unrolled" recurrent layer, then gradually build up from there until we have a model capable of generating semi-reasonable sounding text. Much of this content is based on Jeremy Howard's [fast.ai lessons](http://course.fast.ai/), specifically lesson 6 from the first course. However, we'll use Keras instead of PyTorch and build out all of the code from scratch rather than relying on the fast.ai library. # # The text corpus we're using for this task are the works of the philosopher Nietzsche. The whole corpus can be found [here](https://s3.amazonaws.com/text-datasets/nietzsche.txt). Let's start by loading the data into memory and taking a peek at the beginning of the text. # In[2]: get_ipython().run_line_magic('matplotlib', 'inline') import io import numpy as np import keras from keras.utils.data_utils import get_file path = get_file('nietzsche.txt', origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt') with io.open(path, encoding='utf-8') as f: text = f.read().lower() len(text) # In[3]: text[:400] # Get the unique set of characters that appear in the text. This is our vocabulary. # In[4]: chars = sorted(list(set(text))) vocab_size = len(chars) vocab_size # In[5]: ''.join(chars) # Create a dictionary that maps each unique character to an integer, which is what we'll feed into the model. The actual integer used isn't important, it just has to be unique (here we just take the index from the "chars" list above). It's also useful to have a reverse mapping to get back to characters in order to do something with the model output. Finally, create a "mapped" corpus where each character in the data has been replaced with its corresponding integer. # In[6]: char_indices = {c: i for i, c in enumerate(chars)} indices_char = {i: c for i, c in enumerate(chars)} idx = [char_indices[c] for c in text] idx[:20] # Example of how to convert from integers back to characters. # In[7]: ''.join(indices_char[i] for i in idx[:100]) # For our first attempt, we'll build a model that accepts a 3-character sequence as input and tries to predict the following character in the text. For simplicity, we can just manually create each character sequence. Start by creating lists that take every 3rd character, offset by some amount between 0 and 3. # In[8]: cs = 3 c1 = [idx[i] for i in range(0, len(idx) - cs, cs)] c2 = [idx[i + 1] for i in range(0, len(idx) - cs, cs)] c3 = [idx[i + 2] for i in range(0, len(idx) - cs, cs)] c4 = [idx[i + 3] for i in range(0, len(idx) - cs, cs)] # This just converts the lists to numpy arrays. Notice that this approach resulted in non-overlapping sequences, i.e. we use characters 0-2 to predict character 3, then characters 3-5 to predict character 6, etc. That's why the array shape is about 1/3 the size of the original text. We'll see how to improve on this later. # In[9]: x1 = np.stack(c1) x2 = np.stack(c2) x3 = np.stack(c3) y = np.stack(c4) x1.shape, y.shape # Our model will use embeddings to represent each character. This is why we converted them to integers before - each integer gets turned into a vector in the embedding layer. Set some variables for the embedding vector size and the number of hidden units to use in the model. Finally, we need to convert the target variable to a one-hot character encoding. This is because the model outputs a probability for each character, and in order to score this properly it needs to be able to compare that output with an array that's structured the same way. # In[10]: n_factors = 42 n_hidden = 256 y_cat = keras.utils.to_categorical(y) y_cat.shape # Now we get to the first iteration of our model. The way I've structred this is by defining the layers of the model so that they can be re-used across multiple inputs. For example, rather than create an embedding layer for each of the three character inputs, we're instead creating one embedding layer and sharing it. This is a reasonable approach to handling sequences since each input comes from an identical distribution. # # The next thing to observe is the part where h is defined. The first character is fed through the hidden layer like normal, but the other characters in the sequence are doing something different. We're re-using the same layer, but instead of just taking the character as input, we're using the character + the previous output h. This is the "hidden" state of the model. I think about it in the following way: "give me the output of this layer for character c conditioned on the fact that these other characters (represented by h) came before it". # # You'll notice that there's no use of an RNN class at all. Basically what's going on here is we're implmenting an "unrolled" RNN from scratch on our own. # In[11]: from keras import backend as K from keras.models import Model from keras.layers import add from keras.layers import Input, Reshape, Dense, Add from keras.layers.embeddings import Embedding from keras.optimizers import Adam def Char3Model(vocab_size, n_factors, n_hidden): embed_layer = Embedding(vocab_size, n_factors) reshape_layer = Reshape((n_factors,)) input_layer = Dense(n_hidden, activation='relu') hidden_layer = Dense(n_hidden, activation='tanh') output_layer = Dense(vocab_size - 1, activation='softmax') in1 = Input(shape=(1,)) in2 = Input(shape=(1,)) in3 = Input(shape=(1,)) c1 = input_layer(reshape_layer(embed_layer(in1))) c2 = input_layer(reshape_layer(embed_layer(in2))) c3 = input_layer(reshape_layer(embed_layer(in3))) h = hidden_layer(c1) h = hidden_layer(add([h, c2])) h = hidden_layer(add([h, c3])) out = output_layer(h) model = Model(inputs=[in1, in2, in3], outputs=out) opt = Adam(lr=0.01) model.compile(loss='categorical_crossentropy', optimizer=opt) return model # Train the model for a few iterations. # In[12]: model = Char3Model(vocab_size, n_factors, n_hidden) history = model.fit(x=[x1, x2, x3], y=y_cat, batch_size=512, epochs=3, verbose=1) # In order to make sense of the model's output, we need a helper function that converts the character probability array that it returns into an actual character. This is where the reverse lookup table we created earlier comes in handy! # In[13]: def get_next_char(model, s): idxs = [np.array([char_indices[c]]) for c in s] pred = model.predict(idxs) char_idx = np.argmax(pred) return chars[char_idx] get_next_char(model, ' th') # In[14]: get_next_char(model, 'and') # It appears to be spitting out sensible results. The 3-character approach is very limiting though. That's not enough context for even a full word most of the time. For our next step, let's expand the input window to 8 characters. We can create an input array using some list comprehension magic to output a list of lists, then stacking them together into an array. Try experimenting with the logic below yourself to get a better sense of what it's doing. The target array is created in a similar manner as before. # In[15]: cs = 8 c_in = [[idx[i + j] for i in range(cs)] for j in range(len(idx) - cs)] c_out = [idx[j + cs] for j in range(len(idx) - cs)] X = np.stack(c_in, axis=0) y = np.stack(c_out) # Notice this time we're making better use of our data by making the sequences overlapping. For example, the first "row" in the data uses characters 0-7 to predict character 8. The next "row" uses characters 1-8 to predict character 9, and so on. We just increment by one each time. It does create a lot of duplicate data, but that's not a huge issue with a corpus of this size. # In[16]: X.shape, y.shape # It helps to look at an example to see how the data is formatted. Each row is a sequence of 8 characters from the text. As you go down the rows it's apparent they're offset by one character. # In[17]: X[:cs, :cs] # In[18]: y[:cs] # Since we have separate inputs for each character, Keras expects separate arrays rather than one big array. Also need to one-hot encode the target again. # In[19]: X_array = [X[:, i] for i in range(X.shape[1])] y_cat = keras.utils.to_categorical(y) # The 8-character model works exactly the same way as the 3-character model, there are just more of the same steps. Rather than write them all out in code, I converted it to a loop. Again, this is almost exactly the way an RNN works under the hood. # In[20]: def CharLoopModel(vocab_size, n_chars, n_factors, n_hidden): embed_layer = Embedding(vocab_size, n_factors) reshape_layer = Reshape((n_factors,)) input_layer = Dense(n_hidden, activation='relu') hidden_layer = Dense(n_hidden, activation='tanh') output_layer = Dense(vocab_size, activation='softmax') inputs = [] for i in range(n_chars): inp = Input(shape=(1,)) inputs.append(inp) c = input_layer(reshape_layer(embed_layer(inp))) if i == 0: h = hidden_layer(c) else: h = hidden_layer(add([h, c])) out = output_layer(h) model = Model(inputs=inputs, outputs=out) opt = Adam(lr=0.001) model.compile(loss='categorical_crossentropy', optimizer=opt) return model # Train the model a bit and generate some predictions. # In[21]: model = CharLoopModel(vocab_size, cs, n_factors, n_hidden) history = model.fit(x=X_array, y=y_cat, batch_size=512, epochs=5, verbose=1) # In[22]: get_next_char(model, 'for thos') # In[23]: get_next_char(model, 'queens a') # Now we're ready to replace the loop with a real recurrent layer. The first thing to notice is that we no longer need to create separate inputs for each step in the sequence - recurrent layers in Keras are designed to accept 3-dimensional arrays where the 2nd dimension is the number of timesteps. We just need to add an extra dimension to the input shape with the number of characters. # # The second wrinkle is the use of the "TimeDistributed" class on the embedding layer. Just as with the input, this is another more convenient way of doing what we were already doing by defining and re-using layers. Wrapping a layer with "TimeDistributed" basically says "apply this to every timestep in the array". Like the RNN, it expects (and returns) a 3-dimensional array. The reshape operation is the same story, we just add another dimension to it. The RNN layer itself is very straightforward. # In[24]: from keras.layers import TimeDistributed, SimpleRNN def CharRnn(vocab_size, n_chars, n_factors, n_hidden): i = Input(shape=(n_chars, 1)) x = TimeDistributed(Embedding(vocab_size, n_factors))(i) x = Reshape((n_chars, n_factors))(x) x = SimpleRNN(n_hidden, activation='tanh')(x) x = Dense(vocab_size, activation='softmax')(x) model = Model(inputs=i, outputs=x) opt = Adam(lr=0.001) model.compile(loss='categorical_crossentropy', optimizer=opt) return model # Let's look at a summary of the model. Notice the array shapes have a third dimension to them until we get on the other side of the RNN. # In[25]: model = CharRnn(vocab_size, cs, n_factors, n_hidden) model.summary() # Reshape the input to match the 3-dimensional input format of (rows, timesteps, features). Since we only have one feature, the last dimension is trivially set to one. # In[26]: X = X.reshape((X.shape[0], cs, 1)) X.shape # Train the model for a bit. Notice that the loss looks almost identical to the last model! All we really did is shuffle things around to take advantage of some built-in classes that Keras provides. The model structure and performance should look no different than before. # In[27]: history = model.fit(x=X, y=y_cat, batch_size=512, epochs=5, verbose=1) # We can train it a bit longer at a lower learning rate to reduce the loss further. # In[28]: K.set_value(model.optimizer.lr, 0.0001) history = model.fit(x=X, y=y_cat, batch_size=512, epochs=3, verbose=1) # In[29]: def get_next_char(model, s): idxs = np.array([char_indices[c] for c in s]) idxs = idxs.reshape((1, idxs.shape[0], 1)) pred = model.predict(idxs) char_idx = np.argmax(pred) return chars[char_idx] get_next_char(model, 'for thos') # Since the model is getting better, we can now try to generate more than one character of text. All we need is an initial seed of 8 characters and it can go on as long as we like. To do this, we'll create a simple helper function that continuously predicts the next character using the last 8 characters that it spit out (starting with the seed value). # In[30]: def get_next_n_chars(model, s, n): r = s for i in range(n): c = get_next_char(model, s) r += c s = s[1:] + c return r get_next_n_chars(model, 'for thos', 40) # It's definitely getting better. There are more improvements we can make though! In the current model, each instance of the data is completely independent. When a new sequence comes in, the model has no idea what came before that sequence. That "hidden state" mentioned earlier (which is now part of the RNN layer) gets thrown away. However, there's a way we can set this up that persists that hidden state through to the next part of the sequence. In other words, it conditions the output not only on the current 8 characters but all the characters that came before it as well. # # The good news is that this capability is built into Keras's recurrent layers, we just need to set a flag to true! The bad news is that we need to re-think how the data is structured. Stateful models require 1) a fixed batch size, which is specified in the model input, and 2) that each batch be a "slice" of sequences such that the next batch contains the next part of each sequence. In other words, we need to split up our data (which is one long continuous stream of text) into n chunks of equal-length streams of text, where n is the batch size. Then, we need to carve up these n chunks into sequences of length 8 (which is the sequence length the model looks at) with the following character in each sequence being the target (the thing we're predicting). # # If that sounds confusing and complicated, that's because it is. It took me a while to make sense of it (and figure out how to express it in code) but hopefully you can follow along. Below is the first step, which splits the data up into chunks and stacks them vertically into an array. The result is 64 equal-length continuous sequences of text. # In[31]: bs = 64 seg_len = len(text) // bs segments = [idx[i*seg_len:(i+1)*seg_len] for i in range(bs)] segments = np.stack(segments) segments.shape # One other change happening at the same time is we're no longer staggering the input by one character (which duplicates a lot of text because most of it is repeated in each row). Instead, we're now carving the data into chucks of non-overlapping characters like we did originally. However, we're going to make better use of it this time. Instead of just predicting character 8 based on characters 0-7, we're going to predict characters 1-8 conditioned on the characters in the sequence that came before them. Each pass will actually be 8 character predictions, and the loss function will be calculated across all of those outputs (we'll see how to do this in a minute). # # Below we're creating a list of lists, where each sub-list is an 8-character sequence. The second list is offset by one (this is our target). # In[32]: c_in = [segments[:,i*cs:(i+1)*cs] for i in range(seg_len // cs)] c_out = [segments[:,(i*cs)+1:((i+1)*cs)+1] for i in range(seg_len // cs)] # Now we just need to concatenate and reshape these into arrays that we can use with the model. We end up with ~75,000 chunks of unique 8-character sequences. # In[33]: X = np.concatenate(c_in) X = X.reshape((X.shape[0], X.shape[1], 1)) y = np.concatenate(c_out) y_cat = keras.utils.to_categorical(y) # In[34]: X.shape, y_cat.shape # Crucially, they are ordered such that the 65th row is a continuation of the 1st row, the 66th row is a continuation of the 2nd row, and so on all the way down. # In[35]: ''.join(indices_char[i] for i in np.concatenate((X[0,:,0], X[64,:,0], X[128,:,0]))) # Next we can create the stateful RNN model. It's similar to the last one but there are a few wrinkles. The input specifies "batch_shape" and has three dimensions (this is a hard requirement to use stateful RNNs in Keras, and gets quite annoying during inference time). We've set "return_sequences" to true, which changes the shape that the RNN returns and gives us an output for each step in the sequence. We've set "stateful" to true, the motivation for which was already discussed. Finally, we've wrapped the last dense layer with "TimeDistributed". This is because the RNN is now returning a higher-dimensional array to account for the output at each timestep. Everything else works basically the same way. # In[36]: def CharStatefulRnn(vocab_size, n_chars, n_factors, n_hidden, bs): i = Input(batch_shape=(bs, n_chars, 1)) x = TimeDistributed(Embedding(vocab_size, n_factors))(i) x = Reshape((n_chars, n_factors))(x) x = SimpleRNN(n_hidden, activation='tanh', return_sequences=True, stateful=True)(x) x = TimeDistributed(Dense(vocab_size, activation='softmax'))(x) model = Model(inputs=i, outputs=x) opt = Adam(lr=0.001) model.compile(loss='categorical_crossentropy', optimizer=opt) return model # Looking at the output shapes, we can see the effect of turning on "return_sequences". Note that the number of model parameters has not changed. The complexity is identical, we've just changed the task and the information available to solve it. # In[37]: model = CharStatefulRnn(vocab_size, cs, n_factors, n_hidden, bs) model.summary() # One quirk of using stateful RNNs is that we now have to manually reset the model state, it never goes away until we tell it to. I just created a simple callback that resets the state at the end of every epoch. # In[38]: from keras.callbacks import Callback class ResetModelState(Callback): def on_epoch_end(self, epoch, logs): self.model.reset_states() reset_state = ResetModelState() # Train the model for a while as before, with the addition of the callback to reset state between epochs. # In[39]: model.fit(x=X, y=y_cat, batch_size=bs, epochs=8, verbose=1, callbacks=[reset_state], shuffle=False) # In[40]: K.set_value(model.optimizer.lr, 0.0001) model.fit(x=X, y=y_cat, batch_size=bs, epochs=3, verbose=1, callbacks=[reset_state], shuffle=False) # The "get next" functions need to be updated since our approach has changed. One of the annoying things about stateful models is the batch size is fixed, so even when making a prediction it needs an array of the same size, no matter if we just want to predict one sequence. I got around this with some numpy hackery. # In[41]: def get_next_char(model, bs, s): idxs = np.array([char_indices[c] for c in s]) idxs = idxs.reshape((1, idxs.shape[0], 1)) idxs = np.repeat(idxs, bs, axis=0) pred = model.predict(idxs, batch_size=bs) char_idx = np.argmax(pred[0, 7]) return chars[char_idx] def get_next_n_chars(model, bs, s, n): r = s for i in range(n): c = get_next_char(model, bs, s) r += c s = s[1:] + c return r # In[42]: get_next_n_chars(model, bs, 'for thos', 40) # The output is actually a bit worse than before, but we're still using simple RNNs which aren't that great to begin with. The real fun comes when we make the jump to a more complex unit like the LSTM. The details of LSTM's are beyond my scope here but there's a great blog post that everyone links to as the canonical explainer for LTMS, which you can find [here](http://colah.github.io/posts/2015-08-Understanding-LSTMs/). This is the easiest step yet as the only thing we need to do is replace the class name. The only other change I made is increasing the number of hidden units. Everything else stays exactly the same. # In[43]: from keras.layers import LSTM n_hidden = 512 def CharStatefulLSTM(vocab_size, n_chars, n_factors, n_hidden, bs): i = Input(batch_shape=(bs, n_chars, 1)) x = TimeDistributed(Embedding(vocab_size, n_factors))(i) x = Reshape((n_chars, n_factors))(x) x = LSTM(n_hidden, return_sequences=True, stateful=True)(x) x = TimeDistributed(Dense(vocab_size, activation='softmax'))(x) model = Model(inputs=i, outputs=x) opt = Adam(lr=0.001) model.compile(loss='categorical_crossentropy', optimizer=opt) return model # LSTMs need to train for a bit longer. We'll do 20 epochs at each learning rate. # In[44]: model = CharStatefulLSTM(vocab_size, cs, n_factors, n_hidden, bs) model.fit(x=X, y=y_cat, batch_size=bs, epochs=20, verbose=1, callbacks=[reset_state], shuffle=False) # In[45]: K.set_value(model.optimizer.lr, 0.0001) model.fit(x=X, y=y_cat, batch_size=bs, epochs=20, verbose=1, callbacks=[reset_state], shuffle=False) # And now the moment of truth! # In[55]: pprint(get_next_n_chars(model, bs, 'for thos', 400)) # Ha, well I wouldn't quite call it sensible but it's not super-terrible either. It's forming mostly complete words, occasionally using punctuation, etc. Not bad for being trained one character at a time. There are many ways that this can be improved of course, but hopefully this has illustrated the key concepts to building a sequence model. # In[ ]: