CMT: Directional Coupler#
GDS Layout#
We use the generic PDK shipped with GDSFactory. First, we will create the layout:
import gdsfactory as gf
import jax.numpy as jnp
import matplotlib.pyplot as plt
import sax
coupler_cell = gf.Component()
r = coupler_cell << gf.c.coupler(gap=0.3, length=5, dy=7)
coupler_cell.add_ports(r)
coupler_cell.draw_ports()
coupler_cell.plot()
mzi_cell = gf.Component()
r = mzi_cell << gf.components.mzi_lattice(coupler_lengths=(5, 10),
coupler_gaps=(0.3, 0.4),
delta_lengths=(10,))
mzi_cell.add_ports(r)
mzi_cell.draw_ports()
mzi_cell.plot()


TODO: add MZI layout
SAX Models#
import jax
import jax.numpy as jnp
import sax
import xarray as xr
from jaxtyping import ArrayLike
from rich import print as rprint
def waveguide_1st_order_dispersion(
wl: float = 1.55,
wl0: float = 1.55,
neff: float = 2.34,
ng: float = 3.4,
length: float = 10.0,
loss_db_per_cm: float = 0.5,
) -> sax.SDict:
"""A simple straight waveguide model.
Args:
wl: wavelength in microns.
wl0: reference wavelength in microns.
neff: effective index.
ng: group index.
length: length of the waveguide in microns.
loss_db_per_cm: loss in dB/cm.
"""
dwl = wl - wl0
dneff_dwl = (ng - neff) / wl0
_neff = neff - dwl * dneff_dwl
loss_db_per_µm = loss_db_per_cm * 1e-4
phase = 2 * jnp.pi * _neff * length / wl
amplitude = jnp.asarray(10 ** (-loss_db_per_µm * length / 20), dtype=complex)
transmission = amplitude * jnp.exp(1j * phase)
return sax.reciprocal(
{
("o1", "o2"): transmission,
}
)
with jax.ensure_compile_time_eval():
xarr_dc = (
xr.open_dataarray("data/directional_coupler.nc")
.load()
.expand_dims({"kappa": ["kappa"]}, -1)
)
def _interpolate_kappa(xarr: xr.DataArray, **kwargs: ArrayLike) -> jnp.ndarray:
# Extract interpolation dims from xarray
dims = [d for d in xarr.coords if d != "kappa"]
# Ensure required args are provided
missing = [d for d in dims if d not in kwargs]
if missing:
raise ValueError(f"Missing required interpolation inputs: {missing}")
# Broadcast all input arrays
arrays = [jnp.asarray(kwargs[dim]) for dim in dims]
broadcasted = jnp.broadcast_arrays(*arrays)
shape = broadcasted[0].shape
# Prepare kwargs for interpolation
interp_args = {dim: arr.ravel() for dim, arr in zip(dims, broadcasted, strict=True)}
# Interpolate
result = sax.interpolate_xarray(xarr, **interp_args)["kappa"]
return result.reshape(shape)
def directional_coupler_no_phase(
*,
wl: float = 1.3,
coupler_length: float = 10.0,
gap: float = 0.5,
offset: float = 20,
bend_radius: float = 25,
) -> sax.SDict:
r"""Ring coupler model.
Semi-analytical model for ring couplers developed by GDSFactory.
Use at your own risk.
Args:
wl: wavelength [µm]; between 1.5 and 1.6 µm.
gap: gap between the two waveguides [µm]; between 0.05 and 1.5 µm.
coupler_length: length of the ring coupler [µm]; between 0 and 100 µm.
offset: offset between the two waveguides [µm]; between 5 and 100 µm.
bend_radius: bend radius of the ring coupler [µm]; between 5 and 100 µm.
"""
xarr = xarr_dc
kappa = _interpolate_kappa(
xarr=xarr,
wavelength=wl,
radius=bend_radius,
gap=gap,
length_x=coupler_length,
v_offset=offset,
)
tau = jnp.sqrt(1 - jnp.array(kappa) ** 2)
return sax.reciprocal(
{
("o1", "o4"): tau,
("o1", "o3"): 1j * kappa,
("o2", "o4"): 1j * kappa,
("o2", "o3"): tau,
}
)
def directional_coupler(
*,
wl: float = 1.3,
length: float = 10.0,
gap: float = 0.5,
offset: float = 20,
bend_radius: float = 25,
) -> sax.SDict:
r"""Directional coupler model.
Semi-analytical model for directional couplers developed by GDSFactory.
Use at your own risk.
Args:
wl: wavelength [µm]
gap: gap between the two waveguides [µm]
length: length of the ring coupler [µm]
offset: offset between the two waveguides [µm]
bend_radius: bend radius of the ring coupler [µm]
with_euler: if True, the directional coupler will have an Euler bend.
"""
coupler_length = length
def sbend_length(radius: float, offset: float) -> float:
return float(2 * radius * jnp.arccos(1 - offset / 2 / radius))
coupler_circuit, info = sax.circuit(
netlist={
"instances": {
"s1": "straight",
"s2": "straight",
"s3": "straight",
"s4": "straight",
"dc": "coupling_area",
},
"connections": {
"s1,o1": "dc,o1",
"s2,o1": "dc,o2",
"s3,o1": "dc,o3",
"s4,o1": "dc,o4",
},
"ports": {
"o1": "s1,o2",
"o2": "s2,o2",
"o3": "s3,o2",
"o4": "s4,o2",
},
},
models={
"straight": waveguide_1st_order_dispersion,
"coupling_area": directional_coupler_no_phase,
},
)
s = coupler_circuit(
wl=wl,
dc={
"coupler_length": coupler_length,
"gap": gap,
"bend_radius": bend_radius,
},
s1={"length": coupler_length / 2 + sbend_length(bend_radius, offset)},
s2={"length": coupler_length / 2 + sbend_length(bend_radius, offset)},
s3={"length": coupler_length / 2 + sbend_length(bend_radius, offset)},
s4={"length": coupler_length / 2 + sbend_length(bend_radius, offset)},
)
return sax.reciprocal(
{
("o1", "o4"): s["o1", "o4"],
("o1", "o3"): s["o1", "o3"],
("o2", "o4"): s["o2", "o4"],
("o2", "o3"): s["o2", "o3"],
}
)
plt.figure()
for g in [0.3, 0.4, 0.5]:
cell = gf.Component()
r = cell << gf.c.coupler(gap=g)
cell.add_ports(r.ports)
netlist_dict = cell.get_netlist(recursive=False)
circuit_sax_obj, info = sax.circuit(netlist_dict, models={"coupler": directional_coupler})
wl = jnp.linspace(1.5, 1.6, 1000)
s = circuit_sax_obj(
wl=wl,
)
plt.plot(wl, jnp.abs(s["o1", "o3"]) ** 2, label=f"gap={g}")
plt.legend()
<matplotlib.legend.Legend at 0x152999af0>
