#ifndef MIN_VERSION_template_haskell
#define MIN_VERSION_template_haskell(x,y,z) (defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 706)
#endif
#ifndef MIN_VERSION_containers
#define MIN_VERSION_containers(x,y,z) 1
#endif
module Lens.Micro.TH
(
Getter,
Fold,
makeLenses,
makeLensesFor,
makeLensesWith,
makeFields,
LensRules,
DefName(..),
lensRules,
lensRulesFor,
defaultFieldRules,
camelCaseFields,
lensField,
simpleLenses,
createClass,
generateSignatures,
generateUpdateableOptics,
generateLazyPatterns,
)
where
import Control.Applicative
import Control.Monad
import Data.Char
import Data.Data
import Data.Either
import Data.Foldable (toList)
import qualified Data.Map as Map
import Data.Map (Map)
import Data.Monoid
import qualified Data.Set as Set
import Data.Set (Set)
import Data.List (nub, findIndices, stripPrefix, isPrefixOf)
import Data.Maybe
import Data.Traversable (traverse, sequenceA)
import Lens.Micro
import Language.Haskell.TH
type Getter s a = forall r. Getting r s a
type Fold s a = forall r. Applicative (Const r) => Getting r s a
elemOf :: Eq a => Getting (Endo [a]) s a -> a -> s -> Bool
elemOf l x = elem x . toListOf l
lengthOf :: Getting (Endo [a]) s a -> s -> Int
lengthOf l = length . toListOf l
setOf :: Ord a => Getting (Endo [a]) s a -> s -> Set a
setOf l = Set.fromList . toListOf l
_ForallT :: Traversal' Type ([TyVarBndr], Cxt, Type)
_ForallT f (ForallT a b c) = (\(x, y, z) -> ForallT x y z) <$> f (a, b, c)
_ForallT _ other = pure other
coerce :: Const r a -> Const r b
coerce = Const . getConst
setIx :: Int -> a -> [a] -> [a]
setIx i x s
| i < 0 || i >= length s = s
| otherwise = let (l, _:r) = splitAt i s
in l ++ [x] ++ r
rewrite :: (Data a, Data b) => (a -> Maybe a) -> b -> b
rewrite f mbA = case cast mbA of
Nothing -> gmapT (rewrite f) mbA
Just a -> let a' = gmapT (rewrite f) a
in fromJust . cast $ fromMaybe a' (f a')
fromSet :: (k -> v) -> Set.Set k -> Map.Map k v
#if MIN_VERSION_containers(0,5,0)
fromSet = Map.fromSet
#else
fromSet f x = Map.fromDistinctAscList [ (k,f k) | k <- Set.toAscList x ]
#endif
overHead :: (a -> a) -> [a] -> [a]
overHead _ [] = []
overHead f (x:xs) = f x : xs
makeLenses :: Name -> DecsQ
makeLenses = makeFieldOptics lensRules
makeLensesFor :: [(String, String)] -> Name -> DecsQ
makeLensesFor fields = makeFieldOptics (lensRulesFor fields)
makeLensesWith :: LensRules -> Name -> DecsQ
makeLensesWith = makeFieldOptics
makeFields :: Name -> DecsQ
makeFields = makeFieldOptics camelCaseFields
simpleLenses :: Lens' LensRules Bool
simpleLenses f r = fmap (\x -> r { _simpleLenses = x}) (f (_simpleLenses r))
generateSignatures :: Lens' LensRules Bool
generateSignatures f r =
fmap (\x -> r { _generateSigs = x}) (f (_generateSigs r))
generateUpdateableOptics :: Lens' LensRules Bool
generateUpdateableOptics f r =
fmap (\x -> r { _allowUpdates = x}) (f (_allowUpdates r))
generateLazyPatterns :: Lens' LensRules Bool
generateLazyPatterns f r =
fmap (\x -> r { _lazyPatterns = x}) (f (_lazyPatterns r))
createClass :: Lens' LensRules Bool
createClass f r =
fmap (\x -> r { _generateClasses = x}) (f (_generateClasses r))
lensField :: Lens' LensRules (Name -> [Name] -> Name -> [DefName])
lensField f r = fmap (\x -> r { _fieldToDef = x}) (f (_fieldToDef r))
lensRules :: LensRules
lensRules = LensRules
{ _simpleLenses = False
, _generateSigs = True
, _generateClasses = False
, _allowUpdates = True
, _lazyPatterns = False
, _fieldToDef = \_ _ n ->
case nameBase n of
'_':x:xs -> [TopName (mkName (toLower x:xs))]
_ -> []
}
lensRulesFor ::
[(String, String)] ->
LensRules
lensRulesFor fields = lensRules & lensField .~ mkNameLookup fields
mkNameLookup :: [(String,String)] -> Name -> [Name] -> Name -> [DefName]
mkNameLookup kvs _ _ field =
[ TopName (mkName v) | (k,v) <- kvs, k == nameBase field]
camelCaseFields :: LensRules
camelCaseFields = defaultFieldRules
camelCaseNamer :: Name -> [Name] -> Name -> [DefName]
camelCaseNamer tyName fields field = maybeToList $ do
fieldPart <- stripPrefix expectedPrefix (nameBase field)
method <- computeMethod fieldPart
let cls = "Has" ++ fieldPart
return (MethodName (mkName cls) (mkName method))
where
expectedPrefix = optUnderscore ++ overHead toLower (nameBase tyName)
optUnderscore = ['_' | any (isPrefixOf "_" . nameBase) fields ]
computeMethod (x:xs) | isUpper x = Just (toLower x : xs)
computeMethod _ = Nothing
defaultFieldRules :: LensRules
defaultFieldRules = LensRules
{ _simpleLenses = True
, _generateSigs = True
, _generateClasses = True
, _allowUpdates = True
, _lazyPatterns = False
, _fieldToDef = camelCaseNamer
}
class HasName t where
name :: Lens' t Name
instance HasName TyVarBndr where
name f (PlainTV n) = PlainTV <$> f n
name f (KindedTV n k) = (`KindedTV` k) <$> f n
instance HasName Name where
name = id
instance HasName Con where
name f (NormalC n tys) = (`NormalC` tys) <$> f n
name f (RecC n tys) = (`RecC` tys) <$> f n
name f (InfixC l n r) = (\n' -> InfixC l n' r) <$> f n
name f (ForallC bds ctx con) = ForallC bds ctx <$> name f con
class HasTypeVars t where
typeVarsEx :: Set Name -> Traversal' t Name
instance HasTypeVars TyVarBndr where
typeVarsEx s f b
| Set.member (b^.name) s = pure b
| otherwise = name f b
instance HasTypeVars Name where
typeVarsEx s f n
| Set.member n s = pure n
| otherwise = f n
instance HasTypeVars Type where
typeVarsEx s f (VarT n) = VarT <$> typeVarsEx s f n
typeVarsEx s f (AppT l r) = AppT <$> typeVarsEx s f l <*> typeVarsEx s f r
typeVarsEx s f (SigT t k) = (`SigT` k) <$> typeVarsEx s f t
typeVarsEx s f (ForallT bs ctx ty) = ForallT bs <$> typeVarsEx s' f ctx <*> typeVarsEx s' f ty
where s' = s `Set.union` Set.fromList (bs ^.. typeVars)
typeVarsEx _ _ t = pure t
#if !MIN_VERSION_template_haskell(2,10,0)
instance HasTypeVars Pred where
typeVarsEx s f (ClassP n ts) = ClassP n <$> typeVarsEx s f ts
typeVarsEx s f (EqualP l r) = EqualP <$> typeVarsEx s f l <*> typeVarsEx s f r
#endif
instance HasTypeVars Con where
typeVarsEx s f (NormalC n ts) =
NormalC n <$> (traverse . _2) (typeVarsEx s f) ts
typeVarsEx s f (RecC n ts) =
RecC n <$> (traverse . _3) (typeVarsEx s f) ts
typeVarsEx s f (InfixC l n r) =
InfixC <$> g l <*> pure n <*> g r
where g (i, t) = (,) i <$> typeVarsEx s f t
typeVarsEx s f (ForallC bs ctx c) =
ForallC bs <$> typeVarsEx s' f ctx <*> typeVarsEx s' f c
where s' = s `Set.union` Set.fromList (bs ^.. typeVars)
instance HasTypeVars t => HasTypeVars [t] where
typeVarsEx s = traverse . typeVarsEx s
instance HasTypeVars t => HasTypeVars (Maybe t) where
typeVarsEx s = traverse . typeVarsEx s
typeVars :: HasTypeVars t => Traversal' t Name
typeVars = typeVarsEx mempty
substTypeVars :: HasTypeVars t => Map Name Name -> t -> t
substTypeVars m = over typeVars $ \n -> fromMaybe n (Map.lookup n m)
makeFieldOptics :: LensRules -> Name -> DecsQ
makeFieldOptics rules tyName =
do info <- reify tyName
case info of
TyConI dec -> makeFieldOpticsForDec rules dec
_ -> fail "makeFieldOptics: Expected type constructor name"
makeFieldOpticsForDec :: LensRules -> Dec -> DecsQ
makeFieldOpticsForDec rules dec = case dec of
DataD _ tyName vars cons _ ->
makeFieldOpticsForDec' rules tyName (mkS tyName vars) cons
NewtypeD _ tyName vars con _ ->
makeFieldOpticsForDec' rules tyName (mkS tyName vars) [con]
DataInstD _ tyName args cons _ ->
makeFieldOpticsForDec' rules tyName (tyName `conAppsT` args) cons
NewtypeInstD _ tyName args con _ ->
makeFieldOpticsForDec' rules tyName (tyName `conAppsT` args) [con]
_ -> fail "makeFieldOptics: Expected data or newtype type-constructor"
where
mkS tyName vars = tyName `conAppsT` map VarT (toListOf typeVars vars)
makeFieldOpticsForDec' :: LensRules -> Name -> Type -> [Con] -> DecsQ
makeFieldOpticsForDec' rules tyName s cons =
do fieldCons <- traverse normalizeConstructor cons
let allFields = toListOf (folded . _2 . folded . _1 . folded) fieldCons
let defCons = over normFieldLabels (expandName allFields) fieldCons
allDefs = setOf (normFieldLabels . folded) defCons
perDef <- sequenceA (fromSet (buildScaffold rules s defCons) allDefs)
let defs = Map.toList perDef
decss <- traverse (makeFieldOptic rules) defs
return (concat decss)
where
normFieldLabels :: Traversal [(Name,[(a,Type)])] [(Name,[(b,Type)])] a b
normFieldLabels = traverse . _2 . traverse . _1
expandName :: [Name] -> Maybe Name -> [DefName]
expandName allFields (Just n) = _fieldToDef rules tyName allFields n
expandName _ _ = []
normalizeConstructor ::
Con ->
Q (Name, [(Maybe Name, Type)])
normalizeConstructor (RecC n xs) =
return (n, [ (Just fieldName, ty) | (fieldName,_,ty) <- xs])
normalizeConstructor (NormalC n xs) =
return (n, [ (Nothing, ty) | (_,ty) <- xs])
normalizeConstructor (InfixC (_,ty1) n (_,ty2)) =
return (n, [ (Nothing, ty1), (Nothing, ty2) ])
normalizeConstructor (ForallC _ _ con) =
do con' <- normalizeConstructor con
return (set (_2 . mapped . _1) Nothing con')
data OpticType = GetterType | LensType
buildScaffold ::
LensRules ->
Type ->
[(Name, [([DefName], Type)])] ->
DefName ->
Q (OpticType, OpticStab, [(Name, Int, [Int])])
buildScaffold rules s cons defName =
do (s',t,a,b) <- buildStab s (concatMap snd consForDef)
let defType
| Just (_,cx,a') <- a ^? _ForallT =
let optic | lensCase = ''Getter
| otherwise = ''Fold
in OpticSa cx optic s' a'
| not (_allowUpdates rules) =
let optic | lensCase = ''Getter
| otherwise = ''Fold
in OpticSa [] optic s' a
| _simpleLenses rules || s' == t && a == b =
let optic
| lensCase = ''Lens'
| otherwise = ''Traversal'
in OpticSa [] optic s' a
| otherwise =
let optic
| lensCase = ''Lens
| otherwise = ''Traversal
in OpticStab optic s' t a b
opticType | has _ForallT a = GetterType
| not (_allowUpdates rules) = GetterType
| otherwise = LensType
return (opticType, defType, scaffolds)
where
consForDef :: [(Name, [Either Type Type])]
consForDef = over (mapped . _2 . mapped) categorize cons
scaffolds :: [(Name, Int, [Int])]
scaffolds = [ (n, length ts, rightIndices ts) | (n,ts) <- consForDef ]
rightIndices :: [Either Type Type] -> [Int]
rightIndices = findIndices (has _Right)
categorize :: ([DefName], Type) -> Either Type Type
categorize (defNames, t)
| defName `elem` defNames = Right t
| otherwise = Left t
lensCase :: Bool
lensCase = all (\x -> lengthOf (_2 . folded . _Right) x == 1) consForDef
data OpticStab = OpticStab Name Type Type Type Type
| OpticSa Cxt Name Type Type
stabToType :: OpticStab -> Type
stabToType (OpticStab c s t a b) = quantifyType [] (c `conAppsT` [s,t,a,b])
stabToType (OpticSa cx c s a ) = quantifyType cx (c `conAppsT` [s,a])
stabToContext :: OpticStab -> Cxt
stabToContext OpticStab{} = []
stabToContext (OpticSa cx _ _ _) = cx
stabToOptic :: OpticStab -> Name
stabToOptic (OpticStab c _ _ _ _) = c
stabToOptic (OpticSa _ c _ _) = c
stabToS :: OpticStab -> Type
stabToS (OpticStab _ s _ _ _) = s
stabToS (OpticSa _ _ s _) = s
stabToA :: OpticStab -> Type
stabToA (OpticStab _ _ _ a _) = a
stabToA (OpticSa _ _ _ a) = a
buildStab :: Type -> [Either Type Type] -> Q (Type,Type,Type,Type)
buildStab s categorizedFields =
do (subA,a) <- unifyTypes targetFields
let s' = applyTypeSubst subA s
sub <- sequenceA (fromSet (newName . nameBase) unfixedTypeVars)
let (t,b) = over both (substTypeVars sub) (s',a)
return (s',t,a,b)
where
(fixedFields, targetFields) = partitionEithers categorizedFields
fixedTypeVars = setOf typeVars fixedFields
unfixedTypeVars = setOf typeVars s Set.\\ fixedTypeVars
makeFieldOptic ::
LensRules ->
(DefName, (OpticType, OpticStab, [(Name, Int, [Int])])) ->
DecsQ
makeFieldOptic rules (defName, (opticType, defType, cons)) =
do cls <- mkCls
sequenceA (cls ++ sig ++ def)
where
mkCls = case defName of
MethodName c n | _generateClasses rules ->
do classExists <- isJust <$> lookupTypeName (show c)
return (if classExists then [] else [makeFieldClass defType c n])
_ -> return []
sig = case defName of
_ | not (_generateSigs rules) -> []
TopName n -> [sigD n (return (stabToType defType))]
MethodName{} -> []
fun n = funD n clauses : inlinePragma n
def = case defName of
TopName n -> fun n
MethodName c n -> [makeFieldInstance defType c (fun n)]
clauses = makeFieldClauses rules opticType cons
makeFieldClass :: OpticStab -> Name -> Name -> DecQ
makeFieldClass defType className methodName =
classD (cxt []) className [PlainTV s, PlainTV a] [FunDep [s] [a]]
[sigD methodName (return methodType)]
where
methodType = quantifyType' (Set.fromList [s,a])
(stabToContext defType)
$ stabToOptic defType `conAppsT` [VarT s,VarT a]
s = mkName "s"
a = mkName "a"
makeFieldInstance :: OpticStab -> Name -> [DecQ] -> DecQ
makeFieldInstance defType className =
instanceD (cxt [])
(return (className `conAppsT` [stabToS defType, stabToA defType]))
makeFieldClauses :: LensRules -> OpticType -> [(Name, Int, [Int])] -> [ClauseQ]
makeFieldClauses rules opticType cons =
case opticType of
GetterType -> [ makeGetterClause conName fieldCount fields
| (conName, fieldCount, fields) <- cons ]
LensType -> [ makeFieldOpticClause conName fieldCount fields irref
| (conName, fieldCount, fields) <- cons ]
where
irref = _lazyPatterns rules
&& length cons == 1
makePureClause :: Name -> Int -> ClauseQ
makePureClause conName fieldCount =
do xs <- replicateM fieldCount (newName "x")
clause [wildP, conP conName (map varP xs)]
(normalB (appE (varE 'pure) (appsE (conE conName : map varE xs))))
[]
makeGetterClause :: Name -> Int -> [Int] -> ClauseQ
makeGetterClause conName fieldCount [] = makePureClause conName fieldCount
makeGetterClause conName fieldCount fields =
do f <- newName "f"
xs <- replicateM (length fields) (newName "x")
let pats (i:is) (y:ys)
| i `elem` fields = varP y : pats is ys
| otherwise = wildP : pats is (y:ys)
pats is _ = map (const wildP) is
fxs = [ appE (varE f) (varE x) | x <- xs ]
body = foldl (\a b -> appsE [varE '(<*>), a, b])
(appE (varE 'coerce) (head fxs))
(tail fxs)
clause [varP f, conP conName (pats [0..fieldCount 1] xs)]
(normalB body)
[]
makeFieldOpticClause :: Name -> Int -> [Int] -> Bool -> ClauseQ
makeFieldOpticClause conName fieldCount [] _ =
makePureClause conName fieldCount
makeFieldOpticClause conName fieldCount (field:fields) irref =
do f <- newName "f"
xs <- replicateM fieldCount (newName "x")
ys <- replicateM (1 + length fields) (newName "y")
let xs' = foldr (\(i,x) -> setIx i x) xs (zip (field:fields) ys)
mkFx i = appE (varE f) (varE (xs !! i))
body0 = appsE [ varE 'fmap
, lamE (map varP ys) (appsE (conE conName : map varE xs'))
, mkFx field
]
body = foldl (\a b -> appsE [varE '(<*>), a, mkFx b]) body0 fields
let wrap = if irref then tildeP else id
clause [varP f, wrap (conP conName (map varP xs))]
(normalB body)
[]
unifyTypes :: [Type] -> Q (Map Name Type, Type)
unifyTypes (x:xs) = foldM (uncurry unify1) (Map.empty, x) xs
unifyTypes [] = fail "unifyTypes: Bug: Unexpected empty list"
unify1 :: Map Name Type -> Type -> Type -> Q (Map Name Type, Type)
unify1 sub (VarT x) y
| Just r <- Map.lookup x sub = unify1 sub r y
unify1 sub x (VarT y)
| Just r <- Map.lookup y sub = unify1 sub x r
unify1 sub x y
| x == y = return (sub, x)
unify1 sub (AppT f1 x1) (AppT f2 x2) =
do (sub1, f) <- unify1 sub f1 f2
(sub2, x) <- unify1 sub1 x1 x2
return (sub2, AppT (applyTypeSubst sub2 f) x)
unify1 sub x (VarT y)
| elemOf typeVars y (applyTypeSubst sub x) =
fail "Failed to unify types: occurs check"
| otherwise = return (Map.insert y x sub, x)
unify1 sub (VarT x) y = unify1 sub y (VarT x)
unify1 sub (ForallT v1 [] t1) (ForallT v2 [] t2) =
do (sub1,t) <- unify1 sub t1 t2
v <- fmap nub (traverse (limitedSubst sub1) (v1++v2))
return (sub1, ForallT v [] t)
unify1 _ x y = fail ("Failed to unify types: " ++ show (x,y))
limitedSubst :: Map Name Type -> TyVarBndr -> Q TyVarBndr
limitedSubst sub (PlainTV n)
| Just r <- Map.lookup n sub =
case r of
VarT m -> limitedSubst sub (PlainTV m)
_ -> fail "Unable to unify exotic higher-rank type"
limitedSubst sub (KindedTV n k)
| Just r <- Map.lookup n sub =
case r of
VarT m -> limitedSubst sub (KindedTV m k)
_ -> fail "Unable to unify exotic higher-rank type"
limitedSubst _ tv = return tv
applyTypeSubst :: Map Name Type -> Type -> Type
applyTypeSubst sub = rewrite aux
where
aux (VarT n) = Map.lookup n sub
aux _ = Nothing
data LensRules = LensRules
{ _simpleLenses :: Bool
, _generateSigs :: Bool
, _generateClasses :: Bool
, _allowUpdates :: Bool
, _lazyPatterns :: Bool
, _fieldToDef :: Name -> [Name] -> Name -> [DefName]
}
data DefName
= TopName Name
| MethodName Name Name
deriving (Show, Eq, Ord)
quantifyType :: Cxt -> Type -> Type
quantifyType c t = ForallT vs c t
where
vs = map PlainTV (toList (setOf typeVars t))
quantifyType' :: Set Name -> Cxt -> Type -> Type
quantifyType' exclude c t = ForallT vs c t
where
vs = map PlainTV (toList (setOf typeVars t Set.\\ exclude))
inlinePragma :: Name -> [DecQ]
#ifdef INLINING
#if MIN_VERSION_template_haskell(2,8,0)
# ifdef OLD_INLINE_PRAGMAS
inlinePragma methodName = [pragInlD methodName (inlineSpecNoPhase Inline False)]
# else
inlinePragma methodName = [pragInlD methodName Inline FunLike AllPhases]
# endif
#else
inlinePragma methodName = [pragInlD methodName (inlineSpecNoPhase True False)]
#endif
#else
inlinePragma _ = []
#endif
conAppsT :: Name -> [Type] -> Type
conAppsT conName = foldl AppT (ConT conName)