1. Summary
MemN2N is a generalization of RNN
1) The sentence in MemN2N is equivalent to the word in RNN;
2. Kernel Code
- Build Model
def build_model(self):
self.W = tf.Variable(tf.random_normal([self.edim, self.nwords], stddev=self.init_std))
z = tf.matmul(self.hid[-1], self.W)
self.loss = tf.nn.softmax_cross_entropy_with_logits(logits=z, = tf.Variable(self.current_lr)
self.opt = tf.train.GradientDescentOptimizer(
params = [self.A, self.B, self.C, self.T_A, self.T_B, self.W]
grads_and_vars = self.opt.compute_gradients(self.loss,params)
clipped_grads_and_vars = [(tf.clip_by_norm(gv[0], self.max_grad_norm), gv[1]) for gv in grads_and_vars]
inc = self.global_step.assign_add(1)
with tf.control_dependencies([inc]):
self.optim = self.opt.apply_gradients(clipped_grads_and_vars)
self.saver = tf.train.Saver()
- Build Memory
def build_memory(self):
self.global_step = tf.Variable(0, name="global_step")
self.A = tf.Variable(tf.random_normal([self.nwords, self.edim], stddev=self.init_std))
self.B = tf.Variable(tf.random_normal([self.nwords, self.edim], stddev=self.init_std))
self.C = tf.Variable(tf.random_normal([self.edim, self.edim], stddev=self.init_std))
# Temporal Encoding
self.T_A = tf.Variable(tf.random_normal([self.mem_size, self.edim], stddev=self.init_std))
self.T_B = tf.Variable(tf.random_normal([self.mem_size, self.edim], stddev=self.init_std))
# m_i = sum A_ij * x_ij + T_A_i
Ain_c = tf.nn.embedding_lookup(self.A, self.context)
Ain_t = tf.nn.embedding_lookup(self.T_A, self.time)
Ain = tf.add(Ain_c, Ain_t)
# c_i = sum B_ij * u + T_B_i
Bin_c = tf.nn.embedding_lookup(self.B, self.context)
Bin_t = tf.nn.embedding_lookup(self.T_B, self.time)
Bin = tf.add(Bin_c, Bin_t)
for h in xrange(self.nhop):
self.hid3dim = tf.reshape(self.hid[-1], [-1, 1, self.edim])
Aout = tf.matmul(self.hid3dim, Ain, adjoint_b=True)
Aout2dim = tf.reshape(Aout, [-1, self.mem_size])
P = tf.nn.softmax(Aout2dim)
probs3dim = tf.reshape(P, [-1, 1, self.mem_size])
Bout = tf.matmul(probs3dim, Bin)
Bout2dim = tf.reshape(Bout, [-1, self.edim])
Cout = tf.matmul(self.hid[-1], self.C)
Dout = tf.add(Cout, Bout2dim)
if self.lindim == self.edim:
elif self.lindim == 0:
F = tf.slice(Dout, [0, 0], [self.batch_size, self.lindim])
G = tf.slice(Dout, [0, self.lindim], [self.batch_size, self.edim-self.lindim])
K = tf.nn.relu(G)
self.hid.append(tf.concat(axis=1, values=[F, K]))
3. Reference
