-- Copyright (c) 2009-2010, ERICSSON AB
-- All rights reserved.
--
-- Redistribution and use in source and binary forms, with or without
-- modification, are permitted provided that the following conditions are met:
--
--     * Redistributions of source code must retain the above copyright notice,
--       this list of conditions and the following disclaimer.
--     * Redistributions in binary form must reproduce the above copyright
--       notice, this list of conditions and the following disclaimer in the
--       documentation and/or other materials provided with the distribution.
--     * Neither the name of the ERICSSON AB nor the names of its contributors
--       may be used to endorse or promote products derived from this software
--       without specific prior written permission.
--
-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-- AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-- IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

{-# LANGUAGE UndecidableInstances #-}

-- | This module gives a representation core programs as typed expressions (see
-- 'Expr' / 'Data').

module Feldspar.Core.Expr where



import Data.Monoid
import Data.Unique

import Feldspar.Range
import Feldspar.Core.Types
import Feldspar.Core.Ref



-- | Typed core language expressions. A value of type @`Expr` a@ is a
-- representation of a program that computes a value of type @a@.
data Expr a
  where
    Input :: Size a -> Expr a
      -- XXX Risky to rely on observable sharing?

    Value :: Storable a => Size a -> a -> Expr a

    Tuple2 :: Data a -> Data b -> Expr (a,b)
    Tuple3 :: Data a -> Data b -> Data c -> Expr (a,b,c)
    Tuple4 :: Data a -> Data b -> Data c -> Data d -> Expr (a,b,c,d)
      -- XXX Tuple construction should be generalized.

    Get21 :: Data (a,b) -> Expr a
    Get22 :: Data (a,b) -> Expr b

    Get31 :: Data (a,b,c) -> Expr a
    Get32 :: Data (a,b,c) -> Expr b
    Get33 :: Data (a,b,c) -> Expr c

    Get41 :: Data (a,b,c,d) -> Expr a
    Get42 :: Data (a,b,c,d) -> Expr b
    Get43 :: Data (a,b,c,d) -> Expr c
    Get44 :: Data (a,b,c,d) -> Expr d
      -- XXX Tuple projection should be generalized.

    Function :: String -> Size b -> (a -> b) -> (Data a -> Expr b)

    NoInline :: String -> Ref (a :-> b) -> (Data a -> Expr b)

    IfThenElse
      :: Data Bool           -- Condition
      -> (a :-> b)           -- If branch
      -> (a :-> b)           -- Else branch
      -> (Data a -> Expr b)

    While
      :: (a :-> Bool)  -- Continue?
      -> (a :-> a)     -- Body
      -> Data a        -- Initial state
      -> Expr a        -- Final state

    Parallel
      :: Storable a
      => Data Length
      -> (Int :-> a)  -- Index mapping
      -> Expr [a]     -- Result vector



data a :-> b = SubFunction (Data a -> Data b) (Data a) (Data b)



-- | A wrapper around 'Expr' to allow observable sharing (see
-- "Feldspar.Core.Ref") and for memoizing size information.
data Data a = Typeable a => Data (Size a) (Ref (Expr a))

instance Eq (Data a)
  where
    Data _ a == Data _ b = a==b
      -- Reference equality

instance Ord (Data a)
  where
    Data _ a `compare` Data _ b = a `compare` b
      -- Reference comparison



dataSize :: Data a -> Size a
dataSize (Data sz _) = sz

dataType :: forall a . Data a -> Tuple StorableType
dataType a@(Data _ _) = typeOf (dataSize a) (T::T a)

dataId :: Data a -> Unique
dataId (Data _ r) = refId r

dataToExpr :: Data a -> Expr a
dataToExpr (Data _ r) = deref r

subFunSize :: (a :-> b) -> Size b
subFunSize (SubFunction _ _ outp) = dataSize outp

subAp :: (a :-> b) -> (Data a -> Data b)
subAp (SubFunction f _ _) = f

exprToData :: Typeable a => Expr a -> Data a
exprToData a = Data (exprSize a) (ref a)



exprSize :: forall a . Typeable a => Expr a -> Size a

exprSize (Input sz)   = sz
exprSize (Value sz _) = sz

exprSize (Tuple2 a b)     = (dataSize a, dataSize b)
exprSize (Tuple3 a b c)   = (dataSize a, dataSize b, dataSize c)
exprSize (Tuple4 a b c d) = (dataSize a, dataSize b, dataSize c, dataSize d)

exprSize (Get21 ab) = da
  where
    (da,db) = dataSize ab

exprSize (Get22 ab) = db
  where
    (da,db) = dataSize ab

exprSize (Get31 abc) = da
  where
    (da,db,dc) = dataSize abc

exprSize (Get32 abc) = db
  where
    (da,db,dc) = dataSize abc

exprSize (Get33 abc) = dc
  where
    (da,db,dc) = dataSize abc

exprSize (Get41 abcd) = da
  where
    (da,db,dc,dd) = dataSize abcd

exprSize (Get42 abcd) = db
  where
    (da,db,dc,dd) = dataSize abcd

exprSize (Get43 abcd) = dc
  where
    (da,db,dc,dd) = dataSize abcd

exprSize (Get44 abcd) = dd
  where
    (da,db,dc,dd) = dataSize abcd

exprSize (Function _ sz _ _)  = sz
exprSize (NoInline _ f a)     = subFunSize (deref f)
exprSize (IfThenElse _ t e a) = subFunSize t `mappend` subFunSize e
exprSize (While _ b i)        = dataSize i   `mappend` subFunSize b
exprSize (Parallel l ixf)     = mapMonotonic fromIntegral (dataSize l)
                                :> subFunSize ixf



-- | Computable types. A computable value completely represents a core program,
-- in such a way that @`internalize` `.` `externalize`@ preserves semantics, but
-- not necessarily syntax.
--
-- The terminology used in this class comes from thinking of the 'Data' type as
-- the \"internal\" core language and the "Feldspar.Core" API as the
-- \"external\" core language.
class Typeable (Internal a) => Computable a
  where
    -- | @`Data` (`Internal` a)@ is the internal representation of the type @a@.
    type Internal a

    -- | Convert to internal representation
    internalize :: a -> Data (Internal a)

    -- | Convert to external representation
    externalize :: Data (Internal a) -> a

instance Storable a => Computable (Data a)
  where
    type Internal (Data a) = a

    internalize = id
    externalize = id

instance (Computable a, Computable b) => Computable (a,b)
  where
    type Internal (a,b) = (Internal a, Internal b)

    internalize (a,b) = exprToData $ Tuple2 (internalize a) (internalize b)

    externalize ab =
        ( externalizeE $ Get21 ab
        , externalizeE $ Get22 ab
        )

instance (Computable a, Computable b, Computable c) => Computable (a,b,c)
  where
    type Internal (a,b,c) = (Internal a, Internal b, Internal c)

    internalize (a,b,c) = exprToData $ Tuple3
      (internalize a)
      (internalize b)
      (internalize c)

    externalize abc =
        ( externalizeE $ Get31 abc
        , externalizeE $ Get32 abc
        , externalizeE $ Get33 abc
        )

instance
    ( Computable a
    , Computable b
    , Computable c
    , Computable d
    ) =>
      Computable (a,b,c,d)
  where
    type Internal (a,b,c,d) = (Internal a, Internal b, Internal c, Internal d)

    internalize (a,b,c,d) = exprToData $ Tuple4
      (internalize a)
      (internalize b)
      (internalize c)
      (internalize d)

    externalize abcd =
        ( externalizeE $ Get41 abcd
        , externalizeE $ Get42 abcd
        , externalizeE $ Get43 abcd
        , externalizeE $ Get44 abcd
        )



externalizeE :: Computable a => Expr (Internal a) -> a
externalizeE = externalize . exprToData

-- | Lower a function to operate on internal representation.
lowerFun :: (Computable a, Computable b) =>
    (a -> b) -> (Data (Internal a) -> Data (Internal b))
lowerFun f = internalize . f . externalize

-- | Lift a function to operate on external representation.
liftFun :: (Computable a, Computable b) =>
    (Data (Internal a) -> Data (Internal b)) -> (a -> b)
liftFun f = externalize . f . internalize



-- | The semantics of expressions
evalE :: Expr a -> a

evalE (Input _)   = error "evaluating Input"
evalE (Value _ a) = a

evalE (Tuple2 a b)     = (evalD a, evalD b)
evalE (Tuple3 a b c)   = (evalD a, evalD b, evalD c)
evalE (Tuple4 a b c d) = (evalD a, evalD b, evalD c, evalD d)

evalE (Get21 ab) = a
  where
    (a,b) = evalD ab

evalE (Get22 ab) = b
  where
    (a,b) = evalD ab

evalE (Get31 abc) = a
  where
    (a,b,c) = evalD abc

evalE (Get32 abc) = b
  where
    (a,b,c) = evalD abc

evalE (Get33 abc) = c
  where
    (a,b,c) = evalD abc

evalE (Get41 abcd) = a
  where
    (a,b,c,d) = evalD abcd

evalE (Get42 abcd) = b
  where
    (a,b,c,d) = evalD abcd

evalE (Get43 abcd) = c
  where
    (a,b,c,d) = evalD abcd

evalE (Get44 abcd) = d
  where
    (a,b,c,d) = evalD abcd

evalE (Function _ _ f a)   = f (evalD a)
evalE (NoInline _ f a)     = evalD $ subAp (deref f) a
evalE (IfThenElse c t e a) = if evalD c
    then evalD (subAp t a)
    else evalD (subAp e a)

evalE (While continue body init) = loop init
  where
    loop s = if done
        then evalD s
        else loop (subAp body s)
      where
        done = not $ evalD $ subAp continue s

evalE (Parallel l ixf) = map (evalD . subAp ixf . value) [0 .. n-1]
  where
    n = evalD l



-- | The semantics of 'Data'
evalD :: Data a -> a
evalD = evalE . dataToExpr

-- | The semantics of any 'Computable' type
eval :: Computable a => a -> Internal a
eval = evalD . internalize



-- | A program that computes a constant value
value :: Storable a => a -> Data a
value a = exprToData (Value (storableSize a) a)

-- | Like 'value' but with an extra 'Size' argument that can be used to increase
-- the size beyond the given data.
--
-- Example 1:
--
-- > array (10 :> 20 :> universal) [] :: Data [[Int]]
--
-- gives an uninitialized 10x20 array of 'Int' elements.
--
-- Example 2:
--
-- > array (10 :> 20 :> universal) [[1,2,3]] :: Data [[Int]]
--
-- gives a 10x20 array whose first row is initialized to @[1,2,3]@.
array :: Storable a => Size a -> a -> Data a
array sz a = exprToData $ Value (sz `mappend` storableSize a) a

arrayLen :: Storable a => Data Length -> [a] -> Data [a]
arrayLen len = array sz
  where
    sz = mapMonotonic fromInteger (dataSize len) :> universal
  -- XXX This function is a temporary solution.

unit :: Data ()
unit = value ()

true :: Data Bool
true = value True

false :: Data Bool
false = value False

-- | Returns the size of each level of a multi-dimensional array, starting with
-- the outermost level.
size :: forall a . Storable a => Data [a] -> [Range Length]
size = listSize (T::T [a]) . dataSize

cap :: (Storable a, Size a ~ Range b, Ord b) => Range b -> Data a -> Data a
cap szb (Data sz a) = Data (sz /\ szb) a
  -- XXX Should really have the type
  --       cap :: Storable a => Size a -> Data a -> Data a



-- | Constructs a one-argument primitive function.
--
-- @`function` fun szf f@:
--
--   * @fun@ is the name of the function.
--
--   * @szf@ computes the output size from the input size.
--
--   * @f@   gives the evaluation semantics.
function
    :: (Storable a, Storable b)
    => String -> (Size a -> Size b) -> (a -> b) -> (Data a -> Data b)

function fun sizeProp f a = case dataToExpr a of
    Value _ a' -> Data s (ref $ Value s $ f a')
    _          -> exprToData $ Function fun s f a
  where
    s = sizeProp (dataSize a)



-- | A two-argument primitive function
function2
    :: ( Storable a
       , Storable b
       , Storable c
       )
    => String
    -> (Size a -> Size b -> Size c)
    -> (a -> b -> c)
    -> (Data a -> Data b -> Data c)

function2 fun sizeProp f a b = case (dataToExpr a, dataToExpr b) of
    (Value _ a', Value _ b') -> Data s (ref $ Value s $ f a' b')
    _ -> exprToData $ Function fun s f' $ exprToData $ Tuple2 a b
  where
    s = sizeProp (dataSize a) (dataSize b)
    f' (a,b) = f a b



-- | A three-argument primitive function
function3
    :: ( Storable a
       , Storable b
       , Storable c
       , Storable d
       )
    => String
    -> (Size a -> Size b -> Size c -> Size d)
    -> (a -> b -> c -> d)
    -> (Data a -> Data b -> Data c -> Data d)

function3 fun sizeProp f a b c = case (d2e a, d2e b, d2e c) of
    (Value _ a', Value _ b', Value _ c') -> Data s (ref $ Value s $ f a' b' c')
    _ -> exprToData $ Function fun s f' $ exprToData $ Tuple3 a b c
  where
    d2e = dataToExpr
    s = sizeProp (dataSize a) (dataSize b) (dataSize c)
    f' (a,b,c) = f a b c



-- | A four-argument primitive function
function4
    :: ( Storable a
       , Storable b
       , Storable c
       , Storable d
       , Storable e
       )
    => String
    -> (Size a -> Size b -> Size c -> Size d -> Size e)
    -> (a -> b -> c -> d -> e)
    -> (Data a -> Data b -> Data c -> Data d -> Data e)

function4 fun sizeProp f a b c d = case (d2e a, d2e b, d2e c, d2e d) of
    (Value _ a', Value _ b', Value _ c', Value _ d') -> Data s (ref $ Value s $ f a' b' c' d')
    _ -> exprToData $ Function fun s f' $ exprToData $ Tuple4 a b c d
  where
    d2e = dataToExpr
    s = sizeProp (dataSize a) (dataSize b) (dataSize c) (dataSize d)
    f' (a,b,c,d) = f a b c d



instance Show (Data a)
  where
    show _ = "... :: Data a"
  -- Needed for the 'Num' instance.

instance Numeric a => Num (Data a)
  where
    fromInteger = value . fromInteger
    abs         = function  "abs"    abs    abs
    signum      = function  "signum" signum signum
    (+)         = function2 "(+)"    (+)    (+)
    (-)         = function2 "(-)"    (-)    (-)
    (*)         = function2 "(*)"    (*)    (*)

instance Fractional (Data Float)
  where
    fromRational = value . fromRational
    (/)          = function2 "(/)" (\_ _ -> fullRange) (/)  -- XXX Improve range



-- | Look up an index in an array (see also '!')
getIx :: Storable a => Data [a] -> Data Int -> Data a
getIx arr = function2 "(!)" sizeProp f arr
  where
    sizeProp (_:>aSize) _ = aSize

    f as i
        | not (i `inRange` r) = error "getIx: index out of bounds"
        | i >= la             = error "getIx: reading garbage"
        | otherwise           = as !! i
      where
        l :> _ = dataSize arr
        r      = rangeByRange 0 (l-1)
        la     = length as



-- | @`setIx` arr i a@:
--
-- Replaces the value at index @i@ in the array @arr@ with the value @a@.
setIx :: Storable a => Data [a] -> Data Int -> Data a -> Data [a]
setIx arr = function3 "setIx" sizeProp f arr
  where
    sizeProp (l:>aSize) _ aSize' = l :> (aSize `mappend` aSize')

    f as i a
        | not (i `inRange` r) = error "setIx: index out of bounds"
        | i > la              = error "setIx: writing past initialized area"
        | otherwise           = take i as ++ [a] ++ drop (i+1) as
      where
        l:>_ = dataSize arr
        r    = rangeByRange 0 (l-1)
        la   = length as

infixl 9 !

class RandomAccess a
  where
    -- | The type of elements in a random access structure
    type Element a

    -- | Index lookup in a random access structure
    (!) :: a -> Data Int -> Element a

instance Storable a => RandomAccess (Data [a])
  where
    type Element (Data [a]) = Data a
    (!) = getIx



mkSubFun :: Typeable a => Size a -> (Data a -> Data b) -> (a :-> b)
mkSubFun sz f = SubFunction f inp (f inp)
  where
    inp = exprToData $ Input sz


-- | Constructs a non-primitive, non-inlined function.
--
-- The normal way to make a non-primitive function is to use an ordinary Haskell
-- function, for example:
--
-- > myFunc x = x * 4 + 5
--
-- However, such functions are inevitably inlined into the program expression
-- when applied. @noInline@ can be thought of as a way to protect a function
-- against inlining (but later transformations may choose to inline anyway).
--
-- Ideally, it should be posssible to reuse such a function several times, but
-- at the moment this does not work. Every application of a @noInline@ function
-- results in a new copy of the function in the core program.
noInline :: (Computable a, Computable b) => String -> (a -> b) -> (a -> b)
noInline fun f a = liftFun (exprToData . NoInline fun (ref subFun)) a
  where
    subFun = mkSubFun (dataSize $ internalize a) (lowerFun f)



-- | @`ifThenElse` cond thenFunc elseFunc@:
--
-- Selects between the two functions @thenFunc@ and @elseFunc@ depending on
-- whether the condition @cond@ is true or false.
ifThenElse
    :: (Computable a, Computable b)
    => Data Bool -> (a -> b) -> (a -> b) -> (a -> b)

ifThenElse cond t e a = case dataToExpr cond of
    Value _ True       -> t a
    Value _ False      -> e a
--     Function "not" _ c -> ifThenElse c e t
-- XXX Not possible...
    _ -> liftFun (exprToData . IfThenElse cond thenSub elseSub) a
  where
    sz      = dataSize $ internalize a
    thenSub = mkSubFun sz $ lowerFun t
    elseSub = mkSubFun sz $ lowerFun e



whileSized
    :: Computable state
    => Size (Internal state)
    -> (state -> Data Bool)
    -> (state -> state)
    -> (state -> state)

whileSized sz cont body init = liftFun (exprToData . While contSub bodySub) init
  where
    contSub = mkSubFun sz $ lowerFun cont
    bodySub = mkSubFun sz $ lowerFun body



-- | While-loop
--
-- @while cont body :: state -> state@:
--
--   * @state@ is the type of the state.
--
--   * @cont@ determines whether or not to continue based on the current state.
--
--   * @body@ computes the next state from the current state.
--
--   * The result is a function from initial state to final state.
while
    :: Computable state
    => (state -> Data Bool)
    -> (state -> state)
    -> (state -> state)

while = whileSized universal



-- | Parallel array
--
-- @parallel l ixf@:
--
--   * @l@ is the length of the resulting array (outermost level).
--
--   * @ifx@ is a function that maps each index in the range @[0 .. l-1]@ to its
--     element.
--
-- Since there are no dependencies between the elements, the compiler is free to
-- compute the elements in any order, or even in parallel.
parallel :: Storable a => Data Length -> (Data Int -> Data a) -> Data [a]
parallel l ixf = exprToData $ Parallel l ixfSub
  where
    szl    = dataSize l
    ixfSub = mkSubFun (rangeByRange 0 (szl-1)) ixf