Source code for communication.serializer_plugins.numpy

"""
(De)serialization logic for numpy objects. Used only when
ormsgpack.packb(..., option=(ormsgpack.OPT_SERIALIZE_NUMPY, ...)) fails.
"""

# pylint: disable=unused-argument

from __future__ import annotations

from typing import Any

from tno.mpc.communication import Serializer
from tno.mpc.communication.functions import (
    redirect_importerror_oserror_to_optionalimporterror,
)
from tno.mpc.communication.packers import DeserializerOpts, SerializerOpts

with redirect_importerror_oserror_to_optionalimporterror():
    import numpy as np
    import numpy.typing as npt


# called only if ormsgpack fails serializing (see module docstring)
[docs] def numpy_serialize( obj: npt.NDArray[Any], opts: SerializerOpts, ) -> dict[str, Any]: """ Function for serializing numpy object arrays :param obj: numpy object to serialize :param opts: options to change the behaviour of the serialization. :return: serialized object """ return {"values": obj.tolist(), "shape": obj.shape}
[docs] def numpy_deserialize( obj: dict[str, Any], opts: DeserializerOpts ) -> npt.NDArray[np.object_]: """ Function for serializing numpy object arrays :param obj: numpy object to serialize :param opts: options to change the behaviour of the serialization. :return: deserialized object """ # ormsgpack can handle native numpy dtypes obj_dict = Serializer.transform_into_nonnative(obj, opts) if not obj_dict["shape"]: return np.array(obj_dict["values"]) result: npt.NDArray[np.object_] = np.empty(obj_dict["shape"], dtype=object) if obj_dict["values"]: result[:] = obj_dict["values"] return result
[docs] def register() -> None: """ Register numpy serializer and deserializer. """ Serializer.register( numpy_serialize, numpy_deserialize, np.ndarray.__name__, check_annotations=False )