位置编码公式
偶数位置用sin,奇数位置用cos. d_model 表示token的维度;pos表示token在序列中的位置;i表示每个token编码的第i个位置,属于[0,d_model)。
 
torch实现
import math
import torch
from torch import nn
from torch.autograd import Variable
import matplotlib.pyplot as plt
class PositionalEncoder(nn.Module):
    def __init__(self, max_seq_len=50, d_model=128):
        super().__init__()
        self.d_model = d_model  # d_model 表示token的维度
        pe = torch.zeros(max_seq_len, d_model)  # max_seq_len * d_model 的二维张量   例如: 50*128
        for pos in range(max_seq_len):  # 重新初始化
            for i in range(0, d_model, 2):
                pe[pos, i] = math.sin(pos / (10000 ** (i / d_model)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** (i / d_model)))
        pe = pe.unsqueeze(0)  # 1*50*128
        self.register_buffer('pe', pe)
    def forward(self, x):
        x = x * math.sqrt(self.d_model)
        seq_len = x.size(1)
        x = x + Variable(self.pe[:, :seq_len], requires_grad=False).cuda()
        return x
if __name__ == '__main__':
    positional_encoder = PositionalEncoder(50, 128)
    plt.pcolormesh(positional_encoder.pe.numpy()[0], cmap='RdBu')
    plt.xlabel('Depth')  # 50
    plt.xlim((0, 128))
    plt.ylabel('Position')  # 128
    plt.colorbar()
    plt.show() 
位置编码可视化




















