module Math.LinearRecursive.Internal.Matrix
( Matrix
, matrix
, diagonal
, unMatrix
, unDiagonal
, toMatrix
, unMatrix'
, matrixSize
, inverseMatrixDiag1
) where
import Data.List (transpose, foldl1')
data Matrix a = Matrix { unMatrix :: [[a]] }
| Diagonal { unDiagonal :: a }
deriving Eq
instance (Num a, Show a) => Show (Matrix a) where
show (Matrix a) = "Matrix " ++ show a
show (Diagonal a) = "Diagonal " ++ show a
toMatrix :: Num a => Matrix a -> Matrix a
toMatrix (Matrix a) = Matrix a
toMatrix (Diagonal a) = Matrix [replicate i 0 ++ [aii] ++ repeat 0 | (i, aii) <- zip [0..] (repeat a)]
unMatrix' :: Num a => Matrix a -> [[a]]
unMatrix' = unMatrix . toMatrix
matrixSize :: Num a => Matrix a -> Maybe Int
matrixSize (Diagonal _) = Nothing
matrixSize (Matrix a) = Just (length a)
matrix :: [[a]] -> Matrix a
matrix = Matrix
diagonal :: a -> Matrix a
diagonal = Diagonal
instance Num a => Num (Matrix a) where
Diagonal a + Diagonal b = Diagonal (a + b)
Matrix a + Matrix b = Matrix (zipWith (zipWith (+)) a b)
Matrix a + Diagonal b = Matrix [ [if (i :: Int) == j then aij + b else aij | (j, aij) <- zip [0..] ai]
| (i, ai) <- zip [0..] a
]
Diagonal b + Matrix a = Matrix [ [if (i :: Int) == j then b + aij else aij | (j, aij) <- zip [0..] ai]
| (i, ai) <- zip [0..] a
]
negate (Matrix a) = Matrix (map (map negate) a)
negate (Diagonal a) = Diagonal (negate a)
fromInteger = Diagonal . fromInteger
Matrix a * Matrix b = let tb = transpose b
c = [[foldl1' (+) (zipWith (*) ra cb) | cb <- tb] | ra <- a]
in
Matrix c
Diagonal a * Diagonal b = Diagonal (a * b)
Diagonal a * Matrix b = Matrix ((map.map) (a*) b)
Matrix a * Diagonal b = Matrix ((map.map) (*b) a)
abs = error "Matrix: abs undefined"
signum = error "Matrix: abs undefined"
gauss :: Num a => [[a]] -> [[a]]
gauss ma = go [] ma
where
go xs [] = reverse xs
go xs ys = go (row : map handle xs) (map handle (prefix ++ suffix))
where
pivot = 0
(prefix, (_:row):suffix) = splitAt pivot ys
handle (r:rs) = zipWith (\x y -> x y * r) rs row
handle [] = error "gauss: internal error"
inverseMatrixDiag1 :: (Eq a, Num a) => Matrix a -> Matrix a
inverseMatrixDiag1 (Diagonal 1) = Diagonal 1
inverseMatrixDiag1 (Diagonal (1)) = Diagonal (1)
inverseMatrixDiag1 (Diagonal n) = error "inverseMatrixDet1: Diagonal";
inverseMatrixDiag1 (Matrix ma) = matrix (gauss ma')
where
n = length ma
ma' = [ri ++ [if i == j then 1 else 0 | j <- [0..n1]] | (i, ri) <- zip [0..] ma]