{-# LANGUAGE OverloadedStrings #-} module NLP.Brillig.Brill where import Control.Arrow ( first, second, (***) ) import Data.Function ( on ) import qualified Data.Text as T import qualified Data.Map as Map import Data.List ( isPrefixOf, delete, maximumBy, sort, sortBy ) import Data.Ord ( compare ) import Data.Map ( Map ) import Data.Text ( Text ) import qualified Data.List.Zipper as Z import Data.List.Zipper ( Zipper ) import qualified Data.Set as Set import Data.Set ( Set ) import NLP.Brillig import NLP.Brillig.Util data Transform = Transform { context :: [Tag] -- backwards list! , replace :: Replacement , tscore :: Int } deriving (Ord, Eq) instance Show Transform where show (Transform c (Replacement f t) i) = unwords $ twords ++ [show i] where twords = map (T.unpack . fromTag) (reverse c ++ [f,t]) instance Read Transform where readsPrec p s = case reverse (words s) of (i:t:f:ctx) -> case readsPrec p i of [(s,"")] -> [ (Transform (map toTag ctx) (repl f t) s, "") ] _ -> [] _ -> [] where repl f t = Replacement (toTag f) (toTag t) toTag = Tag . T.pack data Replacement = Replacement { from :: Tag , to :: Tag } deriving (Show, Ord, Eq) type TCount = Map Tag Int data TagPair = TagPair { proposed :: Tag , actual :: Tag } deriving (Ord, Eq) -- ---------------------------------------------------------------------- -- tag -- ---------------------------------------------------------------------- brilltag :: [Transform] -> [Tagged Text] -> [Tagged Text] brilltag rules_ = retag (\ts -> foldr tagOne ts rules) where rules = sortBy (flip compare `on` tscore) rules_ -- | Apply a single transformation tagOne :: Transform -> [Tag] -> [Tag] tagOne x = Z.toList . walk . Z.fromList where from_ = from (replace x) to_ = to (replace x) walk z@(Z.Zip _ [] ) = z walk z@(Z.Zip ls (r:rs)) = walk (Z.right next) where next = if context x `isPrefixOf` ls && r == from_ then Z.replace to_ z else z -- ---------------------------------------------------------------------- -- train -- ---------------------------------------------------------------------- learnConverge :: Int -- ^ floor -> [Tag] -- ^ corpus -> [Tag] -- ^ best guess -> [Transform] learnConverge floor cs bs = if tscore r > floor then r : learnConverge floor cs (tagOne r bs) else [] where r = learnOne bs cs learnN :: Int -> [Tag] -- ^ corpus -> [Tag] -- ^ best guess -> [Transform] learnN = go [] where go acc 0 _ _ = acc go acc n cs bs = let r = learnOne bs cs bs2 = tagOne r bs in go (r:acc) (n - 1) cs bs2 -- | Not iteratively applying and relearning! Just doing one pass for now learnOne :: [Tag] -- ^ corpus -> [Tag] -- ^ best guess -> Transform learnOne best corpus = bestTransform tags pairs where tags = Set.fromList best `Set.union` Set.fromList corpus pairs = zipWith TagPair best corpus bestTransform :: Set Tag -> [TagPair] -> Transform bestTransform tags xs = answer where answer = maximumBy (compare `on` tscore) (map best repls) best r = bestInstance r dhist repls = Set.toList (replacements tags) dhist = deltaHistogram xs replacements :: Set Tag -> Set Replacement replacements tags = Set.fromList [ Replacement f t | f <- xs, t <- delete f xs ] where xs = Set.toList tags bestInstance :: Replacement -> DeltaHistogram -> Transform bestInstance repl dhist = answer -- we swap instead of just looking at the score so we can get consisent -- ordering in case of a tie; not that it matters much just trying to -- future-proof against changes in the Set implementation where answer = toTransform $ maximumBy (compare `on` swap) tmap toTransform (t,i) = Transform [t] repl i swap (x,y) = (y,x) tmap = Map.toList $ Map.map (score repl) dhist type DeltaHistogram = Map Tag (Map TagPair Int) -- | how many times a replacement follows each context deltaHistogram :: [TagPair] -> DeltaHistogram deltaHistogram xs = Map.map histogram $ Map.fromListWith (++) $ zip prevs currs where prevs = map proposed xs currs = map (\x -> [x]) $ drop 1 xs score :: Replacement -> Map TagPair Int -> Int score r m = count good - count bad where count x = Map.findWithDefault 0 x m good = TagPair (from r) (to r) bad = TagPair (from r) (from r) plusPair :: Num a => (a,a) -> (a,a) -> (a,a) plusPair (x1,x2) (y1,y2) = (x1 + y1, x2 + y2)