{-# LANGUAGE TemplateHaskell #-}

module Database.Algebra.Rewrite.PatternConstruction
    ( dagPatMatch
    , v
    ) where

import           Control.Monad.Writer
import           Data.Maybe
import           Language.Haskell.TH

import           Database.Algebra.Dag
import           Database.Algebra.Dag.Common
import qualified Database.Algebra.Rewrite.DagRewrite    as R
import qualified Database.Algebra.Rewrite.Match         as M
import           Database.Algebra.Rewrite.PatternSyntax

type Code a = WriterT [Q Stmt] Q a

emit :: Q Stmt -> Code ()
emit s = tell [s]

matchOp :: Name
matchOp = mkName "matchOp"

opName :: Name
opName = mkName "op__internal"

terOpName :: Name
terOpName = mkName "TerOp"

binOpName :: Name
binOpName = mkName "BinOp"

unOpName :: Name
unOpName = mkName "UnOp"

nullOpName :: Name
nullOpName = mkName "NullaryOp"

failName :: Name
failName = mkName "fail"

catchAllCase :: Q Match
catchAllCase = match wildP (normalB (appE (varE failName) (litE (stringL "")))) []

data SemPattern = Bind (Q Pat, Name)
                | NoBind
                | NoSemantics

-- case op of ... -> return _ -> fail ""
instMatchCase :: Name           -- ^ The name of the node constructor (BinOp, UnOp, NullOp)
                 -> [Name]      -- ^ The name of the operator constructors
                 -> SemPattern -- ^ If the semantical pattern is not a wildcard: the name of the binding variable
                 -> [Q Pat]     -- ^ The list of patterns matching the node children
                 -> [Name]      -- ^ The list of variables for the children (may be empty)
                 -> Bool        -- ^ Bind the operator name (or don't)
                 -> Q Exp       -- ^ Returns the case expression
instMatchCase nodeConstructor opConstructors semantics childMatchPatterns childNames bindOp =
  caseE (varE opName) ((map opAlternative opConstructors) ++ [catchAllCase])
  where opAlternative opConstructor = match opPattern opBody []
          where (semPat, semName) = case semantics of
                  Bind (p, n) -> ([p], [n])
                  NoBind      -> ([wildP], [])
                  NoSemantics -> ([], [])
                opPattern = conP nodeConstructor ((conP opConstructor semPat) : childMatchPatterns)
                opConstExp = if bindOp then [conE opConstructor] else []
                opBody = normalB $ appE (varE (mkName "return")) (tupE $ opConstExp ++ (map varE (semName ++ childNames)))

-- \op -> case op of...
instMatchLambda :: Q Exp -> Q Exp
instMatchLambda body = lam1E (varP opName) body

instMatchExp :: Name -> Q Exp -> Q Exp
instMatchExp nodeName matchLambda =
  appE (appE (varE matchOp) (varE nodeName)) matchLambda

-- (a, b, c) <- ...
instBindingPattern :: Maybe (Q Pat) -> SemPattern -> [Q Pat] -> Q Pat
instBindingPattern mOpConstPat semPat childPats = tupP patterns
  where patterns = (maybeList mOpConstPat) ++ (semList semPat) ++ childPats
        maybeList (Just x) = [x]
        maybeList Nothing  = []

        semList (Bind (pat, _)) = [pat]
        semList NoBind         = []
        semList NoSemantics    = []

-- (a, b, c) <- matchOp q (\op -> case op of ...)
instStatement :: Maybe (Q Pat) -> SemPattern -> [Q Pat] -> Q Exp -> Q Stmt
instStatement mOpConstPat semPat childPats matchExp =
  case (semPat, childPats) of
    (NoBind, [])      -> noBindS matchExp
    (NoSemantics, []) -> noBindS matchExp
    (_, _)            -> bindS (instBindingPattern mOpConstPat semPat childPats) matchExp

semPatternName :: Maybe Sem -> SemPattern
semPatternName (Just WildS)      = NoBind
semPatternName (Just (NamedS s)) = let name = mkName s in Bind (varP name, name)
semPatternName Nothing           = NoSemantics

instStmtWrapper :: Name              -- ^ The name of the node on which to match
                   -> Name           -- ^ The name of the node constructor (BinOp, UnOp, NullOp)
                   -> [Name]         -- ^ The name of the operator constructors
                   -> Maybe (Q Pat)  -- ^ The binding name for the operator constructor
                   -> SemPattern     -- ^ Pattern binding the semantical information (or wildcard)
                   -> [Q Pat]        -- ^ The list of patterns matching the node children
                   -> [Q Pat]        -- ^ The list of patterns binding the node children (may be empty)
                   -> [Name]         -- ^ The list of variables for the children (may be empty)
                   -> Q Stmt         -- ^ Returns the case expression
instStmtWrapper nodeName nodeKind operNames mOpConstPat semantics childMatchPats childPats childNames =
  let matchCase   = instMatchCase nodeKind operNames semantics childMatchPats childNames (isJust mOpConstPat)
      matchLambda = instMatchLambda matchCase
      matchExp    = instMatchExp nodeName matchLambda
  in instStatement mOpConstPat semantics childPats matchExp

opInfo :: Op -> (Maybe (Q Pat), [Name])
opInfo (NamedOp bindingName opNames) = (Just $ varP $ mkName bindingName, map mkName opNames)
opInfo (UnnamedOp opNames) = (Nothing, map mkName opNames)

-- generate a list of node matching statements from an operator (tree)
gen :: Name -> Node -> Code ()
gen nodeName (NullP op semBinding) =
  let semantics = semPatternName semBinding
      (mOpConstPat, opNames) = opInfo op
      statement = instStmtWrapper nodeName nullOpName opNames mOpConstPat semantics [] [] []
  in emit statement

gen nodeName (UnP op semBinding child) = do
  let semantics = semPatternName semBinding
      (mOpConstPat, opNames) = opInfo op

  patAndName <- lift (childMatchPattern child)

  let (matchPatterns, bindNames, bindPatterns) = splitMatchAndBind $ [patAndName]
      statement = instStmtWrapper nodeName unOpName opNames mOpConstPat semantics matchPatterns bindPatterns bindNames

  emit statement

  maybeDescend child (snd patAndName)

gen nodeName (BinP op semBinding child1 child2) = do
  let semantics = semPatternName semBinding
      (mOpConstPat, opNames) = opInfo op

  leftPatAndName   <- lift (childMatchPattern child1)
  rightPatAndName  <- lift (childMatchPattern child2)

  let (matchPatterns, bindNames, bindPatterns) = splitMatchAndBind [leftPatAndName, rightPatAndName]
      statement = instStmtWrapper nodeName binOpName opNames mOpConstPat semantics matchPatterns bindPatterns bindNames

  emit statement

  maybeDescend child1 (snd leftPatAndName)
  maybeDescend child2 (snd rightPatAndName)

gen nodeName (TerP op semBinding child1 child2 child3) = do
  let semantics = semPatternName semBinding
      (mOpConstPat, opNames) = opInfo op

  patAndName1 <- lift (childMatchPattern child1)
  patAndName2 <- lift (childMatchPattern child2)
  patAndName3 <- lift (childMatchPattern child3)

  let childPatAndNames = [patAndName1, patAndName2, patAndName3]
      (matchPatterns, bindNames, bindPatterns) = splitMatchAndBind $ childPatAndNames
      statement = instStmtWrapper nodeName terOpName opNames mOpConstPat semantics matchPatterns bindPatterns bindNames

  emit statement

  maybeDescend child1 (snd patAndName1)
  maybeDescend child2 (snd patAndName2)
  maybeDescend child3 (snd patAndName3)

gen nodeName (HoleP holeStart subHolePat) = do
  -- collect all binders from the sub-hole pattern in a canonical order
  let binderNames = map mkName $ collectBinders subHolePat

  -- generate a function that tries to match the sub-hole pattern at the given node
  (patMatchFunName, patMatchFunStmt) <- lift $ genSubHoleMatch binderNames subHolePat
  emit patMatchFunStmt

  -- (nodeName, binderNames) <- searchHolePat patMatchName holeStart
  -- Use function searchHolePat to search for a node at which the sub-hole
  -- pattern matches.
  let searchExpr = appE (appE (varE 'searchHolePat) (varE patMatchFunName)) (varE nodeName)
      bindingPat = tupP [varP (mkName holeStart), listP (map varP binderNames)]

  emit $ bindS bindingPat searchExpr

gen nodeName (HoleEq eqNode) = do
  emit $ noBindS $ appE (appE (varE 'searchHoleEq) (varE nodeName)) (varE $ mkName eqNode)

-- Traverse a DAG (DFS, preorder) and search for a node where the given pattern applies.
-- Returns the matching node and the list of values for the pattern's binders.
searchHolePat :: Operator o
                 => (AlgNode -> M.Match o p e [AlgNode])
                 -> AlgNode
                 -> M.Match o p e (AlgNode, [AlgNode])
searchHolePat patMatch q = do
  (d, p, e) <- M.exposeEnv
  case M.runMatch e d p (patMatch q) of
    Just nodes -> return (q, nodes)
    Nothing    -> do
                    children <- opChildren <$> M.getOperator q
                    searchChildren patMatch children

-- Apply searchHolePat to a list of nodes, take the first one that matches.
searchChildren :: Operator o
                  => (AlgNode -> M.Match o p e [AlgNode])
                  -> [AlgNode]
                  -> M.Match o p e (AlgNode, [AlgNode])
searchChildren _        []     = fail "no match"
searchChildren patMatch (q:qs) = do
  (d, p, e) <- M.exposeEnv
  case M.runMatch e d p(searchHolePat patMatch q) of
    Just nodes -> return nodes
    Nothing    -> searchChildren patMatch qs

-- Search for an occurence of the node 'eqNode', starting at 'startNode'
searchHoleEq :: Operator o => AlgNode -> AlgNode -> M.Match o p e ()
searchHoleEq startNode eqNode =
  if startNode == eqNode
  then return ()
  else do
    (d, _, _) <- M.exposeEnv
    children <- opChildren <$> M.getOperator startNode
    if nodeOccurs d eqNode children
      then return ()
      else fail "no occurence"

-- Since we only search for occurences of a particular node and no pattern matching
-- occurs, we do not burden ourselves with the Match monad here.
nodeOccurs :: Operator o => AlgebraDag o -> AlgNode -> [AlgNode] -> Bool
nodeOccurs dag eqNode startNodes =
  if eqNode `elem` startNodes
  then True
  else or $ map (nodeOccurs dag eqNode . opChildren . (flip operator dag)) startNodes

-- | Generate a function which matches a pattern on a certain node.
-- The generated function returns values for all binders in the pattern
-- in the canonical order given by 'binderNames'
-- Type of the generated function:
--      subhole_xy :: AlgNode -> Match o [AlgNode]
genSubHoleMatch :: [Name] -> Pattern -> Q (Name, Q Stmt)
genSubHoleMatch binderNames pat = do
  -- generate the code for matching the pattern
  rootName <- newName "subNode"
  patternStatements <- execWriterT $ gen rootName pat
  -- return values for the binders in the proper order.
  let returnStatement   = noBindS $ appE (varE 'return) (listE $ map varE binderNames)
      body              = doE $ patternStatements ++ [returnStatement]

  -- the function binding
  funName <- newName "subhole"
  let fun  = funD funName [(clause [varP rootName] (normalB body) [])]
      stmt = letS $ [fun]
  return (funName, stmt)

{-
semBinder :: Maybe Sem -> [Ident]
semBinder (Just (NamedS i)) = [i]
semBinder (Just WildS)      = []
semBinder Nothing           = []
-}

opBinder :: Op -> [Ident]
opBinder (NamedOp i _) = [i]
opBinder (UnnamedOp _) = []

childBinders :: Child -> [Ident]
childBinders (NodeC n)        = collectBinders n
childBinders WildC            = []
childBinders (NameC i)        = [i]
childBinders (NamedNodeC i n) = i : collectBinders n

-- Collect binders in pre-order fashion from a pattern tree
-- TODO: so far, only binders for nodes (type AlgNode) are collected. This is
-- necessary so that we can return values for them in a list without type-specific
-- wrappers
collectBinders :: Node -> [Ident]
collectBinders (TerP op _ c1 c2 c3) = opBinder op
                                      -- ++ semBinder sem
                                      ++ concatMap childBinders [c1, c2, c3]
collectBinders (BinP op _ c1 c2)    = opBinder op
                                      -- ++ semBinder sem
                                      ++ concatMap childBinders [c1, c2]
collectBinders (UnP op _ c)         = opBinder op
                                      -- ++ semBinder sem
                                      ++ childBinders c
collectBinders (NullP op _)         = opBinder op -- ++ semBinder sem
collectBinders (HoleP _ _)          = error "collectBinders: Holes in sub-hole patterns not supported"
collectBinders (HoleEq _)           = []

{-
Split the list of matching patterns and binding names.
-}
splitMatchAndBind :: [(Q Pat, Maybe Name)] -> ([Q Pat], [Name], [Q Pat])
splitMatchAndBind ps =
  let (matchPatterns, mBindNames) = unzip ps
      bindNames = catMaybes mBindNames
  in (matchPatterns, bindNames, map varP bindNames)

{-
For every child, generate the matching pattern and - if the child
is to be bound either with a given name or for matching on the child itself -
the name to which it should be bound.

This distinction is necessary because a child that is not to be bound
must be matched anyway with a wildcard pattern so that the operator constructor
has enough parameters in the match.
-}
childMatchPattern :: Child -> Q (Q Pat, Maybe Name)
childMatchPattern WildC   =
  return (wildP, Nothing)
childMatchPattern (NameC s) =
  let n = mkName s
  in return (varP n, Just n)
childMatchPattern (NodeC _) =
  newName "child"
  >>= (\n -> return (varP n, Just n))
childMatchPattern (NamedNodeC s _) =
  let n = mkName s
  in return (varP n, Just n)

recurse :: Child -> Maybe Node
recurse WildC           = Nothing
recurse (NameC _)       = Nothing
recurse (NodeC o)         = Just o
recurse (NamedNodeC _ o ) = Just o

maybeDescend :: Child -> Maybe Name -> Code ()
maybeDescend c ns =
  case recurse c of
    Just o   ->
      case ns of
        Just n   -> gen n o
        Nothing  -> error "PatternConstruction.gen: no name for child pattern"
    Nothing  -> return ()

assembleStatements :: Q [Stmt] -> Q Exp -> Q Exp
assembleStatements patternStatements userExpr = do
  ps <- patternStatements
  e <- userExpr

  let us =
        case e of
          DoE userStatements -> userStatements
          _ -> error "PatternConstruction.assembleStatements: no do-block supplied"

      -- The call to collect
      collectStmt = NoBindS $ VarE 'R.collect

      -- Extract the returned sequence of rewrite actions and patch the
      -- call to collect at the end
      returnStmt = case last us of
                     NoBindS (InfixE (Just (VarE returnName)) (VarE dollarName) (Just rewriteExpr))
                       | dollarName == '($) && returnName == 'return    ->
                         let rewriteExpr' = DoE [NoBindS rewriteExpr, collectStmt]
                         in NoBindS (InfixE (Just (VarE returnName)) (VarE dollarName) (Just rewriteExpr'))
                     s                                                  -> error $ show s

      -- reassemble the user statements
      us' = init us ++ [returnStmt]

  -- Return a do block consisting of the pattern statements and the user statements.
  return $ DoE $ ps ++ us'

-- | Take a quoted variable with the root node on which to apply the pattern,
-- a string description of the pattern and the body of the match
-- and return the complete match statement. The body has to be a quoted ([| ...|])
-- do-block.
dagPatMatch :: Name -> String -> Q Exp -> Q Exp
dagPatMatch rootName patternString userExpr = do

  let pat = parsePattern patternString

  -- generate the code that matches the pattern (a list of statements)
  patternStatements <- execWriterT $ gen rootName pat

  -- combine the generated pattern-matching statements with the
  -- user-supplied additional predicates
  assembleStatements (mapM id patternStatements) userExpr

-- | Reference a variable that is bound by a pattern in a quoted match body.
v :: String -> Q Exp
v = dyn