{-# LANGUAGE CPP #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE OverloadedStrings #-}
module Torch.Internal.GC where
import Control.Concurrent (threadDelay)
import Control.Concurrent.Async
import Control.Exception.Safe (Exception, MonadThrow, Typeable, catch, throwIO, throwM)
import Control.Monad (when)
import Data.List (isPrefixOf)
import Foreign.C.Types
import GHC.ExecutionStack
import Language.C.Inline.Cpp.Exception
import System.Environment (lookupEnv)
import System.IO (hPutStrLn, stderr)
import System.IO.Unsafe (unsafePerformIO)
import System.Mem (performGC)
import System.SysInfo
import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T
import qualified Data.Text as T
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
foreign import ccall unsafe "hasktorch_finalizer.h showWeakPtrList"
c_showWeakPtrList :: CInt -> IO ()
#ifdef ENABLE_DUMMY_MALLOC_TRIM
mallocTrim :: CInt -> IO ()
mallocTrim _ = return ()
#else
foreign import ccall unsafe "malloc.h malloc_trim"
mallocTrim :: CInt -> IO ()
#endif
dumpLibtorchObjects ::
Int ->
IO ()
dumpLibtorchObjects :: Int -> IO ()
dumpLibtorchObjects Int
age = CInt -> IO ()
c_showWeakPtrList (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
age)
newtype HasktorchException = HasktorchException String
deriving (Int -> HasktorchException -> ShowS
[HasktorchException] -> ShowS
HasktorchException -> String
(Int -> HasktorchException -> ShowS)
-> (HasktorchException -> String)
-> ([HasktorchException] -> ShowS)
-> Show HasktorchException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> HasktorchException -> ShowS
showsPrec :: Int -> HasktorchException -> ShowS
$cshow :: HasktorchException -> String
show :: HasktorchException -> String
$cshowList :: [HasktorchException] -> ShowS
showList :: [HasktorchException] -> ShowS
Show)
instance Exception HasktorchException
bsToChars :: ByteString -> String
bsToChars :: ByteString -> String
bsToChars = Text -> String
T.unpack (Text -> String) -> (ByteString -> Text) -> ByteString -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OnDecodeError -> ByteString -> Text
T.decodeUtf8With OnDecodeError
T.lenientDecode
unsafeThrowableIO :: forall a m. MonadThrow m => IO a -> m a
unsafeThrowableIO :: forall a (m :: * -> *). MonadThrow m => IO a -> m a
unsafeThrowableIO IO a
a = IO (m a) -> m a
forall a. IO a -> a
unsafePerformIO (IO (m a) -> m a) -> IO (m a) -> m a
forall a b. (a -> b) -> a -> b
$ (a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> m a) -> IO a -> IO (m a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO a
a) IO (m a) -> (CppException -> IO (m a)) -> IO (m a)
forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` (\(CppStdException CppExceptionPtr
_ ByteString
msg Maybe ByteString
_) -> m a -> IO (m a)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (m a -> IO (m a))
-> (HasktorchException -> m a) -> HasktorchException -> IO (m a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HasktorchException -> m a
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwM (HasktorchException -> IO (m a)) -> HasktorchException -> IO (m a)
forall a b. (a -> b) -> a -> b
$ String -> HasktorchException
HasktorchException (String
"Exception: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> ByteString -> String
bsToChars ByteString
msg))
prettyException :: IO a -> IO a
prettyException :: forall a. IO a -> IO a
prettyException IO a
func =
IO a
func IO a -> (CppException -> IO a) -> IO a
forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` \a :: CppException
a@(CppStdException CppExceptionPtr
_ ByteString
message Maybe ByteString
_) -> do
flag <- String -> IO (Maybe String)
lookupEnv String
"HASKTORCH_DEBUG"
when (flag /= Just "0") $ do
mst <- showStackTrace
case mst of
Just String
st -> Handle -> String -> IO ()
hPutStrLn Handle
stderr String
st
Maybe String
Nothing -> Handle -> String -> IO ()
hPutStrLn Handle
stderr String
"Cannot show stacktrace"
B.hPutStr stderr message
throwIO a
{-# INLINE prettyException #-}
retryWithGC' :: Int -> IO a -> IO a
retryWithGC' :: forall a. Int -> IO a -> IO a
retryWithGC' Int
count IO a
func =
IO a
func IO a -> (CppException -> IO a) -> IO a
forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> (e -> m a) -> m a
`catch` \a :: CppException
a@(CppStdException CppExceptionPtr
_ ByteString
message Maybe ByteString
_) ->
if ByteString -> ByteString -> Bool
B.isPrefixOf ByteString
msgOutOfMemory ByteString
message
then
if Int
count Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0
then IOError -> IO a
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> IO a) -> IOError -> IO a
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError (String -> IOError) -> String -> IOError
forall a b. (a -> b) -> a -> b
$ ByteString -> String
bsToChars (ByteString -> String) -> ByteString -> String
forall a b. (a -> b) -> a -> b
$ ByteString
"Too many calls to performGC, " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
message
else do
IO ()
performGC
CInt -> IO ()
mallocTrim CInt
0
Int -> IO ()
threadDelay Int
1000
Int -> IO a -> IO a
forall a. Int -> IO a -> IO a
retryWithGC' (Int
count Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) IO a
func
else CppException -> IO a
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO CppException
a
where
#ifdef darwin_HOST_OS
msgOutOfMemory = "MPS backend out of memory"
#else
msgOutOfMemory :: ByteString
msgOutOfMemory = ByteString
"Exception: CUDA out of memory."
#endif
{-# INLINE retryWithGC' #-}
retryWithGC :: IO a -> IO a
retryWithGC :: forall a. IO a -> IO a
retryWithGC IO a
func = IO a -> IO a
forall a. IO a -> IO a
prettyException (IO a -> IO a) -> IO a -> IO a
forall a b. (a -> b) -> a -> b
$ Int -> IO a -> IO a
forall a. Int -> IO a -> IO a
retryWithGC' Int
10 IO a
func
{-# INLINE retryWithGC #-}
checkOSMemoryWithGC :: IO ()
checkOSMemoryWithGC :: IO ()
checkOSMemoryWithGC = do
v <- IO (Either Errno SysInfo)
sysInfo
case v of
Right SysInfo
stat -> do
let rate :: Double
rate = (CULong -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (SysInfo -> CULong
freeram SysInfo
stat) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ CULong -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (SysInfo -> CULong
totalram SysInfo
stat))
if Double
rate Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
0.5
then do
IO ()
performGC
CInt -> IO ()
mallocTrim CInt
0
else () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
Left Errno
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
threadDelay (500 * 1000)
checkOSMemoryWithGC
monitorMemory :: IO () -> IO ()
monitorMemory :: IO () -> IO ()
monitorMemory IO ()
func = do
IO ()
func IO () -> IO () -> IO (Either () ())
forall a b. IO a -> IO b -> IO (Either a b)
`race` IO ()
checkOSMemoryWithGC
() -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()