| | import copy |
| | import os |
| | import sys |
| |
|
| | dir_path = os.path.dirname(os.path.realpath(__file__)) |
| | sys.path.insert(0, dir_path) |
| |
|
| | import contextlib |
| |
|
| | import torch.utils.checkpoint |
| | import torch.nn as nn |
| | from torch.nn import LayerNorm |
| | from torchvision import transforms |
| | from torchvision.transforms.functional import InterpolationMode |
| | from PIL import Image |
| |
|
| | from .modeling_vit import * |
| | from .modeling_InternLM import * |
| | from .modeling_utils import * |
| | from .resampler import create_resampler |
| |
|
| | from transformers.utils import logging |
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class InternLMXComposerForCausalLM(PreTrainedModel): |
| | config_class = InternLMXComposerConfig |
| | _auto_class = "AutoModelForCausalLM" |
| |
|
| | gen_config = dict( |
| | num_beams=5, |
| | do_sample=True, |
| | min_length=1, |
| | repetition_penalty=1.5, |
| | length_penalty=1.0, |
| | temperature=1.0, |
| | max_new_tokens=500, |
| | ) |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | self.max_length = config.max_length |
| | print (f'Set max length to {self.max_length}') |
| | print('Init VIT ... ', end='') |
| | self.visual_encoder = create_eva_vit_g(img_size=448) |
| | self.ln_vision = nn.Identity() |
| | self.supports_gradient_checkpointing = True |
| | print('Done') |
| | print('Init Perceive Sampler ... ', end='') |
| | with all_logging_disabled(): |
| | self.Qformer = create_resampler(num_query_token=256) |
| | print('Done') |
| |
|
| | print('Init InternLM ... ', end='') |
| | self.flag_image_start = nn.Parameter(torch.zeros([1, 1, 4096])) |
| | self.flag_image_end = nn.Parameter(torch.zeros([1, 1, 4096])) |
| | self.flag_image_start.requires_grad = False |
| | self.flag_image_end.requires_grad = False |
| |
|
| |
|
| | if int(torch.__version__[0]) == 1: |
| | self.internlm_model = InternLMForCausalLM._from_config(config).to( |
| | torch.float16) |
| | else: |
| | assert int(torch.__version__[0]) == 2 |
| | |
| | with torch.device('meta'): |
| | self.internlm_model = InternLMForCausalLM._from_config(config) |
| | self.internlm_model.to_empty(device=config.device).to(torch.float16) |
| |
|
| | self.internlm_proj = nn.Linear(4096, |
| | self.internlm_model.config.hidden_size) |
| | print('Done') |
| |
|
| | self.vis_processor = transforms.Compose([ |
| | transforms.Resize((448, 448), |
| | interpolation=InterpolationMode.BICUBIC), |
| | transforms.ToTensor(), |
| | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), |
| | (0.26862954, 0.26130258, 0.27577711)), |
| | ]) |
| |
|
| | self.tokenizer = None |
| |
|
| | @property |
| | def eoh(self): |
| | return '<TOKENS_UNUSED_0>' |
| |
|
| | @property |
| | def eoa(self): |
| | return '<TOKENS_UNUSED_1>' |
| |
|
| | def get_input_embeddings(self): |
| | return self.internlm_model.get_input_embeddings() |
| | |
| | def _set_gradient_checkpointing(self, module, value=False): |
| | if value: |
| | self.internlm_model.apply( |
| | partial(self.internlm_model._set_gradient_checkpointing, value=True) |
| | ) |
| |
|
| |
|
| | def encode_img(self, image): |
| | if image is None: |
| | return None |
| | if isinstance(image, str): |
| | image = Image.open(image).convert("RGB") |
| | image = self.vis_processor(image).unsqueeze(0).to(self.device) |
| | else: |
| | assert isinstance(image, torch.Tensor) |
| | device = image.device |
| | image_embeds = self.ln_vision( |
| | self.visual_encoder(image)).to(device) |
| | image_atts = torch.ones(image_embeds.size()[:-1], |
| | dtype=torch.long).to(device) |
| | query_output = self.Qformer(image_embeds) |
| | inputs_internlm = self.internlm_proj(query_output) |
| |
|
| | inputs_internlm = torch.cat([ |
| | self.flag_image_start.expand(inputs_internlm.shape[0], -1, -1), |
| | inputs_internlm, |
| | self.flag_image_end.expand(inputs_internlm.shape[0], -1, -1) |
| | ], |
| | dim=1) |
| | return inputs_internlm |
| |
|
| | def encode_text(self, text, add_special_tokens=False): |
| | text_token_ids = self.tokenizer( |
| | text, |
| | return_tensors='pt', |
| | add_special_tokens=add_special_tokens, |
| | ).input_ids.to(self.device) |
| | text_embeds = self.internlm_model.model.embed_tokens(text_token_ids) |
| | return text_embeds |
| |
|
| | def decode_text(self, out_embeds): |
| | out_text = self.tokenizer.batch_decode(out_embeds, |
| | skip_special_tokens=True)[0] |
| | out_text = out_text.split(self.eoa)[0] |
| | return out_text |
| |
|
| | def wrap_text(self, user_text, bot_text='', add_special=True): |
| | if add_special: |
| | eoh = self.eoh |
| | else: |
| | eoh = '' |
| | text = f'<|User|>:{user_text}{eoh}\n<|Bot|>:{bot_text}' |
| | return text |
| |
|
| | def get_gen_args(self, **kwargs): |
| | new_kargs = copy.deepcopy(self.gen_config) |
| | new_kargs.update(kwargs) |
| | return new_kargs |
| | |
| | def generate(self, text, image=None, **kwargs): |
| | text_embeds = self.encode_text(text) |
| | img_embeds = self.encode_img(image) |
| | prompt_embeds = self.wrap_prompt(text_embeds, img_embeds) |
| | out_embeds = self.internlm_model.generate(inputs_embeds=prompt_embeds, |
| | **self.get_gen_args(**kwargs)) |
| | out_text = self.decode_text(out_embeds) |
| | return out_text |
| |
|
| | def chat(self, text, image=None, history=None, **kwargs): |
| | text_embeds = self.encode_text(text) |
| | img_embeds = self.encode_img(image) |
| | prompt_embeds = self.wrap_prompt(text_embeds, |
| | img_embeds, |
| | history=history) |
| | out_embeds = self.internlm_model.generate(inputs_embeds=prompt_embeds, |
| | **self.get_gen_args(**kwargs)) |
| | out_text = self.decode_text(out_embeds) |
| |
|
| | |
| | clean_out_text_token_ids = self.tokenizer( |
| | out_text, return_tensors='pt').input_ids.to(self.device) |
| | clean_out_text_embeds = self.internlm_model.model.embed_tokens( |
| | clean_out_text_token_ids) |
| | clean_prompt_embeds = self.wrap_prompt(text_embeds, |
| | img_embeds, |
| | add_special=False) |
| | cur_history = torch.cat([clean_prompt_embeds, clean_out_text_embeds], |
| | dim=1) |
| | if history is None: |
| | history = [] |
| | history.append(cur_history) |
| | return out_text, history |
| |
|
| | def wrap_prompt(self, |
| | text_embeds, |
| | img_embeds=None, |
| | history=None, |
| | add_special=True): |
| | if add_special: |
| | prompt_segs = ['<|User|>:', f'{self.eoh}\n<|Bot|>:'] |
| | else: |
| | prompt_segs = ['<|User|>:', '<|Bot|>:'] |
| | prompt_seg_embeds = [] |
| | for i, seg in enumerate(prompt_segs): |
| | if history is not None: |
| | add_special_tokens = False |
| | else: |
| | add_special_tokens = i == 0 |
| | seg_embeds = self.encode_text( |
| | seg, add_special_tokens=add_special_tokens) |
| | prompt_seg_embeds.append(seg_embeds) |
| | if img_embeds is None: |
| | img_embeds = text_embeds.new_empty(text_embeds.size(0), 0, |
| | text_embeds.size(-1)) |
| | prompt_seg_embeds = [ |
| | prompt_seg_embeds[0], img_embeds, text_embeds, prompt_seg_embeds[1] |
| | ] |
| | prompt_embeds = torch.cat(prompt_seg_embeds, dim=1) |
| | if history is not None: |
| | prompt_embeds = torch.cat([*history, prompt_embeds], dim=1) |
| | return prompt_embeds |
| |
|
| |
|