{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds             #-}
{-# LANGUAGE Rank2Types            #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeFamilies          #-}
-- |
-- Type classes for vectors which are implemented on top of the arrays
-- and support in-place mutation. API is similar to one used in the
-- @vector@ package.
module Data.Vector.Fixed.Mutable (
    -- * Mutable vectors
    Arity
  , arity
  , Mutable
  , DimM
  , MVector(..)
  , lengthM
  , read
  , write
  , clone
    -- * Creation
  , replicate
  , replicateM
  , generate
  , generateM
    -- * Loops
  , forI
    -- * Immutable vectors
  , IVector(..)
  , index
  , freeze
  , thaw
    -- * Vector API
  , constructVec
  , inspectVec
  ) where

import Control.Applicative  (Const(..))
import Control.Monad.ST
import Control.Monad.Primitive
import Data.Typeable  (Proxy(..))
import GHC.TypeLits
import Data.Vector.Fixed.Cont (Dim,PeanoNum(..),Peano,Arity,Fun(..),Vector(..),ContVec,arity,apply,accum,length)
import Prelude hiding (read,length,replicate)


----------------------------------------------------------------
-- Type classes
----------------------------------------------------------------

-- | Mutable counterpart of fixed-length vector.
type family Mutable (v :: * -> *) :: * -> * -> *

-- | Dimension for mutable vector.
type family DimM (v :: * -> * -> *) :: Nat

-- | Type class for mutable vectors.
class (Arity (DimM v)) => MVector v a where
  -- | Copy vector. The two vectors may not overlap. Since vectors'
  --   length is encoded in the type there is no need in runtime checks.
  copy :: PrimMonad m
       => v (PrimState m) a    -- ^ Target
       -> v (PrimState m) a    -- ^ Source
       -> m ()
  -- | Copy vector. The two vectors may overlap. Since vectors' length
  --   is encoded in the type there is no need in runtime checks.
  move :: PrimMonad m
       => v (PrimState m) a    -- ^ Target
       -> v (PrimState m) a    -- ^ Source
       -> m ()
  -- | Allocate new vector
  new   :: PrimMonad m => m (v (PrimState m) a)
  -- | Read value at index without bound checks.
  unsafeRead  :: PrimMonad m => v (PrimState m) a -> Int -> m a
  -- | Write value at index without bound checks.
  unsafeWrite :: PrimMonad m => v (PrimState m) a -> Int -> a -> m ()


-- | Length of mutable vector. Function doesn't evaluate its argument.
lengthM :: forall v s a. (Arity (DimM v)) => v s a -> Int
lengthM :: forall (v :: * -> * -> *) s a. Arity (DimM v) => v s a -> Int
lengthM v s a
_ = forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Int
arity (forall {k} (t :: k). Proxy t
Proxy :: Proxy (DimM v))

-- | Create copy of vector.
--
--   Examples:
--
--   >>> import Control.Monad.ST (runST)
--   >>> import Data.Vector.Fixed (mk3)
--   >>> import Data.Vector.Fixed.Boxed (Vec3)
--   >>> import qualified Data.Vector.Fixed.Mutable as M
--   >>> let x = runST (do { v <- M.replicate 100; v' <- clone v; M.write v' 0 2; M.unsafeFreeze v' }) :: Vec3 Int
--   >>> x
--   fromList [2,100,100]
clone :: (PrimMonad m, MVector v a) => v (PrimState m) a -> m (v (PrimState m) a)
{-# INLINE clone #-}
clone :: forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> m (v (PrimState m) a)
clone v (PrimState m) a
v = do
  v (PrimState m) a
u <- forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
m (v (PrimState m) a)
new
  forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
move v (PrimState m) a
u v (PrimState m) a
v
  forall (m :: * -> *) a. Monad m => a -> m a
return v (PrimState m) a
u

-- | Read value at index with bound checks.
read  :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> m a
{-# INLINE read #-}
read :: forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
read v (PrimState m) a
v Int
i
  | Int
i forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
i forall a. Ord a => a -> a -> Bool
>= forall (v :: * -> * -> *) s a. Arity (DimM v) => v s a -> Int
lengthM v (PrimState m) a
v = forall a. HasCallStack => [Char] -> a
error [Char]
"Data.Vector.Fixed.Mutable.read: index out of range"
  | Bool
otherwise               = forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) a
v Int
i

-- | Write value at index with bound checks.
write :: (PrimMonad m, MVector v a) => v (PrimState m) a -> Int -> a -> m ()
{-# INLINE write #-}
write :: forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
write v (PrimState m) a
v Int
i a
x
  | Int
i forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
i forall a. Ord a => a -> a -> Bool
>= forall (v :: * -> * -> *) s a. Arity (DimM v) => v s a -> Int
lengthM v (PrimState m) a
v = forall a. HasCallStack => [Char] -> a
error [Char]
"Data.Vector.Fixed.Mutable.write: index out of range"
  | Bool
otherwise               = forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) a
v Int
i a
x


-- | Create new vector with all elements set to given value.
replicate :: (PrimMonad m, MVector v a) => a -> m (v (PrimState m) a)
{-# INLINE replicate #-}
replicate :: forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
a -> m (v (PrimState m) a)
replicate a
a = do
  v (PrimState m) a
v <- forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
m (v (PrimState m) a)
new
  forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (Int -> m ()) -> m ()
forI v (PrimState m) a
v forall a b. (a -> b) -> a -> b
$ \Int
i -> forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) a
v Int
i a
a
  forall (f :: * -> *) a. Applicative f => a -> f a
pure v (PrimState m) a
v

-- | Create new vector with all elements are generated by provided
--   monadic action.
replicateM :: (PrimMonad m, MVector v a) => m a -> m (v (PrimState m) a)
{-# INLINE replicateM #-}
replicateM :: forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
m a -> m (v (PrimState m) a)
replicateM m a
m = do
  v (PrimState m) a
v <- forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
m (v (PrimState m) a)
new
  forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (Int -> m ()) -> m ()
forI v (PrimState m) a
v forall a b. (a -> b) -> a -> b
$ \Int
i -> forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) a
v Int
i forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m a
m
  forall (f :: * -> *) a. Applicative f => a -> f a
pure v (PrimState m) a
v

-- | Create new vector with using function from index to value.
generate :: (PrimMonad m, MVector v a) => (Int -> a) -> m (v (PrimState m) a)
{-# INLINE generate #-}
generate :: forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
(Int -> a) -> m (v (PrimState m) a)
generate Int -> a
f = do
  v (PrimState m) a
v <- forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
m (v (PrimState m) a)
new
  forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (Int -> m ()) -> m ()
forI v (PrimState m) a
v forall a b. (a -> b) -> a -> b
$ \Int
i -> forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) a
v Int
i forall a b. (a -> b) -> a -> b
$ Int -> a
f Int
i
  forall (f :: * -> *) a. Applicative f => a -> f a
pure v (PrimState m) a
v

-- | Create new vector with using monadic function from index to value.
generateM :: (PrimMonad m, MVector v a) => (Int -> m a) -> m (v (PrimState m) a)
{-# INLINE generateM #-}
generateM :: forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
(Int -> m a) -> m (v (PrimState m) a)
generateM Int -> m a
f = do
  v (PrimState m) a
v <- forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
m (v (PrimState m) a)
new
  forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (Int -> m ()) -> m ()
forI v (PrimState m) a
v forall a b. (a -> b) -> a -> b
$ \Int
i -> forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) a
v Int
i forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Int -> m a
f Int
i
  forall (f :: * -> *) a. Applicative f => a -> f a
pure v (PrimState m) a
v

-- | Loop which calls function for each index
forI :: (PrimMonad m, MVector v a) => v (PrimState m) a -> (Int -> m ()) -> m ()
{-# INLINE forI #-}
forI :: forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (Int -> m ()) -> m ()
forI v (PrimState m) a
v Int -> m ()
f = Int -> m ()
go Int
0
  where
    go :: Int -> m ()
go Int
i | Int
i forall a. Ord a => a -> a -> Bool
>= Int
n    = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
         | Bool
otherwise = Int -> m ()
f Int
i forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> m ()
go (Int
iforall a. Num a => a -> a -> a
+Int
1)
    n :: Int
n = forall (v :: * -> * -> *) s a. Arity (DimM v) => v s a -> Int
lengthM v (PrimState m) a
v


----------------------------------------------------------------
-- Immutable
----------------------------------------------------------------

-- | Type class for immutable vectors
class (Dim v ~ DimM (Mutable v), MVector (Mutable v) a) => IVector v a where
  -- | Convert vector to immutable state. Mutable vector must not be
  --   modified afterwards.
  unsafeFreeze :: PrimMonad m => Mutable v (PrimState m) a -> m (v a)
  -- | /O(1)/ Unsafely convert immutable vector to mutable without
  --   copying.  Note that this is a very dangerous function and
  --   generally it's only safe to read from the resulting vector. In
  --   this case, the immutable vector could be used safely as well.
  --
  -- Problems with mutation happen because GHC has a lot of freedom to
  -- introduce sharing. As a result mutable vectors produced by
  -- @unsafeThaw@ may or may not share the same underlying buffer. For
  -- example:
  --
  -- > foo = do
  -- >   let vec = F.generate 10 id
  -- >   mvec <- M.unsafeThaw vec
  -- >   do_something mvec
  --
  -- Here GHC could lift @vec@ outside of foo which means that all calls to
  -- @do_something@ will use same buffer with possibly disastrous
  -- results. Whether such aliasing happens or not depends on the program in
  -- question, optimization levels, and GHC flags.
  --
  -- All in all, attempts to modify a vector produced by @unsafeThaw@
  -- fall out of domain of software engineering and into realm of
  -- black magic, dark rituals, and unspeakable horrors. The only
  -- advice that could be given is: "Don't attempt to mutate a vector
  -- produced by @unsafeThaw@ unless you know how to prevent GHC from
  -- aliasing buffers accidentally. We don't."
  unsafeThaw   :: PrimMonad m => v a -> m (Mutable v (PrimState m) a)
  -- | Get element at specified index without bounds check.
  unsafeIndex :: v a -> Int -> a

index :: IVector v a => v a -> Int -> a
{-# INLINE index #-}
index :: forall (v :: * -> *) a. IVector v a => v a -> Int -> a
index v a
v Int
i | Int
i forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
i forall a. Ord a => a -> a -> Bool
>= forall (v :: * -> *) a. KnownNat (Dim v) => v a -> Int
length v a
v = forall a. HasCallStack => [Char] -> a
error [Char]
"Data.Vector.Fixed.Mutable.!: index out of bounds"
          | Bool
otherwise              = forall (v :: * -> *) a. IVector v a => v a -> Int -> a
unsafeIndex v a
v Int
i


-- | Safely convert mutable vector to immutable.
freeze :: (PrimMonad m, IVector v a) => Mutable v (PrimState m) a -> m (v a)
{-# INLINE freeze #-}
freeze :: forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, IVector v a) =>
Mutable v (PrimState m) a -> m (v a)
freeze Mutable v (PrimState m) a
v = forall (v :: * -> *) a (m :: * -> *).
(IVector v a, PrimMonad m) =>
Mutable v (PrimState m) a -> m (v a)
unsafeFreeze forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> m (v (PrimState m) a)
clone Mutable v (PrimState m) a
v

-- | Safely convert immutable vector to mutable.
thaw :: (PrimMonad m, IVector v a) => v a -> m (Mutable v (PrimState m) a)
{-# INLINE thaw #-}
thaw :: forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, IVector v a) =>
v a -> m (Mutable v (PrimState m) a)
thaw v a
v = forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> m (v (PrimState m) a)
clone forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (v :: * -> *) a (m :: * -> *).
(IVector v a, PrimMonad m) =>
v a -> m (Mutable v (PrimState m) a)
unsafeThaw v a
v



----------------------------------------------------------------
-- Vector API
----------------------------------------------------------------

-- | Generic inspect implementation for array-based vectors.
inspectVec :: forall v a b. (Arity (Dim v), IVector v a) => v a -> Fun (Peano (Dim v)) a b -> b
{-# INLINE inspectVec #-}
inspectVec :: forall (v :: * -> *) a b.
(Arity (Dim v), IVector v a) =>
v a -> Fun (Peano (Dim v)) a b -> b
inspectVec v a
v
  = forall (v :: * -> *) a b.
Vector v a =>
v a -> Fun (Peano (Dim v)) a b -> b
inspect ContVec (Dim v) a
cv
  where
    cv :: ContVec (Dim v) a
    cv :: ContVec (Dim v) a
cv = forall (n :: Nat) (t :: PeanoNum -> *) a.
Arity n =>
(forall (k :: PeanoNum). t ('S k) -> (a, t k))
-> t (Peano n) -> ContVec n a
apply (\(Const Int
i) -> (forall (v :: * -> *) a. IVector v a => v a -> Int -> a
unsafeIndex v a
v Int
i, forall {k} a (b :: k). a -> Const a b
Const (Int
iforall a. Num a => a -> a -> a
+Int
1)))
               (forall {k} a (b :: k). a -> Const a b
Const Int
0 :: Const Int (Peano (Dim v)))

-- | Generic construct implementation for array-based vectors.
constructVec :: forall v a. (Arity (Dim v), IVector v a) => Fun (Peano (Dim v)) a (v a)
{-# INLINE constructVec #-}
constructVec :: forall (v :: * -> *) a.
(Arity (Dim v), IVector v a) =>
Fun (Peano (Dim v)) a (v a)
constructVec =
  forall (n :: PeanoNum) (t :: PeanoNum -> *) a b.
ArityPeano n =>
(forall (k :: PeanoNum). t ('S k) -> a -> t k)
-> (t 'Z -> b) -> t n -> Fun n a b
accum forall (v :: * -> *) a (n :: PeanoNum).
IVector v a =>
T_new v a ('S n) -> a -> T_new v a n
step
        (\(T_new Int
_ forall s. ST s (Mutable v s a)
st) -> forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a (m :: * -> *).
(IVector v a, PrimMonad m) =>
Mutable v (PrimState m) a -> m (v a)
unsafeFreeze forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall s. ST s (Mutable v s a)
st :: v a)
        (forall {k} (v :: * -> *) a (n :: k).
Int -> (forall s. ST s (Mutable v s a)) -> T_new v a n
T_new Int
0 forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
m (v (PrimState m) a)
new :: T_new v a (Peano (Dim v)))

data T_new v a n = T_new Int (forall s. ST s (Mutable v s a))

step :: (IVector v a) => T_new v a ('S n) -> a -> T_new v a n
step :: forall (v :: * -> *) a (n :: PeanoNum).
IVector v a =>
T_new v a ('S n) -> a -> T_new v a n
step (T_new Int
i forall s. ST s (Mutable v s a)
st) a
x = forall {k} (v :: * -> *) a (n :: k).
Int -> (forall s. ST s (Mutable v s a)) -> T_new v a n
T_new (Int
iforall a. Num a => a -> a -> a
+Int
1) forall a b. (a -> b) -> a -> b
$ do
  Mutable v s a
mv <- forall s. ST s (Mutable v s a)
st
  forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite Mutable v s a
mv Int
i a
x
  forall (m :: * -> *) a. Monad m => a -> m a
return Mutable v s a
mv