{- 
    Copyright 2008 Mario Blazevic

    This file is part of the Streaming Component Combinators (SCC) project.

    The SCC project is free software: you can redistribute it and/or modify it under the terms of the GNU General Public
    License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later
    version.

    SCC is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty
    of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more details.

    You should have received a copy of the GNU General Public License along with SCC.  If not, see
    <http://www.gnu.org/licenses/>.
-}

{-# LANGUAGE ScopedTypeVariables, MultiParamTypeClasses, FlexibleInstances, FunctionalDependencies,
             ExistentialQuantification, KindSignatures, Rank2Types, PatternSignatures #-}

module Control.Concurrent.SCC.ComponentTypes
   (-- * Classes
    Component (..), BranchComponent (combineBranches),
    -- * Types
    AnyComponent (AnyComponent), Performer (..), Consumer (..), Producer(..), Splitter(..), Transducer(..),
    ComponentConfiguration(..),
    -- * Lifting functions
    liftPerformer, liftConsumer, liftAtomicConsumer, liftProducer, liftAtomicProducer,
    liftTransducer, liftAtomicTransducer, lift121Transducer, liftStatelessTransducer, liftFoldTransducer, liftStatefulTransducer,
    liftSimpleSplitter, liftSectionSplitter, liftAtomicSimpleSplitter, liftAtomicSectionSplitter, liftStatelessSplitter,
    -- * Utility functions
    showComponentTree, optimalTwoParallelConfigurations, optimalTwoSequentialConfigurations, optimalThreeParallelConfigurations
   )
where

import Control.Concurrent.SCC.Foundation

import Control.Monad (liftM, when)
import Data.List (minimumBy)
import Data.Maybe
import Data.Typeable (Typeable, cast)

-- | 'AnyComponent' is an existential type wrapper around a 'Component'.
data AnyComponent = forall a. Component a => AnyComponent a

-- | The types of 'Component' class carry metadata and can be configured to use a specific number of threads.
class Component c where
   name :: c -> String
   -- | Returns the list of all children components.
   subComponents :: c -> [AnyComponent]
   -- | Returns the maximum number of threads that can be used by the component.
   maxUsableThreads :: c -> Int
   -- | Configures the component to use the specified number of threads. This function affects 'usedThreads', 'cost',
   -- and 'subComponents' methods of the result, while 'name' and 'maxUsableThreads' remain the same.
   usingThreads :: Int -> c -> c
   -- | The number of threads that the component is configured to use. By default the number is usually 1.
   usedThreads :: c -> Int
   -- | The cost of using the component as configured.
   cost :: c -> Int
   cost c = 1 + sum (map cost (subComponents c))

instance Component AnyComponent where
   name (AnyComponent c) = name c
   subComponents (AnyComponent c) = subComponents c
   maxUsableThreads (AnyComponent c) = maxUsableThreads c
   usingThreads n (AnyComponent c) = AnyComponent (usingThreads n c)
   usedThreads (AnyComponent c) = usedThreads c
   cost (AnyComponent c) = cost c

-- | Show details of the given component's configuration.
showComponentTree :: forall c. Component c => c -> String
showComponentTree c = showIndentedComponent 1 c

showIndentedComponent :: forall c. Component c => Int -> c -> String
showIndentedComponent depth c = showRightAligned 4 (cost c) ++ showRightAligned 3 (usedThreads c) ++ replicate depth ' '
                                ++ name c ++ "\n"
                                ++ concatMap (showIndentedComponent (succ depth)) (subComponents c)

showRightAligned :: Show x => Int -> x -> String
showRightAligned width x = let str = show x
                           in replicate (width - length str) ' ' ++ str

data ComponentConfiguration = ComponentConfiguration {componentChildren :: [AnyComponent],
                                                      componentThreads :: Int,
                                                      componentCost :: Int}

-- | A component that performs a computation with no inputs nor outputs is a 'Performer'.
data Performer m r = Performer {performerName :: String,
                                performerMaxThreads :: Int,
                                performerConfiguration :: ComponentConfiguration,
                                performerUsingThreads :: Int -> (ComponentConfiguration, forall c. Pipe c m r),
                                perform :: forall c. Pipe c m r}

-- | A component that consumes values from a 'Source' is called 'Consumer'.
-- data Consumer m x r = Consumer {consumerData :: ComponentData (forall c. Source c x -> Pipe c m r),
--                                 consume :: forall c. Source c x -> Pipe c m r}
data Consumer m x r = Consumer {consumerName :: String,
                                consumerMaxThreads :: Int,
                                consumerConfiguration :: ComponentConfiguration,
                                consumerUsingThreads :: Int -> (ComponentConfiguration, forall c. Source c x -> Pipe c m r),
                                consume :: forall c. Source c x -> Pipe c m r}

-- | A component that produces values and puts them into a 'Sink' is called 'Producer'.
data Producer m x r = Producer {producerName :: String,
                                producerMaxThreads :: Int,
                                producerConfiguration :: ComponentConfiguration,
                                producerUsingThreads :: Int -> (ComponentConfiguration, forall c. Sink c x -> Pipe c m r),
                                produce :: forall c. Sink c x -> Pipe c m r}

-- | The 'Transducer' type represents computations that transform data and return no result.
-- A transducer must continue consuming the given source and feeding the sink while there is data.
data Transducer m x y = Transducer {transducerName :: String,
                                    transducerMaxThreads :: Int,
                                    transducerConfiguration :: ComponentConfiguration,
                                    transducerUsingThreads :: Int -> (ComponentConfiguration,
                                                                      forall c. Source c x -> Sink c y -> Pipe c m [x]),
                                    transduce :: forall c. Source c x -> Sink c y -> Pipe c m [x]}

-- | The 'Splitter' type represents computations that distribute data acording to some criteria.  A splitter should
-- distribute only the original input data, and feed it into the sinks in the same order it has been read from the
-- source. If the two sink arguments of a splitter are the same, the splitter must act as an identity transform.
data Splitter m x = Splitter {splitterName :: String,
                              splitterMaxThreads :: Int,
                              splitterConfiguration :: ComponentConfiguration,
                              splitterUsingThreads :: Int -> (ComponentConfiguration,
                                                              forall c. Source c x -> Sink c x -> Sink c x -> Pipe c m [x],
                                                              forall c. Source c x -> Sink c (Maybe x) -> Sink c (Maybe x)
                                                                                   -> Pipe c m [x]),
                              split :: forall c. Source c x -> Sink c x -> Sink c x -> Pipe c m [x],
                              splitSections :: forall c. Source c x -> Sink c (Maybe x) -> Sink c (Maybe x) -> Pipe c m [x]}

instance Component (Performer m r) where
   name = performerName
   subComponents = componentChildren . performerConfiguration
   maxUsableThreads = performerMaxThreads
   usedThreads = componentThreads . performerConfiguration
   usingThreads threads performer = let (configuration', perform' :: forall c. Pipe c m r) = performerUsingThreads performer threads
                                    in performer{performerConfiguration= configuration', perform= perform'}
   cost = componentCost . performerConfiguration

instance Component (Consumer m x r) where
   name = consumerName
   subComponents = componentChildren . consumerConfiguration
   maxUsableThreads = consumerMaxThreads
   usedThreads = componentThreads . consumerConfiguration
   usingThreads threads consumer = let (configuration',
                                        consume' :: forall c. Source c x -> Pipe c m r) = consumerUsingThreads consumer threads
                                   in consumer{consumerConfiguration= configuration', consume= consume'}
   cost = componentCost . consumerConfiguration

instance Component (Producer m x r) where
   name = producerName
   subComponents = componentChildren . producerConfiguration
   maxUsableThreads = producerMaxThreads
   usedThreads = componentThreads . producerConfiguration
   usingThreads threads producer = let (configuration',
                                        produce' :: forall c. Sink c x -> Pipe c m r) = producerUsingThreads producer threads
                                   in producer{producerConfiguration= configuration', produce= produce'}
   cost = componentCost . producerConfiguration

instance Component (Transducer m x y) where
   name = transducerName
   subComponents = componentChildren . transducerConfiguration
   maxUsableThreads = transducerMaxThreads
   usedThreads = componentThreads . transducerConfiguration
   usingThreads threads transducer = let (configuration', transduce' :: forall c. Source c x -> Sink c y -> Pipe c m [x])
                                            = transducerUsingThreads transducer threads
                                     in transducer{transducerConfiguration= configuration', transduce= transduce'}
   cost = componentCost . transducerConfiguration

instance Component (Splitter m x) where
   name = splitterName
   subComponents = componentChildren . splitterConfiguration
   maxUsableThreads = splitterMaxThreads
   usedThreads = componentThreads . splitterConfiguration
   usingThreads threads splitter = let (configuration',
                                        split' :: forall c. Source c x -> Sink c x -> Sink c x -> Pipe c m [x],
                                        splitSections' :: forall c. Source c x -> Sink c (Maybe x) -> Sink c (Maybe x)
                                                       -> Pipe c m [x])
                                            = splitterUsingThreads splitter threads
                                     in splitter{splitterConfiguration= configuration',
                                                 split= split', splitSections= splitSections'}
   cost = componentCost . splitterConfiguration


-- | 'BranchComponent' is a type class representing all components that can act as consumers, namely 'Consumer',
-- 'Transducer', and 'Splitter'.
class BranchComponent cc m x r | cc -> m x where
   -- | 'combineBranches' is used to combine two components in 'BranchComponent' class into one, using the
   -- given 'Consumer' binary combinator.
   combineBranches :: String -> Int
                   -> (forall c. Bool -> (Source c x -> Pipe c m r) -> (Source c x -> Pipe c m r) -> (Source c x -> Pipe c m r))
                   -> cc -> cc -> cc

instance forall m x r. Monad m => BranchComponent (Consumer m x r) m x r where
   combineBranches name cost combinator c1 c2 = liftConsumer name 1 $
                                                \threads-> (ComponentConfiguration [AnyComponent c1, AnyComponent c2] 1 cost,
                                                            combinator False (consume c1) (consume c2))

instance forall m x. Monad m => BranchComponent (Consumer m x ()) m x [x] where
   combineBranches name cost combinator c1 c2 = liftConsumer name 1 $
                                                \threads-> (ComponentConfiguration [AnyComponent c1, AnyComponent c2] 1 cost,
                                                            liftM (const ())
                                                            . combinator False
                                                                 (\source-> consume c1 source >> return [])
                                                                 (\source-> consume c2 source >> return []))

instance forall m x y. BranchComponent (Transducer m x y) m x [x] where
   combineBranches name cost combinator t1 t2
      = liftTransducer name (maxUsableThreads t1 + maxUsableThreads t2) $
        \threads-> let (configuration, t1', t2', parallel) = optimalTwoParallelConfigurations threads t1 t2
                       transduce' source sink = combinator parallel
                                                   (\source-> transduce t1 source sink)
                                                   (\source-> transduce t2 source sink)
                                                   source
                   in (configuration, transduce')

instance forall m x. (ParallelizableMonad m, Typeable x) => BranchComponent (Splitter m x) m x [x] where
   combineBranches name cost combinator s1 s2
      = liftSimpleSplitter name (maxUsableThreads s1 + maxUsableThreads s2) $
        \threads-> let (configuration, s1', s2', parallel) = optimalTwoParallelConfigurations threads s1 s2
                       split' source true false = combinator parallel
                                                     (\source-> split s1 source true false)
                                                     (\source-> split s2 source true false)
                                                     source
                   in (configuration, split')

-- | Function 'liftPerformer' takes a component name, maximum number of threads it can use, and its 'usingThreads'
-- method, and returns a 'Performer' component.
liftPerformer :: String -> Int -> (Int -> (ComponentConfiguration, forall c. Pipe c m r)) -> Performer m r
liftPerformer name maxThreads usingThreads = case usingThreads 1
                                             of (configuration, perform) -> Performer name maxThreads configuration
                                                                                      usingThreads perform

-- | Function 'liftConsumer' takes a component name, maximum number of threads it can use, and its 'usingThreads'
-- method, and returns a 'Consumer' component.
liftConsumer :: String -> Int -> (Int -> (ComponentConfiguration, forall c. Source c x -> Pipe c m r)) -> Consumer m x r
liftConsumer name maxThreads usingThreads = case usingThreads 1
                                            of (configuration, consume) -> Consumer name maxThreads configuration
                                                                                    usingThreads consume

-- | Function 'liftProducer' takes a component name, maximum number of threads it can use, and its 'usingThreads'
-- method, and returns a 'Producer' component.
liftProducer :: String -> Int -> (Int -> (ComponentConfiguration, forall c. Sink c x -> Pipe c m r)) -> Producer m x r
liftProducer name maxThreads usingThreads = case usingThreads 1
                                            of (configuration, produce) -> Producer name maxThreads configuration
                                                                                    usingThreads produce

-- | Function 'liftTransducer' takes a component name, maximum number of threads it can use, and its 'usingThreads'
-- method, and returns a 'Transducer' component.
liftTransducer :: String -> Int -> (Int -> (ComponentConfiguration, forall c. Source c x -> Sink c y -> Pipe c m [x]))
               -> Transducer m x y
liftTransducer name maxThreads usingThreads = case usingThreads 1
                                              of (configuration, transduce) -> Transducer name maxThreads configuration
                                                                                          usingThreads transduce

-- | Function 'liftAtomicConsumer' lifts a single-threaded 'consume' function into a 'Consumer' component.
liftAtomicConsumer :: String -> Int -> (forall c. Source c x -> Pipe c m r) -> Consumer m x r
liftAtomicConsumer name cost consume = liftConsumer name 1 (\_threads-> (ComponentConfiguration [] 1 cost, consume))

-- | Function 'liftAtomicProducer' lifts a single-threaded 'produce' function into a 'Producer' component.
liftAtomicProducer :: String -> Int -> (forall c. Sink c x -> Pipe c m r) -> Producer m x r
liftAtomicProducer name cost produce = liftProducer name 1 (\_threads-> (ComponentConfiguration [] 1 cost, produce))

-- | Function 'liftAtomicTransducer' lifts a single-threaded 'transduce' function into a 'Transducer' component.
liftAtomicTransducer :: String -> Int -> (forall c. Source c x -> Sink c y -> Pipe c m [x]) -> Transducer m x y
liftAtomicTransducer name cost transduce = liftTransducer name 1 (\_threads-> (ComponentConfiguration [] 1 cost, transduce))

-- | Function 'lift121Transducer' takes a function that maps one input value to one output value each, and lifts it into
-- a 'Transducer'.
lift121Transducer :: (Monad m, Typeable x, Typeable y) => String -> (x -> y) -> Transducer m x y
lift121Transducer name f = liftAtomicTransducer name 1 $
                           \source sink-> let t = canPut sink
                                                  >>= flip when (getSuccess source (\x-> put sink (f x) >> t))
                                          in t >> return []

-- | Function 'liftStatelessTransducer' takes a function that maps one input value into a list of output values, and
-- lifts it into a 'Transducer'.
liftStatelessTransducer :: (Monad m, Typeable x, Typeable y) => String -> (x -> [y]) -> Transducer m x y
liftStatelessTransducer name f = liftAtomicTransducer name 1 $
                                 \source sink-> let t = canPut sink
                                                        >>= flip when (getSuccess source (\x-> putList (f x) sink >> t))
                                                in t >> return []

-- | Function 'liftFoldTransducer' creates a stateful transducer that produces only one output value after consuming the
-- entire input. Similar to 'Data.List.foldl'
liftFoldTransducer :: (Monad m, Typeable x, Typeable y) => String -> (s -> x -> s) -> s -> (s -> y) -> Transducer m x y
liftFoldTransducer name f s0 w = liftAtomicTransducer name 1 $
                                 \source sink-> let t s = canPut sink
                                                          >>= flip when (get source
                                                                         >>= maybe
                                                                                (put sink (w s) >> return ())
                                                                                (t . f s))
                                                in t s0 >> return []

-- | Function 'liftStatefulTransducer' constructs a 'Transducer' from a state-transition function and the initial
-- state. The transition function may produce arbitrary output at any transition step.
liftStatefulTransducer :: (Monad m, Typeable x, Typeable y) => String -> (state -> x -> (state, [y])) -> state -> Transducer m x y
liftStatefulTransducer name f s0 = liftAtomicTransducer name 1 $
                                   \source sink-> let t s = canPut sink
                                                            >>= flip when (getSuccess source
                                                                              (\x-> let (s', ys) = f s x
                                                                                    in putList ys sink >> t s'))
                                                  in t s0 >> return []

-- | Function 'liftStatelessSplitter' takes a function that assigns a Boolean value to each input item and lifts it into
-- a 'Splitter'.
liftStatelessSplitter :: (ParallelizableMonad m, Typeable x) => String -> (x -> Bool) -> Splitter m x
liftStatelessSplitter name f = liftAtomicSimpleSplitter name 1 $
                               \source true false-> let s = get source
                                                            >>= maybe
                                                                   (return [])
                                                                   (\x-> (if f x
                                                                          then put true x
                                                                          else put false x)
                                                                    >>= cond s (return [x]))
                                                    in s

-- | Function 'liftSimpleSplitter' lifts a simple, non-sectioning splitter function into a full 'Splitter'.
liftSimpleSplitter :: forall m x. (ParallelizableMonad m, Typeable x) =>
                      String -> Int
                             -> (Int -> (ComponentConfiguration, forall c. Source c x -> Sink c x -> Sink c x -> Pipe c m [x]))
                             -> Splitter m x
liftSimpleSplitter name maxThreads usingThreads
   = case usingThreads 1
     of (configuration, split) -> Splitter name maxThreads configuration usingThreads' split (splitSections split)
   where usingThreads' :: Int -> (ComponentConfiguration,
                                  forall c. Source c x -> Sink c x -> Sink c x -> Pipe c m [x],
                                  forall c. Source c x -> Sink c (Maybe x) -> Sink c (Maybe x) -> Pipe c m [x])
         usingThreads' threads = case usingThreads threads
                                 of (configuration, splitValues) -> (configuration, splitValues, splitSections splitValues)
         splitSections split source true false
            = liftM (fst . fst) $
              pipeD "liftSimpleSplitter true"
                    (\true'-> pipeD "liftSimpleSplitter false"
                                    (\false'-> split source true' false')
                                    (decorate false))
                    (decorate true)
         decorate sink source = transduce (lift121Transducer "Just" Just) source sink


-- | Function 'liftSectionSplitter' lifts a sectioning splitter function into a full 'Splitter'
liftSectionSplitter :: forall m x. (ParallelizableMonad m, Typeable x) =>
                       String -> Int -> (Int -> (ComponentConfiguration,
                                                 forall c. Source c x -> Sink c (Maybe x) -> Sink c (Maybe x) -> Pipe c m [x]))
                              -> Splitter m x
liftSectionSplitter name maxThreads usingThreads
   = case usingThreads 1
     of (configuration, splitSections) -> Splitter name 1 configuration usingThreads' (splitValues splitSections) splitSections
   where usingThreads' :: Int -> (ComponentConfiguration,
                                  forall c. Source c x -> Sink c x -> Sink c x -> Pipe c m [x],
                                  forall c. Source c x -> Sink c (Maybe x) -> Sink c (Maybe x) -> Pipe c m [x])
         usingThreads' threads = case usingThreads threads
                                 of (configuration, splitSections) -> (configuration, splitValues splitSections, splitSections)
         splitValues splitSections source true false
            = liftM (fst . fst) $
              pipeD "liftSectionSplitter true"
                    (\true'-> pipeD "liftSectionSplitter false" (\false'-> splitSections source true' false') (strip false))
                    (strip true)
         strip sink source = canPut sink
                             >>= flip when (getSuccess source (\x-> maybe (return False) (put sink) x >> strip sink source))

-- | Function 'liftAtomicSimpleSplitter' lifts a single-threaded 'split' function into a 'Splitter' component.
liftAtomicSimpleSplitter :: forall m x. (ParallelizableMonad m, Typeable x) =>
                      String -> Int -> (forall c. Source c x -> Sink c x -> Sink c x -> Pipe c m [x]) -> Splitter m x
liftAtomicSimpleSplitter name cost split = liftSimpleSplitter name 1 (\_threads-> (ComponentConfiguration [] 1 cost, split))

-- | Function 'liftAtomicSectionSplitter' lifts a single-threaded 'splitSections' function into a full 'Splitter'
-- component.
liftAtomicSectionSplitter :: forall m x. (ParallelizableMonad m, Typeable x) =>
                             String -> Int -> (forall c. Source c x -> Sink c (Maybe x) -> Sink c (Maybe x) -> Pipe c m [x])
                                    -> Splitter m x
liftAtomicSectionSplitter name cost splitSections = liftSectionSplitter name 1 $
                                                    \_threads-> (ComponentConfiguration [] 1 cost, splitSections)
   where configuration = ComponentConfiguration [] 1 1
         usingThreads :: Int -> (ComponentConfiguration,
                                 forall c. Source c x -> Sink c x -> Sink c x -> Pipe c m [x],
                                 forall c. Source c x -> Sink c (Maybe x) -> Sink c (Maybe x) -> Pipe c m [x])
         usingThreads threads = (configuration, splitValues, splitSections)
         splitValues source true false
            = liftM (fst . fst) $
              pipeD "liftSectionSplitter true"
                    (\true'-> pipeD "liftSectionSplitter false" (\false'-> splitSections source true' false') (strip false))
                    (strip true)
--         strip sink source = transduce (liftStatelessTransducer (maybe [] (:[]))) source sink
         strip sink source = canPut sink
                             >>= flip when (getSuccess source (\x-> maybe (return False) (put sink) x >> strip sink source))

-- | Function 'optimalTwoParallelConfigurations' configures two components, both of them with the full thread count, and
-- returns the components and a 'ComponentConfiguration' that can be used to build a new component from them.
optimalTwoSequentialConfigurations :: (Component c1, Component c2) => Int -> c1 -> c2 -> (ComponentConfiguration, c1, c2)
optimalTwoSequentialConfigurations threads c1 c2 = (configuration, c1', c2')
   where configuration = ComponentConfiguration
                            [AnyComponent c1', AnyComponent c2']
                            (usedThreads c1' `max` usedThreads c2')
                            (cost c1' + cost c2')
         c1' = usingThreads threads c1
         c2' = usingThreads threads c2

-- | Function 'optimalTwoParallelConfigurations' configures two components assuming they can be run in parallel,
-- splitting the given thread count between them, and returns the configured components, a 'ComponentConfiguration' that
-- can be used to build a new component from them, and a flag that indicates if they should be run in parallel or
-- sequentially for optimal resource usage.
optimalTwoParallelConfigurations :: (Component c1, Component c2) => Int -> c1 -> c2 -> (ComponentConfiguration, c1, c2, Bool)
optimalTwoParallelConfigurations threads c1 c2 = (configuration, c1', c2', parallelize)
   where parallelize = threads > 1 && parallelCost + 1 < sequentialCost
         configuration = ComponentConfiguration
                            [AnyComponent c1', AnyComponent c2']
                            (if parallelize then usedThreads c1' + usedThreads c2' else usedThreads c1' `max` usedThreads c2')
                            (if parallelize then parallelCost + 1 else sequentialCost)
         (c1', c2') = if parallelize then (c1p, c2p) else (c1s, c2s)
         (c1p, c2p, parallelCost) = minimumBy
                                       (\(_, _, cost1) (_, _, cost2)-> compare cost1 cost2)
                                       [let c2threads = threads - c1threads `min` maxUsableThreads c2
                                            c1i = usingThreads c1threads c1
                                            c2i = usingThreads c2threads c2
                                        in (c1i, c2i, cost c1i `max` cost c2i)
                                        | c1threads <- [1 .. threads - 1 `min` maxUsableThreads c1]]
         c1s = usingThreads threads c1
         c2s = usingThreads threads c2
         sequentialCost = cost c1s + cost c2s

-- | Function 'optimalThreeParallelConfigurations' configures three components assuming they can be run in parallel,
-- splitting the given thread count between them, and returns the components, a 'ComponentConfiguration' that can be
-- used to build a new component from them, and a flag per component that indicates if it should be run in parallel or
-- sequentially for optimal resource usage.
optimalThreeParallelConfigurations :: (Component c1, Component c2, Component c3) =>
                                      Int -> c1 -> c2 -> c3 -> (ComponentConfiguration, (c1, Bool), (c2, Bool), (c3, Bool))
optimalThreeParallelConfigurations threadCount c1 c2 c3 = undefined