{-# LANGUAGE CPP #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE OverloadedStrings #-}

module Torch.Internal.GC where

import Control.Concurrent (threadDelay)
import Control.Concurrent.Async
import Control.Exception.Safe (Exception, MonadThrow, Typeable, catch, throwIO, throwM)
import Control.Monad (when)
import Data.List (isPrefixOf)
import Foreign.C.Types
import GHC.ExecutionStack
import Language.C.Inline.Cpp.Exception
import System.Environment (lookupEnv)
import System.IO (hPutStrLn, stderr)
import System.IO.Unsafe (unsafePerformIO)
import System.Mem (performGC)
import System.SysInfo
import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T
import qualified Data.Text as T
import           Data.ByteString (ByteString)
import qualified Data.ByteString as B


foreign import ccall unsafe "hasktorch_finalizer.h showWeakPtrList"
  c_showWeakPtrList :: CInt -> IO ()

-- malloc_trim is a glibc function. It doesn't exist on macos.
#ifdef ENABLE_DUMMY_MALLOC_TRIM
mallocTrim :: CInt -> IO ()
mallocTrim _ = return ()
#else
foreign import ccall unsafe "malloc.h malloc_trim"
  mallocTrim :: CInt -> IO ()
#endif

-- | Returns all objects of libtorch.
-- Each time it is called, the age of the object increases by one.
-- Dumps objects that are greater than or equal to the argument of age.
dumpLibtorchObjects ::
  -- | age
  Int ->
  -- | output
  IO ()
dumpLibtorchObjects :: Int -> IO ()
dumpLibtorchObjects Int
age = CInt -> IO ()
c_showWeakPtrList (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
age)

newtype HasktorchException = HasktorchException String
  deriving (Int -> HasktorchException -> ShowS
[HasktorchException] -> ShowS
HasktorchException -> String
(Int -> HasktorchException -> ShowS)
-> (HasktorchException -> String)
-> ([HasktorchException] -> ShowS)
-> Show HasktorchException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HasktorchException -> ShowS
showsPrec :: Int -> HasktorchException -> ShowS
$cshow :: HasktorchException -> String
show :: HasktorchException -> String
$cshowList :: [HasktorchException] -> ShowS
showList :: [HasktorchException] -> ShowS
Show)

instance Exception HasktorchException

bsToChars :: ByteString -> String
bsToChars :: ByteString -> String
bsToChars = Text -> String
T.unpack (Text -> String) -> (ByteString -> Text) -> ByteString -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OnDecodeError -> ByteString -> Text
T.decodeUtf8With OnDecodeError
T.lenientDecode

unsafeThrowableIO :: forall a m. MonadThrow m => IO a -> m a
unsafeThrowableIO :: forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO IO a
a = IO (m a) -> m a
forall a. IO a -> a
unsafePerformIO (IO (m a) -> m a) -> IO (m a) -> m a
forall a b. (a -> b) -> a -> b
$ (a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> m a) -> IO a -> IO (m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO a
a) IO (m a) -> (CppException -> IO (m a)) -> IO (m a)
forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` (\(CppStdException CppExceptionPtr
_ ByteString
msg Maybe ByteString
_) -> m a -> IO (m a)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (m a -> IO (m a))
-> (HasktorchException -> m a) -> HasktorchException -> IO (m a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HasktorchException -> m a
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwM (HasktorchException -> IO (m a)) -> HasktorchException -> IO (m a)
forall a b. (a -> b) -> a -> b
$ String -> HasktorchException
HasktorchException (String
"Exception: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> ByteString -> String
bsToChars ByteString
msg))

prettyException :: IO a -> IO a
prettyException :: forall a. IO a -> IO a
prettyException IO a
func =
  IO a
func IO a -> (CppException -> IO a) -> IO a
forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` \a :: CppException
a@(CppStdException CppExceptionPtr
_ ByteString
message Maybe ByteString
_) -> do
    flag <- String -> IO (Maybe String)
lookupEnv String
"HASKTORCH_DEBUG"
    when (flag /= Just "0") $ do
      mst <- showStackTrace
      case mst of
        Just String
st -> Handle -> String -> IO ()
hPutStrLn Handle
stderr String
st
        Maybe String
Nothing -> Handle -> String -> IO ()
hPutStrLn Handle
stderr String
"Cannot show stacktrace"
      B.hPutStr stderr message
    throwIO a
{-# INLINE prettyException #-}

retryWithGC' :: Int -> IO a -> IO a
retryWithGC' :: forall a. Int -> IO a -> IO a
retryWithGC' Int
count IO a
func =
  IO a
func IO a -> (CppException -> IO a) -> IO a
forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` \a :: CppException
a@(CppStdException CppExceptionPtr
_ ByteString
message Maybe ByteString
_) ->
    if ByteString -> ByteString -> Bool
B.isPrefixOf ByteString
msgOutOfMemory ByteString
message
      then
        if Int
count Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0
          then IOError -> IO a
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> IO a) -> IOError -> IO a
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError (String -> IOError) -> String -> IOError
forall a b. (a -> b) -> a -> b
$ ByteString -> String
bsToChars (ByteString -> String) -> ByteString -> String
forall a b. (a -> b) -> a -> b
$ ByteString
"Too many calls to performGC, " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
message
          else do
            IO ()
performGC
            CInt -> IO ()
mallocTrim CInt
0
            Int -> IO ()
threadDelay Int
1000 -- We need delta delay(1ms) to wait GC.
            Int -> IO a -> IO a
forall a. Int -> IO a -> IO a
retryWithGC' (Int
count Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) IO a
func
      else CppException -> IO a
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO CppException
a
  where
#ifdef darwin_HOST_OS
    msgOutOfMemory = "MPS backend out of memory"
#else
    msgOutOfMemory :: ByteString
msgOutOfMemory = ByteString
"Exception: CUDA out of memory."
#endif
{-# INLINE retryWithGC' #-}

retryWithGC :: IO a -> IO a
retryWithGC :: forall a. IO a -> IO a
retryWithGC IO a
func = IO a -> IO a
forall a. IO a -> IO a
prettyException (IO a -> IO a) -> IO a -> IO a
forall a b. (a -> b) -> a -> b
$ Int -> IO a -> IO a
forall a. Int -> IO a -> IO a
retryWithGC' Int
10 IO a
func
{-# INLINE retryWithGC #-}

checkOSMemoryWithGC :: IO ()
checkOSMemoryWithGC :: IO ()
checkOSMemoryWithGC = do
  v <- IO (Either Errno SysInfo)
sysInfo
  case v of
    Right SysInfo
stat -> do
      let rate :: Double
rate = (CULong -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (SysInfo -> CULong
freeram SysInfo
stat) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ CULong -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (SysInfo -> CULong
totalram SysInfo
stat))
      if Double
rate Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
0.5
        then do
          IO ()
performGC
          CInt -> IO ()
mallocTrim CInt
0
        else () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Left Errno
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  threadDelay (500 * 1000) -- wait 500msec
  checkOSMemoryWithGC

monitorMemory :: IO () -> IO ()
monitorMemory :: IO () -> IO ()
monitorMemory IO ()
func = do
  IO ()
func IO () -> IO () -> IO (Either () ())
forall a b. IO a -> IO b -> IO (Either a b)
`race` IO ()
checkOSMemoryWithGC
  () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()