Spaces:
Runtime error
Runtime error
| # src/model_loader.py | |
| import os | |
| import math | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer, AutoConfig | |
| from huggingface_hub import snapshot_download | |
| MODEL_NAME = "OpenGVLab/InternVL3-14B" | |
| CACHE_DIR = "/data/internvl3_model" | |
| # === 自动分配模型层到多张 GPU(InternVL3 建议方式) === | |
| def split_model(model_path): | |
| device_map = {} | |
| world_size = torch.cuda.device_count() | |
| config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) | |
| num_layers = config.llm_config.num_hidden_layers | |
| num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5)) | |
| num_layers_per_gpu = [num_layers_per_gpu] * world_size | |
| num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5) | |
| layer_cnt = 0 | |
| for i, num_layer in enumerate(num_layers_per_gpu): | |
| for _ in range(num_layer): | |
| device_map[f'language_model.model.layers.{layer_cnt}'] = i | |
| layer_cnt += 1 | |
| # 固定组件放在 GPU 0 | |
| for key in [ | |
| 'vision_model', 'mlp1', | |
| 'language_model.model.tok_embeddings', | |
| 'language_model.model.embed_tokens', | |
| 'language_model.output', | |
| 'language_model.model.norm', | |
| 'language_model.model.rotary_emb', | |
| 'language_model.lm_head', | |
| f'language_model.model.layers.{num_layers - 1}' | |
| ]: | |
| device_map[key] = 0 | |
| return device_map | |
| # === 模型加载函数 === | |
| def load_model(): | |
| if not os.path.exists(CACHE_DIR): | |
| print("⏬ First run: downloading model to persistent storage...") | |
| snapshot_download(repo_id=MODEL_NAME, local_dir=CACHE_DIR) | |
| else: | |
| print("✅ Loaded model from persistent cache.") | |
| device_map = split_model(CACHE_DIR) | |
| tokenizer = AutoTokenizer.from_pretrained(CACHE_DIR, trust_remote_code=True) | |
| model = AutoModel.from_pretrained( | |
| CACHE_DIR, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| use_flash_attn=False, # 或者True,如果确认安装好FlashAttention | |
| trust_remote_code=True, | |
| device_map=device_map | |
| ).eval() | |
| return tokenizer, model | |