Source code for pyCP_APR.numpy_backend.ttm_tensor

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Python implementation of ttm utility with Numpy backend from the MATLAB Tensor Toolbox [1].

References
========================================
[1] General software, latest release: Brett W. Bader, Tamara G. Kolda and others, Tensor Toolbox for MATLAB, Version 3.2.1, www.tensortoolbox.org, April 5, 2021.\n
"""
import numpy as np
import copy

from . tt_dimscheck import tt_dimscheck
from . tensor import TENSOR
from . ipermute_tensor import ipermute

[docs]def ttm(X, V, varargin={}): """ Tensor times matrix operation. Parameters ---------- X : object Sparse tensor. sptensor.SP_TENSOR. V : np.ndarray Numpy array. varargin : dict Optional parameter to specify tflag and or mode settings. Returns ------- Y : object Sparse tensor. sptensor.SP_TENSOR. """ # # Create 'n' and 'tflag' arguments from varargin # n = np.arange(0, X.Dimensions) tflag = '' ver = 0 if len(varargin) == 1: if 'tflag' in varargin: tflag = varargin['tflag'] else: n = varargin['n'] elif len(varargin) == 2: n = varargin['n'] tflag = varargin['tflag'] elif len(varargin) == 3: n = varargin['n'] tflag = varargin['tflag'] ver = varargin['ver'] # # Handle cell array # if isinstance(V, list): dims = n dims, vidx = tt_dimscheck(dims, X.Dimensions, len(V)) Y = ttm(copy.deepcopy(X), V[vidx[0]], varargin={'n':dims[0], 'tflag':tflag}) for k in range(1, len(dims)): Y = ttm(copy.deepcopy(Y), V[vidx[k]], varargin={'n':dims[k], 'tflag':tflag}) else: # # COMPUTE SINGLE N-MODE PRODUCT # N = X.Dimensions sz = X.Size if ver == 0: order = [n] + list(np.arange(0, n)) + list(np.arange(n+1, N)) newdata = np.transpose(X.data.copy(), order) newdata = np.reshape(newdata, (sz[n], np.prod(sz[0:n] + sz[n+1:N]))) if tflag == 't': newdata = np.dot(V.T, newdata) p = V.shape[1] else: newdata = np.dot(V, newdata) p = V.shape[0] newsz = [p] + sz[0:n] + sz[n+1:N] newdata = np.reshape(newdata, newsz) Y = TENSOR(newdata) Y = ipermute(Y, order) else: raise Exception("Not yet implemented!") return Y