-- GenI surface realiser
-- Copyright (C) 2009 Eric Kow
--
-- This program is free software; you can redistribute it and/or
-- modify it under the terms of the GNU General Public License
-- as published by the Free Software Foundation; either version 2
-- of the License, or (at your option) any later version.
--
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-- GNU General Public License for more details.
--
-- You should have received a copy of the GNU General Public License
-- along with this program; if not, write to the Free Software
-- Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

{-# LANGUAGE OverloadedStrings #-}
module NLP.GenI.Polarity(
                module NLP.GenI.Polarity.Types,

                -- * Entry point
                PolAut, PolState(PolSt), AutDebug, PolResult(..),
                buildAutomaton,

                -- * Polarity paths
                PolPathSet, detectPolPaths, hasSharedPolPaths,
                polPathsToList,
                singletonPolPath,
                emptyPolPaths, polPathsNull,
                intersectPolPaths, unionPolPaths,

                -- * Inner stuff (exported for debugging?)
                makePolAut,
                fixPronouns,
                detectSansIdx, suggestPolFeatures, detectPols,
                declareIdxConstraints, detectIdxConstraints,
                prettyPolPaths,

                -- re-exported from Automaton
                automatonPaths, finalSt,
                NFA(states, transitions),
                )
where

import           Data.IntSet                (IntSet)
import qualified Data.IntSet                as IntSet
import           Data.List
import qualified Data.Map                   as Map
import qualified Data.Set                   as Set
import           Data.Text                  (Text)
import qualified Data.Text                  as T

import           Control.Error

import           NLP.GenI.Automaton
import           NLP.GenI.FeatureStructure  (AvPair (..), FeatStruct, Flist,
                                             unifyFeat)
import           NLP.GenI.General           (Interval,
                                             isEmptyIntersect, ival, thd3,
                                             (!+!))
import           NLP.GenI.GeniVal           (GeniVal (gConstraints), isAnon,
                                             mkGAnon, replace)
import           NLP.GenI.Polarity.Internal
import           NLP.GenI.Polarity.Types
import           NLP.GenI.Pretty
import           NLP.GenI.Semantics         (Literal (..), Sem, SemInput,
                                             emptyLiteral, sortSem)
import           NLP.GenI.Tag               (TagElem (..), TagItem (..),
                                             setTidnums)
import           NLP.GenI.TreeSchema        (GNode, Ptype (Initial), gdown, gup,
                                             root, rootUpd)

-- ----------------------------------------------------------------------
-- Interface
-- ----------------------------------------------------------------------

-- | intermediate auts, seed aut, final aut, potentially modified sem
data PolResult = PolResult { prIntermediate :: [AutDebug]
                           , prInitial      :: PolAut
                           , prFinal        :: PolAut
                           , prSem          :: Sem }
type AutDebug  = (PolarityKey, PolAut, PolAut)

-- | Constructs a polarity automaton.  For debugging purposes, it returns
--   all the intermediate automata produced by the construction algorithm.
buildAutomaton :: Set.Set PolarityAttr -- ^ polarities to detect (eg. "cat")
               -> FeatStruct GeniVal   -- ^ root features to compensate for
               -> PolMap               -- ^ explicit extra polarities
               -> SemInput             -- ^ input semantics
               -> [TagElem]            -- ^ lexical selection
               -> PolResult
buildAutomaton polarityAttrs rootFeat extrapol (tsem,tres,_) candRaw =
  let -- root categories, index constraints, and external polarities
      rcatPol :: PolMap
      rcatPol = detectRootCompensation polarityAttrs rootFeat
      -- index constraints on candidate trees
      detect      = detectIdxConstraints tres
      constrain t = t { tpolarities = Map.unionWith (!+!) p r
                      } --, tinterface  = [] }
                   where p  = tpolarities t
                         r  = detect . tinterface $ t
      candRest  = map constrain candRaw
      inputRest = declareIdxConstraints tres
      -- polarity detection
      cand1   = map (detectPols polarityAttrs) candRest
      extras1 = Map.unionsWith (!+!) [ extrapol, inputRest, rcatPol ]
      ks1     = polarityKeys cand1 Map.empty
      -- expanding unconstrained polarities
      tconvert t = t { tpolarities = convertUnconstrainedPolarities ks1 (tpolarities t) }
      cand    = map tconvert cand1
      extras  = convertUnconstrainedPolarities ks1 extras1
      ks      = polarityKeys cand extras
      -- building the automaton
  in makePolAut cand tsem extras ks

makePolAut :: [TagElem] -> Sem -> PolMap -> [PolarityKey] -> PolResult
makePolAut candsRaw tsemRaw extraPol ks =
 let -- perform index counting
     (tsem, cands') = fixPronouns (tsemRaw,candsRaw)
     cands = setTidnums cands'
     -- sorted semantics (for more efficient construction)
     sortedsem = sortSemByFreq tsem cands
     -- the seed automaton
     smap = buildColumns cands sortedsem
     seed = buildSeedAut smap  sortedsem
     -- building and remembering the automata
     build k xs = (k,aut,prune aut):xs
       where aut   = buildPolAut k initK (thd3 $ head xs)
             initK = Map.findWithDefault (ival 0) k extraPol
     res = foldr build [(PolarityKeyStr "(seed)",seed,prune seed)] ks
 in PolResult { prIntermediate = reverse res
              , prInitial      = seed
              , prFinal        = thd3 $ head res
              , prSem          = tsem }

-- ====================================================================
-- Polarity automaton
-- ====================================================================

-- | Note: this is not the same function as 'NLP.GenI.Tags.mapBySem'!
-- The fact that we
-- preserve the order of the input semantics is important for our handling
-- of multi-literal semantics and for semantic frequency sorting.
buildColumns :: (TagItem t) => [t]
             -> Sem
             -> Map.Map (Literal GeniVal) [t]
buildColumns cands [] =
  Map.singleton emptyLiteral e
  where e = filter (null.tgSemantics) cands

buildColumns cands (l:ls) =
  let matchfn t = l `elem` tgSemantics t
      (match, cands2) = partition matchfn cands
      next = buildColumns cands2 ls
  in Map.insert l match next

-- ----------------------------------------------------------------------
-- Initial Automaton
-- ----------------------------------------------------------------------

buildSeedAut :: SemMap -> Sem -> PolAut
buildSeedAut cands tsem =
  let start = polstart []
      hasZero (x,y) = x <= 0 && y >= 0
      isFinal (PolSt c _ pols) =
        c == length tsem && all hasZero pols
      initAut = NFA
        { startSt = start
        , isFinalSt = Just isFinal
        , finalStList = []
        , states  = [[start]]
        , transitions = Map.empty }
  in nubAut $ buildSeedAut' cands tsem 1 initAut

-- for each literal...
buildSeedAut' :: SemMap -> Sem -> Int -> PolAut -> PolAut
buildSeedAut' _ [] _ aut = aut
buildSeedAut' cands (l:ls) i aut =
  let -- previously created candidates
      prev   = head $ states aut
      -- candidates that match the target semantics
      tcands = Map.findWithDefault [] l cands
      -- create the next batch of states
      fn st ap             = buildSeedAutHelper tcands l i st ap
      (newAut,newStates)   = foldr fn (aut,[]) prev
      next                 = nub newStates : states aut
      -- recursive step to the next literal
  in buildSeedAut' cands ls (i+1) (newAut { states = next })

-- for each candidate corresponding to literal l...
buildSeedAutHelper :: [TagElem]
                   -> Literal GeniVal
                   -> Int
                   -> PolState
                   -> (PolAut,[PolState])
                   -> (PolAut,[PolState])
buildSeedAutHelper cs l i st (aut,prev) =
  let -- get the extra semantics from the last state
      (PolSt _ ex1 _) = st
      -- candidates that match the target semantics and which
      -- do not overlap the extra baggage semantics
      tcand = [ Just t | t <- cs
              , isEmptyIntersect ex1 (tsemantics t) ]
      -- add the transitions out of the current state
      addT tr (a,n) = (addTrans a st tr st2, st2:n)
        where
         st2 = PolSt i (delete l $ ex1 ++ ex2) []
         ex2 = case tr of
               Nothing  -> []
               Just tr_ -> tsemantics tr_
  in if (l `elem` ex1)
     then addT Nothing (aut,prev)
     else foldr addT   (aut,prev) tcand

-- ----------------------------------------------------------------------
-- Construction
-- ----------------------------------------------------------------------

buildPolAut :: PolarityKey -> Interval -> PolAut -> PolAut
buildPolAut k initK skelAut =
  let concatPol p (PolSt pr b pol) = PolSt pr b (p:pol)
      newStart = concatPol initK $ startSt skelAut
      --
      initAut  = skelAut
        { startSt = newStart
        , states  = [[newStart]]
        , transitions = Map.empty }
      -- cand' = observe "candidate map" cand
  in nubAut $ buildPolAut' k (transitions skelAut) initAut

{-
Our helper function looks at a single state in the skeleton automaton
and at one of the states in the new automaton which correspond to it.
We use the transitions from the old automaton to determine which states
to construct.  Note: there can be more than one state in the automaton
which corresponds to a state in the old automaton.  This is because we
are looking at a different polarity key, so that whereas two candidates
automaton may transition to the same state in the old automaton, their
polarity effects for the new key will make them diverge in the new
automaton.
-}

buildPolAut' :: PolarityKey -> PolTransFn -> PolAut -> PolAut
-- for each literal... (this is implicit in the automaton state grouping)
buildPolAut' fk skeleton aut =
  let -- previously created candidates
      prev = head $ states aut
      -- create the next batch of states
      fn st ap            = buildPolAutHelper fk skeleton st ap
      (newAut,newStates)  = foldr fn (aut,Set.empty) prev
      next                = Set.toList newStates : states aut
      -- recursive step to the next literal
  in if Set.null newStates
     then aut
     else buildPolAut' fk skeleton (newAut { states = next })

-- given a previously created state...
buildPolAutHelper :: PolarityKey -> PolTransFn -> PolState -> (PolAut,Set.Set PolState) -> (PolAut,Set.Set PolState)
buildPolAutHelper fk skeleton st (aut,prev) =
  let -- reconstruct the skeleton state used to build st
      PolSt pr ex (po1:skelpo1) = st
      skelSt = PolSt pr ex skelpo1
      -- for each transition out of the current state
      -- nb: a transition is (next state, [labels to that state])
      trans = Map.toList $ Map.findWithDefault Map.empty skelSt skeleton
      result = foldr addT (aut,prev) trans
      -- . for each label to the next state st2
      addT (oldSt2,trs) (a,n) = foldr (addTS oldSt2) (a,n) trs
      -- .. calculate a new state and add a transition to it
      addTS skel2 tr (a,n) = (addTrans a st tr st2, Set.insert st2 n)
        where st2 = newSt tr skel2
      --
      newSt :: Maybe TagElem -> PolState -> PolState
      newSt t skel2 = PolSt pr2 ex2 (po2:skelPo2)
        where
         PolSt pr2 ex2 skelPo2 = skel2
         po2 = po1 !+! Map.findWithDefault (ival 0) fk pol
         pol = case t of Nothing -> Map.empty
                         Just t2 -> tpolarities t2
  in result

-- ----------------------------------------------------------------------
-- Pruning
-- ----------------------------------------------------------------------

{-|
The pruning algorithm takes as arguments a list of states to process.
Among these, any state which does not have outgoing transitions is
placed on the blacklist.  We remove all transitions to the blacklist and
all states that only transition to the blacklist, and then we repeat
pruning, with a next batch of states.

Finally, we return the pruned automaton.  Note: in order for this to
work, it is essential that the final states are *not* included in the
list of states to process.
-}
prune :: PolAut -> PolAut
prune aut =
  let theStates   = states aut
      final       = finalSt aut
      -- (remember that states is a list of lists)
      lastStates  = head theStates
      nextStates  = tail theStates
      nonFinal    = (lastStates \\ final)
      -- the helper function will rebuild the state list
      firstAut    = aut { states = [] }
      pruned      = prune' (nonFinal:nextStates) firstAut
      -- re-add the final state!
      statesPruned = states pruned
      headPruned   = head statesPruned
      tailPruned   = tail statesPruned
  in if (null theStates)
     then aut
     else pruned { states = (headPruned ++ final) : tailPruned }

prune' :: [[PolState]] -> PolAut -> PolAut
prune' [] oldAut = oldAut { states = reverse $ states oldAut }
prune' (sts:next) oldAut =
  let -- calculate the blacklist
      oldT  = transitions oldAut
      oldSt = states oldAut
      transFrom st = Map.lookup st oldT
      blacklist    = filter (isNothing.transFrom) sts
      -- given a st: filter out all transitions to the blacklist
      allTrans  = Map.toList $ transitions oldAut
      -- delete all transitions to the blacklist
      miniTrim = Map.filterWithKey (\k _ -> not (k `elem` blacklist))
      -- extra cleanup: delete from map states that only transition to the blacklist
      trim = Map.filterWithKey (\k m -> not (k `elem` blacklist || Map.null m))
      -- execute the kill and miniKill filters
      newT = trim $ Map.fromList [ (st2, miniTrim m) | (st2,m) <- allTrans ]
      -- new list of states and new automaton
      newSts = sts \\ blacklist
      newAut = oldAut { transitions = newT,
                        states = newSts : oldSt }
      {-
      -- debugging code
      debugstr  = "blacklist: [\n" ++ debugstr' ++ "]"
      debugstr' = concat $ intersperse "\n" $ map showSt blacklist
      showSt (PolSt pr ex po) = showPr pr ++ showEx ex ++ showPo po
      showPr (_,pr,_) = pr ++ " "
      showPo po = concat $ intersperse "," $ map show po
      showEx ex = if (null ex) then "" else (showSem ex)
      -}
      -- recursive step
  in if null blacklist
     then oldAut { states = (reverse oldSt) ++ (sts:next) }
     else prune' next newAut

-- ====================================================================
-- Zero-literal semantics
-- ====================================================================

type PredLite = (String,[GeniVal]) -- handle is head of arg list
type SemWeightMap = Map.Map PredLite SemPols

-- | Returns a modified input semantics and lexical selection in which pronouns
--   are properly accounted for.
fixPronouns :: (Sem,[TagElem]) -> (Sem,[TagElem])
fixPronouns (tsem,cands) =
  let -- part 1 (for each literal get smallest charge for each idx)
      getpols :: TagElem -> [ (PredLite,SemPols) ]
      getpols x = zip [ (prettyStr p, h:as) | Literal h p as <- tsemantics x ] (tsempols x)
      sempols :: [ (PredLite,SemPols) ]
      sempols = concatMap getpols cands
      usagemap :: SemWeightMap
      usagemap = Map.fromListWith (zipWith min) sempols
      -- part 2 (cancel sem polarities)
      chargemap :: Map.Map GeniVal Int -- index to charge
      chargemap =  Map.fromListWith (+) $ concatMap clump $ Map.toList usagemap
        where clump ((_,is),ps) = zip is ps
      -- part 3 (adding extra semantics)
      indices = concatMap fn (Map.toList chargemap)
        where fn (i,c) = replicate (negate c) i
      -- the extra columns
      extraSem = map indexLiteral indices
      tsem2    = sortSem (tsem ++ extraSem)
      -- zero-literal semantic items to realise the extra columns
      zlit = filter (null.tsemantics) cands
      cands2 = (cands \\ zlit) ++ concatMap fn indices
        where fn i = map (tweak i) zlit
              tweak i x = assignIndex i $ x { tsemantics = [indexLiteral i] }
      -- part 4 (insert excess pronouns in tree sem)
      comparefn :: GeniVal -> Int -> Int -> [GeniVal]
      comparefn i ct cm = if cm < ct then extra else []
        where maxNeeded = Map.findWithDefault 0 i chargemap -- cap the number added
              extra = replicate (min (negate maxNeeded) (ct - cm)) i
      comparePron :: (PredLite,SemPols) -> [GeniVal]
      comparePron (lit,c1) = concat $ zipWith3 comparefn idxs c1 c2
        where idxs = snd lit
              c2   = Map.findWithDefault [] lit usagemap
      addextra :: TagElem -> TagElem
      addextra c = c { tsemantics = sortSem (sem ++ extra) }
        where sem   = tsemantics c
              extra = map indexLiteral $ concatMap comparePron (getpols c)
      cands3 = map addextra cands2
  in (tsem2, cands3)

-- | Builds a fake semantic predicate that the index counting mechanism uses to
--   represent extra columns.
indexLiteral :: GeniVal -> Literal GeniVal
indexLiteral x = Literal x mkGAnon []

-- Returns True if the given literal was introduced by the index counting mechanism
isExtraCol :: Literal GeniVal -> Bool
isExtraCol (Literal _ p []) = isAnon p
isExtraCol _                = False

-- | 'assignIndex' is a useful way to restrict the behaviour of
-- null semantic items like pronouns using the information generated by
-- the index counting mechanism.  The problem with null semantic items
-- is that their indices are not set, which means that they could
-- potentially combine with any other tree.  To make things more
-- efficient, we can set the index of these items and thus reduce the
-- number of spurious combinations.
--
-- Notes
--
-- * These combinations could produce false results if the
--   input has to use multiple pronouns.  For example, if you wanted to say
--   something like “John promises Mary to convince Paul to give her
--   his book”, these combinations could instead produce “give him
--   *her* book“.
--
-- * This function works by FS unification on the root node of the
--   tree with the *[idx:i]*.  If unification is not possible,
--   we simply return the tree as is.
--
-- * This function renames the tree by appending the index to its name
assignIndex :: GeniVal -> TagElem -> TagElem
assignIndex i te =
  let idxfs = [ AvPair __idx__ i ]
      oldt  = ttree te
      oldr  = root oldt
      tfup  = gup oldr
      --
  in case hush (unifyFeat tfup idxfs) of
     Nothing          -> te
     Just (gup2, sub) -> replace sub $ te { ttree = newt }
       where newt = rootUpd oldt $ oldr { gup = gup2 }

-- ====================================================================
-- Further optimisations
-- ====================================================================

-- Index constraints
-- -----------------
detectIdxConstraints :: Flist GeniVal -> Flist GeniVal -> PolMap
detectIdxConstraints cs interface =
  let matches  = intersect cs interface
      matchStr = map idxConstraintKey matches
  in Map.fromList $ zip matchStr ((repeat.ival) 1)

declareIdxConstraints :: Flist GeniVal -> PolMap
declareIdxConstraints = Map.fromList . (map declare) where
   declare c = (idxConstraintKey c, minusone)
   minusone = ival (-1)

-- TODO: test that index constraints come first
idxConstraintKey :: AvPair GeniVal -> PolarityKey
idxConstraintKey = PolarityKeyStr . ("." <>) . pretty

-- Automatic polarity detection
-- ----------------------------
suggestPolFeatures :: [TagElem] -> [Text]
suggestPolFeatures tes =
  let -- only initial trees need be counted; in aux trees, the
      -- root node is implicitly canceled by the foot node
      rfeats, sfeats :: [Flist GeniVal]
      rfeats = map (gdown.root.ttree) $ filter (\t -> ttype t == Initial) tes
      sfeats = [ concat s | s <- map substTops tes, (not.null) s ]
      --
      attrs :: Flist GeniVal -> [Text]
      attrs avs = [ a | AvPair a v <- avs, isJust (gConstraints v) ]
      theAttributes = map attrs $ rfeats ++ sfeats
  in if null theAttributes then [] else foldr1 intersect theAttributes

-- FIXME: temporary HACKY code - delete me as soon as possible (written
-- 2006-03-30
--
-- only initial trees need be counted; in aux trees, the
-- root node is implicitly canceled by the foot node
detectSansIdx :: [TagElem] -> [TagElem]
detectSansIdx =
  let rfeats t = (gdown.root.ttree) t
      feats  t | ttype t == Initial = concat $ rfeats t : substTops t
      feats  t = concat $ substTops t
      attrs avs = [ a | AvPair a v <- avs, isJust (gConstraints v) ]
      hasIdx t = __idx__ `elem` (attrs.feats $ t) || (ttype t /= Initial && (null $ substTops t))
  in filter (not.hasIdx)

detectPols :: Set.Set PolarityAttr -> TagElem -> TagElem
detectPols attrs t =
  t { tpolarities = addPols (detectPolsH attrs t) (tpolarities t) }

-- Chart sharing
-- -------------

type PolPathSet = IntSet
type PolPathMap = Map.Map TagElem IntSet

-- | Given a list of paths (i.e. a list of list of trees)
--   return a list of trees such that each tree is annotated with the paths it
--   belongs to.
detectPolPaths :: [[TagElem]] -> [(TagElem,PolPathSet)]
detectPolPaths paths =
    Map.toList pathFM
  where
    (pathFM, _) = foldl' addPath (Map.empty, 1) paths
    -- insert all trees from this path
    addPath :: (PolPathMap, Int) -> [TagElem] -> (PolPathMap, Int)
    addPath (accFM, counter) path =
        (foldl' ins accFM path, counter + 1)
      where
        myBit    = singletonPolPath counter -- eg. 000100
        ins m t  = Map.insertWith (IntSet.union) t myBit m

emptyPolPaths :: PolPathSet
emptyPolPaths = IntSet.empty

polPathsNull :: PolPathSet -> Bool
polPathsNull = IntSet.null

polPathsToList :: PolPathSet -> [Int]
polPathsToList = IntSet.toAscList

unionPolPaths :: PolPathSet -> PolPathSet -> PolPathSet
unionPolPaths = IntSet.union

intersectPolPaths :: PolPathSet -> PolPathSet -> PolPathSet
intersectPolPaths = IntSet.intersection

hasSharedPolPaths :: PolPathSet -> PolPathSet -> Bool
hasSharedPolPaths x y = not . polPathsNull $
    x `intersectPolPaths` y

-- | Render the list of polarity automaton paths as a string
prettyPolPaths :: PolPathSet -> Text
prettyPolPaths paths =
    T.intercalate ", " $ map pretty pathlist
  where
    pathlist = IntSet.toAscList paths

-- | A (trivially) packed representation of the singleton
--   set containing a single polarity path
singletonPolPath :: Int -> PolPathSet
singletonPolPath = IntSet.singleton

-- Semantic sorting
-- ----------------

sortSemByFreq :: Sem -> [TagElem] -> Sem
sortSemByFreq tsem cands =
  let counts = map lenfn tsem
      lenfn l = length $ filter fn cands
                where fn x = l `elem` (tsemantics x)
      -- note: we introduce an extra hack to push
      -- index-counted extra columns to the end; just for UI reasons
      sortfn a b
        | isX a && isX b = compare (snd a) (snd b)
        | isX a          = GT
        | isX b          = LT
        | otherwise      = compare (snd a) (snd b)
        where isX = isExtraCol.fst
      sorted = sortBy sortfn $ zip tsem counts
  in (fst.unzip) sorted

-- ----------------------------------------------------------------------
-- Types
-- ----------------------------------------------------------------------

-- Polarity NFA

data PolState = PolSt Int [Literal GeniVal] [(Int,Int)]
                -- ^ position in the input semantics, extra semantics,
                --   polarity interval
     deriving (Eq)
type PolTrans = TagElem
type PolAut   = NFA PolState PolTrans
type PolTransFn = Map.Map PolState (Map.Map PolState [Maybe PolTrans])

instance Show PolState
  where show (PolSt pr ex po) = show pr ++ " " ++ prettyStr ex ++ show po
-- showPred pr ++ " " ++ showSem ex ++ show po

instance Ord PolState where
  compare (PolSt pr1 ex1 po1) (PolSt pr2 ex2 po2) =
    let prC   = compare pr1 pr2
        expoC = compare (ex1,po1) (ex2,po2)
    in if (prC == EQ) then expoC else prC

-- We include also some fake states which are useful for general
-- housekeeping during the main algortihms.
fakestate :: Int -> [Interval] -> PolState
fakestate s pol = PolSt s [] pol --PolSt (0, s, [""]) [] pol

-- | an initial state for polarity automata
polstart :: [Interval] -> PolState
polstart pol = fakestate 0 pol -- fakestate "START" pol