{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE LambdaCase #-}

module GHC.StgToJS.Sinker (sinkPgm) where

import GHC.Prelude
import GHC.Types.Unique.Set
import GHC.Types.Unique.FM
import GHC.Types.Var.Set
import GHC.Stg.Syntax
import GHC.Types.Id
import GHC.Types.Name
import GHC.Unit.Module
import GHC.Types.Literal
import GHC.Data.Graph.Directed

import GHC.StgToJS.CoreUtils

import Data.Char
import Data.Either
import Data.List (partition)
import Data.Maybe


-- | Unfloat some top-level unexported things
--
-- GHC floats constants to the top level. This is fine in native code, but with JS
-- they occupy some global variable name. We can unfloat some unexported things:
--
-- - global constructors, as long as they're referenced only once by another global
--      constructor and are not in a recursive binding group
-- - literals (small literals may also be sunk if they are used more than once)
sinkPgm :: Module
        -> [CgStgTopBinding]
        -> (UniqFM Id CgStgExpr, [CgStgTopBinding])
sinkPgm :: Module
-> [CgStgTopBinding] -> (UniqFM Id CgStgExpr, [CgStgTopBinding])
sinkPgm Module
m [CgStgTopBinding]
pgm = (UniqFM Id CgStgExpr
sunk, forall a b. (a -> b) -> [a] -> [b]
map forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted [CgStgBinding]
pgm'' forall a. [a] -> [a] -> [a]
++ [CgStgTopBinding]
stringLits)
  where
    selectLifted :: GenStgTopBinding pass
-> Either (GenStgBinding pass) (GenStgTopBinding pass)
selectLifted (StgTopLifted GenStgBinding pass
b) = forall a b. a -> Either a b
Left GenStgBinding pass
b
    selectLifted GenStgTopBinding pass
x                = forall a b. b -> Either a b
Right GenStgTopBinding pass
x
    ([CgStgBinding]
pgm', [CgStgTopBinding]
stringLits) = forall a b. [Either a b] -> ([a], [b])
partitionEithers (forall a b. (a -> b) -> [a] -> [b]
map forall {pass :: StgPass}.
GenStgTopBinding pass
-> Either (GenStgBinding pass) (GenStgTopBinding pass)
selectLifted [CgStgTopBinding]
pgm)
    (UniqFM Id CgStgExpr
sunk, [CgStgBinding]
pgm'')      = Module -> [CgStgBinding] -> (UniqFM Id CgStgExpr, [CgStgBinding])
sinkPgm' Module
m [CgStgBinding]
pgm'

sinkPgm'
  :: Module
       -- ^ the module, since we treat definitions from the current module
       -- differently
  -> [CgStgBinding]
       -- ^ the bindings
  -> (UniqFM Id CgStgExpr, [CgStgBinding])
       -- ^ a map with sunken replacements for nodes, for where the replacement
       -- does not fit in the 'StgBinding' AST and the new bindings
sinkPgm' :: Module -> [CgStgBinding] -> (UniqFM Id CgStgExpr, [CgStgBinding])
sinkPgm' Module
m [CgStgBinding]
pgm =
  let usedOnce :: IdSet
usedOnce = [CgStgBinding] -> IdSet
collectUsedOnce [CgStgBinding]
pgm
      sinkables :: UniqFM Id CgStgExpr
sinkables = forall key elt. Uniquable key => [(key, elt)] -> UniqFM key elt
listToUFM forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap CgStgBinding -> [(Id, CgStgExpr)]
alwaysSinkable [CgStgBinding]
pgm forall a. [a] -> [a] -> [a]
++
          forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Uniquable a => a -> UniqSet a -> Bool
`elementOfUniqSet` IdSet
usedOnce) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Module -> CgStgBinding -> [(Id, CgStgExpr)]
onceSinkable Module
m) [CgStgBinding]
pgm)
      isSunkBind :: CgStgBinding -> Bool
isSunkBind (StgNonRec BinderP 'CodeGen
b GenStgRhs 'CodeGen
_e) | forall key elt. Uniquable key => key -> UniqFM key elt -> Bool
elemUFM BinderP 'CodeGen
b UniqFM Id CgStgExpr
sinkables = Bool
True
      isSunkBind CgStgBinding
_                                      = Bool
False
  in (UniqFM Id CgStgExpr
sinkables, forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. CgStgBinding -> Bool
isSunkBind) forall a b. (a -> b) -> a -> b
$ Module -> [CgStgBinding] -> [CgStgBinding]
topSortDecls Module
m [CgStgBinding]
pgm)

-- | always sinkable, values that may be duplicated in the generated code (e.g.
-- small literals)
alwaysSinkable :: CgStgBinding -> [(Id, CgStgExpr)]
alwaysSinkable :: CgStgBinding -> [(Id, CgStgExpr)]
alwaysSinkable (StgRec {})       = []
alwaysSinkable (StgNonRec BinderP 'CodeGen
b GenStgRhs 'CodeGen
rhs) = case GenStgRhs 'CodeGen
rhs of
  StgRhsClosure XRhsClosure 'CodeGen
_ CostCentreStack
_ UpdateFlag
_ [BinderP 'CodeGen]
_ e :: CgStgExpr
e@(StgLit Literal
l)
    | Literal -> Bool
isSmallSinkableLit Literal
l
    , Id -> Bool
isLocal BinderP 'CodeGen
b
    -> [(BinderP 'CodeGen
b,CgStgExpr
e)]
  StgRhsCon CostCentreStack
_ccs DataCon
dc ConstructorNumber
cnum [StgTickish]
_ticks as :: [StgArg]
as@[StgLitArg Literal
l]
    | Literal -> Bool
isSmallSinkableLit Literal
l
    , Id -> Bool
isLocal BinderP 'CodeGen
b
    , DataCon -> Bool
isUnboxableCon DataCon
dc
    -> [(BinderP 'CodeGen
b,forall (pass :: StgPass).
DataCon
-> ConstructorNumber -> [StgArg] -> [Type] -> GenStgExpr pass
StgConApp DataCon
dc ConstructorNumber
cnum [StgArg]
as [])]
  GenStgRhs 'CodeGen
_ -> []

isSmallSinkableLit :: Literal -> Bool
isSmallSinkableLit :: Literal -> Bool
isSmallSinkableLit (LitChar Char
c)     = Char -> Int
ord Char
c forall a. Ord a => a -> a -> Bool
< Int
100000
isSmallSinkableLit (LitNumber LitNumType
_ Integer
i) = forall a. Num a => a -> a
abs Integer
i forall a. Ord a => a -> a -> Bool
< Integer
100000
isSmallSinkableLit Literal
_               = Bool
False


-- | once sinkable: may be sunk, but duplication is not ok
onceSinkable :: Module -> CgStgBinding -> [(Id, CgStgExpr)]
onceSinkable :: Module -> CgStgBinding -> [(Id, CgStgExpr)]
onceSinkable Module
_m (StgNonRec BinderP 'CodeGen
b GenStgRhs 'CodeGen
rhs)
  | Just CgStgExpr
e <- forall {pass :: StgPass}. GenStgRhs pass -> Maybe (GenStgExpr pass)
getSinkable GenStgRhs 'CodeGen
rhs
  , Id -> Bool
isLocal BinderP 'CodeGen
b = [(BinderP 'CodeGen
b,CgStgExpr
e)]
  where
    getSinkable :: GenStgRhs pass -> Maybe (GenStgExpr pass)
getSinkable = \case
      StgRhsCon CostCentreStack
_ccs DataCon
dc ConstructorNumber
cnum [StgTickish]
_ticks [StgArg]
args -> forall a. a -> Maybe a
Just (forall (pass :: StgPass).
DataCon
-> ConstructorNumber -> [StgArg] -> [Type] -> GenStgExpr pass
StgConApp DataCon
dc ConstructorNumber
cnum [StgArg]
args [])
      StgRhsClosure XRhsClosure pass
_ CostCentreStack
_ UpdateFlag
_ [BinderP pass]
_ e :: GenStgExpr pass
e@(StgLit{}) -> forall a. a -> Maybe a
Just GenStgExpr pass
e
      GenStgRhs pass
_                                  -> forall a. Maybe a
Nothing
onceSinkable Module
_ CgStgBinding
_ = []

-- | collect all idents used only once in an argument at the top level
--   and never anywhere else
collectUsedOnce :: [CgStgBinding] -> IdSet
collectUsedOnce :: [CgStgBinding] -> IdSet
collectUsedOnce [CgStgBinding]
binds = forall a. UniqSet a -> UniqSet a -> UniqSet a
intersectUniqSets ([Id] -> IdSet
usedOnce [Id]
args) ([Id] -> IdSet
usedOnce [Id]
top_args)
  where
    top_args :: [Id]
top_args = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap CgStgBinding -> [Id]
collectArgsTop [CgStgBinding]
binds
    args :: [Id]
args     = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap CgStgBinding -> [Id]
collectArgs    [CgStgBinding]
binds
    usedOnce :: [Id] -> IdSet
usedOnce = forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr forall {a}.
Uniquable a =>
a -> (UniqSet a, UniqSet a) -> (UniqSet a, UniqSet a)
g (forall a. UniqSet a
emptyUniqSet, forall a. UniqSet a
emptyUniqSet)
    g :: a -> (UniqSet a, UniqSet a) -> (UniqSet a, UniqSet a)
g a
i t :: (UniqSet a, UniqSet a)
t@(UniqSet a
once, UniqSet a
mult)
      | a
i forall a. Uniquable a => a -> UniqSet a -> Bool
`elementOfUniqSet` UniqSet a
mult = (UniqSet a, UniqSet a)
t
      | a
i forall a. Uniquable a => a -> UniqSet a -> Bool
`elementOfUniqSet` UniqSet a
once
        = (forall a. Uniquable a => UniqSet a -> a -> UniqSet a
delOneFromUniqSet UniqSet a
once a
i, forall a. Uniquable a => UniqSet a -> a -> UniqSet a
addOneToUniqSet UniqSet a
mult a
i)
      | Bool
otherwise = (forall a. Uniquable a => UniqSet a -> a -> UniqSet a
addOneToUniqSet UniqSet a
once a
i, UniqSet a
mult)

-- | fold over all id in StgArg used at the top level in an StgRhsCon
collectArgsTop :: CgStgBinding -> [Id]
collectArgsTop :: CgStgBinding -> [Id]
collectArgsTop = \case
  StgNonRec BinderP 'CodeGen
_b GenStgRhs 'CodeGen
r -> GenStgRhs 'CodeGen -> [Id]
collectArgsTopRhs GenStgRhs 'CodeGen
r
  StgRec [(BinderP 'CodeGen, GenStgRhs 'CodeGen)]
bs      -> forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (GenStgRhs 'CodeGen -> [Id]
collectArgsTopRhs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(BinderP 'CodeGen, GenStgRhs 'CodeGen)]
bs

collectArgsTopRhs :: CgStgRhs -> [Id]
collectArgsTopRhs :: GenStgRhs 'CodeGen -> [Id]
collectArgsTopRhs = \case
  StgRhsCon CostCentreStack
_ccs DataCon
_dc ConstructorNumber
_mu [StgTickish]
_ticks [StgArg]
args -> forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap StgArg -> [Id]
collectArgsA [StgArg]
args
  StgRhsClosure {}                   -> []

-- | fold over all Id in StgArg in the AST
collectArgs :: CgStgBinding -> [Id]
collectArgs :: CgStgBinding -> [Id]
collectArgs = \case
  StgNonRec BinderP 'CodeGen
_b GenStgRhs 'CodeGen
r -> GenStgRhs 'CodeGen -> [Id]
collectArgsR GenStgRhs 'CodeGen
r
  StgRec [(BinderP 'CodeGen, GenStgRhs 'CodeGen)]
bs      -> forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (GenStgRhs 'CodeGen -> [Id]
collectArgsR forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(BinderP 'CodeGen, GenStgRhs 'CodeGen)]
bs

collectArgsR :: CgStgRhs -> [Id]
collectArgsR :: GenStgRhs 'CodeGen -> [Id]
collectArgsR = \case
  StgRhsClosure XRhsClosure 'CodeGen
_x0 CostCentreStack
_x1 UpdateFlag
_x2 [BinderP 'CodeGen]
_x3 CgStgExpr
e     -> CgStgExpr -> [Id]
collectArgsE CgStgExpr
e
  StgRhsCon CostCentreStack
_ccs DataCon
_con ConstructorNumber
_mu [StgTickish]
_ticks [StgArg]
args -> forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap StgArg -> [Id]
collectArgsA [StgArg]
args

collectArgsAlt :: CgStgAlt -> [Id]
collectArgsAlt :: CgStgAlt -> [Id]
collectArgsAlt CgStgAlt
alt = CgStgExpr -> [Id]
collectArgsE (forall (pass :: StgPass). GenStgAlt pass -> GenStgExpr pass
alt_rhs CgStgAlt
alt)

collectArgsE :: CgStgExpr -> [Id]
collectArgsE :: CgStgExpr -> [Id]
collectArgsE = \case
  StgApp Id
x [StgArg]
args
    -> Id
x forall a. a -> [a] -> [a]
: forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap StgArg -> [Id]
collectArgsA [StgArg]
args
  StgConApp DataCon
_con ConstructorNumber
_mn [StgArg]
args [Type]
_ts
    -> forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap StgArg -> [Id]
collectArgsA [StgArg]
args
  StgOpApp StgOp
_x [StgArg]
args Type
_t
    -> forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap StgArg -> [Id]
collectArgsA [StgArg]
args
  StgCase CgStgExpr
e BinderP 'CodeGen
_b AltType
_a [CgStgAlt]
alts
    -> CgStgExpr -> [Id]
collectArgsE CgStgExpr
e forall a. [a] -> [a] -> [a]
++ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap CgStgAlt -> [Id]
collectArgsAlt [CgStgAlt]
alts
  StgLet XLet 'CodeGen
_x CgStgBinding
b CgStgExpr
e
    -> CgStgBinding -> [Id]
collectArgs CgStgBinding
b forall a. [a] -> [a] -> [a]
++ CgStgExpr -> [Id]
collectArgsE CgStgExpr
e
  StgLetNoEscape XLetNoEscape 'CodeGen
_x CgStgBinding
b CgStgExpr
e
    -> CgStgBinding -> [Id]
collectArgs CgStgBinding
b forall a. [a] -> [a] -> [a]
++ CgStgExpr -> [Id]
collectArgsE CgStgExpr
e
  StgTick StgTickish
_i CgStgExpr
e
    -> CgStgExpr -> [Id]
collectArgsE CgStgExpr
e
  StgLit Literal
_
    -> []

collectArgsA :: StgArg -> [Id]
collectArgsA :: StgArg -> [Id]
collectArgsA = \case
  StgVarArg Id
i -> [Id
i]
  StgLitArg Literal
_ -> []

isLocal :: Id -> Bool
isLocal :: Id -> Bool
isLocal Id
i = forall a. Maybe a -> Bool
isNothing (Name -> Maybe Module
nameModule_maybe forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Name
idName forall a b. (a -> b) -> a -> b
$ Id
i) Bool -> Bool -> Bool
&& Bool -> Bool
not (Id -> Bool
isExportedId Id
i)

-- | since we have sequential initialization, topsort the non-recursive
-- constructor bindings
topSortDecls :: Module -> [CgStgBinding] -> [CgStgBinding]
topSortDecls :: Module -> [CgStgBinding] -> [CgStgBinding]
topSortDecls Module
_m [CgStgBinding]
binds = [CgStgBinding]
rest forall a. [a] -> [a] -> [a]
++ [CgStgBinding]
nr'
  where
    ([CgStgBinding]
nr, [CgStgBinding]
rest) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition forall {pass :: StgPass}. GenStgBinding pass -> Bool
isNonRec [CgStgBinding]
binds
    isNonRec :: GenStgBinding pass -> Bool
isNonRec StgNonRec{} = Bool
True
    isNonRec GenStgBinding pass
_           = Bool
False
    vs :: [Node Id CgStgBinding]
vs   = forall a b. (a -> b) -> [a] -> [b]
map forall {pass :: StgPass}.
GenStgBinding pass -> Node (BinderP pass) (GenStgBinding pass)
getV [CgStgBinding]
nr
    keys :: IdSet
keys = forall a. Uniquable a => [a] -> UniqSet a
mkUniqSet (forall a b. (a -> b) -> [a] -> [b]
map forall key payload. Node key payload -> key
node_key [Node Id CgStgBinding]
vs)
    getV :: GenStgBinding pass -> Node (BinderP pass) (GenStgBinding pass)
getV e :: GenStgBinding pass
e@(StgNonRec BinderP pass
b GenStgRhs pass
_) = forall key payload. payload -> key -> [key] -> Node key payload
DigraphNode GenStgBinding pass
e BinderP pass
b []
    getV GenStgBinding pass
_                 = forall a. HasCallStack => [Char] -> a
error [Char]
"topSortDecls: getV, unexpected binding"
    collectDeps :: CgStgBinding -> [(Id, Id)]
collectDeps (StgNonRec BinderP 'CodeGen
b (StgRhsCon CostCentreStack
_cc DataCon
_dc ConstructorNumber
_cnum [StgTickish]
_ticks [StgArg]
args)) =
      [ (Id
i, BinderP 'CodeGen
b) | StgVarArg Id
i <- [StgArg]
args, Id
i forall a. Uniquable a => a -> UniqSet a -> Bool
`elementOfUniqSet` IdSet
keys ]
    collectDeps CgStgBinding
_ = []
    g :: Graph (Node Id CgStgBinding)
g = forall key payload.
Ord key =>
[Node key payload] -> [(key, key)] -> Graph (Node key payload)
graphFromVerticesAndAdjacency [Node Id CgStgBinding]
vs (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap CgStgBinding -> [(Id, Id)]
collectDeps [CgStgBinding]
nr)
    nr' :: [CgStgBinding]
nr' | (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> Bool
null) [()| CyclicSCC [Node Id CgStgBinding]
_ <- forall node. Graph node -> [SCC node]
stronglyConnCompG Graph (Node Id CgStgBinding)
g]
            = forall a. HasCallStack => [Char] -> a
error [Char]
"topSortDecls: unexpected cycle"
        | Bool
otherwise = forall a b. (a -> b) -> [a] -> [b]
map forall key payload. Node key payload -> payload
node_payload (forall node. Graph node -> [node]
topologicalSortG Graph (Node Id CgStgBinding)
g)