Skip to content

Sensor

The end2end_imaging.sensor module provides differentiable sensor models with noise simulation and a full image signal processing (ISP) pipeline.

Sensor Models

Base sensor class shared by all sensor models.

end2end_imaging.sensor.Sensor

Sensor(size=(8.0, 6.0), res=(4000, 3000))

Bases: Module

Minimal image sensor with gamma-only ISP.

The simplest sensor model: records physical size and resolution, and applies only a gamma correction in the ISP forward pass. For a sensor with noise simulation and Bayer demosaicing use RGBSensor.

Attributes:

Name Type Description
size tuple

Physical sensor size (W, H) [mm].

res tuple

Pixel resolution (W, H).

pixel_size float

Physical pixel pitch [mm], size[0] / res[0].

isp Sequential

ISP pipeline (GammaCorrection by default).

Example

sensor = Sensor(size=(8.0, 6.0), res=(4000, 3000)) sensor = Sensor.from_config("sensor.json")

Initialize a minimal sensor.

Parameters:

Name Type Description Default
size tuple

Physical sensor size (W, H) [mm]. Defaults to (8.0, 6.0).

(8.0, 6.0)
res tuple

Pixel resolution (W, H). Defaults to (4000, 3000).

(4000, 3000)
Source code in end2endimaging-src/end2end_imaging/sensor/sensor.py
def __init__(self, size=(8.0, 6.0), res=(4000, 3000)):
    """Initialize a minimal sensor.

    Args:
        size (tuple, optional): Physical sensor size (W, H) [mm].
            Defaults to ``(8.0, 6.0)``.
        res (tuple, optional): Pixel resolution (W, H).
            Defaults to ``(4000, 3000)``.
    """
    super().__init__()

    # Sensor size and resolution
    self.size = size
    self.res = res
    self.pixel_size = size[0] / res[0]  # mm per pixel

    # ISP: gamma correction only
    self.isp = nn.Sequential(
        GammaCorrection(),
    )

from_config classmethod

from_config(sensor_file)

Create a Sensor from a JSON config file.

Parameters:

Name Type Description Default
sensor_file

Path to JSON sensor config file.

required

Returns:

Type Description

Sensor instance.

Source code in end2endimaging-src/end2end_imaging/sensor/sensor.py
@classmethod
def from_config(cls, sensor_file):
    """Create a Sensor from a JSON config file.

    Args:
        sensor_file: Path to JSON sensor config file.

    Returns:
        Sensor instance.
    """
    with open(sensor_file, "r") as f:
        config = json.load(f)

    return cls(
        size=config.get("sensor_size", (8.0, 6.0)),
        res=config.get("sensor_res", (4000, 3000)),
    )

to

to(device)

Move the sensor and its ISP pipeline to a device.

Parameters:

Name Type Description Default
device

Target device (e.g. torch.device("cuda")).

required

Returns:

Name Type Description
Sensor

This sensor instance, for call chaining.

Source code in end2endimaging-src/end2end_imaging/sensor/sensor.py
def to(self, device):
    """Move the sensor and its ISP pipeline to a device.

    Args:
        device: Target device (e.g. ``torch.device("cuda")``).

    Returns:
        Sensor: This sensor instance, for call chaining.
    """
    self.device = device
    self.isp.to(device)
    return self

response_curve

response_curve(img_irr)

Apply response curve to the irradiance image to get the raw image.

Default is identity (linear response).

Parameters:

Name Type Description Default
img_irr

Irradiance image

required

Returns:

Name Type Description
img_raw

Raw image

Source code in end2endimaging-src/end2end_imaging/sensor/sensor.py
def response_curve(self, img_irr):
    """Apply response curve to the irradiance image to get the raw image.

    Default is identity (linear response).

    Args:
        img_irr: Irradiance image

    Returns:
        img_raw: Raw image
    """
    return img_irr

unprocess

unprocess(img)

Inverse ISP: convert sRGB image back to linear RGB.

Parameters:

Name Type Description Default
img

Tensor of shape (B, C, H, W), range [0, 1] in sRGB space.

required

Returns:

Name Type Description
img_linear

Tensor of shape (B, C, H, W), range [0, 1] in linear space.

Source code in end2endimaging-src/end2end_imaging/sensor/sensor.py
def unprocess(self, img):
    """Inverse ISP: convert sRGB image back to linear RGB.

    Args:
        img: Tensor of shape (B, C, H, W), range [0, 1] in sRGB space.

    Returns:
        img_linear: Tensor of shape (B, C, H, W), range [0, 1] in linear space.
    """
    # Inverse gamma correction (isp[0] is GammaCorrection)
    return self.isp[0].reverse(img)

linrgb2raw

linrgb2raw(img_linear)

Convert linear RGB image to raw sensor space.

For the base Sensor, raw is the linear image itself (identity).

Parameters:

Name Type Description Default
img_linear

Tensor of shape (B, C, H, W), range [0, 1].

required

Returns:

Name Type Description
img_raw

Tensor of shape (B, C, H, W), range [0, 1].

Source code in end2endimaging-src/end2end_imaging/sensor/sensor.py
def linrgb2raw(self, img_linear):
    """Convert linear RGB image to raw sensor space.

    For the base Sensor, raw is the linear image itself (identity).

    Args:
        img_linear: Tensor of shape (B, C, H, W), range [0, 1].

    Returns:
        img_raw: Tensor of shape (B, C, H, W), range [0, 1].
    """
    return img_linear

simu_noise

simu_noise(img)

Simulate sensor noise.

Default is identity (no noise).

Parameters:

Name Type Description Default
img

Input image

required

Returns:

Name Type Description
img

Same image unchanged

Source code in end2endimaging-src/end2end_imaging/sensor/sensor.py
def simu_noise(self, img):
    """Simulate sensor noise.

    Default is identity (no noise).

    Args:
        img: Input image

    Returns:
        img: Same image unchanged
    """
    return img

Full RGB sensor with Bayer pattern, noise model (read noise + shot noise), and ISP pipeline (black level compensation, white balance, demosaicing, color correction, gamma).

end2end_imaging.sensor.RGBSensor

RGBSensor(size=(36.0, 24.0), res=(5472, 3648), bit=10, black_level=64, bayer_pattern='rggb', white_balance=(2.0, 1.0, 1.8), color_matrix=None, gamma_param=2.2, iso_base=100, read_noise_std=0.5, shot_noise_std_alpha=0.4, shot_noise_std_beta=0.0, wavelengths=None, red_response=None, green_response=None, blue_response=None)

Bases: Sensor

RGB Bayer-pattern sensor with physics-based noise model and invertible ISP.

Simulates the full image-capture pipeline from linear photon counts to display-ready sRGB:

  1. Spectral integration – optional per-channel spectral response.
  2. Bayer mosaic – pixel-level colour filtering to a single-channel raw image.
  3. Noise – shot noise (signal-dependent Gaussian) + read noise (ISO-independent Gaussian) added to the n-bit raw data.
  4. ISP (forward) – via an InvertibleISP: black-level correction → white balance → colour matrix → demosaicing → gamma correction.

The ISP is invertible: unprocess() converts sRGB back to linear RGB for training data generation.

Attributes:

Name Type Description
bit int

ADC bit depth.

nbit_max int

Maximum digital number 2**bit - 1.

black_level int

Black level pedestal [DN].

bayer_pattern str

Bayer pattern (e.g. "rggb").

iso_base int

Base ISO (noise-free reference).

readnoise_std float

Read-noise standard deviation [DN].

shotnoise_std_alpha float

Shot-noise scale coefficient.

shotnoise_std_beta float

Shot-noise offset coefficient.

isp InvertibleISP

Embedded invertible ISP pipeline.

Initialize an RGB sensor with a physics-based noise model and invertible ISP.

Parameters:

Name Type Description Default
size tuple

Sensor physical size in mm (W, H). Default (36.0, 24.0).

(36.0, 24.0)
res tuple

Sensor resolution in pixels (W, H). Default (5472, 3648).

(5472, 3648)
bit int

Bit depth. Default 10.

10
black_level int

Black level. Default 64.

64
bayer_pattern str

Bayer pattern e.g. "rggb". Default "rggb".

'rggb'
white_balance tuple

White balance gains. Default (2.0, 1.0, 1.8).

(2.0, 1.0, 1.8)
color_matrix list or Tensor

Color correction matrix.

None
gamma_param float

Gamma correction parameter. Default 2.2.

2.2
iso_base int

Base ISO. Default 100.

100
read_noise_std float

Read noise std. Default 0.5.

0.5
shot_noise_std_alpha float

Shot noise alpha. Default 0.4.

0.4
shot_noise_std_beta float

Shot noise beta. Default 0.0.

0.0
wavelengths list

Wavelengths.

None
red_response list

Red channel spectral response.

None
green_response list

Green channel spectral response.

None
blue_response list

Blue channel spectral response.

None
Source code in end2endimaging-src/end2end_imaging/sensor/rgb_sensor.py
def __init__(
    self,
    size=(36.0, 24.0),
    res=(5472, 3648),
    bit=10,
    black_level=64,
    bayer_pattern="rggb",
    white_balance=(2.0, 1.0, 1.8),
    color_matrix=None,
    gamma_param=2.2,
    iso_base=100,
    read_noise_std=0.5,
    shot_noise_std_alpha=0.4,
    shot_noise_std_beta=0.0,
    wavelengths=None,
    red_response=None,
    green_response=None,
    blue_response=None,
):
    """Initialize an RGB sensor with a physics-based noise model and invertible ISP.

    Args:
        size (tuple): Sensor physical size in mm (W, H). Default (36.0, 24.0).
        res (tuple): Sensor resolution in pixels (W, H). Default (5472, 3648).
        bit (int): Bit depth. Default 10.
        black_level (int): Black level. Default 64.
        bayer_pattern (str): Bayer pattern e.g. "rggb". Default "rggb".
        white_balance (tuple): White balance gains. Default (2.0, 1.0, 1.8).
        color_matrix (list or Tensor): Color correction matrix.
        gamma_param (float): Gamma correction parameter. Default 2.2.
        iso_base (int): Base ISO. Default 100.
        read_noise_std (float): Read noise std. Default 0.5.
        shot_noise_std_alpha (float): Shot noise alpha. Default 0.4.
        shot_noise_std_beta (float): Shot noise beta. Default 0.0.
        wavelengths (list): Wavelengths.
        red_response (list): Red channel spectral response.
        green_response (list): Green channel spectral response.
        blue_response (list): Blue channel spectral response.
    """
    super().__init__(size=size, res=res)

    self.bit = bit
    self.nbit_max = 2**bit - 1
    self.black_level = black_level
    self.bayer_pattern = bayer_pattern

    # Noise parameters
    self.iso_base = iso_base
    self.readnoise_std = read_noise_std
    self.shotnoise_std_alpha = shot_noise_std_alpha
    self.shotnoise_std_beta = shot_noise_std_beta

    # Spectral response curves
    self.wavelengths = wavelengths
    if self.wavelengths is not None:
        green_sum = sum(green_response)
        self.red_response = torch.tensor(red_response) / green_sum
        self.green_response = torch.tensor(green_response) / green_sum
        self.blue_response = torch.tensor(blue_response) / green_sum

    # ISP
    self.isp = InvertibleISP(
        bit=bit,
        black_level=black_level,
        bayer_pattern=bayer_pattern,
        white_balance=white_balance,
        color_matrix=color_matrix,
        gamma_param=gamma_param,
    )

from_config classmethod

from_config(sensor_file)

Create an RGBSensor from a JSON config file.

Parameters:

Name Type Description Default
sensor_file

Path to JSON sensor config file.

required

Returns:

Type Description

RGBSensor instance.

Source code in end2endimaging-src/end2end_imaging/sensor/rgb_sensor.py
@classmethod
def from_config(cls, sensor_file):
    """Create an RGBSensor from a JSON config file.

    Args:
        sensor_file: Path to JSON sensor config file.

    Returns:
        RGBSensor instance.
    """
    with open(sensor_file, "r") as f:
        config = json.load(f)

    return cls(
        size=config["sensor_size"],
        res=config["sensor_res"],
        bit=config["bit"],
        black_level=config["black_level"],
        bayer_pattern=config.get("bayer_pattern", "rggb"),
        white_balance=config.get("white_balance_d50", (2.0, 1.0, 1.8)),
        color_matrix=config.get("color_matrix_d50", None),
        gamma_param=config.get("gamma_param", 2.2),
        iso_base=config.get("iso_base", 100),
        read_noise_std=config.get("read_noise_std", 0.5),
        shot_noise_std_alpha=config.get("shot_noise_std_alpha", 0.4),
        shot_noise_std_beta=config.get("shot_noise_std_beta", 0.0),
        wavelengths=config.get("wavelengths", None),
        red_response=config.get("red_spectral_response", None),
        green_response=config.get("green_spectral_response", None),
        blue_response=config.get("blue_spectral_response", None),
    )

to

to(device)

Move the sensor, ISP pipeline, and spectral-response tensors to a device.

Parameters:

Name Type Description Default
device

Target device (e.g. torch.device("cuda")).

required

Returns:

Name Type Description
RGBSensor

This sensor instance, for call chaining.

Source code in end2endimaging-src/end2end_imaging/sensor/rgb_sensor.py
def to(self, device):
    """Move the sensor, ISP pipeline, and spectral-response tensors to a device.

    Args:
        device: Target device (e.g. ``torch.device("cuda")``).

    Returns:
        RGBSensor: This sensor instance, for call chaining.
    """
    super().to(device)
    if self.wavelengths is not None:
        self.red_response = self.red_response.to(device)
        self.green_response = self.green_response.to(device)
        self.blue_response = self.blue_response.to(device)
    return self

response_curve

response_curve(img_spectral)

Apply response curve to the spectral image to get the raw image.

Parameters:

Name Type Description Default
img_spectral

Spectral image, shape (B, C, H, W), range [0, 1]

required

Returns:

Name Type Description
img_raw

Raw image, shape (B, 3, H, W), range [0, 1]

Reference

[1] Spectral Sensitivity Estimation Without a Camera. ICCP 2023. [2] https://github.com/COLOR-Lab-Eilat/Spectral-sensitivity-estimation

Source code in end2endimaging-src/end2end_imaging/sensor/rgb_sensor.py
def response_curve(self, img_spectral):
    """Apply response curve to the spectral image to get the raw image.

    Args:
        img_spectral: Spectral image, shape (B, C, H, W), range [0, 1]

    Returns:
        img_raw: Raw image, shape (B, 3, H, W), range [0, 1]

    Reference:
        [1] Spectral Sensitivity Estimation Without a Camera. ICCP 2023.
        [2] https://github.com/COLOR-Lab-Eilat/Spectral-sensitivity-estimation
    """
    if self.wavelengths is not None:
        img_raw = torch.zeros(
            (
                img_spectral.shape[0],
                3,
                img_spectral.shape[2],
                img_spectral.shape[3],
            ),
            device=img_spectral.device,
        )
        img_raw[:, 0, :, :] = (
            img_spectral * self.red_response.view(1, -1, 1, 1)
        ).sum(dim=1)
        img_raw[:, 1, :, :] = (
            img_spectral * self.green_response.view(1, -1, 1, 1)
        ).sum(dim=1)
        img_raw[:, 2, :, :] = (
            img_spectral * self.blue_response.view(1, -1, 1, 1)
        ).sum(dim=1)
    else:
        assert img_spectral.shape[1] == 3, (
            "No spectral response curves provided, input image must have 3 channels"
        )
        img_raw = img_spectral

    return img_raw

unprocess

unprocess(image, in_type='rgb')

Unprocess an image to unbalanced RAW RGB space.

Parameters:

Name Type Description Default
image

Tensor of shape (B, 3, H, W), range [0, 1]

required
in_type

Input image type, either "rgb" or "linear_rgb"

'rgb'

Returns:

Name Type Description
image

Tensor of shape (B, 3, H, W), range [0, 1] in raw space

Source code in end2endimaging-src/end2end_imaging/sensor/rgb_sensor.py
def unprocess(self, image, in_type="rgb"):
    """Unprocess an image to unbalanced RAW RGB space.

    Args:
        image: Tensor of shape (B, 3, H, W), range [0, 1]
        in_type: Input image type, either "rgb" or "linear_rgb"

    Returns:
        image: Tensor of shape (B, 3, H, W), range [0, 1] in raw space
    """
    isp = self.isp

    # Inverse gamma correction
    if in_type == "linear_rgb":
        pass
    elif in_type == "rgb":
        image = isp.gamma.reverse(image)
    else:
        raise ValueError(f"Invalid input type: {in_type}")

    # Inverse color correction matrix
    image = isp.ccm.reverse(image)

    # Inverse auto white balance
    image = isp.awb.reverse(image)  # (B, 3, H, W), [0, 1]

    return image

linrgb2raw

linrgb2raw(img_linrgb)

Convert linear RGB image to raw Bayer space.

Parameters:

Name Type Description Default
img_linrgb

Tensor of shape (B, 3, H, W), range [0, 1]

required

Returns:

Name Type Description
bayer_nbit

Tensor of shape (B, 1, H, W), range [~black_level, 2**bit - 1]

Source code in end2endimaging-src/end2end_imaging/sensor/rgb_sensor.py
def linrgb2raw(self, img_linrgb):
    """Convert linear RGB image to raw Bayer space.

    Args:
        img_linrgb: Tensor of shape (B, 3, H, W), range [0, 1]

    Returns:
        bayer_nbit: Tensor of shape (B, 1, H, W), range [~black_level, 2**bit - 1]
    """
    black_level = self.black_level
    bit = self.bit

    bayer_float = self.isp.demosaic.reverse(img_linrgb)
    bayer_nbit = bayer_float * (2**bit - 1 - black_level) + black_level
    bayer_nbit = torch.round(bayer_nbit)
    return bayer_nbit

simu_noise

simu_noise(img_raw, iso)

Simulate sensor noise considering sensor quantization and noise model.

Parameters:

Name Type Description Default
img_raw

N-bit clean image, (B, C, H, W), range [0, 2**bit - 1]

required
iso

(B,), range [0, 800]

required

Returns:

Name Type Description
img_raw_noise

N-bit noisy image, (B, C, H, W), range [0, 2**bit - 1]

Reference

[1] "Unprocessing Images for Learned Raw Denoising." [2] https://www.photonstophotos.net/Charts/RN_ADU.htm [3] https://www.photonstophotos.net/Investigations/Measurement_and_Sample_Variation.htm [4] https://www.dpreview.com/forums/thread/4669806

Source code in end2endimaging-src/end2end_imaging/sensor/rgb_sensor.py
def simu_noise(self, img_raw, iso):
    """Simulate sensor noise considering sensor quantization and noise model.

    Args:
        img_raw: N-bit clean image, (B, C, H, W), range [0, 2**bit - 1]
        iso: (B,), range [0, 800]

    Returns:
        img_raw_noise: N-bit noisy image, (B, C, H, W), range [0, 2**bit - 1]

    Reference:
        [1] "Unprocessing Images for Learned Raw Denoising."
        [2] https://www.photonstophotos.net/Charts/RN_ADU.htm
        [3] https://www.photonstophotos.net/Investigations/Measurement_and_Sample_Variation.htm
        [4] https://www.dpreview.com/forums/thread/4669806
    """
    nbit_max = self.nbit_max
    black_level = self.black_level
    device = img_raw.device

    # Calculate noise standard deviation
    shotnoise_std = torch.clamp(
        self.shotnoise_std_alpha * torch.sqrt(torch.clamp(img_raw - black_level, min=0.0))
        + self.shotnoise_std_beta,
        0.0,
    )
    if (iso > 800).any():
        raise ValueError(f"Currently noise model only works for low ISO <= 800, got {iso}")
    gain_analog = 1.0  # we only measured analog gain = 1.0
    gain_digit = (iso / self.iso_base).view(-1, 1, 1, 1)
    noise_std = torch.sqrt(
        shotnoise_std**2 * gain_digit * gain_analog
        + self.readnoise_std**2 * gain_digit**2
    )

    # Sample random noise
    noise_sample = (
        torch.normal(mean=0.0, std=1.0, size=img_raw.size(), device=device)
        * noise_std
    )
    img_raw_noisy = img_raw + noise_sample

    # Clip and quantize
    img_raw_noisy = torch.clip(img_raw_noisy, 0.0, nbit_max)
    img_raw_noisy = torch.round(img_raw_noisy)
    return img_raw_noisy

sample_augmentation

sample_augmentation()

Randomly sample a set of augmentation parameters for ISP modules. Used for data augmentation during training.

Source code in end2endimaging-src/end2end_imaging/sensor/rgb_sensor.py
def sample_augmentation(self):
    """Randomly sample a set of augmentation parameters for ISP modules. Used for data augmentation during training."""
    self.isp.gamma.sample_augmentation()
    self.isp.ccm.sample_augmentation()
    self.isp.awb.sample_augmentation()

reset_augmentation

reset_augmentation()

Reset parameters for ISP modules. Used for evaluation.

Source code in end2endimaging-src/end2end_imaging/sensor/rgb_sensor.py
def reset_augmentation(self):
    """Reset parameters for ISP modules. Used for evaluation."""
    self.isp.gamma.reset_augmentation()
    self.isp.ccm.reset_augmentation()
    self.isp.awb.reset_augmentation()

process2rgb

process2rgb(image, in_type='rggb')

Process an image to a RGB image.

Parameters:

Name Type Description Default
image

Tensor of shape (B, 3, H, W), range [0, 1]

required
in_type

Input image type, either "rggb" or "bayer"

'rggb'

Returns:

Name Type Description
image

Tensor of shape (B, 3, H, W), range [0, 1]

Source code in end2endimaging-src/end2end_imaging/sensor/rgb_sensor.py
def process2rgb(self, image, in_type="rggb"):
    """Process an image to a RGB image.

    Args:
        image: Tensor of shape (B, 3, H, W), range [0, 1]
        in_type: Input image type, either "rggb" or "bayer"

    Returns:
        image: Tensor of shape (B, 3, H, W), range [0, 1]
    """
    # Process to RGB
    if in_type == "rggb":
        image = self.isp(self.rggb2bayer(image))
    elif in_type == "bayer":
        image = self.isp(image)
    else:
        raise ValueError(f"Invalid input type: {in_type}")

    return image

bayer2rggb

bayer2rggb(bayer_nbit)

Convert RAW bayer image to RAW RGGB image.

Parameters:

Name Type Description Default
bayer_nbit

Tensor of shape (B, 1, H, W), range [~black_level, 2**bit - 1]

required

Returns:

Name Type Description
rggb

Tensor of shape (B, 3, H, W), range [0, 1]

Source code in end2endimaging-src/end2end_imaging/sensor/rgb_sensor.py
def bayer2rggb(self, bayer_nbit):
    """Convert RAW bayer image to RAW RGGB image.

    Args:
        bayer_nbit: Tensor of shape (B, 1, H, W), range [~black_level, 2**bit - 1]

    Returns:
        rggb: Tensor of shape (B, 3, H, W), range [0, 1]
    """
    black_level = self.black_level
    bit = self.bit

    if len(bayer_nbit.shape) == 2:
        bayer_nbit = bayer_nbit.unsqueeze(0).unsqueeze(0)
        single_image = True
    else:
        single_image = False

    B, _, H, W = bayer_nbit.shape
    bayer_rggb = torch.zeros(
        (B, 4, H // 2, W // 2), dtype=bayer_nbit.dtype, device=bayer_nbit.device
    )

    bayer_rggb[:, 0, :, :] = bayer_nbit[:, 0, 0:H:2, 0:W:2]
    bayer_rggb[:, 1, :, :] = bayer_nbit[:, 0, 0:H:2, 1:W:2]
    bayer_rggb[:, 2, :, :] = bayer_nbit[:, 0, 1:H:2, 0:W:2]
    bayer_rggb[:, 3, :, :] = bayer_nbit[:, 0, 1:H:2, 1:W:2]

    # Data range [black_level, 2**bit - 1] -> [0, 1]
    rggb = (bayer_rggb - black_level) / (2**bit - 1 - black_level)

    if single_image:
        rggb = rggb.squeeze(0)

    return rggb

rggb2bayer

rggb2bayer(rggb)

Convert RGGB image to RAW Bayer.

Parameters:

Name Type Description Default
rggb

Tensor of shape [4, H/2, W/2] or [B, 4, H/2, W/2], range [0, 1]

required

Returns:

Name Type Description
bayer

Tensor of shape [1, H, W] or [B, 1, H, W], range [~black_level, 2**bit - 1]

Source code in end2endimaging-src/end2end_imaging/sensor/rgb_sensor.py
def rggb2bayer(self, rggb):
    """Convert RGGB image to RAW Bayer.

    Args:
        rggb: Tensor of shape [4, H/2, W/2] or [B, 4, H/2, W/2], range [0, 1]

    Returns:
        bayer: Tensor of shape [1, H, W] or [B, 1, H, W], range [~black_level, 2**bit - 1]
    """
    black_level = self.black_level
    bit = self.bit

    if len(rggb.shape) == 3:
        rggb = rggb.unsqueeze(0)
        single_image = True
    else:
        single_image = False

    B, _, H, W = rggb.shape
    bayer = torch.zeros((B, 1, H * 2, W * 2), dtype=rggb.dtype).to(rggb.device)

    bayer[:, 0, 0 : 2 * H : 2, 0 : 2 * W : 2] = rggb[:, 0, :, :]
    bayer[:, 0, 0 : 2 * H : 2, 1 : 2 * W : 2] = rggb[:, 1, :, :]
    bayer[:, 0, 1 : 2 * H : 2, 0 : 2 * W : 2] = rggb[:, 2, :, :]
    bayer[:, 0, 1 : 2 * H : 2, 1 : 2 * W : 2] = rggb[:, 3, :, :]

    # Data range [0, 1] -> [0, 2**bit-1]
    # bayer = torch.round(bayer * (2**bit - 1 - black_level) + black_level)
    bayer = bayer * (2**bit - 1 - black_level) + black_level

    if single_image:
        bayer = bayer.squeeze(0)

    return bayer

Monochrome sensor without color filter array.

end2end_imaging.sensor.MonoSensor

MonoSensor(bit=10, black_level=64, size=(8.0, 6.0), res=(4000, 3000), read_noise_std=0.5, shot_noise_std_alpha=0.4, shot_noise_std_beta=0.0, iso_base=100, wavelengths=None, spectral_response=None)

Bases: Sensor

Monochrome sensor with noise simulation and ISP.

Initialize a monochrome sensor.

Parameters:

Name Type Description Default
bit int

Bit depth of the sensor. Default 10.

10
black_level float

Black level value. Default 64.

64
size tuple

Sensor physical size in mm (W, H). Default (8.0, 6.0).

(8.0, 6.0)
res tuple

Sensor resolution in pixels (W, H). Default (4000, 3000).

(4000, 3000)
read_noise_std float

Read noise standard deviation. Default 0.5.

0.5
shot_noise_std_alpha float

Shot noise alpha parameter. Default 0.4.

0.4
shot_noise_std_beta float

Shot noise beta parameter. Default 0.0.

0.0
iso_base int

Base ISO value. Default 100.

100
wavelengths list

Wavelengths for spectral response.

None
spectral_response list

Spectral response values.

None
Source code in end2endimaging-src/end2end_imaging/sensor/mono_sensor.py
def __init__(
    self,
    bit=10,
    black_level=64,
    size=(8.0, 6.0),
    res=(4000, 3000),
    read_noise_std=0.5,
    shot_noise_std_alpha=0.4,
    shot_noise_std_beta=0.0,
    iso_base=100,
    wavelengths=None,
    spectral_response=None,
):
    """Initialize a monochrome sensor.

    Args:
        bit (int): Bit depth of the sensor. Default 10.
        black_level (float): Black level value. Default 64.
        size (tuple): Sensor physical size in mm (W, H). Default (8.0, 6.0).
        res (tuple): Sensor resolution in pixels (W, H). Default (4000, 3000).
        read_noise_std (float): Read noise standard deviation. Default 0.5.
        shot_noise_std_alpha (float): Shot noise alpha parameter. Default 0.4.
        shot_noise_std_beta (float): Shot noise beta parameter. Default 0.0.
        iso_base (int): Base ISO value. Default 100.
        wavelengths (list, optional): Wavelengths for spectral response.
        spectral_response (list, optional): Spectral response values.
    """
    super().__init__(size=size, res=res)

    self.bit = bit
    self.nbit_max = 2**bit - 1
    self.black_level = black_level

    # Sensor noise statistics (measured in n-bit digital value space)
    self.iso_base = iso_base
    self.readnoise_std = read_noise_std
    self.shotnoise_std_alpha = shot_noise_std_alpha
    self.shotnoise_std_beta = shot_noise_std_beta

    # Spectral response curve
    self.wavelengths = wavelengths
    if self.wavelengths is not None:
        response = torch.tensor(spectral_response, dtype=torch.float32)
        self.spectral_response = response / response.sum()

    # ISP: black level compensation + gamma
    self.isp = nn.Sequential(
        BlackLevelCompensation(bit, black_level),
        GammaCorrection(),
    )

from_config classmethod

from_config(sensor_file)

Create a MonoSensor from a JSON config file.

Parameters:

Name Type Description Default
sensor_file

Path to JSON sensor config file.

required

Returns:

Type Description

MonoSensor instance.

Source code in end2endimaging-src/end2end_imaging/sensor/mono_sensor.py
@classmethod
def from_config(cls, sensor_file):
    """Create a MonoSensor from a JSON config file.

    Args:
        sensor_file: Path to JSON sensor config file.

    Returns:
        MonoSensor instance.
    """
    with open(sensor_file, "r") as f:
        config = json.load(f)

    spectral_response = config.get("spectral_response", None)
    wavelengths = config.get("wavelengths", None) if spectral_response is not None else None

    return cls(
        size=config.get("sensor_size", (8.0, 6.0)),
        res=config.get("sensor_res", (4000, 3000)),
        bit=config.get("bit", 10),
        black_level=config.get("black_level", 64),
        iso_base=config.get("iso_base", 100),
        read_noise_std=config.get("read_noise_std", 0.5),
        shot_noise_std_alpha=config.get("shot_noise_std_alpha", 0.4),
        shot_noise_std_beta=config.get("shot_noise_std_beta", 0.0),
        wavelengths=wavelengths,
        spectral_response=spectral_response,
    )

to

to(device)

Move the sensor, its ISP pipeline, and spectral response to a device.

Parameters:

Name Type Description Default
device

Target device (e.g. torch.device("cuda")).

required

Returns:

Name Type Description
MonoSensor

This sensor instance, for call chaining.

Source code in end2endimaging-src/end2end_imaging/sensor/mono_sensor.py
def to(self, device):
    """Move the sensor, its ISP pipeline, and spectral response to a device.

    Args:
        device: Target device (e.g. ``torch.device("cuda")``).

    Returns:
        MonoSensor: This sensor instance, for call chaining.
    """
    super().to(device)
    if self.wavelengths is not None:
        self.spectral_response = self.spectral_response.to(device)
    return self

response_curve

response_curve(img_spectral)

Apply spectral response curve to get a monochrome raw image.

Parameters:

Name Type Description Default
img_spectral

Spectral image, (B, N_wavelengths, H, W)

required

Returns:

Name Type Description
img_raw

Monochrome raw image, (B, 1, H, W)

Source code in end2endimaging-src/end2end_imaging/sensor/mono_sensor.py
def response_curve(self, img_spectral):
    """Apply spectral response curve to get a monochrome raw image.

    Args:
        img_spectral: Spectral image, (B, N_wavelengths, H, W)

    Returns:
        img_raw: Monochrome raw image, (B, 1, H, W)
    """
    if self.wavelengths is not None:
        img_raw = (
            img_spectral * self.spectral_response.view(1, -1, 1, 1)
        ).sum(dim=1, keepdim=True)
    else:
        if img_spectral.shape[1] == 1:
            img_raw = img_spectral
        else:
            # Average across channels as fallback
            img_raw = img_spectral.mean(dim=1, keepdim=True)

    return img_raw

unprocess

unprocess(img)

Inverse ISP: convert gamma-corrected image back to linear RGB space.

Parameters:

Name Type Description Default
img

Tensor of shape (B, C, H, W), range [0, 1] in display space.

required

Returns:

Name Type Description
img_linear

Tensor of shape (B, C, H, W), range [0, 1] in linear space.

Source code in end2endimaging-src/end2end_imaging/sensor/mono_sensor.py
def unprocess(self, img):
    """Inverse ISP: convert gamma-corrected image back to linear RGB space.

    Args:
        img: Tensor of shape (B, C, H, W), range [0, 1] in display space.

    Returns:
        img_linear: Tensor of shape (B, C, H, W), range [0, 1] in linear space.
    """
    # Only reverse gamma correction
    # isp[0] = BlackLevelCompensation, isp[1] = GammaCorrection
    img_linear = self.isp[1].reverse(img)
    return img_linear

linrgb2raw

linrgb2raw(img_linear)

Convert linear image to n-bit raw digital number.

Applies spectral response (RGB to Mono) and quantization.

Parameters:

Name Type Description Default
img_linear

Tensor of shape (B, C, H, W), range [0, 1].

required

Returns:

Name Type Description
img_nbit

Tensor of shape (B, 1, H, W), range [~black_level, 2**bit - 1].

Source code in end2endimaging-src/end2end_imaging/sensor/mono_sensor.py
def linrgb2raw(self, img_linear):
    """Convert linear image to n-bit raw digital number.

    Applies spectral response (RGB to Mono) and quantization.

    Args:
        img_linear: Tensor of shape (B, C, H, W), range [0, 1].

    Returns:
        img_nbit: Tensor of shape (B, 1, H, W), range [~black_level, 2**bit - 1].
    """
    # 1. Apply spectral response (RGB -> Mono)
    img_mono = self.response_curve(img_linear)

    # 2. Scale and add black level
    img_nbit = img_mono * (self.nbit_max - self.black_level) + self.black_level

    # 3. Quantize
    img_nbit = torch.round(img_nbit)
    return img_nbit

simu_noise

simu_noise(img_raw, iso)

Simulate sensor noise considering sensor quantization and noise model.

Parameters:

Name Type Description Default
img_raw

N-bit clean image, (B, C, H, W), range [0, 2**bit - 1]

required
iso

(B,), range [0, 800]

required

Returns:

Name Type Description
img_raw_noisy

N-bit noisy image, (B, C, H, W), range [0, 2**bit - 1]

Raises:

Type Description
ValueError

If any ISO value exceeds 800, since the noise model is only calibrated for low ISO (<= 800).

Reference

[1] "Unprocessing Images for Learned Raw Denoising." [2] https://www.photonstophotos.net/Charts/RN_ADU.htm [3] https://www.photonstophotos.net/Investigations/Measurement_and_Sample_Variation.htm [4] https://www.dpreview.com/forums/thread/4669806

Source code in end2endimaging-src/end2end_imaging/sensor/mono_sensor.py
def simu_noise(self, img_raw, iso):
    """Simulate sensor noise considering sensor quantization and noise model.

    Args:
        img_raw: N-bit clean image, (B, C, H, W), range [0, 2**bit - 1]
        iso: (B,), range [0, 800]

    Returns:
        img_raw_noisy: N-bit noisy image, (B, C, H, W), range [0, 2**bit - 1]

    Raises:
        ValueError: If any ISO value exceeds 800, since the noise model is
            only calibrated for low ISO (<= 800).

    Reference:
        [1] "Unprocessing Images for Learned Raw Denoising."
        [2] https://www.photonstophotos.net/Charts/RN_ADU.htm
        [3] https://www.photonstophotos.net/Investigations/Measurement_and_Sample_Variation.htm
        [4] https://www.dpreview.com/forums/thread/4669806
    """
    nbit_max = self.nbit_max
    black_level = self.black_level
    device = img_raw.device

    # Calculate noise standard deviation
    shotnoise_std = torch.clamp(
        self.shotnoise_std_alpha * torch.sqrt(torch.clamp(img_raw - black_level, min=0.0))
        + self.shotnoise_std_beta,
        0.0,
    )
    if (iso > 800).any():
        raise ValueError(f"Currently noise model only works for low ISO <= 800, got {iso}")
    gain_analog = 1.0  # we only measured analog gain = 1.0
    gain_digit = (iso / self.iso_base).view(-1, 1, 1, 1)
    noise_std = torch.sqrt(
        shotnoise_std**2 * gain_digit * gain_analog
        + self.readnoise_std**2 * gain_digit**2
    )

    # Sample random noise
    noise_sample = (
        torch.normal(mean=0.0, std=1.0, size=img_raw.size(), device=device)
        * noise_std
    )
    img_raw_noisy = img_raw + noise_sample

    # Clip and quantize
    img_raw_noisy = torch.clip(img_raw_noisy, 0.0, nbit_max)
    img_raw_noisy = torch.round(img_raw_noisy)
    return img_raw_noisy

ISP Pipeline

The full, differentiable image signal processing pipeline that chains the modules below. RGBSensor uses it to turn raw sensor readings into an sRGB image, and it is invertible to reconstruct raw from sRGB.

end2end_imaging.sensor.isp_modules.isp.InvertibleISP

InvertibleISP(bit=10, black_level=64, bayer_pattern='rggb', white_balance=(2.0, 1.0, 1.8), color_matrix=None, gamma_param=2.2)

Bases: Module

Invertible and differentiable Bayer-sRGB ISP pipeline.

Reference

[1] Architectural Analysis of a Baseline ISP Pipeline. https://link.springer.com/chapter/10.1007/978-94-017-9987-4_2. (page 23, 50)

Initialize the invertible ISP pipeline.

Parameters:

Name Type Description Default
bit int

Bit depth of the input bayer image. Default 10.

10
black_level float

Black level value to subtract. Default 64.

64
bayer_pattern str

Bayer pattern of the input. Default "rggb".

'rggb'
white_balance tuple

Manual white balance gains (R, G, B). Default (2.0, 1.0, 1.8).

(2.0, 1.0, 1.8)
color_matrix optional

Color correction matrix. If None, the ColorCorrectionMatrix module uses its default.

None
gamma_param float

Gamma parameter. Default 2.2.

2.2
Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/isp.py
def __init__(
    self,
    bit=10,
    black_level=64,
    bayer_pattern="rggb",
    white_balance=(2.0, 1.0, 1.8),
    color_matrix=None,
    gamma_param=2.2,
):
    """Initialize the invertible ISP pipeline.

    Args:
        bit (int): Bit depth of the input bayer image. Default 10.
        black_level (float): Black level value to subtract. Default 64.
        bayer_pattern (str): Bayer pattern of the input. Default "rggb".
        white_balance (tuple): Manual white balance gains (R, G, B).
            Default (2.0, 1.0, 1.8).
        color_matrix (optional): Color correction matrix. If None, the
            ColorCorrectionMatrix module uses its default.
        gamma_param (float): Gamma parameter. Default 2.2.
    """
    super().__init__()

    self.bit = bit
    self.black_level = black_level
    self.bayer_pattern = bayer_pattern
    self.white_balance = white_balance
    self.color_matrix = color_matrix
    self.gamma_param = gamma_param

    self.blc = BlackLevelCompensation(bit=bit, black_level=black_level)
    self.demosaic = Demosaic(bayer_pattern=bayer_pattern, method="malvar")
    self.awb = AutoWhiteBalance(awb_method="manual", white_balance=white_balance)
    self.ccm = ColorCorrectionMatrix(ccm_matrix=color_matrix)
    self.gamma = GammaCorrection(gamma_param=gamma_param)

    self.isp = nn.Sequential(
        self.blc,
        self.demosaic,
        self.awb,
        self.ccm,
        self.gamma,
    )

forward

forward(bayer_nbit)

A basic differentiable and invertible ISP pipeline.

Parameters:

Name Type Description Default
bayer_nbit

Input tensor of shape [B, 1, H, W], data range [~black_level, 2^bit-1].

required

Returns:

Name Type Description
rgb

Output tensor of shape [B, 3, H, W], data range [0, 1].

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/isp.py
def forward(self, bayer_nbit):
    """A basic differentiable and invertible ISP pipeline.

    Args:
        bayer_nbit: Input tensor of shape [B, 1, H, W], data range [~black_level, 2^bit-1].

    Returns:
        rgb: Output tensor of shape [B, 3, H, W], data range [0, 1].
    """
    img = self.isp(bayer_nbit)
    return img

reverse

reverse(img)

Inverse ISP.

Parameters:

Name Type Description Default
img

Input tensor of shape [B, 3, H, W], data range [0, 1].

required

Returns:

Name Type Description
bayer_Nbit

Output tensor of shape [B, 1, H, W], data range [~black_level, 2^bit-1].

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/isp.py
def reverse(self, img):
    """Inverse ISP.

    Args:
        img: Input tensor of shape [B, 3, H, W], data range [0, 1].

    Returns:
        bayer_Nbit: Output tensor of shape [B, 1, H, W], data range [~black_level, 2^bit-1].
    """
    img = self.gamma.reverse(img)
    img = self.ccm.reverse(img)
    img = self.awb.reverse(img)
    bayer = self.demosaic.reverse(img)
    bayer = self.blc.reverse(bayer)
    return bayer

ISP Modules

Individual image signal processing stages used inside RGBSensor. Each module is a torch.nn.Module.

end2end_imaging.sensor.isp_modules.BlackLevelCompensation

BlackLevelCompensation(bit=10, black_level=64)

Bases: Module

Black level compensation (BLC).

Black level compensation is a technique to subtract the black level from the image.

Initialize black level compensation.

Parameters:

Name Type Description Default
bit

Bit depth of the input image.

10
black_level

Black level value.

64
Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/black_level.py
def __init__(self, bit=10, black_level=64):
    """Initialize black level compensation.

    Args:
        bit: Bit depth of the input image.
        black_level: Black level value.
    """
    super().__init__()
    self.bit = bit
    self.black_level = black_level

forward

forward(bayer)

Black Level Compensation.

Parameters:

Name Type Description Default
bayer Tensor

Input n-bit bayer image [B, 1, H, W], data range [~black_level, 2**bit - 1].

required

Returns:

Name Type Description
bayer_float Tensor

Output float bayer image [B, 1, H, W], data range [0, 1].

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/black_level.py
def forward(self, bayer):
    """Black Level Compensation.

    Args:
        bayer (torch.Tensor): Input n-bit bayer image [B, 1, H, W], data range [~black_level, 2**bit - 1].

    Returns:
        bayer_float (torch.Tensor): Output float bayer image [B, 1, H, W], data range [0, 1].
    """
    # Subtract black level
    bayer_float = (bayer - self.black_level) / (2**self.bit - 1 - self.black_level)

    # Clamp to [0, 1], (unnecessary)
    bayer_float = torch.clamp(bayer_float, 0.0, 1.0)

    return bayer_float

reverse

reverse(bayer, quantize=False)

Inverse black level compensation.

Parameters:

Name Type Description Default
bayer

Input tensor of shape [B, 1, H, W], data range [0, 1].

required
quantize

If True, round to integer values (non-differentiable).

False

Returns:

Name Type Description
bayer_nbit

Output tensor of shape [B, 1, H, W], data range [0, 2^bit-1].

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/black_level.py
def reverse(self, bayer, quantize=False):
    """Inverse black level compensation.

    Args:
        bayer: Input tensor of shape [B, 1, H, W], data range [0, 1].
        quantize: If True, round to integer values (non-differentiable).

    Returns:
        bayer_nbit: Output tensor of shape [B, 1, H, W], data range [0, 2^bit-1].
    """
    max_value = 2**self.bit - 1
    bayer_nbit = bayer * (max_value - self.black_level) + self.black_level
    if quantize:
        # Note: torch.round() is not differentiable
        bayer_nbit = torch.round(bayer_nbit)
    return bayer_nbit

end2end_imaging.sensor.isp_modules.AutoWhiteBalance

AutoWhiteBalance(awb_method='gray_world', white_balance=(2.0, 1.0, 1.8))

Bases: Module

Auto white balance (AWB).

Initialize auto white balance.

Parameters:

Name Type Description Default
awb_method

AWB method, "gray_world" or "manual".

'gray_world'
white_balance

RGB white balance for manual AWB, shape [3].

(2.0, 1.0, 1.8)
Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/white_balance.py
def __init__(self, awb_method="gray_world", white_balance=(2.0, 1.0, 1.8)):
    """Initialize auto white balance.

    Args:
        awb_method: AWB method, "gray_world" or "manual".
        white_balance: RGB white balance for manual AWB, shape [3].
    """
    super().__init__()
    self.awb_method = awb_method
    self.register_buffer('white_balance', torch.tensor(white_balance))

sample_augmentation

sample_augmentation()

Sample augmentation for synthetic data generation.

Perturbs the white balance gains with Gaussian noise, caching the original gains on first call so they can be restored later.

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/white_balance.py
def sample_augmentation(self):
    """Sample augmentation for synthetic data generation.

    Perturbs the white balance gains with Gaussian noise, caching the
    original gains on first call so they can be restored later.
    """
    if not hasattr(self, "white_balance_org"):
        self.white_balance_org = self.white_balance
    self.white_balance = self.white_balance_org + torch.randn_like(self.white_balance_org) * 0.1

reset_augmentation

reset_augmentation()

Reset augmentation for evaluation by restoring the original gains.

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/white_balance.py
def reset_augmentation(self):
    """Reset augmentation for evaluation by restoring the original gains."""
    self.white_balance = self.white_balance_org

apply_awb_bayer

apply_awb_bayer(bayer)

Apply white balance to Bayer pattern image.

Parameters:

Name Type Description Default
bayer

Input tensor of shape [B, 1, H, W].

required

Returns:

Name Type Description
bayer_wb

Output tensor with same shape as input.

Raises:

Type Description
ValueError

If awb_method is not "gray_world" or "manual".

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/white_balance.py
def apply_awb_bayer(self, bayer):
    """Apply white balance to Bayer pattern image.

    Args:
        bayer: Input tensor of shape [B, 1, H, W].

    Returns:
        bayer_wb: Output tensor with same shape as input.

    Raises:
        ValueError: If awb_method is not "gray_world" or "manual".
    """
    B, _, H, W = bayer.shape

    # Create masks for R, G, B pixels (assuming RGGB pattern)
    r_mask = torch.zeros((H, W), device=bayer.device)
    g_mask = torch.zeros((H, W), device=bayer.device)
    b_mask = torch.zeros((H, W), device=bayer.device)

    r_mask[0::2, 0::2] = 1  # R at top-left
    g_mask[0::2, 1::2] = 1  # G at top-right
    g_mask[1::2, 0::2] = 1  # G at bottom-left
    b_mask[1::2, 1::2] = 1  # B at bottom-right

    # Apply masks to extract color channels
    r = bayer * r_mask.view(1, 1, H, W)
    g = bayer * g_mask.view(1, 1, H, W)
    b = bayer * b_mask.view(1, 1, H, W)

    if self.awb_method == "gray_world":
        # Calculate average for each channel (excluding zeros)
        r_avg = torch.sum(r, dim=[2, 3]) / torch.sum(r_mask)
        g_avg = torch.sum(g, dim=[2, 3]) / torch.sum(g_mask)
        b_avg = torch.sum(b, dim=[2, 3]) / torch.sum(b_mask)

        # Calculate white balance to make averages equal
        g_gain = torch.ones_like(g_avg)
        r_gain = g_avg / (r_avg + 1e-6)
        b_gain = g_avg / (b_avg + 1e-6)

        # Apply gains
        bayer_wb = bayer.clone()
        bayer_wb = bayer_wb * (
            r_mask.view(1, 1, H, W) * r_gain.view(B, 1, 1, 1)
            + g_mask.view(1, 1, H, W) * g_gain.view(B, 1, 1, 1)
            + b_mask.view(1, 1, H, W) * b_gain.view(B, 1, 1, 1)
        )

    elif self.awb_method == "manual":
        # Apply manual gains
        bayer_wb = bayer.clone()
        bayer_wb = bayer_wb * (
            r_mask.view(1, 1, H, W) * self.white_balance[0]
            + g_mask.view(1, 1, H, W) * self.white_balance[1]
            + b_mask.view(1, 1, H, W) * self.white_balance[2]
        )
    else:
        raise ValueError(f"Unknown AWB method: {self.awb_method}")

    return bayer_wb

apply_awb_rgb

apply_awb_rgb(rgb)

Apply white balance to RGB image.

Parameters:

Name Type Description Default
rgb

Input tensor of shape [B, 3, H, W].

required

Returns:

Name Type Description
rgb_wb

Output tensor with same shape as input.

Raises:

Type Description
ValueError

If awb_method is not "gray_world" or "manual".

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/white_balance.py
def apply_awb_rgb(self, rgb):
    """Apply white balance to RGB image.

    Args:
        rgb: Input tensor of shape [B, 3, H, W].

    Returns:
        rgb_wb: Output tensor with same shape as input.

    Raises:
        ValueError: If awb_method is not "gray_world" or "manual".
    """
    if self.awb_method == "gray_world":
        # Calculate average for each channel
        rgb_avg = torch.mean(rgb, dim=[2, 3], keepdim=True)

        # Calculate gains to make averages equal
        g_avg = rgb_avg[:, 1:2, :, :]
        gains = g_avg / (rgb_avg + 1e-6)

        # Apply gains
        rgb_wb = rgb * gains

    elif self.awb_method == "manual":
        # Apply manual gains
        rgb_wb = rgb * self.white_balance.view(1, 3, 1, 1)

    else:
        raise ValueError(f"Unknown AWB method: {self.awb_method}")

    return rgb_wb

forward

forward(input_tensor)

Auto White Balance (AWB).

Parameters:

Name Type Description Default
input_tensor

Input tensor of shape [B, 1, H, W] or [B, 3, H, W].

required

Returns:

Name Type Description
output_tensor

Output tensor [B, 1, H, W] or [B, 3, H, W].

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/white_balance.py
def forward(self, input_tensor):
    """Auto White Balance (AWB).

    Args:
        input_tensor: Input tensor of shape [B, 1, H, W] or [B, 3, H, W].

    Returns:
        output_tensor: Output tensor [B, 1, H, W] or [B, 3, H, W].
    """
    if input_tensor.shape[1] == 1:
        return self.apply_awb_bayer(input_tensor)
    else:
        return self.apply_awb_rgb(input_tensor)

reverse

reverse(img)

Inverse auto white balance (differentiable).

Parameters:

Name Type Description Default
img

Input tensor of shape [3, H, W] or [B, 3, H, W].

required

Returns:

Name Type Description
rgb_unbalanced

Output tensor with inverse white balance applied.

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/white_balance.py
def reverse(self, img):
    """Inverse auto white balance (differentiable).

    Args:
        img: Input tensor of shape [3, H, W] or [B, 3, H, W].

    Returns:
        rgb_unbalanced: Output tensor with inverse white balance applied.
    """
    # Compute inverse gains
    inv_gains = 1.0 / self.white_balance  # [3]

    # Apply inverse gains (differentiable element-wise division)
    if img.dim() == 3:
        # Shape: [3, H, W]
        rgb_unbalanced = img * inv_gains.view(3, 1, 1)
    else:
        # Shape: [B, 3, H, W]
        rgb_unbalanced = img * inv_gains.view(1, 3, 1, 1)

    return rgb_unbalanced

safe_reverse_awb

safe_reverse_awb(img)

Inverse auto white balance with highlight-safe gains.

Applies inverse white balance gains while attenuating the correction in bright (near-saturated) regions to avoid pushing highlights out of range.

Parameters:

Name Type Description Default
img

Input tensor of shape [3, H, W] or [B, 3, H, W].

required

Returns:

Name Type Description
rgb_unbalanced

Output tensor with inverse white balance applied, same shape as input.

Raises:

Type Description
ValueError

If img is not 3- or 4-dimensional.

Reference

https://github.com/google-research/google-research/blob/master/unprocessing/unprocess.py#L92C1-L102C28

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/white_balance.py
def safe_reverse_awb(self, img):
    """Inverse auto white balance with highlight-safe gains.

    Applies inverse white balance gains while attenuating the correction in
    bright (near-saturated) regions to avoid pushing highlights out of range.

    Args:
        img: Input tensor of shape [3, H, W] or [B, 3, H, W].

    Returns:
        rgb_unbalanced: Output tensor with inverse white balance applied,
            same shape as input.

    Raises:
        ValueError: If img is not 3- or 4-dimensional.

    Reference:
        https://github.com/google-research/google-research/blob/master/unprocessing/unprocess.py#L92C1-L102C28
    """
    r_gain = self.white_balance[0]
    g_gain = self.white_balance[1]
    b_gain = self.white_balance[2]

    # Safely inverse AWB
    if len(img.shape) == 3:
        white_balance = (
            torch.tensor([1.0 / r_gain, 1.0 / g_gain, 1.0 / b_gain], device=img.device)
            .unsqueeze(-1)
            .unsqueeze(-1)
        )

        gray = torch.mean(img, dim=0, keepdim=True)
        inflection = 0.9
        mask = (torch.clamp(gray - inflection, min=0.0) / (1.0 - inflection)) ** 2.0
        safe_gains = torch.max(mask + (1.0 - mask) * white_balance, white_balance)

        rgb_unbalanced = img * safe_gains

    elif len(img.shape) == 4:
        white_balance = (
            torch.tensor([1.0 / r_gain, 1.0 / g_gain, 1.0 / b_gain], device=img.device)
            .unsqueeze(-1)
            .unsqueeze(-1)
            .unsqueeze(0)
        )

        gray = torch.mean(img, dim=1, keepdim=True)
        inflection = 0.9
        mask = (torch.clamp(gray - inflection, min=0.0) / (1.0 - inflection)) ** 2.0
        safe_gains = torch.max(mask + (1.0 - mask) * white_balance, white_balance)

        rgb_unbalanced = img * safe_gains

    else:
        raise ValueError("Invalid rgb shape")

    return rgb_unbalanced

end2end_imaging.sensor.isp_modules.Demosaic

Demosaic(bayer_pattern='rggb', method='malvar')

Bases: Module

Demosaic, or Color Filter Array (CFA).

Converts a Bayer pattern image to a full RGB image by interpolating missing color values at each pixel location.

Supported methods
  • "bilinear": Simple bilinear interpolation (fast, lower quality)
  • "malvar": Malvar-He-Cutler high-quality gradient-corrected interpolation
Reference

[1] Malvar, He, Cutler. "High-Quality Linear Interpolation for Demosaicing of Bayer-Patterned Color Images", ICASSP 2004.

Initialize demosaic.

Parameters:

Name Type Description Default
bayer_pattern

Bayer pattern, "rggb" or "bggr".

'rggb'
method

Demosaic method, "bilinear" or "malvar".

'malvar'
Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/demosaic.py
def __init__(self, bayer_pattern="rggb", method="malvar"):
    """Initialize demosaic.

    Args:
        bayer_pattern: Bayer pattern, "rggb" or "bggr".
        method: Demosaic method, "bilinear" or "malvar".
    """
    super().__init__()
    self.bayer_pattern = bayer_pattern
    self.method = method

    # Pre-compute Malvar kernels if using that method
    if method == "malvar":
        self._init_malvar_kernels()

forward

forward(bayer)

Demosaic a Bayer pattern image to RGB.

Parameters:

Name Type Description Default
bayer

Input tensor of shape [1, H, W] or [B, 1, H, W].

required

Returns:

Name Type Description
raw_rgb

Output tensor of shape [3, H, W] or [B, 3, H, W], matching the dimensionality of the input.

Raises:

Type Description
ValueError

If self.method is not "bilinear" or "malvar".

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/demosaic.py
def forward(self, bayer):
    """Demosaic a Bayer pattern image to RGB.

    Args:
        bayer: Input tensor of shape [1, H, W] or [B, 1, H, W].

    Returns:
        raw_rgb: Output tensor of shape [3, H, W] or [B, 3, H, W],
            matching the dimensionality of the input.

    Raises:
        ValueError: If ``self.method`` is not "bilinear" or "malvar".
    """
    if bayer.dim() == 3:
        bayer = bayer.unsqueeze(0)
        batch_dim = False
    else:
        batch_dim = True

    if self.method == "bilinear":
        raw_rgb = self._bilinear_demosaic(bayer)
    elif self.method == "malvar":
        raw_rgb = self._malvar_demosaic(bayer)
    else:
        raise ValueError(f"Invalid demosaic method: {self.method}. Use 'bilinear' or 'malvar'.")

    if not batch_dim:
        raw_rgb = raw_rgb.squeeze(0)

    return raw_rgb

reverse

reverse(img)

Inverse demosaic from RAW RGB to RAW Bayer.

Parameters:

Name Type Description Default
img Tensor

RAW RGB image, shape [3, H, W] or [B, 3, H, W], data range [0, 1].

required

Returns:

Type Description

torch.Tensor: Bayer image, shape [1, H, W] or [B, 1, H, W], data range [0, 1].

Raises:

Type Description
ValueError

If the input does not have 3 or 4 dimensions, or if the channel dimension is not 3.

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/demosaic.py
def reverse(self, img):
    """Inverse demosaic from RAW RGB to RAW Bayer.

    Args:
        img (torch.Tensor): RAW RGB image, shape [3, H, W] or [B, 3, H, W], data range [0, 1].

    Returns:
        torch.Tensor: Bayer image, shape [1, H, W] or [B, 1, H, W], data range [0, 1].

    Raises:
        ValueError: If the input does not have 3 or 4 dimensions, or if the
            channel dimension is not 3.
    """
    if img.ndim == 3:
        # Input shape: [3, H, W]
        batch_dim = False
        C, H, W = img.shape
    elif img.ndim == 4:
        # Input shape: [B, 3, H, W]
        batch_dim = True
        B, C, H, W = img.shape
    else:
        raise ValueError(
            "Input image must have 3 or 4 dimensions corresponding to [3, H, W] or [B, 3, H, W]."
        )

    if C != 3:
        raise ValueError("Input image must have 3 channels corresponding to RGB.")

    if batch_dim:
        bayer = torch.zeros((B, 1, H, W), dtype=img.dtype, device=img.device)
        bayer[:, 0, 0::2, 0::2] = img[:, 0, 0::2, 0::2]
        bayer[:, 0, 0::2, 1::2] = img[:, 1, 0::2, 1::2]
        bayer[:, 0, 1::2, 0::2] = img[:, 1, 1::2, 0::2]
        bayer[:, 0, 1::2, 1::2] = img[:, 2, 1::2, 1::2]
    else:
        bayer = torch.zeros((1, H, W), dtype=img.dtype, device=img.device)
        bayer[0, 0::2, 0::2] = img[0, 0::2, 0::2]
        bayer[0, 0::2, 1::2] = img[1, 0::2, 1::2]
        bayer[0, 1::2, 0::2] = img[1, 1::2, 0::2]
        bayer[0, 1::2, 1::2] = img[2, 1::2, 1::2]

    return bayer

end2end_imaging.sensor.isp_modules.ColorCorrectionMatrix

ColorCorrectionMatrix(ccm_matrix=None)

Bases: Module

Color correction matrix (CCM).

Color correction matrix is a 4x3 matrix that corrects the color of the image.

Initialize color correction matrix.

Parameters:

Name Type Description Default
ccm_matrix

Color correction matrix as a list of shape [4, 3] or [3, 3] (a [3, 3] matrix is padded with a zero bias row to [4, 3]). If None (default), an identity matrix with zero bias is used. Example: [[1.8506, -0.7920, -0.0605], [-0.1562, 1.6455, -0.4912], [ 0.0176, -0.5439, 1.5254], [ 0.0, 0.0, 0.0 ]]

None

Raises:

Type Description
ValueError

If ccm_matrix is neither None nor a list.

Reference

[1] https://github.com/QiuJueqin/fast-openISP/blob/master/configs/nikon_d3200.yaml#L57 [2] https://github.com/timothybrooks/hdr-plus/blob/master/src/finish.cpp#L626 [3] https://www.dxomark.com/Cameras/Canon/EOS-R6---Measurements, "Color Response"

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/color_matrix.py
def __init__(self, ccm_matrix=None):
    """Initialize color correction matrix.

    Args:
        ccm_matrix: Color correction matrix as a list of shape [4, 3] or
            [3, 3] (a [3, 3] matrix is padded with a zero bias row to
            [4, 3]). If None (default), an identity matrix with zero bias
            is used. Example:
            [[1.8506, -0.7920, -0.0605],
             [-0.1562,  1.6455, -0.4912],
             [ 0.0176, -0.5439,  1.5254],
             [ 0.0,     0.0,     0.0   ]]

    Raises:
        ValueError: If ccm_matrix is neither None nor a list.

    Reference:
        [1] https://github.com/QiuJueqin/fast-openISP/blob/master/configs/nikon_d3200.yaml#L57
        [2] https://github.com/timothybrooks/hdr-plus/blob/master/src/finish.cpp#L626
        [3] https://www.dxomark.com/Cameras/Canon/EOS-R6---Measurements, "Color Response"
    """
    super().__init__()
    if ccm_matrix is None:
        ccm_matrix = torch.tensor(
            [
                [1.0, 0.0, 0.0],
                [0.0, 1.0, 0.0],
                [0.0, 0.0, 1.0],
                [0.0, 0.0, 0.0],
            ]
        )
    elif isinstance(ccm_matrix, list):
        ccm_matrix = torch.tensor(ccm_matrix)
        if ccm_matrix.shape == (3, 3):
            ccm_matrix = torch.cat([ccm_matrix, torch.zeros(1, 3)], dim=0)
    else:
        raise ValueError(f"Unknown type of ccm_matrix: {type(ccm_matrix)}")

    self.register_buffer("ccm_matrix", ccm_matrix)

sample_augmentation

sample_augmentation()

Sample augmentation for synthetic data generation.

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/color_matrix.py
def sample_augmentation(self):
    """Sample augmentation for synthetic data generation."""
    if not hasattr(self, "ccm_org"):
        self.ccm_org = self.ccm_matrix
    self.ccm_matrix = self.ccm_org + torch.randn_like(self.ccm_org) * 0.01

reset_augmentation

reset_augmentation()

Reset augmentation for evaluation.

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/color_matrix.py
def reset_augmentation(self):
    """Reset augmentation for evaluation."""
    self.ccm_matrix = self.ccm_org

forward

forward(rgb_image)

Color Correction Matrix. Convert RGB image to sensor color space.

Parameters:

Name Type Description Default
rgb_image

Input tensor of shape [B, 3, H, W] in RGB format.

required

Returns:

Name Type Description
rgb_corrected

Corrected RGB image in sensor color space.

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/color_matrix.py
def forward(self, rgb_image):
    """Color Correction Matrix. Convert RGB image to sensor color space.

    Args:
        rgb_image: Input tensor of shape [B, 3, H, W] in RGB format.

    Returns:
        rgb_corrected: Corrected RGB image in sensor color space.
    """
    # Extract matrix and bias
    matrix = self.ccm_matrix[:3, :]  # Shape: (3, 3)
    bias = self.ccm_matrix[3, :].view(1, 3, 1, 1)  # Shape: (1, 3, 1, 1)

    # Apply CCM
    # Reshape rgb_image to [B, H, W, 3] for matrix multiplication
    rgb_image_perm = rgb_image.permute(0, 2, 3, 1)  # [B, H, W, 3]
    rgb_corrected = torch.matmul(rgb_image_perm, matrix.T) + bias.squeeze()
    rgb_corrected = rgb_corrected.permute(0, 3, 1, 2)  # [B, 3, H, W]

    return rgb_corrected

reverse

reverse(img)

Inverse color correction matrix. Convert sensor color space to RGB image.

Parameters:

Name Type Description Default
img

Input tensor of shape [B, 3, H, W] in sensor color space.

required

Returns:

Name Type Description
img_original

RGB image of shape [B, 3, H, W], clamped to [0, 1].

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/color_matrix.py
def reverse(self, img):
    """Inverse color correction matrix. Convert sensor color space to RGB image.

    Args:
        img: Input tensor of shape [B, 3, H, W] in sensor color space.

    Returns:
        img_original: RGB image of shape [B, 3, H, W], clamped to [0, 1].
    """
    ccm_matrix = self.ccm_matrix

    # Extract matrix and bias from CCM
    matrix = ccm_matrix[:3, :]  # Shape: (3, 3)
    bias = ccm_matrix[3, :].view(1, 3, 1, 1)  # Shape: (1, 3, 1, 1)

    # Compute the inverse of the CCM matrix
    inv_matrix = torch.inverse(matrix)  # Shape: (3, 3)

    # Prepare rgb_corrected for matrix multiplication
    img_perm = img.permute(0, 2, 3, 1)  # [B, H, W, 3]

    # Subtract bias
    img_minus_bias = img_perm - bias.squeeze()

    # Apply Inverse CCM
    img_original = torch.matmul(img_minus_bias, inv_matrix.T)  # [B, H, W, 3]
    img_original = img_original.permute(0, 3, 1, 2)  # [B, 3, H, W]

    # Clip the values to ensure they are within the valid range
    img_original = torch.clamp(img_original, 0.0, 1.0)

    return img_original

end2end_imaging.sensor.isp_modules.GammaCorrection

GammaCorrection(gamma_param=2.2)

Bases: Module

Gamma correction (GC).

Gamma correction is a technique to adjust the gamma of the image.

Initialize gamma correction module.

Parameters:

Name Type Description Default
gamma_param

Gamma parameter. Default is 2.2.

2.2
Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/gamma_correction.py
def __init__(self, gamma_param=2.2):
    """Initialize gamma correction module.

    Args:
        gamma_param: Gamma parameter. Default is 2.2.
    """
    super().__init__()
    self.register_buffer("gamma_param", torch.tensor(gamma_param))

sample_augmentation

sample_augmentation()

Sample augmentation for synthetic data generation.

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/gamma_correction.py
def sample_augmentation(self):
    """Sample augmentation for synthetic data generation."""
    if not hasattr(self, "gamma_param_org"):
        self.gamma_param_org = self.gamma_param
    self.gamma_param = (
        self.gamma_param_org + torch.randn_like(self.gamma_param_org) * 0.01
    )

reset_augmentation

reset_augmentation()

Reset augmentation for evaluation.

No-op if no augmentation has been sampled yet, in which case gamma_param already holds the original value.

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/gamma_correction.py
def reset_augmentation(self):
    """Reset augmentation for evaluation.

    No-op if no augmentation has been sampled yet, in which case
    ``gamma_param`` already holds the original value.
    """
    if hasattr(self, "gamma_param_org"):
        self.gamma_param = self.gamma_param_org

forward

forward(img, quantize=False)

Gamma Correction (differentiable).

Parameters:

Name Type Description Default
img tensor

Input image. Shape of [B, C, H, W].

required
quantize bool

Whether to quantize the image to 8-bit. WARNING: quantize=True makes this non-differentiable!

False

Returns:

Name Type Description
img_gamma tensor

Gamma corrected image. Shape of [B, C, H, W].

Reference

[1] "There is no restriction as to where stage gamma correction is placed," page 35, Architectural Analysis of a Baseline ISP Pipeline.

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/gamma_correction.py
def forward(self, img, quantize=False):
    """Gamma Correction (differentiable).

    Args:
        img (tensor): Input image. Shape of [B, C, H, W].
        quantize (bool): Whether to quantize the image to 8-bit.
                         WARNING: quantize=True makes this non-differentiable!

    Returns:
        img_gamma (tensor): Gamma corrected image. Shape of [B, C, H, W].

    Reference:
        [1] "There is no restriction as to where stage gamma correction is placed," page 35, Architectural Analysis of a Baseline ISP Pipeline.
    """
    img_gamma = torch.pow(torch.clamp(img, min=1e-8), 1 / self.gamma_param)
    if quantize:
        # WARNING: torch.round() is NOT differentiable!
        # Use only for final output, not during training
        img_gamma = torch.round(img_gamma * 255) / 255
    return img_gamma

reverse

reverse(img)

Inverse gamma correction.

Parameters:

Name Type Description Default
img tensor

Input image. Shape of [B, C, H, W].

required

Returns:

Name Type Description
img tensor

Inverse gamma corrected image. Shape of [B, C, H, W].

Reference

[1] https://github.com/google-research/google-research/blob/master/unprocessing/unprocess.py#L78

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/gamma_correction.py
def reverse(self, img):
    """Inverse gamma correction.

    Args:
        img (tensor): Input image. Shape of [B, C, H, W].

    Returns:
        img (tensor): Inverse gamma corrected image. Shape of [B, C, H, W].

    Reference:
        [1] https://github.com/google-research/google-research/blob/master/unprocessing/unprocess.py#L78
    """
    gamma_param = self.gamma_param
    img = torch.clip(img, 1e-8) ** gamma_param
    return img

end2end_imaging.sensor.isp_modules.ToneMapping

ToneMapping(method='reinhard', exposure=1.0)

Bases: Module

Global tone mapping operator.

Maps HDR linear radiance values to displayable [0, 1] range using a global (per-pixel, spatially invariant) curve.

Supported methods
  • "reinhard": L / (1 + L), from [Reinhard et al. 2002].
  • "aces": ACES filmic curve approximation, from [Narkowicz 2015].
  • "hable": Uncharted 2 filmic curve, from [Hable 2010].
Reference

[1] Reinhard et al., "Photographic Tone Reproduction for Digital Images", SIGGRAPH 2002. [2] Narkowicz, "ACES Filmic Tone Mapping Curve", 2015. [3] Hable, "Filmic Tonemapping Operators", GDC 2010.

Initialize tone mapping module.

Parameters:

Name Type Description Default
method

Tone mapping method, one of "reinhard", "aces", "hable".

'reinhard'
exposure

Exposure multiplier applied before tone mapping.

1.0

Raises:

Type Description
ValueError

If method is not one of "reinhard", "aces", "hable".

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/tone_mapping.py
def __init__(self, method="reinhard", exposure=1.0):
    """Initialize tone mapping module.

    Args:
        method: Tone mapping method, one of "reinhard", "aces", "hable".
        exposure: Exposure multiplier applied before tone mapping.

    Raises:
        ValueError: If ``method`` is not one of "reinhard", "aces", "hable".
    """
    super().__init__()
    if method not in ("reinhard", "aces", "hable"):
        raise ValueError(f"Unknown tone mapping method: {method}")
    self.method = method
    self.register_buffer("exposure", torch.tensor(exposure))

forward

forward(img)

Apply global tone mapping.

Parameters:

Name Type Description Default
img

HDR linear image, (B, C, H, W), range [0, +inf).

required

Returns:

Name Type Description
img_tm

Tone-mapped image, (B, C, H, W), range [0, 1].

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/tone_mapping.py
def forward(self, img):
    """Apply global tone mapping.

    Args:
        img: HDR linear image, (B, C, H, W), range [0, +inf).

    Returns:
        img_tm: Tone-mapped image, (B, C, H, W), range [0, 1].
    """
    img = torch.clamp(img, min=0.0) * self.exposure

    if self.method == "reinhard":
        img_tm = img / (1.0 + img)
    elif self.method == "aces":
        img_tm = self._aces(img)
    elif self.method == "hable":
        img_tm = self._hable(img)

    return torch.clamp(img_tm, 0.0, 1.0)

reverse

reverse(img)

Inverse tone mapping (recover linear HDR from tone-mapped image).

Only analytically invertible for "reinhard". For "aces" and "hable", uses an iterative Newton's method approximation.

Parameters:

Name Type Description Default
img

Tone-mapped image, (B, C, H, W), range [0, 1].

required

Returns:

Name Type Description
img_hdr

Recovered linear image, (B, C, H, W), range [0, +inf).

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/tone_mapping.py
def reverse(self, img):
    """Inverse tone mapping (recover linear HDR from tone-mapped image).

    Only analytically invertible for "reinhard". For "aces" and "hable",
    uses an iterative Newton's method approximation.

    Args:
        img: Tone-mapped image, (B, C, H, W), range [0, 1].

    Returns:
        img_hdr: Recovered linear image, (B, C, H, W), range [0, +inf).
    """
    img = torch.clamp(img, 0.0, 1.0 - 1e-6)

    if self.method == "reinhard":
        img_hdr = img / (1.0 - img)
    elif self.method == "aces":
        img_hdr = self._aces_reverse(img)
    elif self.method == "hable":
        img_hdr = self._hable_reverse(img)

    return torch.clamp(img_hdr, min=0.0) / self.exposure

end2end_imaging.sensor.isp_modules.DeadPixelCorrection

DeadPixelCorrection(threshold=0.1, kernel_size=3, soft_blend=True, temperature=0.01)

Bases: Module

Dead pixel correction (DPC).

Detects and corrects dead/stuck pixels by comparing each pixel to its neighbors and replacing outliers with a local mean value.

Note: Uses differentiable operations (mean instead of median, soft mask).

Reference

[1] https://github.com/QiuJueqin/fast-openISP/blob/master/modules/dpc.py

Initialize dead pixel correction.

Parameters:

Name Type Description Default
threshold

Threshold for detecting dead pixels (as fraction of max value).

0.1
kernel_size

Size of the kernel for correction (must be odd).

3
soft_blend

If True, use differentiable soft blending. If False, use hard threshold.

True
temperature

Temperature for soft sigmoid blending (lower = sharper transition).

0.01
Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/dead_pixel.py
def __init__(self, threshold=0.1, kernel_size=3, soft_blend=True, temperature=0.01):
    """Initialize dead pixel correction.

    Args:
        threshold: Threshold for detecting dead pixels (as fraction of max value).
        kernel_size: Size of the kernel for correction (must be odd).
        soft_blend: If True, use differentiable soft blending. If False, use hard threshold.
        temperature: Temperature for soft sigmoid blending (lower = sharper transition).
    """
    super().__init__()
    self.threshold = threshold
    self.kernel_size = kernel_size if kernel_size % 2 == 1 else kernel_size + 1
    self.soft_blend = soft_blend
    self.temperature = temperature

    # Pre-compute averaging kernel (excluding center pixel)
    kernel = torch.ones(1, 1, self.kernel_size, self.kernel_size)
    center = self.kernel_size // 2
    kernel[0, 0, center, center] = 0  # Exclude center pixel
    kernel = kernel / kernel.sum()  # Normalize
    self.register_buffer("avg_kernel", kernel)

forward

forward(bayer)

Dead Pixel Correction (differentiable).

Parameters:

Name Type Description Default
bayer Tensor

Input bayer image [B, 1, H, W], data range [0, 1].

required

Returns:

Name Type Description
bayer_corrected Tensor

Corrected bayer image [B, 1, H, W].

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/dead_pixel.py
def forward(self, bayer):
    """Dead Pixel Correction (differentiable).

    Args:
        bayer (torch.Tensor): Input bayer image [B, 1, H, W], data range [0, 1].

    Returns:
        bayer_corrected (torch.Tensor): Corrected bayer image [B, 1, H, W].
    """
    padding = self.kernel_size // 2

    # Compute local mean (excluding center pixel) - differentiable
    local_mean = F.conv2d(bayer, self.avg_kernel.to(bayer.dtype), padding=padding)

    # Compute difference from local mean
    diff = torch.abs(bayer - local_mean)

    if self.soft_blend:
        # Soft differentiable blending using sigmoid
        # blend_weight approaches 1 when diff >> threshold (use local_mean)
        # blend_weight approaches 0 when diff << threshold (use original)
        blend_weight = torch.sigmoid((diff - self.threshold) / self.temperature)
        result = (1 - blend_weight) * bayer + blend_weight * local_mean
    else:
        # Hard threshold (not differentiable through the mask)
        mask = (diff > self.threshold).float()
        result = (1 - mask) * bayer + mask * local_mean

    return result

reverse

reverse(bayer)

Reverse dead pixel correction (identity).

Note: Dead pixel correction is a lossy operation that cannot be reversed. This returns the input unchanged.

Parameters:

Name Type Description Default
bayer Tensor

Input bayer image [B, 1, H, W].

required

Returns:

Name Type Description
bayer Tensor

Input unchanged.

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/dead_pixel.py
def reverse(self, bayer):
    """Reverse dead pixel correction (identity).

    Note: Dead pixel correction is a lossy operation that cannot be reversed.
    This returns the input unchanged.

    Args:
        bayer (torch.Tensor): Input bayer image [B, 1, H, W].

    Returns:
        bayer (torch.Tensor): Input unchanged.
    """
    # Dead pixel correction cannot be reversed; return input as-is
    return bayer

end2end_imaging.sensor.isp_modules.Denoise

Denoise(method='gaussian', kernel_size=3, sigma=0.5, sigma_color=0.1)

Bases: Module

Noise reduction (differentiable).

Applies denoising filters to reduce sensor noise in the image. Supports Gaussian filtering (differentiable) and bilateral filtering.

Note: Median filtering is NOT differentiable, so we use Gaussian or bilateral instead.

Initialize denoise.

Parameters:

Name Type Description Default
method

Noise reduction method: "gaussian", "bilateral", or None.

'gaussian'
kernel_size

Size of the kernel (must be odd).

3
sigma

Standard deviation for spatial Gaussian kernel.

0.5
sigma_color

Standard deviation for color/intensity similarity (bilateral only).

0.1
Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/denoise.py
def __init__(self, method="gaussian", kernel_size=3, sigma=0.5, sigma_color=0.1):
    """Initialize denoise.

    Args:
        method: Noise reduction method: "gaussian", "bilateral", or None.
        kernel_size: Size of the kernel (must be odd).
        sigma: Standard deviation for spatial Gaussian kernel.
        sigma_color: Standard deviation for color/intensity similarity (bilateral only).
    """
    super().__init__()
    self.method = method
    self.kernel_size = kernel_size if kernel_size % 2 == 1 else kernel_size + 1
    self.sigma = sigma
    self.sigma_color = sigma_color

    # Pre-compute Gaussian kernel
    kernel = self._create_gaussian_kernel(self.kernel_size, self.sigma)
    self.register_buffer("gaussian_kernel", kernel)

forward

forward(img)

Apply denoise (differentiable).

Parameters:

Name Type Description Default
img Tensor

Input tensor of shape [B, C, H, W], data range [0, 1].

required

Returns:

Name Type Description
img_filtered Tensor

Denoised image of shape [B, C, H, W], data range [0, 1]. If the method is None or "none", the input is returned unchanged.

Raises:

Type Description
ValueError

If self.method is not None, "none", "gaussian", or "bilateral".

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/denoise.py
def forward(self, img):
    """Apply denoise (differentiable).

    Args:
        img (torch.Tensor): Input tensor of shape [B, C, H, W], data range [0, 1].

    Returns:
        img_filtered (torch.Tensor): Denoised image of shape [B, C, H, W],
            data range [0, 1]. If the method is None or "none", the input is
            returned unchanged.

    Raises:
        ValueError: If ``self.method`` is not None, "none", "gaussian", or
            "bilateral".
    """
    if self.method is None or self.method == "none":
        return img

    if self.method == "gaussian":
        img_filtered = self._gaussian_filter(img)

    elif self.method == "bilateral":
        img_filtered = self._bilateral_filter(img)

    else:
        raise ValueError(f"Unknown noise reduction method: {self.method}")

    return img_filtered

reverse

reverse(img)

Reverse denoising (identity).

Note: Denoising is a lossy operation that cannot be reversed. This returns the input unchanged.

Parameters:

Name Type Description Default
img Tensor

Input tensor of shape [B, C, H, W].

required

Returns:

Name Type Description
img Tensor

Input unchanged.

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/denoise.py
def reverse(self, img):
    """Reverse denoising (identity).

    Note: Denoising is a lossy operation that cannot be reversed.
    This returns the input unchanged.

    Args:
        img (torch.Tensor): Input tensor of shape [B, C, H, W].

    Returns:
        img (torch.Tensor): Input unchanged.
    """
    # Denoising cannot be reversed; return input as-is
    return img

end2end_imaging.sensor.isp_modules.LensShadingCorrection

LensShadingCorrection(shading_map=None, strength=1.0, falloff_model='radial')

Bases: Module

Lens shading correction (LSC).

Corrects vignetting (darkening at edges/corners) caused by lens optical properties by applying a spatially-varying gain map.

Initialize lens shading correction module.

Parameters:

Name Type Description Default
shading_map

Pre-computed shading gain map of shape [H, W] or [1, 1, H, W]. If None, a radial falloff model is used. Default is None.

None
strength

Strength of the correction (0-1). 0 = no correction, 1 = full. Default is 1.0.

1.0
falloff_model

Model for computing gain map. Options: "radial", "polynomial". Only used if shading_map is None. Default is "radial".

'radial'
Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/lens_shading.py
def __init__(self, shading_map=None, strength=1.0, falloff_model="radial"):
    """Initialize lens shading correction module.

    Args:
        shading_map: Pre-computed shading gain map of shape [H, W] or [1, 1, H, W].
                     If None, a radial falloff model is used. Default is None.
        strength: Strength of the correction (0-1). 0 = no correction, 1 = full. Default is 1.0.
        falloff_model: Model for computing gain map. Options: "radial", "polynomial".
                       Only used if shading_map is None. Default is "radial".
    """
    super().__init__()
    self.strength = strength
    self.falloff_model = falloff_model

    if shading_map is not None:
        if isinstance(shading_map, torch.Tensor):
            if shading_map.dim() == 2:
                shading_map = shading_map.unsqueeze(0).unsqueeze(0)
            self.register_buffer("shading_map", shading_map)
        else:
            raise ValueError("shading_map must be a torch.Tensor")
    else:
        self.shading_map = None

    # Polynomial coefficients for vignetting model (typical values)
    # V(r) = 1 + k1*r^2 + k2*r^4 + k3*r^6
    self.register_buffer("poly_coeffs", torch.tensor([0.3, 0.15, 0.05]))

forward

forward(x)

Apply lens shading correction to remove vignetting.

Parameters:

Name Type Description Default
x

Input tensor of shape [B, C, H, W], data range [0, 1].

required

Returns:

Name Type Description
x_corrected

Corrected tensor of shape [B, C, H, W].

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/lens_shading.py
def forward(self, x):
    """Apply lens shading correction to remove vignetting.

    Args:
        x: Input tensor of shape [B, C, H, W], data range [0, 1].

    Returns:
        x_corrected: Corrected tensor of shape [B, C, H, W].
    """
    if self.strength == 0:
        return x

    B, C, H, W = x.shape

    # Get or compute the gain map
    if self.shading_map is not None:
        # Resize shading map to match input if needed
        if self.shading_map.shape[-2:] != (H, W):
            gain_map = F.interpolate(
                self.shading_map, size=(H, W), mode="bilinear", align_corners=True
            )
        else:
            gain_map = self.shading_map
    else:
        # Compute gain map on-the-fly
        gain_map = self._compute_radial_gain(H, W, x.device, x.dtype)

    # Apply strength-weighted correction
    # gain = 1 + strength * (computed_gain - 1)
    effective_gain = 1 + self.strength * (gain_map - 1)

    # Apply correction
    x_corrected = x * effective_gain

    # Clamp to valid range
    x_corrected = torch.clamp(x_corrected, 0.0, 1.0)

    return x_corrected

reverse

reverse(x)

Reverse lens shading correction (add vignetting back).

Parameters:

Name Type Description Default
x

Input tensor of shape [B, C, H, W], data range [0, 1].

required

Returns:

Name Type Description
x_vignetted

Tensor with vignetting applied, shape [B, C, H, W].

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/lens_shading.py
def reverse(self, x):
    """Reverse lens shading correction (add vignetting back).

    Args:
        x: Input tensor of shape [B, C, H, W], data range [0, 1].

    Returns:
        x_vignetted: Tensor with vignetting applied, shape [B, C, H, W].
    """
    if self.strength == 0:
        return x

    B, C, H, W = x.shape

    # Get or compute the gain map
    if self.shading_map is not None:
        if self.shading_map.shape[-2:] != (H, W):
            gain_map = F.interpolate(
                self.shading_map, size=(H, W), mode="bilinear", align_corners=True
            )
        else:
            gain_map = self.shading_map
    else:
        gain_map = self._compute_radial_gain(H, W, x.device, x.dtype)

    # Compute inverse gain
    effective_gain = 1 + self.strength * (gain_map - 1)
    inverse_gain = 1.0 / effective_gain

    # Apply inverse correction (add vignetting)
    x_vignetted = x * inverse_gain

    return x_vignetted

end2end_imaging.sensor.isp_modules.AntiAliasingFilter

AntiAliasingFilter(method='weighted_average', kernel_size=3)

Bases: Module

Anti-Aliasing Filter (AAF).

Anti-aliasing filter is applied to raw Bayer data to reduce moiré patterns and aliasing artifacts before demosaicing.

Reference

[1] https://github.com/QiuJueqin/fast-openISP/blob/master/modules/aaf.py

Initialize the Anti-Aliasing Filter.

Parameters:

Name Type Description Default
method str

Filtering method. Options: "weighted_average", "gaussian", "none", or None.

'weighted_average'
kernel_size int

Size of the filter kernel (must be odd).

3
Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/anti_alising.py
def __init__(self, method="weighted_average", kernel_size=3):
    """Initialize the Anti-Aliasing Filter.

    Args:
        method (str): Filtering method. Options: "weighted_average", "gaussian", "none", or None.
        kernel_size (int): Size of the filter kernel (must be odd).
    """
    super(AntiAliasingFilter, self).__init__()
    self.method = method
    self.kernel_size = kernel_size if kernel_size % 2 == 1 else kernel_size + 1

    # Pre-compute kernels
    if method == "weighted_average":
        # Weighted average kernel: center pixel gets higher weight
        kernel = torch.ones(1, 1, self.kernel_size, self.kernel_size)
        center = self.kernel_size // 2
        kernel[0, 0, center, center] = 8.0
        kernel = kernel / kernel.sum()
        self.register_buffer("kernel", kernel)
    elif method == "gaussian":
        # Gaussian kernel
        sigma = self.kernel_size / 6.0
        x = torch.arange(self.kernel_size) - self.kernel_size // 2
        x = x.float()
        kernel_1d = torch.exp(-0.5 * (x / sigma) ** 2)
        kernel_2d = torch.outer(kernel_1d, kernel_1d)
        kernel_2d = kernel_2d / kernel_2d.sum()
        self.register_buffer(
            "kernel", kernel_2d.view(1, 1, self.kernel_size, self.kernel_size)
        )

forward

forward(bayer)

Apply anti-aliasing filter to remove moiré pattern.

Parameters:

Name Type Description Default
bayer

Input tensor of shape [B, 1, H, W], data range [0, 1].

required

Returns:

Type Description

Filtered bayer tensor of same shape as input.

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/anti_alising.py
def forward(self, bayer):
    """Apply anti-aliasing filter to remove moiré pattern.

    Args:
        bayer: Input tensor of shape [B, 1, H, W], data range [0, 1].

    Returns:
        Filtered bayer tensor of same shape as input.
    """
    if self.method is None or self.method == "none":
        return bayer

    if self.method in ["weighted_average", "gaussian"]:
        padding = self.kernel_size // 2
        # Apply convolution filter
        filtered = F.conv2d(bayer, self.kernel.to(bayer.dtype), padding=padding)
        return filtered

    else:
        raise ValueError(f"Unknown anti-aliasing method: {self.method}")

reverse

reverse(bayer)

Reverse anti-aliasing filter (approximation).

Note: Anti-aliasing is a lossy operation, so perfect reversal is not possible. This returns the input unchanged as an approximation.

Parameters:

Name Type Description Default
bayer

Input tensor of shape [B, 1, H, W], data range [0, 1].

required

Returns:

Type Description

Input tensor unchanged.

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/anti_alising.py
def reverse(self, bayer):
    """Reverse anti-aliasing filter (approximation).

    Note: Anti-aliasing is a lossy operation, so perfect reversal is not possible.
    This returns the input unchanged as an approximation.

    Args:
        bayer: Input tensor of shape [B, 1, H, W], data range [0, 1].

    Returns:
        Input tensor unchanged.
    """
    # Anti-aliasing filtering is lossy; we cannot perfectly reverse it
    # Return input unchanged as best approximation
    return bayer

end2end_imaging.sensor.isp_modules.ColorSpaceConversion

ColorSpaceConversion()

Bases: Module

Color space conversion (CSC).

Color space conversion is a technique to convert the color space of the image.

Initialize color space conversion module.

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/color_space.py
def __init__(self):
    """Initialize color space conversion module."""
    super().__init__()

    # RGB to YCrCb conversion matrix
    self.register_buffer(
        "rgb_to_ycrcb_matrix",
        torch.tensor(
            [
                [0.299, 0.587, 0.114],
                [0.5, -0.4187, -0.0813],
                [-0.1687, -0.3313, 0.5],
            ]
        ),
    )

    # YCrCb to RGB conversion matrix
    self.register_buffer(
        "ycrcb_to_rgb_matrix",
        torch.tensor(
            [[1.0, 0.0, 1.402], [1.0, -0.344136, -0.714136], [1.0, 1.772, 0.0]]
        ),
    )

rgb_to_ycrcb

rgb_to_ycrcb(rgb_image)

Convert RGB to YCrCb (differentiable).

Parameters:

Name Type Description Default
rgb_image

Input tensor of shape [B, 3, H, W] in RGB format.

required

Returns:

Name Type Description
ycrcb_image

Output tensor of shape [B, 3, H, W] in YCrCb format.

Reference

[1] https://github.com/QiuJueqin/fast-openISP/blob/master/modules/csc.py

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/color_space.py
def rgb_to_ycrcb(self, rgb_image):
    """Convert RGB to YCrCb (differentiable).

    Args:
        rgb_image: Input tensor of shape [B, 3, H, W] in RGB format.

    Returns:
        ycrcb_image: Output tensor of shape [B, 3, H, W] in YCrCb format.

    Reference:
        [1] https://github.com/QiuJueqin/fast-openISP/blob/master/modules/csc.py
    """
    # Reshape for matrix multiplication
    rgb_reshaped = rgb_image.permute(0, 2, 3, 1)  # [B, H, W, 3]

    # Apply transformation
    ycrcb = torch.matmul(rgb_reshaped, self.rgb_to_ycrcb_matrix.T)

    # Add offset to Cr and Cb (non-in-place for gradient flow)
    offset = torch.tensor([0.0, 0.5, 0.5], device=ycrcb.device, dtype=ycrcb.dtype)
    ycrcb = ycrcb + offset

    # Reshape back
    ycrcb_image = ycrcb.permute(0, 3, 1, 2)  # [B, 3, H, W]

    return ycrcb_image

ycrcb_to_rgb

ycrcb_to_rgb(ycrcb_image)

Convert YCrCb to RGB (differentiable).

Parameters:

Name Type Description Default
ycrcb_image

Input tensor of shape [B, 3, H, W] in YCrCb format.

required

Returns:

Name Type Description
rgb_image

Output tensor of shape [B, 3, H, W] in RGB format.

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/color_space.py
def ycrcb_to_rgb(self, ycrcb_image):
    """Convert YCrCb to RGB (differentiable).

    Args:
        ycrcb_image: Input tensor of shape [B, 3, H, W] in YCrCb format.

    Returns:
        rgb_image: Output tensor of shape [B, 3, H, W] in RGB format.
    """
    # Reshape for matrix multiplication
    ycrcb = ycrcb_image.permute(0, 2, 3, 1)  # [B, H, W, 3]

    # Subtract offset from Cr and Cb (non-in-place for gradient flow)
    offset = torch.tensor([0.0, 0.5, 0.5], device=ycrcb.device, dtype=ycrcb.dtype)
    ycrcb_adj = ycrcb - offset

    # Apply transformation
    rgb = torch.matmul(ycrcb_adj, self.ycrcb_to_rgb_matrix.T)

    # Clamp values to [0, 1]
    rgb = torch.clamp(rgb, 0.0, 1.0)

    # Reshape back
    rgb_image = rgb.permute(0, 3, 1, 2)  # [B, 3, H, W]

    return rgb_image

forward

forward(image, conversion='rgb_to_ycrcb')

Convert between color spaces.

Parameters:

Name Type Description Default
image

Input tensor of shape [B, 3, H, W].

required
conversion

Conversion direction, "rgb_to_ycrcb" or "ycrcb_to_rgb".

'rgb_to_ycrcb'

Returns:

Name Type Description
converted_image

Output tensor of shape [B, 3, H, W].

Raises:

Type Description
ValueError

If conversion is not "rgb_to_ycrcb" or "ycrcb_to_rgb".

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/color_space.py
def forward(self, image, conversion="rgb_to_ycrcb"):
    """Convert between color spaces.

    Args:
        image: Input tensor of shape [B, 3, H, W].
        conversion: Conversion direction, "rgb_to_ycrcb" or "ycrcb_to_rgb".

    Returns:
        converted_image: Output tensor of shape [B, 3, H, W].

    Raises:
        ValueError: If conversion is not "rgb_to_ycrcb" or "ycrcb_to_rgb".
    """
    if conversion == "rgb_to_ycrcb":
        return self.rgb_to_ycrcb(image)
    elif conversion == "ycrcb_to_rgb":
        return self.ycrcb_to_rgb(image)
    else:
        raise ValueError(f"Unknown conversion: {conversion}")

reverse

reverse(image)

Reverse color space conversion (YCrCb to RGB).

This is the inverse of the forward pass (which defaults to rgb_to_ycrcb).

Parameters:

Name Type Description Default
image

Input tensor of shape [B, 3, H, W] in YCrCb format.

required

Returns:

Name Type Description
rgb_image

Output tensor of shape [B, 3, H, W] in RGB format.

Source code in end2endimaging-src/end2end_imaging/sensor/isp_modules/color_space.py
def reverse(self, image):
    """Reverse color space conversion (YCrCb to RGB).

    This is the inverse of the forward pass (which defaults to rgb_to_ycrcb).

    Args:
        image: Input tensor of shape [B, 3, H, W] in YCrCb format.

    Returns:
        rgb_image: Output tensor of shape [B, 3, H, W] in RGB format.
    """
    return self.ycrcb_to_rgb(image)