"""
This module contains the serialization logic used in sending and receiving arbitrary objects.
"""
from __future__ import annotations
import inspect
import pickle
import sys
from functools import partial
from typing import Any, Callable, TypeVar, Union
import ormsgpack
from mypy_extensions import Arg, KwArg
from tno.mpc.communication import serializer_plugins
from tno.mpc.communication.exceptions import AnnotationError, RepetitionError
from tno.mpc.communication.functions import init
if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol
logger = init(__name__)
SerializerFunction = Union[
Callable[[Arg(Any, "self"), KwArg(Any)], Any],
Callable[[Arg(Any, "obj"), KwArg(Any)], Any],
Callable[[Arg(Any, "self"), Arg(bool, "use_pickle"), KwArg(Any)], Any],
Callable[[Arg(Any, "obj"), Arg(bool, "use_pickle"), KwArg(Any)], Any],
]
DeserializerFunction = Union[
Callable[[Arg(Any, "obj"), KwArg(Any)], Any],
Callable[[Arg(Any, "obj"), Arg(bool, "use_pickle"), KwArg(Any)], Any],
]
DorSFunction = TypeVar(
"DorSFunction", bound=Union[SerializerFunction, DeserializerFunction]
)
DEFAULT_PACK_OPTION = (
ormsgpack.OPT_SERIALIZE_NUMPY
| ormsgpack.OPT_PASSTHROUGH_BIG_INT
| ormsgpack.OPT_PASSTHROUGH_TUPLE
)
[docs]
class SupportsSerialization(Protocol):
"""
Type placeholder for classes supporting custom serialization.
"""
[docs]
def serialize(self, **kwargs: Any) -> Any:
r"""
Serialize this object into bytes.
:param \**kwargs: Optional extra keyword arguments.
:return: Serialization of this instance to Dict with bytes.
"""
[docs]
@staticmethod
def deserialize(obj: Any, **kwargs: Any) -> SupportsSerialization:
r"""
Deserialize the given object into an object of this class.
:param obj: object to be deserialized.
:param \**kwargs: Optional extra keyword arguments.
:return: Deserialized object.
"""
SERIALIZER_FUNCS: dict[
str,
SerializerFunction,
] = {}
DESERIALIZER_FUNCS: dict[
str,
DeserializerFunction,
] = {}
[docs]
class Serialization:
"""
Virtual class that provides packing and unpacking functions used for communications.
The outline is as follows:
- serialization functions for different classes
- packing function that handles metadata and determines which serialization needs to happen
- deserialization functions for different classes
- unpacking function that handles metadata and determines which deserialization needs to happen
"""
[docs]
@staticmethod
def register_class(
obj_class: type[SupportsSerialization],
check_annotations: bool = True,
overwrite: bool = False,
) -> None:
"""
Register (de)serialization logic associated to SupportsSerialization objects.
:param obj_class: object class to set serialization logic for.
:param check_annotations: validate return annotation of the serialization logic.
:param overwrite: Allow (silent) overwrite of currently registered serializers.
:raise RepetitionError: raised when serialization function is already defined for object class.
:raise TypeError: raised when provided object class has no (de)serialization function.
:raise AnnotationError: raised when the return annotation is inconsistent.
"""
obj_class_name = obj_class.__name__
serialization_func: SerializerFunction = obj_class.serialize
deserialization_func: DeserializerFunction = obj_class.deserialize
Serialization.register(
serialization_func,
deserialization_func,
obj_class_name,
check_annotations=check_annotations,
overwrite=overwrite,
)
[docs]
@staticmethod
def register(
serializer: SerializerFunction,
deserializer: DeserializerFunction,
*types: str,
check_annotations: bool = True,
overwrite: bool = False,
) -> None:
"""
Register serialization and deserialization functions.
:param serializer: Serializer function.
:param deserializer: Deserializer function.
:param types: Object types that the serializer can serialize.
:param check_annotations: Verify annotations of the (de)serializer conform to the protocol.
:param overwrite: Allow (silent) overwrite of currently registered serializers.
:raise RepetitionError: Attempted overwrite of registered serialization function.
:raise TypeError: Annotations do not conform to the protocol.
"""
Serialization._register_serializer(
serializer,
types,
check_annotations=check_annotations,
overwrite=overwrite,
)
Serialization._register_deserializer(
deserializer,
types,
check_annotations=check_annotations,
overwrite=overwrite,
)
@staticmethod
def _register_serializer(
serializer: SerializerFunction,
types: tuple[str, ...],
check_annotations: bool = True,
overwrite: bool = False,
) -> None:
"""
Register a serializer function.
:param serializer: Serializer function.
:param types: Object types that the serializer can serialize.
:param check_annotations: Verify annotations of the serializer conform to the protocol.
:param overwrite: Allow (silent) overwrite of currently registered serializers.
:raise RepetitionError: Attempted overwrite of registered serialization function.
:raise TypeError: Annotations do not conform to the protocol.
"""
if not callable(serializer):
raise TypeError("The provided serializer is not a function.")
if check_annotations:
signature = inspect.signature(serializer)
_validate_signature_has_kwargs(signature)
# For all deserializers registered to the given types, verify that serializer is
# compatible with their signatures.
same_type_deserializers = (
d for t, d in DESERIALIZER_FUNCS.items() if t in types
)
for des in same_type_deserializers:
_validate_signatures_consistent(
serializer_signature=signature,
deserializer_signature=inspect.signature(des),
)
Serialization._register(
SERIALIZER_FUNCS, serializer, types, overwrite=overwrite
)
@staticmethod
def _register_deserializer(
deserializer: DeserializerFunction,
types: tuple[str, ...],
check_annotations: bool = True,
overwrite: bool = False,
) -> None:
"""
Register a deserializer function.
:param deserializer: Deserializer function.
:param types: Object types that the serializer can serialize.
:param check_annotations: Verify annotations of the deserializer conform to the protocol.
:param overwrite: Allow (silent) overwrite of currently registered serializers.
:raise RepetitionError: Attempted overwrite of registered serialization function.
:raise TypeError: Annotations do not conform to the protocol.
"""
if not callable(deserializer):
raise TypeError("The provided deserializer is not a function.")
if check_annotations:
signature = inspect.signature(deserializer)
_validate_signature_has_kwargs(signature)
_validate_provided_return_annotation(signature, types)
_validate_signature_accepts_keyword(signature, "obj")
# For all serializers registered to the given types, verify that deserializer is
# compatible with their signatures.
same_type_serializers = (
s for t, s in SERIALIZER_FUNCS.items() if t in types
)
for ser in same_type_serializers:
_validate_signatures_consistent(
serializer_signature=inspect.signature(ser),
deserializer_signature=signature,
)
Serialization._register(
DESERIALIZER_FUNCS, deserializer, types, overwrite=overwrite
)
@staticmethod
def _register(
target_dict: dict[str, DorSFunction],
d_or_s_function: DorSFunction,
types: tuple[str, ...],
overwrite: bool,
) -> None:
"""
In-place add (de)serializer to a target dictionary for multiple keys.
:param target_dict: Target dictionary.
:param d_or_s_function: (De)serializer to register in the target dictionary
:param types: Types of objects that the provided (de)serializer can be applied to.
:param overwrite: Allow (silent) overwrite of currently registered serializers.
:raise RepetitionError: Attempted overwrite of registered (de)serializer.
"""
for type_ in types:
if type_ in target_dict and not overwrite:
raise RepetitionError(
f"The logic for type {type_} has already been set"
)
target_dict[type_] = d_or_s_function
[docs]
@staticmethod
def clear_serialization_logic(reload_defaults: bool = True) -> None:
"""
Clear all custom serialization (and deserialization) logic that was added to this class.
:param reload_defaults: After clearing, reload the (de)serialization logic that is
provided by the package.
"""
SERIALIZER_FUNCS.clear()
DESERIALIZER_FUNCS.clear()
if reload_defaults:
serializer_plugins.register_defaults()
[docs]
@staticmethod
def default_serialize(obj: Any, use_pickle: bool, **_kwargs: Any) -> bytes:
r"""
Fall-back function is case no specific serialization function is available.
This function uses the pickle library
:param obj: object to serialize
:param use_pickle: set to true if one wishes to use pickle as a fallback serializer
:param \**_kwargs: optional extra keyword arguments
:raise NotImplementedError: raised when no serialization function is defined for object
:return: serialized object
"""
if use_pickle:
return pickle.dumps(obj)
# else
raise NotImplementedError(
f"There is no serialization function defined for "
f"{obj.__class__.__name__} objects."
)
[docs]
@staticmethod
def serialize(
obj: Any,
use_pickle: bool,
**kwargs: Any,
) -> bytes | dict[str, bytes]:
r"""
Function that detects with serialization function should be used and applies it
:param obj: object to serialize
:param use_pickle: set to true if one wishes to use pickle as a fallback serializer
:param \**kwargs: optional extra keyword arguments
:return: serialized object
"""
# pylint: disable=missing-raises-doc
obj_class = obj.__class__
obj_class_name = obj_class.__name__
# Take the default serialization function
default_serializer: SerializerFunction = partial(
Serialization.default_serialize, use_pickle=use_pickle
)
# check if the serialization logic for the object has been added in an earlier stage
serialization_func = SERIALIZER_FUNCS.get(obj_class_name, default_serializer)
try:
data = serialization_func(obj, use_pickle=use_pickle, **kwargs)
except Exception:
logger.exception("Serialization failed!")
raise
ser_obj = {"type": obj_class_name, "data": data}
return ser_obj
[docs]
@staticmethod
def pack(
obj: Any,
msg_id: str | int,
use_pickle: bool,
option: int | None = DEFAULT_PACK_OPTION,
**kwargs: Any,
) -> bytes:
r"""
Function that adds metadata and serializes the object for transmission.
:param obj: object to pack
:param msg_id: message identifier associated to the message
:param use_pickle: set to true if one wishes to use pickle as a fallback packer
:param option: ormsgpack options can be specified through this parameter
:param \**kwargs: optional extra keyword arguments
:raise TypeError: Failed to serialize the provided object
:return: packed object (serialized and annotated)
"""
dict_object = {"object": obj, "id": msg_id}
try:
packed_object = ormsgpack.packb(
dict_object,
default=lambda _: Serialization.serialize(_, use_pickle, **kwargs),
option=option,
)
except TypeError:
logger.exception(
"Packing failed, consider 1) enabling use_pickle for"
" inefficient/slow fallback to pickle, or 2) implement"
" a serialization method for this type/structure, or 3)"
" resolve the error by setting an option (if available)."
)
raise
return packed_object
[docs]
@staticmethod
def default_deserialize(
obj: bytes, use_pickle: bool = False, **_kwargs: Any
) -> Any:
r"""
Fall-back function is case no specific deserialization function is available.
This function uses the pickle library
:param obj: object to deserialize
:param use_pickle: set to true if one wishes to use pickle as a fallback deserializer
:param \**_kwargs: optional extra keyword arguments
:raise NotImplementedError: Default serialization not possible for the provided object and
arguments
:return: deserialized object
"""
if use_pickle:
return pickle.loads(obj)
# else
raise NotImplementedError(
f"There is no deserialization function defined for "
f"{obj.__class__.__name__} objects."
)
[docs]
@staticmethod
def collection_deserialize(
collection_obj: list[Any] | dict[str, Any],
**kwargs: Any,
) -> dict[str, Any] | list[Any]:
r"""
Function for deserializing collections
:param collection_obj: object to deserialize
:param \**kwargs: optional extra keyword arguments
:raise ValueError: raised when (nested) value cannot be deserialized
:return: deserialized collection
"""
if isinstance(collection_obj, list):
result_list: list[Any] = []
for sub_obj in collection_obj:
deserialized_element = Serialization.deserialize(sub_obj, **kwargs)
result_list.append(deserialized_element)
return result_list
if (
isinstance(collection_obj, dict)
and "type" in collection_obj
and "data" in collection_obj
):
result_dict = {"type": collection_obj["type"], "data": {}}
for key, value in collection_obj["data"].items():
result_dict["data"][key] = Serialization.deserialize(value, **kwargs)
return result_dict
if isinstance(collection_obj, dict):
result_dict = {}
for key, value in collection_obj.items():
result_dict[key] = Serialization.deserialize(value, **kwargs)
return result_dict
raise ValueError("Cannot process collection")
[docs]
@staticmethod
def deserialize(obj: Any, use_pickle: bool = False, **kwargs: Any) -> Any:
r"""
Function that detects which deserialization function should be run and calls it
:param obj: object to deserialize
:param use_pickle: set to true if one wishes to use pickle as a fallback deserializer
:param \**kwargs: optional extra keyword arguments
:return: deserialized object
"""
if isinstance(obj, list):
return Serialization.collection_deserialize(
obj, use_pickle=use_pickle, **kwargs
)
if isinstance(obj, dict) and "type" in obj.keys() and "data" in obj.keys():
if isinstance(obj["data"], dict):
obj = Serialization.collection_deserialize(
obj, use_pickle=use_pickle, **kwargs
)
default_deserializer: DeserializerFunction = partial(
Serialization.default_deserialize, use_pickle=use_pickle
)
deserialization_func = DESERIALIZER_FUNCS.get(
obj["type"], default_deserializer
)
return deserialization_func(obj["data"], use_pickle=use_pickle, **kwargs)
if isinstance(obj, dict):
return Serialization.collection_deserialize(
obj, use_pickle=use_pickle, **kwargs
)
return obj
# endregion
[docs]
@staticmethod
def unpack(
obj: bytes,
use_pickle: bool = False,
option: int | None = None,
**kwargs: Any,
) -> tuple[str, Any]:
r"""
Function that handles metadata and turns the bytes object into a python object
:param obj: bytes object to unpack
:param use_pickle: set to true if one wishes to use pickle as a fallback deserializer
:param option: ormsgpack options can be specified through this parameter
:param \**kwargs: optional extra keyword arguments
:raise TypeError: Failed to deserialize the provided object
:return: unpacked object
"""
try:
dict_obj = ormsgpack.unpackb(obj, option=option)
except TypeError:
logger.exception(
"Unpacking failed, consider 1) enabling use_pickle for"
" inefficient/slow fallback to pickle, or 2) implement"
" a serialization method for this type/structure, or 3)"
" resolve the error by setting an option (if available)."
)
raise
deserialized_object = Serialization.deserialize(
dict_obj["object"], use_pickle=use_pickle, **kwargs
)
return dict_obj["id"], deserialized_object
def _validate_signature_has_kwargs(signature: inspect.Signature) -> None:
"""
Validate that the provided signature accepts kwargs.
:param signature: Signature to validate.
:raise TypeError: Signature does not contain kwargs.
"""
if not any(
param
for param in signature.parameters.values()
if param.kind == param.VAR_KEYWORD
):
raise TypeError(
"The provided (de)serializer does not accept a dict of keyword arguments that aren't "
"bound to any other parameter, i.e. a '**kwargs' parameter. This is required in the "
"function definition. These keyword arguments should also be forwarded to the next "
"(de)serialization call."
)
def _validate_provided_return_annotation(
signature: inspect.Signature, types: tuple[str, ...]
) -> None:
"""
Validate that the signature agrees with the provided types.
:param signature: Signature to validate.
:param types: Types that are supposedly consistent with the signature.
:raise AnnotationError: Types and signature do not agree.
"""
if (
signature.return_annotation not in types
and signature.return_annotation.__name__ not in types
):
raise AnnotationError(
f"Expected the provided deserialization function to return objects of type {types}, "
f"but detected return type annotation for {signature.return_annotation}. Make sure "
f"the function has type annotation '{types}' or set 'check_annotations' to False if "
"this is intentional behaviour."
)
def _validate_signature_accepts_keyword(
signature: inspect.Signature, word: str
) -> None:
"""
Validate that the signature has a certain parameter (keyword).
:param signature: Signature to validate.
:param word: Keyword to test against.
:raise TypeError: Signature does not accept keyword.
"""
try:
signature.parameters[word]
except KeyError as exception:
raise TypeError(
"The provided (de)serializer is missing the following parameter in its signature: "
f"{word}."
) from exception
def _validate_signatures_consistent(
serializer_signature: inspect.Signature, deserializer_signature: inspect.Signature
) -> None:
"""
Validate that annotations of serializer and deserializer are consistent.
:param serializer_signature: Signature of serializer.
:param deserializer_signature: Signature of deserializer.
:raise AnnotationError: Return type of serializer does not agree with expected input type of
deserializer.
"""
if (
serializer_signature.return_annotation
!= deserializer_signature.parameters["obj"].annotation
):
raise AnnotationError(
f"Return type of serialization function ({serializer_signature.return_annotation}) "
f"does not match type of 'obj' parameter in deserialization function "
f"({deserializer_signature.parameters['obj'].annotation})."
)