"""
Client module for cox regression
"""
from __future__ import annotations
import numpy as np
import numpy.typing as npt
import pandas as pd
from tno.fl.protocols.logistic_regression.client import Client as LogRegClient
from tno.fl.protocols.cox_regression import msg_ids, survival_stacking
from tno.fl.protocols.cox_regression.config import Config
[docs]
class Client(LogRegClient):
"""
The client class, representing data owning clients in the learning process.
Based on logistic regression client.
"""
[docs]
def __init__(self, config: Config, client_name: str) -> None:
"""
Initializes the client
:param config: The configuration for the experiment
"""
self.config: Config
super().__init__(config, client_name)
[docs]
def load_data(self) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.bool_]]:
"""
Load the data as in logistic regression, but add the time column to the data.
The time column is the last column in the data and can hence be used in further analysis.
:return: A data set containing covariates and time column and an array containing failure
data.
:raises FileNotFoundError: if the training data file does not exist.
"""
path = self.client.train_data_path
if not path.exists():
raise FileNotFoundError("The training data does not exist: ", path)
csv_data: pd.DataFrame = pd.read_csv(path)
return (
csv_data[self.config.data_columns + [self.config.time_column]].to_numpy(),
csv_data[self.config.target_column].to_numpy(),
)
[docs]
async def get_global_max_time(self) -> int:
"""
Get the global maximum event time.
:return: The global maximum event time
"""
# Share local max time with server
local_max_time = self.data[:, -1].max()
await self.pool.send(
self.SERVER_ID, local_max_time, msg_id=msg_ids.LOCAL_MAX_TIME
)
# Receive global max event time from server
return (
int(await self.pool.recv(self.SERVER_ID, msg_id=msg_ids.GLOBAL_MAX_TIME))
+ 1
)
[docs]
def compute_time_bins(self, global_max_time: int) -> npt.NDArray[np.float_]:
"""
Compute time bins given a max event time
:param global_max_time: The global max event time
:return: The time bins
"""
return np.linspace(0, global_max_time, self.config.n_bins)
[docs]
async def preprocessing(self) -> None:
"""
Preprocess the data: create time bins and stack the data
"""
global_max_time = await self.get_global_max_time()
time_bins = self.compute_time_bins(global_max_time)
self.data, self.target = survival_stacking.stack(
covariates=self.data[:, :-1],
times=self.data[:, -1],
failed=self.target,
time_bins=time_bins,
)