00001 /* 00002 ------------------------------------------------------------------- 00003 00004 Copyright (C) 2006, 2007, Andrew W. Steiner 00005 00006 This file is part of O2scl. 00007 00008 O2scl is free software; you can redistribute it and/or modify 00009 it under the terms of the GNU General Public License as published by 00010 the Free Software Foundation; either version 3 of the License, or 00011 (at your option) any later version. 00012 00013 O2scl is distributed in the hope that it will be useful, 00014 but WITHOUT ANY WARRANTY; without even the implied warranty of 00015 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00016 GNU General Public License for more details. 00017 00018 You should have received a copy of the GNU General Public License 00019 along with O2scl. If not, see <http://www.gnu.org/licenses/>. 00020 00021 ------------------------------------------------------------------- 00022 */ 00023 #ifndef O2SCL_ODE_IT_SOLVE_H 00024 #define O2SCL_ODE_IT_SOLVE_H 00025 00026 #include <iostream> 00027 00028 #include <gsl/gsl_linalg.h> 00029 #include <gsl/gsl_ieee_utils.h> 00030 00031 #include <o2scl/array.h> 00032 #include <o2scl/ovector_tlate.h> 00033 #include <o2scl/omatrix_tlate.h> 00034 #include <o2scl/uvector_tlate.h> 00035 #include <o2scl/umatrix_tlate.h> 00036 #include <o2scl/test_mgr.h> 00037 00038 #include <o2scl/coord.h> 00039 #include <o2scl/bicg.h> 00040 #include <o2scl/bicgstab.h> 00041 #include <o2scl/gmres.h> 00042 #include <o2scl/ir.h> 00043 00044 #ifndef DOXYGENP 00045 namespace o2scl { 00046 #endif 00047 00048 /// Function class for ode_it_solve 00049 template<class vec_t=o2scl::ovector_view> class ode_it_funct { 00050 00051 public: 00052 00053 ode_it_funct() {} 00054 00055 virtual ~ode_it_funct() {} 00056 00057 /// Using \c x and \c y, return the value of function number \c ieq 00058 virtual double operator()(size_t ieq, double x, vec_t &y) { 00059 return 0.0; 00060 } 00061 00062 }; 00063 00064 /// Function pointer for ode_it_solve 00065 template<class vec_t=o2scl::ovector_view> 00066 class ode_it_funct_fptr : public ode_it_funct<vec_t> { 00067 00068 public: 00069 00070 virtual ~ode_it_funct_fptr() {} 00071 00072 /// Create using a function pointer 00073 ode_it_funct_fptr(double (*fp)(size_t, double, vec_t &)) { 00074 fptr=fp; 00075 } 00076 00077 /// Using \c x and \c y, return the value of function number \c ieq 00078 virtual double operator()(size_t ieq, double x, vec_t &y) { 00079 return fptr(ieq,x,y); 00080 } 00081 00082 protected: 00083 00084 ode_it_funct_fptr() {} 00085 00086 /// The function pointer 00087 double (*fptr)(size_t ieq, double x, vec_t &y); 00088 }; 00089 00090 /// Member function pointer for ode_it_solve 00091 template<class tclass, class vec_t=o2scl::ovector_view> 00092 class ode_it_funct_mfptr : public ode_it_funct<vec_t> { 00093 public: 00094 00095 virtual ~ode_it_funct_mfptr() {} 00096 00097 /// Create using a class instance and member function 00098 ode_it_funct_mfptr(tclass *tp, 00099 double (tclass::*fp)(size_t, double, vec_t &)) { 00100 tptr=tp; 00101 fptr=fp; 00102 } 00103 00104 /// Using \c x and \c y, return the value of function number \c ieq 00105 virtual double operator()(size_t ieq, double x, vec_t &y) { 00106 return (*tptr.*fptr)(ieq,x,y); 00107 } 00108 00109 protected: 00110 00111 ode_it_funct_mfptr() {} 00112 00113 /// The class pointer 00114 tclass *tptr; 00115 00116 /// The member function pointer 00117 double (tclass::*fptr)(size_t ieq, double x, vec_t &y); 00118 00119 }; 00120 00121 /** 00122 \brief A generic solver for the linear system \f$ A x = b \f$ 00123 */ 00124 template<class vec_t, class mat_t> class linear_solver { 00125 public: 00126 virtual ~linear_solver() {} 00127 virtual int solve(size_t n, mat_t &a, vec_t &b, vec_t &x) { 00128 return 0; 00129 } 00130 }; 00131 00132 /// GSL solver by LU decomposition 00133 class gsl_LU_solver : public linear_solver<o2scl::ovector_view, 00134 o2scl::omatrix_view> { 00135 00136 public: 00137 00138 /// Solve square linear system \f$ A x = b \f$ of size \c n 00139 virtual int solve(size_t n, o2scl::omatrix_view &A, o2scl::ovector_view &b, 00140 o2scl::ovector_view &x) { 00141 gsl_permutation *P=gsl_permutation_alloc(n); 00142 int s; 00143 gsl_linalg_LU_decomp((gsl_matrix *)&A,P,&s); 00144 gsl_linalg_LU_solve((gsl_matrix *)&A,P,(gsl_vector *)&b, 00145 (gsl_vector *)&x); 00146 gsl_permutation_free(P); 00147 00148 return 0; 00149 }; 00150 00151 virtual ~gsl_LU_solver() {} 00152 00153 }; 00154 00155 /// GSL solver by QR decomposition 00156 class gsl_QR_solver : public linear_solver<o2scl::ovector_view, 00157 o2scl::omatrix_view> { 00158 00159 public: 00160 00161 /// Solve square linear system \f$ A x = b \f$ of size \c n 00162 virtual int solve(size_t n, o2scl::omatrix_view &A, o2scl::ovector_view &b, 00163 o2scl::ovector_view &x) { 00164 o2scl::ovector tau(n); 00165 gsl_linalg_QR_decomp((gsl_matrix *)&A,(gsl_vector *)&tau); 00166 gsl_linalg_QR_solve((gsl_matrix *)&A,(gsl_vector *)&tau, 00167 (gsl_vector *)&b,(gsl_vector *)&x); 00168 return 0; 00169 }; 00170 00171 virtual ~gsl_QR_solver() {} 00172 00173 }; 00174 00175 /// GSL Householder solver 00176 class gsl_HH_solver : public linear_solver<o2scl::ovector_view, 00177 o2scl::omatrix_view> { 00178 00179 public: 00180 00181 /// Solve square linear system \f$ A x = b \f$ of size \c n 00182 virtual int solve(size_t n, o2scl::omatrix_view &A, o2scl::ovector_view &b, 00183 o2scl::ovector_view &x) { 00184 return gsl_linalg_HH_solve((gsl_matrix *)(&A),(gsl_vector *)(&b), 00185 (gsl_vector *)(&x)); 00186 }; 00187 00188 virtual ~gsl_HH_solver() {} 00189 00190 }; 00191 00192 /** 00193 \brief Make a coordinate matrix for ode_it_solve 00194 00195 00196 */ 00197 class ode_it_make_Coord { 00198 00199 public: 00200 00201 /// The row index 00202 o2scl::uvector_int *r; 00203 00204 /// The column pointer 00205 o2scl::uvector_int *c; 00206 00207 /// The matrix entries 00208 o2scl::uvector *vals; 00209 00210 /// Create a compressed-column format matrix for \ref ode_it_solve 00211 o2scl::Coord_Mat *make(size_t ngrid, size_t neq, size_t nbleft) { 00212 size_t N, nz, nbright=neq-nbleft; 00213 N=ngrid*neq; 00214 00215 nz=neq*neq*2*(ngrid-1)+neq*neq; 00216 r=new o2scl::uvector_int(nz); 00217 c=new o2scl::uvector_int(nz); 00218 vals=new o2scl::uvector(nz); 00219 vals->set_all(1.0); 00220 size_t row=0; 00221 size_t ix=0; 00222 for(size_t i=0;i<nbleft;i++) { 00223 for(size_t j=0;j<neq;j++) { 00224 (*r)[ix]=row; 00225 (*c)[ix]=j; 00226 (*vals)[ix]=0.0; 00227 ix++; 00228 } 00229 row++; 00230 } 00231 for(size_t k=0;k<ngrid-1;k++) { 00232 size_t lhs=k*neq; 00233 for(size_t i=0;i<neq;i++) { 00234 for(size_t j=0;j<neq;j++) { 00235 (*r)[ix]=row; 00236 (*c)[ix]=lhs+j; 00237 ix++; 00238 (*r)[ix]=row; 00239 (*c)[ix]=lhs+j+neq; 00240 ix++; 00241 } 00242 row++; 00243 } 00244 } 00245 for(size_t i=0;i<nbright;i++) { 00246 size_t lhs=neq*(ngrid-1); 00247 for(size_t j=0;j<neq;j++) { 00248 (*r)[ix]=row; 00249 (*c)[ix]=lhs+j; 00250 ix++; 00251 } 00252 row++; 00253 } 00254 o2scl::Coord_Mat *M=new o2scl::Coord_Mat(N,N,nz,*vals,*r,*c,0); 00255 return M; 00256 } 00257 00258 }; 00259 00260 00261 /** \brief ODE solver using a generic linear solver to solve 00262 finite-difference equations 00263 00264 00265 \todo Max and average tolerance? 00266 \todo partial correction option? 00267 */ 00268 template <class func_t, class vec_t, class mat_t, class matrix_row_t, 00269 class solver_vec_t, class solver_mat_t> 00270 class ode_it_solve { 00271 00272 public: 00273 00274 ode_it_solve() { 00275 h=1.0e-4; 00276 niter=30; 00277 tolf=1.0e-8; 00278 verbose=0; 00279 } 00280 00281 virtual ~ode_it_solve() {} 00282 00283 /// Set level of output (default 0) 00284 int verbose; 00285 00286 /// Stepsize for finite differencing (default \f$ 10^{-4} \f$) 00287 double h; 00288 00289 /// Tolerance (default \f$ 10^{-8} \f$) 00290 double tolf; 00291 00292 /// Maximum number of iterations (default 30) 00293 size_t niter; 00294 00295 /// Set the linear solver 00296 int set_solver(linear_solver<solver_vec_t,solver_mat_t> &ls) { 00297 solver=&ls; 00298 return 0; 00299 } 00300 00301 /// Solve \c derivs with boundary conditions \c left and \c right 00302 int solve(size_t ngrid, size_t neq, size_t nbleft, vec_t &x, mat_t &y, 00303 func_t &derivs, func_t &left, func_t &right, 00304 solver_mat_t &mat, solver_vec_t &rhs, solver_vec_t &dy) { 00305 00306 // Store the functions for simple derivatives 00307 fd=&derivs; 00308 fl=&left; 00309 fr=&right; 00310 00311 // Variable index 00312 size_t ix; 00313 00314 // Number of RHS boundary conditions 00315 size_t nbright=neq-nbleft; 00316 00317 // Number of variables 00318 size_t nvars=ngrid*neq; 00319 00320 bool done=false; 00321 for(size_t it=0;done==false && it<niter;it++) { 00322 00323 ix=0; 00324 00325 mat.set_all(0.0); 00326 00327 for(size_t i=0;i<nbleft;i++) { 00328 matrix_row_t yk(y,0); 00329 rhs[ix]=-left(i,x[0],yk); 00330 for(size_t j=0;j<neq;j++) { 00331 int rxa=mat.set(ix,j,fd_left(i,j,x[0],yk)); 00332 } 00333 ix++; 00334 } 00335 00336 for(size_t k=0;k<ngrid-1;k++) { 00337 size_t kp1=k+1; 00338 double tx=(x[kp1]+x[k])/2.0; 00339 double dx=x[kp1]-x[k]; 00340 matrix_row_t yk(y,k); 00341 matrix_row_t ykp1(y,k+1); 00342 00343 for(size_t i=0;i<neq;i++) { 00344 00345 rhs[ix]=y[k][i]-y[kp1][i]+(x[kp1]-x[k])* 00346 (derivs(i,tx,ykp1)+derivs(i,tx,yk))/2.0; 00347 00348 size_t lhs=k*neq; 00349 for(size_t j=0;j<neq;j++) { 00350 int rxb=mat.set(ix,lhs+j,-fd_derivs(i,j,tx,yk)*dx/2.0); 00351 int rxc=mat.set(ix,lhs+j+neq,-fd_derivs(i,j,tx,ykp1)*dx/2.0); 00352 if (i==j) { 00353 int rxd=mat.set(ix,lhs+j,mat(ix,lhs+j)-1.0); 00354 int rxe=mat.set(ix,lhs+j+neq,mat(ix,lhs+j+neq)+1.0); 00355 } 00356 } 00357 00358 ix++; 00359 00360 } 00361 } 00362 00363 for(size_t i=0;i<nbright;i++) { 00364 matrix_row_t ylast(y,ngrid-1); 00365 size_t lhs=neq*(ngrid-1); 00366 00367 rhs[ix]=-right(i,x[ngrid-1],ylast); 00368 00369 for(size_t j=0;j<neq;j++) { 00370 int rxf=mat.set(ix,lhs+j,fd_right(i,j,x[ngrid-1],ylast)); 00371 } 00372 00373 ix++; 00374 00375 } 00376 00377 // Compute correction 00378 00379 int ret=solver->solve(ix,mat,rhs,dy); 00380 00381 // Apply correction and compute its size 00382 00383 double res=0.0; 00384 ix=0; 00385 00386 for(size_t igrid=0;igrid<ngrid;igrid++) { 00387 for(size_t ieq=0;ieq<neq;ieq++) { 00388 y[igrid][ieq]+=dy[ix]; 00389 res+=dy[ix]*dy[ix]; 00390 ix++; 00391 } 00392 } 00393 00394 std::cout << it << " " << sqrt(res) << " " << tolf << std::endl; 00395 00396 if (sqrt(res)<=tolf) done=true; 00397 } 00398 00399 if (done==false) { 00400 set_err_ret("Exceeded number of iterations in solve().", 00401 o2scl::gsl_emaxiter); 00402 } 00403 00404 return 0; 00405 } 00406 00407 protected: 00408 00409 /// \name Storage for functions 00410 //@{ 00411 ode_it_funct<vec_t> *fl, *fr, *fd; 00412 //@} 00413 00414 /// Solver 00415 linear_solver<solver_vec_t,solver_mat_t> *solver; 00416 00417 /// Compute the derivatives of the LHS boundary conditions 00418 double fd_left(size_t ieq, size_t ivar, double x, vec_t &y) { 00419 double ret, dydx; 00420 00421 y[ivar]+=h; 00422 ret=(*fl)(ieq,x,y); 00423 00424 y[ivar]-=h; 00425 ret-=(*fl)(ieq,x,y); 00426 00427 ret/=h; 00428 return ret; 00429 } 00430 00431 /// Compute the derivatives of the RHS boundary conditions 00432 double fd_right(size_t ieq, size_t ivar, double x, vec_t &y) { 00433 double ret, dydx; 00434 00435 y[ivar]+=h; 00436 ret=(*fr)(ieq,x,y); 00437 00438 y[ivar]-=h; 00439 ret-=(*fr)(ieq,x,y); 00440 00441 ret/=h; 00442 return ret; 00443 } 00444 00445 /// Compute the finite-differenced part of the differential equations 00446 double fd_derivs(size_t ieq, size_t ivar, double x, vec_t &y) { 00447 double ret, dydx; 00448 00449 y[ivar]+=h; 00450 ret=(*fd)(ieq,x,y); 00451 00452 y[ivar]-=h; 00453 ret-=(*fd)(ieq,x,y); 00454 00455 ret/=h; 00456 00457 return ret; 00458 } 00459 00460 }; 00461 00462 #ifndef DOXYGENP 00463 } 00464 #endif 00465 00466 #endif
Documentation generated with Doxygen and provided under the GNU Free Documentation License. See License Information for details.
Project hosting provided by
,
O2scl Sourceforge Project Page