--
-- Copyright (c) 2009-2010, ERICSSON AB All rights reserved.
-- 
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
-- 
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the ERICSSON AB nor the names of its contributors
--       may be used to endorse or promote products derived from this software
--       without specific prior written permission.
-- 
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-- ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
-- BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
-- OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-- SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-- INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-- CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-- ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
-- THE POSSIBILITY OF SUCH DAMAGE.
--

{-# LANGUAGE FlexibleInstances #-}

module Feldspar.Compiler.Transformation.Lifting where

import Feldspar.Core.Graph
import Feldspar.Core.Types hiding (typeOf)
import Feldspar.Compiler.Transformation.GraphUtils
import Data.List


replaceNoInlines:: HierarchicalGraph -> HierarchicalGraph
replaceNoInlines g = HierGraph (replaceNoInlinesHr (graphHierarchy g)) (hierGraphInterface g)

replaceNoInlinesHrList:: [Hierarchy]-> [Hierarchy]
replaceNoInlinesHrList hrlist = map replaceNoInlinesHr hrlist

replaceNoInlinesHr:: Hierarchy-> Hierarchy
replaceNoInlinesHr (Hierarchy hrlist) = Hierarchy (map replaceNoInlinesNode hrlist)

replaceNoInlinesNode:: (Node,[Hierarchy]) ->  (Node,[Hierarchy])

replaceNoInlinesNode (n,hs) = 
      case function n of
            NoInline name interface -> case replaceList of
                                               [] -> (n,replaceNoInlinesHrList hs)
                                               _  -> (nNew, replaceNoInlinesHrList hsNew)
                 where 
                   replaceList = foldl (collectChangesHr (interfaceInput interface, hs)) (collectChangesInterface interface hs) hs
                   (nNew_, hsNew_) = changeInp (interfaceInput interface) replaceList (n,hs)
                   fullReplaceList = [((interfaceInput interface, []), inpVarsChange)] ++ (map fst replaceList)
                   (nNew, hsNew) = (nNew_{function= replaceVars fullReplaceList (function nNew_)}, map (replaceVars fullReplaceList) hsNew_)
            _ -> (n,replaceNoInlinesHrList hs)


changeInp::  NodeId -> [((Variable, Variable -> Variable), Tuple StorableType)] -> (Node, [Hierarchy]) -> (Node, [Hierarchy])
changeInp inpNode chLs (node, hs) = (newNode, newHs)
   where 
      newNode = Node (nodeId node) 
                     (addIfcInpTypes newTyps (function node)) 
                     (addInps (map (fst . fst) chLs) (input node)) 
                     (addInpTypes newTyps (inputType node))
                     (outputType node)
      newHs = map (addOutTypesHr inpNode newTyps) hs 
      newTyps = map snd chLs
      addInps vars input = Tup ([input] ++ (map (One . Variable) vars))
      addInpTypes types inpType = Tup ([inpType] ++ types)
      addIfcInpTypes types (NoInline str ifc@(Interface {interfaceInputType = ifcType})) 
             = NoInline str ifc{interfaceInputType = Tup ([ifcType] ++ types)}  
      addOutTypesHr:: NodeId -> [Tuple StorableType] -> Hierarchy -> Hierarchy
      addOutTypesHr id types (Hierarchy ndHrs) = Hierarchy (map (addOutTypesNode id types) ndHrs)      
      addOutTypesNode:: NodeId -> [Tuple StorableType] -> (Node, [Hierarchy]) -> (Node, [Hierarchy])
      addOutTypesNode id types (node@(Node {nodeId = nId, outputType=outType}) ,hs) 
             = if (id == nId) then (node{outputType = Tup ([outType] ++ types)}, hs) else (node, hs)      
      
collectChangesInterface :: Interface -> [Hierarchy] -> [((Variable, Variable -> Variable), Tuple StorableType)]
collectChangesInterface  iface hs =  map (genChange (interfaceInput iface)) $ zip [1..] $ filter ((mustChange hs) . fst) (tupleZipList (interfaceOutput iface, interfaceOutputType iface))
   
  
genChange:: NodeId -> (NodeId, (Source, StorableType)) -> ((Variable, Variable -> Variable), Tuple StorableType)
genChange inpId (index, (Variable (id, list), typ)) =  (((id, list), varChange inpId index) , One typ)

mustChange:: [Hierarchy] -> Source -> Bool
mustChange hs x
    =  case x of   
              (Variable (id, list)) -> (notInHr id hs)
              _                     -> False

inpVarsChange:: Variable -> Variable
inpVarsChange (id,list) = (id, [0] ++ list)

varChange:: NodeId -> Int -> Variable -> Variable
varChange  id index _ = (id, [index])



class CollectChangesHr a where
   collectChangesHr:: (NodeId, [Hierarchy]) -> [((Variable, Variable -> Variable), Tuple StorableType)] ->  a -> [((Variable, Variable -> Variable), Tuple StorableType)] 

instance CollectChangesHr Hierarchy where
  collectChangesHr nhs changesList (Hierarchy nodeHsList) = foldl (collectChangesHr nhs) changesList nodeHsList

instance CollectChangesHr (Node, [Hierarchy]) where
  collectChangesHr nhs changesList (node, hsList) = foldl (collectChangesHr nhs) (collectChangesHr nhs changesList node) hsList

instance CollectChangesHr Node where
  collectChangesHr nhs changesList node = collectChangesHr nhs (collectChangesHr nhs changesList (filter  ((mustChange (snd nhs)) . fst) (tupleZipList (input node, inputType node)))) (function node)

                
instance CollectChangesHr [(Source,StorableType)] where
  collectChangesHr (nodeId,hs) changesList sourceList = changesList ++ (map (genChange nodeId) $ zip [((length changesList) + 1)..] $ sourceList)

instance CollectChangesHr Function where
  collectChangesHr nhs changesList (NoInline _ ifc) = collectChangesHr nhs changesList ifc
  collectChangesHr nhs changesList (Parallel ifc) = collectChangesHr nhs changesList ifc
  collectChangesHr nhs changesList (IfThenElse ifc1 ifc2) = collectChangesHr nhs (collectChangesHr nhs changesList ifc1) ifc2
  collectChangesHr nhs changesList (While ifc1 ifc2) = collectChangesHr nhs (collectChangesHr nhs changesList ifc1) ifc2
  collectChangesHr (nodeId,hs) changesList _ = changesList 

instance CollectChangesHr Interface where
  collectChangesHr (nodeId,hs) changesList ifc = changesList ++ (map (genChange nodeId) $ zip [((length changesList) + 1)..] $ filter (mustChange hs . fst) (tupleZipList (interfaceOutput ifc, interfaceOutputType ifc)))

class NotInHr a where
    notInHr :: NodeId -> a -> Bool

instance NotInHr [Hierarchy] where
  notInHr id hs = and $ map (notInHr id) hs

instance NotInHr Hierarchy where
  notInHr id (Hierarchy nodeHs) = and $ map (notInHr id) nodeHs

instance NotInHr (Node, [Hierarchy]) where
  notInHr id (node, hs) = (notInHr id node) && (notInHr id hs)

instance NotInHr Node where
  notInHr id node = id /= (nodeId node)