{-# LANGUAGE TypeFamilies #-} module Data.Array.Knead.Simple.PhysicalPrivate where import qualified Data.Array.Knead.Simple.Private as Sym import qualified Data.Array.Knead.Shape as Shape import qualified Data.Array.Knead.Expression as Expr import Data.Array.Knead.Code (getElementPtr) import LLVM.DSL.Expression (Exp, unExp) import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Marshal as Marshal import qualified LLVM.Extra.Storable as Storable import qualified LLVM.Extra.Control as C import qualified LLVM.Core as LLVM import Foreign.Ptr (Ptr, ) import qualified Control.Applicative.HT as App import Control.Monad.HT (void, ) import Control.Applicative ((<$>), ) import Data.Tuple.HT (mapSnd, ) import Prelude2010 import Prelude () type MarshalPtr a = LLVM.Ptr (Marshal.Struct a) writeArray :: (Shape.C sh, Shape.Index sh ~ ix, Storable.C a) => MultiValue.T sh -> (MultiValue.T ix -> LLVM.CodeGenFunction r (MultiValue.T a)) -> LLVM.Value (Ptr a) -> LLVM.CodeGenFunction r (LLVM.Value (Ptr a)) writeArray sh code ptr = do let clear ix p = flip Storable.storeNextMultiValue p =<< code ix Shape.loop clear sh ptr mapAccumLLoop :: (MultiValue.C acc, Storable.C b, Shape.C sh, Shape.Index sh ~ ix) => (MultiValue.T ix -> LLVM.CodeGenFunction r (MultiValue.T a)) -> (Exp acc -> Exp a -> Exp (acc, b)) -> MultiValue.T sh -> LLVM.Value (Ptr b) -> MultiValue.T acc -> LLVM.CodeGenFunction r (LLVM.Value (Ptr b), MultiValue.T acc) mapAccumLLoop code f n yPtr accInit = do let step k0 (ptr0, acc0) = do x <- code k0 (acc1,y) <- MultiValue.unzip <$> Expr.unliftM2 f acc0 x ptr1 <- Storable.storeNextMultiValue y ptr0 return (ptr1, acc1) Shape.loop step n (yPtr, accInit) mapAccumLSimple :: (Shape.C sh, Marshal.MV sh, Shape.C n, Marshal.MV n, MultiValue.C acc, Storable.C x, Storable.C y) => (Exp acc -> Exp x -> Exp (acc,y)) -> Sym.Array sh acc -> Sym.Array (sh, n) x -> LLVM.Value (MarshalPtr (sh,n)) -> LLVM.Value (Ptr y) -> LLVM.CodeGenFunction r () mapAccumLSimple f (Sym.Array _ initCode) (Sym.Array esh code) sptr ptr = do (sh, n) <- MultiValue.unzip <$> Shape.load esh sptr let step ix ptrStart = do accInit <- initCode ix fst <$> mapAccumLLoop (code . MultiValue.zip ix) f n ptrStart accInit void $ Shape.loop step sh ptr mapAccumLSequence :: (Shape.C n, Marshal.MV n, MultiValue.C acc, Storable.C final, Storable.C x, Storable.C y) => (Exp acc -> Exp x -> Exp (acc,y)) -> (Exp acc -> Exp final) -> Exp acc -> Sym.Array n x -> LLVM.Value (Ptr final) -> LLVM.Value (MarshalPtr n) -> LLVM.Value (Ptr y) -> LLVM.CodeGenFunction r () mapAccumLSequence f final initExp (Sym.Array esh code) accPtr sptr yPtr = do n <- Shape.load esh sptr accInit <- unExp initExp accExit <- snd <$> mapAccumLLoop code f n yPtr accInit flip Storable.storeMultiValue accPtr =<< Expr.unliftM1 final accExit mapAccumL :: (Shape.C sh, Marshal.MV sh, Shape.C n, Marshal.MV n, MultiValue.C acc, Storable.C final, Storable.C x, Storable.C y) => (Exp acc -> Exp x -> Exp (acc,y)) -> (Exp acc -> Exp final) -> Sym.Array sh acc -> Sym.Array (sh, n) x -> (LLVM.Value (MarshalPtr sh), LLVM.Value (Ptr final)) -> (LLVM.Value (MarshalPtr (sh,n)), LLVM.Value (Ptr y)) -> LLVM.CodeGenFunction r () mapAccumL f final (Sym.Array _ initCode) (Sym.Array esh code) (_, accPtr) (sptr, yPtr) = do (sh, n) <- MultiValue.unzip <$> Shape.load esh sptr let step ix (accPtr0, yPtrStart) = do accInit <- initCode ix (ptrStop, accExit) <- mapAccumLLoop (code . MultiValue.zip ix) f n yPtrStart accInit accPtr1 <- flip Storable.storeNextMultiValue accPtr0 =<< Expr.unliftM1 final accExit return (accPtr1, ptrStop) void $ Shape.loop step sh (accPtr,yPtr) foldOuterL :: (Shape.C sh, Marshal.MV sh, Shape.C n, Marshal.MV n, Storable.C a) => (Exp a -> Exp b -> Exp a) -> Sym.Array sh a -> Sym.Array (n,sh) b -> LLVM.Value (MarshalPtr sh) -> LLVM.Value (Ptr a) -> LLVM.CodeGenFunction r () foldOuterL f (Sym.Array _ initCode) (Sym.Array esh code) sptr ptr = do sh <- Shape.load (Expr.snd esh) sptr n <- MultiValue.fst <$> unExp esh void $ writeArray sh initCode ptr let step k ix ptr0 = do b <- code $ MultiValue.zip k ix a0 <- Storable.loadMultiValue ptr0 a1 <- Expr.unliftM2 f a0 b Storable.storeNextMultiValue a1 ptr0 void $ Shape.loop (\k () -> void $ Shape.loop (step k) sh ptr) n () {- | We need a scalar Shape type @n@. Scalar Shape types could be distinguished from other Shape types by the fact that you can convert any Index into a Shape. -} mapFilter :: (Shape.Sequence n, Marshal.MV n, Storable.C b) => (Exp a -> Exp b) -> (Exp a -> Exp Bool) -> Sym.Array n a -> LLVM.Value (MarshalPtr n) -> LLVM.Value (Ptr b) -> LLVM.CodeGenFunction r (MultiValue.T n) mapFilter f p (Sym.Array esh code) sptr ptr = do n <- Shape.load esh sptr let step ix (dstPtr,dstIx) = do a <- code ix MultiValue.Cons c <- Expr.unliftM1 p a C.ifThen c (dstPtr,dstIx) (App.lift2 (,) (flip Storable.storeNextMultiValue dstPtr =<< Expr.unliftM1 f a) (MultiValue.inc dstIx)) Shape.sequenceShapeFromIndex . snd =<< Shape.loop step n (ptr, MultiValue.zero) filterOuter :: (Shape.Sequence n, Marshal.MV n, Shape.C sh, Marshal.MV sh, Storable.C a) => Sym.Array n Bool -> Sym.Array (n,sh) a -> LLVM.Value (MarshalPtr (n,sh)) -> LLVM.Value (Ptr a) -> LLVM.CodeGenFunction r (MultiValue.T (n,sh)) filterOuter (Sym.Array _eish selectCode) (Sym.Array esh code) sptr ptr = do (n,sh) <- MultiValue.unzip <$> Shape.load esh sptr let step k (dstPtr0,dstK) = do MultiValue.Cons c <- selectCode k C.ifThen c (dstPtr0,dstK) (do dstPtr1 <- writeArray sh (code . MultiValue.zip k) dstPtr0 (,) dstPtr1 <$> MultiValue.inc dstK) finalN <- Shape.sequenceShapeFromIndex . snd =<< Shape.loop step n (ptr, MultiValue.zero) return $ MultiValue.zip finalN sh scatterMaybe :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Marshal.MV sh1, Storable.C a) => (Exp a -> Exp a -> Exp a) -> Sym.Array sh1 a -> Sym.Array sh0 (Maybe (ix1, a)) -> LLVM.Value (MarshalPtr sh1) -> LLVM.Value (Ptr a) -> LLVM.CodeGenFunction r () scatterMaybe accum (Sym.Array esh codeInit) (Sym.Array eish codeMap) sptr ptr = do sh <- Shape.load esh sptr void $ writeArray sh codeInit ptr ish <- unExp eish let fill ix () = do (MultiValue.Cons c, (jx, a)) <- mapSnd MultiValue.unzip . MultiValue.splitMaybe <$> codeMap ix C.ifThen c () $ do p <- getElementPtr sh ptr jx flip Storable.storeMultiValue p =<< Expr.unliftM2 (flip accum) a =<< Storable.loadMultiValue p Shape.loop fill ish () scatter :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Marshal.MV sh1, Storable.C a) => (Exp a -> Exp a -> Exp a) -> Sym.Array sh1 a -> Sym.Array sh0 (Shape.Index sh1, a) -> LLVM.Value (MarshalPtr sh1) -> LLVM.Value (Ptr a) -> LLVM.CodeGenFunction r () scatter accum (Sym.Array esh codeInit) (Sym.Array eish codeMap) sptr ptr = do sh <- Shape.load esh sptr void $ writeArray sh codeInit ptr ish <- unExp eish let fill ix () = do (jx, a) <- MultiValue.unzip <$> codeMap ix p <- getElementPtr sh ptr jx flip Storable.storeMultiValue p =<< Expr.unliftM2 (flip accum) a =<< Storable.loadMultiValue p Shape.loop fill ish () addDimension :: (Shape.C n, Marshal.MV n, Shape.Index n ~ k, Shape.C sh, Marshal.MV sh, Storable.C b) => Exp n -> (Exp k -> Exp a -> Exp b) -> Sym.Array sh a -> LLVM.Value (MarshalPtr (sh,n)) -> LLVM.Value (Ptr b) -> LLVM.CodeGenFunction r () addDimension en select (Sym.Array esh code) sptr ptr = do (sh,n) <- MultiValue.unzip <$> Shape.load (Expr.zip esh en) sptr let fill ix ptr0 = do a <- code ix writeArray n (\k -> Expr.unliftM2 select k a) ptr0 void $ Shape.loop fill sh ptr