Newer
Older
SC-SfMLearner_pytorch / networks / brightness_decoder.py
@planck planck on 9 Nov 2020 1 KB 最初のコミット
import torch
import torch.nn as nn


class brightness_decoder(nn.Module):
    def __init__(self, example_tensor):
        super(brightness_decoder, self).__init__()
        B, C, H, W = example_tensor.shape
        self.batch_size = B
        self.conv1 = nn.Conv2d(C, 1, 1, stride=1)
        self.dense1 = nn.Linear(1 * H * W, 128)
        self.dense2 = nn.Linear(128, 64)
        self.alpha_decode = nn.Linear(64, 1)
        self.beta_decode = nn.Linear(64, 1)


    def forward(self, x, offset=0.2):
        out = nn.ReLU()(self.conv1(x))
        out = out.view(self.batch_size, -1)
        out = nn.ReLU()(self.dense1(out))
        out = nn.ReLU()(self.dense2(out))
        a = nn.ReLU()(2.0 * nn.Tanh()(self.alpha_decode(out)))
        b = nn.Tanh()(self.beta_decode(out))
        # a = torch.clamp(a, min=0.8, max=1.2)
        # a = a.clamp(1.0 - offset, 1.0 + offset)
        # b = b.clamp(- offset, offset)

        return a, b


if __name__ == '__main__':
    x = torch.rand((8, 512, 11, 15))
    model = brightness_decoder(x)
    alpha, beta = model(x)