{-# LANGUAGE FlexibleContexts, FlexibleInstances #-}
module Language.Futhark.TypeChecker.Types
( checkTypeExp
, checkTypeDecl
, unifyTypesU
, subtypeOf
, subuniqueOf
, checkForDuplicateNames
, checkForDuplicateNamesInType
, checkTypeParams
, typeParamToArg
, typeExpUses
, checkShapeParamUses
, TypeSub(..)
, TypeSubs
, substituteTypes
, Subst(..)
, Substitutable(..)
, substTypesAny
)
where
import Control.Monad.Reader
import Control.Monad.Except
import Control.Monad.State
import Data.Bifunctor
import Data.List
import Data.Loc
import Data.Maybe
import Data.Monoid ((<>))
import qualified Data.Map.Strict as M
import Language.Futhark
import Language.Futhark.TypeChecker.Monad
unifyTypesU :: (Monoid als, ArrayDim dim) =>
(Uniqueness -> Uniqueness -> Maybe Uniqueness)
-> TypeBase dim als -> TypeBase dim als -> Maybe (TypeBase dim als)
unifyTypesU uf (Array als1 u1 et1 shape1) (Array als2 u2 et2 shape2) =
Array (als1 <> als2) <$> uf u1 u2
<*> unifyScalarTypes uf et1 et2 <*> unifyShapes shape1 shape2
unifyTypesU uf (Scalar t1) (Scalar t2) = Scalar <$> unifyScalarTypes uf t1 t2
unifyTypesU _ _ _ = Nothing
unifyScalarTypes :: (Monoid als, ArrayDim dim) =>
(Uniqueness -> Uniqueness -> Maybe Uniqueness)
-> ScalarTypeBase dim als -> ScalarTypeBase dim als -> Maybe (ScalarTypeBase dim als)
unifyScalarTypes _ (Prim t1) (Prim t2)
| t1 == t2 = Just $ Prim t1
| otherwise = Nothing
unifyScalarTypes uf (TypeVar als1 u1 t1 targs1) (TypeVar als2 u2 t2 targs2)
| t1 == t2 = do
u3 <- uf u1 u2
targs3 <- zipWithM (unifyTypeArgs uf) targs1 targs2
Just $ TypeVar (als1 <> als2) u3 t1 targs3
| otherwise = Nothing
unifyScalarTypes uf (Record ts1) (Record ts2)
| length ts1 == length ts2,
sort (M.keys ts1) == sort (M.keys ts2) =
Record <$> traverse (uncurry (unifyTypesU uf))
(M.intersectionWith (,) ts1 ts2)
unifyScalarTypes uf (Arrow as1 mn1 t1 t1') (Arrow as2 _ t2 t2') =
Arrow (as1 <> as2) mn1 <$> unifyTypesU (flip uf) t1 t2 <*> unifyTypesU uf t1' t2'
unifyScalarTypes uf (Sum cs1) (Sum cs2)
| length cs1 == length cs2,
sort (M.keys cs1) == sort (M.keys cs2) =
Sum <$> traverse (uncurry (zipWithM (unifyTypesU uf)))
(M.intersectionWith (,) cs1 cs2)
unifyScalarTypes _ _ _ = Nothing
unifyTypeArgs :: (ArrayDim dim) =>
(Uniqueness -> Uniqueness -> Maybe Uniqueness)
-> TypeArg dim -> TypeArg dim -> Maybe (TypeArg dim)
unifyTypeArgs _ (TypeArgDim d1 loc) (TypeArgDim d2 _) =
TypeArgDim <$> unifyDims d1 d2 <*> pure loc
unifyTypeArgs uf (TypeArgType t1 loc) (TypeArgType t2 _) =
TypeArgType <$> unifyTypesU uf t1 t2 <*> pure loc
unifyTypeArgs _ _ _ =
Nothing
subtypeOf :: ArrayDim dim =>
TypeBase dim as1 -> TypeBase dim as2 -> Bool
subtypeOf t1 t2 = isJust $ unifyTypesU unifyUniqueness (toStruct t1) (toStruct t2)
where unifyUniqueness u2 u1 = if u2 `subuniqueOf` u1 then Just u1 else Nothing
subuniqueOf :: Uniqueness -> Uniqueness -> Bool
subuniqueOf Nonunique Unique = False
subuniqueOf _ _ = True
checkTypeDecl :: MonadTypeChecker m =>
[TypeParam]
-> TypeDeclBase NoInfo Name
-> m (TypeDeclBase Info VName, Liftedness)
checkTypeDecl tps (TypeDecl t NoInfo) = do
checkForDuplicateNamesInType t
(t', st, l) <- checkTypeExp t
checkShapeParamUses typeExpUses tps $ unfoldTypeExp t'
return (TypeDecl t' $ Info st, l)
unfoldTypeExp :: TypeExp VName -> [TypeExp VName]
unfoldTypeExp (TEArrow _ t1 t2 _) = t1 : unfoldTypeExp t2
unfoldTypeExp t = [t]
checkTypeExp :: MonadTypeChecker m =>
TypeExp Name
-> m (TypeExp VName, StructType, Liftedness)
checkTypeExp (TEVar name loc) = do
(name', ps, t, l) <- lookupType loc name
case ps of
[] -> return (TEVar name' loc, t, l)
_ -> throwError $ TypeError loc $
"Type constructor " ++ pretty name ++ " used without any arguments."
checkTypeExp (TETuple ts loc) = do
(ts', ts_s, ls) <- unzip3 <$> mapM checkTypeExp ts
return (TETuple ts' loc, tupleRecord ts_s, foldl' max Unlifted ls)
checkTypeExp t@(TERecord fs loc) = do
let field_names = map fst fs
unless (sort field_names == sort (nub field_names)) $
throwError $ TypeError loc $ "Duplicate record fields in " ++ pretty t
fs_ts_ls <- traverse checkTypeExp $ M.fromList fs
let fs' = fmap (\(x,_,_) -> x) fs_ts_ls
ts_s = fmap (\(_,y,_) -> y) fs_ts_ls
ls = fmap (\(_,_,z) -> z) fs_ts_ls
return (TERecord (M.toList fs') loc,
Scalar $ Record ts_s,
foldl' max Unlifted ls)
checkTypeExp (TEArray t d loc) = do
(t', st, l) <- checkTypeExp t
d' <- checkDimDecl d
case (l, arrayOf st (ShapeDecl [d']) Nonunique) of
(Unlifted, st') -> return (TEArray t' d' loc, st', Unlifted)
_ -> throwError $ TypeError loc $
"Cannot create array with elements of type `" ++ pretty st ++ "` (might be functional)."
where checkDimDecl AnyDim =
return AnyDim
checkDimDecl (ConstDim k) =
return $ ConstDim k
checkDimDecl (NamedDim v) =
NamedDim <$> checkNamedDim loc v
checkTypeExp (TEUnique t loc) = do
(t', st, l) <- checkTypeExp t
unless (mayContainArray st) $
warn loc $ "Declaring `" <> pretty st <> "` as unique has no effect."
return (TEUnique t' loc, st `setUniqueness` Unique, l)
where mayContainArray (Scalar Prim{}) = False
mayContainArray Array{} = True
mayContainArray (Scalar (Record fs)) = any mayContainArray fs
mayContainArray (Scalar TypeVar{}) = True
mayContainArray (Scalar Arrow{}) = False
mayContainArray (Scalar (Sum cs)) = (any . any) mayContainArray cs
checkTypeExp (TEArrow (Just v) t1 t2 loc) = do
(t1', st1, _) <- checkTypeExp t1
bindSpaced [(Term, v)] $ do
v' <- checkName Term v loc
let env = mempty { envVtable = M.singleton v' $ BoundV [] st1 }
localEnv env $ do
(t2', st2, _) <- checkTypeExp t2
return (TEArrow (Just v') t1' t2' loc,
Scalar $ Arrow mempty (Named v') st1 st2,
Lifted)
checkTypeExp (TEArrow Nothing t1 t2 loc) = do
(t1', st1, _) <- checkTypeExp t1
(t2', st2, _) <- checkTypeExp t2
return (TEArrow Nothing t1' t2' loc,
Scalar $ Arrow mempty Unnamed st1 st2,
Lifted)
checkTypeExp ote@TEApply{} = do
(tname, tname_loc, targs) <- rootAndArgs ote
(tname', ps, t, l) <- lookupType tloc tname
if length ps /= length targs
then throwError $ TypeError tloc $
"Type constructor " ++ pretty tname ++ " requires " ++ show (length ps) ++
" arguments, but application at " ++ locStr tloc ++ " provides " ++ show (length targs)
else do
(targs', substs) <- unzip <$> zipWithM checkArgApply ps targs
return (foldl (\x y -> TEApply x y tloc) (TEVar tname' tname_loc) targs',
substituteTypes (mconcat substs) t,
l)
where tloc = srclocOf ote
rootAndArgs :: MonadTypeChecker m => TypeExp Name -> m (QualName Name, SrcLoc, [TypeArgExp Name])
rootAndArgs (TEVar qn loc) = return (qn, loc, [])
rootAndArgs (TEApply op arg _) = do (op', loc, args) <- rootAndArgs op
return (op', loc, args++[arg])
rootAndArgs te' = throwError $ TypeError (srclocOf te') $
"Type '" ++ pretty te' ++ "' is not a type constructor."
checkArgApply (TypeParamDim pv _) (TypeArgExpDim (NamedDim v) loc) = do
v' <- checkNamedDim loc v
return (TypeArgExpDim (NamedDim v') loc,
M.singleton pv $ DimSub $ NamedDim v')
checkArgApply (TypeParamDim pv _) (TypeArgExpDim (ConstDim x) loc) =
return (TypeArgExpDim (ConstDim x) loc,
M.singleton pv $ DimSub $ ConstDim x)
checkArgApply (TypeParamDim pv _) (TypeArgExpDim AnyDim loc) =
return (TypeArgExpDim AnyDim loc,
M.singleton pv $ DimSub AnyDim)
checkArgApply (TypeParamType l pv _) (TypeArgExpType te) = do
(te', st, _) <- checkTypeExp te
return (TypeArgExpType te',
M.singleton pv $ TypeSub $ TypeAbbr l [] st)
checkArgApply p a =
throwError $ TypeError tloc $ "Type argument " ++ pretty a ++
" not valid for a type parameter " ++ pretty p
checkTypeExp t@(TESum cs loc) = do
let constructors = map fst cs
unless (sort constructors == sort (nub constructors)) $
throwError $ TypeError loc $ "Duplicate constructors in " ++ pretty t
unless (length constructors <= 256) $
throwError $ TypeError loc "Sum types must have 256 or fewer constructors."
cs_ts_ls <- (traverse . traverse) checkTypeExp $ M.fromList cs
let cs' = (fmap . fmap) (\(x,_,_) -> x) cs_ts_ls
ts_s = (fmap . fmap) (\(_, y, _) -> y) cs_ts_ls
ls = (concatMap . fmap) (\(_, _, z) -> z) cs_ts_ls
return (TESum (M.toList cs') loc,
Scalar $ Sum ts_s,
foldl' max Unlifted ls)
checkForDuplicateNames :: MonadTypeChecker m =>
[UncheckedPattern] -> m ()
checkForDuplicateNames = (`evalStateT` mempty) . mapM_ check
where check (Id v _ loc) = seen v loc
check (PatternParens p _) = check p
check Wildcard{} = return ()
check (TuplePattern ps _) = mapM_ check ps
check (RecordPattern fs _) = mapM_ (check . snd) fs
check (PatternAscription p _ _) = check p
check PatternLit{} = return ()
check (PatternConstr _ _ ps _) = mapM_ check ps
seen v loc = do
already <- gets $ M.lookup v
case already of
Just prev_loc ->
lift $ throwError $ TypeError loc $
"Name " ++ quote (pretty v) ++ " also bound at " ++ locStr prev_loc
Nothing ->
modify $ M.insert v loc
checkForDuplicateNamesInType :: MonadTypeChecker m =>
TypeExp Name -> m ()
checkForDuplicateNamesInType = checkForDuplicateNames . pats
where pats (TEArrow (Just v) t1 t2 loc) = Id v NoInfo loc : pats t1 ++ pats t2
pats (TEArrow Nothing t1 t2 _) = pats t1 ++ pats t2
pats (TETuple ts _) = concatMap pats ts
pats (TERecord fs _) = concatMap (pats . snd) fs
pats (TEArray t _ _) = pats t
pats (TEUnique t _) = pats t
pats (TEApply t1 (TypeArgExpType t2) _) = pats t1 ++ pats t2
pats (TEApply t1 TypeArgExpDim{} _) = pats t1
pats TEVar{} = []
pats (TESum cs _) = concatMap (concatMap pats . snd) cs
checkShapeParamUses :: (MonadTypeChecker m, Located a) =>
(a -> ([VName], [VName])) -> [TypeParam] -> [a]
-> m ()
checkShapeParamUses getUses tps ps = do
pos_uses <- foldM checkShapePositions [] ps
mapM_ (checkUsed pos_uses) tps
where tp_names = map typeParamName tps
checkShapePositions pos_uses p = do
let (pos, neg) = getUses p
pos_uses' = pos <> pos_uses
forM_ neg $ \pv ->
unless ((pv `notElem` tp_names) || (pv `elem` pos_uses')) $
throwError $ TypeError (srclocOf p) $ "Shape parameter `" ++
pretty (baseName pv) ++ "` must first be given in " ++
"a positive position (non-functional parameter)."
return pos_uses'
checkUsed uses (TypeParamDim pv loc)
| pv `elem` uses = return ()
| otherwise =
throwError $ TypeError loc $ "Size parameter `" ++
pretty (baseName pv) ++ "` unused."
checkUsed _ _ = return ()
checkTypeParams :: MonadTypeChecker m =>
[TypeParamBase Name]
-> ([TypeParamBase VName] -> m a)
-> m a
checkTypeParams ps m =
bindSpaced (map typeParamSpace ps) $
m =<< evalStateT (mapM checkTypeParam ps) mempty
where typeParamSpace (TypeParamDim pv _) = (Term, pv)
typeParamSpace (TypeParamType _ pv _) = (Type, pv)
checkParamName ns v loc = do
seen <- gets $ M.lookup (ns,v)
case seen of
Just prev ->
throwError $ TypeError loc $
"Type parameter " ++ quote (pretty v) ++
" previously defined at " ++ locStr prev ++ "."
Nothing -> do
modify $ M.insert (ns,v) loc
lift $ checkName ns v loc
checkTypeParam (TypeParamDim pv loc) =
TypeParamDim <$> checkParamName Term pv loc <*> pure loc
checkTypeParam (TypeParamType l pv loc) =
TypeParamType l <$> checkParamName Type pv loc <*> pure loc
typeParamToArg :: TypeParam -> StructTypeArg
typeParamToArg (TypeParamDim v ploc) =
TypeArgDim (NamedDim $ qualName v) ploc
typeParamToArg (TypeParamType _ v ploc) =
TypeArgType (Scalar (TypeVar () Nonunique (typeName v) [])) ploc
typeExpUses :: TypeExp VName -> ([VName], [VName])
typeExpUses (TEVar _ _) = mempty
typeExpUses (TETuple tes _) = foldMap typeExpUses tes
typeExpUses (TERecord fs _) = foldMap (typeExpUses . snd) fs
typeExpUses (TEArray te d _) = typeExpUses te <> dimDeclUses d
typeExpUses (TEUnique te _) = typeExpUses te
typeExpUses (TEApply te targ _) = typeExpUses te <> typeArgUses targ
where typeArgUses (TypeArgExpDim d _) = dimDeclUses d
typeArgUses (TypeArgExpType tae) = typeExpUses tae
typeExpUses (TEArrow _ t1 t2 _) =
let (pos, neg) = typeExpUses t1 <> typeExpUses t2
in (mempty, pos <> neg)
typeExpUses (TESum cs _) = foldMap (mconcat . fmap typeExpUses . snd) cs
dimDeclUses :: DimDecl VName -> ([VName], [VName])
dimDeclUses (NamedDim v) = ([qualLeaf v], [])
dimDeclUses _ = mempty
data TypeSub = TypeSub TypeBinding
| DimSub (DimDecl VName)
deriving (Show)
type TypeSubs = M.Map VName TypeSub
substituteTypes :: Monoid als => TypeSubs -> TypeBase (DimDecl VName) als -> TypeBase (DimDecl VName) als
substituteTypes substs ot = case ot of
Array als u at shape ->
arrayOf (substituteTypes substs (Scalar at) `setAliases` mempty)
(substituteInShape shape) u `addAliases` (<>als)
Scalar (Prim t) -> Scalar $ Prim t
Scalar (TypeVar als u v targs)
| Just (TypeSub (TypeAbbr _ ps t)) <-
M.lookup (qualLeaf (qualNameFromTypeName v)) substs ->
applyType ps (t `setAliases` mempty) (map substituteInTypeArg targs)
`setUniqueness` u `addAliases` (<>als)
| otherwise -> Scalar $ TypeVar als u v $ map substituteInTypeArg targs
Scalar (Record ts) ->
Scalar $ Record $ fmap (substituteTypes substs) ts
Scalar (Arrow als v t1 t2) ->
Scalar $ Arrow als v (substituteTypes substs t1) (substituteTypes substs t2)
Scalar (Sum cs) ->
Scalar $ Sum $ (fmap . fmap) (substituteTypes substs) cs
where substituteInTypeArg (TypeArgDim d loc) =
TypeArgDim (substituteInDim d) loc
substituteInTypeArg (TypeArgType t loc) =
TypeArgType (substituteTypes substs t) loc
substituteInShape (ShapeDecl ds) =
ShapeDecl $ map substituteInDim ds
substituteInDim (NamedDim v)
| Just (DimSub d) <- M.lookup (qualLeaf v) substs = d
substituteInDim d = d
applyType :: Monoid als =>
[TypeParam] -> TypeBase (DimDecl VName) als -> [StructTypeArg] -> TypeBase (DimDecl VName) als
applyType ps t args =
substituteTypes substs t
where substs = M.fromList $ zipWith mkSubst ps args
mkSubst (TypeParamDim pv _) (TypeArgDim (NamedDim v) _) =
(pv, DimSub $ NamedDim v)
mkSubst (TypeParamDim pv _) (TypeArgDim (ConstDim x) _) =
(pv, DimSub $ ConstDim x)
mkSubst (TypeParamDim pv _) (TypeArgDim AnyDim _) =
(pv, DimSub AnyDim)
mkSubst (TypeParamType l pv _) (TypeArgType at _) =
(pv, TypeSub $ TypeAbbr l [] at)
mkSubst p a =
error $ "applyType mkSubst: cannot substitute " ++ pretty a ++ " for " ++ pretty p
data Subst t = Subst t | PrimSubst
instance Functor Subst where
fmap f (Subst t) = Subst $ f t
fmap _ PrimSubst = PrimSubst
class Substitutable a where
applySubst :: (VName -> Maybe (Subst (TypeBase () ()))) -> a -> a
instance Substitutable (TypeBase () ()) where
applySubst = substTypesAny
instance Substitutable (TypeBase () Aliasing) where
applySubst = substTypesAny . (fmap (fmap fromStruct).)
instance Substitutable (TypeBase (DimDecl VName) ()) where
applySubst = substTypesAny . (fmap (fmap vacuousShapeAnnotations).)
instance Substitutable (TypeBase (DimDecl VName) Aliasing) where
applySubst = substTypesAny . (fmap (fmap (vacuousShapeAnnotations . fromStruct)).)
substTypesAny :: (ArrayDim dim, Monoid as) =>
(VName -> Maybe (Subst (TypeBase dim as)))
-> TypeBase dim as -> TypeBase dim as
substTypesAny lookupSubst ot = case ot of
Array als u et shape ->
arrayOf (substTypesAny lookupSubst' (Scalar et)) shape u `setAliases` als
Scalar (Prim t) -> Scalar $ Prim t
Scalar (TypeVar als u v targs) ->
case lookupSubst $ qualLeaf (qualNameFromTypeName v) of
Just (Subst t) -> t `setUniqueness` u `addAliases` (<>als)
Just PrimSubst -> Scalar $ TypeVar mempty u v $ map subsTypeArg targs
Nothing -> Scalar $ TypeVar als u v $ map subsTypeArg targs
Scalar (Record ts) -> Scalar $ Record $ fmap (substTypesAny lookupSubst) ts
Scalar (Arrow als v t1 t2) ->
Scalar $ Arrow als v (substTypesAny lookupSubst t1) (substTypesAny lookupSubst t2)
Scalar (Sum ts) ->
Scalar $ Sum $ (fmap . fmap) (substTypesAny lookupSubst) ts
where subsTypeArg (TypeArgType t loc) =
TypeArgType (substTypesAny lookupSubst' t) loc
subsTypeArg t = t
lookupSubst' = fmap (fmap $ bimap id (const ())) . lookupSubst