{-
Copyright 2010 Mario Blazevic
This file is part of the Streaming Component Combinators (SCC) project.
The SCC project is free software: you can redistribute it and/or modify it under the terms of the GNU General Public
License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later
version.
SCC is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty
of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
You should have received a copy of the GNU General Public License along with SCC. If not, see
.
-}
-- | The Control.Monad.Coroutine.Enumerator tests.
module Main where
import Control.Exception (assert)
import Control.Exception.Base (SomeException)
import Control.Monad (liftM)
import qualified Data.List as List
import Data.Maybe (fromJust)
import System.Environment (getArgs)
import Debug.Trace
import Data.Iteratee (Enumerator, Iteratee(..), Stream(..))
import Data.Iteratee (enumEof, (>>>))
import qualified Data.Iteratee.ListLike as Iteratee
import Data.Functor.Identity (runIdentity)
import Control.Monad.Coroutine
import Control.Monad.Coroutine.Iteratee
import Control.Monad.Coroutine.SuspensionFunctors (Await(Await), Yield(Yield), await, yield, awaitYieldResolver)
import Control.Monad.Parallel
sumCoroutine :: Monad m => Coroutine (Await [[Integer]]) m (Either SomeException (Integer, [[Integer]]))
sumCoroutine = sum' 0
where sum' s = do ns <- await
if null ns then return (Right (s, [])) else sum' (s + List.sum (List.concat ns))
yieldAll :: Monad m => [Integer] -> Coroutine (Yield [[Integer]]) m ()
yieldAll = mapM_ (yield . (:[])) . List.groupBy (\m n-> m `mod` 10 == n `mod` 10)
listResolver = awaitYieldResolver{resumeLeft= \(Await c)-> c []}
testSumCI :: Monad m => [Integer] -> m Integer
-- testSumCI list = liftM (\(Enumerator.Yield s _)-> s) $
-- runIter =<< ((Iteratee.enumPureNChunk list 10 >>> enumEof) $ coroutineIteratee sumCoroutine)
testSumCI list = do i <- (Iteratee.enumPureNChunk list 10 >>> enumEof) $ coroutineIteratee sumCoroutine
runIter i (\sum _-> return sum) undefined
testSumEC :: MonadParallel m => [Integer] -> m Integer
testSumEC list = liftM (\(Right (s, _), _)-> s) $
seesaw bindM2 listResolver sumCoroutine (enumeratorCoroutine (Iteratee.enumPureNChunk list 10
>>> enumEof))
testSumCE :: Monad m => [Integer] -> m Integer
-- testSumCE list = liftM (\(Enumerator.Yield s _)-> s) $
-- runIter =<< ((coroutineEnumerator (yieldAll list) >>> enumEof) $ Iteratee.sum)
testSumCE list = do i <- (coroutineEnumerator (yieldAll list) >>> enumEof) $ Iteratee.sum
runIter i (\sum _-> return sum) undefined
testSumIC :: MonadParallel m => [Integer] -> m Integer
testSumIC list = liftM (\(Right (s, _), _)-> s) $
seesaw bindM2 listResolver (iterateeCoroutine Iteratee.sum) (yieldAll list)
testSum list = do s1 <- testSumCI list
s2 <- testSumEC list
s3 <- testSumCE list
s4 <- testSumIC list
assert (s1 == s2 && s2 == s3 && s3 == s4) (return s4)
main = do args <- getArgs
if List.length args /= 4
then putStr help
else do let [taskName, monad, size, coroutineCount] = args
task :: MonadParallel m => m Integer
task = case taskName of "sum" -> testSum [1 .. read size]
_ -> error (help ++ "Bad task.")
result <- case monad of "Maybe" -> return $ fromJust task
"[]" -> return $ List.head task
"Identity" -> return $ runIdentity task
"IO" -> task
_ -> error (help ++ "Bad monad.")
print result
help = "Usage: test-iteratee ?\n"
++ " where is 'sum',\n"
++ " is 'Identity', 'Maybe', '[]', or 'IO',\n"
++ " is the size of the task,\n"
++ " and is the number of coroutines to employ.\n"