Task-Driven Lens Design
Design a lens from scratch using only a downstream task loss (image classification), with no classical optical design objectives (spot size, PSF, MTF). This explores unconventional lens designs optimized purely for task performance.
Script: 4_tasklens_img_classi.py
Reference: Xinge Yang, Qiang Fu, Yunfeng Nie, Wolfgang Heidrich, "Task-Driven Lens Design," Optics Express 34(5), 8961–8975 (2026). doi:10.1364/OE.588912
Overview
A GeoLens renders images through the differentiable camera pipeline, and a pretrained classifier (e.g. ResNet via timm) provides the training signal. The lens parameters are optimized to maximize classification accuracy rather than traditional image quality metrics.
import timm
import torch
from end2end_imaging import GeoLens
# Load lens and classifier
lens = GeoLens(filename="datasets/lenses/cellphone/cellphone80deg.json")
classifier = timm.create_model("resnet50", pretrained=True)
# Optimize lens for classification accuracy
opt_lens = torch.optim.Adam(lens.parameters(), lr=1e-4)
for step in range(num_steps):
opt_lens.zero_grad()
# Render through lens and classify
img_degraded = lens.render(img_batch, method="psf_patch")
logits = classifier(img_degraded)
loss = torch.nn.functional.cross_entropy(logits, labels)
loss.backward()
opt_lens.step()