# Repository: https://gitlab.com/qblox/packages/software/qblox-scheduler
# Licensed according to the LICENSE file on the main branch
#
# Copyright 2020-2025, Quantify Consortium
# Copyright 2025, Qblox B.V.
"""
Types that support validation in Pydantic.
Pydantic recognizes magic method ``__get_validators__`` to receive additional
validators, that can be used, i.e., for custom serialization and deserialization.
We implement several custom types here to tune behavior of our models.
See `Pydantic documentation`_ for more information about implementing new types.
.. _Pydantic documentation: https://docs.pydantic.dev/latest/usage/types/custom/
"""
from __future__ import annotations
import base64
import math
from typing import TYPE_CHECKING, Annotated, Any, TypedDict
import networkx as nx
import numpy as np
from annotated_types import Ge
from pydantic import AfterValidator, AllowInfNan, GetCoreSchemaHandler, GetJsonSchemaHandler
from pydantic_core import core_schema
if TYPE_CHECKING:
    from collections.abc import Callable
    from pydantic.json_schema import JsonSchemaValue
def validate_non_negative_or_nan(value: float) -> float:
    """Validator that allows NaN or numbers greater than or equal to 0."""
    if not math.isnan(value) and value < 0:
        raise ValueError("input should be non-negative or NaN.")
    return value
[docs]
Amplitude = Annotated[float, AllowInfNan(True)] 
"""Type alias for a float that can be NaN."""
[docs]
Delay = Annotated[float, AllowInfNan(False)] 
"""Type alias for a float that can't be NaN."""
[docs]
Duration = Annotated[float, Ge(0)] 
"""Type alias for a float that must be >= 0 and not NaN."""
[docs]
Frequency = Annotated[float, AllowInfNan(True), AfterValidator(validate_non_negative_or_nan)] 
"""Type alias for a float that must be >= 0 but can be NaN."""
class _SerializedNDArray(TypedDict):
    data: str
    shape: tuple[int, ...]
    dtype: str
class _NDArrayPydanticAnnotation:
    @classmethod
    def __get_pydantic_core_schema__(
        cls,
        _source_type: Any,  # noqa: ANN401
        _handler: GetCoreSchemaHandler,
    ) -> core_schema.CoreSchema:
        """
        Pydantic-compatible version of :class:`numpy.ndarray`.
        Serialization is implemented using custom methods :meth:`.ndarray_to_dict` and
        :meth:`.validate_from_any`. Data array is encoded in Base64.
        """
        return core_schema.json_or_python_schema(
            json_schema=core_schema.chain_schema(
                [
                    core_schema.dict_schema(),
                    core_schema.no_info_plain_validator_function(cls.validate_from_any),
                ]
            ),
            python_schema=core_schema.union_schema(
                [
                    # check if it's an instance first before doing any further work
                    core_schema.is_instance_schema(np.ndarray),
                    core_schema.no_info_plain_validator_function(cls.validate_from_any),
                ],
            ),
            serialization=core_schema.plain_serializer_function_ser_schema(
                cls.ndarray_to_dict,
                when_used="json",
            ),
        )
    @classmethod
    def __get_pydantic_json_schema__(
        cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
    ) -> JsonSchemaValue:
        # Use the same schema that would be used for `dict`
        return handler(core_schema.dict_schema())
    @staticmethod
    def ndarray_to_dict(v: np.ndarray) -> _SerializedNDArray:
        """Convert the given array to JSON-compatible dictionary."""
        return {
            "data": base64.b64encode(v.tobytes()).decode("ascii"),
            "shape": v.shape,
            "dtype": str(v.dtype),
        }
    @staticmethod
    def validate_from_any(v: _SerializedNDArray | list[Any] | np.ndarray) -> np.ndarray:
        match v:
            case dict():
                return np.frombuffer(base64.b64decode(v["data"]), dtype=v["dtype"]).reshape(
                    v["shape"]
                )
            case list():
                return np.array(v)
            case np.ndarray():
                return v
            case _:
                raise TypeError(f"Unsupported NumPy array: {v}")
# We now create an `Annotated` wrapper that we'll use as the annotation for fields.
[docs]
NDArray = Annotated[np.ndarray, _NDArrayPydanticAnnotation] 
[docs]
class Graph(nx.Graph):
    """Pydantic-compatible version of :class:`networkx.Graph`."""
    # Avoid showing inherited init docstring (which leads to cross-reference issues)
    def __init__(self, incoming_graph_data=None, **attr) -> None:  # noqa: ANN001
        """Create a new graph instance."""
        super().__init__(incoming_graph_data, **attr)
    @classmethod
    def __get_pydantic_core_schema__(
        cls: type[Graph],
        _source_type: Any,  # noqa: ANN401
        _handler: Callable[[Any], core_schema.CoreSchema],
    ) -> core_schema.CoreSchema:
        return core_schema.no_info_plain_validator_function(
            cls.validate,
            serialization=core_schema.plain_serializer_function_ser_schema(
                lambda g: nx.node_link_data(g, edges="links"), when_used="always"
            ),
        )
    @classmethod
[docs]
    def validate(cls: type[Graph], v: Any) -> Graph:  # noqa: ANN401
        """Validate the data and cast from all known representations."""
        if isinstance(v, dict):
            return cls(nx.node_link_graph(v))
        return cls(v) 
 
__all__ = [
    "Amplitude",
    "Delay",
    "Duration",
    "Frequency",
    "Graph",
    "NDArray",
]