Source code for communication.serializer_plugins.numpy
"""
(De)serialization logic for numpy objects. Used only when
ormsgpack.packb(..., option=(ormsgpack.OPT_SERIALIZE_NUMPY, ...)) fails.
"""
from __future__ import annotations
from typing import Any
from tno.mpc.communication.functions import (
redirect_importerror_oserror_to_optionalimporterror,
)
from tno.mpc.communication.serialization import Serialization
with redirect_importerror_oserror_to_optionalimporterror():
import numpy as np
import numpy.typing as npt
# called only if ormsgpack fails serializing (see module docstring)
[docs]
def numpy_serialize(obj: npt.NDArray[Any], **_kwargs: Any) -> dict[str, Any]:
r"""
Function for serializing numpy object arrays
:param obj: numpy object to serialize
:param \**_kwargs: optional extra keyword arguments
:return: serialized object
"""
return {"values": obj.tolist(), "shape": obj.shape}
[docs]
def numpy_deserialize(
obj: dict[str, Any], use_pickle: bool, **_kwargs: Any
) -> npt.NDArray[np.object_]:
r"""
Function for serializing numpy object arrays
:param obj: numpy object to serialize
:param use_pickle: set to True to enable serialization fallback to pickle
:param \**_kwargs: optional extra keyword arguments
:return: deserialized object
"""
# ormsgpack can handle native numpy dtypes
obj_dict = Serialization.deserialize(obj, use_pickle=use_pickle)
if not obj_dict["shape"]:
return np.array(obj_dict["values"])
result: npt.NDArray[np.object_] = np.empty(obj_dict["shape"], dtype=object)
if obj_dict["values"]:
result[:] = obj_dict["values"]
return result
[docs]
def register() -> None:
"""
Register numpy serializer and deserializer.
"""
Serialization.register(
numpy_serialize, numpy_deserialize, np.ndarray.__name__, check_annotations=False
)