Safe Haskell | None |
---|---|
Language | Haskell2010 |
Torch.Typed.NN.DataParallel
Documentation
data ForwardConcurrentlyF Source #
Constructors
ForwardConcurrentlyF | |
ForwardConcurrentlyStochF |
Instances
HasForward model input output => Apply' ForwardConcurrentlyF (model, input) (Concurrently output) Source # | |
Defined in Torch.Typed.NN.DataParallel Methods apply' :: ForwardConcurrentlyF -> (model, input) -> Concurrently output Source # |
forwardConcurrently' :: forall {k} (devices' :: [(DeviceType, Nat)]) (device' :: k) (device :: (DeviceType, Nat)) model input output (models :: [Type]) (inputs :: [Type]) (outputs :: [Type]). ('Just device ~ GetDevice model, 'Just device ~ GetDevice input, HasScatter devices' device input inputs, HasReplicate devices' device model models, HZipWithM Concurrently ForwardConcurrentlyF models inputs outputs, HasGather device' devices' outputs output) => model -> input -> IO output Source #
forwardConcurrentlyStoch' :: forall {k} (devices' :: [(DeviceType, Nat)]) (device' :: k) (device :: (DeviceType, Nat)) model input output (models :: [Type]) (inputs :: [Type]) (outputs :: [Type]). ('Just device ~ GetDevice model, 'Just device ~ GetDevice input, HasScatter devices' device input inputs, HasReplicate devices' device model models, HZipWithM Concurrently ForwardConcurrentlyF models inputs outputs, HasGather device' devices' outputs output) => model -> input -> IO output Source #
forwardConcurrently :: forall {k} (models :: [k]) (inputs :: [k]) (outputs :: [k]). HZipWithM Concurrently ForwardConcurrentlyF models inputs outputs => HList models -> HList inputs -> Concurrently (HList outputs) Source #
forwardConcurrentlyStoch :: forall {k} (models :: [k]) (inputs :: [k]) (outputs :: [k]). HZipWithM Concurrently ForwardConcurrentlyF models inputs outputs => HList models -> HList inputs -> Concurrently (HList outputs) Source #
class HasGradConcurrently (device' :: k) (devices :: k1) (parameters :: [k2]) (losses :: [k3]) (gradients :: [k4]) | device' devices parameters losses -> gradients where Source #
Methods
gradConcurrently :: HList parameters -> HList losses -> Concurrently (HList gradients) Source #
Instances
(HZipWithM Concurrently GradConcurrentlyF parameters losses gradients', ReduceGradients device' devices gradients' gradients) => HasGradConcurrently (device' :: (DeviceType, Nat)) (devices :: [(DeviceType, Nat)]) (parameters :: [k1]) (losses :: [k1]) (gradients :: [k2]) Source # | |
Defined in Torch.Typed.NN.DataParallel Methods gradConcurrently :: HList parameters -> HList losses -> Concurrently (HList gradients) Source # |
data GradConcurrentlyF Source #
Constructors
GradConcurrentlyF |
Instances
(HasGrad (HList parameters) (HList gradients), Castable (HList gradients) [ATenTensor]) => Apply' GradConcurrentlyF (HList parameters, Loss device dtype) (Concurrently (HList gradients)) Source # | |
Defined in Torch.Typed.NN.DataParallel Methods apply' :: GradConcurrentlyF -> (HList parameters, Loss device dtype) -> Concurrently (HList gradients) Source # |
class ReduceGradients (device' :: (DeviceType, Nat)) (devices :: [(DeviceType, Nat)]) (xxs :: [k]) (ys :: [k1]) | device' devices xxs -> ys where Source #
Methods
reduceGradients :: HList xxs -> HList ys Source #
Instances
HasToDevice device' device (HList xs) (HList ys) => ReduceGradients device' '[device] ('[HList xs] :: [Type]) (ys :: [k]) Source # | |
Defined in Torch.Typed.NN.DataParallel | |
(HasToDevice device' device (HList xs) (HList ys), ReduceGradients device' devices xxs ys, HZipWith SumF ys ys ys, 1 <= ListLength xxs) => ReduceGradients device' (device ': devices) (HList xs ': xxs :: [Type]) (ys :: [k]) Source # | |
Defined in Torch.Typed.NN.DataParallel |