00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 #ifndef O2SCL_ODE_IT_SOLVE_H
00024 #define O2SCL_ODE_IT_SOLVE_H
00025
00026 #include <iostream>
00027
00028 #include <gsl/gsl_linalg.h>
00029 #include <gsl/gsl_ieee_utils.h>
00030
00031 #include <o2scl/array.h>
00032 #include <o2scl/ovector_tlate.h>
00033 #include <o2scl/omatrix_tlate.h>
00034 #include <o2scl/uvector_tlate.h>
00035 #include <o2scl/umatrix_tlate.h>
00036 #include <o2scl/test_mgr.h>
00037
00038 #include <o2scl/coord.h>
00039 #include <o2scl/bicg.h>
00040 #include <o2scl/bicgstab.h>
00041 #include <o2scl/gmres.h>
00042 #include <o2scl/ir.h>
00043
00044 #ifndef DOXYGENP
00045 namespace o2scl {
00046 #endif
00047
00048
00049 template<class vec_t=o2scl::ovector_view> class ode_it_funct {
00050
00051 public:
00052
00053 ode_it_funct() {}
00054
00055 virtual ~ode_it_funct() {}
00056
00057
00058 virtual double operator()(size_t ieq, double x, vec_t &y) {
00059 return 0.0;
00060 }
00061
00062 };
00063
00064
00065 template<class vec_t=o2scl::ovector_view>
00066 class ode_it_funct_fptr : public ode_it_funct<vec_t> {
00067
00068 public:
00069
00070 virtual ~ode_it_funct_fptr() {}
00071
00072
00073 ode_it_funct_fptr(double (*fp)(size_t, double, vec_t &)) {
00074 fptr=fp;
00075 }
00076
00077
00078 virtual double operator()(size_t ieq, double x, vec_t &y) {
00079 return fptr(ieq,x,y);
00080 }
00081
00082 protected:
00083
00084 ode_it_funct_fptr() {}
00085
00086
00087 double (*fptr)(size_t ieq, double x, vec_t &y);
00088 };
00089
00090
00091 template<class tclass, class vec_t=o2scl::ovector_view>
00092 class ode_it_funct_mfptr : public ode_it_funct<vec_t> {
00093 public:
00094
00095 virtual ~ode_it_funct_mfptr() {}
00096
00097
00098 ode_it_funct_mfptr(tclass *tp,
00099 double (tclass::*fp)(size_t, double, vec_t &)) {
00100 tptr=tp;
00101 fptr=fp;
00102 }
00103
00104
00105 virtual double operator()(size_t ieq, double x, vec_t &y) {
00106 return (*tptr.*fptr)(ieq,x,y);
00107 }
00108
00109 protected:
00110
00111 ode_it_funct_mfptr() {}
00112
00113
00114 tclass *tptr;
00115
00116
00117 double (tclass::*fptr)(size_t ieq, double x, vec_t &y);
00118
00119 };
00120
00121
00122
00123
00124 template<class vec_t, class mat_t> class linear_solver {
00125 public:
00126 virtual ~linear_solver() {}
00127 virtual int solve(size_t n, mat_t &a, vec_t &b, vec_t &x) {
00128 return 0;
00129 }
00130 };
00131
00132
00133 class gsl_LU_solver : public linear_solver<o2scl::ovector_view,
00134 o2scl::omatrix_view> {
00135
00136 public:
00137
00138
00139 virtual int solve(size_t n, o2scl::omatrix_view &A, o2scl::ovector_view &b,
00140 o2scl::ovector_view &x) {
00141 gsl_permutation *P=gsl_permutation_alloc(n);
00142 int s;
00143 gsl_linalg_LU_decomp((gsl_matrix *)&A,P,&s);
00144 gsl_linalg_LU_solve((gsl_matrix *)&A,P,(gsl_vector *)&b,
00145 (gsl_vector *)&x);
00146 gsl_permutation_free(P);
00147
00148 return 0;
00149 };
00150
00151 virtual ~gsl_LU_solver() {}
00152
00153 };
00154
00155
00156 class gsl_QR_solver : public linear_solver<o2scl::ovector_view,
00157 o2scl::omatrix_view> {
00158
00159 public:
00160
00161
00162 virtual int solve(size_t n, o2scl::omatrix_view &A, o2scl::ovector_view &b,
00163 o2scl::ovector_view &x) {
00164 o2scl::ovector tau(n);
00165 gsl_linalg_QR_decomp((gsl_matrix *)&A,(gsl_vector *)&tau);
00166 gsl_linalg_QR_solve((gsl_matrix *)&A,(gsl_vector *)&tau,
00167 (gsl_vector *)&b,(gsl_vector *)&x);
00168 return 0;
00169 };
00170
00171 virtual ~gsl_QR_solver() {}
00172
00173 };
00174
00175
00176 class gsl_HH_solver : public linear_solver<o2scl::ovector_view,
00177 o2scl::omatrix_view> {
00178
00179 public:
00180
00181
00182 virtual int solve(size_t n, o2scl::omatrix_view &A, o2scl::ovector_view &b,
00183 o2scl::ovector_view &x) {
00184 return gsl_linalg_HH_solve((gsl_matrix *)(&A),(gsl_vector *)(&b),
00185 (gsl_vector *)(&x));
00186 };
00187
00188 virtual ~gsl_HH_solver() {}
00189
00190 };
00191
00192
00193
00194
00195
00196
00197 class ode_it_make_Coord {
00198
00199 public:
00200
00201
00202 o2scl::uvector_int *r;
00203
00204
00205 o2scl::uvector_int *c;
00206
00207
00208 o2scl::uvector *vals;
00209
00210
00211 o2scl::Coord_Mat *make(size_t ngrid, size_t neq, size_t nbleft) {
00212 size_t N, nz, nbright=neq-nbleft;
00213 N=ngrid*neq;
00214
00215 nz=neq*neq*2*(ngrid-1)+neq*neq;
00216 r=new o2scl::uvector_int(nz);
00217 c=new o2scl::uvector_int(nz);
00218 vals=new o2scl::uvector(nz);
00219 vals->set_all(1.0);
00220 size_t row=0;
00221 size_t ix=0;
00222 for(size_t i=0;i<nbleft;i++) {
00223 for(size_t j=0;j<neq;j++) {
00224 (*r)[ix]=row;
00225 (*c)[ix]=j;
00226 (*vals)[ix]=0.0;
00227 ix++;
00228 }
00229 row++;
00230 }
00231 for(size_t k=0;k<ngrid-1;k++) {
00232 size_t lhs=k*neq;
00233 for(size_t i=0;i<neq;i++) {
00234 for(size_t j=0;j<neq;j++) {
00235 (*r)[ix]=row;
00236 (*c)[ix]=lhs+j;
00237 ix++;
00238 (*r)[ix]=row;
00239 (*c)[ix]=lhs+j+neq;
00240 ix++;
00241 }
00242 row++;
00243 }
00244 }
00245 for(size_t i=0;i<nbright;i++) {
00246 size_t lhs=neq*(ngrid-1);
00247 for(size_t j=0;j<neq;j++) {
00248 (*r)[ix]=row;
00249 (*c)[ix]=lhs+j;
00250 ix++;
00251 }
00252 row++;
00253 }
00254 o2scl::Coord_Mat *M=new o2scl::Coord_Mat(N,N,nz,*vals,*r,*c,0);
00255 return M;
00256 }
00257
00258 };
00259
00260
00261
00262
00263
00264
00265
00266
00267
00268 template <class func_t, class vec_t, class mat_t, class matrix_row_t,
00269 class solver_vec_t, class solver_mat_t>
00270 class ode_it_solve {
00271
00272 public:
00273
00274 ode_it_solve() {
00275 h=1.0e-4;
00276 niter=30;
00277 tolf=1.0e-8;
00278 verbose=0;
00279 }
00280
00281 virtual ~ode_it_solve() {}
00282
00283
00284 int verbose;
00285
00286
00287 double h;
00288
00289
00290 double tolf;
00291
00292
00293 size_t niter;
00294
00295
00296 int set_solver(linear_solver<solver_vec_t,solver_mat_t> &ls) {
00297 solver=&ls;
00298 return 0;
00299 }
00300
00301
00302 int solve(size_t ngrid, size_t neq, size_t nbleft, vec_t &x, mat_t &y,
00303 func_t &derivs, func_t &left, func_t &right,
00304 solver_mat_t &mat, solver_vec_t &rhs, solver_vec_t &dy) {
00305
00306
00307 fd=&derivs;
00308 fl=&left;
00309 fr=&right;
00310
00311
00312 size_t ix;
00313
00314
00315 size_t nbright=neq-nbleft;
00316
00317
00318 size_t nvars=ngrid*neq;
00319
00320 bool done=false;
00321 for(size_t it=0;done==false && it<niter;it++) {
00322
00323 ix=0;
00324
00325 mat.set_all(0.0);
00326
00327 for(size_t i=0;i<nbleft;i++) {
00328 matrix_row_t yk(y,0);
00329 rhs[ix]=-left(i,x[0],yk);
00330 for(size_t j=0;j<neq;j++) {
00331 int rxa=mat.set(ix,j,fd_left(i,j,x[0],yk));
00332 }
00333 ix++;
00334 }
00335
00336 for(size_t k=0;k<ngrid-1;k++) {
00337 size_t kp1=k+1;
00338 double tx=(x[kp1]+x[k])/2.0;
00339 double dx=x[kp1]-x[k];
00340 matrix_row_t yk(y,k);
00341 matrix_row_t ykp1(y,k+1);
00342
00343 for(size_t i=0;i<neq;i++) {
00344
00345 rhs[ix]=y[k][i]-y[kp1][i]+(x[kp1]-x[k])*
00346 (derivs(i,tx,ykp1)+derivs(i,tx,yk))/2.0;
00347
00348 size_t lhs=k*neq;
00349 for(size_t j=0;j<neq;j++) {
00350 int rxb=mat.set(ix,lhs+j,-fd_derivs(i,j,tx,yk)*dx/2.0);
00351 int rxc=mat.set(ix,lhs+j+neq,-fd_derivs(i,j,tx,ykp1)*dx/2.0);
00352 if (i==j) {
00353 int rxd=mat.set(ix,lhs+j,mat(ix,lhs+j)-1.0);
00354 int rxe=mat.set(ix,lhs+j+neq,mat(ix,lhs+j+neq)+1.0);
00355 }
00356 }
00357
00358 ix++;
00359
00360 }
00361 }
00362
00363 for(size_t i=0;i<nbright;i++) {
00364 matrix_row_t ylast(y,ngrid-1);
00365 size_t lhs=neq*(ngrid-1);
00366
00367 rhs[ix]=-right(i,x[ngrid-1],ylast);
00368
00369 for(size_t j=0;j<neq;j++) {
00370 int rxf=mat.set(ix,lhs+j,fd_right(i,j,x[ngrid-1],ylast));
00371 }
00372
00373 ix++;
00374
00375 }
00376
00377
00378
00379 int ret=solver->solve(ix,mat,rhs,dy);
00380
00381
00382
00383 double res=0.0;
00384 ix=0;
00385
00386 for(size_t igrid=0;igrid<ngrid;igrid++) {
00387 for(size_t ieq=0;ieq<neq;ieq++) {
00388 y[igrid][ieq]+=dy[ix];
00389 res+=dy[ix]*dy[ix];
00390 ix++;
00391 }
00392 }
00393
00394 std::cout << it << " " << sqrt(res) << " " << tolf << std::endl;
00395
00396 if (sqrt(res)<=tolf) done=true;
00397 }
00398
00399 if (done==false) {
00400 set_err_ret("Exceeded number of iterations in solve().",
00401 o2scl::gsl_emaxiter);
00402 }
00403
00404 return 0;
00405 }
00406
00407 protected:
00408
00409
00410
00411 ode_it_funct<vec_t> *fl, *fr, *fd;
00412
00413
00414
00415 linear_solver<solver_vec_t,solver_mat_t> *solver;
00416
00417
00418 double fd_left(size_t ieq, size_t ivar, double x, vec_t &y) {
00419 double ret, dydx;
00420
00421 y[ivar]+=h;
00422 ret=(*fl)(ieq,x,y);
00423
00424 y[ivar]-=h;
00425 ret-=(*fl)(ieq,x,y);
00426
00427 ret/=h;
00428 return ret;
00429 }
00430
00431
00432 double fd_right(size_t ieq, size_t ivar, double x, vec_t &y) {
00433 double ret, dydx;
00434
00435 y[ivar]+=h;
00436 ret=(*fr)(ieq,x,y);
00437
00438 y[ivar]-=h;
00439 ret-=(*fr)(ieq,x,y);
00440
00441 ret/=h;
00442 return ret;
00443 }
00444
00445
00446 double fd_derivs(size_t ieq, size_t ivar, double x, vec_t &y) {
00447 double ret, dydx;
00448
00449 y[ivar]+=h;
00450 ret=(*fd)(ieq,x,y);
00451
00452 y[ivar]-=h;
00453 ret-=(*fd)(ieq,x,y);
00454
00455 ret/=h;
00456
00457 return ret;
00458 }
00459
00460 };
00461
00462 #ifndef DOXYGENP
00463 }
00464 #endif
00465
00466 #endif