{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE Safe #-}
module Futhark.Internalise.Bindings
  (
  -- * Internalising bindings
    bindingParams
  , bindingLambdaParams
  , stmPattern
  , MatchPattern
  )
  where

import Control.Monad.State  hiding (mapM)
import Control.Monad.Reader hiding (mapM)
import Control.Monad.Writer hiding (mapM)

import qualified Data.Map.Strict as M
import qualified Data.Set as S

import Language.Futhark as E hiding (matchDims)
import qualified Futhark.IR.SOACS as I
import Futhark.MonadFreshNames
import Futhark.Internalise.Monad
import Futhark.Internalise.TypesValues
import Futhark.Internalise.AccurateSizes
import Futhark.Util

bindingParams :: [E.TypeParam] -> [E.Pattern]
              -> ([I.FParam] -> [[I.FParam]] -> InternaliseM a)
              -> InternaliseM a
bindingParams :: [TypeParam]
-> [Pattern]
-> ([FParam] -> [[FParam]] -> InternaliseM a)
-> InternaliseM a
bindingParams [TypeParam]
tparams [Pattern]
params [FParam] -> [[FParam]] -> InternaliseM a
m = do
  [[((Ident, VName), StructType)]]
flattened_params <- (Pattern -> InternaliseM [((Ident, VName), StructType)])
-> [Pattern] -> InternaliseM [[((Ident, VName), StructType)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Pattern -> InternaliseM [((Ident, VName), StructType)]
forall (m :: * -> *).
MonadFreshNames m =>
Pattern -> m [((Ident, VName), StructType)]
flattenPattern [Pattern]
params
  let ([(Ident, VName)]
params_idents, [StructType]
params_types) = [((Ident, VName), StructType)] -> ([(Ident, VName)], [StructType])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((Ident, VName), StructType)]
 -> ([(Ident, VName)], [StructType]))
-> [((Ident, VName), StructType)]
-> ([(Ident, VName)], [StructType])
forall a b. (a -> b) -> a -> b
$ [[((Ident, VName), StructType)]] -> [((Ident, VName), StructType)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[((Ident, VName), StructType)]]
flattened_params
      bound :: BoundInTypes
bound = [TypeParam] -> BoundInTypes
boundInTypes [TypeParam]
tparams
      param_names :: Map VName VName
param_names = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [ (Ident -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
E.identName Ident
x, VName
y) | (Ident
x,VName
y) <- [(Ident, VName)]
params_idents ]
  [[TypeBase ExtShape Uniqueness]]
params_ts <- BoundInTypes
-> Map VName VName
-> [StructType]
-> InternaliseM [[TypeBase ExtShape Uniqueness]]
internaliseParamTypes BoundInTypes
bound Map VName VName
param_names [StructType]
params_types
  let num_param_idents :: [Int]
num_param_idents = ([((Ident, VName), StructType)] -> Int)
-> [[((Ident, VName), StructType)]] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map [((Ident, VName), StructType)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[((Ident, VName), StructType)]]
flattened_params
      num_param_ts :: [Int]
num_param_ts = ([[TypeBase ExtShape Uniqueness]] -> Int)
-> [[[TypeBase ExtShape Uniqueness]]] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int)
-> ([[TypeBase ExtShape Uniqueness]] -> [Int])
-> [[TypeBase ExtShape Uniqueness]]
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([TypeBase ExtShape Uniqueness] -> Int)
-> [[TypeBase ExtShape Uniqueness]] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map [TypeBase ExtShape Uniqueness] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) ([[[TypeBase ExtShape Uniqueness]]] -> [Int])
-> [[[TypeBase ExtShape Uniqueness]]] -> [Int]
forall a b. (a -> b) -> a -> b
$ [Int]
-> [[TypeBase ExtShape Uniqueness]]
-> [[[TypeBase ExtShape Uniqueness]]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
num_param_idents [[TypeBase ExtShape Uniqueness]]
params_ts

  ([[TypeBase Shape Uniqueness]]
params_ts', [[Param (TypeBase Shape Uniqueness)]]
unnamed_shape_params) <-
    ([([TypeBase Shape Uniqueness],
   [Param (TypeBase Shape Uniqueness)])]
 -> ([[TypeBase Shape Uniqueness]],
     [[Param (TypeBase Shape Uniqueness)]]))
-> InternaliseM
     [([TypeBase Shape Uniqueness],
       [Param (TypeBase Shape Uniqueness)])]
-> InternaliseM
     ([[TypeBase Shape Uniqueness]],
      [[Param (TypeBase Shape Uniqueness)]])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [([TypeBase Shape Uniqueness],
  [Param (TypeBase Shape Uniqueness)])]
-> ([[TypeBase Shape Uniqueness]],
    [[Param (TypeBase Shape Uniqueness)]])
forall a b. [(a, b)] -> ([a], [b])
unzip (InternaliseM
   [([TypeBase Shape Uniqueness],
     [Param (TypeBase Shape Uniqueness)])]
 -> InternaliseM
      ([[TypeBase Shape Uniqueness]],
       [[Param (TypeBase Shape Uniqueness)]]))
-> InternaliseM
     [([TypeBase Shape Uniqueness],
       [Param (TypeBase Shape Uniqueness)])]
-> InternaliseM
     ([[TypeBase Shape Uniqueness]],
      [[Param (TypeBase Shape Uniqueness)]])
forall a b. (a -> b) -> a -> b
$ [[TypeBase ExtShape Uniqueness]]
-> ([TypeBase ExtShape Uniqueness]
    -> InternaliseM
         ([TypeBase Shape Uniqueness], [Param (TypeBase Shape Uniqueness)]))
-> InternaliseM
     [([TypeBase Shape Uniqueness],
       [Param (TypeBase Shape Uniqueness)])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [[TypeBase ExtShape Uniqueness]]
params_ts (([TypeBase ExtShape Uniqueness]
  -> InternaliseM
       ([TypeBase Shape Uniqueness], [Param (TypeBase Shape Uniqueness)]))
 -> InternaliseM
      [([TypeBase Shape Uniqueness],
        [Param (TypeBase Shape Uniqueness)])])
-> ([TypeBase ExtShape Uniqueness]
    -> InternaliseM
         ([TypeBase Shape Uniqueness], [Param (TypeBase Shape Uniqueness)]))
-> InternaliseM
     [([TypeBase Shape Uniqueness],
       [Param (TypeBase Shape Uniqueness)])]
forall a b. (a -> b) -> a -> b
$ \[TypeBase ExtShape Uniqueness]
param_ts -> do
      ([TypeBase Shape Uniqueness]
param_ts', [Param (TypeBase Shape Uniqueness)]
param_unnamed_dims) <- Map Int Ident
-> [TypeBase ExtShape Uniqueness]
-> InternaliseM ([TypeBase Shape Uniqueness], [FParam])
forall (m :: * -> *).
MonadFreshNames m =>
Map Int Ident
-> [TypeBase ExtShape Uniqueness]
-> m ([TypeBase Shape Uniqueness], [FParam])
instantiateShapesWithDecls Map Int Ident
forall a. Monoid a => a
mempty [TypeBase ExtShape Uniqueness]
param_ts

      ([TypeBase Shape Uniqueness], [Param (TypeBase Shape Uniqueness)])
-> InternaliseM
     ([TypeBase Shape Uniqueness], [Param (TypeBase Shape Uniqueness)])
forall (m :: * -> *) a. Monad m => a -> m a
return ([TypeBase Shape Uniqueness]
param_ts',
              [Param (TypeBase Shape Uniqueness)]
param_unnamed_dims)

  let named_shape_params :: [Param (TypeBase shape u)]
named_shape_params = [ VName -> TypeBase shape u -> Param (TypeBase shape u)
forall dec. VName -> dec -> Param dec
I.Param VName
v (TypeBase shape u -> Param (TypeBase shape u))
-> TypeBase shape u -> Param (TypeBase shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
I.Prim PrimType
I.int32 | E.TypeParamDim VName
v SrcLoc
_ <- [TypeParam]
tparams ]
      shape_params :: [Param (TypeBase Shape Uniqueness)]
shape_params = [Param (TypeBase Shape Uniqueness)]
forall shape u. [Param (TypeBase shape u)]
named_shape_params [Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
forall a. [a] -> [a] -> [a]
++ [[Param (TypeBase Shape Uniqueness)]]
-> [Param (TypeBase Shape Uniqueness)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Param (TypeBase Shape Uniqueness)]]
unnamed_shape_params
      shape_subst :: Map VName [SubExp]
shape_subst = [(VName, [SubExp])] -> Map VName [SubExp]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [ (Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
I.paramName Param (TypeBase Shape Uniqueness)
p, [VName -> SubExp
I.Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
I.paramName Param (TypeBase Shape Uniqueness)
p]) | Param (TypeBase Shape Uniqueness)
p <- [Param (TypeBase Shape Uniqueness)]
shape_params ]
  [(Ident, VName)]
-> [TypeBase Shape Uniqueness]
-> ([[Param (TypeBase Shape Uniqueness)]] -> InternaliseM a)
-> InternaliseM a
forall t a.
Show t =>
[(Ident, VName)]
-> [t] -> ([[Param t]] -> InternaliseM a) -> InternaliseM a
bindingFlatPattern [(Ident, VName)]
params_idents ([[TypeBase Shape Uniqueness]] -> [TypeBase Shape Uniqueness]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[TypeBase Shape Uniqueness]]
params_ts') (([[Param (TypeBase Shape Uniqueness)]] -> InternaliseM a)
 -> InternaliseM a)
-> ([[Param (TypeBase Shape Uniqueness)]] -> InternaliseM a)
-> InternaliseM a
forall a b. (a -> b) -> a -> b
$ \[[Param (TypeBase Shape Uniqueness)]]
valueparams ->
    Scope SOACS -> InternaliseM a -> InternaliseM a
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
I.localScope ([Param (TypeBase Shape Uniqueness)] -> Scope SOACS
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
I.scopeOfFParams ([Param (TypeBase Shape Uniqueness)] -> Scope SOACS)
-> [Param (TypeBase Shape Uniqueness)] -> Scope SOACS
forall a b. (a -> b) -> a -> b
$ [Param (TypeBase Shape Uniqueness)]
shape_params[Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
-> [Param (TypeBase Shape Uniqueness)]
forall a. [a] -> [a] -> [a]
++[[Param (TypeBase Shape Uniqueness)]]
-> [Param (TypeBase Shape Uniqueness)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Param (TypeBase Shape Uniqueness)]]
valueparams) (InternaliseM a -> InternaliseM a)
-> InternaliseM a -> InternaliseM a
forall a b. (a -> b) -> a -> b
$
    Map VName [SubExp] -> InternaliseM a -> InternaliseM a
forall a. Map VName [SubExp] -> InternaliseM a -> InternaliseM a
substitutingVars Map VName [SubExp]
shape_subst (InternaliseM a -> InternaliseM a)
-> InternaliseM a -> InternaliseM a
forall a b. (a -> b) -> a -> b
$ [FParam] -> [[FParam]] -> InternaliseM a
m [Param (TypeBase Shape Uniqueness)]
[FParam]
shape_params ([[FParam]] -> InternaliseM a) -> [[FParam]] -> InternaliseM a
forall a b. (a -> b) -> a -> b
$ [Int]
-> [Param (TypeBase Shape Uniqueness)]
-> [[Param (TypeBase Shape Uniqueness)]]
forall a. [Int] -> [a] -> [[a]]
chunks [Int]
num_param_ts ([[Param (TypeBase Shape Uniqueness)]]
-> [Param (TypeBase Shape Uniqueness)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Param (TypeBase Shape Uniqueness)]]
valueparams)

bindingLambdaParams :: [E.Pattern] -> [I.Type]
                    -> ([I.LParam] -> InternaliseM a)
                    -> InternaliseM a
bindingLambdaParams :: [Pattern]
-> [Type] -> ([LParam] -> InternaliseM a) -> InternaliseM a
bindingLambdaParams [Pattern]
params [Type]
ts [LParam] -> InternaliseM a
m = do
  ([(Ident, VName)]
params_idents, [StructType]
params_types) <-
    [((Ident, VName), StructType)] -> ([(Ident, VName)], [StructType])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((Ident, VName), StructType)]
 -> ([(Ident, VName)], [StructType]))
-> ([[((Ident, VName), StructType)]]
    -> [((Ident, VName), StructType)])
-> [[((Ident, VName), StructType)]]
-> ([(Ident, VName)], [StructType])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[((Ident, VName), StructType)]] -> [((Ident, VName), StructType)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[((Ident, VName), StructType)]]
 -> ([(Ident, VName)], [StructType]))
-> InternaliseM [[((Ident, VName), StructType)]]
-> InternaliseM ([(Ident, VName)], [StructType])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern -> InternaliseM [((Ident, VName), StructType)])
-> [Pattern] -> InternaliseM [[((Ident, VName), StructType)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Pattern -> InternaliseM [((Ident, VName), StructType)]
forall (m :: * -> *).
MonadFreshNames m =>
Pattern -> m [((Ident, VName), StructType)]
flattenPattern [Pattern]
params
  let param_names :: Map VName VName
param_names = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [ (Ident -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
E.identName Ident
x, VName
y) | (Ident
x,VName
y) <- [(Ident, VName)]
params_idents ]
  [[TypeBase ExtShape Uniqueness]]
params_ts <- BoundInTypes
-> Map VName VName
-> [StructType]
-> InternaliseM [[TypeBase ExtShape Uniqueness]]
internaliseParamTypes BoundInTypes
forall a. Monoid a => a
mempty Map VName VName
param_names [StructType]
params_types

  let ascript_substs :: Map VName [SubExp]
ascript_substs = [TypeBase ExtShape Uniqueness] -> [Type] -> Map VName [SubExp]
lambdaShapeSubstitutions ([[TypeBase ExtShape Uniqueness]] -> [TypeBase ExtShape Uniqueness]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[TypeBase ExtShape Uniqueness]]
params_ts) [Type]
ts

  [(Ident, VName)]
-> [Type] -> ([[Param Type]] -> InternaliseM a) -> InternaliseM a
forall t a.
Show t =>
[(Ident, VName)]
-> [t] -> ([[Param t]] -> InternaliseM a) -> InternaliseM a
bindingFlatPattern [(Ident, VName)]
params_idents [Type]
ts (([[Param Type]] -> InternaliseM a) -> InternaliseM a)
-> ([[Param Type]] -> InternaliseM a) -> InternaliseM a
forall a b. (a -> b) -> a -> b
$ \[[Param Type]]
params' ->
    (InternaliseEnv -> InternaliseEnv)
-> InternaliseM a -> InternaliseM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\InternaliseEnv
env -> InternaliseEnv
env { envSubsts :: Map VName [SubExp]
envSubsts = Map VName [SubExp]
ascript_substs Map VName [SubExp] -> Map VName [SubExp] -> Map VName [SubExp]
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` InternaliseEnv -> Map VName [SubExp]
envSubsts InternaliseEnv
env }) (InternaliseM a -> InternaliseM a)
-> InternaliseM a -> InternaliseM a
forall a b. (a -> b) -> a -> b
$
    Scope SOACS -> InternaliseM a -> InternaliseM a
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
I.localScope ([Param Type] -> Scope SOACS
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
I.scopeOfLParams ([Param Type] -> Scope SOACS) -> [Param Type] -> Scope SOACS
forall a b. (a -> b) -> a -> b
$ [[Param Type]] -> [Param Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Param Type]]
params') (InternaliseM a -> InternaliseM a)
-> InternaliseM a -> InternaliseM a
forall a b. (a -> b) -> a -> b
$ [LParam] -> InternaliseM a
m ([LParam] -> InternaliseM a) -> [LParam] -> InternaliseM a
forall a b. (a -> b) -> a -> b
$ [[Param Type]] -> [Param Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Param Type]]
params'

processFlatPattern :: Show t => [(E.Ident,VName)] -> [t]
                   -> InternaliseM ([[I.Param t]], VarSubstitutions)
processFlatPattern :: [(Ident, VName)]
-> [t] -> InternaliseM ([[Param t]], Map VName [SubExp])
processFlatPattern [(Ident, VName)]
x [t]
y = [([Param t], (VName, [SubExp]))]
-> [(Ident, VName)]
-> [t]
-> InternaliseM ([[Param t]], Map VName [SubExp])
forall dec.
[([Param dec], (VName, [SubExp]))]
-> [(Ident, VName)]
-> [dec]
-> InternaliseM ([[Param dec]], Map VName [SubExp])
processFlatPattern' [] [(Ident, VName)]
x [t]
y
  where
    processFlatPattern' :: [([Param dec], (VName, [SubExp]))]
-> [(Ident, VName)]
-> [dec]
-> InternaliseM ([[Param dec]], Map VName [SubExp])
processFlatPattern' [([Param dec], (VName, [SubExp]))]
pat []       [dec]
_  = do
      let ([[Param dec]]
vs, [(VName, [SubExp])]
substs) = [([Param dec], (VName, [SubExp]))]
-> ([[Param dec]], [(VName, [SubExp])])
forall a b. [(a, b)] -> ([a], [b])
unzip [([Param dec], (VName, [SubExp]))]
pat
          substs' :: Map VName [SubExp]
substs' = [(VName, [SubExp])] -> Map VName [SubExp]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName, [SubExp])]
substs
          idents :: [[Param dec]]
idents = [[Param dec]] -> [[Param dec]]
forall a. [a] -> [a]
reverse [[Param dec]]
vs
      ([[Param dec]], Map VName [SubExp])
-> InternaliseM ([[Param dec]], Map VName [SubExp])
forall (m :: * -> *) a. Monad m => a -> m a
return ([[Param dec]]
idents, Map VName [SubExp]
substs')

    processFlatPattern' [([Param dec], (VName, [SubExp]))]
pat ((Ident
p,VName
name):[(Ident, VName)]
rest) [dec]
ts = do
      ([Param dec]
ps, [Param dec]
subst, [dec]
rest_ts) <- [dec]
-> [(VName, TypeBase ExtShape Uniqueness)]
-> ([Param dec], [Param dec], [dec])
forall a b. [a] -> [(VName, b)] -> ([Param a], [Param a], [a])
handleMapping [dec]
ts ([(VName, TypeBase ExtShape Uniqueness)]
 -> ([Param dec], [Param dec], [dec]))
-> InternaliseM [(VName, TypeBase ExtShape Uniqueness)]
-> InternaliseM ([Param dec], [Param dec], [dec])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ident, VName)
-> InternaliseM [(VName, TypeBase ExtShape Uniqueness)]
internaliseBindee (Ident
p, VName
name)
      [([Param dec], (VName, [SubExp]))]
-> [(Ident, VName)]
-> [dec]
-> InternaliseM ([[Param dec]], Map VName [SubExp])
processFlatPattern' (([Param dec]
ps, (Ident -> VName
forall (f :: * -> *) vn. IdentBase f vn -> vn
E.identName Ident
p, (Param dec -> SubExp) -> [Param dec] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
I.Var (VName -> SubExp) -> (Param dec -> VName) -> Param dec -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
I.paramName) [Param dec]
subst)) ([Param dec], (VName, [SubExp]))
-> [([Param dec], (VName, [SubExp]))]
-> [([Param dec], (VName, [SubExp]))]
forall a. a -> [a] -> [a]
: [([Param dec], (VName, [SubExp]))]
pat) [(Ident, VName)]
rest [dec]
rest_ts

    handleMapping :: [a] -> [(VName, b)] -> ([Param a], [Param a], [a])
handleMapping [a]
ts [] =
      ([], [], [a]
ts)
    handleMapping [a]
ts ((VName, b)
r:[(VName, b)]
rs) =
        let ([Param a]
ps, Param a
reps, [a]
ts')    = [a] -> (VName, b) -> ([Param a], Param a, [a])
forall a b. [a] -> (VName, b) -> ([Param a], Param a, [a])
handleMapping' [a]
ts (VName, b)
r
            ([Param a]
pss, [Param a]
repss, [a]
ts'') = [a] -> [(VName, b)] -> ([Param a], [Param a], [a])
handleMapping [a]
ts' [(VName, b)]
rs
        in ([Param a]
ps[Param a] -> [Param a] -> [Param a]
forall a. [a] -> [a] -> [a]
++[Param a]
pss, Param a
repsParam a -> [Param a] -> [Param a]
forall a. a -> [a] -> [a]
:[Param a]
repss, [a]
ts'')

    handleMapping' :: [a] -> (VName, b) -> ([Param a], Param a, [a])
handleMapping' (a
t:[a]
ts) (VName
vname,b
_) =
      let v' :: Param a
v' = VName -> a -> Param a
forall dec. VName -> dec -> Param dec
I.Param VName
vname a
t
      in ([Param a
v'], Param a
v', [a]
ts)
    handleMapping' [] (VName, b)
_ =
      [Char] -> ([Param a], Param a, [a])
forall a. HasCallStack => [Char] -> a
error ([Char] -> ([Param a], Param a, [a]))
-> [Char] -> ([Param a], Param a, [a])
forall a b. (a -> b) -> a -> b
$ [Char]
"processFlatPattern: insufficient identifiers in pattern." [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ ([(Ident, VName)], [t]) -> [Char]
forall a. Show a => a -> [Char]
show ([(Ident, VName)]
x, [t]
y)

    internaliseBindee :: (E.Ident, VName) -> InternaliseM [(VName, I.DeclExtType)]
    internaliseBindee :: (Ident, VName)
-> InternaliseM [(VName, TypeBase ExtShape Uniqueness)]
internaliseBindee (Ident
bindee, VName
name) = do
      [[TypeBase ExtShape Uniqueness]]
tss <- BoundInTypes
-> Map VName VName
-> [StructType]
-> InternaliseM [[TypeBase ExtShape Uniqueness]]
internaliseParamTypes BoundInTypes
nothing_bound Map VName VName
forall a. Monoid a => a
mempty
             [(TypeBase (DimDecl VName) Aliasing -> () -> StructType)
-> () -> TypeBase (DimDecl VName) Aliasing -> StructType
forall a b c. (a -> b -> c) -> b -> a -> c
flip TypeBase (DimDecl VName) Aliasing -> () -> StructType
forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
E.setAliases () (TypeBase (DimDecl VName) Aliasing -> StructType)
-> TypeBase (DimDecl VName) Aliasing -> StructType
forall a b. (a -> b) -> a -> b
$ Info (TypeBase (DimDecl VName) Aliasing)
-> TypeBase (DimDecl VName) Aliasing
forall a. Info a -> a
E.unInfo (Info (TypeBase (DimDecl VName) Aliasing)
 -> TypeBase (DimDecl VName) Aliasing)
-> Info (TypeBase (DimDecl VName) Aliasing)
-> TypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ Ident -> Info (TypeBase (DimDecl VName) Aliasing)
forall (f :: * -> *) vn.
IdentBase f vn -> f (TypeBase (DimDecl VName) Aliasing)
E.identType Ident
bindee]
      case [[TypeBase ExtShape Uniqueness]] -> [TypeBase ExtShape Uniqueness]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[TypeBase ExtShape Uniqueness]]
tss of
        [TypeBase ExtShape Uniqueness
t] -> [(VName, TypeBase ExtShape Uniqueness)]
-> InternaliseM [(VName, TypeBase ExtShape Uniqueness)]
forall (m :: * -> *) a. Monad m => a -> m a
return [(VName
name, TypeBase ExtShape Uniqueness
t)]
        [TypeBase ExtShape Uniqueness]
tss' -> [TypeBase ExtShape Uniqueness]
-> (TypeBase ExtShape Uniqueness
    -> InternaliseM (VName, TypeBase ExtShape Uniqueness))
-> InternaliseM [(VName, TypeBase ExtShape Uniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [TypeBase ExtShape Uniqueness]
tss' ((TypeBase ExtShape Uniqueness
  -> InternaliseM (VName, TypeBase ExtShape Uniqueness))
 -> InternaliseM [(VName, TypeBase ExtShape Uniqueness)])
-> (TypeBase ExtShape Uniqueness
    -> InternaliseM (VName, TypeBase ExtShape Uniqueness))
-> InternaliseM [(VName, TypeBase ExtShape Uniqueness)]
forall a b. (a -> b) -> a -> b
$ \TypeBase ExtShape Uniqueness
t -> do
          VName
name' <- [Char] -> InternaliseM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> InternaliseM VName) -> [Char] -> InternaliseM VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
name
          (VName, TypeBase ExtShape Uniqueness)
-> InternaliseM (VName, TypeBase ExtShape Uniqueness)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
name', TypeBase ExtShape Uniqueness
t)

    -- Fixed up later.
    nothing_bound :: BoundInTypes
nothing_bound = [TypeParam] -> BoundInTypes
boundInTypes []

bindingFlatPattern :: Show t => [(E.Ident, VName)] -> [t]
                   -> ([[I.Param t]] -> InternaliseM a)
                   -> InternaliseM a
bindingFlatPattern :: [(Ident, VName)]
-> [t] -> ([[Param t]] -> InternaliseM a) -> InternaliseM a
bindingFlatPattern [(Ident, VName)]
idents [t]
ts [[Param t]] -> InternaliseM a
m = do
  ([[Param t]]
ps, Map VName [SubExp]
substs) <- [(Ident, VName)]
-> [t] -> InternaliseM ([[Param t]], Map VName [SubExp])
forall t.
Show t =>
[(Ident, VName)]
-> [t] -> InternaliseM ([[Param t]], Map VName [SubExp])
processFlatPattern [(Ident, VName)]
idents [t]
ts
  (InternaliseEnv -> InternaliseEnv)
-> InternaliseM a -> InternaliseM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\InternaliseEnv
env -> InternaliseEnv
env { envSubsts :: Map VName [SubExp]
envSubsts = Map VName [SubExp]
substs Map VName [SubExp] -> Map VName [SubExp] -> Map VName [SubExp]
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` InternaliseEnv -> Map VName [SubExp]
envSubsts InternaliseEnv
env}) (InternaliseM a -> InternaliseM a)
-> InternaliseM a -> InternaliseM a
forall a b. (a -> b) -> a -> b
$
    [[Param t]] -> InternaliseM a
m [[Param t]]
ps

-- | Flatten a pattern.  Returns a list of identifiers.  The
-- structural type of each identifier is returned separately.
flattenPattern :: MonadFreshNames m => E.Pattern -> m [((E.Ident, VName), E.StructType)]
flattenPattern :: Pattern -> m [((Ident, VName), StructType)]
flattenPattern = Pattern -> m [((Ident, VName), StructType)]
forall (m :: * -> *).
MonadFreshNames m =>
Pattern -> m [((Ident, VName), StructType)]
flattenPattern'
  where flattenPattern' :: Pattern -> f [((Ident, VName), StructType)]
flattenPattern' (E.PatternParens Pattern
p SrcLoc
_) =
          Pattern -> f [((Ident, VName), StructType)]
flattenPattern' Pattern
p
        flattenPattern' (E.Wildcard Info (TypeBase (DimDecl VName) Aliasing)
t SrcLoc
loc) = do
          VName
name <- [Char] -> f VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"nameless"
          Pattern -> f [((Ident, VName), StructType)]
flattenPattern' (Pattern -> f [((Ident, VName), StructType)])
-> Pattern -> f [((Ident, VName), StructType)]
forall a b. (a -> b) -> a -> b
$ VName
-> Info (TypeBase (DimDecl VName) Aliasing) -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
vn
-> f (TypeBase (DimDecl VName) Aliasing)
-> SrcLoc
-> PatternBase f vn
E.Id VName
name Info (TypeBase (DimDecl VName) Aliasing)
t SrcLoc
loc
        flattenPattern' (E.Id VName
v (Info TypeBase (DimDecl VName) Aliasing
t) SrcLoc
loc) = do
          VName
new_name <- [Char] -> f VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> f VName) -> [Char] -> f VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
v
          [((Ident, VName), StructType)] -> f [((Ident, VName), StructType)]
forall (m :: * -> *) a. Monad m => a -> m a
return [((VName
-> Info (TypeBase (DimDecl VName) Aliasing) -> SrcLoc -> Ident
forall (f :: * -> *) vn.
vn
-> f (TypeBase (DimDecl VName) Aliasing)
-> SrcLoc
-> IdentBase f vn
E.Ident VName
v (TypeBase (DimDecl VName) Aliasing
-> Info (TypeBase (DimDecl VName) Aliasing)
forall a. a -> Info a
Info TypeBase (DimDecl VName) Aliasing
t) SrcLoc
loc, VName
new_name),
                   TypeBase (DimDecl VName) Aliasing
t TypeBase (DimDecl VName) Aliasing -> () -> StructType
forall dim asf ast. TypeBase dim asf -> ast -> TypeBase dim ast
`E.setAliases` ())]
        -- XXX: treat empty tuples and records as bool.
        flattenPattern' (E.TuplePattern [] SrcLoc
loc) =
          Pattern -> f [((Ident, VName), StructType)]
flattenPattern' (Info (TypeBase (DimDecl VName) Aliasing) -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
f (TypeBase (DimDecl VName) Aliasing) -> SrcLoc -> PatternBase f vn
E.Wildcard (TypeBase (DimDecl VName) Aliasing
-> Info (TypeBase (DimDecl VName) Aliasing)
forall a. a -> Info a
Info (TypeBase (DimDecl VName) Aliasing
 -> Info (TypeBase (DimDecl VName) Aliasing))
-> TypeBase (DimDecl VName) Aliasing
-> Info (TypeBase (DimDecl VName) Aliasing)
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) Aliasing
-> TypeBase (DimDecl VName) Aliasing
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
E.Scalar (ScalarTypeBase (DimDecl VName) Aliasing
 -> TypeBase (DimDecl VName) Aliasing)
-> ScalarTypeBase (DimDecl VName) Aliasing
-> TypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. PrimType -> ScalarTypeBase dim as
E.Prim PrimType
E.Bool) SrcLoc
loc)
        flattenPattern' (E.RecordPattern [] SrcLoc
loc) =
          Pattern -> f [((Ident, VName), StructType)]
flattenPattern' (Info (TypeBase (DimDecl VName) Aliasing) -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
f (TypeBase (DimDecl VName) Aliasing) -> SrcLoc -> PatternBase f vn
E.Wildcard (TypeBase (DimDecl VName) Aliasing
-> Info (TypeBase (DimDecl VName) Aliasing)
forall a. a -> Info a
Info (TypeBase (DimDecl VName) Aliasing
 -> Info (TypeBase (DimDecl VName) Aliasing))
-> TypeBase (DimDecl VName) Aliasing
-> Info (TypeBase (DimDecl VName) Aliasing)
forall a b. (a -> b) -> a -> b
$ ScalarTypeBase (DimDecl VName) Aliasing
-> TypeBase (DimDecl VName) Aliasing
forall dim as. ScalarTypeBase dim as -> TypeBase dim as
E.Scalar (ScalarTypeBase (DimDecl VName) Aliasing
 -> TypeBase (DimDecl VName) Aliasing)
-> ScalarTypeBase (DimDecl VName) Aliasing
-> TypeBase (DimDecl VName) Aliasing
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarTypeBase (DimDecl VName) Aliasing
forall dim as. PrimType -> ScalarTypeBase dim as
E.Prim PrimType
E.Bool) SrcLoc
loc)
        flattenPattern' (E.TuplePattern [Pattern]
pats SrcLoc
_) =
          [[((Ident, VName), StructType)]] -> [((Ident, VName), StructType)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[((Ident, VName), StructType)]]
 -> [((Ident, VName), StructType)])
-> f [[((Ident, VName), StructType)]]
-> f [((Ident, VName), StructType)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern -> f [((Ident, VName), StructType)])
-> [Pattern] -> f [[((Ident, VName), StructType)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Pattern -> f [((Ident, VName), StructType)]
flattenPattern' [Pattern]
pats
        flattenPattern' (E.RecordPattern [(Name, Pattern)]
fs SrcLoc
loc) =
          Pattern -> f [((Ident, VName), StructType)]
flattenPattern' (Pattern -> f [((Ident, VName), StructType)])
-> Pattern -> f [((Ident, VName), StructType)]
forall a b. (a -> b) -> a -> b
$ [Pattern] -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
[PatternBase f vn] -> SrcLoc -> PatternBase f vn
E.TuplePattern (((Name, Pattern) -> Pattern) -> [(Name, Pattern)] -> [Pattern]
forall a b. (a -> b) -> [a] -> [b]
map (Name, Pattern) -> Pattern
forall a b. (a, b) -> b
snd ([(Name, Pattern)] -> [Pattern]) -> [(Name, Pattern)] -> [Pattern]
forall a b. (a -> b) -> a -> b
$ Map Name Pattern -> [(Name, Pattern)]
forall a. Map Name a -> [(Name, a)]
sortFields (Map Name Pattern -> [(Name, Pattern)])
-> Map Name Pattern -> [(Name, Pattern)]
forall a b. (a -> b) -> a -> b
$ [(Name, Pattern)] -> Map Name Pattern
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, Pattern)]
fs) SrcLoc
loc
        flattenPattern' (E.PatternAscription Pattern
p TypeDeclBase Info VName
_ SrcLoc
_) =
          Pattern -> f [((Ident, VName), StructType)]
flattenPattern' Pattern
p
        flattenPattern' (E.PatternLit ExpBase Info VName
_ Info (TypeBase (DimDecl VName) Aliasing)
t SrcLoc
loc) =
          Pattern -> f [((Ident, VName), StructType)]
flattenPattern' (Pattern -> f [((Ident, VName), StructType)])
-> Pattern -> f [((Ident, VName), StructType)]
forall a b. (a -> b) -> a -> b
$ Info (TypeBase (DimDecl VName) Aliasing) -> SrcLoc -> Pattern
forall (f :: * -> *) vn.
f (TypeBase (DimDecl VName) Aliasing) -> SrcLoc -> PatternBase f vn
E.Wildcard Info (TypeBase (DimDecl VName) Aliasing)
t SrcLoc
loc
        flattenPattern' (E.PatternConstr Name
_ Info (TypeBase (DimDecl VName) Aliasing)
_ [Pattern]
ps SrcLoc
_) =
          [[((Ident, VName), StructType)]] -> [((Ident, VName), StructType)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[((Ident, VName), StructType)]]
 -> [((Ident, VName), StructType)])
-> f [[((Ident, VName), StructType)]]
-> f [((Ident, VName), StructType)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Pattern -> f [((Ident, VName), StructType)])
-> [Pattern] -> f [[((Ident, VName), StructType)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Pattern -> f [((Ident, VName), StructType)]
flattenPattern' [Pattern]
ps

type MatchPattern = SrcLoc -> [I.SubExp] -> InternaliseM [I.SubExp]

stmPattern :: E.Pattern -> [I.ExtType]
           -> ([VName] -> MatchPattern -> InternaliseM a)
           -> InternaliseM a
stmPattern :: Pattern
-> [ExtType]
-> ([VName] -> MatchPattern -> InternaliseM a)
-> InternaliseM a
stmPattern Pattern
pat [ExtType]
ts [VName] -> MatchPattern -> InternaliseM a
m = do
  ([(Ident, VName)]
pat', [StructType]
pat_types) <- [((Ident, VName), StructType)] -> ([(Ident, VName)], [StructType])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((Ident, VName), StructType)]
 -> ([(Ident, VName)], [StructType]))
-> InternaliseM [((Ident, VName), StructType)]
-> InternaliseM ([(Ident, VName)], [StructType])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern -> InternaliseM [((Ident, VName), StructType)]
forall (m :: * -> *).
MonadFreshNames m =>
Pattern -> m [((Ident, VName), StructType)]
flattenPattern Pattern
pat
  ([Type]
ts',[Ident]
_) <- [ExtType] -> InternaliseM ([Type], [Ident])
forall (m :: * -> *) u.
MonadFreshNames m =>
[TypeBase ExtShape u] -> m ([TypeBase Shape u], [Ident])
instantiateShapes' [ExtType]
ts
  [[TypeBase ExtShape Uniqueness]]
pat_types' <- BoundInTypes
-> Map VName VName
-> [StructType]
-> InternaliseM [[TypeBase ExtShape Uniqueness]]
internaliseParamTypes BoundInTypes
forall a. Monoid a => a
mempty Map VName VName
forall a. Monoid a => a
mempty [StructType]
pat_types
  let pat_types'' :: [ExtType]
pat_types'' = (TypeBase ExtShape Uniqueness -> ExtType)
-> [TypeBase ExtShape Uniqueness] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase ExtShape Uniqueness -> ExtType
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
I.fromDecl ([TypeBase ExtShape Uniqueness] -> [ExtType])
-> [TypeBase ExtShape Uniqueness] -> [ExtType]
forall a b. (a -> b) -> a -> b
$ [[TypeBase ExtShape Uniqueness]] -> [TypeBase ExtShape Uniqueness]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[TypeBase ExtShape Uniqueness]]
pat_types'
  let addShapeStms :: t [Param dec] -> InternaliseM a
addShapeStms t [Param dec]
l =
        [VName] -> MatchPattern -> InternaliseM a
m ((Param dec -> VName) -> [Param dec] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param dec -> VName
forall dec. Param dec -> VName
I.paramName ([Param dec] -> [VName]) -> [Param dec] -> [VName]
forall a b. (a -> b) -> a -> b
$ t [Param dec] -> [Param dec]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat t [Param dec]
l) ([ExtType] -> MatchPattern
matchPattern [ExtType]
pat_types'')
  [(Ident, VName)]
-> [Type] -> ([[Param Type]] -> InternaliseM a) -> InternaliseM a
forall t a.
Show t =>
[(Ident, VName)]
-> [t] -> ([[Param t]] -> InternaliseM a) -> InternaliseM a
bindingFlatPattern [(Ident, VName)]
pat' [Type]
ts' [[Param Type]] -> InternaliseM a
forall (t :: * -> *) dec.
Foldable t =>
t [Param dec] -> InternaliseM a
addShapeStms

matchPattern :: [I.ExtType] -> MatchPattern
matchPattern :: [ExtType] -> MatchPattern
matchPattern [ExtType]
exts SrcLoc
loc [SubExp]
ses =
  [(ExtType, SubExp)]
-> ((ExtType, SubExp) -> InternaliseM SubExp)
-> InternaliseM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([ExtType] -> [SubExp] -> [(ExtType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ExtType]
exts [SubExp]
ses) (((ExtType, SubExp) -> InternaliseM SubExp)
 -> InternaliseM [SubExp])
-> ((ExtType, SubExp) -> InternaliseM SubExp)
-> InternaliseM [SubExp]
forall a b. (a -> b) -> a -> b
$ \(ExtType
et, SubExp
se) -> do
  Type
se_t <- SubExp -> InternaliseM Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
I.subExpType SubExp
se
  ExtType
et' <- Set VName -> ExtType -> Type -> InternaliseM ExtType
unExistentialise Set VName
forall a. Monoid a => a
mempty ExtType
et Type
se_t
  ErrorMsg SubExp
-> SrcLoc -> ExtType -> [Char] -> SubExp -> InternaliseM SubExp
ensureExtShape ([ErrorMsgPart SubExp] -> ErrorMsg SubExp
forall a. [ErrorMsgPart a] -> ErrorMsg a
I.ErrorMsg [[Char] -> ErrorMsgPart SubExp
forall a. [Char] -> ErrorMsgPart a
I.ErrorString [Char]
"value cannot match pattern"])
    SrcLoc
loc ExtType
et' [Char]
"correct_shape" SubExp
se

unExistentialise :: S.Set VName -> I.ExtType -> I.Type -> InternaliseM I.ExtType
unExistentialise :: Set VName -> ExtType -> Type -> InternaliseM ExtType
unExistentialise Set VName
tparam_names ExtType
et Type
t = do
  [Ext SubExp]
new_dims <- (Ext SubExp -> SubExp -> InternaliseM (Ext SubExp))
-> [Ext SubExp] -> [SubExp] -> InternaliseM [Ext SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Ext SubExp -> SubExp -> InternaliseM (Ext SubExp)
forall (m :: * -> *).
MonadBinder m =>
Ext SubExp -> SubExp -> m (Ext SubExp)
inspectDim (ExtShape -> [Ext SubExp]
forall d. ShapeBase d -> [d]
I.shapeDims (ExtShape -> [Ext SubExp]) -> ExtShape -> [Ext SubExp]
forall a b. (a -> b) -> a -> b
$ ExtType -> ExtShape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
I.arrayShape ExtType
et) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
I.arrayDims Type
t)
  ExtType -> InternaliseM ExtType
forall (m :: * -> *) a. Monad m => a -> m a
return (ExtType -> InternaliseM ExtType)
-> ExtType -> InternaliseM ExtType
forall a b. (a -> b) -> a -> b
$ Type
t Type -> ExtShape -> ExtType
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`I.setArrayShape` [Ext SubExp] -> ExtShape
forall d. [d] -> ShapeBase d
I.Shape [Ext SubExp]
new_dims
  where inspectDim :: Ext SubExp -> SubExp -> m (Ext SubExp)
inspectDim (I.Free (I.Var VName
v)) SubExp
d
          | VName
v VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
tparam_names = do
              [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
v] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
I.BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
I.SubExp SubExp
d
              Ext SubExp -> m (Ext SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Ext SubExp -> m (Ext SubExp)) -> Ext SubExp -> m (Ext SubExp)
forall a b. (a -> b) -> a -> b
$ SubExp -> Ext SubExp
forall a. a -> Ext a
I.Free (SubExp -> Ext SubExp) -> SubExp -> Ext SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
I.Var VName
v
        inspectDim Ext SubExp
ed SubExp
_ = Ext SubExp -> m (Ext SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return Ext SubExp
ed

instantiateShapesWithDecls :: MonadFreshNames m =>
                              M.Map Int I.Ident
                           -> [I.DeclExtType]
                           -> m ([I.DeclType], [I.FParam])
instantiateShapesWithDecls :: Map Int Ident
-> [TypeBase ExtShape Uniqueness]
-> m ([TypeBase Shape Uniqueness], [FParam])
instantiateShapesWithDecls Map Int Ident
ctx [TypeBase ExtShape Uniqueness]
ts =
  WriterT
  [Param (TypeBase Shape Uniqueness)] m [TypeBase Shape Uniqueness]
-> m ([TypeBase Shape Uniqueness],
      [Param (TypeBase Shape Uniqueness)])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   [Param (TypeBase Shape Uniqueness)] m [TypeBase Shape Uniqueness]
 -> m ([TypeBase Shape Uniqueness],
       [Param (TypeBase Shape Uniqueness)]))
-> WriterT
     [Param (TypeBase Shape Uniqueness)] m [TypeBase Shape Uniqueness]
-> m ([TypeBase Shape Uniqueness],
      [Param (TypeBase Shape Uniqueness)])
forall a b. (a -> b) -> a -> b
$ (Int -> WriterT [Param (TypeBase Shape Uniqueness)] m SubExp)
-> [TypeBase ExtShape Uniqueness]
-> WriterT
     [Param (TypeBase Shape Uniqueness)] m [TypeBase Shape Uniqueness]
forall (m :: * -> *) u.
Monad m =>
(Int -> m SubExp) -> [TypeBase ExtShape u] -> m [TypeBase Shape u]
instantiateShapes Int -> WriterT [Param (TypeBase Shape Uniqueness)] m SubExp
forall (t :: (* -> *) -> * -> *) (m :: * -> *).
(MonadTrans t, MonadFreshNames m,
 MonadWriter [Param (TypeBase Shape Uniqueness)] (t m)) =>
Int -> t m SubExp
instantiate [TypeBase ExtShape Uniqueness]
ts
  where instantiate :: Int -> t m SubExp
instantiate Int
x
          | Just Ident
v <- Int -> Map Int Ident -> Maybe Ident
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Int
x Map Int Ident
ctx =
            SubExp -> t m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> t m SubExp) -> SubExp -> t m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
I.Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
I.identName Ident
v

          | Bool
otherwise = do
            Param (TypeBase Shape Uniqueness)
v <- m (Param (TypeBase Shape Uniqueness))
-> t m (Param (TypeBase Shape Uniqueness))
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Param (TypeBase Shape Uniqueness))
 -> t m (Param (TypeBase Shape Uniqueness)))
-> m (Param (TypeBase Shape Uniqueness))
-> t m (Param (TypeBase Shape Uniqueness))
forall a b. (a -> b) -> a -> b
$ Ident -> Param (TypeBase Shape Uniqueness)
Ident -> FParam
nonuniqueParamFromIdent (Ident -> Param (TypeBase Shape Uniqueness))
-> m Ident -> m (Param (TypeBase Shape Uniqueness))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
[Char] -> Type -> m Ident
newIdent [Char]
"size" (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
I.Prim PrimType
I.int32)
            [Param (TypeBase Shape Uniqueness)] -> t m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Param (TypeBase Shape Uniqueness)
v]
            SubExp -> t m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> t m SubExp) -> SubExp -> t m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
I.Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape Uniqueness) -> VName
forall dec. Param dec -> VName
I.paramName Param (TypeBase Shape Uniqueness)
v

lambdaShapeSubstitutions :: [I.TypeBase I.ExtShape Uniqueness]
                         -> [I.Type]
                         -> VarSubstitutions
lambdaShapeSubstitutions :: [TypeBase ExtShape Uniqueness] -> [Type] -> Map VName [SubExp]
lambdaShapeSubstitutions [TypeBase ExtShape Uniqueness]
param_ts [Type]
ts =
  [Map VName [SubExp]] -> Map VName [SubExp]
forall a. Monoid a => [a] -> a
mconcat ([Map VName [SubExp]] -> Map VName [SubExp])
-> [Map VName [SubExp]] -> Map VName [SubExp]
forall a b. (a -> b) -> a -> b
$ (TypeBase ExtShape Uniqueness -> Type -> Map VName [SubExp])
-> [TypeBase ExtShape Uniqueness] -> [Type] -> [Map VName [SubExp]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TypeBase ExtShape Uniqueness -> Type -> Map VName [SubExp]
forall u u.
TypeBase ExtShape u -> TypeBase Shape u -> Map VName [SubExp]
matchTypes [TypeBase ExtShape Uniqueness]
param_ts [Type]
ts
  where matchTypes :: TypeBase ExtShape u -> TypeBase Shape u -> Map VName [SubExp]
matchTypes TypeBase ExtShape u
pt TypeBase Shape u
t =
          [Map VName [SubExp]] -> Map VName [SubExp]
forall a. Monoid a => [a] -> a
mconcat ([Map VName [SubExp]] -> Map VName [SubExp])
-> [Map VName [SubExp]] -> Map VName [SubExp]
forall a b. (a -> b) -> a -> b
$ (Ext SubExp -> SubExp -> Map VName [SubExp])
-> [Ext SubExp] -> [SubExp] -> [Map VName [SubExp]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Ext SubExp -> SubExp -> Map VName [SubExp]
forall a. Ext SubExp -> a -> Map VName [a]
matchDims (ExtShape -> [Ext SubExp]
forall d. ShapeBase d -> [d]
I.shapeDims (ExtShape -> [Ext SubExp]) -> ExtShape -> [Ext SubExp]
forall a b. (a -> b) -> a -> b
$ TypeBase ExtShape u -> ExtShape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
I.arrayShape TypeBase ExtShape u
pt) (TypeBase Shape u -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
I.arrayDims TypeBase Shape u
t)
        matchDims :: Ext SubExp -> a -> Map VName [a]
matchDims (I.Free (I.Var VName
v)) a
d = VName -> [a] -> Map VName [a]
forall k a. k -> a -> Map k a
M.singleton VName
v [a
d]
        matchDims Ext SubExp
_ a
_ = Map VName [a]
forall a. Monoid a => a
mempty

nonuniqueParamFromIdent :: I.Ident -> I.FParam
nonuniqueParamFromIdent :: Ident -> FParam
nonuniqueParamFromIdent (I.Ident VName
name Type
t) =
  VName
-> TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness)
forall dec. VName -> dec -> Param dec
I.Param VName
name (TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness))
-> TypeBase Shape Uniqueness -> Param (TypeBase Shape Uniqueness)
forall a b. (a -> b) -> a -> b
$ Type -> Uniqueness -> TypeBase Shape Uniqueness
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
I.toDecl Type
t Uniqueness
Nonunique