Skip to content

fix from_pretrained in_channels, etc. #94

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions efficientnet_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
load_pretrained_weights,
Swish,
MemoryEfficientSwish,
Mish,
)

class MBConvBlock(nn.Module):
Expand Down Expand Up @@ -92,11 +93,15 @@ def forward(self, inputs, drop_connect_rate=None):
x = drop_connect(x, p=drop_connect_rate, training=self.training)
x = x + inputs # skip connection
return x

def set_swish(self, memory_efficient=True):
def set_swish(self, act_="mem_swish"):
"""Sets swish function as memory efficient (for training) or standard (for export)"""
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()

if act_ == "swish":
self._swish = Swish()
elif act_ == "mish":
self._swish = Mish()
else:
self._swish = MemoryEfficientSwish()

class EfficientNet(nn.Module):
"""
Expand Down Expand Up @@ -161,12 +166,16 @@ def __init__(self, blocks_args=None, global_params=None):
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
self._swish = MemoryEfficientSwish()

def set_swish(self, memory_efficient=True):
def set_swish(self, act_="mem_swish"):
"""Sets swish function as memory efficient (for training) or standard (for export)"""
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
if act_ == "swish":
self._swish = Swish()
elif act_ == "mish":
self._swish = Mish()
else:
self._swish = MemoryEfficientSwish()
for block in self._blocks:
block.set_swish(memory_efficient)

block.set_swish(act_)

def extract_features(self, inputs):
""" Returns output of the final convolution layer """
Expand Down Expand Up @@ -200,21 +209,27 @@ def forward(self, inputs):
return x

@classmethod
def from_name(cls, model_name, override_params=None):
def _from_name(cls, model_name, override_params=None):
cls._check_model_name_is_valid(model_name)
blocks_args, global_params = get_model_params(model_name, override_params)
return cls(blocks_args, global_params)

@classmethod
def from_pretrained(cls, model_name, advprop=False, num_classes=1000, in_channels=3):
model = cls.from_name(model_name, override_params={'num_classes': num_classes})
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop)
def from_name(cls, model_name, num_classes=1000, pretrained=False, advprop=False, in_channels=3, dropout=-1, drop_connect=-1):
override_params={'num_classes': num_classes}
if 0 < dropout:
override_params['dropout_rate'] = dropout
if 0 < drop_connect:
override_params['drop_connect_rate'] = drop_connect
model = cls._from_name(model_name, override_params=override_params)
if pretrained:
load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop)
if in_channels != 3:
Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size)
out_channels = round_filters(32, model._global_params)
model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
return model

@classmethod
def get_image_size(cls, model_name):
cls._check_model_name_is_valid(model_name)
Expand Down
3 changes: 3 additions & 0 deletions efficientnet_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)

class Mish(nn.Module):
def forward(self, x):
return x * torch.tanh(F.softplus(x))

def round_filters(filters, global_params):
""" Calculate and round number of filters based on depth multiplier. """
Expand Down