module Test.Extrapolate.Derive
( deriveGeneralizable
, deriveGeneralizableIfNeeded
, deriveGeneralizableCascading
)
where
import Test.Extrapolate.Core hiding (isInstanceOf)
import Test.Extrapolate.TypeBinding
import Language.Haskell.TH
import Test.LeanCheck.Basic
import Test.LeanCheck.Utils.TypeBinding
import Control.Monad (unless, liftM, liftM2, filterM)
import Data.List (delete,nub,sort)
import Data.Char (toLower)
import Data.Functor ((<$>))
import Data.Typeable
import Test.Extrapolate.Utils (foldr0)
deriveGeneralizable :: Name -> DecsQ
deriveGeneralizable = deriveGeneralizableX True False
deriveGeneralizableIfNeeded :: Name -> DecsQ
deriveGeneralizableIfNeeded = deriveGeneralizableX False False
deriveGeneralizableCascading :: Name -> DecsQ
deriveGeneralizableCascading = deriveGeneralizableX True True
deriveGeneralizableX :: Bool -> Bool -> Name -> DecsQ
deriveGeneralizableX warnExisting cascade t = do
is <- t `isInstanceOf` ''Generalizable
if is
then do
unless (not warnExisting)
(reportWarning $ "Instance Generalizable " ++ show t
++ " already exists, skipping derivation")
return []
else if cascade
then reallyDeriveGeneralizableCascading t
else reallyDeriveGeneralizable t
reallyDeriveGeneralizable :: Name -> DecsQ
reallyDeriveGeneralizable t = do
isEq <- t `isInstanceOf` ''Eq
isOrd <- t `isInstanceOf` ''Ord
(nt,vs) <- normalizeType t
#if __GLASGOW_HASKELL__ >= 710
cxt <- sequence [ [t| $(conT c) $(return v) |]
#else
cxt <- sequence [ classP c [return v]
#endif
| c <- ''Generalizable:([''Eq | isEq] ++ [''Ord | isOrd])
, v <- vs]
cs <- typeConstructorsArgNames t
asName <- newName "x"
let generalizableExpr = mergeIFns $ foldr1 mergeI
[ do retTypeOf <- lookupValN $ "-" ++ replicate (length ns) '>' ++ ":"
let exprs = [[| expr $(varE n) |] | n <- ns]
let conex = [| $(varE retTypeOf) $(conE c) $(varE asName) |]
let root = [| constant $(stringE $ showJustName c) $(conex) |]
let rhs = foldl (\e1 e2 -> [| $e1 :$ $e2 |]) root exprs
[d| instance Generalizable $(return nt) where
expr $(asP asName $ conP c (map varP ns)) = $rhs |]
| (c,ns) <- cs
]
let generalizableBackground = do
n <- newName "x"
case (isEq, isOrd) of
(True, True) ->
[d| instance Generalizable $(return nt) where
background $(varP n) = [ constant "==" ((==) -:> $(varE n))
, constant "/=" ((/=) -:> $(varE n))
, constant "<" ((<) -:> $(varE n))
, constant "<=" ((<=) -:> $(varE n)) ] |]
(True, False) ->
[d| instance Generalizable $(return nt) where
background $(varP n) = [ constant "==" ((==) -:> $(varE n))
, constant "/=" ((/=) -:> $(varE n)) ] |]
(False, False) ->
[d| instance Generalizable $(return nt) where
background $(varP n) = [] |]
_ -> error $ "reallyDeriveGeneralizable " ++ show t ++ ": the impossible happened"
let generalizableInstances = do
n <- newName "x"
let lets = [letin n c ns | (c,ns) <- cs, not (null ns)]
let rhs = foldr0 (\e1 e2 -> [| $e1 . $e2 |]) [|id|] lets
[d| instance Generalizable $(return nt) where
instances $(varP n) = this $(varE n) $ $rhs |]
let generalizableName = do
[d| instance Generalizable $(return nt) where
name _ = $(stringE vname) |]
cxt |=>| (generalizableName `mergeI` generalizableExpr
`mergeI` generalizableBackground
`mergeI` generalizableInstances)
where
showJustName = reverse . takeWhile (/= '.') . reverse . show
vname = map toLower . take 1 $ showJustName t
letin :: Name -> Name -> [Name] -> ExpQ
letin x c ns = do
und <- VarE <$> lookupValN "undefined"
let lhs = conP c (map varP ns)
let rhs = return $ foldl AppE (ConE c) [und | _ <- ns]
let bot = foldl1 (\e1 e2 -> [| $e1 . $e2 |])
[ [| instances $(varE n) |] | n <- ns ]
[| let $lhs = $rhs `asTypeOf` $(varE x) in $bot |]
typeConstructorsArgNames :: Name -> Q [(Name,[Name])]
typeConstructorsArgNames t = do
cs <- typeConstructors t
sequence [ do ns <- sequence [newName "x" | _ <- ts]
return (c,ns)
| (c,ts) <- cs ]
lookupValN :: String -> Q Name
lookupValN s = do
mn <- lookupValueName s
case mn of
Just n -> return n
Nothing -> fail $ "lookupValN: cannot find " ++ s
data Bla = Bla Int Int
| Ble Char
deriving (Eq, Ord, Show)
reallyDeriveGeneralizableCascading :: Name -> DecsQ
reallyDeriveGeneralizableCascading t =
return . concat
=<< mapM reallyDeriveGeneralizable
=<< filterM (liftM not . isTypeSynonym)
=<< return . (t:) . delete t
=<< t `typeConCascadingArgsThat` (`isntInstanceOf` ''Generalizable)
typeConArgs :: Name -> Q [Name]
typeConArgs t = do
is <- isTypeSynonym t
if is
then liftM typeConTs $ typeSynonymType t
else liftM (nubMerges . map typeConTs . concat . map snd) $ typeConstructors t
where
typeConTs :: Type -> [Name]
typeConTs (AppT t1 t2) = typeConTs t1 `nubMerge` typeConTs t2
typeConTs (SigT t _) = typeConTs t
typeConTs (VarT _) = []
typeConTs (ConT n) = [n]
#if __GLASGOW_HASKELL__ >= 800
typeConTs (InfixT t1 n t2) = typeConTs t1 `nubMerge` typeConTs t2
typeConTs (UInfixT t1 n t2) = typeConTs t1 `nubMerge` typeConTs t2
typeConTs (ParensT t) = typeConTs t
#endif
typeConTs _ = []
typeConArgsThat :: Name -> (Name -> Q Bool) -> Q [Name]
typeConArgsThat t p = do
targs <- typeConArgs t
tbs <- mapM (\t' -> do is <- p t'; return (t',is)) targs
return [t' | (t',p) <- tbs, p]
typeConCascadingArgsThat :: Name -> (Name -> Q Bool) -> Q [Name]
t `typeConCascadingArgsThat` p = do
ts <- t `typeConArgsThat` p
let p' t' = do is <- p t'; return $ t' `notElem` (t:ts) && is
tss <- mapM (`typeConCascadingArgsThat` p') ts
return $ nubMerges (ts:tss)
normalizeType :: Name -> Q (Type, [Type])
normalizeType t = do
ar <- typeArity t
vs <- newVarTs ar
return (foldl AppT (ConT t) vs, vs)
where
newNames :: [String] -> Q [Name]
newNames = mapM newName
newVarTs :: Int -> Q [Type]
newVarTs n = liftM (map VarT)
$ newNames (take n . map (:[]) $ cycle ['a'..'z'])
normalizeTypeUnits :: Name -> Q Type
normalizeTypeUnits t = do
ar <- typeArity t
return (foldl AppT (ConT t) (replicate ar (TupleT 0)))
isInstanceOf :: Name -> Name -> Q Bool
isInstanceOf tn cl = do
ty <- normalizeTypeUnits tn
isInstance cl [ty]
isntInstanceOf :: Name -> Name -> Q Bool
isntInstanceOf tn cl = liftM not (isInstanceOf tn cl)
typeArity :: Name -> Q Int
typeArity t = do
ti <- reify t
return . length $ case ti of
#if __GLASGOW_HASKELL__ < 800
TyConI (DataD _ _ ks _ _) -> ks
TyConI (NewtypeD _ _ ks _ _) -> ks
#else
TyConI (DataD _ _ ks _ _ _) -> ks
TyConI (NewtypeD _ _ ks _ _ _) -> ks
#endif
TyConI (TySynD _ ks _) -> ks
_ -> error $ "error (typeArity): symbol " ++ show t
++ " is not a newtype, data or type synonym"
typeConstructors :: Name -> Q [(Name,[Type])]
typeConstructors t = do
ti <- reify t
return . map simplify $ case ti of
#if __GLASGOW_HASKELL__ < 800
TyConI (DataD _ _ _ cs _) -> cs
TyConI (NewtypeD _ _ _ c _) -> [c]
#else
TyConI (DataD _ _ _ _ cs _) -> cs
TyConI (NewtypeD _ _ _ _ c _) -> [c]
#endif
_ -> error $ "error (typeConstructors): symbol " ++ show t
++ " is neither newtype nor data"
where
simplify (NormalC n ts) = (n,map snd ts)
simplify (RecC n ts) = (n,map trd ts)
simplify (InfixC t1 n t2) = (n,[snd t1,snd t2])
trd (x,y,z) = z
isTypeSynonym :: Name -> Q Bool
isTypeSynonym t = do
ti <- reify t
return $ case ti of
TyConI (TySynD _ _ _) -> True
_ -> False
typeSynonymType :: Name -> Q Type
typeSynonymType t = do
ti <- reify t
return $ case ti of
TyConI (TySynD _ _ t') -> t'
_ -> error $ "error (typeSynonymType): symbol " ++ show t
++ " is not a type synonym"
(|=>|) :: Cxt -> DecsQ -> DecsQ
c |=>| qds = do ds <- qds
return $ map (`ac` c) ds
#if __GLASGOW_HASKELL__ < 800
where ac (InstanceD c ts ds) c' = InstanceD (c++c') ts ds
ac d _ = d
#else
where ac (InstanceD o c ts ds) c' = InstanceD o (c++c') ts ds
ac d _ = d
#endif
mergeIFns :: DecsQ -> DecsQ
mergeIFns qds = do ds <- qds
return $ map m' ds
where
#if __GLASGOW_HASKELL__ < 800
m' (InstanceD c ts ds) = InstanceD c ts [foldr1 m ds]
#else
m' (InstanceD o c ts ds) = InstanceD o c ts [foldr1 m ds]
#endif
FunD n cs1 `m` FunD _ cs2 = FunD n (cs1 ++ cs2)
mergeI :: DecsQ -> DecsQ -> DecsQ
qds1 `mergeI` qds2 = do ds1 <- qds1
ds2 <- qds2
return $ ds1 `m` ds2
where
#if __GLASGOW_HASKELL__ < 800
[InstanceD c ts ds1] `m` [InstanceD _ _ ds2] = [InstanceD c ts (ds1 ++ ds2)]
#else
[InstanceD o c ts ds1] `m` [InstanceD _ _ _ ds2] = [InstanceD o c ts (ds1 ++ ds2)]
#endif
whereI :: DecsQ -> [Dec] -> DecsQ
qds `whereI` w = do ds <- qds
return $ map (`aw` w) ds
#if __GLASGOW_HASKELL__ < 800
where aw (InstanceD c ts ds) w' = InstanceD c ts (ds++w')
aw d _ = d
#else
where aw (InstanceD o c ts ds) w' = InstanceD o c ts (ds++w')
aw d _ = d
#endif
nubMerge :: Ord a => [a] -> [a] -> [a]
nubMerge [] ys = ys
nubMerge xs [] = xs
nubMerge (x:xs) (y:ys) | x < y = x : xs `nubMerge` (y:ys)
| x > y = y : (x:xs) `nubMerge` ys
| otherwise = x : xs `nubMerge` ys
nubMerges :: Ord a => [[a]] -> [a]
nubMerges = foldr nubMerge []