{-# LANGUAGE OverloadedStrings #-}
module NLP.Sequor 
    ( ModelData
    , P.Trace
    , Template.Feature
    , Config
    , Token
    , Label
    , Sentence
    , train
    , predict
    , parseTemplate
    , defaultFlags

import qualified Data.Map as Map
import qualified Data.Set as Set
import qualified Data.IntMap as IntMap
import qualified Data.IntSet as IntSet
import Data.List (foldl',tails)
import Data.Maybe (fromMaybe)
import Helper.ListZipper
import qualified NLP.Perceptron.Sequence as P
import NLP.Perceptron.Sequence (Options(..))
import NLP.Sequor.CoNLL
import Helper.Utils (splitWith,uniq)
import Helper.Atom
import Control.Monad.RWS
import NLP.Sequor.Features (inputFeatures,features,maybeFeatures,outputFeatures,indexFeatures)
import qualified NLP.Sequor.FeatureTemplate as Template
import qualified Data.Array as A
import qualified Data.Vector.Unboxed as V
import qualified Data.Binary as Binary
import qualified Helper.Text as Text
import Helper.Text(Txt)
import qualified Data.Text.Lazy as Text
import Data.Char
import Data.Maybe (catMaybes)
import NLP.Sequor.Config 
import Text.Printf
import Debug.Trace

data ModelData = ModelData { model :: P.Model -- ^ Sequence perceptron model
                           , config :: Config -- ^ Model configuration options
instance Binary.Binary ModelData where
    get = return ModelData `ap` Binary.get `ap` Binary.get
    put (ModelData a b) = Binary.put a >> Binary.put b 

-- | @predict model sentence@ returns the best label sequence for
--  sentence. A sentence is a sequence of 'Token's.
predict :: ModelData -> [[Token]] -> [[Label]]
predict m testdat = 
    let bounds = oFeatBounds . P.options . model $ m
    in fst . flip runAtoms (maybe (error "NLP.Sequor.predict:Nothing") id . atomTable . config $ m) 
       $ do flip mapM  (map (toZippers . map (take (fieldNumber m))) testdat) $ \x -> 
               do x' <- mapM (maybeFeatures bounds (config m)) $ x
                  predict' (P.decode (model m)) $ x'

-- | @train flags template training development@ trains a model on training
-- sentences using give flags and feature template and returns the model and a
-- for each iteration the error rate on training and development sentences.
train :: Flags 
      -> Template.Feature
      -> [(Sentence, [Label])]
      -> [(Sentence, [Label])]
      -> (ModelData, P.Trace)
train fs template traindat heldout = 
        let len = length . (\(x:_) -> x) . fst . (\(x:_) -> x) $ traindat 
            conf = Config { featureTemplate = template
                          , atomTable = Nothing
                          , flags = fs              
                          , fieldNum = len }
            ((m,_predicted, info),_atoms) = 
                 runAtoms (run conf 
                               (zippify traindat)
                               (zippify heldout))
                              $ empty
        in (m, info)

-- | @parseTemplete s@ parses feature template in s and returns the
-- result.
parseTemplate :: Text.Text -> Template.Feature
parseTemplate = Template.parse

defaultFlags :: Flags
defaultFlags = Flags { flagRate         = 0.01
                     , flagBeam         = 10
                     , flagIter         = 10
                     , flagMinFeatCount = 100
                     , flagHeldout      = Nothing
                     , flagHash         = False
                     , flagHashSample   = 1000
                     , flagHashMaxSize  = Nothing
                     , flagStopWinSize  = 5
                     , flagStopThreshold = 0.05

-- Implementation

fieldNumber :: ModelData -> Int
fieldNumber = fieldNum . config

type F = Int
type Tag = Int

zippify :: [([Token], [Txt])] -> [([ListZipper Token], [Txt])]
zippify = map (\ (x, y) -> (toZippers x, y))

tagDictionary ::  IntSet.IntSet 
              -> Int 
              -> [([V.Vector Int], [F])] 
              -> IntMap.IntMap [Tag]
tagDictionary indexFeatureSet wmin trainset = 
    let tags = concat . map snd $ trainset
        ws   =   catMaybes  
               . map (V.find (`IntSet.member` indexFeatureSet))
               . concat 
               . map fst 
               $ trainset
        count_ws = IntMap.fromListWith (+) [ (w,1) | w <- ws ]
        dict =   IntMap.map Set.toList
               . IntMap.fromListWith Set.union 
               $ [ (w,Set.singleton t) | (w,t) <- zip ws tags 
               , count_ws IntMap.! w >= wmin]
    in dict == dict `seq` dict

pruneLabels :: Int -> [(x,[Txt])] -> [(x,[Txt])]
pruneLabels lim xys =
    let freq =   Map.fromListWith (+)
               . map (\y -> (y,1))
               . concat
               . map snd
               $ xys
        undet = "UNDETERMINED"
    in [ (x,[ if freq Map.! yi < lim then undet else yi | yi <- y ]) 
         | (x,y) <- xys ]

run :: (Functor m, MonadAtoms m) =>
    ->  [([ListZipper Token], [Txt])]
    ->  [([ListZipper Token], [Txt])]
    -> m (ModelData, [[Txt]], P.Trace)
run conf trainset_in testset_in = do
  let --trainset_in = pruneLabels (minLabelFreq conf) trainset_in_full
      ys = uniq . concat . map snd $ trainset_in :: [Txt]
  ys' <- mapM toAtom ys
  outm <- mkOutputFeatureAtoms . map snd $ trainset_in 
  let size = outputFeatureCount outm + 
             maybe (estimateFeatureCount conf . map fst $ trainset_in)
                   (flagHashMaxSize . flags $ conf)
      bounds = if flagHash . flags $ conf 
               then Just (0,size)
               else Nothing
  trainset <- mapM (mkfs $ features bounds conf) trainset_in
  testset <- mapM (mkfs $ maybeFeatures bounds conf) testset_in 
  tab <- table
  let indexFeatureSet = indexFeatures tab
      conf' = conf {atomTable = Just tab }
      opts = Options { oYMap = outm
                     , oIndexSet =  indexFeatureSet
                     , oYDict = tagDictionary indexFeatureSet 
                                     (flagMinFeatCount . flags $ conf') trainset
                     , oYs   = ys'
                     , oBeam = flagBeam . flags $ conf
                     , oRate = flagRate . flags $ conf
                     , oEpochs = flagIter . flags $ conf
                     , oFeatBounds = bounds
                     , oStopWinSize = flagStopWinSize . flags $ conf
                     , oStopThreshold = flagStopThreshold . flags $ conf
      (m, info) = P.train opts testset trainset
  ps <- mapM (predict' (P.decode m . fst)) testset
  return (ModelData { model = m , config = conf' } , ps, info)

predict' :: (MonadAtoms m) =>
            (t -> [Int]) -> t -> m [Txt]
predict' dec x = do
        let xr = dec  x
        xr'<- mapM fromAtom xr
        return xr'

mkOutputFeatureAtoms :: (MonadAtoms m) => [[Txt]] -> m P.YMap
mkOutputFeatureAtoms yss = do
  let unigrams = map return . uniq . concat $ yss
      bigrams = uniq $ concat [   filter ((==2) . length) 
                                . map (take 2) 
                                . tails 
                                $ ys | ys <- yss ]
  unigramis <- mapM (mapM toAtom) unigrams
  bigramis  <- mapM (mapM toAtom) bigrams
  let ys = map head unigramis
      (lo,hi) = (minimum ys,maximum ys)
  unigramfs <- mapM (mapM toAtom) . map outputFeatures $ unigrams
  bigramfs  <- mapM (mapM toAtom) . map outputFeatures $ bigrams
  zerofs <- mapM toAtom . outputFeatures $ []
  let ymap1 =   A.accumArray (V.++) V.empty (lo,hi) 
              . zip (map head unigramis) 
              . map V.fromList
              $ unigramfs
      ymap2 =    A.accumArray (V.++) V.empty ((lo,lo),(hi,hi)) 
               . zip (map (\ [y1,y2] -> (y1,y2)) bigramis)
               . map V.fromList
               $ bigramfs 
  return $ (V.fromList zerofs, ymap1, ymap2)

outputFeatureCount :: P.YMap -> Int
outputFeatureCount (zero,uni,bi) = 
    maximum  (V.toList zero 
              ++ (concatMap V.toList . A.elems $ uni)
              ++ (concatMap V.toList . A.elems $ bi ))
mkfs :: (MonadAtoms m) => 
        (ListZipper Token -> m (V.Vector F))
     ->   ([ListZipper Token], [Txt]) 
     ->   m ([V.Vector F], [Tag])
mkfs f (x,y) = do
  fs <- mapM f x
  fs == fs `seq` return ()
  y' <- mapM toAtom y
  y' == y' `seq` return ()
  return $ (fs,y')

estimateFeatureCount :: Config -> [[ListZipper Token]] -> Int
estimateFeatureCount conf xs = 
    let len = length xs
        size = min len . flagHashSample . flags $ conf
        factor = length xs `div` size
        tokno  = (factor *) 
                 . length 
                 . uniq
                 . concatMap (concatMap (inputFeatures conf))
                 . take size
                 $ xs
    in tokno