Spaces:
Sleeping
Sleeping
| from lightning import LightningModule | |
| import torch | |
| from models.model_loader import create_model | |
| from dataset.xray_loader import XrayData | |
| import wandb | |
| class XrayReg(LightningModule): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.save_hyperparameters(config) | |
| model_config = config["model"] | |
| dataset_config = config["dataset"] | |
| self.model = create_model(model_config["name"]) | |
| self.data = XrayData( | |
| dataset_config["root_dir"], | |
| dataset_config["label_csv"], | |
| dataset_config["batch_size"], | |
| val_split=dataset_config["val_split"], | |
| apply_equalization=dataset_config["apply_equalization"], | |
| ) | |
| self.data.setup() | |
| self.test_results = [] | |
| def forward(self, x): | |
| return self.model(x) | |
| def training_step(self, batch, batch_idx): | |
| x, y, filenames = batch | |
| y_hat = self(x).squeeze(1) | |
| loss = torch.nn.functional.mse_loss(y_hat, y) | |
| self.log("train_loss", torch.sqrt(loss), prog_bar=True) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| x, y, filenames = batch | |
| y_hat = self(x).squeeze(1) | |
| loss = torch.nn.functional.mse_loss(y_hat, y) | |
| self.log("val_loss", torch.sqrt(loss), prog_bar=True) | |
| return loss | |
| def test_step(self, batch, batch_idx): | |
| x, y, filenames = batch | |
| y_hat = self(x).squeeze(1) | |
| for img, pred, file, gt in zip(x, y_hat, filenames, y): | |
| self.test_results.append({ | |
| "image": | |
| wandb.Image(img.cpu().numpy().transpose( | |
| 1, 2, 0)), # Convert image for wandb logging | |
| "prediction": | |
| pred.item(), | |
| "filename": | |
| file, | |
| "ground_truth": | |
| gt.item(), | |
| "delta": | |
| abs(pred.item() - gt.item()), | |
| }) | |
| return None | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.Adam( | |
| self.parameters(), lr=self.hparams["training"]["learning_rate"]) | |
| sch = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, | |
| mode='min', | |
| factor=0.1, | |
| patience=5, | |
| verbose=True) | |
| return { | |
| 'optimizer': optimizer, | |
| 'lr_scheduler': { | |
| 'scheduler': sch, | |
| 'monitor': 'val_loss', | |
| "frequency": 1, | |
| "interval": "epoch", | |
| } | |
| } | |
| def train_dataloader(self): | |
| return self.data.train_dataloader() | |
| def val_dataloader(self): | |
| return self.data.val_dataloader() | |
| def test_dataloader(self): | |
| return self.data.test_dataloader() | |
| def save_test_results_to_wandb(self): | |
| columns = ["image", "filename", "prediction", "ground_truth", "delta"] | |
| wandb_table = wandb.Table(columns=columns) | |
| for result in self.test_results: | |
| wandb_table.add_data( | |
| result["image"], | |
| result["filename"], | |
| result["prediction"], | |
| result["ground_truth"], | |
| result["delta"], | |
| ) | |
| wandb.log({"test_results": wandb_table}) | |