{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
module Torch.TensorOptions where
import Data.Int
import Foreign.ForeignPtr
import System.IO.Unsafe
import Torch.DType
import Torch.Device
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.Type.Context as ATen
import qualified Torch.Internal.Managed.Type.TensorOptions as ATen
import qualified Torch.Internal.Type as ATen
import Torch.Layout
type ATenTensorOptions = ForeignPtr ATen.TensorOptions
newtype TensorOptions = TensorOptions ATenTensorOptions deriving (Int -> TensorOptions -> ShowS
[TensorOptions] -> ShowS
TensorOptions -> [Char]
(Int -> TensorOptions -> ShowS)
-> (TensorOptions -> [Char])
-> ([TensorOptions] -> ShowS)
-> Show TensorOptions
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TensorOptions -> ShowS
showsPrec :: Int -> TensorOptions -> ShowS
$cshow :: TensorOptions -> [Char]
show :: TensorOptions -> [Char]
$cshowList :: [TensorOptions] -> ShowS
showList :: [TensorOptions] -> ShowS
Show)
instance Castable TensorOptions ATenTensorOptions where
cast :: forall r.
TensorOptions -> (ForeignPtr TensorOptions -> IO r) -> IO r
cast (TensorOptions ForeignPtr TensorOptions
aten_opts) ForeignPtr TensorOptions -> IO r
f = ForeignPtr TensorOptions -> IO r
f ForeignPtr TensorOptions
aten_opts
uncast :: forall r.
ForeignPtr TensorOptions -> (TensorOptions -> IO r) -> IO r
uncast ForeignPtr TensorOptions
aten_opts TensorOptions -> IO r
f = TensorOptions -> IO r
f (TensorOptions -> IO r) -> TensorOptions -> IO r
forall a b. (a -> b) -> a -> b
$ ForeignPtr TensorOptions -> TensorOptions
TensorOptions ForeignPtr TensorOptions
aten_opts
defaultOpts :: TensorOptions
defaultOpts :: TensorOptions
defaultOpts =
ForeignPtr TensorOptions -> TensorOptions
TensorOptions (ForeignPtr TensorOptions -> TensorOptions)
-> ForeignPtr TensorOptions -> TensorOptions
forall a b. (a -> b) -> a -> b
$ IO (ForeignPtr TensorOptions) -> ForeignPtr TensorOptions
forall a. IO a -> a
unsafePerformIO (IO (ForeignPtr TensorOptions) -> ForeignPtr TensorOptions)
-> IO (ForeignPtr TensorOptions) -> ForeignPtr TensorOptions
forall a b. (a -> b) -> a -> b
$ Layout -> IO (ForeignPtr TensorOptions)
ATen.newTensorOptions_s Layout
ATen.kFloat
withDType :: DType -> TensorOptions -> TensorOptions
withDType :: DType -> TensorOptions -> TensorOptions
withDType DType
dtype TensorOptions
opts =
IO TensorOptions -> TensorOptions
forall a. IO a -> a
unsafePerformIO (IO TensorOptions -> TensorOptions)
-> IO TensorOptions -> TensorOptions
forall a b. (a -> b) -> a -> b
$ (ForeignPtr TensorOptions
-> Layout -> IO (ForeignPtr TensorOptions))
-> TensorOptions -> DType -> IO TensorOptions
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Layout -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_dtype_s TensorOptions
opts DType
dtype
withDevice :: Device -> TensorOptions -> TensorOptions
withDevice :: Device -> TensorOptions -> TensorOptions
withDevice Device {Int16
DeviceType
deviceType :: DeviceType
deviceIndex :: Int16
deviceIndex :: Device -> Int16
deviceType :: Device -> DeviceType
..} TensorOptions
opts = IO TensorOptions -> TensorOptions
forall a. IO a -> a
unsafePerformIO (IO TensorOptions -> TensorOptions)
-> IO TensorOptions -> TensorOptions
forall a b. (a -> b) -> a -> b
$ do
case DeviceType
deviceType of
DeviceType
CPU -> TensorOptions -> IO TensorOptions
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TensorOptions
opts
DeviceType
CUDA -> do
hasCUDA <- IO CBool -> IO Bool
forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA
withDevice' deviceType deviceIndex hasCUDA opts
DeviceType
MPS -> do
hasMPS <- IO CBool -> IO Bool
forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasMPS
withDevice' deviceType deviceIndex hasMPS opts
where
withDeviceType :: DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType :: DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
dt TensorOptions
opts = (ForeignPtr TensorOptions
-> Int16 -> IO (ForeignPtr TensorOptions))
-> TensorOptions -> DeviceType -> IO TensorOptions
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_D TensorOptions
opts DeviceType
dt
withDeviceIndex :: Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex :: Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di TensorOptions
opts = (ForeignPtr TensorOptions
-> Int16 -> IO (ForeignPtr TensorOptions))
-> TensorOptions -> Int16 -> IO TensorOptions
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_index_s TensorOptions
opts Int16
di
withDevice' ::
DeviceType -> Int16 -> Bool -> TensorOptions -> IO TensorOptions
withDevice' :: DeviceType -> Int16 -> Bool -> TensorOptions -> IO TensorOptions
withDevice' DeviceType
CPU Int16
0 Bool
_ TensorOptions
opts = TensorOptions -> IO TensorOptions
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TensorOptions
opts IO TensorOptions
-> (TensorOptions -> IO TensorOptions) -> IO TensorOptions
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
CPU
withDevice' DeviceType
CUDA Int16
di Bool
True TensorOptions
opts | Int16
di Int16 -> Int16 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int16
0 = TensorOptions -> IO TensorOptions
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TensorOptions
opts IO TensorOptions
-> (TensorOptions -> IO TensorOptions) -> IO TensorOptions
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di
withDevice' DeviceType
MPS Int16
0 Bool
True TensorOptions
opts = TensorOptions -> IO TensorOptions
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TensorOptions
opts IO TensorOptions
-> (TensorOptions -> IO TensorOptions) -> IO TensorOptions
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
MPS
withDevice' DeviceType
dt Int16
di Bool
_ TensorOptions
_ =
[Char] -> IO TensorOptions
forall a. HasCallStack => [Char] -> a
error ([Char] -> IO TensorOptions) -> [Char] -> IO TensorOptions
forall a b. (a -> b) -> a -> b
$ [Char]
"cannot move tensor to \"" [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> DeviceType -> [Char]
forall a. Show a => a -> [Char]
show DeviceType
dt [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
":" [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int16 -> [Char]
forall a. Show a => a -> [Char]
show Int16
di [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
"\""
withLayout :: Layout -> TensorOptions -> TensorOptions
withLayout :: Layout -> TensorOptions -> TensorOptions
withLayout Layout
layout TensorOptions
opts =
IO TensorOptions -> TensorOptions
forall a. IO a -> a
unsafePerformIO (IO TensorOptions -> TensorOptions)
-> IO TensorOptions -> TensorOptions
forall a b. (a -> b) -> a -> b
$ (ForeignPtr TensorOptions
-> Layout -> IO (ForeignPtr TensorOptions))
-> TensorOptions -> Layout -> IO TensorOptions
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Layout -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_layout_L TensorOptions
opts Layout
layout