"""Utility functions for executing Schedules on Qblox hardware."""

from __future__ import annotations

import copy
import itertools
import json
import math

import numpy as np

from qblox_scheduler import Schedule
from qblox_scheduler.backends.qblox.visualization import (
    _display_compiled_instructions as display_dict,  # noqa: F401
)
from qblox_scheduler.operations import RampPulse, SquarePulse


def alternate_order(arr: np.ndarray) -> list:
    """
    Return a list where elements are ordered alternatively by max and min values.

    Args:
    arr (np.ndarray): Input array to be sorted and reordered.

    Returns:
    list: Reordered list alternating between max and min values.

    """
    arr = np.sort(arr)
    return list(itertools.chain.from_iterable(zip(arr[::-1], arr)))[: len(arr)]


def alternate_lists(list1: np.ndarray, list2: np.ndarray) -> list:
    """
    Merge two lists by alternating elements.

    Args:
    list1 (np.ndarray): First list.
    list2 (np.ndarray): Second list.

    Returns:
    list: List with elements from both lists, alternating.

    """
    min_length = min(len(list1), len(list2))
    merged = list(itertools.chain.from_iterable(zip(list1[:min_length], list2[:min_length])))

    # Append remaining elements from the longer list
    if len(list1) > len(list2):
        merged.extend(list1[min_length:])
    elif len(list2) > len(list1):
        merged.extend(list2[min_length:])

    return merged


def create_grid(list1: np.ndarray, list2: np.ndarray) -> list:
    """
    Create a grid of tuples from two lists by alternating values.

    Args:
    list1 (np.ndarray): First list (rows).
    list2 (np.ndarray): Second list (columns).

    Returns:
    list: List of tuples representing the grid.

    """
    grid = []
    list2_o = alternate_order(list2)

    while len(list1) > 1:
        list1_o = np.full(len(list2_o), list1[0])
        list1_i = np.full(len(list2_o), list1[-1])
        list1 = np.delete(list1, [0, -1])  # Remove first and last elements

        out0 = alternate_lists(list1_o.tolist(), list1_i.tolist())
        out1 = np.concatenate((list2_o, list2_o[::-1]))

        grid.extend(zip(out0, out1))

    if list1.size == 1:
        grid.extend(zip([list1[0]] * len(list2_o), list2_o))

    return grid


class PulseOperations:  # noqa N801
    """
    Class to manage QCM output voltage operations for qubit pairs.

    Provides methods to initialize, reload configuration, and create schedules
    for voltage ramps and hold operations.
    """

    _csd_config = None
    _config_path = None

    @classmethod
    def initialize(cls, config_path: str) -> None:
        """
        Initialize the DC_Operations class with the configuration path.

        Parameters
        ----------
        config_path : str
            Path to the JSON configuration file.

        """
        cls._config_path = config_path

    @classmethod
    def _reload_config(cls) -> None:
        """
        Reload the configuration from the JSON file.

        Raises
        ------
        ValueError
            If the configuration path is not set.

        """
        if cls._config_path is None:
            raise ValueError("Config path not set. Call initialize() first.")
        with open(cls._config_path, encoding="utf-8") as csd_config:
            cls._csd_config = json.load(csd_config)

    @classmethod
    def setv(
        cls, schedule: Schedule, qubit_pair: str, start_V: str, end_V: str, override: dict = {}
    ) -> Schedule:
        """
        Create a schedule to ramp voltage between two points for a given qubit pair.

        Parameters
        ----------
        schedule: Schedule
            The schedule within which to add defined pulses.
        qubit_pair : str
            The qubit pair to apply the voltage ramp.
        start_V : str
            The starting voltage point.
        end_V : str
            The ending voltage point.
        override: dict
            The dictionary which contains the new values for the qubit voltages and ramp rates on the CSD.

        Returns
        -------
        Schedule
            The generated schedule containing the voltage ramp.

        """
        cls._reload_config()
        if qubit_pair not in cls._csd_config:
            raise TypeError(
                "The designated pair is not in the dictionary. A list of pairs in the dictionary is given below: \n"
                + ",".join(list(cls._csd_config.keys()))
            )
        if qubit_pair not in override and override != {}:
            raise TypeError(
                "The designated pair is not in the dictionary. A list of pairs in the dictionary is given below: \n"
                + ",".join(list(override.keys()))
            )
        # schedule = Schedule("read-out_voltage_ramp", repetitions=1)

        if start_V in list(cls._csd_config[qubit_pair].keys()):
            if end_V in list(cls._csd_config[qubit_pair].keys()):
                None
            else:
                raise TypeError(
                    "The designated end voltage point set for DC_Operations.setV() is not in the dictionary. A list of voltages in the dictionary is given below: \n"
                    + ",".join(list(cls._csd_config[qubit_pair].keys()))
                )
        else:
            raise TypeError(
                "The designated start voltage point set for DC_Operations.setV() is not in the dictionary. A list of voltages in the dictionary is given below: \n"
                + ",".join(list(cls._csd_config[qubit_pair].keys()))
            )

        base_config = cls._csd_config[qubit_pair]
        override_config = override.get(qubit_pair, {})

        # Deep merge: override values take precedence

        if override_config == {}:
            merged_config = base_config
        else:
            merged_config = copy.deepcopy(base_config)
            for voltage_point, elems in override_config.items():
                if voltage_point not in merged_config:
                    merged_config[voltage_point] = elems
                else:
                    merged_config[voltage_point].update(elems)

        elem_names = merged_config[start_V].keys()
        for elem_name in elem_names:
            amplitude = merged_config[end_V][elem_name] - merged_config[start_V][elem_name]
            offset = merged_config[start_V][elem_name]
            ramp_time = merged_config["ramp_times"][f"{start_V}_to_{end_V}"]

            schedule.add(
                RampPulse(
                    amp=amplitude,
                    duration=ramp_time,
                    port=f"{elem_name}:gt",
                    offset=offset,
                ),
                ref_pt="end",
            )

        return schedule

    @classmethod
    def wait(cls, qubit_pair: str, start_V: str, duration: float, override: dict = {}) -> Schedule:
        """
        Create a schedule to hold voltage at a given point for a specified duration.

        Parameters
        ----------
        qubit_pair : str
            The qubit pair to apply the hold operation.
        start_V : str
            The voltage point to hold.
        duration : float
            The duration of the hold operation in seconds.
        override: dict
            The dictionary which contains the new values for the qubit voltages and ramp rates on the CSD.

        Returns
        -------
        Schedule
            The generated schedule containing the hold operation.

        """
        cls._reload_config()
        if qubit_pair not in cls._csd_config:
            raise TypeError(
                "The designated pair is not in the dictionary. A list of pairs in the dictionary is given below: \n"
                + ",".join(list(cls._csd_config.keys()))
            )
        if qubit_pair not in override and override != {}:
            raise TypeError(
                "The designated pair is not in the dictionary. A list of pairs in the dictionary is given below: \n"
                + ",".join(list(override.keys()))
            )
        if start_V in list(cls._csd_config[qubit_pair].keys()):
            None
        else:
            raise TypeError(
                "The designated start voltage point set for DC_Operations.wait() is not in the dictionary. A list of voltages in the dictionary is given below: \n"
                + ",".join(list(cls._csd_config[qubit_pair].keys()))
            )

        base_config = cls._csd_config[qubit_pair]
        override_config = override.get(qubit_pair, {})

        # Deep merge: override values take precedence

        if override_config == {}:
            merged_config = base_config
        else:
            merged_config = copy.deepcopy(base_config)
            for voltage_point, elems in override_config.items():
                if voltage_point not in merged_config:
                    merged_config[voltage_point] = elems
                else:
                    merged_config[voltage_point].update(elems)

        schedule = Schedule(f"wait_{start_V}")

        elem_names = merged_config[start_V].keys()
        for elem_name in elem_names:
            amplitude = merged_config[start_V][elem_name]

            schedule.add(
                SquarePulse(amp=amplitude, duration=duration, port=f"{elem_name}:gt"),
                ref_pt="start",
            )

        return schedule

    @classmethod
    def return_readout_sweep_points(cls, qubit_pair: str, detuning_array: np.ndarray) -> None:
        """
        Return an array of points to be swept around the readout point in both directions.

        Args:
        qubit_pair : str
            The qubit pair to apply the hold operation.
        detuning_array: np.ndarray
            An array defining the magnitude of the sweep, as well as the number of sweep steps.

        """
        cls._reload_config()
        readout_keys = cls._csd_config[qubit_pair]["readout_point"].keys()
        control_keys = cls._csd_config[qubit_pair]["control_point"].keys()

        control_point, readout_point = [], []
        for key_name, _ in zip(control_keys, readout_keys):
            if key_name == qubit_pair[0:2] or key_name == qubit_pair[3:5]:
                control_point.append(cls._csd_config[qubit_pair]["control_point"][key_name])
                readout_point.append(cls._csd_config[qubit_pair]["readout_point"][key_name])
            else:
                break
        sweep_dir_norm = (np.array(readout_point) - np.array(control_point)) / np.linalg.norm(
            np.array(readout_point) - np.array(control_point)
        )
        sweep_array = np.array(
            [readout_point + sweep_point * sweep_dir_norm for sweep_point in detuning_array]
        )
        return sweep_array

    @classmethod
    def set_control_point(
        cls, qubit_pair: str, first_voltage_value: float, second_voltage_value: float
    ) -> None:
        """
        Set the plunger gate voltages of both dots to a defined control point that can be determined from the charge stability diagram.

        Args:
        qubit_pair : str
            The qubit pair to apply the hold operation.
        first_voltage_value: float
            The defined plunger gate voltage value to be set for the first quantum dot of the selected pair (from the left).
        second_voltage_value: float
            The defined plunger gate voltage value to be set for the second quantum dot of the selected pair (from the left).

        """
        cls._reload_config()
        readout_keys = cls._csd_config[qubit_pair]["control_point"].keys()

        with open("configs/CSD_config.json", "w") as updated_csd_config:
            for key_name in readout_keys:
                if key_name == qubit_pair:
                    continue
                cls._csd_config[qubit_pair]["control_point"][qubit_pair[0:2]] = first_voltage_value
                cls._csd_config[qubit_pair]["control_point"][qubit_pair[3:5]] = second_voltage_value
            updated_csd_config.write(json.dumps(cls._csd_config))

    @classmethod
    def set_readout_point(
        cls, qubit_pair: str, first_voltage_value: float, second_voltage_value: float
    ) -> None:
        """
        Set the plunger gate voltages of both dots to a defined readout point that can be determined from the charge stability diagram.

        Args:
        qubit_pair : str
            The qubit pair to apply the hold operation.
        first_voltage_value: float
            The defined plunger gate voltage value to be set for the first quantum dot of the selected pair (from the left).
        second_voltage_value: float
            The defined plunger gate voltage value to be set for the second quantum dot of the selected pair (from the left).

        """
        cls._reload_config()
        readout_keys = cls._csd_config[qubit_pair]["readout_point"].keys()

        with open("configs/CSD_config.json", "w") as updated_csd_config:
            for key_name in readout_keys:
                if key_name == qubit_pair:
                    continue
                cls._csd_config[qubit_pair]["readout_point"][qubit_pair[0:2]] = first_voltage_value
                cls._csd_config[qubit_pair]["readout_point"][qubit_pair[3:5]] = second_voltage_value
            updated_csd_config.write(json.dumps(cls._csd_config))

    @classmethod
    def return_occupation_voltages(cls, qubit_pair: str) -> None:
        """
        Return the occupation voltages for readout point calibration.

        Args:
        qubit_pair : str
            The qubit pair to apply the hold operation.

        """
        cls._reload_config()
        readout_keys = cls._csd_config[qubit_pair]["readout_point"].keys()
        control_keys = cls._csd_config[qubit_pair]["control_point"].keys()

        read_points, control_points = [], []
        for key_name in zip(readout_keys, control_keys):
            if key_name[0] == qubit_pair:
                continue
            read_points.append(cls._csd_config[qubit_pair]["readout_point"][key_name[0]])
            control_points.append(cls._csd_config[qubit_pair]["control_point"][key_name[0]])

        distance = np.sqrt(
            (read_points[1] - control_points[1]) ** 2 + (read_points[1] - control_points[1]) ** 2
        )
        angle = math.degrees(
            math.atan2(
                -(read_points[1] - control_points[1]),
                (read_points[0] - control_points[0]),
            )
        )

        read_points = [
            round(read_points[0] + distance * np.sin(angle), 2),
            round(read_points[1] - distance * np.cos(angle), 2),
        ]

        return {"control_voltages": control_points, "readout_voltages": read_points}
