{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} module Engine where import Irt import Statistics import Types import qualified Listable as L import Data.List import qualified Data.Map as M import qualified Data.Vector.Generic as V import Data.VectorSpace import Text.Printf data State = State { responses :: Responses , params :: TaskParams , thetas :: Thetas , logLikelihood :: Double , dparam :: TaskParam , dtheta :: Double , updatedTheta :: Bool , updatedParams :: Bool } deriving (Eq, Ord, Show) updateOne :: State -> State updateOne s@State {..} = s { logLikelihood = lh } where lh = totalLogLikelihood responses params thetas updateTwo :: State -> State -> State updateTwo old new = updateOne $ new { dparam = L.fromList dp, dtheta = dt} where [dt] = diff (thetas old) (thetas new) dp = diff (params old) (params new) diff xs ys = (max' . trans) (sub xs ys) where max' :: [[Double]] -> [Double] max' = map (maximum . map abs) trans = transpose . map (L.toList) . V.toList sub = V.zipWith (^-^) getTaskParamsList :: State -> [(Task, TaskParam)] getTaskParamsList State {..} = zip (V.toList $ tasks responses) (V.toList params) getThetasList :: State -> [(Contestant, Theta)] getThetasList State {..} = zip (V.toList $ contestants responses) (V.toList thetas) infinity :: Double infinity = 1.0/0.0 init :: Responses -> State init (responses@Responses {..}) = State { responses = responses , params = V.replicate (V.length tasks) defaultTask , thetas = V.replicate (V.length contestants) 0 , logLikelihood = -infinity , dparam = L.fromList $ repeat infinity , dtheta = infinity , updatedTheta = False , updatedParams = False } data EngineParams = EngineParams { algorithm :: Algorithm , maxRounds :: Int , precision :: Double , intPrec :: Double } updateThetas :: State -> [(Contestant, Theta)] -> State updateThetas s@State {..} xs = s { thetas = V.fromList . map (ts M.!) . V.toList $ contestants responses , updatedTheta = True } where ts = M.fromList xs updateParams :: State -> [(Task, TaskParam)] -> State updateParams s@State {..} xs = s { params = V.fromList . map (ps M.!) . V.toList $ tasks responses , updatedParams = True } where ps = M.fromList xs enhanceTaskParams :: EngineParams -> State -> State enhanceTaskParams EngineParams {..} s@State {..} = s { params = estimateAB $ groupByTask responses thetas params , updatedTheta = False , updatedParams = True } enhanceThetas :: EngineParams -> State -> State enhanceThetas EngineParams {..} s@State {..} = s { thetas = estimateTheta $ groupByContestant responses thetas params , updatedTheta = True , updatedParams = False } oneJML :: EngineParams -> State -> State oneJML engine s@State { updatedTheta } | updatedTheta = enhanceTaskParams engine s | otherwise = enhanceThetas engine s oneBFGS :: EngineParams -> State -> State oneBFGS EngineParams {..} s@State {..} = s { thetas = ts , params = ps , updatedTheta = True , updatedParams = True } where (ts,ps) = estimateBfgs precision responses thetas params oneRound :: EngineParams -> State -> State oneRound e@EngineParams { algorithm = JML } = oneJML e oneRound e@EngineParams { algorithm = LBFGSB } = oneBFGS e dparamC :: TaskParam -> Double dparamC a = paramC (a ^+^ a) - paramC a runEngine :: EngineParams -> State -> IO State runEngine engine@EngineParams {..} s = loop 1 [] s where loop n states state@State {..} | n > maxRounds = return state | magnitude dparam + dtheta < precision = return state | otherwise = do printf "L:%+14.8f " logLikelihood :: IO () printf "dA:%11.8f dB:%11.8f dC:%11.8f " (paramA dparam) (paramB dparam) (dparamC dparam) :: IO () printf "dT:%11.8f\n" dtheta :: IO () loop (n+1) (state:states) $ updateTwo state . oneRound engine $ state thetaStatistic :: StatisticType -> State -> IO Statistic thetaStatistic t State {..} = statTheta t $ groupByContestant responses thetas params taskStatistic :: StatisticType -> State -> IO Statistic taskStatistic t State {..} = statTask t $ groupByTask responses thetas params