{-# OPTIONS_GHC -Wno-orphans #-}
-- | The implementation of reverse derivative and forward derivative
-- calculation for an objective function on values of complicated types,
-- e.g., nested tuples of tensors.
--
-- The objective function can be defined as a sufficiently polymorphic
-- Haskell function that uses numeric classes as well as the multi-dimensional
-- tensor operation listed in "HordeAd.OpsTensor". To obtain symbolic
-- derivatives (derivative code that can be executed many times without
-- performing AD again), the user needs an objective function polymorphic
-- enough so that it can be instantiated to the 'HordeAd.Core.Ast.AstTensor'
-- type (nested in tuples, etc., for some extra flexibility).
-- For non-symbolic derivatives, the ability to instantiate to the
-- `HordeAd.Core.CarriersADVal.ADVal` type of dual numbers is enough.
-- See the classes these types are instances of to gauge the breadth
-- of the offered respective APIs.
module HordeAd.ADEngine
  ( -- * Symbolic reverse derivative adaptors
    grad, vjp
  , gradArtifact, vjpArtifact
  , gradInterpretArtifact, vjpInterpretArtifact
    -- * Symbolic forward derivative adaptors
  , jvp, jvpArtifact, jvpInterpretArtifact
    -- * Non-symbolic reverse derivative adaptors
  , cgrad, cvjp
    -- * Non-symbolic forward derivative adaptors
  , cjvp
    -- * Internal machinery for symbolic adaptors
  , IncomingCotangentHandling(..)
  , revArtifactAdapt, revArtifactDelta
  , revProduceArtifactWithoutInterpretation, revInterpretArtifact
  , fwdArtifactAdapt, fwdArtifactDelta, fwdInterpretArtifact
    -- * Internal machinery for non-symbolic adaptors
  , cfwdBoth
  ) where

import Prelude

import HordeAd.AstEngine
import HordeAd.Core.Adaptor
import HordeAd.Core.Ast
import HordeAd.Core.AstEnv
import HordeAd.Core.AstInterpret
import HordeAd.Core.CarriersADVal
import HordeAd.Core.CarriersAst
import HordeAd.Core.CarriersConcrete
import HordeAd.Core.Delta
import HordeAd.Core.DeltaEval
import HordeAd.Core.Ops
import HordeAd.Core.OpsADVal
import HordeAd.Core.OpsAst
import HordeAd.Core.TensorKind
import HordeAd.Core.Types
import HordeAd.Core.Unwind

-- * Symbolic reverse derivative adaptors

-- | This simplified version of the symbolic reverse derivative operation
-- sets the incoming cotangent @dt@ to be 1 and assumes the codomain
-- of the function to be differentiated is a scalar.
--
-- We don't enforce (e.g., by quantifcation) that the objective function
-- is closed, because we evaluate the result of the differentiation
-- down to concrete arrays and so there's no risk of "perturbation confusion"
-- between different levels of differentiation if it's done multiple times.
-- For simplicity of the type signature, the resulting value is converted from
-- the type of concrete contangents to the type of concrete input parameters.
grad
  :: forall src r tgt.
     ( X src ~ X (Value src), KnownSTK (X src)
     , AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , AdaptableTarget Concrete (Value src)
     , tgt ~ AstTensor AstMethodLet FullSpan (TKScalar r) )
  => (src -> tgt)  -- ^ the objective function
  -> Value src
  -> Value src  -- morally Value (ADTensorKind src)
{-# INLINE grad #-}
grad f vals = revMaybe f vals Nothing

-- | This version of the symbolic reverse derivative operation
-- explicitly takes the sensitivity parameter (the incoming cotangent).
-- It also permits an arbitrary (nested tuple+) type of the domain
-- and arbitrary (nested pair) tensor kind of the codomain
-- of the function to be differentiated. The downside of the generality
-- is that if the function doesn't have an explicit type signature,
-- the type to which this operation is instantiated often has to be spelled
-- in full via explicit type applications to aid type reconstruction.
-- For simplicity of the type signature, the resulting value is converted from
-- the type of concrete contangents to the type of concrete input parameters.
vjp
  :: forall src ztgt tgt.
     ( X src ~ X (Value src), KnownSTK (X src)
     , AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , AdaptableTarget Concrete (Value src)
     , tgt ~ AstTensor AstMethodLet FullSpan ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> Value src
  -> Concrete (ADTensorKind ztgt)
  -> Value src  -- morally Value (ADTensorKind src)
{-# INLINE vjp #-}
vjp f vals dt = revMaybe f vals (Just dt)

-- | Compute the reverse derivative not for a specific input, but as symbolic
-- function from inputs to the gradient value.
-- The function is represented as an "artifact", which is the gradient
-- AST term together with the variable corresponding to the input.
gradArtifact
  :: forall src r tgt.
     ( X src ~ X (Value src), KnownSTK (X src)
     , AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , AdaptableTarget Concrete (Value src)
     , tgt ~ AstTensor AstMethodLet FullSpan (TKScalar r) )
  => (src -> tgt)  -- ^ the objective function
  -> Value src
  -> AstArtifactRev (X src) (TKScalar r)
       -- ^ the artifact containing the symbolic code of the derivative
{-# INLINE gradArtifact #-}
gradArtifact f vals0 =
  let xftk = tftkG (knownSTK @(X src)) $ unConcrete $ toTarget vals0
  in revArtifactAdapt IgnoreIncomingCotangent f xftk

-- | Compute the reverse derivative not for a specific input, but as symbolic
-- function from inputs and incoming cotangents to the gradient value.
-- The function is represented as an "artifact", which is the gradient
-- AST term together with variables corresponding to the input and cotangent.
vjpArtifact
  :: forall src ztgt tgt.
     ( X src ~ X (Value src), KnownSTK (X src)
     , AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , AdaptableTarget Concrete (Value src)
     , tgt ~ AstTensor AstMethodLet FullSpan ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> Value src
  -> AstArtifactRev (X src) ztgt
       -- ^ the artifact containing the symbolic code of the derivative
{-# INLINE vjpArtifact #-}
vjpArtifact f vals0 =
  let xftk = tftkG (knownSTK @(X src)) $ unConcrete $ toTarget vals0
  in revArtifactAdapt UseIncomingCotangent f xftk

-- | Interpret the "artifact" as a function from a concrete tensor
-- to a concrete tensor (possibly adapted, e.g., from horde-ad nested pairs
-- to Haskell n-tuples).
gradInterpretArtifact
  :: forall x r avals.
     (X avals ~ ADTensorKind x, AdaptableTarget Concrete avals)
  => AstArtifactRev x (TKScalar r)
       -- ^ the artifact containing the symbolic code of the derivative
  -> Concrete x
  -> avals
{-# INLINE gradInterpretArtifact #-}
gradInterpretArtifact AstArtifactRev{..} parameters =
  let xftk = varNameToFTK artVarDomainRev
      azftk = varNameToFTK artVarDtRev
                -- STKScalar @(ADTensorScalar r) or STKScalar @Z1
      oneAtF = treplTarget 1 azftk
      env = extendEnv artVarDtRev oneAtF
            $ extendEnv artVarDomainRev parameters emptyEnv
  in if tftkG (ftkToSTK xftk) (unConcrete parameters) == xftk
     then fromTarget $ interpretAstPrimal env artDerivativeRev
     else error "gradInterpretArtifact: reverse derivative parameters must have the same shape as the domain of the objective function"

-- | Interpret the "artifact" as a function from concrete tensors
-- to a concrete tensor (possibly adapted, e.g., from horde-ad nested pairs
-- to Haskell n-tuples).
vjpInterpretArtifact
  :: forall x z avals.
     (X avals ~ ADTensorKind x, AdaptableTarget Concrete avals)
  => AstArtifactRev x z
       -- ^ the artifact containing the symbolic code of the derivative
  -> Concrete x
  -> Concrete (ADTensorKind z)
  -> avals
{-# INLINE vjpInterpretArtifact #-}
vjpInterpretArtifact AstArtifactRev{..} parameters dt =
  let xftk = varNameToFTK artVarDomainRev
      azftk = varNameToFTK artVarDtRev
      env = extendEnv artVarDtRev dt
            $ extendEnv artVarDomainRev parameters emptyEnv
  in if tftkG (ftkToSTK xftk) (unConcrete parameters) == xftk
     then if tftkG (ftkToSTK azftk) (unConcrete dt) == azftk
          then fromTarget $ interpretAstPrimal env artDerivativeRev
          else error "vjpInterpretArtifact: reverse derivative incoming cotangent must have the same shape as the codomain of the objective function"
     else error "vjpInterpretArtifact: reverse derivative parameters must have the same shape as the domain of the objective function"


-- * Symbolic reverse derivative adaptors' internal machinery

revMaybe
  :: forall src ztgt tgt.
     ( X src ~ X (Value src), KnownSTK (X src)
     , AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , AdaptableTarget Concrete (Value src)
     , tgt ~ AstTensor AstMethodLet FullSpan ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> Value src
  -> Maybe (Concrete (ADTensorKind ztgt))
  -> Value src  -- morally Value (ADTensorKind src)
{-# INLINE revMaybe #-}
revMaybe f vals0 mdt =
  let valsTarget = toTarget vals0
      xftk = tftkG (knownSTK @(X src)) $ unConcrete valsTarget
      cotangentHandling =
        maybe IgnoreIncomingCotangent (const UseIncomingCotangent) mdt
      artifactRaw = revArtifactAdapt cotangentHandling f xftk
      artifact = simplifyArtifactGradient artifactRaw
  in fromTarget $ fromADTensorKindShared (ftkToSTK xftk)
     $ fst $ revInterpretArtifact artifact valsTarget mdt

revArtifactAdapt
  :: forall src ztgt tgt.
     ( AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , tgt ~ AstTensor AstMethodLet FullSpan ztgt )
  => IncomingCotangentHandling
  -> (src -> tgt)  -- ^ the objective function
  -> FullShapeTK (X src)
  -> AstArtifactRev (X src) ztgt
       -- ^ the artifact containing the symbolic code of the derivative
{-# INLINE revArtifactAdapt #-}
revArtifactAdapt cotangentHandling f xftk =
  let g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
      g !arg = simplifyInline $ ttlet arg $ f . fromTarget
                                  -- fromTarget requires duplicable
  in revProduceArtifact cotangentHandling g emptyEnv xftk

revInterpretArtifact
  :: forall x z.
     AstArtifactRev x z
       -- ^ the artifact containing the symbolic code of the derivative
  -> Concrete x
  -> Maybe (Concrete (ADTensorKind z))
  -> (Concrete (ADTensorKind x), Concrete z)
{-# INLINE revInterpretArtifact #-}
revInterpretArtifact AstArtifactRev{..} parameters mdt =
  let azftk = varNameToFTK artVarDtRev
      env = extendEnv artVarDomainRev parameters emptyEnv
      envDt = case mdt of
        Nothing ->
          let oneAtF = treplTarget 1 azftk
          in extendEnv artVarDtRev oneAtF env
        Just dt ->
          if tftkG (ftkToSTK azftk) (unConcrete dt) == azftk
          then extendEnv artVarDtRev dt env
          else error "revInterpretArtifact: reverse derivative incoming cotangent must have the same shape as the codomain of the objective function"
      gradient = interpretAstPrimal envDt artDerivativeRev
      primal = interpretAstPrimal env artPrimalRev
  in (gradient, primal)


-- * Symbolic reverse derivative adaptors' testing-only internal machinery

revArtifactDelta
  :: forall src ztgt tgt.
     ( AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , tgt ~ AstTensor AstMethodLet FullSpan ztgt )
  => IncomingCotangentHandling
  -> (src -> tgt)  -- ^ the objective function
  -> FullShapeTK (X src)
  -> (AstArtifactRev (X src) ztgt, Delta (AstRaw PrimalSpan) ztgt)
       -- ^ the artifact containing the symbolic code of the derivative
{-# INLINE revArtifactDelta #-}
revArtifactDelta cotangentHandling f xftk =
  let g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
      g !arg = ttlet arg $ f . fromTarget
  in revArtifactFromForwardPass cotangentHandling
                                (forwardPassByInterpretation g emptyEnv) xftk

revProduceArtifactWithoutInterpretation
  :: forall x z.
     IncomingCotangentHandling
  -> (ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z)
  -> FullShapeTK x
  -> (AstArtifactRev x z, Delta (AstRaw PrimalSpan) z)
       -- ^ the artifact containing the symbolic code of the derivative
{-# INLINE revProduceArtifactWithoutInterpretation #-}
revProduceArtifactWithoutInterpretation cotangentHandling f xftk =
  -- No simplification performed to let individual tests decide.
  revArtifactFromForwardPass cotangentHandling
                             (forwardPassByApplication f)
                             xftk

forwardPassByApplication
  :: forall x z.
     (ADVal (AstRaw PrimalSpan) x -> ADVal (AstRaw PrimalSpan) z)
  -> AstTensor AstMethodShare PrimalSpan x
  -> AstVarName FullSpan x
  -> AstTensor AstMethodLet FullSpan x
  -> ADVal (AstRaw PrimalSpan) z
{-# INLINE forwardPassByApplication #-}
forwardPassByApplication g astVarPrimal var _astVar =
  let deltaInputs = generateDeltaInputs $ varNameToFTK var
      varInputs = dDnotShared (AstRaw astVarPrimal) deltaInputs
  in g varInputs


-- * Symbolic forward derivative adaptors

-- | The forward derivative operation takes the perturbation parameter
-- by convention. It permits an arbitrary (nested tuple+)
-- type of the domain and arbitrary (nested pair) tensor kind of the codomain
-- of the function to be differentiated. The generality sometimes makes it
-- necessary to suppy type hints when applying this operation.
jvp
  :: forall src ztgt tgt.
     ( X src ~ X (Value src), KnownSTK (X src)
     , AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , AdaptableTarget Concrete (Value src)
     , tgt ~ AstTensor AstMethodLet FullSpan ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> Value src
  -> Value src  -- morally (ADTensorKind src)
  -> Concrete (ADTensorKind ztgt)
{-# INLINE jvp #-}
jvp f vals0 ds =
  let valsTarget = toTarget vals0
      xftk = tftkG (knownSTK @(X src)) $ unConcrete valsTarget
      artifactRaw = fwdArtifactAdapt f xftk
      artifact = simplifyArtifactDerivative artifactRaw
  in fst $ fwdInterpretArtifact artifact valsTarget
         $ toADTensorKindShared xftk (toTarget ds)
       -- the shapes of vals0 vs ds are checked in fwdInterpretArtifact

-- | Compute the forward derivative not for a specific input, but as symbolic
-- function from inputs and perturbation to the derivative value.
-- The function is represented as an "artifact", which is the derivative
-- AST term together with variables corresponding to the input and perturbation.
jvpArtifact
  :: forall src ztgt tgt.
     ( X src ~ X (Value src), KnownSTK (X src)
     , AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , AdaptableTarget Concrete (Value src)
     , tgt ~ AstTensor AstMethodLet FullSpan ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> Value src
  -> AstArtifactFwd (X src) ztgt
       -- ^ the artifact containing the symbolic code of the derivative
{-# INLINE jvpArtifact #-}
jvpArtifact f vals0 =
  let xftk = tftkG (knownSTK @(X src)) $ unConcrete $ toTarget vals0
  in fwdArtifactAdapt f xftk

-- | Interpret the "artifact" as a function from concrete tensors
-- to a concrete tensor.
jvpInterpretArtifact
  :: forall x z.
     AstArtifactFwd x z
       -- ^ the artifact containing the symbolic code of the derivative
  -> Concrete x
  -> Concrete (ADTensorKind x)
  -> Concrete (ADTensorKind z)
{-# INLINE jvpInterpretArtifact #-}
jvpInterpretArtifact art parameters = fst . fwdInterpretArtifact art parameters
  -- the shapes of parameters vs ds are checked in fwdInterpretArtifact


-- * Symbolic forward derivative adaptors' internal machinery

fwdArtifactAdapt
  :: forall src ztgt tgt.
     ( AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , tgt ~ AstTensor AstMethodLet FullSpan ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> FullShapeTK (X src)
  -> AstArtifactFwd (X src) ztgt
       -- ^ the artifact containing the symbolic code of the derivative
{-# INLINE fwdArtifactAdapt #-}
fwdArtifactAdapt f xftk =
  let g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
      g !arg = simplifyInline $ ttlet arg $ f . fromTarget
                                  -- fromTarget requires duplicable
  in fwdProduceArtifact g emptyEnv xftk

fwdInterpretArtifact
  :: forall x z.
     AstArtifactFwd x z
       -- ^ the artifact containing the symbolic code of the derivative
  -> Concrete x
  -> Concrete (ADTensorKind x)
  -> (Concrete (ADTensorKind z), Concrete z)
{-# INLINE fwdInterpretArtifact #-}
fwdInterpretArtifact AstArtifactFwd{..} parameters ds =
  let xftk = varNameToFTK artVarDomainFwd
      xstk = ftkToSTK xftk
      env = extendEnv artVarDomainFwd parameters emptyEnv
      envD = extendEnv artVarDsFwd ds env
  in if tftkG xstk (unConcrete parameters) == xftk
     then if tftkG (adSTK xstk) (unConcrete ds) == adFTK xftk
          then let derivative = interpretAstPrimal envD artDerivativeFwd
                   primal = interpretAstPrimal env artPrimalFwd
               in (derivative, primal)
          else error "fwdInterpretArtifact: forward derivative perturbation must have the same shape as the domain of the objective function"
     else error "fwdInterpretArtifact: forward derivative input must have the same shape as the domain of the objective function"


-- * Symbolic forward derivative adaptors' testing-only internal machinery

fwdArtifactDelta
  :: forall src ztgt tgt.
     ( AdaptableTarget (AstTensor AstMethodLet FullSpan) src
     , tgt ~ AstTensor AstMethodLet FullSpan ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> FullShapeTK (X src)
  -> (AstArtifactFwd (X src) ztgt, Delta (AstRaw PrimalSpan) ztgt)
       -- ^ the artifact containing the symbolic code of the derivative
{-# INLINE fwdArtifactDelta #-}
fwdArtifactDelta f xftk =
  let g :: AstTensor AstMethodLet FullSpan (X src) -> tgt
      g !arg = ttlet arg $ f . fromTarget
  in fwdArtifactFromForwardPass (forwardPassByInterpretation g emptyEnv) xftk


-- * Non-symbolic reverse derivative adaptors

-- We are inlining these functions because they take function arguments
-- and are not too large. However, because they are called in many places,
-- we break the inline chain not far from the top, to avoid exe blowup.
--
-- | This simplified version of the concrete (non-symbolic)
-- reverse derivative operation sets the incoming cotangent @dt@ to be 1
-- and assumes the codomain of the function to be differentiated is a scalar.
cgrad
  :: forall src r tgt.
     ( X src ~ X (DValue src), KnownSTK (X src)
     , AdaptableTarget (ADVal Concrete) src
     , AdaptableTarget Concrete (DValue src)
     , tgt ~ ADVal Concrete (TKScalar r) )
  => (src -> tgt)  -- ^ the objective function
  -> DValue src
  -> DValue src  -- morally DValue (ADTensorKind src)
{-# INLINE cgrad #-}
cgrad f vals = crevMaybe f vals Nothing

-- | This more general version of the concrete (non-symbolic)
-- reverse derivative operation additionally takes the sensitivity parameter
-- (the incoming cotangent).
cvjp
  :: forall src ztgt tgt.
     ( X src ~ X (DValue src), KnownSTK (X src)
     , AdaptableTarget (ADVal Concrete) src
     , AdaptableTarget Concrete (DValue src)
     , tgt ~ ADVal Concrete ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> DValue src
  -> Concrete (ADTensorKind ztgt)
  -> DValue src  -- morally DValue (ADTensorKind src)
{-# INLINE cvjp #-}
cvjp f vals dt = crevMaybe f vals (Just dt)


-- * Non-symbolic reverse derivative adaptors' internal machinery

crevMaybe
  :: forall src ztgt tgt.
     ( X src ~ X (DValue src), KnownSTK (X src)
     , AdaptableTarget (ADVal Concrete) src
     , AdaptableTarget Concrete (DValue src)
     , tgt ~ ADVal Concrete ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> DValue src
  -> Maybe (Concrete (ADTensorKind ztgt))
  -> DValue src  -- morally DValue (ADTensorKind src)
{-# INLINE crevMaybe #-}
crevMaybe f vals0 mdt =
  let valsTarget = toTarget vals0
      g :: ADVal Concrete (X src) -> tgt
      g = f . fromTarget
      xftk = tftkG (knownSTK @(X src)) $ unConcrete valsTarget
  in fromTarget $ fromADTensorKindShared (ftkToSTK xftk)
     $ fst $ crevOnParams mdt g xftk valsTarget


-- * Non-symbolic forward derivative adaptors

-- | Concrete (non-symbolic) forward derivative operation. It always takes
-- the perturbation parameter, by convention.
cjvp
  :: forall src ztgt tgt.
     ( X src ~ X (DValue src), KnownSTK (X src)
     , AdaptableTarget (ADVal Concrete) src
     , AdaptableTarget Concrete (DValue src)
     , tgt ~ ADVal Concrete ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> DValue src
  -> DValue src  -- morally (ADTensorKind src)
  -> Concrete (ADTensorKind ztgt)
{-# INLINE cjvp #-}
cjvp f vals ds = fst $ cfwdBoth f vals ds


-- * Non-symbolic forward derivative adaptors' internal machinery

cfwdBoth
  :: forall src ztgt tgt.
     ( X src ~ X (DValue src), KnownSTK (X src)
     , AdaptableTarget (ADVal Concrete) src
     , AdaptableTarget Concrete (DValue src)
     , tgt ~ ADVal Concrete ztgt )
  => (src -> tgt)  -- ^ the objective function
  -> DValue src
  -> DValue src  -- morally (ADTensorKind src)
  -> (Concrete (ADTensorKind ztgt), Concrete ztgt)
{-# INLINE cfwdBoth #-}
cfwdBoth f vals0 ds =
  let valsTarget = toTarget vals0
      xftk = tftkG (knownSTK @(X src)) $ unConcrete valsTarget
      g :: ADVal Concrete (X src) -> tgt
      g = f . fromTarget
      dsTarget = toTarget ds
  in if tftkG (ftkToSTK xftk) (unConcrete dsTarget) == xftk
     then cfwdOnParams xftk valsTarget g
          $ toADTensorKindShared xftk dsTarget
     else error "cfwdBoth: forward derivative input must have the same shape as the perturbation argument"





-- This specialization is not possible where the functions are defined,
-- due to dependency cycles, but it's possible here:
{-# SPECIALIZE gradientFromDelta :: FullShapeTK x -> FullShapeTK z -> Concrete (ADTensorKind z) -> Delta Concrete z -> Concrete (ADTensorKind x) #-}
{-# SPECIALIZE evalRev :: FullShapeTK y -> EvalState Concrete -> Concrete (ADTensorKind y) -> Delta Concrete y -> EvalState Concrete #-}
{-# SPECIALIZE evalRevFTK :: EvalState Concrete -> Concrete (ADTensorKind y) -> Delta Concrete y -> EvalState Concrete #-}
-- RULE left-hand side too complicated to desugar:
-- {-# SPECIALIZE evalRevSame :: y ~ ADTensorKind y => EvalState Concrete -> Concrete (ADTensorKind y) -> Delta Concrete y -> EvalState Concrete #-}
{-# SPECIALIZE evalRevFromnMap :: EvalState Concrete -> EvalState Concrete #-}

{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKScalar Double) -> Delta Concrete (TKScalar Double) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKScalar Float) -> Delta Concrete (TKScalar Float) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKR n Double) -> Delta Concrete (TKR n Double) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKR n Float) -> Delta Concrete (TKR n Float) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKS sh Double) -> Delta Concrete (TKS sh Double) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKS sh Float) -> Delta Concrete (TKS sh Float) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKX sh Double) -> Delta Concrete (TKX sh Double) -> EvalState Concrete #-}
{-# SPECIALIZE evalRevSame :: EvalState Concrete -> Concrete (TKX sh Float) -> Delta Concrete (TKX sh Float) -> EvalState Concrete #-}


-- These and all other SPECIALIZE pragmas are needed due to the already
-- mostly fixed issues #21286 and others, even just to compare
-- the output with them and without.
-- This is needed for all three AstSpan values, to handle recursive calls
-- from interpretAstDual, etc.
{-# SPECIALIZE interpretAst
  :: AstEnv (ADVal Concrete)
  -> AstTensor AstMethodLet PrimalSpan y
  -> ADVal Concrete y #-}
{-# SPECIALIZE interpretAst
  :: AstEnv (ADVal (AstRaw PrimalSpan))
  -> AstTensor AstMethodLet PrimalSpan y
  -> ADVal (AstRaw PrimalSpan) y #-}
{-# SPECIALIZE interpretAst
  :: AstEnv Concrete
  -> AstTensor AstMethodLet PrimalSpan y
  -> Concrete y #-}
{-# SPECIALIZE interpretAst
  :: AstEnv (ADVal Concrete)
  -> AstTensor AstMethodLet DualSpan y
  -> ADVal Concrete y #-}
{-# SPECIALIZE interpretAst
  :: AstEnv (ADVal (AstRaw PrimalSpan))
  -> AstTensor AstMethodLet DualSpan y
  -> ADVal (AstRaw PrimalSpan) y #-}
{-# SPECIALIZE interpretAst
  :: AstEnv Concrete
  -> AstTensor AstMethodLet DualSpan y
  -> Concrete y #-}
{-# SPECIALIZE interpretAst
  :: AstEnv (ADVal Concrete)
  -> AstTensor AstMethodLet FullSpan y
  -> ADVal Concrete y #-}
{-# SPECIALIZE interpretAst
  :: AstEnv (ADVal (AstRaw PrimalSpan))
  -> AstTensor AstMethodLet FullSpan y
  -> ADVal (AstRaw PrimalSpan) y #-}
{-# SPECIALIZE interpretAst
  :: AstEnv Concrete
  -> AstTensor AstMethodLet FullSpan y
  -> Concrete y #-}

{-# SPECIALIZE interpretAstPrimal
  :: AstEnv (ADVal Concrete)
  -> AstTensor AstMethodLet PrimalSpan y
  -> Concrete y #-}
{-# SPECIALIZE interpretAstPrimal
  :: AstEnv (ADVal (AstRaw PrimalSpan))
  -> AstTensor AstMethodLet PrimalSpan y
  -> AstRaw PrimalSpan y #-}
{-# SPECIALIZE interpretAstPrimal
  :: AstEnv Concrete
  -> AstTensor AstMethodLet PrimalSpan y
  -> Concrete y #-}

{-# SPECIALIZE interpretAstBool
  :: AstEnv (ADVal Concrete)
  -> AstBool AstMethodLet
  -> Bool #-}
{-# SPECIALIZE interpretAstBool
  :: AstEnv (ADVal (AstRaw PrimalSpan))
  -> AstBool AstMethodLet
  -> AstBool AstMethodShare #-}
{-# SPECIALIZE interpretAstBool
  :: AstEnv Concrete
  -> AstBool AstMethodLet
  -> Bool #-}
