module DDC.Core.Flow.Transform.Rates.SeriesOfVector
        (seriesOfVectorModule
        ,seriesOfVectorFunction)
where
import DDC.Core.Collect
import DDC.Core.Flow.Compounds
import DDC.Core.Flow.Prim
import DDC.Core.Flow.Exp                           as DDC
import DDC.Core.Flow.Transform.Rates.Combinators   as Com
import DDC.Core.Flow.Transform.Rates.CnfFromExp
import DDC.Core.Flow.Transform.Rates.Fail
import DDC.Core.Flow.Transform.Rates.Graph
import qualified DDC.Core.Flow.Transform.Rates.SizeInference as SI
import DDC.Core.Flow.Transform.Rates.Clusters
import DDC.Core.Module
import DDC.Core.Transform.Annotate
import DDC.Core.Transform.Deannotate
import qualified DDC.Type.Env           as Env

import qualified Data.Map as Map
import           Data.Map   (Map)
import qualified Data.Set as Set
import Data.List   (partition)


seriesOfVectorModule :: ModuleF -> (ModuleF, [(Name,Fail)])
seriesOfVectorModule mm
 = let body       = deannotate (const Nothing)
                  $ moduleBody mm

       (lets, xx) = splitXLets body
       letsErrs   = map seriesOfVectorLets lets

       lets'      = concatMap fst letsErrs
       errs       = concatMap snd letsErrs

       body'      = annotate ()
                  $ xLets lets' xx


   in  (mm { moduleBody = body' }, errs)
       


seriesOfVectorLets :: LetsF -> ([LetsF], [(Name,Fail)])
seriesOfVectorLets ll
 | LLet b@(BName n _) x <- ll
 , (x',ls',errs)  <- seriesOfVectorFunction x
 = ( map (uncurry LLet) ls' ++ [LLet b x']
   , map (\f -> (n,f)) errs)

 | LRec bxs             <- ll
 , (bs,xs)              <- unzip bxs
 , (xs',ls', _errs)          <- unzip3 $ map seriesOfVectorFunction xs
 = ( [LRec (concat ls' ++ (bs `zip` xs'))]
   , []) 
        -- We still need to produce errors if this doesn't work.

 | otherwise
 = ([ll], [])


-- | Takes a single function body. Function body must be in a-normal form.
seriesOfVectorFunction :: ExpF -> (ExpF, [(BindF,ExpF)], [Fail])
seriesOfVectorFunction fun
 = case cnfOfExp fun of
   Left err
    -> (fun, [], [FailCannotConvert err])
   Right prog
    -> case SI.generate prog of
           Nothing
            -> (fun, [], [])

           Just (env,_s)
            -> let g          = graphOfBinds prog env
                   tmap a b   = SI.parents prog env a b
                   clusters   = cluster g tmap
                   (re, ls)   = reconstruct fun prog env clusters
               in  (re, ls, [])


reconstruct
        :: ExpF
        -> Program Name Name
        -> SI.Env Name
        -> [[CName Name Name]]
        -> (ExpF, [(BindF, ExpF)])
reconstruct fun prog env clusters
 = (makeXLamFlags lams
   $ xLets lets' xx
   , procs)
 where

  (lams, body)   = takeXLamFlags_safe fun
  (olds, xx)     = splitXLets         body

  types          = takeTypes          (concatMap valwitBindsOfLets olds ++ map snd lams)

  lets           = concatMap convert clusters
  (lets', procs) = extractProcs lets lams

  convert c
   = let outputs = outputsOfCluster prog c

         arrIns  = seriesInputsOfCluster prog c

         -- Map over the list of all binds and find only those that we want.
         -- This way is better than mapping lookup over c, as we get them in program order.
         binds   = filter (flip elem c . cnameOfBind) (_binds prog)
         binds'  = map (\a -> (a, cnameOfBind a `elem` outputs)) binds
      in mkLets types env arrIns binds'


-- | Extract processes out so they can be made into separate bindings
extractProcs :: [LetsF] -> [(Bool, DDC.Bind Name)] -> ([LetsF], [(BindF,ExpF)])
extractProcs lets env
 = go lets $ env
 where
  go [] _
   = ([], [])

  go (l:ls) e
   -- Actually, we know they're all LLets. Maybe it should just be (Name,ExpF)
   | LLet b x <- l
   , BName nm _ <- b
   = let this = go1 b nm x e
         rest = go ls ((False,b) : e)
     in this `mappend` rest

   | otherwise
   = ([l],[]) `mappend` go ls e

  go1 b nm x e
   | Just (op, args)                                      <- takeXApps x
   , XVar (UPrim (NameOpSeries (OpSeriesRateVecsOfVectors n)) _)    <- op
   , (xs, [lam])                                        <- splitAt (length args - 1) args
   , (lams,body)                                        <- takeXLamFlags_safe lam
   , ([LLet n' x'], binds)                              <- go1 b nm body (lams ++ e)
   = ([LLet n'
        (xApps (xVarOpSeries (OpSeriesRateVecsOfVectors n))
            (xs ++ [makeXLamFlags lams x']))]
     , binds)

   | Just (op, args)                                      <- takeXApps x
   , XVar (UPrim (NameOpSeries OpSeriesRunProcess) _)    <- op
   , (xs, [lam])                                         <- splitAt (length args - 1) args

   = let fsX = freeX Env.empty lam
         fsT = freeT Env.empty lam

         isMentioned (ty,bo)
          | Just bo' <- takeSubstBoundOfBind bo
          = if   ty
            then Set.member bo' fsT
            else Set.member bo' fsX
          | otherwise
          = False

         os = filter isMentioned e

         ss  = takeSubstBoundsOfBinds . map snd
         osT = filter     (fst) os
         osX = filter (not.fst) os

         nm' = NameVarMod nm "process"

         x' = xApps op (xs ++
                [xApps (XVar $ UName nm')
                  (  map (XType . TVar) (ss osT)
                  ++ map  XVar          (ss osX))])

         p' = makeXLamFlags (osT ++ osX) lam
     in ([LLet b x'], [(BName nm' (tBot kData),  p')])

   | otherwise
   = ([LLet b x], [])

-- | Make "lets" for a cluster.
-- If it's external, this is trivial.
-- If not, make a runProcess# etc
mkLets :: Map Name TypeF -> SI.Env Name -> [Name] -> [(Com.Bind Name Name, Bool)] -> [LetsF]
mkLets types env arrIns bs
 | any isExt (map fst bs)
 = case bs of
    [(Ext (NameArray  b) xx _, _)] -> [LLet (BName b (types Map.! b)) xx]
    [(Ext (NameScalar b) xx _, _)] -> [LLet (BName b (types Map.! b)) xx]

    _    -> error ("ddc-core-flow:DDC.Core.Flow.Transform.Rates.SeriesOfVector impossible\n" ++
                   "an external node has been clustered with another node.\n" ++
                   "this means there must be a bug in the clustering algorithm.\n" ++
                   show bs)

 -- We *could* just return an empty list in this case, but I don't think that's a good idea.
 | [] <- bs
 =          error ("ddc-core-flow:DDC.Core.Flow.Transform.Rates.SeriesOfVector impossible\n" ++
                   "a cluster was created with no bindings.\n" ++
                   "this means there must be a bug in the clustering algorithm.\n" ++
                   show bs)

 | otherwise
 = process types env arrIns
 $ map toEither bs

 where
  isExt (Ext{}) = True
  isExt _       = False

  toEither (SBind s b, out) = ((s, Left  b), out)
  toEither (ABind a b, out) = ((a, Right b), out)
  toEither (Ext{},     _)   = error "ddc-core-flow.mkLets: impossible!"


-- | Create a process for a cluster of array and scalar bindings.
-- No externals.
-- List of bindings cannot be empty
process :: Map Name TypeF
        -> SI.Env Name
        -> [Name]
        -> [((Name, Either (SBind Name Name) (ABind Name Name)), Bool)]
        -> [LetsF]
process types env arrIns bs
 = let pres  = concatMap  getPre  bs
       mid   = getRateVecs arrIns
             $ getGenerates bs
             $ runProcs   
             $ xLets (map getInSeries arrIns)
             $ mkProcs bs
       posts = concatMap  getPost bs
   in pres ++ [LLet (BName (NameVarMod outname "runproc") tUnit) mid] ++ posts
 where
  getPre b
   -- There is no point of having a reduce that isn't returned.
   | ((s, Left (Fold _ (Scalar z _) _)), _)       <- b
   = [LLet (BName (NameVarMod s "ref") $ tRef $ tyOf s) (xNew (tyOf s) z)]

   -- Returned vectors
   | ((v, Right _), True)                       <- b
   = [LLet (BName v $ tVector $ sctyOf v) (xNewVector (sctyOf v) allocSize)]

   -- Otherwise, it's not returned or we needn't allocate anything
   | _                                          <- b
   = []


  getGenerates [] innerX
   = innerX
  getGenerates (b:rest) innerX
   | ((n, Right (Generate (Scalar sz _) _)), _) <- b
   , rest' <- getGenerates rest innerX
   = xApps (xVarOpSeries (OpSeriesRateVecsOfVectors 0))
           [ XType tUnit
           , sz
           , XLAM (BName (klokV n) kRate) rest' ]
   | otherwise
   = getGenerates rest innerX

  getRateVecs [] innerX
   = innerX

  getRateVecs (a:as) innerX
   | Just t        <- SI.lookupV env a
   , (same,others) <- partition ((==Just t) . SI.lookupV env) as
   , rest'         <- getRateVecs others innerX

   , these         <- (a : same)
   , nums          <- length these

   , op            <- OpSeriesRateVecsOfVectors nums

   , flags         <- map (\n -> (False, BName (NameVarMod n "rv") (tRateVec (klokT a) (sctyOf n)))) these
   = xApps (xVarOpSeries op)
       (   map xsctyOf these ++ [XType tUnit]
       ++  map var     these
       ++[ makeXLamFlags ((True, BName (klokV a) kRate) : flags)
            rest' ]
       )

   -- Input array doesn't have a type..
   | otherwise
   = getRateVecs as innerX

  getInSeries n
   = LLet (BName (NameVarMod n "s") (tSeries procT (klokT n) (sctyOf n)))
          (xApps (xVarOpSeries OpSeriesSeriesOfRateVec)
            [ procX, klokX n, xsctyOf n, var $ NameVarMod n "rv" ] )



  getPost b
   | ((s, Left (Fold _ (Scalar _ _) _)), _)       <- b
   = [ LLet (BName s $ sctyOf s) (xRead (tyOf s) (var $ NameVarMod s "ref")) ]

   -- Ignore anything else
   | _   <- b
   = []

  runProcs body
   = let flags = [ (True,  BName procName kProc)
                 , (False, BNone tUnit) ]
     in  xApps (xVarOpSeries OpSeriesRunProcess)
               ([XType processRate, makeXLamFlags flags body])


  mkProcs (b:rs)
   | ((s, Left (Fold (Fun xf _) (Scalar xs _) ain)), _)   <- b
   = let rest = mkProcs rs
     in  XLet (LLet   (BName (NameVarMod s "proc") $ tProcess procT $ klokT ain)
              $ xApps (xVarOpSeries OpSeriesReduce)
                      [ procX, klokX ain, xtyOf s, var (NameVarMod s "ref")
                      , xf, xs, var (NameVarMod ain "s")])
         rest

   | ((n, Right abind), out) <- b
   = let rest   = mkProcs rs
         n'proc = NameVarMod n "proc"
         n's    = NameVarMod n "s"
         n'flag = NameVarMod n "flags"
         n'sel  = NameVarMod n "sel"

         llet nm t x1
                = XLet (LLet (BName nm t) x1)

         go     | out
                = llet n'proc (tProcess procT $ klokT n)
                ( xApps (xVarOpSeries OpSeriesFill) [procX, klokX n, xsctyOf n, var $ n, var $ n's] )
                  rest
                | otherwise
                = rest

     in  case abind of
         MapN (Fun xf _) ains
          -> llet n's (tSeries procT (klokT n) $ sctyOf n)
           ( xApps (xVarOpSeries (OpSeriesMap (length ains)))
                   ([procX, klokX n] ++ (map xsctyOf  ains) ++ [xsctyOf n, xf] 
                        ++ map (var . flip NameVarMod "s") ains) )
             go

         Filter (Fun xf _) ain
          -> llet n'flag (tSeries procT (klokT ain) tBool)
           ( xApps (xVarOpSeries (OpSeriesMap 1))
                   [ procX, klokX ain, xsctyOf n, XType tBool, xf, var $ NameVarMod ain "s"] )
           $ xApps (xVarOpSeries $ OpSeriesMkSel 1)
                   [ procX, klokX ain, XType processRate, var n'flag
                   ,        XLAM (BName (klokV n) kRate)
                          $ XLam (BName n'sel (tSel1 procT (klokT ain) (klokT n)))
                          $ llet n's (tSeries procT (klokT n) $ sctyOf n)
                          ( xApps (xVarOpSeries OpSeriesPack)
                                  [ procX, klokX ain, klokX n, xsctyOf n
                                  , var n'sel, var $ NameVarMod ain "s"] )
                            go ]

         Generate _sz (Fun xf _)
          -> llet n's (tSeries procT (klokT n) $ sctyOf n)
           ( xApps (xVarOpSeries OpSeriesGenerate)
                   [ procX, klokX n, xsctyOf n, xf ])
             go

         Gather v ix
          -> llet n's (tSeries procT (klokT n) $ sctyOf n)
           ( xApps (xVarOpSeries OpSeriesGather)
                   ([ procX, klokX v, klokX ix, xsctyOf v
                    , var $ NameVarMod v "rv", var $ NameVarMod ix "s"]) )
             go
         Cross _a _b
          -> error "ddc-core-flow.process: Cross combinator not implemented yet"

   -- All cases are handled above
   | otherwise
   = error "ddc-core-flow.process: impossible!"



  mkProcs []
   -- bs cannot be empty: there's no point of an empty cluster.
   = let procs = concatMap getProc bs
     in  case procs of
         (_:_)  -> foldl1 mkJoin $ concatMap getProc bs
         []     -> error "ddc-core-flow.process: cluster with no outputs?"

  mkJoin p q
   = xApps (xVarOpSeries OpSeriesJoin) [p, q]

  getProc b@((s, Left _), _)
   = [resizeProc b $ var $ NameVarMod s "proc"]
  getProc b@((a, _), True)
   = [resizeProc b $ var $ NameVarMod a "proc"]
  getProc _
   = []

  resizeProc b v
   = goResize (fst $ fst b) v (reverse bs)

  goResize _ v []
   = v
  goResize n v (((n',b),_):rest)
   | n == n'
   = case b of
      Left (Fold _ _ ain)
       -> goResize ain v rest
      Right (MapN _ (i:_))
       -> goResize i   v rest
      Right (MapN _ [])
       -> error "ddc-core-flow.process: Map with no inputs."

      Right (Filter _ ain)
       -> goResize ain
        ( xApps (xVarOpSeries OpSeriesResizeProc)
          [ procX, klokX n, klokX ain
          , xApps (xVarOpSeries OpSeriesResizeSel1)
                [ procX, klokX n, klokX ain, klokX n
                , var $ NameVarMod n "sel"
                , xApps (xVarOpSeries OpSeriesResizeId)
                        [ procX, klokX n ]
                ]
         , v ]) rest

      Right (Generate _ _)
       -> v

      Right (Gather _ ain)
       -> goResize ain v rest

      Right (Cross a _b)
       -> goResize a v rest

   | otherwise
   = goResize n v rest

     
  allocSize
   -- If there are any inputs, use the size of one of those
   | (i:_) <- arrIns
   = xApps (xVarOpVector OpVectorLength) [xsctyOf i, var i]
   -- Or if it's a generate, find the size of the generate expression.
   | (Scalar sz _:_) <- concatMap findGenerateSize bs
   = sz
   -- XXX Otherwise it must be an external, and this won't be called...
   | otherwise
   = error ("ddc-core-flow: allocSize, but no size known" ++ show arrIns ++ "\n" ++ show bs)

  -- Find the loop rate of the process.
  -- Since we don't have appends, it's just the rate of the first binding
  processRate
   = bindRate (head bs)

  bindRate b
   = let k = klokT (fst $ fst b)
     in case snd $ fst b of
        Left (Fold _ _ ain)
         -> klokT ain
        Right (MapN _ _)
         -> k
        Right (Filter _ ain)
         -> klokT ain
        Right (Generate _ _)
         -> k
        Right (Gather _ _)
         -> k
        Right (Cross ain _)
         -> klokT ain

  -- We just need a name for the Proc type
  procName = NameVarMod outname "PROC"
  procT    = TVar  $ UName $ procName
  procX    = XType $ procT

  klokV = getKlok env
  klokT = TVar . UName . klokV
  klokX = XType . klokT

  findGenerateSize ((_, Right (Generate sz _)), _)
   = [sz]
  findGenerateSize _
   = []


  -- We just need to find the name of any binding
  -- This head is safe because @mkLets@ will not call @process@ with an empty cluster.
  outname
   = fst $ fst $ head bs

  tyOf n  = types Map.! n
  sctyOf  = getScalarType . tyOf
  xtyOf   = XType . tyOf
  xsctyOf = XType . sctyOf

  var n   = XVar $ UName n


xVarOpSeries n = XVar (UPrim (NameOpSeries n) (typeOpSeries n))
xVarOpVector n = XVar (UPrim (NameOpVector n) (typeOpVector n))


-- | Get underlying scalar of a vector type - or just return original type if it's not a vector.
getScalarType :: TypeF -> TypeF
getScalarType tt
 = case takePrimTyConApps tt of
        Just (NameTyConFlow TyConFlowVector, [sc])      -> sc
        _                                               -> tt


-- | Create map of types of binders
takeTypes :: [DDC.Bind Name] -> Map Name TypeF
takeTypes binds
 = Map.fromList $ concatMap get binds
 where
  get (BName n t) = [(n,t)]
  get _           = []


getKlok :: SI.Env Name -> Name -> Name
getKlok e n
 | Just t <- SI.lookupV e n
 = ty $ goT t
 | otherwise
 = ty n
 where
  ty = flip NameVarMod "k"

  goT (SI.TVar kv)
   = goKV kv
  -- This doesn't matter much..
  goT (SI.TCross ta tb)
   = NameVarMod (goT ta) (show tb)

  goKV (SI.KV v)
   = v
  goKV (SI.K' kv)
   = NameVarMod (goKV kv) "'"