{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} module Torch.Jit where import Torch.Script import Torch.Tensor import Torch.NN import Control.Concurrent.STM.TVar import Control.Concurrent.STM (atomically) import System.IO.Unsafe (unsafePerformIO) newtype ScriptCache = ScriptCache { ScriptCache -> TVar (Maybe ScriptModule) unScriptCache :: TVar (Maybe ScriptModule) } newScriptCache :: IO ScriptCache newScriptCache :: IO ScriptCache newScriptCache = TVar (Maybe ScriptModule) -> ScriptCache ScriptCache (TVar (Maybe ScriptModule) -> ScriptCache) -> IO (TVar (Maybe ScriptModule)) -> IO ScriptCache forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> Maybe ScriptModule -> IO (TVar (Maybe ScriptModule)) forall a. a -> IO (TVar a) newTVarIO Maybe ScriptModule forall a. Maybe a Nothing jitIO :: ScriptCache -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO [Tensor] jitIO :: ScriptCache -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO [Tensor] jitIO (ScriptCache TVar (Maybe ScriptModule) cache) [Tensor] -> IO [Tensor] func [Tensor] input = do Maybe ScriptModule v <- TVar (Maybe ScriptModule) -> IO (Maybe ScriptModule) forall a. TVar a -> IO a readTVarIO TVar (Maybe ScriptModule) cache ScriptModule script <- case Maybe ScriptModule v of Just ScriptModule script' -> ScriptModule -> IO ScriptModule forall a. a -> IO a forall (m :: * -> *) a. Monad m => a -> m a return ScriptModule script' Maybe ScriptModule Nothing -> do RawModule m <- String -> String -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO RawModule trace String "MyModule" String "forward" [Tensor] -> IO [Tensor] func [Tensor] input ScriptModule script' <- RawModule -> IO ScriptModule toScriptModule RawModule m STM () -> IO () forall a. STM a -> IO a atomically (STM () -> IO ()) -> STM () -> IO () forall a b. (a -> b) -> a -> b $ TVar (Maybe ScriptModule) -> Maybe ScriptModule -> STM () forall a. TVar a -> a -> STM () writeTVar TVar (Maybe ScriptModule) cache (ScriptModule -> Maybe ScriptModule forall a. a -> Maybe a Just ScriptModule script') ScriptModule -> IO ScriptModule forall a. a -> IO a forall (m :: * -> *) a. Monad m => a -> m a return ScriptModule script' IVTensor Tensor r0 <- ScriptModule -> [IValue] -> IO IValue forall f a b. HasForward f a b => f -> a -> IO b forwardStoch ScriptModule script ((Tensor -> IValue) -> [Tensor] -> [IValue] forall a b. (a -> b) -> [a] -> [b] map Tensor -> IValue IVTensor [Tensor] input) [Tensor] -> IO [Tensor] forall a. a -> IO a forall (m :: * -> *) a. Monad m => a -> m a return [Tensor r0] jit :: ScriptCache -> ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor] jit :: ScriptCache -> ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor] jit ScriptCache cache [Tensor] -> [Tensor] func [Tensor] input = IO [Tensor] -> [Tensor] forall a. IO a -> a unsafePerformIO (IO [Tensor] -> [Tensor]) -> IO [Tensor] -> [Tensor] forall a b. (a -> b) -> a -> b $ ScriptCache -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO [Tensor] jitIO ScriptCache cache ([Tensor] -> IO [Tensor] forall a. a -> IO a forall (m :: * -> *) a. Monad m => a -> m a return ([Tensor] -> IO [Tensor]) -> ([Tensor] -> [Tensor]) -> [Tensor] -> IO [Tensor] forall b c a. (b -> c) -> (a -> b) -> a -> c . [Tensor] -> [Tensor] func) [Tensor] input