Source code for pyDRESCALk.dist_rescal

# @author: Namita Kharat,Manish Bhattarai
from numpy import matlib
from . import config

from .utils import *


[docs]class rescal_algorithms_2D(): """ Performs the distributed RESCAL operation along 2D cartesian grid Parameters: X_ijk (ndarray) : Distributed Data A_ij (ndarray) : Distributed factor A R_ijk (ndarray) : Distributed factor R params (class): Class which comprises following attributes params.comm1 (object): Global Communicator params.comm (object): Modified communicator object params.k (int) : Rank for decomposition params.m (int) : Global dimensions m params.n (int) : Global dimensions n params.p_r (int): Cartesian grid row count params.p_c (int): Cartesian grid column count params.row_comm (object) : Sub communicator along row params.col_comm (object) : Sub communicator along columns params.W_update (bool) : flag to set W update True/False params.norm (str): NMF norm to be minimized params.method(str): NMF optimization method params.eps (float) : Epsilon value """ @comm_timing() def __init__(self, X_ijk, A_i, A_j, R_ijk, params=None): self.params = params self.m, self.n, self.p_r, self.p_c, self.k = self.params.m, self.params.n, self.params.p_r, self.params.p_c, self.params.k self.comm1 = self.params.comm1 # ['comm1'] self.cartesian1d_row, self.cartesian1d_column, self.comm = self.params.row_comm, self.params.col_comm, self.params.comm self.X_ijk, self.A_i, self.A_j, self.R_ijk = X_ijk, A_i, A_j, R_ijk #if self.comm1.rank==0: print(X_ijk.shape,self.A_i.shape,self.A_j.shape,self.R_ijk.shape) self.eps = self.params.eps self.p = self.p_r * self.p_c self.A_update = self.params.A_update self.norm = self.params.norm self.method = self.params.method self.rank = self.comm1.rank self.local_A_n = self.A_i.shape[0] self.local_R_m = self.R_ijk.shape[0] self.np = self.params.np
[docs] def update(self): """Performs 1 step Update for factors W and H based on NMF method and corresponding norm minimization Returns ------- W_ij : ndarray The m/p X k distributed factor W H_ij : ndarray The k X n/p distributed factor H """ if self.norm.upper() == 'FRO': if self.method.upper() == 'MU': self.Fro_MU_update(self.A_update) else: raise Exception('Not a valid method: Choose (mu)') else: raise Exception('Not a valid norm: Choose (fro)') return self.A_i, self.A_j, self.R_ijk
[docs] @comm_timing() def row_reduce(self,A): """Performs all reduce along row sub communicator""" A_TA_glob = self.cartesian1d_row.allreduce(A, op=MPI.SUM) self.cartesian1d_row.barrier() return A_TA_glob
[docs] @comm_timing() def column_reduce(self,A): """Performs all reduce along column sub communicator""" A_TA_glob = self.cartesian1d_column.allreduce(A, op=MPI.SUM) self.cartesian1d_column.barrier() return A_TA_glob
[docs] @comm_timing() def row_broadcast(self,A): """Performs broadcast along row sub communicator""" A_broadcast = self.cartesian1d_row.bcast(A, root= self.cartesian1d_column.Get_rank()) self.cartesian1d_row.barrier() return A_broadcast
[docs] @comm_timing() def column_broadcast(self,A): """Performs all reduce along column sub communicator""" A_column_broadcast = self.cartesian1d_column.bcast(A, root= self.cartesian1d_row.Get_rank()) self.cartesian1d_column.barrier() return A_column_broadcast
[docs] @count_memory() @count_flops() @comm_timing() def matrix_mul(self,A,B): """Computes the matrix multiplication of matrix A and B""" AB_local = A@B return AB_local
[docs] @count_memory() @count_flops() @comm_timing() def gram_mul(self,A): """Computes the gram operation of matrix A""" A_TA_local = A.T@A return A_TA_local
[docs] @comm_timing() def global_gram(self, A): r""" Distributed gram computation Computes the global gram operation of matrix A .. math:: A^TA Parameters ---------- A : ndarray Returns ------- A_TA_glob : ndarray """ A_TA_loc = self.gram_mul(A) A_TA_glob = self.row_reduce(A_TA_loc) return A_TA_glob
[docs] @comm_timing() def row_mm(self, A, B): r""" Distributed matrix multiplication along row of matrix Computes the matrix multiplication of matrix A and B along row sub communicator .. math:: AB Parameters ---------- A : ndarray B : ndarray Returns ------- AB_glob : ndarray """ AB_loc = self.matrix_mul(A,B) AB_glob = self.row_reduce(AB_loc) return AB_glob
[docs] @comm_timing() def column_mm(self, A, B): r""" Distributed matrix multiplication along column of matrix Computes the matrix multiplication of matrix A and B along column sub communicator .. math:: AB Parameters ---------- A : ndarray B : ndarray Returns ------- AB_glob : ndarray """ AB_loc = self.matrix_mul(A,B) AB_glob = self.column_reduce(AB_loc) return AB_glob
[docs] @count_memory() @count_flops() def element_op(self, A, B, operation): """Performs Element operations between A and B""" if operation == "mul": return A * B else: return A/B
[docs] def Fro_MU_update(self, A_update=True): r""" Frobenius norm based multiplicative update of A and R parameter Function computes updated A and R parameter for each mpi rank Parameters ---------- self : object Returns ------- self.A_i : ndarray self.R_ijk : ndarray """ AtA = self.global_gram(self.A_i) #Internally Column reduce NumeratorA = self.np.zeros(self.A_i.shape).astype(self.A_i.dtype) DenominatorA = self.np.zeros(self.A_i.shape).astype(self.A_i.dtype) for x in range(self.m): """Compute Rx""" #print(self.X_ijk[x].shape,self.A_j.shape) XAj = self.column_mm(self.X_ijk[x], self.A_j) #Internally row reduce AtXA = self.row_mm(self.A_i.T, XAj) RAtA = self.matrix_mul(self.R_ijk[x],AtA) DenominatorR = self.matrix_mul(AtA,RAtA) + self.eps temp = self.element_op(AtXA,DenominatorR,"div") self.R_ijk[x] = self.element_op(self.R_ijk[x],temp, "mul") """Compute A""" if self.A_update: XARt = self.matrix_mul(XAj,self.R_ijk[x].T) AR = self.matrix_mul(self.A_i,self.R_ijk[x]) XtAR = self.row_mm(self.X_ijk[x].T, AR) XtAR = self.column_broadcast(XtAR) NumeratorA += XARt + XtAR AtAR = self.matrix_mul(AtA,self.R_ijk[x]) ARt = self.matrix_mul(self.A_i,self.R_ijk[x].T) ARtAtAR = self.matrix_mul(ARt, AtAR) AtARt = self.matrix_mul(AtA,self.R_ijk[x].T) ARAtARt = self.matrix_mul(AR,AtARt) DenominatorA += ARtAtAR + ARAtARt + self.eps if self.A_update: tempA = self.element_op(NumeratorA,DenominatorA,"div") self.A_i = self.element_op(self.A_i,tempA, "mul") self.A_j = self.row_broadcast(self.A_i)