{-# LANGUAGE NoMonomorphismRestriction , BangPatterns , FlexibleInstances #-} module NLP.Perceptron.Sequence ( Model(..) , Trace , Options(..) , YMap , train , decode ) where import qualified Data.Array.Unsafe as AU import Data.Array.ST import Data.Array.Unboxed import qualified Data.Array as A import qualified Data.Vector.Unboxed as V import Control.Monad.ST import qualified Control.Monad.ST.Lazy as LST import qualified Control.Monad.ST.Unsafe as ST.Unsafe import Control.Monad.Writer import Data.STRef import Control.Monad import qualified Data.Map as Map import qualified Data.IntMap as IntMap import qualified Data.IntSet as IntSet import NLP.Perceptron.Vector import System.IO import Debug.Trace --import NLP.Perceptron.Config import Data.List (inits,foldl',sortBy) import Data.Ord (comparing) import Helper.ListZipper import qualified Data.Binary as Binary import Helper.Utils (uniq) import qualified NLP.Scores as Scores import Text.Printf data Model = Model { options :: Options , weights :: UArray I Float } type X = [Xi] type Y = [Yi] type Xi = V.Vector Xii type Xii = Int type Yi = Int type Dot = Local -> Float data Options = Options { oYMap :: YMap , oIndexSet :: IntSet.IntSet , oYDict :: IntMap.IntMap [Yi] , oYs :: [Yi] , oBeam :: !Int , oRate :: !Float , oEpochs :: !Int , oFeatBounds :: Maybe (Int,Int) , oStopWinSize :: !Int , oStopThreshold :: !Double } deriving Eq type YMap = (Xi,A.Array Yi Xi,A.Array (Yi,Yi) Xi) instance Binary.Binary (V.Vector Int) where put v = Binary.put $ V.toList v get = V.fromList `fmap` Binary.get instance Binary.Binary Model where put m = do Binary.put (options m) -- Binary.put (weights m) let (lo,hi) = bounds . weights $ m xs = filter (\(_,e) -> e /= 0.0) . assocs . weights $ m Binary.put (lo,hi) Binary.put xs get = {-# SCC "get1" #-} do os <- Binary.get os == os `seq` return () ws <- do (lo,hi) <- Binary.get xs <- Binary.get xs == xs `seq` return () return $ accumArray (+) 0 (lo,hi) $ xs ws == ws `seq` return () return $ Model os ws instance Binary.Binary Options where put (Options a b c d e f g h i j) = Binary.put a >> Binary.put b >> Binary.put c >> Binary.put d >> Binary.put e >> Binary.put f >> Binary.put g >> Binary.put h >> Binary.put i >> Binary.put j get = {-# SCC "get2" #-} do a <- Binary.get a == a `seq` return () b <- Binary.get b == b `seq` return () c <- Binary.get c == c `seq` return () d <- Binary.get d == d `seq` return () e <- Binary.get e == e `seq` return () f <- Binary.get f == f `seq` return () g <- Binary.get g == g `seq` return () h <- Binary.get h == h `seq` return () i <- Binary.get i == i `seq` return () j <- Binary.get j == j `seq` return () return $ Options a b c d e f g h i j yDictFind :: Options -> Xi -> [Yi] yDictFind opts fs = let mk = V.find (`IntSet.member` oIndexSet opts) $ fs def = oYs opts in case mk of Just k -> IntMap.findWithDefault def k . oYDict $ opts Nothing -> def -- | DECODING decode :: Model -> X -> Y decode m = fst . decode' (options m) (weights m `dot`) data Cell = Cell { cScore :: !Float , cPhi :: Global , cPath :: Y , cStep :: ListZipper Xi } deriving (Show,Eq) decode' :: Options -> Dot -> X -> (Y,Global) decode' opts w x = bestPath opts w [Cell { cScore = 0 , cPhi = Map.empty , cPath = [] , cStep = fromList x } ] phi :: Options -> X -> Y -> Global phi opts x y = foldl' f Map.empty . zip x . map reverse . tail . inits $ y where f z (xi,ys) = z `plus` toSV (features (oYMap opts) xi ys) {-# INLINE features #-} features :: YMap -> Xi -> [Yi] -> Local features (!zero,uni,bi) xi (y:ys) = case ys of [] -> (Local y $ zero V.++ xi) [y1] -> (Local y $ uni A.! y1 V.++ xi) (y1 : y2 : _) -> let r = bi A.! (y1,y2) in if V.null r then (Local y $ uni A.! y1 V.++ xi) else (Local y $ r V.++ xi) beamSearch :: Options -> Dot -> [Cell] -> [Cell] beamSearch opts w cs = let f cs = if any (atEnd . cStep) cs then cs else let cs' = [ let fs = features (oYMap opts) xi (y':ys) in Cell { cScore = s + w fs , cPhi = ph `plus` (toSV fs) , cPath = (y':ys) , cStep = next x } | Cell { cScore = s , cPhi = ph , cPath = ys , cStep = x } <- cs , let Just xi = focus x , y' <- yDictFind opts xi ] in f . take (oBeam opts) . sortBy (flip $ comparing cScore) $ cs' in f cs bestPath :: Options -> Dot -> [Cell] -> (Y, Global) bestPath opts w xs = let xs' = beamSearch opts w xs first = (\(x:_) -> x) xs' in ( reverse . cPath $ first , cPhi first ) -- | TRAINING iter :: Options -> Int -> [(X,Y)] -> (STRef s Int, WeightsST s, WeightsST s) -> ST s () iter opts _ ss (c,params,params_a) = do for_ ss $ \ (x,y) -> do params' <- AU.unsafeFreeze params let (y',phi_xy') = decode' opts (params'`dot`) x when (y' /= y) $ do let phi_xy = phi opts x y update = (phi_xy `minus` phi_xy') `scale` oRate opts params `plus_` update c' <- readSTRef c params_a `plus_` (update `scale` fromIntegral c') modifySTRef c (+1) type Trace = [(Double, Double, Double)] train :: Options -> [(X, Y)] -> [(X,Y)] -> (Model, Trace) train opts heldout ss = LST.runST (runWriterT (run opts heldout ss)) run :: Options -> [(X, Y)] -> [(X,Y)] -> WriterT Trace (LST.ST s) Model run opts heldout ss = do let bs = computeBounds opts ss --trace ("Param vector bounds: " ++ show bs) () `seq` return () params <- st $ newArray bs 0 params_a <- st $ newArray bs 0 c <- st $ newSTRef 1 erref <- st $ newSTRef [] let loop i = do st $ iter opts i ss (c, params, params_a) c' <- st $ readSTRef c params' <- st $ AU.unsafeFreeze params params_a' <- st $ AU.unsafeFreeze params_a let w = (fromIntegral c', params', params_a') pred xys = [ fst . decode' opts (w `dot'`) $ x | (x,_) <- xys ] err_train = Scores.errorRate (concatMap snd ss) (concat $ pred ss) err_dev = Scores.errorRate (concatMap snd heldout) (concat $ pred heldout) errs <- st $ readSTRef erref let errs' = (err_train, err_dev):errs st $ writeSTRef erref errs' let ch = change (oStopWinSize opts) errs' tell [(err_train, err_dev, ch)] when (continue opts i ch) $ loop (i+1) loop 1 st $ finalParams (c, params, params_a) arr <- st $ AU.unsafeFreeze params return $! Model { options = opts , weights = arr } st :: Monoid w => ST s a -> WriterT w (LST.ST s) a st = lift . LST.strictToLazyST change :: Int -> [(Double, Double)] -> Double change winsize errs = let mi = minimum . take winsize . map snd $ errs ma = maximum . take winsize . map snd $ errs in (ma - mi)/ma continue :: Options -> Int -> Double -> Bool continue opts i n | i >= oEpochs opts = False | i < winsize = True | isNaN n = True | True = n > threshold where threshold = oStopThreshold opts winsize = oStopWinSize opts finalParams :: (STRef s Int, WeightsST s, WeightsST s) -> ST s () finalParams (c,params,params_a) = do (l,u) <- getBounds params c' <- fmap fromIntegral (readSTRef c) for_ (range (l,u)) $ \i -> do e <- readArray params i e_a <- readArray params_a i writeArray params i (e - (e_a * (1/c'))) computeBounds :: Options -> [(X,Y)] -> (I,I) computeBounds opts xys = let ((yl,xl),(yh,xh)) = foldl' f ((maxBound,minimum xis) ,(minBound,maximum xis)) . (\(xs,ys) -> zip (concat xs) (concat ys)) . unzip $ xys in case oFeatBounds opts of Just (xl',xh') -> (I yl xl',I yh xh') Nothing -> (I yl xl,I yh xh) where f ((!miny,!minx),(!maxy,!maxx)) (xs,!y) = ((min miny y,V.minimum $ minx`V.cons`xs) ,(max maxy y,V.maximum $ maxx`V.cons`xs)) xis = let (zero,uni,bi) = oYMap opts in uniq . concatMap V.toList $ [zero] ++ (filter (not . V.null) . A.elems $ bi) ++ (filter (not . V.null) . A.elems $ uni)