"""
Contains utils for matrices and vectors such as transposing and
secure signing
"""
from typing import TypeVar, Union, cast, overload
from mpyc.runtime import mpc
from mpyc.sectypes import SecureFixedPoint, SecureObject
from tno.mpc.mpyc.stubs.asyncoro import mpc_coro_ignore, returnType
from tno.mpc.mpyc.secure_learning.utils.types import Matrix, SecNumTypesTV, Vector
AnyTV = TypeVar("AnyTV")
@overload
def matrix_transpose(matrix: Matrix[SecNumTypesTV]) -> Matrix[SecNumTypesTV]:
...
@overload
def matrix_transpose(matrix: Matrix[float]) -> Matrix[float]:
...
[docs]
def matrix_transpose(matrix: Matrix[AnyTV]) -> Matrix[AnyTV]:
"""
Transpose a list of lists.
.. code-block:: python
A = [[31, 64], [32, 68], [33, 72], [34, 76]]
matrix_transpose(A) == [[31, 32, 33, 34], [64, 68, 72, 76]]
:param matrix: Matrix stored as list of lists
:return: Transpose of $A$
"""
return list(map(list, zip(*matrix)))
@overload
def vector_to_matrix(
vector: Vector[SecNumTypesTV], transpose: bool = False
) -> Matrix[SecNumTypesTV]:
...
@overload
def vector_to_matrix(vector: Vector[float], transpose: bool = False) -> Matrix[float]:
...
[docs]
def vector_to_matrix(vector: Vector[AnyTV], transpose: bool = False) -> Matrix[AnyTV]:
"""
Convert vector to matrix.
.. code-block:: python
vec_to_mat([1, 2, 3]) == [[1, 2, 3]]
vec_to_mat([1, 2, 3], tr=True) == [[1], [2], [3]]
:param vector: Row vector to be converted
:param transpose: Interpret vector as column vector
:return: Matrix that encapsulates vector
"""
if not transpose:
return [vector]
return [[_] for _ in vector]
[docs]
def mat_to_vec(matrix: Matrix[AnyTV], transpose: bool = False) -> Vector[AnyTV]:
"""
Transforms a vector in matrix format to vector format.
.. code-block:: python
A = [[1], [2], [3]]
mat_to_vec(A) == [1, 2, 3]
:param matrix: Vector in matrix format
:param transpose: Interpret vector as column vector
:return: Vector
"""
if not transpose:
return matrix[0]
return [x for [x] in matrix]
[docs]
def mat_vec_mult(
matrix: Matrix[SecNumTypesTV],
vector: Union[Vector[SecNumTypesTV], Vector[float]],
transpose: bool = False,
) -> Vector[SecNumTypesTV]:
"""
Compute matrix-vector multiplication.
:param matrix: Matrix input with dimensions $m * r$.
Dimensions may be $r * m$ when combined with `tr=True`
:param vector: Vector input of length $r$, treated as a column vector
:param transpose: If `True`, first transpose `mat`
:return: Row vector with matrix-vector products
"""
if transpose:
matrix = matrix_transpose(matrix)
else:
matrix = matrix.copy()
product = mpc.matrix_prod(matrix, vector_to_matrix(vector, transpose=True))
return mat_to_vec(product, transpose=True)
@overload
def scale_vector_or_matrix(
factor: float, x: Vector[SecNumTypesTV]
) -> Vector[SecNumTypesTV]:
...
@overload
def scale_vector_or_matrix(
factor: float, x: Matrix[SecNumTypesTV]
) -> Matrix[SecNumTypesTV]:
...
[docs]
def scale_vector_or_matrix(factor: float, x: Vector[AnyTV]) -> Vector[AnyTV]:
"""
Corrects a vector or matrix by a given factor.
:param factor: Factor to scale matrix or vector
:param x: Vector or matrix to be scaled
:return: Scaled vector or matrix
"""
if isinstance(x[0], list):
return cast(Vector[AnyTV], [scale_vector_or_matrix(factor, _) for _ in x])
return [factor * _ for _ in x]
[docs]
@mpc_coro_ignore
async def mult_scalar_mul(
scalars: Union[float, Vector[float], SecureFixedPoint, Vector[SecureFixedPoint]],
matrix: Matrix[SecureFixedPoint],
transpose: bool = False,
) -> Matrix[SecureFixedPoint]:
"""
Vectorized version of mpc.scalar_mul.
.. code-block:: python
scalars = [2, -1]
mat = [[1, 2], [3, 4], [5, 6]]
mult_scalar_mul(scalars, mat) == [[2, -2], [6, -4], [10, -6]]
:param scalars: Vector of scalars
:param matrix: Matrix of which the columns need to be scaled.
:param transpose: If `True`, scale the rows of matrix instead.
:return: Matrix with scaled columns
"""
matrix = [row[:] for row in matrix]
rows = len(matrix)
columns = len(matrix[0])
scalars_list: Union[Vector[float], Vector[SecureFixedPoint]]
if not isinstance(scalars, list):
scalars_list = [scalars] * columns if not transpose else [scalars] * rows # type: ignore
else:
scalars_list = scalars
stype = type(matrix[0][0])
frac_length = stype.frac_length
if not frac_length:
await returnType(stype, rows, columns)
else:
a_integral_first_requirement = isinstance(scalars_list[0], int)
a_integral_second_requirement = False
if isinstance(scalars_list[0], SecureFixedPoint):
a_integral_second_requirement = scalars_list[0].integral
a_integral = a_integral_first_requirement or a_integral_second_requirement
await returnType((stype, a_integral and matrix[0][0].integral), rows, columns)
if not isinstance(scalars_list[0], SecureObject):
for row in range(rows):
for column in range(columns):
matrix[row][column] = matrix[row][column] * (
scalars_list[column] if not transpose else scalars_list[row]
)
else:
scalars_list_sec, matrix = await mpc.gather(scalars_list, matrix)
if frac_length and a_integral:
for index, row in enumerate(scalars_list_sec):
scalars_list_sec[index] = row >> frac_length # NB: no in-place rshift!
for row in range(rows):
for column in range(columns):
matrix[row][column] = matrix[row][column] * (
scalars_list_sec[column] if not transpose else scalars_list_sec[row]
)
matrix[row] = await mpc._reshare(matrix[row])
if frac_length and not a_integral:
for row in range(rows):
matrix[row] = mpc.trunc(matrix[row], f=frac_length, l=stype.bit_length)
matrix = await mpc.gather(matrix)
return matrix
[docs]
@mpc_coro_ignore
async def matrix_sum(
matrix: Matrix[SecureFixedPoint], cols: bool = False
) -> Vector[SecureFixedPoint]:
"""
Securely add all rows in X.
:param matrix: Matrix to be summed
:param cols: If `True`, sum the columns of X instead.
:return: Vector of sums
"""
matrix = [row[:] for row in matrix]
if not cols:
matrix = matrix_transpose(matrix)
rows = len(matrix)
stype = type(matrix[0][0])
frac_length = stype.frac_length
if not frac_length:
await returnType(stype, rows)
else:
await returnType((stype, matrix[0][0].integral), rows)
return [mpc.sum(a) for a in matrix]