{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE NoStarIsType #-}
module Torch.Typed.Tensor where
import Control.Arrow
import Control.Category
import qualified Numeric.Half as N
import Data.Complex
import Data.Finite
import Data.Kind
( Constraint,
Type,
)
import Data.Maybe
import Data.Proxy
import Data.Reflection
import Data.Vector.Sized (Vector)
import qualified Data.Vector.Sized as V
import Foreign.ForeignPtr
import Foreign.Storable
import GHC.Exts
import GHC.Generics
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import qualified Torch.Functional as D hiding (select)
import Torch.HList
import Torch.Internal.Cast
import Torch.Internal.Class
( Castable (..),
CppTuple2 (..),
CppTuple3 (..),
CppTuple4 (..),
)
import qualified Torch.Internal.Type as ATen
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
import Torch.Typed.Auxiliary
import Prelude hiding (id, (.))
class KnownShape (shape :: [Nat]) where
shapeVal :: [Int]
instance KnownShape '[] where
shapeVal :: [Int]
shapeVal = []
instance (KnownNat h, KnownShape t) => KnownShape (h ': t) where
shapeVal :: [Int]
shapeVal = forall (n :: Nat). KnownNat n => Int
natValI @h Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: forall (shape :: [Nat]). KnownShape shape => [Int]
shapeVal @t
getFiniteI :: Finite n -> Int
getFiniteI :: forall (n :: Nat). Finite n -> Int
getFiniteI = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> (Finite n -> Integer) -> Finite n -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Finite n -> Integer
forall (n :: Nat). Finite n -> Integer
getFinite
class KnownDType (dtype :: D.DType) where
dtypeVal :: D.DType
instance KnownDType 'D.Bool where
dtypeVal :: DType
dtypeVal = DType
D.Bool
instance KnownDType 'D.UInt8 where
dtypeVal :: DType
dtypeVal = DType
D.UInt8
instance KnownDType 'D.Int8 where
dtypeVal :: DType
dtypeVal = DType
D.Int8
instance KnownDType 'D.Int16 where
dtypeVal :: DType
dtypeVal = DType
D.Int16
instance KnownDType 'D.Int32 where
dtypeVal :: DType
dtypeVal = DType
D.Int32
instance KnownDType 'D.Int64 where
dtypeVal :: DType
dtypeVal = DType
D.Int64
instance KnownDType 'D.Half where
dtypeVal :: DType
dtypeVal = DType
D.Half
instance KnownDType 'D.Float where
dtypeVal :: DType
dtypeVal = DType
D.Float
instance KnownDType 'D.Double where
dtypeVal :: DType
dtypeVal = DType
D.Double
instance KnownDType 'D.ComplexHalf where
dtypeVal :: DType
dtypeVal = DType
D.ComplexHalf
instance KnownDType 'D.ComplexFloat where
dtypeVal :: DType
dtypeVal = DType
D.ComplexFloat
instance KnownDType 'D.ComplexDouble where
dtypeVal :: DType
dtypeVal = DType
D.ComplexDouble
type family ComputeDType (dtype' :: dtype) :: D.DType where
ComputeDType Bool = D.Bool
ComputeDType D.Bool = D.Bool
ComputeDType D.UInt8 = D.UInt8
ComputeDType D.Int8 = D.Int8
ComputeDType D.Int16 = D.Int16
ComputeDType D.Int32 = D.Int32
ComputeDType Int = D.Int64
ComputeDType D.Int64 = D.Int64
ComputeDType N.Half = D.Half
ComputeDType D.Half = D.Half
ComputeDType Float = D.Float
ComputeDType D.Float = D.Float
ComputeDType Double = D.Double
ComputeDType D.Double = D.Double
ComputeDType (Complex N.Half) = D.ComplexHalf
ComputeDType D.ComplexHalf = D.ComplexHalf
ComputeDType (Complex Float) = D.ComplexFloat
ComputeDType D.ComplexFloat = D.ComplexFloat
ComputeDType (Complex Double) = D.ComplexDouble
ComputeDType D.ComplexDouble = D.ComplexDouble
ComputeDType dtype' = TypeError (Text "Unsupported tensor type " :<>: ShowType dtype')
class KnownDevice (device :: (D.DeviceType, Nat)) where
deviceVal :: D.Device
instance (KnownNat n) => KnownDevice '( 'D.CPU, n) where
deviceVal :: Device
deviceVal = DeviceType -> Int16 -> Device
D.Device DeviceType
D.CPU (forall (n :: Nat). KnownNat n => Int16
natValInt16 @n)
instance (KnownNat n) => KnownDevice '( 'D.CUDA, n) where
deviceVal :: Device
deviceVal = DeviceType -> Int16 -> Device
D.Device DeviceType
D.CUDA (forall (n :: Nat). KnownNat n => Int16
natValInt16 @n)
instance (KnownNat n) => KnownDevice '( 'D.MPS, n) where
deviceVal :: Device
deviceVal = DeviceType -> Int16 -> Device
D.Device DeviceType
D.MPS (forall (n :: Nat). KnownNat n => Int16
natValInt16 @n)
type Size = Type -> Type
type Shape = [Type -> Type]
type family ToNat (shape :: Size) :: Nat where
ToNat (S1 ('MetaSel _ _ _ _) f) = ToNat f
ToNat (D1 _ f) = ToNat f
ToNat (C1 _ f) = ToNat f
ToNat (l :*: r) = ToNat l + ToNat r
ToNat (l :+: r) = If (ToNat l <=? ToNat r) (ToNat r) (ToNat l)
ToNat (K1 R (Vector n _)) = n
ToNat (K1 _ _) = 1
ToNat U1 = 1
ToNat (Vector n) = n
ToNat a = ToNat (Rep (a ()))
type family ToNats (shape :: Shape) :: [Nat] where
ToNats '[] = '[]
ToNats (x ': xs) = ToNat x ': ToNats xs
type family FromNat (shape :: Nat) :: Size where
FromNat n = Vector n
type family FromNats (shape :: [Nat]) :: Shape where
FromNats '[] = '[]
FromNats (x ': xs) = FromNat x ': FromNats xs
class Unnamed t where
type UTShape t :: [Nat]
type UTDevice t :: (D.DeviceType, Nat)
type UTDType t :: D.DType
toUnnamed ::
forall device dtype shape.
IsUnnamed t device dtype shape =>
t ->
Tensor device dtype shape
fromUnnamed ::
forall device dtype shape.
IsUnnamed t device dtype shape =>
Tensor device dtype shape ->
t
toDynamic ::
t -> D.Tensor
type family IsUnnamed t (device :: (D.DeviceType, Nat)) (dtype :: D.DType) (shape :: [Nat]) :: Constraint where
IsUnnamed t device dtype shape =
( Unnamed t,
device ~ (UTDevice t),
dtype ~ (UTDType t),
shape ~ (UTShape t)
)
instance Unnamed (Tensor device dtype shape) where
type UTShape (Tensor device dtype shape) = shape
type UTDevice (Tensor device dtype shape) = device
type UTDType (Tensor device dtype shape) = dtype
toUnnamed :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (Tensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> Tensor device dtype shape
toUnnamed = Tensor device dtype shape -> Tensor device dtype shape
Tensor device dtype shape -> Tensor device dtype shape
forall a. a -> a
forall {k} (cat :: k -> k -> Type) (a :: k).
Category cat =>
cat a a
id
fromUnnamed :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (Tensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> Tensor device dtype shape
fromUnnamed = Tensor device dtype shape -> Tensor device dtype shape
Tensor device dtype shape -> Tensor device dtype shape
forall a. a -> a
forall {k} (cat :: k -> k -> Type) (a :: k).
Category cat =>
cat a a
id
toDynamic :: Tensor device dtype shape -> Tensor
toDynamic (UnsafeMkTensor Tensor
t) = Tensor
t
data Tensor (device :: (D.DeviceType, Nat)) (dtype :: D.DType) (shape :: [Nat]) where
UnsafeMkTensor :: forall device dtype shape. D.Tensor -> Tensor device dtype shape
type CPUTensor = Tensor '( 'D.CPU, 0)
type CUDATensor deviceIndex = Tensor '( 'D.CUDA, deviceIndex)
type MPSTensor deviceIndex = Tensor '( 'D.MPS, 0)
data UnknownShapeTensor device dtype = forall shape. UnknownShapeTensor (Tensor device dtype shape)
type family ComputeHaskellType (dtype :: D.DType) :: Type where
ComputeHaskellType D.Bool = Bool
ComputeHaskellType D.Int64 = Int
ComputeHaskellType D.Float = Float
ComputeHaskellType D.Double = Double
ComputeHaskellType dtype = TypeError (Text "Unsupported tensor type " :<>: ShowType dtype)
type family ComputeItemType (ty :: Type) (shape :: [Nat]) :: Type where
ComputeItemType _ '[] = TypeError (Text "Scalars are not supported")
ComputeItemType ty (_ ': '[]) = ty
ComputeItemType ty (_ ': h ': t) = [ComputeItemType ty (h ': t)]
instance
( D.TensorLike [ComputeItemType (ComputeHaskellType dtype) shape],
KnownDevice device,
KnownShape shape
) =>
IsList (Maybe (Tensor device dtype shape))
where
type Item (Maybe (Tensor device dtype shape)) = ComputeItemType (ComputeHaskellType dtype) shape
fromList :: [Item (Maybe (Tensor device dtype shape))]
-> Maybe (Tensor device dtype shape)
fromList [Item (Maybe (Tensor device dtype shape))]
xs = do
shapeXs <- [ComputeItemType (ComputeHaskellType dtype) shape] -> Maybe [Int]
forall a. TensorLike a => a -> Maybe [Int]
D._deepDims [Item (Maybe (Tensor device dtype shape))]
[ComputeItemType (ComputeHaskellType dtype) shape]
xs
if shapeVal @shape == shapeXs
then return $ UnsafeMkTensor . D.toDevice (deviceVal @device) . D.asTensor $ xs
else Nothing
toList :: Maybe (Tensor device dtype shape)
-> [Item (Maybe (Tensor device dtype shape))]
toList Maybe (Tensor device dtype shape)
Nothing = []
toList (Just Tensor device dtype shape
t) = Tensor -> [Item (Maybe (Tensor device dtype shape))]
Tensor -> [ComputeItemType (ComputeHaskellType dtype) shape]
forall a. TensorLike a => Tensor -> a
D.asValue (Tensor -> [Item (Maybe (Tensor device dtype shape))])
-> (Tensor device dtype shape -> Tensor)
-> Tensor device dtype shape
-> [Item (Maybe (Tensor device dtype shape))]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Device -> Tensor -> Tensor
forall a. HasTypes a Tensor => Device -> a -> a
D.toDevice (DeviceType -> Int16 -> Device
D.Device DeviceType
D.CPU Int16
0) (Tensor -> Tensor)
-> (Tensor device dtype shape -> Tensor)
-> Tensor device dtype shape
-> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic (Tensor device dtype shape
-> [Item (Maybe (Tensor device dtype shape))])
-> Tensor device dtype shape
-> [Item (Maybe (Tensor device dtype shape))]
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape
t
instance KnownDevice device => Num (Tensor device dtype shape) where
+ :: Tensor device dtype shape
-> Tensor device dtype shape -> Tensor device dtype shape
(+) Tensor device dtype shape
a Tensor device dtype shape
b = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
+ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
b
(-) Tensor device dtype shape
a Tensor device dtype shape
b = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
- Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
b
* :: Tensor device dtype shape
-> Tensor device dtype shape -> Tensor device dtype shape
(*) Tensor device dtype shape
a Tensor device dtype shape
b = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a Tensor -> Tensor -> Tensor
forall a. Num a => a -> a -> a
* Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
b
negate :: Tensor device dtype shape -> Tensor device dtype shape
negate Tensor device dtype shape
t = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
forall a. Num a => a -> a
negate (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t
abs :: Tensor device dtype shape -> Tensor device dtype shape
abs Tensor device dtype shape
t = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
forall a. Num a => a -> a
abs (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t
signum :: Tensor device dtype shape -> Tensor device dtype shape
signum Tensor device dtype shape
t = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
forall a. Num a => a -> a
signum (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t
fromInteger :: Integer -> Tensor device dtype shape
fromInteger Integer
i = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> (Int -> Tensor) -> Int -> Tensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Device -> Tensor -> Tensor
forall a. HasTypes a Tensor => Device -> a -> a
D.toDevice (forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device) (Tensor -> Tensor) -> (Int -> Tensor) -> Int -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. TensorLike a => a -> Tensor
D.asTensor @Int (Int -> Tensor device dtype shape)
-> Int -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ forall a. Num a => Integer -> a
fromInteger @Int Integer
i
instance KnownDevice device => Fractional (Tensor device dtype shape) where
Tensor device dtype shape
a / :: Tensor device dtype shape
-> Tensor device dtype shape -> Tensor device dtype shape
/ Tensor device dtype shape
b = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a Tensor -> Tensor -> Tensor
forall a. Fractional a => a -> a -> a
/ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
b
recip :: Tensor device dtype shape -> Tensor device dtype shape
recip Tensor device dtype shape
t = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
forall a. Fractional a => a -> a
recip (Tensor -> Tensor) -> Tensor -> Tensor
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t
fromRational :: Rational -> Tensor device dtype shape
fromRational Rational
i = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> (Float -> Tensor) -> Float -> Tensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Device -> Tensor -> Tensor
forall a. HasTypes a Tensor => Device -> a -> a
D.toDevice (forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device) (Tensor -> Tensor) -> (Float -> Tensor) -> Float -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. TensorLike a => a -> Tensor
D.asTensor @Float (Float -> Tensor device dtype shape)
-> Float -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ forall a. Fractional a => Rational -> a
fromRational @Float Rational
i
instance Show (Tensor device dtype shape) where
show :: Tensor device dtype shape -> String
show (UnsafeMkTensor Tensor
dynamic) = Tensor -> String
forall a. Show a => a -> String
show Tensor
dynamic
class TensorOptions (shape :: [Nat]) (dtype :: D.DType) (device :: (D.DeviceType, Nat)) where
optionsRuntimeShape :: [Int]
optionsRuntimeDType :: D.DType
optionsRuntimeDevice :: D.Device
instance (KnownDType dtype, KnownDevice device) => TensorOptions '[] dtype device where
optionsRuntimeShape :: [Int]
optionsRuntimeShape = []
optionsRuntimeDType :: DType
optionsRuntimeDType = forall (dtype :: DType). KnownDType dtype => DType
dtypeVal @dtype
optionsRuntimeDevice :: Device
optionsRuntimeDevice = forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device
instance (KnownNat h, TensorOptions t dtype device) => TensorOptions (h ': t) dtype device where
optionsRuntimeShape :: [Int]
optionsRuntimeShape = forall (n :: Nat). KnownNat n => Int
natValI @h Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @t @dtype @device
optionsRuntimeDType :: DType
optionsRuntimeDType = forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @t @dtype @device
optionsRuntimeDevice :: Device
optionsRuntimeDevice = forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @t @dtype @device
type family All (pred :: a -> Constraint) (l :: [a]) :: Constraint where
All _ '[] = ()
All pred (h ': t) = (pred h, All pred t)
data SomeShape where
SomeShape :: forall (shape :: [Nat]). KnownShape shape => Proxy shape -> SomeShape
someShape :: [Int] -> SomeShape
someShape :: [Int] -> SomeShape
someShape [] = Proxy '[] -> SomeShape
forall (shape :: [Nat]).
KnownShape shape =>
Proxy shape -> SomeShape
SomeShape (Proxy '[] -> SomeShape) -> Proxy '[] -> SomeShape
forall a b. (a -> b) -> a -> b
$ forall (t :: [Nat]). Proxy t
forall {k} (t :: k). Proxy t
Proxy @'[]
someShape (Int
h : [Int]
t) = case Integer -> Maybe SomeNat
someNatVal (Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
h) of
Maybe SomeNat
Nothing -> String -> SomeShape
forall a. HasCallStack => String -> a
error String
"Negative dimension in someShape!"
(Just (SomeNat (Proxy n
Proxy :: Proxy ht))) -> case [Int] -> SomeShape
someShape [Int]
t of
(SomeShape (Proxy shape
Proxy :: Proxy tt)) -> Proxy (n : shape) -> SomeShape
forall (shape :: [Nat]).
KnownShape shape =>
Proxy shape -> SomeShape
SomeShape (Proxy (n : shape) -> SomeShape) -> Proxy (n : shape) -> SomeShape
forall a b. (a -> b) -> a -> b
$ forall (t :: [Nat]). Proxy t
forall {k} (t :: k). Proxy t
Proxy @(ht ': tt)
data SomeDType where
SomeDType :: forall (dtype :: D.DType). KnownDType dtype => Proxy dtype -> SomeDType
someDType :: D.DType -> SomeDType
someDType :: DType -> SomeDType
someDType DType
D.Bool = Proxy 'Bool -> SomeDType
forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType (Proxy 'Bool -> SomeDType) -> Proxy 'Bool -> SomeDType
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: DType). Proxy t
Proxy @D.Bool
someDType DType
D.UInt8 = Proxy 'UInt8 -> SomeDType
forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType (Proxy 'UInt8 -> SomeDType) -> Proxy 'UInt8 -> SomeDType
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: DType). Proxy t
Proxy @D.UInt8
someDType DType
D.Int8 = Proxy 'Int8 -> SomeDType
forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType (Proxy 'Int8 -> SomeDType) -> Proxy 'Int8 -> SomeDType
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: DType). Proxy t
Proxy @D.Int8
someDType DType
D.Int16 = Proxy 'Int16 -> SomeDType
forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType (Proxy 'Int16 -> SomeDType) -> Proxy 'Int16 -> SomeDType
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: DType). Proxy t
Proxy @D.Int16
someDType DType
D.Int32 = Proxy 'Int32 -> SomeDType
forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType (Proxy 'Int32 -> SomeDType) -> Proxy 'Int32 -> SomeDType
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: DType). Proxy t
Proxy @D.Int32
someDType DType
D.Int64 = Proxy 'Int64 -> SomeDType
forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType (Proxy 'Int64 -> SomeDType) -> Proxy 'Int64 -> SomeDType
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: DType). Proxy t
Proxy @D.Int64
someDType DType
D.Half = Proxy 'Half -> SomeDType
forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType (Proxy 'Half -> SomeDType) -> Proxy 'Half -> SomeDType
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: DType). Proxy t
Proxy @D.Half
someDType DType
D.Float = Proxy 'Float -> SomeDType
forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType (Proxy 'Float -> SomeDType) -> Proxy 'Float -> SomeDType
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: DType). Proxy t
Proxy @D.Float
someDType DType
D.Double = Proxy 'Double -> SomeDType
forall (shape :: DType).
KnownDType shape =>
Proxy shape -> SomeDType
SomeDType (Proxy 'Double -> SomeDType) -> Proxy 'Double -> SomeDType
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: DType). Proxy t
Proxy @D.Double
data SomeDevice where
SomeDevice :: forall (device :: (D.DeviceType, Nat)). KnownDevice device => Proxy device -> SomeDevice
someDevice :: D.Device -> SomeDevice
someDevice :: Device -> SomeDevice
someDevice D.Device {Int16
DeviceType
deviceType :: DeviceType
deviceIndex :: Int16
deviceIndex :: Device -> Int16
deviceType :: Device -> DeviceType
..} = case Integer -> Maybe SomeNat
someNatVal (Int16 -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
deviceIndex) of
Maybe SomeNat
Nothing -> String -> SomeDevice
forall a. HasCallStack => String -> a
error String
"Negative device index in someDevice!"
Just (SomeNat (Proxy n
Proxy :: Proxy n)) -> case DeviceType
deviceType of
DeviceType
D.CPU -> Proxy '( 'CPU, n) -> SomeDevice
forall (shape :: (DeviceType, Nat)).
KnownDevice shape =>
Proxy shape -> SomeDevice
SomeDevice (Proxy '( 'CPU, n) -> SomeDevice)
-> Proxy '( 'CPU, n) -> SomeDevice
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: (DeviceType, Nat)). Proxy t
Proxy @'( 'D.CPU, n)
DeviceType
D.CUDA -> Proxy '( 'CUDA, n) -> SomeDevice
forall (shape :: (DeviceType, Nat)).
KnownDevice shape =>
Proxy shape -> SomeDevice
SomeDevice (Proxy '( 'CUDA, n) -> SomeDevice)
-> Proxy '( 'CUDA, n) -> SomeDevice
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: (DeviceType, Nat)). Proxy t
Proxy @'( 'D.CUDA, n)
DeviceType
D.MPS -> Proxy '( 'MPS, n) -> SomeDevice
forall (shape :: (DeviceType, Nat)).
KnownDevice shape =>
Proxy shape -> SomeDevice
SomeDevice (Proxy '( 'MPS, n) -> SomeDevice)
-> Proxy '( 'MPS, n) -> SomeDevice
forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). Proxy t
forall (t :: (DeviceType, Nat)). Proxy t
Proxy @'( 'D.MPS, n)
withTensor ::
D.Tensor ->
( forall shape dtype device.
( KnownDevice device,
KnownDType dtype,
KnownShape shape
) =>
Tensor device dtype shape ->
r
) ->
r
withTensor :: forall r.
Tensor
-> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownDevice device, KnownDType dtype, KnownShape shape) =>
Tensor device dtype shape -> r)
-> r
withTensor Tensor
untypedTensor forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownDevice device, KnownDType dtype, KnownShape shape) =>
Tensor device dtype shape -> r
f = case [Int] -> SomeShape
someShape (Tensor -> [Int]
D.shape Tensor
untypedTensor) of
(SomeShape (Proxy shape
Proxy :: Proxy shape)) -> case DType -> SomeDType
someDType (Tensor -> DType
D.dtype Tensor
untypedTensor) of
(SomeDType (Proxy dtype
Proxy :: Proxy dtype)) -> case Device -> SomeDevice
someDevice (Tensor -> Device
D.device Tensor
untypedTensor) of
(SomeDevice (Proxy device
Proxy :: Proxy device)) -> Tensor device dtype shape -> r
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownDevice device, KnownDType dtype, KnownShape shape) =>
Tensor device dtype shape -> r
f (Tensor device dtype shape -> r) -> Tensor device dtype shape -> r
forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor @device @dtype @shape Tensor
untypedTensor
withTensorShape ::
forall device dtype r.
( KnownDevice device,
KnownDType dtype
) =>
D.Tensor ->
( forall shape.
KnownShape shape =>
Tensor device dtype shape ->
r
) ->
r
withTensorShape :: forall (device :: (DeviceType, Nat)) (dtype :: DType) r.
(KnownDevice device, KnownDType dtype) =>
Tensor
-> (forall (shape :: [Nat]).
KnownShape shape =>
Tensor device dtype shape -> r)
-> r
withTensorShape Tensor
untypedTensor forall (shape :: [Nat]).
KnownShape shape =>
Tensor device dtype shape -> r
f = case [Int] -> SomeShape
someShape (Tensor -> [Int]
D.shape Tensor
untypedTensor) of
(SomeShape (Proxy shape
Proxy :: Proxy shape)) -> Tensor device dtype shape -> r
forall (shape :: [Nat]).
KnownShape shape =>
Tensor device dtype shape -> r
f (Tensor device dtype shape -> r) -> Tensor device dtype shape -> r
forall a b. (a -> b) -> a -> b
$ forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor @device @dtype @shape Tensor
untypedTensor
type family ComputeBroadcast (reversedShape :: [Nat]) (reversedShape' :: [Nat]) :: Maybe [Nat] where
ComputeBroadcast '[] reversedShape = Just reversedShape
ComputeBroadcast reversedShape '[] = Just reversedShape
ComputeBroadcast (h ': t) (h ': t2) = AppendToMaybe h (ComputeBroadcast t t2)
ComputeBroadcast (h ': t) (1 ': t2) = AppendToMaybe h (ComputeBroadcast t t2)
ComputeBroadcast (1 ': t) (h ': t2) = AppendToMaybe h (ComputeBroadcast t t2)
ComputeBroadcast _ _ = Nothing
type family CheckBroadcast (shape :: [Nat]) (shape' :: [Nat]) (result :: Maybe [Nat]) :: [Nat] where
CheckBroadcast shape shape' Nothing =
TypeError
( Text "The shapes "
:<>: ShowType shape
:<>: Text " and "
:<>: ShowType shape'
:<>: Text " cannot be broadcast"
)
CheckBroadcast _ _ (Just result) = (Reverse result)
type Broadcast shape shape' =
CheckBroadcast
shape
shape'
( ComputeBroadcast
(Reverse shape)
(Reverse shape')
)
type family BasicArithmeticDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
BasicArithmeticDTypeIsValid '( 'D.CPU, 0) dtype =
( DTypeIsNotBool '( 'D.CPU, 0) dtype,
DTypeIsNotHalf '( 'D.CPU, 0) dtype
)
BasicArithmeticDTypeIsValid '( 'D.CUDA, _) dtype = ()
BasicArithmeticDTypeIsValid '( 'D.MPS, 0) dtype = ()
BasicArithmeticDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
add,
sub,
mul,
div ::
forall shape'' shape shape' dtype dtype' dtype'' device.
( dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype''
) =>
Tensor device dtype shape ->
Tensor device dtype' shape' ->
Tensor device dtype'' shape''
add :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
add Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device dtype'' shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype'' shape'')
-> Tensor -> Tensor device dtype'' shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.add (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
sub :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
sub Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device dtype'' shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype'' shape'')
-> Tensor -> Tensor device dtype'' shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.sub (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
mul :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device dtype'' shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype'' shape'')
-> Tensor -> Tensor device dtype'' shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.mul (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
div :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
div Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device dtype'' shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype'' shape'')
-> Tensor -> Tensor device dtype'' shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.div (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
type family ComparisonDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
ComparisonDTypeIsValid '( 'D.CPU, 0) dtype =
( DTypeIsNotBool '( 'D.CPU, 0) dtype,
DTypeIsNotHalf '( 'D.CPU, 0) dtype
)
ComparisonDTypeIsValid '( 'D.CUDA, _) dtype = ()
ComparisonDTypeIsValid '( 'D.MPS, 0) dtype = ()
ComparisonDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
gt,
lt,
ge,
le,
eq,
ne,
(>.),
(<.),
(>=.),
(<=.),
(==.),
(/=.) ::
forall shape'' shape shape' dtype dtype' device.
( shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype'
) =>
Tensor device dtype shape ->
Tensor device dtype' shape' ->
Tensor device 'D.Bool shape''
gt :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
gt Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device 'Bool shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device 'Bool shape'')
-> Tensor -> Tensor device 'Bool shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.gt (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
lt :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
lt Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device 'Bool shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device 'Bool shape'')
-> Tensor -> Tensor device 'Bool shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.lt (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
ge :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
ge Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device 'Bool shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device 'Bool shape'')
-> Tensor -> Tensor device 'Bool shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.ge (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
le :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
le Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device 'Bool shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device 'Bool shape'')
-> Tensor -> Tensor device 'Bool shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.le (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
eq :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
eq Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device 'Bool shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device 'Bool shape'')
-> Tensor -> Tensor device 'Bool shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.eq (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
ne :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
ne Tensor device dtype shape
a Tensor device dtype' shape'
b = Tensor -> Tensor device 'Bool shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device 'Bool shape'')
-> Tensor -> Tensor device 'Bool shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.ne (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype' shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype' shape'
b)
>. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(>.) = Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
gt
<. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(<.) = Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
lt
>=. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(>=.) = Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
ge
<=. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(<=.) = Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
le
==. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(==.) = Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
eq
/=. :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
(/=.) = Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
ne
type family ComputeMatMul (reversedShape :: [Nat]) (reversedShape' :: [Nat]) :: Maybe [Nat] where
ComputeMatMul (k ': '[]) (k ': '[]) = Just '[]
ComputeMatMul (k ': '[]) (m ': k ': reversedBroadcastShape') = AppendToMaybe m (ComputeBroadcast '[] reversedBroadcastShape')
ComputeMatMul (k ': n ': reversedBroadcastShape) (k ': '[]) = AppendToMaybe n (ComputeBroadcast '[] reversedBroadcastShape)
ComputeMatMul (k ': n ': reversedBroadcastShape) (m ': k ': reversedBroadcastShape') = AppendToMaybe m (AppendToMaybe n (ComputeBroadcast reversedBroadcastShape reversedBroadcastShape'))
type family CheckMatMul (shape :: [Nat]) (shape' :: [Nat]) (result :: Maybe [Nat]) :: [Nat] where
CheckMatMul shape shape' Nothing =
TypeError
( Text "The shapes "
:<>: ShowType shape
:<>: Text " and "
:<>: ShowType shape'
:<>: Text " are not compatible with matrix multiplication"
)
CheckMatMul _ _ (Just result) = (Reverse result)
type MatMul shape shape' = CheckMatMul shape shape' (ComputeMatMul (Reverse shape) (Reverse shape'))
type family MatMulDTypeIsValid (device :: (D.DeviceType, Nat)) (dtype :: D.DType) :: Constraint where
MatMulDTypeIsValid '( 'D.CPU, 0) dtype =
( DTypeIsNotBool '( 'D.CPU, 0) dtype,
DTypeIsNotHalf '( 'D.CPU, 0) dtype
)
MatMulDTypeIsValid '( 'D.CUDA, deviceIndex) dtype = DTypeIsFloatingPoint '( 'D.CUDA, deviceIndex) dtype
MatMulDTypeIsValid '( 'D.MPS, 0) dtype = DTypeIsFloatingPoint '( 'D.MPS, 0) dtype
MatMulDTypeIsValid '(deviceType, _) dtype = UnsupportedDTypeForDevice deviceType dtype
matmul ::
forall shape'' shape shape' dtype device.
( shape'' ~ MatMul shape shape',
MatMulDTypeIsValid device dtype
) =>
Tensor device dtype shape ->
Tensor device dtype shape' ->
Tensor device dtype shape''
matmul :: forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
matmul Tensor device dtype shape
a Tensor device dtype shape'
b = Tensor -> Tensor device dtype shape''
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape'')
-> Tensor -> Tensor device dtype shape''
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor -> Tensor
D.matmul (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
a) (Tensor device dtype shape' -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape'
b)
select ::
forall dim idx shape' shape dtype device.
( KnownNat dim,
KnownNat idx,
InRange shape dim idx,
shape' ~ Remove shape dim
) =>
Tensor device dtype shape ->
Tensor device dtype shape'
select :: forall (dim :: Nat) (idx :: Nat) (shape' :: [Nat]) (shape :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, KnownNat idx, InRange shape dim idx,
shape' ~ Remove shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
select Tensor device dtype shape
t = Tensor -> Tensor device dtype shape'
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape')
-> Tensor -> Tensor device dtype shape'
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Tensor -> Tensor
D.select (forall (n :: Nat). KnownNat n => Int
natValI @dim) (forall (n :: Nat). KnownNat n => Int
natValI @idx) (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)
selectIdx ::
forall dim n shape' shape dtype device.
( KnownNat dim,
n ~ Index shape dim,
shape' ~ Remove shape dim
) =>
Tensor device dtype shape ->
Finite n ->
Tensor device dtype shape'
selectIdx :: forall (dim :: Nat) (n :: Nat) (shape' :: [Nat]) (shape :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, n ~ Index shape dim, shape' ~ Remove shape dim) =>
Tensor device dtype shape -> Finite n -> Tensor device dtype shape'
selectIdx Tensor device dtype shape
t Finite n
idx = Tensor -> Tensor device dtype shape'
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape')
-> Tensor -> Tensor device dtype shape'
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Tensor -> Tensor
D.select (forall (n :: Nat). KnownNat n => Int
natValI @dim) (Finite n -> Int
forall (n :: Nat). Finite n -> Int
getFiniteI Finite n
idx) (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)
type family Numel (shape :: [Nat]) :: Nat where
Numel '[] = 1
Numel (h ': t) = h * (Numel t)
reshape ::
forall shape' shape dtype device.
( KnownShape shape',
Numel shape ~ Numel shape'
) =>
Tensor device dtype shape ->
Tensor device dtype shape'
reshape :: forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape Tensor device dtype shape
t = Tensor -> Tensor device dtype shape'
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape')
-> Tensor -> Tensor device dtype shape'
forall a b. (a -> b) -> a -> b
$ [Int] -> Tensor -> Tensor
D.reshape (forall (shape :: [Nat]). KnownShape shape => [Int]
shapeVal @shape') (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)
newtype Wrap a = Wrap {forall a. Wrap a -> a
unWrap :: a}
instance {-# OVERLAPS #-} Unnamed t => Castable (Wrap t) D.ATenTensor where
cast :: forall r. Wrap t -> (ATenTensor -> IO r) -> IO r
cast Wrap t
t ATenTensor -> IO r
f =
let (D.Unsafe ATenTensor
aten_tensor) = t -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic (Wrap t -> t
forall a. Wrap a -> a
unWrap Wrap t
t)
in ATenTensor -> IO r
f ATenTensor
aten_tensor
uncast :: forall r. ATenTensor -> (Wrap t -> IO r) -> IO r
uncast ATenTensor
aten_tensor Wrap t -> IO r
f = Wrap t -> IO r
f (Wrap t -> IO r) -> Wrap t -> IO r
forall a b. (a -> b) -> a -> b
$ t -> Wrap t
forall a. a -> Wrap a
Wrap (t -> Wrap t) -> t -> Wrap t
forall a b. (a -> b) -> a -> b
$ Tensor (UTDevice t) (UTDType t) (UTShape t) -> t
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed t device dtype shape =>
Tensor device dtype shape -> t
fromUnnamed (Tensor (UTDevice t) (UTDType t) (UTShape t) -> t)
-> Tensor (UTDevice t) (UTDType t) (UTShape t) -> t
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor (UTDevice t) (UTDType t) (UTShape t)
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (ATenTensor -> Tensor
D.Unsafe ATenTensor
aten_tensor)
instance Castable (NamedTensor device dtype shape) D.ATenTensor where
cast :: forall r.
NamedTensor device dtype shape -> (ATenTensor -> IO r) -> IO r
cast (FromTensor (UnsafeMkTensor (D.Unsafe ATenTensor
aten_tensor))) ATenTensor -> IO r
f = ATenTensor -> IO r
f ATenTensor
aten_tensor
uncast :: forall r.
ATenTensor -> (NamedTensor device dtype shape -> IO r) -> IO r
uncast ATenTensor
aten_tensor NamedTensor device dtype shape -> IO r
f = NamedTensor device dtype shape -> IO r
f (NamedTensor device dtype shape -> IO r)
-> (Tensor -> NamedTensor device dtype shape) -> Tensor -> IO r
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape' :: Shape) (shape :: [Nat]).
(shape ~ ToNats shape') =>
Tensor device dtype shape -> NamedTensor device dtype shape'
FromTensor (Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape)
-> (Tensor -> Tensor device dtype (ToNats shape))
-> Tensor
-> NamedTensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor -> Tensor device dtype (ToNats shape)
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> IO r) -> Tensor -> IO r
forall a b. (a -> b) -> a -> b
$ ATenTensor -> Tensor
D.Unsafe ATenTensor
aten_tensor
instance Castable (Tensor device dtype shape) D.ATenTensor where
cast :: forall r. Tensor device dtype shape -> (ATenTensor -> IO r) -> IO r
cast (UnsafeMkTensor (D.Unsafe ATenTensor
aten_tensor)) ATenTensor -> IO r
f = ATenTensor -> IO r
f ATenTensor
aten_tensor
uncast :: forall r. ATenTensor -> (Tensor device dtype shape -> IO r) -> IO r
uncast ATenTensor
aten_tensor Tensor device dtype shape -> IO r
f = Tensor device dtype shape -> IO r
f (Tensor device dtype shape -> IO r)
-> Tensor device dtype shape -> IO r
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (ATenTensor -> Tensor
D.Unsafe ATenTensor
aten_tensor)
instance Castable [Tensor device dtype shape] (ForeignPtr ATen.TensorList) where
cast :: forall r.
[Tensor device dtype shape]
-> (ForeignPtr TensorList -> IO r) -> IO r
cast [Tensor device dtype shape]
xs ForeignPtr TensorList -> IO r
f = do
ptr_list <- (Tensor device dtype shape -> IO ATenTensor)
-> [Tensor device dtype shape] -> IO [ATenTensor]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM (\Tensor device dtype shape
x -> (Tensor device dtype shape
-> (ATenTensor -> IO ATenTensor) -> IO ATenTensor
forall r. Tensor device dtype shape -> (ATenTensor -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor device dtype shape
x ATenTensor -> IO ATenTensor
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) [Tensor device dtype shape]
xs
cast ptr_list f
uncast :: forall r.
ForeignPtr TensorList
-> ([Tensor device dtype shape] -> IO r) -> IO r
uncast ForeignPtr TensorList
xs [Tensor device dtype shape] -> IO r
f = ForeignPtr TensorList -> ([ATenTensor] -> IO r) -> IO r
forall r. ForeignPtr TensorList -> ([ATenTensor] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs (([ATenTensor] -> IO r) -> IO r) -> ([ATenTensor] -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \[ATenTensor]
ptr_list -> do
tensor_list <- (ATenTensor -> IO (Tensor device dtype shape))
-> [ATenTensor] -> IO [Tensor device dtype shape]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM (\(ATenTensor
x :: ForeignPtr ATen.Tensor) -> ATenTensor
-> (Tensor device dtype shape -> IO (Tensor device dtype shape))
-> IO (Tensor device dtype shape)
forall r. ATenTensor -> (Tensor device dtype shape -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ATenTensor
x Tensor device dtype shape -> IO (Tensor device dtype shape)
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return) [ATenTensor]
ptr_list
f tensor_list
instance KnownNat n => Castable (Vector n (Tensor device dtype shape)) (ForeignPtr ATen.TensorList) where
cast :: forall r.
Vector n (Tensor device dtype shape)
-> (ForeignPtr TensorList -> IO r) -> IO r
cast Vector n (Tensor device dtype shape)
xs ForeignPtr TensorList -> IO r
f = do
ptr_list <- Vector n ATenTensor -> [ATenTensor]
forall (n :: Nat) a. Vector n a -> [a]
V.toList (Vector n ATenTensor -> [ATenTensor])
-> IO (Vector n ATenTensor) -> IO [ATenTensor]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tensor device dtype shape -> IO ATenTensor)
-> Vector n (Tensor device dtype shape) -> IO (Vector n ATenTensor)
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> Vector Vector n a -> m (Vector Vector n b)
mapM (\Tensor device dtype shape
x -> (Tensor device dtype shape
-> (ATenTensor -> IO ATenTensor) -> IO ATenTensor
forall r. Tensor device dtype shape -> (ATenTensor -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast Tensor device dtype shape
x ATenTensor -> IO ATenTensor
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return :: IO (ForeignPtr ATen.Tensor))) Vector n (Tensor device dtype shape)
xs
cast ptr_list f
uncast :: forall r.
ForeignPtr TensorList
-> (Vector n (Tensor device dtype shape) -> IO r) -> IO r
uncast ForeignPtr TensorList
xs Vector n (Tensor device dtype shape) -> IO r
f = ForeignPtr TensorList -> ([ATenTensor] -> IO r) -> IO r
forall r. ForeignPtr TensorList -> ([ATenTensor] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs (([ATenTensor] -> IO r) -> IO r) -> ([ATenTensor] -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \[ATenTensor]
ptr_list -> do
tensor_list <- (ATenTensor -> IO (Tensor device dtype shape))
-> [ATenTensor] -> IO [Tensor device dtype shape]
forall (t :: Type -> Type) (m :: Type -> Type) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: Type -> Type) a b.
Monad m =>
(a -> m b) -> [a] -> m [b]
mapM (\(ATenTensor
x :: ForeignPtr ATen.Tensor) -> ATenTensor
-> (Tensor device dtype shape -> IO (Tensor device dtype shape))
-> IO (Tensor device dtype shape)
forall r. ATenTensor -> (Tensor device dtype shape -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ATenTensor
x Tensor device dtype shape -> IO (Tensor device dtype shape)
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return) [ATenTensor]
ptr_list
Just xs <- pure $ V.fromListN tensor_list
f xs
data TensorListFold = TensorListFold
instance (Castable x D.ATenTensor) => Apply' TensorListFold (x, IO [D.ATenTensor]) (IO [D.ATenTensor]) where
apply' :: TensorListFold -> (x, IO [ATenTensor]) -> IO [ATenTensor]
apply' TensorListFold
_ (x
x, IO [ATenTensor]
mxs) = do
xs <- IO [ATenTensor]
mxs
x' <- cast x return
return (x' : xs)
data TensorListUnfold = TensorListUnfold
instance Apply TensorListUnfold [D.ATenTensor] (IO HNothing) where
apply :: TensorListUnfold -> [ATenTensor] -> IO HNothing
apply TensorListUnfold
_ [] = HNothing -> IO HNothing
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure HNothing
HNothing
instance (Castable x D.ATenTensor) => Apply TensorListUnfold [D.ATenTensor] (IO (HJust (x, [D.ATenTensor]))) where
apply :: TensorListUnfold -> [ATenTensor] -> IO (HJust (x, [ATenTensor]))
apply TensorListUnfold
_ (ATenTensor
x : [ATenTensor]
xs) = do
x' <- ATenTensor -> (x -> IO x) -> IO x
forall r. ATenTensor -> (x -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ATenTensor
x x -> IO x
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return
return $ HJust (x', xs)
instance
( HFoldrM IO TensorListFold [D.ATenTensor] l [D.ATenTensor],
Apply TensorListUnfold [D.ATenTensor] res,
HUnfoldM IO TensorListUnfold res l,
res ~ (HUnfoldMRes IO [D.ATenTensor] l)
) =>
Castable (HList l) [D.ATenTensor]
where
cast :: forall r. HList l -> ([ATenTensor] -> IO r) -> IO r
cast HList l
xs [ATenTensor] -> IO r
f = [ATenTensor] -> IO r
f ([ATenTensor] -> IO r) -> IO [ATenTensor] -> IO r
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< HList l -> IO [ATenTensor]
go HList l
xs
where
go :: HList l -> IO [D.ATenTensor]
go :: HList l -> IO [ATenTensor]
go HList l
xs = TensorListFold -> [ATenTensor] -> HList l -> IO [ATenTensor]
forall {k} {k1} (m :: k -> Type) f acc (xs :: [k1]) (res :: k).
HFoldrM m f acc xs res =>
f -> acc -> HList xs -> m res
hfoldrM TensorListFold
TensorListFold ([] :: [D.ATenTensor]) HList l
xs
uncast :: forall r. [ATenTensor] -> (HList l -> IO r) -> IO r
uncast [ATenTensor]
xs HList l -> IO r
f = HList l -> IO r
f (HList l -> IO r) -> IO (HList l) -> IO r
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< [ATenTensor] -> IO (HList l)
go [ATenTensor]
xs
where
go :: [D.ATenTensor] -> IO (HList l)
go :: [ATenTensor] -> IO (HList l)
go [ATenTensor]
xs = TensorListUnfold -> [ATenTensor] -> IO (HList l)
forall (m :: Type -> Type) f res (xs :: [Type]) a.
(HUnfoldM m f res xs, Apply f a res, res ~ HUnfoldMRes m a xs) =>
f -> a -> m (HList xs)
hunfoldrM TensorListUnfold
TensorListUnfold [ATenTensor]
xs
instance Castable (HList l) [D.ATenTensor] => Castable (HList l) (ForeignPtr ATen.TensorList) where
cast :: forall r. HList l -> (ForeignPtr TensorList -> IO r) -> IO r
cast HList l
xs ForeignPtr TensorList -> IO r
f = do
ts <- HList l -> ([ATenTensor] -> IO [ATenTensor]) -> IO [ATenTensor]
forall r. HList l -> ([ATenTensor] -> IO r) -> IO r
forall a b r. Castable a b => a -> (b -> IO r) -> IO r
cast HList l
xs [ATenTensor] -> IO [ATenTensor]
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return :: IO [ForeignPtr ATen.Tensor]
cast ts f
uncast :: forall r. ForeignPtr TensorList -> (HList l -> IO r) -> IO r
uncast ForeignPtr TensorList
xs HList l -> IO r
f = ForeignPtr TensorList -> ([ATenTensor] -> IO r) -> IO r
forall r. ForeignPtr TensorList -> ([ATenTensor] -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast ForeignPtr TensorList
xs (([ATenTensor] -> IO r) -> IO r) -> ([ATenTensor] -> IO r) -> IO r
forall a b. (a -> b) -> a -> b
$ \([ATenTensor]
ptrList :: [ForeignPtr ATen.Tensor]) -> do
ts <- [ATenTensor] -> (HList l -> IO (HList l)) -> IO (HList l)
forall r. [ATenTensor] -> (HList l -> IO r) -> IO r
forall a b r. Castable a b => b -> (a -> IO r) -> IO r
uncast [ATenTensor]
ptrList HList l -> IO (HList l)
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return :: IO (HList l)
f ts
toSparse :: Tensor device dtype shape -> Tensor device dtype shape
toSparse :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor device dtype shape -> Tensor device dtype shape
toSparse Tensor device dtype shape
t = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
D.toSparse (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)
toDense :: Tensor device dtype shape -> Tensor device dtype shape
toDense :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor device dtype shape -> Tensor device dtype shape
toDense Tensor device dtype shape
t = Tensor -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor device dtype shape)
-> Tensor -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
D.toDense (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)
toCPU ::
forall device shape dtype.
Tensor device dtype shape ->
CPUTensor dtype shape
toCPU :: forall (device :: (DeviceType, Nat)) (shape :: [Nat])
(dtype :: DType).
Tensor device dtype shape -> CPUTensor dtype shape
toCPU Tensor device dtype shape
input = Tensor -> Tensor '( 'CPU, 0) dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor '( 'CPU, 0) dtype shape)
-> Tensor -> Tensor '( 'CPU, 0) dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
D.toCPU (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
input)
toCUDA ::
forall device' device shape dtype.
Tensor device dtype shape ->
CUDATensor 0 dtype shape
toCUDA :: forall {k} (device' :: k) (device :: (DeviceType, Nat))
(shape :: [Nat]) (dtype :: DType).
Tensor device dtype shape -> CUDATensor 0 dtype shape
toCUDA Tensor device dtype shape
t = Tensor -> Tensor '( 'CUDA, 0) dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor '( 'CUDA, 0) dtype shape)
-> Tensor -> Tensor '( 'CUDA, 0) dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
D.toCUDA (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)
toMPS ::
forall device' device shape dtype.
Tensor device dtype shape ->
MPSTensor 0 dtype shape
toMPS :: forall {k} (device' :: k) (device :: (DeviceType, Nat))
(shape :: [Nat]) (dtype :: DType).
Tensor device dtype shape -> MPSTensor 0 dtype shape
toMPS Tensor device dtype shape
t = Tensor -> Tensor '( 'MPS, 0) dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor '( 'MPS, 0) dtype shape)
-> Tensor -> Tensor '( 'MPS, 0) dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor -> Tensor
D.toMPS (Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t)
toDevice ::
forall device' device dtype shape t t'.
( KnownDevice device',
IsUnnamed t device dtype shape,
Unnamed t',
t' ~ ReplaceDevice'' t device'
) =>
t ->
t'
toDevice :: forall (device' :: (DeviceType, Nat)) (device :: (DeviceType, Nat))
(dtype :: DType) (shape :: [Nat]) t t'.
(KnownDevice device', IsUnnamed t device dtype shape, Unnamed t',
t' ~ ReplaceDevice'' t device') =>
t -> t'
toDevice = Tensor (UTDevice t') (UTDType t') (UTShape t') -> t'
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed t' device dtype shape =>
Tensor device dtype shape -> t'
fromUnnamed (Tensor (UTDevice t') (UTDType t') (UTShape t') -> t')
-> (t -> Tensor (UTDevice t') (UTDType t') (UTShape t')) -> t -> t'
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor -> Tensor (UTDevice t') (UTDType t') (UTShape t')
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor (UTDevice t') (UTDType t') (UTShape t'))
-> (t -> Tensor)
-> t
-> Tensor (UTDevice t') (UTDType t') (UTShape t')
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Device -> Tensor -> Tensor
forall a. HasTypes a Tensor => Device -> a -> a
D.toDevice (forall (device :: (DeviceType, Nat)). KnownDevice device => Device
deviceVal @device') (Tensor -> Tensor) -> (t -> Tensor) -> t -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. t -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic
toDType ::
forall dtype' dtype device shape t t'.
( KnownDType dtype',
IsUnnamed t device dtype shape,
Unnamed t',
t' ~ ReplaceDType'' t dtype'
) =>
t ->
t'
toDType :: forall (dtype' :: DType) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
t' ~ ReplaceDType'' t dtype') =>
t -> t'
toDType = Tensor (UTDevice t') (UTDType t') (UTShape t') -> t'
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed t' device dtype shape =>
Tensor device dtype shape -> t'
fromUnnamed (Tensor (UTDevice t') (UTDType t') (UTShape t') -> t')
-> (t -> Tensor (UTDevice t') (UTDType t') (UTShape t')) -> t -> t'
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor -> Tensor (UTDevice t') (UTDType t') (UTShape t')
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> Tensor (UTDevice t') (UTDType t') (UTShape t'))
-> (t -> Tensor)
-> t
-> Tensor (UTDevice t') (UTDType t') (UTShape t')
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. DType -> Tensor -> Tensor
forall a. HasTypes a Tensor => DType -> a -> a
D.toType (forall (dtype :: DType). KnownDType dtype => DType
dtypeVal @dtype') (Tensor -> Tensor) -> (t -> Tensor) -> t -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. t -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic
dim ::
forall device dtype shape t.
( TensorOptions shape dtype device,
IsUnnamed t device dtype shape
) =>
t ->
Int
dim :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]) t.
(TensorOptions shape dtype device,
IsUnnamed t device dtype shape) =>
t -> Int
dim t
t = [Int] -> Int
forall a. [a] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device
shape ::
forall device dtype shape t.
( TensorOptions shape dtype device,
IsUnnamed t device dtype shape
) =>
t ->
[Int]
shape :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]) t.
(TensorOptions shape dtype device,
IsUnnamed t device dtype shape) =>
t -> [Int]
shape t
_ = forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
[Int]
optionsRuntimeShape @shape @dtype @device
dtype ::
forall device dtype shape t.
( TensorOptions shape dtype device,
IsUnnamed t device dtype shape
) =>
t ->
D.DType
dtype :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]) t.
(TensorOptions shape dtype device,
IsUnnamed t device dtype shape) =>
t -> DType
dtype t
_ = forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
DType
optionsRuntimeDType @shape @dtype @device
device ::
forall device dtype shape t.
( TensorOptions shape dtype device,
IsUnnamed t device dtype shape
) =>
t ->
D.Device
device :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]) t.
(TensorOptions shape dtype device,
IsUnnamed t device dtype shape) =>
t -> Device
device t
_ = forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Device
optionsRuntimeDevice @shape @dtype @device
toInt ::
Tensor device dtype shape ->
Int
toInt :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor device dtype shape -> Int
toInt Tensor device dtype shape
t = Tensor -> Int
D.toInt (Tensor -> Int) -> Tensor -> Int
forall a b. (a -> b) -> a -> b
$ Tensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype shape
t
toFloat :: forall device. Tensor device 'D.Float '[] -> Float
toFloat :: forall (device :: (DeviceType, Nat)).
Tensor device 'Float '[] -> Float
toFloat Tensor device 'Float '[]
t = Tensor -> Float
forall a. TensorLike a => Tensor -> a
D.asValue (Tensor -> Float)
-> (Tensor device 'Float '[] -> Tensor)
-> Tensor device 'Float '[]
-> Float
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. CPUTensor 'Float '[] -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic (CPUTensor 'Float '[] -> Tensor)
-> (Tensor device 'Float '[] -> CPUTensor 'Float '[])
-> Tensor device 'Float '[]
-> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor device 'Float '[] -> CPUTensor 'Float '[]
forall (device :: (DeviceType, Nat)) (shape :: [Nat])
(dtype :: DType).
Tensor device dtype shape -> CPUTensor dtype shape
toCPU (Tensor device 'Float '[] -> Float)
-> Tensor device 'Float '[] -> Float
forall a b. (a -> b) -> a -> b
$ Tensor device 'Float '[]
t
toDouble :: forall device. Tensor device 'D.Double '[] -> Double
toDouble :: forall (device :: (DeviceType, Nat)).
Tensor device 'Double '[] -> Double
toDouble Tensor device 'Double '[]
t = Tensor -> Double
forall a. TensorLike a => Tensor -> a
D.asValue (Tensor -> Double)
-> (Tensor device 'Double '[] -> Tensor)
-> Tensor device 'Double '[]
-> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. CPUTensor 'Double '[] -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic (CPUTensor 'Double '[] -> Tensor)
-> (Tensor device 'Double '[] -> CPUTensor 'Double '[])
-> Tensor device 'Double '[]
-> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor device 'Double '[] -> CPUTensor 'Double '[]
forall (device :: (DeviceType, Nat)) (shape :: [Nat])
(dtype :: DType).
Tensor device dtype shape -> CPUTensor dtype shape
toCPU (Tensor device 'Double '[] -> Double)
-> Tensor device 'Double '[] -> Double
forall a b. (a -> b) -> a -> b
$ Tensor device 'Double '[]
t
toBool :: forall device. Tensor device 'D.Bool '[] -> Bool
toBool :: forall (device :: (DeviceType, Nat)).
Tensor device 'Bool '[] -> Bool
toBool Tensor device 'Bool '[]
t = Tensor -> Bool
forall a. TensorLike a => Tensor -> a
D.asValue (Tensor -> Bool)
-> (Tensor device 'Bool '[] -> Tensor)
-> Tensor device 'Bool '[]
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. CPUTensor 'Bool '[] -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic (CPUTensor 'Bool '[] -> Tensor)
-> (Tensor device 'Bool '[] -> CPUTensor 'Bool '[])
-> Tensor device 'Bool '[]
-> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor device 'Bool '[] -> CPUTensor 'Bool '[]
forall (device :: (DeviceType, Nat)) (shape :: [Nat])
(dtype :: DType).
Tensor device dtype shape -> CPUTensor dtype shape
toCPU (Tensor device 'Bool '[] -> Bool)
-> Tensor device 'Bool '[] -> Bool
forall a b. (a -> b) -> a -> b
$ Tensor device 'Bool '[]
t
type family ToDType a :: D.DType where
ToDType Bool = 'D.Bool
ToDType Int = 'D.Int64
ToDType Float = 'D.Float
ToDType Double = 'D.Double
ToDType (f a) = ToDType a
type family ToShape a :: Shape where
ToShape Bool = '[]
ToShape Int = '[]
ToShape Float = '[]
ToShape Double = '[]
ToShape (f a) = f ': ToShape a
type family FindDim (a :: Size) (shape :: Shape) :: Nat where
FindDim a (a ': _) = 0
FindDim a (b ': ax) = 1 + FindDim a ax
FindDim a _ = TypeError (Text "Not find a type:" :<>: ShowType a :<>: Text " in the shape.")
data NamedTensor (device :: (D.DeviceType, Nat)) (dtype :: D.DType) (shape :: Shape) where
FromTensor :: forall device dtype shape' shape. shape ~ ToNats shape' => Tensor device dtype shape -> NamedTensor device dtype shape'
instance Unnamed (NamedTensor device dtype shape) where
type UTShape (NamedTensor device dtype shape) = ToNats shape
type UTDevice (NamedTensor device dtype shape) = device
type UTDType (NamedTensor device dtype shape) = dtype
toUnnamed :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed (FromTensor Tensor device dtype shape
t) = Tensor device dtype shape
Tensor device dtype shape
t
fromUnnamed :: forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed = Tensor device dtype shape -> NamedTensor device dtype shape
Tensor device dtype shape -> NamedTensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape' :: Shape) (shape :: [Nat]).
(shape ~ ToNats shape') =>
Tensor device dtype shape -> NamedTensor device dtype shape'
FromTensor
toDynamic :: NamedTensor device dtype shape -> Tensor
toDynamic (FromTensor (UnsafeMkTensor Tensor
t)) = Tensor
t
instance (KnownDevice device) => Num (NamedTensor device dtype shape) where
+ :: NamedTensor device dtype shape
-> NamedTensor device dtype shape -> NamedTensor device dtype shape
(+) NamedTensor device dtype shape
a NamedTensor device dtype shape
b = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape)
-> Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall a b. (a -> b) -> a -> b
$ NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
a Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
forall a. Num a => a -> a -> a
+ NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
b
(-) NamedTensor device dtype shape
a NamedTensor device dtype shape
b = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape)
-> Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall a b. (a -> b) -> a -> b
$ NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
a Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
forall a. Num a => a -> a -> a
- NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
b
* :: NamedTensor device dtype shape
-> NamedTensor device dtype shape -> NamedTensor device dtype shape
(*) NamedTensor device dtype shape
a NamedTensor device dtype shape
b = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape)
-> Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall a b. (a -> b) -> a -> b
$ NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
a Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
forall a. Num a => a -> a -> a
* NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
b
negate :: NamedTensor device dtype shape -> NamedTensor device dtype shape
negate = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape)
-> (NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape))
-> NamedTensor device dtype shape
-> NamedTensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
forall a. Num a => a -> a
negate (Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape))
-> (NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape))
-> NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed
abs :: NamedTensor device dtype shape -> NamedTensor device dtype shape
abs = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape)
-> (NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape))
-> NamedTensor device dtype shape
-> NamedTensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
forall a. Num a => a -> a
abs (Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape))
-> (NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape))
-> NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed
signum :: NamedTensor device dtype shape -> NamedTensor device dtype shape
signum = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape)
-> (NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape))
-> NamedTensor device dtype shape
-> NamedTensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
forall a. Num a => a -> a
signum (Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape))
-> (NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape))
-> NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed
fromInteger :: Integer -> NamedTensor device dtype shape
fromInteger = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape)
-> (Integer -> Tensor device dtype (ToNats shape))
-> Integer
-> NamedTensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Integer -> Tensor device dtype (ToNats shape)
forall a. Num a => Integer -> a
fromInteger
instance KnownDevice device => Fractional (NamedTensor device dtype shape) where
NamedTensor device dtype shape
a / :: NamedTensor device dtype shape
-> NamedTensor device dtype shape -> NamedTensor device dtype shape
/ NamedTensor device dtype shape
b = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape)
-> Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall a b. (a -> b) -> a -> b
$ NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
a Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
forall a. Fractional a => a -> a -> a
/ NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed NamedTensor device dtype shape
b
recip :: NamedTensor device dtype shape -> NamedTensor device dtype shape
recip = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape)
-> (NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape))
-> NamedTensor device dtype shape
-> NamedTensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape)
forall a. Fractional a => a -> a
recip (Tensor device dtype (ToNats shape)
-> Tensor device dtype (ToNats shape))
-> (NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape))
-> NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed
fromRational :: Rational -> NamedTensor device dtype shape
fromRational = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape)
-> (Rational -> Tensor device dtype (ToNats shape))
-> Rational
-> NamedTensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Rational -> Tensor device dtype (ToNats shape)
forall a. Fractional a => Rational -> a
fromRational
instance Show (NamedTensor device dtype shape) where
show :: NamedTensor device dtype shape -> String
show = Tensor device dtype (ToNats shape) -> String
forall a. Show a => a -> String
show (Tensor device dtype (ToNats shape) -> String)
-> (NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape))
-> NamedTensor device dtype shape
-> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. NamedTensor device dtype shape
-> Tensor device dtype (ToNats shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
t -> Tensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
NamedTensor device dtype shape -> Tensor device dtype shape
toUnnamed
type family ReplaceDevice'' (tensor :: t) (device :: (D.DeviceType, Nat)) :: t where
ReplaceDevice'' (Tensor device0 dtype shape) device1 = Tensor device1 dtype shape
ReplaceDevice'' (NamedTensor device0 dtype shape) device1 = NamedTensor device1 dtype shape
type family ReplaceDType'' (tensor :: t) (dtype :: D.DType) :: t where
ReplaceDType'' (Tensor device dtype0 shape) dtype1 = Tensor device dtype1 shape
ReplaceDType'' (NamedTensor device dtype0 shape) dtype1 = NamedTensor device dtype1 shape