Skip to content

Isotropic Solver

The standard 2×2 transfer matrix method for isotropic multi-layer films. This is the fastest solver in DiffTMM (~190× faster than NumPy TMM) and the right default whenever every layer is isotropic. It computes the complex Fresnel coefficients (ts, tp, rs, rp) with full phase, and supports bidirectional propagation — angles in [0, π/2] are forward (top → bottom), angles in [π/2, π] are reverse (bottom → top).

difftmm.IsotropicFilmSolver

IsotropicFilmSolver(mat_in, mat_out, mat_ls, thickness_ls=None, thickness_min=0.0, thickness_max=0.2, batch_size=1, sigmoid_param=False, device=torch.device('cuda'))

Multi-layer coating film solver for isotropic materials.

Uses the standard 2x2 transfer matrix method which is much faster than the general 4x4 anisotropic formulation. This solver calculates (ts, tp, rs, rp) with phase shifts using rigorous electromagnetic wave propagation through multi-layer coating stacks.

Initialize the isotropic film solver.

Parameters:

Name Type Description Default
mat_in

Refractive index of incident medium. float, complex, or str material name.

required
mat_out

Refractive index of exit medium. Same types as mat_in.

required
mat_ls

Refractive indices of interior layers. List of float/complex scalars or str material names.

required
thickness_ls

Thicknesses of interior layers in um, list or tensor of length N. If None, randomly initializes thicknesses.

None
thickness_min

Minimum layer thickness in um.

0.0
thickness_max

Maximum layer thickness in um.

0.2
batch_size

Number of film stacks in the batch dimension.

1
sigmoid_param

If True, use sigmoid parameterization for thickness.

False
device

PyTorch device.

device('cuda')
Source code in difftmm-src/difftmm/film_solver_isotropic.py
def __init__(
    self,
    mat_in,
    mat_out,
    mat_ls,
    thickness_ls=None,
    thickness_min=0.0,
    thickness_max=0.2,
    batch_size=1,
    sigmoid_param=False,
    device=torch.device("cuda"),
):
    """
    Initialize the isotropic film solver.

    Args:
        mat_in: Refractive index of incident medium. float, complex, or str material name.
        mat_out: Refractive index of exit medium. Same types as mat_in.
        mat_ls: Refractive indices of interior layers. List of float/complex
                  scalars or str material names.
        thickness_ls: Thicknesses of interior layers in um, list or tensor of length N.
                      If None, randomly initializes thicknesses.
        thickness_min: Minimum layer thickness in um.
        thickness_max: Maximum layer thickness in um.
        batch_size: Number of film stacks in the batch dimension.
        sigmoid_param: If True, use sigmoid parameterization for thickness.
        device: PyTorch device.
    """
    self.batch_size = batch_size
    self.device = device

    self.mat_in = Material(mat_in, device=device)
    self.mat_out = Material(mat_out, device=device)
    self.mat_ls = [Material(s, device=device) for s in mat_ls]
    self.num_layers = len(self.mat_ls)

    self.thickness_min = thickness_min
    self.thickness_max = thickness_max
    self._thickness_range = self.thickness_max - self.thickness_min

    self.sigmoid_param = sigmoid_param
    if thickness_ls is not None:
        if not torch.is_tensor(thickness_ls):
            thickness_ls = torch.tensor(thickness_ls, dtype=torch.float32)
        normalized = (
            thickness_ls.clamp(self.thickness_min, self.thickness_max)
            - self.thickness_min
        ) / self._thickness_range
        self.film_params = normalized.unsqueeze(0).expand(batch_size, -1).clone()
    else:
        self.film_params = torch.randn(batch_size, self.num_layers) * 0.01 + 0.5

    if self.sigmoid_param:
        self.film_params = inv_sigmoid(self.film_params.clamp(1e-6, 1 - 1e-6))

    self.to(device)

to

to(device)

Move tensors to specified device.

Source code in difftmm-src/difftmm/film_solver_isotropic.py
def to(self, device):
    """Move tensors to specified device."""
    self.device = device
    self.film_params = self.film_params.to(device, non_blocking=True)
    self.mat_in.to(device)
    self.mat_out.to(device)
    for m in self.mat_ls:
        m.to(device)
    return self

load_ckpt

load_ckpt(ckpt_path)

Load thicknesses (and spec metadata) from a checkpoint.

Source code in difftmm-src/difftmm/film_solver_isotropic.py
def load_ckpt(self, ckpt_path):
    """Load thicknesses (and spec metadata) from a checkpoint."""
    ckpt = torch.load(ckpt_path, map_location=self.device, weights_only=False)
    film_thickness = torch.clamp(
        ckpt["film_thickness"], self.thickness_min, self.thickness_max
    )
    film_thickness_normalized = (film_thickness - self.thickness_min) / (
        self.thickness_max - self.thickness_min
    )
    if self.sigmoid_param:
        film_thickness_normalized = torch.clamp(
            film_thickness_normalized, 1e-6, 1 - 1e-6
        )
        self.film_params = inv_sigmoid(film_thickness_normalized).to(self.device)
    else:
        self.film_params = film_thickness_normalized.to(self.device)

    if "mat_ls" in ckpt:
        self.mat_in = Material(ckpt["mat_in"], device=self.device)
        self.mat_out = Material(ckpt["mat_out"], device=self.device)
        self.mat_ls = [Material(v, device=self.device) for v in ckpt["mat_ls"]]

save_ckpt

save_ckpt(save_path)

Save thicknesses and material specs to a checkpoint.

Material objects are persisted by name; scalars are persisted by value. Per-axis 3-tuples are persisted element-wise.

Source code in difftmm-src/difftmm/film_solver_isotropic.py
def save_ckpt(self, save_path):
    """Save thicknesses and material specs to a checkpoint.

    Material objects are persisted by name; scalars are persisted by value.
    Per-axis 3-tuples are persisted element-wise.
    """
    payload = {
        "film_thickness": self.get_film_thickness().cpu(),
        "batch_size": self.batch_size,
        "num_layers": self.num_layers,
        "mat_in": self.mat_in.name,
        "mat_out": self.mat_out.name,
        "mat_ls": [m.name for m in self.mat_ls],
    }
    torch.save(payload, save_path)

get_film_thickness

get_film_thickness()

Convert optimization-friendly film parameters to real film thickness.

Returns:

Name Type Description
film_thickness

tensor of shape (batch_size, num_layers), in [um].

Source code in difftmm-src/difftmm/film_solver_isotropic.py
def get_film_thickness(self):
    """Convert optimization-friendly film parameters to real film thickness.

    Returns:
        film_thickness: tensor of shape (batch_size, num_layers), in [um].
    """
    if self.sigmoid_param:
        film_thickness = (
            torch.sigmoid(self.film_params) * self._thickness_range
            + self.thickness_min
        )
    else:
        film_thickness = (
            self.film_params * self._thickness_range + self.thickness_min
        )
        film_thickness = film_thickness.clamp(
            self.thickness_min, self.thickness_max
        )

    return film_thickness

simulate

simulate(theta, wvln)

Calculate (ts, tp, rs, rp) using TMM for specified angles and wavelengths.

Parameters:

Name Type Description Default
theta

Incident angles in radians. Can be: - 1D tensor of shape (n_angles,): same angles for all mirrors - 2D tensor of shape (batch_size, n_angles): different angles per film stack

required
wvln

Wavelengths in micrometers. Can be: - List or 1D tensor of shape (n_wvlns,) - Scalar or 0D tensor: single wavelength

required

Returns:

Type Description

ts, tp, rs, rp: Complex transmission/reflection coefficients. Shape: (batch_size, n_wvlns, n_angles)

Source code in difftmm-src/difftmm/film_solver_isotropic.py
def simulate(self, theta, wvln):
    """Calculate (ts, tp, rs, rp) using TMM for specified angles and wavelengths.

    Args:
        theta: Incident angles in radians. Can be:
               - 1D tensor of shape (n_angles,): same angles for all mirrors
               - 2D tensor of shape (batch_size, n_angles): different angles per film stack
        wvln: Wavelengths in micrometers. Can be:
              - List or 1D tensor of shape (n_wvlns,)
              - Scalar or 0D tensor: single wavelength

    Returns:
        ts, tp, rs, rp: Complex transmission/reflection coefficients.
                       Shape: (batch_size, n_wvlns, n_angles)
    """
    theta, wv_batch, d_batch, n_in_t, n_out_t, n_layers_t = (
        self._prepare_simulate_inputs(theta, wvln)
    )
    ts, tp, rs, rp = create_jones_matrix_isotropic(
        n_layers_t, d_batch, wv_batch, n_in_t, n_out_t, theta
    )
    return ts, tp, rs, rp

__call__

__call__(theta, wvln)

Forward pass using simulate.

Source code in difftmm-src/difftmm/film_solver_isotropic.py
def __call__(self, theta, wvln):
    """Forward pass using simulate."""
    return self.simulate(theta, wvln)

Functional API

The solver is a thin wrapper around a pure function. Use create_jones_matrix_isotropic directly when you want to differentiate through the film stack without holding solver state — for example, when the thicknesses are an external torch.nn.Parameter in an inverse-design loop.

difftmm.create_jones_matrix_isotropic

create_jones_matrix_isotropic(n_layers_1d, d_1d, wv_1d, n_in, n_out, theta_1d)

Fast Jones matrix calculation for isotropic multi-layer films.

Uses the standard 2x2 transfer matrix method which is much faster than the general 4x4 anisotropic formulation when materials are isotropic. Avoids eigenvalue decomposition entirely.

Supports bidirectional propagation: - theta in [0, pi/2]: Forward direction (top to bottom, n_in -> layers -> n_out) - theta in [pi/2, pi]: Reverse direction (bottom to top, n_out -> layers -> n_in) Internally converts to equivalent forward problem with swapped media and reversed layers.

Parameters:

Name Type Description Default
n_layers_1d

refractive index of each layer. Shape (batchsize, n_layer) — wavelength-independent, OR shape (batchsize, n_wls, n_layer) — dispersive per wavelength. Complex.

required
d_1d

thicknesses of all layers, shape (batchsize, n_layer). Real or Complex.

required
wv_1d

wavelengths of simulations, shape (batchsize, n_wls). Real.

required
n_in

incident media refractive index (top medium). Scalar (Python float/complex), OR tensor of shape (batchsize, n_wls) for per-wavelength dispersive incident medium.

required
n_out

transmit media refractive index (bottom medium). Scalar (Python float/complex), OR tensor of shape (batchsize, n_wls) for per-wavelength dispersive exit medium.

required
theta_1d

incident angles, shape (batchsize, n_angles). Real. Range [0, pi]. Angles > pi/2 represent reverse propagation.

required

Returns:

Type Description

ts, tp, rs, rp: complex transmission/reflection coefficients each with shape (batchsize, n_wls, n_angles)

Source code in difftmm-src/difftmm/film_solver_isotropic.py
def create_jones_matrix_isotropic(n_layers_1d, d_1d, wv_1d, n_in, n_out, theta_1d):
    """
    Fast Jones matrix calculation for isotropic multi-layer films.

    Uses the standard 2x2 transfer matrix method which is much faster
    than the general 4x4 anisotropic formulation when materials are isotropic.
    Avoids eigenvalue decomposition entirely.

    Supports bidirectional propagation:
    - theta in [0, pi/2]: Forward direction (top to bottom, n_in -> layers -> n_out)
    - theta in [pi/2, pi]: Reverse direction (bottom to top, n_out -> layers -> n_in)
      Internally converts to equivalent forward problem with swapped media and reversed layers.

    Args:
        n_layers_1d: refractive index of each layer.
                     Shape (batchsize, n_layer) — wavelength-independent, OR
                     shape (batchsize, n_wls, n_layer) — dispersive per wavelength. Complex.
        d_1d: thicknesses of all layers, shape (batchsize, n_layer). Real or Complex.
        wv_1d: wavelengths of simulations, shape (batchsize, n_wls). Real.
        n_in: incident media refractive index (top medium).
              Scalar (Python float/complex), OR tensor of shape (batchsize, n_wls)
              for per-wavelength dispersive incident medium.
        n_out: transmit media refractive index (bottom medium).
               Scalar (Python float/complex), OR tensor of shape (batchsize, n_wls)
               for per-wavelength dispersive exit medium.
        theta_1d: incident angles, shape (batchsize, n_angles). Real.
                  Range [0, pi]. Angles > pi/2 represent reverse propagation.

    Returns:
        ts, tp, rs, rp: complex transmission/reflection coefficients
                        each with shape (batchsize, n_wls, n_angles)
    """
    device = n_layers_1d.device
    dtype = torch.complex64

    batchsize = d_1d.shape[0]
    num_wv = wv_1d.shape[1]
    num_angles = theta_1d.shape[1]
    num_layer = d_1d.shape[1]

    # Normalize n_in / n_out to (batch, num_wv, 1, 1) complex tensors so they
    # broadcast correctly with the (batch, num_wv, num_angles, num_layer) tensors
    # used inside _compute_isotropic_tmm.
    def _to_per_wvln(x):
        if torch.is_tensor(x) and x.dim() == 2:
            # Already (batch, num_wv)
            return x.to(dtype=dtype, device=device).unsqueeze(-1).unsqueeze(-1)
        if torch.is_tensor(x) and x.dim() == 0:
            x = x.item()
        # Scalar (Python or 0-d tensor that we just unwrapped)
        return torch.tensor(complex(x), dtype=dtype, device=device).view(1, 1, 1, 1)

    n_in_t = _to_per_wvln(n_in)
    n_out_t = _to_per_wvln(n_out)

    # Identify forward (theta <= pi/2) and reverse (theta > pi/2) angles
    pi_half = torch.pi / 2
    is_reverse = theta_1d > pi_half  # (batch, angles)

    # Fast path: if n_in == n_out (symmetric media), we can use a simpler approach
    # For symmetric media with symmetric layer stack, |r(theta)| = |r(pi-theta)|
    # We compute forward angles only and map results for reverse angles
    # Fast path only when both media are scalar and equal
    is_scalar_pair = (
        not torch.is_tensor(n_in)
        and not torch.is_tensor(n_out)
        and abs(n_in - n_out) < 1e-10
    )
    if is_scalar_pair:
        # Map all angles to [0, pi/2] range
        theta_mapped = torch.where(is_reverse, torch.pi - theta_1d, theta_1d)

        # Expand dimensions (combined for fewer operations)
        # Normalize n_layers to (batch, num_wv, 1, num_layer) complex
        if n_layers_1d.dim() == 2:
            # (batch, num_layer) — wavelength-independent
            n_layers = (
                n_layers_1d.unsqueeze(1).unsqueeze(2).to(dtype=dtype, device=device)
            )
        elif n_layers_1d.dim() == 3:
            # (batch, num_wv, num_layer) — wavelength-dependent
            n_layers = n_layers_1d.unsqueeze(2).to(dtype=dtype, device=device)
        else:
            raise ValueError(
                f"n_layers_1d must be 2-D or 3-D, got shape {n_layers_1d.shape}"
            )
        d = d_1d.unsqueeze(1).unsqueeze(2).to(dtype)
        wv = wv_1d.unsqueeze(2).unsqueeze(3).to(dtype)
        theta = theta_mapped.unsqueeze(1).unsqueeze(3).to(dtype)

        # Single forward computation with mapped angles
        ts, tp, rs, rp = _compute_isotropic_tmm(
            n_layers,
            d,
            wv,
            n_in_t,
            n_out_t,
            theta,
            batchsize,
            num_wv,
            num_angles,
            num_layer,
            device,
            dtype,
        )

        return ts, tp, rs, rp

    # Check if we have any reverse angles
    has_forward = (~is_reverse).any()
    has_reverse = is_reverse.any()

    # Pre-expand common tensors (avoid redundant expansion)
    wv = wv_1d.unsqueeze(2).unsqueeze(3).to(dtype)

    # Prepare output tensors
    ts_out = torch.zeros((batchsize, num_wv, num_angles), dtype=dtype, device=device)
    tp_out = torch.zeros((batchsize, num_wv, num_angles), dtype=dtype, device=device)
    rs_out = torch.zeros((batchsize, num_wv, num_angles), dtype=dtype, device=device)
    rp_out = torch.zeros((batchsize, num_wv, num_angles), dtype=dtype, device=device)

    # Process forward angles (theta <= pi/2)
    if has_forward:
        # Get forward angle indices
        forward_mask = ~is_reverse  # (batch, angles)

        # For simplicity, process all angles but only use results for forward ones
        # Expand dimensions for broadcasting
        # Normalize n_layers to (batch, num_wv, 1, num_layer) complex
        if n_layers_1d.dim() == 2:
            # (batch, num_layer) — wavelength-independent
            n_layers = (
                n_layers_1d.unsqueeze(1).unsqueeze(2).to(dtype=dtype, device=device)
            )
        elif n_layers_1d.dim() == 3:
            # (batch, num_wv, num_layer) — wavelength-dependent
            n_layers = n_layers_1d.unsqueeze(2).to(dtype=dtype, device=device)
        else:
            raise ValueError(
                f"n_layers_1d must be 2-D or 3-D, got shape {n_layers_1d.shape}"
            )
        d = d_1d.unsqueeze(1).unsqueeze(2).to(dtype)
        theta = theta_1d.unsqueeze(1).unsqueeze(3).to(dtype)  # (batch, 1, angles, 1)

        ts_fwd, tp_fwd, rs_fwd, rp_fwd = _compute_isotropic_tmm(
            n_layers,
            d,
            wv,
            n_in_t,
            n_out_t,
            theta,
            batchsize,
            num_wv,
            num_angles,
            num_layer,
            device,
            dtype,
        )

        # Copy forward results
        forward_mask_exp = forward_mask.unsqueeze(1).expand(-1, num_wv, -1)
        ts_out = torch.where(forward_mask_exp, ts_fwd, ts_out)
        tp_out = torch.where(forward_mask_exp, tp_fwd, tp_out)
        rs_out = torch.where(forward_mask_exp, rs_fwd, rs_out)
        rp_out = torch.where(forward_mask_exp, rp_fwd, rp_out)

    # Process reverse angles (theta > pi/2)
    if has_reverse:
        # For reverse direction:
        # 1. Use supplementary angle: theta_rev = pi - theta
        # 2. Swap incident and output media
        # 3. Reverse layer order

        # Supplementary angle
        theta_rev = torch.pi - theta_1d  # (batch, angles)

        # Reverse layer order
        # Flip the layer axis (last dim works for both 2-D and 3-D)
        n_layers_rev = torch.flip(n_layers_1d, dims=[-1])
        d_rev = torch.flip(d_1d, dims=[1])

        # Expand dimensions
        # Normalize n_layers to (batch, num_wv, 1, num_layer) complex
        if n_layers_rev.dim() == 2:
            # (batch, num_layer) — wavelength-independent
            n_layers_rev_exp = (
                n_layers_rev.unsqueeze(1).unsqueeze(2).to(dtype=dtype, device=device)
            )
        elif n_layers_rev.dim() == 3:
            # (batch, num_wv, num_layer) — wavelength-dependent
            n_layers_rev_exp = n_layers_rev.unsqueeze(2).to(dtype=dtype, device=device)
        else:
            raise ValueError(
                f"n_layers_1d must be 2-D or 3-D, got shape {n_layers_1d.shape}"
            )
        d_rev_exp = d_rev.unsqueeze(1).unsqueeze(2).to(dtype)
        theta_rev_exp = theta_rev.unsqueeze(1).unsqueeze(3).to(dtype)

        # Compute with swapped media (n_out -> n_in)
        ts_rev, tp_rev, rs_rev, rp_rev = _compute_isotropic_tmm(
            n_layers_rev_exp,
            d_rev_exp,
            wv,
            n_out_t,
            n_in_t,
            theta_rev_exp,
            batchsize,
            num_wv,
            num_angles,
            num_layer,
            device,
            dtype,
        )

        # Copy reverse results
        reverse_mask = is_reverse  # (batch, angles)
        reverse_mask_exp = reverse_mask.unsqueeze(1).expand(-1, num_wv, -1)
        ts_out = torch.where(reverse_mask_exp, ts_rev, ts_out)
        tp_out = torch.where(reverse_mask_exp, tp_rev, tp_out)
        rs_out = torch.where(reverse_mask_exp, rs_rev, rs_out)
        rp_out = torch.where(reverse_mask_exp, rp_rev, rp_out)

    return ts_out, tp_out, rs_out, rp_out