{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE GHCForeignImportPrim #-}
{-# LANGUAGE UnliftedFFITypes #-}

{- |
   Module      : HeapSize
   Copyright   : (c) Michail Pardalos
   License     : 3-Clause BSD-style
   Maintainer  : mpardalos@gmail.com

   Based on GHC.Datasize by Dennis Felsing
 -}
module HeapSize (
  recursiveSize,
  recursiveSizeNoGC,
  recursiveSizeNF,
  closureSize
  )
  where

import Control.DeepSeq (NFData, force)

import GHC.Exts hiding (closureSize#)
import GHC.Arr
import GHC.Exts.Heap hiding (size)
import qualified Data.HashSet as H
import Data.IORef
import Data.Hashable

import Control.Monad

import System.Mem

foreign import prim "aToWordzh" aToWord# :: Any -> Word#
foreign import prim "unpackClosurePtrs" unpackClosurePtrs# :: Any -> Array# b
foreign import prim "closureSize" closureSize# :: Any -> Int#

-- | Get the *non-recursive* size of an closure in words
closureSize :: a -> IO Int
closureSize :: a -> IO Int
closureSize a
x = Int -> IO Int
forall (m :: * -> *) a. Monad m => a -> m a
return (Int# -> Int
I# (Any -> Int#
closureSize# (a -> Any
unsafeCoerce# a
x)))

getClosures :: a -> IO (Array Int Box)
getClosures :: a -> IO (Array Int Box)
getClosures a
x = case Any -> Array# Any
forall b. Any -> Array# b
unpackClosurePtrs# (a -> Any
unsafeCoerce# a
x) of
    Array# Any
pointers ->
      let nelems :: Int
nelems = Int# -> Int
I# (Array# Any -> Int#
forall k1. Array# k1 -> Int#
sizeofArray# Array# Any
pointers)
      in Array Int Box -> IO (Array Int Box)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Any -> Box) -> Array Int Any -> Array Int Box
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Any -> Box
Box (Array Int Any -> Array Int Box) -> Array Int Any -> Array Int Box
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Int -> Array# Any -> Array Int Any
forall i e. i -> i -> Int -> Array# e -> Array i e
Array Int
0 (Int
nelems Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
nelems Array# Any
pointers)

-- | Calculate the recursive size of GHC objects in Bytes. Note that the actual
--   size in memory is calculated, so shared values are only counted once.
--
--   Call with
--   @
--    recursiveSize $! 2
--   @
--   to force evaluation to WHNF before calculating the size.
--
--   Call with
--   @
--    recursiveSize $!! \"foobar\"
--   @
--   ($!! from Control.DeepSeq) to force full evaluation before calculating the
--   size.
--
--   A garbage collection is performed before the size is calculated, because
--   the garbage collector would make heap walks difficult.
--
--   This function works very quickly on small data structures, but can be slow
--   on large and complex ones. If speed is an issue it's probably possible to
--   get the exact size of a small portion of the data structure and then
--   estimate the total size from that.
recursiveSize :: a -> IO Int
recursiveSize :: a -> IO Int
recursiveSize a
x = IO ()
performGC IO () -> IO Int -> IO Int
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> a -> IO Int
forall a. a -> IO Int
recursiveSizeNoGC a
x

-- | Same as `recursiveSize` except without performing garbage collection first.
--   Useful if you want to measure the size of many objects in sequence. You can
--   call `performGC` once at first and then use this function to avoid multiple
--   unnecessary garbage collections.
recursiveSizeNoGC :: a -> IO Int
recursiveSizeNoGC :: a -> IO Int
recursiveSizeNoGC a
x = do
  IORef (Int, HashSet HashableBox)
state <- (Int, HashSet HashableBox) -> IO (IORef (Int, HashSet HashableBox))
forall a. a -> IO (IORef a)
newIORef (Int
0, HashSet HashableBox
forall a. HashSet a
H.empty)
  IORef (Int, HashSet HashableBox) -> Box -> IO ()
go IORef (Int, HashSet HashableBox)
state (a -> Box
forall a. a -> Box
asBox a
x)

  (Int, HashSet HashableBox) -> Int
forall a b. (a, b) -> a
fst ((Int, HashSet HashableBox) -> Int)
-> IO (Int, HashSet HashableBox) -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IORef (Int, HashSet HashableBox) -> IO (Int, HashSet HashableBox)
forall a. IORef a -> IO a
readIORef IORef (Int, HashSet HashableBox)
state
  where
    go :: IORef (Int, H.HashSet HashableBox) -> Box -> IO ()
    go :: IORef (Int, HashSet HashableBox) -> Box -> IO ()
go IORef (Int, HashSet HashableBox)
state b :: Box
b@(Box Any
y) = do
      (Int
_, HashSet HashableBox
closuresSeen) <- IORef (Int, HashSet HashableBox) -> IO (Int, HashSet HashableBox)
forall a. IORef a -> IO a
readIORef IORef (Int, HashSet HashableBox)
state

      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ HashableBox -> HashSet HashableBox -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
H.member (Box -> HashableBox
HashableBox Box
b) HashSet HashableBox
closuresSeen) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Int
thisSize <- Any -> IO Int
forall a. a -> IO Int
closureSize Any
y
        IORef (Int, HashSet HashableBox)
-> ((Int, HashSet HashableBox) -> (Int, HashSet HashableBox))
-> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef IORef (Int, HashSet HashableBox)
state (((Int, HashSet HashableBox) -> (Int, HashSet HashableBox))
 -> IO ())
-> ((Int, HashSet HashableBox) -> (Int, HashSet HashableBox))
-> IO ()
forall a b. (a -> b) -> a -> b
$ \(Int
size, HashSet HashableBox
_) ->
          (Int
size Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
thisSize, HashableBox -> HashSet HashableBox -> HashSet HashableBox
forall a. (Eq a, Hashable a) => a -> HashSet a -> HashSet a
H.insert (Box -> HashableBox
HashableBox Box
b) HashSet HashableBox
closuresSeen)

        (Box -> IO ()) -> Array Int Box -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (IORef (Int, HashSet HashableBox) -> Box -> IO ()
go IORef (Int, HashSet HashableBox)
state) (Array Int Box -> IO ()) -> IO (Array Int Box) -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Any -> IO (Array Int Box)
forall a. a -> IO (Array Int Box)
getClosures Any
y

-- | Calculate the recursive size of GHC objects in Bytes after calling
-- Control.DeepSeq.force on the data structure to force it into Normal Form.
-- Using this function requires that the data structure has an `NFData`
-- typeclass instance.

recursiveSizeNF :: NFData a => a -> IO Int
recursiveSizeNF :: a -> IO Int
recursiveSizeNF = a -> IO Int
forall a. a -> IO Int
recursiveSize (a -> IO Int) -> (a -> a) -> a -> IO Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
forall a. NFData a => a -> a
force

newtype HashableBox = HashableBox Box
    deriving newtype Int -> HashableBox -> ShowS
[HashableBox] -> ShowS
HashableBox -> String
(Int -> HashableBox -> ShowS)
-> (HashableBox -> String)
-> ([HashableBox] -> ShowS)
-> Show HashableBox
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HashableBox] -> ShowS
$cshowList :: [HashableBox] -> ShowS
show :: HashableBox -> String
$cshow :: HashableBox -> String
showsPrec :: Int -> HashableBox -> ShowS
$cshowsPrec :: Int -> HashableBox -> ShowS
Show

-- | Pointer Equality
instance Eq HashableBox where
    (HashableBox (Box Any
a1)) == :: HashableBox -> HashableBox -> Bool
== (HashableBox (Box Any
a2)) =
        Word# -> Word
W# (Any -> Word#
aToWord# Any
a1) Word -> Word -> Bool
forall a. Eq a => a -> a -> Bool
== Word# -> Word
W# (Any -> Word#
aToWord# Any
a2)

-- | Pointer hash
instance Hashable HashableBox where
    hashWithSalt :: Int -> HashableBox -> Int
hashWithSalt Int
n (HashableBox (Box Any
a)) = Int -> Word -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
n (Word# -> Word
W# (Any -> Word#
aToWord# Any
a))