module Data.Array.Accelerate.Trafo (
Phase(..), phases,
convertAcc, convertAccWith,
convertAccFun1, convertAccFun1With,
module Data.Array.Accelerate.Trafo.Fusion,
rebuildAcc,
module Data.Array.Accelerate.Trafo.Substitution,
) where
import System.IO.Unsafe
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Debug
import Data.Array.Accelerate.Pretty ( )
import Data.Array.Accelerate.Array.Sugar ( Arrays, Elt )
import Data.Array.Accelerate.Trafo.Base
import Data.Array.Accelerate.Trafo.Fusion hiding ( convertAcc, convertAfun )
import Data.Array.Accelerate.Trafo.Substitution
import qualified Data.Array.Accelerate.AST as AST
import qualified Data.Array.Accelerate.Trafo.Fusion as Fusion
import qualified Data.Array.Accelerate.Trafo.Rewrite as Rewrite
import qualified Data.Array.Accelerate.Trafo.Simplify as Rewrite
import qualified Data.Array.Accelerate.Trafo.Sharing as Sharing
data Phase = Phase
{
recoverAccSharing :: Bool
, recoverExpSharing :: Bool
, floatOutAccFromExp :: Bool
, enableAccFusion :: Bool
, convertOffsetOfSegment :: Bool
}
phases :: Phase
phases = Phase
{ recoverAccSharing = True
, recoverExpSharing = True
, floatOutAccFromExp = True
, enableAccFusion = True
, convertOffsetOfSegment = False
}
convertAcc :: Arrays arrs => Acc arrs -> DelayedAcc arrs
convertAcc = convertAccWith phases
convertAccWith :: Arrays arrs => Phase -> Acc arrs -> DelayedAcc arrs
convertAccWith ok acc
= Fusion.convertAcc
$ Rewrite.convertSegments `when` convertOffsetOfSegment
$ Sharing.convertAcc (recoverAccSharing ok) (recoverExpSharing ok) (floatOutAccFromExp ok) acc
where
when f phase
| phase ok = f
| otherwise = id
convertAccFun1 :: (Arrays a, Arrays b) => (Acc a -> Acc b) -> DelayedAfun (a -> b)
convertAccFun1 = convertAccFun1With phases
convertAccFun1With :: (Arrays a, Arrays b) => Phase -> (Acc a -> Acc b) -> DelayedAfun (a -> b)
convertAccFun1With ok acc
= Fusion.convertAfun
$ Rewrite.convertSegmentsAfun `when` convertOffsetOfSegment
$ Sharing.convertAccFun1 (recoverAccSharing ok) (recoverExpSharing ok) (floatOutAccFromExp ok) acc
where
when f phase
| phase ok = f
| otherwise = id
convertExp :: Elt e => Exp e -> AST.Exp () e
convertExp
= Rewrite.simplify
. Sharing.convertExp (recoverExpSharing phases)
convertFun1 :: (Elt a, Elt b) => (Exp a -> Exp b) -> AST.Fun () (a -> b)
convertFun1
= Rewrite.simplify
. Sharing.convertFun1 (recoverExpSharing phases)
convertFun2 :: (Elt a, Elt b, Elt c) => (Exp a -> Exp b -> Exp c) -> AST.Fun () (a -> b -> c)
convertFun2
= Rewrite.simplify
. Sharing.convertFun2 (recoverExpSharing phases)
instance Arrays arrs => Show (Acc arrs) where
show = withSimplStats . show . convertAcc
instance (Arrays a, Arrays b) => Show (Acc a -> Acc b) where
show = withSimplStats . show . convertAccFun1
instance Elt e => Show (Exp e) where
show = withSimplStats . show . convertExp
instance (Elt a, Elt b) => Show (Exp a -> Exp b) where
show = withSimplStats . show . convertFun1
instance (Elt a, Elt b, Elt c) => Show (Exp a -> Exp b -> Exp c) where
show = withSimplStats . show . convertFun2
withSimplStats :: String -> String
#ifdef ACCELERATE_DEBUG
withSimplStats x = unsafePerformIO $ do
enabled <- queryFlag dump_simpl_stats
if not enabled
then return x
else do resetSimplCount
stats <- length x `seq` simplCount
traceMessage dump_simpl_stats (show stats)
return x
#else
withSimplStats x = x
#endif