module Torch.Serialize where
import Control.Exception.Safe
( SomeException (..),
throwIO,
try,
)
import Control.Monad (when)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BSI
import Foreign.Marshal.Utils (copyBytes)
import qualified Foreign.ForeignPtr as F
import qualified Foreign.Ptr as F
import System.IO
import Torch.Autograd
import Torch.DType
import Torch.Functional
import Torch.Internal.Cast
import qualified Torch.Internal.Managed.Serialize as S
import Torch.NN
import Torch.Script hiding (clone, load, save)
import Torch.Tensor
save ::
[Tensor] ->
FilePath ->
IO ()
save :: [Tensor] -> FilePath -> IO ()
save = (ForeignPtr TensorList -> FilePath -> IO ())
-> [Tensor] -> FilePath -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorList -> FilePath -> IO ()
S.save
load ::
FilePath ->
IO [Tensor]
load :: FilePath -> IO [Tensor]
load = (FilePath -> IO (ForeignPtr TensorList)) -> FilePath -> IO [Tensor]
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 FilePath -> IO (ForeignPtr TensorList)
S.load
pickleSave ::
IValue ->
FilePath ->
IO ()
pickleSave :: IValue -> FilePath -> IO ()
pickleSave = (ForeignPtr IValue -> FilePath -> IO ())
-> IValue -> FilePath -> IO ()
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr IValue -> FilePath -> IO ()
S.pickleSave
pickleLoad ::
FilePath ->
IO IValue
pickleLoad :: FilePath -> IO IValue
pickleLoad = (FilePath -> IO (ForeignPtr IValue)) -> FilePath -> IO IValue
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 FilePath -> IO (ForeignPtr IValue)
S.pickleLoad
saveParams ::
Parameterized f =>
f ->
FilePath ->
IO ()
saveParams :: forall f. Parameterized f => f -> FilePath -> IO ()
saveParams f
model FilePath
filePath = do
let params :: [Tensor]
params = (Parameter -> Tensor) -> [Parameter] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
map Parameter -> Tensor
toDependent ([Parameter] -> [Tensor]) -> [Parameter] -> [Tensor]
forall a b. (a -> b) -> a -> b
$ f -> [Parameter]
forall f. Parameterized f => f -> [Parameter]
flattenParameters f
model
[Tensor] -> FilePath -> IO ()
save [Tensor]
params FilePath
filePath
loadParams ::
Parameterized b =>
b ->
FilePath ->
IO b
loadParams :: forall b. Parameterized b => b -> FilePath -> IO b
loadParams b
model FilePath
filePath = do
[Tensor]
tensors <- FilePath -> IO [Tensor]
load FilePath
filePath
let params :: [Parameter]
params = (Tensor -> Parameter) -> [Tensor] -> [Parameter]
forall a b. (a -> b) -> [a] -> [b]
map Tensor -> Parameter
IndependentTensor [Tensor]
tensors
b -> IO b
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (b -> IO b) -> b -> IO b
forall a b. (a -> b) -> a -> b
$ b -> [Parameter] -> b
forall f. Parameterized f => f -> [Parameter] -> f
replaceParameters b
model [Parameter]
params
class RawFile a where
loadBinary :: Handle -> a -> IO a
saveBinary :: Handle -> a -> IO ()
instance RawFile Tensor where
loadBinary :: Handle -> Tensor -> IO Tensor
loadBinary Handle
handle Tensor
tensor = do
let len :: Int
len = (DType -> Int
byteLength (Tensor -> DType
dtype Tensor
tensor)) Int -> Int -> Int
forall a. Num a => a -> a -> a
* [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Tensor -> [Int]
shape Tensor
tensor)
ByteString
v <- Handle -> Int -> IO ByteString
BS.hGet Handle
handle Int
len
Tensor
t <- Tensor -> IO Tensor
clone Tensor
tensor
Tensor -> (Ptr () -> IO Tensor) -> IO Tensor
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t ((Ptr () -> IO Tensor) -> IO Tensor)
-> (Ptr () -> IO Tensor) -> IO Tensor
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr1 -> do
let (BSI.PS ForeignPtr Word8
fptr Int
_ Int
len') = ByteString
v
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
len' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
IOError -> IO ()
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> IO ()) -> IOError -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath -> IOError
userError (FilePath -> IOError) -> FilePath -> IOError
forall a b. (a -> b) -> a -> b
$ FilePath
"Read data's size is less than input tensor's one(" FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> Int -> FilePath
forall a. Show a => a -> FilePath
show Int
len FilePath -> FilePath -> FilePath
forall a. Semigroup a => a -> a -> a
<> FilePath
")."
ForeignPtr Word8 -> (Ptr Word8 -> IO Tensor) -> IO Tensor
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
F.withForeignPtr ForeignPtr Word8
fptr ((Ptr Word8 -> IO Tensor) -> IO Tensor)
-> (Ptr Word8 -> IO Tensor) -> IO Tensor
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr2 -> do
Ptr Any -> Ptr Any -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes (Ptr () -> Ptr Any
forall a b. Ptr a -> Ptr b
F.castPtr Ptr ()
ptr1) (Ptr Word8 -> Ptr Any
forall a b. Ptr a -> Ptr b
F.castPtr Ptr Word8
ptr2) (Int -> Int -> Int
forall a. Ord a => a -> a -> a
Prelude.min Int
len Int
len')
Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t
saveBinary :: Handle -> Tensor -> IO ()
saveBinary Handle
handle Tensor
tensor = do
let len :: Int
len = (DType -> Int
byteLength (Tensor -> DType
dtype Tensor
tensor)) Int -> Int -> Int
forall a. Num a => a -> a -> a
* [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Tensor -> [Int]
shape Tensor
tensor)
Tensor
t <- Tensor -> IO Tensor
clone Tensor
tensor
Tensor -> (Ptr () -> IO ()) -> IO ()
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
tensor ((Ptr () -> IO ()) -> IO ()) -> (Ptr () -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr1 -> do
Handle -> Ptr Any -> Int -> IO ()
forall a. Handle -> Ptr a -> Int -> IO ()
hPutBuf Handle
handle (Ptr () -> Ptr Any
forall a b. Ptr a -> Ptr b
F.castPtr Ptr ()
ptr1) Int
len