module NLP.Punkt where
import qualified Data.Text as Text
import Data.Text (Text)
import Data.Maybe (catMaybes, fromMaybe)
import Data.HashMap.Strict (HashMap)
import Data.Char (isLower, isAlpha, isSpace)
import qualified Data.HashMap.Strict as Map
import qualified Data.List as List
import Control.Applicative ((<$>), (<*>), (<|>))
import qualified Control.Monad.Reader as Reader
import NLP.Punkt.Match (re_split_pos, word_seps)
data OrthoFreq = OrthoFreq {
    freq_lower :: Int,
    
    freq_upper :: Int,
    
    freq_first_lower :: Int,
    
    freq_internal_upper :: Int,
    
    freq_after_ender :: Int
    
    }
    deriving Show
data PunktData = PunktData {
    type_count :: HashMap Text Int,
    
    
    ortho_count :: HashMap Text OrthoFreq,
    
    collocations :: HashMap (Text, Text) Int,
    total_enders :: Int,
    total_toks :: Int
    }
    deriving Show
data Entity a = Word a Bool | Punct a | ParaStart | Ellipsis | Dash
    deriving (Eq, Show)
data Token = Token {
    offset :: Int,
    toklen :: Int,
    entity :: Entity Text,
    sentend :: Bool,
    abbrev :: Bool
    }
    deriving Show
type Punkt = Reader.Reader PunktData
norm :: Text -> Text
norm = Text.toLower
is_initial :: Token -> Bool
is_initial (Token {entity=Word w True}) =
    Text.length w == 1 && isAlpha (Text.head w)
is_initial _ = False
is_word :: Token -> Bool
is_word tok = case entity tok of { Word _ _ -> True; _ -> False; }
strunk_log :: Double -> Double -> Double -> Double -> Double
strunk_log a b ab n = 2 * (null  alt)
    where
    null = ab * log p1 + (a  ab) * log (1  p1)
    alt = ab * log p2 + (a  ab) * log (1  p2)
    (p1, p2) = (b / n, 0.99)
dunning_log :: Double -> Double -> Double -> Double -> Double
dunning_log a b ab n | b == 0 || ab == 0 = 0
                     | otherwise = 2 * (s1 + s2  s3  s4)
    where
    (p0, p1, p2) = (b / n, ab / a, (b  ab) / (n  a))
    s1 = ab * log p0 + (a  ab) * log (1  p0)
    s2 = (b  ab) * log p0 + (n  a  b + ab) * log (1  p0)
    s3 = if a == ab then 0 else ab * log p1 + (a  ab) * log (1  p1)
    s4 = if b == ab then 0 else
        (b  ab) * log p2 + (n  a  b + ab) * log (1  p2)
ask_type_count :: Punkt (HashMap Text Int)
ask_type_count = Reader.liftM type_count Reader.ask
ask_total_toks :: Num a => Punkt a
ask_total_toks = Reader.liftM (fromIntegral . total_toks) Reader.ask
ask_total_enders :: Num a => Punkt a
ask_total_enders = Reader.liftM (fromIntegral . total_enders) Reader.ask
ask_ortho :: Text -> Punkt OrthoFreq
ask_ortho w_ = return . Map.lookupDefault (OrthoFreq 0 0 0 0 0) (norm w_)
               =<< fmap ortho_count Reader.ask
ask_colloc :: Text -> Text -> Punkt Double
ask_colloc w0_ w1_ =
    return . fromIntegral . Map.lookupDefault 0 (norm w0_, norm w1_)
    =<< collocations <$> Reader.ask
freq :: Text -> Punkt Double
freq w_ = ask_type_count >>= return . fromIntegral . Map.lookupDefault 0 w
    where w = norm w_
freq_snoc_dot :: Text -> Punkt Double
freq_snoc_dot w_ = freq wdot where wdot = w_ `Text.snoc` '.'
freq_type :: Text -> Punkt Double
freq_type w_ = (+) <$> freq w_ <*> freq_snoc_dot w_
dlen :: Text -> Double
dlen = fromIntegral . Text.length
prob_abbr :: Text -> Punkt Double
prob_abbr w_ = compensate =<< strunk_log <$> freq_type w_ <*> freq "."
                                         <*> freq_snoc_dot w_ <*> ask_total_toks
    where
    compensate loglike = do
        f_penalty <- do
            p <- freq w_  
            return $ 1 / dlen (Text.filter (/= '.') w_) ** p
        return $ loglike * f_len * f_periods * f_penalty
    f_len = 1 / exp (dlen $ Text.filter (/= '.') w_)
    f_periods = 1 + dlen (Text.filter (== '.') w_)
decide_ortho :: Text -> Punkt (Maybe Bool)
decide_ortho w_ = ask_ortho w_ >>= return . decide' w_
    where
    decide' w_ wortho
        | title && ever_lower && never_title_internal = Just True
        | lower && (ever_title || never_lower_start) = Just False
        | otherwise = Nothing
        where
        (lower, title) = (isLower $ Text.head w_, not lower)
        ever_lower = freq_lower wortho > 0
        never_title_internal = freq_internal_upper wortho == 0
        ever_title = freq_upper wortho > 0
        never_lower_start = freq_first_lower wortho == 0
decide_initial_ortho :: Text -> Punkt (Maybe Bool)
decide_initial_ortho w_ = do
    neverlower <- (== 0) . freq_lower <$> ask_ortho w_
    orthosays <- decide_ortho w_
    return $ orthosays <|> if neverlower then Just False else Nothing
prob_starter :: Text -> Punkt Double
prob_starter w_ = dunning_log <$> ask_total_enders <*> freq_type w_
                              <*> fafterend <*> ask_total_toks
    where fafterend = fromIntegral . freq_after_ender <$> ask_ortho w_
prob_colloc :: Text -> Text -> Punkt Double
prob_colloc w_ x_ = dunning_log <$> freq_type w_ <*> freq_type x_
                                <*> ask_colloc w_ x_ <*> ask_total_toks
build_type_count :: [Token] -> HashMap Text Int
build_type_count = List.foldl' update initcount
    where
    initcount = Map.singleton "." 0
    update ctr (Token {entity=(Word w per)})
        | per = Map.adjust (+ 1) "." ctr_
        | otherwise = ctr_
        where
        ctr_ = Map.insertWith (+) wnorm 1 ctr
        wnorm = norm $ if per then w `Text.snoc` '.' else w
    update ctr _ = ctr
    
    
build_ortho_count :: [Token] -> HashMap Text OrthoFreq
build_ortho_count toks = List.foldl' update Map.empty $
                            zip (dummy : toks) toks
    where
    dummy = Token 0 0 (Word " " False) True False
    
    update :: HashMap Text OrthoFreq -> (Token, Token) -> HashMap Text OrthoFreq
    update ctr (prev, Token {entity=(Word w _)}) = Map.insert wnorm wortho ctr
        where
        upd (OrthoFreq a b c d e) a' b' c' d' e' =
            OrthoFreq (a |+ a') (b |+ b') (c |+ c') (d |+ d') (e |+ e')
            where int |+ bool = if bool then int + 1 else int
        wortho = upd z lower (not lower) (first && lower)
                       (internal && not lower) first
        z = Map.lookupDefault (OrthoFreq 0 0 0 0 0) wnorm ctr
        wnorm = norm w
        lower = isLower $ Text.head w
        first = sentend prev && not (is_initial prev)
        internal = not (sentend prev) && not (abbrev prev)
                   && not (is_initial prev)
    update ctr _ = ctr
build_collocs :: [Token] -> HashMap (Text, Text) Int
build_collocs toks = List.foldl' update Map.empty $ zip toks (drop 1 toks)
    where
    update ctr (Token {entity=(Word u _)}, Token {entity=(Word v _)}) =
        Map.insertWith (+) (norm u, norm v) 1 ctr
    update ctr _ = ctr
to_tokens :: Text -> [Token]
to_tokens corpus = catMaybes . map (either tok_word add_delim) $
                        re_split_pos word_seps corpus
    where
    tok_word (w, pos)
        | trim == "" = Nothing
        | otherwise = Just $ Token pos (len trim) (Word s period) False False
        where
        trim = Text.dropAround (`elem` ",:()[]{}“”’\"\')") w
        period = Text.last trim == '.'
        s = if period then Text.init trim else trim
    add_delim (delim, pos)
        | d `elem` "—-" = Just $ Token pos (len delim) Dash False False
        | d `elem` ".…" = Just $ Token pos (len delim) Ellipsis False False
        | d `elem` ";!?" = Just $ Token pos (len delim) (Punct delim) True False
        | otherwise = Nothing
        where d = Text.head delim
    len = Text.length
build_punkt_data :: [Token] -> PunktData
build_punkt_data toks = PunktData typecnt orthocnt collocs nender totes
    where
    typecnt = build_type_count toks
    temppunkt = PunktData typecnt Map.empty Map.empty 0 (length toks)
    refined = runPunkt temppunkt $ mapM classify_by_type toks
    orthocnt = build_ortho_count refined
    collocs = build_collocs refined
    nender = length . filter (sentend . fst) $ zip (dummy : refined) refined
    dummy = Token 0 0 (Word " " False) True False
    totes = length $ filter is_word toks
classify_by_type :: Token -> Punkt Token
classify_by_type tok@(Token {entity=(Word w True)}) = do
    p <- prob_abbr w
    return $ tok { abbrev = p >= 0.3, sentend = p < 0.3}
classify_by_type tok = return tok
classify_by_next :: Token -> Token -> Punkt Token
classify_by_next this (Token _ _ (Word next _) _ _)
    | is_initial this = do
        let Word thisinitial _ = entity this
        colo <- prob_colloc thisinitial next
        startnext <- prob_starter next
        orthonext <- decide_initial_ortho next
        return $ if (colo >= 7.88 && startnext < 30) || orthonext == Just False
            then this { abbrev = True, sentend = False}
            else this  
    | entity this == Ellipsis || abbrev this = do
        ortho_says <- decide_ortho next
        prob_says <- prob_starter next
        return $ case ortho_says of
            Nothing -> this { sentend = prob_says >= 30 }
            Just bool -> this { sentend = bool }
classify_by_next this _ = return this
classify_punkt :: Text -> [Token]
classify_punkt corpus = runPunkt (build_punkt_data toks) $ do
    abbrd <- mapM classify_by_type toks
    final <- Reader.zipWithM classify_by_next abbrd (drop 1 abbrd)
    return $ final ++ [last toks]
    where toks = to_tokens corpus
find_breaks :: Text -> [(Int, Int)]
find_breaks corpus = slices_from endpairs 0
    where
    pairs_of xs = zip xs $ drop 1 xs
    endpairs = filter (sentend . fst) . pairs_of $ classify_punkt corpus
    
    slices_from [] n = [(n, Text.length corpus)]
    slices_from ((endtok, nexttok):pairs) n = (n, endpos + end) : slices_from pairs (endpos + n')
        where
        endpos = offset endtok + toklen endtok
        (end, n') = fromMaybe (endpos, endpos + 1) . match_spaces $
            substring corpus endpos (offset nexttok)
substring :: Text -> Int -> Int -> Text
substring c s e = Text.take (e  s) $ Text.drop s c
match_spaces :: Text -> Maybe (Int, Int)
match_spaces w = Text.findIndex isSpace w >>= \p ->
    case Text.break notSpace (Text.drop p w) of
        (spaces, _) -> Just (p, Text.length spaces + p)
    where notSpace = not . isSpace
split_sentences :: Text -> [Text]
split_sentences corpus = map (uncurry $ substring corpus) slices
    where slices = find_breaks corpus
runPunkt :: PunktData -> Punkt a -> a
runPunkt = flip Reader.runReader