Skip to content

Models

model.transformers

get_embedding(text, tokenizer, model, device)

Get the mean embedding of a text using a transformer model

Parameters:

Name Type Description Default
text str

text to embed

required
tokenizer dict

tokenizer object

required
model str

model object

required
device str

device to use

required

Returns:

Type Description
ndarray

np.ndarray: mean embedding of the text

Source code in src/model/transformers.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def get_embedding(text: str, tokenizer: dict, model: str, device: str) -> np.ndarray:
    """Get the mean embedding of a text using a transformer model

    Args:
        text (str): text to embed
        tokenizer (dict): tokenizer object
        model (str): model object
        device (str): device to use

    Returns:
        np.ndarray: mean embedding of the text
    """
    import torch

    # Tokenize and move to device
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    inputs = {key: value.to(device) for key, value in inputs.items()}

    with torch.no_grad():
        # Forward pass to get model outputs
        outputs = model(**inputs)

        # Access last_hidden_state
        last_hidden_state = outputs.last_hidden_state
        # Compute mean of last hidden state and convert to numpy array
        mean_embedding = last_hidden_state.mean(dim=1).squeeze().cpu().numpy()

    return mean_embedding