import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig class LexicalConfig(PretrainedConfig): model_type = "lexical_embedding" def __init__( self, vocab_size=30522, embed_dim=2048, padding_idx=0, **kwargs ): super().__init__(**kwargs) self.vocab_size = vocab_size self.embed_dim = embed_dim self.padding_idx = padding_idx class LexicalHFModel(PreTrainedModel): config_class = LexicalConfig def __init__(self, config): super().__init__(config) self.config = config self.embedding = nn.Embedding( config.vocab_size, config.embed_dim, padding_idx=config.padding_idx ) def forward(self, input_ids, attention_mask=None, **kwargs): embeds = self.embedding(input_ids) if attention_mask is None: attention_mask = torch.ones_like(input_ids) mask_expanded = attention_mask.unsqueeze(-1).expand(embeds.size()).float() sum_embeddings = torch.sum(embeds * mask_expanded, 1) sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9) mean_pooled = sum_embeddings / sum_mask return torch.nn.functional.normalize(mean_pooled, p=2, dim=1)