00001 #ifndef O2SCL_ODE_BV_MULTISHOOT_H 00002 #define O2SCL_ODE_BV_MULTISHOOT_H 00003 00004 #include <string> 00005 #include <o2scl/collection.h> 00006 #include <o2scl/ovector_tlate.h> 00007 #include <o2scl/adapt_step.h> 00008 #include <o2scl/gsl_astep.h> 00009 #include <o2scl/ode_iv_solve.h> 00010 #include <o2scl/gsl_mroot_hybrids.h> 00011 00012 using namespace std; 00013 00014 /* 00015 00016 Here is the multishooting method 00017 */ 00018 00019 00020 namespace o2scl { 00021 00022 00023 /* Multishooting */ 00024 00025 00026 template < class param_t, 00027 class func_t=ode_funct<param_t>, 00028 class vec_t=ovector_view, 00029 class alloc_vec_t=ovector, class alloc_t=ovector_alloc, 00030 class vec_int_t=ovector_int_view, class mat_t=omatrix> class ode_bv_multishoot { 00031 00032 public : 00033 00034 ode_bv_multishoot(){ 00035 oisp=&def_ois; 00036 mrootp=&def_mroot; 00037 } 00038 virtual ~ode_bv_multishoot(){} 00039 00040 00041 /* Solve the boundary value problem on a mesh */ 00042 virtual int solve(vec_t &mesh,int &n_func,vec_t &y_start, 00043 param_t ¶m,func_t &left_b,func_t &right_b, 00044 func_t &extra_b,func_t &derivs,vec_t &x_save,mat_t &y_save){ 00045 00046 /* Make copies of the input for later access */ 00047 this->l_mesh = &mesh; 00048 this->l_n_func = &n_func; 00049 this->l_y_start = &y_start; 00050 this->l_param = ¶m; 00051 this->l_left_b = &left_b; 00052 this->l_right_b = &right_b; 00053 this->l_extra_b = &extra_b; 00054 this->l_derivs = &derivs; 00055 this->l_x_save = &x_save; 00056 this->l_y_save = &y_save; 00057 00058 /* vector of variables */ 00059 int nsolve=y_start.size(); 00060 ovector sx(nsolve),sy(nsolve); 00061 sx = y_start; 00062 00063 /* Equation solver */ 00064 mm_funct_mfptr<ode_bv_multishoot<param_t,func_t,vec_t, 00065 alloc_vec_t,alloc_t,vec_int_t>,param_t> 00066 mfm(this,&ode_bv_multishoot<param_t,func_t,vec_t, 00067 alloc_vec_t,alloc_t,vec_int_t>::solve_fun); 00068 00069 /* Run multishooting and save at the last step */ 00070 this->save=false; 00071 int ret=this->mrootp->msolve(nsolve,sx,param,mfm); 00072 this->save= true; 00073 solve_fun(nsolve,sx,sy,param); 00074 00075 return ret; 00076 } 00077 00078 /* Set initial value solver */ 00079 int set_iv(ode_iv_solve<param_t,func_t,vec_t,alloc_vec_t,alloc_t> &ois){ 00080 oisp=&ois; 00081 return 0; 00082 } 00083 /* Set equation solver */ 00084 int set_mroot(mroot<param_t,mm_funct<param_t> > &root){ 00085 mrootp=&root; 00086 return 0; 00087 } 00088 /* Set default initial value solver */ 00089 ode_iv_solve<param_t,func_t,vec_t,alloc_vec_t,alloc_t> def_ois; 00090 00091 /* Set default equation solver */ 00092 gsl_mroot_hybrids<param_t,mm_funct<param_t> > def_mroot; 00093 00094 00095 protected : 00096 00097 ode_iv_solve<param_t,func_t,vec_t,alloc_vec_t,alloc_t> *oisp; 00098 gsl_mroot_hybrids<param_t,mm_funct<param_t> > *mrootp; 00099 vec_t *l_mesh; 00100 vec_t *l_y_start; 00101 param_t *l_param; 00102 func_t *l_left_b; 00103 func_t *l_right_b; 00104 func_t *l_extra_b; 00105 func_t *l_derivs; 00106 int *l_n_func; 00107 vec_t *l_x_save; 00108 mat_t *l_y_save; 00109 bool save; 00110 00111 int solve_fun(size_t nv,const vec_t &sx,vec_t &sy,param_t &pa){ 00112 00113 double xa,xb=0.0,h; 00114 ovector y((*this->l_n_func)),y2((*this->l_n_func)); 00115 00116 00117 /* We update y_start in order that derivs know all the values of parameters */ 00118 for(size_t i=0;i<(*this->l_y_start).size();i++) 00119 (*this->l_y_start)[i]=sx[i]; 00120 00121 00122 /* A loop on each subinterval */ 00123 for(size_t k=0;k<(*this->l_mesh).size()-1;k++){ 00124 00125 xa=(*this->l_mesh)[k]; 00126 xb=(*this->l_mesh)[k+1]; 00127 h=(xb-xa)/100.0; 00128 00129 /* We load function's value at the left point of the sub-interval */ 00130 if(k==0) 00131 (*this->l_left_b)(xa,(*this->l_n_func),sx,y,pa); 00132 else 00133 for(int i=0;i<(*this->l_n_func);i++) 00134 y[i]=sx[i+(*this->l_n_func)*(1+k)]; 00135 00136 /* iv_solver if we save */ 00137 if(this->save){ 00138 int ngrid=((*this->l_x_save).size()-1)/((*this->l_mesh).size()-1)+1; 00139 ovector xxsave(ngrid); 00140 omatrix yysave(ngrid,(*this->l_n_func)); 00141 00142 if(k!=((*this->l_mesh).size()-2)) 00143 xb=(*this->l_mesh)[k+1]-((*this->l_mesh)[k+1]-(*this->l_mesh)[k])/ngrid; 00144 00145 this->oisp->solve_grid(xa,xb,h,(*this->l_n_func),y,ngrid, 00146 xxsave,yysave,pa,(*this->l_derivs)); 00147 for(int i=0;i<ngrid;i++){ 00148 (*this->l_x_save)[i+k*(ngrid)] = xxsave[i]; 00149 for(int j=0;j<(*this->l_n_func);j++) 00150 (*this->l_y_save)[i+k*(ngrid)][j] = yysave[i][j]; 00151 } 00152 00153 } 00154 /* iv_solver if we don't save */ 00155 else 00156 this->oisp->solve_final_value(xa,xb,h,(*this->l_n_func),y,y2,pa,(*this->l_derivs)); 00157 00158 /* Then we load values at the end of sub-interval */ 00159 if(k==(*this->l_mesh).size()-2) 00160 (*this->l_right_b)(xb,(*this->l_n_func),sx,y,pa); 00161 else 00162 for(int i=0;i<(*this->l_n_func);i++) 00163 y[i]=sx[i+(*this->l_n_func)*(2+k)]; 00164 00165 /* Now we take the difference */ 00166 for(int i=0;i<(*this->l_n_func);i++) 00167 //sy[i+k*(*this->l_n_func)]=(y2[i]-y[i])/y[i]; 00168 sy[i+k*(*this->l_n_func)]= y2[i]-y[i]; 00169 00170 } 00171 00172 /* Then load Extra boundary condition */ 00173 (*this->l_extra_b)(xb,(*this->l_n_func),sx,y,pa); 00174 for(int i=0;i<(*this->l_n_func);i++) 00175 sy[i+(int((*this->l_mesh).size()-1))*(*this->l_n_func)]=y[i]; 00176 00177 00178 return 0; 00179 } 00180 00181 }; 00182 } 00183 00184 00185 00186 #endif
Documentation generated with Doxygen and provided under the GNU Free Documentation License. See License Information for details.
Project hosting provided by
,
O2scl Sourceforge Project Page