{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}

module Torch.TensorOptions where

import Data.Int
import Foreign.ForeignPtr
import System.IO.Unsafe
import Torch.DType
import Torch.Device
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.Type.Context as ATen
import qualified Torch.Internal.Managed.Type.TensorOptions as ATen
import qualified Torch.Internal.Type as ATen
import Torch.Layout

type ATenTensorOptions = ForeignPtr ATen.TensorOptions

newtype TensorOptions = TensorOptions ATenTensorOptions deriving (Int -> TensorOptions -> ShowS
[TensorOptions] -> ShowS
TensorOptions -> [Char]
(Int -> TensorOptions -> ShowS)
-> (TensorOptions -> [Char])
-> ([TensorOptions] -> ShowS)
-> Show TensorOptions
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TensorOptions -> ShowS
showsPrec :: Int -> TensorOptions -> ShowS
$cshow :: TensorOptions -> [Char]
show :: TensorOptions -> [Char]
$cshowList :: [TensorOptions] -> ShowS
showList :: [TensorOptions] -> ShowS
Show)

instance Castable TensorOptions ATenTensorOptions where
  cast :: forall r.
TensorOptions -> (ForeignPtr TensorOptions -> IO r) -> IO r
cast (TensorOptions ForeignPtr TensorOptions
aten_opts) ForeignPtr TensorOptions -> IO r
f = ForeignPtr TensorOptions -> IO r
f ForeignPtr TensorOptions
aten_opts
  uncast :: forall r.
ForeignPtr TensorOptions -> (TensorOptions -> IO r) -> IO r
uncast ForeignPtr TensorOptions
aten_opts TensorOptions -> IO r
f = TensorOptions -> IO r
f (TensorOptions -> IO r) -> TensorOptions -> IO r
forall a b. (a -> b) -> a -> b
$ ForeignPtr TensorOptions -> TensorOptions
TensorOptions ForeignPtr TensorOptions
aten_opts

defaultOpts :: TensorOptions
defaultOpts :: TensorOptions
defaultOpts =
  ForeignPtr TensorOptions -> TensorOptions
TensorOptions (ForeignPtr TensorOptions -> TensorOptions)
-> ForeignPtr TensorOptions -> TensorOptions
forall a b. (a -> b) -> a -> b
$ IO (ForeignPtr TensorOptions) -> ForeignPtr TensorOptions
forall a. IO a -> a
unsafePerformIO (IO (ForeignPtr TensorOptions) -> ForeignPtr TensorOptions)
-> IO (ForeignPtr TensorOptions) -> ForeignPtr TensorOptions
forall a b. (a -> b) -> a -> b
$ Layout -> IO (ForeignPtr TensorOptions)
ATen.newTensorOptions_s Layout
ATen.kFloat

withDType :: DType -> TensorOptions -> TensorOptions
withDType :: DType -> TensorOptions -> TensorOptions
withDType DType
dtype TensorOptions
opts =
  IO TensorOptions -> TensorOptions
forall a. IO a -> a
unsafePerformIO (IO TensorOptions -> TensorOptions)
-> IO TensorOptions -> TensorOptions
forall a b. (a -> b) -> a -> b
$ (ForeignPtr TensorOptions
 -> Layout -> IO (ForeignPtr TensorOptions))
-> TensorOptions -> DType -> IO TensorOptions
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Layout -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_dtype_s TensorOptions
opts DType
dtype

withDevice :: Device -> TensorOptions -> TensorOptions
withDevice :: Device -> TensorOptions -> TensorOptions
withDevice Device {Int16
DeviceType
deviceType :: DeviceType
deviceIndex :: Int16
deviceIndex :: Device -> Int16
deviceType :: Device -> DeviceType
..} TensorOptions
opts = IO TensorOptions -> TensorOptions
forall a. IO a -> a
unsafePerformIO (IO TensorOptions -> TensorOptions)
-> IO TensorOptions -> TensorOptions
forall a b. (a -> b) -> a -> b
$ do
  case DeviceType
deviceType of
    DeviceType
CPU -> TensorOptions -> IO TensorOptions
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TensorOptions
opts
    DeviceType
CUDA -> do
      Bool
hasCUDA <- IO CBool -> IO Bool
forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA
      DeviceType -> Int16 -> Bool -> TensorOptions -> IO TensorOptions
withDevice' DeviceType
deviceType Int16
deviceIndex Bool
hasCUDA TensorOptions
opts
    DeviceType
MPS -> do
      Bool
hasMPS <- IO CBool -> IO Bool
forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasMPS
      DeviceType -> Int16 -> Bool -> TensorOptions -> IO TensorOptions
withDevice' DeviceType
deviceType Int16
deviceIndex Bool
hasMPS TensorOptions
opts
  where
    withDeviceType :: DeviceType -> TensorOptions -> IO TensorOptions
    withDeviceType :: DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
dt TensorOptions
opts = (ForeignPtr TensorOptions
 -> Int16 -> IO (ForeignPtr TensorOptions))
-> TensorOptions -> DeviceType -> IO TensorOptions
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_D TensorOptions
opts DeviceType
dt
    withDeviceIndex :: Int16 -> TensorOptions -> IO TensorOptions
    withDeviceIndex :: Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di TensorOptions
opts = (ForeignPtr TensorOptions
 -> Int16 -> IO (ForeignPtr TensorOptions))
-> TensorOptions -> Int16 -> IO TensorOptions
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_index_s TensorOptions
opts Int16
di -- careful, this somehow implies deviceType = CUDA
    withDevice' ::
      DeviceType -> Int16 -> Bool -> TensorOptions -> IO TensorOptions
    withDevice' :: DeviceType -> Int16 -> Bool -> TensorOptions -> IO TensorOptions
withDevice' DeviceType
CPU Int16
0 Bool
_ TensorOptions
opts = TensorOptions -> IO TensorOptions
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TensorOptions
opts IO TensorOptions
-> (TensorOptions -> IO TensorOptions) -> IO TensorOptions
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
CPU
    withDevice' DeviceType
CUDA Int16
di Bool
True TensorOptions
opts | Int16
di Int16 -> Int16 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int16
0 = TensorOptions -> IO TensorOptions
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TensorOptions
opts IO TensorOptions
-> (TensorOptions -> IO TensorOptions) -> IO TensorOptions
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di
    withDevice' DeviceType
MPS Int16
0 Bool
True TensorOptions
opts = TensorOptions -> IO TensorOptions
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TensorOptions
opts IO TensorOptions
-> (TensorOptions -> IO TensorOptions) -> IO TensorOptions
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
MPS
    withDevice' DeviceType
dt Int16
di Bool
_ TensorOptions
_ =
      [Char] -> IO TensorOptions
forall a. HasCallStack => [Char] -> a
error ([Char] -> IO TensorOptions) -> [Char] -> IO TensorOptions
forall a b. (a -> b) -> a -> b
$ [Char]
"cannot move tensor to \"" [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> DeviceType -> [Char]
forall a. Show a => a -> [Char]
show DeviceType
dt [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
":" [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int16 -> [Char]
forall a. Show a => a -> [Char]
show Int16
di [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
"\""

withLayout :: Layout -> TensorOptions -> TensorOptions
withLayout :: Layout -> TensorOptions -> TensorOptions
withLayout Layout
layout TensorOptions
opts =
  IO TensorOptions -> TensorOptions
forall a. IO a -> a
unsafePerformIO (IO TensorOptions -> TensorOptions)
-> IO TensorOptions -> TensorOptions
forall a b. (a -> b) -> a -> b
$ (ForeignPtr TensorOptions
 -> Layout -> IO (ForeignPtr TensorOptions))
-> TensorOptions -> Layout -> IO TensorOptions
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Layout -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_layout_L TensorOptions
opts Layout
layout