-------------------------------------------------------------
-- |
-- Module      : Control.Imperative.Vector.Static
-- Copyright   : (C) 2015, Yu Fukuzawa
-- License     : BSD3
-- Maintainer  : minpou.primer@email.com
-- Stability   : experimental
-- Portability : portable
--
-----------------------------------------------------------

{-# LANGUAGE ConstraintKinds           #-}
{-# LANGUAGE DataKinds                 #-}
{-# LANGUAGE FlexibleContexts          #-}
{-# LANGUAGE FlexibleInstances         #-}
{-# LANGUAGE FunctionalDependencies    #-}
{-# LANGUAGE GADTs                     #-}
{-# LANGUAGE MultiParamTypeClasses     #-}
{-# LANGUAGE ScopedTypeVariables       #-}
{-# LANGUAGE TypeFamilies              #-}
{-# LANGUAGE TypeOperators             #-}
{-# LANGUAGE UndecidableInstances      #-}

module Control.Imperative.Vector.Static
( -- $doc

  -- * Types
  Vector
, MonadVector
, VectorElem
, VectorEntity
, HasVector
, NestedList
, Size(..)
, Dim(..)
, dim1
, dim2
, dim3
  -- * Operations
, newSized
, newSized'
, Control.Imperative.Vector.Static.length
, size
, fromListN
, toList
) where
import           Control.Imperative.Internal
import           Control.Imperative.Vector.Base
import           Control.Monad                  (liftM)
import qualified Control.Monad                  as M
import           Control.Monad.Base
import           Control.Monad.Primitive        (PrimMonad, PrimState)
import           Data.Nat
import qualified Data.Vector.Generic.Mutable    as GMV
import qualified Data.Vector.Mutable            as MV

-- $doc
-- An efficient array which has two features.
--
-- * Automatic switching unboxed and boxed arrays.
-- * Multi-dimension support
--
-- There are two basic operation exported from the "Control.Imperative" module.
--
-- [@ref@] /O(1)/. return the element of a vector at the given index.
-- [@assign@] /O(1)/. replace the element at the given index.

newtype Vector m n a = V (MultiDim m n a)

class Monad m => HasVector s v m | s -> v, s -> m where
  getVector :: s -> m v

instance Monad m => HasVector (Vector m n a) (Vector m n a) m where
  getVector = return
  {-# INLINE getVector #-}

instance Monad m => HasVector (Ref m (Vector m n a)) (Vector m n a) m where
  getVector = get
  {-# INLINE getVector #-}

data MultiDim m (n :: Nat) a where
  D1 :: VectorEntity a (PrimState m) a -> MultiDim m (S Z) a
  DN :: MV.MVector (PrimState m) (MultiDim m (S n) a) -> MultiDim m (S (S n)) a

instance (VectorElem a, PrimMonad m) => Indexable (Vector m (S Z) a) where
  type Element (Vector m (S Z) a)   = Ref m a
  type IndexType (Vector m (S Z) a) = Int
  (!) (V (D1 v)) i = Ref
    { get = GMV.read v i
    , set = GMV.write v i
    }
  {-# INLINE (!) #-}

instance PrimMonad m => Indexable (Vector m (S (S n)) a) where
  type Element (Vector m (S (S n)) a) = Ref m (Vector m (S n) a)
  type IndexType (Vector m (S (S n)) a) = Int
  (!) (V (DN v)) i = Ref
    { get = liftM V $ MV.read v i
    , set = \(V w) -> MV.write v i w
    }
  {-# INLINE (!) #-}

instance (VectorElem a, PrimMonad m) => Indexable (Ref m (Vector m (S Z) a)) where
  type Element (Ref m (Vector m (S Z) a)) = Ref m a
  type IndexType (Ref m (Vector m (S Z) a)) = Int
  r ! i = Ref
    { get = get r >>= \(V (D1 v)) -> GMV.read v i
    , set = \x -> get r >>= \(V (D1 v)) -> GMV.write v i x
    }
  {-# INLINE (!) #-}

instance PrimMonad m => Indexable (Ref m (Vector m (S (S n)) a)) where
  type Element (Ref m (Vector m (S (S n)) a)) = Ref m (Vector m (S n) a)
  type IndexType (Ref m (Vector m (S (S n)) a)) = Int
  r ! i = Ref
    { get = get r >>= \(V (DN v)) -> liftM V $ MV.read v i
    , set = \(V w) -> get r >>= \(V (DN v)) -> MV.write v i w
    }
  {-# INLINE (!) #-}

-- | /O(n)/. Create a vector of the given length.
newSized :: (VectorElem a, MonadVector m) => Size (S n) -> m (Vector (BaseEff m) (S n) a)
newSized = liftBase . liftM V . go
  where
    go :: (VectorElem a, PrimMonad m) => Size (S n) -> m (MultiDim m (S n) a)
    go (n :*: One) = liftM D1 $ GMV.new n
    go (n :*: r@(_ :*: _)) = do
      v <- MV.new n
      M.forM_ [0..n-1] $ \i -> do
        w <- go r
        GMV.write v i w
      return $ DN v
{-# INLINE newSized #-}

-- | /O(n)/. Create a vector filled with an initial value.
newSized' :: (VectorElem a, MonadVector m) => Size (S n) -> a -> m (Vector (BaseEff m) (S n) a)
newSized' r = liftBase . liftM V . go r
  where
    go :: (VectorElem a, PrimMonad m) => Size (S n) -> a -> m (MultiDim m (S n) a)
    go (n :*: One) x = liftM D1 $ GMV.replicate n x
    go (n :*: rest@(_ :*: _)) x = do
      v <- MV.new n
      M.forM_ [0..n-1] $ \i -> do
        w <- go rest x
        GMV.write v i w
      return $ DN v
{-# INLINE newSized' #-}

-- | Short alias for 'length'.
size :: (VectorElem a, HasVector s (Vector (BaseEff m) (S n) a) (BaseEff m), MonadVector m) => s -> m Int
size s = liftBase $ getVector s >>= \(V dv) -> return $ case dv of
  D1 v -> GMV.length v
  DN v -> MV.length v
{-# INLINE size #-}

-- | /O(1)/. The number of elements in the vector.
length :: (VectorElem a, HasVector s (Vector (BaseEff m) (S n) a) (BaseEff m), MonadVector m) => s -> m Int
length = size
{-# INLINE length #-}

-- | /O(n)/. Build a vector from a nested list.
fromListN
  :: (VectorElem a, MonadVector m)
  => Size (S n) -- ^ sizes of vector
  -> NestedList (S n) a -- ^ nested list
  -> m (Vector (BaseEff m) (S n) a)
fromListN r = liftBase . liftM V . go r
  where
    go :: (VectorElem a, PrimMonad m) => Size (S n) -> NestedList (S n) a -> m (MultiDim m (S n) a)
    go (n :*: One) xs = do
      v <- GMV.new n
      M.forM_ (zip [0..n-1] xs) $ \(i, x) -> GMV.write v i x
      return $ D1 v
    go (n :*: rest@(_ :*: _)) xs = do
      v <- GMV.new n
      M.forM_ (zip [0..n-1] xs) $ \(i, ys) -> do
        w <- go rest ys
        GMV.write v i w
      return $ DN v
{-# INLINE fromListN #-}

-- | /O(n)/. Convert the vector to a nested list.
toList :: (VectorElem a, HasVector s (Vector (BaseEff m) (S n) a) (BaseEff m), MonadVector m) => s -> m (NestedList (S n) a)
toList s = liftBase $ getVector s >>= \(V dv) -> go dv
  where
    go :: (VectorElem a, PrimMonad m) => MultiDim m n a -> m (NestedList n a)
    go (D1 v) = M.forM [0..GMV.length v-1] (GMV.read v)
    go (DN v) = M.forM [0..MV.length v-1] (MV.read v) >>= M.mapM go
{-# INLINE toList #-}