Skip to content

DiffractiveLens

Pure wave-optics lens built from diffractive surfaces with scalar diffraction propagation. Use it to model flat DOE / metasurface imaging systems end-to-end.

deeplens.DiffractiveLens

DiffractiveLens(filename=None, device=None, dtype=torch.float32, primary_wvln=DEFAULT_WAVE, wvln_rgb=WAVE_RGB, obj_depth=DEPTH)

Bases: Lens

Paraxial diffractive lens in which each element is modelled as a phase surface.

Every optical element (converging lens, DOE, metasurface, …) is represented by a phase function applied to an incoming complex wavefront. Propagation between surfaces uses the Angular Spectrum Method (ASM). This model is simple and fast, but accurate only in the paraxial regime (it does not account for higher-order geometric aberrations).

Attributes:

Name Type Description
surfaces list

Ordered list of diffractive/phase surfaces.

d_sensor Tensor

Distance from the last surface to the sensor plane [mm].

Notes

Lens parameters default to torch.float32; pass dtype=torch.float64 for higher-precision wave propagation.

Initialize a diffractive lens.

Parameters:

Name Type Description Default
filename str

Path to the lens configuration JSON file. If provided, loads the lens configuration from file. Defaults to None.

None
device str

Computation device ('cpu' or 'cuda'). Defaults to 'cpu'.

None
dtype dtype

Data type for the lens parameters. Defaults to torch.float32; pass torch.float64 for higher-precision wave propagation.

float32
primary_wvln float

Primary design wavelength [µm]. Used as fallback when a method is called without an explicit wvln. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
wvln_rgb sequence of float

Three wavelengths used for RGB computations, ordered [R, G, B] in µm. Defaults to WAVE_RGB.

WAVE_RGB
obj_depth float

Default object depth [mm], used when a method is called without an explicit depth. Defaults to DEPTH.

DEPTH
Source code in deeplens-src/deeplens/diffraclens.py
def __init__(
    self,
    filename=None,
    device=None,
    dtype=torch.float32,
    primary_wvln=DEFAULT_WAVE,
    wvln_rgb=WAVE_RGB,
    obj_depth=DEPTH,
):
    """Initialize a diffractive lens.

    Args:
        filename (str, optional): Path to the lens configuration JSON file. If provided, loads the lens configuration from file. Defaults to None.
        device (str, optional): Computation device ('cpu' or 'cuda'). Defaults to 'cpu'.
        dtype (torch.dtype, optional): Data type for the lens parameters.
            Defaults to torch.float32; pass torch.float64 for
            higher-precision wave propagation.
        primary_wvln (float, optional): Primary design wavelength [µm].
            Used as fallback when a method is called without an explicit
            ``wvln``.  Defaults to ``DEFAULT_WAVE``.
        wvln_rgb (sequence of float, optional): Three wavelengths used
            for RGB computations, ordered ``[R, G, B]`` in µm.  Defaults
            to ``WAVE_RGB``.
        obj_depth (float, optional): Default object depth [mm], used
            when a method is called without an explicit ``depth``.
            Defaults to ``DEPTH``.
    """
    super().__init__(
        device=device,
        dtype=dtype,
        primary_wvln=primary_wvln,
        wvln_rgb=wvln_rgb,
        obj_depth=obj_depth,
    )

    # Load lens file
    if filename is not None:
        self.read_lens_json(filename)
    else:
        self.surfaces = []
        # Set default sensor size and resolution if no file provided
        self.sensor_size = (8.0, 8.0)
        self.sensor_res = (2000, 2000)

    self.astype(self.dtype)

    # Use total track length (first element to sensor) as focal length
    if hasattr(self, "d_sensor"):
        self.foclen = float(self.d_sensor)
        self.calc_fov()

    # Move all tensors (surfaces, sensor params) to the target device.
    self.to(self.device)

read_lens_json

read_lens_json(filename)

Load the lens configuration from a JSON file.

Reads lens parameters including sensor configuration and diffractive surfaces from the specified JSON file. If sensor_size or sensor_res are not provided, defaults of 8mm x 8mm and 2000x2000 pixels will be used.

Parameters:

Name Type Description Default
filename str

Path to the JSON configuration file.

required
Source code in deeplens-src/deeplens/diffraclens.py
def read_lens_json(self, filename):
    """Load the lens configuration from a JSON file.

    Reads lens parameters including sensor configuration and diffractive surfaces
    from the specified JSON file. If sensor_size or sensor_res are not provided,
    defaults of 8mm x 8mm and 2000x2000 pixels will be used.

    Args:
        filename (str): Path to the JSON configuration file.
    """
    assert filename.endswith(".json"), "File must be a .json file."

    with open(filename, "r") as f:
        # Lens general info
        data = json.load(f)
        self.d_sensor = torch.tensor(data["d_sensor"])
        self.lens_info = data.get("info", "None")

        # Read sensor_size with default
        if "sensor_size" in data:
            sensor_size = tuple(data["sensor_size"])
        else:
            sensor_size = (8.0, 8.0)
            print(
                f"Sensor_size not found in lens file. Using default: {sensor_size} mm. "
                "Consider specifying sensor_size in the lens file or using set_sensor()."
            )

        # Read sensor_res with default
        if "sensor_res" in data:
            sensor_res = tuple(data["sensor_res"])
        else:
            sensor_res = (2000, 2000)
            print(
                f"Sensor_res not found in lens file. Using default: {sensor_res} pixels. "
                "Consider specifying sensor_res in the lens file or using set_sensor()."
            )

        # Configure sensor (also sets pixel_size and r_sensor).
        self.set_sensor(sensor_size, sensor_res)

        # Load diffractive surfaces/elements
        d = 0.0
        self.surfaces = []
        for surf_dict in data["surfaces"]:
            surf_dict["d"] = d

            if surf_dict["type"].lower() == "binary2":
                s = Binary2.init_from_dict(surf_dict)
            elif surf_dict["type"].lower() == "fresnel":
                s = Fresnel.init_from_dict(surf_dict)
            elif surf_dict["type"].lower() == "pixel2d":
                s = Pixel2D.init_from_dict(surf_dict)
            elif surf_dict["type"].lower() == "thinlens":
                s = ThinLens.init_from_dict(surf_dict)
            elif surf_dict["type"].lower() == "zernike":
                s = Zernike.init_from_dict(surf_dict)
            elif surf_dict["type"].lower() == "rank1":
                s = Rank1.init_from_dict(surf_dict)
            elif surf_dict["type"].lower() == "diffractedrotation":
                s = DiffractedRotation.init_from_dict(surf_dict)
            elif surf_dict["type"].lower() == "rotationallysymmetric":
                s = RotationallySymmetric.init_from_dict(surf_dict)
            else:
                raise ValueError(
                    f"Diffractive surface type {surf_dict['type']} not implemented."
                )

            self.surfaces.append(s)
            d_next = surf_dict["d_next"]
            d += d_next

write_lens_json

write_lens_json(filename)

Write the lens configuration to a JSON file.

Saves all lens parameters including sensor configuration and diffractive surface data to the specified file.

Parameters:

Name Type Description Default
filename str

Output path for the JSON file.

required
Source code in deeplens-src/deeplens/diffraclens.py
def write_lens_json(self, filename):
    """Write the lens configuration to a JSON file.

    Saves all lens parameters including sensor configuration and
    diffractive surface data to the specified file.

    Args:
        filename (str): Output path for the JSON file.
    """
    assert filename.endswith(".json"), "File must be a .json file."

    # Save lens to a file
    data = {}
    data["info"] = self.lens_info if hasattr(self, "lens_info") else "None"
    data["surfaces"] = []
    data["d_sensor"] = round(self.d_sensor.item(), 3)
    data["sensor_size"] = [
        round(float(self.sensor_size[0]), 3),
        round(float(self.sensor_size[1]), 3),
    ]
    data["sensor_res"] = self.sensor_res

    # Save diffractive surfaces
    for i, s in enumerate(self.surfaces):
        surf_dict = {"idx": i + 1}

        if isinstance(s, Pixel2D):
            surf_data = s.surf_dict(filename.replace(".json", "_pixel2d.pth"))
        elif isinstance(s, (Rank1, RotationallySymmetric)):
            surf_data = s.surf_dict(filename.replace(".json", f"_surf{i + 1}.pth"))
        else:
            surf_data = s.surf_dict()

        surf_dict.update(surf_data)

        if i < len(self.surfaces) - 1:
            surf_dict["d_next"] = (
                self.surfaces[i + 1].d.item() - self.surfaces[i].d.item()
            )
        else:
            # Last surface: distance to the sensor. read_lens_json requires
            # d_next on every surface, so the file must always include it.
            surf_dict["d_next"] = round(
                float(self.d_sensor) - self.surfaces[i].d.item(), 3
            )

        data["surfaces"].append(surf_dict)

    # Save data to a file
    with open(filename, "w") as f:
        json.dump(data, f, indent=4)

__call__

__call__(wave)

Propagate a wave through the lens system.

Source code in deeplens-src/deeplens/diffraclens.py
def __call__(self, wave):
    """Propagate a wave through the lens system."""
    return self.forward(wave)

forward

forward(wave)

Propagate a wave through the diffractive lens system to the sensor.

Sequentially applies phase modulation from each diffractive surface, then propagates the wave to the sensor plane using wave optics.

Parameters:

Name Type Description Default
wave ComplexWave

Input wave field entering the lens system.

required

Returns:

Name Type Description
ComplexWave

Output wave field at the sensor plane.

Source code in deeplens-src/deeplens/diffraclens.py
def forward(self, wave):
    """Propagate a wave through the diffractive lens system to the sensor.

    Sequentially applies phase modulation from each diffractive surface, then propagates
    the wave to the sensor plane using wave optics.

    Args:
        wave (ComplexWave): Input wave field entering the lens system.

    Returns:
        ComplexWave: Output wave field at the sensor plane.
    """
    # Propagate to DOE
    for surf in self.surfaces:
        wave = surf(wave)

    # Propagate to sensor
    wave = wave.prop_to(self.d_sensor.item())

    return wave

render_mono

render_mono(img, wvln=None, ks=None)

Simulate monochromatic lens blur by convolving an image with the point spread function.

Parameters:

Name Type Description Default
img Tensor

Input image. Shape: (B, 1, H, W)

required
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None
ks int

PSF kernel size. When None (default), the full sensor resolution (max(self.sensor_res)) is used.

None

Returns:

Type Description

torch.Tensor: Rendered image after applying lens blur with shape (B, 1, H, W).

Source code in deeplens-src/deeplens/diffraclens.py
def render_mono(self, img, wvln=None, ks=None):
    """Simulate monochromatic lens blur by convolving an image with the point spread function.

    Args:
        img (torch.Tensor): Input image. Shape: (B, 1, H, W)
        wvln (float, optional): Wavelength in µm. When ``None`` (default),
            falls back to ``self.primary_wvln``.
        ks (int, optional): PSF kernel size. When ``None`` (default), the
            full sensor resolution (``max(self.sensor_res)``) is used.

    Returns:
        torch.Tensor: Rendered image after applying lens blur with shape (B, 1, H, W).
    """
    wvln = self.primary_wvln if wvln is None else wvln
    psf = self.psf_infinite(wvln=wvln, ks=ks).unsqueeze(0)  # (1, ks, ks)
    img_render = conv_psf(img, psf)
    return img_render

psf

psf(points, wvln=None, ks=None, recenter=False, upsample_factor=1)

Calculate the monochromatic PSF for one or more point sources.

Off-axis point sources are supported. The signature follows psf and psf.

Parameters:

Name Type Description Default
points Tensor or list

Point source coordinates, shape [N, 3] or [3]. x, y are normalised to [-1, 1] (relative to the sensor half-width/height); z is the depth in mm (negative; -inf for an object at infinity).

required
wvln float

Wavelength in µm. When None (default), falls back to self.primary_wvln.

None
ks int

PSF kernel size in pixels. When None (default), the full sensor resolution (max(self.sensor_res)) is used.

None
recenter bool

How the ks x ks kernel is centered (both options keep off-axis PSFs centered in the kernel). If True, crop around the measured peak (argmax of the sensor-plane intensity). If False (default), crop around the perspective (pinhole) image of the field point. The lens forms a physically inverted image, but the result is flipped so the PSF is reported in the sensor/source-sign convention (a +x source -> +x).

False
upsample_factor int

Field upsampling factor to meet the Nyquist sampling constraint. Defaults to 1.

1

Returns:

Type Description

torch.Tensor: PSF intensity map, shape [ks, ks] for a single

point or [N, ks, ks] for a batch.

Note

A single Angular Spectrum Method (ASM) window is used, so very large off-axis fields can suffer from the shifted-phase/aliasing issue; see "Modeling off-axis diffraction with the least-sampling angular spectrum method".

Source code in deeplens-src/deeplens/diffraclens.py
def psf(self, points, wvln=None, ks=None, recenter=False, upsample_factor=1):
    """Calculate the monochromatic PSF for one or more point sources.

    Off-axis point sources are supported. The signature follows
    `psf` and `psf`.

    Args:
        points (torch.Tensor or list): Point source coordinates, shape
            ``[N, 3]`` or ``[3]``. ``x, y`` are normalised to ``[-1, 1]``
            (relative to the sensor half-width/height); ``z`` is the depth
            in mm (negative; ``-inf`` for an object at infinity).
        wvln (float, optional): Wavelength in µm. When ``None`` (default),
            falls back to ``self.primary_wvln``.
        ks (int, optional): PSF kernel size in pixels. When ``None``
            (default), the full sensor resolution
            (``max(self.sensor_res)``) is used.
        recenter (bool, optional): How the ks x ks kernel is centered (both
            options keep off-axis PSFs centered in the kernel). If True,
            crop around the measured peak (argmax of the sensor-plane
            intensity). If False (default), crop around the perspective
            (pinhole) image of the field point. The lens forms a physically
            inverted image, but the result is flipped so the PSF is reported
            in the sensor/source-sign convention (a +x source -> +x).
        upsample_factor (int, optional): Field upsampling factor to meet the
            Nyquist sampling constraint. Defaults to 1.

    Returns:
        torch.Tensor: PSF intensity map, shape ``[ks, ks]`` for a single
        point or ``[N, ks, ks]`` for a batch.

    Note:
        A single Angular Spectrum Method (ASM) window is used, so very large
        off-axis fields can suffer from the shifted-phase/aliasing issue;
        see "Modeling off-axis diffraction with the least-sampling angular
        spectrum method".
    """
    wvln = self.primary_wvln if wvln is None else wvln
    ks = max(int(self.sensor_res[0]), int(self.sensor_res[1])) if ks is None else ks
    if not torch.is_tensor(points):
        points = torch.tensor(points, dtype=torch.float64)
    single_point = points.dim() == 1
    points = points.reshape(-1, 3)

    # Field-plane sampling (high resolution to satisfy Nyquist).
    field_res = [
        self.surfaces[0].res[0] * upsample_factor,
        self.surfaces[0].res[1] * upsample_factor,
    ]
    field_size = [
        self.surfaces[0].res[0] * self.surfaces[0].ps,
        self.surfaces[0].res[1] * self.surfaces[0].ps,
    ]
    sensor_w, sensor_h = self.sensor_size

    psfs = []
    for pt in points:
        x_norm, y_norm, depth = float(pt[0]), float(pt[1]), float(pt[2])

        # Build the incident field for this (possibly off-axis) source.
        if math.isinf(depth):
            # Collimated source: tilted plane wave. The tilt sign is negated
            # so the source physically images to the inverted side (an object
            # at +x focuses to -x), consistent with the finite-depth point
            # source below; the inversion is undone by the flip further down.
            theta_x = math.atan(-x_norm * sensor_w / 2 / self.foclen)
            theta_y = math.atan(-y_norm * sensor_h / 2 / self.foclen)
            inp_wave = ComplexWave.plane_wave(
                wvln=wvln,
                z=0.0,
                phy_size=field_size,
                res=field_res,
                theta_x=theta_x,
                theta_y=theta_y,
            ).to(self.device)
        else:
            # Finite-depth source: spherical wave from the object point.
            scale = -depth / self.foclen  # object height / image height
            obj_x = x_norm * scale * sensor_w / 2
            obj_y = y_norm * scale * sensor_h / 2
            inp_wave = ComplexWave.point_wave(
                point=[obj_x, obj_y, depth],
                phy_size=field_size,
                res=field_res,
                wvln=wvln,
                z=0.0,
            ).to(self.device)

        # Propagate to the sensor and compute intensity. Shape [H, W].
        output_wave = self.forward(inp_wave)
        intensity = output_wave.u.abs() ** 2

        # Resample to the sensor pixel pitch.
        factor = output_wave.ps / self.pixel_size
        intensity = F.interpolate(
            intensity,
            scale_factor=(factor, factor),
            mode="bilinear",
            align_corners=False,
        )[0, 0, :, :]

        # Center crop / pad to the sensor resolution. ``sensor_res`` is
        # (W, H) while the intensity tensor is indexed [H, W]; handle each
        # dimension independently so non-square sensors work correctly.
        target_h, target_w = int(self.sensor_res[1]), int(self.sensor_res[0])
        intensity_h, intensity_w = intensity.shape[-2:]
        pad_h = max(target_h - intensity_h, 0)
        pad_w = max(target_w - intensity_w, 0)
        if pad_h > 0 or pad_w > 0:
            intensity = F.pad(
                intensity,
                (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2),
                mode="constant",
                value=0,
            )
        intensity_h, intensity_w = intensity.shape[-2:]
        start_h = (intensity_h - target_h) // 2
        start_w = (intensity_w - target_w) // 2
        intensity = intensity[
            start_h : start_h + target_h, start_w : start_w + target_w
        ]

        # The lens forms a physically inverted image (an object at +x focuses
        # to -x). Flip both axes to report the PSF in the sensor / source-sign
        # convention (+x source -> +x sensor position), consistent across the
        # collimated and finite-depth paths.
        intensity = torch.flip(intensity, [0, 1])

        # Crop the ks x ks patch around the PSF location. A diffractive lens
        # has no chief ray to trace, so when ``recenter`` is True the crop
        # center is the measured PSF peak (argmax of the simulated
        # sensor-plane intensity); otherwise the crop center is the
        # perspective (pinhole) image of the source field point.
        if recenter:
            peak = torch.argmax(intensity)
            coord_c_i = int(peak // target_w)
            coord_c_j = int(peak % target_w)
        else:
            # Perspective center: paraxial image of (x_norm, y_norm).
            # +x maps to larger columns and +y to smaller rows, matching the
            # un-inverted sensor-plane intensity.
            coord_c_j = int(round(target_w * (1.0 + x_norm) / 2.0))
            coord_c_i = int(round(target_h * (1.0 - y_norm) / 2.0))
        coord_c_i = min(max(coord_c_i, 0), target_h - 1)
        coord_c_j = min(max(coord_c_j, 0), target_w - 1)
        intensity = F.pad(
            intensity,
            [ks // 2, ks // 2, ks // 2, ks // 2],
            mode="constant",
            value=0,
        )
        psf = intensity[coord_c_i : coord_c_i + ks, coord_c_j : coord_c_j + ks]
        psf = psf / psf.sum()
        psfs.append(diff_float(psf))

    psf_out = torch.stack(psfs, dim=0)
    return psf_out[0] if single_point else psf_out

draw_layout

draw_layout(save_name='./doelens.png')

Draw a 2D layout diagram of the diffractive lens.

Each diffractive surface is drawn as a vertical dashed line at its axial position z = surface.d, and the sensor as a solid rectangle at z = d_sensor.

Parameters:

Name Type Description Default
save_name str

Path to save the figure. Defaults to './doelens.png'.

'./doelens.png'
Source code in deeplens-src/deeplens/diffraclens.py
def draw_layout(self, save_name="./doelens.png"):
    """Draw a 2D layout diagram of the diffractive lens.

    Each diffractive surface is drawn as a vertical dashed line at its axial
    position ``z = surface.d``, and the sensor as a solid rectangle at
    ``z = d_sensor``.

    Args:
        save_name (str, optional): Path to save the figure. Defaults to './doelens.png'.
    """
    fig, ax = plt.subplots(figsize=(12, 4))

    default_l = float(max(self.sensor_size))

    # Draw each diffractive surface as a vertical dashed line.
    for i, surf in enumerate(self.surfaces):
        d = float(surf.d)
        surf_l = float(getattr(surf, "w", default_l))
        ax.plot(
            [d, d], [-surf_l / 2, surf_l / 2], "orange", linestyle="--", dashes=[1, 1]
        )
        ax.text(
            d, surf_l / 2 * 1.08, f"{type(surf).__name__}\n(z={d:.1f} mm)",
            ha="center", va="bottom", fontsize=8,
        )

    # Draw the sensor plane as a thin rectangle.
    d_sensor = float(self.d_sensor)
    sensor_l = float(self.sensor_size[1])
    width = max(0.01 * d_sensor, 0.2)
    rect = plt.Rectangle(
        (d_sensor - width / 2, -sensor_l / 2), width, sensor_l,
        facecolor="none", edgecolor="black", linewidth=1,
    )
    ax.add_patch(rect)
    ax.text(
        d_sensor, sensor_l / 2 * 1.08, f"Sensor\n(z={d_sensor:.1f} mm)",
        ha="center", va="bottom", fontsize=8,
    )

    # Optical axis.
    ax.plot([0, d_sensor], [0, 0], "k-", linewidth=0.5, alpha=0.3)

    ax.set_xlabel("z [mm]")
    ax.set_yticks([])
    ax.margins(x=0.05, y=0.25)
    fig.savefig(save_name, dpi=300, bbox_inches="tight")
    plt.close(fig)

draw_psf

draw_psf(depth=None, ks=None, save_name='./psf_doelens.png', log_scale=True, eps=0.0001)

Draw on-axis RGB PSF.

Computes and saves a visualization of the RGB PSF for a given depth.

Parameters:

Name Type Description Default
depth float

Depth of the point source. When None (default), falls back to self.obj_depth.

None
ks int

Size of the PSF kernel in pixels. When None (default), the full sensor resolution (max(self.sensor_res)) is used.

None
save_name str

Path to save the PSF image. Defaults to './psf_doelens.png'.

'./psf_doelens.png'
log_scale bool

If True, display PSF in log scale. Defaults to True.

True
eps float

Small value for log scale to avoid log(0). Defaults to 1e-4.

0.0001
Source code in deeplens-src/deeplens/diffraclens.py
def draw_psf(
    self,
    depth=None,
    ks=None,
    save_name="./psf_doelens.png",
    log_scale=True,
    eps=1e-4,
):
    """Draw on-axis RGB PSF.

    Computes and saves a visualization of the RGB PSF for a given depth.

    Args:
        depth (float, optional): Depth of the point source. When ``None``
            (default), falls back to ``self.obj_depth``.
        ks (int, optional): Size of the PSF kernel in pixels. When ``None``
            (default), the full sensor resolution
            (``max(self.sensor_res)``) is used.
        save_name (str, optional): Path to save the PSF image. Defaults to './psf_doelens.png'.
        log_scale (bool, optional): If True, display PSF in log scale. Defaults to True.
        eps (float, optional): Small value for log scale to avoid log(0). Defaults to 1e-4.
    """
    depth = self.obj_depth if depth is None else depth
    psf_rgb = self.psf_rgb(points=[0.0, 0.0, depth], ks=ks)

    if log_scale:
        psf_rgb = torch.log10(psf_rgb + eps)
        psf_rgb = (psf_rgb - psf_rgb.min()) / (psf_rgb.max() - psf_rgb.min())
        save_name = save_name.replace(".png", "_log.png")

    save_image(psf_rgb.unsqueeze(0), save_name, normalize=True)

get_optimizer

get_optimizer(lr, optim_surf_ls=None)

Build an Adam optimizer over the trainable diffractive surfaces.

Parameters:

Name Type Description Default
lr float

Learning rate.

required
optim_surf_ls list[int]

Indices of the surfaces to optimize. If None, all diffractive surfaces are optimized.

None

Returns:

Type Description

torch.optim.Optimizer: Adam optimizer over the selected surfaces'

phase parameters.

Source code in deeplens-src/deeplens/diffraclens.py
def get_optimizer(self, lr, optim_surf_ls=None):
    """Build an Adam optimizer over the trainable diffractive surfaces.

    Args:
        lr (float): Learning rate.
        optim_surf_ls (list[int], optional): Indices of the surfaces to
            optimize. If None, all diffractive surfaces are optimized.

    Returns:
        torch.optim.Optimizer: Adam optimizer over the selected surfaces'
        phase parameters.
    """
    if optim_surf_ls is None:
        optim_surf_ls = list(range(len(self.surfaces)))

    params = []
    for idx in optim_surf_ls:
        params += self.surfaces[idx].get_optimizer_params(lr=lr)

    return torch.optim.Adam(params)