{-# LANGUAGE FlexibleContexts #-}
-- | Perform range analysis of a program or other fragment.
module Futhark.Analysis.Range
       ( rangeAnalysis
       , runRangeM
       , RangeM
       , analyseLambda
       , analyseStms
       )
       where

import qualified Data.Map.Strict as M
import Control.Monad.Reader
import Data.List (nub)

import qualified Futhark.Analysis.ScalExp as SE
import Futhark.IR.Ranges
import Futhark.Analysis.AlgSimplify as AS

-- Entry point

-- | Perform variable range analysis on the given program, returning a
-- program with embedded range annotations.
rangeAnalysis :: (ASTLore lore, CanBeRanged (Op lore)) =>
                 Prog lore -> Prog (Ranges lore)
rangeAnalysis :: Prog lore -> Prog (Ranges lore)
rangeAnalysis (Prog Stms lore
consts [FunDef lore]
funs) =
  Stms (Ranges lore) -> [FunDef (Ranges lore)] -> Prog (Ranges lore)
forall lore. Stms lore -> [FunDef lore] -> Prog lore
Prog (RangeM (Stms (Ranges lore)) -> Stms (Ranges lore)
forall a. RangeM a -> a
runRangeM (RangeM (Stms (Ranges lore)) -> Stms (Ranges lore))
-> RangeM (Stms (Ranges lore)) -> Stms (Ranges lore)
forall a b. (a -> b) -> a -> b
$ (Stm lore -> ReaderT RangeEnv Identity (Stm (Ranges lore)))
-> Stms lore -> RangeM (Stms (Ranges lore))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm lore -> ReaderT RangeEnv Identity (Stm (Ranges lore))
forall lore.
(ASTLore lore, CanBeRanged (Op lore)) =>
Stm lore -> RangeM (Stm (Ranges lore))
analyseStm Stms lore
consts) ((FunDef lore -> FunDef (Ranges lore))
-> [FunDef lore] -> [FunDef (Ranges lore)]
forall a b. (a -> b) -> [a] -> [b]
map FunDef lore -> FunDef (Ranges lore)
forall lore.
(ASTLore lore, CanBeRanged (Op lore)) =>
FunDef lore -> FunDef (Ranges lore)
analyseFun [FunDef lore]
funs)

-- Implementation

analyseFun :: (ASTLore lore, CanBeRanged (Op lore)) =>
              FunDef lore -> FunDef (Ranges lore)
analyseFun :: FunDef lore -> FunDef (Ranges lore)
analyseFun (FunDef Maybe EntryPoint
entry Name
fname [RetType lore]
restype [FParam lore]
params BodyT lore
body) =
  RangeM (FunDef (Ranges lore)) -> FunDef (Ranges lore)
forall a. RangeM a -> a
runRangeM (RangeM (FunDef (Ranges lore)) -> FunDef (Ranges lore))
-> RangeM (FunDef (Ranges lore)) -> FunDef (Ranges lore)
forall a b. (a -> b) -> a -> b
$ [FParam lore]
-> RangeM (FunDef (Ranges lore)) -> RangeM (FunDef (Ranges lore))
forall dec a. Typed dec => [Param dec] -> RangeM a -> RangeM a
bindFunParams [FParam lore]
params (RangeM (FunDef (Ranges lore)) -> RangeM (FunDef (Ranges lore)))
-> RangeM (FunDef (Ranges lore)) -> RangeM (FunDef (Ranges lore))
forall a b. (a -> b) -> a -> b
$
  Maybe EntryPoint
-> Name
-> [RetType (Ranges lore)]
-> [FParam (Ranges lore)]
-> BodyT (Ranges lore)
-> FunDef (Ranges lore)
forall lore.
Maybe EntryPoint
-> Name
-> [RetType lore]
-> [FParam lore]
-> BodyT lore
-> FunDef lore
FunDef Maybe EntryPoint
entry Name
fname [RetType lore]
[RetType (Ranges lore)]
restype [FParam lore]
[FParam (Ranges lore)]
params (BodyT (Ranges lore) -> FunDef (Ranges lore))
-> ReaderT RangeEnv Identity (BodyT (Ranges lore))
-> RangeM (FunDef (Ranges lore))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BodyT lore -> ReaderT RangeEnv Identity (BodyT (Ranges lore))
forall lore.
(ASTLore lore, CanBeRanged (Op lore)) =>
Body lore -> RangeM (Body (Ranges lore))
analyseBody BodyT lore
body

analyseBody :: (ASTLore lore, CanBeRanged (Op lore)) =>
               Body lore
            -> RangeM (Body (Ranges lore))
analyseBody :: Body lore -> RangeM (Body (Ranges lore))
analyseBody (Body BodyDec lore
lore Stms lore
origbnds Result
result) =
  Stms lore
-> (Stms (Ranges lore) -> RangeM (Body (Ranges lore)))
-> RangeM (Body (Ranges lore))
forall lore a.
(ASTLore lore, CanBeRanged (Op lore)) =>
Stms lore -> (Stms (Ranges lore) -> RangeM a) -> RangeM a
analyseStms Stms lore
origbnds ((Stms (Ranges lore) -> RangeM (Body (Ranges lore)))
 -> RangeM (Body (Ranges lore)))
-> (Stms (Ranges lore) -> RangeM (Body (Ranges lore)))
-> RangeM (Body (Ranges lore))
forall a b. (a -> b) -> a -> b
$ \Stms (Ranges lore)
bnds' ->
    Body (Ranges lore) -> RangeM (Body (Ranges lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Body (Ranges lore) -> RangeM (Body (Ranges lore)))
-> Body (Ranges lore) -> RangeM (Body (Ranges lore))
forall a b. (a -> b) -> a -> b
$ BodyDec lore -> Stms (Ranges lore) -> Result -> Body (Ranges lore)
forall lore.
BodyDec lore -> Stms (Ranges lore) -> Result -> Body (Ranges lore)
mkRangedBody BodyDec lore
lore Stms (Ranges lore)
bnds' Result
result

-- | Perform range analysis on some statements, taking a continuation
-- where the ranges of the variables bound by the statements is
-- in scope.
analyseStms :: (ASTLore lore, CanBeRanged (Op lore)) =>
               Stms lore
            -> (Stms (Ranges lore) -> RangeM a)
            -> RangeM a
analyseStms :: Stms lore -> (Stms (Ranges lore) -> RangeM a) -> RangeM a
analyseStms = Stms (Ranges lore)
-> [Stm lore] -> (Stms (Ranges lore) -> RangeM a) -> RangeM a
forall lore b.
(ASTLore lore, CanBeRanged (Op lore)) =>
Stms (Ranges lore)
-> [Stm lore]
-> (Stms (Ranges lore) -> ReaderT RangeEnv Identity b)
-> ReaderT RangeEnv Identity b
analyseStms' Stms (Ranges lore)
forall a. Monoid a => a
mempty ([Stm lore] -> (Stms (Ranges lore) -> RangeM a) -> RangeM a)
-> (Stms lore -> [Stm lore])
-> Stms lore
-> (Stms (Ranges lore) -> RangeM a)
-> RangeM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList
  where analyseStms' :: Stms (Ranges lore)
-> [Stm lore]
-> (Stms (Ranges lore) -> ReaderT RangeEnv Identity b)
-> ReaderT RangeEnv Identity b
analyseStms' Stms (Ranges lore)
acc [] Stms (Ranges lore) -> ReaderT RangeEnv Identity b
m =
          Stms (Ranges lore) -> ReaderT RangeEnv Identity b
m Stms (Ranges lore)
acc
        analyseStms' Stms (Ranges lore)
acc (Stm lore
bnd:[Stm lore]
bnds) Stms (Ranges lore) -> ReaderT RangeEnv Identity b
m = do
          Stm (Ranges lore)
bnd' <- Stm lore -> RangeM (Stm (Ranges lore))
forall lore.
(ASTLore lore, CanBeRanged (Op lore)) =>
Stm lore -> RangeM (Stm (Ranges lore))
analyseStm Stm lore
bnd
          PatternT (Range, LetDec lore)
-> ReaderT RangeEnv Identity b -> ReaderT RangeEnv Identity b
forall dec a.
Typed dec =>
PatternT (Range, dec) -> RangeM a -> RangeM a
bindPattern (Stm (Ranges lore) -> Pattern (Ranges lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm (Ranges lore)
bnd') (ReaderT RangeEnv Identity b -> ReaderT RangeEnv Identity b)
-> ReaderT RangeEnv Identity b -> ReaderT RangeEnv Identity b
forall a b. (a -> b) -> a -> b
$
            Stms (Ranges lore)
-> [Stm lore]
-> (Stms (Ranges lore) -> ReaderT RangeEnv Identity b)
-> ReaderT RangeEnv Identity b
analyseStms' (Stms (Ranges lore)
acc Stms (Ranges lore) -> Stms (Ranges lore) -> Stms (Ranges lore)
forall a. Semigroup a => a -> a -> a
<> Stm (Ranges lore) -> Stms (Ranges lore)
forall lore. Stm lore -> Stms lore
oneStm Stm (Ranges lore)
bnd') [Stm lore]
bnds Stms (Ranges lore) -> ReaderT RangeEnv Identity b
m

analyseStm :: (ASTLore lore, CanBeRanged (Op lore)) =>
              Stm lore -> RangeM (Stm (Ranges lore))
analyseStm :: Stm lore -> RangeM (Stm (Ranges lore))
analyseStm (Let Pattern lore
pat StmAux (ExpDec lore)
lore Exp lore
e) = do
  Exp (Ranges lore)
e' <- Exp lore -> RangeM (Exp (Ranges lore))
forall lore.
(ASTLore lore, CanBeRanged (Op lore)) =>
Exp lore -> RangeM (Exp (Ranges lore))
analyseExp Exp lore
e
  PatternT (Range, LetDec lore)
pat' <- PatternT (Range, LetDec lore)
-> RangeM (PatternT (Range, LetDec lore))
forall dec. PatternT (Range, dec) -> RangeM (PatternT (Range, dec))
simplifyPatRanges (PatternT (Range, LetDec lore)
 -> RangeM (PatternT (Range, LetDec lore)))
-> PatternT (Range, LetDec lore)
-> RangeM (PatternT (Range, LetDec lore))
forall a b. (a -> b) -> a -> b
$ Pattern lore -> Exp (Ranges lore) -> Pattern (Ranges lore)
forall lore.
(ASTLore lore, CanBeRanged (Op lore)) =>
Pattern lore -> Exp (Ranges lore) -> Pattern (Ranges lore)
addRangesToPattern Pattern lore
pat Exp (Ranges lore)
e'
  Stm (Ranges lore) -> RangeM (Stm (Ranges lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm (Ranges lore) -> RangeM (Stm (Ranges lore)))
-> Stm (Ranges lore) -> RangeM (Stm (Ranges lore))
forall a b. (a -> b) -> a -> b
$ Pattern (Ranges lore)
-> StmAux (ExpDec (Ranges lore))
-> Exp (Ranges lore)
-> Stm (Ranges lore)
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT (Range, LetDec lore)
Pattern (Ranges lore)
pat' StmAux (ExpDec lore)
StmAux (ExpDec (Ranges lore))
lore Exp (Ranges lore)
e'

analyseExp :: (ASTLore lore, CanBeRanged (Op lore)) =>
              Exp lore
           -> RangeM (Exp (Ranges lore))
analyseExp :: Exp lore -> RangeM (Exp (Ranges lore))
analyseExp = Mapper lore (Ranges lore) (ReaderT RangeEnv Identity)
-> Exp lore -> RangeM (Exp (Ranges lore))
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper lore (Ranges lore) (ReaderT RangeEnv Identity)
analyse
  where analyse :: Mapper lore (Ranges lore) (ReaderT RangeEnv Identity)
analyse =
          Mapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Scope tlore -> Body flore -> m (Body tlore))
-> (VName -> m VName)
-> (RetType flore -> m (RetType tlore))
-> (BranchType flore -> m (BranchType tlore))
-> (FParam flore -> m (FParam tlore))
-> (LParam flore -> m (LParam tlore))
-> (Op flore -> m (Op tlore))
-> Mapper flore tlore m
Mapper { mapOnSubExp :: SubExp -> ReaderT RangeEnv Identity SubExp
mapOnSubExp = SubExp -> ReaderT RangeEnv Identity SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return
                 , mapOnVName :: VName -> ReaderT RangeEnv Identity VName
mapOnVName = VName -> ReaderT RangeEnv Identity VName
forall (m :: * -> *) a. Monad m => a -> m a
return
                 , mapOnBody :: Scope (Ranges lore)
-> Body lore -> ReaderT RangeEnv Identity (Body (Ranges lore))
mapOnBody = (Body lore -> ReaderT RangeEnv Identity (Body (Ranges lore)))
-> Scope (Ranges lore)
-> Body lore
-> ReaderT RangeEnv Identity (Body (Ranges lore))
forall a b. a -> b -> a
const Body lore -> ReaderT RangeEnv Identity (Body (Ranges lore))
forall lore.
(ASTLore lore, CanBeRanged (Op lore)) =>
Body lore -> RangeM (Body (Ranges lore))
analyseBody
                 , mapOnRetType :: RetType lore -> ReaderT RangeEnv Identity (RetType (Ranges lore))
mapOnRetType = RetType lore -> ReaderT RangeEnv Identity (RetType (Ranges lore))
forall (m :: * -> *) a. Monad m => a -> m a
return
                 , mapOnBranchType :: BranchType lore
-> ReaderT RangeEnv Identity (BranchType (Ranges lore))
mapOnBranchType = BranchType lore
-> ReaderT RangeEnv Identity (BranchType (Ranges lore))
forall (m :: * -> *) a. Monad m => a -> m a
return
                 , mapOnFParam :: FParam lore -> ReaderT RangeEnv Identity (FParam (Ranges lore))
mapOnFParam = FParam lore -> ReaderT RangeEnv Identity (FParam (Ranges lore))
forall (m :: * -> *) a. Monad m => a -> m a
return
                 , mapOnLParam :: LParam lore -> ReaderT RangeEnv Identity (LParam (Ranges lore))
mapOnLParam = LParam lore -> ReaderT RangeEnv Identity (LParam (Ranges lore))
forall (m :: * -> *) a. Monad m => a -> m a
return
                 , mapOnOp :: Op lore -> ReaderT RangeEnv Identity (Op (Ranges lore))
mapOnOp = OpWithRanges (Op lore)
-> ReaderT RangeEnv Identity (OpWithRanges (Op lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (OpWithRanges (Op lore)
 -> ReaderT RangeEnv Identity (OpWithRanges (Op lore)))
-> (Op lore -> OpWithRanges (Op lore))
-> Op lore
-> ReaderT RangeEnv Identity (OpWithRanges (Op lore))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op lore -> OpWithRanges (Op lore)
forall op. CanBeRanged op => op -> OpWithRanges op
addOpRanges
                 }

-- | Perform range analysis on a lambda.
analyseLambda :: (ASTLore lore, CanBeRanged (Op lore)) =>
                 Lambda lore
              -> RangeM (Lambda (Ranges lore))
analyseLambda :: Lambda lore -> RangeM (Lambda (Ranges lore))
analyseLambda Lambda lore
lam = do
  Body (Ranges lore)
body <- Body lore -> RangeM (Body (Ranges lore))
forall lore.
(ASTLore lore, CanBeRanged (Op lore)) =>
Body lore -> RangeM (Body (Ranges lore))
analyseBody (Body lore -> RangeM (Body (Ranges lore)))
-> Body lore -> RangeM (Body (Ranges lore))
forall a b. (a -> b) -> a -> b
$ Lambda lore -> Body lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
  Lambda (Ranges lore) -> RangeM (Lambda (Ranges lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda (Ranges lore) -> RangeM (Lambda (Ranges lore)))
-> Lambda (Ranges lore) -> RangeM (Lambda (Ranges lore))
forall a b. (a -> b) -> a -> b
$ Lambda lore
lam { lambdaBody :: Body (Ranges lore)
lambdaBody = Body (Ranges lore)
body
               , lambdaParams :: [LParam (Ranges lore)]
lambdaParams = Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
               }

-- Monad and utility definitions

type RangeEnv = M.Map VName Range

emptyRangeEnv :: RangeEnv
emptyRangeEnv :: RangeEnv
emptyRangeEnv = RangeEnv
forall k a. Map k a
M.empty

-- | The range analysis monad.
type RangeM = Reader RangeEnv

-- | Run a 'RangeM' action.
runRangeM :: RangeM a -> a
runRangeM :: RangeM a -> a
runRangeM = (RangeM a -> RangeEnv -> a) -> RangeEnv -> RangeM a -> a
forall a b c. (a -> b -> c) -> b -> a -> c
flip RangeM a -> RangeEnv -> a
forall r a. Reader r a -> r -> a
runReader RangeEnv
emptyRangeEnv

bindFunParams :: Typed dec => [Param dec] -> RangeM a -> RangeM a
bindFunParams :: [Param dec] -> RangeM a -> RangeM a
bindFunParams []             RangeM a
m =
  RangeM a
m
bindFunParams (Param dec
param:[Param dec]
params) RangeM a
m = do
  RangesRep
ranges <- RangeM RangesRep
rangesRep
  (RangeEnv -> RangeEnv) -> RangeM a -> RangeM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local RangeEnv -> RangeEnv
bindFunParam (RangeM a -> RangeM a) -> RangeM a -> RangeM a
forall a b. (a -> b) -> a -> b
$
    (RangeEnv -> RangeEnv) -> RangeM a -> RangeM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (RangesRep -> Result -> RangeEnv -> RangeEnv
refineDimensionRanges RangesRep
ranges Result
dims) (RangeM a -> RangeM a) -> RangeM a -> RangeM a
forall a b. (a -> b) -> a -> b
$
    [Param dec] -> RangeM a -> RangeM a
forall dec a. Typed dec => [Param dec] -> RangeM a -> RangeM a
bindFunParams [Param dec]
params RangeM a
m
  where bindFunParam :: RangeEnv -> RangeEnv
bindFunParam = VName -> Range -> RangeEnv -> RangeEnv
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param) Range
unknownRange
        dims :: Result
dims = TypeBase Shape NoUniqueness -> Result
forall u. TypeBase Shape u -> Result
arrayDims (TypeBase Shape NoUniqueness -> Result)
-> TypeBase Shape NoUniqueness -> Result
forall a b. (a -> b) -> a -> b
$ Param dec -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param dec
param

bindPattern :: Typed dec => PatternT (Range, dec) -> RangeM a -> RangeM a
bindPattern :: PatternT (Range, dec) -> RangeM a -> RangeM a
bindPattern PatternT (Range, dec)
pat RangeM a
m = do
  RangesRep
ranges <- RangeM RangesRep
rangesRep
  (RangeEnv -> RangeEnv) -> RangeM a -> RangeM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local RangeEnv -> RangeEnv
bindPatElems (RangeM a -> RangeM a) -> RangeM a -> RangeM a
forall a b. (a -> b) -> a -> b
$
    (RangeEnv -> RangeEnv) -> RangeM a -> RangeM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (RangesRep -> Result -> RangeEnv -> RangeEnv
refineDimensionRanges RangesRep
ranges Result
dims)
    RangeM a
m
  where bindPatElems :: RangeEnv -> RangeEnv
bindPatElems RangeEnv
env =
          (RangeEnv -> PatElemT (Range, dec) -> RangeEnv)
-> RangeEnv -> [PatElemT (Range, dec)] -> RangeEnv
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl RangeEnv -> PatElemT (Range, dec) -> RangeEnv
forall a b. Map VName a -> PatElemT (a, b) -> Map VName a
bindPatElem RangeEnv
env ([PatElemT (Range, dec)] -> RangeEnv)
-> [PatElemT (Range, dec)] -> RangeEnv
forall a b. (a -> b) -> a -> b
$ PatternT (Range, dec) -> [PatElemT (Range, dec)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT (Range, dec)
pat
        bindPatElem :: Map VName a -> PatElemT (a, b) -> Map VName a
bindPatElem Map VName a
env PatElemT (a, b)
patElem =
          VName -> a -> Map VName a -> Map VName a
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (PatElemT (a, b) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (a, b)
patElem) ((a, b) -> a
forall a b. (a, b) -> a
fst ((a, b) -> a) -> (a, b) -> a
forall a b. (a -> b) -> a -> b
$ PatElemT (a, b) -> (a, b)
forall dec. PatElemT dec -> dec
patElemDec PatElemT (a, b)
patElem) Map VName a
env
        dims :: Result
dims = Result -> Result
forall a. Eq a => [a] -> [a]
nub (Result -> Result) -> Result -> Result
forall a b. (a -> b) -> a -> b
$ (TypeBase Shape NoUniqueness -> Result)
-> [TypeBase Shape NoUniqueness] -> Result
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap TypeBase Shape NoUniqueness -> Result
forall u. TypeBase Shape u -> Result
arrayDims ([TypeBase Shape NoUniqueness] -> Result)
-> [TypeBase Shape NoUniqueness] -> Result
forall a b. (a -> b) -> a -> b
$ PatternT (Range, dec) -> [TypeBase Shape NoUniqueness]
forall dec.
Typed dec =>
PatternT dec -> [TypeBase Shape NoUniqueness]
patternTypes PatternT (Range, dec)
pat

refineDimensionRanges :: AS.RangesRep -> [SubExp]
                      -> RangeEnv -> RangeEnv
refineDimensionRanges :: RangesRep -> Result -> RangeEnv -> RangeEnv
refineDimensionRanges RangesRep
ranges = (RangeEnv -> Result -> RangeEnv) -> Result -> RangeEnv -> RangeEnv
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((RangeEnv -> Result -> RangeEnv)
 -> Result -> RangeEnv -> RangeEnv)
-> (RangeEnv -> Result -> RangeEnv)
-> Result
-> RangeEnv
-> RangeEnv
forall a b. (a -> b) -> a -> b
$ (RangeEnv -> SubExp -> RangeEnv) -> RangeEnv -> Result -> RangeEnv
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl RangeEnv -> SubExp -> RangeEnv
refineShape
  where refineShape :: RangeEnv -> SubExp -> RangeEnv
refineShape RangeEnv
env (Var VName
dim) =
          RangesRep -> VName -> Range -> RangeEnv -> RangeEnv
refineRange RangesRep
ranges VName
dim Range
dimBound RangeEnv
env
        refineShape RangeEnv
env SubExp
_ =
          RangeEnv
env
        -- A dimension is never negative.
        dimBound :: Range
        dimBound :: Range
dimBound = (KnownBound -> Maybe KnownBound
forall a. a -> Maybe a
Just (KnownBound -> Maybe KnownBound) -> KnownBound -> Maybe KnownBound
forall a b. (a -> b) -> a -> b
$ ScalExp -> KnownBound
ScalarBound ScalExp
0,
                    Maybe KnownBound
forall a. Maybe a
Nothing)

refineRange :: AS.RangesRep -> VName -> Range -> RangeEnv
            -> RangeEnv
refineRange :: RangesRep -> VName -> Range -> RangeEnv -> RangeEnv
refineRange =
  (Range -> Range -> Range) -> VName -> Range -> RangeEnv -> RangeEnv
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith ((Range -> Range -> Range)
 -> VName -> Range -> RangeEnv -> RangeEnv)
-> (RangesRep -> Range -> Range -> Range)
-> RangesRep
-> VName
-> Range
-> RangeEnv
-> RangeEnv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RangesRep -> Range -> Range -> Range
refinedRange

-- New range, old range, result range.
refinedRange :: AS.RangesRep -> Range -> Range -> Range
refinedRange :: RangesRep -> Range -> Range -> Range
refinedRange RangesRep
ranges (Maybe KnownBound
new_lower, Maybe KnownBound
new_upper) (Maybe KnownBound
old_lower, Maybe KnownBound
old_upper) =
  (RangesRep -> Maybe KnownBound -> Maybe KnownBound
simplifyBound RangesRep
ranges (Maybe KnownBound -> Maybe KnownBound)
-> Maybe KnownBound -> Maybe KnownBound
forall a b. (a -> b) -> a -> b
$ Maybe KnownBound -> Maybe KnownBound -> Maybe KnownBound
refineLowerBound Maybe KnownBound
new_lower Maybe KnownBound
old_lower,
   RangesRep -> Maybe KnownBound -> Maybe KnownBound
simplifyBound RangesRep
ranges (Maybe KnownBound -> Maybe KnownBound)
-> Maybe KnownBound -> Maybe KnownBound
forall a b. (a -> b) -> a -> b
$ Maybe KnownBound -> Maybe KnownBound -> Maybe KnownBound
refineUpperBound Maybe KnownBound
new_upper Maybe KnownBound
old_upper)

-- New bound, old bound, result bound.
refineLowerBound :: Bound -> Bound -> Bound
refineLowerBound :: Maybe KnownBound -> Maybe KnownBound -> Maybe KnownBound
refineLowerBound = (Maybe KnownBound -> Maybe KnownBound -> Maybe KnownBound)
-> Maybe KnownBound -> Maybe KnownBound -> Maybe KnownBound
forall a b c. (a -> b -> c) -> b -> a -> c
flip Maybe KnownBound -> Maybe KnownBound -> Maybe KnownBound
maximumBound

-- New bound, old bound, result bound.
refineUpperBound :: Bound -> Bound -> Bound
refineUpperBound :: Maybe KnownBound -> Maybe KnownBound -> Maybe KnownBound
refineUpperBound = (Maybe KnownBound -> Maybe KnownBound -> Maybe KnownBound)
-> Maybe KnownBound -> Maybe KnownBound -> Maybe KnownBound
forall a b c. (a -> b -> c) -> b -> a -> c
flip Maybe KnownBound -> Maybe KnownBound -> Maybe KnownBound
minimumBound

lookupRange :: VName -> RangeM Range
lookupRange :: VName -> RangeM Range
lookupRange = (RangeEnv -> Range) -> RangeM Range
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((RangeEnv -> Range) -> RangeM Range)
-> (VName -> RangeEnv -> Range) -> VName -> RangeM Range
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Range -> VName -> RangeEnv -> Range
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Range
unknownRange

simplifyPatRanges :: PatternT (Range, dec)
                  -> RangeM (PatternT (Range, dec))
simplifyPatRanges :: PatternT (Range, dec) -> RangeM (PatternT (Range, dec))
simplifyPatRanges (Pattern [PatElemT (Range, dec)]
context [PatElemT (Range, dec)]
values) =
  [PatElemT (Range, dec)]
-> [PatElemT (Range, dec)] -> PatternT (Range, dec)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern ([PatElemT (Range, dec)]
 -> [PatElemT (Range, dec)] -> PatternT (Range, dec))
-> ReaderT RangeEnv Identity [PatElemT (Range, dec)]
-> ReaderT
     RangeEnv
     Identity
     ([PatElemT (Range, dec)] -> PatternT (Range, dec))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatElemT (Range, dec)
 -> ReaderT RangeEnv Identity (PatElemT (Range, dec)))
-> [PatElemT (Range, dec)]
-> ReaderT RangeEnv Identity [PatElemT (Range, dec)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT (Range, dec)
-> ReaderT RangeEnv Identity (PatElemT (Range, dec))
forall b.
PatElemT (Range, b)
-> ReaderT RangeEnv Identity (PatElemT (Range, b))
simplifyPatElemRange [PatElemT (Range, dec)]
context ReaderT
  RangeEnv
  Identity
  ([PatElemT (Range, dec)] -> PatternT (Range, dec))
-> ReaderT RangeEnv Identity [PatElemT (Range, dec)]
-> RangeM (PatternT (Range, dec))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (PatElemT (Range, dec)
 -> ReaderT RangeEnv Identity (PatElemT (Range, dec)))
-> [PatElemT (Range, dec)]
-> ReaderT RangeEnv Identity [PatElemT (Range, dec)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT (Range, dec)
-> ReaderT RangeEnv Identity (PatElemT (Range, dec))
forall b.
PatElemT (Range, b)
-> ReaderT RangeEnv Identity (PatElemT (Range, b))
simplifyPatElemRange [PatElemT (Range, dec)]
values
  where simplifyPatElemRange :: PatElemT (Range, b)
-> ReaderT RangeEnv Identity (PatElemT (Range, b))
simplifyPatElemRange PatElemT (Range, b)
patElem = do
          let (Range
range, b
innerdec) = PatElemT (Range, b) -> (Range, b)
forall dec. PatElemT dec -> dec
patElemDec PatElemT (Range, b)
patElem
          Range
range' <- Range -> RangeM Range
simplifyRange Range
range
          PatElemT (Range, b)
-> ReaderT RangeEnv Identity (PatElemT (Range, b))
forall (m :: * -> *) a. Monad m => a -> m a
return (PatElemT (Range, b)
 -> ReaderT RangeEnv Identity (PatElemT (Range, b)))
-> PatElemT (Range, b)
-> ReaderT RangeEnv Identity (PatElemT (Range, b))
forall a b. (a -> b) -> a -> b
$ PatElemT (Range, b) -> (Range, b) -> PatElemT (Range, b)
forall oldattr newattr.
PatElemT oldattr -> newattr -> PatElemT newattr
setPatElemLore PatElemT (Range, b)
patElem (Range
range', b
innerdec)

simplifyRange :: Range -> RangeM Range
simplifyRange :: Range -> RangeM Range
simplifyRange (Maybe KnownBound
lower, Maybe KnownBound
upper) = do
  RangesRep
ranges <- RangeM RangesRep
rangesRep
  Maybe KnownBound
lower' <- RangesRep -> Maybe KnownBound -> Maybe KnownBound
simplifyBound RangesRep
ranges (Maybe KnownBound -> Maybe KnownBound)
-> ReaderT RangeEnv Identity (Maybe KnownBound)
-> ReaderT RangeEnv Identity (Maybe KnownBound)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe KnownBound -> ReaderT RangeEnv Identity (Maybe KnownBound)
betterLowerBound Maybe KnownBound
lower
  Maybe KnownBound
upper' <- RangesRep -> Maybe KnownBound -> Maybe KnownBound
simplifyBound RangesRep
ranges (Maybe KnownBound -> Maybe KnownBound)
-> ReaderT RangeEnv Identity (Maybe KnownBound)
-> ReaderT RangeEnv Identity (Maybe KnownBound)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe KnownBound -> ReaderT RangeEnv Identity (Maybe KnownBound)
betterUpperBound Maybe KnownBound
upper
  Range -> RangeM Range
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KnownBound
lower', Maybe KnownBound
upper')

simplifyBound :: AS.RangesRep -> Bound -> Bound
simplifyBound :: RangesRep -> Maybe KnownBound -> Maybe KnownBound
simplifyBound RangesRep
ranges = (KnownBound -> KnownBound) -> Maybe KnownBound -> Maybe KnownBound
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((KnownBound -> KnownBound)
 -> Maybe KnownBound -> Maybe KnownBound)
-> (KnownBound -> KnownBound)
-> Maybe KnownBound
-> Maybe KnownBound
forall a b. (a -> b) -> a -> b
$ RangesRep -> KnownBound -> KnownBound
simplifyKnownBound RangesRep
ranges

simplifyKnownBound :: AS.RangesRep -> KnownBound -> KnownBound
simplifyKnownBound :: RangesRep -> KnownBound -> KnownBound
simplifyKnownBound RangesRep
ranges KnownBound
bound
  | Just ScalExp
se <- KnownBound -> Maybe ScalExp
boundToScalExp KnownBound
bound =
    ScalExp -> KnownBound
ScalarBound (ScalExp -> KnownBound) -> ScalExp -> KnownBound
forall a b. (a -> b) -> a -> b
$ ScalExp -> RangesRep -> ScalExp
AS.simplify ScalExp
se RangesRep
ranges
simplifyKnownBound RangesRep
ranges (MinimumBound KnownBound
b1 KnownBound
b2) =
  KnownBound -> KnownBound -> KnownBound
MinimumBound (RangesRep -> KnownBound -> KnownBound
simplifyKnownBound RangesRep
ranges KnownBound
b1) (RangesRep -> KnownBound -> KnownBound
simplifyKnownBound RangesRep
ranges KnownBound
b2)
simplifyKnownBound RangesRep
ranges (MaximumBound KnownBound
b1 KnownBound
b2) =
  KnownBound -> KnownBound -> KnownBound
MaximumBound (RangesRep -> KnownBound -> KnownBound
simplifyKnownBound RangesRep
ranges KnownBound
b1) (RangesRep -> KnownBound -> KnownBound
simplifyKnownBound RangesRep
ranges KnownBound
b2)
simplifyKnownBound RangesRep
_ KnownBound
bound =
  KnownBound
bound

betterLowerBound :: Bound -> RangeM Bound
betterLowerBound :: Maybe KnownBound -> ReaderT RangeEnv Identity (Maybe KnownBound)
betterLowerBound (Just (ScalarBound (SE.Id VName
v PrimType
t))) = do
  Range
range <- VName -> RangeM Range
lookupRange VName
v
  Maybe KnownBound -> ReaderT RangeEnv Identity (Maybe KnownBound)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KnownBound -> ReaderT RangeEnv Identity (Maybe KnownBound))
-> Maybe KnownBound -> ReaderT RangeEnv Identity (Maybe KnownBound)
forall a b. (a -> b) -> a -> b
$ KnownBound -> Maybe KnownBound
forall a. a -> Maybe a
Just (KnownBound -> Maybe KnownBound) -> KnownBound -> Maybe KnownBound
forall a b. (a -> b) -> a -> b
$ case Range
range of
    (Just KnownBound
lower, Maybe KnownBound
_) -> KnownBound
lower
    Range
_               -> ScalExp -> KnownBound
ScalarBound (ScalExp -> KnownBound) -> ScalExp -> KnownBound
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> ScalExp
SE.Id VName
v PrimType
t
betterLowerBound Maybe KnownBound
bound =
  Maybe KnownBound -> ReaderT RangeEnv Identity (Maybe KnownBound)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KnownBound
bound

betterUpperBound :: Bound -> RangeM Bound
betterUpperBound :: Maybe KnownBound -> ReaderT RangeEnv Identity (Maybe KnownBound)
betterUpperBound (Just (ScalarBound (SE.Id VName
v PrimType
t))) = do
  Range
range <- VName -> RangeM Range
lookupRange VName
v
  Maybe KnownBound -> ReaderT RangeEnv Identity (Maybe KnownBound)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KnownBound -> ReaderT RangeEnv Identity (Maybe KnownBound))
-> Maybe KnownBound -> ReaderT RangeEnv Identity (Maybe KnownBound)
forall a b. (a -> b) -> a -> b
$ KnownBound -> Maybe KnownBound
forall a. a -> Maybe a
Just (KnownBound -> Maybe KnownBound) -> KnownBound -> Maybe KnownBound
forall a b. (a -> b) -> a -> b
$ case Range
range of
    (Maybe KnownBound
_, Just KnownBound
upper) -> KnownBound
upper
    Range
_               -> ScalExp -> KnownBound
ScalarBound (ScalExp -> KnownBound) -> ScalExp -> KnownBound
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> ScalExp
SE.Id VName
v PrimType
t
betterUpperBound Maybe KnownBound
bound =
  Maybe KnownBound -> ReaderT RangeEnv Identity (Maybe KnownBound)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KnownBound
bound

-- The algebraic simplifier requires a loop nesting level for each
-- range.  We just put a zero because I don't think it's used for
-- anything in this case.
rangesRep :: RangeM AS.RangesRep
rangesRep :: RangeM RangesRep
rangesRep = (RangeEnv -> RangesRep) -> RangeM RangesRep
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((RangeEnv -> RangesRep) -> RangeM RangesRep)
-> (RangeEnv -> RangesRep) -> RangeM RangesRep
forall a b. (a -> b) -> a -> b
$ (Range -> (Int, Maybe ScalExp, Maybe ScalExp))
-> RangeEnv -> RangesRep
forall a b k. (a -> b) -> Map k a -> Map k b
M.map Range -> (Int, Maybe ScalExp, Maybe ScalExp)
forall a. Num a => Range -> (a, Maybe ScalExp, Maybe ScalExp)
addLeadingZero
  where addLeadingZero :: Range -> (a, Maybe ScalExp, Maybe ScalExp)
addLeadingZero (Maybe KnownBound
x,Maybe KnownBound
y) =
          (a
0, KnownBound -> Maybe ScalExp
boundToScalExp (KnownBound -> Maybe ScalExp) -> Maybe KnownBound -> Maybe ScalExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Maybe KnownBound
x, KnownBound -> Maybe ScalExp
boundToScalExp (KnownBound -> Maybe ScalExp) -> Maybe KnownBound -> Maybe ScalExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Maybe KnownBound
y)