module Math.LinearRecursive.Internal.Matrix ( Matrix , matrix , diagonal , unMatrix , unDiagonal , toMatrix , unMatrix' , matrixSize , inverseMatrixDiag1 ) where import Data.List (transpose, foldl1') data Matrix a = Matrix { unMatrix :: [[a]] } | Diagonal { unDiagonal :: a } deriving Eq instance (Num a, Show a) => Show (Matrix a) where show (Matrix a) = "Matrix " ++ show a show (Diagonal a) = "Diagonal " ++ show a toMatrix :: Num a => Matrix a -> Matrix a toMatrix (Matrix a) = Matrix a toMatrix (Diagonal a) = Matrix [replicate i 0 ++ [aii] ++ repeat 0 | (i, aii) <- zip [0..] (repeat a)] unMatrix' :: Num a => Matrix a -> [[a]] unMatrix' = unMatrix . toMatrix matrixSize :: Num a => Matrix a -> Maybe Int matrixSize (Diagonal _) = Nothing matrixSize (Matrix a) = Just (length a) matrix :: [[a]] -> Matrix a matrix = Matrix diagonal :: a -> Matrix a diagonal = Diagonal instance Num a => Num (Matrix a) where Diagonal a + Diagonal b = Diagonal (a + b) Matrix a + Matrix b = Matrix (zipWith (zipWith (+)) a b) Matrix a + Diagonal b = Matrix [ [if (i :: Int) == j then aij + b else aij | (j, aij) <- zip [0..] ai] | (i, ai) <- zip [0..] a ] Diagonal b + Matrix a = Matrix [ [if (i :: Int) == j then b + aij else aij | (j, aij) <- zip [0..] ai] | (i, ai) <- zip [0..] a ] negate (Matrix a) = Matrix (map (map negate) a) negate (Diagonal a) = Diagonal (negate a) fromInteger = Diagonal . fromInteger Matrix a * Matrix b = let tb = transpose b c = [[foldl1' (+) (zipWith (*) ra cb) | cb <- tb] | ra <- a] in Matrix c Diagonal a * Diagonal b = Diagonal (a * b) Diagonal a * Matrix b = Matrix ((map.map) (a*) b) Matrix a * Diagonal b = Matrix ((map.map) (*b) a) abs = error "Matrix: abs undefined" signum = error "Matrix: abs undefined" gauss :: Num a => [[a]] -> [[a]] gauss ma = go [] ma where go xs [] = reverse xs go xs ys = go (row : map handle xs) (map handle (prefix ++ suffix)) where pivot = 0 (prefix, (_:row):suffix) = splitAt pivot ys handle (r:rs) = zipWith (\x y -> x - y * r) rs row handle [] = error "gauss: internal error" inverseMatrixDiag1 :: (Eq a, Num a) => Matrix a -> Matrix a inverseMatrixDiag1 (Diagonal 1) = Diagonal 1 inverseMatrixDiag1 (Diagonal (-1)) = Diagonal (-1) inverseMatrixDiag1 (Diagonal n) = error "inverseMatrixDet1: Diagonal"; inverseMatrixDiag1 (Matrix ma) = matrix (gauss ma') where n = length ma ma' = [ri ++ [if i == j then 1 else 0 | j <- [0..n-1]] | (i, ri) <- zip [0..] ma]