{-# LANGUAGE Haskell2010
    , TypeFamilies
    , FlexibleContexts
    , Trustworthy
 #-}
--{-# OPTIONS -Wall -fno-warn-name-shadowing #-}

module Numeric.Matrix (
    Matrix,
    MatrixElement (
        matrix,
        fromList,

        unit,
        zero,
        diag,
        empty,

        at,
        row,
        col,

        select,
        toList,

        dimensions,
        numRows,
        numCols,

        isUnit,
        isZero,
        isDiagonal,
        isEmpty,
        isSquare,
        
        det,
        rank,
        transpose,
        trace,

        minus,
        plus,
        times,
        inv,

        map,
        all,
        any,

        mapWithIndex,
        allWithIndex,
        anyWithIndex
    )
) where


import Control.Applicative
import Control.Monad
import Control.Monad.ST

import Data.Function
import Data.Ratio
import Data.Complex
import Data.Maybe

import qualified Data.List as L
import Data.Array.IArray
import Data.Array.MArray
import Data.Array.Unboxed
import Data.Array.ST
import Data.STRef

import Prelude hiding (any, all, read)
import qualified Prelude as P

import qualified Debug.Trace as D

data family Matrix e

data instance Matrix Int    = IntMatrix    Int Int (Array Int (UArray Int Int))
data instance Matrix Float  = FloatMatrix  Int Int (Array Int (UArray Int Float))
data instance Matrix Double = DoubleMatrix Int Int (Array Int (UArray Int Double))

data instance Matrix Integer = IntegerMatrix Int Int (Array Int (Array Int Integer))
data instance Matrix (Ratio a) = RatioMatrix Int Int (Array Int (Array Int (Ratio a)))
data instance Matrix (Complex a) = ComplexMatrix Int Int (Array Int (Array Int (Complex a)))

instance (MatrixElement e, Show e) => Show (Matrix e) where
    show = unlines . P.map showRow . toList
      where
        showRow = unwords . P.map ((' ':) . show)

instance (MatrixElement e) => Num (Matrix e) where
    (+) = plus
    (-) = minus
    (*) = times
    abs         = matrix (1,1) . const . abs . det
    signum      = matrix (1,1) . const . signum . det
    fromInteger = matrix (1,1) . const . fromInteger
    
instance (MatrixElement e, Fractional e) => Fractional (Matrix e) where
    recip        = fromJust . inv
    fromRational = matrix (1,1) . const . fromRational

instance (MatrixElement e) => Eq (Matrix e) where
    m == n
        | dimensions m == dimensions n
            = allWithIndex (\ix e -> m `at` ix == e) n
        | otherwise = False


class Division e where
    divide :: e -> e -> e

instance Division Int    where divide = quot
-- instance Division Int8   where divide = quot
-- instance Division Int16  where divide = quot
-- instance Division Int32  where divide = quot
-- instance Division Int64  where divide = quot
instance Division Integer where divide = quot
instance Division Float  where divide = (/)
instance Division Double where divide = (/)
instance Integral a => Division (Ratio a) where divide = (/)
instance RealFloat a => Division (Complex a) where divide = (/)



class (Eq e, Num e) => MatrixElement e where

    matrix :: (Int, Int) -> ((Int, Int) -> e) -> Matrix e
    select :: ((Int, Int) -> Bool) -> Matrix e -> [e]
    at :: Matrix e -> (Int, Int) -> e

    row :: Int -> Matrix e -> [e]
    col :: Int -> Matrix e -> [e]

    dimensions :: Matrix e -> (Int, Int)
    numRows :: Matrix e -> Int
    numCols :: Matrix e -> Int

    fromList :: [[e]] -> Matrix e
    toList   :: Matrix e -> [[e]]

    unit  :: Int -> Matrix e
    zero  :: Int -> Matrix e
    diag  :: [e] -> Matrix e
    empty :: Matrix e

    minus :: Matrix e -> Matrix e -> Matrix e
    plus  :: Matrix e -> Matrix e -> Matrix e
    times :: Matrix e -> Matrix e -> Matrix e
    inv   :: Matrix e -> Maybe (Matrix e)

--    adjugate  :: Matrix e -> Matrix e
--    cofactors :: Matrix e -> Matrix e ; cofactors = undefined
    det       :: Matrix e -> e
    transpose :: Matrix e -> Matrix e
    rank      :: Matrix e -> e
    trace     :: Matrix e -> [e]

    isUnit     :: Matrix e -> Bool
    isDiagonal :: Matrix e -> Bool
    isZero     :: Matrix e -> Bool
    isEmpty    :: Matrix e -> Bool
    isSquare   :: Matrix e -> Bool

    select p m = [ at m (i,j) | i <- [1..numRows m]
                              , j <- [1..numCols m]
                              , p (i,j) ]

    at m (i, j) = ((!! j) . (!! i) . toList) m

    map :: MatrixElement f => (e -> f) -> Matrix e -> Matrix f
    all :: (e -> Bool) -> Matrix e -> Bool
    any :: (e -> Bool) -> Matrix e -> Bool

    mapWithIndex :: MatrixElement f => ((Int, Int) -> e -> f) -> Matrix e -> Matrix f
    allWithIndex :: ((Int, Int) -> e -> Bool) -> Matrix e -> Bool
    anyWithIndex :: ((Int, Int) -> e -> Bool) -> Matrix e -> Bool

    unit n  = fromList [[ if i == j then 1 else 0 | j <- [1..n]] | i <- [1..n] ]
    zero n  = matrix (n,n) (const 0)
    empty   = fromList []
    diag xs = matrix (n,n) (\(i,j) -> if i == j then xs !! (i-1) else 0)
      where n = length xs
    
    row i m = ((!! (i-1)) . toList) m
    col i m = (row i . transpose) m

    numRows = fst . dimensions
    numCols = snd . dimensions
    dimensions m = case toList m of [] -> (0, 0)
                                    (x:xs) -> (length xs + 1, length x)

--    adjugate = transpose . cofactors
    transpose = fromList . L.transpose . toList
    trace = select (uncurry (==))
    inv _ = Nothing

    isZero = all (== 0)
    isUnit m = isSquare m && P.all (== 1) (trace m)
    isEmpty m = numRows m == 0 || numCols m == 0
    isDiagonal = allWithIndex (uncurry $ \x y z -> if x /= y then z == 0 else True)
    isSquare m = let (a, b) = dimensions m in a == b

    map f = mapWithIndex (const f)
    all f = allWithIndex (const f)
    any f = anyWithIndex (const f)

    mapWithIndex f m = matrix (dimensions m) (\x -> f x (m `at` x))
    allWithIndex f m = P.all id [ f (i, j) (m `at` (i,j))
                                | i <- [1..numRows m], j <- [1..numCols m]]
    anyWithIndex f m = P.any id [ f (i, j) (m `at` (i,j))
                                | i <- [1..numRows m], j <- [1..numCols m]]

    a `plus` b
        | dimensions a /= dimensions b = error "Matrix.plus: dimensions don't match."
        | otherwise = matrix (dimensions a) (\x -> a `at` x + b `at` x)
    a `minus` b
        | dimensions a /= dimensions b = error "Matrix.minus: dimensions don't match."
        | otherwise = matrix (dimensions a) (\x -> a `at` x - b `at` x)
    a `times` b
        | numRows a /= numCols b = error "Matrix.times: `numRows a' and `numCols b' don't match."
        | otherwise = fromList [ [ row i a `dotProd` col j b
                                 | j <- [1..numCols b] ]
                               | i <- [1..numRows a] ]


instance MatrixElement Int where
    matrix   = _matrix IntMatrix
    fromList = _fromList IntMatrix

    at         (IntMatrix _ _ arr) = _at arr
    dimensions (IntMatrix m n _) = (m, n)
    row i      (IntMatrix _ _ arr) = _row i arr
    col j      (IntMatrix _ _ arr) = _col j arr
    toList     (IntMatrix _ _ arr) = _toList arr
    inv = undefined -- IntMatrix $ runST (invSTU arr)
    det        (IntMatrix m n arr) = if m /= n then 0 else runST (_det thawsUnboxed arr)
    rank = undefined -- runST (_rank thawsBoxed arr)

instance MatrixElement Integer where
    matrix   = _matrix IntegerMatrix
    fromList = _fromList IntegerMatrix

    at         (IntegerMatrix _ _ arr) = _at arr
    dimensions (IntegerMatrix m n _) = (m, n)
    row i      (IntegerMatrix _ _ arr) = _row i arr
    col j      (IntegerMatrix _ _ arr) = _col j arr
    toList     (IntegerMatrix _ _ arr) = _toList arr
    inv = undefined -- IntMatrix $ runST (invSTU arr)
    det        (IntegerMatrix m n arr) = if m /= n then 0 else runST (_det thawsBoxed arr)
    rank = undefined -- runST (_rank thawsBoxed arr)

instance MatrixElement Float where
    matrix   = _matrix FloatMatrix
    fromList = _fromList FloatMatrix

    at         (FloatMatrix _ _ arr) = _at arr
    dimensions (FloatMatrix m n _  ) = (m, n)
    row i      (FloatMatrix _ _ arr) = _row i arr
    col j      (FloatMatrix _ _ arr) = _col j arr
    toList     (FloatMatrix _ _ arr) = _toList arr
    inv        (FloatMatrix m n arr) = if m /= n then Nothing else
                                        Just $ FloatMatrix m n $ runST (_inv unboxedST arr)
    det        (FloatMatrix m n arr) = if m /= n then 0 else runST (_det thawsUnboxed arr)
    rank       (FloatMatrix _ _ arr) = runST (_rank thawsBoxed arr)

instance MatrixElement Double where
    matrix   = _matrix DoubleMatrix
    fromList = _fromList DoubleMatrix

    at         (DoubleMatrix _ _ arr) = _at arr
    dimensions (DoubleMatrix m n _  ) = (m, n)
    row i      (DoubleMatrix _ _ arr) = _row i arr
    col j      (DoubleMatrix _ _ arr) = _col j arr
    toList     (DoubleMatrix _ _ arr) = _toList arr
    inv        (DoubleMatrix m n arr) = if m /= n then Nothing else
                                         Just $ DoubleMatrix m n $ runST (_inv unboxedST arr)
    det        (DoubleMatrix m n arr) = if m /= n then 0 else runST (_det thawsUnboxed arr)
    rank       (DoubleMatrix _ _ arr) = runST (_rank thawsBoxed arr)

instance (Show a, Integral a) => MatrixElement (Ratio a) where
    matrix   = _matrix RatioMatrix
    fromList = _fromList RatioMatrix

    at         (RatioMatrix _ _ arr) = _at arr
    dimensions (RatioMatrix m n _  ) = (m, n)
    row i      (RatioMatrix _ _ arr) = _row i arr
    col j      (RatioMatrix _ _ arr) = _col j arr
    toList     (RatioMatrix _ _ arr) = _toList arr
    inv        (RatioMatrix m n arr) = if m /= n then Nothing else
                                        Just $ RatioMatrix m n $ runST (_inv boxedST arr)
    det        (RatioMatrix m n arr) = if m /= n then 0 else  runST (_det thawsBoxed arr)
    rank       (RatioMatrix _ _ arr) = runST (_rank thawsBoxed arr)

instance (Show a, RealFloat a) => MatrixElement (Complex a) where
    matrix   = _matrix ComplexMatrix
    fromList = _fromList ComplexMatrix

    at         (ComplexMatrix _ _ arr) = _at arr
    dimensions (ComplexMatrix m n _  ) = (m, n)
    row i      (ComplexMatrix _ _ arr) = _row i arr
    col j      (ComplexMatrix _ _ arr) = _col j arr
    toList     (ComplexMatrix _ _ arr) = _toList arr
    inv        (ComplexMatrix m n arr) = Nothing
--if m /= n then Nothing else
-- Just $ ComplexMatrix m n $ runST (_inv boxedST arr)
    det        (ComplexMatrix m n arr) = if m /= n then 0 else runST (_det thawsBoxed arr)
    rank       (ComplexMatrix _ _ arr) = runST (_rank thawsBoxed arr)


_at :: (IArray a (u Int e), IArray u e)
    => a Int (u Int e) -> (Int, Int) -> e
_at arr (i,j) = arr ! i ! j

_row, _col :: (IArray a (u Int e), IArray u e) => Int -> a Int (u Int e) -> [e]
_row i arr = let row = arr ! i in [ row ! j | j <- [1..(snd (bounds arr))] ]
_col j arr = [ arr ! i ! j | i <- [1..(snd (bounds arr))] ]

_matrix :: (IArray a (u Int e), IArray u e)
        => (Int -> Int -> (a Int (u Int e)) -> matrix e)
        -> (Int, Int)
        -> ((Int, Int) -> e)
        -> matrix e
_matrix c (numRows, numCols) generator =
    c numRows numCols
      $ array (1, numRows)
      $ [ (i, array (1, numCols) [(j, generator (i, j))
                                 | j <- [1..numCols]])
          | i <- [1..numRows] ]

_toList :: (IArray a e) => Array Int (a Int e) -> [[e]]
_toList = P.map elems . elems

_fromList :: (IArray a (u Int e), IArray u e)
          => (Int -> Int -> a Int (u Int e) -> matrix e) -> [[e]] -> matrix e
_fromList c xs =
    let lengths = P.map length xs
        numCols = foldl1 min lengths
        numRows = length lengths
        
    in  c numRows numCols
          $ array (1, numRows)
          $ zip [1..numRows]
          $ P.map (array (1, numCols) . zip [1..numCols]) xs

dotProd :: Num a => [a] -> [a] -> a
dotProd x = L.foldl' (+) 0 . zipWith (*) x

thawsBoxed :: (IArray a e, MArray (STArray s) e (ST s))
           => Array Int (a Int e)
           -> ST s [STArray s Int e]
thawsBoxed = mapM thaw . elems

thawsUnboxed :: (IArray a e, MArray (STUArray s) e (ST s))
             => Array Int (a Int e)
             -> ST s [STUArray s Int e]
thawsUnboxed = mapM thaw . elems

arrays :: [(u s) Int e]
       -> ST s ((STArray s) Int ((u s) Int e))
arrays list = newListArray (1, length list) list

augment :: (IArray a e, MArray (u s) e (ST s), Num e)
        => ((Int, Int) -> [e] -> ST s ((u s) Int e))
        -> Array Int (a Int e)
        -> ST s (STArray s Int (u s Int e))
augment _ arr = do
    let (_, n) = bounds arr
        row (a,i) = newListArray (1, 2*n)
                                 [ if j > n then (if j == i + n then 1 else 0)
                                            else a ! j
                                 | j <- [1..2*n] ]
    
    mapM row (zip (elems arr) [1..]) >>= newListArray (1, n)

boxedST :: MArray (STArray s) e (ST s)
        => (Int, Int) -> [e] -> ST s ((STArray s) Int e)
boxedST = newListArray

unboxedST :: MArray (STUArray s) e (ST s)
          => (Int, Int) -> [e] -> ST s ((STUArray s) Int e)
unboxedST = newListArray


tee :: Monad m => (b -> m a) -> b -> m b
tee f x = f x >> return x

read :: (MArray a1 b m, MArray a (a1 Int b) m) =>
                       a Int (a1 Int b) -> Int -> Int -> m b
read a i j = readArray a i >>= flip readArray j


_inv :: (IArray a e, MArray (u s) e (ST s), Fractional e, Ord e, Show e)
     => ((Int, Int) -> [e] -> ST s ((u s) Int e))
     -> Array Int (a Int e)
     -> ST s (Array Int (a Int e))
_inv mkArrayST mat = do
    let m = snd $ bounds mat
        n = 2*m

        swap a i j = do
            tmp <- readArray a i
            readArray a j >>= writeArray a i
            writeArray a j tmp

    a <- augment mkArrayST mat

    flip mapM_ [1..m] $ \k -> do
        iPivot <- zip [k..m] <$> mapM (\i -> abs <$> read a i k) [k..m]
                    >>= return . fst . L.maximumBy (compare `on` snd)

        p <- read a iPivot k
        when (p == 0) (fail "not invertible")
        swap a iPivot k

        flip mapM_ [k+1..m] $ \i -> do
            a_i <- readArray a i
            a_k <- readArray a k
            flip mapM_ [k+1..n] $ \j -> do
                a_ij <- readArray a_i j
                a_kj <- readArray a_k j
                a_ik <- readArray a_i k
                writeArray a_i j (a_ij - a_kj * (a_ik / p))
            writeArray a_i k 0

    flip mapM_ [ m - v | v <- [0..m-1] ] $ \i -> do
        a_i <- readArray a i
        p   <- readArray a_i i
        writeArray a_i i 1
        flip mapM_ [i+1..n] $ \j -> do
            readArray a_i j >>= writeArray a_i j . (/ p)

        unless (i == m) $ do
            flip mapM_ [i+1..m] $ \k -> do
                a_k <- readArray a k
                p   <- readArray a_i k

                flip mapM_ [k..n] $ \j -> do
                    a_ij <- readArray a_i j
                    a_kj <- readArray a_k j
                    writeArray a_i j (a_ij - p * a_kj)

    mapM (\i -> readArray a i >>= getElems
                    >>= return . listArray (1, m) . drop m) [1..m]
        >>= return . listArray (1, m)


_rank :: (IArray a e, MArray (u s) e (ST s), Fractional e, Eq e)
      => (Array Int (a Int e) -> ST s [(u s) Int e])
      -> Array Int (a Int e)
      -> ST s e
_rank thaws mat = do
    let m = snd $ bounds mat
        n = snd $ bounds (mat ! 1)

    a <- thaws mat >>= arrays

    trace <- flip mapM [1..m] $ \k -> do
        flip mapM_ [(k+1)..m] $ \i -> do
            a_i <- readArray a i
            a_k <- readArray a k
            flip mapM_ [(k+1)..n] $ \j -> do
                a_ij <- readArray a_i j
                a_kj <- readArray a_k j
                a_ik <- readArray a_i k
                a_kk <- readArray a_k k
                writeArray a_i j (a_ij - a_kj * (a_ik / a_kk))
            writeArray a_i k 0
        read a k k

    return $ fromIntegral $ length $ filter (/= 0) trace


_det :: (IArray a e, MArray (u s) e (ST s),
         Num e, Eq e, Division e)
     => (Array Int (a Int e) -> ST s [(u s) Int e])
     -> Array Int (a Int e) -> ST s e
_det thaws mat = do

    let size = snd $ bounds mat

    a <- thaws mat >>= arrays

    signR  <- newSTRef 1
    pivotR <- newSTRef 1

    flip mapM_ [1..size] $ \k -> do

        prev  <- readSTRef pivotR
        pivot <- read a k k >>= tee (writeSTRef pivotR)

        when (pivot == 0) $ do
            s <- flip mapM [(k+1)..size] $ \r -> do
                a_rk <- read a r k
                if a_rk == 0 then return 0 else return r
            let sf = filter (>0) s

            when (not $ null sf) $ do
                let sw = head sf

                row <- readArray a sw
                readArray a k >>= writeArray a sw
                writeArray a k row

                read a k k >>= writeSTRef pivotR
                readSTRef signR >>= writeSTRef signR . negate

        pivot' <- readSTRef pivotR
        flip mapM [(k+1)..size] $ \i -> do
            a_i <- readArray a i
            flip mapM [(k+1)..size] $ \j -> do
                a_ij <- readArray a_i j
                a_ik <- readArray a_i k
                a_kj <- read a k j
                writeArray a_i j ((pivot' * a_ij - a_ik * a_kj) `divide` prev)

    liftM2 (*) (readSTRef pivotR) (readSTRef signR)