Source code for qblox_scheduler.structure.types

# Repository: https://gitlab.com/qblox/packages/software/qblox-scheduler
# Licensed according to the LICENSE file on the main branch
#
# Copyright 2020-2025, Quantify Consortium
# Copyright 2025, Qblox B.V.
"""
Types that support validation in Pydantic.

Pydantic recognizes magic method ``__get_validators__`` to receive additional
validators, that can be used, i.e., for custom serialization and deserialization.
We implement several custom types here to tune behavior of our models.

See `Pydantic documentation`_ for more information about implementing new types.

.. _Pydantic documentation: https://docs.pydantic.dev/latest/usage/types/custom/
"""

from __future__ import annotations

import base64
import math
from typing import TYPE_CHECKING, Annotated, Any, TypedDict

import networkx as nx
import numpy as np
from annotated_types import Ge
from pydantic import AfterValidator, AllowInfNan, GetCoreSchemaHandler, GetJsonSchemaHandler
from pydantic_core import core_schema

if TYPE_CHECKING:
    from collections.abc import Callable

    from pydantic.json_schema import JsonSchemaValue


def validate_non_negative_or_nan(value: float) -> float:
    """Validator that allows NaN or numbers greater than or equal to 0."""
    if not math.isnan(value) and value < 0:
        raise ValueError("input should be non-negative or NaN.")
    return value


[docs] Amplitude = Annotated[float, AllowInfNan(True)]
"""Type alias for a float that can be NaN."""
[docs] Delay = Annotated[float, AllowInfNan(False)]
"""Type alias for a float that can't be NaN."""
[docs] Duration = Annotated[float, Ge(0)]
"""Type alias for a float that must be >= 0 and not NaN."""
[docs] Frequency = Annotated[float, AllowInfNan(True), AfterValidator(validate_non_negative_or_nan)]
"""Type alias for a float that must be >= 0 but can be NaN.""" class _SerializedNDArray(TypedDict): data: str shape: tuple[int, ...] dtype: str class _NDArrayPydanticAnnotation: @classmethod def __get_pydantic_core_schema__( cls, _source_type: Any, # noqa: ANN401 _handler: GetCoreSchemaHandler, ) -> core_schema.CoreSchema: """ Pydantic-compatible version of :class:`numpy.ndarray`. Serialization is implemented using custom methods :meth:`.ndarray_to_dict` and :meth:`.validate_from_any`. Data array is encoded in Base64. """ return core_schema.json_or_python_schema( json_schema=core_schema.chain_schema( [ core_schema.dict_schema(), core_schema.no_info_plain_validator_function(cls.validate_from_any), ] ), python_schema=core_schema.union_schema( [ # check if it's an instance first before doing any further work core_schema.is_instance_schema(np.ndarray), core_schema.no_info_plain_validator_function(cls.validate_from_any), ], ), serialization=core_schema.plain_serializer_function_ser_schema( cls.ndarray_to_dict, when_used="json", ), ) @classmethod def __get_pydantic_json_schema__( cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler ) -> JsonSchemaValue: # Use the same schema that would be used for `dict` return handler(core_schema.dict_schema()) @staticmethod def ndarray_to_dict(v: np.ndarray) -> _SerializedNDArray: """Convert the given array to JSON-compatible dictionary.""" return { "data": base64.b64encode(v.tobytes()).decode("ascii"), "shape": v.shape, "dtype": str(v.dtype), } @staticmethod def validate_from_any(v: _SerializedNDArray | list[Any] | np.ndarray) -> np.ndarray: match v: case dict(): return np.frombuffer(base64.b64decode(v["data"]), dtype=v["dtype"]).reshape( v["shape"] ) case list(): return np.array(v) case np.ndarray(): return v case _: raise TypeError(f"Unsupported NumPy array: {v}") # We now create an `Annotated` wrapper that we'll use as the annotation for fields.
[docs] NDArray = Annotated[np.ndarray, _NDArrayPydanticAnnotation]
[docs] class Graph(nx.Graph): """Pydantic-compatible version of :class:`networkx.Graph`.""" # Avoid showing inherited init docstring (which leads to cross-reference issues) def __init__(self, incoming_graph_data=None, **attr) -> None: # noqa: ANN001 """Create a new graph instance.""" super().__init__(incoming_graph_data, **attr) @classmethod def __get_pydantic_core_schema__( cls: type[Graph], _source_type: Any, # noqa: ANN401 _handler: Callable[[Any], core_schema.CoreSchema], ) -> core_schema.CoreSchema: return core_schema.no_info_plain_validator_function( cls.validate, serialization=core_schema.plain_serializer_function_ser_schema( lambda g: nx.node_link_data(g, edges="links"), when_used="always" ), ) @classmethod
[docs] def validate(cls: type[Graph], v: Any) -> Graph: # noqa: ANN401 """Validate the data and cast from all known representations.""" if isinstance(v, dict): return cls(nx.node_link_graph(v)) return cls(v)
__all__ = [ "Amplitude", "Delay", "Duration", "Frequency", "Graph", "NDArray", ]