Source code for pyCP_APR.numpy_backend.tt_dimscheck

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Python implementation of tt_dimscheck utility with Numpy backend from the MATLAB Tensor Toolbox [1].

References
========================================
[1] General software, latest release: Brett W. Bader, Tamara G. Kolda and others, Tensor Toolbox for MATLAB, Version 3.2.1, www.tensortoolbox.org, April 5, 2021.\n
"""
import numpy as np

[docs]def tt_dimscheck(dims, N, M): """ Processes tensor dimensions. Parameters ---------- dims : list or int Dimension indices. N : int tensor order. M : int Multiplicants Returns ------- sdims : list index for M muliplicands vidx : list index for M muliplicands """ if isinstance(dims, list) or isinstance(dims, np.ndarray): if len(dims) > 0: if (max(dims) < 0): dims = np.setdiff1d(np.arange(0, N), -dims) else: dims = np.arange(0, N) else: if dims < 0 or dims == 0: dims = np.setdiff1d(np.arange(0, N), -dims) # Save the number of dimensions in dims P = len(dims) # Reorder dims from smallest to largest (this matters in particular # for the vector multiplicand case, where the order affects the result) sidx = np.argsort(dims) sdims = np.array(dims)[sidx] # Check sizes to determine how to index multiplicands if P == M: # Case 1: Number of items in dims and number of multiplicands # are equal; therefore, index in order of how sdims was sorted. vidx = sidx else: # Case 2: Number of multiplicands is equal to the number of # dimensions in the tensor; therefore, index multiplicands by # dimensions specified in dims argument. vidx = sdims return sdims, vidx