--
-- Copyright (c) 2009-2010, ERICSSON AB All rights reserved.
-- 
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
-- 
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the ERICSSON AB nor the names of its contributors
--       may be used to endorse or promote products derived from this software
--       without specific prior written permission.
-- 
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-- ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
-- BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
-- OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-- SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-- INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-- CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-- ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
-- THE POSSIBILITY OF SUCH DAMAGE.
--

{-# LANGUAGE UndecidableInstances #-}

-- | This module gives a representation core programs as typed expressions (see
-- 'Expr' / 'Data').

module Feldspar.Core.Expr where



import Data.Function
import Data.Monoid
import Data.Unique

import Feldspar.Range
import Feldspar.Core.Types
import Feldspar.Core.Ref



-- | Typed core language expressions. A value of type @`Expr` a@ is a
-- representation of a program that computes a value of type @a@.
data Expr a
  where
    Val         :: a -> Expr a  -- XXX Temporary, only used by evalF
    Variable    :: Expr a  -- XXX Risky to rely on obs. sharing for bound variables.
    Value       :: Storable a => a -> Expr a
    Function    :: String -> (a -> b) -> Expr (a -> b)
    Application :: Expr (a -> b) -> Data a -> Expr b
    NoInline    :: String -> Ref (a :-> b) -> (Data a -> Expr b)

    IfThenElse
      :: Data Bool           -- Condition
      -> (a :-> b)           -- If branch
      -> (a :-> b)           -- Else branch
      -> (Data a -> Expr b)

    While
      :: (a :-> Bool)  -- Continue?
      -> (a :-> a)     -- Body
      -> Data a        -- Initial state
      -> Expr a        -- Final state

    Parallel
      :: Storable a
      => Data Length
      -> (Int :-> a)  -- Index mapping
      -> Expr [a]     -- Result vector



-- | A wrapper around 'Expr' to allow observable sharing (see
-- "Feldspar.Core.Ref") and for memoizing size information.
data Data a = Typeable a => Data
  { dataSize :: Size a
  , dataRef  :: Ref (Expr a)
  }

instance Eq (Data a)
  where
    (==) = (==) `on` dataRef

instance Ord (Data a)
  where
    compare = compare `on` dataRef

data a :-> b = Typeable a =>  -- Typeable needed by evalF
    Lambda (Data a -> Data b) (Data a) (Data b)



dataType :: forall a . Data a -> Tuple StorableType
dataType a@(Data _ _) = typeOf (dataSize a) (T::T a)

dataId :: Data a -> Unique
dataId = refId . dataRef

dataToExpr :: Data a -> Expr a
dataToExpr = deref . dataRef

{-# NOINLINE exprToData #-}
exprToData :: Typeable a => Size a -> Expr a -> Data a
exprToData sz a = Data sz (ref a)

{-# NOINLINE freshVar #-}
freshVar :: Typeable a => Size a -> Data a
freshVar sz = exprToData sz Variable

{-# NOINLINE lambda #-}
lambda :: Typeable a => Size a -> (Data a -> Data b) -> (a :-> b)
lambda sz f = Lambda f var (f var)
  where
    var = freshVar sz
  -- XXX It's assumed that `f` is only going to be applied to an argument whose
  --     size is `sz`.

apply :: (a :-> b) -> Data a -> Data b
apply (Lambda f _ _) = f

resultSize :: (a :-> b) -> Size b
resultSize (Lambda _ _ outp) = dataSize outp



(|$|) :: Expr (a -> b) -> Data a -> Expr b
f |$| a = Application f a

-- XXX Document these constructors. Currently, only _function is used for
-- ordinary functions. _function2 etc. are only used to construct tuples.
_function
    :: Typeable b
    => String -> (Size a -> Size b) -> (a -> b) -> (Data a -> Data b)
_function fun sizeProp f a = exprToData sz $ Function fun f |$| a
  where
    sz = sizeProp (dataSize a)

_function2
    :: Typeable c
    => String
    -> (Size a -> Size b -> Size c)
    -> (a -> b -> c)
    -> (Data a -> Data b -> Data c)
_function2 fun sizeProp f a b = exprToData sz $ Function fun f |$| a |$| b
  where
    sz = sizeProp (dataSize a) (dataSize b)

_function3
    :: Typeable d
    => String -> (Size a -> Size b -> Size c -> Size d)
    -> (a -> b -> c -> d)
    -> (Data a -> Data b -> Data c -> Data d)
_function3 fun sizeProp f a b c =
    exprToData sz $ Function fun f |$| a |$| b |$| c
  where
    sz = sizeProp (dataSize a) (dataSize b) (dataSize c)

_function4
    :: Typeable e
    => String
    -> (Size a -> Size b -> Size c -> Size d -> Size e)
    -> (a -> b -> c -> d -> e)
    -> (Data a -> Data b -> Data c -> Data d -> Data e)
_function4 fun sizeProp f a b c d =
    exprToData sz $ Function fun f |$| a |$| b |$| c |$| d
  where
    sz = sizeProp (dataSize a) (dataSize b) (dataSize c) (dataSize d)



tup2 :: (Typeable a, Typeable b) => Data a -> Data b -> Data (a,b)
tup2 = _function2 "tup2" (,) (,)

tup3 :: (Typeable a, Typeable b, Typeable c) =>
    Data a -> Data b -> Data c -> Data (a,b,c)
tup3 = _function3 "tup3" (,,) (,,)

tup4 :: (Typeable a, Typeable b, Typeable c, Typeable d) =>
    Data a -> Data b -> Data c -> Data d -> Data (a,b,c,d)
tup4 = _function4 "tup4" (,,,) (,,,)

get21 :: Typeable a => Data (a,b) -> Data a
get21 = _function "getTup21" get get
  where
    get (a,b) = a

get22 :: Typeable b => Data (a,b) -> Data b
get22 = _function "getTup22" get get
  where
    get (a,b) = b

get31 :: Typeable a => Data (a,b,c) -> Data a
get31 = _function "getTup31" get get
  where
    get (a,b,c) = a

get32 :: Typeable b => Data (a,b,c) -> Data b
get32 = _function "getTup32" get get
  where
    get (a,b,c) = b

get33 :: Typeable c => Data (a,b,c) -> Data c
get33 = _function "getTup33" get get
  where
    get (a,b,c) = c

get41 :: Typeable a => Data (a,b,c,d) -> Data a
get41 = _function "getTup41" get get
  where
    get (a,b,c,d) = a

get42 :: Typeable b => Data (a,b,c,d) -> Data b
get42 = _function "getTup42" get get
  where
    get (a,b,c,d) = b

get43 :: Typeable c => Data (a,b,c,d) -> Data c
get43 = _function "getTup43" get get
  where
    get (a,b,c,d) = c

get44 :: Typeable d => Data (a,b,c,d) -> Data d
get44 = _function "getTup44" get get
  where
    get (a,b,c,d) = d



-- | Computable types. A computable value completely represents a core program,
-- in such a way that @`internalize` `.` `externalize`@ preserves semantics, but
-- not necessarily syntax.
--
-- The terminology used in this class comes from thinking of the 'Data' type as
-- the \"internal\" core language and the "Feldspar.Core" API as the
-- \"external\" core language.
class Typeable (Internal a) => Computable a
  where
    -- | @`Data` (`Internal` a)@ is the internal representation of the type @a@.
    type Internal a

    -- | Convert to internal representation
    internalize :: a -> Data (Internal a)

    -- | Convert to external representation
    externalize :: Data (Internal a) -> a

instance Storable a => Computable (Data a)
  where
    type Internal (Data a) = a

    internalize = id
    externalize = id

instance (Computable a, Computable b) => Computable (a,b)
  where
    type Internal (a,b) = (Internal a, Internal b)

    internalize (a,b) = tup2 (internalize a) (internalize b)

    externalize ab =
        ( externalize (get21 ab)
        , externalize (get22 ab)
        )

instance (Computable a, Computable b, Computable c) => Computable (a,b,c)
  where
    type Internal (a,b,c) = (Internal a, Internal b, Internal c)

    internalize (a,b,c) = tup3
      (internalize a)
      (internalize b)
      (internalize c)

    externalize abc =
        ( externalize (get31 abc)
        , externalize (get32 abc)
        , externalize (get33 abc)
        )

instance
    ( Computable a
    , Computable b
    , Computable c
    , Computable d
    ) =>
      Computable (a,b,c,d)
  where
    type Internal (a,b,c,d) = (Internal a, Internal b, Internal c, Internal d)

    internalize (a,b,c,d) = tup4
      (internalize a)
      (internalize b)
      (internalize c)
      (internalize d)

    externalize abcd =
        ( externalize (get41 abcd)
        , externalize (get42 abcd)
        , externalize (get43 abcd)
        , externalize (get44 abcd)
        )



-- | Lower a function to operate on internal representation.
lowerFun :: (Computable a, Computable b) =>
    (a -> b) -> (Data (Internal a) -> Data (Internal b))
lowerFun f = internalize . f . externalize

-- | Lift a function to operate on external representation.
liftFun :: (Computable a, Computable b) =>
    (Data (Internal a) -> Data (Internal b)) -> (a -> b)
liftFun f = externalize . f . internalize



-- | The semantics of expressions
evalE :: Expr a -> a

evalE (Val a)           = a
evalE Variable          = error "evaluating free variable"
evalE (Value a)         = a
evalE (Function _ f)    = f
evalE (Application f a) = evalE f (evalD a)
evalE (NoInline _ f a)  = evalD (apply (deref f) a)

evalE (IfThenElse c t e a)
    | evalD c   = evalD (apply t a)
    | otherwise = evalD (apply e a)

evalE (While cont body init) =
    head $ dropWhile (evalF cont) $ iterate (evalF body) $ evalD init

evalE (Parallel l ixf) = map (evalF ixf) [0 .. evalD l-1]



-- | The semantics of 'Data'
evalD :: Data a -> a
evalD = evalE . dataToExpr

evalF :: (a :-> b) -> (a -> b)
evalF (Lambda f i o) = evalD . f . exprToData (dataSize i) . Val

-- | The semantics of any 'Computable' type
eval :: Computable a => a -> Internal a
eval = evalD . internalize



-- | A program that computes a constant value
value :: Storable a => a -> Data a
value a = exprToData (storableSize a) (Value a)

-- | Like 'value' but with an extra 'Size' argument that can be used to increase
-- the size beyond the given data.
--
-- Example 1:
--
-- > array (10 :> 20 :> universal) [] :: Data [[Int]]
--
-- gives an uninitialized 10x20 array of 'Int' elements.
--
-- Example 2:
--
-- > array (10 :> 20 :> universal) [[1,2,3]] :: Data [[Int]]
--
-- gives a 10x20 array whose first row is initialized to @[1,2,3]@.
array :: Storable a => Size a -> a -> Data a
array sz a = exprToData (sz `mappend` storableSize a) (Value a)

arrayLen :: Storable a => Data Length -> [a] -> Data [a]
arrayLen len = array sz
  where
    sz = mapMonotonic fromInteger (dataSize len) :> universal
  -- XXX This function is a temporary solution.

unit :: Data ()
unit = value ()

true :: Data Bool
true = value True

false :: Data Bool
false = value False

-- | Returns the size of each level of a multi-dimensional array, starting with
-- the outermost level.
size :: forall a . Storable a => Data [a] -> [Range Length]
size = listSize (T::T [a]) . dataSize

cap :: (Storable a, Size a ~ Range b, Ord b) => Range b -> Data a -> Data a
cap szb (Data sz a) = Data (sz /\ szb) a
  -- XXX Should really have the type
  --       cap :: Storable a => Size a -> Data a -> Data a



-- | Constructs a one-argument primitive function.
--
-- @`function` fun szf f@:
--
--   * @fun@ is the name of the function.
--
--   * @szf@ computes the output size from the input size.
--
--   * @f@   gives the evaluation semantics.
function
    :: (Storable a, Storable b)
    => String -> (Size a -> Size b) -> (a -> b) -> (Data a -> Data b)

function fun sizeProp f a = case dataToExpr a of
    Value a' -> exprToData sz $ Value (f a')
    _        -> _function fun sizeProp f a
  where
    sz = sizeProp (dataSize a)



-- | A two-argument primitive function
function2
    :: ( Storable a
       , Storable b
       , Storable c
       )
    => String
    -> (Size a -> Size b -> Size c)
    -> (a -> b -> c)
    -> (Data a -> Data b -> Data c)

function2 fun sizeProp f a b = case (dataToExpr a, dataToExpr b) of
    (Value a', Value b') -> exprToData sz $ Value (f a' b')
    _ -> _function fun (uncurry sizeProp) (uncurry f) (tup2 a b)
    -- XXX Should perhaps look like this instead:
    -- _ -> _function2 fun sizeProp f a b
  where
    sz = sizeProp (dataSize a) (dataSize b)



-- | A three-argument primitive function
function3
    :: ( Storable a
       , Storable b
       , Storable c
       , Storable d
       )
    => String
    -> (Size a -> Size b -> Size c -> Size d)
    -> (a -> b -> c -> d)
    -> (Data a -> Data b -> Data c -> Data d)

function3 fun sizeProp f a b c = case (d2e a, d2e b, d2e c) of
    (Value a', Value b', Value c') -> exprToData sz $ Value (f a' b' c')
    _ -> _function fun (uncurr sizeProp) (uncurr f) (tup3 a b c)
  where
    d2e = dataToExpr
    sz  = sizeProp (dataSize a) (dataSize b) (dataSize c)
    uncurr g (a,b,c) = g a b c



-- | A four-argument primitive function
function4
    :: ( Storable a
       , Storable b
       , Storable c
       , Storable d
       , Storable e
       )
    => String
    -> (Size a -> Size b -> Size c -> Size d -> Size e)
    -> (a -> b -> c -> d -> e)
    -> (Data a -> Data b -> Data c -> Data d -> Data e)

function4 fun sizeProp f a b c d = case (d2e a, d2e b, d2e c, d2e d) of
    (Value a', Value b', Value c', Value d') -> exprToData sz $ Value (f a' b' c' d')
    _ -> _function fun (uncurr sizeProp) (uncurr f) (tup4 a b c d)
  where
    d2e = dataToExpr
    sz  = sizeProp (dataSize a) (dataSize b) (dataSize c) (dataSize d)
    uncurr g (a,b,c,d) = g a b c d


-- | Look up an index in an array (see also '!')
getIx :: Storable a => Data [a] -> Data Int -> Data a
getIx arr = function2 "(!)" sizeProp f arr
  where
    sizeProp (_:>aSize) _ = aSize

    f as i
        | not (i `inRange` r) = error "getIx: index out of bounds"
        | i >= la             = error "getIx: reading garbage"
        | otherwise           = as !! i
      where
        l:>_ = dataSize arr
        r    = rangeByRange 0 (l-1)
        la   = length as



-- | @`setIx` arr i a@:
--
-- Replaces the value at index @i@ in the array @arr@ with the value @a@.
setIx :: Storable a => Data [a] -> Data Int -> Data a -> Data [a]
setIx arr = function3 "setIx" sizeProp f arr
  where
    sizeProp (l:>aSize) _ aSize' = l :> (aSize `mappend` aSize')

    f as i a
        | not (i `inRange` r) = error "setIx: index out of bounds"
        | i > la              = error "setIx: writing past initialized area"
        | otherwise           = take i as ++ [a] ++ drop (i+1) as
      where
        l:>_ = dataSize arr
        r    = rangeByRange 0 (l-1)
        la   = length as



infixl 9 !

class RandomAccess a
  where
    -- | The type of elements in a random access structure
    type Element a

    -- | Index lookup in a random access structure
    (!) :: a -> Data Int -> Element a

instance Storable a => RandomAccess (Data [a])
  where
    type Element (Data [a]) = Data a
    (!) = getIx



-- | Constructs a non-primitive, non-inlined function.
--
-- The normal way to make a non-primitive function is to use an ordinary Haskell
-- function, for example:
--
-- > myFunc x = x * 4 + 5
--
-- However, such functions are inevitably inlined into the program expression
-- when applied. @noInline@ can be thought of as a way to protect a function
-- against inlining (but later transformations may choose to inline anyway).
--
-- Ideally, it should be posssible to reuse such a function several times, but
-- at the moment this does not work. Every application of a @noInline@ function
-- results in a new copy of the function in the core program.
noInline :: (Computable a, Computable b) => String -> (a -> b) -> (a -> b)
noInline fun f a = liftFun (exprToData sz . NoInline fun (ref fLam)) a
  where
    fLam = lambda (dataSize $ internalize a) (lowerFun f)
    sz   = resultSize fLam



-- | @`ifThenElse` cond thenFunc elseFunc@:
--
-- Selects between the two functions @thenFunc@ and @elseFunc@ depending on
-- whether the condition @cond@ is true or false.
ifThenElse
    :: (Computable a, Computable b)
    => Data Bool -> (a -> b) -> (a -> b) -> (a -> b)

ifThenElse cond t e a = case dataToExpr cond of
    Value True  -> t a
    Value False -> e a
    _           -> liftFun (exprToData szb . IfThenElse cond thenLam elseLam) a
  where
    sza     = dataSize $ internalize a
    thenLam = lambda sza (lowerFun t)
    elseLam = lambda sza (lowerFun e)
    szb     = resultSize thenLam `mappend` resultSize elseLam



whileSized
    :: Computable state
    => Size (Internal state)
    -> Size (Internal state)
    -> (state -> Data Bool)
    -> (state -> state)
    -> (state -> state)

whileSized szInitCont szInitBody cont body =
    liftFun (exprToData szFinal . While contLam bodyLam)
  where
    contLam = lambda szInitCont (lowerFun cont)
    bodyLam = lambda szInitBody (lowerFun body)
    szFinal = universal  -- XXX The best we can do at the moment...



-- | While-loop
--
-- @while cont body :: state -> state@:
--
--   * @state@ is the type of the state.
--
--   * @cont@ determines whether or not to continue based on the current state.
--
--   * @body@ computes the next state from the current state.
--
--   * The result is a function from initial state to final state.
while
    :: Computable state
    => (state -> Data Bool)
    -> (state -> state)
    -> (state -> state)

while = whileSized universal universal



-- | Parallel array
--
-- @parallel l ixf@:
--
--   * @l@ is the length of the resulting array (outermost level).
--
--   * @ifx@ is a function that maps each index in the range @[0 .. l-1]@ to its
--     element.
--
-- Since there are no dependencies between the elements, the compiler is free to
-- compute the elements in any order, or even in parallel.
parallel :: Storable a => Data Length -> (Data Int -> Data a) -> Data [a]
parallel l ixf = exprToData szPar $ Parallel l ixfLam
  where
    szl    = dataSize l
    ixfLam = lambda (rangeByRange 0 (szl-1)) ixf
    szPar  = mapMonotonic fromIntegral szl :> resultSize ixfLam