{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE UndecidableInstances #-} module LLVM.Extra.Storable.Private where import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Tuple as Tuple import qualified LLVM.Extra.ArithmeticPrivate as A import qualified LLVM.Util.Proxy as LP import qualified LLVM.Core as LLVM import LLVM.Core (CodeGenFunction, Value) import qualified Type.Data.Num.Decimal as TypeNum import qualified Control.Monad.Trans.Class as MT import qualified Control.Monad.Trans.Reader as MR import qualified Control.Monad.Trans.State as MS import qualified Control.Applicative.HT as App import qualified Control.Functor.HT as FuncHT import Control.Monad (foldM, replicateM, replicateM_, (<=<)) import Control.Applicative (Applicative, pure) import qualified Foreign.Storable.Record.Tuple as StoreTuple import qualified Foreign.Storable as Store import Foreign.Storable.FixedArray (roundUp) import Foreign.Ptr (Ptr) import qualified Data.NonEmpty.Class as NonEmptyC import qualified Data.Traversable as Trav import qualified Data.Foldable as Fold import Data.Orphans () import Data.Complex (Complex) import Data.Word (Word8, Word16, Word32, Word64, Word) import Data.Int (Int8, Int16, Int32, Int64) import Data.Bool8 (Bool8) class (Store.Storable a, Tuple.Value a, Tuple.Phi (Tuple.ValueOf a), Tuple.Undefined (Tuple.ValueOf a)) => C a where {- Not all Storable types have a compatible LLVM type, or even more, one LLVM type that is compatible on all platforms. -} load :: Value (Ptr a) -> CodeGenFunction r (Tuple.ValueOf a) store :: Tuple.ValueOf a -> Value (Ptr a) -> CodeGenFunction r () storeNext :: (C a, Tuple.ValueOf a ~ al, Value (Ptr a) ~ ptr) => al -> ptr -> CodeGenFunction r ptr storeNext a ptr = store a ptr >> incrementPtr ptr modify :: (C a, Tuple.ValueOf a ~ al) => (al -> CodeGenFunction r al) -> Value (Ptr a) -> CodeGenFunction r () modify f ptr = flip store ptr =<< f =<< load ptr loadMultiValue :: (C a) => Value (Ptr a) -> CodeGenFunction r (MultiValue.T a) loadMultiValue ptr = fmap MultiValue.Cons $ load ptr storeMultiValue :: (C a) => MultiValue.T a -> Value (Ptr a) -> CodeGenFunction r () storeMultiValue (MultiValue.Cons a) ptr = store a ptr storeNextMultiValue :: (C a) => MultiValue.T a -> Value (Ptr a) -> CodeGenFunction r (Value (Ptr a)) storeNextMultiValue (MultiValue.Cons a) ptr = store a ptr >> incrementPtr ptr modifyMultiValue :: (C a) => (MultiValue.T a -> CodeGenFunction r (MultiValue.T a)) -> Value (Ptr a) -> CodeGenFunction r () modifyMultiValue f ptr = flip storeMultiValue ptr =<< f =<< loadMultiValue ptr loadPrimitive :: (LLVM.Storable a) => Value (Ptr a) -> CodeGenFunction r (Value a) loadPrimitive ptr = LLVM.load =<< LLVM.bitcast ptr storePrimitive :: (LLVM.Storable a) => Value a -> Value (Ptr a) -> CodeGenFunction r () storePrimitive a ptr = LLVM.store a =<< LLVM.bitcast ptr instance C Float where load = loadPrimitive; store = storePrimitive instance C Double where load = loadPrimitive; store = storePrimitive instance C Word where load = loadPrimitive; store = storePrimitive instance C Word8 where load = loadPrimitive; store = storePrimitive instance C Word16 where load = loadPrimitive; store = storePrimitive instance C Word32 where load = loadPrimitive; store = storePrimitive instance C Word64 where load = loadPrimitive; store = storePrimitive instance C Int where load = loadPrimitive; store = storePrimitive instance C Int8 where load = loadPrimitive; store = storePrimitive instance C Int16 where load = loadPrimitive; store = storePrimitive instance C Int32 where load = loadPrimitive; store = storePrimitive instance C Int64 where load = loadPrimitive; store = storePrimitive {- | Not very efficient implementation because we want to adapt to @sizeOf Bool@ dynamically. Unfortunately, LLVM-9's optimizer does not recognize the instruction pattern. Better use 'Bool8' for booleans. -} instance C Bool where load ptr = do bytePtr <- castToBytePtr ptr bytes <- flip MS.evalStateT bytePtr $ replicateM (Store.sizeOf (False :: Bool)) (MT.lift . LLVM.load =<< incPtrState) let zero = LLVM.valueOf 0 mask <- foldM A.or zero bytes A.cmp LLVM.CmpNE mask zero store b ptr = do bytePtr <- castToBytePtr ptr byte <- LLVM.sext b flip MS.evalStateT bytePtr $ replicateM_ (Store.sizeOf (False :: Bool)) (MT.lift . LLVM.store byte =<< incPtrState) incPtrState :: MS.StateT BytePtr (CodeGenFunction r) BytePtr incPtrState = update A.advanceArrayElementPtr instance C Bool8 where load ptr = A.cmp LLVM.CmpNE (LLVM.valueOf 0) =<< LLVM.load =<< castToBytePtr ptr store b ptr = do byte <- LLVM.zext b LLVM.store byte =<< castToBytePtr ptr instance (C a) => C (Complex a) where load = loadApplicative; store = storeFoldable instance (Tuple tuple) => C (StoreTuple.Tuple tuple) where load = loadTuple store = storeTuple class (StoreTuple.Storable tuple, Tuple.Value tuple, Tuple.Phi (Tuple.ValueOf tuple), Tuple.Undefined (Tuple.ValueOf tuple)) => Tuple tuple where loadTuple :: Value (Ptr (StoreTuple.Tuple tuple)) -> CodeGenFunction r (Tuple.ValueOf tuple) storeTuple :: Tuple.ValueOf tuple -> Value (Ptr (StoreTuple.Tuple tuple)) -> CodeGenFunction r () instance (C a, C b) => Tuple (a,b) where loadTuple ptr = runElements ptr $ App.mapPair (loadElement, loadElement) $ FuncHT.unzip $ proxyFromElement3 ptr storeTuple (a,b) ptr = case FuncHT.unzip $ proxyFromElement3 ptr of (pa,pb) -> runElements ptr $ storeElement pa a >> storeElement pb b instance (C a, C b, C c) => Tuple (a,b,c) where loadTuple ptr = runElements ptr $ App.mapTriple (loadElement, loadElement, loadElement) $ FuncHT.unzip3 $ proxyFromElement3 ptr storeTuple (a,b,c) ptr = case FuncHT.unzip3 $ proxyFromElement3 ptr of (pa,pb,pc) -> runElements ptr $ storeElement pa a >> storeElement pb b >> storeElement pc c runElements :: Value (Ptr a) -> MR.ReaderT BytePtr (MS.StateT Int (CodeGenFunction r)) c -> CodeGenFunction r c runElements ptr act = do bytePtr <- castToBytePtr ptr flip MS.evalStateT 0 $ flip MR.runReaderT bytePtr act loadElement :: (C a) => LP.Proxy a -> MR.ReaderT BytePtr (MS.StateT Int (CodeGenFunction r)) (Tuple.ValueOf a) loadElement proxy = MT.lift . MT.lift . load =<< elementPtr proxy storeElement :: (C a) => LP.Proxy a -> Tuple.ValueOf a -> MR.ReaderT BytePtr (MS.StateT Int (CodeGenFunction r)) () storeElement proxy a = MT.lift . MT.lift . store a =<< elementPtr proxy elementPtr :: (C a) => LP.Proxy a -> MR.ReaderT BytePtr (MS.StateT Int (CodeGenFunction r)) (LLVM.Value (Ptr a)) elementPtr proxy = do ptr <- MR.ask MT.lift $ do offset <- elementOffset proxy MT.lift $ castFromBytePtr =<< LLVM.getElementPtr ptr (offset, ()) elementOffset :: (Monad m, Store.Storable a) => LP.Proxy a -> MS.StateT Int m Int elementOffset proxy = do let dummy = elementFromProxy proxy MS.modify (roundUp $ Store.alignment dummy) offset <- MS.get MS.modify (+ Store.sizeOf dummy) return offset instance (TypeNum.Positive n, Vector a, Tuple.VectorValue n a, Tuple.Phi (Tuple.VectorValueOf n a)) => C (LLVM.Vector n a) where load ptr = assembleVector (proxyFromElement3 ptr) =<< loadApplicative ptr store a ptr = flip storeFoldable ptr =<< disassembleVector (proxyFromElement3 ptr) a class (C a) => Vector a where assembleVector :: (TypeNum.Positive n) => LP.Proxy a -> LLVM.Vector n (Tuple.ValueOf a) -> CodeGenFunction r (Tuple.VectorValueOf n a) disassembleVector :: (TypeNum.Positive n) => LP.Proxy a -> Tuple.VectorValueOf n a -> CodeGenFunction r (LLVM.Vector n (Tuple.ValueOf a)) assemblePrimitive :: (TypeNum.Positive n, LLVM.IsPrimitive a) => LLVM.Vector n (Value a) -> CodeGenFunction r (Value (LLVM.Vector n a)) assemblePrimitive = foldM (\v (i,x) -> LLVM.insertelement v x (LLVM.valueOf i)) (LLVM.value LLVM.undef) . zip [0..] . Fold.toList disassemblePrimitive :: (TypeNum.Positive n, LLVM.IsPrimitive a) => Value (LLVM.Vector n a) -> CodeGenFunction r (LLVM.Vector n (Value a)) disassemblePrimitive v = Trav.mapM (LLVM.extractelement v . LLVM.valueOf) indices indices :: (Applicative f, Trav.Traversable f) => f Word32 indices = flip MS.evalState 0 $ Trav.sequenceA $ pure $ MS.state (\k -> (k,k+1)) instance Vector Float where assembleVector LP.Proxy = assemblePrimitive disassembleVector LP.Proxy = disassemblePrimitive instance Vector Double where assembleVector LP.Proxy = assemblePrimitive disassembleVector LP.Proxy = disassemblePrimitive instance Vector Word where assembleVector LP.Proxy = assemblePrimitive disassembleVector LP.Proxy = disassemblePrimitive instance Vector Word8 where assembleVector LP.Proxy = assemblePrimitive disassembleVector LP.Proxy = disassemblePrimitive instance Vector Word16 where assembleVector LP.Proxy = assemblePrimitive disassembleVector LP.Proxy = disassemblePrimitive instance Vector Word32 where assembleVector LP.Proxy = assemblePrimitive disassembleVector LP.Proxy = disassemblePrimitive instance Vector Word64 where assembleVector LP.Proxy = assemblePrimitive disassembleVector LP.Proxy = disassemblePrimitive instance Vector Int where assembleVector LP.Proxy = assemblePrimitive disassembleVector LP.Proxy = disassemblePrimitive instance Vector Int8 where assembleVector LP.Proxy = assemblePrimitive disassembleVector LP.Proxy = disassemblePrimitive instance Vector Int16 where assembleVector LP.Proxy = assemblePrimitive disassembleVector LP.Proxy = disassemblePrimitive instance Vector Int32 where assembleVector LP.Proxy = assemblePrimitive disassembleVector LP.Proxy = disassemblePrimitive instance Vector Int64 where assembleVector LP.Proxy = assemblePrimitive disassembleVector LP.Proxy = disassemblePrimitive instance Vector Bool where assembleVector LP.Proxy = assemblePrimitive disassembleVector LP.Proxy = disassemblePrimitive instance Vector Bool8 where assembleVector LP.Proxy = assemblePrimitive disassembleVector LP.Proxy = disassemblePrimitive instance (Tuple tuple, TupleVector tuple) => Vector (StoreTuple.Tuple tuple) where assembleVector = deinterleave . fmap StoreTuple.getTuple disassembleVector = interleave . fmap StoreTuple.getTuple class TupleVector a where deinterleave :: (TypeNum.Positive n) => LP.Proxy a -> LLVM.Vector n (Tuple.ValueOf a) -> CodeGenFunction r (Tuple.VectorValueOf n a) interleave :: (TypeNum.Positive n) => LP.Proxy a -> Tuple.VectorValueOf n a -> CodeGenFunction r (LLVM.Vector n (Tuple.ValueOf a)) instance (Vector a, Vector b) => TupleVector (a,b) where deinterleave = FuncHT.uncurry $ \pa pb -> FuncHT.uncurry $ \a b -> App.lift2 (,) (assembleVector pa a) (assembleVector pb b) interleave = FuncHT.uncurry $ \pa pb (a,b) -> App.lift2 (App.lift2 (,)) (disassembleVector pa a) (disassembleVector pb b) instance (Vector a, Vector b, Vector c) => TupleVector (a,b,c) where deinterleave = FuncHT.uncurry3 $ \pa pb pc -> FuncHT.uncurry3 $ \a b c -> App.lift3 (,,) (assembleVector pa a) (assembleVector pb b) (assembleVector pc c) interleave = FuncHT.uncurry3 $ \pa pb pc (a,b,c) -> App.lift3 (App.lift3 (,,)) (disassembleVector pa a) (disassembleVector pb b) (disassembleVector pc c) {- instance Storable () available since base-4.9/GHC-8.0. Before we need Data.Orphans. -} instance C () where load _ptr = return () store () _ptr = return () loadNewtype :: (C a, Tuple.ValueOf a ~ al) => (a -> wrapped) -> (al -> wrappedl) -> Value (Ptr wrapped) -> CodeGenFunction r wrappedl loadNewtype wrap wrapl = fmap wrapl . load <=< rmapPtr wrap storeNewtype :: (C a, Tuple.ValueOf a ~ al) => (a -> wrapped) -> (wrappedl -> al) -> wrappedl -> Value (Ptr wrapped) -> CodeGenFunction r () storeNewtype wrap unwrapl y = store (unwrapl y) <=< rmapPtr wrap rmapPtr :: (a -> b) -> Value (Ptr b) -> CodeGenFunction r (Value (Ptr a)) rmapPtr _f = LLVM.bitcast loadTraversable :: (NonEmptyC.Repeat f, Trav.Traversable f, C a, Tuple.ValueOf a ~ al) => Value (Ptr (f a)) -> CodeGenFunction r (f al) loadTraversable = (MS.evalStateT $ Trav.sequence $ NonEmptyC.repeat $ loadState) <=< castElementPtr loadApplicative :: (Applicative f, Trav.Traversable f, C a, Tuple.ValueOf a ~ al) => Value (Ptr (f a)) -> CodeGenFunction r (f al) loadApplicative = (MS.evalStateT $ Trav.sequence $ pure loadState) <=< castElementPtr loadState :: (C a, Tuple.ValueOf a ~ al) => MS.StateT (Value (Ptr a)) (CodeGenFunction r) al loadState = MT.lift . load =<< advancePtrState storeFoldable :: (Fold.Foldable f, C a, Tuple.ValueOf a ~ al) => f al -> Value (Ptr (f a)) -> CodeGenFunction r () storeFoldable xs = MS.evalStateT (Fold.mapM_ storeState xs) <=< castElementPtr storeState :: (C a, Tuple.ValueOf a ~ al) => al -> MS.StateT (Value (Ptr a)) (CodeGenFunction r) () storeState a = MT.lift . store a =<< advancePtrState update :: (Monad m) => (a -> m a) -> MS.StateT a m a update f = MS.StateT $ \a0 -> do a1 <- f a0; return (a0,a1) advancePtrState :: (C a, Value (Ptr a) ~ ptr) => MS.StateT ptr (CodeGenFunction r) ptr advancePtrState = update $ advancePtrStatic 1 advancePtr :: (Store.Storable a, Value (Ptr a) ~ ptr) => Value Int -> ptr -> CodeGenFunction r ptr advancePtr n ptr = do size <- A.mul n $ LLVM.valueOf $ Store.sizeOf (elementFromPtr ptr) addPointer size ptr advancePtrStatic :: (Store.Storable a, Value (Ptr a) ~ ptr) => Int -> ptr -> CodeGenFunction r ptr advancePtrStatic n ptr = addPointer (LLVM.valueOf (Store.sizeOf (elementFromPtr ptr) * n)) ptr incrementPtr :: (Store.Storable a, Value (Ptr a) ~ ptr) => ptr -> CodeGenFunction r ptr incrementPtr = advancePtrStatic 1 decrementPtr :: (Store.Storable a, Value (Ptr a) ~ ptr) => ptr -> CodeGenFunction r ptr decrementPtr = advancePtrStatic (-1) addPointer :: Value Int -> Value (Ptr a) -> CodeGenFunction r (Value (Ptr a)) addPointer k ptr = do bytePtr <- castToBytePtr ptr castFromBytePtr =<< LLVM.getElementPtr bytePtr (k, ()) type BytePtr = Value (LLVM.Ptr Word8) castToBytePtr :: Value (Ptr a) -> CodeGenFunction r BytePtr castToBytePtr = LLVM.bitcast castFromBytePtr :: BytePtr -> CodeGenFunction r (Value (Ptr a)) castFromBytePtr = LLVM.bitcast castElementPtr :: Value (Ptr (f a)) -> CodeGenFunction r (Value (Ptr a)) castElementPtr = LLVM.bitcast sizeOf :: (Store.Storable a) => LP.Proxy a -> Int sizeOf = Store.sizeOf . elementFromProxy elementFromPtr :: LLVM.Value (Ptr a) -> a elementFromPtr _ = error "elementFromProxy" elementFromProxy :: LP.Proxy a -> a elementFromProxy LP.Proxy = error "elementFromProxy" proxyFromElement2 :: f (g a) -> LP.Proxy a proxyFromElement2 _ = LP.Proxy proxyFromElement3 :: f (g (h a)) -> LP.Proxy a proxyFromElement3 _ = LP.Proxy