module Control.Imperative.Vector.Static
(
Vector
, MonadVector
, VectorElem
, VectorEntity
, HasVector
, NestedList
, Size(..)
, Dim(..)
, dim1
, dim2
, dim3
, newSized
, newSized'
, Control.Imperative.Vector.Static.length
, size
, fromListN
, toList
) where
import Control.Imperative.Internal
import Control.Imperative.Vector.Base
import Control.Monad (liftM)
import qualified Control.Monad as M
import Control.Monad.Base
import Control.Monad.Primitive (PrimMonad, PrimState)
import Data.Nat
import qualified Data.Vector.Generic.Mutable as GMV
import qualified Data.Vector.Mutable as MV
newtype Vector m n a = V (MultiDim m n a)
class Monad m => HasVector s v m | s -> v, s -> m where
getVector :: s -> m v
instance Monad m => HasVector (Vector m n a) (Vector m n a) m where
getVector = return
instance Monad m => HasVector (Ref m (Vector m n a)) (Vector m n a) m where
getVector = get
data MultiDim m (n :: Nat) a where
D1 :: VectorEntity a (PrimState m) a -> MultiDim m (S Z) a
DN :: MV.MVector (PrimState m) (MultiDim m (S n) a) -> MultiDim m (S (S n)) a
instance (VectorElem a, PrimMonad m) => Indexable (Vector m (S Z) a) where
type Element (Vector m (S Z) a) = Ref m a
type IndexType (Vector m (S Z) a) = Int
(!) (V (D1 v)) i = Ref
{ get = GMV.read v i
, set = GMV.write v i
}
instance PrimMonad m => Indexable (Vector m (S (S n)) a) where
type Element (Vector m (S (S n)) a) = Ref m (Vector m (S n) a)
type IndexType (Vector m (S (S n)) a) = Int
(!) (V (DN v)) i = Ref
{ get = liftM V $ MV.read v i
, set = \(V w) -> MV.write v i w
}
instance (VectorElem a, PrimMonad m) => Indexable (Ref m (Vector m (S Z) a)) where
type Element (Ref m (Vector m (S Z) a)) = Ref m a
type IndexType (Ref m (Vector m (S Z) a)) = Int
r ! i = Ref
{ get = get r >>= \(V (D1 v)) -> GMV.read v i
, set = \x -> get r >>= \(V (D1 v)) -> GMV.write v i x
}
instance PrimMonad m => Indexable (Ref m (Vector m (S (S n)) a)) where
type Element (Ref m (Vector m (S (S n)) a)) = Ref m (Vector m (S n) a)
type IndexType (Ref m (Vector m (S (S n)) a)) = Int
r ! i = Ref
{ get = get r >>= \(V (DN v)) -> liftM V $ MV.read v i
, set = \(V w) -> get r >>= \(V (DN v)) -> MV.write v i w
}
newSized :: (VectorElem a, MonadVector m) => Size (S n) -> m (Vector (BaseEff m) (S n) a)
newSized = liftBase . liftM V . go
where
go :: (VectorElem a, PrimMonad m) => Size (S n) -> m (MultiDim m (S n) a)
go (n :*: One) = liftM D1 $ GMV.new n
go (n :*: r@(_ :*: _)) = do
v <- MV.new n
M.forM_ [0..n1] $ \i -> do
w <- go r
GMV.write v i w
return $ DN v
newSized' :: (VectorElem a, MonadVector m) => Size (S n) -> a -> m (Vector (BaseEff m) (S n) a)
newSized' r = liftBase . liftM V . go r
where
go :: (VectorElem a, PrimMonad m) => Size (S n) -> a -> m (MultiDim m (S n) a)
go (n :*: One) x = liftM D1 $ GMV.replicate n x
go (n :*: rest@(_ :*: _)) x = do
v <- MV.new n
M.forM_ [0..n1] $ \i -> do
w <- go rest x
GMV.write v i w
return $ DN v
size :: (VectorElem a, HasVector s (Vector (BaseEff m) (S n) a) (BaseEff m), MonadVector m) => s -> m Int
size s = liftBase $ getVector s >>= \(V dv) -> return $ case dv of
D1 v -> GMV.length v
DN v -> MV.length v
length :: (VectorElem a, HasVector s (Vector (BaseEff m) (S n) a) (BaseEff m), MonadVector m) => s -> m Int
length = size
fromListN
:: (VectorElem a, MonadVector m)
=> Size (S n)
-> NestedList (S n) a
-> m (Vector (BaseEff m) (S n) a)
fromListN r = liftBase . liftM V . go r
where
go :: (VectorElem a, PrimMonad m) => Size (S n) -> NestedList (S n) a -> m (MultiDim m (S n) a)
go (n :*: One) xs = do
v <- GMV.new n
M.forM_ (zip [0..n1] xs) $ \(i, x) -> GMV.write v i x
return $ D1 v
go (n :*: rest@(_ :*: _)) xs = do
v <- GMV.new n
M.forM_ (zip [0..n1] xs) $ \(i, ys) -> do
w <- go rest ys
GMV.write v i w
return $ DN v
toList :: (VectorElem a, HasVector s (Vector (BaseEff m) (S n) a) (BaseEff m), MonadVector m) => s -> m (NestedList (S n) a)
toList s = liftBase $ getVector s >>= \(V dv) -> go dv
where
go :: (VectorElem a, PrimMonad m) => MultiDim m n a -> m (NestedList n a)
go (D1 v) = M.forM [0..GMV.length v1] (GMV.read v)
go (DN v) = M.forM [0..MV.length v1] (MV.read v) >>= M.mapM go