{-# LANGUAGE CPP
, DataKinds
, KindSignatures
, GADTs
, ScopedTypeVariables
, Rank2Types
, FlexibleContexts
, PolyKinds
, ViewPatterns
#-}
{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
module Language.Hakaru.Syntax.TypeOf
(
typeOf
, typeOfReducer
, getTermSing
) where
import qualified Data.Foldable as F
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative (Applicative(..), (<$>))
#endif
import Language.Hakaru.Syntax.IClasses (Pair2(..), fst2, snd2)
import Language.Hakaru.Syntax.Variable (varType)
import Language.Hakaru.Syntax.ABT (ABT, caseBind, paraABT)
import Language.Hakaru.Types.DataKind (Hakaru())
import Language.Hakaru.Types.HClasses (sing_HSemiring)
import Language.Hakaru.Types.Sing (Sing(..), sUnMeasure, sUnit, sPair)
import Language.Hakaru.Types.Coercion
(singCoerceCod, singCoerceDom, Coerce(..))
import Language.Hakaru.Syntax.Datum (Datum(..), Branch(..))
import Language.Hakaru.Syntax.Reducer
import Language.Hakaru.Syntax.AST (Term(..), SCon(..), SArgs(..)
,typeOfTransform
,getSArgsSing)
import Language.Hakaru.Syntax.AST.Sing
(sing_PrimOp, sing_ArrayOp, sing_MeasureOp, sing_NaryOp, sing_Literal)
typeOf :: (ABT Term abt) => abt '[] a -> Sing a
typeOf e0 =
case typeOf_ e0 of
Left err -> error $ "typeOf: " ++ err
Right typ -> typ
typeOf_ :: (ABT Term abt) => abt '[] a -> Either String (Sing a)
typeOf_
= unLiftSing
. paraABT
(LiftSing . return . varType)
(\_ _ -> LiftSing . unLiftSing)
(LiftSing . getTermSing unLiftSing)
typeOfReducer
:: Reducer abt xs a
-> Sing a
typeOfReducer (Red_Fanout a b) = sPair (typeOfReducer a) (typeOfReducer b)
typeOfReducer (Red_Index _ _ a) = SArray (typeOfReducer a)
typeOfReducer (Red_Split _ a b) = sPair (typeOfReducer a) (typeOfReducer b)
typeOfReducer Red_Nop = sUnit
typeOfReducer (Red_Add h _) = sing_HSemiring h
newtype LiftSing (xs :: [Hakaru]) (a :: Hakaru) =
LiftSing { unLiftSing :: Either String (Sing a) }
getTermSing
:: forall abt r
. (ABT Term abt)
=> (forall xs a. r xs a -> Either String (Sing a))
-> forall a
. Term (Pair2 abt r) a
-> Either String (Sing a)
getTermSing singify = go
where
getSing :: forall xs a. Pair2 abt r xs a -> Either String (Sing a)
getSing = singify . snd2
{-# INLINE getSing #-}
getBranchSing
:: forall a b
. Branch a (Pair2 abt r) b
-> Either String (Sing b)
getBranchSing (Branch _ e) = getSing e
{-# INLINE getBranchSing #-}
go :: forall a. Term (Pair2 abt r) a -> Either String (Sing a)
go (Lam_ :$ r1 :* End) =
caseBind (fst2 r1) $ \x _ ->
SFun (varType x) <$> getSing r1
go (App_ :$ r1 :* _ :* End) = do
typ1 <- getSing r1
case typ1 of SFun _ typ3 -> return typ3
go (Let_ :$ _ :* r2 :* End) = getSing r2
go (CoerceTo_ c :$ r1 :* End) =
maybe (coerceTo c <$> getSing r1) return (singCoerceCod c)
go (UnsafeFrom_ c :$ r1 :* End) =
maybe (coerceFrom c <$> getSing r1) return (singCoerceDom c)
go (PrimOp_ o :$ _) = return . snd $ sing_PrimOp o
go (ArrayOp_ o :$ _) = return . snd $ sing_ArrayOp o
go (MeasureOp_ o :$ _) =
return . SMeasure . snd $ sing_MeasureOp o
go (Dirac :$ r1 :* End) = SMeasure <$> getSing r1
go (MBind :$ _ :* r2 :* End) = getSing r2
go (Plate :$ _ :* r2 :* End) = SMeasure . SArray . sUnMeasure <$> getSing r2
go (Integrate :$ _) = return SProb
go (Summate _ h :$ _) = return $ sing_HSemiring h
go (Product _ h :$ _) = return $ sing_HSemiring h
go (Transform_ t :$ as) =
typeOfTransform t <$> getSArgsSing getSing as
go (NaryOp_ o _) = return $ sing_NaryOp o
go (Literal_ v) = return $ sing_Literal v
go (Empty_ typ) = return typ
go (Array_ _ r2) = SArray <$> getSing r2
go (ArrayLiteral_ es) = SArray <$> tryAll "ArrayLiteral_" getSing es
go (Bucket _ _ r) = return (typeOfReducer r)
go (Datum_ (Datum _ typ _)) = return typ
go (Case_ _ bs) = tryAll "Case_" getBranchSing bs
go (Superpose_ pes) = tryAll "Superpose_" (getSing . snd) pes
go (Reject_ typ) = return typ
go (_ :$ _) = error "getTermSing: the impossible happened"
tryAll
:: F.Foldable f
=> String
-> (a -> Either String b)
-> f a
-> Either String b
tryAll name f =
F.foldr step (Left $ "no unique type for " ++ name)
where
step x rest =
case f x of
r@(Right _) -> r
Left _ -> rest
{-# INLINE tryAll #-}