-----------------------------------------------------------
-- |
-- Module      : Control.Imperative.Vector.Dynamic
-- Copyright   : (C) 2015, Yu Fukuzawa
-- License     : BSD3
-- Maintainer  : minpou.primer@email.com
-- Stability   : experimental
-- Portability : portable
--
-----------------------------------------------------------

{-# LANGUAGE ConstraintKinds           #-}
{-# LANGUAGE DataKinds                 #-}
{-# LANGUAGE FlexibleContexts          #-}
{-# LANGUAGE FlexibleInstances         #-}
{-# LANGUAGE FunctionalDependencies    #-}
{-# LANGUAGE GADTs                     #-}
{-# LANGUAGE MultiParamTypeClasses     #-}
{-# LANGUAGE ScopedTypeVariables       #-}
{-# LANGUAGE TypeFamilies              #-}
{-# LANGUAGE TypeOperators             #-}
{-# LANGUAGE UndecidableInstances      #-}

module Control.Imperative.Vector.Dynamic
( -- $doc

  -- * Types
  Vector
, MonadVector
, VectorElem
, VectorEntity
, HasVector
, Item
, NestedList
, Size(..)
, Dim(..)
, dim1
, dim2
, dim3
  -- * Operations
, new
, newSized
, newSized'
, Control.Imperative.Vector.Dynamic.length
, size
, fromList
, toList
, push
, pop
, unshift
, shift
) where
import           Control.Imperative.Internal
import           Control.Imperative.Vector.Base
import           Control.Monad                  (liftM)
import qualified Control.Monad                  as M
import           Control.Monad.Base
import           Control.Monad.Primitive        (PrimMonad)
import           Data.Nat
import           Data.Vector.Dynamic

-- $doc
-- An efficient array which has three features.
--
-- * Automatic switching unboxed and boxed arrays.
-- * Multi-dimension support
-- * Dynamically resizing
--
-- There are two basic operation exported from the "Control.Imperative" module.
--
-- [@ref@] /O(1)/. return the element of a vector at the given index.
-- [@assign@] /O(1)/ amortized. replace the element at the given index.
--
newtype Vector m n a = V (MultiDim m n a)

class Monad m => HasVector s v m | s -> v, s -> m where
  getVector :: s -> m v

instance Monad m => HasVector (Vector m n a) (Vector m n a) m where
  getVector = return
  {-# INLINE getVector #-}

instance Monad m => HasVector (Ref m (Vector m n a)) (Vector m n a) m where
  getVector = get
  {-# INLINE getVector #-}

data MultiDim m (n :: Nat) a where
  D1 :: DynamicVector m a -> MultiDim m (S Z) a
  DN :: DynamicVector m (MultiDim m (S n) a) -> MultiDim m (S (S n)) a

instance (VectorElem a, PrimMonad m) => Indexable (Vector m (S Z) a) where
  type Element (Vector m (S Z) a)   = Ref m a
  type IndexType (Vector m (S Z) a) = Int
  (!) (V (D1 v)) i = Ref
    { get = readDyn v i 
    , set = writeDyn v i
    }
  {-# INLINE (!) #-}

instance PrimMonad m => Indexable (Vector m (S (S n)) a) where
  type Element (Vector m (S (S n)) a) = Ref m (Vector m (S n) a)
  type IndexType (Vector m (S (S n)) a) = Int
  (!) (V (DN v)) i = Ref
    { get = liftM V $ readDyn v i
    , set = \(V w) -> writeDyn v i w
    }
  {-# INLINE (!) #-}

instance (VectorElem a, PrimMonad m) => Indexable (Ref m (Vector m (S Z) a)) where
  type Element (Ref m (Vector m (S Z) a)) = Ref m a
  type IndexType (Ref m (Vector m (S Z) a)) = Int
  r ! i = Ref
    { get = get r >>= \(V (D1 v)) -> readDyn v i
    , set = \x -> get r >>= \(V (D1 v)) -> writeDyn v i x
    }
  {-# INLINE (!) #-}

instance PrimMonad m => Indexable (Ref m (Vector m (S (S n)) a)) where
  type Element (Ref m (Vector m (S (S n)) a)) = Ref m (Vector m (S n) a)
  type IndexType (Ref m (Vector m (S (S n)) a)) = Int
  r ! i = Ref
    { get = get r >>= \(V (DN v)) -> liftM V $ readDyn v i
    , set = \(V w) -> get r >>= \(V (DN v)) -> writeDyn v i w
    }
  {-# INLINE (!) #-}

-- | /O(1)/. Create an empty vector.
new
  :: (VectorElem a, MonadVector m, SingNat (S n))
  => proxy (S n) -- ^ dimension
  -> m (Vector (BaseEff m) (S n) a)
new (_ :: proxy (S n)) = liftBase $ liftM V $ case (singNat :: SNat (S n)) of
  SS SZ -> M.liftM D1 $ newDyn 0
  SS (SS _) -> M.liftM DN $ newDyn 0
{-# INLINE new #-}

-- | /O(n)/. Create a vector of the given length.
newSized :: (VectorElem a, MonadVector m) => Size (S n) -> m (Vector (BaseEff m) (S n) a)
newSized = liftBase . liftM V . go
  where
    go :: (VectorElem a, PrimMonad m) => Size (S n) -> m (MultiDim m (S n) a)
    go (n :*: One) = liftM D1 $ newDyn n
    go (n :*: r@(_ :*: _)) = do
      v <- newDyn n
      M.forM_ [0..n-1] $ \i -> do
        w <- go r
        writeDyn v i w
      return $ DN v
{-# INLINE newSized #-}

-- | /O(n)/. Create a vector filled with an initial value.
newSized' :: (VectorElem a, MonadVector m) => Size (S n) -> a -> m (Vector (BaseEff m) (S n) a)
newSized' r = liftBase . liftM V . go r
  where
    go :: (VectorElem a, PrimMonad m) => Size (S n) -> a -> m (MultiDim m (S n) a)
    go (n :*: One) x = liftM D1 $ newDyn' n x
    go (n :*: rest@(_ :*: _)) x = do
      v <- newDyn n
      M.forM_ [0..n-1] $ \i -> do
        w <- go rest x
        writeDyn v i w
      return $ DN v
{-# INLINE newSized' #-}

-- | /O(n)/. Build a vector from a nested list.
fromList
  :: (VectorElem a, MonadVector m, SingNat (S d))
  => proxy (S d) -- ^ dimension
  -> NestedList (S d) a -- ^ nested list
  -> m (Vector (BaseEff m) (S d) a)
fromList (_ :: proxy (S d)) = liftBase . liftM V . go (singNat :: SNat (S d))
  where
    go :: (PrimMonad f, VectorElem b) => SNat (S n) -> NestedList (S n) b -> f (MultiDim f (S n) b)
    go (SS SZ) xs = do
      v <- newDyn 0
      M.forM_ xs $ \x -> pushDyn v x
      return (D1 v)
    go (SS n@(SS _)) xs = do
      v <- newDyn 0
      M.forM_ xs $ \ys -> do
        w <- go n ys
        pushDyn v w
      return (DN v)
{-# INLINE fromList #-}

-- | /O(n)/. Convert the vector to a nested list.
toList :: (VectorElem a, HasVector s (Vector (BaseEff m) (S n) a) (BaseEff m), MonadVector m) => s -> m (NestedList (S n) a)
toList s = liftBase $ getVector s >>= \(V dv) -> go dv
  where
    go :: (VectorElem a, PrimMonad m) => MultiDim m n a -> m (NestedList n a)
    go (D1 v) = toListDyn v
    go (DN v) = toListDyn v >>= M.mapM go
{-# INLINE toList #-}

-- | Short alias for 'length'.
size :: (VectorElem a, HasVector s (Vector (BaseEff m) (S n) a) (BaseEff m), MonadVector m) => s -> m Int
size s = liftBase $ getVector s >>= \(V dv) -> case dv of
  D1 v -> sizeDyn v
  DN v -> sizeDyn v
{-# INLINE size #-}

-- | /O(1)/. The number of elements in the vector.
length :: (VectorElem a, HasVector s (Vector (BaseEff m) (S n) a) (BaseEff m), MonadVector m) => s -> m Int
length = size
{-# INLINE length #-}

type family Item a where
  Item (Vector m (S Z) a) = a
  Item (Vector m (S (S n)) a) = Vector m (S n) a

-- | /O(1)/. Add a value to the rear of a vector.
push :: (VectorElem a, HasVector s (Vector (BaseEff m) (S n) a) (BaseEff m), MonadVector m) => s -> Item (Vector (BaseEff m) (S n) a) -> m ()
push s x = liftBase $ getVector s >>= \(V dv) -> case dv of
  D1 v -> pushDyn v x
  DN v -> let (V w) = x in pushDyn v w
{-# INLINE push #-}

-- | /O(1)/. Extract a value from the rear of a vector.
pop :: (VectorElem a, HasVector s (Vector (BaseEff m) (S n) a) (BaseEff m), MonadVector m) => s -> m (Item (Vector (BaseEff m) (S n) a))
pop s = liftBase $ getVector s >>= \(V dv) -> case dv of
  D1 v -> popDyn v
  DN v -> liftM V $ popDyn v
{-# INLINE pop #-}

-- | /O(1)/. Add a value to the front of a vector.
unshift :: (VectorElem a, HasVector s (Vector (BaseEff m) (S n) a) (BaseEff m), MonadVector m) => s -> Item (Vector (BaseEff m) (S n) a) -> m ()
unshift s x = liftBase $ getVector s >>= \(V dv) -> case dv of
  D1 v -> unshiftDyn v x
  DN v -> let (V w) = x in unshiftDyn v w
{-# INLINE unshift #-}

-- | /O(1)/. Extract a value from the front of a vector.
shift :: (VectorElem a, HasVector s (Vector (BaseEff m) (S n) a) (BaseEff m), MonadVector m) => s -> m (Item (Vector (BaseEff m) (S n) a))
shift s = liftBase $ getVector s >>= \(V dv) -> case dv of
  D1 v -> shiftDyn v
  DN v -> liftM V $ shiftDyn v
{-# INLINE shift #-}