{-# LANGUAGE NoMonomorphismRestriction , BangPatterns #-}
module Entropy.Algorithm ( cluster
                         , clusterBeam
                         , clusterToken
                         , labelToken
                         , clusterWords
                         , ClusterSet (..)
                         , weightedhXY
                         , empty
                         , makeClusterSet
                         , X
                         , Y
                         , Count
                         , featIDs
                         , predictX0
                         , predictX0Full
                         , clusterLabelToX0
                         , defocus
                         , getX0
                         , examples
                         , display
                         , getLabeler
                         )
where
  
import qualified Entropy.Features as F
import qualified Data.Map as Map
import qualified Data.Set as Set
import Data.Map ((!))
import ListZipper 
import Data.List (sort,sortBy,foldl')
import Data.Ord (comparing)
import Reader (Token,readcorpus)
import Debug.Trace
import Prelude hiding (sum)
import SparseVector (plus,dot)
import Data.Binary (encode,decode,put,get,Binary)
import Control.Monad (ap)
import qualified Data.ByteString.Lazy as BS
import Counts (rankNormalize)

type Count = Double
type H = Double

nextID c = lastID c + 1
ids = Map.keys . countY

flatten :: (Ord k, Ord k1) => Map.Map k (Map.Map k1 a) -> Map.Map (k,k1) a
flatten m = Map.fromList [ ((k,k1),x) | (k,m1) <- Map.toList m 
                                      , (k1,x) <- Map.toList m1 ]


features :: ListZipper String -> X (Int,String)
features = flatten . F.features

data ClusterSet x = CS { countY   :: Map.Map Y Count 
                       , hY       :: !H
                       , countXY  :: Map.Map Y (Map.Map x Count) 
                       , hXY      :: Map.Map Y H
                       , countN   :: !Count 
                       , lastID   :: !Y }
                    deriving (Eq,Ord,Show)  

instance (Ord x,Binary x) => Binary (ClusterSet x) where
    put cs =      put (countY cs) 
               >> put (hY cs)
               >> put (countXY cs)
               >> put (hXY cs)
               >> put (countN cs)
               >> put (lastID cs)

    get = return CS `ap` get `ap` get `ap` get `ap` get `ap` get `ap` get
             
    
empty = CS { countY  = Map.empty
           , hY      = 0
           , countXY = Map.empty
           , hXY     = Map.empty
           , countN  = 0
           , lastID  = 0 } 

featIDs :: (Ord a, Ord b) => ClusterSet (a,b) -> [a]
featIDs = uniq . map fst . concat . Map.elems . Map.map Map.keys . countXY

makeClusterSet ::  Map.Map Y (Map.Map x Count) -> ClusterSet x
makeClusterSet cxy =
    let cy = Map.map (Map.fold (+) 0) $ cxy
        hy = entropy cy
        hxy = Map.map entropy $ cxy
        n = Map.fold (+) 0 cy
        last = maximum . Map.keys $ cy
    in CS { countY = cy , hY = hy 
          , countXY = cxy
          , hXY = hxy 
          , countN = n 
          , lastID = last + 1 }

type X k = Map.Map k Count
type Y = Int

sum :: [Double] -> Double
sum = foldl' (+) 0

entropy cs = 
    let xs = Map.elems cs
        n  = sum xs
        logn = logBase 2 n
        loop !s [] = s
        loop !s (!x:xs) = loop (s - x/n * (logBase 2 x - logn)) xs
    in loop 0 xs

pseudoC = 0.00001
pseudoN = 1
entropy' pc pn m = let ms = Map.elems m 
                       n = sum ms + pn
                       logn = logBase 2 n
                   in - pn * (if pc == 0 then 0 else pc/n * logBase 2 (pc/n))
                      - sum [ m / n * logBase 2 (m/n) | m <- ms ]

score cs = weightedhXY cs + hY cs
weightedhXY cs = sum 
              . Map.elems
              . Map.mapWithKey (\k c -> c/countN cs * hXY cs ! k) 
              . countY
              $ cs
{- incremental update:

 H^t(X|Y) - H^{t-1}(X|Y)
        =     \sum_{y \neq y^n} [ count(y)/N(N+1) H^{t-1}(X|Y=y) ]
              + [ p^t(y^n)H^t(X|Y=y^n) - p^{t-1}(y^n)H^{t-1}(X|Y=y^n) ]


where N is the count of features at time t-1
-}

update :: (Ord k) => ClusterSet k -> X k -> Y -> ClusterSet k
update cs x y = 
    CS { countY  = countY'
       , hY      = entropy countY'
       , countXY = countXY'
       , hXY     = hXY'
       , countN  = countN'
       , lastID  = if Map.member y (countY cs) 
                   then lastID cs
                   else lastID cs + 1
       }
    where countN'  = countN cs + nx
          countXY' = countXY cs `plus` Map.singleton y x
          countY' = Map.insertWith' (+) y nx . countY $ cs
          hy_new  = entropy (countXY' ! y)
          hXY'    = Map.insert y hy_new . hXY $ cs
          nx      = sum . Map.elems $ x
          
clusterToken :: Bool 
             -> ClusterSet (Int,String) 
             -> X (Int,String)
             -> [(Y,ClusterSet (Int,String))]
clusterToken freeze cs x = map snd . clusterTok freeze cs $ x

clusterTok :: Bool 
             -> ClusterSet (Int,String) 
             -> X (Int,String)
             -> [((Double,Y), (Y,ClusterSet (Int,String)))]
clusterTok freeze cs x =
    let rs = rank 
             $ [let cs' = update cs x y 
                in ((score cs',y),(y,cs'))
                | y <- nextID cs : ids cs
               ,  y == nextID cs 
                      || Map.size (countXY cs ! y `Map.intersection` x) > 0
                     ]
    in rs

rank :: [((Double,Y),a)] -> [((Double,Y),a)]
rank = sortBy (comparing (\((s,y),_) -> (realToFrac s :: Float, negate y)))

-- | labelToken: output a single label (from a closed set) 
labelToken :: ClusterSet (Int,String)
           -> X (Int,String)
           -> Y
labelToken cs x = 
    let ((y,cs'):rs) = clusterToken True cs x
    in if y == nextID cs 
       then fst . head . sortBy (flip $ comparing snd) . Map.toList . countY 
                $ cs
       else y

getLabeler :: Bool -> Maybe FilePath -> FilePath -> IO ([String] -> [String])
getLabeler nonew md file = do
  cs <- fmap decode $ BS.readFile file
  conv <- case md of
         Nothing -> return show
         Just fp -> do ls <- fmap lines $ readFile fp
                       let dict = Map.fromList 
                                  . map (\ln -> case words ln of
                                                  [k,v] -> (read k,v))
                                  $ ls
                       return $ 
                        (\k -> Map.findWithDefault (error 
                                              $ "getLabeler:Not found: " ++ 
                                                show k) k dict)
  let fids = featIDs cs
      label = if nonew 
              then labelToken cs 
              else fst . head . clusterToken True cs
  return $ \ws -> let toks = zip ws ws
                      xs = examples fids [toks]
                  in map (conv . label) . head $ xs

clusterWords :: [Int] -> ClusterSet (Int,String) -> [String] -> [String]
clusterWords fids cs ws =
    let toks = zip ws ws
        xs = examples fids [toks]
    in map (show . fst . head . clusterToken True cs) . head $ xs

cluster :: Bool 
        -> ClusterSet (Int,String) 
        -> [X (Int, String)] 
        -> ClusterSet (Int, String)
cluster freeze = foldl' (\cs x -> snd . head . clusterToken freeze cs $ x) 
                        
clusterBeam :: Int 
            -> Bool 
            -> ClusterSet (Int,String)
            -> [X (Int, String)] 
            -> ClusterSet (Int,String)
clusterBeam sz freeze cs = 
    let step css x = map (snd . snd)
                     . take sz
                     . rank 
                     . concat 
                     $ [ take sz . clusterTok freeze cs $ x | cs <- css ]
    in head . foldl' step (return cs)

normalize :: (Ord k) => Map.Map k Double -> Map.Map k Double
normalize x = let s = Map.fold (+) 0 x in Map.map (/s) x

rescale xs = let a = (negate $ minimum xs)
                 xs' = map (+ a) xs
                 s = sum xs'
             in map (/s) xs'


predictX0Full :: ClusterSet (Int,String)
              -> X (Int,String)
              -> [String]
predictX0Full cs' = 
    let xy = Map.map normalize . countXY $ cs'
        unigram = normalize . Map.unionsWith (+) 
                  . map (Map.filterWithKey (\(i,_) _ -> i == 0))
                  . Map.elems 
                  . countXY 
                  $ cs'
        uniform = (1/) . fromIntegral . succ . Map.size . countY $ cs'
        ws = uniq 
             . map snd
             . filter ((==0) . fst)
             . concat 
             . Map.elems 
             . Map.map Map.keys 
             . countXY 
             $ cs'
        score0 = score cs'
    in \x ->
        let ycs = clusterToken True cs' . defocus $ x
            cs  = map snd ycs
            es = map ((score0 -) . score) cs
            ys' = map fst ycs
            (ys,ps) = unzip . rankNormalize . zip ys' $ es
            pyxs = Map.fromList . zip ys $ ps
        in map fst
               . sortBy (flip $ comparing snd)
                     $ [ (w, sum [ let px0y =   Map.findWithDefault 0 (0,w) 
                                             . Map.findWithDefault Map.empty y
                                             $ xy
                                       pyx =  Map.findWithDefault 0 y pyxs
                                   in  if y == nextID cs' then
                                           (unigram ! (0,w)) * uniform
                                       else 
                                       px0y * pyx 
                                   | y <- ys ]) 
                         | w <- ws ]

predictX0 :: ClusterSet (Int,String) 
          -> X (Int,String) 
          -> [String]
predictX0 cs = 
    let unigram =  Map.filterWithKey (\(i,w) _ -> i == 0) 
                  . Map.unionsWith (+) 
                  . Map.elems 
                  . countXY 
                  $ cs
    in \x ->     let ((y,CS { countXY = xy} ):_) = 
                         clusterToken True cs . defocus $ x
                     pred = Map.filterWithKey (\(i,w) _ -> i == 0) $ xy ! y 
                 in map fst 
                            . sortBy (flip $ comparing snd)
                            . map (\((0,w),c) -> (w,c))
                            . Map.toList
                            $ if Map.size pred < 1 then unigram else pred

clusterLabelToX0 :: ClusterSet (Int,String)
                 -> Y
                 -> [String]
clusterLabelToX0 cs = 
    let m = Map.map (map fst 
                     . sortBy (flip $ comparing snd)
                     . map (\((0,w),c) -> (w,c))
                     . Map.toList
                     . Map.filterWithKey (\(i,w) _ -> i == 0))
            . countXY
            $ cs
    in \y -> m ! y

examples :: [Int] -> [[Token]] -> [[X (Int, String)]]
examples fids train = [ [ Map.filterWithKey (\(fid,_) _ -> fid `elem` fids)
                         . features 
                         $ t 
                       | t <- take (length s) 
                                . iterate next 
                                . fromList 
                                . map fst
                                $ s ] | s <- train ]

display :: (Y,Map.Map (Int,String) Double) -> String
display (y,c) = 
    let cs = Map.toList c
        fis = uniq . map (fst . fst) $ cs
        grps = [ (fi,
                    sortBy (flip $ comparing snd) 
                    . map (\((_,fv),c) -> (fv,c))
                    . filter (\((fi',fv),c) -> fi' == fi) $ cs) | fi <- fis ]
    in show y ++ "\n" ++ 
       unlines [unwords $ [ take 3 $  show fi ++ (repeat ' ') ] 
                               ++ [ fi ++ ":" ++ show (ceiling fv) 
                                      | (fi,fv) <- take 20 fs ] 
               | (fi,fs) <- grps ] 

uniq = Set.toList . Set.fromList

defocus :: X (Int,String) -> X (Int,String) 
defocus = Map.filterWithKey (\(i,_) _ -> i /= 0)

getX0 :: X (Int,String) -> String
getX0 = maybe (error "Entropy.Algorithm.getX0: not found") 
              id      
        . lookup 0
        . Map.keys