{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.Typed.Lens where
import Control.Applicative (liftA2)
import Control.Monad.State.Strict
import Control.Monad (forM)
import Data.Kind
import Data.Maybe (fromJust)
import Data.Proxy
import Data.Reflection hiding (D)
import Data.Type.Bool
import Data.Vector.Sized (Vector)
import GHC.Generics
import GHC.TypeLits
import System.IO.Unsafe
import qualified Torch.DType as D
import qualified Torch.Device as D
import qualified Torch.Functional as D hiding (select)
import qualified Torch.Functional.Internal as I
import qualified Torch.Internal.Managed.Type.TensorIndex as ATen
import Torch.Lens (Lens, Lens', Traversal, Traversal')
import qualified Torch.Tensor as T
import Torch.Typed.Auxiliary hiding (If)
import Torch.Typed.Tensor
class HasName (name :: Type -> Type) shape where
name :: Traversal' (NamedTensor device dtype shape) (NamedTensor device dtype (DropName name shape))
default name :: (KnownNat (NamedIdx name shape)) => Traversal' (NamedTensor device dtype shape) (NamedTensor device dtype (DropName name shape))
name NamedTensor device dtype (DropName name shape)
-> f (NamedTensor device dtype (DropName name shape))
func NamedTensor device dtype shape
s = f (NamedTensor device dtype shape)
func'
where
dimension :: Int
dimension :: Int
dimension = forall (n :: Nat). KnownNat n => Int
natValI @(NamedIdx name shape)
func' :: f (NamedTensor device dtype shape)
func' = (\[NamedTensor device dtype (DropName name shape)]
v -> (Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape)
-> (Tensor -> Tensor device dtype (ToNats shape))
-> Tensor
-> NamedTensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> Tensor device dtype (ToNats shape)
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> NamedTensor device dtype shape)
-> Tensor -> NamedTensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Dim -> [Tensor] -> Tensor
D.stack (Int -> Dim
D.Dim Int
dimension) ((NamedTensor device dtype (DropName name shape) -> Tensor)
-> [NamedTensor device dtype (DropName name shape)] -> [Tensor]
forall a b. (a -> b) -> [a] -> [b]
map NamedTensor device dtype (DropName name shape) -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic [NamedTensor device dtype (DropName name shape)]
v))) ([NamedTensor device dtype (DropName name shape)]
-> NamedTensor device dtype shape)
-> f [NamedTensor device dtype (DropName name shape)]
-> f (NamedTensor device dtype shape)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [f (NamedTensor device dtype (DropName name shape))]
-> f [NamedTensor device dtype (DropName name shape)]
-> f [NamedTensor device dtype (DropName name shape)]
forall {f :: * -> *} {a}. Applicative f => [f a] -> f [a] -> f [a]
swapA ((NamedTensor device dtype (DropName name shape)
-> f (NamedTensor device dtype (DropName name shape)))
-> [NamedTensor device dtype (DropName name shape)]
-> [f (NamedTensor device dtype (DropName name shape))]
forall a b. (a -> b) -> [a] -> [b]
map NamedTensor device dtype (DropName name shape)
-> f (NamedTensor device dtype (DropName name shape))
func [NamedTensor device dtype (DropName name shape)]
forall (device :: (DeviceType, Nat)) (dtype :: DType).
[NamedTensor device dtype (DropName name shape)]
a') ([NamedTensor device dtype (DropName name shape)]
-> f [NamedTensor device dtype (DropName name shape)]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [])
s' :: Tensor
s' = NamedTensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic NamedTensor device dtype shape
s
swapA :: [f a] -> f [a] -> f [a]
swapA [] f [a]
v = f [a]
v
swapA (f a
x : [f a]
xs) f [a]
v = [f a] -> f [a] -> f [a]
swapA [f a]
xs ((a -> [a] -> [a]) -> f a -> f [a] -> f [a]
forall a b c. (a -> b -> c) -> f a -> f b -> f c
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (\a
a [a]
b -> [a]
b [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
a]) f a
x f [a]
v)
a' :: [NamedTensor device dtype (DropName name shape)]
a' :: forall (device :: (DeviceType, Nat)) (dtype :: DType).
[NamedTensor device dtype (DropName name shape)]
a' = (Tensor -> NamedTensor device dtype (DropName name shape))
-> [Tensor] -> [NamedTensor device dtype (DropName name shape)]
forall a b. (a -> b) -> [a] -> [b]
map (Tensor device dtype (ToNats (DropName name shape))
-> NamedTensor device dtype (DropName name shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed
(NamedTensor device dtype (DropName name shape))
device
dtype
shape =>
Tensor device dtype shape
-> NamedTensor device dtype (DropName name shape)
fromUnnamed (Tensor device dtype (ToNats (DropName name shape))
-> NamedTensor device dtype (DropName name shape))
-> (Tensor -> Tensor device dtype (ToNats (DropName name shape)))
-> Tensor
-> NamedTensor device dtype (DropName name shape)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> Tensor device dtype (ToNats (DropName name shape))
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor) ([Tensor] -> [NamedTensor device dtype (DropName name shape)])
-> [Tensor] -> [NamedTensor device dtype (DropName name shape)]
forall a b. (a -> b) -> a -> b
$ Tensor -> Int -> [Tensor]
I.unbind Tensor
s' Int
dimension
instance (KnownNat (NamedIdx name shape)) => HasName name shape
class HasField (field :: Symbol) shape where
field :: Lens' (NamedTensor device dtype shape) (NamedTensor device dtype (DropField field shape))
default field :: (FieldIdx field shape) => Lens' (NamedTensor device dtype shape) (NamedTensor device dtype (DropField field shape))
field NamedTensor device dtype (DropField field shape)
-> f (NamedTensor device dtype (DropField field shape))
func NamedTensor device dtype shape
s = (NamedTensor device dtype (DropField field shape)
-> NamedTensor device dtype shape)
-> f (NamedTensor device dtype (DropField field shape))
-> f (NamedTensor device dtype shape)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap NamedTensor device dtype (DropField field shape)
-> NamedTensor device dtype shape
forall (device :: (DeviceType, Nat)) (dtype :: DType).
NamedTensor device dtype (DropField field shape)
-> NamedTensor device dtype shape
func' (NamedTensor device dtype (DropField field shape)
-> f (NamedTensor device dtype (DropField field shape))
func NamedTensor device dtype (DropField field shape)
forall (device :: (DeviceType, Nat)) (dtype :: DType).
NamedTensor device dtype (DropField field shape)
a')
where
index :: [Maybe Int]
index = forall (field :: Symbol) (a :: Shape).
FieldIdx field a =>
Proxy a -> [Maybe Int]
fieldIdx @field @shape Proxy shape
forall {k} (t :: k). Proxy t
Proxy
func' :: NamedTensor device dtype (DropField field shape) -> NamedTensor device dtype shape
func' :: forall (device :: (DeviceType, Nat)) (dtype :: DType).
NamedTensor device dtype (DropField field shape)
-> NamedTensor device dtype shape
func' NamedTensor device dtype (DropField field shape)
v = Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed (NamedTensor device dtype shape) device dtype shape =>
Tensor device dtype shape -> NamedTensor device dtype shape
fromUnnamed (Tensor device dtype (ToNats shape)
-> NamedTensor device dtype shape)
-> (Tensor -> Tensor device dtype (ToNats shape))
-> Tensor
-> NamedTensor device dtype shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> Tensor device dtype (ToNats shape)
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> NamedTensor device dtype shape)
-> Tensor -> NamedTensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Tensor -> [Maybe Int] -> Tensor -> Tensor
forall a t.
(TensorIndex a, TensorLike t) =>
Tensor -> a -> t -> Tensor
T.maskedFill Tensor
s' [Maybe Int]
index (NamedTensor device dtype (DropField field shape) -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic NamedTensor device dtype (DropField field shape)
v)
s' :: Tensor
s' = NamedTensor device dtype shape -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic NamedTensor device dtype shape
s
a' :: NamedTensor device dtype (DropField field shape)
a' :: forall (device :: (DeviceType, Nat)) (dtype :: DType).
NamedTensor device dtype (DropField field shape)
a' = Tensor device dtype (ToNats (DropField field shape))
-> NamedTensor device dtype (DropField field shape)
forall t (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
(Unnamed t, IsUnnamed t device dtype shape) =>
Tensor device dtype shape -> t
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
IsUnnamed
(NamedTensor device dtype (DropField field shape))
device
dtype
shape =>
Tensor device dtype shape
-> NamedTensor device dtype (DropField field shape)
fromUnnamed (Tensor device dtype (ToNats (DropField field shape))
-> NamedTensor device dtype (DropField field shape))
-> (Tensor -> Tensor device dtype (ToNats (DropField field shape)))
-> Tensor
-> NamedTensor device dtype (DropField field shape)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor -> Tensor device dtype (ToNats (DropField field shape))
forall (device :: (DeviceType, Nat)) (dtype :: DType)
(shape :: [Nat]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor (Tensor -> NamedTensor device dtype (DropField field shape))
-> Tensor -> NamedTensor device dtype (DropField field shape)
forall a b. (a -> b) -> a -> b
$ (Tensor
s' Tensor -> [Maybe Int] -> Tensor
forall a. TensorIndex a => Tensor -> a -> Tensor
T.! [Maybe Int]
index)
instance {-# OVERLAPS #-} FieldIdx field shape => HasField field shape
type family GHasField (field :: Symbol) f :: Bool where
GHasField field (S1 ('MetaSel ('Just field) _ _ _) _) = 'True
GHasField field (S1 ('MetaSel _ _ _ _) _) = 'False
GHasField field (D1 _ f) = GHasField field f
GHasField field (C1 _ f) = GHasField field f
GHasField field (l :*: r) = GHasField field l || GHasField field r
GHasField field (l :+: r) = GHasField field l || GHasField field r
GHasField field (K1 _ _) = 'False
GHasField field U1 = 'False
GHasField field (Vector n) = 'False
GHasField field a = GHasField field (Rep (a ()))
type family DropField (field :: Symbol) (a :: [Type -> Type]) :: [Type -> Type] where
DropField field '[] = '[]
DropField field (x ': xs) = If (GHasField field x) xs (x ': DropField field xs)
type family DropName (name :: Type -> Type) (a :: [Type -> Type]) :: [Type -> Type] where
DropName name '[] = '[]
DropName name (name ': xs) = xs
DropName name (x ': xs) = x ': DropName name xs
instance {-# OVERLAPS #-} T.TensorIndex [Maybe Int] where
pushIndex :: [RawTensorIndex] -> [Maybe Int] -> [RawTensorIndex]
pushIndex [RawTensorIndex]
vec [Maybe Int]
list_of_maybe_int = IO [RawTensorIndex] -> [RawTensorIndex]
forall a. IO a -> a
unsafePerformIO (IO [RawTensorIndex] -> [RawTensorIndex])
-> IO [RawTensorIndex] -> [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ do
[RawTensorIndex]
idx <- [Maybe Int]
-> (Maybe Int -> IO RawTensorIndex) -> IO [RawTensorIndex]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Maybe Int]
list_of_maybe_int ((Maybe Int -> IO RawTensorIndex) -> IO [RawTensorIndex])
-> (Maybe Int -> IO RawTensorIndex) -> IO [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ \Maybe Int
i -> do
case Maybe Int
i of
Maybe Int
Nothing -> ForeignPtr TensorIndex -> RawTensorIndex
T.RawTensorIndex (ForeignPtr TensorIndex -> RawTensorIndex)
-> IO (ForeignPtr TensorIndex) -> IO RawTensorIndex
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CInt -> CInt -> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithSlice CInt
0 CInt
forall a. Bounded a => a
maxBound CInt
1
Just Int
v -> ForeignPtr TensorIndex -> RawTensorIndex
T.RawTensorIndex (ForeignPtr TensorIndex -> RawTensorIndex)
-> IO (ForeignPtr TensorIndex) -> IO RawTensorIndex
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CInt -> IO (ForeignPtr TensorIndex)
ATen.newTensorIndexWithInt (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
v)
[RawTensorIndex] -> IO [RawTensorIndex]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([RawTensorIndex] -> IO [RawTensorIndex])
-> [RawTensorIndex] -> IO [RawTensorIndex]
forall a b. (a -> b) -> a -> b
$ [RawTensorIndex]
idx [RawTensorIndex] -> [RawTensorIndex] -> [RawTensorIndex]
forall a. [a] -> [a] -> [a]
++ [RawTensorIndex]
vec
type family NamedIdx (name :: Type -> Type) (shape :: [Type -> Type]) :: Nat where
NamedIdx name '[] = TypeError (Text "There is not the name in the shape.")
NamedIdx name (name ': xs) = 0
NamedIdx name (x ': xs) = NamedIdx name xs + 1
class FieldIdx (field :: Symbol) (a :: [Type -> Type]) where
fieldIdx :: Proxy a -> [Maybe Int]
instance FieldIdx field '[] where
fieldIdx :: Proxy '[] -> [Maybe Int]
fieldIdx Proxy '[]
_ = []
instance (FieldId field (x ()), FieldIdx field xs) => FieldIdx field (x ': xs) where
fieldIdx :: Proxy (x : xs) -> [Maybe Int]
fieldIdx Proxy (x : xs)
_ = forall (field :: Symbol) a. FieldId field a => Proxy a -> Maybe Int
fieldId @field @(x ()) Proxy (x ())
forall {k} (t :: k). Proxy t
Proxy Maybe Int -> [Maybe Int] -> [Maybe Int]
forall a. a -> [a] -> [a]
: forall (field :: Symbol) (a :: Shape).
FieldIdx field a =>
Proxy a -> [Maybe Int]
fieldIdx @field @xs Proxy xs
forall {k} (t :: k). Proxy t
Proxy
class FieldId (field :: Symbol) a where
fieldId :: Proxy a -> Maybe Int
default fieldId :: (Generic a, GFieldId field (Rep a)) => Proxy a -> Maybe Int
fieldId Proxy a
_ = forall (field :: Symbol) (a :: * -> *).
GFieldId field a =>
Proxy a -> Maybe Int
gfieldId @field (Proxy (Rep a)
forall {k} (t :: k). Proxy t
Proxy :: Proxy (Rep a))
instance FieldId field (Vector n v) where
fieldId :: Proxy (Vector n v) -> Maybe Int
fieldId Proxy (Vector n v)
_ = Maybe Int
forall a. Maybe a
Nothing
instance {-# OVERLAPS #-} (Generic s, GFieldId field (Rep s)) => FieldId field s
class GFieldId (field :: Symbol) (a :: Type -> Type) where
gfieldId :: Proxy a -> Maybe Int
gfieldId Proxy a
p = (Maybe Int, Int) -> Maybe Int
forall a b. (a, b) -> a
fst ((Maybe Int, Int) -> Maybe Int) -> (Maybe Int, Int) -> Maybe Int
forall a b. (a -> b) -> a -> b
$ forall (field :: Symbol) (a :: * -> *).
GFieldId field a =>
Proxy a -> (Maybe Int, Int)
gfieldId' @field @a Proxy a
p
gfieldId' :: Proxy a -> (Maybe Int, Int)
instance (GFieldId field f) => GFieldId field (M1 D t f) where
gfieldId' :: Proxy (M1 D t f) -> (Maybe Int, Int)
gfieldId' Proxy (M1 D t f)
_ = forall (field :: Symbol) (a :: * -> *).
GFieldId field a =>
Proxy a -> (Maybe Int, Int)
gfieldId' @field (Proxy f
forall {k} (t :: k). Proxy t
Proxy :: Proxy f)
instance (GFieldId field f) => GFieldId field (M1 C t f) where
gfieldId' :: Proxy (M1 C t f) -> (Maybe Int, Int)
gfieldId' Proxy (M1 C t f)
_ = forall (field :: Symbol) (a :: * -> *).
GFieldId field a =>
Proxy a -> (Maybe Int, Int)
gfieldId' @field (Proxy f
forall {k} (t :: k). Proxy t
Proxy :: Proxy f)
instance (KnownSymbol field, KnownSymbol field_) => GFieldId field (S1 ('MetaSel ('Just field_) p f b) (Rec0 a)) where
gfieldId' :: Proxy (S1 ('MetaSel ('Just field_) p f b) (Rec0 a))
-> (Maybe Int, Int)
gfieldId' Proxy (S1 ('MetaSel ('Just field_) p f b) (Rec0 a))
_ =
if Proxy field -> String
forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal (Proxy field
forall {k} (t :: k). Proxy t
Proxy :: Proxy field) String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== Proxy field_ -> String
forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal (Proxy field_
forall {k} (t :: k). Proxy t
Proxy :: Proxy field_)
then (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
0, Int
1)
else (Maybe Int
forall a. Maybe a
Nothing, Int
1)
instance GFieldId field (K1 c f) where
gfieldId' :: Proxy (K1 c f) -> (Maybe Int, Int)
gfieldId' Proxy (K1 c f)
_ = (Maybe Int
forall a. Maybe a
Nothing, Int
1)
instance GFieldId field U1 where
gfieldId' :: Proxy U1 -> (Maybe Int, Int)
gfieldId' Proxy U1
_ = (Maybe Int
forall a. Maybe a
Nothing, Int
1)
instance (GFieldId field f, GFieldId field g) => GFieldId field (f :*: g) where
gfieldId' :: Proxy (f :*: g) -> (Maybe Int, Int)
gfieldId' Proxy (f :*: g)
_ =
case (forall (field :: Symbol) (a :: * -> *).
GFieldId field a =>
Proxy a -> (Maybe Int, Int)
gfieldId' @field (Proxy f
forall {k} (t :: k). Proxy t
Proxy :: Proxy f), forall (field :: Symbol) (a :: * -> *).
GFieldId field a =>
Proxy a -> (Maybe Int, Int)
gfieldId' @field (Proxy g
forall {k} (t :: k). Proxy t
Proxy :: Proxy g)) of
((Maybe Int
Nothing, Int
t0), (Maybe Int
Nothing, Int
t1)) -> (Maybe Int
forall a. Maybe a
Nothing, Int
t0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
t1)
((Maybe Int
Nothing, Int
t0), (Just Int
v1, Int
t1)) -> (Int -> Maybe Int
forall a. a -> Maybe a
Just (Int
v1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
t0), Int
t1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
t0)
((Just Int
v0, Int
t0), (Maybe Int
_, Int
t1)) -> (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
v0, Int
t0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
t1)
instance (GFieldId field f, GFieldId field g) => GFieldId field (f :+: g) where
gfieldId' :: Proxy (f :+: g) -> (Maybe Int, Int)
gfieldId' Proxy (f :+: g)
_ =
case (forall (field :: Symbol) (a :: * -> *).
GFieldId field a =>
Proxy a -> (Maybe Int, Int)
gfieldId' @field (Proxy f
forall {k} (t :: k). Proxy t
Proxy :: Proxy f), forall (field :: Symbol) (a :: * -> *).
GFieldId field a =>
Proxy a -> (Maybe Int, Int)
gfieldId' @field (Proxy g
forall {k} (t :: k). Proxy t
Proxy :: Proxy g)) of
((Maybe Int
Nothing, Int
_), (Maybe Int, Int)
a1) -> (Maybe Int, Int)
a1
(a0 :: (Maybe Int, Int)
a0@(Just Int
_, Int
_), (Maybe Int, Int)
_) -> (Maybe Int, Int)
a0