Newer
Older
MyEndoSfMLearner / networks / brightness_decoder2.py
@planck planck on 17 Nov 2020 2 KB 最初のコミット
# Copyright Niantic 2019. Patent Pending. All rights reserved.
#
# This software is licensed under the terms of the Monodepth2 licence
# which allows for non-commercial use only, the full terms of which are made
# available in the LICENSE file.

from __future__ import absolute_import, division, print_function

import torch
import torch.nn as nn
from collections import OrderedDict
import networks


class brightness_decoder2(nn.Module):
    def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=1, stride=1):
        super(brightness_decoder2, self).__init__()

        num_frames_to_predict_for = 1
        self.num_ch_enc = num_ch_enc
        self.num_input_features = num_input_features

        if num_frames_to_predict_for is None:
            num_frames_to_predict_for = num_input_features - 1
        self.num_frames_to_predict_for = num_frames_to_predict_for

        self.convs = OrderedDict()
        self.convs[("squeeze")] = nn.Conv2d(512, 256, 1)
        self.convs[("brightness", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1)
        self.convs[("brightness", 1)] = nn.Conv2d(256, 256, 3, stride, 1)
        self.convs[("brightness", 2)] = nn.Conv2d(256, 1 * num_frames_to_predict_for, 1)
        self.convs[("brightness", 3)] = nn.Conv2d(256, 1 * num_frames_to_predict_for, 1)

        self.relu = nn.ReLU()

        self.net = nn.ModuleList(list(self.convs.values()))

    def forward(self, input_features):
        last_features = [f[-1] for f in input_features]

        cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features]
        cat_features = torch.cat(cat_features, 1)

        out = cat_features
        for i in range(2):
            out = self.convs[("brightness", i)](out)
            if i != 2:
                out = self.relu(out)
        a = self.convs[("brightness", 2)](out)
        b = self.convs[("brightness", 3)](out)

        a = a.mean(3).mean(2)
        b = b.mean(3).mean(2)

        a = 1.0 + (0.8 + nn.Tanh()(a))
        b = nn.Tanh()(b)

        return a, b


if __name__ == '__main__':
    x = torch.rand((4, 6, 352, 480))
    model1 = networks.ResnetEncoder(
        18,
        False,
        num_input_images=2, )
    model2 = brightness_decoder2(
                    512,
                    num_input_features=1,
                    num_frames_to_predict_for=2)
    x = model1(x)
    a, b = model2([x])