{- Copyright 2010 Mario Blazevic This file is part of the Streaming Component Combinators (SCC) project. The SCC project is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. SCC is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with SCC. If not, see . -} -- | The "Control.Monad.Coroutine" tests. module Main where import Prelude hiding (sequence) import Control.Exception (assert) import Control.Monad (liftM, mapM, when) import Control.Parallel (pseq) import Data.Functor.Compose (Compose(..)) import Data.Functor.Identity (runIdentity) import Data.List (find) import Data.Maybe (fromJust) import System.Environment (getArgs) import Control.Monad.Coroutine import Control.Monad.Coroutine.SuspensionFunctors import Control.Monad.Coroutine.Nested import Control.Monad.Parallel (MonadParallel, bindM2, liftM2, sequence) factors n = maybe [n] (\k-> (k : factors (n `div` k))) (find (\k-> n `mod` k == 0) [2 .. n - 1]) fib x 0 | x >= 0 = 1 fib _ 1 = 1 fib x n = fib x (n - 2) + fib x (n - 1) factorFibs :: MonadParallel m => [Int] -> m Integer factorFibs nums = liftM snd $ seesaw bindM2 (SeesawResolver resumeLeft resumeRight resumeBoth) (mapM_ (yieldApply (fib 0)) nums) (factorize 0) where factorize :: MonadParallel m => Integer -> Coroutine (Await (Maybe Integer)) m Integer factorize sum = await >>= maybe (return sum) (\n-> factorize (sum + n {-product (factors n)-})) resumeLeft (Yield _ c) = c resumeRight (Await c) = c Nothing resumeBoth c (Yield x c1) (Await c2) = c c1 (c2 (Just x)) twoFibs :: MonadParallel m => [Int] -> m Integer twoFibs nums = pogoStick resume (couple bindM2 (fibs 1) (fibs 2)) >>= \(x, y)-> return (x + y) where resume :: SomeFunctor (Yield Integer) (Yield Integer) c -> c resume (Both (Compose (Yield n1 (Yield n2 c)))) = assert (n1 == n2) c fibs ix = mapM_ (yieldApply (fib ix)) nums >> applyM (fib ix) (last nums) twoFibsSeesaw :: MonadParallel m => [Int] -> m Integer twoFibsSeesaw nums = liftM (uncurry (+)) $ seesaw bindM2 resolver (fibs 1) (fibs 2) where resolver = SeesawResolver{ resumeLeft= undefined, resumeRight= undefined, resumeBoth= \cont (Yield left c1) (Yield right c2)-> assert (left == right) $ cont c1 c2 } fibs ix = mapM_ (yieldApply (fib ix)) nums >> applyM (fib ix) (last nums) fibs :: MonadParallel m => Int -> [Int] -> m Integer fibs coroutineCount nums = liftM sum $ pogoStick resume (merge sequence appendYields $ replicateIx coroutineCount fibs) where resume :: Yield [Integer] (Coroutine (Yield [Integer]) m [Integer]) -> Coroutine (Yield [Integer]) m [Integer] resume (Yield (x:xs) c) = assert (all (==x) xs) c fibs ix = mapM_ (yieldApply ((:[]) . fib ix)) nums >> applyM (fib ix) (last nums) appendYields :: [Yield [s] x] -> Yield [s] [x] appendYields yields = uncurry Yield $ foldr (\(Yield s x) (ss, xs)-> (s ++ ss, x:xs)) ([], []) yields yieldApply f n = let result = f n in result `pseq` yield result applyM f n = let result = f n in result `pseq` return result replicateIx :: Int -> (Int -> x) -> [x] replicateIx n f = map f [1..n] nested :: (Monad m, Functor p) => Int -> (Integer -> Coroutine p m ()) -> Coroutine (EitherFunctor p (Yield Integer)) m () nested level suspendParent = do mapSuspension RightF (yield 1) liftAncestor (suspendParent 2) when (level > 0) (pogoStickNested cont $ nested (pred level) (liftAncestor . suspendParent)) where cont (Yield x c) = c main = do args <- getArgs if length args /= 4 then putStr help else do let [taskName, monad, size, coroutineCount] = args task :: MonadParallel m => m Integer task = case taskName of "fib-factor" -> factorFibs [1 .. read size] "2fibs" -> twoFibs [1 .. read size] "2fibsSeesaw" -> twoFibsSeesaw [1 .. read size] "fibs" -> fibs (read coroutineCount) [1 .. read size] "nested" -> liftM fst $ foldRun add 0 (nested (read size) yield) where add s (LeftF (Yield n c)) = (s + n, c) add s (RightF (Yield n c)) = (s + 10 * n, c) _ -> error (help ++ "Bad task.") result <- case monad of "Maybe" -> return $ fromJust task "[]" -> return $ head task "Identity" -> return $ runIdentity task "IO" -> task _ -> error (help ++ "Bad monad.") print result help = "Usage: test-coroutine ?\n" ++ " where is 'fib-factor' or 'fibs',\n" ++ " is 'Identity', 'Maybe', '[]', or 'IO',\n" ++ " is the size of the task,\n" ++ " and is the number of coroutines to employ.\n"