"""
Configuration of a bank
"""
import asyncio
from typing import Dict, Optional, Tuple, Union
import pandas as pd
from tno.mpc.communication import Pool
from tno.mpc.encryption_schemes.paillier import PaillierCiphertext
from tno.mpc.encryption_schemes.utils import FixedPoint
from tno.mpc.protocols.distributed_keygen import DistributedPaillier
from .bank import Bank
[docs]
class Player:
"""
Player class performing steps in protocol
"""
# This factor is used to make sure that the computation of the risk scores
# remains accurate enough (we have to compensate for scalar division)
COMPENSATION_FACTOR = 10 ** 12
[docs]
def __init__(
self,
name: str,
accounts: pd.DataFrame,
transactions: pd.DataFrame,
pool: Pool,
paillier: DistributedPaillier,
delta: float = 0.5,
):
"""
Initializes a player instance
:param name: the name of the player
:param accounts: a dataframe of accounts containing an initial risk score per account
:param transactions: a dataframe of transactions
:param pool: the communication pool to use
:param paillier: an instance of DistributedPaillier
:param delta: the delta to use
"""
self._iteration = 0
self._pool = pool
self._paillier = paillier
self.bank: Bank = Bank(name)
self._other_banks: Tuple[Bank, ...] = tuple(
Bank(name) for name in self._pool.pool_handlers.keys()
)
self._decrypted_scores: Optional[Dict[str, FixedPoint]] = None
self.bank.process_accounts(accounts, delta)
for bank in self.banks:
bank.process_transactions(transactions)
@property
def banks(self) -> Tuple[Bank, ...]:
"""
All banks in the protocol
:return: all banks in the protocol
"""
return (self.bank,) + self.other_banks
@property
def other_banks(self) -> Tuple[Bank, ...]:
"""
The other banks in the protocol
:return: the other banks in the protocol
"""
return self._other_banks
@property
def risk_scores(self) -> Dict[str, FixedPoint]:
"""
The plaintext risk scores belonging to this player's bank
:return: plaintext dictionary of risk scores
:raise AttributeError: raised when risk scores are not available
"""
if self._decrypted_scores is None:
raise AttributeError("Risk scores haven been decrypted (yet)")
return self._decrypted_scores
def _compute_new_risk_scores(self) -> Dict[str, PaillierCiphertext]:
"""
Computes new risk scores
:return: dictionary containing new risk scores
"""
scores = self.bank.risk_scores
updated_scores: Dict[str, Union[PaillierCiphertext]] = {}
for account_name, account in self.bank.accounts_dict.items():
scaled_delta_diff = int(self.COMPENSATION_FACTOR * (1 - account.delta))
updated_scores[account_name] = scaled_delta_diff * account.safe_risk_score
if account.total_income != 0:
scaled_total_income_recip = (
int(self.COMPENSATION_FACTOR * account.delta)
// account.total_income
)
for incoming_account, incoming_amount in account.linked_accounts:
weight = scaled_total_income_recip * incoming_amount
updated_scores[account_name] += weight * scores[incoming_account]
return updated_scores
async def _decrypt_bank(
self, party: str
) -> Dict[str, Union[int, float, FixedPoint, None]]:
"""
Decrypts the risk scores of party and reveals them to party
:param party: the party to decrypt
:return: a dictionary of decrypted risk scores
"""
if party in (_.name for _ in self.other_banks):
risk_scores = await self._pool.recv(party, msg_id=f"Decryption {party}")
else:
risk_scores = self.bank.get_risk_scores()
await asyncio.gather(
*[
self._pool.send(
bank, risk_scores, msg_id=f"Decryption {self.bank.name}"
)
for bank in self._pool.pool_handlers.keys()
]
)
decrypted_risk_scores = {}
for key, risk_score in risk_scores.items():
decrypted_risk_scores[key] = await self._paillier.decrypt(
risk_score, receivers=[party]
)
return decrypted_risk_scores
async def _receive_update(self) -> None:
"""
Sends updated risk scores to other banks
"""
await asyncio.gather(
*[self._receive_updated_risk_scores(bank) for bank in self.other_banks]
)
async def _receive_updated_risk_scores(self, bank: Bank) -> None:
"""
Receives updated the scores of the external nodes of bank
:param bank: the bank to update
"""
risk_scores = await self._pool.recv(
bank.name, msg_id=f"Iteration {self._iteration}"
)
for label, risk_score in risk_scores.items():
self.bank.set_risk_score(label, risk_score, external=True)
async def _send_update(self) -> None:
"""
Sends updated risk scores to other banks
"""
await asyncio.gather(
*[self._send_updated_risk_scores(bank) for bank in self.other_banks]
)
async def _send_updated_risk_scores(self, bank: Bank) -> None:
"""
Sends updated risk score to bank
:param bank: the bank to send update to
"""
risk_scores = self.bank.get_risk_scores(bank.external_accounts)
await self._pool.send(
bank.name, risk_scores, msg_id=f"Iteration {self._iteration}"
)
[docs]
async def decrypt(self) -> None:
"""
Decryption of the risk scores per bank, compensating for the COMPENSATION_FACTOR
"""
for party in self._paillier.party_indices.keys():
if party == "self":
decrypted_scores = await self._decrypt_bank(party)
else:
await self._decrypt_bank(party)
self._decrypted_scores = {}
for account, scaled_score in decrypted_scores.items():
self._decrypted_scores[account] = scaled_score / (
self.COMPENSATION_FACTOR ** self._iteration
)
[docs]
def encrypt_initial_risk_scores(self) -> None:
"""
Encrypt the initialised risk scores of this player's accounts
"""
self.bank.encrypt(self._paillier)
[docs]
async def iteration(self) -> None:
"""
Perform a single iteration
"""
await asyncio.gather(
*[
self._send_update(),
self._receive_update(),
]
)
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self.update_risk_scores)
self._iteration += 1
[docs]
async def run_protocol(self, iterations: int) -> None:
"""
Runs the entire protocol
:param iterations: the number of iterations to perform
"""
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self.encrypt_initial_risk_scores)
for _ in range(iterations):
await self.iteration()
await self.decrypt()
[docs]
def update_risk_scores(self) -> None:
"""
Updates risk scores of all accounts
"""
updated_scores = self._compute_new_risk_scores()
for account_name, risk_score in updated_scores.items():
self.bank.set_risk_score(account_name, risk_score)