Skip to content

GeoLens API Reference

The GeoLens class is the primary lens model in AutoLens — a differentiable multi-element refractive lens via geometric ray tracing.


GeoLens

Differentiable multi-element refractive lens. Composes nine mixin classes for PSF computation, evaluation, optimization, I/O, visualization, and tolerancing.

src.GeoLens

GeoLens(filename=None, device=None, dtype=torch.float32)

Bases: GeoLensPSF, GeoLensEval, GeoLensSeidel, GeoLensOptim, GeoLensSurfOps, GeoLensVis, GeoLensIO, GeoLensTolerance, GeoLensVis3D, Lens

Differentiable geometric lens using vectorised ray tracing.

The primary lens model in DeepLens. Supports multi-element refractive (and partially reflective) systems loaded from JSON, Zemax .zmx, or Code V .seq files. Accuracy is aligned with Zemax OpticStudio.

Uses a mixin architecture – eight specialised mixin classes are composed at class definition time to keep each concern isolated:

  • :class:~deeplens.optics.geolens_pkg.psf_compute.GeoLensPSF – PSF computation (geometric, coherent, Huygens models).
  • :class:~deeplens.optics.geolens_pkg.eval.GeoLensEval – optical performance evaluation (spot, MTF, distortion, vignetting).
  • :class:~deeplens.optics.geolens_pkg.optim.GeoLensOptim – loss functions and gradient-based optimisation.
  • :class:~deeplens.optics.geolens_pkg.optim_ops.GeoLensSurfOps – surface geometry operations (aspheric conversion, pruning, shape correction, material matching).
  • :class:~deeplens.optics.geolens_pkg.vis.GeoLensVis – 2-D layout and ray visualisation.
  • :class:~deeplens.optics.geolens_pkg.io.GeoLensIO – read/write JSON, Zemax .zmx.
  • :class:~deeplens.optics.geolens_pkg.eval_tolerance.GeoLensTolerance – manufacturing tolerance analysis.
  • :class:~deeplens.optics.geolens_pkg.vis3d.GeoLensVis3D – 3-D mesh visualisation.

Key differentiability trick: Ray-surface intersection (:meth:~deeplens.optics.geometric_surface.base.Surface.newtons_method) uses a non-differentiable Newton loop followed by one differentiable Newton step to enable gradient flow.

Attributes:

Name Type Description
surfaces list[Surface]

Ordered list of optical surfaces.

materials list[Material]

Optical materials between surfaces.

d_sensor Tensor

Back focal distance [mm].

foclen float

Effective focal length [mm].

fnum float

F-number.

rfov_eff float

Effective half-diagonal field of view [radians] (pinhole model).

rfov float

Half-diagonal field of view [radians] (ray-traced).

sensor_size tuple

Physical sensor size (W, H) [mm].

sensor_res tuple

Sensor resolution (W, H) [pixels].

pixel_size float

Pixel pitch [mm].

References

Xinge Yang et al., "Curriculum learning for ab initio deep learned refractive optics," Nature Communications 2024.

Initialize a refractive lens.

There are two ways to initialize a GeoLens
  1. Read a lens from .json/.zmx/.seq file
  2. Initialize a lens with no lens file, then manually add surfaces and materials

Parameters:

Name Type Description Default
filename str

Path to lens file (.json, .zmx, or .seq). Defaults to None.

None
device device

Device for tensor computations. Defaults to None.

None
dtype dtype

Data type for computations. Defaults to torch.float32.

float32
Source code in src/geolens.py
def __init__(
    self,
    filename=None,
    device=None,
    dtype=torch.float32,
):
    """Initialize a refractive lens.

    There are two ways to initialize a GeoLens:
        1. Read a lens from .json/.zmx/.seq file
        2. Initialize a lens with no lens file, then manually add surfaces and materials

    Args:
        filename (str, optional): Path to lens file (.json, .zmx, or .seq). Defaults to None.
        device (torch.device, optional): Device for tensor computations. Defaults to None.
        dtype (torch.dtype, optional): Data type for computations. Defaults to torch.float32.
    """
    super().__init__(device=device, dtype=dtype)

    self.aper_idx = None

    # Load lens file
    if filename is not None:
        self.read_lens(filename)
    else:
        self.surfaces = []
        self.materials = []
        # Set default sensor size and resolution
        self.sensor_size = (8.0, 8.0)
        self.sensor_res = (2000, 2000)
        self.to(self.device)

read_lens

read_lens(filename)

Read a GeoLens from a file.

Supported file formats
  • .json: DeepLens native JSON format
  • .zmx: Zemax lens file format
  • .seq: CODE V sequence file format

Parameters:

Name Type Description Default
filename str

Path to the lens file.

required
Note

Sensor size and resolution will usually be overwritten by values from the file.

Source code in src/geolens.py
def read_lens(self, filename):
    """Read a GeoLens from a file.

    Supported file formats:
        - .json: DeepLens native JSON format
        - .zmx: Zemax lens file format
        - .seq: CODE V sequence file format

    Args:
        filename (str): Path to the lens file.

    Note:
        Sensor size and resolution will usually be overwritten by values from the file.
    """
    # Load lens file
    if filename[-4:] == ".txt":
        raise ValueError("File format .txt has been deprecated.")
    elif filename[-5:] == ".json":
        self.read_lens_json(filename)
    elif filename[-4:] == ".zmx":
        self.read_lens_zmx(filename)
    elif filename[-4:] == ".seq":
        self.read_lens_seq(filename)
    else:
        raise ValueError(f"File format {filename[-4:]} not supported.")

    # Complete sensor size and resolution if not set from lens file
    if not hasattr(self, "sensor_size"):
        self.sensor_size = (8.0, 8.0)
        print(
            f"Sensor_size not found in lens file. Using default: {self.sensor_size} mm. "
            "Consider specifying sensor_size in the lens file or using set_sensor()."
        )

    if not hasattr(self, "sensor_res"):
        self.sensor_res = (2000, 2000)
        print(
            f"Sensor_res not found in lens file. Using default: {self.sensor_res} pixels. "
            "Consider specifying sensor_res in the lens file or using set_sensor()."
        )
        self.set_sensor_res(self.sensor_res)

    # After loading lens, find aperture and compute derived properties
    self.to(self.device)
    self.astype(self.dtype)
    if self.aper_idx is None:
        self.find_aperture()
    self.post_computation()

post_computation

post_computation()

Compute derived optical properties after loading or modifying lens.

Calculates and caches
  • Effective focal length (EFL)
  • Entrance and exit pupil positions and radii
  • Field of view (FoV) in horizontal, vertical, and diagonal directions
  • F-number
Note

This method should be called after any changes to the lens geometry.

Source code in src/geolens.py
def post_computation(self):
    """Compute derived optical properties after loading or modifying lens.

    Calculates and caches:
        - Effective focal length (EFL)
        - Entrance and exit pupil positions and radii
        - Field of view (FoV) in horizontal, vertical, and diagonal directions
        - F-number

    Note:
        This method should be called after any changes to the lens geometry.
    """
    if self.aper_idx is None:
        self.find_aperture()
    self.calc_foclen()
    self.calc_pupil()
    self.calc_fov()

__call__

__call__(ray)

Trace rays through the lens system.

Makes the GeoLens callable, allowing ray tracing with function call syntax.

Source code in src/geolens.py
def __call__(self, ray):
    """Trace rays through the lens system.

    Makes the GeoLens callable, allowing ray tracing with function call syntax.
    """
    return self.trace(ray)

sample_grid_rays

sample_grid_rays(depth=float('inf'), num_grid=(11, 11), num_rays=SPP_PSF, wvln=DEFAULT_WAVE, uniform_fov=True, sample_more_off_axis=False, scale_pupil=1.0)

Sample grid rays from object space. (1) If depth is infinite, sample parallel rays at different field angles. (2) If depth is finite, sample point source rays from the object plane.

This function is usually used for (1) PSF map, (2) RMS error map, and (3) spot diagram calculation.

Parameters:

Name Type Description Default
depth float

sampling depth. Defaults to float("inf").

float('inf')
num_grid tuple

number of grid points. Defaults to [11, 11].

(11, 11)
num_rays int

number of rays. Defaults to SPP_PSF.

SPP_PSF
wvln float

ray wvln. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
uniform_fov bool

If True, sample uniform FoV angles.

True
sample_more_off_axis bool

If True, sample more off-axis rays.

False
scale_pupil float

Scale factor for pupil radius.

1.0

Returns:

Name Type Description
ray Ray object

Ray object. Shape [num_grid[1], num_grid[0], num_rays, 3]

Source code in src/geolens.py
@torch.no_grad()
def sample_grid_rays(
    self,
    depth=float("inf"),
    num_grid=(11, 11),
    num_rays=SPP_PSF,
    wvln=DEFAULT_WAVE,
    uniform_fov=True,
    sample_more_off_axis=False,
    scale_pupil=1.0,
):
    """Sample grid rays from object space.
        (1) If depth is infinite, sample parallel rays at different field angles.
        (2) If depth is finite, sample point source rays from the object plane.

    This function is usually used for (1) PSF map, (2) RMS error map, and (3) spot diagram calculation.

    Args:
        depth (float, optional): sampling depth. Defaults to float("inf").
        num_grid (tuple, optional): number of grid points. Defaults to [11, 11].
        num_rays (int, optional): number of rays. Defaults to SPP_PSF.
        wvln (float, optional): ray wvln. Defaults to DEFAULT_WAVE.
        uniform_fov (bool, optional): If True, sample uniform FoV angles.
        sample_more_off_axis (bool, optional): If True, sample more off-axis rays.
        scale_pupil (float, optional): Scale factor for pupil radius.

    Returns:
        ray (Ray object): Ray object. Shape [num_grid[1], num_grid[0], num_rays, 3]
    """
    # Normalize num_grid to a tuple if it's an int
    if isinstance(num_grid, int):
        num_grid = (num_grid, num_grid)

    # Calculate field angles for grid source. Top-left field has positive fov_x and negative fov_y
    x_list = [x for x in np.linspace(1, -1, num_grid[0])]
    y_list = [y for y in np.linspace(-1, 1, num_grid[1])]
    if sample_more_off_axis:
        x_list = [np.sign(x) * np.abs(x) ** 0.5 for x in x_list]
        y_list = [np.sign(y) * np.abs(y) ** 0.5 for y in y_list]

    # Calculate FoV_x and FoV_y
    if uniform_fov:
        # Sample uniform FoV angles
        fov_x_list = [x * self.vfov / 2 for x in x_list]
        fov_y_list = [y * self.hfov / 2 for y in y_list]
        fov_x_list = [float(np.rad2deg(fov_x)) for fov_x in fov_x_list]
        fov_y_list = [float(np.rad2deg(fov_y)) for fov_y in fov_y_list]
    else:
        # Sample uniform object grid
        fov_x_list = [np.arctan(x * np.tan(self.vfov / 2)) for x in x_list]
        fov_y_list = [np.arctan(y * np.tan(self.hfov / 2)) for y in y_list]
        fov_x_list = [float(np.rad2deg(fov_x)) for fov_x in fov_x_list]
        fov_y_list = [float(np.rad2deg(fov_y)) for fov_y in fov_y_list]

    # Sample rays (collimated or point source via unified API)
    rays = self.sample_from_fov(
        fov_x=fov_x_list,
        fov_y=fov_y_list,
        depth=depth,
        num_rays=num_rays,
        wvln=wvln,
        scale_pupil=scale_pupil,
    )
    return rays

sample_radial_rays

sample_radial_rays(num_field=5, depth=float('inf'), num_rays=SPP_PSF, wvln=DEFAULT_WAVE, direction='y')

Sample radial rays at evenly-spaced field angles along a chosen direction.

Parameters:

Name Type Description Default
num_field int

Number of field angles from on-axis to full-field. Defaults to 5.

5
depth float

Object distance in mm. Use float('inf') for collimated light. Defaults to float('inf').

float('inf')
num_rays int

Rays per field position. Defaults to SPP_PSF.

SPP_PSF
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
direction str

Sampling direction — "y" (meridional, default), "x" (sagittal), "diagonal" (45°, x = y).

'y'

Returns:

Name Type Description
Ray

Ray object with shape [num_field, num_rays, 3].

Source code in src/geolens.py
@torch.no_grad()
def sample_radial_rays(
    self,
    num_field=5,
    depth=float("inf"),
    num_rays=SPP_PSF,
    wvln=DEFAULT_WAVE,
    direction="y",
):
    """Sample radial rays at evenly-spaced field angles along a chosen direction.

    Args:
        num_field (int): Number of field angles from on-axis to full-field.
            Defaults to 5.
        depth (float): Object distance in mm. Use ``float('inf')`` for
            collimated light. Defaults to ``float('inf')``.
        num_rays (int): Rays per field position. Defaults to ``SPP_PSF``.
        wvln (float): Wavelength in micrometers. Defaults to ``DEFAULT_WAVE``.
        direction (str): Sampling direction —
            ``"y"`` (meridional, default),
            ``"x"`` (sagittal),
            ``"diagonal"`` (45°, x = y).

    Returns:
        Ray: Ray object with shape ``[num_field, num_rays, 3]``.
    """
    device = self.device
    fov_deg = float(np.rad2deg(self.rfov))
    fov_list = torch.linspace(0, fov_deg, num_field, device=device)

    if direction == "y":
        ray = self.sample_from_fov(
            fov_x=0.0, fov_y=fov_list, depth=depth, num_rays=num_rays, wvln=wvln
        )
    elif direction == "x":
        ray = self.sample_from_fov(
            fov_x=fov_list, fov_y=0.0, depth=depth, num_rays=num_rays, wvln=wvln
        )
    elif direction == "diagonal":
        # sample_from_fov creates a meshgrid; for pairwise diagonal, loop
        rays = [
            self.sample_from_fov(
                fov_x=f.item(), fov_y=f.item(), depth=depth, num_rays=num_rays, wvln=wvln
            )
            for f in fov_list
        ]
        ray_o = torch.stack([r.o for r in rays], dim=0)
        ray_d = torch.stack([r.d for r in rays], dim=0)
        ray = Ray(ray_o, ray_d, wvln, device=device)
    else:
        raise ValueError(f"Invalid direction: {direction!r}. Use 'x', 'y', or 'diagonal'.")
    return ray

sample_from_points

sample_from_points(points=[[0.0, 0.0, -10000.0]], num_rays=SPP_PSF, wvln=DEFAULT_WAVE, scale_pupil=1.0)

Sample rays from point sources in object space (absolute physical coordinates).

Used for PSF and chief ray calculation.

Parameters:

Name Type Description Default
points list or Tensor

Ray origins in shape [3], [N, 3], or [Nx, Ny, 3].

[[0.0, 0.0, -10000.0]]
num_rays int

Number of rays per point. Default: SPP_PSF.

SPP_PSF
wvln float

Wavelength of rays. Default: DEFAULT_WAVE.

DEFAULT_WAVE
scale_pupil float

Scale factor for pupil radius.

1.0

Returns:

Name Type Description
Ray

Sampled rays with shape (\*points.shape[:-1], num_rays, 3).

Source code in src/geolens.py
@torch.no_grad()
def sample_from_points(
    self,
    points=[[0.0, 0.0, -10000.0]],
    num_rays=SPP_PSF,
    wvln=DEFAULT_WAVE,
    scale_pupil=1.0,
):
    """
    Sample rays from point sources in object space (absolute physical coordinates).

    Used for PSF and chief ray calculation.

    Args:
        points (list or Tensor): Ray origins in shape [3], [N, 3], or [Nx, Ny, 3].
        num_rays (int): Number of rays per point. Default: SPP_PSF.
        wvln (float): Wavelength of rays. Default: DEFAULT_WAVE.
        scale_pupil (float): Scale factor for pupil radius.

    Returns:
        Ray: Sampled rays with shape ``(\\*points.shape[:-1], num_rays, 3)``.
    """
    # Ray origin is given
    if not torch.is_tensor(points):
        ray_o = torch.tensor(points, device=self.device)
    else:
        ray_o = points.to(self.device)

    # Sample points on the pupil
    pupilz, pupilr = self.get_entrance_pupil()
    pupilr *= scale_pupil
    ray_o2 = self.sample_circle(
        r=pupilr, z=pupilz, shape=(*ray_o.shape[:-1], num_rays)
    )

    # Compute ray directions
    if len(ray_o.shape) == 1:
        # Input point shape is [3]
        ray_o = ray_o.unsqueeze(0).repeat(num_rays, 1)  # shape [num_rays, 3]
        ray_d = ray_o2 - ray_o

    elif len(ray_o.shape) == 2:
        # Input point shape is [N, 3]
        ray_o = ray_o.unsqueeze(1).repeat(1, num_rays, 1)  # shape [N, num_rays, 3]
        ray_d = ray_o2 - ray_o

    elif len(ray_o.shape) == 3:
        # Input point shape is [Nx, Ny, 3]
        ray_o = ray_o.unsqueeze(2).repeat(
            1, 1, num_rays, 1
        )  # shape [Nx, Ny, num_rays, 3]
        ray_d = ray_o2 - ray_o

    else:
        raise Exception("The shape of input object positions is not supported.")

    # Calculate rays
    rays = Ray(ray_o, ray_d, wvln, device=self.device)
    return rays

sample_from_fov

sample_from_fov(fov_x=[0.0], fov_y=[0.0], depth=float('inf'), num_rays=SPP_CALC, wvln=DEFAULT_WAVE, entrance_pupil=True, scale_pupil=1.0)

Sample rays from object space at given field angles.

For infinite depth, generates collimated parallel rays: origins are distributed on the entrance pupil and all rays in a field share the same direction determined by the FOV angle.

For finite depth, generates diverging point-source rays: the point source position is determined by FOV angle and depth, and rays fan out toward the entrance pupil.

Parameters:

Name Type Description Default
fov_x float or list

Field angle(s) in the xz plane (degrees).

[0.0]
fov_y float or list

Field angle(s) in the yz plane (degrees).

[0.0]
depth float

Object distance in mm. float('inf') for collimated rays, finite for point-source rays.

float('inf')
num_rays int

Number of rays per field point.

SPP_CALC
wvln float

Wavelength in micrometers.

DEFAULT_WAVE
entrance_pupil bool

If True, sample on entrance pupil; otherwise on surface 0. Default: True.

True
scale_pupil float

Scale factor for pupil radius.

1.0

Returns:

Name Type Description
Ray

Rays with shape [..., num_rays, 3], where leading dims are squeezed when the corresponding fov input is scalar.

Source code in src/geolens.py
@torch.no_grad()
def sample_from_fov(
    self,
    fov_x=[0.0],
    fov_y=[0.0],
    depth=float("inf"),
    num_rays=SPP_CALC,
    wvln=DEFAULT_WAVE,
    entrance_pupil=True,
    scale_pupil=1.0,
):
    """Sample rays from object space at given field angles.

    For infinite depth, generates collimated parallel rays: origins are
    distributed on the entrance pupil and all rays in a field share the
    same direction determined by the FOV angle.

    For finite depth, generates diverging point-source rays: the point
    source position is determined by FOV angle and depth, and rays fan
    out toward the entrance pupil.

    Args:
        fov_x (float or list): Field angle(s) in the xz plane (degrees).
        fov_y (float or list): Field angle(s) in the yz plane (degrees).
        depth (float): Object distance in mm. ``float('inf')`` for
            collimated rays, finite for point-source rays.
        num_rays (int): Number of rays per field point.
        wvln (float): Wavelength in micrometers.
        entrance_pupil (bool): If True, sample on entrance pupil;
            otherwise on surface 0. Default: True.
        scale_pupil (float): Scale factor for pupil radius.

    Returns:
        Ray: Rays with shape ``[..., num_rays, 3]``, where leading dims
            are squeezed when the corresponding fov input is scalar.
    """
    # Track which inputs were scalar for output shape
    x_scalar = isinstance(fov_x, (float, int))
    y_scalar = isinstance(fov_y, (float, int))
    if x_scalar:
        fov_x = [float(fov_x)]
    if y_scalar:
        fov_y = [float(fov_y)]

    fov_x_rad = torch.tensor([fx * torch.pi / 180 for fx in fov_x], device=self.device)
    fov_y_rad = torch.tensor([fy * torch.pi / 180 for fy in fov_y], device=self.device)
    fov_x_grid, fov_y_grid = torch.meshgrid(fov_x_rad, fov_y_rad, indexing="xy")

    # Pupil position and radius
    if entrance_pupil:
        pupilz, pupilr = self.get_entrance_pupil()
    else:
        pupilz, pupilr = 0.0, self.surfaces[0].r
    pupilr *= scale_pupil

    if depth == float("inf"):
        # Collimated rays: origins on pupil, uniform direction per field
        ray_o = self.sample_circle(
            r=pupilr, z=pupilz, shape=[len(fov_y), len(fov_x), num_rays]
        )
        dx = torch.tan(fov_x_grid).unsqueeze(-1).expand_as(ray_o[..., 0])
        dy = torch.tan(fov_y_grid).unsqueeze(-1).expand_as(ray_o[..., 1])
        dz = torch.ones_like(ray_o[..., 2])
        ray_d = torch.stack((dx, dy, dz), dim=-1)

        if x_scalar:
            ray_o = ray_o.squeeze(1)
            ray_d = ray_d.squeeze(1)
        if y_scalar:
            ray_o = ray_o.squeeze(0)
            ray_d = ray_d.squeeze(0)

        rays = Ray(ray_o, ray_d, wvln, device=self.device)
        rays.prop_to(-1.0)

    else:
        # Point-source rays: origin at object point, fan toward pupil
        x = torch.tan(fov_x_grid) * depth
        y = torch.tan(fov_y_grid) * depth
        z = torch.full_like(x, depth)
        points = torch.stack((x, y, z), dim=-1)

        if x_scalar:
            points = points.squeeze(-2)
        if y_scalar:
            points = points.squeeze(0)

        rays = self.sample_from_points(
            points=points, num_rays=num_rays, wvln=wvln, scale_pupil=scale_pupil
        )

    return rays

sample_sensor

sample_sensor(spp=64, wvln=DEFAULT_WAVE, sub_pixel=False)

Sample rays from sensor pixels (backward rays). Used for ray tracing rendering.

Parameters:

Name Type Description Default
spp int

sample per pixel. Defaults to 64.

64
pupil bool

whether to use pupil. Defaults to True.

required
wvln float

ray wvln. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
sub_pixel bool

whether to sample multiple points inside the pixel. Defaults to False.

False

Returns:

Name Type Description
ray Ray object

Ray object. Shape [H, W, spp, 3]

Source code in src/geolens.py
@torch.no_grad()
def sample_sensor(self, spp=64, wvln=DEFAULT_WAVE, sub_pixel=False):
    """Sample rays from sensor pixels (backward rays). Used for ray tracing rendering.

    Args:
        spp (int, optional): sample per pixel. Defaults to 64.
        pupil (bool, optional): whether to use pupil. Defaults to True.
        wvln (float, optional): ray wvln. Defaults to DEFAULT_WAVE.
        sub_pixel (bool, optional): whether to sample multiple points inside the pixel. Defaults to False.

    Returns:
        ray (Ray object): Ray object. Shape [H, W, spp, 3]
    """
    w, h = self.sensor_size
    W, H = self.sensor_res
    device = self.device

    # Sample points on sensor plane
    # Use top-left point as reference in rendering, so here we should sample bottom-right point
    x1, y1 = torch.meshgrid(
        torch.linspace(
            -w / 2,
            w / 2,
            W + 1,
            device=device,
        )[1:],
        torch.linspace(
            h / 2,
            -h / 2,
            H + 1,
            device=device,
        )[1:],
        indexing="xy",
    )
    z1 = torch.full_like(x1, self.d_sensor)

    # Sample second points on the pupil
    pupilz, pupilr = self.get_exit_pupil()
    ray_o2 = self.sample_circle(r=pupilr, z=pupilz, shape=(H, W, spp))

    # Form rays
    ray_o = torch.stack((x1, y1, z1), 2)
    ray_o = ray_o.unsqueeze(2).repeat(1, 1, spp, 1)  # [H, W, spp, 3]

    # Sub-pixel sampling for more realistic rendering
    if sub_pixel:
        delta_ox = (
            torch.rand(ray_o.shape[:-1], device=device)
            * self.pixel_size
        )
        delta_oy = (
            -torch.rand(ray_o.shape[:-1], device=device)
            * self.pixel_size
        )
        delta_oz = torch.zeros_like(delta_ox)
        delta_o = torch.stack((delta_ox, delta_oy, delta_oz), -1)
        ray_o = ray_o + delta_o

    # Form rays
    ray_d = ray_o2 - ray_o  # shape [H, W, spp, 3]
    ray = Ray(ray_o, ray_d, wvln, device=device)
    return ray

sample_circle

sample_circle(r, z, shape=[16, 16, 512])

Sample points inside a circle.

Parameters:

Name Type Description Default
r float

Radius of the circle.

required
z float

Z-coordinate for all sampled points.

required
shape list

Shape of the output tensor.

[16, 16, 512]

Returns:

Type Description

torch.Tensor: Sampled points, shape (\*shape, 3).

Source code in src/geolens.py
def sample_circle(self, r, z, shape=[16, 16, 512]):
    """Sample points inside a circle.

    Args:
        r (float): Radius of the circle.
        z (float): Z-coordinate for all sampled points.
        shape (list): Shape of the output tensor.

    Returns:
        torch.Tensor: Sampled points, shape ``(\\*shape, 3)``.
    """
    device = self.device

    # Generate random angles and radii
    theta = torch.rand(*shape, device=device) * 2 * torch.pi
    r2 = torch.rand(*shape, device=device) * r**2
    radius = torch.sqrt(r2)

    # Stack to form 3D points
    x = radius * torch.cos(theta)
    y = radius * torch.sin(theta)
    z_tensor = torch.full_like(x, z)
    points = torch.stack((x, y, z_tensor), dim=-1)

    # Manually sample chief ray
    # points[..., 0, :2] = 0.0

    return points

trace

trace(ray, surf_range=None, record=False)

Trace rays through the lens.

Forward or backward tracing is automatically determined by the ray direction.

Parameters:

Name Type Description Default
ray Ray object

Ray object.

required
surf_range list

Surface index range.

None
record bool

record ray path or not.

False

Returns:

Name Type Description
ray_final Ray object

ray after optical system.

ray_o_rec list

list of intersection points.

Source code in src/geolens.py
def trace(self, ray, surf_range=None, record=False):
    """Trace rays through the lens.

    Forward or backward tracing is automatically determined by the ray direction.

    Args:
        ray (Ray object): Ray object.
        surf_range (list): Surface index range.
        record (bool): record ray path or not.

    Returns:
        ray_final (Ray object): ray after optical system.
        ray_o_rec (list): list of intersection points.
    """
    if surf_range is None:
        surf_range = range(0, len(self.surfaces))

    if (ray.d[..., 2] > 0).any():
        ray_out, ray_o_rec = self.forward_tracing(ray, surf_range, record=record)
    else:
        ray_out, ray_o_rec = self.backward_tracing(ray, surf_range, record=record)

    return ray_out, ray_o_rec

trace2obj

trace2obj(ray)

Traces rays backwards through all lens surfaces from sensor side to object side.

Parameters:

Name Type Description Default
ray Ray

Ray object to trace backwards.

required

Returns:

Name Type Description
Ray

Ray object after backward propagation through the lens.

Source code in src/geolens.py
def trace2obj(self, ray):
    """Traces rays backwards through all lens surfaces from sensor side
    to object side.

    Args:
        ray (Ray): Ray object to trace backwards.

    Returns:
        Ray: Ray object after backward propagation through the lens.
    """
    ray, _ = self.trace(ray)
    return ray

trace2sensor

trace2sensor(ray, record=False)

Forward trace rays through the lens to sensor plane.

Parameters:

Name Type Description Default
ray Ray object

Ray object.

required
record bool

record ray path or not.

False

Returns:

Name Type Description
ray_out Ray object

ray after optical system.

ray_o_record list

list of intersection points.

Source code in src/geolens.py
def trace2sensor(self, ray, record=False):
    """Forward trace rays through the lens to sensor plane.

    Args:
        ray (Ray object): Ray object.
        record (bool): record ray path or not.

    Returns:
        ray_out (Ray object): ray after optical system.
        ray_o_record (list): list of intersection points.
    """
    # Manually propagate ray to a shallow depth to avoid numerical instability
    if ray.o[..., 2].min() < -100.0:
        ray = ray.prop_to(-10.0)

    # Trace rays
    ray, ray_o_record = self.trace(ray, record=record)
    ray = ray.prop_to(self.d_sensor)

    if record:
        ray_o = ray.o.clone().detach()
        # Set to NaN to be skipped in 2d layout visualization
        ray_o[ray.is_valid == 0] = float("nan")
        ray_o_record.append(ray_o)
        return ray, ray_o_record
    else:
        return ray

trace2exit_pupil

trace2exit_pupil(ray)

Forward trace rays through the lens to exit pupil plane.

Parameters:

Name Type Description Default
ray Ray

Ray object to trace.

required

Returns:

Name Type Description
Ray

Ray object propagated to the exit pupil plane.

Source code in src/geolens.py
def trace2exit_pupil(self, ray):
    """Forward trace rays through the lens to exit pupil plane.

    Args:
        ray (Ray): Ray object to trace.

    Returns:
        Ray: Ray object propagated to the exit pupil plane.
    """
    ray = self.trace2sensor(ray)
    pupil_z, _ = self.get_exit_pupil()
    ray = ray.prop_to(pupil_z)
    return ray

forward_tracing

forward_tracing(ray, surf_range, record)

Forward traces rays through each surface in the specified range from object side to image side.

Parameters:

Name Type Description Default
ray Ray

Ray object to trace.

required
surf_range range

Range of surface indices to trace through.

required
record bool

If True, record ray positions at each surface.

required

Returns:

Name Type Description
tuple

(ray_out, ray_o_record) where: - ray_out (Ray): Ray after propagation through all surfaces. - ray_o_record (list or None): List of ray positions at each surface, or None if record is False.

Source code in src/geolens.py
def forward_tracing(self, ray, surf_range, record):
    """Forward traces rays through each surface in the specified range from object side to image side.

    Args:
        ray (Ray): Ray object to trace.
        surf_range (range): Range of surface indices to trace through.
        record (bool): If True, record ray positions at each surface.

    Returns:
        tuple: (ray_out, ray_o_record) where:
            - ray_out (Ray): Ray after propagation through all surfaces.
            - ray_o_record (list or None): List of ray positions at each surface,
                or None if record is False.
    """
    if record:
        ray_o_record = []
        ray_o_record.append(ray.o.clone().detach())
    else:
        ray_o_record = None

    mat1 = Material("air")
    for i in surf_range:
        n1 = mat1.ior(ray.wvln)
        n2 = self.surfaces[i].mat2.ior(ray.wvln)
        ray = self.surfaces[i].ray_reaction(ray, n1, n2)
        mat1 = self.surfaces[i].mat2

        if record:
            ray_out_o = ray.o.clone().detach()
            ray_out_o[ray.is_valid == 0] = float("nan")
            ray_o_record.append(ray_out_o)

    return ray, ray_o_record

backward_tracing

backward_tracing(ray, surf_range, record)

Backward traces rays through each surface in reverse order from image side to object side.

Parameters:

Name Type Description Default
ray Ray

Ray object to trace.

required
surf_range range

Range of surface indices to trace through.

required
record bool

If True, record ray positions at each surface.

required

Returns:

Name Type Description
tuple

(ray_out, ray_o_record) where: - ray_out (Ray): Ray after backward propagation through all surfaces. - ray_o_record (list or None): List of ray positions at each surface, or None if record is False.

Source code in src/geolens.py
def backward_tracing(self, ray, surf_range, record):
    """Backward traces rays through each surface in reverse order from image side to object side.

    Args:
        ray (Ray): Ray object to trace.
        surf_range (range): Range of surface indices to trace through.
        record (bool): If True, record ray positions at each surface.

    Returns:
        tuple: (ray_out, ray_o_record) where:
            - ray_out (Ray): Ray after backward propagation through all surfaces.
            - ray_o_record (list or None): List of ray positions at each surface,
                or None if record is False.
    """
    if record:
        ray_o_record = []
        ray_o_record.append(ray.o.clone().detach())
    else:
        ray_o_record = None

    # Initial material: the material the ray is in when entering the
    # backward trace. If the range ends before the last surface, the ray
    # starts inside surfaces[max_idx].mat2, not air.
    max_idx = max(surf_range)
    if max_idx < len(self.surfaces) - 1:
        mat1 = self.surfaces[max_idx].mat2
    else:
        mat1 = Material("air")

    for i in np.flip(surf_range):
        n1 = mat1.ior(ray.wvln)
        n2 = self.surfaces[i - 1].mat2.ior(ray.wvln) if i > 0 else Material("air").ior(ray.wvln)
        ray = self.surfaces[i].ray_reaction(ray, n1, n2)
        mat1 = self.surfaces[i - 1].mat2 if i > 0 else Material("air")

        if record:
            ray_out_o = ray.o.clone().detach()
            ray_out_o[ray.is_valid == 0] = float("nan")
            ray_o_record.append(ray_out_o)

    return ray, ray_o_record

render

render(img_obj, depth=DEPTH, method='ray_tracing', **kwargs)

Differentiable image simulation.

Image simulation methods

[1] PSF map block convolution. [2] PSF patch convolution. [3] Ray tracing rendering.

Parameters:

Name Type Description Default
img_obj Tensor

Input image object in raw space. Shape of [N, C, H, W].

required
depth float

Depth of the object. Defaults to DEPTH.

DEPTH
method str

Image simulation method. One of 'psf_map', 'psf_patch', or 'ray_tracing'. Defaults to 'ray_tracing'.

'ray_tracing'
**kwargs

Additional arguments for different methods: - psf_grid (tuple): Grid size for PSF map method. Defaults to (10, 10). - psf_ks (int): Kernel size for PSF methods. Defaults to PSF_KS. - patch_center (tuple): Center position for PSF patch method. - spp (int): Samples per pixel for ray tracing. Defaults to SPP_RENDER.

{}

Returns:

Name Type Description
Tensor

Rendered image tensor. Shape of [N, C, H, W].

Source code in src/geolens.py
def render(self, img_obj, depth=DEPTH, method="ray_tracing", **kwargs):
    """Differentiable image simulation.

    Image simulation methods:
        [1] PSF map block convolution.
        [2] PSF patch convolution.
        [3] Ray tracing rendering.

    Args:
        img_obj (Tensor): Input image object in raw space. Shape of [N, C, H, W].
        depth (float, optional): Depth of the object. Defaults to DEPTH.
        method (str, optional): Image simulation method. One of 'psf_map', 'psf_patch',
            or 'ray_tracing'. Defaults to 'ray_tracing'.
        **kwargs: Additional arguments for different methods:
            - psf_grid (tuple): Grid size for PSF map method. Defaults to (10, 10).
            - psf_ks (int): Kernel size for PSF methods. Defaults to PSF_KS.
            - patch_center (tuple): Center position for PSF patch method.
            - spp (int): Samples per pixel for ray tracing. Defaults to SPP_RENDER.

    Returns:
        Tensor: Rendered image tensor. Shape of [N, C, H, W].
    """
    B, C, Himg, Wimg = img_obj.shape
    Wsensor, Hsensor = self.sensor_res

    # Image simulation
    if method == "psf_map":
        # PSF rendering - uses PSF map to render image
        assert Wimg == Wsensor and Himg == Hsensor, (
            f"Sensor resolution {Wsensor}x{Hsensor} must match input image {Wimg}x{Himg}."
        )
        psf_grid = kwargs.get("psf_grid", (10, 10))
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        img_render = self.render_psf_map(
            img_obj, depth=depth, psf_grid=psf_grid, psf_ks=psf_ks
        )

    elif method == "psf_patch":
        # PSF patch rendering - uses a single PSF to render a patch of the image
        patch_center = kwargs.get("patch_center", (0.0, 0.0))
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        img_render = self.render_psf_patch(
            img_obj, depth=depth, patch_center=patch_center, psf_ks=psf_ks
        )

    elif method == "ray_tracing":
        # Ray tracing rendering
        assert Wimg == Wsensor and Himg == Hsensor, (
            f"Sensor resolution {Wsensor}x{Hsensor} must match input image {Wimg}x{Himg}."
        )
        spp = kwargs.get("spp", SPP_RENDER)
        img_render = self.render_raytracing(img_obj, depth=depth, spp=spp)

    else:
        raise Exception(f"Image simulation method {method} is not supported.")

    return img_render

render_raytracing

render_raytracing(img, depth=DEPTH, spp=SPP_RENDER, vignetting=False)

Render RGB image using ray tracing rendering.

Parameters:

Name Type Description Default
img tensor

RGB image tensor. Shape of [N, 3, H, W].

required
depth float

Depth of the object. Defaults to DEPTH.

DEPTH
spp int

Sample per pixel. Defaults to 64.

SPP_RENDER
vignetting bool

whether to consider vignetting effect. Defaults to False.

False

Returns:

Name Type Description
img_render tensor

Rendered RGB image tensor. Shape of [N, 3, H, W].

Source code in src/geolens.py
def render_raytracing(self, img, depth=DEPTH, spp=SPP_RENDER, vignetting=False):
    """Render RGB image using ray tracing rendering.

    Args:
        img (tensor): RGB image tensor. Shape of [N, 3, H, W].
        depth (float, optional): Depth of the object. Defaults to DEPTH.
        spp (int, optional): Sample per pixel. Defaults to 64.
        vignetting (bool, optional): whether to consider vignetting effect. Defaults to False.

    Returns:
        img_render (tensor): Rendered RGB image tensor. Shape of [N, 3, H, W].
    """
    img_render = torch.zeros_like(img)
    for i in range(3):
        img_render[:, i, :, :] = self.render_raytracing_mono(
            img=img[:, i, :, :],
            wvln=WAVE_RGB[i],
            depth=depth,
            spp=spp,
            vignetting=vignetting,
        )
    return img_render

render_raytracing_mono

render_raytracing_mono(img, wvln, depth=DEPTH, spp=64, vignetting=False)

Render monochrome image using ray tracing rendering.

Parameters:

Name Type Description Default
img tensor

Monochrome image tensor. Shape of [N, 1, H, W] or [N, H, W].

required
wvln float

Wavelength of the light.

required
depth float

Depth of the object. Defaults to DEPTH.

DEPTH
spp int

Sample per pixel. Defaults to 64.

64

Returns:

Name Type Description
img_mono tensor

Rendered monochrome image tensor. Shape of [N, 1, H, W] or [N, H, W].

Source code in src/geolens.py
def render_raytracing_mono(self, img, wvln, depth=DEPTH, spp=64, vignetting=False):
    """Render monochrome image using ray tracing rendering.

    Args:
        img (tensor): Monochrome image tensor. Shape of [N, 1, H, W] or [N, H, W].
        wvln (float): Wavelength of the light.
        depth (float, optional): Depth of the object. Defaults to DEPTH.
        spp (int, optional): Sample per pixel. Defaults to 64.

    Returns:
        img_mono (tensor): Rendered monochrome image tensor. Shape of [N, 1, H, W] or [N, H, W].
    """
    img = torch.flip(img, [-2, -1])
    scale = self.calc_scale(depth=depth)
    ray = self.sample_sensor(spp=spp, wvln=wvln)
    ray = self.trace2obj(ray)
    img_mono = self.render_compute_image(
        img, depth=depth, scale=scale, ray=ray, vignetting=vignetting
    )
    return img_mono

render_compute_image

render_compute_image(img, depth, scale, ray, vignetting=False)

Computes the intersection points between rays and the object image plane, then generates the rendered image following rendering equation.

Back-propagation gradient flow: image -> w_i -> u -> p -> ray -> surface

Parameters:

Name Type Description Default
img tensor

[N, C, H, W] or [N, H, W] shape image tensor.

required
depth float

depth of the object.

required
scale float

scale factor.

required
ray Ray object

Ray object. Shape [H, W, spp, 3].

required
vignetting bool

whether to consider vignetting effect.

False

Returns:

Name Type Description
image tensor

[N, C, H, W] or [N, H, W] shape rendered image tensor.

Source code in src/geolens.py
def render_compute_image(self, img, depth, scale, ray, vignetting=False):
    """Computes the intersection points between rays and the object image plane, then generates the rendered image following rendering equation.

    Back-propagation gradient flow: image -> w_i -> u -> p -> ray -> surface

    Args:
        img (tensor): [N, C, H, W] or [N, H, W] shape image tensor.
        depth (float): depth of the object.
        scale (float): scale factor.
        ray (Ray object): Ray object. Shape [H, W, spp, 3].
        vignetting (bool): whether to consider vignetting effect.

    Returns:
        image (tensor): [N, C, H, W] or [N, H, W] shape rendered image tensor.
    """
    assert torch.is_tensor(img), "Input image should be Tensor."

    # Padding
    H, W = img.shape[-2:]
    if len(img.shape) == 3:
        img = F.pad(img.unsqueeze(1), (1, 1, 1, 1), "replicate").squeeze(1)
    elif len(img.shape) == 4:
        img = F.pad(img, (1, 1, 1, 1), "replicate")
    else:
        raise ValueError("Input image should be [N, C, H, W] or [N, H, W] tensor.")

    # Scale object image physical size to get 1:1 pixel-pixel alignment with sensor image
    ray = ray.prop_to(depth)
    p = ray.o[..., :2]
    pixel_size = scale * self.pixel_size
    ray.is_valid = (
        ray.is_valid
        * (torch.abs(p[..., 0] / pixel_size) < (W / 2 + 1))
        * (torch.abs(p[..., 1] / pixel_size) < (H / 2 + 1))
    )

    # Convert to uv coordinates in object image coordinate
    # (we do padding so corrdinates should add 1)
    u = torch.clamp(W / 2 + p[..., 0] / pixel_size, min=-0.99, max=W - 0.01)
    v = torch.clamp(H / 2 + p[..., 1] / pixel_size, min=0.01, max=H + 0.99)

    # (idx_i, idx_j) denotes left-top pixel (reference pixel). Index does not store gradients
    # (idx + 1 because we did padding)
    idx_i = H - v.ceil().long() + 1
    idx_j = u.floor().long() + 1

    # Gradients are stored in interpolation weight parameters
    w_i = v - v.floor().long()
    w_j = u.ceil().long() - u

    # Bilinear interpolation
    # (img shape [B, N, H', W'], idx_i shape [H, W, spp], w_i shape [H, W, spp], irr_img shape [N, C, H, W, spp])
    irr_img = img[..., idx_i, idx_j] * w_i * w_j
    irr_img += img[..., idx_i + 1, idx_j] * (1 - w_i) * w_j
    irr_img += img[..., idx_i, idx_j + 1] * w_i * (1 - w_j)
    irr_img += img[..., idx_i + 1, idx_j + 1] * (1 - w_i) * (1 - w_j)

    # Computation image
    if not vignetting:
        image = torch.sum(irr_img * ray.is_valid, -1) / (
            torch.sum(ray.is_valid, -1) + EPSILON
        )
    else:
        image = torch.sum(irr_img * ray.is_valid, -1) / torch.numel(ray.is_valid)

    return image

unwarp

unwarp(img, depth=DEPTH, num_grid=128, crop=True, flip=True)

Unwarp rendered images using distortion map.

Parameters:

Name Type Description Default
img tensor

Rendered image tensor. Shape of [N, C, H, W].

required
depth float

Depth of the object. Defaults to DEPTH.

DEPTH
grid_size int

Grid size. Defaults to 256.

required
crop bool

Whether to crop the image. Defaults to True.

True

Returns:

Name Type Description
img_unwarpped tensor

Unwarped image tensor. Shape of [N, C, H, W].

Source code in src/geolens.py
def unwarp(self, img, depth=DEPTH, num_grid=128, crop=True, flip=True):
    """Unwarp rendered images using distortion map.

    Args:
        img (tensor): Rendered image tensor. Shape of [N, C, H, W].
        depth (float, optional): Depth of the object. Defaults to DEPTH.
        grid_size (int, optional): Grid size. Defaults to 256.
        crop (bool, optional): Whether to crop the image. Defaults to True.

    Returns:
        img_unwarpped (tensor): Unwarped image tensor. Shape of [N, C, H, W].
    """
    # Calculate distortion map, shape (num_grid, num_grid, 2)
    distortion_map = self.calc_distortion_map(depth=depth, num_grid=num_grid)

    # Interpolate distortion map to image resolution
    distortion_map = distortion_map.permute(2, 0, 1).unsqueeze(1)
    # distortion_map = torch.flip(distortion_map, [-2]) if flip else distortion_map
    distortion_map = F.interpolate(
        distortion_map, img.shape[-2:], mode="bilinear", align_corners=True
    )  # shape (B, 2, Himg, Wimg)
    distortion_map = distortion_map.permute(1, 2, 3, 0).repeat(
        img.shape[0], 1, 1, 1
    )  # shape (B, Himg, Wimg, 2)

    # Unwarp using grid_sample function
    img_unwarpped = F.grid_sample(
        img, distortion_map, align_corners=True
    )  # shape (B, C, Himg, Wimg)
    return img_unwarpped

find_aperture

find_aperture()

Find and set the aperture stop index.

Called after loading when no surface was marked with is_aperture in the lens file. Looks for an Aperture surface instance first, then falls back to the surface with the smallest semi-diameter.

Sets

self.aper_idx (int): Index of the aperture surface.

Source code in src/geolens.py
def find_aperture(self):
    """Find and set the aperture stop index.

    Called after loading when no surface was marked with ``is_aperture``
    in the lens file. Looks for an ``Aperture`` surface instance first,
    then falls back to the surface with the smallest semi-diameter.

    Sets:
        self.aper_idx (int): Index of the aperture surface.
    """
    for i, s in enumerate(self.surfaces):
        if isinstance(s, Aperture):
            self.aper_idx = i
            return

    self.aper_idx = int(np.argmin([s.r for s in self.surfaces]))
    print("No aperture found, use the smallest surface as aperture.")

find_diff_surf

find_diff_surf()

Get differentiable/optimizable surface indices.

Returns a list of surface indices that can be optimized during lens design. Excludes the aperture surface from optimization.

Returns:

Type Description

list or range: Surface indices excluding the aperture.

Source code in src/geolens.py
def find_diff_surf(self):
    """Get differentiable/optimizable surface indices.

    Returns a list of surface indices that can be optimized during lens design.
    Excludes the aperture surface from optimization.

    Returns:
        list or range: Surface indices excluding the aperture.
    """
    if self.aper_idx is None:
        diff_surf_range = range(len(self.surfaces))
    else:
        diff_surf_range = list(range(0, self.aper_idx)) + list(
            range(self.aper_idx + 1, len(self.surfaces))
        )
    return diff_surf_range

calc_foclen

calc_foclen(test_fov_deg=1.0)

Compute effective focal length (EFL).

Two-step approach: 1. Trace on-axis parallel rays to find the paraxial focal point z. This is necessary because the sensor may not be at the focal plane (e.g. finite-conjugate designs or defocused systems). 2. Trace off-axis rays at a small angle to the focal point, measure image height, and compute EFL = imgh / tan(angle).

The default 1-degree field avoids the numerical noise of truly paraxial angles (0.01 rad) while remaining small enough to approximate the paraxial regime.

Parameters:

Name Type Description Default
test_fov_deg float

Chief-ray field angle used for the focal-length estimate. Defaults to 1.0 degree.

1.0
Updates

self.efl: Effective focal length. self.foclen: Alias for effective focal length. self.bfl: Back focal length (distance from last surface to sensor).

Source code in src/geolens.py
@torch.no_grad()
def calc_foclen(self, test_fov_deg=1.0):
    """Compute effective focal length (EFL).

    Two-step approach:
    1. Trace on-axis parallel rays to find the paraxial focal point z.
       This is necessary because the sensor may not be at the focal plane
       (e.g. finite-conjugate designs or defocused systems).
    2. Trace off-axis rays at a small angle to the focal point, measure
       image height, and compute EFL = imgh / tan(angle).

    The default 1-degree field avoids the numerical noise of truly paraxial
    angles (0.01 rad) while remaining small enough to approximate the
    paraxial regime.

    Args:
        test_fov_deg (float, optional): Chief-ray field angle used for the
            focal-length estimate.  Defaults to 1.0 degree.

    Updates:
        self.efl: Effective focal length.
        self.foclen: Alias for effective focal length.
        self.bfl: Back focal length (distance from last surface to sensor).
    """
    # Step 1: Trace on-axis parallel rays to find paraxial focus z
    ray_axis = self.sample_from_fov(
        fov_x=0.0, fov_y=0.0, entrance_pupil=False, scale_pupil=0.2
    )
    ray_axis, _ = self.trace(ray_axis)
    valid_axis = ray_axis.is_valid > 0
    if valid_axis.sum() <= 0:
        self.efl = float('nan')
        self.foclen = float('nan')
        self.bfl = self.d_sensor.item() - self.surfaces[-1].d.item()
        return float('nan')

    # Find where rays cross the optical axis (minimize transverse distance)
    t = -(ray_axis.d[valid_axis, 0] * ray_axis.o[valid_axis, 0]
          + ray_axis.d[valid_axis, 1] * ray_axis.o[valid_axis, 1]) / (
        ray_axis.d[valid_axis, 0] ** 2 + ray_axis.d[valid_axis, 1] ** 2
    )
    focus_z = ray_axis.o[valid_axis, 2] + t * ray_axis.d[valid_axis, 2]
    focus_z = focus_z[~torch.isnan(focus_z) & (focus_z > 0)]
    if len(focus_z) == 0:
        self.efl = float('nan')
        self.foclen = float('nan')
        self.bfl = self.d_sensor.item() - self.surfaces[-1].d.item()
        return float('nan')
    paraxial_focus_z = float(torch.mean(focus_z))

    # Step 2: Trace off-axis ray to focal point, measure image height
    test_fov_rad = float(np.deg2rad(test_fov_deg))
    ray = self.sample_from_fov(
        fov_x=0.0, fov_y=test_fov_deg, entrance_pupil=False, scale_pupil=0.2
    )
    ray, _ = self.trace(ray)
    ray = ray.prop_to(paraxial_focus_z)

    valid_sum = ray.is_valid.sum()
    if valid_sum <= 0:
        eff_foclen = float('nan')
    else:
        imgh = (ray.o[:, 1] * ray.is_valid).sum() / valid_sum
        eff_foclen = imgh.item() / float(np.tan(test_fov_rad))

    self.efl = eff_foclen
    self.foclen = eff_foclen
    self.bfl = self.d_sensor.item() - self.surfaces[-1].d.item()
    return eff_foclen

calc_numerical_aperture

calc_numerical_aperture(n=1.0)

Compute numerical aperture (NA).

Parameters:

Name Type Description Default
n float

Refractive index. Defaults to 1.0.

1.0

Returns:

Name Type Description
NA float

Numerical aperture.

Reference

[1] https://en.wikipedia.org/wiki/Numerical_aperture

Source code in src/geolens.py
@torch.no_grad()
def calc_numerical_aperture(self, n=1.0):
    """Compute numerical aperture (NA).

    Args:
        n (float, optional): Refractive index. Defaults to 1.0.

    Returns:
        NA (float): Numerical aperture.

    Reference:
        [1] https://en.wikipedia.org/wiki/Numerical_aperture
    """
    return n * math.sin(math.atan(1 / 2 / self.fnum))

calc_focal_plane

calc_focal_plane(wvln=DEFAULT_WAVE)

Compute the focus distance in the object space. Ray starts from sensor center and traces to the object space.

Parameters:

Name Type Description Default
wvln float

Wavelength. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE

Returns:

Name Type Description
focal_plane float

Focal plane in the object space.

Source code in src/geolens.py
@torch.no_grad()
def calc_focal_plane(self, wvln=DEFAULT_WAVE):
    """Compute the focus distance in the object space. Ray starts from sensor center and traces to the object space.

    Args:
        wvln (float, optional): Wavelength. Defaults to DEFAULT_WAVE.

    Returns:
        focal_plane (float): Focal plane in the object space.
    """
    device = self.device

    # Sample point source rays from sensor center
    o1 = torch.zeros(SPP_CALC, 3, device=device)
    o1[:, 2] = self.d_sensor

    # Sample the first surface as pupil
    # o2 = self.sample_circle(self.surfaces[0].r, z=0.0, shape=[SPP_CALC])
    # o2 *= 0.5  # Shrink sample region to improve accuracy
    pupilz, pupilr = self.get_exit_pupil()
    o2 = self.sample_circle(pupilr, pupilz, shape=[SPP_CALC])
    d = o2 - o1
    ray = Ray(o1, d, wvln, device=device)

    # Trace rays to object space
    ray = self.trace2obj(ray)

    # Optical axis intersection
    t = (ray.d[..., 0] * ray.o[..., 0] + ray.d[..., 1] * ray.o[..., 1]) / (
        ray.d[..., 0] ** 2 + ray.d[..., 1] ** 2
    )
    focus_z = (ray.o[..., 2] - ray.d[..., 2] * t)[ray.is_valid > 0].cpu().numpy()
    focus_z = focus_z[~np.isnan(focus_z) & (focus_z < 0)]

    if len(focus_z) > 0:
        focal_plane = float(np.mean(focus_z))
    else:
        raise ValueError(
            "No valid rays found, focal plane in the image space cannot be computed."
        )

    return focal_plane

calc_sensor_plane

calc_sensor_plane(depth=float('inf'))

Calculate in-focus sensor plane.

Parameters:

Name Type Description Default
depth float

Depth of the object plane. Defaults to float("inf").

float('inf')

Returns:

Name Type Description
d_sensor Tensor

Sensor plane in the image space.

Source code in src/geolens.py
@torch.no_grad()
def calc_sensor_plane(self, depth=float("inf")):
    """Calculate in-focus sensor plane.

    Args:
        depth (float, optional): Depth of the object plane. Defaults to float("inf").

    Returns:
        d_sensor (torch.Tensor): Sensor plane in the image space.
    """
    # Sample and trace rays, shape [SPP_CALC, 3]
    ray = self.sample_from_fov(
        fov_x=0.0, fov_y=0.0, depth=depth, num_rays=SPP_CALC, wvln=DEFAULT_WAVE
    )
    ray = self.trace2sensor(ray)

    # Calculate in-focus sensor position
    t = (ray.d[:, 0] * ray.o[:, 0] + ray.d[:, 1] * ray.o[:, 1]) / (
        ray.d[:, 0] ** 2 + ray.d[:, 1] ** 2
    )
    focus_z = ray.o[:, 2] - ray.d[:, 2] * t
    focus_z = focus_z[ray.is_valid > 0]
    focus_z = focus_z[~torch.isnan(focus_z) & (focus_z > 0)]
    d_sensor = torch.mean(focus_z)
    return d_sensor

calc_fov

calc_fov()

Compute field of view (FoV) of the lens in radians.

Calculates FoV using two methods
  1. Perspective projection — from focal length and sensor size (effective FoV, ignoring distortion).
  2. Ray tracing — traces rays from the sensor edge backwards to determine the real FoV including distortion effects.
Updates

self.vfov (float): Vertical FoV in radians. self.hfov (float): Horizontal FoV in radians. self.dfov (float): Diagonal FoV in radians. self.rfov_eff (float): Half-diagonal (radius) FoV in radians. self.rfov (float): Real half-diagonal FoV from ray tracing. self.real_dfov (float): Real diagonal FoV from ray tracing. self.eqfl (float): 35mm equivalent focal length in mm.

Reference

[1] https://en.wikipedia.org/wiki/Angle_of_view_(photography)

Source code in src/geolens.py
@torch.no_grad()
def calc_fov(self):
    """Compute field of view (FoV) of the lens in radians.

    Calculates FoV using two methods:
        1. **Perspective projection** — from focal length and sensor size
           (effective FoV, ignoring distortion).
        2. **Ray tracing** — traces rays from the sensor edge backwards to
           determine the real FoV including distortion effects.

    Updates:
        self.vfov (float): Vertical FoV in radians.
        self.hfov (float): Horizontal FoV in radians.
        self.dfov (float): Diagonal FoV in radians.
        self.rfov_eff (float): Half-diagonal (radius) FoV in radians.
        self.rfov (float): Real half-diagonal FoV from ray tracing.
        self.real_dfov (float): Real diagonal FoV from ray tracing.
        self.eqfl (float): 35mm equivalent focal length in mm.

    Reference:
        [1] https://en.wikipedia.org/wiki/Angle_of_view_(photography)
    """
    if not hasattr(self, "foclen"):
        return

    # 1. Perspective projection (effective FoV)
    self.vfov = 2 * math.atan(self.sensor_size[0] / 2 / self.foclen)
    self.hfov = 2 * math.atan(self.sensor_size[1] / 2 / self.foclen)
    self.dfov = 2 * math.atan(self.r_sensor / self.foclen)
    self.rfov_eff = self.dfov / 2  # radius (half diagonal) FoV

    # 2. Forward ray tracing to calculate real FoV (distortion-affected)
    # Sweep FOV angles from object side, trace to sensor, and find which
    # angle produces an image height matching r_sensor.
    num_fov = 64
    fov_lo = float(np.rad2deg(self.rfov_eff)) * 0.5
    fov_hi = min(float(np.rad2deg(self.rfov_eff)) * 1.8, 89.0)
    fov_samples = torch.linspace(fov_lo, fov_hi, num_fov, device=self.device)

    ray = self.sample_from_fov(
        fov_x=0.0, fov_y=fov_samples.tolist(), num_rays=256
    )
    ray = self.trace2sensor(ray)

    # Centroid image height per FOV angle, shape [num_fov]
    valid = ray.is_valid > 0  # [num_fov, num_rays]
    masked_y = ray.o[..., 1] * valid
    n_valid = valid.sum(dim=-1).clamp(min=1)
    imgh = (masked_y.sum(dim=-1) / n_valid).abs()

    # Find the FOV angle whose image height is closest to r_sensor
    has_valid = valid.sum(dim=-1) > 10
    if has_valid.any():
        imgh[~has_valid] = float("inf")
        diff = (imgh - self.r_sensor).abs()
        best_idx = diff.argmin().item()
        rfov = fov_samples[best_idx].item() * math.pi / 180.0
        self.rfov = rfov
        self.real_dfov = 2 * rfov
    else:
        self.rfov = self.rfov_eff
        self.real_dfov = self.dfov

    # 3. Compute 35mm equivalent focal length. 35mm sensor: 36mm * 24mm
    self.eqfl = 21.63 / math.tan(self.rfov_eff)

calc_scale

calc_scale(depth)

Calculate the scale factor (object height / image height).

Uses the pinhole camera model to compute magnification.

Parameters:

Name Type Description Default
depth float

Object distance from the lens (negative z direction).

required

Returns:

Name Type Description
float

Scale factor relating object height to image height.

Source code in src/geolens.py
@torch.no_grad()
def calc_scale(self, depth):
    """Calculate the scale factor (object height / image height).

    Uses the pinhole camera model to compute magnification.

    Args:
        depth (float): Object distance from the lens (negative z direction).

    Returns:
        float: Scale factor relating object height to image height.
    """
    return -depth / self.foclen

calc_pupil

calc_pupil()

Compute entrance and exit pupil positions and radii.

The entrance and exit pupils must be recalculated whenever
  • First-order parameters change (e.g., field of view, object height, image height),
  • Lens geometry or materials change (e.g., surface curvatures, refractive indices, thicknesses),
  • Or generally, any time the lens configuration is modified.
Updates

self.aper_idx: Index of the aperture surface. self.exit_pupilz, self.exit_pupilr: Exit pupil position and radius. self.entr_pupilz, self.entr_pupilr: Entrance pupil position and radius. self.exit_pupilz_parax, self.exit_pupilr_parax: Paraxial exit pupil. self.entr_pupilz_parax, self.entr_pupilr_parax: Paraxial entrance pupil. self.fnum: F-number calculated from focal length and entrance pupil.

Source code in src/geolens.py
@torch.no_grad()
def calc_pupil(self):
    """Compute entrance and exit pupil positions and radii.

    The entrance and exit pupils must be recalculated whenever:
        - First-order parameters change (e.g., field of view, object height, image height),
        - Lens geometry or materials change (e.g., surface curvatures, refractive indices, thicknesses),
        - Or generally, any time the lens configuration is modified.

    Updates:
        self.aper_idx: Index of the aperture surface.
        self.exit_pupilz, self.exit_pupilr: Exit pupil position and radius.
        self.entr_pupilz, self.entr_pupilr: Entrance pupil position and radius.
        self.exit_pupilz_parax, self.exit_pupilr_parax: Paraxial exit pupil.
        self.entr_pupilz_parax, self.entr_pupilr_parax: Paraxial entrance pupil.
        self.fnum: F-number calculated from focal length and entrance pupil.
    """
    # Compute entrance and exit pupil
    self.exit_pupilz, self.exit_pupilr = self.calc_exit_pupil(paraxial=False)
    self.entr_pupilz, self.entr_pupilr = self.calc_entrance_pupil(paraxial=False)
    self.exit_pupilz_parax, self.exit_pupilr_parax = self.calc_exit_pupil(
        paraxial=True
    )
    self.entr_pupilz_parax, self.entr_pupilr_parax = self.calc_entrance_pupil(
        paraxial=True
    )

    # Compute F-number
    self.fnum = self.foclen / (2 * self.entr_pupilr)

get_entrance_pupil

get_entrance_pupil(paraxial=False)

Get entrance pupil location and radius.

Parameters:

Name Type Description Default
paraxial bool

If True, return paraxial approximation values. If False, return real ray-traced values. Defaults to False.

False

Returns:

Name Type Description
tuple

(z_position, radius) of the entrance pupil in [mm].

Source code in src/geolens.py
def get_entrance_pupil(self, paraxial=False):
    """Get entrance pupil location and radius.

    Args:
        paraxial (bool, optional): If True, return paraxial approximation values.
            If False, return real ray-traced values. Defaults to False.

    Returns:
        tuple: (z_position, radius) of the entrance pupil in [mm].
    """
    if paraxial:
        return self.entr_pupilz_parax, self.entr_pupilr_parax
    else:
        return self.entr_pupilz, self.entr_pupilr

get_exit_pupil

get_exit_pupil(paraxial=False)

Get exit pupil location and radius.

Parameters:

Name Type Description Default
paraxial bool

If True, return paraxial approximation values. If False, return real ray-traced values. Defaults to False.

False

Returns:

Name Type Description
tuple

(z_position, radius) of the exit pupil in [mm].

Source code in src/geolens.py
def get_exit_pupil(self, paraxial=False):
    """Get exit pupil location and radius.

    Args:
        paraxial (bool, optional): If True, return paraxial approximation values.
            If False, return real ray-traced values. Defaults to False.

    Returns:
        tuple: (z_position, radius) of the exit pupil in [mm].
    """
    if paraxial:
        return self.exit_pupilz_parax, self.exit_pupilr_parax
    else:
        return self.exit_pupilz, self.exit_pupilr

calc_exit_pupil

calc_exit_pupil(paraxial=False)

Calculate exit pupil location and radius.

Paraxial mode

Rays are emitted from near the center of the aperture stop and are close to the optical axis. This mode estimates the exit pupil position and radius under ideal (first-order) optical assumptions. It is fast and stable.

Non-paraxial mode

Rays are emitted from the edge of the aperture stop in large quantities. The exit pupil position and radius are determined based on the intersection points of these rays. This mode is slower and affected by aperture-related aberrations.

Use paraxial mode unless precise ray aiming is required.

Parameters:

Name Type Description Default
paraxial bool

center (True) or edge (False).

False

Returns:

Name Type Description
avg_pupilz float

z coordinate of exit pupil.

avg_pupilr float

radius of exit pupil.

Reference

[1] Exit pupil: how many rays can come from sensor to object space. [2] https://en.wikipedia.org/wiki/Exit_pupil

Source code in src/geolens.py
@torch.no_grad()
def calc_exit_pupil(self, paraxial=False):
    """Calculate exit pupil location and radius.

    Paraxial mode:
        Rays are emitted from near the center of the aperture stop and are close to the optical axis.
        This mode estimates the exit pupil position and radius under ideal (first-order) optical assumptions.
        It is fast and stable.

    Non-paraxial mode:
        Rays are emitted from the edge of the aperture stop in large quantities.
        The exit pupil position and radius are determined based on the intersection points of these rays.
        This mode is slower and affected by aperture-related aberrations.

    Use paraxial mode unless precise ray aiming is required.

    Args:
        paraxial (bool): center (True) or edge (False).

    Returns:
        avg_pupilz (float): z coordinate of exit pupil.
        avg_pupilr (float): radius of exit pupil.

    Reference:
        [1] Exit pupil: how many rays can come from sensor to object space.
        [2] https://en.wikipedia.org/wiki/Exit_pupil
    """
    if self.aper_idx is None or hasattr(self, "aper_idx") is False:
        print("No aperture, use the last surface as exit pupil.")
        return self.surfaces[-1].d.item(), self.surfaces[-1].r

    # Sample rays from aperture (edge or center)
    aper_idx = self.aper_idx
    aper_z = self.surfaces[aper_idx].d.item()
    aper_r = self.surfaces[aper_idx].r

    if paraxial:
        ray_o = torch.tensor([[DELTA_PARAXIAL, 0, aper_z]], device=self.device).repeat(32, 1)
        phi_rad = torch.linspace(-0.01, 0.01, 32, device=self.device)
    else:
        ray_o = torch.tensor([[aper_r, 0, aper_z]], device=self.device).repeat(SPP_CALC, 1)
        rfov_eff = float(np.arctan(self.r_sensor / self.foclen))
        phi_rad = torch.linspace(-rfov_eff / 2, rfov_eff / 2, SPP_CALC, device=self.device)

    d = torch.stack(
        (torch.sin(phi_rad), torch.zeros_like(phi_rad), torch.cos(phi_rad)), axis=-1
    )
    ray = Ray(ray_o, d, device=self.device)

    # Ray tracing from aperture edge to last surface
    surf_range = range(self.aper_idx + 1, len(self.surfaces))
    ray, _ = self.trace(ray, surf_range=surf_range)

    # Compute intersection points, solving the equation: o1+d1*t1 = o2+d2*t2
    ray_o = torch.stack(
        [ray.o[ray.is_valid != 0][:, 0], ray.o[ray.is_valid != 0][:, 2]], dim=-1
    )
    ray_d = torch.stack(
        [ray.d[ray.is_valid != 0][:, 0], ray.d[ray.is_valid != 0][:, 2]], dim=-1
    )
    intersection_points = self.compute_intersection_points_2d(ray_o, ray_d)

    # Handle the case where no intersection points are found or small pupil
    if len(intersection_points) == 0:
        print("No intersection points found, use the last surface as exit pupil.")
        avg_pupilr = self.surfaces[-1].r
        avg_pupilz = self.surfaces[-1].d.item()
    else:
        avg_pupilr = torch.mean(intersection_points[:, 0]).item()
        avg_pupilz = torch.mean(intersection_points[:, 1]).item()

        if paraxial:
            avg_pupilr = abs(avg_pupilr / DELTA_PARAXIAL * aper_r)

        if avg_pupilr < EPSILON:
            print(
                "Zero or negative exit pupil is detected, use the last surface as pupil."
            )
            avg_pupilr = self.surfaces[-1].r
            avg_pupilz = self.surfaces[-1].d.item()

    return avg_pupilz, avg_pupilr

calc_entrance_pupil

calc_entrance_pupil(paraxial=False)

Calculate entrance pupil of the lens.

The entrance pupil is the optical image of the physical aperture stop, as seen through the optical elements in front of the stop. We sample backward rays from the aperture stop and trace them to the first surface, then find the intersection points of the reverse extension of the rays. The average of the intersection points defines the entrance pupil position and radius.

Parameters:

Name Type Description Default
paraxial bool

Ray sampling mode. If True, rays are emitted near the centre of the aperture stop (fast, paraxially stable). If False, rays are emitted from the stop edge in larger quantities (slower, accounts for aperture aberrations). Defaults to False.

False

Returns:

Name Type Description
tuple

(z_position, radius) of entrance pupil.

Note

[1] Use paraxial mode unless precise ray aiming is required. [2] This function only works for object at a far distance. For microscopes, this function usually returns a negative entrance pupil.

References

[1] Entrance pupil: how many rays can come from object space to sensor. [2] https://en.wikipedia.org/wiki/Entrance_pupil: "In an optical system, the entrance pupil is the optical image of the physical aperture stop, as 'seen' through the optical elements in front of the stop." [3] Zemax LLC, OpticStudio User Manual, Version 19.4, Document No. 2311, 2019.

Source code in src/geolens.py
@torch.no_grad()
def calc_entrance_pupil(self, paraxial=False):
    """Calculate entrance pupil of the lens.

    The entrance pupil is the optical image of the physical aperture stop, as seen through the optical elements in front of the stop. We sample backward rays from the aperture stop and trace them to the first surface, then find the intersection points of the reverse extension of the rays. The average of the intersection points defines the entrance pupil position and radius.

    Args:
        paraxial (bool): Ray sampling mode.  If ``True``, rays are emitted
            near the centre of the aperture stop (fast, paraxially stable).
            If ``False``, rays are emitted from the stop edge in larger
            quantities (slower, accounts for aperture aberrations).
            Defaults to ``False``.

    Returns:
        tuple: (z_position, radius) of entrance pupil.

    Note:
        [1] Use paraxial mode unless precise ray aiming is required.
        [2] This function only works for object at a far distance. For microscopes, this function usually returns a negative entrance pupil.

    References:
        [1] Entrance pupil: how many rays can come from object space to sensor.
        [2] https://en.wikipedia.org/wiki/Entrance_pupil: "In an optical system, the entrance pupil is the optical image of the physical aperture stop, as 'seen' through the optical elements in front of the stop."
        [3] Zemax LLC, *OpticStudio User Manual*, Version 19.4, Document No. 2311, 2019.
    """
    if self.aper_idx is None or not hasattr(self, "aper_idx"):
        print("No aperture stop, use the first surface as entrance pupil.")
        return self.surfaces[0].d.item(), self.surfaces[0].r

    # Sample rays from edge of aperture stop
    aper_idx = self.aper_idx
    aper_surf = self.surfaces[aper_idx]
    aper_z = aper_surf.d.item()
    if aper_surf.is_square:
        aper_r = float(np.sqrt(2)) * aper_surf.r
    else:
        aper_r = aper_surf.r

    if paraxial:
        ray_o = torch.tensor([[DELTA_PARAXIAL, 0, aper_z]]).repeat(32, 1)
        phi = torch.linspace(-0.01, 0.01, 32)
    else:
        ray_o = torch.tensor([[aper_r, 0, aper_z]]).repeat(SPP_CALC, 1)
        rfov_eff = float(np.arctan(self.r_sensor / self.foclen))
        phi = torch.linspace(-rfov_eff / 2, rfov_eff / 2, SPP_CALC)

    d = torch.stack(
        (torch.sin(phi), torch.zeros_like(phi), -torch.cos(phi)), axis=-1
    )
    ray = Ray(ray_o, d, device=self.device)

    # Ray tracing from aperture edge to first surface
    surf_range = range(0, self.aper_idx)
    if len(surf_range) == 0:
        # Aperture is the first surface — entrance pupil is at the aperture
        return aper_z, aper_r
    ray, _ = self.trace(ray, surf_range=surf_range)

    # Compute intersection points, solving the equation: o1+d1*t1 = o2+d2*t2
    ray_o = torch.stack(
        [ray.o[ray.is_valid > 0][:, 0], ray.o[ray.is_valid > 0][:, 2]], dim=-1
    )
    ray_d = torch.stack(
        [ray.d[ray.is_valid > 0][:, 0], ray.d[ray.is_valid > 0][:, 2]], dim=-1
    )
    intersection_points = self.compute_intersection_points_2d(ray_o, ray_d)

    # Handle the case where no intersection points are found or small entrance pupil
    if len(intersection_points) == 0:
        print(
            "No intersection points found, use the first surface as entrance pupil."
        )
        avg_pupilr = self.surfaces[0].r
        avg_pupilz = self.surfaces[0].d.item()
    else:
        avg_pupilr = torch.mean(intersection_points[:, 0]).item()
        avg_pupilz = torch.mean(intersection_points[:, 1]).item()

        if paraxial:
            avg_pupilr = abs(avg_pupilr / DELTA_PARAXIAL * aper_r)

        if avg_pupilr < EPSILON:
            print(
                "Zero or negative entrance pupil is detected, use the first surface as entrance pupil."
            )
            avg_pupilr = self.surfaces[0].r
            avg_pupilz = self.surfaces[0].d.item()

    return avg_pupilz, avg_pupilr

compute_intersection_points_2d staticmethod

compute_intersection_points_2d(origins, directions)

Compute the intersection points of 2D lines.

Parameters:

Name Type Description Default
origins Tensor

Origins of the lines. Shape: [N, 2]

required
directions Tensor

Directions of the lines. Shape: [N, 2]

required

Returns:

Type Description

torch.Tensor: Intersection points. Shape: [N*(N-1)/2, 2]

Source code in src/geolens.py
@staticmethod
def compute_intersection_points_2d(origins, directions):
    """Compute the intersection points of 2D lines.

    Args:
        origins (torch.Tensor): Origins of the lines. Shape: [N, 2]
        directions (torch.Tensor): Directions of the lines. Shape: [N, 2]

    Returns:
        torch.Tensor: Intersection points. Shape: [N*(N-1)/2, 2]
    """
    N = origins.shape[0]

    # Create pairwise combinations of indices
    idx = torch.arange(N)
    idx_i, idx_j = torch.combinations(idx, r=2).unbind(1)

    Oi = origins[idx_i]  # Shape: [N*(N-1)/2, 2]
    Oj = origins[idx_j]  # Shape: [N*(N-1)/2, 2]
    Di = directions[idx_i]  # Shape: [N*(N-1)/2, 2]
    Dj = directions[idx_j]  # Shape: [N*(N-1)/2, 2]

    # Vector from Oi to Oj
    b = Oj - Oi  # Shape: [N*(N-1)/2, 2]

    # Coefficients matrix A
    A = torch.stack([Di, -Dj], dim=-1)  # Shape: [N*(N-1)/2, 2, 2]

    # Solve the linear system Ax = b
    # Using least squares to handle the case of no exact solution
    if A.device.type == "mps":
        # Perform lstsq on CPU for MPS devices and move result back
        x, _ = torch.linalg.lstsq(A.cpu(), b.unsqueeze(-1).cpu())[:2]
        x = x.to(A.device)
    else:
        x, _ = torch.linalg.lstsq(A, b.unsqueeze(-1))[:2]
    x = x.squeeze(-1)  # Shape: [N*(N-1)/2, 2]
    s = x[:, 0]
    t = x[:, 1]

    # Calculate the intersection points using either rays
    P_i = Oi + s.unsqueeze(-1) * Di  # Shape: [N*(N-1)/2, 2]
    P_j = Oj + t.unsqueeze(-1) * Dj  # Shape: [N*(N-1)/2, 2]

    # Take the average to mitigate numerical precision issues
    P = (P_i + P_j) / 2

    return P

refocus

refocus(foc_dist=float('inf'))

Refocus the lens to a depth distance by changing sensor position.

Parameters:

Name Type Description Default
foc_dist float

focal distance.

float('inf')
Note

In DSLR, phase detection autofocus (PDAF) is a popular and efficient method. But here we simplify the problem by calculating the in-focus position of green light.

Source code in src/geolens.py
@torch.no_grad()
def refocus(self, foc_dist=float("inf")):
    """Refocus the lens to a depth distance by changing sensor position.

    Args:
        foc_dist (float): focal distance.

    Note:
        In DSLR, phase detection autofocus (PDAF) is a popular and efficient method. But here we simplify the problem by calculating the in-focus position of green light.
    """
    # Calculate in-focus sensor position
    d_sensor_new = self.calc_sensor_plane(depth=foc_dist)

    # Update sensor position
    assert d_sensor_new > 0, "Obtained negative sensor position."
    self.d_sensor = d_sensor_new

    # FoV will be slightly changed
    self.post_computation()

set_fnum

set_fnum(fnum)

Set F-number and aperture radius using binary search.

Parameters:

Name Type Description Default
fnum float

target F-number.

required
Source code in src/geolens.py
@torch.no_grad()
def set_fnum(self, fnum):
    """Set F-number and aperture radius using binary search.

    Args:
        fnum (float): target F-number.
    """
    current_fnum = self.fnum
    current_aper_r = self.surfaces[self.aper_idx].r
    target_pupil_r = self.foclen / fnum / 2

    # Binary search to find aperture radius that gives desired exit pupil radius
    aper_r = current_aper_r * (current_fnum / fnum)
    aper_r_min = 0.5 * aper_r
    aper_r_max = 2.0 * aper_r

    for _ in range(16):
        self.surfaces[self.aper_idx].r = aper_r
        _, pupilr = self.calc_entrance_pupil()

        if abs(pupilr - target_pupil_r) < 0.1:  # Close enough
            break

        if pupilr > target_pupil_r:
            # Current radius is too large, decrease it
            aper_r_max = aper_r
            aper_r = (aper_r_min + aper_r) / 2
        else:
            # Current radius is too small, increase it
            aper_r_min = aper_r
            aper_r = (aper_r_max + aper_r) / 2

    self.surfaces[self.aper_idx].r = aper_r

    # Update pupil after setting aperture radius
    self.calc_pupil()

set_target_fov_fnum

set_target_fov_fnum(rfov_eff, fnum)

Set FoV, ImgH and F number, only use this function to assign design targets.

Parameters:

Name Type Description Default
rfov_eff float

half diagonal-FoV in radian.

required
fnum float

F number.

required
Source code in src/geolens.py
@torch.no_grad()
def set_target_fov_fnum(self, rfov_eff, fnum):
    """Set FoV, ImgH and F number, only use this function to assign design targets.

    Args:
        rfov_eff (float): half diagonal-FoV in radian.
        fnum (float): F number.
    """
    if rfov_eff > math.pi:
        self.rfov_eff = rfov_eff / 180.0 * math.pi
    else:
        self.rfov_eff = rfov_eff

    self.rfov = self.rfov_eff
    self.foclen = self.r_sensor / math.tan(self.rfov_eff)
    self.eqfl = 21.63 / math.tan(self.rfov_eff)
    self.fnum = fnum
    aper_r = self.foclen / fnum / 2
    self.surfaces[self.aper_idx].update_r(float(aper_r))

    # Update pupil after setting aperture radius
    self.calc_pupil()

set_fov

set_fov(rfov_eff)

Set half-diagonal field of view as a design target.

Unlike calc_fov() which derives FoV from focal length and sensor size, this method directly assigns the target FoV for lens optimisation.

Parameters:

Name Type Description Default
rfov_eff float

Half-diagonal FoV in radians.

required
Source code in src/geolens.py
@torch.no_grad()
def set_fov(self, rfov_eff):
    """Set half-diagonal field of view as a design target.

    Unlike ``calc_fov()`` which derives FoV from focal length and sensor
    size, this method directly assigns the target FoV for lens optimisation.

    Args:
        rfov_eff (float): Half-diagonal FoV in radians.
    """
    self.rfov_eff = rfov_eff
    self.rfov = rfov_eff
    self.eqfl = 21.63 / math.tan(self.rfov_eff)

Constructor

GeoLens(filename=None, device=None, dtype=torch.float32)
Parameter Type Description
filename str or None Path to .json, .zmx, or .seq lens file. None creates an empty lens.
device str or None PyTorch device. Auto-selects CUDA/MPS/CPU if None.
dtype torch.dtype Floating-point precision. Default torch.float32.

Key Attributes

Attribute Type Description
surfaces list[Surface] Ordered list of optical surfaces
materials list[Material] Materials between surfaces
d_sensor Tensor Back focal distance (mm)
foclen float Effective focal length (mm)
fnum float F-number
rfov float Half-diagonal FoV, ray-traced (radians)
rfov_eff float Half-diagonal FoV, pinhole model (radians)
sensor_size tuple Physical sensor size (W, H) in mm
r_sensor float Sensor half-diagonal (mm)
aper_idx int Index of aperture stop surface

Ray Tracing

Core ray tracing methods for propagating rays through the lens system.

trace(ray, surf_range, record)

Forward or backward ray trace through a range of surfaces.

trace2sensor(ray, record)

Trace rays forward from object space to the sensor plane.

ray = lens.sample_from_fov(fov_x=0.0, fov_y=10.0, num_rays=1000, wvln=0.587)
ray_out = lens.trace2sensor(ray)

trace2obj(ray)

Backward trace from sensor side to object space.

forward_tracing(ray, surf_range, record)

Sequential forward ray trace through surfaces. Returns modified ray.

backward_tracing(ray, surf_range, record)

Sequential backward ray trace through surfaces.


Ray Sampling

Methods for generating ray bundles from various source configurations.

sample_from_fov(fov_x, fov_y, depth, num_rays, wvln)

Sample rays from a point source at a given field angle (degrees).

ray = lens.sample_from_fov(fov_x=0.0, fov_y=15.0, num_rays=2048, wvln=0.587)

sample_grid_rays(depth, num_grid, num_rays, wvln)

Sample a 2D grid of field positions across the full field of view.

sample_radial_rays(num_field, depth, num_rays, wvln, direction)

Sample along the y-axis, x-axis, or diagonal.

sample_from_points(points, num_rays, wvln, scale_pupil)

Generate rays from specific point sources in object space.

sample_sensor(spp, wvln, sub_pixel)

Backward rays from the sensor plane (for image simulation).


Lens Properties

calc_foclen(test_fov_deg)

Compute effective focal length via ray tracing.

calc_fov()

Compute field of view angles (half-diagonal, horizontal, vertical).

calc_pupil()

Compute entrance and exit pupil positions and radii.

calc_numerical_aperture(n)

Compute numerical aperture in medium with refractive index n.

set_fnum(fnum)

Scale the aperture stop to achieve target f-number.

set_target_fov_fnum(rfov_eff, fnum)

Set both field of view and f-number targets.

lens.set_target_fov_fnum(rfov_eff=40 / 57.3, fnum=2.0)

PSF Computation

src.geolens_pkg.psf_compute.GeoLensPSF

Mixin providing PSF computation for GeoLens.

All three PSF models are exposed through a single :meth:psf dispatcher. The geometric and coherent models are differentiable; Huygens is not.

This class is not instantiated directly; it is mixed into :class:~deeplens.optics.geolens.GeoLens.

psf

psf(points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=None, recenter=True, model='geometric')

Calculate Point Spread Function (PSF) for given point sources.

Supports multiple PSF calculation models
  • geometric: Incoherent intensity ray tracing (fast, differentiable)
  • coherent: Coherent ray tracing with free-space propagation (accurate, differentiable)
  • huygens: Huygens-Fresnel integration (accurate, not differentiable)

Parameters:

Name Type Description Default
points Tensor

Point source positions. Shape [N, 3] with x, y in [-1, 1] and z in [-Inf, 0]. Normalized coordinates.

required
ks int

Output kernel size in pixels. Defaults to PSF_KS.

PSF_KS
wvln float

Wavelength in [um]. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
spp int

Samples per pixel. If None, uses model-specific default.

None
recenter bool

If True, center PSF using chief ray. Defaults to True.

True
model str

PSF model type. One of 'geometric', 'coherent', 'huygens'. Defaults to 'geometric'.

'geometric'

Returns:

Name Type Description
Tensor

PSF normalized to sum to 1. Shape [ks, ks] or [N, ks, ks].

Source code in src/geolens_pkg/psf_compute.py
def psf(
    self,
    points,
    ks=PSF_KS,
    wvln=DEFAULT_WAVE,
    spp=None,
    recenter=True,
    model="geometric",
):
    """Calculate Point Spread Function (PSF) for given point sources.

    Supports multiple PSF calculation models:
        - geometric: Incoherent intensity ray tracing (fast, differentiable)
        - coherent: Coherent ray tracing with free-space propagation (accurate, differentiable)
        - huygens: Huygens-Fresnel integration (accurate, not differentiable)

    Args:
        points (Tensor): Point source positions. Shape [N, 3] with x, y in [-1, 1]
            and z in [-Inf, 0]. Normalized coordinates.
        ks (int, optional): Output kernel size in pixels. Defaults to PSF_KS.
        wvln (float, optional): Wavelength in [um]. Defaults to DEFAULT_WAVE.
        spp (int, optional): Samples per pixel. If None, uses model-specific default.
        recenter (bool, optional): If True, center PSF using chief ray. Defaults to True.
        model (str, optional): PSF model type. One of 'geometric', 'coherent', 'huygens'.
            Defaults to 'geometric'.

    Returns:
        Tensor: PSF normalized to sum to 1. Shape [ks, ks] or [N, ks, ks].
    """
    if model == "geometric":
        spp = SPP_PSF if spp is None else spp
        return self.psf_geometric(points, ks, wvln, spp, recenter)
    elif model == "coherent":
        spp = SPP_COHERENT if spp is None else spp
        return self.psf_coherent(points, ks, wvln, spp, recenter)
    elif model == "huygens":
        spp = SPP_COHERENT if spp is None else spp
        return self.psf_huygens(points, ks, wvln, spp, recenter)
    else:
        raise ValueError(f"Unknown PSF model: {model}")

psf_geometric

psf_geometric(points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_PSF, recenter=True)

Single wavelength geometric PSF calculation.

Parameters:

Name Type Description Default
points Tensor

Normalized point source position. Shape of [N, 3], x, y in range [-1, 1], z in range [-Inf, 0].

required
ks int

Output kernel size.

PSF_KS
wvln float

Wavelength.

DEFAULT_WAVE
spp int

Sample per pixel.

SPP_PSF
recenter bool

Recenter PSF using chief ray.

True

Returns:

Name Type Description
psf

Shape of [ks, ks] or [N, ks, ks].

References

[1] https://optics.ansys.com/hc/en-us/articles/42661723066515-What-is-a-Point-Spread-Function

Source code in src/geolens_pkg/psf_compute.py
def psf_geometric(
    self, points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_PSF, recenter=True
):
    """Single wavelength geometric PSF calculation.

    Args:
        points (Tensor): Normalized point source position. Shape of [N, 3], x, y in range [-1, 1], z in range [-Inf, 0].
        ks (int, optional): Output kernel size.
        wvln (float, optional): Wavelength.
        spp (int, optional): Sample per pixel.
        recenter (bool, optional): Recenter PSF using chief ray.

    Returns:
        psf: Shape of [ks, ks] or [N, ks, ks].

    References:
        [1] https://optics.ansys.com/hc/en-us/articles/42661723066515-What-is-a-Point-Spread-Function
    """
    sensor_w, sensor_h = self.sensor_size
    pixel_size = self.pixel_size
    device = self.device

    # Points shape of [N, 3]
    if not torch.is_tensor(points):
        points = torch.tensor(points, device=device)

    if len(points.shape) == 1:
        single_point = True
        points = points.unsqueeze(0)
    else:
        single_point = False

    # Sample rays. Ray position in the object space by perspective projection
    depth = points[:, 2]
    scale = self.calc_scale(depth)
    point_obj_x = points[..., 0] * scale * sensor_w / 2
    point_obj_y = points[..., 1] * scale * sensor_h / 2
    point_obj = torch.stack([point_obj_x, point_obj_y, points[..., 2]], dim=-1)
    ray = self.sample_from_points(points=point_obj, num_rays=spp, wvln=wvln)

    # Trace rays to sensor plane (incoherent)
    ray.coherent = False
    ray = self.trace2sensor(ray)

    # Calculate PSF center, shape [N, 2]
    if recenter:
        pointc = self.psf_center(point_obj, method="chief_ray")
    else:
        pointc = self.psf_center(point_obj, method="pinhole")

    # Monte Carlo integration
    psf = forward_integral(ray.flip_xy(), ps=pixel_size, ks=ks, pointc=pointc)

    # Intensity normalization
    psf = psf / (torch.sum(psf, dim=(-2, -1), keepdim=True) + EPSILON)

    if single_point:
        psf = psf.squeeze(0)

    return diff_float(psf)

psf_coherent

psf_coherent(points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True)

Alias for psf_pupil_prop. Calculates PSF by coherent ray tracing to exit pupil followed by Angular Spectrum Method (ASM) propagation.

Source code in src/geolens_pkg/psf_compute.py
def psf_coherent(
    self, points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True
):
    """Alias for psf_pupil_prop. Calculates PSF by coherent ray tracing to exit pupil followed by Angular Spectrum Method (ASM) propagation."""
    return self.psf_pupil_prop(points, ks=ks, wvln=wvln, spp=spp, recenter=recenter)

psf_pupil_prop

psf_pupil_prop(points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True)

Single point monochromatic PSF using exit-pupil diffraction model. This function is differentiable.

Steps

1, Calculate complex wavefield at exit-pupil plane by coherent ray tracing. 2, Free-space propagation to sensor plane and calculate intensity PSF.

Parameters:

Name Type Description Default
points Tensor

[x, y, z] coordinates of the point source. Defaults to torch.Tensor([0,0,-10000]).

required
ks int

size of the PSF patch. Defaults to PSF_KS.

PSF_KS
wvln float

wvln. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
spp int

number of rays to sample. Defaults to SPP_COHERENT.

SPP_COHERENT
recenter bool

Recenter PSF using chief ray. Defaults to True.

True

Returns:

Name Type Description
psf_out Tensor

PSF patch. Normalized to sum to 1. Shape [ks, ks]

Reference

[1] "End-to-End Hybrid Refractive-Diffractive Lens Design with Differentiable Ray-Wave Model", SIGGRAPH Asia 2024.

Note

[1] This function is similar to ZEMAX FFT_PSF but implement free-space propagation with Angular Spectrum Method (ASM) rather than FFT transform. Free-space propagation using ASM is more accurate than doing FFT, because FFT (as used in ZEMAX) assumes far-field condition (e.g., chief ray perpendicular to image plane).

Source code in src/geolens_pkg/psf_compute.py
def psf_pupil_prop(
    self, points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True
):
    """Single point monochromatic PSF using exit-pupil diffraction model. This function is differentiable.

    Steps:
        1, Calculate complex wavefield at exit-pupil plane by coherent ray tracing.
        2, Free-space propagation to sensor plane and calculate intensity PSF.

    Args:
        points (torch.Tensor, optional): [x, y, z] coordinates of the point source. Defaults to torch.Tensor([0,0,-10000]).
        ks (int, optional): size of the PSF patch. Defaults to PSF_KS.
        wvln (float, optional): wvln. Defaults to DEFAULT_WAVE.
        spp (int, optional): number of rays to sample. Defaults to SPP_COHERENT.
        recenter (bool, optional): Recenter PSF using chief ray. Defaults to True.

    Returns:
        psf_out (torch.Tensor): PSF patch. Normalized to sum to 1. Shape [ks, ks]

    Reference:
        [1] "End-to-End Hybrid Refractive-Diffractive Lens Design with Differentiable Ray-Wave Model", SIGGRAPH Asia 2024.

    Note:
        [1] This function is similar to ZEMAX FFT_PSF but implement free-space propagation with Angular Spectrum Method (ASM) rather than FFT transform. Free-space propagation using ASM is more accurate than doing FFT, because FFT (as used in ZEMAX) assumes far-field condition (e.g., chief ray perpendicular to image plane).
    """
    # Pupil field by coherent ray tracing
    wavefront, psfc = self.pupil_field(
        points=points, wvln=wvln, spp=spp, recenter=recenter
    )

    # Propagate to sensor plane and get intensity
    pupilz, pupilr = self.get_exit_pupil()
    h, w = wavefront.shape
    # Manually pad wave field
    wavefront = F.pad(
        wavefront.unsqueeze(0).unsqueeze(0),
        [h // 2, h // 2, w // 2, w // 2],
        mode="constant",
        value=0,
    )
    # Free-space propagation using Angular Spectrum Method (ASM)
    sensor_field = AngularSpectrumMethod(
        wavefront,
        z=self.d_sensor - pupilz,
        wvln=wvln,
        ps=self.pixel_size,
        padding=False,
    )
    # Get intensity
    psf_inten = sensor_field.abs() ** 2

    # Calculate PSF center
    h, w = psf_inten.shape[-2:]
    # consider both interplation and padding
    psfc_idx_i = ((2 - psfc[1]) * h / 4).round().long()
    psfc_idx_j = ((2 + psfc[0]) * w / 4).round().long()

    # Crop valid PSF region and normalize
    if ks is not None:
        psf_inten_pad = (
            F.pad(
                psf_inten,
                [ks // 2, ks // 2, ks // 2, ks // 2],
                mode="constant",
                value=0,
            )
            .squeeze(0)
            .squeeze(0)
        )
        psf = psf_inten_pad[
            psfc_idx_i : psfc_idx_i + ks, psfc_idx_j : psfc_idx_j + ks
        ]
    else:
        psf = psf_inten

    # Intensity normalization, shape of [ks, ks] or [h, w]
    psf = psf / (torch.sum(psf, dim=(-2, -1), keepdim=True) + EPSILON)

    return diff_float(psf)

pupil_field

pupil_field(points, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True)

Compute complex wavefront at exit pupil plane by coherent ray tracing.

The wavefront is flipped for subsequent PSF calculation and has the same size as the image sensor. This function is differentiable.

Parameters:

Name Type Description Default
points Tensor or list

Single point source position. Shape [3] or [1, 3], with x, y in [-1, 1] and z in [-Inf, 0].

required
wvln float

Wavelength in [um]. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
spp int

Number of rays to sample. Must be >= 1,000,000 for accurate coherent simulation. Defaults to SPP_COHERENT.

SPP_COHERENT
recenter bool

If True, center using chief ray. Defaults to True.

True

Returns:

Name Type Description
tuple

(wavefront, psf_center) where: - wavefront (Tensor): Complex wavefront at exit pupil. Shape [H, H]. - psf_center (list): Normalized PSF center coordinates [x, y] in [-1, 1].

Note

Default dtype must be torch.float64 for accurate phase calculation.

Source code in src/geolens_pkg/psf_compute.py
def pupil_field(self, points, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True):
    """Compute complex wavefront at exit pupil plane by coherent ray tracing.

    The wavefront is flipped for subsequent PSF calculation and has the same
    size as the image sensor. This function is differentiable.

    Args:
        points (Tensor or list): Single point source position. Shape [3] or [1, 3],
            with x, y in [-1, 1] and z in [-Inf, 0].
        wvln (float, optional): Wavelength in [um]. Defaults to DEFAULT_WAVE.
        spp (int, optional): Number of rays to sample. Must be >= 1,000,000 for
            accurate coherent simulation. Defaults to SPP_COHERENT.
        recenter (bool, optional): If True, center using chief ray. Defaults to True.

    Returns:
        tuple: (wavefront, psf_center) where:
            - wavefront (Tensor): Complex wavefront at exit pupil. Shape [H, H].
            - psf_center (list): Normalized PSF center coordinates [x, y] in [-1, 1].

    Note:
        Default dtype must be torch.float64 for accurate phase calculation.
    """
    assert spp >= 1_000_000, (
        f"Ray sampling {spp} is too small for coherent ray tracing, which may lead to inaccurate simulation."
    )
    assert torch.get_default_dtype() == torch.float64, (
        "Default dtype must be set to float64 for accurate phase calculation."
    )

    sensor_w, sensor_h = self.sensor_size
    device = self.device

    if isinstance(points, list):
        points = torch.tensor(points, device=device).unsqueeze(0)  # [1, 3]
    elif torch.is_tensor(points) and len(points.shape) == 1:
        points = points.unsqueeze(0).to(device)  # [1, 3]
    elif torch.is_tensor(points) and len(points.shape) == 2:
        assert points.shape[0] == 1, (
            f"pupil_field only supports single point input, got shape {points.shape}"
        )
    else:
        raise ValueError(f"Unsupported point type {points.type()}.")

    assert points.shape[0] == 1, (
        "Only one point is supported for pupil field calculation."
    )

    # Ray origin in the object space
    scale = self.calc_scale(points[:, 2].item())
    point_obj_x = points[:, 0] * scale * sensor_w / 2
    point_obj_y = points[:, 1] * scale * sensor_h / 2
    points_obj = torch.stack([point_obj_x, point_obj_y, points[:, 2]], dim=-1)

    # Ray center determined by chief ray
    # Shape of [N, 2], un-normalized physical coordinates
    if recenter:
        pointc = self.psf_center(points_obj, method="chief_ray")
    else:
        pointc = self.psf_center(points_obj, method="pinhole")

    # Ray-tracing to exit_pupil
    ray = self.sample_from_points(points=points_obj, num_rays=spp, wvln=wvln)
    ray.coherent = True
    ray = self.trace2exit_pupil(ray)

    # Calculate complex field (same physical size and resolution as the sensor)
    # Complex field is flipped here for further PSF calculation
    pointc_ref = torch.zeros_like(points[:, :2])  # [N, 2]
    wavefront = forward_integral(
        ray.flip_xy(),
        ps=self.pixel_size,
        ks=self.sensor_res[1],
        pointc=pointc_ref,
    )
    wavefront = wavefront.squeeze(0)  # [H, H]

    # PSF center (on the sensor plane)
    pointc = pointc[0, :]
    psf_center = [
        pointc[0] / sensor_w * 2,
        pointc[1] / sensor_h * 2,
    ]

    return wavefront, psf_center

psf_huygens

psf_huygens(points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True)

Single wavelength Huygens PSF calculation.

This function is not differentiable due to its heavy computational cost.

Steps

1, Trace coherent rays to exit-pupil plane. 2, Treat every ray as a secondary point source emitting a spherical wave.

Parameters:

Name Type Description Default
points Tensor

Normalized point source position. Shape of [N, 3], x, y in range [-1, 1], z in range [-Inf, 0].

required
ks int

Output kernel size.

PSF_KS
wvln float

Wavelength.

DEFAULT_WAVE
spp int

Sample per pixel.

SPP_COHERENT
recenter bool

Recenter PSF using chief ray.

True

Returns:

Name Type Description
psf

Shape of [ks, ks] or [N, ks, ks].

References

[1] "Optical Aberrations Correction in Postprocessing Using Imaging Simulation", TOG 2021

Note

This is different from ZEMAX Huygens PSF, which traces rays to image plane and do plane wave integration.

Source code in src/geolens_pkg/psf_compute.py
def psf_huygens(
    self, points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True
):
    """Single wavelength Huygens PSF calculation.

    This function is not differentiable due to its heavy computational cost.

    Steps:
        1, Trace coherent rays to exit-pupil plane.
        2, Treat every ray as a secondary point source emitting a spherical wave.

    Args:
        points (Tensor): Normalized point source position. Shape of [N, 3], x, y in range [-1, 1], z in range [-Inf, 0].
        ks (int, optional): Output kernel size.
        wvln (float, optional): Wavelength.
        spp (int, optional): Sample per pixel.
        recenter (bool, optional): Recenter PSF using chief ray.

    Returns:
        psf: Shape of [ks, ks] or [N, ks, ks].

    References:
        [1] "Optical Aberrations Correction in Postprocessing Using Imaging Simulation", TOG 2021

    Note:
        This is different from ZEMAX Huygens PSF, which traces rays to image plane and do plane wave integration.
    """
    assert torch.get_default_dtype() == torch.float64, (
        "Default dtype must be set to float64 for accurate phase calculation."
    )

    sensor_w, sensor_h = self.sensor_size
    pixel_size = self.pixel_size
    device = self.device
    wvln_mm = wvln * 1e-3  # Convert wavelength to mm

    # Points shape of [N, 3]
    if not torch.is_tensor(points):
        points = torch.tensor(points, device=device)

    if len(points.shape) == 1:
        single_point = True
        points = points.unsqueeze(0)
    elif len(points.shape) == 2 and points.shape[0] == 1:
        single_point = True
    else:
        raise ValueError(
            f"Points must be of shape [3] or [1, 3], got {points.shape}."
        )

    # Sample rays from object point
    depth = points[:, 2]
    scale = self.calc_scale(depth)
    point_obj_x = points[..., 0] * scale * sensor_w / 2
    point_obj_y = points[..., 1] * scale * sensor_h / 2
    point_obj = torch.stack([point_obj_x, point_obj_y, points[..., 2]], dim=-1)
    ray = self.sample_from_points(points=point_obj, num_rays=spp, wvln=wvln)

    # Trace rays coherently through the lens to exit pupil
    ray.coherent = True
    ray = self.trace2exit_pupil(ray)

    # Calculate PSF center (not flipped here)
    if recenter:
        pointc = -self.psf_center(point_obj, method="chief_ray")
    else:
        pointc = -self.psf_center(point_obj, method="pinhole")

    # Build PSF pixel coordinates (sensor plane at z = d_sensor)
    sensor_z = self.d_sensor.item()
    psf_half_size = (ks / 2) * pixel_size  # Physical half-size of PSF region
    x_coords = torch.linspace(
        -psf_half_size + pixel_size / 2,
        psf_half_size - pixel_size / 2,
        ks,
        device=device,
    )
    y_coords = torch.linspace(
        psf_half_size - pixel_size / 2,
        -psf_half_size + pixel_size / 2,
        ks,
        device=device,
    )
    psf_x, psf_y = torch.meshgrid(
        pointc[0, 0] + x_coords, pointc[0, 1] + y_coords, indexing="xy"
    )  # [ks, ks] each

    # Get valid rays only
    valid_mask = ray.is_valid > 0
    valid_pos = ray.o[valid_mask]  # [num_valid, 3]
    valid_dir = ray.d[valid_mask]  # [num_valid, 3]
    valid_opl = ray.opl[valid_mask]  # [num_valid]
    num_valid = valid_pos.shape[0]

    # Huygens integration: sum spherical waves from each secondary source
    psf_complex = torch.zeros(ks, ks, dtype=torch.complex128, device=device)
    opl_min = valid_opl.min()

    # Compute distance from each secondary source to each pixel
    batch_size = min(num_valid, 10_000)  # Process rays in batches
    for batch_start in range(0, num_valid, batch_size):
        batch_end = min(batch_start + batch_size, num_valid)

        # Batch ray data
        batch_pos = valid_pos[batch_start:batch_end]  # [batch, 3]
        batch_dir = valid_dir[batch_start:batch_end]  # [batch, 3]
        batch_opl = valid_opl[batch_start:batch_end].squeeze(-1)  # [batch]

        # Distance from each secondary source to each pixel
        # batch_pos: [batch, 3], psf_x: [ks, ks]
        dx = psf_x.unsqueeze(-1) - batch_pos[:, 0]  # [ks, ks, batch]
        dy = psf_y.unsqueeze(-1) - batch_pos[:, 1]  # [ks, ks, batch]
        dz = sensor_z - batch_pos[:, 2]  # [batch]

        # Distance r from secondary source to pixel
        r = torch.sqrt(dx**2 + dy**2 + dz**2)  # [ks, ks, batch]

        # Obliquity factor: cos(theta) where theta is angle from normal
        # Using ray direction at exit pupil (dz component)
        obliq = torch.abs(batch_dir[:, 2])  # [batch]
        amp = 0.5 * (1.0 + obliq)  # Huygens–Fresnel obliquity factor

        # Total optical path = OPL through lens + distance to pixel
        total_opl = batch_opl + r  # [ks, ks, batch]

        # Phase relative to reference
        phase = torch.fmod((total_opl - opl_min) / wvln_mm, 1.0) * (
            2 * torch.pi
        )  # [ks, ks, batch]

        # Complex amplitude: A * exp(i * phase) / r (spherical wave decay)
        # We use 1/r for spherical wave amplitude decay
        complex_amp = (amp / r) * torch.exp(1j * phase)  # [ks, ks, batch]

        # Sum contributions from this batch
        psf_complex += complex_amp.sum(dim=-1)  # [ks, ks]

    # Convert complex field to intensity
    psf = psf_complex.abs() ** 2

    # Intensity normalization
    psf = psf / (torch.sum(psf, dim=(-2, -1), keepdim=True) + EPSILON)

    # Flip PSF
    psf = torch.flip(psf, [-2, -1])

    if single_point:
        psf = psf.squeeze(0)

    return diff_float(psf)

psf_map

psf_map(depth=DEPTH, grid=(7, 7), ks=PSF_KS, spp=SPP_PSF, wvln=DEFAULT_WAVE, recenter=True)

Compute the geometric PSF map at given depth.

Overrides the base method in Lens class to improve efficiency by parallel ray tracing over different field points.

Parameters:

Name Type Description Default
depth float

Depth of the object plane. Defaults to DEPTH.

DEPTH
grid (int, tuple)

Grid size (grid_w, grid_h). Defaults to 7.

(7, 7)
ks int

Kernel size. Defaults to PSF_KS.

PSF_KS
spp int

Sample per pixel. Defaults to SPP_PSF.

SPP_PSF
recenter bool

Recenter PSF using chief ray. Defaults to True.

True

Returns:

Name Type Description
psf_map

PSF map. Shape of [grid_h, grid_w, 1, ks, ks].

Source code in src/geolens_pkg/psf_compute.py
def psf_map(
    self,
    depth=DEPTH,
    grid=(7, 7),
    ks=PSF_KS,
    spp=SPP_PSF,
    wvln=DEFAULT_WAVE,
    recenter=True,
):
    """Compute the geometric PSF map at given depth.

    Overrides the base method in Lens class to improve efficiency by parallel ray tracing over different field points.

    Args:
        depth (float, optional): Depth of the object plane. Defaults to DEPTH.
        grid (int, tuple): Grid size (grid_w, grid_h). Defaults to 7.
        ks (int, optional): Kernel size. Defaults to PSF_KS.
        spp (int, optional): Sample per pixel. Defaults to SPP_PSF.
        recenter (bool, optional): Recenter PSF using chief ray. Defaults to True.

    Returns:
        psf_map: PSF map. Shape of [grid_h, grid_w, 1, ks, ks].
    """
    if isinstance(grid, int):
        grid = (grid, grid)
    points = self.point_source_grid(depth=depth, grid=grid)
    points = points.reshape(-1, 3)
    psfs = self.psf(
        points=points, ks=ks, recenter=recenter, spp=spp, wvln=wvln
    ).unsqueeze(1)  # [grid_h * grid_w, 1, ks, ks]

    psf_map = psfs.reshape(grid[1], grid[0], 1, ks, ks)
    return psf_map

psf_center

psf_center(points_obj, method='chief_ray')

Compute reference PSF center (flipped to match the original point) for given point source.

Parameters:

Name Type Description Default
points_obj

[..., 3] un-normalized point in object plane. [-Inf, Inf] * [-Inf, Inf] * [-Inf, 0]

required
method

"chief_ray" or "pinhole". Defaults to "chief_ray".

'chief_ray'

Returns:

Name Type Description
psf_center

[..., 2] un-normalized psf center in sensor plane.

Source code in src/geolens_pkg/psf_compute.py
@torch.no_grad()
def psf_center(self, points_obj, method="chief_ray"):
    """Compute reference PSF center (flipped to match the original point) for given point source.

    Args:
        points_obj: [..., 3] un-normalized point in object plane. [-Inf, Inf] * [-Inf, Inf] * [-Inf, 0]
        method: "chief_ray" or "pinhole". Defaults to "chief_ray".

    Returns:
        psf_center: [..., 2] un-normalized psf center in sensor plane.
    """
    if method == "chief_ray":
        # Shrink the pupil and calculate green light centroid ray as the chief ray
        ray = self.sample_from_points(points_obj, scale_pupil=0.5, num_rays=SPP_CALC)
        ray = self.trace2sensor(ray)
        if not ray.is_valid.any():
            logging.warning(
                "No valid chief ray for PSF center; falling back to pinhole model."
            )
            return self.psf_center(points_obj, method="pinhole")
        psf_center = ray.centroid()
        psf_center = -psf_center[..., :2]  # shape [..., 2]

    elif method == "pinhole":
        # Pinhole camera perspective projection, distortion not considered
        if points_obj[..., 2].min().abs() < 100:
            print(
                "Point source is too close, pinhole model may be inaccurate for PSF center calculation."
            )
        tan_point_fov_x = -points_obj[..., 0] / points_obj[..., 2]
        tan_point_fov_y = -points_obj[..., 1] / points_obj[..., 2]
        psf_center_x = self.foclen * tan_point_fov_x
        psf_center_y = self.foclen * tan_point_fov_y
        psf_center = torch.stack([psf_center_x, psf_center_y], dim=-1).to(
            self.device
        )

    else:
        raise ValueError(
            f"Unsupported method for PSF center calculation: {method}."
        )

    return psf_center

psf(point, ks, wvln)

Compute the point spread function. Dispatches to geometric or coherent methods.

psf_geometric(point, ks, wvln)

Geometric PSF via spot diagram binning.

psf_coherent(point, ks, wvln)

Coherent PSF via pupil wavefront propagation.

psf_huygens(point, ks, wvln)

Huygens PSF via Rayleigh-Sommerfeld diffraction.

psf_map(ks)

Compute PSF across a grid of field positions.

psf_center(point, ks, wvln)

Compute PSF center position for a given point source.


Evaluation

src.geolens_pkg.eval.GeoLensEval

Mixin that adds classical optical evaluation methods to GeoLens.

This class is never instantiated on its own. It is mixed into GeoLens via multiple inheritance, so every method can access lens geometry (self.d_sensor, self.rfov_eff, …) and ray-tracing routines (self.trace(), self.trace2sensor(), …) directly through self.

All evaluation functions follow the same pattern
  1. Sample rays from object space (parallel / grid / radial).
  2. Trace rays through the lens (self.trace or self.trace2sensor).
  3. Analyze ray positions / directions at the sensor plane.
  4. Optionally produce a matplotlib figure saved to disk.

Results are accuracy-aligned with Zemax OpticStudio for the same lens prescriptions and ray-sampling densities.

Attributes consumed from GeoLens (via self): d_sensor (float): Axial position of the sensor plane (mm). sensor_size (tuple[float, float]): Sensor (width, height) in mm. pixel_size (float): Pixel pitch in mm. sensor_res (tuple[int, int]): Sensor resolution (H, W) in pixels. rfov_eff (float): Effective half field-of-view in radians (pinhole model). rfov (float): Half field-of-view in radians (ray-traced). foclen (float): Equivalent focal length in mm. fnum (float): F-number. aper_idx (int): Index of the aperture stop surface. device (torch.device): Compute device (CPU / CUDA).

spot_points

spot_points(points, num_rays=SPP_PSF, wvln=DEFAULT_WAVE)

Trace rays from object points to sensor and return the traced Ray.

Samples rays from each physical object point toward the entrance pupil, traces through all lens surfaces (refraction + clipping), and returns the resulting Ray object on the sensor plane.

This is the shared computational core for spot diagrams (draw_spot_radial, draw_spot_map) and RMS error maps (rms_map, rms_map_rgb).

Algorithm
  1. self.sample_from_points(points, num_rays, wvln) generates a fan of num_rays rays per object point, aimed at the entrance pupil.
  2. self.trace2sensor() propagates through all surfaces and clips vignetted rays.

Parameters:

Name Type Description Default
points Tensor

Physical 3D object-space coordinates with shape [..., 3] (mm). Supported layouts: - [3] — single point. - [N, 3] — N points (e.g. radial field positions). - [H, W, 3] — 2-D field grid. Generated by self.point_source_grid(normalized=False) for grid sampling, or self.point_source_radial(normalized=False) for radial sampling.

required
num_rays int

Number of rays sampled per object point. Defaults to SPP_PSF.

SPP_PSF
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE

Returns:

Name Type Description
Ray

Traced ray on the sensor plane, with shape [..., num_rays, 3] for positions and [..., num_rays] for validity mask. Use ray.o[..., :2] for transverse positions and ray.is_valid for the validity mask. ray.centroid() gives the weighted centroid.

Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def spot_points(self, points, num_rays=SPP_PSF, wvln=DEFAULT_WAVE):
    """Trace rays from object points to sensor and return the traced Ray.

    Samples rays from each physical object point toward the entrance pupil,
    traces through all lens surfaces (refraction + clipping), and returns
    the resulting Ray object on the sensor plane.

    This is the shared computational core for spot diagrams
    (``draw_spot_radial``, ``draw_spot_map``) and RMS error maps
    (``rms_map``, ``rms_map_rgb``).

    Algorithm:
        1. ``self.sample_from_points(points, num_rays, wvln)`` generates a
           fan of ``num_rays`` rays per object point, aimed at the entrance
           pupil.
        2. ``self.trace2sensor()`` propagates through all surfaces and
           clips vignetted rays.

    Args:
        points (torch.Tensor): Physical 3D object-space coordinates with
            shape ``[..., 3]`` (mm).  Supported layouts:
            - ``[3]`` — single point.
            - ``[N, 3]`` — N points (e.g. radial field positions).
            - ``[H, W, 3]`` — 2-D field grid.
            Generated by ``self.point_source_grid(normalized=False)`` for
            grid sampling, or ``self.point_source_radial(normalized=False)``
            for radial sampling.
        num_rays (int): Number of rays sampled per object point.
            Defaults to ``SPP_PSF``.
        wvln (float): Wavelength in micrometers.
            Defaults to ``DEFAULT_WAVE``.

    Returns:
        Ray: Traced ray on the sensor plane, with shape
            ``[..., num_rays, 3]`` for positions and ``[..., num_rays]``
            for validity mask. Use ``ray.o[..., :2]`` for transverse
            positions and ``ray.is_valid`` for the validity mask.
            ``ray.centroid()`` gives the weighted centroid.
    """
    ray = self.sample_from_points(points=points, num_rays=num_rays, wvln=wvln)
    return self.trace2sensor(ray)

draw_spot_radial

draw_spot_radial(save_name='./lens_spot_radial.png', num_fov=5, depth=DEPTH, num_rays=SPP_PSF, wvln_list=WAVE_RGB, direction='y', show=False)

Draw spot diagrams at evenly-spaced field angles along a chosen direction.

A spot diagram visualizes the transverse ray-intercept distribution on the sensor plane for a point source at a given field angle and depth. It reveals the combined effect of all aberrations (spherical, coma, astigmatism, field curvature, chromatic, …).

Algorithm

For each wavelength in wvln_list: 1. self.point_source_radial(direction, normalized=False) generates physical object-space points along the chosen direction. 2. self.spot_points() samples rays and traces to sensor. 3. Valid ray (x, y) positions are scatter-plotted per subplot. All wavelengths are overlaid in a single figure with RGB coloring.

Parameters:

Name Type Description Default
save_name str

File path for the output PNG. Defaults to './lens_spot_radial.png'.

'./lens_spot_radial.png'
num_fov int

Number of field positions sampled uniformly from on-axis (0) to full-field. Defaults to 5.

5
depth float

Object distance in mm (negative = real object). Defaults to DEPTH.

DEPTH
num_rays int

Rays per field position per wavelength. Defaults to SPP_PSF.

SPP_PSF
wvln_list list[float]

Wavelengths in micrometers. Defaults to WAVE_RGB (red, green, blue).

WAVE_RGB
direction str

Sampling direction — "y" (meridional, default), "x" (sagittal), "diagonal" (45°).

'y'
show bool

If True, display the figure interactively instead of saving to disk. Defaults to False.

False
Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def draw_spot_radial(
    self,
    save_name="./lens_spot_radial.png",
    num_fov=5,
    depth=DEPTH,
    num_rays=SPP_PSF,
    wvln_list=WAVE_RGB,
    direction="y",
    show=False,
):
    """Draw spot diagrams at evenly-spaced field angles along a chosen direction.

    A *spot diagram* visualizes the transverse ray-intercept distribution on
    the sensor plane for a point source at a given field angle and depth.
    It reveals the combined effect of all aberrations (spherical, coma,
    astigmatism, field curvature, chromatic, …).

    Algorithm:
        For each wavelength in ``wvln_list``:
            1. ``self.point_source_radial(direction, normalized=False)``
               generates physical object-space points along the chosen
               direction.
            2. ``self.spot_points()`` samples rays and traces to sensor.
            3. Valid ray (x, y) positions are scatter-plotted per subplot.
        All wavelengths are overlaid in a single figure with RGB coloring.

    Args:
        save_name (str): File path for the output PNG.
            Defaults to ``'./lens_spot_radial.png'``.
        num_fov (int): Number of field positions sampled uniformly from
            on-axis (0) to full-field. Defaults to 5.
        depth (float): Object distance in mm (negative = real object).
            Defaults to ``DEPTH``.
        num_rays (int): Rays per field position per wavelength.
            Defaults to ``SPP_PSF``.
        wvln_list (list[float]): Wavelengths in micrometers.
            Defaults to ``WAVE_RGB`` (red, green, blue).
        direction (str): Sampling direction —
            ``"y"`` (meridional, default), ``"x"`` (sagittal),
            ``"diagonal"`` (45°).
        show (bool): If ``True``, display the figure interactively instead
            of saving to disk. Defaults to ``False``.
    """
    assert isinstance(wvln_list, list), "wvln_list must be a list"
    if depth == float("inf"):
        depth = DEPTH

    # Generate physical object-space points using ray-traced FoV to cover
    # the full sensor, even for lenses with significant distortion.
    max_obj_height = abs(depth) * float(np.tan(self.rfov)) * 0.98
    r = torch.linspace(0, 1.0, num_fov, device=self.device)

    if direction == "y":
        px, py = torch.zeros_like(r), r * max_obj_height
    elif direction == "x":
        px, py = r * max_obj_height, torch.zeros_like(r)
    else:  # diagonal
        px, py = r * max_obj_height, r * max_obj_height

    z = torch.full_like(px, depth)
    points = torch.stack([px, py, z], dim=-1)

    # Prepare figure
    fig, axs = plt.subplots(1, num_fov, figsize=(num_fov * 3.5, 3))
    axs = np.atleast_1d(axs)

    # Trace and draw each wavelength separately, overlaying results
    for wvln_idx, wvln in enumerate(wvln_list):
        ray = self.spot_points(points, num_rays=num_rays, wvln=wvln)
        ray_o = ray.o[..., :2].cpu().numpy()
        ray_valid_np = ray.is_valid.cpu().numpy()

        color = RGB_COLORS[wvln_idx % len(RGB_COLORS)]

        # Plot multiple spot diagrams in one figure
        for i in range(num_fov):
            valid = ray_valid_np[i, :]
            xi, yi = ray_o[i, :, 0], ray_o[i, :, 1]

            # Filter valid rays
            mask = valid > 0
            x_valid, y_valid = xi[mask], yi[mask]

            # Plot points and center of mass for this wavelength
            axs[i].scatter(x_valid, y_valid, 2, color=color, alpha=0.5)
            axs[i].set_aspect("equal", adjustable="datalim")
            axs[i].tick_params(axis="both", which="major", labelsize=6)

    if show:
        plt.show()
    else:
        assert save_name.endswith(".png"), "save_name must end with .png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

draw_spot_map

draw_spot_map(save_name='./lens_spot_map.png', num_grid=5, depth=DEPTH, num_rays=SPP_PSF, wvln_list=WAVE_RGB, show=False)

Draw a 2-D grid of spot diagrams across the full field of view.

Unlike draw_spot_radial (which samples only a radial slice), this method samples a num_grid × num_grid grid of field positions covering both the x (sagittal) and y (meridional) axes, revealing off-axis aberrations that are invisible in a 1-D radial scan.

Algorithm

For each wavelength in wvln_list: 1. self.point_source_grid(normalized=False) creates physical object-space grid points, shape [grid_h, grid_w, 3]. 2. self.spot_points() samples rays and traces to sensor. 3. Valid (x, y) positions are scatter-plotted in the corresponding subplot of the num_grid × num_grid figure. All wavelengths are overlaid with RGB coloring.

Parameters:

Name Type Description Default
save_name str

File path for the output PNG. Defaults to './lens_spot_map.png'.

'./lens_spot_map.png'
num_grid int | tuple[int, int]

Number of grid points along each axis. Total subplots = grid_w * grid_h. Defaults to 5.

5
depth float

Object distance in mm. Defaults to DEPTH.

DEPTH
num_rays int

Rays per grid cell per wavelength. Defaults to SPP_PSF.

SPP_PSF
wvln_list list[float]

Wavelengths in micrometers. Defaults to WAVE_RGB.

WAVE_RGB
show bool

If True, display interactively. Defaults to False.

False
Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def draw_spot_map(
    self,
    save_name="./lens_spot_map.png",
    num_grid=5,
    depth=DEPTH,
    num_rays=SPP_PSF,
    wvln_list=WAVE_RGB,
    show=False,
):
    """Draw a 2-D grid of spot diagrams across the full field of view.

    Unlike ``draw_spot_radial`` (which samples only a radial slice),
    this method samples a ``num_grid × num_grid`` grid of field positions
    covering both the x (sagittal) and y (meridional) axes, revealing
    off-axis aberrations that are invisible in a 1-D radial scan.

    Algorithm:
        For each wavelength in ``wvln_list``:
            1. ``self.point_source_grid(normalized=False)`` creates physical
               object-space grid points, shape ``[grid_h, grid_w, 3]``.
            2. ``self.spot_points()`` samples rays and traces to sensor.
            3. Valid (x, y) positions are scatter-plotted in the
               corresponding subplot of the ``num_grid × num_grid`` figure.
        All wavelengths are overlaid with RGB coloring.

    Args:
        save_name (str): File path for the output PNG.
            Defaults to ``'./lens_spot_map.png'``.
        num_grid (int | tuple[int, int]): Number of grid points along each
            axis. Total subplots = ``grid_w * grid_h``. Defaults to 5.
        depth (float): Object distance in mm. Defaults to ``DEPTH``.
        num_rays (int): Rays per grid cell per wavelength.
            Defaults to ``SPP_PSF``.
        wvln_list (list[float]): Wavelengths in micrometers.
            Defaults to ``WAVE_RGB``.
        show (bool): If ``True``, display interactively. Defaults to ``False``.
    """
    assert isinstance(wvln_list, list), "wvln_list must be a list"
    if isinstance(num_grid, int):
        num_grid = (num_grid, num_grid)

    # Generate physical object-space grid points, shape [grid_h, grid_w, 3]
    points = self.point_source_grid(depth=depth, grid=num_grid, normalized=False)

    grid_w, grid_h = num_grid
    fig, axs = plt.subplots(
        grid_h, grid_w, figsize=(grid_w * 3, grid_h * 3)
    )
    axs = np.atleast_2d(axs)

    # Loop wavelengths and overlay scatters
    for wvln_idx, wvln in enumerate(wvln_list):
        ray = self.spot_points(points, num_rays=num_rays, wvln=wvln)

        # Convert to numpy, shape [grid_h, grid_w, num_rays, 2]
        ray_o = -ray.o[..., :2].cpu().numpy()
        ray_valid_np = ray.is_valid.cpu().numpy()

        color = RGB_COLORS[wvln_idx % len(RGB_COLORS)]

        # Draw per grid cell
        for i in range(grid_h):
            for j in range(grid_w):
                valid = ray_valid_np[i, j, :]
                xi, yi = ray_o[i, j, :, 0], ray_o[i, j, :, 1]

                # Filter valid rays
                mask = valid > 0
                x_valid, y_valid = xi[mask], yi[mask]

                # Plot points for this wavelength
                axs[i, j].scatter(x_valid, y_valid, 2, color=color, alpha=0.5)
                axs[i, j].set_aspect("equal", adjustable="datalim")
                axs[i, j].tick_params(axis="both", which="major", labelsize=6)

    if show:
        plt.show()
    else:
        assert save_name.endswith(".png"), "save_name must end with .png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

rms_map

rms_map(num_grid=32, depth=DEPTH, wvln=DEFAULT_WAVE, center=None)

Compute per-field-position RMS spot radius for a single wavelength.

Traces SPP_PSF rays per grid cell and computes the root-mean-square distance of valid ray hits from a reference centroid. When center is None, each cell uses its own centroid (monochromatic blur). When an external center is provided (e.g. the green-channel centroid), the RMS includes the chromatic shift from that reference.

Algorithm
  1. self.point_source_grid(normalized=False) generates physical object points on a [num_grid, num_grid] field grid.
  2. self.spot_points() samples SPP_PSF rays per point and traces to sensor.
  3. If center is None, compute per-cell centroid c = mean(valid ray_xy); otherwise use the provided center.
  4. RMS = sqrt( mean( ||ray_xy - c||^2 ) ).

Parameters:

Name Type Description Default
num_grid int | tuple[int, int]

Spatial resolution of the field sampling grid. Defaults to 32.

32
depth float

Object distance in mm. Defaults to DEPTH.

DEPTH
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
center Tensor | None

External reference centroid with shape [grid_h, grid_w, 2]. If None, each cell's own centroid is used. Defaults to None.

None

Returns:

Type Description

tuple[torch.Tensor, torch.Tensor]: - rms: RMS spot error map, shape [grid_h, grid_w], in mm. - centroid: Per-cell centroid used as reference, shape [grid_h, grid_w, 2]. Useful for passing as center to subsequent calls (e.g. in rms_map_rgb).

Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def rms_map(self, num_grid=32, depth=DEPTH, wvln=DEFAULT_WAVE, center=None):
    """Compute per-field-position RMS spot radius for a single wavelength.

    Traces ``SPP_PSF`` rays per grid cell and computes the root-mean-square
    distance of valid ray hits from a reference centroid.  When ``center``
    is ``None``, each cell uses its own centroid (monochromatic blur).
    When an external ``center`` is provided (e.g. the green-channel
    centroid), the RMS includes the chromatic shift from that reference.

    Algorithm:
        1. ``self.point_source_grid(normalized=False)`` generates physical
           object points on a ``[num_grid, num_grid]`` field grid.
        2. ``self.spot_points()`` samples ``SPP_PSF`` rays per point and
           traces to sensor.
        3. If ``center`` is ``None``, compute per-cell centroid
           ``c = mean(valid ray_xy)``; otherwise use the provided ``center``.
        4. ``RMS = sqrt( mean( ||ray_xy - c||^2 ) )``.

    Args:
        num_grid (int | tuple[int, int]): Spatial resolution of the field
            sampling grid. Defaults to 32.
        depth (float): Object distance in mm. Defaults to ``DEPTH``.
        wvln (float): Wavelength in micrometers. Defaults to ``DEFAULT_WAVE``.
        center (torch.Tensor | None): External reference centroid with shape
            ``[grid_h, grid_w, 2]``.  If ``None``, each cell's own
            centroid is used. Defaults to ``None``.

    Returns:
        tuple[torch.Tensor, torch.Tensor]:
            - **rms**: RMS spot error map, shape ``[grid_h, grid_w]``,
              in mm.
            - **centroid**: Per-cell centroid used as reference, shape
              ``[grid_h, grid_w, 2]``.  Useful for passing as
              ``center`` to subsequent calls (e.g. in ``rms_map_rgb``).
    """
    if isinstance(num_grid, int):
        num_grid = (num_grid, num_grid)

    # Generate physical grid points and trace rays to sensor
    points = self.point_source_grid(depth=depth, grid=num_grid, normalized=False)
    ray = self.spot_points(points, num_rays=SPP_PSF, wvln=wvln)

    # Reuse Ray.centroid() — shape [grid_h, grid_w, 3], slice to [grid_h, grid_w, 2]
    centroid = ray.centroid()[..., :2]

    # Use external center if provided, otherwise own centroid
    ref = center if center is not None else centroid

    # RMS relative to reference, shape [grid_h, grid_w]
    ray_xy = ray.o[..., :2]
    ray_valid = ray.is_valid
    rms = torch.sqrt(
        (((ray_xy - ref.unsqueeze(-2)) ** 2).sum(-1) * ray_valid).sum(-1)
        / (ray_valid.sum(-1) + EPSILON)
    )

    return rms, centroid

rms_map_rgb

rms_map_rgb(num_grid=32, depth=DEPTH)

Compute per-field-position RMS spot radius for R, G, B wavelengths.

The RMS spot radius is a standard measure of geometrical image quality. For each field position in a num_grid × num_grid grid, this method traces SPP_PSF rays per wavelength and computes the root-mean-square distance of valid ray hits from a common reference centroid.

The reference centroid is the green-channel centroid. Using a common reference means the returned RMS values include lateral chromatic aberration (the shift between R/G/B centroids), making the map useful as a polychromatic image-quality metric.

Algorithm
  1. Call rms_map(wvln=green) to get the green RMS map and the green centroid.
  2. Call rms_map(wvln=red, center=green_centroid) and rms_map(wvln=blue, center=green_centroid) to measure R/B blur relative to the green reference.
  3. Stack as [R, G, B].

Parameters:

Name Type Description Default
num_grid int

Spatial resolution of the field sampling grid. Defaults to 32.

32
depth float

Object distance in mm. Defaults to DEPTH.

DEPTH

Returns:

Type Description

torch.Tensor: RMS spot error map with shape [3, num_grid, num_grid] (channels ordered R, G, B). Units are mm (same as sensor coordinates).

Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def rms_map_rgb(self, num_grid=32, depth=DEPTH):
    """Compute per-field-position RMS spot radius for R, G, B wavelengths.

    The RMS spot radius is a standard measure of geometrical image quality.
    For each field position in a ``num_grid × num_grid`` grid, this method
    traces ``SPP_PSF`` rays per wavelength and computes the root-mean-square
    distance of valid ray hits from a **common** reference centroid.

    The reference centroid is the green-channel centroid.  Using a common
    reference means the returned RMS values include *lateral chromatic
    aberration* (the shift between R/G/B centroids), making the map useful
    as a polychromatic image-quality metric.

    Algorithm:
        1. Call ``rms_map(wvln=green)`` to get the green RMS map **and**
           the green centroid.
        2. Call ``rms_map(wvln=red, center=green_centroid)`` and
           ``rms_map(wvln=blue, center=green_centroid)`` to measure R/B
           blur relative to the green reference.
        3. Stack as ``[R, G, B]``.

    Args:
        num_grid (int): Spatial resolution of the field sampling grid.
            Defaults to 32.
        depth (float): Object distance in mm. Defaults to ``DEPTH``.

    Returns:
        torch.Tensor: RMS spot error map with shape ``[3, num_grid, num_grid]``
            (channels ordered R, G, B). Units are mm (same as sensor
            coordinates).
    """
    # Green first to obtain the shared reference centroid
    rms_g, green_centroid = self.rms_map(
        num_grid=num_grid, depth=depth, wvln=WAVE_RGB[1]
    )

    # Red and blue relative to the green centroid
    rms_r, _ = self.rms_map(
        num_grid=num_grid, depth=depth, wvln=WAVE_RGB[0], center=green_centroid
    )
    rms_b, _ = self.rms_map(
        num_grid=num_grid, depth=depth, wvln=WAVE_RGB[2], center=green_centroid
    )

    return torch.stack([rms_r, rms_g, rms_b], dim=0)

calc_distortion_radial

calc_distortion_radial(num_points=GEO_GRID, wvln=DEFAULT_WAVE, plane='meridional', ray_aiming=True)

Compute fractional distortion at evenly-spaced field angles along the meridional direction.

Distortion is defined as (h_actual - h_ideal) / h_ideal, where h_ideal = f * tan(theta) (rectilinear projection) and h_actual is the chief-ray image height on the sensor. A positive value means pincushion distortion; negative means barrel distortion.

This is the computational counterpart to draw_spot_radial: it samples num_points field angles uniformly from 0 to self.rfov_eff and returns both the sampled angles and the corresponding distortion values, making it easy to pair with other radial evaluation functions.

Algorithm
  1. Derive rfov_deg from self.rfov_eff (radians → degrees).
  2. Sample num_points field angles uniformly in [0, rfov_deg]. The on-axis sample (0°) is replaced by a tiny positive angle to avoid 0/0.
  3. Compute h_ideal = foclen * tan(angle) for each sample.
  4. Trace the chief ray (via calc_chief_ray_infinite) through the full lens to the sensor plane.
  5. Extract h_actual from the appropriate transverse coordinate (x for sagittal, y for meridional).
  6. Return (h_actual - h_ideal) / h_ideal.

Parameters:

Name Type Description Default
num_points int

Number of evenly-spaced field-angle samples from on-axis (0°) to full-field (self.rfov_eff). Defaults to GEO_GRID.

GEO_GRID
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
plane str

'meridional' (y-axis) or 'sagittal' (x-axis). Defaults to 'meridional'.

'meridional'
ray_aiming bool

If True, the chief ray is aimed to pass through the center of the aperture stop (more accurate for wide-angle lenses). Defaults to True.

True

Returns:

Type Description

tuple[np.ndarray, np.ndarray]: - rfov_samples: Field angles in degrees, shape [num_points]. - distortions: Fractional distortion at each angle, shape [num_points]. Dimensionless (multiply by 100 for percent).

Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def calc_distortion_radial(
    self,
    num_points=GEO_GRID,
    wvln=DEFAULT_WAVE,
    plane="meridional",
    ray_aiming=True,
):
    """Compute fractional distortion at evenly-spaced field angles along the meridional direction.

    Distortion is defined as ``(h_actual - h_ideal) / h_ideal``, where
    ``h_ideal = f * tan(theta)`` (rectilinear projection) and ``h_actual``
    is the chief-ray image height on the sensor.  A positive value means
    pincushion distortion; negative means barrel distortion.

    This is the computational counterpart to ``draw_spot_radial``: it
    samples ``num_points`` field angles uniformly from 0 to ``self.rfov_eff``
    and returns both the sampled angles and the corresponding distortion
    values, making it easy to pair with other radial evaluation functions.

    Algorithm:
        1. Derive ``rfov_deg`` from ``self.rfov_eff`` (radians → degrees).
        2. Sample ``num_points`` field angles uniformly in
           ``[0, rfov_deg]``.  The on-axis sample (0°) is replaced by a
           tiny positive angle to avoid 0/0.
        3. Compute ``h_ideal = foclen * tan(angle)`` for each sample.
        4. Trace the chief ray (via ``calc_chief_ray_infinite``) through the
           full lens to the sensor plane.
        5. Extract ``h_actual`` from the appropriate transverse coordinate
           (x for sagittal, y for meridional).
        6. Return ``(h_actual - h_ideal) / h_ideal``.

    Args:
        num_points (int): Number of evenly-spaced field-angle samples from
            on-axis (0°) to full-field (``self.rfov_eff``).
            Defaults to ``GEO_GRID``.
        wvln (float): Wavelength in micrometers. Defaults to ``DEFAULT_WAVE``.
        plane (str): ``'meridional'`` (y-axis) or ``'sagittal'`` (x-axis).
            Defaults to ``'meridional'``.
        ray_aiming (bool): If ``True``, the chief ray is aimed to pass
            through the center of the aperture stop (more accurate for
            wide-angle lenses). Defaults to ``True``.

    Returns:
        tuple[np.ndarray, np.ndarray]:
            - **rfov_samples**: Field angles in degrees, shape ``[num_points]``.
            - **distortions**: Fractional distortion at each angle, shape
              ``[num_points]``.  Dimensionless (multiply by 100 for
              percent).
    """
    rfov_deg = float(self.rfov) * 180.0 / np.pi

    # Sample field angles uniformly from 0 to rfov_deg.
    # For the on-axis point (FOV=0), distortion is 0/0.  We compute it at a
    # tiny positive angle to obtain the correct limit, which may be non-zero
    # when the sensor is not at the paraxial focus.
    rfov_samples = torch.linspace(0, rfov_deg, num_points)
    rfov_compute = rfov_samples.clone()
    if rfov_compute[0] == 0:
        rfov_compute[0] = min(0.01, rfov_samples[1].item() * 0.01)

    # Ideal image height: h_ideal = f * tan(theta)
    eff_foclen = float(self.foclen)
    ideal_imgh = eff_foclen * np.tan(rfov_compute.numpy() * np.pi / 180)

    # Trace chief rays to the sensor plane
    chief_ray_o, chief_ray_d = self.calc_chief_ray_infinite(
        rfov=rfov_compute, wvln=wvln, plane=plane, ray_aiming=ray_aiming
    )
    ray = Ray(chief_ray_o, chief_ray_d, wvln=wvln, device=self.device)
    ray, _ = self.trace(ray)
    t = (self.d_sensor - ray.o[..., 2]) / ray.d[..., 2]

    # Actual image height from the appropriate transverse coordinate
    if plane == "sagittal":
        actual_imgh = (ray.o[..., 0] + ray.d[..., 0] * t).abs()
    elif plane == "meridional":
        actual_imgh = (ray.o[..., 1] + ray.d[..., 1] * t).abs()
    else:
        raise ValueError(f"Invalid plane: {plane}")

    actual_imgh = actual_imgh.cpu().numpy()

    # Fractional distortion, with safe handling of the on-axis singularity
    ideal_imgh = np.asarray(ideal_imgh)
    mask = np.abs(ideal_imgh) < EPSILON
    distortions = np.where(
        mask, 0.0, (actual_imgh - ideal_imgh) / np.where(mask, 1.0, ideal_imgh)
    )

    return rfov_samples.numpy(), distortions

draw_distortion_radial

draw_distortion_radial(save_name=None, num_points=GEO_GRID, wvln=DEFAULT_WAVE, plane='meridional', ray_aiming=True, show=False)

Draw distortion-vs-field-angle curve in Zemax style.

Produces a plot with field angle on the y-axis and percent distortion on the x-axis, matching the layout convention used in Zemax OpticStudio. Useful for quick visual assessment of barrel / pincushion distortion.

Algorithm
  1. Call calc_distortion_radial to obtain field angles and fractional distortion values.
  2. Convert distortion to percent and plot.

Parameters:

Name Type Description Default
save_name str | None

File path for the output PNG. If None, auto-generates './{plane}_distortion_inf.png'.

None
num_points int

Number of field-angle samples. Defaults to GEO_GRID.

GEO_GRID
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
plane str

'meridional' or 'sagittal'. Defaults to 'meridional'.

'meridional'
ray_aiming bool

Whether to use ray aiming for chief-ray computation. Defaults to True.

True
show bool

If True, display interactively. Defaults to False.

False
Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def draw_distortion_radial(
    self,
    save_name=None,
    num_points=GEO_GRID,
    wvln=DEFAULT_WAVE,
    plane="meridional",
    ray_aiming=True,
    show=False,
):
    """Draw distortion-vs-field-angle curve in Zemax style.

    Produces a plot with field angle on the y-axis and percent distortion
    on the x-axis, matching the layout convention used in Zemax OpticStudio.
    Useful for quick visual assessment of barrel / pincushion distortion.

    Algorithm:
        1. Call ``calc_distortion_radial`` to obtain field angles and
           fractional distortion values.
        2. Convert distortion to percent and plot.

    Args:
        save_name (str | None): File path for the output PNG.  If ``None``,
            auto-generates ``'./{plane}_distortion_inf.png'``.
        num_points (int): Number of field-angle samples.
            Defaults to ``GEO_GRID``.
        wvln (float): Wavelength in micrometers. Defaults to ``DEFAULT_WAVE``.
        plane (str): ``'meridional'`` or ``'sagittal'``.
            Defaults to ``'meridional'``.
        ray_aiming (bool): Whether to use ray aiming for chief-ray
            computation. Defaults to ``True``.
        show (bool): If ``True``, display interactively. Defaults to ``False``.
    """
    rfov_deg = float(self.rfov) * 180.0 / np.pi

    # Calculate distortion at evenly-spaced field angles
    rfov_samples, distortions = self.calc_distortion_radial(
        num_points=num_points, wvln=wvln, plane=plane, ray_aiming=ray_aiming
    )

    # Convert to percentage and handle NaN
    values = np.nan_to_num(distortions * 100, nan=0.0).tolist()

    # Create figure
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_title(f"{plane} Surface Distortion")

    # Draw distortion curve
    ax.plot(values, rfov_samples, linestyle="-", color="g", linewidth=1.5)

    # Draw reference line (vertical line)
    ax.axvline(x=0, color="k", linestyle="-", linewidth=0.8)

    # Set grid
    ax.grid(True, color="gray", linestyle="-", linewidth=0.5, alpha=1)

    # Dynamically adjust x-axis range
    value = max(abs(v) for v in values)
    margin = value * 0.2  # 20% margin
    x_min, x_max = -max(0.2, value + margin), max(0.2, value + margin)

    # Set ticks
    x_ticks = np.linspace(-value, value, 3)
    y_ticks = np.linspace(0, rfov_deg, 3)

    ax.set_xticks(x_ticks)
    ax.set_yticks(y_ticks)

    # Format tick labels
    x_labels = [f"{x:.1f}%" for x in x_ticks]
    y_labels = [f"{y:.1f}" for y in y_ticks]

    ax.set_xticklabels(x_labels)
    ax.set_yticklabels(y_labels)

    # Set axis labels
    ax.set_xlabel("Distortion (%)")
    ax.set_ylabel("Field of View (degrees)")

    # Set axis range
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(0, rfov_deg)

    if show:
        plt.show()
    else:
        if save_name is None:
            save_name = f"./{plane}_distortion_inf.png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

calc_distortion_map

calc_distortion_map(num_grid=16, depth=DEPTH, wvln=DEFAULT_WAVE)

Compute a 2-D distortion grid mapping ideal to actual image positions.

For each cell in a num_grid × num_grid field grid, rays are traced to the sensor and their centroid is computed. The centroid is then normalized to [-1, 1] sensor coordinates, producing a map that shows how each ideal image point is displaced by lens distortion.

This map can be used with torch.nn.functional.grid_sample to warp or unwarp rendered images.

Parameters:

Name Type Description Default
num_grid int

Grid resolution along each axis. Defaults to 16.

16
depth float

Object distance in mm. Defaults to DEPTH.

DEPTH
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE

Returns:

Type Description

torch.Tensor: Distortion grid with shape [num_grid, num_grid, 2]. Each entry (dx, dy) is in normalized sensor coordinates [-1, 1], representing the actual centroid position for the corresponding ideal grid position.

Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def calc_distortion_map(self, num_grid=16, depth=DEPTH, wvln=DEFAULT_WAVE):
    """Compute a 2-D distortion grid mapping ideal to actual image positions.

    For each cell in a ``num_grid × num_grid`` field grid, rays are traced
    to the sensor and their centroid is computed.  The centroid is then
    normalized to ``[-1, 1]`` sensor coordinates, producing a map that
    shows how each ideal image point is displaced by lens distortion.

    This map can be used with ``torch.nn.functional.grid_sample`` to warp
    or unwarp rendered images.

    Args:
        num_grid (int): Grid resolution along each axis. Defaults to 16.
        depth (float): Object distance in mm. Defaults to ``DEPTH``.
        wvln (float): Wavelength in micrometers. Defaults to ``DEFAULT_WAVE``.

    Returns:
        torch.Tensor: Distortion grid with shape ``[num_grid, num_grid, 2]``.
            Each entry ``(dx, dy)`` is in normalized sensor coordinates
            ``[-1, 1]``, representing the actual centroid position for the
            corresponding ideal grid position.
    """
    # Sample and trace rays, shape (grid_size, grid_size, num_rays, 3)
    ray = self.sample_grid_rays(depth=depth, num_grid=num_grid, wvln=wvln, uniform_fov=False)
    ray = self.trace2sensor(ray)

    # Calculate centroid of the rays, shape (grid_size, grid_size, 2)
    ray_xy = ray.centroid()[..., :2]
    x_dist = -ray_xy[..., 0] / self.sensor_size[1] * 2
    y_dist = ray_xy[..., 1] / self.sensor_size[0] * 2
    distortion_grid = torch.stack((x_dist, y_dist), dim=-1)
    return distortion_grid

distortion_center

distortion_center(points)

Compute the distorted image centroid for arbitrary normalized object points.

Given object points in normalized coordinates, this method converts them to physical object-space positions, traces rays from each point through the lens, and returns the ray centroid on the sensor in normalized [-1, 1] coordinates. This is the inverse mapping needed for distortion correction (unwarping).

Algorithm
  1. Convert normalized (x, y) ∈ [-1, 1] to physical object-space positions using self.calc_scale(depth) and self.sensor_size.
  2. self.sample_from_points() generates rays from each point.
  3. self.trace2sensor() propagates rays.
  4. Compute centroid and normalize back to [-1, 1].

Parameters:

Name Type Description Default
points Tensor

Normalized point source positions with shape [N, 3] or [..., 3]. x, y ∈ [-1, 1] encode the field position; z ∈ (-∞, 0] is the object depth in mm.

required

Returns:

Type Description

torch.Tensor: Normalized distortion centroid positions with shape [N, 2] or [..., 2]. x, y ∈ [-1, 1].

Source code in src/geolens_pkg/eval.py
def distortion_center(self, points):
    """Compute the distorted image centroid for arbitrary normalized object points.

    Given object points in normalized coordinates, this method converts them
    to physical object-space positions, traces rays from each point through
    the lens, and returns the ray centroid on the sensor in normalized
    ``[-1, 1]`` coordinates.  This is the inverse mapping needed for
    distortion correction (unwarping).

    Algorithm:
        1. Convert normalized ``(x, y)`` ∈ [-1, 1] to physical object-space
           positions using ``self.calc_scale(depth)`` and ``self.sensor_size``.
        2. ``self.sample_from_points()`` generates rays from each point.
        3. ``self.trace2sensor()`` propagates rays.
        4. Compute centroid and normalize back to ``[-1, 1]``.

    Args:
        points (torch.Tensor): Normalized point source positions with shape
            ``[N, 3]`` or ``[..., 3]``.  ``x, y`` ∈ [-1, 1] encode the
            field position; ``z`` ∈ (-∞, 0] is the object depth in mm.

    Returns:
        torch.Tensor: Normalized distortion centroid positions with shape
            ``[N, 2]`` or ``[..., 2]``.  ``x, y`` ∈ [-1, 1].
    """
    sensor_w, sensor_h = self.sensor_size

    # Convert normalized points to object space coordinates
    depth = points[..., 2]
    scale = self.calc_scale(depth)
    points_obj_x = points[..., 0] * scale * sensor_w / 2
    points_obj_y = points[..., 1] * scale * sensor_h / 2
    points_obj = torch.stack([points_obj_x, points_obj_y, depth], dim=-1)

    # Sample rays and trace to sensor
    ray = self.sample_from_points(points=points_obj)
    ray = self.trace2sensor(ray)

    # Calculate centroid and normalize to [-1, 1]
    ray_center = -ray.centroid()  # shape [..., 3]
    distortion_center_x = ray_center[..., 0] / (sensor_w / 2)
    distortion_center_y = ray_center[..., 1] / (sensor_h / 2)
    distortion_center = torch.stack((distortion_center_x, distortion_center_y), dim=-1)
    return distortion_center

draw_distortion_map

draw_distortion_map(save_name=None, num_grid=16, depth=DEPTH, wvln=DEFAULT_WAVE, show=False)

Draw a scatter plot of the distortion grid.

Visualizes the output of calc_distortion_map() as a scatter plot on [-1, 1] normalized sensor coordinates. An undistorted lens would show a perfect rectilinear grid; deviations reveal barrel or pincushion distortion.

Parameters:

Name Type Description Default
save_name str | None

File path for the output PNG. If None, auto-generates './distortion_{depth}.png'.

None
num_grid int

Grid resolution per axis. Defaults to 16.

16
depth float

Object distance in mm. Defaults to DEPTH.

DEPTH
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
show bool

If True, display interactively. Defaults to False.

False
Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def draw_distortion_map(
    self, save_name=None, num_grid=16, depth=DEPTH, wvln=DEFAULT_WAVE, show=False
):
    """Draw a scatter plot of the distortion grid.

    Visualizes the output of ``calc_distortion_map()`` as a scatter plot on
    ``[-1, 1]`` normalized sensor coordinates.  An undistorted lens would
    show a perfect rectilinear grid; deviations reveal barrel or pincushion
    distortion.

    Args:
        save_name (str | None): File path for the output PNG.  If ``None``,
            auto-generates ``'./distortion_{depth}.png'``.
        num_grid (int): Grid resolution per axis. Defaults to 16.
        depth (float): Object distance in mm. Defaults to ``DEPTH``.
        wvln (float): Wavelength in micrometers. Defaults to ``DEFAULT_WAVE``.
        show (bool): If ``True``, display interactively. Defaults to ``False``.
    """
    # Ray tracing to calculate distortion map
    distortion_grid = self.calc_distortion_map(num_grid=num_grid, depth=depth, wvln=wvln)
    x1 = distortion_grid[..., 0].cpu().numpy()
    y1 = distortion_grid[..., 1].cpu().numpy()

    # Draw image
    fig, ax = plt.subplots()
    ax.set_title("Lens distortion")
    ax.scatter(x1, y1, s=2)
    ax.axis("scaled")
    ax.grid(True)

    # Add grid lines based on grid_size
    ax.set_xticks(np.linspace(-1, 1, num_grid))
    ax.set_yticks(np.linspace(-1, 1, num_grid))

    if show:
        plt.show()
    else:
        depth_str = "inf" if depth == float("inf") else f"{-depth}mm"
        if save_name is None:
            save_name = f"./distortion_{depth_str}.png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

mtf

mtf(fov, wvln=DEFAULT_WAVE)

Compute the geometric MTF at a single field position.

The Modulation Transfer Function describes how well the lens preserves contrast as a function of spatial frequency. MTF = 1 at low frequencies (perfect contrast) and falls toward 0 near the diffraction limit or the Nyquist frequency of the sensor.

This implementation uses the geometric (ray-based) approach: 1. Compute the PSF at the given field position via self.psf(). 2. Convert PSF → MTF via psf2mtf() (project onto tangential and sagittal axes, then take the magnitude of the 1-D FFT).

Tangential MTF captures resolution in the meridional (radial) direction; sagittal MTF captures resolution perpendicular to it. The difference between the two indicates astigmatism.

Parameters:

Name Type Description Default
fov float

Field position as a fraction of self.rfov_eff (0 = on-axis, 1 = full field). Internally mapped to a normalized point [0, -fov/rfov, DEPTH].

required
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE

Returns:

Type Description

tuple[np.ndarray, np.ndarray, np.ndarray]: - freq: Spatial frequency axis in cycles/mm (positive frequencies only, excluding DC). - mtf_tan: Tangential (meridional) MTF values, normalized so that MTF → 1 at low frequency. - mtf_sag: Sagittal MTF values, same normalization.

Source code in src/geolens_pkg/eval.py
def mtf(self, fov, wvln=DEFAULT_WAVE):
    """Compute the geometric MTF at a single field position.

    The *Modulation Transfer Function* describes how well the lens preserves
    contrast as a function of spatial frequency.  MTF = 1 at low frequencies
    (perfect contrast) and falls toward 0 near the diffraction limit or the
    Nyquist frequency of the sensor.

    This implementation uses the *geometric* (ray-based) approach:
        1. Compute the PSF at the given field position via ``self.psf()``.
        2. Convert PSF → MTF via ``psf2mtf()`` (project onto tangential and
           sagittal axes, then take the magnitude of the 1-D FFT).

    Tangential MTF captures resolution in the meridional (radial) direction;
    sagittal MTF captures resolution perpendicular to it.  The difference
    between the two indicates astigmatism.

    Args:
        fov (float): Field position as a **fraction** of ``self.rfov_eff``
            (0 = on-axis, 1 = full field).  Internally mapped to a
            normalized point ``[0, -fov/rfov, DEPTH]``.
        wvln (float): Wavelength in micrometers. Defaults to ``DEFAULT_WAVE``.

    Returns:
        tuple[np.ndarray, np.ndarray, np.ndarray]:
            - **freq**: Spatial frequency axis in cycles/mm (positive
              frequencies only, excluding DC).
            - **mtf_tan**: Tangential (meridional) MTF values, normalized
              so that MTF → 1 at low frequency.
            - **mtf_sag**: Sagittal MTF values, same normalization.
    """
    point = [0, -fov / self.rfov_eff, DEPTH]
    psf = self.psf(points=point, recenter=True, wvln=wvln)
    freq, mtf_tan, mtf_sag = self.psf2mtf(psf, pixel_size=self.pixel_size)
    return freq, mtf_tan, mtf_sag

psf2mtf staticmethod

psf2mtf(psf, pixel_size)

Convert a 2-D point-spread function to tangential and sagittal MTF curves.

The MTF is the magnitude of the optical transfer function (OTF), which is the Fourier transform of the PSF. For separable 1-D analysis: 1. Integrate the PSF along the x-axis → tangential line-spread function (LSF_tan). 2. Integrate the PSF along the y-axis → sagittal LSF_sag. 3. Take |FFT(LSF)| and normalize by the DC component so that MTF(0) = 1.

Only positive frequencies (excluding DC) are returned, following the convention used in Zemax MTF plots.

Parameters:

Name Type Description Default
psf Tensor | ndarray

2-D PSF with shape [H, W]. The array's y-axis (rows) corresponds to the tangential (meridional) direction; x-axis (columns) to the sagittal direction.

required
pixel_size float

Pixel pitch in mm. Determines the frequency axis scaling: Nyquist = 0.5 / pixel_size cycles/mm.

required

Returns:

Type Description

tuple[np.ndarray, np.ndarray, np.ndarray]: - freq: Spatial frequency in cycles/mm (positive, excluding DC). Length is roughly H // 2. - mtf_tan: Tangential MTF, normalized to 1 at DC. - mtf_sag: Sagittal MTF, normalized to 1 at DC.

References
  • https://en.wikipedia.org/wiki/Optical_transfer_function
  • Edmund Optics: Introduction to Modulation Transfer Function.
Source code in src/geolens_pkg/eval.py
@staticmethod
def psf2mtf(psf, pixel_size):
    """Convert a 2-D point-spread function to tangential and sagittal MTF curves.

    The MTF is the magnitude of the optical transfer function (OTF), which
    is the Fourier transform of the PSF.  For separable 1-D analysis:
        1. Integrate the PSF along the x-axis → *tangential* line-spread
           function (LSF_tan).
        2. Integrate the PSF along the y-axis → *sagittal* LSF_sag.
        3. Take ``|FFT(LSF)|`` and normalize by the DC component so that
           MTF(0) = 1.

    Only positive frequencies (excluding DC) are returned, following the
    convention used in Zemax MTF plots.

    Args:
        psf (torch.Tensor | np.ndarray): 2-D PSF with shape ``[H, W]``.
            The array's y-axis (rows) corresponds to the **tangential**
            (meridional) direction; x-axis (columns) to the **sagittal**
            direction.
        pixel_size (float): Pixel pitch in mm.  Determines the frequency
            axis scaling: ``Nyquist = 0.5 / pixel_size`` cycles/mm.

    Returns:
        tuple[np.ndarray, np.ndarray, np.ndarray]:
            - **freq**: Spatial frequency in cycles/mm (positive, excluding
              DC).  Length is roughly ``H // 2``.
            - **mtf_tan**: Tangential MTF, normalized to 1 at DC.
            - **mtf_sag**: Sagittal MTF, normalized to 1 at DC.

    References:
        - https://en.wikipedia.org/wiki/Optical_transfer_function
        - Edmund Optics: Introduction to Modulation Transfer Function.
    """
    # Convert to numpy (supports torch tensors and numpy arrays)
    try:
        psf_np = psf.detach().cpu().numpy()
    except AttributeError:
        try:
            psf_np = psf.cpu().numpy()
        except AttributeError:
            psf_np = np.asarray(psf)

    # Compute line spread functions (integrate PSF over orthogonal axes)
    # y-axis corresponds to tangential; x-axis corresponds to sagittal
    lsf_sagittal = psf_np.sum(axis=0)  # function of x
    lsf_tangential = psf_np.sum(axis=1)  # function of y

    # One-sided spectra (for real inputs)
    mtf_sag = np.abs(np.fft.rfft(lsf_sagittal))
    mtf_tan = np.abs(np.fft.rfft(lsf_tangential))

    # Normalize by DC to ensure MTF(0) == 1
    dc_sag = mtf_sag[0] if mtf_sag.size > 0 else 1.0
    dc_tan = mtf_tan[0] if mtf_tan.size > 0 else 1.0
    if dc_sag != 0:
        mtf_sag = mtf_sag / dc_sag
    if dc_tan != 0:
        mtf_tan = mtf_tan / dc_tan

    # Frequency axis in cycles/mm (one-sided)
    fx = np.fft.rfftfreq(lsf_sagittal.size, d=pixel_size)
    freq = fx
    positive_freq_idx = freq > 0

    return (
        freq[positive_freq_idx],
        mtf_tan[positive_freq_idx],
        mtf_sag[positive_freq_idx],
    )

draw_mtf

draw_mtf(save_name='./lens_mtf.png', relative_fov_list=[0.0, 0.7, 1.0], depth_list=[DEPTH], psf_ks=128, show=False)

Draw a grid of tangential MTF curves for multiple depths and field positions.

Produces a len(depth_list) × len(relative_fov_list) subplot grid. Each subplot shows the tangential MTF for R, G, B wavelengths plus a vertical line at the sensor Nyquist frequency (0.5 / pixel_size cycles/mm).

Algorithm per subplot
  1. Compute the RGB PSF via self.psf_rgb() at the specified (depth, relative_fov) with kernel size psf_ks.
  2. For each wavelength channel, call psf2mtf() to obtain the tangential MTF curve.
  3. Plot frequency vs MTF with RGB coloring.

Parameters:

Name Type Description Default
save_name str

File path for the output PNG. Defaults to './lens_mtf.png'.

'./lens_mtf.png'
relative_fov_list list[float]

Relative field positions in [0, 1], where 0 = on-axis and 1 = full field. Defaults to [0.0, 0.7, 1.0].

[0.0, 0.7, 1.0]
depth_list list[float]

Object distances in mm. float('inf') is automatically replaced by DEPTH. Defaults to [DEPTH].

[DEPTH]
psf_ks int

PSF kernel size in pixels (controls frequency resolution of the resulting MTF). Defaults to 128.

128
show bool

If True, display interactively. Defaults to False.

False
Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def draw_mtf(
    self,
    save_name="./lens_mtf.png",
    relative_fov_list=[0.0, 0.7, 1.0],
    depth_list=[DEPTH],
    psf_ks=128,
    show=False,
):
    """Draw a grid of tangential MTF curves for multiple depths and field positions.

    Produces a ``len(depth_list) × len(relative_fov_list)`` subplot grid.
    Each subplot shows the tangential MTF for R, G, B wavelengths plus a
    vertical line at the sensor Nyquist frequency
    (``0.5 / pixel_size`` cycles/mm).

    Algorithm per subplot:
        1. Compute the RGB PSF via ``self.psf_rgb()`` at the specified
           ``(depth, relative_fov)`` with kernel size ``psf_ks``.
        2. For each wavelength channel, call ``psf2mtf()`` to obtain the
           tangential MTF curve.
        3. Plot frequency vs MTF with RGB coloring.

    Args:
        save_name (str): File path for the output PNG.
            Defaults to ``'./lens_mtf.png'``.
        relative_fov_list (list[float]): Relative field positions in
            ``[0, 1]``, where 0 = on-axis and 1 = full field.
            Defaults to ``[0.0, 0.7, 1.0]``.
        depth_list (list[float]): Object distances in mm.
            ``float('inf')`` is automatically replaced by ``DEPTH``.
            Defaults to ``[DEPTH]``.
        psf_ks (int): PSF kernel size in pixels (controls frequency
            resolution of the resulting MTF). Defaults to 128.
        show (bool): If ``True``, display interactively. Defaults to ``False``.
    """
    pixel_size = self.pixel_size
    nyquist_freq = 0.5 / pixel_size
    num_fovs = len(relative_fov_list)
    if float("inf") in depth_list:
        depth_list = [DEPTH if x == float("inf") else x for x in depth_list]
    num_depths = len(depth_list)

    # Create figure and subplots (num_depths * num_fovs subplots)
    fig, axs = plt.subplots(
        num_depths, num_fovs, figsize=(num_fovs * 3, num_depths * 3), squeeze=False
    )

    # Correct for distortion: scale normalized FoV so 1.0 maps to the true
    # sensor edge, not the pinhole-model edge.
    fov_scale = float(np.tan(self.rfov) / np.tan(self.rfov_eff)) if self.rfov_eff > 0 else 1.0

    # Iterate over depth and field of view
    for depth_idx, depth in enumerate(depth_list):
        for fov_idx, fov_relative in enumerate(relative_fov_list):
            # Calculate rgb PSF (scale fov_relative to correct for distortion)
            point = [0, -fov_relative * fov_scale, depth]
            psf_rgb = self.psf_rgb(points=point, ks=psf_ks, recenter=False)

            # Calculate MTF curves for rgb wavelengths
            for wvln_idx, wvln in enumerate(WAVE_RGB):
                # Calculate MTF curves from PSF
                psf = psf_rgb[wvln_idx]
                freq, mtf_tan, _ = self.psf2mtf(psf, pixel_size)

                # Plot MTF curves
                ax = axs[depth_idx, fov_idx]
                color = RGB_COLORS[wvln_idx % len(RGB_COLORS)]
                wvln_label = RGB_LABELS[wvln_idx % len(RGB_LABELS)]
                wvln_nm = int(wvln * 1000)
                ax.plot(
                    freq,
                    mtf_tan,
                    color=color,
                    label=f"{wvln_label}({wvln_nm}nm)-Tan",
                )

            # Draw Nyquist frequency
            ax.axvline(
                x=nyquist_freq,
                color="k",
                linestyle=":",
                linewidth=1.2,
                label="Nyquist",
            )

            # Set title and label for subplot
            fov_deg = round(fov_relative * self.rfov * 180 / np.pi, 1)
            depth_str = "inf" if depth == float("inf") else f"{depth}"
            ax.set_title(f"FOV: {fov_deg}deg, Depth: {depth_str}mm", fontsize=8)
            ax.set_xlabel("Spatial Frequency [cycles/mm]", fontsize=8)
            ax.set_ylabel("MTF", fontsize=8)
            ax.legend(fontsize=6)
            ax.tick_params(axis="both", which="major", labelsize=7)
            ax.grid(True)
            ax.set_ylim(0, 1.05)

    plt.tight_layout()
    if show:
        plt.show()
    else:
        assert save_name.endswith(".png"), "save_name must end with .png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

draw_field_curvature

draw_field_curvature(save_name=None, num_points=64, z_span=1.0, z_steps=201, wvln_list=WAVE_RGB, spp=256, show=False)

Draw field curvature: best-focus defocus (Δz) vs field angle for RGB.

Field curvature (Petzval curvature) causes off-axis image points to focus on a curved surface rather than the flat sensor. This method finds the axial position of minimum RMS spot size at each field angle and plots the deviation from the nominal sensor plane.

Algorithm (fully vectorized per wavelength): 1. Construct a meridional ray fan at num_points field angles, each with spp rays spanning the entrance pupil. 2. Trace all rays through the lens in a single batched call. 3. For each of z_steps defocus planes within ±z_span mm of self.d_sensor, propagate rays analytically (linear extension) and compute the variance of the y-coordinate. 4. The defocus with minimum variance is the best-focus plane. Parabolic interpolation on the three-point neighborhood gives sub-grid-step precision. 5. Repeat for each wavelength; overlay R/G/B curves on a single plot.

Parameters:

Name Type Description Default
save_name str | None

File path for the output PNG. If None, defaults to './field_curvature.png'.

None
num_points int

Number of field-angle samples from 0 to self.rfov_eff. Defaults to 64.

64
z_span float

Half-range of the defocus sweep in mm. If the best-focus hits the boundary, a warning is printed. Defaults to 1.0.

1.0
z_steps int

Number of uniformly-spaced defocus planes within ±z_span. Higher values give finer axial resolution. Defaults to 201.

201
wvln_list list[float]

Wavelengths in micrometers. Defaults to WAVE_RGB.

WAVE_RGB
spp int

Rays per field point (sampled uniformly across the entrance pupil in the meridional plane). Defaults to 256.

256
show bool

If True, display interactively. Defaults to False.

False
Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def draw_field_curvature(
    self,
    save_name=None,
    num_points=64,
    z_span=1.0,
    z_steps=201,
    wvln_list=WAVE_RGB,
    spp=256,
    show=False,
):
    """Draw field curvature: best-focus defocus (Δz) vs field angle for RGB.

    *Field curvature* (Petzval curvature) causes off-axis image points to
    focus on a curved surface rather than the flat sensor.  This method
    finds the axial position of minimum RMS spot size at each field angle
    and plots the deviation from the nominal sensor plane.

    Algorithm (fully vectorized per wavelength):
        1. Construct a meridional ray fan at ``num_points`` field angles,
           each with ``spp`` rays spanning the entrance pupil.
        2. Trace all rays through the lens in a single batched call.
        3. For each of ``z_steps`` defocus planes within ``±z_span`` mm of
           ``self.d_sensor``, propagate rays analytically (linear
           extension) and compute the variance of the y-coordinate.
        4. The defocus with minimum variance is the best-focus plane.
           Parabolic interpolation on the three-point neighborhood gives
           sub-grid-step precision.
        5. Repeat for each wavelength; overlay R/G/B curves on a single plot.

    Args:
        save_name (str | None): File path for the output PNG.  If ``None``,
            defaults to ``'./field_curvature.png'``.
        num_points (int): Number of field-angle samples from 0 to
            ``self.rfov_eff``. Defaults to 64.
        z_span (float): Half-range of the defocus sweep in mm.  If the
            best-focus hits the boundary, a warning is printed.
            Defaults to 1.0.
        z_steps (int): Number of uniformly-spaced defocus planes within
            ``±z_span``. Higher values give finer axial resolution.
            Defaults to 201.
        wvln_list (list[float]): Wavelengths in micrometers.
            Defaults to ``WAVE_RGB``.
        spp (int): Rays per field point (sampled uniformly across the
            entrance pupil in the meridional plane). Defaults to 256.
        show (bool): If ``True``, display interactively. Defaults to ``False``.
    """
    device = self.device
    rfov_deg = float(self.rfov) * 180.0 / np.pi

    # Sample field angles [0, rfov_deg], shape [F]
    rfov_samples = torch.linspace(0.0, rfov_deg, num_points, device=device)

    # Entrance pupil (computed once)
    pupilz, pupilr = self.get_entrance_pupil()

    # Defocus sweep grid, shape [Z]
    d_sensor = self.d_sensor
    z_grid = d_sensor + torch.linspace(-z_span, z_span, z_steps, device=device)

    delta_z_tan = []

    for wvln in wvln_list:
        # --- Batch ray construction for all field angles ---
        # Pupil positions: shape [spp]
        pupil_y = torch.linspace(-pupilr, pupilr, spp, device=device) * 0.99

        # Ray origins: shape [F, spp, 3] (meridional plane: x=0)
        ray_o = torch.zeros(num_points, spp, 3, device=device)
        ray_o[..., 1] = pupil_y.unsqueeze(0)  # y = pupil sample
        ray_o[..., 2] = pupilz  # z = entrance pupil z

        # Ray directions: shape [F, spp, 3] (meridional: dx=0)
        fov_rad = rfov_samples * (np.pi / 180.0)  # [F]
        sin_fov = torch.sin(fov_rad)  # [F]
        cos_fov = torch.cos(fov_rad)  # [F]
        ray_d = torch.zeros(num_points, spp, 3, device=device)
        ray_d[..., 1] = sin_fov.unsqueeze(-1)  # [F, 1] -> [F, spp]
        ray_d[..., 2] = cos_fov.unsqueeze(-1)

        # Create batched ray and trace all field angles at once
        ray = Ray(ray_o, ray_d, wvln=wvln, device=device)
        ray, _ = self.trace(ray)

        # --- Vectorized best-focus for all field angles ---
        # ray.o: [F, spp, 3], ray.d: [F, spp, 3]
        oz = ray.o[..., 2:3]  # [F, spp, 1]
        dz = ray.d[..., 2:3]  # [F, spp, 1]
        t = (z_grid.view(1, 1, -1) - oz) / (dz + EPSILON)  # [F, spp, Z]

        oa = ray.o[..., 1:2]  # y-axis (tangential)
        da = ray.d[..., 1:2]
        pos_y = oa + da * t  # [F, spp, Z]

        w = ray.is_valid.unsqueeze(-1).float()  # [F, spp, 1]
        pos_y = pos_y * w  # mask invalid rays
        w_sum = w.sum(dim=1)  # [F, 1]

        centroid = pos_y.sum(dim=1) / (w_sum + EPSILON)  # [F, Z]
        ms = (((pos_y - centroid.unsqueeze(1)) ** 2) * w).sum(dim=1) / (
            w_sum + EPSILON
        )  # [F, Z]

        best_idx = torch.argmin(ms, dim=1)  # [F]

        # Warn if best focus hits z_span boundary
        boundary_hit = (best_idx == 0) | (best_idx == z_steps - 1)
        if boundary_hit.any():
            n_boundary = boundary_hit.sum().item()
            print(
                f"Warning: {n_boundary}/{num_points} field angles hit z_span "
                f"boundary. Consider increasing z_span (currently {z_span} mm)."
            )

        # Parabolic interpolation for sub-grid precision
        idx_c = best_idx.clamp(1, z_steps - 2)  # avoid boundary
        f_range = torch.arange(num_points, device=device)
        y_l = ms[f_range, idx_c - 1]
        y_c = ms[f_range, idx_c]
        y_r = ms[f_range, idx_c + 1]
        denom = 2.0 * (y_l - 2.0 * y_c + y_r)
        shift = (y_l - y_r) / (denom + EPSILON)  # fractional index offset
        shift = shift.clamp(-0.5, 0.5)  # safety clamp

        z_step_size = (2.0 * z_span) / (z_steps - 1)
        best_z = z_grid[idx_c] + shift * z_step_size  # [F]
        dz_tan = (best_z - d_sensor).cpu().numpy()

        # Mark fully-vignetted field angles as NaN (gaps in plot)
        valid_count = w.sum(dim=1).squeeze(-1)  # [F]
        fully_vignetted = (valid_count < 2).cpu().numpy()
        dz_tan[fully_vignetted] = np.nan

        delta_z_tan.append(dz_tan)

    # Plot
    fov_np = rfov_samples.detach().cpu().numpy()
    fig, ax = plt.subplots(figsize=(7, 6))
    ax.set_title("Field Curvature (Δz vs Field Angle)")

    all_vals = np.abs(np.concatenate(delta_z_tan)) if len(delta_z_tan) > 0 else np.array([0.0])
    x_range = float(max(0.2, all_vals.max() * 1.2)) if all_vals.size > 0 else 0.2

    for w_idx in range(len(wvln_list)):
        color = RGB_COLORS[w_idx % len(RGB_COLORS)]
        lbl = RGB_LABELS[w_idx % len(RGB_LABELS)]
        ax.plot(
            delta_z_tan[w_idx],
            fov_np,
            color=color,
            linestyle="-",
            label=f"{lbl}-Tan",
        )

    ax.axvline(x=0, color="k", linestyle="-", linewidth=0.8)
    ax.grid(True, color="gray", linestyle="-", linewidth=0.5, alpha=1.0)
    ax.set_xlabel("Defocus Δz (mm) relative to sensor plane")
    ax.set_ylabel("Field Angle (deg)")
    ax.set_xlim(-x_range, x_range)
    ax.set_ylim(0, rfov_deg)
    ax.legend(fontsize=8)
    plt.tight_layout()

    if show:
        plt.show()
    else:
        if save_name is None:
            save_name = "./field_curvature.png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

vignetting

vignetting(depth=DEPTH, num_grid=32, num_rays=512)

Compute the relative-illumination (vignetting) map across the field.

Vignetting measures how much light is lost at each field position due to rays being clipped by lens apertures or barrel edges. It is computed as the fraction of traced rays that remain valid (not vignetted) at each grid cell, normalized by the total number of launched rays.

A value of 1.0 means all rays reach the sensor (no vignetting); 0.0 means complete light blockage. Real lenses typically show 1.0 on-axis and fall off toward the field edges due to mechanical vignetting and the cos⁴ illumination law.

Algorithm
  1. self.sample_grid_rays() with uniform_fov=False (uniform image-space sampling) to ensure correct sensor-plane mapping.
  2. self.trace2sensor() propagates rays and marks clipped ones as invalid.
  3. Per-cell throughput = count(valid) / num_rays.

Parameters:

Name Type Description Default
depth float

Object distance in mm. Defaults to DEPTH.

DEPTH
num_grid int

Grid resolution per axis. Defaults to 32.

32
num_rays int

Rays launched per grid cell. Higher values reduce Monte-Carlo noise. Defaults to 512.

512

Returns:

Type Description

torch.Tensor: Vignetting map with shape [num_grid, num_grid], values in [0, 1].

Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def vignetting(self, depth=DEPTH, num_grid=32, num_rays=512):
    """Compute the relative-illumination (vignetting) map across the field.

    Vignetting measures how much light is lost at each field position due to
    rays being clipped by lens apertures or barrel edges.  It is computed as
    the fraction of traced rays that remain valid (not vignetted) at each
    grid cell, normalized by the total number of launched rays.

    A value of 1.0 means all rays reach the sensor (no vignetting); 0.0
    means complete light blockage.  Real lenses typically show 1.0 on-axis
    and fall off toward the field edges due to mechanical vignetting and the
    cos⁴ illumination law.

    Algorithm:
        1. ``self.sample_grid_rays()`` with ``uniform_fov=False`` (uniform
           image-space sampling) to ensure correct sensor-plane mapping.
        2. ``self.trace2sensor()`` propagates rays and marks clipped ones as
           invalid.
        3. Per-cell throughput = ``count(valid) / num_rays``.

    Args:
        depth (float): Object distance in mm. Defaults to ``DEPTH``.
        num_grid (int): Grid resolution per axis. Defaults to 32.
        num_rays (int): Rays launched per grid cell.  Higher values reduce
            Monte-Carlo noise. Defaults to 512.

    Returns:
        torch.Tensor: Vignetting map with shape ``[num_grid, num_grid]``,
            values in ``[0, 1]``.
    """
    # Sample rays in uniform image space (not FOV angles) for correct sensor mapping
    # shape [num_grid, num_grid, num_rays, 3]
    ray = self.sample_grid_rays(
        depth=depth, num_grid=num_grid, num_rays=num_rays, uniform_fov=False
    )

    # Trace rays to sensor
    ray = self.trace2sensor(ray)

    # Calculate vignetting map
    vignetting = ray.is_valid.sum(-1) / (ray.is_valid.shape[-1])
    return vignetting

draw_vignetting

draw_vignetting(filename=None, depth=DEPTH, resolution=512, show=False)

Draw the vignetting map as a grayscale image with a colorbar.

Computes the vignetting map via self.vignetting(), bilinearly upsamples it to resolution × resolution, and displays it as a grayscale image where white = no vignetting and black = fully vignetted.

Parameters:

Name Type Description Default
filename str | None

File path for the output PNG. If None, auto-generates './vignetting_{depth}.png'.

None
depth float

Object distance in mm. Defaults to DEPTH.

DEPTH
resolution int

Output image size in pixels (square). Defaults to 512.

512
show bool

If True, display interactively. Defaults to False.

False
Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def draw_vignetting(self, filename=None, depth=DEPTH, resolution=512, show=False):
    """Draw the vignetting map as a grayscale image with a colorbar.

    Computes the vignetting map via ``self.vignetting()``, bilinearly
    upsamples it to ``resolution × resolution``, and displays it as a
    grayscale image where white = no vignetting and black = fully vignetted.

    Args:
        filename (str | None): File path for the output PNG.  If ``None``,
            auto-generates ``'./vignetting_{depth}.png'``.
        depth (float): Object distance in mm. Defaults to ``DEPTH``.
        resolution (int): Output image size in pixels (square).
            Defaults to 512.
        show (bool): If ``True``, display interactively. Defaults to ``False``.
    """
    # Calculate vignetting map
    vignetting = self.vignetting(depth=depth)

    # Interpolate vignetting map to desired resolution
    vignetting = F.interpolate(
        vignetting.unsqueeze(0).unsqueeze(0),
        size=(resolution, resolution),
        mode="bilinear",
        align_corners=False,
    ).squeeze()

    fig, ax = plt.subplots()
    ax.set_title("Relative Illumination (Vignetting)")
    im = ax.imshow(vignetting.cpu().numpy(), cmap="gray", vmin=0.0, vmax=1.0)
    fig.colorbar(im, ax=ax, ticks=[0.0, 0.25, 0.5, 0.75, 1.0])

    if show:
        plt.show()
    else:
        if filename is None:
            filename = f"./vignetting_{depth}.png"
        plt.savefig(filename, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

wavefront_error

wavefront_error(relative_fov=0.0, depth=DEPTH, wvln=DEFAULT_WAVE, num_rays=SPP_COHERENT, ks=256)

Compute wavefront error (OPD) at the exit pupil for a given field position.

The wavefront error is the optical path difference between the actual wavefront and the ideal spherical reference wavefront. The reference sphere is centered at the ideal image point (chief ray intersection with the sensor) and passes through the exit pupil center.

By Fermat's principle, a perfect lens has equal total optical path (object → lens → image) for all rays. The deviation from this equal-path condition is the wavefront error:

``OPD(x,y) = [OPL(x,y) + r(x,y)] - mean_over_pupil``

where OPL(x,y) is the accumulated optical path from the object through the lens to the exit pupil, and r(x,y) is the geometric distance from the exit pupil point to the ideal image point. Piston (mean) is removed.

Uses the same coherent ray-tracing infrastructure as :meth:pupil_field.

Parameters:

Name Type Description Default
relative_fov float

Relative field of view in [-1, 1] along the meridional (y) direction. 0 = on-axis, 1 = full field.

0.0
depth float

Object distance [mm]. Use DEPTH for practical infinity.

DEPTH
wvln float

Wavelength [µm].

DEFAULT_WAVE
num_rays int

Number of rays to sample through the pupil.

SPP_COHERENT
ks int

Grid resolution for the OPD map at the exit pupil.

256

Returns:

Name Type Description
dict
  • opd_map (Tensor): OPD map on exit pupil grid, shape [ks, ks], in waves. Invalid (vignetted) regions are zero.
  • rms (float): RMS wavefront error in waves (piston removed).
  • pv (float): Peak-to-valley wavefront error in waves.
  • valid_mask (Tensor): Boolean mask of valid pupil pixels [ks, ks].
  • strehl (float): Maréchal approximation Strehl ratio.
Note

This function sets the default dtype to torch.float64 for phase accuracy (consistent with :meth:pupil_field).

References

[1] V. N. Mahajan, "Optical Imaging and Aberrations, Part II", Ch. 1. [2] Zemax OpticStudio, "Wavefront Error Analysis".

Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def wavefront_error(
    self,
    relative_fov=0.0,
    depth=DEPTH,
    wvln=DEFAULT_WAVE,
    num_rays=SPP_COHERENT,
    ks=256,
):
    """Compute wavefront error (OPD) at the exit pupil for a given field position.

    The wavefront error is the optical path difference between the actual
    wavefront and the ideal spherical reference wavefront. The reference sphere
    is centered at the ideal image point (chief ray intersection with the sensor)
    and passes through the exit pupil center.

    By Fermat's principle, a perfect lens has equal total optical path (object →
    lens → image) for all rays. The deviation from this equal-path condition is
    the wavefront error:

        ``OPD(x,y) = [OPL(x,y) + r(x,y)] - mean_over_pupil``

    where ``OPL(x,y)`` is the accumulated optical path from the object through
    the lens to the exit pupil, and ``r(x,y)`` is the geometric distance from
    the exit pupil point to the ideal image point. Piston (mean) is removed.

    Uses the same coherent ray-tracing infrastructure as :meth:`pupil_field`.

    Args:
        relative_fov (float): Relative field of view in ``[-1, 1]`` along the
            meridional (y) direction. ``0`` = on-axis, ``1`` = full field.
        depth (float): Object distance [mm]. Use ``DEPTH`` for practical infinity.
        wvln (float): Wavelength [µm].
        num_rays (int): Number of rays to sample through the pupil.
        ks (int): Grid resolution for the OPD map at the exit pupil.

    Returns:
        dict:
            - ``opd_map`` (Tensor): OPD map on exit pupil grid, shape ``[ks, ks]``,
              in waves. Invalid (vignetted) regions are zero.
            - ``rms`` (float): RMS wavefront error in waves (piston removed).
            - ``pv`` (float): Peak-to-valley wavefront error in waves.
            - ``valid_mask`` (Tensor): Boolean mask of valid pupil pixels ``[ks, ks]``.
            - ``strehl`` (float): Maréchal approximation Strehl ratio.

    Note:
        This function sets the default dtype to ``torch.float64`` for phase
        accuracy (consistent with :meth:`pupil_field`).

    References:
        [1] V. N. Mahajan, "Optical Imaging and Aberrations, Part II", Ch. 1.
        [2] Zemax OpticStudio, "Wavefront Error Analysis".
    """
    # Float64 required for accurate OPL accumulation
    self.astype(torch.float64)
    device = self.device
    sensor_w, sensor_h = self.sensor_size
    wvln_mm = wvln * 1e-3

    # Build normalized point: positive relative_fov -> negative y (convention)
    point_norm = torch.tensor(
        [0.0, -relative_fov, depth], dtype=torch.float64, device=device
    )
    points = point_norm.unsqueeze(0)  # [1, 3]

    # Convert to physical object coordinates
    scale = self.calc_scale(points[:, 2].item())
    point_obj_x = points[:, 0] * scale * sensor_w / 2
    point_obj_y = points[:, 1] * scale * sensor_h / 2
    point_obj = torch.stack([point_obj_x, point_obj_y, points[:, 2]], dim=-1)

    # Find ideal image point via chief ray
    # psf_center returns negated centroid, so negate back to get actual image position
    chief_pointc = self.psf_center(point_obj, method="chief_ray")  # [1, 2]
    img_x = -chief_pointc[0, 0]
    img_y = -chief_pointc[0, 1]
    img_z = float(self.d_sensor)

    # Sample rays and trace coherently to exit pupil
    ray = self.sample_from_points(
        points=point_obj, num_rays=num_rays, wvln=wvln
    )
    ray.coherent = True
    ray = self.trace2exit_pupil(ray)

    # Get exit pupil parameters
    pupilz, pupilr = self.get_exit_pupil()
    pupilr = float(pupilr)
    pupilz = float(pupilz)

    # Extract valid rays (squeeze batch dim since single point)
    valid = ray.is_valid.squeeze(0) > 0  # [num_rays]
    ray_x = ray.o[0, :, 0]  # [num_rays]
    ray_y = ray.o[0, :, 1]
    opl = ray.opl[0, :, 0]  # [num_rays]

    if valid.sum() == 0:
        raise RuntimeError(
            f"No valid rays at relative_fov={relative_fov}. "
            "The field may be fully vignetted."
        )

    # Distance from each ray's exit pupil position to ideal image point
    dist_to_img = torch.sqrt(
        (ray_x - img_x) ** 2
        + (ray_y - img_y) ** 2
        + (pupilz - img_z) ** 2
    )

    # Total optical path = OPL through lens to exit pupil + free-space to image
    total_path = opl + dist_to_img  # [num_rays]

    # Remove piston (mean over valid rays) to get wavefront error
    total_path_valid = total_path[valid]
    mean_path = total_path_valid.mean()
    opd_mm = total_path - mean_path  # OPD in [mm]
    opd_waves = opd_mm / wvln_mm  # OPD in [waves]

    # Compute RMS and PV from per-ray values (more accurate than from grid)
    opd_valid = opd_waves[valid]
    rms_waves = torch.sqrt(torch.mean(opd_valid**2)).item()
    pv_waves = (opd_valid.max() - opd_valid.min()).item()

    # Maréchal approximation: Strehl ≈ exp(-(2π·σ)²)
    strehl = math.exp(-(2 * math.pi * rms_waves) ** 2)

    # Bin OPD values onto exit pupil grid using assign_points_to_pixels
    # Grid covers [-pupilr, pupilr] x [-pupilr, pupilr]
    pupil_range = [-pupilr, pupilr]
    pupil_points = torch.stack([ray_x[valid], ray_y[valid]], dim=-1)  # [N, 2]
    pupil_mask = torch.ones(pupil_points.shape[0], device=device)

    # Sum of weighted OPD values
    opd_sum = assign_points_to_pixels(
        points=pupil_points,
        mask=pupil_mask,
        ks=ks,
        x_range=pupil_range,
        y_range=pupil_range,
        value=opd_valid,
    )
    # Sum of weights (count)
    count = assign_points_to_pixels(
        points=pupil_points,
        mask=pupil_mask,
        ks=ks,
        x_range=pupil_range,
        y_range=pupil_range,
        value=torch.ones_like(opd_valid),
    )
    valid_mask = count > 0
    opd_map = torch.where(valid_mask, opd_sum / count, torch.zeros_like(opd_sum))

    return {
        "opd_map": opd_map,
        "rms": rms_waves,
        "pv": pv_waves,
        "valid_mask": valid_mask,
        "strehl": strehl,
    }

rms_wavefront_error

rms_wavefront_error(relative_fov=0.0, depth=DEPTH, wvln=DEFAULT_WAVE, num_rays=SPP_COHERENT)

Compute scalar RMS wavefront error at a given field position.

Convenience wrapper around :meth:wavefront_error.

Parameters:

Name Type Description Default
relative_fov float

Relative field of view in [-1, 1].

0.0
depth float

Object distance [mm].

DEPTH
wvln float

Wavelength [µm].

DEFAULT_WAVE
num_rays int

Number of rays to sample.

SPP_COHERENT

Returns:

Name Type Description
float

RMS wavefront error in waves.

Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def rms_wavefront_error(
    self,
    relative_fov=0.0,
    depth=DEPTH,
    wvln=DEFAULT_WAVE,
    num_rays=SPP_COHERENT,
):
    """Compute scalar RMS wavefront error at a given field position.

    Convenience wrapper around :meth:`wavefront_error`.

    Args:
        relative_fov (float): Relative field of view in ``[-1, 1]``.
        depth (float): Object distance [mm].
        wvln (float): Wavelength [µm].
        num_rays (int): Number of rays to sample.

    Returns:
        float: RMS wavefront error in waves.
    """
    result = self.wavefront_error(
        relative_fov=relative_fov,
        depth=depth,
        wvln=wvln,
        num_rays=num_rays,
    )
    return result["rms"]

draw_wavefront_error

draw_wavefront_error(save_name='./wavefront_error.png', num_fov=5, depth=DEPTH, wvln=DEFAULT_WAVE, num_rays=SPP_COHERENT, ks=256, show=False)

Draw wavefront error (OPD) maps at multiple field positions.

Evaluates the wavefront error along the meridional (y) direction from on-axis to full field, and displays each OPD map with RMS and PV annotations.

Parameters:

Name Type Description Default
save_name str

Filename to save the figure.

'./wavefront_error.png'
num_fov int

Number of field positions to evaluate.

5
depth float

Object distance [mm].

DEPTH
wvln float

Wavelength [µm].

DEFAULT_WAVE
num_rays int

Number of rays to sample per field position.

SPP_COHERENT
ks int

Grid resolution for each OPD map.

256
show bool

If True, display the figure interactively.

False
Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def draw_wavefront_error(
    self,
    save_name="./wavefront_error.png",
    num_fov=5,
    depth=DEPTH,
    wvln=DEFAULT_WAVE,
    num_rays=SPP_COHERENT,
    ks=256,
    show=False,
):
    """Draw wavefront error (OPD) maps at multiple field positions.

    Evaluates the wavefront error along the meridional (y) direction from
    on-axis to full field, and displays each OPD map with RMS and PV
    annotations.

    Args:
        save_name (str): Filename to save the figure.
        num_fov (int): Number of field positions to evaluate.
        depth (float): Object distance [mm].
        wvln (float): Wavelength [µm].
        num_rays (int): Number of rays to sample per field position.
        ks (int): Grid resolution for each OPD map.
        show (bool): If True, display the figure interactively.
    """
    fov_list = torch.linspace(0, 1, num_fov).tolist()

    fig, axs = plt.subplots(1, num_fov, figsize=(num_fov * 3.5, 3.5))
    axs = np.atleast_1d(axs)

    # Collect all OPD ranges to use a shared color scale
    results = []
    vmax = 0.0
    for fov in fov_list:
        try:
            result = self.wavefront_error(
                relative_fov=fov,
                depth=depth,
                wvln=wvln,
                num_rays=num_rays,
                ks=ks,
            )
            results.append(result)
            opd_valid = result["opd_map"][result["valid_mask"]]
            if len(opd_valid) > 0:
                vmax = max(vmax, opd_valid.abs().max().item())
        except RuntimeError:
            results.append(None)

    if vmax == 0:
        vmax = 1.0  # fallback

    for i, (fov, result) in enumerate(zip(fov_list, results)):
        if result is None:
            axs[i].set_title(f"FoV={fov:.2f}\n(vignetted)", fontsize=8)
            axs[i].axis("off")
            continue

        opd = result["opd_map"].cpu().numpy()
        mask = result["valid_mask"].cpu().numpy()
        rms = result["rms"]
        pv = result["pv"]

        # Mask invalid regions with NaN for visualization
        opd_vis = np.where(mask, opd, np.nan)

        im = axs[i].imshow(
            opd_vis,
            cmap="RdBu_r",
            vmin=-vmax,
            vmax=vmax,
            interpolation="bilinear",
        )
        axs[i].set_title(
            f"FoV={fov:.2f}\nRMS={rms:.4f}λ  PV={pv:.3f}λ",
            fontsize=8,
        )
        axs[i].axis("off")
        fig.colorbar(
            im,
            ax=axs[i],
            fraction=0.046,
            pad=0.04,
            label="OPD [waves]",
        )

    fig.suptitle(
        f"Wavefront Error (λ={wvln}µm, depth={depth}mm)", fontsize=10
    )
    plt.tight_layout()

    if show:
        plt.show()
    else:
        assert save_name.endswith(".png"), "save_name must end with .png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

field_curvature

field_curvature()

Compute field curvature data (best-focus defocus vs field angle).

Field curvature is the axial shift of the best-focus surface away from the flat sensor plane as a function of field angle. It is caused by the Petzval sum of lens surface curvatures and refractive indices.

Not yet implemented. See draw_field_curvature() for a plotting version that already performs the underlying computation.

Source code in src/geolens_pkg/eval.py
def field_curvature(self):
    """Compute field curvature data (best-focus defocus vs field angle).

    Field curvature is the axial shift of the best-focus surface away from
    the flat sensor plane as a function of field angle.  It is caused by
    the Petzval sum of lens surface curvatures and refractive indices.

    Not yet implemented.  See ``draw_field_curvature()`` for a plotting
    version that already performs the underlying computation.
    """
    pass

calc_chief_ray

calc_chief_ray(fov, plane='sagittal')

Find the chief ray for a given field angle using 2-D ray tracing.

The chief ray (also called the principal ray) is the ray from an off-axis object point that passes through the center of the aperture stop. It defines the image height for distortion calculations and sets the reference axis for coma and lateral color analysis.

Algorithm
  1. Sample a fan of parallel rays at the specified fov in the chosen plane, entering through the entrance pupil.
  2. Trace the fan up to (but not through) the aperture stop.
  3. Select the ray whose transverse position at the stop is closest to the optical axis — this is the chief ray.
  4. Return its incident (object-space) origin and direction.

Parameters:

Name Type Description Default
fov float

Incident half-angle in degrees.

required
plane str

'sagittal' (x-axis) or 'meridional' (y-axis). Defaults to 'sagittal'.

'sagittal'

Returns:

Type Description

tuple[torch.Tensor, torch.Tensor]: - chief_ray_o: Origin of the chief ray in object space, shape [3]. - chief_ray_d: Unit direction of the chief ray, shape [3].

Note

This is a 2-D (meridional or sagittal plane) search. For a full 3-D chief ray, one would shrink the pupil and trace the centroid ray.

Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def calc_chief_ray(self, fov, plane="sagittal"):
    """Find the chief ray for a given field angle using 2-D ray tracing.

    The *chief ray* (also called the *principal ray*) is the ray from an
    off-axis object point that passes through the center of the aperture
    stop.  It defines the image height for distortion calculations and sets
    the reference axis for coma and lateral color analysis.

    Algorithm:
        1. Sample a fan of parallel rays at the specified ``fov`` in the
           chosen plane, entering through the entrance pupil.
        2. Trace the fan up to (but not through) the aperture stop.
        3. Select the ray whose transverse position at the stop is closest
           to the optical axis — this is the chief ray.
        4. Return its *incident* (object-space) origin and direction.

    Args:
        fov (float): Incident half-angle in **degrees**.
        plane (str): ``'sagittal'`` (x-axis) or ``'meridional'`` (y-axis).
            Defaults to ``'sagittal'``.

    Returns:
        tuple[torch.Tensor, torch.Tensor]:
            - **chief_ray_o**: Origin of the chief ray in object space,
              shape ``[3]``.
            - **chief_ray_d**: Unit direction of the chief ray, shape ``[3]``.

    Note:
        This is a 2-D (meridional or sagittal plane) search.  For a full
        3-D chief ray, one would shrink the pupil and trace the centroid ray.
    """
    # Sample parallel rays from object space
    ray = self.sample_parallel_2D(
        fov=fov, num_rays=SPP_CALC, entrance_pupil=True, plane=plane
    )
    inc_ray = ray.clone()

    # Trace to the aperture
    surf_range = range(0, self.aper_idx)
    ray, _ = self.trace(ray, surf_range=surf_range)

    # Look for the ray that is closest to the optical axis
    center_x = torch.min(torch.abs(ray.o[:, 0]))
    center_idx = torch.where(torch.abs(ray.o[:, 0]) == center_x)[0][0].item()
    chief_ray_o, chief_ray_d = inc_ray.o[center_idx, :], inc_ray.d[center_idx, :]

    return chief_ray_o, chief_ray_d

calc_chief_ray_infinite

calc_chief_ray_infinite(rfov, depth=0.0, wvln=DEFAULT_WAVE, plane='meridional', num_rays=SPP_CALC, ray_aiming=True)

Compute chief rays for one or more field angles with optional ray aiming.

This is the batched, production version of calc_chief_ray. It supports vectorized evaluation over multiple field angles and implements ray aiming — an iterative procedure that launches a fan of rays toward the entrance pupil and selects the one that passes closest to the aperture-stop center. Ray aiming is essential for accurate distortion measurement in wide-angle or fisheye lenses where the paraxial approximation breaks down.

Algorithm
  1. For on-axis (rfov = 0): chief ray is trivially along the z-axis.
  2. For off-axis angles with ray_aiming=False: the chief ray is aimed at the entrance pupil center (paraxial approximation).
  3. For off-axis angles with ray_aiming=True: a. Estimate the object-space y (or x) position from the entrance pupil geometry. b. Create a narrow fan of num_rays rays bracketing that estimate (width = 5 % of y_distance, clamped to 0.05 * pupil_radius). c. Trace the fan to the aperture stop. d. Pick the ray closest to the optical axis at the stop.

Parameters:

Name Type Description Default
rfov float | Tensor

Field angle(s) in degrees. A scalar is converted to [0, rfov] (two-element tensor). A tensor of shape [N] is used directly.

required
depth float | Tensor

Object depth(s) in mm. Defaults to 0.0 (object at the first surface).

0.0
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
plane str

'sagittal' or 'meridional'. Defaults to 'meridional'.

'meridional'
num_rays int

Size of the search fan for ray aiming. Defaults to SPP_CALC.

SPP_CALC
ray_aiming bool

If True, perform iterative ray aiming for accurate chief-ray identification. Defaults to True.

True

Returns:

Type Description

tuple[torch.Tensor, torch.Tensor]: - chief_ray_o: Origins, shape [N, 3]. - chief_ray_d: Unit directions, shape [N, 3].

Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def calc_chief_ray_infinite(
    self,
    rfov,
    depth=0.0,
    wvln=DEFAULT_WAVE,
    plane="meridional",
    num_rays=SPP_CALC,
    ray_aiming=True,
):
    """Compute chief rays for one or more field angles with optional ray aiming.

    This is the batched, production version of ``calc_chief_ray``.  It
    supports vectorized evaluation over multiple field angles and implements
    *ray aiming* — an iterative procedure that launches a fan of rays
    toward the entrance pupil and selects the one that passes closest to
    the aperture-stop center.  Ray aiming is essential for accurate
    distortion measurement in wide-angle or fisheye lenses where the
    paraxial approximation breaks down.

    Algorithm:
        1. For on-axis (``rfov = 0``): chief ray is trivially along the
           z-axis.
        2. For off-axis angles with ``ray_aiming=False``: the chief ray is
           aimed at the entrance pupil center (paraxial approximation).
        3. For off-axis angles with ``ray_aiming=True``:
           a. Estimate the object-space y (or x) position from the entrance
              pupil geometry.
           b. Create a narrow fan of ``num_rays`` rays bracketing that
              estimate (width = 5 % of y_distance, clamped to
              ``0.05 * pupil_radius``).
           c. Trace the fan to the aperture stop.
           d. Pick the ray closest to the optical axis at the stop.

    Args:
        rfov (float | torch.Tensor): Field angle(s) in **degrees**.
            A scalar is converted to ``[0, rfov]`` (two-element tensor).
            A tensor of shape ``[N]`` is used directly.
        depth (float | torch.Tensor): Object depth(s) in mm.
            Defaults to 0.0 (object at the first surface).
        wvln (float): Wavelength in micrometers. Defaults to ``DEFAULT_WAVE``.
        plane (str): ``'sagittal'`` or ``'meridional'``.
            Defaults to ``'meridional'``.
        num_rays (int): Size of the search fan for ray aiming.
            Defaults to ``SPP_CALC``.
        ray_aiming (bool): If ``True``, perform iterative ray aiming for
            accurate chief-ray identification. Defaults to ``True``.

    Returns:
        tuple[torch.Tensor, torch.Tensor]:
            - **chief_ray_o**: Origins, shape ``[N, 3]``.
            - **chief_ray_d**: Unit directions, shape ``[N, 3]``.
    """
    if isinstance(rfov, float) and rfov > 0:
        rfov = torch.linspace(0, rfov, 2)
    rfov = rfov.to(self.device)

    if not isinstance(depth, torch.Tensor):
        depth = torch.tensor(depth, device=self.device).repeat(len(rfov))

    # set chief ray
    chief_ray_o = torch.zeros([len(rfov), 3]).to(self.device)
    chief_ray_d = torch.zeros([len(rfov), 3]).to(self.device)

    # Convert rfov to radian
    rfov = rfov * torch.pi / 180.0

    if torch.any(rfov == 0):
        chief_ray_o[0, ...] = torch.tensor(
            [0.0, 0.0, depth[0]], device=self.device, dtype=torch.float32
        )
        chief_ray_d[0, ...] = torch.tensor(
            [0.0, 0.0, 1.0], device=self.device, dtype=torch.float32
        )
        if len(rfov) == 1:
            return chief_ray_o, chief_ray_d

    # Extract non-zero rfov entries for processing
    has_zero = torch.any(rfov == 0)
    if has_zero:
        start_idx = 1
        rfovs = rfov[1:]
        depths = depth[1:]
    else:
        start_idx = 0
        rfovs = rfov
        depths = depth

    if self.aper_idx == 0:
        if plane == "sagittal":
            chief_ray_o[start_idx:, ...] = torch.stack(
                [depths * torch.tan(rfovs), torch.zeros_like(rfovs), depths], dim=-1
            )
            chief_ray_d[start_idx:, ...] = torch.stack(
                [torch.sin(rfovs), torch.zeros_like(rfovs), torch.cos(rfovs)],
                dim=-1,
            )
        else:
            chief_ray_o[start_idx:, ...] = torch.stack(
                [torch.zeros_like(rfovs), depths * torch.tan(rfovs), depths], dim=-1
            )
            chief_ray_d[start_idx:, ...] = torch.stack(
                [torch.zeros_like(rfovs), torch.sin(rfovs), torch.cos(rfovs)],
                dim=-1,
            )

        return chief_ray_o, chief_ray_d

    # Scale factor
    pupilz, pupilr = self.calc_entrance_pupil()
    y_distance = torch.tan(rfovs) * (abs(depths) + pupilz)

    if ray_aiming:
        scale = 0.05
        min_delta = 0.05 * pupilr  # minimum search range based on pupil radius
        delta = torch.clamp(scale * y_distance, min=min_delta)

    if not ray_aiming:
        if plane == "sagittal":
            chief_ray_o[start_idx:, ...] = torch.stack(
                [-y_distance, torch.zeros_like(rfovs), depths], dim=-1
            )
            chief_ray_d[start_idx:, ...] = torch.stack(
                [torch.sin(rfovs), torch.zeros_like(rfovs), torch.cos(rfovs)],
                dim=-1,
            )
        else:
            chief_ray_o[start_idx:, ...] = torch.stack(
                [torch.zeros_like(rfovs), -y_distance, depths], dim=-1
            )
            chief_ray_d[start_idx:, ...] = torch.stack(
                [torch.zeros_like(rfovs), torch.sin(rfovs), torch.cos(rfovs)],
                dim=-1,
            )

    else:
        min_y = -y_distance - delta
        max_y = -y_distance + delta
        t = torch.linspace(0, 1, num_rays, device=min_y.device)
        o1_linspace = min_y.unsqueeze(-1) + t * (max_y - min_y).unsqueeze(-1)

        o1 = torch.zeros([len(rfovs), num_rays, 3])
        o1[:, :, 2] = depths[0]

        o2_linspace = -delta.unsqueeze(-1) + t * (2 * delta).unsqueeze(-1)

        o2 = torch.zeros([len(rfovs), num_rays, 3])
        o2[:, :, 2] = pupilz

        if plane == "sagittal":
            o1[:, :, 0] = o1_linspace
            o2[:, :, 0] = o2_linspace
        else:
            o1[:, :, 1] = o1_linspace
            o2[:, :, 1] = o2_linspace

        # Trace until the aperture
        ray = Ray(o1, o2 - o1, wvln=wvln, device=self.device)
        inc_ray = ray.clone()
        surf_range = range(0, self.aper_idx + 1)
        ray, _ = self.trace(ray, surf_range=surf_range)

        # Look for the ray that is closest to the optical axis
        if plane == "sagittal":
            _, center_idx = torch.min(torch.abs(ray.o[..., 0]), dim=1)
            chief_ray_o[start_idx:, ...] = inc_ray.o[
                torch.arange(len(rfovs)), center_idx.long(), ...
            ]
            chief_ray_d[start_idx:, ...] = torch.stack(
                [torch.sin(rfovs), torch.zeros_like(rfovs), torch.cos(rfovs)],
                dim=-1,
            )
        else:
            _, center_idx = torch.min(torch.abs(ray.o[..., 1]), dim=1)
            chief_ray_o[start_idx:, ...] = inc_ray.o[
                torch.arange(len(rfovs)), center_idx.long(), ...
            ]
            chief_ray_d[start_idx:, ...] = torch.stack(
                [torch.zeros_like(rfovs), torch.sin(rfovs), torch.cos(rfovs)],
                dim=-1,
            )

    return chief_ray_o, chief_ray_d

analysis_rendering

analysis_rendering(img_org, save_name=None, depth=DEPTH, spp=SPP_RENDER, unwarp=False, method='ray_tracing', show=False)

Render a test image through the lens and report PSNR / SSIM.

Simulates what the sensor would capture if the given image were placed at the specified object distance. The rendering accounts for all geometric aberrations (blur, distortion, vignetting, chromatic effects). Optionally applies an inverse distortion warp (unwarp) and reports quality metrics for both the raw and unwarped renderings.

Algorithm
  1. Convert img_org to a [1, 3, H, W] float tensor and temporarily set the sensor resolution to match.
  2. Call self.render() with the chosen method (ray tracing or PSF convolution).
  3. Compute PSNR and SSIM between the original and rendered images.
  4. If unwarp=True, apply self.unwarp() to correct geometric distortion and report metrics again.
  5. Restore the original sensor resolution.

Parameters:

Name Type Description Default
img_org ndarray | Tensor

Source image with shape [H, W, 3], either uint8 [0, 255] or float [0, 1].

required
save_name str | None

Path prefix for saved PNGs. If not None, saves '{save_name}.png' and (if unwarped) '{save_name}_unwarped.png'. Defaults to None.

None
depth float

Object distance in mm. Defaults to DEPTH.

DEPTH
spp int

Samples (rays) per pixel for rendering. Defaults to SPP_RENDER.

SPP_RENDER
unwarp bool

If True, apply distortion correction after rendering. Defaults to False.

False
method str

Rendering backend — 'ray_tracing' or 'psf_conv'. Defaults to 'ray_tracing'.

'ray_tracing'
show bool

If True, display the result with matplotlib. Defaults to False.

False

Returns:

Type Description

torch.Tensor: Rendered (and optionally unwarped) image with shape [1, 3, H, W], float values in [0, 1].

Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def analysis_rendering(
    self,
    img_org,
    save_name=None,
    depth=DEPTH,
    spp=SPP_RENDER,
    unwarp=False,
    method="ray_tracing",
    show=False,
):
    """Render a test image through the lens and report PSNR / SSIM.

    Simulates what the sensor would capture if the given image were placed
    at the specified object distance.  The rendering accounts for all
    geometric aberrations (blur, distortion, vignetting, chromatic effects).
    Optionally applies an inverse distortion warp (``unwarp``) and reports
    quality metrics for both the raw and unwarped renderings.

    Algorithm:
        1. Convert ``img_org`` to a ``[1, 3, H, W]`` float tensor and
           temporarily set the sensor resolution to match.
        2. Call ``self.render()`` with the chosen method (ray tracing or PSF
           convolution).
        3. Compute PSNR and SSIM between the original and rendered images.
        4. If ``unwarp=True``, apply ``self.unwarp()`` to correct geometric
           distortion and report metrics again.
        5. Restore the original sensor resolution.

    Args:
        img_org (np.ndarray | torch.Tensor): Source image with shape
            ``[H, W, 3]``, either uint8 ``[0, 255]`` or float ``[0, 1]``.
        save_name (str | None): Path prefix for saved PNGs.  If not
            ``None``, saves ``'{save_name}.png'`` and (if unwarped)
            ``'{save_name}_unwarped.png'``. Defaults to ``None``.
        depth (float): Object distance in mm. Defaults to ``DEPTH``.
        spp (int): Samples (rays) per pixel for rendering.
            Defaults to ``SPP_RENDER``.
        unwarp (bool): If ``True``, apply distortion correction after
            rendering. Defaults to ``False``.
        method (str): Rendering backend — ``'ray_tracing'`` or
            ``'psf_conv'``. Defaults to ``'ray_tracing'``.
        show (bool): If ``True``, display the result with matplotlib.
            Defaults to ``False``.

    Returns:
        torch.Tensor: Rendered (and optionally unwarped) image with shape
            ``[1, 3, H, W]``, float values in ``[0, 1]``.
    """
    # Change sensor resolution to match the image
    sensor_res_original = self.sensor_res
    if isinstance(img_org, np.ndarray):
        img = torch.from_numpy(img_org).permute(2, 0, 1).unsqueeze(0).float() / 255.0
    elif torch.is_tensor(img_org):
        img = img_org.permute(2, 0, 1).unsqueeze(0).float()
        if img.max() > 1.0:
            img = img / 255.0
    img = img.to(self.device)
    self.set_sensor_res(sensor_res=img.shape[-2:])

    # Image rendering
    img_render = self.render(img, depth=depth, method=method, spp=spp)

    # Compute PSNR and SSIM
    img_np = img.squeeze(0).permute(1, 2, 0).cpu().numpy()
    render_np = img_render.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().detach().numpy()
    render_psnr = round(peak_signal_noise_ratio(img_np, render_np, data_range=1.0), 3)
    render_ssim = round(structural_similarity(img_np, render_np, channel_axis=2, data_range=1.0), 4)
    print(f"Rendered image: PSNR={render_psnr:.3f}, SSIM={render_ssim:.4f}")

    # Save image
    if save_name is not None:
        save_image(img_render, f"{save_name}.png")

    # Unwarp to correct geometry distortion
    if unwarp:
        img_render = self.unwarp(img_render, depth)

        # Compute PSNR and SSIM
        render_np = img_render.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().detach().numpy()
        render_psnr = round(peak_signal_noise_ratio(img_np, render_np, data_range=1.0), 3)
        render_ssim = round(structural_similarity(img_np, render_np, channel_axis=2, data_range=1.0), 4)
        print(
            f"Rendered image (unwarped): PSNR={render_psnr:.3f}, SSIM={render_ssim:.4f}"
        )

        if save_name is not None:
            save_image(img_render, f"{save_name}_unwarped.png")

    # Change the sensor resolution back
    self.set_sensor_res(sensor_res=sensor_res_original)

    # Show image
    if show:
        plt.imshow(img_render.cpu().squeeze(0).permute(1, 2, 0).numpy())
        plt.title("Rendered image")
        plt.axis("off")
        plt.show()
        plt.close()

    return img_render

analysis_spot

analysis_spot(num_field=3, depth=float('inf'))

Compute RMS and geometric spot radii at multiple field positions for RGB.

Traces rays at num_field evenly-spaced field positions along the meridional direction for three wavelengths (G, R, B), computes per- wavelength RMS and maximum (geometric) spot radii referenced to the green centroid, then averages the three wavelengths.

This provides a quick polychromatic spot-size summary used for design comparisons and printed to stdout during analysis().

Algorithm
  1. For each wavelength (G first, then R, B): a. self.sample_radial_rays()[num_field, SPP_PSF, 3]. b. self.trace2sensor() → sensor-plane positions. c. Green centroid c_G is computed on the first iteration and used as the common reference for all wavelengths. d. RMS = sqrt(mean(||xy - c_G||^2)) per field position. e. radius = max(||xy - c_G||) per field position.
  2. Average RMS and radius over the three wavelengths.
  3. Convert from mm to μm (× 1000).

Parameters:

Name Type Description Default
num_field int

Number of field positions sampled from on-axis to full-field. Defaults to 3.

3
depth float

Object distance in mm. Use float('inf') for collimated light. Defaults to float('inf').

float('inf')

Returns:

Type Description

dict[str, dict[str, float]]: Spot analysis results keyed by field position string (e.g., 'fov0.0', 'fov0.5', 'fov1.0'). Each value is a dict with: - 'rms': Polychromatic RMS spot radius in μm. - 'radius': Polychromatic geometric spot radius in μm.

Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def analysis_spot(self, num_field=3, depth=float("inf")):
    """Compute RMS and geometric spot radii at multiple field positions for RGB.

    Traces rays at ``num_field`` evenly-spaced field positions along the
    meridional direction for three wavelengths (G, R, B), computes per-
    wavelength RMS and maximum (geometric) spot radii referenced to the
    **green centroid**, then averages the three wavelengths.

    This provides a quick polychromatic spot-size summary used for design
    comparisons and printed to stdout during ``analysis()``.

    Algorithm:
        1. For each wavelength (G first, then R, B):
           a. ``self.sample_radial_rays()`` → ``[num_field, SPP_PSF, 3]``.
           b. ``self.trace2sensor()`` → sensor-plane positions.
           c. Green centroid ``c_G`` is computed on the first iteration and
              used as the common reference for all wavelengths.
           d. ``RMS = sqrt(mean(||xy - c_G||^2))`` per field position.
           e. ``radius = max(||xy - c_G||)`` per field position.
        2. Average RMS and radius over the three wavelengths.
        3. Convert from mm to μm (× 1000).

    Args:
        num_field (int): Number of field positions sampled from on-axis
            to full-field. Defaults to 3.
        depth (float): Object distance in mm.  Use ``float('inf')`` for
            collimated light. Defaults to ``float('inf')``.

    Returns:
        dict[str, dict[str, float]]: Spot analysis results keyed by field
            position string (e.g., ``'fov0.0'``, ``'fov0.5'``, ``'fov1.0'``).
            Each value is a dict with:
                - ``'rms'``: Polychromatic RMS spot radius in μm.
                - ``'radius'``: Polychromatic geometric spot radius in μm.
    """
    rms_radius_fields = []
    geo_radius_fields = []
    for i, wvln in enumerate([WAVE_RGB[1], WAVE_RGB[0], WAVE_RGB[2]]):
        # Sample rays along meridional (y) direction, shape [num_field, num_rays, 3]
        ray = self.sample_radial_rays(
            num_field=num_field, depth=depth, num_rays=SPP_PSF, wvln=wvln
        )
        ray = self.trace2sensor(ray)

        # Green light point center for reference, shape [num_field, 1, 2]
        if i == 0:
            ray_xy_center_green = ray.centroid()[..., :2].unsqueeze(-2)

        # Calculate RMS spot size and radius for different FoVs
        ray_xy_norm = (
            ray.o[..., :2] - ray_xy_center_green
        ) * ray.is_valid.unsqueeze(-1)
        spot_rms = (
            ((ray_xy_norm**2).sum(-1) * ray.is_valid).sum(-1)
            / (ray.is_valid.sum(-1) + EPSILON)
        ).sqrt()
        spot_radius = (ray_xy_norm**2).sum(-1).sqrt().max(dim=-1).values

        # Append to list
        rms_radius_fields.append(spot_rms)
        geo_radius_fields.append(spot_radius)

    # Average over wavelengths, shape [num_field]
    avg_rms_radius_um = torch.stack(rms_radius_fields, dim=0).mean(dim=0) * 1000.0
    avg_geo_radius_um = torch.stack(geo_radius_fields, dim=0).mean(dim=0) * 1000.0

    # Print results
    print(f"Ray spot analysis results for depth {depth}:")
    print(
        f"RMS radius: FoV (0.0) {avg_rms_radius_um[0]:.3f} um, FoV (0.5) {avg_rms_radius_um[num_field // 2]:.3f} um, FoV (1.0) {avg_rms_radius_um[-1]:.3f} um"
    )
    print(
        f"Geo radius: FoV (0.0) {avg_geo_radius_um[0]:.3f} um, FoV (0.5) {avg_geo_radius_um[num_field // 2]:.3f} um, FoV (1.0) {avg_geo_radius_um[-1]:.3f} um"
    )

    # Save to dict
    rms_results = {}
    fov_ls = torch.linspace(0, 1, num_field)
    for i in range(num_field):
        fov = round(fov_ls[i].item(), 2)
        rms_results[f"fov{fov}"] = {
            "rms": round(avg_rms_radius_um[i].item(), 4),
            "radius": round(avg_geo_radius_um[i].item(), 4),
        }

    return rms_results

analysis

analysis(save_name='./lens', depth=float('inf'), full_eval=False, render=False, render_unwarp=False, lens_title=None, show=False)

Run a comprehensive optical analysis pipeline for the lens.

This is the main entry point for evaluating a lens design. It chains multiple evaluation steps in order, saving all plots with a common save_name prefix.

Execution flow
  1. Always: draw the lens layout (draw_layout) and compute polychromatic spot RMS/radius (analysis_spot).
  2. If full_eval=True: additionally generate:
  3. Spot diagram (draw_spot_radial).
  4. MTF grid (draw_mtf).
  5. Distortion curve (draw_distortion_radial).
  6. Field curvature plot (draw_field_curvature).
  7. Vignetting map (draw_vignetting).
  8. If render=True: render a test chart image through the lens and report PSNR/SSIM (analysis_rendering).

Parameters:

Name Type Description Default
save_name str

Path prefix for all output files. Each plot appends a suffix (e.g., '_spot.png', '_mtf.png'). Defaults to './lens'.

'./lens'
depth float

Object distance in mm. float('inf') is replaced by DEPTH for rendering and vignetting. Defaults to float('inf').

float('inf')
full_eval bool

If True, run all evaluation plots. If False, only layout + spot RMS. Defaults to False.

False
render bool

If True, render a test image through the lens. Defaults to False.

False
render_unwarp bool

If True (and render=True), also produce an unwarped rendering. Defaults to False.

False
lens_title str | None

Title string for the layout plot. Defaults to None.

None
show bool

If True, display all plots interactively. Defaults to False.

False
Source code in src/geolens_pkg/eval.py
@torch.no_grad()
def analysis(
    self,
    save_name="./lens",
    depth=float("inf"),
    full_eval=False,
    render=False,
    render_unwarp=False,
    lens_title=None,
    show=False,
):
    """Run a comprehensive optical analysis pipeline for the lens.

    This is the main entry point for evaluating a lens design.  It chains
    multiple evaluation steps in order, saving all plots with a common
    ``save_name`` prefix.

    Execution flow:
        1. **Always**: draw the lens layout (``draw_layout``) and compute
           polychromatic spot RMS/radius (``analysis_spot``).
        2. **If** ``full_eval=True``: additionally generate:
           - Spot diagram (``draw_spot_radial``).
           - MTF grid (``draw_mtf``).
           - Distortion curve (``draw_distortion_radial``).
           - Field curvature plot (``draw_field_curvature``).
           - Vignetting map (``draw_vignetting``).
        3. **If** ``render=True``: render a test chart image through the
           lens and report PSNR/SSIM (``analysis_rendering``).

    Args:
        save_name (str): Path prefix for all output files.  Each plot
            appends a suffix (e.g., ``'_spot.png'``, ``'_mtf.png'``).
            Defaults to ``'./lens'``.
        depth (float): Object distance in mm.  ``float('inf')`` is replaced
            by ``DEPTH`` for rendering and vignetting.
            Defaults to ``float('inf')``.
        full_eval (bool): If ``True``, run all evaluation plots.  If
            ``False``, only layout + spot RMS. Defaults to ``False``.
        render (bool): If ``True``, render a test image through the lens.
            Defaults to ``False``.
        render_unwarp (bool): If ``True`` (and ``render=True``), also
            produce an unwarped rendering. Defaults to ``False``.
        lens_title (str | None): Title string for the layout plot.
            Defaults to ``None``.
        show (bool): If ``True``, display all plots interactively.
            Defaults to ``False``.
    """
    # Draw lens layout and ray path
    self.draw_layout(
        filename=f"{save_name}.png",
        lens_title=lens_title,
        depth=depth,
        show=show,
    )

    # Calculate RMS error
    self.analysis_spot(depth=depth)

    # Comprehensive optical evaluation
    if full_eval:
        # Draw spot diagram
        self.draw_spot_radial(
            save_name=f"{save_name}_spot.png",
            depth=depth,
            show=show,
        )

        # Draw MTF
        if depth == float("inf"):
            self.draw_mtf(
                depth_list=[DEPTH],
                save_name=f"{save_name}_mtf.png",
                show=show,
            )
        else:
            self.draw_mtf(
                depth_list=[depth],
                save_name=f"{save_name}_mtf.png",
                show=show,
            )

        # Draw distortion
        self.draw_distortion_radial(
            save_name=f"{save_name}_distortion.png",
            show=show,
        )

        # Draw field curvature
        self.draw_field_curvature(
            save_name=f"{save_name}_field_curvature.png",
            show=show,
        )

        # Draw vignetting
        eval_depth = DEPTH if depth == float("inf") else depth
        self.draw_vignetting(
            filename=f"{save_name}_vignetting.png",
            depth=eval_depth,
            show=show,
        )

    # Render an image, compute PSNR and SSIM
    if render:
        depth = DEPTH if depth == float("inf") else depth
        img_org = Image.open("./datasets/charts/NBS_1963_1k.png").convert("RGB")
        img_org = np.array(img_org)
        self.analysis_rendering(
            img_org,
            depth=depth,
            spp=SPP_RENDER,
            unwarp=render_unwarp,
            save_name=f"{save_name}_render",
            show=show,
        )

Spot Diagrams

lens.draw_spot_map()          # Spot diagram grid across the field
lens.draw_spot_radial()       # Spot diagram along radial direction
lens.rms_map()                # RMS spot size map

MTF

lens.draw_mtf()               # Modulation transfer function curves

Distortion

lens.draw_distortion_radial() # Radial distortion curve
lens.draw_distortion_map()    # 2D distortion map

Wavefront Error

lens.draw_wavefront_error()   # Wavefront error map
lens.rms_wavefront_error()    # RMS wavefront error

Full Analysis

lens.analysis(render=True)    # Complete evaluation suite

Seidel Aberrations

src.geolens_pkg.eval_seidel.GeoLensSeidel

Mixin for Seidel (third-order) aberration analysis.

seidel_coefficients

seidel_coefficients(wvln: float = WVLN_d, include_chromatic: bool = True) -> Dict

Compute per-surface Seidel (third-order) aberration coefficients.

Parameters:

Name Type Description Default
wvln float

Reference wavelength in µm (default: d-line 0.5876 µm).

WVLN_d
include_chromatic bool

If True, also compute longitudinal and transverse chromatic aberration (C_L, C_T).

True

Returns:

Type Description
Dict

Dict with keys: S1..S5 — per-surface lists of Seidel sums [mm] CL, CT — per-surface chromatic aberrations [mm] labels — surface labels (e.g. ["S1", "S2", ...]) sums — dict of system totals for each aberration

Source code in src/geolens_pkg/eval_seidel.py
@torch.no_grad()
def seidel_coefficients(
    self,
    wvln: float = WVLN_d,
    include_chromatic: bool = True,
) -> Dict:
    """Compute per-surface Seidel (third-order) aberration coefficients.

    Args:
        wvln: Reference wavelength in µm (default: d-line 0.5876 µm).
        include_chromatic: If True, also compute longitudinal and
            transverse chromatic aberration (C_L, C_T).

    Returns:
        Dict with keys:
            S1..S5 — per-surface lists of Seidel sums [mm]
            CL, CT — per-surface chromatic aberrations [mm]
            labels — surface labels (e.g. ["S1", "S2", ...])
            sums   — dict of system totals for each aberration
    """
    tr = self._paraxial_trace(wvln)
    y = tr["y"]
    u = tr["u"]
    u_aft = tr["u_after"]
    yb = tr["ybar"]
    ub = tr["ubar"]
    ub_aft = tr["ubar_after"]
    n = tr["n"]
    np_ = tr["np"]
    c = tr["c"]
    surf_indices = tr["surf_indices"]
    num = len(y)

    # Lagrange invariant: H = n * (y_bar * u - y * u_bar)
    # Compute at first surface
    H = n[0] * (yb[0] * u[0] - y[0] * ub[0])

    S1 = [0.0] * num  # Spherical
    S2 = [0.0] * num  # Coma
    S3 = [0.0] * num  # Astigmatism
    S4 = [0.0] * num  # Petzval
    S5 = [0.0] * num  # Distortion
    CL = [0.0] * num  # Longitudinal chromatic
    CT = [0.0] * num  # Transverse chromatic

    wvln_t = torch.tensor([wvln])
    wvln_F_t = torch.tensor([WVLN_F])
    wvln_C_t = torch.tensor([WVLN_C])

    mat_before = Material("air")

    for j in range(num):
        si = surf_indices[j]
        surf = self.surfaces[si]

        # Refraction invariant A = n*(u + y*c), Abar = n*(ubar + ybar*c)
        A = n[j] * (u[j] + y[j] * c[j])
        Abar = n[j] * (ub[j] + yb[j] * c[j])

        # Delta(u/n) = u'/n' - u/n
        delta_u_over_n = u_aft[j] / np_[j] - u[j] / n[j]

        # Delta(1/n) = 1/n' - 1/n
        delta_inv_n = 1.0 / np_[j] - 1.0 / n[j]

        # --- Spherical surface contributions ---
        S1[j] = -A * A * y[j] * delta_u_over_n
        S2[j] = -A * Abar * y[j] * delta_u_over_n
        S3[j] = -Abar * Abar * y[j] * delta_u_over_n
        S4[j] = -H * H * c[j] * delta_inv_n
        # S5 = (Abar/A) * (S3 + S4), guarding A ≈ 0
        if abs(A) > 1e-12:
            S5[j] = (Abar / A) * (S3[j] + S4[j])
        else:
            S5[j] = 0.0

        # --- Aspheric correction ---
        if isinstance(surf, Aspheric):
            k_tensor = surf.k
            k_val = float(k_tensor.detach().item()) if torch.is_tensor(k_tensor) else float(k_tensor)
            c_val = c[j]
            # Fourth-order deformation: b4 = k*c^3/8 + a4
            a4 = 0.0
            if surf.ai is not None and len(surf.ai) > 0:
                a4_tensor = surf.ai[0]
                a4 = float(a4_tensor.detach().item()) if torch.is_tensor(a4_tensor) else float(a4_tensor)
            b4 = k_val * c_val**3 / 8.0 + a4

            dn = np_[j] - n[j]
            y4 = y[j] ** 4

            dS1 = -8.0 * dn * y4 * b4
            S1[j] += dS1

            if abs(y[j]) > 1e-12:
                ratio = yb[j] / y[j]
                dS2 = -ratio * dS1
                dS3 = -(ratio**2) * dS1
                dS5 = -(ratio**3) * dS1
                S2[j] += dS2
                S3[j] += dS3
                S5[j] += dS5

        # --- Chromatic aberration ---
        if include_chromatic:
            n_F = float(mat_before.ior(wvln_F_t))
            n_C = float(mat_before.ior(wvln_C_t))
            np_F = float(surf.mat2.ior(wvln_F_t))
            np_C = float(surf.mat2.ior(wvln_C_t))

            delta_n = n_F - n_C
            delta_np = np_F - np_C

            # Δ(δn / n_d) = δn'/n'_d - δn/n_d
            delta_dn_over_nd = delta_np / np_[j] - delta_n / n[j]

            CL[j] = -y[j] * A * delta_dn_over_nd
            CT[j] = -y[j] * Abar * delta_dn_over_nd

        mat_before = surf.mat2

    # Labels
    labels = [f"S{si + 1}" for si in surf_indices]

    # System sums
    sums = {
        "S1": sum(S1),
        "S2": sum(S2),
        "S3": sum(S3),
        "S4": sum(S4),
        "S5": sum(S5),
        "CL": sum(CL),
        "CT": sum(CT),
    }

    result = {
        "S1": S1,
        "S2": S2,
        "S3": S3,
        "S4": S4,
        "S5": S5,
        "CL": CL,
        "CT": CT,
        "labels": labels,
        "sums": sums,
    }

    logger.info(
        "Seidel sums: S1=%.4f S2=%.4f S3=%.4f S4=%.4f S5=%.4f CL=%.4f CT=%.4f",
        sums["S1"], sums["S2"], sums["S3"], sums["S4"], sums["S5"],
        sums["CL"], sums["CT"],
    )

    return result

aberration_histogram

aberration_histogram(wvln: float = WVLN_d, save_name: Optional[str] = None, show: bool = False, include_chromatic: bool = True) -> Dict

Draw a Zemax-style Seidel aberration bar chart.

Parameters:

Name Type Description Default
wvln float

Reference wavelength in µm.

WVLN_d
save_name Optional[str]

Path to save the figure. Defaults to "./seidel_aberration.png".

None
show bool

If True, call plt.show() instead of saving.

False
include_chromatic bool

Include C_L and C_T bars.

True

Returns:

Type Description
Dict

The Seidel coefficients dict (same as seidel_coefficients).

Source code in src/geolens_pkg/eval_seidel.py
@torch.no_grad()
def aberration_histogram(
    self,
    wvln: float = WVLN_d,
    save_name: Optional[str] = None,
    show: bool = False,
    include_chromatic: bool = True,
) -> Dict:
    """Draw a Zemax-style Seidel aberration bar chart.

    Args:
        wvln: Reference wavelength in µm.
        save_name: Path to save the figure. Defaults to
            ``"./seidel_aberration.png"``.
        show: If True, call ``plt.show()`` instead of saving.
        include_chromatic: Include C_L and C_T bars.

    Returns:
        The Seidel coefficients dict (same as ``seidel_coefficients``).
    """
    coeffs = self.seidel_coefficients(wvln=wvln, include_chromatic=include_chromatic)

    labels = coeffs["labels"]
    sums = coeffs["sums"]

    # Aberration keys and display config
    if include_chromatic:
        ab_keys = ["S1", "S2", "S3", "S4", "S5", "CL", "CT"]
        ab_names = [
            "S_I (Spherical)",
            "S_II (Coma)",
            "S_III (Astigmatism)",
            "S_IV (Petzval)",
            "S_V (Distortion)",
            "C_L (Axial Color)",
            "C_T (Lateral Color)",
        ]
        colors = ["#1f77b4", "#2ca02c", "#d62728", "#17becf", "#9467bd", "#bcbd22", "#ff7f0e"]
    else:
        ab_keys = ["S1", "S2", "S3", "S4", "S5"]
        ab_names = [
            "S_I (Spherical)",
            "S_II (Coma)",
            "S_III (Astigmatism)",
            "S_IV (Petzval)",
            "S_V (Distortion)",
        ]
        colors = ["#1f77b4", "#2ca02c", "#d62728", "#17becf", "#9467bd"]

    n_ab = len(ab_keys)
    n_surf = len(labels)
    x_labels = labels + ["SUM"]
    n_groups = n_surf + 1  # surfaces + SUM

    x = np.arange(n_groups)
    bar_width = 0.8 / n_ab

    fig, ax = plt.subplots(figsize=(max(8, n_groups * 0.8 + 2), 5))

    for k, (key, name, color) in enumerate(zip(ab_keys, ab_names, colors)):
        vals = coeffs[key] + [sums[key]]
        offset = (k - n_ab / 2.0 + 0.5) * bar_width
        ax.bar(x + offset, vals, bar_width, label=name, color=color, edgecolor="white", linewidth=0.5)

    ax.set_xlabel("Surface")
    ax.set_ylabel("Aberration Coefficient [mm]")
    ax.set_title("Seidel Aberration Diagram")
    ax.set_xticks(x)
    ax.set_xticklabels(x_labels, rotation=45, ha="right")
    ax.legend(fontsize=7, loc="best")
    ax.axhline(y=0, color="black", linewidth=0.5)
    ax.grid(axis="y", alpha=0.3)

    plt.tight_layout()

    if show:
        plt.show()
    else:
        if save_name is None:
            save_name = "./seidel_aberration.png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

    return coeffs

Third-order (Seidel) aberration analysis via paraxial ray tracing. Computes per-surface coefficients: \(W_{040}\) (spherical), \(W_{131}\) (coma), \(W_{222}\) (astigmatism), \(W_{220}\) (field curvature), \(W_{311}\) (distortion), and chromatic terms.


Optimization

src.geolens_pkg.optim.GeoLensOptim

Mixin providing differentiable optimisation for GeoLens.

Implements gradient-based lens design using PyTorch autograd:

  • Loss functions – RMS spot error, focus, surface regularity, gap constraints, material validity.
  • Constraint initialisation – edge-thickness and self-intersection guards.
  • Optimizer helpers – parameter groups with per-type learning rates and cosine annealing schedules.
  • High-level optimize() – curriculum-learning training loop.

This class is not instantiated directly; it is mixed into :class:~deeplens.optics.geolens.GeoLens.

References

Xinge Yang et al., "Curriculum learning for ab initio deep learned refractive optics," Nature Communications 2024.

init_constraints

init_constraints(constraint_params=None)

Initialize constraints for the lens design.

Parameters:

Name Type Description Default
constraint_params dict

Constraint parameters.

None
Source code in src/geolens_pkg/optim.py
def init_constraints(self, constraint_params=None):
    """Initialize constraints for the lens design.

    Args:
        constraint_params (dict): Constraint parameters.
    """
    # In the future, we want to use constraint_params to set the constraints.
    if constraint_params is None:
        constraint_params = {}
        print("Lens design constraints initialized with default values.")

    if self.r_sensor < 12.0:
        self.is_cellphone = True

        self.air_min_edge = 0.025
        self.air_max_edge = 3.0
        self.air_min_center = 0.025
        self.air_max_center = 1.5

        self.thick_min_edge = 0.25
        self.thick_max_edge = 2.0
        self.thick_min_center = 0.25
        self.thick_max_center = 3.0

        self.bfl_min = 0.8
        self.bfl_max = 3.0
        self.ttl_max = 15.0

        # Surface shape constraints
        self.sag2diam_max = 0.1
        self.grad_max = 0.57 # tan(30deg)
        self.diam2thick_max = 15.0
        self.tmax2tmin_max = 5.0

        # Ray angle constraints
        self.chief_ray_angle_max = 30.0 # deg
        self.obliq_min = 0.6

    else:
        self.is_cellphone = False

        self.air_min_edge = 0.1
        self.air_max_edge = 100.0  # float("inf")
        self.air_min_center = 0.1
        self.air_max_center = 100.0  # float("inf")

        self.thick_min_edge = 1.0
        self.thick_max_edge = 20.0
        self.thick_min_center = 2.0
        self.thick_max_center = 20.0

        self.bfl_min = 5.0
        self.bfl_max = 100.0  # float("inf")
        self.ttl_max = 300.0

        # Surface shape constraints
        self.sag2diam_max = 0.2
        self.grad_max = 0.84 # tan(40deg)
        self.diam2thick_max = 20.0
        self.tmax2tmin_max = 10.0

        # Ray angle constraints
        self.chief_ray_angle_max = 40.0 # deg
        self.obliq_min = 0.4

loss_reg

loss_reg(w_focus=10.0, w_ray_angle=2.0, w_intersec=1.0, w_thickness=0.1, w_surf=1.0)

Compute combined regularization loss for lens design.

Aggregates multiple constraint losses to keep the lens physically valid during gradient-based optimisation.

Parameters:

Name Type Description Default
w_focus float

Weight for focus loss. Defaults to 10.0.

10.0
w_ray_angle float

Weight for chief ray angle loss. Defaults to 2.0.

2.0
w_intersec float

Weight for self-intersection loss. Defaults to 1.0.

1.0
w_thickness float

Weight for thickness / TTL loss. Defaults to 0.1.

0.1
w_surf float

Weight for surface shape loss. Defaults to 1.0.

1.0

Returns:

Name Type Description
tuple

(loss_reg, loss_dict) where: - loss_reg (Tensor): Scalar combined regularization loss. - loss_dict (dict): Per-component loss values for logging.

Source code in src/geolens_pkg/optim.py
def loss_reg(self, w_focus=10.0, w_ray_angle=2.0, w_intersec=1.0, w_thickness=0.1, w_surf=1.0):
    """Compute combined regularization loss for lens design.

    Aggregates multiple constraint losses to keep the lens physically valid
    during gradient-based optimisation.

    Args:
        w_focus (float, optional): Weight for focus loss. Defaults to 10.0.
        w_ray_angle (float, optional): Weight for chief ray angle loss. Defaults to 2.0.
        w_intersec (float, optional): Weight for self-intersection loss. Defaults to 1.0.
        w_thickness (float, optional): Weight for thickness / TTL loss. Defaults to 0.1.
        w_surf (float, optional): Weight for surface shape loss. Defaults to 1.0.

    Returns:
        tuple: (loss_reg, loss_dict) where:
            - loss_reg (Tensor): Scalar combined regularization loss.
            - loss_dict (dict): Per-component loss values for logging.
    """
    # Loss functions for regularization
    # loss_focus = self.loss_infocus()
    loss_ray_angle = self.loss_ray_angle()
    loss_intersec = self.loss_intersec()
    loss_thickness = self.loss_thickness()
    loss_surf = self.loss_surface()
    # loss_mat = self.loss_mat()
    loss_reg = (
        # w_focus * loss_focus
        + w_intersec * loss_intersec
        + w_thickness * loss_thickness
        + w_surf * loss_surf
        + w_ray_angle * loss_ray_angle
        # + loss_mat
    )

    # Return loss and loss dictionary
    loss_dict = {
        # "loss_focus": loss_focus.item(),
        "loss_intersec": loss_intersec.item(),
        "loss_thickness": loss_thickness.item(),
        "loss_surf": loss_surf.item(),
        'loss_ray_angle': loss_ray_angle.item(),
        # 'loss_mat': loss_mat.item(),
    }
    return loss_reg, loss_dict

loss_infocus

loss_infocus(target=0.005, wvln=None)

Sample parallel rays and compute RMS loss on the sensor plane, minimize focus loss.

Parameters:

Name Type Description Default
target float

target of RMS loss. Defaults to 0.005 [mm].

0.005
wvln float

Wavelength in um. Defaults to WAVE_RGB[1].

None
Source code in src/geolens_pkg/optim.py
def loss_infocus(self, target=0.005, wvln=None):
    """Sample parallel rays and compute RMS loss on the sensor plane, minimize focus loss.

    Args:
        target (float, optional): target of RMS loss. Defaults to 0.005 [mm].
        wvln (float, optional): Wavelength in um. Defaults to WAVE_RGB[1].
    """
    if wvln is None:
        wvln = WAVE_RGB[1]
    loss = torch.tensor(0.0, device=self.device)

    # Ray tracing and calculate RMS error
    ray = self.sample_from_fov(fov_x=0.0, fov_y=0.0, wvln=wvln, num_rays=SPP_CALC)
    ray = self.trace2sensor(ray)
    rms_error = ray.rms_error()

    # Smooth penalty: activates when rms_error exceeds target
    loss += torch.nn.functional.softplus(rms_error - target, beta=50.0)

    return loss

loss_surface

loss_surface()

Penalize extreme surface shapes that are difficult to manufacture.

Checks four constraints for each optimisable surface
  1. Sag-to-diameter ratio exceeding sag2diam_max.
  2. Maximum surface gradient exceeding grad_max.
  3. Diameter-to-thickness ratio exceeding diam2thick_max.
  4. Maximum-to-minimum thickness ratio exceeding tmax2tmin_max.

Returns:

Name Type Description
Tensor

Scalar surface shape penalty loss.

Source code in src/geolens_pkg/optim.py
def loss_surface(self):
    """Penalize extreme surface shapes that are difficult to manufacture.

    Checks four constraints for each optimisable surface:
        1. Sag-to-diameter ratio exceeding ``sag2diam_max``.
        2. Maximum surface gradient exceeding ``grad_max``.
        3. Diameter-to-thickness ratio exceeding ``diam2thick_max``.
        4. Maximum-to-minimum thickness ratio exceeding ``tmax2tmin_max``.

    Returns:
        Tensor: Scalar surface shape penalty loss.
    """
    sag2diam_max = self.sag2diam_max
    grad_max_allowed = self.grad_max
    diam2thick_max = self.diam2thick_max
    tmax2tmin_max = self.tmax2tmin_max

    loss_grad = torch.tensor(0.0, device=self.device)
    loss_diam2thick = torch.tensor(0.0, device=self.device)
    loss_tmax2tmin = torch.tensor(0.0, device=self.device)
    loss_sag2diam = torch.tensor(0.0, device=self.device)
    for i in self.find_diff_surf():
        # Sample points on the surface
        x_ls = torch.linspace(0.0, 1.0, 32, device=self.device) * self.surfaces[i].r
        y_ls = torch.zeros_like(x_ls)

        # Sag
        sag_ls = self.surfaces[i].sag(x_ls, y_ls)
        sag2diam = sag_ls.abs().max() / self.surfaces[i].r / 2
        loss_sag2diam += torch.nn.functional.softplus(sag2diam - sag2diam_max, beta=50.0)

        # 1st-order derivative
        grad_ls = self.surfaces[i].dfdxyz(x_ls, y_ls)[0]
        grad_max = grad_ls.abs().max()
        loss_grad += torch.nn.functional.softplus(grad_max - grad_max_allowed, beta=50.0)

        # Diameter to thickness ratio, thick_max to thick_min ratio
        if not self.surfaces[i].mat2.name == "air":
            surf2 = self.surfaces[i + 1]
            surf1 = self.surfaces[i]

            # Penalize diameter to thickness ratio
            diam2thick = 2 * max(surf2.r, surf1.r) / (surf2.d - surf1.d)
            loss_diam2thick += torch.nn.functional.softplus(diam2thick - diam2thick_max, beta=50.0)

            # Penalize thick_max to thick_min ratio.
            # Use torch.maximum/minimum for differentiable max/min.
            r_edge = min(surf2.r, surf1.r)
            thick_center = surf2.d - surf1.d
            thick_edge = surf2.surface_with_offset(r_edge, 0.0) - surf1.surface_with_offset(r_edge, 0.0)
            thick_max = torch.maximum(thick_center, thick_edge)
            thick_min = torch.minimum(thick_center, thick_edge).clamp(min=0.01)
            tmax2tmin = thick_max / thick_min

            loss_tmax2tmin += torch.nn.functional.softplus(tmax2tmin - tmax2tmin_max, beta=50.0)

    return loss_sag2diam + loss_grad + loss_diam2thick + loss_tmax2tmin

loss_intersec

loss_intersec()

Loss function to avoid self-intersection.

This function penalizes when surfaces are too close to each other, which could cause self-intersection or manufacturing issues.

Source code in src/geolens_pkg/optim.py
def loss_intersec(self):
    """Loss function to avoid self-intersection.

    This function penalizes when surfaces are too close to each other,
    which could cause self-intersection or manufacturing issues.
    """
    # Constraints
    air_min_center = self.air_min_center
    air_min_edge = self.air_min_edge
    thick_min_center = self.thick_min_center
    thick_min_edge = self.thick_min_edge
    bfl_min = self.bfl_min

    # Loss
    loss = torch.tensor(0.0, device=self.device)
    for i in range(len(self.surfaces) - 1):
        # Sample evaluation points on the two surfaces
        current_surf = self.surfaces[i]
        next_surf = self.surfaces[i + 1]

        r_center = torch.tensor(0.0, device=self.device) * current_surf.r
        z_prev_center = current_surf.surface_with_offset(r_center, 0.0, valid_check=False)
        z_next_center = next_surf.surface_with_offset(r_center, 0.0, valid_check=False)

        r_edge = torch.linspace(0.5, 1.0, 16, device=self.device) * current_surf.r
        z_prev_edge = current_surf.surface_with_offset(r_edge, 0.0, valid_check=False)
        z_next_edge = next_surf.surface_with_offset(r_edge, 0.0, valid_check=False)

        # Next surface is air
        if self.surfaces[i].mat2.name == "air":
            # Center air gap
            dist_center = z_next_center - z_prev_center
            loss += torch.nn.functional.softplus(air_min_center - dist_center, beta=50.0)

            # Edge air gap
            dist_edge = torch.min(z_next_edge - z_prev_edge)
            loss += torch.nn.functional.softplus(air_min_edge - dist_edge, beta=50.0)

        # Next surface is lens
        else:
            # Center thickness
            dist_center = z_next_center - z_prev_center
            loss += torch.nn.functional.softplus(thick_min_center - dist_center, beta=50.0)

            # Edge thickness
            dist_edge = torch.min(z_next_edge - z_prev_edge)
            loss += torch.nn.functional.softplus(thick_min_edge - dist_edge, beta=50.0)

    # Distance to sensor (back focal length)
    last_surf = self.surfaces[-1]
    r = torch.linspace(0.0, 1.0, 32, device=self.device) * last_surf.r
    z_last_surf = self.d_sensor - last_surf.surface_with_offset(r, 0.0)

    bfl = torch.min(z_last_surf)
    loss += torch.nn.functional.softplus(bfl_min - bfl, beta=50.0)

    # Loss (softplus already produces positive penalties; return as-is)
    return loss

loss_thickness

loss_thickness()

Penalize excessive air gaps, lens thicknesses, and total track length.

Checks three types of upper-bound constraints
  1. Per-gap air and glass thickness (center and edge).
  2. Back focal length (BFL).
  3. Total track length (TTL) from first surface to sensor.

Returns:

Name Type Description
Tensor

Scalar thickness penalty loss.

Source code in src/geolens_pkg/optim.py
def loss_thickness(self):
    """Penalize excessive air gaps, lens thicknesses, and total track length.

    Checks three types of upper-bound constraints:
        1. Per-gap air and glass thickness (center and edge).
        2. Back focal length (BFL).
        3. Total track length (TTL) from first surface to sensor.

    Returns:
        Tensor: Scalar thickness penalty loss.
    """
    # Constraints
    air_max_center = self.air_max_center
    air_max_edge = self.air_max_edge
    thick_max_center = self.thick_max_center
    thick_max_edge = self.thick_max_edge
    bfl_max = self.bfl_max
    ttl_max = self.ttl_max

    # Loss
    loss = torch.tensor(0.0, device=self.device)

    # Distance between surfaces
    for i in range(len(self.surfaces) - 1):
        # Sample evaluation points on the two surfaces
        current_surf = self.surfaces[i]
        next_surf = self.surfaces[i + 1]

        r_center = torch.tensor(0.0, device=self.device) * current_surf.r
        z_prev_center = current_surf.surface_with_offset(r_center, 0.0, valid_check=False)
        z_next_center = next_surf.surface_with_offset(r_center, 0.0, valid_check=False)

        r_edge = torch.linspace(0.5, 1.0, 16, device=self.device) * current_surf.r
        z_prev_edge = current_surf.surface_with_offset(r_edge, 0.0, valid_check=False)
        z_next_edge = next_surf.surface_with_offset(r_edge, 0.0, valid_check=False)

        # Air gap
        if self.surfaces[i].mat2.name == "air":
            # Center air gap
            dist_center = z_next_center - z_prev_center
            loss += torch.nn.functional.softplus(dist_center - air_max_center, beta=50.0)

            # Edge air gap
            dist_edge = torch.max(z_next_edge - z_prev_edge)
            loss += torch.nn.functional.softplus(dist_edge - air_max_edge, beta=50.0)

        # Lens thickness
        else:
            # Center thickness
            dist_center = z_next_center - z_prev_center
            loss += torch.nn.functional.softplus(dist_center - thick_max_center, beta=50.0)

            # Edge thickness
            dist_edge = torch.max(z_next_edge - z_prev_edge)
            loss += torch.nn.functional.softplus(dist_edge - thick_max_edge, beta=50.0)

    # Distance to sensor (back focal length)
    last_surf = self.surfaces[-1]
    r = torch.linspace(0.0, 1.0, 32, device=self.device) * last_surf.r
    z_last_surf = self.d_sensor - last_surf.surface_with_offset(r, 0.0)

    bfl = torch.max(z_last_surf)
    loss += torch.nn.functional.softplus(bfl - bfl_max, beta=50.0)

    # Total track length (first surface to sensor)
    ttl = self.d_sensor - self.surfaces[0].d
    loss += torch.nn.functional.softplus(ttl - ttl_max, beta=50.0)

    # Loss, minimize loss
    return loss

loss_ray_angle

loss_ray_angle()

Penalize rays that violate chief ray angle or obliquity constraints.

Uses softplus on the violation amount (cos_ref - cos_cra) so the loss is always non-negative, smooth at the boundary, and proportional to violation severity. Minimising the loss pushes cos(CRA) upward (i.e. reduces the chief ray angle).

Returns:

Name Type Description
Tensor

Scalar ray-angle penalty loss (always >= 0).

Source code in src/geolens_pkg/optim.py
def loss_ray_angle(self):
    """Penalize rays that violate chief ray angle or obliquity constraints.

    Uses softplus on the violation amount (cos_ref - cos_cra) so the loss
    is always non-negative, smooth at the boundary, and proportional to
    violation severity.  Minimising the loss pushes cos(CRA) upward
    (i.e. reduces the chief ray angle).

    Returns:
        Tensor: Scalar ray-angle penalty loss (always >= 0).
    """
    cos_cra_ref = float(np.cos(np.deg2rad(self.chief_ray_angle_max)))

    # Loss on chief ray angle: softplus(cos_ref - cos_cra)
    # Positive when cos_cra < cos_ref (CRA exceeds limit).
    # Gradient pushes cos_cra upward → reduces CRA.
    ray = self.sample_ring_arm_rays(num_ring=4, num_arm=8, spp=SPP_CALC, scale_pupil=0.2)
    ray = self.trace2sensor(ray)
    cos_cra = ray.d[..., 2]
    valid = ray.is_valid > 0
    penalty_cra = torch.nn.functional.softplus(cos_cra_ref - cos_cra, beta=50.0)
    loss_cra = (penalty_cra * valid).sum() / (valid.sum() + EPSILON)

    # Loss on accumulated oblique term: softplus(obliq_min - obliq)
    # Positive when obliq < obliq_min.
    # Gradient pushes obliq upward.
    ray = self.sample_ring_arm_rays(num_ring=4, num_arm=8, spp=SPP_CALC, scale_pupil=1.0)
    ray = self.trace2sensor(ray)
    obliq = ray.obliq.squeeze(-1)
    valid = ray.is_valid > 0
    penalty_obliq = torch.nn.functional.softplus(self.obliq_min - obliq, beta=50.0)
    loss_obliq = (penalty_obliq * valid).sum() / (valid.sum() + EPSILON)

    return loss_cra + loss_obliq

loss_mat

loss_mat()

Penalize material parameters outside manufacturable ranges.

Constrains refractive index n to [1.5, 1.9] and Abbe number V to [30, 70] for each non-air surface material.

Returns:

Name Type Description
Tensor

Scalar material penalty loss.

Source code in src/geolens_pkg/optim.py
def loss_mat(self):
    """Penalize material parameters outside manufacturable ranges.

    Constrains refractive index *n* to [1.5, 1.9] and Abbe number *V* to
    [30, 70] for each non-air surface material.

    Returns:
        Tensor: Scalar material penalty loss.
    """
    n_max = 1.9
    n_min = 1.5
    V_max = 70
    V_min = 30
    loss_mat = torch.tensor(0.0, device=self.device)
    for i in range(len(self.surfaces)):
        if self.surfaces[i].mat2.name != "air":
            if self.surfaces[i].mat2.n > n_max:
                loss_mat += (self.surfaces[i].mat2.n - n_max) / (n_max - n_min)
            if self.surfaces[i].mat2.n < n_min:
                loss_mat += (n_min - self.surfaces[i].mat2.n) / (n_max - n_min)
            if self.surfaces[i].mat2.V > V_max:
                loss_mat += (self.surfaces[i].mat2.V - V_max) / (V_max - V_min)
            if self.surfaces[i].mat2.V < V_min:
                loss_mat += (V_min - self.surfaces[i].mat2.V) / (V_max - V_min)

    return loss_mat

loss_rms

loss_rms(num_grid=GEO_GRID, depth=DEPTH, num_rays=SPP_PSF, sample_more_off_axis=False)

Loss function to compute RGB spot error RMS.

Parameters:

Name Type Description Default
num_grid int

Number of grid points. Defaults to GEO_GRID.

GEO_GRID
depth float

Depth of the lens. Defaults to DEPTH.

DEPTH
num_rays int

Number of rays. Defaults to SPP_CALC.

SPP_PSF
sample_more_off_axis bool

Whether to sample more off-axis rays. Defaults to False.

False

Returns:

Name Type Description
avg_rms_error Tensor

RMS error averaged over wavelengths and grid points.

Source code in src/geolens_pkg/optim.py
def loss_rms(
    self,
    num_grid=GEO_GRID,
    depth=DEPTH,
    num_rays=SPP_PSF,
    sample_more_off_axis=False,
):
    """Loss function to compute RGB spot error RMS.

    Args:
        num_grid (int, optional): Number of grid points. Defaults to GEO_GRID.
        depth (float, optional): Depth of the lens. Defaults to DEPTH.
        num_rays (int, optional): Number of rays. Defaults to SPP_CALC.
        sample_more_off_axis (bool, optional): Whether to sample more off-axis rays. Defaults to False.

    Returns:
        avg_rms_error (torch.Tensor): RMS error averaged over wavelengths and grid points.
    """
    all_rms_errors = []
    for i, wvln in enumerate([WAVE_RGB[1], WAVE_RGB[0], WAVE_RGB[2]]):
        ray = self.sample_grid_rays(
            depth=depth,
            num_grid=num_grid,
            num_rays=num_rays,
            wvln=wvln,
            sample_more_off_axis=sample_more_off_axis,
        )

        # Calculate reference center, shape of (..., 2)
        if i == 0:
            with torch.no_grad():
                ray_center_green = -self.psf_center(points_obj=ray.o[:, :, 0, :], method="pinhole")

        ray = self.trace2sensor(ray)

        # # Green light centroid for reference
        # if i == 0:
        #     with torch.no_grad():
        #         ray_center_green = ray.centroid()

        # Calculate RMS error with reference center
        rms_error = ray.rms_error(center_ref=ray_center_green)
        all_rms_errors.append(rms_error)

    # Calculate average RMS error
    avg_rms_error = torch.stack(all_rms_errors).mean(dim=0)
    return avg_rms_error

sample_ring_arm_rays

sample_ring_arm_rays(num_ring=8, num_arm=8, spp=2048, depth=DEPTH, wvln=DEFAULT_WAVE, scale_pupil=1.0, sample_more_off_axis=True)

Sample rays from object space using a ring-arm pattern.

This method distributes sampling points (origins of ray bundles) on a polar grid in the object plane, defined by field of view. This is useful for capturing lens performance across the full field. The points include the center and num_ring rings with num_arm points on each.

Parameters:

Name Type Description Default
num_ring int

Number of rings to sample in the field of view.

8
num_arm int

Number of arms (spokes) to sample for each ring.

8
spp int

Total number of rays to be sampled, distributed among field points.

2048
depth float

Depth of the object plane.

DEPTH
wvln float

Wavelength of the rays.

DEFAULT_WAVE
scale_pupil float

Scale factor for the pupil size.

1.0

Returns:

Name Type Description
Ray

A Ray object containing the sampled rays.

Source code in src/geolens_pkg/optim.py
def sample_ring_arm_rays(self, num_ring=8, num_arm=8, spp=2048, depth=DEPTH, wvln=DEFAULT_WAVE, scale_pupil=1.0, sample_more_off_axis=True):
    """Sample rays from object space using a ring-arm pattern.

    This method distributes sampling points (origins of ray bundles) on a polar grid in the object plane,
    defined by field of view. This is useful for capturing lens performance across the full field.
    The points include the center and `num_ring` rings with `num_arm` points on each.

    Args:
        num_ring (int): Number of rings to sample in the field of view.
        num_arm (int): Number of arms (spokes) to sample for each ring.
        spp (int): Total number of rays to be sampled, distributed among field points.
        depth (float): Depth of the object plane.
        wvln (float): Wavelength of the rays.
        scale_pupil (float): Scale factor for the pupil size.

    Returns:
        Ray: A Ray object containing the sampled rays.
    """
    # Create points on rings and arms
    max_fov_rad = self.rfov
    if sample_more_off_axis:
        beta_values = torch.linspace(0.0, 1.0, num_ring, device=self.device)
        beta_transformed = beta_values ** 0.5
        ring_fovs = max_fov_rad * beta_transformed
    else:
        ring_fovs = max_fov_rad * torch.linspace(0.0, 1.0, num_ring, device=self.device)

    arm_angles = torch.linspace(0.0, 2 * torch.pi, num_arm + 1, device=self.device)[:-1]
    ring_grid, arm_grid = torch.meshgrid(ring_fovs, arm_angles, indexing="ij")
    x = depth * torch.tan(ring_grid) * torch.cos(arm_grid)
    y = depth * torch.tan(ring_grid) * torch.sin(arm_grid)        
    z = torch.full_like(x, depth)
    points = torch.stack([x, y, z], dim=-1)  # shape: [num_ring, num_arm, 3]

    # Sample rays
    rays = self.sample_from_points(points=points, num_rays=spp, wvln=wvln, scale_pupil=scale_pupil)
    return rays

optimize

optimize(lrs=[0.001, 0.001, 0.01, 0.0001], iterations=5000, test_per_iter=100, centroid=False, optim_mat=False, shape_control=True, momentum_decay=0.1, result_dir=None)

Optimise the lens by minimising RGB RMS spot errors.

Runs a curriculum-learning training loop with Adam optimiser and cosine annealing. Periodically evaluates the lens, saves intermediate results, and optionally corrects surface shapes.

Parameters:

Name Type Description Default
lrs list

Learning rates for [d, c, k, a] parameter groups. Defaults to [1e-3, 1e-3, 1e-2, 1e-4].

[0.001, 0.001, 0.01, 0.0001]
iterations int

Total training iterations. Defaults to 5000.

5000
test_per_iter int

Evaluate and save every N iterations. Defaults to 100.

100
centroid bool

If True, use chief-ray centroid as PSF centre reference; otherwise use pinhole model. Defaults to False.

False
optim_mat bool

If True, include material parameters (n, V) in optimisation. Defaults to False.

False
shape_control bool

If True, call correct_shape() at each evaluation step. Defaults to True.

True
momentum_decay float

Factor to scale Adam's first moment (exp_avg) at each evaluation step. Prevents stale momentum from the previous ray batch from corrupting gradients after resampling. The second moment (exp_avg_sq) is left untouched to preserve adaptive lr scaling. Set to 0.0 for a full reset, 1.0 to disable. Defaults to 0.1.

0.1
result_dir str

Directory to save results. If None, auto-generates a timestamped directory. Defaults to None.

None
Note

Debug hints: 1. Slowly optimise with small learning rate. 2. FoV and thickness should match well. 3. Keep parameter ranges reasonable. 4. Higher aspheric order is better but more sensitive. 5. More iterations with larger ray sampling improves convergence.

Source code in src/geolens_pkg/optim.py
def optimize(
    self,
    lrs=[1e-3, 1e-3, 1e-2, 1e-4],
    iterations=5000,
    test_per_iter=100,
    centroid=False,
    optim_mat=False,
    shape_control=True,
    momentum_decay=0.1,
    result_dir=None,
):
    """Optimise the lens by minimising RGB RMS spot errors.

    Runs a curriculum-learning training loop with Adam optimiser and cosine
    annealing. Periodically evaluates the lens, saves intermediate results,
    and optionally corrects surface shapes.

    Args:
        lrs (list, optional): Learning rates for [d, c, k, a] parameter groups.
            Defaults to [1e-3, 1e-3, 1e-2, 1e-4].
        iterations (int, optional): Total training iterations. Defaults to 5000.
        test_per_iter (int, optional): Evaluate and save every N iterations.
            Defaults to 100.
        centroid (bool, optional): If True, use chief-ray centroid as PSF centre
            reference; otherwise use pinhole model. Defaults to False.
        optim_mat (bool, optional): If True, include material parameters (n, V)
            in optimisation. Defaults to False.
        shape_control (bool, optional): If True, call ``correct_shape()`` at each
            evaluation step. Defaults to True.
        momentum_decay (float, optional): Factor to scale Adam's first
            moment (exp_avg) at each evaluation step. Prevents stale
            momentum from the previous ray batch from corrupting gradients
            after resampling. The second moment (exp_avg_sq) is left
            untouched to preserve adaptive lr scaling. Set to 0.0 for a
            full reset, 1.0 to disable. Defaults to 0.1.
        result_dir (str, optional): Directory to save results. If None,
            auto-generates a timestamped directory. Defaults to None.

    Note:
        Debug hints:
            1. Slowly optimise with small learning rate.
            2. FoV and thickness should match well.
            3. Keep parameter ranges reasonable.
            4. Higher aspheric order is better but more sensitive.
            5. More iterations with larger ray sampling improves convergence.
    """
    # Experiment settings
    depth = DEPTH
    num_ring = 32
    num_arm = 8
    spp = 2048

    # Result directory and logger
    if result_dir is None:
        result_dir = f"./results/{datetime.now().strftime('%m%d-%H%M%S')}-DesignLens"

    os.makedirs(result_dir, exist_ok=True)
    if not logging.getLogger().hasHandlers():
        logger = logging.getLogger()
        logger.setLevel("DEBUG")
        fmt = logging.Formatter("%(asctime)s:%(levelname)s:%(message)s", "%Y-%m-%d %H:%M:%S")
        sh = logging.StreamHandler()
        sh.setFormatter(fmt)
        sh.setLevel("INFO")
        fh = logging.FileHandler(f"{result_dir}/output.log")
        fh.setFormatter(fmt)
        fh.setLevel("INFO")
        logger.addHandler(sh)
        logger.addHandler(fh)
    logging.info(f"lr:{lrs}, iterations:{iterations}, num_ring:{num_ring}, num_arm:{num_arm}, rays_per_fov:{spp}.")
    logging.info("If Out-of-Memory, try to reduce num_ring, num_arm, and rays_per_fov.")

    # Optimizer and scheduler
    optimizer = self.get_optimizer(lrs, optim_mat=optim_mat)
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=iterations)

    # Training loop
    pbar = tqdm(
        total=iterations + 1,
        desc="Progress",
        postfix={"loss_rms": 0, "loss_focus": 0},
    )
    for i in range(iterations + 1):
        # ===> Evaluate the lens
        if i % test_per_iter == 0:
            with torch.no_grad():
                if shape_control and i > 0:
                    self.correct_shape()
                    # self.refocus()

                self.write_lens_json(f"{result_dir}/iter{i}.json")
                self.analysis(f"{result_dir}/iter{i}")

                # Sample rays
                self.calc_pupil()
                rays_backup = []
                for wv in WAVE_RGB:
                    ray = self.sample_ring_arm_rays(num_ring=num_ring, num_arm=num_arm, spp=spp, depth=depth, wvln=wv, scale_pupil=1.05, sample_more_off_axis=False)
                    rays_backup.append(ray)

                # Calculate ray centers
                method = "chief_ray" if centroid else "pinhole"
                center_ref = -self.psf_center(points_obj=ray.o[:, :, 0, :], method=method)
                center_ref = center_ref.unsqueeze(-2).repeat(1, 1, spp, 1)

            # Soft momentum reset: decay Adam buffers before new ray batch
            if i > 0 and momentum_decay < 1.0:
                for pg in optimizer.param_groups:
                    for p in pg["params"]:
                        state = optimizer.state.get(p)
                        if state and "exp_avg" in state:
                            state["exp_avg"].mul_(momentum_decay)

        # Compute error-adaptive weight mask every iteration
        with torch.no_grad():
            ray_wm = rays_backup[0].clone()
            ray_wm = self.trace2sensor(ray_wm)
            ray_err_wm = ray_wm.o[..., :2] - center_ref
            ray_valid_wm = ray_wm.is_valid
            w_mask = ((ray_err_wm**2).sum(-1) * ray_valid_wm).sum(-1)
            w_mask /= ray_valid_wm.sum(-1) + EPSILON
            w_mask /= w_mask.mean() + EPSILON

        # ===> Optimize lens by minimizing RMS
        loss_rms_ls = []
        for wv_idx, wv in enumerate(WAVE_RGB):
            # Ray tracing to sensor, [num_grid, num_grid, num_rays, 3]
            ray = rays_backup[wv_idx].clone()
            ray = self.trace2sensor(ray)

            # Ray error to center and valid mask.
            # Use torch.where to zero out invalid rays BEFORE squaring,
            # preventing NaN from Inf*0 (IEEE 754: inf * 0 = nan).
            ray_xy = ray.o[..., :2]
            ray_valid = ray.is_valid
            ray_err = ray_xy - center_ref
            ray_err = torch.where(
                ray_valid.bool().unsqueeze(-1), ray_err, torch.zeros_like(ray_err)
            )

            # Loss on RMS error
            l_rms = (ray_err**2).sum(-1).sum(-1)
            l_rms /= ray_valid.sum(-1) + EPSILON
            l_rms = (l_rms + EPSILON).sqrt()

            l_rms_weighted = (l_rms * w_mask).sum()
            l_rms_weighted /= w_mask.sum() + EPSILON
            loss_rms_ls.append(l_rms_weighted)

        # RMS loss for all wavelengths
        loss_rms = sum(loss_rms_ls) / len(loss_rms_ls)

        # Total loss
        w_focus = 0.05
        loss_focus = self.loss_infocus()

        w_reg = 0.05
        loss_reg, loss_dict = self.loss_reg()

        L_total = loss_rms + w_focus * loss_focus + w_reg * loss_reg

        # Back-propagation
        optimizer.zero_grad()
        L_total.backward()
        optimizer.step()
        scheduler.step()

        pbar.set_postfix(loss_rms=loss_rms.item(), loss_focus=loss_focus.item(), **loss_dict)
        pbar.update(1)

    pbar.close()

curriculum_design

curriculum_design(lrs=[0.001, 0.001, 0.01, 0.0001], iterations=5000, test_per_iter=100, optim_mat=False, match_mat=False, shape_control=True, result_dir='./results')

Optimise the lens from scratch using curriculum aperture growth.

Gradually increases the aperture from 25% to full size over the training schedule, transforming a hard global optimisation into a sequence of easier subproblems.

Parameters:

Name Type Description Default
lrs list

Learning rates for [d, c, k, ai].

[0.001, 0.001, 0.01, 0.0001]
iterations int

Total training iterations.

5000
test_per_iter int

Evaluate and save every N iterations.

100
optim_mat bool

Optimise material parameters.

False
match_mat bool

Match materials at each evaluation.

False
shape_control bool

Correct surface shapes at each evaluation.

True
result_dir str

Directory to save results.

'./results'
Source code in src/geolens_pkg/optim.py
def curriculum_design(
    self,
    lrs=[1e-3, 1e-3, 1e-2, 1e-4],
    iterations=5000,
    test_per_iter=100,
    optim_mat=False,
    match_mat=False,
    shape_control=True,
    result_dir="./results",
):
    """Optimise the lens from scratch using curriculum aperture growth.

    Gradually increases the aperture from 25% to full size over the
    training schedule, transforming a hard global optimisation into a
    sequence of easier subproblems.

    Args:
        lrs (list, optional): Learning rates for [d, c, k, ai].
        iterations (int, optional): Total training iterations.
        test_per_iter (int, optional): Evaluate and save every N iterations.
        optim_mat (bool, optional): Optimise material parameters.
        match_mat (bool, optional): Match materials at each evaluation.
        shape_control (bool, optional): Correct surface shapes at each evaluation.
        result_dir (str, optional): Directory to save results.
    """
    depth = DEPTH
    num_ring = 16
    num_arm = 4
    spp = 2048

    aper_start = self.surfaces[self.aper_idx].r * 0.25
    aper_final = self.surfaces[self.aper_idx].r

    os.makedirs(result_dir, exist_ok=True)
    if not logging.getLogger().hasHandlers():
        logger = logging.getLogger()
        logger.setLevel("DEBUG")
        fmt = logging.Formatter("%(asctime)s:%(levelname)s:%(message)s", "%Y-%m-%d %H:%M:%S")
        sh = logging.StreamHandler()
        sh.setFormatter(fmt)
        sh.setLevel("INFO")
        fh = logging.FileHandler(f"{result_dir}/output.log")
        fh.setFormatter(fmt)
        fh.setLevel("INFO")
        logger.addHandler(sh)
        logger.addHandler(fh)
    logging.info(
        f"lr:{lrs}, iterations:{iterations}, spp:{spp}, "
        f"num_ring:{num_ring}, num_arm:{num_arm}."
    )

    optimizer = self.get_optimizer(lrs, optim_mat=optim_mat)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=iterations // 4, T_mult=1
    )

    pbar = tqdm(
        total=iterations + 1, desc="Curriculum", postfix={"loss_rms": 0, "loss_reg": 0}
    )
    for i in range(iterations + 1):
        # === Evaluate ===
        if i % test_per_iter == 0:
            with torch.no_grad():
                progress = 0.5 * (1 + math.cos(math.pi * (1 - i / iterations)))
                aper_r = min(
                    aper_start + (aper_final - aper_start) * progress,
                    aper_final,
                )
                self.surfaces[self.aper_idx].update_r(aper_r)
                self.calc_pupil()

                if i > 0:
                    if shape_control:
                        self.correct_shape()
                    if optim_mat and match_mat:
                        self.match_materials()

                self.write_lens_json(f"{result_dir}/iter{i}.json")
                self.analysis(f"{result_dir}/iter{i}")

                rays_backup = []
                for wv in WAVE_RGB:
                    ray = self.sample_ring_arm_rays(
                        num_ring=num_ring,
                        num_arm=num_arm,
                        depth=depth,
                        spp=spp,
                        wvln=wv,
                        scale_pupil=1.10,
                    )
                    rays_backup.append(ray)

                center_ref = -self.psf_center(
                    points_obj=ray.o[:, :, 0, :], method="pinhole"
                )
                center_ref = center_ref.unsqueeze(-2).repeat(1, 1, spp, 1)

        # Compute error-adaptive weight mask every iteration
        with torch.no_grad():
            ray_wm = rays_backup[0].clone()
            ray_wm = self.trace2sensor(ray_wm)
            ray_err_wm = ray_wm.o[..., :2] - center_ref
            ray_valid_wm = ray_wm.is_valid
            w_mask = ((ray_err_wm**2).sum(-1) * ray_valid_wm).sum(-1)
            w_mask /= ray_valid_wm.sum(-1) + EPSILON
            w_mask /= w_mask.mean() + EPSILON

        # === Compute RMS loss ===
        loss_rms = []
        for wv_idx, wv in enumerate(WAVE_RGB):
            ray = rays_backup[wv_idx].clone()
            ray = self.trace2sensor(ray)

            ray_xy = ray.o[..., :2]
            ray_valid = ray.is_valid
            ray_err = ray_xy - center_ref

            l_rms = ((ray_err**2).sum(-1) * ray_valid).sum(-1)
            l_rms /= ray_valid.sum(-1) + EPSILON
            l_rms = (l_rms + EPSILON).sqrt()

            l_rms_weighted = (l_rms * w_mask).sum()
            l_rms_weighted /= w_mask.sum() + EPSILON
            loss_rms.append(l_rms_weighted)

        loss_rms = sum(loss_rms) / len(loss_rms)

        w_focus = 0.05
        loss_focus = self.loss_infocus()

        w_reg = 0.05
        loss_reg, loss_dict = self.loss_reg()

        L_total = loss_rms + w_focus * loss_focus + w_reg * loss_reg

        optimizer.zero_grad()
        L_total.backward()
        optimizer.step()
        scheduler.step()

        pbar.set_postfix(loss_rms=loss_rms.item(), **loss_dict)
        pbar.update(1)

    pbar.close()

optimize_lbfgs

optimize_lbfgs(lr=0.01, iterations=500, test_per_iter=50, centroid=False, optim_mat=False, shape_control=True, max_iter=5, history_size=10, line_search_fn='strong_wolfe', result_dir=None)

Optimise the lens using L-BFGS, a quasi-Newton method.

L-BFGS uses curvature information for faster convergence on smooth objectives, making it well-suited for fine-tuning pre-optimised lenses where the loss landscape is relatively smooth.

Parameters:

Name Type Description Default
lr float

Learning rate (step size). Defaults to 0.01.

0.01
iterations int

Number of outer iterations. Defaults to 500.

500
test_per_iter int

Evaluate every N iterations. Defaults to 50.

50
centroid bool

Use chief-ray centroid as PSF centre. Defaults to False.

False
optim_mat bool

Optimise material parameters. Defaults to False.

False
shape_control bool

Apply correct_shape() at each evaluation step. Defaults to True.

True
max_iter int

Max L-BFGS inner iterations per step. Defaults to 5.

5
history_size int

L-BFGS history size. Defaults to 10.

10
line_search_fn str

Line search strategy. Use "strong_wolfe" or None. Defaults to "strong_wolfe".

'strong_wolfe'
result_dir str

Directory to save results.

None
Source code in src/geolens_pkg/optim.py
def optimize_lbfgs(
    self,
    lr=0.01,
    iterations=500,
    test_per_iter=50,
    centroid=False,
    optim_mat=False,
    shape_control=True,
    max_iter=5,
    history_size=10,
    line_search_fn="strong_wolfe",
    result_dir=None,
):
    """Optimise the lens using L-BFGS, a quasi-Newton method.

    L-BFGS uses curvature information for faster convergence on smooth
    objectives, making it well-suited for fine-tuning pre-optimised lenses
    where the loss landscape is relatively smooth.

    Args:
        lr (float, optional): Learning rate (step size). Defaults to 0.01.
        iterations (int, optional): Number of outer iterations. Defaults to 500.
        test_per_iter (int, optional): Evaluate every N iterations.
            Defaults to 50.
        centroid (bool, optional): Use chief-ray centroid as PSF centre.
            Defaults to False.
        optim_mat (bool, optional): Optimise material parameters.
            Defaults to False.
        shape_control (bool, optional): Apply ``correct_shape()`` at each
            evaluation step. Defaults to True.
        max_iter (int, optional): Max L-BFGS inner iterations per step.
            Defaults to 5.
        history_size (int, optional): L-BFGS history size. Defaults to 10.
        line_search_fn (str, optional): Line search strategy. Use
            ``"strong_wolfe"`` or ``None``. Defaults to ``"strong_wolfe"``.
        result_dir (str, optional): Directory to save results.
    """
    # Experiment settings
    depth = DEPTH
    num_ring = 32
    num_arm = 8
    spp = 2048

    # Result directory and logger
    if result_dir is None:
        result_dir = f"./results/{datetime.now().strftime('%m%d-%H%M%S')}-DesignLens-LBFGS"

    os.makedirs(result_dir, exist_ok=True)
    if not logging.getLogger().hasHandlers():
        logger = logging.getLogger()
        logger.setLevel("DEBUG")
        fmt = logging.Formatter("%(asctime)s:%(levelname)s:%(message)s", "%Y-%m-%d %H:%M:%S")
        sh = logging.StreamHandler()
        sh.setFormatter(fmt)
        sh.setLevel("INFO")
        fh = logging.FileHandler(f"{result_dir}/output.log")
        fh.setFormatter(fmt)
        fh.setLevel("INFO")
        logger.addHandler(sh)
        logger.addHandler(fh)
    logging.info(
        f"[LBFGS] lr:{lr}, iterations:{iterations}, max_iter:{max_iter}, "
        f"history_size:{history_size}, line_search:{line_search_fn}, "
        f"num_ring:{num_ring}, num_arm:{num_arm}, rays_per_fov:{spp}."
    )

    # Initialize constraints and collect all learnable parameters
    # Reuse get_optimizer_params to ensure consistency with Adam path,
    # then extract the raw tensors for LBFGS (single lr, no param groups).
    # reparam=True normalizes all params to ~O(1) for single-lr L-BFGS
    self.init_constraints()
    dummy_lrs = [1e-3, 1e-3, 1e-3, 1e-3]
    param_groups = self.get_optimizer_params(
        lrs=dummy_lrs, optim_mat=optim_mat, reparam=True
    )
    all_params = []
    for pg in param_groups:
        p = pg["params"]
        if isinstance(p, list):
            all_params.extend(p)
        else:
            all_params.append(p)

    optimizer = torch.optim.LBFGS(
        all_params,
        lr=lr,
        max_iter=max_iter,
        history_size=max(history_size, len(all_params)),
        line_search_fn=line_search_fn,
    )
    logging.info(
        f"[LBFGS] history_size raised to {optimizer.defaults['history_size']} "
        f"(= max(history_size, n_params={len(all_params)})) for full-BFGS equivalence."
    )

    # Training loop
    pbar = tqdm(
        total=iterations + 1,
        desc="LBFGS Progress",
        postfix={"loss_rms": 0, "loss_focus": 0},
    )

    # Shared state for closure logging
    last_losses = {}

    for i in range(iterations + 1):
        # ===> Evaluate the lens
        if i % test_per_iter == 0:
            with torch.no_grad():
                if shape_control and i > 0:
                    self.correct_shape()

                self.write_lens_json(f"{result_dir}/iter{i}.json")
                self.analysis(f"{result_dir}/iter{i}")

                # Sample rays
                self.calc_pupil()
                rays_backup = []
                for wv in WAVE_RGB:
                    ray = self.sample_ring_arm_rays(
                        num_ring=num_ring, num_arm=num_arm, spp=spp,
                        depth=depth, wvln=wv, scale_pupil=1.05,
                        sample_more_off_axis=False,
                    )
                    rays_backup.append(ray)

                # Calculate ray centers
                if centroid:
                    center_ref = -self.psf_center(
                        points_obj=ray.o[:, :, 0, :], method="chief_ray"
                    )
                else:
                    center_ref = -self.psf_center(
                        points_obj=ray.o[:, :, 0, :], method="pinhole"
                    )
                center_ref = center_ref.unsqueeze(-2).repeat(1, 1, spp, 1)

        # ===> L-BFGS closure
        def closure():
            optimizer.zero_grad()

            loss_rms_ls = []
            for wv_idx, wv in enumerate(WAVE_RGB):
                ray = rays_backup[wv_idx].clone()
                ray = self.trace2sensor(ray)

                ray_xy = ray.o[..., :2]
                ray_valid = ray.is_valid
                ray_err = ray_xy - center_ref
                ray_err = torch.where(
                    ray_valid.bool().unsqueeze(-1),
                    ray_err,
                    torch.zeros_like(ray_err),
                )

                l_rms = (ray_err**2).sum(-1).sum(-1)
                l_rms /= ray_valid.sum(-1) + EPSILON
                l_rms = (l_rms + EPSILON).sqrt()
                loss_rms_ls.append(l_rms.mean())

            loss_rms = sum(loss_rms_ls) / len(loss_rms_ls)

            w_focus = 1.0
            loss_focus = self.loss_infocus()
            w_reg = 0.1
            loss_reg, loss_dict = self.loss_reg()

            L_total = loss_rms + w_focus * loss_focus + w_reg * loss_reg

            L_total.backward()

            last_losses["loss_rms"] = loss_rms.item()
            last_losses["loss_focus"] = loss_focus.item()
            last_losses.update(loss_dict)

            return L_total

        optimizer.step(closure)

        pbar.set_postfix(**{k: round(v, 6) for k, v in last_losses.items()})
        pbar.update(1)

    pbar.close()

optimize_bfgs

optimize_bfgs(lrs=[0.001, 0.001, 0.01, 0.0001], lr=1.0, max_step_factor=10.0, iterations=500, test_per_iter=50, centroid=False, optim_mat=False, shape_control=True, reset_hessian_on_resample=False, result_dir=None, w_reg=0.1, w_focus=1.0)

Optimise the lens using full BFGS (no closure, Adam-like loop).

Maintains the complete N x N inverse Hessian approximation, giving true second-order convergence for smooth objectives. Uses the same zero_grad -> backward -> step loop as Adam, so ray re-sampling at evaluation boundaries is straightforward.

The inverse Hessian is initialised as diag(per_param_lr) using the learning rates from lrs, so the first BFGS step matches what Adam would do. Per-element clamping ensures no parameter changes by more than max_step_factor * its_lr per step.

Parameters:

Name Type Description Default
lrs list

Per-type learning rates [d, c, k, ai], used to initialise the diagonal of the inverse Hessian and as per-element clamp bounds. Default: [1e-3, 1e-4, 1e-1, 1e-4].

[0.001, 0.001, 0.01, 0.0001]
lr float

Global multiplier on the BFGS direction. Default: 1.0.

1.0
max_step_factor float

Per-element clamp — each scalar param changes by at most max_step_factor * lr_i per step. Default: 10.0.

10.0
iterations int

Total outer iterations. Default: 500.

500
test_per_iter int

Evaluate / re-sample rays every N iters. Default: 50.

50
centroid bool

Use chief-ray centroid as PSF centre reference.

False
optim_mat bool

Include material parameters in optimisation.

False
shape_control bool

Call correct_shape() at each eval step.

True
reset_hessian_on_resample bool

Reset inverse Hessian to initial diagonal when rays are re-sampled. Default: False.

False
result_dir str | None

Directory for results. Auto-generated if None.

None
w_reg float

Weight for regularization losses (loss_surface, loss_intersec, loss_thickness). Default: 0.1.

0.1
w_focus float

Weight for focus loss (loss_infocus). Default: 1.0.

1.0
Source code in src/geolens_pkg/optim.py
def optimize_bfgs(
    self,
    lrs=[1e-3, 1e-3, 1e-2, 1e-4],
    lr=1.0,
    max_step_factor=10.0,
    iterations=500,
    test_per_iter=50,
    centroid=False,
    optim_mat=False,
    shape_control=True,
    reset_hessian_on_resample=False,
    result_dir=None,
    w_reg=0.1,
    w_focus=1.0,
):
    """Optimise the lens using full BFGS (no closure, Adam-like loop).

    Maintains the complete N x N inverse Hessian approximation, giving
    true second-order convergence for smooth objectives.  Uses the same
    ``zero_grad -> backward -> step`` loop as Adam, so ray re-sampling
    at evaluation boundaries is straightforward.

    The inverse Hessian is initialised as ``diag(per_param_lr)`` using the
    learning rates from ``lrs``, so the first BFGS step matches what Adam
    would do.  Per-element clamping ensures no parameter changes by more
    than ``max_step_factor * its_lr`` per step.

    Args:
        lrs (list): Per-type learning rates [d, c, k, ai], used to
            initialise the diagonal of the inverse Hessian and as
            per-element clamp bounds. Default: [1e-3, 1e-4, 1e-1, 1e-4].
        lr (float): Global multiplier on the BFGS direction.  Default: 1.0.
        max_step_factor (float): Per-element clamp — each scalar param
            changes by at most ``max_step_factor * lr_i`` per step.
            Default: 10.0.
        iterations (int): Total outer iterations. Default: 500.
        test_per_iter (int): Evaluate / re-sample rays every N iters.
            Default: 50.
        centroid (bool): Use chief-ray centroid as PSF centre reference.
        optim_mat (bool): Include material parameters in optimisation.
        shape_control (bool): Call ``correct_shape()`` at each eval step.
        reset_hessian_on_resample (bool): Reset inverse Hessian to
            initial diagonal when rays are re-sampled.  Default: False.
        result_dir (str | None): Directory for results. Auto-generated
            if None.
        w_reg (float): Weight for regularization losses (loss_surface,
            loss_intersec, loss_thickness). Default: 0.1.
        w_focus (float): Weight for focus loss (loss_infocus).
            Default: 1.0.
    """
    # Experiment settings
    depth = DEPTH
    num_ring = 32
    num_arm = 8
    spp = 2048

    # Result directory and logger
    if result_dir is None:
        result_dir = f"./results/{datetime.now().strftime('%m%d-%H%M%S')}-DesignLens-BFGS"

    os.makedirs(result_dir, exist_ok=True)
    if not logging.getLogger().hasHandlers():
        root_logger = logging.getLogger()
        root_logger.setLevel("DEBUG")
        fmt = logging.Formatter(
            "%(asctime)s:%(levelname)s:%(message)s", "%Y-%m-%d %H:%M:%S"
        )
        sh = logging.StreamHandler()
        sh.setFormatter(fmt)
        sh.setLevel("INFO")
        fh = logging.FileHandler(f"{result_dir}/output.log")
        fh.setFormatter(fmt)
        fh.setLevel("INFO")
        root_logger.addHandler(sh)
        root_logger.addHandler(fh)

    # Collect learnable params WITH per-param lr from get_optimizer_params
    # FullBFGS reads the "lr" from each param group to build diag(H_0)
    # reparam=True normalizes all params to ~O(1) for single-lr BFGS
    self.init_constraints()
    param_groups = self.get_optimizer_params(
        lrs=lrs, optim_mat=optim_mat, reparam=True
    )

    n_scalar = sum(
        (p.numel() if isinstance(p, torch.Tensor) else sum(pp.numel() for pp in p))
        for pg in param_groups
        for p in [pg["params"]]
    )
    optimizer = FullBFGS(param_groups, lr=lr, max_step_factor=max_step_factor)

    logging.info(
        f"[BFGS] lrs:{lrs}, lr:{lr}, max_step_factor:{max_step_factor}, "
        f"iterations:{iterations}, n_params:{n_scalar}, "
        f"H_size:{n_scalar}x{n_scalar}, "
        f"reset_on_resample:{reset_hessian_on_resample}, "
        f"num_ring:{num_ring}, num_arm:{num_arm}, rays_per_fov:{spp}."
    )

    # Cosine lr schedule: lr decays from lr → lr * 0.01 over training
    def get_cosine_lr(step, total_steps, lr_init):
        return lr_init * 0.5 * (1.0 + math.cos(math.pi * step / total_steps))

    # Training loop
    pbar = tqdm(
        total=iterations + 1,
        desc="BFGS Progress",
        postfix={"loss_rms": 0, "loss_focus": 0},
    )

    for i in range(iterations + 1):
        # Update lr via cosine schedule
        current_lr = get_cosine_lr(i, iterations, lr)
        optimizer.set_lr(current_lr)

        # ===> Evaluate the lens
        if i % test_per_iter == 0:
            with torch.no_grad():
                if shape_control and i > 0:
                    self.correct_shape()

                self.write_lens_json(f"{result_dir}/iter{i}.json")
                self.analysis(f"{result_dir}/iter{i}")

                # Re-sample rays
                self.calc_pupil()
                rays_backup = []
                for wv in WAVE_RGB:
                    ray = self.sample_ring_arm_rays(
                        num_ring=num_ring,
                        num_arm=num_arm,
                        spp=spp,
                        depth=depth,
                        wvln=wv,
                        scale_pupil=1.05,
                        sample_more_off_axis=False,
                    )
                    rays_backup.append(ray)

                # Calculate ray centers
                if centroid:
                    center_ref = -self.psf_center(
                        points_obj=ray.o[:, :, 0, :], method="chief_ray"
                    )
                else:
                    center_ref = -self.psf_center(
                        points_obj=ray.o[:, :, 0, :], method="pinhole"
                    )
                center_ref = center_ref.unsqueeze(-2).repeat(1, 1, spp, 1)

            if reset_hessian_on_resample and i > 0:
                optimizer.reset_hessian()

        # ===> Compute loss (same as optimize())
        loss_rms_ls = []
        for wv_idx, wv in enumerate(WAVE_RGB):
            ray = rays_backup[wv_idx].clone()
            ray = self.trace2sensor(ray)

            ray_xy = ray.o[..., :2]
            ray_valid = ray.is_valid
            ray_err = ray_xy - center_ref
            ray_err = torch.where(
                ray_valid.bool().unsqueeze(-1),
                ray_err,
                torch.zeros_like(ray_err),
            )

            l_rms = (ray_err**2).sum(-1).sum(-1)
            l_rms /= ray_valid.sum(-1) + EPSILON
            l_rms = (l_rms + EPSILON).sqrt()
            loss_rms_ls.append(l_rms.mean())

        loss_rms = sum(loss_rms_ls) / len(loss_rms_ls)

        # Regularization losses (weights from method parameters)
        loss_focus = self.loss_infocus()
        loss_reg, loss_dict = self.loss_reg()

        L_total = loss_rms + w_focus * loss_focus + w_reg * loss_reg

        # Standard backward + step (no closure)
        optimizer.zero_grad()
        L_total.backward()
        optimizer.step()

        pbar.set_postfix(
            loss_rms=loss_rms.item(), loss_focus=loss_focus.item(), **loss_dict
        )
        pbar.update(1)

    pbar.close()

get_optimizer_params

get_optimizer_params(lrs=[0.001, 0.001, 0.01, 0.0001], optim_mat=False, optim_surf_range=None, reparam=False)

Get optimizer parameters for different lens surface.

Recommendation

For cellphone lens: [d, c, k, a], [1e-4, 1e-4, 1e-1, 1e-4] For camera lens: [d, c, 0, 0], [1e-3, 1e-4, 0, 0]

Parameters:

Name Type Description Default
lrs list

learning rate for different parameters.

[0.001, 0.001, 0.01, 0.0001]
optim_mat bool

whether to optimize material. Defaults to False.

False
optim_surf_range list

surface indices to be optimized. Defaults to None.

None
reparam bool

use normalized reparametrization for Aspheric surfaces. Set True for single-lr optimizers (BFGS). Defaults to False.

False

Returns:

Name Type Description
list

optimizer parameters

Source code in src/geolens_pkg/optim.py
def get_optimizer_params(
    self,
    lrs=[1e-3, 1e-3, 1e-2, 1e-4],
    optim_mat=False,
    optim_surf_range=None,
    reparam=False,
):
    """Get optimizer parameters for different lens surface.

    Recommendation:
        For cellphone lens: [d, c, k, a], [1e-4, 1e-4, 1e-1, 1e-4]
        For camera lens: [d, c, 0, 0], [1e-3, 1e-4, 0, 0]

    Args:
        lrs (list): learning rate for different parameters.
        optim_mat (bool): whether to optimize material. Defaults to False.
        optim_surf_range (list): surface indices to be optimized. Defaults to None.
        reparam (bool): use normalized reparametrization for Aspheric
            surfaces. Set True for single-lr optimizers (BFGS).
            Defaults to False.

    Returns:
        list: optimizer parameters
    """
    # Find surfaces to be optimized
    if optim_surf_range is None:
        # optim_surf_range = self.find_diff_surf()
        optim_surf_range = range(len(self.surfaces))

    # If lr for each surface is a list is given
    if isinstance(lrs[0], list):
        return self.get_optimizer_params_manual(
            lrs=lrs, optim_mat=optim_mat, optim_surf_range=optim_surf_range
        )

    # Optimize lens surface parameters
    params = []
    for surf_idx in optim_surf_range:
        surf = self.surfaces[surf_idx]

        if isinstance(surf, Aperture):
            params += surf.get_optimizer_params(lrs=[lrs[0]])

        elif isinstance(surf, Aspheric):
            params += surf.get_optimizer_params(
                lrs=lrs[:4], optim_mat=optim_mat, reparam=reparam
            )

        elif isinstance(surf, Phase):
            params += surf.get_optimizer_params(lrs=[lrs[0], lrs[4]])

        # elif isinstance(surf, GaussianRBF):
        #     params += surf.get_optimizer_params(lrs=lr, optim_mat=optim_mat)

        # elif isinstance(surf, NURBS):
        #     params += surf.get_optimizer_params(lrs=lr, optim_mat=optim_mat)

        elif isinstance(surf, Plane):
            params += surf.get_optimizer_params(lrs=[lrs[0]], optim_mat=optim_mat)

        # elif isinstance(surf, PolyEven):
        #     params += surf.get_optimizer_params(lrs=lr, optim_mat=optim_mat)

        elif isinstance(surf, Spheric):
            params += surf.get_optimizer_params(
                lrs=[lrs[0], lrs[1]], optim_mat=optim_mat
            )

        elif isinstance(surf, ThinLens):
            params += surf.get_optimizer_params(
                lrs=[lrs[0], lrs[1]], optim_mat=optim_mat
            )

        else:
            raise Exception(
                f"Surface type {surf.__class__.__name__} is not supported for optimization yet."
            )

    # Optimize sensor place
    self.d_sensor.requires_grad = True
    params += [{"params": self.d_sensor, "lr": lrs[0]}]

    return params

get_optimizer

get_optimizer(lrs=[0.0001, 0.0001, 0.1, 0.0001], optim_surf_range=None, optim_mat=False)

Get optimizers and schedulers for different lens parameters.

Parameters:

Name Type Description Default
lrs list

learning rate for different parameters [c, d, k, a]. Defaults to [1e-4, 1e-4, 0, 1e-4].

[0.0001, 0.0001, 0.1, 0.0001]
optim_surf_range list

surface indices to be optimized. Defaults to None.

None
optim_mat bool

whether to optimize material. Defaults to False.

False

Returns:

Name Type Description
list

optimizer parameters

Source code in src/geolens_pkg/optim.py
def get_optimizer(
    self,
    lrs=[1e-4, 1e-4, 1e-1, 1e-4],
    optim_surf_range=None,
    optim_mat=False,
):
    """Get optimizers and schedulers for different lens parameters.

    Args:
        lrs (list): learning rate for different parameters [c, d, k, a]. Defaults to [1e-4, 1e-4, 0, 1e-4].
        optim_surf_range (list): surface indices to be optimized. Defaults to None.
        optim_mat (bool): whether to optimize material. Defaults to False.

    Returns:
        list: optimizer parameters
    """
    # Initialize lens design constraints (edge thickness, etc.)
    self.init_constraints()

    # Get optimizer (reparam=False for Adam — uses per-order lr scaling)
    params = self.get_optimizer_params(
        lrs=lrs, optim_surf_range=optim_surf_range, optim_mat=optim_mat,
        reparam=False,
    )
    optimizer = torch.optim.Adam(params)
    # optimizer = torch.optim.SGD(params)
    return optimizer

optimize(lrs, iterations, test_per_iter, centroid, optim_mat, shape_control, result_dir)

Adam optimizer loop with cosine warmup scheduler and error-adaptive field weighting.

lens.optimize(
    lrs=[1e-3, 1e-4, 1e-2, 1e-4],
    iterations=2000,
    test_per_iter=100,
    centroid=False,
    result_dir="results/finetune",
)

optimize_lbfgs(iterations, lr, reparam, result_dir)

L-BFGS optimizer with optional parameter reparametrization.

lens.optimize_lbfgs(
    iterations=500,
    lr=0.5,
    reparam=True,
    result_dir="results/lbfgs",
)

curriculum_design(iterations, test_per_iter, lrs, optim_mat, result_dir)

Curriculum learning with gradual aperture opening and material optimization.

Loss Functions

Method Returns
loss_rms() RGB spot RMS error
loss_infocus() On-axis focus penalty
loss_reg() Composite regularization
loss_surface() Surface shape penalty
loss_intersec() Self-intersection penalty
loss_thickness() Thickness / TTL penalty
loss_ray_angle() Chief ray angle penalty
loss_mat() Material bounds penalty

Surface Operations

src.geolens_pkg.optim_ops.GeoLensSurfOps

Mixin providing surface geometry operations for GeoLens.

Methods:

Name Description
- add_aspheric

Convert a spherical surface to aspheric.

- increase_aspheric_order

Add higher-order polynomial terms.

- prune_surf

Size clear apertures by ray tracing.

- correct_shape

Fix lens geometry during optimisation.

add_aspheric

add_aspheric(surf_idx=None, ai_degree=4)

Convert a spherical surface to aspheric for improved aberration correction.

If surf_idx is given, converts that specific surface. Otherwise, automatically selects the best candidate following established optical design principles:

  1. First asphere: placed near the aperture stop (corrects spherical aberration).
  2. Subsequent aspheres: placed far from the stop (corrects field-dependent aberrations like coma, astigmatism, distortion).
  3. Prefer air-glass interfaces over cemented surfaces.
  4. Among candidates at similar stop-distances, prefer larger semi-diameter (higher marginal ray height → more SA contribution).

The new surface starts with k=0 and all polynomial coefficients at zero, so it is initially identical to the original spherical surface.

Note

After calling this method, any existing optimizer is stale. Call get_optimizer() again to include the new parameters.

Parameters:

Name Type Description Default
surf_idx int or None

Surface index to convert. If None, auto-selects the best candidate.

None
ai_degree int

Number of even-order aspheric coefficients [a2, a4, a6, ...]. Defaults to 4.

4

Returns:

Name Type Description
int

Index of the converted surface.

Raises:

Type Description
IndexError

If surf_idx is out of range.

ValueError

If surf_idx points to a non-Spheric surface, or no eligible candidate exists for auto-selection.

References

Design principles from research/aspheric_design_principles.md.

Source code in src/geolens_pkg/optim_ops.py
@torch.no_grad()
def add_aspheric(self, surf_idx=None, ai_degree=4):
    """Convert a spherical surface to aspheric for improved aberration correction.

    If ``surf_idx`` is given, converts that specific surface. Otherwise,
    automatically selects the best candidate following established optical
    design principles:

    1. First asphere: placed near the aperture stop (corrects spherical
       aberration).
    2. Subsequent aspheres: placed far from the stop (corrects field-dependent
       aberrations like coma, astigmatism, distortion).
    3. Prefer air-glass interfaces over cemented surfaces.
    4. Among candidates at similar stop-distances, prefer larger semi-diameter
       (higher marginal ray height → more SA contribution).

    The new surface starts with ``k=0`` and all polynomial coefficients at
    zero, so it is initially identical to the original spherical surface.

    Note:
        After calling this method, any existing optimizer is stale.
        Call ``get_optimizer()`` again to include the new parameters.

    Args:
        surf_idx (int or None): Surface index to convert. If ``None``,
            auto-selects the best candidate.
        ai_degree (int): Number of even-order aspheric coefficients
            ``[a2, a4, a6, ...]``. Defaults to 4.

    Returns:
        int: Index of the converted surface.

    Raises:
        IndexError: If ``surf_idx`` is out of range.
        ValueError: If ``surf_idx`` points to a non-Spheric surface, or no
            eligible candidate exists for auto-selection.

    References:
        Design principles from ``research/aspheric_design_principles.md``.
    """
    if surf_idx is not None:
        if surf_idx < 0 or surf_idx >= len(self.surfaces):
            raise IndexError(
                f"surf_idx={surf_idx} out of range [0, {len(self.surfaces) - 1}]."
            )
        if not isinstance(self.surfaces[surf_idx], Spheric):
            raise ValueError(
                f"Surface {surf_idx} is {type(self.surfaces[surf_idx]).__name__}, "
                f"expected Spheric. To add higher-order terms to an existing "
                f"Aspheric surface, use increase_aspheric_order(surf_idx={surf_idx})."
            )
        self._spheric_to_aspheric(surf_idx, ai_degree)
        logging.info(
            f"Converted surface {surf_idx} from Spheric to Aspheric "
            f"(ai_degree={ai_degree})."
        )
        return surf_idx

    # Auto-select best candidate
    surf_idx = self._find_best_asphere_candidate()
    self._spheric_to_aspheric(surf_idx, ai_degree)
    logging.info(
        f"Auto-selected surface {surf_idx} as best asphere candidate. "
        f"Converted to Aspheric (ai_degree={ai_degree})."
    )
    return surf_idx

increase_aspheric_order

increase_aspheric_order(surf_idx=None, increment=1)

Add higher-order polynomial terms to existing Aspheric surfaces.

Appends increment additional even-order coefficients (initialised to zero). For example, degree 4 [a4, a6, a8, a10] becomes degree 5 [a4, a6, a8, a10, a12] after increment=1.

Follows the principle of start low, add incrementally: increase order only when residual higher-order aberrations persist after optimisation at the current order.

Note

After calling this method, any existing optimizer is stale. Call get_optimizer() again to include the new parameters.

Parameters:

Name Type Description Default
surf_idx int or None

Surface index. If None, auto-selects the best candidate (see _find_best_order_increase_candidate).

None
increment int

Number of additional coefficients to add. Defaults to 1.

1

Returns:

Name Type Description
int

Index of the surface whose order was increased.

Raises:

Type Description
IndexError

If surf_idx is out of range.

ValueError

If surf_idx is given but is not Aspheric, if no Aspheric surfaces exist when surf_idx is None, or if increment < 1.

Source code in src/geolens_pkg/optim_ops.py
@torch.no_grad()
def increase_aspheric_order(self, surf_idx=None, increment=1):
    """Add higher-order polynomial terms to existing Aspheric surfaces.

    Appends ``increment`` additional even-order coefficients (initialised
    to zero). For example, degree 4 ``[a4, a6, a8, a10]`` becomes degree 5
    ``[a4, a6, a8, a10, a12]`` after ``increment=1``.

    Follows the principle of *start low, add incrementally*: increase
    order only when residual higher-order aberrations persist after
    optimisation at the current order.

    Note:
        After calling this method, any existing optimizer is stale.
        Call ``get_optimizer()`` again to include the new parameters.

    Args:
        surf_idx (int or None): Surface index. If ``None``, auto-selects
            the best candidate (see ``_find_best_order_increase_candidate``).
        increment (int): Number of additional coefficients to add.
            Defaults to 1.

    Returns:
        int: Index of the surface whose order was increased.

    Raises:
        IndexError: If ``surf_idx`` is out of range.
        ValueError: If ``surf_idx`` is given but is not Aspheric, if
            no Aspheric surfaces exist when ``surf_idx`` is ``None``,
            or if ``increment`` < 1.
    """
    if increment < 1:
        raise ValueError(f"increment must be >= 1, got {increment}.")
    if surf_idx is not None:
        if surf_idx < 0 or surf_idx >= len(self.surfaces):
            raise IndexError(
                f"surf_idx={surf_idx} out of range [0, {len(self.surfaces) - 1}]."
            )
    else:
        surf_idx = self._find_best_order_increase_candidate()

    surf = self.surfaces[surf_idx]
    if not isinstance(surf, Aspheric):
        raise ValueError(
            f"Surface {surf_idx} is {type(surf).__name__}, expected Aspheric."
        )
    old_degree = surf.ai_degree
    self._increase_surface_order(surf, increment)
    logging.info(
        f"Surface {surf_idx}: aspheric order {old_degree} -> {surf.ai_degree}."
    )

    return surf_idx

prune_surf

prune_surf(expand_factor=None, mounting_margin=None)

Prune surfaces to allow all valid rays to go through.

Determines the clear aperture for each surface by ray tracing, then applies margins and enforces manufacturability constraints (edge thickness and air-gap clearance).

Parameters:

Name Type Description Default
expand_factor float

Fractional expansion applied to the ray-traced clear aperture radius. Auto-selected if None: 10 % for all lenses.

None
mounting_margin float

Absolute margin [mm] added to the clear aperture for mechanical mounting. When given, this replaces the proportional expand_factor expansion.

None
Source code in src/geolens_pkg/optim_ops.py
@torch.no_grad()
def prune_surf(self, expand_factor=None, mounting_margin=None):
    """Prune surfaces to allow all valid rays to go through.

    Determines the clear aperture for each surface by ray tracing, then
    applies margins and enforces manufacturability constraints (edge
    thickness and air-gap clearance).

    Args:
        expand_factor (float, optional): Fractional expansion applied to
            the ray-traced clear aperture radius.  Auto-selected if None:
            10 % for all lenses.
        mounting_margin (float, optional): Absolute margin [mm] added to
            the clear aperture for mechanical mounting.  When given, this
            replaces the proportional ``expand_factor`` expansion.
    """
    surface_range = self.find_diff_surf()
    num_surfs = len(self.surfaces)

    # Set expansion factor
    if expand_factor is None:
        expand_factor = 0.10

    # ------------------------------------------------------------------
    # 1. Temporarily remove radius limits so the trace is unclipped
    # ------------------------------------------------------------------
    saved_radii = [self.surfaces[i].r for i in range(num_surfs)]
    for i in surface_range:
        self.surfaces[i].r = self.surfaces[i].max_height()

    # ------------------------------------------------------------------
    # 2. Trace rays at full FoV to find maximum ray height per surface
    # ------------------------------------------------------------------
    if self.rfov is not None:
        fov_deg = self.rfov * 180 / torch.pi
    elif self.rfov_eff is not None:
        fov_deg = self.rfov_eff * 180 / torch.pi
    else:
        fov = np.arctan(self.r_sensor / self.foclen)
        fov_deg = float(fov) * 180 / torch.pi
        print(f"Using fov_deg: {fov_deg} during surface pruning.")

    fov_y = [f * fov_deg / 10 for f in range(0, 11)]
    ray = self.sample_from_fov(
        fov_x=[0.0], fov_y=fov_y, num_rays=SPP_CALC, scale_pupil=1.0
    )
    _, ray_o_record = self.trace2sensor(ray=ray, record=True)

    # Ray record, shape [num_rays, num_surfaces + 2, 3]
    ray_o_record = torch.stack(ray_o_record, dim=-2)
    ray_o_record = torch.nan_to_num(ray_o_record, 0.0)
    ray_o_record = ray_o_record.reshape(-1, ray_o_record.shape[-2], 3)

    # Compute the maximum ray height for each surface
    ray_r_record = (ray_o_record[..., :2] ** 2).sum(-1).sqrt()
    surf_r_max = ray_r_record.max(dim=0)[0][1:-1]

    # Restore original radii before updating
    for i in range(num_surfs):
        self.surfaces[i].r = saved_radii[i]

    # ------------------------------------------------------------------
    # 3. Set new surface radii = ray-traced clear aperture + margin
    # ------------------------------------------------------------------
    for i in surface_range:
        if surf_r_max[i] > 0:
            r_clear = surf_r_max[i].item()
            if mounting_margin is not None:
                r_new = r_clear + mounting_margin
            else:
                r_expand = r_clear * expand_factor
                r_expand = max(min(r_expand, 2.0), 0.1)
                r_new = r_clear + r_expand
            self.surfaces[i].update_r(r_new)
        else:
            print(f"No valid rays for Surf {i}, expand existing radius.")
            if mounting_margin is not None:
                self.surfaces[i].update_r(self.surfaces[i].r + mounting_margin)
            else:
                r_expand = self.surfaces[i].r * expand_factor
                r_expand = max(min(r_expand, 2.0), 0.1)
                self.surfaces[i].update_r(self.surfaces[i].r + r_expand)

    # ------------------------------------------------------------------
    # 4. Air gap clearance check
    #    For each air gap (surface i with mat2 = "air"), ensure that
    #    surfaces do not physically intersect at the clear aperture edge.
    # ------------------------------------------------------------------
    if self.r_sensor < 10.0:
        air_gap_min = 0.05  # mm
    else:
        air_gap_min = 0.1  # mm

    for i in range(num_surfs - 1):
        if self.surfaces[i].mat2.name != "air":
            continue
        if isinstance(self.surfaces[i], Aperture):
            continue

        curr = self.surfaces[i]
        nxt = self.surfaces[i + 1]
        r_check = min(curr.r, nxt.r)

        if r_check <= 0:
            continue

        # Check gap at multiple radial points along the edge
        r_pts = torch.linspace(0.5 * r_check, r_check, 8, device=self.device)
        z_curr = curr.surface_with_offset(r_pts, 0.0, valid_check=False)
        z_nxt = nxt.surface_with_offset(r_pts, 0.0, valid_check=False)
        min_gap = (z_nxt - z_curr).min().item()

        if min_gap < air_gap_min:
            # Shrink radius until air gap is met (binary search)
            r_lo, r_hi = 0.0, r_check
            for _ in range(20):
                r_mid = (r_lo + r_hi) / 2
                r_pts = torch.linspace(0.5 * r_mid, r_mid, 8, device=self.device)
                z_c = curr.surface_with_offset(r_pts, 0.0, valid_check=False)
                z_n = nxt.surface_with_offset(r_pts, 0.0, valid_check=False)
                if (z_n - z_c).min().item() >= air_gap_min:
                    r_lo = r_mid
                else:
                    r_hi = r_mid

            r_safe = r_lo
            if r_safe > 0 and r_safe < r_check:
                print(
                    f"Surf {i}-{i+1}: air gap {min_gap:.3f} mm "
                    f"< {air_gap_min} mm, shrinking radius {r_check:.3f} -> {r_safe:.3f} mm."
                )
                if curr.r > r_safe:
                    curr.update_r(r_safe)
                if nxt.r > r_safe:
                    nxt.update_r(r_safe)

    # ------------------------------------------------------------------
    # 6. Validate aperture radius consistency
    #    The aperture (stop) radius should not exceed the clear aperture
    #    of its neighboring surfaces.
    # ------------------------------------------------------------------
    if self.aper_idx is not None:
        aper = self.surfaces[self.aper_idx]
        # Find neighboring non-aperture surfaces
        neighbor_r = []
        if self.aper_idx > 0:
            neighbor_r.append(self.surfaces[self.aper_idx - 1].r)
        if self.aper_idx < num_surfs - 1:
            neighbor_r.append(self.surfaces[self.aper_idx + 1].r)

        if neighbor_r:
            max_aper_r = min(neighbor_r)
            if aper.r > max_aper_r:
                print(
                    f"Aperture radius {aper.r:.3f} mm exceeds neighbor "
                    f"clear aperture {max_aper_r:.3f} mm, clamping."
                )
                aper.r = max_aper_r

correct_shape

correct_shape(expand_factor=None, mounting_margin=None)

Correct wrong lens shape during lens design optimization.

Applies correction rules to ensure valid lens geometry
  1. Move the first surface to z = 0.0
  2. Fix aperture distance if aperture is at the front
  3. Prune all surfaces to allow valid rays through

Parameters:

Name Type Description Default
expand_factor float

Height expansion factor for surface pruning. If None, auto-selects based on lens type. Defaults to None.

None
mounting_margin float

Absolute mounting margin [mm] for surface pruning. Passed through to :meth:prune_surf.

None

Returns:

Name Type Description
bool

True if any shape corrections were made, False otherwise.

Source code in src/geolens_pkg/optim_ops.py
@torch.no_grad()
def correct_shape(self, expand_factor=None, mounting_margin=None):
    """Correct wrong lens shape during lens design optimization.

    Applies correction rules to ensure valid lens geometry:
        1. Move the first surface to z = 0.0
        2. Fix aperture distance if aperture is at the front
        3. Prune all surfaces to allow valid rays through

    Args:
        expand_factor (float, optional): Height expansion factor for surface pruning.
            If None, auto-selects based on lens type. Defaults to None.
        mounting_margin (float, optional): Absolute mounting margin [mm] for
            surface pruning.  Passed through to :meth:`prune_surf`.

    Returns:
        bool: True if any shape corrections were made, False otherwise.
    """
    aper_idx = self.aper_idx
    optim_surf_range = self.find_diff_surf()
    shape_changed = False

    # Rule 1: Move the first surface to z = 0.0
    move_dist = self.surfaces[0].d.item()
    for surf in self.surfaces:
        surf.d -= move_dist
    self.d_sensor -= move_dist

    # Rule 2: Fix aperture distance to the first surface if aperture in the front.
    if aper_idx == 0:
        d_aper = 0.05

        # If the first surface is concave, use the maximum negative sag.
        aper_r = torch.tensor(self.surfaces[aper_idx].r, device=self.device)
        sag1 = -self.surfaces[aper_idx + 1].sag(aper_r, 0).item()

        if sag1 > 0:
            d_aper += sag1

        # Update position of all surfaces.
        delta_aper = self.surfaces[1].d.item() - d_aper
        for i in optim_surf_range:
            self.surfaces[i].d -= delta_aper
        self.d_sensor -= delta_aper

    # Rule 4: Prune all surfaces
    self.prune_surf(expand_factor=expand_factor, mounting_margin=mounting_margin)

    if shape_changed:
        print("Surface shape corrected.")
    return shape_changed

match_materials

match_materials(mat_table='CDGM')

Match lens materials to a glass catalog.

Parameters:

Name Type Description Default
mat_table str

Glass catalog name. Common options include 'CDGM', 'SCHOTT', 'OHARA'. Defaults to 'CDGM'.

'CDGM'
Source code in src/geolens_pkg/optim_ops.py
@torch.no_grad()
def match_materials(self, mat_table="CDGM"):
    """Match lens materials to a glass catalog.

    Args:
        mat_table (str, optional): Glass catalog name. Common options include
            'CDGM', 'SCHOTT', 'OHARA'. Defaults to 'CDGM'.
    """
    for surf in self.surfaces:
        surf.mat2.match_material(mat_table=mat_table)

add_aspheric()

Add aspheric polynomial terms to surfaces.

increase_aspheric_order()

Increase the order of existing aspheric polynomials.

prune_surf(expand_factor)

Resize surface clear apertures to actual ray footprint plus margin.

correct_shape()

Clip surfaces to valid geometry (positive edge thickness, valid sag).

match_materials()

Snap floating refractive indices to the nearest real glass in the catalog.


File I/O

src.geolens_pkg.io.GeoLensIO

Mixin providing file I/O for GeoLens.

Supports reading and writing lens prescriptions in three formats:

  • JSON (primary): human-readable, supports parenthesised optimisable parameters, e.g. "(d)": 5.0.
  • Zemax .zmx: industry-standard sequential lens file.
  • Code V .seq: Code V sequential format (read-only).

This class is not instantiated directly; it is mixed into :class:~deeplens.optics.geolens.GeoLens.

read_lens_zmx

read_lens_zmx(filename='./test.zmx')

Load the lens from a Zemax .zmx sequential lens file.

Parses STANDARD and EVENASPH surface types, glass materials, field definitions (YFLN), and entrance pupil settings (ENPD/FLOA).

Parameters:

Name Type Description Default
filename str

Path to the .zmx file. Supports both UTF-8 and UTF-16 encoded files. Defaults to './test.zmx'.

'./test.zmx'

Returns:

Name Type Description
GeoLens

self, for method chaining.

Source code in src/geolens_pkg/io.py
def read_lens_zmx(self, filename="./test.zmx"):
    """Load the lens from a Zemax .zmx sequential lens file.

    Parses STANDARD and EVENASPH surface types, glass materials, field
    definitions (YFLN), and entrance pupil settings (ENPD/FLOA).

    Args:
        filename (str, optional): Path to the .zmx file. Supports both
            UTF-8 and UTF-16 encoded files. Defaults to './test.zmx'.

    Returns:
        GeoLens: ``self``, for method chaining.
    """
    # Read .zmx file
    try:
        with open(filename, "r", encoding="utf-8") as file:
            lines = file.readlines()
    except UnicodeDecodeError:
        with open(filename, "r", encoding="utf-16") as file:
            lines = file.readlines()

    # Iterate through the lines and extract SURF dict
    surfs_dict = {}
    current_surf = None
    for line in lines:
        # Strip leading/trailing whitespace for consistent parsing
        stripped_line = line.strip()

        if stripped_line.startswith("SURF"):
            current_surf = int(stripped_line.split()[1])
            surfs_dict[current_surf] = {}

        elif current_surf is not None and stripped_line != "":
            if stripped_line == "STOP":
                surfs_dict[current_surf]["STOP"] = True
                continue
            if len(stripped_line.split(maxsplit=1)) == 1:
                continue
            else:
                key, value = stripped_line.split(maxsplit=1)
                if key == "PARM":
                    new_key = "PARM" + value.split()[0]
                    new_value = value.split()[1]
                    surfs_dict[current_surf][new_key] = new_value
                elif key == "XDAT":
                    new_key = "XDAT" + value.split()[0]
                    new_value = value.split()[1]
                    surfs_dict[current_surf][new_key] = new_value
                else:
                    surfs_dict[current_surf][key] = value

        elif stripped_line.startswith("FLOA") or stripped_line.startswith("ENPD"):
            if stripped_line.startswith("FLOA"):
                self.float_enpd = True
                self.enpd = None
            else:
                self.float_enpd = False
                self.enpd = float(stripped_line.split()[1])

        elif stripped_line.startswith("YFLN"):
            # Parse field of view from YFLN line (field coordinates in degrees)
            # YFLN format: YFLN 0.0 <0.707*rfov_deg> <0.99*rfov_deg>
            parts = stripped_line.split()
            if len(parts) > 1:
                field_values = [abs(float(x)) for x in parts[1:] if float(x) != 0.0]
                if field_values:
                    # The largest field value is typically 0.99 * rfov_deg
                    max_field_deg = max(field_values) / 0.99
                    self.rfov_eff = (
                        max_field_deg * math.pi / 180.0
                    )  # Convert to radians

    self.float_foclen = False
    self.float_rfov = False
    if not hasattr(self, "float_enpd"):
        self.float_enpd = True
        self.enpd = None
    # Set default rfov_eff if not parsed from file
    if not hasattr(self, "rfov_eff"):
        self.rfov_eff = None

    # Read the extracted data from each SURF
    self.surfaces = []
    d = 0.0
    mat1_name = "air"
    for surf_idx, surf_dict in surfs_dict.items():
        if surf_idx > 0 and surf_idx < current_surf:
            # Lens surface parameters
            if "GLAS" in surf_dict:
                if surf_dict["GLAS"].split()[0] == "___BLANK":
                    mat2_name = f"{surf_dict['GLAS'].split()[3]}/{surf_dict['GLAS'].split()[4]}"
                else:
                    mat2_name = surf_dict["GLAS"].split()[0].lower()
            else:
                mat2_name = "air"

            surf_r = (
                float(surf_dict["DIAM"].split()[0]) if "DIAM" in surf_dict else 1.0
            )
            surf_c = (
                float(surf_dict["CURV"].split()[0]) if "CURV" in surf_dict else 0.0
            )
            surf_d_next = (
                float(surf_dict["DISZ"].split()[0]) if "DISZ" in surf_dict else 0.0
            )
            surf_conic = float(surf_dict.get("CONI", 0.0))
            surf_param2 = float(surf_dict.get("PARM2", 0.0))
            surf_param3 = float(surf_dict.get("PARM3", 0.0))
            surf_param4 = float(surf_dict.get("PARM4", 0.0))
            surf_param5 = float(surf_dict.get("PARM5", 0.0))
            surf_param6 = float(surf_dict.get("PARM6", 0.0))
            surf_param7 = float(surf_dict.get("PARM7", 0.0))
            surf_param8 = float(surf_dict.get("PARM8", 0.0))

            # Create surface object
            if surf_dict["TYPE"] == "STANDARD":
                if mat2_name == "air" and mat1_name == "air":
                    # Aperture
                    s = Aperture(r=surf_r, d=d)
                else:
                    # Spherical surface
                    s = Spheric(c=surf_c, r=surf_r, d=d, mat2=mat2_name)

            elif surf_dict["TYPE"] == "EVENASPH":
                # Aspherical surface
                s = Aspheric(
                    c=surf_c,
                    r=surf_r,
                    d=d,
                    ai=[
                        surf_param2,
                        surf_param3,
                        surf_param4,
                        surf_param5,
                        surf_param6,
                        surf_param7,
                        surf_param8,
                    ],
                    k=surf_conic,
                    mat2=mat2_name,
                )

            elif surf_dict["TYPE"] == "BINARY_2":
                # Binary 2 (metalens / DOE) phase surface
                norm_radii = float(surf_dict.get("XDAT2", surf_r))
                s = Binary2Phase(
                    r=surf_r,
                    d=d,
                    order2=float(surf_dict.get("XDAT3", 0.0)),
                    order4=float(surf_dict.get("XDAT4", 0.0)),
                    order6=float(surf_dict.get("XDAT5", 0.0)),
                    order8=float(surf_dict.get("XDAT6", 0.0)),
                    order10=float(surf_dict.get("XDAT7", 0.0)),
                    order12=float(surf_dict.get("XDAT8", 0.0)),
                    norm_radii=norm_radii,
                    mat2=mat2_name,
                )

            else:
                print(f"Surface type {surf_dict['TYPE']} not implemented.")
                continue

            self.surfaces.append(s)
            d += surf_d_next
            mat1_name = mat2_name

        elif surf_idx == current_surf:
            # Image sensor
            self.r_sensor = float(surf_dict["DIAM"].split()[0])

        else:
            pass

    self.d_sensor = torch.tensor(d)
    return self

write_lens_zmx

write_lens_zmx(filename='./test.zmx')

Write the lens to a Zemax .zmx sequential lens file.

Exports surfaces (STANDARD or EVENASPH), materials, field definitions, and entrance pupil settings in Zemax OpticStudio format.

Parameters:

Name Type Description Default
filename str

Output file path. Defaults to './test.zmx'.

'./test.zmx'
Source code in src/geolens_pkg/io.py
def write_lens_zmx(self, filename="./test.zmx"):
    """Write the lens to a Zemax .zmx sequential lens file.

    Exports surfaces (STANDARD or EVENASPH), materials, field definitions,
    and entrance pupil settings in Zemax OpticStudio format.

    Args:
        filename (str, optional): Output file path. Defaults to './test.zmx'.
    """
    lens_zmx_str = ""
    if self.float_enpd:
        enpd_str = "FLOA"
    else:
        enpd_str = f"ENPD {self.enpd}"
    # Head string
    head_str = f"""VERS 190513 80 123457 L123457
MODE SEQ
NAME 
PFIL 0 0 0
LANG 0
UNIT MM X W X CM MR CPMM
{enpd_str}
ENVD 2.0E+1 1 0
GFAC 0 0
GCAT OSAKAGASCHEMICAL MISC
XFLN 0. 0. 0.
YFLN 0.0 {0.707 * self.rfov_eff * 57.3} {0.99 * self.rfov_eff * 57.3}
WAVL 0.4861327 0.5875618 0.6562725
RAIM 0 0 1 1 0 0 0 0 0
PUSH 0 0 0 0 0 0
SDMA 0 1 0
FTYP 0 0 3 3 0 0 0
ROPD 2
PICB 1
PWAV 2
POLS 1 0 1 0 0 1 0
GLRS 1 0
GSTD 0 100.000 100.000 100.000 100.000 100.000 100.000 0 1 1 0 0 1 1 1 1 1 1
NSCD 100 500 0 1.0E-3 5 1.0E-6 0 0 0 0 0 0 1000000 0 2
COFN QF "COATING.DAT" "SCATTER_PROFILE.DAT" "ABG_DATA.DAT" "PROFILE.GRD"
COFN COATING.DAT SCATTER_PROFILE.DAT ABG_DATA.DAT PROFILE.GRD
SURF 0
TYPE STANDARD
CURV 0.0
DISZ INFINITY
"""
    lens_zmx_str += head_str

    # Surface string
    for i, s in enumerate(self.surfaces):
        d_next = (
            self.surfaces[i + 1].d - self.surfaces[i].d
            if i < len(self.surfaces) - 1
            else self.d_sensor - self.surfaces[i].d
        )
        surf_str = s.zmx_str(surf_idx=i + 1, d_next=d_next)
        lens_zmx_str += surf_str

    # Sensor string
    sensor_str = f"""SURF {i + 2}
TYPE STANDARD
CURV 0.
DISZ 0.0
DIAM {self.r_sensor}
"""
    lens_zmx_str += sensor_str

    # Write lens zmx string into file
    with open(filename, "w") as f:
        f.writelines(lens_zmx_str)
        f.close()
        print(f"Lens written to {filename}")

read_lens_seq

read_lens_seq(filename='./test.seq')

Load the lens from a CODE V .seq sequential file.

Parses standard and aspheric surfaces (with conic and polynomial coefficients A–I), entrance pupil diameter (EPD), field angles (YAN), aperture stop (STO), and image surface (SI).

Parameters:

Name Type Description Default
filename str

Path to the .seq file. Supports both UTF-8 and Latin-1 encoded files. Defaults to './test.seq'.

'./test.seq'

Returns:

Name Type Description
GeoLens

self, for method chaining.

Source code in src/geolens_pkg/io.py
def read_lens_seq(self, filename="./test.seq"):
    """Load the lens from a CODE V .seq sequential file.

    Parses standard and aspheric surfaces (with conic and polynomial
    coefficients A–I), entrance pupil diameter (EPD), field angles (YAN),
    aperture stop (STO), and image surface (SI).

    Args:
        filename (str, optional): Path to the .seq file. Supports both
            UTF-8 and Latin-1 encoded files. Defaults to './test.seq'.

    Returns:
        GeoLens: ``self``, for method chaining.
    """
    print(f"\n{'=' * 60}")
    print(f"Start reading CODE V file: {filename}")
    print(f"{'=' * 60}\n")

    # Read .seq file
    try:
        with open(filename, "r", encoding="utf-8") as file:
            lines = file.readlines()
        print(f"File read successfully (UTF-8)")
    except UnicodeDecodeError:
        try:
            with open(filename, "r", encoding="latin-1") as file:
                lines = file.readlines()
            print(f"File read successfully (Latin-1)")
        except Exception as e:
            print(f"Failed to read file: {e}")
            return self
    print(f"Total lines: {len(lines)}\n")

    # ============ Step 1: Parse file structure ============
    surfaces = []
    current_surface = {}
    surface_index = 0
    global_diameter = None

    print("Beginning to parse surface data...\n")

    for line_num, line in enumerate(lines, 1):
        line = line.strip()

        # Skip irrelevant lines
        if not line or line.startswith(
            (
                "RDM",
                "TITLE",
                "UID",
                "GO",
                "WL",
                "XAN",
                "REF",
                "WTW",
                "INI",
                "WTF",
                "VUY",
                "VLY",
                "DOR",
                "DIM",
                "THC",
            )
        ):
            continue
        # Read entrance pupil diameter
        if line.startswith("EPD"):
            self.enpd = float(line.split()[1])
            self.float_enpd = False
            global_diameter = self.enpd / 2.0
            print(
                f"[Line {line_num}] EPD={self.enpd} -> default radius={global_diameter}"
            )
            continue
        # Read field of view angle
        if line.startswith("YAN"):
            angles = [abs(float(x)) for x in line.split()[1:] if float(x) != 0.0]
            if angles:
                self.hfov = max(angles)
                # Also set rfov_eff in radians for consistency with write functions
                self.rfov_eff = self.hfov * math.pi / 180.0
                print(f"[Line {line_num}] Max field of view={self.hfov} deg")
            continue
        # Object surface
        if line.startswith("SO"):
            parts = line.split()
            thickness = float(parts[2]) if len(parts) > 2 else 1e10

            current_surface = {
                "type": "OBJECT",
                "thickness": thickness,
                "index": surface_index,
            }
            surfaces.append(current_surface)
            print(f"[Line {line_num}] Object surface: T={thickness}")
            surface_index += 1
            current_surface = {}
            continue
        # Standard surface
        if line.startswith("S "):
            # Save the previous surface
            if current_surface:
                surfaces.append(current_surface)
                surface_index += 1

            parts = line.split()
            radius_value = float(parts[1]) if len(parts) > 1 else 0.0
            thickness = float(parts[2]) if len(parts) > 2 else 0.0
            material = parts[3].upper() if len(parts) > 3 else "AIR"

            # Key: compute curvature C = 1/R
            if abs(radius_value) > 1e-10:
                curvature = 1.0 / radius_value
            else:
                curvature = 0.0

            current_surface = {
                "type": "STANDARD",
                "radius": radius_value,
                "curvature": curvature,
                "thickness": thickness,
                "material": material,
                "index": surface_index,
                "diameter": global_diameter,
                "conic": 0.0,
                "asph_coeffs": {},
                "is_stop": False,
            }

            print(
                f"[Line {line_num}] Surface{surface_index}: R={radius_value:.4f} → C={curvature:.6f}, T={thickness}, Mat={material}"
            )
            continue
        # Image surface - do not append yet, wait for CIR
        if line.startswith("SI"):
            if current_surface:
                surfaces.append(current_surface)
                surface_index += 1

            parts = line.split()
            thickness = float(parts[1]) if len(parts) > 1 else 0.0

            current_surface = {
                "type": "IMAGE",
                "thickness": thickness,
                "diameter": None,  # Set to None first, wait for CIR line to update
                "index": surface_index,
            }
            print(f"[Line {line_num}] Image surface")
            continue
        # Handle surface attributes (CIR, STO, ASP, K, A~J, etc.)
        if current_surface:
            if line.startswith("CIR"):
                current_surface["diameter"] = float(
                    line.split()[1].replace(";", "")
                )
                print(f"[Line {line_num}]   → CIR={current_surface['diameter']}")

            elif line.startswith("STO"):
                current_surface["is_stop"] = True
                print(f"[Line {line_num}]   → Aperture stop flag")

            elif line.startswith("ASP"):
                current_surface["type"] = "ASPHERIC"
                print(f"[Line {line_num}]   → Aspheric surface")

            elif line.startswith("K "):
                current_surface["conic"] = float(line.split()[1].replace(";", ""))
                print(f"[Line {line_num}]   → K={current_surface['conic']}")

            # Only extract single-letter coefficients A-J
            elif any(
                line.startswith(p)
                for p in [
                    "A ",
                    "B ",
                    "C ",
                    "D ",
                    "E ",
                    "F ",
                    "G ",
                    "H ",
                    "I ",
                    "J ",
                ]
            ):
                parts = line.replace(";", "").split()
                i = 0
                while i < len(parts) - 1:
                    try:
                        key = parts[i]
                        # Only accept single letters within the range A-J
                        if len(key) == 1 and key in [
                            "A",
                            "B",
                            "C",
                            "D",
                            "E",
                            "F",
                            "G",
                            "H",
                            "I",
                            "J",
                        ]:
                            value = float(parts[i + 1])
                            current_surface["asph_coeffs"][key] = value
                            print(f"[Line {line_num}]   → {key}={value}")
                        i += 2
                    except:
                        i += 1

    # Save the last surface
    if current_surface:
        surfaces.append(current_surface)

    print(f"\nParsing complete, total {len(surfaces)} surfaces\n")

    # ============ Step 2: Create surface objects ============
    print(f"{'=' * 60}")
    print("Start creating surface objects:")
    print(f"{'=' * 60}\n")

    self.surfaces = []
    d = 0.0  # Cumulative distance from the first optical surface to the current surface
    previous_material = "air"

    for surf in surfaces:
        surf_idx = surf["index"]
        surf_type = surf["type"]

        print(f"{'=' * 50}")
        print(f"Processing surface{surf_idx} ({surf_type}), current d={d:.4f}")

        # Handle object surface
        if surf_type == "OBJECT":
            obj_thickness = surf["thickness"]
            if obj_thickness < 1e9:  # Finite object distance
                d += obj_thickness
                print(
                    f"   Object surface thickness={obj_thickness} → accumulated d={d:.4f}"
                )
            else:
                print("   Object surface at infinity")
            previous_material = "air"
            continue

        # Handle image surface
        if surf_type == "IMAGE":
            self.d_sensor = torch.tensor(d)
            # Read diameter from surf dictionary (CIR value)
            self.r_sensor = (
                surf.get("diameter") if surf.get("diameter") is not None else 18.0
            )
            print(
                f"   Image plane position: d_sensor={d:.4f}, r_sensor={self.r_sensor:.4f}"
            )
            break

        # Get surface parameters
        current_material = surf.get("material", "AIR")
        if current_material in ["AIR", "0.0", "", None]:
            current_material = "air"
        else:
            current_material = current_material.lower()

        c = surf.get("curvature", 0.0)
        r = surf.get("diameter", 10.0)
        d_next = surf.get("thickness", 0.0)
        is_stop = surf.get("is_stop", False)

        print(f"   C={c:.6f}, R_aperture={r:.4f}, T={d_next:.4f}")
        print(f"   Material: {previous_material}{current_material}")
        print(f"   is_stop={is_stop}")

        # Create surface object
        try:
            # Case 1: pure aperture (air on both sides + STO flag)
            if is_stop and current_material == "air" and previous_material == "air":
                aperture = Aperture(r=r, d=d)
                self.surfaces.append(aperture)
                print(f"   Created pure aperture: Aperture(r={r:.4f}, d={d:.4f})")

            # Case 2: refractive surface (material change)
            elif current_material != previous_material:
                if surf_type == "STANDARD":
                    s = Spheric(c=c, r=r, d=d, mat2=current_material)
                    self.surfaces.append(s)
                    status = " (stop surface)" if is_stop else ""
                    print(
                        f"   Created spherical surface{status}: Spheric(c={c:.6f}, r={r:.4f}, d={d:.4f}, mat2='{current_material}')"
                    )

                elif surf_type == "ASPHERIC":
                    k = surf.get("conic", 0.0)
                    asph_coeffs = surf.get("asph_coeffs", {})

                    # CODE V aspheric coefficient mapping (shift forward by one position):
                    # A → ai[1] (2nd term, ρ²)
                    # B → ai[2] (4th term, ρ⁴)
                    # C → ai[3] (6th term, ρ⁶)
                    # D → ai[4] (8th term, ρ⁸)
                    # E → ai[5] (10th term, ρ¹⁰)
                    # F → ai[6] (12th term, ρ¹²)
                    # G → ai[7] (14th term, ρ¹⁴)
                    # H → ai[8] (16th term, ρ¹⁶)
                    # I → ai[9] (18th term, ρ¹⁸)

                    # Initialize ai array (10 elements)
                    ai = [0.0] * 10
                    ai[0] = 0.0  # ρ⁰ term (unused)
                    ai[1] = asph_coeffs.get("A", 0.0)  # ρ²
                    ai[2] = asph_coeffs.get("B", 0.0)  # ρ⁴
                    ai[3] = asph_coeffs.get("C", 0.0)  # ρ⁶
                    ai[4] = asph_coeffs.get("D", 0.0)  # ρ⁸
                    ai[5] = asph_coeffs.get("E", 0.0)  # ρ¹⁰
                    ai[6] = asph_coeffs.get("F", 0.0)  # ρ¹²
                    ai[7] = asph_coeffs.get("G", 0.0)  # ρ¹⁴
                    ai[8] = asph_coeffs.get("H", 0.0)  # ρ¹⁶
                    ai[9] = asph_coeffs.get("I", 0.0)  # ρ¹⁸

                    s = Aspheric(c=c, r=r, d=d, ai=ai, k=k, mat2=current_material)
                    self.surfaces.append(s)
                    status = " (stop surface)" if is_stop else ""
                    print(
                        f"   Created aspheric surface{status}: Aspheric(c={c:.6f}, r={r:.4f}, d={d:.4f}, k={k}, mat2='{current_material}')"
                    )
                    if any(
                        ai[1:]
                    ):  # If there are non-zero higher-order terms (starting from ai[1])
                        print(
                            f"      Aspheric coefficients: A={ai[1]:.2e}, B={ai[2]:.2e}, C={ai[3]:.2e}, D={ai[4]:.2e}"
                        )

            else:
                print(f"   Skipped (same material on both sides and no stop flag)")

        except Exception as e:
            print(f"   Failed to create surface: {e}")
            import traceback

            traceback.print_exc()

        # Key: accumulate distance at the end of the loop
        d += d_next
        print(f"   After accumulation: d={d:.4f}")
        previous_material = current_material

    print(f"\n{'=' * 60}")
    print(f"   Done! Created {len(self.surfaces)} objects")
    print(f"   d_sensor={self.d_sensor:.4f}")
    print(f"   r_sensor={self.r_sensor:.4f}")
    print(f"   hfov={self.hfov:.4f}°")
    print(f"{'=' * 60}\n")

    return self

write_lens_seq

write_lens_seq(filename='./test.seq')

Write the lens to a CODE V .seq sequential file.

Exports surfaces, materials, field definitions, and entrance pupil settings in CODE V format.

Parameters:

Name Type Description Default
filename str

Output file path. Defaults to './test.seq'.

'./test.seq'

Returns:

Name Type Description
GeoLens

self, for method chaining.

Source code in src/geolens_pkg/io.py
def write_lens_seq(self, filename="./test.seq"):
    """Write the lens to a CODE V .seq sequential file.

    Exports surfaces, materials, field definitions, and entrance pupil
    settings in CODE V format.

    Args:
        filename (str, optional): Output file path. Defaults to './test.seq'.

    Returns:
        GeoLens: ``self``, for method chaining.
    """

    import datetime

    current_date = datetime.datetime.now().strftime("%d-%b-%Y")

    head_str = f"""RDM;LEN       "VERSION: 2023.03       LENS VERSION: 89       Creation Date:  {current_date}"
TITLE 'Lens Design'
EPD   {self.enpd}
DIM   M
WL    650.0 550.0 480.0
REF   2
WTW   1 2 1
INI   '   '
XAN   0.0 0.0 0.0
YAN   0.0  {0.707 * self.rfov_eff * 57.3} {0.99 * self.rfov_eff * 57.3}
WTF   1.0 1.0 1.0
VUY   0.0 0.0 0.0
VLY   0.0 0.0 0.0
DOR   1.15 1.05
SO    0.0 0.1e14
"""

    lens_seq_str = head_str
    previous_material = "air"

    for i, surf in enumerate(self.surfaces):
        if i < len(self.surfaces) - 1:
            d_next = self.surfaces[i + 1].d - surf.d
        else:
            d_next = float(self.d_sensor - surf.d)

        current_material = getattr(surf, "mat2", "air")

        if current_material is None or current_material == "air":
            material_str = ""
            material_name = "air"
        elif isinstance(current_material, str):
            material_str = f" {current_material.upper()}"
            material_name = current_material
        else:
            material_name = getattr(current_material, "name", str(current_material))
            material_str = f" {material_name.upper()}"

        is_aperture = surf.__class__.__name__ == "Aperture"

        if is_aperture:
            continue

        is_aspheric = surf.__class__.__name__ == "Aspheric"
        is_stop_surface = getattr(surf, "is_stop", False)

        if is_aspheric:
            c_val = float(surf.c.detach().item()) if torch.is_tensor(surf.c) else float(surf.c)
            if abs(c_val) > 1e-10:
                radius = 1.0 / c_val
            else:
                radius = 0.0

            k = float(surf.k.detach().item()) if hasattr(surf, "k") and torch.is_tensor(surf.k) else float(getattr(surf, "k", 0.0))
            ai_tensor = surf.ai if hasattr(surf, "ai") else None
            if ai_tensor is None:
                ai = [0.0] * 10
            else:
                ai = [float(a.detach().item()) if torch.is_tensor(a) else float(a) for a in ai_tensor]

            surf_str = f"S     {radius} {d_next}{material_str}\n"
            surf_str += f"  CCY 0; THC 0\n"
            surf_str += f"  CIR {surf.r}\n"
            if is_stop_surface:
                surf_str += f"  STO\n"
            surf_str += f"  ASP\n"
            surf_str += f"  K   {k}\n"

            if len(ai) > 4 and any(ai[1:5]):
                surf_str += f"  A   {ai[1]:.16e}; B {ai[2]:.16e}; C&\n"
                surf_str += f"   {ai[3]:.16e}; D {ai[4]:.16e}\n"

            if len(ai) > 8 and any(ai[5:9]):
                surf_str += f"  E   {ai[5]:.16e}; F {ai[6]:.16e}; G {ai[7]:.16e}; H {ai[8]:.16e}\n"

        else:
            if abs(surf.c) > 1e-10:
                radius = 1.0 / surf.c
            else:
                radius = 0.0

            surf_str = f"S     {radius} {d_next}{material_str}\n"
            surf_str += f"  CCY 0; THC 0\n"

            if is_stop_surface:
                surf_str += f"  STO\n"

            surf_str += f"  CIR {surf.r}\n"

        lens_seq_str += surf_str
        previous_material = material_name

    sensor_str = f"SI    0.0 0.0\n"
    sensor_str += f"  CIR {self.r_sensor}\n"
    lens_seq_str += sensor_str
    lens_seq_str += "GO \n"

    with open(filename, "w") as f:
        f.write(lens_seq_str)

    print(f"Lens written to CODE V file: {filename}")
    return self

read_lens_json

read_lens_json(filename='./test.json')

Read the lens from a JSON file.

Loads lens configuration including surfaces, materials, and optical properties from the DeepLens native JSON format.

Parameters:

Name Type Description Default
filename str

Path to the JSON lens file. Defaults to './test.json'.

'./test.json'
Note

After loading, the lens is moved to self.device and post_computation is called to calculate derived properties.

Source code in src/geolens_pkg/io.py
def read_lens_json(self, filename="./test.json"):
    """Read the lens from a JSON file.

    Loads lens configuration including surfaces, materials, and optical properties
    from the DeepLens native JSON format.

    Args:
        filename (str, optional): Path to the JSON lens file. Defaults to './test.json'.

    Note:
        After loading, the lens is moved to self.device and post_computation is called
        to calculate derived properties.
    """
    self.surfaces = []
    self.materials = []
    with open(filename, "r") as f:
        data = json.load(f)
        d = 0.0
        for idx, surf_dict in enumerate(data["surfaces"]):
            surf_dict["d"] = d
            surf_dict["surf_idx"] = idx

            if surf_dict["type"] == "Aperture":
                s = Aperture.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Aspheric":
                s = Aspheric.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Cubic":
                s = Cubic.init_from_dict(surf_dict)

            # elif surf_dict["type"] == "GaussianRBF":
            #     s = GaussianRBF.init_from_dict(surf_dict)

            # elif surf_dict["type"] == "NURBS":
            #     s = NURBS.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Binary2Phase":
                s = Binary2Phase.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Phase":
                s = Phase.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Plane":
                s = Plane.init_from_dict(surf_dict)

            # elif surf_dict["type"] == "PolyEven":
            #     s = PolyEven.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Stop":
                s = Aperture.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Spheric":
                s = Spheric.init_from_dict(surf_dict)

            elif surf_dict["type"] == "ThinLens":
                s = ThinLens.init_from_dict(surf_dict)

            else:
                raise Exception(
                    f"Surface type {surf_dict['type']} is not implemented in GeoLens.read_lens_json()."
                )

            if surf_dict.get("is_aperture", False):
                self.aper_idx = idx

            self.surfaces.append(s)
            d += surf_dict["d_next"]

    self.d_sensor = torch.tensor(d)
    self.lens_info = data.get("info", "None")
    self.enpd = data.get("enpd", None)
    self.float_enpd = True if self.enpd is None else False
    self.float_foclen = False
    self.float_rfov = False
    self.r_sensor = data["r_sensor"]

    self.to(self.device)

    # Set sensor size and resolution
    sensor_res = data.get("sensor_res", (2000, 2000))
    self.set_sensor_res(sensor_res=sensor_res)
    self.post_computation()

write_lens_json

write_lens_json(filename='./test.json')

Write the lens to a JSON file.

Saves the complete lens configuration including all surfaces, materials, focal length, F-number, and sensor properties to the DeepLens JSON format.

Parameters:

Name Type Description Default
filename str

Path for the output JSON file. Defaults to './test.json'.

'./test.json'
Source code in src/geolens_pkg/io.py
def write_lens_json(self, filename="./test.json"):
    """Write the lens to a JSON file.

    Saves the complete lens configuration including all surfaces, materials,
    focal length, F-number, and sensor properties to the DeepLens JSON format.

    Args:
        filename (str, optional): Path for the output JSON file. Defaults to './test.json'.
    """
    data = {}
    data["info"] = self.lens_info if hasattr(self, "lens_info") else "None"
    data["foclen"] = round(self.foclen, 4)
    data["fnum"] = round(self.fnum, 4)
    if self.float_enpd is False:
        data["enpd"] = round(self.enpd, 4)
    data["r_sensor"] = self.r_sensor
    data["(d_sensor)"] = round(self.d_sensor.item(), 4)
    data["(sensor_size)"] = [round(i, 4) for i in self.sensor_size]
    data["surfaces"] = []
    for i, s in enumerate(self.surfaces):
        surf_dict = {"idx": i}
        surf_dict.update(s.surf_dict())
        if i == self.aper_idx and not isinstance(s, Aperture):
            surf_dict["is_aperture"] = True
        if i < len(self.surfaces) - 1:
            surf_dict["d_next"] = round(
                self.surfaces[i + 1].d.item() - self.surfaces[i].d.item(), 4
            )
        else:
            surf_dict["d_next"] = round(
                self.d_sensor.item() - self.surfaces[i].d.item(), 4
            )

        data["surfaces"].append(surf_dict)

    with open(filename, "w") as f:
        json.dump(data, f, indent=4)
    print(f"Lens written to {filename}")

read_lens_json(filename)

Load a lens from JSON format.

write_lens_json(filename)

Save a lens to JSON format.

read_lens_zmx(filename)

Load a lens from Zemax .zmx format.

write_lens_zmx(filename)

Save a lens to Zemax .zmx format.

read_lens_seq(filename)

Load a lens from Code V .seq format.


Visualization

src.geolens_pkg.vis.GeoLensVis

Mixin providing 2-D lens layout and ray visualisation for GeoLens.

Generates publication-quality cross-section plots showing lens surfaces and traced ray bundles in either the meridional or sagittal plane.

This class is not instantiated directly; it is mixed into :class:~deeplens.optics.geolens.GeoLens.

sample_parallel_2D

sample_parallel_2D(fov=0.0, num_rays=7, wvln=DEFAULT_WAVE, plane='meridional', entrance_pupil=True, depth=0.0)

Sample parallel rays (2D) in object space.

Used for (1) drawing lens setup, (2) 2D geometric optics calculation, for example, refocusing to infinity

Parameters:

Name Type Description Default
fov float

incident angle (in degree). Defaults to 0.0.

0.0
depth float

sampling depth. Defaults to 0.0.

0.0
num_rays int

ray number. Defaults to 7.

7
wvln float

ray wvln. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
plane str

sampling plane. Defaults to "meridional" (y-z plane).

'meridional'
entrance_pupil bool

whether to use entrance pupil. Defaults to True.

True

Returns:

Name Type Description
ray Ray object

Ray object. Shape [num_rays, 3]

Source code in src/geolens_pkg/vis.py
@torch.no_grad()
def sample_parallel_2D(
    self,
    fov=0.0,
    num_rays=7,
    wvln=DEFAULT_WAVE,
    plane="meridional",
    entrance_pupil=True,
    depth=0.0,
):
    """Sample parallel rays (2D) in object space.

    Used for (1) drawing lens setup, (2) 2D geometric optics calculation, for example, refocusing to infinity

    Args:
        fov (float, optional): incident angle (in degree). Defaults to 0.0.
        depth (float, optional): sampling depth. Defaults to 0.0.
        num_rays (int, optional): ray number. Defaults to 7.
        wvln (float, optional): ray wvln. Defaults to DEFAULT_WAVE.
        plane (str, optional): sampling plane. Defaults to "meridional" (y-z plane).
        entrance_pupil (bool, optional): whether to use entrance pupil. Defaults to True.

    Returns:
        ray (Ray object): Ray object. Shape [num_rays, 3]
    """
    # Sample points on the pupil
    if entrance_pupil:
        pupilz, pupilr = self.get_entrance_pupil()
    else:
        pupilz, pupilr = 0, self.surfaces[0].r

    # Sample ray origins, shape [num_rays, 3]
    if plane == "sagittal":
        ray_o = torch.stack(
            (
                torch.linspace(-pupilr, pupilr, num_rays) * 0.99,
                torch.full((num_rays,), 0),
                torch.full((num_rays,), pupilz),
            ),
            axis=-1,
        )
    elif plane == "meridional":
        ray_o = torch.stack(
            (
                torch.full((num_rays,), 0),
                torch.linspace(-pupilr, pupilr, num_rays) * 0.99,
                torch.full((num_rays,), pupilz),
            ),
            axis=-1,
        )
    else:
        raise ValueError(f"Invalid plane: {plane}")

    # Sample ray directions, shape [num_rays, 3]
    if plane == "sagittal":
        ray_d = torch.stack(
            (
                torch.full((num_rays,), float(np.sin(np.deg2rad(fov)))),
                torch.zeros((num_rays,)),
                torch.full((num_rays,), float(np.cos(np.deg2rad(fov)))),
            ),
            axis=-1,
        )
    elif plane == "meridional":
        ray_d = torch.stack(
            (
                torch.zeros((num_rays,)),
                torch.full((num_rays,), float(np.sin(np.deg2rad(fov)))),
                torch.full((num_rays,), float(np.cos(np.deg2rad(fov)))),
            ),
            axis=-1,
        )
    else:
        raise ValueError(f"Invalid plane: {plane}")

    # Form rays and propagate to the target depth
    rays = Ray(ray_o, ray_d, wvln, device=self.device)
    rays.prop_to(depth)
    return rays

sample_point_source_2D

sample_point_source_2D(fov=0.0, depth=DEPTH, num_rays=7, wvln=DEFAULT_WAVE, entrance_pupil=True)

Sample point source rays (2D) in object space.

Used for (1) drawing lens setup.

Parameters:

Name Type Description Default
fov float

incident angle (in degree). Defaults to 0.0.

0.0
depth float

sampling depth. Defaults to DEPTH.

DEPTH
num_rays int

ray number. Defaults to 7.

7
wvln float

ray wvln. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
entrance_pupil bool

whether to use entrance pupil. Defaults to False.

True

Returns:

Name Type Description
ray Ray object

Ray object. Shape [num_rays, 3]

Source code in src/geolens_pkg/vis.py
@torch.no_grad()
def sample_point_source_2D(
    self,
    fov=0.0,
    depth=DEPTH,
    num_rays=7,
    wvln=DEFAULT_WAVE,
    entrance_pupil=True,
):
    """Sample point source rays (2D) in object space.

    Used for (1) drawing lens setup.

    Args:
        fov (float, optional): incident angle (in degree). Defaults to 0.0.
        depth (float, optional): sampling depth. Defaults to DEPTH.
        num_rays (int, optional): ray number. Defaults to 7.
        wvln (float, optional): ray wvln. Defaults to DEFAULT_WAVE.
        entrance_pupil (bool, optional): whether to use entrance pupil. Defaults to False.

    Returns:
        ray (Ray object): Ray object. Shape [num_rays, 3]
    """
    # Sample point on the object plane
    ray_o = torch.tensor([depth * float(np.tan(np.deg2rad(fov))), 0.0, depth])
    ray_o = ray_o.unsqueeze(0).repeat(num_rays, 1)

    # Sample points (second point) on the pupil
    if entrance_pupil:
        pupilz, pupilr = self.calc_entrance_pupil()
    else:
        pupilz, pupilr = 0, self.surfaces[0].r

    x2 = torch.linspace(-pupilr, pupilr, num_rays) * 0.99
    y2 = torch.zeros_like(x2)
    z2 = torch.full_like(x2, pupilz)
    ray_o2 = torch.stack((x2, y2, z2), axis=1)

    # Form the rays
    ray_d = ray_o2 - ray_o
    ray = Ray(ray_o, ray_d, wvln, device=self.device)

    # Propagate rays to the sampling depth
    ray.prop_to(depth)
    return ray

draw_layout

draw_layout(filename, depth=float('inf'), zmx_format=True, multi_plot=False, lens_title=None, show=False)

Plot 2D lens layout with ray tracing.

Parameters:

Name Type Description Default
filename

Output filename

required
depth

Depth for ray tracing

float('inf')
entrance_pupil

Whether to use entrance pupil

required
zmx_format

Whether to use ZMX format

True
multi_plot

Whether to create multiple plots

False
lens_title

Title for the lens plot

None
show

Whether to show the plot

False
Source code in src/geolens_pkg/vis.py
def draw_layout(
    self,
    filename,
    depth=float("inf"),
    zmx_format=True,
    multi_plot=False,
    lens_title=None,
    show=False,
):
    """Plot 2D lens layout with ray tracing.

    Args:
        filename: Output filename
        depth: Depth for ray tracing
        entrance_pupil: Whether to use entrance pupil
        zmx_format: Whether to use ZMX format
        multi_plot: Whether to create multiple plots
        lens_title: Title for the lens plot
        show: Whether to show the plot
    """
    num_rays = 11
    num_views = 3

    # Lens title
    if lens_title is None:
        eff_foclen = round(float(self.foclen), 2)
        eq_foclen = round(float(self.eqfl), 2)
        fov_deg = round(2 * self.rfov * 180 / torch.pi, 1)
        sensor_r = round(self.r_sensor, 1)
        sensor_w, sensor_h = self.sensor_size
        sensor_w = round(sensor_w, 1)
        sensor_h = round(sensor_h, 1)

        if self.aper_idx is not None:
            _, pupil_r = self.calc_entrance_pupil()
            fnum = round(eff_foclen / pupil_r / 2, 2)
            lens_title = f"FocLen{eff_foclen}mm - F/{fnum} - FoV{fov_deg}(Equivalent {eq_foclen}mm) - Sensor Diagonal {2 * sensor_r}mm"
        else:
            lens_title = f"FocLen{eff_foclen}mm - FoV{fov_deg}(Equivalent {eq_foclen}mm) - Sensor Diagonal {2 * sensor_r}mm"

    # Draw lens layout
    colors_list = ["#CC0000", "#006600", "#0066CC"]
    rfov_deg = float(np.rad2deg(self.rfov))
    fov_ls = np.linspace(0, rfov_deg * 0.99, num=num_views)

    if not multi_plot:
        ax, fig = self.draw_lens_2d(zmx_format=zmx_format)
        fig.suptitle(lens_title, fontsize=10)
        for i, fov in enumerate(fov_ls):
            # Sample rays, shape (num_rays, 3)
            if depth == float("inf"):
                ray = self.sample_parallel_2D(
                    fov=fov,
                    wvln=WAVE_RGB[2 - i],
                    num_rays=num_rays,
                    depth=-1.0,
                    plane="sagittal",
                )
            else:
                ray = self.sample_point_source_2D(
                    fov=fov,
                    depth=depth,
                    num_rays=num_rays,
                    wvln=WAVE_RGB[2 - i],
                )
                ray.prop_to(-1.0)

            # Trace rays to sensor and plot ray paths
            _, ray_o_record = self.trace2sensor(ray=ray, record=True)
            ax, fig = self.draw_ray_2d(
                ray_o_record, ax=ax, fig=fig, color=colors_list[i]
            )

        ax.axis("off")

    else:
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        fig.suptitle(lens_title, fontsize=10)
        for i, wvln in enumerate(WAVE_RGB):
            ax = axs[i]
            ax, fig = self.draw_lens_2d(ax=ax, fig=fig, zmx_format=zmx_format)
            for fov in fov_ls:
                # Sample rays, shape (num_rays, 3)
                if depth == float("inf"):
                    ray = self.sample_parallel_2D(
                        fov=fov,
                        num_rays=num_rays,
                        wvln=wvln,
                        plane="sagittal",
                    )
                else:
                    ray = self.sample_point_source_2D(
                        fov=fov,
                        depth=depth,
                        num_rays=num_rays,
                        wvln=wvln,
                    )

                # Trace rays to sensor and plot ray paths
                ray_out, ray_o_record = self.trace2sensor(ray=ray, record=True)
                ax, fig = self.draw_ray_2d(
                    ray_o_record, ax=ax, fig=fig, color=colors_list[i]
                )
                ax.axis("off")

    if show:
        fig.show()
    else:
        fig.savefig(filename, format="png", dpi=300)
        plt.close()

draw_lens_2d

draw_lens_2d(ax=None, fig=None, color='k', linestyle='-', zmx_format=False, fix_bound=False)

Draw lens cross-section layout in a 2D plot.

Renders each surface profile, connects lens elements with edge lines, and draws the sensor plane.

Parameters:

Name Type Description Default
ax Axes

Existing axes to draw on. If None, creates a new figure. Defaults to None.

None
fig Figure

Existing figure. Defaults to None.

None
color str

Line colour for lens outlines. Defaults to 'k'.

'k'
linestyle str

Line style. Defaults to '-'.

'-'
zmx_format bool

If True, draw stepped edge connections matching Zemax layout style. Defaults to False.

False
fix_bound bool

If True, use fixed axis limits [-1,7]x[-4,4]. Defaults to False.

False

Returns:

Name Type Description
tuple

(ax, fig) matplotlib axes and figure objects.

Source code in src/geolens_pkg/vis.py
def draw_lens_2d(
    self,
    ax=None,
    fig=None,
    color="k",
    linestyle="-",
    zmx_format=False,
    fix_bound=False,
):
    """Draw lens cross-section layout in a 2D plot.

    Renders each surface profile, connects lens elements with edge lines,
    and draws the sensor plane.

    Args:
        ax (matplotlib.axes.Axes, optional): Existing axes to draw on. If None,
            creates a new figure. Defaults to None.
        fig (matplotlib.figure.Figure, optional): Existing figure. Defaults to None.
        color (str, optional): Line colour for lens outlines. Defaults to 'k'.
        linestyle (str, optional): Line style. Defaults to '-'.
        zmx_format (bool, optional): If True, draw stepped edge connections
            matching Zemax layout style. Defaults to False.
        fix_bound (bool, optional): If True, use fixed axis limits [-1,7]x[-4,4].
            Defaults to False.

    Returns:
        tuple: (ax, fig) matplotlib axes and figure objects.
    """
    # If no ax is given, generate a new one.
    if ax is None and fig is None:
        # fig, ax = plt.subplots(figsize=(6, 6))
        fig, ax = plt.subplots()

    # Draw lens surfaces
    for i, s in enumerate(self.surfaces):
        s.draw_widget(ax)

    # Connect two surfaces
    for i in range(len(self.surfaces)):
        if self.surfaces[i].mat2.n > 1.1:
            s_prev = self.surfaces[i]
            s = self.surfaces[i + 1]

            r_prev = float(s_prev.draw_r())
            r = float(s.draw_r())
            sag_prev = s_prev.surface_with_offset(
                r_prev, 0.0, valid_check=False
            ).item()
            sag = s.surface_with_offset(
                r, 0.0, valid_check=False
            ).item()

            if r_prev >= r:
                # Front surface wider: go axially forward at r_prev, then step radially inward
                z = np.array([sag_prev, sag, sag])
                x = np.array([r_prev, r_prev, r])
            else:
                # Rear surface wider: step radially outward at z_prev, then go axially forward
                z = np.array([sag_prev, sag_prev, sag])
                x = np.array([r_prev, r, r])

            if not zmx_format:
                # In non-zmx mode use a direct diagonal between the two outer edges
                z = np.array([z[0], z[-1]])
                x = np.array([x[0], x[-1]])

            ax.plot(z, -x, color, linewidth=0.75)
            ax.plot(z, x, color, linewidth=0.75)
            s_prev = s

    # Draw sensor
    ax.plot(
        [self.d_sensor.item(), self.d_sensor.item()],
        [-self.r_sensor, self.r_sensor],
        color,
    )

    # Set figure size
    if fix_bound:
        ax.set_aspect("equal")
        ax.set_xlim(-1, 7)
        ax.set_ylim(-4, 4)
    else:
        ax.set_aspect("equal", adjustable="datalim", anchor="C")
        ax.minorticks_on()
        ax.set_xlim(-0.5, 7.5)
        ax.set_ylim(-4, 4)
        ax.autoscale()

    return ax, fig

draw_ray_2d

draw_ray_2d(ray_o_record, ax, fig, color='b')

Plot ray paths.

Parameters:

Name Type Description Default
ray_o_record list

list of intersection points.

required
ax Axes

matplotlib axes.

required
fig Figure

matplotlib figure.

required
Source code in src/geolens_pkg/vis.py
def draw_ray_2d(self, ray_o_record, ax, fig, color="b"):
    """Plot ray paths.

    Args:
        ray_o_record (list): list of intersection points.
        ax (matplotlib.axes.Axes): matplotlib axes.
        fig (matplotlib.figure.Figure): matplotlib figure.
    """
    # shape (num_view, num_rays, num_path, 2)
    ray_o_record = torch.stack(ray_o_record, dim=-2).cpu().numpy()
    if ray_o_record.ndim == 3:
        ray_o_record = ray_o_record[None, ...]

    for idx_view in range(ray_o_record.shape[0]):
        for idx_ray in range(ray_o_record.shape[1]):
            ax.plot(
                ray_o_record[idx_view, idx_ray, :, 2],
                ray_o_record[idx_view, idx_ray, :, 0],
                color,
                linewidth=0.8,
            )

            # ax.scatter(
            #     ray_o_record[idx_view, idx_ray, :, 2],
            #     ray_o_record[idx_view, idx_ray, :, 0],
            #     "b",
            #     marker="x",
            # )

    return ax, fig

create_barrier

create_barrier(filename, barrier_thickness=1.0, ring_height=0.5, ring_size=1.0)

Create a 3D barrier for the lens system.

Parameters:

Name Type Description Default
filename

Path to save the figure

required
barrier_thickness

Thickness of the barrier

1.0
ring_height

Height of the annular ring

0.5
ring_size

Size of the annular ring

1.0
Source code in src/geolens_pkg/vis.py
def create_barrier(
    self, filename, barrier_thickness=1.0, ring_height=0.5, ring_size=1.0
):
    """Create a 3D barrier for the lens system.

    Args:
        filename: Path to save the figure
        barrier_thickness: Thickness of the barrier
        ring_height: Height of the annular ring
        ring_size: Size of the annular ring
    """
    barriers = []
    rings = []

    # Create barriers
    barrier_z = 0.0
    barrier_r = 0.0
    barrier_length = 0.0
    for i in range(len(self.surfaces)):
        barrier_r = max(self.surfaces[i].r, barrier_r)

        if self.surfaces[i].mat2.get_name() != "air":
            # Update the barrier radius
            # barrier_r = max(geolens.surfaces[i].r, barrier_r)
            pass
        else:
            # Extend the barrier till middle of the air space to the next surface
            max_curr_surf_d = self.surfaces[i].d.item() + max(
                self.surfaces[i].surface_sag(0.0, self.surfaces[i].r), 0.0
            )
            if i < len(self.surfaces) - 1:
                min_next_surf_d = self.surfaces[i + 1].d.item() + min(
                    self.surfaces[i + 1].surface_sag(0.0, self.surfaces[i + 1].r),
                    0.0,
                )
                extra_space = (min_next_surf_d - max_curr_surf_d) / 2
            else:
                min_next_surf_d = self.d_sensor.item()
                extra_space = min_next_surf_d - max_curr_surf_d

            barrier_length = max_curr_surf_d + extra_space - barrier_z

            # Create a barrier
            barrier = {
                "pos_z": barrier_z,
                "pos_r": barrier_r,
                "length": barrier_length,
                "thickness": barrier_thickness,
            }
            barriers.append(barrier)

            # Reset the barrier parameters
            barrier_z = barrier_length + barrier_z
            barrier_r = 0.0
            barrier_length = 0.0

    # # Create rings
    # for i in range(len(geolens.surfaces)):
    #     if geolens.surfaces[i].mat2.get_name() != "air":
    #         ring = {
    #             "pos_z": geolens.surfaces[i].d.item(),

    # Plot lens layout
    ax, fig = self.draw_layout(filename)

    # Plot barrier
    barrier_z_ls = []
    barrier_r_ls = []
    for b in barriers:
        barrier_z_ls.append(b["pos_z"])
        barrier_z_ls.append(b["pos_z"] + b["length"])
        barrier_r_ls.append(b["pos_r"])
        barrier_r_ls.append(b["pos_r"])
    ax.plot(barrier_z_ls, barrier_r_ls, "green", linewidth=1.0)
    ax.plot(barrier_z_ls, [-i for i in barrier_r_ls], "green", linewidth=1.0)

    # Plot rings

    fig.savefig(filename, format="png", dpi=300)
    plt.close()

    pass

draw_layout()

Draw 2D lens cross-section with ray fans.

lens.draw_layout()

draw_lens_2d(ax)

Draw lens element outlines on a matplotlib axis.

draw_ray_2d(ax, ray)

Draw traced rays on a matplotlib axis.


Tolerancing

src.geolens_pkg.eval_tolerance.GeoLensTolerance

Mixin providing tolerance analysis for GeoLens.

Implements two complementary approaches:

  • Sensitivity analysis – first-order gradient-based estimation of how each manufacturing error affects optical performance.
  • Monte-Carlo analysis – statistical sampling of random manufacturing errors to predict yield and worst-case performance.

This class is not instantiated directly; it is mixed into :class:~deeplens.optics.geolens.GeoLens.

References

Jun Dai et al., "Tolerance-Aware Deep Optics," arXiv:2502.04719, 2025.

init_tolerance

init_tolerance(tolerance_params=None)

Initialize manufacturing tolerance parameters for all surfaces.

Sets up tolerance ranges (e.g., curvature, thickness, decenter, tilt) on each surface. These are used by sample_tolerance() to simulate random manufacturing errors.

Parameters:

Name Type Description Default
tolerance_params dict

Custom tolerance specifications. If None, each surface uses its own defaults. Defaults to None.

None
Source code in src/geolens_pkg/eval_tolerance.py
def init_tolerance(self, tolerance_params=None):
    """Initialize manufacturing tolerance parameters for all surfaces.

    Sets up tolerance ranges (e.g., curvature, thickness, decenter, tilt)
    on each surface. These are used by ``sample_tolerance()`` to simulate
    random manufacturing errors.

    Args:
        tolerance_params (dict, optional): Custom tolerance specifications.
            If None, each surface uses its own defaults. Defaults to None.
    """
    if tolerance_params is None:
        tolerance_params = {}

    for i in range(len(self.surfaces)):
        self.surfaces[i].init_tolerance(tolerance_params=tolerance_params)

sample_tolerance

sample_tolerance()

Apply random manufacturing errors to all surfaces.

Randomly perturbs each surface according to its tolerance ranges and then refocuses the lens to compensate for the focus shift.

Source code in src/geolens_pkg/eval_tolerance.py
@torch.no_grad()
def sample_tolerance(self):
    """Apply random manufacturing errors to all surfaces.

    Randomly perturbs each surface according to its tolerance ranges and
    then refocuses the lens to compensate for the focus shift.
    """
    # Randomly perturb all surfaces
    for i in range(len(self.surfaces)):
        self.surfaces[i].sample_tolerance()

    # Refocus the lens
    self.refocus()

zero_tolerance

zero_tolerance()

Reset all manufacturing errors to zero (nominal lens state).

Clears the perturbations on every surface and refocuses the lens.

Source code in src/geolens_pkg/eval_tolerance.py
@torch.no_grad()
def zero_tolerance(self):
    """Reset all manufacturing errors to zero (nominal lens state).

    Clears the perturbations on every surface and refocuses the lens.
    """
    for i in range(len(self.surfaces)):
        self.surfaces[i].zero_tolerance()

    # Refocus the lens
    self.refocus()

tolerancing_sensitivity

tolerancing_sensitivity(tolerance_params=None)

Use sensitivity analysis (1st order gradient) to compute the tolerance score.

References

[1] Page 10 from: https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/08/8-Tolerancing-1.pdf [2] Fast sensitivity control method with differentiable optics. Optics Express 2025. [3] Optical Design Tolerancing. CODE V.

Source code in src/geolens_pkg/eval_tolerance.py
def tolerancing_sensitivity(self, tolerance_params=None):
    """Use sensitivity analysis (1st order gradient) to compute the tolerance score.

    References:
        [1] Page 10 from: https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/08/8-Tolerancing-1.pdf
        [2] Fast sensitivity control method with differentiable optics. Optics Express 2025.
        [3] Optical Design Tolerancing. CODE V.
    """
    # Initialize tolerance
    self.init_tolerance(tolerance_params=tolerance_params)

    # AutoDiff to compute the gradient/sensitivity
    self.get_optimizer_params()
    loss = self.loss_rms()
    loss.backward()

    # Calculate sensitivity results
    sensitivity_results = {}
    for i in range(len(self.surfaces)):
        sensitivity_results.update(self.surfaces[i].sensitivity_score())

    # Toleranced RSS (Root Sum Square) loss
    tolerancing_score = sum(
        v for k, v in sensitivity_results.items() if k.endswith("_score")
    )
    loss_rss = torch.sqrt(loss**2 + tolerancing_score).item()
    sensitivity_results["loss_nominal"] = round(loss.item(), 6)
    sensitivity_results["loss_rss"] = round(loss_rss, 6)
    return sensitivity_results

tolerancing_monte_carlo

tolerancing_monte_carlo(trials=200, spp=SPP_CALC, tolerance_params=None)

Use Monte Carlo simulation to compute the tolerance.

The default trials=200 is tuned for ~3 min runtime on GPU. For production-quality yield estimates (especially 95th/99th percentile tails), increase to 1000+.

Parameters:

Name Type Description Default
trials int

Number of Monte Carlo trials. Defaults to 200.

200
spp int

Samples per pixel for PSF calculation. Lower values run faster at the cost of noisier MTF estimates. Defaults to SPP_CALC (1024), which is ~16x faster than the full SPP_PSF.

SPP_CALC
tolerance_params dict

Tolerance parameters.

None

Returns:

Name Type Description
dict

Monte Carlo tolerance analysis results.

References

[1] https://optics.ansys.com/hc/en-us/articles/43071088477587-How-to-analyze-your-tolerance-results [2] Optical Design Tolerancing. CODE V.

Source code in src/geolens_pkg/eval_tolerance.py
@torch.no_grad()
def tolerancing_monte_carlo(self, trials=200, spp=SPP_CALC, tolerance_params=None):
    """Use Monte Carlo simulation to compute the tolerance.

    The default ``trials=200`` is tuned for ~3 min runtime on GPU.
    For production-quality yield estimates (especially 95th/99th
    percentile tails), increase to 1000+.

    Args:
        trials (int): Number of Monte Carlo trials. Defaults to 200.
        spp (int): Samples per pixel for PSF calculation. Lower values
            run faster at the cost of noisier MTF estimates. Defaults to
            SPP_CALC (1024), which is ~16x faster than the full SPP_PSF.
        tolerance_params (dict): Tolerance parameters.

    Returns:
        dict: Monte Carlo tolerance analysis results.

    References:
        [1] https://optics.ansys.com/hc/en-us/articles/43071088477587-How-to-analyze-your-tolerance-results
        [2] Optical Design Tolerancing. CODE V.
    """

    def merit_func(lens, fov=0.0, depth=DEPTH):
        """Evaluate MTF merit at a single field point."""
        try:
            point = [0, -fov / lens.rfov_eff, depth]
            psf = lens.psf(points=point, spp=spp, recenter=True)
            freq, mtf_tan, mtf_sag = lens.psf2mtf(psf, pixel_size=lens.pixel_size)

            # Evaluate MTF at quarter-Nyquist frequency
            nyquist_freq = 0.5 / lens.pixel_size
            eval_freq = 0.25 * nyquist_freq
            idx = torch.argmin(torch.abs(torch.tensor(freq) - eval_freq))
            score = (mtf_tan[idx] + mtf_sag[idx]) / 2
            return score.item()
        except RuntimeError:
            # Perturbed lens may block all rays at extreme fields
            return 0.0

    def multi_field_merit(lens, depth=DEPTH):
        """Evaluate average MTF merit across multiple field positions."""
        fov_points = [0.0, 0.5, 1.0]
        scores = [merit_func(lens, fov=fov, depth=depth) for fov in fov_points]
        return float(np.mean(scores))

    # Initialize tolerance
    self.init_tolerance(tolerance_params=tolerance_params)

    # Monte Carlo simulation
    merit_ls = []
    with torch.no_grad():
        for i in tqdm(range(trials)):
            # Sample a random perturbation and refocus sensor only
            # (skip full post_computation — focal length, pupil, and FoV
            # don't change meaningfully under small tolerance errors).
            for surf in self.surfaces:
                surf.sample_tolerance()
            self.d_sensor = self.calc_sensor_plane()

            # Evaluate perturbed performance across multiple field positions
            perturbed_merit = multi_field_merit(lens=self, depth=DEPTH)
            merit_ls.append(perturbed_merit)

            # Clear perturbation (no refocus needed — next iteration
            # will set sensor position after sampling).
            for surf in self.surfaces:
                surf.zero_tolerance()

    merit_ls = np.array(merit_ls)

    # Baseline merit (nominal lens)
    self.refocus()
    baseline_merit = multi_field_merit(lens=self, depth=DEPTH)

    # Results plot — histogram + CDF
    fig, ax1 = plt.subplots(figsize=(9, 5))

    # Histogram
    ax1.hist(
        merit_ls,
        bins=30,
        color="#4C72B0",
        alpha=0.6,
        edgecolor="white",
        label="Frequency",
    )
    ax1.set_xlabel("MTF Merit Score (higher is better)", fontsize=12)
    ax1.set_ylabel("Count", fontsize=12, color="#4C72B0")
    ax1.tick_params(axis="y", labelcolor="#4C72B0")

    # CDF on secondary axis
    ax2 = ax1.twinx()
    sorted_merit = np.sort(merit_ls)
    cdf = np.arange(1, len(sorted_merit) + 1) / len(sorted_merit) * 100
    ax2.plot(sorted_merit, cdf, color="#C44E52", linewidth=2, label="CDF")
    ax2.set_ylabel("Cumulative % of Lenses", fontsize=12, color="#C44E52")
    ax2.tick_params(axis="y", labelcolor="#C44E52")
    ax2.set_ylim(0, 105)

    # Baseline reference
    ax1.axvline(
        baseline_merit,
        color="green",
        linestyle="--",
        linewidth=1.5,
        label=f"Nominal = {baseline_merit:.3f}",
    )

    # Yield annotations — 90% and 50% yield lines
    p90 = float(np.percentile(merit_ls, 10))  # 90% of lenses exceed this
    p50 = float(np.percentile(merit_ls, 50))
    ax1.axvline(
        p90, color="orange", linestyle=":", linewidth=1.5,
        label=f"90% yield > {p90:.3f}",
    )
    ax1.axvline(
        p50, color="gray", linestyle=":", linewidth=1.5,
        label=f"50% yield > {p50:.3f}",
    )

    # Title and legend
    ax1.set_title(
        f"Monte Carlo Tolerance Analysis  ({trials} trials)",
        fontsize=13,
        fontweight="bold",
    )
    ax1.legend(loc="upper left", fontsize=9, framealpha=0.9)
    ax1.grid(True, alpha=0.2)
    fig.tight_layout()
    fig.savefig(
        "Monte_Carlo_Tolerance.png", dpi=300, bbox_inches="tight"
    )
    plt.close(fig)

    # Results dict
    results = {
        "method": "monte_carlo",
        "trials": trials,
        "baseline_merit": round(baseline_merit, 6),
        "merit_std": round(float(np.std(merit_ls)), 6),
        "merit_mean": round(float(np.mean(merit_ls)), 6),
        "merit_yield": {
            "99% > ": round(float(np.percentile(merit_ls, 1)), 4),
            "95% > ": round(float(np.percentile(merit_ls, 5)), 4),
            "90% > ": round(float(np.percentile(merit_ls, 10)), 4),
            "80% > ": round(float(np.percentile(merit_ls, 20)), 4),
            "70% > ": round(float(np.percentile(merit_ls, 30)), 4),
            "60% > ": round(float(np.percentile(merit_ls, 40)), 4),
            "50% > ": round(float(np.percentile(merit_ls, 50)), 4),
        },
        "merit_percentile": {
            "99% < ": round(float(np.percentile(merit_ls, 99)), 4),
            "95% < ": round(float(np.percentile(merit_ls, 95)), 4),
            "90% < ": round(float(np.percentile(merit_ls, 90)), 4),
            "80% < ": round(float(np.percentile(merit_ls, 80)), 4),
            "70% < ": round(float(np.percentile(merit_ls, 70)), 4),
            "60% < ": round(float(np.percentile(merit_ls, 60)), 4),
            "50% < ": round(float(np.percentile(merit_ls, 50)), 4),
        },
    }
    return results

tolerancing_wavefront

tolerancing_wavefront(tolerance_params=None)

Use wavefront differential method to compute the tolerance.

Wavefront differential method is proposed in [1], while the detailed implementation remains unknown. I (Xinge Yang) assume a symbolic differentiation is used to compute the gradient/Jacobian of the wavefront error. With AutoDiff, we can easily calculate Jacobian with gradient backpropagation, therefore I leave the implementation of this method as future work.

Parameters:

Name Type Description Default
tolerance_params dict

Tolerance parameters

None

Returns:

Name Type Description
dict

Wavefront tolerance analysis results

References

[1] Optical Design Tolerancing. CODE V.

Source code in src/geolens_pkg/eval_tolerance.py
def tolerancing_wavefront(self, tolerance_params=None):
    """Use wavefront differential method to compute the tolerance.

    Wavefront differential method is proposed in [1], while the detailed implementation remains unknown. I (Xinge Yang) assume a symbolic differentiation is used to compute the gradient/Jacobian of the wavefront error. With AutoDiff, we can easily calculate Jacobian with gradient backpropagation, therefore I leave the implementation of this method as future work.

    Args:
        tolerance_params (dict): Tolerance parameters

    Returns:
        dict: Wavefront tolerance analysis results

    References:
        [1] Optical Design Tolerancing. CODE V.
    """
    pass

Manufacturing tolerance analysis following Jun Dai et al. (arXiv:2502.04719, 2025).

tolerancing_sensitivity()

Sensitivity-based tolerance analysis.

tolerancing_monte_carlo(num_samples)

Monte Carlo tolerance analysis with random perturbations.


Reparametrization

enable_reparam()

Enable parameter reparametrization on all Aspheric surfaces. Normalizes \(c\), \(k\), \(a_i\) to \(\sim\mathcal{O}(1)\) scale.

disable_reparam()

Disable reparametrization and restore physical parameter values.