{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.Tensor where

import Control.Exception.Safe (throwIO)
import Control.Monad (forM, forM_)
import Numeric.Half
import Data.Complex
import Data.Int (Int16, Int64)
import Data.List (intercalate)
import Data.Proxy
import Data.Reflection
import qualified Data.Vector as V
import Data.Word (Word8)
import Foreign.C.Types
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import GHC.Generics
import Numeric
import System.IO.Unsafe
import Torch.DType
import Torch.Device
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..), CppTuple2 (..), CppTuple3 (..), CppTuple4 (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.Cast as ATen
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Managed.TensorFactories as LibTorch
import qualified Torch.Internal.Managed.Type.Context as ATen
import qualified Torch.Internal.Managed.Type.StdArray as ATen
import qualified Torch.Internal.Managed.Type.StdString as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Managed.Type.TensorIndex as ATen
import qualified Torch.Internal.Managed.Type.TensorOptions as ATen
import qualified Torch.Internal.Managed.Type.Extra as ATen
import qualified Torch.Internal.Type as ATen
import qualified Torch.Internal.Unmanaged.Type.Tensor as Unmanaged (tensor_data_ptr)
import Torch.Lens
import Torch.TensorOptions

type ATenTensor = ForeignPtr ATen.Tensor

-- do not use the constructor
newtype Tensor = Unsafe ATenTensor

instance Castable Tensor ATenTensor where
  cast :: forall r. Tensor -> (ForeignPtr Tensor -> IO r) -> IO r
cast (Unsafe ForeignPtr Tensor
aten_tensor) ForeignPtr Tensor -> IO r
f = ForeignPtr Tensor -> IO r
f ForeignPtr Tensor
aten_tensor
  uncast :: forall r. ForeignPtr Tensor -> (Tensor -> IO r) -> IO r
uncast ForeignPtr Tensor
aten_tensor Tensor -> IO r
f = Tensor -> IO r
f (Tensor -> IO r) -> Tensor -> IO r
forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor -> Tensor
Unsafe ForeignPtr Tensor
aten_tensor

newtype MutableTensor = MutableTensor Tensor deriving Int -> MutableTensor -> ShowS
[MutableTensor] -> ShowS
MutableTensor -> [Char]
(Int -> MutableTensor -> ShowS)
-> (MutableTensor -> [Char])
-> ([MutableTensor] -> ShowS)
-> Show MutableTensor
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MutableTensor -> ShowS
showsPrec :: Int -> MutableTensor -> ShowS
$cshow :: MutableTensor -> [Char]
show :: MutableTensor -> [Char]
$cshowList :: [MutableTensor] -> ShowS
showList :: [MutableTensor] -> ShowS
Show

newMutableTensor :: Tensor -> IO MutableTensor
newMutableTensor :: Tensor -> IO MutableTensor
newMutableTensor Tensor
tensor = Tensor -> MutableTensor
MutableTensor (Tensor -> MutableTensor) -> IO Tensor -> IO MutableTensor
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> IO Tensor
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.detach_t Tensor
tensor

toImmutable :: MutableTensor -> IO Tensor
toImmutable :: MutableTensor -> IO Tensor
toImmutable (MutableTensor Tensor
tensor) = (ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> IO Tensor
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.detach_t Tensor
tensor

--------------------------------------------------------------------------------
-- Basic tensor properties
--------------------------------------------------------------------------------

-- | Returns the total number of elements in the input tensor.
numel ::
  -- | input
  Tensor ->
  -- | number of elements in tensor
  Int
numel :: Tensor -> Int
numel Tensor
t = IO Int -> Int
forall a. IO a -> a
unsafePerformIO (IO Int -> Int) -> IO Int -> Int
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> IO Int64) -> Tensor -> IO Int
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_numel (Tensor -> IO Int) -> Tensor -> IO Int
forall a b. (a -> b) -> a -> b
$ Tensor
t

-- | Returns the size of a given dimension of the input tensor.
size ::
  -- | dimension
  Int ->
  -- | input
  Tensor ->
  Int
size :: Int -> Tensor -> Int
size Int
dim Tensor
t = IO Int -> Int
forall a. IO a -> a
unsafePerformIO (IO Int -> Int) -> IO Int -> Int
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> Int64 -> IO Int64) -> Tensor -> Int -> IO Int
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> Int64 -> IO Int64
ATen.tensor_size_l) Tensor
t Int
dim

-- | Returns the shape of the tensor
shape ::
  -- | input
  Tensor ->
  -- | list of integers representing the shape of the tensor
  [Int]
shape :: Tensor -> [Int]
shape Tensor
t = IO [Int] -> [Int]
forall a. IO a -> a
unsafePerformIO (IO [Int] -> [Int]) -> IO [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO (ForeignPtr IntArray))
-> Tensor -> IO [Int]
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr IntArray)
ATen.tensor_sizes) Tensor
t

-- | Returns the dimensions of the input tensor
dim ::
  -- | input
  Tensor ->
  -- | output
  Int
dim :: Tensor -> Int
dim Tensor
t = IO Int -> Int
forall a. IO a -> a
unsafePerformIO (IO Int -> Int) -> IO Int -> Int
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO Int64) -> Tensor -> IO Int
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_dim) Tensor
t

-- | Returns the dimensions of the input tensor
dimUnsafe ::
  -- | input
  Tensor ->
  -- | output
  Int
dimUnsafe :: Tensor -> Int
dimUnsafe Tensor
t = IO Int -> Int
forall a. IO a -> a
unsafePerformIO (IO Int -> Int) -> IO Int -> Int
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO Int64) -> Tensor -> IO Int
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_dim_unsafe) Tensor
t

-- | Returns the dimensions of the input tensor
dimCUnsafe ::
  -- | input
  Tensor ->
  -- | output
  Int
dimCUnsafe :: Tensor -> Int
dimCUnsafe Tensor
t = IO Int -> Int
forall a. IO a -> a
unsafePerformIO (IO Int -> Int) -> IO Int -> Int
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO Int64) -> Tensor -> IO Int
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_dim_c_unsafe) Tensor
t

-- | Returns the device on which the tensor is currently allocated
device ::
  -- | input
  Tensor ->
  -- | object representing the device
  Device
device :: Tensor -> Device
device Tensor
t = IO Device -> Device
forall a. IO a -> a
unsafePerformIO (IO Device -> Device) -> IO Device -> Device
forall a b. (a -> b) -> a -> b
$ do
  Bool
hasCUDA <- IO CBool -> IO Bool
forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA :: IO Bool
  if Bool
hasCUDA
    then do
      Bool
isCUDA <- (ForeignPtr Tensor -> IO CBool) -> Tensor -> IO Bool
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_cuda Tensor
t :: IO Bool
      if Bool
isCUDA then Int -> Device
cuda (Int -> Device) -> IO Int -> IO Device
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ForeignPtr Tensor -> IO Int64) -> Tensor -> IO Int
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_get_device Tensor
t else Device -> IO Device
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Device
cpu
    else do
      Bool
hasMPS <- IO CBool -> IO Bool
forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasMPS :: IO Bool
      if Bool
hasMPS
        then do
        Bool
isMPS <- (ForeignPtr Tensor -> IO CBool) -> Tensor -> IO Bool
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_mps Tensor
t :: IO Bool
        if Bool
isMPS then Device -> IO Device
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Device
mps else Device -> IO Device
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Device
cpu
      else
        Device -> IO Device
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Device
cpu
  where
    cpu :: Device
cpu = Device {deviceType :: DeviceType
deviceType = DeviceType
CPU, deviceIndex :: Int16
deviceIndex = Int16
0}
    cuda :: Int -> Device
    cuda :: Int -> Device
cuda Int
di = Device {deviceType :: DeviceType
deviceType = DeviceType
CUDA, deviceIndex :: Int16
deviceIndex = Int -> Int16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
di}
    mps :: Device
mps = Device {deviceType :: DeviceType
deviceType = DeviceType
MPS, deviceIndex :: Int16
deviceIndex = Int16
0}

-- | Returns the data type of the input tensor
dtype ::
  -- | input
  Tensor ->
  -- | data type of the input tensor
  DType
dtype :: Tensor -> DType
dtype Tensor
t = IO DType -> DType
forall a. IO a -> a
unsafePerformIO (IO DType -> DType) -> IO DType -> DType
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> IO ScalarType) -> Tensor -> IO DType
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO ScalarType
ATen.tensor_scalar_type Tensor
t

toComplex :: Tensor -> Complex Double
toComplex :: Tensor -> Complex Double
toComplex Tensor
t = IO (Complex Double) -> Complex Double
forall a. IO a -> a
unsafePerformIO (IO (Complex Double) -> Complex Double)
-> IO (Complex Double) -> Complex Double
forall a b. (a -> b) -> a -> b
$
    case Tensor -> DType
dtype Tensor
t of
      DType
ComplexHalf -> do
        Half
r :+ Half
i  <- Tensor -> (Ptr () -> IO (Complex Half)) -> IO (Complex Half)
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t ((Ptr () -> IO (Complex Half)) -> IO (Complex Half))
-> (Ptr () -> IO (Complex Half)) -> IO (Complex Half)
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> Ptr (Complex Half) -> Int -> IO (Complex Half)
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (Ptr () -> Ptr (Complex Half)
forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
0 :: IO (Complex Half)
        Complex Double -> IO (Complex Double)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Half -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac Half
r Double -> Double -> Complex Double
forall a. a -> a -> Complex a
:+ Half -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac Half
i)
      DType
ComplexFloat -> do
        Float
r :+ Float
i  <- Tensor -> (Ptr () -> IO (Complex Float)) -> IO (Complex Float)
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t ((Ptr () -> IO (Complex Float)) -> IO (Complex Float))
-> (Ptr () -> IO (Complex Float)) -> IO (Complex Float)
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> Ptr (Complex Float) -> Int -> IO (Complex Float)
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (Ptr () -> Ptr (Complex Float)
forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
0 :: IO (Complex Float)
        Complex Double -> IO (Complex Double)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Float -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
r Double -> Double -> Complex Double
forall a. a -> a -> Complex a
:+ Float -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
i)
      DType
ComplexDouble -> Tensor -> (Ptr () -> IO (Complex Double)) -> IO (Complex Double)
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t ((Ptr () -> IO (Complex Double)) -> IO (Complex Double))
-> (Ptr () -> IO (Complex Double)) -> IO (Complex Double)
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> Ptr (Complex Double) -> Int -> IO (Complex Double)
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (Ptr () -> Ptr (Complex Double)
forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
0 :: IO (Complex Double)
      DType
_ -> (Double -> Double -> Complex Double
forall a. a -> a -> Complex a
:+ Double
0) (Double -> Complex Double) -> IO Double -> IO (Complex Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ForeignPtr Tensor -> IO CDouble) -> Tensor -> IO Double
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CDouble
ATen.tensor_item_double Tensor
t

toDouble :: Tensor -> Double
toDouble :: Tensor -> Double
toDouble Tensor
t = IO Double -> Double
forall a. IO a -> a
unsafePerformIO (IO Double -> Double) -> IO Double -> Double
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> IO CDouble) -> Tensor -> IO Double
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CDouble
ATen.tensor_item_double Tensor
t

toInt :: Tensor -> Int
toInt :: Tensor -> Int
toInt Tensor
t = IO Int -> Int
forall a. IO a -> a
unsafePerformIO (IO Int -> Int) -> IO Int -> Int
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> IO Int64) -> Tensor -> IO Int
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_item_int64_t Tensor
t

-- | Casts the input tensor to the given data type
_toType ::
  -- | data type to cast input to
  DType ->
  -- | input
  Tensor ->
  -- | output
  Tensor
_toType :: DType -> Tensor -> Tensor
_toType DType
dtype Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor))
-> Tensor -> DType -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor
t DType
dtype

instance HasTypes Tensor Tensor where
  types_ :: Traversal' Tensor Tensor
types_ = (Tensor -> f Tensor) -> Tensor -> f Tensor
forall a. a -> a
id

instance HasTypes (a -> a) Tensor where
  types_ :: Traversal' (a -> a) Tensor
types_ Tensor -> f Tensor
_ = (a -> a) -> f (a -> a)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance HasTypes Int Tensor where
  types_ :: Traversal' Int Tensor
types_ Tensor -> f Tensor
_ = Int -> f Int
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance HasTypes Double Tensor where
  types_ :: Traversal' Double Tensor
types_ Tensor -> f Tensor
_ = Double -> f Double
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance HasTypes Float Tensor where
  types_ :: Traversal' Float Tensor
types_ Tensor -> f Tensor
_ = Float -> f Float
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance HasTypes Bool Tensor where
  types_ :: Traversal' Bool Tensor
types_ Tensor -> f Tensor
_ = Bool -> f Bool
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

instance HasTypes Int Int where
  types_ :: Traversal' Int Int
types_ = (Int -> f Int) -> Int -> f Int
forall a. a -> a
id

instance HasTypes Float Float where
  types_ :: Traversal' Float Float
types_ = (Float -> f Float) -> Float -> f Float
forall a. a -> a
id

instance HasTypes Double Double where
  types_ :: Traversal' Double Double
types_ = (Double -> f Double) -> Double -> f Double
forall a. a -> a
id

instance HasTypes Bool Bool where
  types_ :: Traversal' Bool Bool
types_ = (Bool -> f Bool) -> Bool -> f Bool
forall a. a -> a
id

toType :: forall a. HasTypes a Tensor => DType -> a -> a
toType :: forall a. HasTypes a Tensor => DType -> a -> a
toType DType
dtype a
t = Traversal' a Tensor -> (Tensor -> Tensor) -> a -> a
forall s a. Traversal' s a -> (a -> a) -> s -> s
over (forall a s. HasTypes s a => Traversal' s a
types @Tensor @a) (DType -> Tensor -> Tensor
_toType DType
dtype) a
t

toDevice :: forall a. HasTypes a Tensor => Device -> a -> a
toDevice :: forall a. HasTypes a Tensor => Device -> a -> a
toDevice Device
device' a
t = Traversal' a Tensor -> (Tensor -> Tensor) -> a -> a
forall s a. Traversal' s a -> (a -> a) -> s -> s
over (forall a s. HasTypes s a => Traversal' s a
types @Tensor @a) (Device -> Tensor -> Tensor
_toDevice Device
device') a
t

-- | Casts the input tensor to given device
_toDevice ::
  -- | device to cast input to
  Device ->
  -- | input
  Tensor ->
  -- | output
  Tensor
_toDevice :: Device -> Tensor -> Tensor
_toDevice Device
device' Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ do
  Bool
hasDevice <- case Device -> DeviceType
deviceType Device
device' of
    DeviceType
CPU -> Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
    DeviceType
CUDA -> IO CBool -> IO Bool
forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA
    DeviceType
MPS -> IO CBool -> IO Bool
forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasMPS
  let device :: Device
device = Tensor -> Device
Torch.Tensor.device Tensor
t
  Tensor
t' <-
    DeviceType -> DeviceType -> Int16 -> Int16 -> Bool -> IO Tensor
toDevice'
      (Device -> DeviceType
deviceType Device
device)
      (Device -> DeviceType
deviceType Device
device')
      (Device -> Int16
deviceIndex Device
device)
      (Device -> Int16
deviceIndex Device
device')
      Bool
hasDevice
  DeviceType -> DeviceType -> Int16 -> Int16 -> IO ()
forall {a} {a} {f :: * -> *}.
(Eq a, Eq a, Applicative f, Show a, Show a) =>
a -> a -> a -> a -> f ()
check
    (Device -> DeviceType
deviceType Device
device')
    (Device -> DeviceType
deviceType (Device -> DeviceType) -> Device -> DeviceType
forall a b. (a -> b) -> a -> b
$ Tensor -> Device
Torch.Tensor.device Tensor
t')
    (Device -> Int16
deviceIndex Device
device')
    (Device -> Int16
deviceIndex (Device -> Int16) -> Device -> Int16
forall a b. (a -> b) -> a -> b
$ Tensor -> Device
Torch.Tensor.device Tensor
t')
  Tensor -> IO Tensor
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor
t'
  where
    toDevice' :: DeviceType -> DeviceType -> Int16 -> Int16 -> Bool -> IO Tensor
toDevice' DeviceType
dt DeviceType
dt' Int16
di Int16
di' Bool
_ | DeviceType
dt DeviceType -> DeviceType -> Bool
forall a. Eq a => a -> a -> Bool
== DeviceType
dt' Bool -> Bool -> Bool
&& Int16
di Int16 -> Int16 -> Bool
forall a. Eq a => a -> a -> Bool
== Int16
di' = Tensor -> IO Tensor
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor
t -- do nothing
    toDevice' DeviceType
CUDA DeviceType
CUDA Int16
di Int16
di' Bool
True | Int16
di Int16 -> Int16 -> Bool
forall a. Eq a => a -> a -> Bool
/= Int16
di' = Tensor -> IO TensorOptions
getOpts Tensor
t IO TensorOptions
-> (TensorOptions -> IO TensorOptions) -> IO TensorOptions
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di' IO TensorOptions -> (TensorOptions -> IO Tensor) -> IO Tensor
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Tensor -> TensorOptions -> IO Tensor
to Tensor
t -- copy from di to di'
    toDevice' DeviceType
CPU DeviceType
CUDA Int16
0 Int16
di' Bool
True | Int16
di' Int16 -> Int16 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int16
0 = Tensor -> IO TensorOptions
getOpts Tensor
t IO TensorOptions
-> (TensorOptions -> IO TensorOptions) -> IO TensorOptions
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di' IO TensorOptions -> (TensorOptions -> IO Tensor) -> IO Tensor
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Tensor -> TensorOptions -> IO Tensor
to Tensor
t -- copy from cpu:0 to cuda:di'
    toDevice' DeviceType
CUDA DeviceType
CPU Int16
di Int16
0 Bool
True | Int16
di Int16 -> Int16 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int16
0 = Tensor -> IO TensorOptions
getOpts Tensor
t IO TensorOptions
-> (TensorOptions -> IO TensorOptions) -> IO TensorOptions
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
CPU IO TensorOptions -> (TensorOptions -> IO Tensor) -> IO Tensor
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Tensor -> TensorOptions -> IO Tensor
to Tensor
t -- copy from cuda:di to cpu:0
    toDevice' DeviceType
CPU DeviceType
MPS Int16
0 Int16
0 Bool
True = Tensor -> IO TensorOptions
getOpts Tensor
t IO TensorOptions
-> (TensorOptions -> IO TensorOptions) -> IO TensorOptions
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
MPS IO TensorOptions -> (TensorOptions -> IO Tensor) -> IO Tensor
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Tensor -> TensorOptions -> IO Tensor
to Tensor
t -- copy from cpu:0 to mps:0'
    toDevice' DeviceType
MPS DeviceType
CPU Int16
0 Int16
0 Bool
True = Tensor -> IO TensorOptions
getOpts Tensor
t IO TensorOptions
-> (TensorOptions -> IO TensorOptions) -> IO TensorOptions
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
CPU IO TensorOptions -> (TensorOptions -> IO Tensor) -> IO Tensor
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Tensor -> TensorOptions -> IO Tensor
to Tensor
t -- copy from mps:0 to cpu:0
    toDevice' DeviceType
dt DeviceType
dt' Int16
di Int16
di' Bool
_ =
      [Char] -> IO Tensor
forall a. HasCallStack => [Char] -> a
error ([Char] -> IO Tensor) -> [Char] -> IO Tensor
forall a b. (a -> b) -> a -> b
$
        [Char]
"cannot move tensor from \""
          [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> DeviceType -> [Char]
forall a. Show a => a -> [Char]
show DeviceType
dt
          [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
":"
          [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int16 -> [Char]
forall a. Show a => a -> [Char]
show Int16
di
          [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
"\" to \""
          [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> DeviceType -> [Char]
forall a. Show a => a -> [Char]
show DeviceType
dt'
          [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
":"
          [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int16 -> [Char]
forall a. Show a => a -> [Char]
show Int16
di'
          [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
"\""
    getOpts :: Tensor -> IO TensorOptions
    getOpts :: Tensor -> IO TensorOptions
getOpts = (ForeignPtr Tensor -> IO (ForeignPtr TensorOptions))
-> Tensor -> IO TensorOptions
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr TensorOptions)
ATen.tensor_options
    withDeviceType :: DeviceType -> TensorOptions -> IO TensorOptions
    withDeviceType :: DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
dt TensorOptions
opts = (ForeignPtr TensorOptions
 -> Int16 -> IO (ForeignPtr TensorOptions))
-> TensorOptions -> DeviceType -> IO TensorOptions
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_D TensorOptions
opts DeviceType
dt
    withDeviceIndex :: Int16 -> TensorOptions -> IO TensorOptions
    withDeviceIndex :: Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di TensorOptions
opts = (ForeignPtr TensorOptions
 -> Int16 -> IO (ForeignPtr TensorOptions))
-> TensorOptions -> Int16 -> IO TensorOptions
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_index_s TensorOptions
opts Int16
di -- careful, setting the device index implies setting the device type to CUDA!
    to :: Tensor -> TensorOptions -> IO Tensor
    to :: Tensor -> TensorOptions -> IO Tensor
to Tensor
t TensorOptions
opts = (ForeignPtr Tensor
 -> ForeignPtr TensorOptions
 -> CBool
 -> CBool
 -> IO (ForeignPtr Tensor))
-> Tensor -> TensorOptions -> Bool -> Bool -> IO Tensor
forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr TensorOptions
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.tensor_to_obb Tensor
t TensorOptions
opts Bool
nonBlocking Bool
copy
      where
        nonBlocking :: Bool
nonBlocking = Bool
False
        copy :: Bool
copy = Bool
False
    check :: a -> a -> a -> a -> f ()
check a
dt a
dt' a
di a
di' | a
dt a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
dt' Bool -> Bool -> Bool
&& a
di a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
di' = () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    check a
dt a
dt' a
di a
di' =
      [Char] -> f ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> f ()) -> [Char] -> f ()
forall a b. (a -> b) -> a -> b
$
        [Char]
"moving of tensor failed: device should have been \""
          [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> a -> [Char]
forall a. Show a => a -> [Char]
show a
dt
          [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
":"
          [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> a -> [Char]
forall a. Show a => a -> [Char]
show a
di
          [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
"\" but is \""
          [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> a -> [Char]
forall a. Show a => a -> [Char]
show a
dt'
          [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
":"
          [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> a -> [Char]
forall a. Show a => a -> [Char]
show a
di'
          [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
"\""

toDeviceWithTensor :: Tensor -> Tensor -> Tensor
toDeviceWithTensor :: Tensor -> Tensor -> Tensor
toDeviceWithTensor Tensor
reference Tensor
input = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> Tensor -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_device Tensor
reference Tensor
input

-- | Slices the input tensor along the selected dimension at the given index.
select ::
  -- | dimension to slice along
  Int ->
  -- | index in the given dimension
  Int ->
  -- | input
  Tensor ->
  -- | output
  Tensor
select :: Int -> Int -> Tensor -> Tensor
select Int
dim Int
idx Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr Tensor))
-> Tensor -> Int -> Int -> IO Tensor
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.tensor_select_ll Tensor
t Int
dim Int
idx

-- | Returns a new tensor which indexes the input tensor along dimension dim using the entries in index which is a LongTensor.
indexSelect ::
  -- | dim
  Int ->
  -- | indexTensor
  Tensor ->
  -- | input
  Tensor ->
  -- | output
  Tensor
indexSelect :: Int -> Tensor -> Tensor -> Tensor
indexSelect Int
dim Tensor
indexTensor Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor
 -> Int64 -> ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> Int -> Tensor -> IO Tensor
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> Int64 -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.index_select_tlt) Tensor
t Int
dim Tensor
indexTensor

indexSelect' ::
  -- | dim
  Int ->
  -- | indexList
  [Int] ->
  -- | input
  Tensor ->
  -- | output
  Tensor
indexSelect' :: Int -> [Int] -> Tensor -> Tensor
indexSelect' Int
dim [Int]
indexList Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor
 -> Int64 -> ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> Int -> Tensor -> IO Tensor
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> Int64 -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.index_select_tlt) Tensor
t Int
dim (Device -> Tensor -> Tensor
_toDevice (Tensor -> Device
device Tensor
t) ([Int] -> Tensor
forall a. TensorLike a => a -> Tensor
asTensor [Int]
indexList))

-- | Slices the input tensor along the selected dimension at the given range.
sliceDim ::
  -- | dim
  Int ->
  -- | start
  Int ->
  -- | end
  Int ->
  -- | step
  Int ->
  -- | input
  Tensor ->
  Tensor
sliceDim :: Int -> Int -> Int -> Int -> Tensor -> Tensor
sliceDim Int
_dim Int
_start Int
_end Int
_step Tensor
_self = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor
 -> Int64 -> Int64 -> Int64 -> Int64 -> IO (ForeignPtr Tensor))
-> Tensor -> Int -> Int -> Int -> Int -> IO Tensor
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 ForeignPtr Tensor
-> Int64 -> Int64 -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.slice_tllll) Tensor
_self Int
_dim Int
_start Int
_end Int
_step

isContiguous ::
  Tensor ->
  Bool
isContiguous :: Tensor -> Bool
isContiguous Tensor
t = IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO CBool) -> Tensor -> IO Bool
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_contiguous) Tensor
t

contiguous ::
  Tensor ->
  Tensor
contiguous :: Tensor -> Tensor
contiguous Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> IO Tensor
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_contiguous) Tensor
t

-- | Returns a tensor with the same data and number of elements as input, but with the specified shape.
reshape ::
  [Int] ->
  Tensor ->
  Tensor
reshape :: [Int] -> Tensor -> Tensor
reshape [Int]
shape Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor
 -> ForeignPtr IntArray -> IO (ForeignPtr Tensor))
-> Tensor -> [Int] -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.reshape_tl Tensor
t [Int]
shape

--------------------------------------------------------------------------------
-- Move backend
--------------------------------------------------------------------------------

toSparse :: Tensor -> Tensor
toSparse :: Tensor -> Tensor
toSparse Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor))
-> Tensor -> Int -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.tensor_to_sparse_l) Tensor
t (Tensor -> Int
dimCUnsafe Tensor
t)

toDense :: Tensor -> Tensor
toDense :: Tensor -> Tensor
toDense Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> IO Tensor
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_dense) Tensor
t

toMKLDNN :: Tensor -> Tensor
toMKLDNN :: Tensor -> Tensor
toMKLDNN Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> IO Tensor
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_mkldnn) Tensor
t

toCPU :: Tensor -> Tensor
toCPU :: Tensor -> Tensor
toCPU Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> IO Tensor
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_cpu) Tensor
t

toCUDA :: Tensor -> Tensor
toCUDA :: Tensor -> Tensor
toCUDA Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> IO Tensor
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_cuda) Tensor
t

toMPS :: Tensor -> Tensor
toMPS :: Tensor -> Tensor
toMPS Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> IO Tensor
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_mps) Tensor
t

--------------------------------------------------------------------------------
-- Indexing support
--------------------------------------------------------------------------------

-- TensorIndex is the same as slice of pytorch.
--
-- There is one-to-one correspondence between Pytorch and Hasktorch tensor index types:
-- Pytorch                 | Hasktorch
-- -----------------------------------------------------
-- `None`                  | `None`
-- `Ellipsis`              | `Ellipsis`
-- `...`                   | `Ellipsis`
-- `123`                   | `123`
-- `True` / `False`        | `True` / `False`
-- `:`                     | `Slice ()`
-- `::`                    | `Slice ()`
-- `1:`                    | `Slice (1, None)`
-- `1::`                   | `Slice (1, None)`
-- `:3`                    | `Slice (None, 3)`
-- `:3:`                   | `Slice (None, 3)`
-- `::2`                   | `Slice (None, None, 2)`
-- `1:3`                   | `Slice (1, 3)`
-- `1::2`                  | `Slice (1, None, 2)`
-- `:3:2`                  | `Slice (None, 3, 2)`
-- `1:3:2`                 | `Slice (1, 3, 2)`
-- `torch.tensor([1, 2])`) | `asTensor([1, 2 ::Int])`

newtype RawTensorIndexList = RawTensorIndexList (ForeignPtr (ATen.StdVector ATen.TensorIndex))

newtype RawTensorIndex = RawTensorIndex (ForeignPtr ATen.TensorIndex)

(!) :: TensorIndex a => Tensor -> a -> Tensor
(Unsafe ForeignPtr Tensor
t) ! :: forall a. TensorIndex a => Tensor -> a -> Tensor
! a
idx = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ do
  let idxs :: [RawTensorIndex]
idxs = [RawTensorIndex] -> a -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex [] a
idx
  ForeignPtr (StdVector TensorIndex)
vec <- IO (ForeignPtr (StdVector TensorIndex))
ATen.newTensorIndexList
  [RawTensorIndex] -> (RawTensorIndex -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [RawTensorIndex]
idxs ((RawTensorIndex -> IO ()) -> IO ())
-> (RawTensorIndex -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(RawTensorIndex ForeignPtr TensorIndex
i) -> do
    ForeignPtr (StdVector TensorIndex)
-> ForeignPtr TensorIndex -> IO ()
ATen.tensorIndexList_push_back ForeignPtr (StdVector TensorIndex)
vec ForeignPtr TensorIndex
i
  ForeignPtr Tensor
-> ForeignPtr (StdVector TensorIndex) -> IO (ForeignPtr Tensor)
ATen.index ForeignPtr Tensor
t ForeignPtr (StdVector TensorIndex)
vec IO (ForeignPtr Tensor)
-> (ForeignPtr Tensor -> IO Tensor) -> IO Tensor
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor -> IO Tensor)
-> (ForeignPtr Tensor -> Tensor) -> ForeignPtr Tensor -> IO Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr Tensor -> Tensor
Unsafe)

maskedFill :: (TensorIndex a, TensorLike t) => Tensor -> a -> t -> Tensor
maskedFill :: forall a t.
(TensorIndex a, TensorLike t) =>
Tensor -> a -> t -> Tensor
maskedFill (Unsafe ForeignPtr Tensor
t') a
idx t
v' = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ do
  let idxs :: [RawTensorIndex]
idxs = [RawTensorIndex] -> a -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex [] a
idx
      (Unsafe ForeignPtr Tensor
v) = t -> Tensor
forall a. TensorLike a => a -> Tensor
asTensor t
v'
  ForeignPtr Tensor
t <- ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.clone_t ForeignPtr Tensor
t'
  ForeignPtr (StdVector TensorIndex)
vec <- IO (ForeignPtr (StdVector TensorIndex))
ATen.newTensorIndexList
  [RawTensorIndex] -> (RawTensorIndex -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [RawTensorIndex]
idxs ((RawTensorIndex -> IO ()) -> IO ())
-> (RawTensorIndex -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(RawTensorIndex ForeignPtr TensorIndex
i) -> do
    ForeignPtr (StdVector TensorIndex)
-> ForeignPtr TensorIndex -> IO ()
ATen.tensorIndexList_push_back ForeignPtr (StdVector TensorIndex)
vec ForeignPtr TensorIndex
i
  ForeignPtr Tensor
-> ForeignPtr (StdVector TensorIndex)
-> ForeignPtr Tensor
-> IO (ForeignPtr Tensor)
ATen.index_put_ ForeignPtr Tensor
t ForeignPtr (StdVector TensorIndex)
vec ForeignPtr Tensor
v
  Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor -> IO Tensor) -> Tensor -> IO Tensor
forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor -> Tensor
Unsafe ForeignPtr Tensor
t

data None = None
  deriving (Int -> None -> ShowS
[None] -> ShowS
None -> [Char]
(Int -> None -> ShowS)
-> (None -> [Char]) -> ([None] -> ShowS) -> Show None
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> None -> ShowS
showsPrec :: Int -> None -> ShowS
$cshow :: None -> [Char]
show :: None -> [Char]
$cshowList :: [None] -> ShowS
showList :: [None] -> ShowS
Show, None -> None -> Bool
(None -> None -> Bool) -> (None -> None -> Bool) -> Eq None
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: None -> None -> Bool
== :: None -> None -> Bool
$c/= :: None -> None -> Bool
/= :: None -> None -> Bool
Eq)

data Ellipsis = Ellipsis
  deriving (Int -> Ellipsis -> ShowS
[Ellipsis] -> ShowS
Ellipsis -> [Char]
(Int -> Ellipsis -> ShowS)
-> (Ellipsis -> [Char]) -> ([Ellipsis] -> ShowS) -> Show Ellipsis
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Ellipsis -> ShowS
showsPrec :: Int -> Ellipsis -> ShowS
$cshow :: Ellipsis -> [Char]
show :: Ellipsis -> [Char]
$cshowList :: [Ellipsis] -> ShowS
showList :: [Ellipsis] -> ShowS
Show, Ellipsis -> Ellipsis -> Bool
(Ellipsis -> Ellipsis -> Bool)
-> (Ellipsis -> Ellipsis -> Bool) -> Eq Ellipsis
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Ellipsis -> Ellipsis -> Bool
== :: Ellipsis -> Ellipsis -> Bool
$c/= :: Ellipsis -> Ellipsis -> Bool
/= :: Ellipsis -> Ellipsis -> Bool
Eq)

newtype Slice a = Slice a
  deriving (Int -> Slice a -> ShowS
[Slice a] -> ShowS
Slice a -> [Char]
(Int -> Slice a -> ShowS)
-> (Slice a -> [Char]) -> ([Slice a] -> ShowS) -> Show (Slice a)
forall a. Show a => Int -> Slice a -> ShowS
forall a. Show a => [Slice a] -> ShowS
forall a. Show a => Slice a -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> Slice a -> ShowS
showsPrec :: Int -> Slice a -> ShowS
$cshow :: forall a. Show a => Slice a -> [Char]
show :: Slice a -> [Char]
$cshowList :: forall a. Show a => [Slice a] -> ShowS
showList :: [Slice a] -> ShowS
Show, Slice a -> Slice a -> Bool
(Slice a -> Slice a -> Bool)
-> (Slice a -> Slice a -> Bool) -> Eq (Slice a)
forall a. Eq a => Slice a -> Slice a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => Slice a -> Slice a -> Bool
== :: Slice a -> Slice a -> Bool
$c/= :: forall a. Eq a => Slice a -> Slice a -> Bool
/= :: Slice a -> Slice a -> Bool
Eq)

instance Castable RawTensorIndex (ForeignPtr ATen.TensorIndex) where
  cast :: forall r.
RawTensorIndex -> (ForeignPtr TensorIndex -> IO r) -> IO r
cast (RawTensorIndex ForeignPtr TensorIndex
obj) ForeignPtr TensorIndex -> IO r
f = ForeignPtr TensorIndex -> IO r
f ForeignPtr TensorIndex
obj
  uncast :: forall r.
ForeignPtr TensorIndex -> (RawTensorIndex -> IO r) -> IO r
uncast ForeignPtr TensorIndex
obj RawTensorIndex -> IO r
f = RawTensorIndex -> IO r
f (RawTensorIndex -> IO r) -> RawTensorIndex -> IO r
forall a b. (a -> b) -> a -> b
$ ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
obj

class TensorIndex a where
  pushIndex :: [RawTensorIndex] -> a -> [RawTensorIndex]
  toLens :: TensorLike b => a -> Lens' Tensor b
  default toLens :: TensorLike b => a -> Lens' Tensor b
  toLens a
idx b -> f b
func Tensor
s = Tensor -> a -> Tensor -> Tensor
forall a t.
(TensorIndex a, TensorLike t) =>
Tensor -> a -> t -> Tensor
maskedFill Tensor
s a
idx (Tensor -> Tensor) -> f Tensor -> f Tensor
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (b -> Tensor
forall a. TensorLike a => a -> Tensor
asTensor (b -> Tensor) -> f b -> f Tensor
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> b -> f b
func (Tensor -> b
forall a. TensorLike a => Tensor -> a
asValue (Tensor
s Tensor -> a -> Tensor
forall a. TensorIndex a => Tensor -> a -> Tensor
! a
idx)))

instance {-# OVERLAPS #-} TensorIndex None where
  pushIndex :: [RawTensorIndex] -> None -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec None
_ = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithNone
    [RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} TensorIndex Ellipsis where
  pushIndex :: [RawTensorIndex] -> Ellipsis -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Ellipsis
_ = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithEllipsis
    [RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} TensorIndex Bool where
  pushIndex :: [RawTensorIndex] -> Bool -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Bool
b = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CBool -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithBool (if Bool
b then CBool
1 else CBool
0)
    [RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (a, a)) where
  pushIndex :: [RawTensorIndex] -> Slice (a, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (a
start, a
end)) = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
end :: CInt) CInt
1
    [RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (a, a, a)) where
  pushIndex :: [RawTensorIndex] -> Slice (a, a, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (a
start, a
end, a
step)) = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
end :: CInt) (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
step :: CInt)
    [RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (None, None, a)) where
  pushIndex :: [RawTensorIndex] -> Slice (None, None, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (None
_, None
_, a
step)) = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (CInt
forall a. Bounded a => a
maxBound :: CInt) (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
step :: CInt)
    [RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice a) where
  pushIndex :: [RawTensorIndex] -> Slice a -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice a
start) = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (CInt
forall a. Bounded a => a
maxBound :: CInt) CInt
1
    [RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (a, None)) where
  pushIndex :: [RawTensorIndex] -> Slice (a, None) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (a
start, None
_)) = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (CInt
forall a. Bounded a => a
maxBound :: CInt) CInt
1
    [RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (a, None, a)) where
  pushIndex :: [RawTensorIndex] -> Slice (a, None, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (a
start, None
_, a
step)) = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (CInt
forall a. Bounded a => a
maxBound :: CInt) (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
step :: CInt)
    [RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (None, a, a)) where
  pushIndex :: [RawTensorIndex] -> Slice (None, a, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (None
_, a
end, a
step)) = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
end :: CInt) (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
step :: CInt)
    [RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (None, a)) where
  pushIndex :: [RawTensorIndex] -> Slice (None, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (None
_, a
end)) = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
end :: CInt) CInt
1
    [RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance {-# OVERLAPS #-} TensorIndex (Slice ()) where
  pushIndex :: [RawTensorIndex] -> Slice () -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice ()) = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (CInt
forall a. Bounded a => a
maxBound :: CInt) CInt
1
    [RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance TensorIndex Int where
  pushIndex :: [RawTensorIndex] -> Int -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Int
v = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithInt (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
v :: CInt)
    [RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance TensorIndex Integer where
  pushIndex :: [RawTensorIndex] -> Integer -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Integer
v = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithInt (Integer -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
v :: CInt)
    [RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance TensorIndex Tensor where
  pushIndex :: [RawTensorIndex] -> Tensor -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Tensor
v = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
    RawTensorIndex
idx <- (ForeignPtr Tensor -> IO (ForeignPtr TensorIndex))
-> Tensor -> IO RawTensorIndex
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithTensor Tensor
v
    [RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RawTensorIndex
idx RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance TensorIndex () where
  pushIndex :: [RawTensorIndex] -> () -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec ()
_ = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
    ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (CInt
forall a. Bounded a => a
maxBound :: CInt) CInt
1
    [RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)

instance (TensorIndex a, TensorIndex b) => TensorIndex (a, b) where
  pushIndex :: [RawTensorIndex] -> (a, b) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (a
a, b
b) = (([RawTensorIndex] -> a -> [RawTensorIndex])
-> a -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> a -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex a
a) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> b -> [RawTensorIndex])
-> b -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> b -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex b
b) ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ [RawTensorIndex]
vec

instance (TensorIndex a, TensorIndex b, TensorIndex c) => TensorIndex (a, b, c) where
  pushIndex :: [RawTensorIndex] -> (a, b, c) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (a
a, b
b, c
c) = (([RawTensorIndex] -> a -> [RawTensorIndex])
-> a -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> a -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex a
a) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> b -> [RawTensorIndex])
-> b -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> b -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex b
b) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> c -> [RawTensorIndex])
-> c -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> c -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex c
c) ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ [RawTensorIndex]
vec

instance (TensorIndex a, TensorIndex b, TensorIndex c, TensorIndex d) => TensorIndex (a, b, c, d) where
  pushIndex :: [RawTensorIndex] -> (a, b, c, d) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (a
a, b
b, c
c, d
d) = (([RawTensorIndex] -> a -> [RawTensorIndex])
-> a -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> a -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex a
a) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> b -> [RawTensorIndex])
-> b -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> b -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex b
b) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> c -> [RawTensorIndex])
-> c -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> c -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex c
c) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> d -> [RawTensorIndex])
-> d -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> d -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex d
d) ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ [RawTensorIndex]
vec

instance (TensorIndex a, TensorIndex b, TensorIndex c, TensorIndex d, TensorIndex e) => TensorIndex (a, b, c, d, e) where
  pushIndex :: [RawTensorIndex] -> (a, b, c, d, e) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (a
a, b
b, c
c, d
d, e
e) = (([RawTensorIndex] -> a -> [RawTensorIndex])
-> a -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> a -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex a
a) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> b -> [RawTensorIndex])
-> b -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> b -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex b
b) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> c -> [RawTensorIndex])
-> c -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> c -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex c
c) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> d -> [RawTensorIndex])
-> d -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> d -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex d
d) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> e -> [RawTensorIndex])
-> e -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> e -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex e
e) ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ [RawTensorIndex]
vec

--------------------------------------------------------------------------------
-- Scalar <-> Tensor promotion
--------------------------------------------------------------------------------

asValue :: TensorLike a => Tensor -> a
asValue :: forall a. TensorLike a => Tensor -> a
asValue Tensor
t =
  let cpuTensor :: Tensor
cpuTensor = if Tensor -> Device
device Tensor
t Device -> Device -> Bool
forall a. Eq a => a -> a -> Bool
== DeviceType -> Int16 -> Device
Device DeviceType
CPU Int16
0 then Tensor
t else Tensor -> Tensor
toCPU Tensor
t
      contTensor :: Tensor
contTensor = if Tensor -> Bool
isContiguous Tensor
cpuTensor then Tensor
cpuTensor else Tensor -> Tensor
contiguous Tensor
cpuTensor
   in Tensor -> a
forall a. TensorLike a => Tensor -> a
_asValue Tensor
contTensor

class TensorOptionLike a where
  withTensorOptions :: Tensor -> a -> Tensor

instance  TensorOptionLike TensorOptions where
  withTensorOptions :: Tensor -> TensorOptions -> Tensor
withTensorOptions Tensor
t TensorOptions
opts = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor
 -> ForeignPtr TensorOptions
 -> CBool
 -> CBool
 -> IO (ForeignPtr Tensor))
-> Tensor -> TensorOptions -> Bool -> Bool -> IO Tensor
forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr TensorOptions
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.tensor_to_obb Tensor
t TensorOptions
opts Bool
nonBlocking Bool
copy
    where
      nonBlocking :: Bool
nonBlocking = Bool
False
      copy :: Bool
copy = Bool
False

instance  TensorOptionLike Tensor where
  withTensorOptions :: Tensor -> Tensor -> Tensor
withTensorOptions Tensor
t Tensor
opts = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor
 -> ForeignPtr Tensor -> CBool -> CBool -> IO (ForeignPtr Tensor))
-> Tensor -> Tensor -> Bool -> Bool -> IO Tensor
forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
 Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr Tensor -> CBool -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_to_tbb Tensor
t Tensor
opts Bool
nonBlocking Bool
copy
    where
      nonBlocking :: Bool
nonBlocking = Bool
False
      copy :: Bool
copy = Bool
False

class TensorLike a where
  asTensor' :: TensorOptionLike opt => a -> opt -> Tensor
  asTensor' a
v opt
opts = Tensor -> opt -> Tensor
forall a. TensorOptionLike a => Tensor -> a -> Tensor
withTensorOptions (a -> Tensor
forall a. TensorLike a => a -> Tensor
asTensor a
v) opt
opts
  asTensor :: a -> Tensor
  _asValue :: Tensor -> a

  -- Internal functions(like "_xxx") are below. Do not use them directly.
  _dtype :: DType
  _dims :: a -> [Int]
  _deepDims :: a -> Maybe [Int]
  _peekElemOff :: Ptr () -> Int -> [Int] -> IO a
  _pokeElemOff :: Ptr () -> Int -> a -> IO ()

bool_opts :: TensorOptions
bool_opts = DType -> TensorOptions -> TensorOptions
withDType DType
Bool TensorOptions
defaultOpts

uint8_opts :: TensorOptions
uint8_opts = DType -> TensorOptions -> TensorOptions
withDType DType
UInt8 TensorOptions
defaultOpts

int64_opts :: TensorOptions
int64_opts = DType -> TensorOptions -> TensorOptions
withDType DType
Int64 TensorOptions
defaultOpts

float_opts :: TensorOptions
float_opts = DType -> TensorOptions -> TensorOptions
withDType DType
Float TensorOptions
defaultOpts

double_opts :: TensorOptions
double_opts = DType -> TensorOptions -> TensorOptions
withDType DType
Double TensorOptions
defaultOpts

withTensor :: Tensor -> (Ptr () -> IO a) -> IO a
withTensor :: forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t Ptr () -> IO a
fn =
  let tensor :: Tensor
tensor = if Tensor -> Bool
isContiguous Tensor
t then Tensor
t else Tensor -> Tensor
contiguous Tensor
t
   in Tensor -> (ForeignPtr Tensor -> IO a) -> IO a
forall r. Tensor -> (ForeignPtr Tensor -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
tensor ((ForeignPtr Tensor -> IO a) -> IO a)
-> (ForeignPtr Tensor -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \ForeignPtr Tensor
t' -> ForeignPtr Tensor -> (Ptr Tensor -> IO a) -> IO a
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Tensor
t' ((Ptr Tensor -> IO a) -> IO a) -> (Ptr Tensor -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr Tensor
tensor_ptr -> Ptr Tensor -> IO (Ptr ())
Unmanaged.tensor_data_ptr Ptr Tensor
tensor_ptr IO (Ptr ()) -> (Ptr () -> IO a) -> IO a
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr () -> IO a
fn

-- | The internal function of withTensor. It does not check contiguous memory-layout.
_withTensor :: Tensor -> (Ptr () -> IO a) -> IO a
_withTensor :: forall a. Tensor -> (Ptr () -> IO a) -> IO a
_withTensor Tensor
t Ptr () -> IO a
fn =
  Tensor -> (ForeignPtr Tensor -> IO a) -> IO a
forall r. Tensor -> (ForeignPtr Tensor -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
t ((ForeignPtr Tensor -> IO a) -> IO a)
-> (ForeignPtr Tensor -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \ForeignPtr Tensor
t' -> ForeignPtr Tensor -> (Ptr Tensor -> IO a) -> IO a
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Tensor
t' ((Ptr Tensor -> IO a) -> IO a) -> (Ptr Tensor -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr Tensor
tensor_ptr -> Ptr Tensor -> IO (Ptr ())
Unmanaged.tensor_data_ptr Ptr Tensor
tensor_ptr IO (Ptr ()) -> (Ptr () -> IO a) -> IO a
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr () -> IO a
fn

instance {-# OVERLAPPING #-} (Reifies a DType, Storable a) => TensorLike a where
  asTensor :: a -> Tensor
asTensor a
v = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ do
    Tensor
t <- ((([Int] -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor))
-> [Int] -> TensorOptions -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 [Int] -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.new_empty_tensor) :: [Int] -> TensorOptions -> IO Tensor) [] (TensorOptions -> IO Tensor) -> TensorOptions -> IO Tensor
forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
withDType (forall a. TensorLike a => DType
_dtype @a) TensorOptions
defaultOpts
    Tensor -> (Ptr () -> IO ()) -> IO ()
forall a. Tensor -> (Ptr () -> IO a) -> IO a
_withTensor Tensor
t ((Ptr () -> IO ()) -> IO ()) -> (Ptr () -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
      Ptr () -> Int -> a -> IO ()
forall a. TensorLike a => Ptr () -> Int -> a -> IO ()
_pokeElemOff Ptr ()
ptr Int
0 a
v
    Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t

  _asValue :: Tensor -> a
_asValue Tensor
t = IO a -> a
forall a. IO a -> a
unsafePerformIO (IO a -> a) -> IO a -> a
forall a b. (a -> b) -> a -> b
$ do
    if forall a. TensorLike a => DType
_dtype @a DType -> DType -> Bool
forall a. Eq a => a -> a -> Bool
== Tensor -> DType
dtype Tensor
t
      then do
        Tensor -> (Ptr () -> IO a) -> IO a
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t ((Ptr () -> IO a) -> IO a) -> (Ptr () -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
          Ptr () -> Int -> [Int] -> IO a
forall a. TensorLike a => Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr Int
0 []
      else IOError -> IO a
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> IO a) -> IOError -> IO a
forall a b. (a -> b) -> a -> b
$ [Char] -> IOError
userError ([Char] -> IOError) -> [Char] -> IOError
forall a b. (a -> b) -> a -> b
$ [Char]
"The infered DType of asValue is " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ DType -> [Char]
forall a. Show a => a -> [Char]
show (forall a. TensorLike a => DType
_dtype @a) [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
", but the DType of tensor on memory is " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ DType -> [Char]
forall a. Show a => a -> [Char]
show (Tensor -> DType
dtype Tensor
t) [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"."

  _dtype :: DType
_dtype = Proxy a -> DType
forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
forall (proxy :: * -> *). proxy a -> DType
reflect (Proxy a
forall {k} (t :: k). Proxy t
Proxy :: Proxy a)
  _dims :: a -> [Int]
_dims a
_ = []
  _deepDims :: a -> Maybe [Int]
_deepDims a
_ = [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just []
  _peekElemOff :: Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr Int
offset [Int]
_ = Ptr a -> Int -> IO a
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (Ptr () -> Ptr a
forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset
  _pokeElemOff :: Ptr () -> Int -> a -> IO ()
_pokeElemOff Ptr ()
ptr Int
offset a
v = Ptr a -> Int -> a -> IO ()
forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff (Ptr () -> Ptr a
forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset a
v

instance {-# OVERLAPPING #-} TensorLike Bool where
  asTensor :: Bool -> Tensor
asTensor Bool
v = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ do
    Tensor
t <- ((([Int] -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor))
-> [Int] -> TensorOptions -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 [Int] -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.new_empty_tensor) :: [Int] -> TensorOptions -> IO Tensor) [] (TensorOptions -> IO Tensor) -> TensorOptions -> IO Tensor
forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
withDType (forall a. TensorLike a => DType
_dtype @Bool) TensorOptions
defaultOpts
    Tensor -> (Ptr () -> IO ()) -> IO ()
forall a. Tensor -> (Ptr () -> IO a) -> IO a
_withTensor Tensor
t ((Ptr () -> IO ()) -> IO ()) -> (Ptr () -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
      Ptr () -> Int -> Bool -> IO ()
forall a. TensorLike a => Ptr () -> Int -> a -> IO ()
_pokeElemOff Ptr ()
ptr Int
0 Bool
v
    Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t

  _asValue :: Tensor -> Bool
_asValue Tensor
t = IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$ do
    if forall a. TensorLike a => DType
_dtype @Bool DType -> DType -> Bool
forall a. Eq a => a -> a -> Bool
== Tensor -> DType
dtype Tensor
t
      then do
        Tensor -> (Ptr () -> IO Bool) -> IO Bool
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t ((Ptr () -> IO Bool) -> IO Bool) -> (Ptr () -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
          Ptr () -> Int -> [Int] -> IO Bool
forall a. TensorLike a => Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr Int
0 []
      else IOError -> IO Bool
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> IO Bool) -> IOError -> IO Bool
forall a b. (a -> b) -> a -> b
$ [Char] -> IOError
userError ([Char] -> IOError) -> [Char] -> IOError
forall a b. (a -> b) -> a -> b
$ [Char]
"The infered DType of asValue is " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ DType -> [Char]
forall a. Show a => a -> [Char]
show (forall a. TensorLike a => DType
_dtype @Bool) [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
", but the DType of tensor on memory is " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ DType -> [Char]
forall a. Show a => a -> [Char]
show (Tensor -> DType
dtype Tensor
t) [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"."

  _dtype :: DType
_dtype = Proxy Bool -> DType
forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
forall (proxy :: * -> *). proxy Bool -> DType
reflect (Proxy Bool
forall {k} (t :: k). Proxy t
Proxy :: Proxy Bool)
  _dims :: Bool -> [Int]
_dims Bool
_ = []
  _deepDims :: Bool -> Maybe [Int]
_deepDims Bool
_ = [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just []
  _peekElemOff :: Ptr () -> Int -> [Int] -> IO Bool
_peekElemOff Ptr ()
ptr Int
offset [Int]
_ = (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
0) (Word8 -> Bool) -> IO Word8 -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ptr Word8 -> Int -> IO Word8
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (Ptr () -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset :: IO Word8)
  _pokeElemOff :: Ptr () -> Int -> Bool -> IO ()
_pokeElemOff Ptr ()
ptr Int
offset Bool
v = Ptr Word8 -> Int -> Word8 -> IO ()
forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff (Ptr () -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset ((if Bool
v then Word8
1 else Word8
0) :: Word8)

instance {-# OVERLAPPING #-} TensorLike Tensor where
  asTensor' :: forall a. TensorOptionLike a => Tensor -> a -> Tensor
asTensor' Tensor
v opt
opts = Tensor -> opt -> Tensor
forall a. TensorOptionLike a => Tensor -> a -> Tensor
withTensorOptions Tensor
v opt
opts
  asTensor :: Tensor -> Tensor
asTensor = Tensor -> Tensor
forall a. a -> a
id
  _asValue :: Tensor -> Tensor
_asValue = Tensor -> Tensor
forall a. a -> a
id
  _dtype :: DType
_dtype = [Char] -> DType
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"
  _dims :: Tensor -> [Int]
_dims Tensor
v = [Char] -> [Int]
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"
  _deepDims :: Tensor -> Maybe [Int]
_deepDims Tensor
v = [Char] -> Maybe [Int]
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"
  _peekElemOff :: Ptr () -> Int -> [Int] -> IO Tensor
_peekElemOff = [Char] -> Ptr () -> Int -> [Int] -> IO Tensor
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"
  _pokeElemOff :: Ptr () -> Int -> Tensor -> IO ()
_pokeElemOff = [Char] -> Ptr () -> Int -> Tensor -> IO ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"

instance {-# OVERLAPPING #-} TensorLike a => TensorLike (a, a) where
  asTensor :: (a, a) -> Tensor
asTensor (a
a, a
b) = [a] -> Tensor
forall a. TensorLike a => a -> Tensor
asTensor [a
a, a
b]
  _asValue :: Tensor -> (a, a)
_asValue Tensor
v =
    let [a
a, a
b] = Tensor -> [a]
forall a. TensorLike a => Tensor -> a
_asValue Tensor
v
     in (a
a, a
b)
  _dtype :: DType
_dtype = [Char] -> DType
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"
  _dims :: (a, a) -> [Int]
_dims (a, a)
v = [Char] -> [Int]
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"
  _deepDims :: (a, a) -> Maybe [Int]
_deepDims (a, a)
v = [Char] -> Maybe [Int]
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"
  _peekElemOff :: Ptr () -> Int -> [Int] -> IO (a, a)
_peekElemOff = [Char] -> Ptr () -> Int -> [Int] -> IO (a, a)
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"
  _pokeElemOff :: Ptr () -> Int -> (a, a) -> IO ()
_pokeElemOff = [Char] -> Ptr () -> Int -> (a, a) -> IO ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"

instance {-# OVERLAPPING #-} TensorLike a => TensorLike [a] where
  asTensor :: [a] -> Tensor
asTensor [a]
v = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ do
    Tensor
t <- ((([Int] -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor))
-> [Int] -> TensorOptions -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 [Int] -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.new_empty_tensor) :: [Int] -> TensorOptions -> IO Tensor) ([a] -> [Int]
forall a. TensorLike a => a -> [Int]
_dims [a]
v) (TensorOptions -> IO Tensor) -> TensorOptions -> IO Tensor
forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
withDType (forall a. TensorLike a => DType
_dtype @a) TensorOptions
defaultOpts
    Tensor -> (Ptr () -> IO ()) -> IO ()
forall a. Tensor -> (Ptr () -> IO a) -> IO a
_withTensor Tensor
t ((Ptr () -> IO ()) -> IO ()) -> (Ptr () -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
      Ptr () -> Int -> [a] -> IO ()
forall a. TensorLike a => Ptr () -> Int -> a -> IO ()
_pokeElemOff Ptr ()
ptr Int
0 [a]
v
    Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t

  _asValue :: Tensor -> [a]
_asValue Tensor
t = IO [a] -> [a]
forall a. IO a -> a
unsafePerformIO (IO [a] -> [a]) -> IO [a] -> [a]
forall a b. (a -> b) -> a -> b
$ do
    if forall a. TensorLike a => DType
_dtype @a DType -> DType -> Bool
forall a. Eq a => a -> a -> Bool
== Tensor -> DType
dtype Tensor
t
      then do
        Tensor -> (Ptr () -> IO [a]) -> IO [a]
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t ((Ptr () -> IO [a]) -> IO [a]) -> (Ptr () -> IO [a]) -> IO [a]
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
          Ptr () -> Int -> [Int] -> IO [a]
forall a. TensorLike a => Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr Int
0 (Tensor -> [Int]
shape Tensor
t)
      else IOError -> IO [a]
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> IO [a]) -> IOError -> IO [a]
forall a b. (a -> b) -> a -> b
$ [Char] -> IOError
userError ([Char] -> IOError) -> [Char] -> IOError
forall a b. (a -> b) -> a -> b
$ [Char]
"The infered DType of asValue is " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ DType -> [Char]
forall a. Show a => a -> [Char]
show (forall a. TensorLike a => DType
_dtype @a) [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
", but the DType of tensor on memory is " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ DType -> [Char]
forall a. Show a => a -> [Char]
show (Tensor -> DType
dtype Tensor
t) [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"."

  _dtype :: DType
_dtype = forall a. TensorLike a => DType
_dtype @a

  _dims :: [a] -> [Int]
_dims [] = [Int
0]
  _dims v :: [a]
v@(a
x : [a]
_) = ([a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
v) Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: (a -> [Int]
forall a. TensorLike a => a -> [Int]
_dims a
x)

  _deepDims :: [a] -> Maybe [Int]
_deepDims [] = [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int
0]
  _deepDims v :: [a]
v@(a
x : [a]
xs) = do
    [Int]
deepDimsX <- a -> Maybe [Int]
forall a. TensorLike a => a -> Maybe [Int]
_deepDims a
x
    [[Int]]
deepDimsXs <- (a -> Maybe [Int]) -> [a] -> Maybe [[Int]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse a -> Maybe [Int]
forall a. TensorLike a => a -> Maybe [Int]
_deepDims [a]
xs
    if [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ ([Int] -> Bool) -> [[Int]] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Int]
deepDimsX [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
==) [[Int]]
deepDimsXs
      then [Int] -> Maybe [Int]
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int] -> Maybe [Int]) -> [Int] -> Maybe [Int]
forall a b. (a -> b) -> a -> b
$ [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
v Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
deepDimsX
      else Maybe [Int]
forall a. Maybe a
Nothing

  _peekElemOff :: Ptr () -> Int -> [Int] -> IO [a]
_peekElemOff Ptr ()
ptr Int
offset [] = [a] -> IO [a]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return []
  _peekElemOff Ptr ()
ptr Int
offset (Int
d : [Int]
dims) =
    let width :: Int
width = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
dims
     in [Int] -> (Int -> IO a) -> IO [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
0 .. (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)] ((Int -> IO a) -> IO [a]) -> (Int -> IO a) -> IO [a]
forall a b. (a -> b) -> a -> b
$ \Int
i ->
          Ptr () -> Int -> [Int] -> IO a
forall a. TensorLike a => Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
width) [Int]
dims

  _pokeElemOff :: Ptr () -> Int -> [a] -> IO ()
_pokeElemOff Ptr ()
ptr Int
offset [] = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  _pokeElemOff Ptr ()
ptr Int
offset v :: [a]
v@(a
x : [a]
_) =
    let width :: Int
width = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (a -> [Int]
forall a. TensorLike a => a -> [Int]
_dims a
x)
     in [(Int, a)] -> ((Int, a) -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int] -> [a] -> [(Int, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [a]
v) (((Int, a) -> IO ()) -> IO ()) -> ((Int, a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Int
i, a
d) ->
          if [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (a -> [Int]
forall a. TensorLike a => a -> [Int]
_dims a
d) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
width -- This validation may be slow.
            then (forall a. TensorLike a => Ptr () -> Int -> a -> IO ()
_pokeElemOff @a) Ptr ()
ptr (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
width) a
d
            else IOError -> IO ()
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> IO ()) -> IOError -> IO ()
forall a b. (a -> b) -> a -> b
$ [Char] -> IOError
userError ([Char] -> IOError) -> [Char] -> IOError
forall a b. (a -> b) -> a -> b
$ [Char]
"There are lists having different length."

class AsTensors as where
  toTensors :: as -> V.Vector Tensor
  default toTensors :: (Generic as, GAsTensors (Rep as)) => as -> V.Vector Tensor
  toTensors as
a = Rep as Any -> Vector Tensor
forall as. Rep as as -> Vector Tensor
forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors (Rep as Any -> Vector Tensor) -> Rep as Any -> Vector Tensor
forall a b. (a -> b) -> a -> b
$ as -> Rep as Any
forall x. as -> Rep as x
forall a x. Generic a => a -> Rep a x
from as
a

instance TensorLike a => AsTensors a where
  toTensors :: a -> Vector Tensor
toTensors = Tensor -> Vector Tensor
forall a. a -> Vector a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor -> Vector Tensor) -> (a -> Tensor) -> a -> Vector Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Tensor
forall a. TensorLike a => a -> Tensor
asTensor

class GAsTensors record where
  gToTensors :: record as -> V.Vector Tensor

instance (GAsTensors ls, GAsTensors rs) => GAsTensors (ls :*: rs) where
  gToTensors :: forall as. (:*:) ls rs as -> Vector Tensor
gToTensors (ls as
g :*: rs as
d) = ls as -> Vector Tensor
forall as. ls as -> Vector Tensor
forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors ls as
g Vector Tensor -> Vector Tensor -> Vector Tensor
forall a. Vector a -> Vector a -> Vector a
V.++ rs as -> Vector Tensor
forall as. rs as -> Vector Tensor
forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors rs as
d

instance (GAsTensors ls, GAsTensors rs) => GAsTensors (ls :+: rs) where
  gToTensors :: forall as. (:+:) ls rs as -> Vector Tensor
gToTensors (L1 ls as
g) = ls as -> Vector Tensor
forall as. ls as -> Vector Tensor
forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors ls as
g
  gToTensors (R1 rs as
g) = rs as -> Vector Tensor
forall as. rs as -> Vector Tensor
forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors rs as
g

instance (GAsTensors ls) => GAsTensors (M1 i c ls) where
  gToTensors :: forall as. M1 i c ls as -> Vector Tensor
gToTensors (M1 ls as
g) = ls as -> Vector Tensor
forall as. ls as -> Vector Tensor
forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors ls as
g

instance (TensorLike ls) => GAsTensors (K1 i ls) where
  gToTensors :: forall as. K1 i ls as -> Vector Tensor
gToTensors (K1 ls
g) = Tensor -> Vector Tensor
forall a. a -> Vector a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor -> Vector Tensor) -> Tensor -> Vector Tensor
forall a b. (a -> b) -> a -> b
$ ls -> Tensor
forall a. TensorLike a => a -> Tensor
asTensor ls
g

--------------------------------------------------------------------------------
-- Show
--------------------------------------------------------------------------------

instance Show Tensor where
  show :: Tensor -> [Char]
show Tensor
t' =
    case (Tensor -> Int
dim Tensor
t) of
      Int
0 -> [Char]
details [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Tensor -> [Char]
show0d Tensor
t
      Int
1 -> [Char]
details [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Tensor -> [Char]
show1d Tensor
t
      Int
n -> [Char]
details [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> Int -> Tensor -> [Char]
shownd Int
n Int
0 Tensor
t
    where
      t :: Tensor
t = if Tensor -> Device
device Tensor
t' Device -> Device -> Bool
forall a. Eq a => a -> a -> Bool
== DeviceType -> Int16 -> Device
Device DeviceType
CPU Int16
0 then Tensor
t' else Tensor -> Tensor
toCPU Tensor
t'
      -- TODO: this is obviously not the right way to do it,
      -- and will be terribly slow, so please fix it.
      showElems :: (Tensor -> [Char]) -> [Char] -> Tensor -> [Char]
showElems Tensor -> [Char]
elemShow [Char]
sep Tensor
t = [Char]
"[" [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ ([Char] -> [[Char]] -> [Char]
forall a. [a] -> [[a]] -> [a]
intercalate [Char]
sep ([[Char]] -> [Char]) -> [[Char]] -> [Char]
forall a b. (a -> b) -> a -> b
$ (Tensor -> [Char]) -> [Tensor] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map Tensor -> [Char]
elemShow [Tensor
t Tensor -> Int -> Tensor
forall a. TensorIndex a => Tensor -> a -> Tensor
! Int
i | Int
i <- [Int
0 .. ((Int -> Tensor -> Int
size Int
0 Tensor
t) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)]]) [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"]"
      padPositive :: a -> ShowS
padPositive a
x [Char]
s = if a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
0 then [Char]
" " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
s else [Char]
s
      -- TODO: this assumes that scientific notation only uses one-digit exponents, which is not
      --       true in general
      padLarge :: a -> ShowS
padLarge a
x [Char]
s = if (a -> a
forall a. Num a => a -> a
abs a
x) a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
0.1 then [Char]
s [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"   " else [Char]
s
      show0d :: Tensor -> [Char]
show0d Tensor
x =
        if DType -> Bool
isIntegral (Tensor -> DType
dtype Tensor
t)
          then Int -> ShowS
forall {a}. (Ord a, Num a) => a -> ShowS
padPositive (Tensor -> Int
toInt Tensor
x) ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ Int -> [Char]
forall a. Show a => a -> [Char]
show (Int -> [Char]) -> Int -> [Char]
forall a b. (a -> b) -> a -> b
$ Tensor -> Int
toInt Tensor
x
          else
            if DType -> Bool
isComplex (Tensor -> DType
dtype Tensor
t)
               then
                 let Double
r :+ Double
i = Tensor -> Complex Double
toComplex Tensor
x
                 in (Double -> ShowS
forall {a}. (Ord a, Fractional a) => a -> ShowS
padLarge Double
r ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ Double -> ShowS
forall {a}. (Ord a, Num a) => a -> ShowS
padPositive Double
r ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ Maybe Int -> Double -> ShowS
forall a. RealFloat a => Maybe Int -> a -> ShowS
showGFloat (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
4) Double
r [Char]
"") [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" + i" [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++
                    (Double -> ShowS
forall {a}. (Ord a, Fractional a) => a -> ShowS
padLarge Double
i ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ Double -> ShowS
forall {a}. (Ord a, Num a) => a -> ShowS
padPositive Double
i ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ Maybe Int -> Double -> ShowS
forall a. RealFloat a => Maybe Int -> a -> ShowS
showGFloat (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
4) Double
i [Char]
"")
               else Double -> ShowS
forall {a}. (Ord a, Fractional a) => a -> ShowS
padLarge (Tensor -> Double
toDouble Tensor
x) ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ Double -> ShowS
forall {a}. (Ord a, Num a) => a -> ShowS
padPositive (Tensor -> Double
toDouble Tensor
x) ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ Maybe Int -> Double -> ShowS
forall a. RealFloat a => Maybe Int -> a -> ShowS
showGFloat (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
4) (Tensor -> Double
toDouble Tensor
x) [Char]
""
      show1d :: Tensor -> [Char]
show1d = (Tensor -> [Char]) -> [Char] -> Tensor -> [Char]
showElems Tensor -> [Char]
show0d [Char]
", "
      shownd :: Int -> Int -> Tensor -> [Char]
shownd Int
n Int
offset =
        case Int
n of
          Int
2 -> (Tensor -> [Char]) -> [Char] -> Tensor -> [Char]
showElems Tensor -> [Char]
show1d ([Char]
",\n " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
padding [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> Char -> [Char]
forall a. Int -> a -> [a]
replicate Int
offset Char
' ')
          Int
_ -> (Tensor -> [Char]) -> [Char] -> Tensor -> [Char]
showElems (Int -> Int -> Tensor -> [Char]
shownd (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) ([Char]
",\n " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
padding [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> Char -> [Char]
forall a. Int -> a -> [a]
replicate Int
offset Char
' ')
      details :: [Char]
details = [Char]
"Tensor " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ (DType -> [Char]
forall a. Show a => a -> [Char]
show (DType -> [Char]) -> DType -> [Char]
forall a b. (a -> b) -> a -> b
$ Tensor -> DType
dtype Tensor
t) [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ ([Int] -> [Char]
forall a. Show a => a -> [Char]
show ([Int] -> [Char]) -> [Int] -> [Char]
forall a b. (a -> b) -> a -> b
$ Tensor -> [Int]
shape Tensor
t) [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" "
      padding :: [Char]
padding = (Char -> Char) -> ShowS
forall a b. (a -> b) -> [a] -> [b]
map (Char -> Char -> Char
forall a b. a -> b -> a
const Char
' ') [Char]
details

--------------------------------------------------------------------------------

-- Castable instances
--------------------------------------------------------------------------------

-- NB: ATen only defines Castable [ForeignPtr ATen.Tensor] (ForeignPtr ATen.TensorList)
instance Castable [Tensor] (ForeignPtr ATen.TensorList) where
  cast :: forall r. [Tensor] -> (ForeignPtr TensorList -> IO r) -> IO r
cast [Tensor]
xs ForeignPtr TensorList -> IO r
f = do
    [ForeignPtr Tensor]
ptr_list <- (Tensor -> IO (ForeignPtr Tensor))
-> [Tensor] -> IO [ForeignPtr Tensor]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Tensor
x -> (Tensor
-> (ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> IO (ForeignPtr Tensor)
forall r. Tensor -> (ForeignPtr Tensor -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
x ForeignPtr Tensor -> IO (ForeignPtr Tensor)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) [Tensor]
xs
    [ForeignPtr Tensor] -> (ForeignPtr TensorList -> IO r) -> IO r
forall r.
[ForeignPtr Tensor] -> (ForeignPtr TensorList -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr Tensor]
ptr_list ForeignPtr TensorList -> IO r
f
  uncast :: forall r. ForeignPtr TensorList -> ([Tensor] -> IO r) -> IO r
uncast ForeignPtr TensorList
xs [Tensor] -> IO r
f = ForeignPtr TensorList -> ([ForeignPtr Tensor] -> IO r) -> IO r
forall r.
ForeignPtr TensorList -> ([ForeignPtr Tensor] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs (([ForeignPtr Tensor] -> IO r) -> IO r)
-> ([ForeignPtr Tensor] -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \[ForeignPtr Tensor]
ptr_list -> do
    [Tensor]
tensor_list <- (ForeignPtr Tensor -> IO Tensor)
-> [ForeignPtr Tensor] -> IO [Tensor]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(ForeignPtr Tensor
x :: ForeignPtr ATen.Tensor) -> ForeignPtr Tensor -> (Tensor -> IO Tensor) -> IO Tensor
forall r. ForeignPtr Tensor -> (Tensor -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
x Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return) [ForeignPtr Tensor]
ptr_list
    [Tensor] -> IO r
f [Tensor]
tensor_list

instance Castable [Tensor] (ForeignPtr (ATen.C10List ATen.Tensor)) where
  cast :: forall r. [Tensor] -> (ForeignPtr (C10List Tensor) -> IO r) -> IO r
cast [Tensor]
xs ForeignPtr (C10List Tensor) -> IO r
f = do
    [ForeignPtr Tensor]
ptr_list <- (Tensor -> IO (ForeignPtr Tensor))
-> [Tensor] -> IO [ForeignPtr Tensor]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Tensor
x -> (Tensor
-> (ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> IO (ForeignPtr Tensor)
forall r. Tensor -> (ForeignPtr Tensor -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
x ForeignPtr Tensor -> IO (ForeignPtr Tensor)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) [Tensor]
xs
    [ForeignPtr Tensor]
-> (ForeignPtr (C10List Tensor) -> IO r) -> IO r
forall r.
[ForeignPtr Tensor]
-> (ForeignPtr (C10List Tensor) -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr Tensor]
ptr_list ForeignPtr (C10List Tensor) -> IO r
f
  uncast :: forall r. ForeignPtr (C10List Tensor) -> ([Tensor] -> IO r) -> IO r
uncast ForeignPtr (C10List Tensor)
xs [Tensor] -> IO r
f = ForeignPtr (C10List Tensor)
-> ([ForeignPtr Tensor] -> IO r) -> IO r
forall r.
ForeignPtr (C10List Tensor)
-> ([ForeignPtr Tensor] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List Tensor)
xs (([ForeignPtr Tensor] -> IO r) -> IO r)
-> ([ForeignPtr Tensor] -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \[ForeignPtr Tensor]
ptr_list -> do
    [Tensor]
tensor_list <- (ForeignPtr Tensor -> IO Tensor)
-> [ForeignPtr Tensor] -> IO [Tensor]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(ForeignPtr Tensor
x :: ForeignPtr ATen.Tensor) -> ForeignPtr Tensor -> (Tensor -> IO Tensor) -> IO Tensor
forall r. ForeignPtr Tensor -> (Tensor -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
x Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return) [ForeignPtr Tensor]
ptr_list
    [Tensor] -> IO r
f [Tensor]
tensor_list

instance Castable [Tensor] (ForeignPtr (ATen.C10List (ATen.C10Optional ATen.Tensor))) where
  cast :: forall r.
[Tensor]
-> (ForeignPtr (C10List (C10Optional Tensor)) -> IO r) -> IO r
cast [Tensor]
xs ForeignPtr (C10List (C10Optional Tensor)) -> IO r
f = do
    [ForeignPtr Tensor]
ptr_list <- (Tensor -> IO (ForeignPtr Tensor))
-> [Tensor] -> IO [ForeignPtr Tensor]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Tensor
x -> (Tensor
-> (ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> IO (ForeignPtr Tensor)
forall r. Tensor -> (ForeignPtr Tensor -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
x ForeignPtr Tensor -> IO (ForeignPtr Tensor)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) [Tensor]
xs
    [ForeignPtr Tensor]
-> (ForeignPtr (C10List (C10Optional Tensor)) -> IO r) -> IO r
forall r.
[ForeignPtr Tensor]
-> (ForeignPtr (C10List (C10Optional Tensor)) -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr Tensor]
ptr_list ForeignPtr (C10List (C10Optional Tensor)) -> IO r
f
  uncast :: forall r.
ForeignPtr (C10List (C10Optional Tensor))
-> ([Tensor] -> IO r) -> IO r
uncast ForeignPtr (C10List (C10Optional Tensor))
xs [Tensor] -> IO r
f = ForeignPtr (C10List (C10Optional Tensor))
-> ([ForeignPtr Tensor] -> IO r) -> IO r
forall r.
ForeignPtr (C10List (C10Optional Tensor))
-> ([ForeignPtr Tensor] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List (C10Optional Tensor))
xs (([ForeignPtr Tensor] -> IO r) -> IO r)
-> ([ForeignPtr Tensor] -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \[ForeignPtr Tensor]
ptr_list -> do
    [Tensor]
tensor_list <- (ForeignPtr Tensor -> IO Tensor)
-> [ForeignPtr Tensor] -> IO [Tensor]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(ForeignPtr Tensor
x :: ForeignPtr ATen.Tensor) -> ForeignPtr Tensor -> (Tensor -> IO Tensor) -> IO Tensor
forall r. ForeignPtr Tensor -> (Tensor -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
x Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return) [ForeignPtr Tensor]
ptr_list
    [Tensor] -> IO r
f [Tensor]
tensor_list