$$ \begin{align*} h_t &= \tanh(W_h h_{t-1} + W_x x_t + b_h)\\ o_t &= W_o h_t + b_o \end{align*} $$
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_siez):
super().__init__()
self.input_layer = nn.Linear(input_size + hidden_size, hidden_size)
self.tanh = nn.Tanh()
def forward(self, X, h, batch_first=False):
if batch_first:
X = X.transpose(0, 1)
outputs = []
for x in X:
i = torch.cat((x, h), dim=1)
h = self.tanh(self.input_layer(i))
outputs.append(h)
outputs = torch.stack(outputs)
if batch_first:
outputs.transpose_(0, 1)
return outputs
R_t 重置门,$Z_t$ 更新门
$$ \begin{align*} R_t &= \sigma(X_tW_{xr} + H_{t-1}W_{hr}+b_r)\\ Z_t &= \sigma(X_tW_{xz} + H_{t-1}W_{hz} + b_z)\\ \tilde{H}t &= \tanh(X_tW{xh} + (R_t\odot H_{t-1})W_{hh} + b_h)\\ H_t &= Z_t\odot H_{t-1} + (1-Z_t)\odot \tilde{H}_t \end{align*} $$
class GRU(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.reset_gate = nn.Linear(input_size + hidden_size, hidden_size)
self.upgrade_gate = nn.Linear(input_size + hidden_size, hidden_size)
self.output_layer = nn.Linear(input_size + hidden_size, hidden_size)
self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()
def forward(self, X, h, batch_first=False):
if batch_first:
X = X.transpose(0, 1)
outputs = []
for x in X:
xh = torch.cat((x, h), dim=1)
rg = self.sigmoid(self.reset_gate(xh))
ug = self.sigmoid(self.upgrade_gate(xh))
h_candidate = self.tanh(self.output_layer(torch.cat((x, h * rg), dim=1)))
h = ug * h + (1 - ug) * h_candidate
outputs.append(h)
outputs = torch.stack(outputs)
if batch_first:
outputs.transpose_(0, 1)
return outputs
$$ \begin{align*} I_t&=\sigma(X_tW_{xi} + H_{t-1}W_{hi} + b_i)\\ F_t&=\sigma(X_tW_{xf} + H_{t-1}W_{hf} + b_f)\\ O_t&=\sigma(X_tW_{xo} + H_{t-1}W_{ho} + b_o)\\ \tilde{C}t &= \tanh(X_tW{xc} + H_{t-1}W_{hc} + b_c)\\ C_t &= F_t\odot C_{t-1} + I_t\odot \tilde{C}_t\\ H_t &= O_t\odot \tanh(C_t) \end{align*} $$
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.input_gate = nn.Linear(input_size + hidden_size, hidden_size)
self.forget_gate = nn.Linear(input_size + hidden_size, hidden_size)
self.output_gate = nn.Linear(input_size + hidden_size, hidden_size)
self.cell_gate = nn.Linear(input_size + hidden_size, hidden_size)
self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()
def forward(self, X, c, h, batch_first=False):
if batch_first:
X = X.transpose(0, 1)
outputs = []
for x in X:
xh = torch.cat((x, h), dim=1)
ig = self.sigmoid(self.input_gate(xh))
fg = self.sigmoid(self.forget_gate(xh))
og = self.sigmoid(self.output_gate(xh))
cell_candidate = self.tanh(self.cell_gate(xh))
c = c * fg + cell_candidate * ig
h = og * self.tanh(c)
outputs.append(h)
outputs = torch.stack(outputs)
if batch_first:
outputs.transpose_(0, 1)
return outputs