from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
import scipy
from scipy.interpolate import interp1d
from scipy.signal import lfilter


# Function to apply BIAS-T distortion to a signal
def bias_tee_distort(
    ysig: np.ndarray, tau: float, g: float = 1, sampling_rate: float = 1
) -> np.ndarray:
    """
    Apply BIAS-T distortion to the input signal using a digital filter.

    Args:
        ysig (np.ndarray): The input signal to be distorted.
        tau (float): The time constant of the BIAS-T filter.
        g (float, optional): Gain factor applied to the signal. Defaults to 1.
        sampling_rate (float, optional): Sampling rate of the input signal. Defaults to 1.

    Returns:
        np.ndarray: The distorted signal after applying BIAS-T effects.

    """
    # Apply BIAS-T distortion using a first-order difference equation (high-pass filter behavior)
    return g * lfilter([1, -1], [1, -np.exp(-1 / tau)], ysig)


# Function to pre-distort a signal to correct for BIAS-T effects
def bias_tee_correction(sig: np.ndarray, tau: float) -> np.ndarray:
    """
    Apply a pre-distortion to the input signal to compensate for the BIAS-T distortion.

    Args:
        sig (np.ndarray): The input signal to be pre-distorted.
        tau (float): The time constant of the BIAS-T filter.

    Returns:
        np.ndarray: The pre-distorted signal to counteract BIAS-T effects.

    """
    # Coefficient to determine the strength of pre-distortion
    k = 2 / (2 * tau + 1)

    # Numerator and denominator for the filter
    num = [k]
    den = np.array([1, -1])

    # Apply the pre-distortion using a digital filter and return the compensated signal
    predistorted = lfilter(num, den, sig) + sig
    return predistorted


# Plotting the waveform and rectangle
def plot_waveform_and_rectangle(
    x: np.ndarray,
    y: np.ndarray,
    decay: float,
    second_square_duration: float,
    amp_set_points: np.ndarray,
) -> None:
    """
    Plot the original waveform and its distorted version after applying BIAS-T distortion.

    It also highlights a region of interest with a rectangle.

    Args:
        x (np.ndarray): Array representing the x-axis values (e.g., time in seconds).
        y (np.ndarray): Array representing the y-axis values (e.g., amplitude in Volts).
        decay (float): Decay constant to model the BIAS-T distortion.
        second_square_duration (float): Duration of the second square pulse in the waveform.
        amp_set_points (np.ndarray): Amplitude set points used in the waveform generation.

    """
    # Create the first plot for the original waveform
    plt.figure(figsize=(6, 5))  # Set figure size for the first plot
    plt.plot(x, y)
    plt.axhline(0, color="red", linestyle="--")
    plt.title("Output of the instrument")
    plt.xlabel("Time (seconds)")
    plt.ylabel("Amplitude ")
    plt.tight_layout()  # Adjust layout for the first plot
    plt.show()  # Show the first plot

    # Create the second plot for the distorted waveform
    plt.figure(figsize=(6, 5))  # Set figure size for the second plot
    new_y = bias_tee_distort(y, decay)  # Apply BIAS-T distortion to the waveform
    plt.plot(x, new_y)
    plt.axhline(0, color="red", linestyle="--")
    plt.title("After Bias-T")
    plt.xlabel("Time (seconds)")
    plt.ylabel("Amplitude ")

    # Draw a rectangle around the second square pulse in the distorted waveform
    x_center, y_center, width, height = draw_rectangle(
        x, second_square_duration, amp_set_points, new_y
    )

    # Convert from the center to the bottom-left corner for rectangle plotting
    x_corner = x_center - width / 2
    y_corner = y_center - height / 2

    # Add the rectangle to the second plot
    rectangle = plt.Rectangle(
        (x_corner, y_corner),
        width,
        height,
        color="r",
        fill=False,
        linewidth=2,
        label="Highlighted Region",
        zorder=2,
    )
    plt.gca().add_artist(rectangle)

    # Add a legend to the second plot
    plt.legend()
    plt.tight_layout()  # Adjust layout for the second plot
    plt.show()  # Show the second plot


# Draw a rectangle around the second-to-last square pulse
def draw_rectangle(
    x: np.ndarray,
    second_square_duration: float,
    amp_set_points: np.ndarray,
    new_y: np.ndarray,
) -> tuple[float, float, float, float]:
    """
    Determine the dimensions and position of a rectangle to highlight a region in the waveform.

    Args:
        x (np.ndarray): Array representing the x-axis values (e.g., time in seconds).
        second_square_duration (float): Duration of the second square pulse in the waveform.
        amp_set_points (np.ndarray): Amplitude set points for the waveform.
        new_y (np.ndarray): The y-values of the distorted waveform after BIAS-T correction.

    Returns:
        tuple[float, float, float, float]: The center x and y coordinates, width, and height of the rectangle.

    """
    # Find peaks depending on the signal's mean value
    peaks = scipy.signal.find_peaks(new_y, distance=len(new_y) / len(amp_set_points))[0]

    # The x_center corresponds to the midpoint of the second-to-last pulse
    x_center = x[peaks[-2]] + second_square_duration / 2

    # The y_center is calculated as the average between two points of the second square pulse
    y_center = (new_y[peaks[-2]] + new_y[peaks[-2] + int(second_square_duration * 1e9)]) / 2

    # Define the rectangle width and height
    width = 3 * second_square_duration  # Adjust width to match the duration of the square pulse
    height = 3 * (new_y[peaks[-2]] - new_y[peaks[-2] + int(second_square_duration * 1e9)])

    return x_center, y_center, width, height


# Extract axes data from a compiled schedule plot
def get_axis(compiled_schedule: object) -> tuple[np.ndarray, np.ndarray]:
    """
    Extract the x and y axis data from the pulse diagram of a compiled schedule.

    Args:
        compiled_schedule: A compiled schedule object with a plot_pulse_diagram() method.

    Returns:
         tuple[np.ndarray, np.ndarray]: A tuple containing the array of x and y data for each line.

    """
    # Temporarily disable interactive plotting
    plt.ioff()

    # Plot the pulse diagram and extract the axis
    _, ax = compiled_schedule.plot_pulse_diagram()

    # Extract data from existing plot
    lines = ax.get_lines()
    x_data_list = []
    y_data_list = []

    plt.close()
    plt.ion()

    for line in lines:
        x = line.get_xdata()
        y = line.get_ydata()

        # Define interpolation function
        f = interp1d(x, y, kind="linear")

        # Create new x values with higher density
        x_dense = np.linspace(x.min(), x.max(), len(x) * 30)

        # Interpolated y values
        y_dense = f(x_dense)

        # Store the x and y data
        x_data_list.append(x_dense[1:-1])
        y_data_list.append(y_dense[1:-1])

    x, y = np.concatenate(x_data_list), np.concatenate(y_data_list)
    sorted_indices = np.argsort(x)
    sorted_x = x[sorted_indices]
    sorted_y = y[sorted_indices]

    # ensure time samples start from 0
    sorted_x = sorted_x - sorted_x.min()
    return sorted_x, sorted_y
