{-# LANGUAGE CPP             #-}
{-# LANGUAGE MultiWayIf      #-}
{-# LANGUAGE TemplateHaskell #-}
module Data.Matchable.TH (
  deriveMatchable, makeZipMatchWith,
  deriveBimatchable, makeBizipMatchWith
) where

import           Data.Bimatchable             (Bimatchable (..))
import           Data.Matchable               (Matchable (..))

import           Data.Monoid                  (Monoid (..))
import           Data.Semigroup               (Semigroup (..))

import           Language.Haskell.TH hiding (TyVarBndr(..))
import           Language.Haskell.TH.Datatype (ConstructorInfo (..),
                                               DatatypeInfo (..), reifyDatatype)
import           Language.Haskell.TH.Datatype.TyVarBndr

-- | Build an instance of 'Matchable' for a data type.
--
-- /e.g./
--
-- @
-- data Exp a = Plus a a | Times a a
-- 'deriveMatchable' ''Exp
-- @
--
-- will create
--
-- @
-- instance Matchable Exp where
--   zipMatchWith f (Plus  l1 l2) (Plus  r1 r2) = pure Plus  <*> f l1 r1 <*> f l2 r2
--   zipMatchWith f (Times l1 l2) (Times r1 r2) = pure Times <*> f l1 r1 <*> f l2 r2
--   zipMatchWith _ _ _ = Nothing
-- @
deriveMatchable :: Name -> Q [Dec]
deriveMatchable name = do
  ((ctx, f), zipMatchWithE) <- makeZipMatchWith' name

  dec <- instanceD ctx (appT (conT ''Matchable) (pure f))
           [ funD 'zipMatchWith [clause [] (normalB zipMatchWithE) []] ]

  pure [dec]

makeZipMatchWith :: Name -> ExpQ
makeZipMatchWith name = makeZipMatchWith' name >>= snd

makeZipMatchWith' :: Name -> Q ((Q Cxt, Type), ExpQ)
makeZipMatchWith' name = do
  info <- reifyDatatype name
  let DatatypeInfo { datatypeVars = dtVars , datatypeCons = cons } = info
      tyA : rest' = reverse (VarT . tvName <$> dtVars)
      dtFunctor = foldr (flip AppT) (ConT name) rest'

  f <- newName "f"

  let mkMatchClause (ConstructorInfo ctrName _ _ fields _ _) =
        do matchers <- mapM (dMatchField tyA f) fields
           let lFieldsP = leftPat <$> matchers
               rFieldsP = rightPat <$> matchers
               bodyUsesF = any additionalInfo matchers
               body = foldl (\x y -> [| $x <*> $y |])
                            [| pure $(conE ctrName) |]
                            (bodyExp <$> matchers)
               ctx = concatMap requiredCtx matchers
               fPat = if bodyUsesF then varP f else wildP
               lPat = conP ctrName lFieldsP
               rPat = conP ctrName rFieldsP
           return (clause [fPat, lPat, rPat] (normalB body) [], ctx)

  matchClausesAndCtxs <- mapM mkMatchClause cons

  let matchClauses = map fst matchClausesAndCtxs
      ctx = concatMap snd matchClausesAndCtxs
      mismatchClause = clause [ wildP, wildP, wildP ] (normalB [| Nothing |]) []
      finalClauses = case cons of
        []  -> []
        [_] -> matchClauses
        _   -> matchClauses ++ [mismatchClause]

  zmw <- newName "zmw"
  return ((sequenceA ctx, dtFunctor), letE [ funD zmw finalClauses ] (varE zmw))

data Matcher u = Matcher
  { leftPat        :: PatQ
  , rightPat       :: PatQ
  , bodyExp        :: ExpQ
  , requiredCtx    :: [TypeQ]
  , additionalInfo :: u }

dMatchField :: Type -> Name -> Type -> Q (Matcher Bool)
dMatchField tyA fName ty = case spine ty of
  _ | ty == tyA -> do
        l <- newName "l"
        r <- newName "r"
        return $ Matcher
          { leftPat = varP l
          , rightPat = varP r
          , additionalInfo = True
          , bodyExp = [| $(varE fName) $(varE l) $(varE r) |]
          , requiredCtx = [] }
    | not (occurs tyA ty) -> do
        l <- newName "l"
        r <- newName "r"
        let ctx = [ pure (AppT (ConT ''Eq) ty) | hasTyVar ty ]
        return $ Matcher
          { leftPat = varP l
          , rightPat = varP r
          , additionalInfo = False
          , bodyExp = [| if $(varE l) == $(varE r)
                           then Just $(varE l)
                           else Nothing |]
          , requiredCtx = ctx }
  (ListT, ty':_) -> dWrapped ty'
  (TupleT n, subtys) -> do
     matchers <- mapM (dMatchField tyA fName) (reverse subtys)
     let lP = tupP (leftPat <$> matchers)
         rP = tupP (rightPat <$> matchers)
         tupcon = [| pure $(conE (tupleDataName n)) |]
         anyUsesF = any additionalInfo matchers
         body = foldl (\x y -> [| $x <*> $y |]) tupcon (bodyExp <$> matchers)
         ctx = concatMap requiredCtx matchers
     return $ Matcher
       { leftPat = lP
       , rightPat = rP
       , additionalInfo = anyUsesF
       , bodyExp = body
       , requiredCtx = ctx }
  (ConT tcon, ty' : rest) | all (not . occurs tyA) rest -> do
     let g = foldr (flip AppT) (ConT tcon) rest
         ctxG = [ pure (AppT (ConT ''Matchable) g) | hasTyVar g ]
     matcher <- dWrapped ty'
     return $ matcher{ requiredCtx = ctxG ++ requiredCtx matcher }
  (ConT tcon, ty1' : ty2' : rest) | all (not . occurs tyA) rest -> do
     let g = foldr (flip AppT) (ConT tcon) rest
         ctxG = [ pure (AppT (ConT ''Bimatchable) g) | hasTyVar g ]
     -- Note that since @spine@ reverses argument order,
     -- it must be dWrappedBi ty2 ty1.
     matcher <- dWrappedBi ty2' ty1'
     return $ matcher{ requiredCtx = ctxG ++ requiredCtx matcher }
  (VarT t, ty' : rest) | all (not . occurs tyA) rest -> do
     let g = foldr (flip AppT) (VarT t) rest
         ctxG = [ pure (AppT (ConT ''Matchable) g) ]
     matcher <- dWrapped ty'
     return $ matcher{ requiredCtx = ctxG ++ requiredCtx matcher }
  (VarT t, ty1' : ty2' : rest) | all (not . occurs tyA) rest -> do
     let g = foldr (flip AppT) (VarT t) rest
         ctxG = [ pure (AppT (ConT ''Bimatchable) g) | hasTyVar g ]
     matcher <- dWrappedBi ty2' ty1'
     return $ matcher{ requiredCtx = ctxG ++ requiredCtx matcher }
  (ForallT _ _ _, _) -> unexpectedType ty "Matchable"
  (ParensT _, _) -> error "Never reach here"
  (AppT _ _, _) -> error "Never reach here"
  (SigT _ _, _) -> error "Never reach here"
  _ -> unexpectedType ty "Matchable"

  where
    dWrapped :: Type -> Q (Matcher Bool)
    dWrapped ty' =do
      l <- newName "l"
      r <- newName "r"
      (usesF', ctx, fun) <- do
         matcher <- dMatchField tyA fName ty'
         let fun = lamE [leftPat matcher, rightPat matcher] (bodyExp matcher)
         return (additionalInfo matcher, requiredCtx matcher, fun)
      return $ Matcher
        { leftPat = varP l
        , rightPat = varP r
        , additionalInfo = usesF'
        , bodyExp = [| zipMatchWith $fun $(varE l) $(varE r) |]
        , requiredCtx = ctx }

    dWrappedBi :: Type -> Type -> Q (Matcher Bool)
    dWrappedBi ty1 ty2 = do
      l <- newName "l"
      r <- newName "r"
      (usesF', ctx, fun1, fun2) <- do
         matcher1 <- dMatchField tyA fName ty1
         matcher2 <- dMatchField tyA fName ty2
         let fun1 = lamE [leftPat matcher1, rightPat matcher1] (bodyExp matcher1)
             fun2 = lamE [leftPat matcher2, rightPat matcher2] (bodyExp matcher2)
             usesF' = additionalInfo matcher1 || additionalInfo matcher2
             ctx = requiredCtx matcher1 ++ requiredCtx matcher2
         return (usesF', ctx, fun1, fun2)
      return $ Matcher
        { leftPat = varP l
        , rightPat = varP r
        , additionalInfo = usesF'
        , bodyExp = [| bizipMatchWith $fun1 $fun2 $(varE l) $(varE r) |]
        , requiredCtx = ctx }

-- | Build an instance of 'Bimatchable' for a data type.
--
-- /e.g./
--
-- @
-- data Sum a b = InL a | InR b
-- 'deriveMatchable' ''Sum
-- @
--
-- will create
--
-- @
-- instance Matchable Sum where
--   bizipMatchWith f _ (InL l1) (InL r1) = pure InL <$> f l1 r1
--   bizipMatchWith _ g (InR l1) (InR r1) = pure InR <$> g l1 r1
-- @
deriveBimatchable :: Name -> Q [Dec]
deriveBimatchable name = do
  ((ctx, f), zipMatchWithE) <- makeBizipMatchWith' name

  dec <- instanceD ctx (appT (conT ''Bimatchable) (pure f))
           [ funD 'bizipMatchWith [clause [] (normalB zipMatchWithE) []] ]

  pure [dec]

makeBizipMatchWith :: Name -> ExpQ
makeBizipMatchWith name = makeBizipMatchWith' name >>= snd

makeBizipMatchWith' :: Name -> Q ((Q Cxt, Type), ExpQ)
makeBizipMatchWith' name = do
  info <- reifyDatatype name
  let DatatypeInfo { datatypeVars = dtVars , datatypeCons = cons } = info
      tyB : tyA : rest' = reverse (VarT . tvName <$> dtVars)
      dtFunctor = foldr (flip AppT) (ConT name) rest'

  f <- newName "f"
  g <- newName "g"

  let mkMatchClause (ConstructorInfo ctrName _ _ fields _ _) =
        do matchers <- mapM (dBimatchField tyA f tyB g) fields
           let lFieldsP = leftPat <$> matchers
               rFieldsP = rightPat <$> matchers
               Usage2 usesF usesG = foldMap additionalInfo matchers
               body = foldl (\x y -> [| $x <*> $y |])
                            [| pure $(conE ctrName) |]
                            (bodyExp <$> matchers)
               ctx = concatMap requiredCtx matchers
               fPat = if usesF then varP f else wildP
               gPat = if usesG then varP g else wildP
               lPat = conP ctrName lFieldsP
               rPat = conP ctrName rFieldsP
           return (clause [fPat, gPat, lPat, rPat] (normalB body) [], ctx)

  matchClausesAndCtxs <- mapM mkMatchClause cons

  let matchClauses = map fst matchClausesAndCtxs
      ctx = concatMap snd matchClausesAndCtxs
      mismatchClause = clause [ wildP, wildP, wildP, wildP ] (normalB [| Nothing |]) []
      finalClauses = case cons of
        []  -> []
        [_] -> matchClauses
        _   -> matchClauses ++ [mismatchClause]

  bzmw <- newName "bzmw"
  return ((sequenceA ctx, dtFunctor), letE [ funD bzmw finalClauses ] (varE bzmw))

data FunUsage2 = Usage2 Bool Bool

instance Semigroup FunUsage2 where
  Usage2 f1 g1 <> Usage2 f2 g2 = Usage2 (f1 || f2) (g1 || g2)

instance Monoid FunUsage2 where
  mempty = Usage2 False False
  mappend = (<>)

dBimatchField :: Type -> Name -> Type -> Name -> Type -> Q (Matcher FunUsage2)
dBimatchField tyA fName tyB gName ty = case spine ty of
  _ | ty == tyA -> do
        l <- newName "l"
        r <- newName "r"
        return $ Matcher
          { leftPat = varP l
          , rightPat = varP r
          , additionalInfo = Usage2 True False
          , bodyExp = [| $(varE fName) $(varE l) $(varE r) |]
          , requiredCtx = [] }
    | ty == tyB -> do
        l <- newName "l"
        r <- newName "r"
        return $ Matcher
          { leftPat = varP l
          , rightPat = varP r
          , additionalInfo = Usage2 False True
          , bodyExp = [| $(varE gName) $(varE l) $(varE r) |]
          , requiredCtx = [] }
    | isConst ty -> do
        l <- newName "l"
        r <- newName "r"
        let ctx = [ pure (AppT (ConT ''Eq) ty) | hasTyVar ty ]
        return $ Matcher
          { leftPat = varP l
          , rightPat = varP r
          , additionalInfo = Usage2 False False
          , bodyExp = [| if $(varE l) == $(varE r)
                           then Just $(varE l)
                           else Nothing |]
          , requiredCtx = ctx }
  (ListT, ty':_) -> dWrapped ty'
  (TupleT n, subtys) -> do
     matchers <- mapM (dBimatchField tyA fName tyB gName) (reverse subtys)
     let lP = tupP (leftPat <$> matchers)
         rP = tupP (rightPat <$> matchers)
         tupcon = [| pure $(conE (tupleDataName n)) |]
         anyUsesF = foldMap additionalInfo matchers
         body = foldl (\x y -> [| $x <*> $y |]) tupcon (bodyExp <$> matchers)
         ctx = concatMap requiredCtx matchers
     return $ Matcher
       { leftPat = lP
       , rightPat = rP
       , additionalInfo = anyUsesF
       , bodyExp = body
       , requiredCtx = ctx }
  (ConT tcon, ty' : rest) | all isConst rest -> do
     let g = foldr (flip AppT) (ConT tcon) rest
         ctxG = [ pure (AppT (ConT ''Matchable) g) | hasTyVar g ]
     matcher <- dWrapped ty'
     return $ matcher{ requiredCtx = ctxG ++ requiredCtx matcher }
  (ConT tcon, ty1' : ty2' : rest) | all isConst rest -> do
     let g = foldr (flip AppT) (ConT tcon) rest
         ctxG = [ pure (AppT (ConT ''Bimatchable) g) | hasTyVar g ]
     -- Note that since @spine@ reverses argument order,
     -- it must be dWrappedBi ty2 ty1.
     matcher <- dWrappedBi ty2' ty1'
     return $ matcher{ requiredCtx = ctxG ++ requiredCtx matcher }
  (VarT t, ty' : rest) | all isConst rest -> do
     let g = foldr (flip AppT) (VarT t) rest
         ctxG = [ pure (AppT (ConT ''Matchable) g) ]
     matcher <- dWrapped ty'
     return $ matcher{ requiredCtx = ctxG ++ requiredCtx matcher }
  (VarT t, ty1' : ty2' : rest) | all isConst rest -> do
     let g = foldr (flip AppT) (VarT t) rest
         ctxG = [ pure (AppT (ConT ''Bimatchable) g) | hasTyVar g ]
     matcher <- dWrappedBi ty2' ty1'
     return $ matcher{ requiredCtx = ctxG ++ requiredCtx matcher }
  (ForallT _ _ _, _) -> unexpectedType ty "Bimatchable"
  (ParensT _, _) -> error "Never reach here"
  (AppT _ _, _) -> error "Never reach here"
  (SigT _ _, _) -> error "Never reach here"
  _ -> unexpectedType ty "Bimatchable"

  where
    isConst :: Type -> Bool
    isConst t = not (occurs tyA t || occurs tyB t)

    dWrapped :: Type -> Q (Matcher FunUsage2)
    dWrapped ty' = do
      l <- newName "l"
      r <- newName "r"
      (usesF', ctx, fun) <- do
         matcher <- dBimatchField tyA fName tyB gName ty'
         let fun = lamE [leftPat matcher, rightPat matcher] (bodyExp matcher)
         return (additionalInfo matcher, requiredCtx matcher, fun)
      return $ Matcher
        { leftPat = varP l
        , rightPat = varP r
        , additionalInfo = usesF'
        , bodyExp = [| zipMatchWith $fun $(varE l) $(varE r) |]
        , requiredCtx = ctx }

    dWrappedBi :: Type -> Type -> Q (Matcher FunUsage2)
    dWrappedBi ty1 ty2 = do
      l <- newName "l"
      r <- newName "r"
      (usesF', ctx, fun1, fun2) <- do
         matcher1 <- dBimatchField tyA fName tyB gName ty1
         matcher2 <- dBimatchField tyA fName tyB gName ty2
         let fun1 = lamE [leftPat matcher1, rightPat matcher1] (bodyExp matcher1)
             fun2 = lamE [leftPat matcher2, rightPat matcher2] (bodyExp matcher2)
             usesF' = additionalInfo matcher1 <> additionalInfo matcher2
             ctx = requiredCtx matcher1 ++ requiredCtx matcher2
         return (usesF', ctx, fun1, fun2)
      return $ Matcher
        { leftPat = varP l
        , rightPat = varP r
        , additionalInfo = usesF'
        , bodyExp = [| bizipMatchWith $fun1 $fun2 $(varE l) $(varE r) |]
        , requiredCtx = ctx }

-----------------------------

unexpectedType :: Type -> String -> Q a
unexpectedType ty cls = fail $
  "unexpected type " ++ show ty ++ " in derivation of " ++ cls ++
  " (it's only possible to implement " ++ cls ++
  " genericaly when all subterms are traversable)"

spine :: Type -> (Type, [Type])
spine (ParensT t)  = spine t
spine (AppT t1 t2) = let (h, r) = spine t1 in (h, t2:r)
spine (SigT t _)   = spine t
spine t            = (t, [])

occurs :: Type -> Type -> Bool
occurs t u | t == u = True
occurs t u = case u of
  AppT u1 u2 -> occurs t u1 || occurs t u2
  ParensT u' -> occurs t u'
  SigT u' _  -> occurs t u'
  _          -> False

hasTyVar :: Type -> Bool
hasTyVar (VarT _)     = True
hasTyVar (ParensT t)  = hasTyVar t
hasTyVar (AppT t1 t2) = hasTyVar t1 || hasTyVar t2
hasTyVar (SigT t _)   = hasTyVar t
hasTyVar _            = False