Models
model.transformers
get_embedding(text, tokenizer, model, device)
Get the mean embedding of a text using a transformer model
Parameters:
Name | Type | Description | Default |
---|---|---|---|
text |
str
|
text to embed |
required |
tokenizer |
dict
|
tokenizer object |
required |
model |
str
|
model object |
required |
device |
str
|
device to use |
required |
Returns:
Type | Description |
---|---|
ndarray
|
np.ndarray: mean embedding of the text |
Source code in src/model/transformers.py
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
|