Методы инициализации весов во встраиваниях: подробное руководство с примерами кода

“Инициализация весов при внедрении”

Когда дело доходит до инициализации весов во встраиваниях, вы можете рассмотреть несколько методов. Я предоставлю вам список популярных методов вместе с примерами кода на Python:

  1. Случайная инициализация.
    Этот метод инициализирует веса внедрения случайными значениями из равномерного или нормального распределения.

    import torch
    import torch.nn as nn
    embedding_dim = 300
    vocab_size = 10000
    embedding = nn.Embedding(vocab_size, embedding_dim)
    embedding.weight.data.uniform_(-0.1, 0.1)  # Random initialization
  2. Предварительно обученные внедрения слов.
    Вы можете инициализировать веса внедрения, используя предварительно обученные внедрения слов, такие как Word2Vec, GloVe или FastText. Эти внедрения обучаются на больших корпусах и фиксируют семантические отношения между словами.

    import torch
    import torch.nn as nn
    embedding_dim = 300
    vocab_size = 10000
    embedding = nn.Embedding(vocab_size, embedding_dim)
    pretrained_embeddings = torch.load('pretrained_embeddings.pt')
    embedding.weight.data.copy_(pretrained_embeddings)
  3. Инициализация Xavier/Glorot:
    Этот метод инициализирует веса внедрения, используя равномерное распределение, масштабируемое по квадратному корню из обратного измерения внедрения.

    import torch
    import torch.nn as nn
    embedding_dim = 300
    vocab_size = 10000
    embedding = nn.Embedding(vocab_size, embedding_dim)
    nn.init.xavier_uniform_(embedding.weight.data)
  4. Инициализация He:
    Этот метод похож на инициализацию Xavier, но используется для функций активации, таких как ReLU. Он инициализирует веса внедрения, используя нормальное распределение, масштабированное квадратным корнем обратного измерения внедрения.

    import torch
    import torch.nn as nn
    embedding_dim = 300
    vocab_size = 10000
    embedding = nn.Embedding(vocab_size, embedding_dim)
    nn.init.kaiming_normal_(embedding.weight.data)