Module torchtt.solvers
System solvers in the TT format.
Expand source code
"""
System solvers in the TT format.
"""
import torch as tn
import numpy as np
import torchtt
import datetime
from torchtt._decomposition import QR, SVD, lr_orthogonal, rl_orthogonal
from torchtt._iterative_solvers import BiCGSTAB_reset, gmres_restart
import opt_einsum as oe
from .errors import *
try:
import torchttcpp
_flag_use_cpp = True
except:
import warnings
warnings.warn("\x1B[33m\nC++ implementation not available. Using pure Python.\n\033[0m")
_flag_use_cpp = False
def cpp_enabled():
"""
Is the C++ backend enabled?
Returns:
bool: the flag
"""
return _flag_use_cpp
def _local_product(Phi_right, Phi_left, coreA, core, shape):
"""
Compute local matvec product
Args:
Phi (torch.tensor): right tensor of shape r x R x r.
Psi (torch.tensor): left tensor of shape lp x Rp x lp.
coreA (torch.tensor): current core of A, shape is rp x N x N x r.
x (torch.tensor): the current core of x, shape is rp x N x r.
shape (torch.Size): the shape of x.
Returns:
torch.tensor: the reuslt.
"""
# tme1 = datetime.datetime.now()
# w = tn.einsum('lsr,smnS,LSR,rnR->lmL',Phi_left,coreA,Phi_right,core)
# tme1 = datetime.datetime.now() - tme1
# tme2 = datetime.datetime.now()
w = oe.contract('lsr,smnS,LSR,rnR->lmL',Phi_left,coreA,Phi_right,core)
#tme2 = datetime.datetime.now() - tme2
#print('################### ',tme1,tme2)
# product = tn.reshape(w,[-1])
return w
class _LinearOp():
def __init__(self,Phi_left,Phi_right,coreA,shape,prec):
self.Phi_left = Phi_left
self.Phi_right = Phi_right
self.coreA = coreA
self.shape = shape
self.prec = prec
#tme = datetime.datetime.now()
#self.contraction = oe.contract_expression('lsr,smnS,LSR,rnR->lmL', Phi_left.shape, coreA.shape, Phi_right.shape, shape)
#tme = datetime.datetime.now() - tme
#print('contr ',tme)
# tme = datetime.datetime.now()
if prec == 'c':
# Jl = oe.contract('sd,smnS->dmnS',tn.diagonal(Phi_left,0,0,2),coreA)
Jl = tn.einsum('sd,smnS->dmnS',tn.diagonal(Phi_left,0,0,2),coreA)
Jr = tn.diagonal(Phi_right,0,0,2)
# J = oe.contract('dmnS,SD->dDmn',Jl,Jr)
J = tn.einsum('dmnS,SD->dDmn',Jl,Jr)
self.J = tn.linalg.inv(J)
if shape[0]*shape[1]*shape[2] > 1e5:
self.contraction = oe.contract_expression('lsr,smnS,LSR,raR,rRna->lmL', Phi_left.shape, coreA.shape, Phi_right.shape, shape, self.J.shape)
else:
self.contraction = None
if prec == 'r':
Jl = tn.einsum('sd,smnS->dmnS',tn.diagonal(Phi_left,0,0,2),coreA)
J = tn.einsum('dmnS,LSR->dmLnR',Jl,Phi_right)
sh = J.shape
J = tn.reshape(J, [-1,J.shape[1]*J.shape[2], J.shape[3]*J.shape[4]])
self.J = tn.reshape(tn.linalg.inv(J), sh)
if shape[0]*shape[1]*shape[2] > 2*1e4:
self.contraction = oe.contract_expression('lsr,smnS,LSR,rab,rnRab->lmL', Phi_left.shape, coreA.shape, Phi_right.shape, shape, self.J.shape)
else:
self.contraction = None
# tme = datetime.datetime.now() - tme
# print('contr ',tme)
def apply_prec(self,x):
if self.prec == 'c':
y = tn.einsum('rnR,rRmn->rmR',x,self.J) # no improvement using opt_einsum
return y
elif self.prec == 'r':
y = tn.einsum('rnR,rmLnR->rmL', x, self.J)
return y
def matvec(self, x, apply_prec = True):
if self.prec == None or not apply_prec:
x = tn.reshape(x,self.shape)
# tme = datetime.datetime.now()
#w = oe.contract('lsr,smnS,LSR,rnR->lmL',self.Phi_left,self.coreA,self.Phi_right,x)
# # path = oe.contract_path('lsr,smnS,LSR,rnR->lmL',self.Phi_left,self.coreA,self.Phi_right,x,optimize = 'optimal')
# # print(path[1])
# tme = datetime.datetime.now() - tme
# print('time 1 ',tme)
# tme = datetime.datetime.now()
# #w = tn.einsum('lsr,smnS,LSR,rnR->lmL',self.Phi_left,self.coreA,self.Phi_right,x)
# w = tn.einsum('rnR,lsr->nRls',x,self.Phi_left)
w = tn.tensordot(x,self.Phi_left,([0],[2])) # shape rnR,lsr->nRls
w = tn.tensordot(w,self.coreA,([0,3],[2,0])) # nRls,smnS->RlmS
w = tn.tensordot(w,self.Phi_right,([0,3],[2,1])) # RlmS,LSR->lmL
# w = self.contraction(self.Phi_left,self.coreA,self.Phi_right,x)
# tme = datetime.datetime.now() - tme
# # print('time 2 ',tme)
#elif self.prec == 'c':
#
# x = tn.reshape(x,self.shape)
# w = self.contraction(self.Phi_left, self.coreA, self.Phi_right, x, self.J)
elif self.prec == 'c' or self.prec == 'r':
# tme = datetime.datetime.now()
x = tn.reshape(x,self.shape)
# tme = datetime.datetime.now() - tme
# print('reshape ',tme)
if not self.contraction is None:
#tme = datetime.datetime.now()
w = self.contraction(self.Phi_left, self.coreA, self.Phi_right, x, self.J)
#tme = datetime.datetime.now() - tme
#print('optimized ',tme)
#tme = datetime.datetime.now()
else:
x = self.apply_prec(x)
w = tn.tensordot(x,self.Phi_left,([0],[2])) # shape rnR,lsr->nRls
w = tn.tensordot(w,self.coreA,([0,3],[2,0])) # nRls,smnS->RlmS
w = tn.tensordot(w,self.Phi_right,([0,3],[2,1])) # RlmS,LSR->lmL
#tme = datetime.datetime.now() - tme
#print('custom ',tme)
else:
raise Exception('Preconditioner '+str(self.prec)+' not defined.')
return tn.reshape(w,[-1,1])
def amen_solve(A, b, nswp = 22, x0 = None, eps = 1e-10,rmax = 32768, max_full = 500, kickrank = 4, kick2 = 0, trunc_norm = 'res', local_solver = 1, local_iterations = 40, resets = 2, verbose = False, preconditioner = None, use_cpp = True, use_single_precision = False):
"""
Solve a multilinear system \(\\mathsf{Ax} = \\mathsf{b}\) in the Tensor Train format.
This method implements the algorithm from [Sergey V Dolgov, Dmitry V Savostyanov, Alternating minimal energy methods for linear systems in higher dimensions](https://epubs.siam.org/doi/abs/10.1137/140953289).
Example:
```
import torchtt
A = torchtt.random([(4,4),(5,5),(6,6)],[1,2,3,1]) # create random matrix
x = torchtt.random([4,5,6],[1,2,3,1]) # invent a random solution
b = A @ x # compute the rhs
xx = torchtt.solvers.amen_solve(A,b) # solve
print((xx-x).norm()/x.norm()) # error
```
Args:
A (torchtt.TT): the system matrix in TT.
b (torchtt.TT): the right hand side in TT.
nswp (int, optional): number of sweeps. Defaults to 22.
x0 (torchtt.TT, optional): initial guess. In None is provided the initial guess is a ones tensor. Defaults to None.
eps (float, optional): relative residual. Defaults to 1e-10.
rmax (int, optional): maximum rank. Defaults to 100000.
max_full (int, optional): the maximum size of the core until direct solver is used for the local subproblem. Defaults to 500.
kickrank (int, optional): rank enrichment. Defaults to 4.
kick2 (int, optional): [description]. Defaults to 0.
trunc_norm (str, optional): [description]. Defaults to 'res'.
local_solver (int, optional): choose local iterative solver: 1 for GMRES and 2 for BiCGSTAB. Defaults to 1.
local_iterations (int, optional): number of GMRES iterations for the local subproblems. Defaults to 40.
resets (int, optional): number of resets in the GMRES. Defaults to 2.
verbose (bool, optional): choose whether to display or not additional information during the runtime. Defaults to True.
preconditioner (string, optional): Choose the preconditioner for the local system. Possible values are None, 'c' (central Jacobi preconditioner). No preconditioner is used if None is provided. Defaults to None.
use_cpp (bool, optional): use the C++ implementation of AMEn. Defaults to True.
Raises:
InvalidArguments: A and b must be TT instances.
InvalidArguments: Invalid preconditioner.
IncompatibleTypes: A must be TT-matrix and b must be vector.
ShapeMismatch: A is not quadratic.
ShapeMismatch: Dimension mismatch.
Returns:
torchtt.TT: the approximation of the solution in TT format.
"""
# perform checks of the input data
if not (isinstance(A,torchtt.TT) and isinstance(b,torchtt.TT)):
raise InvalidArguments('A and b must be TT instances.')
if not (A.is_ttm and not b.is_ttm) :
raise IncompatibleTypes('A must be TT-matrix and b must be vector.')
if A.M != A.N:
raise ShapeMismatch('A is not quadratic.')
if A.N != b.N:
raise ShapeMismatch('Dimension mismatch.')
if use_cpp and _flag_use_cpp:
if x0 == None:
x_cores = []
x_R = [1]*(1+len(A.N))
else:
x_cores = x0.cores
x_R = x0.R
if preconditioner == None:
prec = 0
elif preconditioner == 'c':
prec = 1
elif preconditioner == 'r':
prec = 2
else:
raise InvalidArguments("Invalid preconditioner.")
cores = torchttcpp.amen_solve(A.cores, b.cores, x_cores, b.N, A.R, b.R, x_R, nswp, eps, rmax, max_full, kickrank, kick2, local_iterations, resets, verbose, prec)
return torchtt.TT(list(cores))
else:
return _amen_solve_python(A, b, nswp, x0, eps,rmax, max_full, kickrank, kick2, trunc_norm, local_solver, local_iterations, resets, verbose, preconditioner, use_single_precision)
def _amen_solve_python(A, b, nswp = 22, x0 = None, eps = 1e-10,rmax = 1024, max_full = 500, kickrank = 4, kick2 = 0, trunc_norm = 'res', local_solver = 1, local_iterations = 40, resets = 2, verbose = False, preconditioner = None, use_single_precision = False):
if verbose: time_total = datetime.datetime.now()
dtype = A.cores[0].dtype
device = A.cores[0].device
rank_search = 1 # binary rank search
damp = 2
if x0 == None:
x = torchtt.ones(b.N, dtype = dtype, device = device)
else:
x = x0
# kkt = torchttcpp.amen_solve(A.cores, b.cores, x.cores, b.N, A.R, b.R, x.R, nswp, eps, rmax, max_full, kickrank, kick2, local_iterations, resets, verbose, 0)
rA = A.R
N = b.N
d = len(N)
x_cores = x.cores.copy()
rx = x.R.copy()
# check if rmax is a list
if isinstance(rmax, int):
rmax = [1] + (d-1) * [rmax] + [1]
# z cores
rz = [1]+(d-1)*[kickrank+kick2]+[1]
z_tt = torchtt.random(N,rz,dtype,device = device)
z_cores = z_tt.cores
z_cores, rz = rl_orthogonal(z_cores, rz, False)
norms = np.zeros(d)
Phiz = [tn.ones((1,1,1), dtype = dtype, device = device)] + [None] * (d-1) + [tn.ones((1,1,1), dtype = dtype, device = device)] # size is rzk x Rk x rxk
Phiz_b = [tn.ones((1,1), dtype = dtype, device = device)] + [None] * (d-1) + [tn.ones((1,1), dtype = dtype, device = device)] # size is rzk x rzbk
Phis = [tn.ones((1,1,1), dtype = dtype, device = device)] + [None] * (d-1) + [tn.ones((1,1,1), dtype = dtype, device = device)] # size is rk x Rk x rk
Phis_b = [tn.ones((1,1), dtype = dtype, device = device)] + [None] * (d-1) + [tn.ones((1,1), dtype = dtype, device = device)] # size is rk x rbk
last = False
normA = np.ones((d-1))
normb = np.ones((d-1))
normx = np.ones((d-1))
nrmsc = 1.0
if verbose:
print('Starting AMEn solve with:\n\tepsilon: %g\n\tsweeps: %d\n\tlocal iterations: %d\n\tresets: %d\n\tpreconditioner: %s'%(eps, nswp, local_iterations, resets, str(preconditioner)))
print()
for swp in range(nswp):
# right to left orthogonalization
if verbose:
print()
print('Starting sweep %d %s...'%(swp+1,"(last one) " if last else ""))
tme_sweep = datetime.datetime.now()
tme = datetime.datetime.now()
for k in range(d-1,0,-1):
# update the z part (ALS) update
if not last:
if swp > 0:
czA = _local_product(Phiz[k+1],Phiz[k],A.cores[k],x_cores[k],x_cores[k].shape) # shape rzp x N x rz
czy = tn.einsum('br,bnB,BR->rnR',Phiz_b[k],b.cores[k],Phiz_b[k+1]) # shape is rzp x N x rz
cz_new = czy*nrmsc - czA
_,_,vz = SVD(tn.reshape(cz_new,[cz_new.shape[0],-1]))
cz_new = vz[:min(kickrank,vz.shape[0]),:].t() # truncate to kickrank
if k < d-1: # extend cz_new with random elements
cz_new = tn.cat((cz_new,tn.randn((cz_new.shape[0],kick2), dtype = dtype, device = device)),1)
else:
cz_new = tn.reshape(z_cores[k],[rz[k],-1]).t()
qz, _ = QR(cz_new)
rz[k] = qz.shape[1]
z_cores[k] = tn.reshape(qz.t(),[rz[k],N[k],rz[k+1]])
# norm correction ?
if swp > 0: nrmsc = nrmsc * normA[k-1] * normx[k-1] / normb[k-1]
core = tn.reshape(x_cores[k],[rx[k],N[k]*rx[k+1]]).t()
Qmat, Rmat = QR(core)
core_prev = tn.einsum('ijk,km->ijm',x_cores[k-1],Rmat.T)
rx[k] = Qmat.shape[1]
current_norm = tn.linalg.norm(core_prev)
if current_norm>0:
core_prev = core_prev / current_norm
else:
current_norm = 1.0
normx[k-1] = normx[k-1]*current_norm
x_cores[k] = tn.reshape(Qmat.t(),[rx[k],N[k],rx[k+1]])
x_cores[k-1] = core_prev[:]
# update phis (einsum)
# print(x_cores[k].shape,A.cores[k].shape,x_cores[k].shape)
Phis[k] = _compute_phi_bck_A(Phis[k+1],x_cores[k],A.cores[k],x_cores[k])
Phis_b[k] = _compute_phi_bck_rhs(Phis_b[k+1],b.cores[k],x_cores[k])
# ... and norms
norm = tn.linalg.norm(Phis[k])
norm = norm if norm>0 else 1.0
normA[k-1] = norm
Phis[k] = Phis[k] / norm
norm = tn.linalg.norm(Phis_b[k])
norm = norm if norm>0 else 1.0
normb[k-1] = norm
Phis_b[k] = Phis_b[k]/norm
# norm correction
nrmsc = nrmsc * normb[k-1]/ (normA[k-1] * normx[k-1])
# compute phis_z
if not last:
Phiz[k] = _compute_phi_bck_A(Phiz[k+1], z_cores[k], A.cores[k], x_cores[k]) / normA[k-1]
Phiz_b[k] = _compute_phi_bck_rhs(Phiz_b[k+1], b.cores[k], z_cores[k]) / normb[k-1]
# start loop
max_res = 0
max_dx = 0
for k in range(d):
if verbose: print('\tCore',k)
previous_solution = tn.reshape(x_cores[k],[-1,1])
# assemble rhs
rhs = tn.einsum('br,bmB,BR->rmR',Phis_b[k] , b.cores[k] * nrmsc, Phis_b[k+1])
rhs = tn.reshape(rhs,[-1,1])
norm_rhs = tn.linalg.norm(rhs)
#residuals
real_tol = (eps/np.sqrt(d))/damp
# solve the local system
use_full = rx[k]*N[k]*rx[k+1] < max_full
if use_full:
# solve the full system
if verbose: print('\t\tChoosing direct solver (local size %d)....'%(rx[k]*N[k]*rx[k+1]))
Bp = tn.einsum('smnS,LSR->smnRL',A.cores[k],Phis[k+1]) # shape is Rp x N x N x r x r
B = tn.einsum('lsr,smnRL->lmLrnR',Phis[k],Bp)
B = tn.reshape(B,[rx[k]*N[k]*rx[k+1],rx[k]*N[k]*rx[k+1]])
solution_now = tn.linalg.solve(B,rhs)
res_old = tn.linalg.norm(B@previous_solution-rhs)/norm_rhs
res_new = tn.linalg.norm(B@solution_now-rhs)/norm_rhs
else:
# iterative solver
if verbose:
print('\t\tChoosing iterative solver %s (local size %d)....'%('GMRES' if local_solver==1 else 'BiCGSTAB_reset', rx[k]*N[k]*rx[k+1]))
time_local = datetime.datetime.now()
shape_now = [rx[k],N[k],rx[k+1]]
if use_single_precision:
Op = _LinearOp(Phis[k].to(tn.float32),Phis[k+1].to(tn.float32),A.cores[k].to(tn.float32),shape_now, preconditioner)
# solution_now, flag, nit, res_new = BiCGSTAB_reset(Op, rhs,previous_solution[:], eps_local, local_iterations)
eps_local = real_tol * norm_rhs
drhs = Op.matvec(previous_solution.to(tn.float32), False)
drhs = rhs.to(tn.float32)-drhs
eps_local = eps_local / tn.linalg.norm(drhs)
if local_solver == 1:
solution_now, flag, nit = gmres_restart(Op, drhs, previous_solution.to(tn.float32)*0, rhs.shape[0], local_iterations+1, eps_local, resets)
elif local_solver == 2:
solution_now, flag, nit, _ = BiCGSTAB_reset(Op, drhs, previous_solution.to(tn.float32)*0, eps_local, local_iterations)
else:
raise InvalidArguments('Solver not implemented.')
if preconditioner != None:
solution_now = Op.apply_prec(tn.reshape(solution_now,shape_now))
solution_now = tn.reshape(solution_now,[-1,1])
solution_now = previous_solution + solution_now.to(dtype)
res_old = tn.linalg.norm(Op.matvec(previous_solution.to(tn.float32), False).to(dtype)-rhs)/norm_rhs
res_new = tn.linalg.norm(Op.matvec(solution_now.to(tn.float32), False).to(dtype)-rhs)/norm_rhs
else:
Op = _LinearOp(Phis[k],Phis[k+1],A.cores[k],shape_now, preconditioner)
# solution_now, flag, nit, res_new = BiCGSTAB_reset(Op, rhs,previous_solution[:], eps_local, local_iterations)
eps_local = real_tol * norm_rhs
drhs = Op.matvec(previous_solution, False)
drhs = rhs-drhs
eps_local = eps_local / tn.linalg.norm(drhs)
if local_solver == 1:
solution_now, flag, nit = gmres_restart(Op, drhs, previous_solution*0, rhs.shape[0], local_iterations+1, eps_local, resets)
elif local_solver == 2:
solution_now, flag, nit, _ = BiCGSTAB_reset(Op, drhs, previous_solution*0, eps_local, local_iterations)
else:
raise InvalidArguments('Solver not implemented.')
if preconditioner != None:
solution_now = Op.apply_prec(tn.reshape(solution_now,shape_now))
solution_now = tn.reshape(solution_now,[-1,1])
solution_now = previous_solution + solution_now
res_old = tn.linalg.norm(Op.matvec(previous_solution, False)-rhs)/norm_rhs
res_new = tn.linalg.norm(Op.matvec(solution_now, False)-rhs)/norm_rhs
if verbose:
print('\t\tFinished with flag %d after %d iterations with relres %g (from %g)'%(flag,nit,res_new,eps_local))
time_local = datetime.datetime.now() - time_local
print('\t\tTime needed ',time_local)
# residual damp check
if res_old/res_new < damp and res_new > real_tol:
if verbose: print('WARNING: residual increases. res_old %g, res_new %g, real_tol %g'%(res_old,res_new,real_tol)) # warning (from tt toolbox)
# compute residual and step size
dx = tn.linalg.norm(solution_now-previous_solution)/tn.linalg.norm(solution_now)
if verbose:
print('\t\tdx = %g, res_now = %g, res_old = %g'%(dx,res_new,res_old))
max_dx = max(dx,max_dx)
max_res = max(max_res,res_old)
solution_now = tn.reshape(solution_now,[rx[k]*N[k],rx[k+1]])
# truncation
if k<d-1:
u, s, v = SVD(solution_now)
if trunc_norm == 'fro':
pass
else:
# search for a rank such that offeres small enough residuum
# TODO: binary search?
r = 0
for r in range(u.shape[1]-1,0,-1):
solution = u[:,:r] @ tn.diag(s[:r]) @ v[:r,:] # solution has the same size
# res = tn.linalg.norm(tn.reshape(local_product(Phis[k+1],Phis[k],A.cores[k],tn.reshape(solution,[rx[k],N[k],rx[k+1]]),solution_now.shape),[-1,1]) - rhs)/norm_rhs
if use_full:
res = tn.linalg.norm(B@tn.reshape(solution,[-1,1])-rhs)/norm_rhs
else:
# res = tn.linalg.norm(tn.reshape(local_product(Phis[k+1],Phis[k],A.cores[k],tn.reshape(solution,[rx[k],N[k],rx[k+1]]),solution_now.shape),[-1,1]) - rhs)/norm_rhs
res = tn.linalg.norm(Op.matvec(solution.to(tn.float32 if use_single_precision else dtype)).to(dtype)-rhs)/norm_rhs
if res > max(real_tol*damp,res_new):
break
r += 1
r = min([r,tn.numel(s),rmax[k+1]])
else:
u, v = QR(solution_now)
# v = v.t()
r = u.shape[1]
s = tn.ones(r, dtype = dtype, device = device)
u = u[:,:r]
v = tn.diag(s[:r]) @ v[:r,:]
v = v.t()
if not last:
czA = _local_product(Phiz[k+1], Phiz[k], A.cores[k], tn.reshape(u@v.t(),[rx[k],N[k],rx[k+1]]), [rx[k],N[k],rx[k+1]]) # shape rzp x N x rz
czy = tn.einsum('br,bnB,BR->rnR',Phiz_b[k],b.cores[k]*nrmsc,Phiz_b[k+1]) # shape is rzp x N x rz
cz_new = czy - czA
uz,_,_ = SVD(tn.reshape(cz_new, [rz[k]*N[k],rz[k+1]]))
cz_new = uz[:,:min(kickrank,uz.shape[1])] # truncate to kickrank
if k < d-1: # extend cz_new with random elements
cz_new = tn.cat((cz_new,tn.randn((cz_new.shape[0],kick2), dtype = dtype, device = device)),1)
qz,_ = QR(cz_new)
rz[k+1] = qz.shape[1]
z_cores[k] = tn.reshape(qz,[rz[k],N[k],rz[k+1]])
if k < d-1:
if not last:
left_res = _local_product(Phiz[k+1],Phis[k],A.cores[k],tn.reshape(u@v.t(),[rx[k],N[k],rx[k+1]]),[rx[k],N[k],rx[k+1]])
left_b = tn.einsum('br,bmB,BR->rmR',Phis_b[k],b.cores[k]*nrmsc,Phiz_b[k+1])
uk = left_b - left_res # rx_k x N_k x rz_k+1
u, Rmat = QR(tn.cat((u,tn.reshape(uk,[u.shape[0],-1])),1))
r_add = uk.shape[2]
v = tn.cat((v,tn.zeros([rx[k+1],r_add], dtype = dtype, device = device)), 1)
v = v @ Rmat.t()
r = u.shape[1]
v = tn.einsum('ji,jkl->ikl',v,x_cores[k+1])
# remove norm correction
nrmsc = nrmsc * normA[k] * normx[k] / normb[k]
norm_now = tn.linalg.norm(v)
if norm_now>0:
v = v / norm_now
else:
norm_now = 1.0
normx[k] = normx[k] * norm_now
x_cores[k] = tn.reshape(u, [rx[k],N[k],r])
x_cores[k+1] = tn.reshape(v, [r,N[k+1],rx[k+2]])
rx[k+1] = r
# next phis with norm correction
Phis[k+1] = _compute_phi_fwd_A(Phis[k], x_cores[k], A.cores[k], x_cores[k])
Phis_b[k+1] = _compute_phi_fwd_rhs(Phis_b[k], b.cores[k],x_cores[k])
# ... and norms
norm = tn.linalg.norm(Phis[k+1])
norm = norm if norm>0 else 1.0
normA[k] = norm
Phis[k+1] = Phis[k+1] / norm
norm = tn.linalg.norm(Phis_b[k+1])
norm = norm if norm>0 else 1.0
normb[k] = norm
Phis_b[k+1] = Phis_b[k+1] / norm
# norm correction
nrmsc = nrmsc * normb[k] / ( normA[k] * normx[k] )
# next phiz
if not last:
Phiz[k+1] = _compute_phi_fwd_A(Phiz[k], z_cores[k], A.cores[k], x_cores[k]) / normA[k]
Phiz_b[k+1] = _compute_phi_fwd_rhs(Phiz_b[k], b.cores[k],z_cores[k]) / normb[k]
else:
x_cores[k] = tn.reshape(u@tn.diag(s[:r]) @ v[:r,:].t(),[rx[k],N[k],rx[k+1]])
if verbose:
print('Solution rank is',rx)
print('Maxres ',max_res)
tme_sweep = datetime.datetime.now()-tme_sweep
print('Time ',tme_sweep)
if last:
break
if max_res < eps:
last = True
if verbose:
time_total = datetime.datetime.now() - time_total
print()
print('Finished after' ,swp+1,' sweeps and ',time_total)
print()
normx = np.exp(np.sum(np.log(normx))/d)
for k in range(d):
x_cores[k] = x_cores[k] * normx
x = torchtt.TT(x_cores)
return x
def _compute_phi_bck_A(Phi_now,core_left,core_A,core_right):
"""
Compute the phi backwards for the form dot(left,A @ right)
Args:
Phi_now (torch.tensor): The current phi. Has shape r1_k+1 x R_k+1 x r2_k+1
core_left (torch.tensor): the core on the left. Has shape r1_k x N_k x r1_k+1
core_A (torch.tensor): the core of the matrix. Has shape R_k x N_k x N_k x R_k
core_right (torch.tensor): the core to the right. Has shape r2_k x N_k x r2_k+1
Returns:
torch.tensor: The following phi (backward). Has shape r1_k x R_k x r2_k
"""
# Phip = tn.einsum('ijk,klm->ijlm',core_right,Phi_now)
# Phipp = tn.einsum('ijkl,abjk->ilba',Phip,core_A)
# Phi = tn.einsum('ijkl,akj->ila',Phipp,core_left)
Phi = oe.contract('LSR,lML,sMNS,rNR->lsr',Phi_now,core_left,core_A,core_right)
# print(oe.contract_path('LSR,lML,sMNS,rNR->lsr',Phi_now,core_left,core_A,core_right))
return Phi
def _compute_phi_fwd_A(Phi_now, core_left, core_A, core_right):
"""
Compute the phi forward for the form dot(left,A @ right)
Args:
Phi_now (torch.tensor): The current phi. Has shape r1_k x R_k x r2_k
core_left (torch.tensor): the core on the left. Has shape r1_k x N_k x r1_k+1
core_A (torch.tensor): the core of the matrix. Has shape R_k x N_k x N_k x R_k
core_right (torch.tensor): the core to the right. Has shape r2_k x N_k x r2_k+1
Returns:
torch.tensor: The following phi (backward). Has shape r1_k+1 x R_k+1 x r2_k+1
"""
# Psip = tn.einsum('ijk,kbc->ijbc', Phi_now, core_left) # shape is rk-1 x Rk-1 x Nk x rk
# Psipp = tn.einsum('ijkl,aijd->klad', core_A, Psip) # shape is nk x Rk x rk-1 x rk
# Phi_next= tn.einsum('ijk,jbid->kbd',core_right,Psipp) # shape is rk x Rk x rk
# tme1 = datetime.datetime.now()
# Phi_next = tn.einsum('lsr,lML,sMNS,rNR->LSR',Phi_now,core_left,core_A,core_right)
# tme1 = datetime.datetime.now() - tme1
# tme2 = datetime.datetime.now()
Phi_next = oe.contract('lsr,lML,sMNS,rNR->LSR',Phi_now,core_left,core_A,core_right)
# print(oe.contract_path('lsr,lML,sMNS,rNR->LSR',Phi_now,core_left,core_A,core_right))
# tme2 = datetime.datetime.now() - tme2
# print('\n>>>>>>>>>>>>>>>>>>>>>>>>>>Time1 ',tme1,' time 2', tme2)
return Phi_next
def _compute_phi_bck_rhs(Phi_now,core_b,core):
"""
Args:
Phi_now (torch.tensor): The current phi. Has shape rb_k+1 x r_k+1
core_b (torch.tensor): The current core of the rhs. Has shape rb_k x N_k x rb_k+1
core (torch.tensor): The current core. Has shape r_k x N_k x r_k+1
Returns:
torch.tensor: The backward phi corresponding to the rhs. Has shape rb_k x r_k
"""
#Phit = tn.einsum('ij,abj->iba',Phi_now,core_b)
#Phi = tn.einsum('ijk,kjc->ic',core,Phit)
Phi = oe.contract('BR,bnB,rnR->br',Phi_now,core_b,core)
return Phi
def _compute_phi_fwd_rhs(Phi_now,core_rhs,core):
"""
Args:
Phi_now (torch.tensor): The current phi. Has shape rb_k x r_k
core_b (torch.tensor): The current core of the rhs. Has shape rb_k x N_k+1 x rb_k+1
core (torch.tensor): The current core. Has shape r_k x N_k x r_k+1
Returns:
torch.tensor: The forward computer phi for the rhs. Has shape rb_k+1 x r_k+1
"""
# tmp = tn.einsum('ij,jbc->ibc',Phi_now,core_rhs) # shape rk-1 x Nk x rbk
# Phi_next = tn.einsum('ijk,ijc->kc',core,tmp)
Phi_next = oe.contract('br,bnB,rnR->BR', Phi_now, core_rhs, core)
return Phi_next
Functions
def amen_solve(A, b, nswp=22, x0=None, eps=1e-10, rmax=32768, max_full=500, kickrank=4, kick2=0, trunc_norm='res', local_solver=1, local_iterations=40, resets=2, verbose=False, preconditioner=None, use_cpp=True, use_single_precision=False)
-
Solve a multilinear system \mathsf{Ax} = \mathsf{b} in the Tensor Train format.
This method implements the algorithm from Sergey V Dolgov, Dmitry V Savostyanov, Alternating minimal energy methods for linear systems in higher dimensions.
Example
import torchtt A = torchtt.random([(4,4),(5,5),(6,6)],[1,2,3,1]) # create random matrix x = torchtt.random([4,5,6],[1,2,3,1]) # invent a random solution b = A @ x # compute the rhs xx = torchtt.solvers.amen_solve(A,b) # solve print((xx-x).norm()/x.norm()) # error
Args
A
:TT
- the system matrix in TT.
b
:TT
- the right hand side in TT.
nswp
:int
, optional- number of sweeps. Defaults to 22.
x0
:TT
, optional- initial guess. In None is provided the initial guess is a ones tensor. Defaults to None.
eps
:float
, optional- relative residual. Defaults to 1e-10.
rmax
:int
, optional- maximum rank. Defaults to 100000.
max_full
:int
, optional- the maximum size of the core until direct solver is used for the local subproblem. Defaults to 500.
kickrank
:int
, optional- rank enrichment. Defaults to 4.
kick2
:int
, optional- [description]. Defaults to 0.
trunc_norm
:str
, optional- [description]. Defaults to 'res'.
local_solver
:int
, optional- choose local iterative solver: 1 for GMRES and 2 for BiCGSTAB. Defaults to 1.
local_iterations
:int
, optional- number of GMRES iterations for the local subproblems. Defaults to 40.
resets
:int
, optional- number of resets in the GMRES. Defaults to 2.
verbose
:bool
, optional- choose whether to display or not additional information during the runtime. Defaults to True.
preconditioner
:string
, optional- Choose the preconditioner for the local system. Possible values are None, 'c' (central Jacobi preconditioner). No preconditioner is used if None is provided. Defaults to None.
use_cpp
:bool
, optional- use the C++ implementation of AMEn. Defaults to True.
Raises
InvalidArguments
- A and b must be TT instances.
InvalidArguments
- Invalid preconditioner.
IncompatibleTypes
- A must be TT-matrix and b must be vector.
ShapeMismatch
- A is not quadratic.
ShapeMismatch
- Dimension mismatch.
Returns
TT
- the approximation of the solution in TT format.
Expand source code
def amen_solve(A, b, nswp = 22, x0 = None, eps = 1e-10,rmax = 32768, max_full = 500, kickrank = 4, kick2 = 0, trunc_norm = 'res', local_solver = 1, local_iterations = 40, resets = 2, verbose = False, preconditioner = None, use_cpp = True, use_single_precision = False): """ Solve a multilinear system \(\\mathsf{Ax} = \\mathsf{b}\) in the Tensor Train format. This method implements the algorithm from [Sergey V Dolgov, Dmitry V Savostyanov, Alternating minimal energy methods for linear systems in higher dimensions](https://epubs.siam.org/doi/abs/10.1137/140953289). Example: ``` import torchtt A = torchtt.random([(4,4),(5,5),(6,6)],[1,2,3,1]) # create random matrix x = torchtt.random([4,5,6],[1,2,3,1]) # invent a random solution b = A @ x # compute the rhs xx = torchtt.solvers.amen_solve(A,b) # solve print((xx-x).norm()/x.norm()) # error ``` Args: A (torchtt.TT): the system matrix in TT. b (torchtt.TT): the right hand side in TT. nswp (int, optional): number of sweeps. Defaults to 22. x0 (torchtt.TT, optional): initial guess. In None is provided the initial guess is a ones tensor. Defaults to None. eps (float, optional): relative residual. Defaults to 1e-10. rmax (int, optional): maximum rank. Defaults to 100000. max_full (int, optional): the maximum size of the core until direct solver is used for the local subproblem. Defaults to 500. kickrank (int, optional): rank enrichment. Defaults to 4. kick2 (int, optional): [description]. Defaults to 0. trunc_norm (str, optional): [description]. Defaults to 'res'. local_solver (int, optional): choose local iterative solver: 1 for GMRES and 2 for BiCGSTAB. Defaults to 1. local_iterations (int, optional): number of GMRES iterations for the local subproblems. Defaults to 40. resets (int, optional): number of resets in the GMRES. Defaults to 2. verbose (bool, optional): choose whether to display or not additional information during the runtime. Defaults to True. preconditioner (string, optional): Choose the preconditioner for the local system. Possible values are None, 'c' (central Jacobi preconditioner). No preconditioner is used if None is provided. Defaults to None. use_cpp (bool, optional): use the C++ implementation of AMEn. Defaults to True. Raises: InvalidArguments: A and b must be TT instances. InvalidArguments: Invalid preconditioner. IncompatibleTypes: A must be TT-matrix and b must be vector. ShapeMismatch: A is not quadratic. ShapeMismatch: Dimension mismatch. Returns: torchtt.TT: the approximation of the solution in TT format. """ # perform checks of the input data if not (isinstance(A,torchtt.TT) and isinstance(b,torchtt.TT)): raise InvalidArguments('A and b must be TT instances.') if not (A.is_ttm and not b.is_ttm) : raise IncompatibleTypes('A must be TT-matrix and b must be vector.') if A.M != A.N: raise ShapeMismatch('A is not quadratic.') if A.N != b.N: raise ShapeMismatch('Dimension mismatch.') if use_cpp and _flag_use_cpp: if x0 == None: x_cores = [] x_R = [1]*(1+len(A.N)) else: x_cores = x0.cores x_R = x0.R if preconditioner == None: prec = 0 elif preconditioner == 'c': prec = 1 elif preconditioner == 'r': prec = 2 else: raise InvalidArguments("Invalid preconditioner.") cores = torchttcpp.amen_solve(A.cores, b.cores, x_cores, b.N, A.R, b.R, x_R, nswp, eps, rmax, max_full, kickrank, kick2, local_iterations, resets, verbose, prec) return torchtt.TT(list(cores)) else: return _amen_solve_python(A, b, nswp, x0, eps,rmax, max_full, kickrank, kick2, trunc_norm, local_solver, local_iterations, resets, verbose, preconditioner, use_single_precision)
def cpp_enabled()
-
Is the C++ backend enabled?
Returns
bool
- the flag
Expand source code
def cpp_enabled(): """ Is the C++ backend enabled? Returns: bool: the flag """ return _flag_use_cpp