{-# LANGUAGE ScopedTypeVariables, Rank2Types #-}

-- Note : for now, the initial state is computed during the first tick

-- | Transform the copilot specification in an atom one, and then compile that one.
module Language.Copilot.Compiler(copilotToAtom, tmpSampleStr) where

import Language.Copilot.Core

import Data.Maybe
import Data.Map as M
import Data.List

import qualified Language.Atom as A

-- | Compiles an /Copilot/ specification to an /Atom/ one.
-- The period is given as a Maybe : if it is Nothing, an optimal period will be chosen.
copilotToAtom :: LangElems -> Maybe Period -> (Period, A.Atom ())
copilotToAtom (LangElems streams sends triggers) p = 
  (p', A.period p' $ do

    prophArrs <- mapStreamableMapsM initProphArr streams
    outputs <- mapStreamableMapsM initOutput streams

    updateIndexes <- foldStreamableMaps makeUpdateIndex prophArrs (return M.empty)
    outputIndexes <- foldStreamableMaps makeOutputIndex prophArrs (return M.empty)

    tmpSamples <- foldStreamableMaps (\_ -> initExtSamples streams prophArrs outputIndexes) 
                    streams 
                    (return emptyTmpSamples)

    -- One atom rule for each stream
    foldStreamableMaps (makeRule streams outputs prophArrs tmpSamples 
                           updateIndexes outputIndexes) 
      streams (return ())

    -- foldStreamableMaps (makeTrigger streams prophArrs tmpSamples
    --                        outputIndexes)
    --  triggers (return ())

    M.fold (makeTrigger streams prophArrs tmpSamples outputIndexes) 
           (return ()) triggers

    foldStreamableMaps (makeSend outputs) sends (return ())

    -- Sampling of the external variables.  Remove redundancies.
    sequence_ $ snd . unzip $ nubBy (\x y -> fst x == fst y) $ 
      foldStreamableMaps (\_ -> sampleExts tmpSamples) streams []
    )
  where
    optP = getOptimalPeriod streams sends
    p' = 
      case p of
        Nothing -> optP
        Just i -> if i >= optP 
            then i 
            else error $ "Copilot error: the period is too short, " 
                     ++ "it should be at least " ++ show optP ++ " ticks."

initProphArr :: forall a. Streamable a => Var -> Spec a -> A.Atom (BoundedArray a)
initProphArr v s =
    let states = initState s
        name = "prophVal__" ++ normalizeVar v
        n = genericLength states in
    if n > 0
        then
            do
                array <- A.array name (states ++ [unit])
                -- unit is replaced by the good value during the first tick
                return $ B n $ Just array
    else return $ B n Nothing
    where
        initState s' =
            case s' of
                Append ls s'' -> ls ++ initState s''
                _ -> []

initOutput :: forall a. Streamable a => Var -> Spec a -> A.Atom (A.V a)
initOutput v _ = do
  atomConstructor (normalizeVar v) (unit::a)

tmpSampleStr :: String
tmpSampleStr = "tmpSampleVal__"

initExtSamples :: forall a. Streamable a 
               => StreamableMaps Spec -> ProphArrs -> Indexes -> Spec a 
                  -> A.Atom TmpSamples -> A.Atom TmpSamples
initExtSamples streams prophArrs outputIndexes s tmpSamples = do
    case s of
        Const _ -> tmpSamples
        Var _ ->   tmpSamples
        Drop _ s0 -> initExtSamples' s0 tmpSamples
        Append _ s0 -> initExtSamples' s0 tmpSamples
        F _ _ s0 -> initExtSamples' s0 tmpSamples
        F2 _ _ s0 s1 -> initExtSamples' s0 $
                           initExtSamples' s1 tmpSamples
        F3 _ _ s0 s1 s2 -> initExtSamples' s0 $ initExtSamples' s1 $
                             initExtSamples' s2 tmpSamples
        PVar _ v ph -> 
            do  -- checkVar v 
                ts <- tmpSamples
                let v' = tmpVarName v ph
                    vts = tmpVars ts
                    maybeElem = getMaybeElem v' vts::Maybe (PhasedValueVar a)
                    name = tmpSampleStr ++ normalizeVar v'
                case maybeElem of
                    Nothing -> 
                        do  val <- atomConstructor name (unit::a)
                            let m' = M.insert v' (PhV ph val) (getSubMap vts)
                            return $ ts {tmpVars = updateSubMap (\_ -> m') vts}
                    Just _ -> return ts
        PArr _ (arr, idx) ph -> 
            do  --checkVar arr
                ts <- tmpSamples
                let arr' = tmpArrName arr ph (show idx)
                    arrts = tmpArrs ts
                    idxts = tmpIdxs ts
                    maybeElem = getMaybeElem arr' arrts::Maybe (PhasedValueArr a)
                    name = tmpSampleStr ++ normalizeVar arr'
                case maybeElem of 
                  Nothing -> -- if the array isn't in the map, neither is the index
                      do val <- atomConstructor name (unit::a)
                         let i = case idx of
                                   Const e -> PhIdx $ A.Const e
                                   Var _ -> 
--                                     let B initLen maybeArr = getElem v prophArrs 
--                                     let B initLen maybeArr = 
--                                           case getMaybeElem v prophArrs of
--                                             Nothing -> 
--                                               error "Error in function initExtSamples."
--                                             Just x -> x
                                     PhIdx $ nextSt streams prophArrs 
                                                undefined outputIndexes idx 0
                                   _    -> error "Unexpected Spec in initExtSamples."
                         let m' = M.insert arr' (PhA ph val) (getSubMap arrts)
                         let m'' = M.insert arr' i (getSubMap idxts)
                         return $ ts { tmpArrs = updateSubMap (\_ -> m') arrts
                                     , tmpIdxs = updateSubMap (\_ -> m'') idxts
                                     }
                  Just _ -> return ts
    where 
    --      checkVar v = when (normalizeVar v /= v)
    --                    (error $ "Copilot: external variable " ++ v ++ " is not "
    --                            ++ "a valid C99 variable.")
          initExtSamples' :: Streamable b
                          => Spec b -> A.Atom TmpSamples -> A.Atom TmpSamples
          initExtSamples' = initExtSamples streams prophArrs outputIndexes

makeUpdateIndex :: Var -> BoundedArray a -> A.Atom Indexes -> A.Atom Indexes
makeUpdateIndex v (B n arr) indexes =
    case arr of
        Nothing -> indexes
        Just _ ->  
            do
                mindexes <- indexes
                index <- atomConstructor ("updateIndex__" ++ normalizeVar v) n
                return $ M.insert v index mindexes

makeOutputIndex :: Var -> BoundedArray a -> A.Atom Indexes -> A.Atom Indexes
makeOutputIndex v (B _ arr) indexes =
    case arr of
        Nothing -> indexes
        Just _ ->  
            do
                mindexes <- indexes
                index <- atomConstructor ("outputIndex__" ++ normalizeVar v) 0
                return $ M.insert v index mindexes

makeRule :: forall a. Streamable a => 
    StreamableMaps Spec -> Outputs -> ProphArrs -> TmpSamples -> 
    Indexes -> Indexes -> Var -> Spec a -> A.Atom () -> A.Atom ()
makeRule streams outputs prophArrs tmpSamples updateIndexes outputIndexes v s r = do
    r 
    let B n maybeArr = getElem v prophArrs::BoundedArray a
    case maybeArr of
        Nothing ->
            -- Fusing together the update and the output if the prophecy array doesn't exist 
            -- (ie if it would only have hold the output value)
            A.exactPhase 0 $ A.atom ("updateOutput__" ++ normalizeVar v) $ do
                ((getElem v outputs)::(A.V a)) A.<== nextSt'

        Just arr -> do
            let updateIndex = fromJust $ M.lookup v updateIndexes
                outputIndex = fromJust $ M.lookup v outputIndexes

            A.exactPhase 0 $ A.atom ("update__" ++ normalizeVar v) $ do
                arr A.! (A.VRef updateIndex) A.<== nextSt'
            
            A.exactPhase 1 $ A.atom ("output__" ++ normalizeVar v) $ do
                ((getElem v outputs)::(A.V a)) A.<== arr A.!. (A.VRef outputIndex)
                outputIndex A.<== (A.VRef outputIndex + A.Const 1) `A.mod_` A.Const (n + 1)
            
            -- Spread these out evenly accross the remaining phases, staring no
            -- earlier than phase 1.
            A.phase ((maxSampleDep v streams) + 1)
              $ A.atom ("incrUpdateIndex__" ++ normalizeVar v) $ do
                updateIndex A.<== (A.VRef updateIndex + A.Const 1) `A.mod_` A.Const (n + 1)

       where nextSt' = nextSt streams prophArrs tmpSamples outputIndexes s 0
             
-- | Find the maximum phase as which an array sampling depends on this stream by
-- computing it's index in terms of it. Returns zero by default.
maxSampleDep :: Var -> StreamableMaps Spec -> Int
maxSampleDep v streams =
  foldStreamableMaps (\_ -> streamDep) streams 0
  where 
    streamDep :: Streamable b => Spec b -> Int -> Int
    streamDep s i = 
      case s of
        Var _ -> i
        Const _  -> i
        PVar _ _ _ -> i
        PArr _ (_, Var v') ph | v == v'   -> max ph i
                              | otherwise -> i
        PArr _ _ _ -> i
        F _ _ s0 -> streamDep s0 i
        F2 _ _ s0 s1 -> streamDep s0 $ streamDep s1 i
        F3 _ _ s0 s1 s2 -> streamDep s0 $ streamDep s1 $ streamDep s2 i
        Append _ s0 -> streamDep s0 i
        Drop _ s0 -> streamDep s0 i

-- makeSend :: forall a. Sendable a => Outputs -> Var -> Send a -> A.Atom () -> A.Atom ()
-- makeSend outputs name (Send (v, ph, port)) r = do
--         r 
--         A.exactPhase ph $ A.atom ("__send_" ++ name) $
--             send ((A.value (getElem v outputs))::(A.E a)) port
makeSend :: forall a. Streamable a 
         => Outputs -> String -> Send a -> A.Atom () -> A.Atom ()
makeSend outputs name (Send v ph port portName) r = do
        r 
        A.exactPhase ph $ A.atom ("__send_" ++ name) $
            mkSend (A.value (notVarErr v (\var -> getElem var outputs)) :: A.E a) 
                   port 
                   portName

-- What we really should be doing is just folding over the TmpSamples, since
-- that data should contain all the info we need to construct external variable
-- and external array samples.  However, there is the issue that for array
-- samples, the type of the index may differ from the type of the array, and
-- having the spec available provides typing coercion.  We could fold over the
-- TmpSamples, passing streams in, and extract the appropriate Spec a.
sampleExts :: forall a. Streamable a 
           => TmpSamples -> Spec a -> [(Var, A.Atom ())] -> [(Var, A.Atom ())]
sampleExts ts s a = do
  case s of
    Var _ -> a
    Const _ -> a
    PVar _ v ph -> 
     let v' = tmpVarName v ph 
         PhV _ var = getElem v' (tmpVars ts)::PhasedValueVar a in
     (v', A.exactPhase ph $ 
            A.atom ("sample__" ++ v') $ 
              var A.<== (A.value $ externalAtomConstructor v)
     ) : a

    PArr _ (arr, idx) ph -> 
         let arr' = tmpArrName arr ph (show idx)
             PhIdx i = getIdx arr' idx (tmpIdxs ts)
--             PhA _ arrV = getElem arr' (tmpArrs ts)::PhasedValueArr a in
             PhA _ arrV = case getMaybeElem arr' (tmpArrs ts)::Maybe (PhasedValueArr a) of
                            Nothing -> error "Error in fucntion sampleExts."
                            Just x -> x
         in 
     (arr', A.exactPhase ph $ 
              A.atom ("sample__" ++ arr') $ 
                arrV A.<== A.array' arr (atomType (unit::a)) A.!. i
     ) : a
    F _ _ s0 -> sampleExts ts s0 a
    F2 _ _ s0 s1 -> sampleExts ts s0 $ sampleExts ts s1 a
    F3 _ _ s0 s1 s2 -> sampleExts ts s0 $ sampleExts ts s1 $
                         sampleExts ts s2 a
    Append _ s0 -> sampleExts ts s0 a
    Drop _ s0 -> sampleExts ts s0 a

-- lookup the idx for external array accesses in the map.
getIdx :: forall a. (Streamable a, A.IntegralE a) 
       => Var -> Spec a -> StreamableMaps PhasedValueIdx -> PhasedValueIdx a
getIdx arr s ts = 
  case s of
    Var _   -> case getMaybeElem arr ts of
                 Nothing -> error "Error in function getIdx."
                 Just x  -> x
    Const e -> PhIdx $ A.Const e
    _       -> error $ "Expecing either a variable or constant for the index "
                 ++ "in the external array access for array " ++ arr ++ "."

-- XXX bound min, max send phases
getOptimalPeriod :: StreamableMaps Spec -> StreamableMaps Send -> Period
getOptimalPeriod streams sends =
  max (foldStreamableMaps getMaximumSamplingPhase streams 2)
      (foldStreamableMaps getMaxSendPhase sends 0)
  where
    getMaximumSamplingPhase :: Var -> Spec a -> Period -> Period 
    getMaximumSamplingPhase _ spec n =
      case spec of
        PVar _ _ ph -> max (ph + 1) n 
        PArr _ (_, Var _) ph -> max (ph + 2) n -- because this may depend on a
                                               -- variable, and if that variable
                                               -- has a prophecy array, it needs
                                               -- an extra phase to update after
                                               -- the index is taken.
        PArr _ _ ph -> max (ph + 1) n
        F _ _ s -> getMaximumSamplingPhase "" s n
        F2 _ _ s0 s1 -> maximum [n,
                (getMaximumSamplingPhase "" s0 n),
                (getMaximumSamplingPhase "" s1 n)]
        F3 _ _ s0 s1 s2 -> maximum [n,
                (getMaximumSamplingPhase "" s0 n), 
                (getMaximumSamplingPhase "" s1 n), 
                (getMaximumSamplingPhase "" s2 n)]
        Drop _ s -> getMaximumSamplingPhase "" s n
        Append _ s -> getMaximumSamplingPhase "" s n
        _ -> n

    getMaxSendPhase :: Var -> Send a -> Period -> Period
    getMaxSendPhase _ (Send _ ph _ _) n = max (ph+1) n