module Torch.Serialize where

import Control.Exception.Safe
  ( SomeException (..),
    throwIO,
    try,
  )
import Control.Monad (when)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BSI
import Foreign.Marshal.Utils (copyBytes)
import qualified Foreign.ForeignPtr as F
import qualified Foreign.Ptr as F
import System.IO
import Torch.Autograd
import Torch.DType
import Torch.Functional
import Torch.Internal.Cast
import qualified Torch.Internal.Managed.Serialize as S
import Torch.NN
import Torch.Script hiding (clone, load, save)
import Torch.Tensor

save ::
  -- | inputs
  [Tensor] ->
  -- | file
  FilePath ->
  -- | output
  IO ()
save :: [Tensor] -> FilePath -> IO ()
save = (ForeignPtr TensorList -> FilePath -> IO ())
-> [Tensor] -> FilePath -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorList -> FilePath -> IO ()
S.save

load ::
  -- | file
  FilePath ->
  -- | output
  IO [Tensor]
load :: FilePath -> IO [Tensor]
load = (FilePath -> IO (ForeignPtr TensorList)) -> FilePath -> IO [Tensor]
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 FilePath -> IO (ForeignPtr TensorList)
S.load

-- | Save state_dict
pickleSave ::
  -- | inputs
  IValue ->
  -- | file
  FilePath ->
  -- | output
  IO ()
pickleSave :: IValue -> FilePath -> IO ()
pickleSave = (ForeignPtr IValue -> FilePath -> IO ())
-> IValue -> FilePath -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr IValue -> FilePath -> IO ()
S.pickleSave

-- | Load a state_dict file
-- You should use a dict function of pytorch to save a state_dict file as follows.
--
-- > torch.save(dict(model.state_dict()), "state_dict.pth")
pickleLoad ::
  -- | file
  FilePath ->
  -- | output
  IO IValue
pickleLoad :: FilePath -> IO IValue
pickleLoad = (FilePath -> IO (ForeignPtr IValue)) -> FilePath -> IO IValue
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 FilePath -> IO (ForeignPtr IValue)
S.pickleLoad

saveParams ::
  Parameterized f =>
  -- | model
  f ->
  -- | filepath
  FilePath ->
  -- | output
  IO ()
saveParams :: forall f. Parameterized f => f -> FilePath -> IO ()
saveParams f
model FilePath
filePath = do
  let params :: [Tensor]
params = (Parameter -> Tensor) -> [Parameter] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
map Parameter -> Tensor
toDependent ([Parameter] -> [Tensor]) -> [Parameter] -> [Tensor]
forall a b. (a -> b) -> a -> b
$ f -> [Parameter]
forall f. Parameterized f => f -> [Parameter]
flattenParameters f
model
  [Tensor] -> FilePath -> IO ()
save [Tensor]
params FilePath
filePath

loadParams ::
  Parameterized b =>
  -- | model
  b ->
  -- | filepath
  FilePath ->
  -- | output
  IO b
loadParams :: forall b. Parameterized b => b -> FilePath -> IO b
loadParams b
model FilePath
filePath = do
  [Tensor]
tensors <- FilePath -> IO [Tensor]
load FilePath
filePath
  let params :: [Parameter]
params = (Tensor -> Parameter) -> [Tensor] -> [Parameter]
forall a b. (a -> b) -> [a] -> [b]
map Tensor -> Parameter
IndependentTensor [Tensor]
tensors
  b -> IO b
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (b -> IO b) -> b -> IO b
forall a b. (a -> b) -> a -> b
$ b -> [Parameter] -> b
forall f. Parameterized f => f -> [Parameter] -> f
replaceParameters b
model [Parameter]
params

class RawFile a where
  loadBinary :: Handle -> a -> IO a
  saveBinary :: Handle -> a -> IO ()

instance RawFile Tensor where
  loadBinary :: Handle -> Tensor -> IO Tensor
loadBinary Handle
handle Tensor
tensor = do
    let len :: Int
len = (DType -> Int
byteLength (Tensor -> DType
dtype Tensor
tensor)) Int -> Int -> Int
forall a. Num a => a -> a -> a
* [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Tensor -> [Int]
shape Tensor
tensor)
    ByteString
v <- Handle -> Int -> IO ByteString
BS.hGet Handle
handle Int
len
    Tensor
t <- Tensor -> IO Tensor
clone Tensor
tensor
    Tensor -> (Ptr () -> IO Tensor) -> IO Tensor
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t ((Ptr () -> IO Tensor) -> IO Tensor)
-> (Ptr () -> IO Tensor) -> IO Tensor
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr1 -> do
      let (BSI.PS ForeignPtr Word8
fptr Int
_ Int
len') = ByteString
v
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
len' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        IOError -> IO ()
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> IO ()) -> IOError -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath -> IOError
userError (FilePath -> IOError) -> FilePath -> IOError
forall a b. (a -> b) -> a -> b
$ FilePath
"Read data's size is less than input tensor's one(" FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> Int -> FilePath
forall a. Show a => a -> FilePath
show Int
len FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> FilePath
")."
      ForeignPtr Word8 -> (Ptr Word8 -> IO Tensor) -> IO Tensor
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word8
fptr ((Ptr Word8 -> IO Tensor) -> IO Tensor)
-> (Ptr Word8 -> IO Tensor) -> IO Tensor
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr2 -> do
        Ptr Any -> Ptr Any -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes (Ptr () -> Ptr Any
forall a b. Ptr a -> Ptr b
F.castPtr Ptr ()
ptr1) (Ptr Word8 -> Ptr Any
forall a b. Ptr a -> Ptr b
F.castPtr Ptr Word8
ptr2) (Int -> Int -> Int
forall a. Ord a => a -> a -> a
Prelude.min Int
len Int
len')
        Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t

  saveBinary :: Handle -> Tensor -> IO ()
saveBinary Handle
handle Tensor
tensor = do
    let len :: Int
len = (DType -> Int
byteLength (Tensor -> DType
dtype Tensor
tensor)) Int -> Int -> Int
forall a. Num a => a -> a -> a
* [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Tensor -> [Int]
shape Tensor
tensor)
    Tensor
t <- Tensor -> IO Tensor
clone Tensor
tensor
    Tensor -> (Ptr () -> IO ()) -> IO ()
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
tensor ((Ptr () -> IO ()) -> IO ()) -> (Ptr () -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr1 -> do
      Handle -> Ptr Any -> Int -> IO ()
forall a. Handle -> Ptr a -> Int -> IO ()
hPutBuf Handle
handle (Ptr () -> Ptr Any
forall a b. Ptr a -> Ptr b
F.castPtr Ptr ()
ptr1) Int
len