From da86572a6f1538be5d616b985f2aaff06513ec63 Mon Sep 17 00:00:00 2001 From: Phoenix Date: Wed, 25 Sep 2024 19:44:45 +0800 Subject: [PATCH 1/2] upload new model : ghostnetv1,v2,v3 --- ghostnet/ghostnet.py | 315 +++++++++++++++ ghostnet/ghostnetv2.py | 324 ++++++++++++++++ ghostnet/ghostnetv3.py | 854 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 1493 insertions(+) create mode 100644 ghostnet/ghostnet.py create mode 100644 ghostnet/ghostnetv2.py create mode 100644 ghostnet/ghostnetv3.py diff --git a/ghostnet/ghostnet.py b/ghostnet/ghostnet.py new file mode 100644 index 0000000..6a6e929 --- /dev/null +++ b/ghostnet/ghostnet.py @@ -0,0 +1,315 @@ +import torch +import torch.nn as nn +import torch.onnx +import onnxsim +import onnx +import struct +import os + + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +__all__ = ['ghost_net'] + + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.).clamp_(0., 6.).div_(6.) + else: + return F.relu6(x + 3.) / 6. + + +class SqueezeExcite(nn.Module): + def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, + act_layer=nn.ReLU, gate_fn=hard_sigmoid, divisor=4, **_): + super(SqueezeExcite, self).__init__() + self.gate_fn = gate_fn + reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layer(inplace=True) + self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + + def forward(self, x): + x_se = self.avg_pool(x) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + x = x * self.gate_fn(x_se) + return x + + +class ConvBnAct(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, + stride=1, act_layer=nn.ReLU): + super(ConvBnAct, self).__init__() + self.conv = nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size//2, bias=False) + self.bn1 = nn.BatchNorm2d(out_chs) + self.act1 = act_layer(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn1(x) + x = self.act1(x) + return x + + +class GhostModule(nn.Module): + def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True): + super(GhostModule, self).__init__() + self.oup = oup + init_channels = math.ceil(oup / ratio) + new_channels = init_channels*(ratio-1) + + self.primary_conv = nn.Sequential( + nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False), + nn.BatchNorm2d(init_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + + self.cheap_operation = nn.Sequential( + nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False), + nn.BatchNorm2d(new_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + + def forward(self, x): + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + out = torch.cat([x1,x2], dim=1) + return out[:,:self.oup,:,:] + + +class GhostBottleneck(nn.Module): + """ Ghost bottleneck w/ optional SE""" + + def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3, + stride=1, act_layer=nn.ReLU, se_ratio=0.): + super(GhostBottleneck, self).__init__() + has_se = se_ratio is not None and se_ratio > 0. + self.stride = stride + + # Point-wise expansion + self.ghost1 = GhostModule(in_chs, mid_chs, relu=True) + + # Depth-wise convolution + if self.stride > 1: + self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride=stride, + padding=(dw_kernel_size-1)//2, + groups=mid_chs, bias=False) + self.bn_dw = nn.BatchNorm2d(mid_chs) + + # Squeeze-and-excitation + if has_se: + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio) + else: + self.se = None + + # Point-wise linear projection + self.ghost2 = GhostModule(mid_chs, out_chs, relu=False) + + # shortcut + if (in_chs == out_chs and self.stride == 1): + self.shortcut = nn.Sequential() + else: + self.shortcut = nn.Sequential( + nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride=stride, + padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False), + nn.BatchNorm2d(in_chs), + nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(out_chs), + ) + + + def forward(self, x): + residual = x + + # 1st ghost bottleneck + x = self.ghost1(x) + + # Depth-wise convolution + if self.stride > 1: + x = self.conv_dw(x) + x = self.bn_dw(x) + + # Squeeze-and-excitation + if self.se is not None: + x = self.se(x) + + # 2nd ghost bottleneck + x = self.ghost2(x) + + x += self.shortcut(residual) + return x + + +class GhostNet(nn.Module): + def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2): + super(GhostNet, self).__init__() + # setting of inverted residual blocks + self.cfgs = cfgs + self.dropout = dropout + + # building first layer + output_channel = _make_divisible(16 * width, 4) + self.conv_stem = nn.Conv2d(3, output_channel, 3, 2, 1, bias=False) + self.bn1 = nn.BatchNorm2d(output_channel) + self.act1 = nn.ReLU(inplace=True) + input_channel = output_channel + + # building inverted residual blocks + stages = [] + block = GhostBottleneck + for cfg in self.cfgs: + layers = [] + for k, exp_size, c, se_ratio, s in cfg: + output_channel = _make_divisible(c * width, 4) + hidden_channel = _make_divisible(exp_size * width, 4) + layers.append(block(input_channel, hidden_channel, output_channel, k, s, + se_ratio=se_ratio)) + input_channel = output_channel + stages.append(nn.Sequential(*layers)) + + output_channel = _make_divisible(exp_size * width, 4) + stages.append(nn.Sequential(ConvBnAct(input_channel, output_channel, 1))) + input_channel = output_channel + + self.blocks = nn.Sequential(*stages) + + # building last several layers + output_channel = 1280 + self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) + self.conv_head = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=True) + self.act2 = nn.ReLU(inplace=True) + self.classifier = nn.Linear(output_channel, num_classes) + + def forward(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + x = x.view(x.size(0), -1) + if self.dropout > 0.: + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.classifier(x) + return x + + +def ghostnet(**kwargs): + """ + Constructs a GhostNet model + """ + cfgs = [ + # k, t, c, SE, s + # stage1 + [[3, 16, 16, 0, 1]], + # stage2 + [[3, 48, 24, 0, 2]], + [[3, 72, 24, 0, 1]], + # stage3 + [[5, 72, 40, 0.25, 2]], + [[5, 120, 40, 0.25, 1]], + # stage4 + [[3, 240, 80, 0, 2]], + [[3, 200, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 480, 112, 0.25, 1], + [3, 672, 112, 0.25, 1] + ], + # stage5 + [[5, 672, 160, 0.25, 2]], + [[5, 960, 160, 0, 1], + [5, 960, 160, 0.25, 1], + [5, 960, 160, 0, 1], + [5, 960, 160, 0.25, 1] + ] + ] + return GhostNet(cfgs, **kwargs) + + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + +def export_weight(model): + current_path = os.path.dirname(__file__) + f = open(current_path + "ghostnetv1.weights", 'w') + f.write("{}\n".format(len(model.state_dict().keys()))) + + for k, v in model.state_dict().items(): + print('exporting ... {}: {}'.format(k, v.shape)) + + vr = v.reshape(-1).cpu().numpy() + f.write("{} {}".format(k, len(vr))) + for vv in vr: + f.write(" ") + f.write(struct.pack(">f", float(vv)).hex()) + f.write("\n") + + f.close() + +def export_onnx(input, model): + current_path = os.path.dirname(__file__) + file = current_path + "ghostnetv1.onnx" + torch.onnx.export( + model=model, + args=(input,), + f=file, + input_names=["input0"], + output_names=["output0"], + opset_version=13 + ) + print("Finished ONNX export") + + model_onnx = onnx.load(file) + onnx.checker.check_model(model_onnx) + + print(f"Simplifying with onnx-simplifier {onnxsim.__version__}...") + model_onnx, check = onnxsim.simplify(model_onnx) + assert check, "Simplification check failed" + onnx.save(model_onnx, file) + +def eval_model(input, model): + output = model(input) + print("------from inference------") + print(input) + print(output) + +if __name__ == "__main__": + setup_seed(1) + + model = ghostnet(num_classes=1000, width=1.0, dropout=0.2) + model.eval() + + input = torch.randn(32, 3, 320, 256) + + export_weight(model) + + export_onnx(input, model) + + eval_model(input, model) diff --git a/ghostnet/ghostnetv2.py b/ghostnet/ghostnetv2.py new file mode 100644 index 0000000..2b23a05 --- /dev/null +++ b/ghostnet/ghostnetv2.py @@ -0,0 +1,324 @@ +import torch +import torch.nn as nn +import torch.onnx +import onnxsim +import onnx +import struct +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +from timm.models.registry import register_model + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.).clamp_(0., 6.).div_(6.) + else: + return F.relu6(x + 3.) / 6. + +class SqueezeExcite(nn.Module): + def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, + act_layer=nn.ReLU, gate_fn=hard_sigmoid, divisor=4, **_): + super(SqueezeExcite, self).__init__() + self.gate_fn = gate_fn + reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layer(inplace=True) + self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + + def forward(self, x): + x_se = self.avg_pool(x) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + x = x * self.gate_fn(x_se) + return x + +class ConvBnAct(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, + stride=1, act_layer=nn.ReLU): + super(ConvBnAct, self).__init__() + self.conv = nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size//2, bias=False) + self.bn1 = nn.BatchNorm2d(out_chs) + self.act1 = act_layer(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn1(x) + x = self.act1(x) + return x + +class GhostModuleV2(nn.Module): + def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True,mode=None,args=None): + super(GhostModuleV2, self).__init__() + self.mode=mode + self.gate_fn=nn.Sigmoid() + + if self.mode in ['original']: + self.oup = oup + init_channels = math.ceil(oup / ratio) + new_channels = init_channels*(ratio-1) + self.primary_conv = nn.Sequential( + nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False), + nn.BatchNorm2d(init_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + self.cheap_operation = nn.Sequential( + nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False), + nn.BatchNorm2d(new_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + elif self.mode in ['attn']: + self.oup = oup + init_channels = math.ceil(oup / ratio) + new_channels = init_channels*(ratio-1) + self.primary_conv = nn.Sequential( + nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False), + nn.BatchNorm2d(init_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + self.cheap_operation = nn.Sequential( + nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False), + nn.BatchNorm2d(new_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + self.short_conv = nn.Sequential( + nn.Conv2d(inp, oup, kernel_size, stride, kernel_size//2, bias=False), + nn.BatchNorm2d(oup), + nn.Conv2d(oup, oup, kernel_size=(1,5), stride=1, padding=(0,2), groups=oup,bias=False), + nn.BatchNorm2d(oup), + nn.Conv2d(oup, oup, kernel_size=(5,1), stride=1, padding=(2,0), groups=oup,bias=False), + nn.BatchNorm2d(oup), + ) + + def forward(self, x): + if self.mode in ['original']: + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + out = torch.cat([x1,x2], dim=1) + return out[:,:self.oup,:,:] + elif self.mode in ['attn']: + res=self.short_conv(F.avg_pool2d(x,kernel_size=2,stride=2)) + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + out = torch.cat([x1,x2], dim=1) + return out[:,:self.oup,:,:]*F.interpolate(self.gate_fn(res),size=(out.shape[-2],out.shape[-1]),mode='nearest') + + +class GhostBottleneckV2(nn.Module): + + def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3, + stride=1, act_layer=nn.ReLU, se_ratio=0.,layer_id=None,args=None): + super(GhostBottleneckV2, self).__init__() + has_se = se_ratio is not None and se_ratio > 0. + self.stride = stride + + # Point-wise expansion + if layer_id<=1: + self.ghost1 = GhostModuleV2(in_chs, mid_chs, relu=True,mode='original',args=args) + else: + self.ghost1 = GhostModuleV2(in_chs, mid_chs, relu=True,mode='attn',args=args) + + # Depth-wise convolution + if self.stride > 1: + self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride=stride, + padding=(dw_kernel_size-1)//2,groups=mid_chs, bias=False) + self.bn_dw = nn.BatchNorm2d(mid_chs) + + # Squeeze-and-excitation + if has_se: + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio) + else: + self.se = None + + self.ghost2 = GhostModuleV2(mid_chs, out_chs, relu=False,mode='original',args=args) + + # shortcut + if (in_chs == out_chs and self.stride == 1): + self.shortcut = nn.Sequential() + else: + self.shortcut = nn.Sequential( + nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride=stride, + padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False), + nn.BatchNorm2d(in_chs), + nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(out_chs), + ) + def forward(self, x): + residual = x + x = self.ghost1(x) + if self.stride > 1: + x = self.conv_dw(x) + x = self.bn_dw(x) + if self.se is not None: + x = self.se(x) + x = self.ghost2(x) + x += self.shortcut(residual) + return x + + +class GhostNetV2(nn.Module): + def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2,block=GhostBottleneckV2,args=None): + super(GhostNetV2, self).__init__() + self.cfgs = cfgs + self.dropout = dropout + + # building first layer + output_channel = _make_divisible(16 * width, 4) + self.conv_stem = nn.Conv2d(3, output_channel, 3, 2, 1, bias=False) + self.bn1 = nn.BatchNorm2d(output_channel) + self.act1 = nn.ReLU(inplace=True) + input_channel = output_channel + + # building inverted residual blocks + stages = [] + #block = block + layer_id=0 + for cfg in self.cfgs: + layers = [] + for k, exp_size, c, se_ratio, s in cfg: + output_channel = _make_divisible(c * width, 4) + hidden_channel = _make_divisible(exp_size * width, 4) + if block==GhostBottleneckV2: + layers.append(block(input_channel, hidden_channel, output_channel, k, s, + se_ratio=se_ratio,layer_id=layer_id,args=args)) + input_channel = output_channel + layer_id+=1 + stages.append(nn.Sequential(*layers)) + + output_channel = _make_divisible(exp_size * width, 4) + stages.append(nn.Sequential(ConvBnAct(input_channel, output_channel, 1))) + input_channel = output_channel + + self.blocks = nn.Sequential(*stages) + + # building last several layers + output_channel = 1280 + self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) + self.conv_head = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=True) + self.act2 = nn.ReLU(inplace=True) + self.classifier = nn.Linear(output_channel, num_classes) + + def forward(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + x = x.view(x.size(0), -1) + if self.dropout > 0.: + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.classifier(x) + return x + +@register_model +def ghostnetv2(**kwargs): + cfgs = [ + # k, t, c, SE, s + [[3, 16, 16, 0, 1]], + [[3, 48, 24, 0, 2]], + [[3, 72, 24, 0, 1]], + [[5, 72, 40, 0.25, 2]], + [[5, 120, 40, 0.25, 1]], + [[3, 240, 80, 0, 2]], + [[3, 200, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 480, 112, 0.25, 1], + [3, 672, 112, 0.25, 1] + ], + [[5, 672, 160, 0.25, 2]], + [[5, 960, 160, 0, 1], + [5, 960, 160, 0.25, 1], + [5, 960, 160, 0, 1], + [5, 960, 160, 0.25, 1] + ] + ] + return GhostNetV2(cfgs, num_classes=kwargs['num_classes'], + width=kwargs['width'], + dropout=kwargs['dropout'], + args=kwargs['args']) + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + +def export_weight(model): + current_path = os.path.dirname(__file__) + f = open(current_path + "ghostnetv2.weights", 'w') + f.write("{}\n".format(len(model.state_dict().keys()))) + + for k, v in model.state_dict().items(): + print('exporting ... {}: {}'.format(k, v.shape)) + + vr = v.reshape(-1).cpu().numpy() + f.write("{} {}".format(k, len(vr))) + for vv in vr: + f.write(" ") + f.write(struct.pack(">f", float(vv)).hex()) + f.write("\n") + + f.close() + +def export_onnx(input, model): + current_path = os.path.dirname(__file__) + file = current_path + "ghostnetv2.onnx" + torch.onnx.export( + model=model, + args=(input,), + f=file, + input_names=["input0"], + output_names=["output0"], + opset_version=13 + ) + print("Finished ONNX export") + + model_onnx = onnx.load(file) + onnx.checker.check_model(model_onnx) + + print(f"Simplifying with onnx-simplifier {onnxsim.__version__}...") + model_onnx, check = onnxsim.simplify(model_onnx) + assert check, "Simplification check failed" + onnx.save(model_onnx, file) + +def eval_model(input, model): + output = model(input) + print("------from inference------") + print(input) + print(output) + +if __name__ == "__main__": + setup_seed(1) + + model = ghostnetv2(width=1.0, num_classes=1000, dropout=0.2, args=None) + model.eval() + + input = torch.randn(32, 3, 320, 256) + + export_weight(model) + + export_onnx(input, model) + + eval_model(input, model) diff --git a/ghostnet/ghostnetv3.py b/ghostnet/ghostnetv3.py new file mode 100644 index 0000000..3e6f5ed --- /dev/null +++ b/ghostnet/ghostnetv3.py @@ -0,0 +1,854 @@ +import torch +import torch.nn as nn +import torch.onnx +import onnxsim +import onnx +import struct +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +from typing import Optional, List, Tuple +from timm.models.registry import register_model + +#__all__ = ['ghost_net'] + + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.).clamp_(0., 6.).div_(6.) + else: + return F.relu6(x + 3.) / 6. + + +class SqueezeExcite(nn.Module): + def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, + act_layer=nn.ReLU, gate_fn=hard_sigmoid, divisor=4, **_): + super(SqueezeExcite, self).__init__() + self.gate_fn = gate_fn + reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layer(inplace=True) + self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + + def forward(self, x): + x_se = self.avg_pool(x) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + x = x * self.gate_fn(x_se) + return x + + +class ConvBnAct(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, + stride=1, act_layer=nn.ReLU): + super(ConvBnAct, self).__init__() + self.conv = nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size//2, bias=False) + self.bn1 = nn.BatchNorm2d(out_chs) + self.act1 = act_layer(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn1(x) + x = self.act1(x) + return x + + +def gcd(a,b): + if a 1: + self.primary_rpr_scale = self._conv_bn(inp, init_channels, 1, 1, 0, bias=False) + self.primary_activation = nn.ReLU(inplace=True) if relu else None + + + self.cheap_rpr_skip = nn.BatchNorm2d(init_channels) \ + if init_channels == new_channels else None + cheap_rpr_conv = list() + for _ in range(self.num_conv_branches): + cheap_rpr_conv.append(self._conv_bn(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False)) + self.cheap_rpr_conv = nn.ModuleList(cheap_rpr_conv) + # Re-parameterizable scale branch + self.cheap_rpr_scale = None + if dw_size > 1: + self.cheap_rpr_scale = self._conv_bn(init_channels, new_channels, 1, 1, 0, groups=init_channels, bias=False) + self.cheap_activation = nn.ReLU(inplace=True) if relu else None + self.in_channels = init_channels + self.groups = init_channels + self.kernel_size = dw_size + + elif self.mode in ['ori_shortcut_mul_conv15']: + self.oup = oup + init_channels = math.ceil(oup / ratio) + new_channels = init_channels*(ratio-1) + self.short_conv = nn.Sequential( + nn.Conv2d(inp, oup, kernel_size, stride, kernel_size//2, bias=False), + nn.BatchNorm2d(oup), + nn.Conv2d(oup, oup, kernel_size=(1,5), stride=1, padding=(0,2), groups=oup,bias=False), + nn.BatchNorm2d(oup), + nn.Conv2d(oup, oup, kernel_size=(5,1), stride=1, padding=(2,0), groups=oup,bias=False), + nn.BatchNorm2d(oup), + ) + if self.infer_mode: + self.primary_conv = nn.Sequential( + nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False), + nn.BatchNorm2d(init_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + self.cheap_operation = nn.Sequential( + nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False), + nn.BatchNorm2d(new_channels), + nn.ReLU(inplace=True) if relu else nn.Sequential(), + ) + else: + self.primary_rpr_skip = nn.BatchNorm2d(inp) \ + if inp == init_channels and stride == 1 else None + primary_rpr_conv = list() + for _ in range(self.num_conv_branches): + primary_rpr_conv.append(self._conv_bn(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False)) + self.primary_rpr_conv = nn.ModuleList(primary_rpr_conv) + # Re-parameterizable scale branch + self.primary_rpr_scale = None + if kernel_size > 1: + self.primary_rpr_scale = self._conv_bn(inp, init_channels, 1, 1, 0, bias=False) + self.primary_activation = nn.ReLU(inplace=True) if relu else None + + + self.cheap_rpr_skip = nn.BatchNorm2d(init_channels) \ + if init_channels == new_channels else None + cheap_rpr_conv = list() + for _ in range(self.num_conv_branches): + cheap_rpr_conv.append(self._conv_bn(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False)) + self.cheap_rpr_conv = nn.ModuleList(cheap_rpr_conv) + # Re-parameterizable scale branch + self.cheap_rpr_scale = None + if dw_size > 1: + self.cheap_rpr_scale = self._conv_bn(init_channels, new_channels, 1, 1, 0, groups=init_channels, bias=False) + self.cheap_activation = nn.ReLU(inplace=True) if relu else None + self.in_channels = init_channels + self.groups = init_channels + self.kernel_size = dw_size + + + def forward(self, x): + if self.mode in ['ori']: + if self.infer_mode: + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + else: + identity_out = 0 + if self.primary_rpr_skip is not None: + identity_out = self.primary_rpr_skip(x) + scale_out = 0 + if self.primary_rpr_scale is not None and self.dconv_scale: + scale_out = self.primary_rpr_scale(x) + x1 = scale_out + identity_out + for ix in range(self.num_conv_branches): + x1 += self.primary_rpr_conv[ix](x) + if self.primary_activation is not None: + x1 = self.primary_activation(x1) + + cheap_identity_out = 0 + if self.cheap_rpr_skip is not None: + cheap_identity_out = self.cheap_rpr_skip(x1) + cheap_scale_out = 0 + if self.cheap_rpr_scale is not None and self.dconv_scale: + cheap_scale_out = self.cheap_rpr_scale(x1) + x2 = cheap_scale_out + cheap_identity_out + for ix in range(self.num_conv_branches): + x2 += self.cheap_rpr_conv[ix](x1) + if self.cheap_activation is not None: + x2 = self.cheap_activation(x2) + + out = torch.cat([x1,x2], dim=1) + return out + + elif self.mode in ['ori_shortcut_mul_conv15']: + res=self.short_conv(F.avg_pool2d(x,kernel_size=2,stride=2)) + + if self.infer_mode: + x1 = self.primary_conv(x) + x2 = self.cheap_operation(x1) + else: + identity_out = 0 + if self.primary_rpr_skip is not None: + identity_out = self.primary_rpr_skip(x) + scale_out = 0 + if self.primary_rpr_scale is not None and self.dconv_scale: + scale_out = self.primary_rpr_scale(x) + x1 = scale_out + identity_out + for ix in range(self.num_conv_branches): + x1 += self.primary_rpr_conv[ix](x) + if self.primary_activation is not None: + x1 = self.primary_activation(x1) + + cheap_identity_out = 0 + if self.cheap_rpr_skip is not None: + cheap_identity_out = self.cheap_rpr_skip(x1) + cheap_scale_out = 0 + if self.cheap_rpr_scale is not None and self.dconv_scale: + cheap_scale_out = self.cheap_rpr_scale(x1) + x2 = cheap_scale_out + cheap_identity_out + for ix in range(self.num_conv_branches): + x2 += self.cheap_rpr_conv[ix](x1) + if self.cheap_activation is not None: + x2 = self.cheap_activation(x2) + + out = torch.cat([x1,x2], dim=1) + + if self.gate_loc=='before': + return out[:,:self.oup,:,:]*F.interpolate(self.gate_fn(res/self.scale),size=out.shape[-2:],mode=self.inter_mode) # 'nearest' +# return out*F.interpolate(self.gate_fn(res/self.scale),size=out.shape[-1].item(),mode=self.inter_mode) # 'nearest' + else: + return out[:,:self.oup,:,:]*self.gate_fn(F.interpolate(res,size=out.shape[-2:],mode=self.inter_mode)) +# return out*self.gate_fn(F.interpolate(res,size=out.shape[-1],mode=self.inter_mode)) + + + def reparameterize(self): + """ Following works like `RepVGG: Making VGG-style ConvNets Great Again` - + https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched + architecture used at training time to obtain a plain CNN-like structure + for inference. + """ + if self.infer_mode: + return + primary_kernel, primary_bias = self._get_kernel_bias_primary() + self.primary_conv = nn.Conv2d(in_channels=self.primary_rpr_conv[0].conv.in_channels, + out_channels=self.primary_rpr_conv[0].conv.out_channels, + kernel_size=self.primary_rpr_conv[0].conv.kernel_size, + stride=self.primary_rpr_conv[0].conv.stride, + padding=self.primary_rpr_conv[0].conv.padding, + dilation=self.primary_rpr_conv[0].conv.dilation, + groups=self.primary_rpr_conv[0].conv.groups, + bias=True) + self.primary_conv.weight.data = primary_kernel + self.primary_conv.bias.data = primary_bias + self.primary_conv = nn.Sequential( + self.primary_conv, + self.primary_activation if self.primary_activation is not None else nn.Sequential() + ) + + cheap_kernel, cheap_bias = self._get_kernel_bias_cheap() + self.cheap_operation = nn.Conv2d(in_channels=self.cheap_rpr_conv[0].conv.in_channels, + out_channels=self.cheap_rpr_conv[0].conv.out_channels, + kernel_size=self.cheap_rpr_conv[0].conv.kernel_size, + stride=self.cheap_rpr_conv[0].conv.stride, + padding=self.cheap_rpr_conv[0].conv.padding, + dilation=self.cheap_rpr_conv[0].conv.dilation, + groups=self.cheap_rpr_conv[0].conv.groups, + bias=True) + self.cheap_operation.weight.data = cheap_kernel + self.cheap_operation.bias.data = cheap_bias + + self.cheap_operation = nn.Sequential( + self.cheap_operation, + self.cheap_activation if self.cheap_activation is not None else nn.Sequential() + ) + + # Delete un-used branches + for para in self.parameters(): + para.detach_() + if hasattr(self, 'primary_rpr_conv'): + self.__delattr__('primary_rpr_conv') + if hasattr(self, 'primary_rpr_scale'): + self.__delattr__('primary_rpr_scale') + if hasattr(self, 'primary_rpr_skip'): + self.__delattr__('primary_rpr_skip') + + if hasattr(self, 'cheap_rpr_conv'): + self.__delattr__('cheap_rpr_conv') + if hasattr(self, 'cheap_rpr_scale'): + self.__delattr__('cheap_rpr_scale') + if hasattr(self, 'cheap_rpr_skip'): + self.__delattr__('cheap_rpr_skip') + + self.infer_mode = True + + def _get_kernel_bias_primary(self) -> Tuple[torch.Tensor, torch.Tensor]: + """ Method to obtain re-parameterized kernel and bias. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 + + :return: Tuple of (kernel, bias) after fusing branches. + """ + # get weights and bias of scale branch + kernel_scale = 0 + bias_scale = 0 + if self.primary_rpr_scale is not None: + kernel_scale, bias_scale = self._fuse_bn_tensor(self.primary_rpr_scale) + # Pad scale branch kernel to match conv branch kernel size. + pad = self.kernel_size // 2 + kernel_scale = torch.nn.functional.pad(kernel_scale, + [pad, pad, pad, pad]) + + # get weights and bias of skip branch + kernel_identity = 0 + bias_identity = 0 + if self.primary_rpr_skip is not None: + kernel_identity, bias_identity = self._fuse_bn_tensor(self.primary_rpr_skip) + + # get weights and bias of conv branches + kernel_conv = 0 + bias_conv = 0 + for ix in range(self.num_conv_branches): + _kernel, _bias = self._fuse_bn_tensor(self.primary_rpr_conv[ix]) + kernel_conv += _kernel + bias_conv += _bias + + kernel_final = kernel_conv + kernel_scale + kernel_identity + bias_final = bias_conv + bias_scale + bias_identity + return kernel_final, bias_final + + def _get_kernel_bias_cheap(self) -> Tuple[torch.Tensor, torch.Tensor]: + """ Method to obtain re-parameterized kernel and bias. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 + + :return: Tuple of (kernel, bias) after fusing branches. + """ + # get weights and bias of scale branch + kernel_scale = 0 + bias_scale = 0 + if self.cheap_rpr_scale is not None: + kernel_scale, bias_scale = self._fuse_bn_tensor(self.cheap_rpr_scale) + # Pad scale branch kernel to match conv branch kernel size. + pad = self.kernel_size // 2 + kernel_scale = torch.nn.functional.pad(kernel_scale, + [pad, pad, pad, pad]) + + # get weights and bias of skip branch + kernel_identity = 0 + bias_identity = 0 + if self.cheap_rpr_skip is not None: + kernel_identity, bias_identity = self._fuse_bn_tensor(self.cheap_rpr_skip) + + # get weights and bias of conv branches + kernel_conv = 0 + bias_conv = 0 + for ix in range(self.num_conv_branches): + _kernel, _bias = self._fuse_bn_tensor(self.cheap_rpr_conv[ix]) + kernel_conv += _kernel + bias_conv += _bias + + kernel_final = kernel_conv + kernel_scale + kernel_identity + bias_final = bias_conv + bias_scale + bias_identity + return kernel_final, bias_final + + def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]: + """ Method to fuse batchnorm layer with preceeding conv layer. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 + + :param branch: + :return: Tuple of (kernel, bias) after fusing batchnorm. + """ + if isinstance(branch, nn.Sequential): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + else: + assert isinstance(branch, nn.BatchNorm2d) + if not hasattr(self, 'id_tensor'): + input_dim = self.in_channels // self.groups + kernel_value = torch.zeros((self.in_channels, + input_dim, + self.kernel_size, + self.kernel_size), + dtype=branch.weight.dtype, + device=branch.weight.device) + for i in range(self.in_channels): + kernel_value[i, i % input_dim, + self.kernel_size // 2, + self.kernel_size // 2] = 1 + self.id_tensor = kernel_value + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + def _conv_bn(self, in_channels, out_channels, kernel_size, stride, padding, groups=1, bias=False): + """ Helper method to construct conv-batchnorm layers. + + :param kernel_size: Size of the convolution kernel. + :param padding: Zero-padding size. + :return: Conv-BN module. + """ + mod_list = nn.Sequential() + mod_list.add_module('conv', nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=bias)) + mod_list.add_module('bn', nn.BatchNorm2d(out_channels)) + return mod_list + + +class GhostBottleneck(nn.Module): + """ Ghost bottleneck w/ optional SE""" + + def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3, + stride=1, act_layer=nn.ReLU, se_ratio=0.,layer_id=None,args=None): + super(GhostBottleneck, self).__init__() + has_se = se_ratio is not None and se_ratio > 0. + self.stride = stride + + self.num_conv_branches = 3 + self.infer_mode = False + self.dconv_scale = True + + # Point-wise expansion + if layer_id<=1: + self.ghost1 = GhostModule(in_chs, mid_chs, relu=True,mode='ori',args=args) + else: + self.ghost1 = GhostModule(in_chs, mid_chs, relu=True,mode='ori_shortcut_mul_conv15',args=args) ####这里是扩张 mid_chs远大于in_chs + + # Depth-wise convolution + if self.stride > 1: + if self.infer_mode: + self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride=stride, + padding=(dw_kernel_size-1)//2, + groups=mid_chs, bias=False) + self.bn_dw = nn.BatchNorm2d(mid_chs) + else: + self.dw_rpr_skip = nn.BatchNorm2d(mid_chs) if stride == 1 else None + dw_rpr_conv = list() + for _ in range(self.num_conv_branches): + dw_rpr_conv.append(self._conv_bn(mid_chs, mid_chs, dw_kernel_size, stride, (dw_kernel_size-1)//2, groups=mid_chs, bias=False)) + self.dw_rpr_conv = nn.ModuleList(dw_rpr_conv) + # Re-parameterizable scale branch + self.dw_rpr_scale = None + if dw_kernel_size > 1: + self.dw_rpr_scale = self._conv_bn(mid_chs, mid_chs, 1, 2, 0, groups=mid_chs, bias=False) + self.kernel_size = dw_kernel_size + self.in_channels = mid_chs + + # Squeeze-and-excitation + if has_se: + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio) + else: + self.se = None + + # Point-wise linear projection + if layer_id<=1: + self.ghost2 = GhostModule(mid_chs, out_chs, relu=False,mode='ori',args=args) + else: + self.ghost2 = GhostModule(mid_chs, out_chs, relu=False,mode='ori',args=args) + + # shortcut + if (in_chs == out_chs and self.stride == 1): + self.shortcut = nn.Sequential() + else: + self.shortcut = nn.Sequential( + nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride=stride, + padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False), + nn.BatchNorm2d(in_chs), + nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(out_chs), + ) + + def forward(self, x): + residual = x + + # 1st ghost bottleneck + x = self.ghost1(x) + + # Depth-wise convolution + if self.stride > 1: + if self.infer_mode: + x = self.conv_dw(x) + x = self.bn_dw(x) + else: + dw_identity_out = 0 + if self.dw_rpr_skip is not None: + dw_identity_out = self.dw_rpr_skip(x) + dw_scale_out = 0 + if self.dw_rpr_scale is not None and self.dconv_scale: + dw_scale_out = self.dw_rpr_scale(x) + x1 = dw_scale_out + dw_identity_out + for ix in range(self.num_conv_branches): + x1 += self.dw_rpr_conv[ix](x) + x = x1 + + # Squeeze-and-excitation + if self.se is not None: + x = self.se(x) + + # 2nd ghost bottleneck + x = self.ghost2(x) + + x += self.shortcut(residual) + return x + + def _conv_bn(self, in_channels, out_channels, kernel_size, stride, padding, groups=1, bias=False): + """ Helper method to construct conv-batchnorm layers. + + :param kernel_size: Size of the convolution kernel. + :param padding: Zero-padding size. + :return: Conv-BN module. + """ + mod_list = nn.Sequential() + mod_list.add_module('conv', nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=bias)) + mod_list.add_module('bn', nn.BatchNorm2d(out_channels)) + return mod_list + + + def reparameterize(self): + """ Following works like `RepVGG: Making VGG-style ConvNets Great Again` - + https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched + architecture used at training time to obtain a plain CNN-like structure + for inference. + """ + if self.infer_mode or self.stride == 1: + return + dw_kernel, dw_bias = self._get_kernel_bias_dw() + self.conv_dw = nn.Conv2d(in_channels=self.dw_rpr_conv[0].conv.in_channels, + out_channels=self.dw_rpr_conv[0].conv.out_channels, + kernel_size=self.dw_rpr_conv[0].conv.kernel_size, + stride=self.dw_rpr_conv[0].conv.stride, + padding=self.dw_rpr_conv[0].conv.padding, + dilation=self.dw_rpr_conv[0].conv.dilation, + groups=self.dw_rpr_conv[0].conv.groups, + bias=True) + self.conv_dw.weight.data = dw_kernel + self.conv_dw.bias.data = dw_bias + self.bn_dw = nn.Identity() + + # Delete un-used branches + for para in self.parameters(): + para.detach_() + if hasattr(self, 'dw_rpr_conv'): + self.__delattr__('dw_rpr_conv') + if hasattr(self, 'dw_rpr_scale'): + self.__delattr__('dw_rpr_scale') + if hasattr(self, 'dw_rpr_skip'): + self.__delattr__('dw_rpr_skip') + + self.infer_mode = True + + def _get_kernel_bias_dw(self) -> Tuple[torch.Tensor, torch.Tensor]: + """ Method to obtain re-parameterized kernel and bias. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 + + :return: Tuple of (kernel, bias) after fusing branches. + """ + # get weights and bias of scale branch + kernel_scale = 0 + bias_scale = 0 + if self.dw_rpr_scale is not None: + kernel_scale, bias_scale = self._fuse_bn_tensor(self.dw_rpr_scale) + # Pad scale branch kernel to match conv branch kernel size. + pad = self.kernel_size // 2 + kernel_scale = torch.nn.functional.pad(kernel_scale, + [pad, pad, pad, pad]) + + # get weights and bias of skip branch + kernel_identity = 0 + bias_identity = 0 + if self.dw_rpr_skip is not None: + kernel_identity, bias_identity = self._fuse_bn_tensor(self.dw_rpr_skip) + + # get weights and bias of conv branches + kernel_conv = 0 + bias_conv = 0 + for ix in range(self.num_conv_branches): + _kernel, _bias = self._fuse_bn_tensor(self.dw_rpr_conv[ix]) + kernel_conv += _kernel + bias_conv += _bias + + kernel_final = kernel_conv + kernel_scale + kernel_identity + bias_final = bias_conv + bias_scale + bias_identity + return kernel_final, bias_final + + + def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]: + """ Method to fuse batchnorm layer with preceeding conv layer. + Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 + + :param branch: + :return: Tuple of (kernel, bias) after fusing batchnorm. + """ + if isinstance(branch, nn.Sequential): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + else: + assert isinstance(branch, nn.BatchNorm2d) + if not hasattr(self, 'id_tensor'): + input_dim = self.in_channels // self.groups + kernel_value = torch.zeros((self.in_channels, + input_dim, + self.kernel_size, + self.kernel_size), + dtype=branch.weight.dtype, + device=branch.weight.device) + for i in range(self.in_channels): + kernel_value[i, i % input_dim, + self.kernel_size // 2, + self.kernel_size // 2] = 1 + self.id_tensor = kernel_value + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + +class GhostNet(nn.Module): + def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, block=GhostBottleneck, args=None): + super(GhostNet, self).__init__() + # setting of inverted residual blocks + self.cfgs = cfgs + self.dropout = dropout + + # building first layer + output_channel = _make_divisible(16 * width, 4) + self.conv_stem = nn.Conv2d(3, output_channel, 3, 2, 1, bias=False) + self.bn1 = nn.BatchNorm2d(output_channel) + self.act1 = nn.ReLU(inplace=True) + input_channel = output_channel + + # building inverted residual blocks + stages = [] + #block = block + layer_id=0 + for cfg in self.cfgs: + layers = [] + for k, exp_size, c, se_ratio, s in cfg: + + output_channel = _make_divisible(c * width, 4) + hidden_channel = _make_divisible(exp_size * width, 4) + if block==GhostBottleneck: + layers.append(block(input_channel, hidden_channel, output_channel, k, s, + se_ratio=se_ratio,layer_id=layer_id,args=args)) + input_channel = output_channel + layer_id+=1 + stages.append(nn.Sequential(*layers)) + + output_channel = _make_divisible(exp_size * width, 4) + stages.append(nn.Sequential(ConvBnAct(input_channel, output_channel, 1))) + input_channel = output_channel + + self.blocks = nn.Sequential(*stages) + + # building last several layers + output_channel = 1280 + self.global_pool = nn.AdaptiveAvgPool2d((1, 1)) + self.conv_head = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=True) + self.act2 = nn.ReLU(inplace=True) + self.classifier = nn.Linear(output_channel, num_classes) + + def forward(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + x = x.view(x.size(0), -1) + # if self.dropout > 0.: + # x = F.dropout(x, p=self.dropout, training=self.training) + x = self.classifier(x) + x = x.squeeze() + return x + + def reparameterize(self): + for _, module in self.named_modules(): + if isinstance(module, GhostModule): + module.reparameterize() + if isinstance(module, GhostBottleneck): + module.reparameterize() + +@register_model +def ghostnetv3(**kwargs): + """ + Constructs a GhostNet model + """ + cfgs = [ + # k, t, c, SE, s + # stage1 + [[3, 16, 16, 0, 1]], + # stage2 + [[3, 48, 24, 0, 2]], + [[3, 72, 24, 0, 1]], + # stage3 + [[5, 72, 40, 0.25, 2]], + [[5, 120, 40, 0.25, 1]], + # stage4 + [[3, 240, 80, 0, 2]], + [[3, 200, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 480, 112, 0.25, 1], + [3, 672, 112, 0.25, 1] + ], + # stage5 + [[5, 672, 160, 0.25, 2]], + [[5, 960, 160, 0, 1], + [5, 960, 160, 0.25, 1], + [5, 960, 160, 0, 1], + [5, 960, 160, 0.25, 1] + ] + ] + return GhostNet(cfgs, num_classes=1000, width=kwargs['width'], dropout=0.2) + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + +def export_weight(model): + current_path = os.path.dirname(__file__) + f = open(current_path + "ghostnetv3.weights", 'w') + f.write("{}\n".format(len(model.state_dict().keys()))) + + for k, v in model.state_dict().items(): + print('exporting ... {}: {}'.format(k, v.shape)) + + vr = v.reshape(-1).cpu().numpy() + f.write("{} {}".format(k, len(vr))) + for vv in vr: + f.write(" ") + f.write(struct.pack(">f", float(vv)).hex()) + f.write("\n") + + f.close() + +def export_onnx(input, model): + current_path = os.path.dirname(__file__) + file = current_path + "ghostnetv3.onnx" + torch.onnx.export( + model=model, + args=(input,), + f=file, + input_names=["input0"], + output_names=["output0"], + opset_version=13 + ) + print("Finished ONNX export") + + model_onnx = onnx.load(file) + onnx.checker.check_model(model_onnx) + + print(f"Simplifying with onnx-simplifier {onnxsim.__version__}...") + model_onnx, check = onnxsim.simplify(model_onnx) + assert check, "Simplification check failed" + onnx.save(model_onnx, file) + +def eval_model(input, model): + output = model(input) + print("------from inference------") + print(input) + print(output) + +if __name__ == "__main__": + setup_seed(1) + + model = ghostnetv3(width=1.0) + model.eval() + + input = torch.randn(32, 3, 320, 256) + + export_weight(model) + + export_onnx(input, model) + + eval_model(input, model) From 999b522d6f9acc8d6670fdd583d66787e5309bc8 Mon Sep 17 00:00:00 2001 From: Phoenix Date: Wed, 25 Sep 2024 20:09:02 +0800 Subject: [PATCH 2/2] follow the code rules in the README --- ghostnet/ghosnetv3_inference.py | 76 +++++++++++++++++++++++++++++++ ghostnet/ghostnet.py | 56 +++-------------------- ghostnet/ghostnet_inference.py | 77 ++++++++++++++++++++++++++++++++ ghostnet/ghostnetv2.py | 61 +++---------------------- ghostnet/ghostnetv2_inference.py | 76 +++++++++++++++++++++++++++++++ ghostnet/ghostnetv3.py | 60 ++++++------------------- 6 files changed, 253 insertions(+), 153 deletions(-) create mode 100644 ghostnet/ghosnetv3_inference.py create mode 100644 ghostnet/ghostnet_inference.py create mode 100644 ghostnet/ghostnetv2_inference.py diff --git a/ghostnet/ghosnetv3_inference.py b/ghostnet/ghosnetv3_inference.py new file mode 100644 index 0000000..e23e93d --- /dev/null +++ b/ghostnet/ghosnetv3_inference.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn +import torch.onnx +import onnxsim +import onnx +import struct +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +from ghostnetv3 import ghostnetv3 + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + +def export_weight(model): + current_path = os.path.dirname(__file__) + f = open(current_path + "ghostnetv3.weights", 'w') + f.write("{}\n".format(len(model.state_dict().keys()))) + + for k, v in model.state_dict().items(): + print('exporting ... {}: {}'.format(k, v.shape)) + + vr = v.reshape(-1).cpu().numpy() + f.write("{} {}".format(k, len(vr))) + for vv in vr: + f.write(" ") + f.write(struct.pack(">f", float(vv)).hex()) + f.write("\n") + + f.close() + +def export_onnx(input, model): + current_path = os.path.dirname(__file__) + file = current_path + "ghostnetv3.onnx" + torch.onnx.export( + model=model, + args=(input,), + f=file, + input_names=["input0"], + output_names=["output0"], + opset_version=13 + ) + print("Finished ONNX export") + + model_onnx = onnx.load(file) + onnx.checker.check_model(model_onnx) + + print(f"Simplifying with onnx-simplifier {onnxsim.__version__}...") + model_onnx, check = onnxsim.simplify(model_onnx) + assert check, "Simplification check failed" + onnx.save(model_onnx, file) + +def eval_model(input, model): + output = model(input) + print("------from inference------") + print(input) + print(output) + +if __name__ == "__main__": + setup_seed(1) + + model = ghostnetv3(width=1.0) + model.eval() + + input = torch.randn(32, 3, 320, 256) + + export_weight(model) + + export_onnx(input, model) + + eval_model(input, model) \ No newline at end of file diff --git a/ghostnet/ghostnet.py b/ghostnet/ghostnet.py index 6a6e929..da4400e 100644 --- a/ghostnet/ghostnet.py +++ b/ghostnet/ghostnet.py @@ -250,66 +250,22 @@ def ghostnet(**kwargs): ] return GhostNet(cfgs, **kwargs) - def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True -def export_weight(model): - current_path = os.path.dirname(__file__) - f = open(current_path + "ghostnetv1.weights", 'w') - f.write("{}\n".format(len(model.state_dict().keys()))) - - for k, v in model.state_dict().items(): - print('exporting ... {}: {}'.format(k, v.shape)) - vr = v.reshape(-1).cpu().numpy() - f.write("{} {}".format(k, len(vr))) - for vv in vr: - f.write(" ") - f.write(struct.pack(">f", float(vv)).hex()) - f.write("\n") - - f.close() - -def export_onnx(input, model): +def export_pth(model): current_path = os.path.dirname(__file__) - file = current_path + "ghostnetv1.onnx" - torch.onnx.export( - model=model, - args=(input,), - f=file, - input_names=["input0"], - output_names=["output0"], - opset_version=13 - ) - print("Finished ONNX export") - - model_onnx = onnx.load(file) - onnx.checker.check_model(model_onnx) - - print(f"Simplifying with onnx-simplifier {onnxsim.__version__}...") - model_onnx, check = onnxsim.simplify(model_onnx) - assert check, "Simplification check failed" - onnx.save(model_onnx, file) - -def eval_model(input, model): - output = model(input) - print("------from inference------") - print(input) - print(output) + torch.save(model.state_dict(), os.path.join(current_path, "ghostnetv1.pth")) + print("Model saved as ghostnetv1.pth") + if __name__ == "__main__": setup_seed(1) - + model = ghostnet(num_classes=1000, width=1.0, dropout=0.2) model.eval() - - input = torch.randn(32, 3, 320, 256) - - export_weight(model) - - export_onnx(input, model) - eval_model(input, model) + export_pth(model) \ No newline at end of file diff --git a/ghostnet/ghostnet_inference.py b/ghostnet/ghostnet_inference.py new file mode 100644 index 0000000..19d727f --- /dev/null +++ b/ghostnet/ghostnet_inference.py @@ -0,0 +1,77 @@ +import torch +import torch.nn as nn +import torch.onnx +import onnxsim +import onnx +import struct +import os + + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from ghostnet import ghostnet + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + +def export_weight(model): + current_path = os.path.dirname(__file__) + f = open(current_path + "ghostnetv1.weights", 'w') + f.write("{}\n".format(len(model.state_dict().keys()))) + + for k, v in model.state_dict().items(): + print('exporting ... {}: {}'.format(k, v.shape)) + + vr = v.reshape(-1).cpu().numpy() + f.write("{} {}".format(k, len(vr))) + for vv in vr: + f.write(" ") + f.write(struct.pack(">f", float(vv)).hex()) + f.write("\n") + + f.close() + +def export_onnx(input, model): + current_path = os.path.dirname(__file__) + file = current_path + "ghostnetv1.onnx" + torch.onnx.export( + model=model, + args=(input,), + f=file, + input_names=["input0"], + output_names=["output0"], + opset_version=13 + ) + print("Finished ONNX export") + + model_onnx = onnx.load(file) + onnx.checker.check_model(model_onnx) + + print(f"Simplifying with onnx-simplifier {onnxsim.__version__}...") + model_onnx, check = onnxsim.simplify(model_onnx) + assert check, "Simplification check failed" + onnx.save(model_onnx, file) + +def eval_model(input, model): + output = model(input) + print("------from inference------") + print(input) + print(output) + +if __name__ == "__main__": + setup_seed(1) + + model = ghostnet(num_classes=1000, width=1.0, dropout=0.2) + model.eval() + + input = torch.randn(32, 3, 320, 256) + + export_weight(model) + + export_onnx(input, model) + + eval_model(input, model) \ No newline at end of file diff --git a/ghostnet/ghostnetv2.py b/ghostnet/ghostnetv2.py index 2b23a05..8cb0cbb 100644 --- a/ghostnet/ghostnetv2.py +++ b/ghostnet/ghostnetv2.py @@ -260,65 +260,14 @@ def ghostnetv2(**kwargs): dropout=kwargs['dropout'], args=kwargs['args']) -def setup_seed(seed): - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - -def export_weight(model): - current_path = os.path.dirname(__file__) - f = open(current_path + "ghostnetv2.weights", 'w') - f.write("{}\n".format(len(model.state_dict().keys()))) - - for k, v in model.state_dict().items(): - print('exporting ... {}: {}'.format(k, v.shape)) - - vr = v.reshape(-1).cpu().numpy() - f.write("{} {}".format(k, len(vr))) - for vv in vr: - f.write(" ") - f.write(struct.pack(">f", float(vv)).hex()) - f.write("\n") - - f.close() - -def export_onnx(input, model): +def export_pth(model): current_path = os.path.dirname(__file__) - file = current_path + "ghostnetv2.onnx" - torch.onnx.export( - model=model, - args=(input,), - f=file, - input_names=["input0"], - output_names=["output0"], - opset_version=13 - ) - print("Finished ONNX export") + torch.save(model.state_dict(), os.path.join(current_path, "ghostnetv2.pth")) + print("Model saved as ghostnetv2.pth") - model_onnx = onnx.load(file) - onnx.checker.check_model(model_onnx) - - print(f"Simplifying with onnx-simplifier {onnxsim.__version__}...") - model_onnx, check = onnxsim.simplify(model_onnx) - assert check, "Simplification check failed" - onnx.save(model_onnx, file) - -def eval_model(input, model): - output = model(input) - print("------from inference------") - print(input) - print(output) if __name__ == "__main__": - setup_seed(1) - - model = ghostnetv2(width=1.0, num_classes=1000, dropout=0.2, args=None) + model = ghostnetv2(num_classes=1000, width=1.0, dropout=0.2, args=None) model.eval() - - input = torch.randn(32, 3, 320, 256) - - export_weight(model) - - export_onnx(input, model) - eval_model(input, model) + export_pth(model) diff --git a/ghostnet/ghostnetv2_inference.py b/ghostnet/ghostnetv2_inference.py new file mode 100644 index 0000000..343bfec --- /dev/null +++ b/ghostnet/ghostnetv2_inference.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn +import torch.onnx +import onnxsim +import onnx +import struct +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from ghostnetv2 import ghostnetv2 + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + +def export_weight(model): + current_path = os.path.dirname(__file__) + f = open(current_path + "ghostnetv2.weights", 'w') + f.write("{}\n".format(len(model.state_dict().keys()))) + + for k, v in model.state_dict().items(): + print('exporting ... {}: {}'.format(k, v.shape)) + + vr = v.reshape(-1).cpu().numpy() + f.write("{} {}".format(k, len(vr))) + for vv in vr: + f.write(" ") + f.write(struct.pack(">f", float(vv)).hex()) + f.write("\n") + + f.close() + +def export_onnx(input, model): + current_path = os.path.dirname(__file__) + file = current_path + "ghostnetv2.onnx" + torch.onnx.export( + model=model, + args=(input,), + f=file, + input_names=["input0"], + output_names=["output0"], + opset_version=13 + ) + print("Finished ONNX export") + + model_onnx = onnx.load(file) + onnx.checker.check_model(model_onnx) + + print(f"Simplifying with onnx-simplifier {onnxsim.__version__}...") + model_onnx, check = onnxsim.simplify(model_onnx) + assert check, "Simplification check failed" + onnx.save(model_onnx, file) + +def eval_model(input, model): + output = model(input) + print("------from inference------") + print(input) + print(output) + +if __name__ == "__main__": + setup_seed(1) + + model = ghostnetv2(width=1.0, num_classes=1000, dropout=0.2, args=None) + model.eval() + + input = torch.randn(32, 3, 320, 256) + + export_weight(model) + + export_onnx(input, model) + + eval_model(input, model) \ No newline at end of file diff --git a/ghostnet/ghostnetv3.py b/ghostnet/ghostnetv3.py index 3e6f5ed..ff4f01f 100644 --- a/ghostnet/ghostnetv3.py +++ b/ghostnet/ghostnetv3.py @@ -790,65 +790,31 @@ def ghostnetv3(**kwargs): ] return GhostNet(cfgs, num_classes=1000, width=kwargs['width'], dropout=0.2) -def setup_seed(seed): - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - -def export_weight(model): +def export_model(model): current_path = os.path.dirname(__file__) - f = open(current_path + "ghostnetv3.weights", 'w') - f.write("{}\n".format(len(model.state_dict().keys()))) - - for k, v in model.state_dict().items(): - print('exporting ... {}: {}'.format(k, v.shape)) + torch.save(model.state_dict(), os.path.join(current_path, "ghostnetv3.pth")) + print("Model saved as ghostnetv3.pth") - vr = v.reshape(-1).cpu().numpy() - f.write("{} {}".format(k, len(vr))) - for vv in vr: - f.write(" ") - f.write(struct.pack(">f", float(vv)).hex()) - f.write("\n") - f.close() - -def export_onnx(input, model): +def export_onnx_model(input_tensor, model): current_path = os.path.dirname(__file__) - file = current_path + "ghostnetv3.onnx" + file = os.path.join(current_path, "ghostnetv3.onnx") torch.onnx.export( - model=model, - args=(input,), + model=model, + args=(input_tensor,), f=file, - input_names=["input0"], - output_names=["output0"], + input_names=["input"], + output_names=["output"], opset_version=13 ) - print("Finished ONNX export") - - model_onnx = onnx.load(file) - onnx.checker.check_model(model_onnx) + print("ONNX model exported as ghostnetv3.onnx") - print(f"Simplifying with onnx-simplifier {onnxsim.__version__}...") - model_onnx, check = onnxsim.simplify(model_onnx) - assert check, "Simplification check failed" - onnx.save(model_onnx, file) - -def eval_model(input, model): - output = model(input) - print("------from inference------") - print(input) - print(output) if __name__ == "__main__": - setup_seed(1) - model = ghostnetv3(width=1.0) model.eval() - - input = torch.randn(32, 3, 320, 256) - - export_weight(model) - export_onnx(input, model) + export_model(model) - eval_model(input, model) + input_tensor = torch.randn(1, 3, 224, 224) + export_onnx_model(input_tensor, model)