module Flite.RedCompile where

-- Parameterise app-length, spine-length and num apps per template,
-- but not arity limit (for now).

import Flite.Syntax
import Flite.Flatten
import Flite.RedFrontend
import Data.List
import Flite.Traversals
import Flite.WriterState
import Flite.Inline
import Flite.Predex
import qualified Flite.RedSyntax as R

import Flite.Pretty
import Debug.Trace

-- Splits applications so that they contain no more than one 'Alts' node.

splitCase :: App -> Bind App
splitCase app
  | length is <= 1 = return app
  | otherwise = do i <- freshId ; write (i, app0) ; splitCase (Var i:rest)
  where
    is = findIndices isAlts app
    (app0, rest) = splitAt (is !! 1) app

-- Splits an application so that it has maximum length 'n'.

splitApp :: Int -> App -> Bind App
splitApp n app
  | length app <= n = return app
  | otherwise = do i <- freshId ; write (i, app0) ; splitApp n (Var i:rest)
  where (app0, rest) = splitAt n app

-- Splits a group of applications so that they each have maximum
-- length 'n' and no more than one 'Alts' node.

splitApps :: Int -> [(Id, App)] -> [(Id, App)]
splitApps n apps = cs ++ ds
  where
    (i, as, bs) = runWS (mapM splitCase' apps) 0
    (j, cs, ds) = runWS (mapM splitApp' (as ++ bs)) i
    splitCase' (v, app) = (,) v `fmap` (splitCase app)
    splitApp' (v, app) = (,) v `fmap` (splitApp n app)

splitSpine :: Int -> [(Id, App)] -> (App, [(Id, App)], [Exp])
splitSpine n ((v, app):rest)
  | length spine <= n = (spine, rest, luts)
  | otherwise = -- Needed????
      ( Var v:takeBack (n-1) spine
      , (v, dropBack (n-1) spine):rest
      , luts
      )
  where
    spine = filter (not . isAlts) app
    luts = filter isAlts app

-- Translates a program to Reduceron syntax.  Takes the max
-- application length and max spine length as arguments.

translate :: InlineFlag -> Int -> Int -> Int -> Prog -> R.Prog
translate i n m nregs p = map (trDefn n m nregs p2) p2
  where
    p0 = frontend nregs i (force01:force0:force1:p)
    p1 = [ (f, map getVar args, flatten $ removePredexSpine rhs)
         | Func f args rhs <- p0
         ]
    p2 = lift "main" p1

trDefn n m nregs p (f, args, xs) =
  (f, length args, luts, pushs', apps')
  where
    (spine, body, ls) = splitSpine m xs
    body' = predexReorder nregs $ splitApps n body
    d = (f, args, spine, body')
    luts = map (indexOf p) $ map getAlts ls
    apps = map (trApp p d . snd) body'
    pushs = map (tr p d) $ filter (not . isAlts) spine
    (pushs', apps') = predex nregs (pushs, apps)

trApp p d app
   | isPrimitiveApp app = R.PRIM (-1) rest
  -- - | isPrimitiveApp app = R.PRIM (-1) (reverse rest) {- PV STACK -}
  | null luts = R.APP (isNormal rest) rest
  | otherwise = R.CASE (head luts) rest
  where
    app' = force app
    --app' = app {- PV STACK -}
    luts = map (indexOf p) $ map getAlts $ filter isAlts app'
    rest = map (tr p d) $ filter (not . isAlts) app'

force app@[Prim p,y,Int _] = Fun "!force0" : app
force app@[Prim p,Int i,y] = Fun "!force1" : app
force app
  | isPrimitiveApp app = Fun "!force01" : app
  | otherwise = app

indexOf p f =
  case [i | ((g, args, rhs), i) <- zip p [0..], f == g] of
    [] -> error "RedCompile: indexOf"
    i:_ -> i

isNormal (R.CON n c:rest) = length rest <= n
isNormal (R.FUN b n f:rest) = length rest < n
isNormal _ = False

tr p d (Int i) = R.INT i
tr p d (Prim f) = R.PRI 2 f
tr p d (Fun f) =
  case xs of
    [] -> R.PRI 2 f
    (i, args):_ -> R.FUN False (length args) i
  where xs = [(i, args) | ((g, args, rhs), i) <- zip p [0..], f == g]
tr p (f, args, spine, body) (Var v) =
  case v `elemIndex` args of
    Nothing -> R.VAR shared idx
    Just i -> R.ARG shared i
  where
    shared = (length $ filter (== v)
                     $ concatMap (concatMap vars) (spine : map snd body)) > 1
    idx = case [i | ((w, _), i) <- zip body [0..], v == w] of
            [] -> error ("Unbound variable: " ++ v)
            i:_ -> i
tr p d (Ctr c n i) = R.CON n i
tr p d Bottom = R.INT 0

-- Set boolean 'original' flag on funtions; if true, function was
-- originally defined, and if false, function was introduced in
-- Reduceron compilation process.

flagFuns :: Int -> R.Prog -> R.Prog
flagFuns i p = map flag p
  where
    flag (f, pop, luts, push, apps) =
      (f, pop, luts, map fl push, map (mapAtoms fl) apps)
    fl (R.FUN _ n f) = R.FUN (f < i) n f
    fl a = a

-- Fragment a program such that: (1) each template contains at most
-- 'n' applications; (2) each template contains at most 'm' LUTs; (3)
-- each template pushes a maximum of 'm' atoms; (4) if a template
-- pushes more than one atom, then it contains at most 'n-1'
-- applications; (5) the first atom pushed by the final template does
-- not refer to any of that template's applications (the 'refers
-- check').

fragment :: Int -> Int -> R.Prog -> R.Prog
fragment n m p = flagFuns (length p) (p' ++ ts')
  where
    (_, ts, p') = runWS (mapM (frag n m) p) (length p)
    ts' = map snd (sortBy cmp ts)
    cmp (a, b) (c, d) = compare a c

sub n m = m-n

frag n m (f, pop, luts, push, apps)
  | length apps >= n || any isPRIM apps = fr n m (f, pop, luts, push, apps)
  | length luts > m =
      do x <- newId
         t <- frag n m (f, pop, dropBack m luts, push, apps)
         write (x, t)
         return (f, 0, takeBack m luts, [R.FUN False 0 x], [])
  | refersCheck (head push) = fr n m (f, pop, luts, push, apps)
  | otherwise = return (f, pop, luts, push, apps)

fr n m (f, pop, luts, push, apps) =
  do x <- newId
     let offset = length (take n apps0)
     let apps' = map (relocate (sub offset)) (drop n apps0 ++ apps1)
     let push' = map (reloc (sub offset)) push
     t <- frag n m (f, pop, dropBack m luts, push', apps')
     write (x, t)
     return (f, 0, takeBack m luts, [R.FUN False 0 x], take n apps0)
  where
    (apps0, apps1) = splitPredexes apps

relocate f app = mapAtoms (reloc f) app

reloc f (R.VAR sh i) = R.VAR sh (f i)
reloc f x = x

refersCheck (R.VAR sh i) = i >= 0
refersCheck _ = False

-- Top-level compilation

redCompile :: InlineFlag -> Int -> Int -> Int -> Int -> Int -> Prog -> R.Prog
redCompile i slen alen napps nluts nregs =
  fragment napps nluts . translate i alen slen nregs

-- Auxiliary functions

takeBack n xs = reverse $ take n $ reverse xs

dropBack n xs = reverse $ drop n $ reverse xs

getVar :: Exp -> String
getVar (Var v) = v

vars :: Exp -> [Id]
vars (Var v) = [v]
vars e = []

isAlts :: Exp -> Bool
isAlts (Alts fs n) = True
isAlts e = False

getAlts :: Exp -> Id
getAlts (Alts fs n)
  | null fs = error "RedCompile: getAlts"
  | otherwise = head fs

lift f p = xs ++ ys
  where (xs, ys) = partition (\(g, _, _) -> f == g) p

type Bind a = WriterState (Id, [Exp]) Int a

freshId :: Bind Id
freshId = do n <- get ; set (n+1) ; return ("new_bind_" ++ show n)

type Define a = WriterState (Int, R.Template) Int a

newId :: Define Int
newId = do n <- get ; set (n+1) ; return n