module Math.Combinatorics.Species.TH
( deriveDefaultSpecies
, deriveSpecies
) where
#if MIN_VERSION_numeric_prelude(0,2,0)
import NumericPrelude hiding (cycle)
#else
import NumericPrelude
import PreludeBase hiding (cycle)
#endif
import Math.Combinatorics.Species.AST
import Math.Combinatorics.Species.AST.Instances ()
import Math.Combinatorics.Species.Class
import Math.Combinatorics.Species.Enumerate
import Math.Combinatorics.Species.Structures
import Control.Applicative (Applicative (..),
(<$>), (<*>))
import Control.Arrow (first, (***))
import Control.Monad (zipWithM)
import Data.Char (toLower)
import Data.Maybe (isJust)
import Data.Typeable
import Language.Haskell.TH
import Language.Haskell.TH.Syntax (lift)
errorQ :: String -> Q a
errorQ msg = reportError msg >> error msg
data Struct = SId
| SList
| SConst Type
| SEnum Type
| SSumProd [(Name, [Struct])]
| SComp Struct Struct
| SSelf
deriving Show
nameToStruct :: Name -> Q Struct
nameToStruct nm = reify nm >>= infoToStruct
where infoToStruct (TyConI d) = decToStruct nm d
infoToStruct _ = errorQ (show nm ++ " is not a type constructor.")
decToStruct :: Name -> Dec -> Q Struct
decToStruct _ (DataD _ nm [bndr] cons _)
= SSumProd <$> mapM (conToStruct nm (tyVarNm bndr)) cons
decToStruct _ (NewtypeD _ nm [bndr] con _)
= SSumProd . (:[]) <$> conToStruct nm (tyVarNm bndr) con
decToStruct _ (TySynD nm [bndr] ty)
= tyToStruct nm (tyVarNm bndr) ty
decToStruct nm _
= errorQ $ "Processing " ++ show nm ++ ": Only type constructors of kind * -> * are supported."
tyVarNm :: TyVarBndr -> Name
tyVarNm (PlainTV n) = n
tyVarNm (KindedTV n _) = n
conToStruct :: Name -> Name -> Con -> Q (Name, [Struct])
conToStruct nm var (NormalC cnm tys)
= (,) cnm <$> mapM (tyToStruct nm var) (map snd tys)
conToStruct nm var (RecC cnm tys)
= (,) cnm <$> mapM (tyToStruct nm var) (map thrd tys)
where thrd (_,_,t) = t
conToStruct nm var (InfixC ty1 cnm ty2)
= (,) cnm <$> mapM (tyToStruct nm var) [snd ty1, snd ty2]
tyToStruct :: Name -> Name -> Type -> Q Struct
tyToStruct nm var (VarT v) | v == var = return SId
| otherwise = errorQ $ "Unknown variable " ++ show v
tyToStruct nm var ListT = return SList
tyToStruct nm var t@(ConT b)
| b == ''[] = return SList
| otherwise = return $ SConst t
tyToStruct nm var (AppT t (VarT v))
| v == var && t == (ConT nm) = return $ SSelf
| v == var = return $ SEnum t
| otherwise = errorQ $ "Unknown variable " ++ show v
tyToStruct nm var (AppT t1 t2@(AppT _ _))
= SComp <$> tyToStruct nm var t1 <*> tyToStruct nm var t2
tyToStruct nm vars t@(AppT _ _)
= return $ SConst t
isRecursive :: Struct -> Bool
isRecursive (SSumProd cons) = any isRecursive (concatMap snd cons)
isRecursive (SComp s1 s2) = isRecursive s1 || isRecursive s2
isRecursive SSelf = True
isRecursive _ = False
structToSp :: Struct -> SpeciesAST
structToSp SId = X
structToSp SList = L
structToSp (SConst (ConT t))
| t == ''Bool = N 2
| otherwise = error $ "structToSp: unrecognized type " ++ show t ++ " in SConst"
structToSp (SEnum t) = error "SEnum in structToSp"
structToSp (SSumProd []) = Zero
structToSp (SSumProd ss) = foldl1 (+) $ map conToSp ss
structToSp (SComp s1 s2) = structToSp s1 `o` structToSp s2
structToSp SSelf = Omega
conToSp :: (Name, [Struct]) -> SpeciesAST
conToSp (_,[]) = One
conToSp (_,ps) = foldl1 (*) $ map structToSp ps
spToExp :: Name -> SpeciesAST -> Q Exp
spToExp self = spToExp'
where
spToExp' Zero = [| 0 |]
spToExp' One = [| 1 |]
spToExp' (N n) = lift n
spToExp' X = [| singleton |]
spToExp' E = [| set |]
spToExp' C = [| cycle |]
spToExp' L = [| linOrd |]
spToExp' Subset = [| subset |]
spToExp' (KSubset k) = [| ksubset $(lift k) |]
spToExp' Elt = [| element |]
spToExp' (f :+ g) = [| $(spToExp' f) + $(spToExp' g) |]
spToExp' (f :* g) = [| $(spToExp' f) * $(spToExp' g) |]
spToExp' (f :. g) = [| $(spToExp' f) `o` $(spToExp' g) |]
spToExp' (f :>< g) = [| $(spToExp' f) >< $(spToExp' g) |]
spToExp' (f :@ g) = [| $(spToExp' f) @@ $(spToExp' g) |]
spToExp' (Der f) = [| oneHole $(spToExp' f) |]
spToExp' (OfSize _ _) = error "Can't reify general size predicate into code"
spToExp' (OfSizeExactly f k) = [| $(spToExp' f) `ofSizeExactly` $(lift k) |]
spToExp' (NonEmpty f) = [| nonEmpty $(spToExp' f) |]
spToExp' (Rec _) = [| wrap $(varE self) |]
spToExp' Omega = [| wrap $(varE self) |]
spToTy :: Name -> SpeciesAST -> Q Type
spToTy self = spToTy'
where
spToTy' Zero = [t| Void |]
spToTy' One = [t| Unit |]
spToTy' (N n) = [t| Const Integer |]
spToTy' X = [t| Id |]
spToTy' E = [t| Set |]
spToTy' C = [t| Cycle |]
spToTy' L = [t| [] |]
spToTy' Subset = [t| Set |]
spToTy' (KSubset _) = [t| Set |]
spToTy' Elt = [t| Id |]
spToTy' (f :+ g) = [t| $(spToTy' f) :+: $(spToTy' g) |]
spToTy' (f :* g) = [t| $(spToTy' f) :*: $(spToTy' g) |]
spToTy' (f :. g) = [t| $(spToTy' f) :.: $(spToTy' g) |]
spToTy' (f :>< g) = [t| $(spToTy' f) :*: $(spToTy' g) |]
spToTy' (f :@ g) = [t| $(spToTy' f) :.: $(spToTy' g) |]
spToTy' (Der f) = [t| Star $(spToTy' f) |]
spToTy' (OfSize f _) = spToTy' f
spToTy' (OfSizeExactly f _) = spToTy' f
spToTy' (NonEmpty f) = spToTy' f
spToTy' (Rec _) = varT self
spToTy' Omega = varT self
mkEnumerableInst :: Name -> SpeciesAST -> Struct -> Maybe Name -> Q Dec
mkEnumerableInst nm sp st code = do
clauses <- mkIsoClauses (isJust code) sp st
let stTy = case code of
Just cd -> [t| Mu $(conT cd) |]
Nothing -> spToTy undefined sp
instanceD (return []) (appT (conT ''Enumerable) (conT nm))
[ tySynInstD ''StructTy (tySynEqn [conT nm] stTy)
, return $ FunD 'iso clauses
]
mkIsoClauses :: Bool -> SpeciesAST -> Struct -> Q [Clause]
mkIsoClauses isRec sp st = (fmap.map) (mkClause isRec) (mkIsoMatches sp st)
where mkClause False (pat, exp) = Clause [pat] (NormalB $ exp) []
mkClause True (pat, exp) = Clause [ConP 'Mu [pat]] (NormalB $ exp) []
mkIsoMatches :: SpeciesAST -> Struct -> Q [(Pat, Exp)]
mkIsoMatches _ SId = newName "x" >>= \x ->
return [(ConP 'Id [VarP x], VarE x)]
mkIsoMatches _ (SConst t)
| t == ConT ''Bool = return [(ConP 'Const [LitP $ IntegerL 1], ConE 'False)
,(ConP 'Const [LitP $ IntegerL 2], ConE 'True)]
| otherwise = error "mkIsoMatches: unrecognized type in SConst case"
mkIsoMatches _ (SEnum t) = newName "x" >>= \x ->
return [(VarP x, AppE (VarE 'iso) (VarE x))]
mkIsoMatches _ (SSumProd []) = return []
mkIsoMatches sp (SSumProd [con]) = mkIsoConMatches sp con
mkIsoMatches sp (SSumProd cons) = addInjs 0 <$> zipWithM mkIsoConMatches (terms sp) cons
where terms (f :+ g) = terms f ++ [g]
terms f = [f]
addInjs :: Int -> [[(Pat, Exp)]] -> [(Pat, Exp)]
addInjs n [ps] = map (addInj (n1) 'Inr) ps
addInjs n (ps:pss) = map (addInj n 'Inl) ps ++ addInjs (n+1) pss
addInj 0 c = first (ConP c . (:[]))
addInj n c = first (ConP 'Inr . (:[])) . addInj (n1) c
mkIsoMatches _ (SComp s1 s2) = newName "x" >>= \x ->
return [ (ConP 'Comp [VarP x]
, AppE (VarE 'iso) (AppE (AppE (VarE 'fmap) (VarE 'iso)) (VarE x))) ]
mkIsoMatches _ SSelf = newName "s" >>= \s ->
return [(VarP s, AppE (VarE 'iso) (VarE s))]
mkIsoConMatches :: SpeciesAST -> (Name, [Struct]) -> Q [(Pat, Exp)]
mkIsoConMatches _ (cnm, []) = return [(ConP 'Unit [], ConE cnm)]
mkIsoConMatches sp (cnm, ps) = map mkProd . sequence <$> zipWithM mkIsoMatches (factors sp) ps
where factors (f :* g) = factors f ++ [g]
factors f = [f]
mkProd :: [(Pat, Exp)] -> (Pat, Exp)
mkProd = (foldl1 (\x y -> (ConP '(:*:) [x, y])) *** foldl AppE (ConE cnm))
. unzip
mkSpeciesSig :: Name -> Q Dec
mkSpeciesSig nm = sigD nm [t| Species s => s |]
mkSpecies :: Name -> SpeciesAST -> Maybe Name -> Q Dec
mkSpecies nm sp (Just code) = valD (varP nm) (normalB (appE (varE 'rec) (conE code))) []
mkSpecies nm sp Nothing = valD (varP nm) (normalB (spToExp undefined sp)) []
deriveDefaultSpecies :: Name -> Q [Dec]
deriveDefaultSpecies nm = do
st <- nameToStruct nm
deriveSpecies nm (structToSp st)
deriveSpecies :: Name -> SpeciesAST -> Q [Dec]
deriveSpecies nm sp = do
st <- nameToStruct nm
let spNm = mkName . map toLower . nameBase $ nm
if (isRecursive st)
then mkEnumerableRec nm spNm st sp
else mkEnumerableNonrec nm spNm st sp
where
mkEnumerableRec nm spNm st sp = do
codeNm <- newName (nameBase nm)
self <- newName "self"
let declCode = DataD [] codeNm [] [NormalC codeNm []] [''Typeable]
[showCode] <- [d| instance Show $(conT codeNm) where
show _ = $(lift (nameBase nm))
|]
[interpCode] <- [d| type instance Interp $(conT codeNm) $(varT self)
= $(spToTy self sp)
|]
applyBody <- NormalB <$> [| unwrap $(spToExp self sp) |]
let astFunctorInst = InstanceD [] (AppT (ConT ''ASTFunctor) (ConT codeNm))
[FunD 'apply [Clause [WildP, VarP self] applyBody []]]
[showMu] <- [d| instance Show a => Show (Mu $(conT codeNm) a) where
show = show . unMu
|]
enum <- mkEnumerableInst nm sp st (Just codeNm)
sig <- mkSpeciesSig spNm
spD <- mkSpecies spNm sp (Just codeNm)
return $ [ declCode
, showCode
, interpCode
, astFunctorInst
, showMu
, enum
, sig
, spD
]
mkEnumerableNonrec nm spNm st sp =
sequence
[ mkEnumerableInst nm sp st Nothing
, mkSpeciesSig spNm
, mkSpecies spNm sp Nothing
]