{- | Reduce selected dimensions. Alternatively you may reorder dimensions with 'ShapeDep.backpermute' and fold once along multiple dimensions. -} {-# LANGUAGE GADTs #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} module Data.Array.Knead.Simple.Fold ( T, Linear, apply, passAny, pass, fold, (Core.$:.), ) where import qualified Data.Array.Knead.Simple.Private as Core import Data.Array.Knead.Simple.Private (Array(Array), Code, Val, ) import qualified Data.Array.Knead.Shape.Cubic as Linear import qualified Data.Array.Knead.Shape.Cubic.Int as IndexInt import qualified Data.Array.Knead.Shape as Shape import qualified Data.Array.Knead.Expression as Expr import Data.Array.Knead.Expression (Exp, unExp, ) import Data.Array.Knead.Shape.Cubic ((#:.), (:.)((:.)), ) import qualified LLVM.Extra.Multi.Value as MultiValue import LLVM.Extra.Multi.Value (atom, ) import Prelude hiding (zipWith, zipWith3, zip, zip3, replicate, ) data T sh0 sh1 a = forall ix0 ix1. (Shape.Index sh0 ~ ix0, Shape.Index sh1 ~ ix1) => Cons (Exp sh0 -> Exp sh1) (forall r. Val sh0 -> (Val ix0 -> Code r a) -> (Val ix1 -> Code r a)) apply :: (Core.C array, Shape.C sh0, Shape.C sh1, MultiValue.C a) => T sh0 sh1 a -> array sh0 a -> array sh1 a apply (Cons fsh reduce) = Core.lift1 $ \(Array sh code) -> Array (fsh sh) (\ix -> do sh0 <- unExp sh; reduce sh0 code ix) type Linear sh0 sh1 = T (Linear.Shape sh0) (Linear.Shape sh1) passAny :: Linear sh sh a passAny = Cons id (const id) pass :: Linear sh0 sh1 a -> Linear (sh0:.i) (sh1:.i) a pass (Cons fsh reduce) = Cons (Expr.modify (Linear.shape (atom:.atom)) $ \(sh:.s) -> fsh sh :. s) (\sh code -> Linear.switchR $ \jx j -> reduce (Linear.tail sh) (\kx -> code (kx #:. j)) jx) fold1CodeLinear :: (MultiValue.C a) => (Exp a -> Exp a -> Exp a) -> Exp IndexInt.Int -> (Val (Linear.Index (sh :. IndexInt.Int)) -> Code r a) -> (Val (Linear.Index sh) -> Code r a) fold1CodeLinear f nc code ix = Core.fold1Code f (Expr.lift1 (MultiValue.compose . Shape.ZeroBased) $ IndexInt.decons nc) (\j -> code (ix #:. IndexInt.cons j)) fold :: (MultiValue.C a) => (Exp a -> Exp a -> Exp a) -> Linear sh0 sh1 a -> Linear (sh0:.IndexInt.Int) sh1 a fold f (Cons fsh reduce) = Cons (fsh . Linear.tail) (\sh code jx -> reduce (Linear.tail sh) (fold1CodeLinear f (Expr.lift0 (Linear.head sh)) code) jx) instance Core.Process (T sh0 sh1 a) where