{-# LANGUAGE TypeSynonymInstances #-}
-- | A small simple matrix library.
module Algebra.Matrix
( Vector(Vec)
, unVec, lengthVec
, Matrix(M), matrix
, matrixToVector, vectorToMatrix, unMVec, unM, (!!!)
, identity, propLeftIdentity, propRightIdentity
, mulM, addM, transpose, isSquareMatrix, dimension
, scale, swap, pivot
, addRow, subRow, addCol, subCol
, findPivot, forwardElim, gaussElim, gaussElimCorrect
) where
import qualified Data.List as L
import Data.Function (on)
import Control.Monad (liftM)
import Control.Arrow hiding ((<+>))
import Test.QuickCheck
import Algebra.Structures.Field
import Debug.Trace
-------------------------------------------------------------------------------
-- | Row vectors
newtype Vector r = Vec [r] deriving (Eq)
instance Show r => Show (Vector r) where
show (Vec vs) = show vs
-- Generate vector of length 1-10
instance Arbitrary r => Arbitrary (Vector r) where
arbitrary = do n <- choose (1,10) :: Gen Int
liftM Vec $ gen n
where
gen 0 = return []
gen n = do x <- arbitrary
xs <- gen (n-1)
return (x:xs)
instance Functor Vector where
fmap f = Vec . map f . unVec
{-
instance Ring r => Ring (Vector r) where
(Vec xs) <+> (Vec ys) | length xs == length ys = Vec (zipWith (<+>) xs ys)
| otherwise = error "Bad dimensions in vector addition"
(Vec xs) <*> (Vec ys) | length xs == length ys = Vec (zipWith (<*>) xs ys)
| otherwise = error "Bad dimensions in vector multiplication"
-- In order to do these we need to know the length of the vector in advance...
-- Give me dependent types!
one = ?
zero = ?
-}
unVec :: Vector r -> [r]
unVec (Vec vs) = vs
lengthVec :: Vector r -> Int
lengthVec = length . unVec
-------------------------------------------------------------------------------
-- | Matrices
newtype Matrix r = M [Vector r]
deriving (Eq)
instance Show r => Show (Matrix r) where
show xs = case unlines (map show (unMVec xs)) of
[] -> "[]"
xs -> init xs ++ "\n"
-- Generate matrices with at most 10 rows
instance Arbitrary r => Arbitrary (Matrix r) where
arbitrary = do n <- choose (1,10) :: Gen Int
m <- choose (1,10) :: Gen Int
xs <- sequence [ liftM Vec (gen n) | _ <- [1..m]]
return (M xs)
where
gen 0 = return []
gen n = do x <- arbitrary
xs <- gen (n-1)
return (x:xs)
instance Functor Matrix where
fmap f = M . map (fmap f) . unM
-- | Construct a mxn matrix.
matrix :: [[r]] -> Matrix r
matrix xs =
let m = fromIntegral $ length xs
n = fromIntegral $ length (head xs)
in if length (filter (\x -> fromIntegral (length x) == n) xs) == length xs
then M (map Vec xs)
else error "matrix: Bad dimensions"
unM :: Matrix r -> [Vector r]
unM (M xs) = xs
unMVec :: Matrix r -> [[r]]
unMVec = map unVec . unM
vectorToMatrix :: Vector r -> Matrix r
vectorToMatrix = matrix . (:[]) . unVec
matrixToVector :: Matrix r -> Vector r
matrixToVector m | fst (dimension m) == 1 = head (unM m)
| otherwise = error "matrixToVector: Bad dimension"
(!!!) :: Matrix a -> (Int,Int) -> a
m !!! (r,c) | r >= 0 && r < rows && c >= 0 && c < cols = unMVec m !! r !! c
| otherwise = error "!!!: Out of bounds"
where
(rows,cols) = dimension m
-- | Compute the dimension of a matrix.
dimension :: Matrix r -> (Int, Int)
dimension (M xs) | null xs = (0,0)
| otherwise = (length xs, length (unVec (head xs)))
isSquareMatrix :: Matrix r -> Bool
isSquareMatrix (M xs) = all (== length xs) (map lengthVec xs)
-- | Transpose a matrix.
transpose :: Matrix r -> Matrix r
transpose (M xs) = matrix (L.transpose (map unVec xs))
-- | Matrix addition.
addM :: Ring r => Matrix r -> Matrix r -> Matrix r
addM (M xs) (M ys)
| dimension (M xs) == dimension (M ys) = m
| otherwise = error "Bad dimensions in matrix addition"
where
m = matrix (zipWith (zipWith (<+>)) (map unVec xs) (map unVec ys))
-- | Matrix multiplication.
mulM :: Ring r => Matrix r -> Matrix r -> Matrix r
mulM (M xs) (M ys)
| snd (dimension (M xs)) == fst (dimension (M ys)) = m
| otherwise = error "Bad dimensions in matrix multiplication"
where
m = matrix [ [ mulVec x y | y <- L.transpose (map unVec ys) ]
| x <- map unVec xs ]
mulVec xs ys | length xs == length ys = foldr (<+>) zero $ zipWith (<*>) xs ys
| otherwise = error "mulVec: Bad dimension"
{-
-- In order to do this the size of the matrix need to be encoded in the type
-- There is also a problem with the fact that it is not possible to add or
-- multiply matrices with bad dimensions, so the generation of matrices has to be better...
instance Ring r => Ring (Matrix r) where
(<+>) = add
(<*>) = mul
neg (Vec xs d) = Vec [ map neg x | x <- xs ] d
zero = undefined
-}
-- | Construct a nxn identity matrix.
identity :: IntegralDomain r => Int -> Matrix r
identity n = matrix (xs 0)
where
xs x | x == n = []
| otherwise = (replicate x zero ++ [one] ++
replicate (n-x-1) zero) : xs (x+1)
-- Specification of identity.
propLeftIdentity :: (IntegralDomain r, Eq r) => Matrix r -> Bool
propLeftIdentity a = a == identity n `mulM` a
where n = fst (dimension a)
propRightIdentity :: (IntegralDomain r, Eq r) => Matrix r -> Bool
propRightIdentity a = a == a `mulM` identity m
where m = snd (dimension a)
-------------------------------------------------------------------------------
-- Operations on matrices.
-- | Scale a row in a matrix.
scale :: CommutativeRing a => Matrix a -> Int -> a -> Matrix a
scale m r s
| 0 <= r && r < rows = matrix $ take r m' ++ map (s <*>) (m' !! r) : drop (r+1) m'
| otherwise = error "scale: Index out of bounds"
where
(rows,_) = dimension m
m' = unMVec m
-- Scaling does not affect dimension
propScaleDimension :: (Arbitrary r, CommutativeRing r) => Matrix r -> Int -> r -> Bool
propScaleDimension m r s = d == dimension (scale m (mod r rows) s)
where d@(rows,_) = dimension m
-- | Swap two rows of a matrix.
swap :: Matrix a -> Int -> Int -> Matrix a
swap m i j
| 0 <= i && i <= r && 0 <= j && j <= r = matrix $ swap' m' i j
| otherwise = error "swap: Index out of bounds"
where
(r,_) = dimension m
m' = unMVec m
swap' xs 0 0 = xs
swap' (x:xs) 0 j = (x:xs) !! j : take (j-1) xs ++ x : drop j xs
swap' xs i 0 = swap' xs 0 i
swap' (x:xs) i j = x : swap' xs (i-1) (j-1)
-- Swapping does not affect dimension
propSwapDimension :: Matrix () -> Int -> Int -> Bool
propSwapDimension m i j = d == dimension (swap m (mod i r) (mod j r))
where d@(r,_) = dimension m
-- Swap is itselfs identity.
propSwapIdentity :: Matrix () -> Int -> Int -> Bool
propSwapIdentity m i j = m == swap (swap m i' j') i' j'
where
d@(r,_) = dimension m
i' = mod i r
j' = mod j r
-- Add the row-vector to the specified row of the matrix.
addRow :: CommutativeRing a => Matrix a -> Vector a -> Int -> Matrix a
addRow m row@(Vec xs) x
| 0 <= x && x < r = matrix $ take x m' ++
zipWith (<+>) (m' !! x) xs :
drop (x+1) m'
| c /= length xs = error "addRow: Bad length of row"
| otherwise = error "addRow: Bad row number"
where
(r,c) = dimension m
m' = unMVec m
propAddRowDimension :: (CommutativeRing a, Arbitrary a)
=> Matrix a -> Vector a -> Int -> Property
propAddRowDimension m row@(Vec xs) r =
length xs == c ==> d == dimension (addRow m row (mod r r'))
where d@(r',c) = dimension m
addCol :: CommutativeRing a => Matrix a -> Vector a -> Int -> Matrix a
addCol m c x = transpose $ addRow (transpose m) c x
subRow, subCol :: CommutativeRing a => Matrix a -> Vector a -> Int -> Matrix a
subRow m (Vec xs) x = addRow m (Vec (map neg xs)) x
subCol m (Vec xs) x = addCol m (Vec (map neg xs)) x
-- Multiply the pivot row and add it to the target row.
pivot :: CommutativeRing a => Matrix a -> a -> Int -> Int -> Matrix a
pivot m s p t = addRow m (fmap (s <*>) (unM m !! p)) t
-- Find first non-zero number below the pivot and return its value and row number
-- given that it exists
findPivot :: (CommutativeRing a, Eq a) => Matrix a -> (Int,Int) -> Maybe (a,Int)
findPivot m (r,c) = safeHead $ filter ((/= zero) . fst) $ drop (r+1) $ zip (head $ drop c $ unMVec $ transpose m) [0..]
where
m' = unMVec m
safeHead [] = Nothing
safeHead (x:xs) = Just x
fE :: (Field a, Eq a) => Matrix a -> Matrix a
fE (M []) = M []
fE (M (Vec []:_)) = M []
fE m = case L.findIndices (/= zero) (map head xs) of
(i:is) -> case fE (cancelOut m [ (i,map head xs !! i) | i <- is ] (i,map head xs !! i)) of
ys -> matrix (xs !! i : map (zero :) (unMVec ys))
[] -> case fE (matrix (map tail xs)) of
ys -> matrix (map (zero:) (unMVec ys))
where
cancelOut :: (Field a, Eq a) => Matrix a -> [(Int,a)] -> (Int,a) -> Matrix a
cancelOut m [] (i,_) = let xs = unMVec m in matrix $ map tail (L.delete (xs !! i) xs)
cancelOut m ((t,x):xs) (i,p) = cancelOut (pivot m (neg (x p)) i t) xs (i,p)
xs = unMVec m
-- | Compute row echelon form of a system Ax=b.
forwardElim :: (Field a, Eq a) => (Matrix a,Vector a) -> (Matrix a,Vector a)
forwardElim (m,v) = fE m' (0,0)
where
-- fE takes the matrix to eliminate and the current row and column
fE :: (Field a, Eq a) => Matrix a -> (Int,Int) -> (Matrix a,Vector a)
fE (M []) _ = error "forwardElim: Empty input matrix"
fE m rc@(r,c)
-- The algorithm is done when it reaches the last column or row.
| c == mc || r == mr =
-- Decompose the matrix into A and b again
(matrix *** Vec) $ unzip $ map (init &&& last) $ unMVec m
| m !!! rc == zero = case findPivot m rc of
-- If the pivot element is zero swap the pivot row with the first row
-- with a nonzero element in the pivot column.
Just (_,r') -> fE (swap m r r') rc
-- If all elements in the pivot column is zero the move right.
Nothing -> fE m (r,c+1)
| m !!! rc /= one =
-- Make the pivot element 1.
fE (scale m r (inv (m !!! rc))) rc
| otherwise = case findPivot m rc of
-- Make the first nonzero element in the pivot row 0.
Just (v,r') -> fE (pivot m (neg v) r r') (r,c)
-- If all elements in the pivot column is zero then move down and right.
Nothing -> fE m (r+1,c+1)
(mr,mc) = dimension m
-- Combine A and b to a matrix where the last column is b
m' = matrix $ [ r ++ [x] | (r,x) <- zip (unMVec m) (unVec v) ]
-- | Perform "jordan"-step in Gauss-Jordan elimination. That is make every
-- element above the diagonal zero. In other words compute the reduced
-- echelon form of a matrix given that the input is in row echelon form.
jordan :: (Field a, Eq a) => (Matrix a, Vector a) -> (Matrix a, Vector a)
jordan (m, Vec ys) = case L.unzip (jordan' (zip (unMVec m) ys) (r-1)) of
(a,b) -> (matrix a, Vec b)
where
(r,_) = dimension m
jordan' [] _ = []
jordan' xs c =
jordan' [ (take c x ++ zero : drop (c+1) x, v <-> x !! c <*> snd (last xs))
| (x,v) <- init xs ] (c-1) ++ [last xs]
-- | Gauss-Jordan elimination: Given A and B solve Ax=B.
gaussElim :: (Field a, Eq a, Show a) => (Matrix a, Vector a) -> (Matrix a, Vector a)
gaussElim = jordan . forwardElim
gaussElimCorrect :: (Field a, Eq a, Arbitrary a, Show a) => (Matrix a, Vector a) -> Property
gaussElimCorrect m@(a,b) = fst (dimension a) == lengthVec b && isSquareMatrix a ==>
matrixToVector (transpose (a `mulM` transpose (M [snd (gaussElim m)]))) == b