module Feldspar.Core.Representation where

import Data.List
import Data.Typeable hiding (TypeRep)

import Data.Tagged
import Data.Proxy

import Feldspar.DSL.Expression hiding (Eval)
import qualified Feldspar.DSL.Expression as E
import Feldspar.DSL.Lambda
import Feldspar.DSL.Sharing
import Feldspar.DSL.Network
import Feldspar.Set
import Feldspar.Core.Types

-- * Feldspar expressions

-- | Feldspar-specific expressions
data Feldspar role a
    Literal :: (Type a, MetaType () a) => a -> Feldspar (Out ()) a

    Function :: (Typeable (a -> b), MetaType () b) =>
      String -> (a -> b) -> Feldspar (In ra -> Out ()) (a -> b)

    Pair :: (Type a, Type b, MetaType () (a,b)) => Feldspar (In () -> In () -> Out ()) (a -> b -> (a,b))

    Condition :: MetaType ra a => Feldspar
      (In () -> In ra -> In ra -> Out ra)
      (Bool  -> a     -> a     -> a)

    Parallel :: (Type a, MetaType () [a]) => Feldspar
      (In ()  -> (Out () -> In ()) -> In () -> Out ())
      (Length -> (Index  -> a)     -> [a]   -> [a])

    Sequential :: (Type a, MetaType () [a], MetaType rst st) => Feldspar
      (In ()  -> In rst -> (Out () -> Out rst -> In ((),rst)) -> (Out rst -> In ()) -> Out ())
      (Length -> st     -> (Index  -> st      -> (a,st))      -> (st      -> [a])   -> [a])

    ForLoop :: MetaType rst st => Feldspar
      (In ()  -> In rst -> (Out () -> Out rst -> In rst) -> Out rst)
      (Length -> st     -> (Index  -> st      -> st)     -> st)

    NoInline :: MetaType rb b =>
      String -> Feldspar ((Out ra -> In rb) -> (In ra -> Out rb)) ((a -> b) -> (a -> b))

    SetLength :: Type a =>
      Feldspar (In () -> In () -> Out ()) (Length -> [a] -> [a])

    SetIx :: (Type a) => Feldspar
      (In () -> In () -> In () -> Out ())
      (Index -> a     -> [a]   -> [a])

-- TODO Missing support for writing to several indices at once in 'Parallel' and
--      'Sequential'.

instance ExprEq Feldspar
    exprEq (Literal a)      (Literal b)      = eqLiteral a b
    exprEq (Function n1 f1) (Function n2 f2) = n1==n2 && sameType f1 f2
    exprEq Pair             Pair             = True
    exprEq Condition        Condition        = True
    exprEq Parallel         Parallel         = True
    exprEq Sequential       Sequential       = True
    exprEq ForLoop          ForLoop          = True
    exprEq (NoInline n1)    (NoInline n2)    = n1 == n2
    exprEq SetLength        SetLength        = True
    exprEq SetIx            SetIx            = True
    exprEq _ _                               = False
      -- Note that functions are only compared by name.

eqLiteral :: (Typeable a, Typeable b, Eq b) => a -> b -> Bool
eqLiteral a b = case cast a of
    Just a' -> a'==b
    _       -> False

sameType :: forall a b . (Typeable a, Typeable b) => a -> b -> Bool
sameType a b = case cast a :: Maybe b of
    Nothing -> False
    _       -> True

instance E.Eval Feldspar
    eval (Literal a)    = a
    eval (Function _ f) = f
    eval Pair           = (,)
    eval Condition      = \cond t e -> if cond then t else e
    eval Parallel       = evalParallel
    eval Sequential     = evalSequential
    eval ForLoop        = evalForLoop
    eval (NoInline _)   = id
    eval SetLength      = evalSetLength
    eval SetIx          = evalSetIx
evalParallel :: Length -> (Index -> a) -> [a] -> [a]
evalParallel 0 _ cont   = cont
evalParallel l ixf cont = map ixf [0 .. l-1] ++ cont
  -- Need a special case for l==0 because 0-1 is a huge number

evalSequential :: Length -> st -> (Index -> st -> (a,st)) -> (st -> [a]) -> [a]
evalSequential l init step cont = start ++ cont st'
    (st',start)   = mapAccumL evalStep init [0 .. l-1]
    evalStep st i = (st',a) where (a,st') = step i st

evalForLoop :: Length -> st -> (Index -> st -> st) -> st
evalForLoop 0 init body = init
evalForLoop l init body = foldl (flip body) init [0 .. l-1]
  -- Need a special case for l==0 because 0-1 is a huge number

evalSetLength :: Length -> [a] -> [a]
evalSetLength 0 as     = []
evalSetLength l (a:as) = a : evalSetLength (l-1) as
evalSetLength _ _      = error "setLength: reading past the end of an array"

evalSetIx :: Index -> a -> [a] -> [a]
evalSetIx i v as | i < len   = genericTake i as ++ [v] ++ genericDrop (i+1) as
                 | otherwise = error $ "setIx: assigning index (" ++ show i ++
                                       ") past the end of an array of length " ++
                                       show len
  where len = genericLength as

instance ExprShow Feldspar
    exprShow (Literal a)      = show a
    exprShow (Function fun _) = fun
    exprShow Pair             = "pair"
    exprShow Condition        = "condition"
    exprShow Parallel         = "parallel"
    exprShow Sequential       = "sequential"
    exprShow ForLoop          = "forLoop"
    exprShow (NoInline n)     = "noinline " ++ show n
    exprShow SetLength        = "setLength"
    exprShow SetIx            = "setIx"

-- * Feldspar networks

-- | A wrapper around 'Size' to make it look like an expression. The 'Type'
-- constraint ensures that edges in a 'FeldNetwork' always have supported types.
data EdgeSize role a = (Type a, Eq (Size a), Show (Size a)) =>
    EdgeSize { edgeSize :: Size a }

instance ExprShow EdgeSize
    exprShow (EdgeSize a) = show a

instance Eq (Size a) => Eq (EdgeSize role a) where
  EdgeSize sz1 == EdgeSize sz2 = sz1 == sz2

instance Type a => Set (EdgeSize role a)
    empty                        = EdgeSize empty
    universal                    = EdgeSize universal
    EdgeSize sz1 \/ EdgeSize sz2 = EdgeSize (sz1 \/ sz2)
    EdgeSize sz1 /\ EdgeSize sz2 = EdgeSize (sz1 /\ sz2)

-- | 'Network' of 'Feldspar' expressions
type FeldNetwork = Network EdgeSize Feldspar

-- | A Feldspar program computing a value of type @a@
newtype Data a = Data { unData :: FeldNetwork (In ()) a }
  deriving (Eq)

instance Show (Data a)
    show = show . unData

instance EdgeInfo (Data a)
    type Info (Data a) = EdgeSize () a
    edgeInfo           = edgeInfo . unData

instance Type a => MultiEdge (Data a) Feldspar EdgeSize
    type Role     (Data a) = ()
    type Internal (Data a) = a
    toEdge                 = toEdge . unData
    fromInEdge             = Data . fromInEdge
    fromOutEdge info       = Data . fromOutEdge info

-- | 'Syntactic' is a specialization of the 'MultiEdge' class for 'Feldspar'
-- programs.
    ( MultiEdge a Feldspar EdgeSize
    , Set (Info a)
    , Type (Internal a)
    , MetaType (Role a) (Internal a)
    ) => Syntactic a

-- TODO There is something strange with the constraint Type (Internal a). It is
--      really only needed when Role a ~ (), but it accidentally works to have
--      this constraint for all Syntactic types.

instance Type a => Syntactic (Data a)
instance (Syntactic a, Syntactic b) => Syntactic (a,b)
instance (Syntactic a, Syntactic b, Syntactic c) => Syntactic (a,b,c)
instance (Syntactic a, Syntactic b, Syntactic c, Syntactic d) => Syntactic (a,b,c,d)

edgeType :: forall a . EdgeSize () a -> TypeRep
edgeType (EdgeSize sz) = typeRep (Tagged sz :: Tagged a (Size a))

dataSize :: Type a => Data a -> Size a
dataSize = edgeSize . edgeInfo . unData

dataNode :: Data a -> FeldNetwork (Out ()) a
dataNode = undoEdge . unData

nodeData :: Type a => Size a -> FeldNetwork (Out ()) a -> Data a
nodeData sz = fromOutEdge (EdgeSize sz)

getInfo :: Syntactic a => a -> Info a
getInfo = edgeInfo

resizeData :: Type a => Size a -> Data a -> Data a
resizeData sz = nodeData sz . dataNode

variable :: Syntactic a => Info a -> Ident -> a
variable info = fromOutEdge info . Variable

lambda :: (Syntactic a, Syntactic b)
    => Info a
    -> (a -> b)
    -> FeldNetwork (Out (Role a) -> In (Role b)) (Internal a -> Internal b)
lambda info f = Lambda (toEdge . f . fromOutEdge info)

-- | Forcing computation
force :: Syntactic a => a -> a
force = edgeCast

-- | Evaluation of Feldspar programs
eval :: Syntactic a => a -> Internal a
eval = E.eval . toEdge

-- | Yield the value of a constant program. If the value is not known
-- statically, the result is 'Nothing'.
viewLiteral :: Syntactic a => a -> Maybe (Internal a)
viewLiteral = mapEdge (\_ a -> lit (undoEdge a)) . toEdge
    lit :: FeldNetwork (Out ()) a -> Maybe a
    lit (Inject (Node (Literal a))) = Just a
    lit _                           = Nothing

metaTypes :: forall a ra expr .
    MetaType ra a => expr (Out ra) a -> [([Int], TypeRep)]
metaTypes _ = listTypes [] (Proxy :: Proxy ra) (Proxy :: Proxy a)

-- | List the types of the results produced by a 'Feldspar' expression
resTypes :: FeldNetwork ra a -> [([Int], TypeRep)]  -- TODO Should use (Out ra)
resTypes a = case a of
    Inject (Node (Literal _))                        -> metaTypes a
    Inject (Node (Function _ _)) :$: _               -> metaTypes a
    Inject (Node Pair) :$: _ :$: _                   -> metaTypes a
    Inject (Node Condition) :$: _ :$: _ :$: _        -> metaTypes a
    Inject (Node Parallel) :$: _ :$: _ :$: _         -> metaTypes a
    Inject (Node Sequential) :$: _ :$: _ :$: _ :$: _ -> metaTypes a
    Inject (Node ForLoop) :$: _ :$: _ :$: _          -> metaTypes a
    Inject (Node (NoInline n)) :$: _ :$: _           -> metaTypes a
    Inject (Node SetLength) :$: _ :$: _              -> metaTypes a
    Inject (Node SetIx) :$: _ :$: _ :$: _            -> metaTypes a
    Let _ :$: _ :$: (Lambda f)                       -> resTypes (f ph)
    _                                                -> error $ "Representation.resTypes: " ++ show a

isMulti :: FeldNetwork ra a -> Bool
isMulti a = isNode a && (length (resTypes a) > 1)

isElem :: FeldNetwork ra a -> Bool
isElem (Inject (Node (Function "(!)" _)) :$: _) = True
isElem _ = False

isSelector :: FeldNetwork ra a -> Bool
isSelector (Inject (Node (Function fun _)) :$: _) =
    fun `elem` ["getFst","getSnd"]
isSelector _ = False

isArrayLit :: FeldNetwork ra a -> Bool
isArrayLit (Inject (Node (Literal a)))
    | ArrayData as <- dataRep a = True
isArrayLit _ = False

isEmpty :: FeldNetwork ra a -> Bool
isEmpty (Inject (Node (Literal a)))
    | ArrayData as <- dataRep a = length as == 0
isEmpty _ = False

feldSharing :: (Typeable ra, Typeable a) => FeldNetwork ra a -> FeldNetwork ra a
feldSharing = sharing Params
    { necessary    = \(SomeLam a) -> isNode a && not (isFunction a || isVar a || isElem a || isSelector a)
    , sufficient   = \(SomeLam a) -> isMulti a
    , sharingPoint = \(SomeLam a) -> not (isFunction a)
        -- To avoid introducing 'Let' in the middle of a nested lambda (e.g. the
        -- body of a 'ForLoop')

showExprTree :: Syntactic a => a -> String
showExprTree = showLamTree . feldSharing . toEdge

showExprTree2 :: (Syntactic a, Syntactic b) => (a -> b) -> String
showExprTree2 = showLamTree . feldSharing . lambda universal
  -- TODO Only temporary...

drawExpr :: Syntactic a => a -> IO ()
drawExpr = drawLambda . feldSharing . toEdge

drawExpr2 :: (Syntactic a, Syntactic b) => (a -> b) -> IO ()
drawExpr2 = drawLambda . feldSharing . lambda universal
  -- TODO Only temporary...