Source code for secure_learning.models.secure_lasso

"""
Implementation of LASSO regression model.
"""
from tno.mpc.mpyc.secure_learning.models.secure_linear import Linear
from tno.mpc.mpyc.secure_learning.models.secure_model import PenaltyTypes, SolverTypes


[docs] class Lasso(Linear): r""" Solver for LASSO regression. Optimizes a model with objective function $$\frac{1}{2n_{\textrm{samples}}} \times ||y - X_times_w||^2_2 + \alpha_1 ||w||_1$$ """ name = "LASSO regression"
[docs] def __init__( self, solver_type: SolverTypes = SolverTypes.GD, alpha: float = 1, ) -> None: """ Constructor method. :param solver_type: Solver type to use (e.g. Gradient Descent aka GD) :param alpha: Regularization parameter """ super().__init__(solver_type, penalty=PenaltyTypes.L1, alpha=alpha)