{-# LANGUAGE RecordWildCards #-}
module NLP.Concraft.Disamb
(
  Disamb (..)
, P.Tier (..)
, P.Atom (..)
, marginals
, disamb
, include
, disambSent
, TrainConf (..)
, train
, prune
) where
import Control.Applicative ((<$>), (<*>))
import Data.List (find)
import Data.Binary (Binary, put, get)
import qualified Data.Set as S
import qualified Data.Map as M
import qualified Data.Vector as V
import qualified Control.Monad.Ox as Ox
import qualified Data.CRF.Chain2.Tiers as CRF
import NLP.Concraft.Schema hiding (schematize)
import qualified NLP.Concraft.Morphosyntax as X
import qualified NLP.Concraft.Disamb.Positional as P
import qualified Data.Tagset.Positional as T
import qualified Numeric.SGD as SGD
schematize :: Schema w [t] a -> X.Sent w [t] -> CRF.Sent Ob t
schematize schema sent =
    [ CRF.mkWord (obs i) (lbs i)
    | i <- [0 .. n - 1] ]
  where
    v = V.fromList sent
    n = V.length v
    obs = S.fromList . Ox.execOx . schema v
    lbs i = X.interpsSet w
        where w = v V.! i
data Disamb = Disamb
    { tiers         :: [P.Tier]
    , schemaConf    :: SchemaConf
    , crf           :: CRF.CRF Ob P.Atom }
instance Binary Disamb where
    put Disamb{..} = put tiers >> put schemaConf >> put crf
    get = Disamb <$> get <*> get <*> get
unSplit :: Eq t => (r -> t) -> X.Seg w r -> t -> r
unSplit split' word x = case jy of
    Just y  -> y
    Nothing -> error "unSplit: no such interpretation"
  where
    jy = find ((==x) . split') (X.interps word)
disamb :: X.Word w => Disamb -> X.Sent w T.Tag -> [T.Tag]
disamb Disamb{..} sent
    = map (uncurry embed)
    . zip sent
    . CRF.tag crf
    . schematize schema
    . X.mapSent split
    $ sent
  where
    schema  = fromConf schemaConf
    split   = \t -> P.split tiers t Nothing
    embed   = unSplit split
include :: (X.Sent w T.Tag -> [T.Tag]) -> X.Sent w T.Tag -> X.Sent w T.Tag
include f sent =
    [ word { X.tags = tags }
    | (word, tags) <- zip sent sentTags ]
  where
    sentTags = map (uncurry select) (zip (f sent) sent)
    select x word = X.mkWMap
        [ (y, if x == y then 1 else 0)
        | y <- X.interps word ]
disambSent :: X.Word w => Disamb -> X.Sent w T.Tag -> X.Sent w T.Tag
disambSent = include . disamb
marginals :: X.Word w => Disamb -> X.Sent w T.Tag -> [X.WMap T.Tag]
marginals Disamb{..} sent
    = map (uncurry embed)
    . zip sent
    . CRF.marginals crf
    . schematize schema
    . X.mapSent split
    $ sent
  where
    schema  = fromConf schemaConf
    split   = \t -> P.split tiers t Nothing
    embed w = X.mkWMap . zip (X.interps w)
prune :: Double -> Disamb -> Disamb
prune x dmb =
    let crf' = CRF.prune x (crf dmb)
    in  dmb { crf = crf' }
data TrainConf
    = TrainConf
        { tiersT        :: [P.Tier]
        , schemaConfT   :: SchemaConf
        , sgdArgsT      :: SGD.SgdArgs
        , onDiskT       :: Bool }
    | ReTrainConf
        { initDmb       :: Disamb
        , sgdArgsT      :: SGD.SgdArgs
        , onDiskT       :: Bool }
train
    :: X.Word w
    => TrainConf                
    -> IO [X.Sent w T.Tag]      
    -> IO [X.Sent w T.Tag]      
    -> IO Disamb                
train TrainConf{..} trainData evalData = do
    crf <- CRF.train (length tiersT) CRF.selectHidden sgdArgsT onDiskT
        (schemed schema split <$> trainData)
        (schemed schema split <$> evalData)
    putStr "\nNumber of features: " >> print (CRF.size crf)
    return $ Disamb tiersT schemaConfT crf
  where
    schema = fromConf schemaConfT
    split  = \t -> P.split tiersT t Nothing
train ReTrainConf{..} trainData evalData = do
    crf' <- CRF.reTrain crf sgdArgsT onDiskT
        (schemed schema split <$> trainData)
        (schemed schema split <$> evalData)
    putStr "\nNumber of features: " >> print (CRF.size crf')
    return $ initDmb { crf = crf' }
  where
    Disamb{..} = initDmb
    schema = fromConf schemaConf
    split  = \t -> P.split tiers t Nothing
schemed :: Ord t => Schema w [t] a -> (T.Tag -> [t])
        -> [X.Sent w T.Tag] -> [CRF.SentL Ob t]
schemed schema split =
    map onSent
  where
    onSent sent =
        let xs  = map (X.mapSeg split) sent
            mkProb = CRF.mkProb . M.toList . X.unWMap . X.tags
        in  zip (schematize schema xs) (map mkProb xs)