Source code for pyDRESCALk.pyDRESCAL

# @author:  Namita Kharat,Manish Bhattarai
import numpy as np
from .data_io import *
from .dist_rescal import *
#from .dist_svd import *
from .utils import *
from numpy.random import SeedSequence, default_rng
from multiprocessing import Pool


[docs]class pyDRESCAL(): r""" Performs the distributed NMF decomposition of given matrix X into factors W and H Parameters: A_ij (ndarray) : Distributed Data factors (tuple), optional : Distributed factors W and H params (class): Class which comprises following attributes params.init (str) : NMF initialization(rand/nnsvd) params.comm1 (object): Global Communicator params.comm (object): Modified communicator object params.k (int) : Rank for decomposition params.m (int) : Global dimensions m params.n (int) : Global dimensions n params.p_r (int): Cartesian grid row count params.p_c (int): Cartesian grid column count params.row_comm (object) : Sub communicator along row params.col_comm (object) : Sub communicator along columns params.W_update (bool) : flag to set W update True/False params.norm (str): NMF norm to be minimized params.method(str): NMF optimization method params.eps (float) : Epsilon value params.verbose(bool) : Flag to enable/disable display results params.save_factors(bool) : Flag to enable/disable saving computed factors""" @comm_timing() def __init__(self, X_ijk, factors=None, save_factors=False, params=None): self.X_ijk = X_ijk self.params = params self.np = self.params.np self.m_loc, self.n_loc, _ = len(self.X_ijk), self.X_ijk[0].shape[0], self.X_ijk[0].shape[1] self.init = self.params.init if self.params.init else 'rand' self.p_r, self.p_c, self.k = self.params.p_r, self.params.p_c, self.params.k # params['m'], params['n'], params['p_r'], params['p_c'], params['k'] self.p = self.p_r * self.p_c self.comm1 = self.params.comm1 # params['comm1'] self.cart_1d_row, self.cart_1d_column, self.comm = self.params.row_comm, self.params.col_comm, self.params.comm # params['row_comm'],params['col_comm'],params['main_comm'] self.verbose = self.params.verbose if self.params.verbose else True self.perturbation = self.params.perturbation if self.params.perturbation else 0 self.rank = self.comm1.rank self.eps = self.np.finfo(X_ijk[0].dtype).eps self.params.eps = self.eps self.norm = var_init(self.params,'norm',default='fb') self.method = var_init(self.params,'method',default='mu') self.save_factors = save_factors self.params.itr = var_init(self.params,'itr',default=5000) self.seed=None self.itr = self.params.itr self.A_update = var_init(self.params,'A_update',default=True) self.symmetric = var_init(self.params,'symmetric',default=False) #self.recon_err = 0 self.p = self.p_r * self.p_c if self.p_r == self.p_c: self.topo = '2d' else: self.topo = '1d' raise Exception('Current implementation only supports p_r==p_c') self.compute_global_dim() if factors is not None: self.A_i = factors[0].astype(self.X_ijk.dtype) self.A_j = self.cart_1d_row.bcast(self.A_i, root= self.cart_1d_column.Get_rank()) self.cart_1d_row.barrier() self.R_ijk = factors[1].astype(self.X_ijk.dtype) else: self.init_factors()
[docs] @comm_timing() def compute_global_dim(self): """Computes global dimensions m and n from given chunk sizes for any grid configuration""" self.loc_m, self.loc_n, self.loc_n = len(self.X_ijk), self.X_ijk[0].shape[0], self.X_ijk[0].shape[1] if self.cart_1d_row.Get_rank()%self.p_r==0: self.params.globaln=self.loc_n else: self.params.globaln=0 self.params.globaln = self.comm1.allreduce(self.params.globaln) self.params.globalm = self.loc_m self.params.n = self.params.globaln self.params.m = self.params.globalm
[docs] @comm_timing() def init_factors(self): """Initializes Rescal factors with rand/nnsvd method""" for i in range(self.p_r): #Different seed per perturbation and row if i==self.cart_1d_row.Get_rank(): self.seed = self.perturbation * 9999 + i*1000+self.cart_1d_row.Get_rank() self._set_seed() self.A_i = self.np.random.rand(self.n_loc, self.k).astype(self.X_ijk[0].dtype) if i==self.cart_1d_column.Get_rank(): self.seed = self.perturbation * 9999+i*1000+self.cart_1d_column.Get_rank()#Different seed per perturbation and column self._set_seed() self.A_j = self.np.random.rand(self.n_loc, self.k).astype(self.X_ijk[0].dtype) self.seed =self.perturbation * 9999 #Different seed per perturbation self._set_seed() self.R_ijk = self.np.stack([self.np.random.rand(self.k, self.k).astype(self.X_ijk[0].dtype) for _ in range(self.m_loc)], axis=0) if self.symmetric: self.R_ijk = self.R_ijk + self.R_ijk.transpose(0,2,1)
def _set_seed(self): self.np.random.seed(self.seed)
[docs] @count_memory() @comm_timing() def fit(self): r""" Calls the sub routines to perform distributed NMF decomposition with initialization for a given norm minimization and update method Returns ------- W_i : ndarray Factor W of shape m/p_r * k H_j : ndarray Factor H of shape k * n/p_c recon_err : float Reconstruction error for NMF decomposition """ for i in range(self.itr): self.A_i, self.A_j, self.R_ijk = rescal_algorithms_2D(self.X_ijk, self.A_i, self.A_j, self.R_ijk, params=self.params).update() if i % 10 == 0: self.R_ijk = self.np.maximum(self.R_ijk, self.eps) self.A_i = self.np.maximum(self.A_i, self.eps) self.A_j = self.np.maximum(self.A_j, self.eps) if i == self.itr-1: self.relative_err() if self.verbose == True: if self.rank == 0: print('relative error is:', self.recon_err) if self.save_factors: data_write(self.params).save_factors([self.A_i, self.R_ijk]) self.comm.Free() return self.A_i, self.A_j, self.R_ijk, self.recon_err
[docs] @comm_timing() def normalize_features(self, Wall, Wall1, Hall): """Normalizes features Wall and Hall""" Wall_norm = Wall.sum(axis=0, keepdims=True) + self.eps Wall1_norm = Wall1.sum(axis=0, keepdims=True) + self.eps Wall_norm = self.comm1.allreduce(Wall_norm, op=MPI.SUM) Wall1_norm = self.comm1.allreduce(Wall1_norm, op=MPI.SUM) Wall /= Wall_norm Wall1 /= Wall1_norm Hall *= Wall_norm.T return Wall, Wall1, Hall
[docs] @comm_timing() def relative_err(self): """Computes the relative error for NMF decomposition""" self.glob_norm_err = self.dist_norm([self.X_ijk[i] - self.A_i @ self.R_ijk[i] @ self.A_j.T for i in range(self.m_loc)],proc=self.p) self.glob_norm_X = self.dist_norm(self.X_ijk) self.recon_err = self.glob_norm_err / self.glob_norm_X
[docs] @comm_timing() def dist_norm(self, X, proc=-1, norm='fro', axis=None): """Computes the distributed norm""" nm = self.np.linalg.norm(X) if proc != 1: nm = self.comm1.allreduce(nm ** 2) return self.np.sqrt(nm)