{-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeSynonymInstances #-} {-# LANGUAGE FlexibleInstances #-} -- | Utility declarations for performing range analysis. module Futhark.Representation.AST.Attributes.Ranges ( Bound , KnownBound (..) , boundToScalExp , minimumBound , maximumBound , Range , unknownRange , ScalExpRange , Ranged , RangeOf (..) , RangesOf (..) , expRanges , RangedOp (..) , CanBeRanged (..) ) where import Data.Monoid ((<>)) import qualified Data.Set as S import qualified Data.Map.Strict as M import Futhark.Representation.AST.Attributes import Futhark.Representation.AST.Syntax import qualified Futhark.Analysis.ScalExp as SE import qualified Futhark.Analysis.AlgSimplify as AS import Futhark.Transform.Substitute import Futhark.Transform.Rename import qualified Futhark.Util.Pretty as PP -- | A known bound on a value. data KnownBound = VarBound VName -- ^ Has the same bounds as this variable. VERY -- IMPORTANT: this variable may be an array, so it -- cannot be immediately translated to a 'ScalExp'. | MinimumBound KnownBound KnownBound -- ^ Bounded by the minimum of these two bounds. | MaximumBound KnownBound KnownBound -- ^ Bounded by the maximum of these two bounds. | ScalarBound SE.ScalExp -- ^ Bounded by this scalar expression. deriving (Eq, Ord, Show) instance Substitute KnownBound where substituteNames substs (VarBound name) = VarBound $ substituteNames substs name substituteNames substs (MinimumBound b1 b2) = MinimumBound (substituteNames substs b1) (substituteNames substs b2) substituteNames substs (MaximumBound b1 b2) = MaximumBound (substituteNames substs b1) (substituteNames substs b2) substituteNames substs (ScalarBound se) = ScalarBound $ substituteNames substs se instance Rename KnownBound where rename = substituteRename instance FreeIn KnownBound where freeIn (VarBound v) = freeIn v freeIn (MinimumBound b1 b2) = freeIn b1 <> freeIn b2 freeIn (MaximumBound b1 b2) = freeIn b1 <> freeIn b2 freeIn (ScalarBound e) = freeIn e instance FreeAttr KnownBound where precomputed _ = id instance PP.Pretty KnownBound where ppr (VarBound v) = PP.text "variable " <> PP.ppr v ppr (MinimumBound b1 b2) = PP.text "min" <> PP.parens (PP.ppr b1 <> PP.comma PP.<+> PP.ppr b2) ppr (MaximumBound b1 b2) = PP.text "max" <> PP.parens (PP.ppr b1 <> PP.comma PP.<+> PP.ppr b2) ppr (ScalarBound e) = PP.ppr e -- | Convert the bound to a scalar expression if possible. This is -- possible for all bounds that do not contain 'VarBound's. boundToScalExp :: KnownBound -> Maybe SE.ScalExp boundToScalExp (VarBound _) = Nothing boundToScalExp (ScalarBound se) = Just se boundToScalExp (MinimumBound b1 b2) = do b1' <- boundToScalExp b1 b2' <- boundToScalExp b2 return $ SE.MaxMin True [b1', b2'] boundToScalExp (MaximumBound b1 b2) = do b1' <- boundToScalExp b1 b2' <- boundToScalExp b2 return $ SE.MaxMin False [b1', b2'] -- | A possibly undefined bound on a value. type Bound = Maybe KnownBound -- | Construct a 'MinimumBound' from two possibly known bounds. The -- resulting bound will be unknown unless both of the given 'Bound's -- are known. This may seem counterintuitive, but it actually makes -- sense when you consider the task of combining the lower bounds for -- two different flows of execution (like an @if@ expression). If we -- only have knowledge about one of the branches, this means that we -- have no useful information about the combined lower bound, as the -- other branch may take any value. minimumBound :: Bound -> Bound -> Bound minimumBound (Just x) (Just y) = Just $ MinimumBound x y minimumBound _ _ = Nothing -- | Like 'minimumBound', but constructs a 'MaximumBound'. maximumBound :: Bound -> Bound -> Bound maximumBound (Just x) (Just y) = Just $ MaximumBound x y maximumBound _ _ = Nothing -- | Upper and lower bound, both inclusive. type Range = (Bound, Bound) -- | A range in which both upper and lower bounds are 'Nothing. unknownRange :: Range unknownRange = (Nothing, Nothing) -- | The range as a pair of scalar expressions. type ScalExpRange = (Maybe SE.ScalExp, Maybe SE.ScalExp) -- | The lore has embedded range information. Note that it may not be -- up to date, unless whatever maintains the syntax tree is careful. type Ranged lore = (Attributes lore, RangedOp (Op lore), RangeOf (LetAttr lore), RangesOf (BodyAttr lore)) -- | Something that contains range information. class RangeOf a where -- | The range of the argument element. rangeOf :: a -> Range instance RangeOf Range where rangeOf = id instance RangeOf attr => RangeOf (PatElemT attr) where rangeOf = rangeOf . patElemAttr instance RangeOf SubExp where rangeOf se = (Just lower, Just upper) where (lower, upper) = subExpKnownRange se -- | Something that contains range information for several things, -- most notably 'Body' or 'Pattern'. class RangesOf a where -- | The ranges of the argument. rangesOf :: a -> [Range] instance RangeOf a => RangesOf [a] where rangesOf = map rangeOf instance RangeOf attr => RangesOf (PatternT attr) where rangesOf = map rangeOf . patternElements instance Ranged lore => RangesOf (Body lore) where rangesOf = rangesOf . bodyAttr subExpKnownRange :: SubExp -> (KnownBound, KnownBound) subExpKnownRange (Var v) = (VarBound v, VarBound v) subExpKnownRange (Constant val) = (ScalarBound $ SE.Val val, ScalarBound $ SE.Val val) -- | The range of a scalar expression. scalExpRange :: SE.ScalExp -> Range scalExpRange se = (Just $ ScalarBound se, Just $ ScalarBound se) primOpRanges :: BasicOp lore -> [Range] primOpRanges (SubExp se) = [rangeOf se] primOpRanges (BinOp (Add t) x y) = [scalExpRange $ SE.SPlus (SE.subExpToScalExp x $ IntType t) (SE.subExpToScalExp y $ IntType t)] primOpRanges (BinOp (Sub t) x y) = [scalExpRange $ SE.SMinus (SE.subExpToScalExp x $ IntType t) (SE.subExpToScalExp y $ IntType t)] primOpRanges (BinOp (Mul t) x y) = [scalExpRange $ SE.STimes (SE.subExpToScalExp x $ IntType t) (SE.subExpToScalExp y $ IntType t)] primOpRanges (BinOp (SDiv t) x y) = [scalExpRange $ SE.SDiv (SE.subExpToScalExp x $ IntType t) (SE.subExpToScalExp y $ IntType t)] primOpRanges (ConvOp (SExt from to) x) | from < to = [rangeOf x] primOpRanges (Iota n x s Int32) = [(Just $ ScalarBound x', Just $ ScalarBound $ x' + (n' - 1) * s')] where n' = case n of Var v -> SE.Id v $ IntType Int32 Constant val -> SE.Val val x' = case x of Var v -> SE.Id v $ IntType Int32 Constant val -> SE.Val val s' = case s of Var v -> SE.Id v $ IntType Int32 Constant val -> SE.Val val primOpRanges (Replicate _ v) = [rangeOf v] primOpRanges (Rearrange _ v) = [rangeOf $ Var v] primOpRanges (Copy se) = [rangeOf $ Var se] primOpRanges (Index v _) = [rangeOf $ Var v] primOpRanges (ArrayLit (e:es) _) = [(Just lower, Just upper)] where (e_lower, e_upper) = subExpKnownRange e (es_lower, es_upper) = unzip $ map subExpKnownRange es lower = foldl MinimumBound e_lower es_lower upper = foldl MaximumBound e_upper es_upper primOpRanges _ = [unknownRange] -- | Ranges of the value parts of the expression. expRanges :: Ranged lore => Exp lore -> [Range] expRanges (BasicOp op) = primOpRanges op expRanges (If _ tbranch fbranch _) = zip (zipWith minimumBound t_lower f_lower) (zipWith maximumBound t_upper f_upper) where (t_lower, t_upper) = unzip $ rangesOf tbranch (f_lower, f_upper) = unzip $ rangesOf fbranch expRanges (DoLoop ctxmerge valmerge (ForLoop i Int32 iterations _) body) = zipWith returnedRange valmerge $ rangesOf body where bound_in_loop = S.fromList $ i : map (paramName . fst) (ctxmerge++valmerge) ++ concatMap (patternNames . stmPattern) (bodyStms body) returnedRange mergeparam (lower, upper) = (returnedBound mergeparam lower, returnedBound mergeparam upper) returnedBound (param, mergeinit) (Just bound) | paramType param == Prim (IntType Int32), Just bound' <- boundToScalExp bound, let se_diff = AS.simplify (SE.SMinus (SE.Id (paramName param) $ IntType Int32) bound') M.empty, S.null $ S.intersection bound_in_loop $ freeIn se_diff = Just $ ScalarBound $ SE.SPlus (SE.subExpToScalExp mergeinit $ IntType Int32) $ SE.STimes se_diff $ SE.MaxMin False [SE.subExpToScalExp iterations $ IntType Int32, 0] returnedBound _ _ = Nothing expRanges (Op ranges) = opRanges ranges expRanges e = replicate (expExtTypeSize e) unknownRange class IsOp op => RangedOp op where opRanges :: op -> [Range] instance RangedOp () where opRanges () = [] class RangedOp (OpWithRanges op) => CanBeRanged op where type OpWithRanges op :: * removeOpRanges :: OpWithRanges op -> op addOpRanges :: op -> OpWithRanges op instance CanBeRanged () where type OpWithRanges () = () removeOpRanges = id addOpRanges = id