{-# LANGUAGE Rank2Types #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Data.Array.Knead.Simple.Private where import qualified Data.Array.Knead.Index.Nested.Shape as Shape import qualified Data.Array.Knead.Expression as Expr import Data.Array.Knead.Expression (Exp(Exp), ) import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Monad as Monad import qualified LLVM.Extra.Maybe as Maybe import qualified LLVM.Core as LLVM import qualified Control.Category as Cat import Control.Monad ((<=<), ) import Prelude hiding (id, map, zipWith, replicate, ) type Val = MultiValue.T type Code r a = LLVM.CodeGenFunction r (Val a) data Array sh a = Array (Exp sh) (forall r. Val (Shape.Index sh) -> Code r a) shape :: Array sh a -> Exp sh shape (Array sh _) = sh (!) :: (Shape.C sh, Shape.Index sh ~ ix) => Array sh a -> Exp ix -> Exp a (!) (Array _ code) (Exp ix) = Exp (code =<< ix) the :: (Shape.Scalar sh) => Array sh a -> Exp a the (Array z code) = Exp (code $ Shape.zeroIndex z) fromScalar :: (Shape.Scalar sh) => Exp a -> Array sh a fromScalar = fill Shape.scalar fill :: Exp sh -> Exp a -> Array sh a fill sh (Exp code) = Array sh (\_z -> code) {- | This class allows to implement functions without parameters for both simple and parameterized arrays. -} class C array where lift0 :: Array sh a -> array sh a lift1 :: (Array sha a -> Array shb b) -> array sha a -> array shb b lift2 :: (Array sha a -> Array shb b -> Array shc c) -> array sha a -> array shb b -> array shc c instance C Array where lift0 = Cat.id lift1 = Cat.id lift2 = Cat.id gather :: (C array, Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, MultiValue.C a) => array sh1 ix0 -> array sh0 a -> array sh1 a gather = lift2 $ \(Array sh1 f) (Array _sh0 code) -> Array sh1 (code <=< f) backpermute2 :: (C array, Shape.C sh0, Shape.Index sh0 ~ ix0, Shape.C sh1, Shape.Index sh1 ~ ix1, Shape.C sh, Shape.Index sh ~ ix) => Exp sh -> (Exp ix -> Exp ix0) -> (Exp ix -> Exp ix1) -> (Exp a -> Exp b -> Exp c) -> array sh0 a -> array sh1 b -> array sh c backpermute2 sh projectIndex0 projectIndex1 f = lift2 $ \(Array _sha codeA) (Array _shb codeB) -> Array sh (\ix -> Monad.liftR2 (Expr.unliftM2 f) (codeA =<< Expr.unliftM1 projectIndex0 ix) (codeB =<< Expr.unliftM1 projectIndex1 ix)) id :: (Shape.C sh, Shape.Index sh ~ ix) => Exp sh -> Array sh ix id sh = Array sh return map :: (C array, Shape.C sh) => (Exp a -> Exp b) -> array sh a -> array sh b map f = lift1 $ \(Array sh code) -> Array sh (Expr.unliftM1 f <=< code) mapWithIndex :: (C array, Shape.C sh, Shape.Index sh ~ ix) => (Exp ix -> Exp a -> Exp b) -> array sh a -> array sh b mapWithIndex f = lift1 $ \(Array sh code) -> Array sh (\ix -> Expr.unliftM2 f ix =<< code ix) fold1Code :: (Shape.C sh1, Shape.Index sh1 ~ ix1, MultiValue.C a) => (Exp a -> Exp a -> Exp a) -> Exp sh1 -> (Val ix0 -> Val ix1 -> Code r a) -> (Val ix0 -> Code r a) fold1Code f (Exp nc) code ix = do n <- nc fmap Maybe.fromJust $ Shape.loop (\i0 macc0 -> do a <- code ix i0 acc1 <- Maybe.run macc0 (return a) (flip (Expr.unliftM2 f) a) return $ Maybe.just acc1) n Maybe.nothing fold1 :: (C array, Shape.C sh0, Shape.C sh1, MultiValue.C a) => (Exp a -> Exp a -> Exp a) -> array (sh0, sh1) a -> array sh0 a fold1 f = lift1 $ \(Array shs code) -> case Expr.unzip shs of (sh, s) -> Array sh $ fold1Code f s $ MultiValue.curry code fold1All :: (Shape.C sh, MultiValue.C a) => (Exp a -> Exp a -> Exp a) -> Array sh a -> Array () a fold1All f (Array esh code) = fold1 f $ Array (Expr.lift1 (MultiValue.zip (MultiValue.Cons ())) esh) (code . MultiValue.snd) class Process proc where infixl 3 $:. {- | Use this for combining several dimension manipulators. E.g. > apply (passAny $:. pick 3 $:. pass $:. replicate 10) array The constraint @(Process proc0, Process proc1)@ is a bit weak. We like to enforce that the type constructor like @Slice.T@ is the same in @proc0@ and @proc1@, and only the parameters differ. Currently this coherence is achieved, because we only provide functions of type @proc0 -> proc1@ with this condition. -} ($:.) :: (Process proc0, Process proc1) => proc0 -> (proc0 -> proc1) -> proc1 ($:.) = flip ($)