{-# LANGUAGE FlexibleContexts #-}
-----------------------------------------------------------------------------
-- Copyright 2019, Advise-Me project team. This file is distributed under 
-- the terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-----------------------------------------------------------------------------

module Util.Expr
   ( substitute, eqComAssoc, isLiteral, isSubExprOf
   , wildcard, wildcardsToVars, restoreWildcards
   ) where

import Control.Applicative
import Control.Monad.State
import Data.Bool
import Domain.Math.Expr
import Ideas.Common.Rewriting.Term
import Ideas.Common.View
import Ideas.Utils.Uniplate
import Util.Monad
import qualified Data.Map as M

substitute :: M.Map String Expr -> Expr -> Expr
substitute m = rec
 where
   rec (Var s) =
      case M.lookup s m of
         -- no occurs check, infinite loop for recursive definitions, but bounded by time constraint. 
         -- todo: implement occurs check.
         Just e | s `notElem` vars e -> rec e
         _ -> Var s
   rec e = descend rec e

wildcardSymbol :: Symbol
wildcardSymbol = newSymbol "wildcard"

isWildcardSymbol :: Symbol -> Bool
isWildcardSymbol = (== wildcardSymbol)

-- a wildcard matches any expression.
-- we can use multiple wildcards in a single expression, we can
-- distinguish between wildcards using identifiers (as variables)
wildcard :: String -> Expr
wildcard s = Sym wildcardSymbol [Var s]

prefixVariables :: Char -> Expr -> Expr
prefixVariables s (Var v) = Var (s : v)
prefixVariables _ a@(Sym s _) | isWildcardSymbol s = a
prefixVariables s e = descend (prefixVariables s) e

instantiateWildcards :: Expr -> Expr
instantiateWildcards (Sym s [Var v]) | isWildcardSymbol s = Var ('w':v)
instantiateWildcards e = descend instantiateWildcards e

wildcardsToVars :: Expr -> Expr
wildcardsToVars = instantiateWildcards . prefixVariables 'v'

restoreWildcards :: Expr -> Expr
restoreWildcards (Var ('v':v)) = Var v
restoreWildcards (Var ('w':v)) = wildcard v
restoreWildcards e = descend restoreWildcards e

isLiteral :: Expr -> Bool
isLiteral (Nat _) = True
isLiteral (Negate n) = isLiteral n
isLiteral (Number _) = True
isLiteral _ = False

eq :: MonadState (M.Map String Expr, M.Map String (Expr -> Bool)) m => Expr -> Expr -> m Bool
eq x y = do
   s <- get
   b <- eq' x y
   unless b (put s)
   return b

eq' :: MonadState (M.Map String Expr, M.Map String (Expr -> Bool)) m => Expr -> Expr -> m Bool
eq' (Sym s1 _) (Sym s2 _) | isWildcardSymbol s1  && isWildcardSymbol s2 = return False
eq' a b@(Sym s _) | isWildcardSymbol s = eq' b a
eq' (Sym s [Var v]) b | isWildcardSymbol s = do (me,pe) <- get
                                                case (M.lookup v me, M.lookup v pe) of
                                                 (Just a, _) -> eq' a b
                                                 (_, Just p) | not (p b) -> return False
                                                 _ -> put (M.insert v b me, pe) >> return True
eq' (Nat a) (Nat b) = return (a == b)
eq' (Number a) (Number b) = return (a == b)
eq' (Var a) (Var b) = return (a == b)
eq' a b | Just (sa,[a1,a2]) <- getFunction a
       , Just (sb,[b1,b2]) <- getFunction b
       , sa == sb
       , sa == plusSymbol || sa == timesSymbol --commutativity
       = do a1b1 <- eq' a1 b1
            a2b1 <- eq' a2 b1
            a1b2 <- eq' a1 b2
            a2b2 <- eq' a2 b2
            return ((a1b1 && a2b2) || (a1b2 && a2b1))
eq' a b | Just (sa,as) <- getFunction a
       , Just (sb,bs) <- getFunction b
       , sa == sb
       = and <$> zipWithM eq' as bs
eq' _ _ = return False

isSubExprOf :: (Expr -> Expr) -> Expr -> Expr -> State (M.Map String Expr, M.Map String (Expr -> Bool)) Bool
isSubExprOf norm se e = state (\s -> maybe (False,s) (first (const True)) (runStateT (extractSum norm e se) s))

eqComAssoc :: (Expr -> Expr) -> Expr -> Expr -> State (M.Map String Expr, M.Map String (Expr -> Bool)) Bool
eqComAssoc norm e se = state (\s -> maybe (False,s) (first null) (runStateT (extractSum norm e se) s))

-- @extractSum e s@ tries to extract each sum part of @s@ from @e@ and returns
-- the list of remaining sum operands of @e@.
extractSum :: (Expr -> Expr) -> Expr -> Expr -> StateT (M.Map String Expr, M.Map String (Expr -> Bool)) Maybe [Expr]
extractSum norm e s =
  mIf (eqOrNorm norm e s) (return []) empty <|>
  foldM
   (extractSumMember norm)
   (from sumView s)
   (from sumView e)
   -- (return $ from sumView s)
{-
  do
 (b,xs) <- tryOrDiscardOn fst $
  foldM
   (extractSumMember norm)
   (True, from sumView s)
   (filter (not.hasWildcard) (from sumView e) ++ filter hasWildcard (from sumView e))
 if (not b)
 then do x <- eqOrNorm norm e s
         return (x,if x then [] else from sumView s)
 else return (b,xs)
-}
-- | @extractSumPart norm sv r@ tries to extract expression @r@ from @sv@.
-- When @r@ is a wildcard, we first try to match the wildcard to the entire sum
-- Note that there is also a possibility for the wildcard to match a subset of the sum but we disregard this.
-- If there is no wildcard or it does not match, we try to see if we can find a part of the sum for which
-- we can extract a product.
extractSumMember :: (Expr -> Expr) -> [Expr] -> Expr -> StateT (M.Map String Expr, M.Map String (Expr -> Bool)) Maybe [Expr]
extractSumMember norm eSumMembers sSumMember  = case sSumMember of
                   Sym s [Var _] | isWildcardSymbol s -> mIf (eq sSumMember (to sumView eSumMembers)) (return []) noWildcardCase
                   _ -> noWildcardCase
  where noWildcardCase = do xs <- deleteByM (\x y -> null <$> extractProduct norm x y) sSumMember eSumMembers
                            guard (length xs < length eSumMembers)
                            return xs
                            -- return (p && length xs < length eSumMembers,xs)


extractProduct :: (Expr -> Expr) -> Expr -> Expr -> StateT (M.Map String Expr, M.Map String (Expr -> Bool)) Maybe [Expr]
extractProduct norm e s = guard (eEven == sEven) >>
     foldM (extractProductMember norm b) sProductMembers eProductMembers
 <|> mIf (eqOrNorm norm e s) (return []) empty
 {-
 <|> [] <$ eqOrNorm norm e s  -- TODO: only when eqOrNorm is True
 <|> return sProductMembers
  -}
  where
   (eEven, eProductMembers) = from productView e
   (sEven, sProductMembers) = from productView s
   b = True -- length sumX == 1 && length xProducts == 1

extractProductMember :: (Expr -> Expr) -> Bool -> [Expr] -> Expr -> StateT (M.Map String Expr, M.Map String (Expr -> Bool)) Maybe [Expr]
extractProductMember norm _ sv r = do
  xs <- deleteByM (eqOrNorm norm) {-(if b then eqOrNorm norm else eqComAssoc norm) -} r sv
  guard (length xs < length sv)
  return xs

  -- return (p && length xs < length sv,xs)


eqOrNorm :: MonadState (M.Map String Expr, M.Map String (Expr -> Bool)) m => (Expr -> Expr) -> Expr -> Expr -> m Bool
eqOrNorm norm x y = do b <- eq x y
                       if b then return True
                            else eq (norm x) (norm y)

-- helper
mIf :: Applicative m => m Bool -> m a -> m a -> m a
mIf = (flip . fmap flip) (liftA3 bool)