CMT: Directional Coupler

CMT: Directional Coupler#

GDS Layout#

We use the generic PDK shipped with GDSFactory. First, we will create the layout:

import gdsfactory as gf
import jax.numpy as jnp
import matplotlib.pyplot as plt
import sax

coupler_cell = gf.Component()
r = coupler_cell << gf.c.coupler(gap=0.3, length=5, dy=7)
coupler_cell.add_ports(r)
coupler_cell.draw_ports()
coupler_cell.plot()


mzi_cell = gf.Component()
r = mzi_cell << gf.components.mzi_lattice(coupler_lengths=(5, 10), 
                                      coupler_gaps=(0.3, 0.4), 
                                      delta_lengths=(10,))

mzi_cell.add_ports(r)
mzi_cell.draw_ports()
mzi_cell.plot()
../_images/dc060b8bc6ca43f9972d49483e13b36f14d8a7f873e6e195f801010f0044c9b2.png ../_images/61033ff773df43e5e197b974b7be53547aeeda6343bd4ebc85cfc1acc831d0a9.png

TODO: add MZI layout

SAX Models#

import jax
import jax.numpy as jnp
import sax
import xarray as xr
from jaxtyping import ArrayLike
from rich import print as rprint


def waveguide_1st_order_dispersion(
    wl: float = 1.55,
    wl0: float = 1.55,
    neff: float = 2.34,
    ng: float = 3.4,
    length: float = 10.0,
    loss_db_per_cm: float = 0.5,
) -> sax.SDict:
    """A simple straight waveguide model.

    Args:
        wl: wavelength in microns.
        wl0: reference wavelength in microns.
        neff: effective index.
        ng: group index.
        length: length of the waveguide in microns.
        loss_db_per_cm: loss in dB/cm.

    """
    dwl = wl - wl0
    dneff_dwl = (ng - neff) / wl0
    _neff = neff - dwl * dneff_dwl
    loss_db_per_µm = loss_db_per_cm * 1e-4
    phase = 2 * jnp.pi * _neff * length / wl
    amplitude = jnp.asarray(10 ** (-loss_db_per_µm * length / 20), dtype=complex)
    transmission = amplitude * jnp.exp(1j * phase)
    return sax.reciprocal(
        {
            ("o1", "o2"): transmission,
        }
    )


with jax.ensure_compile_time_eval():
    xarr_dc = (
        xr.open_dataarray("data/directional_coupler.nc")
        .load()
        .expand_dims({"kappa": ["kappa"]}, -1)
    )


def _interpolate_kappa(xarr: xr.DataArray, **kwargs: ArrayLike) -> jnp.ndarray:
    # Extract interpolation dims from xarray
    dims = [d for d in xarr.coords if d != "kappa"]

    # Ensure required args are provided
    missing = [d for d in dims if d not in kwargs]
    if missing:
        raise ValueError(f"Missing required interpolation inputs: {missing}")

    # Broadcast all input arrays
    arrays = [jnp.asarray(kwargs[dim]) for dim in dims]
    broadcasted = jnp.broadcast_arrays(*arrays)
    shape = broadcasted[0].shape

    # Prepare kwargs for interpolation
    interp_args = {dim: arr.ravel() for dim, arr in zip(dims, broadcasted, strict=True)}

    # Interpolate
    result = sax.interpolate_xarray(xarr, **interp_args)["kappa"]
    return result.reshape(shape)

def directional_coupler_no_phase(
    *,
    wl: float = 1.3,
    coupler_length: float = 10.0,
    gap: float = 0.5,
    offset: float = 20,
    bend_radius: float = 25,
) -> sax.SDict:
    r"""Ring coupler model.

    Semi-analytical model for ring couplers developed by GDSFactory.
    Use at your own risk.

    Args:
        wl: wavelength [µm]; between 1.5 and 1.6 µm.
        gap: gap between the two waveguides [µm]; between 0.05 and 1.5 µm.
        coupler_length: length of the ring coupler [µm]; between 0 and 100 µm.
        offset: offset between the two waveguides [µm]; between 5 and 100 µm.
        bend_radius: bend radius of the ring coupler [µm]; between 5 and 100 µm.
    """
    xarr = xarr_dc

    kappa = _interpolate_kappa(
        xarr=xarr,
        wavelength=wl,
        radius=bend_radius,
        gap=gap,
        length_x=coupler_length,
        v_offset=offset,
    )

    tau = jnp.sqrt(1 - jnp.array(kappa) ** 2)

    return sax.reciprocal(
        {
            ("o1", "o4"): tau,
            ("o1", "o3"): 1j * kappa,
            ("o2", "o4"): 1j * kappa,
            ("o2", "o3"): tau,
        }
    )


def directional_coupler(
    *,
    wl: float = 1.3,
    length: float = 10.0,
    gap: float = 0.5,
    offset: float = 20,
    bend_radius: float = 25,
) -> sax.SDict:
    r"""Directional coupler model.

    Semi-analytical model for directional couplers developed by GDSFactory.
    Use at your own risk.

    Args:
        wl: wavelength [µm]
        gap: gap between the two waveguides [µm]
        length: length of the ring coupler [µm]
        offset: offset between the two waveguides [µm]
        bend_radius: bend radius of the ring coupler [µm]
        with_euler: if True, the directional coupler will have an Euler bend.
    """

    coupler_length = length

    def sbend_length(radius: float, offset: float) -> float:
        return float(2 * radius * jnp.arccos(1 - offset / 2 / radius))

    coupler_circuit, info = sax.circuit(
        netlist={
            "instances": {
                "s1": "straight",
                "s2": "straight",
                "s3": "straight",
                "s4": "straight",
                "dc": "coupling_area",
            },
            "connections": {
                "s1,o1": "dc,o1",
                "s2,o1": "dc,o2",
                "s3,o1": "dc,o3",
                "s4,o1": "dc,o4",
            },
            "ports": {
                "o1": "s1,o2",
                "o2": "s2,o2",
                "o3": "s3,o2",
                "o4": "s4,o2",
            },
        },
        models={
            "straight": waveguide_1st_order_dispersion,
            "coupling_area": directional_coupler_no_phase,
        },
    )

    s = coupler_circuit(
        wl=wl,
        dc={
            "coupler_length": coupler_length,
            "gap": gap,
            "bend_radius": bend_radius,
        },
        s1={"length": coupler_length / 2 + sbend_length(bend_radius, offset)},
        s2={"length": coupler_length / 2 + sbend_length(bend_radius, offset)},
        s3={"length": coupler_length / 2 + sbend_length(bend_radius, offset)},
        s4={"length": coupler_length / 2 + sbend_length(bend_radius, offset)},
    )

    return sax.reciprocal(
        {
            ("o1", "o4"): s["o1", "o4"],
            ("o1", "o3"): s["o1", "o3"],
            ("o2", "o4"): s["o2", "o4"],
            ("o2", "o3"): s["o2", "o3"],
        }
    )
plt.figure()
for g in [0.3, 0.4, 0.5]:
    cell = gf.Component()
    r = cell << gf.c.coupler(gap=g)
    cell.add_ports(r.ports)

    netlist_dict = cell.get_netlist(recursive=False)
    circuit_sax_obj, info = sax.circuit(netlist_dict, models={"coupler": directional_coupler})
    wl = jnp.linspace(1.5, 1.6, 1000)
    s = circuit_sax_obj(
        wl=wl,
    )
    plt.plot(wl, jnp.abs(s["o1", "o3"]) ** 2, label=f"gap={g}")
plt.legend()
<matplotlib.legend.Legend at 0x152999af0>
../_images/ef910df69fc2f075fa4cbf16d03021e254403c73a772f316a8cef612e4dd7325.png