{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# OPTIONS_GHC -fno-warn-orphans #-} {-# OPTIONS_GHC -fno-warn-name-shadowing #-} {-# OPTIONS_HADDOCK hide #-} -- | -- Module : Data.Array.Accelerate.Trafo.Sharing -- Copyright : [2008..2017] Manuel M T Chakravarty, Gabriele Keller -- [2009..2017] Trevor L. McDonell -- [2013..2017] Robert Clifton-Everest -- License : BSD3 -- -- Maintainer : Manuel M T Chakravarty -- Stability : experimental -- Portability : non-portable (GHC extensions) -- -- This module implements HOAS to de Bruijn conversion of array expressions -- while incorporating sharing information. -- module Data.Array.Accelerate.Trafo.Sharing ( -- * HOAS -> de Bruijn conversion convertAcc, convertAfun, Afunction, AfunctionR, convertExp, convertFun, Function, FunctionR, -- convertSeq ) where -- standard library import Control.Applicative hiding ( Const ) import Control.Monad.Fix import Data.List import Data.Maybe import Data.Hashable import Data.Typeable import System.Mem.StableName import System.IO.Unsafe ( unsafePerformIO ) import Text.Printf import qualified Data.HashTable.IO as Hash import qualified Data.IntMap as IntMap import qualified Data.HashMap.Strict as Map import qualified Data.HashSet as Set import Prelude -- friends import Data.Array.Accelerate.Error import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Array.Sugar as Sugar hiding ( (!!) ) import Data.Array.Accelerate.AST hiding ( PreOpenAcc(..), OpenAcc(..), Acc , PreOpenExp(..), OpenExp, PreExp, Exp , PreBoundary(..), Boundary, Stencil(..) , showPreAccOp, showPreExpOp ) import qualified Data.Array.Accelerate.AST as AST import qualified Data.Array.Accelerate.Debug as Debug -- Configuration -- ------------- -- Perhaps the configuration should be passed as a reader monad or some such, -- but that's a little inconvenient. -- data Config = Config { recoverAccSharing :: Bool -- ^ Recover sharing of array computations ? , recoverExpSharing :: Bool -- ^ Recover sharing of scalar expressions ? , recoverSeqSharing :: Bool -- ^ Recover sharing of sequence computations ? , floatOutAcc :: Bool -- ^ Always float array computations out of expressions ? } -- Layouts -- ------- -- A layout of an environment has an entry for each entry of the environment. -- Each entry in the layout holds the de Bruijn index that refers to the -- corresponding entry in the environment. -- data Layout env env' where EmptyLayout :: Layout env () PushLayout :: Typeable t => Layout env env' -> Idx env t -> Layout env (env', t) -- Project the nth index out of an environment layout. -- -- The first argument provides context information for error messages in the -- case of failure. -- prjIdx :: Typeable t => String -> Int -> Layout env env' -> Idx env t prjIdx context = go where go :: forall env env' t. Typeable t => Int -> Layout env env' -> Idx env t go _ EmptyLayout = no "environment does not contain index" go 0 (PushLayout _ (ix :: Idx env0 s)) | Just ix' <- gcast ix = ix' | otherwise = no $ printf "couldn't match expected type `%s' with actual type `%s'" (show (typeOf (undefined::t))) (show (typeOf (undefined::s))) go n (PushLayout l _) = go (n-1) l no :: String -> a no reason = $internalError "prjIdx" (printf "%s\nin the context: %s" reason context) -- Add an entry to a layout, incrementing all indices -- incLayout :: Layout env env' -> Layout (env, t) env' incLayout EmptyLayout = EmptyLayout incLayout (PushLayout lyt ix) = PushLayout (incLayout lyt) (SuccIdx ix) sizeLayout :: Layout env env' -> Int sizeLayout EmptyLayout = 0 sizeLayout (PushLayout lyt _) = 1 + sizeLayout lyt -- Conversion from HOAS to de Bruijn computation AST -- ================================================= -- Array computations -- ------------------ -- | Convert a closed array expression to de Bruijn form while also incorporating sharing -- information. -- convertAcc :: Arrays arrs => Bool -- ^ recover sharing of array computations ? -> Bool -- ^ recover sharing of scalar expressions ? -> Bool -- ^ recover sharing of sequence computations ? -> Bool -- ^ always float array computations out of expressions? -> Acc arrs -> AST.Acc arrs convertAcc shareAcc shareExp shareSeq floatAcc acc = let config = Config shareAcc shareExp shareSeq (shareAcc && floatAcc) in convertOpenAcc config 0 [] EmptyLayout acc -- | Convert a closed function over array computations, while incorporating -- sharing information. -- convertAfun :: Afunction f => Bool -> Bool -> Bool -> Bool -> f -> AST.Afun (AfunctionR f) convertAfun shareAcc shareExp shareSeq floatAcc = let config = Config shareAcc shareExp shareSeq (shareAcc && floatAcc) in aconvert config EmptyLayout -- Convert a HOAS fragment into de Bruijn form, binding variables into the typed -- environment layout one binder at a time. -- -- NOTE: Because we convert one binder at a time left-to-right, the bound -- variables ('vars') will have de Bruijn index _zero_ as the outermost -- binding, and thus go to the end of the list. -- class Afunction f where type AfunctionR f aconvert :: Config -> Layout aenv aenv -> f -> AST.OpenAfun aenv (AfunctionR f) instance (Arrays a, Afunction r) => Afunction (Acc a -> r) where type AfunctionR (Acc a -> r) = a -> AfunctionR r -- aconvert config alyt f = let a = Acc $ Atag (sizeLayout alyt) alyt' = incLayout alyt `PushLayout` ZeroIdx in Alam $ aconvert config alyt' (f a) instance Arrays b => Afunction (Acc b) where type AfunctionR (Acc b) = b -- aconvert config alyt body = let lvl = sizeLayout alyt vars = [lvl-1, lvl-2 .. 0] in Abody $ convertOpenAcc config lvl vars alyt body -- | Convert an open array expression to de Bruijn form while also incorporating sharing -- information. -- convertOpenAcc :: Arrays arrs => Config -> Level -> [Level] -> Layout aenv aenv -> Acc arrs -> AST.OpenAcc aenv arrs convertOpenAcc config lvl fvs alyt acc = let (sharingAcc, initialEnv) = recoverSharingAcc config lvl fvs acc in convertSharingAcc config alyt initialEnv sharingAcc -- | Convert an array expression with given array environment layout and sharing information into -- de Bruijn form while recovering sharing at the same time (by introducing appropriate let -- bindings). The latter implements the third phase of sharing recovery. -- -- The sharing environment 'env' keeps track of all currently bound sharing variables, keeping them -- in reverse chronological order (outermost variable is at the end of the list). -- convertSharingAcc :: forall aenv arrs. Arrays arrs => Config -> Layout aenv aenv -> [StableSharingAcc] -> ScopedAcc arrs -> AST.OpenAcc aenv arrs convertSharingAcc _ alyt aenv (ScopedAcc lams (AvarSharing sa)) | Just i <- findIndex (matchStableAcc sa) aenv' = AST.OpenAcc $ AST.Avar (prjIdx (ctxt ++ "; i = " ++ show i) i alyt) | null aenv' = error $ "Cyclic definition of a value of type 'Acc' (sa = " ++ show (hashStableNameHeight sa) ++ ")" | otherwise = $internalError "convertSharingAcc" err where aenv' = lams ++ aenv ctxt = "shared 'Acc' tree with stable name " ++ show (hashStableNameHeight sa) err = "inconsistent valuation @ " ++ ctxt ++ ";\n aenv = " ++ show aenv' convertSharingAcc config alyt aenv (ScopedAcc lams (AletSharing sa@(StableSharingAcc _ boundAcc) bodyAcc)) = AST.OpenAcc $ let alyt' = incLayout alyt `PushLayout` ZeroIdx aenv' = lams ++ aenv in AST.Alet (convertSharingAcc config alyt aenv' (ScopedAcc [] boundAcc)) (convertSharingAcc config alyt' (sa:aenv') bodyAcc) convertSharingAcc config alyt aenv (ScopedAcc lams (AccSharing _ preAcc)) = AST.OpenAcc $ let aenv' = lams ++ aenv cvtA :: Arrays a => ScopedAcc a -> AST.OpenAcc aenv a cvtA = convertSharingAcc config alyt aenv' cvtE :: Elt t => ScopedExp t -> AST.Exp aenv t cvtE = convertSharingExp config EmptyLayout alyt [] aenv' cvtF1 :: (Elt a, Elt b) => (Exp a -> ScopedExp b) -> AST.Fun aenv (a -> b) cvtF1 = convertSharingFun1 config alyt aenv' cvtF2 :: (Elt a, Elt b, Elt c) => (Exp a -> Exp b -> ScopedExp c) -> AST.Fun aenv (a -> b -> c) cvtF2 = convertSharingFun2 config alyt aenv' cvtAfun1 :: (Arrays a, Arrays b) => (Acc a -> ScopedAcc b) -> AST.OpenAfun aenv (a -> b) cvtAfun1 = convertSharingAfun1 config alyt aenv' in case preAcc of Atag i -> AST.Avar (prjIdx ("de Bruijn conversion tag " ++ show i) i alyt) Pipe afun1 afun2 acc -> let noStableSharing = StableSharingAcc noStableAccName (undefined :: SharingAcc acc exp ()) alyt' = incLayout alyt `PushLayout` ZeroIdx boundAcc = cvtAfun1 afun1 `AST.Apply` cvtA acc bodyAcc = convertSharingAfun1 config alyt' (noStableSharing : aenv') afun2 `AST.Apply` AST.OpenAcc (AST.Avar AST.ZeroIdx) in AST.Alet (AST.OpenAcc boundAcc) (AST.OpenAcc bodyAcc) Aforeign ff afun acc -> let a = recoverAccSharing config e = recoverExpSharing config s = recoverSeqSharing config f = floatOutAcc config in AST.Aforeign ff (convertAfun a e s f afun) (cvtA acc) Acond b acc1 acc2 -> AST.Acond (cvtE b) (cvtA acc1) (cvtA acc2) Awhile pred iter init -> AST.Awhile (cvtAfun1 pred) (cvtAfun1 iter) (cvtA init) Atuple arrs -> AST.Atuple (convertSharingAtuple config alyt aenv' arrs) Aprj ix a -> AST.Aprj ix (cvtA a) Use array -> AST.Use (fromArr array) Unit e -> AST.Unit (cvtE e) Generate sh f -> AST.Generate (cvtE sh) (cvtF1 f) Reshape e acc -> AST.Reshape (cvtE e) (cvtA acc) Replicate ix acc -> mkReplicate (cvtE ix) (cvtA acc) Slice acc ix -> mkIndex (cvtA acc) (cvtE ix) Map f acc -> AST.Map (cvtF1 f) (cvtA acc) ZipWith f acc1 acc2 -> AST.ZipWith (cvtF2 f) (cvtA acc1) (cvtA acc2) Fold f e acc -> AST.Fold (cvtF2 f) (cvtE e) (cvtA acc) Fold1 f acc -> AST.Fold1 (cvtF2 f) (cvtA acc) FoldSeg f e acc1 acc2 -> AST.FoldSeg (cvtF2 f) (cvtE e) (cvtA acc1) (cvtA acc2) Fold1Seg f acc1 acc2 -> AST.Fold1Seg (cvtF2 f) (cvtA acc1) (cvtA acc2) Scanl f e acc -> AST.Scanl (cvtF2 f) (cvtE e) (cvtA acc) Scanl' f e acc -> AST.Scanl' (cvtF2 f) (cvtE e) (cvtA acc) Scanl1 f acc -> AST.Scanl1 (cvtF2 f) (cvtA acc) Scanr f e acc -> AST.Scanr (cvtF2 f) (cvtE e) (cvtA acc) Scanr' f e acc -> AST.Scanr' (cvtF2 f) (cvtE e) (cvtA acc) Scanr1 f acc -> AST.Scanr1 (cvtF2 f) (cvtA acc) Permute f dftAcc perm acc -> AST.Permute (cvtF2 f) (cvtA dftAcc) (cvtF1 perm) (cvtA acc) Backpermute newDim perm acc -> AST.Backpermute (cvtE newDim) (cvtF1 perm) (cvtA acc) Stencil stencil boundary acc -> AST.Stencil (convertSharingStencilFun1 config acc alyt aenv' stencil) (convertSharingBoundary config alyt aenv' boundary) (cvtA acc) Stencil2 stencil bndy1 acc1 bndy2 acc2 -> AST.Stencil2 (convertSharingStencilFun2 config acc1 acc2 alyt aenv' stencil) (convertSharingBoundary config alyt aenv' bndy1) (cvtA acc1) (convertSharingBoundary config alyt aenv' bndy2) (cvtA acc2) -- Collect seq -> AST.Collect (convertSharingSeq config alyt EmptyLayout aenv' [] seq) {-- -- Sequence expressions -- ------------------ -- | Convert a closed sequence expression to de Bruijn form while incorporating -- sharing information. -- convertSeq :: Typeable s => Bool -- ^ recover sharing of array computations ? -> Bool -- ^ recover sharing of scalar expressions ? -> Bool -- ^ recover sharing of sequence computations ? -> Bool -- ^ always float array computations out of expressions? -> Seq s -- ^ computation to be converted -> AST.Seq s convertSeq shareAcc shareExp shareSeq floatAcc seq = let config = Config shareAcc shareExp shareSeq floatAcc (sharingSeq, initialEnv) = recoverSharingSeq config seq in convertSharingSeq config EmptyLayout EmptyLayout [] initialEnv sharingSeq convertSharingSeq :: forall aenv senv arrs. Config -> Layout aenv aenv -> Layout senv senv -> [StableSharingAcc] -> [StableSharingSeq] -> ScopedSeq arrs -> AST.PreOpenSeq AST.OpenAcc aenv senv arrs convertSharingSeq _ _ slyt _ senv (ScopedSeq (SvarSharing sn)) | Just i <- findIndex (matchStableSeq sn) senv = AST.Reify $ prjIdx (ctxt ++ "; i = " ++ show i) i slyt | null senv = error $ "Cyclic definition of a value of type 'Seq' (sa = " ++ show (hashStableNameHeight sn) ++ ")" | otherwise = $internalError "convertSharingSeq" err where ctxt = "shared 'Seq' tree with stable name " ++ show (hashStableNameHeight sn) err = "inconsistent valuation @ " ++ ctxt ++ ";\n senv = " ++ show senv convertSharingSeq config alyt slyt aenv senv (ScopedSeq (SletSharing sa@(StableSharingSeq _ (SeqSharing _ boundSeq)) bodySeq)) = convSeq boundSeq bodySeq where convSeq :: forall bnd body. PreSeq ScopedAcc ScopedSeq ScopedExp bnd -> ScopedSeq body -> AST.PreOpenSeq AST.OpenAcc aenv senv body convSeq bnd body = case bnd of StreamIn arrs -> producer $ AST.StreamIn arrs ToSeq slix acc -> producer $ mkToSeq slix (cvtA acc) MapSeq afun x -> producer $ AST.MapSeq (cvtAF1 afun) (asIdx x) ZipWithSeq afun x y -> producer $ AST.ZipWithSeq (cvtAF2 afun) (asIdx x) (asIdx y) ScanSeq fun e x -> producer $ AST.ScanSeq (cvtF2 fun) (cvtE e) (asIdx x) _ -> $internalError "convertSharingSeq:convSeq" "Consumer appears to have been let bound" where producer :: Arrays a => AST.Producer AST.OpenAcc aenv senv a -> AST.PreOpenSeq AST.OpenAcc aenv senv body producer p = AST.Producer p $ convertSharingSeq config alyt slyt' aenv (sa:senv) body where slyt' = incLayout slyt `PushLayout` ZeroIdx asIdx :: Arrays a => ScopedSeq [a] -> Idx senv a asIdx (ScopedSeq (SvarSharing sn)) | Just i <- findIndex (matchStableSeq sn) senv = prjIdx (ctxt ++ "; i = " ++ show i) i slyt | null senv = error $ "Cyclic definition of a value of type 'Seq' (sa = " ++ show (hashStableNameHeight sn) ++ ")" | otherwise = $internalError "convertSharingSeq" err where ctxt = "shared 'Seq' tree with stable name " ++ show (hashStableNameHeight sn) err = "inconsistent valuation @ " ++ ctxt ++ ";\n senv = " ++ show senv asIdx _ = $internalError "convertSharingSeq:asIdx" "Sequence computation not in A-normal form" cvtA :: forall a. Arrays a => ScopedAcc a -> AST.OpenAcc aenv a cvtA acc = convertSharingAcc config alyt aenv acc cvtE :: forall t. Elt t => ScopedExp t -> AST.Exp aenv t cvtE = convertSharingExp config EmptyLayout alyt [] aenv cvtF2 :: (Elt a, Elt b, Elt c) => (Exp a -> Exp b -> ScopedExp c) -> AST.Fun aenv (a -> b -> c) cvtF2 = convertSharingFun2 config alyt aenv cvtAF1 :: forall a b. (Arrays a, Arrays b) => (Acc a -> ScopedAcc b) -> OpenAfun aenv (a -> b) cvtAF1 afun = convertSharingAfun1 config alyt aenv afun cvtAF2 :: forall a b c. (Arrays a, Arrays b, Arrays c) => (Acc a -> Acc b -> ScopedAcc c) -> OpenAfun aenv (a -> b -> c) cvtAF2 afun = convertSharingAfun2 config alyt aenv afun convertSharingSeq _ _ _ _ _ (ScopedSeq (SletSharing _ _)) = $internalError "convertSharingSeq" "Sequence computation not in A-normal form" convertSharingSeq config alyt slyt aenv senv s = cvtC s where cvtC :: ScopedSeq a -> AST.PreOpenSeq AST.OpenAcc aenv senv a cvtC (ScopedSeq (SeqSharing _ s)) = case s of FoldSeq fun e x -> AST.Consumer $ AST.FoldSeq (cvtF2 fun) (cvtE e) (asIdx x) FoldSeqFlatten afun acc x -> AST.Consumer $ AST.FoldSeqFlatten (cvtAF3 afun) (cvtA acc) (asIdx x) Stuple t -> AST.Consumer $ AST.Stuple (cvtST t) _ -> $internalError "convertSharingSeq" "Producer has not been let bound" cvtC _ = $internalError "convertSharingSeq" "Unreachable" asIdx :: Arrays a => ScopedSeq [a] -> Idx senv a asIdx (ScopedSeq (SvarSharing sn)) | Just i <- findIndex (matchStableSeq sn) senv = prjIdx (ctxt ++ "; i = " ++ show i) i slyt | null senv = error $ "Cyclic definition of a value of type 'Seq' (sa = " ++ show (hashStableNameHeight sn) ++ ")" | otherwise = $internalError "convertSharingSeq" err where ctxt = "shared 'Seq' tree with stable name " ++ show (hashStableNameHeight sn) err = "inconsistent valuation @ " ++ ctxt ++ ";\n senv = " ++ show senv asIdx _ = $internalError "convertSharingSeq:asIdx" "Sequence computation not in A-normal form" cvtA :: forall a. Arrays a => ScopedAcc a -> AST.OpenAcc aenv a cvtA acc = convertSharingAcc config alyt aenv acc cvtE :: forall t. Elt t => ScopedExp t -> AST.Exp aenv t cvtE = convertSharingExp config EmptyLayout alyt [] aenv cvtF2 :: (Elt a, Elt b, Elt c) => (Exp a -> Exp b -> ScopedExp c) -> AST.Fun aenv (a -> b -> c) cvtF2 = convertSharingFun2 config alyt aenv cvtAF3 :: forall a b c d. (Arrays a, Arrays b, Arrays c, Arrays d) => (Acc a -> Acc b -> Acc c -> ScopedAcc d) -> OpenAfun aenv (a -> b -> c -> d) cvtAF3 afun = convertSharingAfun3 config alyt aenv afun cvtST :: Atuple ScopedSeq t -> Atuple (AST.Consumer AST.OpenAcc aenv senv) t cvtST NilAtup = NilAtup cvtST (SnocAtup t c) | AST.Consumer c' <- cvtC c = SnocAtup (cvtST t) c' | otherwise = $internalError "convertSharingSeq" "Unreachable" --} convertSharingAfun1 :: forall aenv a b. (Arrays a, Arrays b) => Config -> Layout aenv aenv -> [StableSharingAcc] -> (Acc a -> ScopedAcc b) -> OpenAfun aenv (a -> b) convertSharingAfun1 config alyt aenv f = Alam (Abody (convertSharingAcc config alyt' aenv body)) where alyt' = incLayout alyt `PushLayout` ZeroIdx body = f undefined {-- convertSharingAfun2 :: forall aenv a b c. (Arrays a, Arrays b, Arrays c) => Config -> Layout aenv aenv -> [StableSharingAcc] -> (Acc a -> Acc b -> ScopedAcc c) -> OpenAfun aenv (a -> b -> c) convertSharingAfun2 config alyt aenv f = Alam (Alam (Abody (convertSharingAcc config alyt' aenv body))) where alyt' = incLayout (incLayout alyt `PushLayout` ZeroIdx) `PushLayout` ZeroIdx body = f undefined undefined convertSharingAfun3 :: forall aenv a b c d. (Arrays a, Arrays b, Arrays c, Arrays d) => Config -> Layout aenv aenv -> [StableSharingAcc] -> (Acc a -> Acc b -> Acc c -> ScopedAcc d) -> OpenAfun aenv (a -> b -> c -> d) convertSharingAfun3 config alyt aenv f = Alam (Alam (Alam (Abody (convertSharingAcc config alyt' aenv body)))) where alyt' = incLayout (incLayout (incLayout alyt `PushLayout` ZeroIdx) `PushLayout` ZeroIdx) `PushLayout` ZeroIdx body = f undefined undefined undefined --} convertSharingAtuple :: forall aenv a. Config -> Layout aenv aenv -> [StableSharingAcc] -> Atuple ScopedAcc a -> Atuple (AST.OpenAcc aenv) a convertSharingAtuple config alyt aenv = cvt where cvt :: Atuple ScopedAcc a' -> Atuple (AST.OpenAcc aenv) a' cvt NilAtup = NilAtup cvt (SnocAtup t a) = cvt t `SnocAtup` convertSharingAcc config alyt aenv a -- | Convert a boundary condition -- convertSharingBoundary :: forall aenv t. Config -> Layout aenv aenv -> [StableSharingAcc] -> PreBoundary ScopedAcc ScopedExp t -> AST.PreBoundary AST.OpenAcc aenv t convertSharingBoundary config alyt aenv = cvt where cvt :: PreBoundary ScopedAcc ScopedExp t -> AST.Boundary aenv t cvt bndy = case bndy of Clamp -> AST.Clamp Mirror -> AST.Mirror Wrap -> AST.Wrap Constant v -> AST.Constant $ fromElt v Function f -> AST.Function $ convertSharingFun1 config alyt aenv f -- Smart constructors to represent AST forms -- mkIndex :: forall slix e aenv. (Slice slix, Elt e) => AST.OpenAcc aenv (Array (FullShape slix) e) -> AST.Exp aenv slix -> AST.PreOpenAcc AST.OpenAcc aenv (Array (SliceShape slix) e) mkIndex = AST.Slice (sliceIndex slix) where slix = undefined :: slix mkReplicate :: forall slix e aenv. (Slice slix, Elt e) => AST.Exp aenv slix -> AST.OpenAcc aenv (Array (SliceShape slix) e) -> AST.PreOpenAcc AST.OpenAcc aenv (Array (FullShape slix) e) mkReplicate = AST.Replicate (sliceIndex slix) where slix = undefined :: slix -- mkToSeq :: forall slsix slix e aenv senv. (Division slsix, DivisionSlice slsix ~ slix, Elt e, Elt slix, Slice slix) -- => slsix -- -> AST.OpenAcc aenv (Array (FullShape slix) e) -- -> AST.Producer AST.OpenAcc aenv senv (Array (SliceShape slix) e) -- mkToSeq _ = AST.ToSeq (sliceIndex slix) (Proxy :: Proxy slix) -- where -- slix = undefined :: slix -- Scalar functions -- ---------------- -- | Convert a closed scalar function to de Bruijn form while incorporating -- sharing information. -- -- The current design requires all free variables to be bound at the outermost -- level --- we have no general apply term, and so lambdas are always outermost. -- In higher-order abstract syntax, this represents an n-ary, polyvariadic -- function. -- convertFun :: Function f => Bool -> f -> AST.Fun () (FunctionR f) convertFun shareExp = let config = Config False shareExp False False in convert config EmptyLayout class Function f where type FunctionR f convert :: Config -> Layout env env -> f -> AST.OpenFun env () (FunctionR f) instance (Elt a, Function r) => Function (Exp a -> r) where type FunctionR (Exp a -> r) = a -> FunctionR r -- convert config lyt f = let x = Exp $ Tag (sizeLayout lyt) lyt' = incLayout lyt `PushLayout` ZeroIdx in Lam $ convert config lyt' (f x) instance Elt b => Function (Exp b) where type FunctionR (Exp b) = b -- convert config lyt body = let lvl = sizeLayout lyt vars = [lvl-1, lvl-2 .. 0] in Body $ convertOpenExp config lvl vars lyt body -- Scalar expressions -- ------------------ -- | Convert a closed scalar expression to de Bruijn form while incorporating -- sharing information. -- convertExp :: Elt e => Bool -- ^ recover sharing of scalar expressions ? -> Exp e -- ^ expression to be converted -> AST.Exp () e convertExp shareExp exp = let config = Config False shareExp False False in convertOpenExp config 0 [] EmptyLayout exp convertOpenExp :: Elt e => Config -> Level -- level of currently bound scalar variables -> [Level] -- tags of bound scalar variables -> Layout env env -> Exp e -> AST.OpenExp env () e convertOpenExp config lvl fvar lyt exp = let (sharingExp, initialEnv) = recoverSharingExp config lvl fvar exp in convertSharingExp config lyt EmptyLayout initialEnv [] sharingExp -- | Convert an open expression with given environment layouts and sharing information into -- de Bruijn form while recovering sharing at the same time (by introducing appropriate let -- bindings). The latter implements the third phase of sharing recovery. -- -- The sharing environments 'env' and 'aenv' keep track of all currently bound sharing variables, -- keeping them in reverse chronological order (outermost variable is at the end of the list). -- convertSharingExp :: forall t env aenv. Elt t => Config -> Layout env env -- scalar environment -> Layout aenv aenv -- array environment -> [StableSharingExp] -- currently bound sharing variables of expressions -> [StableSharingAcc] -- currently bound sharing variables of array computations -> ScopedExp t -- expression to be converted -> AST.OpenExp env aenv t convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp where -- scalar environment with any lambda bound variables this expression is rooted in env' = lams ++ env cvt :: Elt t' => ScopedExp t' -> AST.OpenExp env aenv t' cvt (ScopedExp _ (VarSharing se)) | Just i <- findIndex (matchStableExp se) env' = AST.Var (prjIdx (ctx i) i lyt) | otherwise = $internalError "convertSharingExp" msg where ctx i = printf "shared 'Exp' tree with stable name %d; i=%d" (hashStableNameHeight se) i msg = unlines [ if null env' then printf "cyclic definition of a value of type 'Exp' (sa=%d)" (hashStableNameHeight se) else printf "inconsistent valuation at shared 'Exp' tree (sa=%d; env=%s)" (hashStableNameHeight se) (show env') , "" , "Note that this error usually arises due to the presence of nested data" , "parallelism; when a parallel computation attempts to initiate new parallel" , "work _which depends on_ a scalar variable given by the first computation." , "" , "For example, suppose we wish to sum the columns of a two-dimensional array." , "You might think to do this in the following (incorrect) way: by constructing" , "a vector using 'generate' where at each index we 'slice' out the" , "corresponding column of the matrix and 'sum' it:" , "" , "> sum_columns_ndp :: Num a => Acc (Matrix a) -> Acc (Vector a)" , "> sum_columns_ndp mat =" , "> let Z :. rows :. cols = unlift (shape mat) :: Z :. Exp Int :. Exp Int" , "> in generate (index1 cols)" , "> (\\col -> the $ sum (slice mat (lift (Z :. All :. unindex1 col))))" , "" , "However, since both 'generate' and 'slice' are data-parallel operators, and" , "moreover that 'slice' _depends on_ the argument 'col' given to it by the" , "'generate' function, this operation requires nested parallelism and is thus" , "not (at this time) permitted. The clue that this definition is invalid is" , "that in order to create a program which will be accepted by the type checker," , "we had to use the function 'the' to retrieve the result of the parallel" , "'sum', effectively concealing that this is a collective operation in order to" , "match the type expected by 'generate'." , "" , "To solve this particular example, we can make use of the fact that (most)" , "collective operations in Accelerate are _rank polymorphic_. The 'sum'" , "operation reduces along the innermost dimension of an array of arbitrary" , "rank, reducing the dimensionality of the array by one. To reduce the array" , "column-wise then, we first need to simply 'transpose' the array:" , "" , "> sum_columns :: Num a => Acc (Matrix a) -> Acc (Vector a)" , "> sum_columns = sum . transpose" , "" , "If you feel like this is not the cause of your error, or you would like some" , "advice locating the problem and perhaps with a workaround, feel free to" , "submit an issue at the above URL." ] cvt (ScopedExp _ (LetSharing se@(StableSharingExp _ boundExp) bodyExp)) = let lyt' = incLayout lyt `PushLayout` ZeroIdx in AST.Let (cvt (ScopedExp [] boundExp)) (convertSharingExp config lyt' alyt (se:env') aenv bodyExp) cvt (ScopedExp _ (ExpSharing _ pexp)) = case pexp of Tag i -> AST.Var (prjIdx ("de Bruijn conversion tag " ++ show i) i lyt) Const v -> AST.Const (fromElt v) Undef -> AST.Undef Tuple tup -> AST.Tuple (cvtT tup) Prj idx e -> AST.Prj idx (cvt e) IndexNil -> AST.IndexNil IndexCons ix i -> AST.IndexCons (cvt ix) (cvt i) IndexHead i -> AST.IndexHead (cvt i) IndexTail ix -> AST.IndexTail (cvt ix) IndexAny -> AST.IndexAny ToIndex sh ix -> AST.ToIndex (cvt sh) (cvt ix) FromIndex sh e -> AST.FromIndex (cvt sh) (cvt e) Cond e1 e2 e3 -> AST.Cond (cvt e1) (cvt e2) (cvt e3) While p it i -> AST.While (cvtFun1 p) (cvtFun1 it) (cvt i) PrimConst c -> AST.PrimConst c PrimApp f e -> cvtPrimFun f (cvt e) Index a e -> AST.Index (cvtA a) (cvt e) LinearIndex a i -> AST.LinearIndex (cvtA a) (cvt i) Shape a -> AST.Shape (cvtA a) ShapeSize e -> AST.ShapeSize (cvt e) Intersect sh1 sh2 -> AST.Intersect (cvt sh1) (cvt sh2) Union sh1 sh2 -> AST.Union (cvt sh1) (cvt sh2) Foreign ff f e -> AST.Foreign ff (convertFun (recoverExpSharing config) f) (cvt e) Coerce e -> AST.Coerce (cvt e) cvtA :: Arrays a => ScopedAcc a -> AST.OpenAcc aenv a cvtA = convertSharingAcc config alyt aenv cvtT :: Tuple ScopedExp tup -> Tuple (AST.OpenExp env aenv) tup cvtT = convertSharingTuple config lyt alyt env' aenv cvtFun1 :: (Elt a, Elt b) => (Exp a -> ScopedExp b) -> AST.OpenFun env aenv (a -> b) cvtFun1 f = Lam (Body (convertSharingExp config lyt' alyt env' aenv body)) where lyt' = incLayout lyt `PushLayout` ZeroIdx body = f undefined -- Push primitive function applications down through let bindings so that -- they are adjacent to their arguments. It looks a bit nicer this way. -- cvtPrimFun :: (Elt a, Elt r) => AST.PrimFun (a -> r) -> AST.OpenExp env' aenv' a -> AST.OpenExp env' aenv' r cvtPrimFun f e = case e of AST.Let bnd body -> AST.Let bnd (cvtPrimFun f body) x -> AST.PrimApp f x -- | Convert a tuple expression -- convertSharingTuple :: Config -> Layout env env -> Layout aenv aenv -> [StableSharingExp] -- currently bound scalar sharing-variables -> [StableSharingAcc] -- currently bound array sharing-variables -> Tuple ScopedExp t -> Tuple (AST.OpenExp env aenv) t convertSharingTuple config lyt alyt env aenv tup = case tup of NilTup -> NilTup SnocTup t e -> convertSharingTuple config lyt alyt env aenv t `SnocTup` convertSharingExp config lyt alyt env aenv e -- | Convert a unary functions -- convertSharingFun1 :: forall a b aenv. (Elt a, Elt b) => Config -> Layout aenv aenv -> [StableSharingAcc] -- currently bound array sharing-variables -> (Exp a -> ScopedExp b) -> AST.Fun aenv (a -> b) convertSharingFun1 config alyt aenv f = Lam (Body openF) where a = Exp undefined -- the 'tag' was already embedded in Phase 1 lyt = EmptyLayout `PushLayout` (ZeroIdx :: Idx ((), a) a) openF = convertSharingExp config lyt alyt [] aenv (f a) -- | Convert a binary functions -- convertSharingFun2 :: forall a b c aenv. (Elt a, Elt b, Elt c) => Config -> Layout aenv aenv -> [StableSharingAcc] -- currently bound array sharing-variables -> (Exp a -> Exp b -> ScopedExp c) -> AST.Fun aenv (a -> b -> c) convertSharingFun2 config alyt aenv f = Lam (Lam (Body openF)) where a = Exp undefined b = Exp undefined lyt = EmptyLayout `PushLayout` (SuccIdx ZeroIdx :: Idx (((), a), b) a) `PushLayout` (ZeroIdx :: Idx (((), a), b) b) openF = convertSharingExp config lyt alyt [] aenv (f a b) -- | Convert a unary stencil function -- convertSharingStencilFun1 :: forall sh a stencil b aenv. (Elt a, Stencil sh a stencil, Elt b) => Config -> ScopedAcc (Array sh a) -- just passed to fix the type variables -> Layout aenv aenv -> [StableSharingAcc] -- currently bound array sharing-variables -> (stencil -> ScopedExp b) -> AST.Fun aenv (StencilRepr sh stencil -> b) convertSharingStencilFun1 config _ alyt aenv stencilFun = Lam (Body openStencilFun) where stencil = Exp undefined :: Exp (StencilRepr sh stencil) lyt = EmptyLayout `PushLayout` (ZeroIdx :: Idx ((), StencilRepr sh stencil) (StencilRepr sh stencil)) body = stencilFun (stencilPrj (undefined::sh) (undefined::a) stencil) openStencilFun = convertSharingExp config lyt alyt [] aenv body -- | Convert a binary stencil function -- convertSharingStencilFun2 :: forall sh a b stencil1 stencil2 c aenv. (Elt a, Stencil sh a stencil1, Elt b, Stencil sh b stencil2, Elt c) => Config -> ScopedAcc (Array sh a) -- just passed to fix the type variables -> ScopedAcc (Array sh b) -- just passed to fix the type variables -> Layout aenv aenv -> [StableSharingAcc] -- currently bound array sharing-variables -> (stencil1 -> stencil2 -> ScopedExp c) -> AST.Fun aenv (StencilRepr sh stencil1 -> StencilRepr sh stencil2 -> c) convertSharingStencilFun2 config _ _ alyt aenv stencilFun = Lam (Lam (Body openStencilFun)) where stencil1 = Exp undefined :: Exp (StencilRepr sh stencil1) stencil2 = Exp undefined :: Exp (StencilRepr sh stencil2) lyt = EmptyLayout `PushLayout` (SuccIdx ZeroIdx :: Idx (((), StencilRepr sh stencil1), StencilRepr sh stencil2) (StencilRepr sh stencil1)) `PushLayout` (ZeroIdx :: Idx (((), StencilRepr sh stencil1), StencilRepr sh stencil2) (StencilRepr sh stencil2)) body = stencilFun (stencilPrj (undefined::sh) (undefined::a) stencil1) (stencilPrj (undefined::sh) (undefined::b) stencil2) openStencilFun = convertSharingExp config lyt alyt [] aenv body -- Sharing recovery -- ================ -- Sharing recovery proceeds in two phases: -- -- /Phase One: build the occurrence map/ -- -- This is a top-down traversal of the AST that computes a map from AST nodes to the number of -- occurrences of that AST node in the overall Accelerate program. An occurrences count of two or -- more indicates sharing. -- -- IMPORTANT: To avoid unfolding the sharing, we do not descent into subtrees that we have -- previously encountered. Hence, the complexity is proportional to the number of nodes in the -- tree /with/ sharing. Consequently, the occurrence count is that in the tree with sharing -- as well. -- -- During computation of the occurrences, the tree is annotated with stable names on every node -- using 'AccSharing' constructors and all but the first occurrence of shared subtrees are pruned -- using 'AvarSharing' constructors (see 'SharingAcc' below). This phase is impure as it is based -- on stable names. -- -- We use a hash table (instead of 'Data.Map') as computing stable names forces us to live in IO -- anyway. Once, the computation of occurrence counts is complete, we freeze the hash table into -- a 'Data.Map'. -- -- (Implemented by 'makeOccMap*'.) -- -- /Phase Two: determine scopes and inject sharing information/ -- -- This is a bottom-up traversal that determines the scope for every binding to be introduced -- to share a subterm. It uses the occurrence map to determine, for every shared subtree, the -- lowest AST node at which the binding for that shared subtree can be placed (using a -- 'AletSharing' constructor)— it's the meet of all the shared subtree occurrences. -- -- The second phase is also replacing the first occurrence of each shared subtree with a -- 'AvarSharing' node and floats the shared subtree up to its binding point. -- -- (Implemented by 'determineScopes*'.) -- -- /Sharing recovery for expressions/ -- -- We recover sharing for each expression (including function bodies) independently of any other -- expression — i.e., we cannot share scalar expressions across array computations. Hence, during -- Phase One, we mark all scalar expression nodes with a stable name and compute one occurrence map -- for every scalar expression (including functions) that occurs in an array computation. These -- occurrence maps are added to the root of scalar expressions using 'RootExp'. -- -- NB: We do not need to worry sharing recovery will try to float a shared subexpression past a -- binder that occurs in that subexpression. Why? Otherwise, the binder would already occur -- out of scope in the original source program. -- -- /Lambda bound variables/ -- -- During sharing recovery, lambda bound variables appear in the form of 'Atag' and 'Tag' data -- constructors. The tag values are determined during Phase One of sharing recovery by computing -- the /level/ of each variable at its binding occurrence. The level at the root of the AST is 0 -- and increases by one with each lambda on each path through the AST. -- Stable names -- ------------ -- Opaque stable name for AST nodes — used to key the occurrence map. -- data StableASTName c where StableASTName :: (Typeable c, Typeable t) => StableName (c t) -> StableASTName c instance Show (StableASTName c) where show (StableASTName sn) = show $ hashStableName sn instance Eq (StableASTName c) where StableASTName sn1 == StableASTName sn2 | Just sn1' <- gcast sn1 = sn1' == sn2 | otherwise = False instance Hashable (StableASTName c) where hashWithSalt s (StableASTName sn) = hashWithSalt s sn makeStableAST :: c t -> IO (StableName (c t)) makeStableAST e = e `seq` makeStableName e -- Stable name for an AST node including the height of the AST representing the array computation. -- data StableNameHeight t = StableNameHeight (StableName t) Int instance Eq (StableNameHeight t) where (StableNameHeight sn1 _) == (StableNameHeight sn2 _) = sn1 == sn2 higherSNH :: StableNameHeight t1 -> StableNameHeight t2 -> Bool StableNameHeight _ h1 `higherSNH` StableNameHeight _ h2 = h1 > h2 hashStableNameHeight :: StableNameHeight t -> Int hashStableNameHeight (StableNameHeight sn _) = hashStableName sn -- Mutable occurrence map -- ---------------------- -- Hash table keyed on the stable names of array computations. -- type HashTable key val = Hash.BasicHashTable key val type ASTHashTable c v = HashTable (StableASTName c) v -- Mutable hashtable version of the occurrence map, which associates each AST node with an -- occurrence count and the height of the AST. -- type OccMapHash c = ASTHashTable c (Int, Int) -- Create a new hash table keyed on AST nodes. -- newASTHashTable :: IO (ASTHashTable c v) newASTHashTable = Hash.new -- Enter one AST node occurrence into an occurrence map. Returns 'Just h' if this is a repeated -- occurrence and the height of the repeatedly occurring AST is 'h'. -- -- If this is the first occurrence, the 'height' *argument* must provide the height of the AST; -- otherwise, the height will be *extracted* from the occurrence map. In the latter case, this -- function yields the AST height. -- enterOcc :: OccMapHash c -> StableASTName c -> Int -> IO (Maybe Int) enterOcc occMap sa height = Hash.mutate occMap sa $ \case Nothing -> (Just (1, height), Nothing) Just (n, heightS) -> (Just (n+1, heightS), Just heightS) -- Immutable occurrence map -- ------------------------ -- Immutable version of the occurrence map (storing the occurrence count only, not the height). We -- use the 'StableName' hash to index an 'IntMap' and disambiguate 'StableName's with identical -- hashes explicitly, storing them in a list in the 'IntMap'. -- type OccMap c = IntMap.IntMap [(StableASTName c, Int)] -- Turn a mutable into an immutable occurrence map. -- freezeOccMap :: OccMapHash c -> IO (OccMap c) freezeOccMap oc = do ocl <- Hash.toList oc traceChunk "OccMap" (show ocl) return . IntMap.fromList . map (\kvs -> (key (head kvs), kvs)) . groupBy sameKey . map dropHeight $ ocl where key (StableASTName sn, _) = hashStableName sn sameKey kv1 kv2 = key kv1 == key kv2 dropHeight (k, (cnt, _)) = (k, cnt) -- Look up the occurrence map keyed by array computations using a stable name. If the key does -- not exist in the map, return an occurrence count of '1'. -- lookupWithASTName :: OccMap c -> StableASTName c -> Int lookupWithASTName oc sa@(StableASTName sn) = fromMaybe 1 $ IntMap.lookup (hashStableName sn) oc >>= Prelude.lookup sa -- Look up the occurrence map keyed by array computations using a sharing array computation. If an -- the key does not exist in the map, return an occurrence count of '1'. -- lookupWithSharingAcc :: OccMap Acc -> StableSharingAcc -> Int lookupWithSharingAcc oc (StableSharingAcc (StableNameHeight sn _) _) = lookupWithASTName oc (StableASTName sn) -- Look up the occurrence map keyed by scalar expressions using a sharing expression. If an -- the key does not exist in the map, return an occurrence count of '1'. -- lookupWithSharingExp :: OccMap Exp -> StableSharingExp -> Int lookupWithSharingExp oc (StableSharingExp (StableNameHeight sn _) _) = lookupWithASTName oc (StableASTName sn) -- Stable 'Acc' nodes -- ------------------ -- Stable name for 'Acc' nodes including the height of the AST. -- type StableAccName arrs = StableNameHeight (Acc arrs) -- Interleave sharing annotations into an array computation AST. Subtrees can be marked as being -- represented by variable (binding a shared subtree) using 'AvarSharing' and as being prefixed by -- a let binding (for a shared subtree) using 'AletSharing'. -- data SharingAcc acc exp arrs where AvarSharing :: Arrays arrs => StableAccName arrs -> SharingAcc acc exp arrs AletSharing :: StableSharingAcc -> acc arrs -> SharingAcc acc exp arrs AccSharing :: Arrays arrs => StableAccName arrs -> PreAcc acc exp arrs -> SharingAcc acc exp arrs -- Array expression with sharing but shared values have not been scoped; i.e. no let bindings. If -- the expression is rooted in a function, the list contains the tags of the variables bound by the -- immediate surrounding lambdas. data UnscopedAcc t = UnscopedAcc [Int] (SharingAcc UnscopedAcc RootExp t) -- Array expression with sharing. For expressions rooted in functions the list holds a sorted -- environment corresponding to the variables bound in the immediate surounding lambdas. data ScopedAcc t = ScopedAcc [StableSharingAcc] (SharingAcc ScopedAcc ScopedExp t) -- Stable name for an array computation associated with its sharing-annotated version. -- data StableSharingAcc where StableSharingAcc :: Arrays arrs => StableAccName arrs -> SharingAcc ScopedAcc ScopedExp arrs -> StableSharingAcc instance Show StableSharingAcc where show (StableSharingAcc sn _) = show $ hashStableNameHeight sn instance Eq StableSharingAcc where StableSharingAcc sn1 _ == StableSharingAcc sn2 _ | Just sn1' <- gcast sn1 = sn1' == sn2 | otherwise = False higherSSA :: StableSharingAcc -> StableSharingAcc -> Bool StableSharingAcc sn1 _ `higherSSA` StableSharingAcc sn2 _ = sn1 `higherSNH` sn2 -- Test whether the given stable names matches an array computation with sharing. -- matchStableAcc :: Typeable arrs => StableAccName arrs -> StableSharingAcc -> Bool matchStableAcc sn1 (StableSharingAcc sn2 _) | Just sn1' <- gcast sn1 = sn1' == sn2 | otherwise = False -- Dummy entry for environments to be used for unused variables. -- noStableAccName :: StableAccName arrs noStableAccName = unsafePerformIO $ StableNameHeight <$> makeStableName undefined <*> pure 0 -- Stable 'Exp' nodes -- ------------------ -- Stable name for 'Exp' nodes including the height of the AST. -- type StableExpName t = StableNameHeight (Exp t) -- Interleave sharing annotations into a scalar expressions AST in the same manner as 'SharingAcc' -- do for array computations. -- data SharingExp (acc :: * -> *) exp t where VarSharing :: Elt t => StableExpName t -> SharingExp acc exp t LetSharing :: StableSharingExp -> exp t -> SharingExp acc exp t ExpSharing :: Elt t => StableExpName t -> PreExp acc exp t -> SharingExp acc exp t -- Specifies a scalar expression AST with sharing annotations but no scoping; i.e. no LetSharing -- constructors. If the expression is rooted in a function, the list contains the tags of the -- variables bound by the immediate surrounding lambdas. data UnscopedExp t = UnscopedExp [Int] (SharingExp UnscopedAcc UnscopedExp t) -- Specifies a scalar expression AST with sharing. For expressions rooted in functions the list -- holds a sorted environment corresponding to the variables bound in the immediate surounding -- lambdas. data ScopedExp t = ScopedExp [StableSharingExp] (SharingExp ScopedAcc ScopedExp t) -- Expressions rooted in 'Acc' computations. -- -- * When counting occurrences, the root of every expression embedded in an 'Acc' is annotated by -- an occurrence map for that one expression (excluding any subterms that are rooted in embedded -- 'Acc's.) -- data RootExp t = RootExp (OccMap Exp) (UnscopedExp t) -- Stable name for an expression associated with its sharing-annotated version. -- data StableSharingExp where StableSharingExp :: Elt t => StableExpName t -> SharingExp ScopedAcc ScopedExp t -> StableSharingExp instance Show StableSharingExp where show (StableSharingExp sn _) = show $ hashStableNameHeight sn instance Eq StableSharingExp where StableSharingExp sn1 _ == StableSharingExp sn2 _ | Just sn1' <- gcast sn1 = sn1' == sn2 | otherwise = False higherSSE :: StableSharingExp -> StableSharingExp -> Bool StableSharingExp sn1 _ `higherSSE` StableSharingExp sn2 _ = sn1 `higherSNH` sn2 -- Test whether the given stable names matches an expression with sharing. -- matchStableExp :: Typeable t => StableExpName t -> StableSharingExp -> Bool matchStableExp sn1 (StableSharingExp sn2 _) | Just sn1' <- gcast sn1 = sn1' == sn2 | otherwise = False -- Dummy entry for environments to be used for unused variables. -- noStableExpName :: StableExpName t noStableExpName = unsafePerformIO $ StableNameHeight <$> makeStableName undefined <*> pure 0 {-- -- Stable 'Seq' nodes -- ------------------ -- Stable name for 'Seq' nodes including the height of the AST. -- type StableSeqName arrs = StableNameHeight (Seq arrs) -- Interleave sharing annotations into an sequence computation AST in the same manner as SharingAcc -- and SharingExp -- data SharingSeq acc seq exp arrs where SvarSharing :: (Typeable arrs, Arrays arrs) => StableSeqName [arrs] -> SharingSeq acc seq exp [arrs] SletSharing :: StableSharingSeq -> seq t -> SharingSeq acc seq exp t SeqSharing :: Typeable arrs => StableSeqName arrs -> PreSeq acc seq exp arrs -> SharingSeq acc seq exp arrs -- Array expression with sharing but shared values have not been scoped; i.e. no let bindings. If -- the expression is rooted in a function, the list contains the tags of the variables bound by the -- immediate surrounding lambdas. data UnscopedSeq t = UnscopedSeq (SharingSeq UnscopedAcc UnscopedSeq RootExp t) -- Array expression with sharing. For expressions rooted in functions the list holds a sorted -- environment corresponding to the variables bound in the immediate surounding lambdas. data ScopedSeq t = ScopedSeq (SharingSeq ScopedAcc ScopedSeq ScopedExp t) -- Sequences rooted in 'Acc' computations. -- -- * When counting occurrences, the root of every sequence embedded in an 'Acc' is annotated by -- an occurrence map for that one expression (excluding any subterms that are rooted in embedded -- 'Acc's.) -- data RootSeq t = RootSeq (OccMap Seq) (UnscopedSeq t) -- Stable name for an array computation associated with its sharing-annotated version. -- data StableSharingSeq where StableSharingSeq :: Typeable arrs => StableSeqName arrs -> SharingSeq ScopedAcc ScopedSeq ScopedExp arrs -> StableSharingSeq instance Show StableSharingSeq where show (StableSharingSeq sn _) = show $ hashStableNameHeight sn instance Eq StableSharingSeq where StableSharingSeq sn1 _ == StableSharingSeq sn2 _ | Just sn1' <- gcast sn1 = sn1' == sn2 | otherwise = False higherSSS :: StableSharingSeq -> StableSharingSeq -> Bool StableSharingSeq sn1 _ `higherSSS` StableSharingSeq sn2 _ = sn1 `higherSNH` sn2 -- Test whether the given stable names matches an array computation with sharing. -- matchStableSeq :: Typeable arrs => StableSeqName arrs -> StableSharingSeq -> Bool matchStableSeq sn1 (StableSharingSeq sn2 _) | Just sn1' <- gcast sn1 = sn1' == sn2 | otherwise = False --} -- Occurrence counting -- =================== -- Compute the 'Acc' occurrence map, marks all nodes (both 'Seq' and 'Exp' nodes) with stable names, -- and drop repeated occurrences of shared 'Acc' and 'Exp' subtrees (Phase One). -- -- We compute a single 'Acc' occurrence map for the whole AST, but one 'Exp' occurrence map for each -- sub-expression rooted in an 'Acc' operation. This is as we cannot float 'Exp' subtrees across -- 'Acc' operations, but we can float 'Acc' subtrees out of 'Exp' expressions. -- -- Note [Traversing functions and side effects] -- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -- We need to descent into function bodies to build the 'OccMap' with all occurrences in the -- function bodies. Due to the side effects in the construction of the occurrence map and, more -- importantly, the dependence of the second phase on /global/ occurrence information, we may not -- delay the body traversals by putting them under a lambda. Hence, we apply each function, to -- traverse its body and use a /dummy abstraction/ of the result. -- -- For example, given a function 'f', we traverse 'f (Tag 0)', which yields a transformed body 'e'. -- As the result of the traversal of the overall function, we use 'const e'. Hence, it is crucial -- that the 'Tag' supplied during the initial traversal is already the one required by the HOAS to -- de Bruijn conversion in 'convertSharingAcc' — any subsequent application of 'const e' will only -- yield 'e' with the embedded 'Tag 0' of the original application. During sharing recovery, we -- float /all/ free variables ('Atag' and 'Tag') out to construct the initial environment for -- producing de Bruijn indices, which replaces them by 'AvarSharing' or 'VarSharing' nodes. Hence, -- the tag values only serve the purpose of determining the ordering in that initial environment. -- They are /not/ directly used to compute the de Brujin indices. -- makeOccMapAcc :: Typeable arrs => Config -> Level -> Acc arrs -> IO (UnscopedAcc arrs, OccMap Acc) makeOccMapAcc config lvl acc = do traceLine "makeOccMapAcc" "Enter" accOccMap <- newASTHashTable (acc', _) <- makeOccMapSharingAcc config accOccMap lvl acc frozenAccOccMap <- freezeOccMap accOccMap traceLine "makeOccMapAcc" "Exit" return (acc', frozenAccOccMap) makeOccMapSharingAcc :: Typeable arrs => Config -> OccMapHash Acc -> Level -> Acc arrs -> IO (UnscopedAcc arrs, Int) makeOccMapSharingAcc config accOccMap = traverseAcc where traverseFun1 :: (Elt a, Typeable b) => Level -> (Exp a -> Exp b) -> IO (Exp a -> RootExp b, Int) traverseFun1 = makeOccMapFun1 config accOccMap traverseFun2 :: (Elt a, Elt b, Typeable c) => Level -> (Exp a -> Exp b -> Exp c) -> IO (Exp a -> Exp b -> RootExp c, Int) traverseFun2 = makeOccMapFun2 config accOccMap traverseAfun1 :: (Arrays a, Typeable b) => Level -> (Acc a -> Acc b) -> IO (Acc a -> UnscopedAcc b, Int) traverseAfun1 = makeOccMapAfun1 config accOccMap traverseExp :: Typeable e => Level -> Exp e -> IO (RootExp e, Int) traverseExp = makeOccMapExp config accOccMap traverseBoundary :: Level -> PreBoundary Acc Exp t -> IO (PreBoundary UnscopedAcc RootExp t, Int) traverseBoundary lvl bndy = case bndy of Clamp -> return (Clamp, 0) Mirror -> return (Mirror, 0) Wrap -> return (Wrap, 0) Constant v -> return (Constant v, 0) Function f -> do (f', h) <- traverseFun1 lvl f return (Function f', h) -- traverseSeq :: forall arrs. Typeable arrs -- => Level -> Seq arrs -- -> IO (RootSeq arrs, Int) -- traverseSeq = makeOccMapRootSeq config accOccMap traverseAcc :: forall arrs. Typeable arrs => Level -> Acc arrs -> IO (UnscopedAcc arrs, Int) traverseAcc lvl acc@(Acc pacc) = mfix $ \ ~(_, height) -> do -- Compute stable name and enter it into the occurrence map -- sn <- makeStableAST acc heightIfRepeatedOccurrence <- enterOcc accOccMap (StableASTName sn) height traceLine (showPreAccOp pacc) $ do let hash = show (hashStableName sn) case heightIfRepeatedOccurrence of Just height -> "REPEATED occurrence (sn = " ++ hash ++ "; height = " ++ show height ++ ")" Nothing -> "first occurrence (sn = " ++ hash ++ ")" -- Reconstruct the computation in shared form. -- -- In case of a repeated occurrence, the height comes from the occurrence map; otherwise -- it is computed by the traversal function passed in 'newAcc'. See also 'enterOcc'. -- -- NB: This function can only be used in the case alternatives below; outside of the -- case we cannot discharge the 'Arrays arrs' constraint. -- let reconstruct :: Arrays arrs => IO (PreAcc UnscopedAcc RootExp arrs, Int) -> IO (UnscopedAcc arrs, Int) reconstruct newAcc = case heightIfRepeatedOccurrence of Just height | recoverAccSharing config -> return (UnscopedAcc [] (AvarSharing (StableNameHeight sn height)), height) _ -> do (acc, height) <- newAcc return (UnscopedAcc [] (AccSharing (StableNameHeight sn height) acc), height) case pacc of Atag i -> reconstruct $ return (Atag i, 0) -- height is 0! Pipe afun1 afun2 acc -> reconstruct $ do (afun1', h1) <- traverseAfun1 lvl afun1 (afun2', h2) <- traverseAfun1 lvl afun2 (acc', h3) <- traverseAcc lvl acc return (Pipe afun1' afun2' acc' , h1 `max` h2 `max` h3 + 1) Aforeign ff afun acc -> reconstruct $ travA (Aforeign ff afun) acc Acond e acc1 acc2 -> reconstruct $ do (e' , h1) <- traverseExp lvl e (acc1', h2) <- traverseAcc lvl acc1 (acc2', h3) <- traverseAcc lvl acc2 return (Acond e' acc1' acc2', h1 `max` h2 `max` h3 + 1) Awhile pred iter init -> reconstruct $ do (pred', h1) <- traverseAfun1 lvl pred (iter', h2) <- traverseAfun1 lvl iter (init', h3) <- traverseAcc lvl init return (Awhile pred' iter' init' , h1 `max` h2 `max` h3 + 1) Atuple tup -> reconstruct $ do (tup', h) <- travAtup tup return (Atuple tup', h) Aprj ix a -> reconstruct $ travA (Aprj ix) a Use arr -> reconstruct $ return (Use arr, 1) Unit e -> reconstruct $ do (e', h) <- traverseExp lvl e return (Unit e', h + 1) Generate e f -> reconstruct $ do (e', h1) <- traverseExp lvl e (f', h2) <- traverseFun1 lvl f return (Generate e' f', h1 `max` h2 + 1) Reshape e acc -> reconstruct $ travEA Reshape e acc Replicate e acc -> reconstruct $ travEA Replicate e acc Slice acc e -> reconstruct $ travEA (flip Slice) e acc Map f acc -> reconstruct $ do (f' , h1) <- traverseFun1 lvl f (acc', h2) <- traverseAcc lvl acc return (Map f' acc', h1 `max` h2 + 1) ZipWith f acc1 acc2 -> reconstruct $ travF2A2 ZipWith f acc1 acc2 Fold f e acc -> reconstruct $ travF2EA Fold f e acc Fold1 f acc -> reconstruct $ travF2A Fold1 f acc FoldSeg f e acc1 acc2 -> reconstruct $ do (f' , h1) <- traverseFun2 lvl f (e' , h2) <- traverseExp lvl e (acc1', h3) <- traverseAcc lvl acc1 (acc2', h4) <- traverseAcc lvl acc2 return (FoldSeg f' e' acc1' acc2', h1 `max` h2 `max` h3 `max` h4 + 1) Fold1Seg f acc1 acc2 -> reconstruct $ travF2A2 Fold1Seg f acc1 acc2 Scanl f e acc -> reconstruct $ travF2EA Scanl f e acc Scanl' f e acc -> reconstruct $ travF2EA Scanl' f e acc Scanl1 f acc -> reconstruct $ travF2A Scanl1 f acc Scanr f e acc -> reconstruct $ travF2EA Scanr f e acc Scanr' f e acc -> reconstruct $ travF2EA Scanr' f e acc Scanr1 f acc -> reconstruct $ travF2A Scanr1 f acc Permute c acc1 p acc2 -> reconstruct $ do (c' , h1) <- traverseFun2 lvl c (p' , h2) <- traverseFun1 lvl p (acc1', h3) <- traverseAcc lvl acc1 (acc2', h4) <- traverseAcc lvl acc2 return (Permute c' acc1' p' acc2', h1 `max` h2 `max` h3 `max` h4 + 1) Backpermute e p acc -> reconstruct $ do (e' , h1) <- traverseExp lvl e (p' , h2) <- traverseFun1 lvl p (acc', h3) <- traverseAcc lvl acc return (Backpermute e' p' acc', h1 `max` h2 `max` h3 + 1) Stencil s bnd acc -> reconstruct $ do (s' , h1) <- makeOccMapStencil1 config accOccMap acc lvl s (bnd', h2) <- traverseBoundary lvl bnd (acc', h3) <- traverseAcc lvl acc return (Stencil s' bnd' acc', h1 `max` h2 `max` h3 + 1) Stencil2 s bnd1 acc1 bnd2 acc2 -> reconstruct $ do (s' , h1) <- makeOccMapStencil2 config accOccMap acc1 acc2 lvl s (bnd1', h2) <- traverseBoundary lvl bnd1 (acc1', h3) <- traverseAcc lvl acc1 (bnd2', h4) <- traverseBoundary lvl bnd2 (acc2', h5) <- traverseAcc lvl acc2 return (Stencil2 s' bnd1' acc1' bnd2' acc2', h1 `max` h2 `max` h3 `max` h4 `max` h5 + 1) -- Collect s -> reconstruct $ do -- (s', h) <- traverseSeq lvl s -- return (Collect s', h + 1) where travA :: Arrays arrs' => (UnscopedAcc arrs' -> PreAcc UnscopedAcc RootExp arrs) -> Acc arrs' -> IO (PreAcc UnscopedAcc RootExp arrs, Int) travA c acc = do (acc', h) <- traverseAcc lvl acc return (c acc', h + 1) travEA :: (Typeable b, Arrays arrs') => (RootExp b -> UnscopedAcc arrs' -> PreAcc UnscopedAcc RootExp arrs) -> Exp b -> Acc arrs' -> IO (PreAcc UnscopedAcc RootExp arrs, Int) travEA c exp acc = do (exp', h1) <- traverseExp lvl exp (acc', h2) <- traverseAcc lvl acc return (c exp' acc', h1 `max` h2 + 1) travF2A :: (Elt b, Elt c, Typeable d, Arrays arrs') => ((Exp b -> Exp c -> RootExp d) -> UnscopedAcc arrs' -> PreAcc UnscopedAcc RootExp arrs) -> (Exp b -> Exp c -> Exp d) -> Acc arrs' -> IO (PreAcc UnscopedAcc RootExp arrs, Int) travF2A c fun acc = do (fun', h1) <- traverseFun2 lvl fun (acc', h2) <- traverseAcc lvl acc return (c fun' acc', h1 `max` h2 + 1) travF2EA :: (Elt b, Elt c, Typeable d, Typeable e, Arrays arrs') => ((Exp b -> Exp c -> RootExp d) -> RootExp e -> UnscopedAcc arrs' -> PreAcc UnscopedAcc RootExp arrs) -> (Exp b -> Exp c -> Exp d) -> Exp e -> Acc arrs' -> IO (PreAcc UnscopedAcc RootExp arrs, Int) travF2EA c fun exp acc = do (fun', h1) <- traverseFun2 lvl fun (exp', h2) <- traverseExp lvl exp (acc', h3) <- traverseAcc lvl acc return (c fun' exp' acc', h1 `max` h2 `max` h3 + 1) travF2A2 :: (Elt b, Elt c, Typeable d, Arrays arrs1, Arrays arrs2) => ((Exp b -> Exp c -> RootExp d) -> UnscopedAcc arrs1 -> UnscopedAcc arrs2 -> PreAcc UnscopedAcc RootExp arrs) -> (Exp b -> Exp c -> Exp d) -> Acc arrs1 -> Acc arrs2 -> IO (PreAcc UnscopedAcc RootExp arrs, Int) travF2A2 c fun acc1 acc2 = do (fun' , h1) <- traverseFun2 lvl fun (acc1', h2) <- traverseAcc lvl acc1 (acc2', h3) <- traverseAcc lvl acc2 return (c fun' acc1' acc2', h1 `max` h2 `max` h3 + 1) travAtup :: Atuple Acc a -> IO (Atuple UnscopedAcc a, Int) travAtup NilAtup = return (NilAtup, 1) travAtup (SnocAtup tup a) = do (tup', h1) <- travAtup tup (a', h2) <- traverseAcc lvl a return (SnocAtup tup' a', h1 `max` h2 + 1) makeOccMapAfun1 :: (Arrays a, Typeable b) => Config -> OccMapHash Acc -> Level -> (Acc a -> Acc b) -> IO (Acc a -> UnscopedAcc b, Int) makeOccMapAfun1 config accOccMap lvl f = do let x = Acc (Atag lvl) -- (UnscopedAcc [] body, height) <- makeOccMapSharingAcc config accOccMap (lvl+1) (f x) return (const (UnscopedAcc [lvl] body), height) {-- makeOccMapAfun2 :: (Arrays a, Arrays b, Typeable c) => Config -> OccMapHash Acc -> Level -> (Acc a -> Acc b -> Acc c) -> IO (Acc a -> Acc b -> UnscopedAcc c, Int) makeOccMapAfun2 config accOccMap lvl f = do let x = Acc (Atag (lvl + 1)) y = Acc (Atag (lvl + 0)) -- (UnscopedAcc [] body, height) <- makeOccMapSharingAcc config accOccMap (lvl+2) (f x y) return (\ _ _ -> (UnscopedAcc [lvl, lvl+1] body), height) makeOccMapAfun3 :: (Arrays a, Arrays b, Arrays c, Typeable d) => Config -> OccMapHash Acc -> Level -> (Acc a -> Acc b -> Acc c -> Acc d) -> IO (Acc a -> Acc b -> Acc c -> UnscopedAcc d, Int) makeOccMapAfun3 config accOccMap lvl f = do let x = Acc (Atag (lvl + 2)) y = Acc (Atag (lvl + 1)) z = Acc (Atag (lvl + 0)) -- (UnscopedAcc [] body, height) <- makeOccMapSharingAcc config accOccMap (lvl+3) (f x y z) return (\ _ _ _ -> (UnscopedAcc [lvl, lvl+1, lvl+2] body), height) --} -- Generate occupancy information for scalar functions and expressions. Helper -- functions wrapping around 'makeOccMapRootExp' with more specific types. -- -- See Note [Traversing functions and side effects] -- makeOccMapExp :: Typeable e => Config -> OccMapHash Acc -> Level -> Exp e -> IO (RootExp e, Int) makeOccMapExp config accOccMap lvl = makeOccMapRootExp config accOccMap lvl [] makeOccMapFun1 :: (Elt a, Typeable b) => Config -> OccMapHash Acc -> Level -> (Exp a -> Exp b) -> IO (Exp a -> RootExp b, Int) makeOccMapFun1 config accOccMap lvl f = do let x = Exp (Tag lvl) -- (body, height) <- makeOccMapRootExp config accOccMap (lvl+1) [lvl] (f x) return (const body, height) makeOccMapFun2 :: (Elt a, Elt b, Typeable c) => Config -> OccMapHash Acc -> Level -> (Exp a -> Exp b -> Exp c) -> IO (Exp a -> Exp b -> RootExp c, Int) makeOccMapFun2 config accOccMap lvl f = do let x = Exp (Tag (lvl+1)) y = Exp (Tag lvl) -- (body, height) <- makeOccMapRootExp config accOccMap (lvl+2) [lvl, lvl+1] (f x y) return (\_ _ -> body, height) makeOccMapStencil1 :: forall sh a b stencil. (Stencil sh a stencil, Typeable b) => Config -> OccMapHash Acc -> Acc (Array sh a) {- dummy -} -> Level -> (stencil -> Exp b) -> IO (stencil -> RootExp b, Int) makeOccMapStencil1 config accOccMap _ lvl stencil = do let x = Exp (Tag lvl) f = stencil . stencilPrj (undefined::sh) (undefined::a) -- (body, height) <- makeOccMapRootExp config accOccMap (lvl+1) [lvl] (f x) return (const body, height) makeOccMapStencil2 :: forall sh a b c stencil1 stencil2. (Stencil sh a stencil1, Stencil sh b stencil2, Typeable c) => Config -> OccMapHash Acc -> Acc (Array sh a) {- dummy -} -> Acc (Array sh b) {- dummy -} -> Level -> (stencil1 -> stencil2 -> Exp c) -> IO (stencil1 -> stencil2 -> RootExp c, Int) makeOccMapStencil2 config accOccMap _ _ lvl stencil = do let x = Exp (Tag (lvl+1)) y = Exp (Tag lvl) f a b = stencil (stencilPrj (undefined::sh) (undefined::a) a) (stencilPrj (undefined::sh) (undefined::b) b) -- (body, height) <- makeOccMapRootExp config accOccMap (lvl+2) [lvl, lvl+1] (f x y) return (\_ _ -> body, height) -- Generate sharing information for expressions embedded in Acc computations. -- Expressions are annotated with: -- -- 1) the tags of free scalar variables (for scalar functions) -- 2) a local occurrence map for that expression. -- makeOccMapRootExp :: Typeable e => Config -> OccMapHash Acc -> Level -- The level of currently bound scalar variables -> [Int] -- The tags of newly introduced free scalar variables in this expression -> Exp e -> IO (RootExp e, Int) makeOccMapRootExp config accOccMap lvl fvs exp = do traceLine "makeOccMapRootExp" "Enter" expOccMap <- newASTHashTable (UnscopedExp [] exp', height) <- makeOccMapSharingExp config accOccMap expOccMap lvl exp frozenExpOccMap <- freezeOccMap expOccMap traceLine "makeOccMapRootExp" "Exit" return (RootExp frozenExpOccMap (UnscopedExp fvs exp'), height) -- Generate sharing information for an open scalar expression. -- makeOccMapSharingExp :: Typeable e => Config -> OccMapHash Acc -> OccMapHash Exp -> Level -- The level of currently bound variables -> Exp e -> IO (UnscopedExp e, Int) makeOccMapSharingExp config accOccMap expOccMap = travE where travE :: forall a. Typeable a => Level -> Exp a -> IO (UnscopedExp a, Int) travE lvl exp@(Exp pexp) = mfix $ \ ~(_, height) -> do -- Compute stable name and enter it into the occurrence map -- sn <- makeStableAST exp heightIfRepeatedOccurrence <- enterOcc expOccMap (StableASTName sn) height traceLine (showPreExpOp pexp) $ do let hash = show (hashStableName sn) case heightIfRepeatedOccurrence of Just height -> "REPEATED occurrence (sn = " ++ hash ++ "; height = " ++ show height ++ ")" Nothing -> "first occurrence (sn = " ++ hash ++ ")" -- Reconstruct the computation in shared form. -- -- In case of a repeated occurrence, the height comes from the occurrence map; otherwise -- it is computed by the traversal function passed in 'newExp'. See also 'enterOcc'. -- -- NB: This function can only be used in the case alternatives below; outside of the -- case we cannot discharge the 'Elt a' constraint. -- let reconstruct :: Elt a => IO (PreExp UnscopedAcc UnscopedExp a, Int) -> IO (UnscopedExp a, Int) reconstruct newExp = case heightIfRepeatedOccurrence of Just height | recoverExpSharing config -> return (UnscopedExp [] (VarSharing (StableNameHeight sn height)), height) _ -> do (exp, height) <- newExp return (UnscopedExp [] (ExpSharing (StableNameHeight sn height) exp), height) case pexp of Tag i -> reconstruct $ return (Tag i, 0) -- height is 0! Const c -> reconstruct $ return (Const c, 1) Undef -> reconstruct $ return (Undef, 1) Tuple tup -> reconstruct $ do (tup', h) <- travTup tup return (Tuple tup', h) Prj i e -> reconstruct $ travE1 (Prj i) e IndexNil -> reconstruct $ return (IndexNil, 1) IndexCons ix i -> reconstruct $ travE2 IndexCons ix i IndexHead i -> reconstruct $ travE1 IndexHead i IndexTail ix -> reconstruct $ travE1 IndexTail ix IndexAny -> reconstruct $ return (IndexAny, 1) ToIndex sh ix -> reconstruct $ travE2 ToIndex sh ix FromIndex sh e -> reconstruct $ travE2 FromIndex sh e Cond e1 e2 e3 -> reconstruct $ travE3 Cond e1 e2 e3 While p iter init -> reconstruct $ do (p' , h1) <- traverseFun1 lvl p (iter', h2) <- traverseFun1 lvl iter (init', h3) <- travE lvl init return (While p' iter' init', h1 `max` h2 `max` h3 + 1) PrimConst c -> reconstruct $ return (PrimConst c, 1) PrimApp p e -> reconstruct $ travE1 (PrimApp p) e Index a e -> reconstruct $ travAE Index a e LinearIndex a i -> reconstruct $ travAE LinearIndex a i Shape a -> reconstruct $ travA Shape a ShapeSize e -> reconstruct $ travE1 ShapeSize e Intersect sh1 sh2 -> reconstruct $ travE2 Intersect sh1 sh2 Union sh1 sh2 -> reconstruct $ travE2 Union sh1 sh2 Foreign ff f e -> reconstruct $ do (e', h) <- travE lvl e return (Foreign ff f e', h+1) Coerce e -> reconstruct $ travE1 Coerce e where traverseAcc :: Typeable arrs => Level -> Acc arrs -> IO (UnscopedAcc arrs, Int) traverseAcc = makeOccMapSharingAcc config accOccMap traverseFun1 :: (Elt a, Typeable b) => Level -> (Exp a -> Exp b) -> IO (Exp a -> UnscopedExp b, Int) traverseFun1 lvl f = do let x = Exp (Tag lvl) (UnscopedExp [] body, height) <- travE (lvl+1) (f x) return (const (UnscopedExp [lvl] body), height + 1) travE1 :: Typeable b => (UnscopedExp b -> PreExp UnscopedAcc UnscopedExp a) -> Exp b -> IO (PreExp UnscopedAcc UnscopedExp a, Int) travE1 c e = do (e', h) <- travE lvl e return (c e', h + 1) travE2 :: (Typeable b, Typeable c) => (UnscopedExp b -> UnscopedExp c -> PreExp UnscopedAcc UnscopedExp a) -> Exp b -> Exp c -> IO (PreExp UnscopedAcc UnscopedExp a, Int) travE2 c e1 e2 = do (e1', h1) <- travE lvl e1 (e2', h2) <- travE lvl e2 return (c e1' e2', h1 `max` h2 + 1) travE3 :: (Typeable b, Typeable c, Typeable d) => (UnscopedExp b -> UnscopedExp c -> UnscopedExp d -> PreExp UnscopedAcc UnscopedExp a) -> Exp b -> Exp c -> Exp d -> IO (PreExp UnscopedAcc UnscopedExp a, Int) travE3 c e1 e2 e3 = do (e1', h1) <- travE lvl e1 (e2', h2) <- travE lvl e2 (e3', h3) <- travE lvl e3 return (c e1' e2' e3', h1 `max` h2 `max` h3 + 1) travA :: Typeable b => (UnscopedAcc b -> PreExp UnscopedAcc UnscopedExp a) -> Acc b -> IO (PreExp UnscopedAcc UnscopedExp a, Int) travA c acc = do (acc', h) <- traverseAcc lvl acc return (c acc', h + 1) travAE :: (Typeable b, Typeable c) => (UnscopedAcc b -> UnscopedExp c -> PreExp UnscopedAcc UnscopedExp a) -> Acc b -> Exp c -> IO (PreExp UnscopedAcc UnscopedExp a, Int) travAE c acc e = do (acc', h1) <- traverseAcc lvl acc (e' , h2) <- travE lvl e return (c acc' e', h1 `max` h2 + 1) travTup :: Tuple Exp tup -> IO (Tuple UnscopedExp tup, Int) travTup NilTup = return (NilTup, 1) travTup (SnocTup tup e) = do (tup', h1) <- travTup tup (e' , h2) <- travE lvl e return (SnocTup tup' e', h1 `max` h2 + 1) {-- makeOccMapRootSeq :: Typeable arrs => Config -> OccMapHash Acc -> Level -> Seq arrs -> IO (RootSeq arrs, Int) makeOccMapRootSeq config accOccMap lvl seq = do traceLine "makeOccMapRootSeq" "Enter" seqOccMap <- newASTHashTable (seq', height) <- makeOccMapSharingSeq config accOccMap seqOccMap lvl seq frozenSeqOccMap <- freezeOccMap seqOccMap traceLine "makeOccMapRootSeq" "Exit" return (RootSeq frozenSeqOccMap seq', height) -- Generate sharing information for an open sequence expression. -- makeOccMapSharingSeq :: Typeable e => Config -> OccMapHash Acc -> OccMapHash Seq -> Level -- The level of currently bound variables -> Seq e -> IO (UnscopedSeq e, Int) makeOccMapSharingSeq config accOccMap seqOccMap = traverseSeq where traverseAcc :: Typeable arrs => Level -> Acc arrs -> IO (UnscopedAcc arrs, Int) traverseAcc = makeOccMapSharingAcc config accOccMap traverseAfun1 :: (Arrays a, Typeable b) => Level -> (Acc a -> Acc b) -> IO (Acc a -> UnscopedAcc b, Int) traverseAfun1 = makeOccMapAfun1 config accOccMap traverseAfun2 :: (Arrays a, Arrays b, Typeable c) => Level -> (Acc a -> Acc b -> Acc c) -> IO (Acc a -> Acc b -> UnscopedAcc c, Int) traverseAfun2 = makeOccMapAfun2 config accOccMap traverseAfun3 :: (Arrays a, Arrays b, Arrays c, Typeable d) => Level -> (Acc a -> Acc b -> Acc c -> Acc d) -> IO (Acc a -> Acc b -> Acc c -> UnscopedAcc d, Int) traverseAfun3 = makeOccMapAfun3 config accOccMap traverseExp :: Typeable e => Level -> Exp e -> IO (RootExp e, Int) traverseExp = makeOccMapExp config accOccMap traverseFun2 :: (Elt a, Elt b, Typeable c) => Level -> (Exp a -> Exp b -> Exp c) -> IO (Exp a -> Exp b -> RootExp c, Int) traverseFun2 = makeOccMapFun2 config accOccMap traverseTup :: Level -> Atuple Seq tup -> IO (Atuple UnscopedSeq tup, Int) traverseTup _ NilAtup = return (NilAtup, 1) traverseTup lvl (SnocAtup tup s) = do (tup', h1) <- traverseTup lvl tup (s' , h2) <- traverseSeq lvl s return (SnocAtup tup' s', h1 `max` h2 + 1) traverseSeq :: forall arrs. Typeable arrs => Level -> Seq arrs -> IO (UnscopedSeq arrs, Int) traverseSeq lvl acc@(Seq seq) = mfix $ \ ~(_, height) -> do -- Compute stable name and enter it into the occurrence map -- sn <- makeStableAST acc heightIfRepeatedOccurrence <- enterOcc seqOccMap (StableASTName sn) height traceLine (showPreSeqOp seq) $ do let hash = show (hashStableName sn) case heightIfRepeatedOccurrence of Just height -> "REPEATED occurrence (sn = " ++ hash ++ "; height = " ++ show height ++ ")" Nothing -> "first occurrence (sn = " ++ hash ++ ")" -- Reconstruct the computation in shared form. -- -- In case of a repeated occurrence, the height comes from the occurrence map; otherwise -- it is computed by the traversal function passed in 'newAcc'. See also 'enterOcc'. -- -- NB: This function can only be used in the case alternatives below; outside of the -- case we cannot discharge the 'Arrays arrs' constraint. -- let producer :: (arrs ~ [a], Arrays a) => IO (PreSeq UnscopedAcc UnscopedSeq RootExp arrs, Int) -> IO (UnscopedSeq arrs, Int) producer newSeq = case heightIfRepeatedOccurrence of Just height | recoverSeqSharing config -> return (UnscopedSeq (SvarSharing (StableNameHeight sn height)), height) _ -> do (seq, height) <- newSeq return (UnscopedSeq (SeqSharing (StableNameHeight sn height) seq), height) let consumer :: IO (PreSeq UnscopedAcc UnscopedSeq RootExp arrs, Int) -> IO (UnscopedSeq arrs, Int) consumer newSeq = do (seq, height) <- newSeq return (UnscopedSeq (SeqSharing (StableNameHeight sn height) seq), height) case seq of StreamIn arrs -> producer $ return (StreamIn arrs, 1) ToSeq sl acc -> producer $ do (acc', h1) <- traverseAcc lvl acc return (ToSeq sl acc', h1 + 1) MapSeq afun s -> producer $ do (afun', h1) <- traverseAfun1 lvl afun (s' , h2) <- traverseSeq lvl s return (MapSeq afun' s', h1 `max` h2 + 1) ZipWithSeq afun s1 s2 -> producer $ do (afun', h1) <- traverseAfun2 lvl afun (s1' , h2) <- traverseSeq lvl s1 (s2' , h3) <- traverseSeq lvl s2 return (ZipWithSeq afun' s1' s2', h1 `max` h2 `max` h3 + 1) ScanSeq fun e s -> producer $ do (fun', h1) <- traverseFun2 lvl fun (e', h2) <- traverseExp lvl e (s' , h3) <- traverseSeq lvl s return (ScanSeq fun' e' s', h1 `max` h2 `max` h3 + 1) FoldSeq fun e s -> consumer $ do (fun', h1) <- traverseFun2 lvl fun (e' , h2) <- traverseExp lvl e (s' , h3) <- traverseSeq lvl s return (FoldSeq fun' e' s', h1 `max` h2 `max` h3 + 1) FoldSeqFlatten afun acc s -> consumer $ do (afun', h1) <- traverseAfun3 lvl afun (acc', h2) <- traverseAcc lvl acc (s' , h3) <- traverseSeq lvl s return (FoldSeqFlatten afun' acc' s', h1 `max` h2 `max` h3 + 1) Stuple t -> consumer $ do (t', h1) <- traverseTup lvl t return (Stuple t', h1 + 1) --} -- Type used to maintain how often each shared subterm, so far, occurred during a bottom-up sweep, -- as well as the relation between subterms. It is comprised of a list of terms and a graph giving -- their relation. -- -- Invariants of the list: -- - If one shared term 's' is itself a subterm of another shared term 't', then 's' must occur -- *after* 't' in the list. -- - No shared term occurs twice. -- - A term may have a final occurrence count of only 1 iff it is either a free variable ('Atag' -- or 'Tag') or an array computation lifted out of an expression. -- - All 'Exp' node counts precede all 'Acc' node counts as we don't share 'Exp' nodes across 'Acc' -- nodes. Similarly, all 'Seq' nodes precede 'Acc' nodes and 'Exp' nodes precede 'Seq' nodes. -- -- We determine the subterm property by using the tree height in 'StableNameHeight'. Trees get -- smaller towards the end of a 'NodeCounts' list. The height of free variables ('Atag' or 'Tag') -- is 0, whereas other leaves have height 1. This guarantees that all free variables are at the end -- of the 'NodeCounts' list. -- -- The graph is represented as a map where a stable name 'a' is mapped to a set of stables names 'b' -- such that if there exists a edge from 'a' to 'c' that 'c' is contained within 'b'. -- -- Properties of the graph: -- - There exists an edge from 'a' to 'b' if the term 'a' names is a subterm of the term named by -- 'b'. -- -- To ensure the list invariant and the graph properties are preserved over merging node counts from -- sibling subterms, the function '(+++)' must be used. -- type NodeCounts = ([NodeCount], Map.HashMap NodeName (Set.HashSet NodeName)) data NodeName where NodeName :: Typeable a => StableName a -> NodeName instance Eq NodeName where (NodeName sn1) == (NodeName sn2) | Just sn2' <- gcast sn2 = sn1 == sn2' | otherwise = False instance Hashable NodeName where hashWithSalt hash (NodeName sn1) = hash + hashStableName sn1 instance Show NodeName where show (NodeName sn) = show (hashStableName sn) data NodeCount = AccNodeCount StableSharingAcc Int | ExpNodeCount StableSharingExp Int -- SeqNodeCount StableSharingSeq Int deriving Show -- Empty node counts -- noNodeCounts :: NodeCounts noNodeCounts = ([], Map.empty) -- Insert an Acc node into the node counts, assuming that it is a superterm of the all the existing -- nodes. -- -- TODO: Perform cycle detection here. insertAccNode :: StableSharingAcc -> NodeCounts -> NodeCounts insertAccNode ssa@(StableSharingAcc (StableNameHeight sn _) _) (subterms,g) = ([AccNodeCount ssa 1], g') +++ (subterms,g) where k = NodeName sn hs = map nodeName subterms g' = Map.fromList $ (k, Set.empty) : [(h, Set.singleton k) | h <- hs] -- Insert an Exp node into the node counts, assuming that it is a superterm of the all the existing -- nodes. -- -- TODO: Perform cycle detection here. insertExpNode :: StableSharingExp -> NodeCounts -> NodeCounts insertExpNode ssa@(StableSharingExp (StableNameHeight sn _) _) (subterms,g) = ([ExpNodeCount ssa 1], g') +++ (subterms,g) where k = NodeName sn hs = map nodeName subterms g' = Map.fromList $ (k, Set.empty) : [(h, Set.singleton k) | h <- hs] {-- -- Insert an Seq node into the node counts, assuming that it is a superterm of the all the existing -- nodes. -- -- TODO: Perform cycle detection here. insertSeqNode :: StableSharingSeq -> NodeCounts -> NodeCounts insertSeqNode ssa@(StableSharingSeq (StableNameHeight sn _) _) (subterms,g) = ([SeqNodeCount ssa 1], g') +++ (subterms,g) where k = NodeName sn hs = map nodeName subterms g' = Map.fromList $ (k, Set.empty) : [(h, Set.singleton k) | h <- hs] --} -- Remove nodes that aren't in the list from the graph. -- -- RCE: This is no longer necessary when NDP is supported. cleanCounts :: NodeCounts -> NodeCounts cleanCounts (ns, g) = (ns, Map.fromList $ [(h, Set.filter (flip elem hs) (g Map.! h)) | h <- hs ]) where hs = (map nodeName ns) nodeName :: NodeCount -> NodeName nodeName (AccNodeCount (StableSharingAcc (StableNameHeight sn _) _) _) = NodeName sn nodeName (ExpNodeCount (StableSharingExp (StableNameHeight sn _) _) _) = NodeName sn -- nodeName (SeqNodeCount (StableSharingSeq (StableNameHeight sn _) _) _) = NodeName sn -- Combine node counts that belong to the same node. -- -- * We assume that the list invariant —subterms follow their parents— holds for both arguments and -- guarantee that it still holds for the result. -- * In the same manner, we assume that all 'Exp' node counts precede 'Acc' node counts and -- guarantee that this also hold for the result. -- -- RCE: The list combination should be able to be performed as a more efficient merge. -- (+++) :: NodeCounts -> NodeCounts -> NodeCounts (ns1,g1) +++ (ns2,g2) = (foldr insert ns1 ns2, Map.unionWith Set.union g1 g2) where insert x [] = [x] insert x@(AccNodeCount sa1 count1) ys@(y@(AccNodeCount sa2 count2) : ys') | sa1 == sa2 = AccNodeCount (sa1 `pickNoneAvar` sa2) (count1 + count2) : ys' | sa1 `higherSSA` sa2 = x : ys | otherwise = y : insert x ys' insert x@(ExpNodeCount se1 count1) ys@(y@(ExpNodeCount se2 count2) : ys') | se1 == se2 = ExpNodeCount (se1 `pickNoneVar` se2) (count1 + count2) : ys' | se1 `higherSSE` se2 = x : ys | otherwise = y : insert x ys' -- insert x@(SeqNodeCount se1 count1) ys@(y@(SeqNodeCount se2 count2) : ys') -- | se1 == se2 = SeqNodeCount (se1 `pickNoneSvar` se2) (count1 + count2) : ys' -- | se1 `higherSSS` se2 = x : ys -- | otherwise = y : insert x ys' insert x@(AccNodeCount _ _) (y@(ExpNodeCount _ _) : ys') = y : insert x ys' insert x@(ExpNodeCount _ _) (y@(AccNodeCount _ _) : ys') = x : insert y ys' -- insert x@(SeqNodeCount _ _) (y@(ExpNodeCount _ _) : ys') -- = y : insert x ys' -- insert x@(ExpNodeCount _ _) (y@(SeqNodeCount _ _) : ys') -- = x : insert y ys' -- insert x@(AccNodeCount _ _) (y@(SeqNodeCount _ _) : ys') -- = y : insert x ys' -- insert x@(SeqNodeCount _ _) (y@(AccNodeCount _ _) : ys') -- = x : insert y ys' (StableSharingAcc _ (AvarSharing _)) `pickNoneAvar` sa2 = sa2 sa1 `pickNoneAvar` _sa2 = sa1 (StableSharingExp _ (VarSharing _)) `pickNoneVar` sa2 = sa2 sa1 `pickNoneVar` _sa2 = sa1 -- pickNoneSvar :: StableSharingSeq -> StableSharingSeq -> StableSharingSeq -- (StableSharingSeq _ (SvarSharing _)) `pickNoneSvar` sa2 = sa2 -- sa1 `pickNoneSvar` _sa2 = sa1 -- Build an initial environment for the tag values given in the first argument for traversing an -- array expression. The 'StableSharingAcc's for all tags /actually used/ in the expressions are -- in the second argument. (Tags are not used if a bound variable has no usage occurrence.) -- -- Bail out if any tag occurs multiple times as this indicates that the sharing of an argument -- variable was not preserved and we cannot build an appropriate initial environment (c.f., comments -- at 'determineScopesAcc'. -- buildInitialEnvAcc :: [Level] -> [StableSharingAcc] -> [StableSharingAcc] buildInitialEnvAcc tags sas = map (lookupSA sas) tags where lookupSA sas tag1 = case filter hasTag sas of [] -> noStableSharing -- tag is not used in the analysed expression [sa] -> sa -- tag has a unique occurrence sas2 -> $internalError "buildInitialEnvAcc" $ "Encountered duplicate 'ATag's\n " ++ intercalate ", " (map showSA sas2) where hasTag (StableSharingAcc _ (AccSharing _ (Atag tag2))) = tag1 == tag2 hasTag sa = $internalError "buildInitialEnvAcc" $ "Encountered a node that is not a plain 'Atag'\n " ++ showSA sa noStableSharing :: StableSharingAcc noStableSharing = StableSharingAcc noStableAccName (undefined :: SharingAcc acc exp ()) showSA (StableSharingAcc _ (AccSharing sn acc)) = show (hashStableNameHeight sn) ++ ": " ++ showPreAccOp acc showSA (StableSharingAcc _ (AvarSharing sn)) = "AvarSharing " ++ show (hashStableNameHeight sn) showSA (StableSharingAcc _ (AletSharing sa _ )) = "AletSharing " ++ show sa ++ "..." -- Build an initial environment for the tag values given in the first argument for traversing a -- scalar expression. The 'StableSharingExp's for all tags /actually used/ in the expressions are -- in the second argument. (Tags are not used if a bound variable has no usage occurrence.) -- -- Bail out if any tag occurs multiple times as this indicates that the sharing of an argument -- variable was not preserved and we cannot build an appropriate initial environment (c.f., comments -- at 'determineScopesAcc'. -- buildInitialEnvExp :: [Level] -> [StableSharingExp] -> [StableSharingExp] buildInitialEnvExp tags ses = map (lookupSE ses) tags where lookupSE ses tag1 = case filter hasTag ses of [] -> noStableSharing -- tag is not used in the analysed expression [se] -> se -- tag has a unique occurrence ses2 -> $internalError "buildInitialEnvExp" ("Encountered a duplicate 'Tag'\n " ++ intercalate ", " (map showSE ses2)) where hasTag (StableSharingExp _ (ExpSharing _ (Tag tag2))) = tag1 == tag2 hasTag se = $internalError "buildInitialEnvExp" ("Encountered a node that is not a plain 'Tag'\n " ++ showSE se) noStableSharing :: StableSharingExp noStableSharing = StableSharingExp noStableExpName (undefined :: SharingExp acc exp ()) showSE (StableSharingExp _ (ExpSharing sn exp)) = show (hashStableNameHeight sn) ++ ": " ++ showPreExpOp exp showSE (StableSharingExp _ (VarSharing sn)) = "VarSharing " ++ show (hashStableNameHeight sn) showSE (StableSharingExp _ (LetSharing se _ )) = "LetSharing " ++ show se ++ "..." -- Determine whether a 'NodeCount' is for an 'Atag' or 'Tag', which represent free variables. -- isFreeVar :: NodeCount -> Bool isFreeVar (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag _))) _) = True isFreeVar (ExpNodeCount (StableSharingExp _ (ExpSharing _ (Tag _))) _) = True isFreeVar _ = False -- Determine scope of shared subterms -- ================================== -- Determine the scopes of all variables representing shared subterms (Phase Two) in a bottom-up -- sweep. The first argument determines whether array computations are floated out of expressions -- irrespective of whether they are shared or not — 'True' implies floating them out. -- -- In addition to the AST with sharing information, yield the 'StableSharingAcc's for all free -- variables of 'rootAcc', which are represented by 'Atag' leaves in the tree. They are in order of -- the tag values — i.e., in the same order that they need to appear in an environment to use the -- tag for indexing into that environment. -- -- Precondition: there are only 'AvarSharing' and 'AccSharing' nodes in the argument. -- determineScopesAcc :: Typeable a => Config -> [Level] -> OccMap Acc -> UnscopedAcc a -> (ScopedAcc a, [StableSharingAcc]) determineScopesAcc config fvs accOccMap rootAcc = let (sharingAcc, (counts, _)) = determineScopesSharingAcc config accOccMap rootAcc unboundTrees = filter (not . isFreeVar) counts in if all isFreeVar counts then (sharingAcc, buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- counts]) else $internalError "determineScopesAcc" ("unbound shared subtrees" ++ show unboundTrees) determineScopesSharingAcc :: Config -> OccMap Acc -> UnscopedAcc a -> (ScopedAcc a, NodeCounts) determineScopesSharingAcc config accOccMap = scopesAcc where scopesAcc :: forall arrs. UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts) scopesAcc (UnscopedAcc _ (AletSharing _ _)) = $internalError "determineScopesSharingAcc: scopesAcc" "unexpected 'AletSharing'" scopesAcc (UnscopedAcc _ (AvarSharing sn)) = (ScopedAcc [] (AvarSharing sn), StableSharingAcc sn (AvarSharing sn) `insertAccNode` noNodeCounts) scopesAcc (UnscopedAcc _ (AccSharing sn pacc)) = case pacc of Atag i -> reconstruct (Atag i) noNodeCounts Pipe afun1 afun2 acc -> let (afun1', accCount1) = scopesAfun1 afun1 (afun2', accCount2) = scopesAfun1 afun2 (acc', accCount3) = scopesAcc acc in reconstruct (Pipe afun1' afun2' acc') (accCount1 +++ accCount2 +++ accCount3) Aforeign ff afun acc -> let (acc', accCount) = scopesAcc acc in reconstruct (Aforeign ff afun acc') accCount Acond e acc1 acc2 -> let (e' , accCount1) = scopesExp e (acc1', accCount2) = scopesAcc acc1 (acc2', accCount3) = scopesAcc acc2 in reconstruct (Acond e' acc1' acc2') (accCount1 +++ accCount2 +++ accCount3) Awhile pred iter init -> let (pred', accCount1) = scopesAfun1 pred (iter', accCount2) = scopesAfun1 iter (init', accCount3) = scopesAcc init in reconstruct (Awhile pred' iter' init') (accCount1 +++ accCount2 +++ accCount3) Atuple tup -> let (tup', accCount) = travAtup tup in reconstruct (Atuple tup') accCount Aprj ix a -> travA (Aprj ix) a Use arr -> reconstruct (Use arr) noNodeCounts Unit e -> let (e', accCount) = scopesExp e in reconstruct (Unit e') accCount Generate sh f -> let (sh', accCount1) = scopesExp sh (f' , accCount2) = scopesFun1 f in reconstruct (Generate sh' f') (accCount1 +++ accCount2) Reshape sh acc -> travEA Reshape sh acc Replicate n acc -> travEA Replicate n acc Slice acc i -> travEA (flip Slice) i acc Map f acc -> let (f' , accCount1) = scopesFun1 f (acc', accCount2) = scopesAcc acc in reconstruct (Map f' acc') (accCount1 +++ accCount2) ZipWith f acc1 acc2 -> travF2A2 ZipWith f acc1 acc2 Fold f z acc -> travF2EA Fold f z acc Fold1 f acc -> travF2A Fold1 f acc FoldSeg f z acc1 acc2 -> let (f' , accCount1) = scopesFun2 f (z' , accCount2) = scopesExp z (acc1', accCount3) = scopesAcc acc1 (acc2', accCount4) = scopesAcc acc2 in reconstruct (FoldSeg f' z' acc1' acc2') (accCount1 +++ accCount2 +++ accCount3 +++ accCount4) Fold1Seg f acc1 acc2 -> travF2A2 Fold1Seg f acc1 acc2 Scanl f z acc -> travF2EA Scanl f z acc Scanl' f z acc -> travF2EA Scanl' f z acc Scanl1 f acc -> travF2A Scanl1 f acc Scanr f z acc -> travF2EA Scanr f z acc Scanr' f z acc -> travF2EA Scanr' f z acc Scanr1 f acc -> travF2A Scanr1 f acc Permute fc acc1 fp acc2 -> let (fc' , accCount1) = scopesFun2 fc (acc1', accCount2) = scopesAcc acc1 (fp' , accCount3) = scopesFun1 fp (acc2', accCount4) = scopesAcc acc2 in reconstruct (Permute fc' acc1' fp' acc2') (accCount1 +++ accCount2 +++ accCount3 +++ accCount4) Backpermute sh fp acc -> let (sh' , accCount1) = scopesExp sh (fp' , accCount2) = scopesFun1 fp (acc', accCount3) = scopesAcc acc in reconstruct (Backpermute sh' fp' acc') (accCount1 +++ accCount2 +++ accCount3) Stencil st bnd acc -> let (st' , accCount1) = scopesStencil1 acc st (bnd', accCount2) = scopesBoundary bnd (acc', accCount3) = scopesAcc acc in reconstruct (Stencil st' bnd' acc') (accCount1 +++ accCount2 +++ accCount3) Stencil2 st bnd1 acc1 bnd2 acc2 -> let (st' , accCount1) = scopesStencil2 acc1 acc2 st (bnd1', accCount2) = scopesBoundary bnd1 (acc1', accCount3) = scopesAcc acc1 (bnd2', accCount4) = scopesBoundary bnd2 (acc2', accCount5) = scopesAcc acc2 in reconstruct (Stencil2 st' bnd1' acc1' bnd2' acc2') (accCount1 +++ accCount2 +++ accCount3 +++ accCount4 +++ accCount5) -- Collect seq -> let -- (seq', accCount1) = scopesSeq seq -- in -- reconstruct (Collect seq') accCount1 where travEA :: (ScopedExp e -> ScopedAcc arrs' -> PreAcc ScopedAcc ScopedExp arrs) -> RootExp e -> UnscopedAcc arrs' -> (ScopedAcc arrs, NodeCounts) travEA c e acc = reconstruct (c e' acc') (accCount1 +++ accCount2) where (e' , accCount1) = scopesExp e (acc', accCount2) = scopesAcc acc travF2A :: (Elt a, Elt b) => ((Exp a -> Exp b -> ScopedExp c) -> ScopedAcc arrs' -> PreAcc ScopedAcc ScopedExp arrs) -> (Exp a -> Exp b -> RootExp c) -> UnscopedAcc arrs' -> (ScopedAcc arrs, NodeCounts) travF2A c f acc = reconstruct (c f' acc') (accCount1 +++ accCount2) where (f' , accCount1) = scopesFun2 f (acc', accCount2) = scopesAcc acc travF2EA :: (Elt a, Elt b) => ((Exp a -> Exp b -> ScopedExp c) -> ScopedExp e -> ScopedAcc arrs' -> PreAcc ScopedAcc ScopedExp arrs) -> (Exp a -> Exp b -> RootExp c) -> RootExp e -> UnscopedAcc arrs' -> (ScopedAcc arrs, NodeCounts) travF2EA c f e acc = reconstruct (c f' e' acc') (accCount1 +++ accCount2 +++ accCount3) where (f' , accCount1) = scopesFun2 f (e' , accCount2) = scopesExp e (acc', accCount3) = scopesAcc acc travF2A2 :: (Elt a, Elt b) => ((Exp a -> Exp b -> ScopedExp c) -> ScopedAcc arrs1 -> ScopedAcc arrs2 -> PreAcc ScopedAcc ScopedExp arrs) -> (Exp a -> Exp b -> RootExp c) -> UnscopedAcc arrs1 -> UnscopedAcc arrs2 -> (ScopedAcc arrs, NodeCounts) travF2A2 c f acc1 acc2 = reconstruct (c f' acc1' acc2') (accCount1 +++ accCount2 +++ accCount3) where (f' , accCount1) = scopesFun2 f (acc1', accCount2) = scopesAcc acc1 (acc2', accCount3) = scopesAcc acc2 travAtup :: Atuple UnscopedAcc a -> (Atuple ScopedAcc a, NodeCounts) travAtup NilAtup = (NilAtup, noNodeCounts) travAtup (SnocAtup tup a) = let (tup', accCountT) = travAtup tup (a', accCountA) = scopesAcc a in (SnocAtup tup' a', accCountT +++ accCountA) travA :: (ScopedAcc arrs' -> PreAcc ScopedAcc ScopedExp arrs) -> UnscopedAcc arrs' -> (ScopedAcc arrs, NodeCounts) travA c acc = reconstruct (c acc') accCount where (acc', accCount) = scopesAcc acc -- Occurrence count of the currently processed node accOccCount = let StableNameHeight sn' _ = sn in lookupWithASTName accOccMap (StableASTName sn') -- Reconstruct the current tree node. -- -- * If the current node is being shared ('accOccCount > 1'), replace it by a 'AvarSharing' -- node and float the shared subtree out wrapped in a 'NodeCounts' value. -- * If the current node is not shared, reconstruct it in place. -- * Special case for free variables ('Atag'): Replace the tree by a sharing variable and -- float the 'Atag' out in a 'NodeCounts' value. This is independent of the number of -- occurrences. -- -- In either case, any completed 'NodeCounts' are injected as bindings using 'AletSharing' -- node. -- reconstruct :: PreAcc ScopedAcc ScopedExp arrs -> NodeCounts -> (ScopedAcc arrs, NodeCounts) reconstruct newAcc@(Atag _) _subCount -- free variable => replace by a sharing variable regardless of the number of -- occurrences = let thisCount = StableSharingAcc sn (AccSharing sn newAcc) `insertAccNode` noNodeCounts in tracePure "FREE" (show thisCount) (ScopedAcc [] (AvarSharing sn), thisCount) reconstruct newAcc subCount -- shared subtree => replace by a sharing variable (if 'recoverAccSharing' enabled) | accOccCount > 1 && recoverAccSharing config = let allCount = (StableSharingAcc sn sharingAcc `insertAccNode` newCount) in tracePure ("SHARED" ++ completed) (show allCount) (ScopedAcc [] (AvarSharing sn), allCount) -- neither shared nor free variable => leave it as it is | otherwise = tracePure ("Normal" ++ completed) (show newCount) (ScopedAcc [] sharingAcc, newCount) where -- Determine the bindings that need to be attached to the current node... (newCount, bindHere) = filterCompleted subCount -- ...and wrap them in 'AletSharing' constructors lets = foldl (flip (.)) id . map (\x y -> AletSharing x (ScopedAcc [] y)) $ bindHere sharingAcc = lets $ AccSharing sn newAcc -- trace support completed | null bindHere = "" | otherwise = "(" ++ show (length bindHere) ++ " lets)" -- Extract *leading* nodes that have a complete node count (i.e., their node count is equal -- to the number of occurrences of that node in the overall expression). -- -- Nodes with a completed node count should be let bound at the currently processed node. -- -- NB: Only extract leading nodes (i.e., the longest run at the *front* of the list that is -- complete). Otherwise, we would let-bind subterms before their parents, which leads -- scope errors. -- filterCompleted :: NodeCounts -> (NodeCounts, [StableSharingAcc]) filterCompleted (ns, graph) = let bindable = map (isBindable bindable (map nodeName ns)) ns (bind, rest) = partition fst $ zip bindable ns in ((map snd rest, graph), [sa | AccNodeCount sa _ <- map snd bind]) where -- a node is not yet complete while the node count 'n' is below the overall number -- of occurrences for that node in the whole program, with the exception that free -- variables are never complete isCompleted nc@(AccNodeCount sa n) | not . isFreeVar $ nc = lookupWithSharingAcc accOccMap sa == n isCompleted _ = False isBindable :: [Bool] -> [NodeName] -> NodeCount -> Bool isBindable bindable nodes nc@(AccNodeCount _ _) = let superTerms = Set.toList $ graph Map.! nodeName nc unbound = mapMaybe (`elemIndex` nodes) superTerms in isCompleted nc && all (bindable !!) unbound isBindable _ _ (ExpNodeCount _ _) = False -- isBindable _ _ (SeqNodeCount _ _) = False -- scopesSeq :: forall arrs. RootSeq arrs -> (ScopedSeq arrs, NodeCounts) -- scopesSeq = determineScopesSeq config accOccMap scopesExp :: RootExp t -> (ScopedExp t, NodeCounts) scopesExp = determineScopesExp config accOccMap -- The lambda bound variable is at this point already irrelevant; for details, see -- Note [Traversing functions and side effects] -- scopesAfun1 :: Arrays a1 => (Acc a1 -> UnscopedAcc a2) -> (Acc a1 -> ScopedAcc a2, NodeCounts) scopesAfun1 f = (const (ScopedAcc ssa body'), (counts',graph)) where body@(UnscopedAcc fvs _) = f undefined ((ScopedAcc [] body'), (counts,graph)) = scopesAcc body ssa = buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- freeCounts] (freeCounts, counts') = partition isBoundHere counts isBoundHere (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag i))) _) = i `elem` fvs isBoundHere _ = False -- The lambda bound variable is at this point already irrelevant; for details, see -- Note [Traversing functions and side effects] -- scopesFun1 :: Elt e1 => (Exp e1 -> RootExp e2) -> (Exp e1 -> ScopedExp e2, NodeCounts) scopesFun1 f = (const body, counts) where (body, counts) = scopesExp (f undefined) -- The lambda bound variable is at this point already irrelevant; for details, see -- Note [Traversing functions and side effects] -- scopesFun2 :: (Elt e1, Elt e2) => (Exp e1 -> Exp e2 -> RootExp e3) -> (Exp e1 -> Exp e2 -> ScopedExp e3, NodeCounts) scopesFun2 f = (\_ _ -> body, counts) where (body, counts) = scopesExp (f undefined undefined) -- The lambda bound variable is at this point already irrelevant; for details, see -- Note [Traversing functions and side effects] -- scopesStencil1 :: forall sh e1 e2 stencil. Stencil sh e1 stencil => UnscopedAcc (Array sh e1){-dummy-} -> (stencil -> RootExp e2) -> (stencil -> ScopedExp e2, NodeCounts) scopesStencil1 _ stencilFun = (const body, counts) where (body, counts) = scopesExp (stencilFun undefined) -- The lambda bound variable is at this point already irrelevant; for details, see -- Note [Traversing functions and side effects] -- scopesStencil2 :: forall sh e1 e2 e3 stencil1 stencil2. (Stencil sh e1 stencil1, Stencil sh e2 stencil2) => UnscopedAcc (Array sh e1){-dummy-} -> UnscopedAcc (Array sh e2){-dummy-} -> (stencil1 -> stencil2 -> RootExp e3) -> (stencil1 -> stencil2 -> ScopedExp e3, NodeCounts) scopesStencil2 _ _ stencilFun = (\_ _ -> body, counts) where (body, counts) = scopesExp (stencilFun undefined undefined) scopesBoundary :: PreBoundary UnscopedAcc RootExp t -> (PreBoundary ScopedAcc ScopedExp t, NodeCounts) scopesBoundary bndy = case bndy of Clamp -> (Clamp, noNodeCounts) Mirror -> (Mirror, noNodeCounts) Wrap -> (Wrap, noNodeCounts) Constant v -> (Constant v, noNodeCounts) Function f -> let (body, counts) = scopesFun1 f in (Function body, counts) determineScopesExp :: Config -> OccMap Acc -> RootExp t -> (ScopedExp t, NodeCounts) -- Root (closed) expression plus Acc node counts determineScopesExp config accOccMap (RootExp expOccMap exp@(UnscopedExp fvs _)) = let ((ScopedExp [] expWithScopes), (nodeCounts,graph)) = determineScopesSharingExp config accOccMap expOccMap exp (expCounts, accCounts) = partition isExpNodeCount nodeCounts isExpNodeCount ExpNodeCount{} = True isExpNodeCount _ = False in (ScopedExp (buildInitialEnvExp fvs [se | ExpNodeCount se _ <- expCounts]) expWithScopes, cleanCounts (accCounts,graph)) determineScopesSharingExp :: Config -> OccMap Acc -> OccMap Exp -> UnscopedExp t -> (ScopedExp t, NodeCounts) determineScopesSharingExp config accOccMap expOccMap = scopesExp where scopesAcc :: UnscopedAcc a -> (ScopedAcc a, NodeCounts) scopesAcc = determineScopesSharingAcc config accOccMap scopesFun1 :: (Exp a -> UnscopedExp b) -> (Exp a -> ScopedExp b, NodeCounts) scopesFun1 f = tracePure ("LAMBDA " ++ (show ssa)) (show counts) (const (ScopedExp ssa body'), (counts',graph)) where body@(UnscopedExp fvs _) = f undefined ((ScopedExp [] body'), (counts, graph)) = scopesExp body ssa = buildInitialEnvExp fvs [se | ExpNodeCount se _ <- freeCounts] (freeCounts, counts') = partition isBoundHere counts isBoundHere (ExpNodeCount (StableSharingExp _ (ExpSharing _ (Tag i))) _) = i `elem` fvs isBoundHere _ = False scopesExp :: forall t. UnscopedExp t -> (ScopedExp t, NodeCounts) scopesExp (UnscopedExp _ (LetSharing _ _)) = $internalError "determineScopesSharingExp: scopesExp" "unexpected 'LetSharing'" scopesExp (UnscopedExp _ (VarSharing sn)) = (ScopedExp [] (VarSharing sn), StableSharingExp sn (VarSharing sn) `insertExpNode` noNodeCounts) scopesExp (UnscopedExp _ (ExpSharing sn pexp)) = case pexp of Tag i -> reconstruct (Tag i) noNodeCounts Const c -> reconstruct (Const c) noNodeCounts Undef -> reconstruct Undef noNodeCounts Tuple tup -> let (tup', accCount) = travTup tup in reconstruct (Tuple tup') accCount Prj i e -> travE1 (Prj i) e IndexNil -> reconstruct IndexNil noNodeCounts IndexCons ix i -> travE2 IndexCons ix i IndexHead i -> travE1 IndexHead i IndexTail ix -> travE1 IndexTail ix IndexAny -> reconstruct IndexAny noNodeCounts ToIndex sh ix -> travE2 ToIndex sh ix FromIndex sh e -> travE2 FromIndex sh e Cond e1 e2 e3 -> travE3 Cond e1 e2 e3 While p it i -> let (p' , accCount1) = scopesFun1 p (it', accCount2) = scopesFun1 it (i' , accCount3) = scopesExp i in reconstruct (While p' it' i') (accCount1 +++ accCount2 +++ accCount3) PrimConst c -> reconstruct (PrimConst c) noNodeCounts PrimApp p e -> travE1 (PrimApp p) e Index a e -> travAE Index a e LinearIndex a e -> travAE LinearIndex a e Shape a -> travA Shape a ShapeSize e -> travE1 ShapeSize e Intersect sh1 sh2 -> travE2 Intersect sh1 sh2 Union sh1 sh2 -> travE2 Union sh1 sh2 Foreign ff f e -> travE1 (Foreign ff f) e Coerce e -> travE1 Coerce e where travTup :: Tuple UnscopedExp tup -> (Tuple ScopedExp tup, NodeCounts) travTup NilTup = (NilTup, noNodeCounts) travTup (SnocTup tup e) = let (tup', accCountT) = travTup tup (e' , accCountE) = scopesExp e in (SnocTup tup' e', accCountT +++ accCountE) travE1 :: (ScopedExp a -> PreExp ScopedAcc ScopedExp t) -> UnscopedExp a -> (ScopedExp t, NodeCounts) travE1 c e = reconstruct (c e') accCount where (e', accCount) = scopesExp e travE2 :: (ScopedExp a -> ScopedExp b -> PreExp ScopedAcc ScopedExp t) -> UnscopedExp a -> UnscopedExp b -> (ScopedExp t, NodeCounts) travE2 c e1 e2 = reconstruct (c e1' e2') (accCount1 +++ accCount2) where (e1', accCount1) = scopesExp e1 (e2', accCount2) = scopesExp e2 travE3 :: (ScopedExp a -> ScopedExp b -> ScopedExp c -> PreExp ScopedAcc ScopedExp t) -> UnscopedExp a -> UnscopedExp b -> UnscopedExp c -> (ScopedExp t, NodeCounts) travE3 c e1 e2 e3 = reconstruct (c e1' e2' e3') (accCount1 +++ accCount2 +++ accCount3) where (e1', accCount1) = scopesExp e1 (e2', accCount2) = scopesExp e2 (e3', accCount3) = scopesExp e3 travA :: (ScopedAcc a -> PreExp ScopedAcc ScopedExp t) -> UnscopedAcc a -> (ScopedExp t, NodeCounts) travA c acc = maybeFloatOutAcc c acc' accCount where (acc', accCount) = scopesAcc acc travAE :: (ScopedAcc a -> ScopedExp b -> PreExp ScopedAcc ScopedExp t) -> UnscopedAcc a -> UnscopedExp b -> (ScopedExp t, NodeCounts) travAE c acc e = maybeFloatOutAcc (`c` e') acc' (accCountA +++ accCountE) where (acc', accCountA) = scopesAcc acc (e' , accCountE) = scopesExp e maybeFloatOutAcc :: (ScopedAcc a -> PreExp ScopedAcc ScopedExp t) -> ScopedAcc a -> NodeCounts -> (ScopedExp t, NodeCounts) maybeFloatOutAcc c acc@(ScopedAcc _ (AvarSharing _)) accCount -- nothing to float out = reconstruct (c acc) accCount maybeFloatOutAcc c acc accCount | floatOutAcc config = reconstruct (c var) ((stableAcc `insertAccNode` noNodeCounts) +++ accCount) | otherwise = reconstruct (c acc) accCount where (var, stableAcc) = abstract acc (\(ScopedAcc _ s) -> s) abstract :: ScopedAcc a -> (ScopedAcc a -> SharingAcc ScopedAcc ScopedExp a) -> (ScopedAcc a, StableSharingAcc) abstract (ScopedAcc _ (AvarSharing _)) _ = $internalError "sharingAccToVar" "AvarSharing" abstract (ScopedAcc ssa (AletSharing sa acc)) lets = abstract acc (lets . (\x -> ScopedAcc ssa (AletSharing sa x))) abstract acc@(ScopedAcc ssa (AccSharing sn _)) lets = (ScopedAcc ssa (AvarSharing sn), StableSharingAcc sn (lets acc)) -- Occurrence count of the currently processed node expOccCount = let StableNameHeight sn' _ = sn in lookupWithASTName expOccMap (StableASTName sn') -- Reconstruct the current tree node. -- -- * If the current node is being shared ('expOccCount > 1'), replace it by a 'VarSharing' -- node and float the shared subtree out wrapped in a 'NodeCounts' value. -- * If the current node is not shared, reconstruct it in place. -- * Special case for free variables ('Tag'): Replace the tree by a sharing variable and -- float the 'Tag' out in a 'NodeCounts' value. This is independent of the number of -- occurrences. -- -- In either case, any completed 'NodeCounts' are injected as bindings using 'LetSharing' -- node. -- reconstruct :: PreExp ScopedAcc ScopedExp t -> NodeCounts -> (ScopedExp t, NodeCounts) reconstruct newExp@(Tag _) _subCount -- free variable => replace by a sharing variable regardless of the number of -- occurrences = let thisCount = StableSharingExp sn (ExpSharing sn newExp) `insertExpNode` noNodeCounts in tracePure "FREE" (show thisCount) (ScopedExp [] (VarSharing sn), thisCount) reconstruct newExp subCount -- shared subtree => replace by a sharing variable (if 'recoverExpSharing' enabled) | expOccCount > 1 && recoverExpSharing config = let allCount = StableSharingExp sn sharingExp `insertExpNode` newCount in tracePure ("SHARED" ++ completed) (show allCount) (ScopedExp [] (VarSharing sn), allCount) -- neither shared nor free variable => leave it as it is | otherwise = tracePure ("Normal" ++ completed) (show newCount) (ScopedExp [] sharingExp, newCount) where -- Determine the bindings that need to be attached to the current node... (newCount, bindHere) = filterCompleted subCount -- ...and wrap them in 'LetSharing' constructors lets = foldl (flip (.)) id . map (\x y -> LetSharing x (ScopedExp [] y)) $ bindHere sharingExp = lets $ ExpSharing sn newExp -- trace support completed | null bindHere = "" | otherwise = "(" ++ show (length bindHere) ++ " lets)" -- Extract *leading* nodes that have a complete node count (i.e., their node count is equal -- to the number of occurrences of that node in the overall expression). -- -- Nodes with a completed node count should be let bound at the currently processed node. -- -- NB: Only extract leading nodes (i.e., the longest run at the *front* of the list that is -- complete). Otherwise, we would let-bind subterms before their parents, which leads -- scope errors. -- filterCompleted :: NodeCounts -> (NodeCounts, [StableSharingExp]) filterCompleted (ns,graph) = let bindable = map (isBindable bindable (map nodeName ns)) ns (bind, unbind) = partition fst $ zip bindable ns in ((map snd unbind, graph), [se | ExpNodeCount se _ <- map snd bind]) where -- a node is not yet complete while the node count 'n' is below the overall number -- of occurrences for that node in the whole program, with the exception that free -- variables are never complete isCompleted nc@(ExpNodeCount sa n) | not . isFreeVar $ nc = lookupWithSharingExp expOccMap sa == n isCompleted _ = False isBindable :: [Bool] -> [NodeName] -> NodeCount -> Bool isBindable bindable nodes nc@(ExpNodeCount _ _) = let superTerms = Set.toList $ graph Map.! nodeName nc unbound = mapMaybe (`elemIndex` nodes) superTerms in isCompleted nc && all (bindable !!) unbound isBindable _ _ (AccNodeCount _ _) = False -- isBindable _ _ (SeqNodeCount _ _) = False {-- determineScopesSeq :: Config -> OccMap Acc -> RootSeq t -> (ScopedSeq t, NodeCounts) -- Root (closed) expression plus Acc node counts determineScopesSeq config accOccMap (RootSeq seqOccMap seq) = let (ScopedSeq seqWithScopes, (nodeCounts,graph)) = determineScopesSharingSeq config accOccMap seqOccMap seq binds = [s | SeqNodeCount s _ <- nodeCounts] lets = foldl (flip (.)) id . map (\x y -> SletSharing x (ScopedSeq y)) $ binds sharingSeq = lets seqWithScopes newCounts = filter (not . isSeqCount) nodeCounts isSeqCount SeqNodeCount{} = True isSeqCount _ = False in (ScopedSeq sharingSeq, cleanCounts (newCounts,graph)) determineScopesSharingSeq :: Config -> OccMap Acc -> OccMap Seq -> UnscopedSeq t -> (ScopedSeq t, NodeCounts) determineScopesSharingSeq config accOccMap _seqOccMap = scopesSeq where scopesAcc :: UnscopedAcc a -> (ScopedAcc a, NodeCounts) scopesAcc = determineScopesSharingAcc config accOccMap scopesExp :: RootExp t -> (ScopedExp t, NodeCounts) scopesExp = determineScopesExp config accOccMap scopesFun2 :: (Elt e1, Elt e2) => (Exp e1 -> Exp e2 -> RootExp e3) -> (Exp e1 -> Exp e2 -> ScopedExp e3, NodeCounts) scopesFun2 f = (\_ _ -> body, counts) where (body, counts) = scopesExp (f undefined undefined) -- The lambda bound variable is at this point already irrelevant; for details, see -- Note [Traversing functions and side effects] -- scopesAfun1 :: Arrays a1 => (Acc a1 -> UnscopedAcc a2) -> (Acc a1 -> ScopedAcc a2, NodeCounts) scopesAfun1 f = (const (ScopedAcc ssa body'), (counts',graph)) where body@(UnscopedAcc fvs _) = f undefined ((ScopedAcc [] body'), (counts,graph)) = scopesAcc body ssa = buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- freeCounts] (freeCounts, counts') = partition isBoundHere counts isBoundHere (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag i))) _) = i `elem` fvs isBoundHere _ = False scopesAfun2 :: (Arrays a1, Arrays a2) => (Acc a1 -> Acc a2 -> UnscopedAcc a3) -> (Acc a1 -> Acc a2 -> ScopedAcc a3, NodeCounts) scopesAfun2 f = (\ _ _ -> (ScopedAcc ssa body'), (counts',graph)) where body@(UnscopedAcc fvs _) = f undefined undefined ((ScopedAcc [] body'), (counts,graph)) = scopesAcc body ssa = buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- freeCounts] (freeCounts, counts') = partition isBoundHere counts isBoundHere (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag i))) _) = i `elem` fvs isBoundHere _ = False scopesAfun3 :: (Arrays a1, Arrays a2, Arrays a3) => (Acc a1 -> Acc a2 -> Acc a3 -> UnscopedAcc a4) -> (Acc a1 -> Acc a2 -> Acc a3 -> ScopedAcc a4, NodeCounts) scopesAfun3 f = (\ _ _ _ -> (ScopedAcc ssa body'), (counts',graph)) where body@(UnscopedAcc fvs _) = f undefined undefined undefined ((ScopedAcc [] body'), (counts,graph)) = scopesAcc body ssa = buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- freeCounts] (freeCounts, counts') = partition isBoundHere counts isBoundHere (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag i))) _) = i `elem` fvs isBoundHere _ = False scopesTup :: Atuple UnscopedSeq tup -> (Atuple ScopedSeq tup, NodeCounts) scopesTup NilAtup = (NilAtup, noNodeCounts) scopesTup (SnocAtup tup s) = let (tup', accCountT) = scopesTup tup (s' , accCountS) = scopesSeq s in (SnocAtup tup' s', accCountT +++ accCountS) scopesSeq :: forall t. UnscopedSeq t -> (ScopedSeq t, NodeCounts) scopesSeq (UnscopedSeq (SletSharing _ _)) = $internalError "determineScopesSharingSeq: scopesSeq" "unexpected 'LetSharing'" scopesSeq (UnscopedSeq (SvarSharing sn)) = (ScopedSeq (SvarSharing sn), StableSharingSeq sn (SvarSharing sn) `insertSeqNode` noNodeCounts) scopesSeq (UnscopedSeq (SeqSharing sn s)) = case s of StreamIn arrs -> producer (StreamIn arrs) noNodeCounts ToSeq sl acc -> let (acc', accCount1) = scopesAcc acc in producer (ToSeq sl acc') accCount1 MapSeq afun s' -> let (afun', accCount1) = scopesAfun1 afun (s'' , accCount2) = scopesSeq s' in producer (MapSeq afun' s'') (accCount1 +++ accCount2) ZipWithSeq afun s1 s2 -> let (afun', accCount1) = scopesAfun2 afun (s1' , accCount2) = scopesSeq s1 (s2' , accCount3) = scopesSeq s2 in producer (ZipWithSeq afun' s1' s2') (accCount1 +++ accCount2 +++ accCount3) ScanSeq fun e s' -> let (fun', accCount1) = scopesFun2 fun (e' , accCount2) = scopesExp e (s'' , accCount3) = scopesSeq s' in producer (ScanSeq fun' e' s'') (accCount1 +++ accCount2 +++ accCount3) FoldSeq fun e s' -> let (fun', accCount1) = scopesFun2 fun (e' , accCount2) = scopesExp e (s'' , accCount3) = scopesSeq s' in consumer (FoldSeq fun' e' s'') (accCount1 +++ accCount2 +++ accCount3) FoldSeqFlatten afun acc s' -> let (afun', accCount1) = scopesAfun3 afun (acc' , accCount2) = scopesAcc acc (s'' , accCount3) = scopesSeq s' in consumer (FoldSeqFlatten afun' acc' s'') (accCount1 +++ accCount2 +++ accCount3) Stuple tup -> let (tup', accCount1) = scopesTup tup in consumer (Stuple tup') accCount1 where -- All producers must be replaced by sharing variables -- producer :: (t ~ [a], Arrays a) => PreSeq ScopedAcc ScopedSeq ScopedExp t -> NodeCounts -> (ScopedSeq t, NodeCounts) producer newSeq subCount = let allCount = StableSharingSeq sn (SeqSharing sn newSeq) `insertSeqNode` subCount in tracePure "Producer" (show allCount) (ScopedSeq (SvarSharing sn), allCount) -- Consumers cannot be shared. -- consumer :: PreSeq ScopedAcc ScopedSeq ScopedExp t -> NodeCounts -> (ScopedSeq t, NodeCounts) consumer newSeq subCount = tracePure "Consumer" (show subCount) (ScopedSeq (SeqSharing sn newSeq), subCount) --} -- |Recover sharing information and annotate the HOAS AST with variable and let binding -- annotations. The first argument determines whether array computations are floated out of -- expressions irrespective of whether they are shared or not — 'True' implies floating them out. -- -- Also returns the 'StableSharingAcc's of all 'Atag' leaves in environment order — they represent -- the free variables of the AST. -- -- NB: Strictly speaking, this function is not deterministic, as it uses stable pointers to -- determine the sharing of subterms. The stable pointer API does not guarantee its -- completeness; i.e., it may miss some equalities, which implies that we may fail to discover -- some sharing. However, sharing does not affect the denotational meaning of an array -- computation; hence, we do not compromise denotational correctness. -- -- There is one caveat: We currently rely on the 'Atag' and 'Tag' leaves representing free -- variables to be shared if any of them is used more than once. If one is duplicated, the -- environment for de Bruijn conversion will have a duplicate entry, and hence, be of the wrong -- size, which is fatal. (The 'buildInitialEnv*' functions will already bail out.) -- {-# NOINLINE recoverSharingAcc #-} recoverSharingAcc :: Typeable a => Config -> Level -- The level of currently bound array variables -> [Level] -- The tags of newly introduced free array variables -> Acc a -> (ScopedAcc a, [StableSharingAcc]) recoverSharingAcc config alvl avars acc = let (acc', occMap) = unsafePerformIO -- to enable stable pointers; this is safe as explained above $ makeOccMapAcc config alvl acc in determineScopesAcc config avars occMap acc' {-# NOINLINE recoverSharingExp #-} recoverSharingExp :: Typeable e => Config -> Level -- The level of currently bound scalar variables -> [Level] -- The tags of newly introduced free scalar variables -> Exp e -> (ScopedExp e, [StableSharingExp]) recoverSharingExp config lvl fvar exp = let (rootExp, accOccMap) = unsafePerformIO $ do accOccMap <- newASTHashTable (exp', _) <- makeOccMapRootExp config accOccMap lvl fvar exp frozenAccOccMap <- freezeOccMap accOccMap return (exp', frozenAccOccMap) (ScopedExp sse sharingExp, _) = determineScopesExp config accOccMap rootExp in (ScopedExp [] sharingExp, sse) {-- {-# NOINLINE recoverSharingSeq #-} recoverSharingSeq :: Typeable e => Config -> Seq e -> (ScopedSeq e, [StableSharingSeq]) recoverSharingSeq config seq = let (rootSeq, accOccMap) = unsafePerformIO $ do accOccMap <- newASTHashTable (seq', _) <- makeOccMapRootSeq config accOccMap 0 seq frozenAccOccMap <- freezeOccMap accOccMap return (seq', frozenAccOccMap) (ScopedSeq sharingSeq, (ns, _)) = determineScopesSeq config accOccMap rootSeq in (ScopedSeq sharingSeq, [a | SeqNodeCount a _ <- ns]) --} -- Debugging -- --------- traceLine :: String -> String -> IO () traceLine header msg = Debug.traceIO Debug.dump_sharing $ header ++ ": " ++ msg traceChunk :: String -> String -> IO () traceChunk header msg = Debug.traceIO Debug.dump_sharing $ header ++ "\n " ++ msg tracePure :: String -> String -> a -> a tracePure header msg = Debug.trace Debug.dump_sharing $ header ++ ": " ++ msg