module Language.Haskell.Djinn (
djinn,
djinns,
djinnD,
djinnsD
) where
import Data.List (nub, sortBy)
import Data.Ord (comparing)
import Data.Ratio ((%))
import Data.Set (Set, empty, singleton, union, toList)
import Language.Haskell.TH (
Name, Type(..), Dec(..), Pat(..), Exp(..), Body(..), Clause(..),
Match(..), Info(..), Con(..), TyVarBndr(..), Q,
newName, mkName, tupleTypeName, tupleDataName, reify, pprint, report)
import Control.Monad (forM)
import Language.Haskell.Djinn.HTypes (
HType(..), HPat(..), HExpr(..), HClause(..), HEnvironment,
termToHClause, hTypeToFormula, getBinderVars)
import Language.Haskell.Djinn.LJT (prove)
getConTs :: Type -> Set Name
getConTs (ForallT _ _ t) = getConTs t
getConTs (ConT name) = singleton name
getConTs (AppT t1 t2) = getConTs t1 `union` getConTs t2
getConTs (TupleT n) = singleton (tupleTypeName n)
getConTs _ = empty
hType :: Type -> HType
hType (TupleT 0) = HTTuple []
hType (TupleT 1) = error $ "djinn: 1-tuple should not exist"
hType (AppT (AppT ArrowT t1) t2) = HTArrow (hType t1) (hType t2)
hType (AppT (AppT (TupleT 2) t1) t2) = HTTuple (map hType [t1, t2])
hType (AppT (AppT (ConT c) t1) t2) | c == tupleTypeName 2 = HTTuple (map hType [t1, t2])
hType (AppT (AppT (AppT (TupleT 3) t1) t2) t3) = HTTuple (map hType [t1, t2, t3])
hType (AppT (AppT (AppT (ConT c) t1) t2) t3) | c == tupleTypeName 3 = HTTuple (map hType [t1, t2, t3])
hType (AppT (AppT (AppT (AppT (TupleT 4) t1) t2) t3) t4) = HTTuple (map hType [t1, t2, t3, t4])
hType (AppT (AppT (AppT (AppT (ConT c) t1) t2) t3) t4) | c == tupleTypeName 4 = HTTuple (map hType [t1, t2, t3, t4])
hType (AppT (AppT (AppT (AppT (AppT (TupleT 5) t1) t2) t3) t4) t5) = HTTuple (map hType [t1, t2, t3, t4, t5])
hType (AppT (AppT (AppT (AppT (AppT (ConT c) t1) t2) t3) t4) t5) | c == tupleTypeName 5 = HTTuple (map hType [t1, t2, t3, t4, t5])
hType (TupleT n) | n > 5 = error $ "djinn: " ++ show n ++ "-tuple not yet supported (max 5)"
hType (AppT t1 t2) = HTApp (hType t1) (hType t2)
hType (ForallT _ _ t) = hType t
hType (VarT v) = HTVar v
hType (ConT n) = HTCon n
hType t = error $ "djinn: unimplemented in hType: " ++ pprint t
environment :: Type -> Q HEnvironment
environment = fmap concat . mapM environment1 . toList . getConTs
environment1 :: Name -> Q HEnvironment
environment1 name = do
info <- reify name
case info of
ClassI _dec -> fail $ "djinn: unexpected ClassI"
ClassOpI _n _t _c _fx -> fail $ "djinn: unexpected ClassOpI"
TyConI dec -> do
case dec of
DataD _cxt dName dVars dCtors _derivs -> do
dTypes <- forM dCtors $ \(NormalC cName cFields) -> do
let cTypes = map (hType . snd) cFields
cEnv <- mapM (environment . snd) cFields
return ((cName, cTypes), cEnv)
return $ [(dName, (map binderName dVars, HTUnion (map fst dTypes)))]
++ (concat . concatMap snd $ dTypes)
TySynD tName tVars tType -> do
es <- environment tType
return $ [(tName, (map binderName tVars, hType tType))] ++ es
x -> fail $ "djinn: unexpected TyConI " ++ show x
PrimTyConI n _ar _l -> fail $ "djinn: unexpected PrimTyConI " ++ show n
DataConI _n _t _tn _fx -> fail $ "djinn: unexpected DataConI"
VarI _n _t _mdec _fx -> fail $ "djinn: unexpected VarI"
TyVarI _tvName _tvType -> fail $ "djinn: unexpected TyVarI"
binderName :: TyVarBndr -> Name
binderName (PlainTV n) = n
binderName (KindedTV n _k) = n
pat :: HPat -> Pat
pat (HPVar s) = VarP s
pat (HPTuple ps) = TupP (map pat ps)
pat (HPAt s p) = AsP s (pat p)
pat (HPCon c) = ConP c []
pat (HPApply p q) = let ConP c ps = pat p in ConP c (ps ++ [pat q])
expr :: HExpr -> Exp
expr (HELam ps e) = LamE (map pat ps) (expr e)
expr (HEApply e f) = AppE (expr e) (expr f)
expr (HECon c) = ConE c
expr (HEVar v) = VarE v
expr (HETuple es) = foldl AppE (ConE (tupleDataName (length es))) (map expr es)
expr (HECase e ms) = CaseE (expr e) (map case1 ms)
where case1 (p, f) = Match (pat p) (NormalB $ expr f) []
djinn0 :: Bool -> Maybe String -> Type -> Q Exp
djinn0 multi mStr typ = do
syns <- environment typ
name <- case mStr of
Nothing -> newName "djinn"
Just s -> return $ mkName s
let form = hTypeToFormula syns (hType typ)
ps <- (nub . map snd . sortBy (comparing fst) . map (f name)) `fmap` (prove multi [] form)
if multi
then return $ ListE (map g ps)
else case ps of
ps'@(p:_:_) -> do
report False $ "djinn: " ++ show (length ps') ++ " options for: " ++ show name ++ " :: " ++ pprint typ
return $ g p
[p] -> return $ g p
[] -> do
report True $ "djinn: cannot realize: " ++ show name ++ " :: " ++ pprint typ
x <- newName "djinnError"
return $ LetE [ValD (VarP x) (NormalB (VarE x)) [] ] (VarE x)
where
f name p = let c = termToHClause name p
bvs = getBinderVars c
r = if null bvs then (0, 0) else (length (filter (== underscore) bvs) % length bvs, length bvs)
in (r, c)
g (HClause _ pats body) = let e = expr (HELam pats body) in wilderE e
underscore :: Name
underscore = mkName "_"
wilder :: Pat -> Pat
wilder l@(LitP _) = l
wilder (VarP n) | n == underscore = WildP
wilder (TupP ps) = TupP (map wilder ps)
wilder (ConP n ps) = ConP n (map wilder ps)
wilder (InfixP p1 n p2) = InfixP (wilder p1) n (wilder p2)
wilder (TildeP p) = TildeP (wilder p)
wilder (AsP n p) | n == underscore = wilder p
| otherwise = AsP n (wilder p)
wilder (ListP ps) = ListP (map wilder ps)
wilder (SigP p t) = SigP (wilder p) t
wilder p = p
wilderE :: Exp -> Exp
wilderE (AppE e f) = AppE (wilderE e) (wilderE f)
wilderE (InfixE me o mf) = InfixE (fmap wilderE me) (wilderE o) (fmap wilderE mf)
wilderE (LamE ps e) = LamE (map wilder ps) (wilderE e)
wilderE (TupE es) = TupE (map wilderE es)
wilderE (CondE e f g) = CondE (wilderE e) (wilderE f) (wilderE g)
wilderE (LetE ds e) = LetE (map wilderD ds) (wilderE e)
wilderE (CaseE e ms) = CaseE (wilderE e) (map wilderM ms)
wilderE (ListE es) = ListE (map wilderE es)
wilderE (SigE e t) = SigE (wilderE e) t
wilderE e = e
wilderM :: Match -> Match
wilderM (Match p b ds) = Match (wilder p) (wilderB b) (map wilderD ds)
wilderD :: Dec -> Dec
wilderD d = d
wilderB :: Body -> Body
wilderB b = b
djinn :: Q Type
-> Q Exp
djinn qtyp = do
typ <- qtyp
djinn0 False Nothing typ
djinns :: Q Type
-> Q Exp
djinns qtyp = do
typ <- qtyp
djinn0 True Nothing typ
djinnD :: String
-> Q Type
-> Q [Dec]
djinnD str qtyp = do
let name = mkName str
typ <- qtyp
exp' <- djinn0 False (Just str) typ
return
[ SigD name typ
, FunD name [ Clause [] (NormalB $ exp') [] ] ]
djinnsD :: String
-> Q Type
-> Q [Dec]
djinnsD str qtyp = do
let name = mkName str
typ <- qtyp
exp' <- djinn0 True (Just str) typ
let ForallT vs cxt t = typ
return
[ SigD name (ForallT vs cxt (AppT ListT t))
, FunD name [ Clause [] (NormalB $ exp') [] ] ]