module Data.Array.Accelerate.Trafo.Sharing (
convertAcc, convertAfun, Afunction, AfunctionR,
convertExp, convertFun, Function, FunctionR
) where
import Control.Applicative hiding ( Const )
import Control.Monad.Fix
import Data.List
import Data.Maybe
import Data.Hashable
import Data.Typeable
import qualified Data.HashTable.IO as Hash
import qualified Data.IntMap as IntMap
import System.IO.Unsafe ( unsafePerformIO )
import System.Mem.StableName
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Array.Sugar as Sugar
import Data.Array.Accelerate.Tuple hiding ( Tuple )
import Data.Array.Accelerate.AST hiding (
PreOpenAcc(..), OpenAcc(..), Acc, Stencil(..), PreOpenExp(..), OpenExp, PreExp, Exp,
showPreAccOp, showPreExpOp )
import qualified Data.Array.Accelerate.AST as AST
import qualified Data.Array.Accelerate.Tuple as Tuple
import qualified Data.Array.Accelerate.Debug as Debug
#include "accelerate.h"
data Config = Config
{
recoverAccSharing :: Bool
, recoverExpSharing :: Bool
, floatOutAcc :: Bool
}
data Layout env env' where
EmptyLayout :: Layout env ()
PushLayout :: Typeable t
=> Layout env env' -> Idx env t -> Layout env (env', t)
prjIdx :: forall t env env'. Typeable t => String -> Int -> Layout env env' -> Idx env t
prjIdx ctxt 0 (PushLayout _ (ix :: Idx env0 t0))
= flip fromMaybe (gcast ix)
$ possiblyNestedErr ctxt $
"Couldn't match expected type `" ++ show (typeOf (undefined::t)) ++
"' with actual type `" ++ show (typeOf (undefined::t0)) ++ "'" ++
"\n Type mismatch"
prjIdx ctxt n (PushLayout l _) = prjIdx ctxt (n 1) l
prjIdx ctxt _ EmptyLayout = possiblyNestedErr ctxt "Environment doesn't contain index"
possiblyNestedErr :: String -> String -> a
possiblyNestedErr ctxt failreason
= error $ "Fatal error in Sharing.prjIdx:"
++ "\n " ++ failreason ++ " at " ++ ctxt
++ "\n Possible reason: nested data parallelism — array computation that depends on a"
++ "\n scalar variable of type 'Exp a'"
incLayout :: Layout env env' -> Layout (env, t) env'
incLayout EmptyLayout = EmptyLayout
incLayout (PushLayout lyt ix) = PushLayout (incLayout lyt) (SuccIdx ix)
sizeLayout :: Layout env env' -> Int
sizeLayout EmptyLayout = 0
sizeLayout (PushLayout lyt _) = 1 + sizeLayout lyt
convertAcc
:: Arrays arrs
=> Bool
-> Bool
-> Bool
-> Acc arrs
-> AST.Acc arrs
convertAcc shareAcc shareExp floatAcc acc
= let config = Config shareAcc shareExp (shareAcc && floatAcc)
in
convertOpenAcc config 0 [] EmptyLayout acc
convertAfun :: Afunction f => Bool -> Bool -> Bool -> f -> AST.Afun (AfunctionR f)
convertAfun shareAcc shareExp floatAcc =
let config = Config shareAcc shareExp (shareAcc && floatAcc)
in aconvert config EmptyLayout
class Afunction f where
type AfunctionR f
aconvert :: Config -> Layout aenv aenv -> f -> AST.OpenAfun aenv (AfunctionR f)
instance (Arrays a, Afunction r) => Afunction (Acc a -> r) where
type AfunctionR (Acc a -> r) = a -> AfunctionR r
aconvert config alyt f
= let a = Acc $ Atag (sizeLayout alyt)
alyt' = incLayout alyt `PushLayout` ZeroIdx
in
Alam $ aconvert config alyt' (f a)
instance Arrays b => Afunction (Acc b) where
type AfunctionR (Acc b) = b
aconvert config alyt body
= let lvl = sizeLayout alyt
vars = [lvl1, lvl2 .. 0]
in
Abody $ convertOpenAcc config lvl vars alyt body
convertOpenAcc
:: Arrays arrs
=> Config
-> Level
-> [Level]
-> Layout aenv aenv
-> Acc arrs
-> AST.OpenAcc aenv arrs
convertOpenAcc config lvl fvs alyt acc
= let (sharingAcc, initialEnv) = recoverSharingAcc config lvl fvs acc
in
convertSharingAcc config alyt initialEnv sharingAcc
convertSharingAcc
:: forall aenv arrs. Arrays arrs
=> Config
-> Layout aenv aenv
-> [StableSharingAcc]
-> SharingAcc arrs
-> AST.OpenAcc aenv arrs
convertSharingAcc _ alyt aenv (AvarSharing sa)
| Just i <- findIndex (matchStableAcc sa) aenv
= AST.OpenAcc $ AST.Avar (prjIdx (ctxt ++ "; i = " ++ show i) i alyt)
| null aenv
= error $ "Cyclic definition of a value of type 'Acc' (sa = " ++
show (hashStableNameHeight sa) ++ ")"
| otherwise
= INTERNAL_ERROR(error) "convertSharingAcc" err
where
ctxt = "shared 'Acc' tree with stable name " ++ show (hashStableNameHeight sa)
err = "inconsistent valuation @ " ++ ctxt ++ ";\n aenv = " ++ show aenv
convertSharingAcc config alyt aenv (AletSharing sa@(StableSharingAcc _ boundAcc) bodyAcc)
= AST.OpenAcc
$ let alyt' = incLayout alyt `PushLayout` ZeroIdx
in
AST.Alet (convertSharingAcc config alyt aenv boundAcc)
(convertSharingAcc config alyt' (sa:aenv) bodyAcc)
convertSharingAcc config alyt aenv (AccSharing _ preAcc)
= AST.OpenAcc
$ let cvtA :: Arrays a => SharingAcc a -> AST.OpenAcc aenv a
cvtA = convertSharingAcc config alyt aenv
cvtE :: Elt t => RootExp t -> AST.Exp aenv t
cvtE = convertRootExp config alyt aenv
cvtF1 :: (Elt a, Elt b) => (Exp a -> RootExp b) -> AST.Fun aenv (a -> b)
cvtF1 = convertSharingFun1 config alyt aenv
cvtF2 :: (Elt a, Elt b, Elt c) => (Exp a -> Exp b -> RootExp c) -> AST.Fun aenv (a -> b -> c)
cvtF2 = convertSharingFun2 config alyt aenv
in
case preAcc of
Atag i
-> AST.Avar (prjIdx ("de Bruijn conversion tag " ++ show i) i alyt)
Pipe afun1 afun2 acc
-> let alyt' = incLayout alyt `PushLayout` ZeroIdx
boundAcc = aconvert config alyt afun1 `AST.Apply` convertSharingAcc config alyt aenv acc
bodyAcc = aconvert config alyt' afun2 `AST.Apply` AST.OpenAcc (AST.Avar AST.ZeroIdx)
in
AST.Alet (AST.OpenAcc boundAcc) (AST.OpenAcc bodyAcc)
Aforeign ff afun acc
-> let a = recoverAccSharing config
e = recoverExpSharing config
f = floatOutAcc config
in
AST.Aforeign ff (convertAfun a e f afun) (cvtA acc)
Acond b acc1 acc2 -> AST.Acond (cvtE b) (cvtA acc1) (cvtA acc2)
Atuple arrs -> AST.Atuple (convertSharingAtuple config alyt aenv arrs)
Aprj ix a -> AST.Aprj ix (cvtA a)
Use array -> AST.Use (fromArr array)
Unit e -> AST.Unit (cvtE e)
Generate sh f -> AST.Generate (cvtE sh) (cvtF1 f)
Reshape e acc -> AST.Reshape (cvtE e) (cvtA acc)
Replicate ix acc -> mkReplicate (cvtE ix) (cvtA acc)
Slice acc ix -> mkIndex (cvtA acc) (cvtE ix)
Map f acc -> AST.Map (cvtF1 f) (cvtA acc)
ZipWith f acc1 acc2 -> AST.ZipWith (cvtF2 f) (cvtA acc1) (cvtA acc2)
Fold f e acc -> AST.Fold (cvtF2 f) (cvtE e) (cvtA acc)
Fold1 f acc -> AST.Fold1 (cvtF2 f) (cvtA acc)
FoldSeg f e acc1 acc2 -> AST.FoldSeg (cvtF2 f) (cvtE e) (cvtA acc1) (cvtA acc2)
Fold1Seg f acc1 acc2 -> AST.Fold1Seg (cvtF2 f) (cvtA acc1) (cvtA acc2)
Scanl f e acc -> AST.Scanl (cvtF2 f) (cvtE e) (cvtA acc)
Scanl' f e acc -> AST.Scanl' (cvtF2 f) (cvtE e) (cvtA acc)
Scanl1 f acc -> AST.Scanl1 (cvtF2 f) (cvtA acc)
Scanr f e acc -> AST.Scanr (cvtF2 f) (cvtE e) (cvtA acc)
Scanr' f e acc -> AST.Scanr' (cvtF2 f) (cvtE e) (cvtA acc)
Scanr1 f acc -> AST.Scanr1 (cvtF2 f) (cvtA acc)
Permute f dftAcc perm acc -> AST.Permute (cvtF2 f) (cvtA dftAcc) (cvtF1 perm) (cvtA acc)
Backpermute newDim perm acc -> AST.Backpermute (cvtE newDim) (cvtF1 perm) (cvtA acc)
Stencil stencil boundary acc
-> AST.Stencil (convertSharingStencilFun1 config acc alyt aenv stencil)
(convertBoundary boundary)
(cvtA acc)
Stencil2 stencil bndy1 acc1 bndy2 acc2
-> AST.Stencil2 (convertSharingStencilFun2 config acc1 acc2 alyt aenv stencil)
(convertBoundary bndy1)
(cvtA acc1)
(convertBoundary bndy2)
(cvtA acc2)
convertSharingAtuple
:: forall aenv a.
Config
-> Layout aenv aenv
-> [StableSharingAcc]
-> Tuple.Atuple SharingAcc a
-> Tuple.Atuple (AST.OpenAcc aenv) a
convertSharingAtuple config alyt aenv = cvt
where
cvt :: Tuple.Atuple SharingAcc a' -> Tuple.Atuple (AST.OpenAcc aenv) a'
cvt NilAtup = NilAtup
cvt (SnocAtup t a) = cvt t `SnocAtup` convertSharingAcc config alyt aenv a
convertBoundary :: Elt e => Boundary e -> Boundary (EltRepr e)
convertBoundary Clamp = Clamp
convertBoundary Mirror = Mirror
convertBoundary Wrap = Wrap
convertBoundary (Constant e) = Constant (fromElt e)
mkIndex :: forall slix e aenv. (Slice slix, Elt e)
=> AST.OpenAcc aenv (Array (FullShape slix) e)
-> AST.Exp aenv slix
-> AST.PreOpenAcc AST.OpenAcc aenv (Array (SliceShape slix) e)
mkIndex = AST.Slice (sliceIndex slix)
where
slix = undefined :: slix
mkReplicate :: forall slix e aenv. (Slice slix, Elt e)
=> AST.Exp aenv slix
-> AST.OpenAcc aenv (Array (SliceShape slix) e)
-> AST.PreOpenAcc AST.OpenAcc aenv (Array (FullShape slix) e)
mkReplicate = AST.Replicate (sliceIndex slix)
where
slix = undefined :: slix
convertFun :: Function f => Bool -> f -> AST.Fun () (FunctionR f)
convertFun shareExp =
let config = Config False shareExp False
in convert config EmptyLayout
class Function f where
type FunctionR f
convert :: Config -> Layout env env -> f -> AST.OpenFun env () (FunctionR f)
instance (Elt a, Function r) => Function (Exp a -> r) where
type FunctionR (Exp a -> r) = a -> FunctionR r
convert config lyt f
= let x = Exp $ Tag (sizeLayout lyt)
lyt' = incLayout lyt `PushLayout` ZeroIdx
in
Lam $ convert config lyt' (f x)
instance Elt b => Function (Exp b) where
type FunctionR (Exp b) = b
convert config lyt body
= let lvl = sizeLayout lyt
vars = [lvl1, lvl2 .. 0]
in
Body $ convertOpenExp config lvl vars lyt body
convertExp
:: Elt e
=> Bool
-> Exp e
-> AST.Exp () e
convertExp shareExp exp
= let config = Config False shareExp False
in
convertOpenExp config 0 [] EmptyLayout exp
convertOpenExp
:: Elt e
=> Config
-> Level
-> [Level]
-> Layout env env
-> Exp e
-> AST.OpenExp env () e
convertOpenExp config lvl fvar lyt exp
= let (sharingExp, initialEnv) = recoverSharingExp config lvl fvar exp
in
convertSharingExp config lyt EmptyLayout initialEnv [] sharingExp
convertSharingExp
:: forall t env aenv. Elt t
=> Config
-> Layout env env
-> Layout aenv aenv
-> [StableSharingExp]
-> [StableSharingAcc]
-> SharingExp t
-> AST.OpenExp env aenv t
convertSharingExp config lyt alyt env aenv = cvt
where
cvt :: Elt t' => SharingExp t' -> AST.OpenExp env aenv t'
cvt (VarSharing se)
| Just i <- findIndex (matchStableExp se) env
= AST.Var (prjIdx (ctxt ++ "; i = " ++ show i) i lyt)
| null env
= error $ "Cyclic definition of a value of type 'Exp' (sa = " ++ show (hashStableNameHeight se) ++ ")"
| otherwise
= INTERNAL_ERROR(error) "convertSharingExp" err
where
ctxt = "shared 'Exp' tree with stable name " ++ show (hashStableNameHeight se)
err = "inconsistent valuation @ " ++ ctxt ++ ";\n env = " ++ show env
cvt (LetSharing se@(StableSharingExp _ boundExp) bodyExp)
= let lyt' = incLayout lyt `PushLayout` ZeroIdx
in
AST.Let (cvt boundExp) (convertSharingExp config lyt' alyt (se:env) aenv bodyExp)
cvt (ExpSharing _ pexp)
= case pexp of
Tag i -> AST.Var (prjIdx ("de Bruijn conversion tag " ++ show i) i lyt)
Const v -> AST.Const (fromElt v)
Tuple tup -> AST.Tuple (cvtT tup)
Prj idx e -> AST.Prj idx (cvt e)
IndexNil -> AST.IndexNil
IndexCons ix i -> AST.IndexCons (cvt ix) (cvt i)
IndexHead i -> AST.IndexHead (cvt i)
IndexTail ix -> AST.IndexTail (cvt ix)
IndexAny -> AST.IndexAny
ToIndex sh ix -> AST.ToIndex (cvt sh) (cvt ix)
FromIndex sh e -> AST.FromIndex (cvt sh) (cvt e)
Cond e1 e2 e3 -> AST.Cond (cvt e1) (cvt e2) (cvt e3)
PrimConst c -> AST.PrimConst c
PrimApp f e -> cvtPrimFun f (cvt e)
Index a e -> AST.Index (cvtA a) (cvt e)
LinearIndex a i -> AST.LinearIndex (cvtA a) (cvt i)
Shape a -> AST.Shape (cvtA a)
ShapeSize e -> AST.ShapeSize (cvt e)
Foreign ff f e -> AST.Foreign ff (convertFun (recoverExpSharing config) f) (cvt e)
cvtA :: Arrays a => SharingAcc a -> AST.OpenAcc aenv a
cvtA = convertSharingAcc config alyt aenv
cvtT :: Tuple.Tuple SharingExp tup -> Tuple.Tuple (AST.OpenExp env aenv) tup
cvtT = convertSharingTuple config lyt alyt env aenv
cvtPrimFun :: (Elt a, Elt r)
=> AST.PrimFun (a -> r) -> AST.OpenExp env' aenv' a -> AST.OpenExp env' aenv' r
cvtPrimFun f e = case e of
AST.Let bnd body -> AST.Let bnd (cvtPrimFun f body)
x -> AST.PrimApp f x
convertSharingTuple
:: Config
-> Layout env env
-> Layout aenv aenv
-> [StableSharingExp]
-> [StableSharingAcc]
-> Tuple.Tuple SharingExp t
-> Tuple.Tuple (AST.OpenExp env aenv) t
convertSharingTuple config lyt alyt env aenv tup =
case tup of
NilTup -> NilTup
SnocTup t e -> convertSharingTuple config lyt alyt env aenv t
`SnocTup` convertSharingExp config lyt alyt env aenv e
convertRootExp
:: Elt t
=> Config
-> Layout aenv aenv
-> [StableSharingAcc]
-> RootExp t
-> AST.Exp aenv t
convertRootExp config alyt aenv exp
= case exp of
EnvExp env exp -> convertSharingExp config EmptyLayout alyt env aenv exp
_ -> INTERNAL_ERROR(error) "convertRootExp" "not an 'EnvExp'"
convertSharingFun1
:: forall a b aenv. (Elt a, Elt b)
=> Config
-> Layout aenv aenv
-> [StableSharingAcc]
-> (Exp a -> RootExp b)
-> AST.Fun aenv (a -> b)
convertSharingFun1 config alyt aenv f = Lam (Body openF)
where
a = Exp undefined
lyt = EmptyLayout
`PushLayout`
(ZeroIdx :: Idx ((), a) a)
EnvExp env body = f a
openF = convertSharingExp config lyt alyt env aenv body
convertSharingFun2
:: forall a b c aenv. (Elt a, Elt b, Elt c)
=> Config
-> Layout aenv aenv
-> [StableSharingAcc]
-> (Exp a -> Exp b -> RootExp c)
-> AST.Fun aenv (a -> b -> c)
convertSharingFun2 config alyt aenv f = Lam (Lam (Body openF))
where
a = Exp undefined
b = Exp undefined
lyt = EmptyLayout
`PushLayout`
(SuccIdx ZeroIdx :: Idx (((), a), b) a)
`PushLayout`
(ZeroIdx :: Idx (((), a), b) b)
EnvExp env body = f a b
openF = convertSharingExp config lyt alyt env aenv body
convertSharingStencilFun1
:: forall sh a stencil b aenv. (Elt a, Stencil sh a stencil, Elt b)
=> Config
-> SharingAcc (Array sh a)
-> Layout aenv aenv
-> [StableSharingAcc]
-> (stencil -> RootExp b)
-> AST.Fun aenv (StencilRepr sh stencil -> b)
convertSharingStencilFun1 config _ alyt aenv stencilFun = Lam (Body openStencilFun)
where
stencil = Exp undefined :: Exp (StencilRepr sh stencil)
lyt = EmptyLayout
`PushLayout`
(ZeroIdx :: Idx ((), StencilRepr sh stencil)
(StencilRepr sh stencil))
EnvExp env body = stencilFun (stencilPrj (undefined::sh) (undefined::a) stencil)
openStencilFun = convertSharingExp config lyt alyt env aenv body
convertSharingStencilFun2
:: forall sh a b stencil1 stencil2 c aenv.
(Elt a, Stencil sh a stencil1,
Elt b, Stencil sh b stencil2,
Elt c)
=> Config
-> SharingAcc (Array sh a)
-> SharingAcc (Array sh b)
-> Layout aenv aenv
-> [StableSharingAcc]
-> (stencil1 -> stencil2 -> RootExp c)
-> AST.Fun aenv (StencilRepr sh stencil1 -> StencilRepr sh stencil2 -> c)
convertSharingStencilFun2 config _ _ alyt aenv stencilFun = Lam (Lam (Body openStencilFun))
where
stencil1 = Exp undefined :: Exp (StencilRepr sh stencil1)
stencil2 = Exp undefined :: Exp (StencilRepr sh stencil2)
lyt = EmptyLayout
`PushLayout`
(SuccIdx ZeroIdx :: Idx (((), StencilRepr sh stencil1),
StencilRepr sh stencil2)
(StencilRepr sh stencil1))
`PushLayout`
(ZeroIdx :: Idx (((), StencilRepr sh stencil1),
StencilRepr sh stencil2)
(StencilRepr sh stencil2))
EnvExp env body = stencilFun (stencilPrj (undefined::sh) (undefined::a) stencil1)
(stencilPrj (undefined::sh) (undefined::b) stencil2)
openStencilFun = convertSharingExp config lyt alyt env aenv body
data StableASTName c where
StableASTName :: (Typeable1 c, Typeable t) => StableName (c t) -> StableASTName c
instance Show (StableASTName c) where
show (StableASTName sn) = show $ hashStableName sn
instance Eq (StableASTName c) where
StableASTName sn1 == StableASTName sn2
| Just sn1' <- gcast sn1 = sn1' == sn2
| otherwise = False
instance Hashable (StableASTName c) where
hashWithSalt s (StableASTName sn) = hashWithSalt s sn
makeStableAST :: c t -> IO (StableName (c t))
makeStableAST e = e `seq` makeStableName e
data StableNameHeight t = StableNameHeight (StableName t) Int
instance Eq (StableNameHeight t) where
(StableNameHeight sn1 _) == (StableNameHeight sn2 _) = sn1 == sn2
higherSNH :: StableNameHeight t1 -> StableNameHeight t2 -> Bool
StableNameHeight _ h1 `higherSNH` StableNameHeight _ h2 = h1 > h2
hashStableNameHeight :: StableNameHeight t -> Int
hashStableNameHeight (StableNameHeight sn _) = hashStableName sn
type HashTable key val = Hash.BasicHashTable key val
type ASTHashTable c v = HashTable (StableASTName c) v
type OccMapHash c = ASTHashTable c (Int, Int)
newASTHashTable :: IO (ASTHashTable c v)
newASTHashTable = Hash.new
enterOcc :: OccMapHash c -> StableASTName c -> Int -> IO (Maybe Int)
enterOcc occMap sa height
= do
entry <- Hash.lookup occMap sa
case entry of
Nothing -> Hash.insert occMap sa (1 , height) >> return Nothing
Just (n, heightS) -> Hash.insert occMap sa (n + 1, heightS) >> return (Just heightS)
type OccMap c = IntMap.IntMap [(StableASTName c, Int)]
freezeOccMap :: OccMapHash c -> IO (OccMap c)
freezeOccMap oc
= do
ocl <- Hash.toList oc
traceChunk "OccMap" (show ocl)
return . IntMap.fromList
. map (\kvs -> (key (head kvs), kvs))
. groupBy sameKey
. map dropHeight
$ ocl
where
key (StableASTName sn, _) = hashStableName sn
sameKey kv1 kv2 = key kv1 == key kv2
dropHeight (k, (cnt, _)) = (k, cnt)
lookupWithASTName :: OccMap c -> StableASTName c -> Int
lookupWithASTName oc sa@(StableASTName sn)
= fromMaybe 1 $ IntMap.lookup (hashStableName sn) oc >>= Prelude.lookup sa
lookupWithSharingAcc :: OccMap Acc -> StableSharingAcc -> Int
lookupWithSharingAcc oc (StableSharingAcc (StableNameHeight sn _) _)
= lookupWithASTName oc (StableASTName sn)
lookupWithSharingExp :: OccMap Exp -> StableSharingExp -> Int
lookupWithSharingExp oc (StableSharingExp (StableNameHeight sn _) _)
= lookupWithASTName oc (StableASTName sn)
type StableAccName arrs = StableNameHeight (Acc arrs)
data SharingAcc arrs where
AvarSharing :: Arrays arrs
=> StableAccName arrs -> SharingAcc arrs
AletSharing :: StableSharingAcc -> SharingAcc arrs -> SharingAcc arrs
AccSharing :: Arrays arrs
=> StableAccName arrs -> PreAcc SharingAcc RootExp arrs -> SharingAcc arrs
data StableSharingAcc where
StableSharingAcc :: Arrays arrs
=> StableAccName arrs
-> SharingAcc arrs
-> StableSharingAcc
instance Show StableSharingAcc where
show (StableSharingAcc sn _) = show $ hashStableNameHeight sn
instance Eq StableSharingAcc where
StableSharingAcc sn1 _ == StableSharingAcc sn2 _
| Just sn1' <- gcast sn1 = sn1' == sn2
| otherwise = False
higherSSA :: StableSharingAcc -> StableSharingAcc -> Bool
StableSharingAcc sn1 _ `higherSSA` StableSharingAcc sn2 _ = sn1 `higherSNH` sn2
matchStableAcc :: Typeable arrs => StableAccName arrs -> StableSharingAcc -> Bool
matchStableAcc sn1 (StableSharingAcc sn2 _)
| Just sn1' <- gcast sn1 = sn1' == sn2
| otherwise = False
noStableAccName :: StableAccName arrs
noStableAccName = unsafePerformIO $ StableNameHeight <$> makeStableName undefined <*> pure 0
type StableExpName t = StableNameHeight (Exp t)
data SharingExp t where
VarSharing :: Elt t
=> StableExpName t -> SharingExp t
LetSharing :: StableSharingExp -> SharingExp t -> SharingExp t
ExpSharing :: Elt t
=> StableExpName t -> PreExp SharingAcc SharingExp t -> SharingExp t
data RootExp t where
OccMapExp :: [Int] -> OccMap Exp -> SharingExp t -> RootExp t
EnvExp :: [StableSharingExp] -> SharingExp t -> RootExp t
data StableSharingExp where
StableSharingExp :: Elt t => StableExpName t -> SharingExp t -> StableSharingExp
instance Show StableSharingExp where
show (StableSharingExp sn _) = show $ hashStableNameHeight sn
instance Eq StableSharingExp where
StableSharingExp sn1 _ == StableSharingExp sn2 _
| Just sn1' <- gcast sn1 = sn1' == sn2
| otherwise = False
higherSSE :: StableSharingExp -> StableSharingExp -> Bool
StableSharingExp sn1 _ `higherSSE` StableSharingExp sn2 _ = sn1 `higherSNH` sn2
matchStableExp :: Typeable t => StableExpName t -> StableSharingExp -> Bool
matchStableExp sn1 (StableSharingExp sn2 _)
| Just sn1' <- gcast sn1 = sn1' == sn2
| otherwise = False
noStableExpName :: StableExpName t
noStableExpName = unsafePerformIO $ StableNameHeight <$> makeStableName undefined <*> pure 0
makeOccMapAcc
:: Typeable arrs
=> Config
-> Level
-> Acc arrs
-> IO (SharingAcc arrs, OccMap Acc)
makeOccMapAcc config lvl acc = do
traceLine "makeOccMapAcc" "Enter"
accOccMap <- newASTHashTable
(acc', _) <- makeOccMapSharingAcc config accOccMap lvl acc
frozenAccOccMap <- freezeOccMap accOccMap
traceLine "makeOccMapAcc" "Exit"
return (acc', frozenAccOccMap)
makeOccMapSharingAcc
:: Typeable arrs
=> Config
-> OccMapHash Acc
-> Level
-> Acc arrs
-> IO (SharingAcc arrs, Int)
makeOccMapSharingAcc config accOccMap = traverseAcc
where
traverseFun1 :: (Elt a, Typeable b) => Level -> (Exp a -> Exp b) -> IO (Exp a -> RootExp b, Int)
traverseFun1 = makeOccMapFun1 config accOccMap
traverseFun2 :: (Elt a, Elt b, Typeable c)
=> Level
-> (Exp a -> Exp b -> Exp c)
-> IO (Exp a -> Exp b -> RootExp c, Int)
traverseFun2 = makeOccMapFun2 config accOccMap
traverseExp :: Typeable e => Level -> Exp e -> IO (RootExp e, Int)
traverseExp = makeOccMapExp config accOccMap
traverseAcc :: forall arrs. Typeable arrs => Level -> Acc arrs -> IO (SharingAcc arrs, Int)
traverseAcc lvl acc@(Acc pacc)
= mfix $ \ ~(_, height) -> do
sn <- makeStableAST acc
heightIfRepeatedOccurrence <- enterOcc accOccMap (StableASTName sn) height
traceLine (showPreAccOp pacc) $ do
let hash = show (hashStableName sn)
case heightIfRepeatedOccurrence of
Just height -> "REPEATED occurrence (sn = " ++ hash ++ "; height = " ++ show height ++ ")"
Nothing -> "first occurrence (sn = " ++ hash ++ ")"
let reconstruct :: Arrays arrs
=> IO (PreAcc SharingAcc RootExp arrs, Int)
-> IO (SharingAcc arrs, Int)
reconstruct newAcc
= case heightIfRepeatedOccurrence of
Just height | recoverAccSharing config
-> return (AvarSharing (StableNameHeight sn height), height)
_ -> do (acc, height) <- newAcc
return (AccSharing (StableNameHeight sn height) acc, height)
case pacc of
Atag i -> reconstruct $ return (Atag i, 0)
Pipe afun1 afun2 acc -> reconstruct $ travA (Pipe afun1 afun2) acc
Aforeign ff afun acc -> reconstruct $ travA (Aforeign ff afun) acc
Acond e acc1 acc2 -> reconstruct $ do
(e' , h1) <- traverseExp lvl e
(acc1', h2) <- traverseAcc lvl acc1
(acc2', h3) <- traverseAcc lvl acc2
return (Acond e' acc1' acc2', h1 `max` h2 `max` h3 + 1)
Atuple tup -> reconstruct $ do
(tup', h) <- travAtup tup
return (Atuple tup', h)
Aprj ix a -> reconstruct $ travA (Aprj ix) a
Use arr -> reconstruct $ return (Use arr, 1)
Unit e -> reconstruct $ do
(e', h) <- traverseExp lvl e
return (Unit e', h + 1)
Generate e f -> reconstruct $ do
(e', h1) <- traverseExp lvl e
(f', h2) <- traverseFun1 lvl f
return (Generate e' f', h1 `max` h2 + 1)
Reshape e acc -> reconstruct $ travEA Reshape e acc
Replicate e acc -> reconstruct $ travEA Replicate e acc
Slice acc e -> reconstruct $ travEA (flip Slice) e acc
Map f acc -> reconstruct $ do
(f' , h1) <- traverseFun1 lvl f
(acc', h2) <- traverseAcc lvl acc
return (Map f' acc', h1 `max` h2 + 1)
ZipWith f acc1 acc2 -> reconstruct $ travF2A2 ZipWith f acc1 acc2
Fold f e acc -> reconstruct $ travF2EA Fold f e acc
Fold1 f acc -> reconstruct $ travF2A Fold1 f acc
FoldSeg f e acc1 acc2 -> reconstruct $ do
(f' , h1) <- traverseFun2 lvl f
(e' , h2) <- traverseExp lvl e
(acc1', h3) <- traverseAcc lvl acc1
(acc2', h4) <- traverseAcc lvl acc2
return (FoldSeg f' e' acc1' acc2',
h1 `max` h2 `max` h3 `max` h4 + 1)
Fold1Seg f acc1 acc2 -> reconstruct $ travF2A2 Fold1Seg f acc1 acc2
Scanl f e acc -> reconstruct $ travF2EA Scanl f e acc
Scanl' f e acc -> reconstruct $ travF2EA Scanl' f e acc
Scanl1 f acc -> reconstruct $ travF2A Scanl1 f acc
Scanr f e acc -> reconstruct $ travF2EA Scanr f e acc
Scanr' f e acc -> reconstruct $ travF2EA Scanr' f e acc
Scanr1 f acc -> reconstruct $ travF2A Scanr1 f acc
Permute c acc1 p acc2 -> reconstruct $ do
(c' , h1) <- traverseFun2 lvl c
(p' , h2) <- traverseFun1 lvl p
(acc1', h3) <- traverseAcc lvl acc1
(acc2', h4) <- traverseAcc lvl acc2
return (Permute c' acc1' p' acc2',
h1 `max` h2 `max` h3 `max` h4 + 1)
Backpermute e p acc -> reconstruct $ do
(e' , h1) <- traverseExp lvl e
(p' , h2) <- traverseFun1 lvl p
(acc', h3) <- traverseAcc lvl acc
return (Backpermute e' p' acc', h1 `max` h2 `max` h3 + 1)
Stencil s bnd acc -> reconstruct $ do
(s' , h1) <- makeOccMapStencil1 config accOccMap acc lvl s
(acc', h2) <- traverseAcc lvl acc
return (Stencil s' bnd acc', h1 `max` h2 + 1)
Stencil2 s bnd1 acc1
bnd2 acc2 -> reconstruct $ do
(s' , h1) <- makeOccMapStencil2 config accOccMap acc1 acc2 lvl s
(acc1', h2) <- traverseAcc lvl acc1
(acc2', h3) <- traverseAcc lvl acc2
return (Stencil2 s' bnd1 acc1' bnd2 acc2',
h1 `max` h2 `max` h3 + 1)
where
travA :: Arrays arrs'
=> (SharingAcc arrs' -> PreAcc SharingAcc RootExp arrs)
-> Acc arrs' -> IO (PreAcc SharingAcc RootExp arrs, Int)
travA c acc
= do
(acc', h) <- traverseAcc lvl acc
return (c acc', h + 1)
travEA :: (Typeable b, Arrays arrs')
=> (RootExp b -> SharingAcc arrs' -> PreAcc SharingAcc RootExp arrs)
-> Exp b -> Acc arrs' -> IO (PreAcc SharingAcc RootExp arrs, Int)
travEA c exp acc
= do
(exp', h1) <- traverseExp lvl exp
(acc', h2) <- traverseAcc lvl acc
return (c exp' acc', h1 `max` h2 + 1)
travF2A :: (Elt b, Elt c, Typeable d, Arrays arrs')
=> ((Exp b -> Exp c -> RootExp d) -> SharingAcc arrs'
-> PreAcc SharingAcc RootExp arrs)
-> (Exp b -> Exp c -> Exp d) -> Acc arrs'
-> IO (PreAcc SharingAcc RootExp arrs, Int)
travF2A c fun acc
= do
(fun', h1) <- traverseFun2 lvl fun
(acc', h2) <- traverseAcc lvl acc
return (c fun' acc', h1 `max` h2 + 1)
travF2EA :: (Elt b, Elt c, Typeable d, Typeable e, Arrays arrs')
=> ((Exp b -> Exp c -> RootExp d) -> RootExp e -> SharingAcc arrs' -> PreAcc SharingAcc RootExp arrs)
-> (Exp b -> Exp c -> Exp d) -> Exp e -> Acc arrs'
-> IO (PreAcc SharingAcc RootExp arrs, Int)
travF2EA c fun exp acc
= do
(fun', h1) <- traverseFun2 lvl fun
(exp', h2) <- traverseExp lvl exp
(acc', h3) <- traverseAcc lvl acc
return (c fun' exp' acc', h1 `max` h2 `max` h3 + 1)
travF2A2 :: (Elt b, Elt c, Typeable d, Arrays arrs1, Arrays arrs2)
=> ((Exp b -> Exp c -> RootExp d) -> SharingAcc arrs1 -> SharingAcc arrs2 -> PreAcc SharingAcc RootExp arrs)
-> (Exp b -> Exp c -> Exp d) -> Acc arrs1 -> Acc arrs2
-> IO (PreAcc SharingAcc RootExp arrs, Int)
travF2A2 c fun acc1 acc2
= do
(fun' , h1) <- traverseFun2 lvl fun
(acc1', h2) <- traverseAcc lvl acc1
(acc2', h3) <- traverseAcc lvl acc2
return (c fun' acc1' acc2', h1 `max` h2 `max` h3 + 1)
travAtup :: Tuple.Atuple Acc a
-> IO (Tuple.Atuple SharingAcc a, Int)
travAtup NilAtup = return (NilAtup, 1)
travAtup (SnocAtup tup a) = do
(tup', h1) <- travAtup tup
(a', h2) <- traverseAcc lvl a
return (SnocAtup tup' a', h1 `max` h2 + 1)
makeOccMapExp
:: Typeable e
=> Config
-> OccMapHash Acc
-> Level
-> Exp e
-> IO (RootExp e, Int)
makeOccMapExp config accOccMap lvl = makeOccMapRootExp config accOccMap lvl []
makeOccMapFun1
:: (Elt a, Typeable b)
=> Config
-> OccMapHash Acc
-> Level
-> (Exp a -> Exp b)
-> IO (Exp a -> RootExp b, Int)
makeOccMapFun1 config accOccMap lvl f = do
let x = Exp (Tag lvl)
(body, height) <- makeOccMapRootExp config accOccMap (lvl+1) [lvl] (f x)
return (const body, height)
makeOccMapFun2
:: (Elt a, Elt b, Typeable c)
=> Config
-> OccMapHash Acc
-> Level
-> (Exp a -> Exp b -> Exp c)
-> IO (Exp a -> Exp b -> RootExp c, Int)
makeOccMapFun2 config accOccMap lvl f = do
let x = Exp (Tag (lvl+1))
y = Exp (Tag lvl)
(body, height) <- makeOccMapRootExp config accOccMap (lvl+2) [lvl, lvl+1] (f x y)
return (\_ _ -> body, height)
makeOccMapStencil1
:: forall sh a b stencil. (Stencil sh a stencil, Typeable b)
=> Config
-> OccMapHash Acc
-> Acc (Array sh a)
-> Level
-> (stencil -> Exp b)
-> IO (stencil -> RootExp b, Int)
makeOccMapStencil1 config accOccMap _ lvl stencil = do
let x = Exp (Tag lvl)
f = stencil . stencilPrj (undefined::sh) (undefined::a)
(body, height) <- makeOccMapRootExp config accOccMap (lvl+1) [lvl] (f x)
return (const body, height)
makeOccMapStencil2
:: forall sh a b c stencil1 stencil2. (Stencil sh a stencil1, Stencil sh b stencil2, Typeable c)
=> Config
-> OccMapHash Acc
-> Acc (Array sh a)
-> Acc (Array sh b)
-> Level
-> (stencil1 -> stencil2 -> Exp c)
-> IO (stencil1 -> stencil2 -> RootExp c, Int)
makeOccMapStencil2 config accOccMap _ _ lvl stencil = do
let x = Exp (Tag (lvl+1))
y = Exp (Tag lvl)
f a b = stencil (stencilPrj (undefined::sh) (undefined::a) a)
(stencilPrj (undefined::sh) (undefined::b) b)
(body, height) <- makeOccMapRootExp config accOccMap (lvl+2) [lvl, lvl+1] (f x y)
return (\_ _ -> body, height)
makeOccMapRootExp
:: Typeable e
=> Config
-> OccMapHash Acc
-> Level
-> [Int]
-> Exp e
-> IO (RootExp e, Int)
makeOccMapRootExp config accOccMap lvl fvs exp = do
traceLine "makeOccMapRootExp" "Enter"
expOccMap <- newASTHashTable
(exp', height) <- makeOccMapSharingExp config accOccMap expOccMap lvl exp
frozenExpOccMap <- freezeOccMap expOccMap
traceLine "makeOccMapRootExp" "Exit"
return (OccMapExp fvs frozenExpOccMap exp', height)
makeOccMapSharingExp
:: Typeable e
=> Config
-> OccMapHash Acc
-> OccMapHash Exp
-> Level
-> Exp e
-> IO (SharingExp e, Int)
makeOccMapSharingExp config accOccMap expOccMap = travE
where
travE :: forall a. Typeable a => Level -> Exp a -> IO (SharingExp a, Int)
travE lvl exp@(Exp pexp)
= mfix $ \ ~(_, height) -> do
sn <- makeStableAST exp
heightIfRepeatedOccurrence <- enterOcc expOccMap (StableASTName sn) height
traceLine (showPreExpOp pexp) $ do
let hash = show (hashStableName sn)
case heightIfRepeatedOccurrence of
Just height -> "REPEATED occurrence (sn = " ++ hash ++ "; height = " ++ show height ++ ")"
Nothing -> "first occurrence (sn = " ++ hash ++ ")"
let reconstruct :: Elt a
=> IO (PreExp SharingAcc SharingExp a, Int)
-> IO (SharingExp a, Int)
reconstruct newExp
= case heightIfRepeatedOccurrence of
Just height | recoverExpSharing config
-> return (VarSharing (StableNameHeight sn height), height)
_ -> do (exp, height) <- newExp
return (ExpSharing (StableNameHeight sn height) exp, height)
case pexp of
Tag i -> reconstruct $ return (Tag i, 0)
Const c -> reconstruct $ return (Const c, 1)
Tuple tup -> reconstruct $ do
(tup', h) <- travTup tup
return (Tuple tup', h)
Prj i e -> reconstruct $ travE1 (Prj i) e
IndexNil -> reconstruct $ return (IndexNil, 1)
IndexCons ix i -> reconstruct $ travE2 IndexCons ix i
IndexHead i -> reconstruct $ travE1 IndexHead i
IndexTail ix -> reconstruct $ travE1 IndexTail ix
IndexAny -> reconstruct $ return (IndexAny, 1)
ToIndex sh ix -> reconstruct $ travE2 ToIndex sh ix
FromIndex sh e -> reconstruct $ travE2 FromIndex sh e
Cond e1 e2 e3 -> reconstruct $ travE3 Cond e1 e2 e3
PrimConst c -> reconstruct $ return (PrimConst c, 1)
PrimApp p e -> reconstruct $ travE1 (PrimApp p) e
Index a e -> reconstruct $ travAE Index a e
LinearIndex a i -> reconstruct $ travAE LinearIndex a i
Shape a -> reconstruct $ travA Shape a
ShapeSize e -> reconstruct $ travE1 ShapeSize e
Foreign ff f e -> reconstruct $ do
(e', h) <- travE lvl e
return (Foreign ff f e', h+1)
where
traverseAcc :: Typeable arrs => Level -> Acc arrs -> IO (SharingAcc arrs, Int)
traverseAcc = makeOccMapSharingAcc config accOccMap
travE1 :: Typeable b => (SharingExp b -> PreExp SharingAcc SharingExp a) -> Exp b
-> IO (PreExp SharingAcc SharingExp a, Int)
travE1 c e
= do
(e', h) <- travE lvl e
return (c e', h + 1)
travE2 :: (Typeable b, Typeable c)
=> (SharingExp b -> SharingExp c -> PreExp SharingAcc SharingExp a)
-> Exp b -> Exp c
-> IO (PreExp SharingAcc SharingExp a, Int)
travE2 c e1 e2
= do
(e1', h1) <- travE lvl e1
(e2', h2) <- travE lvl e2
return (c e1' e2', h1 `max` h2 + 1)
travE3 :: (Typeable b, Typeable c, Typeable d)
=> (SharingExp b -> SharingExp c -> SharingExp d -> PreExp SharingAcc SharingExp a)
-> Exp b -> Exp c -> Exp d
-> IO (PreExp SharingAcc SharingExp a, Int)
travE3 c e1 e2 e3
= do
(e1', h1) <- travE lvl e1
(e2', h2) <- travE lvl e2
(e3', h3) <- travE lvl e3
return (c e1' e2' e3', h1 `max` h2 `max` h3 + 1)
travA :: Typeable b => (SharingAcc b -> PreExp SharingAcc SharingExp a) -> Acc b
-> IO (PreExp SharingAcc SharingExp a, Int)
travA c acc
= do
(acc', h) <- traverseAcc lvl acc
return (c acc', h + 1)
travAE :: (Typeable b, Typeable c)
=> (SharingAcc b -> SharingExp c -> PreExp SharingAcc SharingExp a)
-> Acc b -> Exp c
-> IO (PreExp SharingAcc SharingExp a, Int)
travAE c acc e
= do
(acc', h1) <- traverseAcc lvl acc
(e' , h2) <- travE lvl e
return (c acc' e', h1 `max` h2 + 1)
travTup :: Tuple.Tuple Exp tup -> IO (Tuple.Tuple SharingExp tup, Int)
travTup NilTup = return (NilTup, 1)
travTup (SnocTup tup e) = do
(tup', h1) <- travTup tup
(e' , h2) <- travE lvl e
return (SnocTup tup' e', h1 `max` h2 + 1)
type NodeCounts = [NodeCount]
data NodeCount = AccNodeCount StableSharingAcc Int
| ExpNodeCount StableSharingExp Int
deriving Show
noNodeCounts :: NodeCounts
noNodeCounts = []
accNodeCount :: StableSharingAcc -> Int -> NodeCounts
accNodeCount ssa n = [AccNodeCount ssa n]
expNodeCount :: StableSharingExp -> Int -> NodeCounts
expNodeCount sse n = [ExpNodeCount sse n]
(+++) :: NodeCounts -> NodeCounts -> NodeCounts
us +++ vs = foldr insert us vs
where
insert x [] = [x]
insert x@(AccNodeCount sa1 count1) ys@(y@(AccNodeCount sa2 count2) : ys')
| sa1 == sa2 = AccNodeCount (sa1 `pickNoneAvar` sa2) (count1 + count2) : ys'
| sa1 `higherSSA` sa2 = x : ys
| otherwise = y : insert x ys'
insert x@(ExpNodeCount se1 count1) ys@(y@(ExpNodeCount se2 count2) : ys')
| se1 == se2 = ExpNodeCount (se1 `pickNoneVar` se2) (count1 + count2) : ys'
| se1 `higherSSE` se2 = x : ys
| otherwise = y : insert x ys'
insert x@(AccNodeCount _ _) (y@(ExpNodeCount _ _) : ys')
= y : insert x ys'
insert x@(ExpNodeCount _ _) (y@(AccNodeCount _ _) : ys')
= x : insert y ys'
(StableSharingAcc _ (AvarSharing _)) `pickNoneAvar` sa2 = sa2
sa1 `pickNoneAvar` _sa2 = sa1
(StableSharingExp _ (VarSharing _)) `pickNoneVar` sa2 = sa2
sa1 `pickNoneVar` _sa2 = sa1
buildInitialEnvAcc :: [Level] -> [StableSharingAcc] -> [StableSharingAcc]
buildInitialEnvAcc tags sas = map (lookupSA sas) tags
where
lookupSA sas tag1
= case filter hasTag sas of
[] -> noStableSharing
[sa] -> sa
sas2 -> INTERNAL_ERROR(error) "buildInitialEnvAcc"
$ "Encountered duplicate 'ATag's\n " ++ intercalate ", " (map showSA sas2)
where
hasTag (StableSharingAcc _ (AccSharing _ (Atag tag2))) = tag1 == tag2
hasTag sa
= INTERNAL_ERROR(error) "buildInitialEnvAcc"
$ "Encountered a node that is not a plain 'Atag'\n " ++ showSA sa
noStableSharing :: StableSharingAcc
noStableSharing = StableSharingAcc noStableAccName (undefined :: SharingAcc ())
showSA (StableSharingAcc _ (AccSharing sn acc)) = show (hashStableNameHeight sn) ++ ": " ++
showPreAccOp acc
showSA (StableSharingAcc _ (AvarSharing sn)) = "AvarSharing " ++ show (hashStableNameHeight sn)
showSA (StableSharingAcc _ (AletSharing sa _ )) = "AletSharing " ++ show sa ++ "..."
buildInitialEnvExp :: [Level] -> [StableSharingExp] -> [StableSharingExp]
buildInitialEnvExp tags ses = map (lookupSE ses) tags
where
lookupSE ses tag1
= case filter hasTag ses of
[] -> noStableSharing
[se] -> se
ses2 -> INTERNAL_ERROR(error) "buildInitialEnvExp"
("Encountered a duplicate 'Tag'\n " ++ intercalate ", " (map showSE ses2))
where
hasTag (StableSharingExp _ (ExpSharing _ (Tag tag2))) = tag1 == tag2
hasTag se
= INTERNAL_ERROR(error) "buildInitialEnvExp"
("Encountered a node that is not a plain 'Tag'\n " ++ showSE se)
noStableSharing :: StableSharingExp
noStableSharing = StableSharingExp noStableExpName (undefined :: SharingExp ())
showSE (StableSharingExp _ (ExpSharing sn exp)) = show (hashStableNameHeight sn) ++ ": " ++
showPreExpOp exp
showSE (StableSharingExp _ (VarSharing sn)) = "VarSharing " ++ show (hashStableNameHeight sn)
showSE (StableSharingExp _ (LetSharing se _ )) = "LetSharing " ++ show se ++ "..."
isFreeVar :: NodeCount -> Bool
isFreeVar (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag _))) _) = True
isFreeVar (ExpNodeCount (StableSharingExp _ (ExpSharing _ (Tag _))) _) = True
isFreeVar _ = False
determineScopesAcc
:: Typeable a
=> Config
-> [Level]
-> OccMap Acc
-> SharingAcc a
-> (SharingAcc a, [StableSharingAcc])
determineScopesAcc config fvs accOccMap rootAcc
= let (sharingAcc, counts) = determineScopesSharingAcc config accOccMap rootAcc
unboundTrees = filter (not . isFreeVar) counts
in
if all isFreeVar counts
then (sharingAcc, buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- counts])
else INTERNAL_ERROR(error) "determineScopesAcc" ("unbound shared subtrees" ++ show unboundTrees)
determineScopesSharingAcc
:: Config
-> OccMap Acc
-> SharingAcc a
-> (SharingAcc a, NodeCounts)
determineScopesSharingAcc config accOccMap = scopesAcc
where
scopesAcc :: forall arrs. SharingAcc arrs -> (SharingAcc arrs, NodeCounts)
scopesAcc (AletSharing _ _)
= INTERNAL_ERROR(error) "determineScopesSharingAcc: scopesAcc" "unexpected 'AletSharing'"
scopesAcc sharingAcc@(AvarSharing sn)
= (sharingAcc, StableSharingAcc sn sharingAcc `accNodeCount` 1)
scopesAcc (AccSharing sn pacc)
= case pacc of
Atag i -> reconstruct (Atag i) noNodeCounts
Pipe afun1 afun2 acc -> travA (Pipe afun1 afun2) acc
Aforeign ff afun acc -> let
(acc', accCount) = scopesAcc acc
in
reconstruct (Aforeign ff afun acc') accCount
Acond e acc1 acc2 -> let
(e' , accCount1) = scopesExp e
(acc1', accCount2) = scopesAcc acc1
(acc2', accCount3) = scopesAcc acc2
in
reconstruct (Acond e' acc1' acc2')
(accCount1 +++ accCount2 +++ accCount3)
Atuple tup -> let (tup', accCount) = travAtup tup
in reconstruct (Atuple tup') accCount
Aprj ix a -> travA (Aprj ix) a
Use arr -> reconstruct (Use arr) noNodeCounts
Unit e -> let
(e', accCount) = scopesExp e
in
reconstruct (Unit e') accCount
Generate sh f -> let
(sh', accCount1) = scopesExp sh
(f' , accCount2) = scopesFun1 f
in
reconstruct (Generate sh' f') (accCount1 +++ accCount2)
Reshape sh acc -> travEA Reshape sh acc
Replicate n acc -> travEA Replicate n acc
Slice acc i -> travEA (flip Slice) i acc
Map f acc -> let
(f' , accCount1) = scopesFun1 f
(acc', accCount2) = scopesAcc acc
in
reconstruct (Map f' acc') (accCount1 +++ accCount2)
ZipWith f acc1 acc2 -> travF2A2 ZipWith f acc1 acc2
Fold f z acc -> travF2EA Fold f z acc
Fold1 f acc -> travF2A Fold1 f acc
FoldSeg f z acc1 acc2 -> let
(f' , accCount1) = scopesFun2 f
(z' , accCount2) = scopesExp z
(acc1', accCount3) = scopesAcc acc1
(acc2', accCount4) = scopesAcc acc2
in
reconstruct (FoldSeg f' z' acc1' acc2')
(accCount1 +++ accCount2 +++ accCount3 +++ accCount4)
Fold1Seg f acc1 acc2 -> travF2A2 Fold1Seg f acc1 acc2
Scanl f z acc -> travF2EA Scanl f z acc
Scanl' f z acc -> travF2EA Scanl' f z acc
Scanl1 f acc -> travF2A Scanl1 f acc
Scanr f z acc -> travF2EA Scanr f z acc
Scanr' f z acc -> travF2EA Scanr' f z acc
Scanr1 f acc -> travF2A Scanr1 f acc
Permute fc acc1 fp acc2 -> let
(fc' , accCount1) = scopesFun2 fc
(acc1', accCount2) = scopesAcc acc1
(fp' , accCount3) = scopesFun1 fp
(acc2', accCount4) = scopesAcc acc2
in
reconstruct (Permute fc' acc1' fp' acc2')
(accCount1 +++ accCount2 +++ accCount3 +++ accCount4)
Backpermute sh fp acc -> let
(sh' , accCount1) = scopesExp sh
(fp' , accCount2) = scopesFun1 fp
(acc', accCount3) = scopesAcc acc
in
reconstruct (Backpermute sh' fp' acc')
(accCount1 +++ accCount2 +++ accCount3)
Stencil st bnd acc -> let
(st' , accCount1) = scopesStencil1 acc st
(acc', accCount2) = scopesAcc acc
in
reconstruct (Stencil st' bnd acc') (accCount1 +++ accCount2)
Stencil2 st bnd1 acc1 bnd2 acc2
-> let
(st' , accCount1) = scopesStencil2 acc1 acc2 st
(acc1', accCount2) = scopesAcc acc1
(acc2', accCount3) = scopesAcc acc2
in
reconstruct (Stencil2 st' bnd1 acc1' bnd2 acc2')
(accCount1 +++ accCount2 +++ accCount3)
where
travEA :: Arrays arrs
=> (RootExp e -> SharingAcc arrs' -> PreAcc SharingAcc RootExp arrs)
-> RootExp e
-> SharingAcc arrs'
-> (SharingAcc arrs, NodeCounts)
travEA c e acc = reconstruct (c e' acc') (accCount1 +++ accCount2)
where
(e' , accCount1) = scopesExp e
(acc', accCount2) = scopesAcc acc
travF2A :: (Elt a, Elt b, Arrays arrs)
=> ((Exp a -> Exp b -> RootExp c) -> SharingAcc arrs'
-> PreAcc SharingAcc RootExp arrs)
-> (Exp a -> Exp b -> RootExp c)
-> SharingAcc arrs'
-> (SharingAcc arrs, NodeCounts)
travF2A c f acc = reconstruct (c f' acc') (accCount1 +++ accCount2)
where
(f' , accCount1) = scopesFun2 f
(acc', accCount2) = scopesAcc acc
travF2EA :: (Elt a, Elt b, Arrays arrs)
=> ((Exp a -> Exp b -> RootExp c) -> RootExp e
-> SharingAcc arrs' -> PreAcc SharingAcc RootExp arrs)
-> (Exp a -> Exp b -> RootExp c)
-> RootExp e
-> SharingAcc arrs'
-> (SharingAcc arrs, NodeCounts)
travF2EA c f e acc = reconstruct (c f' e' acc') (accCount1 +++ accCount2 +++ accCount3)
where
(f' , accCount1) = scopesFun2 f
(e' , accCount2) = scopesExp e
(acc', accCount3) = scopesAcc acc
travF2A2 :: (Elt a, Elt b, Arrays arrs)
=> ((Exp a -> Exp b -> RootExp c) -> SharingAcc arrs1
-> SharingAcc arrs2 -> PreAcc SharingAcc RootExp arrs)
-> (Exp a -> Exp b -> RootExp c)
-> SharingAcc arrs1
-> SharingAcc arrs2
-> (SharingAcc arrs, NodeCounts)
travF2A2 c f acc1 acc2 = reconstruct (c f' acc1' acc2')
(accCount1 +++ accCount2 +++ accCount3)
where
(f' , accCount1) = scopesFun2 f
(acc1', accCount2) = scopesAcc acc1
(acc2', accCount3) = scopesAcc acc2
travAtup :: Tuple.Atuple SharingAcc a
-> (Tuple.Atuple SharingAcc a, NodeCounts)
travAtup NilAtup = (NilAtup, noNodeCounts)
travAtup (SnocAtup tup a) = let (tup', accCountT) = travAtup tup
(a', accCountA) = scopesAcc a
in
(SnocAtup tup' a', accCountT +++ accCountA)
travA :: Arrays arrs
=> (SharingAcc arrs' -> PreAcc SharingAcc RootExp arrs)
-> SharingAcc arrs'
-> (SharingAcc arrs, NodeCounts)
travA c acc = reconstruct (c acc') accCount
where
(acc', accCount) = scopesAcc acc
accOccCount = let StableNameHeight sn' _ = sn
in
lookupWithASTName accOccMap (StableASTName sn')
reconstruct :: Arrays arrs
=> PreAcc SharingAcc RootExp arrs -> NodeCounts
-> (SharingAcc arrs, NodeCounts)
reconstruct newAcc@(Atag _) _subCount
= let thisCount = StableSharingAcc sn (AccSharing sn newAcc) `accNodeCount` 1
in
tracePure "FREE" (show thisCount)
(AvarSharing sn, thisCount)
reconstruct newAcc subCount
| accOccCount > 1 && recoverAccSharing config
= let allCount = (StableSharingAcc sn sharingAcc `accNodeCount` 1) +++ newCount
in
tracePure ("SHARED" ++ completed) (show allCount)
(AvarSharing sn, allCount)
| otherwise
= tracePure ("Normal" ++ completed) (show newCount)
(sharingAcc, newCount)
where
(newCount, bindHere) = filterCompleted subCount
lets = foldl (flip (.)) id . map AletSharing $ bindHere
sharingAcc = lets $ AccSharing sn newAcc
completed | null bindHere = ""
| otherwise = "(" ++ show (length bindHere) ++ " lets)"
filterCompleted :: NodeCounts -> (NodeCounts, [StableSharingAcc])
filterCompleted counts
= let (completed, counts') = break notComplete counts
in (counts', [sa | AccNodeCount sa _ <- completed])
where
notComplete nc@(AccNodeCount sa n) | not . isFreeVar $ nc = lookupWithSharingAcc accOccMap sa > n
notComplete _ = True
scopesExp :: RootExp t -> (RootExp t, NodeCounts)
scopesExp = determineScopesExp config accOccMap
scopesFun1 :: Elt e1 => (Exp e1 -> RootExp e2) -> (Exp e1 -> RootExp e2, NodeCounts)
scopesFun1 f = (const body, counts)
where
(body, counts) = scopesExp (f undefined)
scopesFun2 :: (Elt e1, Elt e2)
=> (Exp e1 -> Exp e2 -> RootExp e3)
-> (Exp e1 -> Exp e2 -> RootExp e3, NodeCounts)
scopesFun2 f = (\_ _ -> body, counts)
where
(body, counts) = scopesExp (f undefined undefined)
scopesStencil1 :: forall sh e1 e2 stencil. Stencil sh e1 stencil
=> SharingAcc (Array sh e1)
-> (stencil -> RootExp e2)
-> (stencil -> RootExp e2, NodeCounts)
scopesStencil1 _ stencilFun = (const body, counts)
where
(body, counts) = scopesExp (stencilFun undefined)
scopesStencil2 :: forall sh e1 e2 e3 stencil1 stencil2.
(Stencil sh e1 stencil1, Stencil sh e2 stencil2)
=> SharingAcc (Array sh e1)
-> SharingAcc (Array sh e2)
-> (stencil1 -> stencil2 -> RootExp e3)
-> (stencil1 -> stencil2 -> RootExp e3, NodeCounts)
scopesStencil2 _ _ stencilFun = (\_ _ -> body, counts)
where
(body, counts) = scopesExp (stencilFun undefined undefined)
determineScopesExp
:: Config
-> OccMap Acc
-> RootExp t
-> (RootExp t, NodeCounts)
determineScopesExp config accOccMap (OccMapExp fvs expOccMap exp)
= let
(expWithScopes, nodeCounts) = determineScopesSharingExp config accOccMap expOccMap exp
(expCounts, accCounts) = break isAccNodeCount nodeCounts
isAccNodeCount AccNodeCount{} = True
isAccNodeCount _ = False
in
(EnvExp (buildInitialEnvExp fvs [se | ExpNodeCount se _ <- expCounts]) expWithScopes, accCounts)
determineScopesExp _ _ _ = INTERNAL_ERROR(error) "determineScopesExp" "not an 'OccMapExp'"
determineScopesSharingExp
:: Config
-> OccMap Acc
-> OccMap Exp
-> SharingExp t
-> (SharingExp t, NodeCounts)
determineScopesSharingExp config accOccMap expOccMap = scopesExp
where
scopesAcc :: SharingAcc a -> (SharingAcc a, NodeCounts)
scopesAcc = determineScopesSharingAcc config accOccMap
scopesExp :: forall t. SharingExp t -> (SharingExp t, NodeCounts)
scopesExp (LetSharing _ _)
= INTERNAL_ERROR(error) "determineScopesSharingExp: scopesExp" "unexpected 'LetSharing'"
scopesExp sharingExp@(VarSharing sn)
= (sharingExp, StableSharingExp sn sharingExp `expNodeCount` 1)
scopesExp (ExpSharing sn pexp)
= case pexp of
Tag i -> reconstruct (Tag i) noNodeCounts
Const c -> reconstruct (Const c) noNodeCounts
Tuple tup -> let (tup', accCount) = travTup tup
in
reconstruct (Tuple tup') accCount
Prj i e -> travE1 (Prj i) e
IndexNil -> reconstruct IndexNil noNodeCounts
IndexCons ix i -> travE2 IndexCons ix i
IndexHead i -> travE1 IndexHead i
IndexTail ix -> travE1 IndexTail ix
IndexAny -> reconstruct IndexAny noNodeCounts
ToIndex sh ix -> travE2 ToIndex sh ix
FromIndex sh e -> travE2 FromIndex sh e
Cond e1 e2 e3 -> travE3 Cond e1 e2 e3
PrimConst c -> reconstruct (PrimConst c) noNodeCounts
PrimApp p e -> travE1 (PrimApp p) e
Index a e -> travAE Index a e
LinearIndex a e -> travAE LinearIndex a e
Shape a -> travA Shape a
ShapeSize e -> travE1 ShapeSize e
Foreign ff f e -> travE1 (Foreign ff f) e
where
travTup :: Tuple.Tuple SharingExp tup -> (Tuple.Tuple SharingExp tup, NodeCounts)
travTup NilTup = (NilTup, noNodeCounts)
travTup (SnocTup tup e) = let
(tup', accCountT) = travTup tup
(e' , accCountE) = scopesExp e
in
(SnocTup tup' e', accCountT +++ accCountE)
travE1 :: (SharingExp a -> PreExp SharingAcc SharingExp t) -> SharingExp a
-> (SharingExp t, NodeCounts)
travE1 c e = reconstruct (c e') accCount
where
(e', accCount) = scopesExp e
travE2 :: (SharingExp a -> SharingExp b -> PreExp SharingAcc SharingExp t)
-> SharingExp a
-> SharingExp b
-> (SharingExp t, NodeCounts)
travE2 c e1 e2 = reconstruct (c e1' e2') (accCount1 +++ accCount2)
where
(e1', accCount1) = scopesExp e1
(e2', accCount2) = scopesExp e2
travE3 :: (SharingExp a -> SharingExp b -> SharingExp c -> PreExp SharingAcc SharingExp t)
-> SharingExp a
-> SharingExp b
-> SharingExp c
-> (SharingExp t, NodeCounts)
travE3 c e1 e2 e3 = reconstruct (c e1' e2' e3') (accCount1 +++ accCount2 +++ accCount3)
where
(e1', accCount1) = scopesExp e1
(e2', accCount2) = scopesExp e2
(e3', accCount3) = scopesExp e3
travA :: (SharingAcc a -> PreExp SharingAcc SharingExp t) -> SharingAcc a
-> (SharingExp t, NodeCounts)
travA c acc = maybeFloatOutAcc c acc' accCount
where
(acc', accCount) = scopesAcc acc
travAE :: (SharingAcc a -> SharingExp b -> PreExp SharingAcc SharingExp t)
-> SharingAcc a
-> SharingExp b
-> (SharingExp t, NodeCounts)
travAE c acc e = maybeFloatOutAcc (`c` e') acc' (accCountA +++ accCountE)
where
(acc', accCountA) = scopesAcc acc
(e' , accCountE) = scopesExp e
maybeFloatOutAcc :: (SharingAcc a -> PreExp SharingAcc SharingExp t)
-> SharingAcc a
-> NodeCounts
-> (SharingExp t, NodeCounts)
maybeFloatOutAcc c acc@(AvarSharing _) accCount
= reconstruct (c acc) accCount
maybeFloatOutAcc c acc accCount
| floatOutAcc config = reconstruct (c var) ((stableAcc `accNodeCount` 1) +++ accCount)
| otherwise = reconstruct (c acc) accCount
where
(var, stableAcc) = abstract acc id
abstract :: SharingAcc a -> (SharingAcc a -> SharingAcc a)
-> (SharingAcc a, StableSharingAcc)
abstract (AvarSharing _) _ = INTERNAL_ERROR(error) "sharingAccToVar" "AvarSharing"
abstract (AletSharing sa acc) lets = abstract acc (lets . AletSharing sa)
abstract acc@(AccSharing sn _) lets = (AvarSharing sn, StableSharingAcc sn (lets acc))
expOccCount = let StableNameHeight sn' _ = sn
in
lookupWithASTName expOccMap (StableASTName sn')
reconstruct :: PreExp SharingAcc SharingExp t -> NodeCounts
-> (SharingExp t, NodeCounts)
reconstruct newExp@(Tag _) _subCount
= let thisCount = StableSharingExp sn (ExpSharing sn newExp) `expNodeCount` 1
in
tracePure "FREE" (show thisCount)
(VarSharing sn, thisCount)
reconstruct newExp subCount
| expOccCount > 1 && recoverExpSharing config
= let allCount = (StableSharingExp sn sharingExp `expNodeCount` 1) +++ newCount
in
tracePure ("SHARED" ++ completed) (show allCount)
(VarSharing sn, allCount)
| otherwise
= tracePure ("Normal" ++ completed) (show newCount)
(sharingExp, newCount)
where
(newCount, bindHere) = filterCompleted subCount
lets = foldl (flip (.)) id . map LetSharing $ bindHere
sharingExp = lets $ ExpSharing sn newExp
completed | null bindHere = ""
| otherwise = "(" ++ show (length bindHere) ++ " lets)"
filterCompleted :: NodeCounts -> (NodeCounts, [StableSharingExp])
filterCompleted counts
= let (completed, counts') = break notComplete counts
in (counts', [sa | ExpNodeCount sa _ <- completed])
where
notComplete nc@(ExpNodeCount sa n) | not . isFreeVar $ nc = lookupWithSharingExp expOccMap sa > n
notComplete _ = True
recoverSharingAcc
:: Typeable a
=> Config
-> Level
-> [Level]
-> Acc a
-> (SharingAcc a, [StableSharingAcc])
recoverSharingAcc config alvl avars acc
= let (acc', occMap)
= unsafePerformIO
$ makeOccMapAcc config alvl acc
in
determineScopesAcc config avars occMap acc'
recoverSharingExp
:: Typeable e
=> Config
-> Level
-> [Level]
-> Exp e
-> (SharingExp e, [StableSharingExp])
recoverSharingExp config lvl fvar exp
= let
(rootExp, accOccMap) = unsafePerformIO $ do
accOccMap <- newASTHashTable
(exp', _) <- makeOccMapRootExp config accOccMap lvl fvar exp
frozenAccOccMap <- freezeOccMap accOccMap
return (exp', frozenAccOccMap)
(EnvExp sse sharingExp, _) =
determineScopesExp config accOccMap rootExp
in
(sharingExp, sse)
traceLine :: String -> String -> IO ()
traceLine header msg
= Debug.traceMessage Debug.dump_sharing
$ header ++ ": " ++ msg
traceChunk :: String -> String -> IO ()
traceChunk header msg
= Debug.traceMessage Debug.dump_sharing
$ header ++ "\n " ++ msg
tracePure :: String -> String -> a -> a
tracePure header msg
= Debug.tracePure Debug.dump_sharing
$ header ++ ": " ++ msg
_showSharingAccOp :: SharingAcc arrs -> String
_showSharingAccOp (AvarSharing sn) = "AVAR " ++ show (hashStableNameHeight sn)
_showSharingAccOp (AletSharing _ acc) = "ALET " ++ _showSharingAccOp acc
_showSharingAccOp (AccSharing _ acc) = showPreAccOp acc