{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE OverloadedLists     #-}
{-# LANGUAGE PatternGuards       #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving  #-}
{-# LANGUAGE TupleSections       #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeFamilies        #-}
{-# LANGUAGE TypeOperators       #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.Trafo.Sharing
-- Copyright   : [2008..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
-- This module implements HOAS to de Bruijn conversion of array expressions
-- while incorporating sharing information.
--

module Data.Array.Accelerate.Trafo.Sharing (

  -- * HOAS to de Bruijn conversion
  convertAcc, convertAccWith,

  Afunction, AfunctionR, ArraysFunctionR, AfunctionRepr(..), afunctionRepr,
  convertAfun, convertAfunWith,

  Function, FunctionR, EltFunctionR, FunctionRepr(..), functionRepr,
  convertExp, convertExpWith,
  convertFun, convertFunWith,

  -- convertSeq

) where

import Data.Array.Accelerate.AST                                    hiding ( PreOpenAcc(..), OpenAcc(..), Acc, OpenExp(..), Exp, Boundary(..), HasArraysR(..), showPreAccOp )
import Data.Array.Accelerate.AST.Environment
import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.AST.LeftHandSide
import Data.Array.Accelerate.AST.Var
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Debug.Flags                            as Debug
import Data.Array.Accelerate.Debug.Trace                            as Debug
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Representation.Array                   ( Array, ArraysR, ArrayR(..), showArraysR )
import Data.Array.Accelerate.Representation.Shape                   hiding ( zip )
import Data.Array.Accelerate.Representation.Stencil
import Data.Array.Accelerate.Representation.Tag
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Smart                                  as Smart hiding ( StencilR )
import Data.Array.Accelerate.Sugar.Array                            hiding ( Array, ArraysR, (!!) )
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Trafo.Config
import Data.Array.Accelerate.Trafo.Substitution
import Data.Array.Accelerate.Trafo.Var
import Data.Array.Accelerate.Type
import Data.BitSet                                                  ( (\\), member )
import qualified Data.Array.Accelerate.AST                          as AST
import qualified Data.Array.Accelerate.Representation.Stencil       as R
import qualified Data.Array.Accelerate.Sugar.Array                  as Sugar

import Control.Applicative                                          hiding ( Const )
import Control.Lens                                                 ( over, mapped, _1, _2 )
import Control.Monad.Fix
import Data.Function                                                ( on )
import Data.Hashable
import Data.List                                                    ( elemIndex, findIndex, groupBy, intercalate, partition )
import Data.Maybe
import Data.Monoid                                                  ( Any(..) )
import System.IO.Unsafe                                             ( unsafePerformIO )
import System.Mem.StableName
import Text.Printf
import qualified Data.HashMap.Strict                                as Map
import qualified Data.HashSet                                       as Set
import qualified Data.HashTable.IO                                  as Hash
import qualified Data.IntMap                                        as IntMap
import Prelude


-- Layouts
-- -------

-- A layout of an environment has an entry for each entry of the environment.
-- Each entry in the layout holds the de Bruijn index that refers to the
-- corresponding entry in the environment.
--
data Layout s env env' where
  EmptyLayout :: Layout s env ()
  PushLayout  :: Layout s env env1
              -> LeftHandSide s t env1 env2
              -> Vars s env t
              -> Layout s env env2

type ELayout = Layout ScalarType
type ArrayLayout = Layout ArrayR


-- Project the nth index out of an environment layout.
--
-- The first argument provides context information for error messages in the
-- case of failure.
--
prjIdx :: forall s t env env1. HasCallStack
       => String
       -> (forall t'. TupR s t' -> ShowS)
       -> (forall u v. TupR s u -> TupR s v -> Maybe (u :~: v))
       -> TupR s t
       -> Int
       -> Layout s env env1
       -> Vars s env t
prjIdx context showTp matchTp tp = go
  where
    go :: forall env'. HasCallStack => Int -> Layout s env env' -> Vars s env t
    go _ EmptyLayout                        = no "environment does not contain index"
    go 0 (PushLayout _ lhs vars)
      | Just Refl <- matchTp tp tp'         = vars
      | otherwise                           = no $ printf "couldn't match expected type `%s' with actual type `%s'"
                                                          (showTp tp  "")
                                                          (showTp tp' "")
      where
        tp' = lhsToTupR lhs
    go n (PushLayout l _ _)                 = go (n-1) l

    no :: HasCallStack => String -> a
    no reason = internalError (printf "%s\nin the context: %s" reason context)

-- Add an entry to a layout, incrementing all indices
--
incLayout :: env1 :> env2 -> Layout s env1 env' -> Layout s env2 env'
incLayout _ EmptyLayout            = EmptyLayout
incLayout k (PushLayout lyt lhs v) = PushLayout (incLayout k lyt) lhs (weakenVars k v)

sizeLayout :: Layout s env env' -> Int
sizeLayout EmptyLayout          = 0
sizeLayout (PushLayout lyt _ _) = 1 + sizeLayout lyt

-- Conversion from HOAS to de Bruijn computation AST
-- =================================================

-- Array computations
-- ------------------

-- | Convert a closed array expression to de Bruijn form while also incorporating sharing
-- information.
--
convertAcc :: HasCallStack => Acc arrs -> AST.Acc (Sugar.ArraysR arrs)
convertAcc = convertAccWith defaultOptions

convertAccWith :: HasCallStack => Config -> Acc arrs -> AST.Acc (Sugar.ArraysR arrs)
convertAccWith config (Acc acc) = convertOpenAcc config EmptyLayout acc


-- | Convert a closed function over array computations, while incorporating
-- sharing information.
--
convertAfun :: HasCallStack => Afunction f => f -> AST.Afun (ArraysFunctionR f)
convertAfun = convertAfunWith defaultOptions

convertAfunWith :: HasCallStack => Afunction f => Config -> f -> AST.Afun (ArraysFunctionR f)
convertAfunWith config = convertOpenAfun config EmptyLayout

data AfunctionRepr f ar areprr where
  AfunctionReprBody
    :: Arrays b => AfunctionRepr (Acc b) b (Sugar.ArraysR b)

  AfunctionReprLam
    :: Arrays a
    => AfunctionRepr b br breprr
    -> AfunctionRepr (Acc a -> b) (a -> br) (Sugar.ArraysR a -> breprr)

-- Convert a HOAS fragment into de Bruijn form, binding variables into the typed
-- environment layout one binder at a time.
--
-- NOTE: Because we convert one binder at a time left-to-right, the bound
--       variables ('vars') will have de Bruijn index _zero_ as the outermost
--       binding, and thus go to the end of the list.
--
class Afunction f where
  type AfunctionR f
  type ArraysFunctionR f
  afunctionRepr   :: HasCallStack => AfunctionRepr f (AfunctionR f) (ArraysFunctionR f)
  convertOpenAfun :: HasCallStack => Config -> ArrayLayout aenv aenv -> f -> AST.OpenAfun aenv (ArraysFunctionR f)

instance (Arrays a, Afunction r) => Afunction (Acc a -> r) where
  type AfunctionR      (Acc a -> r) = a -> AfunctionR r
  type ArraysFunctionR (Acc a -> r) = Sugar.ArraysR a -> ArraysFunctionR r

  afunctionRepr = AfunctionReprLam $ afunctionRepr @r
  convertOpenAfun config alyt f
    | repr <- Sugar.arraysR @a
    , DeclareVars lhs k value <- declareVars repr
    = let
        a     = Acc $ SmartAcc $ Atag repr $ sizeLayout alyt
        alyt' = PushLayout (incLayout k alyt) lhs (value weakenId)
      in
        Alam lhs $ convertOpenAfun config alyt' $ f a

instance Arrays b => Afunction (Acc b) where
  type AfunctionR      (Acc b) = b
  type ArraysFunctionR (Acc b) = Sugar.ArraysR b
  afunctionRepr = AfunctionReprBody
  convertOpenAfun config alyt (Acc body) = Abody $ convertOpenAcc config alyt body

convertSmartAfun1
    :: HasCallStack
    => Config
    -> ArraysR a
    -> (SmartAcc a -> SmartAcc b)
    -> AST.Afun (a -> b)
convertSmartAfun1 config repr f
  | DeclareVars lhs _ value <- declareVars repr
  = let
      a     = SmartAcc $ Atag repr 0
      alyt' = PushLayout EmptyLayout lhs (value weakenId)
    in
      Alam lhs $ Abody $ convertOpenAcc config alyt' $ f a

-- | Convert an open array expression to de Bruijn form while also incorporating sharing
-- information.
--
convertOpenAcc
    :: HasCallStack
    => Config
    -> ArrayLayout aenv aenv
    -> SmartAcc arrs
    -> AST.OpenAcc aenv arrs
convertOpenAcc config alyt acc =
  let lvl                      = sizeLayout alyt
      fvs                      = [lvl-1, lvl-2 .. 0]
      (sharingAcc, initialEnv) = recoverSharingAcc config lvl fvs acc
  in
  convertSharingAcc config alyt initialEnv sharingAcc


-- | Convert an array expression with given array environment layout and sharing information into
-- de Bruijn form while recovering sharing at the same time (by introducing appropriate let
-- bindings).  The latter implements the third phase of sharing recovery.
--
-- The sharing environment 'env' keeps track of all currently bound sharing variables, keeping them
-- in reverse chronological order (outermost variable is at the end of the list).
--
convertSharingAcc
    :: forall aenv arrs. HasCallStack
    => Config
    -> ArrayLayout aenv aenv
    -> [StableSharingAcc]
    -> ScopedAcc arrs
    -> AST.OpenAcc aenv arrs
convertSharingAcc _ alyt aenv (ScopedAcc lams (AvarSharing sa repr))
  | Just i <- findIndex (matchStableAcc sa) aenv'
  = avarsIn AST.OpenAcc
  $ prjIdx (ctxt ++ "; i = " ++ show i) showArraysR matchArraysR repr i alyt
  | null aenv'
  = error $ "Cyclic definition of a value of type 'Acc' (sa = " ++ show (hashStableNameHeight sa) ++ ")"
  | otherwise
  = internalError err
  where
    aenv' = lams ++ aenv
    ctxt = "shared 'Acc' tree with stable name " ++ show (hashStableNameHeight sa)
    err  = "inconsistent valuation @ " ++ ctxt ++ ";\n  aenv = " ++ show aenv'

convertSharingAcc config alyt aenv (ScopedAcc lams (AletSharing sa@(StableSharingAcc (_ :: StableAccName as) boundAcc) bodyAcc))
  = case declareVars $ AST.arraysR bound of
      DeclareVars lhs k value ->
        let
          alyt' = PushLayout (incLayout k alyt) lhs (value weakenId)
        in
          AST.OpenAcc $ AST.Alet
            lhs
            bound
            (convertSharingAcc config alyt' (sa:aenv') bodyAcc)
  where
    aenv' = lams ++ aenv
    bound = convertSharingAcc config alyt aenv' (ScopedAcc [] boundAcc)

convertSharingAcc config alyt aenv (ScopedAcc lams (AccSharing _ preAcc))
  = AST.OpenAcc
  $ let aenv' = lams ++ aenv

        cvtA :: ScopedAcc a -> AST.OpenAcc aenv a
        cvtA = convertSharingAcc config alyt aenv'

        cvtE :: ScopedExp t -> AST.Exp aenv t
        cvtE = convertSharingExp config EmptyLayout alyt [] aenv'

        cvtF1 :: TypeR a -> (SmartExp a -> ScopedExp b) -> AST.Fun aenv (a -> b)
        cvtF1 = convertSharingFun1 config alyt aenv'

        cvtF2 :: TypeR a -> TypeR b -> (SmartExp a -> SmartExp b -> ScopedExp c) -> AST.Fun aenv (a -> b -> c)
        cvtF2 = convertSharingFun2 config alyt aenv'

        cvtAfun1 :: ArraysR a -> (SmartAcc a -> ScopedAcc b) -> AST.OpenAfun aenv (a -> b)
        cvtAfun1 = convertSharingAfun1 config alyt aenv'

        cvtAprj :: forall a b c. PairIdx (a, b) c -> ScopedAcc (a, b) -> AST.OpenAcc aenv c
        cvtAprj ix a = cvtAprj' ix $ cvtA a

        cvtAprj' :: forall a b c aenv1. PairIdx (a, b) c -> AST.OpenAcc aenv1 (a, b) -> AST.OpenAcc aenv1 c
        cvtAprj' PairIdxLeft  (AST.OpenAcc (AST.Apair a _)) = a
        cvtAprj' PairIdxRight (AST.OpenAcc (AST.Apair _ b)) = b
        cvtAprj' ix a = case declareVars $ AST.arraysR a of
          DeclareVars lhs _ value ->
            AST.OpenAcc $ AST.Alet lhs a $ cvtAprj' ix $ avarsIn AST.OpenAcc $ value weakenId
    in
    case preAcc of

      Atag repr i
        -> let AST.OpenAcc a = avarsIn AST.OpenAcc $ prjIdx ("de Bruijn conversion tag " ++ show i) showArraysR matchArraysR repr i alyt
           in  a

      Pipe reprA reprB reprC (afun1 :: SmartAcc as -> ScopedAcc bs) (afun2 :: SmartAcc bs -> ScopedAcc cs) acc
        | DeclareVars lhs k value <- declareVars reprB ->
          let
            noStableSharing = StableSharingAcc noStableAccName (undefined :: SharingAcc acc exp ())
            boundAcc = AST.Apply reprB (cvtAfun1 reprA afun1) (cvtA acc)
            alyt'   = PushLayout (incLayout k alyt) lhs (value weakenId)
            bodyAcc = AST.Apply reprC
                        (convertSharingAfun1 config alyt' (noStableSharing : aenv') reprB afun2)
                        (avarsIn AST.OpenAcc $ value weakenId)
          in AST.Alet lhs (AST.OpenAcc boundAcc) (AST.OpenAcc bodyAcc)

      Aforeign repr ff afun acc
        -> AST.Aforeign repr ff (convertSmartAfun1 config (Smart.arraysR acc) afun) (cvtA acc)

      Acond b acc1 acc2           -> AST.Acond (cvtE b) (cvtA acc1) (cvtA acc2)
      Awhile reprA pred iter init -> AST.Awhile (cvtAfun1 reprA pred) (cvtAfun1 reprA iter) (cvtA init)
      Anil                        -> AST.Anil
      Apair acc1 acc2             -> AST.Apair (cvtA acc1) (cvtA acc2)
      Aprj ix a                   -> let AST.OpenAcc a' = cvtAprj ix a
                                     in a'
      Use repr array              -> AST.Use repr array
      Unit tp e                   -> AST.Unit tp (cvtE e)
      Generate repr@(ArrayR shr _) sh f
                                  -> AST.Generate repr (cvtE sh) (cvtF1 (shapeType shr) f)
      Reshape shr e acc           -> AST.Reshape shr (cvtE e) (cvtA acc)
      Replicate si ix acc         -> AST.Replicate si (cvtE ix) (cvtA acc)
      Slice si acc ix             -> AST.Slice si (cvtA acc) (cvtE ix)
      Map t1 t2 f acc             -> AST.Map t2 (cvtF1 t1 f) (cvtA acc)
      ZipWith t1 t2 t3 f acc1 acc2
                                  -> AST.ZipWith t3 (cvtF2 t1 t2 f) (cvtA acc1) (cvtA acc2)
      Fold tp f e acc             -> AST.Fold (cvtF2 tp tp f) (cvtE <$> e) (cvtA acc)
      FoldSeg i tp f e acc1 acc2  -> AST.FoldSeg i (cvtF2 tp tp f) (cvtE <$> e) (cvtA acc1) (cvtA acc2)
      Scan  d tp f e acc          -> AST.Scan  d (cvtF2 tp tp f) (cvtE <$> e) (cvtA acc)
      Scan' d tp f e acc          -> AST.Scan' d (cvtF2 tp tp f) (cvtE e)     (cvtA acc)
      Permute (ArrayR shr tp) f dftAcc perm acc
                                  -> AST.Permute (cvtF2 tp tp f) (cvtA dftAcc) (cvtF1 (shapeType shr) perm) (cvtA acc)
      Backpermute shr newDim perm acc
                                  -> AST.Backpermute shr (cvtE newDim) (cvtF1 (shapeType shr) perm) (cvtA acc)
      Stencil stencil tp f boundary acc
        -> AST.Stencil stencil
                       tp
                       (convertSharingStencilFun1 config alyt aenv' stencil f)
                       (convertSharingBoundary config alyt aenv' (stencilShapeR stencil) boundary)
                       (cvtA acc)
      Stencil2 stencil1 stencil2 tp f bndy1 acc1 bndy2 acc2
        | shr <- stencilShapeR stencil1
        -> AST.Stencil2 stencil1
                        stencil2
                        tp
                        (convertSharingStencilFun2 config alyt aenv' stencil1 stencil2 f)
                        (convertSharingBoundary config alyt aenv' shr bndy1)
                        (cvtA acc1)
                        (convertSharingBoundary config alyt aenv' shr bndy2)
                        (cvtA acc2)
      -- Collect seq -> AST.Collect (convertSharingSeq config alyt EmptyLayout aenv' [] seq)

{--
-- Sequence expressions
-- --------------------

-- | Convert a closed sequence expression to de Bruijn form while incorporating
-- sharing information.
--
convertSeq
    :: Typeable s
    => Bool             -- ^ recover sharing of array computations ?
    -> Bool             -- ^ recover sharing of scalar expressions ?
    -> Bool             -- ^ recover sharing of sequence computations ?
    -> Bool             -- ^ always float array computations out of expressions?
    -> Seq s            -- ^ computation to be converted
    -> AST.Seq s
convertSeq shareAcc shareExp shareSeq floatAcc seq
  = let config = Config shareAcc shareExp shareSeq floatAcc
        (sharingSeq, initialEnv) = recoverSharingSeq config seq
    in
    convertSharingSeq config EmptyLayout EmptyLayout [] initialEnv sharingSeq

convertSharingSeq
    :: forall aenv senv arrs.
       Config
    -> Layout aenv aenv
    -> Layout senv senv
    -> [StableSharingAcc]
    -> [StableSharingSeq]
    -> ScopedSeq arrs
    -> AST.PreOpenSeq AST.OpenAcc aenv senv arrs
convertSharingSeq _ _ slyt _ senv (ScopedSeq (SvarSharing sn))
  | Just i <- findIndex (matchStableSeq sn) senv
  = AST.Reify $ prjIdx (ctxt ++ "; i = " ++ show i) i slyt
  | null senv
  = error $ "Cyclic definition of a value of type 'Seq' (sa = " ++
            show (hashStableNameHeight sn) ++ ")"
  | otherwise
  = $internalError "convertSharingSeq" err
  where
    ctxt = "shared 'Seq' tree with stable name " ++ show (hashStableNameHeight sn)
    err  = "inconsistent valuation @ " ++ ctxt ++ ";\n  senv = " ++ show senv
convertSharingSeq config alyt slyt aenv senv (ScopedSeq (SletSharing sa@(StableSharingSeq _ (SeqSharing _ boundSeq)) bodySeq))
  = convSeq boundSeq bodySeq
  where
    convSeq :: forall bnd body.
               PreSeq ScopedAcc ScopedSeq ScopedExp bnd
            -> ScopedSeq body
            -> AST.PreOpenSeq AST.OpenAcc aenv senv body
    convSeq bnd body =
      case bnd of
        StreamIn arrs               -> producer $ AST.StreamIn arrs
        ToSeq slix acc              -> producer $ mkToSeq slix (cvtA acc)
        MapSeq afun x               -> producer $ AST.MapSeq (cvtAF1 afun) (asIdx x)
        ZipWithSeq afun x y         -> producer $ AST.ZipWithSeq (cvtAF2 afun) (asIdx x) (asIdx y)
        ScanSeq fun e x             -> producer $ AST.ScanSeq (cvtF2 fun) (cvtE e) (asIdx x)
        _                           -> $internalError "convertSharingSeq:convSeq" "Consumer appears to have been let bound"
      where
        producer :: Arrays a
                 => AST.Producer AST.OpenAcc aenv senv a
                 -> AST.PreOpenSeq AST.OpenAcc aenv senv body
        producer p = AST.Producer p $ convertSharingSeq config alyt slyt' aenv (sa:senv) body
          where
            slyt' = incLayout slyt `PushLayout` ZeroIdx

        asIdx :: (HasCallStack, Arrays a)
              => ScopedSeq [a]
              -> Idx senv a
        asIdx (ScopedSeq (SvarSharing sn))
          | Just i <- findIndex (matchStableSeq sn) senv
          = prjIdx (ctxt ++ "; i = " ++ show i) i slyt
          | null senv
          = error $ "Cyclic definition of a value of type 'Seq' (sa = " ++
                    show (hashStableNameHeight sn) ++ ")"
          | otherwise
          = $internalError "convertSharingSeq" err
          where
            ctxt = "shared 'Seq' tree with stable name " ++ show (hashStableNameHeight sn)
            err  = "inconsistent valuation @ " ++ ctxt ++ ";\n  senv = " ++ show senv
        asIdx _
          = $internalError "convertSharingSeq:asIdx" "Sequence computation not in A-normal form"

        cvtA :: forall a. Arrays a => ScopedAcc a -> AST.OpenAcc aenv a
        cvtA acc = convertSharingAcc config alyt aenv acc

        cvtE :: forall t. Elt t => ScopedExp t -> AST.Exp aenv t
        cvtE = convertSharingExp config EmptyLayout alyt [] aenv

        cvtF2 :: (Elt a, Elt b, Elt c) => (Exp a -> Exp b -> ScopedExp c) -> AST.Fun aenv (a -> b -> c)
        cvtF2 = convertSharingFun2 config alyt aenv

        cvtAF1 :: forall a b. (Arrays a, Arrays b) => (Acc a -> ScopedAcc b) -> OpenAfun aenv (a -> b)
        cvtAF1 afun = convertSharingAfun1 config alyt aenv afun

        cvtAF2 :: forall a b c. (Arrays a, Arrays b, Arrays c) => (Acc a -> Acc b -> ScopedAcc c) -> OpenAfun aenv (a -> b -> c)
        cvtAF2 afun = convertSharingAfun2 config alyt aenv afun

convertSharingSeq _ _ _ _ _ (ScopedSeq (SletSharing _ _))
 = $internalError "convertSharingSeq" "Sequence computation not in A-normal form"

convertSharingSeq config alyt slyt aenv senv s
  = cvtC s
  where
    cvtC :: ScopedSeq a -> AST.PreOpenSeq AST.OpenAcc aenv senv a
    cvtC (ScopedSeq (SeqSharing _ s)) =
      case s of
        FoldSeq fun e x                    -> AST.Consumer $ AST.FoldSeq (cvtF2 fun) (cvtE e) (asIdx x)
        FoldSeqFlatten afun acc x          -> AST.Consumer $ AST.FoldSeqFlatten (cvtAF3 afun) (cvtA acc) (asIdx x)
        Stuple t                           -> AST.Consumer $ AST.Stuple (cvtST t)
        _                                  -> $internalError "convertSharingSeq" "Producer has not been let bound"
    cvtC _ = $internalError "convertSharingSeq" "Unreachable"

    asIdx :: Arrays a
          => ScopedSeq [a]
          -> Idx senv a
    asIdx (ScopedSeq (SvarSharing sn))
      | Just i <- findIndex (matchStableSeq sn) senv
      = prjIdx (ctxt ++ "; i = " ++ show i) i slyt
      | null senv
      = error $ "Cyclic definition of a value of type 'Seq' (sa = " ++
                show (hashStableNameHeight sn) ++ ")"
      | otherwise
      = $internalError "convertSharingSeq" err
      where
        ctxt = "shared 'Seq' tree with stable name " ++ show (hashStableNameHeight sn)
        err  = "inconsistent valuation @ " ++ ctxt ++ ";\n  senv = " ++ show senv
    asIdx _
      = $internalError "convertSharingSeq:asIdx" "Sequence computation not in A-normal form"

    cvtA :: forall a. Arrays a => ScopedAcc a -> AST.OpenAcc aenv a
    cvtA acc = convertSharingAcc config alyt aenv acc

    cvtE :: forall t. Elt t => ScopedExp t -> AST.Exp aenv t
    cvtE = convertSharingExp config EmptyLayout alyt [] aenv

    cvtF2 :: (Elt a, Elt b, Elt c) => (Exp a -> Exp b -> ScopedExp c) -> AST.Fun aenv (a -> b -> c)
    cvtF2 = convertSharingFun2 config alyt aenv

    cvtAF3 :: forall a b c d. (Arrays a, Arrays b, Arrays c, Arrays d) => (Acc a -> Acc b -> Acc c -> ScopedAcc d) -> OpenAfun aenv (a -> b -> c -> d)
    cvtAF3 afun = convertSharingAfun3 config alyt aenv afun

    cvtST :: Atuple ScopedSeq t -> Atuple (AST.Consumer AST.OpenAcc aenv senv) t
    cvtST NilAtup        = NilAtup
    cvtST (SnocAtup t c) | AST.Consumer c' <- cvtC c
                         = SnocAtup (cvtST t) c'
                         | otherwise
                         = $internalError "convertSharingSeq" "Unreachable"
--}

convertSharingAfun1
    :: forall aenv a b. HasCallStack
    => Config
    -> ArrayLayout aenv aenv
    -> [StableSharingAcc]
    -> ArraysR a
    -> (SmartAcc a -> ScopedAcc b)
    -> OpenAfun aenv (a -> b)
convertSharingAfun1 config alyt aenv reprA f
  | DeclareVars lhs k value <- declareVars reprA
  = let
      alyt' = PushLayout (incLayout k alyt) lhs (value weakenId)
      body = f undefined
    in
      Alam lhs (Abody (convertSharingAcc config alyt' aenv body))

-- | Convert a boundary condition
--
convertSharingBoundary
    :: forall aenv sh e. HasCallStack
    => Config
    -> ArrayLayout aenv aenv
    -> [StableSharingAcc]
    -> ShapeR sh
    -> PreBoundary ScopedAcc ScopedExp (Array sh e)
    -> AST.Boundary aenv (Array sh e)
convertSharingBoundary config alyt aenv shr = cvt
  where
    cvt :: PreBoundary ScopedAcc ScopedExp (Array sh e) -> AST.Boundary aenv (Array sh e)
    cvt bndy =
      case bndy of
        Clamp       -> AST.Clamp
        Mirror      -> AST.Mirror
        Wrap        -> AST.Wrap
        Constant v  -> AST.Constant v
        Function f  -> AST.Function $ convertSharingFun1 config alyt aenv (shapeType shr) f


-- mkToSeq :: forall slsix slix e aenv senv. (Division slsix, DivisionSlice slsix ~ slix, Elt e, Elt slix, Slice slix)
--         => slsix
--         -> AST.OpenAcc              aenv (Array (FullShape  slix) e)
--         -> AST.Producer AST.OpenAcc aenv senv (Array (SliceShape slix) e)
-- mkToSeq _ = AST.ToSeq (sliceIndex slix) (Proxy :: Proxy slix)
--   where
--     slix = undefined :: slix


-- Scalar functions
-- ----------------

-- | Convert a closed scalar function to de Bruijn form while incorporating
-- sharing information.
--
-- The current design requires all free variables to be bound at the outermost
-- level --- we have no general apply term, and so lambdas are always outermost.
-- In higher-order abstract syntax, this represents an n-ary, polyvariadic
-- function.
--
convertFun :: (HasCallStack, Function f) => f -> AST.Fun () (EltFunctionR f)
convertFun
  = convertFunWith
  $ defaultOptions { options = options defaultOptions \\ [seq_sharing, acc_sharing] }

convertFunWith :: (HasCallStack, Function f) => Config -> f -> AST.Fun () (EltFunctionR f)
convertFunWith config = convertOpenFun config EmptyLayout

data FunctionRepr f r reprr where
  FunctionReprBody
    :: Elt b => FunctionRepr (Exp b) b (EltR b)

  FunctionReprLam
    :: Elt a
    => FunctionRepr b br breprr
    -> FunctionRepr (Exp a -> b) (a -> br) (EltR a -> breprr)

class Function f where
  type FunctionR f
  type EltFunctionR f

  functionRepr   :: HasCallStack => FunctionRepr f (FunctionR f) (EltFunctionR f)
  convertOpenFun :: HasCallStack => Config -> ELayout env env -> f -> AST.OpenFun env () (EltFunctionR f)

instance (Elt a, Function r) => Function (Exp a -> r) where
  type FunctionR (Exp a -> r) = a -> FunctionR r
  type EltFunctionR (Exp a -> r) = EltR a -> EltFunctionR r

  functionRepr = FunctionReprLam $ functionRepr @r
  convertOpenFun config lyt f
    | tp <- eltR @a
    , DeclareVars lhs k value <- declareVars tp
    = let
        e    = Exp $ SmartExp $ Tag tp $ sizeLayout lyt
        lyt' = PushLayout (incLayout k lyt) lhs (value weakenId)
      in
        Lam lhs $ convertOpenFun config lyt' $ f e

instance Elt b => Function (Exp b) where
  type FunctionR (Exp b) = b
  type EltFunctionR (Exp b) = EltR b

  functionRepr = FunctionReprBody
  convertOpenFun config lyt (Exp body) = Body $ convertOpenExp config lyt body

convertSmartFun
    :: HasCallStack
    => Config
    -> TypeR a
    -> (SmartExp a -> SmartExp b)
    -> AST.Fun () (a -> b)
convertSmartFun config tp f
  | DeclareVars lhs _ value <- declareVars tp
  = let
      e    = SmartExp $ Tag tp 0
      lyt' = PushLayout EmptyLayout lhs (value weakenId)
    in
      Lam lhs $ Body $ convertOpenExp config lyt' $ f e

-- Scalar expressions
-- ------------------

-- | Convert a closed scalar expression to de Bruijn form while incorporating
-- sharing information.
--
convertExp
    :: HasCallStack
    => Exp e
    -> AST.Exp () (EltR e)
convertExp
  = convertExpWith
  $ defaultOptions { options = options defaultOptions \\ [seq_sharing, acc_sharing] }

convertExpWith
      :: HasCallStack
      => Config
      -> Exp e
      -> AST.Exp () (EltR e)
convertExpWith config (Exp e) = convertOpenExp config EmptyLayout e

convertOpenExp
    :: HasCallStack
    => Config
    -> ELayout env env
    -> SmartExp e
    -> AST.OpenExp env () e
convertOpenExp config lyt exp =
  let lvl                      = sizeLayout lyt
      fvs                      = [lvl-1, lvl-2 .. 0]
      (sharingExp, initialEnv) = recoverSharingExp config lvl fvs exp
  in
  convertSharingExp config lyt EmptyLayout initialEnv [] sharingExp


-- | Convert an open expression with given environment layouts and sharing information into
-- de Bruijn form while recovering sharing at the same time (by introducing appropriate let
-- bindings).  The latter implements the third phase of sharing recovery.
--
-- The sharing environments 'env' and 'aenv' keep track of all currently bound sharing variables,
-- keeping them in reverse chronological order (outermost variable is at the end of the list).
--
convertSharingExp
    :: forall t env aenv. HasCallStack
    => Config
    -> ELayout env env          -- scalar environment
    -> ArrayLayout aenv aenv    -- array environment
    -> [StableSharingExp]       -- currently bound sharing variables of expressions
    -> [StableSharingAcc]       -- currently bound sharing variables of array computations
    -> ScopedExp t              -- expression to be converted
    -> AST.OpenExp env aenv t
convertSharingExp config lyt alyt env aenv exp@(ScopedExp lams _) = cvt exp
  where
    -- scalar environment with any lambda bound variables this expression is rooted in
    env' = lams ++ env

    cvt :: HasCallStack => ScopedExp t' -> AST.OpenExp env aenv t'
    cvt (ScopedExp _ (VarSharing se tp))
      | Just i <- findIndex (matchStableExp se) env' = expVars (prjIdx (ctx i) shows matchTypeR tp i lyt)
      | otherwise                                    = internalError msg
      where
        ctx i = printf "shared 'Exp' tree with stable name %d; i=%d" (hashStableNameHeight se) i
        msg   = unlines
          [ if null env'
               then printf "cyclic definition of a value of type 'Exp' (sa=%d)" (hashStableNameHeight se)
               else printf "inconsistent valuation at shared 'Exp' tree (sa=%d; env=%s)" (hashStableNameHeight se) (show env')
          , ""
          , "Note that this error usually arises due to the presence of nested data"
          , "parallelism; when a parallel computation attempts to initiate new parallel"
          , "work _which depends on_ a scalar variable given by the first computation."
          , ""
          , "For example, suppose we wish to sum the columns of a two-dimensional array."
          , "You might think to do this in the following (incorrect) way: by constructing"
          , "a vector using 'generate' where at each index we 'slice' out the"
          , "corresponding column of the matrix and 'sum' it:"
          , ""
          , "> sum_columns_ndp :: Num a => Acc (Matrix a) -> Acc (Vector a)"
          , "> sum_columns_ndp mat ="
          , ">   let I2 rows cols = shape mat"
          , ">   in  generate (I1 cols)"
          , ">                (\\(I1 col) -> the $ sum (slice mat (lift (Z :. All :. col))))"
          , ""
          , "However, since both 'generate' and 'slice' are data-parallel operators, and"
          , "moreover that 'slice' _depends on_ the argument 'col' given to it by the"
          , "'generate' function, this operation requires nested parallelism and is thus"
          , "not (at this time) permitted. The clue that this definition is invalid is"
          , "that in order to create a program which will be accepted by the type checker,"
          , "we had to use the function 'the' to retrieve the result of the parallel"
          , "'sum', effectively concealing that this is a collective operation in order to"
          , "match the type expected by 'generate'."
          , ""
          , "To solve this particular example, we can make use of the fact that (most)"
          , "collective operations in Accelerate are _rank polymorphic_. The 'sum'"
          , "operation reduces along the innermost dimension of an array of arbitrary"
          , "rank, reducing the dimensionality of the array by one. To reduce the array"
          , "column-wise then, we first need to simply 'transpose' the array:"
          , ""
          , "> sum_columns :: Num a => Acc (Matrix a) -> Acc (Vector a)"
          , "> sum_columns = sum . transpose"
          , ""
          , "If you feel like this is not the cause of your error, or you would like some"
          , "advice locating the problem and perhaps with a workaround, feel free to"
          , "submit an issue at the above URL."
          ]

    cvt (ScopedExp _ (LetSharing se@(StableSharingExp _ boundExp) bodyExp))
      | DeclareVars lhs k value <- declareVars $ typeR boundExp
      = let
          lyt' = PushLayout (incLayout k lyt) lhs (value weakenId)
        in
          AST.Let lhs (cvt (ScopedExp [] boundExp)) (convertSharingExp config lyt' alyt (se:env') aenv bodyExp)
    cvt (ScopedExp _ (ExpSharing _ pexp))
      = case pexp of
          Tag tp i              -> expVars $ prjIdx ("de Bruijn conversion tag " ++ show i) shows matchTypeR tp i lyt
          Match _ e             -> cvt e  -- XXX: this should probably be an error
          Const tp v            -> AST.Const tp v
          Undef tp              -> AST.Undef tp
          Prj idx e             -> cvtPrj idx (cvt e)
          Nil                   -> AST.Nil
          Pair e1 e2            -> AST.Pair (cvt e1) (cvt e2)
          VecPack   vec e       -> AST.VecPack   vec (cvt e)
          VecUnpack vec e       -> AST.VecUnpack vec (cvt e)
          ToIndex shr sh ix     -> AST.ToIndex shr (cvt sh) (cvt ix)
          FromIndex shr sh e    -> AST.FromIndex shr (cvt sh) (cvt e)
          Case e rhs            -> cvtCase (cvt e) (over (mapped . _2) cvt rhs)
          Cond e1 e2 e3         -> AST.Cond (cvt e1) (cvt e2) (cvt e3)
          While tp p it i       -> AST.While (cvtFun1 tp p) (cvtFun1 tp it) (cvt i)
          PrimConst c           -> AST.PrimConst c
          PrimApp f e           -> cvtPrimFun f (cvt e)
          Index _ a e           -> AST.Index (cvtAvar a) (cvt e)
          LinearIndex _ a i     -> AST.LinearIndex (cvtAvar a) (cvt i)
          Shape _ a             -> AST.Shape (cvtAvar a)
          ShapeSize shr e       -> AST.ShapeSize shr (cvt e)
          Foreign repr ff f e   -> AST.Foreign repr ff (convertSmartFun config (typeR e) f) (cvt e)
          Coerce t1 t2 e        -> AST.Coerce t1 t2 (cvt e)

    cvtPrj :: forall a b c env1 aenv1. PairIdx (a, b) c -> AST.OpenExp env1 aenv1 (a, b) -> AST.OpenExp env1 aenv1 c
    cvtPrj PairIdxLeft  (AST.Pair a _) = a
    cvtPrj PairIdxRight (AST.Pair _ b) = b
    cvtPrj ix a
      | DeclareVars lhs _ value <- declareVars $ AST.expType a
      = AST.Let lhs a (cvtPrj ix (expVars (value weakenId)))

    cvtA :: HasCallStack => ScopedAcc a -> AST.OpenAcc aenv a
    cvtA = convertSharingAcc config alyt aenv

    cvtAvar :: HasCallStack => ScopedAcc a -> AST.ArrayVar aenv a
    cvtAvar a = case cvtA a of
      AST.OpenAcc (AST.Avar var) -> var
      _                          -> internalError "Expected array computation in expression to be floated out"

    cvtFun1 :: HasCallStack => TypeR a -> (SmartExp a -> ScopedExp b) -> AST.OpenFun env aenv (a -> b)
    cvtFun1 tp f
      | DeclareVars lhs k value <- declareVars tp
      = let
          lyt' = PushLayout (incLayout k lyt) lhs (value weakenId)
          body = f undefined
        in
          Lam lhs $ Body $ convertSharingExp config lyt' alyt env' aenv body

    -- Push primitive function applications down through let bindings so that
    -- they are adjacent to their arguments. It looks a bit nicer this way.
    --
    cvtPrimFun :: HasCallStack => AST.PrimFun (a -> r) -> AST.OpenExp env' aenv' a -> AST.OpenExp env' aenv' r
    cvtPrimFun f e = case e of
      AST.Let lhs bnd body -> AST.Let lhs bnd (cvtPrimFun f body)
      x                    -> AST.PrimApp f x

    -- Convert the flat list of equations into nested case statement
    -- directly on the tag variables.
    --
    cvtCase :: HasCallStack => AST.OpenExp env' aenv' a -> [(TagR a, AST.OpenExp env' aenv' b)] -> AST.OpenExp env' aenv' b
    cvtCase s es
      | AST.Pair{} <- s
      = nested s es
      | DeclareVars lhs _ value <- declareVars (AST.expType s)
      = AST.Let lhs s $ nested (expVars (value weakenId)) (over (mapped . _2) (weakenE (weakenWithLHS lhs)) es)
      where
        nested :: HasCallStack => AST.OpenExp env' aenv' a -> [(TagR a, AST.OpenExp env' aenv' b)] -> AST.OpenExp env' aenv' b
        nested _ [(_,r)] = r
        nested s rs      =
          let groups = groupBy (eqT `on` fst) rs
              tags   = map (firstT . fst . head) groups
              e      = prjT (fst (head rs)) s
              rhs    = map (nested s . map (over _1 ignore)) groups
          in
          AST.Case e (zip tags rhs) Nothing

        -- Extract the variable representing this particular tag from the
        -- scrutinee. This is safe because we let-bind the argument first.
        prjT :: TagR a -> AST.OpenExp env' aenv' a -> AST.OpenExp env' aenv' TAG
        prjT = fromJust $$ go
          where
            go :: TagR a -> AST.OpenExp env' aenv' a -> Maybe (AST.OpenExp env' aenv' TAG)
            go TagRtag{}        (AST.Pair l _) = Just l
            go (TagRpair ta tb) (AST.Pair l r) =
              case go ta l of
                Just t  -> Just t
                Nothing -> go tb r
            go _ _ = Nothing

        -- Equality up to the first constructor tag encountered
        eqT :: TagR a -> TagR a -> Bool
        eqT a b = snd $ go a b
          where
            go :: TagR a -> TagR a -> (Any, Bool)
            go TagRunit          TagRunit          = no True
            go TagRsingle{}      TagRsingle{}      = no True
            go TagRundef{}       TagRundef{}       = no True
            go (TagRtag v1 _)    (TagRtag v2 _)    = yes (v1 == v2)
            go (TagRpair a1 b1)  (TagRpair a2 b2)  =
              let (Any r, s) = go a1 a2
               in case r of
                    True  -> yes s
                    False -> go b1 b2
            go _ _ = no False

        firstT :: TagR a -> TAG
        firstT = fromJust . go
          where
            go :: TagR a -> Maybe TAG
            go (TagRtag v _)  = Just v
            go (TagRpair a b) =
              case go a of
                Just t  -> Just t
                Nothing -> go b
            go _ = Nothing

        -- Replace the first constructor tag encountered with a regular
        -- scalar tag, so that that tag will be ignored in the recursive
        -- case.
        ignore = snd . go
          where
            go :: TagR a -> (Any, TagR a)
            go TagRunit         = no  $ TagRunit
            go (TagRsingle t)   = no  $ TagRsingle t
            go (TagRundef t)    = no  $ TagRundef t
            go (TagRtag _ a)    = yes $ TagRpair (TagRundef scalarType) a
            go (TagRpair a1 a2) =
              let (Any r, a1') = go a1
               in case r of
                    True  -> yes $ TagRpair a1' a2
                    False -> TagRpair a1' <$> go a2

        yes :: x -> (Any, x)
        yes e = (Any True, e)

        no :: x -> (Any, x)
        no = pure


-- | Convert a unary functions
--
convertSharingFun1
    :: HasCallStack
    => Config
    -> ArrayLayout aenv aenv
    -> [StableSharingAcc]       -- currently bound array sharing-variables
    -> TypeR a
    -> (SmartExp a -> ScopedExp b)
    -> AST.Fun aenv (a -> b)
convertSharingFun1 config alyt aenv tp f
  | DeclareVars lhs _ value <- declareVars tp
  = let
      a               = SmartExp undefined             -- the 'tag' was already embedded in Phase 1
      lyt             = PushLayout EmptyLayout lhs (value weakenId)
      openF           = convertSharingExp config lyt alyt [] aenv (f a)
    in
      Lam lhs (Body openF)

-- | Convert a binary functions
--
convertSharingFun2
    :: HasCallStack
    => Config
    -> ArrayLayout aenv aenv
    -> [StableSharingAcc]       -- currently bound array sharing-variables
    -> TypeR a
    -> TypeR b
    -> (SmartExp a -> SmartExp b -> ScopedExp c)
    -> AST.Fun aenv (a -> b -> c)
convertSharingFun2 config alyt aenv ta tb f
  | DeclareVars lhs1 _  value1 <- declareVars ta
  , DeclareVars lhs2 k2 value2 <- declareVars tb
  = let
      a               = SmartExp undefined
      b               = SmartExp undefined
      lyt1            = PushLayout EmptyLayout lhs1 (value1 k2)
      lyt2            = PushLayout lyt1        lhs2 (value2 weakenId)
      openF           = convertSharingExp config lyt2 alyt [] aenv (f a b)
    in
      Lam lhs1 $ Lam lhs2 $ Body openF

-- | Convert a unary stencil function
--
convertSharingStencilFun1
    :: HasCallStack
    => Config
    -> ArrayLayout aenv aenv
    -> [StableSharingAcc]               -- currently bound array sharing-variables
    -> R.StencilR sh a stencil
    -> (SmartExp stencil -> ScopedExp b)
    -> AST.Fun aenv (stencil -> b)
convertSharingStencilFun1 config alyt aenv sR1 stencil =
  convertSharingFun1 config alyt aenv (R.stencilR sR1) stencil

-- | Convert a binary stencil function
--
convertSharingStencilFun2
    :: HasCallStack
    => Config
    -> ArrayLayout aenv aenv
    -> [StableSharingAcc]               -- currently bound array sharing-variables
    -> R.StencilR sh a stencil1
    -> R.StencilR sh b stencil2
    -> (SmartExp stencil1 -> SmartExp stencil2 -> ScopedExp c)
    -> AST.Fun aenv (stencil1 -> stencil2 -> c)
convertSharingStencilFun2 config alyt aenv sR1 sR2 stencil =
  convertSharingFun2 config alyt aenv (R.stencilR sR1) (R.stencilR sR2) stencil


-- Sharing recovery
-- ================

-- Sharing recovery proceeds in two phases:
--
-- /Phase One: build the occurrence map/
--
-- This is a top-down traversal of the AST that computes a map from AST nodes to the number of
-- occurrences of that AST node in the overall Accelerate program.  An occurrences count of two or
-- more indicates sharing.
--
-- IMPORTANT: To avoid unfolding the sharing, we do not descent into subtrees that we have
--   previously encountered.  Hence, the complexity is proportional to the number of nodes in the
--   tree /with/ sharing.  Consequently, the occurrence count is that in the tree with sharing
--   as well.
--
-- During computation of the occurrences, the tree is annotated with stable names on every node
-- using 'AccSharing' constructors and all but the first occurrence of shared subtrees are pruned
-- using 'AvarSharing' constructors (see 'SharingAcc' below).  This phase is impure as it is based
-- on stable names.
--
-- We use a hash table (instead of 'Data.Map') as computing stable names forces us to live in IO
-- anyway.  Once, the computation of occurrence counts is complete, we freeze the hash table into
-- a 'Data.Map'.
--
-- (Implemented by 'makeOccMap*'.)
--
-- /Phase Two: determine scopes and inject sharing information/
--
-- This is a bottom-up traversal that determines the scope for every binding to be introduced
-- to share a subterm.  It uses the occurrence map to determine, for every shared subtree, the
-- lowest AST node at which the binding for that shared subtree can be placed (using a
-- 'AletSharing' constructor)— it's the meet of all the shared subtree occurrences.
--
-- The second phase is also replacing the first occurrence of each shared subtree with a
-- 'AvarSharing' node and floats the shared subtree up to its binding point.
--
--  (Implemented by 'determineScopes*'.)
--
-- /Sharing recovery for expressions/
--
-- We recover sharing for each expression (including function bodies) independently of any other
-- expression — i.e., we cannot share scalar expressions across array computations.  Hence, during
-- Phase One, we mark all scalar expression nodes with a stable name and compute one occurrence map
-- for every scalar expression (including functions) that occurs in an array computation.  These
-- occurrence maps are added to the root of scalar expressions using 'RootExp'.
--
-- NB: We do not need to worry sharing recovery will try to float a shared subexpression past a
--     binder that occurs in that subexpression.  Why?  Otherwise, the binder would already occur
--     out of scope in the original source program.
--
-- /Lambda bound variables/
--
-- During sharing recovery, lambda bound variables appear in the form of 'Atag' and 'Tag' data
-- constructors.  The tag values are determined during Phase One of sharing recovery by computing
-- the /level/ of each variable at its binding occurrence.  The level at the root of the AST is 0
-- and increases by one with each lambda on each path through the AST.

-- Stable names
-- ------------

-- Opaque stable name for AST nodes — used to key the occurrence map.
--
data StableASTName c where
  StableASTName :: StableName (c t) -> StableASTName c

instance Show (StableASTName c) where
  show (StableASTName sn) = show $ hashStableName sn

instance Eq (StableASTName c) where
  StableASTName sn1 == StableASTName sn2 = eqStableName sn1 sn2

instance Hashable (StableASTName c) where
  hashWithSalt s (StableASTName sn) = hashWithSalt s sn

makeStableAST :: c t -> IO (StableName (c t))
makeStableAST e = e `seq` makeStableName e

-- Stable name for an AST node including the height of the AST representing the array computation.
--
data StableNameHeight t = StableNameHeight (StableName t) Int

instance Eq (StableNameHeight t) where
  (StableNameHeight sn1 _) == (StableNameHeight sn2 _) = eqStableName sn1 sn2

higherSNH :: StableNameHeight t1 -> StableNameHeight t2 -> Bool
StableNameHeight _ h1 `higherSNH` StableNameHeight _ h2 = h1 > h2

hashStableNameHeight :: StableNameHeight t -> Int
hashStableNameHeight (StableNameHeight sn _) = hashStableName sn

-- Mutable occurrence map
-- ----------------------

-- Hash table keyed on the stable names of array computations.
--
type HashTable key val = Hash.BasicHashTable key val
type ASTHashTable c v  = HashTable (StableASTName c) v

-- Mutable hashtable version of the occurrence map, which associates each AST node with an
-- occurrence count and the height of the AST.
--
type OccMapHash c = ASTHashTable c (Int, Int)

-- Create a new hash table keyed on AST nodes.
--
newASTHashTable :: IO (ASTHashTable c v)
newASTHashTable = Hash.new

-- Enter one AST node occurrence into an occurrence map.  Returns 'Just h' if this is a repeated
-- occurrence and the height of the repeatedly occurring AST is 'h'.
--
-- If this is the first occurrence, the 'height' *argument* must provide the height of the AST;
-- otherwise, the height will be *extracted* from the occurrence map.  In the latter case, this
-- function yields the AST height.
--
enterOcc :: OccMapHash c -> StableASTName c -> Int -> IO (Maybe Int)
enterOcc occMap sa height
  = Hash.mutate occMap sa
  $ \case
      Nothing           -> (Just (1,   height),  Nothing)
      Just (n, heightS) -> (Just (n+1, heightS), Just heightS)


-- Immutable occurrence map
-- ------------------------

-- Immutable version of the occurrence map (storing the occurrence count only, not the height).  We
-- use the 'StableName' hash to index an 'IntMap' and disambiguate 'StableName's with identical
-- hashes explicitly, storing them in a list in the 'IntMap'.
--
type OccMap c = IntMap.IntMap [(StableASTName c, Int)]

-- Turn a mutable into an immutable occurrence map.
--
freezeOccMap :: OccMapHash c -> IO (OccMap c)
freezeOccMap oc
  = do
      ocl <- Hash.toList oc
      traceChunk "OccMap" (show ocl)

      return . IntMap.fromList
             . map (\kvs -> (key (head kvs), kvs))
             . groupBy sameKey
             . map dropHeight
             $ ocl
  where
    key (StableASTName sn, _) = hashStableName sn
    sameKey kv1 kv2           = key kv1 == key kv2
    dropHeight (k, (cnt, _))  = (k, cnt)

-- Look up the occurrence map keyed by array computations using a stable name.  If the key does
-- not exist in the map, return an occurrence count of '1'.
--
lookupWithASTName :: OccMap c -> StableASTName c -> Int
lookupWithASTName oc sa@(StableASTName sn)
  = fromMaybe 1 $ IntMap.lookup (hashStableName sn) oc >>= Prelude.lookup sa

-- Look up the occurrence map keyed by array computations using a sharing array computation.  If an
-- the key does not exist in the map, return an occurrence count of '1'.
--
lookupWithSharingAcc :: OccMap SmartAcc -> StableSharingAcc -> Int
lookupWithSharingAcc oc (StableSharingAcc (StableNameHeight sn _) _)
  = lookupWithASTName oc (StableASTName sn)

-- Look up the occurrence map keyed by scalar expressions using a sharing expression.  If an
-- the key does not exist in the map, return an occurrence count of '1'.
--
lookupWithSharingExp :: OccMap SmartExp -> StableSharingExp -> Int
lookupWithSharingExp oc (StableSharingExp (StableNameHeight sn _) _)
  = lookupWithASTName oc (StableASTName sn)


-- Stable 'SmartAcc' nodes
-- ------------------

-- Stable name for 'SmartAcc' nodes including the height of the AST.
--
type StableAccName t = StableNameHeight (SmartAcc t)

-- Interleave sharing annotations into an array computation AST.  Subtrees can be marked as being
-- represented by variable (binding a shared subtree) using 'AvarSharing' and as being prefixed by
-- a let binding (for a shared subtree) using 'AletSharing'.
--
data SharingAcc acc exp arrs where
  AvarSharing :: StableAccName arrs -> ArraysR arrs             -> SharingAcc acc exp arrs
  AletSharing :: StableSharingAcc -> acc arrs                   -> SharingAcc acc exp arrs
  AccSharing  :: StableAccName arrs -> PreSmartAcc acc exp arrs -> SharingAcc acc exp arrs

instance HasArraysR acc => HasArraysR (SharingAcc acc exp) where
  arraysR (AvarSharing _ repr) = repr
  arraysR (AletSharing _ acc)  = Smart.arraysR acc
  arraysR (AccSharing  _ acc)  = Smart.arraysR acc


-- Array expression with sharing but shared values have not been scoped; i.e. no let bindings. If
-- the expression is rooted in a function, the list contains the tags of the variables bound by the
-- immediate surrounding lambdas.
data UnscopedAcc t = UnscopedAcc [Int] (SharingAcc UnscopedAcc RootExp t)

instance HasArraysR UnscopedAcc where
  arraysR (UnscopedAcc _ acc) = Smart.arraysR acc


-- Array expression with sharing. For expressions rooted in functions the list holds a sorted
-- environment corresponding to the variables bound in the immediate surounding lambdas.
data ScopedAcc t = ScopedAcc [StableSharingAcc] (SharingAcc ScopedAcc ScopedExp t)

instance HasArraysR ScopedAcc where
  arraysR (ScopedAcc _ acc) = Smart.arraysR acc


-- Stable name for an array computation associated with its sharing-annotated version.
--
data StableSharingAcc where
  StableSharingAcc :: StableAccName arrs
                   -> SharingAcc ScopedAcc ScopedExp arrs
                   -> StableSharingAcc

instance Show StableSharingAcc where
  show (StableSharingAcc sn _) = show $ hashStableNameHeight sn

instance Eq StableSharingAcc where
  StableSharingAcc (StableNameHeight sn1 _) _ == StableSharingAcc (StableNameHeight sn2 _) _
    = eqStableName sn1 sn2

higherSSA :: StableSharingAcc -> StableSharingAcc -> Bool
StableSharingAcc sn1 _ `higherSSA` StableSharingAcc sn2 _ = sn1 `higherSNH` sn2

-- Test whether the given stable names matches an array computation with sharing.
--
matchStableAcc :: StableAccName arrs -> StableSharingAcc -> Bool
matchStableAcc (StableNameHeight sn1 _) (StableSharingAcc (StableNameHeight sn2 _) _)
  = eqStableName sn1 sn2

-- Dummy entry for environments to be used for unused variables.
--
{-# NOINLINE noStableAccName #-}
noStableAccName :: StableAccName arrs
noStableAccName = unsafePerformIO $ StableNameHeight <$> makeStableName undefined <*> pure 0

-- Stable 'Exp' nodes
-- ------------------

-- Stable name for 'Exp' nodes including the height of the AST.
--
type StableExpName t = StableNameHeight (SmartExp t)

-- Interleave sharing annotations into a scalar expressions AST in the same manner as 'SharingAcc'
-- do for array computations.
--
data SharingExp acc exp t where
  VarSharing :: StableExpName t -> TypeR t               -> SharingExp acc exp t
  LetSharing :: StableSharingExp -> exp t                -> SharingExp acc exp t
  ExpSharing :: StableExpName t -> PreSmartExp acc exp t -> SharingExp acc exp t

instance HasTypeR exp => HasTypeR (SharingExp acc exp) where
  typeR (VarSharing _ tp)  = tp
  typeR (LetSharing _ exp) = Smart.typeR exp
  typeR (ExpSharing _ exp) = Smart.typeR exp

-- Specifies a scalar expression AST with sharing annotations but no scoping; i.e. no LetSharing
-- constructors. If the expression is rooted in a function, the list contains the tags of the
-- variables bound by the immediate surrounding lambdas.
data UnscopedExp t = UnscopedExp [Int] (SharingExp UnscopedAcc UnscopedExp t)

instance HasTypeR UnscopedExp where
  typeR (UnscopedExp _ exp) = Smart.typeR exp

-- Specifies a scalar expression AST with sharing. For expressions rooted in functions the list
-- holds a sorted environment corresponding to the variables bound in the immediate surounding
-- lambdas.
data ScopedExp t = ScopedExp [StableSharingExp] (SharingExp ScopedAcc ScopedExp t)

instance HasTypeR ScopedExp where
  typeR (ScopedExp _ exp) = Smart.typeR exp

-- Expressions rooted in 'SmartAcc' computations.
--
-- * When counting occurrences, the root of every expression embedded in an 'SmartAcc' is annotated by
--   an occurrence map for that one expression (excluding any subterms that are rooted in embedded
--   'SmartAcc's.)
--
data RootExp t = RootExp (OccMap SmartExp) (UnscopedExp t)

-- Stable name for an expression associated with its sharing-annotated version.
--
data StableSharingExp where
  StableSharingExp :: StableExpName t -> SharingExp ScopedAcc ScopedExp t -> StableSharingExp

instance Show StableSharingExp where
  show (StableSharingExp sn _) = show $ hashStableNameHeight sn

instance Eq StableSharingExp where
  StableSharingExp (StableNameHeight sn1 _) _ == StableSharingExp (StableNameHeight sn2 _) _ =
    eqStableName sn1 sn2

higherSSE :: StableSharingExp -> StableSharingExp -> Bool
StableSharingExp sn1 _ `higherSSE` StableSharingExp sn2 _ = sn1 `higherSNH` sn2

-- Test whether the given stable names matches an expression with sharing.
--
matchStableExp :: StableExpName t -> StableSharingExp -> Bool
matchStableExp (StableNameHeight sn1 _) (StableSharingExp (StableNameHeight sn2 _) _) = eqStableName sn1 sn2

-- Dummy entry for environments to be used for unused variables.
--
{-# NOINLINE noStableExpName #-}
noStableExpName :: StableExpName t
noStableExpName = unsafePerformIO $ StableNameHeight <$> makeStableName undefined <*> pure 0


{--
-- Stable 'Seq' nodes
-- ------------------

-- Stable name for 'Seq' nodes including the height of the AST.
--
type StableSeqName arrs = StableNameHeight (Seq arrs)

-- Interleave sharing annotations into an sequence computation AST in the same manner as SharingAcc
-- and SharingExp
--
data SharingSeq acc seq exp arrs where
  SvarSharing :: (Typeable arrs, Arrays arrs)
              => StableSeqName [arrs]                       -> SharingSeq acc seq exp [arrs]
  SletSharing :: StableSharingSeq -> seq t                  -> SharingSeq acc seq exp t
  SeqSharing  :: Typeable arrs
              => StableSeqName arrs -> PreSeq acc seq exp arrs -> SharingSeq acc seq exp arrs

-- Array expression with sharing but shared values have not been scoped; i.e. no let bindings. If
-- the expression is rooted in a function, the list contains the tags of the variables bound by the
-- immediate surrounding lambdas.
data UnscopedSeq t = UnscopedSeq (SharingSeq UnscopedAcc UnscopedSeq RootExp t)

-- Array expression with sharing. For expressions rooted in functions the list holds a sorted
-- environment corresponding to the variables bound in the immediate surounding lambdas.
data ScopedSeq t = ScopedSeq (SharingSeq ScopedAcc ScopedSeq ScopedExp t)

-- Sequences rooted in 'Acc' computations.
--
-- * When counting occurrences, the root of every sequence embedded in an 'Acc' is annotated by
--   an occurrence map for that one expression (excluding any subterms that are rooted in embedded
--   'Acc's.)
--
data RootSeq t = RootSeq (OccMap Seq) (UnscopedSeq t)

-- Stable name for an array computation associated with its sharing-annotated version.
--
data StableSharingSeq where
  StableSharingSeq :: Typeable arrs
                   => StableSeqName arrs
                   -> SharingSeq ScopedAcc ScopedSeq ScopedExp arrs
                   -> StableSharingSeq

instance Show StableSharingSeq where
  show (StableSharingSeq sn _) = show $ hashStableNameHeight sn

instance Eq StableSharingSeq where
  StableSharingSeq sn1 _ == StableSharingSeq sn2 _
    | Just sn1' <- gcast sn1 = sn1' == sn2
    | otherwise              = False

higherSSS :: StableSharingSeq -> StableSharingSeq -> Bool
StableSharingSeq sn1 _ `higherSSS` StableSharingSeq sn2 _ = sn1 `higherSNH` sn2

-- Test whether the given stable names matches an array computation with sharing.
--
matchStableSeq :: Typeable arrs => StableSeqName arrs -> StableSharingSeq -> Bool
matchStableSeq sn1 (StableSharingSeq sn2 _)
  | Just sn1' <- gcast sn1 = sn1' == sn2
  | otherwise              = False
--}


-- Occurrence counting
-- ===================

-- Compute the 'SmartAcc' occurrence map, marks all nodes (both 'Seq' and 'Exp' nodes) with stable names,
-- and drop repeated occurrences of shared 'SmartAcc' and 'Exp' subtrees (Phase One).
--
-- We compute a single 'SmartAcc' occurrence map for the whole AST, but one 'Exp' occurrence map for each
-- sub-expression rooted in an 'SmartAcc' operation.  This is as we cannot float 'Exp' subtrees across
-- 'SmartAcc' operations, but we can float 'SmartAcc' subtrees out of 'Exp' expressions.
--
-- Note [Traversing functions and side effects]
-- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-- We need to descent into function bodies to build the 'OccMap' with all occurrences in the
-- function bodies.  Due to the side effects in the construction of the occurrence map and, more
-- importantly, the dependence of the second phase on /global/ occurrence information, we may not
-- delay the body traversals by putting them under a lambda.  Hence, we apply each function, to
-- traverse its body and use a /dummy abstraction/ of the result.
--
-- For example, given a function 'f', we traverse 'f (Tag 0)', which yields a transformed body 'e'.
-- As the result of the traversal of the overall function, we use 'const e'.  Hence, it is crucial
-- that the 'Tag' supplied during the initial traversal is already the one required by the HOAS to
-- de Bruijn conversion in 'convertSharingAcc' — any subsequent application of 'const e' will only
-- yield 'e' with the embedded 'Tag 0' of the original application.  During sharing recovery, we
-- float /all/ free variables ('Atag' and 'Tag') out to construct the initial environment for
-- producing de Bruijn indices, which replaces them by 'AvarSharing' or 'VarSharing' nodes.  Hence,
-- the tag values only serve the purpose of determining the ordering in that initial environment.
-- They are /not/ directly used to compute the de Brujin indices.
--
makeOccMapAcc
    :: HasCallStack
    => Config
    -> Level
    -> SmartAcc arrs
    -> IO (UnscopedAcc arrs, OccMap SmartAcc)
makeOccMapAcc config lvl acc = do
  traceLine "makeOccMapAcc" "Enter"
  accOccMap             <- newASTHashTable
  (acc', _)             <- makeOccMapSharingAcc config accOccMap lvl acc
  frozenAccOccMap       <- freezeOccMap accOccMap
  traceLine "makeOccMapAcc" "Exit"
  return (acc', frozenAccOccMap)


makeOccMapSharingAcc
    :: HasCallStack
    => Config
    -> OccMapHash SmartAcc
    -> Level
    -> SmartAcc arrs
    -> IO (UnscopedAcc arrs, Int)
makeOccMapSharingAcc config accOccMap = traverseAcc
  where
    traverseFun1
        :: HasCallStack
        => Level
        -> TypeR a
        -> (SmartExp a -> SmartExp b)
        -> IO (SmartExp a -> RootExp b, Int)
    traverseFun1 = makeOccMapFun1 config accOccMap

    traverseFun2
        :: HasCallStack
        => Level
        -> TypeR a
        -> TypeR b
        -> (SmartExp a -> SmartExp b -> SmartExp c)
        -> IO (SmartExp a -> SmartExp b -> RootExp c, Int)
    traverseFun2 = makeOccMapFun2 config accOccMap

    traverseAfun1
        :: HasCallStack
        => Level
        -> ArraysR a
        -> (SmartAcc a -> SmartAcc b)
        -> IO (SmartAcc a -> UnscopedAcc b, Int)
    traverseAfun1 = makeOccMapAfun1 config accOccMap

    traverseExp
      :: HasCallStack
      => Level
      -> SmartExp e
      -> IO (RootExp e, Int)
    traverseExp = makeOccMapExp config accOccMap

    traverseBoundary
        :: HasCallStack
        => Level
        -> ShapeR sh
        -> PreBoundary SmartAcc SmartExp (Array sh e)
        -> IO (PreBoundary UnscopedAcc RootExp (Array sh e), Int)
    traverseBoundary lvl shr bndy =
      case bndy of
        Clamp      -> return (Clamp, 0)
        Mirror     -> return (Mirror, 0)
        Wrap       -> return (Wrap, 0)
        Constant v -> return (Constant v, 0)
        Function f -> do
          (f', h) <- traverseFun1 lvl (shapeType shr) f
          return (Function f', h)

    -- traverseSeq :: forall arrs. Typeable arrs
    --             => Level -> Seq arrs
    --             -> IO (RootSeq arrs, Int)
    -- traverseSeq = makeOccMapRootSeq config accOccMap

    traverseAcc
        :: forall arrs. HasCallStack
        => Level
        -> SmartAcc arrs
        -> IO (UnscopedAcc arrs, Int)
    traverseAcc lvl acc@(SmartAcc pacc)
      = mfix $ \ ~(_, height) -> do
          -- Compute stable name and enter it into the occurrence map
          --
          sn                         <- makeStableAST acc
          heightIfRepeatedOccurrence <- enterOcc accOccMap (StableASTName sn) height

          traceLine (showPreAccOp pacc) $ do
            let hash = show (hashStableName sn)
            case heightIfRepeatedOccurrence of
              Just height -> "REPEATED occurrence (sn = " ++ hash ++ "; height = " ++ show height ++ ")"
              Nothing     -> "first occurrence (sn = " ++ hash ++ ")"

          -- Reconstruct the computation in shared form.
          --
          -- In case of a repeated occurrence, the height comes from the occurrence map; otherwise
          -- it is computed by the traversal function passed in 'newAcc'. See also 'enterOcc'.
          --
          let reconstruct :: IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
                          -> IO (UnscopedAcc arrs, Int)
              reconstruct newAcc
                = case heightIfRepeatedOccurrence of
                    Just height | acc_sharing `member` options config
                      -> return (UnscopedAcc [] (AvarSharing (StableNameHeight sn height) (Smart.arraysR pacc)), height)
                    _ -> do (acc, height) <- newAcc
                            return (UnscopedAcc [] (AccSharing (StableNameHeight sn height) acc), height)

          reconstruct $ case pacc of
            Atag repr i                 -> return (Atag repr i, 0)           -- height is 0!
            Pipe repr1 repr2 repr3 afun1 afun2 acc
                                        -> do
                                             (afun1', h1) <- traverseAfun1 lvl repr1 afun1
                                             (afun2', h2) <- traverseAfun1 lvl repr2 afun2
                                             (acc', h3)   <- traverseAcc lvl acc
                                             return (Pipe repr1 repr2 repr3 afun1' afun2' acc'
                                                    , h1 `max` h2 `max` h3 + 1)
            Aforeign repr ff afun acc   -> travA (Aforeign repr ff afun) acc
            Acond e acc1 acc2           -> do
                                             (e'   , h1) <- traverseExp lvl e
                                             (acc1', h2) <- traverseAcc lvl acc1
                                             (acc2', h3) <- traverseAcc lvl acc2
                                             return (Acond e' acc1' acc2', h1 `max` h2 `max` h3 + 1)
            Awhile repr pred iter init  -> do
                                             (pred', h1) <- traverseAfun1 lvl repr pred
                                             (iter', h2) <- traverseAfun1 lvl repr iter
                                             (init', h3) <- traverseAcc lvl init
                                             return (Awhile repr pred' iter' init'
                                                    , h1 `max` h2 `max` h3 + 1)

            Anil                        -> return (Anil, 0)
            Apair acc1 acc2             -> do
                                             (a', h1) <- traverseAcc lvl acc1
                                             (b', h2) <- traverseAcc lvl acc2
                                             return (Apair a' b', h1 `max` h2 + 1)
            Aprj ix a                   -> travA (Aprj ix) a

            Use repr arr                -> return (Use repr arr, 1)
            Unit tp e                   -> do
                                             (e', h) <- traverseExp lvl e
                                             return (Unit tp e', h + 1)
            Generate repr@(ArrayR shr _) e f
                                        -> do
                                             (e', h1) <- traverseExp lvl e
                                             (f', h2) <- traverseFun1 lvl (shapeType shr) f
                                             return (Generate repr e' f', h1 `max` h2 + 1)
            Reshape shr e acc           -> travEA (Reshape shr) e acc
            Replicate si e acc          -> travEA (Replicate si) e acc
            Slice si acc e              -> travEA (flip $ Slice si) e acc
            Map t1 t2 f acc             -> do
                                             (f'  , h1) <- traverseFun1 lvl t1 f
                                             (acc', h2) <- traverseAcc lvl acc
                                             return (Map t1 t2 f' acc', h1 `max` h2 + 1)
            ZipWith t1 t2 t3 f acc1 acc2
                                        -> travF2A2 (ZipWith t1 t2 t3) t1 t2 f acc1 acc2
            Fold tp f e acc             -> travF2MEA (Fold tp) tp tp f e acc
            FoldSeg i tp f e acc1 acc2  -> do
                                             (f'   , h1) <- traverseFun2 lvl tp tp f
                                             (e'   , h2) <- travME e
                                             (acc1', h3) <- traverseAcc lvl acc1
                                             (acc2', h4) <- traverseAcc lvl acc2
                                             return (FoldSeg i tp f' e' acc1' acc2',
                                                     h1 `max` h2 `max` h3 `max` h4 + 1)
            Scan  d tp f e acc          -> travF2MEA (Scan  d tp) tp tp f e acc
            Scan' d tp f e acc          -> travF2EA (Scan' d tp) tp tp f e acc
            Permute repr@(ArrayR shr tp) c acc1 p acc2
                                        -> do
                                             (c'   , h1) <- traverseFun2 lvl tp tp c
                                             (p'   , h2) <- traverseFun1 lvl (shapeType shr) p
                                             (acc1', h3) <- traverseAcc lvl acc1
                                             (acc2', h4) <- traverseAcc lvl acc2
                                             return (Permute repr c' acc1' p' acc2',
                                                     h1 `max` h2 `max` h3 `max` h4 + 1)
            Backpermute shr e p acc     -> do
                                             (e'  , h1) <- traverseExp lvl e
                                             (p'  , h2) <- traverseFun1 lvl (shapeType shr) p
                                             (acc', h3) <- traverseAcc lvl acc
                                             return (Backpermute shr e' p' acc', h1 `max` h2 `max` h3 + 1)
            Stencil s tp f bnd acc      -> do
                                             (f'  , h1) <- makeOccMapStencil1 config accOccMap s lvl f
                                             (bnd', h2) <- traverseBoundary lvl (stencilShapeR s) bnd
                                             (acc', h3) <- traverseAcc lvl acc
                                             return (Stencil s tp f' bnd' acc', h1 `max` h2 `max` h3 + 1)
            Stencil2 s1 s2 tp f bnd1 acc1
                              bnd2 acc2 -> do
                                             let shr = stencilShapeR s1
                                             (f'   , h1) <- makeOccMapStencil2 config accOccMap s1 s2 lvl f
                                             (bnd1', h2) <- traverseBoundary lvl shr bnd1
                                             (acc1', h3) <- traverseAcc lvl acc1
                                             (bnd2', h4) <- traverseBoundary lvl shr bnd2
                                             (acc2', h5) <- traverseAcc lvl acc2
                                             return (Stencil2 s1 s2 tp f' bnd1' acc1' bnd2' acc2',
                                                     h1 `max` h2 `max` h3 `max` h4 `max` h5 + 1)
            -- Collect s                   -> do
            --                                  (s', h) <- traverseSeq lvl s
            --                                  return (Collect s', h + 1)


      where
        travA :: HasCallStack
              => (UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs)
              -> SmartAcc arrs'
              -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
        travA c acc
          = do
              (acc', h) <- traverseAcc lvl acc
              return (c acc', h + 1)

        travEA :: HasCallStack
               => (RootExp b -> UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs)
               -> SmartExp b
               -> SmartAcc arrs'
               -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
        travEA c exp acc
          = do
              (exp', h1) <- traverseExp lvl exp
              (acc', h2) <- traverseAcc lvl acc
              return (c exp' acc', h1 `max` h2 + 1)

        travF2EA
            :: HasCallStack
            => ((SmartExp b -> SmartExp c -> RootExp d) -> RootExp e -> UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs)
            -> TypeR b
            -> TypeR c
            -> (SmartExp b -> SmartExp c -> SmartExp d)
            -> SmartExp e
            -> SmartAcc arrs'
            -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
        travF2EA c t1 t2 fun exp acc
          = do
              (fun', h1) <- traverseFun2 lvl t1 t2 fun
              (exp', h2) <- traverseExp lvl exp
              (acc', h3) <- traverseAcc lvl acc
              return (c fun' exp' acc', h1 `max` h2 `max` h3 + 1)

        travF2MEA
            :: HasCallStack
            => ((SmartExp b -> SmartExp c -> RootExp d) -> Maybe (RootExp e) -> UnscopedAcc arrs' -> PreSmartAcc UnscopedAcc RootExp arrs)
            -> TypeR b
            -> TypeR c
            -> (SmartExp b -> SmartExp c -> SmartExp d)
            -> Maybe (SmartExp e)
            -> SmartAcc arrs'
            -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
        travF2MEA c t1 t2 fun exp acc
          = do
              (fun', h1) <- traverseFun2 lvl t1 t2 fun
              (exp', h2) <- travME exp
              (acc', h3) <- traverseAcc lvl acc
              return (c fun' exp' acc', h1 `max` h2 `max` h3 + 1)

        travME :: HasCallStack => Maybe (SmartExp t) -> IO (Maybe (RootExp t), Int)
        travME Nothing  = return (Nothing, 0)
        travME (Just e) = do
          (e', c) <- traverseExp lvl e
          return (Just e', c)

        travF2A2
            :: HasCallStack
            => ((SmartExp b -> SmartExp c -> RootExp d) -> UnscopedAcc arrs1 -> UnscopedAcc arrs2 -> PreSmartAcc UnscopedAcc RootExp arrs)
            -> TypeR b
            -> TypeR c
            -> (SmartExp b -> SmartExp c -> SmartExp d)
            -> SmartAcc arrs1
            -> SmartAcc arrs2
            -> IO (PreSmartAcc UnscopedAcc RootExp arrs, Int)
        travF2A2 c t1 t2 fun acc1 acc2
          = do
              (fun' , h1) <- traverseFun2 lvl t1 t2 fun
              (acc1', h2) <- traverseAcc lvl acc1
              (acc2', h3) <- traverseAcc lvl acc2
              return (c fun' acc1' acc2', h1 `max` h2 `max` h3 + 1)

makeOccMapAfun1
    :: HasCallStack
    => Config
    -> OccMapHash SmartAcc
    -> Level
    -> ArraysR a
    -> (SmartAcc a -> SmartAcc b)
    -> IO (SmartAcc a -> UnscopedAcc b, Int)
makeOccMapAfun1 config accOccMap lvl repr f = do
  let x = SmartAcc (Atag repr lvl)
  --
  (UnscopedAcc [] body, height) <- makeOccMapSharingAcc config accOccMap (lvl+1) (f x)
  return (const (UnscopedAcc [lvl] body), height)

{--
makeOccMapAfun2 :: (Arrays a, Arrays b, Typeable c)
                => Config
                -> OccMapHash Acc
                -> Level
                -> (Acc a -> Acc b -> Acc c)
                -> IO (Acc a -> Acc b -> UnscopedAcc c, Int)
makeOccMapAfun2 config accOccMap lvl f = do
  let x = Acc (Atag (lvl + 1))
      y = Acc (Atag (lvl + 0))
  --
  (UnscopedAcc [] body, height) <- makeOccMapSharingAcc config accOccMap (lvl+2) (f x y)
  return (\ _ _ -> (UnscopedAcc [lvl, lvl+1] body), height)

makeOccMapAfun3 :: (Arrays a, Arrays b, Arrays c, Typeable d)
                => Config
                -> OccMapHash Acc
                -> Level
                -> (Acc a -> Acc b -> Acc c -> Acc d)
                -> IO (Acc a -> Acc b -> Acc c -> UnscopedAcc d, Int)
makeOccMapAfun3 config accOccMap lvl f = do
  let x = Acc (Atag (lvl + 2))
      y = Acc (Atag (lvl + 1))
      z = Acc (Atag (lvl + 0))
  --
  (UnscopedAcc [] body, height) <- makeOccMapSharingAcc config accOccMap (lvl+3) (f x y z)
  return (\ _ _ _ -> (UnscopedAcc [lvl, lvl+1, lvl+2] body), height)
--}

-- Generate occupancy information for scalar functions and expressions. Helper
-- functions wrapping around 'makeOccMapRootExp' with more specific types.
--
-- See Note [Traversing functions and side effects]
--
makeOccMapExp
    :: HasCallStack
    => Config
    -> OccMapHash SmartAcc
    -> Level
    -> SmartExp e
    -> IO (RootExp e, Int)
makeOccMapExp config accOccMap lvl = makeOccMapRootExp config accOccMap lvl []

makeOccMapFun1
    :: HasCallStack
    => Config
    -> OccMapHash SmartAcc
    -> Level
    -> TypeR a
    -> (SmartExp a -> SmartExp b)
    -> IO (SmartExp a -> RootExp b, Int)
makeOccMapFun1 config accOccMap lvl tp f = do
  let x = SmartExp (Tag tp lvl)
  --
  (body, height) <- makeOccMapRootExp config accOccMap (lvl+1) [lvl] (f x)
  return (const body, height)

makeOccMapFun2
    :: HasCallStack
    => Config
    -> OccMapHash SmartAcc
    -> Level
    -> TypeR a
    -> TypeR b
    -> (SmartExp a -> SmartExp b -> SmartExp c)
    -> IO (SmartExp a -> SmartExp b -> RootExp c, Int)
makeOccMapFun2 config accOccMap lvl t1 t2 f = do
  let x = SmartExp (Tag t1 (lvl+1))
      y = SmartExp (Tag t2 lvl)
  --
  (body, height) <- makeOccMapRootExp config accOccMap (lvl+2) [lvl, lvl+1] (f x y)
  return (\_ _ -> body, height)

makeOccMapStencil1
    :: forall sh a b stencil. HasCallStack
    => Config
    -> OccMapHash SmartAcc
    -> R.StencilR sh a stencil
    -> Level
    -> (SmartExp stencil -> SmartExp b)
    -> IO (SmartExp stencil -> RootExp b, Int)
makeOccMapStencil1 config accOccMap s lvl stencil = do
  let x = SmartExp (Tag (R.stencilR s) lvl)
  --
  (body, height) <- makeOccMapRootExp config accOccMap (lvl+1) [lvl] (stencil x)
  return (const body, height)

makeOccMapStencil2
    :: forall sh a b c stencil1 stencil2. HasCallStack
    => Config
    -> OccMapHash SmartAcc
    -> R.StencilR sh a stencil1
    -> R.StencilR sh b stencil2
    -> Level
    -> (SmartExp stencil1 -> SmartExp stencil2 -> SmartExp c)
    -> IO (SmartExp stencil1 -> SmartExp stencil2 -> RootExp c, Int)
makeOccMapStencil2 config accOccMap sR1 sR2 lvl stencil = do
  let x = SmartExp (Tag (R.stencilR sR1) (lvl+1))
      y = SmartExp (Tag (R.stencilR sR2) lvl)
  --
  (body, height) <- makeOccMapRootExp config accOccMap (lvl+2) [lvl, lvl+1] (stencil x y)
  return (\_ _ -> body, height)


-- Generate sharing information for expressions embedded in Acc computations.
-- Expressions are annotated with:
--
--  1) the tags of free scalar variables (for scalar functions)
--  2) a local occurrence map for that expression.
--
makeOccMapRootExp
    :: HasCallStack
    => Config
    -> OccMapHash SmartAcc
    -> Level                            -- The level of currently bound scalar variables
    -> [Int]                            -- The tags of newly introduced free scalar variables in this expression
    -> SmartExp e
    -> IO (RootExp e, Int)
makeOccMapRootExp config accOccMap lvl fvs exp = do
  traceLine "makeOccMapRootExp" "Enter"
  expOccMap                     <- newASTHashTable
  (UnscopedExp [] exp', height) <- makeOccMapSharingExp config accOccMap expOccMap lvl exp
  frozenExpOccMap               <- freezeOccMap expOccMap
  traceLine "makeOccMapRootExp" "Exit"
  return (RootExp frozenExpOccMap (UnscopedExp fvs exp'), height)


-- Generate sharing information for an open scalar expression.
--
makeOccMapSharingExp
    :: HasCallStack
    => Config
    -> OccMapHash SmartAcc
    -> OccMapHash SmartExp
    -> Level                            -- The level of currently bound variables
    -> SmartExp e
    -> IO (UnscopedExp e, Int)
makeOccMapSharingExp config accOccMap expOccMap = travE
  where
    travE :: forall a. HasCallStack => Level -> SmartExp a -> IO (UnscopedExp a, Int)
    travE lvl exp@(SmartExp pexp)
      = mfix $ \ ~(_, height) -> do
          -- Compute stable name and enter it into the occurrence map
          --
          sn                         <- makeStableAST exp
          heightIfRepeatedOccurrence <- enterOcc expOccMap (StableASTName sn) height

          traceLine (showPreExpOp pexp) $ do
            let hash = show (hashStableName sn)
            case heightIfRepeatedOccurrence of
              Just height -> "REPEATED occurrence (sn = " ++ hash ++ "; height = " ++ show height ++ ")"
              Nothing     -> "first occurrence (sn = " ++ hash ++ ")"

          -- Reconstruct the computation in shared form.
          --
          -- In case of a repeated occurrence, the height comes from the occurrence map; otherwise
          -- it is computed by the traversal function passed in 'newExp'.  See also 'enterOcc'.
          --
          let reconstruct :: IO (PreSmartExp UnscopedAcc UnscopedExp a, Int)
                          -> IO (UnscopedExp a, Int)
              reconstruct newExp
                = case heightIfRepeatedOccurrence of
                    Just height | exp_sharing `member` options config
                      -> return (UnscopedExp [] (VarSharing (StableNameHeight sn height) (typeR pexp)), height)
                    _ -> do (exp, height) <- newExp
                            return (UnscopedExp [] (ExpSharing (StableNameHeight sn height) exp), height)

          reconstruct $ case pexp of
            Tag tp i            -> return (Tag tp i, 0)      -- height is 0!
            Const tp c          -> return (Const tp c, 1)
            Undef tp            -> return (Undef tp, 1)
            Nil                 -> return (Nil, 1)
            Pair e1 e2          -> travE2 Pair e1 e2
            Prj i e             -> travE1 (Prj i) e
            VecPack   vec e     -> travE1 (VecPack   vec) e
            VecUnpack vec e     -> travE1 (VecUnpack vec) e
            ToIndex shr sh ix   -> travE2 (ToIndex shr) sh ix
            FromIndex shr sh e  -> travE2 (FromIndex shr) sh e
            Match t e           -> travE1 (Match t) e
            Case e rhs          -> do
                                     (e',   h1) <- travE lvl e
                                     (rhs', h2) <- unzip <$> sequence [ travE1 (t,) c | (t,c) <- rhs ]
                                     return (Case e' rhs', h1 `max` maximum h2 + 1)
            Cond e1 e2 e3       -> travE3 Cond e1 e2 e3
            While t p iter init -> do
                                     (p'   , h1) <- traverseFun1 lvl t p
                                     (iter', h2) <- traverseFun1 lvl t iter
                                     (init', h3) <- travE lvl init
                                     return (While t p' iter' init', h1 `max` h2 `max` h3 + 1)
            PrimConst c         -> return (PrimConst c, 1)
            PrimApp p e         -> travE1 (PrimApp p) e
            Index tp a e        -> travAE (Index tp) a e
            LinearIndex tp a i  -> travAE (LinearIndex tp) a i
            Shape shr a         -> travA (Shape shr) a
            ShapeSize shr e     -> travE1 (ShapeSize shr) e
            Foreign tp ff f e   -> do
                                      (e', h) <- travE lvl e
                                      return  (Foreign tp ff f e', h+1)
            Coerce t1 t2 e      -> travE1 (Coerce t1 t2) e

      where
        traverseAcc :: HasCallStack => Level -> SmartAcc arrs -> IO (UnscopedAcc arrs, Int)
        traverseAcc = makeOccMapSharingAcc config accOccMap

        traverseFun1
            :: HasCallStack
            => Level
            -> TypeR a
            -> (SmartExp a -> SmartExp b)
            -> IO (SmartExp a -> UnscopedExp b, Int)
        traverseFun1 lvl tp f
          = do
              let x = SmartExp (Tag tp lvl)
              (UnscopedExp [] body, height) <- travE (lvl+1) (f x)
              return (const (UnscopedExp [lvl] body), height + 1)


        travE1 :: HasCallStack => (UnscopedExp b -> r) -> SmartExp b -> IO (r, Int)
        travE1 c e
          = do
              (e', h) <- travE lvl e
              return (c e', h + 1)

        travE2 :: HasCallStack
               => (UnscopedExp b -> UnscopedExp c -> r)
               -> SmartExp b
               -> SmartExp c
               -> IO (r, Int)
        travE2 c e1 e2
          = do
              (e1', h1) <- travE lvl e1
              (e2', h2) <- travE lvl e2
              return (c e1' e2', h1 `max` h2 + 1)

        travE3 :: HasCallStack
               => (UnscopedExp b -> UnscopedExp c -> UnscopedExp d -> r)
               -> SmartExp b
               -> SmartExp c
               -> SmartExp d
               -> IO (r, Int)
        travE3 c e1 e2 e3
          = do
              (e1', h1) <- travE lvl e1
              (e2', h2) <- travE lvl e2
              (e3', h3) <- travE lvl e3
              return (c e1' e2' e3', h1 `max` h2 `max` h3 + 1)

        travA :: HasCallStack => (UnscopedAcc b -> r) -> SmartAcc b -> IO (r, Int)
        travA c acc
          = do
              (acc', h) <- traverseAcc lvl acc
              return (c acc', h + 1)

        travAE :: HasCallStack
               => (UnscopedAcc b -> UnscopedExp c -> r)
               -> SmartAcc b
               -> SmartExp c
               -> IO (r, Int)
        travAE c acc e
          = do
              (acc', h1) <- traverseAcc lvl acc
              (e'  , h2) <- travE lvl e
              return (c acc' e', h1 `max` h2 + 1)

{--
makeOccMapRootSeq
    :: Typeable arrs
    => Config
    -> OccMapHash Acc
    -> Level
    -> Seq arrs
    -> IO (RootSeq arrs, Int)
makeOccMapRootSeq config accOccMap lvl seq = do
  traceLine "makeOccMapRootSeq" "Enter"
  seqOccMap       <- newASTHashTable
  (seq', height)  <- makeOccMapSharingSeq config accOccMap seqOccMap lvl seq
  frozenSeqOccMap <- freezeOccMap seqOccMap
  traceLine "makeOccMapRootSeq" "Exit"
  return (RootSeq frozenSeqOccMap seq', height)

-- Generate sharing information for an open sequence expression.
--
makeOccMapSharingSeq
    :: Typeable e
    => Config
    -> OccMapHash Acc
    -> OccMapHash Seq
    -> Level                            -- The level of currently bound variables
    -> Seq e
    -> IO (UnscopedSeq e, Int)
makeOccMapSharingSeq config accOccMap seqOccMap = traverseSeq
  where
    traverseAcc :: Typeable arrs => Level -> Acc arrs -> IO (UnscopedAcc arrs, Int)
    traverseAcc = makeOccMapSharingAcc config accOccMap

    traverseAfun1 :: (Arrays a, Typeable b) => Level -> (Acc a -> Acc b) -> IO (Acc a -> UnscopedAcc b, Int)
    traverseAfun1 = makeOccMapAfun1 config accOccMap

    traverseAfun2 :: (Arrays a, Arrays b, Typeable c) => Level -> (Acc a -> Acc b -> Acc c) -> IO (Acc a -> Acc b -> UnscopedAcc c, Int)
    traverseAfun2 = makeOccMapAfun2 config accOccMap

    traverseAfun3 :: (Arrays a, Arrays b, Arrays c, Typeable d) => Level -> (Acc a -> Acc b -> Acc c -> Acc d) -> IO (Acc a -> Acc b -> Acc c -> UnscopedAcc d, Int)
    traverseAfun3 = makeOccMapAfun3 config accOccMap

    traverseExp :: Typeable e => Level -> Exp e -> IO (RootExp e, Int)
    traverseExp = makeOccMapExp config accOccMap

    traverseFun2 :: (Elt a, Elt b, Typeable c)
                 => Level
                 -> (Exp a -> Exp b -> Exp c)
                 -> IO (Exp a -> Exp b -> RootExp c, Int)
    traverseFun2 = makeOccMapFun2 config accOccMap

    traverseTup :: Level -> Atuple Seq tup -> IO (Atuple UnscopedSeq tup, Int)
    traverseTup _   NilAtup          = return (NilAtup, 1)
    traverseTup lvl (SnocAtup tup s) = do
                                        (tup', h1) <- traverseTup lvl tup
                                        (s'  , h2) <- traverseSeq lvl s
                                        return (SnocAtup tup' s', h1 `max` h2 + 1)

    traverseSeq :: forall arrs. Typeable arrs => Level -> Seq arrs -> IO (UnscopedSeq arrs, Int)
    traverseSeq lvl acc@(Seq seq)
      = mfix $ \ ~(_, height) -> do
          -- Compute stable name and enter it into the occurrence map
          --
          sn                         <- makeStableAST acc
          heightIfRepeatedOccurrence <- enterOcc seqOccMap (StableASTName sn) height

          traceLine (showPreSeqOp seq) $ do
            let hash = show (hashStableName sn)
            case heightIfRepeatedOccurrence of
              Just height -> "REPEATED occurrence (sn = " ++ hash ++ "; height = " ++ show height ++ ")"
              Nothing     -> "first occurrence (sn = " ++ hash ++ ")"

          -- Reconstruct the computation in shared form.
          --
          -- In case of a repeated occurrence, the height comes from the occurrence map; otherwise
          -- it is computed by the traversal function passed in 'newAcc'. See also 'enterOcc'.
          --
          -- NB: This function can only be used in the case alternatives below; outside of the
          --     case we cannot discharge the 'Arrays arrs' constraint.
          --
          let producer :: (arrs ~ [a], Arrays a)
                       => IO (PreSeq UnscopedAcc UnscopedSeq RootExp arrs, Int)
                       -> IO (UnscopedSeq arrs, Int)
              producer newSeq
                = case heightIfRepeatedOccurrence of
                    Just height | recoverSeqSharing config
                      -> return (UnscopedSeq (SvarSharing (StableNameHeight sn height)), height)
                    _ -> do (seq, height) <- newSeq
                            return (UnscopedSeq (SeqSharing (StableNameHeight sn height) seq), height)

          let consumer :: IO (PreSeq UnscopedAcc UnscopedSeq RootExp arrs, Int)
                       -> IO (UnscopedSeq arrs, Int)
              consumer newSeq
                = do (seq, height) <- newSeq
                     return (UnscopedSeq (SeqSharing (StableNameHeight sn height) seq), height)

          case seq of
            StreamIn arrs -> producer $ return (StreamIn arrs, 1)
            ToSeq sl acc -> producer $ do
              (acc', h1) <- traverseAcc lvl acc
              return (ToSeq sl acc', h1 + 1)
            MapSeq afun s -> producer $ do
              (afun', h1) <- traverseAfun1 lvl afun
              (s'   , h2) <- traverseSeq lvl s
              return (MapSeq afun' s', h1 `max` h2 + 1)
            ZipWithSeq afun s1 s2 -> producer $ do
              (afun', h1) <- traverseAfun2 lvl afun
              (s1'  , h2) <- traverseSeq lvl s1
              (s2'  , h3) <- traverseSeq lvl s2
              return (ZipWithSeq afun' s1' s2', h1 `max` h2 `max` h3 + 1)
            ScanSeq fun e s -> producer $ do
              (fun', h1) <- traverseFun2 lvl fun
              (e',  h2) <- traverseExp lvl e
              (s'   , h3) <- traverseSeq lvl s
              return (ScanSeq fun' e' s', h1 `max` h2 `max` h3 + 1)
            FoldSeq fun e s -> consumer $ do
              (fun', h1) <- traverseFun2 lvl fun
              (e'  , h2) <- traverseExp lvl e
              (s'  , h3) <- traverseSeq lvl s
              return (FoldSeq fun' e' s', h1 `max` h2 `max` h3 + 1)
            FoldSeqFlatten afun acc s -> consumer $ do
              (afun', h1) <- traverseAfun3 lvl afun
              (acc',  h2) <- traverseAcc lvl acc
              (s'   , h3) <- traverseSeq lvl s
              return (FoldSeqFlatten afun' acc' s', h1 `max` h2 `max` h3 + 1)
            Stuple t -> consumer $ do
              (t', h1) <- traverseTup lvl t
              return (Stuple t', h1 + 1)
--}


-- Type used to maintain how often each shared subterm, so far, occurred during a bottom-up sweep,
-- as well as the relation between subterms. It is comprised of a list of terms and a graph giving
-- their relation.
--
--   Invariants of the list:
--   - If one shared term 's' is itself a subterm of another shared term 't', then 's' must occur
--     *after* 't' in the list.
--   - No shared term occurs twice.
--   - A term may have a final occurrence count of only 1 iff it is either a free variable ('Atag'
--     or 'Tag') or an array computation lifted out of an expression.
--   - All 'Exp' node counts precede all 'SmartAcc' node counts as we don't share 'Exp' nodes across 'SmartAcc'
--     nodes. Similarly, all 'Seq' nodes precede 'SmartAcc' nodes and 'Exp' nodes precede 'Seq' nodes.
--
-- We determine the subterm property by using the tree height in 'StableNameHeight'.  Trees get
-- smaller towards the end of a 'NodeCounts' list.  The height of free variables ('Atag' or 'Tag')
-- is 0, whereas other leaves have height 1.  This guarantees that all free variables are at the end
-- of the 'NodeCounts' list.
--
-- The graph is represented as a map where a stable name 'a' is mapped to a set of stables names 'b'
-- such that if there exists a edge from 'a' to 'c' that 'c' is contained within 'b'.
--
--  Properties of the graph:
--  - There exists an edge from 'a' to 'b' if the term 'a' names is a subterm of the term named by
--    'b'.
--
-- To ensure the list invariant and the graph properties are preserved over merging node counts from
-- sibling subterms, the function '(+++)' must be used.
--
type NodeCounts = ([NodeCount], Map.HashMap NodeName (Set.HashSet NodeName))

data NodeName where
  NodeName :: StableName a -> NodeName

instance Eq NodeName where
  (NodeName sn1) == (NodeName sn2) = eqStableName sn1 sn2

instance Hashable NodeName where
  hashWithSalt hash (NodeName sn1) = hash + hashStableName sn1

instance Show NodeName where
  show (NodeName sn) = show (hashStableName sn)

data NodeCount = AccNodeCount StableSharingAcc Int
               | ExpNodeCount StableSharingExp Int
               -- SeqNodeCount StableSharingSeq Int
               deriving Show

-- Empty node counts
--
noNodeCounts :: NodeCounts
noNodeCounts = ([], Map.empty)

-- Insert an Acc node into the node counts, assuming that it is a superterm of the all the existing
-- nodes.
--
-- TODO: Perform cycle detection here.
--
insertAccNode :: StableSharingAcc -> NodeCounts -> NodeCounts
insertAccNode ssa@(StableSharingAcc (StableNameHeight sn _) _) (subterms,g)
  = ([AccNodeCount ssa 1], g') +++ (subterms,g)
  where
    k  = NodeName sn
    hs = map nodeName subterms
    g' = Map.fromList $ (k, Set.empty) : [(h, Set.singleton k) | h <- hs]

-- Insert an Exp node into the node counts, assuming that it is a superterm of the all the existing
-- nodes.
--
-- TODO: Perform cycle detection here.
--
insertExpNode :: StableSharingExp -> NodeCounts -> NodeCounts
insertExpNode ssa@(StableSharingExp (StableNameHeight sn _) _) (subterms,g)
  = ([ExpNodeCount ssa 1], g') +++ (subterms,g)
  where
    k  = NodeName sn
    hs = map nodeName subterms
    g' = Map.fromList $ (k, Set.empty) : [(h, Set.singleton k) | h <- hs]

{--
-- Insert an Seq node into the node counts, assuming that it is a superterm of the all the existing
-- nodes.
--
-- TODO: Perform cycle detection here.
--
insertSeqNode :: StableSharingSeq -> NodeCounts -> NodeCounts
insertSeqNode ssa@(StableSharingSeq (StableNameHeight sn _) _) (subterms,g)
  = ([SeqNodeCount ssa 1], g') +++ (subterms,g)
  where
    k  = NodeName sn
    hs = map nodeName subterms
    g' = Map.fromList $ (k, Set.empty) : [(h, Set.singleton k) | h <- hs]
--}

-- Remove nodes that aren't in the list from the graph.
--
-- RCE: This is no longer necessary when NDP is supported.
--
cleanCounts :: NodeCounts -> NodeCounts
cleanCounts (ns, g) = (ns, Map.fromList [(h, Set.filter (flip elem hs) (g Map.! h)) | h <- hs ])
  where
    hs = map nodeName ns

nodeName :: NodeCount -> NodeName
nodeName (AccNodeCount (StableSharingAcc (StableNameHeight sn _) _) _) = NodeName sn
nodeName (ExpNodeCount (StableSharingExp (StableNameHeight sn _) _) _) = NodeName sn
-- nodeName (SeqNodeCount (StableSharingSeq (StableNameHeight sn _) _) _) = NodeName sn


-- Combine node counts that belong to the same node.
--
-- * We assume that the list invariant —subterms follow their parents— holds for both arguments and
--   guarantee that it still holds for the result.
--
-- * In the same manner, we assume that all 'Exp' node counts precede 'SmartAcc' node counts and
--   guarantee that this also hold for the result.
--
(+++) :: NodeCounts -> NodeCounts -> NodeCounts
(ns1, g1) +++ (ns2, g2) = (cleanup $ merge ns1 ns2, Map.unionWith Set.union g1 g2)
  where
    merge [] x = x
    merge x [] = x
    merge (x@(AccNodeCount sa1 count1):xs) (y@(AccNodeCount sa2 count2):ys)
     | sa1 == sa2          = AccNodeCount (sa1 `pickNoneAvar` sa2) (count1 + count2) : merge xs ys
     | sa1 `higherSSA` sa2 = x : merge xs (y:ys)
     | otherwise           = y : merge (x:xs) ys
    merge (x@(ExpNodeCount se1 count1):xs) (y@(ExpNodeCount se2 count2):ys)
     | se1 == se2          = ExpNodeCount (se1 `pickNoneVar` se2) (count1 + count2) : merge xs ys
     | se1 `higherSSE` se2 = x : merge xs (y:ys)
     | otherwise           = y : merge (x:xs) ys
    merge (x@(AccNodeCount _ _):xs) (y@(ExpNodeCount _ _):ys) = y : merge (x:xs) ys
    merge (x@(ExpNodeCount _ _):xs) (y@(AccNodeCount _ _):ys) = x : merge xs (y:ys)

    (StableSharingAcc _ (AvarSharing _ _)) `pickNoneAvar` sa2  = sa2
    sa1                                    `pickNoneAvar` _sa2 = sa1

    (StableSharingExp _ (VarSharing _ _))  `pickNoneVar`  sa2  = sa2
    sa1                                    `pickNoneVar`  _sa2 = sa1

    -- As the StableSharingAccs do not pose a strict ordering, this cleanup
    -- step is needed. In this step, all pairs of AccNodes and ExpNodes
    -- that are of the same height are compared against each other. Without
    -- this step, duplicates may arise.
    --
    -- Note that while (+++) is morally symmetric, replacing `merge [x] y'
    -- with `merge y [x]' inside of `cleanup' won't check all required
    -- possibilities.
    --
    cleanup = concatMap (foldr (\x y -> merge [x] y) []) . groupBy sameHeight
    sameHeight (AccNodeCount sa1 _) (AccNodeCount sa2 _) = not (sa1 `higherSSA` sa2) && not (sa2 `higherSSA` sa1)
    sameHeight (ExpNodeCount se1 _) (ExpNodeCount se2 _) = not (se1 `higherSSE` se2) && not (se2 `higherSSE` se1)
    sameHeight _ _ = False


-- Build an initial environment for the tag values given in the first argument for traversing an
-- array expression.  The 'StableSharingAcc's for all tags /actually used/ in the expressions are
-- in the second argument. (Tags are not used if a bound variable has no usage occurrence.)
--
-- Bail out if any tag occurs multiple times as this indicates that the sharing of an argument
-- variable was not preserved and we cannot build an appropriate initial environment (c.f., comments
-- at 'determineScopesAcc'.
--
buildInitialEnvAcc
    :: HasCallStack
    => [Level]
    -> [StableSharingAcc]
    -> [StableSharingAcc]
buildInitialEnvAcc tags sas = map (lookupSA sas) tags
  where
    lookupSA sas tag1
      = case filter hasTag sas of
          []   -> noStableSharing    -- tag is not used in the analysed expression
          [sa] -> sa                 -- tag has a unique occurrence
          sas2 -> internalError ("Encountered duplicate 'ATag's\n  " ++ intercalate ", " (map showSA sas2))
      where
        hasTag (StableSharingAcc _ (AccSharing _ (Atag _ tag2))) = tag1 == tag2
        hasTag sa
          = internalError ("Encountered a node that is not a plain 'Atag'\n  " ++ showSA sa)

        noStableSharing :: StableSharingAcc
        noStableSharing = StableSharingAcc noStableAccName (undefined :: SharingAcc acc exp ())

    showSA (StableSharingAcc _ (AccSharing  sn acc)) = show (hashStableNameHeight sn) ++ ": " ++
                                                       showPreAccOp acc
    showSA (StableSharingAcc _ (AvarSharing sn _))   = "AvarSharing " ++ show (hashStableNameHeight sn)
    showSA (StableSharingAcc _ (AletSharing sa _))   = "AletSharing " ++ show sa ++ "..."

-- Build an initial environment for the tag values given in the first argument for traversing a
-- scalar expression.  The 'StableSharingExp's for all tags /actually used/ in the expressions are
-- in the second argument. (Tags are not used if a bound variable has no usage occurrence.)
--
-- Bail out if any tag occurs multiple times as this indicates that the sharing of an argument
-- variable was not preserved and we cannot build an appropriate initial environment (c.f., comments
-- at 'determineScopesAcc'.
--
buildInitialEnvExp
    :: HasCallStack
    => [Level]
    -> [StableSharingExp]
    -> [StableSharingExp]
buildInitialEnvExp tags ses = map (lookupSE ses) tags
  where
    lookupSE ses tag1
      = case filter hasTag ses of
          []   -> noStableSharing    -- tag is not used in the analysed expression
          [se] -> se                 -- tag has a unique occurrence
          ses2 -> internalError ("Encountered a duplicate 'Tag'\n  " ++ intercalate ", " (map showSE ses2))
      where
        hasTag (StableSharingExp _ (ExpSharing _ (Tag _ tag2))) = tag1 == tag2
        hasTag se
          = internalError ("Encountered a node that is not a plain 'Tag'\n  " ++ showSE se)

        noStableSharing :: StableSharingExp
        noStableSharing = StableSharingExp noStableExpName (undefined :: SharingExp acc exp ())

    showSE (StableSharingExp _ (ExpSharing sn exp)) = show (hashStableNameHeight sn) ++ ": " ++
                                                      showPreExpOp exp
    showSE (StableSharingExp _ (VarSharing sn _ ))  = "VarSharing " ++ show (hashStableNameHeight sn)
    showSE (StableSharingExp _ (LetSharing se _ ))  = "LetSharing " ++ show se ++ "..."

-- Determine whether a 'NodeCount' is for an 'Atag' or 'Tag', which represent free variables.
--
isFreeVar :: NodeCount -> Bool
isFreeVar (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag _ _))) _) = True
isFreeVar (ExpNodeCount (StableSharingExp _ (ExpSharing _ (Tag  _ _))) _) = True
isFreeVar _                                                               = False


-- Determine scope of shared subterms
-- ==================================

-- Determine the scopes of all variables representing shared subterms (Phase Two) in a bottom-up
-- sweep.  The first argument determines whether array computations are floated out of expressions
-- irrespective of whether they are shared or not — 'True' implies floating them out.
--
-- In addition to the AST with sharing information, yield the 'StableSharingAcc's for all free
-- variables of 'rootAcc', which are represented by 'Atag' leaves in the tree. They are in order of
-- the tag values — i.e., in the same order that they need to appear in an environment to use the
-- tag for indexing into that environment.
--
-- Precondition: there are only 'AvarSharing' and 'AccSharing' nodes in the argument.
--
determineScopesAcc
    :: HasCallStack
    => Config
    -> [Level]
    -> OccMap SmartAcc
    -> UnscopedAcc a
    -> (ScopedAcc a, [StableSharingAcc])
determineScopesAcc config fvs accOccMap rootAcc
  = let (sharingAcc, (counts, _)) = determineScopesSharingAcc config accOccMap rootAcc
        unboundTrees              = filter (not . isFreeVar) counts
    in
    if all isFreeVar counts
       then (sharingAcc, buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- counts])
       else internalError ("unbound shared subtrees" ++ show unboundTrees)


determineScopesSharingAcc
    :: HasCallStack
    => Config
    -> OccMap SmartAcc
    -> UnscopedAcc a
    -> (ScopedAcc a, NodeCounts)
determineScopesSharingAcc config accOccMap = scopesAcc
  where
    scopesAcc :: forall arrs. HasCallStack => UnscopedAcc arrs -> (ScopedAcc arrs, NodeCounts)
    scopesAcc (UnscopedAcc _ (AletSharing _ _))
      = internalError "unexpected 'AletSharing'"

    scopesAcc (UnscopedAcc _ (AvarSharing sn tp))
      = (ScopedAcc [] (AvarSharing sn tp), StableSharingAcc sn (AvarSharing sn tp) `insertAccNode` noNodeCounts)

    scopesAcc (UnscopedAcc _ (AccSharing sn pacc))
      = case pacc of
          Atag tp i               -> reconstruct (Atag tp i) noNodeCounts
          Pipe repr1 repr2 repr3 afun1 afun2 acc
                                  -> let
                                       (afun1', accCount1) = scopesAfun1 afun1
                                       (afun2', accCount2) = scopesAfun1 afun2
                                       (acc', accCount3)   = scopesAcc acc
                                     in
                                     reconstruct (Pipe repr1 repr2 repr3 afun1' afun2' acc')
                                                 (accCount1 +++ accCount2 +++ accCount3)

          Aforeign r ff afun acc  -> let
                                       (acc', accCount) = scopesAcc acc
                                     in
                                     reconstruct (Aforeign r ff afun acc') accCount
          Acond e acc1 acc2       -> let
                                       (e'   , accCount1) = scopesExp e
                                       (acc1', accCount2) = scopesAcc acc1
                                       (acc2', accCount3) = scopesAcc acc2
                                     in
                                     reconstruct (Acond e' acc1' acc2')
                                                 (accCount1 +++ accCount2 +++ accCount3)

          Awhile repr pred iter init
                                  -> let
                                       (pred', accCount1) = scopesAfun1 pred
                                       (iter', accCount2) = scopesAfun1 iter
                                       (init', accCount3) = scopesAcc init
                                     in
                                     reconstruct (Awhile repr pred' iter' init')
                                                 (accCount1 +++ accCount2 +++ accCount3)

          Anil                    -> reconstruct Anil noNodeCounts
          Apair a1 a2             -> let
                                       (a1', accCount1) = scopesAcc a1
                                       (a2', accCount2) = scopesAcc a2
                                     in
                                       reconstruct (Apair a1' a2') (accCount1 +++ accCount2)
          Aprj ix a               -> travA (Aprj ix) a

          Use repr arr            -> reconstruct (Use repr arr) noNodeCounts
          Unit tp e               -> let
                                       (e', accCount) = scopesExp e
                                     in
                                     reconstruct (Unit tp e') accCount
          Generate repr sh f      -> let
                                       (sh', accCount1) = scopesExp sh
                                       (f' , accCount2) = scopesFun1 f
                                     in
                                     reconstruct (Generate repr sh' f') (accCount1 +++ accCount2)
          Reshape shr sh acc      -> travEA (Reshape shr) sh acc
          Replicate si n acc      -> travEA (Replicate si) n acc
          Slice si acc i          -> travEA (flip $ Slice si) i acc
          Map t1 t2 f acc         -> let
                                       (f'  , accCount1) = scopesFun1 f
                                       (acc', accCount2) = scopesAcc  acc
                                     in
                                     reconstruct (Map t1 t2 f' acc') (accCount1 +++ accCount2)
          ZipWith t1 t2 t3 f acc1 acc2
                                  -> travF2A2 (ZipWith t1 t2 t3) f acc1 acc2
          Fold tp f z acc         -> travF2MEA (Fold tp) f z acc
          FoldSeg i tp f z acc1 acc2 -> let
                                       (f'   , accCount1)  = scopesFun2 f
                                       (z'   , accCount2)  = travME z
                                       (acc1', accCount3)  = scopesAcc  acc1
                                       (acc2', accCount4)  = scopesAcc  acc2
                                     in
                                     reconstruct (FoldSeg i tp f' z' acc1' acc2')
                                       (accCount1 +++ accCount2 +++ accCount3 +++ accCount4)
          Scan d tp f z acc       -> travF2MEA (Scan d tp) f z acc
          Scan' d tp f z acc      -> travF2EA (Scan' d tp) f z acc
          Permute repr fc acc1 fp acc2
                                  -> let
                                       (fc'  , accCount1) = scopesFun2 fc
                                       (acc1', accCount2) = scopesAcc  acc1
                                       (fp'  , accCount3) = scopesFun1 fp
                                       (acc2', accCount4) = scopesAcc  acc2
                                     in
                                     reconstruct (Permute repr fc' acc1' fp' acc2')
                                       (accCount1 +++ accCount2 +++ accCount3 +++ accCount4)
          Backpermute shr sh fp acc
                                  -> let
                                       (sh' , accCount1) = scopesExp  sh
                                       (fp' , accCount2) = scopesFun1 fp
                                       (acc', accCount3) = scopesAcc  acc
                                     in
                                     reconstruct (Backpermute shr sh' fp' acc')
                                       (accCount1 +++ accCount2 +++ accCount3)
          Stencil sr tp st bnd acc      -> let
                                       (st' , accCount1) = scopesStencil1 acc st
                                       (bnd', accCount2) = scopesBoundary bnd
                                       (acc', accCount3) = scopesAcc acc
                                     in
                                     reconstruct (Stencil sr tp st' bnd' acc') (accCount1 +++ accCount2 +++ accCount3)
          Stencil2 s1 s2 tp st bnd1 acc1 bnd2 acc2
                                  -> let
                                       (st'  , accCount1) = scopesStencil2 acc1 acc2 st
                                       (bnd1', accCount2) = scopesBoundary bnd1
                                       (acc1', accCount3) = scopesAcc acc1
                                       (bnd2', accCount4) = scopesBoundary bnd2
                                       (acc2', accCount5) = scopesAcc acc2
                                     in
                                     reconstruct (Stencil2 s1 s2 tp st' bnd1' acc1' bnd2' acc2')
                                       (accCount1 +++ accCount2 +++ accCount3 +++ accCount4 +++ accCount5)
          -- Collect seq             -> let
          --                              (seq', accCount1) = scopesSeq seq
          --                            in
          --                            reconstruct (Collect seq') accCount1

      where
        travEA :: HasCallStack
               => (ScopedExp e -> ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs)
               -> RootExp e
               -> UnscopedAcc arrs'
               -> (ScopedAcc arrs, NodeCounts)
        travEA c e acc = reconstruct (c e' acc') (accCount1 +++ accCount2)
          where
            (e'  , accCount1) = scopesExp e
            (acc', accCount2) = scopesAcc acc

        travF2EA
            :: HasCallStack
            => ((SmartExp a -> SmartExp b -> ScopedExp c) -> ScopedExp e -> ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs)
            -> (SmartExp a -> SmartExp b -> RootExp c)
            -> RootExp e
            -> UnscopedAcc arrs'
            -> (ScopedAcc arrs, NodeCounts)
        travF2EA c f e acc = reconstruct (c f' e' acc') (accCount1 +++ accCount2 +++ accCount3)
          where
            (f'  , accCount1) = scopesFun2 f
            (e'  , accCount2) = scopesExp  e
            (acc', accCount3) = scopesAcc  acc

        travF2MEA
            :: HasCallStack
            => ((SmartExp a -> SmartExp b -> ScopedExp c) -> Maybe (ScopedExp e) -> ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs)
            -> (SmartExp a -> SmartExp b -> RootExp c)
            -> Maybe (RootExp e)
            -> UnscopedAcc arrs'
            -> (ScopedAcc arrs, NodeCounts)
        travF2MEA c f e acc = reconstruct (c f' e' acc') (accCount1 +++ accCount2 +++ accCount3)
          where
            (f'  , accCount1) = scopesFun2 f
            (e'  , accCount2) = travME e
            (acc', accCount3) = scopesAcc  acc

        travME :: HasCallStack => Maybe (RootExp e) -> (Maybe (ScopedExp e), NodeCounts)
        travME Nothing  = (Nothing, noNodeCounts)
        travME (Just e) = (Just e', c)
          where (e', c) = scopesExp e

        travF2A2
            :: HasCallStack
            => ((SmartExp a -> SmartExp b -> ScopedExp c) -> ScopedAcc arrs1 -> ScopedAcc arrs2 -> PreSmartAcc ScopedAcc ScopedExp arrs)
            -> (SmartExp a -> SmartExp b -> RootExp c)
            -> UnscopedAcc arrs1
            -> UnscopedAcc arrs2
            -> (ScopedAcc arrs, NodeCounts)
        travF2A2 c f acc1 acc2 = reconstruct (c f' acc1' acc2')
                                             (accCount1 +++ accCount2 +++ accCount3)
          where
            (f'   , accCount1) = scopesFun2 f
            (acc1', accCount2) = scopesAcc  acc1
            (acc2', accCount3) = scopesAcc  acc2

        travA :: HasCallStack
              => (ScopedAcc arrs' -> PreSmartAcc ScopedAcc ScopedExp arrs)
              -> UnscopedAcc arrs'
              -> (ScopedAcc arrs, NodeCounts)
        travA c acc = reconstruct (c acc') accCount
          where
            (acc', accCount) = scopesAcc acc

          -- Occurrence count of the currently processed node
        accOccCount = let StableNameHeight sn' _ = sn
                      in
                      lookupWithASTName accOccMap (StableASTName sn')

        -- Reconstruct the current tree node.
        --
        -- * If the current node is being shared ('accOccCount > 1'), replace it by a 'AvarSharing'
        --   node and float the shared subtree out wrapped in a 'NodeCounts' value.
        -- * If the current node is not shared, reconstruct it in place.
        -- * Special case for free variables ('Atag'): Replace the tree by a sharing variable and
        --   float the 'Atag' out in a 'NodeCounts' value.  This is independent of the number of
        --   occurrences.
        --
        -- In either case, any completed 'NodeCounts' are injected as bindings using 'AletSharing'
        -- node.
        --
        reconstruct
            :: HasCallStack
            => PreSmartAcc ScopedAcc ScopedExp arrs
            -> NodeCounts
            -> (ScopedAcc arrs, NodeCounts)
        reconstruct newAcc@(Atag tp _) _subCount
              -- free variable => replace by a sharing variable regardless of the number of
              -- occurrences
          = let thisCount = StableSharingAcc sn (AccSharing sn newAcc) `insertAccNode` noNodeCounts
            in
            tracePure "FREE" (show thisCount)
            (ScopedAcc [] (AvarSharing sn tp), thisCount)
        reconstruct newAcc subCount
              -- shared subtree => replace by a sharing variable (if 'recoverAccSharing' enabled)
          | accOccCount > 1 && acc_sharing `member` options config
          = let allCount = (StableSharingAcc sn sharingAcc `insertAccNode` newCount)
            in
            tracePure ("SHARED" ++ completed) (show allCount)
            (ScopedAcc [] (AvarSharing sn $ Smart.arraysR newAcc), allCount)
              -- neither shared nor free variable => leave it as it is
          | otherwise
          = tracePure ("Normal" ++ completed) (show newCount)
            (ScopedAcc [] sharingAcc, newCount)
          where
              -- Determine the bindings that need to be attached to the current node...
            (newCount, bindHere) = filterCompleted subCount

              -- ...and wrap them in 'AletSharing' constructors
            lets       = foldl (flip (.)) id . map (\x y -> AletSharing x (ScopedAcc [] y)) $ bindHere
            sharingAcc = lets $ AccSharing sn newAcc

              -- trace support
            completed | null bindHere = ""
                      | otherwise     = "(" ++ show (length bindHere) ++ " lets)"

        -- Extract *leading* nodes that have a complete node count (i.e., their node count is equal
        -- to the number of occurrences of that node in the overall expression).
        --
        -- Nodes with a completed node count should be let bound at the currently processed node.
        --
        -- NB: Only extract leading nodes (i.e., the longest run at the *front* of the list that is
        --     complete).  Otherwise, we would let-bind subterms before their parents, which leads
        --     scope errors.
        --
        filterCompleted :: NodeCounts -> (NodeCounts, [StableSharingAcc])
        filterCompleted (ns, graph)
          = let bindable     = map (isBindable bindable (map nodeName ns)) ns
                (bind, rest) = partition fst $ zip bindable ns
            in ((map snd rest, graph), [sa | AccNodeCount sa _ <- map snd bind])
          where
            -- a node is not yet complete while the node count 'n' is below the overall number
            -- of occurrences for that node in the whole program, with the exception that free
            -- variables are never complete
            isCompleted nc@(AccNodeCount sa n) | not . isFreeVar $ nc = lookupWithSharingAcc accOccMap sa == n
            isCompleted _                                             = False

            isBindable :: [Bool] -> [NodeName] -> NodeCount -> Bool
            isBindable bindable nodes nc@(AccNodeCount _ _) =
              let superTerms = Set.toList $ graph Map.! nodeName nc
                  unbound    = mapMaybe (`elemIndex` nodes) superTerms
              in    isCompleted nc
                 && all (bindable !!) unbound
            isBindable _ _ (ExpNodeCount _ _) = False
            -- isBindable _ _ (SeqNodeCount _ _) = False

    -- scopesSeq :: forall arrs. RootSeq arrs -> (ScopedSeq arrs, NodeCounts)
    -- scopesSeq = determineScopesSeq config accOccMap

    scopesExp
        :: HasCallStack
        => RootExp t
        -> (ScopedExp t, NodeCounts)
    scopesExp = determineScopesExp config accOccMap

    -- The lambda bound variable is at this point already irrelevant; for details, see
    -- Note [Traversing functions and side effects]
    --
    scopesAfun1
        :: HasCallStack
        => (SmartAcc a1 -> UnscopedAcc a2)
        -> (SmartAcc a1 -> ScopedAcc a2, NodeCounts)
    scopesAfun1 f = (const (ScopedAcc ssa body'), (counts', graph))
      where
        body@(UnscopedAcc fvs _)             = f undefined
        (ScopedAcc [] body', (counts,graph)) = scopesAcc body
        (freeCounts, counts')                = partition isBoundHere counts
        ssa                                  = buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- freeCounts]

        isBoundHere (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag _ i))) _) = i `elem` fvs
        isBoundHere _                                                               = False

    -- The lambda bound variable is at this point already irrelevant; for details, see
    -- Note [Traversing functions and side effects]
    --
    scopesFun1
        :: HasCallStack
        => (SmartExp e1 -> RootExp e2)
        -> (SmartExp e1 -> ScopedExp e2, NodeCounts)
    scopesFun1 f = (const body, counts)
      where
        (body, counts) = scopesExp (f undefined)

    -- The lambda bound variable is at this point already irrelevant; for details, see
    -- Note [Traversing functions and side effects]
    --
    scopesFun2
        :: HasCallStack
        => (SmartExp e1 -> SmartExp e2 -> RootExp e3)
        -> (SmartExp e1 -> SmartExp e2 -> ScopedExp e3, NodeCounts)
    scopesFun2 f = (\_ _ -> body, counts)
      where
        (body, counts) = scopesExp (f undefined undefined)

    -- The lambda bound variable is at this point already irrelevant; for details, see
    -- Note [Traversing functions and side effects]
    --
    scopesStencil1
        :: forall sh e1 e2 stencil. HasCallStack
        => UnscopedAcc (Array sh e1){-dummy-}
        -> (stencil -> RootExp e2)
        -> (stencil -> ScopedExp e2, NodeCounts)
    scopesStencil1 _ stencilFun = (const body, counts)
      where
        (body, counts) = scopesExp (stencilFun undefined)

    -- The lambda bound variable is at this point already irrelevant; for details, see
    -- Note [Traversing functions and side effects]
    --
    scopesStencil2
        :: forall sh e1 e2 e3 stencil1 stencil2. HasCallStack
        => UnscopedAcc (Array sh e1){-dummy-}
        -> UnscopedAcc (Array sh e2){-dummy-}
        -> (stencil1 -> stencil2 -> RootExp e3)
        -> (stencil1 -> stencil2 -> ScopedExp e3, NodeCounts)
    scopesStencil2 _ _ stencilFun = (\_ _ -> body, counts)
      where
        (body, counts) = scopesExp (stencilFun undefined undefined)

    scopesBoundary
        :: HasCallStack
        => PreBoundary UnscopedAcc RootExp t
        -> (PreBoundary ScopedAcc ScopedExp t, NodeCounts)
    scopesBoundary bndy =
      case bndy of
        Clamp      -> (Clamp, noNodeCounts)
        Mirror     -> (Mirror, noNodeCounts)
        Wrap       -> (Wrap, noNodeCounts)
        Constant v -> (Constant v, noNodeCounts)
        Function f -> let (body, counts) = scopesFun1 f
                      in  (Function body, counts)


determineScopesExp
    :: HasCallStack
    => Config
    -> OccMap SmartAcc
    -> RootExp t
    -> (ScopedExp t, NodeCounts)          -- Root (closed) expression plus Acc node counts
determineScopesExp config accOccMap (RootExp expOccMap exp@(UnscopedExp fvs _))
  = let
        (ScopedExp [] expWithScopes, (nodeCounts,graph)) = determineScopesSharingExp config accOccMap expOccMap exp
        (expCounts, accCounts)                           = partition isExpNodeCount nodeCounts

        isExpNodeCount ExpNodeCount{} = True
        isExpNodeCount _              = False
    in
    (ScopedExp (buildInitialEnvExp fvs [se | ExpNodeCount se _ <- expCounts]) expWithScopes, cleanCounts (accCounts,graph))


determineScopesSharingExp
    :: HasCallStack
    => Config
    -> OccMap SmartAcc
    -> OccMap SmartExp
    -> UnscopedExp t
    -> (ScopedExp t, NodeCounts)
determineScopesSharingExp config accOccMap expOccMap = scopesExp
  where
    scopesAcc
        :: HasCallStack
        => UnscopedAcc a
        -> (ScopedAcc a, NodeCounts)
    scopesAcc = determineScopesSharingAcc config accOccMap

    scopesFun1
        :: HasCallStack
        => (SmartExp a -> UnscopedExp b)
        -> (SmartExp a -> ScopedExp b, NodeCounts)
    scopesFun1 f = tracePure ("LAMBDA " ++ show ssa) (show counts) (const (ScopedExp ssa body'), (counts',graph))
      where
        body@(UnscopedExp fvs _)              = f undefined
        (ScopedExp [] body', (counts, graph)) = scopesExp body
        (freeCounts, counts')                 = partition isBoundHere counts
        ssa                                   = buildInitialEnvExp fvs [se | ExpNodeCount se _ <- freeCounts]

        isBoundHere (ExpNodeCount (StableSharingExp _ (ExpSharing _ (Tag _ i))) _) = i `elem` fvs
        isBoundHere _                                                              = False

    scopesExp
        :: forall t. HasCallStack
        => UnscopedExp t
        -> (ScopedExp t, NodeCounts)
    scopesExp (UnscopedExp _ (LetSharing _ _))
      = internalError "unexpected 'LetSharing'"

    scopesExp (UnscopedExp _ (VarSharing sn tp))
      = (ScopedExp [] (VarSharing sn tp), StableSharingExp sn (VarSharing sn tp) `insertExpNode` noNodeCounts)

    scopesExp (UnscopedExp _ (ExpSharing sn pexp))
      = case pexp of
          Tag tp i              -> reconstruct (Tag tp i) noNodeCounts
          Const tp c            -> reconstruct (Const tp c) noNodeCounts
          Undef tp              -> reconstruct (Undef tp) noNodeCounts
          Pair e1 e2            -> travE2 Pair e1 e2
          Nil                   -> reconstruct Nil noNodeCounts
          Prj i e               -> travE1 (Prj i) e
          VecPack   vec e       -> travE1 (VecPack   vec) e
          VecUnpack vec e       -> travE1 (VecUnpack vec) e
          ToIndex shr sh ix     -> travE2 (ToIndex shr) sh ix
          FromIndex shr sh e    -> travE2 (FromIndex shr) sh e
          Match t e             -> travE1 (Match t) e
          Case e rhs            -> let (e',   accCount1) = scopesExp e
                                       (rhs', accCount2) = unzip [ ((t,c'), counts)| (t,c) <- rhs, let (c', counts) = scopesExp c ]
                                    in reconstruct (Case e' rhs') (foldr (+++) accCount1 accCount2)
          Cond e1 e2 e3         -> travE3 Cond e1 e2 e3
          While tp p it i       -> let (p' , accCount1) = scopesFun1 p
                                       (it', accCount2) = scopesFun1 it
                                       (i' , accCount3) = scopesExp i
                                    in reconstruct (While tp p' it' i') (accCount1 +++ accCount2 +++ accCount3)
          PrimConst c           -> reconstruct (PrimConst c) noNodeCounts
          PrimApp p e           -> travE1 (PrimApp p) e
          Index tp a e          -> travAE (Index tp) a e
          LinearIndex tp a e    -> travAE (LinearIndex tp) a e
          Shape shr a           -> travA (Shape shr) a
          ShapeSize shr e       -> travE1 (ShapeSize shr) e
          Foreign tp ff f e     -> travE1 (Foreign tp ff f) e
          Coerce t1 t2 e        -> travE1 (Coerce t1 t2) e
      where
        travE1 :: HasCallStack
               => (ScopedExp a -> PreSmartExp ScopedAcc ScopedExp t)
               -> UnscopedExp a
               -> (ScopedExp t, NodeCounts)
        travE1 c e = reconstruct (c e') accCount
          where
            (e', accCount) = scopesExp e

        travE2 :: HasCallStack
               => (ScopedExp a -> ScopedExp b -> PreSmartExp ScopedAcc ScopedExp t)
               -> UnscopedExp a
               -> UnscopedExp b
               -> (ScopedExp t, NodeCounts)
        travE2 c e1 e2 = reconstruct (c e1' e2') (accCount1 +++ accCount2)
          where
            (e1', accCount1) = scopesExp e1
            (e2', accCount2) = scopesExp e2

        travE3 :: HasCallStack
               => (ScopedExp a -> ScopedExp b -> ScopedExp c -> PreSmartExp ScopedAcc ScopedExp t)
               -> UnscopedExp a
               -> UnscopedExp b
               -> UnscopedExp c
               -> (ScopedExp t, NodeCounts)
        travE3 c e1 e2 e3 = reconstruct (c e1' e2' e3') (accCount1 +++ accCount2 +++ accCount3)
          where
            (e1', accCount1) = scopesExp e1
            (e2', accCount2) = scopesExp e2
            (e3', accCount3) = scopesExp e3

        travA :: HasCallStack
              => (ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t) -> UnscopedAcc a
              -> (ScopedExp t, NodeCounts)
        travA c acc = floatOutAcc c acc' accCount
          where
            (acc', accCount)  = scopesAcc acc

        travAE :: HasCallStack
               => (ScopedAcc a -> ScopedExp b -> PreSmartExp ScopedAcc ScopedExp t)
               -> UnscopedAcc a
               -> UnscopedExp b
               -> (ScopedExp t, NodeCounts)
        travAE c acc e = floatOutAcc (`c` e') acc' (accCountA +++ accCountE)
          where
            (acc', accCountA) = scopesAcc acc
            (e'  , accCountE) = scopesExp e

        floatOutAcc
            :: HasCallStack
            => (ScopedAcc a -> PreSmartExp ScopedAcc ScopedExp t)
            -> ScopedAcc a
            -> NodeCounts
            -> (ScopedExp t, NodeCounts)
        floatOutAcc c acc@(ScopedAcc _ (AvarSharing _ _)) accCount        -- nothing to float out
          = reconstruct (c acc) accCount
        floatOutAcc c acc accCount
          = reconstruct (c var) ((stableAcc `insertAccNode` noNodeCounts) +++ accCount)
          where
             (var, stableAcc) = abstract acc (\(ScopedAcc _ s) -> s)

        abstract
            :: HasCallStack
            => ScopedAcc a
            -> (ScopedAcc a -> SharingAcc ScopedAcc ScopedExp a)
            -> (ScopedAcc a, StableSharingAcc)
        abstract (ScopedAcc _   (AvarSharing _ _))     _    = internalError "AvarSharing"
        abstract (ScopedAcc ssa (AletSharing sa acc))  lets = abstract acc (lets . ScopedAcc ssa . AletSharing sa)
        abstract acc@(ScopedAcc ssa (AccSharing sn a)) lets = (ScopedAcc ssa (AvarSharing sn $ Smart.arraysR a), StableSharingAcc sn (lets acc))

        -- Occurrence count of the currently processed node
        expOccCount = let StableNameHeight sn' _ = sn
                       in lookupWithASTName expOccMap (StableASTName sn')

        -- Reconstruct the current tree node.
        --
        -- * If the current node is being shared ('expOccCount > 1'), replace it by a 'VarSharing'
        --   node and float the shared subtree out wrapped in a 'NodeCounts' value.
        -- * If the current node is not shared, reconstruct it in place.
        -- * Special case for free variables ('Tag'): Replace the tree by a sharing variable and
        --   float the 'Tag' out in a 'NodeCounts' value.  This is independent of the number of
        --   occurrences.
        --
        -- In either case, any completed 'NodeCounts' are injected as bindings using 'LetSharing'
        -- node.
        --
        reconstruct
            :: HasCallStack
            => PreSmartExp ScopedAcc ScopedExp t
            -> NodeCounts
            -> (ScopedExp t, NodeCounts)
        reconstruct newExp@(Tag tp _) _subCount
              -- free variable => replace by a sharing variable regardless of the number of
              -- occurrences
          = let thisCount = StableSharingExp sn (ExpSharing sn newExp) `insertExpNode` noNodeCounts
            in
            tracePure "FREE" (show thisCount)
            (ScopedExp [] (VarSharing sn tp), thisCount)
        reconstruct newExp subCount
              -- shared subtree => replace by a sharing variable (if 'recoverExpSharing' enabled)
          | expOccCount > 1 && exp_sharing `member` options config
          = let allCount = StableSharingExp sn sharingExp `insertExpNode` newCount
            in
            tracePure ("SHARED" ++ completed) (show allCount)
            (ScopedExp [] (VarSharing sn $ typeR newExp), allCount)
              -- neither shared nor free variable => leave it as it is
          | otherwise
          = tracePure ("Normal" ++ completed) (show newCount)
            (ScopedExp [] sharingExp, newCount)
          where
              -- Determine the bindings that need to be attached to the current node...
            (newCount, bindHere) = filterCompleted subCount

              -- ...and wrap them in 'LetSharing' constructors
            lets       = foldl (flip (.)) id . map (\x y -> LetSharing x (ScopedExp [] y)) $ bindHere
            sharingExp = lets $ ExpSharing sn newExp

              -- trace support
            completed | null bindHere = ""
                      | otherwise     = "(" ++ show (length bindHere) ++ " lets)"

        -- Extract *leading* nodes that have a complete node count (i.e., their node count is equal
        -- to the number of occurrences of that node in the overall expression).
        --
        -- Nodes with a completed node count should be let bound at the currently processed node.
        --
        -- NB: Only extract leading nodes (i.e., the longest run at the *front* of the list that is
        --     complete).  Otherwise, we would let-bind subterms before their parents, which leads
        --     scope errors.
        --
        filterCompleted :: HasCallStack => NodeCounts -> (NodeCounts, [StableSharingExp])
        filterCompleted (ns,graph)
          = let bindable       = map (isBindable bindable (map nodeName ns)) ns
                (bind, unbind) = partition fst $ zip bindable ns
            in ((map snd unbind, graph), [se | ExpNodeCount se _ <- map snd bind])
          where
            -- a node is not yet complete while the node count 'n' is below the overall number
            -- of occurrences for that node in the whole program, with the exception that free
            -- variables are never complete
            isCompleted nc@(ExpNodeCount sa n) | not . isFreeVar $ nc = lookupWithSharingExp expOccMap sa == n
            isCompleted _                                             = False

            isBindable :: [Bool] -> [NodeName] -> NodeCount -> Bool
            isBindable bindable nodes nc@(ExpNodeCount _ _) =
              let superTerms = Set.toList $ graph Map.! nodeName nc
                  unbound    = mapMaybe (`elemIndex` nodes) superTerms
              in    isCompleted nc
                 && all (bindable !!) unbound
            isBindable _ _ (AccNodeCount _ _) = False
            -- isBindable _ _ (SeqNodeCount _ _) = False

{--
determineScopesSeq
    :: Config
    -> OccMap Acc
    -> RootSeq t
    -> (ScopedSeq t, NodeCounts)          -- Root (closed) expression plus Acc node counts
determineScopesSeq config accOccMap (RootSeq seqOccMap seq)
  = let
        (ScopedSeq seqWithScopes, (nodeCounts,graph)) = determineScopesSharingSeq config accOccMap seqOccMap seq
        binds      = [s | SeqNodeCount s _ <- nodeCounts]
        lets       = foldl (flip (.)) id . map (\x y -> SletSharing x (ScopedSeq y)) $ binds
        sharingSeq = lets seqWithScopes
        newCounts  = filter (not . isSeqCount) nodeCounts
        isSeqCount SeqNodeCount{} = True
        isSeqCount _              = False
    in
    (ScopedSeq sharingSeq, cleanCounts (newCounts,graph))

determineScopesSharingSeq
  :: Config
  -> OccMap Acc
  -> OccMap Seq
  -> UnscopedSeq t
  -> (ScopedSeq t, NodeCounts)
determineScopesSharingSeq config accOccMap _seqOccMap = scopesSeq
  where
    scopesAcc :: UnscopedAcc a -> (ScopedAcc a, NodeCounts)
    scopesAcc = determineScopesSharingAcc config accOccMap

    scopesExp :: RootExp t -> (ScopedExp t, NodeCounts)
    scopesExp = determineScopesExp config accOccMap

    scopesFun2 :: (Elt e1, Elt e2)
               => (Exp e1 -> Exp e2 -> RootExp e3)
               -> (Exp e1 -> Exp e2 -> ScopedExp e3, NodeCounts)
    scopesFun2 f = (\_ _ -> body, counts)
      where
        (body, counts) = scopesExp (f undefined undefined)

    -- The lambda bound variable is at this point already irrelevant; for details, see
    -- Note [Traversing functions and side effects]
    --
    scopesAfun1 :: Arrays a1 => (Acc a1 -> UnscopedAcc a2) -> (Acc a1 -> ScopedAcc a2, NodeCounts)
    scopesAfun1 f = (const (ScopedAcc ssa body'), (counts',graph))
      where
        body@(UnscopedAcc fvs _) = f undefined
        ((ScopedAcc [] body'), (counts,graph)) = scopesAcc body
        ssa     = buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- freeCounts]
        (freeCounts, counts') = partition isBoundHere counts

        isBoundHere (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag i))) _) = i `elem` fvs
        isBoundHere _                                                             = False

    scopesAfun2 :: (Arrays a1, Arrays a2) => (Acc a1 -> Acc a2 -> UnscopedAcc a3) -> (Acc a1 -> Acc a2 -> ScopedAcc a3, NodeCounts)
    scopesAfun2 f = (\ _ _ -> (ScopedAcc ssa body'), (counts',graph))
      where
        body@(UnscopedAcc fvs _) = f undefined undefined
        ((ScopedAcc [] body'), (counts,graph)) = scopesAcc body
        ssa     = buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- freeCounts]
        (freeCounts, counts') = partition isBoundHere counts

        isBoundHere (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag i))) _) = i `elem` fvs
        isBoundHere _                                                             = False

    scopesAfun3 :: (Arrays a1, Arrays a2, Arrays a3) => (Acc a1 -> Acc a2 -> Acc a3 -> UnscopedAcc a4) -> (Acc a1 -> Acc a2 -> Acc a3 -> ScopedAcc a4, NodeCounts)
    scopesAfun3 f = (\ _ _ _ -> (ScopedAcc ssa body'), (counts',graph))
      where
        body@(UnscopedAcc fvs _) = f undefined undefined undefined
        ((ScopedAcc [] body'), (counts,graph)) = scopesAcc body
        ssa     = buildInitialEnvAcc fvs [sa | AccNodeCount sa _ <- freeCounts]
        (freeCounts, counts') = partition isBoundHere counts

        isBoundHere (AccNodeCount (StableSharingAcc _ (AccSharing _ (Atag i))) _) = i `elem` fvs
        isBoundHere _                                                             = False

    scopesTup :: Atuple UnscopedSeq tup -> (Atuple ScopedSeq tup, NodeCounts)
    scopesTup NilAtup          = (NilAtup, noNodeCounts)
    scopesTup (SnocAtup tup s) = let
                                   (tup', accCountT) = scopesTup tup
                                   (s'  , accCountS) = scopesSeq s
                                 in
                                 (SnocAtup tup' s', accCountT +++ accCountS)

    scopesSeq :: forall t. UnscopedSeq t -> (ScopedSeq t, NodeCounts)
    scopesSeq (UnscopedSeq (SletSharing _ _))
      = $internalError "determineScopesSharingSeq: scopesSeq" "unexpected 'LetSharing'"
    scopesSeq (UnscopedSeq (SvarSharing sn))
      = (ScopedSeq (SvarSharing sn), StableSharingSeq sn (SvarSharing sn) `insertSeqNode` noNodeCounts)

    scopesSeq (UnscopedSeq (SeqSharing sn s)) =
      case s of
        StreamIn arrs -> producer (StreamIn arrs) noNodeCounts
        ToSeq sl acc   -> let
                            (acc', accCount1) = scopesAcc acc
                          in producer (ToSeq sl acc') accCount1
        MapSeq     afun s'  -> let
                                 (afun', accCount1) = scopesAfun1 afun
                                 (s''  , accCount2) = scopesSeq s'
                               in producer (MapSeq afun' s'') (accCount1 +++ accCount2)
        ZipWithSeq afun s1 s2 -> let
                                   (afun', accCount1) = scopesAfun2 afun
                                   (s1'  , accCount2) = scopesSeq s1
                                   (s2'  , accCount3) = scopesSeq s2
                                 in producer (ZipWithSeq afun' s1' s2') (accCount1 +++ accCount2 +++ accCount3)
        ScanSeq fun e s' -> let
                              (fun', accCount1) = scopesFun2 fun
                              (e'  , accCount2) = scopesExp e
                              (s'' , accCount3) = scopesSeq s'
                            in producer (ScanSeq fun' e' s'') (accCount1 +++ accCount2 +++ accCount3)
        FoldSeq fun e s' -> let
                              (fun', accCount1) = scopesFun2 fun
                              (e'  , accCount2) = scopesExp e
                              (s'' , accCount3) = scopesSeq s'
                            in consumer (FoldSeq fun' e' s'') (accCount1 +++ accCount2 +++ accCount3)
        FoldSeqFlatten afun acc s' ->
                               let
                                 (afun', accCount1) = scopesAfun3 afun
                                 (acc' , accCount2) = scopesAcc acc
                                 (s''  , accCount3) = scopesSeq s'
                               in consumer (FoldSeqFlatten afun' acc' s'') (accCount1 +++ accCount2 +++ accCount3)
        Stuple tup          -> let
                                 (tup', accCount1) = scopesTup tup
                               in consumer (Stuple tup') accCount1
      where
        -- All producers must be replaced by sharing variables
        --
        producer :: (t ~ [a], Arrays a)
                 => PreSeq ScopedAcc ScopedSeq ScopedExp t
                 -> NodeCounts
                 -> (ScopedSeq t, NodeCounts)
        producer newSeq subCount
          = let allCount = StableSharingSeq sn (SeqSharing sn newSeq) `insertSeqNode` subCount
            in
            tracePure "Producer" (show allCount)
            (ScopedSeq (SvarSharing sn), allCount)

        -- Consumers cannot be shared.
        --
        consumer :: PreSeq ScopedAcc ScopedSeq ScopedExp t
                 -> NodeCounts
                 -> (ScopedSeq t, NodeCounts)
        consumer newSeq subCount
          = tracePure "Consumer" (show subCount)
            (ScopedSeq (SeqSharing sn newSeq), subCount)
--}

-- |Recover sharing information and annotate the HOAS AST with variable and let binding
-- annotations.  The first argument determines whether array computations are floated out of
-- expressions irrespective of whether they are shared or not — 'True' implies floating them out.
--
-- Also returns the 'StableSharingAcc's of all 'Atag' leaves in environment order — they represent
-- the free variables of the AST.
--
-- NB: Strictly speaking, this function is not deterministic, as it uses stable pointers to
--     determine the sharing of subterms.  The stable pointer API does not guarantee its
--     completeness; i.e., it may miss some equalities, which implies that we may fail to discover
--     some sharing.  However, sharing does not affect the denotational meaning of an array
--     computation; hence, we do not compromise denotational correctness.
--
--     There is one caveat: We currently rely on the 'Atag' and 'Tag' leaves representing free
--     variables to be shared if any of them is used more than once.  If one is duplicated, the
--     environment for de Bruijn conversion will have a duplicate entry, and hence, be of the wrong
--     size, which is fatal. (The 'buildInitialEnv*' functions will already bail out.)
--
{-# NOINLINE recoverSharingAcc #-}
recoverSharingAcc
    :: HasCallStack
    => Config
    -> Level            -- The level of currently bound array variables
    -> [Level]          -- The tags of newly introduced free array variables
    -> SmartAcc a
    -> (ScopedAcc a, [StableSharingAcc])
recoverSharingAcc config alvl avars acc
  = let (acc', occMap)
          = unsafePerformIO             -- to enable stable pointers; this is safe as explained above
          $ makeOccMapAcc config alvl acc
    in
    determineScopesAcc config avars occMap acc'


{-# NOINLINE recoverSharingExp #-}
recoverSharingExp
    :: HasCallStack
    => Config
    -> Level            -- The level of currently bound scalar variables
    -> [Level]          -- The tags of newly introduced free scalar variables
    -> SmartExp e
    -> (ScopedExp e, [StableSharingExp])
recoverSharingExp config lvl fvar exp
  = let
        (rootExp, accOccMap) = unsafePerformIO $ do
          accOccMap       <- newASTHashTable
          (exp', _)       <- makeOccMapRootExp config accOccMap lvl fvar exp
          frozenAccOccMap <- freezeOccMap accOccMap

          return (exp', frozenAccOccMap)

        (ScopedExp sse sharingExp, _) =
          determineScopesExp config accOccMap rootExp
    in
    (ScopedExp [] sharingExp, sse)


{--
{-# NOINLINE recoverSharingSeq #-}
recoverSharingSeq
    :: Config
    -> Seq e
    -> (ScopedSeq e, [StableSharingSeq])
recoverSharingSeq config seq
  = let
        (rootSeq, accOccMap) = unsafePerformIO $ do
          accOccMap       <- newASTHashTable
          (seq', _)       <- makeOccMapRootSeq config accOccMap 0 seq
          frozenAccOccMap <- freezeOccMap accOccMap

          return (seq', frozenAccOccMap)

        (ScopedSeq sharingSeq, (ns, _)) =
          determineScopesSeq config accOccMap rootSeq
    in
    (ScopedSeq sharingSeq, [a | SeqNodeCount a _ <- ns])
--}


-- Debugging
-- ---------

traceLine :: String -> String -> IO ()
traceLine header msg
  = Debug.traceIO Debug.dump_sharing
  $ header ++ ": " ++ msg

traceChunk :: String -> String -> IO ()
traceChunk header msg
  = Debug.traceIO Debug.dump_sharing
  $ header ++ "\n      " ++ msg

tracePure :: String -> String -> a -> a
tracePure header msg
  = Debug.trace Debug.dump_sharing
  $ header ++ ": " ++ msg