Skip to content

Ray & Light API Reference

The src.light module contains ray and wave optics representations for geometric and physical optics simulation.


Ray

Batched geometric ray bundle carrying origin, direction, wavelength, validity mask, energy, and optical path length.

src.light.Ray

Ray(o, d, wvln=DEFAULT_WAVE, coherent=False, device='cpu')

Bases: DeepObj

Batched ray bundle for optical simulation.

Stores ray origins, directions, wavelength, validity mask, energy, obliquity, and (in coherent mode) optical path length. All tensor attributes share the same batch shape (*batch_size, num_rays).

Attributes:

Name Type Description
o Tensor

Ray origins, shape (*batch, num_rays, 3) [mm].

d Tensor

Unit ray directions, shape (*batch, num_rays, 3).

wvln Tensor

Wavelength scalar [µm].

is_valid Tensor

Binary validity mask, shape (*batch, num_rays).

en Tensor

Energy weight, shape (*batch, num_rays, 1).

obliq Tensor

Obliquity factor, shape (*batch, num_rays, 1).

opl Tensor

Optical path length (coherent mode only), shape (*batch, num_rays, 1) [mm].

coherent bool

Whether OPL tracking is enabled.

Initialize a ray object.

Parameters:

Name Type Description Default
o Tensor

Ray origin, shape (..., num_rays, 3) [mm].

required
d Tensor

Ray direction, shape (..., num_rays, 3).

required
wvln float

Ray wavelength [µm].

DEFAULT_WAVE
coherent bool

Enable optical path length tracking for coherent tracing. Defaults to False.

False
device str

Compute device. Defaults to "cpu".

'cpu'
Source code in src/light/ray.py
def __init__(self, o, d, wvln=DEFAULT_WAVE, coherent=False, device="cpu"):
    """Initialize a ray object.

    Args:
        o (torch.Tensor): Ray origin, shape ``(..., num_rays, 3)`` [mm].
        d (torch.Tensor): Ray direction, shape ``(..., num_rays, 3)``.
        wvln (float): Ray wavelength [µm].
        coherent (bool): Enable optical path length tracking for coherent
            tracing. Defaults to ``False``.
        device (str): Compute device. Defaults to ``"cpu"``.
    """
    # Basic ray parameters - move to device
    self.o = (o if torch.is_tensor(o) else torch.tensor(o)).to(device)
    self.d = (d if torch.is_tensor(d) else torch.tensor(d)).to(device)
    self.shape = self.o.shape[:-1]

    # Wavelength
    assert wvln > 0.1 and wvln < 10.0, "Ray wavelength unit should be [um]"
    self.wvln = torch.tensor(wvln, device=device)

    # Auxiliary ray parameters - create directly on device
    self.is_valid = torch.ones(self.shape, device=device)
    self.en = torch.ones((*self.shape, 1), device=device)
    self.obliq = torch.ones((*self.shape, 1), device=device)

    # Coherent ray tracing
    self.coherent = coherent  # bool
    self.opl = torch.zeros((*self.shape, 1), device=device)

    self.device = device
    self.d = F.normalize(self.d, p=2, dim=-1)

prop_to

prop_to(z, n=1.0)

Ray propagates to a given depth plane.

Parameters:

Name Type Description Default
z float

depth.

required
n float

refractive index. Defaults to 1.

1.0
Source code in src/light/ray.py
def prop_to(self, z, n=1.0):
    """Ray propagates to a given depth plane.

    Args:
        z (float): depth.
        n (float, optional): refractive index. Defaults to 1.
    """
    dz = self.d[..., 2]
    nearly_parallel = dz.abs() < 1e-6
    safe_dz = torch.where(nearly_parallel, torch.ones_like(dz), dz)
    t = (z - self.o[..., 2]) / safe_dz
    t = torch.where(nearly_parallel, torch.zeros_like(t), t)
    self.is_valid = self.is_valid * (~nearly_parallel).float()

    new_o = self.o + self.d * t.unsqueeze(-1)
    valid_mask = (self.is_valid > 0).unsqueeze(-1)
    self.o = torch.where(valid_mask, new_o, self.o)

    if self.coherent:
        if t.dtype != torch.float64:
            raise Warning("Should use float64 in coherent ray tracing.")
        else:
            new_opl = self.opl + n * t.unsqueeze(-1)
            self.opl = torch.where(valid_mask, new_opl, self.opl)

    return self

centroid

centroid()

Calculate the centroid of the ray, shape (..., num_rays, 3)

Returns:

Type Description

torch.Tensor: Centroid of the ray, shape (..., 3)

Source code in src/light/ray.py
def centroid(self):
    """Calculate the centroid of the ray, shape (..., num_rays, 3)

    Returns:
        torch.Tensor: Centroid of the ray, shape (..., 3)
    """
    return (self.o * self.is_valid.unsqueeze(-1)).sum(-2) / self.is_valid.sum(
        -1
    ).add(EPSILON).unsqueeze(-1)

rms_error

rms_error(center_ref=None)

Calculate the RMS error of the ray.

Parameters:

Name Type Description Default
center_ref Tensor

Reference center of the ray, shape (..., 3). If None, use the centroid of the ray as reference.

None

Returns:

Type Description

torch.Tensor: average RMS error of the ray

Source code in src/light/ray.py
def rms_error(self, center_ref=None):
    """Calculate the RMS error of the ray.

    Args:
        center_ref (torch.Tensor): Reference center of the ray, shape (..., 3). If None, use the centroid of the ray as reference.

    Returns:
        torch.Tensor: average RMS error of the ray
    """
    # Calculate the centroid of the ray as reference
    if center_ref is None:
        with torch.no_grad():
            center_ref = self.centroid()

    center_ref = center_ref.unsqueeze(-2)

    # Calculate RMS error for each region
    rms_error = ((self.o[..., :2] - center_ref[..., :2]) ** 2).sum(-1)
    rms_error = (rms_error * self.is_valid).sum(-1) / (
        self.is_valid.sum(-1) + EPSILON
    )
    rms_error = rms_error.sqrt()

    # Average RMS error
    return rms_error.mean()

flip_xy

flip_xy()

Flip the x and y coordinates of the ray.

This function is used when calculating point spread function and wavefront distribution.

Source code in src/light/ray.py
def flip_xy(self):
    """Flip the x and y coordinates of the ray.

    This function is used when calculating point spread function and wavefront distribution.
    """
    self.o = torch.cat([-self.o[..., :2], self.o[..., 2:]], dim=-1)
    self.d = torch.cat([-self.d[..., :2], self.d[..., 2:]], dim=-1)
    return self

clone

clone(device=None)

Clone the ray.

Can spercify which device we want to clone. Sometimes we want to store all rays in CPU, and when using it, we move it to GPU.

Source code in src/light/ray.py
def clone(self, device=None):
    """Clone the ray.

    Can spercify which device we want to clone. Sometimes we want to store all rays in CPU, and when using it, we move it to GPU.
    """
    if device is None:
        return copy.deepcopy(self).to(self.device)
    else:
        return copy.deepcopy(self).to(device)

squeeze

squeeze(dim=None)

Squeeze the ray.

Parameters:

Name Type Description Default
dim int

dimension to squeeze. Defaults to None.

None
Source code in src/light/ray.py
def squeeze(self, dim=None):
    """Squeeze the ray.

    Args:
        dim (int, optional): dimension to squeeze. Defaults to None.
    """
    self.o = self.o.squeeze(dim)
    self.d = self.d.squeeze(dim)
    # wvln is a single element tensor, no squeeze needed
    self.is_valid = self.is_valid.squeeze(dim)
    self.en = self.en.squeeze(dim)
    self.opl = self.opl.squeeze(dim)
    self.obliq = self.obliq.squeeze(dim)
    return self

unsqueeze

unsqueeze(dim=None)

Unsqueeze the ray.

Parameters:

Name Type Description Default
dim int

dimension to unsqueeze. Defaults to None.

None
Source code in src/light/ray.py
def unsqueeze(self, dim=None):
    """Unsqueeze the ray.

    Args:
        dim (int, optional): dimension to unsqueeze. Defaults to None.
    """
    self.o = self.o.unsqueeze(dim)
    self.d = self.d.unsqueeze(dim)
    # wvln is a single element tensor, no unsqueeze needed
    self.is_valid = self.is_valid.unsqueeze(dim)
    self.en = self.en.unsqueeze(dim)
    self.opl = self.opl.unsqueeze(dim)
    self.obliq = self.obliq.unsqueeze(dim)
    return self
Ray(o, d, wvln=0.587, coherent=False, device="cpu")
Parameter Type Description
o Tensor Ray origins, shape (*batch, num_rays, 3) in mm
d Tensor Ray directions, shape (*batch, num_rays, 3) (auto-normalized)
wvln float Wavelength in micrometers (must be 0.1 -- 10)
coherent bool Enable optical path length tracking

Key Attributes

Attribute Type Shape Description
o Tensor (*batch, N, 3) Ray origins (mm)
d Tensor (*batch, N, 3) Unit ray directions
wvln Tensor scalar Wavelength (\(\mu\)m)
is_valid Tensor (*batch, N) Binary validity mask
en Tensor (*batch, N, 1) Energy weight
obliq Tensor (*batch, N, 1) Accumulated obliquity factor
opl Tensor (*batch, N, 1) Optical path length (mm, coherent mode)

Key Methods

prop_to(z, n=1.0)

Propagate all rays to a given z-plane.

ray.prop_to(z=50.0, n=1.0)  # Propagate to z=50mm in air

centroid()

Weighted mean position of valid rays.

center = ray.centroid()  # shape: (..., 3)

rms_error(center_ref=None)

RMS spot radius relative to a reference center.

rms = ray.rms_error()  # scalar in mm

clone(device=None)

Deep copy, optionally moving to a new device.

flip_xy()

Negate x and y components (used for PSF centering).


Wave Optics

Complex electromagnetic field representations with scalar diffraction propagation.

src.light.ComplexWave

ComplexWave(u=None, wvln=0.55, z=0.0, phy_size=(4.0, 4.0), res=(2000, 2000))

Bases: DeepObj

Complex scalar wave field for diffraction simulation.

Represents a monochromatic, coherent complex amplitude on a uniform rectangular grid. Propagation methods (ASM, Fresnel, Fraunhofer) are implemented as member functions and use torch.fft for efficiency.

Attributes:

Name Type Description
u Tensor

Complex amplitude, shape [1, 1, H, W].

wvln float

Wavelength [µm].

k float

Wave number 2π / (λ × 10⁻³) [mm⁻¹].

phy_size tuple

Physical aperture size (W, H) [mm].

ps float

Pixel pitch [mm] (must be square).

res tuple

Grid resolution (H, W) in pixels.

z float

Current axial position [mm].

Initialize a complex wave field.

Parameters:

Name Type Description Default
u Tensor or None

Initial complex amplitude. Accepted shapes: [H, W], [1, H, W], or [1, 1, H, W]. If None a zero field is created with the given res.

None
wvln float

Wavelength [µm]. Defaults to 0.55.

0.55
z float

Initial axial position [mm]. Defaults to 0.0.

0.0
phy_size tuple

Physical aperture (W, H) [mm]. Defaults to (4.0, 4.0).

(4.0, 4.0)
res tuple

Grid resolution (H, W) [pixels]. Only used when u is None. Defaults to (2000, 2000).

(2000, 2000)

Raises:

Type Description
AssertionError

If the pixel pitch is not square or the wavelength is outside the range (0.1, 10) µm.

Source code in src/light/wave.py
def __init__(
    self,
    u=None,
    wvln=0.55,
    z=0.0,
    phy_size=(4.0, 4.0),
    res=(2000, 2000),
):
    """Initialize a complex wave field.

    Args:
        u (torch.Tensor or None, optional): Initial complex amplitude.
            Accepted shapes: ``[H, W]``, ``[1, H, W]``, or
            ``[1, 1, H, W]``.  If ``None`` a zero field is created with
            the given *res*.
        wvln (float, optional): Wavelength [µm].  Defaults to ``0.55``.
        z (float, optional): Initial axial position [mm].  Defaults to
            ``0.0``.
        phy_size (tuple, optional): Physical aperture (W, H) [mm].
            Defaults to ``(4.0, 4.0)``.
        res (tuple, optional): Grid resolution (H, W) [pixels].  Only
            used when *u* is ``None``.  Defaults to ``(2000, 2000)``.

    Raises:
        AssertionError: If the pixel pitch is not square or the
            wavelength is outside the range ``(0.1, 10)`` µm.
    """
    if u is not None:
        if not u.dtype == torch.complex128:
            print(
                "A complex wave field is created with single precision. In the future, we want to always use double precision."
            )

        self.u = u if torch.is_tensor(u) else torch.from_numpy(u)
        if not self.u.is_complex():
            self.u = self.u.to(torch.complex64)

        # [H, W] or [1, H, W] to [1, 1, H, W]
        if len(u.shape) == 2:
            self.u = u.unsqueeze(0).unsqueeze(0)
        elif len(self.u.shape) == 3:
            self.u = self.u.unsqueeze(0)

        self.res = self.u.shape[-2:]

    else:
        # Initialize a zero complex wave field
        amp = torch.zeros(res).unsqueeze(0).unsqueeze(0)
        phi = torch.zeros(res).unsqueeze(0).unsqueeze(0)
        self.u = amp + 1j * phi
        self.res = res

    # Wave field parameters
    assert wvln > 0.1 and wvln < 10.0, "Wavelength should be in [um]."
    self.wvln = wvln  # [um], wavelength
    self.k = 2 * torch.pi / (self.wvln * 1e-3)  # [mm^-1], wave number
    self.phy_size = phy_size  # [mm], physical size
    assert phy_size[0] / self.res[0] == phy_size[1] / self.res[1], (
        "Pixel size is not square."
    )
    self.ps = phy_size[0] / self.res[0]  # [mm], pixel size

    # Wave field grid
    self.x, self.y = self.gen_xy_grid()  # x, y grid
    self.z = torch.full_like(self.x, z)  # z grid

    # Cache propagation method boundaries (depend only on wvln, ps, phy_size)
    self._asm_zmax = Nyquist_ASM_zmax(wvln=self.wvln, ps=self.ps, side_length=self.phy_size[0])
    self._fresnel_zmin = Fresnel_zmin(wvln=self.wvln, ps=self.ps, side_length=self.phy_size[0])

point_wave classmethod

point_wave(point=(0, 0, -1000.0), wvln=0.55, z=0.0, phy_size=(4.0, 4.0), res=(2000, 2000), valid_r=None)

Create a spherical wave field on x0y plane originating from a point source.

Parameters:

Name Type Description Default
point tuple

Point source position in object space. [mm]. Defaults to (0, 0, -1000.0).

(0, 0, -1000.0)
wvln float

Wavelength. [um]. Defaults to 0.55.

0.55
z float

Field z position. [mm]. Defaults to 0.0.

0.0
phy_size tuple

Valid plane on x0y plane. [mm]. Defaults to (2, 2).

(4.0, 4.0)
res tuple

Valid plane resoltution. Defaults to (1000, 1000).

(2000, 2000)
valid_r float

Valid circle radius. [mm]. Defaults to None.

None

Returns:

Name Type Description
field ComplexWave

Complex field on x0y plane.

Source code in src/light/wave.py
@classmethod
def point_wave(
    cls,
    point=(0, 0, -1000.0),
    wvln=0.55,
    z=0.0,
    phy_size=(4.0, 4.0),
    res=(2000, 2000),
    valid_r=None,
):
    """Create a spherical wave field on x0y plane originating from a point source.

    Args:
        point (tuple): Point source position in object space. [mm]. Defaults to (0, 0, -1000.0).
        wvln (float): Wavelength. [um]. Defaults to 0.55.
        z (float): Field z position. [mm]. Defaults to 0.0.
        phy_size (tuple): Valid plane on x0y plane. [mm]. Defaults to (2, 2).
        res (tuple): Valid plane resoltution. Defaults to (1000, 1000).
        valid_r (float): Valid circle radius. [mm]. Defaults to None.

    Returns:
        field (ComplexWave): Complex field on x0y plane.
    """
    assert wvln > 0.1 and wvln < 10.0, "Wavelength should be in [um]."
    k = 2 * torch.pi / (wvln * 1e-3)  # [mm^-1], wave number

    # Create meshgrid on target plane
    x, y = torch.meshgrid(
        torch.linspace(
            -0.5 * phy_size[0], 0.5 * phy_size[0], res[0], dtype=torch.float64
        ),
        torch.linspace(
            0.5 * phy_size[1], -0.5 * phy_size[1], res[1], dtype=torch.float64
        ),
        indexing="xy",
    )

    # Calculate distance to point source, and calculate spherical wave phase
    r = torch.sqrt((x - point[0]) ** 2 + (y - point[1]) ** 2 + (z - point[2]) ** 2)
    if point[2] < z:
        phi = k * r
    else:
        phi = -k * r
    u = (r.min() / r) * torch.exp(1j * phi)

    # Apply valid circle if provided, e.g., the aperture of a lens
    if valid_r is not None:
        mask = (x - point[0]) ** 2 + (y - point[1]) ** 2 < valid_r**2
        u = u * mask

    # Create wave field
    return cls(u=u, wvln=wvln, phy_size=phy_size, res=res, z=z)

plane_wave classmethod

plane_wave(wvln=0.55, z=0.0, phy_size=(4.0, 4.0), res=(2000, 2000), valid_r=None)

Create a planar wave field on x0y plane.

Parameters:

Name Type Description Default
wvln float

Wavelength. [um].

0.55
z float

Field z position. [mm].

0.0
phy_size tuple

Physical size of the field. [mm].

(4.0, 4.0)
res tuple

Resolution.

(2000, 2000)
valid_r float

Valid circle radius. [mm].

None

Returns:

Name Type Description
field ComplexWave

Complex field.

Source code in src/light/wave.py
@classmethod
def plane_wave(
    cls,
    wvln=0.55,
    z=0.0,
    phy_size=(4.0, 4.0),
    res=(2000, 2000),
    valid_r=None,
):
    """Create a planar wave field on x0y plane.

    Args:
        wvln (float): Wavelength. [um].
        z (float): Field z position. [mm].
        phy_size (tuple): Physical size of the field. [mm].
        res (tuple): Resolution.
        valid_r (float): Valid circle radius. [mm].

    Returns:
        field (ComplexWave): Complex field.
    """
    assert wvln > 0.1 and wvln < 10.0, "Wavelength should be in [um]."

    # Create a plane wave field
    u = torch.ones(res, dtype=torch.float64) + 0j

    # Apply valid circle if provided
    if valid_r is not None:
        x, y = torch.meshgrid(
            torch.linspace(-0.5 * phy_size[0], 0.5 * phy_size[0], res[0]),
            torch.linspace(-0.5 * phy_size[1], 0.5 * phy_size[1], res[1]),
            indexing="xy",
        )
        mask = (x**2 + y**2) < valid_r**2
        u = u * mask

    # Create wave field
    return cls(u=u, phy_size=phy_size, wvln=wvln, res=res, z=z)

image_wave classmethod

image_wave(img, wvln=0.55, z=0.0, phy_size=(4.0, 4.0))

Initialize a complex wave field from an image.

Parameters:

Name Type Description Default
img Tensor

Input image with shape [H, W] or [B, C, H, W]. Data range is [0, 1].

required
wvln float

Wavelength. [um].

0.55
z float

Field z position. [mm].

0.0
phy_size tuple

Physical size of the field. [mm].

(4.0, 4.0)

Returns:

Name Type Description
field ComplexWave

Complex field.

Source code in src/light/wave.py
@classmethod
def image_wave(cls, img, wvln=0.55, z=0.0, phy_size=(4.0, 4.0)):
    """Initialize a complex wave field from an image.

    Args:
        img (torch.Tensor): Input image with shape [H, W] or [B, C, H, W]. Data range is [0, 1].
        wvln (float): Wavelength. [um].
        z (float): Field z position. [mm].
        phy_size (tuple): Physical size of the field. [mm].

    Returns:
        field (ComplexWave): Complex field.
    """
    assert img.dtype == torch.float32, "Image must be float32."

    amp = torch.sqrt(img)
    phi = torch.zeros_like(amp)
    u = amp + 1j * phi

    return cls(u=u, wvln=wvln, phy_size=phy_size, res=u.shape[-2:], z=z)

prop

prop(prop_dist, n=1.0)

Propagate the field by distance z. Can only propagate planar wave.

Reference

[1] Modeling and propagation of near-field diffraction patterns: A more complete approach. Table 1. [2] https://github.com/kaanaksit/odak/blob/master/odak/wave/classical.py [3] https://spie.org/samples/PM103.pdf [4] "Non-approximated Rayleigh Sommerfeld diffraction integral: advantages and disadvantages in the propagation of complex wave fields"

Parameters:

Name Type Description Default
prop_dist float

propagation distance, unit [mm].

required
n float

refractive index.

1.0

Returns:

Name Type Description
self

propagated complex wave field.

Source code in src/light/wave.py
def prop(self, prop_dist, n=1.0):
    """Propagate the field by distance z. Can only propagate planar wave.

    Reference:
        [1] Modeling and propagation of near-field diffraction patterns: A more complete approach. Table 1.
        [2] https://github.com/kaanaksit/odak/blob/master/odak/wave/classical.py
        [3] https://spie.org/samples/PM103.pdf
        [4] "Non-approximated Rayleigh Sommerfeld diffraction integral: advantages and disadvantages in the propagation of complex wave fields"

    Args:
        prop_dist (float): propagation distance, unit [mm].
        n (float): refractive index.

    Returns:
        self: propagated complex wave field.
    """
    # Determine propagation method using cached boundaries
    wvln_mm = self.wvln * 1e-3  # [um] to [mm]

    # Wave propagation methods
    if prop_dist < DELTA:
        # Zero distance: do nothing
        pass

    elif prop_dist < wvln_mm:
        # Sub-wavelength distance: full wave method (e.g., FDTD)
        raise Exception(
            "The propagation distance in sub-wavelength range is not implemented yet. Have to use full wave method (e.g., FDTD)."
        )

    elif prop_dist < self._asm_zmax:
        # Angular Spectrum Method (ASM)
        self.u = AngularSpectrumMethod(self.u, z=prop_dist, wvln=self.wvln, ps=self.ps, n=n)

    elif prop_dist > self._fresnel_zmin:
        # Fresnel diffraction
        self.u = FresnelDiffraction(self.u, z=prop_dist, wvln=self.wvln, ps=self.ps, n=n)

    else:
        raise Exception(f"Propagation method not implemented for distance {prop_dist} mm.")

    # Update z grid
    self.z += prop_dist
    return self

prop_to

prop_to(z, n=1)

Propagate the field to plane z.

Parameters:

Name Type Description Default
z float

destination plane z coordinate.

required
Source code in src/light/wave.py
def prop_to(self, z, n=1):
    """Propagate the field to plane z.

    Args:
        z (float): destination plane z coordinate.
    """
    # Use float() instead of .item() to avoid GPU-CPU sync on CUDA tensors
    # (self.z is a full grid but all values are identical; [0,0] is representative)
    prop_dist = float(z) - float(self.z[0, 0])
    self.prop(prop_dist, n=n)
    return self

gen_xy_grid

gen_xy_grid()

Generate the x and y grid.

Source code in src/light/wave.py
def gen_xy_grid(self):
    """Generate the x and y grid."""
    x, y = torch.meshgrid(
        torch.linspace(-0.5 * self.phy_size[1], 0.5 * self.phy_size[1], self.res[0],),
        torch.linspace(0.5 * self.phy_size[0], -0.5 * self.phy_size[0], self.res[1],),
        indexing="xy",
    )
    return x, y

gen_freq_grid

gen_freq_grid()

Generate the frequency grid.

Source code in src/light/wave.py
def gen_freq_grid(self):
    """Generate the frequency grid."""
    x, y = self.gen_xy_grid()
    fx = x / (self.ps * self.phy_size[0])
    fy = y / (self.ps * self.phy_size[1])
    return fx, fy

load_npz

load_npz(filepath)

Load data from npz file.

Source code in src/light/wave.py
def load_npz(self, filepath):
    """Load data from npz file."""
    data = np.load(filepath)
    self.u = torch.from_numpy(data["u"])
    self.x = torch.from_numpy(data["x"])
    self.y = torch.from_numpy(data["y"])
    self.wvln = data["wvln"].item()
    self.phy_size = data["phy_size"].tolist()
    self.res = self.u.shape[-2:]

save

save(filepath='./wavefield.npz')

Save the complex wave field to a npz file.

Source code in src/light/wave.py
def save(self, filepath="./wavefield.npz"):
    """Save the complex wave field to a npz file."""
    if filepath.endswith(".npz"):
        self.save_npz(filepath)
    else:
        raise Exception("Unimplemented file format.")

save_npz

save_npz(filepath='./wavefield.npz')

Save the complex wave field to a npz file.

Source code in src/light/wave.py
def save_npz(self, filepath="./wavefield.npz"):
    """Save the complex wave field to a npz file."""
    # Save data
    np.savez_compressed(
        filepath,
        u=self.u.cpu().numpy(),
        x=self.x.cpu().numpy(),
        y=self.y.cpu().numpy(),
        wvln=np.array(self.wvln),
        phy_size=np.array(self.phy_size),
    )

    # Save intensity, amplitude, and phase images
    u = self.u.cpu()
    save_image(u.abs() ** 2, f"{filepath[:-4]}_intensity.png", normalize=True)
    save_image(u.abs(), f"{filepath[:-4]}_amp.png", normalize=True)
    save_image(u.angle(), f"{filepath[:-4]}_phase.png", normalize=True)

show

show(save_name=None, data='irr')

Save the field as an image.

Source code in src/light/wave.py
def show(self, save_name=None, data="irr"):
    """Save the field as an image."""
    cmap = "gray"
    if data == "irr":
        value = self.u.detach().abs() ** 2
    elif data == "amp":
        value = self.u.detach().abs()
    elif data == "phi" or data == "phase":
        value = torch.angle(self.u).detach()
        cmap = "hsv"
    elif data == "real":
        value = self.u.real.detach()
    elif data == "imag":
        value = self.u.imag.detach()
    else:
        raise Exception(f"Unimplemented visualization: {data}.")

    if len(self.u.shape) == 2:
        raise Exception("Deprecated.")
        if save_name is not None:
            save_image(value, save_name, normalize=True)
        else:
            value = value.cpu().numpy()
            plt.imshow(
                value,
                cmap=cmap,
                extent=[
                    -self.phy_size[0] / 2,
                    self.phy_size[0] / 2,
                    -self.phy_size[1] / 2,
                    self.phy_size[1] / 2,
                ],
            )

    elif len(self.u.shape) == 4:
        B, C, H, W = self.u.shape
        if B == 1:
            if save_name is not None:
                save_image(value, save_name, normalize=True)
            else:
                value = value.cpu().numpy()
                plt.imshow(
                    value[0, 0, :, :],
                    cmap=cmap,
                    extent=[
                        -self.phy_size[0] / 2,
                        self.phy_size[0] / 2,
                        -self.phy_size[1] / 2,
                        self.phy_size[1] / 2,
                    ],
                )
        else:
            if save_name is not None:
                plt.savefig(save_name)
            else:
                value = value.cpu().numpy()
                fig, axs = plt.subplots(1, B)
                for i in range(B):
                    axs[i].imshow(
                        value[i, 0, :, :],
                        cmap=cmap,
                        extent=[
                            -self.phy_size[0] / 2,
                            self.phy_size[0] / 2,
                            -self.phy_size[1] / 2,
                            self.phy_size[1] / 2,
                        ],
                    )
                fig.show()
    else:
        raise Exception("Unsupported complex field shape.")

pad

pad(Hpad, Wpad)

Pad the input field by (Hpad, Hpad, Wpad, Wpad). This step will also expand physical size of the field.

Parameters:

Name Type Description Default
Hpad int

Number of pixels to pad on the top and bottom.

required
Wpad int

Number of pixels to pad on the left and right.

required

Returns:

Name Type Description
self

Padded complex wave field.

Source code in src/light/wave.py
def pad(self, Hpad, Wpad):
    """Pad the input field by (Hpad, Hpad, Wpad, Wpad). This step will also expand physical size of the field.

    Args:
        Hpad (int): Number of pixels to pad on the top and bottom.
        Wpad (int): Number of pixels to pad on the left and right.

    Returns:
        self: Padded complex wave field.
    """
    self.u = F.pad(self.u, (Hpad, Hpad, Wpad, Wpad), mode="constant", value=0)

    Horg, Worg = self.res
    self.res = [Horg + 2 * Hpad, Worg + 2 * Wpad]
    self.phy_size = [
        self.phy_size[0] * self.res[0] / Horg,
        self.phy_size[1] * self.res[1] / Worg,
    ]
    self.x, self.y = self.gen_xy_grid()
    self.z = torch.full_like(self.x, float(self.z[0, 0]))

flip

flip()

Flip the field horizontally and vertically.

Source code in src/light/wave.py
def flip(self):
    """Flip the field horizontally and vertically."""
    self.u = torch.flip(self.u, [-1, -2])
    self.x = torch.flip(self.x, [-1, -2])
    self.y = torch.flip(self.y, [-1, -2])
    self.z = torch.flip(self.z, [-1, -2])
    return self

Propagation Methods

Class Method Use Case
AngularSpectrumMethod Fourier-domain propagation Near-field, high-NA
ScalableASM Scalable ASM with chirp-z transform Large propagation distances
FresnelDiffraction Fresnel approximation Moderate distances
FraunhoferDiffraction Fraunhofer (far-field) Far-field patterns
RayleighSommerfeld Rayleigh-Sommerfeld integral Reference method

Utility Functions

Function Description
Nyquist_ASM_zmax(wvln, L, N) Maximum valid propagation distance for ASM
Fresnel_zmin(wvln, L, N) Minimum valid distance for Fresnel approximation

Default Wavelengths

Defined in src/config.py:

Constant Value Description
DEFAULT_WAVE 0.587 \(\mu\)m d-line (green)
WAVE_RGB [0.656, 0.587, 0.486] \(\mu\)m C-line, d-line, F-line (R, G, B)
DEPTH -20000.0 mm Approximate optical infinity