Network
The end2end_imaging.network module provides image reconstruction networks and loss functions for end-to-end training.
Looking for PSF surrogate networks?
The PSF surrogate models (MLP, SIREN, etc.) are part of DeepLens — see Surrogate Networks in the DeepLens API reference.
Reconstruction Networks
Image restoration networks that recover a clean image from a degraded (aberrated) sensor capture.
end2end_imaging.network.NAFNet
NAFNet(in_chan=3, out_chan=3, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 1, 28], dec_blk_nums=[1, 1, 1, 1])
Bases: Module
Nonlinear Activation Free Network for image restoration.
A U-Net-style encoder-decoder with NAFBlocks that replace nonlinear activations with SimpleGate (element-wise multiplication of channel-split halves). Includes a global residual connection from input to output.
Reference: "Simple Baselines for Image Restoration" (ECCV 2022).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_chan
|
Number of input channels. Defaults to 3. |
3
|
|
out_chan
|
Number of output channels. Defaults to 3. |
3
|
|
width
|
Base channel width. Defaults to 32. |
32
|
|
middle_blk_num
|
Number of NAFBlocks in the bottleneck. Defaults to 1. |
1
|
|
enc_blk_nums
|
Number of NAFBlocks per encoder stage. Defaults to |
[1, 1, 1, 28]
|
|
dec_blk_nums
|
Number of NAFBlocks per decoder stage. Defaults to |
[1, 1, 1, 1]
|
Source code in end2endimaging-src/end2end_imaging/network/reconstruction/nafnet.py
initialize_weights
Initialize all module weights.
Uses truncated-normal initialization (std 0.02) for conv and linear
layers per the NAFNet paper, sets BatchNorm to identity scale, and
zeros the final conv so the global residual yields an exact identity
on the first out_chan input channels at the start of training.
Source code in end2endimaging-src/end2end_imaging/network/reconstruction/nafnet.py
forward
Forward pass with global residual connection.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inp
|
Input image tensor of shape |
required |
Returns:
| Type | Description |
|---|---|
|
Restored image tensor of shape |
Source code in end2endimaging-src/end2end_imaging/network/reconstruction/nafnet.py
check_image_size
Pad the input so its spatial dims are divisible by padder_size.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Input tensor of shape |
required |
Returns:
| Type | Description |
|---|---|
|
Zero-padded tensor whose height and width are multiples of |
|
|
|
Source code in end2endimaging-src/end2end_imaging/network/reconstruction/nafnet.py
end2end_imaging.network.UNet
Bases: Module
U-Net with residual skip connections for image restoration.
A 3-level encoder-decoder with dense BasicBlocks and PixelShuffle upsampling. Uses additive skip connections between encoder and decoder stages.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
Number of input channels. Defaults to 3. |
3
|
|
out_channels
|
Number of output channels. Defaults to 3. |
3
|
Source code in end2endimaging-src/end2end_imaging/network/reconstruction/unet.py
forward
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Input image tensor of shape |
required |
Returns:
| Type | Description |
|---|---|
|
Output tensor of shape |
Source code in end2endimaging-src/end2end_imaging/network/reconstruction/unet.py
end2end_imaging.network.Restormer
Restormer(inp_channels=3, out_channels=3, dim=48, num_blocks=[4, 6, 6, 8], num_refinement_blocks=4, heads=[1, 2, 4, 8], ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', dual_pixel_task=False)
Bases: Module
Restormer: Efficient Transformer for high-resolution image restoration.
A multi-scale encoder-decoder transformer using Multi-DConv Head Transposed Self-Attention (MDTA) and Gated-DConv Feed-Forward Networks (GDFN). Includes a global residual connection from input to output.
Reference: Zamir et al., "Restormer: Efficient Transformer for High-Resolution Image Restoration" (CVPR 2022).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inp_channels
|
Number of input channels. Defaults to 3. |
3
|
|
out_channels
|
Number of output channels. Defaults to 3. |
3
|
|
dim
|
Base embedding dimension. Defaults to 48. |
48
|
|
num_blocks
|
Number of transformer blocks per encoder/decoder stage.
Defaults to |
[4, 6, 6, 8]
|
|
num_refinement_blocks
|
Number of refinement blocks after the decoder. Defaults to 4. |
4
|
|
heads
|
Number of attention heads per stage. Defaults to |
[1, 2, 4, 8]
|
|
ffn_expansion_factor
|
Hidden dimension multiplier in GDFN. Defaults to 2.66. |
2.66
|
|
bias
|
Whether to use bias in convolutions. Defaults to False. |
False
|
|
LayerNorm_type
|
|
'WithBias'
|
|
dual_pixel_task
|
If True, uses skip connection for dual-pixel defocus
deblurring (set |
False
|
Source code in end2endimaging-src/end2end_imaging/network/reconstruction/restormer.py
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 | |
forward
Forward pass with global residual connection.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inp_img
|
Input image tensor of shape |
required |
Returns:
| Type | Description |
|---|---|
|
Restored image tensor of shape |
Source code in end2endimaging-src/end2end_imaging/network/reconstruction/restormer.py
Loss Functions
Differentiable image-quality losses for training reconstruction networks.
end2end_imaging.network.PerceptualLoss
Bases: Module
Perceptual loss based on VGG16 features.
Initialize perceptual loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device
|
Device to put the VGG model on. If None, uses cuda if available. |
None
|
|
weights
|
Weights for different feature layers. |
[1.0, 1.0, 1.0, 1.0, 1.0]
|
Source code in end2endimaging-src/end2end_imaging/network/loss/perceptual_loss.py
forward
Calculate perceptual loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Predicted tensor. |
required | |
y
|
Target tensor. |
required |
Returns:
| Type | Description |
|---|---|
|
Perceptual loss. |
Source code in end2endimaging-src/end2end_imaging/network/loss/perceptual_loss.py
end2end_imaging.network.PSNRLoss
Bases: Module
Peak Signal-to-Noise Ratio (PSNR) loss.
Initialize PSNR loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loss_weight
|
Weight for the loss. |
1.0
|
|
reduction
|
Reduction method, only "mean" is supported. |
'mean'
|
|
toY
|
Whether to convert RGB to Y channel. |
False
|
Source code in end2endimaging-src/end2end_imaging/network/loss/psnr_loss.py
forward
Calculate PSNR loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred
|
Predicted tensor. |
required | |
target
|
Target tensor. |
required |
Returns:
| Type | Description |
|---|---|
|
PSNR loss. |
Source code in end2endimaging-src/end2end_imaging/network/loss/psnr_loss.py
end2end_imaging.network.SSIMLoss
Bases: Module
Structural Similarity Index (SSIM) loss.
Initialize SSIM loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
window_size
|
Size of the window. |
11
|
|
size_average
|
Whether to average the loss. |
True
|
Source code in end2endimaging-src/end2end_imaging/network/loss/ssim_loss.py
forward
Calculate SSIM loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred
|
Predicted tensor. |
required | |
target
|
Target tensor. |
required |
Returns:
| Type | Description |
|---|---|
|
1 - SSIM value. |