Skip to content

Anisotropic Solver

The general 4×4 transfer matrix method, which handles both isotropic and anisotropic (birefringent) media. Use it when one or more layers have direction-dependent refractive index; each such layer is given as a (mat_x, mat_y, mat_z) tuple of per-axis indices. It is more general (and heavier) than the isotropic 2×2 solver — for fully isotropic stacks, prefer the isotropic solver for speed and lower memory.

AnisotropicFilmSolver is an alias

difftmm.AnisotropicFilmSolver is an alias for FilmSolver; both names refer to the same class. The two share the constructor and simulate(theta, wvln) signature of the isotropic solver, returning complex (ts, tp, rs, rp).

difftmm.FilmSolver

FilmSolver(mat_in, mat_out, mat_ls, thickness_ls=None, thickness_min=0.0, thickness_max=0.2, batch_size=1, sigmoid_param=False, device=torch.device('cuda'))

Multi-layer coating physical film solver using transfer matrix method.

This solver calculates (ts, tp, rs, rp) with phase shifts using rigorous electromagnetic wave propagation through multi-layer coating stacks.

Initialize the anisotropic film solver.

Parameters:

Name Type Description Default
mat_in

Refractive index of incident medium. float, complex, or str material name.

required
mat_out

Refractive index of exit medium. Same types as mat_in.

required
mat_ls

Refractive indices of interior layers. Each element is either a float/complex scalar or str material name (isotropic layer), or a 3-tuple of those types for (nx, ny, nz) birefringent layers.

required
thickness_ls

Thicknesses of interior layers in um, list or tensor of length N. If None, randomly initializes thicknesses.

None
thickness_min

Minimum layer thickness in um.

0.0
thickness_max

Maximum layer thickness in um.

0.2
batch_size

Number of film stacks in the batch dimension.

1
sigmoid_param

If True, use sigmoid parameterization for thickness.

False
device

PyTorch device.

device('cuda')
Source code in difftmm-src/difftmm/film_solver_anisotropic.py
def __init__(
    self,
    mat_in,
    mat_out,
    mat_ls,
    thickness_ls=None,
    thickness_min=0.0,
    thickness_max=0.2,
    batch_size=1,
    sigmoid_param=False,
    device=torch.device("cuda"),
):
    """Initialize the anisotropic film solver.

    Args:
        mat_in: Refractive index of incident medium. float, complex, or str material name.
        mat_out: Refractive index of exit medium. Same types as mat_in.
        mat_ls: Refractive indices of interior layers. Each element is either a
            float/complex scalar or str material name (isotropic layer), or a
            3-tuple of those types for (nx, ny, nz) birefringent layers.
        thickness_ls: Thicknesses of interior layers in um, list or tensor of length N.
            If None, randomly initializes thicknesses.
        thickness_min: Minimum layer thickness in um.
        thickness_max: Maximum layer thickness in um.
        batch_size: Number of film stacks in the batch dimension.
        sigmoid_param: If True, use sigmoid parameterization for thickness.
        device: PyTorch device.
    """
    self.batch_size = batch_size
    self.device = device

    self.mat_in  = Material(mat_in,  device=device)
    self.mat_out = Material(mat_out, device=device)

    def _to_mat_layer(spec):
        if isinstance(spec, tuple):
            if len(spec) != 3:
                raise ValueError(
                    f"anisotropic layer spec must be 3-tuple, got len {len(spec)}"
                )
            return tuple(Material(x, device=device) for x in spec)
        return Material(spec, device=device)

    self.mat_ls = [_to_mat_layer(s) for s in mat_ls]
    self.num_layers = len(self.mat_ls)

    all_constant = all(
        (all(x.dispersion == "constant" for x in s) if isinstance(s, tuple)
         else s.dispersion == "constant")
        for s in self.mat_ls
    )
    if all_constant:
        rows = []
        for s in self.mat_ls:
            if isinstance(s, tuple):
                rows.append([s[0]._const_n, s[1]._const_n, s[2]._const_n])
            else:
                rows.append([s._const_n, s._const_n, s._const_n])
        t = torch.tensor(rows, dtype=torch.complex64)
        self.refract_idx_layers = t.unsqueeze(0).expand(batch_size, -1, -1).clone()
    else:
        self.refract_idx_layers = None

    self.thickness_min = thickness_min
    self.thickness_max = thickness_max
    self._thickness_range = self.thickness_max - self.thickness_min

    self.sigmoid_param = sigmoid_param
    if thickness_ls is not None:
        if not torch.is_tensor(thickness_ls):
            thickness_ls = torch.tensor(thickness_ls, dtype=torch.float32)
        normalized = (
            thickness_ls.clamp(self.thickness_min, self.thickness_max)
            - self.thickness_min
        ) / self._thickness_range
        self.film_params = normalized.unsqueeze(0).expand(batch_size, -1).clone()
    else:
        self.film_params = torch.randn(batch_size, self.num_layers) * 0.01 + 0.5
    if self.sigmoid_param:
        self.film_params = inv_sigmoid(self.film_params.clamp(1e-6, 1 - 1e-6))

    self.to(device)

to

to(device)

Move tensors to specified device.

Source code in difftmm-src/difftmm/film_solver_anisotropic.py
def to(self, device):
    """Move tensors to specified device."""
    self.device = device
    self.film_params = self.film_params.to(device)
    if self.refract_idx_layers is not None:
        self.refract_idx_layers = self.refract_idx_layers.to(device)
    self.mat_in.to(device)
    self.mat_out.to(device)
    for s in self.mat_ls:
        if isinstance(s, tuple):
            for m in s:
                m.to(device)
        else:
            s.to(device)
    return self

load_ckpt

load_ckpt(ckpt_path)

Load thicknesses (and spec metadata) from a checkpoint.

Source code in difftmm-src/difftmm/film_solver_anisotropic.py
def load_ckpt(self, ckpt_path):
    """Load thicknesses (and spec metadata) from a checkpoint."""
    ckpt = torch.load(ckpt_path, map_location=self.device, weights_only=False)
    film_thickness = torch.clamp(
        ckpt["film_thickness"], self.thickness_min, self.thickness_max
    )
    film_thickness_normalized = (film_thickness - self.thickness_min) / (
        self.thickness_max - self.thickness_min
    )
    if self.sigmoid_param:
        film_thickness_normalized = torch.clamp(
            film_thickness_normalized, 1e-6, 1 - 1e-6
        )
        self.film_params = inv_sigmoid(film_thickness_normalized).to(self.device)
    else:
        self.film_params = film_thickness_normalized.to(self.device)

    if "mat_ls" in ckpt:
        self.mat_in  = Material(ckpt["mat_in"],  device=self.device)
        self.mat_out = Material(ckpt["mat_out"], device=self.device)
        self.mat_ls  = [
            tuple(Material(x, device=self.device) for x in v) if isinstance(v, tuple)
            else Material(v, device=self.device)
            for v in ckpt["mat_ls"]
        ]

save_ckpt

save_ckpt(save_path)

Save thicknesses and material specs to a checkpoint.

Source code in difftmm-src/difftmm/film_solver_anisotropic.py
def save_ckpt(self, save_path):
    """Save thicknesses and material specs to a checkpoint."""
    def _layer_name(s):
        if isinstance(s, tuple):
            return tuple(m.name for m in s)
        return s.name

    payload = {
        "film_thickness": self.get_film_thickness().cpu(),
        "batch_size": self.batch_size,
        "num_layers": self.num_layers,
        "mat_in":  self.mat_in.name,
        "mat_out": self.mat_out.name,
        "mat_ls":  [_layer_name(s) for s in self.mat_ls],
    }
    torch.save(payload, save_path)

get_film_thickness

get_film_thickness()

Convert optimization-friendly film parameters to real film thickness.

Returns:

Name Type Description
film_thickness

tensor of shape (batch_size, num_layers), in [um].

Source code in difftmm-src/difftmm/film_solver_anisotropic.py
def get_film_thickness(self):
    """Convert optimization-friendly film parameters to real film thickness.

    Returns:
        film_thickness: tensor of shape (batch_size, num_layers), in [um].
    """
    if self.sigmoid_param:
        film_thickness = (
            torch.sigmoid(self.film_params) * self._thickness_range + self.thickness_min
        )
    else:
        film_thickness = self.film_params * self._thickness_range + self.thickness_min
        film_thickness = film_thickness.clamp(self.thickness_min, self.thickness_max)

    return film_thickness

simulate

simulate(theta, wvln)

Calculate (ts, tp, rs, rp) using 4x4 TMM for specified angles and wavelengths.

Parameters:

Name Type Description Default
theta

Incident angles in radians. Can be: - 1D tensor of shape (n_angles,) - 2D tensor of shape (batch_size, n_angles)

required
wvln

Wavelengths in micrometers. Can be: - List or 1D tensor of shape (n_wvlns,) - Scalar or 0D tensor

required

Returns:

Type Description

ts, tp, rs, rp: Complex transmission/reflection coefficients, shape (batch_size, n_wvlns, n_angles).

Source code in difftmm-src/difftmm/film_solver_anisotropic.py
def simulate(self, theta, wvln):
    """
    Calculate (ts, tp, rs, rp) using 4x4 TMM for specified angles and wavelengths.

    Args:
        theta: Incident angles in radians. Can be:
               - 1D tensor of shape (n_angles,)
               - 2D tensor of shape (batch_size, n_angles)
        wvln: Wavelengths in micrometers. Can be:
              - List or 1D tensor of shape (n_wvlns,)
              - Scalar or 0D tensor

    Returns:
        ts, tp, rs, rp: Complex transmission/reflection coefficients,
                        shape (batch_size, n_wvlns, n_angles).
    """
    if not torch.is_tensor(theta):
        theta = torch.tensor(theta, dtype=torch.float32, device=self.device)
    theta = theta.to(self.device)
    if theta.dim() == 1:
        theta = theta.unsqueeze(0).expand(self.batch_size, -1)

    if torch.is_tensor(wvln):
        wv = wvln.to(self.device)
        if wv.dim() == 0:
            wv = wv.unsqueeze(0)
    elif isinstance(wvln, (list, tuple)):
        wv = torch.tensor(wvln, dtype=torch.float32, device=self.device)
    else:
        wv = torch.tensor([wvln], dtype=torch.float32, device=self.device)

    d_1d = self.get_film_thickness()
    wv_1d = wv.unsqueeze(0).expand(self.batch_size, -1)
    n_wvlns = wv.shape[0]
    n_angles = theta.shape[1]

    n_in_t  = self.mat_in.ior(wv).unsqueeze(0).expand(self.batch_size, -1)
    n_out_t = self.mat_out.ior(wv).unsqueeze(0).expand(self.batch_size, -1)

    # Build n_2d_w: shape (batch, n_wvln, n_layer, 3)
    per_layer_axes = []
    for s in self.mat_ls:
        if isinstance(s, tuple):
            cols = torch.stack(
                [s[ax].ior(wv) for ax in range(3)],
                dim=-1,
            )
        else:
            col = s.ior(wv)
            cols = col.unsqueeze(-1).expand(-1, 3)
        per_layer_axes.append(cols)
    n_2d_w = torch.stack(per_layer_axes, dim=-2)
    n_2d_w = n_2d_w.unsqueeze(0).expand(self.batch_size, -1, -1, -1)

    a_2d = torch.zeros(
        (self.batch_size, self.num_layers, 3),
        dtype=torch.complex64,
        device=self.device,
    )
    d_1d_complex = d_1d.to(torch.complex64)
    Az_1d = torch.zeros((self.batch_size, 1), device=self.device)

    Jt, Jr = create_jones_matrix_AOIAz(
        a_2d, n_2d_w, d_1d_complex, wv_1d, n_in_t, n_out_t, theta, Az_1d
    )

    p_in_lab = torch.tensor([[1.0 + 0.0j], [0.0 + 0.0j]], dtype=torch.complex64).to(
        self.device
    )
    s_in_lab = torch.tensor([[0.0 + 0.0j], [1.0 + 0.0j]], dtype=torch.complex64).to(
        self.device
    )
    p_in_5d = p_in_lab.reshape((1, 1, 1, 1, 2, 1)).expand(
        self.batch_size, n_wvlns, n_angles, 1, -1, -1
    )
    s_in_5d = s_in_lab.reshape((1, 1, 1, 1, 2, 1)).expand(
        self.batch_size, n_wvlns, n_angles, 1, -1, -1
    )

    t1_vec_p_5d = torch.matmul(Jt, p_in_5d)
    r1_vec_p_5d = torch.matmul(Jr, p_in_5d)
    t1_vec_s_5d = torch.matmul(Jt, s_in_5d)
    r1_vec_s_5d = torch.matmul(Jr, s_in_5d)

    tp = t1_vec_p_5d[:, :, :, :, 0, 0].squeeze(-1)
    ts = t1_vec_s_5d[:, :, :, :, 1, 0].squeeze(-1)
    rp = r1_vec_p_5d[:, :, :, :, 0, 0].squeeze(-1)
    rs = r1_vec_s_5d[:, :, :, :, 1, 0].squeeze(-1)

    return ts, tp, rs, rp

__call__

__call__(theta, wvln)

Forward pass using simulate.

Source code in difftmm-src/difftmm/film_solver_anisotropic.py
def __call__(self, theta, wvln):
    """Forward pass using simulate."""
    return self.simulate(theta, wvln)

Functional API

Lower-level entry point used by FilmSolver.simulate(). It builds the 4×4 Jones matrix for an anisotropic stack expressed in the angle-of-incidence / azimuth (AOIAz) frame.

difftmm.create_jones_matrix_AOIAz

create_jones_matrix_AOIAz(a_2d, n_2d, d_1d, wv_1d, n_in, n_out, theta_x_1d, theta_y_1d)

Calculate the Jones matrix for reflected and transmitted light.

Optimized: Vectorized AOI/Az setup, batched eigenvalue decomposition, and using torch.linalg.solve instead of torch.inverse.

Parameters:

Name Type Description Default
a_2d

azimuth angle of materials in each layer, shape (batchsize, n_layer, 3). Complex.

required
n_2d

refractive index of each layer. Accepts: - 3-D shape (batchsize, n_layer, 3): non-dispersive, broadcast across wvlns. - 4-D shape (batchsize, num_wv, n_layer, 3): dispersive per wavelength. Complex.

required
d_1d

thicknesses of all layers, shape (batchsize, n_layer). Complex.

required
wv_1d

wavelengths of simulations, shape (batchsize, n_wls). Real

required
n_in

incident media refractive index. Accepts Python scalar or tensor of shape (batchsize, num_wv) for per-wavelength values.

required
n_out

transmit media refractive index. Same accepted shapes as n_in.

required
theta_x_1d

incident Zenith angle, shape (batchsize, n_aoi_angles). Real

required
theta_y_1d

azimuth angle of incident light, shape (batchsize, n_az_angles). Real.

required

Returns:

Type Description

Jones_trn, Jones_ref: Jones matrices, each with shape (batchsize, n_wls, n_aoi_angles, n_az_angles, 2, 2). Complex

Source code in difftmm-src/difftmm/film_solver_anisotropic.py
def create_jones_matrix_AOIAz(
    a_2d, n_2d, d_1d, wv_1d, n_in, n_out, theta_x_1d, theta_y_1d
):
    """
    Calculate the Jones matrix for reflected and transmitted light.

    Optimized: Vectorized AOI/Az setup, batched eigenvalue decomposition,
    and using torch.linalg.solve instead of torch.inverse.

    Args:
        a_2d: azimuth angle of materials in each layer, shape (batchsize, n_layer, 3). Complex.
        n_2d: refractive index of each layer. Accepts:
              - 3-D shape (batchsize, n_layer, 3): non-dispersive, broadcast across wvlns.
              - 4-D shape (batchsize, num_wv, n_layer, 3): dispersive per wavelength.
              Complex.
        d_1d: thicknesses of all layers, shape (batchsize, n_layer). Complex.
        wv_1d: wavelengths of simulations, shape (batchsize, n_wls). Real
        n_in: incident media refractive index. Accepts Python scalar or tensor of
              shape (batchsize, num_wv) for per-wavelength values.
        n_out: transmit media refractive index. Same accepted shapes as n_in.
        theta_x_1d: incident Zenith angle, shape (batchsize, n_aoi_angles). Real
        theta_y_1d: azimuth angle of incident light, shape (batchsize, n_az_angles). Real.

    Returns:
        Jones_trn, Jones_ref: Jones matrices, each with shape (batchsize, n_wls, n_aoi_angles, n_az_angles, 2, 2). Complex
    """
    device = a_2d.device

    batchsize = d_1d.shape[0]
    num_wv = wv_1d.size()[1]
    num_x = theta_x_1d.size()[1]
    num_y = theta_y_1d.size()[1]
    num_layer = d_1d.size()[1]

    # Normalize n_2d to (batch, num_wv, n_layer, 3)
    if n_2d.dim() == 3:
        # (batch, n_layer, 3) — broadcast across wvlns
        n_2d_w = n_2d.unsqueeze(1).expand(-1, num_wv, -1, -1)
    elif n_2d.dim() == 4:
        # (batch, num_wv, n_layer, 3) — already dispersive
        n_2d_w = n_2d
    else:
        raise ValueError(f"n_2d must be 3-D or 4-D, got shape {n_2d.shape}")

    # Vectorized AOI and Az calculation (no loops)
    # theta_x_1d: (batchsize, num_x), theta_y_1d: (batchsize, num_y)
    # AOI_2d: (batchsize, num_x, num_y) - broadcast theta_x over y dimension
    AOI_2d = theta_x_1d.unsqueeze(-1).expand(-1, -1, num_y).to(torch.complex64)
    # Az_2d: (batchsize, num_x, num_y) - broadcast theta_y over x dimension
    Az_2d = theta_y_1d.unsqueeze(1).expand(-1, num_x, -1).to(torch.float64)

    k0_1d = 2 * torch.pi / wv_1d
    # ng has shape (batch, num_wv, n_layer)
    ng_4d = torch.sqrt(
        (n_2d_w[..., 0] ** 2 + n_2d_w[..., 1] ** 2 + n_2d_w[..., 2] ** 2) / 3
    )

    # n_in / n_out: scalar or shape (batch, num_wv). Build (batch, num_wv, 1, 1) eps.
    def _per_wvln(x):
        if torch.is_tensor(x) and x.dim() == 2:
            return (x ** 2).to(torch.complex64).unsqueeze(-1).unsqueeze(-1)
        return torch.tensor(complex(x) ** 2, dtype=torch.complex64, device=device).view(1, 1, 1, 1)

    eps_in = _per_wvln(n_in)
    eps_out = _per_wvln(n_out)

    # For the per-wvln Snell/AOI, also compute per-wvln n_in / n_out scalars
    # ready for broadcasting into the (batch, num_wv, num_x, num_y) theta grid.
    def _n_per_wvln(x):
        if torch.is_tensor(x) and x.dim() == 2:
            return x.to(torch.complex64).unsqueeze(-1).unsqueeze(-1)
        return torch.tensor(complex(x), dtype=torch.complex64, device=device).view(1, 1, 1, 1)

    n_in_4d = _n_per_wvln(n_in)   # (batch, num_wv, 1, 1) or (1,1,1,1)
    n_out_4d = _n_per_wvln(n_out)

    # theta inputs to EnterExitMatrix_XY are 4-D
    AOI_3d = AOI_2d.unsqueeze(1).expand(-1, num_wv, -1, -1)  # (batch, num_wv, num_x, num_y)
    theta_inc_air_3d = AOI_3d
    # Use complex_arcsin to properly handle evanescent waves beyond critical angle
    theta_inc_sub_3d = complex_arcsin(n_in_4d * torch.sin(AOI_3d) / n_out_4d)

    # ng_5d shape: (batch, num_wv, num_x, num_y, n_layer)
    ng_5d = ng_4d.unsqueeze(2).unsqueeze(3).expand(-1, -1, num_x, num_y, -1)
    # AOI_4d: (batch, num_wv, num_x, num_y, n_layer)
    AOI_4d = AOI_2d.unsqueeze(1).unsqueeze(-1).expand(-1, num_wv, -1, -1, num_layer)

    # Use complex_arcsin for angles in each layer - critical for TIR handling
    # n_in_4d is (batch, num_wv, 1, 1); after unsqueeze(-1) becomes (batch, num_wv, 1, 1, 1)
    # which broadcasts against the layer axis.
    theta_inc_medium_4d = complex_arcsin(n_in_4d.unsqueeze(-1) * torch.sin(AOI_4d) / ng_5d)
    sin_Vt_4d = ng_5d * torch.sin(theta_inc_medium_4d)

    eps_6d = create_eps_matrix_XY(a_2d, n_2d_w, Az_2d)

    # Extract epsilon components — each (batch, num_wv, num_x, num_y, num_layer)
    exx_4d = eps_6d[..., 0, 0]
    exy_4d = eps_6d[..., 0, 1]
    exz_4d = eps_6d[..., 0, 2]
    eyx_4d = eps_6d[..., 1, 0]
    eyy_4d = eps_6d[..., 1, 1]
    eyz_4d = eps_6d[..., 1, 2]
    ezx_4d = eps_6d[..., 2, 0]
    ezy_4d = eps_6d[..., 2, 1]
    ezz_4d = eps_6d[..., 2, 2]

    # Build Q matrix for all layers at once
    Q_6d = torch.zeros(
        (batchsize, num_wv, num_x, num_y, num_layer, 4, 4),
        dtype=torch.complex64,
        device=device,
    )
    Q_6d[:, :, :, :, :, 0, 0] = -ezx_4d * sin_Vt_4d / ezz_4d
    Q_6d[:, :, :, :, :, 0, 1] = 1 - sin_Vt_4d**2 / ezz_4d
    Q_6d[:, :, :, :, :, 0, 2] = -ezy_4d * sin_Vt_4d / ezz_4d
    Q_6d[:, :, :, :, :, 1, 0] = exx_4d - exz_4d * ezx_4d / ezz_4d
    Q_6d[:, :, :, :, :, 1, 1] = -exz_4d * sin_Vt_4d / ezz_4d
    Q_6d[:, :, :, :, :, 1, 2] = exy_4d - exz_4d * ezy_4d / ezz_4d
    Q_6d[:, :, :, :, :, 2, 3] = 1.0
    Q_6d[:, :, :, :, :, 3, 0] = eyx_4d - eyz_4d * ezx_4d / ezz_4d
    Q_6d[:, :, :, :, :, 3, 1] = -eyz_4d * sin_Vt_4d / ezz_4d
    Q_6d[:, :, :, :, :, 3, 2] = eyy_4d - eyz_4d * ezy_4d / ezz_4d - sin_Vt_4d**2

    dtype = torch.complex64
    k0_1d_exp = k0_1d.reshape(batchsize, num_wv, 1, 1, 1).to(dtype)
    d_1d_exp = d_1d.reshape(batchsize, 1, 1, 1, num_layer).to(dtype)
    k0d = k0_1d_exp * d_1d_exp

    # gradient-stable: avoids linalg_eig_backward eigenvector phase ambiguity
    exponent = 1j * k0d.unsqueeze(-1).unsqueeze(-1) * Q_6d
    Pn_flat = torch.linalg.matrix_exp(exponent.reshape(-1, 4, 4))
    Pn_all = Pn_flat.view(batchsize, num_wv, num_x, num_y, num_layer, 4, 4)

    # Sequential multiplication of layer transfer matrices P = Pn[n-1] @ ... @ Pn[1] @ Pn[0]
    # Start with first layer's transfer matrix
    P_5d = Pn_all[:, :, :, :, 0, :, :].clone()

    for i_layer in range(1, num_layer):
        P_5d = torch.matmul(Pn_all[:, :, :, :, i_layer, :, :], P_5d)

    T0_5d, T_N_inv_5d = EnterExitMatrix_XY(
        eps_in, eps_out, theta_inc_air_3d, theta_inc_sub_3d
    )

    N_5d = torch.matmul(torch.matmul(T_N_inv_5d, P_5d), T0_5d)

    N11_5d = N_5d[:, :, :, :, :2, :2]
    N12_5d = N_5d[:, :, :, :, :2, 2:]
    N21_5d = N_5d[:, :, :, :, 2:, :2]
    N22_5d = N_5d[:, :, :, :, 2:, 2:]

    # Reshape to scattering matrix S
    # For 2x2 matrices, direct inverse is efficient
    S12_5d = torch.linalg.inv(N22_5d)
    S11_5d = torch.matmul(-S12_5d, N21_5d)
    S21_5d = N11_5d + torch.matmul(N12_5d, S11_5d)

    Jones_trans = S21_5d
    Jones_rflc = S11_5d

    return Jones_trans, Jones_rflc