Newer
Older
Demo-Maker / modules / PytorchSSD / prunning / prunner.py
@mikado-4410 mikado-4410 on 10 Oct 2024 10 KB 最初のコミット
import torch
import torch.nn as nn
import logging
from heapq import nsmallest

from ..utils.model_book import ModelBook


class ModelPrunner:
    def __init__(self, model, train_fun, ignored_paths=[]):
        """Implement the pruning algorithm described in the paper https://arxiv.org/pdf/1611.06440.pdf .

        The prunning criteria is dC/dh * h, while C is the cost, h is the activation.
        """
        self.model = model
        self.train_fun = train_fun
        self.ignored_paths = ignored_paths
        self.book = ModelBook(self.model)
        self.outputs = {}
        self.grads = {}
        self.handles = []
        self.decendent_batch_norms = {}  # descendants impacted by the conv layers.
        self.last_conv_path = None  # used to trace the graph
        self.descendent_convs = {}  # descendants impacted by the conv layers.
        self.descendent_linears = {}  # descendants impacted by the linear layers.
        self.last_linear_path = None  # used to trace the graph

    def _make_new_conv(self, conv, filter_index, channel_type="out"):
        if not isinstance(conv, nn.Conv2d):
            raise TypeError(f"The module is not Conv2d, but {type(conv)}.")

        if channel_type == "out":
            new_conv = nn.Conv2d(
                conv.in_channels,
                conv.out_channels - 1,
                conv.kernel_size,
                conv.stride,
                conv.padding,
                conv.dilation,
                conv.groups,
                conv.bias is not None,
            )
            mask = torch.ones(conv.out_channels, dtype=torch.uint8)
            mask[filter_index] = 0
            new_conv.weight.data = conv.weight.data[mask, :, :, :]
            if conv.bias is not None:
                new_conv.bias.data = conv.bias.data[mask]

        elif channel_type == "in":
            new_conv = nn.Conv2d(
                conv.in_channels - 1,
                conv.out_channels,
                conv.kernel_size,
                conv.stride,
                conv.padding,
                conv.dilation,
                conv.groups,
                conv.bias is not None,
            )
            mask = torch.ones(conv.in_channels, dtype=torch.uint8)
            mask[filter_index] = 0
            new_conv.weight.data = conv.weight.data[:, mask, :, :]
            if conv.bias is not None:
                new_conv.bias.data = conv.bias.data
        else:
            raise ValueError(f"{channel_type} should be either 'in' or 'out'.")
        return new_conv

    def remove_conv_filter(self, path, filter_index):
        conv = self.book.get_module(path)
        logging.info(f'Prune Conv: {"/".join(path)}, Filter: {filter_index}, Layer: {conv}')
        new_conv = self._make_new_conv(conv, filter_index, channel_type="out")
        self._update_model(path, new_conv)

        next_conv_path = self.descendent_convs.get(path)
        if next_conv_path:
            next_conv = self.book.get_module(next_conv_path)
            new_next_conv = self._make_new_conv(next_conv, filter_index, channel_type="in")
            self._update_model(next_conv_path, new_next_conv)

        # reduce the num_features of batch norm
        batch_norm_path = self.decendent_batch_norms.get(path)
        if batch_norm_path:
            batch_norm = self.book.get_module(batch_norm_path)
            new_batch_norm = nn.BatchNorm2d(batch_norm.num_features - 1)
            self._update_model(batch_norm_path, new_batch_norm)

        # reduce the in channels of linear layer
        linear_path = self.descendent_linears.get(path)
        if linear_path:
            linear = self.book.get_module(linear_path)
            new_linear = self._make_new_linear(linear, filter_index, conv, channel_type="in")
            self._update_model(linear_path, new_linear)

    @staticmethod
    def _make_new_linear(linear, feature_index, conv=None, channel_type="out"):
        if channel_type == "out":
            new_linear = nn.Linear(linear.in_features, linear.out_features - 1, bias=linear.bias is not None)
            mask = torch.ones(linear.out_features, dtype=torch.uint8)
            mask[feature_index] = 0
            new_linear.weight.data = linear.weight.data[mask, :]
            if linear.bias is not None:
                new_linear.bias.data = linear.bias.data[mask]
        elif channel_type == "in":
            if conv:
                block = int(linear.in_features / conv.out_channels)
            else:
                block = 1
            new_linear = nn.Linear(linear.in_features - block, linear.out_features, bias=linear.bias is not None)
            start_index = feature_index * block
            end_index = (feature_index + 1) * block
            mask = torch.ones(linear.in_features, dtype=torch.uint8)
            mask[start_index:end_index] = 0
            new_linear.weight.data = linear.weight.data[:, mask]
            if linear.bias is not None:
                new_linear.bias.data = linear.bias.data
        else:
            raise ValueError(f"{channel_type} should be either 'in' or 'out'.")
        return new_linear

    def prune_conv_layers(self, num=1):
        """Prune one conv2d filter."""
        self.register_conv_hooks()
        before_loss, before_accuracy = self.train_fun(self.model)
        ranks = []
        for path, output in self.outputs.items():
            output = output.data
            grad = self.grads[path].data
            v = grad * output
            v = v.sum(0).sum(1).sum(1)  # sum to the channel axis.
            v = torch.abs(v)
            v = v / torch.sqrt(torch.sum(v * v))  # normalize
            for i, e in enumerate(v):
                ranks.append((path, i, e))
        to_prune = nsmallest(num, ranks, key=lambda t: t[2])
        to_prune = sorted(
            to_prune, key=lambda t: (t[0], -t[1])
        )  # prune the filters with bigger indexes first to avoid rearrangement.
        for path, filter_index, value in to_prune:
            self.remove_conv_filter(path, filter_index)
        self.deregister_hooks()
        after_loss, after_accuracy = self.train_fun(self.model)
        return after_loss - before_loss, after_accuracy - before_accuracy

    def register_conv_hooks(self):
        """Run register before training for pruning."""
        self.outputs.clear()
        self.grads.clear()
        self.handles.clear()
        self.last_conv_path = None
        self.decendent_batch_norms.clear()
        self.descendent_convs.clear()
        self.descendent_linears.clear()

        def forward_hook(m, input, output):
            path = self.book.get_path(m)
            if isinstance(m, nn.Conv2d):
                if path not in self.ignored_paths:
                    self.outputs[path] = output
                if self.last_conv_path:
                    self.descendent_convs[self.last_conv_path] = path
                self.last_conv_path = path
            elif isinstance(m, nn.BatchNorm2d):
                if self.last_conv_path:
                    self.decendent_batch_norms[self.last_conv_path] = path
            elif isinstance(m, nn.Linear):
                if self.last_conv_path:
                    self.descendent_linears[self.last_conv_path] = path
                self.last_conv_path = None  # after a linear layer the conv layer doesn't matter

        def backward_hook(m, input, output):
            path = self.book.get_path(m)
            self.grads[path] = output[0]

        for path, m in self.book.modules(module_type=(nn.Conv2d, nn.BatchNorm2d, nn.Linear)):
            h = m.register_forward_hook(forward_hook)
            self.handles.append(h)
            h = m.register_backward_hook(backward_hook)
            self.handles.append(h)

    def deregister_hooks(self):
        """Run degresiter before retraining to recover the model"""
        for handle in self.handles:
            handle.remove()

    def prune_linear_layers(self, num=1):
        self.register_linear_hooks()
        before_loss, before_accuracy = self.train_fun(self.model)
        ranks = []
        for path, output in self.outputs.items():
            output = output.data
            grad = self.grads[path].data
            v = grad * output
            v = v.sum(0)  # sum to the channel axis.
            v = torch.abs(v)
            v = v / torch.sqrt(torch.sum(v * v))  # normalize
            for i, e in enumerate(v):
                ranks.append((path, i, e))
        to_prune = nsmallest(num, ranks, key=lambda t: t[2])
        to_prune = sorted(to_prune, key=lambda t: (t[0], -t[1]))
        for path, feature_index, value in to_prune:
            self.remove_linear_feature(path, feature_index)
        self.deregister_hooks()
        after_loss, after_accuracy = self.train_fun(self.model)
        return after_loss - before_loss, after_accuracy - before_accuracy

    def register_linear_hooks(self):
        self.outputs.clear()
        self.grads.clear()
        self.handles.clear()
        self.descendent_linears.clear()
        self.last_linear_path = None

        def forward_hook(m, input, output):
            path = self.book.get_path(m)
            if path not in self.ignored_paths:
                self.outputs[path] = output
            if self.last_linear_path:
                self.descendent_linears[self.last_linear_path] = path
            self.last_linear_path = path

        def backward_hook(m, input, output):
            path = self.book.get_path(m)
            self.grads[path] = output[0]

        for _, m in self.book.linear_modules():
            h = m.register_forward_hook(forward_hook)
            self.handles.append(h)
            h = m.register_backward_hook(backward_hook)
            self.handles.append(h)

    def remove_linear_feature(self, path, feature_index):
        linear = self.book.get_module(path)
        logging.info(f'Prune Linear: {"/".join(path)}, Filter: {feature_index}, Layer: {linear}')
        new_linear = self._make_new_linear(linear, feature_index, channel_type="out")
        self._update_model(path, new_linear)

        # update following linear layers
        next_linear_path = self.descendent_linears.get(path)
        if next_linear_path:
            next_linear = self.book.get_module(next_linear_path)
            new_next_linear = self._make_new_linear(next_linear, feature_index, channel_type="in")
            self._update_model(next_linear_path, new_next_linear)

    def _update_model(self, path, module):
        parent = self.book.get_module(path[:-1])
        parent._modules[path[-1]] = module
        self.book.update(path, module)