import theano.tensor as tt
from keras.layers.recurrent import GRU
from keras.layers.core import Dense, MaskedLayer, Layer, Merge
from keras.models import Graph
from keras.utils.theano_utils import shared_zeros
class SoftSequentialAttentionLayer(MaskedLayer):
def __init__(self, memmory_dim, driver_dim, inner_dim=128, init='glorot_uniform', inner_activation='relu'):
super(SoftSequentialAttentionLayer, self).__init__()
self.init = initializations.get(init)
self.W_m = self.init((memory_dim, inner_dim))
self.W_d = self.init((driver_dim, inner_dim))
self.W_a = self.init((inner_dim, 1))
self.inner_activation = activations.get(inner_activation)
self.b_inner = shared_zeros(inner_dim)
self.b_out = shared_zeros(1)
def set_previous(self, *previous_layers):
type_name = self.__class__.__name__
if len(previous_layers) != 2:
raise ValueError("{}.set_previous expects 2 input layers, got {}".format(
type_name, previous_layers))
sequential_memory, attention_driver = previous_layers
if not sequential_memory.return_sequences:
raise ValueError("The first input of {} should be a recurrent layer with"
" return_sequences=True".format(type_name))
self.sequential_memory = sequential_memory
self.attention_driver = attention_driver
def get_input(self, train=False):
return [self.sequential_memory.get_output(train=train),
self.attention_driver.get_output(train=train)]
def get_output(self, train=False):
sequential_memory, attention_driver = self.get_input(train=train)
# sequential_memory shape: (nb_samples, time (padded with zeros), input_dim)
# attentin_driver shape: (nb_samples, input_dim)
# new shape: (time, nb_samples, input_dim) -> because theano.scan iterates over main dimension
padded_mask = self.get_padded_shuffled_mask(train, sequential_memory, pad=1)
sequential_memory = sequential_memory.dimshuffle((1, 0, 2))
h = self.inner_activation(tt.dot(sequential_memory, self.W_m)
+ tt.dot(driver, self.W_d)
+ self.b_inner)
a = tt.exp(tt.dot(h, self.W_a) + self.b_out)
output = None #XXX: TODO
return output
def _variable_length_softmax_step(self, a_t, sum_t):
return )
class CustomGraph(Graph):
def add_node(self, layer, name, input=None, inputs=[], merge_mode='concat', create_output=False):
if hasattr(layer, 'set_name'):
layer.set_name(name)
if name in self.namespace:
raise Exception('Duplicate node identifier: ' + name)
if input:
if input not in self.namespace:
raise Exception('Unknown node/input identifier: ' + input)
if input in self.nodes:
layer.set_previous(self.nodes[input])
elif input in self.inputs:
layer.set_previous(self.inputs[input])
if inputs:
to_merge = []
for n in inputs:
if n in self.nodes:
to_merge.append(self.nodes[n])
elif n in self.inputs:
to_merge.append(self.inputs[n])
else:
raise Exception('Unknown identifier: ' + n)
# XXX: here is the change
if merge_mode == 'distinct':
layer.set_previous(*to_merge)
else:
merge = Merge(to_merge, mode=merge_mode)
layer.set_previous(merge)
self.namespace.add(name)
self.nodes[name] = layer
self.node_config.append({'name': name,
'input': input,
'inputs': inputs,
'merge_mode': merge_mode})
layer.init_updates()
params, regularizers, constraints, updates = layer.get_params()
self.params += params
self.regularizers += regularizers
self.constraints += constraints
self.updates += updates
if create_output:
self.add_output(name, input=name)
graph = CustomGraph()
graph.add_input(name='context_sequences', ndim=3)
graph.add_node(GRU(32, return_sequences=True), name='dense1', input='context_sequences')
graph.add_node(Dense(32, 4), name='dense2', input='context_sequences')
graph.add_node(SoftSequentialAttentionLayer(),
name='attention', inputs=['dense1', 'dense2'],
merge_mode='distinct')
graph.add_output(name='output1', input='dense2')
graph.add_output(name='output2', input='attention')
graph.nodes
{'attention': <__main__.SoftSequentialAttentionLayer at 0x10873d630>, 'dense1': <keras.layers.recurrent.GRU at 0x1085caeb8>, 'dense2': <keras.layers.core.Dense at 0x10873f438>}
graph.namespace
{'attention', 'context_sequences', 'dense1', 'dense2'}
import numpy as np
x = np.arange(3 * 4 * 5).reshape(5, 3, 4)
a = np.arange(4 * 2).reshape(4, 2)
np.dot(x, a).shape
(5, 3, 2)