diff --git a/modules/simclr.py b/modules/simclr.py index 326735e..e6ba22a 100644 --- a/modules/simclr.py +++ b/modules/simclr.py @@ -27,17 +27,6 @@ def __init__(self, args, encoder, n_features): nn.Linear(self.n_features, args.projection_dim, bias=False), ) - def get_resnet(self, name): - resnets = { - "resnet18": torchvision.models.resnet18(), - "resnet50": torchvision.models.resnet50(), - } - if name not in resnets.keys(): - raise KeyError(f"{name} is not a valid ResNet version") - return modify_resnet_model( - resnets[name], cifar_stem=self.args.dataset.startswith("CIFAR"), v1=True - ) - def forward(self, x_i, x_j): h_i = self.encoder(x_i) h_j = self.encoder(x_j)