Spaces:
Build error
Build error
| import torch | |
| import torch | |
| import torch.nn as nn | |
| from transformers import ViTImageProcessor, ViTModel, BertTokenizerFast, BertModel | |
| from PIL import Image | |
| import gradio as gr | |
| class VisionLanguageModel(nn.Module): | |
| def __init__(self): | |
| super(VisionLanguageModel, self).__init__() | |
| self.vision_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k') | |
| self.language_model = BertModel.from_pretrained('bert-base-uncased') | |
| self.classifier = nn.Linear( | |
| self.vision_model.config.hidden_size + self.language_model.config.hidden_size, | |
| 2 # Number of classes: benign or malignant | |
| ) | |
| def forward(self, input_ids, attention_mask, pixel_values): | |
| vision_outputs = self.vision_model(pixel_values=pixel_values) | |
| vision_pooled_output = vision_outputs.pooler_output | |
| language_outputs = self.language_model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask | |
| ) | |
| language_pooled_output = language_outputs.pooler_output | |
| combined_features = torch.cat( | |
| (vision_pooled_output, language_pooled_output), | |
| dim=1 | |
| ) | |
| logits = self.classifier(combined_features) | |
| return logits | |
| # Load the model checkpoint with safer loading | |
| model = VisionLanguageModel() | |
| model.load_state_dict(torch.load('best_model.pth', map_location=torch.device('cpu'), weights_only=True)) | |
| model.eval() | |
| tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') | |
| feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k') | |
| def predict(image, text_input): | |
| # Preprocess the image | |
| image = feature_extractor(images=image, return_tensors="pt").pixel_values | |
| # Preprocess the text | |
| encoding = tokenizer( | |
| text_input, | |
| add_special_tokens=True, | |
| max_length=256, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='pt' | |
| ) | |
| # Make a prediction | |
| with torch.no_grad(): | |
| outputs = model( | |
| input_ids=encoding['input_ids'], | |
| attention_mask=encoding['attention_mask'], | |
| pixel_values=image | |
| ) | |
| _, prediction = torch.max(outputs, dim=1) | |
| return "Malignant" if prediction.item() == 1 else "Benign" | |
| # Define Gradio interface with updated component syntax | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Skin Lesion Image"), | |
| gr.Textbox(label="Clinical Information (e.g., patient age, symptoms)") | |
| ], | |
| outputs="text", | |
| title="Skin Lesion Classification Demo", | |
| description="This model classifies skin lesions as benign or malignant based on an image and clinical information." | |
| ) | |
| iface.launch() | |