Source code for cox_regression.server

"""
Server module for cox regression
"""

from __future__ import annotations

import logging

import numpy as np

from tno.fl.protocols.logistic_regression.server import Server as LogRegServer
from tno.mpc.communication import Pool

from tno.fl.protocols.cox_regression import msg_ids
from tno.fl.protocols.cox_regression.survival_stacking import TimeBinType

logger = logging.getLogger(__name__)


[docs] class Server: """ The Server class. Responsible for aggregating results of the clients. Based on the logistic regression server. """
[docs] def __init__(self, pool: Pool, n_time_bins: int, max_iter: int = 25) -> None: """ Initializes the server. :param pool: The communication pool. :param n_time_bins: The number of time bins to use. :param max_iter: The max number of epochs """ self.pool = pool self.n_time_bins = n_time_bins self.log_reg_solver = LogRegServer(pool=pool, max_iter=max_iter)
async def _compute_global_max_time(self) -> float: """ Receive local maximum event times and distribute global maximum event time. """ local_max_times = await self.pool.recv_all(msg_id=msg_ids.LOCAL_MAX_TIME) return float(max(max_time for _, max_time in local_max_times)) def _split_time_bins(self, global_max_time: float) -> TimeBinType: """ Compute time bins evenly given a max event time. :param global_max_time: The global max event time :return: The time bins """ return np.linspace(0, global_max_time + 1, self.n_time_bins) async def _compute_time_bins( self, time_bins: TimeBinType | None = None ) -> TimeBinType: """ Compute time bins, based on the input of the clients. :param time_bins: Optional parameter specifying the time bins. If None, the bins will be spaced according to the _split_time_bins function. :return: The time bins :raises ValueError: If time bins are smaller than maximum time. """ global_max_time = await self._compute_global_max_time() if time_bins is None: return self._split_time_bins(global_max_time) if np.max(time_bins) < global_max_time: raise ValueError("Global max time is greater than maximum time bin.") return time_bins
[docs] async def run(self, time_bins: TimeBinType | None = None) -> None: """ Runs the entire learning process. :param time_bins: Optional parameter specifying the time bins. If None, the bins will be spaced according to the _split_time_bins function. """ # Compute and distribute the global maximum event time logger.info("Computing time bins..") time_bins = await self._compute_time_bins(time_bins) await self.pool.broadcast(time_bins, msg_ids.TIME_BINS) # Perform the logistic regression logger.info("Starting Logistic Regression..") await self.log_reg_solver.run()
[docs] async def compute_statistics(self) -> None: """ Perform server role in computing the statistics. """ await self.log_reg_solver.compute_statistics()