位置编码

正弦余弦固定位置编码

class PositionEmbedding(nn.Module):
    def __init__(self, hidden_size, max_length=1024):
        super().__init__()
        self.P = torch.zeros((max_length, hidden_size))
        x = torch.arange(max_length).reshape(-1, 1) /\\
            torch.pow(torch.tensor(10000), torch.arange(0, hidden_size, 2) / hidden_size).reshape(1, -1)
        self.P[:, 0::2] = torch.sin(x)
        self.P[:, 1::2] = torch.cos(x)

    def forward(self, x):
        seq_length = x.shape[1]
        return x + self.P[:seq_length, :]

注意力

多头注意力

class ScaledDotAttention(nn.Module):
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, valid_lens=None):
        """
        @param q: shape = (batch_size, query_num, hidden_size)
        @param k, v: shape = (batch_size, kv_num, hidden_size)
        @param valid_lens: shape = (batch_size,) or (batch_size, query_num)
        @return: shape = (batch_size, query_num, hidden_size)
        """
        hidden_size = q.shape[-1]
        # scores: shape = (batch_size, query_num, kv_num)
        scores = torch.bmm(q, k.transpose(1, 2)) / math.sqrt(hidden_size)
        attention_weights = ScaledDotAttention.masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(attention_weights), v)

    @staticmethod
    def masked_softmax(scores, valid_lens=None):
        """
        @param scores: shape = (batch_size, query_num, kv_num)
        @param valid_lens: shape = (batch_size,) or (batch_size, query_num)
        @return: shape = (batch_size, query_num, kv_num)
        """
        if valid_lens is not None:
            if valid_lens.dim() == 1:
                mask = torch.arange(scores.shape[-1])[None, :] >= valid_lens[:, None]
                mask = torch.repeat_interleave(mask.unsqueeze(1), scores.shape[1], dim=1)
                # mask = mask.unsqueeze(1).repeat(1, scores.shape[1], 1)
            else:
                mask = torch.arange(scores.shape[-1])[None, None, :] >= valid_lens[:, :, None]
            scores[mask] = -torch.inf
        return F.softmax(scores, dim=-1)

class MultiHeadAttention(nn.Module):
    def __init__(self, query_size, key_size, value_size, hidden_size, output_size, num_heads, dropout):
        super().__init__()
        self.num_heads = num_heads
        self.wq = nn.Linear(query_size, hidden_size, bias=False)
        self.wk = nn.Linear(key_size, hidden_size, bias=False)
        self.wv = nn.Linear(value_size, hidden_size, bias=False)
        self.scaledDotAttention = ScaledDotAttention(dropout)
        self.dense = nn.Linear(hidden_size, output_size)

    def forward(self, q, k, v, valid_lens=None):
        """
        @param q: shape = (batch_size, query_num, query_size)
        @param k: shape = (batch_size, kv_num, key_size)
        @param v: shape = (batch_size, kv_num, value_size)
        @param valid_lens: shape = (batch_size,) or (batch_size, query_num)
        @return: shape = (batch_size, query_num, hidden_size)
        """
        q, k, v = map(self.transpose_qkv, (self.wq(q), self.wk(k), self.wv(v)))
        if valid_lens is not None:
            valid_lens = torch.repeat_interleave(valid_lens, self.num_heads, dim=0)
        cat_hidden = self.scaledDotAttention(q, k, v, valid_lens)
        return self.dense(self.invtranspose_qkv(cat_hidden))
    
    def transpose_qkv(self, tensor):
        """
        @param tensor: shape = (batch_size, num, hidden_size)
        @return: shape = (batch_size * num_heads, num, hidden_size / num_heads)
        """
        tensor = tensor.reshape(tensor.shape[0], tensor.shape[1], self.num_heads, -1)
        tensor = tensor.permute(0, 2, 1, 3)
        return tensor.reshape(-1, tensor.shape[2], tensor.shape[3])

    def invtranspose_qkv(self, tensor):
        """
        @param tensor: shape = (batch_size * num_heads, num, hidden_size / num_heads)
        @return: shape = (batch_size, num, hidden_size)
        """
        tensor = tensor.reshape(-1, self.num_heads, tensor.shape[1], tensor.shape[2])
        tensor = tensor.permute(0, 2, 1, 3)
        return tensor.reshape(tensor.shape[0], tensor.shape[1], -1)

  1. 自注意力:mask 掉 <pad> 的词
  2. 掩码注意力:mask 掉未来的词 valid_lens = torch.arange(1, len + 1).repeat(batch, 1)
  3. 交叉注意力:mask 掉 <pad> 的词

FFN

positionwise feed-forward network

位置无关,对每个词向量表示做一个 MLP(可以改变里层向量的维度)

class PositionWiseFFN(nn.Module):
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs):
        super().__init__()
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

Add & Norm

class AddNorm(nn.Module):
    def __init__(self, hidden_size, dropout=0):
        super().__init__()
        # self.ln = nn.LayerNorm(hidden_size)
        self.ln = LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, fx):
        return self.ln(x + self.dropout(fx))

残差连接