{-# LANGUAGE TemplateHaskell            #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE UndecidableInstances       #-}
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE PatternGuards              #-}

module Generics.MultiRec.Transformations.TH ( 
  deriveRefRep, prefix, postfix
  ) where

import Generics.MultiRec hiding (show)
import Generics.MultiRec.TH
import Language.Haskell.TH hiding (Stmt ())
import Generics.MultiRec.Transformations.Explicit
import Control.Monad
import Control.Applicative
import Debug.Trace

-- | Derive data type with references and 'HasRef' instance. For a data type
--   N the name of the constructor for a reference is RefN, and the given
--   function is used to change the rest of the constructors and the data
--   type name itself. For example, for the following definition:
--
-- > data Tree = Leaf Int | Bin Tree Tree
-- > data TreeAST :: * -> * where
-- >    Tree :: TreeAST Tree
-- > $(deriveRefRep ''TreeAST (postfix "R"))
--
-- The following data type is generated:
--
-- > data TreeR = LeafR Int | BinR TreeR TreeR | RefTree Path
-- > instance HasRef TreeAST
deriveRefRep :: Name -> (Name -> Name) -> Q [Dec]
deriveRefRep n namef =
  do
    info <- reify n
    let ps  = init (extractParameters info)
    let nps = map (\ (n, ps) -> (remakeName n, ps)) (extractConstructorNames ps info)
    let ns  = map fst nps
    d <- deriveDatas n namef ps ns
    r <- deriveHasRef n namef ps ns
    return $ d ++ r

prefix :: String -> Name -> Name
prefix pref n = mkName $ pref ++ nameBase n

postfix :: String -> Name -> Name
postfix post n = mkName $ nameBase n ++ post

-- | Turn a record-constructor into a normal constructor by just
-- removing all the field names.
stripRecordNames :: Con -> Con
stripRecordNames (RecC n f) =
  NormalC n (map (\(_, s, t) -> (s, t)) f)
stripRecordNames c = c

unApp :: Type -> [Type]
unApp (AppT f a) = unApp f ++ [a]
unApp t          = [t]

-- | Process the reified info of the index GADT, and extract
-- its constructor names, which are also the names of the datatypes
-- that are part of the family.
extractConstructorNames :: [Name] -> Info -> [(Name, [Name])]
extractConstructorNames ps (TyConI (DataD _ _ _ cs _)) = concatMap extractFrom cs
  where
    extractFrom :: Con -> [(Name, [Name])]
    extractFrom (ForallC _ eqs c) = map (\ (n, _) -> (n, concatMap extractEq eqs)) (extractFrom c)
    extractFrom (InfixC _ n _)    = [(n, [])]
    extractFrom (RecC n _)        = [(n, [])]
    extractFrom (NormalC n [])    = [(n, [])]
    extractFrom _                 = []

    extractEq :: Pred -> [Name]
    extractEq (EqualP t1 t2) = filter (\ p -> p `elem` ps) (extractArgs t1 ++ extractArgs t2)
    extractEq _              = []

    extractArgs :: Type -> [Name]
    extractArgs (AppT x (VarT n)) = extractArgs x ++ [n]
    extractArgs (VarT n)          = [n]
    extractArgs _                 = []
extractConstructorNames _  _                           = []

-- | Process the reified info of the index GADT, and extract
-- its type parameters.
extractParameters :: Info -> [Name]
extractParameters (TyConI (DataD _ _ ns _ _)) = concatMap extractFromBndr ns
extractParameters (TyConI (TySynD _ ns _))    = concatMap extractFromBndr ns
extractParameters _                           = []

extractFromBndr :: TyVarBndr -> [Name]
extractFromBndr (PlainTV n)    = [n]
extractFromBndr (KindedTV n _) = [n]

deriveDatas :: Name -> (Name -> Name) -> [Name] -> [Name] -> Q [Dec]
deriveDatas s namef ps ns = zipWithM (deriveData s namef ps ns) [0..] ns

deriveData :: Name -> (Name -> Name) -> [Name] -> [Name] -> Int -> Name -> Q Dec
deriveData s namef ps ns i n = do
  let nm = namef n
  i <- reify n
  cons <- case i of
    TyConI (DataD _ _ _ cs _) -> mapM (mkCon n namef ns) cs
  r <- normalC (prefix "Ref" n) [return (NotStrict, ConT ''Path)]
  dataD (cxt []) nm (typeVariables i) (map return $ r : cons) []

mkCon :: Name -> (Name -> Name) -> [Name] -> Con -> Q Con
mkCon t namef ns (NormalC a b) = normalC (namef a) (map f b) where
  f :: (Strict, Type) -> Q (Strict, Type)
  f (s,t) = g t >>= return . (,) s
  g :: Type -> Q Type
  g (ConT n) | remakeName n `elem` ns = return (ConT $ namef n)
  g (AppT f a) = g a >>= return . AppT f
  g x          = return x

typeVariables :: Info -> [TyVarBndr]
typeVariables (TyConI (DataD    _ _ tv _ _)) = tv
typeVariables (TyConI (NewtypeD _ _ tv _ _)) = tv
typeVariables _                           = []

deriveHasRef :: Name -> (Name -> Name) -> [Name] -> [Name] -> Q [Dec]
deriveHasRef s namef ps ns =
  do
    let tyInsts = [tySynInstD ''RefRep [conT s, conT n] (conT $ namef n) | n <- ns]
    fcs <- liftM concat $ zipWithM (mkFrom ns namef (length ns)) [0..] ns
    tcs <- liftM concat $ zipWithM (mkTo   ns namef (length ns)) [0..] ns
    return <$>
      instanceD (cxt []) (conT ''HasRef `appT` (foldl appT (conT s) (map varT ps)))
        (tyInsts ++ [funD 'toRef tcs, funD 'fromRef fcs])

mkFrom :: [Name] -> (Name -> Name) -> Int -> Int -> Name -> Q [Q Clause]
mkFrom ns namef m i n = do
  let wrapE e = conE 'HIn `appE` (conE 'InR `appE` lrE m i (conE 'Tag `appE` e))
  i <- reify n
  let dn = remakeName n
  let r = clause [conP dn [], conP (prefix "Ref" dn) [varP (field 0)]]
               (normalB $ conE 'HIn `appE` (conE 'Ref `appE` varE (field 0))) []
  let b = case i of
            TyConI (DataD _ _ _ cs _) ->
               zipWith (fromCon wrapE ns dn namef (length cs)) [0..] cs
            TyConI (TySynD t _ _) ->
              [clause [conP dn [], varP (field 0)] (normalB (wrapE $ conE 'K `appE` varE (field 0))) []]
            _ -> error "unknown construct"
  return (r : b)

mkTo :: [Name] -> (Name -> Name) -> Int -> Int -> Name -> Q [Q Clause]
mkTo ns namef m i n = do
  let wrapP p = conP 'HIn [conP 'InR [lrP m i (conP 'Tag [p])]]
  i <- reify n
  let dn = remakeName n
  let r = clause [conP dn [], conP 'HIn [conP 'Ref [varP (field 0)]]] 
               (normalB $ conE (prefix "Ref" dn) `appE` varE (field 0)) []
  let b = case i of
             TyConI (DataD _ _ _ cs _) ->
                  zipWith (toCon wrapP ns dn namef (length cs)) [0..] cs
             TyConI (TySynD t _ _) ->
                  [clause [conP dn [], wrapP $ conP 'K [varP (field 0)]] (normalB $ varE (field 0)) []]
             _ -> error "unknown construct"
  return (r : b)


fromCon :: (Q Exp -> Q Exp) -> [Name] -> Name -> (Name -> Name) -> Int -> Int -> Con -> Q Clause
fromCon wrap ns n namef m i (NormalC cn []) =
    clause
      [conP n [], conP (namef cn) []]
      (normalB $ wrap $ lrE m i $ conE 'C `appE` (conE 'U)) []
fromCon wrap ns n namef m i (NormalC cn fs) =
    -- runIO (putStrLn ("constructor " ++ show ix)) >>
    clause
      [conP n [], conP (namef cn) (map (varP . field) [0..length fs - 1])]
      (normalB $ wrap $ lrE m i $ conE 'C `appE` foldr1 prod (zipWith (fromField ns) [0..] (map snd fs))) []
  where
    prod x y = conE '(:*:) `appE` x `appE` y
fromCon wrap ns n namef m i r@(RecC _ _) =
  fromCon wrap ns n namef m i (stripRecordNames r)
fromCon wrap ns n namef m i (InfixC t1 cn t2) =
  fromCon wrap ns n namef m i (NormalC cn [t1,t2])
fromCon wrap ns n namef m i (ForallC _ _ c) =
  fromCon wrap ns n namef m i c

toCon :: (Q Pat -> Q Pat) -> [Name] -> Name -> (Name -> Name) -> Int -> Int -> Con -> Q Clause
toCon wrap ns n namef m i (NormalC cn []) =
    clause
      [conP n [], wrap $ lrP m i $ conP 'C [conP 'U []]]
      (normalB $ conE $ namef cn) []
toCon wrap ns n namef m i (NormalC cn fs) =
    -- runIO (putStrLn ("constructor " ++ show ix)) >>
    clause
      [conP n [], wrap $ lrP m i $ conP 'C [foldr1 prod (map (varP . field) [0..length fs - 1])]]
      (normalB $ foldl appE (conE $ namef cn) (zipWith (toField ns) [0..] (map snd fs))) []
  where
    prod x y = conP '(:*:) [x,y]
toCon wrap ns n namef m i r@(RecC _ _) =
  toCon wrap ns n namef m i (stripRecordNames r)
toCon wrap ns n namef m i (InfixC t1 cn t2) =
  toCon wrap ns n namef m i (NormalC cn [t1,t2])
toCon wrap ns n namef m i (ForallC _ _ c) =
  toCon wrap ns n namef m i c

fromField :: [Name] -> Int -> Type -> Q Exp
fromField ns nr t = [| $(fromFieldFun ns t) $(varE (field nr)) |]

fromFieldFun :: [Name] -> Type -> Q Exp
fromFieldFun ns t@(ConT n)
  | remakeName n `elem` ns   = [| I . fromRef $(conE $ remakeName n) |]
fromFieldFun ns t
  | ConT n : a <- unApp t, remakeName n `elem` ns
                             = [| I . fromRef $(conE $ remakeName n) |]
fromFieldFun ns t@(AppT f a) = [| D . fmap $(fromFieldFun ns a) |]
fromFieldFun ns t            = [| K |]

toField :: [Name] -> Int -> Type -> Q Exp
toField ns nr t = [| $(toFieldFun ns t) $(varE (field nr)) |]

toFieldFun :: [Name] -> Type -> Q Exp
toFieldFun ns t@(ConT n)
  | remakeName n `elem` ns = [| toRef $(conE $ remakeName n) . unI |]
toFieldFun ns t
  | ConT n : a <- unApp t, remakeName n `elem` ns
                           = [| toRef $(conE $ remakeName n) . unI |]
toFieldFun ns t@(AppT f a) = [| fmap $(toFieldFun ns a) . unD |]
toFieldFun ns t            = [| unK |]

field :: Int -> Name
field n = mkName $ "f" ++ show n

lrP :: Int -> Int -> (Q Pat -> Q Pat)
lrP 1 0 p = p
lrP m 0 p = conP 'L [p]
lrP m i p = conP 'R [lrP (m-1) (i-1) p]

lrE :: Int -> Int -> (Q Exp -> Q Exp)
lrE 1 0 e = e
lrE m 0 e = conE 'L `appE` e
lrE m i e = conE 'R `appE` lrE (m-1) (i-1) e

-- Should we, under certain circumstances, maintain the module name?
remakeName :: Name -> Name
remakeName n = mkName (nameBase n)