Package torchtt

Provides Tensor-Train (TT) decomposition using pytorch as backend.

Contains routines for computing the TT decomposition and all the basisc linear algebra in the TT format. Additionally, GPU support can be used thanks to the pytorch backend.
It also has linear solvers in TT and cross approximation as well as automatic differentiation.

What is the Tensor-Train format?

The Tensor-Train (TT) format is a low-rank tensor decomposition format used to fight the curse of dimensionality. A d-dimensional tensor \mathsf{x} \in \mathbb{R} ^{n_1 \times n_2 \times \cdots \times n_d} can be expressed using algebraic operations between d smaller tensors:

\mathsf{x}_{i_1i_2...i_d} = \sum\limits_{s_0=1}^{r_0} \sum\limits_{s_1=1}^{r_1} \cdots \sum\limits_{s_{d-1}=1}^{r_{d-1}} \sum\limits_{s_d=1}^{r_d} \mathsf{g}^{(1)}_{s_0 i_1 s_1} \cdots \mathsf{g}^{(d)}_{s_{d-1} i_d s_d}, where \mathbf{r} = (r_0,r_1,...,r_d), r_0 = r_d = 1 is the TT rank and \mathsf{g}^{(k)} \in \mathbb{R}^{r_{k-1} \times n_k \times r_k} are the TT cores. The storage complexity is \mathcal{O}(nr^2d) instead of \mathcal{O}(n^d) if the rank remains bounded. Tensor operators \mathsf{A} \in \mathbb{R} ^{(m_1 \times m_2 \times \cdots \times m_d) \times (n_1 \times n_2 \times \cdots \times n_d)} can be similarly expressed in the TT format as:

\mathsf{A}_{i_1i_2...i_d,j_1j_2...j_d} = \sum\limits_{s_0=1}^{r_0} \sum\limits_{s_1=1}^{r_1} \cdots \sum\limits_{s_{d-1}=1}^{r_{d-1}} \sum\limits_{s_d=1}^{r_d} \mathsf{h}^{(1)}_{s_0 i_1 j_1 s_1} \cdots \mathsf{h}^{(d)}_{s_{d-1} i_d j_d s_d}, \\ j_k = 1,...,m_k, \: i_k=1,...,n_k, \; \; k=1,...,d. Tensor operators (also called tensor matrices in this library) generalize the concept of matrix-vector product to the multilinear case.

To create a TT object one can simply provide a tensor or the representation in terms of TT-cores. In the first case, the relative accuracy can also be provided such that || \mathsf{x} - \mathsf{y} ||_F^2 < \epsilon || \mathsf{x}||_F^2, where y is the TT tensor returned by the decomposition. In code, this translates to

import torchtt

# tens is a torch.Tensor 
# tens = ...

tt = torchtt.TT(tens, 1e-10)

The rank of the object tt can be inspected using the print() function or can accessed using tt.R. The tensor can be converted back to the full format using tt.full(). The TT class implements tensors in the TT format as well as tensors operators in TT format. Once in the TT format, linear algebra operations (+, -, *, @, /) can be performed without resorting to the full format. The format and the operations is similat to the one implemented in torch. As an example, we have the following code where 3 tensors in the TT format are involved in algebra operations:

import torchTT
import torch

# generate 2 random tensors and a tensor matrix
a = torchtt.randn([4,5,6,7],[1,2,3,4,1])
b = torchtt.randn([8,4,6,4],[1,2,5,2,1])
A = torchtt.randn([(4,8), (5,4) ,(6,6) (7,4)],[1,2,3,2,1])

x = a * ( A @ b )
x = x.round(1e-12)
y = x-2*a

# this is equivalent to 
yf = x.full() - 2*(a.full()*torch.einsum('ijklabcd,abcd->ijkl', A.full(), b.full()))

During the process, the round() function has been used. This has the role of further compressing tensors by reducing the rank. After successive linear algebra operations, the rank will overshoot and therefore it is required to perform rounding operations.

About The Package

The class TT is used to create tensors in the TT format. Passing a torch.Tensor to the constructor computes a TT decomposition. The accuracy eps can be provided as an additional argument. In order to recover the original tensor (also called full tensor), the TT.full() method can be used. Tensors can be further compressed using the TT.round() method.

Once in the TT format, linear algebra operations can be performed between compressed tensors without going to the full format. The implemented operations are:

  • Sum and difference between TT objects. Two TT instances can be summed using the + operator. The difference can be implemented using the - operator.
  • Elementwise product (also called Hadamard product is performed using) the * operator. The same operator also implements the scalar multiplication.
  • The operator @ implements the generalization of the matrix product. It can also be used between a tensor operator and a tensor.
  • The operator / implements the elementwise division of two TT objects. The algorithm is AMEn.
  • The operator ** implements the Kronecker product.

The package also includes more features such as solving multilinear systems, cross approximation and automatic differentiation (with the possibility to define TT layers for neural networks`torchtt.TT.full()). Working examples that can be used as a tutorial are to be found in examples/.

Utilities

  • Example scripts (and ipy notebooks) can be found in the examples/ folder.
  • Tests can be found in the tests/ folder.
Expand source code
r"""
Provides Tensor-Train (TT) decomposition using `pytorch` as backend.

Contains routines for computing the TT decomposition and all the basisc linear algebra in the TT format. Additionally, GPU support can be used thanks to the `pytorch` backend.   
It also has linear solvers in TT and cross approximation as well as automatic differentiation.

.. include:: INTRO.md 

"""

from ._tt_base import TT 
from ._extras import eye, zeros, kron, ones, random, randn, reshape, meshgrid , dot, elementwise_divide, numel, rank1TT, bilinear_form, diag, permute, load, save, cat, pad, shape_mn_to_tuple, shape_tuple_to_mn 
# from .torchtt import TT, eye, zeros, kron, ones, random, randn, reshape, meshgrid , dot, elementwise_divide, numel, rank1TT, bilinear_form, diag, permute, load, save, cat, pad 

__all__ = ['TT', 'eye', 'zeros', 'kron', 'ones', 'random', 'randn', 'reshape', 'meshgrid', 'dot', 'elementwise_divide', 'numel', 'rank1TT', 'bilinear_form', 'diag', 'permute', 'load', 'save', 'cat', 'pad', 'shape_mn_to_tuple', 'shape_tuple_to_mn']

from . import solvers
from . import grad
# from .grad import grad, watch, unwatch
from . import manifold
from . import interpolate
from . import nn
from . import cpp
# from .errors import *

Sub-modules

torchtt.cpp

Module for the C++ backend.

torchtt.errors

Contains the errors used in the torchtt package.

torchtt.grad

Adds AD functionality to torchtt.

torchtt.interpolate

Implements the cross approximation methods (DMRG).

torchtt.manifold

Manifold gradient module.

torchtt.nn

Implements a basic TT layer for constructing deep TT networks.

torchtt.solvers

System solvers in the TT format.

Functions

def bilinear_form(x, A, y)

Computes the bilinear form x^T A y for TT tensors:

Args

x : TT
the tensors.
A : TT
the tensors (must be TT matrix).
y : TT
the tensors.

Raises

InvalidArguments
Inputs must be torchtt.TT instances.
IncompatibleTypes
x and y must be TT tensors and A must be TT matrix.
ShapeMismatch
Check the shapes. Required is x.N == A.M and y.N == A.N.

Returns

torch.tensor
the result of the bilienar form as tensor with 1 element.
Expand source code
def bilinear_form(x,A,y):
    """
    Computes the bilinear form x^T A y for TT tensors:

    Args:
        x (torchtt.TT): the tensors.
        A (torchtt.TT): the tensors (must be TT matrix).
        y (torchtt.TT): the tensors.

    Raises:
        InvalidArguments: Inputs must be torchtt.TT instances.
        IncompatibleTypes: x and y must be TT tensors and A must be TT matrix.
        ShapeMismatch: Check the shapes. Required is x.N == A.M and y.N == A.N.

    Returns:
        torch.tensor: the result of the bilienar form as tensor with 1 element.
    """
    if not isinstance(x,TT) or not isinstance(A,TT) or not isinstance(y,TT):
        raise InvalidArguments("Inputs must be torchtt.TT instances.")
    if x.is_ttm or y.is_ttm or A.is_ttm==False:
        raise IncompatibleTypes("x and y must be TT tensors and A must be TT matrix.")
    if x.N != A.M or y.N != A.N:
        raise ShapeMismatch("Check the shapes. Required is x.N == A.M and y.N == A.N.")
    d = len(x.N)
    return bilinear_form_aux(x.cores,A.cores,y.cores,d)
def cat(tensors, dim=0)

Concatenate tensors in the TT format along a given dimension dim. Only works for TT tensors and not TT matrices.

Examples

import torchtt 
import torch 


a1 = torchtt.randn((3,4,2,6,7), [1,2,3,4,2,1])
a2 = torchtt.randn((3,4,8,6,7), [1,3,1,7,5,1])
a3 = torchtt.randn((3,4,15,6,7), [1,3,10,2,4,1])

a = torchtt.cat((a1,a2,a3),2)

af = torch.cat((a1.full(), a2.full(),
print(torch.linalg.norm(a.full()-af))

Args

tensors : tuple[TT]
the tensors to be concatenated. Their mode sizes must match for all modex except the concatenating dimension.
dim : int, optional
The dimension to be concatenated after. Defaults to 0.

Raises

InvalidArguments
Not implemented for tensor matrices.
InvalidArguments
The mode sizes must be the same on the nonconcatenated dimensions for all the provided tensors.
InvalidArguments
The tensors must have the same number of dimensions.

Returns

TT
the result.
Expand source code
def cat(tensors, dim = 0):
    """
    Concatenate tensors in the TT format along a given dimension `dim`. Only works for TT tensors and not TT matrices.
    
    Examples:
        ```
        import torchtt 
        import torch 


        a1 = torchtt.randn((3,4,2,6,7), [1,2,3,4,2,1])
        a2 = torchtt.randn((3,4,8,6,7), [1,3,1,7,5,1])
        a3 = torchtt.randn((3,4,15,6,7), [1,3,10,2,4,1])

        a = torchtt.cat((a1,a2,a3),2)

        af = torch.cat((a1.full(), a2.full(),
        print(torch.linalg.norm(a.full()-af))
        ```
        
    Args:
        tensors (tuple[TT]): the tensors to be concatenated. Their mode sizes must match for all modex except the concatenating dimension.
        dim (int, optional): The dimension to be concatenated after. Defaults to 0.

    Raises:
        InvalidArguments: Not implemented for tensor matrices.
        InvalidArguments: The mode sizes must be the same on the nonconcatenated dimensions for all the provided tensors.
        InvalidArguments: The tensors must have the same number of dimensions.

    Returns:
        torchtt.TT: the result.
    """

    if(len(tensors) == 0):
        return None 

    if tensors[0].is_ttm:
        raise InvalidArguments("Not implemented for tensor matrices.")
    Rs = [tensors[0].R] 

    for i in range(1, len(tensors)):
        if tensors[i].is_ttm:
            raise InvalidArguments("Not implemented for tensor matrices.")
        if tensors[i].N[:dim] != tensors[0].N[:dim] and tensors[i].N[(dim+1):] != tensors[0].N[(dim+1):]:
            raise InvalidArguments("The mode sizes must be the same on the nonconcatenated dimensions for all the provided tensors.")
        if len(tensors[i].N) != len(tensors[0].N):
            raise InvalidArguments("The tensors must have the same number of dimensions.")
        Rs.append(tensors[i].R)
    

    cores = []
    
    
    if tensors[0].is_ttm:
        pass
    else:
        
        r_sum = [1]
        for i in range(1,len(tensors[0].N)):
            r_sum.append(sum([Rs[k][i] for k in range(len(tensors))]))
        r_sum.append(1)
        for i in range(len(tensors[0].N)):
            if i == dim:
                n = sum([t.N[dim] for t in tensors])
                cores.append(tn.zeros((r_sum[i], n, r_sum[i+1]), device = tensors[0].cores[0].device, dtype = tensors[0].cores[0].dtype))
            else:
                cores.append(tn.zeros((r_sum[i], tensors[0].N[i], r_sum[i+1]), device = tensors[0].cores[0].device, dtype = tensors[0].cores[0].dtype))
                
            offset1 = 0
            offset2 = 0
            offset3 = 0
            
            for t in tensors:
                if i==dim:
                    cores[i][offset1:(offset1+t.cores[i].shape[0]),offset2:(offset2+t.cores[i].shape[1]),offset3:(offset3+t.cores[i].shape[2])] = t.cores[i]
                    if i>0: offset1 += t.cores[i].shape[0]
                    offset2 += t.cores[i].shape[1]
                    if i<len(tensors[0].N)-1: offset3 += t.cores[i].shape[2]
                else:
                    cores[i][offset1:(offset1+t.cores[i].shape[0]),:,offset3:(offset3+t.cores[i].shape[2])] = t.cores[i]
                    if i>0: offset1 += t.cores[i].shape[0]
                    if i<len(tensors[0].N)-1: offset3 += t.cores[i].shape[2]
        #for i in range(len(self.__N)):
        #    pad1 = (0,0 if i == len(self.__N)-1 else other.R[i+1] , 0,0 , 0,0 if i==0 else other.R[i])
        #    pad2 = (0 if i == len(self.__N)-1 else self.__R[i+1],0 , 0,0 , 0 if i==0 else self.R[i],0)
        #    cores.append(tnf.pad(self.cores[i],pad1)+tnf.pad(other.cores[i],pad2))
    return TT(cores)
def diag(input)

Creates diagonal TT matrix from TT tensor or extracts the diagonal of a TT matrix:

  • If a TT matrix is provided the result is a TT tensor representing the diagonal \mathsf{x}_{i_1...i_d} = \mathsf{A}_{i_1...i_d,i_1...i_d}

  • If a TT tensor is provided the result is a diagonal TT matrix with the entries \mathsf{A}_{i_1...i_d,j_1...j_d} = \mathsf{x}_{i_1...i_d} \delta_{i_1}^{j_1} \cdots \delta_{i_d}^{j_d}

Args

input : TT
the input.

Raises

InvalidArguments
Input must be a torchtt.TT instance.

Returns

TT
the result.
Expand source code
def diag(input):
    """
    Creates diagonal TT matrix from TT tensor or extracts the diagonal of a TT matrix:

    * If a TT matrix is provided the result is a TT tensor representing the diagonal \( \\mathsf{x}_{i_1...i_d} = \\mathsf{A}_{i_1...i_d,i_1...i_d} \)

    * If a TT tensor is provided the result is a diagonal TT matrix with the entries \( \\mathsf{A}_{i_1...i_d,j_1...j_d} = \\mathsf{x}_{i_1...i_d} \\delta_{i_1}^{j_1} \\cdots \\delta_{i_d}^{j_d} \)

    Args:
        input (TT): the input. 

    Raises:
        InvalidArguments: Input must be a torchtt.TT instance.

    Returns:
        torchtt.TT: the result.
    """

    if not isinstance(input, TT):
        raise InvalidArguments("Input must be a torchtt.TT instance.")

    if input.is_ttm:
        return TT([tn.diagonal(c, dim1 = 1, dim2 = 2) for c in input.cores])
    else:
        return TT([tn.einsum('ijk,jm->ijmk',c,tn.eye(c.shape[1])) for c in input.cores])
def dot(a, b, axis=None)

Computes the dot product between 2 tensors in TT format. If both a and b have identical mode sizes the result is the dot product. If a and b have inequal mode sizes, the function perform index contraction. The number of dimensions of a must be greater or equal as b. The modes of the tensor a along which the index contraction with b is performed are given in axis. For the compelx case (a,b) = b^H . a.

Examples

a = torchtt.randn([3,4,5,6,7],[1,2,2,2,2,1])
b = torchtt.randn([3,4,5,6,7],[1,2,2,2,2,1])
c = torchtt.randn([3,5,6],[1,2,2,1])
print(torchtt.dot(a,b))
print(torchtt.dot(a,c,[0,2,3]))

Args

a : TT
the first tensor.
b : TT
the second tensor.
axis : list[int], optional
the mode indices for index contraction. Defaults to None.

Raises

InvalidArguments
Both operands should be TT instances.
NotImplementedError
Operation not implemented for TT-matrices.
ShapeMismatch
Operands are not the same size.
ShapeMismatch
Number of the modes of the first tensor must be equal with the second.

Returns

float or TT
the result. If no axis index is provided the result is a scalar otherwise a torchtt.TT object.
Expand source code
def dot(a,b,axis=None):
    """
    Computes the dot product between 2 tensors in TT format.
    If both a and b have identical mode sizes the result is the dot product.
    If a and b have inequal mode sizes, the function perform index contraction. 
    The number of dimensions of a must be greater or equal as b.
    The modes of the tensor a along which the index contraction with b is performed are given in axis.
    For the compelx case (a,b) = b^H . a.

    Examples:
        ```
        a = torchtt.randn([3,4,5,6,7],[1,2,2,2,2,1])
        b = torchtt.randn([3,4,5,6,7],[1,2,2,2,2,1])
        c = torchtt.randn([3,5,6],[1,2,2,1])
        print(torchtt.dot(a,b))
        print(torchtt.dot(a,c,[0,2,3]))
        ```

    Args:
        a (torchtt.TT): the first tensor.
        b (torchtt.TT): the second tensor.
        axis (list[int], optional): the mode indices for index contraction. Defaults to None.

    Raises:
        InvalidArguments: Both operands should be TT instances.
        NotImplementedError: Operation not implemented for TT-matrices.
        ShapeMismatch: Operands are not the same size.
        ShapeMismatch: Number of the modes of the first tensor must be equal with the second.

    Returns:
        float or torchtt.TT: the result. If no axis index is provided the result is a scalar otherwise a torchtt.TT object.
    """
    
    if not isinstance(a, TT) or not isinstance(b, TT):
        raise InvalidArguments('Both operands should be TT instances.')
    
    
    if axis == None:
        # treat first the full dot product
        # faster than partial projection
        if a.is_ttm or b.is_ttm:
            raise NotImplementedError('Operation not implemented for TT-matrices.')
        if a.N != b.N:
            raise ShapeMismatch('Operands are not the same size.')
        
        result = tn.tensor([[1.0]],dtype = a.cores[0].dtype, device=a.cores[0].device)
        
        for i in range(len(a.N)):
            result = tn.einsum('ab,aim,bin->mn',result, a.cores[i], tn.conj(b.cores[i]))
        result = tn.squeeze(result)
    else:
        # partial case
        if a.is_ttm or b.is_ttm:
            raise NotImplementedError('Operation not implemented for TT-matrices.')
        if len(a.N)<len(b.N):
            raise ShapeMismatch('Number of the modes of the first tensor must be equal with the second.')
        # if a.N[axis] != b.N:
        #     raise Exception('Dimension mismatch.')
        
        k = 0 # index for the tensor b
        cores_new = []
        rank_left = 1
        for i in range(len(a.N)):
            if i in axis:
                cores_new.append(tn.conj(b.cores[k]))
                rank_left = b.cores[k].shape[2]
                k+=1
            else:
                rank_right = b.cores[k].shape[0] if i+1 in axis else rank_left                
                cores_new.append(tn.conj(tn.einsum('ik,j->ijk',tn.eye(rank_left,rank_right,dtype=a.cores[0].dtype),tn.ones([a.N[i]],dtype=a.cores[0].dtype))))
        
        result = (a*TT(cores_new)).sum(axis)
    return result
def elementwise_divide(x, y, eps=1e-12, starting_tensor=None, nswp=50, kick=4, local_iterations=40, resets=2, preconditioner=None, verbose=False)

Perform the elemntwise division x/y of two tensors in the TT format using the AMEN method. Use this method if different AMEN arguments are needed. This method does not check the validity of the inputs.

Args

x : TT or scalar
first tensor (can also be scalar of type float, int, torch.tensor with shape (1)).
y : TT
second tensor.
eps : float, optional
relative acccuracy. Defaults to 1e-12.
starting_tensor : TT or None, optional
initial guess of the result (None for random initial guess). Defaults to None.
nswp : int, optional
number of iterations. Defaults to 50.
kick : int, optional
size of rank enrichment. Defaults to 4.
local_iterations : int, optional
the number of iterations for the local iterative solver. Defaults to 40.
resets : int, optional
the number of restarts in the GMRES solver. Defaults to 2.
preconditioner : string, optional
Use preconditioner for the local solver (possible vaules None, 'c'). Defaults to None.
verbose : bool, optional
display debug info. Defaults to False.

Returns

TT
the result
Expand source code
def elementwise_divide(x, y, eps = 1e-12, starting_tensor = None, nswp = 50, kick = 4, local_iterations = 40, resets = 2, preconditioner = None, verbose = False):
    """
    Perform the elemntwise division x/y of two tensors in the TT format using the AMEN method.
    Use this method if different AMEN arguments are needed.
    This method does not check the validity of the inputs.
    
    Args:
        x (torchtt.TT or scalar): first tensor (can also be scalar of type float, int, torch.tensor with shape (1)).
        y (torchtt.TT): second tensor.
        eps (float, optional): relative acccuracy. Defaults to 1e-12.
        starting_tensor (torchtt.TT or None, optional): initial guess of the result (None for random initial guess). Defaults to None.
        nswp (int, optional): number of iterations. Defaults to 50.
        kick (int, optional): size of rank enrichment. Defaults to 4.
        local_iterations (int, optional): the number of iterations for the local iterative solver. Defaults to 40.
        resets (int, optional): the number of restarts in the GMRES solver. Defaults to 2.
        preconditioner (string, optional): Use preconditioner for the local solver (possible vaules None, 'c'). Defaults to None. 
        verbose (bool, optional): display debug info. Defaults to False.

    Returns:
        torchtt.TT: the result
    """

    cores_new = amen_divide(y,x,nswp,starting_tensor,eps,rmax = 1000, kickrank = kick, local_iterations = local_iterations, resets = resets, verbose=verbose, preconditioner = preconditioner)
    return TT(cores_new)
def eye(shape, dtype=torch.float64, device=None)

Construct the TT decomposition of a multidimensional identity matrix. all the TT ranks are 1.

Args

shape : list[int]
the shape.
dtype : torch.dtype, optional
the dtype of the returned tensor. Defaults to tn.float64.
device : torch.device, optional
the device where the TT cores are created (None means CPU). Defaults to None.

Returns

TT
the one tensor.
Expand source code
def eye(shape, dtype=tn.float64, device = None):
    """
    Construct the TT decomposition of a multidimensional identity matrix.
    all the TT ranks are 1.

    Args:
        shape (list[int]): the shape.
        dtype (torch.dtype, optional): the dtype of the returned tensor. Defaults to tn.float64.
        device (torch.device, optional): the device where the TT cores are created (None means CPU). Defaults to None.

    Returns:
        torchtt.TT: the one tensor.
    """
    
    shape = list(shape)
    
    cores = [tn.unsqueeze(tn.unsqueeze(tn.eye(s, dtype=dtype, device = device),0),3) for s in shape]            
    
    return TT(cores)
def kron(first, second)

Computes the tensor Kronecker product. If None is provided as input the reult is the other tensor. If A is N_1 x … x N_d and B is M_1 x … x M_p, then kron(A,B) is N_1 x … x N_d x M_1 x … x M_p

Args

first (torchtt.TT | None): first argument. second (torchtt.TT | None): second argument.

Raises

IncompatibleTypes
Incompatible data types (make sure both are either TT-matrices or TT-tensors).
InvalidArguments
Invalid arguments.

Returns

TT
the result.
Expand source code
def kron(first, second):
    """
    Computes the tensor Kronecker product.
    If None is provided as input the reult is the other tensor.
    If A is N_1 x ... x N_d and B is M_1 x ... x M_p, then kron(A,B) is N_1 x ... x N_d x M_1 x ... x M_p


    Args:
        first (torchtt.TT | None): first argument.
        second (torchtt.TT | None): second argument.

    Raises:
        IncompatibleTypes: Incompatible data types (make sure both are either TT-matrices or TT-tensors).
        InvalidArguments: Invalid arguments.

    Returns:
        torchtt.TT: the result.
    """
    if first == None and isinstance(second,TT):
        cores_new = [c.clone() for c in second.cores]
        result = TT(cores_new)
    elif second == None and isinstance(first,TT): 
        cores_new = [c.clone() for c in first.cores]
        result = TT(cores_new)
    elif isinstance(first,TT) and isinstance(second,TT):
        if first.is_ttm != second.is_ttm:
            raise IncompatibleTypes('Incompatible data types (make sure both are either TT-matrices or TT-tensors).')
    
        # concatenate the result
        cores_new = [c.clone() for c in first.cores] + [c.clone() for c in second.cores]
        result = TT(cores_new)
    else:
        raise InvalidArguments('Invalid arguments.')
    return result
def load(path)

Load a torchtt.TT object from a file.

Examples

import torchtt
#generate a TT object
A = torchtt.randn([10,20,30,40,4,5],[1,6,5,4,3,2,1])
# save the TT object
torchtt.save(A,"./test.TT")
# load the TT object
B = torchtt.load("./test.TT")
# the loaded should be the same
print((A-B).norm()/A.norm())

Args

path : str
the file name.

Returns

TT
the tensor.
Expand source code
def load(path):
    """
    Load a torchtt.TT object from a file.

    Examples:
        ```
        import torchtt
        #generate a TT object
        A = torchtt.randn([10,20,30,40,4,5],[1,6,5,4,3,2,1])
        # save the TT object
        torchtt.save(A,"./test.TT")
        # load the TT object
        B = torchtt.load("./test.TT")
        # the loaded should be the same
        print((A-B).norm()/A.norm())
        ```
        
    Args:
        path (str): the file name.

    Returns:
        torchtt.TT: the tensor.
    """
    dct = tn.load(path)
    
    return TT(dct['cores'])
def meshgrid(vectors)

Creates a meshgrid of torchtt.TT objects. Similar to numpy.meshgrid or torch.meshgrid. The input is a list of d torch.tensor vectors of sizes N_1, … ,N_d The result is a list of torchtt.TT instances of shapes N1 x … x Nd.

Args

vectors : list[torch.tensor]
the vectors (1d tensors).

Returns

list[TT]
the resulting meshgrid.
Expand source code
def meshgrid(vectors):
    """
    Creates a meshgrid of torchtt.TT objects. Similar to numpy.meshgrid or torch.meshgrid.
    The input is a list of d torch.tensor vectors of sizes N_1, ... ,N_d
    The result is a list of torchtt.TT instances of shapes N1 x ... x Nd.
    
    Args:
        vectors (list[torch.tensor]): the vectors (1d tensors).

    Returns:
        list[TT]: the resulting meshgrid.
    """
    
    Xs = []
    dtype = vectors[0].dtype
    for i in range(len(vectors)):
        lst = [tn.ones((1,v.shape[0],1),dtype=dtype) for v in vectors]
        lst[i] = tn.reshape(vectors[i],[1,-1,1])
        Xs.append(TT(lst))
    return Xs
def numel(tensor)

Return the number of entries needed to store the TT cores for the given tensor.

Args

tensor : TT
the TT representation of the tensor.

Returns

int
number of floats stored for the TT decomposition.
Expand source code
def numel(tensor):
    """
    Return the number of entries needed to store the TT cores for the given tensor.

    Args:
        tensor (torchtt.TT): the TT representation of the tensor.

    Returns:
        int: number of floats stored for the TT decomposition.
    """
    
    return sum([tn.numel(tensor.cores[i]) for i in range(len(tensor.N))])
def ones(shape, dtype=torch.float64, device=None)

Construct a tensor that contains only ones. the shape can be a list of ints or a list of tuples of ints. The second case creates a TT matrix.

Args

shape : list[int] or list[tuple[int]]
the shape.
dtype : torch.dtype, optional
the dtype of the returned tensor. Defaults to tn.float64.
device : torch.device, optional
the device where the TT cores are created (None means CPU). Defaults to None.

Raises

InvalidArguments
Shape must be a list.

Returns

TT
the one tensor.
Expand source code
def ones(shape, dtype=tn.float64, device = None):
    """
    Construct a tensor that contains only ones.
    the shape can be a list of ints or a list of tuples of ints. The second case creates a TT matrix.

    Args:
        shape (list[int] or list[tuple[int]]): the shape.
        dtype (torch.dtype, optional): the dtype of the returned tensor. Defaults to tn.float64.
        device (torch.device, optional): the device where the TT cores are created (None means CPU). Defaults to None.

    Raises:
        InvalidArguments: Shape must be a list.

    Returns:
        torchtt.TT: the one tensor.
    """
    if isinstance(shape,list):
        d = len(shape)
        if d==0:
            return TT(None)
        else:
            if isinstance(shape[0],tuple):
                # we create a TT-matrix
                cores = [tn.ones([1,shape[i][0],shape[i][1],1],dtype=dtype,device=device) for i in range(d)]            
                
            else:
                # we create a TT-tensor
                cores = [tn.ones([1,shape[i],1],dtype=dtype,device=device) for i in range(d)]
            
    else:
        raise InvalidArguments('Shape must be a list.')
    
    return TT(cores)
def pad(tensor, padding, value=0.0)

Pad a tensor in the TT format. The padding argument is a tuple of tuples ((b1, a1), (b2, a2), … , (bd, ad)). Each dimension is padded with bk at the beginning and ak at the end. The padding value is constant and is given as the argument value. In case of a TT operator, duiagual padding is performed. On the diagonal, the provided value is inserted.

Args

tensor : TT
the tensor to be padded.
padding (tuple(tuple(int))): the paddings.
value : float, optional
the value to pad. Defaults to 0.0.

Raises

InvalidArguments
The number of paddings should not exceed the number of dimensions of the tensor.

Returns

TT
the result.
Expand source code
def pad(tensor, padding, value = 0.0):
    """
    Pad a tensor in the TT format.
    The `padding` argument is a tuple of tuples `((b1, a1), (b2, a2), ... , (bd, ad))`. 
    Each dimension is padded with `bk` at the beginning and `ak` at the end. The padding value is constant and is given as the argument `value`. 
    In case of a TT operator, duiagual padding is performed. On the diagonal, the provided `value` is inserted.

    Args:
        tensor (TT): the tensor to be padded.
        padding (tuple(tuple(int))): the paddings.
        value (float, optional): the value to pad. Defaults to 0.0.

    Raises:
        InvalidArguments: The number of paddings should not exceed the number of dimensions of the tensor.

    Returns:
        TT: the result.
    """
    if(len(padding) > len(tensor.N)):
        raise InvalidArguments("The number of paddings should not exceed the number of dimensions of the tensor.")
    
    
    if tensor.is_ttm:
        cores = [c.clone() for c in tensor.cores]
        for pad,k in zip(reversed(padding),reversed(range(len(tensor.N)))):
            cores[k] = tnf.pad(cores[k],(1 if k < len(tensor.N)-1 else 0,1 if k < len(tensor.N)-1 else 0,pad[0],pad[1],pad[0],pad[1],1 if k>0 else 0,1 if k>0 else 0),value = 0)
            cores[k][0,:pad[0],:pad[0],0] = value*tn.eye(pad[0], device = cores[k].device, dtype = cores[k].dtype)
            cores[k][-1,(pad[0]+tensor.M[k]):,(pad[0]+tensor.N[k]):,-1] = value*tn.eye(pad[1], device = cores[k].device, dtype = cores[k].dtype)
            value = 1
    else:
        rprod = np.prod(tensor.R)
        value = value/rprod 

        cores = [c.clone() for c in tensor.cores]
        for pad,k in zip(reversed(padding),reversed(range(len(tensor.N)))):
            cores[k] = tnf.pad(cores[k],(0,0,pad[0],pad[1],0,0),value = value)
            value = 1
            
    return TT(cores)
def permute(input, dims, eps=1e-12)

Permutes the dimensions of the tensor. Works similarily to torch.permute. Works like a bubble sort for both TT tensors and TT matrices.

Examples:

x_tt = torchtt.random([5,6,7,8,9],[1,2,3,4,2,1])
xp_tt = torchtt.permute(x_tt, [4,3,2,1,0], 1e-10)
print(xp_tt) # the shape of this tensor should be [9,8,7,6,5]

Args

input : TT
the input tensor.
dims : list[int]
the order of the indices in the new tensor.
eps : float, optional
the relative accuracy of the decomposition. Defaults to 1e-12.

Raises

InvalidArguments
The input must be a TT tensor dims must be a list of integers or a tple of integers.
ShapeMismatch
dims must be the length of the number of dimensions.
InvalidArguments
Duplicate dims are not allowed.
InvalidArguments
Dims should only contain integers from 0 to d-1.

Returns

TT
the resulting tensor.
Expand source code
def permute(input, dims, eps = 1e-12):
    """
    Permutes the dimensions of the tensor. Works similarily to `torch.permute`.
    Works like a bubble sort for both TT tensors and TT matrices.
    
    Examples:
    ```
    x_tt = torchtt.random([5,6,7,8,9],[1,2,3,4,2,1])
    xp_tt = torchtt.permute(x_tt, [4,3,2,1,0], 1e-10)
    print(xp_tt) # the shape of this tensor should be [9,8,7,6,5]
    ```
    
    Args:
        input (torchtt.TT): the input tensor.
        dims (list[int]): the order of the indices in the new tensor.
        eps (float, optional): the relative accuracy of the decomposition. Defaults to 1e-12.

    Raises:
        InvalidArguments: The input must be a TT tensor dims must be a list of integers or a tple of integers.
        ShapeMismatch: `dims` must be the length of the number of dimensions.
        InvalidArguments: Duplicate dims are not allowed.
        InvalidArguments: Dims should only contain integers from 0 to d-1.
    Returns:
        torchtt.TT: the resulting tensor.
    """
    if not isinstance(input, TT) :
        raise InvalidArguments("The input must be a TT tensor dims must be a list of integers or a tple of integers.")
    if len(dims) != len(input.N):
        raise ShapeMismatch("`dims` must be the length of the number of dimensions.")
    if len(dims) != len(set(dims)):
        raise InvalidArguments("Duplicate dims are not allowed.")
    if min(dims) != 0 or max(dims) != len(input.N)-1:
        raise InvalidArguments("Dims should only contain integers from 0 to d-1.")
    
    cores, R  = rl_orthogonal(input.cores, input.R, input.is_ttm)
    d = len(cores)
    eps = eps/(d**1.5) 
    indices = list(range(d))
    
    last_idx = 0
    
    inversions = True 
    while inversions:
        inversions = False 
        
        
        
        for i in range(d-1):
            i1 = indices[i]
            i2 = indices[i+1]
            if dims.index(i1)>dims.index(i2):
                # inverion in the index permutation => the cores must be swapped.
                inversions = True
            
                indices[i] = i2
                indices[i+1] = i1
                
                
                # print(indices,' permute ', i1, i2)
        
        

                last_idx = i
                if input.is_ttm:
                    #reorthonormalize
                    for k in range(last_idx, i):
                        Q, R = QR(tn.reshape(cores[k],[cores[k].shape[0]*cores[k].shape[1]*cores[k].shape[2], cores[k].shape[3]]))
                        R[k+1] = Q.shape[1]
                        cores[k] = tn.reshape(Q, [cores[k].shape[0], cores[k].shape[1], cores[k].shape[2], -1])
                        cores[k+1] = tn.einsum('ij,jkl->ikl',R,cores[k+1])
                    
                    n2 = [cores[i].shape[1], cores[i].shape[2]]
                    core = tn.einsum('ijkl,lmno->ijkmno',cores[i],cores[i+1])
                    core = tn.permute(core, [0,3,4,1,2,5])
                    U,S,V = SVD(tn.reshape(core, [core.shape[0]*core.shape[1]*core.shape[2],-1]))
                    if S.is_cuda:
                        r_now = min([rank_chop(S.cpu().numpy(),tn.linalg.norm(S).cpu().numpy()*eps)])
                    else:
                        r_now = min([rank_chop(S.numpy(),tn.linalg.norm(S).numpy()*eps)])
                
                    US = U[:,:r_now]@tn.diag(S[:r_now])
                    V = V[:r_now,:]
                    
                    cores[i] = tn.reshape(US,[cores[i].shape[0],cores[i+1].shape[1],cores[i+1].shape[2],-1])
                    R[i+1] = cores[i].shape[2]
                    cores[i+1] = tn.reshape(V, [-1]+ n2 +[cores[i+1].shape[3]])
                    
                else:
                    
                    #reorthonormalize
                    for k in range(last_idx, i):
                        Q, R = QR(tn.reshape(cores[k],[cores[k].shape[0]*cores[k].shape[1], cores[k].shape[2]]))
                        R[k+1] = Q.shape[1]
                        cores[k] = tn.reshape(Q, [cores[k].shape[0], cores[k].shape[1],-1])
                        cores[k+1] = tn.einsum('ij,jkl->ikl',R,cores[k+1])
                    
                    n2 = cores[i].shape[1]
                    core = tn.einsum('ijk,klm->ijlm',cores[i],cores[i+1])
                    core = tn.permute(core, [0,2,1,3])
                    U,S,V = SVD(tn.reshape(core, [core.shape[0]*core.shape[1],-1]))
                    if S.is_cuda:
                        r_now = min([rank_chop(S.cpu().numpy(),tn.linalg.norm(S).cpu().numpy()*eps)])
                    else:
                        r_now = min([rank_chop(S.numpy(),tn.linalg.norm(S).numpy()*eps)])

                
                    US = U[:,:r_now]@tn.diag(S[:r_now])
                    V = V[:r_now,:]
                    
                    cores[i] = tn.reshape(US,[cores[i].shape[0],cores[i+1].shape[1],-1])
                    R[i+1] = cores[i].shape[2]
                    cores[i+1] = tn.reshape(V, [-1, n2, cores[i+1].shape[2]])
                    
                
    return TT(cores)
def randn(N, R, var=1.0, dtype=torch.float64, device=None)

A torchtt.TT tensor of shape N = [N1 x … x Nd] and rank R is returned. The entries of the fuill tensor are alomst normal distributed with the variance var.

Args

N : list[int]
the shape.
R : list[int]
the rank.
var : float, optional
the variance. Defaults to 1.0.
dtype : torch.dtype, optional
the dtype of the returned tensor. Defaults to tn.float64.
device : torch.device, optional
the device where the TT cores are created (None means CPU). Defaults to None.

Returns

TT
the result.
Expand source code
def randn(N, R, var = 1.0, dtype = tn.float64, device = None):
    """
    A torchtt.TT tensor of shape N = [N1 x ... x Nd] and rank R is returned. 
    The entries of the fuill tensor are alomst normal distributed with the variance var.
    
    Args:
        N (list[int]): the shape.
        R (list[int]): the rank.
        var (float, optional): the variance. Defaults to 1.0.
        dtype (torch.dtype, optional): the dtype of the returned tensor. Defaults to tn.float64.
        device (torch.device, optional): the device where the TT cores are created (None means CPU). Defaults to None.

    Returns:
        torchtt.TT: the result.
    """

    d = len(N)
    v1 = var / np.prod(R)
    v = v1**(1/d)
    cores = [None] * d
    for i in range(d):
        cores[i] = tn.randn([R[i],N[i][0],N[i][1],R[i+1]] if isinstance(N[i],tuple) else [R[i],N[i],R[i+1]], dtype = dtype, device = device)*np.sqrt(v)

    return TT(cores)
def random(N, R, dtype=torch.float64, device=None)

Returns a tensor of shape N with random cores of rank R. Each core is a normal distributed with mean 0 and variance 1. Check also the method torchtt.randn()for better random tensors in the TT format.

Args

N : list[int] or list[tuple[int]]
the shape of the tensor. If the elements are tuples of integers, we deal with a TT-matrix.
R : list[int] or int
can be a list if the exact rank is specified or an integer if the maximum rank is secified.
dtype : torch.dtype, optional
the dtype of the returned tensor. Defaults to tn.float64.
device : torch.device, optional
the device where the TT cores are created (None means CPU). Defaults to None.

Raises

InvalidArguments
Check if N and R are right.

Returns

TT
the result.
Expand source code
def random(N, R, dtype = tn.float64, device = None):
    """
    Returns a tensor of shape N with random cores of rank R.
    Each core is a normal distributed with mean 0 and variance 1.
    Check also the method torchtt.randn()for better random tensors in the TT format.

    Args:
        N (list[int] or list[tuple[int]]): the shape of the tensor. If the elements are tuples of integers, we deal with a TT-matrix.
        R (list[int] or int): can be a list if the exact rank is specified or an integer if the maximum rank is secified.
        dtype (torch.dtype, optional): the dtype of the returned tensor. Defaults to tn.float64.
        device (torch.device, optional): the device where the TT cores are created (None means CPU). Defaults to None.

    Raises:
        InvalidArguments: Check if N and R are right.

    Returns:
        torchtt.TT: the result.
    """
    
    if isinstance(R,int):
        R = [1]+[R]*(len(N)-1)+[1]
    elif len(N)+1 != len(R) or R[0] != 1 or R[-1] != 1 or len(N)==0:
        raise InvalidArguments('Check if N and R are right.')
        
    cores = []
    
    for i in range(len(N)):
        cores.append(tn.randn([R[i],N[i][0],N[i][1],R[i+1]] if isinstance(N[i],tuple) else [R[i],N[i],R[i+1]], dtype = dtype, device = device))
        
    T = TT(cores)
    
    return T
def rank1TT(elements)

Compute the rank 1 TT from a list of vectors (or matrices).

Args

elements : list[torch.tensor]
the list of vectors (or matrices in case a TT matrix should be created).

Returns

TT
the resulting TT object.
Expand source code
def rank1TT(elements):
    """
    Compute the rank 1 TT from a list of vectors (or matrices).

    Args:
        elements (list[torch.tensor]): the list of vectors (or matrices in case a TT matrix should be created).

    Returns:
        torchtt.TT: the resulting TT object.
    """
    
    return TT([e[None,...,None] for e in elements])
def reshape(tens, shape, eps=1e-16, rmax=9223372036854775807)

Reshapes a torchtt.TT tensor in the TT format. A rounding is also performed.

Args

tens : TT
the input tensor.
shape : list[int] or list[tuple[int]]
the desired shape. In the case of a TT operator the shape has to be given as list of tuples of ints [(M1,N1),…,(Md,Nd)].
eps : float, optional
relative accuracy. Defaults to 1e-16.
rmax : int, optional
maximum rank. Defaults to the maximum possible integer.

Raises

ShapeMismatch
The product of modes should remain equal. Check the given shape.

Returns

TT
the resulting tensor.
Expand source code
def reshape(tens, shape, eps = 1e-16, rmax = sys.maxsize):
    """
    Reshapes a torchtt.TT tensor in the TT format.
    A rounding is also performed.
    
    Args:
        tens (torchtt.TT): the input tensor.
        shape (list[int] or list[tuple[int]]): the desired shape. In the case of a TT operator the shape has to be given as list of tuples of ints [(M1,N1),...,(Md,Nd)].
        eps (float, optional): relative accuracy. Defaults to 1e-16.
        rmax (int, optional): maximum rank. Defaults to the maximum possible integer.

    Raises:
        ShapeMismatch: The product of modes should remain equal. Check the given shape.

    Returns:
        torchtt.TT: the resulting tensor.
    """
    
    
    if tens.is_ttm:
        M = []
        N = []
        for t in shape:
            M.append(t[0])
            N.append(t[1])
        if np.prod(tens.N)!=np.prod(N) or np.prod(tens.M)!=np.prod(M):
            raise ShapeMismatch('The product of modes should remain equal. Check the given shape.')
        core = tens.cores[0]
        cores_new = []
        
        idx = 0
        idx_shape = 0
        
        while True:
            if core.shape[1] % M[idx_shape] == 0 and core.shape[2] % N[idx_shape] == 0:
                if core.shape[1] // M[idx_shape] > 1 or core.shape[2] // N[idx_shape] > 1:
                    m1 = M[idx_shape]
                    m2 = core.shape[1] // m1
                    n1 = N[idx_shape]
                    n2 = core.shape[2] // n1
                    r1 = core.shape[0]
                    r2 = core.shape[-1]
                    tmp = tn.reshape(core,[r1*m1,m2,n1,n2*r2])
                    
                    crz,_ = mat_to_tt(tmp, [r1*m1,m2], [n1,n2*r2], eps, rmax)
                    
                    cores_new.append(tn.reshape(crz[0],[r1,m1,n1,-1]))
                    
                    core = tn.reshape(crz[1],[-1,m2,n2,r2]) 
                else:
                    cores_new.append(core+0)
                    if idx == len(tens.cores)-1:
                        break
                    else:
                        idx+=1
                        core = tens.cores[idx]
                idx_shape += 1
                if idx_shape == len(shape):
                    break
            else: 
                idx += 1
                if idx>=len(tens.cores):
                    break
                
                core = tn.einsum('ijkl,lmno->ijmkno',core,tens.cores[idx])
                core = tn.reshape(core,[core.shape[0],core.shape[1]*core.shape[2],-1,core.shape[-1]])
                
    else:
        if np.prod(tens.N)!=np.prod(shape):
            raise ShapeMismatch('The product of modes should remain equal. Check the given shape.')
            
        core = tens.cores[0]
        cores_new = []
        
        idx = 0
        idx_shape = 0
        while True:
            if core.shape[1] % shape[idx_shape] == 0:
                if core.shape[1] // shape[idx_shape] > 1:
                    s1 = shape[idx_shape]
                    s2 = core.shape[1] // s1
                    r1 = core.shape[0]
                    r2 = core.shape[2]
                    tmp = tn.reshape(core,[r1*s1,s2*r2])
                    
                    crz,_ = to_tt(tmp,tmp.shape,eps,rmax)
                    
                    cores_new.append(tn.reshape(crz[0],[r1,s1,-1]))
                    
                    core = tn.reshape(crz[1],[-1,s2,r2]) 
                else:
                    cores_new.append(core+0)
                    if idx == len(tens.cores)-1:
                        break
                    else:
                        idx+=1
                        core = tens.cores[idx]
                idx_shape += 1
                if idx_shape == len(shape):
                    break
            else: 
                idx += 1
                if idx>=len(tens.cores):
                    break
                
                core = tn.einsum('ijk,klm->ijlm',core,tens.cores[idx])
                core = tn.reshape(core,[core.shape[0],-1,core.shape[-1]])
                
    return TT(cores_new).round(eps)
def save(tensor, path)

Save a TT object in a file.

Examples

import torchtt
#generate a TT object
A = torchtt.randn([10,20,30,40,4,5],[1,6,5,4,3,2,1])
# save the TT object
torchtt.save(A,"./test.TT")
# load the TT object
B = torchtt.load("./test.TT")
# the loaded should be the same
print((A-B).norm()/A.norm())

Args

tensor : TT
the tensor to be saved.
path : str
the file name.

Raises

InvalidArguments
First argument must be a torchtt.TT instance.
Expand source code
def save(tensor, path):
    """
    Save a `torchtt.TT` object in a file.

    Examples:
        ```
        import torchtt
        #generate a TT object
        A = torchtt.randn([10,20,30,40,4,5],[1,6,5,4,3,2,1])
        # save the TT object
        torchtt.save(A,"./test.TT")
        # load the TT object
        B = torchtt.load("./test.TT")
        # the loaded should be the same
        print((A-B).norm()/A.norm())
        ```
    
    Args:
        tensor (torchtt.TT): the tensor to be saved.
        path (str): the file name.

    Raises:
        InvalidArguments: First argument must be a torchtt.TT instance.
    """
    if not isinstance(tensor, TT):
        raise InvalidArguments("First argument must be a torchtt.TT instance.")
    
    if tensor.is_ttm:
        dct = {"is_ttm": tensor.is_ttm, "R": tensor.R, "M": tensor.M, "N": tensor.N, "cores": tensor.cores}
        tn.save(dct, path)
    else:
        dct = {"is_ttm": tensor.is_ttm, "R": tensor.R, "N": tensor.N, "cores": tensor.cores}
        tn.save(dct, path)
def shape_mn_to_tuple(M, N)

Convert the shape of a TTM from row/column format to tuple format.

Args

M : list[int]
row shapes.
N : list[int]
column shapes.

Returns

list[tuple[int]]
shape.
Expand source code
def shape_mn_to_tuple(M, N):
    """
    Convert the shape of a TTM from row/column format to tuple format.

    Args:
        M (list[int]): row shapes.
        N (list[int]): column shapes.

    Returns:
        list[tuple[int]]: shape.
    """

    return [(m,n) for m,n in zip(M,N)]
def shape_tuple_to_mn(shape)

Convert the shape of a TTM from tuple format to row and column shapes.

Args

shape : list[tuple[int]]
shape.

Returns

tuple[list[int],list[int]]
still the shape.
Expand source code
def shape_tuple_to_mn(shape):
    """
    Convert the shape of a TTM from tuple format to row and column shapes.

    Args:
        shape (list[tuple[int]]): shape.

    Returns:
        tuple[list[int],list[int]]: still the shape.
    """
    M = [s[0] for s in shape]
    N = [s[1] for s in shape]
    
    return M, N
def zeros(shape, dtype=torch.float64, device=None)

Construct a tensor that contains only zeros. the shape can be a list of ints or a list of tuples of ints. The second case creates a TT matrix.

Args

shape (list[int] | list[tuple[int]]): the shape.
dtype : torch.dtype, optional
the dtype of the returned tensor. Defaults to tn.float64.
device : torch.device, optional
the device where the TT cores are created (None means CPU). Defaults to None.

Raises

InvalidArguments
Shape must be a list.

Returns

TT
the zero tensor.
Expand source code
def zeros(shape, dtype=tn.float64, device = None):
    """
    Construct a tensor that contains only zeros.
    the shape can be a list of ints or a list of tuples of ints. The second case creates a TT matrix.

    Args:
        shape (list[int] | list[tuple[int]]): the shape.
        dtype (torch.dtype, optional): the dtype of the returned tensor. Defaults to tn.float64.
        device (torch.device, optional): the device where the TT cores are created (None means CPU). Defaults to None.

    Raises:
        InvalidArguments: Shape must be a list.

    Returns:
        torchtt.TT: the zero tensor.
    """
    if isinstance(shape,list):
        d = len(shape)
        if isinstance(shape[0],tuple):
            # we create a TT-matrix
            cores = [tn.zeros([1,shape[i][0],shape[i][1],1],dtype=dtype, device = device) for i in range(d)]            
            
        else:
            # we create a TT-tensor
            cores = [tn.zeros([1,shape[i],1],dtype=dtype, device = device) for i in range(d)]
            
    else:
        raise InvalidArguments('Shape must be a list.')
    
    return TT(cores)

Classes

class TT (source, shape=None, eps=1e-10, rmax=9223372036854775807)

Constructor of the TT class. Can convert full tensor in the TT-format (from torch.tensor or numpy.array). In the case of tensor operators of full shape M1 x … Md x N1 x … x Nd, the shape must be specified as a list of tuples [(M1,N1),…,(Md,Nd)]. A TT-object can also be computed from cores if the list of cores is passed as argument. If None is provided, an empty tensor is created.

The TT decomposition of a tensor is

\mathsf{x}=\sum\limits_{r_1...r_{d-1}=1}^{R_1,...,R_{d-1}} \mathsf{x}^{(1)}_{1i_1r_1}\cdots\mathsf{x}^{(d)}_{r_{d-1}i_d1},

where \{\mathsf{x}^{(k)}\}_{k=1}^d are the TT cores and \mathbf{R}=(1,R_1,...,R_{d-1},1) is the TT rank. Using the constructor, a TT decomposition of a tensor can be computed. The TT cores are stored as a list in torchtt.TT.cores.
This class implements basic operators such as +,-,*,/,@,** (add, subtract, elementwise multiplication, elementwise division, matrix vector product and Kronecker product) between TT instances. The examples\ folder server as a tutorial for all the possibilities of the toolbox.

Examples

import torchtt
import torch
x = torch.reshape(torch.arange(0,128,dtype = torch.float64),[8,4,4])
xtt = torchtt.TT(x)
ytt = torchtt.TT(torch.squeeze(x),[8,4,4])
# create a TT matrix
A = torch.reshape(torch.arange(0,20160,dtype = torch.float64),[3,5,7,4,6,8])
Att = torchtt.TT(A,[(3,4),(5,6),(7,8)])
print(Att)        

Args

source : torch.tensor ot list[torch.tensor] or numpy.array or None
the input tensor in full format or the cores. If a torch.tensor or numpy.array is provided
shape : list[int] or list[tuple[int]], optional
the shape (if it differs from the one provided). For the TT-matrix case is mandatory. Defaults to None.
eps : float, optional
tolerance of the TT approximation. Defaults to 1e-10.
rmax : int or list[int], optional
maximum rank (either a list of integer or an integer). Defaults to the maximum possible integer.

Raises

RankMismatch
Ranks of the given cores do not match (change the spaces of the cores).
InvalidArguments
Invalid input: TT-cores have to be either 4d or 3d.
InvalidArguments
Check the ranks and the mode size.
NotImplementedError
Function only implemented for torch tensors, numpy arrays, list of cores as torch tensors and None
Expand source code
class TT():

    #cores : list[tn.tensor]
    #""" The TT cores as a list of `torch.tensor` instances."""
    
    @property
    def is_ttm(self):
        """
        Check whether the instance is a TT operator or not.

        Returns:
            bool: the flag.
        """
        return self.__is_ttm

    @property 
    def M(self):
        """
        Return the "row" shape in case of TT matrices.

        Raises:
            IncompatibleTypes: The field is_ttm is defined only for TT matrices.

        Returns:
            list[int]: the shape.
        """
        if not self.__is_ttm:
            raise IncompatibleTypes("The field is_ttm is defined only for TT matrices.")
        return self.__M.copy()

    @property 
    def N(self):
        """
        Return the shape of a tensor or the "column" shape of a TT operator.

        Returns:
            list[int]: the shape.
        """
        return self.__N.copy()
    
    @property
    def R(self):
        """
        The rank of the TT decomposition.
        It's length should be `len(R)==len(N)+1`.

        Returns:
            list[int]: the rank.
        """
        return self.__R.copy()

    def __init__(self, source, shape=None, eps=1e-10, rmax=sys.maxsize):
        """
        Constructor of the TT class. Can convert full tensor in the TT-format (from `torch.tensor` or `numpy.array`).
        In the case of tensor operators of full shape `M1 x ... Md x N1 x ... x Nd`, the shape must be specified as a list of tuples `[(M1,N1),...,(Md,Nd)]`.
        A TT-object can also be computed from cores if the list of cores is passed as argument.
        If None is provided, an empty tensor is created.
        
        The TT decomposition of a tensor is
        
        \(\\mathsf{x}=\\sum\\limits_{r_1...r_{d-1}=1}^{R_1,...,R_{d-1}} \\mathsf{x}^{(1)}_{1i_1r_1}\\cdots\\mathsf{x}^{(d)}_{r_{d-1}i_d1}\),
        
        where \(\\{\\mathsf{x}^{(k)}\\}_{k=1}^d\) are the TT cores and \(\\mathbf{R}=(1,R_1,...,R_{d-1},1)\) is the TT rank.
        Using the constructor, a TT decomposition of a tensor can be computed. The TT cores are stored as a list in `torchtt.TT.cores`.   
        This class implements basic operators such as `+,-,*,/,@,**` (add, subtract, elementwise multiplication, elementwise division, matrix vector product and Kronecker product) between TT instances.
        The `examples\` folder server as a tutorial for all the possibilities of the toolbox.
        
        Examples:
            ```
            import torchtt
            import torch
            x = torch.reshape(torch.arange(0,128,dtype = torch.float64),[8,4,4])
            xtt = torchtt.TT(x)
            ytt = torchtt.TT(torch.squeeze(x),[8,4,4])
            # create a TT matrix
            A = torch.reshape(torch.arange(0,20160,dtype = torch.float64),[3,5,7,4,6,8])
            Att = torchtt.TT(A,[(3,4),(5,6),(7,8)])
            print(Att)        
            ```
            
        Args:
            source (torch.tensor ot list[torch.tensor] or numpy.array or None): the input tensor in full format or the cores. If a `torch.tensor` or `numpy.array` is provided
            shape (list[int] or list[tuple[int]], optional): the shape (if it differs from the one provided). For the TT-matrix case is mandatory. Defaults to None.
            eps (float, optional): tolerance of the TT approximation. Defaults to 1e-10.
            rmax (int or list[int], optional): maximum rank (either a list of integer or an integer). Defaults to the maximum possible integer.

        Raises:
            RankMismatch: Ranks of the given cores do not match (change the spaces of the cores).
            InvalidArguments: Invalid input: TT-cores have to be either 4d or 3d.
            InvalidArguments: Check the ranks and the mode size.
            NotImplementedError: Function only implemented for torch tensors, numpy arrays, list of cores as torch tensors and None
   
        """
        
        if source is None:
            # empty TT
            self.cores = []
            self.__M = []
            self.__N = []
            self.__R = [1,1]
            self.__is_ttm = False
            
        elif isinstance(source, list):
            # tt cores were passed directly
            
            # check if sizes are consistent
            prev = 1
            N = []
            M = []
            R = [source[0].shape[0]]
            d = len(source)
            for i in range(len(source)):
                s = source[i].shape
                
                if s[0] != R[-1]:
                    raise RankMismatch("Ranks of the given cores do not match: for core number %d previous rank is %d and and current rank is %d."%(i,R[-1],s[0]))
                if len(s) == 3:
                    R.append(s[2])
                    N.append(s[1])
                elif len(s)==4:
                    R.append(s[3])
                    M.append(s[1])
                    N.append(s[2])
                else:
                    raise InvalidArguments("Invalid input: TT-cores have to be either 4d or 3d.")
            
            if len(N) != d or len(R) != d+1 or R[0] != 1 or R[-1] != 1 or (len(M)!=0 and len(M)!=len(N)) :
                raise InvalidArguments("Check the ranks and the mode size.")
            
            self.cores = source
            self.__R = R
            self.__N = N
            if len(M) == len(N):
                self.__M = M
                self.__is_ttm = True
            else:
                self.__is_ttm = False
            self.shape = [ (m,n) for m,n in zip(self.__M,self.__N) ] if self.__is_ttm else [n for n in self.N]     

        elif tn.is_tensor(source):
            if shape == None:
                # no size is given. Deduce it from the tensor. No TT-matrix in this case.
                self.__N = list(source.shape)
                if len(self.__N)>1:
                    self.cores, self.__R = to_tt(source,self.__N,eps,rmax,is_sparse=False)
                else:    
                    self.cores = [tn.reshape(source,[1,self.__N[0],1])]
                    self.__R = [1,1]
                self.__is_ttm = False
            elif isinstance(shape,list) and isinstance(shape[0],tuple):
                # if the size contains tuples, we have a TT-matrix.
                if len(shape) > 1:
                    self.__M = [s[0] for s in shape]
                    self.__N = [s[1] for s in shape]
                    self.cores, self.__R = mat_to_tt(source, self.__M, self.__N, eps, rmax)
                    self.__is_ttm = True
                else:
                    self.__M = [shape[0][0]]
                    self.__N = [shape[0][1]]
                    self.cores, self.__R = [tn.reshape(source,[1,shape[0][0],shape[0][1],1])], [1,1]
                    self.__is_ttm = True
            else:
                # TT-decomposition with prescribed size
                # perform reshape first
                self.__N = shape
                self.cores, self.__R = to_tt(tn.reshape(source,shape),self.__N,eps,rmax,is_sparse=False)
                self.__is_ttm = False
            self.shape = [ (m,n) for m,n in zip(self.__M,self.__N) ] if self.__is_ttm else [n for n in self.N]     

        elif isinstance(source, np.ndarray):
            source = tn.tensor(source) 
                    
            if shape == None:
                # no size is given. Deduce it from the tensor. No TT-matrix in this case.
                self.__N = list(source.shape)
                if len(self.__N)>1:
                    self.cores, self.__R = to_tt(source,self.__N,eps,rmax,is_sparse=False)
                else:    
                    self.cores = [tn.reshape(source,[1,self.__N[0],1])]
                    self.__R = [1,1]
                self.__is_ttm = False
            elif isinstance(shape,list) and isinstance(shape[0],tuple):
                # if the size contains tuples, we have a TT-matrix.
                self.__M = [s[0] for s in shape]
                self.__N = [s[1] for s in shape]
                self.cores, self.__R = mat_to_tt(source, self.__M, self.__N, eps, rmax)
                self.__is_ttm = True
            else:
                # TT-decomposition with prescribed size
                # perform reshape first
                self.__N = shape
                self.cores, self.__R = to_tt(tn.reshape(source,shape),self.__N,eps,rmax,is_sparse=False)
                self.__is_ttm = False
            self.shape = [ (m,n) for m,n in zip(self.__M,self.__N) ] if self.__is_ttm else [n for n in self.N]     
        else:
            raise NotImplementedError("Function only implemented for torch tensors, numpy arrays, list of cores as torch tensors and None.")

    def cuda(self, device = None):
        """
        Return a torchtt.TT object on the CUDA device by cloning all the cores on the GPU.

        Args:
            device (torch.device, optional): The CUDA device (None for CPU). Defaults to None.

        Returns:
            torchtt.TT: The TT-object. The TT-cores are on CUDA.
        """
         
        
        t = TT([ c.cuda(device) for c in self.cores])

        return t

    def cpu(self):
        """
        Retrive the cores from the GPU.

        Returns:
            torchtt.TT: The TT-object on CPU.
        """

        
        return TT([ c.cpu() for c in self.cores])

    def is_cuda(self):
        """
        Return True if the tensor is on GPU.

        Returns:
            bool: Is the torchtt.TT on GPU or not.
        """
        return all([c.is_cuda for c in self.core])

    
    def to(self, device = None, dtype = None):
        """
        Moves the TT instance to the given device with the given dtype.

        Args:
            device (torch.device, optional): The desired device. If none is provided, the device is the CPU. Defaults to None.
            dtype (torch.dtype, optional): The desired dtype (torch.float64, torch.float32,...). If None is provided the dtype is not changed. Defaults to None.
        """
        return TT( [ c.to(device=device,dtype=dtype) for c in self.cores])

    def detach(self):
        """
        Detaches the TT tensor. Similar to torch.tensor.detach().

        Returns:
            torchtt.TT: the detached tensor.
        """
        return TT([c.detach() for c in self.cores])
        
    def clone(self):
        """
        Clones the torchtt.TT instance. Similar to torch.tensor.clone().

        Returns:
            torchtt.TT: the cloned TT object.
        """
        return TT([c.clone() for c in self.cores]) 

    def full(self):       
        """
        Return the full tensor.
        In case of a TTM, the result has the shape M1 x M2 x ... x Md x N1 x N2 x ... x Nd.

        Returns:
            torch.tensor: the full tensor.
        """
        if self.__is_ttm:
            # the case of tt-matrix
            tfull = self.cores[0][0,:,:,:]
            for i in  range(1,len(self.cores)-1) :
                tfull = tn.einsum('...i,ijkl->...jkl',tfull,self.cores[i])
            if len(self.__N) != 1:
                tfull = tn.einsum('...i,ijk->...jk',tfull,self.cores[-1][:,:,:,0])
                tfull = tn.permute(tfull,list(np.arange(len(self.__N))*2)+list(np.arange(len(self.N))*2+1))
            else:
                tfull = tfull[:,:,0]
        else:
            # the case of a normal tt
            tfull = self.cores[0][0,:,:]
            for i in  range(1,len(self.cores)-1) :
                tfull = tn.einsum('...i,ijk->...jk',tfull,self.cores[i])
            if len(self.__N) != 1:
                tfull = tn.einsum('...i,ij->...j',tfull,self.cores[-1][:,:,0])
            else:
                tfull = tn.squeeze(tfull)
        return tfull
    
    def numpy(self):
        """
        Return the full tensor as a numpy.array.
        In case of a TTM, the result has the shape M1 x M2 x ... x Md x N1 x N2 x ... x Nd.
        If it is involved in an AD graph, an error will occur.
        
        Returns:
            numpy.array: the full tensor in numpy.
        """
        return self.full().cpu().numpy()
    
    def __repr__(self):
        """
        Show the information as a string

        Returns:
            string: the string representation of a torchtt.TT
        """
        
        if self.__is_ttm:
            output = 'TT-matrix' 
            output += ' with sizes and ranks:\n'
            output += 'M = ' + str(self.__M) + '\nN = ' + str(self.__N) + '\n'
            output += 'R = ' + str(self.__R) + '\n'
            output += 'Device: '+str(self.cores[0].device)+', dtype: '+str(self.cores[0].dtype)+'\n'
            entries = sum([tn.numel(c)  for c in self.cores])
            output += '#entries ' + str(entries) +' compression ' + str(entries/np.prod(np.array(self.__N,dtype=np.float64)*np.array(self.__M,dtype=np.float64))) +  '\n'
        else:
            output = 'TT'
            output += ' with sizes and ranks:\n'
            output += 'N = ' + str(self.__N) + '\n'
            output += 'R = ' + str(self.__R) + '\n\n'
            output += 'Device: '+str(self.cores[0].device)+', dtype: '+str(self.cores[0].dtype)+'\n'
            entries = sum([tn.numel(c) for c in self.cores])
            output += '#entries ' + str(entries) +' compression '  + str(entries/np.prod(np.array(self.__N,dtype=np.float64))) + '\n'
        
        return output
    
    def __radd__(self,other):
        """
        Addition in the TT format. Implements the "+" operator. This function is called in the case a non-torchtt.TT object is added to the left.

        Args:
            other (float | int | torch.tensor): the first operand. If a `torch.tensor` is provided, it must have 1 element.

        Returns:
            torchtt.TT: the result.
        """
        
        return self.__add__(other)

    def __add__(self,other):
        """
        Addition in the TT format. Implements the "+" operator. The following type pairs are supported:
            - both operands are TT-tensors.
            - both operands are TT-matrices.
            - first operand is a TT-tensor or a TT-matrix and the second is a scalar (either torch.tensor scalar or int or float).
        The broadcasting rules from `torch` apply here.
        
        Args:
            other (torchtt.TT | float | int | torch.tensor): the second operand. If a `torch.tensor` is provided, it must have 1 element.
            
        Raises:
            ShapeMismatch: Dimension mismatch.
            IncompatibleTypes: Addition between a tensor and a matrix is not defined.

        Returns:
            torchtt.TT: the result.
        """

        if np.isscalar(other) or ( tn.is_tensor(other) and tn.numel(other) == 1):
            # the second term is a scalar
            cores =  []
            
            for i in range(len(self.__N)):
                if self.__is_ttm:
                    pad1 = (0,0 if i == len(self.__N)-1 else 1 , 0,0 , 0,0 , 0,0 if i==0 else 1)
                    pad2 = (0 if i == len(self.__N)-1 else self.__R[i+1],0 , 0,0 , 0,0 , 0 if i==0 else self.R[i],0)
                    othr = tn.ones([1,1,1,1],dtype=self.cores[i].dtype) * (other if i ==0 else 1)
                else:
                    pad1 = (0,0 if i == len(self.__N)-1 else 1 , 0,0 , 0,0 if i==0 else 1)
                    pad2 = (0 if i == len(self.__N)-1 else self.__R[i+1],0 , 0,0 , 0 if i==0 else self.R[i],0)
                    othr = tn.ones([1,1,1],dtype=self.cores[i].dtype) * (other if i ==0 else 1)
                

                cores.append(tnf.pad(self.cores[i],pad1)+tnf.pad(othr,pad2))

                
            result = TT(cores)
        elif isinstance(other,TT):
        #second term is TT object 
            if self.__is_ttm and other.is_ttm:
                # both are TT-matrices
                if self.__M != self.M or self.__N != self.N:
                    raise ShapeMismatch("Shapes are incompatible: first operand is %s x %s, second operand is %s x %s."%(str(self.M), str(self.N), str(other.M), str(other.N)))
                    
                cores = []
                for i in range(len(self.__N)):
                    pad1 = (0,0 if i == len(self.__N)-1 else other.R[i+1], 0,0 , 0,0 , 0,0 if i==0 else other.R[i])
                    pad2 = (0 if i == len(self.__N)-1 else self.__R[i+1],0 , 0,0 , 0,0 , 0 if i==0 else self.R[i],0)
                    cores.append(tnf.pad(self.cores[i],pad1)+tnf.pad(other.cores[i],pad2))
                    
                result = TT(cores)
                
            elif self.__is_ttm==False and other.is_ttm==False:
                # normal tensors in TT format.
                if self.__N == other.N:                  
                    cores = []
                    for i in range(len(self.__N)):
                        pad1 = (0,0 if i == len(self.__N)-1 else other.R[i+1] , 0,0 , 0,0 if i==0 else other.R[i])
                        pad2 = (0 if i == len(self.__N)-1 else self.__R[i+1],0 , 0,0 , 0 if i==0 else self.R[i],0)
                        cores.append(tnf.pad(self.cores[i],pad1)+tnf.pad(other.cores[i],pad2))
                else:
                    if len(self.__N) < len(other.N):
                        raise ShapeMismatch("Shapes are incompatible: first operand is %s, second operand is %s."%(str(self.N), str(other.N)))

                    cores = []
                    for i in range(len(self.cores)-len(other.cores)):
                        pad1 = (0,0 if i == len(self.__N)-1 else 1 , 0,0 , 0,0 if i==0 else 1)
                        pad2 = (0 if i == len(self.__N)-1 else self.__R[i+1],0 , 0,0 , 0 if i==0 else self.R[i],0)
                        cores.append(tnf.pad(self.cores[i],pad1)+tnf.pad(tn.ones((1,self.__N[i],1), device = self.cores[i].device),pad2))
                        
                    for k,i in zip(range(len(other.cores)), range(len(self.cores)-len(other.cores), len(self.cores))):
                        if other.N[k] == self.__N[i]:
                            pad1 = (0,0 if i == len(self.__N)-1 else other.R[k+1] , 0,0 , 0,0 if i==0 else other.R[k])
                            pad2 = (0 if i == len(self.__N)-1 else self.__R[i+1],0 , 0,0 , 0 if i==0 else self.R[i],0)
                            cores.append(tnf.pad(self.cores[i],pad1)+tnf.pad(other.cores[k],pad2))
                            
                        elif other.N[k] == 1:
                            pad1 = (0,0 if i == len(self.__N)-1 else other.R[k+1] , 0,0 , 0,0 if i==0 else other.R[k])
                            pad2 = (0 if i == len(self.__N)-1 else self.__R[i+1],0 , 0,0 , 0 if i==0 else self.R[i],0)
                            cores.append(tnf.pad(self.cores[i],pad1)+tnf.pad(tn.tile(other.cores[k],(1,self.__N[i],1)),pad2))
                        else:
                            raise ShapeMismatch("Shapes are incompatible: first operand is %s, second operand is %s."%(str(self.N), str(other.N)))
                            
                    
                result = TT(cores)
                
                
            else:
                # incompatible types 
                raise IncompatibleTypes('Addition between a tensor and a matrix is not defined.')
        else:
            InvalidArguments('Second term is incompatible.')
            
        return result
    
    def __rsub__(self,other):
        """
        Subtract 2 tensors in the TT format. Implements the "-" operator.  

        Args:
            other (torchtt.TT | float | int | torch.tensor): the first operand. If a `torch.tensor` is provided, it must have 1 element.

        Returns:
            torchtt.TT: the result.
        """
        
        T = self.__sub__(other)
        T.cores[0] = -T.cores[0]
        return T
    
    def __sub__(self,other):
        """
        Subtract 2 tensors in the TT format. Implements the "-" operator.
        Possible second operands are: torchtt.TT, float, int, torch.tensor with 1 element.
        Broadcasting rules from `torch` apply for this operation as well.
        
        Args:
            other (torchtt.TT | float | int | torch.tensor): the second operand. If a `torch.tensor` is provided, it must have 1 element.

        Raises:
            ShapeMismatch: Both dimensions of the TT matrix should be equal.
            ShapeMismatch: Dimension mismatch.
            IncompatibleTypes: Addition between a tensor and a matrix is not defined.
            InvalidArguments: Second term is incompatible (must be either torchtt.TT or int or float or torch.tensor with 1 element).

        Returns:
            torchtt.TT: the result.
        """
        if np.isscalar(other) or ( tn.is_tensor(other) and other.shape == []):
            # the second term is a scalar
            cores =  []
            
            for i in range(len(self.__N)):
                if self.__is_ttm:
                    pad1 = (0,0 if i == len(self.__N)-1 else 1 , 0,0 , 0,0 , 0,0 if i==0 else 1)
                    pad2 = (0 if i == len(self.__N)-1 else self.__R[i+1],0 , 0,0 , 0,0 , 0 if i==0 else self.R[i],0)
                    othr = tn.ones([1,1,1,1],dtype=self.cores[i].dtype) * (-other if i ==0 else 1)
                else:
                    pad1 = (0,0 if i == len(self.__N)-1 else 1 , 0,0 , 0,0 if i==0 else 1)
                    pad2 = (0 if i == len(self.__N)-1 else self.__R[i+1],0 , 0,0 , 0 if i==0 else self.R[i],0)
                    othr = tn.ones([1,1,1],dtype=self.cores[i].dtype) * (-other if i ==0 else 1)
                cores.append(tnf.pad(self.cores[i],pad1)+tnf.pad(othr,pad2))
            result = TT(cores)

        elif isinstance(other,TT):
        #second term is TT object 
            if self.__is_ttm and other.is_ttm:
                # both are TT-matrices
                if self.__M != self.M or self.__N != self.N:
                    raise ShapeMismatch("Shapes are incompatible: first operand is %s x %s, second operand is %s x %s."%(str(self.M), str(self.N), str(other.M), str(other.N)))
                
                cores = []
                for i in range(len(self.__N)):
                    pad1 = (0,0 if i == len(self.__N)-1 else other.R[i+1] , 0,0 , 0,0 , 0,0 if i==0 else other.R[i])
                    pad2 = (0 if i == len(self.__N)-1 else self.__R[i+1],0 , 0,0 , 0,0 , 0 if i==0 else self.R[i],0)
                    cores.append(tnf.pad(self.cores[i],pad1)+tnf.pad(-other.cores[i] if i==0 else other.cores[i],pad2))
                    
                result = TT(cores)
                
            elif self.__is_ttm==False and other.is_ttm==False:
                # normal tensors in TT format.
                if self.__N == other.N:                  
                    cores = []
                    for i in range(len(self.__N)):
                        pad1 = (0,0 if i == len(self.__N)-1 else other.R[i+1] , 0,0 , 0,0 if i==0 else other.R[i])
                        pad2 = (0 if i == len(self.__N)-1 else self.__R[i+1],0 , 0,0 , 0 if i==0 else self.R[i],0)
                        cores.append(tnf.pad(self.cores[i], pad1)+tnf.pad(-other.cores[i] if i==0 else other.cores[i],pad2))
                else:
                    if len(self.__N) < len(other.N):
                        raise ShapeMismatch("Shapes are incompatible: first operand is %s, second operand is %s."%(str(self.N), str(other.N)))

                    cores = []
                    for i in range(len(self.cores)-len(other.cores)):
                        pad1 = (0,0 if i == len(self.__N)-1 else 1 , 0,0 , 0,0 if i==0 else 1)
                        pad2 = (0 if i == len(self.__N)-1 else self.__R[i+1],0 , 0,0 , 0 if i==0 else self.R[i],0)
                        cores.append(tnf.pad(self.cores[i],pad1)+tnf.pad((-1 if i==0 else 1)*tn.ones((1,self.__N[i],1), device = self.cores[i].device),pad2))
                        
                    for k,i in zip(range(len(other.cores)), range(len(self.cores)-len(other.cores), len(self.cores))):
                        if other.N[k] == self.__N[i]:
                            pad1 = (0,0 if i == len(self.__N)-1 else other.R[k+1] , 0,0 , 0,0 if i==0 else other.R[k])
                            pad2 = (0 if i == len(self.__N)-1 else self.__R[i+1],0 , 0,0 , 0 if i==0 else self.R[i],0)
                            cores.append(tnf.pad(self.cores[i],pad1)+tnf.pad(-other.cores[k] if i==0 else other.cores[k],pad2))
                            
                        elif other.N[k] == 1:
                            pad1 = (0,0 if i == len(self.__N)-1 else other.R[k+1] , 0,0 , 0,0 if i==0 else other.R[k])
                            pad2 = (0 if i == len(self.__N)-1 else self.__R[i+1],0 , 0,0 , 0 if i==0 else self.R[i],0)
                            cores.append(tnf.pad(self.cores[i],pad1)+tnf.pad(tn.tile(-other.cores[k] if i==0 else other.cores[k],(1,self.__N[i],1)),pad2))
                        else:
                            raise ShapeMismatch("Shapes are incompatible: first operand is %s, second operand is %s."%(str(self.N), str(other.N)))
                            
                    
                result = TT(cores)
                
                
            else:
                # incompatible types 
                raise IncompatibleTypes('Addition between a tensor and a matrix is not defined.')
        else:
            InvalidArguments('Second term is incompatible (must be either torchtt.TT or int or float or torch.tensor with 1 element).')
            
        return result
    
    def __rmul__(self,other):
        """
        Elementwise multiplication in the TT format.
        This implements the "*" operator when the left operand is not torchtt.TT.
        Following are supported:

         * TT tensor and TT tensor
         * TT matrix and TT matrix
         * TT tensor and scalar(int, float or torch.tensor scalar)

        Args:
            other (torchtt.TT | float | int | torch.tensor): the first operand. If a `torch.tensor` is provided, it must have 1 element.

        Raises:
            ShapeMismatch: Shapes must be equal.
            IncompatibleTypes: Second operand must be the same type as the fisrt (both should be either TT matrices or TT tensors).
            InvalidArguments: Second operand must be of type: torchtt.TT, float, int of torch.tensor.

        Returns:
            torchtt.TT: [description]
        """
        
        return self.__mul__(other)
        
    def __mul__(self,other):
        """
        Elementwise multiplication in the TT format.
        This implements the "*" operator.
        Following are supported:
         - TT tensor and TT tensor
         - TT matrix and TT matrix
         - TT tensor and scalar(int, float or torch.tensor scalar)
        The broadcasting rules are the same as in torch (see [here](https://pytorch.org/docs/stable/notes/broadcasting.html)).
        
        Args:
            other (torchtt.TT | float | int | torch.tensor): the second operand. If a `torch.tensor` is provided, it must have 1 element.

        Raises:
            ShapeMismatch: Shapes are incompatible (see the broadcasting rules).
            IncompatibleTypes: Second operand must be the same type as the fisrt (both should be either TT matrices or TT tensors).
            InvalidArguments: Second operand must be of type: torchtt.TT, float, int of torch.tensor.

        Returns:
            torchtt.TT: the result.
        """
       
        # elementwise multiplication
        if isinstance(other, TT):
            if self.__is_ttm and other.is_ttm:
                if self.__N == other.N and self.__M == other.M:
                    # raise ShapeMismatch('Shapes must be equal.') 
                    
                    cores_new = []
                    
                    for i in range(len(self.cores)):
                        core = tn.reshape(tn.einsum('aijb,mijn->amijbn',self.cores[i],other.cores[i]),[self.__R[i]*other.R[i],self.__M[i],self.__N[i],self.R[i+1]*other.R[i+1]])
                        cores_new.append(core)
                        
                else:
                    raise ShapeMismatch("Shapes are incompatible: first operand is %s x %s, second operand is %s x %s."%(str(self.M), str(self.N), str(other.M), str(other.N)))
                    # if len(self.__N) < len(other.N):
                    #     raise ShapeMismatch("Shapes are incompatible: first operand is %s x %s, second operand is %s x %s."%(str(self.M), str(self.N), str(other.M), str(other.N)))
                    
                    # cores_new = []
                    # raise NotImplementedError("Not yet implemented.")
                    
            elif self.__is_ttm == False and other.is_ttm == False:
                # broadcasting rul;es have to be applied. Sperate if else to make the non-broadcasting case the fastest.
                if self.__N == other.N:  
                    cores_new = []
                    
                    for i in range(len(self.cores)):
                        core = tn.reshape(tn.einsum('aib,min->amibn',self.cores[i],other.cores[i]),[self.__R[i]*other.R[i],self.__N[i],self.R[i+1]*other.R[i+1]])
                        cores_new.append(core)
                else:
                    if len(self.__N) < len(other.N):
                        raise ShapeMismatch("Shapes are incompatible: first operand is %s, second operand is %s."%(str(self.N), str(other.N)))

                    cores_new = []
                    for i in range(len(self.cores)-len(other.cores)):
                        cores_new.append(self.cores[i]*1)
                        
                    for k,i in zip(range(len(other.cores)), range(len(self.cores)-len(other.cores), len(self.cores))):
                        if other.N[k] == self.__N[i]:
                            core = tn.reshape(tn.einsum('aib,min->amibn',self.cores[i],other.cores[k]),[self.__R[i]*other.R[k],self.__N[i],self.R[i+1]*other.R[k+1]])
                        elif other.N[k] == 1:
                            core = tn.reshape(tn.einsum('aib,mn->amibn',self.cores[i],other.cores[k][:,0,:]),[self.__R[i]*other.R[k],self.__N[i],self.R[i+1]*other.R[k+1]])
                        else:
                            raise ShapeMismatch("Shapes are incompatible: first operand is %s, second operand is %s."%(str(self.N), str(other.N)))
                            
                        cores_new.append(core)
                    
            else:
                raise IncompatibleTypes('Second operand must be the same type as the fisrt (both should be either TT matrices or TT tensors).')
            result = TT(cores_new)

        elif isinstance(other,int) or isinstance(other,float) or isinstance(other,tn.tensor):
            if other != 0:
                cores_new = [c+0 for c in self.cores]
                cores_new[0] *= other
                result = TT(cores_new)
            else:
                result = TT([tn.zeros((1,self.M[i],self.N[i],1) if self.is_ttm else (1,self.N[i],1), device = self.cores[0].device, dtype = self.cores[0].dtype) for i in range(len(self.N))])
                # result = zeros([(m,n) for m,n in zip(self.M,self.N)] if self.is_ttm else self.N, device=self.cores[0].device)
        else:
            raise InvalidArguments('Second operand must be of type: TT, float, int of tensorflow Tensor.')
                    
        return result
        
    def __matmul__(self,other):
        """
        Matrix-vector multiplication in TT-format
        Supported operands:
            - TT-matrix @ TT-tensor -> TT-tensor: y_i = A_ij * x_j
            - TT-tensor @ TT-matrix -> TT-tensor: y_j = x_i * A_ij 
            - TT-matrix @ TT-matrix -> TT-matrix: Y_ij = A_ik * B_kj
            - TT-matrix @ torch.tensor -> torch.tensor: y_bi = A_ij * x_bj 
        In the last case, the multiplication is performed along the last modes and a full torch.tensor is returned.

        Args:
            other (torchtt.TT | torch.tensor): the second operand.

        Raises:
            ShapeMismatch: Shapes do not match.
            InvalidArguments: Wrong arguments.

        Returns:
            torchtt.TT | torch.tensor: the result. Can be full tensor if the second operand is full tensor.
        """
     
        if self.__is_ttm and tn.is_tensor(other):
            if self.__N != list(other.shape)[-len(self.N):]:
                raise ShapeMismatch("Shapes do not match.")
            result = dense_matvec(self.cores,other) 
            return result

        elif self.__is_ttm and other.is_ttm == False:
            # matrix-vector multiplication
            if self.__N != other.N:
                raise ShapeMismatch("Shapes do not match.")
                
            cores_new = []
            
            for i in range(len(self.cores)):
                core = tn.reshape(tn.einsum('ijkl,mkp->imjlp',self.cores[i],other.cores[i]),[self.cores[i].shape[0]*other.cores[i].shape[0],self.cores[i].shape[1],self.cores[i].shape[3]*other.cores[i].shape[2]])
                cores_new.append(core)
            
            
        elif self.__is_ttm and other.is_ttm:
            # multiplication between 2 TT-matrices
            if self.__N != other.M:
                raise ShapeMismatch("Shapes do not match.")
                
            cores_new = []
            
            for i in range(len(self.cores)):
                core = tn.reshape(tn.einsum('ijkl,mknp->imjnlp',self.cores[i],other.cores[i]),[self.cores[i].shape[0]*other.cores[i].shape[0],self.cores[i].shape[1],other.cores[i].shape[2],self.cores[i].shape[3]*other.cores[i].shape[3]])
                cores_new.append(core)
        elif self.__is_ttm == False and other.is_ttm:
            # vector-matrix multiplication
            if self.__N != other.M:
                raise ShapeMismatch("Shapes do not match.")
                
            cores_new = []
            
            for i in range(len(self.cores)):
                core = tn.reshape(tn.einsum('mkp,ikjl->imjlp',self.cores[i],other.cores[i]),[self.cores[i].shape[0]*other.cores[i].shape[0],other.cores[i].shape[2],self.cores[i].shape[2]*other.cores[i].shape[3]])
                cores_new.append(core)
        else:
            raise InvalidArguments("Wrong arguments.")
            
        result = TT(cores_new)
        return result

    def fast_matvec(self,other, eps = 1e-12, initial = None, nswp = 20, verb = False, use_cpp = True):
        """
        Fast matrix vector multiplication A@x using DMRG iterations. Faster than traditional matvec + rounding.

        Args:
            other (torchtt.TT): the TT tensor.
            eps (float, optional): relative accuracy for DMRG. Defaults to 1e-12.
            initial (None|torchtt.TT, optional): an approximation of the product (None means random initial guess). Defaults to None.
            nswp (int, optional): number of DMRG iterations. Defaults to 40.
            verb (bool, optional): show info for debug. Defaults to False.
            use_cpp (bool, optional): use the C++ implementation if available. Defaults to True.

        Raises:
            InvalidArguments: Second operand has to be TT object.
            IncompatibleTypes: First operand should be a TT matrix and second a TT vector.

        Returns:
            torchtt.TT: the result.
        """
        
        if not isinstance(other,TT):
            raise InvalidArguments('Second operand has to be TT object.')
        if not self.__is_ttm or other.is_ttm:
            raise IncompatibleTypes('First operand should be a TT matrix and second a TT vector.')
            
        return dmrg_matvec(self, other, y0 = initial, eps = eps, verb = verb, nswp = nswp, use_cpp = use_cpp)

    def apply_mask(self,indices):
        """
        Evaluate the tensor on the given index list.

        Examples:
            ```
            x = torchtt.random([10,12,14],[1,4,5,1])
            indices = torch.tensor([[0,0,0],[1,2,3],[1,1,1]])
            val = x.apply_mask(indices)
            ```
            
        Args:
            indices (list[list[int]]): the index list where the tensor should be evaluated. Length is M.

        Returns:
            torch.tensor: the values of the tensor

        """
        result = apply_mask(self.cores,self.__R,indices)
        return result

    def __truediv__(self,other):
        """
        This function implements the "/" operator.
        This operation is performed using the AMEN solver. The number of sweeps and rthe relative accuracy are fixed.
        For most cases it is sufficient but sometimes it can fail.
        Check the function torchtt.elementwise_divide() if you want to change the arguments of the AMEN solver.
        

        Args:
            other (torchtt.TT | float | int | torch.tensor): the second operand. If a `torch.tensor` is provided, it must have 1 element.

        Raises:
            IncompatibleTypes: Operands should be either TT or TTM.
            ShapeMismatch: Both operands should have the same shape.
            InvalidArguments: Operand not permitted. A TT-object can be divided only with scalars.
            
        Returns:
            torchtt.TT: the result.
        """
        if isinstance(other,int) or isinstance(other,float) or tn.is_tensor(other):
            # divide by a scalar
            cores_new = self.cores.copy()
            cores_new[0] /= other
            result = TT(cores_new)
        elif isinstance(other,TT):
            if self.__is_ttm != other.is_ttm:
                raise IncompatibleTypes('Operands should be either TT or TTM.')
            if self.__N != other.N or (self.__is_ttm and self.__M != other.M):
                raise ShapeMismatch("Both operands should have the same shape.")
            result = TT(amen_divide(other,self,50,None,1e-12,500,verbose=False))       
        else:
            raise InvalidArguments('Operand not permitted. A TT-object can be divided only with scalars.')
            
       
        return result
    
    def __rtruediv__(self,other):
        """
        Right true division. this function is called when a non TT object is divided by a TT object.
        This operation is performed using the AMEN solver. The number of sweeps and rthe relative accuracy are fixed.
        For most cases it is sufficient but sometimes it can fail.
        Check the function torchtt.elementwise_divide() if you want to change the arguments of the AMEN solver.
        
        Example: 
            ```
            z = 1.0/x # x is TT instance
            ```
            
        Args:
            other (torchtt.TT | float | int | torch.tensor): the first operand. If a `torch.tensor` is provided, it must have 1 element.

        Raises:
            InvalidArguments: The first operand must be int, float or 1d torch.tensor.
            
        Returns:
            torchtt.TT: the result.
        """
        if isinstance(other,int) or isinstance(other,float) or ( tn.is_tensor(other) and other.numel()==1):
            o = TT([tn.ones((1,n,1),dtype=self.cores[0].dtype,device = self.cores[0].device) for n in self.__N])# ones(self.__N,dtype=self.cores[0].dtype,device = self.cores[0].device)
            o.cores[0] *= other
            cores_new = amen_divide(self,o,50,None,1e-12,500,verbose=False)
        else:
            raise InvalidArguments("The first operand must be int, float or 1d torch.tensor.")   
         
        return TT(cores_new)

    
    
    def t(self):
        """
        Returns the transpose of a given TT matrix.
                
                    
        Returns:
            torchtt.TT: the transpose.
            
        Raises:
            InvalidArguments: Has to be TT matrix.
        """ 
        if not self.__is_ttm:
            raise InvalidArguments('Has to be TT matrix.')
            
        cores_new = [tn.permute(c,[0,2,1,3]) for c in self.cores]
        
        return TT(cores_new)
        
    
    def norm(self,squared=False):
        """
        Computes the frobenius norm of a TT object.

        Args:
            squared (bool, optional): returns the square of the norm if True. Defaults to False.

        Returns:
            torch.tensor: the norm.
        """
        
        if any([c.requires_grad or c.grad_fn != None for c in self.cores]):
            norm = tn.tensor([[1.0]],dtype = self.cores[0].dtype, device=self.cores[0].device)
            
            if self.__is_ttm:
                for i in range(len(self.__N)):
                    norm = tn.einsum('ab,aijm,bijn->mn',norm, self.cores[i], tn.conj(self.cores[i]))
                norm = tn.squeeze(norm)
            else:
                           
                for i in range(len(self.__N)):
                    norm = tn.einsum('ab,aim,bin->mn',norm, self.cores[i], tn.conj(self.cores[i]))
                norm = tn.squeeze(norm)
            if squared:
                return norm
            else:
                return tn.sqrt(tn.abs(norm))
 
        else:        
            d = len(self.cores)

            core_now = self.cores[0]
            for i in range(d-1):
                if self.__is_ttm:
                    mode_shape = [core_now.shape[1],core_now.shape[2]]
                    core_now = tn.reshape(core_now,[core_now.shape[0]*core_now.shape[1]*core_now.shape[2],-1])
                else:
                    mode_shape = [core_now.shape[1]]
                    core_now = tn.reshape(core_now,[core_now.shape[0]*core_now.shape[1],-1])
                    
                # perform QR
                Qmat, Rmat = QR(core_now)
                     
                # take next core
                core_next = self.cores[i+1]
                shape_next = list(core_next.shape[1:])
                core_next = tn.reshape(core_next,[core_next.shape[0],-1])
                core_next = Rmat @ core_next
                core_next = tn.reshape(core_next,[Qmat.shape[1]]+shape_next)
                
                # update the cores
                
                core_now = core_next
            if squared:
                return tn.linalg.norm(core_next)**2
            else:
                return tn.linalg.norm(core_next)

    def sum(self,index = None):
        """
        Contracts a tensor in the TT format along the given indices and retuyrns the resulting tensor in the TT format.
        If no index list is given, the sum over all indices is performed.

        Examples:
            ```
            a = torchtt.ones([3,4,5,6,7])
            print(a.sum()) 
            print(a.sum([0,2,4]))
            print(a.sum([1,2]))
            print(a.sum([0,1,2,3,4]))
            ```
            
        Args:
            index (int | list[int] | None, optional): the indices along which the summation is performed. None selects all of them. Defaults to None.

        Raises:
            InvalidArguments: Invalid index.

        Returns:
            torchtt.TT/torch.tensor: the result.
        """
        
        if index != None and isinstance(index,int):
            index = [index]
        if not isinstance(index,list) and index != None:
            raise InvalidArguments('Invalid index.')
             
        if index == None: 
            # the case we need to sum over all modes
            if self.__is_ttm:
                C = tn.reduce_sum(self.cores[0],[0,1,2])
                for i in range(1,len(self.__N)):
                    C = tn.sum(tn.einsum('i,ijkl->jkl',C,self.cores[i]),[0,1])
                S = tn.sum(C)
            else:
                C = tn.sum(self.cores[0],[0,1])
                for i in range(1,len(self.__N)):
                    C = tn.sum(tn.einsum('i,ijk->jk',C,self.cores[i]),0)
                S = tn.sum(C)
        else:
            # we return the TT-tensor with summed indices
            cores = []
            
            if self.__is_ttm:
                tmp = [1,2]
            else:
                tmp = [1]
                
            for i in range(len(self.__N)):
                if i in index:
                    C = tn.sum(self.cores[i], tmp, keepdim = True)
                    cores.append(C)
                else:
                    cores.append(self.cores[i])
                        
            S = TT(cores)
            S.reduce_dims()
            if len(S.cores)==1 and tn.numel(S.cores[0])==1:
                S = tn.squeeze(S.cores[0])
        return S

    def to_ttm(self):
        """
        Converts a TT-tensor to the TT-matrix format. In the tensor has the shape N1 x ... x Nd, the result has the shape 
        N1 x ... x Nd x 1 x ... x 1.
    
        Returns:
            torch.TT: the result
        """

        cores_new = [tn.reshape(c,(c.shape[0],c.shape[1],1,c.shape[2])) for c in self.cores]
        return TT(cores_new)

    def reduce_dims(self, exclude = []):
        """
        Reduces the size 1 modes of the TT-object.
        At least one mode should be larger than 1.

        Args:
            exclude (list, optional): Indices to exclude. Defaults to [].
        """
        
        # TODO: implement a version that reduces the rank also. by spliting the cores with modes 1 into 2 using the SVD.
        
        if self.__is_ttm:
            cores_new = []
            
            for i in range(len(self.__N)):
                
                if self.cores[i].shape[1] == 1 and self.cores[i].shape[2] == 1 and not i in exclude:
                    if self.cores[i].shape[0] > self.cores[i].shape[3] or i == len(self.__N)-1:
                        # multiply to the left
                        if len(cores_new) > 0:
                            cores_new[-1] = tn.einsum('ijok,kl->ijol',cores_new[-1], self.cores[i][:,0,0,:])
                        else: 
                            # there is no core to the left. Multiply right.
                            if i != len(self.__N)-1:
                                self.cores[i+1] = tn.einsum('ij,jkml->ikml', self.cores[i][:,0,0,:],self.cores[i+1])
                            else:
                                cores_new.append(self.cores[i])
                            
                    else:
                        # multiply to the right. Set the carry 
                        self.cores[i+1] = tn.einsum('ij,jkml->ikml',self.cores[i][:,0,0,:],self.cores[i+1])
                        
                else:
                    cores_new.append(self.cores[i])
                    
            # update the cores and ranks and shape
            self.__N = []
            self.__M = []
            self.__R = [1]
            for i in range(len(cores_new)):
                self.__N.append(cores_new[i].shape[2])
                self.__M.append(cores_new[i].shape[1])
                self.__R.append(cores_new[i].shape[3])
            self.cores = cores_new
        else:
            cores_new = []
            
            for i in range(len(self.__N)):
                
                if self.cores[i].shape[1] == 1 and not i in exclude:
                    if self.cores[i].shape[0] > self.cores[i].shape[2] or i == len(self.__N)-1:
                        # multiply to the left
                        if len(cores_new) > 0:
                            cores_new[-1] = tn.einsum('ijk,kl->ijl',cores_new[-1], self.cores[i][:,0,:])
                        else: 
                            # there is no core to the left. Multiply right.
                            if i != len(self.__N)-1:
                                self.cores[i+1] = tn.einsum('ij,jkl->ikl', self.cores[i][:,0,:],self.cores[i+1])
                            else:
                                cores_new.append(self.cores[i])
                                
                            
                    else:
                        # multiply to the right. Set the carry 
                        self.cores[i+1] = tn.einsum('ij,jkl->ikl',self.cores[i][:,0,:],self.cores[i+1])
                        
                else:
                    cores_new.append(self.cores[i])
            
            
            # update the cores and ranks and shape
            self.__N = []
            self.__R = [1]
            for i in range(len(cores_new)):
                self.__N.append(cores_new[i].shape[1])
                self.__R.append(cores_new[i].shape[2])
            self.cores = cores_new
                    
                    
        self.shape = [ (m,n) for m,n in zip(self.__M,self.__N) ] if self.__is_ttm else [n for n in self.N] 
        
    def __getitem__(self,index):
        """
        Performs slicing of a TT object.
        Both TT matrix and TT tensor are supported.
        Similar to pytorch or numpy slicing.

        Args:
            index (tuple[slice] | tuple[int] | int | Ellipsis | slice): the slicing.

        Raises:
            NotImplementedError: Ellipsis are not supported.
            InvalidArguments: Slice size is invalid.
            InvalidArguments: Slice carguments not valid. They have to be either int, slice | None.
            InvalidArguments: Invalid slice. Tensor is not 1d.


        Returns:
            torchtt.TT | torch.tensor: the result. If all the indices are fixed, a scalar torch.tensor is returned otherwise a torchtt.TT.
        """
        
        
        # slicing function
        
        ##### TODO: include Ellipsis support for tensor operators.
        
        # if a slice containg integers is passed, an element is returned
        # if ranged slices are used, a TT-object has to be returned.

        exclude = []
        
        if isinstance(index,tuple):
            # check if more than two Ellipsis are to be found.
            if index.count(Ellipsis) > 1 or (self.is_ttm and index.count(Ellipsis) > 0):
                raise NotImplementedError('Ellipsis are not supported more than once of for tensor operators.')
            
            if self.__is_ttm:
                    
                    
                cores_new = []
                k=0
                for i in range(len(index)//2):
                    idx1 = index[i]
                    idx2 = index[i+len(index)//2]
                    if isinstance(idx1,slice) and isinstance(idx2,slice):
                        cores_new.append(self.cores[k][:,idx1,idx2,:])
                        k+=1
                    elif idx1==None and idx2==None:
                        # extend the tensor
                        tmp = tn.eye(cores_new[-1].shape[-1] if len(cores_new)!=0 else 1, device = self.cores[0].device, dtype = self.cores[0].dtype)[:,None,None,:]
                        cores_new.append(tmp)
                        exclude.append(i)
                    elif isinstance(idx1, int) and isinstance(idx2,int):
                        cores_new.append(tn.reshape(self.cores[k][:,idx1,idx2,:],[self.__R[k],1,1,self.R[k+1]]))
                        k+=1
                    else:
                        raise InvalidArguments("Slice carguments not valid. They have to be either int, slice or None.")
                if k<len(self.cores):
                    raise InvalidArguments('Slice size is invalid.')
                
            else:
                # if len(index) != len(self.__N):
                #    raise InvalidArguments('Slice size is invalid.')
                num_none = sum([i is None for i in index])
                
                if index[0] == Ellipsis:
                    index = (slice(None, None, None),)*(len(self.__N)-len(index)+1+num_none) + index[1:]
                elif index[-1] == Ellipsis:
                    index = index[:-1] + (slice(None, None, None),)*(len(self.__N)-len(index)+1+num_none)
                cores_new = []
                k = 0
                for i,idx in enumerate(index):
                    if isinstance(idx,slice):
                        cores_new.append(self.cores[k][:,idx,:])
                        k+=1
                    elif idx is None:
                        # extend the tensor
                        tmp = tn.eye(cores_new[-1].shape[-1] if len(cores_new)!=0 else 1, device = self.cores[0].device, dtype = self.cores[0].dtype)[:,None,:]
                        cores_new.append(tmp)
                        exclude.append(i)
                    elif isinstance(idx, int):
                        cores_new.append(tn.reshape(self.cores[k][:,idx,:],[self.__R[k],-1,self.R[k+1]]))
                        k+=1
                    else:
                        raise InvalidArguments("Slice carguments not valid. They have to be either int, slice or None.")
                if k<len(self.cores):
                    raise InvalidArguments('Slice size is invalid.')
                        
                
            sliced = TT(cores_new)
            sliced.reduce_dims(exclude)
            if (sliced.is_ttm == False and sliced.N == [1]) or (sliced.is_ttm and sliced.N == [1] and sliced.M == [1]):
                sliced = tn.squeeze(sliced.cores[0])
                
                
            # cores = None
            
            
        elif isinstance(index,int):
            # tensor is 1d and one element is retrived
            if len(self.__N) == 1:
                sliced = self.cores[0][0,index,0]
            else:
                raise InvalidArguments('Invalid slice. Tensor is not 1d.')
                
            ## TODO
        elif index == Ellipsis:
            # return a copy of the tensor
            sliced = TT([c.clone() for c in self.cores])
            
        elif isinstance(index,slice):
            # tensor is 1d and one slice is extracted
            if len(self.__N) == 1:
                sliced = TT(self.cores[0][:,index,:])
            else:
                raise InvalidArguments('Invalid slice. Tensor is not 1d.')
            ## TODO
        else:
            raise InvalidArguments('Invalid slice.')
            
        
        return sliced
    
    def __pow__(self, other):
        """
        Computes the tensor Kronecker product.
        This implements the "**" operator.
        If None is provided as input the reult is the other tensor.
        If A is N_1 x ... x N_d and B is M_1 x ... x M_p, then kron(A,B) is N_1 x ... x N_d x M_1 x ... x M_p


        Args:
            first (torchtt.TT or None): first argument.
            second (torchtt.TT or none): second argument.

        Raises:
            IncompatibleTypes: Incompatible data types (make sure both are either TT-matrices or TT-tensors).
            InvalidArguments: Invalid arguments.

        Returns:
            torchtt.TT: the result.
        """
        
     
        if other == None: 
            cores_new = [c.clone() for c in self.cores]
            result = TT(cores_new)
        elif isinstance(other,TT):
            if self.is_ttm != other.is_ttm:
                raise IncompatibleTypes('Incompatible data types (make sure both are either TT-matrices or TT-tensors).')
        
            # concatenate the result
            cores_new = [c.clone() for c in self.cores] + [c.clone() for c in other.cores]
            result = TT(cores_new)
        else:
            raise InvalidArguments('Invalid arguments.')
        
        return result
    
    def __rpow__(self,other):
        """
        Computes the tensor Kronecker product.
        This implements the "**" operator.
        If None is provided as input the reult is the other tensor.
        If A is N_1 x ... x N_d and B is M_1 x ... x M_p, then kron(A,B) is N_1 x ... x N_d x M_1 x ... x M_p


        Args:
            first (torchtt.TT or None): first argument.
            second (torchtt.TT or none): second argument.

        Raises:
            IncompatibleTypes: Incompatible data types (make sure both are either TT-matrices or TT-tensors).
            InvalidArguments: Invalid arguments.

        Returns:
            torchtt.TT: the result.
        """
        
        result = kron(self,other)
        
        return result
    
    def __neg__(self):
        """
        Returns the negative of a given TT tensor.
        This implements the unery operator "-"

        Returns:
            torchtt.TT: the negated tensor.
        """
    
        cores_new = [c.clone() for c in self.cores]
        cores_new[0] = -cores_new[0]
        return TT(cores_new)
    
    def __pos__(self):
        """
        Implements the unary "+" operator returning a copy o the tensor.

        Returns:
            torchtt.TT: the tensor clone.
        """
        
        cores_new = [c.clone() for c in self.cores]

        return TT(cores_new)
    
    def round(self, eps=1e-12, rmax = sys.maxsize): 
        """
        Implements the rounding operations within a given tolerance epsilon.
        The maximum rank is also provided.

        Args:
            eps (float, optional): the relative accuracy. Defaults to 1e-12.
            rmax (int, optional): the maximum rank. Defaults to the maximum possible integer.

        Returns:
            torchtt.TT: the result.
        """
        
        # rmax is not list
        if not isinstance(rmax,list):
            rmax = [1] + len(self.__N)*[rmax] + [1]
            
        # call the round function
        tt_cores, R = round_tt(self.cores, self.__R.copy(), eps, rmax,self.__is_ttm)
        # creates a new TT and return it
        T = TT(tt_cores)
               
        return T
    
    def to_qtt(self, eps = 1e-12, mode_size = 2, rmax = sys.maxsize):
        """
        Converts a tensor to the QTT format: N1 x N2 x ... x Nd -> mode_size x mode_size x ... x mode_size.
        The product of the mode sizes should be a power of mode_size.
        The tensor in QTT can be converted back using the qtt_to_tens() method.

        Examples:
            ```
            x = torchtt.random([16,8,64,128],[1,2,10,12,1])
            x_qtt = x.to_qtt()
            print(x_qtt)
            xf = x_qtt.qtt_to_tens(x.N) # a TT-rounding is recommended.
            ```
            
        Args:
            eps (float,optional): the accuracy. Defaults to 1e-12.
            mode_size (int, optional): the size of the modes. Defaults to 2.
            rmax (int): the maximum rank. Defaults to the maximum possible integer.
            

        Raises:
            ShapeMismatch: Only quadratic TTM can be tranformed to QTT.
            ShapeMismatch: Reshaping error: check if the dimensions are powers of the desired mode size.

        Returns:
            torchtt.TT: the resulting reshaped tensor.
                       
        """
       
        cores_new = []
        if self.__is_ttm:
            shape_new = []
            for i in range(len(self.__N)):
                if self.__N[i]!=self.__M[i]:
                    raise ShapeMismatch('Only quadratic TTM can be tranformed to QTT.')
                if self.__N[i]==mode_size**int(math.log(self.N[i],mode_size)):
                    shape_new += [(mode_size,mode_size)]*int(math.log(self.__N[i],mode_size))
                else:
                    raise ShapeMismatch('Reshaping error: check if the dimensions are powers of the desired mode size:\r\ncore size '+str(list(self.cores[i].shape))+' cannot be reshaped.')
                
            result = reshape(self, shape_new, eps, rmax)
        else:
            for core in self.cores:
                if int(math.log(core.shape[1],mode_size))>1:
                    Nnew = [core.shape[0]*mode_size]+[mode_size]*(int(math.log(core.shape[1],mode_size))-2)+[core.shape[2]*mode_size]
                    try:
                        core = tn.reshape(core,Nnew)
                    except:
                        raise ShapeMismatch('Reshaping error: check if the dimensions care powers of the desired mode size:\r\ncore size '+str(list(core.shape))+' cannot be reshaped to '+str(Nnew))
                    cores,_ = to_tt(core,Nnew,eps,rmax,is_sparse=False)
                    cores_new.append(tn.reshape(cores[0],[-1,mode_size,cores[0].shape[-1]]))
                    cores_new += cores[1:-1]
                    cores_new.append(tn.reshape(cores[-1],[cores[-1].shape[0],mode_size,-1]))
                else: 
                    cores_new.append(core)
            result = TT(cores_new)
            
        return result
               
    def qtt_to_tens(self, original_shape):
        """
        Transform a tensor back from QTT.

        Args:
            original_shape (list): the original shape.

        Raises:
            InvalidArguments: Original shape must be a list.
            ShapeMismatch: Mode sizes do not match.

        Returns:
            torchtt.TT: the folded tensor.
        """
        
        if not isinstance(original_shape,list):
            raise InvalidArguments("Original shape must be a list.")

        core = None
        cores_new = []
        
        if self.__is_ttm:
            pass
        else:
            k = 0
            for c in self.cores:
                if core==None:
                    core = c
                    so_far = core.shape[1]
                else:
                    core = tn.einsum('...i,ijk->...jk',core,c)
                    so_far *= c.shape[1]
                if so_far==original_shape[k]:
                    core = tn.reshape(core,[core.shape[0],-1,core.shape[-1]])
                    cores_new.append(core)
                    core = None
                    k += 1
            if k!= len(original_shape):
                raise ShapeMismatch('Mode sizes do not match.')
        return TT(cores_new)
    
    def mprod(self, factor_matrices, mode):
        """
        n-mode product.

        Args:
            factor_matrices (torch.tensor or list[torch.tensor]): either a single matrix is directly provided or a list of matrices for product along multiple modes.
            mode (int or list[int]): the mode for the product. If factor_matrices is a torch.tensor then mode is an integer and the multiplication will be performed along a single mode.
                                     If factor_matrices is a list, the mode has to be list[int] of equal size.

        Raises:
            InvalidArguments: Invalid arguments.
            ShapeMismatch: The n-th mode of the tensor must be equal with the 2nd mode of the matrix.
            IncompatibleTypes: n-model product works only with TT-tensors and not TT matrices.
            
        Returns:
            torchtt.TT: the result
        """
        if self.__is_ttm:
            raise IncompatibleTypes("n-model product works only with TT-tensors and not TT matrices.")
    
        if isinstance(factor_matrices,list) and isinstance(mode, list):
            cores_new = [c.clone() for c in self.cores]
            for i in range(len(factor_matrices)):
                if cores_new[mode[i]].shape[1] != factor_matrices[i].shape[1]:
                    raise ShapeMismatch("The n-th mode of the tensor must be equal with the 2nd mode of the matrix.")
                cores_new[mode[i]] =  tn.einsum('ijk,lj->ilk',cores_new[mode[i]],factor_matrices[i]) # if self.__is_ttm else tn.einsum('ijk,lj->ilk',cores_new[mode[i]],factor_matrices[i]) 
        elif isinstance(mode, int) and tn.is_tensor(factor_matrices):
            cores_new = [c.clone() for c in self.cores]
            if cores_new[mode].shape[1] != factor_matrices.shape[1]:
                raise ShapeMismatch("The n-th mode of the tensor must be equal with the 2nd mode of the matrix.")
            cores_new[mode] =  tn.einsum('ijk,lj->ilk',cores_new[mode],factor_matrices) # if self.__is_ttm else tn.einsum('ijk,lj->ilk',cores_new[mode],factor_matrices) 
        else:
            raise InvalidArguments('Invalid arguments.')
        
        return TT(cores_new)        
        
    def conj(self):
        """
        Return the complex conjugate of a tensor in TT format.

        Returns:
            torchtt.TT: the complex conjugated tensor.
        """
        return TT([tn.conj(c) for c in self.cores])

Instance variables

var M

Return the "row" shape in case of TT matrices.

Raises

IncompatibleTypes
The field is_ttm is defined only for TT matrices.

Returns

list[int]
the shape.
Expand source code
@property 
def M(self):
    """
    Return the "row" shape in case of TT matrices.

    Raises:
        IncompatibleTypes: The field is_ttm is defined only for TT matrices.

    Returns:
        list[int]: the shape.
    """
    if not self.__is_ttm:
        raise IncompatibleTypes("The field is_ttm is defined only for TT matrices.")
    return self.__M.copy()
var N

Return the shape of a tensor or the "column" shape of a TT operator.

Returns

list[int]
the shape.
Expand source code
@property 
def N(self):
    """
    Return the shape of a tensor or the "column" shape of a TT operator.

    Returns:
        list[int]: the shape.
    """
    return self.__N.copy()
var R

The rank of the TT decomposition. It's length should be len(R)==len(N)+1.

Returns

list[int]
the rank.
Expand source code
@property
def R(self):
    """
    The rank of the TT decomposition.
    It's length should be `len(R)==len(N)+1`.

    Returns:
        list[int]: the rank.
    """
    return self.__R.copy()
var is_ttm

Check whether the instance is a TT operator or not.

Returns

bool
the flag.
Expand source code
@property
def is_ttm(self):
    """
    Check whether the instance is a TT operator or not.

    Returns:
        bool: the flag.
    """
    return self.__is_ttm

Methods

def apply_mask(self, indices)

Evaluate the tensor on the given index list.

Examples

x = torchtt.random([10,12,14],[1,4,5,1])
indices = torch.tensor([[0,0,0],[1,2,3],[1,1,1]])
val = x.apply_mask(indices)

Args

indices : list[list[int]]
the index list where the tensor should be evaluated. Length is M.

Returns

torch.tensor
the values of the tensor
Expand source code
def apply_mask(self,indices):
    """
    Evaluate the tensor on the given index list.

    Examples:
        ```
        x = torchtt.random([10,12,14],[1,4,5,1])
        indices = torch.tensor([[0,0,0],[1,2,3],[1,1,1]])
        val = x.apply_mask(indices)
        ```
        
    Args:
        indices (list[list[int]]): the index list where the tensor should be evaluated. Length is M.

    Returns:
        torch.tensor: the values of the tensor

    """
    result = apply_mask(self.cores,self.__R,indices)
    return result
def clone(self)

Clones the torchtt.TT instance. Similar to torch.tensor.clone().

Returns

TT
the cloned TT object.
Expand source code
def clone(self):
    """
    Clones the torchtt.TT instance. Similar to torch.tensor.clone().

    Returns:
        torchtt.TT: the cloned TT object.
    """
    return TT([c.clone() for c in self.cores]) 
def conj(self)

Return the complex conjugate of a tensor in TT format.

Returns

TT
the complex conjugated tensor.
Expand source code
def conj(self):
    """
    Return the complex conjugate of a tensor in TT format.

    Returns:
        torchtt.TT: the complex conjugated tensor.
    """
    return TT([tn.conj(c) for c in self.cores])
def cpu(self)

Retrive the cores from the GPU.

Returns

TT
The TT-object on CPU.
Expand source code
def cpu(self):
    """
    Retrive the cores from the GPU.

    Returns:
        torchtt.TT: The TT-object on CPU.
    """

    
    return TT([ c.cpu() for c in self.cores])
def cuda(self, device=None)

Return a torchtt.TT object on the CUDA device by cloning all the cores on the GPU.

Args

device : torch.device, optional
The CUDA device (None for CPU). Defaults to None.

Returns

TT
The TT-object. The TT-cores are on CUDA.
Expand source code
def cuda(self, device = None):
    """
    Return a torchtt.TT object on the CUDA device by cloning all the cores on the GPU.

    Args:
        device (torch.device, optional): The CUDA device (None for CPU). Defaults to None.

    Returns:
        torchtt.TT: The TT-object. The TT-cores are on CUDA.
    """
     
    
    t = TT([ c.cuda(device) for c in self.cores])

    return t
def detach(self)

Detaches the TT tensor. Similar to torch.tensor.detach().

Returns

TT
the detached tensor.
Expand source code
def detach(self):
    """
    Detaches the TT tensor. Similar to torch.tensor.detach().

    Returns:
        torchtt.TT: the detached tensor.
    """
    return TT([c.detach() for c in self.cores])
def fast_matvec(self, other, eps=1e-12, initial=None, nswp=20, verb=False, use_cpp=True)

Fast matrix vector multiplication A@x using DMRG iterations. Faster than traditional matvec + rounding.

Args

other : TT
the TT tensor.
eps : float, optional
relative accuracy for DMRG. Defaults to 1e-12.
initial (None|torchtt.TT, optional): an approximation of the product (None means random initial guess). Defaults to None.
nswp : int, optional
number of DMRG iterations. Defaults to 40.
verb : bool, optional
show info for debug. Defaults to False.
use_cpp : bool, optional
use the C++ implementation if available. Defaults to True.

Raises

InvalidArguments
Second operand has to be TT object.
IncompatibleTypes
First operand should be a TT matrix and second a TT vector.

Returns

TT
the result.
Expand source code
def fast_matvec(self,other, eps = 1e-12, initial = None, nswp = 20, verb = False, use_cpp = True):
    """
    Fast matrix vector multiplication A@x using DMRG iterations. Faster than traditional matvec + rounding.

    Args:
        other (torchtt.TT): the TT tensor.
        eps (float, optional): relative accuracy for DMRG. Defaults to 1e-12.
        initial (None|torchtt.TT, optional): an approximation of the product (None means random initial guess). Defaults to None.
        nswp (int, optional): number of DMRG iterations. Defaults to 40.
        verb (bool, optional): show info for debug. Defaults to False.
        use_cpp (bool, optional): use the C++ implementation if available. Defaults to True.

    Raises:
        InvalidArguments: Second operand has to be TT object.
        IncompatibleTypes: First operand should be a TT matrix and second a TT vector.

    Returns:
        torchtt.TT: the result.
    """
    
    if not isinstance(other,TT):
        raise InvalidArguments('Second operand has to be TT object.')
    if not self.__is_ttm or other.is_ttm:
        raise IncompatibleTypes('First operand should be a TT matrix and second a TT vector.')
        
    return dmrg_matvec(self, other, y0 = initial, eps = eps, verb = verb, nswp = nswp, use_cpp = use_cpp)
def full(self)

Return the full tensor. In case of a TTM, the result has the shape M1 x M2 x … x Md x N1 x N2 x … x Nd.

Returns

torch.tensor
the full tensor.
Expand source code
def full(self):       
    """
    Return the full tensor.
    In case of a TTM, the result has the shape M1 x M2 x ... x Md x N1 x N2 x ... x Nd.

    Returns:
        torch.tensor: the full tensor.
    """
    if self.__is_ttm:
        # the case of tt-matrix
        tfull = self.cores[0][0,:,:,:]
        for i in  range(1,len(self.cores)-1) :
            tfull = tn.einsum('...i,ijkl->...jkl',tfull,self.cores[i])
        if len(self.__N) != 1:
            tfull = tn.einsum('...i,ijk->...jk',tfull,self.cores[-1][:,:,:,0])
            tfull = tn.permute(tfull,list(np.arange(len(self.__N))*2)+list(np.arange(len(self.N))*2+1))
        else:
            tfull = tfull[:,:,0]
    else:
        # the case of a normal tt
        tfull = self.cores[0][0,:,:]
        for i in  range(1,len(self.cores)-1) :
            tfull = tn.einsum('...i,ijk->...jk',tfull,self.cores[i])
        if len(self.__N) != 1:
            tfull = tn.einsum('...i,ij->...j',tfull,self.cores[-1][:,:,0])
        else:
            tfull = tn.squeeze(tfull)
    return tfull
def is_cuda(self)

Return True if the tensor is on GPU.

Returns

bool
Is the torchtt.TT on GPU or not.
Expand source code
def is_cuda(self):
    """
    Return True if the tensor is on GPU.

    Returns:
        bool: Is the torchtt.TT on GPU or not.
    """
    return all([c.is_cuda for c in self.core])
def mprod(self, factor_matrices, mode)

n-mode product.

Args

factor_matrices : torch.tensor or list[torch.tensor]
either a single matrix is directly provided or a list of matrices for product along multiple modes.
mode : int or list[int]
the mode for the product. If factor_matrices is a torch.tensor then mode is an integer and the multiplication will be performed along a single mode. If factor_matrices is a list, the mode has to be list[int] of equal size.

Raises

InvalidArguments
Invalid arguments.
ShapeMismatch
The n-th mode of the tensor must be equal with the 2nd mode of the matrix.
IncompatibleTypes
n-model product works only with TT-tensors and not TT matrices.

Returns

TT
the result
Expand source code
def mprod(self, factor_matrices, mode):
    """
    n-mode product.

    Args:
        factor_matrices (torch.tensor or list[torch.tensor]): either a single matrix is directly provided or a list of matrices for product along multiple modes.
        mode (int or list[int]): the mode for the product. If factor_matrices is a torch.tensor then mode is an integer and the multiplication will be performed along a single mode.
                                 If factor_matrices is a list, the mode has to be list[int] of equal size.

    Raises:
        InvalidArguments: Invalid arguments.
        ShapeMismatch: The n-th mode of the tensor must be equal with the 2nd mode of the matrix.
        IncompatibleTypes: n-model product works only with TT-tensors and not TT matrices.
        
    Returns:
        torchtt.TT: the result
    """
    if self.__is_ttm:
        raise IncompatibleTypes("n-model product works only with TT-tensors and not TT matrices.")

    if isinstance(factor_matrices,list) and isinstance(mode, list):
        cores_new = [c.clone() for c in self.cores]
        for i in range(len(factor_matrices)):
            if cores_new[mode[i]].shape[1] != factor_matrices[i].shape[1]:
                raise ShapeMismatch("The n-th mode of the tensor must be equal with the 2nd mode of the matrix.")
            cores_new[mode[i]] =  tn.einsum('ijk,lj->ilk',cores_new[mode[i]],factor_matrices[i]) # if self.__is_ttm else tn.einsum('ijk,lj->ilk',cores_new[mode[i]],factor_matrices[i]) 
    elif isinstance(mode, int) and tn.is_tensor(factor_matrices):
        cores_new = [c.clone() for c in self.cores]
        if cores_new[mode].shape[1] != factor_matrices.shape[1]:
            raise ShapeMismatch("The n-th mode of the tensor must be equal with the 2nd mode of the matrix.")
        cores_new[mode] =  tn.einsum('ijk,lj->ilk',cores_new[mode],factor_matrices) # if self.__is_ttm else tn.einsum('ijk,lj->ilk',cores_new[mode],factor_matrices) 
    else:
        raise InvalidArguments('Invalid arguments.')
    
    return TT(cores_new)        
def norm(self, squared=False)

Computes the frobenius norm of a TT object.

Args

squared : bool, optional
returns the square of the norm if True. Defaults to False.

Returns

torch.tensor
the norm.
Expand source code
def norm(self,squared=False):
    """
    Computes the frobenius norm of a TT object.

    Args:
        squared (bool, optional): returns the square of the norm if True. Defaults to False.

    Returns:
        torch.tensor: the norm.
    """
    
    if any([c.requires_grad or c.grad_fn != None for c in self.cores]):
        norm = tn.tensor([[1.0]],dtype = self.cores[0].dtype, device=self.cores[0].device)
        
        if self.__is_ttm:
            for i in range(len(self.__N)):
                norm = tn.einsum('ab,aijm,bijn->mn',norm, self.cores[i], tn.conj(self.cores[i]))
            norm = tn.squeeze(norm)
        else:
                       
            for i in range(len(self.__N)):
                norm = tn.einsum('ab,aim,bin->mn',norm, self.cores[i], tn.conj(self.cores[i]))
            norm = tn.squeeze(norm)
        if squared:
            return norm
        else:
            return tn.sqrt(tn.abs(norm))

    else:        
        d = len(self.cores)

        core_now = self.cores[0]
        for i in range(d-1):
            if self.__is_ttm:
                mode_shape = [core_now.shape[1],core_now.shape[2]]
                core_now = tn.reshape(core_now,[core_now.shape[0]*core_now.shape[1]*core_now.shape[2],-1])
            else:
                mode_shape = [core_now.shape[1]]
                core_now = tn.reshape(core_now,[core_now.shape[0]*core_now.shape[1],-1])
                
            # perform QR
            Qmat, Rmat = QR(core_now)
                 
            # take next core
            core_next = self.cores[i+1]
            shape_next = list(core_next.shape[1:])
            core_next = tn.reshape(core_next,[core_next.shape[0],-1])
            core_next = Rmat @ core_next
            core_next = tn.reshape(core_next,[Qmat.shape[1]]+shape_next)
            
            # update the cores
            
            core_now = core_next
        if squared:
            return tn.linalg.norm(core_next)**2
        else:
            return tn.linalg.norm(core_next)
def numpy(self)

Return the full tensor as a numpy.array. In case of a TTM, the result has the shape M1 x M2 x … x Md x N1 x N2 x … x Nd. If it is involved in an AD graph, an error will occur.

Returns

numpy.array
the full tensor in numpy.
Expand source code
def numpy(self):
    """
    Return the full tensor as a numpy.array.
    In case of a TTM, the result has the shape M1 x M2 x ... x Md x N1 x N2 x ... x Nd.
    If it is involved in an AD graph, an error will occur.
    
    Returns:
        numpy.array: the full tensor in numpy.
    """
    return self.full().cpu().numpy()
def qtt_to_tens(self, original_shape)

Transform a tensor back from QTT.

Args

original_shape : list
the original shape.

Raises

InvalidArguments
Original shape must be a list.
ShapeMismatch
Mode sizes do not match.

Returns

TT
the folded tensor.
Expand source code
def qtt_to_tens(self, original_shape):
    """
    Transform a tensor back from QTT.

    Args:
        original_shape (list): the original shape.

    Raises:
        InvalidArguments: Original shape must be a list.
        ShapeMismatch: Mode sizes do not match.

    Returns:
        torchtt.TT: the folded tensor.
    """
    
    if not isinstance(original_shape,list):
        raise InvalidArguments("Original shape must be a list.")

    core = None
    cores_new = []
    
    if self.__is_ttm:
        pass
    else:
        k = 0
        for c in self.cores:
            if core==None:
                core = c
                so_far = core.shape[1]
            else:
                core = tn.einsum('...i,ijk->...jk',core,c)
                so_far *= c.shape[1]
            if so_far==original_shape[k]:
                core = tn.reshape(core,[core.shape[0],-1,core.shape[-1]])
                cores_new.append(core)
                core = None
                k += 1
        if k!= len(original_shape):
            raise ShapeMismatch('Mode sizes do not match.')
    return TT(cores_new)
def reduce_dims(self, exclude=[])

Reduces the size 1 modes of the TT-object. At least one mode should be larger than 1.

Args

exclude : list, optional
Indices to exclude. Defaults to [].
Expand source code
def reduce_dims(self, exclude = []):
    """
    Reduces the size 1 modes of the TT-object.
    At least one mode should be larger than 1.

    Args:
        exclude (list, optional): Indices to exclude. Defaults to [].
    """
    
    # TODO: implement a version that reduces the rank also. by spliting the cores with modes 1 into 2 using the SVD.
    
    if self.__is_ttm:
        cores_new = []
        
        for i in range(len(self.__N)):
            
            if self.cores[i].shape[1] == 1 and self.cores[i].shape[2] == 1 and not i in exclude:
                if self.cores[i].shape[0] > self.cores[i].shape[3] or i == len(self.__N)-1:
                    # multiply to the left
                    if len(cores_new) > 0:
                        cores_new[-1] = tn.einsum('ijok,kl->ijol',cores_new[-1], self.cores[i][:,0,0,:])
                    else: 
                        # there is no core to the left. Multiply right.
                        if i != len(self.__N)-1:
                            self.cores[i+1] = tn.einsum('ij,jkml->ikml', self.cores[i][:,0,0,:],self.cores[i+1])
                        else:
                            cores_new.append(self.cores[i])
                        
                else:
                    # multiply to the right. Set the carry 
                    self.cores[i+1] = tn.einsum('ij,jkml->ikml',self.cores[i][:,0,0,:],self.cores[i+1])
                    
            else:
                cores_new.append(self.cores[i])
                
        # update the cores and ranks and shape
        self.__N = []
        self.__M = []
        self.__R = [1]
        for i in range(len(cores_new)):
            self.__N.append(cores_new[i].shape[2])
            self.__M.append(cores_new[i].shape[1])
            self.__R.append(cores_new[i].shape[3])
        self.cores = cores_new
    else:
        cores_new = []
        
        for i in range(len(self.__N)):
            
            if self.cores[i].shape[1] == 1 and not i in exclude:
                if self.cores[i].shape[0] > self.cores[i].shape[2] or i == len(self.__N)-1:
                    # multiply to the left
                    if len(cores_new) > 0:
                        cores_new[-1] = tn.einsum('ijk,kl->ijl',cores_new[-1], self.cores[i][:,0,:])
                    else: 
                        # there is no core to the left. Multiply right.
                        if i != len(self.__N)-1:
                            self.cores[i+1] = tn.einsum('ij,jkl->ikl', self.cores[i][:,0,:],self.cores[i+1])
                        else:
                            cores_new.append(self.cores[i])
                            
                        
                else:
                    # multiply to the right. Set the carry 
                    self.cores[i+1] = tn.einsum('ij,jkl->ikl',self.cores[i][:,0,:],self.cores[i+1])
                    
            else:
                cores_new.append(self.cores[i])
        
        
        # update the cores and ranks and shape
        self.__N = []
        self.__R = [1]
        for i in range(len(cores_new)):
            self.__N.append(cores_new[i].shape[1])
            self.__R.append(cores_new[i].shape[2])
        self.cores = cores_new
                
                
    self.shape = [ (m,n) for m,n in zip(self.__M,self.__N) ] if self.__is_ttm else [n for n in self.N] 
def round(self, eps=1e-12, rmax=9223372036854775807)

Implements the rounding operations within a given tolerance epsilon. The maximum rank is also provided.

Args

eps : float, optional
the relative accuracy. Defaults to 1e-12.
rmax : int, optional
the maximum rank. Defaults to the maximum possible integer.

Returns

TT
the result.
Expand source code
def round(self, eps=1e-12, rmax = sys.maxsize): 
    """
    Implements the rounding operations within a given tolerance epsilon.
    The maximum rank is also provided.

    Args:
        eps (float, optional): the relative accuracy. Defaults to 1e-12.
        rmax (int, optional): the maximum rank. Defaults to the maximum possible integer.

    Returns:
        torchtt.TT: the result.
    """
    
    # rmax is not list
    if not isinstance(rmax,list):
        rmax = [1] + len(self.__N)*[rmax] + [1]
        
    # call the round function
    tt_cores, R = round_tt(self.cores, self.__R.copy(), eps, rmax,self.__is_ttm)
    # creates a new TT and return it
    T = TT(tt_cores)
           
    return T
def sum(self, index=None)

Contracts a tensor in the TT format along the given indices and retuyrns the resulting tensor in the TT format. If no index list is given, the sum over all indices is performed.

Examples

a = torchtt.ones([3,4,5,6,7])
print(a.sum()) 
print(a.sum([0,2,4]))
print(a.sum([1,2]))
print(a.sum([0,1,2,3,4]))

Args

index (int | list[int] | None, optional): the indices along which the summation is performed. None selects all of them. Defaults to None.

Raises

InvalidArguments
Invalid index.

Returns

torchtt.TT/torch.tensor: the result.

Expand source code
def sum(self,index = None):
    """
    Contracts a tensor in the TT format along the given indices and retuyrns the resulting tensor in the TT format.
    If no index list is given, the sum over all indices is performed.

    Examples:
        ```
        a = torchtt.ones([3,4,5,6,7])
        print(a.sum()) 
        print(a.sum([0,2,4]))
        print(a.sum([1,2]))
        print(a.sum([0,1,2,3,4]))
        ```
        
    Args:
        index (int | list[int] | None, optional): the indices along which the summation is performed. None selects all of them. Defaults to None.

    Raises:
        InvalidArguments: Invalid index.

    Returns:
        torchtt.TT/torch.tensor: the result.
    """
    
    if index != None and isinstance(index,int):
        index = [index]
    if not isinstance(index,list) and index != None:
        raise InvalidArguments('Invalid index.')
         
    if index == None: 
        # the case we need to sum over all modes
        if self.__is_ttm:
            C = tn.reduce_sum(self.cores[0],[0,1,2])
            for i in range(1,len(self.__N)):
                C = tn.sum(tn.einsum('i,ijkl->jkl',C,self.cores[i]),[0,1])
            S = tn.sum(C)
        else:
            C = tn.sum(self.cores[0],[0,1])
            for i in range(1,len(self.__N)):
                C = tn.sum(tn.einsum('i,ijk->jk',C,self.cores[i]),0)
            S = tn.sum(C)
    else:
        # we return the TT-tensor with summed indices
        cores = []
        
        if self.__is_ttm:
            tmp = [1,2]
        else:
            tmp = [1]
            
        for i in range(len(self.__N)):
            if i in index:
                C = tn.sum(self.cores[i], tmp, keepdim = True)
                cores.append(C)
            else:
                cores.append(self.cores[i])
                    
        S = TT(cores)
        S.reduce_dims()
        if len(S.cores)==1 and tn.numel(S.cores[0])==1:
            S = tn.squeeze(S.cores[0])
    return S
def t(self)

Returns the transpose of a given TT matrix.

Returns

TT
the transpose.

Raises

InvalidArguments
Has to be TT matrix.
Expand source code
def t(self):
    """
    Returns the transpose of a given TT matrix.
            
                
    Returns:
        torchtt.TT: the transpose.
        
    Raises:
        InvalidArguments: Has to be TT matrix.
    """ 
    if not self.__is_ttm:
        raise InvalidArguments('Has to be TT matrix.')
        
    cores_new = [tn.permute(c,[0,2,1,3]) for c in self.cores]
    
    return TT(cores_new)
def to(self, device=None, dtype=None)

Moves the TT instance to the given device with the given dtype.

Args

device : torch.device, optional
The desired device. If none is provided, the device is the CPU. Defaults to None.
dtype : torch.dtype, optional
The desired dtype (torch.float64, torch.float32,…). If None is provided the dtype is not changed. Defaults to None.
Expand source code
def to(self, device = None, dtype = None):
    """
    Moves the TT instance to the given device with the given dtype.

    Args:
        device (torch.device, optional): The desired device. If none is provided, the device is the CPU. Defaults to None.
        dtype (torch.dtype, optional): The desired dtype (torch.float64, torch.float32,...). If None is provided the dtype is not changed. Defaults to None.
    """
    return TT( [ c.to(device=device,dtype=dtype) for c in self.cores])
def to_qtt(self, eps=1e-12, mode_size=2, rmax=9223372036854775807)

Converts a tensor to the QTT format: N1 x N2 x … x Nd -> mode_size x mode_size x … x mode_size. The product of the mode sizes should be a power of mode_size. The tensor in QTT can be converted back using the qtt_to_tens() method.

Examples

x = torchtt.random([16,8,64,128],[1,2,10,12,1])
x_qtt = x.to_qtt()
print(x_qtt)
xf = x_qtt.qtt_to_tens(x.N) # a TT-rounding is recommended.

Args

eps : float,optional
the accuracy. Defaults to 1e-12.
mode_size : int, optional
the size of the modes. Defaults to 2.
rmax : int
the maximum rank. Defaults to the maximum possible integer.

Raises

ShapeMismatch
Only quadratic TTM can be tranformed to QTT.
ShapeMismatch
Reshaping error: check if the dimensions are powers of the desired mode size.

Returns

TT
the resulting reshaped tensor.
Expand source code
def to_qtt(self, eps = 1e-12, mode_size = 2, rmax = sys.maxsize):
    """
    Converts a tensor to the QTT format: N1 x N2 x ... x Nd -> mode_size x mode_size x ... x mode_size.
    The product of the mode sizes should be a power of mode_size.
    The tensor in QTT can be converted back using the qtt_to_tens() method.

    Examples:
        ```
        x = torchtt.random([16,8,64,128],[1,2,10,12,1])
        x_qtt = x.to_qtt()
        print(x_qtt)
        xf = x_qtt.qtt_to_tens(x.N) # a TT-rounding is recommended.
        ```
        
    Args:
        eps (float,optional): the accuracy. Defaults to 1e-12.
        mode_size (int, optional): the size of the modes. Defaults to 2.
        rmax (int): the maximum rank. Defaults to the maximum possible integer.
        

    Raises:
        ShapeMismatch: Only quadratic TTM can be tranformed to QTT.
        ShapeMismatch: Reshaping error: check if the dimensions are powers of the desired mode size.

    Returns:
        torchtt.TT: the resulting reshaped tensor.
                   
    """
   
    cores_new = []
    if self.__is_ttm:
        shape_new = []
        for i in range(len(self.__N)):
            if self.__N[i]!=self.__M[i]:
                raise ShapeMismatch('Only quadratic TTM can be tranformed to QTT.')
            if self.__N[i]==mode_size**int(math.log(self.N[i],mode_size)):
                shape_new += [(mode_size,mode_size)]*int(math.log(self.__N[i],mode_size))
            else:
                raise ShapeMismatch('Reshaping error: check if the dimensions are powers of the desired mode size:\r\ncore size '+str(list(self.cores[i].shape))+' cannot be reshaped.')
            
        result = reshape(self, shape_new, eps, rmax)
    else:
        for core in self.cores:
            if int(math.log(core.shape[1],mode_size))>1:
                Nnew = [core.shape[0]*mode_size]+[mode_size]*(int(math.log(core.shape[1],mode_size))-2)+[core.shape[2]*mode_size]
                try:
                    core = tn.reshape(core,Nnew)
                except:
                    raise ShapeMismatch('Reshaping error: check if the dimensions care powers of the desired mode size:\r\ncore size '+str(list(core.shape))+' cannot be reshaped to '+str(Nnew))
                cores,_ = to_tt(core,Nnew,eps,rmax,is_sparse=False)
                cores_new.append(tn.reshape(cores[0],[-1,mode_size,cores[0].shape[-1]]))
                cores_new += cores[1:-1]
                cores_new.append(tn.reshape(cores[-1],[cores[-1].shape[0],mode_size,-1]))
            else: 
                cores_new.append(core)
        result = TT(cores_new)
        
    return result
def to_ttm(self)

Converts a TT-tensor to the TT-matrix format. In the tensor has the shape N1 x … x Nd, the result has the shape N1 x … x Nd x 1 x … x 1.

Returns

torch.TT
the result
Expand source code
def to_ttm(self):
    """
    Converts a TT-tensor to the TT-matrix format. In the tensor has the shape N1 x ... x Nd, the result has the shape 
    N1 x ... x Nd x 1 x ... x 1.

    Returns:
        torch.TT: the result
    """

    cores_new = [tn.reshape(c,(c.shape[0],c.shape[1],1,c.shape[2])) for c in self.cores]
    return TT(cores_new)