RNN

$$ \begin{align*} h_t &= \tanh(W_h h_{t-1} + W_x x_t + b_h)\\ o_t &= W_o h_t + b_o \end{align*} $$

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_siez):
        super().__init__()
        self.input_layer = nn.Linear(input_size + hidden_size, hidden_size)
        self.tanh = nn.Tanh()
    
    def forward(self, X, h, batch_first=False):
        if batch_first:
            X = X.transpose(0, 1)
        outputs = []
        for x in X:
            i = torch.cat((x, h), dim=1)
            h = self.tanh(self.input_layer(i))
            outputs.append(h)
        outputs = torch.stack(outputs)
        if batch_first:
            outputs.transpose_(0, 1)
        return outputs

GRU

image.png

R_t 重置门,$Z_t$ 更新门

$$ \begin{align*} R_t &= \sigma(X_tW_{xr} + H_{t-1}W_{hr}+b_r)\\ Z_t &= \sigma(X_tW_{xz} + H_{t-1}W_{hz} + b_z)\\ \tilde{H}t &= \tanh(X_tW{xh} + (R_t\odot H_{t-1})W_{hh} + b_h)\\ H_t &= Z_t\odot H_{t-1} + (1-Z_t)\odot \tilde{H}_t \end{align*} $$

class GRU(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.reset_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.upgrade_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.output_layer = nn.Linear(input_size + hidden_size, hidden_size)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, X, h, batch_first=False):
        if batch_first:
            X = X.transpose(0, 1)
        outputs = []
        for x in X:
            xh = torch.cat((x, h), dim=1)
            rg = self.sigmoid(self.reset_gate(xh))
            ug = self.sigmoid(self.upgrade_gate(xh))
            h_candidate = self.tanh(self.output_layer(torch.cat((x, h * rg), dim=1)))
            h = ug * h + (1 - ug) * h_candidate
            outputs.append(h)
        outputs = torch.stack(outputs)
        if batch_first:
            outputs.transpose_(0, 1)
        return outputs

LSTM

image.png

$$ \begin{align*} I_t&=\sigma(X_tW_{xi} + H_{t-1}W_{hi} + b_i)\\ F_t&=\sigma(X_tW_{xf} + H_{t-1}W_{hf} + b_f)\\ O_t&=\sigma(X_tW_{xo} + H_{t-1}W_{ho} + b_o)\\ \tilde{C}t &= \tanh(X_tW{xc} + H_{t-1}W_{hc} + b_c)\\ C_t &= F_t\odot C_{t-1} + I_t\odot \tilde{C}_t\\ H_t &= O_t\odot \tanh(C_t) \end{align*} $$

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.forget_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.output_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.cell_gate = nn.Linear(input_size + hidden_size, hidden_size)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, X, c, h, batch_first=False):
        if batch_first:
            X = X.transpose(0, 1)
        outputs = []
        for x in X:
            xh = torch.cat((x, h), dim=1)
            ig = self.sigmoid(self.input_gate(xh))
            fg = self.sigmoid(self.forget_gate(xh))
            og = self.sigmoid(self.output_gate(xh))
            cell_candidate = self.tanh(self.cell_gate(xh))
            c = c * fg + cell_candidate * ig
            h = og * self.tanh(c)
            outputs.append(h)
        outputs = torch.stack(outputs)
        if batch_first:
            outputs.transpose_(0, 1)
        return outputs