{-| hback is a dual n-back memory test based primarily on the work of: Jaeggi, Buschkuehl, et al. (2008) Improving Fluid Intelligence With Training on Working Memory. Proceedings of the National Academy of Sciences of the United States of America, 105(19), 6829-6833 Any reference in the comments to [Paper] refers to the above work. -} {-# OPTIONS -fbang-patterns #-} module Main where import Debug.Trace import System.Exit import IO import System.Cmd (system) import Directory (getDirectoryContents) import System.Environment (getArgs, getEnv) import System.FilePath (joinPath) import GHC.Conc (threadDelay) import System.Posix.Unistd (usleep) import Data.Time.Clock.POSIX (getPOSIXTime) import Data.List (intersperse) import Text.Printf import Control.Monad import Graphics.UI.Gtk hiding (fill) import Graphics.UI.Gtk.General.General import Graphics.UI.Gtk.Glade import Graphics.Rendering.Cairo import Graphics.Rendering.Cairo.SVG import Data.IORef import Random import Paths_hback -- ========== Data ========== type Visual = (Int, Int) type Audio = FilePath data Prediction = None | TruePositive | FalsePositive | FalseNegative | TrueNegative deriving (Eq, Show, Enum) type Level = Int data Timer = Timer Frac Total type Frac = Int type Total = Int data Game = Game { gameLevel :: Level, gameVisuals :: [(Visual, Maybe Bool)], gameAudios :: [(Audio, Maybe Bool)], gameVPreds :: [Prediction], gameAPreds :: [Prediction] } deriving Show data State = State { stateGames :: [Game], stateTimer :: Timer, stateGUI :: GUI, totalGames :: Int, statePaused :: Bool } data GUI = GUI { guiWindow :: Window, guiLevelLabel :: Label, guiScoreLabel :: Label, guiDrawArea :: DrawingArea, guiVButton :: ToggleButton, guiAButton :: ToggleButton } turnOffLogging = False -- |blocksize + n determines how many iterations each "game" takes -- 20 + n is what [Paper] used blockSize = 20 -- |timerFrac * timerFreq = 500ms shows stimuli + 2.5s pause ~ 3s per iteration (from [Paper]) timerFrac = 5 -- [0..5] = 6 loops timerFreq = 500 -- 500 ms -- |totalNumGames * (blockSize + n) * 3s / 60 =~ Total Gametime -- Based on [Paper], default is 20 to give about 20 minutes of memory training; totalNumGames = 20 -- |initLevel determines what N-Back level the game begins with; defaults to 1 initLevel = 1 gameTurn :: Game -> Int gameTurn g = length $ gameVPreds g newTimer :: Timer newTimer = Timer 0 timerFrac imageList :: IO [Visual] imageList = return $ remove (1,1) [(a,b) | a <- [0..2], b <- [0..2]] soundList :: IO [Audio] soundList = do d <- getDataDir let dir = joinPath [d, "sounds/"] l <- getDirectoryContents dir return $ map (dir ++) $ filter (\f -> tail f == ".wav") l predictionToInt :: Prediction -> Int predictionToInt = fromEnum -- ========== Scoring ========== addPredictions :: Game -> Prediction -> Prediction -> Game addPredictions (Game l v a vp ap) vp' ap' = Game l v a (vp ++ [vp']) (ap ++ [ap']) realPredictions :: [Prediction] -> [Prediction] realPredictions = filter (/= None) -- |gamescore vs as takes all visual and audio predictions for a specific and calculates -- a total score; for now naive score = (TruePositive + TrueNegative) / Total gameScore :: [Prediction] -> [Prediction] -> Double gameScore v' a' = num / den where v = realPredictions v' a = realPredictions a' s xs = fromIntegral (length (filter (\x -> x == TruePositive || x == TrueNegative) xs)) :: Double num = (s v) + (s a) den = fromIntegral (2 * length v) :: Double -- |chooseNextLevel old vPredictions aPredictions returns the next game level -- based on performance on the previous game; Same as protocol in [Paper] chooseNextLevel :: Int -> [Prediction] -> [Prediction] -> Int chooseNextLevel n v a | m1 < 3 && m2 < 3 = inc n | m1 + m2 > 5 = max 1 $ dec n | otherwise = n where m1 = miss v m2 = miss a miss xs = length $ filter (\x -> x == FalseNegative || x == FalsePositive) xs -- |score trueValue guessValue returns the appropriate logical prediction score :: Bool -> Bool -> Prediction score val ans | val && ans = TruePositive | not val && ans = FalsePositive | val && not ans = FalseNegative | not val && not ans = TrueNegative -- ========== Main ========== printUsage :: IO () printUsage = putStrLn "hback b n\n b is the number of tests [default=20]\n n determines the starting n-back test [default=1]" main = do args <- getArgs printUsage (!totalNumGames', !initLevel') <- case args of [] -> return (totalNumGames, initLevel) (a:[]) -> return (read a :: Int, initLevel) (a:b:[]) -> return ((read a :: Int), (read b :: Int)) initGUI gFile <- getDataFileName "hback.glade" windowXmlM <- xmlNew gFile let windowXml = case windowXmlM of (Just windowXml) -> windowXml Nothing -> error "Can't find the glade file \"hback.glade\" in the current directory" window <- xmlGetWidget windowXml castToWindow "hback" onDestroy window mainQuit label <- xmlGetWidget windowXml castToLabel "testLabel" scLabel <- xmlGetWidget windowXml castToLabel "scoreLabel" img <- xmlGetWidget windowXml castToDrawingArea "drawArea" visualBtn <- xmlGetWidget windowXml castToToggleButton "visualBtn" audioBtn <- xmlGetWidget windowXml castToToggleButton "audioBtn" stateRef <- newIORef $ State [] newTimer (GUI window label scLabel img visualBtn audioBtn) totalNumGames' False onKeyPress window (processEvent stateRef) widgetShowAll window logInitGame startNewGame stateRef initLevel' mainGUI startNewGame :: IORef State -> Level -> IO () startNewGame stateRef level = do imgList <- imageList sndList <- soundList preds <- shuffledPredictions level visuals <- matchStim imgList level (map fst preds) [] audios <- matchStim sndList level (map snd preds) [] let game = Game level visuals audios [] [] (State games _ gui tL p) <- readIORef stateRef writeIORef stateRef $ State (game:games) newTimer gui tL p labelSetText (guiLevelLabel gui) $ show level ++ "-Back Test" tmhandle <- timeoutAdd (timerInit stateRef) 500 return () makePredictions = take 2 (repeat (Just True, Just True)) ++ take 4 (repeat (Just True, Just False)) ++ take 4 (repeat (Just False, Just True)) ++ take (blockSize - 10) (repeat (Just False, Just False)) shuffledPredictions level = do let preds = makePredictions rands <- getRandomDecList ((length preds) - 1) return $ take level (repeat (Nothing, Nothing)) ++ shuffle preds rands matchStim :: Ord a => [a] -> Int -> [Maybe Bool] -> [(a, Maybe Bool)] -> IO [(a, Maybe Bool)] matchStim _ _ [] acc = return $ reverse acc matchStim orig level (p:ps) acc = do e <- case p of Nothing -> do e' <- randomElem orig return (e', Nothing) Just True -> do let e = fst $ head $ drop (dec level) acc return (e, Just True) Just False -> do let e = fst $ head $ drop (dec level) acc e' <- randomElem (remove e orig) return (e', Just False) matchStim orig level ps (e : acc) endGame :: IORef State -> IO () endGame stateRef = do state <- readIORef stateRef putStrLn "Game finished" sequence_ $ map (\(Game level _ _ vp ap) -> putStrLn ("Level " ++ show level ++ " : " ++ show (gameScore vp ap))) $ reverse $ stateGames state mainQuit exitWith ExitSuccess -- ========== Timers and Events ========== timerInit :: IORef State -> IO Bool timerInit stateRef = do state <- readIORef stateRef timerInit' stateRef state where timerInit' :: IORef State -> State -> IO Bool timerInit' stateRef state@(State _ tm@(Timer t tt) gui _ _) | statePaused state = -- game is paused return True | t == 0 = do renderImage (guiDrawArea gui) renderNewGame stateTick stateRef return True | t == tt = do tmhandle <- timeoutAdd (timer stateRef) timerFreq stateTick stateRef return False | otherwise = do stateTick stateRef return True timer :: IORef State -> IO Bool timer stateRef = do s <- readIORef stateRef timer' stateRef s timer' :: IORef State -> State -> IO Bool timer' stateRef state@(State games@(game:prevGames) tm@(Timer t tt) gui total paused) | statePaused state = -- game is paused return True | turn >= blockSize + gameLevel game = do -- current game finished logGame stateRef if (length games >= totalGames state) then do endGame stateRef else do startNewGame stateRef (chooseNextLevel (gameLevel game) (gameVPreds game) (gameAPreds game)) return False | otherwise = do let (vZ, vB) = gameVisuals game !! turn let (aZ, aB) = gameAudios game !! turn case t of 0 -> do renderImage (guiDrawArea gui) $ renderRect vZ playSound aZ toggleButtonSetActive (guiVButton gui) False toggleButtonSetActive (guiAButton gui) False 1 -> do renderImage (guiDrawArea gui) renderBlank _ -> when (t == tt) (do (vs', as') <- case (vB, aB) of (Just vB', Just aB') -> do b1 <- toggleButtonGetActive (guiVButton gui) b2 <- toggleButtonGetActive (guiAButton gui) return ((score vB' b1), (score aB' b2)) _ -> return (None, None) writeIORef stateRef $ State (addPredictions game vs' as' : prevGames) tm gui total paused) stateTick stateRef return True where turn = gameTurn game stateTick :: IORef State -> IO () stateTick stateRef = do (State g' t' gui' tg' p') <- readIORef stateRef writeIORef stateRef $ State g' (tick t') gui' tg' p' -- |processEvent stateRef event handles key events -- (toggling ToggleButtons with arrows and pause with 'p') processEvent :: IORef State -> Event -> IO Bool processEvent stateRef (Key {eventKeyName = keyName, eventModifier = evModifier, eventKeyChar = char}) = do state@(State g t gui tt p) <- readIORef stateRef case char of Just 'p' -> do case p of True -> renderImage (guiDrawArea gui) renderBlank False -> renderImage (guiDrawArea gui) renderPause writeIORef stateRef $ State g t gui tt $ not p return True Just 'l' -> do flipToggle $ guiAButton gui return True Just 'a' -> do flipToggle $ guiVButton gui return True _ -> return False where flipToggle btn = do p <- toggleButtonGetActive btn toggleButtonSetActive btn (not p) processEvent _ _ = return False tick :: Timer -> Timer tick (Timer t tt) | t' > tt = Timer 0 tt | otherwise = Timer t' tt where t' = inc t -- ========== Rendering ========== renderNewGame :: Int -> Int -> Render () renderNewGame w' h' = do setSourceRGB 0 0 0 paint setSourceRGB 1 1 1 setFontSize 30 moveTo (w/2 - 70) (h/4) showText "Ready?" setFontSize 20 moveTo (w/2 - 130) (h/6 * 3) showText "LeftArrow -> Sound" moveTo (w/2 - 130) (h/6 * 4) showText "RightArrow -> Graphic" moveTo (w/2 - 130) (h/6 * 5) showText " 'p' -> Pause" where w = fromIntegral w' :: Double h = fromIntegral h' :: Double renderPause :: Int -> Int -> Render () renderPause wU hU = do svgRenderFromString s where s = (printf "" w h) ++ (printf "" w h) ++ (printf "" (c * 6)) ++ "[Paused]" ++ "" w = min wU hU h = w -- make sure w and h make square c = (w `div` 150) * 5 -- a multiplier when window gets resized renderBlank :: Int -> Int -> Render () renderBlank w h = renderCross w h Nothing renderRect :: Visual -> Int -> Int -> Render () renderRect (x,y) w h = renderCross w h $ Just (x,y) renderCross wU hU m = svgRenderFromString s where s = (printf "" w h) ++ (printf "" w h) ++ (printf "" marg line x2 y2) ++ squareString ++ (printf "" line marg y2 x2) w = min wU hU h = w -- make sure w and h make square sq = w `div` 3 marg = sq `div` 10 line = marg * 8 x2 = sq + (sq `div` 2) - (marg `div` 2) y2 = sq + marg squareString = case m of Nothing -> "" Just (x, y) -> let x1 = marg + (sq * x) in let y1 = marg + (sq * y) in printf ("") line line x1 y1 renderImage :: DrawingArea -> (Int -> Int -> Render ()) -> IO () renderImage drawArea img = do (w,h) <- widgetGetSize drawArea drawin <- widgetGetDrawWindow drawArea renderWithDrawable drawin $ img w h return () playSound :: Audio -> IO () playSound f = do system $ "mplayer " ++ f ++ "> /dev/null &" return () -- ========== Utils ========== getLogFile :: IO FilePath getLogFile = do l <- getEnv "HOME" return $ joinPath [l, ".hback.scores.db"] logInitGame :: IO () logInitGame = do unless (turnOffLogging) (do t <- getPOSIXTime f <- getLogFile bracket (openFile f AppendMode) hClose (\h -> hPrintf h "%s\n" $ show t)) logGame :: IORef State -> IO () logGame stateRef = do unless (turnOffLogging) (do state <- readIORef stateRef f <- getLogFile let game = head $ stateGames state bracket (openFile f AppendMode) hClose (\h -> do hPrintf h "Level %d\n%s\n%s\n" (gameLevel game) (concat (intersperse " " (map (show . predictionToInt) (gameVPreds game)))) (concat (intersperse " " (map (show . predictionToInt) (gameAPreds game)))))) inc :: Int -> Int inc = (1+) dec :: Int -> Int dec n = n - 1 -- naive list shuffle -- shuffle elems choices. where: length choices == length elems - 1 shuffle :: [a] -> [Int] -> [a] shuffle [] _ = [] shuffle [e] [] = [e] shuffle elems (x:xs) | x >= (length elems) = error "shuffle: index too large" | otherwise = let (a,(e:rest)) = splitAt x elems in e : shuffle (a ++ rest) xs getRandomDecList :: Int -> IO [Int] getRandomDecList 0 = return [] getRandomDecList n = do r <- getStdRandom $ randomR (0, n) rs <- getRandomDecList $ dec n return $ r : rs randomElem :: [a] -> IO a randomElem lst = do i <- getStdRandom (randomR (0, dec (length lst))) return $ lst !! i remove :: Ord a => a -> [a] -> [a] remove a = filter (/= a)