(* ::Package:: *)

(* :Title: Tridiagonal Matrix Routines *)

(* :Context: LinearAlgebra`Tridiagonal` *)

(* :Author: Jerry B.Keiper with Rob Knapp (Wolfram Research,Inc.) *)

(* :Summary:
	TridiagonalSolve solves A.x == r for x,
	where A is an n x n tridiagonal matrix and r is a vector of length n.
*)

(* :Copyright: Copyright 1992-2007, Wolfram Research,Inc. *)

(* :Package Version: 2.0 *)

(* :Mathematica Version: 4.1 *)

(* :History:
	V1.0, by Jerry Keiper, 1991.
	V2.0, by Rob Knapp, April 2000. Extends the functionality to better
		handle packed arrays.
*)

(* :Keywords: tridiagonal matrix, diagonals *)

(* :Discussion:
TridiagonalSolve solves A.x == r for x,
where A is an n x n tridiagonal matrix and r is a vector of length n.
A is assumed to have the form

b[[1]]  c[[1]]  0   ...                                    0
a[[1]]  b[[2]]  c[[2]]  0   ...                            0
0       a[[2]]  b[[3]]  c[[3]]  0   ...                    0
0       .       .       .       .   .    .        .        0
0       .       .       .       .   .    .        .        0
0       .       .       .       .   .    .        .        0
0       ...                              a[[n-2]] b[[n-1]] c[[n-1]] 
0       ...                              0        a[[n-1]] b[[n]]

To meet this, the following requirements are imposed:
Length[a] >= Length[r] - 1 &&
Length[b] >= Length[r] &&
Length[c] >= Length[r] - 1

Regular Gaussian elimination is employed,but without any pivoting.
*)

(* :Warnings: Infinity introduced if pivot becomes zero. *)

(* :Limitations: No pivoting done. *)

(* :Examples:

TridiagonalSolve[ Table[-1,{i,9}],Table[2,{i,10}],
    Table[-1,{i,9}],{1,0,0,0,0,0,0,0,0,0}]

(* identical to:
LinearSolve[Table[Switch[Abs[i-j], 0, 2, 1, -1, _, 0], {i, 10}, {j, 10}],
    {1,0,0,0,0,0,0,0,0,0}]
*)

TridiagonalSolve[{2,3,4,5}, {0,1,2,3,4}, {9,8,7,6}, {9,9,9,9,9}]

(* Because TridiagonalSolve does no pivoting, it fails on the above
   example, while the corresponding LinearSolve succeeds:
LinearSolve[ { { 0,9,0,0,0},{2,1,8,0,0},{0,3,2,7,0},{0,0,4,3,6},{0,0,0,5,4}},
     {9,9,9,9,9}]
*)

*)

Message[General::obspkg, "LinearAlgebra`Tridiagonal`"]
BeginPackage["LinearAlgebra`Tridiagonal`"]

Unprotect[TridiagonalSolve];


TridiagonalSolve::usage=
"TridiagonalSolve[ a, b, c, r ] solves A . x == r for x, \
where A is a tridiagonal matrix. The three diagonals of A are given \
by: a11 = b[[1]], a12 = c[[1]], a21 = a[[1]], a22 = b[[2]], \
a23 = c[[2]], ... No pivoting is done, so the algorithm may fail even \
though a solution exists. This cannot happen for a symmetric positive \
definite matrix.";

TridiagonalSolve::tsv = "Argument `1` is not a non-empty vector";

TridiagonalSolve::dlen =
"The vector containing the diagonal elements should have \
length at least the system size of `1`.";

TridiagonalSolve::odlen =
"The vectors containing the off-diagonal elements should have length at \
least one less than the system size of `1`.";

TridiagonalSolve::lenwarn =
"The diagonal and solution vectors should be the same length, while \
the off-diagonal vectors should be one element shorter. The extra \
elements are being ignored.";

Begin["`Private`"]

issueObsoleteFunMessage[fun_, context_] :=
        (Message[fun::obspkgfn, fun, context];
         )

TridiagonalSolve[a_List, b_List, c_List, r_List]:=
(issueObsoleteFunMessage[TridiagonalSolve,"LinearAlgebra`Tridiagonal`"];
	With[{res =Catch[ TridiagonalSolveCheck[a,b,c,r]]}, 
	res /; ListQ[res]])

TridiagonalSolve[a___] := Null/;(Length[{a}] =!= 4 &&
	(Message[TridiagonalSolve::argrx, TridiagonalSolve, Length[{a}], 4]; False))

(*
	First check to see if the input can be represented as Real 
	or Complex packed arrays.
	If so, use specialized compiled versions of the solver.  
	Otherwise, use the general version.
*)
TridiagonalSolveCheck[a_List, b_List, c_List, r_List]:=
Module[{pack = SeeAboutPacking[{a,b,c,r}]},
	If[ListQ[pack],
		Apply[
			If[MemberQ[Map[Head,pack[[All,1]]],Complex],
				ComplexTridiagonalSolve,
				RealTridiagonalSolve],
			pack],
(* else *)
	LengthCheck[{a,b,c,r}];
	GeneralTridiagonalSolve[a,b,c,r]]]

(*
	Check that the lengths are compatible.  System size is based on the
	rhs or last vector input.
*)
LengthCheck[vin_] := 
Module[{la, lb, lc, lr, r1},
	{la, lb, lc, lr} = Map[Length,vin];
	If[lb < lr, 
		Message[TridiagonalSolve::dlen, lr];
		Throw[$Failed]];
	r1 = lr - 1;
	If[Or[la < r1, lc < r1],
		Message[TridiagonalSolve::odlen, lr];
		Throw[$Failed]];
    If[Or[la =!= r1, lc =!= r1, lb =!= lr],
        Message[TridiagonalSolve::lenwarn]
    ]
]

(*
	The three following functions do the work.  They are all really the same 
	code, but with the first two cases compiled for Real and Complex types.
*)
RealTridiagonalSolve = 
Compile[{{a, _Real, 1}, {b, _Real, 1}, {c, _Real, 1}, {r, _Real, 1}},
	Module[{
			len=Length[r],
			solution=Table[0.,{Length[r]}],
			aux = 0.,
			aux1=Table[0.,{Length[r]}],
			a1=Prepend[a,0.],
			iter = 0},
		aux=1/b[[1]];
		solution[[1]]=r[[1]] aux;
		Do[
			aux1[[iter]]=c[[iter-1]] aux;
			aux=1/(b[[iter]]-a1[[iter]] aux1[[iter]]);
			solution[[iter]]=(r[[iter]]- a1[[iter]] solution[[iter-1]]) aux,
			{iter,2,len}];
		Do[
			solution[[iter]]-=aux1[[iter+1]] solution[[iter+1]],
			{iter,len-1,1,-1}];
		solution]];

ComplexTridiagonalSolve = 
Compile[{{a, _Complex, 1}, {b, _Complex, 1}, {c, _Complex, 1}, {r, _Complex, 1}},
	Module[{
			len=Length[r],
			solution=Table[0. + 0. I,{Length[r]}],
			aux = 0. + 0. I,
			aux1=Table[0. + 0. I,{Length[r]}],
			a1=Prepend[a,0. + 0. I],
			iter = 0},
		aux=1/b[[1]];
		solution[[1]]=r[[1]] aux;
		Do[
			aux1[[iter]]=c[[iter-1]] aux;
			aux=1/(b[[iter]]-a1[[iter]] aux1[[iter]]);
			solution[[iter]]=(r[[iter]]- a1[[iter]] solution[[iter-1]]) aux,
			{iter,2,len}];
		Do[
			solution[[iter]]-=aux1[[iter+1]] solution[[iter+1]],
			{iter,len-1,1,-1}];
		solution]];

GeneralTridiagonalSolve[a_List,b_List,c_List,r_List]:=
Module[{
		len=Length[r],
		solution=Array[0,Length[r]],
		aux,
        aux1=Array[0,Length[r]],
		a1=Prepend[a,0],
		iter},
	aux=1/b[[1]];
	solution[[1]]=r[[1]] aux;
	Do[
		aux1[[iter]]=c[[iter-1]] aux;
		aux=1/(b[[iter]]-a1[[iter]] aux1[[iter]]);
		solution[[iter]]=(r[[iter]]-a1[[iter]] solution[[iter-1]]) aux,
		{iter,2,len}]; 			
	Do[
		solution[[iter]]-=aux1[[iter+1]] solution[[iter+1]],
		{iter,len-1,1,-1}];
	solution]

(*
	Check to see if the input arrays can be represented as PackedArrays 
	of reals or complexes.  If so, return a list of the four input
	arrays which have been packed, otherwise return False.
*)
SeeAboutPacking[input_] := 
Block[{prec, nin, type},
	If[Not[TrueQ[Catch[HasPackableMachineNumber[input]]]], Return[False]];
	pin = Map[Developer`ToPackedArray[N[#], Union]&, input];
	If[TrueQ[Apply[
			And, 
			Map[
				And[Developer`PackedArrayQ[#],TensorRank[#] == 1]&, 
				pin]]],
		LengthCheck[pin];
		pin,
		False]]

(*
	HasPackableMachineNumber checks to see if the input contains a 
	machine number inside a list which may be packable.  This ignores
	machine numbers in things like {x[1.], x[2.]} since that could not 
	be respresented as a packed array.  An enclosing Catch is needed
	since Throw is used for as quick an exit as possible.

	The order of definitions is of the UTMOST importance here.  
	If you were, for example, to put the MachineNumberQ test first, 
	MachineNumberQ would unpack a packed array, which would be a disaster 
	for the execution speed.
*)  

(*
	For PackedArrays, we need look at only a single number element
	since all numbers are of the same type 
*)
HasPackableMachineNumber[x_?Developer`PackedArrayQ] := HasPackableMachineNumber[Part[x, Apply[Sequence, Table[1, {TensorRank[x]}]]]]

(*
	For Lists, if a machine number is found, the result will be 
	Thrown, so this case returns False (if it gets there)
*)
HasPackableMachineNumber[x_List] := (Scan[HasPackableMachineNumber, x]; False)

HasPackableMachineNumber[x_?MachineNumberQ] := Throw[True]
 
HasPackableMachineNumber[x_?NumberQ] := False
(*
	Anything else must be an unpackable element, so we jump out.
*)
HasPackableMachineNumber[x_] := Throw[False]


End[]  (*`Private`*)

Protect[TridiagonalSolve];

EndPackage[] (*LinearAlgebra`Tridiagonal`*)
