00001
00012 #include "rroots.h"
00013
00015 int roots(double *a,int n,double *wr,double *wi)
00016 {
00017 double sq,b2,c,disc;
00018 int m,numroots;
00019
00020 m = n;
00021 numroots = 0;
00022 while (m > 1) {
00023 b2 = -0.5*a[m-2];
00024 c = a[m-1];
00025 disc = b2*b2-c;
00026 if (fabs(disc)/(fabs(b2*b2)+fabs(c)) <= DBL_EPSILON) disc = 0.0;
00027 if (disc < 0.0) {
00028 sq = sqrt(-disc);
00029 wr[m-2] = b2;
00030 wi[m-2] = sq;
00031 wr[m-1] = b2;
00032 wi[m-1] = -sq;
00033 numroots+=2;
00034 }
00035 else {
00036 sq = sqrt(disc);
00037 wr[m-2] = fabs(b2)+sq;
00038 if (b2 < 0.0) wr[m-2] = -wr[m-2];
00039 if (wr[m-2] == 0)
00040 wr[m-1] = 0;
00041 else {
00042 wr[m-1] = c/wr[m-2];
00043 numroots+=2;
00044 }
00045 wi[m-2] = 0.0;
00046 wi[m-1] = 0.0;
00047 }
00048 m -= 2;
00049 }
00050 if (m == 1) {
00051 wr[0] = -a[0];
00052 wi[0] = 0.0;
00053 numroots++;
00054 }
00055 return numroots;
00056 }
00057
00060 void deflate(double *a,int n,double *b,double *quad,double *err)
00061 {
00062 double *c,r,s;
00063 int i;
00064
00065 c = new double [n+1];
00066 r = quad[1];
00067 s = quad[0];
00068
00069 b[1] = a[1] - r;
00070 c[1] = b[1] - r;
00071
00072 for (i=2;i<=n;i++){
00073 b[i] = a[i] - r * b[i-1] - s * b[i-2];
00074 c[i] = b[i] - r * c[i-1] - s * c[i-2];
00075 }
00076 *err = fabs(b[n])+fabs(b[n-1]);
00077 delete [] c;
00078 }
00079
00084 void find_quad(double *a,int n,double *b,double *quad,double *err, int *iter)
00085 {
00086 double *c,dn,dr,ds,drn,dsn,eps,r,s;
00087 int i;
00088
00089 c = new double [n+1];
00090 c[0] = 1.0;
00091 r = quad[1];
00092 s = quad[0];
00093 dr = 1.0;
00094 ds = 0;
00095 eps = 1e-15;
00096 *iter = 1;
00097
00098 while ((fabs(dr)+fabs(ds)) > eps) {
00099 if (*iter > maxiter) break;
00100 if (((*iter) % 200) == 0) {
00101 eps*=10.0;
00102 }
00103 b[1] = a[1] - r;
00104 c[1] = b[1] - r;
00105
00106 for (i=2;i<=n;i++){
00107 b[i] = a[i] - r * b[i-1] - s * b[i-2];
00108 c[i] = b[i] - r * c[i-1] - s * c[i-2];
00109 }
00110 dn=c[n-1] * c[n-3] - c[n-2] * c[n-2];
00111 drn=b[n] * c[n-3] - b[n-1] * c[n-2];
00112 dsn=b[n-1] * c[n-1] - b[n] * c[n-2];
00113
00114 if (fabs(dn) < 1e-15) {
00115 if (dn < 0.0) dn = -1e-8;
00116 else dn = 1e-8;
00117 }
00118 dr = drn / dn;
00119 ds = dsn / dn;
00120 r += dr;
00121 s += ds;
00122 (*iter)++;
00123 }
00124 quad[0] = s;
00125 quad[1] = r;
00126 *err = fabs(ds)+fabs(dr);
00127 delete [] c;
00128 }
00129
00131 void diff_poly(double *a,int n,double *b)
00132 {
00133 double coef;
00134 int i;
00135
00136 coef = (double)n;
00137 b[0] = 1.0;
00138 for (i=1;i<n;i++) {
00139 b[i] = a[i]*((double)(n-i))/coef;
00140 }
00141 }
00142
00162 void recurse(double *a,int n,double *b,int m,double *quad,
00163 double *err,int *iter)
00164 {
00165 double *c,*x,rs[2],tst;
00166
00167 if (fabs(b[m]) < 1e-16) m--;
00168 if (m == 2) {
00169 quad[0] = b[2];
00170 quad[1] = b[1];
00171 *err = 0;
00172 *iter = 0;
00173 return;
00174 }
00175 c = new double [m+1];
00176 x = new double [n+1];
00177 c[0] = x[0] = 1.0;
00178 rs[0] = quad[0];
00179 rs[1] = quad[1];
00180 *iter = 0;
00181 find_quad(b,m,c,rs,err,iter);
00182 tst = fabs(rs[0]-quad[0])+fabs(rs[1]-quad[1]);
00183 if (*err < 1e-12) {
00184 quad[0] = rs[0];
00185 quad[1] = rs[1];
00186 }
00187
00188 if (((*iter > 5) && (tst < 1e-4)) || ((*iter > 20) && (tst < 1e-1))) {
00189 diff_poly(b,m,c);
00190 recurse(a,n,c,m-1,rs,err,iter);
00191 quad[0] = rs[0];
00192 quad[1] = rs[1];
00193 }
00194 delete [] x;
00195 delete [] c;
00196 }
00197
00201 void get_quads(double *a,int n,double *quad,double *x)
00202 {
00203 double *b,*z,err,tmp;
00204 int iter,i,m;
00205
00206 if ((tmp = a[0]) != 1.0) {
00207 a[0] = 1.0;
00208 for (i=1;i<=n;i++) {
00209 a[i] /= tmp;
00210 }
00211 }
00212 if (n == 2) {
00213 x[0] = a[1];
00214 x[1] = a[2];
00215 return;
00216 }
00217 else if (n == 1) {
00218 x[0] = a[1];
00219 return;
00220 }
00221 m = n;
00222 b = new double [n+1];
00223 z = new double [n+1];
00224 b[0] = 1.0;
00225 for (i=0;i<=n;i++) {
00226 z[i] = a[i];
00227 x[i] = 0.0;
00228 }
00229 do {
00230 if (n > m) {
00231 quad[0] = 3.14159e-1;
00232 quad[1] = 2.78127e-1;
00233 }
00234 loop:
00235 find_quad(z,m,b,quad,&err,&iter);
00236 if ((err > 1e-7) || (iter > maxiter)) {
00237 diff_poly(z,m,b);
00238 iter = 0;
00239 recurse(z,m,b,m-1,quad,&err,&iter);
00240 }
00241 deflate(z,m,b,quad,&err);
00242 if (err > 0.01) {
00244
00245
00246
00247
00248
00249 quad[0] = -2.71828e-1;
00250 quad[1] = -3.14159e-1;
00251
00252
00253 }
00254 if (err > 1) goto loop;
00255 x[m-2] = quad[1];
00256 x[m-1] = quad[0];
00257 m -= 2;
00258 for (i=0;i<=m;i++) {
00259 z[i] = b[i];
00260 }
00261 } while (m > 2);
00262 if (m == 2) {
00263 x[0] = b[1];
00264 x[1] = b[2];
00265 }
00266 else x[0] = b[1];
00267 delete [] z;
00268 delete [] b;
00269 }