Skip to content

GeoLens

Differentiable multi-element refractive lens via geometric ray tracing. GeoLens is the primary lens model in DeepLens: it ray-traces through a stack of optical surfaces to compute PSFs, render images, and optimize lens geometry end-to-end.

deeplens.GeoLens

GeoLens(filename=None, device=None, dtype=torch.float32, primary_wvln=DEFAULT_WAVE, wvln_rgb=WAVE_RGB, obj_depth=DEPTH)

Bases: GeoLensPSF, GeoLensEval, GeoLensOptim, GeoLensSurfOps, GeoLensVis, GeoLensIO, GeoLensVis3D, Lens

Differentiable geometric lens using vectorised ray tracing.

The primary lens model in DeepLens. Supports multi-element refractive (and partially reflective) systems loaded from JSON, Zemax .zmx, or Code V .seq files. Accuracy is aligned with Zemax OpticStudio.

Uses a mixin architecture – seven specialised mixin classes are composed at class definition time to keep each concern isolated:

  • GeoLensPSF – PSF computation (geometric, coherent, Huygens models).
  • GeoLensEval – optical performance evaluation (spot, MTF, distortion, vignetting).
  • GeoLensOptim – loss functions and gradient-based optimisation.
  • GeoLensSurfOps – surface geometry operations (aspheric conversion, pruning, shape correction, material matching).
  • GeoLensVis – 2-D layout and ray visualisation.
  • GeoLensIO – read/write JSON, Zemax .zmx.
  • GeoLensVis3D – 3-D mesh visualisation.

Key differentiability trick: Ray-surface intersection (newtons_method) uses a non-differentiable Newton loop followed by one differentiable Newton step to enable gradient flow.

Attributes:

Name Type Description
surfaces list[Surface]

Ordered list of optical surfaces.

materials list[Material]

Optical materials between surfaces.

d_sensor Tensor

Back focal distance [mm].

foclen float

Effective focal length [mm].

fnum float

F-number.

rfov float

Half-diagonal field of view [radians].

sensor_size tuple

Physical sensor size (W, H) [mm].

sensor_res tuple

Sensor resolution (W, H) [pixels].

pixel_size float

Pixel pitch [mm].

References

Xinge Yang et al., "Curriculum learning for ab initio deep learned refractive optics," Nature Communications 2024.

Initialize a refractive lens.

There are two ways to initialize a GeoLens
  1. Read a lens from .json/.zmx/.seq file
  2. Initialize a lens with no lens file, then manually add surfaces and materials

Parameters:

Name Type Description Default
filename str

Path to lens file (.json, .zmx, or .seq). Defaults to None.

None
device device

Device for tensor computations. Defaults to None.

None
dtype dtype

Data type for computations. Defaults to torch.float32.

float32
primary_wvln float

Primary design wavelength [µm]. Used as fallback when a method is called without an explicit wvln. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
wvln_rgb sequence of float

Three wavelengths used for RGB computations, ordered [R, G, B] in µm. Defaults to WAVE_RGB.

WAVE_RGB
obj_depth float

Default object depth [mm], used when a method is called without an explicit depth. Defaults to DEPTH.

DEPTH
Source code in deeplens-src/deeplens/geolens.py
def __init__(
    self,
    filename=None,
    device=None,
    dtype=torch.float32,
    primary_wvln=DEFAULT_WAVE,
    wvln_rgb=WAVE_RGB,
    obj_depth=DEPTH,
):
    """Initialize a refractive lens.

    There are two ways to initialize a GeoLens:
        1. Read a lens from .json/.zmx/.seq file
        2. Initialize a lens with no lens file, then manually add surfaces and materials

    Args:
        filename (str, optional): Path to lens file (.json, .zmx, or .seq). Defaults to None.
        device (torch.device, optional): Device for tensor computations. Defaults to None.
        dtype (torch.dtype, optional): Data type for computations. Defaults to torch.float32.
        primary_wvln (float, optional): Primary design wavelength [µm].
            Used as fallback when a method is called without an explicit
            ``wvln``.  Defaults to ``DEFAULT_WAVE``.
        wvln_rgb (sequence of float, optional): Three wavelengths used
            for RGB computations, ordered ``[R, G, B]`` in µm.  Defaults
            to ``WAVE_RGB``.
        obj_depth (float, optional): Default object depth [mm], used
            when a method is called without an explicit ``depth``.
            Defaults to ``DEPTH``.
    """
    super().__init__(
        device=device,
        dtype=dtype,
        primary_wvln=primary_wvln,
        wvln_rgb=wvln_rgb,
        obj_depth=obj_depth,
    )

    # Load lens file
    if filename is not None:
        self.read_lens(filename)
    else:
        self.surfaces = []
        self.materials = []
        # Set default sensor size and resolution
        self.sensor_size = (8.0, 8.0)
        self.sensor_res = (2000, 2000)
        self.to(self.device)

read_lens

read_lens(filename)

Read a GeoLens from a file.

Supported file formats
  • .json: DeepLens native JSON format
  • .zmx: Zemax lens file format
  • .seq: CODE V sequence file format

Parameters:

Name Type Description Default
filename str

Path to the lens file.

required
Note

Sensor size and resolution will usually be overwritten by values from the file.

Source code in deeplens-src/deeplens/geolens.py
def read_lens(self, filename):
    """Read a GeoLens from a file.

    Supported file formats:
        - .json: DeepLens native JSON format
        - .zmx: Zemax lens file format
        - .seq: CODE V sequence file format

    Args:
        filename (str): Path to the lens file.

    Note:
        Sensor size and resolution will usually be overwritten by values from the file.
    """
    # Load lens file
    if filename[-4:] == ".txt":
        raise ValueError("File format .txt has been deprecated.")
    elif filename[-5:] == ".json":
        self.read_lens_json(filename)
    elif filename[-4:] == ".zmx":
        self.read_lens_zmx(filename)
    elif filename[-4:] == ".seq":
        self.read_lens_seq(filename)
    else:
        raise ValueError(f"File format {filename[-4:]} not supported.")

    # Complete sensor size and resolution if not set from lens file
    if not hasattr(self, "sensor_size"):
        self.sensor_size = (8.0, 8.0)
        print(
            f"Sensor_size not found in lens file. Using default: {self.sensor_size} mm. "
            "Consider specifying sensor_size in the lens file or using set_sensor()."
        )

    if not hasattr(self, "sensor_res"):
        self.sensor_res = (2000, 2000)
        print(
            f"Sensor_res not found in lens file. Using default: {self.sensor_res} pixels. "
            "Consider specifying sensor_res in the lens file or using set_sensor()."
        )
        self.set_sensor_res(self.sensor_res)

    # After loading lens, compute foclen, fov and fnum
    self.to(self.device)
    self.astype(self.dtype)
    self.post_computation()

post_computation

post_computation()

Compute derived optical properties after loading or modifying lens.

Calculates and caches
  • Effective focal length (EFL)
  • Entrance and exit pupil positions and radii
  • Field of view (FoV) in horizontal, vertical, and diagonal directions
  • F-number
  • Lens design constraints (edge/center thickness bounds, etc.)
Note

This method should be called after any changes to the lens geometry.

Source code in deeplens-src/deeplens/geolens.py
def post_computation(self):
    """Compute derived optical properties after loading or modifying lens.

    Calculates and caches:
        - Effective focal length (EFL)
        - Entrance and exit pupil positions and radii
        - Field of view (FoV) in horizontal, vertical, and diagonal directions
        - F-number
        - Lens design constraints (edge/center thickness bounds, etc.)

    Note:
        This method should be called after any changes to the lens geometry.
    """
    self.calc_foclen()
    self.calc_pupil()
    self.calc_fov()
    self.init_constraints()

__call__

__call__(ray)

Trace rays through the lens system.

Makes the GeoLens callable, allowing ray tracing with function call syntax.

Source code in deeplens-src/deeplens/geolens.py
def __call__(self, ray):
    """Trace rays through the lens system.

    Makes the GeoLens callable, allowing ray tracing with function call syntax.
    """
    return self.trace(ray)

sample_grid_rays

sample_grid_rays(depth=float('inf'), num_grid=(11, 11), num_rays=SPP_PSF, wvln=None, uniform_fov=True, sample_more_off_axis=False, scale_pupil=1.0)

Sample grid rays from object space. (1) If depth is infinite, sample parallel rays at different field angles. (2) If depth is finite, sample point source rays from the object plane.

This function is usually used for (1) PSF map, (2) RMS error map, and (3) spot diagram calculation.

Parameters:

Name Type Description Default
depth float

sampling depth. Defaults to float("inf").

float('inf')
num_grid tuple

number of grid points. Defaults to [11, 11].

(11, 11)
num_rays int

number of rays. Defaults to SPP_PSF.

SPP_PSF
wvln float

ray wvln in µm. When None (default), falls back to self.primary_wvln.

None
uniform_fov bool

If True, sample uniform FoV angles.

True
sample_more_off_axis bool

If True, sample more off-axis rays.

False
scale_pupil float

Scale factor for pupil radius.

1.0

Returns:

Name Type Description
ray Ray object

Ray object. Shape [num_grid[1], num_grid[0], num_rays, 3]

Source code in deeplens-src/deeplens/geolens.py
@torch.no_grad()
def sample_grid_rays(
    self,
    depth=float("inf"),
    num_grid=(11, 11),
    num_rays=SPP_PSF,
    wvln=None,
    uniform_fov=True,
    sample_more_off_axis=False,
    scale_pupil=1.0,
):
    """Sample grid rays from object space.
        (1) If depth is infinite, sample parallel rays at different field angles.
        (2) If depth is finite, sample point source rays from the object plane.

    This function is usually used for (1) PSF map, (2) RMS error map, and (3) spot diagram calculation.

    Args:
        depth (float, optional): sampling depth. Defaults to float("inf").
        num_grid (tuple, optional): number of grid points. Defaults to [11, 11].
        num_rays (int, optional): number of rays. Defaults to SPP_PSF.
        wvln (float, optional): ray wvln in µm. When ``None`` (default),
            falls back to ``self.primary_wvln``.
        uniform_fov (bool, optional): If True, sample uniform FoV angles.
        sample_more_off_axis (bool, optional): If True, sample more off-axis rays.
        scale_pupil (float, optional): Scale factor for pupil radius.

    Returns:
        ray (Ray object): Ray object. Shape [num_grid[1], num_grid[0], num_rays, 3]
    """
    wvln = self.primary_wvln if wvln is None else wvln

    # Normalize num_grid to a tuple if it's an int
    if isinstance(num_grid, int):
        num_grid = (num_grid, num_grid)

    # Calculate field angles for grid source. Top-left field has positive fov_x and negative fov_y
    x_list = [x for x in np.linspace(1, -1, num_grid[0])]
    y_list = [y for y in np.linspace(-1, 1, num_grid[1])]
    if sample_more_off_axis:
        x_list = [np.sign(x) * np.abs(x) ** 0.5 for x in x_list]
        y_list = [np.sign(y) * np.abs(y) ** 0.5 for y in y_list]

    # Calculate FoV_x and FoV_y
    if uniform_fov:
        # Sample uniform FoV angles
        fov_x_list = [x * self.vfov / 2 for x in x_list]
        fov_y_list = [y * self.hfov / 2 for y in y_list]
        fov_x_list = [float(np.rad2deg(fov_x)) for fov_x in fov_x_list]
        fov_y_list = [float(np.rad2deg(fov_y)) for fov_y in fov_y_list]
    else:
        # Sample uniform object grid
        fov_x_list = [np.arctan(x * np.tan(self.vfov / 2)) for x in x_list]
        fov_y_list = [np.arctan(y * np.tan(self.hfov / 2)) for y in y_list]
        fov_x_list = [float(np.rad2deg(fov_x)) for fov_x in fov_x_list]
        fov_y_list = [float(np.rad2deg(fov_y)) for fov_y in fov_y_list]

    # Sample rays (collimated or point source via unified API)
    rays = self.sample_from_fov(
        fov_x=fov_x_list,
        fov_y=fov_y_list,
        depth=depth,
        num_rays=num_rays,
        wvln=wvln,
        scale_pupil=scale_pupil,
    )
    return rays

sample_radial_rays

sample_radial_rays(num_field=5, depth=float('inf'), num_rays=SPP_PSF, wvln=None, direction='y')

Sample radial rays at evenly-spaced field angles along a chosen direction.

Parameters:

Name Type Description Default
num_field int

Number of field angles from on-axis to full-field. Defaults to 5.

5
depth float

Object distance in mm. Use float('inf') for collimated light. Defaults to float('inf').

float('inf')
num_rays int

Rays per field position. Defaults to SPP_PSF.

SPP_PSF
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None
direction str

Sampling direction — "y" (meridional, default), "x" (sagittal), "diagonal" (45°, x = y).

'y'

Returns:

Name Type Description
Ray

Ray object with shape [num_field, num_rays, 3].

Source code in deeplens-src/deeplens/geolens.py
@torch.no_grad()
def sample_radial_rays(
    self,
    num_field=5,
    depth=float("inf"),
    num_rays=SPP_PSF,
    wvln=None,
    direction="y",
):
    """Sample radial rays at evenly-spaced field angles along a chosen direction.

    Args:
        num_field (int): Number of field angles from on-axis to full-field.
            Defaults to 5.
        depth (float): Object distance in mm. Use ``float('inf')`` for
            collimated light. Defaults to ``float('inf')``.
        num_rays (int): Rays per field position. Defaults to ``SPP_PSF``.
        wvln (float): Wavelength in µm. When ``None`` (default), falls
            back to ``self.primary_wvln``.
        direction (str): Sampling direction —
            ``"y"`` (meridional, default),
            ``"x"`` (sagittal),
            ``"diagonal"`` (45°, x = y).

    Returns:
        Ray: Ray object with shape ``[num_field, num_rays, 3]``.
    """
    wvln = self.primary_wvln if wvln is None else wvln
    device = self.device
    fov_deg = self.rfov * 180 / torch.pi
    fov_list = torch.linspace(0, fov_deg, num_field, device=device)

    if direction == "y":
        ray = self.sample_from_fov(
            fov_x=0.0, fov_y=fov_list, depth=depth, num_rays=num_rays, wvln=wvln
        )
    elif direction == "x":
        ray = self.sample_from_fov(
            fov_x=fov_list, fov_y=0.0, depth=depth, num_rays=num_rays, wvln=wvln
        )
    elif direction == "diagonal":
        # sample_from_fov creates a meshgrid; for pairwise diagonal, loop
        rays = [
            self.sample_from_fov(
                fov_x=f.item(), fov_y=f.item(), depth=depth, num_rays=num_rays, wvln=wvln
            )
            for f in fov_list
        ]
        ray_o = torch.stack([r.o for r in rays], dim=0)
        ray_d = torch.stack([r.d for r in rays], dim=0)
        ray = Ray(ray_o, ray_d, wvln, device=device)
    else:
        raise ValueError(f"Invalid direction: {direction!r}. Use 'x', 'y', or 'diagonal'.")
    return ray

sample_from_points

sample_from_points(points=[[0.0, 0.0, -10000.0]], num_rays=SPP_PSF, wvln=None, scale_pupil=1.0)

Sample rays from point sources in object space (absolute physical coordinates).

Used for PSF and chief ray calculation.

Parameters:

Name Type Description Default
points list or Tensor

Ray origins in shape [3], [N, 3], or [Nx, Ny, 3].

[[0.0, 0.0, -10000.0]]
num_rays int

Number of rays per point. Default: SPP_PSF.

SPP_PSF
wvln float

Wavelength of rays in µm. When None (default), falls back to self.primary_wvln.

None
scale_pupil float

Scale factor for pupil radius.

1.0

Returns:

Name Type Description
Ray

Sampled rays with shape (\*points.shape[:-1], num_rays, 3).

Source code in deeplens-src/deeplens/geolens.py
@torch.no_grad()
def sample_from_points(
    self,
    points=[[0.0, 0.0, -10000.0]],
    num_rays=SPP_PSF,
    wvln=None,
    scale_pupil=1.0,
):
    """
    Sample rays from point sources in object space (absolute physical coordinates).

    Used for PSF and chief ray calculation.

    Args:
        points (list or Tensor): Ray origins in shape [3], [N, 3], or [Nx, Ny, 3].
        num_rays (int): Number of rays per point. Default: SPP_PSF.
        wvln (float): Wavelength of rays in µm. When ``None`` (default),
            falls back to ``self.primary_wvln``.
        scale_pupil (float): Scale factor for pupil radius.

    Returns:
        Ray: Sampled rays with shape ``(\\*points.shape[:-1], num_rays, 3)``.
    """
    wvln = self.primary_wvln if wvln is None else wvln

    # Ray origin is given
    if not torch.is_tensor(points):
        ray_o = torch.tensor(points, device=self.device)
    else:
        ray_o = points.to(self.device)

    # Sample points on the pupil
    pupilz, pupilr = self.get_entrance_pupil()
    pupilr *= scale_pupil
    ray_o2 = self.sample_circle(
        r=pupilr, z=pupilz, shape=(*ray_o.shape[:-1], num_rays)
    )

    # Compute ray directions
    if len(ray_o.shape) == 1:
        # Input point shape is [3]
        ray_o = ray_o.unsqueeze(0).repeat(num_rays, 1)  # shape [num_rays, 3]
        ray_d = ray_o2 - ray_o

    elif len(ray_o.shape) == 2:
        # Input point shape is [N, 3]
        ray_o = ray_o.unsqueeze(1).repeat(1, num_rays, 1)  # shape [N, num_rays, 3]
        ray_d = ray_o2 - ray_o

    elif len(ray_o.shape) == 3:
        # Input point shape is [Nx, Ny, 3]
        ray_o = ray_o.unsqueeze(2).repeat(
            1, 1, num_rays, 1
        )  # shape [Nx, Ny, num_rays, 3]
        ray_d = ray_o2 - ray_o

    else:
        raise Exception("The shape of input object positions is not supported.")

    # Calculate rays
    rays = Ray(ray_o, ray_d, wvln, device=self.device)
    return rays

sample_from_fov

sample_from_fov(fov_x=[0.0], fov_y=[0.0], depth=float('inf'), num_rays=SPP_CALC, wvln=None, entrance_pupil=True, scale_pupil=1.0)

Sample rays from object space at given field angles.

For infinite depth, generates collimated parallel rays: origins are distributed on the entrance pupil and all rays in a field share the same direction determined by the FOV angle.

For finite depth, generates diverging point-source rays: the point source position is determined by FOV angle and depth, and rays fan out toward the entrance pupil.

Parameters:

Name Type Description Default
fov_x float or list

Field angle(s) in the xz plane (degrees).

[0.0]
fov_y float or list

Field angle(s) in the yz plane (degrees).

[0.0]
depth float

Object distance in mm. float('inf') for collimated rays, finite for point-source rays.

float('inf')
num_rays int

Number of rays per field point.

SPP_CALC
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None
entrance_pupil bool

If True, sample on entrance pupil; otherwise on surface 0. Default: True.

True
scale_pupil float

Scale factor for pupil radius.

1.0

Returns:

Name Type Description
Ray

Rays with shape [..., num_rays, 3], where leading dims are squeezed when the corresponding fov input is scalar.

Source code in deeplens-src/deeplens/geolens.py
@torch.no_grad()
def sample_from_fov(
    self,
    fov_x=[0.0],
    fov_y=[0.0],
    depth=float("inf"),
    num_rays=SPP_CALC,
    wvln=None,
    entrance_pupil=True,
    scale_pupil=1.0,
):
    """Sample rays from object space at given field angles.

    For infinite depth, generates collimated parallel rays: origins are
    distributed on the entrance pupil and all rays in a field share the
    same direction determined by the FOV angle.

    For finite depth, generates diverging point-source rays: the point
    source position is determined by FOV angle and depth, and rays fan
    out toward the entrance pupil.

    Args:
        fov_x (float or list): Field angle(s) in the xz plane (degrees).
        fov_y (float or list): Field angle(s) in the yz plane (degrees).
        depth (float): Object distance in mm. ``float('inf')`` for
            collimated rays, finite for point-source rays.
        num_rays (int): Number of rays per field point.
        wvln (float): Wavelength in µm. When ``None`` (default), falls
            back to ``self.primary_wvln``.
        entrance_pupil (bool): If True, sample on entrance pupil;
            otherwise on surface 0. Default: True.
        scale_pupil (float): Scale factor for pupil radius.

    Returns:
        Ray: Rays with shape ``[..., num_rays, 3]``, where leading dims
            are squeezed when the corresponding fov input is scalar.
    """
    wvln = self.primary_wvln if wvln is None else wvln

    # Track which inputs were scalar for output shape
    x_scalar = isinstance(fov_x, (float, int))
    y_scalar = isinstance(fov_y, (float, int))
    if x_scalar:
        fov_x = [float(fov_x)]
    if y_scalar:
        fov_y = [float(fov_y)]

    fov_x_rad = torch.tensor([fx * torch.pi / 180 for fx in fov_x], device=self.device)
    fov_y_rad = torch.tensor([fy * torch.pi / 180 for fy in fov_y], device=self.device)
    fov_x_grid, fov_y_grid = torch.meshgrid(fov_x_rad, fov_y_rad, indexing="xy")

    # Pupil position and radius
    if entrance_pupil:
        pupilz, pupilr = self.get_entrance_pupil()
    else:
        pupilz, pupilr = self.surfaces[0].d.item(), self.surfaces[0].r
    pupilr *= scale_pupil

    if depth == float("inf"):
        # Collimated rays: origins on pupil, uniform direction per field
        ray_o = self.sample_circle(
            r=pupilr, z=pupilz, shape=[len(fov_y), len(fov_x), num_rays]
        )
        dx = torch.tan(fov_x_grid).unsqueeze(-1).expand_as(ray_o[..., 0])
        dy = torch.tan(fov_y_grid).unsqueeze(-1).expand_as(ray_o[..., 1])
        dz = torch.ones_like(ray_o[..., 2])
        ray_d = torch.stack((dx, dy, dz), dim=-1)

        if x_scalar:
            ray_o = ray_o.squeeze(1)
            ray_d = ray_d.squeeze(1)
        if y_scalar:
            ray_o = ray_o.squeeze(0)
            ray_d = ray_d.squeeze(0)

        rays = Ray(ray_o, ray_d, wvln, device=self.device)
        rays.prop_to(-1.0)

    else:
        # Point-source rays: origin at object point, fan toward pupil
        x = torch.tan(fov_x_grid) * depth
        y = torch.tan(fov_y_grid) * depth
        z = torch.full_like(x, depth)
        points = torch.stack((x, y, z), dim=-1)

        if x_scalar:
            points = points.squeeze(-2)
        if y_scalar:
            points = points.squeeze(0)

        rays = self.sample_from_points(
            points=points, num_rays=num_rays, wvln=wvln, scale_pupil=scale_pupil
        )

    return rays

sample_sensor

sample_sensor(spp=64, wvln=None, sub_pixel=False)

Sample rays from sensor pixels (backward rays). Used for ray-tracing based rendering.

Parameters:

Name Type Description Default
spp int

sample per pixel. Defaults to 64.

64
pupil bool

whether to use pupil. Defaults to True.

required
wvln float

ray wvln in µm. When None (default), falls back to self.primary_wvln.

None
sub_pixel bool

whether to sample multiple points inside the pixel. Defaults to False.

False

Returns:

Name Type Description
ray Ray object

Ray object. Shape [H, W, spp, 3]

Source code in deeplens-src/deeplens/geolens.py
@torch.no_grad()
def sample_sensor(self, spp=64, wvln=None, sub_pixel=False):
    """Sample rays from sensor pixels (backward rays). Used for ray-tracing based rendering.

    Args:
        spp (int, optional): sample per pixel. Defaults to 64.
        pupil (bool, optional): whether to use pupil. Defaults to True.
        wvln (float, optional): ray wvln in µm. When ``None`` (default),
            falls back to ``self.primary_wvln``.
        sub_pixel (bool, optional): whether to sample multiple points inside the pixel. Defaults to False.

    Returns:
        ray (Ray object): Ray object. Shape [H, W, spp, 3]
    """
    wvln = self.primary_wvln if wvln is None else wvln
    w, h = self.sensor_size
    W, H = self.sensor_res
    device = self.device

    # Sample points on sensor plane
    # Use top-left point as reference in rendering, so here we should sample bottom-right point
    x1, y1 = torch.meshgrid(
        torch.linspace(-w / 2, w / 2, W + 1, device=device,)[1:],
        torch.linspace(h / 2, -h / 2, H + 1, device=device,)[1:],
        indexing="xy",
    )
    z1 = torch.full_like(x1, self.d_sensor.item())

    # Sample second points on the pupil
    # sensor_res is (W, H) but meshgrid with indexing="xy" gives (H, W) arrays
    pupilz, pupilr = self.get_exit_pupil()
    ray_o2 = self.sample_circle(r=pupilr, z=pupilz, shape=(H, W, spp))

    # Form rays
    ray_o = torch.stack((x1, y1, z1), 2)
    ray_o = ray_o.unsqueeze(2).repeat(1, 1, spp, 1)  # [H, W, spp, 3]

    # Sub-pixel sampling for more realistic rendering
    if sub_pixel:
        delta_ox = (
            torch.rand(ray_o.shape[:-1], device=device)
            * self.pixel_size
        )
        delta_oy = (
            -torch.rand(ray_o.shape[:-1], device=device)
            * self.pixel_size
        )
        delta_oz = torch.zeros_like(delta_ox)
        delta_o = torch.stack((delta_ox, delta_oy, delta_oz), -1)
        ray_o = ray_o + delta_o

    # Form rays
    ray_d = ray_o2 - ray_o  # shape [H, W, spp, 3]
    ray = Ray(ray_o, ray_d, wvln, device=device)
    return ray

sample_circle

sample_circle(r, z, shape=[16, 16, 512])

Sample points inside a circle.

Parameters:

Name Type Description Default
r float

Radius of the circle.

required
z float

Z-coordinate for all sampled points.

required
shape list

Shape of the output tensor.

[16, 16, 512]

Returns:

Type Description

torch.Tensor: Sampled points, shape (\*shape, 3).

Source code in deeplens-src/deeplens/geolens.py
def sample_circle(self, r, z, shape=[16, 16, 512]):
    """Sample points inside a circle.

    Args:
        r (float): Radius of the circle.
        z (float): Z-coordinate for all sampled points.
        shape (list): Shape of the output tensor.

    Returns:
        torch.Tensor: Sampled points, shape ``(\\*shape, 3)``.
    """
    device = self.device

    # Generate random angles and radii
    theta = torch.rand(*shape, device=device) * 2 * torch.pi
    r2 = torch.rand(*shape, device=device) * r**2
    radius = torch.sqrt(r2)

    # Stack to form 3D points
    x = radius * torch.cos(theta)
    y = radius * torch.sin(theta)
    z_tensor = torch.full_like(x, z)
    points = torch.stack((x, y, z_tensor), dim=-1)

    # Manually sample chief ray
    # points[..., 0, :2] = 0.0

    return points

trace

trace(ray, surf_range=None, record=False)

Trace rays through the lens.

Forward or backward tracing is automatically determined by the ray direction.

Parameters:

Name Type Description Default
ray Ray object

Ray object.

required
surf_range list

Surface index range.

None
record bool

record ray path or not.

False

Returns:

Name Type Description
ray_final Ray object

ray after optical system.

ray_o_rec list

list of intersection points.

Source code in deeplens-src/deeplens/geolens.py
def trace(self, ray, surf_range=None, record=False):
    """Trace rays through the lens.

    Forward or backward tracing is automatically determined by the ray direction.

    Args:
        ray (Ray object): Ray object.
        surf_range (list): Surface index range.
        record (bool): record ray path or not.

    Returns:
        ray_final (Ray object): ray after optical system.
        ray_o_rec (list): list of intersection points.
    """
    if surf_range is None:
        surf_range = range(0, len(self.surfaces))

    if (ray.d[..., 2] > 0).any():
        ray_out, ray_o_rec = self.forward_tracing(ray, surf_range, record=record)
    else:
        ray_out, ray_o_rec = self.backward_tracing(ray, surf_range, record=record)

    return ray_out, ray_o_rec

trace2obj

trace2obj(ray)

Traces rays backwards through all lens surfaces from sensor side to object side.

Parameters:

Name Type Description Default
ray Ray

Ray object to trace backwards.

required

Returns:

Name Type Description
Ray

Ray object after backward propagation through the lens.

Source code in deeplens-src/deeplens/geolens.py
def trace2obj(self, ray):
    """Traces rays backwards through all lens surfaces from sensor side
    to object side.

    Args:
        ray (Ray): Ray object to trace backwards.

    Returns:
        Ray: Ray object after backward propagation through the lens.
    """
    ray, _ = self.trace(ray)
    return ray

trace2sensor

trace2sensor(ray, record=False)

Forward trace rays through the lens to sensor plane.

Parameters:

Name Type Description Default
ray Ray object

Ray object.

required
record bool

record ray path or not.

False

Returns:

Name Type Description
ray_out Ray object

ray after optical system.

ray_o_record list

list of intersection points.

Source code in deeplens-src/deeplens/geolens.py
def trace2sensor(self, ray, record=False):
    """Forward trace rays through the lens to sensor plane.

    Args:
        ray (Ray object): Ray object.
        record (bool): record ray path or not.

    Returns:
        ray_out (Ray object): ray after optical system.
        ray_o_record (list): list of intersection points.
    """
    # Manually propagate ray to a shallow depth to avoid numerical instability
    if ray.o[..., 2].min() < -100.0:
        ray = ray.prop_to(-10.0)

    # Trace rays
    ray, ray_o_record = self.trace(ray, record=record)
    ray = ray.prop_to(self.d_sensor)

    if record:
        ray_o = ray.o.clone().detach()
        # Set to NaN to be skipped in 2d layout visualization
        ray_o[ray.is_valid == 0] = float("nan")
        ray_o_record.append(ray_o)
        return ray, ray_o_record
    else:
        return ray

trace2exit_pupil

trace2exit_pupil(ray)

Forward trace rays through the lens to exit pupil plane.

Parameters:

Name Type Description Default
ray Ray

Ray object to trace.

required

Returns:

Name Type Description
Ray

Ray object propagated to the exit pupil plane.

Source code in deeplens-src/deeplens/geolens.py
def trace2exit_pupil(self, ray):
    """Forward trace rays through the lens to exit pupil plane.

    Args:
        ray (Ray): Ray object to trace.

    Returns:
        Ray: Ray object propagated to the exit pupil plane.
    """
    ray = self.trace2sensor(ray)
    pupil_z, _ = self.get_exit_pupil()
    ray = ray.prop_to(pupil_z)
    return ray

forward_tracing

forward_tracing(ray, surf_range, record)

Forward traces rays through each surface in the specified range from object side to image side.

Parameters:

Name Type Description Default
ray Ray

Ray object to trace.

required
surf_range range

Range of surface indices to trace through.

required
record bool

If True, record ray positions at each surface.

required

Returns:

Name Type Description
tuple

(ray_out, ray_o_record) where: - ray_out (Ray): Ray after propagation through all surfaces. - ray_o_record (list or None): List of ray positions at each surface, or None if record is False.

Source code in deeplens-src/deeplens/geolens.py
def forward_tracing(self, ray, surf_range, record):
    """Forward traces rays through each surface in the specified range from object side to image side.

    Args:
        ray (Ray): Ray object to trace.
        surf_range (range): Range of surface indices to trace through.
        record (bool): If True, record ray positions at each surface.

    Returns:
        tuple: (ray_out, ray_o_record) where:
            - ray_out (Ray): Ray after propagation through all surfaces.
            - ray_o_record (list or None): List of ray positions at each surface,
                or None if record is False.
    """
    if record:
        ray_o_record = []
        ray_o_record.append(ray.o.clone().detach())
    else:
        ray_o_record = None

    mat1 = Material("air")
    for i in surf_range:
        n1 = mat1.ior(ray.wvln)
        n2 = self.surfaces[i].mat2.ior(ray.wvln)
        ray = self.surfaces[i].ray_reaction(ray, n1, n2)
        mat1 = self.surfaces[i].mat2

        if record:
            ray_out_o = ray.o.clone().detach()
            ray_out_o[ray.is_valid == 0] = float("nan")
            ray_o_record.append(ray_out_o)

    return ray, ray_o_record

backward_tracing

backward_tracing(ray, surf_range, record)

Backward traces rays through each surface in reverse order from image side to object side.

Parameters:

Name Type Description Default
ray Ray

Ray object to trace.

required
surf_range range

Range of surface indices to trace through.

required
record bool

If True, record ray positions at each surface.

required

Returns:

Name Type Description
tuple

(ray_out, ray_o_record) where: - ray_out (Ray): Ray after backward propagation through all surfaces. - ray_o_record (list or None): List of ray positions at each surface, or None if record is False.

Source code in deeplens-src/deeplens/geolens.py
def backward_tracing(self, ray, surf_range, record):
    """Backward traces rays through each surface in reverse order from image side to object side.

    Args:
        ray (Ray): Ray object to trace.
        surf_range (range): Range of surface indices to trace through.
        record (bool): If True, record ray positions at each surface.

    Returns:
        tuple: (ray_out, ray_o_record) where:
            - ray_out (Ray): Ray after backward propagation through all surfaces.
            - ray_o_record (list or None): List of ray positions at each surface,
                or None if record is False.
    """
    if record:
        ray_o_record = []
        ray_o_record.append(ray.o.clone().detach())
    else:
        ray_o_record = None

    surf_indices = list(surf_range)
    mat1 = self.surfaces[surf_indices[-1]].mat2 if surf_indices else Material("air")
    for i in reversed(surf_indices):
        n1 = mat1.ior(ray.wvln)
        mat2 = Material("air") if i == 0 else self.surfaces[i - 1].mat2
        n2 = mat2.ior(ray.wvln)
        ray = self.surfaces[i].ray_reaction(ray, n1, n2)
        mat1 = mat2

        if record:
            ray_out_o = ray.o.clone().detach()
            ray_out_o[ray.is_valid == 0] = float("nan")
            ray_o_record.append(ray_out_o)

    return ray, ray_o_record

render

render(img_obj, depth=None, method='ray_tracing', **kwargs)

Differentiable image simulation.

Image simulation methods

[1] PSF map block convolution. [2] PSF patch convolution. [3] Ray tracing rendering.

Parameters:

Name Type Description Default
img_obj Tensor

Input image object in raw space. Shape of [N, C, H, W].

required
depth float

Depth of the object. When None (default), falls back to self.obj_depth.

None
method str

Image simulation method. One of 'psf_map', 'psf_patch', or 'ray_tracing'. Defaults to 'ray_tracing'.

'ray_tracing'
**kwargs

Additional arguments for different methods: - psf_grid (tuple): Grid size for PSF map method. Defaults to (10, 10). - psf_ks (int): Kernel size for PSF methods. Defaults to PSF_KS. - patch_center (tuple): Center position for PSF patch method. - spp (int): Samples per pixel for ray tracing. Defaults to SPP_RENDER.

{}

Returns:

Name Type Description
Tensor

Rendered image tensor. Shape of [N, C, H, W].

Source code in deeplens-src/deeplens/geolens.py
def render(self, img_obj, depth=None, method="ray_tracing", **kwargs):
    """Differentiable image simulation.

    Image simulation methods:
        [1] PSF map block convolution.
        [2] PSF patch convolution.
        [3] Ray tracing rendering.

    Args:
        img_obj (Tensor): Input image object in raw space. Shape of [N, C, H, W].
        depth (float, optional): Depth of the object. When ``None`` (default),
            falls back to ``self.obj_depth``.
        method (str, optional): Image simulation method. One of 'psf_map', 'psf_patch',
            or 'ray_tracing'. Defaults to 'ray_tracing'.
        **kwargs: Additional arguments for different methods:
            - psf_grid (tuple): Grid size for PSF map method. Defaults to (10, 10).
            - psf_ks (int): Kernel size for PSF methods. Defaults to PSF_KS.
            - patch_center (tuple): Center position for PSF patch method.
            - spp (int): Samples per pixel for ray tracing. Defaults to SPP_RENDER.

    Returns:
        Tensor: Rendered image tensor. Shape of [N, C, H, W].
    """
    depth = self.obj_depth if depth is None else depth
    B, C, Himg, Wimg = img_obj.shape
    Wsensor, Hsensor = self.sensor_res

    # Image simulation
    if method == "psf_map":
        # PSF rendering - uses PSF map to render image
        assert Wimg == Wsensor and Himg == Hsensor, (
            f"Sensor resolution {Wsensor}x{Hsensor} must match input image {Wimg}x{Himg}."
        )
        psf_grid = kwargs.get("psf_grid", (10, 10))
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        psf_spp = kwargs.get("psf_spp", SPP_PSF)
        warp_grid = kwargs.get("warp_grid", 128)
        img_obj = self.warp(img_obj, depth=depth, num_grid=warp_grid)
        img_render = self.render_psf_map(
            img_obj,
            depth=depth,
            psf_grid=psf_grid,
            psf_ks=psf_ks,
            psf_spp=psf_spp,
        )

    elif method == "psf_patch":
        # PSF patch rendering - uses a single PSF to render a patch of the image
        patch_center = kwargs.get("patch_center", (0.0, 0.0))
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        img_render = self.render_psf_patch(
            img_obj, depth=depth, patch_center=patch_center, psf_ks=psf_ks
        )

    elif method == "ray_tracing":
        # Ray tracing rendering
        assert Wimg == Wsensor and Himg == Hsensor, (
            f"Sensor resolution {Wsensor}x{Hsensor} must match input image {Wimg}x{Himg}."
        )
        spp = kwargs.get("spp", SPP_RENDER)
        img_render = self.render_raytracing(img_obj, depth=depth, spp=spp)

    else:
        raise Exception(f"Image simulation method {method} is not supported.")

    return img_render

render_raytracing

render_raytracing(img, depth=None, spp=SPP_RENDER, vignetting=False)

Render RGB image using ray tracing rendering.

Parameters:

Name Type Description Default
img tensor

RGB image tensor. Shape of [N, 3, H, W].

required
depth float

Depth of the object. When None (default), falls back to self.obj_depth.

None
spp int

Sample per pixel. Defaults to 64.

SPP_RENDER
vignetting bool

whether to consider vignetting effect. Defaults to False.

False

Returns:

Name Type Description
img_render tensor

Rendered RGB image tensor. Shape of [N, 3, H, W].

Source code in deeplens-src/deeplens/geolens.py
def render_raytracing(self, img, depth=None, spp=SPP_RENDER, vignetting=False):
    """Render RGB image using ray tracing rendering.

    Args:
        img (tensor): RGB image tensor. Shape of [N, 3, H, W].
        depth (float, optional): Depth of the object. When ``None`` (default),
            falls back to ``self.obj_depth``.
        spp (int, optional): Sample per pixel. Defaults to 64.
        vignetting (bool, optional): whether to consider vignetting effect. Defaults to False.

    Returns:
        img_render (tensor): Rendered RGB image tensor. Shape of [N, 3, H, W].
    """
    depth = self.obj_depth if depth is None else depth
    img_render = torch.zeros_like(img)
    for i in range(3):
        img_render[:, i, :, :] = self.render_raytracing_mono(
            img=img[:, i, :, :],
            wvln=self.wvln_rgb[i],
            depth=depth,
            spp=spp,
            vignetting=vignetting,
        )
    return img_render

render_raytracing_mono

render_raytracing_mono(img, wvln, depth=None, spp=64, vignetting=False)

Render monochrome image using ray tracing rendering.

Parameters:

Name Type Description Default
img tensor

Monochrome image tensor. Shape of [N, 1, H, W] or [N, H, W].

required
wvln float

Wavelength of the light.

required
depth float

Depth of the object. When None (default), falls back to self.obj_depth.

None
spp int

Sample per pixel. Defaults to 64.

64

Returns:

Name Type Description
img_mono tensor

Rendered monochrome image tensor. Shape of [N, 1, H, W] or [N, H, W].

Source code in deeplens-src/deeplens/geolens.py
def render_raytracing_mono(self, img, wvln, depth=None, spp=64, vignetting=False):
    """Render monochrome image using ray tracing rendering.

    Args:
        img (tensor): Monochrome image tensor. Shape of [N, 1, H, W] or [N, H, W].
        wvln (float): Wavelength of the light.
        depth (float, optional): Depth of the object. When ``None`` (default),
            falls back to ``self.obj_depth``.
        spp (int, optional): Sample per pixel. Defaults to 64.

    Returns:
        img_mono (tensor): Rendered monochrome image tensor. Shape of [N, 1, H, W] or [N, H, W].
    """
    depth = self.obj_depth if depth is None else depth
    img = torch.flip(img, [-2, -1])
    scale = self.calc_scale(depth=depth)
    ray = self.sample_sensor(spp=spp, wvln=wvln)
    ray = self.trace2obj(ray)
    img_mono = self.render_compute_image(
        img, depth=depth, scale=scale, ray=ray, vignetting=vignetting
    )
    return img_mono

render_compute_image

render_compute_image(img, depth, scale, ray, vignetting=False)

Computes the intersection points between rays and the object image plane, then generates the rendered image following rendering equation.

Back-propagation gradient flow: image -> w_i -> u -> p -> ray -> surface

Parameters:

Name Type Description Default
img tensor

[N, C, H, W] or [N, H, W] shape image tensor.

required
depth float

depth of the object.

required
scale float

scale factor.

required
ray Ray object

Ray object. Shape [H, W, spp, 3].

required
vignetting bool

whether to consider vignetting effect.

False

Returns:

Name Type Description
image tensor

[N, C, H, W] or [N, H, W] shape rendered image tensor.

Source code in deeplens-src/deeplens/geolens.py
def render_compute_image(self, img, depth, scale, ray, vignetting=False):
    """Computes the intersection points between rays and the object image plane, then generates the rendered image following rendering equation.

    Back-propagation gradient flow: image -> w_i -> u -> p -> ray -> surface

    Args:
        img (tensor): [N, C, H, W] or [N, H, W] shape image tensor.
        depth (float): depth of the object.
        scale (float): scale factor.
        ray (Ray object): Ray object. Shape [H, W, spp, 3].
        vignetting (bool): whether to consider vignetting effect.

    Returns:
        image (tensor): [N, C, H, W] or [N, H, W] shape rendered image tensor.
    """
    assert torch.is_tensor(img), "Input image should be Tensor."

    H, W = img.shape[-2:]
    squeeze_channel = False
    if len(img.shape) == 3:
        img = img.unsqueeze(1)
        squeeze_channel = True
    elif len(img.shape) == 4:
        pass
    else:
        raise ValueError("Input image should be [N, C, H, W] or [N, H, W] tensor.")

    # Scale object image physical size to get 1:1 pixel-pixel alignment with sensor image
    ray = ray.prop_to(depth)
    p = ray.o[..., :2]
    pixel_size = scale * self.pixel_size
    ray.is_valid = (
        ray.is_valid
        * (torch.abs(p[..., 0] / pixel_size) < (W / 2 + 1))
        * (torch.abs(p[..., 1] / pixel_size) < (H / 2 + 1))
    )

    image = backward_integral(
        ray=ray,
        img_obj=img,
        ps=pixel_size,
        vignetting=vignetting,
    )
    if squeeze_channel:
        image = image.squeeze(1)

    return image

warp

warp(img, depth=None, num_grid=128)

Apply lens distortion to an image using inverse distortion mapping.

Parameters:

Name Type Description Default
img tensor

Undistorted image tensor, shape [B, C, H, W].

required
depth float

Object depth. When None (default), falls back to self.obj_depth.

None
num_grid int or tuple

Resolution of the inverse distortion grid.

128

Returns:

Name Type Description
tensor

Distorted image tensor, shape [B, C, H, W].

Source code in deeplens-src/deeplens/geolens.py
def warp(self, img, depth=None, num_grid=128):
    """Apply lens distortion to an image using inverse distortion mapping.

    Args:
        img (tensor): Undistorted image tensor, shape ``[B, C, H, W]``.
        depth (float, optional): Object depth. When ``None`` (default),
            falls back to ``self.obj_depth``.
        num_grid (int or tuple): Resolution of the inverse distortion grid.

    Returns:
        tensor: Distorted image tensor, shape ``[B, C, H, W]``.
    """
    depth = self.obj_depth if depth is None else depth
    inv_distortion_map = self.calc_inv_distortion_map(
        depth=depth, num_grid=num_grid
    )
    inv_distortion_map = inv_distortion_map.permute(2, 0, 1).unsqueeze(0)
    inv_distortion_map = F.interpolate(
        inv_distortion_map, img.shape[-2:], mode="bilinear", align_corners=True
    )
    inv_distortion_map = inv_distortion_map.permute(0, 2, 3, 1).repeat(
        img.shape[0], 1, 1, 1
    )
    img_warped = F.grid_sample(img, inv_distortion_map, align_corners=True)
    return img_warped

unwarp

unwarp(img, depth=None, num_grid=128, crop=True, flip=True)

Unwarp rendered images using distortion map.

Parameters:

Name Type Description Default
img tensor

Rendered image tensor. Shape of [N, C, H, W].

required
depth float

Depth of the object. When None (default), falls back to self.obj_depth.

None
grid_size int

Grid size. Defaults to 256.

required
crop bool

Whether to crop the image. Defaults to True.

True

Returns:

Name Type Description
img_unwarpped tensor

Unwarped image tensor. Shape of [N, C, H, W].

Source code in deeplens-src/deeplens/geolens.py
def unwarp(self, img, depth=None, num_grid=128, crop=True, flip=True):
    """Unwarp rendered images using distortion map.

    Args:
        img (tensor): Rendered image tensor. Shape of [N, C, H, W].
        depth (float, optional): Depth of the object. When ``None`` (default),
            falls back to ``self.obj_depth``.
        grid_size (int, optional): Grid size. Defaults to 256.
        crop (bool, optional): Whether to crop the image. Defaults to True.

    Returns:
        img_unwarpped (tensor): Unwarped image tensor. Shape of [N, C, H, W].
    """
    depth = self.obj_depth if depth is None else depth
    # Calculate distortion map, shape (num_grid, num_grid, 2)
    distortion_map = self.calc_distortion_map(depth=depth, num_grid=num_grid)

    # Interpolate distortion map to image resolution
    distortion_map = distortion_map.permute(2, 0, 1).unsqueeze(1)
    # distortion_map = torch.flip(distortion_map, [-2]) if flip else distortion_map
    distortion_map = F.interpolate(
        distortion_map, img.shape[-2:], mode="bilinear", align_corners=True
    )  # shape (B, 2, Himg, Wimg)
    distortion_map = distortion_map.permute(1, 2, 3, 0).repeat(
        img.shape[0], 1, 1, 1
    )  # shape (B, Himg, Wimg, 2)

    # Unwarp using grid_sample function
    img_unwarpped = F.grid_sample(
        img, distortion_map, align_corners=True
    )  # shape (B, C, Himg, Wimg)
    return img_unwarpped

calc_foclen

calc_foclen(paraxial_fov=0.01)

Compute effective focal length (EFL).

Two-step approach: 1. Trace on-axis parallel rays to find the paraxial focal point z. This is necessary because the sensor may not be at the focal plane (e.g. finite-conjugate designs or defocused systems). 2. Trace off-axis rays at a small angle to the focal point, measure image height, and compute EFL = imgh / tan(angle).

Parameters:

Name Type Description Default
paraxial_fov float

Paraxial field of view in radians for the off-axis ray trace. Defaults to 0.01.

0.01
Updates

self.efl: Effective focal length. self.foclen: Alias for effective focal length. self.bfl: Back focal length (distance from last surface to sensor).

Reference

[1] https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/10/Tutorial_MorelSophie.pdf [2] https://rafcamera.com/info/imaging-theory/back-focal-length

Source code in deeplens-src/deeplens/geolens.py
@torch.no_grad()
def calc_foclen(self, paraxial_fov=0.01):
    """Compute effective focal length (EFL).

    Two-step approach:
    1. Trace on-axis parallel rays to find the paraxial focal point z.
       This is necessary because the sensor may not be at the focal plane
       (e.g. finite-conjugate designs or defocused systems).
    2. Trace off-axis rays at a small angle to the focal point, measure
       image height, and compute EFL = imgh / tan(angle).

    Args:
        paraxial_fov (float, optional): Paraxial field of view in radians
            for the off-axis ray trace. Defaults to 0.01.

    Updates:
        self.efl: Effective focal length.
        self.foclen: Alias for effective focal length.
        self.bfl: Back focal length (distance from last surface to sensor).

    Reference:
        [1] https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/10/Tutorial_MorelSophie.pdf
        [2] https://rafcamera.com/info/imaging-theory/back-focal-length
    """
    # Trace a paraxial chief ray, shape [1, 1, num_rays, 3]
    paraxial_fov_deg = float(np.rad2deg(paraxial_fov))

    # 1. Trace on-axis parallel rays to find paraxial focus z (equivalent to infinite focus)
    ray_axis = self.sample_from_fov(
        fov_x=0.0, fov_y=0.0, entrance_pupil=False, scale_pupil=0.2
    )
    ray_axis, _ = self.trace(ray_axis)
    valid_axis = ray_axis.is_valid > 0
    t = -(ray_axis.d[valid_axis, 0] * ray_axis.o[valid_axis, 0]
          + ray_axis.d[valid_axis, 1] * ray_axis.o[valid_axis, 1]) / (
        ray_axis.d[valid_axis, 0] ** 2 + ray_axis.d[valid_axis, 1] ** 2
    )
    focus_z = ray_axis.o[valid_axis, 2] + t * ray_axis.d[valid_axis, 2]
    focus_z = focus_z[~torch.isnan(focus_z) & (focus_z > 0)]
    paraxial_focus_z = float(torch.mean(focus_z))

    # 2. Trace off-axis paraxial ray to paraxial focus, measure image height
    ray = self.sample_from_fov(
        fov_x=0.0, fov_y=paraxial_fov_deg, entrance_pupil=False, scale_pupil=0.2
    )
    ray, _ = self.trace(ray)
    ray = ray.prop_to(paraxial_focus_z)

    # Compute the effective focal length
    paraxial_imgh = (ray.o[:, 1] * ray.is_valid).sum() / ray.is_valid.sum()
    eff_foclen = paraxial_imgh.item() / float(np.tan(paraxial_fov))
    self.efl = eff_foclen
    self.foclen = eff_foclen

    # Compute the back focal length
    self.bfl = self.d_sensor.item() - self.surfaces[-1].d.item()

    return eff_foclen

calc_numerical_aperture

calc_numerical_aperture(n=1.0)

Compute numerical aperture (NA).

Parameters:

Name Type Description Default
n float

Refractive index. Defaults to 1.0.

1.0

Returns:

Name Type Description
NA float

Numerical aperture.

Reference

[1] https://en.wikipedia.org/wiki/Numerical_aperture

Source code in deeplens-src/deeplens/geolens.py
@torch.no_grad()
def calc_numerical_aperture(self, n=1.0):
    """Compute numerical aperture (NA).

    Args:
        n (float, optional): Refractive index. Defaults to 1.0.

    Returns:
        NA (float): Numerical aperture.

    Reference:
        [1] https://en.wikipedia.org/wiki/Numerical_aperture
    """
    return n * math.sin(math.atan(1 / 2 / self.fnum))

calc_focal_plane

calc_focal_plane(wvln=None)

Compute the focus distance in the object space. Ray starts from sensor center and traces to the object space.

Parameters:

Name Type Description Default
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None

Returns:

Name Type Description
focal_plane float

Focal plane in the object space.

Source code in deeplens-src/deeplens/geolens.py
@torch.no_grad()
def calc_focal_plane(self, wvln=None):
    """Compute the focus distance in the object space. Ray starts from sensor center and traces to the object space.

    Args:
        wvln (float, optional): Wavelength in µm. When ``None`` (default),
            falls back to ``self.primary_wvln``.

    Returns:
        focal_plane (float): Focal plane in the object space.
    """
    wvln = self.primary_wvln if wvln is None else wvln
    device = self.device

    # Sample point source rays from sensor center
    o1 = torch.zeros(SPP_CALC, 3, device=device)
    o1[:, 2] = self.d_sensor

    # Sample the first surface as pupil
    # o2 = self.sample_circle(self.surfaces[0].r, z=0.0, shape=[SPP_CALC])
    # o2 *= 0.5  # Shrink sample region to improve accuracy
    pupilz, pupilr = self.get_exit_pupil()
    o2 = self.sample_circle(pupilr, pupilz, shape=[SPP_CALC])
    d = o2 - o1
    ray = Ray(o1, d, wvln, device=device)

    # Trace rays to object space
    ray = self.trace2obj(ray)

    # Optical axis intersection
    t = (ray.d[..., 0] * ray.o[..., 0] + ray.d[..., 1] * ray.o[..., 1]) / (
        ray.d[..., 0] ** 2 + ray.d[..., 1] ** 2
    )
    focus_z = (ray.o[..., 2] - ray.d[..., 2] * t)[ray.is_valid > 0].cpu().numpy()
    focus_z = focus_z[~np.isnan(focus_z) & (focus_z < 0)]

    if len(focus_z) > 0:
        focal_plane = float(np.mean(focus_z))
    else:
        raise ValueError(
            "No valid rays found, focal plane in the image space cannot be computed."
        )

    return focal_plane

calc_sensor_plane

calc_sensor_plane(depth=float('inf'))

Calculate in-focus sensor plane.

Parameters:

Name Type Description Default
depth float

Depth of the object plane. Defaults to float("inf").

float('inf')

Returns:

Name Type Description
d_sensor Tensor

Sensor plane in the image space.

Source code in deeplens-src/deeplens/geolens.py
@torch.no_grad()
def calc_sensor_plane(self, depth=float("inf")):
    """Calculate in-focus sensor plane.

    Args:
        depth (float, optional): Depth of the object plane. Defaults to float("inf").

    Returns:
        d_sensor (torch.Tensor): Sensor plane in the image space.
    """
    # Sample and trace rays, shape [SPP_CALC, 3]
    ray = self.sample_from_fov(
        fov_x=0.0, fov_y=0.0, depth=depth, num_rays=SPP_CALC
    )
    ray = self.trace2sensor(ray)

    # Calculate in-focus sensor position
    t = (ray.d[:, 0] * ray.o[:, 0] + ray.d[:, 1] * ray.o[:, 1]) / (
        ray.d[:, 0] ** 2 + ray.d[:, 1] ** 2
    )
    focus_z = ray.o[:, 2] - ray.d[:, 2] * t
    focus_z = focus_z[ray.is_valid > 0]
    focus_z = focus_z[~torch.isnan(focus_z) & (focus_z > 0)]
    d_sensor = torch.mean(focus_z)
    return d_sensor

calc_fov

calc_fov()

Compute field of view (FoV) of the lens in radians.

Calculates FoV using two methods
  1. Perspective projection — from focal length and sensor size (effective FoV, ignoring distortion).
  2. Forward ray tracing — sweeps FOV angles from object side, traces to sensor, and finds the angle whose centroid image height matches the sensor half-diagonal. This avoids the failure of the old backward-tracing approach on wide-angle lenses where pupil aberration at full field leaves zero valid rays.
Updates

self.vfov (float): Vertical FoV in radians. self.hfov (float): Horizontal FoV in radians. self.dfov (float): Diagonal FoV in radians. self.rfov_eff (float): Effective half-diagonal FoV in radians (paraxial, ignoring distortion). self.rfov (float): Real half-diagonal FoV from ray tracing (accounts for distortion). self.real_dfov (float): Real diagonal FoV from ray tracing. self.eqfl (float): 35mm equivalent focal length in mm.

Reference

[1] https://en.wikipedia.org/wiki/Angle_of_view_(photography)

Source code in deeplens-src/deeplens/geolens.py
@torch.no_grad()
def calc_fov(self):
    """Compute field of view (FoV) of the lens in radians.

    Calculates FoV using two methods:
        1. **Perspective projection** — from focal length and sensor size
           (effective FoV, ignoring distortion).
        2. **Forward ray tracing** — sweeps FOV angles from object side,
           traces to sensor, and finds the angle whose centroid image height
           matches the sensor half-diagonal. This avoids the failure of the
           old backward-tracing approach on wide-angle lenses where pupil
           aberration at full field leaves zero valid rays.

    Updates:
        self.vfov (float): Vertical FoV in radians.
        self.hfov (float): Horizontal FoV in radians.
        self.dfov (float): Diagonal FoV in radians.
        self.rfov_eff (float): Effective half-diagonal FoV in radians (paraxial, ignoring distortion).
        self.rfov (float): Real half-diagonal FoV from ray tracing (accounts for distortion).
        self.real_dfov (float): Real diagonal FoV from ray tracing.
        self.eqfl (float): 35mm equivalent focal length in mm.

    Reference:
        [1] https://en.wikipedia.org/wiki/Angle_of_view_(photography)
    """
    if not hasattr(self, "foclen"):
        return

    # 1. Perspective projection (effective FoV)
    self.vfov = 2 * math.atan(self.sensor_size[0] / 2 / self.foclen)
    self.hfov = 2 * math.atan(self.sensor_size[1] / 2 / self.foclen)
    self.dfov = 2 * math.atan(self.r_sensor / self.foclen)
    self.rfov_eff = self.dfov / 2  # effective (paraxial) half-diagonal FoV

    # 2. Forward ray tracing to calculate real FoV (distortion-affected)
    # Sweep FOV angles from object side, trace to sensor, and find which
    # angle produces an image height matching r_sensor.
    num_fov = 64
    fov_lo = float(np.rad2deg(self.rfov_eff)) * 0.5
    fov_hi = min(float(np.rad2deg(self.rfov_eff)) * 1.8, 89.0)
    fov_samples = torch.linspace(fov_lo, fov_hi, num_fov, device=self.device)

    ray = self.sample_from_fov(
        fov_x=0.0, fov_y=fov_samples.tolist(), num_rays=256
    )
    ray = self.trace2sensor(ray)

    # Centroid image height per FOV angle, shape [num_fov]
    valid = ray.is_valid > 0  # [num_fov, num_rays]
    masked_y = ray.o[..., 1] * valid
    n_valid = valid.sum(dim=-1).clamp(min=1)
    imgh = (masked_y.sum(dim=-1) / n_valid).abs()

    # Find the FOV angle whose image height is closest to r_sensor
    has_valid = valid.sum(dim=-1) > 10
    if has_valid.any():
        imgh[~has_valid] = float("inf")
        diff = (imgh - self.r_sensor).abs()
        best_idx = diff.argmin().item()
        rfov = fov_samples[best_idx].item() * math.pi / 180.0
        self.rfov = rfov
        self.real_dfov = 2 * rfov
    else:
        self.rfov = self.rfov_eff
        self.real_dfov = self.dfov

    # 3. Compute 35mm equivalent focal length. 35mm sensor: 36mm * 24mm
    self.eqfl = 21.63 / math.tan(self.rfov_eff)

calc_scale

calc_scale(depth)

Calculate the scale factor (object height / image height).

Uses the pinhole camera model to compute magnification.

Parameters:

Name Type Description Default
depth float

Object distance from the lens (negative z direction).

required

Returns:

Name Type Description
float

Scale factor relating object height to image height.

Source code in deeplens-src/deeplens/geolens.py
@torch.no_grad()
def calc_scale(self, depth):
    """Calculate the scale factor (object height / image height).

    Uses the pinhole camera model to compute magnification.

    Args:
        depth (float): Object distance from the lens (negative z direction).

    Returns:
        float: Scale factor relating object height to image height.
    """
    return -depth / self.foclen

calc_pupil

calc_pupil()

Compute entrance and exit pupil positions and radii.

The entrance and exit pupils must be recalculated whenever
  • First-order parameters change (e.g., field of view, object height, image height),
  • Lens geometry or materials change (e.g., surface curvatures, refractive indices, thicknesses),
  • Or generally, any time the lens configuration is modified.
Updates

self.aper_idx: Index of the aperture surface. self.exit_pupilz, self.exit_pupilr: Exit pupil position and radius. self.entr_pupilz, self.entr_pupilr: Entrance pupil position and radius. self.exit_pupilz_parax, self.exit_pupilr_parax: Paraxial exit pupil. self.entr_pupilz_parax, self.entr_pupilr_parax: Paraxial entrance pupil. self.fnum: F-number calculated from focal length and entrance pupil.

Source code in deeplens-src/deeplens/geolens.py
@torch.no_grad()
def calc_pupil(self):
    """Compute entrance and exit pupil positions and radii.

    The entrance and exit pupils must be recalculated whenever:
        - First-order parameters change (e.g., field of view, object height, image height),
        - Lens geometry or materials change (e.g., surface curvatures, refractive indices, thicknesses),
        - Or generally, any time the lens configuration is modified.

    Updates:
        self.aper_idx: Index of the aperture surface.
        self.exit_pupilz, self.exit_pupilr: Exit pupil position and radius.
        self.entr_pupilz, self.entr_pupilr: Entrance pupil position and radius.
        self.exit_pupilz_parax, self.exit_pupilr_parax: Paraxial exit pupil.
        self.entr_pupilz_parax, self.entr_pupilr_parax: Paraxial entrance pupil.
        self.fnum: F-number calculated from focal length and entrance pupil.
    """
    # Find aperture
    self.aper_idx = None
    for i in range(len(self.surfaces)):
        if getattr(self.surfaces[i], "is_aperture", False):
            self.aper_idx = i
            break

    if self.aper_idx is None:
        for i in range(len(self.surfaces)):
            if isinstance(self.surfaces[i], Aperture):
                self.aper_idx = i
                break

    if self.aper_idx is None:
        self.aper_idx = np.argmin([s.r for s in self.surfaces])
        print("No aperture found, use the smallest surface as aperture.")

    # Compute entrance and exit pupil
    self.exit_pupilz, self.exit_pupilr = self.calc_exit_pupil(paraxial=False)
    self.entr_pupilz, self.entr_pupilr = self.calc_entrance_pupil(paraxial=False)
    self.exit_pupilz_parax, self.exit_pupilr_parax = self.calc_exit_pupil(
        paraxial=True
    )
    self.entr_pupilz_parax, self.entr_pupilr_parax = self.calc_entrance_pupil(
        paraxial=True
    )

    # Compute F-number
    self.fnum = self.foclen / (2 * self.entr_pupilr)

get_entrance_pupil

get_entrance_pupil(paraxial=False)

Get entrance pupil location and radius.

Parameters:

Name Type Description Default
paraxial bool

If True, return paraxial approximation values. If False, return real ray-traced values. Defaults to False.

False

Returns:

Name Type Description
tuple

(z_position, radius) of the entrance pupil in [mm].

Source code in deeplens-src/deeplens/geolens.py
def get_entrance_pupil(self, paraxial=False):
    """Get entrance pupil location and radius.

    Args:
        paraxial (bool, optional): If True, return paraxial approximation values.
            If False, return real ray-traced values. Defaults to False.

    Returns:
        tuple: (z_position, radius) of the entrance pupil in [mm].
    """
    if paraxial:
        return self.entr_pupilz_parax, self.entr_pupilr_parax
    else:
        return self.entr_pupilz, self.entr_pupilr

get_exit_pupil

get_exit_pupil(paraxial=False)

Get exit pupil location and radius.

Parameters:

Name Type Description Default
paraxial bool

If True, return paraxial approximation values. If False, return real ray-traced values. Defaults to False.

False

Returns:

Name Type Description
tuple

(z_position, radius) of the exit pupil in [mm].

Source code in deeplens-src/deeplens/geolens.py
def get_exit_pupil(self, paraxial=False):
    """Get exit pupil location and radius.

    Args:
        paraxial (bool, optional): If True, return paraxial approximation values.
            If False, return real ray-traced values. Defaults to False.

    Returns:
        tuple: (z_position, radius) of the exit pupil in [mm].
    """
    if paraxial:
        return self.exit_pupilz_parax, self.exit_pupilr_parax
    else:
        return self.exit_pupilz, self.exit_pupilr

calc_exit_pupil

calc_exit_pupil(paraxial=False)

Calculate exit pupil location and radius.

Paraxial mode

Rays are emitted from near the center of the aperture stop and are close to the optical axis. This mode estimates the exit pupil position and radius under ideal (first-order) optical assumptions. It is fast and stable.

Non-paraxial mode

Rays are emitted from the edge of the aperture stop in large quantities. The exit pupil position and radius are determined based on the intersection points of these rays. This mode is slower and affected by aperture-related aberrations.

Use paraxial mode unless precise ray aiming is required.

Parameters:

Name Type Description Default
paraxial bool

center (True) or edge (False).

False

Returns:

Name Type Description
avg_pupilz float

z coordinate of exit pupil.

avg_pupilr float

radius of exit pupil.

Reference

[1] Exit pupil: how many rays can come from sensor to object space. [2] https://en.wikipedia.org/wiki/Exit_pupil

Source code in deeplens-src/deeplens/geolens.py
@torch.no_grad()
def calc_exit_pupil(self, paraxial=False):
    """Calculate exit pupil location and radius.

    Paraxial mode:
        Rays are emitted from near the center of the aperture stop and are close to the optical axis.
        This mode estimates the exit pupil position and radius under ideal (first-order) optical assumptions.
        It is fast and stable.

    Non-paraxial mode:
        Rays are emitted from the edge of the aperture stop in large quantities.
        The exit pupil position and radius are determined based on the intersection points of these rays.
        This mode is slower and affected by aperture-related aberrations.

    Use paraxial mode unless precise ray aiming is required.

    Args:
        paraxial (bool): center (True) or edge (False).

    Returns:
        avg_pupilz (float): z coordinate of exit pupil.
        avg_pupilr (float): radius of exit pupil.

    Reference:
        [1] Exit pupil: how many rays can come from sensor to object space.
        [2] https://en.wikipedia.org/wiki/Exit_pupil
    """
    if self.aper_idx is None or hasattr(self, "aper_idx") is False:
        print("No aperture, use the last surface as exit pupil.")
        return self.surfaces[-1].d.item(), self.surfaces[-1].r

    # Sample rays from aperture (edge or center)
    aper_idx = self.aper_idx
    aper_z = self.surfaces[aper_idx].d.item()
    aper_r = self.surfaces[aper_idx].r

    if paraxial:
        ray_o = torch.tensor([[DELTA_PARAXIAL, 0, aper_z]], device=self.device).repeat(32, 1)
        phi_rad = torch.linspace(-0.01, 0.01, 32, device=self.device)
    else:
        ray_o = torch.tensor([[aper_r, 0, aper_z]], device=self.device).repeat(SPP_CALC, 1)
        rfov = float(np.arctan(self.r_sensor / self.foclen))
        phi_rad = torch.linspace(-rfov / 2, rfov / 2, SPP_CALC, device=self.device)

    d = torch.stack(
        (torch.sin(phi_rad), torch.zeros_like(phi_rad), torch.cos(phi_rad)), axis=-1
    )
    ray = Ray(ray_o, d, wvln=self.primary_wvln, device=self.device)

    # Ray tracing from aperture edge to last surface
    surf_range = range(self.aper_idx + 1, len(self.surfaces))
    ray, _ = self.trace(ray, surf_range=surf_range)

    # Compute intersection points, solving the equation: o1+d1*t1 = o2+d2*t2
    ray_o = torch.stack(
        [ray.o[ray.is_valid != 0][:, 0], ray.o[ray.is_valid != 0][:, 2]], dim=-1
    )
    ray_d = torch.stack(
        [ray.d[ray.is_valid != 0][:, 0], ray.d[ray.is_valid != 0][:, 2]], dim=-1
    )
    intersection_points = self.compute_intersection_points_2d(ray_o, ray_d)

    # Handle the case where no intersection points are found or small pupil
    if len(intersection_points) == 0:
        print("No intersection points found, use the last surface as exit pupil.")
        avg_pupilr = self.surfaces[-1].r
        avg_pupilz = self.surfaces[-1].d.item()
    else:
        avg_pupilr = torch.mean(intersection_points[:, 0]).item()
        avg_pupilz = torch.mean(intersection_points[:, 1]).item()

        if paraxial:
            avg_pupilr = abs(avg_pupilr / DELTA_PARAXIAL * aper_r)

        if avg_pupilr < EPSILON:
            print(
                "Zero or negative exit pupil is detected, use the last surface as pupil."
            )
            avg_pupilr = self.surfaces[-1].r
            avg_pupilz = self.surfaces[-1].d.item()

    return avg_pupilz, avg_pupilr

calc_entrance_pupil

calc_entrance_pupil(paraxial=False)

Calculate entrance pupil of the lens.

The entrance pupil is the optical image of the physical aperture stop, as seen through the optical elements in front of the stop. We sample backward rays from the aperture stop and trace them to the first surface, then find the intersection points of the reverse extension of the rays. The average of the intersection points defines the entrance pupil position and radius.

Parameters:

Name Type Description Default
paraxial bool

Ray sampling mode. If True, rays are emitted near the centre of the aperture stop (fast, paraxially stable). If False, rays are emitted from the stop edge in larger quantities (slower, accounts for aperture aberrations). Defaults to False.

False

Returns:

Name Type Description
tuple

(z_position, radius) of entrance pupil.

Note

[1] Use paraxial mode unless precise ray aiming is required. [2] This function only works for object at a far distance. For microscopes, this function usually returns a negative entrance pupil.

References

[1] Entrance pupil: how many rays can come from object space to sensor. [2] https://en.wikipedia.org/wiki/Entrance_pupil: "In an optical system, the entrance pupil is the optical image of the physical aperture stop, as 'seen' through the optical elements in front of the stop." [3] Zemax LLC, OpticStudio User Manual, Version 19.4, Document No. 2311, 2019.

Source code in deeplens-src/deeplens/geolens.py
@torch.no_grad()
def calc_entrance_pupil(self, paraxial=False):
    """Calculate entrance pupil of the lens.

    The entrance pupil is the optical image of the physical aperture stop, as seen through the optical elements in front of the stop. We sample backward rays from the aperture stop and trace them to the first surface, then find the intersection points of the reverse extension of the rays. The average of the intersection points defines the entrance pupil position and radius.

    Args:
        paraxial (bool): Ray sampling mode.  If ``True``, rays are emitted
            near the centre of the aperture stop (fast, paraxially stable).
            If ``False``, rays are emitted from the stop edge in larger
            quantities (slower, accounts for aperture aberrations).
            Defaults to ``False``.

    Returns:
        tuple: (z_position, radius) of entrance pupil.

    Note:
        [1] Use paraxial mode unless precise ray aiming is required.
        [2] This function only works for object at a far distance. For microscopes, this function usually returns a negative entrance pupil.

    References:
        [1] Entrance pupil: how many rays can come from object space to sensor.
        [2] https://en.wikipedia.org/wiki/Entrance_pupil: "In an optical system, the entrance pupil is the optical image of the physical aperture stop, as 'seen' through the optical elements in front of the stop."
        [3] Zemax LLC, *OpticStudio User Manual*, Version 19.4, Document No. 2311, 2019.
    """
    if self.aper_idx is None or not hasattr(self, "aper_idx"):
        print("No aperture stop, use the first surface as entrance pupil.")
        return self.surfaces[0].d.item(), self.surfaces[0].r

    # Sample rays from edge of aperture stop
    aper_idx = self.aper_idx
    aper_surf = self.surfaces[aper_idx]
    aper_z = aper_surf.d.item()
    if aper_surf.is_square:
        aper_r = float(np.sqrt(2)) * aper_surf.r
    else:
        aper_r = aper_surf.r

    if paraxial:
        ray_o = torch.tensor([[DELTA_PARAXIAL, 0, aper_z]], device=self.device).repeat(32, 1)
        phi = torch.linspace(-0.01, 0.01, 32, device=self.device)
    else:
        ray_o = torch.tensor([[aper_r, 0, aper_z]], device=self.device).repeat(SPP_CALC, 1)
        rfov = float(np.arctan(self.r_sensor / self.foclen))
        phi = torch.linspace(-rfov / 2, rfov / 2, SPP_CALC, device=self.device)

    d = torch.stack(
        (torch.sin(phi), torch.zeros_like(phi), -torch.cos(phi)), axis=-1
    )
    ray = Ray(ray_o, d, wvln=self.primary_wvln, device=self.device)

    # Ray tracing from aperture edge to first surface
    surf_range = range(0, self.aper_idx)
    ray, _ = self.trace(ray, surf_range=surf_range)

    # Compute intersection points, solving the equation: o1+d1*t1 = o2+d2*t2
    ray_o = torch.stack(
        [ray.o[ray.is_valid > 0][:, 0], ray.o[ray.is_valid > 0][:, 2]], dim=-1
    )
    ray_d = torch.stack(
        [ray.d[ray.is_valid > 0][:, 0], ray.d[ray.is_valid > 0][:, 2]], dim=-1
    )
    intersection_points = self.compute_intersection_points_2d(ray_o, ray_d)

    # Handle the case where no intersection points are found or small entrance pupil
    if len(intersection_points) == 0:
        print(
            "No intersection points found, use the first surface as entrance pupil."
        )
        avg_pupilr = self.surfaces[0].r
        avg_pupilz = self.surfaces[0].d.item()
    else:
        avg_pupilr = torch.mean(intersection_points[:, 0]).item()
        avg_pupilz = torch.mean(intersection_points[:, 1]).item()

        if paraxial:
            avg_pupilr = abs(avg_pupilr / DELTA_PARAXIAL * aper_r)

        if avg_pupilr < EPSILON:
            print(
                "Zero or negative entrance pupil is detected, use the first surface as entrance pupil."
            )
            avg_pupilr = self.surfaces[0].r
            avg_pupilz = self.surfaces[0].d.item()

    return avg_pupilz, avg_pupilr

compute_intersection_points_2d staticmethod

compute_intersection_points_2d(origins, directions)

Compute the intersection points of 2D lines.

Parameters:

Name Type Description Default
origins Tensor

Origins of the lines. Shape: [N, 2]

required
directions Tensor

Directions of the lines. Shape: [N, 2]

required

Returns:

Type Description

torch.Tensor: Intersection points. Shape: [N*(N-1)/2, 2]

Source code in deeplens-src/deeplens/geolens.py
@staticmethod
def compute_intersection_points_2d(origins, directions):
    """Compute the intersection points of 2D lines.

    Args:
        origins (torch.Tensor): Origins of the lines. Shape: [N, 2]
        directions (torch.Tensor): Directions of the lines. Shape: [N, 2]

    Returns:
        torch.Tensor: Intersection points. Shape: [N*(N-1)/2, 2]
    """
    N = origins.shape[0]

    # Create pairwise combinations of indices
    idx = torch.arange(N)
    idx_i, idx_j = torch.combinations(idx, r=2).unbind(1)

    Oi = origins[idx_i]  # Shape: [N*(N-1)/2, 2]
    Oj = origins[idx_j]  # Shape: [N*(N-1)/2, 2]
    Di = directions[idx_i]  # Shape: [N*(N-1)/2, 2]
    Dj = directions[idx_j]  # Shape: [N*(N-1)/2, 2]

    # Vector from Oi to Oj
    b = Oj - Oi  # Shape: [N*(N-1)/2, 2]

    # Coefficients matrix A
    A = torch.stack([Di, -Dj], dim=-1)  # Shape: [N*(N-1)/2, 2, 2]

    # Solve the linear system Ax = b
    # Using least squares to handle the case of no exact solution
    if A.device.type == "mps":
        # Perform lstsq on CPU for MPS devices and move result back
        x, _ = torch.linalg.lstsq(A.cpu(), b.unsqueeze(-1).cpu())[:2]
        x = x.to(A.device)
    else:
        x, _ = torch.linalg.lstsq(A, b.unsqueeze(-1))[:2]
    x = x.squeeze(-1)  # Shape: [N*(N-1)/2, 2]
    s = x[:, 0]
    t = x[:, 1]

    # Calculate the intersection points using either rays
    P_i = Oi + s.unsqueeze(-1) * Di  # Shape: [N*(N-1)/2, 2]
    P_j = Oj + t.unsqueeze(-1) * Dj  # Shape: [N*(N-1)/2, 2]

    # Take the average to mitigate numerical precision issues
    P = (P_i + P_j) / 2

    return P

refocus

refocus(foc_dist=float('inf'))

Refocus the lens to a depth distance by changing sensor position.

Parameters:

Name Type Description Default
foc_dist float

focal distance.

float('inf')
Note

In DSLR, phase detection autofocus (PDAF) is a popular and efficient method. But here we simplify the problem by calculating the in-focus position of green light.

Source code in deeplens-src/deeplens/geolens.py
@torch.no_grad()
def refocus(self, foc_dist=float("inf")):
    """Refocus the lens to a depth distance by changing sensor position.

    Args:
        foc_dist (float): focal distance.

    Note:
        In DSLR, phase detection autofocus (PDAF) is a popular and efficient method. But here we simplify the problem by calculating the in-focus position of green light.
    """
    # Calculate in-focus sensor position
    d_sensor_new = self.calc_sensor_plane(depth=foc_dist)

    # Update sensor position
    assert d_sensor_new > 0, "Obtained negative sensor position."
    self.d_sensor = d_sensor_new

    # FoV will be slightly changed
    self.post_computation()

set_fnum

set_fnum(fnum)

Set F-number and aperture radius using binary search.

Parameters:

Name Type Description Default
fnum float

target F-number.

required
Source code in deeplens-src/deeplens/geolens.py
@torch.no_grad()
def set_fnum(self, fnum):
    """Set F-number and aperture radius using binary search.

    Args:
        fnum (float): target F-number.
    """
    target_pupil_r = self.foclen / fnum / 2
    aper_r = self.surfaces[self.aper_idx].r
    lo, hi = 0.1 * aper_r, 5.0 * aper_r

    pupilr = None
    for _ in range(40):
        mid = 0.5 * (lo + hi)
        self.surfaces[self.aper_idx].r = mid
        _, pupilr = self.calc_entrance_pupil()
        if abs(pupilr - target_pupil_r) / target_pupil_r < 1e-3:
            break
        if pupilr > target_pupil_r:
            hi = mid
        else:
            lo = mid
    else:
        logging.warning(
            f"set_fnum: did not converge, pupil_r={pupilr:.4f}, target={target_pupil_r:.4f}"
        )

    self.calc_pupil()

set_target_fov_fnum

set_target_fov_fnum(rfov, fnum)

Set FoV, ImgH and F number, only use this function to assign design targets.

Parameters:

Name Type Description Default
rfov float

half diagonal-FoV in radian.

required
fnum float

F number.

required
Source code in deeplens-src/deeplens/geolens.py
@torch.no_grad()
def set_target_fov_fnum(self, rfov, fnum):
    """Set FoV, ImgH and F number, only use this function to assign design targets.

    Args:
        rfov (float): half diagonal-FoV in radian.
        fnum (float): F number.
    """
    if rfov > math.pi:
        self.rfov_eff = rfov / 180.0 * math.pi
    else:
        self.rfov_eff = rfov

    self.rfov = self.rfov_eff
    self.real_dfov = 2 * self.rfov
    self.foclen = self.r_sensor / math.tan(self.rfov_eff)
    self.eqfl = 21.63 / math.tan(self.rfov_eff)
    self.fnum = fnum
    aper_r = self.foclen / fnum / 2
    self.surfaces[self.aper_idx].update_r(float(aper_r))

    # Update pupil after setting aperture radius
    self.calc_pupil()

set_fov

set_fov(rfov)

Set half-diagonal field of view as a design target.

Unlike calc_fov() which derives FoV from focal length and sensor size, this method directly assigns the target FoV for lens optimisation.

Parameters:

Name Type Description Default
rfov float

Half-diagonal FoV in radians.

required
Source code in deeplens-src/deeplens/geolens.py
@torch.no_grad()
def set_fov(self, rfov):
    """Set half-diagonal field of view as a design target.

    Unlike ``calc_fov()`` which derives FoV from focal length and sensor
    size, this method directly assigns the target FoV for lens optimisation.

    Args:
        rfov (float): Half-diagonal FoV in radians.
    """
    self.rfov_eff = rfov
    self.rfov = rfov
    self.real_dfov = 2 * self.rfov
    self.eqfl = 21.63 / math.tan(self.rfov_eff)

Components

GeoLens uses a mixin architecture — its functionality is split across the focused classes below. You normally interact only with GeoLens itself; these are documented for reference.

deeplens.geolens_pkg.psf_compute.GeoLensPSF

Mixin providing PSF computation for GeoLens.

All three PSF models are exposed through a single psf dispatcher. The geometric and coherent models are differentiable; Huygens is not.

This class is not instantiated directly; it is mixed into GeoLens.

psf

psf(points, ks=PSF_KS, wvln=None, spp=None, recenter=True, model='geometric')

Calculate Point Spread Function (PSF) for given point sources.

Supports multiple PSF calculation models
  • geometric: Incoherent intensity ray tracing (fast, differentiable)
  • coherent: Coherent ray tracing with free-space propagation (accurate, differentiable)
  • huygens: Huygens-Fresnel integration (accurate, not differentiable)

Parameters:

Name Type Description Default
points Tensor

Point source positions. Shape [N, 3] with x, y in [-1, 1] and z in [-Inf, 0]. Normalized coordinates.

required
ks int

Output kernel size in pixels. Defaults to PSF_KS.

PSF_KS
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None
spp int

Samples per pixel. If None, uses model-specific default.

None
recenter bool

If True, center PSF using chief ray. Defaults to True.

True
model str

PSF model type. One of 'geometric', 'coherent', 'huygens'. Defaults to 'geometric'.

'geometric'

Returns:

Name Type Description
Tensor

PSF normalized to sum to 1. Shape [ks, ks] or [N, ks, ks].

Source code in deeplens-src/deeplens/geolens_pkg/psf_compute.py
def psf(
    self,
    points,
    ks=PSF_KS,
    wvln=None,
    spp=None,
    recenter=True,
    model="geometric",
):
    """Calculate Point Spread Function (PSF) for given point sources.

    Supports multiple PSF calculation models:
        - geometric: Incoherent intensity ray tracing (fast, differentiable)
        - coherent: Coherent ray tracing with free-space propagation (accurate, differentiable)
        - huygens: Huygens-Fresnel integration (accurate, not differentiable)

    Args:
        points (Tensor): Point source positions. Shape [N, 3] with x, y in [-1, 1]
            and z in [-Inf, 0]. Normalized coordinates.
        ks (int, optional): Output kernel size in pixels. Defaults to PSF_KS.
        wvln (float, optional): Wavelength in µm. When ``None`` (default),
            falls back to ``self.primary_wvln``.
        spp (int, optional): Samples per pixel. If None, uses model-specific default.
        recenter (bool, optional): If True, center PSF using chief ray. Defaults to True.
        model (str, optional): PSF model type. One of 'geometric', 'coherent', 'huygens'.
            Defaults to 'geometric'.

    Returns:
        Tensor: PSF normalized to sum to 1. Shape [ks, ks] or [N, ks, ks].
    """
    wvln = self.primary_wvln if wvln is None else wvln
    if model == "geometric":
        spp = SPP_PSF if spp is None else spp
        return self.psf_geometric(points, ks, wvln, spp, recenter)
    elif model == "coherent":
        spp = SPP_COHERENT if spp is None else spp
        return self.psf_coherent(points, ks, wvln, spp, recenter)
    elif model == "huygens":
        spp = SPP_COHERENT if spp is None else spp
        return self.psf_huygens(points, ks, wvln, spp, recenter)
    else:
        raise ValueError(f"Unknown PSF model: {model}")

psf_geometric

psf_geometric(points, ks=PSF_KS, wvln=None, spp=SPP_PSF, recenter=True)

Single wavelength geometric PSF calculation.

Parameters:

Name Type Description Default
points Tensor

Normalized point source position. Shape of [N, 3], x, y in range [-1, 1], z in range [-Inf, 0].

required
ks int

Output kernel size.

PSF_KS
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None
spp int

Sample per pixel.

SPP_PSF
recenter bool

Recenter PSF using chief ray.

True

Returns:

Name Type Description
psf

Shape of [ks, ks] or [N, ks, ks].

References

[1] https://optics.ansys.com/hc/en-us/articles/42661723066515-What-is-a-Point-Spread-Function

Source code in deeplens-src/deeplens/geolens_pkg/psf_compute.py
def psf_geometric(
    self, points, ks=PSF_KS, wvln=None, spp=SPP_PSF, recenter=True
):
    """Single wavelength geometric PSF calculation.

    Args:
        points (Tensor): Normalized point source position. Shape of [N, 3], x, y in range [-1, 1], z in range [-Inf, 0].
        ks (int, optional): Output kernel size.
        wvln (float, optional): Wavelength in µm. When ``None`` (default),
            falls back to ``self.primary_wvln``.
        spp (int, optional): Sample per pixel.
        recenter (bool, optional): Recenter PSF using chief ray.

    Returns:
        psf: Shape of [ks, ks] or [N, ks, ks].

    References:
        [1] https://optics.ansys.com/hc/en-us/articles/42661723066515-What-is-a-Point-Spread-Function
    """
    wvln = self.primary_wvln if wvln is None else wvln
    sensor_w, sensor_h = self.sensor_size
    pixel_size = self.pixel_size
    device = self.device

    # Points shape of [N, 3]
    if not torch.is_tensor(points):
        points = torch.tensor(points, device=device)

    if len(points.shape) == 1:
        single_point = True
        points = points.unsqueeze(0)
    else:
        single_point = False

    # Sample rays. Ray position in the object space by perspective projection
    depth = points[:, 2]
    scale = self.calc_scale(depth)
    point_obj_x = points[..., 0] * scale * sensor_w / 2
    point_obj_y = points[..., 1] * scale * sensor_h / 2
    point_obj = torch.stack([point_obj_x, point_obj_y, points[..., 2]], dim=-1)
    ray = self.sample_from_points(points=point_obj, num_rays=spp, wvln=wvln)

    # Trace rays to sensor plane (incoherent)
    ray.is_coherent = False
    ray = self.trace2sensor(ray)

    # Calculate PSF center, shape [N, 2]
    if recenter:
        pointc = self.psf_center(point_obj, method="chief_ray")
    else:
        pointc = self.psf_center(point_obj, method="pinhole")

    # Monte Carlo integration
    psf = forward_integral(ray.flip_xy(), ps=pixel_size, ks=ks, pointc=pointc)

    # Intensity normalization
    psf = psf / (torch.sum(psf, dim=(-2, -1), keepdim=True) + EPSILON)

    if single_point:
        psf = psf.squeeze(0)

    return diff_float(psf)

psf_coherent

psf_coherent(points, ks=PSF_KS, wvln=None, spp=SPP_COHERENT, recenter=True)

Alias for psf_pupil_prop. Calculates PSF by coherent ray tracing to exit pupil followed by Angular Spectrum Method (ASM) propagation.

Source code in deeplens-src/deeplens/geolens_pkg/psf_compute.py
def psf_coherent(
    self, points, ks=PSF_KS, wvln=None, spp=SPP_COHERENT, recenter=True
):
    """Alias for psf_pupil_prop. Calculates PSF by coherent ray tracing to exit pupil followed by Angular Spectrum Method (ASM) propagation."""
    wvln = self.primary_wvln if wvln is None else wvln
    return self.psf_pupil_prop(points, ks=ks, wvln=wvln, spp=spp, recenter=recenter)

psf_pupil_prop

psf_pupil_prop(points, ks=PSF_KS, wvln=None, spp=SPP_COHERENT, recenter=True)

Single point monochromatic PSF using exit-pupil diffraction model. This function is differentiable.

Steps

1, Calculate complex wavefield at exit-pupil plane by coherent ray tracing. 2, Free-space propagation to sensor plane and calculate intensity PSF.

Parameters:

Name Type Description Default
points Tensor

[x, y, z] coordinates of the point source. Defaults to torch.Tensor([0,0,-10000]).

required
ks int

size of the PSF patch. Defaults to PSF_KS.

PSF_KS
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None
spp int

number of rays to sample. Defaults to SPP_COHERENT.

SPP_COHERENT
recenter bool

Recenter PSF using chief ray. Defaults to True.

True

Returns:

Name Type Description
psf_out Tensor

PSF patch. Normalized to sum to 1. Shape [ks, ks]

Reference

[1] "End-to-End Hybrid Refractive-Diffractive Lens Design with Differentiable Ray-Wave Model", SIGGRAPH Asia 2024.

Note

[1] This function is similar to ZEMAX FFT_PSF but implement free-space propagation with Angular Spectrum Method (ASM) rather than FFT transform. Free-space propagation using ASM is more accurate than doing FFT, because FFT (as used in ZEMAX) assumes far-field condition (e.g., chief ray perpendicular to image plane).

Source code in deeplens-src/deeplens/geolens_pkg/psf_compute.py
def psf_pupil_prop(
    self, points, ks=PSF_KS, wvln=None, spp=SPP_COHERENT, recenter=True
):
    """Single point monochromatic PSF using exit-pupil diffraction model. This function is differentiable.

    Steps:
        1, Calculate complex wavefield at exit-pupil plane by coherent ray tracing.
        2, Free-space propagation to sensor plane and calculate intensity PSF.

    Args:
        points (torch.Tensor, optional): [x, y, z] coordinates of the point source. Defaults to torch.Tensor([0,0,-10000]).
        ks (int, optional): size of the PSF patch. Defaults to PSF_KS.
        wvln (float, optional): Wavelength in µm. When ``None`` (default),
            falls back to ``self.primary_wvln``.
        spp (int, optional): number of rays to sample. Defaults to SPP_COHERENT.
        recenter (bool, optional): Recenter PSF using chief ray. Defaults to True.

    Returns:
        psf_out (torch.Tensor): PSF patch. Normalized to sum to 1. Shape [ks, ks]

    Reference:
        [1] "End-to-End Hybrid Refractive-Diffractive Lens Design with Differentiable Ray-Wave Model", SIGGRAPH Asia 2024.

    Note:
        [1] This function is similar to ZEMAX FFT_PSF but implement free-space propagation with Angular Spectrum Method (ASM) rather than FFT transform. Free-space propagation using ASM is more accurate than doing FFT, because FFT (as used in ZEMAX) assumes far-field condition (e.g., chief ray perpendicular to image plane).
    """
    wvln = self.primary_wvln if wvln is None else wvln
    # Pupil field by coherent ray tracing
    wavefront, psfc = self.pupil_field(
        points=points, wvln=wvln, spp=spp, recenter=recenter
    )

    # Propagate to sensor plane and get intensity
    pupilz, pupilr = self.get_exit_pupil()
    h, w = wavefront.shape
    # Manually pad wave field
    wavefront = F.pad(
        wavefront.unsqueeze(0).unsqueeze(0),
        [h // 2, h // 2, w // 2, w // 2],
        mode="constant",
        value=0,
    )
    # Free-space propagation using Angular Spectrum Method (ASM)
    sensor_field = AngularSpectrumMethod(
        wavefront,
        z=self.d_sensor - pupilz,
        wvln=wvln,
        ps=self.pixel_size,
        padding=False,
    )
    # Get intensity
    psf_inten = sensor_field.abs() ** 2

    # Calculate PSF center
    h, w = psf_inten.shape[-2:]
    # consider both interplation and padding
    psfc_idx_i = ((2 - psfc[1]) * h / 4).round().long()
    psfc_idx_j = ((2 + psfc[0]) * w / 4).round().long()

    # Crop valid PSF region and normalize
    if ks is not None:
        psf_inten_pad = (
            F.pad(
                psf_inten,
                [ks // 2, ks // 2, ks // 2, ks // 2],
                mode="constant",
                value=0,
            )
            .squeeze(0)
            .squeeze(0)
        )
        psf = psf_inten_pad[
            psfc_idx_i : psfc_idx_i + ks, psfc_idx_j : psfc_idx_j + ks
        ]
    else:
        psf = psf_inten

    # Intensity normalization, shape of [ks, ks] or [h, w]
    psf = psf / (torch.sum(psf, dim=(-2, -1), keepdim=True) + EPSILON)

    return diff_float(psf)

pupil_field

pupil_field(points, wvln=None, spp=SPP_COHERENT, recenter=True)

Compute complex wavefront at exit pupil plane by coherent ray tracing.

The wavefront is flipped for subsequent PSF calculation and has the same size as the image sensor. This function is differentiable.

Parameters:

Name Type Description Default
points Tensor or list

Single point source position. Shape [3] or [1, 3], with x, y in [-1, 1] and z in [-Inf, 0].

required
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None
spp int

Number of rays to sample. Must be >= 1,000,000 for accurate coherent simulation. Defaults to SPP_COHERENT.

SPP_COHERENT
recenter bool

If True, center using chief ray. Defaults to True.

True

Returns:

Name Type Description
tuple

(wavefront, psf_center) where: - wavefront (Tensor): Complex wavefront at exit pupil. Shape [H, H]. - psf_center (list): Normalized PSF center coordinates [x, y] in [-1, 1].

Note

Default dtype must be torch.float64 for accurate phase calculation.

Source code in deeplens-src/deeplens/geolens_pkg/psf_compute.py
def pupil_field(self, points, wvln=None, spp=SPP_COHERENT, recenter=True):
    """Compute complex wavefront at exit pupil plane by coherent ray tracing.

    The wavefront is flipped for subsequent PSF calculation and has the same
    size as the image sensor. This function is differentiable.

    Args:
        points (Tensor or list): Single point source position. Shape [3] or [1, 3],
            with x, y in [-1, 1] and z in [-Inf, 0].
        wvln (float, optional): Wavelength in µm. When ``None`` (default),
            falls back to ``self.primary_wvln``.
        spp (int, optional): Number of rays to sample. Must be >= 1,000,000 for
            accurate coherent simulation. Defaults to SPP_COHERENT.
        recenter (bool, optional): If True, center using chief ray. Defaults to True.

    Returns:
        tuple: (wavefront, psf_center) where:
            - wavefront (Tensor): Complex wavefront at exit pupil. Shape [H, H].
            - psf_center (list): Normalized PSF center coordinates [x, y] in [-1, 1].

    Note:
        Default dtype must be torch.float64 for accurate phase calculation.
    """
    wvln = self.primary_wvln if wvln is None else wvln
    assert spp >= 1_000_000, (
        f"Ray sampling {spp} is too small for coherent ray tracing, which may lead to inaccurate simulation."
    )
    assert torch.get_default_dtype() == torch.float64, (
        "Default dtype must be set to float64 for accurate phase calculation."
    )

    sensor_w, sensor_h = self.sensor_size
    device = self.device

    if isinstance(points, list):
        points = torch.tensor(points, device=device).unsqueeze(0)  # [1, 3]
    elif torch.is_tensor(points) and len(points.shape) == 1:
        points = points.unsqueeze(0).to(device)  # [1, 3]
    elif torch.is_tensor(points) and len(points.shape) == 2:
        assert points.shape[0] == 1, (
            f"pupil_field only supports single point input, got shape {points.shape}"
        )
    else:
        raise ValueError(f"Unsupported point type {points.type()}.")

    assert points.shape[0] == 1, (
        "Only one point is supported for pupil field calculation."
    )

    # Ray origin in the object space
    scale = self.calc_scale(points[:, 2].item())
    point_obj_x = points[:, 0] * scale * sensor_w / 2
    point_obj_y = points[:, 1] * scale * sensor_h / 2
    points_obj = torch.stack([point_obj_x, point_obj_y, points[:, 2]], dim=-1)

    # Ray center determined by chief ray
    # Shape of [N, 2], un-normalized physical coordinates
    if recenter:
        pointc = self.psf_center(points_obj, method="chief_ray")
    else:
        pointc = self.psf_center(points_obj, method="pinhole")

    # Ray-tracing to exit_pupil
    ray = self.sample_from_points(points=points_obj, num_rays=spp, wvln=wvln)
    ray.is_coherent = True
    ray = self.trace2exit_pupil(ray)

    # Calculate complex field (same physical size and resolution as the sensor)
    # Complex field is flipped here for further PSF calculation
    pointc_ref = torch.zeros_like(points[:, :2])  # [N, 2]
    wavefront = forward_integral(
        ray.flip_xy(),
        ps=self.pixel_size,
        ks=self.sensor_res[1],
        pointc=pointc_ref,
    )
    wavefront = wavefront.squeeze(0)  # [H, H]

    # PSF center (on the sensor plane)
    pointc = pointc[0, :]
    psf_center = [
        pointc[0] / sensor_w * 2,
        pointc[1] / sensor_h * 2,
    ]

    return wavefront, psf_center

psf_huygens

psf_huygens(points, ks=PSF_KS, wvln=None, spp=SPP_COHERENT, recenter=True)

Single wavelength Huygens PSF calculation.

This function is not differentiable due to its heavy computational cost.

Steps

1, Trace coherent rays to exit-pupil plane. 2, Treat every ray as a secondary point source emitting a spherical wave.

Parameters:

Name Type Description Default
points Tensor

Normalized point source position. Shape of [N, 3], x, y in range [-1, 1], z in range [-Inf, 0].

required
ks int

Output kernel size.

PSF_KS
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None
spp int

Sample per pixel.

SPP_COHERENT
recenter bool

Recenter PSF using chief ray.

True

Returns:

Name Type Description
psf

Shape of [ks, ks] or [N, ks, ks].

References

[1] "Optical Aberrations Correction in Postprocessing Using Imaging Simulation", TOG 2021

Note

This is different from ZEMAX Huygens PSF, which traces rays to image plane and do plane wave integration.

Source code in deeplens-src/deeplens/geolens_pkg/psf_compute.py
def psf_huygens(
    self, points, ks=PSF_KS, wvln=None, spp=SPP_COHERENT, recenter=True
):
    """Single wavelength Huygens PSF calculation.

    This function is not differentiable due to its heavy computational cost.

    Steps:
        1, Trace coherent rays to exit-pupil plane.
        2, Treat every ray as a secondary point source emitting a spherical wave.

    Args:
        points (Tensor): Normalized point source position. Shape of [N, 3], x, y in range [-1, 1], z in range [-Inf, 0].
        ks (int, optional): Output kernel size.
        wvln (float, optional): Wavelength in µm. When ``None`` (default),
            falls back to ``self.primary_wvln``.
        spp (int, optional): Sample per pixel.
        recenter (bool, optional): Recenter PSF using chief ray.

    Returns:
        psf: Shape of [ks, ks] or [N, ks, ks].

    References:
        [1] "Optical Aberrations Correction in Postprocessing Using Imaging Simulation", TOG 2021

    Note:
        This is different from ZEMAX Huygens PSF, which traces rays to image plane and do plane wave integration.
    """
    wvln = self.primary_wvln if wvln is None else wvln
    assert torch.get_default_dtype() == torch.float64, (
        "Default dtype must be set to float64 for accurate phase calculation."
    )

    sensor_w, sensor_h = self.sensor_size
    pixel_size = self.pixel_size
    device = self.device
    wvln_mm = wvln * 1e-3  # Convert wavelength to mm

    # Points shape of [N, 3]
    if not torch.is_tensor(points):
        points = torch.tensor(points, device=device)

    if len(points.shape) == 1:
        single_point = True
        points = points.unsqueeze(0)
    elif len(points.shape) == 2 and points.shape[0] == 1:
        single_point = True
    else:
        raise ValueError(
            f"Points must be of shape [3] or [1, 3], got {points.shape}."
        )

    # Sample rays from object point
    depth = points[:, 2]
    scale = self.calc_scale(depth)
    point_obj_x = points[..., 0] * scale * sensor_w / 2
    point_obj_y = points[..., 1] * scale * sensor_h / 2
    point_obj = torch.stack([point_obj_x, point_obj_y, points[..., 2]], dim=-1)
    ray = self.sample_from_points(points=point_obj, num_rays=spp, wvln=wvln)

    # Trace rays coherently through the lens to exit pupil
    ray.is_coherent = True
    ray = self.trace2exit_pupil(ray)

    # Calculate PSF center (not flipped here)
    if recenter:
        pointc = -self.psf_center(point_obj, method="chief_ray")
    else:
        pointc = -self.psf_center(point_obj, method="pinhole")

    # Build PSF pixel coordinates (sensor plane at z = d_sensor)
    sensor_z = self.d_sensor.item()
    psf_half_size = (ks / 2) * pixel_size  # Physical half-size of PSF region
    x_coords = torch.linspace(
        -psf_half_size + pixel_size / 2,
        psf_half_size - pixel_size / 2,
        ks,
        device=device,
    )
    y_coords = torch.linspace(
        psf_half_size - pixel_size / 2,
        -psf_half_size + pixel_size / 2,
        ks,
        device=device,
    )
    psf_x, psf_y = torch.meshgrid(
        pointc[0, 0] + x_coords, pointc[0, 1] + y_coords, indexing="xy"
    )  # [ks, ks] each

    # Get valid rays only
    valid_mask = ray.is_valid > 0
    valid_pos = ray.o[valid_mask]  # [num_valid, 3]
    valid_dir = ray.d[valid_mask]  # [num_valid, 3]
    valid_opl = ray.opl[valid_mask]  # [num_valid]
    num_valid = valid_pos.shape[0]

    # Huygens integration: sum spherical waves from each secondary source
    psf_complex = torch.zeros(ks, ks, dtype=torch.complex128, device=device)
    opl_min = valid_opl.min()

    # Compute distance from each secondary source to each pixel
    batch_size = min(num_valid, 10_000)  # Process rays in batches
    for batch_start in range(0, num_valid, batch_size):
        batch_end = min(batch_start + batch_size, num_valid)

        # Batch ray data
        batch_pos = valid_pos[batch_start:batch_end]  # [batch, 3]
        batch_dir = valid_dir[batch_start:batch_end]  # [batch, 3]
        batch_opl = valid_opl[batch_start:batch_end].squeeze(-1)  # [batch]

        # Distance from each secondary source to each pixel
        # batch_pos: [batch, 3], psf_x: [ks, ks]
        dx = psf_x.unsqueeze(-1) - batch_pos[:, 0]  # [ks, ks, batch]
        dy = psf_y.unsqueeze(-1) - batch_pos[:, 1]  # [ks, ks, batch]
        dz = sensor_z - batch_pos[:, 2]  # [batch]

        # Distance r from secondary source to pixel
        r = torch.sqrt(dx**2 + dy**2 + dz**2)  # [ks, ks, batch]

        # Obliquity factor: cos(theta) where theta is angle from normal
        # Using ray direction at exit pupil (dz component)
        obliq = torch.abs(batch_dir[:, 2])  # [batch]
        amp = 0.5 * (1.0 + obliq)  # Huygens–Fresnel obliquity factor

        # Total optical path = OPL through lens + distance to pixel
        total_opl = batch_opl + r  # [ks, ks, batch]

        # Phase relative to reference
        phase = torch.fmod((total_opl - opl_min) / wvln_mm, 1.0) * (
            2 * torch.pi
        )  # [ks, ks, batch]

        # Complex amplitude: A * exp(i * phase) / r (spherical wave decay)
        # We use 1/r for spherical wave amplitude decay
        complex_amp = (amp / r) * torch.exp(1j * phase)  # [ks, ks, batch]

        # Sum contributions from this batch
        psf_complex += complex_amp.sum(dim=-1)  # [ks, ks]

    # Convert complex field to intensity
    psf = psf_complex.abs() ** 2

    # Intensity normalization
    psf = psf / (torch.sum(psf, dim=(-2, -1), keepdim=True) + EPSILON)

    # Flip PSF
    psf = torch.flip(psf, [-2, -1])

    if single_point:
        psf = psf.squeeze(0)

    return diff_float(psf)

psf_map

psf_map(depth=None, grid=(7, 7), ks=PSF_KS, spp=SPP_PSF, wvln=None, recenter=True)

Compute the geometric PSF map at given depth.

Overrides the base method in Lens class to improve efficiency by parallel ray tracing over different field points.

Parameters:

Name Type Description Default
depth float

Depth of the object plane. When None (default), falls back to self.obj_depth.

None
grid (int, tuple)

Grid size (grid_w, grid_h). Defaults to 7.

(7, 7)
ks int

Kernel size. Defaults to PSF_KS.

PSF_KS
spp int

Sample per pixel. Defaults to SPP_PSF.

SPP_PSF
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None
recenter bool

Recenter PSF using chief ray. Defaults to True.

True

Returns:

Name Type Description
psf_map

PSF map. Shape of [grid_h, grid_w, 1, ks, ks].

Source code in deeplens-src/deeplens/geolens_pkg/psf_compute.py
def psf_map(
    self,
    depth=None,
    grid=(7, 7),
    ks=PSF_KS,
    spp=SPP_PSF,
    wvln=None,
    recenter=True,
):
    """Compute the geometric PSF map at given depth.

    Overrides the base method in Lens class to improve efficiency by parallel ray tracing over different field points.

    Args:
        depth (float, optional): Depth of the object plane. When ``None``
            (default), falls back to ``self.obj_depth``.
        grid (int, tuple): Grid size (grid_w, grid_h). Defaults to 7.
        ks (int, optional): Kernel size. Defaults to PSF_KS.
        spp (int, optional): Sample per pixel. Defaults to SPP_PSF.
        wvln (float, optional): Wavelength in µm. When ``None`` (default),
            falls back to ``self.primary_wvln``.
        recenter (bool, optional): Recenter PSF using chief ray. Defaults to True.

    Returns:
        psf_map: PSF map. Shape of [grid_h, grid_w, 1, ks, ks].
    """
    wvln = self.primary_wvln if wvln is None else wvln
    depth = self.obj_depth if depth is None else depth
    if isinstance(grid, int):
        grid = (grid, grid)
    points = self.point_source_grid(depth=depth, grid=grid)
    points = points.reshape(-1, 3)
    psfs = self.psf(
        points=points, ks=ks, recenter=recenter, spp=spp, wvln=wvln
    ).unsqueeze(1)  # [grid_h * grid_w, 1, ks, ks]

    psf_map = psfs.reshape(grid[1], grid[0], 1, ks, ks)
    return psf_map

psf_center

psf_center(points_obj, method='chief_ray')

Compute reference PSF center (flipped to match the original point) for given point source.

Parameters:

Name Type Description Default
points_obj

[..., 3] un-normalized point in object plane. [-Inf, Inf] * [-Inf, Inf] * [-Inf, 0]

required
method

"chief_ray" or "pinhole". Defaults to "chief_ray".

'chief_ray'

Returns:

Name Type Description
psf_center

[..., 2] un-normalized psf center in sensor plane.

Source code in deeplens-src/deeplens/geolens_pkg/psf_compute.py
@torch.no_grad()
def psf_center(self, points_obj, method="chief_ray"):
    """Compute reference PSF center (flipped to match the original point) for given point source.

    Args:
        points_obj: [..., 3] un-normalized point in object plane. [-Inf, Inf] * [-Inf, Inf] * [-Inf, 0]
        method: "chief_ray" or "pinhole". Defaults to "chief_ray".

    Returns:
        psf_center: [..., 2] un-normalized psf center in sensor plane.
    """
    if method == "chief_ray":
        # Shrink the pupil and calculate centroid ray as the chief ray
        ray = self.sample_from_points(points_obj, scale_pupil=0.5, num_rays=SPP_CALC)
        ray = self.trace2sensor(ray)
        if ray.is_valid.any():
            psf_center = ray.centroid()
            psf_center = -psf_center[..., :2]  # shape [..., 2]
        else:
            # Fallback to pinhole when chief ray fails (can happen during optimization)
            return self.psf_center(points_obj, method="pinhole")

    elif method == "pinhole":
        # Pinhole camera perspective projection, distortion not considered
        if points_obj[..., 2].min().abs() < 100:
            print(
                "Point source is too close, pinhole model may be inaccurate for PSF center calculation."
            )
        tan_point_fov_x = -points_obj[..., 0] / points_obj[..., 2]
        tan_point_fov_y = -points_obj[..., 1] / points_obj[..., 2]
        psf_center_x = self.foclen * tan_point_fov_x
        psf_center_y = self.foclen * tan_point_fov_y
        psf_center = torch.stack([psf_center_x, psf_center_y], dim=-1).to(
            self.device
        )

    else:
        raise ValueError(
            f"Unsupported method for PSF center calculation: {method}."
        )

    return psf_center

deeplens.geolens_pkg.eval.GeoLensEval

Mixin that adds classical optical evaluation methods to GeoLens.

This class is never instantiated on its own. It is mixed into GeoLens via multiple inheritance, so every method can access lens geometry (self.d_sensor, self.rfov, …) and ray-tracing routines (self.trace(), self.trace2sensor(), …) directly through self.

All evaluation functions follow the same pattern
  1. Sample rays from object space (parallel / grid / radial).
  2. Trace rays through the lens (self.trace or self.trace2sensor).
  3. Analyze ray positions / directions at the sensor plane.
  4. Optionally produce a matplotlib figure saved to disk.

Results are accuracy-aligned with Zemax OpticStudio for the same lens prescriptions and ray-sampling densities.

Attributes consumed from GeoLens (via self): d_sensor (float): Axial position of the sensor plane (mm). sensor_size (tuple[float, float]): Sensor (width, height) in mm. pixel_size (float): Pixel pitch in mm. sensor_res (tuple[int, int]): Sensor resolution (H, W) in pixels. rfov (float): Half field-of-view in radians. foclen (float): Equivalent focal length in mm. fnum (float): F-number. aper_idx (int): Index of the aperture stop surface. device (torch.device): Compute device (CPU / CUDA).

spot_points

spot_points(points, num_rays=SPP_PSF, wvln=None)

Trace rays from object points to sensor and return the traced Ray.

Samples rays from each physical object point toward the entrance pupil, traces through all lens surfaces (refraction + clipping), and returns the resulting Ray object on the sensor plane.

This is the shared computational core for spot diagrams (draw_spot_radial, draw_spot_map) and RMS error maps (rms_map, rms_map_rgb).

Algorithm
  1. self.sample_from_points(points, num_rays, wvln) generates a fan of num_rays rays per object point, aimed at the entrance pupil.
  2. self.trace2sensor() propagates through all surfaces and clips vignetted rays.

Parameters:

Name Type Description Default
points Tensor

Physical 3D object-space coordinates with shape [..., 3] (mm). Supported layouts: - [3] — single point. - [N, 3] — N points (e.g. radial field positions). - [H, W, 3] — 2-D field grid. Generated by self.point_source_grid(normalized=False) for grid sampling, or self.point_source_radial(normalized=False) for radial sampling.

required
num_rays int

Number of rays sampled per object point. Defaults to SPP_PSF.

SPP_PSF
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None

Returns:

Name Type Description
Ray

Traced ray on the sensor plane, with shape [..., num_rays, 3] for positions and [..., num_rays] for validity mask. Use ray.o[..., :2] for transverse positions and ray.is_valid for the validity mask. ray.centroid() gives the weighted centroid.

Source code in deeplens-src/deeplens/geolens_pkg/eval.py
@torch.no_grad()
def spot_points(self, points, num_rays=SPP_PSF, wvln=None):
    """Trace rays from object points to sensor and return the traced Ray.

    Samples rays from each physical object point toward the entrance pupil,
    traces through all lens surfaces (refraction + clipping), and returns
    the resulting Ray object on the sensor plane.

    This is the shared computational core for spot diagrams
    (``draw_spot_radial``, ``draw_spot_map``) and RMS error maps
    (``rms_map``, ``rms_map_rgb``).

    Algorithm:
        1. ``self.sample_from_points(points, num_rays, wvln)`` generates a
           fan of ``num_rays`` rays per object point, aimed at the entrance
           pupil.
        2. ``self.trace2sensor()`` propagates through all surfaces and
           clips vignetted rays.

    Args:
        points (torch.Tensor): Physical 3D object-space coordinates with
            shape ``[..., 3]`` (mm).  Supported layouts:
            - ``[3]`` — single point.
            - ``[N, 3]`` — N points (e.g. radial field positions).
            - ``[H, W, 3]`` — 2-D field grid.
            Generated by ``self.point_source_grid(normalized=False)`` for
            grid sampling, or ``self.point_source_radial(normalized=False)``
            for radial sampling.
        num_rays (int): Number of rays sampled per object point.
            Defaults to ``SPP_PSF``.
        wvln (float): Wavelength in µm. When ``None`` (default), falls
            back to ``self.primary_wvln``.

    Returns:
        Ray: Traced ray on the sensor plane, with shape
            ``[..., num_rays, 3]`` for positions and ``[..., num_rays]``
            for validity mask. Use ``ray.o[..., :2]`` for transverse
            positions and ``ray.is_valid`` for the validity mask.
            ``ray.centroid()`` gives the weighted centroid.
    """
    wvln = self.primary_wvln if wvln is None else wvln
    ray = self.sample_from_points(points=points, num_rays=num_rays, wvln=wvln)
    return self.trace2sensor(ray)

draw_spot_radial

draw_spot_radial(save_name='./lens_spot_radial.png', num_fov=5, depth=None, num_rays=SPP_PSF, wvln_list=None, direction='y', show=False)

Draw spot diagrams at evenly-spaced field angles along a chosen direction.

A spot diagram visualizes the transverse ray-intercept distribution on the sensor plane for a point source at a given field angle and depth. It reveals the combined effect of all aberrations (spherical, coma, astigmatism, field curvature, chromatic, …).

Algorithm

For each wavelength in wvln_list: 1. self.point_source_radial(direction, normalized=False) generates physical object-space points along the chosen direction. 2. self.spot_points() samples rays and traces to sensor. 3. Valid ray (x, y) positions are scatter-plotted per subplot. All wavelengths are overlaid in a single figure with RGB coloring.

Parameters:

Name Type Description Default
save_name str

File path for the output PNG. Defaults to './lens_spot_radial.png'.

'./lens_spot_radial.png'
num_fov int

Number of field positions sampled uniformly from on-axis (0) to full-field. Defaults to 5.

5
depth float

Object distance in mm (negative = real object). When None (default), falls back to self.obj_depth.

None
num_rays int

Rays per field position per wavelength. Defaults to SPP_PSF.

SPP_PSF
wvln_list list[float]

Wavelengths in µm. When None (default), falls back to self.wvln_rgb.

None
direction str

Sampling direction — "y" (meridional, default), "x" (sagittal), "diagonal" (45°).

'y'
show bool

If True, display the figure interactively instead of saving to disk. Defaults to False.

False
Source code in deeplens-src/deeplens/geolens_pkg/eval.py
@torch.no_grad()
def draw_spot_radial(
    self,
    save_name="./lens_spot_radial.png",
    num_fov=5,
    depth=None,
    num_rays=SPP_PSF,
    wvln_list=None,
    direction="y",
    show=False,
):
    """Draw spot diagrams at evenly-spaced field angles along a chosen direction.

    A *spot diagram* visualizes the transverse ray-intercept distribution on
    the sensor plane for a point source at a given field angle and depth.
    It reveals the combined effect of all aberrations (spherical, coma,
    astigmatism, field curvature, chromatic, …).

    Algorithm:
        For each wavelength in ``wvln_list``:
            1. ``self.point_source_radial(direction, normalized=False)``
               generates physical object-space points along the chosen
               direction.
            2. ``self.spot_points()`` samples rays and traces to sensor.
            3. Valid ray (x, y) positions are scatter-plotted per subplot.
        All wavelengths are overlaid in a single figure with RGB coloring.

    Args:
        save_name (str): File path for the output PNG.
            Defaults to ``'./lens_spot_radial.png'``.
        num_fov (int): Number of field positions sampled uniformly from
            on-axis (0) to full-field. Defaults to 5.
        depth (float): Object distance in mm (negative = real object).
            When ``None`` (default), falls back to ``self.obj_depth``.
        num_rays (int): Rays per field position per wavelength.
            Defaults to ``SPP_PSF``.
        wvln_list (list[float]): Wavelengths in µm.  When ``None``
            (default), falls back to ``self.wvln_rgb``.
        direction (str): Sampling direction —
            ``"y"`` (meridional, default), ``"x"`` (sagittal),
            ``"diagonal"`` (45°).
        show (bool): If ``True``, display the figure interactively instead
            of saving to disk. Defaults to ``False``.
    """
    wvln_list = self.wvln_rgb if wvln_list is None else wvln_list
    assert isinstance(wvln_list, list), "wvln_list must be a list"
    if depth is None or depth == float("inf"):
        depth = self.obj_depth

    # Generate physical object-space points along the chosen direction
    points = self.point_source_radial(
        depth=depth, grid=num_fov, direction=direction, normalized=False
    )

    # Prepare figure
    fig, axs = plt.subplots(1, num_fov, figsize=(num_fov * 3.5, 3))
    axs = np.atleast_1d(axs)

    # Trace and draw each wavelength separately, overlaying results
    for wvln_idx, wvln in enumerate(wvln_list):
        ray = self.spot_points(points, num_rays=num_rays, wvln=wvln)
        ray_o = ray.o[..., :2].cpu().numpy()
        ray_valid_np = ray.is_valid.cpu().numpy()

        color = RGB_COLORS[wvln_idx % len(RGB_COLORS)]

        # Plot multiple spot diagrams in one figure
        for i in range(num_fov):
            valid = ray_valid_np[i, :]
            xi, yi = ray_o[i, :, 0], ray_o[i, :, 1]

            # Filter valid rays
            mask = valid > 0
            x_valid, y_valid = xi[mask], yi[mask]

            # Plot points and center of mass for this wavelength
            axs[i].scatter(x_valid, y_valid, 2, color=color, alpha=0.5)
            axs[i].set_aspect("equal", adjustable="datalim")
            axs[i].tick_params(axis="both", which="major", labelsize=6)

    if show:
        plt.show()
    else:
        assert save_name.endswith(".png"), "save_name must end with .png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

draw_spot_map

draw_spot_map(save_name='./lens_spot_map.png', num_grid=5, depth=None, num_rays=SPP_PSF, wvln_list=None, show=False)

Draw a 2-D grid of spot diagrams across the full field of view.

Unlike draw_spot_radial (which samples only a radial slice), this method samples a num_grid × num_grid grid of field positions covering both the x (sagittal) and y (meridional) axes, revealing off-axis aberrations that are invisible in a 1-D radial scan.

Algorithm

For each wavelength in wvln_list: 1. self.point_source_grid(normalized=False) creates physical object-space grid points, shape [grid_h, grid_w, 3]. 2. self.spot_points() samples rays and traces to sensor. 3. Valid (x, y) positions are scatter-plotted in the corresponding subplot of the num_grid × num_grid figure. All wavelengths are overlaid with RGB coloring.

Parameters:

Name Type Description Default
save_name str

File path for the output PNG. Defaults to './lens_spot_map.png'.

'./lens_spot_map.png'
num_grid int | tuple[int, int]

Number of grid points along each axis. Total subplots = grid_w * grid_h. Defaults to 5.

5
depth float

Object distance in mm. When None (default), falls back to self.obj_depth.

None
num_rays int

Rays per grid cell per wavelength. Defaults to SPP_PSF.

SPP_PSF
wvln_list list[float]

Wavelengths in µm. When None (default), falls back to self.wvln_rgb.

None
show bool

If True, display interactively. Defaults to False.

False
Source code in deeplens-src/deeplens/geolens_pkg/eval.py
@torch.no_grad()
def draw_spot_map(
    self,
    save_name="./lens_spot_map.png",
    num_grid=5,
    depth=None,
    num_rays=SPP_PSF,
    wvln_list=None,
    show=False,
):
    """Draw a 2-D grid of spot diagrams across the full field of view.

    Unlike ``draw_spot_radial`` (which samples only a radial slice),
    this method samples a ``num_grid × num_grid`` grid of field positions
    covering both the x (sagittal) and y (meridional) axes, revealing
    off-axis aberrations that are invisible in a 1-D radial scan.

    Algorithm:
        For each wavelength in ``wvln_list``:
            1. ``self.point_source_grid(normalized=False)`` creates physical
               object-space grid points, shape ``[grid_h, grid_w, 3]``.
            2. ``self.spot_points()`` samples rays and traces to sensor.
            3. Valid (x, y) positions are scatter-plotted in the
               corresponding subplot of the ``num_grid × num_grid`` figure.
        All wavelengths are overlaid with RGB coloring.

    Args:
        save_name (str): File path for the output PNG.
            Defaults to ``'./lens_spot_map.png'``.
        num_grid (int | tuple[int, int]): Number of grid points along each
            axis. Total subplots = ``grid_w * grid_h``. Defaults to 5.
        depth (float): Object distance in mm. When ``None`` (default),
            falls back to ``self.obj_depth``.
        num_rays (int): Rays per grid cell per wavelength.
            Defaults to ``SPP_PSF``.
        wvln_list (list[float]): Wavelengths in µm.  When ``None``
            (default), falls back to ``self.wvln_rgb``.
        show (bool): If ``True``, display interactively. Defaults to ``False``.
    """
    wvln_list = self.wvln_rgb if wvln_list is None else wvln_list
    depth = self.obj_depth if depth is None else depth
    assert isinstance(wvln_list, list), "wvln_list must be a list"
    if isinstance(num_grid, int):
        num_grid = (num_grid, num_grid)

    # Generate physical object-space grid points, shape [grid_h, grid_w, 3]
    points = self.point_source_grid(depth=depth, grid=num_grid, normalized=False)

    grid_w, grid_h = num_grid
    fig, axs = plt.subplots(
        grid_h, grid_w, figsize=(grid_w * 3, grid_h * 3)
    )
    axs = np.atleast_2d(axs)

    # Loop wavelengths and overlay scatters
    for wvln_idx, wvln in enumerate(wvln_list):
        ray = self.spot_points(points, num_rays=num_rays, wvln=wvln)

        # Convert to numpy, shape [grid_h, grid_w, num_rays, 2]
        ray_o = -ray.o[..., :2].cpu().numpy()
        ray_valid_np = ray.is_valid.cpu().numpy()

        color = RGB_COLORS[wvln_idx % len(RGB_COLORS)]

        # Draw per grid cell
        for i in range(grid_h):
            for j in range(grid_w):
                valid = ray_valid_np[i, j, :]
                xi, yi = ray_o[i, j, :, 0], ray_o[i, j, :, 1]

                # Filter valid rays
                mask = valid > 0
                x_valid, y_valid = xi[mask], yi[mask]

                # Plot points for this wavelength
                axs[i, j].scatter(x_valid, y_valid, 2, color=color, alpha=0.5)
                axs[i, j].set_aspect("equal", adjustable="datalim")
                axs[i, j].tick_params(axis="both", which="major", labelsize=6)

    if show:
        plt.show()
    else:
        assert save_name.endswith(".png"), "save_name must end with .png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

rms_map

rms_map(num_grid=32, depth=None, wvln=None, center=None)

Compute per-field-position RMS spot radius for a single wavelength.

Traces SPP_PSF rays per grid cell and computes the root-mean-square distance of valid ray hits from a reference centroid. When center is None, each cell uses its own centroid (monochromatic blur). When an external center is provided (e.g. the green-channel centroid), the RMS includes the chromatic shift from that reference.

Algorithm
  1. self.point_source_grid(normalized=False) generates physical object points on a [num_grid, num_grid] field grid.
  2. self.spot_points() samples SPP_PSF rays per point and traces to sensor.
  3. If center is None, compute per-cell centroid c = mean(valid ray_xy); otherwise use the provided center.
  4. RMS = sqrt( mean( ||ray_xy - c||^2 ) ).

Parameters:

Name Type Description Default
num_grid int | tuple[int, int]

Spatial resolution of the field sampling grid. Defaults to 32.

32
depth float

Object distance in mm. When None (default), falls back to self.obj_depth.

None
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None
center Tensor | None

External reference centroid with shape [grid_h, grid_w, 2]. If None, each cell's own centroid is used. Defaults to None.

None

Returns:

Type Description

tuple[torch.Tensor, torch.Tensor]: - rms: RMS spot error map, shape [grid_h, grid_w], in mm. - centroid: Per-cell centroid used as reference, shape [grid_h, grid_w, 2]. Useful for passing as center to subsequent calls (e.g. in rms_map_rgb).

Source code in deeplens-src/deeplens/geolens_pkg/eval.py
@torch.no_grad()
def rms_map(self, num_grid=32, depth=None, wvln=None, center=None):
    """Compute per-field-position RMS spot radius for a single wavelength.

    Traces ``SPP_PSF`` rays per grid cell and computes the root-mean-square
    distance of valid ray hits from a reference centroid.  When ``center``
    is ``None``, each cell uses its own centroid (monochromatic blur).
    When an external ``center`` is provided (e.g. the green-channel
    centroid), the RMS includes the chromatic shift from that reference.

    Algorithm:
        1. ``self.point_source_grid(normalized=False)`` generates physical
           object points on a ``[num_grid, num_grid]`` field grid.
        2. ``self.spot_points()`` samples ``SPP_PSF`` rays per point and
           traces to sensor.
        3. If ``center`` is ``None``, compute per-cell centroid
           ``c = mean(valid ray_xy)``; otherwise use the provided ``center``.
        4. ``RMS = sqrt( mean( ||ray_xy - c||^2 ) )``.

    Args:
        num_grid (int | tuple[int, int]): Spatial resolution of the field
            sampling grid. Defaults to 32.
        depth (float): Object distance in mm. When ``None`` (default),
            falls back to ``self.obj_depth``.
        wvln (float): Wavelength in µm. When ``None`` (default), falls
            back to ``self.primary_wvln``.
        center (torch.Tensor | None): External reference centroid with shape
            ``[grid_h, grid_w, 2]``.  If ``None``, each cell's own
            centroid is used. Defaults to ``None``.

    Returns:
        tuple[torch.Tensor, torch.Tensor]:
            - **rms**: RMS spot error map, shape ``[grid_h, grid_w]``,
              in mm.
            - **centroid**: Per-cell centroid used as reference, shape
              ``[grid_h, grid_w, 2]``.  Useful for passing as
              ``center`` to subsequent calls (e.g. in ``rms_map_rgb``).
    """
    wvln = self.primary_wvln if wvln is None else wvln
    depth = self.obj_depth if depth is None else depth
    if isinstance(num_grid, int):
        num_grid = (num_grid, num_grid)

    # Generate physical grid points and trace rays to sensor
    points = self.point_source_grid(depth=depth, grid=num_grid, normalized=False)
    ray = self.spot_points(points, num_rays=SPP_PSF, wvln=wvln)

    # Reuse Ray.centroid() — shape [grid_h, grid_w, 3], slice to [grid_h, grid_w, 2]
    centroid = ray.centroid()[..., :2]

    # Use external center if provided, otherwise own centroid
    ref = center if center is not None else centroid

    # RMS relative to reference, shape [grid_h, grid_w]
    ray_xy = ray.o[..., :2]
    ray_valid = ray.is_valid
    rms = torch.sqrt(
        (((ray_xy - ref.unsqueeze(-2)) ** 2).sum(-1) * ray_valid).sum(-1)
        / (ray_valid.sum(-1) + EPSILON)
    )

    return rms, centroid

rms_map_rgb

rms_map_rgb(num_grid=32, depth=None)

Compute per-field-position RMS spot radius for R, G, B wavelengths.

The RMS spot radius is a standard measure of geometrical image quality. For each field position in a num_grid × num_grid grid, this method traces SPP_PSF rays per wavelength and computes the root-mean-square distance of valid ray hits from a common reference centroid.

The reference centroid is the green-channel centroid. Using a common reference means the returned RMS values include lateral chromatic aberration (the shift between R/G/B centroids), making the map useful as a polychromatic image-quality metric.

Algorithm
  1. Call rms_map(wvln=green) to get the green RMS map and the green centroid.
  2. Call rms_map(wvln=red, center=green_centroid) and rms_map(wvln=blue, center=green_centroid) to measure R/B blur relative to the green reference.
  3. Stack as [R, G, B].

Parameters:

Name Type Description Default
num_grid int

Spatial resolution of the field sampling grid. Defaults to 32.

32
depth float

Object distance in mm. When None (default), falls back to self.obj_depth.

None

Returns:

Type Description

torch.Tensor: RMS spot error map with shape [3, num_grid, num_grid] (channels ordered R, G, B). Units are mm (same as sensor coordinates).

Source code in deeplens-src/deeplens/geolens_pkg/eval.py
@torch.no_grad()
def rms_map_rgb(self, num_grid=32, depth=None):
    """Compute per-field-position RMS spot radius for R, G, B wavelengths.

    The RMS spot radius is a standard measure of geometrical image quality.
    For each field position in a ``num_grid × num_grid`` grid, this method
    traces ``SPP_PSF`` rays per wavelength and computes the root-mean-square
    distance of valid ray hits from a **common** reference centroid.

    The reference centroid is the green-channel centroid.  Using a common
    reference means the returned RMS values include *lateral chromatic
    aberration* (the shift between R/G/B centroids), making the map useful
    as a polychromatic image-quality metric.

    Algorithm:
        1. Call ``rms_map(wvln=green)`` to get the green RMS map **and**
           the green centroid.
        2. Call ``rms_map(wvln=red, center=green_centroid)`` and
           ``rms_map(wvln=blue, center=green_centroid)`` to measure R/B
           blur relative to the green reference.
        3. Stack as ``[R, G, B]``.

    Args:
        num_grid (int): Spatial resolution of the field sampling grid.
            Defaults to 32.
        depth (float): Object distance in mm. When ``None`` (default),
            falls back to ``self.obj_depth``.

    Returns:
        torch.Tensor: RMS spot error map with shape ``[3, num_grid, num_grid]``
            (channels ordered R, G, B). Units are mm (same as sensor
            coordinates).
    """
    depth = self.obj_depth if depth is None else depth
    # Green first to obtain the shared reference centroid
    rms_g, green_centroid = self.rms_map(
        num_grid=num_grid, depth=depth, wvln=self.wvln_rgb[1]
    )

    # Red and blue relative to the green centroid
    rms_r, _ = self.rms_map(
        num_grid=num_grid, depth=depth, wvln=self.wvln_rgb[0], center=green_centroid
    )
    rms_b, _ = self.rms_map(
        num_grid=num_grid, depth=depth, wvln=self.wvln_rgb[2], center=green_centroid
    )

    return torch.stack([rms_r, rms_g, rms_b], dim=0)

calc_distortion_radial

calc_distortion_radial(num_points=GEO_GRID, wvln=None, plane='meridional', ray_aiming=True)

Compute fractional distortion at evenly-spaced field angles along the meridional direction.

Distortion is defined as (h_actual - h_ideal) / h_ideal, where h_ideal = f * tan(theta) (rectilinear projection) and h_actual is the chief-ray image height on the sensor. A positive value means pincushion distortion; negative means barrel distortion.

This is the computational counterpart to draw_spot_radial: it samples num_points field angles uniformly from 0 to self.rfov and returns both the sampled angles and the corresponding distortion values, making it easy to pair with other radial evaluation functions.

Algorithm
  1. Derive rfov_deg from self.rfov (radians → degrees).
  2. Sample num_points field angles uniformly in [0, rfov_deg]. The on-axis sample (0°) is replaced by a tiny positive angle to avoid 0/0.
  3. Compute h_ideal = foclen * tan(angle) for each sample.
  4. Trace the chief ray (via calc_chief_ray_infinite) through the full lens to the sensor plane.
  5. Extract h_actual from the appropriate transverse coordinate (x for sagittal, y for meridional).
  6. Return (h_actual - h_ideal) / h_ideal.

Parameters:

Name Type Description Default
num_points int

Number of evenly-spaced field-angle samples from on-axis (0°) to full-field (self.rfov). Defaults to GEO_GRID.

GEO_GRID
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None
plane str

'meridional' (y-axis) or 'sagittal' (x-axis). Defaults to 'meridional'.

'meridional'
ray_aiming bool

If True, the chief ray is aimed to pass through the center of the aperture stop (more accurate for wide-angle lenses). Defaults to True.

True

Returns:

Type Description

tuple[np.ndarray, np.ndarray]: - rfov_samples: Field angles in degrees, shape [num_points]. - distortions: Fractional distortion at each angle, shape [num_points]. Dimensionless (multiply by 100 for percent).

Source code in deeplens-src/deeplens/geolens_pkg/eval.py
@torch.no_grad()
def calc_distortion_radial(
    self,
    num_points=GEO_GRID,
    wvln=None,
    plane="meridional",
    ray_aiming=True,
):
    """Compute fractional distortion at evenly-spaced field angles along the meridional direction.

    Distortion is defined as ``(h_actual - h_ideal) / h_ideal``, where
    ``h_ideal = f * tan(theta)`` (rectilinear projection) and ``h_actual``
    is the chief-ray image height on the sensor.  A positive value means
    pincushion distortion; negative means barrel distortion.

    This is the computational counterpart to ``draw_spot_radial``: it
    samples ``num_points`` field angles uniformly from 0 to ``self.rfov``
    and returns both the sampled angles and the corresponding distortion
    values, making it easy to pair with other radial evaluation functions.

    Algorithm:
        1. Derive ``rfov_deg`` from ``self.rfov`` (radians → degrees).
        2. Sample ``num_points`` field angles uniformly in
           ``[0, rfov_deg]``.  The on-axis sample (0°) is replaced by a
           tiny positive angle to avoid 0/0.
        3. Compute ``h_ideal = foclen * tan(angle)`` for each sample.
        4. Trace the chief ray (via ``calc_chief_ray_infinite``) through the
           full lens to the sensor plane.
        5. Extract ``h_actual`` from the appropriate transverse coordinate
           (x for sagittal, y for meridional).
        6. Return ``(h_actual - h_ideal) / h_ideal``.

    Args:
        num_points (int): Number of evenly-spaced field-angle samples from
            on-axis (0°) to full-field (``self.rfov``).
            Defaults to ``GEO_GRID``.
        wvln (float): Wavelength in µm. When ``None`` (default), falls
            back to ``self.primary_wvln``.
        plane (str): ``'meridional'`` (y-axis) or ``'sagittal'`` (x-axis).
            Defaults to ``'meridional'``.
        ray_aiming (bool): If ``True``, the chief ray is aimed to pass
            through the center of the aperture stop (more accurate for
            wide-angle lenses). Defaults to ``True``.

    Returns:
        tuple[np.ndarray, np.ndarray]:
            - **rfov_samples**: Field angles in degrees, shape ``[num_points]``.
            - **distortions**: Fractional distortion at each angle, shape
              ``[num_points]``.  Dimensionless (multiply by 100 for
              percent).
    """
    wvln = self.primary_wvln if wvln is None else wvln
    rfov_deg = self.rfov * 180 / torch.pi

    # Sample field angles uniformly from 0 to rfov_deg.
    # For the on-axis point (FOV=0), distortion is 0/0.  We compute it at a
    # tiny positive angle to obtain the correct limit, which may be non-zero
    # when the sensor is not at the paraxial focus.
    rfov_samples = torch.linspace(0, rfov_deg, num_points)
    rfov_compute = rfov_samples.clone()
    if rfov_compute[0] == 0:
        rfov_compute[0] = min(0.01, rfov_samples[1].item() * 0.01)

    # Ideal image height: h_ideal = f * tan(theta)
    eff_foclen = float(self.foclen)
    ideal_imgh = eff_foclen * np.tan(rfov_compute.numpy() * np.pi / 180)

    # Trace chief rays to the sensor plane
    chief_ray_o, chief_ray_d = self.calc_chief_ray_infinite(
        rfov=rfov_compute, wvln=wvln, plane=plane, ray_aiming=ray_aiming
    )
    ray = Ray(chief_ray_o, chief_ray_d, wvln=wvln, device=self.device)
    ray, _ = self.trace(ray)
    t = (self.d_sensor - ray.o[..., 2]) / ray.d[..., 2]

    # Actual image height from the appropriate transverse coordinate
    if plane == "sagittal":
        actual_imgh = (ray.o[..., 0] + ray.d[..., 0] * t).abs()
    elif plane == "meridional":
        actual_imgh = (ray.o[..., 1] + ray.d[..., 1] * t).abs()
    else:
        raise ValueError(f"Invalid plane: {plane}")

    actual_imgh = actual_imgh.cpu().numpy()

    # Fractional distortion, with safe handling of the on-axis singularity
    ideal_imgh = np.asarray(ideal_imgh)
    mask = np.abs(ideal_imgh) < EPSILON
    distortions = np.where(
        mask, 0.0, (actual_imgh - ideal_imgh) / np.where(mask, 1.0, ideal_imgh)
    )

    return rfov_samples.numpy(), distortions

draw_distortion_radial

draw_distortion_radial(save_name=None, num_points=GEO_GRID, wvln=None, plane='meridional', ray_aiming=True, show=False)

Draw distortion-vs-field-angle curve in Zemax style.

Produces a plot with field angle on the y-axis and percent distortion on the x-axis, matching the layout convention used in Zemax OpticStudio. Useful for quick visual assessment of barrel / pincushion distortion.

Algorithm
  1. Call calc_distortion_radial to obtain field angles and fractional distortion values.
  2. Convert distortion to percent and plot.

Parameters:

Name Type Description Default
save_name str | None

File path for the output PNG. If None, auto-generates './{plane}_distortion_inf.png'.

None
num_points int

Number of field-angle samples. Defaults to GEO_GRID.

GEO_GRID
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None
plane str

'meridional' or 'sagittal'. Defaults to 'meridional'.

'meridional'
ray_aiming bool

Whether to use ray aiming for chief-ray computation. Defaults to True.

True
show bool

If True, display interactively. Defaults to False.

False
Source code in deeplens-src/deeplens/geolens_pkg/eval.py
@torch.no_grad()
def draw_distortion_radial(
    self,
    save_name=None,
    num_points=GEO_GRID,
    wvln=None,
    plane="meridional",
    ray_aiming=True,
    show=False,
):
    """Draw distortion-vs-field-angle curve in Zemax style.

    Produces a plot with field angle on the y-axis and percent distortion
    on the x-axis, matching the layout convention used in Zemax OpticStudio.
    Useful for quick visual assessment of barrel / pincushion distortion.

    Algorithm:
        1. Call ``calc_distortion_radial`` to obtain field angles and
           fractional distortion values.
        2. Convert distortion to percent and plot.

    Args:
        save_name (str | None): File path for the output PNG.  If ``None``,
            auto-generates ``'./{plane}_distortion_inf.png'``.
        num_points (int): Number of field-angle samples.
            Defaults to ``GEO_GRID``.
        wvln (float): Wavelength in µm. When ``None`` (default), falls
            back to ``self.primary_wvln``.
        plane (str): ``'meridional'`` or ``'sagittal'``.
            Defaults to ``'meridional'``.
        ray_aiming (bool): Whether to use ray aiming for chief-ray
            computation. Defaults to ``True``.
        show (bool): If ``True``, display interactively. Defaults to ``False``.
    """
    wvln = self.primary_wvln if wvln is None else wvln
    rfov_deg = self.rfov * 180 / torch.pi

    # Calculate distortion at evenly-spaced field angles
    rfov_samples, distortions = self.calc_distortion_radial(
        num_points=num_points, wvln=wvln, plane=plane, ray_aiming=ray_aiming
    )

    # Convert to percentage and handle NaN
    values = np.nan_to_num(distortions * 100, nan=0.0).tolist()

    # Create figure
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_title(f"{plane} Surface Distortion")

    # Draw distortion curve
    ax.plot(values, rfov_samples, linestyle="-", color="g", linewidth=1.5)

    # Draw reference line (vertical line)
    ax.axvline(x=0, color="k", linestyle="-", linewidth=0.8)

    # Set grid
    ax.grid(True, color="gray", linestyle="-", linewidth=0.5, alpha=1)

    # Dynamically adjust x-axis range
    value = max(abs(v) for v in values)
    margin = value * 0.2  # 20% margin
    x_min, x_max = -max(0.2, value + margin), max(0.2, value + margin)

    # Set ticks
    x_ticks = np.linspace(-value, value, 3)
    y_ticks = np.linspace(0, rfov_deg, 3)

    ax.set_xticks(x_ticks)
    ax.set_yticks(y_ticks)

    # Format tick labels
    x_labels = [f"{x:.1f}%" for x in x_ticks]
    y_labels = [f"{y:.1f}" for y in y_ticks]

    ax.set_xticklabels(x_labels)
    ax.set_yticklabels(y_labels)

    # Set axis labels
    ax.set_xlabel("Distortion (%)")
    ax.set_ylabel("Field of View (degrees)")

    # Set axis range
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(0, rfov_deg)

    if show:
        plt.show()
    else:
        if save_name is None:
            save_name = f"./{plane}_distortion_inf.png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

calc_distortion_map

calc_distortion_map(num_grid=16, depth=None, wvln=None)

Compute a 2-D distortion grid mapping ideal to actual image positions.

For each cell in a num_grid × num_grid field grid, rays are traced to the sensor and their centroid is computed. The centroid is then normalized to [-1, 1] sensor coordinates, producing a map that shows how each ideal image point is displaced by lens distortion.

This map can be used with torch.nn.functional.grid_sample to warp or unwarp rendered images.

Parameters:

Name Type Description Default
num_grid int

Grid resolution along each axis. Defaults to 16.

16
depth float

Object distance in mm. When None (default), falls back to self.obj_depth.

None
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None

Returns:

Type Description

torch.Tensor: Distortion grid with shape [num_grid, num_grid, 2]. Each entry (dx, dy) is in normalized sensor coordinates [-1, 1], representing the actual centroid position for the corresponding ideal grid position.

Source code in deeplens-src/deeplens/geolens_pkg/eval.py
@torch.no_grad()
def calc_distortion_map(self, num_grid=16, depth=None, wvln=None):
    """Compute a 2-D distortion grid mapping ideal to actual image positions.

    For each cell in a ``num_grid × num_grid`` field grid, rays are traced
    to the sensor and their centroid is computed.  The centroid is then
    normalized to ``[-1, 1]`` sensor coordinates, producing a map that
    shows how each ideal image point is displaced by lens distortion.

    This map can be used with ``torch.nn.functional.grid_sample`` to warp
    or unwarp rendered images.

    Args:
        num_grid (int): Grid resolution along each axis. Defaults to 16.
        depth (float): Object distance in mm. When ``None`` (default),
            falls back to ``self.obj_depth``.
        wvln (float): Wavelength in µm. When ``None`` (default), falls
            back to ``self.primary_wvln``.

    Returns:
        torch.Tensor: Distortion grid with shape ``[num_grid, num_grid, 2]``.
            Each entry ``(dx, dy)`` is in normalized sensor coordinates
            ``[-1, 1]``, representing the actual centroid position for the
            corresponding ideal grid position.
    """
    wvln = self.primary_wvln if wvln is None else wvln
    depth = self.obj_depth if depth is None else depth
    # Sample and trace rays, shape (grid_size, grid_size, num_rays, 3)
    ray = self.sample_grid_rays(depth=depth, num_grid=num_grid, wvln=wvln, uniform_fov=False)
    ray = self.trace2sensor(ray)

    # Calculate centroid of the rays, shape (grid_size, grid_size, 2).
    # Normalize each axis by its own half-extent so non-square sensors
    # map correctly to [-1, 1]: x by sensor_size[0] (width, W),
    # y by sensor_size[1] (height, H).  Sign is flipped on both axes to
    # undo image inversion, matching ``distortion_center``.
    sensor_w, sensor_h = self.sensor_size
    ray_xy = -ray.centroid()[..., :2]
    x_dist = ray_xy[..., 0] / (sensor_w / 2)
    y_dist = ray_xy[..., 1] / (sensor_h / 2)
    distortion_grid = torch.stack((x_dist, y_dist), dim=-1)
    return distortion_grid

calc_inv_distortion_map

calc_inv_distortion_map(num_grid=16, depth=None, wvln=None)

Compute a grid for applying lens distortion with grid_sample.

For each point on the distorted sensor grid, backward rays are traced through the lens to the target object-depth plane. The traced object intersections are converted to normalized ideal image coordinates. Passing this grid to torch.nn.functional.grid_sample samples an undistorted image and produces a distorted image.

Parameters:

Name Type Description Default
num_grid int or tuple

Grid resolution. If a tuple is supplied, it is interpreted as (grid_w, grid_h).

16
depth float

Object distance in mm. When None (default), falls back to self.obj_depth.

None
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None

Returns:

Type Description

torch.Tensor: Inverse distortion grid with shape

[grid_h, grid_w, 2] in grid_sample coordinates.

Source code in deeplens-src/deeplens/geolens_pkg/eval.py
@torch.no_grad()
def calc_inv_distortion_map(self, num_grid=16, depth=None, wvln=None):
    """Compute a grid for applying lens distortion with ``grid_sample``.

    For each point on the distorted sensor grid, backward rays are traced
    through the lens to the target object-depth plane. The traced object
    intersections are converted to normalized ideal image coordinates.
    Passing this grid to ``torch.nn.functional.grid_sample`` samples an
    undistorted image and produces a distorted image.

    Args:
        num_grid (int or tuple): Grid resolution. If a tuple is supplied,
            it is interpreted as ``(grid_w, grid_h)``.
        depth (float): Object distance in mm. When ``None`` (default),
            falls back to ``self.obj_depth``.
        wvln (float): Wavelength in µm. When ``None`` (default), falls
            back to ``self.primary_wvln``.

    Returns:
        torch.Tensor: Inverse distortion grid with shape
        ``[grid_h, grid_w, 2]`` in ``grid_sample`` coordinates.
    """
    wvln = self.primary_wvln if wvln is None else wvln
    depth = self.obj_depth if depth is None else depth
    if isinstance(num_grid, int):
        num_grid = (num_grid, num_grid)

    grid_w, grid_h = num_grid
    sensor_w, sensor_h = self.sensor_size
    device = self.device

    # Convert grid_sample output coordinates to physical sensor positions.
    # Existing distortion maps use -sensor_centroid as image coordinates.
    x, y = torch.meshgrid(
        torch.linspace(sensor_w / 2, -sensor_w / 2, grid_w, device=device),
        torch.linspace(sensor_h / 2, -sensor_h / 2, grid_h, device=device),
        indexing="xy",
    )
    z = torch.full_like(x, self.d_sensor.item())

    pupilz, pupilr = self.get_exit_pupil()
    ray_o2 = self.sample_circle(r=pupilr, z=pupilz, shape=(grid_h, grid_w, SPP_CALC))
    ray_o = torch.stack((x, y, z), dim=-1).unsqueeze(2).repeat(1, 1, SPP_CALC, 1)
    ray = Ray(ray_o, ray_o2 - ray_o, wvln, device=device)

    ray = self.trace2obj(ray)
    ray = ray.prop_to(depth)
    point_obj = ray.centroid()[..., :2]

    scale = self.calc_scale(depth)
    x_ideal = point_obj[..., 0] / (scale * sensor_w / 2)
    y_ideal = point_obj[..., 1] / (scale * sensor_h / 2)
    inv_distortion_grid = torch.stack((x_ideal, y_ideal), dim=-1)
    return inv_distortion_grid

distortion_center

distortion_center(points)

Compute the distorted image centroid for arbitrary normalized object points.

Given object points in normalized coordinates, this method converts them to physical object-space positions, traces rays from each point through the lens, and returns the ray centroid on the sensor in normalized [-1, 1] coordinates. This is the inverse mapping needed for distortion correction (unwarping).

Algorithm
  1. Convert normalized (x, y) ∈ [-1, 1] to physical object-space positions using self.calc_scale(depth) and self.sensor_size.
  2. self.sample_from_points() generates rays from each point.
  3. self.trace2sensor() propagates rays.
  4. Compute centroid and normalize back to [-1, 1].

Parameters:

Name Type Description Default
points Tensor

Normalized point source positions with shape [N, 3] or [..., 3]. x, y ∈ [-1, 1] encode the field position; z ∈ (-∞, 0] is the object depth in mm.

required

Returns:

Type Description

torch.Tensor: Normalized distortion centroid positions with shape [N, 2] or [..., 2]. x, y ∈ [-1, 1].

Source code in deeplens-src/deeplens/geolens_pkg/eval.py
def distortion_center(self, points):
    """Compute the distorted image centroid for arbitrary normalized object points.

    Given object points in normalized coordinates, this method converts them
    to physical object-space positions, traces rays from each point through
    the lens, and returns the ray centroid on the sensor in normalized
    ``[-1, 1]`` coordinates.  This is the inverse mapping needed for
    distortion correction (unwarping).

    Algorithm:
        1. Convert normalized ``(x, y)`` ∈ [-1, 1] to physical object-space
           positions using ``self.calc_scale(depth)`` and ``self.sensor_size``.
        2. ``self.sample_from_points()`` generates rays from each point.
        3. ``self.trace2sensor()`` propagates rays.
        4. Compute centroid and normalize back to ``[-1, 1]``.

    Args:
        points (torch.Tensor): Normalized point source positions with shape
            ``[N, 3]`` or ``[..., 3]``.  ``x, y`` ∈ [-1, 1] encode the
            field position; ``z`` ∈ (-∞, 0] is the object depth in mm.

    Returns:
        torch.Tensor: Normalized distortion centroid positions with shape
            ``[N, 2]`` or ``[..., 2]``.  ``x, y`` ∈ [-1, 1].
    """
    sensor_w, sensor_h = self.sensor_size

    # Convert normalized points to object space coordinates
    depth = points[..., 2]
    scale = self.calc_scale(depth)
    points_obj_x = points[..., 0] * scale * sensor_w / 2
    points_obj_y = points[..., 1] * scale * sensor_h / 2
    points_obj = torch.stack([points_obj_x, points_obj_y, depth], dim=-1)

    # Sample rays and trace to sensor
    ray = self.sample_from_points(points=points_obj)
    ray = self.trace2sensor(ray)

    # Calculate centroid and normalize to [-1, 1]
    ray_center = -ray.centroid()  # shape [..., 3]
    distortion_center_x = ray_center[..., 0] / (sensor_w / 2)
    distortion_center_y = ray_center[..., 1] / (sensor_h / 2)
    distortion_center = torch.stack((distortion_center_x, distortion_center_y), dim=-1)
    return distortion_center

draw_distortion_map

draw_distortion_map(save_name=None, num_grid=16, depth=None, wvln=None, show=False)

Draw a scatter plot of the distortion grid.

Visualizes the output of calc_distortion_map() as a scatter plot on [-1, 1] normalized sensor coordinates. An undistorted lens would show a perfect rectilinear grid; deviations reveal barrel or pincushion distortion.

Parameters:

Name Type Description Default
save_name str | None

File path for the output PNG. If None, auto-generates './distortion_{depth}.png'.

None
num_grid int

Grid resolution per axis. Defaults to 16.

16
depth float

Object distance in mm. When None (default), falls back to self.obj_depth.

None
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None
show bool

If True, display interactively. Defaults to False.

False
Source code in deeplens-src/deeplens/geolens_pkg/eval.py
@torch.no_grad()
def draw_distortion_map(
    self, save_name=None, num_grid=16, depth=None, wvln=None, show=False
):
    """Draw a scatter plot of the distortion grid.

    Visualizes the output of ``calc_distortion_map()`` as a scatter plot on
    ``[-1, 1]`` normalized sensor coordinates.  An undistorted lens would
    show a perfect rectilinear grid; deviations reveal barrel or pincushion
    distortion.

    Args:
        save_name (str | None): File path for the output PNG.  If ``None``,
            auto-generates ``'./distortion_{depth}.png'``.
        num_grid (int): Grid resolution per axis. Defaults to 16.
        depth (float): Object distance in mm. When ``None`` (default),
            falls back to ``self.obj_depth``.
        wvln (float): Wavelength in µm. When ``None`` (default), falls
            back to ``self.primary_wvln``.
        show (bool): If ``True``, display interactively. Defaults to ``False``.
    """
    wvln = self.primary_wvln if wvln is None else wvln
    depth = self.obj_depth if depth is None else depth
    # Ray tracing to calculate distortion map
    distortion_grid = self.calc_distortion_map(num_grid=num_grid, depth=depth, wvln=wvln)
    # Scale axes so the plot preserves the physical sensor aspect ratio:
    # longer side → ±1, shorter side → ±(shorter/longer).
    sensor_w, sensor_h = self.sensor_size
    max_half = max(sensor_w, sensor_h) / 2
    aspect_x = (sensor_w / 2) / max_half
    aspect_y = (sensor_h / 2) / max_half
    x1 = distortion_grid[..., 0].cpu().numpy() * aspect_x
    y1 = distortion_grid[..., 1].cpu().numpy() * aspect_y

    # Draw image
    fig, ax = plt.subplots()
    ax.set_axisbelow(True)
    ax.grid(True)
    ax.scatter(x1, y1, s=20, zorder=3)
    ax.axis("scaled")

    # Grid lines based on grid_size, scaled per axis so the overlay
    # matches the data extent (±aspect_x × ±aspect_y).
    ax.set_xticks(np.linspace(-aspect_x, aspect_x, num_grid))
    ax.set_yticks(np.linspace(-aspect_y, aspect_y, num_grid))
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.tick_params(length=0)
    for spine in ax.spines.values():
        spine.set_visible(False)

    if show:
        plt.show()
    else:
        depth_str = "inf" if depth == float("inf") else f"{-depth}mm"
        if save_name is None:
            save_name = f"./distortion_{depth_str}.png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

mtf

mtf(fov, wvln=None)

Compute the geometric MTF at a single field position.

The Modulation Transfer Function describes how well the lens preserves contrast as a function of spatial frequency. MTF = 1 at low frequencies (perfect contrast) and falls toward 0 near the diffraction limit or the Nyquist frequency of the sensor.

This implementation uses the geometric (ray-based) approach: 1. Compute the PSF at the given field position via self.psf(). 2. Convert PSF → MTF via psf2mtf() (project onto tangential and sagittal axes, then take the magnitude of the 1-D FFT).

Tangential MTF captures resolution in the meridional (radial) direction; sagittal MTF captures resolution perpendicular to it. The difference between the two indicates astigmatism.

Parameters:

Name Type Description Default
fov float

Field angle in radians. Internally mapped to a normalized point [0, -fov/rfov, self.obj_depth].

required
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None

Returns:

Type Description

tuple[np.ndarray, np.ndarray, np.ndarray]: - freq: Spatial frequency axis in cycles/mm (positive frequencies only, excluding DC). - mtf_tan: Tangential (meridional) MTF values, normalized so that MTF → 1 at low frequency. - mtf_sag: Sagittal MTF values, same normalization.

Source code in deeplens-src/deeplens/geolens_pkg/eval.py
def mtf(self, fov, wvln=None):
    """Compute the geometric MTF at a single field position.

    The *Modulation Transfer Function* describes how well the lens preserves
    contrast as a function of spatial frequency.  MTF = 1 at low frequencies
    (perfect contrast) and falls toward 0 near the diffraction limit or the
    Nyquist frequency of the sensor.

    This implementation uses the *geometric* (ray-based) approach:
        1. Compute the PSF at the given field position via ``self.psf()``.
        2. Convert PSF → MTF via ``psf2mtf()`` (project onto tangential and
           sagittal axes, then take the magnitude of the 1-D FFT).

    Tangential MTF captures resolution in the meridional (radial) direction;
    sagittal MTF captures resolution perpendicular to it.  The difference
    between the two indicates astigmatism.

    Args:
        fov (float): Field angle in radians.  Internally mapped to a
            normalized point ``[0, -fov/rfov, self.obj_depth]``.
        wvln (float): Wavelength in µm. When ``None`` (default), falls
            back to ``self.primary_wvln``.

    Returns:
        tuple[np.ndarray, np.ndarray, np.ndarray]:
            - **freq**: Spatial frequency axis in cycles/mm (positive
              frequencies only, excluding DC).
            - **mtf_tan**: Tangential (meridional) MTF values, normalized
              so that MTF → 1 at low frequency.
            - **mtf_sag**: Sagittal MTF values, same normalization.
    """
    wvln = self.primary_wvln if wvln is None else wvln
    point = [0, -fov / self.rfov, self.obj_depth]
    psf = self.psf(points=point, recenter=True, wvln=wvln)
    freq, mtf_tan, mtf_sag = self.psf2mtf(psf, pixel_size=self.pixel_size)
    return freq, mtf_tan, mtf_sag

psf2mtf staticmethod

psf2mtf(psf, pixel_size)

Convert a 2-D point-spread function to tangential and sagittal MTF curves.

The MTF is the magnitude of the optical transfer function (OTF), which is the Fourier transform of the PSF. For separable 1-D analysis: 1. Integrate the PSF along the x-axis → tangential line-spread function (LSF_tan). 2. Integrate the PSF along the y-axis → sagittal LSF_sag. 3. Take |FFT(LSF)| and normalize by the DC component so that MTF(0) = 1.

Only positive frequencies (excluding DC) are returned, following the convention used in Zemax MTF plots.

Parameters:

Name Type Description Default
psf Tensor | ndarray

2-D PSF with shape [H, W]. The array's y-axis (rows) corresponds to the tangential (meridional) direction; x-axis (columns) to the sagittal direction.

required
pixel_size float

Pixel pitch in mm. Determines the frequency axis scaling: Nyquist = 0.5 / pixel_size cycles/mm.

required

Returns:

Type Description

tuple[np.ndarray, np.ndarray, np.ndarray]: - freq: Spatial frequency in cycles/mm (positive, excluding DC). Length is roughly H // 2. - mtf_tan: Tangential MTF, normalized to 1 at DC. - mtf_sag: Sagittal MTF, normalized to 1 at DC.

References
  • https://en.wikipedia.org/wiki/Optical_transfer_function
  • Edmund Optics: Introduction to Modulation Transfer Function.
Source code in deeplens-src/deeplens/geolens_pkg/eval.py
@staticmethod
def psf2mtf(psf, pixel_size):
    """Convert a 2-D point-spread function to tangential and sagittal MTF curves.

    The MTF is the magnitude of the optical transfer function (OTF), which
    is the Fourier transform of the PSF.  For separable 1-D analysis:
        1. Integrate the PSF along the x-axis → *tangential* line-spread
           function (LSF_tan).
        2. Integrate the PSF along the y-axis → *sagittal* LSF_sag.
        3. Take ``|FFT(LSF)|`` and normalize by the DC component so that
           MTF(0) = 1.

    Only positive frequencies (excluding DC) are returned, following the
    convention used in Zemax MTF plots.

    Args:
        psf (torch.Tensor | np.ndarray): 2-D PSF with shape ``[H, W]``.
            The array's y-axis (rows) corresponds to the **tangential**
            (meridional) direction; x-axis (columns) to the **sagittal**
            direction.
        pixel_size (float): Pixel pitch in mm.  Determines the frequency
            axis scaling: ``Nyquist = 0.5 / pixel_size`` cycles/mm.

    Returns:
        tuple[np.ndarray, np.ndarray, np.ndarray]:
            - **freq**: Spatial frequency in cycles/mm (positive, excluding
              DC).  Length is roughly ``H // 2``.
            - **mtf_tan**: Tangential MTF, normalized to 1 at DC.
            - **mtf_sag**: Sagittal MTF, normalized to 1 at DC.

    References:
        - https://en.wikipedia.org/wiki/Optical_transfer_function
        - Edmund Optics: Introduction to Modulation Transfer Function.
    """
    # Convert to numpy (supports torch tensors and numpy arrays)
    try:
        psf_np = psf.detach().cpu().numpy()
    except AttributeError:
        try:
            psf_np = psf.cpu().numpy()
        except AttributeError:
            psf_np = np.asarray(psf)

    # Compute line spread functions (integrate PSF over orthogonal axes)
    # y-axis corresponds to tangential; x-axis corresponds to sagittal
    lsf_sagittal = psf_np.sum(axis=0)  # function of x
    lsf_tangential = psf_np.sum(axis=1)  # function of y

    # One-sided spectra (for real inputs)
    mtf_sag = np.abs(np.fft.rfft(lsf_sagittal))
    mtf_tan = np.abs(np.fft.rfft(lsf_tangential))

    # Normalize by DC to ensure MTF(0) == 1
    dc_sag = mtf_sag[0] if mtf_sag.size > 0 else 1.0
    dc_tan = mtf_tan[0] if mtf_tan.size > 0 else 1.0
    if dc_sag != 0:
        mtf_sag = mtf_sag / dc_sag
    if dc_tan != 0:
        mtf_tan = mtf_tan / dc_tan

    # Frequency axis in cycles/mm (one-sided)
    fx = np.fft.rfftfreq(lsf_sagittal.size, d=pixel_size)
    freq = fx
    positive_freq_idx = freq > 0

    return (
        freq[positive_freq_idx],
        mtf_tan[positive_freq_idx],
        mtf_sag[positive_freq_idx],
    )

draw_mtf

draw_mtf(save_name='./lens_mtf.png', relative_fov_list=[0.0, 0.7, 1.0], depth_list=None, psf_ks=128, show=False)

Draw a grid of tangential MTF curves for multiple depths and field positions.

Produces a len(depth_list) × len(relative_fov_list) subplot grid. Each subplot shows the tangential MTF for R, G, B wavelengths plus a vertical line at the sensor Nyquist frequency (0.5 / pixel_size cycles/mm).

Algorithm per subplot
  1. Compute the RGB PSF via self.psf_rgb() at the specified (depth, relative_fov) with kernel size psf_ks.
  2. For each wavelength channel, call psf2mtf() to obtain the tangential MTF curve.
  3. Plot frequency vs MTF with RGB coloring.

Parameters:

Name Type Description Default
save_name str

File path for the output PNG. Defaults to './lens_mtf.png'.

'./lens_mtf.png'
relative_fov_list list[float]

Relative field positions in [0, 1], where 0 = on-axis and 1 = full field. Defaults to [0.0, 0.7, 1.0].

[0.0, 0.7, 1.0]
depth_list list[float]

Object distances in mm. float('inf') is automatically replaced by self.obj_depth. When None (default), uses [self.obj_depth].

None
psf_ks int

PSF kernel size in pixels (controls frequency resolution of the resulting MTF). Defaults to 128.

128
show bool

If True, display interactively. Defaults to False.

False
Source code in deeplens-src/deeplens/geolens_pkg/eval.py
@torch.no_grad()
def draw_mtf(
    self,
    save_name="./lens_mtf.png",
    relative_fov_list=[0.0, 0.7, 1.0],
    depth_list=None,
    psf_ks=128,
    show=False,
):
    """Draw a grid of tangential MTF curves for multiple depths and field positions.

    Produces a ``len(depth_list) × len(relative_fov_list)`` subplot grid.
    Each subplot shows the tangential MTF for R, G, B wavelengths plus a
    vertical line at the sensor Nyquist frequency
    (``0.5 / pixel_size`` cycles/mm).

    Algorithm per subplot:
        1. Compute the RGB PSF via ``self.psf_rgb()`` at the specified
           ``(depth, relative_fov)`` with kernel size ``psf_ks``.
        2. For each wavelength channel, call ``psf2mtf()`` to obtain the
           tangential MTF curve.
        3. Plot frequency vs MTF with RGB coloring.

    Args:
        save_name (str): File path for the output PNG.
            Defaults to ``'./lens_mtf.png'``.
        relative_fov_list (list[float]): Relative field positions in
            ``[0, 1]``, where 0 = on-axis and 1 = full field.
            Defaults to ``[0.0, 0.7, 1.0]``.
        depth_list (list[float]): Object distances in mm.
            ``float('inf')`` is automatically replaced by
            ``self.obj_depth``.  When ``None`` (default), uses
            ``[self.obj_depth]``.
        psf_ks (int): PSF kernel size in pixels (controls frequency
            resolution of the resulting MTF). Defaults to 128.
        show (bool): If ``True``, display interactively. Defaults to ``False``.
    """
    if depth_list is None:
        depth_list = [self.obj_depth]
    pixel_size = self.pixel_size
    nyquist_freq = 0.5 / pixel_size
    num_fovs = len(relative_fov_list)
    if float("inf") in depth_list:
        depth_list = [self.obj_depth if x == float("inf") else x for x in depth_list]
    num_depths = len(depth_list)

    # Create figure and subplots (num_depths * num_fovs subplots)
    fig, axs = plt.subplots(
        num_depths, num_fovs, figsize=(num_fovs * 3, num_depths * 3), squeeze=False
    )

    # Iterate over depth and field of view
    for depth_idx, depth in enumerate(depth_list):
        for fov_idx, fov_relative in enumerate(relative_fov_list):
            # Calculate rgb PSF
            point = [0, -fov_relative, depth]
            psf_rgb = self.psf_rgb(points=point, ks=psf_ks, recenter=True)

            # Calculate MTF curves for rgb wavelengths
            for wvln_idx, wvln in enumerate(self.wvln_rgb):
                # Calculate MTF curves from PSF
                psf = psf_rgb[wvln_idx]
                freq, mtf_tan, _ = self.psf2mtf(psf, pixel_size)

                # Plot MTF curves
                ax = axs[depth_idx, fov_idx]
                color = RGB_COLORS[wvln_idx % len(RGB_COLORS)]
                wvln_label = RGB_LABELS[wvln_idx % len(RGB_LABELS)]
                wvln_nm = int(wvln * 1000)
                ax.plot(
                    freq,
                    mtf_tan,
                    color=color,
                    label=f"{wvln_label}({wvln_nm}nm)-Tan",
                )

            # Draw Nyquist frequency
            ax.axvline(
                x=nyquist_freq,
                color="k",
                linestyle=":",
                linewidth=1.2,
                label="Nyquist",
            )

            # Set title and label for subplot
            fov_deg = round(fov_relative * self.rfov * 180 / np.pi, 1)
            depth_str = "inf" if depth == float("inf") else f"{depth}"
            ax.set_title(f"FOV: {fov_deg}deg, Depth: {depth_str}mm", fontsize=8)
            ax.set_xlabel("Spatial Frequency [cycles/mm]", fontsize=8)
            ax.set_ylabel("MTF", fontsize=8)
            ax.legend(fontsize=6)
            ax.tick_params(axis="both", which="major", labelsize=7)
            ax.grid(True)
            ax.set_ylim(0, 1.05)

    plt.tight_layout()
    if show:
        plt.show()
    else:
        assert save_name.endswith(".png"), "save_name must end with .png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

vignetting

vignetting(depth=None, num_grid=32, num_rays=512)

Compute the relative-illumination (vignetting) map across the field.

Vignetting measures how much light is lost at each field position due to rays being clipped by lens apertures or barrel edges. It is computed as the fraction of traced rays that remain valid (not vignetted) at each grid cell, normalized by the total number of launched rays.

A value of 1.0 means all rays reach the sensor (no vignetting); 0.0 means complete light blockage. Real lenses typically show 1.0 on-axis and fall off toward the field edges due to mechanical vignetting and the cos⁴ illumination law.

Algorithm
  1. self.sample_grid_rays() with uniform_fov=False (uniform image-space sampling) to ensure correct sensor-plane mapping.
  2. self.trace2sensor() propagates rays and marks clipped ones as invalid.
  3. Per-cell throughput = count(valid) / num_rays.

Parameters:

Name Type Description Default
depth float

Object distance in mm. When None (default), falls back to self.obj_depth.

None
num_grid int

Grid resolution per axis. Defaults to 32.

32
num_rays int

Rays launched per grid cell. Higher values reduce Monte-Carlo noise. Defaults to 512.

512

Returns:

Type Description

torch.Tensor: Vignetting map with shape [num_grid, num_grid], values in [0, 1].

Source code in deeplens-src/deeplens/geolens_pkg/eval.py
@torch.no_grad()
def vignetting(self, depth=None, num_grid=32, num_rays=512):
    """Compute the relative-illumination (vignetting) map across the field.

    Vignetting measures how much light is lost at each field position due to
    rays being clipped by lens apertures or barrel edges.  It is computed as
    the fraction of traced rays that remain valid (not vignetted) at each
    grid cell, normalized by the total number of launched rays.

    A value of 1.0 means all rays reach the sensor (no vignetting); 0.0
    means complete light blockage.  Real lenses typically show 1.0 on-axis
    and fall off toward the field edges due to mechanical vignetting and the
    cos⁴ illumination law.

    Algorithm:
        1. ``self.sample_grid_rays()`` with ``uniform_fov=False`` (uniform
           image-space sampling) to ensure correct sensor-plane mapping.
        2. ``self.trace2sensor()`` propagates rays and marks clipped ones as
           invalid.
        3. Per-cell throughput = ``count(valid) / num_rays``.

    Args:
        depth (float): Object distance in mm. When ``None`` (default),
            falls back to ``self.obj_depth``.
        num_grid (int): Grid resolution per axis. Defaults to 32.
        num_rays (int): Rays launched per grid cell.  Higher values reduce
            Monte-Carlo noise. Defaults to 512.

    Returns:
        torch.Tensor: Vignetting map with shape ``[num_grid, num_grid]``,
            values in ``[0, 1]``.
    """
    depth = self.obj_depth if depth is None else depth
    # Sample rays in uniform image space (not FOV angles) for correct sensor mapping
    # shape [num_grid, num_grid, num_rays, 3]
    ray = self.sample_grid_rays(
        depth=depth, num_grid=num_grid, num_rays=num_rays, uniform_fov=False
    )

    # Trace rays to sensor
    ray = self.trace2sensor(ray)

    # Calculate vignetting map
    vignetting = ray.is_valid.sum(-1) / (ray.is_valid.shape[-1])
    return vignetting

draw_vignetting

draw_vignetting(filename=None, depth=None, resolution=512, show=False)

Draw the vignetting map as a grayscale image with a colorbar.

Computes the vignetting map via self.vignetting(), bilinearly upsamples it to resolution × resolution, and displays it as a grayscale image where white = no vignetting and black = fully vignetted.

Parameters:

Name Type Description Default
filename str | None

File path for the output PNG. If None, auto-generates './vignetting_{depth}.png'.

None
depth float

Object distance in mm. When None (default), falls back to self.obj_depth.

None
resolution int

Output image size in pixels (square). Defaults to 512.

512
show bool

If True, display interactively. Defaults to False.

False
Source code in deeplens-src/deeplens/geolens_pkg/eval.py
@torch.no_grad()
def draw_vignetting(self, filename=None, depth=None, resolution=512, show=False):
    """Draw the vignetting map as a grayscale image with a colorbar.

    Computes the vignetting map via ``self.vignetting()``, bilinearly
    upsamples it to ``resolution × resolution``, and displays it as a
    grayscale image where white = no vignetting and black = fully vignetted.

    Args:
        filename (str | None): File path for the output PNG.  If ``None``,
            auto-generates ``'./vignetting_{depth}.png'``.
        depth (float): Object distance in mm. When ``None`` (default),
            falls back to ``self.obj_depth``.
        resolution (int): Output image size in pixels (square).
            Defaults to 512.
        show (bool): If ``True``, display interactively. Defaults to ``False``.
    """
    depth = self.obj_depth if depth is None else depth
    # Calculate vignetting map
    vignetting = self.vignetting(depth=depth)

    # Interpolate vignetting map to desired resolution
    vignetting = F.interpolate(
        vignetting.unsqueeze(0).unsqueeze(0),
        size=(resolution, resolution),
        mode="bilinear",
        align_corners=False,
    ).squeeze()

    fig, ax = plt.subplots()
    ax.set_title("Relative Illumination (Vignetting)")
    im = ax.imshow(vignetting.cpu().numpy(), cmap="gray", vmin=0.0, vmax=1.0)
    fig.colorbar(im, ax=ax, ticks=[0.0, 0.25, 0.5, 0.75, 1.0])

    if show:
        plt.show()
    else:
        if filename is None:
            filename = f"./vignetting_{depth}.png"
        plt.savefig(filename, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

calc_chief_ray_infinite

calc_chief_ray_infinite(rfov, depth=0.0, wvln=None, plane='meridional', num_rays=SPP_CALC, ray_aiming=True)

Compute chief rays for one or more field angles with optional ray aiming.

This computes chief rays with vectorized evaluation over multiple field angles and implements ray aiming — an iterative procedure that launches a fan of rays toward the entrance pupil and selects the one that passes closest to the aperture-stop center. Ray aiming is essential for accurate distortion measurement in wide-angle or fisheye lenses where the paraxial approximation breaks down.

Algorithm
  1. For on-axis (rfov = 0): chief ray is trivially along the z-axis.
  2. For off-axis angles with ray_aiming=False: the chief ray is aimed at the entrance pupil center (paraxial approximation).
  3. For off-axis angles with ray_aiming=True: a. Estimate the object-space y (or x) position from the entrance pupil geometry. b. Create a narrow fan of num_rays rays bracketing that estimate (width = 5 % of y_distance, clamped to 0.05 * pupil_radius). c. Trace the fan to the aperture stop. d. Pick the ray closest to the optical axis at the stop.

Parameters:

Name Type Description Default
rfov float | Tensor

Field angle(s) in degrees. A scalar is converted to [0, rfov] (two-element tensor). A tensor of shape [N] is used directly.

required
depth float | Tensor

Object depth(s) in mm. Defaults to 0.0 (object at the first surface).

0.0
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None
plane str

'sagittal' or 'meridional'. Defaults to 'meridional'.

'meridional'
num_rays int

Size of the search fan for ray aiming. Defaults to SPP_CALC.

SPP_CALC
ray_aiming bool

If True, perform iterative ray aiming for accurate chief-ray identification. Defaults to True.

True

Returns:

Type Description

tuple[torch.Tensor, torch.Tensor]: - chief_ray_o: Origins, shape [N, 3]. - chief_ray_d: Unit directions, shape [N, 3].

Source code in deeplens-src/deeplens/geolens_pkg/eval.py
@torch.no_grad()
def calc_chief_ray_infinite(
    self,
    rfov,
    depth=0.0,
    wvln=None,
    plane="meridional",
    num_rays=SPP_CALC,
    ray_aiming=True,
):
    """Compute chief rays for one or more field angles with optional ray aiming.

    This computes chief rays with vectorized evaluation over multiple
    field angles and implements
    *ray aiming* — an iterative procedure that launches a fan of rays
    toward the entrance pupil and selects the one that passes closest to
    the aperture-stop center.  Ray aiming is essential for accurate
    distortion measurement in wide-angle or fisheye lenses where the
    paraxial approximation breaks down.

    Algorithm:
        1. For on-axis (``rfov = 0``): chief ray is trivially along the
           z-axis.
        2. For off-axis angles with ``ray_aiming=False``: the chief ray is
           aimed at the entrance pupil center (paraxial approximation).
        3. For off-axis angles with ``ray_aiming=True``:
           a. Estimate the object-space y (or x) position from the entrance
              pupil geometry.
           b. Create a narrow fan of ``num_rays`` rays bracketing that
              estimate (width = 5 % of y_distance, clamped to
              ``0.05 * pupil_radius``).
           c. Trace the fan to the aperture stop.
           d. Pick the ray closest to the optical axis at the stop.

    Args:
        rfov (float | torch.Tensor): Field angle(s) in **degrees**.
            A scalar is converted to ``[0, rfov]`` (two-element tensor).
            A tensor of shape ``[N]`` is used directly.
        depth (float | torch.Tensor): Object depth(s) in mm.
            Defaults to 0.0 (object at the first surface).
        wvln (float): Wavelength in µm. When ``None`` (default), falls
            back to ``self.primary_wvln``.
        plane (str): ``'sagittal'`` or ``'meridional'``.
            Defaults to ``'meridional'``.
        num_rays (int): Size of the search fan for ray aiming.
            Defaults to ``SPP_CALC``.
        ray_aiming (bool): If ``True``, perform iterative ray aiming for
            accurate chief-ray identification. Defaults to ``True``.

    Returns:
        tuple[torch.Tensor, torch.Tensor]:
            - **chief_ray_o**: Origins, shape ``[N, 3]``.
            - **chief_ray_d**: Unit directions, shape ``[N, 3]``.
    """
    wvln = self.primary_wvln if wvln is None else wvln
    if isinstance(rfov, (int, float)):
        if rfov > 0:
            rfov = torch.linspace(0, rfov, 2, device=self.device)
        else:
            rfov = torch.tensor([float(rfov)], device=self.device)
    else:
        rfov = rfov.to(self.device)

    if not isinstance(depth, torch.Tensor):
        depth = torch.tensor(depth, device=self.device).repeat(len(rfov))

    # set chief ray
    chief_ray_o = torch.zeros([len(rfov), 3], device=self.device)
    chief_ray_d = torch.zeros([len(rfov), 3], device=self.device)

    # Convert rfov to radian
    rfov = rfov * torch.pi / 180.0

    if torch.any(rfov == 0):
        chief_ray_o[0, ...] = torch.tensor(
            [0.0, 0.0, depth[0]], device=self.device, dtype=torch.float32
        )
        chief_ray_d[0, ...] = torch.tensor(
            [0.0, 0.0, 1.0], device=self.device, dtype=torch.float32
        )
        if len(rfov) == 1:
            return chief_ray_o, chief_ray_d

    # Extract non-zero rfov entries for processing
    has_zero = torch.any(rfov == 0)
    if has_zero:
        start_idx = 1
        rfovs = rfov[1:]
        depths = depth[1:]
    else:
        start_idx = 0
        rfovs = rfov
        depths = depth

    if self.aper_idx == 0:
        if plane == "sagittal":
            chief_ray_o[start_idx:, ...] = torch.stack(
                [depths * torch.tan(rfovs), torch.zeros_like(rfovs), depths], dim=-1
            )
            chief_ray_d[start_idx:, ...] = torch.stack(
                [torch.sin(rfovs), torch.zeros_like(rfovs), torch.cos(rfovs)],
                dim=-1,
            )
        else:
            chief_ray_o[start_idx:, ...] = torch.stack(
                [torch.zeros_like(rfovs), depths * torch.tan(rfovs), depths], dim=-1
            )
            chief_ray_d[start_idx:, ...] = torch.stack(
                [torch.zeros_like(rfovs), torch.sin(rfovs), torch.cos(rfovs)],
                dim=-1,
            )

        return chief_ray_o, chief_ray_d

    # Scale factor
    pupilz, pupilr = self.calc_entrance_pupil()
    y_distance = torch.tan(rfovs) * (abs(depths) + pupilz)

    if ray_aiming:
        scale = 0.05
        min_delta = 0.05 * pupilr  # minimum search range based on pupil radius
        delta = torch.clamp(scale * y_distance, min=min_delta)

    if not ray_aiming:
        if plane == "sagittal":
            chief_ray_o[start_idx:, ...] = torch.stack(
                [-y_distance, torch.zeros_like(rfovs), depths], dim=-1
            )
            chief_ray_d[start_idx:, ...] = torch.stack(
                [torch.sin(rfovs), torch.zeros_like(rfovs), torch.cos(rfovs)],
                dim=-1,
            )
        else:
            chief_ray_o[start_idx:, ...] = torch.stack(
                [torch.zeros_like(rfovs), -y_distance, depths], dim=-1
            )
            chief_ray_d[start_idx:, ...] = torch.stack(
                [torch.zeros_like(rfovs), torch.sin(rfovs), torch.cos(rfovs)],
                dim=-1,
            )

    else:
        min_y = -y_distance - delta
        max_y = -y_distance + delta
        t = torch.linspace(0, 1, num_rays, device=min_y.device)
        o1_linspace = min_y.unsqueeze(-1) + t * (max_y - min_y).unsqueeze(-1)

        o1 = torch.zeros([len(rfovs), num_rays, 3], device=self.device)
        o1[:, :, 2] = depths[0]

        o2_linspace = -delta.unsqueeze(-1) + t * (2 * delta).unsqueeze(-1)

        o2 = torch.zeros([len(rfovs), num_rays, 3], device=self.device)
        o2[:, :, 2] = pupilz

        if plane == "sagittal":
            o1[:, :, 0] = o1_linspace
            o2[:, :, 0] = o2_linspace
        else:
            o1[:, :, 1] = o1_linspace
            o2[:, :, 1] = o2_linspace

        # Trace until the aperture
        ray = Ray(o1, o2 - o1, wvln=wvln, device=self.device)
        inc_ray = ray.clone()
        surf_range = range(0, self.aper_idx + 1)
        ray, _ = self.trace(ray, surf_range=surf_range)

        # Look for the ray that is closest to the optical axis
        if plane == "sagittal":
            _, center_idx = torch.min(torch.abs(ray.o[..., 0]), dim=1)
            chief_ray_o[start_idx:, ...] = inc_ray.o[
                torch.arange(len(rfovs)), center_idx.long(), ...
            ]
            chief_ray_d[start_idx:, ...] = torch.stack(
                [torch.sin(rfovs), torch.zeros_like(rfovs), torch.cos(rfovs)],
                dim=-1,
            )
        else:
            _, center_idx = torch.min(torch.abs(ray.o[..., 1]), dim=1)
            chief_ray_o[start_idx:, ...] = inc_ray.o[
                torch.arange(len(rfovs)), center_idx.long(), ...
            ]
            chief_ray_d[start_idx:, ...] = torch.stack(
                [torch.zeros_like(rfovs), torch.sin(rfovs), torch.cos(rfovs)],
                dim=-1,
            )

    return chief_ray_o, chief_ray_d

analysis_rendering

analysis_rendering(img_org, save_name=None, depth=None, spp=SPP_RENDER, unwarp=False, method='ray_tracing', show=False)

Render a test image through the lens and report PSNR / SSIM.

Simulates what the sensor would capture if the given image were placed at the specified object distance. The rendering accounts for all geometric aberrations (blur, distortion, vignetting, chromatic effects). Optionally applies an inverse distortion warp (unwarp) and reports quality metrics for both the raw and unwarped renderings.

Algorithm
  1. Convert img_org to a [1, 3, H, W] float tensor and temporarily set the sensor resolution to match.
  2. Call self.render() with the chosen method (ray tracing or PSF convolution).
  3. Compute PSNR and SSIM between the original and rendered images.
  4. If unwarp=True, apply self.unwarp() to correct geometric distortion and report metrics again.
  5. Restore the original sensor resolution.

Parameters:

Name Type Description Default
img_org ndarray | Tensor

Source image with shape [H, W, 3], either uint8 [0, 255] or float [0, 1].

required
save_name str | None

Path prefix for saved PNGs. If not None, saves '{save_name}.png' and (if unwarped) '{save_name}_unwarped.png'. Defaults to None.

None
depth float

Object distance in mm. When None (default), falls back to self.obj_depth.

None
spp int

Samples (rays) per pixel for rendering. Defaults to SPP_RENDER.

SPP_RENDER
unwarp bool

If True, apply distortion correction after rendering. Defaults to False.

False
method str

Rendering backend — 'ray_tracing' or 'psf_conv'. Defaults to 'ray_tracing'.

'ray_tracing'
show bool

If True, display the result with matplotlib. Defaults to False.

False

Returns:

Type Description

torch.Tensor: Rendered (and optionally unwarped) image with shape [1, 3, H, W], float values in [0, 1].

Source code in deeplens-src/deeplens/geolens_pkg/eval.py
@torch.no_grad()
def analysis_rendering(
    self,
    img_org,
    save_name=None,
    depth=None,
    spp=SPP_RENDER,
    unwarp=False,
    method="ray_tracing",
    show=False,
):
    """Render a test image through the lens and report PSNR / SSIM.

    Simulates what the sensor would capture if the given image were placed
    at the specified object distance.  The rendering accounts for all
    geometric aberrations (blur, distortion, vignetting, chromatic effects).
    Optionally applies an inverse distortion warp (``unwarp``) and reports
    quality metrics for both the raw and unwarped renderings.

    Algorithm:
        1. Convert ``img_org`` to a ``[1, 3, H, W]`` float tensor and
           temporarily set the sensor resolution to match.
        2. Call ``self.render()`` with the chosen method (ray tracing or PSF
           convolution).
        3. Compute PSNR and SSIM between the original and rendered images.
        4. If ``unwarp=True``, apply ``self.unwarp()`` to correct geometric
           distortion and report metrics again.
        5. Restore the original sensor resolution.

    Args:
        img_org (np.ndarray | torch.Tensor): Source image with shape
            ``[H, W, 3]``, either uint8 ``[0, 255]`` or float ``[0, 1]``.
        save_name (str | None): Path prefix for saved PNGs.  If not
            ``None``, saves ``'{save_name}.png'`` and (if unwarped)
            ``'{save_name}_unwarped.png'``. Defaults to ``None``.
        depth (float): Object distance in mm. When ``None`` (default),
            falls back to ``self.obj_depth``.
        spp (int): Samples (rays) per pixel for rendering.
            Defaults to ``SPP_RENDER``.
        unwarp (bool): If ``True``, apply distortion correction after
            rendering. Defaults to ``False``.
        method (str): Rendering backend — ``'ray_tracing'`` or
            ``'psf_conv'``. Defaults to ``'ray_tracing'``.
        show (bool): If ``True``, display the result with matplotlib.
            Defaults to ``False``.

    Returns:
        torch.Tensor: Rendered (and optionally unwarped) image with shape
            ``[1, 3, H, W]``, float values in ``[0, 1]``.
    """
    from skimage.metrics import peak_signal_noise_ratio, structural_similarity
    from torchvision.utils import save_image
    depth = self.obj_depth if depth is None else depth
    # Change sensor resolution to match the image
    sensor_res_original = self.sensor_res
    if isinstance(img_org, np.ndarray):
        img = torch.from_numpy(img_org).permute(2, 0, 1).unsqueeze(0).float() / 255.0
    elif torch.is_tensor(img_org):
        img = img_org.permute(2, 0, 1).unsqueeze(0).float()
        if img.max() > 1.0:
            img = img / 255.0
    img = img.to(self.device)
    self.set_sensor_res(sensor_res=img.shape[-2:])

    # Image rendering
    img_render = self.render(img, depth=depth, method=method, spp=spp)

    # Compute PSNR and SSIM
    img_np = img.squeeze(0).permute(1, 2, 0).cpu().numpy()
    render_np = img_render.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().detach().numpy()
    render_psnr = round(peak_signal_noise_ratio(img_np, render_np, data_range=1.0), 3)
    render_ssim = round(structural_similarity(img_np, render_np, channel_axis=2, data_range=1.0), 4)
    print(f"Rendered image: PSNR={render_psnr:.3f}, SSIM={render_ssim:.4f}")

    # Save image
    if save_name is not None:
        save_image(img_render, f"{save_name}.png")

    # Unwarp to correct geometry distortion
    if unwarp:
        img_render = self.unwarp(img_render, depth)

        # Compute PSNR and SSIM
        render_np = img_render.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().detach().numpy()
        render_psnr = round(peak_signal_noise_ratio(img_np, render_np, data_range=1.0), 3)
        render_ssim = round(structural_similarity(img_np, render_np, channel_axis=2, data_range=1.0), 4)
        print(
            f"Rendered image (unwarped): PSNR={render_psnr:.3f}, SSIM={render_ssim:.4f}"
        )

        if save_name is not None:
            save_image(img_render, f"{save_name}_unwarped.png")

    # Change the sensor resolution back
    self.set_sensor_res(sensor_res=sensor_res_original)

    # Show image
    if show:
        plt.imshow(img_render.cpu().squeeze(0).permute(1, 2, 0).numpy())
        plt.title("Rendered image")
        plt.axis("off")
        plt.show()
        plt.close()

    return img_render

analysis_spot

analysis_spot(num_field=3, depth=float('inf'))

Compute RMS and geometric spot radii at multiple field positions for RGB.

Traces rays at num_field evenly-spaced field positions along the meridional direction for three wavelengths (R, G, B), and computes polychromatic RMS and geometric spot radii referenced to the combined centroid across all wavelengths (matching Zemax's default "RMS Spot Radius w.r.t. Centroid").

This provides a quick polychromatic spot-size summary used for design comparisons and printed to stdout during analysis().

Algorithm (per field point): 1. Trace R, G, B rays through the lens to the sensor. 2. Pool all valid ray intercepts (across all three wavelengths) and compute one combined centroid c. 3. RMS = sqrt(mean(||xy - c||²)) over all pooled rays — a single polychromatic RMS that includes lateral chromatic aberration. 4. radius = max(||xy - c||) over all pooled rays. 5. Convert from mm to μm (× 1000).

Parameters:

Name Type Description Default
num_field int

Number of field positions sampled from on-axis to full-field. Defaults to 3.

3
depth float

Object distance in mm. Use float('inf') for collimated light. Defaults to float('inf').

float('inf')

Returns:

Type Description

dict[str, dict[str, float]]: Spot analysis results keyed by field position string (e.g., 'fov0.0', 'fov0.5', 'fov1.0'). Each value is a dict with: - 'rms': Polychromatic RMS spot radius in μm. - 'radius': Polychromatic geometric spot radius in μm.

Source code in deeplens-src/deeplens/geolens_pkg/eval.py
@torch.no_grad()
def analysis_spot(self, num_field=3, depth=float("inf")):
    """Compute RMS and geometric spot radii at multiple field positions for RGB.

    Traces rays at ``num_field`` evenly-spaced field positions along the
    meridional direction for three wavelengths (R, G, B), and computes
    polychromatic RMS and geometric spot radii referenced to the
    **combined centroid across all wavelengths** (matching Zemax's
    default "RMS Spot Radius w.r.t. Centroid").

    This provides a quick polychromatic spot-size summary used for design
    comparisons and printed to stdout during ``analysis()``.

    Algorithm (per field point):
        1. Trace R, G, B rays through the lens to the sensor.
        2. Pool all valid ray intercepts (across all three wavelengths)
           and compute one combined centroid ``c``.
        3. RMS = sqrt(mean(||xy - c||²)) over all pooled rays — a single
           polychromatic RMS that includes lateral chromatic aberration.
        4. radius = max(||xy - c||) over all pooled rays.
        5. Convert from mm to μm (× 1000).

    Args:
        num_field (int): Number of field positions sampled from on-axis
            to full-field. Defaults to 3.
        depth (float): Object distance in mm.  Use ``float('inf')`` for
            collimated light. Defaults to ``float('inf')``.

    Returns:
        dict[str, dict[str, float]]: Spot analysis results keyed by field
            position string (e.g., ``'fov0.0'``, ``'fov0.5'``, ``'fov1.0'``).
            Each value is a dict with:
                - ``'rms'``: Polychromatic RMS spot radius in μm.
                - ``'radius'``: Polychromatic geometric spot radius in μm.
    """
    # Trace each wavelength and pool rays across wavelengths per field
    xy_list = []
    valid_list = []
    for wvln in self.wvln_rgb:
        ray = self.sample_radial_rays(
            num_field=num_field, depth=depth, num_rays=SPP_PSF, wvln=wvln
        )
        ray = self.trace2sensor(ray)
        xy_list.append(ray.o[..., :2])
        valid_list.append(ray.is_valid)

    # Pool over wavelengths, shape [num_field, 3*num_rays, 2] and [num_field, 3*num_rays]
    xy_all = torch.cat(xy_list, dim=-2)
    valid_all = torch.cat(valid_list, dim=-1)

    # Combined polychromatic centroid per field, shape [num_field, 1, 2]
    valid_mask = valid_all.unsqueeze(-1)
    center = (xy_all * valid_mask).sum(-2) / (
        valid_all.sum(-1, keepdim=True) + EPSILON
    )
    center = center.unsqueeze(-2)

    # Squared distance to combined centroid, shape [num_field, 3*num_rays]
    dist_sq = ((xy_all - center) ** 2).sum(-1)

    # Polychromatic RMS spot radius per field, shape [num_field]
    spot_rms = (
        (dist_sq * valid_all).sum(-1) / (valid_all.sum(-1) + EPSILON)
    ).sqrt()
    # Geometric spot radius (max distance among valid rays)
    dist_masked = torch.where(
        valid_all > 0, dist_sq, torch.full_like(dist_sq, -1.0)
    )
    spot_radius = dist_masked.max(dim=-1).values.clamp(min=0.0).sqrt()

    # Convert mm → μm
    avg_rms_radius_um = spot_rms * 1000.0
    avg_geo_radius_um = spot_radius * 1000.0

    # Print results
    print(f"Ray spot analysis results for depth {depth}:")
    print(
        f"RMS radius: FoV (0.0) {avg_rms_radius_um[0]:.3f} um, FoV (0.5) {avg_rms_radius_um[num_field // 2]:.3f} um, FoV (1.0) {avg_rms_radius_um[-1]:.3f} um"
    )
    print(
        f"Geo radius: FoV (0.0) {avg_geo_radius_um[0]:.3f} um, FoV (0.5) {avg_geo_radius_um[num_field // 2]:.3f} um, FoV (1.0) {avg_geo_radius_um[-1]:.3f} um"
    )

    # Save to dict
    rms_results = {}
    fov_ls = torch.linspace(0, 1, num_field)
    for i in range(num_field):
        fov = round(fov_ls[i].item(), 2)
        rms_results[f"fov{fov}"] = {
            "rms": round(avg_rms_radius_um[i].item(), 4),
            "radius": round(avg_geo_radius_um[i].item(), 4),
        }

    return rms_results

analysis

analysis(save_name='./lens', depth=float('inf'), full_eval=False, render=False, render_unwarp=False, lens_title=None, show=False)

Run a comprehensive optical analysis pipeline for the lens.

This is the main entry point for evaluating a lens design. It chains multiple evaluation steps in order, saving all plots with a common save_name prefix.

Execution flow
  1. Always: draw the lens layout (draw_layout) and compute polychromatic spot RMS/radius (analysis_spot).
  2. If full_eval=True: additionally generate:
  3. Spot diagram (draw_spot_radial).
  4. MTF grid (draw_mtf).
  5. Distortion curve (draw_distortion_radial).
  6. Vignetting map (draw_vignetting).
  7. If render=True: render a test chart image through the lens and report PSNR/SSIM (analysis_rendering).

Parameters:

Name Type Description Default
save_name str

Path prefix for all output files. Each plot appends a suffix (e.g., '_spot.png', '_mtf.png'). Defaults to './lens'.

'./lens'
depth float

Object distance in mm. float('inf') is replaced by self.obj_depth for rendering and vignetting. Defaults to float('inf').

float('inf')
full_eval bool

If True, run all evaluation plots. If False, only layout + spot RMS. Defaults to False.

False
render bool

If True, render a test image through the lens. Defaults to False.

False
render_unwarp bool

If True (and render=True), also produce an unwarped rendering. Defaults to False.

False
lens_title str | None

Title string for the layout plot. Defaults to None.

None
show bool

If True, display all plots interactively. Defaults to False.

False
Source code in deeplens-src/deeplens/geolens_pkg/eval.py
@torch.no_grad()
def analysis(
    self,
    save_name="./lens",
    depth=float("inf"),
    full_eval=False,
    render=False,
    render_unwarp=False,
    lens_title=None,
    show=False,
):
    """Run a comprehensive optical analysis pipeline for the lens.

    This is the main entry point for evaluating a lens design.  It chains
    multiple evaluation steps in order, saving all plots with a common
    ``save_name`` prefix.

    Execution flow:
        1. **Always**: draw the lens layout (``draw_layout``) and compute
           polychromatic spot RMS/radius (``analysis_spot``).
        2. **If** ``full_eval=True``: additionally generate:
           - Spot diagram (``draw_spot_radial``).
           - MTF grid (``draw_mtf``).
           - Distortion curve (``draw_distortion_radial``).
           - Vignetting map (``draw_vignetting``).
        3. **If** ``render=True``: render a test chart image through the
           lens and report PSNR/SSIM (``analysis_rendering``).

    Args:
        save_name (str): Path prefix for all output files.  Each plot
            appends a suffix (e.g., ``'_spot.png'``, ``'_mtf.png'``).
            Defaults to ``'./lens'``.
        depth (float): Object distance in mm.  ``float('inf')`` is replaced
            by ``self.obj_depth`` for rendering and vignetting.
            Defaults to ``float('inf')``.
        full_eval (bool): If ``True``, run all evaluation plots.  If
            ``False``, only layout + spot RMS. Defaults to ``False``.
        render (bool): If ``True``, render a test image through the lens.
            Defaults to ``False``.
        render_unwarp (bool): If ``True`` (and ``render=True``), also
            produce an unwarped rendering. Defaults to ``False``.
        lens_title (str | None): Title string for the layout plot.
            Defaults to ``None``.
        show (bool): If ``True``, display all plots interactively.
            Defaults to ``False``.
    """
    # Draw lens layout and ray path
    self.draw_layout(
        filename=f"{save_name}.png",
        lens_title=lens_title,
        depth=depth,
        show=show,
    )

    # Calculate RMS error
    self.analysis_spot(depth=depth)

    # Comprehensive optical evaluation
    if full_eval:
        # Draw spot diagram
        self.draw_spot_radial(
            save_name=f"{save_name}_spot.png",
            depth=depth,
            show=show,
        )

        # Draw MTF
        if depth == float("inf"):
            self.draw_mtf(
                depth_list=[self.obj_depth],
                save_name=f"{save_name}_mtf.png",
                show=show,
            )
        else:
            self.draw_mtf(
                depth_list=[depth],
                save_name=f"{save_name}_mtf.png",
                show=show,
            )

        # Draw distortion
        self.draw_distortion_radial(
            save_name=f"{save_name}_distortion.png",
            show=show,
        )

        # Draw vignetting
        eval_depth = self.obj_depth if depth == float("inf") else depth
        self.draw_vignetting(
            filename=f"{save_name}_vignetting.png",
            depth=eval_depth,
            show=show,
        )

    # Render an image, compute PSNR and SSIM
    if render:
        depth = self.obj_depth if depth == float("inf") else depth
        img_org = Image.open("./datasets/charts/NBS_1963_1k.png").convert("RGB")
        img_org = np.array(img_org)
        self.analysis_rendering(
            img_org,
            depth=depth,
            spp=SPP_RENDER,
            unwarp=render_unwarp,
            save_name=f"{save_name}_render",
            show=show,
        )

deeplens.geolens_pkg.optim.GeoLensOptim

Mixin providing differentiable optimisation for GeoLens.

Implements gradient-based lens design using PyTorch autograd:

  • Loss functions – RMS spot error, focus, surface regularity, gap constraints, material validity.
  • Constraint initialisation – edge-thickness and self-intersection guards.
  • Optimizer helpers – parameter groups with per-type learning rates and cosine annealing schedules.
  • High-level optimize() – curriculum-learning training loop.

This class is not instantiated directly; it is mixed into GeoLens.

References

Xinge Yang et al., "Curriculum learning for ab initio deep learned refractive optics," Nature Communications 2024.

init_constraints

init_constraints(constraint_params=None)

Initialize constraints for the lens design.

Parameters:

Name Type Description Default
constraint_params dict

Constraint parameters.

None
Source code in deeplens-src/deeplens/geolens_pkg/optim.py
def init_constraints(self, constraint_params=None):
    """Initialize constraints for the lens design.

    Args:
        constraint_params (dict): Constraint parameters.
    """
    # In the future, we want to use constraint_params to set the constraints.
    if constraint_params is None:
        constraint_params = {}

    if self.r_sensor < 12.0:
        self.is_cellphone = True

        self.air_edge_min = 0.05
        self.air_edge_max = 5.0
        self.air_center_min = 0.05
        self.air_center_max = 5.0

        self.thick_edge_min = 0.25
        self.thick_edge_max = 5.0
        self.thick_center_min = 0.25
        self.thick_center_max = 5.0

        self.bfl_min = 0.8
        self.bfl_max = 5.0

        self.ttl_min = 0.0
        self.ttl_max = 50.0

        # Surface shape constraints
        self.sag2diam_max = 0.5
        self.diam2thick_max = 15.0
        self.tmax2tmin_max = 5.0
        self.surf_angle_max = 45.0  # deg

        # Ray angle constraints
        self.chief_ray_angle_max = 45.0  # deg
        self.bend_angle_max = 30.0  # deg

        # Distortion constraint
        self.distortion_max = 0.10  # 10 % relative distortion

    else:
        self.is_cellphone = False

        self.air_edge_min = 0.1
        self.air_edge_max = 100.0  # float("inf")
        self.air_center_min = 0.1
        self.air_center_max = 100.0  # float("inf")

        self.thick_edge_min = 1.0
        self.thick_edge_max = 20.0
        self.thick_center_min = 2.0
        self.thick_center_max = 20.0

        self.bfl_min = 5.0
        self.bfl_max = 100.0  # float("inf")

        self.ttl_min = 0.0  # disabled by default
        self.ttl_max = 300.0  # float("inf")

        # Surface shape constraints
        self.sag2diam_max = 0.5
        self.diam2thick_max = 20.0
        self.tmax2tmin_max = 10.0
        self.surf_angle_max = 45.0  # deg

        # Ray angle constraints
        self.chief_ray_angle_max = 45.0  # deg
        self.bend_angle_max = 30.0  # deg

        # Distortion constraint
        self.distortion_max = 0.02  # 2 % relative distortion

    # Propagate bend angle limit onto every surface so refract() reads it.
    for s in self.surfaces:
        s.bend_angle_max = self.bend_angle_max

loss_reg

loss_reg(w_focus=1.0, w_cra=1.0, w_ray_bend=1.0, w_clearance=1.0, w_envelope=1.0, w_profile=1.0)

Compute combined regularization loss for lens design.

Aggregates multiple constraint losses to keep the lens physically valid during gradient-based optimisation.

Parameters:

Name Type Description Default
w_focus float

Weight for focus loss. Defaults to 1.0.

1.0
w_cra float

Weight for chief ray angle loss. Defaults to 1.0.

1.0
w_ray_bend float

Weight for per-surface bend penalty. Defaults to 1.0.

1.0
w_clearance float

Weight for the clearance penalty (min air gap, min thickness, min BFL, min TTL). Defaults to 1.0.

1.0
w_envelope float

Weight for the envelope penalty (max air gap, max thickness, max BFL, max TTL). Defaults to 1.0.

1.0
w_profile float

Weight for per-surface profile feasibility (sag, slope). Defaults to 1.0.

1.0

Returns:

Name Type Description
tuple

(loss_reg, loss_dict) where: - loss_reg (Tensor): Scalar combined regularization loss. - loss_dict (dict): Per-component loss values for logging.

Source code in deeplens-src/deeplens/geolens_pkg/optim.py
def loss_reg(
    self,
    w_focus=1.0,
    w_cra=1.0,
    w_ray_bend=1.0,
    w_clearance=1.0,
    w_envelope=1.0,
    w_profile=1.0,
):
    """Compute combined regularization loss for lens design.

    Aggregates multiple constraint losses to keep the lens physically valid
    during gradient-based optimisation.

    Args:
        w_focus (float, optional): Weight for focus loss. Defaults to 1.0.
        w_cra (float, optional): Weight for chief ray angle loss. Defaults to 1.0.
        w_ray_bend (float, optional): Weight for per-surface bend penalty. Defaults to 1.0.
        w_clearance (float, optional): Weight for the clearance penalty
            (min air gap, min thickness, min BFL, min TTL). Defaults to 1.0.
        w_envelope (float, optional): Weight for the envelope penalty
            (max air gap, max thickness, max BFL, max TTL). Defaults to 1.0.
        w_profile (float, optional): Weight for per-surface profile
            feasibility (sag, slope). Defaults to 1.0.

    Returns:
        tuple: (loss_reg, loss_dict) where:
            - loss_reg (Tensor): Scalar combined regularization loss.
            - loss_dict (dict): Per-component loss values for logging.
    """
    # Loss functions for regularization
    # loss_focus = self.loss_infocus()
    loss_cra = self.loss_cra()
    loss_ray_bend = self.loss_ray_bend()
    loss_clearance, loss_envelope = self.loss_bound()
    loss_profile = self.loss_profile()
    # loss_mat = self.loss_mat()
    loss_reg = (
        # w_focus * loss_focus
        +w_clearance * loss_clearance
        + w_envelope * loss_envelope
        + w_profile * loss_profile
        + w_cra * loss_cra
        + w_ray_bend * loss_ray_bend
        # w_mat * loss_mat
    )

    # Return loss and loss dictionary
    loss_dict = {
        # "loss_focus": loss_focus.item(),
        "loss_clearance": loss_clearance.item(),
        "loss_envelope": loss_envelope.item(),
        "loss_profile": loss_profile.item(),
        "loss_cra": loss_cra.item(),
        "loss_ray_bend": loss_ray_bend.item(),
        # 'loss_mat': loss_mat.item(),
    }
    return loss_reg, loss_dict

loss_infocus

loss_infocus(target=0.005, wvln=None)

Sample parallel rays and compute RMS loss on the sensor plane, minimize focus loss.

Parameters:

Name Type Description Default
target float

target of RMS loss. Defaults to 0.005 [mm].

0.005
wvln float

Wavelength in µm. When None (default), falls back to the green channel of self.wvln_rgb.

None
Source code in deeplens-src/deeplens/geolens_pkg/optim.py
def loss_infocus(self, target=0.005, wvln=None):
    """Sample parallel rays and compute RMS loss on the sensor plane, minimize focus loss.

    Args:
        target (float, optional): target of RMS loss. Defaults to 0.005 [mm].
        wvln (float, optional): Wavelength in µm.  When ``None`` (default),
            falls back to the green channel of ``self.wvln_rgb``.
    """
    if wvln is None:
        wvln = self.wvln_rgb[1]
    loss = torch.tensor(0.0, device=self.device)

    # Ray tracing and calculate RMS error
    ray = self.sample_from_fov(fov_x=0.0, fov_y=0.0, wvln=wvln, num_rays=SPP_CALC)
    ray = self.trace2sensor(ray)
    rms_error = ray.rms_error()

    # Smooth penalty: activates when rms_error exceeds target
    loss += relu(rms_error - target)

    return loss

loss_profile

loss_profile()

Penalize infeasible per-surface profile shapes.

The "profile" is the z(r) curve of a single surface. This loss makes sure each surface is physically manufacturable by checking: 1. Sag-to-diameter ratio exceeding sag2diam_max. 2. Maximum surface slope angle exceeding surf_angle_max (deg).

Returns:

Name Type Description
Tensor

Scalar profile feasibility penalty.

Source code in deeplens-src/deeplens/geolens_pkg/optim.py
def loss_profile(self):
    """Penalize infeasible per-surface profile shapes.

    The "profile" is the z(r) curve of a single surface. This loss makes
    sure each surface is physically manufacturable by checking:
        1. Sag-to-diameter ratio exceeding ``sag2diam_max``.
        2. Maximum surface slope angle exceeding ``surf_angle_max`` (deg).

    Returns:
        Tensor: Scalar profile feasibility penalty.
    """
    sag2diam_max = self.sag2diam_max
    grad_max = math.tan(math.radians(self.surf_angle_max))

    loss_grad = torch.tensor(0.0, device=self.device)
    loss_sag2diam = torch.tensor(0.0, device=self.device)
    for i in self.find_diff_surf():
        # Sample points on the surface
        x_ls = torch.linspace(0.0, 1.0, 32, device=self.device) * self.surfaces[i].r
        y_ls = torch.zeros_like(x_ls)

        # Sag
        sag_ls = self.surfaces[i].sag(x_ls, y_ls)
        sag2diam = sag_ls.abs().max() / self.surfaces[i].r / 2
        loss_sag2diam += relu(
            (sag2diam - sag2diam_max) / sag2diam_max)

        # 1st-order derivative
        grad_ls = self.surfaces[i].dfdxyz(x_ls, y_ls)[0]
        grad = grad_ls.abs().max()
        loss_grad += relu((grad - grad_max) / grad_max)

        # # Diameter to thickness ratio, thick_max to thick_min ratio
        # if not self.surfaces[i].mat2.name == "air":
        #     surf2 = self.surfaces[i + 1]
        #     surf1 = self.surfaces[i]

        #     # Penalize diameter to thickness ratio
        #     diam2thick = 2 * max(surf2.r, surf1.r) / (surf2.d - surf1.d)
        #     loss_diam2thick += torch.nn.functional.relu(diam2thick - diam2thick_max)

        #     # Penalize thick_max to thick_min ratio.
        #     # Use torch.maximum/minimum for differentiable max/min.
        #     r_edge = min(surf2.r, surf1.r)
        #     thick_center = surf2.d - surf1.d
        #     thick_edge = surf2.surface_with_offset(r_edge, 0.0) - surf1.surface_with_offset(r_edge, 0.0)
        #     thick_max = torch.maximum(thick_center, thick_edge)
        #     thick_min = torch.minimum(thick_center, thick_edge).clamp(min=0.01)
        #     tmax2tmin = thick_max / thick_min

        #     loss_tmax2tmin += torch.nn.functional.relu(tmax2tmin - tmax2tmin_max)

    return loss_sag2diam + loss_grad

loss_bound

loss_bound()

Penalize geometry-bound violations in a single surface-sampling pass.

Each surface pair is sampled once and its distances feed both the clearance (min) and envelope (max) relu penalties for air gaps, glass thickness, BFL, and TTL.

Returns:

Name Type Description
tuple

(loss_clearance, loss_envelope) scalar tensors, so callers can weight them independently. Clearance penalizes parts that are too close / too thin, envelope penalizes the overall assembly growing beyond its spatial budget.

Source code in deeplens-src/deeplens/geolens_pkg/optim.py
def loss_bound(self):
    """Penalize geometry-bound violations in a single surface-sampling pass.

    Each surface pair is sampled once and its distances feed both the
    clearance (min) and envelope (max) relu penalties for air gaps,
    glass thickness, BFL, and TTL.

    Returns:
        tuple: ``(loss_clearance, loss_envelope)`` scalar tensors, so
            callers can weight them independently. Clearance penalizes
            parts that are too close / too thin, envelope penalizes the
            overall assembly growing beyond its spatial budget.
    """
    # Min bounds (clearance)
    air_center_min = self.air_center_min
    air_edge_min = self.air_edge_min
    thick_center_min = self.thick_center_min
    thick_edge_min = self.thick_edge_min
    bfl_min = self.bfl_min
    ttl_min = self.ttl_min

    # Max bounds (envelope)
    air_center_max = self.air_center_max
    air_edge_max = self.air_edge_max
    thick_center_max = self.thick_center_max
    thick_edge_max = self.thick_edge_max
    bfl_max = self.bfl_max
    ttl_max = self.ttl_max

    loss_clearance = torch.tensor(0.0, device=self.device)
    loss_envelope = torch.tensor(0.0, device=self.device)
    air_c_range = air_center_max - air_center_min
    air_e_range = air_edge_max - air_edge_min
    thick_c_range = thick_center_max - thick_center_min
    thick_e_range = thick_edge_max - thick_edge_min
    bfl_range = bfl_max - bfl_min
    ttl_range = ttl_max - ttl_min

    for i in range(len(self.surfaces) - 1):
        current_surf = self.surfaces[i]
        next_surf = self.surfaces[i + 1]

        # Sample surfaces once and reuse for both clearance and envelope
        r_center = torch.tensor(0.0, device=self.device) * current_surf.r
        z_prev_center = current_surf.surface_with_offset(
            r_center, 0.0, valid_check=False
        )
        z_next_center = next_surf.surface_with_offset(
            r_center, 0.0, valid_check=False
        )

        r_edge = torch.linspace(0.5, 1.0, 16, device=self.device) * current_surf.r
        z_prev_edge = current_surf.surface_with_offset(
            r_edge, 0.0, valid_check=False
        )
        z_next_edge = next_surf.surface_with_offset(r_edge, 0.0, valid_check=False)

        dist_center = z_next_center - z_prev_center
        dist_edges = z_next_edge - z_prev_edge
        dist_edge_lo = torch.min(dist_edges)
        dist_edge_hi = torch.max(dist_edges)

        if current_surf.mat2.name == "air":
            loss_clearance += relu((air_center_min - dist_center) / air_c_range)
            loss_clearance += relu((air_edge_min - dist_edge_lo) / air_e_range)
            loss_envelope += relu((dist_center - air_center_max) / air_c_range)
            loss_envelope += relu((dist_edge_hi - air_edge_max) / air_e_range)
        else:
            loss_clearance += relu((thick_center_min - dist_center) / thick_c_range)
            loss_clearance += relu((thick_edge_min - dist_edge_lo) / thick_e_range)
            loss_envelope += relu((dist_center - thick_center_max) / thick_c_range)
            loss_envelope += relu((dist_edge_hi - thick_edge_max) / thick_e_range)

    # Back focal length
    last_surf = self.surfaces[-1]
    r = torch.linspace(0.0, 1.0, 32, device=self.device) * last_surf.r
    z_last_surf = self.d_sensor - last_surf.surface_with_offset(r, 0.0)
    bfl_lo = torch.min(z_last_surf)
    bfl_hi = torch.max(z_last_surf)
    loss_clearance += relu((bfl_min - bfl_lo) / bfl_range)
    loss_envelope += relu((bfl_hi - bfl_max) / bfl_range)

    # Total track length
    ttl = self.d_sensor - self.surfaces[0].d
    loss_clearance += relu((ttl_min - ttl) / ttl_range)
    loss_envelope += relu((ttl - ttl_max) / ttl_range)

    return loss_clearance, loss_envelope

loss_cra

loss_cra()

Penalize chief ray angle at sensor exceeding chief_ray_angle_max.

Uses a near-paraxial pupil sample (scale_pupil=0.2) over the full FoV. Penalty is relu((cos_ref - cos(CRA)) / cra_scale) where cra_scale = 1 - cos_ref normalizes the argument to fractional units of the allowed-to-backward range.

Returns:

Name Type Description
Tensor

Scalar CRA penalty (always >= 0).

Source code in deeplens-src/deeplens/geolens_pkg/optim.py
def loss_cra(self):
    """Penalize chief ray angle at sensor exceeding chief_ray_angle_max.

    Uses a near-paraxial pupil sample (scale_pupil=0.2) over the full FoV.
    Penalty is ``relu((cos_ref - cos(CRA)) / cra_scale)`` where
    ``cra_scale = 1 - cos_ref`` normalizes the argument to fractional units
    of the allowed-to-backward range.

    Returns:
        Tensor: Scalar CRA penalty (always >= 0).
    """
    cos_cra_ref = float(np.cos(np.deg2rad(self.chief_ray_angle_max)))
    cra_scale = 1.0 - cos_cra_ref

    ray = self.sample_ring_arm_rays(num_ring=8, num_arm=2, spp=SPP_CALC, scale_pupil=0.2)
    ray = self.trace2sensor(ray)
    cos_cra = ray.d[..., 2]
    valid = ray.is_valid > 0
    penalty_cra = relu((cos_cra_ref - cos_cra) / cra_scale)
    return (penalty_cra * valid).sum() / (valid.sum() + EPSILON)

loss_ray_bend

loss_ray_bend()

Penalize accumulated per-surface bend angles exceeding bend_angle_max.

Reads ray.bend_penalty, an additive sum of per-surface relu contributions collected during trace2sensor. Each surface contributes independently, so large bends at one surface are not hidden by small bends at another. Uses a full-pupil sample (scale_pupil=1.0).

Returns:

Name Type Description
Tensor

Scalar bend penalty (always >= 0).

Source code in deeplens-src/deeplens/geolens_pkg/optim.py
def loss_ray_bend(self):
    """Penalize accumulated per-surface bend angles exceeding bend_angle_max.

    Reads ``ray.bend_penalty``, an additive sum of per-surface relu
    contributions collected during ``trace2sensor``.  Each surface
    contributes independently, so large bends at one surface are not hidden
    by small bends at another.  Uses a full-pupil sample (scale_pupil=1.0).

    Returns:
        Tensor: Scalar bend penalty (always >= 0).
    """
    ray = self.sample_ring_arm_rays(num_ring=8, num_arm=2, spp=SPP_CALC, scale_pupil=1.0)
    ray = self.trace2sensor(ray)
    bend_penalty = ray.bend_penalty.squeeze(-1)
    valid = ray.is_valid > 0
    return (bend_penalty * valid).sum() / (valid.sum() + EPSILON)

loss_mat

loss_mat()

Penalize material parameters outside manufacturable ranges.

Constrains refractive index n to [1.5, 1.9] and Abbe number V to [30, 70] for each non-air surface material.

Returns:

Name Type Description
Tensor

Scalar material penalty loss.

Source code in deeplens-src/deeplens/geolens_pkg/optim.py
def loss_mat(self):
    """Penalize material parameters outside manufacturable ranges.

    Constrains refractive index *n* to [1.5, 1.9] and Abbe number *V* to
    [30, 70] for each non-air surface material.

    Returns:
        Tensor: Scalar material penalty loss.
    """
    n_max = 1.9
    n_min = 1.5
    V_max = 70
    V_min = 30
    loss_mat = torch.tensor(0.0, device=self.device)
    for i in range(len(self.surfaces)):
        if self.surfaces[i].mat2.name != "air":
            if self.surfaces[i].mat2.n > n_max:
                loss_mat += (self.surfaces[i].mat2.n - n_max) / (n_max - n_min)
            if self.surfaces[i].mat2.n < n_min:
                loss_mat += (n_min - self.surfaces[i].mat2.n) / (n_max - n_min)
            if self.surfaces[i].mat2.V > V_max:
                loss_mat += (self.surfaces[i].mat2.V - V_max) / (V_max - V_min)
            if self.surfaces[i].mat2.V < V_min:
                loss_mat += (V_min - self.surfaces[i].mat2.V) / (V_max - V_min)

    return loss_mat

loss_rms

loss_rms(num_grid=GEO_GRID, depth=None, num_rays=SPP_PSF, sample_more_off_axis=False)

Loss function to compute RGB spot error RMS.

Parameters:

Name Type Description Default
num_grid int

Number of grid points. Defaults to GEO_GRID.

GEO_GRID
depth float

Depth of the lens. When None (default), falls back to self.obj_depth.

None
num_rays int

Number of rays. Defaults to SPP_PSF.

SPP_PSF
sample_more_off_axis bool

Whether to sample more off-axis rays. Defaults to False.

False

Returns:

Name Type Description
avg_rms_error Tensor

RMS error averaged over wavelengths and grid points.

Source code in deeplens-src/deeplens/geolens_pkg/optim.py
def loss_rms(
    self,
    num_grid=GEO_GRID,
    depth=None,
    num_rays=SPP_PSF,
    sample_more_off_axis=False,
):
    """Loss function to compute RGB spot error RMS.

    Args:
        num_grid (int, optional): Number of grid points. Defaults to GEO_GRID.
        depth (float, optional): Depth of the lens. When ``None`` (default),
            falls back to ``self.obj_depth``.
        num_rays (int, optional): Number of rays. Defaults to SPP_PSF.
        sample_more_off_axis (bool, optional): Whether to sample more off-axis rays. Defaults to False.

    Returns:
        avg_rms_error (torch.Tensor): RMS error averaged over wavelengths and grid points.
    """
    depth = self.obj_depth if depth is None else depth
    # Iterate green first so the error-adaptive weight mask is anchored
    # on the reference (green) wavelength.
    loss_rms_ls = []
    w_mask = None
    for i, wvln in enumerate(
        [self.wvln_rgb[1], self.wvln_rgb[0], self.wvln_rgb[2]]
    ):
        ray = self.sample_grid_rays(
            depth=depth,
            num_grid=num_grid,
            num_rays=num_rays,
            wvln=wvln,
            sample_more_off_axis=sample_more_off_axis,
        )

        # Reference center from green chief-ray (pinhole), broadcast to rays.
        if i == 0:
            with torch.no_grad():
                center_ref = -self.psf_center(
                    points_obj=ray.o[:, :, 0, :], method="pinhole"
                )
            center_ref = center_ref.unsqueeze(-2)

        ray = self.trace2sensor(ray)

        # Per-FOV MSE → RMS, zeroing invalid rays before squaring to
        # avoid Inf*0 = NaN.
        ray_xy = ray.o[..., :2]
        ray_valid = ray.is_valid
        ray_err = ray_xy - center_ref
        ray_err = torch.where(
            ray_valid.bool().unsqueeze(-1), ray_err, torch.zeros_like(ray_err)
        )
        mse = (ray_err**2).sum(-1).sum(-1) / (ray_valid.sum(-1) + EPSILON)
        l_rms = (mse + EPSILON).sqrt()

        # First wavelength (green) defines the detached weight mask.
        if w_mask is None:
            w_mask = mse.detach()
            w_mask = w_mask / (w_mask.mean() + EPSILON)

        l_rms_weighted = (l_rms * w_mask).sum() / (w_mask.sum() + EPSILON)
        loss_rms_ls.append(l_rms_weighted)

    avg_rms_error = torch.stack(loss_rms_ls).mean(dim=0)
    return avg_rms_error

sample_ring_arm_rays

sample_ring_arm_rays(num_ring=8, num_arm=2, spp=2048, depth=None, wvln=None, scale_pupil=1.0, sample_more_off_axis=True)

Sample rays from object space using a ring-arm pattern.

This method distributes sampling points (origins of ray bundles) on a polar grid in the object plane, defined by field of view. This is useful for capturing lens performance across the full field. The points include the center and num_ring rings with num_arm points on each.

Uses self.rfov (ray-traced real FoV, accounts for distortion) rather than self.rfov_eff (paraxial pinhole FoV) so the full distorted field is covered.

Parameters:

Name Type Description Default
num_ring int

Number of rings to sample in the field of view.

8
num_arm int

Number of arms (spokes) to sample for each ring.

2
spp int

Total number of rays to be sampled, distributed among field points.

2048
depth float

Depth of the object plane. When None (default), falls back to self.obj_depth.

None
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None
scale_pupil float

Scale factor for the pupil size.

1.0

Returns:

Name Type Description
Ray

A Ray object containing the sampled rays.

Source code in deeplens-src/deeplens/geolens_pkg/optim.py
def sample_ring_arm_rays(
    self,
    num_ring=8,
    num_arm=2,
    spp=2048,
    depth=None,
    wvln=None,
    scale_pupil=1.0,
    sample_more_off_axis=True,
):
    """Sample rays from object space using a ring-arm pattern.

    This method distributes sampling points (origins of ray bundles) on a polar grid in the object plane,
    defined by field of view. This is useful for capturing lens performance across the full field.
    The points include the center and `num_ring` rings with `num_arm` points on each.

    Uses ``self.rfov`` (ray-traced real FoV, accounts for distortion) rather than
    ``self.rfov_eff`` (paraxial pinhole FoV) so the full distorted field is covered.

    Args:
        num_ring (int): Number of rings to sample in the field of view.
        num_arm (int): Number of arms (spokes) to sample for each ring.
        spp (int): Total number of rays to be sampled, distributed among field points.
        depth (float): Depth of the object plane. When ``None`` (default),
            falls back to ``self.obj_depth``.
        wvln (float): Wavelength in µm. When ``None`` (default), falls
            back to ``self.primary_wvln``.
        scale_pupil (float): Scale factor for the pupil size.

    Returns:
        Ray: A Ray object containing the sampled rays.
    """
    wvln = self.primary_wvln if wvln is None else wvln
    depth = self.obj_depth if depth is None else depth
    # Create points on rings and arms
    max_fov_rad = self.rfov
    if sample_more_off_axis:
        beta_values = torch.linspace(0.0, 1.0, num_ring, device=self.device)
        beta_transformed = beta_values**0.5
        ring_fovs = max_fov_rad * beta_transformed
    else:
        ring_fovs = max_fov_rad * torch.linspace(
            0.0, 1.0, num_ring, device=self.device
        )

    arm_angles = torch.linspace(0.0, 2 * torch.pi, num_arm + 1, device=self.device)[
        :-1
    ]
    ring_grid, arm_grid = torch.meshgrid(ring_fovs, arm_angles, indexing="ij")
    x = depth * torch.tan(ring_grid) * torch.cos(arm_grid)
    y = depth * torch.tan(ring_grid) * torch.sin(arm_grid)
    z = torch.full_like(x, depth)
    points = torch.stack([x, y, z], dim=-1)  # shape: [num_ring, num_arm, 3]

    # Sample rays
    rays = self.sample_from_points(
        points=points, num_rays=spp, wvln=wvln, scale_pupil=scale_pupil
    )
    return rays

optimize

optimize(lrs=[0.001, 0.0001, 0.1, 0.0001], iterations=5000, test_per_iter=100, optim_mat=False, shape_control=True, sample_more_off_axis=False, result_dir=None)

Optimise the lens by minimising RGB RMS spot errors.

Runs a curriculum-learning training loop with Adam optimiser and cosine annealing. Periodically evaluates the lens, saves intermediate results, and optionally corrects surface shapes.

Parameters:

Name Type Description Default
lrs list

Learning rates for [d, c, k, a] parameter groups. Defaults to [1e-3, 1e-4, 1e-1, 1e-4].

[0.001, 0.0001, 0.1, 0.0001]
iterations int

Total training iterations. Defaults to 5000.

5000
test_per_iter int

Evaluate and save every N iterations. Defaults to 100.

100
optim_mat bool

If True, include material parameters (n, V) in optimisation. Defaults to False.

False
shape_control bool

If True, call correct_shape() at each evaluation step. Defaults to True.

True
sample_more_off_axis bool

If True, concentrate ray samples toward the edge of the field to improve off-axis correction. Passed directly to sample_ring_arm_rays. Defaults to False.

False
result_dir str

Directory to save results. If None, auto-generates a timestamped directory. Defaults to None.

None
Note

Debug hints: 1. Slowly optimise with small learning rate. 2. FoV and thickness should match well. 3. Keep parameter ranges reasonable. 4. Higher aspheric order is better but more sensitive. 5. More iterations with larger ray sampling improves convergence.

Source code in deeplens-src/deeplens/geolens_pkg/optim.py
def optimize(
    self,
    lrs=[1e-3, 1e-4, 1e-1, 1e-4],
    iterations=5000,
    test_per_iter=100,
    optim_mat=False,
    shape_control=True,
    sample_more_off_axis=False,
    result_dir=None,
):
    """Optimise the lens by minimising RGB RMS spot errors.

    Runs a curriculum-learning training loop with Adam optimiser and cosine
    annealing. Periodically evaluates the lens, saves intermediate results,
    and optionally corrects surface shapes.

    Args:
        lrs (list, optional): Learning rates for [d, c, k, a] parameter groups.
            Defaults to [1e-3, 1e-4, 1e-1, 1e-4].
        iterations (int, optional): Total training iterations. Defaults to 5000.
        test_per_iter (int, optional): Evaluate and save every N iterations.
            Defaults to 100.
        optim_mat (bool, optional): If True, include material parameters (n, V)
            in optimisation. Defaults to False.
        shape_control (bool, optional): If True, call ``correct_shape()`` at each
            evaluation step. Defaults to True.
        sample_more_off_axis (bool, optional): If True, concentrate ray samples
            toward the edge of the field to improve off-axis correction.
            Passed directly to ``sample_ring_arm_rays``. Defaults to False.
        result_dir (str, optional): Directory to save results. If None,
            auto-generates a timestamped directory. Defaults to None.

    Note:
        Debug hints:
            1. Slowly optimise with small learning rate.
            2. FoV and thickness should match well.
            3. Keep parameter ranges reasonable.
            4. Higher aspheric order is better but more sensitive.
            5. More iterations with larger ray sampling improves convergence.
    """
    # Experiment settings
    depth = self.obj_depth
    num_ring = 32
    num_arm = 8
    spp = 2048

    # Result directory and logger
    if result_dir is None:
        result_dir = (
            f"./results/{datetime.now().strftime('%m%d-%H%M%S')}-DesignLens"
        )

    os.makedirs(result_dir, exist_ok=True)
    if not logging.getLogger().hasHandlers():
        logger = logging.getLogger()
        logger.setLevel("DEBUG")
        fmt = logging.Formatter(
            "%(asctime)s:%(levelname)s:%(message)s", "%Y-%m-%d %H:%M:%S"
        )
        sh = logging.StreamHandler()
        sh.setFormatter(fmt)
        sh.setLevel("INFO")
        fh = logging.FileHandler(f"{result_dir}/output.log")
        fh.setFormatter(fmt)
        fh.setLevel("INFO")
        logger.addHandler(sh)
        logger.addHandler(fh)
    logging.info(
        f"lr:{lrs}, iterations:{iterations}, num_ring:{num_ring}, num_arm:{num_arm}, rays_per_fov:{spp}."
    )
    logging.info(
        "If Out-of-Memory, try to reduce num_ring, num_arm, and rays_per_fov."
    )

    # Optimizer and scheduler
    optimizer = self.get_optimizer(lrs, optim_mat=optim_mat)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=100, num_training_steps=iterations
    )

    # Training loop
    pbar = tqdm(
        total=iterations + 1,
        desc="Progress",
        postfix={"loss_rms": 0},
    )
    for i in range(iterations + 1):
        # ===> Evaluate the lens
        if i % test_per_iter == 0:
            with torch.no_grad():
                if shape_control and i > 0:
                    self.correct_shape()

                self.write_lens_json(f"{result_dir}/iter{i}.json")
                self.analysis(f"{result_dir}/iter{i}")

                # Sample rays
                self.calc_pupil()
                rays_backup = []
                for wv in self.wvln_rgb:
                    ray = self.sample_ring_arm_rays(
                        num_ring=num_ring,
                        num_arm=num_arm,
                        spp=spp,
                        depth=depth,
                        wvln=wv,
                        scale_pupil=1.05,
                        sample_more_off_axis=sample_more_off_axis,
                    )
                    rays_backup.append(ray)

                # Pinhole ideal for distortion reference (distortion-free).
                pinhole_ref = -self.psf_center(
                    points_obj=ray.o[:, :, 0, :], method="pinhole"
                )

        # ===> Optimize lens by minimizing RMS
        # Green is traced first: its centroid sets center_ref and drives
        # the distortion penalty; red and blue reuse the same center_ref.
        loss_rms_ls = []
        loss_distortion = torch.tensor(0.0, device=self.device)
        w_mask = None
        center_ref = None
        wvln_order = [1, 0, 2]  # green, red, blue
        for wv_idx in wvln_order:
            # Ray tracing to sensor, [num_ring, num_arm, num_rays, 3]
            ray = rays_backup[wv_idx].clone()
            ray = self.trace2sensor(ray)

            if center_ref is None:
                # Green centroid at sensor, shape [num_ring, num_arm, 2].
                centroid_xy = ray.centroid()[..., :2]

                # Distortion: relative displacement of green centroid from
                # pinhole ideal, averaged equally over all off-axis fields.
                ideal_height = pinhole_ref.norm(dim=-1)
                field_mask = ideal_height > EPSILON
                distortion = (centroid_xy - pinhole_ref).norm(dim=-1)
                distortion = distortion / ideal_height.clamp_min(EPSILON)
                violation = distortion - self.distortion_max
                penalty = relu(violation / self.distortion_max)
                n_fields = field_mask.sum().clamp_min(1)
                loss_distortion = (penalty * field_mask.float()).sum() / n_fields

                # Detach so RMS gradient moves spot shape, not its
                # position; distortion loss handles placement.
                center_ref = centroid_xy.detach().unsqueeze(-2)

            # Ray error to center and valid mask
            ray_valid = ray.is_valid
            ray_err = ray.o[..., :2] - center_ref
            ray_err = torch.where(
                ray_valid.bool().unsqueeze(-1), ray_err, torch.zeros_like(ray_err)
            )

            # MSE per field point, shape [num_ring, num_arm]
            mse = (ray_err**2).sum(-1).sum(-1) / (ray_valid.sum(-1) + EPSILON)

            # Weight mask
            if w_mask is None:
                w_mask = mse.detach().sqrt().clone()
                w_mask = w_mask / (w_mask.mean() + EPSILON)
                w_mask[0, :] = 1.0

            # RMS and weighted loss
            l_rms = torch.clamp(mse, min=EPSILON).sqrt()
            l_rms_weighted = (l_rms * w_mask).sum() / (w_mask.sum() + EPSILON)
            loss_rms_ls.append(l_rms_weighted)

        # RMS loss for all wavelengths
        loss_rms = sum(loss_rms_ls) / len(loss_rms_ls)

        # Total loss
        w_reg = 0.1
        loss_reg, loss_dict = self.loss_reg()
        L_total = loss_rms + w_reg * (loss_reg + loss_distortion)

        # Back-propagation
        optimizer.zero_grad()
        L_total.backward()
        optimizer.step()
        scheduler.step()

        pbar.set_postfix(
            loss_rms=loss_rms.item(),
            loss_dist=loss_distortion.item(),
            **loss_dict,
        )
        pbar.update(1)

    pbar.close()

find_diff_surf

find_diff_surf()

Get differentiable/optimizable surface indices.

Returns a list of surface indices that can be optimized during lens design. Excludes the aperture surface from optimization.

Returns:

Type Description

list or range: Surface indices excluding the aperture.

Source code in deeplens-src/deeplens/geolens_pkg/optim.py
def find_diff_surf(self):
    """Get differentiable/optimizable surface indices.

    Returns a list of surface indices that can be optimized during lens design.
    Excludes the aperture surface from optimization.

    Returns:
        list or range: Surface indices excluding the aperture.
    """
    if self.aper_idx is None:
        diff_surf_range = range(len(self.surfaces))
    else:
        diff_surf_range = list(range(0, self.aper_idx)) + list(
            range(self.aper_idx + 1, len(self.surfaces))
        )
    return diff_surf_range

get_optimizer_params

get_optimizer_params(lrs=[0.0001, 0.0001, 0.01, 0.0001], optim_mat=False, optim_surf_range=None)

Get optimizer parameters for different lens surface.

Recommendation

For cellphone lens: [d, c, k, a], [1e-4, 1e-4, 1e-1, 1e-4] For camera lens: [d, c, 0, 0], [1e-3, 1e-4, 0, 0]

Parameters:

Name Type Description Default
lrs list

learning rate for different parameters.

[0.0001, 0.0001, 0.01, 0.0001]
optim_mat bool

whether to optimize material. Defaults to False.

False
optim_surf_range list

surface indices to be optimized. Defaults to None.

None

Returns:

Name Type Description
list

optimizer parameters

Source code in deeplens-src/deeplens/geolens_pkg/optim.py
def get_optimizer_params(
    self,
    lrs=[1e-4, 1e-4, 1e-2, 1e-4],
    optim_mat=False,
    optim_surf_range=None,
):
    """Get optimizer parameters for different lens surface.

    Recommendation:
        For cellphone lens: [d, c, k, a], [1e-4, 1e-4, 1e-1, 1e-4]
        For camera lens: [d, c, 0, 0], [1e-3, 1e-4, 0, 0]

    Args:
        lrs (list): learning rate for different parameters.
        optim_mat (bool): whether to optimize material. Defaults to False.
        optim_surf_range (list): surface indices to be optimized. Defaults to None.

    Returns:
        list: optimizer parameters
    """
    # Find surfaces to be optimized
    if optim_surf_range is None:
        # optim_surf_range = self.find_diff_surf()
        optim_surf_range = range(len(self.surfaces))

    # If lr for each surface is a list is given
    if isinstance(lrs[0], list):
        return self.get_optimizer_params_manual(
            lrs=lrs, optim_mat=optim_mat, optim_surf_range=optim_surf_range
        )

    # Optimize lens surface parameters
    params = []
    for surf_idx in optim_surf_range:
        surf = self.surfaces[surf_idx]

        if isinstance(surf, Aperture):
            params += surf.get_optimizer_params(lrs=[lrs[0]])

        elif isinstance(surf, Aspheric):
            params += surf.get_optimizer_params(lrs=lrs[:4], optim_mat=optim_mat)

        elif isinstance(surf, Phase):
            params += surf.get_optimizer_params(lrs=[lrs[0], lrs[4]])

        # elif isinstance(surf, GaussianRBF):
        #     params += surf.get_optimizer_params(lrs=lr, optim_mat=optim_mat)

        # elif isinstance(surf, NURBS):
        #     params += surf.get_optimizer_params(lrs=lr, optim_mat=optim_mat)

        elif isinstance(surf, Plane):
            params += surf.get_optimizer_params(lrs=[lrs[0]], optim_mat=optim_mat)

        # elif isinstance(surf, PolyEven):
        #     params += surf.get_optimizer_params(lrs=lr, optim_mat=optim_mat)

        elif isinstance(surf, Spheric):
            params += surf.get_optimizer_params(
                lrs=[lrs[0], lrs[1]], optim_mat=optim_mat
            )

        elif isinstance(surf, ThinLens):
            params += surf.get_optimizer_params(
                lrs=[lrs[0], lrs[1]], optim_mat=optim_mat
            )

        else:
            raise Exception(
                f"Surface type {surf.__class__.__name__} is not supported for optimization yet."
            )

    # Optimize sensor place
    self.d_sensor.requires_grad = True
    params += [{"params": self.d_sensor, "lr": lrs[0]}]

    return params

get_optimizer

get_optimizer(lrs=[0.0001, 0.0001, 0.1, 0.0001], optim_surf_range=None, optim_mat=False)

Build an Adam optimizer over all trainable lens parameters.

Parameters:

Name Type Description Default
lrs list

learning rates for parameter groups [d, c, k, ai]. Defaults to [1e-4, 1e-4, 1e-1, 1e-4].

[0.0001, 0.0001, 0.1, 0.0001]
optim_surf_range list

surface indices to optimise. If None, all surfaces are included. Defaults to None.

None
optim_mat bool

whether to include material parameters (n, V). Defaults to False.

False

Returns:

Type Description

torch.optim.Adam: configured optimizer.

Source code in deeplens-src/deeplens/geolens_pkg/optim.py
def get_optimizer(
    self,
    lrs=[1e-4, 1e-4, 1e-1, 1e-4],
    optim_surf_range=None,
    optim_mat=False,
):
    """Build an Adam optimizer over all trainable lens parameters.

    Args:
        lrs (list): learning rates for parameter groups [d, c, k, ai].
            Defaults to [1e-4, 1e-4, 1e-1, 1e-4].
        optim_surf_range (list): surface indices to optimise. If None,
            all surfaces are included. Defaults to None.
        optim_mat (bool): whether to include material parameters (n, V).
            Defaults to False.

    Returns:
        torch.optim.Adam: configured optimizer.
    """
    # Get optimizer
    params = self.get_optimizer_params(
        lrs=lrs, optim_surf_range=optim_surf_range, optim_mat=optim_mat
    )
    optimizer = torch.optim.Adam(params)
    # optimizer = torch.optim.SGD(params)
    return optimizer

deeplens.geolens_pkg.optim_ops.GeoLensSurfOps

Mixin providing surface geometry operations for GeoLens.

Methods:

Name Description
- add_aspheric

Convert a spherical surface to aspheric.

- increase_aspheric_order

Add higher-order polynomial terms.

- prune_surf

Size clear apertures by ray tracing.

- correct_shape

Fix lens geometry during optimisation.

add_aspheric

add_aspheric(surf_idx=None, ai_degree=4)

Convert a spherical surface to aspheric for improved aberration correction.

If surf_idx is given, converts that specific surface. Otherwise, automatically selects the best candidate following established optical design principles:

  1. First asphere: placed near the aperture stop (corrects spherical aberration).
  2. Subsequent aspheres: placed far from the stop (corrects field-dependent aberrations like coma, astigmatism, distortion).
  3. Prefer air-glass interfaces over cemented surfaces.
  4. Among candidates at similar stop-distances, prefer larger semi-diameter (higher marginal ray height → more SA contribution).

The new surface starts with k=0 and all polynomial coefficients at zero, so it is initially identical to the original spherical surface.

Note

After calling this method, any existing optimizer is stale. Call get_optimizer() again to include the new parameters.

Parameters:

Name Type Description Default
surf_idx int or None

Surface index to convert. If None, auto-selects the best candidate.

None
ai_degree int

Number of even-order aspheric coefficients [a2, a4, a6, ...]. Defaults to 4.

4

Returns:

Name Type Description
int

Index of the converted surface.

Raises:

Type Description
IndexError

If surf_idx is out of range.

ValueError

If surf_idx points to a non-Spheric surface, or no eligible candidate exists for auto-selection.

References

Design principles from research/aspheric_design_principles.md.

Source code in deeplens-src/deeplens/geolens_pkg/optim_ops.py
@torch.no_grad()
def add_aspheric(self, surf_idx=None, ai_degree=4):
    """Convert a spherical surface to aspheric for improved aberration correction.

    If ``surf_idx`` is given, converts that specific surface. Otherwise,
    automatically selects the best candidate following established optical
    design principles:

    1. First asphere: placed near the aperture stop (corrects spherical
       aberration).
    2. Subsequent aspheres: placed far from the stop (corrects field-dependent
       aberrations like coma, astigmatism, distortion).
    3. Prefer air-glass interfaces over cemented surfaces.
    4. Among candidates at similar stop-distances, prefer larger semi-diameter
       (higher marginal ray height → more SA contribution).

    The new surface starts with ``k=0`` and all polynomial coefficients at
    zero, so it is initially identical to the original spherical surface.

    Note:
        After calling this method, any existing optimizer is stale.
        Call ``get_optimizer()`` again to include the new parameters.

    Args:
        surf_idx (int or None): Surface index to convert. If ``None``,
            auto-selects the best candidate.
        ai_degree (int): Number of even-order aspheric coefficients
            ``[a2, a4, a6, ...]``. Defaults to 4.

    Returns:
        int: Index of the converted surface.

    Raises:
        IndexError: If ``surf_idx`` is out of range.
        ValueError: If ``surf_idx`` points to a non-Spheric surface, or no
            eligible candidate exists for auto-selection.

    References:
        Design principles from ``research/aspheric_design_principles.md``.
    """
    if surf_idx is not None:
        if surf_idx < 0 or surf_idx >= len(self.surfaces):
            raise IndexError(
                f"surf_idx={surf_idx} out of range [0, {len(self.surfaces) - 1}]."
            )
        if not isinstance(self.surfaces[surf_idx], Spheric):
            raise ValueError(
                f"Surface {surf_idx} is {type(self.surfaces[surf_idx]).__name__}, "
                f"expected Spheric. To add higher-order terms to an existing "
                f"Aspheric surface, use increase_aspheric_order(surf_idx={surf_idx})."
            )
        self._spheric_to_aspheric(surf_idx, ai_degree)
        logging.info(
            f"Converted surface {surf_idx} from Spheric to Aspheric "
            f"(ai_degree={ai_degree})."
        )
        return surf_idx

    # Auto-select best candidate
    surf_idx = self._find_best_asphere_candidate()
    self._spheric_to_aspheric(surf_idx, ai_degree)
    logging.info(
        f"Auto-selected surface {surf_idx} as best asphere candidate. "
        f"Converted to Aspheric (ai_degree={ai_degree})."
    )
    return surf_idx

increase_aspheric_order

increase_aspheric_order(surf_idx=None, increment=1)

Add higher-order polynomial terms to existing Aspheric surfaces.

Appends increment additional even-order coefficients (initialised to zero). For example, degree 4 [a4, a6, a8, a10] becomes degree 5 [a4, a6, a8, a10, a12] after increment=1.

Follows the principle of start low, add incrementally: increase order only when residual higher-order aberrations persist after optimisation at the current order.

Note

After calling this method, any existing optimizer is stale. Call get_optimizer() again to include the new parameters.

Parameters:

Name Type Description Default
surf_idx int or None

Surface index. If None, auto-selects the best candidate (see _find_best_order_increase_candidate).

None
increment int

Number of additional coefficients to add. Defaults to 1.

1

Returns:

Name Type Description
int

Index of the surface whose order was increased.

Raises:

Type Description
IndexError

If surf_idx is out of range.

ValueError

If surf_idx is given but is not Aspheric, if no Aspheric surfaces exist when surf_idx is None, or if increment < 1.

Source code in deeplens-src/deeplens/geolens_pkg/optim_ops.py
@torch.no_grad()
def increase_aspheric_order(self, surf_idx=None, increment=1):
    """Add higher-order polynomial terms to existing Aspheric surfaces.

    Appends ``increment`` additional even-order coefficients (initialised
    to zero). For example, degree 4 ``[a4, a6, a8, a10]`` becomes degree 5
    ``[a4, a6, a8, a10, a12]`` after ``increment=1``.

    Follows the principle of *start low, add incrementally*: increase
    order only when residual higher-order aberrations persist after
    optimisation at the current order.

    Note:
        After calling this method, any existing optimizer is stale.
        Call ``get_optimizer()`` again to include the new parameters.

    Args:
        surf_idx (int or None): Surface index. If ``None``, auto-selects
            the best candidate (see ``_find_best_order_increase_candidate``).
        increment (int): Number of additional coefficients to add.
            Defaults to 1.

    Returns:
        int: Index of the surface whose order was increased.

    Raises:
        IndexError: If ``surf_idx`` is out of range.
        ValueError: If ``surf_idx`` is given but is not Aspheric, if
            no Aspheric surfaces exist when ``surf_idx`` is ``None``,
            or if ``increment`` < 1.
    """
    if increment < 1:
        raise ValueError(f"increment must be >= 1, got {increment}.")
    if surf_idx is not None:
        if surf_idx < 0 or surf_idx >= len(self.surfaces):
            raise IndexError(
                f"surf_idx={surf_idx} out of range [0, {len(self.surfaces) - 1}]."
            )
    else:
        surf_idx = self._find_best_order_increase_candidate()

    surf = self.surfaces[surf_idx]
    if not isinstance(surf, Aspheric):
        raise ValueError(
            f"Surface {surf_idx} is {type(surf).__name__}, expected Aspheric."
        )
    old_degree = surf.ai_degree
    self._increase_surface_order(surf, increment)
    logging.info(
        f"Surface {surf_idx}: aspheric order {old_degree} -> {surf.ai_degree}."
    )

    return surf_idx

prune_surf

prune_surf(mounting_margin=None)

Prune surfaces to allow all valid rays to go through.

Determines the clear aperture for each surface by ray tracing, then adds a mounting margin and enforces manufacturability constraints (edge thickness and air-gap clearance).

Parameters:

Name Type Description Default
mounting_margin float

Absolute mounting margin in mm added to the ray-traced clear aperture radius. If None, the margin is auto-selected per surface: 10 % of the ray-traced radius when it is below 5 mm, otherwise 1 mm.

None
Source code in deeplens-src/deeplens/geolens_pkg/optim_ops.py
@torch.no_grad()
def prune_surf(self, mounting_margin=None):
    """Prune surfaces to allow all valid rays to go through.

    Determines the clear aperture for each surface by ray tracing, then
    adds a mounting margin and enforces manufacturability constraints
    (edge thickness and air-gap clearance).

    Args:
        mounting_margin (float, optional): Absolute mounting margin in mm
            added to the ray-traced clear aperture radius. If ``None``,
            the margin is auto-selected per surface: 10 % of the
            ray-traced radius when it is below 5 mm, otherwise 1 mm.
    """
    surface_range = self.find_diff_surf()
    num_surfs = len(self.surfaces)

    # ------------------------------------------------------------------
    # 1. Temporarily remove radius limits so the trace is unclipped
    # ------------------------------------------------------------------
    saved_radii = [self.surfaces[i].r for i in range(num_surfs)]
    for i in surface_range:
        self.surfaces[i].r = self.surfaces[i].max_height()

    # ------------------------------------------------------------------
    # 2. Trace rays at full FoV to find maximum ray height per surface
    # ------------------------------------------------------------------
    assert self.rfov is not None, "prune_surf() requires self.rfov."
    fov_deg = self.rfov * 180 / torch.pi
    num_fov_samples = 16
    fov_y = torch.linspace(0.0, fov_deg, num_fov_samples, device=self.device)
    ray = self.sample_from_fov(fov_x=[0.0], fov_y=fov_y)
    _, ray_o_record = self.trace2sensor(ray=ray, record=True)

    # Ray record, shape [num_rays, num_surfaces + 2, 3]
    ray_o_record = torch.stack(ray_o_record, dim=-2)
    ray_o_record = torch.nan_to_num(
        ray_o_record, nan=0.0, posinf=0.0, neginf=0.0
    )
    ray_o_record = ray_o_record.reshape(-1, ray_o_record.shape[-2], 3)

    # Compute the maximum ray height for each surface
    ray_r_record = (ray_o_record[..., :2] ** 2).sum(-1).sqrt()
    surf_r_max = ray_r_record.max(dim=0)[0][1:-1]

    # ------------------------------------------------------------------
    # 3. Propose new radii (not yet committed to surfaces).
    # ------------------------------------------------------------------
    proposed_r = [float(self.surfaces[i].r) for i in range(num_surfs)]
    for i in surface_range:
        # Surface radius required by ray tracing
        if surf_r_max[i] > 0:
            base = float(surf_r_max[i].item())
        else:
            base = float(self.surfaces[i].r)

        # Expand the ray-traced radius by a mounting margin
        if mounting_margin is None:
            r_expand = 0.05 * base if base < 5.0 else 1.0
        else:
            r_expand = float(mounting_margin)

        # Propose the new radius, capped at the surface's physical maximum height
        proposed_r[i] = min(base + r_expand, float(self.surfaces[i].max_height()))

    # ------------------------------------------------------------------
    # 3b. Sag cap: edge sag must not exceed sag_factor * proposed radius.
    # Grid-search for the largest r in [r_min, proposed_r] where the
    # constraint holds. The grid is dense enough for typical aspheric sag
    # profiles; non-monotonic extremes are handled conservatively.
    # ------------------------------------------------------------------
    sag_factor=0.4
    for i in surface_range:
        if not isinstance(self.surfaces[i], Aperture):
            r_prop = proposed_r[i]
            r_cands = torch.linspace(r_prop / 64, r_prop, 64, device=self.device)
            z0 = self.surfaces[i].surface_with_offset(
                torch.tensor(0.0, device=self.device), 0.0, valid_check=False
            )
            z_cands = self.surfaces[i].surface_with_offset(
                r_cands, torch.zeros_like(r_cands), valid_check=False
            )
            sag_valid = (z_cands - z0).abs() <= sag_factor * r_cands
            if sag_valid.any():
                proposed_r[i] = min(r_prop, float(r_cands[sag_valid].max().item()))
            else:
                proposed_r[i] = float(r_cands[0].item())

    # ------------------------------------------------------------------
    # 4. Edge-clearance pass — proactively cap adjacent pairs so the
    #    committed radii never produce self-intersection at the edge.
    #    Thresholds match loss_bound. The cap uses the common
    #    clear-aperture overlap between adjacent surfaces so one surface is
    #    not pruned against regions where the neighbour has already been
    #    apertured away. Aperture surfaces are skipped; the stop size is an
    #    optical specification and should not be changed by pruning. The cap
    #    is computed via a single vectorized grid search rather than a
    #    serial binary loop.
    #
    #    Each pruned surface is checked against both neighbours. The
    #    previous implementation only capped surface i against i + 1,
    #    which allowed surface i to expand into i - 1 and later crash
    #    tracing/optimization.
    # ------------------------------------------------------------------
    min_radius_floor = 0.1  # mm — guard against update_r(0) killing a surface
    n_cand = 64
    n_edge = 64
    r_frac = torch.linspace(0.5, 1.0, n_edge, device=self.device)
    cand_frac = torch.linspace(1.0 / n_cand, 1.0, n_cand, device=self.device)

    def cap_radius_against_pair(cap_idx, prev_idx, next_idx):
        prev_surf = self.surfaces[prev_idx]
        next_surf = self.surfaces[next_idx]
        if isinstance(prev_surf, Aperture) or isinstance(next_surf, Aperture):
            return
        if isinstance(self.surfaces[cap_idx], Aperture):
            return

        edge_min = 0.1 # mm
        r_check = proposed_r[cap_idx]

        other_idx = next_idx if cap_idx == prev_idx else prev_idx
        other_r = proposed_r[other_idx]

        required_r = max(
            float(surf_r_max[cap_idx].item()),
            min_radius_floor,
        )

        # Vectorized cap: evaluate gap for 64 candidate radii in one pass.
        cand_r = cand_frac * r_check
        cand_overlap_r = torch.minimum(
            cand_r, torch.tensor(other_r, device=self.device)
        )
        r_grid = cand_overlap_r.unsqueeze(1) * r_frac.unsqueeze(0)
        z_prev_grid = prev_surf.surface_with_offset(
            r_grid.reshape(-1), 0.0, valid_check=False
        ).reshape(n_cand, n_edge)
        z_next_grid = next_surf.surface_with_offset(
            r_grid.reshape(-1), 0.0, valid_check=False
        ).reshape(n_cand, n_edge)
        per_cand_gap = (z_next_grid - z_prev_grid).min(dim=-1).values
        overlap_ok = per_cand_gap >= edge_min

        # Sag-bracket: the cap surface's edge z (at candidate r) must not
        # axially cross the other surface's edge z. Catches the case
        # where high-order aspheric terms blow up beyond the surface's
        # design r and drag its edge past the neighbour, while the
        # in-overlap gap above is still fine.
        cap_surf = self.surfaces[cap_idx]
        other_surf = self.surfaces[other_idx]
        z_other_edge = other_surf.surface_with_offset(
            torch.tensor(other_r, device=self.device),
            torch.tensor(0.0, device=self.device),
            valid_check=False,
        )
        z_cap_at_cand = cap_surf.surface_with_offset(
            cand_r, torch.zeros_like(cand_r), valid_check=False
        )
        if cap_idx > other_idx:
            # cap is later in light path — must stay axially after other
            bracket_ok = z_cap_at_cand > z_other_edge + edge_min
        else:
            # cap is earlier — must stay axially before other
            bracket_ok = z_cap_at_cand < z_other_edge - edge_min

        valid_mask = overlap_ok & bracket_ok
        if not bool(valid_mask.any()):
            logging.warning(
                f"Surf {prev_idx}-{next_idx} "
                f"({prev_surf.mat2.name}): no candidate "
                f"radius satisfies edge_min {edge_min:.3f} mm at "
                f"r_check {r_check:.3f} mm (possible sag crossing near "
                f"axis). Reducing surface {cap_idx} to the ray-required radius "
                f"{required_r:.3f} mm, but edge clearance may remain "
                f"violated."
            )
            proposed_r[cap_idx] = min(proposed_r[cap_idx], required_r)
            return

        r_safe = float((cand_frac[valid_mask].max() * r_check).item())
        if r_safe < required_r:
            logging.warning(
                f"Surf {prev_idx}-{next_idx} "
                f"({prev_surf.mat2.name}): ray-required "
                f"radius {required_r:.3f} mm exceeds edge-clearance-safe "
                f"radius {r_safe:.3f} mm for edge_min {edge_min:.3f} mm. "
                f"Reducing surface {cap_idx} to the ray-required radius; edge "
                f"clearance may remain violated."
            )
            proposed_r[cap_idx] = min(proposed_r[cap_idx], required_r)
            return

        r_safe = max(r_safe, min_radius_floor)
        if proposed_r[cap_idx] > r_safe:
            proposed_r[cap_idx] = r_safe

    for i in surface_range:
        if i > 0:
            cap_radius_against_pair(i, i - 1, i)
        if i < num_surfs - 1:
            cap_radius_against_pair(i, i, i + 1)

    # ------------------------------------------------------------------
    # 4b. Commit the capped proposed radii to the surfaces.
    # ------------------------------------------------------------------
    for i in surface_range:
        if proposed_r[i] > 0:
            self.surfaces[i].update_r(proposed_r[i])

correct_shape

correct_shape(mounting_margin=None)

Correct wrong lens shape during lens design optimization.

Applies correction rules to ensure valid lens geometry
  1. Move the first surface to z = 0.0
  2. Prune all surfaces to allow valid rays through

Parameters:

Name Type Description Default
mounting_margin float

Absolute mounting margin [mm] for surface pruning. Passed through to prune_surf.

None
Source code in deeplens-src/deeplens/geolens_pkg/optim_ops.py
@torch.no_grad()
def correct_shape(self, mounting_margin=None):
    """Correct wrong lens shape during lens design optimization.

    Applies correction rules to ensure valid lens geometry:
        1. Move the first surface to z = 0.0
        2. Prune all surfaces to allow valid rays through

    Args:
        mounting_margin (float, optional): Absolute mounting margin [mm] for
            surface pruning.  Passed through to `prune_surf`.
    """
    # Rule 1: Move the first surface to z = 0.0
    move_dist = self.surfaces[0].d.item()
    for surf in self.surfaces:
        surf.d -= move_dist
    self.d_sensor -= move_dist

    # Rule 2: Prune all surfaces
    self.prune_surf(mounting_margin=mounting_margin)

match_materials

match_materials(mat_table='CDGM')

Match lens materials to a glass catalog.

Parameters:

Name Type Description Default
mat_table str

Glass catalog name. Common options include 'CDGM', 'SCHOTT', 'OHARA'. Defaults to 'CDGM'.

'CDGM'
Source code in deeplens-src/deeplens/geolens_pkg/optim_ops.py
@torch.no_grad()
def match_materials(self, mat_table="CDGM"):
    """Match lens materials to a glass catalog.

    Args:
        mat_table (str, optional): Glass catalog name. Common options include
            'CDGM', 'SCHOTT', 'OHARA'. Defaults to 'CDGM'.
    """
    for surf in self.surfaces:
        surf.mat2.match_material(mat_table=mat_table)

deeplens.geolens_pkg.io.GeoLensIO

Mixin providing file I/O for GeoLens.

Supports reading and writing lens prescriptions in three formats:

  • JSON (primary): human-readable, supports parenthesised optimisable parameters, e.g. "(d)": 5.0.
  • Zemax .zmx: industry-standard sequential lens file.
  • Code V .seq: Code V sequential format (read-only).

This class is not instantiated directly; it is mixed into GeoLens.

read_lens_zmx

read_lens_zmx(filename='./test.zmx')

Load the lens from a Zemax .zmx sequential lens file.

Parses STANDARD and EVENASPH surface types, glass materials, field definitions (YFLN), and entrance pupil settings (ENPD/FLOA).

Parameters:

Name Type Description Default
filename str

Path to the .zmx file. Supports both UTF-8 and UTF-16 encoded files. Defaults to './test.zmx'.

'./test.zmx'

Returns:

Name Type Description
GeoLens

self, for method chaining.

Source code in deeplens-src/deeplens/geolens_pkg/io.py
def read_lens_zmx(self, filename="./test.zmx"):
    """Load the lens from a Zemax .zmx sequential lens file.

    Parses STANDARD and EVENASPH surface types, glass materials, field
    definitions (YFLN), and entrance pupil settings (ENPD/FLOA).

    Args:
        filename (str, optional): Path to the .zmx file. Supports both
            UTF-8 and UTF-16 encoded files. Defaults to './test.zmx'.

    Returns:
        GeoLens: ``self``, for method chaining.
    """
    # Read .zmx file
    try:
        with open(filename, "r", encoding="utf-8") as file:
            lines = file.readlines()
    except UnicodeDecodeError:
        with open(filename, "r", encoding="utf-16") as file:
            lines = file.readlines()

    # Iterate through the lines and extract SURF dict
    surfs_dict = {}
    current_surf = None
    for line in lines:
        # Strip leading/trailing whitespace for consistent parsing
        stripped_line = line.strip()

        if stripped_line.startswith("SURF"):
            current_surf = int(stripped_line.split()[1])
            surfs_dict[current_surf] = {}

        elif current_surf is not None and stripped_line != "":
            if len(stripped_line.split(maxsplit=1)) == 1:
                continue
            else:
                key, value = stripped_line.split(maxsplit=1)
                if key == "PARM":
                    new_key = "PARM" + value.split()[0]
                    new_value = value.split()[1]
                    surfs_dict[current_surf][new_key] = new_value
                else:
                    surfs_dict[current_surf][key] = value

        elif stripped_line.startswith("FLOA") or stripped_line.startswith("ENPD"):
            if stripped_line.startswith("FLOA"):
                self.float_enpd = True
                self.enpd = None
            else:
                self.float_enpd = False
                self.enpd = float(stripped_line.split()[1])

        elif stripped_line.startswith("YFLN"):
            # Parse field of view from YFLN line (field coordinates in degrees)
            # YFLN format: YFLN 0.0 <0.707*rfov_deg> <0.99*rfov_deg>
            parts = stripped_line.split()
            if len(parts) > 1:
                field_values = [abs(float(x)) for x in parts[1:] if float(x) != 0.0]
                if field_values:
                    # The largest field value is typically 0.99 * rfov_deg
                    max_field_deg = max(field_values) / 0.99
                    self.rfov_eff = (
                        max_field_deg * math.pi / 180.0
                    )  # Convert to radians

    self.float_foclen = False
    self.float_rfov = False
    # Set default rfov_eff if not parsed from file
    if not hasattr(self, "rfov_eff"):
        self.rfov_eff = None

    # Read the extracted data from each SURF
    self.surfaces = []
    d = 0.0
    mat1_name = "air"
    for surf_idx, surf_dict in surfs_dict.items():
        if surf_idx > 0 and surf_idx < current_surf:
            # Lens surface parameters
            if "GLAS" in surf_dict:
                if surf_dict["GLAS"].split()[0] == "___BLANK":
                    mat2_name = f"{surf_dict['GLAS'].split()[3]}/{surf_dict['GLAS'].split()[4]}"
                else:
                    mat2_name = surf_dict["GLAS"].split()[0].lower()
            else:
                mat2_name = "air"

            surf_r = (
                float(surf_dict["DIAM"].split()[0]) if "DIAM" in surf_dict else 1.0
            )
            surf_c = (
                float(surf_dict["CURV"].split()[0]) if "CURV" in surf_dict else 0.0
            )
            surf_d_next = (
                float(surf_dict["DISZ"].split()[0]) if "DISZ" in surf_dict else 0.0
            )
            surf_conic = float(surf_dict.get("CONI", 0.0))
            surf_param2 = float(surf_dict.get("PARM2", 0.0))
            surf_param3 = float(surf_dict.get("PARM3", 0.0))
            surf_param4 = float(surf_dict.get("PARM4", 0.0))
            surf_param5 = float(surf_dict.get("PARM5", 0.0))
            surf_param6 = float(surf_dict.get("PARM6", 0.0))
            surf_param7 = float(surf_dict.get("PARM7", 0.0))
            surf_param8 = float(surf_dict.get("PARM8", 0.0))

            # Create surface object
            if surf_dict["TYPE"] == "STANDARD":
                if mat2_name == "air" and mat1_name == "air":
                    # Aperture
                    s = Aperture(r=surf_r, d=d)
                else:
                    # Spherical surface
                    s = Spheric(c=surf_c, r=surf_r, d=d, mat2=mat2_name)

            elif surf_dict["TYPE"] == "EVENASPH":
                # Aspherical surface
                s = Aspheric(
                    c=surf_c,
                    r=surf_r,
                    d=d,
                    ai=[
                        surf_param2,
                        surf_param3,
                        surf_param4,
                        surf_param5,
                        surf_param6,
                        surf_param7,
                        surf_param8,
                    ],
                    k=surf_conic,
                    mat2=mat2_name,
                )

            else:
                print(f"Surface type {surf_dict['TYPE']} not implemented.")
                continue

            self.surfaces.append(s)
            d += surf_d_next
            mat1_name = mat2_name

        elif surf_idx == current_surf:
            # Image sensor
            self.r_sensor = float(surf_dict["DIAM"].split()[0])

        else:
            pass

    self.d_sensor = torch.tensor(d)
    return self

write_lens_zmx

write_lens_zmx(filename='./test.zmx')

Write the lens to a Zemax .zmx sequential lens file.

Exports surfaces (STANDARD or EVENASPH), materials, field definitions, and entrance pupil settings in Zemax OpticStudio format.

Parameters:

Name Type Description Default
filename str

Output file path. Defaults to './test.zmx'.

'./test.zmx'
Source code in deeplens-src/deeplens/geolens_pkg/io.py
def write_lens_zmx(self, filename="./test.zmx"):
    """Write the lens to a Zemax .zmx sequential lens file.

    Exports surfaces (STANDARD or EVENASPH), materials, field definitions,
    and entrance pupil settings in Zemax OpticStudio format.

    Args:
        filename (str, optional): Output file path. Defaults to './test.zmx'.
    """
    lens_zmx_str = ""
    if self.float_enpd:
        enpd_str = "FLOA"
    else:
        enpd_str = f"ENPD {self.enpd}"
    # Head string
    head_str = f"""VERS 190513 80 123457 L123457
MODE SEQ
NAME 
PFIL 0 0 0
LANG 0
UNIT MM X W X CM MR CPMM
{enpd_str}
ENVD 2.0E+1 1 0
GFAC 0 0
GCAT OSAKAGASCHEMICAL MISC
XFLN 0. 0. 0.
YFLN 0.0 {0.707 * self.rfov_eff * 57.3} {0.99 * self.rfov_eff * 57.3}
WAVL {self.wvln_rgb[2]:.7f} {self.wvln_rgb[1]:.7f} {self.wvln_rgb[0]:.7f}
RAIM 0 0 1 1 0 0 0 0 0
PUSH 0 0 0 0 0 0
SDMA 0 1 0
FTYP 0 0 3 3 0 0 0
ROPD 2
PICB 1
PWAV 2
POLS 1 0 1 0 0 1 0
GLRS 1 0
GSTD 0 100.000 100.000 100.000 100.000 100.000 100.000 0 1 1 0 0 1 1 1 1 1 1
NSCD 100 500 0 1.0E-3 5 1.0E-6 0 0 0 0 0 0 1000000 0 2
COFN QF "COATING.DAT" "SCATTER_PROFILE.DAT" "ABG_DATA.DAT" "PROFILE.GRD"
COFN COATING.DAT SCATTER_PROFILE.DAT ABG_DATA.DAT PROFILE.GRD
SURF 0
TYPE STANDARD
CURV 0.0
DISZ INFINITY
"""
    lens_zmx_str += head_str

    # Surface string
    for i, s in enumerate(self.surfaces):
        d_next = (
            self.surfaces[i + 1].d - self.surfaces[i].d
            if i < len(self.surfaces) - 1
            else self.d_sensor - self.surfaces[i].d
        )
        surf_str = s.zmx_str(surf_idx=i + 1, d_next=d_next)
        lens_zmx_str += surf_str

    # Sensor string
    sensor_str = f"""SURF {i + 2}
TYPE STANDARD
CURV 0.
DISZ 0.0
DIAM {self.r_sensor}
"""
    lens_zmx_str += sensor_str

    # Write lens zmx string into file
    with open(filename, "w") as f:
        f.writelines(lens_zmx_str)
        f.close()
        print(f"Lens written to {filename}")

read_lens_seq

read_lens_seq(filename='./test.seq')

Load the lens from a CODE V .seq sequential file.

Parses standard and aspheric surfaces (with conic and polynomial coefficients A–I), entrance pupil diameter (EPD), field angles (YAN), aperture stop (STO), and image surface (SI).

Parameters:

Name Type Description Default
filename str

Path to the .seq file. Supports both UTF-8 and Latin-1 encoded files. Defaults to './test.seq'.

'./test.seq'

Returns:

Name Type Description
GeoLens

self, for method chaining.

Source code in deeplens-src/deeplens/geolens_pkg/io.py
def read_lens_seq(self, filename="./test.seq"):
    """Load the lens from a CODE V .seq sequential file.

    Parses standard and aspheric surfaces (with conic and polynomial
    coefficients A–I), entrance pupil diameter (EPD), field angles (YAN),
    aperture stop (STO), and image surface (SI).

    Args:
        filename (str, optional): Path to the .seq file. Supports both
            UTF-8 and Latin-1 encoded files. Defaults to './test.seq'.

    Returns:
        GeoLens: ``self``, for method chaining.
    """
    print(f"\n{'=' * 60}")
    print(f"Start reading CODE V file: {filename}")
    print(f"{'=' * 60}\n")

    # Read .seq file
    try:
        with open(filename, "r", encoding="utf-8") as file:
            lines = file.readlines()
        print(f"File read successfully (UTF-8)")
    except UnicodeDecodeError:
        try:
            with open(filename, "r", encoding="latin-1") as file:
                lines = file.readlines()
            print(f"File read successfully (Latin-1)")
        except Exception as e:
            print(f"Failed to read file: {e}")
            return self
    print(f"Total lines: {len(lines)}\n")

    # ============ Step 1: Parse file structure ============
    surfaces = []
    current_surface = {}
    surface_index = 0
    global_diameter = None

    print("Beginning to parse surface data...\n")

    for line_num, line in enumerate(lines, 1):
        line = line.strip()

        # Skip irrelevant lines
        if not line or line.startswith(
            (
                "RDM",
                "TITLE",
                "UID",
                "GO",
                "WL",
                "XAN",
                "REF",
                "WTW",
                "INI",
                "WTF",
                "VUY",
                "VLY",
                "DOR",
                "DIM",
                "THC",
            )
        ):
            continue
        # Read entrance pupil diameter
        if line.startswith("EPD"):
            self.enpd = float(line.split()[1])
            self.float_enpd = False
            global_diameter = self.enpd / 2.0
            print(
                f"[Line {line_num}] EPD={self.enpd} -> default radius={global_diameter}"
            )
            continue
        # Read field of view angle
        if line.startswith("YAN"):
            angles = [abs(float(x)) for x in line.split()[1:] if float(x) != 0.0]
            if angles:
                self.hfov = max(angles)
                # Also set rfov in radians for consistency with write functions
                self.rfov_eff = self.hfov * math.pi / 180.0
                print(f"[Line {line_num}] Max field of view={self.hfov} deg")
            continue
        # Object surface
        if line.startswith("SO"):
            parts = line.split()
            thickness = float(parts[2]) if len(parts) > 2 else 1e10

            current_surface = {
                "type": "OBJECT",
                "thickness": thickness,
                "index": surface_index,
            }
            surfaces.append(current_surface)
            print(f"[Line {line_num}] Object surface: T={thickness}")
            surface_index += 1
            current_surface = {}
            continue
        # Standard surface
        if line.startswith("S "):
            # Save the previous surface
            if current_surface:
                surfaces.append(current_surface)
                surface_index += 1

            parts = line.split()
            radius_value = float(parts[1]) if len(parts) > 1 else 0.0
            thickness = float(parts[2]) if len(parts) > 2 else 0.0
            material = parts[3].upper() if len(parts) > 3 else "AIR"

            # Key: compute curvature C = 1/R
            if abs(radius_value) > 1e-10:
                curvature = 1.0 / radius_value
            else:
                curvature = 0.0

            current_surface = {
                "type": "STANDARD",
                "radius": radius_value,
                "curvature": curvature,
                "thickness": thickness,
                "material": material,
                "index": surface_index,
                "diameter": global_diameter,
                "conic": 0.0,
                "asph_coeffs": {},
                "is_stop": False,
            }

            print(
                f"[Line {line_num}] Surface{surface_index}: R={radius_value:.4f} → C={curvature:.6f}, T={thickness}, Mat={material}"
            )
            continue
        # Image surface - do not append yet, wait for CIR
        if line.startswith("SI"):
            if current_surface:
                surfaces.append(current_surface)
                surface_index += 1

            parts = line.split()
            thickness = float(parts[1]) if len(parts) > 1 else 0.0

            current_surface = {
                "type": "IMAGE",
                "thickness": thickness,
                "diameter": None,  # Set to None first, wait for CIR line to update
                "index": surface_index,
            }
            print(f"[Line {line_num}] Image surface")
            continue
        # Handle surface attributes (CIR, STO, ASP, K, A~J, etc.)
        if current_surface:
            if line.startswith("CIR"):
                current_surface["diameter"] = float(
                    line.split()[1].replace(";", "")
                )
                print(f"[Line {line_num}]   → CIR={current_surface['diameter']}")

            elif line.startswith("STO"):
                current_surface["is_stop"] = True
                print(f"[Line {line_num}]   → Aperture stop flag")

            elif line.startswith("ASP"):
                current_surface["type"] = "ASPHERIC"
                print(f"[Line {line_num}]   → Aspheric surface")

            elif line.startswith("K "):
                current_surface["conic"] = float(line.split()[1].replace(";", ""))
                print(f"[Line {line_num}]   → K={current_surface['conic']}")

            # Only extract single-letter coefficients A-J
            elif any(
                line.startswith(p)
                for p in [
                    "A ",
                    "B ",
                    "C ",
                    "D ",
                    "E ",
                    "F ",
                    "G ",
                    "H ",
                    "I ",
                    "J ",
                ]
            ):
                parts = line.replace(";", "").split()
                i = 0
                while i < len(parts) - 1:
                    try:
                        key = parts[i]
                        # Only accept single letters within the range A-J
                        if len(key) == 1 and key in [
                            "A",
                            "B",
                            "C",
                            "D",
                            "E",
                            "F",
                            "G",
                            "H",
                            "I",
                            "J",
                        ]:
                            value = float(parts[i + 1])
                            current_surface["asph_coeffs"][key] = value
                            print(f"[Line {line_num}]   → {key}={value}")
                        i += 2
                    except:
                        i += 1

    # Save the last surface
    if current_surface:
        surfaces.append(current_surface)

    print(f"\nParsing complete, total {len(surfaces)} surfaces\n")

    # ============ Step 2: Create surface objects ============
    print(f"{'=' * 60}")
    print("Start creating surface objects:")
    print(f"{'=' * 60}\n")

    self.surfaces = []
    d = 0.0  # Cumulative distance from the first optical surface to the current surface
    previous_material = "air"

    for surf in surfaces:
        surf_idx = surf["index"]
        surf_type = surf["type"]

        print(f"{'=' * 50}")
        print(f"Processing surface{surf_idx} ({surf_type}), current d={d:.4f}")

        # Handle object surface
        if surf_type == "OBJECT":
            obj_thickness = surf["thickness"]
            if obj_thickness < 1e9:  # Finite object distance
                d += obj_thickness
                print(
                    f"   Object surface thickness={obj_thickness} → accumulated d={d:.4f}"
                )
            else:
                print("   Object surface at infinity")
            previous_material = "air"
            continue

        # Handle image surface
        if surf_type == "IMAGE":
            self.d_sensor = torch.tensor(d)
            # Read diameter from surf dictionary (CIR value)
            self.r_sensor = (
                surf.get("diameter") if surf.get("diameter") is not None else 18.0
            )
            print(
                f"   Image plane position: d_sensor={d:.4f}, r_sensor={self.r_sensor:.4f}"
            )
            break

        # Get surface parameters
        current_material = surf.get("material", "AIR")
        if current_material in ["AIR", "0.0", "", None]:
            current_material = "air"
        else:
            current_material = current_material.lower()

        c = surf.get("curvature", 0.0)
        r = surf.get("diameter", 10.0)
        d_next = surf.get("thickness", 0.0)
        is_stop = surf.get("is_stop", False)

        print(f"   C={c:.6f}, R_aperture={r:.4f}, T={d_next:.4f}")
        print(f"   Material: {previous_material}{current_material}")
        print(f"   is_stop={is_stop}")

        # Create surface object
        try:
            # Case 1: pure aperture (air on both sides + STO flag)
            if is_stop and current_material == "air" and previous_material == "air":
                aperture = Aperture(r=r, d=d)
                self.surfaces.append(aperture)
                print(f"   Created pure aperture: Aperture(r={r:.4f}, d={d:.4f})")

            # Case 2: refractive surface (material change)
            elif current_material != previous_material:
                if surf_type == "STANDARD":
                    s = Spheric(c=c, r=r, d=d, mat2=current_material)
                    self.surfaces.append(s)
                    status = " (stop surface)" if is_stop else ""
                    print(
                        f"   Created spherical surface{status}: Spheric(c={c:.6f}, r={r:.4f}, d={d:.4f}, mat2='{current_material}')"
                    )

                elif surf_type == "ASPHERIC":
                    k = surf.get("conic", 0.0)
                    asph_coeffs = surf.get("asph_coeffs", {})

                    # CODE V aspheric coefficient mapping (shift forward by one position):
                    # A → ai[1] (2nd term, ρ²)
                    # B → ai[2] (4th term, ρ⁴)
                    # C → ai[3] (6th term, ρ⁶)
                    # D → ai[4] (8th term, ρ⁸)
                    # E → ai[5] (10th term, ρ¹⁰)
                    # F → ai[6] (12th term, ρ¹²)
                    # G → ai[7] (14th term, ρ¹⁴)
                    # H → ai[8] (16th term, ρ¹⁶)
                    # I → ai[9] (18th term, ρ¹⁸)

                    # Initialize ai array (10 elements)
                    ai = [0.0] * 10
                    ai[0] = 0.0  # ρ⁰ term (unused)
                    ai[1] = asph_coeffs.get("A", 0.0)  # ρ²
                    ai[2] = asph_coeffs.get("B", 0.0)  # ρ⁴
                    ai[3] = asph_coeffs.get("C", 0.0)  # ρ⁶
                    ai[4] = asph_coeffs.get("D", 0.0)  # ρ⁸
                    ai[5] = asph_coeffs.get("E", 0.0)  # ρ¹⁰
                    ai[6] = asph_coeffs.get("F", 0.0)  # ρ¹²
                    ai[7] = asph_coeffs.get("G", 0.0)  # ρ¹⁴
                    ai[8] = asph_coeffs.get("H", 0.0)  # ρ¹⁶
                    ai[9] = asph_coeffs.get("I", 0.0)  # ρ¹⁸

                    s = Aspheric(c=c, r=r, d=d, ai=ai, k=k, mat2=current_material)
                    self.surfaces.append(s)
                    status = " (stop surface)" if is_stop else ""
                    print(
                        f"   Created aspheric surface{status}: Aspheric(c={c:.6f}, r={r:.4f}, d={d:.4f}, k={k}, mat2='{current_material}')"
                    )
                    if any(
                        ai[1:]
                    ):  # If there are non-zero higher-order terms (starting from ai[1])
                        print(
                            f"      Aspheric coefficients: A={ai[1]:.2e}, B={ai[2]:.2e}, C={ai[3]:.2e}, D={ai[4]:.2e}"
                        )

            else:
                print(f"   Skipped (same material on both sides and no stop flag)")

        except Exception as e:
            print(f"   Failed to create surface: {e}")
            import traceback

            traceback.print_exc()

        # Key: accumulate distance at the end of the loop
        d += d_next
        print(f"   After accumulation: d={d:.4f}")
        previous_material = current_material

    print(f"\n{'=' * 60}")
    print(f"   Done! Created {len(self.surfaces)} objects")
    print(f"   d_sensor={self.d_sensor:.4f}")
    print(f"   r_sensor={self.r_sensor:.4f}")
    print(f"   hfov={self.hfov:.4f}°")
    print(f"{'=' * 60}\n")

    return self

write_lens_seq

write_lens_seq(filename='./test.seq')

Write the lens to a CODE V .seq sequential file.

Exports surfaces, materials, field definitions, and entrance pupil settings in CODE V format.

Parameters:

Name Type Description Default
filename str

Output file path. Defaults to './test.seq'.

'./test.seq'

Returns:

Name Type Description
GeoLens

self, for method chaining.

Source code in deeplens-src/deeplens/geolens_pkg/io.py
def write_lens_seq(self, filename="./test.seq"):
    """Write the lens to a CODE V .seq sequential file.

    Exports surfaces, materials, field definitions, and entrance pupil
    settings in CODE V format.

    Args:
        filename (str, optional): Output file path. Defaults to './test.seq'.

    Returns:
        GeoLens: ``self``, for method chaining.
    """

    import datetime

    current_date = datetime.datetime.now().strftime("%d-%b-%Y")

    head_str = f"""RDM;LEN       "VERSION: 2023.03       LENS VERSION: 89       Creation Date:  {current_date}"
TITLE 'Lens Design'
EPD   {self.enpd}
DIM   M
WL    650.0 550.0 480.0
REF   2
WTW   1 2 1
INI   '   '
XAN   0.0 0.0 0.0
YAN   0.0  {0.707 * self.rfov_eff * 57.3} {0.99 * self.rfov_eff * 57.3}
WTF   1.0 1.0 1.0
VUY   0.0 0.0 0.0
VLY   0.0 0.0 0.0
DOR   1.15 1.05
SO    0.0 0.1e14
"""

    lens_seq_str = head_str
    previous_material = "air"

    for i, surf in enumerate(self.surfaces):
        if i < len(self.surfaces) - 1:
            d_next = self.surfaces[i + 1].d - surf.d
        else:
            d_next = float(self.d_sensor - surf.d)

        current_material = getattr(surf, "mat2", "air")

        if current_material is None or current_material == "air":
            material_str = ""
            material_name = "air"
        elif isinstance(current_material, str):
            material_str = f" {current_material.upper()}"
            material_name = current_material
        else:
            material_name = getattr(current_material, "name", str(current_material))
            material_str = f" {material_name.upper()}"

        is_aperture = surf.__class__.__name__ == "Aperture"

        if is_aperture:
            continue

        is_aspheric = surf.__class__.__name__ == "Aspheric"
        is_stop_surface = getattr(surf, "is_stop", False)

        if is_aspheric:
            if abs(surf.c) > 1e-10:
                radius = 1.0 / surf.c
            else:
                radius = 0.0

            k = surf.k if hasattr(surf, "k") else 0.0
            ai = surf.ai if hasattr(surf, "ai") else [0.0] * 10

            surf_str = f"S     {radius} {d_next}{material_str}\n"
            surf_str += f"  CCY 0; THC 0\n"
            surf_str += f"  CIR {surf.r}\n"
            if is_stop_surface:
                surf_str += f"  STO\n"
            surf_str += f"  ASP\n"
            surf_str += f"  K   {k}\n"

            if len(ai) > 4 and any(ai[1:5]):
                surf_str += f"  A   {ai[1]:.16e}; B {ai[2]:.16e}; C&\n"
                surf_str += f"   {ai[3]:.16e}; D {ai[4]:.16e}\n"

            if len(ai) > 8 and any(ai[5:9]):
                surf_str += f"  E   {ai[5]:.16e}; F {ai[6]:.16e}; G {ai[7]:.16e}; H {ai[8]:.16e}\n"

        else:
            if abs(surf.c) > 1e-10:
                radius = 1.0 / surf.c
            else:
                radius = 0.0

            surf_str = f"S     {radius} {d_next}{material_str}\n"
            surf_str += f"  CCY 0; THC 0\n"

            if is_stop_surface:
                surf_str += f"  STO\n"

            surf_str += f"  CIR {surf.r}\n"

        lens_seq_str += surf_str
        previous_material = material_name

    sensor_str = f"SI    0.0 0.0\n"
    sensor_str += f"  CIR {self.r_sensor}\n"
    lens_seq_str += sensor_str
    lens_seq_str += "GO \n"

    with open(filename, "w") as f:
        f.write(lens_seq_str)

    print(f"Lens written to CODE V file: {filename}")
    return self

read_lens_json

read_lens_json(filename='./test.json')

Read the lens from a JSON file.

Loads lens configuration including surfaces, materials, and optical properties from the DeepLens native JSON format.

Parameters:

Name Type Description Default
filename str

Path to the JSON lens file. Defaults to './test.json'.

'./test.json'
Note

After loading, the lens is moved to self.device.

Source code in deeplens-src/deeplens/geolens_pkg/io.py
def read_lens_json(self, filename="./test.json"):
    """Read the lens from a JSON file.

    Loads lens configuration including surfaces, materials, and optical properties
    from the DeepLens native JSON format.

    Args:
        filename (str, optional): Path to the JSON lens file. Defaults to './test.json'.

    Note:
        After loading, the lens is moved to self.device.
    """
    self.surfaces = []
    self.materials = []
    with open(filename, "r") as f:
        data = json.load(f)
        d = 0.0
        for idx, surf_dict in enumerate(data["surfaces"]):
            surf_dict["d"] = d
            surf_dict["surf_idx"] = idx

            if surf_dict["type"] == "Aperture":
                s = Aperture.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Aspheric":
                s = Aspheric.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Cubic":
                s = Cubic.init_from_dict(surf_dict)

            # elif surf_dict["type"] == "GaussianRBF":
            #     s = GaussianRBF.init_from_dict(surf_dict)

            # elif surf_dict["type"] == "NURBS":
            #     s = NURBS.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Phase":
                s = Phase.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Binary2Phase":
                s = Binary2Phase.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Plane":
                s = Plane.init_from_dict(surf_dict)

            # elif surf_dict["type"] == "PolyEven":
            #     s = PolyEven.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Stop":
                s = Aperture.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Spheric":
                s = Spheric.init_from_dict(surf_dict)

            elif surf_dict["type"] == "ThinLens":
                s = ThinLens.init_from_dict(surf_dict)

            else:
                raise Exception(
                    f"Surface type {surf_dict['type']} is not implemented in GeoLens.read_lens_json()."
                )

            s.is_aperture = bool(surf_dict.get("is_aperture", False))
            self.surfaces.append(s)
            d += surf_dict["d_next"]

    self.d_sensor = torch.tensor(d)
    self.lens_info = data.get("info", "None")
    self.enpd = data.get("enpd", None)
    self.float_enpd = True if self.enpd is None else False
    self.float_foclen = False
    self.float_rfov = False
    self.r_sensor = data["r_sensor"]

    self.to(self.device)

    # Set sensor size and resolution
    sensor_res = data.get("sensor_res", (2000, 2000))
    self.set_sensor_res(sensor_res=sensor_res)

write_lens_json

write_lens_json(filename='./test.json')

Write the lens to a JSON file.

Saves the complete lens configuration including all surfaces, materials, focal length, F-number, and sensor properties to the DeepLens JSON format.

Parameters:

Name Type Description Default
filename str

Path for the output JSON file. Defaults to './test.json'.

'./test.json'
Source code in deeplens-src/deeplens/geolens_pkg/io.py
def write_lens_json(self, filename="./test.json"):
    """Write the lens to a JSON file.

    Saves the complete lens configuration including all surfaces, materials,
    focal length, F-number, and sensor properties to the DeepLens JSON format.

    Args:
        filename (str, optional): Path for the output JSON file. Defaults to './test.json'.
    """
    data = {}
    data["info"] = self.lens_info if hasattr(self, "lens_info") else "None"
    data["foclen"] = round(self.foclen, 4)
    data["fnum"] = round(self.fnum, 4)
    if self.float_enpd is False:
        data["enpd"] = round(self.enpd, 4)
    data["r_sensor"] = self.r_sensor
    data["(d_sensor)"] = round(self.d_sensor.item(), 4)
    data["(sensor_size)"] = [round(i, 4) for i in self.sensor_size]
    data["sensor_res"] = list(self.sensor_res)
    data["surfaces"] = []
    for i, s in enumerate(self.surfaces):
        surf_dict = {"idx": i}
        surf_dict.update(s.surf_dict())
        if getattr(s, "is_aperture", False):
            surf_dict["is_aperture"] = True
        if i < len(self.surfaces) - 1:
            surf_dict["d_next"] = round(
                self.surfaces[i + 1].d.item() - self.surfaces[i].d.item(), 4
            )
        else:
            surf_dict["d_next"] = round(
                self.d_sensor.item() - self.surfaces[i].d.item(), 4
            )

        data["surfaces"].append(surf_dict)

    with open(filename, "w") as f:
        json.dump(data, f, indent=4)
    print(f"Lens written to {filename}")

deeplens.geolens_pkg.vis.GeoLensVis

Mixin providing 2-D lens layout and ray visualisation for GeoLens.

Generates publication-quality cross-section plots showing lens surfaces and traced ray bundles in either the meridional or sagittal plane.

This class is not instantiated directly; it is mixed into GeoLens.

sample_parallel_2D

sample_parallel_2D(fov=0.0, num_rays=7, wvln=None, plane='meridional', entrance_pupil=True, depth=0.0)

Sample parallel rays (2D) in object space.

Used for (1) drawing lens setup, (2) 2D geometric optics calculation, for example, refocusing to infinity

Parameters:

Name Type Description Default
fov float

incident angle (in degree). Defaults to 0.0.

0.0
depth float

sampling depth. Defaults to 0.0.

0.0
num_rays int

ray number. Defaults to 7.

7
wvln float

ray wvln in µm. When None (default), falls back to self.primary_wvln.

None
plane str

sampling plane. Defaults to "meridional" (y-z plane).

'meridional'
entrance_pupil bool

whether to use entrance pupil. Defaults to True.

True

Returns:

Name Type Description
ray Ray object

Ray object. Shape [num_rays, 3]

Source code in deeplens-src/deeplens/geolens_pkg/vis.py
@torch.no_grad()
def sample_parallel_2D(
    self,
    fov=0.0,
    num_rays=7,
    wvln=None,
    plane="meridional",
    entrance_pupil=True,
    depth=0.0,
):
    """Sample parallel rays (2D) in object space.

    Used for (1) drawing lens setup, (2) 2D geometric optics calculation, for example, refocusing to infinity

    Args:
        fov (float, optional): incident angle (in degree). Defaults to 0.0.
        depth (float, optional): sampling depth. Defaults to 0.0.
        num_rays (int, optional): ray number. Defaults to 7.
        wvln (float, optional): ray wvln in µm. When ``None`` (default),
            falls back to ``self.primary_wvln``.
        plane (str, optional): sampling plane. Defaults to "meridional" (y-z plane).
        entrance_pupil (bool, optional): whether to use entrance pupil. Defaults to True.

    Returns:
        ray (Ray object): Ray object. Shape [num_rays, 3]
    """
    wvln = self.primary_wvln if wvln is None else wvln
    # Sample points on the pupil
    if entrance_pupil:
        pupilz, pupilr = self.get_entrance_pupil()
    else:
        pupilz, pupilr = self.surfaces[0].d.item(), self.surfaces[0].r

    # Sample ray origins, shape [num_rays, 3]
    if plane == "sagittal":
        ray_o = torch.stack(
            (
                torch.linspace(-pupilr, pupilr, num_rays) * 0.99,
                torch.full((num_rays,), 0),
                torch.full((num_rays,), pupilz),
            ),
            axis=-1,
        )
    elif plane == "meridional":
        ray_o = torch.stack(
            (
                torch.full((num_rays,), 0),
                torch.linspace(-pupilr, pupilr, num_rays) * 0.99,
                torch.full((num_rays,), pupilz),
            ),
            axis=-1,
        )
    else:
        raise ValueError(f"Invalid plane: {plane}")

    # Sample ray directions, shape [num_rays, 3]
    if plane == "sagittal":
        ray_d = torch.stack(
            (
                torch.full((num_rays,), float(np.sin(np.deg2rad(fov)))),
                torch.zeros((num_rays,)),
                torch.full((num_rays,), float(np.cos(np.deg2rad(fov)))),
            ),
            axis=-1,
        )
    elif plane == "meridional":
        ray_d = torch.stack(
            (
                torch.zeros((num_rays,)),
                torch.full((num_rays,), float(np.sin(np.deg2rad(fov)))),
                torch.full((num_rays,), float(np.cos(np.deg2rad(fov)))),
            ),
            axis=-1,
        )
    else:
        raise ValueError(f"Invalid plane: {plane}")

    # Form rays and propagate to the target depth
    rays = Ray(ray_o, ray_d, wvln, device=self.device)
    rays.prop_to(depth)
    return rays

sample_point_source_2D

sample_point_source_2D(fov=0.0, depth=None, num_rays=7, wvln=None, entrance_pupil=True)

Sample point source rays (2D) in object space.

Used for (1) drawing lens setup.

Parameters:

Name Type Description Default
fov float

incident angle (in degree). Defaults to 0.0.

0.0
depth float

sampling depth. When None (default), falls back to self.obj_depth.

None
num_rays int

ray number. Defaults to 7.

7
wvln float

ray wvln in µm. When None (default), falls back to self.primary_wvln.

None
entrance_pupil bool

whether to use entrance pupil. Defaults to False.

True

Returns:

Name Type Description
ray Ray object

Ray object. Shape [num_rays, 3]

Source code in deeplens-src/deeplens/geolens_pkg/vis.py
@torch.no_grad()
def sample_point_source_2D(
    self,
    fov=0.0,
    depth=None,
    num_rays=7,
    wvln=None,
    entrance_pupil=True,
):
    """Sample point source rays (2D) in object space.

    Used for (1) drawing lens setup.

    Args:
        fov (float, optional): incident angle (in degree). Defaults to 0.0.
        depth (float, optional): sampling depth. When ``None`` (default),
            falls back to ``self.obj_depth``.
        num_rays (int, optional): ray number. Defaults to 7.
        wvln (float, optional): ray wvln in µm. When ``None`` (default),
            falls back to ``self.primary_wvln``.
        entrance_pupil (bool, optional): whether to use entrance pupil. Defaults to False.

    Returns:
        ray (Ray object): Ray object. Shape [num_rays, 3]
    """
    wvln = self.primary_wvln if wvln is None else wvln
    depth = self.obj_depth if depth is None else depth
    # Sample point on the object plane
    ray_o = torch.tensor([depth * float(np.tan(np.deg2rad(fov))), 0.0, depth])
    ray_o = ray_o.unsqueeze(0).repeat(num_rays, 1)

    # Sample points (second point) on the pupil
    if entrance_pupil:
        pupilz, pupilr = self.calc_entrance_pupil()
    else:
        pupilz, pupilr = self.surfaces[0].d.item(), self.surfaces[0].r

    x2 = torch.linspace(-pupilr, pupilr, num_rays) * 0.99
    y2 = torch.zeros_like(x2)
    z2 = torch.full_like(x2, pupilz)
    ray_o2 = torch.stack((x2, y2, z2), axis=1)

    # Form the rays
    ray_d = ray_o2 - ray_o
    ray = Ray(ray_o, ray_d, wvln, device=self.device)

    # Propagate rays to the sampling depth
    ray.prop_to(depth)
    return ray

draw_layout

draw_layout(filename, depth=float('inf'), zmx_format=True, multi_plot=False, lens_title=None, show=False)

Plot 2D lens layout with ray tracing.

The title is auto-generated when lens_title is None: it includes focal length, F-number, FoV, IMGH, RGB wavelengths, and a second line with per-FoV RMS spot radii from analysis_spot().

Parameters:

Name Type Description Default
filename

Output filename.

required
depth

Object distance for ray tracing [mm]. Use float('inf') for collimated input. Defaults to float('inf').

float('inf')
zmx_format

If True, draw surfaces in Zemax style. Defaults to True.

True
multi_plot

If True, create one sub-plot per wavelength. Defaults to False.

False
lens_title

Title string. If None, auto-generated. Defaults to None.

None
show

If True, display the figure interactively. Defaults to False.

False
Source code in deeplens-src/deeplens/geolens_pkg/vis.py
def draw_layout(
    self,
    filename,
    depth=float("inf"),
    zmx_format=True,
    multi_plot=False,
    lens_title=None,
    show=False,
):
    """Plot 2D lens layout with ray tracing.

    The title is auto-generated when ``lens_title`` is None: it includes
    focal length, F-number, FoV, IMGH, RGB wavelengths, and a second line
    with per-FoV RMS spot radii from ``analysis_spot()``.

    Args:
        filename: Output filename.
        depth: Object distance for ray tracing [mm]. Use ``float('inf')``
            for collimated input. Defaults to ``float('inf')``.
        zmx_format: If True, draw surfaces in Zemax style. Defaults to True.
        multi_plot: If True, create one sub-plot per wavelength.
            Defaults to False.
        lens_title: Title string. If None, auto-generated. Defaults to None.
        show: If True, display the figure interactively. Defaults to False.
    """
    num_rays = 11
    num_views = 3

    # Lens title
    if lens_title is None:
        eff_foclen = round(self.foclen, 2)
        fov_deg = round(2 * self.rfov * 180 / torch.pi, 1)
        imgh = round(self.r_sensor, 1)
        wvl_nm = [int(round(w * 1000)) for w in self.wvln_rgb]  # µm → nm

        if self.aper_idx is not None:
            _, pupil_r = self.calc_entrance_pupil()
            fnum = round(eff_foclen / pupil_r / 2, 2)
            line1 = (
                f"FocLen{eff_foclen}mm - F/{fnum} - FoV{fov_deg} - "
                f"IMGH{imgh}mm - RGB({wvl_nm[0]}/{wvl_nm[1]}/{wvl_nm[2]}nm)"
            )
        else:
            line1 = (
                f"FocLen{eff_foclen}mm - FoV{fov_deg} - "
                f"IMGH{imgh}mm - RGB({wvl_nm[0]}/{wvl_nm[1]}/{wvl_nm[2]}nm)"
            )

        spot = self.analysis_spot(num_field=3)
        rms0 = spot["fov0.0"]["rms"]
        rms5 = spot["fov0.5"]["rms"]
        rms10 = spot["fov1.0"]["rms"]
        line2 = f"RMS spot: 0.0FoV={rms0:.2f}\u03bcm  0.5FoV={rms5:.2f}\u03bcm  1.0FoV={rms10:.2f}\u03bcm"
        lens_title = f"{line1}\n{line2}"

    # Draw lens layout
    colors_list = ["#CC0000", "#006600", "#0066CC"]
    rfov_deg = float(np.rad2deg(self.rfov))
    fov_ls = np.linspace(0, rfov_deg * 0.99, num=num_views)

    if not multi_plot:
        ax, fig = self.draw_lens_2d(zmx_format=zmx_format)
        fig.suptitle(lens_title, fontsize=10, fontfamily="Nimbus Sans")
        for i, fov in enumerate(fov_ls):
            # Sample rays, shape (num_rays, 3)
            if depth == float("inf"):
                ray = self.sample_parallel_2D(
                    fov=fov,
                    wvln=self.wvln_rgb[2 - i],
                    num_rays=num_rays,
                    depth=-1.0,
                    plane="sagittal",
                )
            else:
                ray = self.sample_point_source_2D(
                    fov=fov,
                    depth=depth,
                    num_rays=num_rays,
                    wvln=self.wvln_rgb[2 - i],
                )
                ray.prop_to(-1.0)

            # Trace rays to sensor and plot ray paths
            _, ray_o_record = self.trace2sensor(ray=ray, record=True)
            ax, fig = self.draw_ray_2d(
                ray_o_record, ax=ax, fig=fig, color=colors_list[i]
            )

        ax.axis("off")

    else:
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        fig.suptitle(lens_title, fontsize=10, fontfamily="Nimbus Sans")
        for i, wvln in enumerate(self.wvln_rgb):
            ax = axs[i]
            ax, fig = self.draw_lens_2d(ax=ax, fig=fig, zmx_format=zmx_format)
            for fov in fov_ls:
                # Sample rays, shape (num_rays, 3)
                if depth == float("inf"):
                    ray = self.sample_parallel_2D(
                        fov=fov,
                        num_rays=num_rays,
                        wvln=wvln,
                        plane="sagittal",
                    )
                else:
                    ray = self.sample_point_source_2D(
                        fov=fov,
                        depth=depth,
                        num_rays=num_rays,
                        wvln=wvln,
                    )

                # Trace rays to sensor and plot ray paths
                ray_out, ray_o_record = self.trace2sensor(ray=ray, record=True)
                ax, fig = self.draw_ray_2d(
                    ray_o_record, ax=ax, fig=fig, color=colors_list[i]
                )
                ax.axis("off")

    if show:
        fig.show()
    else:
        fig.savefig(filename, format="png", dpi=300)
        plt.close()

draw_lens_2d

draw_lens_2d(ax=None, fig=None, color='k', linestyle='-', zmx_format=False, fix_bound=False)

Draw lens cross-section layout in a 2D plot.

Renders each surface profile, connects lens elements with edge lines, and draws the sensor plane.

Parameters:

Name Type Description Default
ax Axes

Existing axes to draw on. If None, creates a new figure. Defaults to None.

None
fig Figure

Existing figure. Defaults to None.

None
color str

Line colour for lens outlines. Defaults to 'k'.

'k'
linestyle str

Line style. Defaults to '-'.

'-'
zmx_format bool

If True, draw stepped edge connections matching Zemax layout style. Defaults to False.

False
fix_bound bool

If True, use fixed axis limits [-1,7]x[-4,4]. Defaults to False.

False

Returns:

Name Type Description
tuple

(ax, fig) matplotlib axes and figure objects.

Source code in deeplens-src/deeplens/geolens_pkg/vis.py
def draw_lens_2d(
    self,
    ax=None,
    fig=None,
    color="k",
    linestyle="-",
    zmx_format=False,
    fix_bound=False,
):
    """Draw lens cross-section layout in a 2D plot.

    Renders each surface profile, connects lens elements with edge lines,
    and draws the sensor plane.

    Args:
        ax (matplotlib.axes.Axes, optional): Existing axes to draw on. If None,
            creates a new figure. Defaults to None.
        fig (matplotlib.figure.Figure, optional): Existing figure. Defaults to None.
        color (str, optional): Line colour for lens outlines. Defaults to 'k'.
        linestyle (str, optional): Line style. Defaults to '-'.
        zmx_format (bool, optional): If True, draw stepped edge connections
            matching Zemax layout style. Defaults to False.
        fix_bound (bool, optional): If True, use fixed axis limits [-1,7]x[-4,4].
            Defaults to False.

    Returns:
        tuple: (ax, fig) matplotlib axes and figure objects.
    """
    # If no ax is given, generate a new one.
    if ax is None and fig is None:
        # fig, ax = plt.subplots(figsize=(6, 6))
        fig, ax = plt.subplots()

    # Draw lens surfaces
    for i, s in enumerate(self.surfaces):
        s.draw_widget(ax)

    # Connect two surfaces
    for i in range(len(self.surfaces)):
        if self.surfaces[i].mat2.n > 1.1:
            s_prev = self.surfaces[i]
            s = self.surfaces[i + 1]

            r_prev = float(s_prev.draw_r())
            r = float(s.draw_r())
            sag_prev = s_prev.surface_with_offset(
                r_prev, 0.0, valid_check=False
            ).item()
            sag = s.surface_with_offset(
                r, 0.0, valid_check=False
            ).item()

            if r_prev >= r:
                # Front surface wider: go axially forward at r_prev, then step radially inward
                z = np.array([sag_prev, sag, sag])
                x = np.array([r_prev, r_prev, r])
            else:
                # Rear surface wider: step radially outward at z_prev, then go axially forward
                z = np.array([sag_prev, sag_prev, sag])
                x = np.array([r_prev, r, r])

            if not zmx_format:
                # In non-zmx mode use a direct diagonal between the two outer edges
                z = np.array([z[0], z[-1]])
                x = np.array([x[0], x[-1]])

            ax.plot(z, -x, color, linewidth=0.75)
            ax.plot(z, x, color, linewidth=0.75)
            s_prev = s

    # Draw sensor
    ax.plot(
        [self.d_sensor.item(), self.d_sensor.item()],
        [-self.r_sensor, self.r_sensor],
        color,
    )

    # Set figure size
    if fix_bound:
        ax.set_aspect("equal")
        ax.set_xlim(-1, 7)
        ax.set_ylim(-4, 4)
    else:
        ax.set_aspect("equal", adjustable="datalim", anchor="C")
        ax.minorticks_on()
        ax.set_xlim(-0.5, 7.5)
        ax.set_ylim(-4, 4)
        ax.autoscale()

    return ax, fig

draw_ray_2d

draw_ray_2d(ray_o_record, ax, fig, color='b')

Plot ray paths.

Parameters:

Name Type Description Default
ray_o_record list

list of intersection points.

required
ax Axes

matplotlib axes.

required
fig Figure

matplotlib figure.

required
Source code in deeplens-src/deeplens/geolens_pkg/vis.py
def draw_ray_2d(self, ray_o_record, ax, fig, color="b"):
    """Plot ray paths.

    Args:
        ray_o_record (list): list of intersection points.
        ax (matplotlib.axes.Axes): matplotlib axes.
        fig (matplotlib.figure.Figure): matplotlib figure.
    """
    # shape (num_view, num_rays, num_path, 2)
    ray_o_record = torch.stack(ray_o_record, dim=-2).cpu().numpy()
    if ray_o_record.ndim == 3:
        ray_o_record = ray_o_record[None, ...]

    for idx_view in range(ray_o_record.shape[0]):
        for idx_ray in range(ray_o_record.shape[1]):
            ax.plot(
                ray_o_record[idx_view, idx_ray, :, 2],
                ray_o_record[idx_view, idx_ray, :, 0],
                color,
                linewidth=0.8,
            )

            # ax.scatter(
            #     ray_o_record[idx_view, idx_ray, :, 2],
            #     ray_o_record[idx_view, idx_ray, :, 0],
            #     "b",
            #     marker="x",
            # )

    return ax, fig

create_barrier

create_barrier(filename, barrier_thickness=1.0, ring_height=0.5, ring_size=1.0)

Create a 3D barrier for the lens system.

Parameters:

Name Type Description Default
filename

Path to save the figure

required
barrier_thickness

Thickness of the barrier

1.0
ring_height

Height of the annular ring

0.5
ring_size

Size of the annular ring

1.0
Source code in deeplens-src/deeplens/geolens_pkg/vis.py
def create_barrier(
    self, filename, barrier_thickness=1.0, ring_height=0.5, ring_size=1.0
):
    """Create a 3D barrier for the lens system.

    Args:
        filename: Path to save the figure
        barrier_thickness: Thickness of the barrier
        ring_height: Height of the annular ring
        ring_size: Size of the annular ring
    """
    barriers = []
    rings = []

    # Create barriers
    barrier_z = 0.0
    barrier_r = 0.0
    barrier_length = 0.0
    for i in range(len(self.surfaces)):
        barrier_r = max(self.surfaces[i].r, barrier_r)

        if self.surfaces[i].mat2.get_name() != "air":
            # Update the barrier radius
            # barrier_r = max(geolens.surfaces[i].r, barrier_r)
            pass
        else:
            # Extend the barrier till middle of the air space to the next surface
            max_curr_surf_d = self.surfaces[i].d.item() + max(
                self.surfaces[i].surface_sag(0.0, self.surfaces[i].r), 0.0
            )
            if i < len(self.surfaces) - 1:
                min_next_surf_d = self.surfaces[i + 1].d.item() + min(
                    self.surfaces[i + 1].surface_sag(0.0, self.surfaces[i + 1].r),
                    0.0,
                )
                extra_space = (min_next_surf_d - max_curr_surf_d) / 2
            else:
                min_next_surf_d = self.d_sensor.item()
                extra_space = min_next_surf_d - max_curr_surf_d

            barrier_length = max_curr_surf_d + extra_space - barrier_z

            # Create a barrier
            barrier = {
                "pos_z": barrier_z,
                "pos_r": barrier_r,
                "length": barrier_length,
                "thickness": barrier_thickness,
            }
            barriers.append(barrier)

            # Reset the barrier parameters
            barrier_z = barrier_length + barrier_z
            barrier_r = 0.0
            barrier_length = 0.0

    # # Create rings
    # for i in range(len(geolens.surfaces)):
    #     if geolens.surfaces[i].mat2.get_name() != "air":
    #         ring = {
    #             "pos_z": geolens.surfaces[i].d.item(),

    # Plot lens layout
    ax, fig = self.draw_layout(filename)

    # Plot barrier
    barrier_z_ls = []
    barrier_r_ls = []
    for b in barriers:
        barrier_z_ls.append(b["pos_z"])
        barrier_z_ls.append(b["pos_z"] + b["length"])
        barrier_r_ls.append(b["pos_r"])
        barrier_r_ls.append(b["pos_r"])
    ax.plot(barrier_z_ls, barrier_r_ls, "green", linewidth=1.0)
    ax.plot(barrier_z_ls, [-i for i in barrier_r_ls], "green", linewidth=1.0)

    # Plot rings

    fig.savefig(filename, format="png", dpi=300)
    plt.close()

    pass

deeplens.geolens_pkg.vis3d.GeoLensVis3D

Mixin providing 3-D mesh visualisation for GeoLens.

Creates lens surface, aperture, barrier, sensor, and ray-path meshes as polygon data and optionally renders them with PyVista. All geometry is expressed in millimetres and stored as CrossPoly (vertex/face) objects that can be saved to .obj files for external renderers.

This class is not instantiated directly; it is mixed into GeoLens.

create_mesh

create_mesh(mesh_rings: int = 32, mesh_arms: int = 128, is_wrap: bool = False)

Create all lens/bridge/sensor/aperture meshes.

Parameters:

Name Type Description Default
lens GeoLens

The lens object.

required
mesh_rings int

The number of rings in the mesh.

32
mesh_arms int

The number of arms in the mesh.

128
is_wrap bool

Whether to wrap the lens bridge around the lens as cylinder.

False

Returns: surf_meshes (List[Surface]): Lens surfaces meshes. bridge_meshes (List[FaceMesh]): Lens bridges meshes. (NOT support wrap around for now) sensor_mesh (RectangleMesh): Sensor meshes. (only support rectangular sensor for now)

Source code in deeplens-src/deeplens/geolens_pkg/vis3d.py
def create_mesh(
    self,
    mesh_rings: int = 32,
    mesh_arms: int = 128,
    is_wrap: bool = False,
):
    """Create all lens/bridge/sensor/aperture meshes.

    Args:
        lens (GeoLens): The lens object.
        mesh_rings (int): The number of rings in the mesh.
        mesh_arms (int): The number of arms in the mesh.
        is_wrap (bool): Whether to wrap the lens bridge around the lens as cylinder.
    Returns:
        surf_meshes (List[Surface]): Lens surfaces meshes.
        bridge_meshes (List[FaceMesh]): Lens bridges meshes. (NOT support wrap around for now)
        sensor_mesh (RectangleMesh): Sensor meshes. (only support rectangular sensor for now)
    """
    surf_meshes = []
    element_group = []
    element_groups = []
    bridge_meshes = []  # change to nested list for wrap around
    sensor_mesh = None

    # Create the surface meshes
    for i, surf in enumerate(self.surfaces):
        # Create the surface mesh (list of Surface objects)
        surf_meshes.append(surf.create_mesh(n_rings=mesh_rings, n_arms=mesh_arms))

        # Add the surface to the element group
        element_group.append(i)
        if surf.mat2.name == "air":
            element_groups.append(element_group)
            element_group = []

    # Create the bridge meshes (list of FaceMesh objects)
    for i, pair in enumerate(element_groups):
        if len(pair) == 1:
            bridge_meshes.append([])
            continue
        elif len(pair) == 2:
            a_idx, b_idx = pair
            a = surf_meshes[a_idx]
            b = surf_meshes[b_idx]
            bridge_mesh_group = []
            if not is_wrap:
                bridge_mesh = bridge(a.rim, b.rim)
                bridge_mesh_group.append(bridge_mesh)
            else:
                # create wrap by creating a new rim
                # from projecting the larger rim onto the smaller rim plane
                # assume the elements are always ordered on z-axis
                r_a = self.surfaces[a_idx].r
                r_b = self.surfaces[b_idx].r
                d_rim_a = np.mean(
                    a.rim.vertices[:, 2], keepdims=False
                )  # calc rim mean z
                d_rim_b = np.mean(b.rim.vertices[:, 2], keepdims=False)

                if r_a > r_b:
                    z = line_translate(a.rim, 0, 0, d_rim_b - d_rim_a)
                    bridge_mesh_wrap = bridge(z, b.rim)
                    bridge_mesh = bridge(a.rim, z)
                    bridge_mesh_group.append(bridge_mesh_wrap)
                elif r_a < r_b:
                    z = line_translate(b.rim, 0, 0, d_rim_a - d_rim_b)
                    bridge_mesh_wrap = bridge(a.rim, z)
                    bridge_mesh = bridge(z, b.rim)
                    bridge_mesh_group.append(bridge_mesh_wrap)
                else:
                    bridge_mesh = bridge(a.rim, b.rim)
                bridge_mesh_group.append(bridge_mesh)
            bridge_meshes.append(bridge_mesh_group)

        elif len(pair) == 3:
            a_idx, b_idx, c_idx = pair
            a = surf_meshes[a_idx]
            b = surf_meshes[b_idx]
            c = surf_meshes[c_idx]
            bridge_mesh_group = []
            if not is_wrap:
                bridge_mesh = bridge(a.rim, b.rim)
                bridge_mesh_group.append(bridge_mesh)
                bridge_mesh = bridge(b.rim, c.rim)
                bridge_mesh_group.append(bridge_mesh)
            else:
                # create wrap by creating a new rim
                # from projecting the larger rim onto the smaller rim plane
                # assume the elements are always ordered on z-axis
                r_a = self.surfaces[a_idx].r
                r_b = self.surfaces[b_idx].r
                r_c = self.surfaces[c_idx].r
                d_rim_a = np.mean(
                    a.rim.vertices[:, 2], keepdims=False
                )  # calc rim mean z
                d_rim_b = np.mean(b.rim.vertices[:, 2], keepdims=False)
                d_rim_c = np.mean(c.rim.vertices[:, 2], keepdims=False)

                rim_list = [a.rim, b.rim, c.rim]
                r_list = [r_a, r_b, r_c]
                d_rim_list = [d_rim_a, d_rim_b, d_rim_c]
                idx_wrap = r_list.index(max(r_list))
                r_wrap = r_list[idx_wrap]
                d_rim_wrap = d_rim_list[idx_wrap]

                for i in range(3):
                    if i != idx_wrap and r_list[i] != r_wrap:
                        # substitute the rim with the wrapped rim
                        d_diff = d_rim_list[i] - d_rim_wrap
                        z = line_translate(rim_list[idx_wrap], 0, 0, d_diff)
                        # add the wrap bridge between older rim and wrapped one
                        wrap_mesh = bridge(rim_list[i], z)
                        # update the rim
                        rim_list[i] = z
                        bridge_mesh_group.append(wrap_mesh)
                bridge_mesh = bridge(rim_list[0], rim_list[1])
                bridge_mesh_group.append(bridge_mesh)
                bridge_mesh = bridge(rim_list[1], rim_list[2])
                bridge_mesh_group.append(bridge_mesh)
            bridge_meshes.append(bridge_mesh_group)

        else:
            raise ValueError(f"Invalid bridge group length: {len(pair)}")

    # Create the sensor mesh (RectangleMesh object)
    sensor_d = self.d_sensor.item()
    sensor_r = self.r_sensor
    h, w = sensor_r * 1.4142, sensor_r * 1.4142
    sensor_mesh = RectangleMesh(
        np.array([0, 0, sensor_d]), np.array([1, 0, 0]), np.array([0, 1, 0]), w, h
    )

    # turn surf_meshes to list of FaceMesh
    surf_meshes_cvt = [surf_to_face_mesh(surf) for surf in surf_meshes]
    return surf_meshes_cvt, bridge_meshes, element_groups, sensor_mesh

draw_lens_3d

draw_lens_3d(plotter=None, save_dir: Optional[str] = None, mesh_rings: int = 32, mesh_arms: int = 128, surface_color: List[float] = [0.06, 0.3, 0.6], draw_rays: bool = True, fovs: List[float] = [0.0], fov_phis: List[float] = [0.0], ray_rings: int = 6, ray_arms: int = 8, is_wrap: bool = False)

Draw lens 3D layout with rays using pyvista.

Note: PyVista is imported lazily only when this method is called.

Parameters:

Name Type Description Default
plotter

pv.Plotter. Optional pyvista Plotter instance. If None, a new one is created.

None
save_dir str

The directory to save the image.

None
mesh_rings int

The number of rings in the mesh.

32
mesh_arms int

The number of arms in the mesh.

128
surface_color List[float]

The color of the surfaces.

[0.06, 0.3, 0.6]
draw_rays bool

Whether to show the rays.

True
fovs List[float]

The FoV angles to be sampled, unit: degree.

[0.0]
fov_phis List[float]

The FoV azimuthal angles to be sampled, unit: degree.

[0.0]
ray_rings int

The number of pupil rings to be sampled.

6
ray_arms int

The number of pupil arms to be sampled.

8
is_wrap bool

Whether to wrap the lens bridge around the lens as cylinder.

False

Returns:

Name Type Description
plotter

pv.Plotter. The pyvista Plotter instance.

Source code in deeplens-src/deeplens/geolens_pkg/vis3d.py
def draw_lens_3d(
    self,
    plotter=None,
    save_dir: Optional[str] = None,
    mesh_rings: int = 32,
    mesh_arms: int = 128,
    surface_color: List[float] = [0.06, 0.3, 0.6],
    draw_rays: bool = True,
    fovs: List[float] = [0.0],
    fov_phis: List[float] = [0.0],
    ray_rings: int = 6,
    ray_arms: int = 8,
    is_wrap: bool = False,
):
    """Draw lens 3D layout with rays using pyvista.

    Note: PyVista is imported lazily only when this method is called.

    Args:
        plotter: pv.Plotter. Optional pyvista Plotter instance. If None, a new one is created.
        save_dir (str): The directory to save the image.
        mesh_rings (int): The number of rings in the mesh.
        mesh_arms (int): The number of arms in the mesh.
        surface_color (List[float]): The color of the surfaces.
        draw_rays (bool): Whether to show the rays.
        fovs (List[float]): The FoV angles to be sampled, unit: degree.
        fov_phis (List[float]): The FoV azimuthal angles to be sampled, unit: degree.
        ray_rings (int): The number of pupil rings to be sampled.
        ray_arms (int): The number of pupil arms to be sampled.
        is_wrap (bool): Whether to wrap the lens bridge around the lens as cylinder.

    Returns:
        plotter: pv.Plotter. The pyvista Plotter instance.
    """
    # Lazy import of pyvista
    try:
        import pyvista as pv
    except ImportError as e:
        raise ImportError(
            "PyVista is required for 3D GUI rendering. Install with `pip install pyvista`."
        ) from e

    # Create plotter if not provided
    if plotter is None:
        plotter = pv.Plotter()

    surf_color = surface_color
    sensor_color = [0.5, 0.5, 0.5]

    # Create meshes
    surf_meshes, bridge_meshes, _, sensor_mesh = self.create_mesh(
        mesh_rings, mesh_arms, is_wrap
    )

    # Draw meshes
    for surf in surf_meshes:
        if not isinstance(surf, Aperture):
            _draw_mesh_to_plotter(
                plotter, surf, color=surf_color, opacity=0.5, pv=pv
            )

    for bridge_group in bridge_meshes:
        for bridge_mesh in bridge_group:
            _draw_mesh_to_plotter(
                plotter, bridge_mesh, color=surf_color, opacity=0.5, pv=pv
            )

    _draw_mesh_to_plotter(
        plotter, sensor_mesh, color=sensor_color, opacity=1.0, pv=pv
    )

    # Draw rays
    if draw_rays:
        rays_curve = geolens_ray_poly(
            self, fovs, fov_phis, n_rings=ray_rings, n_arms=ray_arms
        )

        rays_poly_list = [curve_list_to_polydata(r) for r in rays_curve]
        rays_poly_fov = [merge(r) for r in rays_poly_list]
        rays_poly_fov = [_wrap_base_poly_to_pyvista(r, pv) for r in rays_poly_fov]
        for r in rays_poly_fov:
            plotter.add_mesh(r)

    # Save images
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)
        plotter.show(screenshot=os.path.join(save_dir, "lens_layout3d.png"))

    return plotter

save_lens_obj

save_lens_obj(save_dir: str, mesh_rings: int = 64, mesh_arms: int = 128, save_rays: bool = False, fovs: List[float] = [0.0], fov_phis: List[float] = [0.0], ray_rings: int = 6, ray_arms: int = 8, is_wrap: bool = False, save_elements: bool = True)

Save lens geometry and rays as .obj files using pyvista.

Note: use #F2F7FFFF as the color for lens when rendering in Blender.

Parameters:

Name Type Description Default
lens GeoLens

The lens object.

required
save_dir str

The directory to save the image.

required
mesh_rings int

The number of rings in the mesh. (default: 128)

64
mesh_arms int

The number of arms in the mesh. (default: 256)

128
save_rays bool

Whether to save the rays.

False
fovs List[float]

The FoV angles to be sampled, unit: degree.

[0.0]
fov_phis List[float]

The FoV azimuthal angles to be sampled, unit: degree.

[0.0]
ray_rings int

The number of pupil rings to be sampled. (default: 6)

6
ray_arms int

The number of pupil arms to be sampled. (default: 8)

8
is_wrap bool

Whether to wrap the lens bridge around the lens as cylinder.

False
save_elements bool

Whether to save the elements.

True
Source code in deeplens-src/deeplens/geolens_pkg/vis3d.py
def save_lens_obj(
    self,
    save_dir: str,
    mesh_rings: int = 64,
    mesh_arms: int = 128,
    save_rays: bool = False,
    fovs: List[float] = [0.0],
    fov_phis: List[float] = [0.0],
    ray_rings: int = 6,
    ray_arms: int = 8,
    is_wrap: bool = False,
    save_elements: bool = True,
):
    """Save lens geometry and rays as .obj files using pyvista.

    Note: use #F2F7FFFF as the color for lens when rendering in Blender.

    Args:
        lens (GeoLens): The lens object.
        save_dir (str): The directory to save the image.
        mesh_rings (int): The number of rings in the mesh. (default: 128)
        mesh_arms (int): The number of arms in the mesh. (default: 256)
        save_rays (bool): Whether to save the rays.
        fovs (List[float]): The FoV angles to be sampled, unit: degree.
        fov_phis (List[float]): The FoV azimuthal angles to be sampled, unit: degree.
        ray_rings (int): The number of pupil rings to be sampled. (default: 6)
        ray_arms (int): The number of pupil arms to be sampled. (default: 8)
        is_wrap (bool): Whether to wrap the lens bridge around the lens as cylinder.
        save_elements (bool): Whether to save the elements.
    """
    os.makedirs(save_dir, exist_ok=True)

    # Create surfaces & bridges meshes
    surf_meshes, bridge_meshes, element_groups, sensor_mesh = self.create_mesh(
        mesh_rings, mesh_arms, is_wrap
    )

    # Save individual lens elements (surfaces + bridges merged)
    if save_elements:
        for i, pair in enumerate(element_groups):
            print(f"Running in pair {i} with pair length {len(pair)}")
            # Collect surface polydata
            surf_polydata_list = [surf_meshes[idx].get_polydata() for idx in pair]

            # Collect bridge polydata if available
            bridge_polydata_list = []
            if i < len(bridge_meshes) and len(bridge_meshes[i]) > 0:
                print(f"Bridge mesh group number: {len(bridge_meshes[i])}")
                bridge_polydata_list = [b.get_polydata() for b in bridge_meshes[i]]

            # Merge surfaces and bridges together
            all_polydata = surf_polydata_list + bridge_polydata_list
            if len(all_polydata) == 1:
                element = all_polydata[0]
            else:
                element = merge(all_polydata)
            element.save(os.path.join(save_dir, f"element_{i}.obj"))

    # Merge all surfaces and bridges, and save as single lens.obj file
    surf_polydata = [
        surf.get_polydata()
        for surf in surf_meshes
        if not isinstance(surf, Aperture)
    ]
    bridge_polydata = [
        b.get_polydata() for group in bridge_meshes for b in group
    ]  # flatten the nested list
    lens_polydata = surf_polydata + bridge_polydata
    lens_polydata = merge(lens_polydata)
    lens_polydata.save(os.path.join(save_dir, "lens.obj"))

    # Save sensor
    sensor_polydata = sensor_mesh.get_polydata()
    sensor_polydata.save(os.path.join(save_dir, "sensor.obj"))

    # Save rays
    if save_rays:
        rays_curve = geolens_ray_poly(
            self, fovs, fov_phis, n_rings=ray_rings, n_arms=ray_arms
        )
        rays_poly_list = [curve_list_to_polydata(r) for r in rays_curve]
        rays_poly_fov = [merge(r) for r in rays_poly_list]
        for i, r in enumerate(rays_poly_fov):
            r.save(os.path.join(save_dir, f"lens_rays_fov_{i}.obj"))