Listing 5: A class for tridiagonal matrices
#ifndef TRIDIAG_H #define TRIDIAG_H #include "bandstor.h" template <class T> class tridiagonalMatrix : public bandStorage<T> { private: vector<T> c; vector<T> d; vector<T> e; public: tridiagonalMatrix() {} tridiagonalMatrix(const int k) : bandStorage<T>(k) {lowerBandWidth() = -1; upperBandWidth() = 1;} tridiagonalMatrix& operator=(const tridiagonalMatrix&); vector<T>& subDiagonal(void); vector<T>& superDiagonal(void); vector<T>& mainDiagonal(void); vector<T> solve(const vector<T>&); }; template <class T> tridiagonalMatrix<T>& tridiagonalMatrix<T>::operator=(const tridiagonalMatrix<T>& M) { return operator=(M); } template <class T> vector<T>& tridiagonalMatrix<T>::subDiagonal(void) { return bandStorage<T>::diag(-1); } template <class T> vector<T>& tridiagonalMatrix<T>::superDiagonal(void) { return bandStorage<T>::diag(1); } template <class T> vector<T>& tridiagonalMatrix<T>::mainDiagonal(void) { return bandStorage<T>::diag(0); } template <class T> vector<T> tridiagonalMatrix<T>::solve(const vector<T>& b) { vector<T> x; int n=order(); int info; int k, kb, kp1, nm1, nm2; T t; x = b; c = vector<T>(n, 0.0); e = vector<T>(n, 0.0); for (k=0; k<n-1; k++) { c[k+1] = subDiagonal()[k]; e[k] = superDiagonal()[k]; } d = mainDiagonal(); info = 0; c[0] = d[0]; nm1 = n-1; if (nm1 >= 1) { d[0] = e[0]; e[0] = 0.0; e[n-1] = 0.0; for (k=1; k<=nm1; k++) { kp1 = k+1; // find largest of two rows if (fabs(c[kp1-1]) > fabs(c[k-1])) { // interchange rows swap(c[kp1-1], c[k-1]); swap(d[kp1-1], d[k-1]); swap(e[kp1-1], e[k-1]); swap(x[kp1-1], x[k-1]); } if (c[k-1] == 0.0) throw ("zero diagonal encoutered in factorization"); t = -c[kp1-1]/c[k-1]; c[kp1-1] = d[kp1-1] + t*d[k-1]; d[kp1-1] = e[kp1-1] + t*e[k-1]; e[kp1-1] = 0.0; x[kp1-1] = x[kp1-1] + t*x[k-1]; } } if (c[n-1] == 0.0) throw ("zero diagonal encoutered in factorization"); // Back solve nm2 = n-2; x[n-1] = x[n-1]/c[n-1]; if (n > 1) { x[nm1-1] = (x[nm1-1] - d[nm1-1]*x[n-1])/c[nm1-1]; if (nm2 > 1) { for (kb = 1; kb <= nm2; kb++) { k = nm2 - kb + 1; x[k-1] = (x[k-1] - d[k-1]*x[k+1-1] - e[k-1]*x[k+2-1])/c[k-1]; } } } return x; } #endif //End of File