1

Context: I am currently working on an encoder-decoder sequence to sequence model that uses a sequence of word embeddings as input and output, and then reduces the dimensionality of the word embeddings.

The word embeddings are created using pre-trained models. I want to be able to decode the word embeddings returned by the decoder of the Sequence to Sequence model back to natural language.

Question: How can I train a Decoder that works with the sequence of word embeddings and the original sentence for this task?

See below for the code that generates the word embeddings:

from typing import List

import numpy as np import torch from transformers.tokenization_utils_base import BatchEncoding from transformers import BertTokenizerFast, BertModel

TOKENIZER = BertTokenizerFast.from_pretrained('bert-base-uncased') MODEL = BertModel.from_pretrained('bert-base-uncased')

def get_word_indices(sentence: str, separator=" ") -> List: sent = sentence.split(sep=separator) return list(range(len(sent)))

def encode_sentence(sentence: str) -> BatchEncoding: encoded_sentence = TOKENIZER(sentence) return encoded_sentence

def get_hidden_states(encoded: BatchEncoding, layers: list = [-1, -2, -3, -4]) -> torch.Tensor: with torch.no_grad(): output = MODEL(**encoded) hidden_states = output.hidden_states output = torch.stack([hidden_states[i] for i in layers]).sum(0).squeeze() return output

def get_token_ids(word_index: int, encoded: BatchEncoding): token_ids = np.where(np.array(encoded.word_ids()) == word_index) return token_ids

def embed_words(sentence: str) -> torch.Tensor: word_indices = get_word_indices(sentence) encoded_sentence = encode_sentence(sentence) hidden_states = get_hidden_states(encoded_sentence) word_embeddings = [] for word_index in word_indices: # Get the ids of the word in the sentence # Important, because BERT sometimes splits words into subwords token_ids = get_token_ids(word_index, encoded_sentence) # Get all the hidden states for each word (or subwords belonging to one word) # Average the hidden states in case of subwords to retrieve word embedding word_embedding = hidden_states[token_ids].mean(dim=0) word_embeddings.append(word_embedding) return torch.stack(word_embeddings) ```

node_env
  • 11
  • 2

1 Answers1

0

You can follow this example code https://github.com/microsoft/CodeBERT/blob/master/CodeBERT/code2nl/README.md

This example generally loads pre-trained Bert (encoder) and plugs a custom decoder. Regarding its encoder, the example utilizes transformer for loading pre-train.

Bryan
  • 101
  • 1