module Data.Derive.Internal.Traversal(
TraveralType(..), defaultTraversalType,
traversalDerivation1,
traversalInstance, traversalInstance1,
deriveTraversal
) where
import Language.Haskell
import Data.Derive.Internal.Derivation
import Data.List
import qualified Data.Set as S
import Control.Monad.Writer
import Control.Applicative
import Data.Generics.PlateData
import Data.Maybe
instance Monoid w => Applicative (Writer w) where
pure = return
(<*>) = ap
type Trav = Exp
data TraveralType = TraveralType
{ traversalArg :: Int
, traversalCo :: Bool
, traversalName :: QName
, traversalId :: Trav
, traversalDirect :: Trav
, traversalFunc :: QName -> Trav -> Trav
, traversalPlus :: Trav -> Trav -> Trav
, traverseArrow :: Maybe (Trav -> Trav -> Trav)
, traverseTuple :: [Exp] -> Exp
, traverseCtor :: String -> [Exp] -> Exp
, traverseFunc :: Pat -> Exp -> Match
}
defaultTraversalType = TraveralType
{ traversalArg = 1
, traversalCo = False
, traversalName = undefined
, traversalId = var "id"
, traversalDirect = var "_f"
, traversalFunc = \x y -> appP (Var x) y
, traversalPlus = \x y -> apps (Con $ Special Cons) [paren x, paren y]
, traverseArrow = Nothing
, traverseTuple = Tuple
, traverseCtor = \x y -> apps (con x) (map paren y)
, traverseFunc = undefined
}
data RequiredInstance = RequiredInstance
{ requiredDataArg :: String
, requiredPosition :: Int
}
deriving (Eq, Ord)
type WithInstances a = Writer (S.Set RequiredInstance) a
vars f c n = [f $ c : show i | i <- [1..n]]
traversalDerivation1 :: TraveralType -> String -> Derivation
traversalDerivation1 tt nm = Derivation (className $ traversalArg tt) (traversalInstance1 tt nm)
where className n = nm ++ (if n > 1 then show n else "")
traversalInstance1 :: TraveralType -> String -> FullDataDecl -> Either String [Decl]
traversalInstance1 tt nm (_,dat)
| isNothing (traverseArrow tt) && any isTyFun (universeBi dat) = Left $ "Can't derive " ++ prettyPrint (traversalName tt) ++ " for types with arrow"
| dataDeclArity dat == 0 = Left "Cannot derive class for data type arity == 0"
| otherwise = Right $ traversalInstance tt nm dat [deriveTraversal tt dat]
traversalInstance :: TraveralType -> String -> DataDecl -> [WithInstances Decl] -> [Decl]
traversalInstance tt nameBase dat bodyM = [simplify $ InstDecl sl ctx nam args (map InsDecl body)]
where
(body, required) = runWriter (sequence bodyM)
ctx = [ ClassA (qname $ className p) (tyVar n : vars tyVar 's' (p 1))
| RequiredInstance n p <- S.toList required
]
vrs = vars tyVar 't' (dataDeclArity dat)
(vrsBefore,_:vrsAfter) = splitAt (length vrs traversalArg tt) vrs
className n = nameBase ++ (if n > 1 then show n else "")
nam = qname (className (traversalArg tt))
args = TyParen (tyApps (tyCon $ dataDeclName dat) vrsBefore) : vrsAfter
deriveTraversal :: TraveralType -> DataDecl -> WithInstances Decl
deriveTraversal tt dat = fun
where
fun = (\xs -> FunBind [Match sl nam a b c d | Match _ _ a b c d <- xs]) <$> body
args = argPositions dat
nam = unqual $ traversalNameN tt $ traversalArg tt
body = mapM (deriveTraversalCtor tt args) (dataDeclCtors dat)
unqual (Qual _ x) = x
unqual (UnQual x) = x
deriveTraversalCtor :: TraveralType -> ArgPositions -> CtorDecl -> WithInstances Match
deriveTraversalCtor tt ap ctor = do
let nam = ctorDeclName ctor
arity = ctorDeclArity ctor
tTypes <- mapM (deriveTraversalType tt ap) (map (fromBangType . snd) $ ctorDeclFields ctor)
return $ traverseFunc tt (PParen $ PApp (qname nam) (vars pVar 'a' arity))
$ traverseCtor tt nam (zipWith App tTypes (vars var 'a' arity))
deriveTraversalType :: TraveralType -> ArgPositions -> Type -> WithInstances Trav
deriveTraversalType tt ap (TyParen x) = deriveTraversalType tt ap x
deriveTraversalType tt ap TyForall{} = fail "forall not supported in traversal deriving"
deriveTraversalType tt ap (TyFun a b)
= fromJust (traverseArrow tt)
<$> deriveTraversalType tt{traversalCo = not $ traversalCo tt} ap a
<*> deriveTraversalType tt ap b
deriveTraversalType tt ap (TyApp a b) = deriveTraversalApp tt ap a [b]
deriveTraversalType tt ap (TyList a) = deriveTraversalType tt ap $ TyApp (TyCon $ Special ListCon) a
deriveTraversalType tt ap (TyTuple b a) = deriveTraversalType tt ap $ tyApps (TyCon $ Special $ TupleCon b $ length a) a
deriveTraversalType tt ap (TyCon n) = return $ traversalId tt
deriveTraversalType tt ap (TyVar (Ident n))
| ap n /= traversalArg tt = return $ traversalId tt
| traversalCo tt = fail "tyvar used in covariant position"
| otherwise = return $ traversalDirect tt
deriveTraversalApp :: TraveralType -> ArgPositions -> Type -> [Type] -> WithInstances Trav
deriveTraversalApp tt ap (TyApp a b) args = deriveTraversalApp tt ap a (b : args)
deriveTraversalApp tt ap tycon@TyTuple{} args = do
tArgs <- mapM (deriveTraversalType tt ap) args
return $
if (all (== traversalId tt) tArgs) then
traversalId tt
else
Lambda sl [PTuple (vars pVar 't' (length args))]
(traverseTuple tt $ zipWith App tArgs (vars var 't' (length args)))
deriveTraversalApp tt ap tycon args = do
tCon <- deriveTraversalType tt ap tycon
tArgs <- mapM (deriveTraversalType tt ap) args
case tycon of
TyVar (Ident n) | ap n == traversalArg tt -> fail "kind error: type used type constructor"
| otherwise -> tell $ S.fromList
[ RequiredInstance n i
| (t,i) <- zip (reverse tArgs) [1..]
, t /= traversalId tt
]
_ -> return ()
let nonId = [ traverseArg tt i t
| (t,i) <- zip (reverse tArgs) [1..]
, t /= traversalId tt
]
return $ case nonId of
[] -> traversalId tt
_ -> foldl1 (traversalPlus tt) nonId
traverseArg :: TraveralType -> Int -> Trav -> Trav
traverseArg tt n e = traversalFunc tt (traversalNameN tt n) e
traversalNameN :: TraveralType -> Int -> QName
traversalNameN tt n | n <= 1 = nm
| otherwise = nm `f` (if n > 1 then show n else "")
where nm = traversalName tt
f (Qual m x) y = Qual m $ x `g` y
f (UnQual x) y = UnQual $ x `g` y
g (Ident x) y = Ident $ x ++ y
type ArgPositions = String -> Int
argPositions :: DataDecl -> String -> Int
argPositions dat = \nm -> case elemIndex nm args of
Nothing -> error "impossible: tyvar not in scope"
Just k -> length args k
where args = dataDeclVars dat