{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -fno-warn-unused-matches #-}

-- | @std::vector@
--
-- Original author @chpatrick https://github.com/fpco/inline-c/blob/1ba35141e330981fef0457a1619701b8acc32f0b/inline-c-cpp/test/StdVector.hs
module Hercules.CNix.Std.Vector
  ( stdVectorCtx,
    instanceStdVector,
    instanceStdVectorCopyable,
    CStdVector,
    StdVector (StdVector),
    Hercules.CNix.Std.Vector.new,
    size,
    toVector,
    toVectorP,
    toListP,
    toListFP,
    Hercules.CNix.Std.Vector.toList,
    Hercules.CNix.Std.Vector.fromList,
    fromListFP,
    pushBack,
    pushBackP,
    pushBackFP,
  )
where

import Control.Exception (mask_)
import Data.Coerce (Coercible, coerce)
import Data.Foldable
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
import Foreign
import Foreign.C
import Hercules.CNix.Encapsulation
import qualified Language.C.Inline as C
import qualified Language.C.Inline.Cpp as C
import qualified Language.C.Inline.Unsafe as CU
import Language.Haskell.TH
import Prelude

data CStdVector a

stdVectorCtx :: C.Context
stdVectorCtx :: Context
stdVectorCtx = Context
C.cppCtx Context -> Context -> Context
forall a. Monoid a => a -> a -> a
`mappend` [(CIdentifier, TypeQ)] -> Context
C.cppTypePairs [(CIdentifier
"std::vector", [t|CStdVector|])]

newtype StdVector a = StdVector (ForeignPtr (CStdVector a))

instance HasStdVector a => HasEncapsulation (CStdVector a) (StdVector a) where
  moveToForeignPtrWrapper :: Ptr (CStdVector a) -> IO (StdVector a)
moveToForeignPtrWrapper Ptr (CStdVector a)
x = ForeignPtr (CStdVector a) -> StdVector a
forall a. ForeignPtr (CStdVector a) -> StdVector a
StdVector (ForeignPtr (CStdVector a) -> StdVector a)
-> IO (ForeignPtr (CStdVector a)) -> IO (StdVector a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FinalizerPtr (CStdVector a)
-> Ptr (CStdVector a) -> IO (ForeignPtr (CStdVector a))
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FinalizerPtr (CStdVector a)
forall a. HasStdVector a => FunPtr (Ptr (CStdVector a) -> IO ())
cDelete Ptr (CStdVector a)
x

class HasStdVector a where
  cNew :: IO (Ptr (CStdVector a))
  cDelete :: FunPtr (Ptr (CStdVector a) -> IO ())
  cSize :: Ptr (CStdVector a) -> IO CSize
  cCopies :: Ptr (CStdVector a) -> Ptr (Ptr a) -> IO ()
  cPushBackByPtr :: Ptr a -> Ptr (CStdVector a) -> IO ()

class HasStdVector a => HasStdVectorCopyable a where
  cCopyTo :: Ptr (CStdVector a) -> Ptr a -> IO ()
  cPushBack :: a -> Ptr (CStdVector a) -> IO ()

-- | Helper for defining templated instances
roll :: String -> Q [Dec] -> Q [Dec]
roll :: [Char] -> Q [Dec] -> Q [Dec]
roll [Char]
cType Q [Dec]
d =
  [[Dec]] -> [Dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
    ([[Dec]] -> [Dec]) -> Q [[Dec]] -> Q [Dec]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Q [Dec]] -> Q [[Dec]]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence
      [ [Char] -> Q [Dec]
C.include [Char]
"<vector>",
        [Char] -> Q [Dec]
C.include [Char]
"<algorithm>",
        [([Char], [Char] -> [Char])] -> Q [Dec] -> Q [Dec]
forall a. [([Char], [Char] -> [Char])] -> Q a -> Q a
C.substitute
          [ ([Char]
"T", [Char] -> [Char] -> [Char]
forall a b. a -> b -> a
const [Char]
cType),
            ([Char]
"VEC", \[Char]
var -> [Char]
"$(std::vector<" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
cType [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
">* " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
var [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
")")
          ]
          Q [Dec]
d
      ]

instanceStdVector :: String -> DecsQ
instanceStdVector :: [Char] -> Q [Dec]
instanceStdVector [Char]
cType =
  [Char] -> Q [Dec] -> Q [Dec]
roll
    [Char]
cType
    [d|
      instance HasStdVector $(C.getHaskellType False cType) where
        cNew = [CU.exp| std::vector<@T()>* { new std::vector<@T()>() } |]
        cDelete = [C.funPtr| void deleteStdVector(std::vector<@T()>* vec) { delete vec; } |]
        cSize vec = [CU.exp| size_t { @VEC(vec)->size() } |]

        cCopies vec dstPtr =
          [CU.block| void {
          const std::vector<@T()>& vec = *@VEC(vec);
          @T()** aim = $(@T()** dstPtr);
          for (auto item : vec) {
            *aim = new @T()(item);
            aim++;
          }
        }|]
        cPushBackByPtr ptr vec = [CU.exp| void { @VEC(vec)->push_back(*$(@T() *ptr)) } |]
      |]

instanceStdVectorCopyable :: String -> DecsQ
instanceStdVectorCopyable :: [Char] -> Q [Dec]
instanceStdVectorCopyable [Char]
cType =
  [Char] -> Q [Dec] -> Q [Dec]
roll
    [Char]
cType
    [d|
      instance HasStdVectorCopyable $(C.getHaskellType False cType) where
        cCopyTo vec dstPtr =
          [CU.block| void {
          const std::vector<@T()>* vec = @VEC(vec);
          std::copy(vec->begin(), vec->end(), $(@T()* dstPtr));
          } |]
        cPushBack value vec = [CU.exp| void { @VEC(vec)->push_back($(@T() value)) } |]
      |]

new :: forall a. HasStdVector a => IO (StdVector a)
new :: forall a. HasStdVector a => IO (StdVector a)
new = IO (StdVector a) -> IO (StdVector a)
forall a. IO a -> IO a
mask_ (IO (StdVector a) -> IO (StdVector a))
-> IO (StdVector a) -> IO (StdVector a)
forall a b. (a -> b) -> a -> b
$ do
  Ptr (CStdVector a)
ptr <- forall a. HasStdVector a => IO (Ptr (CStdVector a))
cNew @a
  ForeignPtr (CStdVector a) -> StdVector a
forall a. ForeignPtr (CStdVector a) -> StdVector a
StdVector (ForeignPtr (CStdVector a) -> StdVector a)
-> IO (ForeignPtr (CStdVector a)) -> IO (StdVector a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FinalizerPtr (CStdVector a)
-> Ptr (CStdVector a) -> IO (ForeignPtr (CStdVector a))
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FinalizerPtr (CStdVector a)
forall a. HasStdVector a => FunPtr (Ptr (CStdVector a) -> IO ())
cDelete Ptr (CStdVector a)
ptr

size :: HasStdVector a => StdVector a -> IO Int
size :: forall a. HasStdVector a => StdVector a -> IO Int
size (StdVector ForeignPtr (CStdVector a)
fptr) = CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CSize -> Int) -> IO CSize -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ForeignPtr (CStdVector a)
-> (Ptr (CStdVector a) -> IO CSize) -> IO CSize
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr (CStdVector a)
fptr Ptr (CStdVector a) -> IO CSize
forall a. HasStdVector a => Ptr (CStdVector a) -> IO CSize
cSize

toVector :: (HasStdVectorCopyable a, Storable a) => StdVector a -> IO (VS.Vector a)
toVector :: forall a.
(HasStdVectorCopyable a, Storable a) =>
StdVector a -> IO (Vector a)
toVector stdVec :: StdVector a
stdVec@(StdVector ForeignPtr (CStdVector a)
stdVecFPtr) = do
  Int
vecSize <- StdVector a -> IO Int
forall a. HasStdVector a => StdVector a -> IO Int
size StdVector a
stdVec
  IOVector a
hsVec <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
VSM.new Int
vecSize
  ForeignPtr (CStdVector a) -> (Ptr (CStdVector a) -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr (CStdVector a)
stdVecFPtr ((Ptr (CStdVector a) -> IO ()) -> IO ())
-> (Ptr (CStdVector a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (CStdVector a)
stdVecPtr ->
    IOVector a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => IOVector a -> (Ptr a -> IO b) -> IO b
VSM.unsafeWith IOVector a
hsVec ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
hsVecPtr ->
      Ptr (CStdVector a) -> Ptr a -> IO ()
forall a.
HasStdVectorCopyable a =>
Ptr (CStdVector a) -> Ptr a -> IO ()
cCopyTo Ptr (CStdVector a)
stdVecPtr Ptr a
hsVecPtr
  MVector (PrimState IO) a -> IO (Vector a)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VS.unsafeFreeze IOVector a
MVector (PrimState IO) a
hsVec

toVectorP :: HasStdVector a => StdVector a -> IO (VS.Vector (Ptr a))
toVectorP :: forall a. HasStdVector a => StdVector a -> IO (Vector (Ptr a))
toVectorP stdVec :: StdVector a
stdVec@(StdVector ForeignPtr (CStdVector a)
stdVecFPtr) = do
  Int
vecSize <- StdVector a -> IO Int
forall a. HasStdVector a => StdVector a -> IO Int
size StdVector a
stdVec
  IOVector (Ptr a)
hsVec <- Int -> IO (MVector (PrimState IO) (Ptr a))
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
VSM.new Int
vecSize
  ForeignPtr (CStdVector a) -> (Ptr (CStdVector a) -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr (CStdVector a)
stdVecFPtr ((Ptr (CStdVector a) -> IO ()) -> IO ())
-> (Ptr (CStdVector a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (CStdVector a)
stdVecPtr ->
    IOVector (Ptr a) -> (Ptr (Ptr a) -> IO ()) -> IO ()
forall a b. Storable a => IOVector a -> (Ptr a -> IO b) -> IO b
VSM.unsafeWith IOVector (Ptr a)
hsVec ((Ptr (Ptr a) -> IO ()) -> IO ())
-> (Ptr (Ptr a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr (Ptr a)
hsVecPtr ->
      Ptr (CStdVector a) -> Ptr (Ptr a) -> IO ()
forall a.
HasStdVector a =>
Ptr (CStdVector a) -> Ptr (Ptr a) -> IO ()
cCopies Ptr (CStdVector a)
stdVecPtr Ptr (Ptr a)
hsVecPtr
  MVector (PrimState IO) (Ptr a) -> IO (Vector (Ptr a))
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
VS.unsafeFreeze IOVector (Ptr a)
MVector (PrimState IO) (Ptr a)
hsVec

fromList :: HasStdVectorCopyable a => [a] -> IO (StdVector a)
fromList :: forall a. HasStdVectorCopyable a => [a] -> IO (StdVector a)
fromList [a]
as = do
  StdVector a
vec <- IO (StdVector a)
forall a. HasStdVector a => IO (StdVector a)
Hercules.CNix.Std.Vector.new
  [a] -> (a -> IO ()) -> IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [a]
as ((a -> IO ()) -> IO ()) -> (a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \a
a -> StdVector a -> a -> IO ()
forall a. HasStdVectorCopyable a => StdVector a -> a -> IO ()
pushBack StdVector a
vec a
a
  StdVector a -> IO (StdVector a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure StdVector a
vec

fromListFP :: (Coercible a' (ForeignPtr a), HasStdVector a) => [a'] -> IO (StdVector a)
fromListFP :: forall a' a.
(Coercible a' (ForeignPtr a), HasStdVector a) =>
[a'] -> IO (StdVector a)
fromListFP [a']
as = do
  StdVector a
vec <- IO (StdVector a)
forall a. HasStdVector a => IO (StdVector a)
Hercules.CNix.Std.Vector.new
  [a'] -> (a' -> IO ()) -> IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [a']
as ((a' -> IO ()) -> IO ()) -> (a' -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \a'
a -> StdVector a -> a' -> IO ()
forall a' a.
(Coercible a' (ForeignPtr a), HasStdVector a) =>
StdVector a -> a' -> IO ()
pushBackFP StdVector a
vec a'
a
  StdVector a -> IO (StdVector a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure StdVector a
vec

toList :: (HasStdVectorCopyable a, Storable a) => StdVector a -> IO [a]
toList :: forall a.
(HasStdVectorCopyable a, Storable a) =>
StdVector a -> IO [a]
toList StdVector a
vec = Vector a -> [a]
forall a. Storable a => Vector a -> [a]
VS.toList (Vector a -> [a]) -> IO (Vector a) -> IO [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StdVector a -> IO (Vector a)
forall a.
(HasStdVectorCopyable a, Storable a) =>
StdVector a -> IO (Vector a)
toVector StdVector a
vec

toListP :: (HasStdVector a) => StdVector a -> IO [Ptr a]
toListP :: forall a. HasStdVector a => StdVector a -> IO [Ptr a]
toListP StdVector a
vec = Vector (Ptr a) -> [Ptr a]
forall a. Storable a => Vector a -> [a]
VS.toList (Vector (Ptr a) -> [Ptr a]) -> IO (Vector (Ptr a)) -> IO [Ptr a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StdVector a -> IO (Vector (Ptr a))
forall a. HasStdVector a => StdVector a -> IO (Vector (Ptr a))
toVectorP StdVector a
vec

toListFP :: (HasEncapsulation a b, HasStdVector a) => StdVector a -> IO [b]
toListFP :: forall a b.
(HasEncapsulation a b, HasStdVector a) =>
StdVector a -> IO [b]
toListFP StdVector a
vec = (Ptr a -> IO b) -> [Ptr a] -> IO [b]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ptr a -> IO b
forall a b. HasEncapsulation a b => Ptr a -> IO b
moveToForeignPtrWrapper ([Ptr a] -> IO [b]) -> IO [Ptr a] -> IO [b]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< StdVector a -> IO [Ptr a]
forall a. HasStdVector a => StdVector a -> IO [Ptr a]
toListP StdVector a
vec

pushBack :: HasStdVectorCopyable a => StdVector a -> a -> IO ()
pushBack :: forall a. HasStdVectorCopyable a => StdVector a -> a -> IO ()
pushBack (StdVector ForeignPtr (CStdVector a)
fptr) a
value = ForeignPtr (CStdVector a) -> (Ptr (CStdVector a) -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr (CStdVector a)
fptr (a -> Ptr (CStdVector a) -> IO ()
forall a.
HasStdVectorCopyable a =>
a -> Ptr (CStdVector a) -> IO ()
cPushBack a
value)

pushBackP :: HasStdVector a => StdVector a -> Ptr a -> IO ()
pushBackP :: forall a. HasStdVector a => StdVector a -> Ptr a -> IO ()
pushBackP (StdVector ForeignPtr (CStdVector a)
fptr) Ptr a
valueP = ForeignPtr (CStdVector a) -> (Ptr (CStdVector a) -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr (CStdVector a)
fptr (Ptr a -> Ptr (CStdVector a) -> IO ()
forall a. HasStdVector a => Ptr a -> Ptr (CStdVector a) -> IO ()
cPushBackByPtr Ptr a
valueP)

pushBackFP :: (Coercible a' (ForeignPtr a), HasStdVector a) => StdVector a -> a' -> IO ()
pushBackFP :: forall a' a.
(Coercible a' (ForeignPtr a), HasStdVector a) =>
StdVector a -> a' -> IO ()
pushBackFP StdVector a
vec a'
vfptr = ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr (a' -> ForeignPtr a
coerce a'
vfptr) (StdVector a -> Ptr a -> IO ()
forall a. HasStdVector a => StdVector a -> Ptr a -> IO ()
pushBackP StdVector a
vec)