{- Copyright 2010 Mario Blazevic This file is part of the Streaming Component Combinators (SCC) project. The SCC project is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. SCC is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with SCC. If not, see . -} -- | The Control.Monad.Coroutine.Enumerator tests. module Main where import Control.Exception (assert) import Control.Exception.Base (SomeException) import Control.Monad (liftM) import qualified Data.List as List import Data.Maybe (fromJust) import System.Environment (getArgs) import Debug.Trace import Data.Iteratee (Enumerator, Iteratee(..), Stream(..)) import Data.Iteratee (enumEof, (>>>)) import qualified Data.Iteratee.ListLike as Iteratee import Data.Functor.Identity (runIdentity) import Control.Monad.Coroutine import Control.Monad.Coroutine.Iteratee import Control.Monad.Coroutine.SuspensionFunctors (Await(Await), Yield(Yield), await, yield, awaitYieldResolver) import Control.Monad.Parallel sumCoroutine :: Monad m => Coroutine (Await [[Integer]]) m (Either SomeException (Integer, [[Integer]])) sumCoroutine = sum' 0 where sum' s = do ns <- await if null ns then return (Right (s, [])) else sum' (s + List.sum (List.concat ns)) yieldAll :: Monad m => [Integer] -> Coroutine (Yield [[Integer]]) m () yieldAll = mapM_ (yield . (:[])) . List.groupBy (\m n-> m `mod` 10 == n `mod` 10) listResolver = awaitYieldResolver{resumeLeft= \(Await c)-> c []} testSumCI :: Monad m => [Integer] -> m Integer -- testSumCI list = liftM (\(Enumerator.Yield s _)-> s) $ -- runIter =<< ((Iteratee.enumPureNChunk list 10 >>> enumEof) $ coroutineIteratee sumCoroutine) testSumCI list = do i <- (Iteratee.enumPureNChunk list 10 >>> enumEof) $ coroutineIteratee sumCoroutine runIter i (\sum _-> return sum) undefined testSumEC :: MonadParallel m => [Integer] -> m Integer testSumEC list = liftM (\(Right (s, _), _)-> s) $ seesaw bindM2 listResolver sumCoroutine (enumeratorCoroutine (Iteratee.enumPureNChunk list 10 >>> enumEof)) testSumCE :: Monad m => [Integer] -> m Integer -- testSumCE list = liftM (\(Enumerator.Yield s _)-> s) $ -- runIter =<< ((coroutineEnumerator (yieldAll list) >>> enumEof) $ Iteratee.sum) testSumCE list = do i <- (coroutineEnumerator (yieldAll list) >>> enumEof) $ Iteratee.sum runIter i (\sum _-> return sum) undefined testSumIC :: MonadParallel m => [Integer] -> m Integer testSumIC list = liftM (\(Right (s, _), _)-> s) $ seesaw bindM2 listResolver (iterateeCoroutine Iteratee.sum) (yieldAll list) testSum list = do s1 <- testSumCI list s2 <- testSumEC list s3 <- testSumCE list s4 <- testSumIC list assert (s1 == s2 && s2 == s3 && s3 == s4) (return s4) main = do args <- getArgs if List.length args /= 4 then putStr help else do let [taskName, monad, size, coroutineCount] = args task :: MonadParallel m => m Integer task = case taskName of "sum" -> testSum [1 .. read size] _ -> error (help ++ "Bad task.") result <- case monad of "Maybe" -> return $ fromJust task "[]" -> return $ List.head task "Identity" -> return $ runIdentity task "IO" -> task _ -> error (help ++ "Bad monad.") print result help = "Usage: test-iteratee ?\n" ++ " where is 'sum',\n" ++ " is 'Identity', 'Maybe', '[]', or 'IO',\n" ++ " is the size of the task,\n" ++ " and is the number of coroutines to employ.\n"