module Data.EnumMapSetWrapper (EnumMap (..), EnumSet (..), w, w') where
import Prelude
import Control.Applicative
import Control.Arrow
import Data.List (nub)
import Data.IntSet (IntSet, Key)
import Data.IntMap (IntMap)
import Language.Haskell.TH.Syntax
import Control.DeepSeq
import Data.Foldable
import Data.Traversable
import Data.Typeable
import Data.Data
import Data.Monoid
newtype EnumMap k v = EnumMap { unEnumMap :: IntMap v } deriving
(Eq, Ord, Monoid, Functor, Foldable, Traversable, Typeable, Data, NFData)
newtype EnumSet k = EnumSet { unEnumSet :: IntSet } deriving
(Eq, Ord, Monoid, Typeable, Data, NFData)
infixr 9 `o`
o :: Exp -> Exp -> Exp
o = flip UInfixE (VarE '(.))
pre :: Exp -> Exp
pre f = InfixE Nothing (VarE '(.)) (Just (ParensE f))
post :: Exp -> Exp
post g = InfixE (Just (ParensE g)) (VarE '(.)) Nothing
ki, ko :: Name
ki = mkName "k"
ko = mkName "k'"
xxx :: (a -> a') -> (b -> b') -> (c -> c') -> (a, b, c) -> (a', b', c')
xxx f g h = \ (a, b, c) -> (f a, g b, h c)
pos :: Name -> Type -> (Exp, Cxt, Type)
pos k typ = case typ of
ForallT tvs cxt t -> (wrap, [], ForallT tvs (nub $ cxt' ++ cxt) t') where
(wrap, cxt', t') = pos k t
ArrowT `AppT` a `AppT` b -> (wrap, cxt, ArrowT `AppT` a' `AppT` b') where
(a'unwrap, a'cxt, a') = neg ki a
(b'wrap, b'cxt, b') = pos ko b
cxt = nub (a'cxt ++ b'cxt)
wrap = post b'wrap `o` pre a'unwrap
ConT ((==) ''Key -> True) ->
(VarE 'toEnum, [ClassP ''Enum [VarT k]], VarT k)
ConT ((==) ''IntMap -> True) `AppT` v ->
(ConE 'EnumMap, [], ConT ''EnumMap `AppT` VarT k `AppT` v)
ConT ((==) ''IntSet -> True) ->
(ConE 'EnumSet, [], ConT ''EnumSet `AppT` VarT k)
ConT ((==) ''Maybe -> True) `AppT` a ->
(VarE 'fmap `AppE` wrap, cxt, ConT ''Maybe `AppT` a') where
(wrap, cxt, a') = pos k a
TupleT 2 `AppT` a `AppT` b ->
(wrap, cxt, TupleT 2 `AppT` a' `AppT` b') where
(a'wrap, a'cxt, a') = pos k a
(b'wrap, b'cxt, b') = pos k b
cxt = nub (a'cxt ++ b'cxt)
wrap = UInfixE (ParensE a'wrap) (VarE '(***)) (ParensE b'wrap)
TupleT 3 `AppT` a `AppT` b `AppT` c ->
(wrap, cxt, TupleT 3 `AppT` a' `AppT` b' `AppT` c') where
(a'wrap, a'cxt, a') = pos k a
(b'wrap, b'cxt, b') = pos k b
(c'wrap, c'cxt, c') = pos k c
cxt = nub (a'cxt ++ b'cxt ++ c'cxt)
wrap = VarE 'xxx `AppE` a'wrap `AppE` b'wrap `AppE` c'wrap
ListT `AppT` a -> (wrap, cxt, ListT `AppT` a') where
(a'wrap, cxt, a') = pos k a
wrap = VarE 'map `AppE` a'wrap
VarT t `AppT` a -> (wrap, cxt, VarT t `AppT` a') where
(a'wrap, cxt, a') = pos k a
wrap = VarE '(<$>) `AppE` a'wrap
_ -> (VarE 'id, [], typ)
neg :: Name -> Type -> (Exp, Cxt, Type)
neg k typ = case typ of
ArrowT `AppT` a `AppT` b -> (unwrap, cxt, ArrowT `AppT` a' `AppT` b') where
(a'wrap, a'cxt, a') = pos ki a
(b'unwrap, b'cxt, b') = neg ko b
cxt = nub (a'cxt ++ b'cxt)
unwrap = post b'unwrap `o` pre a'wrap
ConT ((==) ''Key -> True) ->
(VarE 'fromEnum, [ClassP ''Enum [VarT k]], VarT k)
ConT ((==) ''IntMap -> True) `AppT` v ->
(VarE 'unEnumMap, [], ConT ''EnumMap `AppT` VarT k `AppT` v)
ConT ((==) ''IntSet -> True) ->
(VarE 'unEnumSet, [], ConT ''EnumSet `AppT` VarT k)
TupleT 2 `AppT` a `AppT` b ->
(unwrap, cxt, TupleT 2 `AppT` a' `AppT` b') where
(a'unwrap, a'cxt, a') = neg k a
(b'unwrap, b'cxt, b') = neg k b
cxt = nub (a'cxt ++ b'cxt)
unwrap = UInfixE (ParensE a'unwrap) (VarE '(***)) (ParensE b'unwrap)
ListT `AppT` a -> (unwrap, cxt, ListT `AppT` a') where
(a'unwrap, cxt, a') = neg k a
unwrap = VarE 'map `AppE` a'unwrap
_ -> (VarE 'id, [], typ)
substT :: Name -> Name -> Type -> Type
substT from to = subT where
subT :: Type -> Type
subT typ = case typ of
VarT ((==) from -> True) -> VarT to
s `AppT` t -> subT s `AppT` subT t
ForallT tvs cxt t -> ForallT tvs' cxt' (subT t) where
tvs' = nub (map subB tvs)
cxt' = nub (map subP cxt)
_ -> typ
subB :: TyVarBndr -> TyVarBndr
subB tv = case tv of
PlainTV ((==) from -> True) -> PlainTV to
KindedTV ((==) from -> True) k -> KindedTV to k
_ -> tv
subP :: Pred -> Pred
subP p = case p of
ClassP c ts -> ClassP c (map subT ts)
EqualP s t -> EqualP (subT s) (subT t)
w, w' :: Name -> Q [Dec]
(w, w') = (wrap True, wrap False) where
wrap :: Bool -> Name -> Q [Dec]
wrap subst name@(mkName . nameBase -> base) = do
VarI _name (pos ko -> (e, cxt', typ')) _dec fixity <- reify name
let ks = map PlainTV [ki, ko]
let t' = (if subst then substT ko ki else id) $ case typ' of
ForallT tvs cxt t ->
ForallT (ks ++ tvs) (nub $ cxt' ++ cxt) t
t -> ForallT ks cxt' t
let body = NormalB (e `AppE` VarE name)
return [ InfixD fixity base
, PragmaD (InlineP base Inline FunLike AllPhases)
, SigD base t', ValD (VarP base) body [] ]