Module torchtt.interpolate
Implements the cross approximation methods (DMRG).
Expand source code
"""
Implements the cross approximation methods (DMRG).
"""
import torch as tn
import numpy as np
import torchtt
import datetime
from torchtt._decomposition import QR, SVD, rank_chop, lr_orthogonal
from torchtt._iterative_solvers import BiCGSTAB_reset, gmres_restart
import opt_einsum as oe
def _LU(M):
"""
Perform an LU decomposition and returns L, U and a permutation vector P.
Args:
M (torch.tensor): [description]
Returns:
tuple[torch.tensor,torch.tensor,torch.tensor]: L, U, P
"""
LU,P = tn.lu(M)
P,L,U = tn.lu_unpack(LU,P) # P transpose or not transpose?
P = P@tn.reshape(tn.arange(P.shape[1],dtype=P.dtype,device=P.device),[-1,1])
# P = tn.reshape(tn.arange(P.shape[1],dtype=P.dtype,device=P.device),[1,-1]) @ P
return L, U, tn.squeeze(P).to(tn.int64)
def _max_matrix(M):
values, indices = M.flatten().topk(1)
indices = [np.unravel_index(i, M.shape) for i in indices]
return values, indices
def _maxvol(M):
"""
Maxvol
Args:
M (torch.tensor): input matrix.
Returns:
torch.tensor: indices of tha maxvol submatrix.
"""
if M.shape[1] >= M.shape[0]:
# more cols than row -> return all the row indices
idx = tn.tensor(range(M.shape[0]),dtype = tn.int64)
return idx
else:
L, U, P = _LU(M)
idx = P[:M.shape[1]]
Msub = M[idx,:]
Mat = tn.linalg.solve(Msub.T,M.T).t()
for i in range(100):
val_max, idx_max = _max_matrix(tn.abs(Mat))
idx_max = idx_max[0]
if val_max<=1+5e-2:
idx = tn.sort(idx)[0]
return idx
Mat += tn.outer(Mat[:,idx_max[1]],Mat[idx[idx_max[1]]]-Mat[idx_max[0],:])/Mat[idx_max[0],idx_max[1]]
idx[idx_max[1]]=idx_max[0]
return idx
def function_interpolate(function, x, eps = 1e-9, start_tens = None, nswp = 20, kick = 2, dtype = tn.float64, verbose = False):
"""
Appication of a nonlinear function on a tensor in the TT format (using DMRG). Two cases are distinguished:
* Univariate interpoaltion:
Let \(f:\\mathbb{R}\\rightarrow\\mathbb{R}\) be a function and \(\\mathsf{x}\\in\\mathbb{R}^{N_1\\times\\cdots\\times N_d}\) be a tensor with a known TT approximation.
The goal is to determine the TT approximation of \(\\mathsf{y}_{i_1...i_d}=f(\\mathsf{x}_{i_1...i_d})\) within a prescribed relative accuracy `eps`.
* Multivariate interpolation
Let \(f:\\mathbb{R}\\rightarrow\\mathbb{R}\) be a function and \(\\mathsf{x}^{(1)},...,\\mathsf{x}^{(d)}\\in\\mathbb{R}^{N_1\\times\\cdots\\times N_d}\) be tensors with a known TT approximation. The goal is to determine the TT approximation of \(\\mathsf{y}_{i_1...i_d}=f(\\mathsf{x}_{i_1...i_d}^{(1)},...,\\mathsf{x}^{(d)})_{i_1...i_d}\) within a prescribed relative accuracy `eps`.
Example:
* Univariate interpolation:
```
func = lambda t: torch.log(t)
y = tntt.interpolate.function_interpolate(func, x, 1e-9) # the tensor x is chosen such that y has an afforbable low rank structure
```
* Multivariate interpolation:
```
xs = tntt.meshgrid([tn.arange(0,n,dtype=torch.float64) for n in N])
func = lambda x: 1/(2+tn.sum(x,1).to(dtype=torch.float64))
z = tntt.interpolate.function_interpolate(func, xs)
```
Args:
function (Callable): function handle. If the argument `x` is a `torchtt.TT` instance, the the function handle has to be appliable elementwise on torch tensors.
If a list is passed as `x`, the function handle takes as argument a $M\times d$ torch.tensor and every of the $M$ lines corresponds to an evaluation of the function \(f\) at a certain tensor entry. The function handle returns a torch tensor of length M.
x (torchtt.TT or list[torchtt.TT]): the argument/arguments of the function.
eps (float, optional): the relative accuracy. Defaults to 1e-9.
start_tens (torchtt.TT, optional): initial approximation of the output tensor (None coresponds to random initialization). Defaults to None.
nswp (int, optional): number of iterations. Defaults to 20.
kick (int, optional): enrichment rank. Defaults to 2.
dtype (torch.dtype, optional): the dtype of the result. Defaults to tn.float64.
verbose (bool, optional): display debug information to the console. Defaults to False.
Returns:
torchtt.TT: the result.
"""
if isinstance(x,list) or isinstance(x,tuple):
eval_mv = True
N = x[0].N
else:
eval_mv = False
N = x.N
device = None
if not eval_mv and len(N)==1:
return torchtt.TT(function(x.full())).to(device)
if eval_mv and len(N)==1:
return torchtt.TT(function(x[0].full())).to(device)
d = len(N)
#random init of the tensor
if start_tens == None:
rank_init = 2
cores = torchtt.random(N,rank_init, dtype, device).cores
rank = [1]+[rank_init]*(d-1)+[1]
else:
rank = start_tens.R.copy()
cores = [c+0 for c in start_tens.cores]
# cores = (ones(N,dtype=dtype)).cores
cores, rank = lr_orthogonal(cores,rank,False)
Mats = []*(d+1)
Ps = [tn.ones((1,1),dtype=dtype,device=device)]+(d-1)*[None] + [tn.ones((1,1),dtype=dtype,device=device)]
# ortho
Rm = tn.ones((1,1),dtype=dtype,device=device)
Idx = [tn.zeros((1,0),dtype=tn.int64)]+(d-1)*[None] + [tn.zeros((0,1),dtype=tn.int64)]
for k in range(d-1,0,-1):
tmp = tn.einsum('ijk,kl->ijl',cores[k],Rm)
tmp = tn.reshape(tmp,[rank[k],-1]).t()
core, Rmat = QR(tmp)
rnew = min(N[k]*rank[k+1], rank[k])
Jk = _maxvol(core)
# print(Jk)
tmp = np.unravel_index(Jk[:rnew],(rank[k+1],N[k]))
#if k==d-1:
# idx_new = tn.tensor(tmp[1].reshape([1,-1]))
# else:
idx_new = tn.tensor(np.vstack( ( tmp[1].reshape([1,-1]),Idx[k+1][:,tmp[0]] ) ))
Idx[k] = idx_new+0
Rm = core[Jk,:]
core = tn.linalg.solve(Rm.T,core.T)
Rm = (Rm@Rmat).t()
cores[k] = tn.reshape(core,[rnew,N[k],rank[k+1]])
core = tn.reshape(core,[-1,rank[k+1]]) @ Ps[k+1]
core = tn.reshape(core,[rank[k],-1]).t()
_,Ps[k] = QR(core)
cores[0] = tn.einsum('ijk,kl->ijl',cores[0],Rm)
# for p in Ps:
# print(p)
# for i in Idx:
# print(i)
# return
n_eval = 0
for swp in range(nswp):
max_err = 0.0
if verbose:
print('Sweep %d: '%(swp+1))
#left to right
for k in range(d-1):
if verbose: print('\tLR supercore %d,%d'%(k+1,k+2))
I1 = tn.reshape(tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.arange(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),[-1,1])
I2 = tn.reshape(tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.arange(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),[-1,1])
I3 = Idx[k][tn.kron(tn.kron(tn.arange(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),:]
I4 = Idx[k+2][:,tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.arange(rank[k+2],dtype=tn.int64)))].t()
eval_index = tn.concat((I3, I1, I2, I4),1)
eval_index = tn.reshape(eval_index,[-1,d]).to(dtype=tn.int64)
if verbose: print('\t\tnumber evaluations',eval_index.shape[0])
if eval_mv:
ev = tn.zeros((eval_index.shape[0],0),dtype = dtype)
for j in range(d):
core = x[j].cores[0][0,eval_index[:,0],:]
for i in range(1,d):
core = tn.einsum('ij,jil->il',core,x[j].cores[i][:,eval_index[:,i],:])
core = tn.reshape(core[...,0],[-1,1])
ev = tn.hstack((ev,core))
supercore = tn.reshape(function(ev),[rank[k],N[k],N[k+1],rank[k+2]])
n_eval += core.shape[0]
else:
core = x.cores[0][0,eval_index[:,0],:]
for i in range(1,d):
core = tn.einsum('ij,jil->il',core,x.cores[i][:,eval_index[:,i],:])
core = core[...,0]
supercore = tn.reshape(function(core),[rank[k],N[k],N[k+1],rank[k+2]])
n_eval += core.shape[0]
# multiply with P_k left and right
supercore = tn.einsum('ij,jklm,mn->ikln',Ps[k],supercore.to(dtype=dtype),Ps[k+2])
rank[k] = supercore.shape[0]
rank[k+2] = supercore.shape[3]
supercore = tn.reshape(supercore,[supercore.shape[0]*supercore.shape[1],-1])
# split the super core with svd
U,S,V = SVD(supercore)
rnew = rank_chop(S.cpu().numpy(),tn.linalg.norm(S).cpu().numpy()*eps/np.sqrt(d-1))+1
rnew = min(S.shape[0],rnew)
U = U[:,:rnew]
S = S[:rnew]
V = V[:rnew,:]
# print('kkt new',tn.linalg.norm(supercore-U@tn.diag(S)@V))
# kick the rank
V = tn.diag(S) @ V
UK = tn.randn((U.shape[0],kick), dtype = dtype, device = device)
U, Rtemp = QR( tn.cat( (U,UK) , 1) )
radd = Rtemp.shape[1] - rnew
if radd>0:
V = tn.cat( (V,tn.zeros((radd,V.shape[1]), dtype = dtype, device = device)) , 0 )
V = Rtemp @ V
# print('kkt new',tn.linalg.norm(supercore-U@V))
# compute err (dx)
super_prev = tn.einsum('ijk,kmn->ijmn',cores[k],cores[k+1])
super_prev = tn.einsum('ij,jklm,mn->ikln',Ps[k],super_prev,Ps[k+2])
err = tn.linalg.norm(supercore.flatten()-super_prev.flatten())/tn.linalg.norm(supercore)
max_err = max(max_err,err)
# update the rank
if verbose:
print('\t\trank updated %d -> %d, local error %e'%(rank[k+1],U.shape[1],err))
rank[k+1] = U.shape[1]
U = tn.linalg.solve(Ps[k],tn.reshape(U,[rank[k],-1]))
V = tn.linalg.solve(Ps[k+2].t(),tn.reshape(V,[rank[k+1]*N[k+1],rank[k+2]]).t()).t()
# U = tn.einsum('ij,jkl->ikl',tn.linalg.inv(Ps[k]),tn.reshape(U,[rank[k],N[k],-1]))
# V = tn.einsum('ijk,kl->ijl',tn.reshape(V,[-1,N[k+1],rank[k+2]]),tn.linalg.inv(Ps[k+2]))
V = tn.reshape(V,[rank[k+1],-1])
U = tn.reshape(U,[-1,rank[k+1]])
# split cores
Qmat, Rmat = QR(U)
idx = _maxvol(Qmat)
Sub = Qmat[idx,:]
core = tn.linalg.solve(Sub.T,Qmat.T).t()
core_next = Sub@Rmat@V
cores[k] = tn.reshape(core,[rank[k],N[k],rank[k+1]])
cores[k+1] = tn.reshape(core_next,[rank[k+1],N[k+1],rank[k+2]])
# calc Ps
tmp = tn.einsum('ij,jkl->ikl',Ps[k],cores[k])
_,Ps[k+1] = QR(tn.reshape(tmp,[rank[k]*N[k],rank[k+1]]))
# calc Idx
tmp = np.unravel_index(idx[:rank[k+1]],(rank[k],N[k]))
idx_new = tn.tensor(np.hstack( ( Idx[k][tmp[0],:] , tmp[1].reshape([-1,1]) ) ))
Idx[k+1] = idx_new+0
#right to left
for k in range(d-2,-1,-1):
if verbose: print('\tRL supercore %d,%d'%(k+1,k+2))
I1 = tn.reshape(tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.arange(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),[-1,1])
I2 = tn.reshape(tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.arange(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),[-1,1])
I3 = Idx[k][tn.kron(tn.kron(tn.arange(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),:]
I4 = Idx[k+2][:,tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.arange(rank[k+2],dtype=tn.int64)))].t()
eval_index = tn.concat((I3, I1, I2, I4),1)
eval_index = tn.reshape(eval_index,[-1,d]).to(dtype=tn.int64)
if verbose: print('\t\tnumber evaluations',eval_index.shape[0])
if eval_mv:
ev = tn.zeros((eval_index.shape[0],0),dtype = dtype)
for j in range(d):
core = x[j].cores[0][0,eval_index[:,0],:]
for i in range(1,d):
core = tn.einsum('ij,jil->il',core,x[j].cores[i][:,eval_index[:,i],:])
core = tn.reshape(core[...,0],[-1,1])
ev = tn.hstack((ev,core))
supercore = tn.reshape(function(ev),[rank[k],N[k],N[k+1],rank[k+2]])
n_eval += core.shape[0]
else:
core = x.cores[0][0,eval_index[:,0],:]
for i in range(1,d):
core = tn.einsum('ij,jil->il',core,x.cores[i][:,eval_index[:,i],:])
core = core[...,0]
supercore = tn.reshape(function(core),[rank[k],N[k],N[k+1],rank[k+2]])
n_eval +=core.shape[0]
# multiply with P_k left and right
supercore = tn.einsum('ij,jklm,mn->ikln',Ps[k],supercore.to(dtype=dtype),Ps[k+2])
rank[k] = supercore.shape[0]
rank[k+2] = supercore.shape[3]
supercore = tn.reshape(supercore,[supercore.shape[0]*supercore.shape[1],-1])
# split the super core with svd
U,S,V = SVD(supercore)
rnew = rank_chop(S.cpu().numpy(),tn.linalg.norm(S).cpu().numpy()*eps/np.sqrt(d-1))+1
rnew = min(S.shape[0],rnew)
U = U[:,:rnew]
S = S[:rnew]
V = V[:rnew,:]
# print('kkt new',tn.linalg.norm(supercore-U@tn.diag(S)@V))
#kick the rank
# print('u before', U.shape)
U = U @ tn.diag(S)
VK = tn.randn((kick,V.shape[1]) , dtype=dtype, device = device)
# print('V enrich', V.shape)
V, Rtemp = QR( tn.cat( (V,VK) , 0).t() )
radd = Rtemp.shape[1] - rnew
# print('V after QR',V.shape,Rtemp.shape,radd)
if radd>0:
U = tn.cat( (U,tn.zeros((U.shape[0],radd), dtype = dtype, device = device)) , 1 )
U = U @ Rtemp.T
V = V.t()
# print('kkt new',tn.linalg.norm(supercore-U@V))
# compute err (dx)
super_prev = tn.einsum('ijk,kmn->ijmn',cores[k],cores[k+1])
super_prev = tn.einsum('ij,jklm,mn->ikln',Ps[k],super_prev,Ps[k+2])
err = tn.linalg.norm(supercore.flatten()-super_prev.flatten())/tn.linalg.norm(supercore)
max_err = max(max_err,err)
# update the rank
if verbose:
print('\t\trank updated %d -> %d, local error %e'%(rank[k+1],U.shape[1],err))
rank[k+1] = U.shape[1]
U = tn.linalg.solve(Ps[k],tn.reshape(U,[rank[k],-1]))
V = tn.linalg.solve(Ps[k+2].t(),tn.reshape(V,[rank[k+1]*N[k+1],rank[k+2]]).t()).t()
# U = tn.einsum('ij,jkl->ikl',tn.linalg.inv(Ps[k]),tn.reshape(U,[rank[k],N[k],-1]))
# V = tn.einsum('ijk,kl->ijl',tn.reshape(V,[-1,N[k+1],rank[k+2]]),tn.linalg.inv(Ps[k+2]))
V = tn.reshape(V,[rank[k+1],-1])
U = tn.reshape(U,[-1,rank[k+1]])
# split cores
Qmat, Rmat = QR(V.T)
idx = _maxvol(Qmat)
Sub = Qmat[idx,:]
core_next = tn.linalg.solve(Sub.T,Qmat.T)
core =U@(Sub@Rmat).t()
cores[k] = tn.reshape(core,[rank[k],N[k],-1])
cores[k+1] = tn.reshape(core_next,[-1,N[k+1],rank[k+2]])
# calc Ps
tmp = tn.einsum('ijk,kl->ijl',cores[k+1],Ps[k+2])
_,tmp = QR(tn.reshape(tmp,[rank[k+1],-1]).t())
Ps[k+1] = tmp
# calc Idx
tmp = np.unravel_index(idx[:rank[k+1]],(N[k+1],rank[k+2]))
idx_new = tn.tensor(np.vstack( ( tmp[0].reshape([1,-1]),Idx[k+2][:,tmp[1]] ) ))
Idx[k+1] = idx_new+0
#xxx = TT(cores)
#print('# ',xxx[1,2,3,4])
# exit condition
if max_err<eps:
if verbose: print('Max error %e < %e ----> DONE'%(max_err,eps))
break
else:
if verbose: print('Max error %g'%(max_err))
if verbose:
print('number of function calls ',n_eval)
print()
return torchtt.TT(cores)
def dmrg_cross(function, N, eps = 1e-9, nswp = 10, x_start = None, kick = 2, dtype = tn.float64, device = None, eval_vect = True, verbose = False):
"""
Approximate a tensor in the TT format given that the individual entries are given using a function.
The function is given as a function handle taking as arguments a matrix of integer indices.
Example:
```
func = lambda I: 1/(2+I[:,0]+I[:,1]+I[:,2]+I[:,3]).to(dtype=torch.float64)
N = [20]*4
x = torchtt.interpolate.dmrg_cross(func, N, eps = 1e-7)
```
Args:
function (Callable): function handle.
N (list[int]): the shape of the tensor.
eps (float, optional): the relative accuracy. Defaults to 1e-9.
nswp (int, optional): number of iterations. Defaults to 20.
x_start (torchtt.TT, optional): initial approximation of the output tensor (None coresponds to random initialization). Defaults to None.
kick (int, optional): enrichment rank. Defaults to 2.
dtype (torch.dtype, optional): the dtype of the result. Defaults to tn.float64.
device (torch.device, optional): the device where the approximation will be stored. Defaults to None.
eval_vect (bool, optional): not yet implemented. Defaults to True.
verbose (bool, optional): display debug information to the console. Defaults to False.
Returns:
torchtt.TT: the result.
"""
# store the computed values
computed_vals = dict()
d = len(N)
#random init of the tensor
if x_start == None:
rank_init = 2
cores = torchtt.random(N,rank_init, dtype, device).cores
rank = [1]+[rank_init]*(d-1)+[1]
else:
rank = x_start.R.copy()
cores = [c+0 for c in x_start.cores]
# cores = (ones(N,dtype=dtype)).cores
cores, rank = lr_orthogonal(cores,rank,False)
Mats = []*(d+1)
Ps = [tn.ones((1,1),dtype=dtype,device=device)]+(d-1)*[None] + [tn.ones((1,1),dtype=dtype,device=device)]
# ortho
Rm = tn.ones((1,1),dtype=dtype,device=device)
Idx = [tn.zeros((1,0),dtype=tn.int64)]+(d-1)*[None] + [tn.zeros((0,1),dtype=tn.int64)]
for k in range(d-1,0,-1):
tmp = tn.einsum('ijk,kl->ijl',cores[k],Rm)
tmp = tn.reshape(tmp,[rank[k],-1]).t()
core, Rmat = QR(tmp)
rnew = min(N[k]*rank[k+1], rank[k])
Jk = _maxvol(core)
# print(Jk)
tmp = np.unravel_index(Jk[:rnew],(rank[k+1],N[k]))
#if k==d-1:
# idx_new = tn.tensor(tmp[1].reshape([1,-1]))
# else:
idx_new = tn.tensor(np.vstack( ( tmp[1].reshape([1,-1]),Idx[k+1][:,tmp[0]] ) ))
Idx[k] = idx_new+0
Rm = core[Jk,:]
core = tn.linalg.solve(Rm.T,core.T)
# core = tn.linalg.solve(Rm,core.T)
Rm = (Rm@Rmat).t()
# core = core.t()
cores[k] = tn.reshape(core,[rnew,N[k],rank[k+1]])
core = tn.reshape(core,[-1,rank[k+1]]) @ Ps[k+1]
core = tn.reshape(core,[rank[k],-1]).t()
_,Ps[k] = QR(core)
cores[0] = tn.einsum('ijk,kl->ijl',cores[0],Rm)
# for p in Ps:
# print(p)
# for i in Idx:
# print(i)
# return
n_eval = 0
for swp in range(nswp):
max_err = 0.0
if verbose:
print('Sweep %d: '%(swp+1))
#left to right
for k in range(d-1):
if verbose: print('\tLR supercore %d,%d'%(k+1,k+2))
I1 = tn.reshape(tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.arange(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),[-1,1])
I2 = tn.reshape(tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.arange(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),[-1,1])
I3 = Idx[k][tn.kron(tn.kron(tn.arange(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),:]
I4 = Idx[k+2][:,tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.arange(rank[k+2],dtype=tn.int64)))].t()
eval_index = tn.concat((I3, I1, I2, I4),1)
eval_index = tn.reshape(eval_index,[-1,d]).to(dtype=tn.int64)
if verbose: print('\t\tnumber evaluations',eval_index.shape[0])
if eval_vect:
supercore = tn.reshape(function(eval_index),[rank[k],N[k],N[k+1],rank[k+2]])
n_eval += eval_index.shape[0]
# multiply with P_k left and right
supercore = tn.einsum('ij,jklm,mn->ikln',Ps[k],supercore.to(dtype=dtype),Ps[k+2])
rank[k] = supercore.shape[0]
rank[k+2] = supercore.shape[3]
supercore = tn.reshape(supercore,[supercore.shape[0]*supercore.shape[1],-1])
# split the super core with svd
U,S,V = SVD(supercore)
rnew = rank_chop(S.cpu().numpy(),tn.linalg.norm(S).cpu().numpy()*eps/np.sqrt(d-1))+1
rnew = min(S.shape[0],rnew)
U = U[:,:rnew]
S = S[:rnew]
V = V[:rnew,:]
# print('kkt new',tn.linalg.norm(supercore-U@tn.diag(S)@V))
# kick the rank
V = tn.diag(S) @ V
UK = tn.randn((U.shape[0],kick), dtype = dtype, device = device)
U, Rtemp = QR( tn.cat( (U,UK) , 1) )
radd = U.shape[1] - rnew
if radd>0:
V = tn.cat( (V,tn.zeros((radd,V.shape[1]), dtype = dtype, device = device)) , 0 )
V = Rtemp @ V
# print('kkt new',tn.linalg.norm(supercore-U@V))
# compute err (dx)
super_prev = tn.einsum('ijk,kmn->ijmn',cores[k],cores[k+1])
super_prev = tn.einsum('ij,jklm,mn->ikln',Ps[k],super_prev,Ps[k+2])
err = tn.linalg.norm(supercore.flatten()-super_prev.flatten())/tn.linalg.norm(supercore)
max_err = max(max_err,err)
# update the rank
if verbose:
print('\t\trank updated %d -> %d, local error %e'%(rank[k+1],U.shape[1],err))
rank[k+1] = U.shape[1]
U = tn.linalg.solve(Ps[k],tn.reshape(U,[rank[k],-1]))
V = tn.linalg.solve(Ps[k+2].t(),tn.reshape(V,[rank[k+1]*N[k+1],rank[k+2]]).t()).t()
# U = tn.einsum('ij,jkl->ikl',tn.linalg.inv(Ps[k]),tn.reshape(U,[rank[k],N[k],-1]))
# V = tn.einsum('ijk,kl->ijl',tn.reshape(V,[-1,N[k+1],rank[k+2]]),tn.linalg.inv(Ps[k+2]))
V = tn.reshape(V,[rank[k+1],-1])
U = tn.reshape(U,[-1,rank[k+1]])
# split cores
Qmat, Rmat = QR(U)
idx = _maxvol(Qmat)
Sub = Qmat[idx,:]
core = tn.linalg.solve(Sub.T,Qmat.T).t()
core_next = Sub@Rmat@V
cores[k] = tn.reshape(core,[rank[k],N[k],rank[k+1]])
cores[k+1] = tn.reshape(core_next,[rank[k+1],N[k+1],rank[k+2]])
# calc Ps
tmp = tn.einsum('ij,jkl->ikl',Ps[k],cores[k])
_,Ps[k+1] = QR(tn.reshape(tmp,[rank[k]*N[k],rank[k+1]]))
# calc Idx
tmp = np.unravel_index(idx[:rank[k+1]],(rank[k],N[k]))
idx_new = tn.tensor(np.hstack( ( Idx[k][tmp[0],:] , tmp[1].reshape([-1,1]) ) ))
Idx[k+1] = idx_new+0
#right to left
for k in range(d-2,-1,-1):
if verbose: print('\tRL supercore %d,%d'%(k+1,k+2))
I1 = tn.reshape(tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.arange(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),[-1,1])
I2 = tn.reshape(tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.arange(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),[-1,1])
I3 = Idx[k][tn.kron(tn.kron(tn.arange(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),:]
I4 = Idx[k+2][:,tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.arange(rank[k+2],dtype=tn.int64)))].t()
eval_index = tn.concat((I3, I1, I2, I4),1)
eval_index = tn.reshape(eval_index,[-1,d]).to(dtype=tn.int64)
if verbose: print('\t\tnumber evaluations',eval_index.shape[0])
if eval_vect:
supercore = tn.reshape(function(eval_index).to(dtype=dtype),[rank[k],N[k],N[k+1],rank[k+2]])
n_eval += eval_index.shape[0]
# multiply with P_k left and right
supercore = tn.einsum('ij,jklm,mn->ikln',Ps[k],supercore.to(dtype=dtype),Ps[k+2])
rank[k] = supercore.shape[0]
rank[k+2] = supercore.shape[3]
supercore = tn.reshape(supercore,[supercore.shape[0]*supercore.shape[1],-1])
# split the super core with svd
U,S,V = SVD(supercore)
rnew = rank_chop(S.cpu().numpy(),tn.linalg.norm(S).cpu().numpy()*eps/np.sqrt(d-1))+1
rnew = min(S.shape[0],rnew)
U = U[:,:rnew]
S = S[:rnew]
V = V[:rnew,:]
# print('kkt new',tn.linalg.norm(supercore-U@tn.diag(S)@V))
#kick the rank
U = U @ tn.diag(S)
VK = tn.randn((kick,V.shape[1]) , dtype=dtype, device = device)
V, Rtemp = QR( tn.cat( (V,VK) , 0).t() )
radd = V.shape[1] - rnew
if radd>0:
U = tn.cat( (U,tn.zeros((U.shape[0],radd), dtype = dtype, device = device)) , 1 )
U = U @ Rtemp.T
V = V.t()
# print('kkt new',tn.linalg.norm(supercore-U@V))
# compute err (dx)
super_prev = tn.einsum('ijk,kmn->ijmn',cores[k],cores[k+1])
super_prev = tn.einsum('ij,jklm,mn->ikln',Ps[k],super_prev,Ps[k+2])
err = tn.linalg.norm(supercore.flatten()-super_prev.flatten())/tn.linalg.norm(supercore)
max_err = max(max_err,err)
# update the rank
if verbose:
print('\t\trank updated %d -> %d, local error %e'%(rank[k+1],U.shape[1],err))
rank[k+1] = U.shape[1]
U = tn.linalg.solve(Ps[k],tn.reshape(U,[rank[k],-1]))
V = tn.linalg.solve(Ps[k+2].t(),tn.reshape(V,[rank[k+1]*N[k+1],rank[k+2]]).t()).t()
# U = tn.einsum('ij,jkl->ikl',tn.linalg.inv(Ps[k]),tn.reshape(U,[rank[k],N[k],-1]))
# V = tn.einsum('ijk,kl->ijl',tn.reshape(V,[-1,N[k+1],rank[k+2]]),tn.linalg.inv(Ps[k+2]))
V = tn.reshape(V,[rank[k+1],-1])
U = tn.reshape(U,[-1,rank[k+1]])
# split cores
Qmat, Rmat = QR(V.T)
idx = _maxvol(Qmat)
Sub = Qmat[idx,:]
core_next = tn.linalg.solve(Sub.T,Qmat.T)
core =U@(Sub@Rmat).t()
cores[k] = tn.reshape(core,[rank[k],N[k],-1])
cores[k+1] = tn.reshape(core_next,[-1,N[k+1],rank[k+2]])
# calc Ps
tmp = tn.einsum('ijk,kl->ijl',cores[k+1],Ps[k+2])
_,tmp = QR(tn.reshape(tmp,[rank[k+1],-1]).t())
Ps[k+1] = tmp
# calc Idx
tmp = np.unravel_index(idx[:rank[k+1]],(N[k+1],rank[k+2]))
idx_new = tn.tensor(np.vstack( ( tmp[0].reshape([1,-1]),Idx[k+2][:,tmp[1]] ) ))
Idx[k+1] = idx_new+0
#xxx = TT(cores)
#print('# ',xxx[1,2,3,4])
# exit condition
if max_err<eps:
if verbose: print('Max error %e < %e ----> DONE'%(max_err,eps))
break
else:
if verbose: print('Max error %g'%(max_err))
if verbose:
print('number of function calls ',n_eval)
print()
return torchtt.TT(cores)
Functions
def dmrg_cross(function, N, eps=1e-09, nswp=10, x_start=None, kick=2, dtype=torch.float64, device=None, eval_vect=True, verbose=False)
-
Approximate a tensor in the TT format given that the individual entries are given using a function. The function is given as a function handle taking as arguments a matrix of integer indices.
Example
func = lambda I: 1/(2+I[:,0]+I[:,1]+I[:,2]+I[:,3]).to(dtype=torch.float64) N = [20]*4 x = torchtt.interpolate.dmrg_cross(func, N, eps = 1e-7)
Args
function
:Callable
- function handle.
N
:list[int]
- the shape of the tensor.
eps
:float
, optional- the relative accuracy. Defaults to 1e-9.
nswp
:int
, optional- number of iterations. Defaults to 20.
x_start
:TT
, optional- initial approximation of the output tensor (None coresponds to random initialization). Defaults to None.
kick
:int
, optional- enrichment rank. Defaults to 2.
dtype
:torch.dtype
, optional- the dtype of the result. Defaults to tn.float64.
device
:torch.device
, optional- the device where the approximation will be stored. Defaults to None.
eval_vect
:bool
, optional- not yet implemented. Defaults to True.
verbose
:bool
, optional- display debug information to the console. Defaults to False.
Returns
TT
- the result.
Expand source code
def dmrg_cross(function, N, eps = 1e-9, nswp = 10, x_start = None, kick = 2, dtype = tn.float64, device = None, eval_vect = True, verbose = False): """ Approximate a tensor in the TT format given that the individual entries are given using a function. The function is given as a function handle taking as arguments a matrix of integer indices. Example: ``` func = lambda I: 1/(2+I[:,0]+I[:,1]+I[:,2]+I[:,3]).to(dtype=torch.float64) N = [20]*4 x = torchtt.interpolate.dmrg_cross(func, N, eps = 1e-7) ``` Args: function (Callable): function handle. N (list[int]): the shape of the tensor. eps (float, optional): the relative accuracy. Defaults to 1e-9. nswp (int, optional): number of iterations. Defaults to 20. x_start (torchtt.TT, optional): initial approximation of the output tensor (None coresponds to random initialization). Defaults to None. kick (int, optional): enrichment rank. Defaults to 2. dtype (torch.dtype, optional): the dtype of the result. Defaults to tn.float64. device (torch.device, optional): the device where the approximation will be stored. Defaults to None. eval_vect (bool, optional): not yet implemented. Defaults to True. verbose (bool, optional): display debug information to the console. Defaults to False. Returns: torchtt.TT: the result. """ # store the computed values computed_vals = dict() d = len(N) #random init of the tensor if x_start == None: rank_init = 2 cores = torchtt.random(N,rank_init, dtype, device).cores rank = [1]+[rank_init]*(d-1)+[1] else: rank = x_start.R.copy() cores = [c+0 for c in x_start.cores] # cores = (ones(N,dtype=dtype)).cores cores, rank = lr_orthogonal(cores,rank,False) Mats = []*(d+1) Ps = [tn.ones((1,1),dtype=dtype,device=device)]+(d-1)*[None] + [tn.ones((1,1),dtype=dtype,device=device)] # ortho Rm = tn.ones((1,1),dtype=dtype,device=device) Idx = [tn.zeros((1,0),dtype=tn.int64)]+(d-1)*[None] + [tn.zeros((0,1),dtype=tn.int64)] for k in range(d-1,0,-1): tmp = tn.einsum('ijk,kl->ijl',cores[k],Rm) tmp = tn.reshape(tmp,[rank[k],-1]).t() core, Rmat = QR(tmp) rnew = min(N[k]*rank[k+1], rank[k]) Jk = _maxvol(core) # print(Jk) tmp = np.unravel_index(Jk[:rnew],(rank[k+1],N[k])) #if k==d-1: # idx_new = tn.tensor(tmp[1].reshape([1,-1])) # else: idx_new = tn.tensor(np.vstack( ( tmp[1].reshape([1,-1]),Idx[k+1][:,tmp[0]] ) )) Idx[k] = idx_new+0 Rm = core[Jk,:] core = tn.linalg.solve(Rm.T,core.T) # core = tn.linalg.solve(Rm,core.T) Rm = (Rm@Rmat).t() # core = core.t() cores[k] = tn.reshape(core,[rnew,N[k],rank[k+1]]) core = tn.reshape(core,[-1,rank[k+1]]) @ Ps[k+1] core = tn.reshape(core,[rank[k],-1]).t() _,Ps[k] = QR(core) cores[0] = tn.einsum('ijk,kl->ijl',cores[0],Rm) # for p in Ps: # print(p) # for i in Idx: # print(i) # return n_eval = 0 for swp in range(nswp): max_err = 0.0 if verbose: print('Sweep %d: '%(swp+1)) #left to right for k in range(d-1): if verbose: print('\tLR supercore %d,%d'%(k+1,k+2)) I1 = tn.reshape(tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.arange(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),[-1,1]) I2 = tn.reshape(tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.arange(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),[-1,1]) I3 = Idx[k][tn.kron(tn.kron(tn.arange(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),:] I4 = Idx[k+2][:,tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.arange(rank[k+2],dtype=tn.int64)))].t() eval_index = tn.concat((I3, I1, I2, I4),1) eval_index = tn.reshape(eval_index,[-1,d]).to(dtype=tn.int64) if verbose: print('\t\tnumber evaluations',eval_index.shape[0]) if eval_vect: supercore = tn.reshape(function(eval_index),[rank[k],N[k],N[k+1],rank[k+2]]) n_eval += eval_index.shape[0] # multiply with P_k left and right supercore = tn.einsum('ij,jklm,mn->ikln',Ps[k],supercore.to(dtype=dtype),Ps[k+2]) rank[k] = supercore.shape[0] rank[k+2] = supercore.shape[3] supercore = tn.reshape(supercore,[supercore.shape[0]*supercore.shape[1],-1]) # split the super core with svd U,S,V = SVD(supercore) rnew = rank_chop(S.cpu().numpy(),tn.linalg.norm(S).cpu().numpy()*eps/np.sqrt(d-1))+1 rnew = min(S.shape[0],rnew) U = U[:,:rnew] S = S[:rnew] V = V[:rnew,:] # print('kkt new',tn.linalg.norm(supercore-U@tn.diag(S)@V)) # kick the rank V = tn.diag(S) @ V UK = tn.randn((U.shape[0],kick), dtype = dtype, device = device) U, Rtemp = QR( tn.cat( (U,UK) , 1) ) radd = U.shape[1] - rnew if radd>0: V = tn.cat( (V,tn.zeros((radd,V.shape[1]), dtype = dtype, device = device)) , 0 ) V = Rtemp @ V # print('kkt new',tn.linalg.norm(supercore-U@V)) # compute err (dx) super_prev = tn.einsum('ijk,kmn->ijmn',cores[k],cores[k+1]) super_prev = tn.einsum('ij,jklm,mn->ikln',Ps[k],super_prev,Ps[k+2]) err = tn.linalg.norm(supercore.flatten()-super_prev.flatten())/tn.linalg.norm(supercore) max_err = max(max_err,err) # update the rank if verbose: print('\t\trank updated %d -> %d, local error %e'%(rank[k+1],U.shape[1],err)) rank[k+1] = U.shape[1] U = tn.linalg.solve(Ps[k],tn.reshape(U,[rank[k],-1])) V = tn.linalg.solve(Ps[k+2].t(),tn.reshape(V,[rank[k+1]*N[k+1],rank[k+2]]).t()).t() # U = tn.einsum('ij,jkl->ikl',tn.linalg.inv(Ps[k]),tn.reshape(U,[rank[k],N[k],-1])) # V = tn.einsum('ijk,kl->ijl',tn.reshape(V,[-1,N[k+1],rank[k+2]]),tn.linalg.inv(Ps[k+2])) V = tn.reshape(V,[rank[k+1],-1]) U = tn.reshape(U,[-1,rank[k+1]]) # split cores Qmat, Rmat = QR(U) idx = _maxvol(Qmat) Sub = Qmat[idx,:] core = tn.linalg.solve(Sub.T,Qmat.T).t() core_next = Sub@Rmat@V cores[k] = tn.reshape(core,[rank[k],N[k],rank[k+1]]) cores[k+1] = tn.reshape(core_next,[rank[k+1],N[k+1],rank[k+2]]) # calc Ps tmp = tn.einsum('ij,jkl->ikl',Ps[k],cores[k]) _,Ps[k+1] = QR(tn.reshape(tmp,[rank[k]*N[k],rank[k+1]])) # calc Idx tmp = np.unravel_index(idx[:rank[k+1]],(rank[k],N[k])) idx_new = tn.tensor(np.hstack( ( Idx[k][tmp[0],:] , tmp[1].reshape([-1,1]) ) )) Idx[k+1] = idx_new+0 #right to left for k in range(d-2,-1,-1): if verbose: print('\tRL supercore %d,%d'%(k+1,k+2)) I1 = tn.reshape(tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.arange(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),[-1,1]) I2 = tn.reshape(tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.arange(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),[-1,1]) I3 = Idx[k][tn.kron(tn.kron(tn.arange(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),:] I4 = Idx[k+2][:,tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.arange(rank[k+2],dtype=tn.int64)))].t() eval_index = tn.concat((I3, I1, I2, I4),1) eval_index = tn.reshape(eval_index,[-1,d]).to(dtype=tn.int64) if verbose: print('\t\tnumber evaluations',eval_index.shape[0]) if eval_vect: supercore = tn.reshape(function(eval_index).to(dtype=dtype),[rank[k],N[k],N[k+1],rank[k+2]]) n_eval += eval_index.shape[0] # multiply with P_k left and right supercore = tn.einsum('ij,jklm,mn->ikln',Ps[k],supercore.to(dtype=dtype),Ps[k+2]) rank[k] = supercore.shape[0] rank[k+2] = supercore.shape[3] supercore = tn.reshape(supercore,[supercore.shape[0]*supercore.shape[1],-1]) # split the super core with svd U,S,V = SVD(supercore) rnew = rank_chop(S.cpu().numpy(),tn.linalg.norm(S).cpu().numpy()*eps/np.sqrt(d-1))+1 rnew = min(S.shape[0],rnew) U = U[:,:rnew] S = S[:rnew] V = V[:rnew,:] # print('kkt new',tn.linalg.norm(supercore-U@tn.diag(S)@V)) #kick the rank U = U @ tn.diag(S) VK = tn.randn((kick,V.shape[1]) , dtype=dtype, device = device) V, Rtemp = QR( tn.cat( (V,VK) , 0).t() ) radd = V.shape[1] - rnew if radd>0: U = tn.cat( (U,tn.zeros((U.shape[0],radd), dtype = dtype, device = device)) , 1 ) U = U @ Rtemp.T V = V.t() # print('kkt new',tn.linalg.norm(supercore-U@V)) # compute err (dx) super_prev = tn.einsum('ijk,kmn->ijmn',cores[k],cores[k+1]) super_prev = tn.einsum('ij,jklm,mn->ikln',Ps[k],super_prev,Ps[k+2]) err = tn.linalg.norm(supercore.flatten()-super_prev.flatten())/tn.linalg.norm(supercore) max_err = max(max_err,err) # update the rank if verbose: print('\t\trank updated %d -> %d, local error %e'%(rank[k+1],U.shape[1],err)) rank[k+1] = U.shape[1] U = tn.linalg.solve(Ps[k],tn.reshape(U,[rank[k],-1])) V = tn.linalg.solve(Ps[k+2].t(),tn.reshape(V,[rank[k+1]*N[k+1],rank[k+2]]).t()).t() # U = tn.einsum('ij,jkl->ikl',tn.linalg.inv(Ps[k]),tn.reshape(U,[rank[k],N[k],-1])) # V = tn.einsum('ijk,kl->ijl',tn.reshape(V,[-1,N[k+1],rank[k+2]]),tn.linalg.inv(Ps[k+2])) V = tn.reshape(V,[rank[k+1],-1]) U = tn.reshape(U,[-1,rank[k+1]]) # split cores Qmat, Rmat = QR(V.T) idx = _maxvol(Qmat) Sub = Qmat[idx,:] core_next = tn.linalg.solve(Sub.T,Qmat.T) core =U@(Sub@Rmat).t() cores[k] = tn.reshape(core,[rank[k],N[k],-1]) cores[k+1] = tn.reshape(core_next,[-1,N[k+1],rank[k+2]]) # calc Ps tmp = tn.einsum('ijk,kl->ijl',cores[k+1],Ps[k+2]) _,tmp = QR(tn.reshape(tmp,[rank[k+1],-1]).t()) Ps[k+1] = tmp # calc Idx tmp = np.unravel_index(idx[:rank[k+1]],(N[k+1],rank[k+2])) idx_new = tn.tensor(np.vstack( ( tmp[0].reshape([1,-1]),Idx[k+2][:,tmp[1]] ) )) Idx[k+1] = idx_new+0 #xxx = TT(cores) #print('# ',xxx[1,2,3,4]) # exit condition if max_err<eps: if verbose: print('Max error %e < %e ----> DONE'%(max_err,eps)) break else: if verbose: print('Max error %g'%(max_err)) if verbose: print('number of function calls ',n_eval) print() return torchtt.TT(cores)
def function_interpolate(function, x, eps=1e-09, start_tens=None, nswp=20, kick=2, dtype=torch.float64, verbose=False)
-
Appication of a nonlinear function on a tensor in the TT format (using DMRG). Two cases are distinguished:
- Univariate interpoaltion:
Let f:\mathbb{R}\rightarrow\mathbb{R} be a function and \mathsf{x}\in\mathbb{R}^{N_1\times\cdots\times N_d} be a tensor with a known TT approximation. The goal is to determine the TT approximation of \mathsf{y}_{i_1...i_d}=f(\mathsf{x}_{i_1...i_d}) within a prescribed relative accuracy
eps
.- Multivariate interpolation
Let f:\mathbb{R}\rightarrow\mathbb{R} be a function and \mathsf{x}^{(1)},...,\mathsf{x}^{(d)}\in\mathbb{R}^{N_1\times\cdots\times N_d} be tensors with a known TT approximation. The goal is to determine the TT approximation of \mathsf{y}_{i_1...i_d}=f(\mathsf{x}_{i_1...i_d}^{(1)},...,\mathsf{x}^{(d)})_{i_1...i_d} within a prescribed relative accuracy
eps
.Example
- Univariate interpolation:
func = lambda t: torch.log(t) y = tntt.interpolate.function_interpolate(func, x, 1e-9) # the tensor x is chosen such that y has an afforbable low rank structure
- Multivariate interpolation:
xs = tntt.meshgrid([tn.arange(0,n,dtype=torch.float64) for n in N]) func = lambda x: 1/(2+tn.sum(x,1).to(dtype=torch.float64)) z = tntt.interpolate.function_interpolate(func, xs)
Args
function
:Callable
- function handle. If the argument
x
is aTT
instance, the the function handle has to be appliable elementwise on torch tensors. If a list is passed asx
, the function handle takes as argument a $M imes d$ torch.tensor and every of the $M$ lines corresponds to an evaluation of the function f at a certain tensor entry. The function handle returns a torch tensor of length M. x
:TT
orlist[TT]
- the argument/arguments of the function.
eps
:float
, optional- the relative accuracy. Defaults to 1e-9.
start_tens
:TT
, optional- initial approximation of the output tensor (None coresponds to random initialization). Defaults to None.
nswp
:int
, optional- number of iterations. Defaults to 20.
kick
:int
, optional- enrichment rank. Defaults to 2.
dtype
:torch.dtype
, optional- the dtype of the result. Defaults to tn.float64.
verbose
:bool
, optional- display debug information to the console. Defaults to False.
Returns
TT
- the result.
Expand source code
def function_interpolate(function, x, eps = 1e-9, start_tens = None, nswp = 20, kick = 2, dtype = tn.float64, verbose = False): """ Appication of a nonlinear function on a tensor in the TT format (using DMRG). Two cases are distinguished: * Univariate interpoaltion: Let \(f:\\mathbb{R}\\rightarrow\\mathbb{R}\) be a function and \(\\mathsf{x}\\in\\mathbb{R}^{N_1\\times\\cdots\\times N_d}\) be a tensor with a known TT approximation. The goal is to determine the TT approximation of \(\\mathsf{y}_{i_1...i_d}=f(\\mathsf{x}_{i_1...i_d})\) within a prescribed relative accuracy `eps`. * Multivariate interpolation Let \(f:\\mathbb{R}\\rightarrow\\mathbb{R}\) be a function and \(\\mathsf{x}^{(1)},...,\\mathsf{x}^{(d)}\\in\\mathbb{R}^{N_1\\times\\cdots\\times N_d}\) be tensors with a known TT approximation. The goal is to determine the TT approximation of \(\\mathsf{y}_{i_1...i_d}=f(\\mathsf{x}_{i_1...i_d}^{(1)},...,\\mathsf{x}^{(d)})_{i_1...i_d}\) within a prescribed relative accuracy `eps`. Example: * Univariate interpolation: ``` func = lambda t: torch.log(t) y = tntt.interpolate.function_interpolate(func, x, 1e-9) # the tensor x is chosen such that y has an afforbable low rank structure ``` * Multivariate interpolation: ``` xs = tntt.meshgrid([tn.arange(0,n,dtype=torch.float64) for n in N]) func = lambda x: 1/(2+tn.sum(x,1).to(dtype=torch.float64)) z = tntt.interpolate.function_interpolate(func, xs) ``` Args: function (Callable): function handle. If the argument `x` is a `torchtt.TT` instance, the the function handle has to be appliable elementwise on torch tensors. If a list is passed as `x`, the function handle takes as argument a $M\times d$ torch.tensor and every of the $M$ lines corresponds to an evaluation of the function \(f\) at a certain tensor entry. The function handle returns a torch tensor of length M. x (torchtt.TT or list[torchtt.TT]): the argument/arguments of the function. eps (float, optional): the relative accuracy. Defaults to 1e-9. start_tens (torchtt.TT, optional): initial approximation of the output tensor (None coresponds to random initialization). Defaults to None. nswp (int, optional): number of iterations. Defaults to 20. kick (int, optional): enrichment rank. Defaults to 2. dtype (torch.dtype, optional): the dtype of the result. Defaults to tn.float64. verbose (bool, optional): display debug information to the console. Defaults to False. Returns: torchtt.TT: the result. """ if isinstance(x,list) or isinstance(x,tuple): eval_mv = True N = x[0].N else: eval_mv = False N = x.N device = None if not eval_mv and len(N)==1: return torchtt.TT(function(x.full())).to(device) if eval_mv and len(N)==1: return torchtt.TT(function(x[0].full())).to(device) d = len(N) #random init of the tensor if start_tens == None: rank_init = 2 cores = torchtt.random(N,rank_init, dtype, device).cores rank = [1]+[rank_init]*(d-1)+[1] else: rank = start_tens.R.copy() cores = [c+0 for c in start_tens.cores] # cores = (ones(N,dtype=dtype)).cores cores, rank = lr_orthogonal(cores,rank,False) Mats = []*(d+1) Ps = [tn.ones((1,1),dtype=dtype,device=device)]+(d-1)*[None] + [tn.ones((1,1),dtype=dtype,device=device)] # ortho Rm = tn.ones((1,1),dtype=dtype,device=device) Idx = [tn.zeros((1,0),dtype=tn.int64)]+(d-1)*[None] + [tn.zeros((0,1),dtype=tn.int64)] for k in range(d-1,0,-1): tmp = tn.einsum('ijk,kl->ijl',cores[k],Rm) tmp = tn.reshape(tmp,[rank[k],-1]).t() core, Rmat = QR(tmp) rnew = min(N[k]*rank[k+1], rank[k]) Jk = _maxvol(core) # print(Jk) tmp = np.unravel_index(Jk[:rnew],(rank[k+1],N[k])) #if k==d-1: # idx_new = tn.tensor(tmp[1].reshape([1,-1])) # else: idx_new = tn.tensor(np.vstack( ( tmp[1].reshape([1,-1]),Idx[k+1][:,tmp[0]] ) )) Idx[k] = idx_new+0 Rm = core[Jk,:] core = tn.linalg.solve(Rm.T,core.T) Rm = (Rm@Rmat).t() cores[k] = tn.reshape(core,[rnew,N[k],rank[k+1]]) core = tn.reshape(core,[-1,rank[k+1]]) @ Ps[k+1] core = tn.reshape(core,[rank[k],-1]).t() _,Ps[k] = QR(core) cores[0] = tn.einsum('ijk,kl->ijl',cores[0],Rm) # for p in Ps: # print(p) # for i in Idx: # print(i) # return n_eval = 0 for swp in range(nswp): max_err = 0.0 if verbose: print('Sweep %d: '%(swp+1)) #left to right for k in range(d-1): if verbose: print('\tLR supercore %d,%d'%(k+1,k+2)) I1 = tn.reshape(tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.arange(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),[-1,1]) I2 = tn.reshape(tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.arange(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),[-1,1]) I3 = Idx[k][tn.kron(tn.kron(tn.arange(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),:] I4 = Idx[k+2][:,tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.arange(rank[k+2],dtype=tn.int64)))].t() eval_index = tn.concat((I3, I1, I2, I4),1) eval_index = tn.reshape(eval_index,[-1,d]).to(dtype=tn.int64) if verbose: print('\t\tnumber evaluations',eval_index.shape[0]) if eval_mv: ev = tn.zeros((eval_index.shape[0],0),dtype = dtype) for j in range(d): core = x[j].cores[0][0,eval_index[:,0],:] for i in range(1,d): core = tn.einsum('ij,jil->il',core,x[j].cores[i][:,eval_index[:,i],:]) core = tn.reshape(core[...,0],[-1,1]) ev = tn.hstack((ev,core)) supercore = tn.reshape(function(ev),[rank[k],N[k],N[k+1],rank[k+2]]) n_eval += core.shape[0] else: core = x.cores[0][0,eval_index[:,0],:] for i in range(1,d): core = tn.einsum('ij,jil->il',core,x.cores[i][:,eval_index[:,i],:]) core = core[...,0] supercore = tn.reshape(function(core),[rank[k],N[k],N[k+1],rank[k+2]]) n_eval += core.shape[0] # multiply with P_k left and right supercore = tn.einsum('ij,jklm,mn->ikln',Ps[k],supercore.to(dtype=dtype),Ps[k+2]) rank[k] = supercore.shape[0] rank[k+2] = supercore.shape[3] supercore = tn.reshape(supercore,[supercore.shape[0]*supercore.shape[1],-1]) # split the super core with svd U,S,V = SVD(supercore) rnew = rank_chop(S.cpu().numpy(),tn.linalg.norm(S).cpu().numpy()*eps/np.sqrt(d-1))+1 rnew = min(S.shape[0],rnew) U = U[:,:rnew] S = S[:rnew] V = V[:rnew,:] # print('kkt new',tn.linalg.norm(supercore-U@tn.diag(S)@V)) # kick the rank V = tn.diag(S) @ V UK = tn.randn((U.shape[0],kick), dtype = dtype, device = device) U, Rtemp = QR( tn.cat( (U,UK) , 1) ) radd = Rtemp.shape[1] - rnew if radd>0: V = tn.cat( (V,tn.zeros((radd,V.shape[1]), dtype = dtype, device = device)) , 0 ) V = Rtemp @ V # print('kkt new',tn.linalg.norm(supercore-U@V)) # compute err (dx) super_prev = tn.einsum('ijk,kmn->ijmn',cores[k],cores[k+1]) super_prev = tn.einsum('ij,jklm,mn->ikln',Ps[k],super_prev,Ps[k+2]) err = tn.linalg.norm(supercore.flatten()-super_prev.flatten())/tn.linalg.norm(supercore) max_err = max(max_err,err) # update the rank if verbose: print('\t\trank updated %d -> %d, local error %e'%(rank[k+1],U.shape[1],err)) rank[k+1] = U.shape[1] U = tn.linalg.solve(Ps[k],tn.reshape(U,[rank[k],-1])) V = tn.linalg.solve(Ps[k+2].t(),tn.reshape(V,[rank[k+1]*N[k+1],rank[k+2]]).t()).t() # U = tn.einsum('ij,jkl->ikl',tn.linalg.inv(Ps[k]),tn.reshape(U,[rank[k],N[k],-1])) # V = tn.einsum('ijk,kl->ijl',tn.reshape(V,[-1,N[k+1],rank[k+2]]),tn.linalg.inv(Ps[k+2])) V = tn.reshape(V,[rank[k+1],-1]) U = tn.reshape(U,[-1,rank[k+1]]) # split cores Qmat, Rmat = QR(U) idx = _maxvol(Qmat) Sub = Qmat[idx,:] core = tn.linalg.solve(Sub.T,Qmat.T).t() core_next = Sub@Rmat@V cores[k] = tn.reshape(core,[rank[k],N[k],rank[k+1]]) cores[k+1] = tn.reshape(core_next,[rank[k+1],N[k+1],rank[k+2]]) # calc Ps tmp = tn.einsum('ij,jkl->ikl',Ps[k],cores[k]) _,Ps[k+1] = QR(tn.reshape(tmp,[rank[k]*N[k],rank[k+1]])) # calc Idx tmp = np.unravel_index(idx[:rank[k+1]],(rank[k],N[k])) idx_new = tn.tensor(np.hstack( ( Idx[k][tmp[0],:] , tmp[1].reshape([-1,1]) ) )) Idx[k+1] = idx_new+0 #right to left for k in range(d-2,-1,-1): if verbose: print('\tRL supercore %d,%d'%(k+1,k+2)) I1 = tn.reshape(tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.arange(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),[-1,1]) I2 = tn.reshape(tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.arange(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),[-1,1]) I3 = Idx[k][tn.kron(tn.kron(tn.arange(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.ones(rank[k+2],dtype=tn.int64))),:] I4 = Idx[k+2][:,tn.kron(tn.kron(tn.ones(rank[k],dtype=tn.int64), tn.ones(N[k],dtype=tn.int64)), tn.kron(tn.ones(N[k+1],dtype=tn.int64), tn.arange(rank[k+2],dtype=tn.int64)))].t() eval_index = tn.concat((I3, I1, I2, I4),1) eval_index = tn.reshape(eval_index,[-1,d]).to(dtype=tn.int64) if verbose: print('\t\tnumber evaluations',eval_index.shape[0]) if eval_mv: ev = tn.zeros((eval_index.shape[0],0),dtype = dtype) for j in range(d): core = x[j].cores[0][0,eval_index[:,0],:] for i in range(1,d): core = tn.einsum('ij,jil->il',core,x[j].cores[i][:,eval_index[:,i],:]) core = tn.reshape(core[...,0],[-1,1]) ev = tn.hstack((ev,core)) supercore = tn.reshape(function(ev),[rank[k],N[k],N[k+1],rank[k+2]]) n_eval += core.shape[0] else: core = x.cores[0][0,eval_index[:,0],:] for i in range(1,d): core = tn.einsum('ij,jil->il',core,x.cores[i][:,eval_index[:,i],:]) core = core[...,0] supercore = tn.reshape(function(core),[rank[k],N[k],N[k+1],rank[k+2]]) n_eval +=core.shape[0] # multiply with P_k left and right supercore = tn.einsum('ij,jklm,mn->ikln',Ps[k],supercore.to(dtype=dtype),Ps[k+2]) rank[k] = supercore.shape[0] rank[k+2] = supercore.shape[3] supercore = tn.reshape(supercore,[supercore.shape[0]*supercore.shape[1],-1]) # split the super core with svd U,S,V = SVD(supercore) rnew = rank_chop(S.cpu().numpy(),tn.linalg.norm(S).cpu().numpy()*eps/np.sqrt(d-1))+1 rnew = min(S.shape[0],rnew) U = U[:,:rnew] S = S[:rnew] V = V[:rnew,:] # print('kkt new',tn.linalg.norm(supercore-U@tn.diag(S)@V)) #kick the rank # print('u before', U.shape) U = U @ tn.diag(S) VK = tn.randn((kick,V.shape[1]) , dtype=dtype, device = device) # print('V enrich', V.shape) V, Rtemp = QR( tn.cat( (V,VK) , 0).t() ) radd = Rtemp.shape[1] - rnew # print('V after QR',V.shape,Rtemp.shape,radd) if radd>0: U = tn.cat( (U,tn.zeros((U.shape[0],radd), dtype = dtype, device = device)) , 1 ) U = U @ Rtemp.T V = V.t() # print('kkt new',tn.linalg.norm(supercore-U@V)) # compute err (dx) super_prev = tn.einsum('ijk,kmn->ijmn',cores[k],cores[k+1]) super_prev = tn.einsum('ij,jklm,mn->ikln',Ps[k],super_prev,Ps[k+2]) err = tn.linalg.norm(supercore.flatten()-super_prev.flatten())/tn.linalg.norm(supercore) max_err = max(max_err,err) # update the rank if verbose: print('\t\trank updated %d -> %d, local error %e'%(rank[k+1],U.shape[1],err)) rank[k+1] = U.shape[1] U = tn.linalg.solve(Ps[k],tn.reshape(U,[rank[k],-1])) V = tn.linalg.solve(Ps[k+2].t(),tn.reshape(V,[rank[k+1]*N[k+1],rank[k+2]]).t()).t() # U = tn.einsum('ij,jkl->ikl',tn.linalg.inv(Ps[k]),tn.reshape(U,[rank[k],N[k],-1])) # V = tn.einsum('ijk,kl->ijl',tn.reshape(V,[-1,N[k+1],rank[k+2]]),tn.linalg.inv(Ps[k+2])) V = tn.reshape(V,[rank[k+1],-1]) U = tn.reshape(U,[-1,rank[k+1]]) # split cores Qmat, Rmat = QR(V.T) idx = _maxvol(Qmat) Sub = Qmat[idx,:] core_next = tn.linalg.solve(Sub.T,Qmat.T) core =U@(Sub@Rmat).t() cores[k] = tn.reshape(core,[rank[k],N[k],-1]) cores[k+1] = tn.reshape(core_next,[-1,N[k+1],rank[k+2]]) # calc Ps tmp = tn.einsum('ijk,kl->ijl',cores[k+1],Ps[k+2]) _,tmp = QR(tn.reshape(tmp,[rank[k+1],-1]).t()) Ps[k+1] = tmp # calc Idx tmp = np.unravel_index(idx[:rank[k+1]],(N[k+1],rank[k+2])) idx_new = tn.tensor(np.vstack( ( tmp[0].reshape([1,-1]),Idx[k+2][:,tmp[1]] ) )) Idx[k+1] = idx_new+0 #xxx = TT(cores) #print('# ',xxx[1,2,3,4]) # exit condition if max_err<eps: if verbose: print('Max error %e < %e ----> DONE'%(max_err,eps)) break else: if verbose: print('Max error %g'%(max_err)) if verbose: print('number of function calls ',n_eval) print() return torchtt.TT(cores)