ode_it_solve.h

00001 /*
00002   -------------------------------------------------------------------
00003   
00004   Copyright (C) 2006, 2007, Andrew W. Steiner
00005   
00006   This file is part of O2scl.
00007   
00008   O2scl is free software; you can redistribute it and/or modify
00009   it under the terms of the GNU General Public License as published by
00010   the Free Software Foundation; either version 3 of the License, or
00011   (at your option) any later version.
00012   
00013   O2scl is distributed in the hope that it will be useful,
00014   but WITHOUT ANY WARRANTY; without even the implied warranty of
00015   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00016   GNU General Public License for more details.
00017   
00018   You should have received a copy of the GNU General Public License
00019   along with O2scl. If not, see <http://www.gnu.org/licenses/>.
00020 
00021   -------------------------------------------------------------------
00022 */
00023 #ifndef O2SCL_ODE_IT_SOLVE_H
00024 #define O2SCL_ODE_IT_SOLVE_H
00025 
00026 #include <iostream>
00027 
00028 #include <gsl/gsl_linalg.h>
00029 #include <gsl/gsl_ieee_utils.h>
00030 
00031 #include <o2scl/array.h>
00032 #include <o2scl/ovector_tlate.h>
00033 #include <o2scl/omatrix_tlate.h>
00034 #include <o2scl/uvector_tlate.h>
00035 #include <o2scl/umatrix_tlate.h>
00036 #include <o2scl/test_mgr.h>
00037 
00038 #include <o2scl/coord.h>
00039 #include <o2scl/bicg.h>
00040 #include <o2scl/bicgstab.h>
00041 #include <o2scl/gmres.h>
00042 #include <o2scl/ir.h>
00043 
00044 #ifndef DOXYGENP
00045 namespace o2scl {
00046 #endif
00047 
00048   /// Function class for ode_it_solve
00049   template<class vec_t=o2scl::ovector_view> class ode_it_funct {
00050   
00051     public:
00052   
00053     ode_it_funct() {}
00054   
00055     virtual ~ode_it_funct() {}
00056   
00057     /// Using \c x and \c y, return the value of function number \c ieq 
00058     virtual double operator()(size_t ieq, double x, vec_t &y) {
00059       return 0.0;
00060     }
00061 
00062   };
00063 
00064   /// Function pointer for ode_it_solve
00065   template<class vec_t=o2scl::ovector_view> 
00066     class ode_it_funct_fptr : public ode_it_funct<vec_t> {
00067 
00068     public:
00069 
00070     virtual ~ode_it_funct_fptr() {}
00071 
00072     /// Create using a function pointer
00073     ode_it_funct_fptr(double (*fp)(size_t, double, vec_t &)) {
00074       fptr=fp;
00075     }
00076 
00077     /// Using \c x and \c y, return the value of function number \c ieq 
00078     virtual double operator()(size_t ieq, double x, vec_t &y) {
00079       return fptr(ieq,x,y);
00080     }
00081 
00082     protected:
00083 
00084     ode_it_funct_fptr() {}
00085 
00086     /// The function pointer
00087     double (*fptr)(size_t ieq, double x, vec_t &y);
00088   };
00089 
00090   /// Member function pointer for ode_it_solve
00091   template<class tclass, class vec_t=o2scl::ovector_view>
00092     class ode_it_funct_mfptr : public ode_it_funct<vec_t> {
00093     public:
00094 
00095     virtual ~ode_it_funct_mfptr() {}
00096 
00097     /// Create using a class instance and member function
00098     ode_it_funct_mfptr(tclass *tp, 
00099                        double (tclass::*fp)(size_t, double, vec_t &)) {
00100       tptr=tp;
00101       fptr=fp;
00102     }
00103 
00104     /// Using \c x and \c y, return the value of function number \c ieq 
00105     virtual double operator()(size_t ieq, double x, vec_t &y) {
00106       return (*tptr.*fptr)(ieq,x,y);
00107     }
00108 
00109     protected:
00110 
00111     ode_it_funct_mfptr() {}
00112 
00113     /// The class pointer
00114     tclass *tptr;
00115 
00116     /// The member function pointer
00117     double (tclass::*fptr)(size_t ieq, double x, vec_t &y);
00118 
00119   };
00120 
00121   /** 
00122       \brief A generic solver for the linear system \f$ A x = b \f$
00123   */
00124   template<class vec_t, class mat_t> class linear_solver {
00125   public:
00126     virtual ~linear_solver() {}
00127     virtual int solve(size_t n, mat_t &a, vec_t &b, vec_t &x) {
00128       return 0;
00129     }
00130   };
00131 
00132   /// GSL solver by LU decomposition
00133   class gsl_LU_solver : public linear_solver<o2scl::ovector_view,
00134     o2scl::omatrix_view> {
00135     
00136   public:
00137     
00138     /// Solve square linear system \f$ A x = b \f$ of size \c n
00139     virtual int solve(size_t n, o2scl::omatrix_view &A, o2scl::ovector_view &b,
00140                       o2scl::ovector_view &x) {
00141       gsl_permutation *P=gsl_permutation_alloc(n);
00142       int s;
00143       gsl_linalg_LU_decomp((gsl_matrix *)&A,P,&s);
00144       gsl_linalg_LU_solve((gsl_matrix *)&A,P,(gsl_vector *)&b,
00145                           (gsl_vector *)&x);
00146       gsl_permutation_free(P);
00147       
00148       return 0;
00149     };
00150     
00151     virtual ~gsl_LU_solver() {}
00152     
00153   };
00154 
00155   /// GSL solver by QR decomposition
00156   class gsl_QR_solver : public linear_solver<o2scl::ovector_view,
00157     o2scl::omatrix_view> {
00158     
00159   public:
00160     
00161     /// Solve square linear system \f$ A x = b \f$ of size \c n
00162     virtual int solve(size_t n, o2scl::omatrix_view &A, o2scl::ovector_view &b,
00163                       o2scl::ovector_view &x) {
00164       o2scl::ovector tau(n);
00165       gsl_linalg_QR_decomp((gsl_matrix *)&A,(gsl_vector *)&tau);
00166       gsl_linalg_QR_solve((gsl_matrix *)&A,(gsl_vector *)&tau,
00167                           (gsl_vector *)&b,(gsl_vector *)&x);
00168       return 0;
00169     };
00170     
00171     virtual ~gsl_QR_solver() {}
00172     
00173   };
00174 
00175   /// GSL Householder solver
00176   class gsl_HH_solver : public linear_solver<o2scl::ovector_view,
00177     o2scl::omatrix_view> {
00178     
00179   public:
00180     
00181     /// Solve square linear system \f$ A x = b \f$ of size \c n
00182     virtual int solve(size_t n, o2scl::omatrix_view &A, o2scl::ovector_view &b,
00183                       o2scl::ovector_view &x) {
00184       return gsl_linalg_HH_solve((gsl_matrix *)(&A),(gsl_vector *)(&b),
00185                                  (gsl_vector *)(&x));
00186     };
00187     
00188     virtual ~gsl_HH_solver() {}
00189     
00190   };
00191 
00192   /** 
00193       \brief Make a coordinate matrix for ode_it_solve
00194       
00195       
00196   */
00197   class ode_it_make_Coord {
00198 
00199   public:
00200 
00201     /// The row index
00202     o2scl::uvector_int *r;
00203 
00204     /// The column pointer
00205     o2scl::uvector_int *c;
00206 
00207     /// The matrix entries
00208     o2scl::uvector *vals;
00209   
00210     /// Create a compressed-column format matrix for \ref ode_it_solve
00211     o2scl::Coord_Mat *make(size_t ngrid, size_t neq, size_t nbleft) {
00212       size_t N, nz, nbright=neq-nbleft;
00213       N=ngrid*neq;
00214       
00215       nz=neq*neq*2*(ngrid-1)+neq*neq;
00216       r=new o2scl::uvector_int(nz);
00217       c=new o2scl::uvector_int(nz);
00218       vals=new o2scl::uvector(nz);
00219       vals->set_all(1.0);
00220       size_t row=0;
00221       size_t ix=0;
00222       for(size_t i=0;i<nbleft;i++) {
00223         for(size_t j=0;j<neq;j++) {
00224           (*r)[ix]=row;
00225           (*c)[ix]=j;
00226           (*vals)[ix]=0.0;
00227           ix++;
00228         }
00229         row++;
00230       }
00231       for(size_t k=0;k<ngrid-1;k++) {
00232         size_t lhs=k*neq;
00233         for(size_t i=0;i<neq;i++) {
00234           for(size_t j=0;j<neq;j++) {
00235             (*r)[ix]=row;
00236             (*c)[ix]=lhs+j;
00237             ix++;
00238             (*r)[ix]=row;
00239             (*c)[ix]=lhs+j+neq;
00240             ix++;
00241           }
00242           row++;
00243         }
00244       }
00245       for(size_t i=0;i<nbright;i++) {
00246         size_t lhs=neq*(ngrid-1);
00247         for(size_t j=0;j<neq;j++) {
00248           (*r)[ix]=row;
00249           (*c)[ix]=lhs+j;
00250           ix++;
00251         }
00252         row++;
00253       }
00254       o2scl::Coord_Mat *M=new o2scl::Coord_Mat(N,N,nz,*vals,*r,*c,0);
00255       return M;
00256     }
00257 
00258   };
00259 
00260 
00261   /** \brief ODE solver using a generic linear solver to solve 
00262       finite-difference equations
00263       
00264       
00265       \todo Max and average tolerance?
00266       \todo partial correction option?
00267   */
00268   template <class func_t, class vec_t, class mat_t, class matrix_row_t,
00269     class solver_vec_t, class solver_mat_t> 
00270     class ode_it_solve {
00271   
00272   public:
00273   
00274     ode_it_solve() {
00275       h=1.0e-4;
00276       niter=30;
00277       tolf=1.0e-8;
00278       verbose=0;
00279     }
00280 
00281     virtual ~ode_it_solve() {}
00282 
00283     /// Set level of output (default 0)
00284     int verbose;
00285 
00286     /// Stepsize for finite differencing (default \f$ 10^{-4} \f$)
00287     double h;
00288 
00289     /// Tolerance (default \f$ 10^{-8} \f$)
00290     double tolf;
00291   
00292     /// Maximum number of iterations (default 30)
00293     size_t niter;
00294   
00295     /// Set the linear solver
00296     int set_solver(linear_solver<solver_vec_t,solver_mat_t> &ls) {
00297       solver=&ls;
00298       return 0;
00299     }
00300     
00301     /// Solve \c derivs with boundary conditions \c left and \c right
00302     int solve(size_t ngrid, size_t neq, size_t nbleft, vec_t &x, mat_t &y,
00303               func_t &derivs, func_t &left, func_t &right,
00304               solver_mat_t &mat, solver_vec_t &rhs, solver_vec_t &dy) {
00305 
00306       // Store the functions for simple derivatives
00307       fd=&derivs;
00308       fl=&left;
00309       fr=&right;
00310     
00311       // Variable index
00312       size_t ix;
00313 
00314       // Number of RHS boundary conditions
00315       size_t nbright=neq-nbleft;
00316 
00317       // Number of variables
00318       size_t nvars=ngrid*neq;
00319     
00320       bool done=false;
00321       for(size_t it=0;done==false && it<niter;it++) {
00322       
00323         ix=0;
00324       
00325         mat.set_all(0.0);
00326 
00327         for(size_t i=0;i<nbleft;i++) {
00328           matrix_row_t yk(y,0);
00329           rhs[ix]=-left(i,x[0],yk);
00330           for(size_t j=0;j<neq;j++) {
00331             int rxa=mat.set(ix,j,fd_left(i,j,x[0],yk));
00332           }
00333           ix++;
00334         }
00335 
00336         for(size_t k=0;k<ngrid-1;k++) {
00337           size_t kp1=k+1;
00338           double tx=(x[kp1]+x[k])/2.0;
00339           double dx=x[kp1]-x[k];
00340           matrix_row_t yk(y,k);
00341           matrix_row_t ykp1(y,k+1);
00342         
00343           for(size_t i=0;i<neq;i++) {
00344           
00345             rhs[ix]=y[k][i]-y[kp1][i]+(x[kp1]-x[k])*
00346               (derivs(i,tx,ykp1)+derivs(i,tx,yk))/2.0;
00347           
00348             size_t lhs=k*neq;
00349             for(size_t j=0;j<neq;j++) {
00350               int rxb=mat.set(ix,lhs+j,-fd_derivs(i,j,tx,yk)*dx/2.0);
00351               int rxc=mat.set(ix,lhs+j+neq,-fd_derivs(i,j,tx,ykp1)*dx/2.0);
00352               if (i==j) {
00353                 int rxd=mat.set(ix,lhs+j,mat(ix,lhs+j)-1.0);
00354                 int rxe=mat.set(ix,lhs+j+neq,mat(ix,lhs+j+neq)+1.0);
00355               }
00356             }
00357 
00358             ix++;
00359 
00360           }
00361         }
00362       
00363         for(size_t i=0;i<nbright;i++) {
00364           matrix_row_t ylast(y,ngrid-1);
00365           size_t lhs=neq*(ngrid-1);
00366         
00367           rhs[ix]=-right(i,x[ngrid-1],ylast);
00368         
00369           for(size_t j=0;j<neq;j++) {
00370             int rxf=mat.set(ix,lhs+j,fd_right(i,j,x[ngrid-1],ylast));
00371           }
00372           
00373           ix++;
00374 
00375         }
00376       
00377         // Compute correction
00378       
00379         int ret=solver->solve(ix,mat,rhs,dy);
00380       
00381         // Apply correction and compute its size
00382 
00383         double res=0.0;
00384         ix=0;
00385 
00386         for(size_t igrid=0;igrid<ngrid;igrid++) {
00387           for(size_t ieq=0;ieq<neq;ieq++) {
00388             y[igrid][ieq]+=dy[ix];
00389             res+=dy[ix]*dy[ix];
00390             ix++;
00391           }
00392         }
00393         
00394         std::cout << it << " " << sqrt(res) << " " << tolf << std::endl;
00395       
00396         if (sqrt(res)<=tolf) done=true;
00397       }
00398 
00399       if (done==false) {
00400         set_err_ret("Exceeded number of iterations in solve().",
00401                     o2scl::gsl_emaxiter);
00402       }
00403 
00404       return 0;
00405     }
00406 
00407   protected:
00408   
00409     /// \name Storage for functions
00410     //@{
00411     ode_it_funct<vec_t> *fl, *fr, *fd;
00412     //@}
00413 
00414     /// Solver
00415     linear_solver<solver_vec_t,solver_mat_t> *solver;
00416   
00417     /// Compute the derivatives of the LHS boundary conditions
00418     double fd_left(size_t ieq, size_t ivar, double x, vec_t &y) {
00419       double ret, dydx;
00420     
00421       y[ivar]+=h;
00422       ret=(*fl)(ieq,x,y);
00423     
00424       y[ivar]-=h;
00425       ret-=(*fl)(ieq,x,y);
00426     
00427       ret/=h;
00428       return ret;
00429     }
00430   
00431     /// Compute the derivatives of the RHS boundary conditions
00432     double fd_right(size_t ieq, size_t ivar, double x, vec_t &y) {
00433       double ret, dydx;
00434     
00435       y[ivar]+=h;
00436       ret=(*fr)(ieq,x,y);
00437     
00438       y[ivar]-=h;
00439       ret-=(*fr)(ieq,x,y);
00440     
00441       ret/=h;
00442       return ret;
00443     }
00444   
00445     /// Compute the finite-differenced part of the differential equations
00446     double fd_derivs(size_t ieq, size_t ivar, double x, vec_t &y) {
00447       double ret, dydx;
00448     
00449       y[ivar]+=h;
00450       ret=(*fd)(ieq,x,y);
00451     
00452       y[ivar]-=h;
00453       ret-=(*fd)(ieq,x,y);
00454     
00455       ret/=h;
00456     
00457       return ret;
00458     }
00459   
00460   };
00461 
00462 #ifndef DOXYGENP
00463 }
00464 #endif
00465 
00466 #endif

Documentation generated with Doxygen and provided under the GNU Free Documentation License. See License Information for details.