Skip to content

Network

The end2end_imaging.network module provides image reconstruction networks and loss functions for end-to-end training.

Looking for PSF surrogate networks?

The PSF surrogate models (MLP, SIREN, etc.) are part of DeepLens — see Surrogate Networks in the DeepLens API reference.

Reconstruction Networks

Image restoration networks that recover a clean image from a degraded (aberrated) sensor capture.

end2end_imaging.network.NAFNet

NAFNet(in_chan=3, out_chan=3, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 1, 28], dec_blk_nums=[1, 1, 1, 1])

Bases: Module

Nonlinear Activation Free Network for image restoration.

A U-Net-style encoder-decoder with NAFBlocks that replace nonlinear activations with SimpleGate (element-wise multiplication of channel-split halves). Includes a global residual connection from input to output.

Reference: "Simple Baselines for Image Restoration" (ECCV 2022).

Parameters:

Name Type Description Default
in_chan

Number of input channels. Defaults to 3.

3
out_chan

Number of output channels. Defaults to 3.

3
width

Base channel width. Defaults to 32.

32
middle_blk_num

Number of NAFBlocks in the bottleneck. Defaults to 1.

1
enc_blk_nums

Number of NAFBlocks per encoder stage. Defaults to [1, 1, 1, 28].

[1, 1, 1, 28]
dec_blk_nums

Number of NAFBlocks per decoder stage. Defaults to [1, 1, 1, 1].

[1, 1, 1, 1]
Source code in end2endimaging-src/end2end_imaging/network/reconstruction/nafnet.py
def __init__(
    self,
    in_chan=3,
    out_chan=3,
    width=32,  # 64
    middle_blk_num=1,
    enc_blk_nums=[1, 1, 1, 28],
    dec_blk_nums=[1, 1, 1, 1],
):
    super().__init__()

    self.intro = nn.Conv2d(
        in_channels=in_chan,
        out_channels=width,
        kernel_size=3,
        padding=1,
        stride=1,
        groups=1,
        bias=True,
    )
    self.ending = nn.Conv2d(
        in_channels=width,
        out_channels=out_chan,
        kernel_size=3,
        padding=1,
        stride=1,
        groups=1,
        bias=True,
    )

    self.encoders = nn.ModuleList()
    self.decoders = nn.ModuleList()
    self.middle_blks = nn.ModuleList()
    self.ups = nn.ModuleList()
    self.downs = nn.ModuleList()

    chan = width
    for num in enc_blk_nums:
        self.encoders.append(nn.Sequential(*[NAFBlock(chan) for _ in range(num)]))
        self.downs.append(nn.Conv2d(chan, 2 * chan, 2, 2))
        chan = chan * 2

    self.middle_blks = nn.Sequential(
        *[NAFBlock(chan) for _ in range(middle_blk_num)]
    )

    for num in dec_blk_nums:
        self.ups.append(
            nn.Sequential(
                nn.Conv2d(chan, chan * 2, 1, bias=False), nn.PixelShuffle(2)
            )
        )
        chan = chan // 2
        self.decoders.append(nn.Sequential(*[NAFBlock(chan) for _ in range(num)]))

    self.padder_size = 2 ** len(self.encoders)

    # Initialize weights  
    self.initialize_weights()  

initialize_weights

initialize_weights()

Initialize all module weights.

Uses truncated-normal initialization (std 0.02) for conv and linear layers per the NAFNet paper, sets BatchNorm to identity scale, and zeros the final conv so the global residual yields an exact identity on the first out_chan input channels at the start of training.

Source code in end2endimaging-src/end2end_imaging/network/reconstruction/nafnet.py
def initialize_weights(self):
    """Initialize all module weights.

    Uses truncated-normal initialization (std 0.02) for conv and linear
    layers per the NAFNet paper, sets BatchNorm to identity scale, and
    zeros the final conv so the global residual yields an exact identity
    on the first ``out_chan`` input channels at the start of training.
    """
    # NAFNet has no ReLU (uses SimpleGate); kaiming-relu inflates activations by sqrt(2)
    # at every layer. Use trunc_normal(std=0.02) per the NAFNet paper.
    for m in self.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
    # Zero the final conv so the global residual makes the network an exact identity
    # on the first `out_chan` input channels at step 0. Training then learns the correction.
    nn.init.zeros_(self.ending.weight)
    if self.ending.bias is not None:
        nn.init.zeros_(self.ending.bias)

forward

forward(inp)

Forward pass with global residual connection.

Parameters:

Name Type Description Default
inp

Input image tensor of shape (B, in_chan, H, W).

required

Returns:

Type Description

Restored image tensor of shape (B, out_chan, H, W).

Source code in end2endimaging-src/end2end_imaging/network/reconstruction/nafnet.py
def forward(self, inp):
    """Forward pass with global residual connection.

    Args:
        inp: Input image tensor of shape ``(B, in_chan, H, W)``.

    Returns:
        Restored image tensor of shape ``(B, out_chan, H, W)``.
    """
    B, C, H, W = inp.shape
    inp = self.check_image_size(inp)

    x = self.intro(inp)

    encs = []

    for encoder, down in zip(self.encoders, self.downs):
        x = encoder(x)
        encs.append(x)
        x = down(x)

    x = self.middle_blks(x)

    for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
        x = up(x)
        x = x + enc_skip
        x = decoder(x)

    x = self.ending(x)
    x = x + inp[:, :x.shape[1], :, :]

    return x[:, :, :H, :W]

check_image_size

check_image_size(x)

Pad the input so its spatial dims are divisible by padder_size.

Parameters:

Name Type Description Default
x

Input tensor of shape (B, C, H, W).

required

Returns:

Type Description

Zero-padded tensor whose height and width are multiples of

self.padder_size.

Source code in end2endimaging-src/end2end_imaging/network/reconstruction/nafnet.py
def check_image_size(self, x):
    """Pad the input so its spatial dims are divisible by ``padder_size``.

    Args:
        x: Input tensor of shape ``(B, C, H, W)``.

    Returns:
        Zero-padded tensor whose height and width are multiples of
        ``self.padder_size``.
    """
    _, _, h, w = x.size()
    mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
    mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
    x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
    return x

end2end_imaging.network.UNet

UNet(in_channels=3, out_channels=3)

Bases: Module

U-Net with residual skip connections for image restoration.

A 3-level encoder-decoder with dense BasicBlocks and PixelShuffle upsampling. Uses additive skip connections between encoder and decoder stages.

Parameters:

Name Type Description Default
in_channels

Number of input channels. Defaults to 3.

3
out_channels

Number of output channels. Defaults to 3.

3
Source code in end2endimaging-src/end2end_imaging/network/reconstruction/unet.py
def __init__(self, in_channels=3, out_channels=3):
    super().__init__()
    self.pre = nn.Sequential(
        nn.Conv2d(in_channels, 16, kernel_size=3, stride=1, padding=1), nn.PReLU(16)
    )
    self.conv00 = BasicBlock(16, 32)
    self.down0 = nn.MaxPool2d((2, 2))
    self.conv10 = BasicBlock(32, 64)
    self.down1 = nn.MaxPool2d((2, 2))
    self.conv20 = BasicBlock(64, 128)
    self.down2 = nn.MaxPool2d((2, 2))
    self.conv30 = BasicBlock(128, 256)
    self.conv31 = BasicBlock(256, 512)
    self.up2 = nn.PixelShuffle(2)
    self.conv21 = BasicBlock(128, 256)
    self.up1 = nn.PixelShuffle(2)
    self.conv11 = BasicBlock(64, 128)
    self.up0 = nn.PixelShuffle(2)
    self.conv01 = BasicBlock(32, 64)

    self.post = nn.Sequential(
        nn.Conv2d(64, 16, kernel_size=3, stride=1, padding=1),
        nn.PReLU(16),
        nn.Conv2d(16, out_channels, kernel_size=3, stride=1, padding=1),
    )

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x

Input image tensor of shape (B, in_channels, H, W).

required

Returns:

Type Description

Output tensor of shape (B, out_channels, H, W).

Source code in end2endimaging-src/end2end_imaging/network/reconstruction/unet.py
def forward(self, x):
    """Forward pass.

    Args:
        x: Input image tensor of shape ``(B, in_channels, H, W)``.

    Returns:
        Output tensor of shape ``(B, out_channels, H, W)``.
    """
    x0 = self.pre(x)
    x0 = self.conv00(x0)
    x1 = self.down0(x0)
    x1 = self.conv10(x1)
    x2 = self.down1(x1)
    x2 = self.conv20(x2)
    x3 = self.down2(x2)
    x3 = self.conv30(x3)
    x3 = self.conv31(x3)
    x2 = x2 + self.up2(x3)
    x2 = self.conv21(x2)
    x1 = x1 + self.up1(x2)
    x1 = self.conv11(x1)
    x0 = x0 + self.up0(x1)
    x0 = self.conv01(x0)
    x = self.post(x0)
    return x

end2end_imaging.network.Restormer

Restormer(inp_channels=3, out_channels=3, dim=48, num_blocks=[4, 6, 6, 8], num_refinement_blocks=4, heads=[1, 2, 4, 8], ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', dual_pixel_task=False)

Bases: Module

Restormer: Efficient Transformer for high-resolution image restoration.

A multi-scale encoder-decoder transformer using Multi-DConv Head Transposed Self-Attention (MDTA) and Gated-DConv Feed-Forward Networks (GDFN). Includes a global residual connection from input to output.

Reference: Zamir et al., "Restormer: Efficient Transformer for High-Resolution Image Restoration" (CVPR 2022).

Parameters:

Name Type Description Default
inp_channels

Number of input channels. Defaults to 3.

3
out_channels

Number of output channels. Defaults to 3.

3
dim

Base embedding dimension. Defaults to 48.

48
num_blocks

Number of transformer blocks per encoder/decoder stage. Defaults to [4, 6, 6, 8].

[4, 6, 6, 8]
num_refinement_blocks

Number of refinement blocks after the decoder. Defaults to 4.

4
heads

Number of attention heads per stage. Defaults to [1, 2, 4, 8].

[1, 2, 4, 8]
ffn_expansion_factor

Hidden dimension multiplier in GDFN. Defaults to 2.66.

2.66
bias

Whether to use bias in convolutions. Defaults to False.

False
LayerNorm_type

"WithBias" or "BiasFree". Defaults to "WithBias".

'WithBias'
dual_pixel_task

If True, uses skip connection for dual-pixel defocus deblurring (set inp_channels=6). Defaults to False.

False
Source code in end2endimaging-src/end2end_imaging/network/reconstruction/restormer.py
def __init__(
    self,
    inp_channels=3,
    out_channels=3,
    dim=48,
    num_blocks=[4, 6, 6, 8],
    num_refinement_blocks=4,
    heads=[1, 2, 4, 8],
    ffn_expansion_factor=2.66,
    bias=False,
    LayerNorm_type="WithBias",  ## Other option 'BiasFree'
    dual_pixel_task=False,  ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
):
    super(Restormer, self).__init__()

    self.patch_embed = OverlapPatchEmbed(inp_channels, dim)

    self.encoder_level1 = nn.Sequential(
        *[
            TransformerBlock(
                dim=dim,
                num_heads=heads[0],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
            )
            for i in range(num_blocks[0])
        ]
    )

    self.down1_2 = Downsample(dim)  ## From Level 1 to Level 2
    self.encoder_level2 = nn.Sequential(
        *[
            TransformerBlock(
                dim=int(dim * 2**1),
                num_heads=heads[1],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
            )
            for i in range(num_blocks[1])
        ]
    )

    self.down2_3 = Downsample(int(dim * 2**1))  ## From Level 2 to Level 3
    self.encoder_level3 = nn.Sequential(
        *[
            TransformerBlock(
                dim=int(dim * 2**2),
                num_heads=heads[2],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
            )
            for i in range(num_blocks[2])
        ]
    )

    self.down3_4 = Downsample(int(dim * 2**2))  ## From Level 3 to Level 4
    self.latent = nn.Sequential(
        *[
            TransformerBlock(
                dim=int(dim * 2**3),
                num_heads=heads[3],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
            )
            for i in range(num_blocks[3])
        ]
    )

    self.up4_3 = Upsample(int(dim * 2**3))  ## From Level 4 to Level 3
    self.reduce_chan_level3 = nn.Conv2d(
        int(dim * 2**3), int(dim * 2**2), kernel_size=1, bias=bias
    )
    self.decoder_level3 = nn.Sequential(
        *[
            TransformerBlock(
                dim=int(dim * 2**2),
                num_heads=heads[2],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
            )
            for i in range(num_blocks[2])
        ]
    )

    self.up3_2 = Upsample(int(dim * 2**2))  ## From Level 3 to Level 2
    self.reduce_chan_level2 = nn.Conv2d(
        int(dim * 2**2), int(dim * 2**1), kernel_size=1, bias=bias
    )
    self.decoder_level2 = nn.Sequential(
        *[
            TransformerBlock(
                dim=int(dim * 2**1),
                num_heads=heads[1],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
            )
            for i in range(num_blocks[1])
        ]
    )

    self.up2_1 = Upsample(
        int(dim * 2**1)
    )  ## From Level 2 to Level 1  (NO 1x1 conv to reduce channels)

    self.decoder_level1 = nn.Sequential(
        *[
            TransformerBlock(
                dim=int(dim * 2**1),
                num_heads=heads[0],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
            )
            for i in range(num_blocks[0])
        ]
    )

    self.refinement = nn.Sequential(
        *[
            TransformerBlock(
                dim=int(dim * 2**1),
                num_heads=heads[0],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
            )
            for i in range(num_refinement_blocks)
        ]
    )

    #### For Dual-Pixel Defocus Deblurring Task ####
    self.dual_pixel_task = dual_pixel_task
    if self.dual_pixel_task:
        self.skip_conv = nn.Conv2d(dim, int(dim * 2**1), kernel_size=1, bias=bias)
    ###########################

    self.output = nn.Conv2d(
        int(dim * 2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias
    )

forward

forward(inp_img)

Forward pass with global residual connection.

Parameters:

Name Type Description Default
inp_img

Input image tensor of shape (B, inp_channels, H, W).

required

Returns:

Type Description

Restored image tensor of shape (B, out_channels, H, W).

Source code in end2endimaging-src/end2end_imaging/network/reconstruction/restormer.py
def forward(self, inp_img):
    """Forward pass with global residual connection.

    Args:
        inp_img: Input image tensor of shape ``(B, inp_channels, H, W)``.

    Returns:
        Restored image tensor of shape ``(B, out_channels, H, W)``.
    """
    inp_enc_level1 = self.patch_embed(inp_img)
    out_enc_level1 = self.encoder_level1(inp_enc_level1)

    inp_enc_level2 = self.down1_2(out_enc_level1)
    out_enc_level2 = self.encoder_level2(inp_enc_level2)

    inp_enc_level3 = self.down2_3(out_enc_level2)
    out_enc_level3 = self.encoder_level3(inp_enc_level3)

    inp_enc_level4 = self.down3_4(out_enc_level3)
    latent = self.latent(inp_enc_level4)

    inp_dec_level3 = self.up4_3(latent)
    inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
    inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
    out_dec_level3 = self.decoder_level3(inp_dec_level3)

    inp_dec_level2 = self.up3_2(out_dec_level3)
    inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
    inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
    out_dec_level2 = self.decoder_level2(inp_dec_level2)

    inp_dec_level1 = self.up2_1(out_dec_level2)
    inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
    out_dec_level1 = self.decoder_level1(inp_dec_level1)

    out_dec_level1 = self.refinement(out_dec_level1)

    #### For Dual-Pixel Defocus Deblurring Task ####
    if self.dual_pixel_task:
        out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
        out_dec_level1 = self.output(out_dec_level1)
    ###########################
    else:
        out_dec_level1 = self.output(out_dec_level1) + inp_img

    return out_dec_level1

Loss Functions

Differentiable image-quality losses for training reconstruction networks.

end2end_imaging.network.PerceptualLoss

PerceptualLoss(device=None, weights=[1.0, 1.0, 1.0, 1.0, 1.0])

Bases: Module

Perceptual loss based on VGG16 features.

Initialize perceptual loss.

Parameters:

Name Type Description Default
device

Device to put the VGG model on. If None, uses cuda if available.

None
weights

Weights for different feature layers.

[1.0, 1.0, 1.0, 1.0, 1.0]
Source code in end2endimaging-src/end2end_imaging/network/loss/perceptual_loss.py
def __init__(self, device=None, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
    """Initialize perceptual loss.

    Args:
        device: Device to put the VGG model on. If None, uses cuda if available.
        weights: Weights for different feature layers.
    """
    super(PerceptualLoss, self).__init__()

    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    self.vgg = models.vgg16(weights=VGG16_Weights.DEFAULT).features.to(device)
    self.layer_name_mapping = {
        '3': "relu1_2",
        '8': "relu2_2",
        '15': "relu3_3",
        '22': "relu4_3",
        '29': "relu5_3"
    }

    self.weights = weights

    for param in self.vgg.parameters():
        param.requires_grad = False

forward

forward(x, y)

Calculate perceptual loss.

Parameters:

Name Type Description Default
x

Predicted tensor.

required
y

Target tensor.

required

Returns:

Type Description

Perceptual loss.

Source code in end2endimaging-src/end2end_imaging/network/loss/perceptual_loss.py
def forward(self, x, y):
    """Calculate perceptual loss.

    Args:
        x: Predicted tensor.
        y: Target tensor.

    Returns:
        Perceptual loss.
    """
    x_vgg, y_vgg = self._get_features(x), self._get_features(y)

    content_loss = 0.0
    for i, (key, value) in enumerate(x_vgg.items()):
        content_loss += self.weights[i] * torch.mean((value - y_vgg[key]) ** 2)

    return content_loss

end2end_imaging.network.PSNRLoss

PSNRLoss(loss_weight=1.0, reduction='mean', toY=False)

Bases: Module

Peak Signal-to-Noise Ratio (PSNR) loss.

Initialize PSNR loss.

Parameters:

Name Type Description Default
loss_weight

Weight for the loss.

1.0
reduction

Reduction method, only "mean" is supported.

'mean'
toY

Whether to convert RGB to Y channel.

False
Source code in end2endimaging-src/end2end_imaging/network/loss/psnr_loss.py
def __init__(self, loss_weight=1.0, reduction="mean", toY=False):
    """Initialize PSNR loss.

    Args:
        loss_weight: Weight for the loss.
        reduction: Reduction method, only "mean" is supported.
        toY: Whether to convert RGB to Y channel.
    """
    super(PSNRLoss, self).__init__()
    assert reduction == "mean"
    self.loss_weight = loss_weight
    self.scale = 10 / np.log(10)
    self.toY = toY
    self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
    self.first = True

forward

forward(pred, target)

Calculate PSNR loss.

Parameters:

Name Type Description Default
pred

Predicted tensor.

required
target

Target tensor.

required

Returns:

Type Description

PSNR loss.

Source code in end2endimaging-src/end2end_imaging/network/loss/psnr_loss.py
def forward(self, pred, target):
    """Calculate PSNR loss.

    Args:
        pred: Predicted tensor.
        target: Target tensor.

    Returns:
        PSNR loss.
    """
    assert len(pred.size()) == 4
    if self.toY:
        if self.first:
            self.coef = self.coef.to(pred.device)
            self.first = False

        pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.0
        target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.0

        pred, target = pred / 255.0, target / 255.0
        pass
    assert len(pred.size()) == 4

    return (
        self.loss_weight
        * self.scale
        * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()
    ) 

end2end_imaging.network.SSIMLoss

SSIMLoss(window_size=11, size_average=True)

Bases: Module

Structural Similarity Index (SSIM) loss.

Initialize SSIM loss.

Parameters:

Name Type Description Default
window_size

Size of the window.

11
size_average

Whether to average the loss.

True
Source code in end2endimaging-src/end2end_imaging/network/loss/ssim_loss.py
def __init__(self, window_size=11, size_average=True):
    """Initialize SSIM loss.

    Args:
        window_size: Size of the window.
        size_average: Whether to average the loss.
    """
    super(SSIMLoss, self).__init__()
    self.window_size = window_size
    self.size_average = size_average
    self.channel = 1
    self.window = self._create_window(window_size, self.channel)

forward

forward(pred, target)

Calculate SSIM loss.

Parameters:

Name Type Description Default
pred

Predicted tensor.

required
target

Target tensor.

required

Returns:

Type Description

1 - SSIM value.

Source code in end2endimaging-src/end2end_imaging/network/loss/ssim_loss.py
def forward(self, pred, target):
    """Calculate SSIM loss.

    Args:
        pred: Predicted tensor.
        target: Target tensor.

    Returns:
        1 - SSIM value.
    """
    return 1 - self._ssim(pred, target)