{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE StandaloneDeriving #-}
module THUtils where

import Language.Haskell.TH
import Language.Haskell.TH.Syntax
import Language.Haskell.TH.ExpandSyns
import Control.Monad.Error
import Control.Applicative
import Data.Generics
import Control.Exception
    
    
deriving instance Ord Type
#if __GLASGOW_HASKELL__ >= 611
deriving instance Ord Kind
deriving instance Ord Pred
deriving instance Ord TyVarBndr
#endif
    
-- name2constructors :: Name -> Q [Con]
-- name2constructors n = do
--   i <- reify n
--   case i of
--       TyConI d -> dec2constructors d
--       _ -> fail ("Not a type name: "++show n++"\nInfo for this name was: "++show i)

-- dec2constructors :: Dec -> Q [Con]  
-- dec2constructors (DataD _ _ _ cs _) = return cs
-- dec2constructors (NewtypeD _ _ _ c _) = return [c]
-- -- dec2constructors (TySynD _ _ t) = type2constructors t
-- dec2constructors x = fail ("Don't know how to extract constructors from this Dec: "
--                            ++ show x)
  
-- type2constructors :: Type -> Q [Con]
-- type2constructors (ForallT _ _ t) = type2constructors t
-- type2constructors (ConT n) = name2constructors n
-- type2constructors (AppT t _) = type2constructors t
-- type2constructors x = fail ("Don't know how to extract constructors from this Type: "
--                             ++ show x)

(@@) = AppT
(@@@) = AppE
infixl 9 @@
infixl 9 @@@


data AppliedTyCon = AppliedTyCon {
      atcHead :: Name
    , atcArgs :: [Type]
    }
                  deriving (Eq,Ord,Show,Data,Typeable)
                           
normaliseSpecialTyCons = everywhere (mkT f)
    where
      f ListT = ConT (''[])
      f (TupleT n) = ConT (tupleTypeName n)
      f ArrowT = ConT (''(->))
      f x = x
                  

-- | Expands synonyms, then tries to parse the type as an applied type constructor
toAppliedTyCon :: (MonadError String m) => Type -> Q (m AppliedTyCon)
toAppliedTyCon t = (go [] . normaliseSpecialTyCons) `fmap` expandSyns t
    where
      go acc (ConT n) = return (AppliedTyCon n acc)
      -- go acc ListT = return (AppliedTyCon ''[] acc)
      -- go acc (TupleT n) = return (AppliedTyCon (tupleTypeName n) acc)
      -- go acc ArrowT = return (AppliedTyCon ''(->) acc)
                      
      go acc (AppT t1 t2) = go (t2:acc) t1
                            
      go acc other = throwError ("Expected applied type constructor, got: "
                                 ++ show (foldl AppT other acc))

fromAppliedTyCon :: AppliedTyCon -> Type
fromAppliedTyCon (AppliedTyCon n ts) | n == ''[] = foldl AppT ListT ts
                                     | otherwise = foldl AppT (ConT n) ts


-- | Get constructors with all type parameters instantiated as
-- described by the 'AppliedTyCon' argument
atc2constructors (AppliedTyCon n args) = do
  i <- reify n
  (params,cs) <- 
      case i of
        -- Note: Synonyms should already be expanded at this point by
        -- toAppliedTyCon
        
        TyConI (DataD _ _ ps cs0 _) -> return (ps,cs0)
        TyConI (NewtypeD _ _ ps c0 _) -> return (ps,[c0])
                  
        _ -> fail ("Expected this name to refer to a data or newtype: "
                  ++show n
                  ++"\nBut info for this name was: "++show i)

  let
      substs = assert (length params == length args)
               (zip params args)
               
      doSubsts x = foldr substInCon x substs
      

  return (doSubsts <$> cs)

  

cutNames = everywhere (mkT cutName)
    where
      cutName = mkName . nameBase
                
pprintUnqual = pprint . cutNames


#define showQ(X)\
            $( (runIO . print =<< (X)) >> [d| showQ_dummy______ = ()|])
        

-- showQ( liftM2 (==) (ConE ''[]) [| [] |] )


instance Ppr AppliedTyCon where
    ppr (AppliedTyCon n args) = ppr (foldl AppT (ConT n) args)


-- | 'Match' with normal body and no where clause
sMatch p b = Match p (NormalB b) []

-- | 'Clause' with normal body and no where clause
sClause ps b = Clause ps (NormalB b) []