-- | Type inference of @loop@.  This is complicated because of the
-- uniqueness and size inference, so the implementation is separate
-- from the main type checker.
module Language.Futhark.TypeChecker.Terms.DoLoop
  ( UncheckedLoop,
    CheckedLoop,
    checkDoLoop,
  )
where

import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor
import Data.Bitraversable
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.Util (nubOrd)
import Futhark.Util.Pretty hiding (group, space)
import Language.Futhark
import Language.Futhark.TypeChecker.Monad hiding (BoundV)
import Language.Futhark.TypeChecker.Terms.Monad
import Language.Futhark.TypeChecker.Terms.Pat
import Language.Futhark.TypeChecker.Types
import Language.Futhark.TypeChecker.Unify
import Prelude hiding (mod)

-- | Retrieve an oracle that can be used to decide whether two are in
-- the same equivalence class (i.e. have been unified).  This is an
-- exotic operation.
getAreSame :: MonadUnify m => m (VName -> VName -> Bool)
getAreSame :: forall (m :: * -> *). MonadUnify m => m (VName -> VName -> Bool)
getAreSame = forall {a}. Map VName (a, Constraint) -> VName -> VName -> Bool
check forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  where
    check :: Map VName (a, Constraint) -> VName -> VName -> Bool
check Map VName (a, Constraint)
constraints VName
x VName
y =
      case (forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
x Map VName (a, Constraint)
constraints, forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
y Map VName (a, Constraint)
constraints) of
        (Just (a
_, Size (Just (Var QualName VName
x' Info StructType
_ SrcLoc
_)) Usage
_), Maybe (a, Constraint)
_) ->
          Map VName (a, Constraint) -> VName -> VName -> Bool
check Map VName (a, Constraint)
constraints (forall vn. QualName vn -> vn
qualLeaf QualName VName
x') VName
y
        (Maybe (a, Constraint)
_, Just (a
_, Size (Just (Var QualName VName
y' Info StructType
_ SrcLoc
_)) Usage
_)) ->
          Map VName (a, Constraint) -> VName -> VName -> Bool
check Map VName (a, Constraint)
constraints VName
x (forall vn. QualName vn -> vn
qualLeaf QualName VName
y')
        (Maybe (a, Constraint), Maybe (a, Constraint))
_ ->
          VName
x forall a. Eq a => a -> a -> Bool
== VName
y

-- | Replace specified sizes with distinct fresh size variables.
someDimsFreshInType ::
  SrcLoc ->
  Name ->
  [VName] ->
  TypeBase Size als ->
  TermTypeM (TypeBase Size als)
someDimsFreshInType :: forall als.
SrcLoc
-> Name
-> [VName]
-> TypeBase Exp als
-> TermTypeM (TypeBase Exp als)
someDimsFreshInType SrcLoc
loc Name
desc [VName]
fresh TypeBase Exp als
t = do
  VName -> VName -> Bool
areSameSize <- forall (m :: * -> *). MonadUnify m => m (VName -> VName -> Bool)
getAreSame
  let freshen :: VName -> Bool
freshen VName
v = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> VName -> Bool
areSameSize VName
v) [VName]
fresh
  forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse (forall {m :: * -> *}.
MonadUnify m =>
(VName -> Bool) -> Exp -> m Exp
onDim VName -> Bool
freshen) forall (f :: * -> *) a. Applicative f => a -> f a
pure TypeBase Exp als
t
  where
    onDim :: (VName -> Bool) -> Exp -> m Exp
onDim VName -> Bool
freshen (Var QualName VName
d Info StructType
_ SrcLoc
_)
      | VName -> Bool
freshen forall a b. (a -> b) -> a -> b
$ forall vn. QualName vn -> vn
qualLeaf QualName VName
d = do
          VName
v <- forall (m :: * -> *). MonadUnify m => Usage -> Name -> m VName
newFlexibleDim (forall a. Located a => a -> Usage
mkUsage' SrcLoc
loc) Name
desc
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ QualName VName -> SrcLoc -> Exp
sizeFromName (forall v. v -> QualName v
qualName VName
v) SrcLoc
loc
    onDim VName -> Bool
_ Exp
d = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
d

-- | Replace the specified sizes with fresh size variables of the
-- specified ridigity.  Returns the new fresh size variables.
freshDimsInType ::
  Usage ->
  Rigidity ->
  Name ->
  [VName] ->
  TypeBase Size u ->
  TermTypeM (TypeBase Size u, [VName])
freshDimsInType :: forall u.
Usage
-> Rigidity
-> Name
-> [VName]
-> TypeBase Exp u
-> TermTypeM (TypeBase Exp u, [VName])
freshDimsInType Usage
usage Rigidity
r Name
desc [VName]
fresh TypeBase Exp u
t = do
  VName -> VName -> Bool
areSameSize <- forall (m :: * -> *). MonadUnify m => m (VName -> VName -> Bool)
getAreSame
  forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse (forall {t :: (* -> *) -> * -> *} {m :: * -> *}.
(MonadState [(VName, VName)] (t m), MonadUnify m, MonadTrans t) =>
(VName -> VName -> Bool) -> Exp -> t m Exp
onDim VName -> VName -> Bool
areSameSize) forall (f :: * -> *) a. Applicative f => a -> f a
pure TypeBase Exp u
t) forall a. Monoid a => a
mempty
  where
    onDim :: (VName -> VName -> Bool) -> Exp -> t m Exp
onDim VName -> VName -> Bool
areSameSize (Var (QualName [VName]
_ VName
d) Info StructType
_ SrcLoc
_)
      | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> VName -> Bool
areSameSize VName
d) [VName]
fresh = do
          Maybe (VName, VName)
prev_subst <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
L.find (VName -> VName -> Bool
areSameSize VName
d forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst)
          case Maybe (VName, VName)
prev_subst of
            Just (VName
_, VName
d') -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ QualName VName -> SrcLoc -> Exp
sizeFromName (forall v. v -> QualName v
qualName VName
d') forall a b. (a -> b) -> a -> b
$ forall a. Located a => a -> SrcLoc
srclocOf Usage
usage
            Maybe (VName, VName)
Nothing -> do
              VName
v <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadUnify m =>
Usage -> Rigidity -> Name -> m VName
newDimVar Usage
usage Rigidity
r Name
desc
              forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((VName
d, VName
v) :)
              forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ QualName VName -> SrcLoc -> Exp
sizeFromName (forall v. v -> QualName v
qualName VName
v) forall a b. (a -> b) -> a -> b
$ forall a. Located a => a -> SrcLoc
srclocOf Usage
usage
    onDim VName -> VName -> Bool
_ Exp
d = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
d

data ArgSource = Initial | BodyResult

wellTypedLoopArg :: ArgSource -> [VName] -> Pat ParamType -> Exp -> TermTypeM ()
wellTypedLoopArg :: ArgSource -> [VName] -> Pat ParamType -> Exp -> TermTypeM ()
wellTypedLoopArg ArgSource
src [VName]
sparams Pat ParamType
pat Exp
arg = do
  (StructType
merge_t, [VName]
_) <-
    forall u.
Usage
-> Rigidity
-> Name
-> [VName]
-> TypeBase Exp u
-> TermTypeM (TypeBase Exp u, [VName])
freshDimsInType (forall a. Located a => a -> Text -> Usage
mkUsage Exp
arg Text
desc) Rigidity
Nonrigid Name
"loop" [VName]
sparams forall a b. (a -> b) -> a -> b
$
      forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct (forall d u. Pat (TypeBase d u) -> TypeBase d u
patternType Pat ParamType
pat)
  StructType
arg_t <- forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> TermTypeM StructType
expTypeFully Exp
arg
  forall a. Checking -> TermTypeM a -> TermTypeM a
onFailure (StructType -> StructType -> Checking
checking StructType
merge_t StructType
arg_t) forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify (forall a. Located a => a -> Text -> Usage
mkUsage Exp
arg Text
desc) StructType
merge_t StructType
arg_t
  where
    (StructType -> StructType -> Checking
checking, Text
desc) =
      case ArgSource
src of
        ArgSource
Initial -> (StructType -> StructType -> Checking
CheckingLoopInitial, Text
"matching initial loop values to pattern")
        ArgSource
BodyResult -> (StructType -> StructType -> Checking
CheckingLoopBody, Text
"matching loop body to pattern")

-- | An un-checked loop.
type UncheckedLoop =
  (UncheckedPat ParamType, UncheckedExp, LoopFormBase NoInfo Name, UncheckedExp)

-- | A loop that has been type-checked.
type CheckedLoop =
  ([VName], Pat ParamType, Exp, LoopFormBase Info VName, Exp)

-- | Type-check a @loop@ expression, passing in a function for
-- type-checking subexpressions.
checkDoLoop ::
  (UncheckedExp -> TermTypeM Exp) ->
  UncheckedLoop ->
  SrcLoc ->
  TermTypeM (CheckedLoop, AppRes)
checkDoLoop :: (UncheckedExp -> TermTypeM Exp)
-> UncheckedLoop -> SrcLoc -> TermTypeM (CheckedLoop, AppRes)
checkDoLoop UncheckedExp -> TermTypeM Exp
checkExp (UncheckedPat ParamType
mergepat, UncheckedExp
mergeexp, LoopFormBase NoInfo Name
form, UncheckedExp
loopbody) SrcLoc
loc = do
  Exp
mergeexp' <- UncheckedExp -> TermTypeM Exp
checkExp UncheckedExp
mergeexp
  Set VName
known_before <- forall k a. Map k a -> Set k
M.keysSet forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadUnify m => m Constraints
getConstraints
  forall (m :: * -> *).
MonadUnify m =>
Usage -> Text -> StructType -> m ()
zeroOrderType
    (forall a. Located a => a -> Text -> Usage
mkUsage UncheckedExp
mergeexp Text
"use as loop variable")
    Text
"type used as loop variable"
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct
    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp -> TermTypeM StructType
expTypeFully Exp
mergeexp'

  -- The handling of dimension sizes is a bit intricate, but very
  -- similar to checking a function, followed by checking a call to
  -- it.  The overall procedure is as follows:
  --
  -- (1) All empty dimensions in the merge pattern are instantiated
  -- with nonrigid size variables.  All explicitly specified
  -- dimensions are preserved.
  --
  -- (2) The body of the loop is type-checked.  The result type is
  -- combined with the merge pattern type to determine which sizes are
  -- variant, and these are turned into size parameters for the merge
  -- pattern.
  --
  -- (3) We now conceptually have a function parameter type and
  -- return type.  We check that it can be called with the body type
  -- as argument.
  --
  -- (4) Similarly to (3), we check that the "function" can be
  -- called with the initial merge values as argument.  The result
  -- of this is the type of the loop as a whole.

  (StructType
merge_t, Map VName Exp
new_dims_map) <-
    -- dim handling (1)
    forall als.
Usage
-> Rigidity
-> Name
-> TypeBase Exp als
-> TermTypeM (TypeBase Exp als, Map VName Exp)
allDimsFreshInType (forall a. Located a => a -> Text -> Usage
mkUsage SrcLoc
loc Text
"loop parameter type inference") Rigidity
Nonrigid Name
"loop_d"
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp -> TermTypeM StructType
expTypeFully Exp
mergeexp'
  let new_dims_to_initial_dim :: [(VName, Exp)]
new_dims_to_initial_dim = forall k a. Map k a -> [(k, a)]
M.toList Map VName Exp
new_dims_map
      new_dims :: [VName]
new_dims = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, Exp)]
new_dims_to_initial_dim

  -- dim handling (2)
  let checkLoopReturnSize :: Pat ParamType -> Exp -> TermTypeM ([VName], Pat ParamType)
checkLoopReturnSize Pat ParamType
mergepat' Exp
loopbody' = do
        StructType
loopbody_t <- Exp -> TermTypeM StructType
expTypeFully Exp
loopbody'
        ParamType
pat_t <-
          forall als.
SrcLoc
-> Name
-> [VName]
-> TypeBase Exp als
-> TermTypeM (TypeBase Exp als)
someDimsFreshInType SrcLoc
loc Name
"loop" [VName]
new_dims
            forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully (forall d u. Pat (TypeBase d u) -> TypeBase d u
patternType Pat ParamType
mergepat')

        -- We are ignoring the dimensions here, because any mismatches
        -- should be turned into fresh size variables.
        forall a. Checking -> TermTypeM a -> TermTypeM a
onFailure (StructType -> StructType -> Checking
CheckingLoopBody (forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct ParamType
pat_t) (forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct StructType
loopbody_t)) forall a b. (a -> b) -> a -> b
$
          forall (m :: * -> *).
MonadUnify m =>
Usage -> StructType -> StructType -> m ()
unify
            (forall a. Located a => a -> Text -> Usage
mkUsage UncheckedExp
loopbody Text
"matching loop body to loop pattern")
            (forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct ParamType
pat_t)
            (forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct StructType
loopbody_t)

        -- Figure out which of the 'new_dims' dimensions are variant.
        -- This works because we know that each dimension from
        -- new_dims in the pattern is unique and distinct.
        VName -> VName -> Bool
areSameSize <- forall (m :: * -> *). MonadUnify m => m (VName -> VName -> Bool)
getAreSame
        let onDims :: p -> Exp -> Exp -> f Exp
onDims p
_ Exp
x Exp
y
              | Exp
x forall a. Eq a => a -> a -> Bool
== Exp
y = forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
x
            onDims p
_ Exp
e Exp
d = do
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (FV -> Set VName
fvVars forall a b. (a -> b) -> a -> b
$ Exp -> FV
freeInExp Exp
e) forall a b. (a -> b) -> a -> b
$ \VName
v -> do
                case forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
L.find (VName -> VName -> Bool
areSameSize VName
v forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(VName, Exp)]
new_dims_to_initial_dim of
                  Just (VName
_, Exp
e') ->
                    if Exp
e' forall a. Eq a => a -> a -> Bool
== Exp
d
                      then forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v forall a b. (a -> b) -> a -> b
$ forall t. Exp -> Subst t
ExpSubst Exp
e'
                      else
                        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (VName
v forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
known_before) forall a b. (a -> b) -> a -> b
$
                          forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (VName
v :))
                  Maybe (VName, Exp)
_ ->
                    forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
              forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp
e
        StructType
loopbody_t' <- forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully StructType
loopbody_t
        StructType
merge_t' <- forall a (m :: * -> *). (Substitutable a, MonadUnify m) => a -> m a
normTypeFully StructType
merge_t

        let (Map VName (Subst t)
init_substs, [VName]
sparams) =
              forall s a. State s a -> s -> s
execState (forall as (m :: * -> *) d1 d2.
(Monoid as, Monad m) =>
([VName] -> d1 -> d2 -> m d1)
-> TypeBase d1 as -> TypeBase d2 as -> m (TypeBase d1 as)
matchDims forall {f :: * -> *} {p :: * -> * -> *} {t} {p}.
(Bifunctor p, MonadState (p (Map VName (Subst t)) [VName]) f) =>
p -> Exp -> Exp -> f Exp
onDims StructType
merge_t' StructType
loopbody_t') forall a. Monoid a => a
mempty

        -- Make sure that any of new_dims that are invariant will be
        -- replaced with the invariant size in the loop body.  Failure
        -- to do this can cause type annotations to still refer to
        -- new_dims.
        let dimToInit :: (VName, Subst t) -> TermTypeM ()
dimToInit (VName
v, ExpSubst Exp
e) =
              VName -> Constraint -> TermTypeM ()
constrain VName
v forall a b. (a -> b) -> a -> b
$ Maybe Exp -> Usage -> Constraint
Size (forall a. a -> Maybe a
Just Exp
e) (forall a. Located a => a -> Text -> Usage
mkUsage SrcLoc
loc Text
"size of loop parameter")
            dimToInit (VName, Subst t)
_ =
              forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {t}. (VName, Subst t) -> TermTypeM ()
dimToInit forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList forall {t}. Map VName (Subst t)
init_substs

        Pat ParamType
mergepat'' <- forall a. Substitutable a => TypeSubs -> a -> a
applySubst (forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` forall {t}. Map VName (Subst t)
init_substs) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall e. ASTMappable e => e -> TermTypeM e
updateTypes Pat ParamType
mergepat'

        -- Eliminate those new_dims that turned into sparams so it won't
        -- look like we have ambiguous sizes lying around.
        forall (m :: * -> *).
MonadUnify m =>
(Constraints -> Constraints) -> m ()
modifyConstraints forall a b. (a -> b) -> a -> b
$ forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey forall a b. (a -> b) -> a -> b
$ \VName
k (Level, Constraint)
_ -> VName
k forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
sparams

        -- dim handling (3)
        --
        -- The only trick here is that we have to turn any instances
        -- of loop parameters in the type of loopbody' rigid,
        -- because we are no longer in a position to change them,
        -- really.
        ArgSource -> [VName] -> Pat ParamType -> Exp -> TermTypeM ()
wellTypedLoopArg ArgSource
BodyResult [VName]
sparams Pat ParamType
mergepat'' Exp
loopbody'

        forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Ord a => [a] -> [a]
nubOrd [VName]
sparams, Pat ParamType
mergepat'')

  ([VName]
sparams, Pat ParamType
mergepat', LoopFormBase Info VName
form', Exp
loopbody') <-
    case LoopFormBase NoInfo Name
form of
      For IdentBase NoInfo Name StructType
i UncheckedExp
uboundexp -> do
        Exp
uboundexp' <-
          Text -> [PrimType] -> Exp -> TermTypeM Exp
require Text
"being the bound in a 'for' loop" [PrimType]
anySignedType
            forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< UncheckedExp -> TermTypeM Exp
checkExp UncheckedExp
uboundexp
        StructType
bound_t <- Exp -> TermTypeM StructType
expTypeFully Exp
uboundexp'
        forall a.
IdentBase NoInfo Name StructType
-> StructType -> (Ident StructType -> TermTypeM a) -> TermTypeM a
bindingIdent IdentBase NoInfo Name StructType
i StructType
bound_t forall a b. (a -> b) -> a -> b
$ \Ident StructType
i' ->
          forall u a.
[SizeBinder VName]
-> UncheckedPat (TypeBase Exp u)
-> StructType
-> (Pat ParamType -> TermTypeM a)
-> TermTypeM a
bindingPat [] UncheckedPat ParamType
mergepat StructType
merge_t forall a b. (a -> b) -> a -> b
$
            \Pat ParamType
mergepat' -> forall a. TermTypeM a -> TermTypeM a
incLevel forall a b. (a -> b) -> a -> b
$ do
              Exp
loopbody' <- UncheckedExp -> TermTypeM Exp
checkExp UncheckedExp
loopbody
              ([VName]
sparams, Pat ParamType
mergepat'') <- Pat ParamType -> Exp -> TermTypeM ([VName], Pat ParamType)
checkLoopReturnSize Pat ParamType
mergepat' Exp
loopbody'
              forall (f :: * -> *) a. Applicative f => a -> f a
pure
                ( [VName]
sparams,
                  Pat ParamType
mergepat'',
                  forall (f :: * -> *) vn.
IdentBase f vn StructType -> ExpBase f vn -> LoopFormBase f vn
For Ident StructType
i' Exp
uboundexp',
                  Exp
loopbody'
                )
      ForIn PatBase NoInfo Name StructType
xpat UncheckedExp
e -> do
        (StructType
arr_t, StructType
_) <- Usage -> Name -> Level -> TermTypeM (StructType, StructType)
newArrayType (forall a. Located a => a -> Usage
mkUsage' (forall a. Located a => a -> SrcLoc
srclocOf UncheckedExp
e)) Name
"e" Level
1
        Exp
e' <- Text -> StructType -> Exp -> TermTypeM Exp
unifies Text
"being iterated in a 'for-in' loop" StructType
arr_t forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< UncheckedExp -> TermTypeM Exp
checkExp UncheckedExp
e
        StructType
t <- Exp -> TermTypeM StructType
expTypeFully Exp
e'
        case StructType
t of
          StructType
_
            | Just StructType
t' <- forall dim u. Level -> TypeBase dim u -> Maybe (TypeBase dim u)
peelArray Level
1 StructType
t ->
                forall u a.
[SizeBinder VName]
-> UncheckedPat (TypeBase Exp u)
-> StructType
-> (Pat ParamType -> TermTypeM a)
-> TermTypeM a
bindingPat [] PatBase NoInfo Name StructType
xpat StructType
t' forall a b. (a -> b) -> a -> b
$ \Pat ParamType
xpat' ->
                  forall u a.
[SizeBinder VName]
-> UncheckedPat (TypeBase Exp u)
-> StructType
-> (Pat ParamType -> TermTypeM a)
-> TermTypeM a
bindingPat [] UncheckedPat ParamType
mergepat StructType
merge_t forall a b. (a -> b) -> a -> b
$
                    \Pat ParamType
mergepat' -> forall a. TermTypeM a -> TermTypeM a
incLevel forall a b. (a -> b) -> a -> b
$ do
                      Exp
loopbody' <- UncheckedExp -> TermTypeM Exp
checkExp UncheckedExp
loopbody
                      ([VName]
sparams, Pat ParamType
mergepat'') <- Pat ParamType -> Exp -> TermTypeM ([VName], Pat ParamType)
checkLoopReturnSize Pat ParamType
mergepat' Exp
loopbody'
                      forall (f :: * -> *) a. Applicative f => a -> f a
pure
                        ( [VName]
sparams,
                          Pat ParamType
mergepat'',
                          forall (f :: * -> *) vn.
PatBase f vn StructType -> ExpBase f vn -> LoopFormBase f vn
ForIn (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct Pat ParamType
xpat') Exp
e',
                          Exp
loopbody'
                        )
            | Bool
otherwise ->
                forall (m :: * -> *) loc a.
(MonadTypeChecker m, Located loc) =>
loc -> Notes -> Doc () -> m a
typeError (forall a. Located a => a -> SrcLoc
srclocOf UncheckedExp
e) forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$
                  Doc ()
"Iteratee of a for-in loop must be an array, but expression has type"
                    forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a ann. Pretty a => a -> Doc ann
pretty StructType
t
      While UncheckedExp
cond ->
        forall u a.
[SizeBinder VName]
-> UncheckedPat (TypeBase Exp u)
-> StructType
-> (Pat ParamType -> TermTypeM a)
-> TermTypeM a
bindingPat [] UncheckedPat ParamType
mergepat StructType
merge_t forall a b. (a -> b) -> a -> b
$ \Pat ParamType
mergepat' ->
          forall a. TermTypeM a -> TermTypeM a
incLevel forall a b. (a -> b) -> a -> b
$ do
            Exp
cond' <-
              UncheckedExp -> TermTypeM Exp
checkExp UncheckedExp
cond
                forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Text -> StructType -> Exp -> TermTypeM Exp
unifies Text
"being the condition of a 'while' loop" (forall dim u. ScalarTypeBase dim u -> TypeBase dim u
Scalar forall a b. (a -> b) -> a -> b
$ forall dim u. PrimType -> ScalarTypeBase dim u
Prim PrimType
Bool)
            Exp
loopbody' <- UncheckedExp -> TermTypeM Exp
checkExp UncheckedExp
loopbody
            ([VName]
sparams, Pat ParamType
mergepat'') <- Pat ParamType -> Exp -> TermTypeM ([VName], Pat ParamType)
checkLoopReturnSize Pat ParamType
mergepat' Exp
loopbody'
            forall (f :: * -> *) a. Applicative f => a -> f a
pure
              ( [VName]
sparams,
                Pat ParamType
mergepat'',
                forall (f :: * -> *) vn. ExpBase f vn -> LoopFormBase f vn
While Exp
cond',
                Exp
loopbody'
              )

  -- dim handling (4)
  ArgSource -> [VName] -> Pat ParamType -> Exp -> TermTypeM ()
wellTypedLoopArg ArgSource
Initial [VName]
sparams Pat ParamType
mergepat' Exp
mergeexp'

  (ParamType
loopt, [VName]
retext) <-
    forall u.
Usage
-> Rigidity
-> Name
-> [VName]
-> TypeBase Exp u
-> TermTypeM (TypeBase Exp u, [VName])
freshDimsInType
      (forall a. Located a => a -> Text -> Usage
mkUsage SrcLoc
loc Text
"inference of loop result type")
      (RigidSource -> Rigidity
Rigid RigidSource
RigidLoop)
      Name
"loop"
      [VName]
sparams
      (forall d u. Pat (TypeBase d u) -> TypeBase d u
patternType Pat ParamType
mergepat')
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( ([VName]
sparams, Pat ParamType
mergepat', Exp
mergeexp', LoopFormBase Info VName
form', Exp
loopbody'),
      StructType -> [VName] -> AppRes
AppRes (forall dim u. TypeBase dim u -> TypeBase dim NoUniqueness
toStruct ParamType
loopt) [VName]
retext
    )