import torch
import torch.nn as nn
class brightness_decoder(nn.Module):
def __init__(self, example_tensor):
super(brightness_decoder, self).__init__()
B, C, H, W = example_tensor.shape
self.batch_size = B
self.conv1 = nn.Conv2d(C, 1, 1, stride=1)
self.dense1 = nn.Linear(1 * H * W, 128)
self.dense2 = nn.Linear(128, 64)
self.alpha_decode = nn.Linear(64, 1)
self.beta_decode = nn.Linear(64, 1)
def forward(self, x, offset=0.2):
out = nn.ReLU()(self.conv1(x))
out = out.view(self.batch_size, -1)
out = nn.ReLU()(self.dense1(out))
out = nn.ReLU()(self.dense2(out))
a = nn.ReLU()(2.0 * nn.Tanh()(self.alpha_decode(out)))
b = nn.Tanh()(self.beta_decode(out))
# a = torch.clamp(a, min=0.8, max=1.2)
# a = a.clamp(1.0 - offset, 1.0 + offset)
# b = b.clamp(- offset, offset)
return a, b
if __name__ == '__main__':
x = torch.rand((8, 512, 11, 15))
model = brightness_decoder(x)
alpha, beta = model(x)