-- Copyright (c) 2009, ERICSSON AB
-- All rights reserved.
--
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
--
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the ERICSSON AB nor the names of its contributors
--       may be used to endorse or promote products derived from this software
--       without specific prior written permission.
--
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

{-# LANGUAGE IncoherentInstances #-}

-- | This module represents core programs as typed expressions (see 'Expr' /
-- 'Data'). The idea is for programmers to use an interface based on 'Data',
-- while back-end tools use the 'Graph' representation. The function 'toGraph'
-- is used to convert between the two representations.

module Feldspar.Core.Expr where



import Control.Monad.State
import Control.Monad.Writer
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe
import Data.Unique

import Types.Data.Num

import Feldspar.Core.Ref (Ref)
import qualified Feldspar.Core.Ref as Ref
import Feldspar.Core.Types
import Feldspar.Core.Graph hiding (function, Function (..))
import qualified Feldspar.Core.Graph as Graph
import Feldspar.Core.Show



-- * Expressions

-- | A wrapper around 'Expr' to allow observable sharing (see
-- "Feldspar.Core.Ref").
data Data a = Typeable a => Data (Ref (Expr a))

instance Eq (Data a)
  where
    Data a == Data b = a==b

instance Ord (Data a)
  where
    Data a `compare` Data b = a `compare` b



ref :: Typeable a => Expr a -> Data a
ref = Data . Ref.ref

refId :: Data a -> Unique
refId (Data r) = Ref.refId r

deref :: Data a -> Expr a
deref (Data r) = Ref.deref r

typeOfData :: forall a . Typeable a => Data a -> Tuple StorableType
typeOfData _ = typeOf (T::T a)



-- | Typed core language expressions. A value of type @Expr a@ can be thought of
-- as a representation of a program that computes a value of type @a@.
data Expr a
  where
    Input :: Expr a  -- XXX Risky to rely on observable sharing?
    Value :: Storable a => a -> Expr a

    Tuple2   :: Data a -> Data b -> Expr (a,b)
    Tuple3   :: Data a -> Data b -> Data c -> Expr (a,b,c)
    Tuple4   :: Data a -> Data b -> Data c -> Data d -> Expr (a,b,c,d)
    GetTuple :: GetTuple n a => T n -> Data a -> Expr (Part n a)

    Function :: String -> (a -> b) -> (Data a -> Expr b)

    NoInline
      :: (Typeable a, Typeable b)
      => String -> Ref.Ref (Data a -> Data b) -> (Data a -> Expr b)

    IfThenElse
      :: (Typeable a, Typeable b)
      => Data Bool           -- Condition
      -> (Data a -> Data b)  -- If branch
      -> (Data a -> Data b)  -- Else branch
      -> (Data a -> Expr b)

    While
      :: Typeable a
      => (Data a -> Data Bool)  -- Continue?
      -> (Data a -> Data a)     -- Body
      -> Data a                 -- Initial state
      -> Expr a                 -- Final state

    Parallel
      :: (NaturalT n, Storable a)
      => Data Int              -- Dynamic size (must be <= array size)
      -> (Data Int -> Data a)  -- Index mapping
      -> Expr (n :> a)         -- Result vector

  -- XXX Some Typeable constraints are needed because the sub-functions need to
  --     be applied to input. Perhaps it's better to scrap the hidden context in
  --     Data and put Typeable context on all Expr constructors instead?



-- | Computable types. A computable value completely represents a core program,
-- in such a way that @internalize . externalize@ preserves semantics, but not
-- necessarily syntax.
--
-- The terminology used in this class comes from thinking of the 'Data' type as
-- the \"internal core language\" and the core API as the \"external core
-- language\".
class Typeable (Internal a) => Computable a
  where
    -- | The internal representation of the type @a@ (without the 'Data'
    -- constructor).
    type Internal a

    -- | Convert to internal representation
    internalize :: a -> Data (Internal a)

    -- | Convert to external representation
    externalize :: Data (Internal a) -> a

instance Storable a => Computable (Data a)
  where
    type Internal (Data a) = a

    internalize = id
    externalize = id

instance (Computable a, Computable b) => Computable (a,b)
  where
    type Internal (a,b) = (Internal a, Internal b)

    internalize (a,b) = ref $ Tuple2 (internalize a) (internalize b)

    externalize ab =
        ( externalize $ ref $ GetTuple (T::T D0) ab
        , externalize $ ref $ GetTuple (T::T D1) ab
        )

instance (Computable a, Computable b, Computable c) => Computable (a,b,c)
  where
    type Internal (a,b,c) = (Internal a, Internal b, Internal c)

    internalize (a,b,c) = ref $ Tuple3
      (internalize a)
      (internalize b)
      (internalize c)

    externalize abc =
        ( externalize $ ref $ GetTuple (T::T D0) abc
        , externalize $ ref $ GetTuple (T::T D1) abc
        , externalize $ ref $ GetTuple (T::T D2) abc
        )

instance
    ( Computable a
    , Computable b
    , Computable c
    , Computable d
    ) =>
      Computable (a,b,c,d)
  where
    type Internal (a,b,c,d) = (Internal a, Internal b, Internal c, Internal d)

    internalize (a,b,c,d) = ref $ Tuple4
      (internalize a)
      (internalize b)
      (internalize c)
      (internalize d)

    externalize abcd =
        ( externalize $ ref $ GetTuple (T::T D0) abcd
        , externalize $ ref $ GetTuple (T::T D1) abcd
        , externalize $ ref $ GetTuple (T::T D2) abcd
        , externalize $ ref $ GetTuple (T::T D3) abcd
        )



wrap :: (Computable a, Computable b) =>
    (a -> b) -> (Data (Internal a) -> Data (Internal b))
wrap f = internalize . f . externalize

unwrap :: (Computable a, Computable b) =>
    (Data (Internal a) -> Data (Internal b)) -> (a -> b)
unwrap f = externalize . f . internalize



-- | The semantics of expressions
evalE :: Expr a -> a

evalE Input     = error "evaluating Input"
evalE (Value a) = a

evalE (Tuple2 a b)     = (evalD a, evalD b)
evalE (Tuple3 a b c)   = (evalD a, evalD b, evalD c)
evalE (Tuple4 a b c d) = (evalD a, evalD b, evalD c, evalD d)
evalE (GetTuple n a)   = getTup n (evalD a)

evalE (Function _ f a)     = f (evalD a)
evalE (NoInline _ f a)     = evalD (Ref.deref f a)
evalE (IfThenElse c t e a) = if evalD c then evalD (t a) else evalD (e a)

evalE (While continue body init) = loop init
  where
    loop s
        | done      = evalD s
        | otherwise = loop (body s)
      where
        done = not $ evalD $ continue s

evalE (Parallel sz ixf) =
    mapArray (evalD . ixf . value) $ fromList [(0::Int) .. n-1]
  where
    n = evalD sz



-- | Evaluation of 'Data'
evalD :: Data a -> a
evalD = evalE . deref

-- | Evaluation of any 'Computable' type
eval :: Computable a => a -> Internal a
eval = evalD . internalize



instance Primitive a => Show (Data a)
  where
    show a = "... :: Data a"
  -- Needed for the @Num@ instance.

instance (Num n, Primitive n) => Num (Data n)
  where
    fromInteger = value . fromInteger
    abs         = functionFold "abs" abs
    signum      = functionFold "signum" signum
    (+)         = functionFold2 "(+)" (+)
    (-)         = functionFold2 "(-)" (-)
    (*)         = functionFold2 "(*)" (*)



instance Fractional (Data Float)
  where
    fromRational = value . fromRational
    (/)          = functionFold2 "(/)" (/)



-- | Internal function for constructing storable values.
value_ :: Storable a => a -> Data a
value_ = ref . Value

-- | A primitive value (a program that computes a constant value)
value :: Primitive a => a -> Data a
value = value_

unit :: Data ()
unit = value ()

true :: Data Bool
true = value True

false :: Data Bool
false = value False

-- | For example,
--
-- > array [[1,2,3],[4,5]] :: Data (D2 :> D4 :> Int)
--
-- is a 2x4-element array of @Int@s, with the first row initialized to @[1,2,3]@
-- and the second row to @[4,5]@.
array :: (NaturalT n, Storable a) => ListBased (n :> a) -> Data (n :> a)
array = value_ . fromList

-- | Returns the size of each level of a multi-dimensional array, starting with
-- the outermost level.
size :: (NaturalT n, Storable a) => Data (n :> a) -> [Int]
size arr = szs
  where
    One (StorableType szs _) = typeOfData arr



-- | A one-argument primitive function. The first argument is the name of the
-- function, and the second argument gives its evaluation semantics.
function :: (Storable a, Storable b) => String -> (a -> b) -> (Data a -> Data b)
function fun f = ref . Function fun f



-- | A two-argument function
function2
    :: ( Storable a
       , Storable b
       , Storable c
       )
    => String -> (a -> b -> c) -> (Data a -> Data b -> Data c)

function2 fun f a b = ref $ Function fun (\(a,b) -> f a b) (ref $ Tuple2 a b)



-- | A three-argument function
function3
    :: ( Storable a
       , Storable b
       , Storable c
       , Storable d
       )
    => String -> (a -> b -> c -> d) -> (Data a -> Data b -> Data c -> Data d)

function3 fun f a b c =
    ref $ Function fun (\(a,b,c) -> f a b c) (ref $ Tuple3 a b c)



-- | A four-argument function
function4
    :: ( Storable a
       , Storable b
       , Storable c
       , Storable d
       , Storable e
       )
    => String
    -> (a -> b -> c -> d -> e)
    -> (Data a -> Data b -> Data c -> Data d -> Data e)

function4 fun f a b c d =
    ref $ Function fun (\(a,b,c,d) -> f a b c d) (ref $ Tuple4 a b c d)



-- | A one-argument function with constant folding
functionFold
    :: (Storable a, Storable b) => String -> (a -> b) -> (Data a -> Data b)

functionFold fun f a = case deref a of
    Value a' -> value_ (f a')
    _        -> function fun f a



-- | A two-argument function with constant folding
functionFold2
    :: ( Storable a
       , Storable b
       , Storable c
       )
    => String -> (a -> b -> c) -> (Data a -> Data b -> Data c)

functionFold2 fun f a b = case (deref a, deref b) of
    (Value a', Value b') -> value_ (f a' b')
    _ -> function2 fun f a b



-- | A three-argument function with constant folding
functionFold3
    :: ( Storable a
       , Storable b
       , Storable c
       , Storable d
       )
    => String -> (a -> b -> c -> d) -> (Data a -> Data b -> Data c -> Data d)

functionFold3 fun f a b c = case (deref a, deref b, deref c) of
    (Value a', Value b', Value c') -> value_ (f a' b' c')
    _ -> function3 fun f a b c



-- | A four-argument function with constant folding
functionFold4
    :: ( Storable a
       , Storable b
       , Storable c
       , Storable d
       , Storable e
       )
    => String -> (a -> b -> c -> d -> e)
    -> (Data a -> Data b -> Data c -> Data d -> Data e)

functionFold4 fun f a b c d = case (deref a, deref b, deref c, deref d) of
    (Value a', Value b', Value c', Value d') -> value_ (f a' b' c' d')
    _ -> function4 fun f a b c d



-- | Look up an index in an array
getIx :: forall n a . (NaturalT n, Storable a) =>
    Data (n :> a) -> Data Int -> Data a

getIx = functionFold2 "(!)" f
  where
    f (ArrayList as) i
        | i >= n || i < 0 = error "getIx: index out of bounds"
        | i >= l          = error "getIx: reading garbage"
        | otherwise       = as !! i
      where
        n = fromIntegerT (undefined :: n)
        l = length as



-- | @setIx arr i a@:
--
-- Replaces the value at index @i@ in the array @arr@ with the value @a@.
setIx :: forall n a . (NaturalT n, Storable a) =>
    Data (n :> a) -> Data Int -> Data a -> Data (n :> a)

setIx = functionFold3 "setIx" f
  where
    f :: (n :> a) -> Int -> a -> (n :> a)
    f (ArrayList as) i a
        | i >= n || i < 0 = error "setIx: index out of bounds"
        | i > l           = error "setIx: writing past initialized area"
        | otherwise       = ArrayList $ take i as ++ [a] ++ drop (i+1) as
      where
        n = fromIntegerT (undefined :: n)
        l = length as



class RandomAccess a
  where
    type Elem a

    -- | Index lookup in random access structures
    (!) :: a -> Data Int -> Elem a

instance (NaturalT n, Storable a) => RandomAccess (Data (n :> a))
  where
    type Elem (Data (n :> a)) = Data a
    (!) = getIx



-- | Constructs a non-primitive, non-inlined function.
--
-- The normal way to make a non-primitive function is to use an ordinary Haskell
-- function, for example:
--
-- > myFunc x = x * 4 + 5
--
-- However, such functions are inevitably inlined into the program expression
-- when applied. @noInline@ can be thought of as a way to protect a function
-- against inlining (but later transformations may choose to inline anyway).
--
-- Ideally, it should be posssible to reuse such a function several times, but
-- at the moment this does not work. Every application of a @noInline@ function
-- results in a new copy of the function in the core program.
noInline :: (Computable a, Computable b) => String -> (a -> b) -> (a -> b)
noInline fun = unwrap . (ref .) . NoInline fun . Ref.ref . wrap



-- | @ifThenElse cond thenFunc elseFunc@:
--
-- Selects between the two functions @thenFunc@ and @elseFunc@ depending on
-- whether the condition @cond@ is true or false.
ifThenElse
    :: (Computable a, Computable b)
    => Data Bool -> (a -> b) -> (a -> b) -> (a -> b)

ifThenElse cond t e = case deref cond of
    Value True         -> t
    Value False        -> e
--     Function "not" _ c -> ifThenElse c e t
-- XXX Not possible...
    _ -> unwrap $ (ref .) $ IfThenElse cond (wrap t) (wrap e)



-- | @while cont body@:
--
-- A while-loop. The condition @cont@ determines whether the loop should
-- continue one more iteration. @body@ computes the next state. The result is a
-- function from initial state to final state.
while
    :: Computable a
    => (a -> Data Bool)
    -> (a -> a)
    -> (a -> a)

while cont = unwrap . (ref .) . While (cont . externalize) . wrap



-- | @parallel sz ixf@:
--
-- Parallel tiling. Computes the elements of a vector. @sz@ is the dynamic size,
-- i.e. how many of the allocated elements that should be computed. The function
-- @ixf@ maps each index to its value.
--
-- Since there are no dependencies between the elements, the compiler is free to
-- compute the elements in parallel (or any other order).
parallel :: (NaturalT n, Storable a) =>
    Data Int -> (Data Int -> Data a) -> Data (n :> a)
parallel sz = ref . Parallel sz



-- * Graph conversion

data Info = Info
  { -- | Next id
    index   :: NodeId
    -- | Visited references mapped to their id
  , visited :: Map Unique NodeId
  }

-- | Monad for making graph building easier
type GraphBuilder a = WriterT [Node] (State Info) a

startInfo :: Info
startInfo = Info 0 Map.empty

runGraph :: GraphBuilder a -> Info -> (a, ([Node], Info))
runGraph graph info = (a, (nodes, info'))
  where
    ((a,nodes),info') = runState (runWriterT graph) info

newIndex :: GraphBuilder NodeId
newIndex = do
    info <- get
    put (info {index = succ (index info)})
    return (index info)

remember :: Data a -> NodeId -> GraphBuilder ()
remember dat i = modify $ \info ->
    info {visited = Map.insert (refId dat) i (visited info)}

checkNode :: Data a -> GraphBuilder (Maybe NodeId)
checkNode dat = gets ((Map.lookup (refId dat)) . visited)

tupleBind :: Typeable a => NodeId -> T a -> Tuple Variable
tupleBind i = fmap (\path -> (i,path)) . tuplePath . typeOf



-- | Declare a node
node
    :: forall a . Typeable a
    => Data a
    -> Graph.Function
    -> Tuple Source
    -> Tuple StorableType
    -> GraphBuilder ()

node dat fun inTup inType = do
    i <- newIndex
    remember dat i
    let outType = typeOf (T::T a)
    tell [Node i fun inTup inType outType]



-- | Declare a source node (one with no inputs)
sourceNode :: Data a -> Graph.Function -> GraphBuilder ()
sourceNode dat@(Data _) fun = node dat fun (Tup []) (Tup [])



-- Creates a source. The node must have been visited.
source :: forall a . [Int] -> Data a -> GraphBuilder Source
source path a = case deref a of

    GetTuple n tup -> source (numberT n : path) tup

    Value a | isPrimitive (T::T a) ->
      let PrimitiveData a' = toData a
       in return $ Constant a'

    _ -> do
      Just i <- checkNode a
      return $ Variable (i,path)



traceTuple :: Data a -> GraphBuilder (Tuple Source)
traceTuple a = case deref a of

    Tuple2 b c -> do
      b' <- traceTuple b
      c' <- traceTuple c
      return (Tup [b',c'])

    Tuple3 b c d -> do
      b' <- traceTuple b
      c' <- traceTuple c
      d' <- traceTuple d
      return (Tup [b',c',d'])

    Tuple4 b c d e -> do
      b' <- traceTuple b
      c' <- traceTuple c
      d' <- traceTuple d
      e' <- traceTuple e
      return (Tup [b',c',d',e'])

    _ -> liftM One (source [] a)



buildGraph :: forall a . Data a -> GraphBuilder ()
buildGraph dat@(Data _) = do
    idat <- checkNode dat
    unless (isJust idat) $ list (deref dat)
  where
    funcNode fun inp@(Data _) = do
      buildGraph inp
      inTup <- traceTuple inp
      node dat fun inTup (typeOfData inp)

    list :: Expr a -> GraphBuilder ()

    list Input = sourceNode dat Graph.Input

    list (Value a)
      | isPrimitive (T::T a) = return ()
      | otherwise            = sourceNode dat $ Graph.Array $ toData a

    list (Tuple2 a b)     = buildGraph a >> buildGraph b
    list (Tuple3 a b c)   = buildGraph a >> buildGraph b >> buildGraph c
    list (Tuple4 a b c d) =
        buildGraph a >> buildGraph b >> buildGraph c >> buildGraph d

    list (GetTuple _ a) = buildGraph a

    list (Function fun _ a) = funcNode (Graph.Function fun) a

    list (NoInline fun f a) = do
      iface <- buildSubFun (Ref.deref f)
      funcNode (Graph.NoInline fun iface) a
      -- XXX Sub-graph is not shared at the moment.

    list (IfThenElse cond t e a) = do
      ifaceThen <- buildSubFun t
      ifaceElse <- buildSubFun e
      funcNode (Graph.IfThenElse ifaceThen ifaceElse) (ref $ Tuple2 cond a)

    list (While cont body a) = do
      ifaceCont <- buildSubFun cont
      ifaceBody <- buildSubFun body
      funcNode (Graph.While ifaceCont ifaceBody) a

    list (Parallel sz ixf) = do
        iface <- buildSubFun ixf
        funcNode (Graph.Parallel n iface) sz
      where
        One (StorableType (n:_) _) = typeOfData dat



buildSubFun :: forall a b . (Typeable a, Typeable b) =>
    (Data a -> Data b) -> GraphBuilder Interface

buildSubFun f = do
    let inp  = ref Input :: Data a
        outp = f inp
    buildGraph inp  -- Needed in case input is not used
    buildGraph outp
    outTup <- traceTuple outp
    info   <- get
    let inId    = visited info Map.! refId inp
        inType  = typeOf (T::T a)
        outType = typeOf (T::T b)
    return (Interface inId outTup inType outType)



toGraphD :: (Typeable a, Typeable b) => (Data a -> Data b) -> Graph
toGraphD f = Graph nodes iface
  where
    (iface,(nodes,_)) = runGraph (buildSubFun f) startInfo



-- | Types that represents core language programs
class Program a
  where
    -- | Converts a program to a Graph
    toGraph :: a -> Graph

    -- | Returns whether or not the program has an argument. This is needed
    -- because the 'Graph' type always assumes the existence of an input. So
    -- for programs without input, the 'Graph' representation will have a
    -- \"dummy\" input, which is indistinguishable from a real input.
    hasArg :: T a -> Bool

instance Computable a => Program a
  where
    toGraph a = toGraphD (const (internalize a) :: Data () -> Data (Internal a))
    hasArg    = const False

instance (Computable a, Computable b) => Program (a -> b)
  where
    toGraph = toGraphD . wrap
    hasArg  = const True

instance (Computable a, Computable b, Computable c)
      => Program (a -> b -> c)
  where
    toGraph = toGraph . uncurry
    hasArg  = const True

instance (Computable a, Computable b, Computable c, Computable d)
      => Program (a -> b -> c -> d)
  where
    toGraph f = toGraph (\(a,b,c) -> f a b c)
    hasArg    = const True

instance
    ( Computable a
    , Computable b
    , Computable c
    , Computable d
    , Computable e
    ) =>
      Program (a -> b -> c -> d -> e)
  where
    toGraph f = toGraph (\(a,b,c,d) -> f a b c d)
    hasArg    = const True



-- | Shows the core code generated by program.
showCore :: forall a . Program a => a -> String
showCore = showGraph "program" (hasArg (T::T a)) . toGraph

-- | @printCore = putStrLn . showCore@
printCore :: Program a => a -> IO ()
printCore = putStrLn . showCore