{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE CPP #-}

-- {-
module Test.AgataTH (
      agatath
    , derive, deriveall
    , DerivOption(..), (<++>)
    , echoAgata
    , module Test.Agata
    , module Test.QuickCheck
    ) where
-- }-
-- module Test.AgataTH where

import Language.Haskell.TH.Syntax hiding (lift)
import qualified Language.Haskell.TH.Syntax as TH (lift)
import Language.Haskell.TH
import Control.Monad

import Test.Agata
import Test.QuickCheck(Arbitrary(..))

import Data.List(nub, union)
import Data.Maybe(fromMaybe)
import qualified Data.Map as Map
import qualified Data.Set as Set

import Control.Monad.State.Lazy


---------------------------------------------------------------------
-- Some day this file might be tidied up into a presentable state...
--   
data Derivation = Derivation {
    derivNames :: [Name]
  , derivOptions :: Set.Set DerivOption
  }

data DerivOption = 
    Inline Name
  | NoArbitrary
    deriving (Show,Eq,Ord)

deriveall :: [Name] -> Derivation
deriveall ns = Derivation ns Set.empty

derive :: Name -> Derivation
derive n = deriveall [n]





(<++>) :: Derivation -> DerivOption -> Derivation
(<++>) d o = d{derivOptions = o `Set.insert` derivOptions d}


echoAgata s n =  agatath (derive  n) >>= (\r -> return [FunD (mkName s) [Clause [] (NormalB $ LitE $ StringL r)  []]]) . dump

agatath :: Derivation -> Q [Dec]
agatath der@(Derivation ts ss) = fmap concat $ mapM deriveAgata ts where
  isSet o = o `Set.member` ss
  deriveAgata n = do
    i@(TyConI d)  <-  reify n
  
    nns <- replicateM (length $ dParams d) (newName "a")
    nns1 <- replicateM (length $ dParams d) (newName "b") -- >>= mapM unVarBndr
    let vs = map VarT nns
    expanded <- fmap reTuple $ expand n nns1

    m@[InstanceD [] (AppT (ConT cBuildable_) _) [ValD (VarP improve_) _ _,ValD (VarP build_) _ _,ValD (VarP dimension_) (NormalB (SigE (AppE rerelate_ _) (AppT tDimension_ _))) []]] <-
       [d| instance Buildable T1 where
             improve = undefined
             build = undefined
             dimension = retag dimension :: Dimension T1
       |]

    impbody <- mapM impClause (dConsts d)
    buildbody <- fmap NormalB $ bldClauses (dConsts d) -- mapM (bldClause t) (dConsts d) >>= return . NormalB . ListE

    allTypesT_t <- fmap (nub . concat) $ mapM (recs n . cFields) (dConsts d)
  
  
    let 
      isRecursive = Mut `elem` allTypesT_t
      dimplus = InfixE (Just $ VarE dimension_)  (VarE $ mkName "+") (Just (LitE (IntegerL 1)))
      dimtyp = ForallT (map varBndr nns1) [] $ AppT (AppT ArrowT (AppT tDimension_ expanded)) (AppT tDimension_ (getType n nns1))
      dimbody = NormalB $ AppE (SigE rerelate_ dimtyp) (if isRecursive then dimplus else VarE dimension_)

    let preqs = allInClass cBuildable_ vs

    arb <- arbInstance preqs vs
  
    return $ [
      InstanceD preqs (AppT (ConT cBuildable_) (rt vs n)) 
        [FunD improve_ impbody
        , ValD (VarP build_) buildbody []
        , ValD (VarP dimension_) dimbody []
        ]] ++ if isSet NoArbitrary then [] else [arb]
  

    where
      rt :: [Type] -> Name -> Type
      rt [] n = ConT n
      rt (v:vs) n = AppT (rt vs n) v
      genPE n = do
        ids <- replicateM n (newName "x")
        return (map varP ids, map varE ids)
      
      bldClauses [c]     = bldClause c
      bldClauses (c:cs)  = [| $(bldClause c) ++ $(bldClauses cs) |]
     
      bldClause :: Con -> Q Exp
      bldClause c 
        | isSet $ Inline $ cName c =
          [| inline $(conE $ cName c) |]
        | otherwise               = do
          let ts   = cFields c
              name = cName c
              f [] = [| id |]
              f (Auto:vars) = [| autorec .> ($(f vars)) |]
              f (Mut:vars) = [| automutrec .> ($(f vars)) |]
          [| $(conE name) $> $(recs n ts >>= f) |]
      
      impClause c = do
        let fields = cFields c
        let name   = cName c
        let idExp  = cId c
        (pats,vars) <- genPE (length fields)
        let f []       = [| return . id |]
            f (v:vars) = [| rb $v *> $(f vars) |]
        clause [conP name pats]                                 -- (A x1 x2)
               (normalB [| rebuild $(idExp) $(f vars) |]) []  -- "A "++show x1++" "++show x2

      arbInstance preqs vs = do
        m@[InstanceD [] (AppT cArbitrary_ _) body_] <-
          [d| instance Arbitrary T1 where
                arbitrary = agata
          |]
        return $ InstanceD preqs (AppT cArbitrary_ (rt vs n)) body_

data Recu = Mut | Auto deriving (Eq,Show)
recs :: Name -> [Type] -> Q [Recu]
recs n []     = return []
recs n (t:ts) = do
  ats <- allTypesT t
  rest  <- recs n ts
  return $ (if n `Set.member` ats then Mut else Auto) : rest


allTypesT :: Type -> Q (Set.Set Name)
allTypesT t = getCollected (xf t) where
  f n1 = do
    i <- lift $ reify n1
    mapM_ xf (iTypes i)
  xf :: Type -> Collecting Name ()
  xf t = case t of
        ConT n2	-> collectIf n2 (f n2)
        AppT t1 t2  -> xf t1 >> xf t2 
        VarT n	-> return ()
        TupleT x	-> return ()
        ArrowT	-> return ()         
        ListT	-> return ()



contains :: Type -> Name -> Q Bool
contains t n = fmap (Set.member n) $ allTypesT t

flat :: Type -> (Type,[Type])
flat = flat' where
  flat' (AppT t1 t2) = case flat' t1 of
    (t,ts) -> (t,ts++[t2])
  flat' x = (x,[])


getType :: Name -> [Name] -> Type
getType n [] = ConT n
getType n (n1:ns) = AppT (getType n ns) (VarT n1)



expand :: Name -> [Name] -> Q Type
expand n0 ns = fmap simplify $ applic [] (getType n0 ns) where
  applic :: [(Type,[Type])] -> Type -> Q Type
  applic nts t0 = do
    b <- t0 `contains` n0
    if not b then return t0 else case flat t0 of
      (TupleT _,ts) -> fmap toTuple $ mapM (applic nts) ts
      (ConT n, ts)  ->
        if (ConT n,ts) `elem` nts then return (ConT n0) else do
          let rec = applic $ (ConT n,ts) : nts
          i <- reify n
          let fs = toTuple $ nub $ iTypes i
          rec $ subst (zip (iParams i) ts) fs

    where
            subst nmap t1 = case t1 of
              AppT t2 t3  -> AppT (subst nmap t2) (subst nmap t3)
              VarT n1	-> fromMaybe t1 $ lookup n1 nmap
              _		-> t1
  simplify :: Type -> Type
  simplify = toTuple . filter filt . nub . toList

  filt t = case t of
   ConT n -> n0/=n
   AppT t1 t2 -> filt t1 && filt t2
   _ -> True

toList :: Type -> [Type]
toList t = toList' $ flat t where
  toList' :: (Type,[Type]) -> [Type]
  toList' (TupleT _,ts) = concatMap toList ts
  toList' _ = [t]

toTuple :: [Type] -> Type
toTuple [t] = t
toTuple ts = toTuple' ts where
  toTuple' []      = TupleT (length ts)
  toTuple' (t:ts') = AppT (toTuple' ts') t
  
reTuple :: Type -> Type
reTuple = reTuple' . toList where
  reTuple' [] = TupleT 0
  reTuple' [t] = t
  reTuple' (t:ts) = AppT (AppT (TupleT 2) t) $ reTuple' ts


iName :: Info -> Name
iName i = case i of 
  TyConI d -> dName d
iTypes :: Info -> [Type]
iTypes i = case i of 
  TyConI d -> dTypes d
  PrimTyConI n _ _ -> [ConT n]
  _ -> error (show i)
iParams :: Info -> [Name]
iParams i = case i of 
  TyConI d -> dParams d
  

dName d = case d of
  DataD _ n _ _ _ -> n
dTypes d = case d of
  DataD _ _ _ cs _ ->  concatMap cFields cs
  NewtypeD _ _ _ c _ -> cFields c
  TySynD _ _ t    -> [t]
dParams :: Dec -> [Name]
dParams d = case d of
  DataD _ _ ns _ _ -> map unVarBndr ns
  NewtypeD _ _ ns _ _ -> map unVarBndr ns
dConsts :: Dec -> [Con]
dConsts d = case d of
  DataD _ _ _ cs _ -> cs
  NewtypeD _ _ _ c _ -> [c]

cName :: Con -> Name
cName c = case c of 
  NormalC n sts     	-> n
  RecC n _ 		-> n	
  InfixC _ n _ 		-> n	
  ForallC _ _ c1 	-> cName c1
cId = conE . cName
cFields :: Con -> [Type]
cFields c = case c of 
  NormalC n sts		-> map snd sts
  InfixC st n st' 	-> [snd st,snd st']




data T1 = T1


dump :: Ppr a => a -> String
dump = show . ppr


type Collecting b a = StateT (Set.Set b) Q a
collected :: (Ord b) => b -> Collecting b Bool
collected = gets . Set.member

collect :: (Ord b) => b -> Collecting b ()
collect b = modify (Set.insert b)

getCollected :: Collecting b a -> Q (Set.Set b)
getCollected = flip execStateT Set.empty

collectIf :: Ord b => b -> Collecting b () -> Collecting b ()
collectIf b x = do
  collected_b <- collected b
  unless collected_b $ collect b >> x



-- TH 2.4 compatability
-- #if __GLASGOW_HASKELL__ >= 611
#if MIN_VERSION_template_haskell(2,4,0)
unVarBndr :: TyVarBndr -> Name
unVarBndr (PlainTV n) = n
unVarBndr (KindedTV n _) = n

varBndr :: Name -> TyVarBndr
varBndr n = (PlainTV n)

allInClass :: Name -> [Type] -> [Pred]
allInClass n vs = map (ClassP n) (map (:[]) vs)

#else
unVarBndr = id
varBndr = id
allInClass n vs = map (AppT (ConT n)) vs 
#endif





-- DEBUG
topApp :: Name -> Q [Dec]
topApp n = do
  i@(TyConI (DataD _ _ ns _ _))  <-  reify n
  nns1 <- replicateM (length ns) (newName "b")
  expand n nns1 >>= error . dump
testDimVal :: Name -> Q [Dec]
testDimVal n = return []