diff --git a/efficientnet_pytorch/model.py b/efficientnet_pytorch/model.py index f2bc2af..21f477d 100755 --- a/efficientnet_pytorch/model.py +++ b/efficientnet_pytorch/model.py @@ -12,6 +12,7 @@ load_pretrained_weights, Swish, MemoryEfficientSwish, + Mish, ) class MBConvBlock(nn.Module): @@ -92,11 +93,15 @@ def forward(self, inputs, drop_connect_rate=None): x = drop_connect(x, p=drop_connect_rate, training=self.training) x = x + inputs # skip connection return x - - def set_swish(self, memory_efficient=True): + + def set_swish(self, act_="mem_swish"): """Sets swish function as memory efficient (for training) or standard (for export)""" - self._swish = MemoryEfficientSwish() if memory_efficient else Swish() - + if act_ == "swish": + self._swish = Swish() + elif act_ == "mish": + self._swish = Mish() + else: + self._swish = MemoryEfficientSwish() class EfficientNet(nn.Module): """ @@ -161,12 +166,16 @@ def __init__(self, blocks_args=None, global_params=None): self._fc = nn.Linear(out_channels, self._global_params.num_classes) self._swish = MemoryEfficientSwish() - def set_swish(self, memory_efficient=True): + def set_swish(self, act_="mem_swish"): """Sets swish function as memory efficient (for training) or standard (for export)""" - self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + if act_ == "swish": + self._swish = Swish() + elif act_ == "mish": + self._swish = Mish() + else: + self._swish = MemoryEfficientSwish() for block in self._blocks: - block.set_swish(memory_efficient) - + block.set_swish(act_) def extract_features(self, inputs): """ Returns output of the final convolution layer """ @@ -200,21 +209,27 @@ def forward(self, inputs): return x @classmethod - def from_name(cls, model_name, override_params=None): + def _from_name(cls, model_name, override_params=None): cls._check_model_name_is_valid(model_name) blocks_args, global_params = get_model_params(model_name, override_params) return cls(blocks_args, global_params) - + @classmethod - def from_pretrained(cls, model_name, advprop=False, num_classes=1000, in_channels=3): - model = cls.from_name(model_name, override_params={'num_classes': num_classes}) - load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop) + def from_name(cls, model_name, num_classes=1000, pretrained=False, advprop=False, in_channels=3, dropout=-1, drop_connect=-1): + override_params={'num_classes': num_classes} + if 0 < dropout: + override_params['dropout_rate'] = dropout + if 0 < drop_connect: + override_params['drop_connect_rate'] = drop_connect + model = cls._from_name(model_name, override_params=override_params) + if pretrained: + load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop) if in_channels != 3: Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size) out_channels = round_filters(32, model._global_params) model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) return model - + @classmethod def get_image_size(cls, model_name): cls._check_model_name_is_valid(model_name) diff --git a/efficientnet_pytorch/utils.py b/efficientnet_pytorch/utils.py index f9b59ab..46a781c 100755 --- a/efficientnet_pytorch/utils.py +++ b/efficientnet_pytorch/utils.py @@ -55,6 +55,9 @@ class Swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x) +class Mish(nn.Module): + def forward(self, x): + return x * torch.tanh(F.softplus(x)) def round_filters(filters, global_params): """ Calculate and round number of filters based on depth multiplier. """