{-# LANGUAGE DoAndIfThenElse #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE EmptyCase #-}
module Data.Parameterized.TH.GADT
  ( 
    
  structuralEquality
  , structuralTypeEquality
  , structuralTypeOrd
  , structuralTraversal
  , structuralShowsPrec
  , structuralHash
  , PolyEq(..)
    
  , DataD
  , lookupDataType'
  , asTypeCon
  , conPat
  , TypePat(..)
  , dataParamTypes
  , assocTypePats
  ) where
import Control.Monad
import Data.Hashable (hashWithSalt)
import Data.Maybe
import Data.Set (Set)
import qualified Data.Set as Set
import Language.Haskell.TH
import Language.Haskell.TH.Datatype
import Data.Parameterized.Classes
type DataD = DatatypeInfo
lookupDataType' :: Name -> Q DatatypeInfo
lookupDataType' = reifyDatatype
conPat ::
  ConstructorInfo  ->
  String           ->
  Q (Pat, [Name]) 
conPat con pre = do
  nms <- newNames pre (length (constructorFields con))
  return (ConP (constructorName con) (VarP <$> nms), nms)
conExpr :: ConstructorInfo -> Exp
conExpr = ConE . constructorName
data TypePat
   = TypeApp TypePat TypePat 
   | AnyType       
   | DataArg Int   
   | ConType TypeQ 
matchTypePat :: [Type] -> TypePat -> Type -> Q Bool
matchTypePat d (TypeApp p q) (AppT x y) = do
  r <- matchTypePat d p x
  case r of
    True -> matchTypePat d q y
    False -> return False
matchTypePat _ AnyType _ = return True
matchTypePat tps (DataArg i) tp
  | i < 0 || i > length tps = error $ "Illegal type pattern index " ++ show i
  | otherwise = do
    return $ stripSigT (tps !! i) == tp
  where
    
    
    stripSigT (SigT t _) = t
    stripSigT t          = t
matchTypePat _ (ConType tpq) tp = do
  tp' <- tpq
  return (tp' == tp)
matchTypePat _ _ _ = return False
dataParamTypes :: DatatypeInfo -> [Type]
dataParamTypes = datatypeVars
assocTypePats :: [Type] -> [(TypePat,v)] -> Type -> Q (Maybe v)
assocTypePats _ [] _ = return Nothing
assocTypePats dTypes ((p,v):pats) tp = do
  r <- matchTypePat dTypes p tp
  case r of
    True -> return (Just v)
    False -> assocTypePats dTypes pats tp
typeVars :: TypeSubstitution a => a -> Set Name
typeVars = Set.fromList . freeVariables
structuralEquality :: TypeQ -> [(TypePat,ExpQ)] -> ExpQ
structuralEquality tpq pats =
  [| \x y -> isJust ($(structuralTypeEquality tpq pats) x y) |]
joinEqMaybe :: Name -> Name -> ExpQ -> ExpQ
joinEqMaybe x y r = do
  [| if $(varE x) == $(varE y) then $(r) else Nothing |]
joinTestEquality :: ExpQ -> Name -> Name -> ExpQ -> ExpQ
joinTestEquality f x y r =
  [| case $(f) $(varE x) $(varE y) of
      Nothing -> Nothing
      Just Refl -> $(r)
   |]
matchEqArguments :: [Type]
                    
                  -> [(TypePat,ExpQ)] 
                 -> Name
                     
                 -> Set Name
                 -> [Type]
                 -> [Name]
                 -> [Name]
                 -> ExpQ
matchEqArguments dTypes pats cnm bnd (tp:tpl) (x:xl) (y:yl) = do
  doesMatch <- assocTypePats dTypes pats tp
  case doesMatch of
    Just q -> do
      let bnd' =
            case tp of
              AppT _ (VarT nm) -> Set.insert nm bnd
              _ -> bnd
      joinTestEquality q x y (matchEqArguments dTypes pats cnm bnd' tpl xl yl)
    Nothing | typeVars tp `Set.isSubsetOf` bnd -> do
      joinEqMaybe x y        (matchEqArguments dTypes pats cnm bnd  tpl xl yl)
    Nothing -> do
      fail $ "Unsupported argument type " ++ show tp
          ++ " in " ++ show (ppr cnm) ++ "."
matchEqArguments _ _ _ _ [] [] [] = [| Just Refl |]
matchEqArguments _ _ _ _ [] _  _  = error "Unexpected end of types."
matchEqArguments _ _ _ _ _  [] _  = error "Unexpected end of names."
matchEqArguments _ _ _ _ _  _  [] = error "Unexpected end of names."
mkSimpleEqF :: [Type] 
            -> Set Name
             -> [(TypePat,ExpQ)] 
             -> ConstructorInfo
             -> [Name]
             -> ExpQ
             -> Bool 
             -> ExpQ
mkSimpleEqF dTypes bnd pats con xv yQ multipleCases = do
  
  let nm = constructorName con
  (yp,yv) <- conPat con "y"
  let rv = matchEqArguments dTypes pats nm bnd (constructorFields con) xv yv
  caseE yQ $ match (pure yp) (normalB rv) []
           : [ match wildP (normalB [| Nothing |]) [] | multipleCases ]
mkEqF :: DatatypeInfo 
      -> [(TypePat,ExpQ)]
      -> ConstructorInfo
      -> [Name]
      -> ExpQ
      -> Bool 
      -> ExpQ
mkEqF d pats con =
  let dVars = datatypeVars d
      bnd | null dVars = Set.empty
          | otherwise  = typeVars (init dVars)
  in mkSimpleEqF dVars bnd pats con
structuralTypeEquality :: TypeQ -> [(TypePat,ExpQ)] -> ExpQ
structuralTypeEquality tpq pats = do
  d <- reifyDatatype =<< asTypeCon "structuralTypeEquality" =<< tpq
  let multipleCons = not (null (drop 1 (datatypeCons d)))
      trueEqs yQ = [ do (xp,xv) <- conPat con "x"
                        match (pure xp) (normalB (mkEqF d pats con xv yQ multipleCons)) []
                   | con <- datatypeCons d
                   ]
  if null (datatypeCons d)
    then [| \x -> case x of {} |]
    else [| \x y -> $(caseE [| x |] (trueEqs [| y |])) |]
structuralTypeOrd ::
  TypeQ ->
  [(TypePat,ExpQ)]  ->
  ExpQ
structuralTypeOrd tpq l = do
  d <- reifyDatatype =<< asTypeCon "structuralTypeEquality" =<< tpq
  let withNumber :: ExpQ -> (Maybe ExpQ -> ExpQ) -> ExpQ
      withNumber yQ k
        | null (drop 1 (datatypeCons d)) = k Nothing
        | otherwise =  [| let yn :: Int
                              yn = $(caseE yQ (constructorNumberMatches (datatypeCons d)))
                          in $(k (Just [| yn |])) |]
  if null (datatypeCons d)
    then [| \x -> case x of {} |]
    else [| \x y -> $(withNumber [|y|] $ \mbYn -> caseE [| x |] (outerOrdMatches d [|y|] mbYn)) |]
  where
    constructorNumberMatches :: [ConstructorInfo] -> [MatchQ]
    constructorNumberMatches cons =
      [ match (recP (constructorName con) [])
              (normalB (litE (integerL i)))
              []
      | (i,con) <- zip [0..] cons ]
    outerOrdMatches :: DatatypeInfo -> ExpQ -> Maybe ExpQ -> [MatchQ]
    outerOrdMatches d yExp mbYn =
      [ do (pat,xv) <- conPat con "x"
           match (pure pat)
                 (normalB (do xs <- mkOrdF d l con i mbYn xv
                              caseE yExp xs))
                 []
      | (i,con) <- zip [0..] (datatypeCons d) ]
newNames ::
  String    ->
  Int       ->
  Q [Name] 
newNames base n = traverse (\i -> newName (base ++ show i)) [1..n]
joinCompareF :: ExpQ -> Name -> Name -> ExpQ -> ExpQ
joinCompareF f x y r = do
  [| case $(f) $(varE x) $(varE y) of
      LTF -> LTF
      GTF -> GTF
      EQF -> $(r)
   |]
joinCompareToOrdF :: Name -> Name -> ExpQ -> ExpQ
joinCompareToOrdF x y r =
  [| case compare $(varE x) $(varE y) of
      LT -> LTF
      GT -> GTF
      EQ -> $(r)
   |]
  
matchOrdArguments :: [Type]
                     
                  -> [(TypePat,ExpQ)] 
                  -> Name
                     
                  -> Set Name
                    
                  -> [Type]
                     
                  -> [Name]
                     
                  -> [Name]
                     
                  -> ExpQ
matchOrdArguments dTypes pats cnm bnd (tp : tpl) (x:xl) (y:yl) = do
  doesMatch <- assocTypePats dTypes pats tp
  case doesMatch of
    Just f -> do
      let bnd' = case tp of
                   AppT _ (VarT nm) -> Set.insert nm bnd
                   _ -> bnd
      joinCompareF f x y (matchOrdArguments dTypes pats cnm bnd' tpl xl yl)
    Nothing | typeVars tp `Set.isSubsetOf` bnd -> do
      joinCompareToOrdF x y (matchOrdArguments dTypes pats cnm bnd tpl xl yl)
    Nothing ->
      fail $ "Unsupported argument type " ++ show (ppr tp)
             ++ " in " ++ show (ppr cnm) ++ "."
matchOrdArguments _ _ _ _ [] [] [] = [| EQF |]
matchOrdArguments _ _ _ _ [] _  _  = error "Unexpected end of types."
matchOrdArguments _ _ _ _ _  [] _  = error "Unexpected end of names."
matchOrdArguments _ _ _ _ _  _  [] = error "Unexpected end of names."
mkSimpleOrdF :: [Type] 
             -> [(TypePat,ExpQ)] 
             -> ConstructorInfo 
             -> Integer 
             -> Maybe ExpQ 
             -> [Name]  
             -> Q [MatchQ]
mkSimpleOrdF dTypes pats con xnum mbYn xv = do
  (yp,yv) <- conPat con "y"
  let rv = matchOrdArguments dTypes pats (constructorName con) Set.empty (constructorFields con) xv yv
  
  return $ match (pure yp) (normalB rv) []
         : case mbYn of
             Nothing -> []
             Just yn -> [match wildP (normalB [| if xnum < $yn then LTF else GTF |]) []]
mkOrdF :: DatatypeInfo 
       -> [(TypePat,ExpQ)] 
       -> ConstructorInfo
       -> Integer
       -> Maybe ExpQ 
       -> [Name]
       -> Q [MatchQ]
mkOrdF d pats = mkSimpleOrdF (datatypeVars d) pats
recurseArg :: (Type -> Q (Maybe ExpQ))
           -> ExpQ 
           -> ExpQ
           -> Type
           -> Q (Maybe Exp)
recurseArg m f v tp = do
  mr <- m tp
  case mr of
    Just g ->  Just <$> [| $(g) $(f) $(v) |]
    Nothing ->
      case tp of
        AppT (ConT _) (AppT (VarT _) _) -> Just <$> [| traverse $(f) $(v) |]
        AppT (VarT _) _ -> Just <$> [| $(f) $(v) |]
        _ -> return Nothing
traverseAppMatch :: (Type -> Q (Maybe ExpQ)) 
                 -> ExpQ 
                 -> ConstructorInfo 
                 -> MatchQ 
traverseAppMatch pats fv c0 = do
  (pat,patArgs) <- conPat c0 "p"
  exprs <- zipWithM (recurseArg pats fv) (varE <$> patArgs) (constructorFields c0)
  let mkRes :: ExpQ -> [(Name, Maybe Exp)] -> ExpQ
      mkRes e [] = e
      mkRes e ((v,Nothing):r) =
        mkRes (appE e (varE v)) r
      mkRes e ((_,Just{}):r) = do
        v <- newName "r"
        lamE [varP v] (mkRes (appE e (varE v)) r)
  
  let applyRest :: ExpQ -> [Exp] -> ExpQ
      applyRest e [] = e
      applyRest e (a:r) = applyRest [| $(e) <*> $(pure a) |] r
  
  let applyFirst :: ExpQ -> [Exp] -> ExpQ
      applyFirst e [] = [| pure $(e) |]
      applyFirst e (a:r) = applyRest [| $(e) <$> $(pure a) |] r
  let pargs = patArgs `zip` exprs
  let rhs = applyFirst (mkRes (pure (conExpr c0)) pargs) (catMaybes exprs)
  match (pure pat) (normalB rhs) []
structuralTraversal :: TypeQ -> [(TypePat, ExpQ)] -> ExpQ
structuralTraversal tpq pats0 = do
  d <- reifyDatatype =<< asTypeCon "structuralTraversal" =<< tpq
  f <- newName "f"
  a <- newName "a"
  lamE [varP f, varP a] $
      caseE (varE a)
      (traverseAppMatch (assocTypePats (datatypeVars d) pats0) (varE f) <$> datatypeCons d)
asTypeCon :: Monad m => String -> Type -> m Name
asTypeCon _ (ConT nm) = return nm
asTypeCon fn _ = fail $ fn ++ " expected type constructor."
structuralHash :: TypeQ -> ExpQ
structuralHash tpq = do
  d <- reifyDatatype =<< asTypeCon "structuralHash" =<< tpq
  s <- newName "s"
  a <- newName "a"
  lamE [varP s, varP a] $
    caseE (varE a) (zipWith (matchHashCtor (varE s)) [0..] (datatypeCons d))
matchHashCtor :: ExpQ -> Integer  -> ConstructorInfo -> MatchQ
matchHashCtor s0 i c = do
  (pat,vars) <- conPat c "x"
  let args = [| $(litE (IntegerL i)) :: Int |] : (varE <$> vars)
  let go s e = [| hashWithSalt $(s) $(e) |]
  let rhs = foldl go s0 args
  match (pure pat) (normalB rhs) []
structuralShowsPrec :: TypeQ -> ExpQ
structuralShowsPrec tpq = do
  d <- reifyDatatype =<< asTypeCon "structuralShowPrec" =<< tpq
  p <- newName "_p"
  a <- newName "a"
  lamE [varP p, varP a] $
    caseE (varE a) (matchShowCtor (varE p) <$> datatypeCons d)
showCon :: ExpQ -> Name -> Int -> MatchQ
showCon p nm n = do
  vars <- newNames "x" n
  let pat = ConP nm (VarP <$> vars)
  let go s e = [| $(s) . showChar ' ' . showsPrec 10 $(varE e) |]
  let ctor = [| showString $(return (LitE (StringL (nameBase nm)))) |]
  let rhs | null vars = ctor
          | otherwise = [| showParen ($(p) >= 10) $(foldl go ctor vars) |]
  match (pure pat) (normalB rhs) []
matchShowCtor :: ExpQ -> ConstructorInfo -> MatchQ
matchShowCtor p con = showCon p (constructorName con) (length (constructorFields con))