Source code for communication.serializer_plugins.pandas

"""
(De)serialization logic for pandas objects.
"""

from __future__ import annotations

import datetime
import io
import sys
import warnings
from typing import Any, Callable, Dict, cast

from packaging.version import parse

from tno.mpc.communication.exceptions import OptionalImportError
from tno.mpc.communication.functions import (
    redirect_importerror_oserror_to_optionalimporterror,
)
from tno.mpc.communication.serialization import Serialization

with redirect_importerror_oserror_to_optionalimporterror():
    import numpy as np
    import pandas as pd
    from pandas import DataFrame, Series

try:
    with redirect_importerror_oserror_to_optionalimporterror():
        from pyarrow import ArrowInvalid
except OptionalImportError:

    class ArrowInvalid(Exception):  # type: ignore[no-redef]
        """Dummy exception class in case pyarrow is unavailable."""


ARROW_SUPPORTED_TYPES = (
    bool,
    datetime.datetime,
    float,
    int,
    type(None),  # https://stackoverflow.com/a/41928862
    np.number,
    str,
)
TEMP_COLUMN_NAME = "TNO_MPC_COMMUNICATION_TEMPNAME"
# pandas 2.1.0 deprecates applymap
DF_MAPPING_METHOD = "map" if parse(pd.__version__) >= parse("2.1.0") else "applymap"


[docs] def pandas_serialize_dataframe( # pylint: disable=missing-raises-doc obj: DataFrame, use_pickle: bool, **kwargs: Any ) -> bytes | dict[str, Any]: r""" Function for serializing pandas dataframes Attempt to use parquet for smaller serialized dataframe, but fallback to dictionaries otherwise. :param obj: pandas object to serialize :param use_pickle: set to True to enable serialization fallback to pickle :param \**kwargs: optional extra keyword arguments :return: serialized dataframe """ try: # Attempt to serialize with parquet return obj.to_parquet() except ImportError: warnings.warn( "Package tno.mpc.communication more efficiently serializes pandas objects (with " "built-in type elements) with parquet, which requires additional dependencies. Please " "consider installing tno.mpc.communication[pandas]." ) except (ArrowInvalid, OverflowError): # Object contains unsupported types. We serialize these and let parquet do the rest. max_int_bit_length = sys.maxsize.bit_length() is_parquet_serializable: Callable[[Any], bool] = lambda x: ( isinstance(x, ARROW_SUPPORTED_TYPES) and not (isinstance(x, int) and x.bit_length() > max_int_bit_length) ) obj_partially_serialized: pd.DataFrame = getattr(obj, DF_MAPPING_METHOD)( lambda x: ( x if is_parquet_serializable(x) else Serialization.serialize(x, use_pickle=use_pickle, **kwargs) ), ) try: return obj_partially_serialized.to_parquet() except ArrowInvalid: pass except ( ValueError ) as exc: # Turn a very specific exception into a warnings, reraise unperturbed otherwise. if "string column" in exc.args[0]: # Parquet requires string column names. warnings.warn( "Failed to serialize a pandas object with parquet as the column names are not of " "type <str>. This might be resolved by using " "'df.columns = df.columns.astype(str)'. Falling back to serialization via " "dictionary." ) else: raise exc # Fall-back to dictionary serialization return cast(Dict[str, Any], obj.to_dict(orient="split"))
[docs] def pandas_deserialize_dataframe( obj: bytes | dict[str, Any], use_pickle: bool, **_kwargs: Any ) -> DataFrame: r""" Function for deserializing pandas dataframe :param obj: pandas dataframe to deserialize :param use_pickle: set to True to enable serialization fallback to pickle :param \**_kwargs: optional extra keyword arguments :raise ImportError: Object was serialized with parquet, but required dependencies for deserialization are missing. :return: deserialized dataframe """ if isinstance(obj, bytes): try: dataframe = pd.read_parquet(io.BytesIO(obj)) except ImportError as exc: raise ImportError( "The pandas object was serialized to parquet, but the required dependencies for " "deserializing this format are missing. Please install " "tno.mpc.communication[pandas]." ) from exc else: # Dataframe is serialized as dictionary dataframe = pd.DataFrame(**obj) fully_deserialized_df: pd.DataFrame = getattr(dataframe, DF_MAPPING_METHOD)( lambda x: ( Serialization.deserialize(x, use_pickle=use_pickle) if isinstance(x, dict) and "type" in x and "data" in x else x ) ) return fully_deserialized_df
[docs] def pandas_serialize_series(obj: Series[Any], **_kwargs: Any) -> bytes | dict[str, Any]: r""" Function for serializing pandas series :param obj: pandas series to serialize :param \**_kwargs: optional extra keyword arguments :return: serialized series """ if obj.name is None: return pandas_serialize_dataframe( pd.DataFrame(obj, columns=[TEMP_COLUMN_NAME]), **_kwargs ) return pandas_serialize_dataframe(pd.DataFrame(obj), **_kwargs)
[docs] def pandas_deserialize_series( obj: bytes | dict[str, Any], **kwargs: Any ) -> Series: # type: ignore[type-arg] r""" Function for deserializing pandas series :param obj: pandas series to deserialize :param \**kwargs: optional extra keyword arguments :return: deserialized series """ dataframe = pandas_deserialize_dataframe(obj, **kwargs) series = dataframe.iloc[:, 0] if series.name == TEMP_COLUMN_NAME: series.name = None return series
[docs] def pandas_serialize_timestamp( obj: pd.Timestamp, use_pickle: bool, **_kwargs: Any ) -> str: r""" Function for serializing pandas timestamp :param obj: pandas timestamp to serialize :param use_pickle: set to True to enable serialization fallback to pickle :param \**_kwargs: optional extra keyword arguments :return: serialized timestamp """ return obj.to_pydatetime().isoformat()
[docs] def pandas_deserialize_timestamp( obj: str, use_pickle: bool, **kwargs: Any ) -> pd.Timestamp: r""" Function for deserializing pandas timestamp :param obj: pandas timestamp to deserialize :param use_pickle: set to True to enable serialization fallback to pickle :param \**kwargs: optional extra keyword arguments :return: deserialized timestamp """ return pd.Timestamp(obj)
[docs] def register() -> None: """ Register pandas serializer and deserializer. """ Serialization.register( pandas_serialize_dataframe, pandas_deserialize_dataframe, pd.DataFrame.__name__ ) Serialization.register( pandas_serialize_series, pandas_deserialize_series, pd.Series.__name__ ) Serialization.register( pandas_serialize_timestamp, pandas_deserialize_timestamp, pd.Timestamp.__name__, check_annotations=False, )