module Control.Reference.TH.Records (makeReferences, debugTH) where
import Language.Haskell.TH hiding (ListT)
import qualified Data.Map as M
import Data.List
import Data.Maybe
import Data.Function (on)
import Control.Monad
import Control.Monad.Writer
import Control.Monad.Trans.State
import Control.Instances.Morph
import Control.Reference.InternalInterface
import Control.Reference.Examples.TH
import Control.Reference.TupleInstances
debugTH :: Q [Dec] -> Q [Dec]
debugTH d = d >>= runIO . putStrLn . pprint >> return []
makeReferences :: Name -> Q [Dec]
makeReferences n
= do inf <- reify n
case inf of
TyConI decl -> case newtypeToData decl of
DataD _ tyConName args _ cons _ ->
createReferences tyConName (args ^? traversal&typeVarName) cons
_ -> fail "makeReferences: Unsupported data type"
_ -> fail "makeReferences: Expected the name of a data type or newtype"
createReferences :: Name -> [Name] -> [Con] -> Q [Dec]
createReferences tyConName args cons
= let toGenerate = group $ sortBy (compare `on` fst) $ concat $ map getConFlds cons
mutableVars = foldl (\a (_,t) -> foldl (flip delete) a (t ^? typeVariableNames :: [Name])) (args++args) (map head toGenerate)
(complete, partials)
= partition ((length cons ==) . length)
$ toGenerate
in do comps <- mapM (createLensForField tyConName args mutableVars . head) complete
parts <- mapM (createPartialLensForField tyConName args mutableVars cons . head) partials
return $ concat (comps ++ parts)
where getConFlds con@(RecC conName conFields) = map (\(n,_,t) -> (n, t)) conFields
getConFlds _ = []
createLensForField :: Name -> [Name] -> [Name] -> (Name,Type) -> Q [Dec]
createLensForField typName typArgs mutArgs (fldName,fldTyp)
= do lTyp <- referenceType (ConT ''Lens) typName typArgs mutArgs fldTyp
lensBody <- genLensBody
return [ SigD lensName lTyp
, ValD (VarP lensName) (NormalB $ lensBody) []
]
where lensName = refName fldName
genLensBody :: Q Exp
genLensBody
= do setVar <- newName "b"
origVar <- newName "s"
return $ VarE 'lens
`AppE` VarE fldName
`AppE` LamE [VarP setVar, VarP origVar]
(RecUpdE (VarE origVar) [(fldName,VarE setVar)])
createPartialLensForField :: Name -> [Name] -> [Name] -> [Con] -> (Name,Type) -> Q [Dec]
createPartialLensForField typName typArgs mutArgs cons (fldName,fldTyp)
= do lTyp <- referenceType (ConT ''Partial) typName typArgs mutArgs fldTyp
lensBody <- genLensBody
return [ SigD lensName lTyp
, ValD (VarP lensName) (NormalB $ lensBody) []
]
where lensName = refName fldName
genLensBody :: Q Exp
genLensBody
= do matchesWithField <- mapM matchWithField consWithField
matchesWithoutField <- mapM matchWithoutField consWithoutField
name <- newName "x"
return $ VarE 'partial
`AppE` LamE [VarP name]
(CaseE (VarE name)
( matchesWithField ++ matchesWithoutField ))
(consWithField, consWithoutField)
= partition (hasField fldName) cons
matchWithField :: Con -> Q Match
matchWithField con
= do (bind, rebuild, vars) <- bindAndRebuild con
setVar <- newName "b"
let Just bindInd = fieldIndex fldName con
bindRight
= ConE 'Right
`AppE` TupE [ VarE (vars !! bindInd)
, LamE [VarP setVar]
(funApplication & element (bindInd+1)
.= VarE setVar $ rebuild)
]
return $ Match bind (NormalB bindRight) []
matchWithoutField :: Con -> Q Match
matchWithoutField con
= do (bind, rebuild, _) <- bindAndRebuild con
return $ Match bind (NormalB (ConE 'Left `AppE` rebuild)) []
referenceType :: Type -> Name -> [Name] -> [Name] -> Type -> Q Type
referenceType refType name args mutArgs fldTyp
= do (fldTyp',mapping) <- makePoly mutArgs fldTyp
let args' = traversal .- (\a -> fromMaybe a (mapping ^? element a)) $ args
return $ ForallT (map PlainTV (sort (nub (M.elems mapping ++ args)))) []
(refType `AppT` addTypeArgs name args
`AppT` addTypeArgs name args'
`AppT` fldTyp
`AppT` fldTyp')
makePoly :: [Name] -> Type -> Q (Type, M.Map Name Name)
makePoly typArgs fldTyp
= runStateT (typVarsBounded !~ updateName $ fldTyp) M.empty
where typVarsBounded :: Simple Traversal Type Name
typVarsBounded = typeVariableNames & filtered (`elem` typArgs)
updateName name = do name' <- lift (newName (nameBase name ++ "'"))
modify (M.insert name name')
return name'
refName :: Name -> Name
refName = nameBaseStr .- \case '_':xs -> xs; xs -> '_':xs
hasField :: Name -> Con -> Bool
hasField n c = not $ null (c ^? recFields & traversal & _1 & filtered (==n) :: [Name])
fieldIndex :: Name -> Con -> Maybe Int
fieldIndex n con = (con ^? recFields) >>= findIndex (\f -> (f ^. _1) == n)
addTypeArgs :: Name -> [Name] -> Type
addTypeArgs n = foldl AppT (ConT n) . map VarT
newtypeToData :: Dec -> Dec
newtypeToData (NewtypeD ctx name tvars kind con derives)
= DataD ctx name tvars kind [con] derives
newtypeToData d = d
bindAndRebuild :: Con -> Q (Pat, Exp, [Name])
bindAndRebuild con
= do let name = con ^. conName
fields = con ^. conFields
bindVars <- replicateM (length fields) (newName "fld")
return ( ConP name (map VarP bindVars)
, (ConE name : map VarE bindVars) ^. turn funApplication
, bindVars
)
instance Morph (StateT s m) (StateT s m) where
morph = id