--
-- Copyright (c) 2009-2010, ERICSSON AB All rights reserved.
-- 
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
-- 
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the ERICSSON AB nor the names of its contributors
--       may be used to endorse or promote products derived from this software
--       without specific prior written permission.
-- 
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-- ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
-- BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
-- OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-- SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-- INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-- CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-- ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
-- THE POSSIBILITY OF SUCH DAMAGE.
--

-- | A graph representation of core programs. A graph is a flat structure that
-- can be viewed as a program with a global scope. For example, the Haskell
-- program
--
-- > main x = f 1
-- >   where
-- >     f y = g 2
-- >       where
-- >         g z = x + z
--
-- might be represented by the following flat graph:
--
-- > graph = Graph
-- >   { graphNodes =
-- >       [ Node
-- >           { nodeId     = 0
-- >           , function   = Input
-- >           , input      = Tup []
-- >           , inputType  = Tup []
-- >           , outputType = intType
-- >           }
-- >       , Node
-- >           { nodeId     = 1
-- >           , function   = Input
-- >           , input      = Tup []
-- >           , inputType  = Tup []
-- >           , outputType = intType
-- >           }
-- >       , Node
-- >           { nodeId     = 2
-- >           , function   = Input
-- >           , input      = Tup []
-- >           , inputType  = Tup []
-- >           , outputType = intType
-- >           }
-- >       , Node
-- >           { nodeId     = 3
-- >           , function   = Function "(+)"
-- >           , input      = Tup [One (Variable (0,[])), One (Variable (2,[]))]
-- >           , inputType  = intPairType
-- >           , outputType = intType
-- >           }
-- >       , Node
-- >           { nodeId     = 4
-- >           , function   = NoInline "f" (Interface 1 (One (Variable (5,[]))) intType intType)
-- >           , input      = One (Constant (IntData 1))
-- >           , inputType  = intType
-- >           , outputType = intType
-- >           }
-- >       , Node
-- >           { nodeId     = 5
-- >           , function   = NoInline "g" (Interface 2 (One (Variable (3,[]))) intType intType)
-- >           , input      = One (Constant (IntData 2))
-- >           , inputType  = intType
-- >           , outputType = intType
-- >           }
-- >       ]
-- >
-- >   , graphInterface = Interface
-- >       { interfaceInput      = 0
-- >       , interfaceOutput     = One (Variable (4,[]))
-- >       , interfaceInputType  = intType
-- >       , interfaceOutputType = intType
-- >       }
-- >   }
-- >   where
-- >     intType     = result (typeOf :: Res [[[Int]]] (Tuple StorableType))
-- >     intPairType = result (typeOf :: Res (Int,Int) (Tuple StorableType))
--
-- XXX Check above code again
--
-- which corresponds to the following flat program
--
-- > main v0 = v4
-- > f v1    = v5
-- > g v2    = v3
-- > v3      = v0 + v2
-- > v4      = f 1
-- > v5      = g 2
--
-- There are a few assumptions on graphs:
--
-- * All nodes have unique identifiers.
--
-- * There are no cycles.
--
-- * The 'input' and 'inputType' tuples of each node should have the same shape.
--
-- * Each 'interfaceInput' (including the top-level one) refers to an 'Input'
-- node not referred to by any other interface.
--
-- * All 'Variable' references are valid (i.e. refer only to those variables
-- implicitly defined by each node).
--
-- * There should not be any cycles in the constraints introduced by
-- 'findLocalities'. (XXX Is this even possible?)
--
-- * Sub-function interfaces should be \"consistent\" with the input/output type
-- of the node. For example, the body of a while loop should have the same type
-- as the whole loop.
--
-- In the original program, @g@ was defined locally to @f@, and the addition was
-- done locally in @g@. But in the flat program, this hierarchy (called
-- /definition hierarchy/) is not represented. The flat program is of course not
-- valid Haskell (@v0@ and @v2@ are used outside of their scopes). The function
-- 'makeHierarchical' turns a flat graph into a hierarchical one that
-- corresponds to syntactically valid Haskell.
--
-- 'makeHierarchical' requires some explanation. First a few definitions:
--
-- * Nodes that have associated interfaces ('NoInline', 'IfThenElse', 'While'
-- and 'Parallel') are said to contain /sub-functions/. These nodes are called
-- /super nodes/. In the above program, the super node @v4@ contains the
-- sub-function @f@, and @v5@ contains the sub-function @g@.
--
-- * A definition @d@ is /local/ to a definition @e@ iff. @d@ is placed
-- somewhere within the definition of @e@ (i.e. inside an arbitrarily deeply
-- nested @where@ clause).
--
-- * A definition @d@ is /owned/ by a definition @e@ iff. @d@ is placed
-- immediately under the top-most @where@ clause of @e@. A definition may have
-- at most one owner.
--
-- The definition hierarchy thus specifies ownership between the definitions in
-- the program. There are two types of ownership:
--
-- * A super node is always the owner of its sub-functions.
--
-- * A sub-function may be the owner of some node definitions.
--
-- Assigning nodes to sub-functions in a useful way takes some work. It is done
-- by first finding out for each node which sub-functions it must be local to.
-- Each locality constraint gives an upper bound on where in the definition
-- hierarchy the node may be placed. There is one principle for introducing a
-- locality constraint:
--
-- * If node @v@ depends on the input of sub-function @f@, then @v@ must be
-- local to @f@.
--
-- The locality constraints for a graph can thus be found be tracing each
-- sub-function input in order to find the nodes that depend on it (see function
-- 'findLocalities'). In the above program, we have the sub-functions @f@ and
-- @g@ with the inputs @v1@ and @v2@ respectively. We can see immediately that
-- no node depends on @v1@, so we get no locality constraints for @f@. The only
-- node that depends on @v2@ is @v3@, so the program has a single locality
-- constraint: @v3@ is local to @g@. Nodes without constraints are simply taken
-- to be local to @main@. With this information, we can now rewrite the flat
-- program as
--
-- > main v0 = v4
-- >   where
-- >     v4 = f 1
-- >       where
-- >         f v1 = v5
-- >     v5 = g 2
-- >       where
-- >         g v2 = v3
-- >           where
-- >             v3 = v0 + v2
--
-- which is syntactically valid Haskell. Note that this program is slightly
-- different from the original which defined @g@ locally to @f@. However, in
-- general, we want definitions to be as \"global\" as possible in order to
-- maximize sharing. For example, we don't want to put definitions in the body
-- of a while loop unless they really depend on the loop state, because then
-- they will (probably, depending on implementation) be recomputed in every
-- iteration. Also note that in this program, it is not strictly necessary to
-- have the sub-functions owned by their super nodes -- @f@ and @g@ could have
-- been owned by @main@ instead. However, this would cause clashes if two
-- sub-functions have the same name. Having sub-functions owned by their super
-- nodes is also a way of keeping related definitions together in the program.
--
-- There is one caveat with the above method. Consider the following flat
-- program:
--
-- > main v0 = v4
-- > f v1    = v5
-- > g v2    = v3
-- > v3      = v1 + 2
-- > v4      = f 0
-- > v5      = g 1
--
-- Here, we get the locality constraint: @v3@ is local to @f@. However, to get a
-- valid definition hierarchy, we also need @v5@ to be local to @f@. This is
-- because @v5@ is the owner of @g@, and the output of @g@ is local to @f@. So
-- when looking for dependencies, we should let each super node depend on its
-- sub-function output, /except/ for the owner of the very sub-function that is
-- being traced (a function cannot be owned by itself).

module Feldspar.Core.Graph where



import qualified Data.Foldable as Fold
import Data.Function
import Data.List
import Data.Map (Map)
import qualified Data.Map as Map

import Feldspar.Utils
import Feldspar.Core.Types



-- | Node identifier
type NodeId = Int

-- | Variable represented by a node id and a tuple path. For example, in a
-- definition (given in Haskell syntax)
--
-- > ((a,b),c) = f x
--
-- the variable @b@ would be represented as @(i,[0,1])@ (where @i@ is the id of
-- the @f@ node).
type Variable = (NodeId, [Int])

-- | The source of a value is either constant data or a variable.
data Source
  = Constant PrimitiveData
  | Variable Variable
    deriving (Eq, Show)

-- | A node in the program graph. The input is given as a 'Source' tuple. The
-- output is implicitly defined by the 'nodeId' and the 'outputType'. For
-- example, a node with id @i@ and output type
--
-- > Tup [One ..., One ...]
--
-- has the implicit output
--
-- > Tup [One (i,[0]), One (i,[1])]
data Node = Node
  { nodeId     :: NodeId
  , function   :: Function
  , input      :: Tuple Source
  , inputType  :: Tuple StorableType
  , outputType :: Tuple StorableType
  }
    deriving (Eq, Show)

-- | The interface of a (sub-)graph. The input is conceptually a
-- @Tuple Variable@, but all these variables refer to the same 'Input' node, so
-- it is sufficient to track the node id (the tuple shape can be inferred from
-- the 'interfaceInputType').
data Interface = Interface
  { interfaceInput      :: NodeId
  , interfaceOutput     :: Tuple Source
  , interfaceInputType  :: Tuple StorableType
  , interfaceOutputType :: Tuple StorableType
  }
    deriving (Eq, Show)

-- | Node functionality
data Function
  =
    -- | Primary input
    Input
    -- | Constant array
  | Array StorableData
    -- | Primitive function
  | Function String
    -- | Non-inlined function
  | NoInline String Interface
    -- | Conditional
  | IfThenElse Interface Interface
    -- | While-loop
  | While Interface Interface
    -- | Parallel tiling
  | Parallel Interface
    deriving (Eq, Show)

-- | A graph is a list of unique nodes with an interface.
data Graph = Graph
  { graphNodes     :: [Node]
  , graphInterface :: Interface
  }

instance Eq Graph
  where
    Graph ns1 iface1 == Graph ns2 iface2
         = ns1'  == ns2'
        && iface1  == iface2
      where
        ns1' = sortBy (compare `on` nodeId) ns1
        ns2' = sortBy (compare `on` nodeId) ns2
      -- Comparison ignores order of nodes.

-- | A definition hierarchy. A hierarchy consists of number of top-level nodes,
-- each one associated with its sub-functions, represented as hierarchies. The
-- nodes owned by a sub-function appear as the top-level nodes in the
-- corresponding hierarchy.
data Hierarchy = Hierarchy [(Node, [Hierarchy])]

-- | A graph with a hierarchical ordering of the nodes. If the hierarchy is
-- flattened it should result in a valid 'Graph'.
data HierarchicalGraph = HierGraph
  { graphHierarchy     :: Hierarchy
  , hierGraphInterface :: Interface
  }

-- | A node that contains a sub-function
type SuperNode = NodeId

-- | The branch is used to distinguish between different sub-functions of the
-- same super node. For example, the continue condition of a while-loop has
-- branch number 0, and the body has number 1 (see 'subFunctions').
data SubFunction = SubFunction
       { sfSuper  :: SuperNode
       , sfBranch :: Int
       , sfInput  :: NodeId
       , sfOutput :: [NodeId]
       }
    deriving (Eq, Show)

instance Ord SubFunction
  where
    compare (SubFunction o1 b1 _ _) (SubFunction o2 b2 _ _) =
        compare (o1,b1) (o2,b2)
      -- Ignores inputs/outputs since these should be equal anyway if the super
      -- and branch fields are equal.

-- | Locality constraint
data Local = Local SubFunction NodeId
    deriving (Eq, Show)



-- | Returns the nodes in a source tuple.
sourceNodes :: Tuple Source -> [NodeId]
sourceNodes tup = [i | Variable (i,_) <- Fold.toList tup]

-- | The fanout of each node in a graph. Nodes that are not in the map are
-- assumed to have no fanout.
fanout :: Graph -> Map NodeId [NodeId]
fanout graph = Map.fromListWith (++)
    [ (inp, [nodeId node])
      | node <- graphNodes graph
      , inp  <- sourceNodes (input node)
    ]

-- | Look up a node in the graph
nodeMap :: Graph -> (NodeId -> Node)
nodeMap graph = (m Map.!)
  where
    m = Map.fromList [(nodeId node, node) | node <- graphNodes graph]



-- | Lists all sub-functions in the graph.
subFunctions :: Graph -> [SubFunction]
subFunctions graph =
    concat [subFun i fun | Node i fun _ _ _ <- graphNodes graph]
  where
    sub i branch (Interface inp outp _ _) =
      SubFunction i branch inp (sourceNodes outp)

    subFun i (NoInline _ f)    = [sub i 0 f]
    subFun i (IfThenElse t e)  = [sub i 0 t, sub i 1 e]
    subFun i (While cont body) = [sub i 0 cont, sub i 1 body]
    subFun i (Parallel ixf)    = [sub i 0 ixf]
    subFun _ _                 = []



-- | Lists all locality constraints of the graph.
findLocalities :: Graph -> [Local]
findLocalities graph = concatMap traceSub sfs
  where
    fo  = fanout graph
    sfs = subFunctions graph

    superLink = Map.fromListWith (++)
      [(outp,[super]) | SubFunction super _ _ outps <- sfs, outp <- outps]
      -- Fanout map with edges from sub-function output to super node

    traceSub sf@(SubFunction _ _ inp outps) = trace inp
      where
        trace a = Local sf a : concatMap trace bs
          where
            as = if a `elem` outps then [] else superLink !!! a
            bs = (fo !!! a) ++ as
      -- Computes locality constraints by tracing the dependencies of
      -- sub-function inputs.



-- | Returns a total ordering between all super nodes in a graph, such that if
-- node @v@ is local to sub-function @f@, then @v@ maps to a lower number than
-- the owner of @f@. The converse is not necessarily true. The second argument
-- gives the locality constraints for each node in the graph (top-level nodes
-- may be left undefined).
orderSuperNodes :: Graph -> Map NodeId [SubFunction] -> Map SuperNode Int
orderSuperNodes graph locals = Map.fromList $ zip (topSort sfOrder) [0..]
  where
    sfOrder = Map.fromListWith (++)
        [ (i, map sfSuper (locals !!! i))
          | SubFunction i _ _ _ <- subFunctions graph
        ]
      -- A partial ordering between all sub-functions. An edge from `f` to `g`
      -- means that `f` is local to `g`. This is a representation of the actual
      -- sub-function ordering which is the transitive closure of `sfOrder`.
      -- `sfOrder` is a dag.

-- | Returns the minimal sub-function according to the given owner ordering.
minimalSubFun :: Map SuperNode Int -> [SubFunction] -> SubFunction
minimalSubFun ownOrd = head . sortBy (compare `on` ((ownOrd Map.!) . sfSuper))

-- | Sorts the nodes by their id.
sortNodes :: [Node] -> [Node]
sortNodes = sortBy (compare `on` nodeId)



-- | Makes a hierarchical graph from a flat one. The node lists in the hierarchy
-- are always sorted according to node id.
makeHierarchical :: Graph -> HierarchicalGraph
makeHierarchical graph@(Graph nodes iface) =
    HierGraph (mkHierarchy topLevel) iface
  where
    locs = findLocalities graph

    locals :: Map NodeId [SubFunction]
    locals = Map.fromListWith (++) [(i,[sf]) | Local sf i <- locs]
      -- The locality constraints for each node. Nodes that are not in the map
      -- have no constraints.

    owner :: Map NodeId SubFunction
    owner = fmap (minimalSubFun $ orderSuperNodes graph locals) locals
      -- The owner of each node. Nodes that are not in the map have no owner.

    nodeLookup :: NodeId -> Node
    nodeLookup = nodeMap graph

    mkHierarchy :: [Node] -> Hierarchy
    mkHierarchy nodes = Hierarchy (nodes `zip` map subHierarchies nodes)

    subFunHier :: SuperNode -> Int -> Hierarchy
    subFunHier i branch = mkHierarchy nodes
      where
        ownedBy = fmap (sortNodes . map nodeLookup) $ invertMap owner
        sf      = SubFunction i branch undefined undefined
        nodes   = ownedBy Map.! sf
          -- Defined for every sub-function, because each sub-function contains
          -- at least one node (the input).

    subHierarchies :: Node -> [Hierarchy]
    subHierarchies (Node i (NoInline _ _) _ _ _)   = map (subFunHier i) [0]
    subHierarchies (Node i (IfThenElse _ _) _ _ _) = map (subFunHier i) [0,1]
    subHierarchies (Node i (While _ _) _ _ _)      = map (subFunHier i) [0,1]
    subHierarchies (Node i (Parallel _) _ _ _)     = map (subFunHier i) [0]
    subHierarchies _ = []

    topLevel :: [Node]
    topLevel = sortNodes
        [ nodeLookup i
          | node <- nodes
          , let i = nodeId node
          , Nothing <- [Map.lookup i owner]
        ]
      -- The nodes that don't have any owner

-------------------- 
-- show function 
-------------------- 

instance Show Graph where 
  show gr = prP 0 gr 

instance Show HierarchicalGraph where 
  show hgr = prP 0 hgr 


class PrP a where 
    prP :: Int -> a -> String 

tab sc = replicate sc ' ' 

listprint :: (a->String) -> String -> [a] -> String 
listprint _ _ [] = "" 
listprint f _ [x] = f x 
listprint f s (x:y:xs) = f x ++ s ++ listprint f s (y:xs) 

instance PrP Graph where 
  prP sc gr = tab sc ++ "Graph {\n" ++ tab (sc + 1) ++ "graphNodes = [\n" ++ prP (sc+2) (graphNodes gr)  
                 ++ "],\n" ++ tab (sc + 1) ++ "graphInterface = \n" ++ tab (sc + 3) ++ show (graphInterface gr) ++ "\n}" 

instance PrP [Node] where 
  prP sc ns = (listprint (\n -> (tab sc ++ prP sc n)) ",\n" ns) 
--  prP sc [] = "" 
--  prP sc [node] = tab sc ++ prP sc node ++ "\n" 
--  prP sc (node:ns) = tab sc ++ prP sc node ++ ",\n" ++ prP sc ns  

instance PrP Node where 
  prP sc node = "Node {nodeId = " ++ show (nodeId node) ++ ",\n"  
                   ++ tab (sc + 6) ++ "function = " ++ prP (sc+8) (function node) ++ ",\n"    
                    ++ tab (sc + 6) ++ "input = " ++ show (input node) ++ ",\n" 
                     ++ tab (sc + 6) ++ "inputType = " ++ show (inputType node) ++ ",\n" 
                      ++ tab (sc + 6) ++ "outputType = " ++ show (outputType node) ++ "}" 

instance PrP Function where 
  prP sc (IfThenElse if1 if2) = "\n" ++ tab (sc+1) ++ "IfThenElse\n" ++ tab (sc+2) ++ show if1 ++ "\n" 
                                   ++ tab (sc+2) ++ show if2 
  prP sc (Parallel if1) = "\n" ++ tab (sc+1) ++"Parallel " ++ "\n" ++ tab (sc+2) ++ show if1 
  prP sc (While if1 if2) = "\n" ++ tab (sc+1) ++ "While\n" ++ tab (sc+2) ++ show if1 ++ "\n" 
                                   ++ tab (sc+2) ++ show if2  
  prP sc (NoInline str if1) = "\n" ++ tab (sc+1) ++ "NoInline \"" ++ str ++"\" \n" ++ tab (sc+2) ++ show if1 
  prP sc x = show x  

instance PrP HierarchicalGraph where 
  prP sc hgr = "HierGraph {\n" ++ tab (sc+1) ++ "graphHierarchy =\n" ++ tab (sc+2) ++ prP (sc+2) (graphHierarchy hgr) 
                   ++ ",\n" ++ tab (sc+1) ++ "hierGraphInterface =\n" ++  tab (sc+2) ++ show (hierGraphInterface hgr) ++ "\n}" 

instance PrP Hierarchy where 
  prP sc (Hierarchy ndhrs) = "Hierarchy [\n" ++ prP (sc+1) ndhrs ++ "\n" ++ tab sc ++ "]"  

instance PrP [(Node, [Hierarchy])] where 
  prP sc nhrs = (listprint (prP sc) ",\n" nhrs) 
-- prP sc [] = "" 
--  prP sc [(node,hrs)] = tab sc ++ "(" ++ prP (sc+1) node ++ ",\n" ++ prP (sc+1) hrs ++ ")"  
--  prP sc ((node,hrs):ns) = tab sc ++ "(" ++ prP (sc+1) node ++ ",\n" ++ prP (sc+1) hrs ++ "),\n" ++ prP (sc+1) ns 

instance PrP (Node, [Hierarchy]) where 
  prP sc (node,hrs) = tab sc ++ "(" ++ prP (sc+1) node ++ ",\n" ++ tab sc ++ "[" ++ prP (sc+1) hrs ++ "])" 


instance PrP [Hierarchy] where 
  prP sc nhrs = (listprint (prP sc) (",\n" ++ tab sc) nhrs)