{-# LANGUAGE Rank2Types #-} {-# LANGUAGE TypeFamilies #-} module Data.Array.Knead.Simple.Private where import qualified Data.Array.Knead.Shape as Shape import qualified Data.Array.Knead.Expression as Expr import LLVM.DSL.Expression (Exp(Exp)) import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Iterator as Iter import qualified LLVM.Extra.Maybe as Maybe import qualified LLVM.Core as LLVM import qualified Control.Category as Cat import qualified Control.Monad.HT as Monad import Control.Monad ((<=<), ) import Prelude hiding (id, map, zipWith, replicate, ) type Val = MultiValue.T type Code r a = LLVM.CodeGenFunction r (Val a) data Array sh a = Array (Exp sh) (forall r. Val (Shape.Index sh) -> Code r a) shape :: Array sh a -> Exp sh shape (Array sh _) = sh (!) :: (Shape.C sh, Shape.Index sh ~ ix) => Array sh a -> Exp ix -> Exp a (!) (Array _ code) (Exp ix) = Exp (code =<< ix) the :: (Shape.Scalar sh) => Array sh a -> Exp a the (Array z code) = Exp (code $ Shape.zeroIndex z) fromScalar :: (Shape.Scalar sh) => Exp a -> Array sh a fromScalar = fill Shape.scalar fill :: Exp sh -> Exp a -> Array sh a fill sh (Exp code) = Array sh (\_z -> code) {- | This class allows to implement functions without parameters for both simple and parameterized arrays. -} class C array where lift0 :: Array sh a -> array sh a lift1 :: (Array sha a -> Array shb b) -> array sha a -> array shb b lift2 :: (Array sha a -> Array shb b -> Array shc c) -> array sha a -> array shb b -> array shc c instance C Array where lift0 = Cat.id lift1 = Cat.id lift2 = Cat.id gather :: (C array, Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, MultiValue.C a) => array sh1 ix0 -> array sh0 a -> array sh1 a gather = lift2 $ \(Array sh1 f) (Array _sh0 code) -> Array sh1 (code <=< f) backpermute2 :: (C array, Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Shape.C sh, Shape.Index sh ~ ix) => Exp sh -> (Exp ix -> Exp ix0) -> (Exp ix -> Exp ix1) -> (Exp a -> Exp b -> Exp c) -> array sh0 a -> array sh1 b -> array sh c backpermute2 sh projectIndex0 projectIndex1 f = lift2 $ \(Array _sha codeA) (Array _shb codeB) -> Array sh (\ix -> Monad.liftJoin2 (Expr.unliftM2 f) (codeA =<< Expr.unliftM1 projectIndex0 ix) (codeB =<< Expr.unliftM1 projectIndex1 ix)) id :: (C array, Shape.C sh, Shape.Index sh ~ ix) => Exp sh -> array sh ix id sh = lift0 $ Array sh return map :: (C array, Shape.C sh) => (Exp a -> Exp b) -> array sh a -> array sh b map f = lift1 $ \(Array sh code) -> Array sh (Expr.unliftM1 f <=< code) mapWithIndex :: (C array, Shape.C sh, Shape.Index sh ~ ix) => (Exp ix -> Exp a -> Exp b) -> array sh a -> array sh b mapWithIndex f = lift1 $ \(Array sh code) -> Array sh (\ix -> Expr.unliftM2 f ix =<< code ix) fold1Code :: (Shape.C sh, Shape.Index sh ~ ix, MultiValue.C a) => (Exp a -> Exp a -> Exp a) -> Exp sh -> (Val ix -> Code r a) -> Code r a fold1Code f (Exp nc) code = do n <- nc fmap Maybe.fromJust $ Shape.loop (\i0 macc0 -> do a <- code i0 acc1 <- Maybe.run macc0 (return a) (flip (Expr.unliftM2 f) a) return $ Maybe.just acc1) n Maybe.nothing fold1 :: (C array, Shape.C sh0, Shape.C sh1, MultiValue.C a) => (Exp a -> Exp a -> Exp a) -> array (sh0, sh1) a -> array sh0 a fold1 f = lift1 $ \(Array shs code) -> case Expr.unzip shs of (sh, s) -> Array sh $ fold1Code f s . MultiValue.curry code fold1All :: (Shape.C sh, MultiValue.C a) => (Exp a -> Exp a -> Exp a) -> Array sh a -> Exp a fold1All f (Array sh code) = Exp (fold1Code f sh code) findAllCode :: (Shape.C sh, Shape.Index sh ~ ix, MultiValue.C a) => (Exp a -> Exp Bool) -> Exp sh -> (Val ix -> Code r a) -> Code r (Maybe a) findAllCode p (Exp sh) code = do n <- sh finalFound <- Iter.mapWhileState_ (\a _found -> do MultiValue.Cons b <- Expr.unliftM1 p a notb <- LLVM.inv b return (notb, Maybe.fromBool b a)) (Iter.mapM code $ Shape.iterator n) Maybe.nothing Maybe.run finalFound (return MultiValue.nothing) (return . MultiValue.just) {- | In principle this can be implemented using fold1All but it has a short-cut semantics. @All@ means that it scans all dimensions but it does not mean that it finds all occurrences. If you want to get the index of the found element, please decorate the array elements with their indices before calling 'findAll'. -} findAll :: (Shape.C sh, MultiValue.C a) => (Exp a -> Exp Bool) -> Array sh a -> Exp (Maybe a) findAll p (Array sh code) = Exp (findAllCode p sh code) class Process proc where infixl 3 $:. {- | Use this for combining several dimension manipulators. E.g. > apply (passAny $:. pick 3 $:. pass $:. replicate 10) array The constraint @(Process proc0, Process proc1)@ is a bit weak. We like to enforce that the type constructor like @Slice.T@ is the same in @proc0@ and @proc1@, and only the parameters differ. Currently this coherence is achieved, because we only provide functions of type @proc0 -> proc1@ with this condition. -} ($:.) :: (Process proc0, Process proc1) => proc0 -> (proc0 -> proc1) -> proc1 ($:.) = flip ($)