```{-# LANGUAGE TypeSynonymInstances #-}
-- | A small simple matrix library.
module Algebra.Matrix
( Vector(Vec)
, unVec, lengthVec
, Matrix(M), matrix
, matrixToVector, vectorToMatrix, unMVec, unM, (!!!)
, identity, propLeftIdentity, propRightIdentity
, mulM, addM, transpose, isSquareMatrix, dimension
, scale, swap, pivot
, findPivot, forwardElim, gaussElim, gaussElimCorrect
) where

import qualified Data.List as L
import Data.Function (on)
import Control.Arrow hiding ((<+>))
import Test.QuickCheck

import Algebra.Structures.Field

import Debug.Trace

-------------------------------------------------------------------------------
-- | Row vectors

newtype Vector r = Vec [r] deriving (Eq)

instance Show r => Show (Vector r) where
show (Vec vs) = show vs

-- Generate vector of length 1-10
instance Arbitrary r => Arbitrary (Vector r) where
arbitrary = do n <- choose (1,10) :: Gen Int
liftM Vec \$ gen n
where
gen 0 = return []
gen n = do x <- arbitrary
xs <- gen (n-1)
return (x:xs)

instance Functor Vector where
fmap f = Vec . map f . unVec

{-
instance Ring r => Ring (Vector r) where
(Vec xs) <+> (Vec ys) | length xs == length ys = Vec (zipWith (<+>) xs ys)
| otherwise = error "Bad dimensions in vector addition"
(Vec xs) <*> (Vec ys) | length xs == length ys = Vec (zipWith (<*>) xs ys)
| otherwise = error "Bad dimensions in vector multiplication"
-- In order to do these we need to know the length of the vector in advance...
-- Give me dependent types!
one  = ?
zero = ?
-}

unVec :: Vector r -> [r]
unVec (Vec vs) = vs

lengthVec :: Vector r -> Int
lengthVec = length . unVec

-------------------------------------------------------------------------------
-- | Matrices

newtype Matrix r = M [Vector r]
deriving (Eq)

instance Show r => Show (Matrix r) where
show xs = case unlines (map show (unMVec xs)) of
[] -> "[]"
xs -> init xs ++ "\n"

-- Generate matrices with at most 10 rows
instance Arbitrary r => Arbitrary (Matrix r) where
arbitrary = do n <- choose (1,10) :: Gen Int
m <- choose (1,10) :: Gen Int
xs <- sequence [ liftM Vec (gen n) | _ <- [1..m]]
return (M xs)
where
gen 0 = return []
gen n = do x <- arbitrary
xs <- gen (n-1)
return (x:xs)

instance Functor Matrix where
fmap f = M . map (fmap f) . unM

-- | Construct a mxn matrix.
matrix :: [[r]] -> Matrix r
matrix xs =
let m = fromIntegral \$ length xs
n = fromIntegral \$ length (head xs)
in if length (filter (\x -> fromIntegral (length x) == n) xs) == length xs
then M (map Vec xs)
else error "matrix: Bad dimensions"

unM :: Matrix r -> [Vector r]
unM (M xs) = xs

unMVec :: Matrix r -> [[r]]
unMVec = map unVec . unM

vectorToMatrix :: Vector r -> Matrix r
vectorToMatrix = matrix . (:[]) . unVec

matrixToVector :: Matrix r -> Vector r
matrixToVector m | fst (dimension m) == 1 = head (unM m)
| otherwise              = error "matrixToVector: Bad dimension"

(!!!) :: Matrix a -> (Int,Int) -> a
m !!! (r,c) | r >= 0 && r < rows && c >= 0 && c < cols = unMVec m !! r !! c
| otherwise = error "!!!: Out of bounds"
where
(rows,cols) = dimension m

-- | Compute the dimension of a matrix.
dimension :: Matrix r -> (Int, Int)
dimension (M xs) | null xs   = (0,0)
| otherwise = (length xs, length (unVec (head xs)))

isSquareMatrix :: Matrix r -> Bool
isSquareMatrix (M xs) = all (== length xs) (map lengthVec xs)

-- | Transpose a matrix.
transpose :: Matrix r -> Matrix r
transpose (M xs) = matrix (L.transpose (map unVec xs))

-- | Matrix addition.
addM :: Ring r => Matrix r -> Matrix r -> Matrix r
addM (M xs) (M ys)
| dimension (M xs) == dimension (M ys) = m
| otherwise = error "Bad dimensions in matrix addition"
where
m = matrix (zipWith (zipWith (<+>)) (map unVec xs) (map unVec ys))

-- | Matrix multiplication.
mulM :: Ring r => Matrix r -> Matrix r -> Matrix r
mulM (M xs) (M ys)
| snd (dimension (M xs)) == fst (dimension (M ys)) = m
| otherwise = error "Bad dimensions in matrix multiplication"
where
m = matrix [ [ mulVec x y | y <- L.transpose (map unVec ys) ]
| x <- map unVec xs ]

mulVec xs ys | length xs == length ys = foldr (<+>) zero \$ zipWith (<*>) xs ys
| otherwise = error "mulVec: Bad dimension"

{-
-- In order to do this the size of the matrix need to be encoded in the type
-- There is also a problem with the fact that it is not possible to add or
-- multiply matrices with bad dimensions, so the generation of matrices has to be better...
instance Ring r => Ring (Matrix r) where
(<*>) = mul
neg (Vec xs d) = Vec [ map neg x | x <- xs ] d
zero  = undefined
-}

-- | Construct a nxn identity matrix.
identity :: IntegralDomain r => Int -> Matrix r
identity n = matrix (xs 0)
where
xs x | x == n    = []
| otherwise = (replicate x zero ++ [one] ++
replicate (n-x-1) zero) : xs (x+1)

-- Specification of identity.
propLeftIdentity :: (IntegralDomain r, Eq r) => Matrix r -> Bool
propLeftIdentity a = a == identity n `mulM` a
where n = fst (dimension a)

propRightIdentity :: (IntegralDomain r, Eq r) => Matrix r -> Bool
propRightIdentity a = a == a `mulM` identity m
where m = snd (dimension a)

-------------------------------------------------------------------------------
-- Operations on matrices.

-- | Scale a row in a matrix.
scale :: CommutativeRing a => Matrix a -> Int -> a -> Matrix a
scale m r s
| 0 <= r && r < rows = matrix \$ take r m' ++ map (s <*>) (m' !! r) : drop (r+1) m'
| otherwise = error "scale: Index out of bounds"
where
(rows,_) = dimension m
m'       = unMVec m

-- Scaling does not affect dimension
propScaleDimension :: (Arbitrary r, CommutativeRing r) => Matrix r -> Int -> r -> Bool
propScaleDimension m r s = d == dimension (scale m (mod r rows) s)
where d@(rows,_) = dimension m

-- | Swap two rows of a matrix.
swap :: Matrix a -> Int -> Int -> Matrix a
swap m i j
| 0 <= i && i <= r && 0 <= j && j <= r = matrix \$ swap' m' i j
| otherwise = error "swap: Index out of bounds"
where
(r,_) = dimension m
m'    = unMVec m

swap' xs 0 0     = xs
swap' (x:xs) 0 j = (x:xs) !! j : take (j-1) xs ++ x : drop j xs
swap' xs i 0     = swap' xs 0 i
swap' (x:xs) i j = x : swap' xs (i-1) (j-1)

-- Swapping does not affect dimension
propSwapDimension :: Matrix () -> Int -> Int -> Bool
propSwapDimension m i j = d == dimension (swap m (mod i r) (mod j r))
where d@(r,_) = dimension m

-- Swap is itselfs identity.
propSwapIdentity :: Matrix () -> Int -> Int -> Bool
propSwapIdentity m i j = m == swap (swap m i' j') i' j'
where
d@(r,_) = dimension m
i'      = mod i r
j'      = mod j r

-- Add the row-vector to the specified row of the matrix.
addRow :: CommutativeRing a => Matrix a -> Vector a -> Int -> Matrix a
addRow m row@(Vec xs) x
| 0 <= x && x < r = matrix \$ take x m' ++
zipWith (<+>) (m' !! x) xs :
drop (x+1) m'
| c /= length xs  = error "addRow: Bad length of row"
| otherwise       = error "addRow: Bad row number"
where
(r,c) = dimension m
m'    = unMVec m

propAddRowDimension :: (CommutativeRing a, Arbitrary a)
=> Matrix a -> Vector a -> Int -> Property
propAddRowDimension m row@(Vec xs) r =
length xs == c ==> d == dimension (addRow m row (mod r r'))
where d@(r',c) = dimension m

addCol :: CommutativeRing a => Matrix a -> Vector a -> Int -> Matrix a
addCol m c x = transpose \$ addRow (transpose m) c x

subRow, subCol :: CommutativeRing a => Matrix a -> Vector a -> Int -> Matrix a
subRow m (Vec xs) x = addRow m (Vec (map neg xs)) x
subCol m (Vec xs) x = addCol m (Vec (map neg xs)) x

-- Multiply the pivot row and add it to the target row.
pivot :: CommutativeRing a => Matrix a -> a -> Int -> Int -> Matrix a
pivot m s p t = addRow m (fmap (s <*>) (unM m !! p)) t

-- Find first non-zero number below the pivot and return its value and row number
-- given that it exists
findPivot :: (CommutativeRing a, Eq a) => Matrix a -> (Int,Int) -> Maybe (a,Int)
findPivot m (r,c) = safeHead \$ filter ((/= zero) . fst) \$ drop (r+1) \$ zip (head \$ drop c \$ unMVec \$ transpose m) [0..]
where
m' = unMVec m

safeHead []     = Nothing
safeHead (x:xs) = Just x

fE :: (Field a, Eq a) => Matrix a -> Matrix a
fE (M [])         = M []
fE (M (Vec []:_)) = M []
fE m     = case L.findIndices (/= zero) (map head xs) of
(i:is) -> case fE (cancelOut m [ (i,map head xs !! i) | i <- is ] (i,map head xs !! i)) of
ys -> matrix (xs !! i : map (zero :) (unMVec ys))
[]     -> case fE (matrix (map tail xs)) of
ys -> matrix (map (zero:) (unMVec ys))
where
cancelOut :: (Field a, Eq a) => Matrix a -> [(Int,a)] -> (Int,a) -> Matrix a
cancelOut m [] (i,_)    = let xs = unMVec m in matrix \$ map tail (L.delete (xs !! i) xs)
cancelOut m ((t,x):xs) (i,p) = cancelOut (pivot m (neg (x </> p)) i t) xs (i,p)

xs = unMVec m

-- | Compute row echelon form of a system Ax=b.
forwardElim :: (Field a, Eq a) => (Matrix a,Vector a) -> (Matrix a,Vector a)
forwardElim (m,v) = fE m' (0,0)
where
-- fE takes the matrix to eliminate and the current row and column
fE :: (Field a, Eq a) => Matrix a -> (Int,Int) -> (Matrix a,Vector a)
fE (M []) _  = error "forwardElim: Empty input matrix"
fE m rc@(r,c)
-- The algorithm is done when it reaches the last column or row.
| c == mc || r == mr =
-- Decompose the matrix into A and b again
(matrix *** Vec) \$ unzip \$ map (init &&& last) \$ unMVec m

| m !!! rc == zero   = case findPivot m rc of
-- If the pivot element is zero swap the pivot row with the first row
-- with a nonzero element in the pivot column.
Just (_,r') -> fE (swap m r r') rc
-- If all elements in the pivot column is zero the move right.
Nothing     -> fE m (r,c+1)

| m !!! rc /= one    =
-- Make the pivot element 1.
fE (scale m r (inv (m !!! rc))) rc

| otherwise          = case findPivot m rc of
-- Make the first nonzero element in the pivot row 0.
Just (v,r') -> fE (pivot m (neg v) r r') (r,c)
-- If all elements in the pivot column is zero then move down and right.
Nothing     -> fE m (r+1,c+1)

(mr,mc) = dimension m

-- Combine A and b to a matrix where the last column is b
m' = matrix \$ [ r ++ [x] | (r,x) <- zip (unMVec m) (unVec v) ]

-- | Perform "jordan"-step in Gauss-Jordan elimination. That is make every
-- element above the diagonal zero. In other words compute the reduced
-- echelon form of a matrix given that the input is in row echelon form.
jordan :: (Field a, Eq a) => (Matrix a, Vector a) -> (Matrix a, Vector a)
jordan (m, Vec ys) = case L.unzip (jordan' (zip (unMVec m) ys) (r-1)) of
(a,b) -> (matrix a, Vec b)
where
(r,_) = dimension m

jordan' [] _ = []
jordan' xs c =
jordan' [ (take c x ++ zero : drop (c+1) x, v <-> x !! c <*> snd (last xs))
| (x,v) <- init xs ] (c-1) ++ [last xs]

-- | Gauss-Jordan elimination: Given A and B solve Ax=B.
gaussElim :: (Field a, Eq a, Show a) => (Matrix a, Vector a) -> (Matrix a, Vector a)
gaussElim = jordan . forwardElim

gaussElimCorrect :: (Field a, Eq a, Arbitrary a, Show a) => (Matrix a, Vector a) -> Property
gaussElimCorrect m@(a,b) = fst (dimension a) == lengthVec b && isSquareMatrix a ==>
matrixToVector (transpose (a `mulM` transpose (M [snd (gaussElim m)]))) == b
```