{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.Jit where

import Torch.Script
import Torch.Tensor
import Torch.NN
import Control.Concurrent.STM.TVar
import Control.Concurrent.STM (atomically)
import System.IO.Unsafe (unsafePerformIO)

newtype ScriptCache = ScriptCache { ScriptCache -> TVar (Maybe ScriptModule)
unScriptCache :: TVar (Maybe ScriptModule) }

newScriptCache :: IO ScriptCache
newScriptCache :: IO ScriptCache
newScriptCache = TVar (Maybe ScriptModule) -> ScriptCache
ScriptCache (TVar (Maybe ScriptModule) -> ScriptCache)
-> IO (TVar (Maybe ScriptModule)) -> IO ScriptCache
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe ScriptModule -> IO (TVar (Maybe ScriptModule))
forall a. a -> IO (TVar a)
newTVarIO Maybe ScriptModule
forall a. Maybe a
Nothing

jitIO :: ScriptCache -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO [Tensor]
jitIO :: ScriptCache -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO [Tensor]
jitIO (ScriptCache TVar (Maybe ScriptModule)
cache) [Tensor] -> IO [Tensor]
func [Tensor]
input = do
  v <- TVar (Maybe ScriptModule) -> IO (Maybe ScriptModule)
forall a. TVar a -> IO a
readTVarIO TVar (Maybe ScriptModule)
cache
  script <- case v of
    Just ScriptModule
script' -> ScriptModule -> IO ScriptModule
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ScriptModule
script'
    Maybe ScriptModule
Nothing -> do
      m <- String
-> String -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO RawModule
trace String
"MyModule" String
"forward" [Tensor] -> IO [Tensor]
func [Tensor]
input
      script' <- toScriptModule m
      atomically $ writeTVar cache (Just script')
      return script'
  IVTensor r0 <- forwardStoch script (map IVTensor input)
  return [r0]

jit :: ScriptCache -> ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor]
jit :: ScriptCache -> ([Tensor] -> [Tensor]) -> [Tensor] -> [Tensor]
jit ScriptCache
cache [Tensor] -> [Tensor]
func [Tensor]
input = IO [Tensor] -> [Tensor]
forall a. IO a -> a
unsafePerformIO (IO [Tensor] -> [Tensor]) -> IO [Tensor] -> [Tensor]
forall a b. (a -> b) -> a -> b
$ ScriptCache -> ([Tensor] -> IO [Tensor]) -> [Tensor] -> IO [Tensor]
jitIO ScriptCache
cache ([Tensor] -> IO [Tensor]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Tensor] -> IO [Tensor])
-> ([Tensor] -> [Tensor]) -> [Tensor] -> IO [Tensor]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Tensor] -> [Tensor]
func) [Tensor]
input