Module TTCME.TimeIntegrator
Tensor train integrator for linear ODEs in the TT format (implements tAMEn)
Expand source code
"""
Tensor train integrator for linear ODEs in the TT format (implements [tAMEn](https://arxiv.org/abs/1403.8085))
"""
import torch as tn
import torchtt as tntt
import numpy as np
from .basis import *
class TTInt():
def __init__(self, Operator, epsilon = 1e-6, N_max = 64, dt_max = 1e-1, method = 'implicit-euler'):
"""
The ODE integrator in the TT format
$$ \\frac{\\text{d} \mathsf{p}(t)}{\text{d} t}\\mathsf{Ap} $$
Args:
Operator (torchtt.TT): the l
epsilon (float, optional): solver accuracy. Defaults to 1e-6.
N_max (int, optional): maximum mode size for the time dimension. Defaults to 64.
dt_max (float, optional): maximum timestep. Defaults to 1e-1.
method (str, optional): the discretization method of choice. Possibilities are `'implicit-euler'`, `'crank–nicolson'`, `'cheby'` and `'legendre'`. Defaults to 'implicit-euler'.
"""
self.__A_tt = Operator
self.__epsilon = epsilon
self.__N_max = N_max
self.__method = method
def _get_SP(self,T,N):
if self.__method == 'implicit-euler':
S = np.eye(N)-np.diag(np.ones(N-1),-1)
P = T*np.eye(N)/N
ev = (S@np.ones((N,1))).flatten()
basis = None
elif self.__method == 'crank–nicolson':
S = np.eye(N)-np.diag(np.ones(N-1),-1)
P = np.eye(N)+np.diag(np.ones(N-1),-1)
P[0,:] = 0
P = P * T / (2*(N-1))
ev = (S@np.ones((N,1))).flatten()
basis = None
elif self.__method == 'cheby':
basis = ChebyBasis(N,[0,T])
S = basis.stiff+np.outer(basis(np.array([0])).flatten(),basis(np.array([0])).flatten())
P = basis.mass
ev = basis(np.array([0])).flatten()
elif self.__method == 'legendre':
basis = LegendreBasis(N,[0,T])
S = basis.stiff+np.outer(basis(np.array([0])).flatten(),basis(np.array([0])).flatten())
P = basis.mass
ev = basis(np.array([0])).flatten()
return S,P,ev,basis
def solve(self, initial_tt, T, intervals = None, return_all = False,nswp = 40,qtt = False,verb = False,rounding = True, device = 'cpu'):
"""
Solve for the time interval of length `T` for the initial value `initial_tt`.
Args:
initial_tt (torchtt.TT): the initial value in the TT format.
T (float): interval length.
intervals (_type_, optional): _description_. Defaults to None.
return_all (bool, optional): return the solution after all subintervals. Defaults to False.
nswp (int, optional): number of sweeps for the AMEn solver. Defaults to 40.
qtt (bool, optional): use QTT or not. Defaults to False.
verb (bool, optional): display additional information during runtime. Defaults to False.
rounding (bool, optional): perform rounding after the individual subintervals. Defaults to True.
device (str, optional): the device the tensor should be saved on (set to `'cuda:0'` if a GPU is available). Defaults to 'cpu'.
Returns:
torchtt.TT: the result.
"""
dev = self.__A_tt.cores[0].device
if intervals == None:
pass
else:
x_tt = initial_tt
dT = T / intervals
Nt = self.__N_max
S,P,ev,basis = self._get_SP(dT,Nt)
S = tn.tensor(S).to(dev)
P = tn.tensor(P).to(dev)
ev= tn.tensor(ev).to(dev)
if qtt:
nqtt = int(np.log2(Nt))
S = tntt.rank1TT([S]).to_qtt()
P = tntt.rank1TT([P]).to_qtt()
I_tt = tntt.eye(self.__A_tt.N).to(dev)
B_tt = I_tt ** S - (I_tt ** P) @ (self.__A_tt ** tntt.eye([Nt]).to(dev).to_qtt())
else:
nqtt = 1
S = tntt.rank1TT([S])
P = tntt.rank1TT([P])
I_tt = tntt.eye(self.__A_tt.N).to(dev)
B_tt = I_tt ** S - (I_tt ** P) @ (self.__A_tt ** tntt.eye([Nt]).to(dev))
# print(dT,T,intervals)
returns = []
for i in range(intervals):
# print(i)
if qtt:
f_tt = x_tt ** tntt.TT(ev).to(dev).to_qtt()
else:
f_tt = x_tt ** tntt.TT(ev).to(dev)
# print(B_tt.n,f_tt.n)
try:
# xs_tt = xs_tt.round(1e-10,5)
# tme = datetime.datetime.now()
if device != None:
xs_tt = tntt.solvers.amen_solve(B_tt.to(device), f_tt.to(device), x0 = self.xs_tt.to(device), eps = self.__epsilon, verbose = verb, nswp = nswp, kickrank = 8, rmax = 2000, preconditioner=None ).cpu()
else:
xs_tt = tntt.solvers.amen_solve(B_tt, f_tt, x0 = self.xs_tt, eps = self.__epsilon, verbose = verb, nswp = nswp, kickrank = 8, rmax = 2000, preconditioner=None )
# tme = datetime.datetime.now() - tme
# print(tme)
self.xs_tt = xs_tt
except:
# tme = datetime.datetime.now()
if device != None:
xs_tt = tntt.solvers.amen_solve(B_tt.to(device), f_tt.to(device), eps = self.__epsilon, verbose = verb, nswp = nswp, kickrank = 8, rmax=2000, preconditioner=None ).cpu()
else:
xs_tt = tntt.solvers.amen_solve(B_tt, f_tt, eps = self.__epsilon, verbose = verb, nswp = nswp, kickrank = 8, rmax=2000, preconditioner=None )
# tme = datetime.datetime.now() - tme
# print(tme)
self.xs_tt = xs_tt
# print('SIZE',tt_size(xs_tt)/1e6)
# print('PLMMM',tt.sum(xs_tt),xs_tt.r)
if basis == None:
if return_all: returns.append(xs_tt)
x_tt = xs_tt[tuple([slice(None,None,None)]*len(self.__A_tt.N)+[-1]*nqtt)]
x_tt = x_tt.round(self.__epsilon/10)
else:
if return_all:
if qtt:
beval = basis(np.array([0])).flatten()
temp1 = xs_tt* ( tntt.ones(self.__A_tt.N) ** tntt.TT(beval).to_qtt())
for l in range(nqtt): temp1 = temp1.sum(len(temp1.N)-1)
beval = basis(np.array([dT])).flatten()
temp2 = xs_tt *(tntt.ones(self.__A_tt.N) ** tntt.TT(beval).to_qtt())
for l in range(nqtt): temp2 = temp2.sum(len(temp2.N)-1)
returns.append((temp1 ** tntt.TT(np.array([1.0,0.0])))+(temp2 ** tntt.TT(np.array([0.0,1.0]))) )
else:
beval = basis(np.array([0])).flatten()
temp1 = xs_tt * (tntt.ones(self.__A_tt.N) ** tntt.TT(beval))
temp1 = temp1.sum(len(temp1.n)-1)
beval = basis(np.array([dT])).flatten()
temp2 = xs_tt * (tntt.ones(self.__A_tt.N) ** tntt.TT(beval))
temp2 = temp2.sum(len(temp2.n)-1)
returns.append((temp1 ** tntt.TT(np.array([1.0,0.0])))+(temp2 ** tntt.TT(np.array([0.0,1.0]))))
beval = basis(np.array([dT])).flatten()
if qtt:
x_tt = xs_tt * (tntt.ones(self.__A_tt.N) ** tntt.TT(beval).to_qtt())
for l in range(nqtt): x_tt = x_tt.sum(len(x_tt.N)-1)
if rounding: x_tt = x_tt.round(self.__epsilon/10)
else:
x_tt = (xs_tt * (tntt.ones(self.__A_tt.N) ** tntt.TT(beval))).sum(len(xs_tt.N)-1)
if rounding: x_tt = x_tt.round(self.__epsilon/10)
# print('SIZE 2 ',tt_size(x_tt)/1e6)
if not return_all: returns = x_tt
return returns
Classes
class TTInt (Operator, epsilon=1e-06, N_max=64, dt_max=0.1, method='implicit-euler')-
The ODE integrator in the TT format \frac{\text{d} \mathsf{p}(t)}{ ext{d} t}\mathsf{Ap}
Args
Operator:torchtt.TT- the l
epsilon:float, optional- solver accuracy. Defaults to 1e-6.
N_max:int, optional- maximum mode size for the time dimension. Defaults to 64.
dt_max:float, optional- maximum timestep. Defaults to 1e-1.
method:str, optional- the discretization method of choice. Possibilities are
'implicit-euler','crank–nicolson','cheby'and'legendre'. Defaults to 'implicit-euler'.
Expand source code
class TTInt(): def __init__(self, Operator, epsilon = 1e-6, N_max = 64, dt_max = 1e-1, method = 'implicit-euler'): """ The ODE integrator in the TT format $$ \\frac{\\text{d} \mathsf{p}(t)}{\text{d} t}\\mathsf{Ap} $$ Args: Operator (torchtt.TT): the l epsilon (float, optional): solver accuracy. Defaults to 1e-6. N_max (int, optional): maximum mode size for the time dimension. Defaults to 64. dt_max (float, optional): maximum timestep. Defaults to 1e-1. method (str, optional): the discretization method of choice. Possibilities are `'implicit-euler'`, `'crank–nicolson'`, `'cheby'` and `'legendre'`. Defaults to 'implicit-euler'. """ self.__A_tt = Operator self.__epsilon = epsilon self.__N_max = N_max self.__method = method def _get_SP(self,T,N): if self.__method == 'implicit-euler': S = np.eye(N)-np.diag(np.ones(N-1),-1) P = T*np.eye(N)/N ev = (S@np.ones((N,1))).flatten() basis = None elif self.__method == 'crank–nicolson': S = np.eye(N)-np.diag(np.ones(N-1),-1) P = np.eye(N)+np.diag(np.ones(N-1),-1) P[0,:] = 0 P = P * T / (2*(N-1)) ev = (S@np.ones((N,1))).flatten() basis = None elif self.__method == 'cheby': basis = ChebyBasis(N,[0,T]) S = basis.stiff+np.outer(basis(np.array([0])).flatten(),basis(np.array([0])).flatten()) P = basis.mass ev = basis(np.array([0])).flatten() elif self.__method == 'legendre': basis = LegendreBasis(N,[0,T]) S = basis.stiff+np.outer(basis(np.array([0])).flatten(),basis(np.array([0])).flatten()) P = basis.mass ev = basis(np.array([0])).flatten() return S,P,ev,basis def solve(self, initial_tt, T, intervals = None, return_all = False,nswp = 40,qtt = False,verb = False,rounding = True, device = 'cpu'): """ Solve for the time interval of length `T` for the initial value `initial_tt`. Args: initial_tt (torchtt.TT): the initial value in the TT format. T (float): interval length. intervals (_type_, optional): _description_. Defaults to None. return_all (bool, optional): return the solution after all subintervals. Defaults to False. nswp (int, optional): number of sweeps for the AMEn solver. Defaults to 40. qtt (bool, optional): use QTT or not. Defaults to False. verb (bool, optional): display additional information during runtime. Defaults to False. rounding (bool, optional): perform rounding after the individual subintervals. Defaults to True. device (str, optional): the device the tensor should be saved on (set to `'cuda:0'` if a GPU is available). Defaults to 'cpu'. Returns: torchtt.TT: the result. """ dev = self.__A_tt.cores[0].device if intervals == None: pass else: x_tt = initial_tt dT = T / intervals Nt = self.__N_max S,P,ev,basis = self._get_SP(dT,Nt) S = tn.tensor(S).to(dev) P = tn.tensor(P).to(dev) ev= tn.tensor(ev).to(dev) if qtt: nqtt = int(np.log2(Nt)) S = tntt.rank1TT([S]).to_qtt() P = tntt.rank1TT([P]).to_qtt() I_tt = tntt.eye(self.__A_tt.N).to(dev) B_tt = I_tt ** S - (I_tt ** P) @ (self.__A_tt ** tntt.eye([Nt]).to(dev).to_qtt()) else: nqtt = 1 S = tntt.rank1TT([S]) P = tntt.rank1TT([P]) I_tt = tntt.eye(self.__A_tt.N).to(dev) B_tt = I_tt ** S - (I_tt ** P) @ (self.__A_tt ** tntt.eye([Nt]).to(dev)) # print(dT,T,intervals) returns = [] for i in range(intervals): # print(i) if qtt: f_tt = x_tt ** tntt.TT(ev).to(dev).to_qtt() else: f_tt = x_tt ** tntt.TT(ev).to(dev) # print(B_tt.n,f_tt.n) try: # xs_tt = xs_tt.round(1e-10,5) # tme = datetime.datetime.now() if device != None: xs_tt = tntt.solvers.amen_solve(B_tt.to(device), f_tt.to(device), x0 = self.xs_tt.to(device), eps = self.__epsilon, verbose = verb, nswp = nswp, kickrank = 8, rmax = 2000, preconditioner=None ).cpu() else: xs_tt = tntt.solvers.amen_solve(B_tt, f_tt, x0 = self.xs_tt, eps = self.__epsilon, verbose = verb, nswp = nswp, kickrank = 8, rmax = 2000, preconditioner=None ) # tme = datetime.datetime.now() - tme # print(tme) self.xs_tt = xs_tt except: # tme = datetime.datetime.now() if device != None: xs_tt = tntt.solvers.amen_solve(B_tt.to(device), f_tt.to(device), eps = self.__epsilon, verbose = verb, nswp = nswp, kickrank = 8, rmax=2000, preconditioner=None ).cpu() else: xs_tt = tntt.solvers.amen_solve(B_tt, f_tt, eps = self.__epsilon, verbose = verb, nswp = nswp, kickrank = 8, rmax=2000, preconditioner=None ) # tme = datetime.datetime.now() - tme # print(tme) self.xs_tt = xs_tt # print('SIZE',tt_size(xs_tt)/1e6) # print('PLMMM',tt.sum(xs_tt),xs_tt.r) if basis == None: if return_all: returns.append(xs_tt) x_tt = xs_tt[tuple([slice(None,None,None)]*len(self.__A_tt.N)+[-1]*nqtt)] x_tt = x_tt.round(self.__epsilon/10) else: if return_all: if qtt: beval = basis(np.array([0])).flatten() temp1 = xs_tt* ( tntt.ones(self.__A_tt.N) ** tntt.TT(beval).to_qtt()) for l in range(nqtt): temp1 = temp1.sum(len(temp1.N)-1) beval = basis(np.array([dT])).flatten() temp2 = xs_tt *(tntt.ones(self.__A_tt.N) ** tntt.TT(beval).to_qtt()) for l in range(nqtt): temp2 = temp2.sum(len(temp2.N)-1) returns.append((temp1 ** tntt.TT(np.array([1.0,0.0])))+(temp2 ** tntt.TT(np.array([0.0,1.0]))) ) else: beval = basis(np.array([0])).flatten() temp1 = xs_tt * (tntt.ones(self.__A_tt.N) ** tntt.TT(beval)) temp1 = temp1.sum(len(temp1.n)-1) beval = basis(np.array([dT])).flatten() temp2 = xs_tt * (tntt.ones(self.__A_tt.N) ** tntt.TT(beval)) temp2 = temp2.sum(len(temp2.n)-1) returns.append((temp1 ** tntt.TT(np.array([1.0,0.0])))+(temp2 ** tntt.TT(np.array([0.0,1.0])))) beval = basis(np.array([dT])).flatten() if qtt: x_tt = xs_tt * (tntt.ones(self.__A_tt.N) ** tntt.TT(beval).to_qtt()) for l in range(nqtt): x_tt = x_tt.sum(len(x_tt.N)-1) if rounding: x_tt = x_tt.round(self.__epsilon/10) else: x_tt = (xs_tt * (tntt.ones(self.__A_tt.N) ** tntt.TT(beval))).sum(len(xs_tt.N)-1) if rounding: x_tt = x_tt.round(self.__epsilon/10) # print('SIZE 2 ',tt_size(x_tt)/1e6) if not return_all: returns = x_tt return returnsMethods
def solve(self, initial_tt, T, intervals=None, return_all=False, nswp=40, qtt=False, verb=False, rounding=True, device='cpu')-
Solve for the time interval of length
Tfor the initial valueinitial_tt.Args
initial_tt:torchtt.TT- the initial value in the TT format.
T:float- interval length.
intervals:_type_, optional- description. Defaults to None.
return_all:bool, optional- return the solution after all subintervals. Defaults to False.
nswp:int, optional- number of sweeps for the AMEn solver. Defaults to 40.
qtt:bool, optional- use QTT or not. Defaults to False.
verb:bool, optional- display additional information during runtime. Defaults to False.
rounding:bool, optional- perform rounding after the individual subintervals. Defaults to True.
device:str, optional- the device the tensor should be saved on (set to
'cuda:0'if a GPU is available). Defaults to 'cpu'.
Returns
torchtt.TT- the result.
Expand source code
def solve(self, initial_tt, T, intervals = None, return_all = False,nswp = 40,qtt = False,verb = False,rounding = True, device = 'cpu'): """ Solve for the time interval of length `T` for the initial value `initial_tt`. Args: initial_tt (torchtt.TT): the initial value in the TT format. T (float): interval length. intervals (_type_, optional): _description_. Defaults to None. return_all (bool, optional): return the solution after all subintervals. Defaults to False. nswp (int, optional): number of sweeps for the AMEn solver. Defaults to 40. qtt (bool, optional): use QTT or not. Defaults to False. verb (bool, optional): display additional information during runtime. Defaults to False. rounding (bool, optional): perform rounding after the individual subintervals. Defaults to True. device (str, optional): the device the tensor should be saved on (set to `'cuda:0'` if a GPU is available). Defaults to 'cpu'. Returns: torchtt.TT: the result. """ dev = self.__A_tt.cores[0].device if intervals == None: pass else: x_tt = initial_tt dT = T / intervals Nt = self.__N_max S,P,ev,basis = self._get_SP(dT,Nt) S = tn.tensor(S).to(dev) P = tn.tensor(P).to(dev) ev= tn.tensor(ev).to(dev) if qtt: nqtt = int(np.log2(Nt)) S = tntt.rank1TT([S]).to_qtt() P = tntt.rank1TT([P]).to_qtt() I_tt = tntt.eye(self.__A_tt.N).to(dev) B_tt = I_tt ** S - (I_tt ** P) @ (self.__A_tt ** tntt.eye([Nt]).to(dev).to_qtt()) else: nqtt = 1 S = tntt.rank1TT([S]) P = tntt.rank1TT([P]) I_tt = tntt.eye(self.__A_tt.N).to(dev) B_tt = I_tt ** S - (I_tt ** P) @ (self.__A_tt ** tntt.eye([Nt]).to(dev)) # print(dT,T,intervals) returns = [] for i in range(intervals): # print(i) if qtt: f_tt = x_tt ** tntt.TT(ev).to(dev).to_qtt() else: f_tt = x_tt ** tntt.TT(ev).to(dev) # print(B_tt.n,f_tt.n) try: # xs_tt = xs_tt.round(1e-10,5) # tme = datetime.datetime.now() if device != None: xs_tt = tntt.solvers.amen_solve(B_tt.to(device), f_tt.to(device), x0 = self.xs_tt.to(device), eps = self.__epsilon, verbose = verb, nswp = nswp, kickrank = 8, rmax = 2000, preconditioner=None ).cpu() else: xs_tt = tntt.solvers.amen_solve(B_tt, f_tt, x0 = self.xs_tt, eps = self.__epsilon, verbose = verb, nswp = nswp, kickrank = 8, rmax = 2000, preconditioner=None ) # tme = datetime.datetime.now() - tme # print(tme) self.xs_tt = xs_tt except: # tme = datetime.datetime.now() if device != None: xs_tt = tntt.solvers.amen_solve(B_tt.to(device), f_tt.to(device), eps = self.__epsilon, verbose = verb, nswp = nswp, kickrank = 8, rmax=2000, preconditioner=None ).cpu() else: xs_tt = tntt.solvers.amen_solve(B_tt, f_tt, eps = self.__epsilon, verbose = verb, nswp = nswp, kickrank = 8, rmax=2000, preconditioner=None ) # tme = datetime.datetime.now() - tme # print(tme) self.xs_tt = xs_tt # print('SIZE',tt_size(xs_tt)/1e6) # print('PLMMM',tt.sum(xs_tt),xs_tt.r) if basis == None: if return_all: returns.append(xs_tt) x_tt = xs_tt[tuple([slice(None,None,None)]*len(self.__A_tt.N)+[-1]*nqtt)] x_tt = x_tt.round(self.__epsilon/10) else: if return_all: if qtt: beval = basis(np.array([0])).flatten() temp1 = xs_tt* ( tntt.ones(self.__A_tt.N) ** tntt.TT(beval).to_qtt()) for l in range(nqtt): temp1 = temp1.sum(len(temp1.N)-1) beval = basis(np.array([dT])).flatten() temp2 = xs_tt *(tntt.ones(self.__A_tt.N) ** tntt.TT(beval).to_qtt()) for l in range(nqtt): temp2 = temp2.sum(len(temp2.N)-1) returns.append((temp1 ** tntt.TT(np.array([1.0,0.0])))+(temp2 ** tntt.TT(np.array([0.0,1.0]))) ) else: beval = basis(np.array([0])).flatten() temp1 = xs_tt * (tntt.ones(self.__A_tt.N) ** tntt.TT(beval)) temp1 = temp1.sum(len(temp1.n)-1) beval = basis(np.array([dT])).flatten() temp2 = xs_tt * (tntt.ones(self.__A_tt.N) ** tntt.TT(beval)) temp2 = temp2.sum(len(temp2.n)-1) returns.append((temp1 ** tntt.TT(np.array([1.0,0.0])))+(temp2 ** tntt.TT(np.array([0.0,1.0])))) beval = basis(np.array([dT])).flatten() if qtt: x_tt = xs_tt * (tntt.ones(self.__A_tt.N) ** tntt.TT(beval).to_qtt()) for l in range(nqtt): x_tt = x_tt.sum(len(x_tt.N)-1) if rounding: x_tt = x_tt.round(self.__epsilon/10) else: x_tt = (xs_tt * (tntt.ones(self.__A_tt.N) ** tntt.TT(beval))).sum(len(xs_tt.N)-1) if rounding: x_tt = x_tt.round(self.__epsilon/10) # print('SIZE 2 ',tt_size(x_tt)/1e6) if not return_all: returns = x_tt return returns