{-# LANGUAGE Haskell2010
    , TypeFamilies
    , FlexibleContexts
    , Trustworthy
 #-}
{-# OPTIONS -Wall -fno-warn-name-shadowing #-}

-- | Efficient matrix operations in 100% pure Haskell.
--
-- This package uses miscellaneous implementations,
-- depending on the type of its components. Typically unboxed
-- arrays will perform best, while unboxed arrays give you
-- certain features such as 'Rational' or 'Complex' components.
--
-- The following component types are supported by 'Matrix':
-- 
-- [@Int@] Uses unboxed arrays internally. 'inv' will always
--      return 'Nothing'.
--
-- [@Integer@] Uses boxed arrays internally. 'inv' will always
--      return 'Nothing'.
--
-- [@Double@ and @Float@] Uses unboxed arrays internally.
--      All matrix operations will work as expected.
--      @Matrix Double@ will probably yield the best peformance.
--
-- [@Rational@] Best choice if precision is what you aim for.
--      Uses boxed arrays internally. All matrix operations will
--      work as expected.
--
-- [@Complex@] Experimental. Uses boxed arrays internally.
--      The current implementation of 'inv' requires an instance
--      of 'Ord' for the component type, therefor it is currently
--      not possible to calculate the inverse of a complex matrix
--      (on my to do list).
module Numeric.Matrix (

    Matrix,

    MatrixElement (..),

    -- * Matrix property and utility functions.
    (<|>),
    (<->),
    scale,

    -- ** Matrix properties
    isUnit,
    isZero,
    isDiagonal,
    isEmpty,
    isSquare
    
) where


import Control.Applicative ((<$>))
import Control.DeepSeq
import Control.Monad
import Control.Monad.ST

import Data.Function (on)
import Data.Ratio
import Data.Complex
import Data.Maybe
--import Data.Foldable (Foldable)
--import qualified Data.Foldable as F

import qualified Data.List as L
import Data.Array.IArray
import Data.Array.MArray
import Data.Array.Unboxed
import Data.Array.ST
import Data.STRef
import Data.Typeable

import Prelude hiding (any, all, read, map)
import qualified Prelude as P

-- | Matrices are represented by a type which fits best the component type.
-- For example a @Matrix Double@ is represented by unboxed arrays,
-- @Matrix Integer@ by boxed arrays.
--
-- Data instances exist for 'Int', 'Float', 'Double', 'Integer', 'Ratio',
-- and 'Complex'. Certain types do have certain disadvantages, like for
-- example you can not compute the inverse matrix of a @Matrix Int@.
--
-- Every matrix (regardless of the component type) has instances for
-- 'Show', 'Read', 'Num', 'Fractional', 'Eq', 'Typeable', and 'NFData'.
-- This means that you can use arithmetic operations like '+', '*', and
-- '/', as well as functions like 'show', 'read', or 'typeOf'.
--
-- [@Show (Matrix e)@]
-- Note that a Show instance for the component type @e@ must exist.
-- 
-- [@Read (Matrix e)@]
-- You can read a matrix like so:
--
-- > read "1 0\n0 1\n" :: Matrix Double
--
-- [@Num (Matrix e)@]
-- '+', '-', '*', 'negate', 'abs', 'signum', and 'fromInteger'.
--
-- 'signum' will compute the determinant and return the signum
-- of it.
--
-- 'abs' applies @map abs@ on the matrix (that is, it applies
-- @abs@ on every component in the matrix and returns a new
-- matrix without negative components).
--
-- @fromInteger@ yields a 1-x-1-matrix.
--
-- [@Fractional (Matrix e)@]
-- Only available if there exists an instance @Fractional e@
-- (the component type needs to have a @Fractional@ instance, too).
-- Note that while the 'Num' operations are safe, 'recip' and
-- '/' will fail (with an 'error') if the involved matrix is
-- not invertible or not a square matrix.
--
-- [@NFData (Matrix e)@]
-- Matrices have instances for NFData so that you can use a
-- matrix in parallel computations using the @Control.Monad.Par@
-- monad (see the @monad-par@ package for details).
--
-- [@Typeable (Matrix e)@]
-- Allows you to use matrices as 'Data.Dynamic' values.
data family Matrix e

data instance Matrix Int
    = IntMatrix !Int !Int (Array Int (UArray Int Int))

data instance Matrix Float
    = FloatMatrix !Int !Int (Array Int (UArray Int Float))

data instance Matrix Double
    = DoubleMatrix !Int !Int (Array Int (UArray Int Double))

data instance Matrix Integer
    = IntegerMatrix !Int !Int (Array Int (Array Int Integer))

data instance Matrix (Ratio a)
    = RatioMatrix !Int !Int (Array Int (Array Int (Ratio a)))

data instance Matrix (Complex a)
    = ComplexMatrix !Int !Int (Array Int (Array Int (Complex a)))

instance Typeable a => Typeable (Matrix a) where
    typeOf x = mkTyConApp (mkTyCon3 "bed-and-breakfast"
                                    "Numeric.Matrix"
                                    "Matrix") [typeOf (unT x)]
      where
        unT :: Matrix a -> a
        unT = undefined

instance (MatrixElement e, Show e) => Show (Matrix e) where
    show = unlines . P.map showRow . toList
      where
        showRow = unwords . P.map ((' ':) . show)

instance (Read e, MatrixElement e) => Read (Matrix e) where
    readsPrec _ = (\x -> [(x, "")]) . fromList . P.map (P.map P.read . words) . lines

instance (MatrixElement e) => Num (Matrix e) where
    (+) = plus
    (-) = minus
    (*) = times
    abs         = map abs
    signum      = matrix (1,1) . const . signum . det
    fromInteger = matrix (1,1) . const . fromInteger
    
instance (MatrixElement e, Fractional e) => Fractional (Matrix e) where
    recip        = fromJust . inv
    fromRational = matrix (1,1) . const . fromRational

instance (MatrixElement e) => Eq (Matrix e) where
    m == n
        | dimensions m == dimensions n
            = allWithIndex (\ix e -> m `at` ix == e) n
        | otherwise = False

instance (MatrixElement e) => NFData (Matrix e) where
    rnf matrix = matrix `deepseq` ()


(<|>) :: MatrixElement e => Matrix e -> Matrix e -> Matrix e
-- ^ Joins two matrices horizontally.
--
-- > 1 2 3     1 0 0      1 2 3 1 0 0
-- > 3 4 5 <|> 2 1 0  ->  3 4 5 2 1 0
-- > 5 6 7     3 2 1      5 6 7 3 2 1
m1 <|> m2 = let m = numCols m1
                n1 = numRows m1
                n2 = numRows m2
            in matrix (max n1 n2, m + numCols m2)
              $ \(i,j) -> if j > m
                    then (if i > n2 then 0 else m2 `at` (i,j-m))
                    else (if i > n1 then 0 else m1 `at` (i,j))

(<->) :: MatrixElement e => Matrix e -> Matrix e -> Matrix e
-- ^ Joins two matrices vertically.
--
-- > 1 2 3     1 0 0      1 2 3
-- > 3 4 5 <-> 2 1 0  ->  3 4 5
-- > 5 6 7     3 2 1      5 6 7
-- >                      1 0 0
-- >                      2 1 0
-- >                      3 2 1
m1 <-> m2 = let m = numRows m1
                n1 = numCols m1
                n2 = numCols m2
            in matrix (m + numRows m2, max n1 n2)
              $ \(i,j) -> if i > m
                    then (if j > n2 then 0 else m2 `at` (i-m,j))
                    else (if j > n1 then 0 else m1 `at` (i,j))

scale :: MatrixElement e => Matrix e -> e -> Matrix e
-- ^ Scales a matrix by the given factor.
-- 
-- > scale s == map (*s)
scale m s = map (*s) m


isUnit, isDiagonal, isZero, isEmpty, isSquare :: MatrixElement e => Matrix e -> Bool

-- | Check whether the matrix consists of all zeros.
--
-- > isZero == all (== 0)
isZero = all (== 0)

-- | Check whether the matrix is an identity matrix.
--
-- > 1 0 0
-- > 0 1 0
-- > 0 0 1 (True)
isUnit m = isSquare m && allWithIndex (uncurry check) m
    where check = \i j e -> if i == j then e == 1 else e == 0

-- | Checks whether the matrix is empty.
--
-- > isEmpty m = numCols == 0 || numRows == 0
isEmpty m = numRows m == 0 || numCols m == 0

-- | Checks whether the matrix is a diagonal matrix.
--
-- > 4 0 0 0
-- > 0 7 0 0
-- > 0 0 3 0
-- > 0 0 0 9 (True)
isDiagonal m = isSquare m && allWithIndex (uncurry check) m
    where check = \i j e -> if i /= j then e == 0 else True

-- | Checks whether the matrix is a square matrix.
--
-- > isSquare == uncurry (==) . dimensions
isSquare m = let (a, b) = dimensions m in a == b


class Division e where
    divide :: e -> e -> e

instance Division Int    where divide = quot
-- instance Division Int8   where divide = quot
-- instance Division Int16  where divide = quot
-- instance Division Int32  where divide = quot
-- instance Division Int64  where divide = quot
instance Division Integer where divide = quot
instance Division Float  where divide = (/)
instance Division Double where divide = (/)
instance Integral a => Division (Ratio a) where divide = (/)
instance RealFloat a => Division (Complex a) where divide = (/)



class (Eq e, Num e) => MatrixElement e where

    matrix :: (Int, Int) -> ((Int, Int) -> e) -> Matrix e
    select :: ((Int, Int) -> Bool) -> Matrix e -> [e]

    -- | Returns the component at the given position in the matrix.
    -- Note that indices start at one, not at zero.
    at :: Matrix e -> (Int, Int) -> e

    -- | Returns the row at the given index in the matrix.
    -- Note that indices start at one, not at zero.
    row :: Int -> Matrix e -> [e]

    -- | Returns the row at the given index in the matrix.
    -- Note that indices start at one, not at zero.
    col :: Int -> Matrix e -> [e]

    -- | The dimensions of a given matrix.
    dimensions :: Matrix e -> (Int, Int)
    numRows :: Matrix e -> Int
    numCols :: Matrix e -> Int


    -- | Builds a matrix from a list of lists.
    --
    -- > fromList [[1,2,3],[2,1,3],[3,2,1]] :: Matrix Rational
    fromList :: [[e]] -> Matrix e
    toList   :: Matrix e -> [[e]]

    -- | An identity square matrix of the given size.
    unit  :: Int -> Matrix e

    -- | A square matrix of the given size consisting of all zeros.
    zero  :: Int -> Matrix e

    diag  :: [e] -> Matrix e

    -- | Check whether the matrix is the empty matrix.
    --
    -- > dimensions empty == (0, 0)
    empty :: Matrix e

    minus :: Matrix e -> Matrix e -> Matrix e
    plus  :: Matrix e -> Matrix e -> Matrix e
    times :: Matrix e -> Matrix e -> Matrix e
    inv   :: Matrix e -> Maybe (Matrix e)

--    adjugate  :: Matrix e -> Matrix e
--    cofactors :: Matrix e -> Matrix e ; cofactors = undefined

    -- | Applies Bareiss multistep integer-preserving
    -- algorithm for finding the determinant of a matrix.
    -- Returns 0 if the matrix is not a square matrix.
    det       :: Matrix e -> e

    -- | Flips rows and columns.
    --
    -- > 1 8 9                1 2 3
    -- > 2 1 8  --transpose-> 8 1 2
    -- > 3 2 1                9 8 1 
    transpose :: Matrix e -> Matrix e
    rank      :: Matrix e -> e
    trace     :: Matrix e -> [e]

    minor :: MatrixElement e => Matrix e -> (Int, Int) -> e
    cofactors :: MatrixElement e => Matrix e -> Matrix e
    adjugate :: MatrixElement e => Matrix e -> Matrix e
    minorMatrix :: MatrixElement e => Matrix e -> (Int, Int) -> Matrix e

    -- | Applies a function on every component in the matrix.
    map :: MatrixElement f => (e -> f) -> Matrix e -> Matrix f

    -- | Applies a predicate on every component in the matrix
    -- and returns True if all components satisfy it.
    all :: (e -> Bool) -> Matrix e -> Bool

    -- | Applies a predicate on every component in the matrix
    -- and returns True if one or more components satisfy it.
    any :: (e -> Bool) -> Matrix e -> Bool

    mapWithIndex :: MatrixElement f => ((Int, Int) -> e -> f) -> Matrix e -> Matrix f
    allWithIndex :: ((Int, Int) -> e -> Bool) -> Matrix e -> Bool
    anyWithIndex :: ((Int, Int) -> e -> Bool) -> Matrix e -> Bool

    unit n  = fromList [[ if i == j then 1 else 0 | j <- [1..n]] | i <- [1..n] ]
    zero n  = matrix (n,n) (const 0)
    empty   = fromList []
    diag xs = matrix (n,n) (\(i,j) -> if i == j then xs !! (i-1) else 0)
      where n = length xs

    select p m = [ at m (i,j) | i <- [1..numRows m]
                              , j <- [1..numCols m]
                              , p (i,j) ]

    at m (i, j) = ((!! j) . (!! i) . toList) m
    
    row i m = ((!! (i-1)) . toList) m
    col i m = (row i . transpose) m

    numRows = fst . dimensions
    numCols = snd . dimensions
    dimensions m = case toList m of [] -> (0, 0)
                                    (x:xs) -> (length xs + 1, length x)

    adjugate = transpose . cofactors
    transpose = fromList . L.transpose . toList
    trace = select (uncurry (==))
    inv _ = Nothing

    minorMatrix m (i,j) = matrix (numRows m - 1, numCols m - 1) $
                \(i',j') -> m `at` (if i' >= i then i' + 1 else i',
                                    if j' >= j then j' + 1 else j')

    minor m = det . minorMatrix m

    cofactors m = matrix (dimensions m) $
       \(i,j) -> fromIntegral ((-1 :: Int)^(i+j)) * minor m (i,j)

    map f = mapWithIndex (const f)
    all f = allWithIndex (const f)
    any f = anyWithIndex (const f)

    mapWithIndex f m = matrix (dimensions m) (\x -> f x (m `at` x))
    allWithIndex f m = P.all id [ f (i, j) (m `at` (i,j))
                                | i <- [1..numRows m], j <- [1..numCols m]]
    anyWithIndex f m = P.any id [ f (i, j) (m `at` (i,j))
                                | i <- [1..numRows m], j <- [1..numCols m]]

    a `plus` b
        | dimensions a /= dimensions b = error "Matrix.plus: dimensions don't match."
        | otherwise = matrix (dimensions a) (\x -> a `at` x + b `at` x)
    a `minus` b
        | dimensions a /= dimensions b = error "Matrix.minus: dimensions don't match."
        | otherwise = matrix (dimensions a) (\x -> a `at` x - b `at` x)
    a `times` b
        | numRows a /= numCols b = error "Matrix.times: `numRows a' and `numCols b' don't match."
        | otherwise = fromList [ [ row i a `dotProd` col j b
                                 | j <- [1..numCols b] ]
                               | i <- [1..numRows a] ]


instance MatrixElement Int where
    matrix   = _matrix IntMatrix
    fromList = _fromList IntMatrix

    at         (IntMatrix _ _ arr) = _at arr
    dimensions (IntMatrix m n _) = (m, n)
    row i      (IntMatrix _ _ arr) = _row i arr
    col j      (IntMatrix _ _ arr) = _col j arr
    toList     (IntMatrix _ _ arr) = _toList arr
    det        (IntMatrix m n arr) = if m /= n then 0 else runST (_det thawsUnboxed arr)
    rank       (IntMatrix _ _ arr) = runST (_rank thawsBoxed arr)

instance MatrixElement Integer where
    matrix   = _matrix IntegerMatrix
    fromList = _fromList IntegerMatrix

    at         (IntegerMatrix _ _ arr) = _at arr
    dimensions (IntegerMatrix m n _) = (m, n)
    row i      (IntegerMatrix _ _ arr) = _row i arr
    col j      (IntegerMatrix _ _ arr) = _col j arr
    toList     (IntegerMatrix _ _ arr) = _toList arr
    det        (IntegerMatrix m n arr) = if m /= n then 0 else runST (_det thawsBoxed arr)
    rank       (IntegerMatrix _ _ arr) = runST (_rank thawsBoxed arr)

instance MatrixElement Float where
    matrix   = _matrix FloatMatrix
    fromList = _fromList FloatMatrix

    at         (FloatMatrix _ _ arr) = _at arr
    dimensions (FloatMatrix m n _  ) = (m, n)
    row i      (FloatMatrix _ _ arr) = _row i arr
    col j      (FloatMatrix _ _ arr) = _col j arr
    toList     (FloatMatrix _ _ arr) = _toList arr
    det        (FloatMatrix m n arr) = if m /= n then 0 else runST (_det thawsUnboxed arr)
    rank       (FloatMatrix _ _ arr) = runST (_rank thawsBoxed arr)
    inv        (FloatMatrix m n arr) = if m /= n then Nothing else
                                         let x = runST (_inv unboxedST arr)
                                         in maybe Nothing (Just . FloatMatrix m n) x

instance MatrixElement Double where
    matrix   = _matrix DoubleMatrix
    fromList = _fromList DoubleMatrix

    at         (DoubleMatrix _ _ arr) = _at arr
    dimensions (DoubleMatrix m n _  ) = (m, n)
    row i      (DoubleMatrix _ _ arr) = _row i arr
    col j      (DoubleMatrix _ _ arr) = _col j arr
    toList     (DoubleMatrix _ _ arr) = _toList arr
    inv        (DoubleMatrix m n arr) = if m /= n then Nothing else
                                         let x = runST (_inv unboxedST arr)
                                         in maybe Nothing (Just . DoubleMatrix m n) x
    det        (DoubleMatrix m n arr) = if m /= n then 0 else runST (_det thawsUnboxed arr)
    rank       (DoubleMatrix _ _ arr) = runST (_rank thawsBoxed arr)

instance (Show a, Integral a) => MatrixElement (Ratio a) where
    matrix   = _matrix RatioMatrix
    fromList = _fromList RatioMatrix

    at         (RatioMatrix _ _ arr) = _at arr
    dimensions (RatioMatrix m n _  ) = (m, n)
    row i      (RatioMatrix _ _ arr) = _row i arr
    col j      (RatioMatrix _ _ arr) = _col j arr
    toList     (RatioMatrix _ _ arr) = _toList arr
    inv        (RatioMatrix m n arr) = if m /= n then Nothing else
                                        let x = runST (_inv boxedST arr)
                                        in maybe Nothing (Just . RatioMatrix m n) x
    det        (RatioMatrix m n arr) = if m /= n then 0 else  runST (_det thawsBoxed arr)
    rank       (RatioMatrix _ _ arr) = runST (_rank thawsBoxed arr)

instance (Show a, RealFloat a) => MatrixElement (Complex a) where
    matrix   = _matrix ComplexMatrix
    fromList = _fromList ComplexMatrix

    at         (ComplexMatrix _ _ arr) = _at arr
    dimensions (ComplexMatrix m n _  ) = (m, n)
    row i      (ComplexMatrix _ _ arr) = _row i arr
    col j      (ComplexMatrix _ _ arr) = _col j arr
    toList     (ComplexMatrix _ _ arr) = _toList arr
--    inv        (ComplexMatrix _ _ _) = Nothing
--if m /= n then Nothing else
-- Just $ ComplexMatrix m n $ runST (_inv boxedST arr)
    det        (ComplexMatrix m n arr) = if m /= n then 0 else runST (_det thawsBoxed arr)
    rank       (ComplexMatrix _ _ arr) = runST (_rank thawsBoxed arr)


_at :: (IArray a (u Int e), IArray u e)
    => a Int (u Int e) -> (Int, Int) -> e
_at arr (i,j) = arr ! i ! j

_row, _col :: (IArray a (u Int e), IArray u e) => Int -> a Int (u Int e) -> [e]
_row i arr = let row = arr ! i in [ row ! j | j <- [1..(snd (bounds arr))] ]
_col j arr = [ arr ! i ! j | i <- [1..(snd (bounds arr))] ]

_matrix :: (IArray a (u Int e), IArray u e)
        => (Int -> Int -> (a Int (u Int e)) -> matrix e)
        -> (Int, Int)
        -> ((Int, Int) -> e)
        -> matrix e
_matrix c (numRows, numCols) generator =
    c numRows numCols
      $ array (1, numRows)
      $ [ (i, array (1, numCols) [(j, generator (i, j))
                                 | j <- [1..numCols]])
          | i <- [1..numRows] ]

_toList :: (IArray a e) => Array Int (a Int e) -> [[e]]
_toList = P.map elems . elems

_fromList :: (IArray a (u Int e), IArray u e)
          => (Int -> Int -> a Int (u Int e) -> matrix e) -> [[e]] -> matrix e
_fromList c xs =
    let lengths = P.map length xs
        numCols = foldl1 min lengths
        numRows = length lengths
        
    in  c numRows numCols
          $ array (1, numRows)
          $ zip [1..numRows]
          $ P.map (array (1, numCols) . zip [1..numCols]) xs

dotProd :: Num a => [a] -> [a] -> a
dotProd x = L.foldl' (+) 0 . zipWith (*) x

thawsBoxed :: (IArray a e, MArray (STArray s) e (ST s))
           => Array Int (a Int e)
           -> ST s [STArray s Int e]
thawsBoxed = mapM thaw . elems

thawsUnboxed :: (IArray a e, MArray (STUArray s) e (ST s))
             => Array Int (a Int e)
             -> ST s [STUArray s Int e]
thawsUnboxed = mapM thaw . elems

arrays :: [(u s) Int e]
       -> ST s ((STArray s) Int ((u s) Int e))
arrays list = newListArray (1, length list) list

augment :: (IArray a e, MArray (u s) e (ST s), Num e)
        => ((Int, Int) -> [e] -> ST s ((u s) Int e))
        -> Array Int (a Int e)
        -> ST s (STArray s Int (u s Int e))
augment _ arr = do
    let (_, n) = bounds arr
        row (a,i) = newListArray (1, 2*n)
                                 [ if j > n then (if j == i + n then 1 else 0)
                                            else a ! j
                                 | j <- [1..2*n] ]
    
    mapM row (zip (elems arr) [1..]) >>= newListArray (1, n)

boxedST :: MArray (STArray s) e (ST s)
        => (Int, Int) -> [e] -> ST s ((STArray s) Int e)
boxedST = newListArray

unboxedST :: MArray (STUArray s) e (ST s)
          => (Int, Int) -> [e] -> ST s ((STUArray s) Int e)
unboxedST = newListArray


tee :: Monad m => (b -> m a) -> b -> m b
tee f x = f x >> return x

read :: (MArray a1 b m, MArray a (a1 Int b) m) =>
                       a Int (a1 Int b) -> Int -> Int -> m b
read a i j = readArray a i >>= flip readArray j


_inv :: (IArray a e, MArray (u s) e (ST s), Fractional e, Ord e, Show e)
     => ((Int, Int) -> [e] -> ST s ((u s) Int e))
     -> Array Int (a Int e)
     -> ST s (Maybe (Array Int (a Int e)))
_inv mkArrayST mat = do
    let m = snd $ bounds mat
        n = 2*m

        swap a i j = do
            tmp <- readArray a i
            readArray a j >>= writeArray a i
            writeArray a j tmp

    okay <- newSTRef True

    a <- augment mkArrayST mat

    flip mapM_ [1..m] $ \k -> do
        iPivot <- zip [k..m] <$> mapM (\i -> abs <$> read a i k) [k..m]
                    >>= return . fst . L.maximumBy (compare `on` snd)

        p <- read a iPivot k
        if p == 0 then writeSTRef okay False else do

            swap a iPivot k

            flip mapM_ [k+1..m] $ \i -> do
                a_i <- readArray a i
                a_k <- readArray a k
                flip mapM_ [k+1..n] $ \j -> do
                    a_ij <- readArray a_i j
                    a_kj <- readArray a_k j
                    a_ik <- readArray a_i k
                    writeArray a_i j (a_ij - a_kj * (a_ik / p))
                writeArray a_i k 0

    invertible <- readSTRef okay

    if invertible then
      do
        flip mapM_ [ m - v | v <- [0..m-1] ] $ \i -> do
            a_i <- readArray a i
            p   <- readArray a_i i
            writeArray a_i i 1
            flip mapM_ [i+1..n] $ \j -> do
                readArray a_i j >>= writeArray a_i j . (/ p)

            unless (i == m) $ do
                flip mapM_ [i+1..m] $ \k -> do
                    a_k <- readArray a k
                    p   <- readArray a_i k

                    flip mapM_ [k..n] $ \j -> do
                        a_ij <- readArray a_i j
                        a_kj <- readArray a_k j
                        writeArray a_i j (a_ij - p * a_kj)

        mapM (\i -> readArray a i >>= getElems
                        >>= return . listArray (1, m) . drop m) [1..m]
            >>= return . Just . listArray (1, m)

      else return Nothing

_rank :: (IArray a e, MArray (u s) e (ST s), Num e, Division e, Eq e)
      => (Array Int (a Int e) -> ST s [(u s) Int e])
      -> Array Int (a Int e)
      -> ST s e
_rank thaws mat = do
    let m = snd $ bounds mat
        n = snd $ bounds (mat ! 1)

        swap a i j = do
            tmp <- readArray a i
            readArray a j >>= writeArray a i
            writeArray a j tmp

    a <- thaws mat >>= arrays

    ixPivot <- newSTRef 1
    prevR   <- newSTRef 1

    flip mapM_ [1..n] $ \k -> do
        pivotRow <- readSTRef ixPivot

        switchRow <- mapM (\i -> read a i k) [pivotRow .. m]
            >>= return . L.findIndex (/= 0)

        when (isJust switchRow) $ do
            let ix = fromJust switchRow + pivotRow
            when (pivotRow /= ix) (swap a pivotRow ix)

            a_p   <- readArray a k
            pivot <- readArray a_p k
            prev  <- readSTRef prevR
            
            flip mapM_ [pivotRow+1..m] $ \i -> do
                a_i <- readArray a i
                flip mapM_ [k+1..n] $ \j -> do
                    a_ij <- readArray a_i j
                    a_ik <- readArray a_i k
                    a_pj <- readArray a_p j
                    writeArray a_i j ((pivot * a_ij - a_ik * a_pj)
                                        `divide` prev)

            writeSTRef ixPivot (pivotRow + 1)
            writeSTRef prevR pivot

    readSTRef ixPivot >>= return . (+ negate 1) . fromIntegral


_det :: (IArray a e, MArray (u s) e (ST s),
         Num e, Eq e, Division e)
     => (Array Int (a Int e) -> ST s [(u s) Int e])
     -> Array Int (a Int e) -> ST s e
_det thaws mat = do

    let size = snd $ bounds mat

    a <- thaws mat >>= arrays

    signR  <- newSTRef 1
    pivotR <- newSTRef 1

    flip mapM_ [1..size] $ \k -> do

        prev  <- readSTRef pivotR
        pivot <- read a k k >>= tee (writeSTRef pivotR)

        when (pivot == 0) $ do
            s <- flip mapM [(k+1)..size] $ \r -> do
                a_rk <- read a r k
                if a_rk == 0 then return 0 else return r
            let sf = filter (>0) s

            when (not $ null sf) $ do
                let sw = head sf

                row <- readArray a sw
                readArray a k >>= writeArray a sw
                writeArray a k row

                read a k k >>= writeSTRef pivotR
                readSTRef signR >>= writeSTRef signR . negate

        pivot' <- readSTRef pivotR
        flip mapM [(k+1)..size] $ \i -> do
            a_i <- readArray a i
            flip mapM [(k+1)..size] $ \j -> do
                a_ij <- readArray a_i j
                a_ik <- readArray a_i k
                a_kj <- read a k j
                writeArray a_i j ((pivot' * a_ij - a_ik * a_kj) `divide` prev)

    liftM2 (*) (readSTRef pivotR) (readSTRef signR)