{-# LANGUAGE OverloadedStrings #-}

module NLP.Brillig.Brill where

import Control.Arrow ( first, second, (***) )
import Data.Function ( on )
import qualified Data.Text as T
import qualified Data.Map as Map 
import Data.List ( isPrefixOf, delete, maximumBy, sort, sortBy )
import Data.Ord ( compare )
import Data.Map ( Map )
import Data.Text ( Text )

import qualified Data.List.Zipper as Z
import Data.List.Zipper ( Zipper )

import qualified Data.Set as Set
import Data.Set ( Set ) 

import NLP.Brillig
import NLP.Brillig.Util

data Transform = Transform { context :: [Tag] -- backwards list!
                           , replace :: Replacement
                           , tscore  :: Int
                           }
 deriving (Ord, Eq)

instance Show Transform where
  show (Transform c (Replacement f t) i) =
    unwords $ twords ++ [show i]
   where
    twords = map (T.unpack . fromTag) (reverse c ++ [f,t])

instance Read Transform where
  readsPrec p s =
    case reverse (words s) of
      (i:t:f:ctx) ->
        case readsPrec p i of
          [(s,"")] -> [ (Transform (map toTag ctx) (repl f t) s, "") ]
          _        -> []
      _            -> []
    where
     repl f t = Replacement (toTag f) (toTag t)
     toTag = Tag . T.pack

data Replacement = Replacement { from :: Tag
                               , to   :: Tag
                               }
 deriving (Show, Ord, Eq)

type TCount = Map Tag Int

data TagPair = TagPair { proposed :: Tag
                       , actual   :: Tag
                       }
 deriving (Ord, Eq)

-- ----------------------------------------------------------------------
-- tag
-- ----------------------------------------------------------------------

brilltag :: [Transform] -> [Tagged Text] -> [Tagged Text]
brilltag rules_ = retag (\ts -> foldr tagOne ts rules)
 where
  rules = sortBy (flip compare `on` tscore) rules_

-- | Apply a single transformation
tagOne :: Transform -> [Tag] -> [Tag]
tagOne x = Z.toList . walk . Z.fromList
 where
  from_ = from (replace x)
  to_   = to (replace x)
  walk z@(Z.Zip _  []    ) = z
  walk z@(Z.Zip ls (r:rs)) = walk (Z.right next)
    where
     next = if context x `isPrefixOf` ls && r == from_
               then Z.replace to_ z
               else z

-- ----------------------------------------------------------------------
-- train
-- ----------------------------------------------------------------------

learnConverge :: Int   -- ^ floor
              -> [Tag] -- ^ corpus
              -> [Tag] -- ^ best guess
              -> [Transform]
learnConverge floor cs bs =
  if tscore r > floor
     then r : learnConverge floor cs (tagOne r bs)
     else []
  where
   r = learnOne bs cs

learnN :: Int
       -> [Tag] -- ^ corpus
       -> [Tag] -- ^ best guess
       -> [Transform]
learnN = go []
 where
  go acc 0 _ _   = acc
  go acc n cs bs =
   let r = learnOne bs cs
       bs2 = tagOne r bs
   in go (r:acc) (n - 1) cs bs2

-- | Not iteratively applying and relearning!  Just doing one pass for now
learnOne :: [Tag] -- ^ corpus
         -> [Tag] -- ^ best guess
         -> Transform
learnOne best corpus =
   bestTransform tags pairs
  where
   tags   = Set.fromList best `Set.union` Set.fromList corpus
   pairs  = zipWith TagPair best corpus

bestTransform :: Set Tag -> [TagPair] -> Transform
bestTransform tags xs = answer
 where
  answer = maximumBy (compare `on` tscore) (map best repls)
  best r = bestInstance r dhist
  repls  = Set.toList (replacements tags)
  dhist  = deltaHistogram xs

replacements :: Set Tag -> Set Replacement
replacements tags =
  Set.fromList [ Replacement f t | f <- xs, t <- delete f xs ]
 where
  xs = Set.toList tags

bestInstance :: Replacement -> DeltaHistogram -> Transform
bestInstance repl dhist = answer
  -- we swap instead of just looking at the score so we can get consisent
  -- ordering in case of a tie; not that it matters much just trying to
  -- future-proof against changes in the Set implementation
 where
  answer = toTransform $ maximumBy (compare `on` swap) tmap
  toTransform (t,i) = Transform [t] repl i
  swap (x,y) = (y,x)
  tmap = Map.toList $ Map.map (score repl) dhist

type DeltaHistogram = Map Tag (Map TagPair Int)

-- | how many times a replacement follows each context
deltaHistogram :: [TagPair] -> DeltaHistogram
deltaHistogram xs =
  Map.map histogram $
  Map.fromListWith (++) $ zip prevs currs
 where
  prevs = map proposed xs
  currs = map (\x -> [x]) $ drop 1 xs

score :: Replacement -> Map TagPair Int -> Int
score r m = count good - count bad
 where
  count x = Map.findWithDefault 0 x m
  good  = TagPair (from r) (to r)
  bad   = TagPair (from r) (from r)

plusPair :: Num a => (a,a) -> (a,a) -> (a,a)
plusPair (x1,x2) (y1,y2) = (x1 + y1, x2 + y2)