Memosa-FVM  0.2
SquareMatrix.h
Go to the documentation of this file.
1 // This file os part of FVM
2 // Copyright (c) 2012 FVM Authors
3 // See LICENSE file for terms.
4 
5 #ifndef _SQUAREMATRIX_H_
6 #define _SQUAREMATRIX_H_
7 
8 #include "Array.h"
9 #include "MatrixJML.h"
10 #include <math.h>
11 
12 template<class T>
13 class SquareMatrix : public MatrixJML<T>
14 {
15  public:
16  typedef Array<T> TArray;
17  typedef shared_ptr<TArray> TArrPtr;
20 
21  SquareMatrix(const int N):
22  _order(N),
23  _elements(N*N),
24  _sorted(false),
25  _pivotRows(N),
26  _maxVals(N),
28  {_values.zero();}
29 
30  T& getElement(const int i, const int j) {return _values[(i-1)*_order+j-1];}
31  T& operator()(const int i, const int j) {return _values[(i-1)*_order+j-1];}
32  void zero() {_values.zero();}
33 
34  void Solve(TArray& bVec)
35  {//Gaussian Elimination w/ scaled partial pivoting
36  //replaces bVec with the solution vector.
37 
39  (*this).makeCopy(LU);
40  TArray bCpy(_order);
41  IntArray l(_order);
42  TArray s(_order);
43 
44  //find max values in each row if not done yet
45  if(!_sorted)
46  {
47  for(int i=1;i<_order+1;i++)
48  {
49  l[i-1]=i;
50  s[i-1]=fabs((*this)(i,1));
51  for(int j=2;j<_order+1;j++)
52  {
53  if(s[i-1]<fabs((*this)(i,j)))
54  s[i-1]=fabs((*this)(i,j));
55  }
56  if(s[i-1]==0)
57  {
58  cout<<"Row: "<<i<<endl;
59  throw CException("Matrix has row of zeros!");
60  }
61  }
62  _maxVals=s;
63  //cout<<"Max vals:"<<endl;
64  //for(int i=0;i<_order;i++)
65  //cout<<s[i]<<endl;
66  }
67 
68  //Forward sweep
69  if(!_sorted)
70  {
71  for(int i=1;i<_order;i++)
72  {
73  T rmax=fabs((*this)(l[i-1],i)/s[l[i-1]-1]);
74  // cout<<"rmax: "<<rmax<<endl;
75  int newMax=-1;
76  for(int j=i+1;j<_order+1;j++)
77  {
78  T r=fabs((*this)(l[j-1],i)/s[l[j-1]-1]);
79  if(r>rmax)
80  {
81  //cout<<"switching,i,j,r:"<<i<<","<<j<<","<<r<<endl;
82  //cout<<"factors "<<(*this)(l[j-1],i)<<" "<<s[l[j-1]-1]<<endl;
83  //cout<<"mapped j "<<l[j-1]<<endl;
84  rmax=r;
85  newMax=j;
86  }
87  }
88 
89  if(newMax!=-1)
90  {
91  int temp=l[i-1];
92  l[i-1]=newMax;
93  l[newMax-1]=temp;
94  }
95 
96  for(int j=i+1;j<_order+1;j++)
97  {
98  T factor=LU(l[j-1],i)/LU(l[i-1],i);
99  LU(l[j-1],i)=factor;
100  bVec[l[j-1]-1]-=factor*bVec[l[i-1]-1];
101  for(int k=i+1;k<_order+1;k++)
102  {
103  LU(l[j-1],k)-=LU(l[i-1],k)*factor;
104  T test=LU(l[j-1],k);
105  if(isnan(test)||isinf(test))
106  {
107  cout<<"Denom: "<<LU(l[i-1],i)<<endl;
108  cout<<"Num: "<<LU(l[j-1],i)<<endl;
109  cout<<"first: "<<LU(l[j-1],k)<<endl;
110  cout<<"second: "<<LU(l[i-1],k)<<endl;
111  cout<<"factor: "<<factor<<endl;
112  cout<<"test: "<<test<<endl;
113  throw CException("test is nan");
114  }
115  }
116  }
117  }
118  _pivotRows=l;
119  _sorted=true;
120  //cout<<"Order"<<endl;
121  //for(int i=0;i<_order;i++)
122  // cout<<_pivotRows[i]<<endl;
123  //cout<<endl;
124  }
125  else
126  {
127  for(int i=1;i<_order;i++)
128  {
129  for(int j=i+1;j<_order+1;j++)
130  {
131  T factor=LU(_pivotRows[j-1],i)/LU(_pivotRows[i-1],i);
132  LU(_pivotRows[j-1],i)=factor;
133  bVec[_pivotRows[j-1]-1]-=factor*bVec[_pivotRows[i-1]-1];
134  for(int k=i+1;k<_order+1;k++)
135  {
136  LU(_pivotRows[j-1],k)=LU(_pivotRows[j-1],k)
137  -LU(_pivotRows[i-1],k)*factor;
138  }
139  }
140  }
141  }
142 
143  //back solve
144  bVec[_pivotRows[_order-1]-1]=
145  bVec[_pivotRows[_order-1]-1]/LU(_pivotRows[_order-1],_order);
146 
147  T sum=0.;
148  for(int i=_order-1;i>0;i--)
149  {
150  sum=0.;
151  for(int j=i+1;j<_order+1;j++)
152  sum-=LU(_pivotRows[i-1],j)*bVec[_pivotRows[j-1]-1];
153  bVec[_pivotRows[i-1]-1]+=sum;
154  bVec[_pivotRows[i-1]-1]=bVec[_pivotRows[i-1]-1]/LU(_pivotRows[i-1],i);
155  }
156 
157  //reorder
158  bCpy=bVec;
159  for(int i=0;i<_order;i++)
160  bVec[i]=bCpy[_pivotRows[i]-1];
161 
162  }
163 
165  {
166  if(o._order!=this->_order)
167  throw CException("Cannot copy matrices of different sizes!");
168 
169  o._sorted=this->_sorted;
170  o._pivotRows=this->_pivotRows;
171  o._maxVals=this->_maxVals;
172  o._values=this->_values;
173  }
174 
175  void printMatrix()
176  {
177  for(int i=1;i<_order+1;i++)
178  {
179  for(int j=1;j<_order+1;j++)
180  cout<<(*this)(i,j)<<" ";
181  cout<<endl;
182  }
183  cout<<endl;
184  }
185 
187  {
188  T trace=0.;
189  for(int i=1;i<_order+1;i++)
190  trace+=fabs((*this)(i,i));
191  return trace;
192  }
193 
194  void multiply(const TArray& x, TArray& b)
195  {
196  int lenx=x.getLength();
197  int lenb=b.getLength();
198  if(lenx==_order && lenb==_order)
199  {
200  b.zero();
201  for(int i=1;i<_order+1;i++)
202  for(int j=1;j<_order+1;j++)
203  b[i-1]+=(*this)(i,j)*x[j-1];
204  }
205  else
206  throw CException("Array length does not match matrix order!");
207  }
208 
209  void testSolve()
210  {
211  TArray x(_order);
212  TArray b(_order);
213 
214  cout<<"Correct Solution:"<<endl;
215  for(int i=0;i<_order;i++)
216  {
217  x[i]=rand()/1.e9;
218  cout<<"x["<<i<<"]="<<x[i]<<endl;
219  }
220 
221  multiply(x,b);
222 
223  cout<<"Before"<<endl;
224  for(int i=0;i<_order;i++)
225  cout<<"b["<<i<<"]="<<b[i]<<endl;
226 
227  Solve(b);
228 
229  cout<<"After"<<endl;
230  for(int i=0;i<_order;i++)
231  cout<<"b["<<i<<"]="<<b[i]<<endl;
232  cout<<endl;
233 
234  cout<<"Error"<<endl;
235  for(int i=0;i<_order;i++)
236  cout<<"b["<<i<<"]="<<b[i]-x[i]<<endl;
237  cout<<endl;
238 
239 
240  throw CException("finished with test");
241 
242  }
243 
244  private:
245  const int _order;
246  const int _elements;
247  bool _sorted;
251 
252 
253 };
254 
255 
256 #endif
void multiply(const TArray &x, TArray &b)
Definition: SquareMatrix.h:194
virtual void zero()
Definition: Array.h:281
void makeCopy(SquareMatrix< T > &o)
Definition: SquareMatrix.h:164
void testSolve()
Definition: SquareMatrix.h:209
void printMatrix()
Definition: SquareMatrix.h:175
Array< int > IntArray
Definition: SquareMatrix.h:18
TArray _maxVals
Definition: SquareMatrix.h:249
NumTypeTraits< T >::T_Scalar T_Scalar
Definition: SquareMatrix.h:19
shared_ptr< TArray > TArrPtr
Definition: SquareMatrix.h:17
T & operator()(const int i, const int j)
Definition: SquareMatrix.h:31
const int _order
Definition: SquareMatrix.h:245
IntArray _pivotRows
Definition: SquareMatrix.h:248
T & getElement(const int i, const int j)
Definition: SquareMatrix.h:30
const int _elements
Definition: SquareMatrix.h:246
Array< T > TArray
Definition: SquareMatrix.h:16
TArray _values
Definition: SquareMatrix.h:250
Tangent fabs(const Tangent &a)
Definition: Tangent.h:312
Definition: Array.h:14
void Solve(TArray &bVec)
Definition: SquareMatrix.h:34
SquareMatrix(const int N)
Definition: SquareMatrix.h:21
int getLength() const
Definition: Array.h:87