{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE Rank2Types            #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators #-}
module Data.Matrix.Static.Generic
    ( Mutable
    , Matrix(..)
    , MatrixKind

    -- * Matrix query
    , rows
    , cols
    , (!)
    , takeColumn
    , takeRow
    , toRows
    , toColumns

    -- * Matrix Construction
    , empty
    , matrix
    , fromRows
    , fromColumns
    , fromVector
    , fromList

    , toList
    , create
    , convertAny
    , mapM
    , imapM
    ) where

import           Control.Monad.Primitive     (PrimMonad, PrimState)
import           Control.Monad.ST            (ST, runST)
import qualified Data.Vector.Generic         as G
import Text.Printf (printf)
import Prelude hiding (map, mapM, mapM_, sequence, sequence_)
import qualified Data.List as L
import Data.Tuple (swap)
import Data.Kind (Type)
import GHC.TypeLits (Nat, type (<=))
import Data.Singletons (SingI, Sing, fromSing, sing)

import Data.Matrix.Static.Generic.Mutable (MMatrix, MMatrixKind)

type MatrixKind = Nat -> Nat -> (Type -> Type) -> Type -> Type

type family Mutable (mat :: MatrixKind) = (mmat :: MMatrixKind) | mmat -> mat

class (MMatrix (Mutable mat) (G.Mutable v) a, G.Vector v a) => Matrix (mat :: MatrixKind) v a where
    dim :: mat r c v a -> (Int, Int)

    unsafeIndex :: mat r c v a -> (Int, Int) -> a

    unsafeFromVector :: (SingI r, SingI c) => v a -> mat r c v a

    -- | Convert matrix to vector in column order.
    -- Default algorithm is O((m*n) * O(unsafeIndex)).
    flatten :: mat r c v a -> v a
    flatten mat = G.generate (r*c) $ \i -> unsafeIndex mat (swap $ i `divMod` r)
      where
        (r,c) = dim mat
    {-# INLINE flatten #-}

    -- | Extract a row. Default algorithm is O(n * O(unsafeIndex)).
    unsafeTakeRow :: mat r c v a -> Int -> v a
    unsafeTakeRow mat i = G.generate c $ \j -> unsafeIndex mat (i,j)
      where
        (_,c) = dim mat
    {-# INLINE unsafeTakeRow #-}

    -- | Extract a column. Default algorithm is O(m * O(unsafeIndex)).
    unsafeTakeColumn :: mat r c v a -> Int -> v a
    unsafeTakeColumn mat j = G.generate r $ \i -> unsafeIndex mat (i,j)
      where
        (r,_) = dim mat
    {-# INLINE unsafeTakeColumn #-}

    -- | Extract the diagonal. Default algorithm is O(min(m,n) * O(unsafeIndex)).
    takeDiag :: mat r c v a -> v a
    takeDiag mat = G.generate n $ \i -> unsafeIndex mat (i,i)
      where
        n = uncurry min . dim $ mat
    {-# INLINE takeDiag #-}

    transpose :: (SingI r, SingI c) => mat r c v a -> mat c r v a
    transpose mat = unsafeFromVector $ G.generate (r*c) $ \x ->
        unsafeIndex mat $ x `divMod` c
      where
       (r, c) = dim mat
    {-# INLINE transpose #-}

    thaw :: PrimMonad s
         => mat r c v a
         -> s ((Mutable mat) r c (G.Mutable v) (PrimState s) a)

    unsafeThaw :: PrimMonad s
               => mat r c v a
               -> s ((Mutable mat) r c (G.Mutable v) (PrimState s) a)

    freeze :: PrimMonad s
           => (Mutable mat) r c (G.Mutable v) (PrimState s) a
           -> s (mat r c v a)

    unsafeFreeze :: PrimMonad s
                 => (Mutable mat) r c (G.Mutable v) (PrimState s) a
                 -> s (mat r c v a)

    map :: G.Vector v b => (a -> b) -> mat r c v a -> mat r c v b
    imap :: G.Vector v b => ((Int, Int) -> a -> b) -> mat r c v a -> mat r c v b
    imapM_ :: (Monad monad, Matrix mat v a)
           => ((Int, Int) -> a -> monad b) -> mat r c v a -> monad ()
    sequence :: (G.Vector v (monad a), Monad monad)
             => mat r c v (monad a) -> monad (mat r c v a)
    sequence_ :: (G.Vector v (monad a), Monad monad) => mat r c v (monad a) -> monad ()

-- | Derived methods

-- | Return the number of rows
rows :: Matrix m v a => m r c v a -> Int
rows = fst . dim
{-# INLINE rows #-}

-- | Return the number of columns
cols :: Matrix m v a => m r c v a -> Int
cols = snd . dim
{-# INLINE cols #-}

-- | Indexing
(!) :: forall m r c v a i j. (Matrix m v a, i <= r, j <= c)
    => m r c v a -> (Sing i, Sing j) -> a
(!) m (si, sj) = unsafeIndex m (i,j)
  where
    i = fromIntegral $ fromSing si
    j = fromIntegral $ fromSing sj
{-# INLINE (!) #-}

-- | Construct matrix from a vector containg columns.
fromVector :: forall m r c v a. (SingI r, SingI c, Matrix m v a)
           => v a -> m r c v a
fromVector vec | r*c /= n = error errMsg
               | otherwise = unsafeFromVector vec
  where
    errMsg = printf "fromVector: incorrect length (%d * %d != %d)" r c n
    n = G.length vec
    r = fromIntegral $ fromSing (sing :: Sing r)
    c = fromIntegral $ fromSing (sing :: Sing c)
{-# INLINE fromVector #-}

matrix :: (SingI r, SingI c, Matrix m v a)
       => [[a]] -> m r c v a
matrix = fromList . concat . L.transpose
{-# INLINE matrix #-}

-- | Construct matrix from a list containg columns.
fromList :: (SingI r, SingI c, Matrix m v a)
         => [a] -> m r c v a
fromList = fromVector . G.fromList
{-# INLINE fromList #-}

-- | O(m*n) Create matrix from rows
fromRows :: (Matrix m v a, SingI r, SingI c) => [v a] -> m r c v a
fromRows = transpose . fromColumns
{-# INLINE fromRows #-}

-- | O(m*n) Create matrix from columns
fromColumns :: (Matrix m v a, SingI r, SingI c)
            => [v a] -> m r c v a
fromColumns = fromVector . G.concat
{-# INLINE fromColumns #-}

-- | O(m*n) Create a list by concatenating columns
toList :: Matrix m v a => m r c v a -> [a]
toList = G.toList . flatten
{-# INLINE toList #-}

empty :: Matrix m v a => m 0 0 v a
empty = unsafeFromVector G.empty
{-# INLINE empty #-}

create :: Matrix m v a
       => (forall s . ST s ((Mutable m) r c (G.Mutable v) s a)) -> m r c v a
create m = runST $ unsafeFreeze =<< m
{-# INLINE create #-}

-- | O(m*n) Convert to any type of matrix.
convertAny :: (Matrix m1 v1 a, Matrix m2 v2 a, SingI r, SingI c)
           => m1 r c v1 a -> m2 r c v2 a
convertAny = unsafeFromVector . G.convert . flatten
{-# INLINE convertAny #-}

-- | Extract a row.
takeRow :: forall m r c v a i. (i <= r, SingI i, Matrix m v a)
        => m r c v a -> Sing i -> v a
takeRow mat _ = unsafeTakeRow mat i
  where
    i = fromIntegral $ fromSing (sing :: Sing i)
{-# INLINE takeRow #-}

-- | O(m) Return the rows
toRows :: Matrix m v a => m r c v a -> [v a]
toRows mat = L.map (unsafeTakeRow mat) [0..r-1]
  where
    (r,_) = dim mat
{-# INLINE toRows #-}

-- | Extract a row.
takeColumn :: forall m r c v a j. (j <= c, SingI j, Matrix m v a)
           => m r c v a -> Sing j -> v a
takeColumn mat _ = unsafeTakeColumn mat j
  where
    j = fromIntegral $ fromSing (sing :: Sing j)
{-# INLINE takeColumn #-}

-- | O(m*n) Return the columns
toColumns :: Matrix m v a => m r c v a -> [v a]
toColumns mat = L.map (unsafeTakeColumn mat) [0..c-1]
  where
    (_,c) = dim mat
{-# INLINE toColumns #-}

mapM :: (G.Vector v (monad b), Monad monad, Matrix mat v a, Matrix mat v b)
     => (a -> monad b) -> mat r c v a -> monad (mat r c v b)
mapM f = sequence . map f
{-# INLINE mapM #-}

imapM :: (G.Vector v (monad b), Monad monad, Matrix mat v a, Matrix mat v b)
      => ((Int, Int) -> a -> monad b)
      -> mat r c v a -> monad (mat r c v b)
imapM f = sequence . imap f
{-# INLINE imapM #-}