import matplotlib.pyplot as plt
import numpy as np


def plot_data_iq_plane(data, groupby=None):
    fig, ax = plt.subplots(figsize=(10, 8))

    # --- Setup group-specific config ---
    if groupby == "amplitude":
        group_values = data["amplitude"].values
        group_dim = "amplitude"
        color_dim = "frequency"
        color_label = "Frequency (MHz)"
        label_fmt = lambda val: f"{val:.1f} a.u."  # noqa: E731
        label_pos = -1  # label at last point
        color_scale = 1e6  # convert Hz → MHz

    elif groupby == "frequency":
        group_values = data["frequency"].values
        group_dim = "frequency"
        color_dim = "amplitude"
        color_label = "Amplitude"
        label_fmt = lambda val: f"{val / 1e6:.2f} MHz"  # noqa: E731
        label_pos = 0  # label at first point
        color_scale = 1

    else:
        # Default: scatter colored by frequency (match shape with i/q)
        i_vals = np.real(data["data"].values)
        q_vals = np.imag(data["data"].values)
        freqs = np.broadcast_to(data["frequency"].values / 1e6, i_vals.shape)

        sc = ax.scatter(i_vals, q_vals, c=freqs.ravel(), s=35, alpha=0.8)
        _finalize_plot(ax, sc, "I/Q Data Scatter Plot", "Frequency (MHz)")
        return

    # --- Main plotting loop ---
    for val in group_values:
        data_slice = data.sel({group_dim: val})
        i_vals = np.real(data_slice["data"].values)
        q_vals = np.imag(data_slice["data"].values)
        colors = data_slice[color_dim].values / color_scale

        # Broadcast color array if needed
        colors = np.broadcast_to(colors, i_vals.shape)

        sc = ax.scatter(i_vals, q_vals, c=colors.ravel(), s=35, alpha=0.8)

        # Place label at specified point
        ax.text(
            i_vals[label_pos],
            q_vals[label_pos],
            label_fmt(val),
            fontsize=12,
            color="cyan",
            va="center",
            ha="left",
            bbox=dict(facecolor="black", alpha=0.5, edgecolor="none", boxstyle="round,pad=0.2"),
        )

    # --- Finalize plot ---
    _finalize_plot(ax, sc, "I/Q Data Scatter Plot", color_label)


def _finalize_plot(ax, sc, title, cbar_label=None):
    if sc is not None and cbar_label:
        cbar = plt.colorbar(sc, ax=ax)
        cbar.set_label(cbar_label, rotation=270, labelpad=15)

    ax.set_xlabel("I (Real Part)")
    ax.set_ylabel("Q (Imaginary Part)")
    ax.set_title(title)
    ax.set_aspect("equal", adjustable="box")
    ax.grid(True, linestyle="--", alpha=0.6)
