{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}

-- | Type inference of @loop@.  This is complicated because of the
-- uniqueness and size inference, so the implementation is separate
-- from the main type checker.
module Language.Futhark.TypeChecker.Terms.DoLoop
  ( UncheckedLoop,
    CheckedLoop,
    checkDoLoop,
  )
where

import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor
import Data.Bitraversable
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import Futhark.Util (nubOrd)
import Futhark.Util.Pretty hiding (bool, group, space)
import Language.Futhark
import Language.Futhark.TypeChecker.Monad hiding (BoundV)
import Language.Futhark.TypeChecker.Terms.Monad hiding (consumed)
import Language.Futhark.TypeChecker.Terms.Pat
import Language.Futhark.TypeChecker.Types
import Language.Futhark.TypeChecker.Unify
import Prelude hiding (mod)

-- | Replace specified sizes with distinct fresh size variables.
someDimsFreshInType ::
  SrcLoc ->
  Rigidity ->
  Name ->
  S.Set VName ->
  TypeBase (DimDecl VName) als ->
  TermTypeM (TypeBase (DimDecl VName) als)
someDimsFreshInType :: SrcLoc
-> Rigidity
-> Name
-> Set VName
-> TypeBase (DimDecl VName) als
-> TermTypeM (TypeBase (DimDecl VName) als)
someDimsFreshInType SrcLoc
loc Rigidity
r Name
desc Set VName
sizes = (DimDecl VName -> TermTypeM (DimDecl VName))
-> (als -> TermTypeM als)
-> TypeBase (DimDecl VName) als
-> TermTypeM (TypeBase (DimDecl VName) als)
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse DimDecl VName -> TermTypeM (DimDecl VName)
forall (m :: * -> *).
MonadUnify m =>
DimDecl VName -> m (DimDecl VName)
onDim als -> TermTypeM als
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  where
    onDim :: DimDecl VName -> m (DimDecl VName)
onDim (NamedDim QualName VName
d)
      | QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
sizes = do
        VName
v <- SrcLoc -> Rigidity -> Name -> m VName
forall (m :: * -> *).
MonadUnify m =>
SrcLoc -> Rigidity -> Name -> m VName
newDimVar SrcLoc
loc Rigidity
r Name
desc
        DimDecl VName -> m (DimDecl VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DimDecl VName -> m (DimDecl VName))
-> DimDecl VName -> m (DimDecl VName)
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
v
    onDim DimDecl VName
d = DimDecl VName -> m (DimDecl VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DimDecl VName
d

-- | Replace the specified sizes with fresh size variables of the
-- specified ridigity.  Returns the new fresh size variables.
freshDimsInType ::
  SrcLoc ->
  Rigidity ->
  Name ->
  S.Set VName ->
  TypeBase (DimDecl VName) als ->
  TermTypeM (TypeBase (DimDecl VName) als, [VName])
freshDimsInType :: SrcLoc
-> Rigidity
-> Name
-> Set VName
-> TypeBase (DimDecl VName) als
-> TermTypeM (TypeBase (DimDecl VName) als, [VName])
freshDimsInType SrcLoc
loc Rigidity
r Name
desc Set VName
sizes TypeBase (DimDecl VName) als
t =
  (Map VName VName -> [VName])
-> (TypeBase (DimDecl VName) als, Map VName VName)
-> (TypeBase (DimDecl VName) als, [VName])
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second Map VName VName -> [VName]
forall k a. Map k a -> [a]
M.elems ((TypeBase (DimDecl VName) als, Map VName VName)
 -> (TypeBase (DimDecl VName) als, [VName]))
-> TermTypeM (TypeBase (DimDecl VName) als, Map VName VName)
-> TermTypeM (TypeBase (DimDecl VName) als, [VName])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT (Map VName VName) TermTypeM (TypeBase (DimDecl VName) als)
-> Map VName VName
-> TermTypeM (TypeBase (DimDecl VName) als, Map VName VName)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT ((DimDecl VName
 -> StateT (Map VName VName) TermTypeM (DimDecl VName))
-> (als -> StateT (Map VName VName) TermTypeM als)
-> TypeBase (DimDecl VName) als
-> StateT
     (Map VName VName) TermTypeM (TypeBase (DimDecl VName) als)
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse DimDecl VName -> StateT (Map VName VName) TermTypeM (DimDecl VName)
forall (t :: (* -> *) -> * -> *) (m :: * -> *).
(MonadState (Map VName VName) (t m), MonadTrans t, MonadUnify m) =>
DimDecl VName -> t m (DimDecl VName)
onDim als -> StateT (Map VName VName) TermTypeM als
forall (f :: * -> *) a. Applicative f => a -> f a
pure TypeBase (DimDecl VName) als
t) Map VName VName
forall a. Monoid a => a
mempty
  where
    onDim :: DimDecl VName -> t m (DimDecl VName)
onDim (NamedDim QualName VName
d)
      | QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
sizes = do
        Maybe VName
prev_subst <- (Map VName VName -> Maybe VName) -> t m (Maybe VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Map VName VName -> Maybe VName) -> t m (Maybe VName))
-> (Map VName VName -> Maybe VName) -> t m (Maybe VName)
forall a b. (a -> b) -> a -> b
$ VName -> Map VName VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (VName -> Map VName VName -> Maybe VName)
-> VName -> Map VName VName -> Maybe VName
forall a b. (a -> b) -> a -> b
$ QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d
        case Maybe VName
prev_subst of
          Just VName
d' -> DimDecl VName -> t m (DimDecl VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DimDecl VName -> t m (DimDecl VName))
-> DimDecl VName -> t m (DimDecl VName)
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
d'
          Maybe VName
Nothing -> do
            VName
v <- m VName -> t m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VName -> t m VName) -> m VName -> t m VName
forall a b. (a -> b) -> a -> b
$ SrcLoc -> Rigidity -> Name -> m VName
forall (m :: * -> *).
MonadUnify m =>
SrcLoc -> Rigidity -> Name -> m VName
newDimVar SrcLoc
loc Rigidity
r Name
desc
            (Map VName VName -> Map VName VName) -> t m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map VName VName -> Map VName VName) -> t m ())
-> (Map VName VName -> Map VName VName) -> t m ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
d) VName
v
            DimDecl VName -> t m (DimDecl VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DimDecl VName -> t m (DimDecl VName))
-> DimDecl VName -> t m (DimDecl VName)
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim (QualName VName -> DimDecl VName)
-> QualName VName -> DimDecl VName
forall a b. (a -> b) -> a -> b
$ VName -> QualName VName
forall v. v -> QualName v
qualName VName
v
    onDim DimDecl VName
d = DimDecl VName -> t m (DimDecl VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DimDecl VName
d

-- | Mark bindings of names in "consumed" as Unique.
uniquePat :: Names -> Pat -> Pat
uniquePat :: Set VName -> Pat -> Pat
uniquePat Set VName
consumed = Pat -> Pat
recurse
  where
    recurse :: Pat -> Pat
recurse (Wildcard (Info PatType
t) SrcLoc
wloc) =
      Info PatType -> SrcLoc -> Pat
forall (f :: * -> *) vn. f PatType -> SrcLoc -> PatBase f vn
Wildcard (PatType -> Info PatType
forall a. a -> Info a
Info (PatType -> Info PatType) -> PatType -> Info PatType
forall a b. (a -> b) -> a -> b
$ PatType
t PatType -> Uniqueness -> PatType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique) SrcLoc
wloc
    recurse (PatParens Pat
p SrcLoc
ploc) =
      Pat -> SrcLoc -> Pat
forall (f :: * -> *) vn. PatBase f vn -> SrcLoc -> PatBase f vn
PatParens (Pat -> Pat
recurse Pat
p) SrcLoc
ploc
    recurse (PatAttr AttrInfo VName
attr Pat
p SrcLoc
ploc) =
      AttrInfo VName -> Pat -> SrcLoc -> Pat
forall (f :: * -> *) vn.
AttrInfo vn -> PatBase f vn -> SrcLoc -> PatBase f vn
PatAttr AttrInfo VName
attr (Pat -> Pat
recurse Pat
p) SrcLoc
ploc
    recurse (Id VName
name (Info PatType
t) SrcLoc
iloc)
      | VName
name VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
consumed =
        let t' :: PatType
t' = PatType
t PatType -> Uniqueness -> PatType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Unique PatType -> Aliasing -> PatType
forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
`setAliases` Aliasing
forall a. Monoid a => a
mempty
         in VName -> Info PatType -> SrcLoc -> Pat
forall (f :: * -> *) vn. vn -> f PatType -> SrcLoc -> PatBase f vn
Id VName
name (PatType -> Info PatType
forall a. a -> Info a
Info PatType
t') SrcLoc
iloc
      | Bool
otherwise =
        let t' :: PatType
t' = PatType
t PatType -> Uniqueness -> PatType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Nonunique
         in VName -> Info PatType -> SrcLoc -> Pat
forall (f :: * -> *) vn. vn -> f PatType -> SrcLoc -> PatBase f vn
Id VName
name (PatType -> Info PatType
forall a. a -> Info a
Info PatType
t') SrcLoc
iloc
    recurse (TuplePat [Pat]
pats SrcLoc
ploc) =
      [Pat] -> SrcLoc -> Pat
forall (f :: * -> *) vn. [PatBase f vn] -> SrcLoc -> PatBase f vn
TuplePat ((Pat -> Pat) -> [Pat] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Pat -> Pat
recurse [Pat]
pats) SrcLoc
ploc
    recurse (RecordPat [(Name, Pat)]
fs SrcLoc
ploc) =
      [(Name, Pat)] -> SrcLoc -> Pat
forall (f :: * -> *) vn.
[(Name, PatBase f vn)] -> SrcLoc -> PatBase f vn
RecordPat (((Name, Pat) -> (Name, Pat)) -> [(Name, Pat)] -> [(Name, Pat)]
forall a b. (a -> b) -> [a] -> [b]
map ((Pat -> Pat) -> (Name, Pat) -> (Name, Pat)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Pat -> Pat
recurse) [(Name, Pat)]
fs) SrcLoc
ploc
    recurse (PatAscription Pat
p TypeDeclBase Info VName
t SrcLoc
ploc) =
      Pat -> TypeDeclBase Info VName -> SrcLoc -> Pat
forall (f :: * -> *) vn.
PatBase f vn -> TypeDeclBase f vn -> SrcLoc -> PatBase f vn
PatAscription Pat
p TypeDeclBase Info VName
t SrcLoc
ploc
    recurse p :: Pat
p@PatLit {} = Pat
p
    recurse (PatConstr Name
n Info PatType
t [Pat]
ps SrcLoc
ploc) =
      Name -> Info PatType -> [Pat] -> SrcLoc -> Pat
forall (f :: * -> *) vn.
Name -> f PatType -> [PatBase f vn] -> SrcLoc -> PatBase f vn
PatConstr Name
n Info PatType
t ((Pat -> Pat) -> [Pat] -> [Pat]
forall a b. (a -> b) -> [a] -> [b]
map Pat -> Pat
recurse [Pat]
ps) SrcLoc
ploc

convergePat :: SrcLoc -> Pat -> Names -> PatType -> Usage -> TermTypeM Pat
convergePat :: SrcLoc -> Pat -> Set VName -> PatType -> Usage -> TermTypeM Pat
convergePat SrcLoc
loop_loc Pat
pat Set VName
body_cons PatType
body_t Usage
body_loc = do
  let -- Make the pattern unique where needed.
      pat' :: Pat
pat' = Set VName -> Pat -> Pat
uniquePat (Pat -> Set VName
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set vn
patNames Pat
pat Set VName -> Set VName -> Set VName
forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` Set VName
body_cons) Pat
pat

  PatType
pat_t <- PatType -> TermTypeM PatType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully (PatType -> TermTypeM PatType) -> PatType -> TermTypeM PatType
forall a b. (a -> b) -> a -> b
$ Pat -> PatType
patternType Pat
pat'
  Bool -> TermTypeM () -> TermTypeM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (PatType -> TypeBase () ()
forall dim as. TypeBase dim as -> TypeBase () ()
toStructural PatType
body_t TypeBase () () -> TypeBase () () -> Bool
`subtypeOf` PatType -> TypeBase () ()
forall dim as. TypeBase dim as -> TypeBase () ()
toStructural PatType
pat_t) (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
    SrcLoc -> StructType -> [StructType] -> TermTypeM ()
forall (m :: * -> *) a.
MonadTypeChecker m =>
SrcLoc -> StructType -> [StructType] -> m a
unexpectedType (Usage -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Usage
body_loc) (PatType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
body_t) [PatType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
pat_t]

  -- Check that the new values of consumed merge parameters do not
  -- alias something bound outside the loop, AND that anything
  -- returned for a unique merge parameter does not alias anything
  -- else returned.  We also update the aliases for the pattern.
  Set VName
bound_outside <- (TermEnv -> Set VName) -> TermTypeM (Set VName)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((TermEnv -> Set VName) -> TermTypeM (Set VName))
-> (TermEnv -> Set VName) -> TermTypeM (Set VName)
forall a b. (a -> b) -> a -> b
$ [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName)
-> (TermEnv -> [VName]) -> TermEnv -> Set VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName ValBinding -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName ValBinding -> [VName])
-> (TermEnv -> Map VName ValBinding) -> TermEnv -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermScope -> Map VName ValBinding
scopeVtable (TermScope -> Map VName ValBinding)
-> (TermEnv -> TermScope) -> TermEnv -> Map VName ValBinding
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermEnv -> TermScope
termScope
  let combAliases :: TypeBase dim ast -> TypeBase shape ast -> TypeBase dim ast
combAliases TypeBase dim ast
t1 TypeBase shape ast
t2 =
        case TypeBase dim ast
t1 of
          Scalar Record {} -> TypeBase dim ast
t1
          TypeBase dim ast
_ -> TypeBase dim ast
t1 TypeBase dim ast -> (ast -> ast) -> TypeBase dim ast
forall dim asf ast.
TypeBase dim asf -> (asf -> ast) -> TypeBase dim ast
`addAliases` (ast -> ast -> ast
forall a. Semigroup a => a -> a -> a
<> TypeBase shape ast -> ast
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase shape ast
t2)

      checkMergeReturn :: PatBase Info vn -> TypeBase dim Aliasing -> t m (PatBase Info vn)
checkMergeReturn (Id vn
pat_v (Info PatType
pat_v_t) SrcLoc
patloc) TypeBase dim Aliasing
t
        | PatType -> Bool
forall shape as. TypeBase shape as -> Bool
unique PatType
pat_v_t,
          VName
v : [VName]
_ <-
            Set VName -> [VName]
forall a. Set a -> [a]
S.toList (Set VName -> [VName]) -> Set VName -> [VName]
forall a b. (a -> b) -> a -> b
$
              (Alias -> VName) -> Aliasing -> Set VName
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map Alias -> VName
aliasVar (TypeBase dim Aliasing -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase dim Aliasing
t) Set VName -> Set VName -> Set VName
forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` Set VName
bound_outside =
          m (PatBase Info vn) -> t m (PatBase Info vn)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (PatBase Info vn) -> t m (PatBase Info vn))
-> (Doc -> m (PatBase Info vn)) -> Doc -> t m (PatBase Info vn)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SrcLoc -> Notes -> Doc -> m (PatBase Info vn)
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loop_loc Notes
forall a. Monoid a => a
mempty (Doc -> t m (PatBase Info vn)) -> Doc -> t m (PatBase Info vn)
forall a b. (a -> b) -> a -> b
$
            Doc
"Return value for loop parameter"
              Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (vn -> Doc
forall v. IsName v => v -> Doc
pprName vn
pat_v)
              Doc -> Doc -> Doc
<+> Doc
"aliases"
              Doc -> Doc -> Doc
<+> VName -> Doc
forall v. IsName v => v -> Doc
pprName VName
v Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc
"."
        | Bool
otherwise = do
          (Aliasing
cons, Aliasing
obs) <- t m (Aliasing, Aliasing)
forall s (m :: * -> *). MonadState s m => m s
get
          Bool -> t m () -> t m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Aliasing -> Bool
forall a. Set a -> Bool
S.null (Aliasing -> Bool) -> Aliasing -> Bool
forall a b. (a -> b) -> a -> b
$ TypeBase dim Aliasing -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase dim Aliasing
t Aliasing -> Aliasing -> Aliasing
forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` Aliasing
cons) (t m () -> t m ()) -> t m () -> t m ()
forall a b. (a -> b) -> a -> b
$
            m () -> t m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> t m ()) -> (Doc -> m ()) -> Doc -> t m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SrcLoc -> Notes -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loop_loc Notes
forall a. Monoid a => a
mempty (Doc -> t m ()) -> Doc -> t m ()
forall a b. (a -> b) -> a -> b
$
              Doc
"Return value for loop parameter"
                Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (vn -> Doc
forall v. IsName v => v -> Doc
pprName vn
pat_v)
                Doc -> Doc -> Doc
<+> Doc
"aliases other consumed loop parameter."
          Bool -> t m () -> t m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
            ( PatType -> Bool
forall shape as. TypeBase shape as -> Bool
unique PatType
pat_v_t
                Bool -> Bool -> Bool
&& Bool -> Bool
not (Aliasing -> Bool
forall a. Set a -> Bool
S.null (TypeBase dim Aliasing -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase dim Aliasing
t Aliasing -> Aliasing -> Aliasing
forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` (Aliasing
cons Aliasing -> Aliasing -> Aliasing
forall a. Semigroup a => a -> a -> a
<> Aliasing
obs)))
            )
            (t m () -> t m ()) -> t m () -> t m ()
forall a b. (a -> b) -> a -> b
$ m () -> t m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> t m ()) -> (Doc -> m ()) -> Doc -> t m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SrcLoc -> Notes -> Doc -> m ()
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError SrcLoc
loop_loc Notes
forall a. Monoid a => a
mempty (Doc -> t m ()) -> Doc -> t m ()
forall a b. (a -> b) -> a -> b
$
              Doc
"Return value for consuming loop parameter"
                Doc -> Doc -> Doc
<+> Doc -> Doc
pquote (vn -> Doc
forall v. IsName v => v -> Doc
pprName vn
pat_v)
                Doc -> Doc -> Doc
<+> Doc
"aliases previously returned value."
          if PatType -> Bool
forall shape as. TypeBase shape as -> Bool
unique PatType
pat_v_t
            then (Aliasing, Aliasing) -> t m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Aliasing
cons Aliasing -> Aliasing -> Aliasing
forall a. Semigroup a => a -> a -> a
<> TypeBase dim Aliasing -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase dim Aliasing
t, Aliasing
obs)
            else (Aliasing, Aliasing) -> t m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Aliasing
cons, Aliasing
obs Aliasing -> Aliasing -> Aliasing
forall a. Semigroup a => a -> a -> a
<> TypeBase dim Aliasing -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase dim Aliasing
t)

          PatBase Info vn -> t m (PatBase Info vn)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatBase Info vn -> t m (PatBase Info vn))
-> PatBase Info vn -> t m (PatBase Info vn)
forall a b. (a -> b) -> a -> b
$ vn -> Info PatType -> SrcLoc -> PatBase Info vn
forall (f :: * -> *) vn. vn -> f PatType -> SrcLoc -> PatBase f vn
Id vn
pat_v (PatType -> Info PatType
forall a. a -> Info a
Info (PatType -> TypeBase dim Aliasing -> PatType
forall ast dim shape.
Monoid ast =>
TypeBase dim ast -> TypeBase shape ast -> TypeBase dim ast
combAliases PatType
pat_v_t TypeBase dim Aliasing
t)) SrcLoc
patloc
      checkMergeReturn (Wildcard (Info PatType
pat_v_t) SrcLoc
patloc) TypeBase dim Aliasing
t =
        PatBase Info vn -> t m (PatBase Info vn)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PatBase Info vn -> t m (PatBase Info vn))
-> PatBase Info vn -> t m (PatBase Info vn)
forall a b. (a -> b) -> a -> b
$ Info PatType -> SrcLoc -> PatBase Info vn
forall (f :: * -> *) vn. f PatType -> SrcLoc -> PatBase f vn
Wildcard (PatType -> Info PatType
forall a. a -> Info a
Info (PatType -> TypeBase dim Aliasing -> PatType
forall ast dim shape.
Monoid ast =>
TypeBase dim ast -> TypeBase shape ast -> TypeBase dim ast
combAliases PatType
pat_v_t TypeBase dim Aliasing
t)) SrcLoc
patloc
      checkMergeReturn (PatParens PatBase Info vn
p SrcLoc
_) TypeBase dim Aliasing
t =
        PatBase Info vn -> TypeBase dim Aliasing -> t m (PatBase Info vn)
checkMergeReturn PatBase Info vn
p TypeBase dim Aliasing
t
      checkMergeReturn (PatAscription PatBase Info vn
p TypeDeclBase Info vn
_ SrcLoc
_) TypeBase dim Aliasing
t =
        PatBase Info vn -> TypeBase dim Aliasing -> t m (PatBase Info vn)
checkMergeReturn PatBase Info vn
p TypeBase dim Aliasing
t
      checkMergeReturn (RecordPat [(Name, PatBase Info vn)]
pfs SrcLoc
patloc) (Scalar (Record Map Name (TypeBase dim Aliasing)
tfs)) =
        [(Name, PatBase Info vn)] -> SrcLoc -> PatBase Info vn
forall (f :: * -> *) vn.
[(Name, PatBase f vn)] -> SrcLoc -> PatBase f vn
RecordPat ([(Name, PatBase Info vn)] -> SrcLoc -> PatBase Info vn)
-> (Map Name (PatBase Info vn) -> [(Name, PatBase Info vn)])
-> Map Name (PatBase Info vn)
-> SrcLoc
-> PatBase Info vn
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Name (PatBase Info vn) -> [(Name, PatBase Info vn)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name (PatBase Info vn) -> SrcLoc -> PatBase Info vn)
-> t m (Map Name (PatBase Info vn))
-> t m (SrcLoc -> PatBase Info vn)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map Name (t m (PatBase Info vn))
-> t m (Map Name (PatBase Info vn))
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence Map Name (t m (PatBase Info vn))
pfs' t m (SrcLoc -> PatBase Info vn)
-> t m SrcLoc -> t m (PatBase Info vn)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SrcLoc -> t m SrcLoc
forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
patloc
        where
          pfs' :: Map Name (t m (PatBase Info vn))
pfs' = (PatBase Info vn -> TypeBase dim Aliasing -> t m (PatBase Info vn))
-> Map Name (PatBase Info vn)
-> Map Name (TypeBase dim Aliasing)
-> Map Name (t m (PatBase Info vn))
forall k a b c.
Ord k =>
(a -> b -> c) -> Map k a -> Map k b -> Map k c
M.intersectionWith PatBase Info vn -> TypeBase dim Aliasing -> t m (PatBase Info vn)
checkMergeReturn ([(Name, PatBase Info vn)] -> Map Name (PatBase Info vn)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, PatBase Info vn)]
pfs) Map Name (TypeBase dim Aliasing)
tfs
      checkMergeReturn (TuplePat [PatBase Info vn]
pats SrcLoc
patloc) TypeBase dim Aliasing
t
        | Just [TypeBase dim Aliasing]
ts <- TypeBase dim Aliasing -> Maybe [TypeBase dim Aliasing]
forall dim as. TypeBase dim as -> Maybe [TypeBase dim as]
isTupleRecord TypeBase dim Aliasing
t =
          [PatBase Info vn] -> SrcLoc -> PatBase Info vn
forall (f :: * -> *) vn. [PatBase f vn] -> SrcLoc -> PatBase f vn
TuplePat ([PatBase Info vn] -> SrcLoc -> PatBase Info vn)
-> t m [PatBase Info vn] -> t m (SrcLoc -> PatBase Info vn)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PatBase Info vn -> TypeBase dim Aliasing -> t m (PatBase Info vn))
-> [PatBase Info vn]
-> [TypeBase dim Aliasing]
-> t m [PatBase Info vn]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM PatBase Info vn -> TypeBase dim Aliasing -> t m (PatBase Info vn)
checkMergeReturn [PatBase Info vn]
pats [TypeBase dim Aliasing]
ts t m (SrcLoc -> PatBase Info vn)
-> t m SrcLoc -> t m (PatBase Info vn)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SrcLoc -> t m SrcLoc
forall (f :: * -> *) a. Applicative f => a -> f a
pure SrcLoc
patloc
      checkMergeReturn PatBase Info vn
p TypeBase dim Aliasing
_ =
        PatBase Info vn -> t m (PatBase Info vn)
forall (f :: * -> *) a. Applicative f => a -> f a
pure PatBase Info vn
p

  (Pat
pat'', (Aliasing
pat_cons, Aliasing
_)) <-
    StateT (Aliasing, Aliasing) TermTypeM Pat
-> (Aliasing, Aliasing) -> TermTypeM (Pat, (Aliasing, Aliasing))
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (Pat -> PatType -> StateT (Aliasing, Aliasing) TermTypeM Pat
forall (t :: (* -> *) -> * -> *) (m :: * -> *) vn dim.
(MonadTrans t, MonadTypeChecker m, IsName vn,
 MonadState (Aliasing, Aliasing) (t m)) =>
PatBase Info vn -> TypeBase dim Aliasing -> t m (PatBase Info vn)
checkMergeReturn Pat
pat' PatType
body_t) (Aliasing
forall a. Monoid a => a
mempty, Aliasing
forall a. Monoid a => a
mempty)

  let body_cons' :: Set VName
body_cons' = Set VName
body_cons Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> (Alias -> VName) -> Aliasing -> Set VName
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map Alias -> VName
aliasVar Aliasing
pat_cons
  if Set VName
body_cons' Set VName -> Set VName -> Bool
forall a. Eq a => a -> a -> Bool
== Set VName
body_cons Bool -> Bool -> Bool
&& Pat -> PatType
patternType Pat
pat'' PatType -> PatType -> Bool
forall a. Eq a => a -> a -> Bool
== Pat -> PatType
patternType Pat
pat
    then Pat -> TermTypeM Pat
forall (f :: * -> *) a. Applicative f => a -> f a
pure Pat
pat'
    else SrcLoc -> Pat -> Set VName -> PatType -> Usage -> TermTypeM Pat
convergePat SrcLoc
loop_loc Pat
pat'' Set VName
body_cons' PatType
body_t Usage
body_loc

data ArgSource = Initial | BodyResult

wellTypedLoopArg :: ArgSource -> [VName] -> Pat -> Exp -> TermTypeM ()
wellTypedLoopArg :: ArgSource -> [VName] -> Pat -> Exp -> TermTypeM ()
wellTypedLoopArg ArgSource
src [VName]
sparams Pat
pat Exp
arg = do
  (StructType
merge_t, [VName]
_) <-
    SrcLoc
-> Rigidity
-> Name
-> Set VName
-> StructType
-> TermTypeM (StructType, [VName])
forall als.
SrcLoc
-> Rigidity
-> Name
-> Set VName
-> TypeBase (DimDecl VName) als
-> TermTypeM (TypeBase (DimDecl VName) als, [VName])
freshDimsInType (Exp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Exp
arg) Rigidity
Nonrigid Name
"loop" ([VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
sparams) (StructType -> TermTypeM (StructType, [VName]))
-> StructType -> TermTypeM (StructType, [VName])
forall a b. (a -> b) -> a -> b
$
      PatType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatType -> StructType) -> PatType -> StructType
forall a b. (a -> b) -> a -> b
$ Pat -> PatType
patternType Pat
pat
  StructType
arg_t <- PatType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct (PatType -> StructType)
-> TermTypeM PatType -> TermTypeM StructType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> TermTypeM PatType
expTypeFully Exp
arg
  Checking -> TermTypeM () -> TermTypeM ()
forall a. Checking -> TermTypeM a -> TermTypeM a
onFailure (StructType -> StructType -> Checking
checking StructType
merge_t StructType
arg_t) (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
    Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify
      (SrcLoc -> String -> Usage
mkUsage (Exp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Exp
arg) String
desc)
      StructType
merge_t
      StructType
arg_t
  where
    (StructType -> StructType -> Checking
checking, String
desc) =
      case ArgSource
src of
        ArgSource
Initial -> (StructType -> StructType -> Checking
CheckingLoopInitial, String
"matching initial loop values to pattern")
        ArgSource
BodyResult -> (StructType -> StructType -> Checking
CheckingLoopBody, String
"matching loop body to pattern")

-- | An un-checked loop.
type UncheckedLoop =
  (UncheckedPat, UncheckedExp, LoopFormBase NoInfo Name, UncheckedExp)

-- | A loop that has been type-checked.
type CheckedLoop =
  ([VName], Pat, Exp, LoopFormBase Info VName, Exp)

-- | Type-check a @loop@ expression, passing in a function for
-- type-checking subexpressions.
checkDoLoop ::
  (UncheckedExp -> TermTypeM Exp) ->
  UncheckedLoop ->
  SrcLoc ->
  TermTypeM (CheckedLoop, AppRes)
checkDoLoop :: (UncheckedExp -> TermTypeM Exp)
-> UncheckedLoop -> SrcLoc -> TermTypeM (CheckedLoop, AppRes)
checkDoLoop UncheckedExp -> TermTypeM Exp
checkExp (UncheckedPat
mergepat, UncheckedExp
mergeexp, LoopFormBase NoInfo Name
form, UncheckedExp
loopbody) SrcLoc
loc =
  TermTypeM Exp
-> (Exp -> Occurrences -> TermTypeM (CheckedLoop, AppRes))
-> TermTypeM (CheckedLoop, AppRes)
forall a b.
TermTypeM a -> (a -> Occurrences -> TermTypeM b) -> TermTypeM b
sequentially (UncheckedExp -> TermTypeM Exp
checkExp UncheckedExp
mergeexp) ((Exp -> Occurrences -> TermTypeM (CheckedLoop, AppRes))
 -> TermTypeM (CheckedLoop, AppRes))
-> (Exp -> Occurrences -> TermTypeM (CheckedLoop, AppRes))
-> TermTypeM (CheckedLoop, AppRes)
forall a b. (a -> b) -> a -> b
$ \Exp
mergeexp' Occurrences
_ -> do
    Usage -> String -> PatType -> TermTypeM ()
forall (m :: * -> *) dim as.
(MonadUnify m, Pretty (ShapeDecl dim), Monoid as) =>
Usage -> String -> TypeBase dim as -> m ()
zeroOrderType
      (SrcLoc -> String -> Usage
mkUsage (UncheckedExp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf UncheckedExp
mergeexp) String
"use as loop variable")
      String
"type used as loop variable"
      (PatType -> TermTypeM ()) -> TermTypeM PatType -> TermTypeM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp -> TermTypeM PatType
expTypeFully Exp
mergeexp'

    -- The handling of dimension sizes is a bit intricate, but very
    -- similar to checking a function, followed by checking a call to
    -- it.  The overall procedure is as follows:
    --
    -- (1) All empty dimensions in the merge pattern are instantiated
    -- with nonrigid size variables.  All explicitly specified
    -- dimensions are preserved.
    --
    -- (2) The body of the loop is type-checked.  The result type is
    -- combined with the merge pattern type to determine which sizes are
    -- variant, and these are turned into size parameters for the merge
    -- pattern.
    --
    -- (3) We now conceptually have a function parameter type and
    -- return type.  We check that it can be called with the body type
    -- as argument.
    --
    -- (4) Similarly to (3), we check that the "function" can be
    -- called with the initial merge values as argument.  The result
    -- of this is the type of the loop as a whole.
    --
    -- (There is also a convergence loop for inferring uniqueness, but
    -- that's orthogonal to the size handling.)

    (PatType
merge_t, Map VName (DimDecl VName)
new_dims_to_initial_dim) <-
      -- dim handling (1)
      SrcLoc
-> Rigidity
-> Name
-> PatType
-> TermTypeM (PatType, Map VName (DimDecl VName))
forall als.
SrcLoc
-> Rigidity
-> Name
-> TypeBase (DimDecl VName) als
-> TermTypeM
     (TypeBase (DimDecl VName) als, Map VName (DimDecl VName))
allDimsFreshInType SrcLoc
loc Rigidity
Nonrigid Name
"loop" (PatType -> TermTypeM (PatType, Map VName (DimDecl VName)))
-> TermTypeM PatType
-> TermTypeM (PatType, Map VName (DimDecl VName))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp -> TermTypeM PatType
expTypeFully Exp
mergeexp'
    let new_dims :: [VName]
new_dims = Map VName (DimDecl VName) -> [VName]
forall k a. Map k a -> [k]
M.keys Map VName (DimDecl VName)
new_dims_to_initial_dim

    -- dim handling (2)
    let checkLoopReturnSize :: Pat -> Exp -> TermTypeM ([VName], Pat)
checkLoopReturnSize Pat
mergepat' Exp
loopbody' = do
          PatType
loopbody_t <- Exp -> TermTypeM PatType
expTypeFully Exp
loopbody'
          PatType
pat_t <-
            SrcLoc
-> Rigidity -> Name -> Set VName -> PatType -> TermTypeM PatType
forall als.
SrcLoc
-> Rigidity
-> Name
-> Set VName
-> TypeBase (DimDecl VName) als
-> TermTypeM (TypeBase (DimDecl VName) als)
someDimsFreshInType SrcLoc
loc Rigidity
Nonrigid Name
"loop" ([VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
new_dims)
              (PatType -> TermTypeM PatType)
-> TermTypeM PatType -> TermTypeM PatType
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PatType -> TermTypeM PatType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully (Pat -> PatType
patternType Pat
mergepat')

          -- We are ignoring the dimensions here, because any mismatches
          -- should be turned into fresh size variables.
          Checking -> TermTypeM () -> TermTypeM ()
forall a. Checking -> TermTypeM a -> TermTypeM a
onFailure (StructType -> StructType -> Checking
CheckingLoopBody (PatType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
pat_t) (PatType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
loopbody_t)) (TermTypeM () -> TermTypeM ()) -> TermTypeM () -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$
            Usage -> StructType -> StructType -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify
              (SrcLoc -> String -> Usage
mkUsage (UncheckedExp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf UncheckedExp
loopbody) String
"matching loop body to loop pattern")
              (PatType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
pat_t)
              (PatType -> StructType
forall dim as. TypeBase dim as -> TypeBase dim ()
toStruct PatType
loopbody_t)

          -- Figure out which of the 'new_dims' dimensions are variant.
          -- This works because we know that each dimension from
          -- new_dims in the pattern is unique and distinct.
          let onDims :: p -> DimDecl VName -> DimDecl VName -> f (DimDecl VName)
onDims p
_ DimDecl VName
x DimDecl VName
y
                | DimDecl VName
x DimDecl VName -> DimDecl VName -> Bool
forall a. Eq a => a -> a -> Bool
== DimDecl VName
y = DimDecl VName -> f (DimDecl VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DimDecl VName
x
              onDims p
_ (NamedDim QualName VName
v) DimDecl VName
d
                | QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
new_dims = do
                  case VName -> Map VName (DimDecl VName) -> Maybe (DimDecl VName)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v) Map VName (DimDecl VName)
new_dims_to_initial_dim of
                    Just DimDecl VName
d'
                      | DimDecl VName
d' DimDecl VName -> DimDecl VName -> Bool
forall a. Eq a => a -> a -> Bool
== DimDecl VName
d ->
                        (p (Map VName (Subst t)) [VName]
 -> p (Map VName (Subst t)) [VName])
-> f ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((p (Map VName (Subst t)) [VName]
  -> p (Map VName (Subst t)) [VName])
 -> f ())
-> (p (Map VName (Subst t)) [VName]
    -> p (Map VName (Subst t)) [VName])
-> f ()
forall a b. (a -> b) -> a -> b
$ (Map VName (Subst t) -> Map VName (Subst t))
-> p (Map VName (Subst t)) [VName]
-> p (Map VName (Subst t)) [VName]
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ((Map VName (Subst t) -> Map VName (Subst t))
 -> p (Map VName (Subst t)) [VName]
 -> p (Map VName (Subst t)) [VName])
-> (Map VName (Subst t) -> Map VName (Subst t))
-> p (Map VName (Subst t)) [VName]
-> p (Map VName (Subst t)) [VName]
forall a b. (a -> b) -> a -> b
$ VName -> Subst t -> Map VName (Subst t) -> Map VName (Subst t)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v) (DimDecl VName -> Subst t
forall t. DimDecl VName -> Subst t
SizeSubst DimDecl VName
d)
                    Maybe (DimDecl VName)
_ ->
                      (p (Map VName (Subst t)) [VName]
 -> p (Map VName (Subst t)) [VName])
-> f ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((p (Map VName (Subst t)) [VName]
  -> p (Map VName (Subst t)) [VName])
 -> f ())
-> (p (Map VName (Subst t)) [VName]
    -> p (Map VName (Subst t)) [VName])
-> f ()
forall a b. (a -> b) -> a -> b
$ ([VName] -> [VName])
-> p (Map VName (Subst t)) [VName]
-> p (Map VName (Subst t)) [VName]
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (QualName VName -> VName
forall vn. QualName vn -> vn
qualLeaf QualName VName
v VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
:)
                  DimDecl VName -> f (DimDecl VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DimDecl VName -> f (DimDecl VName))
-> DimDecl VName -> f (DimDecl VName)
forall a b. (a -> b) -> a -> b
$ QualName VName -> DimDecl VName
forall vn. QualName vn -> DimDecl vn
NamedDim QualName VName
v
              onDims p
_ DimDecl VName
x DimDecl VName
_ = DimDecl VName -> f (DimDecl VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DimDecl VName
x
          PatType
loopbody_t' <- PatType -> TermTypeM PatType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatType
loopbody_t
          PatType
merge_t' <- PatType -> TermTypeM PatType
forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully PatType
merge_t
          let (Map VName (Subst t)
init_substs, [VName]
sparams) =
                State (Map VName (Subst t), [VName]) PatType
-> (Map VName (Subst t), [VName]) -> (Map VName (Subst t), [VName])
forall s a. State s a -> s -> s
execState (([VName]
 -> DimDecl VName
 -> DimDecl VName
 -> StateT (Map VName (Subst t), [VName]) Identity (DimDecl VName))
-> PatType
-> PatType
-> State (Map VName (Subst t), [VName]) PatType
forall as (m :: * -> *) d1 d2.
(Monoid as, Monad m) =>
([VName] -> d1 -> d2 -> m d1)
-> TypeBase d1 as -> TypeBase d2 as -> m (TypeBase d1 as)
matchDims [VName]
-> DimDecl VName
-> DimDecl VName
-> StateT (Map VName (Subst t), [VName]) Identity (DimDecl VName)
forall (f :: * -> *) (p :: * -> * -> *) t p.
(Bifunctor p, MonadState (p (Map VName (Subst t)) [VName]) f) =>
p -> DimDecl VName -> DimDecl VName -> f (DimDecl VName)
onDims PatType
merge_t' PatType
loopbody_t') (Map VName (Subst t), [VName])
forall a. Monoid a => a
mempty

          -- Make sure that any of new_dims that are invariant will be
          -- replaced with the invariant size in the loop body.  Failure
          -- to do this can cause type annotations to still refer to
          -- new_dims.
          let dimToInit :: (VName, Subst t) -> TermTypeM ()
dimToInit (VName
v, SizeSubst DimDecl VName
d) =
                VName -> Constraint -> TermTypeM ()
constrain VName
v (Constraint -> TermTypeM ()) -> Constraint -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Maybe (DimDecl VName) -> Usage -> Constraint
Size (DimDecl VName -> Maybe (DimDecl VName)
forall a. a -> Maybe a
Just DimDecl VName
d) (SrcLoc -> String -> Usage
mkUsage SrcLoc
loc String
"size of loop parameter")
              dimToInit (VName, Subst t)
_ =
                () -> TermTypeM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
          ((VName, Subst Any) -> TermTypeM ())
-> [(VName, Subst Any)] -> TermTypeM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (VName, Subst Any) -> TermTypeM ()
forall t. (VName, Subst t) -> TermTypeM ()
dimToInit ([(VName, Subst Any)] -> TermTypeM ())
-> [(VName, Subst Any)] -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ Map VName (Subst Any) -> [(VName, Subst Any)]
forall k a. Map k a -> [(k, a)]
M.toList Map VName (Subst Any)
forall t. Map VName (Subst t)
init_substs

          Pat
mergepat'' <- TypeSubs -> Pat -> Pat
forall a. Substitutable a => TypeSubs -> a -> a
applySubst (VName
-> Map VName (Subst StructRetType) -> Maybe (Subst StructRetType)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Subst StructRetType)
forall t. Map VName (Subst t)
init_substs) (Pat -> Pat) -> TermTypeM Pat -> TermTypeM Pat
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat -> TermTypeM Pat
forall e. ASTMappable e => e -> TermTypeM e
updateTypes Pat
mergepat'

          -- Eliminate those new_dims that turned into sparams so it won't
          -- look like we have ambiguous sizes lying around.
          (Constraints -> Constraints) -> TermTypeM ()
forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints ((Constraints -> Constraints) -> TermTypeM ())
-> (Constraints -> Constraints) -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ (VName -> (Level, Constraint) -> Bool)
-> Constraints -> Constraints
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey ((VName -> (Level, Constraint) -> Bool)
 -> Constraints -> Constraints)
-> (VName -> (Level, Constraint) -> Bool)
-> Constraints
-> Constraints
forall a b. (a -> b) -> a -> b
$ \VName
k (Level, Constraint)
_ -> VName
k VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
sparams

          -- dim handling (3)
          --
          -- The only trick here is that we have to turn any instances
          -- of loop parameters in the type of loopbody' rigid,
          -- because we are no longer in a position to change them,
          -- really.
          ArgSource -> [VName] -> Pat -> Exp -> TermTypeM ()
wellTypedLoopArg ArgSource
BodyResult [VName]
sparams Pat
mergepat'' Exp
loopbody'

          ([VName], Pat) -> TermTypeM ([VName], Pat)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> [VName]
forall a. Ord a => [a] -> [a]
nubOrd [VName]
sparams, Pat
mergepat'')

    -- First we do a basic check of the loop body to figure out which of
    -- the merge parameters are being consumed.  For this, we first need
    -- to check the merge pattern, which requires the (initial) merge
    -- expression.
    --
    -- Play a little with occurences to ensure it does not look like
    -- none of the merge variables are being used.
    (([VName]
sparams, Pat
mergepat', LoopFormBase Info VName
form', Exp
loopbody'), Occurrences
bodyflow) <-
      case LoopFormBase NoInfo Name
form of
        For IdentBase NoInfo Name
i UncheckedExp
uboundexp -> do
          Exp
uboundexp' <-
            String -> [PrimType] -> Exp -> TermTypeM Exp
require String
"being the bound in a 'for' loop" [PrimType]
anySignedType
              (Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< UncheckedExp -> TermTypeM Exp
checkExp UncheckedExp
uboundexp
          PatType
bound_t <- Exp -> TermTypeM PatType
expTypeFully Exp
uboundexp'
          IdentBase NoInfo Name
-> PatType
-> (Ident
    -> TermTypeM
         (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a.
IdentBase NoInfo Name
-> PatType -> (Ident -> TermTypeM a) -> TermTypeM a
bindingIdent IdentBase NoInfo Name
i PatType
bound_t ((Ident
  -> TermTypeM
       (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
 -> TermTypeM
      (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> (Ident
    -> TermTypeM
         (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a b. (a -> b) -> a -> b
$ \Ident
i' ->
            TermTypeM
  (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a. TermTypeM a -> TermTypeM a
noUnique (TermTypeM
   (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
 -> TermTypeM
      (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> ((Pat
     -> TermTypeM
          (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
    -> TermTypeM
         (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> (Pat
    -> TermTypeM
         (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SizeBinder VName]
-> UncheckedPat
-> InferredType
-> (Pat
    -> TermTypeM
         (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a.
[SizeBinder VName]
-> UncheckedPat
-> InferredType
-> (Pat -> TermTypeM a)
-> TermTypeM a
bindingPat [] UncheckedPat
mergepat (PatType -> InferredType
Ascribed PatType
merge_t) ((Pat
  -> TermTypeM
       (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
 -> TermTypeM
      (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> (Pat
    -> TermTypeM
         (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a b. (a -> b) -> a -> b
$
              \Pat
mergepat' -> TermTypeM
  (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a. TermTypeM a -> TermTypeM a
onlySelfAliasing (TermTypeM
   (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
 -> TermTypeM
      (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> (TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp)
    -> TermTypeM
         (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp)
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp)
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a. TermTypeM a -> TermTypeM (a, Occurrences)
tapOccurrences (TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp)
 -> TermTypeM
      (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp)
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a b. (a -> b) -> a -> b
$ do
                Exp
loopbody' <- TermTypeM Exp -> TermTypeM Exp
forall a. TermTypeM a -> TermTypeM a
noSizeEscape (TermTypeM Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ UncheckedExp -> TermTypeM Exp
checkExp UncheckedExp
loopbody
                ([VName]
sparams, Pat
mergepat'') <- Pat -> Exp -> TermTypeM ([VName], Pat)
checkLoopReturnSize Pat
mergepat' Exp
loopbody'
                ([VName], Pat, LoopFormBase Info VName, Exp)
-> TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return
                  ( [VName]
sparams,
                    Pat
mergepat'',
                    Ident -> Exp -> LoopFormBase Info VName
forall (f :: * -> *) vn.
IdentBase f vn -> ExpBase f vn -> LoopFormBase f vn
For Ident
i' Exp
uboundexp',
                    Exp
loopbody'
                  )
        ForIn UncheckedPat
xpat UncheckedExp
e -> do
          (StructType
arr_t, StructType
_) <- SrcLoc -> Name -> Level -> TermTypeM (StructType, StructType)
newArrayType (UncheckedExp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf UncheckedExp
e) Name
"e" Level
1
          Exp
e' <- String -> StructType -> Exp -> TermTypeM Exp
unifies String
"being iterated in a 'for-in' loop" StructType
arr_t (Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< UncheckedExp -> TermTypeM Exp
checkExp UncheckedExp
e
          PatType
t <- Exp -> TermTypeM PatType
expTypeFully Exp
e'
          case PatType
t of
            PatType
_
              | Just PatType
t' <- Level -> PatType -> Maybe PatType
forall dim as. Level -> TypeBase dim as -> Maybe (TypeBase dim as)
peelArray Level
1 PatType
t ->
                [SizeBinder VName]
-> UncheckedPat
-> InferredType
-> (Pat
    -> TermTypeM
         (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a.
[SizeBinder VName]
-> UncheckedPat
-> InferredType
-> (Pat -> TermTypeM a)
-> TermTypeM a
bindingPat [] UncheckedPat
xpat (PatType -> InferredType
Ascribed PatType
t') ((Pat
  -> TermTypeM
       (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
 -> TermTypeM
      (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> (Pat
    -> TermTypeM
         (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a b. (a -> b) -> a -> b
$ \Pat
xpat' ->
                  TermTypeM
  (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a. TermTypeM a -> TermTypeM a
noUnique (TermTypeM
   (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
 -> TermTypeM
      (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> ((Pat
     -> TermTypeM
          (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
    -> TermTypeM
         (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> (Pat
    -> TermTypeM
         (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SizeBinder VName]
-> UncheckedPat
-> InferredType
-> (Pat
    -> TermTypeM
         (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a.
[SizeBinder VName]
-> UncheckedPat
-> InferredType
-> (Pat -> TermTypeM a)
-> TermTypeM a
bindingPat [] UncheckedPat
mergepat (PatType -> InferredType
Ascribed PatType
merge_t) ((Pat
  -> TermTypeM
       (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
 -> TermTypeM
      (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> (Pat
    -> TermTypeM
         (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a b. (a -> b) -> a -> b
$
                    \Pat
mergepat' -> TermTypeM
  (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a. TermTypeM a -> TermTypeM a
onlySelfAliasing (TermTypeM
   (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
 -> TermTypeM
      (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> (TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp)
    -> TermTypeM
         (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp)
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp)
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a. TermTypeM a -> TermTypeM (a, Occurrences)
tapOccurrences (TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp)
 -> TermTypeM
      (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp)
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a b. (a -> b) -> a -> b
$ do
                      Exp
loopbody' <- TermTypeM Exp -> TermTypeM Exp
forall a. TermTypeM a -> TermTypeM a
noSizeEscape (TermTypeM Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ UncheckedExp -> TermTypeM Exp
checkExp UncheckedExp
loopbody
                      ([VName]
sparams, Pat
mergepat'') <- Pat -> Exp -> TermTypeM ([VName], Pat)
checkLoopReturnSize Pat
mergepat' Exp
loopbody'
                      ([VName], Pat, LoopFormBase Info VName, Exp)
-> TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return
                        ( [VName]
sparams,
                          Pat
mergepat'',
                          Pat -> Exp -> LoopFormBase Info VName
forall (f :: * -> *) vn.
PatBase f vn -> ExpBase f vn -> LoopFormBase f vn
ForIn Pat
xpat' Exp
e',
                          Exp
loopbody'
                        )
              | Bool
otherwise ->
                SrcLoc
-> Notes
-> Doc
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc -> m a
typeError (UncheckedExp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf UncheckedExp
e) Notes
forall a. Monoid a => a
mempty (Doc
 -> TermTypeM
      (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> Doc
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a b. (a -> b) -> a -> b
$
                  Doc
"Iteratee of a for-in loop must be an array, but expression has type"
                    Doc -> Doc -> Doc
<+> PatType -> Doc
forall a. Pretty a => a -> Doc
ppr PatType
t
        While UncheckedExp
cond ->
          TermTypeM
  (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a. TermTypeM a -> TermTypeM a
noUnique (TermTypeM
   (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
 -> TermTypeM
      (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> ((Pat
     -> TermTypeM
          (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
    -> TermTypeM
         (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> (Pat
    -> TermTypeM
         (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SizeBinder VName]
-> UncheckedPat
-> InferredType
-> (Pat
    -> TermTypeM
         (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a.
[SizeBinder VName]
-> UncheckedPat
-> InferredType
-> (Pat -> TermTypeM a)
-> TermTypeM a
bindingPat [] UncheckedPat
mergepat (PatType -> InferredType
Ascribed PatType
merge_t) ((Pat
  -> TermTypeM
       (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
 -> TermTypeM
      (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> (Pat
    -> TermTypeM
         (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a b. (a -> b) -> a -> b
$ \Pat
mergepat' ->
            TermTypeM
  (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a. TermTypeM a -> TermTypeM a
onlySelfAliasing (TermTypeM
   (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
 -> TermTypeM
      (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> ((Exp
     -> Occurrences
     -> TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp))
    -> TermTypeM
         (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> (Exp
    -> Occurrences
    -> TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp))
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp)
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a. TermTypeM a -> TermTypeM (a, Occurrences)
tapOccurrences
              (TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp)
 -> TermTypeM
      (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> ((Exp
     -> Occurrences
     -> TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp))
    -> TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp))
-> (Exp
    -> Occurrences
    -> TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp))
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TermTypeM Exp
-> (Exp
    -> Occurrences
    -> TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp))
-> TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp)
forall a b.
TermTypeM a -> (a -> Occurrences -> TermTypeM b) -> TermTypeM b
sequentially
                ( UncheckedExp -> TermTypeM Exp
checkExp UncheckedExp
cond
                    TermTypeM Exp -> (Exp -> TermTypeM Exp) -> TermTypeM Exp
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= String -> StructType -> Exp -> TermTypeM Exp
unifies String
"being the condition of a 'while' loop" (ScalarTypeBase (DimDecl VName) () -> StructType
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
Scalar (ScalarTypeBase (DimDecl VName) () -> StructType)
-> ScalarTypeBase (DimDecl VName) () -> StructType
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) ()
forall dim as. PrimType -> ScalarTypeBase dim as
Prim PrimType
Bool)
                )
              ((Exp
  -> Occurrences
  -> TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp))
 -> TermTypeM
      (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences))
-> (Exp
    -> Occurrences
    -> TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp))
-> TermTypeM
     (([VName], Pat, LoopFormBase Info VName, Exp), Occurrences)
forall a b. (a -> b) -> a -> b
$ \Exp
cond' Occurrences
_ -> do
                Exp
loopbody' <- TermTypeM Exp -> TermTypeM Exp
forall a. TermTypeM a -> TermTypeM a
noSizeEscape (TermTypeM Exp -> TermTypeM Exp) -> TermTypeM Exp -> TermTypeM Exp
forall a b. (a -> b) -> a -> b
$ UncheckedExp -> TermTypeM Exp
checkExp UncheckedExp
loopbody
                ([VName]
sparams, Pat
mergepat'') <- Pat -> Exp -> TermTypeM ([VName], Pat)
checkLoopReturnSize Pat
mergepat' Exp
loopbody'
                ([VName], Pat, LoopFormBase Info VName, Exp)
-> TermTypeM ([VName], Pat, LoopFormBase Info VName, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return
                  ( [VName]
sparams,
                    Pat
mergepat'',
                    Exp -> LoopFormBase Info VName
forall (f :: * -> *) vn. ExpBase f vn -> LoopFormBase f vn
While Exp
cond',
                    Exp
loopbody'
                  )

    Pat
mergepat'' <- do
      PatType
loopbody_t <- Exp -> TermTypeM PatType
expTypeFully Exp
loopbody'
      SrcLoc -> Pat -> Set VName -> PatType -> Usage -> TermTypeM Pat
convergePat SrcLoc
loc Pat
mergepat' (Occurrences -> Set VName
allConsumed Occurrences
bodyflow) PatType
loopbody_t (Usage -> TermTypeM Pat) -> Usage -> TermTypeM Pat
forall a b. (a -> b) -> a -> b
$
        SrcLoc -> String -> Usage
mkUsage (Exp -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf Exp
loopbody') String
"being (part of) the result of the loop body"

    let consumeMerge :: PatBase Info vn -> TypeBase dim Aliasing -> TermTypeM ()
consumeMerge (Id vn
_ (Info PatType
pt) SrcLoc
ploc) TypeBase dim Aliasing
mt
          | PatType -> Bool
forall shape as. TypeBase shape as -> Bool
unique PatType
pt = SrcLoc -> Aliasing -> TermTypeM ()
consume SrcLoc
ploc (Aliasing -> TermTypeM ()) -> Aliasing -> TermTypeM ()
forall a b. (a -> b) -> a -> b
$ TypeBase dim Aliasing -> Aliasing
forall as shape. Monoid as => TypeBase shape as -> as
aliases TypeBase dim Aliasing
mt
        consumeMerge (TuplePat [PatBase Info vn]
pats SrcLoc
_) TypeBase dim Aliasing
t
          | Just [TypeBase dim Aliasing]
ts <- TypeBase dim Aliasing -> Maybe [TypeBase dim Aliasing]
forall dim as. TypeBase dim as -> Maybe [TypeBase dim as]
isTupleRecord TypeBase dim Aliasing
t =
            (PatBase Info vn -> TypeBase dim Aliasing -> TermTypeM ())
-> [PatBase Info vn] -> [TypeBase dim Aliasing] -> TermTypeM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ PatBase Info vn -> TypeBase dim Aliasing -> TermTypeM ()
consumeMerge [PatBase Info vn]
pats [TypeBase dim Aliasing]
ts
        consumeMerge (PatParens PatBase Info vn
pat SrcLoc
_) TypeBase dim Aliasing
t =
          PatBase Info vn -> TypeBase dim Aliasing -> TermTypeM ()
consumeMerge PatBase Info vn
pat TypeBase dim Aliasing
t
        consumeMerge (PatAscription PatBase Info vn
pat TypeDeclBase Info vn
_ SrcLoc
_) TypeBase dim Aliasing
t =
          PatBase Info vn -> TypeBase dim Aliasing -> TermTypeM ()
consumeMerge PatBase Info vn
pat TypeBase dim Aliasing
t
        consumeMerge PatBase Info vn
_ TypeBase dim Aliasing
_ =
          () -> TermTypeM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Pat -> PatType -> TermTypeM ()
forall vn dim.
PatBase Info vn -> TypeBase dim Aliasing -> TermTypeM ()
consumeMerge Pat
mergepat'' (PatType -> TermTypeM ()) -> TermTypeM PatType -> TermTypeM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp -> TermTypeM PatType
expTypeFully Exp
mergeexp'

    -- dim handling (4)
    ArgSource -> [VName] -> Pat -> Exp -> TermTypeM ()
wellTypedLoopArg ArgSource
Initial [VName]
sparams Pat
mergepat'' Exp
mergeexp'

    (PatType
loopt, [VName]
retext) <-
      SrcLoc
-> Rigidity
-> Name
-> Set VName
-> PatType
-> TermTypeM (PatType, [VName])
forall als.
SrcLoc
-> Rigidity
-> Name
-> Set VName
-> TypeBase (DimDecl VName) als
-> TermTypeM (TypeBase (DimDecl VName) als, [VName])
freshDimsInType SrcLoc
loc (RigidSource -> Rigidity
Rigid RigidSource
RigidLoop) Name
"loop" ([VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
sparams) (PatType -> TermTypeM (PatType, [VName]))
-> PatType -> TermTypeM (PatType, [VName])
forall a b. (a -> b) -> a -> b
$
        Pat -> PatType
patternType Pat
mergepat''
    -- We set all of the uniqueness to be unique.  This is intentional,
    -- and matches what happens for function calls.  Those arrays that
    -- really *cannot* be consumed will alias something unconsumable,
    -- and will be caught that way.
    let bound_here :: Set VName
bound_here = Pat -> Set VName
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set vn
patNames Pat
mergepat'' Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList [VName]
sparams Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> Set VName
form_bound
        form_bound :: Set VName
form_bound =
          case LoopFormBase Info VName
form' of
            For Ident
v Exp
_ -> VName -> Set VName
forall a. a -> Set a
S.singleton (VName -> Set VName) -> VName -> Set VName
forall a b. (a -> b) -> a -> b
$ Ident -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
identName Ident
v
            ForIn Pat
forpat Exp
_ -> Pat -> Set VName
forall (f :: * -> *) vn.
(Functor f, Ord vn) =>
PatBase f vn -> Set vn
patNames Pat
forpat
            While {} -> Set VName
forall a. Monoid a => a
mempty
        loopt' :: PatType
loopt' =
          (Aliasing -> Aliasing) -> PatType -> PatType
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Aliasing -> Aliasing -> Aliasing
forall a. Ord a => Set a -> Set a -> Set a
`S.difference` (VName -> Alias) -> Set VName -> Aliasing
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map VName -> Alias
AliasBound Set VName
bound_here) (PatType -> PatType) -> PatType -> PatType
forall a b. (a -> b) -> a -> b
$
            PatType
loopt PatType -> Uniqueness -> PatType
forall dim as. TypeBase dim as -> Uniqueness -> TypeBase dim as
`setUniqueness` Uniqueness
Unique

    (CheckedLoop, AppRes) -> TermTypeM (CheckedLoop, AppRes)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (([VName]
sparams, Pat
mergepat'', Exp
mergeexp', LoopFormBase Info VName
form', Exp
loopbody'), PatType -> [VName] -> AppRes
AppRes PatType
loopt' [VName]
retext)