{-# LANGUAGE TypeFamilies #-} module Data.Array.Knead.Simple.PhysicalPrivate where import qualified Data.Array.Knead.Simple.Private as Sym import qualified Data.Array.Knead.Index.Nested.Shape as Shape import qualified Data.Array.Knead.Expression as Expr import Data.Array.Knead.Expression (Exp, unExp) import Data.Array.Knead.Code (getElementPtr) import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Extra.Control as C import qualified LLVM.Extra.Memory as Memory import qualified LLVM.Core as LLVM import Foreign.Storable (Storable, ) import Foreign.Ptr (Ptr, ) import Control.Monad.HT (void, ) import Control.Applicative ((<$>), ) import Data.Tuple.HT (mapSnd, ) mapAccumL :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Shape.C n, Storable n, MultiValueMemory.C n, MultiValue.C acc, Storable x, MultiValueMemory.C x, Storable y, MultiValueMemory.C y) => (Exp acc -> Exp x -> Exp (acc,y)) -> Sym.Array sh acc -> Sym.Array (sh, n) x -> LLVM.Value (Ptr (MultiValueMemory.Struct (sh,n))) -> LLVM.Value (Ptr (MultiValueMemory.Struct y)) -> LLVM.CodeGenFunction r () mapAccumL f (Sym.Array _ initCode) (Sym.Array esh code) sptr ptr = do (sh, n) <- MultiValue.unzip <$> Shape.load esh sptr let step ix ptrStart = do accInit <- initCode ix fmap fst $ (\body -> Shape.loop body n (ptrStart, accInit)) $ \k0 (ptr0, acc0) -> do x <- code $ MultiValue.zip ix k0 (acc1,y) <- MultiValue.unzip <$> Expr.unliftM2 f acc0 x Memory.store y ptr0 ptr1 <- A.advanceArrayElementPtr ptr0 return (ptr1, acc1) void $ Shape.loop step sh ptr foldOuterL :: (Shape.C sh, Storable sh, MultiValueMemory.C sh, Shape.C n, Storable n, MultiValueMemory.C n, MultiValueMemory.C a) => (Exp a -> Exp b -> Exp a) -> Sym.Array sh a -> Sym.Array (n,sh) b -> LLVM.Value (Ptr (MultiValueMemory.Struct sh)) -> LLVM.Value (Ptr (MultiValueMemory.Struct a)) -> LLVM.CodeGenFunction r () foldOuterL f (Sym.Array _ initCode) (Sym.Array esh code) _sptr ptr = do -- (n,sh) <- MultiValue.unzip <$> Shape.load esh sptr (n,sh) <- MultiValue.unzip <$> unExp esh let fillInit ix ptr0 = do a <- initCode ix Memory.store a ptr0 A.advanceArrayElementPtr ptr0 void $ Shape.loop fillInit sh ptr let step k ix ptr0 = do b <- code $ MultiValue.zip k ix a0 <- Memory.load ptr0 a1 <- Expr.unliftM2 f a0 b Memory.store a1 ptr0 A.advanceArrayElementPtr ptr0 void $ Shape.loop (\k () -> void $ Shape.loop (step k) sh ptr) n () scatterMaybe :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Storable sh1, MultiValueMemory.C sh1, Storable a, MultiValueMemory.C a) => (Exp a -> Exp a -> Exp a) -> Sym.Array sh1 a -> Sym.Array sh0 (Maybe (ix1, a)) -> LLVM.Value (Ptr (MultiValueMemory.Struct sh1)) -> LLVM.Value (Ptr (MultiValueMemory.Struct a)) -> LLVM.CodeGenFunction r () scatterMaybe accum (Sym.Array esh codeInit) (Sym.Array eish codeMap) sptr ptr = do let clear ix p = do flip Memory.store p =<< codeInit ix A.advanceArrayElementPtr p sh <- Shape.load esh sptr void $ Shape.loop clear sh ptr ish <- unExp eish let fill ix () = do (MultiValue.Cons c, (jx, a)) <- mapSnd MultiValue.unzip . MultiValue.splitMaybe <$> codeMap ix C.ifThen c () $ do p <- getElementPtr sh ptr jx flip Memory.store p =<< Expr.unliftM2 (flip accum) a =<< Memory.load p Shape.loop fill ish () scatter :: (Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Storable sh1, MultiValueMemory.C sh1, Storable a, MultiValueMemory.C a) => (Exp a -> Exp a -> Exp a) -> Sym.Array sh1 a -> Sym.Array sh0 (Shape.Index sh1, a) -> LLVM.Value (Ptr (MultiValueMemory.Struct sh1)) -> LLVM.Value (Ptr (MultiValueMemory.Struct a)) -> LLVM.CodeGenFunction r () scatter accum (Sym.Array esh codeInit) (Sym.Array eish codeMap) sptr ptr = do let clear ix p = do flip Memory.store p =<< codeInit ix A.advanceArrayElementPtr p sh <- Shape.load esh sptr void $ Shape.loop clear sh ptr ish <- unExp eish let fill ix () = do (jx, a) <- fmap MultiValue.unzip $ codeMap ix p <- getElementPtr sh ptr jx flip Memory.store p =<< Expr.unliftM2 (flip accum) a =<< Memory.load p Shape.loop fill ish ()