{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.Lens where
import Control.Monad.Identity
import Control.Monad.State.Strict
import GHC.Generics
type Lens s t a b = forall f. Functor f => (a -> f b) -> s -> f t
type Lens' s a = Lens s s a a
type Traversal s t a b = forall f. Applicative f => (a -> f b) -> s -> f t
type Traversal' s a = Traversal s s a a
class HasTypes s a where
types_ :: Traversal' s a
default types_ :: (Generic s, GHasTypes (Rep s) a) => Traversal' s a
types_ a -> f a
func s
s = Rep s Any -> s
forall a x. Generic a => Rep a x -> a
forall x. Rep s x -> s
to (Rep s Any -> s) -> f (Rep s Any) -> f s
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> f a) -> Rep s Any -> f (Rep s Any)
forall b. Traversal' (Rep s b) a
forall (s :: * -> *) a b. GHasTypes s a => Traversal' (s b) a
gtypes a -> f a
func (s -> Rep s Any
forall x. s -> Rep s x
forall a x. Generic a => a -> Rep a x
from s
s)
{-# INLINE types_ #-}
instance {-# OVERLAPS #-} (Generic s, GHasTypes (Rep s) a) => HasTypes s a
over :: Traversal' s a -> (a -> a) -> s -> s
over :: forall s a. Traversal' s a -> (a -> a) -> s -> s
over Traversal' s a
l a -> a
f = Identity s -> s
forall a. Identity a -> a
runIdentity (Identity s -> s) -> (s -> Identity s) -> s -> s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Identity a) -> s -> Identity s
Traversal' s a
l (a -> Identity a
forall a. a -> Identity a
Identity (a -> Identity a) -> (a -> a) -> a -> Identity a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
f)
flattenValues :: forall a s. Traversal' s a -> s -> [a]
flattenValues :: forall a s. Traversal' s a -> s -> [a]
flattenValues Traversal' s a
func s
orgData = [a] -> [a]
forall a. [a] -> [a]
reverse ([a] -> [a]) -> ((s, [a]) -> [a]) -> (s, [a]) -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (s, [a]) -> [a]
forall a b. (a, b) -> b
snd ((s, [a]) -> [a]) -> (s, [a]) -> [a]
forall a b. (a -> b) -> a -> b
$ State [a] s -> [a] -> (s, [a])
forall s a. State s a -> s -> (a, s)
runState ((a -> StateT [a] Identity a) -> s -> State [a] s
Traversal' s a
func a -> StateT [a] Identity a
push s
orgData) []
where
push :: a -> State [a] a
push :: a -> StateT [a] Identity a
push a
v = do
[a]
d <- StateT [a] Identity [a]
forall s (m :: * -> *). MonadState s m => m s
get
[a] -> StateT [a] Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put ([a] -> StateT [a] Identity ()) -> [a] -> StateT [a] Identity ()
forall a b. (a -> b) -> a -> b
$ a
v a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
d
a -> StateT [a] Identity a
forall a. a -> StateT [a] Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return a
v
replaceValues :: forall a s. Traversal' s a -> s -> [a] -> s
replaceValues :: forall a s. Traversal' s a -> s -> [a] -> s
replaceValues Traversal' s a
func s
orgData [a]
newValues = (s, [a]) -> s
forall a b. (a, b) -> a
fst ((s, [a]) -> s) -> (s, [a]) -> s
forall a b. (a -> b) -> a -> b
$ State [a] s -> [a] -> (s, [a])
forall s a. State s a -> s -> (a, s)
runState ((a -> StateT [a] Identity a) -> s -> State [a] s
Traversal' s a
func a -> StateT [a] Identity a
pop s
orgData) [a]
newValues
where
pop :: a -> State [a] a
pop :: a -> StateT [a] Identity a
pop a
_ = do
[a]
d <- StateT [a] Identity [a]
forall s (m :: * -> *). MonadState s m => m s
get
case [a]
d of
[] -> [Char] -> StateT [a] Identity a
forall a. HasCallStack => [Char] -> a
error [Char]
"Not enough values supplied to replaceValues"
a
x : [a]
xs -> do
[a] -> StateT [a] Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put [a]
xs
a -> StateT [a] Identity a
forall a. a -> StateT [a] Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
types :: forall a s. HasTypes s a => Traversal' s a
types :: forall a s. HasTypes s a => Traversal' s a
types = forall s a. HasTypes s a => Traversal' s a
types_ @s @a
class GHasTypes s a where
gtypes :: forall b. Traversal' (s b) a
instance GHasTypes U1 a where
gtypes :: forall b. Traversal' (U1 b) a
gtypes a -> f a
_ = U1 b -> f (U1 b)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
{-# INLINE gtypes #-}
instance (GHasTypes f a, GHasTypes g a) => GHasTypes (f :+: g) a where
gtypes :: forall b. Traversal' ((:+:) f g b) a
gtypes a -> f a
func (L1 f b
x) = f b -> (:+:) f g b
forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> (:+:) f g p
L1 (f b -> (:+:) f g b) -> f (f b) -> f ((:+:) f g b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> f a) -> f b -> f (f b)
forall b. Traversal' (f b) a
forall (s :: * -> *) a b. GHasTypes s a => Traversal' (s b) a
gtypes a -> f a
func f b
x
gtypes a -> f a
func (R1 g b
x) = g b -> (:+:) f g b
forall k (f :: k -> *) (g :: k -> *) (p :: k). g p -> (:+:) f g p
R1 (g b -> (:+:) f g b) -> f (g b) -> f ((:+:) f g b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> f a) -> g b -> f (g b)
forall b. Traversal' (g b) a
forall (s :: * -> *) a b. GHasTypes s a => Traversal' (s b) a
gtypes a -> f a
func g b
x
instance (GHasTypes f a, GHasTypes g a) => GHasTypes (f :*: g) a where
gtypes :: forall b. Traversal' ((:*:) f g b) a
gtypes a -> f a
func (f b
x :*: g b
y) = f b -> g b -> (:*:) f g b
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
(:*:) (f b -> g b -> (:*:) f g b) -> f (f b) -> f (g b -> (:*:) f g b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> f a) -> f b -> f (f b)
forall b. Traversal' (f b) a
forall (s :: * -> *) a b. GHasTypes s a => Traversal' (s b) a
gtypes a -> f a
func f b
x f (g b -> (:*:) f g b) -> f (g b) -> f ((:*:) f g b)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> f a) -> g b -> f (g b)
forall b. Traversal' (g b) a
forall (s :: * -> *) a b. GHasTypes s a => Traversal' (s b) a
gtypes a -> f a
func g b
y
{-# INLINE gtypes #-}
instance (HasTypes s a) => GHasTypes (K1 i s) a where
gtypes :: forall b. Traversal' (K1 i s b) a
gtypes a -> f a
func (K1 s
x) = s -> K1 i s b
forall k i c (p :: k). c -> K1 i c p
K1 (s -> K1 i s b) -> f s -> f (K1 i s b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> f a) -> s -> f s
forall a s. HasTypes s a => Traversal' s a
Traversal' s a
types a -> f a
func s
x
{-# INLINE gtypes #-}
instance GHasTypes s a => GHasTypes (M1 i t s) a where
gtypes :: forall b. Traversal' (M1 i t s b) a
gtypes a -> f a
func (M1 s b
x) = s b -> M1 i t s b
forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 (s b -> M1 i t s b) -> f (s b) -> f (M1 i t s b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> f a) -> s b -> f (s b)
forall b. Traversal' (s b) a
forall (s :: * -> *) a b. GHasTypes s a => Traversal' (s b) a
gtypes a -> f a
func s b
x
{-# INLINE gtypes #-}
instance {-# OVERLAPS #-} (HasTypes s a) => HasTypes [s] a where
types_ :: Traversal' [s] a
types_ a -> f a
func [] = [s] -> f [s]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
types_ a -> f a
func (s
x : [s]
xs) = (:) (s -> [s] -> [s]) -> f s -> f ([s] -> [s])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> f a) -> s -> f s
forall s a. HasTypes s a => Traversal' s a
Traversal' s a
types_ a -> f a
func s
x f ([s] -> [s]) -> f [s] -> f [s]
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> f a) -> [s] -> f [s]
forall s a. HasTypes s a => Traversal' s a
Traversal' [s] a
types_ a -> f a
func [s]
xs
{-# INLINE types_ #-}
instance {-# OVERLAPS #-} (HasTypes s0 a, HasTypes s1 a) => HasTypes (s0, s1) a where
types_ :: Traversal' (s0, s1) a
types_ a -> f a
func (s0
s0, s1
s1) = (,) (s0 -> s1 -> (s0, s1)) -> f s0 -> f (s1 -> (s0, s1))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> f a) -> s0 -> f s0
forall s a. HasTypes s a => Traversal' s a
Traversal' s0 a
types_ a -> f a
func s0
s0 f (s1 -> (s0, s1)) -> f s1 -> f (s0, s1)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> f a) -> s1 -> f s1
forall s a. HasTypes s a => Traversal' s a
Traversal' s1 a
types_ a -> f a
func s1
s1
{-# INLINE types_ #-}
instance {-# OVERLAPS #-} (HasTypes s0 a, HasTypes s1 a, HasTypes s2 a) => HasTypes (s0, s1, s2) a where
types_ :: Traversal' (s0, s1, s2) a
types_ a -> f a
func (s0
s0, s1
s1, s2
s2) = (,,) (s0 -> s1 -> s2 -> (s0, s1, s2))
-> f s0 -> f (s1 -> s2 -> (s0, s1, s2))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> f a) -> s0 -> f s0
forall s a. HasTypes s a => Traversal' s a
Traversal' s0 a
types_ a -> f a
func s0
s0 f (s1 -> s2 -> (s0, s1, s2)) -> f s1 -> f (s2 -> (s0, s1, s2))
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> f a) -> s1 -> f s1
forall s a. HasTypes s a => Traversal' s a
Traversal' s1 a
types_ a -> f a
func s1
s1 f (s2 -> (s0, s1, s2)) -> f s2 -> f (s0, s1, s2)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> f a) -> s2 -> f s2
forall s a. HasTypes s a => Traversal' s a
Traversal' s2 a
types_ a -> f a
func s2
s2
{-# INLINE types_ #-}