{-# LANGUAGE CPP
, ScopedTypeVariables
, GADTs
, DataKinds
, KindSignatures
, GeneralizedNewtypeDeriving
, TypeOperators
, FlexibleContexts
, FlexibleInstances
, OverloadedStrings
, PatternGuards
, Rank2Types
#-}
{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
module Language.Hakaru.Syntax.TypeCheck
(
TypeCheckError
, TypeCheckMonad(), runTCM, unTCM
, TypeCheckMode(..)
, inferable
, mustCheck
, TypedAST(..)
, onTypedAST, onTypedASTM, elimTypedAST
, inferType
, checkType
) where
import Prelude hiding (id, (.))
import Control.Category
import Data.Proxy (KProxy(..))
import Data.Text (pack, Text())
import Data.Either (partitionEithers)
import qualified Data.IntMap as IM
import qualified Data.Traversable as T
import qualified Data.List.NonEmpty as L
import qualified Data.Foldable as F
import qualified Data.Sequence as S
import qualified Data.Vector as V
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative (Applicative(..), (<$>))
import Data.Monoid (Monoid(..))
#endif
import qualified Language.Hakaru.Parser.AST as U
import Data.Number.Nat (fromNat)
import Language.Hakaru.Syntax.TypeCheck.TypeCheckMonad
import Language.Hakaru.Syntax.TypeCheck.Unification
import Language.Hakaru.Syntax.IClasses
import Language.Hakaru.Types.DataKind (Hakaru(..), HData', HBool)
import Language.Hakaru.Types.Sing
import Language.Hakaru.Types.Coercion
import Language.Hakaru.Types.HClasses
( HEq, hEq_Sing, HOrd, hOrd_Sing, HSemiring, hSemiring_Sing
, hRing_Sing, sing_HRing, hFractional_Sing, sing_HFractional
, sing_NonNegative, hDiscrete_Sing
, HIntegrable(..)
, HRadical(..), HContinuous(..))
import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Syntax.Datum
import Language.Hakaru.Syntax.Reducer
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.AST.Sing
(sing_Literal, sing_MeasureOp)
import Language.Hakaru.Pretty.Concrete (prettyType, prettyTypeT)
import Language.Hakaru.Syntax.TypeOf (typeOf)
import Language.Hakaru.Syntax.Prelude (triv)
inferable :: U.AST -> Bool
inferable = not . mustCheck
mustCheck :: U.AST -> Bool
mustCheck e = caseVarSyn e (const False) go
where
go :: U.MetaTerm -> Bool
go (U.Lam_ _ e2) = mustCheck' e2
go (U.App_ _ _) = False
go (U.Let_ _ e2) = mustCheck' e2
go (U.Ann_ _ _) = False
go (U.CoerceTo_ _ _) = False
go (U.UnsafeTo_ _ _) = False
go (U.PrimOp_ _ _) = False
go (U.ArrayOp_ _ es) = F.all mustCheck es
go (U.NaryOp_ _ es) = F.all mustCheck es
go (U.Superpose_ pes) = F.all (mustCheck . snd) pes
go (U.Literal_ _) = False
go (U.Pair_ e1 e2) = mustCheck e1 && mustCheck e2
go (U.Array_ _ e1) = mustCheck' e1
go (U.ArrayLiteral_ es) = F.all mustCheck es
go (U.Datum_ _) = True
go (U.Case_ _ _) = True
go (U.Dirac_ e1) = mustCheck e1
go (U.MBind_ _ e2) = mustCheck' e2
go (U.Plate_ _ e2) = mustCheck' e2
go (U.Chain_ _ e2 e3) = mustCheck e2 && mustCheck' e3
go (U.MeasureOp_ _ _) = False
go (U.Integrate_ _ _ _) = False
go (U.Summate_ _ _ _) = False
go (U.Product_ _ _ _) = False
go (U.Bucket_ _ _ _) = False
go U.Reject_ = True
go (U.Transform_ tr es ) =
case (tr, es) of
(Expect , (Nil2, e1) U.:* _ U.:* U.End)
-> mustCheck e1
(Observe , (Nil2, e1) U.:* _ U.:* U.End)
-> mustCheck e1
(MCMC , (Nil2, e1) U.:* (Nil2, e2) U.:* U.End)
-> mustCheck e1 && mustCheck e2
(Disint _ , (Nil2, e1) U.:* U.End)
-> mustCheck e1
(Simplify , (Nil2, e1) U.:* U.End)
-> mustCheck e1
(Summarize, (Nil2, e1) U.:* U.End)
-> mustCheck e1
(Reparam , (Nil2, e1) U.:* U.End)
-> mustCheck e1
go U.InjTyped{} = False
mustCheck'
:: MetaABT U.SourceSpan U.Term '[ 'U.U ] 'U.U
-> Bool
mustCheck' e = caseBind e $ \_ e' -> mustCheck e'
inferBinder
:: (ABT Term abt)
=> Sing a
-> MetaABT U.SourceSpan U.Term '[ 'U.U ] 'U.U
-> (forall b. Sing b -> abt '[ a ] b -> TypeCheckMonad r)
-> TypeCheckMonad r
inferBinder typ e k =
caseBind e $ \x e1 -> do
let x' = x {varType = typ}
TypedAST typ1 e1' <- pushCtx x' (inferType e1)
k typ1 (bind x' e1')
inferBinders
:: (ABT Term abt)
=> List1 Variable xs
-> U.AST
-> (forall a. Sing a -> abt xs a -> TypeCheckMonad r)
-> TypeCheckMonad r
inferBinders = \xs e k -> do
TypedAST typ e' <- pushesCtx xs (inferType e)
k typ (binds_ xs e')
where
pushesCtx
:: List1 Variable (xs :: [Hakaru])
-> TypeCheckMonad b
-> TypeCheckMonad b
pushesCtx Nil1 m = m
pushesCtx (Cons1 x xs) m = pushesCtx xs (TCM (unTCM m . insertVarSet x))
checkBinder
:: (ABT Term abt)
=> Sing a
-> Sing b
-> MetaABT U.SourceSpan U.Term '[ 'U.U ] 'U.U
-> TypeCheckMonad (abt '[ a ] b)
checkBinder typ eTyp e =
caseBind e $ \x e1 -> do
let x' = x {varType = typ}
pushCtx x' (bind x' <$> checkType eTyp e1)
checkBinders
:: (ABT Term abt)
=> List1 Variable xs
-> Sing a
-> U.AST
-> TypeCheckMonad (abt xs a)
checkBinders xs eTyp e =
case xs of
Nil1 -> checkType eTyp e
Cons1 x xs' -> pushCtx x (bind x <$> checkBinders xs' eTyp e)
inferType
:: forall abt
. (ABT Term abt)
=> U.AST
-> TypeCheckMonad (TypedAST abt)
inferType = inferType_
where
checkType_ :: forall b. Sing b -> U.AST -> TypeCheckMonad (abt '[] b)
checkType_ = checkType
inferOneCheckOthers_ ::
[U.AST] -> TypeCheckMonad (TypedASTs abt)
inferOneCheckOthers_ = inferOneCheckOthers
inferVariable
:: Maybe U.SourceSpan
-> Variable 'U.U
-> TypeCheckMonad (TypedAST abt)
inferVariable sourceSpan (Variable hintID nameID _) = do
ctx <- getCtx
case IM.lookup (fromNat nameID) (unVarSet ctx) of
Just (SomeVariable x') ->
return $ TypedAST (varType x') (var x')
Nothing -> ambiguousFreeVariable hintID sourceSpan
inferType_ :: U.AST -> TypeCheckMonad (TypedAST abt)
inferType_ e0 =
let s = getMetadata e0 in
caseVarSyn e0 (inferVariable s) (go s)
where
go :: Maybe U.SourceSpan -> U.MetaTerm -> TypeCheckMonad (TypedAST abt)
go sourceSpan t =
case t of
U.Lam_ (U.SSing typ) e -> do
inferBinder typ e $ \typ2 e2 ->
return . TypedAST (SFun typ typ2) $ syn (Lam_ :$ e2 :* End)
U.App_ e1 e2 -> do
TypedAST typ1 e1' <- inferType_ e1
unifyFun typ1 sourceSpan $ \typ2 typ3 -> do
e2' <- checkType_ typ2 e2
return . TypedAST typ3 $ syn (App_ :$ e1' :* e2' :* End)
U.Let_ e1 e2 -> do
TypedAST typ1 e1' <- inferType_ e1
inferBinder typ1 e2 $ \typ2 e2' ->
return . TypedAST typ2 $ syn (Let_ :$ e1' :* e2' :* End)
U.Ann_ (U.SSing typ1) e1 -> do
TypedAST typ1 <$> checkType_ typ1 e1
U.PrimOp_ op es -> inferPrimOp op es
U.ArrayOp_ op es -> inferArrayOp op es
U.NaryOp_ op es -> do
mode <- getMode
TypedASTs typ es' <-
case mode of
StrictMode -> inferOneCheckOthers_ es
LaxMode -> inferLubType sourceSpan es
UnsafeMode -> inferLubType sourceSpan es
op' <- make_NaryOp typ op
return . TypedAST typ $ syn (NaryOp_ op' $ S.fromList es')
U.Literal_ (Some1 v) ->
return . TypedAST (sing_Literal v) $ syn (Literal_ v)
U.CoerceTo_ (Some2 c) e1 ->
case singCoerceDomCod c of
Nothing
| inferable e1 -> inferType_ e1
| otherwise -> ambiguousNullCoercion sourceSpan
Just (dom,cod) -> do
e1' <- checkType_ dom e1
return . TypedAST cod $ syn (CoerceTo_ c :$ e1' :* End)
U.UnsafeTo_ (Some2 c) e1 ->
case singCoerceDomCod c of
Nothing
| inferable e1 -> inferType_ e1
| otherwise -> ambiguousNullCoercion sourceSpan
Just (dom,cod) -> do
e1' <- checkType_ cod e1
return . TypedAST dom $ syn (UnsafeFrom_ c :$ e1' :* End)
U.MeasureOp_ (U.SomeOp op) es -> do
let (typs, typ1) = sing_MeasureOp op
es' <- checkSArgs typs es
return . TypedAST (SMeasure typ1) $ syn (MeasureOp_ op :$ es')
U.Pair_ e1 e2 -> do
TypedAST typ1 e1' <- inferType_ e1
TypedAST typ2 e2' <- inferType_ e2
return . TypedAST (sPair typ1 typ2) $
syn (Datum_ $ dPair_ typ1 typ2 e1' e2')
U.Array_ e1 e2 -> do
e1' <- checkType_ SNat e1
inferBinder SNat e2 $ \typ2 e2' ->
return . TypedAST (SArray typ2) $ syn (Array_ e1' e2')
U.ArrayLiteral_ es -> do
mode <- getMode
TypedASTs typ es' <-
case mode of
StrictMode -> inferOneCheckOthers_ es
LaxMode -> inferLubType sourceSpan es
UnsafeMode -> inferLubType sourceSpan es
return . TypedAST (SArray typ) $ syn (ArrayLiteral_ es')
U.Case_ e1 branches -> do
TypedAST typ1 e1' <- inferType_ e1
mode <- getMode
case mode of
StrictMode -> inferCaseStrict typ1 e1' branches
LaxMode -> inferCaseLax sourceSpan typ1 e1' branches
UnsafeMode -> inferCaseLax sourceSpan typ1 e1' branches
U.Dirac_ e1 -> do
TypedAST typ1 e1' <- inferType_ e1
return . TypedAST (SMeasure typ1) $ syn (Dirac :$ e1' :* End)
U.MBind_ e1 e2 ->
caseBind e2 $ \x e2' -> do
TypedAST typ1 e1' <- inferType_ e1
unifyMeasure typ1 sourceSpan $ \typ2 ->
let x' = makeVar x typ2 in
pushCtx x' $ do
TypedAST typ3 e3' <- inferType_ e2'
unifyMeasure typ3 sourceSpan $ \_ ->
return . TypedAST typ3 $ syn (MBind :$ e1' :* bind x' e3' :* End)
U.Plate_ e1 e2 ->
caseBind e2 $ \x e2' -> do
e1' <- checkType_ SNat e1
let x' = makeVar x SNat
pushCtx x' $ do
TypedAST typ2 e3' <- inferType_ e2'
unifyMeasure typ2 sourceSpan $ \typ3 ->
return . TypedAST (SMeasure . SArray $ typ3) $
syn (Plate :$ e1' :* bind x' e3' :* End)
U.Chain_ e1 e2 e3 ->
caseBind e3 $ \x e3' -> do
e1' <- checkType_ SNat e1
TypedAST typ2 e2' <- inferType_ e2
let x' = makeVar x typ2
pushCtx x' $ do
TypedAST typ3 e4' <- inferType_ e3'
unifyMeasure typ3 sourceSpan $ \typ4 ->
unifyPair typ4 sourceSpan $ \a b ->
matchTypes typ2 b sourceSpan () () $
return . TypedAST (SMeasure $ sPair (SArray a) typ2) $
syn (Chain :$ e1' :* e2' :* bind x' e4' :* End)
U.Integrate_ e1 e2 e3 -> do
e1' <- checkType_ SReal e1
e2' <- checkType_ SReal e2
e3' <- checkBinder SReal SProb e3
return . TypedAST SProb $
syn (Integrate :$ e1' :* e2' :* e3' :* End)
U.Summate_ e1 e2 e3 -> do
TypedAST typ1 e1' <- inferType e1
e2' <- checkType_ typ1 e2
inferBinder typ1 e3 $ \typ2 ee' ->
case (hDiscrete_Sing typ1, hSemiring_Sing typ2) of
(Just h1, Just h2) ->
return . TypedAST typ2 $
syn (Summate h1 h2 :$ e1' :* e2' :* ee' :* End)
_ -> failwith_ "Summate given bounds which are not discrete"
U.Product_ e1 e2 e3 -> do
TypedAST typ1 e1' <- inferType e1
e2' <- checkType_ typ1 e2
inferBinder typ1 e3 $ \typ2 e3' ->
case (hDiscrete_Sing typ1, hSemiring_Sing typ2) of
(Just h1, Just h2) ->
return . TypedAST typ2 $
syn (Product h1 h2 :$ e1' :* e2' :* e3' :* End)
_ -> failwith_ "Product given bounds which are not discrete"
U.Bucket_ e1 e2 r1 -> do
e1' <- checkType_ SNat e1
e2' <- checkType_ SNat e2
TypedReducer typ1 Nil1 r1' <- inferReducer r1 Nil1
return . TypedAST typ1 $
syn (Bucket e1' e2' r1')
U.Transform_ tr es -> inferTransform sourceSpan tr es
U.Superpose_ pes -> do
mode <- getMode
TypedASTs typ es' <-
case mode of
StrictMode -> inferOneCheckOthers_ (L.toList $ fmap snd pes)
LaxMode -> inferLubType sourceSpan (L.toList $ fmap snd pes)
UnsafeMode -> inferLubType sourceSpan (L.toList $ fmap snd pes)
unifyMeasure typ sourceSpan $ \_ -> do
ps' <- T.traverse (checkType SProb) (fmap fst pes)
return $ TypedAST typ (syn (Superpose_ (L.zip ps' (L.fromList es'))))
U.InjTyped t -> let t' = t in return $ TypedAST (typeOf t') t'
_ | mustCheck e0 -> ambiguousMustCheck sourceSpan
| otherwise -> error "inferType: missing an inferable branch!"
inferTransform
:: Maybe U.SourceSpan
-> Transform as x
-> U.SArgs U.U_ABT as
-> TypeCheckMonad (TypedAST abt)
inferTransform sourceSpan
Expect
((Nil2, e1) U.:* (Cons2 U.ToU Nil2, e2) U.:* U.End) = do
let e1src = getMetadata e1
TypedAST typ1 e1' <- inferType_ e1
unifyMeasure typ1 e1src $ \typ2 -> do
e2' <- checkBinder typ2 SProb e2
return . TypedAST SProb $ syn
(Transform_ Expect :$ e1' :* e2' :* End)
inferTransform sourceSpan
Observe
((Nil2, e1) U.:* (Nil2, e2) U.:* U.End) = do
let e1src = getMetadata e1
TypedAST typ1 e1' <- inferType_ e1
unifyMeasure typ1 e1src $ \typ2 -> do
e2' <- checkType_ typ2 e2
return . TypedAST typ1 $ syn
(Transform_ Observe :$ e1' :* e2' :* End)
inferTransform sourceSpan
MCMC
((Nil2, e1) U.:* (Nil2, e2) U.:* U.End) = do
let e1src = getMetadata e1
e2src = getMetadata e2
TypedAST typ1 e1' <- inferType_ e1
TypedAST typ2 e2' <- inferType_ e2
unifyFun typ1 e1src $ \typa typmb ->
unifyMeasure typmb e1src $ \typb ->
unifyMeasure typ2 e2src $ \typc ->
matchTypes typa typb e1src (SFun typa (SMeasure typa)) typ1 $
matchTypes typb typc e2src typmb typ2 $
return $ TypedAST (SFun typa (SMeasure typa))
$ syn $ Transform_ MCMC :$ e1' :* e2' :* End
inferTransform sourceSpan
(Disint k)
((Nil2, e1) U.:* U.End) = do
let e1src = getMetadata e1
TypedAST typ1 e1' <- inferType_ e1
unifyMeasure typ1 e1src $ \typ2 ->
unifyPair typ2 e1src $ \typa typb ->
return $ TypedAST (SFun typa (SMeasure typb)) $
syn $ Transform_ (Disint k) :$ e1' :* End
inferTransform sourceSpan
Simplify
((Nil2, e1) U.:* U.End) = do
TypedAST typ1 e1' <- inferType_ e1
return $ TypedAST typ1 $ syn (Transform_ Simplify :$ e1' :* End)
inferTransform sourceSpan
Reparam
((Nil2, e1) U.:* U.End) = do
TypedAST typ1 e1' <- inferType_ e1
return $ TypedAST typ1 $ syn (Transform_ Reparam :$ e1' :* End)
inferTransform sourceSpan
Summarize
((Nil2, e1) U.:* U.End) = do
TypedAST typ1 e1' <- inferType_ e1
return $ TypedAST typ1 $ syn (Transform_ Summarize :$ e1' :* End)
inferTransform _ tr _ = error $ "inferTransform{" ++ show tr ++ "}: TODO"
inferPrimOp
:: U.PrimOp
-> [U.AST]
-> TypeCheckMonad (TypedAST abt)
inferPrimOp U.Not es =
case es of
[e] -> do e' <- checkType_ sBool e
return . TypedAST sBool $ syn (PrimOp_ Not :$ e' :* End)
_ -> argumentNumberError
inferPrimOp U.Pi es =
case es of
[] -> return . TypedAST SProb $ syn (PrimOp_ Pi :$ End)
_ -> argumentNumberError
inferPrimOp U.Cos es =
case es of
[e] -> do e' <- checkType_ SReal e
return . TypedAST SReal $ syn (PrimOp_ Cos :$ e' :* End)
_ -> argumentNumberError
inferPrimOp U.RealPow es =
case es of
[e1, e2] -> do e1' <- checkType_ SProb e1
e2' <- checkType_ SReal e2
return . TypedAST SProb $
syn (PrimOp_ RealPow :$ e1' :* e2' :* End)
_ -> argumentNumberError
inferPrimOp U.Choose es =
case es of
[e1, e2] -> do e1' <- checkType_ SNat e1
e2' <- checkType_ SNat e2
return . TypedAST SNat $
syn (PrimOp_ Choose :$ e1' :* e2' :* End)
_ -> argumentNumberError
inferPrimOp U.Exp es =
case es of
[e] -> do e' <- checkType_ SReal e
return . TypedAST SProb $ syn (PrimOp_ Exp :$ e' :* End)
_ -> argumentNumberError
inferPrimOp U.Log es =
case es of
[e] -> do e' <- checkType_ SProb e
return . TypedAST SReal $ syn (PrimOp_ Log :$ e' :* End)
_ -> argumentNumberError
inferPrimOp U.Infinity es =
case es of
[] -> return . TypedAST SProb $
syn (PrimOp_ (Infinity HIntegrable_Prob) :$ End)
_ -> argumentNumberError
inferPrimOp U.GammaFunc es =
case es of
[e] -> do e' <- checkType_ SReal e
return . TypedAST SProb $ syn (PrimOp_ GammaFunc :$ e' :* End)
_ -> argumentNumberError
inferPrimOp U.BetaFunc es =
case es of
[e1, e2] -> do e1' <- checkType_ SProb e1
e2' <- checkType_ SProb e2
return . TypedAST SProb $
syn (PrimOp_ BetaFunc :$ e1' :* e2' :* End)
_ -> argumentNumberError
inferPrimOp U.Equal es =
case es of
[_, _] -> do mode <- getMode
TypedASTs typ [e1', e2'] <-
case mode of
StrictMode -> inferOneCheckOthers_ es
_ -> inferLubType Nothing es
primop <- Equal <$> getHEq typ
return . TypedAST sBool $
syn (PrimOp_ primop :$ e1' :* e2' :* End)
_ -> argumentNumberError
inferPrimOp U.Less es =
case es of
[_, _] -> do mode <- getMode
TypedASTs typ [e1', e2'] <-
case mode of
StrictMode -> inferOneCheckOthers_ es
_ -> inferLubType Nothing es
primop <- Less <$> getHOrd typ
return . TypedAST sBool $
syn (PrimOp_ primop :$ e1' :* e2' :* End)
_ -> argumentNumberError
inferPrimOp U.NatPow es =
case es of
[e1, e2] -> do TypedAST typ e1' <- inferType_ e1
e2' <- checkType_ SNat e2
primop <- NatPow <$> getHSemiring typ
return . TypedAST typ $
syn (PrimOp_ primop :$ e1' :* e2' :* End)
_ -> argumentNumberError
inferPrimOp U.Negate es =
case es of
[e] -> do TypedAST typ e' <- inferType_ e
mode <- getMode
SomeRing ring c <- getHRing typ mode
primop <- Negate <$> return ring
let e'' = case c of
CNil -> e'
c' -> unLC_ . coerceTo c' $ LC_ e'
return . TypedAST (sing_HRing ring) $
syn (PrimOp_ primop :$ e'' :* End)
_ -> argumentNumberError
inferPrimOp U.Abs es =
case es of
[e] -> do TypedAST typ e' <- inferType_ e
mode <- getMode
SomeRing ring c <- getHRing typ mode
primop <- Abs <$> return ring
let e'' = case c of
CNil -> e'
c' -> unLC_ . coerceTo c' $ LC_ e'
return . TypedAST (sing_NonNegative ring) $
syn (PrimOp_ primop :$ e'' :* End)
_ -> argumentNumberError
inferPrimOp U.Signum es =
case es of
[e] -> do TypedAST typ e' <- inferType_ e
mode <- getMode
SomeRing ring c <- getHRing typ mode
primop <- Signum <$> return ring
let e'' = case c of
CNil -> e'
c' -> unLC_ . coerceTo c' $ LC_ e'
return . TypedAST (sing_HRing ring) $
syn (PrimOp_ primop :$ e'' :* End)
_ -> argumentNumberError
inferPrimOp U.Recip es =
case es of
[e] -> do TypedAST typ e' <- inferType_ e
mode <- getMode
SomeFractional frac c <- getHFractional typ mode
primop <- Recip <$> return frac
let e'' = case c of
CNil -> e'
c' -> unLC_ . coerceTo c' $ LC_ e'
return . TypedAST (sing_HFractional frac) $
syn (PrimOp_ primop :$ e'' :* End)
_ -> argumentNumberError
inferPrimOp U.NatRoot es =
case es of
[e1, e2] -> do e1' <- checkType_ SProb e1
e2' <- checkType_ SNat e2
return . TypedAST SProb $
syn (PrimOp_ (NatRoot HRadical_Prob)
:$ e1' :* e2' :* End)
_ -> argumentNumberError
inferPrimOp U.Erf es =
case es of
[e] -> do e' <- checkType_ SReal e
return . TypedAST SReal $
syn (PrimOp_ (Erf HContinuous_Real)
:$ e' :* End)
_ -> argumentNumberError
inferPrimOp x es
| Just y <- lookup x
[(U.Sin , Sin ),
(U.Cos , Cos ),
(U.Tan , Tan ),
(U.Asin , Asin ),
(U.Acos , Acos ),
(U.Atan , Atan ),
(U.Sinh , Sinh ),
(U.Cosh , Cosh ),
(U.Tanh , Tanh ),
(U.Asinh, Asinh),
(U.Acosh, Acosh),
(U.Atanh, Atanh)] =
case es of
[e] -> do e' <- checkType_ SReal e
return . TypedAST SReal $
syn (PrimOp_ y :$ e' :* End)
_ -> argumentNumberError
inferPrimOp U.Floor es =
case es of
[e] -> do e' <- checkType_ SProb e
return . TypedAST SNat $ syn (PrimOp_ Floor :$ e' :* End)
_ -> argumentNumberError
inferPrimOp x _ = error ("TODO: inferPrimOp: " ++ show x)
inferArrayOp :: U.ArrayOp
-> [U.AST]
-> TypeCheckMonad (TypedAST abt)
inferArrayOp U.Index_ es =
case es of
[e1, e2] -> do TypedAST typ1 e1' <- inferType_ e1
unifyArray typ1 Nothing $ \typ2 -> do
e2' <- checkType_ SNat e2
return . TypedAST typ2 $
syn (ArrayOp_ (Index typ2) :$ e1' :* e2' :* End)
_ -> argumentNumberError
inferArrayOp U.Size es =
case es of
[e] -> do TypedAST typ e' <- inferType_ e
unifyArray typ Nothing $ \typ1 ->
return . TypedAST SNat $
syn (ArrayOp_ (Size typ1) :$ e' :* End)
_ -> argumentNumberError
inferArrayOp U.Reduce es =
case es of
[e1, e2, e3] -> do
TypedAST typ e1' <- inferType_ e1
unifyFun typ Nothing $ \typ1 typ2 -> do
Refl <- jmEq1_ typ2 (SFun typ1 typ1)
e2' <- checkType_ typ1 e2
e3' <- checkType_ (SArray typ1) e3
return . TypedAST typ1 $
syn (ArrayOp_ (Reduce typ1)
:$ e1' :* e2' :* e3' :* End)
_ -> argumentNumberError
inferReducer :: U.Reducer xs U.U_ABT 'U.U
-> List1 Variable xs1
-> TypeCheckMonad (TypedReducer abt xs1)
inferReducer (U.R_Fanout_ r1 r2) xs = do
TypedReducer t1 _ r1' <- inferReducer r1 xs
TypedReducer t2 _ r2' <- inferReducer r2 xs
return (TypedReducer (sPair t1 t2) xs (Red_Fanout r1' r2'))
inferReducer (U.R_Index_ x n ix r1) xs = do
let (_, n') = caseBinds n
let b = makeVar x SNat
TypedReducer t1 _ r1' <- inferReducer r1 (Cons1 b xs)
n'' <- checkBinders xs SNat n'
caseBind ix $ \i ix1 ->
let i' = makeVar i SNat
(_, ix2) = caseBinds ix1 in do
ix3 <- pushCtx i' (checkBinders xs SNat ix2)
return . TypedReducer (SArray t1) xs $
Red_Index n'' (bind i' ix3) r1'
inferReducer (U.R_Split_ b r1 r2) xs = do
TypedReducer t1 _ r1' <- inferReducer r1 xs
TypedReducer t2 _ r2' <- inferReducer r2 xs
caseBind b $ \x b1 ->
let (_, b2) = caseBinds b1
x' = makeVar x SNat in do
b3 <- pushCtx x' (checkBinders xs sBool b2)
return . TypedReducer (sPair t1 t2) xs $
(Red_Split (bind x' b3) r1' r2')
inferReducer U.R_Nop_ xs = return (TypedReducer sUnit xs Red_Nop)
inferReducer (U.R_Add_ e) xs =
caseBind e $ \x e1 ->
let (_, e2) = caseBinds e1
x' = makeVar x SNat in
pushCtx x' $
inferBinders xs e2 $ \typ e3 -> do
h <- getHSemiring typ
return $ TypedReducer typ xs (Red_Add h (bind x' e3))
inferOneCheckOthers
:: forall abt
. (ABT Term abt)
=> [U.AST]
-> TypeCheckMonad (TypedASTs abt)
inferOneCheckOthers = inferOne []
where
inferOne :: [U.AST] -> [U.AST] -> TypeCheckMonad (TypedASTs abt)
inferOne ls []
| null ls = ambiguousEmptyNary Nothing
| otherwise = ambiguousMustCheckNary Nothing
inferOne ls (e:rs) = do
m <- try $ inferType e
case m of
Nothing -> inferOne (e:ls) rs
Just (TypedAST typ e') -> do
ls' <- checkOthers typ ls
rs' <- checkOthers typ rs
return (TypedASTs typ (reverse ls' ++ e' : rs'))
checkOthers
:: forall a. Sing a -> [U.AST] -> TypeCheckMonad [abt '[] a]
checkOthers typ = T.traverse (checkType typ)
inferLubType
:: forall abt
. (ABT Term abt)
=> Maybe U.SourceSpan
-> [U.AST]
-> TypeCheckMonad (TypedASTs abt)
inferLubType s = start
where
start :: [U.AST] -> TypeCheckMonad (TypedASTs abt)
start [] = ambiguousEmptyNary Nothing
start (u:us) = do
TypedAST typ1 e1 <- inferType u
TypedASTs typ2 es <- F.foldlM step (TypedASTs typ1 [e1]) us
return (TypedASTs typ2 (reverse es))
step :: TypedASTs abt -> U.AST -> TypeCheckMonad (TypedASTs abt)
step (TypedASTs typ1 es) u = do
TypedAST typ2 e2 <- inferType u
case findLub typ1 typ2 of
Nothing -> missingLub typ1 typ2 s
Just (Lub typ c1 c2) ->
let es' = map (unLC_ . coerceTo c1 . LC_) es
e2' = unLC_ . coerceTo c2 $ LC_ e2
in return (TypedASTs typ (e2' : es'))
inferCaseStrict
:: forall abt a
. (ABT Term abt)
=> Sing a
-> abt '[] a
-> [U.Branch]
-> TypeCheckMonad (TypedAST abt)
inferCaseStrict typA e1 = inferOne []
where
inferOne :: [U.Branch] -> [U.Branch] -> TypeCheckMonad (TypedAST abt)
inferOne ls []
| null ls = ambiguousEmptyNary Nothing
| otherwise = ambiguousMustCheckNary Nothing
inferOne ls (b@(U.Branch_ pat e):rs) = do
SP pat' vars <- checkPattern typA pat
m <- try $ inferBinders vars e $ \typ e' -> do
ls' <- checkOthers typ ls
rs' <- checkOthers typ rs
return (TypedAST typ $ syn (Case_ e1 (reverse ls' ++ (Branch pat' e') : rs')))
case m of
Nothing -> inferOne (b:ls) rs
Just m' -> return m'
checkOthers
:: forall b. Sing b -> [U.Branch] -> TypeCheckMonad [Branch a abt b]
checkOthers typ = T.traverse (checkBranch typA typ)
inferCaseLax
:: forall abt a
. (ABT Term abt)
=> Maybe U.SourceSpan
-> Sing a
-> abt '[] a
-> [U.Branch]
-> TypeCheckMonad (TypedAST abt)
inferCaseLax s typA e1 = start
where
start :: [U.Branch] -> TypeCheckMonad (TypedAST abt)
start [] = ambiguousEmptyNary Nothing
start ((U.Branch_ pat e):us) = do
SP pat' vars <- checkPattern typA pat
inferBinders vars e $ \typ1 e' -> do
SomeBranch typ2 bs <- F.foldlM step (SomeBranch typ1 [Branch pat' e']) us
return . TypedAST typ2 . syn . Case_ e1 $ reverse bs
step :: SomeBranch a abt
-> U.Branch
-> TypeCheckMonad (SomeBranch a abt)
step (SomeBranch typB bs) (U.Branch_ pat e) = do
SP pat' vars <- checkPattern typA pat
inferBinders vars e $ \typE e' ->
case findLub typB typE of
Nothing -> missingLub typB typE s
Just (Lub typLub coeB coeE) ->
return $ SomeBranch typLub
( Branch pat' (coerceTo_nonLC coeE e')
: map (coerceTo coeB) bs
)
checkSArgs
:: (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs)
=> List1 Sing typs
-> [U.AST]
-> TypeCheckMonad (SArgs abt args)
checkSArgs Nil1 [] = return End
checkSArgs (Cons1 typ typs) (e:es) =
(:*) <$> checkType typ e <*> checkSArgs typs es
checkSArgs _ _ =
error "checkSArgs: the number of types and terms doesn't match up"
checkType
:: forall abt a
. (ABT Term abt)
=> Sing a
-> U.AST
-> TypeCheckMonad (abt '[] a)
checkType = checkType_
where
inferType_ :: U.AST -> TypeCheckMonad (TypedAST abt)
inferType_ = inferType
checkVariable
:: forall b
. Sing b
-> Maybe U.SourceSpan
-> Variable 'U.U
-> TypeCheckMonad (abt '[] b)
checkVariable typ0 sourceSpan x = do
TypedAST typ' e0' <- inferType_ (var x)
mode <- getMode
case mode of
StrictMode ->
case jmEq1 typ0 typ' of
Just Refl -> return e0'
Nothing -> typeMismatch sourceSpan (Right typ0) (Right typ')
LaxMode -> checkOrCoerce sourceSpan e0' typ' typ0
UnsafeMode -> checkOrUnsafeCoerce sourceSpan e0' typ' typ0
checkType_
:: forall b. Sing b -> U.AST -> TypeCheckMonad (abt '[] b)
checkType_ typ0 e0 =
let s = getMetadata e0 in
caseVarSyn e0 (checkVariable typ0 s) (go s)
where
go sourceSpan t =
case t of
U.Lam_ (U.SSing typ) e1 ->
unifyFun typ0 sourceSpan $ \typ1 typ2 ->
matchTypes typ1 typ sourceSpan () () $
do e1' <- checkBinder typ1 typ2 e1
return $ syn (Lam_ :$ e1' :* End)
U.Let_ e1 e2 -> do
TypedAST typ1 e1' <- inferType_ e1
e2' <- checkBinder typ1 typ0 e2
return $ syn (Let_ :$ e1' :* e2' :* End)
U.CoerceTo_ (Some2 c) e1 ->
case singCoerceDomCod c of
Nothing -> do
e1' <- checkType_ typ0 e1
return $ syn (CoerceTo_ CNil :$ e1' :* End)
Just (dom, cod) ->
matchTypes typ0 cod sourceSpan () () $ do
e1' <- checkType_ dom e1
return $ syn (CoerceTo_ c :$ e1' :* End)
U.UnsafeTo_ (Some2 c) e1 ->
case singCoerceDomCod c of
Nothing -> do
e1' <- checkType_ typ0 e1
return $ syn (UnsafeFrom_ CNil :$ e1' :* End)
Just (dom, cod) ->
matchTypes typ0 dom sourceSpan () () $ do
e1' <- checkType_ cod e1
return $ syn (UnsafeFrom_ c :$ e1' :* End)
U.PrimOp_ U.Infinity [] -> do
case typ0 of
SNat -> return $
syn (PrimOp_ (Infinity HIntegrable_Nat) :$ End)
SInt -> checkOrCoerce sourceSpan (syn (PrimOp_ (Infinity HIntegrable_Nat) :$ End))
SNat
SInt
SProb -> return $
syn (PrimOp_ (Infinity HIntegrable_Prob) :$ End)
SReal -> checkOrCoerce sourceSpan (syn (PrimOp_ (Infinity HIntegrable_Prob) :$ End))
SProb
SReal
_ -> failwith =<<
makeErrMsg
"Type Mismatch:"
sourceSpan
"infinity can only be checked against nat or prob"
U.NaryOp_ op es -> do
mode <- getMode
case mode of
StrictMode -> safeNaryOp typ0
LaxMode -> safeNaryOp typ0
UnsafeMode -> do
op' <- make_NaryOp typ0 op
(bads, goods) <-
fmap partitionEithers . T.forM es $
\e -> fmap (maybe (Left e) Right)
(tryWith LaxMode (checkType_ typ0 e))
if null bads
then return $ syn (NaryOp_ op' (S.fromList goods))
else do TypedAST typ bad <- inferType (case bads of
[b] -> b
_ -> syn $ U.NaryOp_ op bads)
bad <- checkOrUnsafeCoerce sourceSpan bad typ typ0
return (case bad:goods of
[e] -> e
es' -> syn $ NaryOp_ op' (S.fromList es'))
where
safeNaryOp :: forall c. Sing c -> TypeCheckMonad (abt '[] c)
safeNaryOp typ = do
op' <- make_NaryOp typ op
es' <- T.forM es $ checkType_ typ
return $ syn (NaryOp_ op' (S.fromList es'))
U.Pair_ e1 e2 ->
unifyPair typ0 sourceSpan $ \a b -> do
e1' <- checkType_ a e1
e2' <- checkType_ b e2
return $ syn (Datum_ $ dPair_ a b e1' e2')
U.Array_ e1 e2 ->
unifyArray typ0 sourceSpan $ \typ1 -> do
e1' <- checkType_ SNat e1
e2' <- checkBinder SNat typ1 e2
return $ syn (Array_ e1' e2')
U.ArrayLiteral_ es ->
unifyArray typ0 sourceSpan $ \typ1 ->
if null es then return $ syn (Empty_ typ0) else do
es' <- T.forM es $ checkType_ typ1
return $ syn (ArrayLiteral_ es')
U.Datum_ (U.Datum hint d) ->
case typ0 of
SData _ typ2 ->
(syn . Datum_ . Datum hint typ0)
<$> checkDatumCode typ0 typ2 d
_ -> typeMismatch sourceSpan (Right typ0) (Left "HData")
U.Case_ e1 branches -> do
TypedAST typ1 e1' <- inferType_ e1
branches' <- T.forM branches $ checkBranch typ1 typ0
return $ syn (Case_ e1' branches')
U.Dirac_ e1 ->
unifyMeasure typ0 sourceSpan $ \typ1 -> do
e1' <- checkType_ typ1 e1
return $ syn (Dirac :$ e1' :* End)
U.MBind_ e1 e2 ->
unifyMeasure typ0 sourceSpan $ \_ -> do
TypedAST typ1 e1' <- inferType_ e1
unifyMeasure typ1 (getMetadata e1) $ \typ2 -> do
e2' <- checkBinder typ2 typ0 e2
return $ syn (MBind :$ e1' :* e2' :* End)
U.Plate_ e1 e2 ->
unifyMeasure typ0 sourceSpan $ \typ1 -> do
e1' <- checkType_ SNat e1
unifyArray typ1 sourceSpan $ \typ2 -> do
e2' <- checkBinder SNat (SMeasure typ2) e2
return $ syn (Plate :$ e1' :* e2' :* End)
U.Chain_ e1 e2 e3 ->
unifyMeasure typ0 sourceSpan $ \typ1 ->
unifyPair typ1 sourceSpan $ \aa s ->
unifyArray aa sourceSpan $ \a -> do
e1' <- checkType_ SNat e1
e2' <- checkType_ s e2
e3' <- checkBinder s (SMeasure $ sPair a s) e3
return $ syn (Chain :$ e1' :* e2' :* e3' :* End)
U.Transform_ tr es -> checkTransform sourceSpan typ0 tr es
U.Superpose_ pes ->
unifyMeasure typ0 sourceSpan $ \_ ->
fmap (syn . Superpose_) .
T.forM pes $ \(p,e) ->
(,) <$> checkType_ SProb p <*> checkType_ typ0 e
U.Reject_ ->
unifyMeasure typ0 sourceSpan $ \_ ->
return $ syn (Reject_ typ0)
U.InjTyped t ->
let typ1 = typeOf $ triv t
in case jmEq1 typ0 typ1 of
Just Refl -> return t
Nothing -> typeMismatch sourceSpan (Right typ0) (Right typ1)
_ | inferable e0 -> do
TypedAST typ' e0' <- inferType_ e0
mode <- getMode
case mode of
StrictMode ->
case jmEq1 typ0 typ' of
Just Refl -> return e0'
Nothing -> typeMismatch sourceSpan (Right typ0) (Right typ')
LaxMode -> checkOrCoerce sourceSpan e0' typ' typ0
UnsafeMode -> checkOrUnsafeCoerce sourceSpan e0' typ' typ0
| otherwise -> error "checkType: missing an mustCheck branch!"
checkTransform
:: Maybe U.SourceSpan
-> Sing x'
-> Transform as x
-> U.SArgs U.U_ABT as
-> TypeCheckMonad (abt '[] x')
checkTransform sourceSpan typ0
Expect
((Nil2, e1) U.:* (Cons2 U.ToU Nil2, e2) U.:* U.End) =
case typ0 of
SProb -> do
TypedAST typ1 e1' <- inferType_ e1
unifyMeasure typ1 sourceSpan $ \typ2 -> do
e2' <- checkBinder typ2 typ0 e2
return $ syn (Transform_ Expect :$ e1' :* e2' :* End)
_ -> typeMismatch sourceSpan (Right typ0) (Left "HProb")
checkTransform sourceSpan typ0
Observe
((Nil2, e1) U.:* (Nil2, e2) U.:* U.End) =
unifyMeasure typ0 sourceSpan $ \typ2 -> do
e1' <- checkType_ typ0 e1
e2' <- checkType_ typ2 e2
return $ syn (Transform_ Observe :$ e1' :* e2' :* End)
checkTransform sourceSpan typ0
MCMC
((Nil2, e1) U.:* (Nil2, e2) U.:* U.End) =
unifyFun typ0 sourceSpan $ \typa typmb ->
unifyMeasure typmb sourceSpan $ \typb ->
matchTypes typa typb sourceSpan (SFun typa (SMeasure typa)) typ0 $ do
e1' <- checkType (SFun typa (SMeasure typa)) e1
e2' <- checkType (SMeasure typa) e2
return $ syn $ Transform_ MCMC :$ e1' :* e2' :* End
checkTransform sourceSpan typ0
(Disint k)
((Nil2, e1) U.:* U.End) =
unifyFun typ0 sourceSpan $ \typa typmb ->
unifyMeasure typmb sourceSpan $ \typb -> do
e1' <- checkType (SMeasure (sPair typa typb)) e1
return $ syn $ Transform_ (Disint k) :$ e1' :* End
checkTransform sourceSpan typ0
Simplify
((Nil2, e1) U.:* U.End) = do
e1' <- checkType_ typ0 e1
return $ syn (Transform_ Simplify :$ e1' :* End)
checkTransform sourceSpan typ0
Reparam
((Nil2, e1) U.:* U.End) = do
e1' <- checkType_ typ0 e1
return $ syn (Transform_ Reparam :$ e1' :* End)
checkTransform sourceSpan typ0
Summarize
((Nil2, e1) U.:* U.End) = do
e1' <- checkType_ typ0 e1
return $ syn (Transform_ Summarize :$ e1' :* End)
checkTransform _ _ tr _ = error $ "checkTransform{" ++ show tr ++ "}: TODO"
checkDatumCode
:: forall xss t
. Sing (HData' t)
-> Sing xss
-> U.DCode_
-> TypeCheckMonad (DatumCode xss (abt '[]) (HData' t))
checkDatumCode typA typ d =
case d of
U.Inr d2 ->
case typ of
SPlus _ typ2 -> Inr <$> checkDatumCode typA typ2 d2
_ -> failwith_ "expected datum of `inr' type"
U.Inl d1 ->
case typ of
SPlus typ1 _ -> Inl <$> checkDatumStruct typA typ1 d1
_ -> failwith_ "expected datum of `inl' type"
checkDatumStruct
:: forall xs t
. Sing (HData' t)
-> Sing xs
-> U.DStruct_
-> TypeCheckMonad (DatumStruct xs (abt '[]) (HData' t))
checkDatumStruct typA typ d =
case d of
U.Et d1 d2 ->
case typ of
SEt typ1 typ2 -> Et
<$> checkDatumFun typA typ1 d1
<*> checkDatumStruct typA typ2 d2
_ -> failwith_ "expected datum of `et' type"
U.Done ->
case typ of
SDone -> return Done
_ -> failwith_ "expected datum of `done' type"
checkDatumFun
:: forall x t
. Sing (HData' t)
-> Sing x
-> U.DFun_
-> TypeCheckMonad (DatumFun x (abt '[]) (HData' t))
checkDatumFun typA typ d =
case d of
U.Ident e1 ->
case typ of
SIdent -> Ident <$> checkType_ typA e1
_ -> failwith_ "expected datum of `I' type"
U.Konst e1 ->
case typ of
SKonst typ1 -> Konst <$> checkType_ typ1 e1
_ -> failwith_ "expected datum of `K' type"
checkBranch
:: (ABT Term abt)
=> Sing a
-> Sing b
-> U.Branch
-> TypeCheckMonad (Branch a abt b)
checkBranch patTyp bodyTyp (U.Branch_ pat body) = do
SP pat' vars <- checkPattern patTyp pat
Branch pat' <$> checkBinders vars bodyTyp body
checkPattern
:: Sing a
-> U.Pattern
-> TypeCheckMonad (SomePattern a)
checkPattern = \typA pat ->
case pat of
U.PVar x -> return $ SP PVar (Cons1 (makeVar (U.nameToVar x) typA) Nil1)
U.PWild -> return $ SP PWild Nil1
U.PDatum hint pat1 ->
case typA of
SData _ typ1 -> do
SPC pat1' xs <- checkPatternCode typA typ1 pat1
return $ SP (PDatum hint pat1') xs
_ -> typeMismatch Nothing (Right typA) (Left "HData")
where
checkPatternCode
:: Sing (HData' t)
-> Sing xss
-> U.PCode
-> TypeCheckMonad (SomePatternCode xss t)
checkPatternCode typA typ pat =
case pat of
U.PInr pat2 ->
case typ of
SPlus _ typ2 -> do
SPC pat2' xs <- checkPatternCode typA typ2 pat2
return $ SPC (PInr pat2') xs
_ -> failwith_ "expected pattern of `sum' type"
U.PInl pat1 ->
case typ of
SPlus typ1 _ -> do
SPS pat1' xs <- checkPatternStruct typA typ1 pat1
return $ SPC (PInl pat1') xs
_ -> failwith_ "expected pattern of `zero' type"
checkPatternStruct
:: Sing (HData' t)
-> Sing xs
-> U.PStruct
-> TypeCheckMonad (SomePatternStruct xs t)
checkPatternStruct typA typ pat =
case pat of
U.PEt pat1 pat2 ->
case typ of
SEt typ1 typ2 -> do
SPF pat1' xs <- checkPatternFun typA typ1 pat1
SPS pat2' ys <- checkPatternStruct typA typ2 pat2
return $ SPS (PEt pat1' pat2') (append1 xs ys)
_ -> failwith_ "expected pattern of `et' type"
U.PDone ->
case typ of
SDone -> return $ SPS PDone Nil1
_ -> failwith_ "expected pattern of `done' type"
checkPatternFun
:: Sing (HData' t)
-> Sing x
-> U.PFun
-> TypeCheckMonad (SomePatternFun x t)
checkPatternFun typA typ pat =
case pat of
U.PIdent pat1 ->
case typ of
SIdent -> do
SP pat1' xs <- checkPattern typA pat1
return $ SPF (PIdent pat1') xs
_ -> failwith_ "expected pattern of `I' type"
U.PKonst pat1 ->
case typ of
SKonst typ1 -> do
SP pat1' xs <- checkPattern typ1 pat1
return $ SPF (PKonst pat1') xs
_ -> failwith_ "expected pattern of `K' type"
checkOrCoerce
:: (ABT Term abt)
=> Maybe (U.SourceSpan)
-> abt '[] a
-> Sing a
-> Sing b
-> TypeCheckMonad (abt '[] b)
checkOrCoerce s e typA typB =
case findCoercion typA typB of
Just c -> return . unLC_ . coerceTo c $ LC_ e
Nothing -> typeMismatch s (Right typB) (Right typA)
checkOrUnsafeCoerce
:: (ABT Term abt)
=> Maybe (U.SourceSpan)
-> abt '[] a
-> Sing a
-> Sing b
-> TypeCheckMonad (abt '[] b)
checkOrUnsafeCoerce s e typA typB =
case findEitherCoercion typA typB of
Just (Unsafe c) ->
return . unLC_ . coerceFrom c $ LC_ e
Just (Safe c) ->
return . unLC_ . coerceTo c $ LC_ e
Just (Mixed (_, c1, c2)) ->
return . unLC_ . coerceTo c2 . coerceFrom c1 $ LC_ e
Nothing ->
case (typA, typB) of
(SMeasure typ1, SMeasure _) -> do
let x = Variable (pack "") 0 U.SU
e2' <- checkBinder typ1 typB (bind x $ syn $ U.Dirac_ (var x))
return $ syn (MBind :$ e :* e2' :* End)
(_ , _) -> typeMismatch s (Right typB) (Right typA)