| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel, PretrainedConfig | |
| class LexicalConfig(PretrainedConfig): | |
| model_type = "lexical_embedding" | |
| def __init__( | |
| self, | |
| vocab_size=30522, | |
| embed_dim=2048, | |
| padding_idx=0, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| self.vocab_size = vocab_size | |
| self.embed_dim = embed_dim | |
| self.padding_idx = padding_idx | |
| class LexicalHFModel(PreTrainedModel): | |
| config_class = LexicalConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| self.embedding = nn.Embedding( | |
| config.vocab_size, | |
| config.embed_dim, | |
| padding_idx=config.padding_idx | |
| ) | |
| def forward(self, input_ids, attention_mask=None, **kwargs): | |
| embeds = self.embedding(input_ids) | |
| if attention_mask is None: | |
| attention_mask = torch.ones_like(input_ids) | |
| mask_expanded = attention_mask.unsqueeze(-1).expand(embeds.size()).float() | |
| sum_embeddings = torch.sum(embeds * mask_expanded, 1) | |
| sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9) | |
| mean_pooled = sum_embeddings / sum_mask | |
| return torch.nn.functional.normalize(mean_pooled, p=2, dim=1) |