{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Futhark.Internalise.Defunctionalise
( transformProg ) where
import Control.Arrow (first, second)
import Control.Monad.RWS hiding (Sum)
import Data.Bifunctor hiding (first, second)
import Data.Foldable
import Data.List
import qualified Data.List.NonEmpty as NE
import Data.Loc
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import qualified Data.Sequence as Seq
import Futhark.MonadFreshNames
import Language.Futhark
import Futhark.Representation.AST.Pretty ()
data ExtExp = ExtLambda [TypeParam] [Pattern] Exp (Aliasing, StructType) SrcLoc
| ExtExp Exp
deriving (Show)
data StaticVal = Dynamic PatternType
| LambdaSV [VName] Pattern StructType ExtExp Env
| RecordSV [(Name, StaticVal)]
| SumSV Name [StaticVal] [(Name, [PatternType])]
| DynamicFun (Exp, StaticVal) StaticVal
| IntrinsicSV
deriving (Show)
type Env = M.Map VName StaticVal
localEnv :: Env -> DefM a -> DefM a
localEnv env = local $ second (env<>)
localNewEnv :: Env -> DefM a -> DefM a
localNewEnv env = local $ \(globals, old_env) ->
(globals, M.filterWithKey (\k _ -> k `S.member` globals) old_env <> env)
extendEnv :: VName -> StaticVal -> DefM a -> DefM a
extendEnv vn sv = localEnv (M.singleton vn sv)
askEnv :: DefM Env
askEnv = asks snd
isGlobal :: VName -> DefM a -> DefM a
isGlobal v = local $ first (S.insert v)
restrictEnvTo :: NameSet -> DefM Env
restrictEnvTo (NameSet m) = restrict <$> ask
where restrict (globals, env) = M.mapMaybeWithKey keep env
where keep k sv = do guard $ not $ k `S.member` globals
u <- M.lookup k m
Just $ restrict' u sv
restrict' Nonunique (Dynamic t) =
Dynamic $ t `setUniqueness` Nonunique
restrict' _ (Dynamic t) =
Dynamic t
restrict' u (LambdaSV dims pat t e env) =
LambdaSV dims pat t e $ M.map (restrict' u) env
restrict' u (RecordSV fields) =
RecordSV $ map (fmap $ restrict' u) fields
restrict' u (SumSV c svs fields) =
SumSV c (map (restrict' u) svs) fields
restrict' u (DynamicFun (e, sv1) sv2) =
DynamicFun (e, restrict' u sv1) $ restrict' u sv2
restrict' _ IntrinsicSV = IntrinsicSV
newtype DefM a = DefM (RWS (S.Set VName, Env) (Seq.Seq ValBind) VNameSource a)
deriving (Functor, Applicative, Monad,
MonadReader (S.Set VName, Env),
MonadWriter (Seq.Seq ValBind),
MonadFreshNames)
runDefM :: VNameSource -> DefM a -> (a, VNameSource, Seq.Seq ValBind)
runDefM src (DefM m) = runRWS m mempty src
collectFuns :: DefM a -> DefM (a, Seq.Seq ValBind)
collectFuns m = pass $ do
(x, decs) <- listen m
return ((x, decs), const mempty)
lookupVar :: SrcLoc -> VName -> DefM StaticVal
lookupVar loc x = do
env <- askEnv
case M.lookup x env of
Just sv -> return sv
Nothing
| baseTag x <= maxIntrinsicTag -> return IntrinsicSV
| otherwise -> error $ "Variable " ++ pretty x ++ " at "
++ locStr loc ++ " is out of scope."
defuncFun :: [TypeParam] -> [Pattern] -> Exp -> (Aliasing, StructType) -> SrcLoc
-> DefM (Exp, StaticVal)
defuncFun tparams pats e0 (closure, ret) loc = do
when (any isTypeParam tparams) $
error $ "Received a lambda with type parameters at " ++ locStr loc
++ ", but the defunctionalizer expects a monomorphic input program."
let (dims, pat, ret', e0') = case pats of
[] -> error "Received a lambda with no parameters."
[pat'] -> (map typeParamName tparams, pat', ret, ExtExp e0)
(pat' : pats') ->
let bound_by_pat = (`S.member` patternDimNames pat') . typeParamName
(pat_dims, rest_dims) = partition bound_by_pat tparams
in (map typeParamName pat_dims, pat',
foldFunType (map (toStruct . patternType) pats') ret,
ExtLambda rest_dims pats' e0 (closure, ret) loc)
env <- restrictEnvTo $
freeVars (Lambda pats e0 Nothing (Info (closure, ret)) loc) `without`
mconcat (map (oneName . typeParamName) tparams)
let (fields, env') = unzip $ map closureFromDynamicFun $ M.toList env
return (RecordLit fields loc, LambdaSV dims pat ret' e0' $ M.fromList env')
where closureFromDynamicFun (vn, DynamicFun (clsr_env, sv) _) =
let name = nameFromString $ pretty vn
in (RecordFieldExplicit name clsr_env noLoc, (vn, sv))
closureFromDynamicFun (vn, sv) =
let name = nameFromString $ pretty vn
tp' = typeFromSV sv
in (RecordFieldExplicit name
(Var (qualName vn) (Info tp') noLoc) noLoc, (vn, sv))
defuncExp :: Exp -> DefM (Exp, StaticVal)
defuncExp e@Literal{} =
return (e, Dynamic $ typeOf e)
defuncExp e@IntLit{} =
return (e, Dynamic $ typeOf e)
defuncExp e@FloatLit{} =
return (e, Dynamic $ typeOf e)
defuncExp (Parens e loc) = do
(e', sv) <- defuncExp e
return (Parens e' loc, sv)
defuncExp (QualParens qn e loc) = do
(e', sv) <- defuncExp e
return (QualParens qn e' loc, sv)
defuncExp (TupLit es loc) = do
(es', svs) <- unzip <$> mapM defuncExp es
return (TupLit es' loc, RecordSV $ zip fields svs)
where fields = map (nameFromString . show) [(1 :: Int) ..]
defuncExp (RecordLit fs loc) = do
(fs', names_svs) <- unzip <$> mapM defuncField fs
return (RecordLit fs' loc, RecordSV names_svs)
where defuncField (RecordFieldExplicit vn e loc') = do
(e', sv) <- defuncExp e
return (RecordFieldExplicit vn e' loc', (vn, sv))
defuncField (RecordFieldImplicit vn _ loc') = do
sv <- lookupVar loc' vn
case sv of
DynamicFun (e, sv') _ -> let vn' = baseName vn
in return (RecordFieldExplicit vn' e loc',
(vn', sv'))
_ -> let tp = Info $ typeFromSV sv
in return (RecordFieldImplicit vn tp loc', (baseName vn, sv))
defuncExp (ArrayLit es t@(Info t') loc) = do
es' <- mapM defuncExp' es
return (ArrayLit es' t loc, Dynamic t')
defuncExp (Range e1 me incl t@(Info t') loc) = do
e1' <- defuncExp' e1
me' <- mapM defuncExp' me
incl' <- mapM defuncExp' incl
return (Range e1' me' incl' t loc, Dynamic t')
defuncExp e@(Var qn _ loc) = do
sv <- lookupVar loc (qualLeaf qn)
case sv of
DynamicFun closure _ -> return closure
IntrinsicSV -> do
(pats, body, tp) <- etaExpand e
defuncExp $ Lambda pats body Nothing (Info (mempty, tp)) noLoc
_ -> let tp = typeFromSV sv
in return (Var qn (Info tp) loc, sv)
defuncExp (Ascript e0 tydecl t loc)
| orderZero (typeOf e0) = do (e0', sv) <- defuncExp e0
return (Ascript e0' tydecl t loc, sv)
| otherwise = defuncExp e0
defuncExp (LetPat pat e1 e2 _ loc) = do
(e1', sv1) <- defuncExp e1
let env = matchPatternSV pat sv1
pat' = updatePattern pat sv1
(e2', sv2) <- localEnv env $ defuncExp e2
return (LetPat pat' e1' e2' (Info $ typeOf e2') loc, sv2)
defuncExp (LetFun vn (dims, pats, _, Info ret, e1) e2 loc) = do
(e1', sv1) <- defuncFun dims pats e1 (mempty, ret) loc
(e2', sv2) <- localEnv (M.singleton vn sv1) $ defuncExp e2
return (LetPat (Id vn (Info (typeOf e1')) loc) e1' e2' (Info $ typeOf e2') loc,
sv2)
defuncExp (If e1 e2 e3 tp loc) = do
(e1', _ ) <- defuncExp e1
(e2', sv) <- defuncExp e2
(e3', _ ) <- defuncExp e3
return (If e1' e2' e3' tp loc, sv)
defuncExp e@(Apply f@(Var f' _ _) arg d t loc)
| baseTag (qualLeaf f') <= maxIntrinsicTag,
TupLit es tuploc <- arg = do
es' <- mapM defuncSoacExp es
return (Apply f (TupLit es' tuploc) d t loc,
Dynamic $ typeOf e)
defuncExp e@Apply{} = defuncApply 0 e
defuncExp (Negate e0 loc) = do
(e0', sv) <- defuncExp e0
return (Negate e0' loc, sv)
defuncExp (Lambda pats e0 _ (Info (closure, ret)) loc) =
defuncFun [] pats e0 (closure, ret) loc
defuncExp OpSection{} = error "defuncExp: unexpected operator section."
defuncExp OpSectionLeft{} = error "defuncExp: unexpected operator section."
defuncExp OpSectionRight{} = error "defuncExp: unexpected operator section."
defuncExp ProjectSection{} = error "defuncExp: unexpected projection section."
defuncExp IndexSection{} = error "defuncExp: unexpected projection section."
defuncExp (DoLoop pat e1 form e3 loc) = do
(e1', sv1) <- defuncExp e1
let env1 = matchPatternSV pat sv1
(form', env2) <- case form of
For v e2 -> do e2' <- defuncExp' e2
return (For v e2', envFromIdent v)
ForIn pat2 e2 -> do e2' <- defuncExp' e2
return (ForIn pat2 e2', envFromPattern pat2)
While e2 -> do e2' <- localEnv env1 $ defuncExp' e2
return (While e2', mempty)
(e3', sv) <- localEnv (env1 <> env2) $ defuncExp e3
return (DoLoop pat e1' form' e3' loc, sv)
where envFromIdent (Ident vn (Info tp) _) =
M.singleton vn $ Dynamic tp
defuncExp (BinOp qn (Info t) (e1, Info pt1) (e2, Info pt2) (Info ret) loc) =
defuncExp $ Apply (Apply (Var qn (Info t) loc)
e1 (Info (diet pt1)) (Info (Scalar $ Arrow mempty Unnamed (fromStruct pt2) ret)) loc)
e2 (Info (diet pt2)) (Info ret) loc
defuncExp (Project vn e0 tp@(Info tp') loc) = do
(e0', sv0) <- defuncExp e0
case sv0 of
RecordSV svs -> case lookup vn svs of
Just sv -> return (Project vn e0' (Info $ typeFromSV sv) loc, sv)
Nothing -> error "Invalid record projection."
Dynamic _ -> return (Project vn e0' tp loc, Dynamic tp')
_ -> error $ "Projection of an expression with static value " ++ show sv0
defuncExp (LetWith id1 id2 idxs e1 body t loc) = do
e1' <- defuncExp' e1
sv1 <- lookupVar (identSrcLoc id2) $ identName id2
idxs' <- mapM defuncDimIndex idxs
(body', sv) <- extendEnv (identName id1) sv1 $ defuncExp body
return (LetWith id1 id2 idxs' e1' body' t loc, sv)
defuncExp expr@(Index e0 idxs info loc) = do
e0' <- defuncExp' e0
idxs' <- mapM defuncDimIndex idxs
return (Index e0' idxs' info loc, Dynamic $ typeOf expr)
defuncExp (Update e1 idxs e2 loc) = do
(e1', sv) <- defuncExp e1
idxs' <- mapM defuncDimIndex idxs
e2' <- defuncExp' e2
return (Update e1' idxs' e2' loc, sv)
defuncExp (RecordUpdate e1 fs e2 _ loc) = do
(e1', sv1) <- defuncExp e1
(e2', sv2) <- defuncExp e2
let sv = staticField sv1 sv2 fs
return (RecordUpdate e1' fs e2' (Info $ typeFromSV sv1) loc,
sv)
where staticField (RecordSV svs) sv2 (f:fs') =
case lookup f svs of
Just sv -> RecordSV $
(f, staticField sv sv2 fs') : filter ((/=f) . fst) svs
Nothing -> error "Invalid record projection."
staticField (Dynamic t@(Scalar Record{})) sv2 fs'@(_:_) =
staticField (svFromType t) sv2 fs'
staticField _ sv2 _ = sv2
defuncExp (Unsafe e1 loc) = do
(e1', sv) <- defuncExp e1
return (Unsafe e1' loc, sv)
defuncExp (Assert e1 e2 desc loc) = do
(e1', _) <- defuncExp e1
(e2', sv) <- defuncExp e2
return (Assert e1' e2' desc loc, sv)
defuncExp (Constr name es (Info (Scalar (Sum all_fs))) loc) = do
(es', svs) <- unzip <$> mapM defuncExp es
let sv = SumSV name svs $ M.toList $
name `M.delete` M.map (map defuncType) all_fs
return (Constr name es' (Info (typeFromSV sv)) loc, sv)
where defuncType :: Monoid als =>
TypeBase (DimDecl VName) als
-> TypeBase (DimDecl VName) als
defuncType (Array as u t shape) = Array as u (defuncScalar t) shape
defuncType (Scalar t) = Scalar $ defuncScalar t
defuncScalar :: Monoid als =>
ScalarTypeBase (DimDecl VName) als
-> ScalarTypeBase (DimDecl VName) als
defuncScalar (Record fs) = Record $ M.map defuncType fs
defuncScalar Arrow{} = Record mempty
defuncScalar (Sum fs) = Sum $ M.map (map defuncType) fs
defuncScalar (Prim t) = Prim t
defuncScalar (TypeVar as u tn targs) = TypeVar as u tn targs
defuncExp (Constr name _ (Info t) loc) =
error $ "Constructor " ++ pretty name ++ " given type " ++
pretty t ++ " at " ++ locStr loc
defuncExp (Match e cs t loc) = do
(e', sv) <- defuncExp e
csPairs <- mapM (defuncCase sv) cs
let cs' = fmap fst csPairs
sv' = snd $ NE.head csPairs
return (Match e' cs' t loc, sv')
defuncExp' :: Exp -> DefM Exp
defuncExp' = fmap fst . defuncExp
defuncExtExp :: ExtExp -> DefM (Exp, StaticVal)
defuncExtExp (ExtExp e) = defuncExp e
defuncExtExp (ExtLambda tparams pats e0 (closure, ret) loc) =
defuncFun tparams pats e0 (closure, ret) loc
defuncCase :: StaticVal -> Case -> DefM (Case, StaticVal)
defuncCase sv (CasePat p e loc) = do
let p' = updatePattern p sv
env = matchPatternSV p sv
(e', sv') <- localEnv env $ defuncExp e
return (CasePat p' e' loc, sv')
defuncSoacExp :: Exp -> DefM Exp
defuncSoacExp e@OpSection{} = return e
defuncSoacExp e@OpSectionLeft{} = return e
defuncSoacExp e@OpSectionRight{} = return e
defuncSoacExp e@ProjectSection{} = return e
defuncSoacExp (Parens e loc) =
Parens <$> defuncSoacExp e <*> pure loc
defuncSoacExp (Lambda params e0 decl tp loc) = do
let env = foldMap envFromPattern params
e0' <- localEnv env $ defuncSoacExp e0
return $ Lambda params e0' decl tp loc
defuncSoacExp e
| Scalar Arrow{} <- typeOf e = do
(pats, body, tp) <- etaExpand e
let env = foldMap envFromPattern pats
body' <- localEnv env $ defuncExp' body
return $ Lambda pats body' Nothing (Info (mempty, tp)) noLoc
| otherwise = defuncExp' e
etaExpand :: Exp -> DefM ([Pattern], Exp, StructType)
etaExpand e = do
let (ps, ret) = getType $ typeOf e
(pats, vars) <- fmap unzip . forM ps $ \t -> do
x <- newNameFromString "x"
return (Id x (Info t) noLoc,
Var (qualName x) (Info t) noLoc)
let e' = foldl' (\e1 (e2, t2, argtypes) ->
Apply e1 e2 (Info $ diet t2)
(Info (foldFunType argtypes ret)) noLoc)
e $ zip3 vars ps (drop 1 $ tails ps)
return (pats, e', toStruct ret)
where getType (Scalar (Arrow _ _ t1 t2)) =
let (ps, r) = getType t2 in (t1 : ps, r)
getType t = ([], t)
defuncDimIndex :: DimIndexBase Info VName -> DefM (DimIndexBase Info VName)
defuncDimIndex (DimFix e1) = DimFix . fst <$> defuncExp e1
defuncDimIndex (DimSlice me1 me2 me3) =
DimSlice <$> defunc' me1 <*> defunc' me2 <*> defunc' me3
where defunc' = mapM defuncExp'
defuncLet :: [TypeParam] -> [Pattern] -> Exp -> Info StructType
-> DefM ([Pattern], Exp, StaticVal)
defuncLet dims ps@(pat:pats) body (Info rettype)
| patternOrderZero pat = do
let env = envFromPattern pat
bound_by_pat = (`S.member` patternDimNames pat) . typeParamName
(_pat_dims, rest_dims) = partition bound_by_pat dims
(pats', body', sv) <- localEnv env $ defuncLet rest_dims pats body (Info rettype)
closure <- defuncFun dims ps body (mempty, rettype) noLoc
return (pat : pats', body', DynamicFun closure sv)
| otherwise = do
(e, sv) <- defuncFun dims ps body (mempty, rettype) noLoc
return ([], e, sv)
defuncLet _ [] body (Info rettype) = do
(body', sv) <- defuncExp body
return ([], body', imposeType sv rettype )
where imposeType Dynamic{} t =
Dynamic $ fromStruct t
imposeType (RecordSV fs1) (Scalar (Record fs2)) =
RecordSV $ M.toList $ M.intersectionWith imposeType (M.fromList fs1) fs2
imposeType sv _ = sv
defuncApply :: Int -> Exp -> DefM (Exp, StaticVal)
defuncApply depth e@(Apply e1 e2 d t@(Info ret) loc) = do
let (argtypes, _) = unfoldFunType ret
(e1', sv1) <- defuncApply (depth+1) e1
(e2', sv2) <- defuncExp e2
let e' = Apply e1' e2' d t loc
case sv1 of
LambdaSV dims pat e0_t e0 closure_env -> do
let env' = matchPatternSV pat sv2
env_dim = envFromDimNames dims
(e0', sv) <- localNewEnv (env' <> closure_env <> env_dim) $ defuncExtExp e0
let closure_pat = buildEnvPattern closure_env
pat' = updatePattern pat sv2
let params = [closure_pat, pat']
params_for_rettype = params ++ svParams sv1 ++ svParams sv2
svParams (LambdaSV _ sv_pat _ _ _) = [sv_pat]
svParams _ = []
rettype = buildRetType closure_env params_for_rettype e0_t $
anyDimShapeAnnotations $ typeOf e0'
liftedName i (Var f _ _) =
"lifted_" ++ show i ++ "_" ++ baseString (qualLeaf f)
liftedName i (Apply f _ _ _ _) =
liftedName (i+1) f
liftedName _ _ = "lifted"
fname <- newNameFromString $ liftedName (0::Int) e1
liftValDec fname rettype dims params e0'
let t1 = toStruct $ typeOf e1'
t2 = toStruct $ typeOf e2'
fname' = qualName fname
return (Parens (Apply (Apply (Var fname' (Info (Scalar $ Arrow mempty Unnamed (fromStruct t1) $
Scalar $ Arrow mempty Unnamed (fromStruct t2) rettype)) loc)
e1' (Info Observe) (Info $ Scalar $ Arrow mempty Unnamed (fromStruct t2) rettype) loc)
e2' d (Info rettype) loc) noLoc, sv)
DynamicFun _ sv ->
let (argtypes', rettype) = dynamicFunType sv argtypes
apply_e = Apply e1' e2' d (Info $ foldFunType argtypes' rettype
`setAliases` aliases ret) loc
in return (apply_e, sv)
IntrinsicSV
| depth == 0 -> return (e', Dynamic $ typeOf e)
| otherwise -> return (e', IntrinsicSV)
_ -> error $ "Application of an expression that is neither a static lambda "
++ "nor a dynamic function, but has static value: " ++ show sv1
defuncApply depth e@(Var qn (Info t) loc) = do
let (argtypes, _) = unfoldFunType t
sv <- lookupVar loc (qualLeaf qn)
case sv of
DynamicFun _ _
| fullyApplied sv depth ->
let (argtypes', rettype) = dynamicFunType sv argtypes
in return (Var qn (Info (foldFunType argtypes' rettype)) loc, sv)
| otherwise -> do
fname <- newName $ qualLeaf qn
let (dims, pats, e0, sv') = liftDynFun sv depth
(argtypes', rettype) = dynamicFunType sv' argtypes
liftValDec fname (fromStruct rettype) dims pats e0
return (Var (qualName fname)
(Info (foldFunType argtypes' $ fromStruct rettype)) loc, sv')
IntrinsicSV -> return (e, IntrinsicSV)
_ -> return (Var qn (Info (typeFromSV sv)) loc, sv)
defuncApply depth (Parens e _) = defuncApply depth e
defuncApply _ expr = defuncExp expr
fullyApplied :: StaticVal -> Int -> Bool
fullyApplied (DynamicFun _ sv) depth
| depth == 0 = False
| depth > 0 = fullyApplied sv (depth-1)
fullyApplied _ _ = True
liftDynFun :: StaticVal -> Int -> ([VName], [Pattern], Exp, StaticVal)
liftDynFun (DynamicFun (e, sv) _) 0 = ([], [], e, sv)
liftDynFun (DynamicFun clsr@(_, LambdaSV dims pat _ _ _) sv) d
| d > 0 = let (dims', pats, e', sv') = liftDynFun sv (d-1)
in (dims ++ dims', pat : pats, e', DynamicFun clsr sv')
liftDynFun sv _ = error $ "Tried to lift a StaticVal " ++ show sv
++ ", but expected a dynamic function."
envFromPattern :: Pattern -> Env
envFromPattern pat = case pat of
TuplePattern ps _ -> foldMap envFromPattern ps
RecordPattern fs _ -> foldMap (envFromPattern . snd) fs
PatternParens p _ -> envFromPattern p
Id vn (Info t) _ -> M.singleton vn $ Dynamic t
Wildcard _ _ -> mempty
PatternAscription p _ _ -> envFromPattern p
PatternLit{} -> mempty
PatternConstr _ _ ps _ -> foldMap envFromPattern ps
envFromShapeParams :: [TypeParamBase VName] -> Env
envFromShapeParams = envFromDimNames . map dim
where dim (TypeParamDim vn _) = vn
dim tparam = error $
"The defunctionalizer expects a monomorphic input program,\n" ++
"but it received a type parameter " ++ pretty tparam ++
" at " ++ locStr (srclocOf tparam) ++ "."
envFromDimNames :: [VName] -> Env
envFromDimNames = M.fromList . flip zip (repeat $ Dynamic $ Scalar $ Prim $ Signed Int32)
liftValDec :: VName -> PatternType -> [VName] -> [Pattern] -> Exp -> DefM ()
liftValDec fname rettype dims pats body = tell $ Seq.singleton dec
where dims' = map (flip TypeParamDim noLoc) dims
rettype_st = anyDimShapeAnnotations $ toStruct rettype
dec = ValBind
{ valBindEntryPoint = Nothing
, valBindName = fname
, valBindRetDecl = Nothing
, valBindRetType = Info rettype_st
, valBindTypeParams = dims'
, valBindParams = pats
, valBindBody = body
, valBindDoc = Nothing
, valBindLocation = noLoc
}
buildEnvPattern :: Env -> Pattern
buildEnvPattern env = RecordPattern (map buildField $ M.toList env) noLoc
where buildField (vn, sv) = (nameFromString (pretty vn),
Id vn (Info $ anyDimShapeAnnotations $ typeFromSV sv) noLoc)
buildRetType :: Env -> [Pattern] -> StructType -> PatternType -> PatternType
buildRetType env pats = comb
where bound = foldMap oneName (M.keys env) <> foldMap patternVars pats
boundAsUnique v =
maybe False (unique . unInfo . identType) $
find ((==v) . identName) $ S.toList $ foldMap patternIdents pats
problematic v = (v `member` bound) && not (boundAsUnique v)
comb (Scalar (Record fs_annot)) (Scalar (Record fs_got)) =
Scalar $ Record $ M.intersectionWith comb fs_annot fs_got
comb (Scalar Arrow{}) t = descend t
comb got et = descend $ fromStruct got `setUniqueness` uniqueness et `setAliases` aliases et
descend t@Array{}
| any (problematic . aliasVar) (aliases t) = t `setUniqueness` Nonunique
descend (Scalar (Record t)) = Scalar $ Record $ fmap descend t
descend t = t
typeFromSV :: StaticVal -> PatternType
typeFromSV (Dynamic tp) = anyDimShapeAnnotations tp
typeFromSV (LambdaSV _ _ _ _ env) = typeFromEnv env
typeFromSV (RecordSV ls) = Scalar $ Record $ M.fromList $ map (fmap typeFromSV) ls
typeFromSV (DynamicFun (_, sv) _) = typeFromSV sv
typeFromSV (SumSV name svs fields) =
Scalar $ Sum $ M.insert name (map typeFromSV svs) $ M.fromList fields
typeFromSV IntrinsicSV = error $ "Tried to get the type from the "
++ "static value of an intrinsic."
typeFromEnv :: Env -> PatternType
typeFromEnv = Scalar . Record . M.fromList .
map (bimap (nameFromString . pretty) typeFromSV) . M.toList
dynamicFunType :: StaticVal -> [PatternType] -> ([PatternType], PatternType)
dynamicFunType (DynamicFun _ sv) (p:ps) =
let (ps', ret) = dynamicFunType sv ps in (p : ps', ret)
dynamicFunType sv _ = ([], typeFromSV sv)
matchPatternSV :: PatternBase Info VName -> StaticVal -> Env
matchPatternSV (TuplePattern ps _) (RecordSV ls) =
mconcat $ zipWith (\p (_, sv) -> matchPatternSV p sv) ps ls
matchPatternSV (RecordPattern ps _) (RecordSV ls)
| ps' <- sortOn fst ps, ls' <- sortOn fst ls,
map fst ps' == map fst ls' =
mconcat $ zipWith (\(_, p) (_, sv) -> matchPatternSV p sv) ps' ls'
matchPatternSV (PatternParens pat _) sv = matchPatternSV pat sv
matchPatternSV (Id vn (Info t) _) sv =
if orderZeroSV sv
then M.singleton vn $ Dynamic t
else M.singleton vn sv
matchPatternSV (Wildcard _ _) _ = mempty
matchPatternSV (PatternAscription pat _ _) sv = matchPatternSV pat sv
matchPatternSV PatternLit{} _ = mempty
matchPatternSV (PatternConstr c1 _ ps _) (SumSV c2 ls fs)
| c1 == c2 =
mconcat $ zipWith matchPatternSV ps ls
| Just ts <- lookup c1 fs =
mconcat $ zipWith matchPatternSV ps $ map svFromType ts
| otherwise =
error $ "matchPatternSV: missing constructor in type: " ++ pretty c1
matchPatternSV (PatternConstr c1 _ ps _) (Dynamic (Scalar (Sum fs)))
| Just ts <- M.lookup c1 fs =
mconcat $ zipWith matchPatternSV ps $ map svFromType ts
| otherwise =
error $ "matchPatternSV: missing constructor in type: " ++ pretty c1
matchPatternSV pat (Dynamic t) = matchPatternSV pat $ svFromType t
matchPatternSV pat sv = error $ "Tried to match pattern " ++ pretty pat
++ " with static value " ++ show sv ++ "."
orderZeroSV :: StaticVal -> Bool
orderZeroSV Dynamic{} = True
orderZeroSV (RecordSV fields) = all (orderZeroSV . snd) fields
orderZeroSV _ = False
updatePattern :: Pattern -> StaticVal -> Pattern
updatePattern (TuplePattern ps loc) (RecordSV svs) =
TuplePattern (zipWith updatePattern ps $ map snd svs) loc
updatePattern (RecordPattern ps loc) (RecordSV svs)
| ps' <- sortOn fst ps, svs' <- sortOn fst svs =
RecordPattern (zipWith (\(n, p) (_, sv) ->
(n, updatePattern p sv)) ps' svs') loc
updatePattern (PatternParens pat loc) sv =
PatternParens (updatePattern pat sv) loc
updatePattern pat@(Id vn (Info tp) loc) sv
| orderZero tp = pat
| otherwise = Id vn (Info $ typeFromSV sv `setUniqueness` Nonunique) loc
updatePattern pat@(Wildcard (Info tp) loc) sv
| orderZero tp = pat
| otherwise = Wildcard (Info $ typeFromSV sv) loc
updatePattern (PatternAscription pat tydecl loc) sv
| orderZero . unInfo $ expandedType tydecl =
PatternAscription (updatePattern pat sv) tydecl loc
| otherwise = updatePattern pat sv
updatePattern p@PatternLit{} _ = p
updatePattern pat@(PatternConstr c1 (Info t) ps loc) sv@(SumSV _ svs _)
| orderZero t = pat
| otherwise = PatternConstr c1 (Info t') ps' loc
where t' = typeFromSV sv `setUniqueness` Nonunique
ps' = zipWith updatePattern ps svs
updatePattern (PatternConstr c1 _ ps loc) (Dynamic t) =
PatternConstr c1 (Info t) ps loc
updatePattern pat (Dynamic t) = updatePattern pat (svFromType t)
updatePattern pat sv =
error $ "Tried to update pattern " ++ pretty pat
++ "to reflect the static value " ++ show sv
svFromType :: PatternType -> StaticVal
svFromType (Scalar (Record fs)) = RecordSV . M.toList $ M.map svFromType fs
svFromType t = Dynamic t
newtype NameSet = NameSet (M.Map VName Uniqueness)
instance Semigroup NameSet where
NameSet x <> NameSet y = NameSet $ M.unionWith max x y
instance Monoid NameSet where
mempty = NameSet mempty
without :: NameSet -> NameSet -> NameSet
without (NameSet x) (NameSet y) = NameSet $ x `M.difference` y
member :: VName -> NameSet -> Bool
member v (NameSet m) = v `M.member` m
ident :: Ident -> NameSet
ident v = NameSet $ M.singleton (identName v) (uniqueness $ unInfo $ identType v)
oneName :: VName -> NameSet
oneName v = NameSet $ M.singleton v Nonunique
names :: S.Set VName -> NameSet
names = foldMap oneName
freeVars :: Exp -> NameSet
freeVars expr = case expr of
Literal{} -> mempty
IntLit{} -> mempty
FloatLit{} -> mempty
Parens e _ -> freeVars e
QualParens _ e _ -> freeVars e
TupLit es _ -> foldMap freeVars es
RecordLit fs _ -> foldMap freeVarsField fs
where freeVarsField (RecordFieldExplicit _ e _) = freeVars e
freeVarsField (RecordFieldImplicit vn t _) = ident $ Ident vn t noLoc
ArrayLit es _ _ -> foldMap freeVars es
Range e me incl _ _ -> freeVars e <> foldMap freeVars me <>
foldMap freeVars incl
Var qn (Info t) _ -> NameSet $ M.singleton (qualLeaf qn) $ uniqueness t
Ascript e t _ _ -> freeVars e <> names (typeDimNames $ unInfo $ expandedType t)
LetPat pat e1 e2 _ _ -> freeVars e1 <> ((names (patternDimNames pat) <> freeVars e2)
`without` patternVars pat)
LetFun vn (_, pats, _, _, e1) e2 _ ->
((freeVars e1 <> names (foldMap patternDimNames pats))
`without` foldMap patternVars pats) <>
(freeVars e2 `without` oneName vn)
If e1 e2 e3 _ _ -> freeVars e1 <> freeVars e2 <> freeVars e3
Apply e1 e2 _ _ _ -> freeVars e1 <> freeVars e2
Negate e _ -> freeVars e
Lambda pats e0 _ _ _ -> (names (foldMap patternDimNames pats) <> freeVars e0)
`without` foldMap patternVars pats
OpSection{} -> mempty
OpSectionLeft _ _ e _ _ _ -> freeVars e
OpSectionRight _ _ e _ _ _ -> freeVars e
ProjectSection{} -> mempty
IndexSection idxs _ _ -> foldMap freeDimIndex idxs
DoLoop pat e1 form e3 _ -> let (e2fv, e2ident) = formVars form
in freeVars e1 <> e2fv <>
(freeVars e3 `without` (patternVars pat <> e2ident))
where formVars (For v e2) = (freeVars e2, ident v)
formVars (ForIn p e2) = (freeVars e2, patternVars p)
formVars (While e2) = (freeVars e2, mempty)
BinOp qn _ (e1, _) (e2, _) _ _ -> oneName (qualLeaf qn) <>
freeVars e1 <> freeVars e2
Project _ e _ _ -> freeVars e
LetWith id1 id2 idxs e1 e2 _ _ ->
ident id2 <> foldMap freeDimIndex idxs <> freeVars e1 <>
(freeVars e2 `without` ident id1)
Index e idxs _ _ -> freeVars e <> foldMap freeDimIndex idxs
Update e1 idxs e2 _ -> freeVars e1 <> foldMap freeDimIndex idxs <> freeVars e2
RecordUpdate e1 _ e2 _ _ -> freeVars e1 <> freeVars e2
Unsafe e _ -> freeVars e
Assert e1 e2 _ _ -> freeVars e1 <> freeVars e2
Constr _ es _ _ -> foldMap freeVars es
Match e cs _ _ -> freeVars e <> foldMap caseFV cs
where caseFV (CasePat p eCase _) = (names (patternDimNames p) <> freeVars eCase)
`without` patternVars p
freeDimIndex :: DimIndexBase Info VName -> NameSet
freeDimIndex (DimFix e) = freeVars e
freeDimIndex (DimSlice me1 me2 me3) =
foldMap (foldMap freeVars) [me1, me2, me3]
patternVars :: Pattern -> NameSet
patternVars = mconcat . map ident . S.toList . patternIdents
defuncValBind :: ValBind -> DefM (ValBind, Env, Bool)
defuncValBind (ValBind entry@Just{} name _ (Info rettype) tparams params body _ loc)
| (rettype_ps, rettype') <- unfoldFunType rettype,
not $ null rettype_ps = do
(body_pats, body', _) <- etaExpand body
defuncValBind $ ValBind entry name Nothing
(Info $ onlyConstantDims rettype')
tparams (params <> body_pats) body' Nothing loc
where onlyConstantDims = bimap onDim id
onDim (ConstDim x) = ConstDim x
onDim _ = AnyDim
defuncValBind valbind@(ValBind _ name retdecl rettype tparams params body _ _) = do
let env = envFromShapeParams tparams
(params', body', sv) <- localEnv env $ defuncLet tparams params body rettype
let dim_names = foldMap patternDimNames params'
tparams' = filter ((`S.member` dim_names) . typeParamName) tparams
let rettype' = anyDimShapeAnnotations $ toStruct $ typeOf body'
return ( valbind { valBindRetDecl = retdecl
, valBindRetType = Info $ combineTypeShapes
(unInfo rettype) rettype'
, valBindTypeParams = tparams'
, valBindParams = params'
, valBindBody = body'
}
, M.singleton name sv
, case sv of DynamicFun{} -> True
_ -> False)
defuncVals :: [ValBind] -> DefM (Seq.Seq ValBind)
defuncVals [] = return mempty
defuncVals (valbind : ds) = do
((valbind', env, dyn), defs) <- collectFuns $ defuncValBind valbind
ds' <- localEnv env $ if dyn
then isGlobal (valBindName valbind') $ defuncVals ds
else defuncVals ds
return $ defs <> Seq.singleton valbind' <> ds'
transformProg :: MonadFreshNames m => [ValBind] -> m [ValBind]
transformProg decs = modifyNameSource $ \namesrc ->
let (decs', namesrc', liftedDecs) = runDefM namesrc $ defuncVals decs
in (toList $ liftedDecs <> decs', namesrc')