Module torchtt.manifold

Manifold gradient module.

Expand source code
"""
Manifold gradient module.

"""

import torch as tn
from torchtt._decomposition import mat_to_tt, to_tt, lr_orthogonal, round_tt, rl_orthogonal
from . import TT
from torchtt.errors import *

def _delta2cores(tt_cores, R, Sds, is_ttm = False, ortho = None):
    """
    Convert the detla notation to TT.
    Implements Algorithm 5.1 from "AUTOMATIC DIFFERENTIATION FOR RIEMANNIAN OPTIMIZATION ON LOW-RANK MATRIX AND TENSOR-TRAIN MANIFOLDS".

    Args:
        tt_cores (list[torch.tensor]): the TT cores.
        R (list[int]): the rank of the tensor.
        Sds (list[torch.tensor]): deltas.
        is_ttm (bool, optional): is TT amtrix or not. Defaults to False.
        ortho (list[list[torch.tensor]], optional): the left and right orthogonal cores of tt_cores. Defaults to None.

    Returns:
        list[torch.tensor]: the resulting TT cores.
    """
    
    if ortho == None:
        l_cores,_  = lr_orthogonal(tt_cores, R, is_ttm)
        r_cores,_  = rl_orthogonal(tt_cores, R, is_ttm)
    else:
        l_cores = ortho[0]
        r_cores = ortho[1]
    
    # first
    cores_new = [tn.cat((Sds[0],l_cores[0]),2 if not is_ttm else 3)]
    # 2...d-1
    for k in range(1,len(tt_cores)-1):
        up = tn.cat((r_cores[k],tn.zeros((r_cores[k].shape),dtype = l_cores[0].dtype, device = l_cores[0].device)),2 if not is_ttm else 3)
        down = tn.cat((Sds[k],l_cores[k]),2 if not is_ttm else 3)
        cores_new.append(tn.cat((up,down),0))
    # last
    cores_new.append(tn.cat((r_cores[-1],Sds[-1]),0))
    
    return cores_new

def riemannian_gradient(x,func):
    """
    Compute the Riemannian gradient using AD.

    Args:
        x (torchtt.TT): the point on the manifold where the gradient is computed.
        func ([type]): function that has to be differentiated. The function takes as only argument `torchtt.TT` instances.

    Returns:
        torchtt.TT: the gradient projected on the tangent space of x.
    """

    l_cores,_  = lr_orthogonal(x.cores, x.R, x.is_ttm)
    r_cores,_  = rl_orthogonal(l_cores, x.R, x.is_ttm)
    
    is_ttm = x.is_ttm

    
    R = x.R
    d = len(x.N)
    
    Rs = [ r_cores[0] ]
    Rs += [ x.cores[i]*0 for i in range(1,d)]
    
    # AD part
    for i in range(d):
        Rs[i].requires_grad_(True)
    Ghats = _delta2cores(x.cores, R, Rs, is_ttm = is_ttm,ortho = [l_cores,r_cores])
    fval = func(TT(Ghats))
    fval.backward() 

    # Sds = tape.gradient(fval, Rs)
    Sds = [r.grad for r in Rs]
    # print('Sds ',Sds)
  
    
    # compute Sdeltas
    for k in range(d-1):
        D = tn.reshape(Sds[k],[-1,R[k+1]])
        UL = tn.reshape(l_cores[k],[-1,R[k+1]])
        D = D - UL @ (UL.T @ D)
        Sds[k] = tn.reshape(D,l_cores[k].shape)
        
        
        
    # print([tf.einsum('ijk,ijl->kl',l_cores[i],Sds[i]).numpy() for i in range(d-1)])
    # delta to TT
    grad_cores = _delta2cores(x.cores, R, Sds, is_ttm,ortho = [l_cores,r_cores])
    return TT(grad_cores)
        
def riemannian_projection(Xspace,z):
    """
    Project the tensor z onto the tangent space defined at xspace

    Args:
        Xspace (torchtt.TT): the target where the tensor should be projected.
        z (torchtt.TT): the tensor that should be projected.

    Raises:
        IncompatibleTypes: Both must be of same type.

    Returns:
        torchtt.TT: the projection.
    """

    if Xspace.is_ttm != z.is_ttm:
        raise IncompatibleTypes('Both must be of same type.')
       
    is_ttm = Xspace.is_ttm
     
    l_cores,R  = lr_orthogonal(Xspace.cores, Xspace.R, Xspace.is_ttm)
    r_cores,_  = rl_orthogonal(l_cores, R, Xspace.is_ttm)
    
    d = len(Xspace.N)

    N = Xspace.N
    
    # Pleft = [tf.ones((1,1,1),dtype=Xspace.cores[0].dtype)]
    Pleft = []
    tmp = tn.ones((1,1),dtype=Xspace.cores[0].dtype, device = Xspace.cores[0].device)
    for k in range(d-1):
        if is_ttm:
            tmp = tn.einsum('rs,rijR,sijS->RS',tmp,l_cores[k],z.cores[k]) # size rk x sk
        else:
            tmp = tn.einsum('rs,riR,siS->RS',tmp,l_cores[k],z.cores[k]) # size rk x sk
        Pleft.append(tmp)
        
   
    
    Pright = []
    tmp = tn.ones((1,1), dtype = Xspace.cores[0].dtype, device = Xspace.cores[0].device)
    for k in range(d-1,0,-1):
        if is_ttm:
            tmp = tn.einsum('RS,rijR,sijS->rs',tmp,r_cores[k],z.cores[k]) # size rk x sk
        else:
            tmp = tn.einsum('RS,riR,siS->rs',tmp,r_cores[k],z.cores[k]) # size rk x sk
        Pright.append(tmp)
    Pright = Pright[::-1]
    
    
    # compute elements of the tangent space
    Sds = []
    for k in range(d):
  
        if k==0:
            L = tn.ones((1,1),dtype=Xspace.cores[0].dtype, device = Xspace.cores[0].device)
        else:
            L = Pleft[k-1]
        if k==d-1:
            if is_ttm:
                Sds.append(tn.einsum('rs,sjiS->rjiS',L,z.cores[k]))   
            else:
                Sds.append(tn.einsum('rs,siS->riS',L,z.cores[k]))           
        else:
            R = Pright[k]
            if is_ttm:
                tmp1 = tn.einsum('rs,sijS->rijS',L,z.cores[k])
                tmp2 = tn.einsum('rijR,RS->rijS',l_cores[k],tn.einsum('rs,rijR,sijS->RS',L,l_cores[k],z.cores[k]))
                Sds.append(tn.einsum('rijS,RS->rijR',tmp1-tmp2,R))
            else:
                tmp1 = tn.einsum('rs,siS->riS',L,z.cores[k])
                tmp2 = tn.einsum('riR,RS->riS',l_cores[k],tn.einsum('rs,riR,siS->RS',L,l_cores[k],z.cores[k]))
                Sds.append(tn.einsum('riS,RS->riR',tmp1-tmp2,R))  
        
    # convert Sds to TT
    grad_cores = _delta2cores(Xspace.cores, R, Sds, Xspace.is_ttm,ortho = [l_cores,r_cores])

    return TT(grad_cores)

Functions

def riemannian_gradient(x, func)

Compute the Riemannian gradient using AD.

Args

x : TT
the point on the manifold where the gradient is computed.
func : [type]
function that has to be differentiated. The function takes as only argument TT instances.

Returns

TT
the gradient projected on the tangent space of x.
Expand source code
def riemannian_gradient(x,func):
    """
    Compute the Riemannian gradient using AD.

    Args:
        x (torchtt.TT): the point on the manifold where the gradient is computed.
        func ([type]): function that has to be differentiated. The function takes as only argument `torchtt.TT` instances.

    Returns:
        torchtt.TT: the gradient projected on the tangent space of x.
    """

    l_cores,_  = lr_orthogonal(x.cores, x.R, x.is_ttm)
    r_cores,_  = rl_orthogonal(l_cores, x.R, x.is_ttm)
    
    is_ttm = x.is_ttm

    
    R = x.R
    d = len(x.N)
    
    Rs = [ r_cores[0] ]
    Rs += [ x.cores[i]*0 for i in range(1,d)]
    
    # AD part
    for i in range(d):
        Rs[i].requires_grad_(True)
    Ghats = _delta2cores(x.cores, R, Rs, is_ttm = is_ttm,ortho = [l_cores,r_cores])
    fval = func(TT(Ghats))
    fval.backward() 

    # Sds = tape.gradient(fval, Rs)
    Sds = [r.grad for r in Rs]
    # print('Sds ',Sds)
  
    
    # compute Sdeltas
    for k in range(d-1):
        D = tn.reshape(Sds[k],[-1,R[k+1]])
        UL = tn.reshape(l_cores[k],[-1,R[k+1]])
        D = D - UL @ (UL.T @ D)
        Sds[k] = tn.reshape(D,l_cores[k].shape)
        
        
        
    # print([tf.einsum('ijk,ijl->kl',l_cores[i],Sds[i]).numpy() for i in range(d-1)])
    # delta to TT
    grad_cores = _delta2cores(x.cores, R, Sds, is_ttm,ortho = [l_cores,r_cores])
    return TT(grad_cores)
def riemannian_projection(Xspace, z)

Project the tensor z onto the tangent space defined at xspace

Args

Xspace : TT
the target where the tensor should be projected.
z : TT
the tensor that should be projected.

Raises

IncompatibleTypes
Both must be of same type.

Returns

TT
the projection.
Expand source code
def riemannian_projection(Xspace,z):
    """
    Project the tensor z onto the tangent space defined at xspace

    Args:
        Xspace (torchtt.TT): the target where the tensor should be projected.
        z (torchtt.TT): the tensor that should be projected.

    Raises:
        IncompatibleTypes: Both must be of same type.

    Returns:
        torchtt.TT: the projection.
    """

    if Xspace.is_ttm != z.is_ttm:
        raise IncompatibleTypes('Both must be of same type.')
       
    is_ttm = Xspace.is_ttm
     
    l_cores,R  = lr_orthogonal(Xspace.cores, Xspace.R, Xspace.is_ttm)
    r_cores,_  = rl_orthogonal(l_cores, R, Xspace.is_ttm)
    
    d = len(Xspace.N)

    N = Xspace.N
    
    # Pleft = [tf.ones((1,1,1),dtype=Xspace.cores[0].dtype)]
    Pleft = []
    tmp = tn.ones((1,1),dtype=Xspace.cores[0].dtype, device = Xspace.cores[0].device)
    for k in range(d-1):
        if is_ttm:
            tmp = tn.einsum('rs,rijR,sijS->RS',tmp,l_cores[k],z.cores[k]) # size rk x sk
        else:
            tmp = tn.einsum('rs,riR,siS->RS',tmp,l_cores[k],z.cores[k]) # size rk x sk
        Pleft.append(tmp)
        
   
    
    Pright = []
    tmp = tn.ones((1,1), dtype = Xspace.cores[0].dtype, device = Xspace.cores[0].device)
    for k in range(d-1,0,-1):
        if is_ttm:
            tmp = tn.einsum('RS,rijR,sijS->rs',tmp,r_cores[k],z.cores[k]) # size rk x sk
        else:
            tmp = tn.einsum('RS,riR,siS->rs',tmp,r_cores[k],z.cores[k]) # size rk x sk
        Pright.append(tmp)
    Pright = Pright[::-1]
    
    
    # compute elements of the tangent space
    Sds = []
    for k in range(d):
  
        if k==0:
            L = tn.ones((1,1),dtype=Xspace.cores[0].dtype, device = Xspace.cores[0].device)
        else:
            L = Pleft[k-1]
        if k==d-1:
            if is_ttm:
                Sds.append(tn.einsum('rs,sjiS->rjiS',L,z.cores[k]))   
            else:
                Sds.append(tn.einsum('rs,siS->riS',L,z.cores[k]))           
        else:
            R = Pright[k]
            if is_ttm:
                tmp1 = tn.einsum('rs,sijS->rijS',L,z.cores[k])
                tmp2 = tn.einsum('rijR,RS->rijS',l_cores[k],tn.einsum('rs,rijR,sijS->RS',L,l_cores[k],z.cores[k]))
                Sds.append(tn.einsum('rijS,RS->rijR',tmp1-tmp2,R))
            else:
                tmp1 = tn.einsum('rs,siS->riS',L,z.cores[k])
                tmp2 = tn.einsum('riR,RS->riS',l_cores[k],tn.einsum('rs,riR,siS->RS',L,l_cores[k],z.cores[k]))
                Sds.append(tn.einsum('riS,RS->riR',tmp1-tmp2,R))  
        
    # convert Sds to TT
    grad_cores = _delta2cores(Xspace.cores, R, Sds, Xspace.is_ttm,ortho = [l_cores,r_cores])

    return TT(grad_cores)