{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.Jit where

import Torch.Script
import Torch.Tensor
import Torch.NN
import Control.Concurrent.STM.TVar
import Control.Concurrent.STM (atomically)
import System.IO.Unsafe (unsafePerformIO)

newtype ScriptCache = ScriptCache { ScriptCache -> TVar (Maybe ScriptModule)
unScriptCache :: TVar (Maybe ScriptModule) }

newScriptCache :: IO ScriptCache
newScriptCache :: IO ScriptCache
newScriptCache = TVar (Maybe ScriptModule) -> ScriptCache
ScriptCache (TVar (Maybe ScriptModule) -> ScriptCache)
-> IO (TVar (Maybe ScriptModule)) -> IO ScriptCache
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe ScriptModule -> IO (TVar (Maybe ScriptModule))
forall a. a -> IO (TVar a)
newTVarIO Maybe ScriptModule
forall a. Maybe a
Nothing

jitIO :: ScriptCache -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO [Tensor]
jitIO :: ScriptCache -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO [Tensor]
jitIO (ScriptCache TVar (Maybe ScriptModule)
cache) [Tensor] -> IO [Tensor]
func [Tensor]
input = do
  Maybe ScriptModule
v <- TVar (Maybe ScriptModule) -> IO (Maybe ScriptModule)
forall a. TVar a -> IO a
readTVarIO TVar (Maybe ScriptModule)
cache
  ScriptModule
script <- case Maybe ScriptModule
v of
    Just ScriptModule
script' -> ScriptModule -> IO ScriptModule
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ScriptModule
script'
    Maybe ScriptModule
Nothing -> do
      RawModule
m <- String
-> String -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO RawModule
trace String
"MyModule" String
"forward" [Tensor] -> IO [Tensor]
func [Tensor]
input
      ScriptModule
script' <- RawModule -> IO ScriptModule
toScriptModule RawModule
m
      STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar (Maybe ScriptModule) -> Maybe ScriptModule -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (Maybe ScriptModule)
cache (ScriptModule -> Maybe ScriptModule
forall a. a -> Maybe a
Just ScriptModule
script')
      ScriptModule -> IO ScriptModule
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ScriptModule
script'
  IVTensor Tensor
r0 <- ScriptModule -> [IValue] -> IO IValue
forall f a b. HasForward f a b => f -> a -> IO b
forwardStoch ScriptModule
script ((Tensor -> IValue) -> [Tensor] -> [IValue]
forall a b. (a -> b) -> [a] -> [b]
map Tensor -> IValue
IVTensor [Tensor]
input)
  [Tensor] -> IO [Tensor]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [Tensor
r0]

jit :: ScriptCache -> ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor]
jit :: ScriptCache -> ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor]
jit ScriptCache
cache [Tensor] -> [Tensor]
func [Tensor]
input = IO [Tensor] -> [Tensor]
forall a. IO a -> a
unsafePerformIO (IO [Tensor] -> [Tensor]) -> IO [Tensor] -> [Tensor]
forall a b. (a -> b) -> a -> b
$ ScriptCache -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO [Tensor]
jitIO ScriptCache
cache ([Tensor] -> IO [Tensor]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Tensor] -> IO [Tensor])
-> ([Tensor] -> [Tensor]) -> [Tensor] -> IO [Tensor]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Tensor] -> [Tensor]
func) [Tensor]
input