{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TemplateHaskell   #-}
{-# LANGUAGE TupleSections     #-}

module Text.Liquid.VariableFinder where

import           Control.Arrow
import           Control.Monad.State.Lazy
import qualified Data.List                as L
import qualified Data.List.NonEmpty       as NEL
import           Text.Liquid.Types

-- | Allowed types
data VType
  = VString
  | VNumber
  | VStringOrNumber
  | VBool
  | VAny
  deriving (Eq, Show)

-- | Map findAllVariables across terms from parsed template and concatenate variable with inferred type
findAllVariables :: [Expr]
                 -> [(JsonVarPath, VType)]
findAllVariables exps =
  L.nub . join $ flip runStateT VStringOrNumber . findVariables mzero <$> exps

-- | Find all the variables in a node's children, with a type scope
findVariables :: StateT VType [] JsonVarPath
              -> Expr
              -> StateT VType [] JsonVarPath
findVariables vs (Variable v)        =
  return v `mplus` vs
findVariables vs (Equal l r)         =
  (findVariables mzero l) `mplus` (findVariables mzero r) `mplus` vs
findVariables vs (NotEqual l r)      =
  (findVariables mzero l) `mplus` (findVariables mzero r) `mplus` vs
findVariables vs (GtEqual l r)       =
  put VNumber >> (findVariables mzero l) `mplus` (findVariables mzero r) `mplus` vs
findVariables vs (LtEqual l r)       =
  put VNumber >> (findVariables mzero l) `mplus` (findVariables mzero r) `mplus` vs
findVariables vs (Gt l r)            =
  put VNumber >> (findVariables mzero l) `mplus` (findVariables mzero r) `mplus` vs
findVariables vs (Lt l r)            =
  put VNumber >> (findVariables mzero l) `mplus` (findVariables mzero r) `mplus` vs
findVariables vs (Or l r)            =
  (findVariables mzero l) `mplus` (findVariables mzero r) `mplus` vs
findVariables vs (And l r)           =
  (findVariables mzero l) `mplus` (findVariables mzero r) `mplus` vs
findVariables vs (Contains l r)      =
  (findVariables mzero l) `mplus` (findVariables mzero r) `mplus` vs
findVariables vs (Truthy t)          =
  put VBool >> findVariables vs t
findVariables vs (IfClause i)        =
  findVariables vs i
findVariables vs (IfKeyClause i)        =
  findVariables vs i
findVariables vs (ElsIfClause i)     =
  findVariables vs i
findVariables vs (Filter f xs)       =
  put VString >> (findVariables mzero f) `mplus` (msum $ findVariables mzero <$> xs) `mplus` vs
findVariables vs (Output o)          =
  findVariables vs o
findVariables vs (TrueStatements xs) =
  (msum $ findVariables mzero <$> xs) `mplus` vs
findVariables vs (IfLogic l r)       =
  (findVariables mzero l) `mplus` (findVariables mzero r) `mplus` vs
findVariables vs (CaseLogic c xts)   =
  let (ls, rs) = unzip xts
  in (findVariables mzero c)             `mplus`
     (msum $ findVariables mzero <$> ls) `mplus`
     (msum $ findVariables mzero <$> rs) `mplus`
     vs
findVariables vs _                   =
  vs

-- | Find all context variables and add a sample filter.
--   Designed to simulate required templates for aggregate contexts - see JsonTools
makeAggregate :: Expr
              -- ^ Aggregate sample filter to add
              -> VarIndex
              -- ^ Prefix to filter on - do not aggregate this prefix
              -> [Expr]
              -- ^ Parsed template to make `aggregate` style
              -> [Expr]
makeAggregate af pf xs =
  aggregateElem af pf <$> xs

aggregateElem :: Expr     -- ^ Aggregate sample filter to add
              -> VarIndex -- ^ Prefix to filter on - do not aggregate this prefix
              -> Expr     -- ^ Expression under modification
              -> Expr
aggregateElem _  _ Noop                             = Noop
aggregateElem _  _ r@(RawText _)                    = r
aggregateElem _  _ n@(Num _)                        = n
aggregateElem _  _ v@(Variable _)                   = v
aggregateElem _  _ q@(QuoteString _)                = q
aggregateElem af pf (Equal l r)                     =
  Equal (aggregateElem af pf l) (aggregateElem af pf r)
aggregateElem af pf (NotEqual l r)                  =
  NotEqual (aggregateElem af pf l) (aggregateElem af pf r)
aggregateElem af pf (GtEqual l r)                   =
  GtEqual (aggregateElem af pf l) (aggregateElem af pf r)
aggregateElem af pf (LtEqual l r)                   =
  LtEqual (aggregateElem af pf l) (aggregateElem af pf r)
aggregateElem af pf (Gt l r)                        =
  Gt (aggregateElem af pf l) (aggregateElem af pf r)
aggregateElem af pf (Lt l r)                        =
  Lt (aggregateElem af pf l) (aggregateElem af pf r)
aggregateElem af pf (Or l r)                        =
  Or (aggregateElem af pf l) (aggregateElem af pf r)
aggregateElem af pf (And l r)                       =
  And (aggregateElem af pf l) (aggregateElem af pf r)
aggregateElem af pf (Contains l r)                  =
  Contains (aggregateElem af pf l) (aggregateElem af pf r)
aggregateElem af pf (Truthy x)                      =
  Truthy (aggregateElem af pf x)
aggregateElem af pf (IfClause x)                    =
  IfClause (aggregateElem af pf x)
aggregateElem af pf (IfKeyClause x)                 =
  IfKeyClause (aggregateElem af pf x)
aggregateElem af pf (ElsIfClause x)                 =
  ElsIfClause (aggregateElem af pf x)
aggregateElem _  _  Else                            = Else
aggregateElem af pf (FilterCell x ys)               =
  FilterCell x (aggregateElem af pf <$> ys)
aggregateElem af pf (Filter (q@(QuoteString _)) fs) =
  Filter q (aggregateElem af pf <$> fs)
aggregateElem af pf (Filter (v@(Variable path)) fs)
  | NEL.head path == pf = Filter v (aggregateElem af pf <$> fs)
  | otherwise           = Filter v (aggregateElem af pf <$> updateFs af fs)
aggregateElem af pf (Filter x fs)                   =
  Filter x (aggregateElem af pf <$> fs)
aggregateElem _  _ (Output q@(QuoteString _))       =
  Output q
aggregateElem af pf (Output v@(Variable path))
  | NEL.head path == pf = Output v
  | otherwise           = Output (Filter v [af])
aggregateElem af pf (Output f)                      =
  Output (aggregateElem af pf f)
aggregateElem af pf (TrueStatements xs)             =
  TrueStatements (aggregateElem af pf <$> xs)
aggregateElem af pf (IfLogic l r)                   =
  IfLogic (aggregateElem af pf l) (aggregateElem af pf r)
aggregateElem af pf (CaseLogic x ys)                =
  CaseLogic (aggregateElem af pf x) ((aggregateElem af pf *** aggregateElem af pf) <$> ys)
aggregateElem _  _  x                               = x

updateFs :: Expr   -- ^ Aggregate sample filter to prepend
         -> [Expr] -- ^ List of filter cells
         -> [Expr]
updateFs _  []                                             = []
updateFs _  f@((FilterCell "first" _):_)                   = f
updateFs _  f@((FilterCell "firstOrDefault" _):_)          = f
updateFs _  f@((FilterCell "last" _):_)                    = f
updateFs _  f@((FilterCell "lastOrDefault" _):_)           = f
updateFs _  f@((FilterCell "countElements" _):_)           = f
updateFs _  f@((FilterCell "renderWithSeparator" _):_)     = f
updateFs _  f@((FilterCell "toSentenceWithSeparator" _):_) = f
updateFs af (f:fs)                                         = af:f:fs