Skip to content

Lens

Abstract base class for every lens model in DeepLens. Lens defines the shared interface — psf(), psf_rgb(), render(), sensor configuration, and file I/O — that GeoLens, HybridLens, DiffractiveLens, PSFNetLens, and DefocusLens all inherit.

deeplens.Lens

Lens(dtype=torch.float32, device=None, primary_wvln=DEFAULT_WAVE, wvln_rgb=WAVE_RGB, obj_depth=DEPTH)

Bases: DeepObj

Initialize a lens class.

Parameters:

Name Type Description Default
dtype dtype

Data type. Defaults to torch.float32.

float32
device str

Device to run the lens. Defaults to None.

None
primary_wvln float

Primary design wavelength [µm]. Used as fallback when a method is called without an explicit wvln. Defaults to DEFAULT_WAVE (0.587, d-line).

DEFAULT_WAVE
wvln_rgb sequence of float

Three wavelengths used for RGB (polychromatic) computations, ordered [R, G, B] in µm. Defaults to WAVE_RGB.

WAVE_RGB
obj_depth float

Default object depth [mm] used as fallback when a method is called without an explicit depth. Should be negative (object in front of the lens). Defaults to DEPTH (−20 000 mm, practical infinity).

DEPTH
Source code in deeplens-src/deeplens/lens.py
def __init__(
    self,
    dtype=torch.float32,
    device=None,
    primary_wvln=DEFAULT_WAVE,
    wvln_rgb=WAVE_RGB,
    obj_depth=DEPTH,
):
    """Initialize a lens class.

    Args:
        dtype (torch.dtype, optional): Data type. Defaults to torch.float32.
        device (str, optional): Device to run the lens. Defaults to None.
        primary_wvln (float, optional): Primary design wavelength [µm].
            Used as fallback when a method is called without an explicit
            ``wvln``.  Defaults to ``DEFAULT_WAVE`` (0.587, d-line).
        wvln_rgb (sequence of float, optional): Three wavelengths used for
            RGB (polychromatic) computations, ordered ``[R, G, B]`` in
            µm.  Defaults to ``WAVE_RGB``.
        obj_depth (float, optional): Default object depth [mm] used as
            fallback when a method is called without an explicit
            ``depth``.  Should be negative (object in front of the lens).
            Defaults to ``DEPTH`` (−20 000 mm, practical infinity).
    """
    # Lens device
    if device is None:
        self.device = init_device()
    else:
        self.device = torch.device(device)

    # Lens default dtype
    self.dtype = dtype

    primary_wvln = torch.as_tensor(primary_wvln, dtype=torch.float64)
    wvln_rgb = torch.as_tensor(wvln_rgb, dtype=torch.float64)
    obj_depth = torch.as_tensor(obj_depth, dtype=torch.float64)

    if primary_wvln.numel() != 1:
        raise ValueError("primary_wvln must be a scalar wavelength in [µm].")
    if wvln_rgb.numel() != 3:
        raise ValueError("wvln_rgb must contain exactly three wavelengths in [µm].")
    if obj_depth.numel() != 1:
        raise ValueError("obj_depth must be a scalar depth in [mm].")

    if not (primary_wvln.item() > 0.1 and primary_wvln.item() < 10.0):
        raise ValueError("primary_wvln must be in [µm] and satisfy 0.1 < primary_wvln < 10.")
    if not torch.all((wvln_rgb > 0.1) & (wvln_rgb < 10.0)):
        raise ValueError("wvln_rgb must be in [µm] and every value must satisfy 0.1 < wvln < 10.")
    if not obj_depth.item() < 0.0:
        raise ValueError("obj_depth must be negative [mm], with the object in front of the lens.")

    # Design wavelengths [µm].  IO may override.
    self.primary_wvln = float(primary_wvln.item())
    self.wvln_rgb = [float(w) for w in wvln_rgb.tolist()]

    # Default object depth [mm].
    self.obj_depth = float(obj_depth.item())

read_lens_json

read_lens_json(filename)

Read lens from a json file.

Source code in deeplens-src/deeplens/lens.py
def read_lens_json(self, filename):
    """Read lens from a json file."""
    raise NotImplementedError

write_lens_json

write_lens_json(filename)

Write lens to a json file.

Source code in deeplens-src/deeplens/lens.py
def write_lens_json(self, filename):
    """Write lens to a json file."""
    raise NotImplementedError

set_sensor

set_sensor(sensor_size, sensor_res)

Set sensor size and resolution.

Parameters:

Name Type Description Default
sensor_size tuple

Sensor size (w, h) in [mm].

required
sensor_res tuple

Sensor resolution (W, H) in [pixels].

required
Source code in deeplens-src/deeplens/lens.py
def set_sensor(self, sensor_size, sensor_res):
    """Set sensor size and resolution.

    Args:
        sensor_size (tuple): Sensor size (w, h) in [mm].
        sensor_res (tuple): Sensor resolution (W, H) in [pixels].
    """
    assert sensor_size[0] * sensor_res[1] == sensor_size[1] * sensor_res[0], (
        "Sensor resolution aspect ratio does not match sensor size aspect ratio."
    )
    self.sensor_size = sensor_size
    self.sensor_res = sensor_res
    self.pixel_size = self.sensor_size[0] / self.sensor_res[0]
    self.r_sensor = float(np.sqrt(sensor_size[0] ** 2 + sensor_size[1] ** 2)) / 2
    self.calc_fov()

set_sensor_res

set_sensor_res(sensor_res)

Set sensor resolution (and aspect ratio) while keeping sensor radius unchanged.

Parameters:

Name Type Description Default
sensor_res tuple

Sensor resolution (W, H) in [pixels].

required
Source code in deeplens-src/deeplens/lens.py
def set_sensor_res(self, sensor_res):
    """Set sensor resolution (and aspect ratio) while keeping sensor radius unchanged.

    Args:
        sensor_res (tuple): Sensor resolution (W, H) in [pixels].
    """
    # Change sensor resolution
    self.sensor_res = sensor_res

    # Change sensor size (r_sensor is fixed)
    diam_res = float(np.sqrt(self.sensor_res[0] ** 2 + self.sensor_res[1] ** 2))
    self.sensor_size = (
        2 * self.r_sensor * self.sensor_res[0] / diam_res,
        2 * self.r_sensor * self.sensor_res[1] / diam_res,
    )
    self.pixel_size = self.sensor_size[0] / self.sensor_res[0]
    self.calc_fov()

calc_fov

calc_fov()

Compute FoV (radian) of the lens.

Reference

[1] https://en.wikipedia.org/wiki/Angle_of_view_(photography)

Source code in deeplens-src/deeplens/lens.py
@torch.no_grad()
def calc_fov(self):
    """Compute FoV (radian) of the lens.

    Reference:
        [1] https://en.wikipedia.org/wiki/Angle_of_view_(photography)
    """
    if not hasattr(self, "foclen"):
        return

    self.vfov = 2 * float(np.arctan(self.sensor_size[0] / 2 / self.foclen))
    self.hfov = 2 * float(np.arctan(self.sensor_size[1] / 2 / self.foclen))
    self.dfov = 2 * float(np.arctan(self.r_sensor / self.foclen))
    self.rfov_eff = self.dfov / 2  # effective (paraxial) half-diagonal FoV
    self.rfov = self.rfov_eff  # default to effective; GeoLens overrides with ray-traced value

psf

psf(points, wvln=None, ks=PSF_KS, **kwargs)

Compute the monochromatic PSF for one or more point sources.

Subclasses must override this method with a differentiable implementation. Three computation models are common in practice: geometric ray binning, coherent ray-wave, and Huygens spherical-wave integration.

Parameters:

Name Type Description Default
points Tensor

Point source coordinates, shape [N, 3] or [3]. x, y are normalised to [-1, 1] (relative to the sensor half-diagonal); z is depth in mm (must be negative, i.e. in front of the lens).

required
wvln float

Wavelength in micrometers. When None (default), falls back to self.primary_wvln.

None
ks int

Output PSF kernel size in pixels. Defaults to PSF_KS (64).

PSF_KS
**kwargs

Additional keyword arguments forwarded to the underlying PSF computation (e.g. spp, model, recenter).

{}

Returns:

Type Description

torch.Tensor: PSF intensity map, shape [ks, ks] for a single

point or [N, ks, ks] for a batch.

Raises:

Type Description
NotImplementedError

This base implementation must be overridden.

Notes

The method is differentiable with respect to all optimisable lens parameters so it can be used directly inside a training loop.

Example

point = torch.tensor([0.0, 0.0, -10000.0]) psf = lens.psf(points=point, ks=64, model="geometric") print(psf.shape) # torch.Size([64, 64])

Source code in deeplens-src/deeplens/lens.py
def psf(self, points, wvln=None, ks=PSF_KS, **kwargs):
    """Compute the monochromatic PSF for one or more point sources.

    Subclasses must override this method with a differentiable
    implementation.  Three computation models are common in practice:
    geometric ray binning, coherent ray-wave, and Huygens spherical-wave
    integration.

    Args:
        points (torch.Tensor): Point source coordinates, shape ``[N, 3]``
            or ``[3]``.  ``x, y`` are normalised to ``[-1, 1]``
            (relative to the sensor half-diagonal); ``z`` is depth in mm
            (must be negative, i.e. in front of the lens).
        wvln (float, optional): Wavelength in micrometers.  When ``None``
            (default), falls back to ``self.primary_wvln``.
        ks (int, optional): Output PSF kernel size in pixels.  Defaults
            to ``PSF_KS`` (64).
        **kwargs: Additional keyword arguments forwarded to the underlying
            PSF computation (e.g. ``spp``, ``model``, ``recenter``).

    Returns:
        torch.Tensor: PSF intensity map, shape ``[ks, ks]`` for a single
        point or ``[N, ks, ks]`` for a batch.

    Raises:
        NotImplementedError: This base implementation must be overridden.

    Notes:
        The method is differentiable with respect to all optimisable lens
        parameters so it can be used directly inside a training loop.

    Example:
        >>> point = torch.tensor([0.0, 0.0, -10000.0])
        >>> psf = lens.psf(points=point, ks=64, model="geometric")
        >>> print(psf.shape)  # torch.Size([64, 64])
    """
    raise NotImplementedError

psf_rgb

psf_rgb(points, ks=PSF_KS, **kwargs)

Compute the RGB (tri-chromatic) PSF by stacking three wavelength calls.

Calls psf three times for the RGB primary wavelengths stored in self.wvln_rgb and stacks the results along the channel axis.

Parameters:

Name Type Description Default
points Tensor

Point source coordinates, shape [N, 3] or [3]. Same convention as psf.

required
ks int

PSF kernel size. Defaults to PSF_KS.

PSF_KS
**kwargs

Forwarded to psf (e.g. spp, model).

{}

Returns:

Type Description

torch.Tensor: RGB PSF, shape [3, ks, ks] for a single point

or [N, 3, ks, ks] for a batch.

Source code in deeplens-src/deeplens/lens.py
def psf_rgb(self, points, ks=PSF_KS, **kwargs):
    """Compute the RGB (tri-chromatic) PSF by stacking three wavelength calls.

    Calls `psf` three times for the RGB primary wavelengths stored
    in ``self.wvln_rgb`` and stacks the results along the channel axis.

    Args:
        points (torch.Tensor): Point source coordinates, shape ``[N, 3]``
            or ``[3]``.  Same convention as `psf`.
        ks (int, optional): PSF kernel size. Defaults to ``PSF_KS``.
        **kwargs: Forwarded to `psf` (e.g. ``spp``, ``model``).

    Returns:
        torch.Tensor: RGB PSF, shape ``[3, ks, ks]`` for a single point
        or ``[N, 3, ks, ks]`` for a batch.
    """
    psfs = []
    for wvln in self.wvln_rgb:
        psfs.append(self.psf(points=points, ks=ks, wvln=wvln, **kwargs))
    psf_rgb = torch.stack(psfs, dim=-3)  # shape [3, ks, ks] or [N, 3, ks, ks]
    return psf_rgb

point_source_grid

point_source_grid(depth, grid=(9, 9), normalized=True, quater=False, center=True)

Generate point source grid for PSF calculation.

Parameters:

Name Type Description Default
depth float

Depth of the point source.

required
grid tuple

Grid size (grid_w, grid_h). Defaults to (9, 9), meaning 9x9 grid.

(9, 9)
normalized bool

Return normalized object source coordinates. Defaults to True, meaning object sources xy coordinates range from [-1, 1].

True
quater bool

Use quater of the sensor plane to save memory. Defaults to False.

False
center bool

Use center of each patch. Defaults to True.

True

Returns:

Name Type Description
point_source

Normalized object source coordinates. Shape of [grid_h, grid_w, 3], [-1, 1], [-1, 1], [-Inf, 0].

Source code in deeplens-src/deeplens/lens.py
def point_source_grid(
    self, depth, grid=(9, 9), normalized=True, quater=False, center=True
):
    """Generate point source grid for PSF calculation.

    Args:
        depth (float): Depth of the point source.
        grid (tuple): Grid size (grid_w, grid_h). Defaults to (9, 9), meaning 9x9 grid.
        normalized (bool): Return normalized object source coordinates. Defaults to True, meaning object sources xy coordinates range from [-1, 1].
        quater (bool): Use quater of the sensor plane to save memory. Defaults to False.
        center (bool): Use center of each patch. Defaults to True.

    Returns:
        point_source: Normalized object source coordinates. Shape of [grid_h, grid_w, 3], [-1, 1], [-1, 1], [-Inf, 0].
    """
    # Compute point source grid
    if grid[0] == 1:
        x, y = torch.tensor([[0.0]], device=self.device), torch.tensor([[0.0]], device=self.device)
        assert not quater, "Quater should be False when grid is 1."
    else:
        if center:
            # Use center of each patch
            half_bin_size = 1 / 2 / (grid[0] - 1)
            x, y = torch.meshgrid(
                torch.linspace(-1 + half_bin_size, 1 - half_bin_size, grid[0], device=self.device),
                torch.linspace(1 - half_bin_size, -1 + half_bin_size, grid[1], device=self.device),
                indexing="xy",
            )
        else:
            # Use corner of image sensor
            x, y = torch.meshgrid(
                torch.linspace(-0.98, 0.98, grid[0], device=self.device),
                torch.linspace(0.98, -0.98, grid[1], device=self.device),
                indexing="xy",
            )

    z = torch.full_like(x, depth)
    point_source = torch.stack([x, y, z], dim=-1)

    # Use quater of the sensor plane to save memory
    if quater:
        bound_i = grid[0] // 2 if grid[0] % 2 == 0 else grid[0] // 2 + 1
        bound_j = grid[1] // 2
        point_source = point_source[0:bound_i, bound_j:, :]

    # De-normalize object source coordinates to physical coordinates
    if not normalized:
        scale = self.calc_scale(depth)
        point_source[..., 0] *= scale * self.sensor_size[0] / 2
        point_source[..., 1] *= scale * self.sensor_size[1] / 2

    return point_source

psf_map

psf_map(grid=(5, 5), wvln=None, depth=None, ks=PSF_KS, **kwargs)

Compute monochrome PSF map.

Parameters:

Name Type Description Default
grid tuple

Grid size (grid_w, grid_h). Defaults to (5, 5), meaning 5x5 grid.

(5, 5)
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None
depth float

Depth of the object. When None (default), falls back to self.obj_depth.

None
ks int

Kernel size. Defaults to PSF_KS.

PSF_KS

Returns:

Name Type Description
psf_map

Shape of [grid_h, grid_w, 3, ks, ks].

Source code in deeplens-src/deeplens/lens.py
def psf_map(self, grid=(5, 5), wvln=None, depth=None, ks=PSF_KS, **kwargs):
    """Compute monochrome PSF map.

    Args:
        grid (tuple): Grid size (grid_w, grid_h). Defaults to (5, 5), meaning 5x5 grid.
        wvln (float): Wavelength in µm. When ``None`` (default), falls back
            to ``self.primary_wvln``.
        depth (float): Depth of the object. When ``None`` (default), falls
            back to ``self.obj_depth``.
        ks (int): Kernel size. Defaults to PSF_KS.

    Returns:
        psf_map: Shape of [grid_h, grid_w, 3, ks, ks].
    """
    wvln = self.primary_wvln if wvln is None else wvln
    depth = self.obj_depth if depth is None else depth

    # PSF map grid
    points = self.point_source_grid(depth=depth, grid=grid, center=True)
    points = points.reshape(-1, 3)

    # Compute PSF map
    psfs = []
    for i in range(points.shape[0]):
        point = points[i, ...]
        psf = self.psf(points=point, wvln=wvln, ks=ks)
        psfs.append(psf)
    psf_map = torch.stack(psfs).unsqueeze(1)  # shape [grid_h * grid_w, 1, ks, ks]

    # Reshape PSF map from [grid_h * grid_w, 1, ks, ks] -> [grid_h, grid_w, 1, ks, ks]
    psf_map = psf_map.reshape(grid[1], grid[0], 1, ks, ks)
    return psf_map

psf_map_rgb

psf_map_rgb(grid=(5, 5), ks=PSF_KS, depth=None, **kwargs)

Compute RGB PSF map.

Parameters:

Name Type Description Default
grid tuple

Grid size (grid_w, grid_h). Defaults to (5, 5), meaning 5x5 grid.

(5, 5)
ks int

Kernel size. Defaults to PSF_KS, meaning PSF_KS x PSF_KS kernel size.

PSF_KS
depth float

Depth of the object. When None (default), falls back to self.obj_depth.

None
**kwargs

Additional arguments for psf_map().

{}

Returns:

Name Type Description
psf_map

Shape of [grid_h, grid_w, 3, ks, ks].

Source code in deeplens-src/deeplens/lens.py
def psf_map_rgb(self, grid=(5, 5), ks=PSF_KS, depth=None, **kwargs):
    """Compute RGB PSF map.

    Args:
        grid (tuple): Grid size (grid_w, grid_h). Defaults to (5, 5), meaning 5x5 grid.
        ks (int): Kernel size. Defaults to PSF_KS, meaning PSF_KS x PSF_KS kernel size.
        depth (float): Depth of the object. When ``None`` (default), falls
            back to ``self.obj_depth``.
        **kwargs: Additional arguments for psf_map().

    Returns:
        psf_map: Shape of [grid_h, grid_w, 3, ks, ks].
    """
    depth = self.obj_depth if depth is None else depth
    psfs = []
    for wvln in self.wvln_rgb:
        psf_map = self.psf_map(grid=grid, ks=ks, depth=depth, wvln=wvln, **kwargs)
        psfs.append(psf_map)
    psf_map = torch.cat(psfs, dim=2)  # shape [grid_h, grid_w, 3, ks, ks]
    return psf_map

draw_psf_map

draw_psf_map(grid=(7, 7), ks=PSF_KS, depth=None, log_scale=False, save_name='./psf_map.png', show=False)

Draw RGB PSF map of the lens.

Source code in deeplens-src/deeplens/lens.py
@torch.no_grad()
def draw_psf_map(
    self,
    grid=(7, 7),
    ks=PSF_KS,
    depth=None,
    log_scale=False,
    save_name="./psf_map.png",
    show=False,
):
    """Draw RGB PSF map of the lens."""
    depth = self.obj_depth if depth is None else depth
    # Calculate RGB PSF map, shape [grid_h, grid_w, 3, ks, ks]
    psf_map = self.psf_map_rgb(depth=depth, grid=grid, ks=ks)

    # Create a grid visualization (vis_map: shape [3, grid_h * ks, grid_w * ks])
    grid_w, grid_h = grid if isinstance(grid, tuple) else (grid, grid)
    h, w = grid_h * ks, grid_w * ks
    vis_map = torch.zeros((3, h, w), device=psf_map.device, dtype=psf_map.dtype)

    # Put each PSF into the vis_map
    for i in range(grid_h):
        for j in range(grid_w):
            # Extract the PSF at this grid position
            psf = psf_map[i, j]  # shape [3, ks, ks]

            # Normalize the PSF
            if log_scale:
                # Log scale normalization for better visualization
                psf = torch.log(psf + 1e-4)  # 1e-4 is an empirical value
                psf = (psf - psf.min()) / (psf.max() - psf.min() + 1e-8)
            else:
                # Linear normalization
                local_max = psf.max()
                if local_max > 0:
                    psf = psf / local_max

            # Place the normalized PSF in the visualization map
            y_start, y_end = i * ks, (i + 1) * ks
            x_start, x_end = j * ks, (j + 1) * ks
            vis_map[:, y_start:y_end, x_start:x_end] = psf

    # Create the figure and display
    fig, ax = plt.subplots(figsize=(10, 10))

    # Convert to numpy for plotting
    vis_map = vis_map.permute(1, 2, 0).cpu().numpy()
    ax.imshow(vis_map)

    # Add scale bar near bottom-left
    H, W, _ = vis_map.shape
    scale_bar_length = 100
    arrow_length = scale_bar_length / (self.pixel_size * 1e3)
    y_position = H - 20  # a little above the lower edge
    x_start = 20
    x_end = x_start + arrow_length

    ax.annotate(
        "",
        xy=(x_start, y_position),
        xytext=(x_end, y_position),
        arrowprops=dict(arrowstyle="-", color="white"),
        annotation_clip=False,
    )
    ax.text(
        x_end + 5,
        y_position,
        f"{scale_bar_length} μm",
        color="white",
        fontsize=12,
        ha="left",
        va="center",
        clip_on=False,
    )

    # Clean up axes and save
    ax.axis("off")
    plt.tight_layout(pad=0)

    if show:
        return fig, ax
    else:
        plt.savefig(save_name, dpi=300, bbox_inches="tight", pad_inches=0)
        plt.close(fig)

point_source_radial

point_source_radial(depth, grid=9, center=False, direction='diagonal', normalized=True)

Generate radial point sources from center to edge of the field.

Produces grid evenly-spaced points along a chosen radial direction (diagonal, meridional, or sagittal) in normalized or physical object-space coordinates.

Parameters:

Name Type Description Default
depth float

Object depth (z-coordinate) in mm.

required
grid int

Number of sample points. Defaults to 9.

9
center bool

If True, offset positions to bin centers. Defaults to False.

False
direction str

Sampling direction — "diagonal" (x = y, 45°, default), "y" (meridional, x = 0), "x" (sagittal, y = 0).

'diagonal'
normalized bool

If True, return coordinates in [0, 1]. If False, scale to physical object-space positions (mm). Defaults to True.

True

Returns:

Type Description

torch.Tensor: Point source positions, shape [grid, 3].

Source code in deeplens-src/deeplens/lens.py
def point_source_radial(self, depth, grid=9, center=False, direction="diagonal", normalized=True):
    """Generate radial point sources from center to edge of the field.

    Produces ``grid`` evenly-spaced points along a chosen radial direction
    (diagonal, meridional, or sagittal) in normalized or physical object-space
    coordinates.

    Args:
        depth (float): Object depth (z-coordinate) in mm.
        grid (int): Number of sample points. Defaults to 9.
        center (bool): If ``True``, offset positions to bin centers.
            Defaults to ``False``.
        direction (str): Sampling direction —
            ``"diagonal"`` (x = y, 45°, default),
            ``"y"`` (meridional, x = 0),
            ``"x"`` (sagittal, y = 0).
        normalized (bool): If ``True``, return coordinates in [0, 1].
            If ``False``, scale to physical object-space positions (mm).
            Defaults to ``True``.

    Returns:
        torch.Tensor: Point source positions, shape ``[grid, 3]``.
    """
    if grid == 1:
        r = torch.tensor([0.0], device=self.device)
    else:
        # Select center of bin to calculate PSF
        if center:
            half_bin_size = 1 / 2 / (grid - 1)
            r = torch.linspace(0, 1 - half_bin_size, grid, device=self.device)
        else:
            r = torch.linspace(0, 0.98, grid, device=self.device)

    # Map radial coordinate to (x, y) based on direction
    if direction == "diagonal":
        px, py = r, r
    elif direction == "y":
        px, py = torch.zeros_like(r), r
    elif direction == "x":
        px, py = r, torch.zeros_like(r)
    else:
        raise ValueError(f"Invalid direction: {direction!r}. Use 'diagonal', 'x', or 'y'.")

    z = torch.full_like(px, depth)
    point_source = torch.stack([px, py, z], dim=-1)

    if not normalized:
        scale = self.calc_scale(depth)
        point_source[..., 0] = point_source[..., 0] * scale * self.sensor_size[0] / 2
        point_source[..., 1] = point_source[..., 1] * scale * self.sensor_size[1] / 2

    return point_source

draw_psf_radial

draw_psf_radial(M=3, depth=None, ks=PSF_KS, log_scale=False, save_name='./psf_radial.png')

Draw radial PSF (45 deg). Will draw M PSFs, each of size ks x ks.

Source code in deeplens-src/deeplens/lens.py
@torch.no_grad()
def draw_psf_radial(
    self, M=3, depth=None, ks=PSF_KS, log_scale=False, save_name="./psf_radial.png"
):
    """Draw radial PSF (45 deg). Will draw M PSFs, each of size ks x ks."""
    from torchvision.utils import make_grid, save_image
    depth = self.obj_depth if depth is None else depth
    x = torch.linspace(0, 1, M)
    y = torch.linspace(0, 1, M)
    z = torch.full_like(x, depth)
    points = torch.stack((x, y, z), dim=-1)

    psfs = []
    for i in range(M):
        # Scale PSF for a better visualization
        psf = self.psf_rgb(points=points[i], ks=ks, recenter=True, spp=SPP_PSF)
        psf /= psf.max()

        if log_scale:
            psf = torch.log(psf + EPSILON)
            psf = (psf - psf.min()) / (psf.max() - psf.min())

        psfs.append(psf)

    psf_grid = make_grid(psfs, nrow=M, padding=1, pad_value=0.0)
    save_image(psf_grid, save_name, normalize=True)

render

render(img_obj, depth=None, method='psf_patch', **kwargs)

Differentiable image simulation for a 2D (flat) scene.

Performs only the optical component of image simulation and is fully differentiable.

For incoherent imaging the intensity PSF is convolved with the object-space image. For coherent imaging the complex PSF is convolved with the complex object image before squaring for intensity.

Parameters:

Name Type Description Default
img_obj Tensor

Input image in linear (raw) space, shape [B, C, H, W].

required
depth float

Object depth in mm (negative value). When None (default), falls back to self.obj_depth.

None
method str

Rendering method. One of:

  • "psf_patch" – convolve a single PSF evaluated at patch_center (default).
  • "psf_map" – spatially-varying PSF block convolution.
'psf_patch'
**kwargs

Method-specific keyword arguments:

  • For "psf_map": psf_grid (tuple, default (10, 10)), psf_ks (int, default PSF_KS).
  • For "psf_patch": patch_center (tuple or Tensor, default (0.0, 0.0)), psf_ks (int).
{}

Returns:

Type Description

torch.Tensor: Rendered image, shape [B, C, H, W].

Raises:

Type Description
AssertionError

If method is "psf_map" and the image resolution does not match the sensor resolution.

Exception

If method is not recognised.

References

[1] "Optical Aberration Correction in Postprocessing using Imaging Simulation", TOG 2021. [2] "Efficient depth- and spatially-varying image simulation for defocus deblur", ICCVW 2025.

Example

img_rendered = lens.render(img, depth=-10000.0, method="psf_patch", ... patch_center=(0.3, 0.0), psf_ks=64)

Source code in deeplens-src/deeplens/lens.py
def render(self, img_obj, depth=None, method="psf_patch", **kwargs):
    """Differentiable image simulation for a 2D (flat) scene.

    Performs only the optical component of image simulation and is fully
    differentiable.

    For incoherent imaging the intensity PSF is convolved with the
    object-space image.  For coherent imaging the complex PSF is convolved
    with the complex object image before squaring for intensity.

    Args:
        img_obj (torch.Tensor): Input image in linear (raw) space,
            shape ``[B, C, H, W]``.
        depth (float, optional): Object depth in mm (negative value).
            When ``None`` (default), falls back to ``self.obj_depth``.
        method (str, optional): Rendering method.  One of:

            * ``"psf_patch"`` – convolve a single PSF evaluated at
              *patch_center* (default).
            * ``"psf_map"`` – spatially-varying PSF block convolution.

        **kwargs: Method-specific keyword arguments:

            * For ``"psf_map"``: ``psf_grid`` (tuple, default ``(10, 10)``),
              ``psf_ks`` (int, default ``PSF_KS``).
            * For ``"psf_patch"``: ``patch_center`` (tuple or Tensor,
              default ``(0.0, 0.0)``), ``psf_ks`` (int).

    Returns:
        torch.Tensor: Rendered image, shape ``[B, C, H, W]``.

    Raises:
        AssertionError: If *method* is ``"psf_map"`` and the image
            resolution does not match the sensor resolution.
        Exception: If *method* is not recognised.

    References:
        [1] "Optical Aberration Correction in Postprocessing using Imaging Simulation", TOG 2021.
        [2] "Efficient depth- and spatially-varying image simulation for defocus deblur", ICCVW 2025.

    Example:
        >>> img_rendered = lens.render(img, depth=-10000.0, method="psf_patch",
        ...                            patch_center=(0.3, 0.0), psf_ks=64)
    """
    depth = self.obj_depth if depth is None else depth
    # Check sensor resolution
    B, C, Himg, Wimg = img_obj.shape
    Wsensor, Hsensor = self.sensor_res

    # Image simulation (in RAW space)
    if method == "psf_map":
        # Render full resolution image with PSF map convolution
        assert Wimg == Wsensor and Himg == Hsensor, (
            f"Sensor resolution {Wsensor}x{Hsensor} must match input image {Wimg}x{Himg}."
        )
        psf_grid = kwargs.get("psf_grid", (10, 10))
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        psf_spp = kwargs.get("psf_spp", SPP_PSF)
        img_render = self.render_psf_map(
            img_obj,
            depth=depth,
            psf_grid=psf_grid,
            psf_ks=psf_ks,
            psf_spp=psf_spp,
        )

    elif method == "psf_patch":
        # Render an image patch with its corresponding PSF
        patch_center = kwargs.get("patch_center", (0.0, 0.0))
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        img_render = self.render_psf_patch(
            img_obj, depth=depth, patch_center=patch_center, psf_ks=psf_ks
        )

    elif method == "psf_pixel":
        raise NotImplementedError(
            "Per-pixel PSF convolution has not been implemented."
        )

    else:
        raise Exception(f"Image simulation method {method} is not supported.")

    return img_render

render_psf

render_psf(img_obj, depth=None, patch_center=(0, 0), psf_ks=PSF_KS)

Render image patch using PSF convolution. Better not use this function to avoid confusion.

Source code in deeplens-src/deeplens/lens.py
def render_psf(self, img_obj, depth=None, patch_center=(0, 0), psf_ks=PSF_KS):
    """Render image patch using PSF convolution. Better not use this function to avoid confusion."""
    depth = self.obj_depth if depth is None else depth
    return self.render_psf_patch(
        img_obj, depth=depth, patch_center=patch_center, psf_ks=psf_ks
    )

render_psf_patch

render_psf_patch(img_obj, depth=None, patch_center=(0, 0), psf_ks=PSF_KS)

Render an image patch using PSF convolution, and return positional encoding channel.

Parameters:

Name Type Description Default
img_obj tensor

Input image object in raw space. Shape of [B, C, H, W].

required
depth float

Depth of the object. When None (default), falls back to self.obj_depth.

None
patch_center tensor

Center of the image patch. Shape of [2] or [B, 2].

(0, 0)
psf_ks int

PSF kernel size. Defaults to PSF_KS.

PSF_KS

Returns:

Name Type Description
img_render

Rendered image. Shape of [B, C, H, W].

Source code in deeplens-src/deeplens/lens.py
def render_psf_patch(self, img_obj, depth=None, patch_center=(0, 0), psf_ks=PSF_KS):
    """Render an image patch using PSF convolution, and return positional encoding channel.

    Args:
        img_obj (tensor): Input image object in raw space. Shape of [B, C, H, W].
        depth (float): Depth of the object. When ``None`` (default), falls
            back to ``self.obj_depth``.
        patch_center (tensor): Center of the image patch. Shape of [2] or [B, 2].
        psf_ks (int): PSF kernel size. Defaults to PSF_KS.

    Returns:
        img_render: Rendered image. Shape of [B, C, H, W].
    """
    depth = self.obj_depth if depth is None else depth
    # Convert patch_center to tensor
    if isinstance(patch_center, (list, tuple)):
        points = (patch_center[0], patch_center[1], depth)
        points = torch.tensor(points).unsqueeze(0)
    elif isinstance(patch_center, torch.Tensor):
        depth = torch.full_like(patch_center[..., 0], depth)
        points = torch.stack(
            [patch_center[..., 0], patch_center[..., 1], depth], dim=-1
        )
    else:
        raise Exception(
            f"Patch center must be a list or tuple or tensor, but got {type(patch_center)}."
        )

    # Compute PSF and perform PSF convolution
    psf = self.psf_rgb(points=points, ks=psf_ks).squeeze(0)
    img_render = conv_psf(img_obj, psf=psf)
    return img_render

render_psf_map

render_psf_map(img_obj, depth=None, psf_grid=7, psf_ks=PSF_KS, psf_spp=SPP_PSF)

Render image using PSF block convolution.

Note

Larger psf_grid and psf_ks are typically better for more accurate rendering, but slower.

Parameters:

Name Type Description Default
img_obj tensor

Input image object in raw space. Shape of [B, C, H, W].

required
depth float

Depth of the object. When None (default), falls back to self.obj_depth.

None
psf_grid int

PSF grid size.

7
psf_ks int

PSF kernel size. Defaults to PSF_KS.

PSF_KS

Returns:

Name Type Description
img_render

Rendered image. Shape of [B, C, H, W].

Source code in deeplens-src/deeplens/lens.py
def render_psf_map(
    self,
    img_obj,
    depth=None,
    psf_grid=7,
    psf_ks=PSF_KS,
    psf_spp=SPP_PSF,
):
    """Render image using PSF block convolution.

    Note:
        Larger psf_grid and psf_ks are typically better for more accurate rendering, but slower.

    Args:
        img_obj (tensor): Input image object in raw space. Shape of [B, C, H, W].
        depth (float): Depth of the object. When ``None`` (default), falls
            back to ``self.obj_depth``.
        psf_grid (int): PSF grid size.
        psf_ks (int): PSF kernel size. Defaults to PSF_KS.

    Returns:
        img_render: Rendered image. Shape of [B, C, H, W].
    """
    depth = self.obj_depth if depth is None else depth
    psf_map = self.psf_map_rgb(grid=psf_grid, ks=psf_ks, depth=depth, spp=psf_spp)
    img_render = conv_psf_map(img_obj, psf_map)
    return img_render

render_rgbd

render_rgbd(img_obj, depth_map, method='psf_patch', **kwargs)

Render RGBD image.

TODO: add obstruction-aware image simulation.

Parameters:

Name Type Description Default
img_obj tensor

Object image. Shape of [B, C, H, W].

required
depth_map tensor

Depth map [mm]. Shape of [B, 1, H, W]. Values should be positive.

required
method str

Image simulation method. Defaults to "psf_patch".

'psf_patch'
**kwargs

Additional arguments for different methods. - interp_mode (str): "depth" or "disparity". Defaults to "depth".

{}

Returns:

Name Type Description
img_render

Rendered image. Shape of [B, C, H, W].

Reference

[1] "Aberration-Aware Depth-from-Focus", TPAMI 2023. [2] "Efficient Depth- and Spatially-Varying Image Simulation for Defocus Deblur", ICCVW 2025.

Source code in deeplens-src/deeplens/lens.py
def render_rgbd(self, img_obj, depth_map, method="psf_patch", **kwargs):
    """Render RGBD image.

    TODO: add obstruction-aware image simulation.

    Args:
        img_obj (tensor): Object image. Shape of [B, C, H, W].
        depth_map (tensor): Depth map [mm]. Shape of [B, 1, H, W]. Values should be positive.
        method (str, optional): Image simulation method. Defaults to "psf_patch".
        **kwargs: Additional arguments for different methods.
            - interp_mode (str): "depth" or "disparity". Defaults to "depth".

    Returns:
        img_render: Rendered image. Shape of [B, C, H, W].

    Reference:
        [1] "Aberration-Aware Depth-from-Focus", TPAMI 2023.
        [2] "Efficient Depth- and Spatially-Varying Image Simulation for Defocus Deblur", ICCVW 2025.
    """
    if depth_map.min() < 0:
        raise ValueError("Depth map should be positive.")

    if len(depth_map.shape) == 3:
        # [B, H, W] -> [B, 1, H, W]
        depth_map = depth_map.unsqueeze(1)

    if method == "psf_patch":
        # Render an image patch (same FoV, different depth)
        patch_center = kwargs.get("patch_center", (0.0, 0.0))
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        depth_min = kwargs.get("depth_min", depth_map.min())
        depth_max = kwargs.get("depth_max", depth_map.max())
        num_layers = kwargs.get("num_layers", 16)
        interp_mode = kwargs.get("interp_mode", "disparity")

        # Calculate PSF at different depths, (num_layers, 3, ks, ks)
        disp_ref, depths_ref = self._sample_depth_layers(depth_min, depth_max, num_layers)

        points = torch.stack(
            [
                torch.full_like(depths_ref, patch_center[0]),
                torch.full_like(depths_ref, patch_center[1]),
                depths_ref,
            ],
            dim=-1,
        )
        psfs = self.psf_rgb(points=points, ks=psf_ks) # (num_layers, 3, ks, ks)

        # Image simulation
        img_render = conv_psf_depth_interp(img_obj, -depth_map, psfs, depths_ref, interp_mode=interp_mode)
        return img_render

    elif method == "psf_map":
        # Render full resolution image with PSF map convolution (different FoV, different depth)
        psf_grid = kwargs.get("psf_grid", (8, 8))  # (grid_w, grid_h)
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        depth_min = kwargs.get("depth_min", depth_map.min())
        depth_max = kwargs.get("depth_max", depth_map.max())
        num_layers = kwargs.get("num_layers", 16)
        interp_mode = kwargs.get("interp_mode", "disparity")

        # Calculate PSF map at different depths (convert to negative for PSF calculation)
        disp_ref, depths_ref = self._sample_depth_layers(depth_min, depth_max, num_layers)

        psf_maps = []
        from tqdm import tqdm
        for depth in tqdm(depths_ref):
            psf_map = self.psf_map_rgb(grid=psf_grid, ks=psf_ks, depth=depth)
            psf_maps.append(psf_map)
        psf_map = torch.stack(
            psf_maps, dim=2
        )  # shape [grid_h, grid_w, num_layers, 3, ks, ks]

        # Image simulation
        img_render = conv_psf_map_depth_interp(
            img_obj, -depth_map, psf_map, depths_ref, interp_mode=interp_mode
        )
        return img_render

    elif method == "psf_pixel":
        # Render full resolution image with per-pixel PSF splatting. This method is computationally expensive.
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        assert img_obj.shape[0] == 1, "Now only support batch size 1"

        # Calculate points in the object space
        points_xy = torch.meshgrid(
            torch.linspace(-1, 1, img_obj.shape[-1], device=self.device),
            torch.linspace(1, -1, img_obj.shape[-2], device=self.device),
            indexing="xy",
        )
        points_xy = torch.stack(points_xy, dim=0).unsqueeze(0)
        points = torch.cat([points_xy, -depth_map], dim=1)  # shape [B, 3, H, W]

        # Calculate PSF at different pixels. This step is the most time-consuming.
        points = points.permute(0, 2, 3, 1).reshape(-1, 3)  # shape [H*W, 3]
        psfs = self.psf_rgb(points=points, ks=psf_ks)  # shape [H*W, 3, ks, ks]
        psfs = psfs.reshape(
            img_obj.shape[-2], img_obj.shape[-1], 3, psf_ks, psf_ks
        )  # shape [H, W, 3, ks, ks]

        # Image simulation
        img_render = splat_psf_per_pixel(img_obj, psfs)  # shape [1, C, H, W]
        return img_render

    else:
        raise Exception(f"Image simulation method {method} is not supported.")

activate_grad

activate_grad(activate=True)

Activate gradient for each surface.

Source code in deeplens-src/deeplens/lens.py
def activate_grad(self, activate=True):
    """Activate gradient for each surface."""
    raise NotImplementedError

get_optimizer_params

get_optimizer_params(lr=[0.0001, 0.0001, 0.1, 0.001])

Get optimizer parameters for different lens parameters.

Source code in deeplens-src/deeplens/lens.py
def get_optimizer_params(self, lr=[1e-4, 1e-4, 1e-1, 1e-3]):
    """Get optimizer parameters for different lens parameters."""
    raise NotImplementedError

get_optimizer

get_optimizer(lr=[0.0001, 0.0001, 0, 0.001])

Get optimizer.

Source code in deeplens-src/deeplens/lens.py
def get_optimizer(self, lr=[1e-4, 1e-4, 0, 1e-3]):
    """Get optimizer."""
    params = self.get_optimizer_params(lr)
    optimizer = torch.optim.Adam(params)
    return optimizer