{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
#if __GLASGOW_HASKELL__ <= 708
{-# LANGUAGE IncoherentInstances #-}
{-# LANGUAGE OverlappingInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-unrecognised-pragmas #-}
#endif
module Data.Array.Accelerate.Trafo.Base (
Kit(..), Match(..), (:~:)(..),
avarIn, kmap, fromOpenAfun,
DelayedAcc, DelayedOpenAcc(..),
DelayedAfun, DelayedOpenAfun,
DelayedExp, DelayedOpenExp,
DelayedFun, DelayedOpenFun,
matchDelayedOpenAcc, encodeDelayedOpenAcc, hashDelayedOpenAcc,
Gamma(..), incExp, prjExp, pushExp,
Extend(..), append, bind,
Sink(..), sink, sink1,
Supplement(..), bindExps,
) where
import Control.Applicative
import Control.DeepSeq
import Crypto.Hash
import Data.ByteString.Builder
import Data.ByteString.Builder.Extra
import Data.Monoid
import Data.Type.Equality
import Text.PrettyPrint.ANSI.Leijen hiding ( (<$>), (<>) )
import Prelude hiding ( until )
import Data.Array.Accelerate.AST hiding ( Val(..) )
import Data.Array.Accelerate.Analysis.Hash
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Array.Sugar ( Array, Arrays, Shape, Elt )
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Pretty.Print
import Data.Array.Accelerate.Trafo.Substitution
import Data.Array.Accelerate.Debug.Stats as Stats
class (RebuildableAcc acc, Sink acc) => Kit acc where
inject :: PreOpenAcc acc aenv a -> acc aenv a
extract :: acc aenv a -> PreOpenAcc acc aenv a
fromOpenAcc :: OpenAcc aenv a -> acc aenv a
matchAcc :: MatchAcc acc
encodeAcc :: EncodeAcc acc
prettyAcc :: PrettyAcc acc
instance Kit OpenAcc where
inject = OpenAcc
extract (OpenAcc pacc) = pacc
fromOpenAcc = id
{-# INLINEABLE encodeAcc #-}
{-# INLINEABLE matchAcc #-}
{-# INLINEABLE prettyAcc #-}
encodeAcc (OpenAcc pacc) = encodePreOpenAcc encodeAcc pacc
matchAcc (OpenAcc pacc1) (OpenAcc pacc2) = matchPreOpenAcc matchAcc encodeAcc pacc1 pacc2
prettyAcc = prettyOpenAcc
avarIn :: (Kit acc, Arrays arrs) => Idx aenv arrs -> acc aenv arrs
avarIn = inject . Avar
kmap :: Kit acc => (PreOpenAcc acc aenv a -> PreOpenAcc acc aenv b) -> acc aenv a -> acc aenv b
kmap f = inject . f . extract
fromOpenAfun :: Kit acc => OpenAfun aenv f -> PreOpenAfun acc aenv f
fromOpenAfun (Abody a) = Abody $ fromOpenAcc a
fromOpenAfun (Alam f) = Alam $ fromOpenAfun f
class Match f where
match :: f s -> f t -> Maybe (s :~: t)
instance Match (Idx env) where
{-# INLINEABLE match #-}
match = matchIdx
instance Kit acc => Match (PreOpenExp acc env aenv) where
{-# INLINEABLE match #-}
match = matchPreOpenExp matchAcc encodeAcc
instance Kit acc => Match (PreOpenFun acc env aenv) where
{-# INLINEABLE match #-}
match = matchPreOpenFun matchAcc encodeAcc
instance Kit acc => Match (PreOpenAcc acc aenv) where
{-# INLINEABLE match #-}
match = matchPreOpenAcc matchAcc encodeAcc
instance {-# INCOHERENT #-} Kit acc => Match (acc aenv) where
{-# INLINEABLE match #-}
match = matchAcc
type DelayedAcc = DelayedOpenAcc ()
type DelayedAfun = PreOpenAfun DelayedOpenAcc ()
type DelayedExp = DelayedOpenExp ()
type DelayedFun = DelayedOpenFun ()
type DelayedOpenAfun = PreOpenAfun DelayedOpenAcc
type DelayedOpenExp = PreOpenExp DelayedOpenAcc
type DelayedOpenFun = PreOpenFun DelayedOpenAcc
data DelayedOpenAcc aenv a where
Manifest :: PreOpenAcc DelayedOpenAcc aenv a -> DelayedOpenAcc aenv a
Delayed :: (Shape sh, Elt e) =>
{ extentD :: PreExp DelayedOpenAcc aenv sh
, indexD :: PreFun DelayedOpenAcc aenv (sh -> e)
, linearIndexD :: PreFun DelayedOpenAcc aenv (Int -> e)
} -> DelayedOpenAcc aenv (Array sh e)
instance Rebuildable DelayedOpenAcc where
type AccClo DelayedOpenAcc = DelayedOpenAcc
{-# INLINEABLE rebuildPartial #-}
rebuildPartial v acc = case acc of
Manifest pacc -> Manifest <$> rebuildPartial v pacc
Delayed{..} -> Delayed <$> rebuildPartial v extentD
<*> rebuildPartial v indexD
<*> rebuildPartial v linearIndexD
instance Sink DelayedOpenAcc where
weaken k = Stats.substitution "weaken" . rebuildA (Avar . k)
instance Kit DelayedOpenAcc where
inject = Manifest
extract (Manifest pacc) = pacc
extract Delayed{} = error "DelayedAcc.extract"
fromOpenAcc = error "DelayedAcc.fromOpenAcc"
{-# INLINEABLE encodeAcc #-}
{-# INLINEABLE matchAcc #-}
{-# INLINEABLE prettyAcc #-}
encodeAcc = encodeDelayedOpenAcc
matchAcc = matchDelayedOpenAcc
prettyAcc = prettyDelayedOpenAcc
instance NFData (DelayedOpenAfun aenv t) where
rnf = rnfPreOpenAfun rnfDelayedOpenAcc
instance NFData (DelayedOpenAcc aenv t) where
rnf = rnfDelayedOpenAcc
hashDelayedOpenAcc :: DelayedOpenAcc aenv a -> Hash
hashDelayedOpenAcc = hashlazy . toLazyByteString . encodeDelayedOpenAcc
{-# INLINEABLE encodeDelayedOpenAcc #-}
encodeDelayedOpenAcc :: EncodeAcc DelayedOpenAcc
encodeDelayedOpenAcc (Manifest pacc) = intHost $(hashQ "Manifest") <> encodePreOpenAcc encodeDelayedOpenAcc pacc
encodeDelayedOpenAcc Delayed{..} = intHost $(hashQ "Delayed") <> travE extentD <> travF indexD <> travF linearIndexD
where
{-# INLINE travE #-}
travE :: DelayedExp aenv sh -> Builder
travE = encodePreOpenExp encodeDelayedOpenAcc
{-# INLINE travF #-}
travF :: DelayedFun aenv f -> Builder
travF = encodePreOpenFun encodeDelayedOpenAcc
{-# INLINEABLE matchDelayedOpenAcc #-}
matchDelayedOpenAcc :: MatchAcc DelayedOpenAcc
matchDelayedOpenAcc (Manifest pacc1) (Manifest pacc2)
= matchPreOpenAcc matchDelayedOpenAcc encodeDelayedOpenAcc pacc1 pacc2
matchDelayedOpenAcc (Delayed sh1 ix1 lx1) (Delayed sh2 ix2 lx2)
| Just Refl <- matchPreOpenExp matchDelayedOpenAcc encodeDelayedOpenAcc sh1 sh2
, Just Refl <- matchPreOpenFun matchDelayedOpenAcc encodeDelayedOpenAcc ix1 ix2
, Just Refl <- matchPreOpenFun matchDelayedOpenAcc encodeDelayedOpenAcc lx1 lx2
= Just Refl
matchDelayedOpenAcc _ _
= Nothing
rnfDelayedOpenAcc :: DelayedOpenAcc aenv t -> ()
rnfDelayedOpenAcc (Manifest pacc) = rnfPreOpenAcc rnfDelayedOpenAcc pacc
rnfDelayedOpenAcc (Delayed sh ix lx) = rnfPreOpenExp rnfDelayedOpenAcc sh
`seq` rnfPreOpenFun rnfDelayedOpenAcc ix
`seq` rnfPreOpenFun rnfDelayedOpenAcc lx
prettyDelayedOpenAcc :: PrettyAcc DelayedOpenAcc
prettyDelayedOpenAcc wrap aenv acc = case acc of
Manifest pacc -> prettyPreOpenAcc prettyDelayedOpenAcc wrap aenv pacc
Delayed sh f _
| Shape a <- sh
, Just Refl <- match f (Lam (Body (Index a (Var ZeroIdx))))
-> prettyDelayedOpenAcc wrap aenv a
| otherwise
-> wrap $ hang 2 (sep [ green (text "delayed")
, parens (align (prettyPreExp prettyDelayedOpenAcc (parens . align) aenv sh))
, parens (align (prettyPreFun prettyDelayedOpenAcc aenv f))
])
data Gamma acc env env' aenv where
EmptyExp :: Gamma acc env env' aenv
PushExp :: Elt t
=> Gamma acc env env' aenv
-> WeakPreOpenExp acc env aenv t
-> Gamma acc env (env', t) aenv
data WeakPreOpenExp acc env aenv t where
Subst :: env :> env'
-> PreOpenExp acc env aenv t
-> PreOpenExp acc env' aenv t
-> WeakPreOpenExp acc env' aenv t
incExp
:: Kit acc
=> Gamma acc env env' aenv
-> Gamma acc (env,s) env' aenv
incExp EmptyExp = EmptyExp
incExp (PushExp env w) = incExp env `PushExp` subs w
where
subs :: forall acc env aenv s t. Kit acc => WeakPreOpenExp acc env aenv t -> WeakPreOpenExp acc (env,s) aenv t
subs (Subst k (e :: PreOpenExp acc env_ aenv t) _) = Subst k' e (weakenE k' e)
where
k' :: env_ :> (env,s)
k' = SuccIdx . k
prjExp :: Idx env' t -> Gamma acc env env' aenv -> PreOpenExp acc env aenv t
prjExp ZeroIdx (PushExp _ (Subst _ _ e)) = e
prjExp (SuccIdx ix) (PushExp env _) = prjExp ix env
prjExp _ _ = $internalError "prjExp" "inconsistent valuation"
pushExp :: Elt t => Gamma acc env env' aenv -> PreOpenExp acc env aenv t -> Gamma acc env (env',t) aenv
pushExp env e = env `PushExp` Subst id e e
data Extend acc aenv aenv' where
BaseEnv :: Extend acc aenv aenv
PushEnv :: Arrays a
=> Extend acc aenv aenv' -> acc aenv' a -> Extend acc aenv (aenv', a)
append :: Extend acc env env' -> Extend acc env' env'' -> Extend acc env env''
append x BaseEnv = x
append x (PushEnv as a) = x `append` as `PushEnv` a
bind :: (Kit acc, Arrays a)
=> Extend acc aenv aenv'
-> PreOpenAcc acc aenv' a
-> PreOpenAcc acc aenv a
bind BaseEnv = id
bind (PushEnv env a) = bind env . Alet a . inject
sink :: Sink f => Extend acc env env' -> f env t -> f env' t
sink env = weaken (k env)
where
k :: Extend acc env env' -> Idx env t -> Idx env' t
k BaseEnv = Stats.substitution "sink" id
k (PushEnv e _) = SuccIdx . k e
sink1 :: Sink f => Extend acc env env' -> f (env,s) t -> f (env',s) t
sink1 env = weaken (k env)
where
k :: Extend acc env env' -> Idx (env,s) t -> Idx (env',s) t
k BaseEnv = Stats.substitution "sink1" id
k (PushEnv e _) = split . k e
split :: Idx (env,s) t -> Idx ((env,u),s) t
split ZeroIdx = ZeroIdx
split (SuccIdx ix) = SuccIdx (SuccIdx ix)
data Supplement acc env env' aenv where
BaseSup :: Supplement acc env env aenv
PushSup :: Elt e
=> Supplement acc env env' aenv
-> PreOpenExp acc env' aenv e
-> Supplement acc env (env', e) aenv
bindExps :: (Kit acc, Elt e)
=> Supplement acc env env' aenv
-> PreOpenExp acc env' aenv e
-> PreOpenExp acc env aenv e
bindExps BaseSup = id
bindExps (PushSup g b) = bindExps g . Let b