{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.Tensor where
import Control.Exception.Safe (throwIO)
import Control.Monad (forM, forM_)
import Numeric.Half
import Data.Complex
import Data.Int (Int16, Int64)
import Data.List (intercalate)
import Data.Proxy
import Data.Reflection
import qualified Data.Vector as V
import Data.Word (Word8)
import Foreign.C.Types
import Foreign.ForeignPtr
import Foreign.Ptr
import Foreign.Storable
import GHC.Generics
import Numeric
import System.IO.Unsafe
import Torch.DType
import Torch.Device
import Torch.Internal.Cast
import Torch.Internal.Class (Castable (..), CppTuple2 (..), CppTuple3 (..), CppTuple4 (..))
import qualified Torch.Internal.Const as ATen
import qualified Torch.Internal.Managed.Cast as ATen
import qualified Torch.Internal.Managed.Native as ATen
import qualified Torch.Internal.Managed.TensorFactories as LibTorch
import qualified Torch.Internal.Managed.Type.Context as ATen
import qualified Torch.Internal.Managed.Type.StdArray as ATen
import qualified Torch.Internal.Managed.Type.StdString as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Managed.Type.TensorIndex as ATen
import qualified Torch.Internal.Managed.Type.TensorOptions as ATen
import qualified Torch.Internal.Managed.Type.Extra as ATen
import qualified Torch.Internal.Type as ATen
import qualified Torch.Internal.Unmanaged.Type.Tensor as Unmanaged (tensor_data_ptr)
import Torch.Lens
import Torch.TensorOptions
type ATenTensor = ForeignPtr ATen.Tensor
newtype Tensor = Unsafe ATenTensor
instance Castable Tensor ATenTensor where
cast :: forall r. Tensor -> (ForeignPtr Tensor -> IO r) -> IO r
cast (Unsafe ForeignPtr Tensor
aten_tensor) ForeignPtr Tensor -> IO r
f = ForeignPtr Tensor -> IO r
f ForeignPtr Tensor
aten_tensor
uncast :: forall r. ForeignPtr Tensor -> (Tensor -> IO r) -> IO r
uncast ForeignPtr Tensor
aten_tensor Tensor -> IO r
f = Tensor -> IO r
f (Tensor -> IO r) -> Tensor -> IO r
forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor -> Tensor
Unsafe ForeignPtr Tensor
aten_tensor
newtype MutableTensor = MutableTensor Tensor deriving Int -> MutableTensor -> ShowS
[MutableTensor] -> ShowS
MutableTensor -> [Char]
(Int -> MutableTensor -> ShowS)
-> (MutableTensor -> [Char])
-> ([MutableTensor] -> ShowS)
-> Show MutableTensor
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MutableTensor -> ShowS
showsPrec :: Int -> MutableTensor -> ShowS
$cshow :: MutableTensor -> [Char]
show :: MutableTensor -> [Char]
$cshowList :: [MutableTensor] -> ShowS
showList :: [MutableTensor] -> ShowS
Show
newMutableTensor :: Tensor -> IO MutableTensor
newMutableTensor :: Tensor -> IO MutableTensor
newMutableTensor Tensor
tensor = Tensor -> MutableTensor
MutableTensor (Tensor -> MutableTensor) -> IO Tensor -> IO MutableTensor
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> IO Tensor
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.detach_t Tensor
tensor
toImmutable :: MutableTensor -> IO Tensor
toImmutable :: MutableTensor -> IO Tensor
toImmutable (MutableTensor Tensor
tensor) = (ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> IO Tensor
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.detach_t Tensor
tensor
numel ::
Tensor ->
Int
numel :: Tensor -> Int
numel Tensor
t = IO Int -> Int
forall a. IO a -> a
unsafePerformIO (IO Int -> Int) -> IO Int -> Int
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> IO Int64) -> Tensor -> IO Int
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_numel (Tensor -> IO Int) -> Tensor -> IO Int
forall a b. (a -> b) -> a -> b
$ Tensor
t
size ::
Int ->
Tensor ->
Int
size :: Int -> Tensor -> Int
size Int
dim Tensor
t = IO Int -> Int
forall a. IO a -> a
unsafePerformIO (IO Int -> Int) -> IO Int -> Int
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> Int64 -> IO Int64) -> Tensor -> Int -> IO Int
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> Int64 -> IO Int64
ATen.tensor_size_l) Tensor
t Int
dim
shape ::
Tensor ->
[Int]
shape :: Tensor -> [Int]
shape Tensor
t = IO [Int] -> [Int]
forall a. IO a -> a
unsafePerformIO (IO [Int] -> [Int]) -> IO [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO (ForeignPtr IntArray))
-> Tensor -> IO [Int]
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr IntArray)
ATen.tensor_sizes) Tensor
t
dim ::
Tensor ->
Int
dim :: Tensor -> Int
dim Tensor
t = IO Int -> Int
forall a. IO a -> a
unsafePerformIO (IO Int -> Int) -> IO Int -> Int
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO Int64) -> Tensor -> IO Int
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_dim) Tensor
t
dimUnsafe ::
Tensor ->
Int
dimUnsafe :: Tensor -> Int
dimUnsafe Tensor
t = IO Int -> Int
forall a. IO a -> a
unsafePerformIO (IO Int -> Int) -> IO Int -> Int
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO Int64) -> Tensor -> IO Int
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_dim_unsafe) Tensor
t
dimCUnsafe ::
Tensor ->
Int
dimCUnsafe :: Tensor -> Int
dimCUnsafe Tensor
t = IO Int -> Int
forall a. IO a -> a
unsafePerformIO (IO Int -> Int) -> IO Int -> Int
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO Int64) -> Tensor -> IO Int
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_dim_c_unsafe) Tensor
t
device ::
Tensor ->
Device
device :: Tensor -> Device
device Tensor
t = IO Device -> Device
forall a. IO a -> a
unsafePerformIO (IO Device -> Device) -> IO Device -> Device
forall a b. (a -> b) -> a -> b
$ do
Bool
hasCUDA <- IO CBool -> IO Bool
forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA :: IO Bool
if Bool
hasCUDA
then do
Bool
isCUDA <- (ForeignPtr Tensor -> IO CBool) -> Tensor -> IO Bool
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_cuda Tensor
t :: IO Bool
if Bool
isCUDA then Int -> Device
cuda (Int -> Device) -> IO Int -> IO Device
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ForeignPtr Tensor -> IO Int64) -> Tensor -> IO Int
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_get_device Tensor
t else Device -> IO Device
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Device
cpu
else do
Bool
hasMPS <- IO CBool -> IO Bool
forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasMPS :: IO Bool
if Bool
hasMPS
then do
Bool
isMPS <- (ForeignPtr Tensor -> IO CBool) -> Tensor -> IO Bool
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_mps Tensor
t :: IO Bool
if Bool
isMPS then Device -> IO Device
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Device
mps else Device -> IO Device
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Device
cpu
else
Device -> IO Device
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Device
cpu
where
cpu :: Device
cpu = Device {deviceType :: DeviceType
deviceType = DeviceType
CPU, deviceIndex :: Int16
deviceIndex = Int16
0}
cuda :: Int -> Device
cuda :: Int -> Device
cuda Int
di = Device {deviceType :: DeviceType
deviceType = DeviceType
CUDA, deviceIndex :: Int16
deviceIndex = Int -> Int16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
di}
mps :: Device
mps = Device {deviceType :: DeviceType
deviceType = DeviceType
MPS, deviceIndex :: Int16
deviceIndex = Int16
0}
dtype ::
Tensor ->
DType
dtype :: Tensor -> DType
dtype Tensor
t = IO DType -> DType
forall a. IO a -> a
unsafePerformIO (IO DType -> DType) -> IO DType -> DType
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> IO ScalarType) -> Tensor -> IO DType
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO ScalarType
ATen.tensor_scalar_type Tensor
t
toComplex :: Tensor -> Complex Double
toComplex :: Tensor -> Complex Double
toComplex Tensor
t = IO (Complex Double) -> Complex Double
forall a. IO a -> a
unsafePerformIO (IO (Complex Double) -> Complex Double)
-> IO (Complex Double) -> Complex Double
forall a b. (a -> b) -> a -> b
$
case Tensor -> DType
dtype Tensor
t of
DType
ComplexHalf -> do
Half
r :+ Half
i <- Tensor -> (Ptr () -> IO (Complex Half)) -> IO (Complex Half)
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t ((Ptr () -> IO (Complex Half)) -> IO (Complex Half))
-> (Ptr () -> IO (Complex Half)) -> IO (Complex Half)
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> Ptr (Complex Half) -> Int -> IO (Complex Half)
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (Ptr () -> Ptr (Complex Half)
forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
0 :: IO (Complex Half)
Complex Double -> IO (Complex Double)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Half -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac Half
r Double -> Double -> Complex Double
forall a. a -> a -> Complex a
:+ Half -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac Half
i)
DType
ComplexFloat -> do
Float
r :+ Float
i <- Tensor -> (Ptr () -> IO (Complex Float)) -> IO (Complex Float)
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t ((Ptr () -> IO (Complex Float)) -> IO (Complex Float))
-> (Ptr () -> IO (Complex Float)) -> IO (Complex Float)
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> Ptr (Complex Float) -> Int -> IO (Complex Float)
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (Ptr () -> Ptr (Complex Float)
forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
0 :: IO (Complex Float)
Complex Double -> IO (Complex Double)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Float -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
r Double -> Double -> Complex Double
forall a. a -> a -> Complex a
:+ Float -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac Float
i)
DType
ComplexDouble -> Tensor -> (Ptr () -> IO (Complex Double)) -> IO (Complex Double)
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t ((Ptr () -> IO (Complex Double)) -> IO (Complex Double))
-> (Ptr () -> IO (Complex Double)) -> IO (Complex Double)
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> Ptr (Complex Double) -> Int -> IO (Complex Double)
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (Ptr () -> Ptr (Complex Double)
forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
0 :: IO (Complex Double)
DType
_ -> (Double -> Double -> Complex Double
forall a. a -> a -> Complex a
:+ Double
0) (Double -> Complex Double) -> IO Double -> IO (Complex Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ForeignPtr Tensor -> IO CDouble) -> Tensor -> IO Double
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CDouble
ATen.tensor_item_double Tensor
t
toDouble :: Tensor -> Double
toDouble :: Tensor -> Double
toDouble Tensor
t = IO Double -> Double
forall a. IO a -> a
unsafePerformIO (IO Double -> Double) -> IO Double -> Double
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> IO CDouble) -> Tensor -> IO Double
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CDouble
ATen.tensor_item_double Tensor
t
toInt :: Tensor -> Int
toInt :: Tensor -> Int
toInt Tensor
t = IO Int -> Int
forall a. IO a -> a
unsafePerformIO (IO Int -> Int) -> IO Int -> Int
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> IO Int64) -> Tensor -> IO Int
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO Int64
ATen.tensor_item_int64_t Tensor
t
_toType ::
DType ->
Tensor ->
Tensor
_toType :: DType -> Tensor -> Tensor
_toType DType
dtype Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor))
-> Tensor -> DType -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ScalarType -> IO (ForeignPtr Tensor)
ATen.tensor_toType_s Tensor
t DType
dtype
instance HasTypes Tensor Tensor where
types_ :: Traversal' Tensor Tensor
types_ = (Tensor -> f Tensor) -> Tensor -> f Tensor
forall a. a -> a
id
instance HasTypes (a -> a) Tensor where
types_ :: Traversal' (a -> a) Tensor
types_ Tensor -> f Tensor
_ = (a -> a) -> f (a -> a)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
instance HasTypes Int Tensor where
types_ :: Traversal' Int Tensor
types_ Tensor -> f Tensor
_ = Int -> f Int
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
instance HasTypes Double Tensor where
types_ :: Traversal' Double Tensor
types_ Tensor -> f Tensor
_ = Double -> f Double
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
instance HasTypes Float Tensor where
types_ :: Traversal' Float Tensor
types_ Tensor -> f Tensor
_ = Float -> f Float
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
instance HasTypes Bool Tensor where
types_ :: Traversal' Bool Tensor
types_ Tensor -> f Tensor
_ = Bool -> f Bool
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
instance HasTypes Int Int where
types_ :: Traversal' Int Int
types_ = (Int -> f Int) -> Int -> f Int
forall a. a -> a
id
instance HasTypes Float Float where
types_ :: Traversal' Float Float
types_ = (Float -> f Float) -> Float -> f Float
forall a. a -> a
id
instance HasTypes Double Double where
types_ :: Traversal' Double Double
types_ = (Double -> f Double) -> Double -> f Double
forall a. a -> a
id
instance HasTypes Bool Bool where
types_ :: Traversal' Bool Bool
types_ = (Bool -> f Bool) -> Bool -> f Bool
forall a. a -> a
id
toType :: forall a. HasTypes a Tensor => DType -> a -> a
toType :: forall a. HasTypes a Tensor => DType -> a -> a
toType DType
dtype a
t = Traversal' a Tensor -> (Tensor -> Tensor) -> a -> a
forall s a. Traversal' s a -> (a -> a) -> s -> s
over (forall a s. HasTypes s a => Traversal' s a
types @Tensor @a) (DType -> Tensor -> Tensor
_toType DType
dtype) a
t
toDevice :: forall a. HasTypes a Tensor => Device -> a -> a
toDevice :: forall a. HasTypes a Tensor => Device -> a -> a
toDevice Device
device' a
t = Traversal' a Tensor -> (Tensor -> Tensor) -> a -> a
forall s a. Traversal' s a -> (a -> a) -> s -> s
over (forall a s. HasTypes s a => Traversal' s a
types @Tensor @a) (Device -> Tensor -> Tensor
_toDevice Device
device') a
t
_toDevice ::
Device ->
Tensor ->
Tensor
_toDevice :: Device -> Tensor -> Tensor
_toDevice Device
device' Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ do
Bool
hasDevice <- case Device -> DeviceType
deviceType Device
device' of
DeviceType
CPU -> Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
DeviceType
CUDA -> IO CBool -> IO Bool
forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasCUDA
DeviceType
MPS -> IO CBool -> IO Bool
forall a ca. Castable a ca => IO ca -> IO a
cast0 IO CBool
ATen.hasMPS
let device :: Device
device = Tensor -> Device
Torch.Tensor.device Tensor
t
Tensor
t' <-
DeviceType -> DeviceType -> Int16 -> Int16 -> Bool -> IO Tensor
toDevice'
(Device -> DeviceType
deviceType Device
device)
(Device -> DeviceType
deviceType Device
device')
(Device -> Int16
deviceIndex Device
device)
(Device -> Int16
deviceIndex Device
device')
Bool
hasDevice
DeviceType -> DeviceType -> Int16 -> Int16 -> IO ()
forall {a} {a} {f :: * -> *}.
(Eq a, Eq a, Applicative f, Show a, Show a) =>
a -> a -> a -> a -> f ()
check
(Device -> DeviceType
deviceType Device
device')
(Device -> DeviceType
deviceType (Device -> DeviceType) -> Device -> DeviceType
forall a b. (a -> b) -> a -> b
$ Tensor -> Device
Torch.Tensor.device Tensor
t')
(Device -> Int16
deviceIndex Device
device')
(Device -> Int16
deviceIndex (Device -> Int16) -> Device -> Int16
forall a b. (a -> b) -> a -> b
$ Tensor -> Device
Torch.Tensor.device Tensor
t')
Tensor -> IO Tensor
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor
t'
where
toDevice' :: DeviceType -> DeviceType -> Int16 -> Int16 -> Bool -> IO Tensor
toDevice' DeviceType
dt DeviceType
dt' Int16
di Int16
di' Bool
_ | DeviceType
dt DeviceType -> DeviceType -> Bool
forall a. Eq a => a -> a -> Bool
== DeviceType
dt' Bool -> Bool -> Bool
&& Int16
di Int16 -> Int16 -> Bool
forall a. Eq a => a -> a -> Bool
== Int16
di' = Tensor -> IO Tensor
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor
t
toDevice' DeviceType
CUDA DeviceType
CUDA Int16
di Int16
di' Bool
True | Int16
di Int16 -> Int16 -> Bool
forall a. Eq a => a -> a -> Bool
/= Int16
di' = Tensor -> IO TensorOptions
getOpts Tensor
t IO TensorOptions
-> (TensorOptions -> IO TensorOptions) -> IO TensorOptions
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di' IO TensorOptions -> (TensorOptions -> IO Tensor) -> IO Tensor
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Tensor -> TensorOptions -> IO Tensor
to Tensor
t
toDevice' DeviceType
CPU DeviceType
CUDA Int16
0 Int16
di' Bool
True | Int16
di' Int16 -> Int16 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int16
0 = Tensor -> IO TensorOptions
getOpts Tensor
t IO TensorOptions
-> (TensorOptions -> IO TensorOptions) -> IO TensorOptions
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di' IO TensorOptions -> (TensorOptions -> IO Tensor) -> IO Tensor
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Tensor -> TensorOptions -> IO Tensor
to Tensor
t
toDevice' DeviceType
CUDA DeviceType
CPU Int16
di Int16
0 Bool
True | Int16
di Int16 -> Int16 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int16
0 = Tensor -> IO TensorOptions
getOpts Tensor
t IO TensorOptions
-> (TensorOptions -> IO TensorOptions) -> IO TensorOptions
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
CPU IO TensorOptions -> (TensorOptions -> IO Tensor) -> IO Tensor
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Tensor -> TensorOptions -> IO Tensor
to Tensor
t
toDevice' DeviceType
CPU DeviceType
MPS Int16
0 Int16
0 Bool
True = Tensor -> IO TensorOptions
getOpts Tensor
t IO TensorOptions
-> (TensorOptions -> IO TensorOptions) -> IO TensorOptions
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
MPS IO TensorOptions -> (TensorOptions -> IO Tensor) -> IO Tensor
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Tensor -> TensorOptions -> IO Tensor
to Tensor
t
toDevice' DeviceType
MPS DeviceType
CPU Int16
0 Int16
0 Bool
True = Tensor -> IO TensorOptions
getOpts Tensor
t IO TensorOptions
-> (TensorOptions -> IO TensorOptions) -> IO TensorOptions
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
CPU IO TensorOptions -> (TensorOptions -> IO Tensor) -> IO Tensor
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Tensor -> TensorOptions -> IO Tensor
to Tensor
t
toDevice' DeviceType
dt DeviceType
dt' Int16
di Int16
di' Bool
_ =
[Char] -> IO Tensor
forall a. HasCallStack => [Char] -> a
error ([Char] -> IO Tensor) -> [Char] -> IO Tensor
forall a b. (a -> b) -> a -> b
$
[Char]
"cannot move tensor from \""
[Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> DeviceType -> [Char]
forall a. Show a => a -> [Char]
show DeviceType
dt
[Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
":"
[Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int16 -> [Char]
forall a. Show a => a -> [Char]
show Int16
di
[Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
"\" to \""
[Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> DeviceType -> [Char]
forall a. Show a => a -> [Char]
show DeviceType
dt'
[Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
":"
[Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int16 -> [Char]
forall a. Show a => a -> [Char]
show Int16
di'
[Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
"\""
getOpts :: Tensor -> IO TensorOptions
getOpts :: Tensor -> IO TensorOptions
getOpts = (ForeignPtr Tensor -> IO (ForeignPtr TensorOptions))
-> Tensor -> IO TensorOptions
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr TensorOptions)
ATen.tensor_options
withDeviceType :: DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType :: DeviceType -> TensorOptions -> IO TensorOptions
withDeviceType DeviceType
dt TensorOptions
opts = (ForeignPtr TensorOptions
-> Int16 -> IO (ForeignPtr TensorOptions))
-> TensorOptions -> DeviceType -> IO TensorOptions
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_D TensorOptions
opts DeviceType
dt
withDeviceIndex :: Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex :: Int16 -> TensorOptions -> IO TensorOptions
withDeviceIndex Int16
di TensorOptions
opts = (ForeignPtr TensorOptions
-> Int16 -> IO (ForeignPtr TensorOptions))
-> TensorOptions -> Int16 -> IO TensorOptions
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr TensorOptions -> Int16 -> IO (ForeignPtr TensorOptions)
ATen.tensorOptions_device_index_s TensorOptions
opts Int16
di
to :: Tensor -> TensorOptions -> IO Tensor
to :: Tensor -> TensorOptions -> IO Tensor
to Tensor
t TensorOptions
opts = (ForeignPtr Tensor
-> ForeignPtr TensorOptions
-> CBool
-> CBool
-> IO (ForeignPtr Tensor))
-> Tensor -> TensorOptions -> Bool -> Bool -> IO Tensor
forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr TensorOptions
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.tensor_to_obb Tensor
t TensorOptions
opts Bool
nonBlocking Bool
copy
where
nonBlocking :: Bool
nonBlocking = Bool
False
copy :: Bool
copy = Bool
False
check :: a -> a -> a -> a -> f ()
check a
dt a
dt' a
di a
di' | a
dt a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
dt' Bool -> Bool -> Bool
&& a
di a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
di' = () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
check a
dt a
dt' a
di a
di' =
[Char] -> f ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> f ()) -> [Char] -> f ()
forall a b. (a -> b) -> a -> b
$
[Char]
"moving of tensor failed: device should have been \""
[Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> a -> [Char]
forall a. Show a => a -> [Char]
show a
dt
[Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
":"
[Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> a -> [Char]
forall a. Show a => a -> [Char]
show a
di
[Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
"\" but is \""
[Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> a -> [Char]
forall a. Show a => a -> [Char]
show a
dt'
[Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
":"
[Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> a -> [Char]
forall a. Show a => a -> [Char]
show a
di'
[Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Char]
"\""
toDeviceWithTensor :: Tensor -> Tensor -> Tensor
toDeviceWithTensor :: Tensor -> Tensor -> Tensor
toDeviceWithTensor Tensor
reference Tensor
input = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> Tensor -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_device Tensor
reference Tensor
input
select ::
Int ->
Int ->
Tensor ->
Tensor
select :: Int -> Int -> Tensor -> Tensor
select Int
dim Int
idx Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr Tensor))
-> Tensor -> Int -> Int -> IO Tensor
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.tensor_select_ll Tensor
t Int
dim Int
idx
indexSelect ::
Int ->
Tensor ->
Tensor ->
Tensor
indexSelect :: Int -> Tensor -> Tensor -> Tensor
indexSelect Int
dim Tensor
indexTensor Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor
-> Int64 -> ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> Int -> Tensor -> IO Tensor
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> Int64 -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.index_select_tlt) Tensor
t Int
dim Tensor
indexTensor
indexSelect' ::
Int ->
[Int] ->
Tensor ->
Tensor
indexSelect' :: Int -> [Int] -> Tensor -> Tensor
indexSelect' Int
dim [Int]
indexList Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor
-> Int64 -> ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> Int -> Tensor -> IO Tensor
forall a ca x1 cx1 x2 cx2 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable y cy) =>
(ca -> cx1 -> cx2 -> IO cy) -> a -> x1 -> x2 -> IO y
cast3 ForeignPtr Tensor
-> Int64 -> ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.index_select_tlt) Tensor
t Int
dim (Device -> Tensor -> Tensor
_toDevice (Tensor -> Device
device Tensor
t) ([Int] -> Tensor
forall a. TensorLike a => a -> Tensor
asTensor [Int]
indexList))
sliceDim ::
Int ->
Int ->
Int ->
Int ->
Tensor ->
Tensor
sliceDim :: Int -> Int -> Int -> Int -> Tensor -> Tensor
sliceDim Int
_dim Int
_start Int
_end Int
_step Tensor
_self = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor
-> Int64 -> Int64 -> Int64 -> Int64 -> IO (ForeignPtr Tensor))
-> Tensor -> Int -> Int -> Int -> Int -> IO Tensor
forall a ca x1 cx1 x2 cx2 x3 cx3 x4 cx4 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable x4 cx4, Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> cx4 -> IO cy)
-> a -> x1 -> x2 -> x3 -> x4 -> IO y
cast5 ForeignPtr Tensor
-> Int64 -> Int64 -> Int64 -> Int64 -> IO (ForeignPtr Tensor)
ATen.slice_tllll) Tensor
_self Int
_dim Int
_start Int
_end Int
_step
isContiguous ::
Tensor ->
Bool
isContiguous :: Tensor -> Bool
isContiguous Tensor
t = IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO CBool) -> Tensor -> IO Bool
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO CBool
ATen.tensor_is_contiguous) Tensor
t
contiguous ::
Tensor ->
Tensor
contiguous :: Tensor -> Tensor
contiguous Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> IO Tensor
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_contiguous) Tensor
t
reshape ::
[Int] ->
Tensor ->
Tensor
reshape :: [Int] -> Tensor -> Tensor
reshape [Int]
shape Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor
-> ForeignPtr IntArray -> IO (ForeignPtr Tensor))
-> Tensor -> [Int] -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> ForeignPtr IntArray -> IO (ForeignPtr Tensor)
ATen.reshape_tl Tensor
t [Int]
shape
toSparse :: Tensor -> Tensor
toSparse :: Tensor -> Tensor
toSparse Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor))
-> Tensor -> Int -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 ForeignPtr Tensor -> Int64 -> IO (ForeignPtr Tensor)
ATen.tensor_to_sparse_l) Tensor
t (Tensor -> Int
dimCUnsafe Tensor
t)
toDense :: Tensor -> Tensor
toDense :: Tensor -> Tensor
toDense Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> IO Tensor
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_dense) Tensor
t
toMKLDNN :: Tensor -> Tensor
toMKLDNN :: Tensor -> Tensor
toMKLDNN Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> IO Tensor
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_to_mkldnn) Tensor
t
toCPU :: Tensor -> Tensor
toCPU :: Tensor -> Tensor
toCPU Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> IO Tensor
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_cpu) Tensor
t
toCUDA :: Tensor -> Tensor
toCUDA :: Tensor -> Tensor
toCUDA Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> IO Tensor
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_cuda) Tensor
t
toMPS :: Tensor -> Tensor
toMPS :: Tensor -> Tensor
toMPS Tensor
t = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ ((ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> Tensor -> IO Tensor
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.tensor_mps) Tensor
t
newtype RawTensorIndexList = RawTensorIndexList (ForeignPtr (ATen.StdVector ATen.TensorIndex))
newtype RawTensorIndex = RawTensorIndex (ForeignPtr ATen.TensorIndex)
(!) :: TensorIndex a => Tensor -> a -> Tensor
(Unsafe ForeignPtr Tensor
t) ! :: forall a. TensorIndex a => Tensor -> a -> Tensor
! a
idx = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ do
let idxs :: [RawTensorIndex]
idxs = [RawTensorIndex] -> a -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex [] a
idx
ForeignPtr (StdVector TensorIndex)
vec <- IO (ForeignPtr (StdVector TensorIndex))
ATen.newTensorIndexList
[RawTensorIndex] -> (RawTensorIndex -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [RawTensorIndex]
idxs ((RawTensorIndex -> IO ()) -> IO ())
-> (RawTensorIndex -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(RawTensorIndex ForeignPtr TensorIndex
i) -> do
ForeignPtr (StdVector TensorIndex)
-> ForeignPtr TensorIndex -> IO ()
ATen.tensorIndexList_push_back ForeignPtr (StdVector TensorIndex)
vec ForeignPtr TensorIndex
i
ForeignPtr Tensor
-> ForeignPtr (StdVector TensorIndex) -> IO (ForeignPtr Tensor)
ATen.index ForeignPtr Tensor
t ForeignPtr (StdVector TensorIndex)
vec IO (ForeignPtr Tensor)
-> (ForeignPtr Tensor -> IO Tensor) -> IO Tensor
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor -> IO Tensor)
-> (ForeignPtr Tensor -> Tensor) -> ForeignPtr Tensor -> IO Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr Tensor -> Tensor
Unsafe)
maskedFill :: (TensorIndex a, TensorLike t) => Tensor -> a -> t -> Tensor
maskedFill :: forall a t.
(TensorIndex a, TensorLike t) =>
Tensor -> a -> t -> Tensor
maskedFill (Unsafe ForeignPtr Tensor
t') a
idx t
v' = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ do
let idxs :: [RawTensorIndex]
idxs = [RawTensorIndex] -> a -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex [] a
idx
(Unsafe ForeignPtr Tensor
v) = t -> Tensor
forall a. TensorLike a => a -> Tensor
asTensor t
v'
ForeignPtr Tensor
t <- ForeignPtr Tensor -> IO (ForeignPtr Tensor)
ATen.clone_t ForeignPtr Tensor
t'
ForeignPtr (StdVector TensorIndex)
vec <- IO (ForeignPtr (StdVector TensorIndex))
ATen.newTensorIndexList
[RawTensorIndex] -> (RawTensorIndex -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [RawTensorIndex]
idxs ((RawTensorIndex -> IO ()) -> IO ())
-> (RawTensorIndex -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(RawTensorIndex ForeignPtr TensorIndex
i) -> do
ForeignPtr (StdVector TensorIndex)
-> ForeignPtr TensorIndex -> IO ()
ATen.tensorIndexList_push_back ForeignPtr (StdVector TensorIndex)
vec ForeignPtr TensorIndex
i
ForeignPtr Tensor
-> ForeignPtr (StdVector TensorIndex)
-> ForeignPtr Tensor
-> IO (ForeignPtr Tensor)
ATen.index_put_ ForeignPtr Tensor
t ForeignPtr (StdVector TensorIndex)
vec ForeignPtr Tensor
v
Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Tensor -> IO Tensor) -> Tensor -> IO Tensor
forall a b. (a -> b) -> a -> b
$ ForeignPtr Tensor -> Tensor
Unsafe ForeignPtr Tensor
t
data None = None
deriving (Int -> None -> ShowS
[None] -> ShowS
None -> [Char]
(Int -> None -> ShowS)
-> (None -> [Char]) -> ([None] -> ShowS) -> Show None
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> None -> ShowS
showsPrec :: Int -> None -> ShowS
$cshow :: None -> [Char]
show :: None -> [Char]
$cshowList :: [None] -> ShowS
showList :: [None] -> ShowS
Show, None -> None -> Bool
(None -> None -> Bool) -> (None -> None -> Bool) -> Eq None
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: None -> None -> Bool
== :: None -> None -> Bool
$c/= :: None -> None -> Bool
/= :: None -> None -> Bool
Eq)
data Ellipsis = Ellipsis
deriving (Int -> Ellipsis -> ShowS
[Ellipsis] -> ShowS
Ellipsis -> [Char]
(Int -> Ellipsis -> ShowS)
-> (Ellipsis -> [Char]) -> ([Ellipsis] -> ShowS) -> Show Ellipsis
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Ellipsis -> ShowS
showsPrec :: Int -> Ellipsis -> ShowS
$cshow :: Ellipsis -> [Char]
show :: Ellipsis -> [Char]
$cshowList :: [Ellipsis] -> ShowS
showList :: [Ellipsis] -> ShowS
Show, Ellipsis -> Ellipsis -> Bool
(Ellipsis -> Ellipsis -> Bool)
-> (Ellipsis -> Ellipsis -> Bool) -> Eq Ellipsis
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Ellipsis -> Ellipsis -> Bool
== :: Ellipsis -> Ellipsis -> Bool
$c/= :: Ellipsis -> Ellipsis -> Bool
/= :: Ellipsis -> Ellipsis -> Bool
Eq)
newtype Slice a = Slice a
deriving (Int -> Slice a -> ShowS
[Slice a] -> ShowS
Slice a -> [Char]
(Int -> Slice a -> ShowS)
-> (Slice a -> [Char]) -> ([Slice a] -> ShowS) -> Show (Slice a)
forall a. Show a => Int -> Slice a -> ShowS
forall a. Show a => [Slice a] -> ShowS
forall a. Show a => Slice a -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> Slice a -> ShowS
showsPrec :: Int -> Slice a -> ShowS
$cshow :: forall a. Show a => Slice a -> [Char]
show :: Slice a -> [Char]
$cshowList :: forall a. Show a => [Slice a] -> ShowS
showList :: [Slice a] -> ShowS
Show, Slice a -> Slice a -> Bool
(Slice a -> Slice a -> Bool)
-> (Slice a -> Slice a -> Bool) -> Eq (Slice a)
forall a. Eq a => Slice a -> Slice a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => Slice a -> Slice a -> Bool
== :: Slice a -> Slice a -> Bool
$c/= :: forall a. Eq a => Slice a -> Slice a -> Bool
/= :: Slice a -> Slice a -> Bool
Eq)
instance Castable RawTensorIndex (ForeignPtr ATen.TensorIndex) where
cast :: forall r.
RawTensorIndex -> (ForeignPtr TensorIndex -> IO r) -> IO r
cast (RawTensorIndex ForeignPtr TensorIndex
obj) ForeignPtr TensorIndex -> IO r
f = ForeignPtr TensorIndex -> IO r
f ForeignPtr TensorIndex
obj
uncast :: forall r.
ForeignPtr TensorIndex -> (RawTensorIndex -> IO r) -> IO r
uncast ForeignPtr TensorIndex
obj RawTensorIndex -> IO r
f = RawTensorIndex -> IO r
f (RawTensorIndex -> IO r) -> RawTensorIndex -> IO r
forall a b. (a -> b) -> a -> b
$ ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
obj
class TensorIndex a where
pushIndex :: [RawTensorIndex] -> a -> [RawTensorIndex]
toLens :: TensorLike b => a -> Lens' Tensor b
default toLens :: TensorLike b => a -> Lens' Tensor b
toLens a
idx b -> f b
func Tensor
s = Tensor -> a -> Tensor -> Tensor
forall a t.
(TensorIndex a, TensorLike t) =>
Tensor -> a -> t -> Tensor
maskedFill Tensor
s a
idx (Tensor -> Tensor) -> f Tensor -> f Tensor
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (b -> Tensor
forall a. TensorLike a => a -> Tensor
asTensor (b -> Tensor) -> f b -> f Tensor
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> b -> f b
func (Tensor -> b
forall a. TensorLike a => Tensor -> a
asValue (Tensor
s Tensor -> a -> Tensor
forall a. TensorIndex a => Tensor -> a -> Tensor
! a
idx)))
instance {-# OVERLAPS #-} TensorIndex None where
pushIndex :: [RawTensorIndex] -> None -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec None
_ = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithNone
[RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} TensorIndex Ellipsis where
pushIndex :: [RawTensorIndex] -> Ellipsis -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Ellipsis
_ = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithEllipsis
[RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} TensorIndex Bool where
pushIndex :: [RawTensorIndex] -> Bool -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Bool
b = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CBool -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithBool (if Bool
b then CBool
1 else CBool
0)
[RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (a, a)) where
pushIndex :: [RawTensorIndex] -> Slice (a, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (a
start, a
end)) = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
end :: CInt) CInt
1
[RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (a, a, a)) where
pushIndex :: [RawTensorIndex] -> Slice (a, a, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (a
start, a
end, a
step)) = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
end :: CInt) (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
step :: CInt)
[RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (None, None, a)) where
pushIndex :: [RawTensorIndex] -> Slice (None, None, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (None
_, None
_, a
step)) = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (CInt
forall a. Bounded a => a
maxBound :: CInt) (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
step :: CInt)
[RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice a) where
pushIndex :: [RawTensorIndex] -> Slice a -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice a
start) = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (CInt
forall a. Bounded a => a
maxBound :: CInt) CInt
1
[RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (a, None)) where
pushIndex :: [RawTensorIndex] -> Slice (a, None) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (a
start, None
_)) = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (CInt
forall a. Bounded a => a
maxBound :: CInt) CInt
1
[RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (a, None, a)) where
pushIndex :: [RawTensorIndex] -> Slice (a, None, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (a
start, None
_, a
step)) = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
start :: CInt) (CInt
forall a. Bounded a => a
maxBound :: CInt) (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
step :: CInt)
[RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (None, a, a)) where
pushIndex :: [RawTensorIndex] -> Slice (None, a, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (None
_, a
end, a
step)) = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
end :: CInt) (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
step :: CInt)
[RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} (Integral a) => TensorIndex (Slice (None, a)) where
pushIndex :: [RawTensorIndex] -> Slice (None, a) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice (None
_, a
end)) = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
end :: CInt) CInt
1
[RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance {-# OVERLAPS #-} TensorIndex (Slice ()) where
pushIndex :: [RawTensorIndex] -> Slice () -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (Slice ()) = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (CInt
forall a. Bounded a => a
maxBound :: CInt) CInt
1
[RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance TensorIndex Int where
pushIndex :: [RawTensorIndex] -> Int -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Int
v = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithInt (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
v :: CInt)
[RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance TensorIndex Integer where
pushIndex :: [RawTensorIndex] -> Integer -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Integer
v = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithInt (Integer -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
v :: CInt)
[RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance TensorIndex Tensor where
pushIndex :: [RawTensorIndex] -> Tensor -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec Tensor
v = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
RawTensorIndex
idx <- (ForeignPtr Tensor -> IO (ForeignPtr TensorIndex))
-> Tensor -> IO RawTensorIndex
forall a ca y cy.
(Castable a ca, Castable y cy) =>
(ca -> IO cy) -> a -> IO y
cast1 ForeignPtr Tensor -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithTensor Tensor
v
[RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RawTensorIndex
idx RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance TensorIndex () where
pushIndex :: [RawTensorIndex] -> () -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec ()
_ = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
ForeignPtr TensorIndex
idx <- CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 (CInt
forall a. Bounded a => a
maxBound :: CInt) CInt
1
[RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((ForeignPtr TensorIndex -> RawTensorIndex
RawTensorIndex ForeignPtr TensorIndex
idx) RawTensorIndex -> [RawTensorIndex] -> [RawTensorIndex]
forall a. a -> [a] -> [a]
: [RawTensorIndex]
vec)
instance (TensorIndex a, TensorIndex b) => TensorIndex (a, b) where
pushIndex :: [RawTensorIndex] -> (a, b) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (a
a, b
b) = (([RawTensorIndex] -> a -> [RawTensorIndex])
-> a -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> a -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex a
a) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> b -> [RawTensorIndex])
-> b -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> b -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex b
b) ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ [RawTensorIndex]
vec
instance (TensorIndex a, TensorIndex b, TensorIndex c) => TensorIndex (a, b, c) where
pushIndex :: [RawTensorIndex] -> (a, b, c) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (a
a, b
b, c
c) = (([RawTensorIndex] -> a -> [RawTensorIndex])
-> a -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> a -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex a
a) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> b -> [RawTensorIndex])
-> b -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> b -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex b
b) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> c -> [RawTensorIndex])
-> c -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> c -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex c
c) ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ [RawTensorIndex]
vec
instance (TensorIndex a, TensorIndex b, TensorIndex c, TensorIndex d) => TensorIndex (a, b, c, d) where
pushIndex :: [RawTensorIndex] -> (a, b, c, d) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (a
a, b
b, c
c, d
d) = (([RawTensorIndex] -> a -> [RawTensorIndex])
-> a -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> a -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex a
a) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> b -> [RawTensorIndex])
-> b -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> b -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex b
b) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> c -> [RawTensorIndex])
-> c -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> c -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex c
c) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> d -> [RawTensorIndex])
-> d -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> d -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex d
d) ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ [RawTensorIndex]
vec
instance (TensorIndex a, TensorIndex b, TensorIndex c, TensorIndex d, TensorIndex e) => TensorIndex (a, b, c, d, e) where
pushIndex :: [RawTensorIndex] -> (a, b, c, d, e) -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec (a
a, b
b, c
c, d
d, e
e) = (([RawTensorIndex] -> a -> [RawTensorIndex])
-> a -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> a -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex a
a) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> b -> [RawTensorIndex])
-> b -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> b -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex b
b) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> c -> [RawTensorIndex])
-> c -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> c -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex c
c) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> d -> [RawTensorIndex])
-> d -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> d -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex d
d) ([RawTensorIndex] -> [RawTensorIndex])
-> ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex]
-> [RawTensorIndex]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([RawTensorIndex] -> e -> [RawTensorIndex])
-> e -> [RawTensorIndex] -> [RawTensorIndex]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [RawTensorIndex] -> e -> [RawTensorIndex]
forall a.
TensorIndex a =>
[RawTensorIndex] -> a -> [RawTensorIndex]
pushIndex e
e) ([RawTensorIndex] -> [RawTensorIndex])
-> [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ [RawTensorIndex]
vec
asValue :: TensorLike a => Tensor -> a
asValue :: forall a. TensorLike a => Tensor -> a
asValue Tensor
t =
let cpuTensor :: Tensor
cpuTensor = if Tensor -> Device
device Tensor
t Device -> Device -> Bool
forall a. Eq a => a -> a -> Bool
== DeviceType -> Int16 -> Device
Device DeviceType
CPU Int16
0 then Tensor
t else Tensor -> Tensor
toCPU Tensor
t
contTensor :: Tensor
contTensor = if Tensor -> Bool
isContiguous Tensor
cpuTensor then Tensor
cpuTensor else Tensor -> Tensor
contiguous Tensor
cpuTensor
in Tensor -> a
forall a. TensorLike a => Tensor -> a
_asValue Tensor
contTensor
class TensorOptionLike a where
withTensorOptions :: Tensor -> a -> Tensor
instance TensorOptionLike TensorOptions where
withTensorOptions :: Tensor -> TensorOptions -> Tensor
withTensorOptions Tensor
t TensorOptions
opts = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor
-> ForeignPtr TensorOptions
-> CBool
-> CBool
-> IO (ForeignPtr Tensor))
-> Tensor -> TensorOptions -> Bool -> Bool -> IO Tensor
forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr TensorOptions
-> CBool
-> CBool
-> IO (ForeignPtr Tensor)
ATen.tensor_to_obb Tensor
t TensorOptions
opts Bool
nonBlocking Bool
copy
where
nonBlocking :: Bool
nonBlocking = Bool
False
copy :: Bool
copy = Bool
False
instance TensorOptionLike Tensor where
withTensorOptions :: Tensor -> Tensor -> Tensor
withTensorOptions Tensor
t Tensor
opts = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ (ForeignPtr Tensor
-> ForeignPtr Tensor -> CBool -> CBool -> IO (ForeignPtr Tensor))
-> Tensor -> Tensor -> Bool -> Bool -> IO Tensor
forall a ca x1 cx1 x2 cx2 x3 cx3 y cy.
(Castable a ca, Castable x1 cx1, Castable x2 cx2, Castable x3 cx3,
Castable y cy) =>
(ca -> cx1 -> cx2 -> cx3 -> IO cy) -> a -> x1 -> x2 -> x3 -> IO y
cast4 ForeignPtr Tensor
-> ForeignPtr Tensor -> CBool -> CBool -> IO (ForeignPtr Tensor)
ATen.tensor_to_tbb Tensor
t Tensor
opts Bool
nonBlocking Bool
copy
where
nonBlocking :: Bool
nonBlocking = Bool
False
copy :: Bool
copy = Bool
False
class TensorLike a where
asTensor' :: TensorOptionLike opt => a -> opt -> Tensor
asTensor' a
v opt
opts = Tensor -> opt -> Tensor
forall a. TensorOptionLike a => Tensor -> a -> Tensor
withTensorOptions (a -> Tensor
forall a. TensorLike a => a -> Tensor
asTensor a
v) opt
opts
asTensor :: a -> Tensor
_asValue :: Tensor -> a
_dtype :: DType
_dims :: a -> [Int]
_deepDims :: a -> Maybe [Int]
_peekElemOff :: Ptr () -> Int -> [Int] -> IO a
_pokeElemOff :: Ptr () -> Int -> a -> IO ()
bool_opts :: TensorOptions
bool_opts = DType -> TensorOptions -> TensorOptions
withDType DType
Bool TensorOptions
defaultOpts
uint8_opts :: TensorOptions
uint8_opts = DType -> TensorOptions -> TensorOptions
withDType DType
UInt8 TensorOptions
defaultOpts
int64_opts :: TensorOptions
int64_opts = DType -> TensorOptions -> TensorOptions
withDType DType
Int64 TensorOptions
defaultOpts
float_opts :: TensorOptions
float_opts = DType -> TensorOptions -> TensorOptions
withDType DType
Float TensorOptions
defaultOpts
double_opts :: TensorOptions
double_opts = DType -> TensorOptions -> TensorOptions
withDType DType
Double TensorOptions
defaultOpts
withTensor :: Tensor -> (Ptr () -> IO a) -> IO a
withTensor :: forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t Ptr () -> IO a
fn =
let tensor :: Tensor
tensor = if Tensor -> Bool
isContiguous Tensor
t then Tensor
t else Tensor -> Tensor
contiguous Tensor
t
in Tensor -> (ForeignPtr Tensor -> IO a) -> IO a
forall r. Tensor -> (ForeignPtr Tensor -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
tensor ((ForeignPtr Tensor -> IO a) -> IO a)
-> (ForeignPtr Tensor -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \ForeignPtr Tensor
t' -> ForeignPtr Tensor -> (Ptr Tensor -> IO a) -> IO a
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Tensor
t' ((Ptr Tensor -> IO a) -> IO a) -> (Ptr Tensor -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr Tensor
tensor_ptr -> Ptr Tensor -> IO (Ptr ())
Unmanaged.tensor_data_ptr Ptr Tensor
tensor_ptr IO (Ptr ()) -> (Ptr () -> IO a) -> IO a
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr () -> IO a
fn
_withTensor :: Tensor -> (Ptr () -> IO a) -> IO a
_withTensor :: forall a. Tensor -> (Ptr () -> IO a) -> IO a
_withTensor Tensor
t Ptr () -> IO a
fn =
Tensor -> (ForeignPtr Tensor -> IO a) -> IO a
forall r. Tensor -> (ForeignPtr Tensor -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
t ((ForeignPtr Tensor -> IO a) -> IO a)
-> (ForeignPtr Tensor -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \ForeignPtr Tensor
t' -> ForeignPtr Tensor -> (Ptr Tensor -> IO a) -> IO a
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Tensor
t' ((Ptr Tensor -> IO a) -> IO a) -> (Ptr Tensor -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr Tensor
tensor_ptr -> Ptr Tensor -> IO (Ptr ())
Unmanaged.tensor_data_ptr Ptr Tensor
tensor_ptr IO (Ptr ()) -> (Ptr () -> IO a) -> IO a
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr () -> IO a
fn
instance {-# OVERLAPPING #-} (Reifies a DType, Storable a) => TensorLike a where
asTensor :: a -> Tensor
asTensor a
v = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ do
Tensor
t <- ((([Int] -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor))
-> [Int] -> TensorOptions -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 [Int] -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.new_empty_tensor) :: [Int] -> TensorOptions -> IO Tensor) [] (TensorOptions -> IO Tensor) -> TensorOptions -> IO Tensor
forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
withDType (forall a. TensorLike a => DType
_dtype @a) TensorOptions
defaultOpts
Tensor -> (Ptr () -> IO ()) -> IO ()
forall a. Tensor -> (Ptr () -> IO a) -> IO a
_withTensor Tensor
t ((Ptr () -> IO ()) -> IO ()) -> (Ptr () -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
Ptr () -> Int -> a -> IO ()
forall a. TensorLike a => Ptr () -> Int -> a -> IO ()
_pokeElemOff Ptr ()
ptr Int
0 a
v
Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t
_asValue :: Tensor -> a
_asValue Tensor
t = IO a -> a
forall a. IO a -> a
unsafePerformIO (IO a -> a) -> IO a -> a
forall a b. (a -> b) -> a -> b
$ do
if forall a. TensorLike a => DType
_dtype @a DType -> DType -> Bool
forall a. Eq a => a -> a -> Bool
== Tensor -> DType
dtype Tensor
t
then do
Tensor -> (Ptr () -> IO a) -> IO a
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t ((Ptr () -> IO a) -> IO a) -> (Ptr () -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
Ptr () -> Int -> [Int] -> IO a
forall a. TensorLike a => Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr Int
0 []
else IOError -> IO a
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> IO a) -> IOError -> IO a
forall a b. (a -> b) -> a -> b
$ [Char] -> IOError
userError ([Char] -> IOError) -> [Char] -> IOError
forall a b. (a -> b) -> a -> b
$ [Char]
"The infered DType of asValue is " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ DType -> [Char]
forall a. Show a => a -> [Char]
show (forall a. TensorLike a => DType
_dtype @a) [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
", but the DType of tensor on memory is " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ DType -> [Char]
forall a. Show a => a -> [Char]
show (Tensor -> DType
dtype Tensor
t) [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"."
_dtype :: DType
_dtype = Proxy a -> DType
forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
forall (proxy :: * -> *). proxy a -> DType
reflect (Proxy a
forall {k} (t :: k). Proxy t
Proxy :: Proxy a)
_dims :: a -> [Int]
_dims a
_ = []
_deepDims :: a -> Maybe [Int]
_deepDims a
_ = [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just []
_peekElemOff :: Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr Int
offset [Int]
_ = Ptr a -> Int -> IO a
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (Ptr () -> Ptr a
forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset
_pokeElemOff :: Ptr () -> Int -> a -> IO ()
_pokeElemOff Ptr ()
ptr Int
offset a
v = Ptr a -> Int -> a -> IO ()
forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff (Ptr () -> Ptr a
forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset a
v
instance {-# OVERLAPPING #-} TensorLike Bool where
asTensor :: Bool -> Tensor
asTensor Bool
v = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ do
Tensor
t <- ((([Int] -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor))
-> [Int] -> TensorOptions -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 [Int] -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.new_empty_tensor) :: [Int] -> TensorOptions -> IO Tensor) [] (TensorOptions -> IO Tensor) -> TensorOptions -> IO Tensor
forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
withDType (forall a. TensorLike a => DType
_dtype @Bool) TensorOptions
defaultOpts
Tensor -> (Ptr () -> IO ()) -> IO ()
forall a. Tensor -> (Ptr () -> IO a) -> IO a
_withTensor Tensor
t ((Ptr () -> IO ()) -> IO ()) -> (Ptr () -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
Ptr () -> Int -> Bool -> IO ()
forall a. TensorLike a => Ptr () -> Int -> a -> IO ()
_pokeElemOff Ptr ()
ptr Int
0 Bool
v
Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t
_asValue :: Tensor -> Bool
_asValue Tensor
t = IO Bool -> Bool
forall a. IO a -> a
unsafePerformIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$ do
if forall a. TensorLike a => DType
_dtype @Bool DType -> DType -> Bool
forall a. Eq a => a -> a -> Bool
== Tensor -> DType
dtype Tensor
t
then do
Tensor -> (Ptr () -> IO Bool) -> IO Bool
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t ((Ptr () -> IO Bool) -> IO Bool) -> (Ptr () -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
Ptr () -> Int -> [Int] -> IO Bool
forall a. TensorLike a => Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr Int
0 []
else IOError -> IO Bool
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> IO Bool) -> IOError -> IO Bool
forall a b. (a -> b) -> a -> b
$ [Char] -> IOError
userError ([Char] -> IOError) -> [Char] -> IOError
forall a b. (a -> b) -> a -> b
$ [Char]
"The infered DType of asValue is " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ DType -> [Char]
forall a. Show a => a -> [Char]
show (forall a. TensorLike a => DType
_dtype @Bool) [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
", but the DType of tensor on memory is " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ DType -> [Char]
forall a. Show a => a -> [Char]
show (Tensor -> DType
dtype Tensor
t) [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"."
_dtype :: DType
_dtype = Proxy Bool -> DType
forall {k} (s :: k) a (proxy :: k -> *).
Reifies s a =>
proxy s -> a
forall (proxy :: * -> *). proxy Bool -> DType
reflect (Proxy Bool
forall {k} (t :: k). Proxy t
Proxy :: Proxy Bool)
_dims :: Bool -> [Int]
_dims Bool
_ = []
_deepDims :: Bool -> Maybe [Int]
_deepDims Bool
_ = [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just []
_peekElemOff :: Ptr () -> Int -> [Int] -> IO Bool
_peekElemOff Ptr ()
ptr Int
offset [Int]
_ = (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word8
0) (Word8 -> Bool) -> IO Word8 -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ptr Word8 -> Int -> IO Word8
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff (Ptr () -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset :: IO Word8)
_pokeElemOff :: Ptr () -> Int -> Bool -> IO ()
_pokeElemOff Ptr ()
ptr Int
offset Bool
v = Ptr Word8 -> Int -> Word8 -> IO ()
forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff (Ptr () -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr ()
ptr) Int
offset ((if Bool
v then Word8
1 else Word8
0) :: Word8)
instance {-# OVERLAPPING #-} TensorLike Tensor where
asTensor' :: forall a. TensorOptionLike a => Tensor -> a -> Tensor
asTensor' Tensor
v opt
opts = Tensor -> opt -> Tensor
forall a. TensorOptionLike a => Tensor -> a -> Tensor
withTensorOptions Tensor
v opt
opts
asTensor :: Tensor -> Tensor
asTensor = Tensor -> Tensor
forall a. a -> a
id
_asValue :: Tensor -> Tensor
_asValue = Tensor -> Tensor
forall a. a -> a
id
_dtype :: DType
_dtype = [Char] -> DType
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"
_dims :: Tensor -> [Int]
_dims Tensor
v = [Char] -> [Int]
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"
_deepDims :: Tensor -> Maybe [Int]
_deepDims Tensor
v = [Char] -> Maybe [Int]
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"
_peekElemOff :: Ptr () -> Int -> [Int] -> IO Tensor
_peekElemOff = [Char] -> Ptr () -> Int -> [Int] -> IO Tensor
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"
_pokeElemOff :: Ptr () -> Int -> Tensor -> IO ()
_pokeElemOff = [Char] -> Ptr () -> Int -> Tensor -> IO ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for Tensor-type"
instance {-# OVERLAPPING #-} TensorLike a => TensorLike (a, a) where
asTensor :: (a, a) -> Tensor
asTensor (a
a, a
b) = [a] -> Tensor
forall a. TensorLike a => a -> Tensor
asTensor [a
a, a
b]
_asValue :: Tensor -> (a, a)
_asValue Tensor
v =
let [a
a, a
b] = Tensor -> [a]
forall a. TensorLike a => Tensor -> a
_asValue Tensor
v
in (a
a, a
b)
_dtype :: DType
_dtype = [Char] -> DType
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"
_dims :: (a, a) -> [Int]
_dims (a, a)
v = [Char] -> [Int]
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"
_deepDims :: (a, a) -> Maybe [Int]
_deepDims (a, a)
v = [Char] -> Maybe [Int]
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"
_peekElemOff :: Ptr () -> Int -> [Int] -> IO (a, a)
_peekElemOff = [Char] -> Ptr () -> Int -> [Int] -> IO (a, a)
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"
_pokeElemOff :: Ptr () -> Int -> (a, a) -> IO ()
_pokeElemOff = [Char] -> Ptr () -> Int -> (a, a) -> IO ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Not implemented for tuple-type"
instance {-# OVERLAPPING #-} TensorLike a => TensorLike [a] where
asTensor :: [a] -> Tensor
asTensor [a]
v = IO Tensor -> Tensor
forall a. IO a -> a
unsafePerformIO (IO Tensor -> Tensor) -> IO Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ do
Tensor
t <- ((([Int] -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor))
-> [Int] -> TensorOptions -> IO Tensor
forall a ca x1 cx1 y cy.
(Castable a ca, Castable x1 cx1, Castable y cy) =>
(ca -> cx1 -> IO cy) -> a -> x1 -> IO y
cast2 [Int] -> ForeignPtr TensorOptions -> IO (ForeignPtr Tensor)
ATen.new_empty_tensor) :: [Int] -> TensorOptions -> IO Tensor) ([a] -> [Int]
forall a. TensorLike a => a -> [Int]
_dims [a]
v) (TensorOptions -> IO Tensor) -> TensorOptions -> IO Tensor
forall a b. (a -> b) -> a -> b
$ DType -> TensorOptions -> TensorOptions
withDType (forall a. TensorLike a => DType
_dtype @a) TensorOptions
defaultOpts
Tensor -> (Ptr () -> IO ()) -> IO ()
forall a. Tensor -> (Ptr () -> IO a) -> IO a
_withTensor Tensor
t ((Ptr () -> IO ()) -> IO ()) -> (Ptr () -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
Ptr () -> Int -> [a] -> IO ()
forall a. TensorLike a => Ptr () -> Int -> a -> IO ()
_pokeElemOff Ptr ()
ptr Int
0 [a]
v
Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Tensor
t
_asValue :: Tensor -> [a]
_asValue Tensor
t = IO [a] -> [a]
forall a. IO a -> a
unsafePerformIO (IO [a] -> [a]) -> IO [a] -> [a]
forall a b. (a -> b) -> a -> b
$ do
if forall a. TensorLike a => DType
_dtype @a DType -> DType -> Bool
forall a. Eq a => a -> a -> Bool
== Tensor -> DType
dtype Tensor
t
then do
Tensor -> (Ptr () -> IO [a]) -> IO [a]
forall a. Tensor -> (Ptr () -> IO a) -> IO a
withTensor Tensor
t ((Ptr () -> IO [a]) -> IO [a]) -> (Ptr () -> IO [a]) -> IO [a]
forall a b. (a -> b) -> a -> b
$ \Ptr ()
ptr -> do
Ptr () -> Int -> [Int] -> IO [a]
forall a. TensorLike a => Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr Int
0 (Tensor -> [Int]
shape Tensor
t)
else IOError -> IO [a]
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> IO [a]) -> IOError -> IO [a]
forall a b. (a -> b) -> a -> b
$ [Char] -> IOError
userError ([Char] -> IOError) -> [Char] -> IOError
forall a b. (a -> b) -> a -> b
$ [Char]
"The infered DType of asValue is " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ DType -> [Char]
forall a. Show a => a -> [Char]
show (forall a. TensorLike a => DType
_dtype @a) [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
", but the DType of tensor on memory is " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ DType -> [Char]
forall a. Show a => a -> [Char]
show (Tensor -> DType
dtype Tensor
t) [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"."
_dtype :: DType
_dtype = forall a. TensorLike a => DType
_dtype @a
_dims :: [a] -> [Int]
_dims [] = [Int
0]
_dims v :: [a]
v@(a
x : [a]
_) = ([a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
v) Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: (a -> [Int]
forall a. TensorLike a => a -> [Int]
_dims a
x)
_deepDims :: [a] -> Maybe [Int]
_deepDims [] = [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int
0]
_deepDims v :: [a]
v@(a
x : [a]
xs) = do
[Int]
deepDimsX <- a -> Maybe [Int]
forall a. TensorLike a => a -> Maybe [Int]
_deepDims a
x
[[Int]]
deepDimsXs <- (a -> Maybe [Int]) -> [a] -> Maybe [[Int]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse a -> Maybe [Int]
forall a. TensorLike a => a -> Maybe [Int]
_deepDims [a]
xs
if [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ ([Int] -> Bool) -> [[Int]] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Int]
deepDimsX [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
==) [[Int]]
deepDimsXs
then [Int] -> Maybe [Int]
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int] -> Maybe [Int]) -> [Int] -> Maybe [Int]
forall a b. (a -> b) -> a -> b
$ [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
v Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
deepDimsX
else Maybe [Int]
forall a. Maybe a
Nothing
_peekElemOff :: Ptr () -> Int -> [Int] -> IO [a]
_peekElemOff Ptr ()
ptr Int
offset [] = [a] -> IO [a]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return []
_peekElemOff Ptr ()
ptr Int
offset (Int
d : [Int]
dims) =
let width :: Int
width = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
dims
in [Int] -> (Int -> IO a) -> IO [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Int
0 .. (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)] ((Int -> IO a) -> IO [a]) -> (Int -> IO a) -> IO [a]
forall a b. (a -> b) -> a -> b
$ \Int
i ->
Ptr () -> Int -> [Int] -> IO a
forall a. TensorLike a => Ptr () -> Int -> [Int] -> IO a
_peekElemOff Ptr ()
ptr (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
width) [Int]
dims
_pokeElemOff :: Ptr () -> Int -> [a] -> IO ()
_pokeElemOff Ptr ()
ptr Int
offset [] = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
_pokeElemOff Ptr ()
ptr Int
offset v :: [a]
v@(a
x : [a]
_) =
let width :: Int
width = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (a -> [Int]
forall a. TensorLike a => a -> [Int]
_dims a
x)
in [(Int, a)] -> ((Int, a) -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int] -> [a] -> [(Int, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] [a]
v) (((Int, a) -> IO ()) -> IO ()) -> ((Int, a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Int
i, a
d) ->
if [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (a -> [Int]
forall a. TensorLike a => a -> [Int]
_dims a
d) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
width
then (forall a. TensorLike a => Ptr () -> Int -> a -> IO ()
_pokeElemOff @a) Ptr ()
ptr (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
width) a
d
else IOError -> IO ()
forall (m :: * -> *) e a.
(HasCallStack, MonadThrow m, Exception e) =>
e -> m a
throwIO (IOError -> IO ()) -> IOError -> IO ()
forall a b. (a -> b) -> a -> b
$ [Char] -> IOError
userError ([Char] -> IOError) -> [Char] -> IOError
forall a b. (a -> b) -> a -> b
$ [Char]
"There are lists having different length."
class AsTensors as where
toTensors :: as -> V.Vector Tensor
default toTensors :: (Generic as, GAsTensors (Rep as)) => as -> V.Vector Tensor
toTensors as
a = Rep as Any -> Vector Tensor
forall as. Rep as as -> Vector Tensor
forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors (Rep as Any -> Vector Tensor) -> Rep as Any -> Vector Tensor
forall a b. (a -> b) -> a -> b
$ as -> Rep as Any
forall x. as -> Rep as x
forall a x. Generic a => a -> Rep a x
from as
a
instance TensorLike a => AsTensors a where
toTensors :: a -> Vector Tensor
toTensors = Tensor -> Vector Tensor
forall a. a -> Vector a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor -> Vector Tensor) -> (a -> Tensor) -> a -> Vector Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Tensor
forall a. TensorLike a => a -> Tensor
asTensor
class GAsTensors record where
gToTensors :: record as -> V.Vector Tensor
instance (GAsTensors ls, GAsTensors rs) => GAsTensors (ls :*: rs) where
gToTensors :: forall as. (:*:) ls rs as -> Vector Tensor
gToTensors (ls as
g :*: rs as
d) = ls as -> Vector Tensor
forall as. ls as -> Vector Tensor
forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors ls as
g Vector Tensor -> Vector Tensor -> Vector Tensor
forall a. Vector a -> Vector a -> Vector a
V.++ rs as -> Vector Tensor
forall as. rs as -> Vector Tensor
forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors rs as
d
instance (GAsTensors ls, GAsTensors rs) => GAsTensors (ls :+: rs) where
gToTensors :: forall as. (:+:) ls rs as -> Vector Tensor
gToTensors (L1 ls as
g) = ls as -> Vector Tensor
forall as. ls as -> Vector Tensor
forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors ls as
g
gToTensors (R1 rs as
g) = rs as -> Vector Tensor
forall as. rs as -> Vector Tensor
forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors rs as
g
instance (GAsTensors ls) => GAsTensors (M1 i c ls) where
gToTensors :: forall as. M1 i c ls as -> Vector Tensor
gToTensors (M1 ls as
g) = ls as -> Vector Tensor
forall as. ls as -> Vector Tensor
forall (record :: * -> *) as.
GAsTensors record =>
record as -> Vector Tensor
gToTensors ls as
g
instance (TensorLike ls) => GAsTensors (K1 i ls) where
gToTensors :: forall as. K1 i ls as -> Vector Tensor
gToTensors (K1 ls
g) = Tensor -> Vector Tensor
forall a. a -> Vector a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor -> Vector Tensor) -> Tensor -> Vector Tensor
forall a b. (a -> b) -> a -> b
$ ls -> Tensor
forall a. TensorLike a => a -> Tensor
asTensor ls
g
instance Show Tensor where
show :: Tensor -> [Char]
show Tensor
t' =
case (Tensor -> Int
dim Tensor
t) of
Int
0 -> [Char]
details [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Tensor -> [Char]
show0d Tensor
t
Int
1 -> [Char]
details [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Tensor -> [Char]
show1d Tensor
t
Int
n -> [Char]
details [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> Int -> Tensor -> [Char]
shownd Int
n Int
0 Tensor
t
where
t :: Tensor
t = if Tensor -> Device
device Tensor
t' Device -> Device -> Bool
forall a. Eq a => a -> a -> Bool
== DeviceType -> Int16 -> Device
Device DeviceType
CPU Int16
0 then Tensor
t' else Tensor -> Tensor
toCPU Tensor
t'
showElems :: (Tensor -> [Char]) -> [Char] -> Tensor -> [Char]
showElems Tensor -> [Char]
elemShow [Char]
sep Tensor
t = [Char]
"[" [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ ([Char] -> [[Char]] -> [Char]
forall a. [a] -> [[a]] -> [a]
intercalate [Char]
sep ([[Char]] -> [Char]) -> [[Char]] -> [Char]
forall a b. (a -> b) -> a -> b
$ (Tensor -> [Char]) -> [Tensor] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map Tensor -> [Char]
elemShow [Tensor
t Tensor -> Int -> Tensor
forall a. TensorIndex a => Tensor -> a -> Tensor
! Int
i | Int
i <- [Int
0 .. ((Int -> Tensor -> Int
size Int
0 Tensor
t) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)]]) [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"]"
padPositive :: a -> ShowS
padPositive a
x [Char]
s = if a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
0 then [Char]
" " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
s else [Char]
s
padLarge :: a -> ShowS
padLarge a
x [Char]
s = if (a -> a
forall a. Num a => a -> a
abs a
x) a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
0.1 then [Char]
s [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" " else [Char]
s
show0d :: Tensor -> [Char]
show0d Tensor
x =
if DType -> Bool
isIntegral (Tensor -> DType
dtype Tensor
t)
then Int -> ShowS
forall {a}. (Ord a, Num a) => a -> ShowS
padPositive (Tensor -> Int
toInt Tensor
x) ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ Int -> [Char]
forall a. Show a => a -> [Char]
show (Int -> [Char]) -> Int -> [Char]
forall a b. (a -> b) -> a -> b
$ Tensor -> Int
toInt Tensor
x
else
if DType -> Bool
isComplex (Tensor -> DType
dtype Tensor
t)
then
let Double
r :+ Double
i = Tensor -> Complex Double
toComplex Tensor
x
in (Double -> ShowS
forall {a}. (Ord a, Fractional a) => a -> ShowS
padLarge Double
r ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ Double -> ShowS
forall {a}. (Ord a, Num a) => a -> ShowS
padPositive Double
r ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ Maybe Int -> Double -> ShowS
forall a. RealFloat a => Maybe Int -> a -> ShowS
showGFloat (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
4) Double
r [Char]
"") [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" + i" [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++
(Double -> ShowS
forall {a}. (Ord a, Fractional a) => a -> ShowS
padLarge Double
i ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ Double -> ShowS
forall {a}. (Ord a, Num a) => a -> ShowS
padPositive Double
i ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ Maybe Int -> Double -> ShowS
forall a. RealFloat a => Maybe Int -> a -> ShowS
showGFloat (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
4) Double
i [Char]
"")
else Double -> ShowS
forall {a}. (Ord a, Fractional a) => a -> ShowS
padLarge (Tensor -> Double
toDouble Tensor
x) ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ Double -> ShowS
forall {a}. (Ord a, Num a) => a -> ShowS
padPositive (Tensor -> Double
toDouble Tensor
x) ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ Maybe Int -> Double -> ShowS
forall a. RealFloat a => Maybe Int -> a -> ShowS
showGFloat (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
4) (Tensor -> Double
toDouble Tensor
x) [Char]
""
show1d :: Tensor -> [Char]
show1d = (Tensor -> [Char]) -> [Char] -> Tensor -> [Char]
showElems Tensor -> [Char]
show0d [Char]
", "
shownd :: Int -> Int -> Tensor -> [Char]
shownd Int
n Int
offset =
case Int
n of
Int
2 -> (Tensor -> [Char]) -> [Char] -> Tensor -> [Char]
showElems Tensor -> [Char]
show1d ([Char]
",\n " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
padding [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> Char -> [Char]
forall a. Int -> a -> [a]
replicate Int
offset Char
' ')
Int
_ -> (Tensor -> [Char]) -> [Char] -> Tensor -> [Char]
showElems (Int -> Int -> Tensor -> [Char]
shownd (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)) ([Char]
",\n " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
padding [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> Char -> [Char]
forall a. Int -> a -> [a]
replicate Int
offset Char
' ')
details :: [Char]
details = [Char]
"Tensor " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ (DType -> [Char]
forall a. Show a => a -> [Char]
show (DType -> [Char]) -> DType -> [Char]
forall a b. (a -> b) -> a -> b
$ Tensor -> DType
dtype Tensor
t) [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ ([Int] -> [Char]
forall a. Show a => a -> [Char]
show ([Int] -> [Char]) -> [Int] -> [Char]
forall a b. (a -> b) -> a -> b
$ Tensor -> [Int]
shape Tensor
t) [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" "
padding :: [Char]
padding = (Char -> Char) -> ShowS
forall a b. (a -> b) -> [a] -> [b]
map (Char -> Char -> Char
forall a b. a -> b -> a
const Char
' ') [Char]
details
instance Castable [Tensor] (ForeignPtr ATen.TensorList) where
cast :: forall r. [Tensor] -> (ForeignPtr TensorList -> IO r) -> IO r
cast [Tensor]
xs ForeignPtr TensorList -> IO r
f = do
[ForeignPtr Tensor]
ptr_list <- (Tensor -> IO (ForeignPtr Tensor))
-> [Tensor] -> IO [ForeignPtr Tensor]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Tensor
x -> (Tensor
-> (ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> IO (ForeignPtr Tensor)
forall r. Tensor -> (ForeignPtr Tensor -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
x ForeignPtr Tensor -> IO (ForeignPtr Tensor)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) [Tensor]
xs
[ForeignPtr Tensor] -> (ForeignPtr TensorList -> IO r) -> IO r
forall r.
[ForeignPtr Tensor] -> (ForeignPtr TensorList -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr Tensor]
ptr_list ForeignPtr TensorList -> IO r
f
uncast :: forall r. ForeignPtr TensorList -> ([Tensor] -> IO r) -> IO r
uncast ForeignPtr TensorList
xs [Tensor] -> IO r
f = ForeignPtr TensorList -> ([ForeignPtr Tensor] -> IO r) -> IO r
forall r.
ForeignPtr TensorList -> ([ForeignPtr Tensor] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs (([ForeignPtr Tensor] -> IO r) -> IO r)
-> ([ForeignPtr Tensor] -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \[ForeignPtr Tensor]
ptr_list -> do
[Tensor]
tensor_list <- (ForeignPtr Tensor -> IO Tensor)
-> [ForeignPtr Tensor] -> IO [Tensor]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(ForeignPtr Tensor
x :: ForeignPtr ATen.Tensor) -> ForeignPtr Tensor -> (Tensor -> IO Tensor) -> IO Tensor
forall r. ForeignPtr Tensor -> (Tensor -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
x Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return) [ForeignPtr Tensor]
ptr_list
[Tensor] -> IO r
f [Tensor]
tensor_list
instance Castable [Tensor] (ForeignPtr (ATen.C10List ATen.Tensor)) where
cast :: forall r. [Tensor] -> (ForeignPtr (C10List Tensor) -> IO r) -> IO r
cast [Tensor]
xs ForeignPtr (C10List Tensor) -> IO r
f = do
[ForeignPtr Tensor]
ptr_list <- (Tensor -> IO (ForeignPtr Tensor))
-> [Tensor] -> IO [ForeignPtr Tensor]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Tensor
x -> (Tensor
-> (ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> IO (ForeignPtr Tensor)
forall r. Tensor -> (ForeignPtr Tensor -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
x ForeignPtr Tensor -> IO (ForeignPtr Tensor)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) [Tensor]
xs
[ForeignPtr Tensor]
-> (ForeignPtr (C10List Tensor) -> IO r) -> IO r
forall r.
[ForeignPtr Tensor]
-> (ForeignPtr (C10List Tensor) -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr Tensor]
ptr_list ForeignPtr (C10List Tensor) -> IO r
f
uncast :: forall r. ForeignPtr (C10List Tensor) -> ([Tensor] -> IO r) -> IO r
uncast ForeignPtr (C10List Tensor)
xs [Tensor] -> IO r
f = ForeignPtr (C10List Tensor)
-> ([ForeignPtr Tensor] -> IO r) -> IO r
forall r.
ForeignPtr (C10List Tensor)
-> ([ForeignPtr Tensor] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List Tensor)
xs (([ForeignPtr Tensor] -> IO r) -> IO r)
-> ([ForeignPtr Tensor] -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \[ForeignPtr Tensor]
ptr_list -> do
[Tensor]
tensor_list <- (ForeignPtr Tensor -> IO Tensor)
-> [ForeignPtr Tensor] -> IO [Tensor]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(ForeignPtr Tensor
x :: ForeignPtr ATen.Tensor) -> ForeignPtr Tensor -> (Tensor -> IO Tensor) -> IO Tensor
forall r. ForeignPtr Tensor -> (Tensor -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
x Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return) [ForeignPtr Tensor]
ptr_list
[Tensor] -> IO r
f [Tensor]
tensor_list
instance Castable [Tensor] (ForeignPtr (ATen.C10List (ATen.C10Optional ATen.Tensor))) where
cast :: forall r.
[Tensor]
-> (ForeignPtr (C10List (C10Optional Tensor)) -> IO r) -> IO r
cast [Tensor]
xs ForeignPtr (C10List (C10Optional Tensor)) -> IO r
f = do
[ForeignPtr Tensor]
ptr_list <- (Tensor -> IO (ForeignPtr Tensor))
-> [Tensor] -> IO [ForeignPtr Tensor]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Tensor
x -> (Tensor
-> (ForeignPtr Tensor -> IO (ForeignPtr Tensor))
-> IO (ForeignPtr Tensor)
forall r. Tensor -> (ForeignPtr Tensor -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor
x ForeignPtr Tensor -> IO (ForeignPtr Tensor)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) [Tensor]
xs
[ForeignPtr Tensor]
-> (ForeignPtr (C10List (C10Optional Tensor)) -> IO r) -> IO r
forall r.
[ForeignPtr Tensor]
-> (ForeignPtr (C10List (C10Optional Tensor)) -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast [ForeignPtr Tensor]
ptr_list ForeignPtr (C10List (C10Optional Tensor)) -> IO r
f
uncast :: forall r.
ForeignPtr (C10List (C10Optional Tensor))
-> ([Tensor] -> IO r) -> IO r
uncast ForeignPtr (C10List (C10Optional Tensor))
xs [Tensor] -> IO r
f = ForeignPtr (C10List (C10Optional Tensor))
-> ([ForeignPtr Tensor] -> IO r) -> IO r
forall r.
ForeignPtr (C10List (C10Optional Tensor))
-> ([ForeignPtr Tensor] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr (C10List (C10Optional Tensor))
xs (([ForeignPtr Tensor] -> IO r) -> IO r)
-> ([ForeignPtr Tensor] -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \[ForeignPtr Tensor]
ptr_list -> do
[Tensor]
tensor_list <- (ForeignPtr Tensor -> IO Tensor)
-> [ForeignPtr Tensor] -> IO [Tensor]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(ForeignPtr Tensor
x :: ForeignPtr ATen.Tensor) -> ForeignPtr Tensor -> (Tensor -> IO Tensor) -> IO Tensor
forall r. ForeignPtr Tensor -> (Tensor -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr Tensor
x Tensor -> IO Tensor
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return) [ForeignPtr Tensor]
ptr_list
[Tensor] -> IO r
f [Tensor]
tensor_list