Spaces:
Sleeping
Sleeping
| import timm | |
| import torch.nn as nn | |
| def replace_last_layer(model, num_classes): | |
| modules = list(model.named_modules()) | |
| for name, module in reversed(modules): | |
| if isinstance(module, nn.Linear): | |
| in_features = module.in_features | |
| new_fc = nn.Linear(in_features, num_classes) | |
| new_fc.requires_grad = True | |
| if "." in name: | |
| parent_name, child_name = name.rsplit(".", 1) | |
| parent = dict(model.named_modules())[parent_name] | |
| setattr(parent, child_name, new_fc) | |
| else: | |
| setattr(model, name, new_fc) | |
| break | |
| elif isinstance(module, nn.Conv2d): | |
| out_channels = module.out_channels | |
| new_fc = nn.Conv2d( | |
| out_channels, | |
| num_classes, | |
| kernel_size=module.kernel_size, | |
| stride=module.stride, | |
| padding=module.padding, | |
| ) | |
| new_fc.requires_grad = True | |
| if "." in name: | |
| parent_name, child_name = name.rsplit(".", 1) | |
| parent = dict(model.named_modules())[parent_name] | |
| setattr(parent, child_name, new_fc) | |
| else: | |
| setattr(model, name, new_fc) | |
| break | |
| def enable_first_layer_grad(model): | |
| for name, module in model.named_modules(): | |
| if isinstance(module, nn.Conv2d): | |
| module.requires_grad = True | |
| break | |
| def create_model(key, in_chans=1, num_classes=1): | |
| model = timm.create_model( | |
| key, pretrained=False, in_chans=in_chans, num_classes=num_classes | |
| ) | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| enable_first_layer_grad(model) | |
| replace_last_layer(model, num_classes) | |
| return model | |